package com.twitter.home_mixer.functional_component.feature_hydrator import com.twitter.finagle.stats.StatsReceiver import com.twitter.home_mixer.functional_component.feature_hydrator.adapters.twhin_embeddings.TwhinUserFollowEmbeddingsAdapter import com.twitter.home_mixer.param.HomeMixerInjectionNames.TwhinUserFollowFeatureRepository import com.twitter.ml.api.DataRecord import com.twitter.ml.api.RichDataRecord import com.twitter.ml.api.util.ScalaToJavaDataRecordConversions import com.twitter.ml.api.{thriftscala => ml} import com.twitter.product_mixer.core.feature.datarecord.DataRecordInAFeature import com.twitter.product_mixer.core.feature.featuremap.FeatureMap import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder import com.twitter.product_mixer.core.feature.Feature import com.twitter.product_mixer.core.feature.FeatureWithDefaultOnFailure import com.twitter.product_mixer.core.functional_component.feature_hydrator.QueryFeatureHydrator import com.twitter.product_mixer.core.model.common.identifier.FeatureHydratorIdentifier import com.twitter.product_mixer.core.pipeline.PipelineQuery import com.twitter.servo.repository.KeyValueRepository import com.twitter.stitch.Stitch import com.twitter.util.Return import com.twitter.util.Throw import javax.inject.Inject import javax.inject.Named import javax.inject.Singleton object TwhinUserFollowFeature extends DataRecordInAFeature[PipelineQuery] with FeatureWithDefaultOnFailure[PipelineQuery, DataRecord] { override def defaultValue: DataRecord = new DataRecord() } @Singleton class TwhinUserFollowQueryFeatureHydrator @Inject() ( @Named(TwhinUserFollowFeatureRepository) client: KeyValueRepository[Seq[Long], Long, ml.FloatTensor], statsReceiver: StatsReceiver) extends QueryFeatureHydrator[PipelineQuery] { override val identifier: FeatureHydratorIdentifier = FeatureHydratorIdentifier("TwhinUserFollow") override val features: Set[Feature[_, _]] = Set(TwhinUserFollowFeature) private val scopedStatsReceiver = statsReceiver.scope(getClass.getSimpleName) private val keyFoundCounter = scopedStatsReceiver.counter("key/found") private val keyLossCounter = scopedStatsReceiver.counter("key/loss") private val keyFailureCounter = scopedStatsReceiver.counter("key/failure") override def hydrate(query: PipelineQuery): Stitch[FeatureMap] = { val userId = query.getRequiredUserId Stitch.callFuture( client(Seq(userId)).map { results => val embedding: Option[ml.FloatTensor] = results(userId) match { case Return(value) => if (value.exists(_.floats.nonEmpty)) keyFoundCounter.incr() else keyLossCounter.incr() value case Throw(_) => keyFailureCounter.incr() None case _ => None } val dataRecord = new RichDataRecord(new DataRecord, TwhinUserFollowEmbeddingsAdapter.getFeatureContext) embedding.foreach { floatTensor => dataRecord.setFeatureValue( TwhinUserFollowEmbeddingsAdapter.twhinEmbeddingsFeature, ScalaToJavaDataRecordConversions.scalaTensor2Java( ml.GeneralTensor .FloatTensor(floatTensor))) } FeatureMapBuilder() .add(TwhinUserFollowFeature, dataRecord.getRecord) .build() } ) } }