323 lines
12 KiB
Scala
323 lines
12 KiB
Scala
package com.twitter.recos.user_tweet_entity_graph
|
|
|
|
import java.util.Random
|
|
import com.twitter.concurrent.AsyncQueue
|
|
import com.twitter.conversions.DurationOps._
|
|
import com.twitter.finagle.stats.StatsReceiver
|
|
import com.twitter.graphjet.algorithms._
|
|
import com.twitter.graphjet.algorithms.filters._
|
|
import com.twitter.graphjet.algorithms.counting.TopSecondDegreeByCountResponse
|
|
import com.twitter.graphjet.algorithms.counting.tweet.TopSecondDegreeByCountForTweet
|
|
import com.twitter.graphjet.algorithms.counting.tweet.TopSecondDegreeByCountRequestForTweet
|
|
import com.twitter.graphjet.bipartite.NodeMetadataLeftIndexedMultiSegmentBipartiteGraph
|
|
import com.twitter.logging.Logger
|
|
import com.twitter.recos.graph_common.FinagleStatsReceiverWrapper
|
|
import com.twitter.recos.model.SalsaQueryRunner.SalsaRunnerConfig
|
|
import com.twitter.recos.recos_common.thriftscala.SocialProofType
|
|
import com.twitter.recos.user_tweet_entity_graph.thriftscala.RecommendTweetEntityRequest
|
|
import com.twitter.recos.user_tweet_entity_graph.thriftscala.TweetEntityDisplayLocation
|
|
import com.twitter.recos.user_tweet_entity_graph.thriftscala.TweetType
|
|
import com.twitter.recos.util.Stats.trackBlockStats
|
|
import com.twitter.util.Future
|
|
import com.twitter.util.JavaTimer
|
|
import com.twitter.util.Try
|
|
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap
|
|
import it.unimi.dsi.fastutil.longs.LongOpenHashSet
|
|
import scala.collection.JavaConverters._
|
|
|
|
import com.twitter.graphjet.algorithms.RecommendationType
|
|
import com.twitter.recos.user_tweet_entity_graph.thriftscala.{
|
|
RecommendationType => ThriftRecommendationType
|
|
}
|
|
import scala.collection.Map
|
|
import scala.collection.Set
|
|
|
|
object TweetRecommendationsRunner {
|
|
private val DefaultTweetTypes: Seq[TweetType] =
|
|
Seq(TweetType.Regular, TweetType.Summary, TweetType.Photo, TweetType.Player)
|
|
private val DefaultF1ExactSocialProofSize = 1
|
|
private val DefaultRareTweetRecencyMillis: Long = 7.days.inMillis
|
|
|
|
/**
|
|
* Map valid social proof types specified by clients to an array of bytes. If clients do not
|
|
* specify any social proof type unions in thrift, it will return an empty set by default.
|
|
*/
|
|
private def getSocialProofTypeUnions(
|
|
socialProofTypeUnions: Option[Set[Seq[SocialProofType]]]
|
|
): Set[Array[Byte]] = {
|
|
socialProofTypeUnions
|
|
.map {
|
|
_.map {
|
|
_.map {
|
|
_.getValue.toByte
|
|
}.toArray
|
|
}
|
|
}
|
|
.getOrElse(Set.empty)
|
|
}
|
|
|
|
private def getRecommendationTypes(
|
|
recommendationTypes: Seq[ThriftRecommendationType]
|
|
): Set[RecommendationType] = {
|
|
recommendationTypes.flatMap {
|
|
_ match {
|
|
case ThriftRecommendationType.Tweet => Some(RecommendationType.TWEET)
|
|
case ThriftRecommendationType.Hashtag => Some(RecommendationType.HASHTAG)
|
|
case ThriftRecommendationType.Url => Some(RecommendationType.URL)
|
|
case _ =>
|
|
throw new Exception("Unmatched Recommendation Type in getRecommendationTypes")
|
|
}
|
|
}.toSet
|
|
}
|
|
|
|
private def convertThriftEnumsToJavaEnums(
|
|
maxResults: Option[Map[ThriftRecommendationType, Int]]
|
|
): Map[RecommendationType, Integer] = {
|
|
maxResults
|
|
.map {
|
|
_.flatMap {
|
|
_ match {
|
|
case (ThriftRecommendationType.Tweet, v) => Some((RecommendationType.TWEET, v: Integer))
|
|
case (ThriftRecommendationType.Hashtag, v) =>
|
|
Some((RecommendationType.HASHTAG, v: Integer))
|
|
case (ThriftRecommendationType.Url, v) => Some((RecommendationType.URL, v: Integer))
|
|
case _ =>
|
|
throw new Exception("Unmatched Recommendation Type in convertThriftEnumsToJavaEnums")
|
|
}
|
|
}
|
|
}
|
|
.getOrElse(Map.empty)
|
|
}
|
|
|
|
}
|
|
|
|
/**
|
|
* The MagicRecsRunner creates a queue of reader threads, MagicRecs, and each one reads from the
|
|
* graph and computes recommendations.
|
|
*/
|
|
class TweetRecommendationsRunner(
|
|
bipartiteGraph: NodeMetadataLeftIndexedMultiSegmentBipartiteGraph,
|
|
salsaRunnerConfig: SalsaRunnerConfig,
|
|
statsReceiverWrapper: FinagleStatsReceiverWrapper) {
|
|
|
|
import TweetRecommendationsRunner._
|
|
|
|
private val log: Logger = Logger()
|
|
|
|
private val stats = statsReceiverWrapper.statsReceiver.scope(this.getClass.getSimpleName)
|
|
private val magicRecsFailureCounter = stats.counter("failure")
|
|
private val pollCounter = stats.counter("poll")
|
|
private val pollTimeoutCounter = stats.counter("pollTimeout")
|
|
private val offerCounter = stats.counter("offer")
|
|
private val pollLatencyStat = stats.stat("pollLatency")
|
|
|
|
private val magicRecsQueue = new AsyncQueue[TopSecondDegreeByCountForTweet]
|
|
(0 until salsaRunnerConfig.numSalsaRunners).foreach { _ =>
|
|
magicRecsQueue.offer(
|
|
new TopSecondDegreeByCountForTweet(
|
|
bipartiteGraph,
|
|
salsaRunnerConfig.expectedNodesToHitInSalsa,
|
|
statsReceiverWrapper.scope(this.getClass.getSimpleName)
|
|
)
|
|
)
|
|
}
|
|
|
|
private implicit val timer: JavaTimer = new JavaTimer(true)
|
|
|
|
private def getBaseFilters(
|
|
staleTweetDuration: Long,
|
|
tweetTypes: Seq[TweetType]
|
|
) = {
|
|
List(
|
|
// Keep RecentTweetFilter first since it's the cheapest
|
|
new RecentTweetFilter(staleTweetDuration, statsReceiverWrapper),
|
|
new TweetCardFilter(
|
|
tweetTypes.contains(TweetType.Regular),
|
|
tweetTypes.contains(TweetType.Summary),
|
|
tweetTypes.contains(TweetType.Photo),
|
|
tweetTypes.contains(TweetType.Player),
|
|
false, // no promoted tweets
|
|
statsReceiverWrapper
|
|
),
|
|
new DirectInteractionsFilter(bipartiteGraph, statsReceiverWrapper),
|
|
new RequestedSetFilter(statsReceiverWrapper),
|
|
new SocialProofTypesFilter(statsReceiverWrapper)
|
|
)
|
|
}
|
|
|
|
/**
|
|
* Helper method to interpret the output of MagicRecs graph
|
|
*
|
|
* @param magicRecsResponse is the response from running MagicRecs
|
|
* @return a sequence of candidate ids, with score and list of social proofs
|
|
*/
|
|
private def transformMagicRecsResponse(
|
|
magicRecsResponse: Option[TopSecondDegreeByCountResponse]
|
|
): Seq[RecommendationInfo] = {
|
|
val responses = magicRecsResponse match {
|
|
case Some(response) => response.getRankedRecommendations.asScala.toSeq
|
|
case _ => Nil
|
|
}
|
|
responses
|
|
}
|
|
|
|
/**
|
|
* Helper function to determine different post-process filtering logic in GraphJet,
|
|
* based on display locations
|
|
*/
|
|
private def getFiltersByDisplayLocations(
|
|
displayLocation: TweetEntityDisplayLocation,
|
|
whitelistAuthors: LongOpenHashSet,
|
|
blacklistAuthors: LongOpenHashSet,
|
|
validSocialProofs: Array[Byte]
|
|
) = {
|
|
displayLocation match {
|
|
case TweetEntityDisplayLocation.MagicRecsF1 =>
|
|
Seq(
|
|
new ANDFilters(
|
|
List[ResultFilter](
|
|
new TweetAuthorFilter(
|
|
bipartiteGraph,
|
|
whitelistAuthors,
|
|
new LongOpenHashSet(),
|
|
statsReceiverWrapper),
|
|
new ExactUserSocialProofSizeFilter(
|
|
DefaultF1ExactSocialProofSize,
|
|
validSocialProofs,
|
|
statsReceiverWrapper
|
|
)
|
|
).asJava,
|
|
statsReceiverWrapper
|
|
),
|
|
// Blacklist filter must be applied separately from F1's AND filter chain
|
|
new TweetAuthorFilter(
|
|
bipartiteGraph,
|
|
new LongOpenHashSet(),
|
|
blacklistAuthors,
|
|
statsReceiverWrapper)
|
|
)
|
|
case TweetEntityDisplayLocation.MagicRecsRareTweet =>
|
|
Seq(
|
|
new TweetAuthorFilter(
|
|
bipartiteGraph,
|
|
whitelistAuthors,
|
|
blacklistAuthors,
|
|
statsReceiverWrapper),
|
|
new RecentEdgeMetadataFilter(
|
|
DefaultRareTweetRecencyMillis,
|
|
UserTweetEdgeTypeMask.Tweet.id.toByte,
|
|
statsReceiverWrapper
|
|
)
|
|
)
|
|
case _ =>
|
|
Seq(
|
|
new TweetAuthorFilter(
|
|
bipartiteGraph,
|
|
whitelistAuthors,
|
|
blacklistAuthors,
|
|
statsReceiverWrapper))
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Helper method to run salsa computation and convert the results to Option
|
|
*
|
|
* @param magicRecs is magicRecs reader on bipartite graph
|
|
* @param magicRecsRequest is the magicRecs request
|
|
* @return is an option of MagicRecsResponse
|
|
*/
|
|
private def getMagicRecsResponse(
|
|
magicRecs: TopSecondDegreeByCountForTweet,
|
|
magicRecsRequest: TopSecondDegreeByCountRequestForTweet
|
|
)(
|
|
implicit statsReceiver: StatsReceiver
|
|
): Option[TopSecondDegreeByCountResponse] = {
|
|
trackBlockStats(stats) {
|
|
val random = new Random()
|
|
// compute recs -- need to catch and print exceptions here otherwise they are swallowed
|
|
val magicRecsAttempt =
|
|
Try(magicRecs.computeRecommendations(magicRecsRequest, random)).onFailure { e =>
|
|
magicRecsFailureCounter.incr()
|
|
log.error(e, "MagicRecs computation failed")
|
|
}
|
|
magicRecsAttempt.toOption
|
|
}
|
|
}
|
|
|
|
private def getMagicRecsRequest(
|
|
request: RecommendTweetEntityRequest
|
|
): TopSecondDegreeByCountRequestForTweet = {
|
|
val requesterId = request.requesterId
|
|
val leftSeedNodes = new Long2DoubleOpenHashMap(
|
|
request.seedsWithWeights.keys.toArray,
|
|
request.seedsWithWeights.values.toArray
|
|
)
|
|
val tweetsToExcludeArray = new LongOpenHashSet(request.excludedTweetIds.getOrElse(Nil).toArray)
|
|
val staleTweetDuration = request.maxTweetAgeInMillis.getOrElse(RecosConfig.maxTweetAgeInMillis)
|
|
val staleEngagementDuration =
|
|
request.maxEngagementAgeInMillis.getOrElse(RecosConfig.maxEngagementAgeInMillis)
|
|
val tweetTypes = request.tweetTypes.getOrElse(DefaultTweetTypes)
|
|
val tweetAuthors = new LongOpenHashSet(request.tweetAuthors.getOrElse(Nil).toArray)
|
|
val excludedTweetAuthors = new LongOpenHashSet(
|
|
request.excludedTweetAuthors.getOrElse(Nil).toArray)
|
|
val validSocialProofs =
|
|
UserTweetEdgeTypeMask.getUserTweetGraphSocialProofTypes(request.socialProofTypes)
|
|
|
|
val resultFilterChain = new ResultFilterChain(
|
|
(
|
|
getBaseFilters(staleTweetDuration, tweetTypes) ++
|
|
getFiltersByDisplayLocations(
|
|
displayLocation = request.displayLocation,
|
|
whitelistAuthors = tweetAuthors,
|
|
blacklistAuthors = excludedTweetAuthors,
|
|
validSocialProofs = validSocialProofs
|
|
)
|
|
).asJava
|
|
)
|
|
|
|
new TopSecondDegreeByCountRequestForTweet(
|
|
requesterId,
|
|
leftSeedNodes,
|
|
tweetsToExcludeArray,
|
|
getRecommendationTypes(request.recommendationTypes).asJava,
|
|
convertThriftEnumsToJavaEnums(request.maxResultsByType).asJava,
|
|
UserTweetEdgeTypeMask.SIZE,
|
|
request.maxUserSocialProofSize.getOrElse(RecosConfig.maxUserSocialProofSize),
|
|
request.maxTweetSocialProofSize.getOrElse(RecosConfig.maxTweetSocialProofSize),
|
|
convertThriftEnumsToJavaEnums(request.minUserSocialProofSizes).asJava,
|
|
validSocialProofs,
|
|
staleTweetDuration,
|
|
staleEngagementDuration,
|
|
resultFilterChain,
|
|
getSocialProofTypeUnions(request.socialProofTypeUnions).asJava
|
|
)
|
|
}
|
|
|
|
def apply(request: RecommendTweetEntityRequest): Future[Seq[RecommendationInfo]] = {
|
|
pollCounter.incr()
|
|
val t0 = System.currentTimeMillis
|
|
magicRecsQueue.poll().map { magicRecs =>
|
|
val pollTime = System.currentTimeMillis - t0
|
|
pollLatencyStat.add(pollTime)
|
|
val magicRecsResponse = Try {
|
|
if (pollTime < salsaRunnerConfig.timeoutSalsaRunner) {
|
|
val magicRecsRequest = getMagicRecsRequest(request)
|
|
transformMagicRecsResponse(
|
|
getMagicRecsResponse(magicRecs, magicRecsRequest)(statsReceiverWrapper.statsReceiver)
|
|
)
|
|
} else {
|
|
// if we did not get a magicRecs in time, then fail fast here and immediately put it back
|
|
log.warning("magicRecsQueue polling timeout")
|
|
pollTimeoutCounter.incr()
|
|
throw new RuntimeException("magicRecs poll timeout")
|
|
Nil
|
|
}
|
|
} ensure {
|
|
magicRecsQueue.offer(magicRecs)
|
|
offerCounter.incr()
|
|
}
|
|
magicRecsResponse.toOption getOrElse Nil
|
|
}
|
|
}
|
|
}
|