mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-30 06:56:08 +02:00
81 lines
3.3 KiB
Scala
81 lines
3.3 KiB
Scala
|
package com.twitter.home_mixer.functional_component.feature_hydrator
|
||
|
|
||
|
import com.twitter.finagle.stats.StatsReceiver
|
||
|
import com.twitter.home_mixer.functional_component.feature_hydrator.adapters.twhin_embeddings.TwhinUserEngagementEmbeddingsAdapter
|
||
|
import com.twitter.home_mixer.param.HomeMixerInjectionNames.TwhinUserEngagementFeatureRepository
|
||
|
import com.twitter.ml.api.DataRecord
|
||
|
import com.twitter.ml.api.RichDataRecord
|
||
|
import com.twitter.ml.api.util.ScalaToJavaDataRecordConversions
|
||
|
import com.twitter.ml.api.{thriftscala => ml}
|
||
|
import com.twitter.product_mixer.core.feature.datarecord.DataRecordInAFeature
|
||
|
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||
|
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
||
|
import com.twitter.product_mixer.core.feature.Feature
|
||
|
import com.twitter.product_mixer.core.feature.FeatureWithDefaultOnFailure
|
||
|
import com.twitter.product_mixer.core.functional_component.feature_hydrator.QueryFeatureHydrator
|
||
|
import com.twitter.product_mixer.core.model.common.identifier.FeatureHydratorIdentifier
|
||
|
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||
|
import com.twitter.servo.repository.KeyValueRepository
|
||
|
import com.twitter.stitch.Stitch
|
||
|
import com.twitter.util.Return
|
||
|
import com.twitter.util.Throw
|
||
|
import javax.inject.Inject
|
||
|
import javax.inject.Named
|
||
|
import javax.inject.Singleton
|
||
|
|
||
|
object TwhinUserEngagementFeature
|
||
|
extends DataRecordInAFeature[PipelineQuery]
|
||
|
with FeatureWithDefaultOnFailure[PipelineQuery, DataRecord] {
|
||
|
override def defaultValue: DataRecord = new DataRecord()
|
||
|
}
|
||
|
|
||
|
@Singleton
|
||
|
class TwhinUserEngagementQueryFeatureHydrator @Inject() (
|
||
|
@Named(TwhinUserEngagementFeatureRepository)
|
||
|
client: KeyValueRepository[Seq[Long], Long, ml.FloatTensor],
|
||
|
statsReceiver: StatsReceiver)
|
||
|
extends QueryFeatureHydrator[PipelineQuery] {
|
||
|
|
||
|
override val identifier: FeatureHydratorIdentifier =
|
||
|
FeatureHydratorIdentifier("TwhinUserEngagement")
|
||
|
|
||
|
override val features: Set[Feature[_, _]] = Set(TwhinUserEngagementFeature)
|
||
|
|
||
|
private val scopedStatsReceiver = statsReceiver.scope(getClass.getSimpleName)
|
||
|
private val keyFoundCounter = scopedStatsReceiver.counter("key/found")
|
||
|
private val keyLossCounter = scopedStatsReceiver.counter("key/loss")
|
||
|
private val keyFailureCounter = scopedStatsReceiver.counter("key/failure")
|
||
|
|
||
|
override def hydrate(query: PipelineQuery): Stitch[FeatureMap] = {
|
||
|
val userId = query.getRequiredUserId
|
||
|
Stitch.callFuture {
|
||
|
client(Seq(userId)).map { results =>
|
||
|
val embedding: Option[ml.FloatTensor] = results(userId) match {
|
||
|
case Return(value) =>
|
||
|
if (value.exists(_.floats.nonEmpty)) keyFoundCounter.incr()
|
||
|
else keyLossCounter.incr()
|
||
|
value
|
||
|
case Throw(_) =>
|
||
|
keyFailureCounter.incr()
|
||
|
None
|
||
|
case _ =>
|
||
|
None
|
||
|
}
|
||
|
val dataRecord =
|
||
|
new RichDataRecord(new DataRecord, TwhinUserEngagementEmbeddingsAdapter.getFeatureContext)
|
||
|
embedding.foreach { floatTensor =>
|
||
|
dataRecord.setFeatureValue(
|
||
|
TwhinUserEngagementEmbeddingsAdapter.twhinEmbeddingsFeature,
|
||
|
ScalaToJavaDataRecordConversions.scalaTensor2Java(
|
||
|
ml.GeneralTensor.FloatTensor(floatTensor))
|
||
|
)
|
||
|
}
|
||
|
|
||
|
FeatureMapBuilder()
|
||
|
.add(TwhinUserEngagementFeature, dataRecord.getRecord)
|
||
|
.build()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|