the-algorithm/cr-mixer/server/src/main/scala/com/twitter/cr_mixer/similarity_engine/EarlybirdTensorflowBasedSim...

139 lines
5.8 KiB
Scala

package com.twitter.cr_mixer.similarity_engine
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.search.earlybird.thriftscala.EarlybirdRequest
import com.twitter.search.earlybird.thriftscala.EarlybirdService
import com.twitter.search.earlybird.thriftscala.ThriftSearchQuery
import com.twitter.util.Time
import com.twitter.search.common.query.thriftjava.thriftscala.CollectorParams
import com.twitter.search.common.ranking.thriftscala.ThriftRankingParams
import com.twitter.search.common.ranking.thriftscala.ThriftScoringFunctionType
import com.twitter.search.earlybird.thriftscala.ThriftSearchRelevanceOptions
import javax.inject.Inject
import javax.inject.Singleton
import EarlybirdSimilarityEngineBase._
import com.twitter.cr_mixer.config.TimeoutConfig
import com.twitter.cr_mixer.similarity_engine.EarlybirdTensorflowBasedSimilarityEngine.EarlybirdTensorflowBasedSearchQuery
import com.twitter.cr_mixer.util.EarlybirdSearchUtil.EarlybirdClientId
import com.twitter.cr_mixer.util.EarlybirdSearchUtil.FacetsToFetch
import com.twitter.cr_mixer.util.EarlybirdSearchUtil.GetCollectorTerminationParams
import com.twitter.cr_mixer.util.EarlybirdSearchUtil.GetEarlybirdQuery
import com.twitter.cr_mixer.util.EarlybirdSearchUtil.MetadataOptions
import com.twitter.cr_mixer.util.EarlybirdSearchUtil.GetNamedDisjunctions
import com.twitter.search.earlybird.thriftscala.ThriftSearchRankingMode
import com.twitter.simclusters_v2.common.TweetId
import com.twitter.simclusters_v2.common.UserId
import com.twitter.util.Duration
@Singleton
case class EarlybirdTensorflowBasedSimilarityEngine @Inject() (
earlybirdSearchClient: EarlybirdService.MethodPerEndpoint,
timeoutConfig: TimeoutConfig,
stats: StatsReceiver)
extends EarlybirdSimilarityEngineBase[EarlybirdTensorflowBasedSearchQuery] {
import EarlybirdTensorflowBasedSimilarityEngine._
override val statsReceiver: StatsReceiver = stats.scope(this.getClass.getSimpleName)
override def getEarlybirdRequest(
query: EarlybirdTensorflowBasedSearchQuery
): Option[EarlybirdRequest] = {
if (query.seedUserIds.nonEmpty)
Some(
EarlybirdRequest(
searchQuery = getThriftSearchQuery(query, timeoutConfig.earlybirdServerTimeout),
clientHost = None,
clientRequestID = None,
clientId = Some(EarlybirdClientId),
clientRequestTimeMs = Some(Time.now.inMilliseconds),
cachingParams = None,
timeoutMs = timeoutConfig.earlybirdServerTimeout.inMilliseconds.intValue(),
facetRequest = None,
termStatisticsRequest = None,
debugMode = 0,
debugOptions = None,
searchSegmentId = None,
returnStatusType = None,
successfulResponseThreshold = None,
querySource = None,
getOlderResults = Some(false),
followedUserIds = Some(query.seedUserIds),
adjustedProtectedRequestParams = None,
adjustedFullArchiveRequestParams = None,
getProtectedTweetsOnly = Some(false),
retokenizeSerializedQuery = None,
skipVeryRecentTweets = true,
experimentClusterToUse = None
))
else None
}
}
object EarlybirdTensorflowBasedSimilarityEngine {
case class EarlybirdTensorflowBasedSearchQuery(
searcherUserId: Option[UserId],
seedUserIds: Seq[UserId],
maxNumTweets: Int,
beforeTweetIdExclusive: Option[TweetId],
afterTweetIdExclusive: Option[TweetId],
filterOutRetweetsAndReplies: Boolean,
useTensorflowRanking: Boolean,
excludedTweetIds: Set[TweetId],
maxNumHitsPerShard: Int)
extends EarlybirdSearchQuery
private def getThriftSearchQuery(
query: EarlybirdTensorflowBasedSearchQuery,
processingTimeout: Duration
): ThriftSearchQuery =
ThriftSearchQuery(
serializedQuery = GetEarlybirdQuery(
query.beforeTweetIdExclusive,
query.afterTweetIdExclusive,
query.excludedTweetIds,
query.filterOutRetweetsAndReplies).map(_.serialize),
fromUserIDFilter64 = Some(query.seedUserIds),
numResults = query.maxNumTweets,
// Whether to collect conversation IDs. Remove it for now.
// collectConversationId = Gate.True(), // true for Home
rankingMode = ThriftSearchRankingMode.Relevance,
relevanceOptions = Some(getRelevanceOptions),
collectorParams = Some(
CollectorParams(
// numResultsToReturn defines how many results each EB shard will return to search root
numResultsToReturn = 1000,
// terminationParams.maxHitsToProcess is used for early terminating per shard results fetching.
terminationParams =
GetCollectorTerminationParams(query.maxNumHitsPerShard, processingTimeout)
)),
facetFieldNames = Some(FacetsToFetch),
resultMetadataOptions = Some(MetadataOptions),
searcherId = query.searcherUserId,
searchStatusIds = None,
namedDisjunctionMap = GetNamedDisjunctions(query.excludedTweetIds)
)
// The specific values of recap relevance/reranking options correspond to
// experiment: enable_recap_reranking_2988,timeline_internal_disable_recap_filter
// bucket : enable_rerank,disable_filter
private def getRelevanceOptions: ThriftSearchRelevanceOptions = {
ThriftSearchRelevanceOptions(
proximityScoring = true,
maxConsecutiveSameUser = Some(2),
rankingParams = Some(getTensorflowBasedRankingParams),
maxHitsToProcess = Some(500),
maxUserBlendCount = Some(3),
proximityPhraseWeight = 9.0,
returnAllResults = Some(true)
)
}
private def getTensorflowBasedRankingParams: ThriftRankingParams = {
ThriftRankingParams(
`type` = Some(ThriftScoringFunctionType.TensorflowBased),
selectedTensorflowModel = Some("timelines_rectweet_replica"),
minScore = -1.0e100,
applyBoosts = false,
authorSpecificScoreAdjustments = None
)
}
}