the-algorithm/product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/deepbird/TensorflowPredictionEngineScorer.scala
twitter-team ef4c5eb65e Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
2023-03-31 17:36:31 -05:00

59 lines
2.5 KiB
Scala

package com.twitter.product_mixer.component_library.scorer.deepbird
import com.twitter.cortex.deepbird.runtime.prediction_engine.TensorflowPredictionEngine
import com.twitter.cortex.deepbird.thriftjava.ModelSelector
import com.twitter.ml.prediction_service.BatchPredictionRequest
import com.twitter.ml.prediction_service.BatchPredictionResponse
import com.twitter.product_mixer.core.feature.datarecord.BaseDataRecordFeature
import com.twitter.product_mixer.core.feature.featuremap.datarecord.FeaturesScope
import com.twitter.product_mixer.core.model.common.UniversalNoun
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
import com.twitter.product_mixer.core.pipeline.PipelineQuery
import com.twitter.util.Future
/**
* Configurable Scorer that calls a TensorflowPredictionEngine.
* @param identifier Unique identifier for the scorer
* @param tensorflowPredictionEngine The TensorFlow Prediction Engine
* @param queryFeatures The Query Features to convert and pass to the deepbird model.
* @param candidateFeatures The Candidate Features to convert and pass to the deepbird model.
* @param resultFeatures The Candidate features returned by the model.
* @tparam Query Type of pipeline query.
* @tparam Candidate Type of candidates to score.
* @tparam QueryFeatures type of the query level features consumed by the scorer.
* @tparam CandidateFeatures type of the candidate level features consumed by the scorer.
* @tparam ResultFeatures type of the candidate level features returned by the scorer.
*/
class TensorflowPredictionEngineScorer[
Query <: PipelineQuery,
Candidate <: UniversalNoun[Any],
QueryFeatures <: BaseDataRecordFeature[Query, _],
CandidateFeatures <: BaseDataRecordFeature[Candidate, _],
ResultFeatures <: BaseDataRecordFeature[Candidate, _]
](
override val identifier: ScorerIdentifier,
tensorflowPredictionEngine: TensorflowPredictionEngine,
queryFeatures: FeaturesScope[QueryFeatures],
candidateFeatures: FeaturesScope[CandidateFeatures],
resultFeatures: Set[ResultFeatures])
extends BaseDeepbirdV2Scorer[
Query,
Candidate,
QueryFeatures,
CandidateFeatures,
ResultFeatures
](
identifier,
{ _: Query =>
None
},
queryFeatures,
candidateFeatures,
resultFeatures) {
override def getBatchPredictions(
request: BatchPredictionRequest,
modelSelector: ModelSelector
): Future[BatchPredictionResponse] = tensorflowPredictionEngine.getBatchPrediction(request)
}