117 lines
4.3 KiB
Scala
117 lines
4.3 KiB
Scala
package com.twitter.representationscorer.columns
|
|
|
|
import com.twitter.representationscorer.thriftscala.ListScoreId
|
|
import com.twitter.representationscorer.thriftscala.ListScoreResponse
|
|
import com.twitter.representationscorer.scorestore.ScoreStore
|
|
import com.twitter.representationscorer.thriftscala.ScoreResult
|
|
import com.twitter.simclusters_v2.common.SimClustersEmbeddingId.LongInternalId
|
|
import com.twitter.simclusters_v2.common.SimClustersEmbeddingId.LongSimClustersEmbeddingId
|
|
import com.twitter.simclusters_v2.thriftscala.Score
|
|
import com.twitter.simclusters_v2.thriftscala.ScoreId
|
|
import com.twitter.simclusters_v2.thriftscala.ScoreInternalId
|
|
import com.twitter.simclusters_v2.thriftscala.SimClustersEmbeddingId
|
|
import com.twitter.simclusters_v2.thriftscala.SimClustersEmbeddingPairScoreId
|
|
import com.twitter.stitch
|
|
import com.twitter.stitch.Stitch
|
|
import com.twitter.strato.catalog.OpMetadata
|
|
import com.twitter.strato.config.ContactInfo
|
|
import com.twitter.strato.config.Policy
|
|
import com.twitter.strato.data.Conv
|
|
import com.twitter.strato.data.Description.PlainText
|
|
import com.twitter.strato.data.Lifecycle
|
|
import com.twitter.strato.fed._
|
|
import com.twitter.strato.thrift.ScroogeConv
|
|
import com.twitter.util.Future
|
|
import com.twitter.util.Return
|
|
import com.twitter.util.Throw
|
|
import javax.inject.Inject
|
|
|
|
class ListScoreColumn @Inject() (scoreStore: ScoreStore)
|
|
extends StratoFed.Column("recommendations/representation_scorer/listScore")
|
|
with StratoFed.Fetch.Stitch {
|
|
|
|
override val policy: Policy = Common.rsxReadPolicy
|
|
|
|
override type Key = ListScoreId
|
|
override type View = Unit
|
|
override type Value = ListScoreResponse
|
|
|
|
override val keyConv: Conv[Key] = ScroogeConv.fromStruct[ListScoreId]
|
|
override val viewConv: Conv[View] = Conv.ofType
|
|
override val valueConv: Conv[Value] = ScroogeConv.fromStruct[ListScoreResponse]
|
|
|
|
override val contactInfo: ContactInfo = Info.contactInfo
|
|
|
|
override val metadata: OpMetadata = OpMetadata(
|
|
lifecycle = Some(Lifecycle.Production),
|
|
description = Some(
|
|
PlainText(
|
|
"Scoring for multiple candidate entities against a single target entity"
|
|
))
|
|
)
|
|
|
|
override def fetch(key: Key, view: View): Stitch[Result[Value]] = {
|
|
|
|
val target = SimClustersEmbeddingId(
|
|
embeddingType = key.targetEmbeddingType,
|
|
modelVersion = key.modelVersion,
|
|
internalId = key.targetId
|
|
)
|
|
val scoreIds = key.candidateIds.map { candidateId =>
|
|
val candidate = SimClustersEmbeddingId(
|
|
embeddingType = key.candidateEmbeddingType,
|
|
modelVersion = key.modelVersion,
|
|
internalId = candidateId
|
|
)
|
|
ScoreId(
|
|
algorithm = key.algorithm,
|
|
internalId = ScoreInternalId.SimClustersEmbeddingPairScoreId(
|
|
SimClustersEmbeddingPairScoreId(target, candidate)
|
|
)
|
|
)
|
|
}
|
|
|
|
Stitch
|
|
.callFuture {
|
|
val (keys: Iterable[ScoreId], vals: Iterable[Future[Option[Score]]]) =
|
|
scoreStore.uniformScoringStore.multiGet(scoreIds.toSet).unzip
|
|
val results: Future[Iterable[Option[Score]]] = Future.collectToTry(vals.toSeq) map {
|
|
tryOptVals =>
|
|
tryOptVals map {
|
|
case Return(Some(v)) => Some(v)
|
|
case Return(None) => None
|
|
case Throw(_) => None
|
|
}
|
|
}
|
|
val scoreMap: Future[Map[Long, Double]] = results.map { scores =>
|
|
keys
|
|
.zip(scores).collect {
|
|
case (
|
|
ScoreId(
|
|
_,
|
|
ScoreInternalId.SimClustersEmbeddingPairScoreId(
|
|
SimClustersEmbeddingPairScoreId(
|
|
_,
|
|
LongSimClustersEmbeddingId(candidateId)))),
|
|
Some(score)) =>
|
|
(candidateId, score.score)
|
|
}.toMap
|
|
}
|
|
scoreMap
|
|
}
|
|
.map { (scores: Map[Long, Double]) =>
|
|
val orderedScores = key.candidateIds.collect {
|
|
case LongInternalId(id) => ScoreResult(scores.get(id))
|
|
case _ =>
|
|
// This will return None scores for candidates which don't have Long ids, but that's fine:
|
|
// at the moment we're only scoring for Tweets
|
|
ScoreResult(None)
|
|
}
|
|
found(ListScoreResponse(orderedScores))
|
|
}
|
|
.handle {
|
|
case stitch.NotFound => missing
|
|
}
|
|
}
|
|
}
|