47 lines
1.9 KiB
Scala
47 lines
1.9 KiB
Scala
package com.twitter.timelines.data_processing.ml_util.aggregation_framework
|
|
|
|
import com.twitter.dal.personal_data.thriftscala.PersonalDataType
|
|
import com.twitter.ml.api.DataRecord
|
|
import com.twitter.ml.api.Feature
|
|
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection
|
|
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.Batched
|
|
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.JavaCompactThrift
|
|
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.genericInjection
|
|
import com.twitter.summingbird.batch.BatchID
|
|
import scala.collection.JavaConverters._
|
|
|
|
object OfflineAggregateInjections {
|
|
val offlineDataRecordAggregateInjection: KeyValInjection[AggregationKey, (BatchID, DataRecord)] =
|
|
KeyValInjection(
|
|
genericInjection(AggregationKeyInjection),
|
|
Batched(JavaCompactThrift[DataRecord])
|
|
)
|
|
|
|
private[aggregation_framework] def getPdts[T](
|
|
aggregateGroups: Iterable[T],
|
|
featureExtractor: T => Iterable[Feature[_]]
|
|
): Option[Set[PersonalDataType]] = {
|
|
val pdts: Set[PersonalDataType] = for {
|
|
group <- aggregateGroups.toSet[T]
|
|
feature <- featureExtractor(group)
|
|
pdtSet <- feature.getPersonalDataTypes.asSet().asScala
|
|
javaPdt <- pdtSet.asScala
|
|
scalaPdt <- PersonalDataType.get(javaPdt.getValue)
|
|
} yield {
|
|
scalaPdt
|
|
}
|
|
if (pdts.nonEmpty) Some(pdts) else None
|
|
}
|
|
|
|
def getInjection(
|
|
aggregateGroups: Set[TypedAggregateGroup[_]]
|
|
): KeyValInjection[AggregationKey, (BatchID, DataRecord)] = {
|
|
val keyPdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputKeys)
|
|
val valuePdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputFeatures)
|
|
KeyValInjection(
|
|
genericInjection(AggregationKeyInjection, keyPdts),
|
|
genericInjection(Batched(JavaCompactThrift[DataRecord]), valuePdts)
|
|
)
|
|
}
|
|
}
|