166 lines
6.3 KiB
Scala
166 lines
6.3 KiB
Scala
package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion
|
|
|
|
import com.twitter.algebird.DecayedValue
|
|
import com.twitter.algebird.DecayedValueMonoid
|
|
import com.twitter.algebird.Monoid
|
|
import com.twitter.ml.api._
|
|
import com.twitter.ml.api.constant.SharedFeatures
|
|
import com.twitter.ml.api.util.FDsl._
|
|
import com.twitter.ml.api.util.SRichDataRecord
|
|
import com.twitter.summingbird.batch.BatchID
|
|
import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey
|
|
import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup
|
|
import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregateFeature
|
|
import com.twitter.util.Duration
|
|
import java.lang.{Double => JDouble}
|
|
import java.lang.{Long => JLong}
|
|
import scala.collection.JavaConverters._
|
|
import scala.collection.mutable
|
|
import java.{util => ju}
|
|
|
|
object AggregatesV2Adapter {
|
|
type AggregatesV2Tuple = (AggregationKey, (BatchID, DataRecord))
|
|
|
|
val Epsilon: Double = 1e-6
|
|
val decayedValueMonoid: Monoid[DecayedValue] = DecayedValueMonoid(Epsilon)
|
|
|
|
/*
|
|
* Decays the storedValue from timestamp -> sourceVersion
|
|
*
|
|
* @param storedValue value read from the aggregates v2 output store
|
|
* @param timestamp timestamp corresponding to store value
|
|
* @param sourceVersion timestamp of version to decay all values to uniformly
|
|
* @param halfLife Half life duration to use for applying decay
|
|
*
|
|
* By applying this function, the feature values for all users are decayed
|
|
* to sourceVersion. This is important to ensure that a user whose aggregates
|
|
* were updated long in the past does not have an artifically inflated count
|
|
* compared to one whose aggregates were updated (and hence decayed) more recently.
|
|
*/
|
|
def decayValueToSourceVersion(
|
|
storedValue: Double,
|
|
timestamp: Long,
|
|
sourceVersion: Long,
|
|
halfLife: Duration
|
|
): Double =
|
|
if (timestamp > sourceVersion) {
|
|
storedValue
|
|
} else {
|
|
decayedValueMonoid
|
|
.plus(
|
|
DecayedValue.build(storedValue, timestamp, halfLife.inMilliseconds),
|
|
DecayedValue.build(0, sourceVersion, halfLife.inMilliseconds)
|
|
)
|
|
.value
|
|
}
|
|
|
|
/*
|
|
* Decays all the aggregate features occurring in the ''inputRecord''
|
|
* to a given timestamp, and mutates the ''outputRecord'' accordingly.
|
|
* Note that inputRecord and outputRecord can be the same if you want
|
|
* to mutate the input in place, the function does this correctly.
|
|
*
|
|
* @param inputRecord Input record to get features from
|
|
* @param aggregates Aggregates to decay
|
|
* @param decayTo Timestamp to decay to
|
|
* @param trimThreshold Drop features below this trim threshold
|
|
* @param outputRecord Output record to mutate
|
|
* @return the mutated outputRecord
|
|
*/
|
|
def mutateDecay(
|
|
inputRecord: DataRecord,
|
|
aggregateFeaturesAndHalfLives: List[(Feature[_], Duration)],
|
|
decayTo: Long,
|
|
trimThreshold: Double,
|
|
outputRecord: DataRecord
|
|
): DataRecord = {
|
|
val timestamp = inputRecord.getFeatureValue(SharedFeatures.TIMESTAMP).toLong
|
|
|
|
aggregateFeaturesAndHalfLives.foreach {
|
|
case (aggregateFeature: Feature[_], halfLife: Duration) =>
|
|
if (aggregateFeature.getFeatureType() == FeatureType.CONTINUOUS) {
|
|
val continuousFeature = aggregateFeature.asInstanceOf[Feature[JDouble]]
|
|
if (inputRecord.hasFeature(continuousFeature)) {
|
|
val storedValue = inputRecord.getFeatureValue(continuousFeature).toDouble
|
|
val decayedValue = decayValueToSourceVersion(storedValue, timestamp, decayTo, halfLife)
|
|
if (math.abs(decayedValue) > trimThreshold) {
|
|
outputRecord.setFeatureValue(continuousFeature, decayedValue)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Update timestamp to version (now that we've decayed all aggregates) */
|
|
outputRecord.setFeatureValue(SharedFeatures.TIMESTAMP, decayTo)
|
|
|
|
outputRecord
|
|
}
|
|
}
|
|
|
|
class AggregatesV2Adapter(
|
|
aggregates: Set[TypedAggregateGroup[_]],
|
|
sourceVersion: Long,
|
|
trimThreshold: Double)
|
|
extends IRecordOneToManyAdapter[AggregatesV2Adapter.AggregatesV2Tuple] {
|
|
|
|
import AggregatesV2Adapter._
|
|
|
|
val keyFeatures: List[Feature[_]] = aggregates.flatMap(_.allOutputKeys).toList
|
|
val aggregateFeatures: List[Feature[_]] = aggregates.flatMap(_.allOutputFeatures).toList
|
|
val timestampFeatures: List[Feature[JLong]] = List(SharedFeatures.TIMESTAMP)
|
|
val allFeatures: List[Feature[_]] = keyFeatures ++ aggregateFeatures ++ timestampFeatures
|
|
|
|
val featureContext: FeatureContext = new FeatureContext(allFeatures.asJava)
|
|
|
|
override def getFeatureContext: FeatureContext = featureContext
|
|
|
|
val aggregateFeaturesAndHalfLives: List[(Feature[_$3], Duration) forSome { type _$3 }] =
|
|
aggregateFeatures.map { aggregateFeature: Feature[_] =>
|
|
val halfLife = AggregateFeature.parseHalfLife(aggregateFeature)
|
|
(aggregateFeature, halfLife)
|
|
}
|
|
|
|
override def adaptToDataRecords(tuple: AggregatesV2Tuple): ju.List[DataRecord] = tuple match {
|
|
case (key: AggregationKey, (batchId: BatchID, record: DataRecord)) => {
|
|
val resultRecord = new SRichDataRecord(new DataRecord, featureContext)
|
|
|
|
val itr = resultRecord.continuousFeaturesIterator()
|
|
val featuresToClear = mutable.Set[Feature[JDouble]]()
|
|
while (itr.moveNext()) {
|
|
val nextFeature = itr.getFeature
|
|
if (!aggregateFeatures.contains(nextFeature)) {
|
|
featuresToClear += nextFeature
|
|
}
|
|
}
|
|
|
|
featuresToClear.foreach(resultRecord.clearFeature)
|
|
|
|
keyFeatures.foreach { keyFeature: Feature[_] =>
|
|
if (keyFeature.getFeatureType == FeatureType.DISCRETE) {
|
|
resultRecord.setFeatureValue(
|
|
keyFeature.asInstanceOf[Feature[JLong]],
|
|
key.discreteFeaturesById(keyFeature.getDenseFeatureId)
|
|
)
|
|
} else if (keyFeature.getFeatureType == FeatureType.STRING) {
|
|
resultRecord.setFeatureValue(
|
|
keyFeature.asInstanceOf[Feature[String]],
|
|
key.textFeaturesById(keyFeature.getDenseFeatureId)
|
|
)
|
|
}
|
|
}
|
|
|
|
if (record.hasFeature(SharedFeatures.TIMESTAMP)) {
|
|
mutateDecay(
|
|
record,
|
|
aggregateFeaturesAndHalfLives,
|
|
sourceVersion,
|
|
trimThreshold,
|
|
resultRecord)
|
|
List(resultRecord.getRecord).asJava
|
|
} else {
|
|
List.empty[DataRecord].asJava
|
|
}
|
|
}
|
|
}
|
|
}
|