the-algorithm/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2Adapter.scala

166 lines
6.3 KiB
Scala

package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion
import com.twitter.algebird.DecayedValue
import com.twitter.algebird.DecayedValueMonoid
import com.twitter.algebird.Monoid
import com.twitter.ml.api._
import com.twitter.ml.api.constant.SharedFeatures
import com.twitter.ml.api.util.FDsl._
import com.twitter.ml.api.util.SRichDataRecord
import com.twitter.summingbird.batch.BatchID
import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey
import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup
import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregateFeature
import com.twitter.util.Duration
import java.lang.{Double => JDouble}
import java.lang.{Long => JLong}
import scala.collection.JavaConverters._
import scala.collection.mutable
import java.{util => ju}
object AggregatesV2Adapter {
type AggregatesV2Tuple = (AggregationKey, (BatchID, DataRecord))
val Epsilon: Double = 1e-6
val decayedValueMonoid: Monoid[DecayedValue] = DecayedValueMonoid(Epsilon)
/*
* Decays the storedValue from timestamp -> sourceVersion
*
* @param storedValue value read from the aggregates v2 output store
* @param timestamp timestamp corresponding to store value
* @param sourceVersion timestamp of version to decay all values to uniformly
* @param halfLife Half life duration to use for applying decay
*
* By applying this function, the feature values for all users are decayed
* to sourceVersion. This is important to ensure that a user whose aggregates
* were updated long in the past does not have an artifically inflated count
* compared to one whose aggregates were updated (and hence decayed) more recently.
*/
def decayValueToSourceVersion(
storedValue: Double,
timestamp: Long,
sourceVersion: Long,
halfLife: Duration
): Double =
if (timestamp > sourceVersion) {
storedValue
} else {
decayedValueMonoid
.plus(
DecayedValue.build(storedValue, timestamp, halfLife.inMilliseconds),
DecayedValue.build(0, sourceVersion, halfLife.inMilliseconds)
)
.value
}
/*
* Decays all the aggregate features occurring in the ''inputRecord''
* to a given timestamp, and mutates the ''outputRecord'' accordingly.
* Note that inputRecord and outputRecord can be the same if you want
* to mutate the input in place, the function does this correctly.
*
* @param inputRecord Input record to get features from
* @param aggregates Aggregates to decay
* @param decayTo Timestamp to decay to
* @param trimThreshold Drop features below this trim threshold
* @param outputRecord Output record to mutate
* @return the mutated outputRecord
*/
def mutateDecay(
inputRecord: DataRecord,
aggregateFeaturesAndHalfLives: List[(Feature[_], Duration)],
decayTo: Long,
trimThreshold: Double,
outputRecord: DataRecord
): DataRecord = {
val timestamp = inputRecord.getFeatureValue(SharedFeatures.TIMESTAMP).toLong
aggregateFeaturesAndHalfLives.foreach {
case (aggregateFeature: Feature[_], halfLife: Duration) =>
if (aggregateFeature.getFeatureType() == FeatureType.CONTINUOUS) {
val continuousFeature = aggregateFeature.asInstanceOf[Feature[JDouble]]
if (inputRecord.hasFeature(continuousFeature)) {
val storedValue = inputRecord.getFeatureValue(continuousFeature).toDouble
val decayedValue = decayValueToSourceVersion(storedValue, timestamp, decayTo, halfLife)
if (math.abs(decayedValue) > trimThreshold) {
outputRecord.setFeatureValue(continuousFeature, decayedValue)
}
}
}
}
/* Update timestamp to version (now that we've decayed all aggregates) */
outputRecord.setFeatureValue(SharedFeatures.TIMESTAMP, decayTo)
outputRecord
}
}
class AggregatesV2Adapter(
aggregates: Set[TypedAggregateGroup[_]],
sourceVersion: Long,
trimThreshold: Double)
extends IRecordOneToManyAdapter[AggregatesV2Adapter.AggregatesV2Tuple] {
import AggregatesV2Adapter._
val keyFeatures: List[Feature[_]] = aggregates.flatMap(_.allOutputKeys).toList
val aggregateFeatures: List[Feature[_]] = aggregates.flatMap(_.allOutputFeatures).toList
val timestampFeatures: List[Feature[JLong]] = List(SharedFeatures.TIMESTAMP)
val allFeatures: List[Feature[_]] = keyFeatures ++ aggregateFeatures ++ timestampFeatures
val featureContext: FeatureContext = new FeatureContext(allFeatures.asJava)
override def getFeatureContext: FeatureContext = featureContext
val aggregateFeaturesAndHalfLives: List[(Feature[_$3], Duration) forSome { type _$3 }] =
aggregateFeatures.map { aggregateFeature: Feature[_] =>
val halfLife = AggregateFeature.parseHalfLife(aggregateFeature)
(aggregateFeature, halfLife)
}
override def adaptToDataRecords(tuple: AggregatesV2Tuple): ju.List[DataRecord] = tuple match {
case (key: AggregationKey, (batchId: BatchID, record: DataRecord)) => {
val resultRecord = new SRichDataRecord(new DataRecord, featureContext)
val itr = resultRecord.continuousFeaturesIterator()
val featuresToClear = mutable.Set[Feature[JDouble]]()
while (itr.moveNext()) {
val nextFeature = itr.getFeature
if (!aggregateFeatures.contains(nextFeature)) {
featuresToClear += nextFeature
}
}
featuresToClear.foreach(resultRecord.clearFeature)
keyFeatures.foreach { keyFeature: Feature[_] =>
if (keyFeature.getFeatureType == FeatureType.DISCRETE) {
resultRecord.setFeatureValue(
keyFeature.asInstanceOf[Feature[JLong]],
key.discreteFeaturesById(keyFeature.getDenseFeatureId)
)
} else if (keyFeature.getFeatureType == FeatureType.STRING) {
resultRecord.setFeatureValue(
keyFeature.asInstanceOf[Feature[String]],
key.textFeaturesById(keyFeature.getDenseFeatureId)
)
}
}
if (record.hasFeature(SharedFeatures.TIMESTAMP)) {
mutateDecay(
record,
aggregateFeaturesAndHalfLives,
sourceVersion,
trimThreshold,
resultRecord)
List(resultRecord.getRecord).asJava
} else {
List.empty[DataRecord].asJava
}
}
}
}