the-algorithm/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateInjections....

47 lines
1.9 KiB
Scala

package com.twitter.timelines.data_processing.ml_util.aggregation_framework
import com.twitter.dal.personal_data.thriftscala.PersonalDataType
import com.twitter.ml.api.DataRecord
import com.twitter.ml.api.Feature
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.Batched
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.JavaCompactThrift
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.genericInjection
import com.twitter.summingbird.batch.BatchID
import scala.collection.JavaConverters._
object OfflineAggregateInjections {
val offlineDataRecordAggregateInjection: KeyValInjection[AggregationKey, (BatchID, DataRecord)] =
KeyValInjection(
genericInjection(AggregationKeyInjection),
Batched(JavaCompactThrift[DataRecord])
)
private[aggregation_framework] def getPdts[T](
aggregateGroups: Iterable[T],
featureExtractor: T => Iterable[Feature[_]]
): Option[Set[PersonalDataType]] = {
val pdts: Set[PersonalDataType] = for {
group <- aggregateGroups.toSet[T]
feature <- featureExtractor(group)
pdtSet <- feature.getPersonalDataTypes.asSet().asScala
javaPdt <- pdtSet.asScala
scalaPdt <- PersonalDataType.get(javaPdt.getValue)
} yield {
scalaPdt
}
if (pdts.nonEmpty) Some(pdts) else None
}
def getInjection(
aggregateGroups: Set[TypedAggregateGroup[_]]
): KeyValInjection[AggregationKey, (BatchID, DataRecord)] = {
val keyPdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputKeys)
val valuePdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputFeatures)
KeyValInjection(
genericInjection(AggregationKeyInjection, keyPdts),
genericInjection(Batched(JavaCompactThrift[DataRecord]), valuePdts)
)
}
}