the-algorithm/src/scala/com/twitter/recos/user_tweet_entity_graph/TweetRecommendationsRunner....

323 lines
12 KiB
Scala

package com.twitter.recos.user_tweet_entity_graph
import java.util.Random
import com.twitter.concurrent.AsyncQueue
import com.twitter.conversions.DurationOps._
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.graphjet.algorithms._
import com.twitter.graphjet.algorithms.filters._
import com.twitter.graphjet.algorithms.counting.TopSecondDegreeByCountResponse
import com.twitter.graphjet.algorithms.counting.tweet.TopSecondDegreeByCountForTweet
import com.twitter.graphjet.algorithms.counting.tweet.TopSecondDegreeByCountRequestForTweet
import com.twitter.graphjet.bipartite.NodeMetadataLeftIndexedMultiSegmentBipartiteGraph
import com.twitter.logging.Logger
import com.twitter.recos.graph_common.FinagleStatsReceiverWrapper
import com.twitter.recos.model.SalsaQueryRunner.SalsaRunnerConfig
import com.twitter.recos.recos_common.thriftscala.SocialProofType
import com.twitter.recos.user_tweet_entity_graph.thriftscala.RecommendTweetEntityRequest
import com.twitter.recos.user_tweet_entity_graph.thriftscala.TweetEntityDisplayLocation
import com.twitter.recos.user_tweet_entity_graph.thriftscala.TweetType
import com.twitter.recos.util.Stats.trackBlockStats
import com.twitter.util.Future
import com.twitter.util.JavaTimer
import com.twitter.util.Try
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap
import it.unimi.dsi.fastutil.longs.LongOpenHashSet
import scala.collection.JavaConverters._
import com.twitter.graphjet.algorithms.RecommendationType
import com.twitter.recos.user_tweet_entity_graph.thriftscala.{
RecommendationType => ThriftRecommendationType
}
import scala.collection.Map
import scala.collection.Set
object TweetRecommendationsRunner {
private val DefaultTweetTypes: Seq[TweetType] =
Seq(TweetType.Regular, TweetType.Summary, TweetType.Photo, TweetType.Player)
private val DefaultF1ExactSocialProofSize = 1
private val DefaultRareTweetRecencyMillis: Long = 7.days.inMillis
/**
* Map valid social proof types specified by clients to an array of bytes. If clients do not
* specify any social proof type unions in thrift, it will return an empty set by default.
*/
private def getSocialProofTypeUnions(
socialProofTypeUnions: Option[Set[Seq[SocialProofType]]]
): Set[Array[Byte]] = {
socialProofTypeUnions
.map {
_.map {
_.map {
_.getValue.toByte
}.toArray
}
}
.getOrElse(Set.empty)
}
private def getRecommendationTypes(
recommendationTypes: Seq[ThriftRecommendationType]
): Set[RecommendationType] = {
recommendationTypes.flatMap {
_ match {
case ThriftRecommendationType.Tweet => Some(RecommendationType.TWEET)
case ThriftRecommendationType.Hashtag => Some(RecommendationType.HASHTAG)
case ThriftRecommendationType.Url => Some(RecommendationType.URL)
case _ =>
throw new Exception("Unmatched Recommendation Type in getRecommendationTypes")
}
}.toSet
}
private def convertThriftEnumsToJavaEnums(
maxResults: Option[Map[ThriftRecommendationType, Int]]
): Map[RecommendationType, Integer] = {
maxResults
.map {
_.flatMap {
_ match {
case (ThriftRecommendationType.Tweet, v) => Some((RecommendationType.TWEET, v: Integer))
case (ThriftRecommendationType.Hashtag, v) =>
Some((RecommendationType.HASHTAG, v: Integer))
case (ThriftRecommendationType.Url, v) => Some((RecommendationType.URL, v: Integer))
case _ =>
throw new Exception("Unmatched Recommendation Type in convertThriftEnumsToJavaEnums")
}
}
}
.getOrElse(Map.empty)
}
}
/**
* The MagicRecsRunner creates a queue of reader threads, MagicRecs, and each one reads from the
* graph and computes recommendations.
*/
class TweetRecommendationsRunner(
bipartiteGraph: NodeMetadataLeftIndexedMultiSegmentBipartiteGraph,
salsaRunnerConfig: SalsaRunnerConfig,
statsReceiverWrapper: FinagleStatsReceiverWrapper) {
import TweetRecommendationsRunner._
private val log: Logger = Logger()
private val stats = statsReceiverWrapper.statsReceiver.scope(this.getClass.getSimpleName)
private val magicRecsFailureCounter = stats.counter("failure")
private val pollCounter = stats.counter("poll")
private val pollTimeoutCounter = stats.counter("pollTimeout")
private val offerCounter = stats.counter("offer")
private val pollLatencyStat = stats.stat("pollLatency")
private val magicRecsQueue = new AsyncQueue[TopSecondDegreeByCountForTweet]
(0 until salsaRunnerConfig.numSalsaRunners).foreach { _ =>
magicRecsQueue.offer(
new TopSecondDegreeByCountForTweet(
bipartiteGraph,
salsaRunnerConfig.expectedNodesToHitInSalsa,
statsReceiverWrapper.scope(this.getClass.getSimpleName)
)
)
}
private implicit val timer: JavaTimer = new JavaTimer(true)
private def getBaseFilters(
staleTweetDuration: Long,
tweetTypes: Seq[TweetType]
) = {
List(
// Keep RecentTweetFilter first since it's the cheapest
new RecentTweetFilter(staleTweetDuration, statsReceiverWrapper),
new TweetCardFilter(
tweetTypes.contains(TweetType.Regular),
tweetTypes.contains(TweetType.Summary),
tweetTypes.contains(TweetType.Photo),
tweetTypes.contains(TweetType.Player),
false, // no promoted tweets
statsReceiverWrapper
),
new DirectInteractionsFilter(bipartiteGraph, statsReceiverWrapper),
new RequestedSetFilter(statsReceiverWrapper),
new SocialProofTypesFilter(statsReceiverWrapper)
)
}
/**
* Helper method to interpret the output of MagicRecs graph
*
* @param magicRecsResponse is the response from running MagicRecs
* @return a sequence of candidate ids, with score and list of social proofs
*/
private def transformMagicRecsResponse(
magicRecsResponse: Option[TopSecondDegreeByCountResponse]
): Seq[RecommendationInfo] = {
val responses = magicRecsResponse match {
case Some(response) => response.getRankedRecommendations.asScala.toSeq
case _ => Nil
}
responses
}
/**
* Helper function to determine different post-process filtering logic in GraphJet,
* based on display locations
*/
private def getFiltersByDisplayLocations(
displayLocation: TweetEntityDisplayLocation,
whitelistAuthors: LongOpenHashSet,
blacklistAuthors: LongOpenHashSet,
validSocialProofs: Array[Byte]
) = {
displayLocation match {
case TweetEntityDisplayLocation.MagicRecsF1 =>
Seq(
new ANDFilters(
List[ResultFilter](
new TweetAuthorFilter(
bipartiteGraph,
whitelistAuthors,
new LongOpenHashSet(),
statsReceiverWrapper),
new ExactUserSocialProofSizeFilter(
DefaultF1ExactSocialProofSize,
validSocialProofs,
statsReceiverWrapper
)
).asJava,
statsReceiverWrapper
),
// Blacklist filter must be applied separately from F1's AND filter chain
new TweetAuthorFilter(
bipartiteGraph,
new LongOpenHashSet(),
blacklistAuthors,
statsReceiverWrapper)
)
case TweetEntityDisplayLocation.MagicRecsRareTweet =>
Seq(
new TweetAuthorFilter(
bipartiteGraph,
whitelistAuthors,
blacklistAuthors,
statsReceiverWrapper),
new RecentEdgeMetadataFilter(
DefaultRareTweetRecencyMillis,
UserTweetEdgeTypeMask.Tweet.id.toByte,
statsReceiverWrapper
)
)
case _ =>
Seq(
new TweetAuthorFilter(
bipartiteGraph,
whitelistAuthors,
blacklistAuthors,
statsReceiverWrapper))
}
}
/**
* Helper method to run salsa computation and convert the results to Option
*
* @param magicRecs is magicRecs reader on bipartite graph
* @param magicRecsRequest is the magicRecs request
* @return is an option of MagicRecsResponse
*/
private def getMagicRecsResponse(
magicRecs: TopSecondDegreeByCountForTweet,
magicRecsRequest: TopSecondDegreeByCountRequestForTweet
)(
implicit statsReceiver: StatsReceiver
): Option[TopSecondDegreeByCountResponse] = {
trackBlockStats(stats) {
val random = new Random()
// compute recs -- need to catch and print exceptions here otherwise they are swallowed
val magicRecsAttempt =
Try(magicRecs.computeRecommendations(magicRecsRequest, random)).onFailure { e =>
magicRecsFailureCounter.incr()
log.error(e, "MagicRecs computation failed")
}
magicRecsAttempt.toOption
}
}
private def getMagicRecsRequest(
request: RecommendTweetEntityRequest
): TopSecondDegreeByCountRequestForTweet = {
val requesterId = request.requesterId
val leftSeedNodes = new Long2DoubleOpenHashMap(
request.seedsWithWeights.keys.toArray,
request.seedsWithWeights.values.toArray
)
val tweetsToExcludeArray = new LongOpenHashSet(request.excludedTweetIds.getOrElse(Nil).toArray)
val staleTweetDuration = request.maxTweetAgeInMillis.getOrElse(RecosConfig.maxTweetAgeInMillis)
val staleEngagementDuration =
request.maxEngagementAgeInMillis.getOrElse(RecosConfig.maxEngagementAgeInMillis)
val tweetTypes = request.tweetTypes.getOrElse(DefaultTweetTypes)
val tweetAuthors = new LongOpenHashSet(request.tweetAuthors.getOrElse(Nil).toArray)
val excludedTweetAuthors = new LongOpenHashSet(
request.excludedTweetAuthors.getOrElse(Nil).toArray)
val validSocialProofs =
UserTweetEdgeTypeMask.getUserTweetGraphSocialProofTypes(request.socialProofTypes)
val resultFilterChain = new ResultFilterChain(
(
getBaseFilters(staleTweetDuration, tweetTypes) ++
getFiltersByDisplayLocations(
displayLocation = request.displayLocation,
whitelistAuthors = tweetAuthors,
blacklistAuthors = excludedTweetAuthors,
validSocialProofs = validSocialProofs
)
).asJava
)
new TopSecondDegreeByCountRequestForTweet(
requesterId,
leftSeedNodes,
tweetsToExcludeArray,
getRecommendationTypes(request.recommendationTypes).asJava,
convertThriftEnumsToJavaEnums(request.maxResultsByType).asJava,
UserTweetEdgeTypeMask.SIZE,
request.maxUserSocialProofSize.getOrElse(RecosConfig.maxUserSocialProofSize),
request.maxTweetSocialProofSize.getOrElse(RecosConfig.maxTweetSocialProofSize),
convertThriftEnumsToJavaEnums(request.minUserSocialProofSizes).asJava,
validSocialProofs,
staleTweetDuration,
staleEngagementDuration,
resultFilterChain,
getSocialProofTypeUnions(request.socialProofTypeUnions).asJava
)
}
def apply(request: RecommendTweetEntityRequest): Future[Seq[RecommendationInfo]] = {
pollCounter.incr()
val t0 = System.currentTimeMillis
magicRecsQueue.poll().map { magicRecs =>
val pollTime = System.currentTimeMillis - t0
pollLatencyStat.add(pollTime)
val magicRecsResponse = Try {
if (pollTime < salsaRunnerConfig.timeoutSalsaRunner) {
val magicRecsRequest = getMagicRecsRequest(request)
transformMagicRecsResponse(
getMagicRecsResponse(magicRecs, magicRecsRequest)(statsReceiverWrapper.statsReceiver)
)
} else {
// if we did not get a magicRecs in time, then fail fast here and immediately put it back
log.warning("magicRecsQueue polling timeout")
pollTimeoutCounter.incr()
throw new RuntimeException("magicRecs poll timeout")
Nil
}
} ensure {
magicRecsQueue.offer(magicRecs)
offerCounter.incr()
}
magicRecsResponse.toOption getOrElse Nil
}
}
}