the-algorithm/timelines/data_processing/ml_util/aggregation_framework/Utils.scala

123 lines
4.5 KiB
Scala

package com.twitter.timelines.data_processing.ml_util.aggregation_framework
import com.twitter.algebird.ScMapMonoid
import com.twitter.algebird.Semigroup
import com.twitter.ml.api._
import com.twitter.ml.api.constant.SharedFeatures
import com.twitter.ml.api.DataRecord
import com.twitter.ml.api.Feature
import com.twitter.ml.api.FeatureType
import com.twitter.ml.api.util.SRichDataRecord
import java.lang.{Long => JLong}
import scala.collection.{Map => ScMap}
object Utils {
val dataRecordMerger: DataRecordMerger = new DataRecordMerger
def EmptyDataRecord: DataRecord = new DataRecord()
private val random = scala.util.Random
private val keyedDataRecordMapMonoid = {
val dataRecordMergerSg = new Semigroup[DataRecord] {
override def plus(x: DataRecord, y: DataRecord): DataRecord = {
dataRecordMerger.merge(x, y)
x
}
}
new ScMapMonoid[Long, DataRecord]()(dataRecordMergerSg)
}
def keyFromLong(record: DataRecord, feature: Feature[JLong]): Long =
SRichDataRecord(record).getFeatureValue(feature).longValue
def keyFromString(record: DataRecord, feature: Feature[String]): Long =
try {
SRichDataRecord(record).getFeatureValue(feature).toLong
} catch {
case _: NumberFormatException => 0L
}
def keyFromHash(record: DataRecord, feature: Feature[String]): Long =
SRichDataRecord(record).getFeatureValue(feature).hashCode.toLong
def extractSecondary[T](
record: DataRecord,
secondaryKey: Feature[T],
shouldHash: Boolean = false
): Long = secondaryKey.getFeatureType match {
case FeatureType.STRING =>
if (shouldHash) keyFromHash(record, secondaryKey.asInstanceOf[Feature[String]])
else keyFromString(record, secondaryKey.asInstanceOf[Feature[String]])
case FeatureType.DISCRETE => keyFromLong(record, secondaryKey.asInstanceOf[Feature[JLong]])
case f => throw new IllegalArgumentException(s"Feature type $f is not supported.")
}
def mergeKeyedRecordOpts(args: Option[KeyedRecord]*): Option[KeyedRecord] = {
val keyedRecords = args.flatten
if (keyedRecords.isEmpty) {
None
} else {
val keys = keyedRecords.map(_.aggregateType)
require(keys.toSet.size == 1, "All merged records must have the same aggregate key.")
val mergedRecord = mergeRecords(keyedRecords.map(_.record): _*)
Some(KeyedRecord(keys.head, mergedRecord))
}
}
private def mergeRecords(args: DataRecord*): DataRecord =
if (args.isEmpty) EmptyDataRecord
else {
// can just do foldLeft(new DataRecord) for both cases, but try reusing the EmptyDataRecord singleton as much as possible
args.tail.foldLeft(args.head) { (merged, record) =>
dataRecordMerger.merge(merged, record)
merged
}
}
def mergeKeyedRecordMapOpts(
opt1: Option[KeyedRecordMap],
opt2: Option[KeyedRecordMap],
maxSize: Int = Int.MaxValue
): Option[KeyedRecordMap] = {
if (opt1.isEmpty && opt2.isEmpty) {
None
} else {
val keys = Seq(opt1, opt2).flatten.map(_.aggregateType)
require(keys.toSet.size == 1, "All merged records must have the same aggregate key.")
val mergedRecordMap = mergeMapOpts(opt1.map(_.recordMap), opt2.map(_.recordMap), maxSize)
Some(KeyedRecordMap(keys.head, mergedRecordMap))
}
}
private def mergeMapOpts(
opt1: Option[ScMap[Long, DataRecord]],
opt2: Option[ScMap[Long, DataRecord]],
maxSize: Int = Int.MaxValue
): ScMap[Long, DataRecord] = {
require(maxSize >= 0)
val keySet = opt1.map(_.keySet).getOrElse(Set.empty) ++ opt2.map(_.keySet).getOrElse(Set.empty)
val totalSize = keySet.size
val rate = if (totalSize <= maxSize) 1.0 else maxSize.toDouble / totalSize
val prunedOpt1 = opt1.map(downsample(_, rate))
val prunedOpt2 = opt2.map(downsample(_, rate))
Seq(prunedOpt1, prunedOpt2).flatten
.foldLeft(keyedDataRecordMapMonoid.zero)(keyedDataRecordMapMonoid.plus)
}
def downsample[K, T](m: ScMap[K, T], samplingRate: Double): ScMap[K, T] = {
if (samplingRate >= 1.0) {
m
} else if (samplingRate <= 0) {
Map.empty
} else {
m.filter {
case (key, _) =>
// It is important that the same user with the same sampling rate be deterministically
// selected or rejected. Otherwise, mergeMapOpts will choose different keys for the
// two input maps and their union will be larger than the limit we want.
random.setSeed((key.hashCode, samplingRate.hashCode).hashCode)
random.nextDouble < samplingRate
}
}
}
}