diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/util/TweetypieContentFeaturesProvider.docx b/timelineranker/server/src/main/scala/com/twitter/timelineranker/util/TweetypieContentFeaturesProvider.docx new file mode 100644 index 000000000..d591d595e Binary files /dev/null and b/timelineranker/server/src/main/scala/com/twitter/timelineranker/util/TweetypieContentFeaturesProvider.docx differ diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/util/TweetypieContentFeaturesProvider.scala b/timelineranker/server/src/main/scala/com/twitter/timelineranker/util/TweetypieContentFeaturesProvider.scala deleted file mode 100644 index 509f903d9..000000000 --- a/timelineranker/server/src/main/scala/com/twitter/timelineranker/util/TweetypieContentFeaturesProvider.scala +++ /dev/null @@ -1,114 +0,0 @@ -package com.twitter.timelineranker.util - -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.servo.util.Gate -import com.twitter.timelineranker.contentfeatures.ContentFeaturesProvider -import com.twitter.timelineranker.core.HydratedTweets -import com.twitter.timelineranker.model.RecapQuery -import com.twitter.timelineranker.recap.model.ContentFeatures -import com.twitter.timelines.clients.tweetypie.TweetyPieClient -import com.twitter.timelines.model.TweetId -import com.twitter.timelines.model.tweet.HydratedTweet -import com.twitter.timelines.util.FailOpenHandler -import com.twitter.tweetypie.thriftscala.MediaEntity -import com.twitter.tweetypie.thriftscala.TweetInclude -import com.twitter.tweetypie.thriftscala.{Tweet => TTweet} -import com.twitter.util.Future - -object TweetypieContentFeaturesProvider { - val DefaultTweetyPieFieldsToHydrate: Set[TweetInclude] = TweetyPieClient.CoreTweetFields ++ - TweetyPieClient.MediaFields ++ - TweetyPieClient.SelfThreadFields ++ - Set[TweetInclude](TweetInclude.MediaEntityFieldId(MediaEntity.AdditionalMetadataField.id)) - - //add Tweet fields from semantic core - val TweetyPieFieldsToHydrate: Set[TweetInclude] = DefaultTweetyPieFieldsToHydrate ++ - Set[TweetInclude](TweetInclude.TweetFieldId(TTweet.EscherbirdEntityAnnotationsField.id)) - val EmptyHydratedTweets: HydratedTweets = - HydratedTweets(Seq.empty[HydratedTweet], Seq.empty[HydratedTweet]) - val EmptyHydratedTweetsFuture: Future[HydratedTweets] = Future.value(EmptyHydratedTweets) -} - -class TweetypieContentFeaturesProvider( - tweetHydrator: TweetHydrator, - enableContentFeaturesGate: Gate[RecapQuery], - enableTokensInContentFeaturesGate: Gate[RecapQuery], - enableTweetTextInContentFeaturesGate: Gate[RecapQuery], - enableConversationControlContentFeaturesGate: Gate[RecapQuery], - enableTweetMediaHydrationGate: Gate[RecapQuery], - statsReceiver: StatsReceiver) - extends ContentFeaturesProvider { - val scopedStatsReceiver: StatsReceiver = statsReceiver.scope("TweetypieContentFeaturesProvider") - - override def apply( - query: RecapQuery, - tweetIds: Seq[TweetId] - ): Future[Map[TweetId, ContentFeatures]] = { - import TweetypieContentFeaturesProvider._ - - val tweetypieHydrationHandler = new FailOpenHandler(scopedStatsReceiver) - val hydratePenguinTextFeatures = enableContentFeaturesGate(query) - val hydrateSemanticCoreFeatures = enableContentFeaturesGate(query) - val hydrateTokens = enableTokensInContentFeaturesGate(query) - val hydrateTweetText = enableTweetTextInContentFeaturesGate(query) - val hydrateConversationControl = enableConversationControlContentFeaturesGate(query) - - val userId = query.userId - - val hydratedTweetsFuture = tweetypieHydrationHandler { - // tweetyPie fields to hydrate given hydrateSemanticCoreFeatures - val fieldsToHydrateWithSemanticCore = if (hydrateSemanticCoreFeatures) { - TweetyPieFieldsToHydrate - } else { - DefaultTweetyPieFieldsToHydrate - } - - // tweetyPie fields to hydrate given hydrateSemanticCoreFeatures & hydrateConversationControl - val fieldsToHydrateWithConversationControl = if (hydrateConversationControl) { - fieldsToHydrateWithSemanticCore ++ TweetyPieClient.ConversationControlField - } else { - fieldsToHydrateWithSemanticCore - } - - tweetHydrator.hydrate(Some(userId), tweetIds, fieldsToHydrateWithConversationControl) - - } { e: Throwable => EmptyHydratedTweetsFuture } - - hydratedTweetsFuture.map[Map[TweetId, ContentFeatures]] { hydratedTweets => - hydratedTweets.outerTweets.map { hydratedTweet => - val contentFeaturesFromTweet = ContentFeatures.Empty.copy( - selfThreadMetadata = hydratedTweet.tweet.selfThreadMetadata - ) - - val contentFeaturesWithText = TweetTextFeaturesExtractor.addTextFeaturesFromTweet( - contentFeaturesFromTweet, - hydratedTweet.tweet, - hydratePenguinTextFeatures, - hydrateTokens, - hydrateTweetText - ) - val contentFeaturesWithMedia = TweetMediaFeaturesExtractor.addMediaFeaturesFromTweet( - contentFeaturesWithText, - hydratedTweet.tweet, - enableTweetMediaHydrationGate(query) - ) - val contentFeaturesWithAnnotations = TweetAnnotationFeaturesExtractor - .addAnnotationFeaturesFromTweet( - contentFeaturesWithMedia, - hydratedTweet.tweet, - hydrateSemanticCoreFeatures - ) - // add conversationControl to content features if hydrateConversationControl is true - if (hydrateConversationControl) { - val contentFeaturesWithConversationControl = contentFeaturesWithAnnotations.copy( - conversationControl = hydratedTweet.tweet.conversationControl - ) - hydratedTweet.tweetId -> contentFeaturesWithConversationControl - } else { - hydratedTweet.tweetId -> contentFeaturesWithAnnotations - } - - }.toMap - } - } -} diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/BUILD b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/BUILD deleted file mode 100644 index 29e1c63cb..000000000 --- a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - tags = ["bazel-compatible"], - dependencies = [ - "servo/repo/src/main/scala", - "src/thrift/com/twitter/wtf/candidate:wtf-candidate-scala", - "timelineranker/server/src/main/scala/com/twitter/timelineranker/core", - "timelines:visibility", - "timelines/src/main/scala/com/twitter/timelines/clients/socialgraph", - "timelines/src/main/scala/com/twitter/timelines/util", - "timelines/src/main/scala/com/twitter/timelines/util/stats", - "util/util-core:util-core-util", - "util/util-logging/src/main/scala", - "util/util-stats/src/main/scala", - ], -) diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/BUILD.docx b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/BUILD.docx new file mode 100644 index 000000000..ba379e07c Binary files /dev/null and b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/BUILD.docx differ diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/FollowGraphDataProvider.docx b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/FollowGraphDataProvider.docx new file mode 100644 index 000000000..931a50e32 Binary files /dev/null and b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/FollowGraphDataProvider.docx differ diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/FollowGraphDataProvider.scala b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/FollowGraphDataProvider.scala deleted file mode 100644 index 4db2c4216..000000000 --- a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/FollowGraphDataProvider.scala +++ /dev/null @@ -1,25 +0,0 @@ -package com.twitter.timelineranker.visibility - -import com.twitter.timelineranker.core.FollowGraphData -import com.twitter.timelineranker.core.FollowGraphDataFuture -import com.twitter.timelines.model.UserId -import com.twitter.util.Future - -trait FollowGraphDataProvider { - - /** - * Gets follow graph data for the given user. - * - * @param userId user whose follow graph details are to be obtained. - * @param maxFollowingCount Maximum number of followed user IDs to fetch. - * If the given user follows more than these many users, - * then the most recent maxFollowingCount users are returned. - */ - def get(userId: UserId, maxFollowingCount: Int): Future[FollowGraphData] - - def getAsync(userId: UserId, maxFollowingCount: Int): FollowGraphDataFuture - - def getFollowing(userId: UserId, maxFollowingCount: Int): Future[Seq[UserId]] - - def getMutuallyFollowingUserIds(userId: UserId, followingIds: Seq[UserId]): Future[Set[UserId]] -} diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/RealGraphFollowGraphDataProvider.docx b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/RealGraphFollowGraphDataProvider.docx new file mode 100644 index 000000000..6e579cc29 Binary files /dev/null and b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/RealGraphFollowGraphDataProvider.docx differ diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/RealGraphFollowGraphDataProvider.scala b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/RealGraphFollowGraphDataProvider.scala deleted file mode 100644 index f4728c3a1..000000000 --- a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/RealGraphFollowGraphDataProvider.scala +++ /dev/null @@ -1,134 +0,0 @@ -package com.twitter.timelineranker.visibility - -import com.twitter.finagle.stats.Stat -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.servo.repository.KeyValueRepository -import com.twitter.servo.util.Gate -import com.twitter.timelineranker.core.FollowGraphData -import com.twitter.timelineranker.core.FollowGraphDataFuture -import com.twitter.timelines.clients.socialgraph.SocialGraphClient -import com.twitter.timelines.model.UserId -import com.twitter.timelines.util.FailOpenHandler -import com.twitter.util.Future -import com.twitter.util.Stopwatch -import com.twitter.wtf.candidate.thriftscala.CandidateSeq - -object RealGraphFollowGraphDataProvider { - val EmptyRealGraphResponse = CandidateSeq(Nil) -} - -/** - * Wraps an underlying FollowGraphDataProvider (which in practice will usually be a - * [[SgsFollowGraphDataProvider]]) and supplements the list of followings provided by the - * underlying provider with additional followings fetched from RealGraph if it looks like the - * underlying provider did not get the full list of the user's followings. - * - * First checks whether the size of the underlying following list is >= the max requested following - * count, which implies that there were additional followings beyond the max requested count. If so, - * fetches the full set of followings from RealGraph (go/realgraph), which will be at most 2000. - * - * Because the RealGraph dataset is not realtime and thus can potentially include stale followings, - * the provider confirms that the followings fetched from RealGraph are valid using SGS's - * getFollowOverlap method, and then merges the valid RealGraph followings with the underlying - * followings. - * - * Note that this supplementing is expected to be very rare as most users do not have more than - * the max followings we fetch from SGS. Also note that this class is mainly intended for use - * in the home timeline materialization path, with the goal of preventing a case where users - * who follow a very large number of accounts may not see Tweets from their earlier follows if we - * used SGS-based follow fetching alone. - */ -class RealGraphFollowGraphDataProvider( - underlying: FollowGraphDataProvider, - realGraphClient: KeyValueRepository[Seq[UserId], UserId, CandidateSeq], - socialGraphClient: SocialGraphClient, - supplementFollowsWithRealGraphGate: Gate[UserId], - statsReceiver: StatsReceiver) - extends FollowGraphDataProvider { - import RealGraphFollowGraphDataProvider._ - - private[this] val scopedStatsReceiver = statsReceiver.scope("realGraphFollowGraphDataProvider") - private[this] val requestCounter = scopedStatsReceiver.counter("requests") - private[this] val atMaxCounter = scopedStatsReceiver.counter("followsAtMax") - private[this] val totalLatencyStat = scopedStatsReceiver.stat("totalLatencyWhenSupplementing") - private[this] val supplementLatencyStat = scopedStatsReceiver.stat("supplementFollowsLatency") - private[this] val realGraphResponseSizeStat = scopedStatsReceiver.stat("realGraphFollows") - private[this] val realGraphEmptyCounter = scopedStatsReceiver.counter("realGraphEmpty") - private[this] val nonOverlappingSizeStat = scopedStatsReceiver.stat("nonOverlappingFollows") - - private[this] val failOpenHandler = new FailOpenHandler(scopedStatsReceiver) - - override def get(userId: UserId, maxFollowingCount: Int): Future[FollowGraphData] = { - getAsync(userId, maxFollowingCount).get() - } - - override def getAsync(userId: UserId, maxFollowingCount: Int): FollowGraphDataFuture = { - val startTime = Stopwatch.timeMillis() - val underlyingResult = underlying.getAsync(userId, maxFollowingCount) - if (supplementFollowsWithRealGraphGate(userId)) { - val supplementedFollows = underlyingResult.followedUserIdsFuture.flatMap { sgsFollows => - supplementFollowsWithRealGraph(userId, maxFollowingCount, sgsFollows, startTime) - } - underlyingResult.copy(followedUserIdsFuture = supplementedFollows) - } else { - underlyingResult - } - } - - override def getFollowing(userId: UserId, maxFollowingCount: Int): Future[Seq[UserId]] = { - val startTime = Stopwatch.timeMillis() - val underlyingFollows = underlying.getFollowing(userId, maxFollowingCount) - if (supplementFollowsWithRealGraphGate(userId)) { - underlying.getFollowing(userId, maxFollowingCount).flatMap { sgsFollows => - supplementFollowsWithRealGraph(userId, maxFollowingCount, sgsFollows, startTime) - } - } else { - underlyingFollows - } - } - - private[this] def supplementFollowsWithRealGraph( - userId: UserId, - maxFollowingCount: Int, - sgsFollows: Seq[Long], - startTime: Long - ): Future[Seq[UserId]] = { - requestCounter.incr() - if (sgsFollows.size >= maxFollowingCount) { - atMaxCounter.incr() - val supplementedFollowsFuture = realGraphClient(Seq(userId)) - .map(_.getOrElse(userId, EmptyRealGraphResponse)) - .map(_.candidates.map(_.userId)) - .flatMap { - case realGraphFollows if realGraphFollows.nonEmpty => - realGraphResponseSizeStat.add(realGraphFollows.size) - // Filter out "stale" follows from realgraph by checking them against SGS - val verifiedRealGraphFollows = - socialGraphClient.getFollowOverlap(userId, realGraphFollows) - verifiedRealGraphFollows.map { follows => - val combinedFollows = (sgsFollows ++ follows).distinct - val additionalFollows = combinedFollows.size - sgsFollows.size - if (additionalFollows > 0) nonOverlappingSizeStat.add(additionalFollows) - combinedFollows - } - case _ => - realGraphEmptyCounter.incr() - Future.value(sgsFollows) - } - .onSuccess { _ => totalLatencyStat.add(Stopwatch.timeMillis() - startTime) } - - Stat.timeFuture(supplementLatencyStat) { - failOpenHandler(supplementedFollowsFuture) { _ => Future.value(sgsFollows) } - } - } else { - Future.value(sgsFollows) - } - } - - override def getMutuallyFollowingUserIds( - userId: UserId, - followingIds: Seq[UserId] - ): Future[Set[UserId]] = { - underlying.getMutuallyFollowingUserIds(userId, followingIds) - } -} diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/SgsFollowGraphDataProvider.docx b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/SgsFollowGraphDataProvider.docx new file mode 100644 index 000000000..a834b7b6d Binary files /dev/null and b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/SgsFollowGraphDataProvider.docx differ diff --git a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/SgsFollowGraphDataProvider.scala b/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/SgsFollowGraphDataProvider.scala deleted file mode 100644 index 99497fc8b..000000000 --- a/timelineranker/server/src/main/scala/com/twitter/timelineranker/visibility/SgsFollowGraphDataProvider.scala +++ /dev/null @@ -1,266 +0,0 @@ -package com.twitter.timelineranker.visibility - -import com.twitter.finagle.stats.Stat -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.timelineranker.core.FollowGraphData -import com.twitter.timelineranker.core.FollowGraphDataFuture -import com.twitter.timelines.clients.socialgraph.ScopedSocialGraphClientFactory -import com.twitter.timelines.model._ -import com.twitter.timelines.util.FailOpenHandler -import com.twitter.timelines.util.stats._ -import com.twitter.timelines.visibility._ -import com.twitter.util.Future - -object SgsFollowGraphDataProvider { - val EmptyUserIdsSet: Set[UserId] = Set.empty[UserId] - val EmptyUserIdsSetFuture: Future[Set[UserId]] = Future.value(EmptyUserIdsSet) - val EmptyUserIdsSeq: Seq[UserId] = Seq.empty[UserId] - val EmptyUserIdsSeqFuture: Future[Seq[UserId]] = Future.value(EmptyUserIdsSeq) - val EmptyVisibilityProfiles: Map[UserId, VisibilityProfile] = Map.empty[UserId, VisibilityProfile] - val EmptyVisibilityProfilesFuture: Future[Map[UserId, VisibilityProfile]] = - Future.value(EmptyVisibilityProfiles) -} - -object SgsFollowGraphDataFields extends Enumeration { - val FollowedUserIds: Value = Value - val MutuallyFollowingUserIds: Value = Value - val MutedUserIds: Value = Value - val RetweetsMutedUserIds: Value = Value - - val None: ValueSet = SgsFollowGraphDataFields.ValueSet() - - def throwIfInvalid(fields: SgsFollowGraphDataFields.ValueSet): Unit = { - if (fields.contains(MutuallyFollowingUserIds) && !fields.contains(FollowedUserIds)) { - throw new IllegalArgumentException( - "MutuallyFollowingUserIds field requires FollowedUserIds field to be defined." - ) - } - } -} - -/** - * Provides information on the follow graph of a given user. - */ -class SgsFollowGraphDataProvider( - socialGraphClientFactory: ScopedSocialGraphClientFactory, - visibilityProfileHydratorFactory: VisibilityProfileHydratorFactory, - fieldsToFetch: SgsFollowGraphDataFields.ValueSet, - scope: RequestScope, - statsReceiver: StatsReceiver) - extends FollowGraphDataProvider - with RequestStats { - - SgsFollowGraphDataFields.throwIfInvalid(fieldsToFetch) - - private[this] val stats = scope.stats("followGraphDataProvider", statsReceiver) - private[this] val scopedStatsReceiver = stats.scopedStatsReceiver - - private[this] val followingScope = scopedStatsReceiver.scope("following") - private[this] val followingLatencyStat = followingScope.stat(LatencyMs) - private[this] val followingSizeStat = followingScope.stat(Size) - private[this] val followingTruncatedCounter = followingScope.counter("numTruncated") - - private[this] val mutuallyFollowingScope = scopedStatsReceiver.scope("mutuallyFollowing") - private[this] val mutuallyFollowingLatencyStat = mutuallyFollowingScope.stat(LatencyMs) - private[this] val mutuallyFollowingSizeStat = mutuallyFollowingScope.stat(Size) - - private[this] val visibilityScope = scopedStatsReceiver.scope("visibility") - private[this] val visibilityLatencyStat = visibilityScope.stat(LatencyMs) - private[this] val mutedStat = visibilityScope.stat("muted") - private[this] val retweetsMutedStat = visibilityScope.stat("retweetsMuted") - - private[this] val socialGraphClient = socialGraphClientFactory.scope(scope) - private[this] val visibilityProfileHydrator = - createVisibilityProfileHydrator(visibilityProfileHydratorFactory, scope, fieldsToFetch) - - private[this] val failOpenScope = scopedStatsReceiver.scope("failOpen") - private[this] val mutuallyFollowingHandler = - new FailOpenHandler(failOpenScope, "mutuallyFollowing") - - private[this] val obtainVisibilityProfiles = fieldsToFetch.contains( - SgsFollowGraphDataFields.MutedUserIds - ) || fieldsToFetch.contains(SgsFollowGraphDataFields.RetweetsMutedUserIds) - - /** - * Gets follow graph data for the given user. - * - * @param userId user whose follow graph details are to be obtained. - * @param maxFollowingCount Maximum number of followed user IDs to fetch. - * If the given user follows more than these many users, - * then the most recent maxFollowingCount users are returned. - */ - def get( - userId: UserId, - maxFollowingCount: Int - ): Future[FollowGraphData] = { - getAsync( - userId, - maxFollowingCount - ).get() - } - - def getAsync( - userId: UserId, - maxFollowingCount: Int - ): FollowGraphDataFuture = { - - stats.statRequest() - val followedUserIdsFuture = - if (fieldsToFetch.contains(SgsFollowGraphDataFields.FollowedUserIds)) { - getFollowing(userId, maxFollowingCount) - } else { - SgsFollowGraphDataProvider.EmptyUserIdsSeqFuture - } - - val mutuallyFollowingUserIdsFuture = - if (fieldsToFetch.contains(SgsFollowGraphDataFields.MutuallyFollowingUserIds)) { - followedUserIdsFuture.flatMap { followedUserIds => - getMutuallyFollowingUserIds(userId, followedUserIds) - } - } else { - SgsFollowGraphDataProvider.EmptyUserIdsSetFuture - } - - val visibilityProfilesFuture = if (obtainVisibilityProfiles) { - followedUserIdsFuture.flatMap { followedUserIds => - getVisibilityProfiles(userId, followedUserIds) - } - } else { - SgsFollowGraphDataProvider.EmptyVisibilityProfilesFuture - } - - val mutedUserIdsFuture = if (fieldsToFetch.contains(SgsFollowGraphDataFields.MutedUserIds)) { - getMutedUsers(visibilityProfilesFuture).map { mutedUserIds => - mutedStat.add(mutedUserIds.size) - mutedUserIds - } - } else { - SgsFollowGraphDataProvider.EmptyUserIdsSetFuture - } - - val retweetsMutedUserIdsFuture = - if (fieldsToFetch.contains(SgsFollowGraphDataFields.RetweetsMutedUserIds)) { - getRetweetsMutedUsers(visibilityProfilesFuture).map { retweetsMutedUserIds => - retweetsMutedStat.add(retweetsMutedUserIds.size) - retweetsMutedUserIds - } - } else { - SgsFollowGraphDataProvider.EmptyUserIdsSetFuture - } - - FollowGraphDataFuture( - userId, - followedUserIdsFuture, - mutuallyFollowingUserIdsFuture, - mutedUserIdsFuture, - retweetsMutedUserIdsFuture - ) - } - - private[this] def getVisibilityProfiles( - userId: UserId, - followingIds: Seq[UserId] - ): Future[Map[UserId, VisibilityProfile]] = { - Stat.timeFuture(visibilityLatencyStat) { - visibilityProfileHydrator(Some(userId), Future.value(followingIds.toSeq)) - } - } - - def getFollowing(userId: UserId, maxFollowingCount: Int): Future[Seq[UserId]] = { - Stat.timeFuture(followingLatencyStat) { - // We fetch 1 more than the limit so that we can decide if we ended up - // truncating the followings. - val followingIdsFuture = socialGraphClient.getFollowing(userId, Some(maxFollowingCount + 1)) - followingIdsFuture.map { followingIds => - followingSizeStat.add(followingIds.length) - if (followingIds.length > maxFollowingCount) { - followingTruncatedCounter.incr() - followingIds.take(maxFollowingCount) - } else { - followingIds - } - } - } - } - - def getMutuallyFollowingUserIds( - userId: UserId, - followingIds: Seq[UserId] - ): Future[Set[UserId]] = { - Stat.timeFuture(mutuallyFollowingLatencyStat) { - mutuallyFollowingHandler { - val mutuallyFollowingIdsFuture = - socialGraphClient.getFollowOverlap(followingIds.toSeq, userId) - mutuallyFollowingIdsFuture.map { mutuallyFollowingIds => - mutuallyFollowingSizeStat.add(mutuallyFollowingIds.size) - } - mutuallyFollowingIdsFuture - } { e: Throwable => SgsFollowGraphDataProvider.EmptyUserIdsSetFuture } - } - } - - private[this] def getRetweetsMutedUsers( - visibilityProfilesFuture: Future[Map[UserId, VisibilityProfile]] - ): Future[Set[UserId]] = { - // If the hydrator is not able to fetch retweets-muted status, we default to true. - getUsersMatchingVisibilityPredicate( - visibilityProfilesFuture, - (visibilityProfile: VisibilityProfile) => visibilityProfile.areRetweetsMuted.getOrElse(true) - ) - } - - private[this] def getMutedUsers( - visibilityProfilesFuture: Future[Map[UserId, VisibilityProfile]] - ): Future[Set[UserId]] = { - // If the hydrator is not able to fetch muted status, we default to true. - getUsersMatchingVisibilityPredicate( - visibilityProfilesFuture, - (visibilityProfile: VisibilityProfile) => visibilityProfile.isMuted.getOrElse(true) - ) - } - - private[this] def getUsersMatchingVisibilityPredicate( - visibilityProfilesFuture: Future[Map[UserId, VisibilityProfile]], - predicate: (VisibilityProfile => Boolean) - ): Future[Set[UserId]] = { - visibilityProfilesFuture.map { visibilityProfiles => - visibilityProfiles - .filter { - case (_, visibilityProfile) => - predicate(visibilityProfile) - } - .collect { case (userId, _) => userId } - .toSet - } - } - - private[this] def createVisibilityProfileHydrator( - factory: VisibilityProfileHydratorFactory, - scope: RequestScope, - fieldsToFetch: SgsFollowGraphDataFields.ValueSet - ): VisibilityProfileHydrator = { - val hydrationProfileRequest = HydrationProfileRequest( - getMuted = fieldsToFetch.contains(SgsFollowGraphDataFields.MutedUserIds), - getRetweetsMuted = fieldsToFetch.contains(SgsFollowGraphDataFields.RetweetsMutedUserIds) - ) - factory(hydrationProfileRequest, scope) - } -} - -class ScopedSgsFollowGraphDataProviderFactory( - socialGraphClientFactory: ScopedSocialGraphClientFactory, - visibilityProfileHydratorFactory: VisibilityProfileHydratorFactory, - fieldsToFetch: SgsFollowGraphDataFields.ValueSet, - statsReceiver: StatsReceiver) - extends ScopedFactory[SgsFollowGraphDataProvider] { - - override def scope(scope: RequestScope): SgsFollowGraphDataProvider = { - new SgsFollowGraphDataProvider( - socialGraphClientFactory, - visibilityProfileHydratorFactory, - fieldsToFetch, - scope, - statsReceiver - ) - } -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/BUILD b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/BUILD deleted file mode 100644 index 5ce990999..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/BUILD +++ /dev/null @@ -1,8 +0,0 @@ -target( - name = "earlybird_ranking", - dependencies = [ - "timelines/data_processing/ad_hoc/earlybird_ranking/common", - "timelines/data_processing/ad_hoc/earlybird_ranking/model_evaluation", - "timelines/data_processing/ad_hoc/earlybird_ranking/training_data_generation", - ], -) diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/BUILD.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/BUILD.docx new file mode 100644 index 000000000..f1e19d7d6 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/BUILD.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/BUILD b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/BUILD deleted file mode 100644 index f7dcdcfd5..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -scala_library( - name = "common", - sources = ["*.scala"], - platform = "java8", - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - ], - dependencies = [ - "src/java/com/twitter/ml/api:api-base", - "src/java/com/twitter/ml/api/constant", - "src/java/com/twitter/ml/api/transform", - "src/java/com/twitter/search/modeling/tweet_ranking", - "src/scala/com/twitter/ml/api/util", - "src/scala/com/twitter/timelines/prediction/features/common", - "src/scala/com/twitter/timelines/prediction/features/itl", - "src/scala/com/twitter/timelines/prediction/features/real_graph", - "src/scala/com/twitter/timelines/prediction/features/recap", - "src/scala/com/twitter/timelines/prediction/features/request_context", - "src/scala/com/twitter/timelines/prediction/features/time_features", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:transform-java", - ], -) diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/BUILD.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/BUILD.docx new file mode 100644 index 000000000..14c704929 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/BUILD.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingConfiguration.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingConfiguration.docx new file mode 100644 index 000000000..7b21fab27 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingConfiguration.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingConfiguration.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingConfiguration.scala deleted file mode 100644 index 201864038..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingConfiguration.scala +++ /dev/null @@ -1,271 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common - -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.ml.api.FeatureContext -import com.twitter.ml.api.ITransform -import com.twitter.ml.api.transform.CascadeTransform -import com.twitter.ml.api.transform.TransformFactory -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.search.common.features.SearchResultFeature -import com.twitter.search.common.features.ExternalTweetFeature -import com.twitter.search.common.features.TweetFeature -import com.twitter.timelines.prediction.features.recap.RecapFeatures -import com.twitter.timelines.prediction.features.request_context.RequestContextFeatures -import com.twitter.timelines.prediction.features.time_features.TimeDataRecordFeatures -import com.twitter.timelines.prediction.features.common.TimelinesSharedFeatures -import com.twitter.timelines.prediction.features.real_graph.RealGraphDataRecordFeatures -import scala.collection.JavaConverters._ -import java.lang.{Boolean => JBoolean} - -case class LabelInfo(name: String, downsampleFraction: Double, importance: Double) - -case class LabelInfoWithFeature(info: LabelInfo, feature: Feature[JBoolean]) - -trait EarlybirdTrainingConfiguration { - - protected def labels: Map[String, Feature.Binary] - - protected def weights: Map[String, Double] = Map( - "detail_expanded" -> 0.3, - "favorited" -> 1.0, - "open_linked" -> 0.1, - "photo_expanded" -> 0.03, - "profile_clicked" -> 1.0, - "replied" -> 9.0, - "retweeted" -> 1.0, - "video_playback50" -> 0.01 - ) - - // we basically should not downsample any of the precious positive data. - // importance are currently set to match the full model's weights. - protected def PositiveSamplingRate: Double = 1.0 - private def NegativeSamplingRate: Double = PositiveSamplingRate * 0.08 - - // we basically should not downsample any of the precious positive data. - // importance are currently set to match the full model's weights. - final lazy val LabelInfos: List[LabelInfoWithFeature] = { - assert(labels.keySet == weights.keySet) - labels.keySet.map(makeLabelInfoWithFeature).toList - } - - def makeLabelInfoWithFeature(labelName: String): LabelInfoWithFeature = { - LabelInfoWithFeature( - LabelInfo(labelName, PositiveSamplingRate, weights(labelName)), - labels(labelName)) - } - - final lazy val NegativeInfo: LabelInfo = LabelInfo("negative", NegativeSamplingRate, 1.0) - - // example of features available in schema based namespace: - protected def featureToSearchResultFeatureMap: Map[Feature[_], SearchResultFeature] = Map( - RecapFeatures.TEXT_SCORE -> TweetFeature.TEXT_SCORE, - RecapFeatures.REPLY_COUNT -> TweetFeature.REPLY_COUNT, - RecapFeatures.RETWEET_COUNT -> TweetFeature.RETWEET_COUNT, - RecapFeatures.FAV_COUNT -> TweetFeature.FAVORITE_COUNT, - RecapFeatures.HAS_CARD -> TweetFeature.HAS_CARD_FLAG, - RecapFeatures.HAS_CONSUMER_VIDEO -> TweetFeature.HAS_CONSUMER_VIDEO_FLAG, - RecapFeatures.HAS_PRO_VIDEO -> TweetFeature.HAS_PRO_VIDEO_FLAG, - // no corresponding HAS_NATIVE_VIDEO feature in TweetFeature - RecapFeatures.HAS_VINE -> TweetFeature.HAS_VINE_FLAG, - RecapFeatures.HAS_PERISCOPE -> TweetFeature.HAS_PERISCOPE_FLAG, - RecapFeatures.HAS_NATIVE_IMAGE -> TweetFeature.HAS_NATIVE_IMAGE_FLAG, - RecapFeatures.HAS_IMAGE -> TweetFeature.HAS_IMAGE_URL_FLAG, - RecapFeatures.HAS_NEWS -> TweetFeature.HAS_NEWS_URL_FLAG, - RecapFeatures.HAS_VIDEO -> TweetFeature.HAS_VIDEO_URL_FLAG, - RecapFeatures.HAS_TREND -> TweetFeature.HAS_TREND_FLAG, - RecapFeatures.HAS_MULTIPLE_HASHTAGS_OR_TRENDS -> TweetFeature.HAS_MULTIPLE_HASHTAGS_OR_TRENDS_FLAG, - RecapFeatures.IS_OFFENSIVE -> TweetFeature.IS_OFFENSIVE_FLAG, - RecapFeatures.IS_REPLY -> TweetFeature.IS_REPLY_FLAG, - RecapFeatures.IS_RETWEET -> TweetFeature.IS_RETWEET_FLAG, - RecapFeatures.IS_AUTHOR_BOT -> TweetFeature.IS_USER_BOT_FLAG, - RecapFeatures.FROM_VERIFIED_ACCOUNT -> TweetFeature.FROM_VERIFIED_ACCOUNT_FLAG, - RecapFeatures.USER_REP -> TweetFeature.USER_REPUTATION, - RecapFeatures.EMBEDS_IMPRESSION_COUNT -> TweetFeature.EMBEDS_IMPRESSION_COUNT, - RecapFeatures.EMBEDS_URL_COUNT -> TweetFeature.EMBEDS_URL_COUNT, - // RecapFeatures.VIDEO_VIEW_COUNT deprecated - RecapFeatures.FAV_COUNT_V2 -> TweetFeature.FAVORITE_COUNT_V2, - RecapFeatures.RETWEET_COUNT_V2 -> TweetFeature.RETWEET_COUNT_V2, - RecapFeatures.REPLY_COUNT_V2 -> TweetFeature.REPLY_COUNT_V2, - RecapFeatures.IS_SENSITIVE -> TweetFeature.IS_SENSITIVE_CONTENT, - RecapFeatures.HAS_MULTIPLE_MEDIA -> TweetFeature.HAS_MULTIPLE_MEDIA_FLAG, - RecapFeatures.IS_AUTHOR_PROFILE_EGG -> TweetFeature.PROFILE_IS_EGG_FLAG, - RecapFeatures.IS_AUTHOR_NEW -> TweetFeature.IS_USER_NEW_FLAG, - RecapFeatures.NUM_MENTIONS -> TweetFeature.NUM_MENTIONS, - RecapFeatures.NUM_HASHTAGS -> TweetFeature.NUM_HASHTAGS, - RecapFeatures.HAS_VISIBLE_LINK -> TweetFeature.HAS_VISIBLE_LINK_FLAG, - RecapFeatures.HAS_LINK -> TweetFeature.HAS_LINK_FLAG, - //note: DISCRETE features are not supported by the modelInterpreter tool. - // for the following features, we will create separate CONTINUOUS features instead of renaming - //RecapFeatures.LINK_LANGUAGE - //RecapFeatures.LANGUAGE - TimelinesSharedFeatures.HAS_QUOTE -> TweetFeature.HAS_QUOTE_FLAG, - TimelinesSharedFeatures.QUOTE_COUNT -> TweetFeature.QUOTE_COUNT, - TimelinesSharedFeatures.WEIGHTED_FAV_COUNT -> TweetFeature.WEIGHTED_FAVORITE_COUNT, - TimelinesSharedFeatures.WEIGHTED_QUOTE_COUNT -> TweetFeature.WEIGHTED_QUOTE_COUNT, - TimelinesSharedFeatures.WEIGHTED_REPLY_COUNT -> TweetFeature.WEIGHTED_REPLY_COUNT, - TimelinesSharedFeatures.WEIGHTED_RETWEET_COUNT -> TweetFeature.WEIGHTED_RETWEET_COUNT, - TimelinesSharedFeatures.DECAYED_FAVORITE_COUNT -> TweetFeature.DECAYED_FAVORITE_COUNT, - TimelinesSharedFeatures.DECAYED_RETWEET_COUNT -> TweetFeature.DECAYED_RETWEET_COUNT, - TimelinesSharedFeatures.DECAYED_REPLY_COUNT -> TweetFeature.DECAYED_RETWEET_COUNT, - TimelinesSharedFeatures.DECAYED_QUOTE_COUNT -> TweetFeature.DECAYED_QUOTE_COUNT, - TimelinesSharedFeatures.FAKE_FAVORITE_COUNT -> TweetFeature.FAKE_FAVORITE_COUNT, - TimelinesSharedFeatures.FAKE_RETWEET_COUNT -> TweetFeature.FAKE_RETWEET_COUNT, - TimelinesSharedFeatures.FAKE_REPLY_COUNT -> TweetFeature.FAKE_REPLY_COUNT, - TimelinesSharedFeatures.FAKE_QUOTE_COUNT -> TweetFeature.FAKE_QUOTE_COUNT, - TimelinesSharedFeatures.EMBEDS_IMPRESSION_COUNT_V2 -> TweetFeature.EMBEDS_IMPRESSION_COUNT_V2, - TimelinesSharedFeatures.EMBEDS_URL_COUNT_V2 -> TweetFeature.EMBEDS_URL_COUNT_V2, - TimelinesSharedFeatures.LABEL_ABUSIVE_FLAG -> TweetFeature.LABEL_ABUSIVE_FLAG, - TimelinesSharedFeatures.LABEL_ABUSIVE_HI_RCL_FLAG -> TweetFeature.LABEL_ABUSIVE_HI_RCL_FLAG, - TimelinesSharedFeatures.LABEL_DUP_CONTENT_FLAG -> TweetFeature.LABEL_DUP_CONTENT_FLAG, - TimelinesSharedFeatures.LABEL_NSFW_HI_PRC_FLAG -> TweetFeature.LABEL_NSFW_HI_PRC_FLAG, - TimelinesSharedFeatures.LABEL_NSFW_HI_RCL_FLAG -> TweetFeature.LABEL_NSFW_HI_RCL_FLAG, - TimelinesSharedFeatures.LABEL_SPAM_FLAG -> TweetFeature.LABEL_SPAM_FLAG, - TimelinesSharedFeatures.LABEL_SPAM_HI_RCL_FLAG -> TweetFeature.LABEL_SPAM_HI_RCL_FLAG - ) - - protected def derivedFeaturesAdder: ITransform = - new ITransform { - private val hasEnglishTweetDiffUiLangFeature = - featureInstanceFromSearchResultFeature(ExternalTweetFeature.HAS_ENGLISH_TWEET_DIFF_UI_LANG) - .asInstanceOf[Feature.Binary] - private val hasEnglishUiDiffTweetLangFeature = - featureInstanceFromSearchResultFeature(ExternalTweetFeature.HAS_ENGLISH_UI_DIFF_TWEET_LANG) - .asInstanceOf[Feature.Binary] - private val hasDiffLangFeature = - featureInstanceFromSearchResultFeature(ExternalTweetFeature.HAS_DIFF_LANG) - .asInstanceOf[Feature.Binary] - private val isSelfTweetFeature = - featureInstanceFromSearchResultFeature(ExternalTweetFeature.IS_SELF_TWEET) - .asInstanceOf[Feature.Binary] - private val tweetAgeInSecsFeature = - featureInstanceFromSearchResultFeature(ExternalTweetFeature.TWEET_AGE_IN_SECS) - .asInstanceOf[Feature.Continuous] - private val authorSpecificScoreFeature = - featureInstanceFromSearchResultFeature(ExternalTweetFeature.AUTHOR_SPECIFIC_SCORE) - .asInstanceOf[Feature.Continuous] - - // see comments above - private val linkLanguageFeature = new Feature.Continuous(TweetFeature.LINK_LANGUAGE.getName) - private val languageFeature = new Feature.Continuous(TweetFeature.LANGUAGE.getName) - - override def transformContext(featureContext: FeatureContext): FeatureContext = - featureContext.addFeatures( - authorSpecificScoreFeature, - // used when training against the full scoreEarlybirdModelEvaluationJob.scala - // TimelinesSharedFeatures.PREDICTED_SCORE_LOG, - hasEnglishTweetDiffUiLangFeature, - hasEnglishUiDiffTweetLangFeature, - hasDiffLangFeature, - isSelfTweetFeature, - tweetAgeInSecsFeature, - linkLanguageFeature, - languageFeature - ) - - override def transform(record: DataRecord): Unit = { - val srecord = SRichDataRecord(record) - - srecord.getFeatureValueOpt(RealGraphDataRecordFeatures.WEIGHT).map { realgraphWeight => - srecord.setFeatureValue(authorSpecificScoreFeature, realgraphWeight) - } - - // use this when training against the log of the full score - // srecord.getFeatureValueOpt(TimelinesSharedFeatures.PREDICTED_SCORE).map { score => - // if (score > 0.0) { - // srecord.setFeatureValue(TimelinesSharedFeatures.PREDICTED_SCORE_LOG, Math.log(score)) - // } - // } - - if (srecord.hasFeature(RequestContextFeatures.LANGUAGE_CODE) && srecord.hasFeature( - RecapFeatures.LANGUAGE)) { - val uilangIsEnglish = srecord - .getFeatureValue(RequestContextFeatures.LANGUAGE_CODE).toString == "en" - val tweetIsEnglish = srecord.getFeatureValue(RecapFeatures.LANGUAGE) == 5 - srecord.setFeatureValue( - hasEnglishTweetDiffUiLangFeature, - tweetIsEnglish && !uilangIsEnglish - ) - srecord.setFeatureValue( - hasEnglishUiDiffTweetLangFeature, - uilangIsEnglish && !tweetIsEnglish - ) - } - srecord.getFeatureValueOpt(RecapFeatures.MATCH_UI_LANG).map { match_ui_lang => - srecord.setFeatureValue( - hasDiffLangFeature, - !match_ui_lang - ) - } - - for { - author_id <- srecord.getFeatureValueOpt(SharedFeatures.AUTHOR_ID) - user_id <- srecord.getFeatureValueOpt(SharedFeatures.USER_ID) - } srecord.setFeatureValue( - isSelfTweetFeature, - author_id == user_id - ) - - srecord.getFeatureValueOpt(TimeDataRecordFeatures.TIME_SINCE_TWEET_CREATION).map { - time_since_tweet_creation => - srecord.setFeatureValue( - tweetAgeInSecsFeature, - time_since_tweet_creation / 1000.0 - ) - } - - srecord.getFeatureValueOpt(RecapFeatures.LINK_LANGUAGE).map { link_language => - srecord.setFeatureValue( - linkLanguageFeature, - link_language.toDouble - ) - } - srecord.getFeatureValueOpt(RecapFeatures.LANGUAGE).map { language => - srecord.setFeatureValue( - languageFeature, - language.toDouble - ) - } - } - } - - protected def featureInstanceFromSearchResultFeature( - tweetFeature: SearchResultFeature - ): Feature[_] = { - val featureType = tweetFeature.getType - val featureName = tweetFeature.getName - - require( - !tweetFeature.isDiscrete && ( - featureType == com.twitter.search.common.features.thrift.ThriftSearchFeatureType.BOOLEAN_VALUE || - featureType == com.twitter.search.common.features.thrift.ThriftSearchFeatureType.DOUBLE_VALUE || - featureType == com.twitter.search.common.features.thrift.ThriftSearchFeatureType.INT32_VALUE - ) - ) - - if (featureType == com.twitter.search.common.features.thrift.ThriftSearchFeatureType.BOOLEAN_VALUE) - new Feature.Binary(featureName) - else - new Feature.Continuous(featureName) - } - - lazy val EarlybirdFeatureRenamer: ITransform = { - val earlybirdFeatureRenameMap: Map[Feature[_], Feature[_]] = - featureToSearchResultFeatureMap.map { - case (originalFeature, tweetFeature) => - originalFeature -> featureInstanceFromSearchResultFeature(tweetFeature) - }.toMap - - new CascadeTransform( - List( - derivedFeaturesAdder, - TransformFactory.produceTransform( - TransformFactory.produceFeatureRenameTransformSpec( - earlybirdFeatureRenameMap.asJava - ) - ) - ).asJava - ) - } -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRecapConfiguration.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRecapConfiguration.docx new file mode 100644 index 000000000..62347cd51 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRecapConfiguration.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRecapConfiguration.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRecapConfiguration.scala deleted file mode 100644 index 59e046235..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRecapConfiguration.scala +++ /dev/null @@ -1,17 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common - -import com.twitter.ml.api.Feature -import com.twitter.timelines.prediction.features.recap.RecapFeatures - -class EarlybirdTrainingRecapConfiguration extends EarlybirdTrainingConfiguration { - override val labels: Map[String, Feature.Binary] = Map( - "detail_expanded" -> RecapFeatures.IS_CLICKED, - "favorited" -> RecapFeatures.IS_FAVORITED, - "open_linked" -> RecapFeatures.IS_OPEN_LINKED, - "photo_expanded" -> RecapFeatures.IS_PHOTO_EXPANDED, - "profile_clicked" -> RecapFeatures.IS_PROFILE_CLICKED, - "replied" -> RecapFeatures.IS_REPLIED, - "retweeted" -> RecapFeatures.IS_RETWEETED, - "video_playback50" -> RecapFeatures.IS_VIDEO_PLAYBACK_50 - ) -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRectweetConfiguration.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRectweetConfiguration.docx new file mode 100644 index 000000000..1f2b4639b Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRectweetConfiguration.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRectweetConfiguration.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRectweetConfiguration.scala deleted file mode 100644 index fbe61ef4d..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/common/EarlybirdTrainingRectweetConfiguration.scala +++ /dev/null @@ -1,100 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common - -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.ml.api.FeatureContext -import com.twitter.ml.api.ITransform -import com.twitter.ml.api.transform.CascadeTransform -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.search.common.features.SearchResultFeature -import com.twitter.search.common.features.TweetFeature -import com.twitter.timelines.prediction.features.itl.ITLFeatures._ -import scala.collection.JavaConverters._ - -class EarlybirdTrainingRectweetConfiguration extends EarlybirdTrainingConfiguration { - - override val labels: Map[String, Feature.Binary] = Map( - "detail_expanded" -> IS_CLICKED, - "favorited" -> IS_FAVORITED, - "open_linked" -> IS_OPEN_LINKED, - "photo_expanded" -> IS_PHOTO_EXPANDED, - "profile_clicked" -> IS_PROFILE_CLICKED, - "replied" -> IS_REPLIED, - "retweeted" -> IS_RETWEETED, - "video_playback50" -> IS_VIDEO_PLAYBACK_50 - ) - - override val PositiveSamplingRate: Double = 0.5 - - override def featureToSearchResultFeatureMap: Map[Feature[_], SearchResultFeature] = - super.featureToSearchResultFeatureMap ++ Map( - TEXT_SCORE -> TweetFeature.TEXT_SCORE, - REPLY_COUNT -> TweetFeature.REPLY_COUNT, - RETWEET_COUNT -> TweetFeature.RETWEET_COUNT, - FAV_COUNT -> TweetFeature.FAVORITE_COUNT, - HAS_CARD -> TweetFeature.HAS_CARD_FLAG, - HAS_CONSUMER_VIDEO -> TweetFeature.HAS_CONSUMER_VIDEO_FLAG, - HAS_PRO_VIDEO -> TweetFeature.HAS_PRO_VIDEO_FLAG, - HAS_VINE -> TweetFeature.HAS_VINE_FLAG, - HAS_PERISCOPE -> TweetFeature.HAS_PERISCOPE_FLAG, - HAS_NATIVE_IMAGE -> TweetFeature.HAS_NATIVE_IMAGE_FLAG, - HAS_IMAGE -> TweetFeature.HAS_IMAGE_URL_FLAG, - HAS_NEWS -> TweetFeature.HAS_NEWS_URL_FLAG, - HAS_VIDEO -> TweetFeature.HAS_VIDEO_URL_FLAG, - // some features that exist for recap are not available in rectweet - // HAS_TREND - // HAS_MULTIPLE_HASHTAGS_OR_TRENDS - // IS_OFFENSIVE - // IS_REPLY - // IS_RETWEET - IS_AUTHOR_BOT -> TweetFeature.IS_USER_BOT_FLAG, - IS_AUTHOR_SPAM -> TweetFeature.IS_USER_SPAM_FLAG, - IS_AUTHOR_NSFW -> TweetFeature.IS_USER_NSFW_FLAG, - // FROM_VERIFIED_ACCOUNT - USER_REP -> TweetFeature.USER_REPUTATION, - // EMBEDS_IMPRESSION_COUNT - // EMBEDS_URL_COUNT - // VIDEO_VIEW_COUNT - FAV_COUNT_V2 -> TweetFeature.FAVORITE_COUNT_V2, - RETWEET_COUNT_V2 -> TweetFeature.RETWEET_COUNT_V2, - REPLY_COUNT_V2 -> TweetFeature.REPLY_COUNT_V2, - IS_SENSITIVE -> TweetFeature.IS_SENSITIVE_CONTENT, - HAS_MULTIPLE_MEDIA -> TweetFeature.HAS_MULTIPLE_MEDIA_FLAG, - IS_AUTHOR_PROFILE_EGG -> TweetFeature.PROFILE_IS_EGG_FLAG, - IS_AUTHOR_NEW -> TweetFeature.IS_USER_NEW_FLAG, - NUM_MENTIONS -> TweetFeature.NUM_MENTIONS, - NUM_HASHTAGS -> TweetFeature.NUM_HASHTAGS, - HAS_VISIBLE_LINK -> TweetFeature.HAS_VISIBLE_LINK_FLAG, - HAS_LINK -> TweetFeature.HAS_LINK_FLAG - ) - - override def derivedFeaturesAdder: CascadeTransform = { - // only LINK_LANGUAGE availabe in rectweet. no LANGUAGE feature - val linkLanguageTransform = new ITransform { - private val linkLanguageFeature = new Feature.Continuous(TweetFeature.LINK_LANGUAGE.getName) - - override def transformContext(featureContext: FeatureContext): FeatureContext = - featureContext.addFeatures( - linkLanguageFeature - ) - - override def transform(record: DataRecord): Unit = { - val srecord = SRichDataRecord(record) - - srecord.getFeatureValueOpt(LINK_LANGUAGE).map { link_language => - srecord.setFeatureValue( - linkLanguageFeature, - link_language.toDouble - ) - } - } - } - - new CascadeTransform( - List( - super.derivedFeaturesAdder, - linkLanguageTransform - ).asJava - ) - } -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/BUILD b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/BUILD deleted file mode 100644 index 5fb8ef7f6..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -scala_library( - name = "model_evaluation", - sources = ["*.scala"], - platform = "java8", - strict_deps = False, - dependencies = [ - "3rdparty/src/jvm/com/twitter/scalding:json", - "src/scala/com/twitter/ml/api:api-base", - "src/scala/com/twitter/ml/api/prediction_engine", - "src/scala/com/twitter/ml/api/util", - "src/scala/com/twitter/scalding_internal/job", - "src/scala/com/twitter/timelines/prediction/adapters/recap", - "src/scala/com/twitter/timelines/prediction/features/recap", - "timelines/data_processing/ad_hoc/earlybird_ranking/common", - "timelines/data_processing/util:rich-request", - "timelines/data_processing/util/example", - "timelines/data_processing/util/execution", - "twadoop_config/configuration/log_categories/group/timelines:timelineservice_injection_request_log-scala", - ], -) - -hadoop_binary( - name = "bin", - basename = "earlybird_model_evaluation-deploy", - main = "com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.model_evaluation.EarlybirdModelEvaluationJob", - platform = "java8", - runtime_platform = "java8", - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - "bazel-only", - ], - dependencies = [ - ":model_evaluation", - ], -) diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/BUILD.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/BUILD.docx new file mode 100644 index 000000000..9eafd60fd Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/BUILD.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdEvaluationMetric.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdEvaluationMetric.docx new file mode 100644 index 000000000..d46ed5757 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdEvaluationMetric.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdEvaluationMetric.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdEvaluationMetric.scala deleted file mode 100644 index 3885fdc61..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdEvaluationMetric.scala +++ /dev/null @@ -1,203 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.model_evaluation - -import scala.collection.GenTraversableOnce - -case class CandidateRecord(tweetId: Long, fullScore: Double, earlyScore: Double, served: Boolean) - -/** - * A metric that compares scores generated by a "full" prediction - * model to a "light" (Earlybird) model. The metric is calculated for candidates - * from a single request. - */ -sealed trait EarlybirdEvaluationMetric { - def name: String - def apply(candidates: Seq[CandidateRecord]): Option[Double] -} - -/** - * Picks the set of `k` top candidates using light scores, and calculates - * recall of these light-score based candidates among set of `k` top candidates - * using full scores. - * - * If there are fewer than `k` candidates, then we can choose to filter out requests (will - * lower value of recall) or keep them by trivially computing recall as 1.0. - */ -case class TopKRecall(k: Int, filterFewerThanK: Boolean) extends EarlybirdEvaluationMetric { - override val name: String = s"top_${k}_recall${if (filterFewerThanK) "_filtered" else ""}" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - if (candidates.size <= k) { - if (filterFewerThanK) None else Some(1.0) - } else { - val topFull = candidates.sortBy(-_.fullScore).take(k) - val topLight = candidates.sortBy(-_.earlyScore).take(k) - val overlap = topFull.map(_.tweetId).intersect(topLight.map(_.tweetId)) - val truePos = overlap.size.toDouble - Some(truePos / k.toDouble) - } - } -} - -/** - * Calculates the probability that a random pair of candidates will be ordered the same by the - * full and earlybird models. - * - * Note: A pair with same scores for one model and different for the other will contribute 1 - * to the sum. Pairs that are strictly ordered the same, will contribute 2. - * It follows that the score for a constant model is 0.5, which is approximately equal to a - * random model as expected. - */ -case object ProbabilityOfCorrectOrdering extends EarlybirdEvaluationMetric { - - def fractionOf[A](trav: GenTraversableOnce[A])(p: A => Boolean): Double = { - if (trav.isEmpty) - 0.0 - else { - val (numPos, numElements) = trav.foldLeft((0, 0)) { - case ((numPosAcc, numElementsAcc), elem) => - (if (p(elem)) numPosAcc + 1 else numPosAcc, numElementsAcc + 1) - } - numPos.toDouble / numElements - } - } - - override def name: String = "probability_of_correct_ordering" - - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - if (candidates.size < 2) - None - else { - val pairs = for { - left <- candidates.iterator - right <- candidates.iterator - if left != right - } yield (left, right) - - val probabilityOfCorrect = fractionOf(pairs) { - case (left, right) => - (left.fullScore > right.fullScore) == (left.earlyScore > right.earlyScore) - } - - Some(probabilityOfCorrect) - } - } -} - -/** - * Like `TopKRecall`, but uses `n` % of top candidates instead. - */ -case class TopNPercentRecall(percent: Double) extends EarlybirdEvaluationMetric { - override val name: String = s"top_${percent}_pct_recall" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - val k = Math.floor(candidates.size * percent).toInt - if (k > 0) { - val topFull = candidates.sortBy(-_.fullScore).take(k) - val topLight = candidates.sortBy(-_.earlyScore).take(k) - val overlap = topFull.map(_.tweetId).intersect(topLight.map(_.tweetId)) - val truePos = overlap.size.toDouble - Some(truePos / k.toDouble) - } else { - None - } - } -} - -/** - * Picks the set of `k` top candidates using light scores, and calculates - * recall of selected light-score based candidates among set of actual - * shown candidates. - */ -case class ShownTweetRecall(k: Int) extends EarlybirdEvaluationMetric { - override val name: String = s"shown_tweet_recall_$k" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - if (candidates.size <= k) { - None - } else { - val topLight = candidates.sortBy(-_.earlyScore).take(k) - val truePos = topLight.count(_.served).toDouble - val allPos = candidates.count(_.served).toDouble - if (allPos > 0) Some(truePos / allPos) - else None - } - } -} - -/** - * Like `ShownTweetRecall`, but uses `n` % of top candidates instead. - */ -case class ShownTweetPercentRecall(percent: Double) extends EarlybirdEvaluationMetric { - override val name: String = s"shown_tweet_recall_${percent}_pct" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - val k = Math.floor(candidates.size * percent).toInt - val topLight = candidates.sortBy(-_.earlyScore).take(k) - val truePos = topLight.count(_.served).toDouble - val allPos = candidates.count(_.served).toDouble - if (allPos > 0) Some(truePos / allPos) - else None - } -} - -/** - * Like `ShownTweetRecall`, but calculated using *full* scores. This is a sanity metric, - * because by definition the top full-scored candidates will be served. If the value is - * < 1, this is due to the ranked section being smaller than k. - */ -case class ShownTweetRecallWithFullScores(k: Int) extends EarlybirdEvaluationMetric { - override val name: String = s"shown_tweet_recall_with_full_scores_$k" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - if (candidates.size <= k) { - None - } else { - val topFull = candidates.sortBy(-_.fullScore).take(k) - val truePos = topFull.count(_.served).toDouble - val allPos = candidates.count(_.served).toDouble - if (allPos > 0) Some(truePos / allPos) - else None - } - } -} - -/** - * Picks the set of `k` top candidates using the light scores, and calculates - * average full score for the candidates. - */ -case class AverageFullScoreForTopLight(k: Int) extends EarlybirdEvaluationMetric { - override val name: String = s"average_full_score_for_top_light_$k" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - if (candidates.size <= k) { - None - } else { - val topLight = candidates.sortBy(-_.earlyScore).take(k) - Some(topLight.map(_.fullScore).sum / topLight.size) - } - } -} - -/** - * Picks the set of `k` top candidates using the light scores, and calculates - * sum of full scores for those. Divides that by sum of `k` top full scores, - * overall, to get a "score recall". - */ -case class SumScoreRecallForTopLight(k: Int) extends EarlybirdEvaluationMetric { - override val name: String = s"sum_score_recall_for_top_light_$k" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = { - if (candidates.size <= k) { - None - } else { - val sumFullScoresForTopLight = candidates.sortBy(-_.earlyScore).take(k).map(_.fullScore).sum - val sumScoresForTopFull = candidates.sortBy(-_.fullScore).take(k).map(_.fullScore).sum - Some(sumFullScoresForTopLight / sumScoresForTopFull) - } - } -} - -case class HasFewerThanKCandidates(k: Int) extends EarlybirdEvaluationMetric { - override val name: String = s"has_fewer_than_${k}_candidates" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = - Some(if (candidates.size <= k) 1.0 else 0.0) -} - -case object NumberOfCandidates extends EarlybirdEvaluationMetric { - override val name: String = s"number_of_candidates" - override def apply(candidates: Seq[CandidateRecord]): Option[Double] = - Some(candidates.size.toDouble) -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdModelEvaluationJob.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdModelEvaluationJob.docx new file mode 100644 index 000000000..dfca26f08 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdModelEvaluationJob.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdModelEvaluationJob.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdModelEvaluationJob.scala deleted file mode 100644 index 2203146b8..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/model_evaluation/EarlybirdModelEvaluationJob.scala +++ /dev/null @@ -1,214 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.model_evaluation - -import com.twitter.algebird.Aggregator -import com.twitter.algebird.AveragedValue -import com.twitter.ml.api.prediction_engine.PredictionEnginePlugin -import com.twitter.ml.api.util.FDsl -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.IRecordOneToManyAdapter -import com.twitter.scalding.Args -import com.twitter.scalding.DateRange -import com.twitter.scalding.Execution -import com.twitter.scalding.TypedJson -import com.twitter.scalding.TypedPipe -import com.twitter.scalding_internal.dalv2.DAL -import com.twitter.scalding_internal.job.TwitterExecutionApp -import com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common.EarlybirdTrainingRecapConfiguration -import com.twitter.timelines.data_processing.util.RequestImplicits.RichRequest -import com.twitter.timelines.data_processing.util.example.RecapTweetExample -import com.twitter.timelines.data_processing.util.execution.UTCDateRangeFromArgs -import com.twitter.timelines.prediction.adapters.recap.RecapSuggestionRecordAdapter -import com.twitter.timelines.prediction.features.recap.RecapFeatures -import com.twitter.timelines.suggests.common.record.thriftscala.SuggestionRecord -import com.twitter.timelineservice.suggests.logging.recap.thriftscala.HighlightTweet -import com.twitter.timelineservice.suggests.logging.thriftscala.SuggestsRequestLog -import scala.collection.JavaConverters._ -import scala.language.reflectiveCalls -import scala.util.Random -import twadoop_config.configuration.log_categories.group.timelines.TimelineserviceInjectionRequestLogScalaDataset - -/** - * Evaluates an Earlybird model using 1% injection request logs. - * - * Arguments: - * --model_base_path path to Earlybird model snapshots - * --models list of model names to evaluate - * --output path to output stats - * --parallelism (default: 3) number of tasks to run in parallel - * --topks (optional) list of values of `k` (integers) for top-K metrics - * --topn_fractions (optional) list of values of `n` (doubles) for top-N-fraction metrics - * --seed (optional) seed for random number generator - */ -object EarlybirdModelEvaluationJob extends TwitterExecutionApp with UTCDateRangeFromArgs { - - import FDsl._ - import PredictionEnginePlugin._ - - private[this] val averager: Aggregator[Double, AveragedValue, Double] = - AveragedValue.aggregator - private[this] val recapAdapter: IRecordOneToManyAdapter[SuggestionRecord] = - new RecapSuggestionRecordAdapter(checkDwellTime = false) - - override def job: Execution[Unit] = { - for { - args <- Execution.getArgs - dateRange <- dateRangeEx - metrics = getMetrics(args) - random = buildRandom(args) - modelBasePath = args("model_base_path") - models = args.list("models") - parallelism = args.int("parallelism", 3) - logs = logsHavingCandidates(dateRange) - modelScoredCandidates = models.map { model => - (model, scoreCandidatesUsingModel(logs, s"$modelBasePath/$model")) - } - functionScoredCandidates = List( - ("random", scoreCandidatesUsingFunction(logs, _ => Some(random.nextDouble()))), - ("original_earlybird", scoreCandidatesUsingFunction(logs, extractOriginalEarlybirdScore)), - ("blender", scoreCandidatesUsingFunction(logs, extractBlenderScore)) - ) - allCandidates = modelScoredCandidates ++ functionScoredCandidates - statsExecutions = allCandidates.map { - case (name, pipe) => - for { - saved <- pipe.forceToDiskExecution - stats <- computeMetrics(saved, metrics, parallelism) - } yield (name, stats) - } - stats <- Execution.withParallelism(statsExecutions, parallelism) - _ <- TypedPipe.from(stats).writeExecution(TypedJson(args("output"))) - } yield () - } - - private[this] def computeMetrics( - requests: TypedPipe[Seq[CandidateRecord]], - metricsToCompute: Seq[EarlybirdEvaluationMetric], - parallelism: Int - ): Execution[Map[String, Double]] = { - val metricExecutions = metricsToCompute.map { metric => - val metricEx = requests.flatMap(metric(_)).aggregate(averager).toOptionExecution - metricEx.map { value => value.map((metric.name, _)) } - } - Execution.withParallelism(metricExecutions, parallelism).map(_.flatten.toMap) - } - - private[this] def getMetrics(args: Args): Seq[EarlybirdEvaluationMetric] = { - val topKs = args.list("topks").map(_.toInt) - val topNFractions = args.list("topn_fractions").map(_.toDouble) - val topKMetrics = topKs.flatMap { topK => - Seq( - TopKRecall(topK, filterFewerThanK = false), - TopKRecall(topK, filterFewerThanK = true), - ShownTweetRecall(topK), - AverageFullScoreForTopLight(topK), - SumScoreRecallForTopLight(topK), - HasFewerThanKCandidates(topK), - ShownTweetRecallWithFullScores(topK), - ProbabilityOfCorrectOrdering - ) - } - val topNPercentMetrics = topNFractions.flatMap { topNPercent => - Seq( - TopNPercentRecall(topNPercent), - ShownTweetPercentRecall(topNPercent) - ) - } - topKMetrics ++ topNPercentMetrics ++ Seq(NumberOfCandidates) - } - - private[this] def buildRandom(args: Args): Random = { - val seedOpt = args.optional("seed").map(_.toLong) - seedOpt.map(new Random(_)).getOrElse(new Random()) - } - - private[this] def logsHavingCandidates(dateRange: DateRange): TypedPipe[SuggestsRequestLog] = - DAL - .read(TimelineserviceInjectionRequestLogScalaDataset, dateRange) - .toTypedPipe - .filter(_.recapCandidates.exists(_.nonEmpty)) - - /** - * Uses a model defined at `earlybirdModelPath` to score candidates and - * returns a Seq[CandidateRecord] for each request. - */ - private[this] def scoreCandidatesUsingModel( - logs: TypedPipe[SuggestsRequestLog], - earlybirdModelPath: String - ): TypedPipe[Seq[CandidateRecord]] = { - logs - .usingScorer(earlybirdModelPath) - .map { - case (scorer: PredictionEngineScorer, log: SuggestsRequestLog) => - val suggestionRecords = - RecapTweetExample - .extractCandidateTweetExamples(log) - .map(_.asSuggestionRecord) - val servedTweetIds = log.servedHighlightTweets.flatMap(_.tweetId).toSet - val renamer = (new EarlybirdTrainingRecapConfiguration).EarlybirdFeatureRenamer - suggestionRecords.flatMap { suggestionRecord => - val dataRecordOpt = recapAdapter.adaptToDataRecords(suggestionRecord).asScala.headOption - dataRecordOpt.foreach(renamer.transform) - for { - tweetId <- suggestionRecord.itemId - fullScore <- suggestionRecord.recapFeatures.flatMap(_.combinedModelScore) - earlybirdScore <- dataRecordOpt.flatMap(calculateLightScore(_, scorer)) - } yield CandidateRecord( - tweetId = tweetId, - fullScore = fullScore, - earlyScore = earlybirdScore, - served = servedTweetIds.contains(tweetId) - ) - } - } - } - - /** - * Uses a simple function to score candidates and returns a Seq[CandidateRecord] for each - * request. - */ - private[this] def scoreCandidatesUsingFunction( - logs: TypedPipe[SuggestsRequestLog], - earlyScoreExtractor: HighlightTweet => Option[Double] - ): TypedPipe[Seq[CandidateRecord]] = { - logs - .map { log => - val tweetCandidates = log.recapTweetCandidates.getOrElse(Nil) - val servedTweetIds = log.servedHighlightTweets.flatMap(_.tweetId).toSet - for { - candidate <- tweetCandidates - tweetId <- candidate.tweetId - fullScore <- candidate.recapFeatures.flatMap(_.combinedModelScore) - earlyScore <- earlyScoreExtractor(candidate) - } yield CandidateRecord( - tweetId = tweetId, - fullScore = fullScore, - earlyScore = earlyScore, - served = servedTweetIds.contains(tweetId) - ) - } - } - - private[this] def extractOriginalEarlybirdScore(candidate: HighlightTweet): Option[Double] = - for { - recapFeatures <- candidate.recapFeatures - tweetFeatures <- recapFeatures.tweetFeatures - } yield tweetFeatures.earlybirdScore - - private[this] def extractBlenderScore(candidate: HighlightTweet): Option[Double] = - for { - recapFeatures <- candidate.recapFeatures - tweetFeatures <- recapFeatures.tweetFeatures - } yield tweetFeatures.blenderScore - - private[this] def calculateLightScore( - dataRecord: DataRecord, - scorer: PredictionEngineScorer - ): Option[Double] = { - val scoredRecord = scorer(dataRecord) - if (scoredRecord.hasFeature(RecapFeatures.PREDICTED_IS_UNIFIED_ENGAGEMENT)) { - Some(scoredRecord.getFeatureValue(RecapFeatures.PREDICTED_IS_UNIFIED_ENGAGEMENT).toDouble) - } else { - None - } - } -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/BUILD b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/BUILD deleted file mode 100644 index e49e08758..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/BUILD +++ /dev/null @@ -1,89 +0,0 @@ -create_datarecord_datasets( - base_name = "earlybird_recap_data_records", - platform = "java8", - role = "timelines", - segment_type = "partitioned", - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - ], -) - -create_datarecord_datasets( - base_name = "earlybird_rectweet_data_records", - platform = "java8", - role = "timelines", - segment_type = "partitioned", - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - ], -) - -scala_library( - name = "training_data_generation", - sources = ["*.scala"], - platform = "java8", - strict_deps = True, - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - ], - dependencies = [ - ":earlybird_recap_data_records-java", - ":earlybird_rectweet_data_records-java", - "3rdparty/jvm/com/ibm/icu:icu4j", - "3rdparty/src/jvm/com/twitter/scalding:json", - "src/java/com/twitter/ml/api:api-base", - "src/java/com/twitter/ml/api/constant", - "src/java/com/twitter/ml/api/matcher", - "src/java/com/twitter/search/common/features", - "src/scala/com/twitter/ml/api:api-base", - "src/scala/com/twitter/ml/api/analytics", - "src/scala/com/twitter/ml/api/util", - "src/scala/com/twitter/scalding_internal/dalv2", - "src/scala/com/twitter/scalding_internal/dalv2/dataset", - "src/scala/com/twitter/scalding_internal/job", - "src/scala/com/twitter/scalding_internal/job/analytics_batch", - "src/scala/com/twitter/timelines/prediction/features/common", - "src/scala/com/twitter/timelines/prediction/features/recap", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:dataset-analytics-java", - "timelines/data_processing/ad_hoc/earlybird_ranking/common", - "timelines/data_processing/ad_hoc/recap/dataset_utils", - "timelines/data_processing/ad_hoc/recap/offline_execution", - "timelines/data_processing/util/execution", - ], -) - -hadoop_binary( - name = "bin", - basename = "earlybird_training_data_generation-deploy", - main = "com.twitter.scalding.Tool", - platform = "java8", - runtime_platform = "java8", - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - "bazel-only", - ], - dependencies = [ - ":training_data_generation", - ], -) - -hadoop_binary( - name = "earlybird_training_data_generation_prod", - basename = "earlybird_training_data_generation_prod-deploy", - main = "com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.training_data_generation.EarlybirdTrainingDataProdJob", - platform = "java8", - runtime_platform = "java8", - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - "bazel-only", - ], - dependencies = [ - ":training_data_generation", - ], -) diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/BUILD.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/BUILD.docx new file mode 100644 index 000000000..9aa9e3f81 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/BUILD.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdExampleSampler.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdExampleSampler.docx new file mode 100644 index 000000000..ac4ee9b26 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdExampleSampler.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdExampleSampler.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdExampleSampler.scala deleted file mode 100644 index b1aff5bd4..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdExampleSampler.scala +++ /dev/null @@ -1,65 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.training_data_generation - -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.DataSetPipe -import com.twitter.ml.api.Feature -import com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common.LabelInfo -import com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common.LabelInfoWithFeature -import com.twitter.timelines.prediction.features.recap.RecapFeatures -import java.lang.{Double => JDouble} -import scala.util.Random - -/** - * Adds an IsGlobalEngagement label to records containing any recap label, and adjusts - * weights accordingly. See [[weightAndSample]] for details on operation. - */ -class EarlybirdExampleSampler( - random: Random, - labelInfos: List[LabelInfoWithFeature], - negativeInfo: LabelInfo) { - - import com.twitter.ml.api.util.FDsl._ - - private[this] val ImportanceFeature: Feature[JDouble] = - SharedFeatures.RECORD_WEIGHT_FEATURE_BUILDER - .extensionBuilder() - .addExtension("type", "earlybird") - .build() - - private[this] def uniformSample(labelInfo: LabelInfo) = - random.nextDouble() < labelInfo.downsampleFraction - - private[this] def weightedImportance(labelInfo: LabelInfo) = - labelInfo.importance / labelInfo.downsampleFraction - - /** - * Generates a IsGlobalEngagement label for records that contain any - * recap label. Adds an "importance" value per recap label found - * in the record. Simultaneously, downsamples positive and negative examples based on provided - * downsample rates. - */ - def weightAndSample(data: DataSetPipe): DataSetPipe = { - val updatedRecords = data.records.flatMap { record => - val featuresOn = labelInfos.filter(labelInfo => record.hasFeature(labelInfo.feature)) - if (featuresOn.nonEmpty) { - val sampled = featuresOn.map(_.info).filter(uniformSample) - if (sampled.nonEmpty) { - record.setFeatureValue(RecapFeatures.IS_EARLYBIRD_UNIFIED_ENGAGEMENT, true) - Some(record.setFeatureValue(ImportanceFeature, sampled.map(weightedImportance).sum)) - } else { - None - } - } else if (uniformSample(negativeInfo)) { - Some(record.setFeatureValue(ImportanceFeature, weightedImportance(negativeInfo))) - } else { - None - } - } - - DataSetPipe( - updatedRecords, - data.featureContext - .addFeatures(ImportanceFeature, RecapFeatures.IS_EARLYBIRD_UNIFIED_ENGAGEMENT) - ) - } -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdStatsJob.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdStatsJob.docx new file mode 100644 index 000000000..528dad97b Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdStatsJob.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdStatsJob.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdStatsJob.scala deleted file mode 100644 index 140ee6f94..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdStatsJob.scala +++ /dev/null @@ -1,63 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.training_data_generation - -import com.twitter.ml.api.analytics.DataSetAnalyticsPlugin -import com.twitter.ml.api.matcher.FeatureMatcher -import com.twitter.ml.api.util.FDsl -import com.twitter.ml.api.DailySuffixFeatureSource -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.DataSetPipe -import com.twitter.ml.api.FeatureStats -import com.twitter.ml.api.IMatcher -import com.twitter.scalding.typed.TypedPipe -import com.twitter.scalding.Execution -import com.twitter.scalding.TypedJson -import com.twitter.scalding_internal.job.TwitterExecutionApp -import com.twitter.timelines.data_processing.util.execution.UTCDateRangeFromArgs -import com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common.EarlybirdTrainingConfiguration -import com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common.EarlybirdTrainingRecapConfiguration -import com.twitter.timelines.prediction.features.recap.RecapFeatures -import scala.collection.JavaConverters._ - -/** - * Compute counts and fractions for all labels in a Recap data source. - * - * Arguments: - * --input recap data source (containing all labels) - * --output path to output JSON file containing stats - */ -object EarlybirdStatsJob extends TwitterExecutionApp with UTCDateRangeFromArgs { - - import DataSetAnalyticsPlugin._ - import FDsl._ - import RecapFeatures.IS_EARLYBIRD_UNIFIED_ENGAGEMENT - - lazy val constants: EarlybirdTrainingConfiguration = new EarlybirdTrainingRecapConfiguration - private[this] def addGlobalEngagementLabel(record: DataRecord) = { - if (constants.LabelInfos.exists { labelInfo => record.hasFeature(labelInfo.feature) }) { - record.setFeatureValue(IS_EARLYBIRD_UNIFIED_ENGAGEMENT, true) - } - record - } - - private[this] def labelFeatureMatcher: IMatcher = { - val allLabels = - (IS_EARLYBIRD_UNIFIED_ENGAGEMENT :: constants.LabelInfos.map(_.feature)).map(_.getFeatureName) - FeatureMatcher.names(allLabels.asJava) - } - - private[this] def computeStats(data: DataSetPipe): TypedPipe[FeatureStats] = { - data - .viaRecords { _.map(addGlobalEngagementLabel) } - .project(labelFeatureMatcher) - .collectFeatureStats() - } - - override def job: Execution[Unit] = { - for { - args <- Execution.getArgs - dateRange <- dateRangeEx - data = DailySuffixFeatureSource(args("input"))(dateRange).read - _ <- computeStats(data).writeExecution(TypedJson(args("output"))) - } yield () - } -} diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdTrainingDataJob.docx b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdTrainingDataJob.docx new file mode 100644 index 000000000..87a3662e8 Binary files /dev/null and b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdTrainingDataJob.docx differ diff --git a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdTrainingDataJob.scala b/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdTrainingDataJob.scala deleted file mode 100644 index 6b614d78f..000000000 --- a/timelines/data_processing/ad_hoc/earlybird_ranking/earlybird_ranking/training_data_generation/EarlybirdTrainingDataJob.scala +++ /dev/null @@ -1,92 +0,0 @@ -package com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.training_data_generation - -import com.twitter.ml.api.HourlySuffixFeatureSource -import com.twitter.ml.api.IRecord -import com.twitter.scalding.Args -import com.twitter.scalding.DateRange -import com.twitter.scalding.Days -import com.twitter.scalding.Execution -import com.twitter.scalding.ExecutionUtil -import com.twitter.scalding_internal.dalv2.DALWrite.D -import com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common.EarlybirdTrainingRecapConfiguration -import com.twitter.timelines.data_processing.ad_hoc.earlybird_ranking.common.EarlybirdTrainingRectweetConfiguration -import com.twitter.timelines.data_processing.ad_hoc.recap.offline_execution.OfflineAdhocExecution -import com.twitter.timelines.data_processing.ad_hoc.recap.offline_execution.OfflineAnalyticsBatchExecution -import com.twitter.timelines.data_processing.ad_hoc.recap.offline_execution.OfflineExecution -import scala.util.Random -import com.twitter.scalding_internal.dalv2.dataset.DALWrite._ -import com.twitter.timelines.prediction.features.common.TimelinesSharedFeatures -import timelines.data_processing.ad_hoc.earlybird_ranking.training_data_generation._ - -/** - * Generates data for training an Earlybird-friendly model. - * Produces a single "global" engagement, and samples data accordingly. - * Also converts features from Earlybird to their original Earlybird - * feature names so they can be used as is in EB. - * - * Arguments: - * --input path to raw Recap training data (all labels) - * --output path to write sampled Earlybird-friendly training data - * --seed (optional) for random number generator (in sampling) - * --parallelism (default: 1) number of days to generate data for in parallel - * [splits long date range into single days] - */ -trait GenerateEarlybirdTrainingData { _: OfflineExecution => - - def isEligibleForEarlybirdScoring(record: IRecord): Boolean = { - // The rationale behind this logic is available in TQ-9678. - record.getFeatureValue(TimelinesSharedFeatures.EARLYBIRD_SCORE) <= 100.0 - } - - override def executionFromParams(args: Args)(implicit dateRange: DateRange): Execution[Unit] = { - val seedOpt = args.optional("seed").map(_.toLong) - val parallelism = args.int("parallelism", 1) - val rectweet = args.boolean("rectweet") - - ExecutionUtil - .runDateRangeWithParallelism(Days(1), parallelism) { splitRange => - val data = HourlySuffixFeatureSource(args("input"))(splitRange).read - .filter(isEligibleForEarlybirdScoring _) - - lazy val rng = seedOpt.map(new Random(_)).getOrElse(new Random()) - - val (constants, sink) = - if (rectweet) - (new EarlybirdTrainingRectweetConfiguration, EarlybirdRectweetDataRecordsJavaDataset) - else (new EarlybirdTrainingRecapConfiguration, EarlybirdRecapDataRecordsJavaDataset) - - val earlybirdSampler = - new EarlybirdExampleSampler( - random = rng, - labelInfos = constants.LabelInfos, - negativeInfo = constants.NegativeInfo - ) - val outputPath = args("output") - earlybirdSampler - .weightAndSample(data) - .transform(constants.EarlybirdFeatureRenamer) - // shuffle row-wise in order to get rid of clustered replies - // also keep number of part files small - .viaRecords { record => - record - .groupRandomly(partitions = 500) - .sortBy { _ => rng.nextDouble() } - .values - } - .writeDALExecution( - sink, - D.Daily, - D.Suffix(outputPath), - D.EBLzo() - )(splitRange) - }(dateRange).unit - } -} - -object EarlybirdTrainingDataAdHocJob - extends OfflineAdhocExecution - with GenerateEarlybirdTrainingData - -object EarlybirdTrainingDataProdJob - extends OfflineAnalyticsBatchExecution - with GenerateEarlybirdTrainingData diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregateGroup.docx b/timelines/data_processing/ml_util/aggregation_framework/AggregateGroup.docx new file mode 100644 index 000000000..3fe70c323 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/AggregateGroup.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregateGroup.scala b/timelines/data_processing/ml_util/aggregation_framework/AggregateGroup.scala deleted file mode 100644 index 6797d838a..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/AggregateGroup.scala +++ /dev/null @@ -1,124 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.ml.api._ -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetric -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.EasyMetric -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.MaxMetric -import com.twitter.timelines.data_processing.ml_util.transforms.OneToSomeTransform -import com.twitter.util.Duration -import java.lang.{Boolean => JBoolean} -import java.lang.{Long => JLong} -import scala.language.existentials - -/** - * A wrapper for [[com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup]] - * (see TypedAggregateGroup.scala) with some convenient syntactic sugar that avoids - * the user having to specify different groups for different types of features. - * Gets translated into multiple strongly typed TypedAggregateGroup(s) - * by the buildTypedAggregateGroups() method defined below. - * - * @param inputSource Source to compute this aggregate over - * @param preTransforms Sequence of [[ITransform]] that is applied to - * data records pre-aggregation (e.g. discretization, renaming) - * @param samplingTransformOpt Optional [[OneToSomeTransform]] that samples data record - * @param aggregatePrefix Prefix to use for naming resultant aggregate features - * @param keys Features to group by when computing the aggregates - * (e.g. USER_ID, AUTHOR_ID). These must be either discrete, string or sparse binary. - * Grouping by a sparse binary feature is different than grouping by a discrete or string - * feature. For example, if you have a sparse binary feature WORDS_IN_TWEET which is - * a set of all words in a tweet, then grouping by this feature generates a - * separate aggregate mean/count/etc for each value of the feature (each word), and - * not just a single aggregate count for different "sets of words" - * @param features Features to aggregate (e.g. blender_score or is_photo). - * @param labels Labels to cross the features with to make pair features, if any. - * @param metrics Aggregation metrics to compute (e.g. count, mean) - * @param halfLives Half lives to use for the aggregations, to be crossed with the above. - * use Duration.Top for "forever" aggregations over an infinite time window (no decay). - * @param outputStore Store to output this aggregate to - * @param includeAnyFeature Aggregate label counts for any feature value - * @param includeAnyLabel Aggregate feature counts for any label value (e.g. all impressions) - * @param includeTimestampFeature compute max aggregate on timestamp feature - * @param aggExclusionRegex Sequence of Regexes, which define features to - */ -case class AggregateGroup( - inputSource: AggregateSource, - aggregatePrefix: String, - keys: Set[Feature[_]], - features: Set[Feature[_]], - labels: Set[_ <: Feature[JBoolean]], - metrics: Set[EasyMetric], - halfLives: Set[Duration], - outputStore: AggregateStore, - preTransforms: Seq[OneToSomeTransform] = Seq.empty, - includeAnyFeature: Boolean = true, - includeAnyLabel: Boolean = true, - includeTimestampFeature: Boolean = false, - aggExclusionRegex: Seq[String] = Seq.empty) { - - private def toStrongType[T]( - metrics: Set[EasyMetric], - features: Set[Feature[_]], - featureType: FeatureType - ): TypedAggregateGroup[_] = { - val underlyingMetrics: Set[AggregationMetric[T, _]] = - metrics.flatMap(_.forFeatureType[T](featureType)) - val underlyingFeatures: Set[Feature[T]] = features - .map(_.asInstanceOf[Feature[T]]) - - TypedAggregateGroup[T]( - inputSource = inputSource, - aggregatePrefix = aggregatePrefix, - keysToAggregate = keys, - featuresToAggregate = underlyingFeatures, - labels = labels, - metrics = underlyingMetrics, - halfLives = halfLives, - outputStore = outputStore, - preTransforms = preTransforms, - includeAnyFeature, - includeAnyLabel, - aggExclusionRegex - ) - } - - private def timestampTypedAggregateGroup: TypedAggregateGroup[_] = { - val metrics: Set[AggregationMetric[JLong, _]] = - Set(MaxMetric.forFeatureType[JLong](TypedAggregateGroup.timestampFeature.getFeatureType).get) - - TypedAggregateGroup[JLong]( - inputSource = inputSource, - aggregatePrefix = aggregatePrefix, - keysToAggregate = keys, - featuresToAggregate = Set(TypedAggregateGroup.timestampFeature), - labels = Set.empty, - metrics = metrics, - halfLives = Set(Duration.Top), - outputStore = outputStore, - preTransforms = preTransforms, - includeAnyFeature = false, - includeAnyLabel = true, - aggExclusionRegex = Seq.empty - ) - } - - def buildTypedAggregateGroups(): List[TypedAggregateGroup[_]] = { - val typedAggregateGroupsList = { - if (features.isEmpty) { - List(toStrongType(metrics, features, FeatureType.BINARY)) - } else { - features - .groupBy(_.getFeatureType()) - .toList - .map { - case (featureType, features) => - toStrongType(metrics, features, featureType) - } - } - } - - val optionalTimestampTypedAggregateGroup = - if (includeTimestampFeature) List(timestampTypedAggregateGroup) else List() - - typedAggregateGroupsList ++ optionalTimestampTypedAggregateGroup - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregateSource.docx b/timelines/data_processing/ml_util/aggregation_framework/AggregateSource.docx new file mode 100644 index 000000000..a4cc03894 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/AggregateSource.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregateSource.scala b/timelines/data_processing/ml_util/aggregation_framework/AggregateSource.scala deleted file mode 100644 index 7fb239c65..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/AggregateSource.scala +++ /dev/null @@ -1,9 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.ml.api.Feature -import java.lang.{Long => JLong} - -trait AggregateSource extends Serializable { - def name: String - def timestampFeature: Feature[JLong] -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregateStore.docx b/timelines/data_processing/ml_util/aggregation_framework/AggregateStore.docx new file mode 100644 index 000000000..d7a225b5a Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/AggregateStore.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregateStore.scala b/timelines/data_processing/ml_util/aggregation_framework/AggregateStore.scala deleted file mode 100644 index 1c09b33f0..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/AggregateStore.scala +++ /dev/null @@ -1,5 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -trait AggregateStore extends Serializable { - def name: String -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregationConfig.docx b/timelines/data_processing/ml_util/aggregation_framework/AggregationConfig.docx new file mode 100644 index 000000000..b0e1cc431 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/AggregationConfig.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregationConfig.scala b/timelines/data_processing/ml_util/aggregation_framework/AggregationConfig.scala deleted file mode 100644 index 2b117ddbd..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/AggregationConfig.scala +++ /dev/null @@ -1,5 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -trait AggregationConfig { - def aggregatesToCompute: Set[TypedAggregateGroup[_]] -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregationKey.docx b/timelines/data_processing/ml_util/aggregation_framework/AggregationKey.docx new file mode 100644 index 000000000..909f35b77 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/AggregationKey.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/AggregationKey.scala b/timelines/data_processing/ml_util/aggregation_framework/AggregationKey.scala deleted file mode 100644 index c3aafef69..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/AggregationKey.scala +++ /dev/null @@ -1,50 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.bijection.Bufferable -import com.twitter.bijection.Injection -import scala.util.Try - -/** - * Case class that represents the "grouping" key for any aggregate feature. - * Used by Summingbird to output aggregates to the key-value "store" using sumByKey() - * - * @discreteFeaturesById All discrete featureids (+ values) that are part of this key - * @textFeaturesById All string featureids (+ values) that are part of this key - * - * Example 1: the user aggregate features in aggregatesv1 all group by USER_ID, - * which is a discrete feature. When storing these features, the key would be: - * - * discreteFeaturesById = Map(hash(USER_ID) -> ), textFeaturesById = Map() - * - * Ex 2: If aggregating grouped by USER_ID, AUTHOR_ID, tweet link url, the key would be: - * - * discreteFeaturesById = Map(hash(USER_ID) -> , hash(AUTHOR_ID) -> ), - * textFeaturesById = Map(hash(URL_FEATURE) -> ) - * - * I could have just used a DataRecord for the key, but I wanted to make it strongly typed - * and only support grouping by discrete and string features, so using a case class instead. - * - * Re: efficiency, storing the hash of the feature in addition to just the feature value - * is somewhat more inefficient than only storing the feature value in the key, but it - * adds flexibility to group multiple types of aggregates in the same output store. If we - * decide this isn't a good tradeoff to make later, we can reverse/refactor this decision. - */ -case class AggregationKey( - discreteFeaturesById: Map[Long, Long], - textFeaturesById: Map[Long, String]) - -/** - * A custom injection for the above case class, - * so that Summingbird knows how to store it in Manhattan. - */ -object AggregationKeyInjection extends Injection[AggregationKey, Array[Byte]] { - /* Injection from tuple representation of AggregationKey to Array[Byte] */ - val featureMapsInjection: Injection[(Map[Long, Long], Map[Long, String]), Array[Byte]] = - Bufferable.injectionOf[(Map[Long, Long], Map[Long, String])] - - def apply(aggregationKey: AggregationKey): Array[Byte] = - featureMapsInjection(AggregationKey.unapply(aggregationKey).get) - - def invert(ab: Array[Byte]): Try[AggregationKey] = - featureMapsInjection.invert(ab).map(AggregationKey.tupled(_)) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/BUILD b/timelines/data_processing/ml_util/aggregation_framework/BUILD deleted file mode 100644 index aff488116..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/BUILD +++ /dev/null @@ -1,101 +0,0 @@ -scala_library( - name = "common_types", - sources = ["*.scala"], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/google/guava", - "3rdparty/jvm/com/twitter/algebird:bijection", - "3rdparty/jvm/com/twitter/algebird:core", - "3rdparty/jvm/com/twitter/algebird:util", - "3rdparty/jvm/com/twitter/bijection:core", - "3rdparty/jvm/com/twitter/bijection:json", - "3rdparty/jvm/com/twitter/bijection:macros", - "3rdparty/jvm/com/twitter/bijection:netty", - "3rdparty/jvm/com/twitter/bijection:scrooge", - "3rdparty/jvm/com/twitter/bijection:thrift", - "3rdparty/jvm/com/twitter/bijection:util", - "3rdparty/jvm/org/apache/thrift:libthrift", - "3rdparty/src/jvm/com/twitter/scalding:date", - "3rdparty/src/jvm/com/twitter/summingbird:batch", - "src/java/com/twitter/ml/api:api-base", - "src/java/com/twitter/ml/api/constant", - "src/scala/com/twitter/dal/client/dataset", - "src/scala/com/twitter/ml/api/util:datarecord", - "src/scala/com/twitter/scalding_internal/dalv2/vkvs", - "src/scala/com/twitter/scalding_internal/multiformat/format/keyval", - "src/scala/com/twitter/storehaus_internal/manhattan/config", - "src/scala/com/twitter/storehaus_internal/offline", - "src/scala/com/twitter/storehaus_internal/util", - "src/scala/com/twitter/summingbird_internal/bijection:bijection-implicits", - "src/scala/com/twitter/summingbird_internal/runner/store_config", - "src/thrift/com/twitter/dal/personal_data:personal_data-java", - "src/thrift/com/twitter/dal/personal_data:personal_data-scala", - "src/thrift/com/twitter/ml/api:data-java", - "timelines/data_processing/ml_util/aggregation_framework/metrics", - "timelines/data_processing/ml_util/transforms", - "util/util-core:util-core-util", - ], -) - -target( - name = "common_online_stores", - dependencies = [ - "src/scala/com/twitter/storehaus_internal/memcache", - ], -) - -target( - name = "common_offline_stores", - dependencies = [ - "src/scala/com/twitter/storehaus_internal/manhattan", - ], -) - -target( - name = "user_job", - dependencies = [ - "timelines/data_processing/ml_util/aggregation_framework/job", - ], -) - -target( - name = "scalding", - dependencies = [ - "timelines/data_processing/ml_util/aggregation_framework/scalding", - ], -) - -target( - name = "conversion", - dependencies = [ - "timelines/data_processing/ml_util/aggregation_framework/conversion", - ], -) - -target( - name = "query", - dependencies = [ - "timelines/data_processing/ml_util/aggregation_framework/query", - ], -) - -target( - name = "heron", - dependencies = [ - "timelines/data_processing/ml_util/aggregation_framework/heron", - ], -) - -target( - dependencies = [ - ":common_offline_stores", - ":common_online_stores", - ":common_types", - ":conversion", - ":heron", - ":query", - ":scalding", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/BUILD.docx new file mode 100644 index 000000000..296397160 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/DataRecordAggregationMonoid.docx b/timelines/data_processing/ml_util/aggregation_framework/DataRecordAggregationMonoid.docx new file mode 100644 index 000000000..bcaeb62b3 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/DataRecordAggregationMonoid.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/DataRecordAggregationMonoid.scala b/timelines/data_processing/ml_util/aggregation_framework/DataRecordAggregationMonoid.scala deleted file mode 100644 index bc37c8e05..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/DataRecordAggregationMonoid.scala +++ /dev/null @@ -1,92 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.algebird.Monoid -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.util.SRichDataRecord -import scala.collection.mutable -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon._ - -/** - * Monoid to aggregate over DataRecord objects. - * - * @param aggregates Set of ''TypedAggregateGroup'' case classes* - * to compute using this monoid (see TypedAggregateGroup.scala) - */ -trait DataRecordMonoid extends Monoid[DataRecord] { - - val aggregates: Set[TypedAggregateGroup[_]] - - def zero(): DataRecord = new DataRecord - - /* - * Add two datarecords using this monoid. - * - * @param left Left datarecord to add - * @param right Right datarecord to add - * @return Sum of the two datarecords as a DataRecord - */ - def plus(left: DataRecord, right: DataRecord): DataRecord = { - val result = zero() - aggregates.foreach(_.mutatePlus(result, left, right)) - val leftTimestamp = getTimestamp(left) - val rightTimestamp = getTimestamp(right) - SRichDataRecord(result).setFeatureValue( - SharedFeatures.TIMESTAMP, - leftTimestamp.max(rightTimestamp) - ) - result - } -} - -case class DataRecordAggregationMonoid(aggregates: Set[TypedAggregateGroup[_]]) - extends DataRecordMonoid { - - private def sumBuffer(buffer: mutable.ArrayBuffer[DataRecord]): Unit = { - val bufferSum = zero() - buffer.toIterator.foreach { value => - val leftTimestamp = getTimestamp(bufferSum) - val rightTimestamp = getTimestamp(value) - aggregates.foreach(_.mutatePlus(bufferSum, bufferSum, value)) - SRichDataRecord(bufferSum).setFeatureValue( - SharedFeatures.TIMESTAMP, - leftTimestamp.max(rightTimestamp) - ) - } - - buffer.clear() - buffer += bufferSum - } - - /* - * Efficient batched aggregation of datarecords using - * this monoid + a buffer, for performance. - * - * @param dataRecordIter An iterator of datarecords to sum - * @return A datarecord option containing the sum - */ - override def sumOption(dataRecordIter: TraversableOnce[DataRecord]): Option[DataRecord] = { - if (dataRecordIter.isEmpty) { - None - } else { - var buffer = mutable.ArrayBuffer[DataRecord]() - val BatchSize = 1000 - - dataRecordIter.foreach { u => - if (buffer.size > BatchSize) sumBuffer(buffer) - buffer += u - } - - if (buffer.size > 1) sumBuffer(buffer) - Some(buffer(0)) - } - } -} - -/* - * This class is used when there is no need to use sumBuffer functionality, as in the case of - * online aggregation of datarecords where using a buffer on a small number of datarecords - * would add some performance overhead. - */ -case class DataRecordAggregationMonoidNoBuffer(aggregates: Set[TypedAggregateGroup[_]]) - extends DataRecordMonoid {} diff --git a/timelines/data_processing/ml_util/aggregation_framework/KeyedRecord.docx b/timelines/data_processing/ml_util/aggregation_framework/KeyedRecord.docx new file mode 100644 index 000000000..c0ddf7c0a Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/KeyedRecord.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/KeyedRecord.scala b/timelines/data_processing/ml_util/aggregation_framework/KeyedRecord.scala deleted file mode 100644 index bb3096767..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/KeyedRecord.scala +++ /dev/null @@ -1,27 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.ml.api.DataRecord - -/** - * Keyed record that is used to reprsent the aggregation type and its corresponding data record. - * - * @constructor creates a new keyed record. - * - * @param aggregateType the aggregate type - * @param record the data record associated with the key - **/ -case class KeyedRecord(aggregateType: AggregateType.Value, record: DataRecord) - -/** - * Keyed record map with multiple data record. - * - * @constructor creates a new keyed record map. - * - * @param aggregateType the aggregate type - * @param recordMap a map with key of type Long and value of type DataRecord - * where the key indicates the index and the value indicating the record - * - **/ -case class KeyedRecordMap( - aggregateType: AggregateType.Value, - recordMap: scala.collection.Map[Long, DataRecord]) diff --git a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateInjections.docx b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateInjections.docx new file mode 100644 index 000000000..a3ce00d49 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateInjections.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateInjections.scala b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateInjections.scala deleted file mode 100644 index 7ab1233c1..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateInjections.scala +++ /dev/null @@ -1,46 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.dal.personal_data.thriftscala.PersonalDataType -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection -import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.Batched -import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.JavaCompactThrift -import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.genericInjection -import com.twitter.summingbird.batch.BatchID -import scala.collection.JavaConverters._ - -object OfflineAggregateInjections { - val offlineDataRecordAggregateInjection: KeyValInjection[AggregationKey, (BatchID, DataRecord)] = - KeyValInjection( - genericInjection(AggregationKeyInjection), - Batched(JavaCompactThrift[DataRecord]) - ) - - private[aggregation_framework] def getPdts[T]( - aggregateGroups: Iterable[T], - featureExtractor: T => Iterable[Feature[_]] - ): Option[Set[PersonalDataType]] = { - val pdts: Set[PersonalDataType] = for { - group <- aggregateGroups.toSet[T] - feature <- featureExtractor(group) - pdtSet <- feature.getPersonalDataTypes.asSet().asScala - javaPdt <- pdtSet.asScala - scalaPdt <- PersonalDataType.get(javaPdt.getValue) - } yield { - scalaPdt - } - if (pdts.nonEmpty) Some(pdts) else None - } - - def getInjection( - aggregateGroups: Set[TypedAggregateGroup[_]] - ): KeyValInjection[AggregationKey, (BatchID, DataRecord)] = { - val keyPdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputKeys) - val valuePdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputFeatures) - KeyValInjection( - genericInjection(AggregationKeyInjection, keyPdts), - genericInjection(Batched(JavaCompactThrift[DataRecord]), valuePdts) - ) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateSource.docx b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateSource.docx new file mode 100644 index 000000000..df5bc4855 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateSource.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateSource.scala b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateSource.scala deleted file mode 100644 index 116f553c4..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateSource.scala +++ /dev/null @@ -1,21 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.dal.client.dataset.TimePartitionedDALDataset -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import java.lang.{Long => JLong} - -case class OfflineAggregateSource( - override val name: String, - override val timestampFeature: Feature[JLong], - scaldingHdfsPath: Option[String] = None, - scaldingSuffixType: Option[String] = None, - dalDataSet: Option[TimePartitionedDALDataset[DataRecord]] = None, - withValidation: Boolean = true) // context: https://jira.twitter.biz/browse/TQ-10618 - extends AggregateSource { - /* - * Th help transition callers to use DAL.read, we check that either the HDFS - * path is defined, or the dalDataset. Both options cannot be set at the same time. - */ - assert(!(scaldingHdfsPath.isDefined && dalDataSet.isDefined)) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateStore.docx b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateStore.docx new file mode 100644 index 000000000..f07d50219 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateStore.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateStore.scala b/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateStore.scala deleted file mode 100644 index 0bba08a94..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/OfflineAggregateStore.scala +++ /dev/null @@ -1,128 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.dal.client.dataset.KeyValDALDataset -import com.twitter.ml.api.DataRecord -import com.twitter.scalding.DateParser -import com.twitter.scalding.RichDate -import com.twitter.scalding_internal.multiformat.format.keyval.KeyVal -import com.twitter.storehaus_internal.manhattan._ -import com.twitter.storehaus_internal.util.ApplicationID -import com.twitter.storehaus_internal.util.DatasetName -import com.twitter.storehaus_internal.util.HDFSPath -import com.twitter.summingbird.batch.BatchID -import com.twitter.summingbird.batch.Batcher -import com.twitter.summingbird_internal.runner.store_config._ -import java.util.TimeZone -import com.twitter.summingbird.batch.MillisecondBatcher - -/* - * Configuration common to all offline aggregate stores - * - * @param outputHdfsPathPrefix HDFS prefix to store all output aggregate types offline - * @param dummyAppId Dummy manhattan app id required by summingbird (unused) - * @param dummyDatasetPrefix Dummy manhattan dataset prefix required by summingbird (unused) - * @param startDate Start date for summingbird job to begin computing aggregates - */ -case class OfflineAggregateStoreCommonConfig( - outputHdfsPathPrefix: String, - dummyAppId: String, - dummyDatasetPrefix: String, - startDate: String) - -/** - * A trait inherited by any object that defines - * a HDFS prefix to write output data to. E.g. timelines has its own - * output prefix to write aggregates_v2 results, your team can create - * its own. - */ -trait OfflineStoreCommonConfig extends Serializable { - /* - * @param startDate Date to create config for - * @return OfflineAggregateStoreCommonConfig object with all config details for output populated - */ - def apply(startDate: String): OfflineAggregateStoreCommonConfig -} - -/** - * @param name Uniquely identifiable human-readable name for this output store - * @param startDate Start date for this output store from which aggregates should be computed - * @param commonConfig Provider of other common configuration details - * @param batchesToKeep Retention policy on output (number of batches to keep) - */ -abstract class OfflineAggregateStoreBase - extends OfflineStoreOnlyConfig[ManhattanROConfig] - with AggregateStore { - - override def name: String - def startDate: String - def commonConfig: OfflineStoreCommonConfig - def batchesToKeep: Int - def maxKvSourceFailures: Int - - val datedCommonConfig: OfflineAggregateStoreCommonConfig = commonConfig.apply(startDate) - val manhattan: ManhattanROConfig = ManhattanROConfig( - /* This is a sample config, will be replaced with production config later */ - HDFSPath(s"${datedCommonConfig.outputHdfsPathPrefix}/${name}"), - ApplicationID(datedCommonConfig.dummyAppId), - DatasetName(s"${datedCommonConfig.dummyDatasetPrefix}_${name}_1"), - com.twitter.storehaus_internal.manhattan.Adama - ) - - val batcherSize = 24 - val batcher: MillisecondBatcher = Batcher.ofHours(batcherSize) - - val startTime: RichDate = - RichDate(datedCommonConfig.startDate)(TimeZone.getTimeZone("UTC"), DateParser.default) - - val offline: ManhattanROConfig = manhattan -} - -/** - * Defines an aggregates store which is composed of DataRecords - * @param name Uniquely identifiable human-readable name for this output store - * @param startDate Start date for this output store from which aggregates should be computed - * @param commonConfig Provider of other common configuration details - * @param batchesToKeep Retention policy on output (number of batches to keep) - */ -case class OfflineAggregateDataRecordStore( - override val name: String, - override val startDate: String, - override val commonConfig: OfflineStoreCommonConfig, - override val batchesToKeep: Int = 7, - override val maxKvSourceFailures: Int = 0) - extends OfflineAggregateStoreBase { - - def toOfflineAggregateDataRecordStoreWithDAL( - dalDataset: KeyValDALDataset[KeyVal[AggregationKey, (BatchID, DataRecord)]] - ): OfflineAggregateDataRecordStoreWithDAL = - OfflineAggregateDataRecordStoreWithDAL( - name = name, - startDate = startDate, - commonConfig = commonConfig, - dalDataset = dalDataset, - maxKvSourceFailures = maxKvSourceFailures - ) -} - -trait withDALDataset { - def dalDataset: KeyValDALDataset[KeyVal[AggregationKey, (BatchID, DataRecord)]] -} - -/** - * Defines an aggregates store which is composed of DataRecords and writes using DAL. - * @param name Uniquely identifiable human-readable name for this output store - * @param startDate Start date for this output store from which aggregates should be computed - * @param commonConfig Provider of other common configuration details - * @param dalDataset The KeyValDALDataset for this output store - * @param batchesToKeep Unused, kept for interface compatibility. You must define a separate Oxpecker - * retention policy to maintain the desired number of versions. - */ -case class OfflineAggregateDataRecordStoreWithDAL( - override val name: String, - override val startDate: String, - override val commonConfig: OfflineStoreCommonConfig, - override val dalDataset: KeyValDALDataset[KeyVal[AggregationKey, (BatchID, DataRecord)]], - override val batchesToKeep: Int = -1, - override val maxKvSourceFailures: Int = 0) - extends OfflineAggregateStoreBase - with withDALDataset diff --git a/timelines/data_processing/ml_util/aggregation_framework/README.docx b/timelines/data_processing/ml_util/aggregation_framework/README.docx new file mode 100644 index 000000000..3111a4a2b Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/README.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/README.md b/timelines/data_processing/ml_util/aggregation_framework/README.md deleted file mode 100644 index ea9a4b446..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/README.md +++ /dev/null @@ -1,39 +0,0 @@ -Overview -======== - - -The **aggregation framework** is a set of libraries and utilities that allows teams to flexibly -compute aggregate (counting) features in both batch and in real-time. Aggregate features can capture -historical interactions between on arbitrary entities (and sets thereof), conditional on provided features -and labels. - -These types of engineered aggregate features have proven to be highly impactful across different teams at Twitter. - - -What are some features we can compute? --------------------------------------- - -The framework supports computing aggregate features on provided grouping keys. The only constraint is that these keys are sparse binary features (or are sets thereof). - -For example, a common use case is to calculate a user's past engagement history with various types of tweets (photo, video, retweets, etc.), specific authors, specific in-network engagers or any other entity the user has interacted with and that could provide signal. In this case, the underlying aggregation keys are `userId`, `(userId, authorId)` or `(userId, engagerId)`. - -In Timelines and MagicRecs, we also compute custom aggregate engagement counts on every `tweetId`. Similary, other aggregations are possible, perhaps on `advertiserId` or `mediaId` as long as the grouping key is sparse binary. - - -What implementations are supported? ------------------------------------ - -Offline, we support the daily batch processing of DataRecords containing all required input features to generate -aggregate features. These are then uploaded to Manhattan for online hydration. - -Online, we support the real-time aggregation of DataRecords through Storm with a backing memcache that can be queried -for the real-time aggregate features. - -Additional documentation exists in the [docs folder](docs) - - -Where is this used? --------------------- - -The Home Timeline heavy ranker uses a varierty of both [batch and real time features](../../../../src/scala/com/twitter/timelines/prediction/common/aggregates/README.md) generated by this framework. -These features are also used for email and other recommendations. \ No newline at end of file diff --git a/timelines/data_processing/ml_util/aggregation_framework/StoreConfig.docx b/timelines/data_processing/ml_util/aggregation_framework/StoreConfig.docx new file mode 100644 index 000000000..02280b084 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/StoreConfig.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/StoreConfig.scala b/timelines/data_processing/ml_util/aggregation_framework/StoreConfig.scala deleted file mode 100644 index 703d5893c..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/StoreConfig.scala +++ /dev/null @@ -1,68 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.Feature -import com.twitter.ml.api.FeatureType - -/** - * Convenience class to describe the stores that make up a particular type of aggregate. - * - * For example, as of 2018/07, user aggregates are generate by merging the individual - * "user_aggregates", "rectweet_user_aggregates", and, "twitter_wide_user_aggregates". - * - * @param storeNames Name of the stores. - * @param aggregateType Type of aggregate, usually differentiated by the aggregation key. - * @param shouldHash Used at TimelineRankingAggregatesUtil.extractSecondary when extracting the - * secondary key value. - */ -case class StoreConfig[T]( - storeNames: Set[String], - aggregateType: AggregateType.Value, - shouldHash: Boolean = false -)( - implicit storeMerger: StoreMerger) { - require(storeMerger.isValidToMerge(storeNames)) - - private val representativeStore = storeNames.head - - val aggregationKeyIds: Set[Long] = storeMerger.getAggregateKeys(representativeStore) - val aggregationKeyFeatures: Set[Feature[_]] = - storeMerger.getAggregateKeyFeatures(representativeStore) - val secondaryKeyFeatureOpt: Option[Feature[_]] = storeMerger.getSecondaryKey(representativeStore) -} - -trait StoreMerger { - def aggregationConfig: AggregationConfig - - def getAggregateKeyFeatures(storeName: String): Set[Feature[_]] = - aggregationConfig.aggregatesToCompute - .filter(_.outputStore.name == storeName) - .flatMap(_.keysToAggregate) - - def getAggregateKeys(storeName: String): Set[Long] = - TypedAggregateGroup.getKeyFeatureIds(getAggregateKeyFeatures(storeName)) - - def getSecondaryKey(storeName: String): Option[Feature[_]] = { - val keys = getAggregateKeyFeatures(storeName) - require(keys.size <= 2, "Only singleton or binary aggregation keys are supported.") - require(keys.contains(SharedFeatures.USER_ID), "USER_ID must be one of the aggregation keys.") - keys - .filterNot(_ == SharedFeatures.USER_ID) - .headOption - .map { possiblySparseKey => - if (possiblySparseKey.getFeatureType != FeatureType.SPARSE_BINARY) { - possiblySparseKey - } else { - TypedAggregateGroup.sparseFeature(possiblySparseKey) - } - } - } - - /** - * Stores may only be merged if they have the same aggregation key. - */ - def isValidToMerge(storeNames: Set[String]): Boolean = { - val expectedKeyOpt = storeNames.headOption.map(getAggregateKeys) - storeNames.forall(v => getAggregateKeys(v) == expectedKeyOpt.get) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/StoreRegister.docx b/timelines/data_processing/ml_util/aggregation_framework/StoreRegister.docx new file mode 100644 index 000000000..2c289ac0b Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/StoreRegister.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/StoreRegister.scala b/timelines/data_processing/ml_util/aggregation_framework/StoreRegister.scala deleted file mode 100644 index a7e9cd535..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/StoreRegister.scala +++ /dev/null @@ -1,13 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -trait StoreRegister { - def allStores: Set[StoreConfig[_]] - - lazy val storeMap: Map[AggregateType.Value, StoreConfig[_]] = allStores - .map(store => (store.aggregateType, store)) - .toMap - - lazy val storeNameToTypeMap: Map[String, AggregateType.Value] = allStores - .flatMap(store => store.storeNames.map(name => (name, store.aggregateType))) - .toMap -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/TypedAggregateGroup.docx b/timelines/data_processing/ml_util/aggregation_framework/TypedAggregateGroup.docx new file mode 100644 index 000000000..3ab880c50 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/TypedAggregateGroup.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/TypedAggregateGroup.scala b/timelines/data_processing/ml_util/aggregation_framework/TypedAggregateGroup.scala deleted file mode 100644 index 92afc4137..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/TypedAggregateGroup.scala +++ /dev/null @@ -1,486 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregateFeature -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetric -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon._ -import com.twitter.timelines.data_processing.ml_util.transforms.OneToSomeTransform -import com.twitter.util.Duration -import com.twitter.util.Try -import java.lang.{Boolean => JBoolean} -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} -import java.util.{Set => JSet} -import scala.annotation.tailrec -import scala.language.existentials -import scala.collection.JavaConverters._ -import scala.util.matching.Regex - -/** - * A case class contained precomputed data useful to quickly - * process operations over an aggregate. - * - * @param query The underlying feature being aggregated - * @param metric The aggregation metric - * @param outputFeatures The output features that aggregation will produce - * @param outputFeatureIds The precomputed hashes of the above outputFeatures - */ -case class PrecomputedAggregateDescriptor[T]( - query: AggregateFeature[T], - metric: AggregationMetric[T, _], - outputFeatures: List[Feature[_]], - outputFeatureIds: List[JLong]) - -object TypedAggregateGroup { - - /** - * Recursive function that generates all combinations of value - * assignments for a collection of sparse binary features. - * - * @param sparseBinaryIdValues list of sparse binary feature ids and possible values they can take - * @return A set of maps, where each map represents one possible assignment of values to ids - */ - def sparseBinaryPermutations( - sparseBinaryIdValues: List[(Long, Set[String])] - ): Set[Map[Long, String]] = sparseBinaryIdValues match { - case (id, values) +: rest => - tailRecSparseBinaryPermutations( - existingPermutations = values.map(value => Map(id -> value)), - remainingIdValues = rest - ) - case Nil => Set.empty - } - - @tailrec private[this] def tailRecSparseBinaryPermutations( - existingPermutations: Set[Map[Long, String]], - remainingIdValues: List[(Long, Set[String])] - ): Set[Map[Long, String]] = remainingIdValues match { - case Nil => existingPermutations - case (id, values) +: rest => - tailRecSparseBinaryPermutations( - existingPermutations.flatMap { existingIdValueMap => - values.map(value => existingIdValueMap ++ Map(id -> value)) - }, - rest - ) - } - - val SparseFeatureSuffix = ".member" - def sparseFeature(sparseBinaryFeature: Feature[_]): Feature[String] = - new Feature.Text( - sparseBinaryFeature.getDenseFeatureName + SparseFeatureSuffix, - AggregationMetricCommon.derivePersonalDataTypes(Some(sparseBinaryFeature))) - - /* Throws exception if obj not an instance of U */ - private[this] def validate[U](obj: Any): U = { - require(obj.isInstanceOf[U]) - obj.asInstanceOf[U] - } - - private[this] def getFeatureOpt[U](dataRecord: DataRecord, feature: Feature[U]): Option[U] = - Option(SRichDataRecord(dataRecord).getFeatureValue(feature)).map(validate[U](_)) - - /** - * Get a mapping from feature ids - * (including individual sparse elements of a sparse feature) to values - * from the given data record, for a given feature type. - * - * @param dataRecord Data record to get features from - * @param keysToAggregate key features to get id-value mappings for - * @param featureType Feature type to get id-value maps for - */ - def getKeyFeatureIdValues[U]( - dataRecord: DataRecord, - keysToAggregate: Set[Feature[_]], - featureType: FeatureType - ): Set[(Long, Option[U])] = { - val featuresOfThisType: Set[Feature[U]] = keysToAggregate - .filter(_.getFeatureType == featureType) - .map(validate[Feature[U]]) - - featuresOfThisType - .map { feature: Feature[U] => - val featureId: Long = getDenseFeatureId(feature) - val featureOpt: Option[U] = getFeatureOpt(dataRecord, feature) - (featureId, featureOpt) - } - } - - // TypedAggregateGroup may transform the aggregate keys for internal use. This method generates - // denseFeatureIds for the transformed feature. - def getDenseFeatureId(feature: Feature[_]): Long = - if (feature.getFeatureType != FeatureType.SPARSE_BINARY) { - feature.getDenseFeatureId - } else { - sparseFeature(feature).getDenseFeatureId - } - - /** - * Return denseFeatureIds for the input features after applying the custom transformation that - * TypedAggregateGroup applies to its keysToAggregate. - * - * @param keysToAggregate key features to get id for - */ - def getKeyFeatureIds(keysToAggregate: Set[Feature[_]]): Set[Long] = - keysToAggregate.map(getDenseFeatureId) - - def checkIfAllKeysExist[U](featureIdValueMap: Map[Long, Option[U]]): Boolean = - featureIdValueMap.forall { case (_, valueOpt) => valueOpt.isDefined } - - def liftOptions[U](featureIdValueMap: Map[Long, Option[U]]): Map[Long, U] = - featureIdValueMap - .flatMap { - case (id, valueOpt) => - valueOpt.map { value => (id, value) } - } - - val timestampFeature: Feature[JLong] = SharedFeatures.TIMESTAMP - - /** - * Builds all valid aggregation keys (for the output store) from - * a datarecord and a spec listing the keys to aggregate. There - * can be multiple aggregation keys generated from a single data - * record when grouping by sparse binary features, for which multiple - * values can be set within the data record. - * - * @param dataRecord Data record to read values for key features from - * @return A set of AggregationKeys encoding the values of all keys - */ - def buildAggregationKeys( - dataRecord: DataRecord, - keysToAggregate: Set[Feature[_]] - ): Set[AggregationKey] = { - val discreteAggregationKeys = getKeyFeatureIdValues[Long]( - dataRecord, - keysToAggregate, - FeatureType.DISCRETE - ).toMap - - val textAggregationKeys = getKeyFeatureIdValues[String]( - dataRecord, - keysToAggregate, - FeatureType.STRING - ).toMap - - val sparseBinaryIdValues = getKeyFeatureIdValues[JSet[String]]( - dataRecord, - keysToAggregate, - FeatureType.SPARSE_BINARY - ).map { - case (id, values) => - ( - id, - values - .map(_.asScala.toSet) - .getOrElse(Set.empty[String]) - ) - }.toList - - if (checkIfAllKeysExist(discreteAggregationKeys) && - checkIfAllKeysExist(textAggregationKeys)) { - if (sparseBinaryIdValues.nonEmpty) { - sparseBinaryPermutations(sparseBinaryIdValues).map { sparseBinaryTextKeys => - AggregationKey( - discreteFeaturesById = liftOptions(discreteAggregationKeys), - textFeaturesById = liftOptions(textAggregationKeys) ++ sparseBinaryTextKeys - ) - } - } else { - Set( - AggregationKey( - discreteFeaturesById = liftOptions(discreteAggregationKeys), - textFeaturesById = liftOptions(textAggregationKeys) - ) - ) - } - } else Set.empty[AggregationKey] - } - -} - -/** - * Specifies one or more related aggregate(s) to compute in the summingbird job. - * - * @param inputSource Source to compute this aggregate over - * @param preTransforms Sequence of [[com.twitter.ml.api.RichITransform]] that transform - * data records pre-aggregation (e.g. discretization, renaming) - * @param samplingTransformOpt Optional [[OneToSomeTransform]] that transform data - * record to optional data record (e.g. for sampling) before aggregation - * @param aggregatePrefix Prefix to use for naming resultant aggregate features - * @param keysToAggregate Features to group by when computing the aggregates - * (e.g. USER_ID, AUTHOR_ID) - * @param featuresToAggregate Features to aggregate (e.g. blender_score or is_photo) - * @param labels Labels to cross the features with to make pair features, if any. - * use Label.All if you don't want to cross with a label. - * @param metrics Aggregation metrics to compute (e.g. count, mean) - * @param halfLives Half lives to use for the aggregations, to be crossed with the above. - * use Duration.Top for "forever" aggregations over an infinite time window (no decay). - * @param outputStore Store to output this aggregate to - * @param includeAnyFeature Aggregate label counts for any feature value - * @param includeAnyLabel Aggregate feature counts for any label value (e.g. all impressions) - * - * The overall config for the summingbird job consists of a list of "AggregateGroup" - * case class objects, which get translated into strongly typed "TypedAggregateGroup" - * case class objects. A single TypedAggregateGroup always groups input data records from - * ''inputSource'' by a single set of aggregation keys (''featuresToAggregate''). - * Within these groups, we perform a comprehensive cross of: - * - * ''featuresToAggregate'' x ''labels'' x ''metrics'' x ''halfLives'' - * - * All the resultant aggregate features are assigned a human-readable feature name - * beginning with ''aggregatePrefix'', and are written to DataRecords that get - * aggregated and written to the store specified by ''outputStore''. - * - * Illustrative example. Suppose we define our spec as follows: - * - * TypedAggregateGroup( - * inputSource = "timelines_recap_daily", - * aggregatePrefix = "user_author_aggregate", - * keysToAggregate = Set(USER_ID, AUTHOR_ID), - * featuresToAggregate = Set(RecapFeatures.TEXT_SCORE, RecapFeatures.BLENDER_SCORE), - * labels = Set(RecapFeatures.IS_FAVORITED, RecapFeatures.IS_REPLIED), - * metrics = Set(CountMetric, MeanMetric), - * halfLives = Set(7.Days, 30.Days), - * outputStore = "user_author_aggregate_store" - * ) - * - * This will process data records from the source named "timelines_recap_daily" - * (see AggregateSource.scala for more details on how to add your own source) - * It will produce a total of 2x2x2x2 = 16 aggregation features, named like: - * - * user_author_aggregate.pair.recap.engagement.is_favorited.recap.searchfeature.blender_score.count.7days - * user_author_aggregate.pair.recap.engagement.is_favorited.recap.searchfeature.blender_score.count.30days - * user_author_aggregate.pair.recap.engagement.is_favorited.recap.searchfeature.blender_score.mean.7days - * - * ... (and so on) - * - * and all the result features will be stored in DataRecords, summed up, and written - * to the output store defined by the name "user_author_aggregate_store". - * (see AggregateStore.scala for details on how to add your own store). - * - * If you do not want a full cross, split up your config into multiple TypedAggregateGroup - * objects. Splitting is strongly advised to avoid blowing up and creating invalid - * or unnecessary combinations of aggregate features (note that some combinations - * are useless or invalid e.g. computing the mean of a binary feature). Splitting - * also does not cost anything in terms of real-time performance, because all - * Aggregate objects in the master spec that share the same ''keysToAggregate'', the - * same ''inputSource'' and the same ''outputStore'' are grouped by the summingbird - * job logic and stored into a single DataRecord in the output store. Overlapping - * aggregates will also automatically be deduplicated so don't worry about overlaps. - */ -case class TypedAggregateGroup[T]( - inputSource: AggregateSource, - aggregatePrefix: String, - keysToAggregate: Set[Feature[_]], - featuresToAggregate: Set[Feature[T]], - labels: Set[_ <: Feature[JBoolean]], - metrics: Set[AggregationMetric[T, _]], - halfLives: Set[Duration], - outputStore: AggregateStore, - preTransforms: Seq[OneToSomeTransform] = Seq.empty, - includeAnyFeature: Boolean = true, - includeAnyLabel: Boolean = true, - aggExclusionRegex: Seq[String] = Seq.empty) { - import TypedAggregateGroup._ - - val compiledRegexes = aggExclusionRegex.map(new Regex(_)) - - // true if should drop, false if should keep - def filterOutAggregateFeature( - feature: PrecomputedAggregateDescriptor[_], - regexes: Seq[Regex] - ): Boolean = { - if (regexes.nonEmpty) - feature.outputFeatures.exists { feature => - regexes.exists { re => re.findFirstMatchIn(feature.getDenseFeatureName).nonEmpty } - } - else false - } - - def buildAggregationKeys( - dataRecord: DataRecord - ): Set[AggregationKey] = { - TypedAggregateGroup.buildAggregationKeys(dataRecord, keysToAggregate) - } - - /** - * This val precomputes descriptors for all individual aggregates in this group - * (of type ''AggregateFeature''). Also precompute hashes of all aggregation - * "output" features generated by these operators for faster - * run-time performance (this turns out to be a primary CPU bottleneck). - * Ex: for the mean operator, "sum" and "count" are output features - */ - val individualAggregateDescriptors: Set[PrecomputedAggregateDescriptor[T]] = { - /* - * By default, in additional to all feature-label crosses, also - * compute in aggregates over each feature and label without crossing - */ - val labelOptions = labels.map(Option(_)) ++ - (if (includeAnyLabel) Set(None) else Set.empty) - val featureOptions = featuresToAggregate.map(Option(_)) ++ - (if (includeAnyFeature) Set(None) else Set.empty) - for { - feature <- featureOptions - label <- labelOptions - metric <- metrics - halfLife <- halfLives - } yield { - val query = AggregateFeature[T](aggregatePrefix, feature, label, halfLife) - - val aggregateOutputFeatures = metric.getOutputFeatures(query) - val aggregateOutputFeatureIds = metric.getOutputFeatureIds(query) - PrecomputedAggregateDescriptor( - query, - metric, - aggregateOutputFeatures, - aggregateOutputFeatureIds - ) - } - }.filterNot(filterOutAggregateFeature(_, compiledRegexes)) - - /* Precomputes a map from all generated aggregate feature ids to their half lives. */ - val continuousFeatureIdsToHalfLives: Map[Long, Duration] = - individualAggregateDescriptors.flatMap { descriptor => - descriptor.outputFeatures - .flatMap { feature => - if (feature.getFeatureType() == FeatureType.CONTINUOUS) { - Try(feature.asInstanceOf[Feature[JDouble]]).toOption - .map(feature => (feature.getFeatureId(), descriptor.query.halfLife)) - } else None - } - }.toMap - - /* - * Sparse binary keys become individual string keys in the output. - * e.g. group by "words.in.tweet", output key: "words.in.tweet.member" - */ - val allOutputKeys: Set[Feature[_]] = keysToAggregate.map { key => - if (key.getFeatureType == FeatureType.SPARSE_BINARY) sparseFeature(key) - else key - } - - val allOutputFeatures: Set[Feature[_]] = individualAggregateDescriptors.flatMap { - case PrecomputedAggregateDescriptor( - query, - metric, - outputFeatures, - outputFeatureIds - ) => - outputFeatures - } - - val aggregateContext: FeatureContext = new FeatureContext(allOutputFeatures.toList.asJava) - - /** - * Adds all aggregates in this group found in the two input data records - * into a result, mutating the result. Uses a while loop for an - * approximately 10% gain in speed over a for comprehension. - * - * WARNING: mutates ''result'' - * - * @param result The output data record to mutate - * @param left The left data record to add - * @param right The right data record to add - */ - def mutatePlus(result: DataRecord, left: DataRecord, right: DataRecord): Unit = { - val featureIterator = individualAggregateDescriptors.iterator - while (featureIterator.hasNext) { - val descriptor = featureIterator.next - descriptor.metric.mutatePlus( - result, - left, - right, - descriptor.query, - Some(descriptor.outputFeatureIds) - ) - } - } - - /** - * Apply preTransforms sequentially. If any transform results in a dropped (None) - * DataRecord, then entire tranform sequence will result in a dropped DataRecord. - * Note that preTransforms are order-dependent. - */ - private[this] def sequentiallyTransform(dataRecord: DataRecord): Option[DataRecord] = { - val recordOpt = Option(new DataRecord(dataRecord)) - preTransforms.foldLeft(recordOpt) { - case (Some(previousRecord), preTransform) => - preTransform(previousRecord) - case _ => Option.empty[DataRecord] - } - } - - /** - * Given a data record, apply transforms and fetch the incremental contributions to - * each configured aggregate from this data record, and store these in an output data record. - * - * @param dataRecord Input data record to aggregate. - * @return A set of tuples (AggregationKey, DataRecord) whose first entry is an - * AggregationKey indicating what keys we're grouping by, and whose second entry - * is an output data record with incremental contributions to the aggregate value(s) - */ - def computeAggregateKVPairs(dataRecord: DataRecord): Set[(AggregationKey, DataRecord)] = { - sequentiallyTransform(dataRecord) - .flatMap { dataRecord => - val aggregationKeys = buildAggregationKeys(dataRecord) - val increment = new DataRecord - - val isNonEmptyIncrement = individualAggregateDescriptors - .map { descriptor => - descriptor.metric.setIncrement( - output = increment, - input = dataRecord, - query = descriptor.query, - timestampFeature = inputSource.timestampFeature, - aggregateOutputs = Some(descriptor.outputFeatureIds) - ) - } - .exists(identity) - - if (isNonEmptyIncrement) { - SRichDataRecord(increment).setFeatureValue( - timestampFeature, - getTimestamp(dataRecord, inputSource.timestampFeature) - ) - Some(aggregationKeys.map(key => (key, increment))) - } else { - None - } - } - .getOrElse(Set.empty[(AggregationKey, DataRecord)]) - } - - def outputFeaturesToRenamedOutputFeatures(prefix: String): Map[Feature[_], Feature[_]] = { - require(prefix.nonEmpty) - - allOutputFeatures.map { feature => - if (feature.isSetFeatureName) { - val renamedFeatureName = prefix + feature.getDenseFeatureName - val personalDataTypes = - if (feature.getPersonalDataTypes.isPresent) feature.getPersonalDataTypes.get() - else null - - val renamedFeature = feature.getFeatureType match { - case FeatureType.BINARY => - new Feature.Binary(renamedFeatureName, personalDataTypes) - case FeatureType.DISCRETE => - new Feature.Discrete(renamedFeatureName, personalDataTypes) - case FeatureType.STRING => - new Feature.Text(renamedFeatureName, personalDataTypes) - case FeatureType.CONTINUOUS => - new Feature.Continuous(renamedFeatureName, personalDataTypes) - case FeatureType.SPARSE_BINARY => - new Feature.SparseBinary(renamedFeatureName, personalDataTypes) - case FeatureType.SPARSE_CONTINUOUS => - new Feature.SparseContinuous(renamedFeatureName, personalDataTypes) - } - feature -> renamedFeature - } else { - feature -> feature - } - }.toMap - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/Utils.docx b/timelines/data_processing/ml_util/aggregation_framework/Utils.docx new file mode 100644 index 000000000..a2d04f4e6 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/Utils.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/Utils.scala b/timelines/data_processing/ml_util/aggregation_framework/Utils.scala deleted file mode 100644 index 60196fc62..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/Utils.scala +++ /dev/null @@ -1,122 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -import com.twitter.algebird.ScMapMonoid -import com.twitter.algebird.Semigroup -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.ml.api.FeatureType -import com.twitter.ml.api.util.SRichDataRecord -import java.lang.{Long => JLong} -import scala.collection.{Map => ScMap} - -object Utils { - val dataRecordMerger: DataRecordMerger = new DataRecordMerger - def EmptyDataRecord: DataRecord = new DataRecord() - - private val random = scala.util.Random - private val keyedDataRecordMapMonoid = { - val dataRecordMergerSg = new Semigroup[DataRecord] { - override def plus(x: DataRecord, y: DataRecord): DataRecord = { - dataRecordMerger.merge(x, y) - x - } - } - new ScMapMonoid[Long, DataRecord]()(dataRecordMergerSg) - } - - def keyFromLong(record: DataRecord, feature: Feature[JLong]): Long = - SRichDataRecord(record).getFeatureValue(feature).longValue - - def keyFromString(record: DataRecord, feature: Feature[String]): Long = - try { - SRichDataRecord(record).getFeatureValue(feature).toLong - } catch { - case _: NumberFormatException => 0L - } - - def keyFromHash(record: DataRecord, feature: Feature[String]): Long = - SRichDataRecord(record).getFeatureValue(feature).hashCode.toLong - - def extractSecondary[T]( - record: DataRecord, - secondaryKey: Feature[T], - shouldHash: Boolean = false - ): Long = secondaryKey.getFeatureType match { - case FeatureType.STRING => - if (shouldHash) keyFromHash(record, secondaryKey.asInstanceOf[Feature[String]]) - else keyFromString(record, secondaryKey.asInstanceOf[Feature[String]]) - case FeatureType.DISCRETE => keyFromLong(record, secondaryKey.asInstanceOf[Feature[JLong]]) - case f => throw new IllegalArgumentException(s"Feature type $f is not supported.") - } - - def mergeKeyedRecordOpts(args: Option[KeyedRecord]*): Option[KeyedRecord] = { - val keyedRecords = args.flatten - if (keyedRecords.isEmpty) { - None - } else { - val keys = keyedRecords.map(_.aggregateType) - require(keys.toSet.size == 1, "All merged records must have the same aggregate key.") - val mergedRecord = mergeRecords(keyedRecords.map(_.record): _*) - Some(KeyedRecord(keys.head, mergedRecord)) - } - } - - private def mergeRecords(args: DataRecord*): DataRecord = - if (args.isEmpty) EmptyDataRecord - else { - // can just do foldLeft(new DataRecord) for both cases, but try reusing the EmptyDataRecord singleton as much as possible - args.tail.foldLeft(args.head) { (merged, record) => - dataRecordMerger.merge(merged, record) - merged - } - } - - def mergeKeyedRecordMapOpts( - opt1: Option[KeyedRecordMap], - opt2: Option[KeyedRecordMap], - maxSize: Int = Int.MaxValue - ): Option[KeyedRecordMap] = { - if (opt1.isEmpty && opt2.isEmpty) { - None - } else { - val keys = Seq(opt1, opt2).flatten.map(_.aggregateType) - require(keys.toSet.size == 1, "All merged records must have the same aggregate key.") - val mergedRecordMap = mergeMapOpts(opt1.map(_.recordMap), opt2.map(_.recordMap), maxSize) - Some(KeyedRecordMap(keys.head, mergedRecordMap)) - } - } - - private def mergeMapOpts( - opt1: Option[ScMap[Long, DataRecord]], - opt2: Option[ScMap[Long, DataRecord]], - maxSize: Int = Int.MaxValue - ): ScMap[Long, DataRecord] = { - require(maxSize >= 0) - val keySet = opt1.map(_.keySet).getOrElse(Set.empty) ++ opt2.map(_.keySet).getOrElse(Set.empty) - val totalSize = keySet.size - val rate = if (totalSize <= maxSize) 1.0 else maxSize.toDouble / totalSize - val prunedOpt1 = opt1.map(downsample(_, rate)) - val prunedOpt2 = opt2.map(downsample(_, rate)) - Seq(prunedOpt1, prunedOpt2).flatten - .foldLeft(keyedDataRecordMapMonoid.zero)(keyedDataRecordMapMonoid.plus) - } - - def downsample[K, T](m: ScMap[K, T], samplingRate: Double): ScMap[K, T] = { - if (samplingRate >= 1.0) { - m - } else if (samplingRate <= 0) { - Map.empty - } else { - m.filter { - case (key, _) => - // It is important that the same user with the same sampling rate be deterministically - // selected or rejected. Otherwise, mergeMapOpts will choose different keys for the - // two input maps and their union will be larger than the limit we want. - random.setSeed((key.hashCode, samplingRate.hashCode).hashCode) - random.nextDouble < samplingRate - } - } - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2Adapter.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2Adapter.docx new file mode 100644 index 000000000..9b2f97673 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2Adapter.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2Adapter.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2Adapter.scala deleted file mode 100644 index f5b7d1814..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2Adapter.scala +++ /dev/null @@ -1,165 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.algebird.DecayedValue -import com.twitter.algebird.DecayedValueMonoid -import com.twitter.algebird.Monoid -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.util.FDsl._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.summingbird.batch.BatchID -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregateFeature -import com.twitter.util.Duration -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} -import scala.collection.JavaConverters._ -import scala.collection.mutable -import java.{util => ju} - -object AggregatesV2Adapter { - type AggregatesV2Tuple = (AggregationKey, (BatchID, DataRecord)) - - val Epsilon: Double = 1e-6 - val decayedValueMonoid: Monoid[DecayedValue] = DecayedValueMonoid(Epsilon) - - /* - * Decays the storedValue from timestamp -> sourceVersion - * - * @param storedValue value read from the aggregates v2 output store - * @param timestamp timestamp corresponding to store value - * @param sourceVersion timestamp of version to decay all values to uniformly - * @param halfLife Half life duration to use for applying decay - * - * By applying this function, the feature values for all users are decayed - * to sourceVersion. This is important to ensure that a user whose aggregates - * were updated long in the past does not have an artifically inflated count - * compared to one whose aggregates were updated (and hence decayed) more recently. - */ - def decayValueToSourceVersion( - storedValue: Double, - timestamp: Long, - sourceVersion: Long, - halfLife: Duration - ): Double = - if (timestamp > sourceVersion) { - storedValue - } else { - decayedValueMonoid - .plus( - DecayedValue.build(storedValue, timestamp, halfLife.inMilliseconds), - DecayedValue.build(0, sourceVersion, halfLife.inMilliseconds) - ) - .value - } - - /* - * Decays all the aggregate features occurring in the ''inputRecord'' - * to a given timestamp, and mutates the ''outputRecord'' accordingly. - * Note that inputRecord and outputRecord can be the same if you want - * to mutate the input in place, the function does this correctly. - * - * @param inputRecord Input record to get features from - * @param aggregates Aggregates to decay - * @param decayTo Timestamp to decay to - * @param trimThreshold Drop features below this trim threshold - * @param outputRecord Output record to mutate - * @return the mutated outputRecord - */ - def mutateDecay( - inputRecord: DataRecord, - aggregateFeaturesAndHalfLives: List[(Feature[_], Duration)], - decayTo: Long, - trimThreshold: Double, - outputRecord: DataRecord - ): DataRecord = { - val timestamp = inputRecord.getFeatureValue(SharedFeatures.TIMESTAMP).toLong - - aggregateFeaturesAndHalfLives.foreach { - case (aggregateFeature: Feature[_], halfLife: Duration) => - if (aggregateFeature.getFeatureType() == FeatureType.CONTINUOUS) { - val continuousFeature = aggregateFeature.asInstanceOf[Feature[JDouble]] - if (inputRecord.hasFeature(continuousFeature)) { - val storedValue = inputRecord.getFeatureValue(continuousFeature).toDouble - val decayedValue = decayValueToSourceVersion(storedValue, timestamp, decayTo, halfLife) - if (math.abs(decayedValue) > trimThreshold) { - outputRecord.setFeatureValue(continuousFeature, decayedValue) - } - } - } - } - - /* Update timestamp to version (now that we've decayed all aggregates) */ - outputRecord.setFeatureValue(SharedFeatures.TIMESTAMP, decayTo) - - outputRecord - } -} - -class AggregatesV2Adapter( - aggregates: Set[TypedAggregateGroup[_]], - sourceVersion: Long, - trimThreshold: Double) - extends IRecordOneToManyAdapter[AggregatesV2Adapter.AggregatesV2Tuple] { - - import AggregatesV2Adapter._ - - val keyFeatures: List[Feature[_]] = aggregates.flatMap(_.allOutputKeys).toList - val aggregateFeatures: List[Feature[_]] = aggregates.flatMap(_.allOutputFeatures).toList - val timestampFeatures: List[Feature[JLong]] = List(SharedFeatures.TIMESTAMP) - val allFeatures: List[Feature[_]] = keyFeatures ++ aggregateFeatures ++ timestampFeatures - - val featureContext: FeatureContext = new FeatureContext(allFeatures.asJava) - - override def getFeatureContext: FeatureContext = featureContext - - val aggregateFeaturesAndHalfLives: List[(Feature[_$3], Duration) forSome { type _$3 }] = - aggregateFeatures.map { aggregateFeature: Feature[_] => - val halfLife = AggregateFeature.parseHalfLife(aggregateFeature) - (aggregateFeature, halfLife) - } - - override def adaptToDataRecords(tuple: AggregatesV2Tuple): ju.List[DataRecord] = tuple match { - case (key: AggregationKey, (batchId: BatchID, record: DataRecord)) => { - val resultRecord = new SRichDataRecord(new DataRecord, featureContext) - - val itr = resultRecord.continuousFeaturesIterator() - val featuresToClear = mutable.Set[Feature[JDouble]]() - while (itr.moveNext()) { - val nextFeature = itr.getFeature - if (!aggregateFeatures.contains(nextFeature)) { - featuresToClear += nextFeature - } - } - - featuresToClear.foreach(resultRecord.clearFeature) - - keyFeatures.foreach { keyFeature: Feature[_] => - if (keyFeature.getFeatureType == FeatureType.DISCRETE) { - resultRecord.setFeatureValue( - keyFeature.asInstanceOf[Feature[JLong]], - key.discreteFeaturesById(keyFeature.getDenseFeatureId) - ) - } else if (keyFeature.getFeatureType == FeatureType.STRING) { - resultRecord.setFeatureValue( - keyFeature.asInstanceOf[Feature[String]], - key.textFeaturesById(keyFeature.getDenseFeatureId) - ) - } - } - - if (record.hasFeature(SharedFeatures.TIMESTAMP)) { - mutateDecay( - record, - aggregateFeaturesAndHalfLives, - sourceVersion, - trimThreshold, - resultRecord) - List(resultRecord.getRecord).asJava - } else { - List.empty[DataRecord].asJava - } - } - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2FeatureSource.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2FeatureSource.docx new file mode 100644 index 000000000..46fc84132 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2FeatureSource.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2FeatureSource.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2FeatureSource.scala deleted file mode 100644 index 5e196a43e..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/AggregatesV2FeatureSource.scala +++ /dev/null @@ -1,171 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.bijection.Injection -import com.twitter.bijection.thrift.CompactThriftCodec -import com.twitter.ml.api.AdaptedFeatureSource -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.IRecordOneToManyAdapter -import com.twitter.ml.api.TypedFeatureSource -import com.twitter.scalding.DateRange -import com.twitter.scalding.RichDate -import com.twitter.scalding.TypedPipe -import com.twitter.scalding.commons.source.VersionedKeyValSource -import com.twitter.scalding.commons.tap.VersionedTap.TapMode -import com.twitter.summingbird.batch.BatchID -import com.twitter.summingbird_internal.bijection.BatchPairImplicits -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKeyInjection -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import org.apache.hadoop.mapred.JobConf -import scala.collection.JavaConverters._ -import AggregatesV2Adapter._ - -object AggregatesV2AdaptedSource { - val DefaultTrimThreshold = 0 -} - -trait AggregatesV2AdaptedSource extends AggregatesV2AdaptedSourceBase[DataRecord] { - override def storageFormatCodec: Injection[DataRecord, Array[Byte]] = - CompactThriftCodec[DataRecord] - override def toDataRecord(v: DataRecord): DataRecord = v -} - -trait AggregatesV2AdaptedSourceBase[StorageFormat] - extends TypedFeatureSource[AggregatesV2Tuple] - with AdaptedFeatureSource[AggregatesV2Tuple] - with BatchPairImplicits { - - /* Output root path of aggregates v2 job, excluding store name and version */ - def rootPath: String - - /* Name of store under root path to read */ - def storeName: String - - // max bijection failures - def maxFailures: Int = 0 - - /* Aggregate config used to generate above output */ - def aggregates: Set[TypedAggregateGroup[_]] - - /* trimThreshold Trim all aggregates below a certain threshold to save memory */ - def trimThreshold: Double - - def toDataRecord(v: StorageFormat): DataRecord - - def sourceVersionOpt: Option[Long] - - def enableMostRecentBeforeSourceVersion: Boolean = false - - implicit private val aggregationKeyInjection: Injection[AggregationKey, Array[Byte]] = - AggregationKeyInjection - implicit def storageFormatCodec: Injection[StorageFormat, Array[Byte]] - - private def filteredAggregates = aggregates.filter(_.outputStore.name == storeName) - def storePath: String = List(rootPath, storeName).mkString("/") - - def mostRecentVkvs: VersionedKeyValSource[_, _] = { - VersionedKeyValSource[AggregationKey, (BatchID, StorageFormat)]( - path = storePath, - sourceVersion = None, - maxFailures = maxFailures - ) - } - - private def availableVersions: Seq[Long] = - mostRecentVkvs - .getTap(TapMode.SOURCE) - .getStore(new JobConf(true)) - .getAllVersions() - .asScala - .map(_.toLong) - - private def mostRecentVersion: Long = { - require(!availableVersions.isEmpty, s"$storeName has no available versions") - availableVersions.max - } - - def versionToUse: Long = - if (enableMostRecentBeforeSourceVersion) { - sourceVersionOpt - .map(sourceVersion => - availableVersions.filter(_ <= sourceVersion) match { - case Seq() => - throw new IllegalArgumentException( - "No version older than version: %s, available versions: %s" - .format(sourceVersion, availableVersions) - ) - case versionList => versionList.max - }) - .getOrElse(mostRecentVersion) - } else { - sourceVersionOpt.getOrElse(mostRecentVersion) - } - - override lazy val adapter: IRecordOneToManyAdapter[AggregatesV2Tuple] = - new AggregatesV2Adapter(filteredAggregates, versionToUse, trimThreshold) - - override def getData: TypedPipe[AggregatesV2Tuple] = { - val vkvsToUse: VersionedKeyValSource[AggregationKey, (BatchID, StorageFormat)] = { - VersionedKeyValSource[AggregationKey, (BatchID, StorageFormat)]( - path = storePath, - sourceVersion = Some(versionToUse), - maxFailures = maxFailures - ) - } - TypedPipe.from(vkvsToUse).map { - case (key, (batch, value)) => (key, (batch, toDataRecord(value))) - } - } -} - -/* - * Adapted data record feature source from aggregates v2 manhattan output - * Params documented in parent trait. - */ -case class AggregatesV2FeatureSource( - override val rootPath: String, - override val storeName: String, - override val aggregates: Set[TypedAggregateGroup[_]], - override val trimThreshold: Double = 0, - override val maxFailures: Int = 0, -)( - implicit val dateRange: DateRange) - extends AggregatesV2AdaptedSource { - - // Increment end date by 1 millisec since summingbird output for date D is stored at (D+1)T00 - override val sourceVersionOpt: Some[Long] = Some(dateRange.end.timestamp + 1) -} - -/* - * Reads most recent available AggregatesV2FeatureSource. - * There is no constraint on recency. - * Params documented in parent trait. - */ -case class AggregatesV2MostRecentFeatureSource( - override val rootPath: String, - override val storeName: String, - override val aggregates: Set[TypedAggregateGroup[_]], - override val trimThreshold: Double = AggregatesV2AdaptedSource.DefaultTrimThreshold, - override val maxFailures: Int = 0) - extends AggregatesV2AdaptedSource { - - override val sourceVersionOpt: None.type = None -} - -/* - * Reads most recent available AggregatesV2FeatureSource - * on or before the specified beforeDate. - * Params documented in parent trait. - */ -case class AggregatesV2MostRecentFeatureSourceBeforeDate( - override val rootPath: String, - override val storeName: String, - override val aggregates: Set[TypedAggregateGroup[_]], - override val trimThreshold: Double = AggregatesV2AdaptedSource.DefaultTrimThreshold, - beforeDate: RichDate, - override val maxFailures: Int = 0) - extends AggregatesV2AdaptedSource { - - override val enableMostRecentBeforeSourceVersion = true - override val sourceVersionOpt: Some[Long] = Some(beforeDate.timestamp + 1) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/BUILD b/timelines/data_processing/ml_util/aggregation_framework/conversion/BUILD deleted file mode 100644 index d6c86cc12..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/BUILD +++ /dev/null @@ -1,71 +0,0 @@ -scala_library( - sources = ["*.scala"], - platform = "java8", - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/twitter/algebird:core", - "3rdparty/jvm/com/twitter/algebird:util", - "3rdparty/jvm/com/twitter/bijection:core", - "3rdparty/jvm/com/twitter/bijection:json", - "3rdparty/jvm/com/twitter/bijection:netty", - "3rdparty/jvm/com/twitter/bijection:scrooge", - "3rdparty/jvm/com/twitter/bijection:thrift", - "3rdparty/jvm/com/twitter/bijection:util", - "3rdparty/jvm/com/twitter/storehaus:algebra", - "3rdparty/jvm/com/twitter/storehaus:core", - "3rdparty/src/jvm/com/twitter/scalding:commons", - "3rdparty/src/jvm/com/twitter/scalding:core", - "3rdparty/src/jvm/com/twitter/scalding:date", - "3rdparty/src/jvm/com/twitter/summingbird:batch", - "3rdparty/src/jvm/com/twitter/summingbird:core", - "src/java/com/twitter/ml/api:api-base", - "src/java/com/twitter/ml/api/constant", - "src/scala/com/twitter/ml/api:api-base", - "src/scala/com/twitter/ml/api/util", - "src/scala/com/twitter/summingbird_internal/bijection:bijection-implicits", - "src/thrift/com/twitter/dal/personal_data:personal_data-java", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:interpretable-model-java", - "src/thrift/com/twitter/summingbird", - "timelines/data_processing/ml_util/aggregation_framework:common_types", - "timelines/data_processing/ml_util/aggregation_framework/metrics", - "util/util-core:scala", - ], -) - -scala_library( - name = "for-timelines", - sources = [ - "CombineCountsPolicy.scala", - "SparseBinaryMergePolicy.scala", - ], - platform = "java8", - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/twitter/algebird:core", - "3rdparty/jvm/com/twitter/algebird:util", - "3rdparty/jvm/com/twitter/bijection:core", - "3rdparty/jvm/com/twitter/bijection:json", - "3rdparty/jvm/com/twitter/bijection:netty", - "3rdparty/jvm/com/twitter/bijection:scrooge", - "3rdparty/jvm/com/twitter/bijection:thrift", - "3rdparty/jvm/com/twitter/bijection:util", - "3rdparty/jvm/com/twitter/storehaus:algebra", - "3rdparty/jvm/com/twitter/storehaus:core", - "3rdparty/src/jvm/com/twitter/scalding:commons", - "3rdparty/src/jvm/com/twitter/scalding:core", - "3rdparty/src/jvm/com/twitter/scalding:date", - "3rdparty/src/jvm/com/twitter/summingbird:batch", - "3rdparty/src/jvm/com/twitter/summingbird:core", - "src/java/com/twitter/ml/api:api-base", - "src/java/com/twitter/ml/api/constant", - "src/scala/com/twitter/summingbird_internal/bijection:bijection-implicits", - "src/thrift/com/twitter/dal/personal_data:personal_data-java", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:interpretable-model-java", - "src/thrift/com/twitter/summingbird", - "timelines/data_processing/ml_util/aggregation_framework:common_types", - "timelines/data_processing/ml_util/aggregation_framework/metrics", - "util/util-core:scala", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/BUILD.docx new file mode 100644 index 000000000..06167389f Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/CombineCountsPolicy.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/CombineCountsPolicy.docx new file mode 100644 index 000000000..2eb584811 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/CombineCountsPolicy.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/CombineCountsPolicy.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/CombineCountsPolicy.scala deleted file mode 100644 index eb1690231..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/CombineCountsPolicy.scala +++ /dev/null @@ -1,223 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.google.common.annotations.VisibleForTesting -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.ml.api.FeatureContext -import com.twitter.ml.api._ -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.TypedCountMetric -import java.lang.{Double => JDouble} -import scala.collection.JavaConverters._ - -case class CombinedFeatures( - sum: Feature[JDouble], - nonzero: Feature[JDouble], - mean: Feature[JDouble], - topK: Seq[Feature[JDouble]]) - -trait CombineCountsBase { - val SparseSum = "sparse_sum" - val SparseNonzero = "sparse_nonzero" - val SparseMean = "sparse_mean" - val SparseTop = "sparse_top" - - def topK: Int - def hardLimit: Option[Int] - def precomputedCountFeatures: Seq[Feature[_]] - - lazy val precomputedFeaturesMap: Map[Feature[_], CombinedFeatures] = - precomputedCountFeatures.map { countFeature => - val derivedPersonalDataTypes = - AggregationMetricCommon.derivePersonalDataTypes(Some(countFeature)) - val sum = new Feature.Continuous( - countFeature.getDenseFeatureName + "." + SparseSum, - derivedPersonalDataTypes) - val nonzero = new Feature.Continuous( - countFeature.getDenseFeatureName + "." + SparseNonzero, - derivedPersonalDataTypes) - val mean = new Feature.Continuous( - countFeature.getDenseFeatureName + "." + SparseMean, - derivedPersonalDataTypes) - val topKFeatures = (1 to topK).map { k => - new Feature.Continuous( - countFeature.getDenseFeatureName + "." + SparseTop + k, - derivedPersonalDataTypes) - } - (countFeature, CombinedFeatures(sum, nonzero, mean, topKFeatures)) - }.toMap - - lazy val outputFeaturesPostMerge: Set[Feature[JDouble]] = - precomputedFeaturesMap.values.flatMap { combinedFeatures: CombinedFeatures => - Seq( - combinedFeatures.sum, - combinedFeatures.nonzero, - combinedFeatures.mean - ) ++ combinedFeatures.topK - }.toSet - - private case class ComputedStats(sum: Double, nonzero: Double, mean: Double) - - private def preComputeStats(featureValues: Seq[Double]): ComputedStats = { - val (sum, nonzero) = featureValues.foldLeft((0.0, 0.0)) { - case ((accSum, accNonzero), value) => - (accSum + value, if (value > 0.0) accNonzero + 1.0 else accNonzero) - } - ComputedStats(sum, nonzero, if (nonzero > 0.0) sum / nonzero else 0.0) - } - - private def computeSortedFeatureValues(featureValues: List[Double]): List[Double] = - featureValues.sortBy(-_) - - private def extractKth(sortedFeatureValues: Seq[Double], k: Int): Double = - sortedFeatureValues - .lift(k - 1) - .getOrElse(0.0) - - private def setContinuousFeatureIfNonZero( - record: SRichDataRecord, - feature: Feature[JDouble], - value: Double - ): Unit = - if (value != 0.0) { - record.setFeatureValue(feature, value) - } - - def hydrateCountFeatures( - richRecord: SRichDataRecord, - features: Seq[Feature[_]], - featureValuesMap: Map[Feature[_], List[Double]] - ): Unit = - for { - feature <- features - featureValues <- featureValuesMap.get(feature) - } { - mergeRecordFromCountFeature( - countFeature = feature, - featureValues = featureValues, - richInputRecord = richRecord - ) - } - - def mergeRecordFromCountFeature( - richInputRecord: SRichDataRecord, - countFeature: Feature[_], - featureValues: List[Double] - ): Unit = { - // In majority of calls to this method from timeline scorer - // the featureValues list is empty. - // While with empty list each operation will be not that expensive, these - // small things do add up. By adding early stop here we can avoid sorting - // empty list, allocating several options and making multiple function - // calls. In addition to that, we won't iterate over [1, topK]. - if (featureValues.nonEmpty) { - val sortedFeatureValues = hardLimit - .map { limit => - computeSortedFeatureValues(featureValues).take(limit) - }.getOrElse(computeSortedFeatureValues(featureValues)).toIndexedSeq - val computed = preComputeStats(sortedFeatureValues) - - val combinedFeatures = precomputedFeaturesMap(countFeature) - setContinuousFeatureIfNonZero( - richInputRecord, - combinedFeatures.sum, - computed.sum - ) - setContinuousFeatureIfNonZero( - richInputRecord, - combinedFeatures.nonzero, - computed.nonzero - ) - setContinuousFeatureIfNonZero( - richInputRecord, - combinedFeatures.mean, - computed.mean - ) - (1 to topK).foreach { k => - setContinuousFeatureIfNonZero( - richInputRecord, - combinedFeatures.topK(k - 1), - extractKth(sortedFeatureValues, k) - ) - } - } - } -} - -object CombineCountsPolicy { - def getCountFeatures(aggregateContext: FeatureContext): Seq[Feature[_]] = - aggregateContext.getAllFeatures.asScala.toSeq - .filter { feature => - feature.getFeatureType == FeatureType.CONTINUOUS && - feature.getDenseFeatureName.endsWith(TypedCountMetric[JDouble]().operatorName) - } - - @VisibleForTesting - private[conversion] def getFeatureValues( - dataRecordsWithCounts: List[DataRecord], - countFeature: Feature[_] - ): List[Double] = - dataRecordsWithCounts.map(new SRichDataRecord(_)).flatMap { record => - Option(record.getFeatureValue(countFeature)).map(_.asInstanceOf[JDouble].toDouble) - } -} - -/** - * A merge policy that works whenever all aggregate features are - * counts (computed using CountMetric), and typically represent - * either impressions or engagements. For each such input count - * feature, the policy outputs the following (3+k) derived features - * into the output data record: - * - * Sum of the feature's value across all aggregate records - * Number of aggregate records that have the feature set to non-zero - * Mean of the feature's value across all aggregate records - * topK values of the feature across all aggregate records - * - * @param topK topK values to compute - * @param hardLimit when set, records are sorted and only the top values will be used for aggregation if - * the number of records are higher than this hard limit. - */ -case class CombineCountsPolicy( - override val topK: Int, - aggregateContextToPrecompute: FeatureContext, - override val hardLimit: Option[Int] = None) - extends SparseBinaryMergePolicy - with CombineCountsBase { - import CombineCountsPolicy._ - override val precomputedCountFeatures: Seq[Feature[_]] = getCountFeatures( - aggregateContextToPrecompute) - - override def mergeRecord( - mutableInputRecord: DataRecord, - aggregateRecords: List[DataRecord], - aggregateContext: FeatureContext - ): Unit = { - // Assumes aggregateContext === aggregateContextToPrecompute - mergeRecordFromCountFeatures(mutableInputRecord, aggregateRecords, precomputedCountFeatures) - } - - def defaultMergeRecord( - mutableInputRecord: DataRecord, - aggregateRecords: List[DataRecord] - ): Unit = { - mergeRecordFromCountFeatures(mutableInputRecord, aggregateRecords, precomputedCountFeatures) - } - - def mergeRecordFromCountFeatures( - mutableInputRecord: DataRecord, - aggregateRecords: List[DataRecord], - countFeatures: Seq[Feature[_]] - ): Unit = { - val richInputRecord = new SRichDataRecord(mutableInputRecord) - countFeatures.foreach { countFeature => - mergeRecordFromCountFeature( - richInputRecord = richInputRecord, - countFeature = countFeature, - featureValues = getFeatureValues(aggregateRecords, countFeature) - ) - } - } - - override def aggregateFeaturesPostMerge(aggregateContext: FeatureContext): Set[Feature[_]] = - outputFeaturesPostMerge.map(_.asInstanceOf[Feature[_]]) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/DataSetPipeSketchJoin.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/DataSetPipeSketchJoin.docx new file mode 100644 index 000000000..e666b0cdf Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/DataSetPipeSketchJoin.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/DataSetPipeSketchJoin.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/DataSetPipeSketchJoin.scala deleted file mode 100644 index 8d3dd58bb..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/DataSetPipeSketchJoin.scala +++ /dev/null @@ -1,46 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.bijection.Injection -import com.twitter.ml.api._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.scalding.TypedPipe - -object DataSetPipeSketchJoin { - val DefaultSketchNumReducers = 500 - val dataRecordMerger: DataRecordMerger = new DataRecordMerger - implicit val str2Byte: String => Array[Byte] = - implicitly[Injection[String, Array[Byte]]].toFunction - - /* Computes a left sketch join on a set of skewed keys. */ - def apply( - inputDataSet: DataSetPipe, - skewedJoinKeys: Product, - joinFeaturesDataSet: DataSetPipe, - sketchNumReducers: Int = DefaultSketchNumReducers - ): DataSetPipe = { - val joinKeyList = skewedJoinKeys.productIterator.toList.asInstanceOf[List[Feature[_]]] - - def makeKey(record: DataRecord): String = - joinKeyList - .map(SRichDataRecord(record).getFeatureValue(_)) - .toString - - def byKey(pipe: DataSetPipe): TypedPipe[(String, DataRecord)] = - pipe.records.map(record => (makeKey(record), record)) - - val joinedRecords = byKey(inputDataSet) - .sketch(sketchNumReducers) - .leftJoin(byKey(joinFeaturesDataSet)) - .values - .map { - case (inputRecord, joinFeaturesOpt) => - joinFeaturesOpt.foreach { joinRecord => dataRecordMerger.merge(inputRecord, joinRecord) } - inputRecord - } - - DataSetPipe( - joinedRecords, - FeatureContext.merge(inputDataSet.featureContext, joinFeaturesDataSet.featureContext) - ) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/PickFirstRecordPolicy.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/PickFirstRecordPolicy.docx new file mode 100644 index 000000000..1f3a7654c Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/PickFirstRecordPolicy.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/PickFirstRecordPolicy.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/PickFirstRecordPolicy.scala deleted file mode 100644 index b022d35b0..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/PickFirstRecordPolicy.scala +++ /dev/null @@ -1,26 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.ml.api._ -import com.twitter.ml.api.FeatureContext -import scala.collection.JavaConverters._ - -/* - * A really bad default merge policy that picks all the aggregate - * features corresponding to the first sparse key value in the list. - * Does not rename any of the aggregate features for simplicity. - * Avoid using this merge policy if at all possible. - */ -object PickFirstRecordPolicy extends SparseBinaryMergePolicy { - val dataRecordMerger: DataRecordMerger = new DataRecordMerger - - override def mergeRecord( - mutableInputRecord: DataRecord, - aggregateRecords: List[DataRecord], - aggregateContext: FeatureContext - ): Unit = - aggregateRecords.headOption - .foreach(aggregateRecord => dataRecordMerger.merge(mutableInputRecord, aggregateRecord)) - - override def aggregateFeaturesPostMerge(aggregateContext: FeatureContext): Set[Feature[_]] = - aggregateContext.getAllFeatures.asScala.toSet -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/PickTopCtrPolicy.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/PickTopCtrPolicy.docx new file mode 100644 index 000000000..3f82dc38f Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/PickTopCtrPolicy.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/PickTopCtrPolicy.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/PickTopCtrPolicy.scala deleted file mode 100644 index 94d3ac126..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/PickTopCtrPolicy.scala +++ /dev/null @@ -1,226 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.ml.api._ -import com.twitter.ml.api.FeatureContext -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon -import java.lang.{Boolean => JBoolean} -import java.lang.{Double => JDouble} - -case class CtrDescriptor( - engagementFeature: Feature[JDouble], - impressionFeature: Feature[JDouble], - outputFeature: Feature[JDouble]) - -object PickTopCtrBuilderHelper { - - def createCtrDescriptors( - aggregatePrefix: String, - engagementLabels: Set[Feature[JBoolean]], - aggregatesToCompute: Set[TypedAggregateGroup[_]], - outputSuffix: String - ): Set[CtrDescriptor] = { - val aggregateFeatures = aggregatesToCompute - .filter(_.aggregatePrefix == aggregatePrefix) - - val impressionFeature = aggregateFeatures - .flatMap { group => - group.individualAggregateDescriptors - .filter(_.query.feature == None) - .filter(_.query.label == None) - .flatMap(_.outputFeatures) - } - .head - .asInstanceOf[Feature[JDouble]] - - val aggregateEngagementFeatures = - aggregateFeatures - .flatMap { group => - group.individualAggregateDescriptors - .filter(_.query.feature == None) - .filter { descriptor => - //TODO: we should remove the need to pass around engagementLabels and just use all the labels available. - descriptor.query.label.exists(engagementLabels.contains(_)) - } - .flatMap(_.outputFeatures) - } - .map(_.asInstanceOf[Feature[JDouble]]) - - aggregateEngagementFeatures - .map { aggregateEngagementFeature => - CtrDescriptor( - engagementFeature = aggregateEngagementFeature, - impressionFeature = impressionFeature, - outputFeature = new Feature.Continuous( - aggregateEngagementFeature.getDenseFeatureName + "." + outputSuffix, - AggregationMetricCommon.derivePersonalDataTypes( - Some(aggregateEngagementFeature), - Some(impressionFeature) - ) - ) - ) - } - } -} - -object PickTopCtrPolicy { - def build( - aggregatePrefix: String, - engagementLabels: Set[Feature[JBoolean]], - aggregatesToCompute: Set[TypedAggregateGroup[_]], - smoothing: Double = 1.0, - outputSuffix: String = "ratio" - ): PickTopCtrPolicy = { - val ctrDescriptors = PickTopCtrBuilderHelper.createCtrDescriptors( - aggregatePrefix = aggregatePrefix, - engagementLabels = engagementLabels, - aggregatesToCompute = aggregatesToCompute, - outputSuffix = outputSuffix - ) - PickTopCtrPolicy( - ctrDescriptors = ctrDescriptors, - smoothing = smoothing - ) - } -} - -object CombinedTopNCtrsByWilsonConfidenceIntervalPolicy { - def build( - aggregatePrefix: String, - engagementLabels: Set[Feature[JBoolean]], - aggregatesToCompute: Set[TypedAggregateGroup[_]], - outputSuffix: String = "ratioWithWCI", - z: Double = 1.96, - topN: Int = 1 - ): CombinedTopNCtrsByWilsonConfidenceIntervalPolicy = { - val ctrDescriptors = PickTopCtrBuilderHelper.createCtrDescriptors( - aggregatePrefix = aggregatePrefix, - engagementLabels = engagementLabels, - aggregatesToCompute = aggregatesToCompute, - outputSuffix = outputSuffix - ) - CombinedTopNCtrsByWilsonConfidenceIntervalPolicy( - ctrDescriptors = ctrDescriptors, - z = z, - topN = topN - ) - } -} - -/* - * A merge policy that picks the aggregate features corresponding to - * the sparse key value with the highest engagement rate (defined - * as the ratio of two specified features, representing engagements - * and impressions). Also outputs the engagement rate to the specified - * outputFeature. - * - * This is an abstract class. We can make variants of this policy by overriding - * the calculateCtr method. - */ - -abstract class PickTopCtrPolicyBase(ctrDescriptors: Set[CtrDescriptor]) - extends SparseBinaryMergePolicy { - - private def getContinuousFeature( - aggregateRecord: DataRecord, - feature: Feature[JDouble] - ): Double = { - Option(SRichDataRecord(aggregateRecord).getFeatureValue(feature)) - .map(_.asInstanceOf[JDouble].toDouble) - .getOrElse(0.0) - } - - /** - * For every provided descriptor, compute the corresponding CTR feature - * and only hydrate this result to the provided input record. - */ - override def mergeRecord( - mutableInputRecord: DataRecord, - aggregateRecords: List[DataRecord], - aggregateContext: FeatureContext - ): Unit = { - ctrDescriptors - .foreach { - case CtrDescriptor(engagementFeature, impressionFeature, outputFeature) => - val sortedCtrs = - aggregateRecords - .map { aggregateRecord => - val impressions = getContinuousFeature(aggregateRecord, impressionFeature) - val engagements = getContinuousFeature(aggregateRecord, engagementFeature) - calculateCtr(impressions, engagements) - } - .sortBy { ctr => -ctr } - combineTopNCtrsToSingleScore(sortedCtrs) - .foreach { score => - SRichDataRecord(mutableInputRecord).setFeatureValue(outputFeature, score) - } - } - } - - protected def calculateCtr(impressions: Double, engagements: Double): Double - - protected def combineTopNCtrsToSingleScore(sortedCtrs: Seq[Double]): Option[Double] - - override def aggregateFeaturesPostMerge(aggregateContext: FeatureContext): Set[Feature[_]] = - ctrDescriptors - .map(_.outputFeature) - .toSet -} - -case class PickTopCtrPolicy(ctrDescriptors: Set[CtrDescriptor], smoothing: Double = 1.0) - extends PickTopCtrPolicyBase(ctrDescriptors) { - require(smoothing > 0.0) - - override def calculateCtr(impressions: Double, engagements: Double): Double = - (1.0 * engagements) / (smoothing + impressions) - - override def combineTopNCtrsToSingleScore(sortedCtrs: Seq[Double]): Option[Double] = - sortedCtrs.headOption -} - -case class CombinedTopNCtrsByWilsonConfidenceIntervalPolicy( - ctrDescriptors: Set[CtrDescriptor], - z: Double = 1.96, - topN: Int = 1) - extends PickTopCtrPolicyBase(ctrDescriptors) { - - private val zSquared = z * z - private val zSquaredDiv2 = zSquared / 2.0 - private val zSquaredDiv4 = zSquared / 4.0 - - /** - * calculates the lower bound of wilson score interval. which roughly says "the actual engagement - * rate is at least this value" with confidence designated by the z-score: - * https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval - */ - override def calculateCtr(rawImpressions: Double, engagements: Double): Double = { - // just in case engagements happens to be more than impressions... - val impressions = Math.max(rawImpressions, engagements) - - if (impressions > 0.0) { - val p = engagements / impressions - (p - + zSquaredDiv2 / impressions - - z * Math.sqrt( - (p * (1.0 - p) + zSquaredDiv4 / impressions) / impressions)) / (1.0 + zSquared / impressions) - - } else 0.0 - } - - /** - * takes the topN engagement rates, and returns the joint probability as {1.0 - Π(1.0 - p)} - * - * e.g. let's say you have 0.6 chance of clicking on a tweet shared by the user A. - * you also have 0.3 chance of clicking on a tweet shared by the user B. - * seeing a tweet shared by both A and B will not lead to 0.9 chance of you clicking on it. - * but you could say that you have 0.4*0.7 chance of NOT clicking on that tweet. - */ - override def combineTopNCtrsToSingleScore(sortedCtrs: Seq[Double]): Option[Double] = - if (sortedCtrs.nonEmpty) { - val inverseLogP = sortedCtrs - .take(topN).map { p => Math.log(1.0 - p) }.sum - Some(1.0 - Math.exp(inverseLogP)) - } else None - -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryAggregateJoin.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryAggregateJoin.docx new file mode 100644 index 000000000..763b69f4d Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryAggregateJoin.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryAggregateJoin.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryAggregateJoin.scala deleted file mode 100644 index 10c6a9096..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryAggregateJoin.scala +++ /dev/null @@ -1,199 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.ml.api._ -import com.twitter.ml.api.Feature -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.scalding.typed.TypedPipe -import com.twitter.scalding.typed.UnsortedGrouped -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import java.util.{Set => JSet} -import scala.collection.JavaConverters._ - -object SparseBinaryAggregateJoin { - import TypedAggregateGroup._ - - def makeKey(record: DataRecord, joinKeyList: List[Feature[_]]): String = { - joinKeyList.map { - case sparseKey: Feature.SparseBinary => - SRichDataRecord(record).getFeatureValue(sparseFeature(sparseKey)) - case nonSparseKey: Feature[_] => - SRichDataRecord(record).getFeatureValue(nonSparseKey) - }.toString - } - - /** - * @param record Data record to get all possible sparse aggregate keys from - * @param List of join key features (some can be sparse and some non-sparse) - * @return A list of string keys to use for joining - */ - def makeKeyPermutations(record: DataRecord, joinKeyList: List[Feature[_]]): List[String] = { - val allIdValues = joinKeyList.flatMap { - case sparseKey: Feature.SparseBinary => { - val id = sparseKey.getDenseFeatureId - val valuesOpt = Option(SRichDataRecord(record).getFeatureValue(sparseKey)) - .map(_.asInstanceOf[JSet[String]].asScala.toSet) - valuesOpt.map { (id, _) } - } - case nonSparseKey: Feature[_] => { - val id = nonSparseKey.getDenseFeatureId - Option(SRichDataRecord(record).getFeatureValue(nonSparseKey)).map { value => - (id, Set(value.toString)) - } - } - } - sparseBinaryPermutations(allIdValues).toList.map { idValues => - joinKeyList.map { key => idValues.getOrElse(key.getDenseFeatureId, "") }.toString - } - } - - private[this] def mkKeyIndexedAggregates( - joinFeaturesDataSet: DataSetPipe, - joinKeyList: List[Feature[_]] - ): TypedPipe[(String, DataRecord)] = - joinFeaturesDataSet.records - .map { record => (makeKey(record, joinKeyList), record) } - - private[this] def mkKeyIndexedInput( - inputDataSet: DataSetPipe, - joinKeyList: List[Feature[_]] - ): TypedPipe[(String, DataRecord)] = - inputDataSet.records - .flatMap { record => - for { - key <- makeKeyPermutations(record, joinKeyList) - } yield { (key, record) } - } - - private[this] def mkKeyIndexedInputWithUniqueId( - inputDataSet: DataSetPipe, - joinKeyList: List[Feature[_]], - uniqueIdFeatureList: List[Feature[_]] - ): TypedPipe[(String, String)] = - inputDataSet.records - .flatMap { record => - for { - key <- makeKeyPermutations(record, joinKeyList) - } yield { (key, makeKey(record, uniqueIdFeatureList)) } - } - - private[this] def mkRecordIndexedAggregates( - keyIndexedInput: TypedPipe[(String, DataRecord)], - keyIndexedAggregates: TypedPipe[(String, DataRecord)] - ): UnsortedGrouped[DataRecord, List[DataRecord]] = - keyIndexedInput - .join(keyIndexedAggregates) - .map { case (_, (inputRecord, aggregateRecord)) => (inputRecord, aggregateRecord) } - .group - .toList - - private[this] def mkRecordIndexedAggregatesWithUniqueId( - keyIndexedInput: TypedPipe[(String, String)], - keyIndexedAggregates: TypedPipe[(String, DataRecord)] - ): UnsortedGrouped[String, List[DataRecord]] = - keyIndexedInput - .join(keyIndexedAggregates) - .map { case (_, (inputId, aggregateRecord)) => (inputId, aggregateRecord) } - .group - .toList - - def mkJoinedDataSet( - inputDataSet: DataSetPipe, - joinFeaturesDataSet: DataSetPipe, - recordIndexedAggregates: UnsortedGrouped[DataRecord, List[DataRecord]], - mergePolicy: SparseBinaryMergePolicy - ): TypedPipe[DataRecord] = - inputDataSet.records - .map(record => (record, ())) - .leftJoin(recordIndexedAggregates) - .map { - case (inputRecord, (_, aggregateRecordsOpt)) => - aggregateRecordsOpt - .map { aggregateRecords => - mergePolicy.mergeRecord( - inputRecord, - aggregateRecords, - joinFeaturesDataSet.featureContext - ) - inputRecord - } - .getOrElse(inputRecord) - } - - def mkJoinedDataSetWithUniqueId( - inputDataSet: DataSetPipe, - joinFeaturesDataSet: DataSetPipe, - recordIndexedAggregates: UnsortedGrouped[String, List[DataRecord]], - mergePolicy: SparseBinaryMergePolicy, - uniqueIdFeatureList: List[Feature[_]] - ): TypedPipe[DataRecord] = - inputDataSet.records - .map(record => (makeKey(record, uniqueIdFeatureList), record)) - .leftJoin(recordIndexedAggregates) - .map { - case (_, (inputRecord, aggregateRecordsOpt)) => - aggregateRecordsOpt - .map { aggregateRecords => - mergePolicy.mergeRecord( - inputRecord, - aggregateRecords, - joinFeaturesDataSet.featureContext - ) - inputRecord - } - .getOrElse(inputRecord) - } - - /** - * If uniqueIdFeatures is non-empty and the join keys include a sparse binary - * key, the join will use this set of keys as a unique id to reduce - * memory consumption. You should need this option only for - * memory-intensive joins to avoid OOM errors. - */ - def apply( - inputDataSet: DataSetPipe, - joinKeys: Product, - joinFeaturesDataSet: DataSetPipe, - mergePolicy: SparseBinaryMergePolicy = PickFirstRecordPolicy, - uniqueIdFeaturesOpt: Option[Product] = None - ): DataSetPipe = { - val joinKeyList = joinKeys.productIterator.toList.asInstanceOf[List[Feature[_]]] - val sparseBinaryJoinKeySet = - joinKeyList.toSet.filter(_.getFeatureType() == FeatureType.SPARSE_BINARY) - val containsSparseBinaryKey = !sparseBinaryJoinKeySet.isEmpty - if (containsSparseBinaryKey) { - val uniqueIdFeatureList = uniqueIdFeaturesOpt - .map(uniqueIdFeatures => - uniqueIdFeatures.productIterator.toList.asInstanceOf[List[Feature[_]]]) - .getOrElse(List.empty[Feature[_]]) - val keyIndexedAggregates = mkKeyIndexedAggregates(joinFeaturesDataSet, joinKeyList) - val joinedDataSet = if (uniqueIdFeatureList.isEmpty) { - val keyIndexedInput = mkKeyIndexedInput(inputDataSet, joinKeyList) - val recordIndexedAggregates = - mkRecordIndexedAggregates(keyIndexedInput, keyIndexedAggregates) - mkJoinedDataSet(inputDataSet, joinFeaturesDataSet, recordIndexedAggregates, mergePolicy) - } else { - val keyIndexedInput = - mkKeyIndexedInputWithUniqueId(inputDataSet, joinKeyList, uniqueIdFeatureList) - val recordIndexedAggregates = - mkRecordIndexedAggregatesWithUniqueId(keyIndexedInput, keyIndexedAggregates) - mkJoinedDataSetWithUniqueId( - inputDataSet, - joinFeaturesDataSet, - recordIndexedAggregates, - mergePolicy, - uniqueIdFeatureList - ) - } - - DataSetPipe( - joinedDataSet, - mergePolicy.mergeContext( - inputDataSet.featureContext, - joinFeaturesDataSet.featureContext - ) - ) - } else { - inputDataSet.joinWithSmaller(joinKeys, joinFeaturesDataSet) { _.pass } - } - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMergePolicy.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMergePolicy.docx new file mode 100644 index 000000000..9fe54e679 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMergePolicy.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMergePolicy.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMergePolicy.scala deleted file mode 100644 index 7201e39a2..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMergePolicy.scala +++ /dev/null @@ -1,81 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.ml.api._ -import com.twitter.ml.api.FeatureContext -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import scala.collection.JavaConverters._ - -/** - * When using the aggregates framework to group by sparse binary keys, - * we generate different aggregate feature values for each possible - * value of the sparse key. Hence, when joining back the aggregate - * features with a training data set, each individual training record - * has multiple aggregate features to choose from, for each value taken - * by the sparse key(s) in the training record. The merge policy trait - * below specifies how to condense/combine this variable number of - * aggregate features into a constant number of features for training. - * Some simple policies might be: pick the first feature set (randomly), - * pick the top sorted by some attribute, or take some average. - * - * Example: suppose we group by (ADVERTISER_ID, INTEREST_ID) where INTEREST_ID - * is the sparse key, and compute a "CTR" aggregate feature for each such - * pair measuring the click through rate on ads with (ADVERTISER_ID, INTEREST_ID). - * Say we have the following aggregate records: - * - * (ADVERTISER_ID = 1, INTEREST_ID = 1, CTR = 5%) - * (ADVERTISER_ID = 1, INTEREST_ID = 2, CTR = 15%) - * (ADVERTISER_ID = 2, INTEREST_ID = 1, CTR = 1%) - * (ADVERTISER_ID = 2, INTEREST_ID = 2, CTR = 10%) - * ... - * At training time, each training record has one value for ADVERTISER_ID, but it - * has multiple values for INTEREST_ID e.g. - * - * (ADVERTISER_ID = 1, INTEREST_IDS = (1,2)) - * - * There are multiple potential CTRs we can get when joining in the aggregate features: - * in this case 2 values (5% and 15%) but in general it could be many depending on how - * many interests the user has. When joining back the CTR features, the merge policy says how to - * combine all these CTRs to engineer features. - * - * "Pick first" would say - pick some random CTR (whatever is first in the list, maybe 5%) - * for training (probably not a good policy). "Sort by CTR" could be a policy - * that just picks the top CTR and uses it as a feature (here 15%). Similarly, you could - * imagine "Top K sorted by CTR" (use both 5 and 15%) or "Avg CTR" (10%) or other policies, - * all of which are defined as objects/case classes that override this trait. - */ -trait SparseBinaryMergePolicy { - - /** - * @param mutableInputRecord Input record to add aggregates to - * @param aggregateRecords Aggregate feature records - * @param aggregateContext Context for aggregate records - */ - def mergeRecord( - mutableInputRecord: DataRecord, - aggregateRecords: List[DataRecord], - aggregateContext: FeatureContext - ): Unit - - def aggregateFeaturesPostMerge(aggregateContext: FeatureContext): Set[Feature[_]] - - /** - * @param inputContext Context for input record - * @param aggregateContext Context for aggregate records - * @return Context for record returned by mergeRecord() - */ - def mergeContext( - inputContext: FeatureContext, - aggregateContext: FeatureContext - ): FeatureContext = new FeatureContext( - (inputContext.getAllFeatures.asScala.toSet ++ aggregateFeaturesPostMerge( - aggregateContext)).toSeq.asJava - ) - - def allOutputFeaturesPostMergePolicy[T](config: TypedAggregateGroup[T]): Set[Feature[_]] = { - val containsSparseBinary = config.keysToAggregate - .exists(_.getFeatureType == FeatureType.SPARSE_BINARY) - - if (!containsSparseBinary) config.allOutputFeatures - else aggregateFeaturesPostMerge(new FeatureContext(config.allOutputFeatures.toSeq.asJava)) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMultipleAggregateJoin.docx b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMultipleAggregateJoin.docx new file mode 100644 index 000000000..a30f20d46 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMultipleAggregateJoin.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMultipleAggregateJoin.scala b/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMultipleAggregateJoin.scala deleted file mode 100644 index d0aff7e34..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/conversion/SparseBinaryMultipleAggregateJoin.scala +++ /dev/null @@ -1,109 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion - -import com.twitter.bijection.Injection -import com.twitter.ml.api._ -import com.twitter.ml.api.Feature -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.scalding.typed.TypedPipe -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup.sparseFeature -import scala.collection.JavaConverters._ - -case class SparseJoinConfig( - aggregates: DataSetPipe, - sparseKey: Feature.SparseBinary, - mergePolicies: SparseBinaryMergePolicy*) - -object SparseBinaryMultipleAggregateJoin { - type CommonMap = (String, ((Feature.SparseBinary, String), DataRecord)) - - def apply( - source: DataSetPipe, - commonKey: Feature[_], - joinConfigs: Set[SparseJoinConfig], - rightJoin: Boolean = false, - isSketchJoin: Boolean = false, - numSketchJoinReducers: Int = 0 - ): DataSetPipe = { - val emptyPipe: TypedPipe[CommonMap] = TypedPipe.empty - val aggregateMaps: Set[TypedPipe[CommonMap]] = joinConfigs.map { joinConfig => - joinConfig.aggregates.records.map { record => - val sparseKeyValue = - SRichDataRecord(record).getFeatureValue(sparseFeature(joinConfig.sparseKey)).toString - val commonKeyValue = SRichDataRecord(record).getFeatureValue(commonKey).toString - (commonKeyValue, ((joinConfig.sparseKey, sparseKeyValue), record)) - } - } - - val commonKeyToAggregateMap = aggregateMaps - .foldLeft(emptyPipe) { - case (union: TypedPipe[CommonMap], next: TypedPipe[CommonMap]) => - union ++ next - } - .group - .toList - .map { - case (commonKeyValue, aggregateTuples) => - (commonKeyValue, aggregateTuples.toMap) - } - - val commonKeyToRecordMap = source.records - .map { record => - val commonKeyValue = SRichDataRecord(record).getFeatureValue(commonKey).toString - (commonKeyValue, record) - } - - // rightJoin is not supported by Sketched, so rightJoin will be ignored if isSketchJoin is set - implicit val string2Byte = (value: String) => Injection[String, Array[Byte]](value) - val intermediateRecords = if (isSketchJoin) { - commonKeyToRecordMap.group - .sketch(numSketchJoinReducers) - .leftJoin(commonKeyToAggregateMap) - .toTypedPipe - } else if (rightJoin) { - commonKeyToAggregateMap - .rightJoin(commonKeyToRecordMap) - .mapValues(_.swap) - .toTypedPipe - } else { - commonKeyToRecordMap.leftJoin(commonKeyToAggregateMap).toTypedPipe - } - - val joinedRecords = intermediateRecords - .map { - case (commonKeyValue, (inputRecord, aggregateTupleMapOpt)) => - aggregateTupleMapOpt.foreach { aggregateTupleMap => - joinConfigs.foreach { joinConfig => - val sparseKeyValues = Option( - SRichDataRecord(inputRecord) - .getFeatureValue(joinConfig.sparseKey) - ).map(_.asScala.toList) - .getOrElse(List.empty[String]) - - val aggregateRecords = sparseKeyValues.flatMap { sparseKeyValue => - aggregateTupleMap.get((joinConfig.sparseKey, sparseKeyValue)) - } - - joinConfig.mergePolicies.foreach { mergePolicy => - mergePolicy.mergeRecord( - inputRecord, - aggregateRecords, - joinConfig.aggregates.featureContext - ) - } - } - } - inputRecord - } - - val joinedFeatureContext = joinConfigs - .foldLeft(source.featureContext) { - case (left, joinConfig) => - joinConfig.mergePolicies.foldLeft(left) { - case (soFar, mergePolicy) => - mergePolicy.mergeContext(soFar, joinConfig.aggregates.featureContext) - } - } - - DataSetPipe(joinedRecords, joinedFeatureContext) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/AUTOMATED_COMMIT_FILES b/timelines/data_processing/ml_util/aggregation_framework/docs/AUTOMATED_COMMIT_FILES deleted file mode 100644 index 80aaae8d9..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/AUTOMATED_COMMIT_FILES +++ /dev/null @@ -1,5 +0,0 @@ -aggregation.rst -batch.rst -index.rst -real-time.rst -troubleshooting.rst diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/AUTOMATED_COMMIT_FILES.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/AUTOMATED_COMMIT_FILES.docx new file mode 100644 index 000000000..010636351 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/AUTOMATED_COMMIT_FILES.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/aggregation.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/aggregation.docx new file mode 100644 index 000000000..e4b6758c7 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/aggregation.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/aggregation.rst b/timelines/data_processing/ml_util/aggregation_framework/docs/aggregation.rst deleted file mode 100644 index fddd926b4..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/aggregation.rst +++ /dev/null @@ -1,167 +0,0 @@ -.. _aggregation: - -Core Concepts -============= - -This page provides an overview of the aggregation framework and goes through examples on how to define aggregate features. In general, we can think of an aggregate feature as a grouped set of records, on which we incrementally update the aggregate feature values, crossed by the provided features and conditional on the provided labels. - -AggregateGroup --------------- - -An `AggregateGroup` defines a single unit of aggregate computation, similar to a SQL query. These are executed by the underlying jobs (internally, a `DataRecordAggregationMonoid `_ is applied to `DataRecords` that contain the features to aggregate). Many of these groups can exist to define different types of aggregate features. - -Let's start with the following examples of an `AggregateGroup` to discuss the meaning of each of its constructor arguments: - -.. code-block:: scala - - val UserAggregateStore = "user_aggregates" - val aggregatesToCompute: Set[TypedAggregateGroup[_]] = Set( - AggregateGroup( - inputSource = timelinesDailyRecapSource, - aggregatePrefix = "user_aggregate_v2", - preTransformOpt = Some(RemoveUserIdZero), - keys = Set(USER_ID), - features = Set(HAS_PHOTO), - labels = Set(IS_FAVORITED), - metrics = Set(CountMetric, SumMetric), - halfLives = Set(50.days), - outputStore = OfflineAggregateStore( - name = UserAggregateStore, - startDate = "2016-07-15 00:00", - commonConfig = timelinesDailyAggregateSink, - batchesToKeep = 5 - ) - ) - .flatMap(_.buildTypedAggregateGroups) - ) - -This `AggregateGroup` computes the number of times each user has faved a tweet with a photo. The aggregate count is decayed with a 50 day halflife. - -Naming and preprocessing ------------------------- - -`UserAggregateStore` is a string val that acts as a scope of a "root path" to which this group of aggregate features will be written. The root path is provided separately by the implementing job. - -`inputSource` defines the input source of `DataRecords` that we aggregate on. These records contain the relevant features required for aggregation. - -`aggregatePrefix` tells the framework what prefix to use for the aggregate features it generates. A descriptive naming scheme with versioning makes it easier to maintain features as you add or remove them over the long-term. - -`preTransforms` is a `Seq[com.twitter.ml.api.ITransform] `_ that can be applied to the data records read from the input source before they are fed into the `AggregateGroup` to apply aggregation. These transforms are optional but can be useful for certain preprocessing operations for a group's raw input features. - -.. admonition:: Examples - - You can downsample input data records by providing `preTransforms`. In addition, you could also join different input labels (e.g. "is_push_openend" and "is_push_favorited") and transform them into a combined label that is their union ("is_push_engaged") on which aggregate counts will be calculated. - - -Keys ----- - -`keys` is a crucial field in the config. It defines a `Set[com.twitter.ml.api.Feature]` which specifies a set of grouping keys to use for this `AggregateGroup`. - -Keys can only be of 3 supported types currently: `DISCRETE`, `STRING` and `SPARSE_BINARY`. Using a discrete or a string/text feature as a key specifies the unit to group records by before applying counting/aggregation operators. - - -.. admonition:: Examples - - .. cssclass:: shortlist - - #. If the key is `USER_ID`, this tells the framework to group all records by `USER_ID`, and then apply aggregations (sum/count/etc) within each user’s data to generate aggregate features for each user. - - #. If the key is `(USER_ID, AUTHOR_ID)`, then the `AggregateGroup` will output features for each unique user-author pair in the input data. - - #. Finally, using a sparse binary feature as key has special "flattening" or "flatMap" like semantics. For example, consider grouping by `(USER_ID, AUTHOR_INTEREST_IDS)` where `AUTHOR_INTEREST_IDS` is a sparse binary feature which represents a set of topic IDs the author may be tweeting about. This creates one record for each `(user_id, interest_id)` pair - so each record with multiple author interests is flattened before feeding it to the aggregation. - -Features --------- - -`features` specifies a `Set[com.twitter.ml.api.Feature]` to aggregate within each group (defined by the keys specified earlier). - -We support 2 types of `features`: `BINARY` and `CONTINUOUS`. - -The semantics of how the aggregation works is slightly different based on the type of “feature”, and based on the “metric” (or aggregation operation): - -.. cssclass:: shortlist - -#. Binary Feature, Count Metric: Suppose we have a binary feature `HAS_PHOTO` in this set, and are applying the “Count” metric (see below for more details on the metrics), with key `USER_ID`. The semantics is that this computes a feature which measures the count of records with `HAS_PHOTO` set to true for each user. - -#. Binary Feature, Sum Metric - Does not apply. No feature will be computed. - -#. Continuous Feature, Count Metric - The count metric treats all features as binary features ignoring their value. For example, suppose we have a continuous feature `NUM_CHARACTERS_IN_TWEET`, and key `USER_ID`. This measures the count of records that have this feature `NUM_CHARACTERS_IN_TWEET` present. - -#. Continuous Feature, Sum Metric - In the above example, the features measures the sum of (num_characters_in_tweet) over all a user’s records. Dividing this sum feature by the count feature would give the average number of characters in all tweets. - -.. admonition:: Unsupported feature types - - `DISCRETE` and `SPARSE` features are not supported by the Sum Metric, because there is no meaning in summing a discrete feature or a sparse feature. You can use them with the CountMetric, but they may not do what you would expect since they will be treated as binary features losing all the information within the feature. The best way to use these is as “keys” and not as “features”. - -.. admonition:: Setting includeAnyFeature - - If constructor argument `includeAnyFeature` is set, the framework will append a feature with scope `any_feature` to the set of all features you define. This additional feature simply measures the total count of records. So if you set your features to be equal to Set.empty, this will measure the count of records for a given `USER_ID`. - -Labels ------- - -`labels` specifies a set of `BINARY` features that you can cross with, prior to applying aggregations on the `features`. This essentially restricts the aggregate computation to a subset of the records within a particular key. - -We typically use this to represent engagement labels in an ML model, in this case, `IS_FAVORITED`. - -In this example, we are grouping by `USER_ID`, the feature is `HAS_PHOTO`, the label is `IS_FAVORITED`, and we are computing `CountMetric`. The system will output a feature for each user that represents the number of favorites on tweets having photos by this `userId`. - -.. admonition:: Setting includeAnyLabel - - If constructor argument `includeAnyLabel` is set (as it is by default), then similar to `any_feature`, the framework automatically appends a label of type `any_label` to the set of all labels you define, which represents not applying any filter or cross. - -In this example, `any_label` and `any_feature` are set by default and the system would actually output 4 features for each `user_id`: - -.. cssclass:: shortlist - -#. The number of `IS_FAVORITED` (favorites) on tweet impressions having `HAS_PHOTO=true` - -#. The number of `IS_FAVORITED` (favorites) on all tweet impressions (`any_feature` aggregate) - -#. The number of tweet impressions having `HAS_PHOTO=true` (`any_label` aggregate) - -#. The total number of tweet impressions for this user id (`any_feature.any_label` aggregate) - -.. admonition:: Disabling includeAnyLabel - - To disable this automatically generated feature you can use `includeAnyLabel = false` in your config. This will remove some useful features (particularly for counterfactual signal), but it can greatly save on space since it does not store every possible impressed set of keys in the output store. So use this if you are short on space, but not otherwise. - -Metrics -------- - -`metrics` specifies the aggregate operators to apply. The most commonly used are `Count`, `Sum` and `SumSq`. - -As mentioned before, `Count` can be applied to all types of features, but treats every feature as binary and ignores the value of the feature. `Sum` and `SumSq` can only be applied to Continuous features - they will ignore all other features you specify. By combining sum and sumsq and count, you can produce powerful “z-score” features or other distributional features using a post-transform. - -It is also possible to add your own aggregate operators (e.g. `LastResetMetric `_) to the framework with some additional work. - -HalfLives ---------- - -`halfLives` specifies how fast aggregate features should be decayed. It is important to note that the framework works on an incremental basis: in the batch implementation, the summingbird-scalding job takes in the most recently computed aggregate features, processed on data until day `N-1`, then reads new data records for day `N` and computes updated values of the aggregate features. Similarly, the decay of real-time aggregate features takes the actual time delta between the current time and the last time the aggregate feature value was updated. - -The halflife `H` specifies how fast to decay old sums/counts to simulate a sliding window of counts. The implementation is such that it will take `H` amount of time to decay an aggregate feature to half its initial value. New observed values of sums/counts are added to the aggregate feature value. - -.. admonition:: Batch and real-time - - In the batch use case where aggregate features are recomputed on a daily basis, we typically take halflives on the order of weeks or longer (in Timelines, 50 days). In the real-time use case, shorter halflives are appropriate (hours) since they are updated as client engagements are received by the summingbird job. - - -SQL Equivalent --------------- -Conceptually, you can also think of it as: - -.. code-block:: sql - - INSERT INTO . - SELECT AGG() /* AGG is , which is a exponentially decaying SUM or COUNT etc. based on the halfLifves */ - FROM ( - SELECT preTransformOpt(*) FROM - ) - GROUP BY - WHERE = True - -any_features is AGG(*). - -any_labels removes the WHERE clause. \ No newline at end of file diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/batch.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/batch.docx new file mode 100644 index 000000000..266af06f4 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/batch.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/batch.rst b/timelines/data_processing/ml_util/aggregation_framework/docs/batch.rst deleted file mode 100644 index f3b6ac9a5..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/batch.rst +++ /dev/null @@ -1,215 +0,0 @@ -.. _batch: - -Batch aggregate feature jobs -============================ - -In the previous section, we went over the core concepts of the aggregation framework and discussed how you can set up you own `AggregateGroups` to compute aggregate features. - -Given these groups, this section will discuss how you can setup offline batch jobs to produce the corresponding aggregate features, updated daily. To accomplish this, we need to setup a summingbird-scalding job that is pointed to the input data records containing features and labels to be aggregated. - -Input Data ----------- - -In order to generate aggregate features, the relevant input features need to be available offline as a daily scalding source in `DataRecord` format (typically `DailySuffixFeatureSource `_, though `HourlySuffixFeatureSource` could also be usable but we have not tested this). - -.. admonition:: Note - - The input data source should contain the keys, features and labels you want to use in your `AggregateGroups`. - -Aggregation Config ------------------- - -Now that we have a daily data source with input features and labels, we need to setup the `AggregateGroup` config itself. This contains all aggregation groups that you would like to compute and we will go through the implementation step-by-step. - -.. admonition:: Example: Timelines Quality config - - `TimelinesAggregationConfig `_ imports the configured `AggregationGroups` from `TimelinesAggregationConfigDetails `_. The config is then referenced by the implementing summingbird-scalding job which we will setup below. - -OfflineAggregateSource ----------------------- - -Each `AggregateGroup` will need to define a (daily) source of input features. We use `OfflineAggregateSource` for this to tell the aggregation framework where the input data set is and the required timestamp feature that the framework uses to decay aggregate feature values: - -.. code-block:: scala - - val timelinesDailyRecapSource = OfflineAggregateSource( - name = "timelines_daily_recap", - timestampFeature = TIMESTAMP, - scaldingHdfsPath = Some("/user/timelines/processed/suggests/recap/data_records"), - scaldingSuffixType = Some("daily"), - withValidation = true - ) - -.. admonition:: Note - - .. cssclass:: shortlist - - #. The name is not important as long as it is unique. - - #. `timestampFeature` must be a discrete feature of type `com.twitter.ml.api.Feature[Long]` and represents the “time” of a given training record in milliseconds - for example, the time at which an engagement, push open event, or abuse event took place that you are trying to train on. If you do not already have such a feature in your daily training data, you need to add one. - - #. `scaldingSuffixType` can be “hourly” or “daily” depending on the type of source (`HourlySuffixFeatureSource` vs `DailySuffixFeatureSource`). - - #. Set `withValidation` to true to validate the presence of _SUCCESS file. Context: https://jira.twitter.biz/browse/TQ-10618 - -Output HDFS store ------------------ - -The output HDFS store is where the computed aggregate features are stored. This store contains all computed aggregate feature values and is incrementally updated by the aggregates job every day. - -.. code-block:: scala - - val outputHdfsPath = "/user/timelines/processed/aggregates_v2" - val timelinesOfflineAggregateSink = new OfflineStoreCommonConfig { - override def apply(startDate: String) = new OfflineAggregateStoreCommonConfig( - outputHdfsPathPrefix = outputHdfsPath, - dummyAppId = "timelines_aggregates_v2_ro", // unused - can be arbitrary - dummyDatasetPrefix = "timelines_aggregates_v2_ro", // unused - can be arbitrary - startDate = startDate - ) - } - -Note: `dummyAppId` and `dummyDatasetPrefix` are unused so can be set to any arbitrary value. They should be removed on the framework side. - -The `outputHdfsPathPrefix` is the only field that matters, and should be set to the HDFS path where you want to store the aggregate features. Make sure you have a lot of quota available at that path. - -Setting Up Aggregates Job -------------------------- - -Once you have defined a config file with the aggregates you would like to compute, the next step is to create the aggregates scalding job using the config (`example `_). This is very concise and requires only a few lines of code: - -.. code-block:: scala - - object TimelinesAggregationScaldingJob extends AggregatesV2ScaldingJob { - override val aggregatesToCompute = TimelinesAggregationConfig.aggregatesToCompute - } - -Now that the scalding job is implemented with the aggregation config, we need to setup a capesos config similar to https://cgit.twitter.biz/source/tree/science/scalding/mesos/timelines/prod.yml: - -.. code-block:: scala - - # Common configuration shared by all aggregates v2 jobs - __aggregates_v2_common__: &__aggregates_v2_common__ - class: HadoopSummingbirdProducer - bundle: offline_aggregation-deploy.tar.gz - mainjar: offline_aggregation-deploy.jar - pants_target: "bundle timelines/data_processing/ad_hoc/aggregate_interactions/v2/offline_aggregation:bin" - cron_collision_policy: CANCEL_NEW - use_libjar_wild_card: true - -.. code-block:: scala - - # Specific job computing user aggregates - user_aggregates_v2: - <<: *__aggregates_v2_common__ - cron_schedule: "25 * * * *" - arguments: --batches 1 --output_stores user_aggregates --job_name timelines_user_aggregates_v2 - -.. admonition:: Important - - Each AggregateGroup in your config should have its own associated offline job which specifies `output_stores` pointing to the output store name you defined in your config. - -Running The Job ---------------- - -When you run the batch job for the first time, you need to add a temporary entry to your capesos yml file that looks like this: - -.. code-block:: scala - - user_aggregates_v2_initial_run: - <<: *__aggregates_v2_common__ - cron_schedule: "25 * * * *" - arguments: --batches 1 --start-time “2017-03-03 00:00:00” --output_stores user_aggregates --job_name timelines_user_aggregates_v2 - -.. admonition:: Start Time - - The additional `--start-time` argument should match the `startDate` in your config for that AggregateGroup, but in the format `yyyy-mm-dd hh:mm:ss`. - -To invoke the initial run via capesos, we would do the following (in Timelines case): - -.. code-block:: scala - - CAPESOSPY_ENV=prod capesospy-v2 update --build_locally --start_cron user_aggregates_v2_initial_run science/scalding/mesos/timelines/prod.yml - -Once it is running smoothly, you can deschedule the initial run job and delete the temporary entry from your production yml config. - -.. code-block:: scala - - aurora cron deschedule atla/timelines/prod/user_aggregates_v2_initial_run - -Note: deschedule it preemptively to avoid repeatedly overwriting the same initial results - -Then schedule the production job from jenkins using something like this: - -.. code-block:: scala - - CAPESOSPY_ENV=prod capesospy-v2 update user_aggregates_v2 science/scalding/mesos/timelines/prod.yml - -All future runs (2nd onwards) will use the permanent entry in the capesos yml config that does not have the `start-time` specified. - -.. admonition:: Job name has to match - - It's important that the production run should share the same `--job_name` with the initial_run so that eagleeye/statebird knows how to keep track of it correctly. - -Output Aggregate Features -------------------------- - -This scalding job using the example config from the earlier section would output a VersionedKeyValSource to `/user/timelines/processed/aggregates_v2/user_aggregates` on HDFS. - -Note that `/user/timelines/processed/aggregates_v2` is the explicitly defined root path while `user_aggregates` is the output directory of the example `AggregateGroup` defined earlier. The latter can be different for different `AggregateGroups` defined in your config. - - -The VersionedKeyValSource is difficult to use directly in your jobs/offline trainings, but we provide an adapted source `AggregatesV2FeatureSource` that makes it easy to join and use in your jobs: - -.. code-block:: scala - - import com.twitter.timelines.data_processing.ml_util.aggregation_framework.conversion._ - - val pipe: DataSetPipe = AggregatesV2FeatureSource( - rootPath = "/user/timelines/processed/aggregates_v2", - storeName = "user_aggregates", - aggregates = TimelinesAggregationConfig.aggregatesToCompute, - trimThreshold = 0 - )(dateRange).read - -Simply replace the `rootPath`, `storeName` and `aggregates` object to whatever you defined. The `trimThreshold` tells the framework to trim all features below a certain cutoff: 0 is a safe default to use to begin with. - -.. admonition:: Usage - - This can now be used like any other `DataSetPipe` in offline ML jobs. You can write out the features to a `DailySuffixFeatureSource`, you can join them with your data offline for trainings, or you can write them to a Manhattan store for serving online. - -Aggregate Features Example --------------------------- - -Here is an example of sample of the aggregate features we just computed: - -.. code-block:: scala - - user_aggregate_v2.pair.any_label.any_feature.50.days.count: 100.0 - user_aggregate_v2.pair.any_label.tweetsource.is_quote.50.days.count: 30.0 - user_aggregate_v2.pair.is_favorited.any_feature.50.days.count: 10.0 - user_aggregate_v2.pair.is_favorited.tweetsource.is_quote.50.days.count: 6.0 - meta.user_id: 123456789 - -Aggregate feature names match a `prefix.pair.label.feature.half_life.metric` schema and correspond to what was defined in the aggregation config for each of these fields. - -.. admonition:: Example - - In this example, the above features are capturing that userId 123456789L has: - - .. - A 50-day decayed count of 100 training records with any label or feature (“tweet impressions”) - - A 50-day decayed count of 30 records that are “quote tweets” (tweetsource.is_quote = true) - - A 50-day decayed count of 10 records that are favorites on any type of tweet (is_favorited = true) - - A 50-day decayed count of 6 records that are “favorites” on “quote tweets” (both of the above are true) - -By combining the above, a model might infer that for this specific user, quote tweets comprise 30% of all impressions, have a favorite rate of 6/30 = 20%, compared to a favorite rate of 10/100 = 10% on the total population of tweets. - -Therefore, being a quote tweet makes this specific user `123456789L` approximately twice as likely to favorite the tweet, which is useful for prediction and could result in the ML model giving higher scores to & ranking quote tweets higher in a personalized fashion for this user. - -Tests for Feature Names --------------------------- -When you change or add AggregateGroup, feature names might change. And the Feature Store provides a testing mechanism to assert that the feature names change as you expect. See `tests for feature names `_. diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/conf.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/conf.docx new file mode 100644 index 000000000..d288d9f47 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/conf.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/conf.py b/timelines/data_processing/ml_util/aggregation_framework/docs/conf.py deleted file mode 100644 index 03996dfd7..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/conf.py +++ /dev/null @@ -1,59 +0,0 @@ -# -*- coding: utf-8 -*- -# -# docbird documentation build configuration file -# Note that not all possible configuration values are present in this -# autogenerated file. -# - -from os.path import abspath, dirname, isfile, join - - -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.intersphinx", - "sphinx.ext.ifconfig", - "sphinx.ext.graphviz", - "twitter.docbird.ext.thriftlexer", - "twitter.docbird.ext.toctree_default_caption", - "sphinxcontrib.httpdomain", -] - - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# The suffix of source filenames. -source_suffix = ".rst" - -# The master toctree document. -master_doc = "index" - -# General information about the project. -project = u"""Aggregation Framework""" -description = u"""""" - -# The short X.Y version. -version = u"""1.0""" -# The full version, including alpha/beta/rc tags. -release = u"""1.0""" - -exclude_patterns = ["_build"] - -pygments_style = "sphinx" - -html_theme = "default" - -html_static_path = ["_static"] - -html_logo = u"""""" - -# Automagically add project logo, if it exists -# (checks on any build, not just init) -# Scan for some common defaults (png or svg format, -# called "logo" or project name, in docs folder) -if not html_logo: - location = dirname(abspath(__file__)) - for logo_file in ["logo.png", "logo.svg", ("%s.png" % project), ("%s.svg" % project)]: - html_logo = logo_file if isfile(join(location, logo_file)) else html_logo - -graphviz_output_format = "svg" diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/index.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/index.docx new file mode 100644 index 000000000..fffc8eae2 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/index.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/index.rst b/timelines/data_processing/ml_util/aggregation_framework/docs/index.rst deleted file mode 100644 index af703c688..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/index.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. markdowninclude:: ../README.md - -.. toctree:: - :maxdepth: 2 - :hidden: - - aggregation - batch - real-time - joining - troubleshooting diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/joining.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/joining.docx new file mode 100644 index 000000000..20677f509 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/joining.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/joining.rst b/timelines/data_processing/ml_util/aggregation_framework/docs/joining.rst deleted file mode 100644 index 2ecdf7612..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/joining.rst +++ /dev/null @@ -1,72 +0,0 @@ -.. _joining: - -Joining aggregates features to records -====================================== - -After setting up either offline batch jobs or online real-time summingbird jobs to produce -aggregate features and querying them, we are left with data records containing aggregate features. -This page will go over how to join them with other data records to produce offline training data. - -(To discuss: joining aggregates to records online) - -Joining Aggregates on Discrete/String Keys ------------------------------------------- - -Joining aggregate features keyed on discrete or text features to your training data is very easy - -you can use the built in methods provided by `DataSetPipe`. For example, suppose you have aggregates -keyed by `(USER_ID, AUTHOR_ID)`: - -.. code-block:: scala - - val userAuthorAggregates: DataSetPipe = AggregatesV2FeatureSource( - rootPath = “/path/to/my/aggregates”, - storeName = “user_author_aggregates”, - aggregates = MyConfig.aggregatesToCompute, - trimThreshold = 0 - )(dateRange).read - -Offline, you can then join with your training data set as follows: - -.. code-block:: scala - - val myTrainingData: DataSetPipe = ... - val joinedData = myTrainingData.joinWithLarger((USER_ID, AUTHOR_ID), userAuthorAggregates) - -You can read from `AggregatesV2MostRecentFeatureSourceBeforeDate` in order to read the most recent aggregates -before a provided date `beforeDate`. Just note that `beforeDate` must be aligned with the date boundary so if -you’re passing in a `dateRange`, use `dateRange.end`). - -Joining Aggregates on Sparse Binary Keys ----------------------------------------- - -When joining on sparse binary keys, there can be multiple aggregate records to join to each training record in -your training data set. For example, suppose you have setup an aggregate group that is keyed on `(INTEREST_ID, AUTHOR_ID)` -capturing engagement counts of users interested in a particular `INTEREST_ID` for specific authors provided by `AUTHOR_ID`. - -Suppose now that you have a training data record representing a specific user action. This training data record contains -a sparse binary feature `INTEREST_IDS` representing all the "interests" of that user - e.g. music, sports, and so on. Each `interest_id` -translates to a different set of counting features found in your aggregates data. Therefore we need a way to merge all of -these different sets of counting features to produce a more compact, fixed-size set of features. - -.. admonition:: Merge policies - - To do this, the aggregate framework provides a trait `SparseBinaryMergePolicy `_. Classes overriding this trait define policies - that state how to merge the individual aggregate features from each sparse binary value (in this case, each `INTEREST_ID` for a user). - Furthermore, we provide `SparseBinaryMultipleAggregateJoin` which executes these policies to merge aggregates. - -A simple policy might simply average all the counts from the individual interests, or just take the max, or -a specific quantile. More advanced policies might use custom criteria to decide which interest is most relevant and choose -features from that interest to represent the user, or use some weighted combination of counts. - -The framework provides two simple in-built policies (`PickTopCtrPolicy `_ -and `CombineCountsPolicy `_, which keeps the topK counts per -record) that you can get started with, though you likely want to implement your own policy based on domain knowledge to get -the best results for your specific problem domain. - -.. admonition:: Offline Code Example - - The scalding job `TrainingDataWithAggV2Generator `_ shows how multiple merge policies are defined and implemented to merge aggregates on sparse binary keys to the TQ's training data records. - -.. admonition:: Online Code Example - - In our (non-FeatureStore enabled) online code path, we merge aggregates on sparse binary keys using the `CombineCountsPolicy `_. diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/real-time.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/real-time.docx new file mode 100644 index 000000000..c46f6cfff Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/real-time.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/real-time.rst b/timelines/data_processing/ml_util/aggregation_framework/docs/real-time.rst deleted file mode 100644 index fc853ba69..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/real-time.rst +++ /dev/null @@ -1,327 +0,0 @@ -.. _real_time: - -Real-Time aggregate features -============================ - -In addition to computing batch aggregate features, the aggregation framework supports real-time aggregates as well. The framework concepts used here are identical to the batch use case, however, the underlying implementation differs and is provided by summingbird-storm jobs. - -RTA Runbook ------------ - -For operational details, please visit http://go/tqrealtimeaggregates. - -Prerequisites -------------- - -In order to start computing real-time aggregate features, the framework requires the following to be provided: - -* A backing memcached store that will hold the computed aggregate features. This is conceptually equivalent to the output HDFS store in the batch compute case. -* Implementation of `StormAggregateSource `_ that creates `DataRecords` with the necessary input features. This serves as the input to the aggregation operations. -* Definition of aggregate features by defining `AggregateGroup` in an implementation of `OnlineAggregationConfigTrait`. This is identical to the batch case. -* Job config file defining the backing memcached for feature storage and retrieval, and job-related parameters. - -We will now go through the details in setting up each required component. - -Memcached store ---------------- - -Real-time aggregates use Memcache as the backing cache to store and update aggregate features keys. Caches can be provisioned on `go/cacheboard `_. - -.. admonition:: Test and prod caches - - For development, it is sufficient to setup a test cache that your new job can query and write to. At the same time, a production cache request should also be submitted as these generally have significant lead times for provisioning. - -StormAggregateSource --------------------- - -To enable aggregation of your features, we need to start with defining a `StormAggregateSource` that builds a `Producer[Storm, DataRecord]`. This summingbird producer generates `DataRecords` that contain the input features and labels that the real-time aggregate job will compute aggregate features on. Conceptually, this is equivalent to the input data set in the offline batch use case. - -.. admonition:: Example - - If you are planning to aggregate on client engagements, you would need to subscribe to the `ClientEvent` kafka stream and then convert each event to a `DataRecord` that contains the key and the engagement on which to aggregate. - -Typically, we would setup a julep filter for the relevant client events that we would like to aggregate on. This gives us a `Producer[Storm, LogEvent]` object which we then convert to `Producer[Storm, DataRecord]` with adapters that we wrote: - -.. code-block:: scala - - lazy val clientEventProducer: Producer[Storm, LogEvent] = - ClientEventSourceScrooge( - appId = AppId(jobConfig.appId), - topic = "julep_client_event_suggests", - resumeAtLastReadOffset = false - ).source.name("timelines_events") - - lazy val clientEventWithCachedFeaturesProducer: Producer[Storm, DataRecord] = clientEventProducer - .flatMap(mkDataRecords) - -Note that this way of composing the storm graph gives us flexiblity in how we can hydrate input features. If you would like to join more complex features to `DataRecord`, you can do so here with additional storm components which can implement cache queries. - -.. admonition:: Timelines Quality use case - - In Timelines Quality, we aggregate client engagements on `userId` or `tweetId` and implement - `TimelinesStormAggregateSource `_. We create - `Producer[Storm,LogEvent]` of Timelines engagements to which we apply `ClientLogEventAdapter `_ which converts the event to `DataRecord` containing `userId`, `tweetId`, `timestampFeature` of the engagement and the engagement label itself. - -.. admonition:: MagicRecs use case - - MagicRecs has a very similar setup for real-time aggregate features. In addition, they also implement a more complex cache query to fetch the user's history in the `StormAggregateSource` for each observed client engagement to hydrate a richer set of input `DataRecords`: - - .. code-block:: scala - - val userHistoryStoreService: Storm#Service[Long, History] = - Storm.service(UserHistoryReadableStore) - - val clientEventDataRecordProducer: Producer[Storm, DataRecord] = - magicRecsClientEventProducer - .flatMap { ... - (userId, logEvent) - }.leftJoin(userHistoryStoreService) - .flatMap { - case (_, (logEvent, history)) => - mkDataRecords(LogEventHistoryPair(logEvent, history)) - } - -.. admonition:: EmailRecs use case - - EmailRecs shares the same cache as MagicRecs. They combine notification scribe data with email history data to identify the particular item a user engaged with in an email: - - .. code-block:: scala - - val emailHistoryStoreService: Storm#Service[Long, History] = - Storm.service(EmailHistoryReadableStore) - - val emailEventDataRecordProducer: Producer[Storm, DataRecord] = - emailEventProducer - .flatMap { ... - (userId, logEvent) - }.leftJoin(emailHistoryStoreService) - .flatMap { - case (_, (scribe, history)) => - mkDataRecords(ScribeHistoryPair(scribe, history)) - } - - -Aggregation config ------------------- - -The real-time aggregation config is extended from `OnlineAggregationConfigTrait `_ and defines the features to aggregate and the backing memcached store to which they will be written. - -Setting up real-time aggregates follows the same rules as in the offline batch use case. The major difference here is that `inputSource` should point to the `StormAggregateSource` implementation that provides the `DataRecord` containing the engagements and core features on which to aggregate. In the offline case, this would have been an `OfflineAggregateSource` pointing to an offline source of daily records. - -Finally, `RealTimeAggregateStore` defines the backing memcache to be used and should be provided here as the `outputStore`. - -.. NOTE:: - - Please make sure to provide an `AggregateGroup` for both staging and production. The main difference should be the `outputStore` where features in either environment are read from and written to. You want to make sure that a staged real-time aggregates summingbird job is reading/writing only to the test memcache store and does not mutate the production store. - -Job config ----------- - -In addition to the aggregation config that defines the features to aggregate, the final piece we need to provide is a `RealTimeAggregatesJobConfig` that specificies job values such as `appId`, `teamName` and counts for the various topology components that define the capacity of the job (`Timelines example `_). - -Once you have the job config, implementing the storm job itself is easy and almost as concise as in the batch use case: - -.. code-block:: scala - - object TimelinesRealTimeAggregatesJob extends RealTimeAggregatesJobBase { - override lazy val statsReceiver = DefaultStatsReceiver.scope("timelines_real_time_aggregates") - override lazy val jobConfigs = TimelinesRealTimeAggregatesJobConfigs - override lazy val aggregatesToCompute = TimelinesOnlineAggregationConfig.AggregatesToCompute - } - -.. NOTE:: - There are some topology settings that are currently hard-coded. In particular, we enable `Config.TOPOLOGY_DROPTUPLES_UPON_BACKPRESSURE` to be true for added robustness. This may be made user-definable in the future. - -Steps to hydrate RTAs --------------------- -1. Make the changes to RTAs and follow the steps for `Running the topology`. -2. Register the new RTAs to feature store. Sample phab: https://phabricator.twitter.biz/D718120 -3. Wire the features from feature store to TLX. This is usually done with the feature switch set to False. So it's just a code change and will not yet start hydrating the features yet. Merge the phab. Sample phab: https://phabricator.twitter.biz/D718424 -4. Now we hydrate the features to TLX gradually by doing it shard wise. For this, first create a PCM and then enable the hydration. Sample PCM: https://jira.twitter.biz/browse/PCM-147814 - -Running the topology --------------------- -0. For phab that makes change to the topology (such as adding new ML features), before landing the phab, please create a PCM (`example `_) and deploy the change to devel topology first and then prod (atla and pdxa). Once it is confirmed that the prod topology can handle the change, the phab can be landed. -1. Go to https://ci.twitter.biz/job/tq-ci/build -2. In `commands` input - -.. code-block:: bash - - . src/scala/com/twitter/timelines/prediction/common/aggregates/real_time/deploy_local.sh [devel|atla|pdxa] - -One can only deploy either `devel`, `atla` (prod atla), `pdxa` (prod pdxa) at a time. -For example, to deploy both pdxa and atla prod topologies, one needs to build/run the above steps twice, one with `pdxa` and the other with `atla`. - -The status and performance stats of the topology are found at `go/heron-ui `_. Here you can view whether the job is processing tuples, whether it is under any memory or backpressure and provides general observability. - -Finally, since we enable `Config.TOPOLOGY_DROPTUPLES_UPON_BACKPRESSURE` by default in the topology, we also need to monitor and alert on the number of dropped tuples. Since this is a job generating features a small fraction of dropped tuples is tolerable if that enables us to avoid backpressure that would hold up global computation in the entire graph. - -Hydrating Real-Time Aggregate Features --------------------------------------- - -Once the job is up and running, the aggregate features will be accessible in the backing memcached store. To access these features and hydrate to your online pipeline, we need to build a Memcache client with the right query key. - -.. admonition:: Example - - Some care needs to be taken to define the key injection and codec correctly for the memcached store. These types do not change and you can use the Timelines `memcache client builder `_ as an example. - -Aggregate features are written to store with a `(AggregationKey, BatchID)` key. - -`AggregationKey `_ is an instant of the keys that you previously defined in `AggregateGroup`. If your aggregation key is `USER_ID`, you would need to instantiate `AggregationKey` with the `USER_ID` featureId and the userId value. - -.. admonition:: Returned features - - The `DataRecord` that is returned by the cache now contains all real-time aggregate features for the query `AggregationKey` (similar to the batch use case). If your online hydration flow produces data records, the real-time aggregate features can be joined with your existing records in a straightforward way. - -Adding features from Feature Store to RTA --------------------------------------------- -To add features from Feature Store to RTA and create real time aggregated features based on them, one needs to follow these steps: - -**Step 1** - -Copy Strato column for features that one wants to explore and add a cache if needed. See details at `Customize any Columns for your Team as Needed `_. As an `example `_, we copy Strato column of recommendationsUserFeaturesProd.User.strato and add a cache for timelines team's usage. - -**Step 2** - -Create a new ReadableStore which uses Feature Store Client to request features from Feature Store. Implement FeaturesAdapter which extends TimelinesAdapterBase and derive new features based on raw features from Feature Store. As an `example `_, we create UserFeaturesReadableStore which reads discrete feature user state, and convert it to a list of boolean user state features. - -**Step 3** - -Join these derived features from Feature Store to timelines storm aggregate source. Depends on the characteristic of these derived features, joined key could be tweet id, user id or others. As an `example `_, because user state is per user, the joined key is user id. - -**Step 4** - -Define `AggregateGroup` based on derived features in RTA - -Adding New Aggregate Features from an Existing Dataset --------------------------------- -To add a new aggregate feature group from an existing dataset for use in home models, use the following steps: - -1. Identify the hypothesis being tested by the addition of the features, in accordance with `go/tpfeatureguide `_. -2. Modify or add a new AggregateGroup to `TimelinesOnlineAggregationConfigBase.scala `_ to define the aggregation key, set of features, labels and metrics. An example phab to add more halflives can be found at `D204415 `_. -3. If the change is expected to be very large, it may be recommended to perform capacity estimation. See :ref:`Capacity Estimation` for more details. -4. Create feature catalog items for the new RTAs. An example phab is `D706348 `_. For approval from a featurestore owner ping #help-ml-features on slack. -5. Add new features to the featurestore. An example phab is `D706112 `_. This change can be rolled out with feature switches or by canarying TLX, depending on the risk. An example PCM for feature switches is: `PCM-148654 `_. An example PCM for canarying is: `PCM-145753 `_. -6. Wait for redeploy and confirm the new features are available. One way is querying in BigQuery from a table like `twitter-bq-timelines-prod.continuous_training_recap_fav`. Another way is to inspect individual records using pcat. The command to be used is like: - -.. code-block:: bash - - java -cp pcat-deploy.jar:$(hadoop classpath) com.twitter.ml.tool.pcat.PredictionCatTool - -path /atla/proc2/user/timelines/processed/suggests/recap/continuous_training_data_records/fav/data/YYYY/MM/DD/01/part-00000.lzo - -fc /atla/proc2/user/timelines/processed/suggests/recap/continuous_training_data_records/fav/data_spec.json - -dates YYYY-MM-DDT01 -record_limit 100 | grep [feature_group] - - -7. Create a phab with the new features and test the performance of a model with them compared to a control model without them. Test offline using `Deepbird for training `_ and `RCE Hypothesis Testing `_ to test. Test online using a DDG. Some helpful instructions are available in `Serving Timelines Models `_ and the `Experiment Cookbook `_ - -Capacity Estimation --------------------------------- -This section describes how to approximate the capacity required for a new aggregate group. It is not expected to be exact, but should give a rough estimate. - -There are two main components that must be stored for each aggregate group. - -Key space: Each AggregationKey struct consists of two maps, one of which is populated with tuples [Long, Long] representing of discrete features. This takes up 4 x 8 bytes or 32 bytes. The cache team estimates an additional 40 bytes of overhead. - -Features: An aggregate feature is represented as a pair (16 bytes) and is produced for each feature x label x metric x halflife combination. - -1. Use bigquery to estimate how many unique values exist for the selected key (key_count). Also collect the number of features, labels, metrics, and half-lives being used. -2. Compute the number of entries to be created, which is num_entires = feature_count * label_count * metric_count * halflife_count -3. Compute the number of bytes per entry, which is num_entry_bytes = 16*num_entries + 32 bytes (key storage) + 40 bytes (overhead) -4. Compute total space required = num_entry_bytes * key_count - -Debugging New Aggregate Features --------------------------------- - -To debug problems in the setup of your job, there are several steps you can take. - -First, ensure that data is being received from the input stream and passed through to create data records. This can be achieved by logging results at various places in your code, and especially at the point of data record creation. - -For example, suppose you want to ensure that a data record is being created with -the features you expect. With push and email features, we find that data records -are created in the adaptor, using logic like the following: - -.. code-block:: scala - - val record = new SRichDataRecord(new DataRecord) - ... - record.setFeatureValue(feature, value) - -To see what these feature values look like, we can have our adaptor class extend -Twitter's `Logging` trait, and write each created record to a log file. - -.. code-block:: scala - - class MyEventAdaptor extends TimelinesAdapterBase[MyObject] with Logging { - ... - ... - def mkDataRecord(myFeatures: MyFeatures): DataRecord = { - val record = new SRichDataRecord(new DataRecord) - ... - record.setFeatureValue(feature, value) - logger.info("data record xyz: " + record.getRecord.toString) - } - -This way, every time a data record is sent to the aggregator, it will also be -logged. To inspect these logs, you can push these changes to a staging instance, -ssh into that aurora instance, and grep the `log-files` directory for `xyz`. The -data record objects you find should resemble a map from feature ids to their -values. - -To check that steps in the aggregation are being performed, you can also inspect the job's topology on go/heronui. - -Lastly, to verify that values are being written to your cache you can check the `set` chart in your cache's viz. - -To check particular feature values for a given key, you can spin up a Scala REPL like so: - -.. code-block:: bash - - $ ssh -fN -L*:2181:sdzookeeper-read.atla.twitter.com:2181 -D *:50001 nest.atlc.twitter.com - - $ ./pants repl --jvm-repl-scala-options='-DsocksProxyHost=localhost -DsocksProxyPort=50001 -Dcom.twitter.server.resolverZkHosts=localhost:2181' timelinemixer/common/src/main/scala/com/twitter/timelinemixer/clients/real_time_aggregates_cache - -You will then need to create a connection to the cache, and a key with which to query it. - -.. code-block:: scala - - import com.twitter.conversions.DurationOps._ - import com.twitter.finagle.stats.{DefaultStatsReceiver, StatsReceiver} - import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey - import com.twitter.summingbird.batch.Batcher - import com.twitter.timelinemixer.clients.real_time_aggregates_cache.RealTimeAggregatesMemcacheBuilder - import com.twitter.timelines.clients.memcache_common.StorehausMemcacheConfig - - val userFeature = -1887718638306251279L // feature id corresponding to User feature - val userId = 12L // replace with a user id logged when creating your data record - val key = (AggregationKey(Map(userFeature -> userId), Map.empty), Batcher.unit.currentBatch) - - val dataset = "twemcache_magicrecs_real_time_aggregates_cache_staging" // replace with the appropriate cache name - val dest = s"/srv#/test/local/cache/twemcache_/$dataset" - - val statsReceiver: StatsReceiver = DefaultStatsReceiver - val cache = new RealTimeAggregatesMemcacheBuilder( - config = StorehausMemcacheConfig( - destName = dest, - keyPrefix = "", - requestTimeout = 10.seconds, - numTries = 1, - globalTimeout = 10.seconds, - tcpConnectTimeout = 10.seconds, - connectionAcquisitionTimeout = 10.seconds, - numPendingRequests = 250, - isReadOnly = true - ), - statsReceiver.scope(dataset) - ).build - - val result = cache.get(key) - -Another option is to create a debugger which points to the staging cache and creates a cache connection and key similar to the logic above. - -Run CQL query to find metrics/counters --------------------------------- -We can also visualize the counters from our job to verify new features. Run CQL query on terminal to find the right path of metrics/counters. For example, in order to check counter mergeNumFeatures, run: - -cql -z atla keys heron/summingbird_timelines_real_time_aggregates Tail-FlatMap | grep mergeNumFeatures - - -Then use the right path to create the viz, example: https://monitoring.twitter.biz/tiny/2552105 diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/troubleshooting.docx b/timelines/data_processing/ml_util/aggregation_framework/docs/troubleshooting.docx new file mode 100644 index 000000000..461bc1ba1 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/docs/troubleshooting.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/docs/troubleshooting.rst b/timelines/data_processing/ml_util/aggregation_framework/docs/troubleshooting.rst deleted file mode 100644 index d9799f433..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/docs/troubleshooting.rst +++ /dev/null @@ -1,117 +0,0 @@ -.. _troubleshooting: - -TroubleShooting -================== - - -[Batch] Regenerating a corrupt version --------------------------------------- - -Symptom -~~~~~~~~~~ -The Summingbird batch job failed due to the following error: - -.. code:: bash - - Caused by: com.twitter.bijection.InversionFailure: ... - -It typically indicates the corrupt records of the aggregate store (not the other side of the DataRecord source). -The following describes the method to re-generate the required (typically the latest) version: - -Solution -~~~~~~~~~~ -1. Copy **the second to last version** of the problematic data to canaries folder. For example, if 11/20's job keeps failing, then copy the 11/19's data. - -.. code:: bash - - $ hadoop --config /etc/hadoop/hadoop-conf-proc2-atla/ \ - distcp -m 1000 \ - /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates/1605744000000 \ - /atla/proc2/user/timelines/canaries/processed/aggregates_v2/user_mention_aggregates/1605744000000 - - -2. Setup canary run for the date of the problem with fallback path pointing to `1605744000000` in the prod/canaries folder. - -3. Deschedule the production job and kill the current run: - -For example, - -.. code:: bash - - $ aurora cron deschedule atla/timelines/prod/user_mention_aggregates - $ aurora job killall atla/timelines/prod/user_mention_aggregates - -4. Create backup folder and move the corrupt prod store output there - -.. code:: bash - - $ hdfs dfs -mkdir /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates_backup - $ hdfs dfs -mv /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates/1605830400000 /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates_backup/ - $ hadoop fs -count /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates_backup/1605830400000 - - 1 1001 10829136677614 /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates_backup/1605830400000 - - -5. Copy canary output store to prod folder: - -.. code:: bash - - $ hadoop --config /etc/hadoop/hadoop-conf-proc2-atla/ distcp -m 1000 /atla/proc2/user/timelines/canaries/processed/aggregates_v2/user_mention_aggregates/1605830400000 /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates/1605830400000 - -We can see the slight difference of size: - -.. code:: bash - - $ hadoop fs -count /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates_backup/1605830400000 - 1 1001 10829136677614 /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates_backup/1605830400000 - $ hadoop fs -count /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates/1605830400000 - 1 1001 10829136677844 /atla/proc2/user/timelines/processed/aggregates_v2/user_mention_aggregates/1605830400000 - -6. Deploy prod job again and observe whether it can successfully process the new output for the date of interest. - -7. Verify the new run succeeded and job is unblocked. - -Example -~~~~~~~~ - -There is an example in https://phabricator.twitter.biz/D591174 - - -[Batch] Skipping the offline job ahead ---------------------------------------- - -Symptom -~~~~~~~~~~ -The Summingbird batch job keeps failing and the DataRecord source is no longer available (e.g. due to retention) and there is no way for the job succeed **OR** - -.. -The job is stuck processing old data (more than one week old) and it will not catch up to the new data on its own if it is left alone - -Solution -~~~~~~~~ - -We will need to skip the job ahead. Unfortunately, this involves manual effort. We also need help from the ADP team (Slack #adp). - -1. Ask the ADP team to manually insert an entry into the store via the #adp Slack channel. You may refer to https://jira.twitter.biz/browse/AIPIPE-7520 and https://jira.twitter.biz/browse/AIPIPE-9300 as references. However, please don't create and assign tickets directly to an ADP team member unless they ask you to. - -2. Copy the latest version of the store to the same HDFS directory but with a different destination name. The name MUST be the same as the above inserted version. - -For example, if the ADP team manually inserted a version on 12/09/2020, then we can see the version by running - -.. code:: bash - - $ dalv2 segment list --name user_original_author_aggregates --role timelines --location-name proc2-atla --location-type hadoop-cluster - ... - None 2020-12-09T00:00:00Z viewfs://hadoop-proc2-nn.atla.twitter.com/user/timelines/processed/aggregates_v2/user_original_author_aggregates/1607472000000 Unknown None - -where `1607472000000` is the timestamp of 12/09/2020. -Then you will need to duplicate the latest version of the store to a dir of `1607472000000`. -For example, - -.. code:: bash - - $ hadoop --config /etc/hadoop/hadoop-conf-proc2-atla/ distcp -m 1000 /atla/proc2/user/timelines/processed/aggregates_v2/user_original_author_aggregates/1605052800000 /atla/proc2/user/timelines/processed/aggregates_v2/user_original_author_aggregates/1607472000000 - -3. Go to the EagleEye UI of the job and click on the "Skip Ahead" button to the desired datetime. In our example, it should be `2020-12-09 12am` - -4. Wait for the job to start. Now the job should be running the 2020-12-09 partition. diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/BUILD b/timelines/data_processing/ml_util/aggregation_framework/heron/BUILD deleted file mode 100644 index 0cc576e4e..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/BUILD +++ /dev/null @@ -1,74 +0,0 @@ -scala_library( - sources = ["*.scala"], - platform = "java8", - strict_deps = False, - tags = ["bazel-compatible"], - dependencies = [ - ":configs", - "3rdparty/jvm/storm:heron-oss-storm", - "3rdparty/src/jvm/com/twitter/scalding:args", - "3rdparty/src/jvm/com/twitter/summingbird:storm", - "src/java/com/twitter/heron/util", - "src/java/com/twitter/ml", - "src/scala/com/twitter/storehaus_internal/nighthawk_kv", - "src/scala/com/twitter/summingbird_internal/bijection:bijection-implicits", - "src/scala/com/twitter/summingbird_internal/runner/common", - "src/scala/com/twitter/summingbird_internal/runner/storm", - "src/scala/com/twitter/timelines/prediction/features/common", - "timelines/data_processing/ml_util/aggregation_framework:user_job", - ], -) - -scala_library( - name = "configs", - sources = [ - "NighthawkUnderlyingStoreConfig.scala", - "OnlineAggregationConfigTrait.scala", - "OnlineAggregationStoresTrait.scala", - "RealTimeAggregateStore.scala", - "RealTimeAggregatesJobConfig.scala", - "StormAggregateSource.scala", - ], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - ":base-config", - "3rdparty/jvm/storm:heron-oss-storm", - "3rdparty/src/jvm/com/twitter/summingbird:core", - "3rdparty/src/jvm/com/twitter/summingbird:storm", - "finagle/finagle-core/src/main", - "src/java/com/twitter/ml/api:api-base", - "src/scala/com/twitter/storehaus_internal/memcache", - "src/scala/com/twitter/storehaus_internal/memcache/config", - "src/scala/com/twitter/storehaus_internal/nighthawk_kv", - "src/scala/com/twitter/storehaus_internal/nighthawk_kv/config", - "src/scala/com/twitter/storehaus_internal/online", - "src/scala/com/twitter/storehaus_internal/store", - "src/scala/com/twitter/storehaus_internal/util", - "src/scala/com/twitter/summingbird_internal/runner/store_config", - "src/thrift/com/twitter/clientapp/gen:clientapp-java", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:data-scala", - "src/thrift/com/twitter/ml/api:feature_context-java", - "timelines/data_processing/ml_util/aggregation_framework:common_types", - "timelines/data_processing/ml_util/transforms", - "util/util-core:scala", - "util/util-core:util-core-util", - "util/util-stats/src/main/scala/com/twitter/finagle/stats", - ], -) - -scala_library( - name = "base-config", - sources = [ - "OnlineAggregationConfigTrait.scala", - ], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "src/java/com/twitter/ml/api:api-base", - "timelines/data_processing/ml_util/aggregation_framework:common_types", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/BUILD.docx new file mode 100644 index 000000000..af447c05a Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/NighthawkUnderlyingStoreConfig.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/NighthawkUnderlyingStoreConfig.docx new file mode 100644 index 000000000..baab72e18 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/NighthawkUnderlyingStoreConfig.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/NighthawkUnderlyingStoreConfig.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/NighthawkUnderlyingStoreConfig.scala deleted file mode 100644 index cf7668a20..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/NighthawkUnderlyingStoreConfig.scala +++ /dev/null @@ -1,31 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.mtls.authentication.EmptyServiceIdentifier -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.ssl.OpportunisticTls -import com.twitter.storehaus_internal.nighthawk_kv.CacheClientNighthawkConfig -import com.twitter.storehaus_internal.util.TTL -import com.twitter.storehaus_internal.util.TableName -import com.twitter.summingbird_internal.runner.store_config.OnlineStoreOnlyConfig -import com.twitter.util.Duration - -case class NighthawkUnderlyingStoreConfig( - serversetPath: String = "", - tableName: String = "", - cacheTTL: Duration = 1.day) - extends OnlineStoreOnlyConfig[CacheClientNighthawkConfig] { - - def online: CacheClientNighthawkConfig = online(EmptyServiceIdentifier) - - def online( - serviceIdentifier: ServiceIdentifier = EmptyServiceIdentifier - ): CacheClientNighthawkConfig = - CacheClientNighthawkConfig( - serversetPath, - TableName(tableName), - TTL(cacheTTL), - serviceIdentifier = serviceIdentifier, - opportunisticTlsLevel = OpportunisticTls.Required - ) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationConfigTrait.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationConfigTrait.docx new file mode 100644 index 000000000..161513cc2 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationConfigTrait.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationConfigTrait.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationConfigTrait.scala deleted file mode 100644 index aea649128..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationConfigTrait.scala +++ /dev/null @@ -1,28 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import com.twitter.ml.api.Feature - -trait OnlineAggregationConfigTrait { - def ProdAggregates: Set[TypedAggregateGroup[_]] - def StagingAggregates: Set[TypedAggregateGroup[_]] - def ProdCommonAggregates: Set[TypedAggregateGroup[_]] - - /** - * AggregateToCompute: This defines the complete set of aggregates to be - * computed by the aggregation job and to be stored in memcache. - */ - def AggregatesToCompute: Set[TypedAggregateGroup[_]] - - /** - * ProdFeatures: This defines the subset of aggregates to be extracted - * and hydrated (or adapted) by callers to the aggregates features cache. - * This should only contain production aggregates and aggregates on - * product specific engagements. - * ProdCommonFeatures: Similar to ProdFeatures but containing user-level - * aggregate features. This is provided to PredictionService just - * once per user. - */ - lazy val ProdFeatures: Set[Feature[_]] = ProdAggregates.flatMap(_.allOutputFeatures) - lazy val ProdCommonFeatures: Set[Feature[_]] = ProdCommonAggregates.flatMap(_.allOutputFeatures) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationStoresTrait.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationStoresTrait.docx new file mode 100644 index 000000000..dff90614b Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationStoresTrait.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationStoresTrait.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationStoresTrait.scala deleted file mode 100644 index 4f693190e..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/OnlineAggregationStoresTrait.scala +++ /dev/null @@ -1,6 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -trait OnlineAggregationStoresTrait { - def ProductionStore: RealTimeAggregateStore - def StagingStore: RealTimeAggregateStore -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregateStore.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregateStore.docx new file mode 100644 index 000000000..b983de65f Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregateStore.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregateStore.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregateStore.scala deleted file mode 100644 index 2e75039d3..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregateStore.scala +++ /dev/null @@ -1,50 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.mtls.authentication.EmptyServiceIdentifier -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.storehaus_internal.memcache.ConnectionConfig -import com.twitter.storehaus_internal.memcache.MemcacheConfig -import com.twitter.storehaus_internal.util.KeyPrefix -import com.twitter.storehaus_internal.util.TTL -import com.twitter.storehaus_internal.util.ZkEndPoint -import com.twitter.summingbird_internal.runner.store_config.OnlineStoreOnlyConfig -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregateStore -import com.twitter.util.Duration - -object RealTimeAggregateStore { - val twCacheWilyPrefix = "/srv#" // s2s is only supported for wily path - - def makeEndpoint( - memcacheDataSet: String, - isProd: Boolean, - twCacheWilyPrefix: String = twCacheWilyPrefix - ): String = { - val env = if (isProd) "prod" else "test" - s"$twCacheWilyPrefix/$env/local/cache/$memcacheDataSet" - } -} - -case class RealTimeAggregateStore( - memcacheDataSet: String, - isProd: Boolean = false, - cacheTTL: Duration = 1.day) - extends OnlineStoreOnlyConfig[MemcacheConfig] - with AggregateStore { - import RealTimeAggregateStore._ - - override val name: String = "" - val storeKeyPrefix: KeyPrefix = KeyPrefix(name) - val memcacheZkEndPoint: String = makeEndpoint(memcacheDataSet, isProd) - - def online: MemcacheConfig = online(serviceIdentifier = EmptyServiceIdentifier) - - def online(serviceIdentifier: ServiceIdentifier = EmptyServiceIdentifier): MemcacheConfig = - new MemcacheConfig { - val endpoint = ZkEndPoint(memcacheZkEndPoint) - override val connectionConfig = - ConnectionConfig(endpoint, serviceIdentifier = serviceIdentifier) - override val keyPrefix = storeKeyPrefix - override val ttl = TTL(Duration.fromMilliseconds(cacheTTL.inMillis)) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobBase.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobBase.docx new file mode 100644 index 000000000..b1523570f Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobBase.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobBase.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobBase.scala deleted file mode 100644 index 906f7c1be..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobBase.scala +++ /dev/null @@ -1,301 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -import com.twitter.algebird.Monoid -import com.twitter.bijection.Injection -import com.twitter.bijection.thrift.CompactThriftCodec -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.mtls.authentication.EmptyServiceIdentifier -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.heron.util.CommonMetric -import com.twitter.ml.api.DataRecord -import com.twitter.scalding.Args -import com.twitter.storehaus.algebra.MergeableStore -import com.twitter.storehaus.algebra.StoreAlgebra._ -import com.twitter.storehaus_internal.memcache.Memcache -import com.twitter.storehaus_internal.store.CombinedStore -import com.twitter.storehaus_internal.store.ReplicatingWritableStore -import com.twitter.summingbird.batch.BatchID -import com.twitter.summingbird.batch.Batcher -import com.twitter.summingbird.online.MergeableStoreFactory -import com.twitter.summingbird.online.option._ -import com.twitter.summingbird.option.CacheSize -import com.twitter.summingbird.option.JobId -import com.twitter.summingbird.storm.option.FlatMapStormMetrics -import com.twitter.summingbird.storm.option.SummerStormMetrics -import com.twitter.summingbird.storm.Storm -import com.twitter.summingbird.storm.StormMetric -import com.twitter.summingbird.Options -import com.twitter.summingbird._ -import com.twitter.summingbird_internal.runner.common.CapTicket -import com.twitter.summingbird_internal.runner.common.JobName -import com.twitter.summingbird_internal.runner.common.TeamEmail -import com.twitter.summingbird_internal.runner.common.TeamName -import com.twitter.summingbird_internal.runner.storm.ProductionStormConfig -import com.twitter.timelines.data_processing.ml_util.aggregation_framework._ -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.job.AggregatesV2Job -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.job.AggregatesV2Job -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.job.DataRecordFeatureCounter -import org.apache.heron.api.{Config => HeronConfig} -import org.apache.heron.common.basics.ByteAmount -import org.apache.storm.Config -import scala.collection.JavaConverters._ - -object RealTimeAggregatesJobBase { - lazy val commonMetric: StormMetric[CommonMetric] = - StormMetric(new CommonMetric(), CommonMetric.NAME, CommonMetric.POLL_INTERVAL) - lazy val flatMapMetrics: FlatMapStormMetrics = FlatMapStormMetrics(Iterable(commonMetric)) - lazy val summerMetrics: SummerStormMetrics = SummerStormMetrics(Iterable(commonMetric)) -} - -trait RealTimeAggregatesJobBase extends Serializable { - import RealTimeAggregatesJobBase._ - import com.twitter.summingbird_internal.bijection.BatchPairImplicits._ - - def statsReceiver: StatsReceiver - - def aggregatesToCompute: Set[TypedAggregateGroup[_]] - - def jobConfigs: RealTimeAggregatesJobConfigs - - implicit lazy val dataRecordCodec: Injection[DataRecord, Array[Byte]] = - CompactThriftCodec[DataRecord] - implicit lazy val monoid: Monoid[DataRecord] = DataRecordAggregationMonoid(aggregatesToCompute) - implicit lazy val aggregationKeyInjection: Injection[AggregationKey, Array[Byte]] = - AggregationKeyInjection - - val clusters: Set[String] = Set("atla", "pdxa") - - def buildAggregateStoreToStorm( - isProd: Boolean, - serviceIdentifier: ServiceIdentifier, - jobConfig: RealTimeAggregatesJobConfig - ): (AggregateStore => Option[Storm#Store[AggregationKey, DataRecord]]) = { - (store: AggregateStore) => - store match { - case rtaStore: RealTimeAggregateStore if rtaStore.isProd == isProd => { - lazy val primaryStore: MergeableStore[(AggregationKey, BatchID), DataRecord] = - Memcache.getMemcacheStore[(AggregationKey, BatchID), DataRecord]( - rtaStore.online(serviceIdentifier)) - - lazy val mergeableStore: MergeableStore[(AggregationKey, BatchID), DataRecord] = - if (jobConfig.enableUserReindexingNighthawkBtreeStore - || jobConfig.enableUserReindexingNighthawkHashStore) { - val reindexingNighthawkBtreeWritableDataRecordStoreList = - if (jobConfig.enableUserReindexingNighthawkBtreeStore) { - lazy val cacheClientNighthawkConfig = - jobConfig.userReindexingNighthawkBtreeStoreConfig.online(serviceIdentifier) - List( - UserReindexingNighthawkWritableDataRecordStore.getBtreeStore( - nighthawkCacheConfig = cacheClientNighthawkConfig, - // Choose a reasonably large target size as this will be equivalent to the number of unique (user, timestamp) - // keys that are returned on read on the pKey, and we may have duplicate authors and associated records. - targetSize = 512, - statsReceiver = statsReceiver, - // Assuming trims are relatively expensive, choose a trimRate that's not as aggressive. In this case we trim on - // 10% of all writes. - trimRate = 0.1 - )) - } else { Nil } - val reindexingNighthawkHashWritableDataRecordStoreList = - if (jobConfig.enableUserReindexingNighthawkHashStore) { - lazy val cacheClientNighthawkConfig = - jobConfig.userReindexingNighthawkHashStoreConfig.online(serviceIdentifier) - List( - UserReindexingNighthawkWritableDataRecordStore.getHashStore( - nighthawkCacheConfig = cacheClientNighthawkConfig, - // Choose a reasonably large target size as this will be equivalent to the number of unique (user, timestamp) - // keys that are returned on read on the pKey, and we may have duplicate authors and associated records. - targetSize = 512, - statsReceiver = statsReceiver, - // Assuming trims are relatively expensive, choose a trimRate that's not as aggressive. In this case we trim on - // 10% of all writes. - trimRate = 0.1 - )) - } else { Nil } - - lazy val replicatingWritableStore = new ReplicatingWritableStore( - stores = List(primaryStore) ++ reindexingNighthawkBtreeWritableDataRecordStoreList - ++ reindexingNighthawkHashWritableDataRecordStoreList - ) - - lazy val combinedStoreWithReindexing = new CombinedStore( - read = primaryStore, - write = replicatingWritableStore - ) - - combinedStoreWithReindexing.toMergeable - } else { - primaryStore - } - - lazy val storeFactory: MergeableStoreFactory[(AggregationKey, BatchID), DataRecord] = - Storm.store(mergeableStore)(Batcher.unit) - Some(storeFactory) - } - case _ => None - } - } - - def buildDataRecordSourceToStorm( - jobConfig: RealTimeAggregatesJobConfig - ): (AggregateSource => Option[Producer[Storm, DataRecord]]) = { (source: AggregateSource) => - { - source match { - case stormAggregateSource: StormAggregateSource => - Some(stormAggregateSource.build(statsReceiver, jobConfig)) - case _ => None - } - } - } - - def apply(args: Args): ProductionStormConfig = { - lazy val isProd = args.boolean("production") - lazy val cluster = args.getOrElse("cluster", "") - lazy val isDebug = args.boolean("debug") - lazy val role = args.getOrElse("role", "") - lazy val service = - args.getOrElse( - "service_name", - "" - ) // don't use the argument service, which is a reserved heron argument - lazy val environment = if (isProd) "prod" else "devel" - lazy val s2sEnabled = args.boolean("s2s") - lazy val keyedByUserEnabled = args.boolean("keyed_by_user") - lazy val keyedByAuthorEnabled = args.boolean("keyed_by_author") - - require(clusters.contains(cluster)) - if (s2sEnabled) { - require(role.length() > 0) - require(service.length() > 0) - } - - lazy val serviceIdentifier = if (s2sEnabled) { - ServiceIdentifier( - role = role, - service = service, - environment = environment, - zone = cluster - ) - } else EmptyServiceIdentifier - - lazy val jobConfig = { - val jobConfig = if (isProd) jobConfigs.Prod else jobConfigs.Devel - jobConfig.copy( - serviceIdentifier = serviceIdentifier, - keyedByUserEnabled = keyedByUserEnabled, - keyedByAuthorEnabled = keyedByAuthorEnabled) - } - - lazy val dataRecordSourceToStorm = buildDataRecordSourceToStorm(jobConfig) - lazy val aggregateStoreToStorm = - buildAggregateStoreToStorm(isProd, serviceIdentifier, jobConfig) - - lazy val JaasConfigFlag = "-Djava.security.auth.login.config=resources/jaas.conf" - lazy val JaasDebugFlag = "-Dsun.security.krb5.debug=true" - lazy val JaasConfigString = - if (isDebug) { "%s %s".format(JaasConfigFlag, JaasDebugFlag) } - else JaasConfigFlag - - new ProductionStormConfig { - implicit val jobId: JobId = JobId(jobConfig.name) - override val jobName = JobName(jobConfig.name) - override val teamName = TeamName(jobConfig.teamName) - override val teamEmail = TeamEmail(jobConfig.teamEmail) - override val capTicket = CapTicket("n/a") - - val configureHeronJvmSettings = { - val heronJvmOptions = new java.util.HashMap[String, AnyRef]() - jobConfig.componentToRamGigaBytesMap.foreach { - case (component, gigabytes) => - HeronConfig.setComponentRam( - heronJvmOptions, - component, - ByteAmount.fromGigabytes(gigabytes)) - } - - HeronConfig.setContainerRamRequested( - heronJvmOptions, - ByteAmount.fromGigabytes(jobConfig.containerRamGigaBytes) - ) - - jobConfig.componentsToKerberize.foreach { component => - HeronConfig.setComponentJvmOptions( - heronJvmOptions, - component, - JaasConfigString - ) - } - - jobConfig.componentToMetaSpaceSizeMap.foreach { - case (component, metaspaceSize) => - HeronConfig.setComponentJvmOptions( - heronJvmOptions, - component, - metaspaceSize - ) - } - - heronJvmOptions.asScala.toMap ++ AggregatesV2Job - .aggregateNames(aggregatesToCompute).map { - case (prefix, aggNames) => (s"extras.aggregateNames.${prefix}", aggNames) - } - } - - override def transformConfig(m: Map[String, AnyRef]): Map[String, AnyRef] = { - super.transformConfig(m) ++ List( - /** - * Disable acking by setting acker executors to 0. Tuples that come off the - * spout will be immediately acked which effectively disables retries on tuple - * failures. This should help topology throughput/availability by relaxing consistency. - */ - Config.TOPOLOGY_ACKER_EXECUTORS -> int2Integer(0), - Config.TOPOLOGY_WORKERS -> int2Integer(jobConfig.topologyWorkers), - HeronConfig.TOPOLOGY_CONTAINER_CPU_REQUESTED -> int2Integer(8), - HeronConfig.TOPOLOGY_DROPTUPLES_UPON_BACKPRESSURE -> java.lang.Boolean.valueOf(true), - HeronConfig.TOPOLOGY_WORKER_CHILDOPTS -> List( - JaasConfigString, - s"-Dcom.twitter.eventbus.client.zoneName=${cluster}", - "-Dcom.twitter.eventbus.client.EnableKafkaSaslTls=true" - ).mkString(" "), - "storm.job.uniqueId" -> jobId.get - ) ++ configureHeronJvmSettings - - } - - override lazy val getNamedOptions: Map[String, Options] = jobConfig.topologyNamedOptions ++ - Map( - "DEFAULT" -> Options() - .set(flatMapMetrics) - .set(summerMetrics) - .set(MaxWaitingFutures(1000)) - .set(FlushFrequency(30.seconds)) - .set(UseAsyncCache(true)) - .set(AsyncPoolSize(4)) - .set(SourceParallelism(jobConfig.sourceCount)) - .set(SummerBatchMultiplier(1000)), - "FLATMAP" -> Options() - .set(FlatMapParallelism(jobConfig.flatMapCount)) - .set(CacheSize(0)), - "SUMMER" -> Options() - .set(SummerParallelism(jobConfig.summerCount)) - /** - * Sets number of tuples a Summer awaits before aggregation. Set higher - * if you need to lower qps to memcache at the expense of introducing - * some (stable) latency. - */ - .set(CacheSize(jobConfig.cacheSize)) - ) - - val featureCounters: Seq[DataRecordFeatureCounter] = - Seq(DataRecordFeatureCounter.any(Counter(Group("feature_counter"), Name("num_records")))) - - override def graph: TailProducer[Storm, Any] = AggregatesV2Job.generateJobGraph[Storm]( - aggregateSet = aggregatesToCompute, - aggregateSourceToSummingbird = dataRecordSourceToStorm, - aggregateStoreToSummingbird = aggregateStoreToStorm, - featureCounters = featureCounters - ) - } - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobConfig.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobConfig.docx new file mode 100644 index 000000000..d500e17fe Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobConfig.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobConfig.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobConfig.scala deleted file mode 100644 index 8bed26264..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/RealTimeAggregatesJobConfig.scala +++ /dev/null @@ -1,79 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -import com.twitter.finagle.mtls.authentication.EmptyServiceIdentifier -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.ml.api.DataRecord -import com.twitter.summingbird.Options -import com.twitter.timelines.data_processing.ml_util.transforms.OneToSomeTransform - -/** - * - * @param appId application id for topology job - * @param topologyWorkers number of workers/containers of topology - * @param sourceCount number of parallel sprouts of topology - * @param summerCount number of Summer of topology - * @param cacheSize number of tuples a Summer awaits before aggregation. - * @param flatMapCount number of parallel FlatMap of topology - * @param containerRamGigaBytes total RAM of each worker/container has - * @param name name of topology job - * @param teamName name of team who owns topology job - * @param teamEmail email of team who owns topology job - * @param componentsToKerberize component of topology job (eg. Tail-FlatMap-Source) which enables kerberization - * @param componentToMetaSpaceSizeMap MetaSpaceSize settings for components of topology job - * @param topologyNamedOptions Sets spout allocations for named topology components - * @param serviceIdentifier represents the identifier used for Service to Service Authentication - * @param onlinePreTransforms sequential data record transforms applied to Producer of DataRecord before creating AggregateGroup. - * While preTransforms defined at AggregateGroup are applied to each aggregate group, onlinePreTransforms are applied to the whole producer source. - * @param keyedByUserEnabled boolean value to enable/disable merging user-level features from Feature Store - * @param keyedByAuthorEnabled boolean value to enable/disable merging author-level features from Feature Store - * @param enableUserReindexingNighthawkBtreeStore boolean value to enable reindexing RTAs on user id with btree backed nighthawk - * @param enableUserReindexingNighthawkHashStore boolean value to enable reindexing RTAs on user id with hash backed nighthawk - * @param userReindexingNighthawkBtreeStoreConfig NH btree store config used in reindexing user RTAs - * @param userReindexingNighthawkHashStoreConfig NH hash store config used in reindexing user RTAs - */ -case class RealTimeAggregatesJobConfig( - appId: String, - topologyWorkers: Int, - sourceCount: Int, - summerCount: Int, - cacheSize: Int, - flatMapCount: Int, - containerRamGigaBytes: Int, - name: String, - teamName: String, - teamEmail: String, - componentsToKerberize: Seq[String] = Seq.empty, - componentToMetaSpaceSizeMap: Map[String, String] = Map.empty, - componentToRamGigaBytesMap: Map[String, Int] = Map("Tail" -> 4), - topologyNamedOptions: Map[String, Options] = Map.empty, - serviceIdentifier: ServiceIdentifier = EmptyServiceIdentifier, - onlinePreTransforms: Seq[OneToSomeTransform] = Seq.empty, - keyedByUserEnabled: Boolean = false, - keyedByAuthorEnabled: Boolean = false, - keyedByTweetEnabled: Boolean = false, - enableUserReindexingNighthawkBtreeStore: Boolean = false, - enableUserReindexingNighthawkHashStore: Boolean = false, - userReindexingNighthawkBtreeStoreConfig: NighthawkUnderlyingStoreConfig = - NighthawkUnderlyingStoreConfig(), - userReindexingNighthawkHashStoreConfig: NighthawkUnderlyingStoreConfig = - NighthawkUnderlyingStoreConfig()) { - - /** - * Apply transforms sequentially. If any transform results in a dropped (None) - * DataRecord, then entire transform sequence will result in a dropped DataRecord. - * Note that transforms are order-dependent. - */ - def sequentiallyTransform(dataRecord: DataRecord): Option[DataRecord] = { - val recordOpt = Option(new DataRecord(dataRecord)) - onlinePreTransforms.foldLeft(recordOpt) { - case (Some(previousRecord), preTransform) => - preTransform(previousRecord) - case _ => Option.empty[DataRecord] - } - } -} - -trait RealTimeAggregatesJobConfigs { - def Prod: RealTimeAggregatesJobConfig - def Devel: RealTimeAggregatesJobConfig -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/StormAggregateSource.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/StormAggregateSource.docx new file mode 100644 index 000000000..e63cd5026 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/StormAggregateSource.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/StormAggregateSource.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/StormAggregateSource.scala deleted file mode 100644 index a252cf197..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/StormAggregateSource.scala +++ /dev/null @@ -1,27 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.summingbird._ -import com.twitter.summingbird.storm.Storm -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregateSource -import java.lang.{Long => JLong} - -/** - * Use this trait to implement online summingbird producer that subscribes to - * spouts and generates a data record. - */ -trait StormAggregateSource extends AggregateSource { - def name: String - - def timestampFeature: Feature[JLong] - - /** - * Constructs the storm Producer with the implemented topology at runtime. - */ - def build( - statsReceiver: StatsReceiver, - jobConfig: RealTimeAggregatesJobConfig - ): Producer[Storm, DataRecord] -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/UserReindexingNighthawkStore.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/UserReindexingNighthawkStore.docx new file mode 100644 index 000000000..98fb38f11 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/UserReindexingNighthawkStore.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/UserReindexingNighthawkStore.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/UserReindexingNighthawkStore.scala deleted file mode 100644 index a4d2adeac..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/UserReindexingNighthawkStore.scala +++ /dev/null @@ -1,309 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron - -import com.twitter.bijection.Injection -import com.twitter.bijection.thrift.CompactThriftCodec -import com.twitter.cache.client._ -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.storehaus.WritableStore -import com.twitter.storehaus_internal.nighthawk_kv.CacheClientNighthawkConfig -import com.twitter.storehaus_internal.nighthawk_kv.NighthawkStore -import com.twitter.summingbird.batch.BatchID -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.heron.UserReindexingNighthawkWritableDataRecordStore._ -import com.twitter.timelines.prediction.features.common.TimelinesSharedFeatures -import com.twitter.util.Future -import com.twitter.util.Time -import com.twitter.util.Try -import com.twitter.util.logging.Logger -import java.nio.ByteBuffer -import java.util -import scala.util.Random - -object UserReindexingNighthawkWritableDataRecordStore { - implicit val longInjection = Injection.long2BigEndian - implicit val dataRecordInjection: Injection[DataRecord, Array[Byte]] = - CompactThriftCodec[DataRecord] - val arrayToByteBuffer = Injection.connect[Array[Byte], ByteBuffer] - val longToByteBuffer = longInjection.andThen(arrayToByteBuffer) - val dataRecordToByteBuffer = dataRecordInjection.andThen(arrayToByteBuffer) - - def getBtreeStore( - nighthawkCacheConfig: CacheClientNighthawkConfig, - targetSize: Int, - statsReceiver: StatsReceiver, - trimRate: Double - ): UserReindexingNighthawkBtreeWritableDataRecordStore = - new UserReindexingNighthawkBtreeWritableDataRecordStore( - nighthawkStore = NighthawkStore[UserId, TimestampMs, DataRecord](nighthawkCacheConfig) - .asInstanceOf[NighthawkStore[UserId, TimestampMs, DataRecord]], - tableName = nighthawkCacheConfig.table.toString, - targetSize = targetSize, - statsReceiver = statsReceiver, - trimRate = trimRate - ) - - def getHashStore( - nighthawkCacheConfig: CacheClientNighthawkConfig, - targetSize: Int, - statsReceiver: StatsReceiver, - trimRate: Double - ): UserReindexingNighthawkHashWritableDataRecordStore = - new UserReindexingNighthawkHashWritableDataRecordStore( - nighthawkStore = NighthawkStore[UserId, AuthorId, DataRecord](nighthawkCacheConfig) - .asInstanceOf[NighthawkStore[UserId, AuthorId, DataRecord]], - tableName = nighthawkCacheConfig.table.toString, - targetSize = targetSize, - statsReceiver = statsReceiver, - trimRate = trimRate - ) - - def buildTimestampedByteBuffer(timestamp: Long, bb: ByteBuffer): ByteBuffer = { - val timestampedBb = ByteBuffer.allocate(getLength(bb) + java.lang.Long.SIZE) - timestampedBb.putLong(timestamp) - timestampedBb.put(bb) - timestampedBb - } - - def extractTimestampFromTimestampedByteBuffer(bb: ByteBuffer): Long = { - bb.getLong(0) - } - - def extractValueFromTimestampedByteBuffer(bb: ByteBuffer): ByteBuffer = { - val bytes = new Array[Byte](getLength(bb) - java.lang.Long.SIZE) - util.Arrays.copyOfRange(bytes, java.lang.Long.SIZE, getLength(bb)) - ByteBuffer.wrap(bytes) - } - - def transformAndBuildKeyValueMapping( - table: String, - userId: UserId, - authorIdsAndDataRecords: Seq[(AuthorId, DataRecord)] - ): KeyValue = { - val timestamp = Time.now.inMillis - val pkey = longToByteBuffer(userId) - val lkeysAndTimestampedValues = authorIdsAndDataRecords.map { - case (authorId, dataRecord) => - val lkey = longToByteBuffer(authorId) - // Create a byte buffer with a prepended timestamp to reduce deserialization cost - // when parsing values. We only have to extract and deserialize the timestamp in the - // ByteBuffer in order to sort the value, as opposed to deserializing the DataRecord - // and having to get a timestamp feature value from the DataRecord. - val dataRecordBb = dataRecordToByteBuffer(dataRecord) - val timestampedValue = buildTimestampedByteBuffer(timestamp, dataRecordBb) - (lkey, timestampedValue) - } - buildKeyValueMapping(table, pkey, lkeysAndTimestampedValues) - } - - def buildKeyValueMapping( - table: String, - pkey: ByteBuffer, - lkeysAndTimestampedValues: Seq[(ByteBuffer, ByteBuffer)] - ): KeyValue = { - val lkeys = lkeysAndTimestampedValues.map { case (lkey, _) => lkey } - val timestampedValues = lkeysAndTimestampedValues.map { case (_, value) => value } - val kv = KeyValue( - key = Key(table = table, pkey = pkey, lkeys = lkeys), - value = Value(timestampedValues) - ) - kv - } - - private def getLength(bb: ByteBuffer): Int = { - // capacity can be an over-estimate of the actual length (remaining - start position) - // but it's the safest to avoid overflows. - bb.capacity() - } -} - -/** - * Implements a NH store that stores aggregate feature DataRecords using userId as the primary key. - * - * This store re-indexes user-author keyed real-time aggregate (RTA) features on userId by - * writing to a userId primary key (pkey) and timestamp secondary key (lkey). To fetch user-author - * RTAs for a given user from cache, the caller just needs to make a single RPC for the userId pkey. - * The downside of a re-indexing store is that we cannot store arbitrarily many secondary keys - * under the primary key. This specific implementation using the NH btree backend also mandates - * mandates an ordering of secondary keys - we therefore use timestamp as the secondary key - * as opposed to say authorId. - * - * Note that a caller of the btree backed NH re-indexing store receives back a response where the - * secondary key is a timestamp. The associated value is a DataRecord containing user-author related - * aggregate features which was last updated at the timestamp. The caller therefore needs to handle - * the response and dedupe on unique, most recent user-author pairs. - * - * For a discussion on this and other implementations, please see: - * https://docs.google.com/document/d/1yVzAbQ_ikLqwSf230URxCJmSKj5yZr5dYv6TwBlQw18/edit - */ -class UserReindexingNighthawkBtreeWritableDataRecordStore( - nighthawkStore: NighthawkStore[UserId, TimestampMs, DataRecord], - tableName: String, - targetSize: Int, - statsReceiver: StatsReceiver, - trimRate: Double = 0.1 // by default, trim on 10% of puts -) extends WritableStore[(AggregationKey, BatchID), Option[DataRecord]] { - - private val scope = getClass.getSimpleName - private val failures = statsReceiver.counter(scope, "failures") - private val log = Logger.getLogger(getClass) - private val random: Random = new Random(1729L) - - override def put(kv: ((AggregationKey, BatchID), Option[DataRecord])): Future[Unit] = { - val ((aggregationKey, _), dataRecordOpt) = kv - // Fire-and-forget below because the store itself should just be a side effect - // as it's just making re-indexed writes based on the writes to the primary store. - for { - userId <- aggregationKey.discreteFeaturesById.get(SharedFeatures.USER_ID.getFeatureId) - dataRecord <- dataRecordOpt - } yield { - SRichDataRecord(dataRecord) - .getFeatureValueOpt(TypedAggregateGroup.timestampFeature) - .map(_.toLong) // convert to Scala Long - .map { timestamp => - val trim: Future[Unit] = if (random.nextDouble <= trimRate) { - val trimKey = TrimKey( - table = tableName, - pkey = longToByteBuffer(userId), - targetSize = targetSize, - ascending = true - ) - nighthawkStore.client.trim(Seq(trimKey)).unit - } else { - Future.Unit - } - // We should wait for trim to complete above - val fireAndForget = trim.before { - val kvTuple = ((userId, timestamp), Some(dataRecord)) - nighthawkStore.put(kvTuple) - } - - fireAndForget.onFailure { - case e => - failures.incr() - log.error("Failure in UserReindexingNighthawkHashWritableDataRecordStore", e) - } - } - } - // Ignore fire-and-forget result above and simply return - Future.Unit - } -} - -/** - * Implements a NH store that stores aggregate feature DataRecords using userId as the primary key. - * - * This store re-indexes user-author keyed real-time aggregate (RTA) features on userId by - * writing to a userId primary key (pkey) and authorId secondary key (lkey). To fetch user-author - * RTAs for a given user from cache, the caller just needs to make a single RPC for the userId pkey. - * The downside of a re-indexing store is that we cannot store arbitrarily - * many secondary keys under the primary key. We have to limit them in some way; - * here, we do so by randomly (based on trimRate) issuing an HGETALL command (via scan) to - * retrieve the whole hash, sort by oldest timestamp, and then remove the oldest authors to keep - * only targetSize authors (aka trim), where targetSize is configurable. - * - * @note The full hash returned from scan could be as large (or even larger) than targetSize, - * which could mean many DataRecords to deserialize, especially at high write qps. - * To reduce deserialization cost post-scan, we use timestamped values with a prepended timestamp - * in the value ByteBuffer; this allows us to only deserialize the timestamp and not the full - * DataRecord when sorting. This is necessary in order to identify the oldest values to trim. - * When we do a put for a new (user, author) pair, we also write out timestamped values. - * - * For a discussion on this and other implementations, please see: - * https://docs.google.com/document/d/1yVzAbQ_ikLqwSf230URxCJmSKj5yZr5dYv6TwBlQw18/edit - */ -class UserReindexingNighthawkHashWritableDataRecordStore( - nighthawkStore: NighthawkStore[UserId, AuthorId, DataRecord], - tableName: String, - targetSize: Int, - statsReceiver: StatsReceiver, - trimRate: Double = 0.1 // by default, trim on 10% of puts -) extends WritableStore[(AggregationKey, BatchID), Option[DataRecord]] { - - private val scope = getClass.getSimpleName - private val scanMismatchErrors = statsReceiver.counter(scope, "scanMismatchErrors") - private val failures = statsReceiver.counter(scope, "failures") - private val log = Logger.getLogger(getClass) - private val random: Random = new Random(1729L) - private val arrayToByteBuffer = Injection.connect[Array[Byte], ByteBuffer] - private val longToByteBuffer = Injection.long2BigEndian.andThen(arrayToByteBuffer) - - override def put(kv: ((AggregationKey, BatchID), Option[DataRecord])): Future[Unit] = { - val ((aggregationKey, _), dataRecordOpt) = kv - // Fire-and-forget below because the store itself should just be a side effect - // as it's just making re-indexed writes based on the writes to the primary store. - for { - userId <- aggregationKey.discreteFeaturesById.get(SharedFeatures.USER_ID.getFeatureId) - authorId <- aggregationKey.discreteFeaturesById.get( - TimelinesSharedFeatures.SOURCE_AUTHOR_ID.getFeatureId) - dataRecord <- dataRecordOpt - } yield { - val scanAndTrim: Future[Unit] = if (random.nextDouble <= trimRate) { - val scanKey = ScanKey( - table = tableName, - pkey = longToByteBuffer(userId) - ) - nighthawkStore.client.scan(Seq(scanKey)).flatMap { scanResults: Seq[Try[KeyValue]] => - scanResults.headOption - .flatMap(_.toOption).map { keyValue: KeyValue => - val lkeys: Seq[ByteBuffer] = keyValue.key.lkeys - // these are timestamped bytebuffers - val timestampedValues: Seq[ByteBuffer] = keyValue.value.values - // this should fail loudly if this is not true. it would indicate - // there is a mistake in the scan. - if (lkeys.size != timestampedValues.size) scanMismatchErrors.incr() - assert(lkeys.size == timestampedValues.size) - if (lkeys.size > targetSize) { - val numToRemove = targetSize - lkeys.size - // sort by oldest and take top k oldest and remove - this is equivalent to a trim - val oldestKeys: Seq[ByteBuffer] = lkeys - .zip(timestampedValues) - .map { - case (lkey, timestampedValue) => - val timestamp = extractTimestampFromTimestampedByteBuffer(timestampedValue) - (timestamp, lkey) - } - .sortBy { case (timestamp, _) => timestamp } - .take(numToRemove) - .map { case (_, k) => k } - val pkey = longToByteBuffer(userId) - val key = Key(table = tableName, pkey = pkey, lkeys = oldestKeys) - // NOTE: `remove` is a batch API, and we group all lkeys into a single batch (batch - // size = single group of lkeys = 1). Instead, we could separate lkeys into smaller - // groups and have batch size = number of groups, but this is more complex. - // Performance implications of batching vs non-batching need to be assessed. - nighthawkStore.client - .remove(Seq(key)) - .map { responses => - responses.map(resp => nighthawkStore.processValue(resp)) - }.unit - } else { - Future.Unit - } - }.getOrElse(Future.Unit) - } - } else { - Future.Unit - } - // We should wait for scan and trim to complete above - val fireAndForget = scanAndTrim.before { - val kv = transformAndBuildKeyValueMapping(tableName, userId, Seq((authorId, dataRecord))) - nighthawkStore.client - .put(Seq(kv)) - .map { responses => - responses.map(resp => nighthawkStore.processValue(resp)) - }.unit - } - fireAndForget.onFailure { - case e => - failures.incr() - log.error("Failure in UserReindexingNighthawkHashWritableDataRecordStore", e) - } - } - // Ignore fire-and-forget result above and simply return - Future.Unit - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/package.docx b/timelines/data_processing/ml_util/aggregation_framework/heron/package.docx new file mode 100644 index 000000000..54ba76a44 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/heron/package.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/heron/package.scala b/timelines/data_processing/ml_util/aggregation_framework/heron/package.scala deleted file mode 100644 index e995cf202..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/heron/package.scala +++ /dev/null @@ -1,8 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework - -package object heron { - // NOTE: please sort alphabetically - type AuthorId = Long - type UserId = Long - type TimestampMs = Long -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/job/AggregatesV2Job.docx b/timelines/data_processing/ml_util/aggregation_framework/job/AggregatesV2Job.docx new file mode 100644 index 000000000..f5c484b5b Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/job/AggregatesV2Job.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/job/AggregatesV2Job.scala b/timelines/data_processing/ml_util/aggregation_framework/job/AggregatesV2Job.scala deleted file mode 100644 index 7d9e1946e..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/job/AggregatesV2Job.scala +++ /dev/null @@ -1,163 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.job - -import com.twitter.algebird.Semigroup -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.DataRecordMerger -import com.twitter.summingbird.Platform -import com.twitter.summingbird.Producer -import com.twitter.summingbird.TailProducer -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregateSource -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregateStore -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup - -object AggregatesV2Job { - private lazy val merger = new DataRecordMerger - - /** - * Merges all "incremental" records with the same aggregation key - * into a single record. - * - * @param recordsPerKey A set of (AggregationKey, DataRecord) tuples - * known to share the same AggregationKey - * @return A single merged datarecord - */ - def mergeRecords(recordsPerKey: Set[(AggregationKey, DataRecord)]): DataRecord = - recordsPerKey.foldLeft(new DataRecord) { - case (merged: DataRecord, (key: AggregationKey, elem: DataRecord)) => { - merger.merge(merged, elem) - merged - } - } - - /** - * Given a set of aggregates to compute and a datarecord, extract key-value - * pairs to output to the summingbird store. - * - * @param dataRecord input data record - * @param aggregates set of aggregates to compute - * @param featureCounters counters to apply to each input data record - * @return computed aggregates - */ - def computeAggregates( - dataRecord: DataRecord, - aggregates: Set[TypedAggregateGroup[_]], - featureCounters: Seq[DataRecordFeatureCounter] - ): Map[AggregationKey, DataRecord] = { - val computedAggregates = aggregates - .flatMap(_.computeAggregateKVPairs(dataRecord)) - .groupBy { case (aggregationKey: AggregationKey, _) => aggregationKey } - .mapValues(mergeRecords) - - featureCounters.foreach(counter => - computedAggregates.map(agg => DataRecordFeatureCounter(counter, agg._2))) - - computedAggregates - - } - - /** - * Util method to apply a filter on containment in an optional set. - * - * @param setOptional Optional set of items to check containment in. - * @param toCheck Item to check if contained in set. - * @return If the optional set is None, returns true. - */ - def setFilter[T](setOptional: Option[Set[T]], toCheck: T): Boolean = - setOptional.map(_.contains(toCheck)).getOrElse(true) - - /** - * Util for filtering a collection of `TypedAggregateGroup` - * - * @param aggregates a set of aggregates - * @param sourceNames Optional filter on which AggregateGroups to process - * based on the name of the input source. - * @param storeNames Optional filter on which AggregateGroups to process - * based on the name of the output store. - * @return filtered aggregates - */ - def filterAggregates( - aggregates: Set[TypedAggregateGroup[_]], - sourceNames: Option[Set[String]], - storeNames: Option[Set[String]] - ): Set[TypedAggregateGroup[_]] = - aggregates - .filter { aggregateGroup => - val sourceName = aggregateGroup.inputSource.name - val storeName = aggregateGroup.outputStore.name - val containsSource = setFilter(sourceNames, sourceName) - val containsStore = setFilter(storeNames, storeName) - containsSource && containsStore - } - - /** - * The core summingbird job code. - * - * For each aggregate in the set passed in, the job - * processes all datarecords in the input producer - * stream to generate "incremental" contributions to - * these aggregates, and emits them grouped by - * aggregation key so that summingbird can aggregate them. - * - * It is important that after applying the sourceNameFilter and storeNameFilter, - * all the result AggregateGroups share the same startDate, otherwise the job - * will fail or give invalid results. - * - * @param aggregateSet A set of aggregates to compute. All aggregates - * in this set that pass the sourceNameFilter and storeNameFilter - * defined below, if any, will be computed. - * @param aggregateSourceToSummingbird Function that maps from our logical - * AggregateSource abstraction to the underlying physical summingbird - * producer of data records to aggregate (e.g. scalding/eventbus source) - * @param aggregateStoreToSummingbird Function that maps from our logical - * AggregateStore abstraction to the underlying physical summingbird - * store to write output aggregate records to (e.g. mahattan for scalding, - * or memcache for heron) - * @param featureCounters counters to use with each input DataRecord - * @return summingbird tail producer - */ - def generateJobGraph[P <: Platform[P]]( - aggregateSet: Set[TypedAggregateGroup[_]], - aggregateSourceToSummingbird: AggregateSource => Option[Producer[P, DataRecord]], - aggregateStoreToSummingbird: AggregateStore => Option[P#Store[AggregationKey, DataRecord]], - featureCounters: Seq[DataRecordFeatureCounter] = Seq.empty - )( - implicit semigroup: Semigroup[DataRecord] - ): TailProducer[P, Any] = { - val tailProducerList: List[TailProducer[P, Any]] = aggregateSet - .groupBy { aggregate => (aggregate.inputSource, aggregate.outputStore) } - .flatMap { - case ( - (inputSource: AggregateSource, outputStore: AggregateStore), - aggregatesInThisStore - ) => { - val producerOpt = aggregateSourceToSummingbird(inputSource) - val storeOpt = aggregateStoreToSummingbird(outputStore) - - (producerOpt, storeOpt) match { - case (Some(producer), Some(store)) => - Some( - producer - .flatMap(computeAggregates(_, aggregatesInThisStore, featureCounters)) - .name("FLATMAP") - .sumByKey(store) - .name("SUMMER") - ) - case _ => None - } - } - } - .toList - - tailProducerList.reduceLeft { (left, right) => left.also(right) } - } - - def aggregateNames(aggregateSet: Set[TypedAggregateGroup[_]]) = { - aggregateSet - .map(typedGroup => - ( - typedGroup.aggregatePrefix, - typedGroup.individualAggregateDescriptors - .flatMap(_.outputFeatures.map(_.getFeatureName)).mkString(","))) - }.toMap -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/job/BUILD b/timelines/data_processing/ml_util/aggregation_framework/job/BUILD deleted file mode 100644 index 57593fa34..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/job/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -scala_library( - sources = ["*.scala"], - platform = "java8", - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/twitter/algebird:core", - "3rdparty/jvm/com/twitter/algebird:util", - "3rdparty/jvm/com/twitter/storehaus:algebra", - "3rdparty/jvm/com/twitter/storehaus:core", - "3rdparty/src/jvm/com/twitter/scalding:commons", - "3rdparty/src/jvm/com/twitter/scalding:core", - "3rdparty/src/jvm/com/twitter/summingbird:batch", - "3rdparty/src/jvm/com/twitter/summingbird:core", - "src/java/com/twitter/ml/api:api-base", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:interpretable-model-java", - "timelines/data_processing/ml_util/aggregation_framework:common_types", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/job/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/job/BUILD.docx new file mode 100644 index 000000000..1be10f1ee Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/job/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/job/DataRecordFeatureCounter.docx b/timelines/data_processing/ml_util/aggregation_framework/job/DataRecordFeatureCounter.docx new file mode 100644 index 000000000..23081d6c0 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/job/DataRecordFeatureCounter.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/job/DataRecordFeatureCounter.scala b/timelines/data_processing/ml_util/aggregation_framework/job/DataRecordFeatureCounter.scala deleted file mode 100644 index eb1580a11..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/job/DataRecordFeatureCounter.scala +++ /dev/null @@ -1,39 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.job - -import com.twitter.ml.api.DataRecord -import com.twitter.summingbird.Counter - -/** - * A summingbird Counter which is associated with a predicate which operates on - * [[com.twitter.ml.api.DataRecord]] instances. - * - * For example, for a data record which represents a Tweet, one could define a predicate - * which checks whether the Tweet contains a binary feature representing the presence of - * an image. The counter can then be used to represent the the count of Tweets with - * images processed. - * - * @param predicate a predicate which gates the counter - * @param counter a summingbird Counter instance - */ -case class DataRecordFeatureCounter(predicate: DataRecord => Boolean, counter: Counter) - -object DataRecordFeatureCounter { - - /** - * Increments the counter if the record satisfies the predicate - * - * @param recordCounter a data record counter - * @param record a data record - */ - def apply(recordCounter: DataRecordFeatureCounter, record: DataRecord): Unit = - if (recordCounter.predicate(record)) recordCounter.counter.incr() - - /** - * Defines a feature counter with a predicate that is always true - * - * @param counter a summingbird Counter instance - * @return a data record counter - */ - def any(counter: Counter): DataRecordFeatureCounter = - DataRecordFeatureCounter({ _: DataRecord => true }, counter) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregateFeature.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregateFeature.docx new file mode 100644 index 000000000..d7c72471a Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregateFeature.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregateFeature.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregateFeature.scala deleted file mode 100644 index 4f80490bc..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregateFeature.scala +++ /dev/null @@ -1,51 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.util.Duration -import com.twitter.ml.api._ -import java.lang.{Boolean => JBoolean} - -/** - * Case class used as shared argument for - * getAggregateValue() and setAggregateValue() in AggregationMetric. - * - * @param aggregatePrefix Prefix for aggregate feature name - * @param feature Simple (non-aggregate) feature being aggregated. This - is optional; if None, then the label is aggregated on its own without - being crossed with any feature. - * @param label Label being paired with. This is optional; if None, then - the feature is aggregated on its own without being crossed with any label. - * @param halfLife Half life being used for aggregation - */ -case class AggregateFeature[T]( - aggregatePrefix: String, - feature: Option[Feature[T]], - label: Option[Feature[JBoolean]], - halfLife: Duration) { - val aggregateType = "pair" - val labelName: String = label.map(_.getDenseFeatureName()).getOrElse("any_label") - val featureName: String = feature.map(_.getDenseFeatureName()).getOrElse("any_feature") - - /* - * This val precomputes a portion of the feature name - * for faster processing. String building turns - * out to be a significant bottleneck. - */ - val featurePrefix: String = List( - aggregatePrefix, - aggregateType, - labelName, - featureName, - halfLife.toString - ).mkString(".") -} - -/* Companion object with util methods. */ -object AggregateFeature { - def parseHalfLife(aggregateFeature: Feature[_]): Duration = { - val aggregateComponents = aggregateFeature.getDenseFeatureName().split("\\.") - val numComponents = aggregateComponents.length - val halfLifeStr = aggregateComponents(numComponents - 3) + "." + - aggregateComponents(numComponents - 2) - Duration.parse(halfLifeStr) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetric.docx new file mode 100644 index 000000000..bc76ff163 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetric.scala deleted file mode 100644 index 4278c8812..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetric.scala +++ /dev/null @@ -1,184 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.util.Duration -import java.lang.{Long => JLong} - -/** - * Represents an aggregation operator (e.g. count or mean). - * Override all functions in this trait to implement your own metric. - * The operator is parameterized on an input type T, which is the type - * of feature it aggregates, and a TimedValue[A] which is - * the result type of aggregation for this metric. - */ -trait AggregationMetric[T, A] extends FeatureCache[T] { - /* - * Combines two timed aggregate values ''left'' and ''right'' - * with the specified half life ''halfLife'' to produce a result - * TimedValue - * - * @param left Left timed value - * @param right Right timed value - * @param halfLife Half life to use for adding timed values - * @return Result timed value - */ - def plus(left: TimedValue[A], right: TimedValue[A], halfLife: Duration): TimedValue[A] - - /* - * Gets increment value given a datarecord and a feature. - * - * @param dataRecord to get increment value from. - * @param feature Feature to get increment value for. If None, - then the semantics is to just aggregate the label. - * @param timestampFeature Feature to use as millisecond timestamp - for decayed value aggregation. - * @return The incremental contribution to the aggregate of ''feature'' from ''dataRecord''. - * - * For example, if the aggregation metric is count, the incremental - * contribution is always a TimedValue (1.0, time). If the aggregation metric - * is mean, and the feature is a continuous feature (double), the incremental - * contribution looks like a tuple (value, 1.0, time) - */ - def getIncrementValue( - dataRecord: DataRecord, - feature: Option[Feature[T]], - timestampFeature: Feature[JLong] - ): TimedValue[A] - - /* - * The "zero" value for aggregation. - * For example, the zero is 0 for the count operator. - */ - def zero(timeOpt: Option[Long] = None): TimedValue[A] - - /* - * Gets the value of aggregate feature(s) stored in a datarecord, if any. - * Different aggregate operators might store this info in the datarecord - * differently. E.g. count just stores a count, while mean needs to - * store both a sum and a count, and compile them into a TimedValue. We call - * these features stored in the record "output" features. - * - * @param record Record to get value from - * @param query AggregateFeature (see above) specifying details of aggregate - * @param aggregateOutputs An optional precomputed set of aggregation "output" - * feature hashes for this (query, metric) pair. This can be derived from ''query'', - * but we precompute and pass this in for significantly (approximately 4x = 400%) - * faster performance. If not passed in, the operator should reconstruct these features - * from scratch. - * - * @return The aggregate value if found in ''record'', else the appropriate "zero" - for this type of aggregation. - */ - def getAggregateValue( - record: DataRecord, - query: AggregateFeature[T], - aggregateOutputs: Option[List[JLong]] = None - ): TimedValue[A] - - /* - * Sets the value of aggregate feature(s) in a datarecord. Different operators - * will have different representations (see example above). - * - * @param record Record to set value in - * @param query AggregateFeature (see above) specifying details of aggregate - * @param aggregateOutputs An optional precomputed set of aggregation "output" - * features for this (query, metric) pair. This can be derived from ''query'', - * but we precompute and pass this in for significantly (approximately 4x = 400%) - * faster performance. If not passed in, the operator should reconstruct these features - * from scratch. - * - * @param value Value to set for aggregate feature in the record being passed in via ''query'' - */ - def setAggregateValue( - record: DataRecord, - query: AggregateFeature[T], - aggregateOutputs: Option[List[JLong]] = None, - value: TimedValue[A] - ): Unit - - /** - * Get features used to store aggregate output representation - * in partially aggregated data records. - * - * @query AggregateFeature (see above) specifying details of aggregate - * @return A list of "output" features used by this metric to store - * output representation. For example, for the "count" operator, we - * have only one element in this list, which is the result "count" feature. - * For the "mean" operator, we have three elements in this list: the "count" - * feature, the "sum" feature and the "mean" feature. - */ - def getOutputFeatures(query: AggregateFeature[T]): List[Feature[_]] - - /** - * Get feature hashes used to store aggregate output representation - * in partially aggregated data records. - * - * @query AggregateFeature (see above) specifying details of aggregate - * @return A list of "output" feature hashes used by this metric to store - * output representation. For example, for the "count" operator, we - * have only one element in this list, which is the result "count" feature. - * For the "mean" operator, we have three elements in this list: the "count" - * feature, the "sum" feature and the "mean" feature. - */ - def getOutputFeatureIds(query: AggregateFeature[T]): List[JLong] = - getOutputFeatures(query) - .map(_.getDenseFeatureId().asInstanceOf[JLong]) - - /* - * Sums the given feature in two datarecords into a result record - * WARNING: this method has side-effects; it modifies combined - * - * @param combined Result datarecord to mutate and store addition result in - * @param left Left datarecord to add - * @param right Right datarecord to add - * @param query Details of aggregate to add - * @param aggregateOutputs An optional precomputed set of aggregation "output" - * feature hashes for this (query, metric) pair. This can be derived from ''query'', - * but we precompute and pass this in for significantly (approximately 4x = 400%) - * faster performance. If not passed in, the operator should reconstruct these features - * from scratch. - */ - def mutatePlus( - combined: DataRecord, - left: DataRecord, - right: DataRecord, - query: AggregateFeature[T], - aggregateOutputs: Option[List[JLong]] = None - ): Unit = { - val leftValue = getAggregateValue(left, query, aggregateOutputs) - val rightValue = getAggregateValue(right, query, aggregateOutputs) - val combinedValue = plus(leftValue, rightValue, query.halfLife) - setAggregateValue(combined, query, aggregateOutputs, combinedValue) - } - - /** - * Helper function to get increment value from an input DataRecord - * and copy it to an output DataRecord, given an AggregateFeature query spec. - * - * @param output Datarecord to output increment to (will be mutated by this method) - * @param input Datarecord to get increment from - * @param query Details of aggregation - * @param aggregateOutputs An optional precomputed set of aggregation "output" - * feature hashes for this (query, metric) pair. This can be derived from ''query'', - * but we precompute and pass this in for significantly (approximately 4x = 400%) - * faster performance. If not passed in, the operator should reconstruct these features - * from scratch. - * @return True if an increment was set in the output record, else false - */ - def setIncrement( - output: DataRecord, - input: DataRecord, - query: AggregateFeature[T], - timestampFeature: Feature[JLong] = SharedFeatures.TIMESTAMP, - aggregateOutputs: Option[List[JLong]] = None - ): Boolean = { - if (query.label == None || - (query.label.isDefined && SRichDataRecord(input).hasFeature(query.label.get))) { - val incrementValue: TimedValue[A] = getIncrementValue(input, query.feature, timestampFeature) - setAggregateValue(output, query, aggregateOutputs, incrementValue) - true - } else false - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetricCommon.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetricCommon.docx new file mode 100644 index 000000000..ada271ffb Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetricCommon.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetricCommon.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetricCommon.scala deleted file mode 100644 index e7b97e07b..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/AggregationMetricCommon.scala +++ /dev/null @@ -1,55 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.algebird.DecayedValue -import com.twitter.algebird.DecayedValueMonoid -import com.twitter.algebird.Monoid -import com.twitter.dal.personal_data.thriftjava.PersonalDataType -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.util.Duration -import java.lang.{Long => JLong} -import java.util.{HashSet => JHashSet} -import java.util.{Set => JSet} - -object AggregationMetricCommon { - /* Shared definitions and utils that can be reused by child classes */ - val Epsilon: Double = 1e-6 - val decayedValueMonoid: Monoid[DecayedValue] = DecayedValueMonoid(Epsilon) - val TimestampHash: JLong = SharedFeatures.TIMESTAMP.getDenseFeatureId() - - def toDecayedValue(tv: TimedValue[Double], halfLife: Duration): DecayedValue = { - DecayedValue.build( - tv.value, - tv.timestamp.inMilliseconds, - halfLife.inMilliseconds - ) - } - - def getTimestamp( - record: DataRecord, - timestampFeature: Feature[JLong] = SharedFeatures.TIMESTAMP - ): Long = { - Option( - SRichDataRecord(record) - .getFeatureValue(timestampFeature) - ).map(_.toLong) - .getOrElse(0L) - } - - /* - * Union the PDTs of the input featureOpts. - * Return null if empty, else the JSet[PersonalDataType] - */ - def derivePersonalDataTypes(features: Option[Feature[_]]*): JSet[PersonalDataType] = { - val unionPersonalDataTypes = new JHashSet[PersonalDataType]() - for { - featureOpt <- features - feature <- featureOpt - pdtSetOptional = feature.getPersonalDataTypes - if pdtSetOptional.isPresent - pdtSet = pdtSetOptional.get - } unionPersonalDataTypes.addAll(pdtSet) - if (unionPersonalDataTypes.isEmpty) null else unionPersonalDataTypes - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/BUILD b/timelines/data_processing/ml_util/aggregation_framework/metrics/BUILD deleted file mode 100644 index 676b31d81..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -scala_library( - sources = ["*.scala"], - platform = "java8", - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/twitter/algebird:core", - "src/java/com/twitter/ml/api:api-base", - "src/java/com/twitter/ml/api/constant", - "src/scala/com/twitter/ml/api/util:datarecord", - "src/thrift/com/twitter/dal/personal_data:personal_data-java", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:interpretable-model-java", - "util/util-core:scala", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/BUILD.docx new file mode 100644 index 000000000..13520838b Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/ConversionUtils.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/ConversionUtils.docx new file mode 100644 index 000000000..3f55fc67d Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/ConversionUtils.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/ConversionUtils.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/ConversionUtils.scala deleted file mode 100644 index b04263ea0..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/ConversionUtils.scala +++ /dev/null @@ -1,5 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -object ConversionUtils { - def booleanToDouble(value: Boolean): Double = if (value) 1.0 else 0.0 -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/CountMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/CountMetric.docx new file mode 100644 index 000000000..f1c03e413 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/CountMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/CountMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/CountMetric.scala deleted file mode 100644 index 720fa68e5..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/CountMetric.scala +++ /dev/null @@ -1,41 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.util.Time -import java.lang.{Long => JLong} - -case class TypedCountMetric[T]( -) extends TypedSumLikeMetric[T] { - import AggregationMetricCommon._ - import ConversionUtils._ - override val operatorName = "count" - - override def getIncrementValue( - record: DataRecord, - feature: Option[Feature[T]], - timestampFeature: Feature[JLong] - ): TimedValue[Double] = { - val featureExists: Boolean = feature match { - case Some(f) => SRichDataRecord(record).hasFeature(f) - case None => true - } - - TimedValue[Double]( - value = booleanToDouble(featureExists), - timestamp = Time.fromMilliseconds(getTimestamp(record, timestampFeature)) - ) - } -} - -/** - * Syntactic sugar for the count metric that works with - * any feature type as opposed to being tied to a specific type. - * See EasyMetric.scala for more details on why this is useful. - */ -object CountMetric extends EasyMetric { - override def forFeatureType[T]( - featureType: FeatureType, - ): Option[AggregationMetric[T, _]] = - Some(TypedCountMetric[T]()) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/EasyMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/EasyMetric.docx new file mode 100644 index 000000000..bd38ff7ba Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/EasyMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/EasyMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/EasyMetric.scala deleted file mode 100644 index 67edce7ce..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/EasyMetric.scala +++ /dev/null @@ -1,34 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ - -/** - * A "human-readable" metric that can be applied to features of multiple - * different types. Wrapper around AggregationMetric used as syntactic sugar - * for easier config. - */ -trait EasyMetric extends Serializable { - /* - * Given a feature type, fetches the corrrect underlying AggregationMetric - * to perform this operation over the given feature type, if any. If no such - * metric is available, returns None. For example, MEAN cannot be applied - * to FeatureType.String and would return None. - * - * @param featureType Type of feature to fetch metric for - * @param useFixedDecay Param to control whether the metric should use fixed decay - * logic (if appropriate) - * @return Strongly typed aggregation metric to use for this feature type - * - * For example, if the EasyMetric is MEAN and the featureType is - * FeatureType.Continuous, the underlying AggregationMetric should be a - * scalar mean. If the EasyMetric is MEAN and the featureType is - * FeatureType.SparseContinuous, the AggregationMetric returned could be a - * "vector" mean that averages sparse maps. Using the single logical name - * MEAN for both is nice syntactic sugar making for an easier to read top - * level config, though different underlying operators are used underneath - * for the actual implementation. - */ - def forFeatureType[T]( - featureType: FeatureType, - ): Option[AggregationMetric[T, _]] -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/FeatureCache.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/FeatureCache.docx new file mode 100644 index 000000000..cf9733f60 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/FeatureCache.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/FeatureCache.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/FeatureCache.scala deleted file mode 100644 index e5f384100..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/FeatureCache.scala +++ /dev/null @@ -1,72 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import scala.collection.mutable - -trait FeatureCache[T] { - /* - * Constructs feature names from scratch given an aggregate query and an output - * feature name. E.g. given mean operator and "sum". This function is slow and should - * only be called at pre-computation time. - * - * @param query Details of aggregate feature - * @name Name of "output" feature for which we want to construct feature name - * @return Full name of output feature - */ - private def uncachedFullFeatureName(query: AggregateFeature[T], name: String): String = - List(query.featurePrefix, name).mkString(".") - - /* - * A cache from (aggregate query, output feature name) -> fully qualified feature name - * lazy since it doesn't need to be serialized to the mappers - */ - private lazy val featureNameCache = mutable.Map[(AggregateFeature[T], String), String]() - - /* - * A cache from (aggregate query, output feature name) -> precomputed output feature - * lazy since it doesn't need to be serialized to the mappers - */ - private lazy val featureCache = mutable.Map[(AggregateFeature[T], String), Feature[_]]() - - /** - * Given an (aggregate query, output feature name, output feature type), - * look it up using featureNameCache and featureCache, falling back to uncachedFullFeatureName() - * as a last resort to construct a precomputed output feature. Should only be - * called at pre-computation time. - * - * @param query Details of aggregate feature - * @name Name of "output" feature we want to precompute - * @aggregateFeatureType type of "output" feature we want to precompute - */ - def cachedFullFeature( - query: AggregateFeature[T], - name: String, - aggregateFeatureType: FeatureType - ): Feature[_] = { - lazy val cachedFeatureName = featureNameCache.getOrElseUpdate( - (query, name), - uncachedFullFeatureName(query, name) - ) - - def uncachedFullFeature(): Feature[_] = { - val personalDataTypes = - AggregationMetricCommon.derivePersonalDataTypes(query.feature, query.label) - - aggregateFeatureType match { - case FeatureType.BINARY => new Feature.Binary(cachedFeatureName, personalDataTypes) - case FeatureType.DISCRETE => new Feature.Discrete(cachedFeatureName, personalDataTypes) - case FeatureType.STRING => new Feature.Text(cachedFeatureName, personalDataTypes) - case FeatureType.CONTINUOUS => new Feature.Continuous(cachedFeatureName, personalDataTypes) - case FeatureType.SPARSE_BINARY => - new Feature.SparseBinary(cachedFeatureName, personalDataTypes) - case FeatureType.SPARSE_CONTINUOUS => - new Feature.SparseContinuous(cachedFeatureName, personalDataTypes) - } - } - - featureCache.getOrElseUpdate( - (query, name), - uncachedFullFeature() - ) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/LastResetMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/LastResetMetric.docx new file mode 100644 index 000000000..a782bec83 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/LastResetMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/LastResetMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/LastResetMetric.scala deleted file mode 100644 index 67fe444aa..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/LastResetMetric.scala +++ /dev/null @@ -1,107 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import java.lang.{Long => JLong} -import com.twitter.ml.api._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.ConversionUtils._ -import com.twitter.util.Duration -import com.twitter.util.Time -import scala.math.max - -/** - * This metric measures how recently an action has taken place. A value of 1.0 - * indicates the action happened just now. This value decays with time if the - * action has not taken place and is reset to 1 when the action happens. So lower - * value indicates a stale or older action. - * - * For example consider an action of "user liking a video". The last reset metric - * value changes as follows for a half life of 1 day. - * - * ---------------------------------------------------------------------------- - * day | action | feature value | Description - * ---------------------------------------------------------------------------- - * 1 | user likes the video | 1.0 | Set the value to 1 - * 2 | user does not like video | 0.5 | Decay the value - * 3 | user does not like video | 0.25 | Decay the value - * 4 | user likes the video | 1.0 | Reset the value to 1 - * ----------------------------------------------------------------------------- - * - * @tparam T - */ -case class TypedLastResetMetric[T]() extends TimedValueAggregationMetric[T] { - import AggregationMetricCommon._ - - override val operatorName = "last_reset" - - override def getIncrementValue( - record: DataRecord, - feature: Option[Feature[T]], - timestampFeature: Feature[JLong] - ): TimedValue[Double] = { - val featureExists: Boolean = feature match { - case Some(f) => SRichDataRecord(record).hasFeature(f) - case None => true - } - - TimedValue[Double]( - value = booleanToDouble(featureExists), - timestamp = Time.fromMilliseconds(getTimestamp(record, timestampFeature)) - ) - } - private def getDecayedValue( - olderTimedValue: TimedValue[Double], - newerTimestamp: Time, - halfLife: Duration - ): Double = { - if (halfLife.inMilliseconds == 0L) { - 0.0 - } else { - val timeDelta = newerTimestamp.inMilliseconds - olderTimedValue.timestamp.inMilliseconds - val resultValue = olderTimedValue.value / math.pow(2.0, timeDelta / halfLife.inMillis) - if (resultValue > AggregationMetricCommon.Epsilon) resultValue else 0.0 - } - } - - override def plus( - left: TimedValue[Double], - right: TimedValue[Double], - halfLife: Duration - ): TimedValue[Double] = { - - val (newerTimedValue, olderTimedValue) = if (left.timestamp > right.timestamp) { - (left, right) - } else { - (right, left) - } - - val optionallyDecayedOlderValue = if (halfLife == Duration.Top) { - // Since we don't want to decay, older value is not changed - olderTimedValue.value - } else { - // Decay older value - getDecayedValue(olderTimedValue, newerTimedValue.timestamp, halfLife) - } - - TimedValue[Double]( - value = max(newerTimedValue.value, optionallyDecayedOlderValue), - timestamp = newerTimedValue.timestamp - ) - } - - override def zero(timeOpt: Option[Long]): TimedValue[Double] = TimedValue[Double]( - value = 0.0, - timestamp = Time.fromMilliseconds(0) - ) -} - -/** - * Syntactic sugar for the last reset metric that works with - * any feature type as opposed to being tied to a specific type. - * See EasyMetric.scala for more details on why this is useful. - */ -object LastResetMetric extends EasyMetric { - override def forFeatureType[T]( - featureType: FeatureType - ): Option[AggregationMetric[T, _]] = - Some(TypedLastResetMetric[T]()) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/LatestMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/LatestMetric.docx new file mode 100644 index 000000000..9f65cdcd0 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/LatestMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/LatestMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/LatestMetric.scala deleted file mode 100644 index 08bd6483a..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/LatestMetric.scala +++ /dev/null @@ -1,69 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.ml.api.FeatureType -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon.getTimestamp -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetric -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.EasyMetric -import com.twitter.util.Duration -import com.twitter.util.Time -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} -import java.lang.{Number => JNumber} - -case class TypedLatestMetric[T <: JNumber](defaultValue: Double = 0.0) - extends TimedValueAggregationMetric[T] { - override val operatorName = "latest" - - override def plus( - left: TimedValue[Double], - right: TimedValue[Double], - halfLife: Duration - ): TimedValue[Double] = { - assert( - halfLife.toString == "Duration.Top", - s"halfLife must be Duration.Top when using latest metric, but ${halfLife.toString} is used" - ) - - if (left.timestamp > right.timestamp) { - left - } else { - right - } - } - - override def getIncrementValue( - dataRecord: DataRecord, - feature: Option[Feature[T]], - timestampFeature: Feature[JLong] - ): TimedValue[Double] = { - val value = feature - .flatMap(SRichDataRecord(dataRecord).getFeatureValueOpt(_)) - .map(_.doubleValue()).getOrElse(defaultValue) - val timestamp = Time.fromMilliseconds(getTimestamp(dataRecord, timestampFeature)) - TimedValue[Double](value = value, timestamp = timestamp) - } - - override def zero(timeOpt: Option[Long]): TimedValue[Double] = - TimedValue[Double]( - value = 0.0, - timestamp = Time.fromMilliseconds(0) - ) -} - -object LatestMetric extends EasyMetric { - override def forFeatureType[T]( - featureType: FeatureType - ): Option[AggregationMetric[T, _]] = { - featureType match { - case FeatureType.CONTINUOUS => - Some(TypedLatestMetric[JDouble]().asInstanceOf[AggregationMetric[T, Double]]) - case FeatureType.DISCRETE => - Some(TypedLatestMetric[JLong]().asInstanceOf[AggregationMetric[T, Double]]) - case _ => None - } - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/MaxMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/MaxMetric.docx new file mode 100644 index 000000000..79d9ab113 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/MaxMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/MaxMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/MaxMetric.scala deleted file mode 100644 index b9e9176bb..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/MaxMetric.scala +++ /dev/null @@ -1,64 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon.getTimestamp -import com.twitter.util.Duration -import com.twitter.util.Time -import java.lang.{Long => JLong} -import java.lang.{Number => JNumber} -import java.lang.{Double => JDouble} -import scala.math.max - -case class TypedMaxMetric[T <: JNumber](defaultValue: Double = 0.0) - extends TimedValueAggregationMetric[T] { - override val operatorName = "max" - - override def getIncrementValue( - dataRecord: DataRecord, - feature: Option[Feature[T]], - timestampFeature: Feature[JLong] - ): TimedValue[Double] = { - val value = feature - .flatMap(SRichDataRecord(dataRecord).getFeatureValueOpt(_)) - .map(_.doubleValue()).getOrElse(defaultValue) - val timestamp = Time.fromMilliseconds(getTimestamp(dataRecord, timestampFeature)) - TimedValue[Double](value = value, timestamp = timestamp) - } - - override def plus( - left: TimedValue[Double], - right: TimedValue[Double], - halfLife: Duration - ): TimedValue[Double] = { - - assert( - halfLife.toString == "Duration.Top", - s"halfLife must be Duration.Top when using max metric, but ${halfLife.toString} is used" - ) - - TimedValue[Double]( - value = max(left.value, right.value), - timestamp = left.timestamp.max(right.timestamp) - ) - } - - override def zero(timeOpt: Option[Long]): TimedValue[Double] = - TimedValue[Double]( - value = 0.0, - timestamp = Time.fromMilliseconds(0) - ) -} - -object MaxMetric extends EasyMetric { - def forFeatureType[T]( - featureType: FeatureType, - ): Option[AggregationMetric[T, _]] = - featureType match { - case FeatureType.CONTINUOUS => - Some(TypedMaxMetric[JDouble]().asInstanceOf[AggregationMetric[T, Double]]) - case FeatureType.DISCRETE => - Some(TypedMaxMetric[JLong]().asInstanceOf[AggregationMetric[T, Double]]) - case _ => None - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumLikeMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumLikeMetric.docx new file mode 100644 index 000000000..7a20fbd84 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumLikeMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumLikeMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumLikeMetric.scala deleted file mode 100644 index 1f7aeb58a..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumLikeMetric.scala +++ /dev/null @@ -1,66 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.util.Duration -import com.twitter.util.Time -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} -import java.util.{Map => JMap} - -/* - * TypedSumLikeMetric aggregates a sum over any feature transform. - * TypedCountMetric, TypedSumMetric, TypedSumSqMetric are examples - * of metrics that are inherited from this trait. To implement a new - * "sum like" metric, override the getIncrementValue() and operatorName - * members of this trait. - * - * getIncrementValue() is inherited from the - * parent trait AggregationMetric, but not overriden in this trait, so - * it needs to be overloaded by any metric that extends TypedSumLikeMetric. - * - * operatorName is a string used for naming the resultant aggregate feature - * (e.g. "count" if its a count feature, or "sum" if a sum feature). - */ -trait TypedSumLikeMetric[T] extends TimedValueAggregationMetric[T] { - import AggregationMetricCommon._ - - def useFixedDecay = true - - override def plus( - left: TimedValue[Double], - right: TimedValue[Double], - halfLife: Duration - ): TimedValue[Double] = { - val resultValue = if (halfLife == Duration.Top) { - /* We could use decayedValueMonoid here, but - * a simple addition is slightly more accurate */ - left.value + right.value - } else { - val decayedLeft = toDecayedValue(left, halfLife) - val decayedRight = toDecayedValue(right, halfLife) - decayedValueMonoid.plus(decayedLeft, decayedRight).value - } - - TimedValue[Double]( - resultValue, - left.timestamp.max(right.timestamp) - ) - } - - override def zero(timeOpt: Option[Long]): TimedValue[Double] = { - val timestamp = - /* - * Please see TQ-11279 for documentation for this fix to the decay logic. - */ - if (useFixedDecay) { - Time.fromMilliseconds(timeOpt.getOrElse(0L)) - } else { - Time.fromMilliseconds(0L) - } - - TimedValue[Double]( - value = 0.0, - timestamp = timestamp - ) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumMetric.docx new file mode 100644 index 000000000..5105a884d Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumMetric.scala deleted file mode 100644 index bd93d5bae..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumMetric.scala +++ /dev/null @@ -1,52 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.util.Time -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} - -case class TypedSumMetric( -) extends TypedSumLikeMetric[JDouble] { - import AggregationMetricCommon._ - - override val operatorName = "sum" - - /* - * Transform feature -> its value in the given record, - * or 0 when feature = None (sum has no meaning in this case) - */ - override def getIncrementValue( - record: DataRecord, - feature: Option[Feature[JDouble]], - timestampFeature: Feature[JLong] - ): TimedValue[Double] = feature match { - case Some(f) => { - TimedValue[Double]( - value = Option(SRichDataRecord(record).getFeatureValue(f)).map(_.toDouble).getOrElse(0.0), - timestamp = Time.fromMilliseconds(getTimestamp(record, timestampFeature)) - ) - } - - case None => - TimedValue[Double]( - value = 0.0, - timestamp = Time.fromMilliseconds(getTimestamp(record, timestampFeature)) - ) - } -} - -/** - * Syntactic sugar for the sum metric that works with continuous features. - * See EasyMetric.scala for more details on why this is useful. - */ -object SumMetric extends EasyMetric { - override def forFeatureType[T]( - featureType: FeatureType - ): Option[AggregationMetric[T, _]] = - featureType match { - case FeatureType.CONTINUOUS => - Some(TypedSumMetric().asInstanceOf[AggregationMetric[T, Double]]) - case _ => None - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumSqMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumSqMetric.docx new file mode 100644 index 000000000..7113ccdb8 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumSqMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumSqMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/SumSqMetric.scala deleted file mode 100644 index b24b16377..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/SumSqMetric.scala +++ /dev/null @@ -1,53 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.util.Time -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} - -case class TypedSumSqMetric() extends TypedSumLikeMetric[JDouble] { - import AggregationMetricCommon._ - - override val operatorName = "sumsq" - - /* - * Transform feature -> its squared value in the given record - * or 0 when feature = None (sumsq has no meaning in this case) - */ - override def getIncrementValue( - record: DataRecord, - feature: Option[Feature[JDouble]], - timestampFeature: Feature[JLong] - ): TimedValue[Double] = feature match { - case Some(f) => { - val featureVal = - Option(SRichDataRecord(record).getFeatureValue(f)).map(_.toDouble).getOrElse(0.0) - TimedValue[Double]( - value = featureVal * featureVal, - timestamp = Time.fromMilliseconds(getTimestamp(record, timestampFeature)) - ) - } - - case None => - TimedValue[Double]( - value = 0.0, - timestamp = Time.fromMilliseconds(getTimestamp(record, timestampFeature)) - ) - } -} - -/** - * Syntactic sugar for the sum of squares metric that works with continuous features. - * See EasyMetric.scala for more details on why this is useful. - */ -object SumSqMetric extends EasyMetric { - override def forFeatureType[T]( - featureType: FeatureType - ): Option[AggregationMetric[T, _]] = - featureType match { - case FeatureType.CONTINUOUS => - Some(TypedSumSqMetric().asInstanceOf[AggregationMetric[T, Double]]) - case _ => None - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValue.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValue.docx new file mode 100644 index 000000000..2f9a4fc70 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValue.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValue.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValue.scala deleted file mode 100644 index 7f9fb5090..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValue.scala +++ /dev/null @@ -1,14 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.util.Time - -/** - * Case class wrapping a (value, timestamp) tuple. - * All aggregate metrics must operate over this class - * to ensure we can implement decay and half lives for them. - * This is translated to an algebird DecayedValue under the hood. - * - * @param value Value being wrapped - * @param timestamp Time after epoch at which value is being measured - */ -case class TimedValue[T](value: T, timestamp: Time) diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValueAggregationMetric.docx b/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValueAggregationMetric.docx new file mode 100644 index 000000000..8acf5fb8e Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValueAggregationMetric.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValueAggregationMetric.scala b/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValueAggregationMetric.scala deleted file mode 100644 index f31152a23..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/metrics/TimedValueAggregationMetric.scala +++ /dev/null @@ -1,90 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics - -import com.twitter.ml.api._ -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregateFeature -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.TimedValue -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetric -import com.twitter.util.Duration -import com.twitter.util.Time -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} -import java.util.{Map => JMap} - -/* - * ContinuousAggregationMetric overrides method AggregationMetric dealing - * with reading and writing continuous values from a data record. - * - * operatorName is a string used for naming the resultant aggregate feature - * (e.g. "count" if its a count feature, or "sum" if a sum feature). - */ -trait TimedValueAggregationMetric[T] extends AggregationMetric[T, Double] { - import AggregationMetricCommon._ - - val operatorName: String - - override def getAggregateValue( - record: DataRecord, - query: AggregateFeature[T], - aggregateOutputs: Option[List[JLong]] = None - ): TimedValue[Double] = { - /* - * We know aggregateOutputs(0) will have the continuous feature, - * since we put it there in getOutputFeatureIds() - see code below. - * This helps us get a 4x speedup. Using any structure more complex - * than a list was also a performance bottleneck. - */ - val featureHash: JLong = aggregateOutputs - .getOrElse(getOutputFeatureIds(query)) - .head - - val continuousValueOption: Option[Double] = Option(record.continuousFeatures) - .flatMap { case jmap: JMap[JLong, JDouble] => Option(jmap.get(featureHash)) } - .map(_.toDouble) - - val timeOption = Option(record.discreteFeatures) - .flatMap { case jmap: JMap[JLong, JLong] => Option(jmap.get(TimestampHash)) } - .map(_.toLong) - - val resultOption: Option[TimedValue[Double]] = (continuousValueOption, timeOption) match { - case (Some(featureValue), Some(timesamp)) => - Some(TimedValue[Double](featureValue, Time.fromMilliseconds(timesamp))) - case _ => None - } - - resultOption.getOrElse(zero(timeOption)) - } - - override def setAggregateValue( - record: DataRecord, - query: AggregateFeature[T], - aggregateOutputs: Option[List[JLong]] = None, - value: TimedValue[Double] - ): Unit = { - /* - * We know aggregateOutputs(0) will have the continuous feature, - * since we put it there in getOutputFeatureIds() - see code below. - * This helps us get a 4x speedup. Using any structure more complex - * than a list was also a performance bottleneck. - */ - val featureHash: JLong = aggregateOutputs - .getOrElse(getOutputFeatureIds(query)) - .head - - /* Only set value if non-zero to save space */ - if (value.value != 0.0) { - record.putToContinuousFeatures(featureHash, value.value) - } - - /* - * We do not set timestamp since that might affect correctness of - * future aggregations due to the decay semantics. - */ - } - - /* Only one feature stored in the aggregated datarecord: the result continuous value */ - override def getOutputFeatures(query: AggregateFeature[T]): List[Feature[_]] = { - val feature = cachedFullFeature(query, operatorName, FeatureType.CONTINUOUS) - List(feature) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/package.docx b/timelines/data_processing/ml_util/aggregation_framework/package.docx new file mode 100644 index 000000000..81ffccc37 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/package.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/package.scala b/timelines/data_processing/ml_util/aggregation_framework/package.scala deleted file mode 100644 index 824398a7f..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/package.scala +++ /dev/null @@ -1,19 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util - -import com.twitter.ml.api.DataRecord - -package object aggregation_framework { - object AggregateType extends Enumeration { - type AggregateType = Value - val User, UserAuthor, UserEngager, UserMention, UserRequestHour, UserRequestDow, - UserOriginalAuthor, UserList, UserTopic, UserInferredTopic, UserMediaUnderstandingAnnotation = - Value - } - - type AggregateUserEntityKey = (Long, AggregateType.Value, Option[Long]) - - case class MergedRecordsDescriptor( - userId: Long, - keyedRecords: Map[AggregateType.Value, Option[KeyedRecord]], - keyedRecordMaps: Map[AggregateType.Value, Option[KeyedRecordMap]]) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/query/BUILD b/timelines/data_processing/ml_util/aggregation_framework/query/BUILD deleted file mode 100644 index 97e6d1ea7..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/query/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -scala_library( - sources = ["*.scala"], - platform = "java8", - tags = ["bazel-compatible"], - dependencies = [ - "finagle/finagle-stats", - "src/java/com/twitter/ml/api:api-base", - "src/thrift/com/twitter/ml/api:data-scala", - "src/thrift/com/twitter/ml/api:interpretable-model-java", - "timelines/data_processing/ml_util/aggregation_framework/metrics", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/query/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/query/BUILD.docx new file mode 100644 index 000000000..1e8721128 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/query/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/query/ScopedAggregateBuilder.docx b/timelines/data_processing/ml_util/aggregation_framework/query/ScopedAggregateBuilder.docx new file mode 100644 index 000000000..dcf27a784 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/query/ScopedAggregateBuilder.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/query/ScopedAggregateBuilder.scala b/timelines/data_processing/ml_util/aggregation_framework/query/ScopedAggregateBuilder.scala deleted file mode 100644 index 2fcce3312..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/query/ScopedAggregateBuilder.scala +++ /dev/null @@ -1,159 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.query - -import com.twitter.dal.personal_data.thriftjava.PersonalDataType -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.ml.api.FeatureBuilder -import com.twitter.ml.api.FeatureContext -import com.twitter.ml.api.thriftscala.{DataRecord => ScalaDataRecord} -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.metrics.AggregationMetricCommon -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} -import scala.collection.JavaConverters._ - -/** - * Provides methods to build "scoped" aggregates, where base features generated by aggregates - * V2 are scoped with a specific key. - * - * The class provides methods that take a Map of T -> DataRecord, where T is a key type, and - * the DataRecord contains features produced by the aggregation_framework. The methods then - * generate a _new_ DataRecord, containing "scoped" aggregate features, where each scoped - * feature has the value of the scope key in the feature name, and the value of the feature - * is the value of the original aggregate feature in the corresponding value from the original - * Map. - * - * For efficiency reasons, the builder is initialized with the set of features that should be - * scoped and the set of keys for which scoping should be supported. - * - * To understand how scope feature names are constructed, consider the following: - * - * {{{ - * val features = Set( - * new Feature.Continuous("user_injection_aggregate.pair.any_label.any_feature.5.days.count"), - * new Feature.Continuous("user_injection_aggregate.pair.any_label.any_feature.10.days.count") - * ) - * val scopes = Set(SuggestType.Recap, SuggestType.WhoToFollow) - * val scopeName = "InjectionType" - * val scopedAggregateBuilder = ScopedAggregateBuilder(features, scopes, scopeName) - * - * }}} - * - * Then, generated scoped features would be among the following: - * - user_injection_aggregate.scoped.pair.any_label.any_feature.5.days.count/scope_name=InjectionType/scope=Recap - * - user_injection_aggregate.scoped.pair.any_label.any_feature.5.days.count/scope_name=InjectionType/scope=WhoToFollow - * - user_injection_aggregate.scoped.pair.any_label.any_feature.10.days.count/scope_name=InjectionType/scope=Recap - * - user_injection_aggregate.scoped.pair.any_label.any_feature.10.days.count/scope_name=InjectionType/scope=WhoToFollow - * - * @param featuresToScope the set of features for which one should generate scoped versions - * @param scopeKeys the set of scope keys to generate scopes with - * @param scopeName a string indicating what the scopes represent. This is also added to the scoped feature - * @tparam K the type of scope key - */ -class ScopedAggregateBuilder[K]( - featuresToScope: Set[Feature[JDouble]], - scopeKeys: Set[K], - scopeName: String) { - - private[this] def buildScopedAggregateFeature( - baseName: String, - scopeValue: String, - personalDataTypes: java.util.Set[PersonalDataType] - ): Feature[JDouble] = { - val components = baseName.split("\\.").toList - - val newName = (components.head :: "scoped" :: components.tail).mkString(".") - - new FeatureBuilder.Continuous() - .addExtensionDimensions("scope_name", "scope") - .setBaseName(newName) - .setPersonalDataTypes(personalDataTypes) - .extensionBuilder() - .addExtension("scope_name", scopeName) - .addExtension("scope", scopeValue) - .build() - } - - /** - * Index of (base aggregate feature name, key) -> key scoped count feature. - */ - private[this] val keyScopedAggregateMap: Map[(String, K), Feature[JDouble]] = { - featuresToScope.flatMap { feat => - scopeKeys.map { key => - (feat.getFeatureName, key) -> - buildScopedAggregateFeature( - feat.getFeatureName, - key.toString, - AggregationMetricCommon.derivePersonalDataTypes(Some(feat)) - ) - } - }.toMap - } - - type ContinuousFeaturesMap = Map[JLong, JDouble] - - /** - * Create key-scoped features for raw aggregate feature ID to value maps, partitioned by key. - */ - private[this] def buildAggregates(featureMapsByKey: Map[K, ContinuousFeaturesMap]): DataRecord = { - val continuousFeatures = featureMapsByKey - .flatMap { - case (key, featureMap) => - featuresToScope.flatMap { feature => - val newFeatureOpt = keyScopedAggregateMap.get((feature.getFeatureName, key)) - newFeatureOpt.flatMap { newFeature => - featureMap.get(feature.getFeatureId).map(new JLong(newFeature.getFeatureId) -> _) - } - }.toMap - } - - new DataRecord().setContinuousFeatures(continuousFeatures.asJava) - } - - /** - * Create key-scoped features for Java [[DataRecord]] aggregate records partitioned by key. - * - * As an example, if the provided Map includes the key `SuggestType.Recap`, and [[scopeKeys]] - * includes this key, then for a feature "xyz.pair.any_label.any_feature.5.days.count", the method - * will generate the scoped feature "xyz.scoped.pair.any_label.any_feature.5.days.count/scope_name=InjectionType/scope=Recap", - * with the value being the value of the original feature from the Map. - * - * @param aggregatesByKey a map from key to a continuous feature map (ie. feature ID -> Double) - * @return a Java [[DataRecord]] containing key-scoped features - */ - def buildAggregatesJava(aggregatesByKey: Map[K, DataRecord]): DataRecord = { - val featureMapsByKey = aggregatesByKey.mapValues(_.continuousFeatures.asScala.toMap) - buildAggregates(featureMapsByKey) - } - - /** - * Create key-scoped features for Scala [[DataRecord]] aggregate records partitioned by key. - * - * As an example, if the provided Map includes the key `SuggestType.Recap`, and [[scopeKeys]] - * includes this key, then for a feature "xyz.pair.any_label.any_feature.5.days.count", the method - * will generate the scoped feature "xyz.scoped.pair.any_label.any_feature.5.days.count/scope_name=InjectionType/scope=Recap", - * with the value being the value of the original feature from the Map. - * - * This is a convenience method for some use cases where aggregates are read from Scala - * thrift objects. Note that this still returns a Java [[DataRecord]], since most ML API - * use the Java version. - * - * @param aggregatesByKey a map from key to a continuous feature map (ie. feature ID -> Double) - * @return a Java [[DataRecord]] containing key-scoped features - */ - def buildAggregatesScala(aggregatesByKey: Map[K, ScalaDataRecord]): DataRecord = { - val featureMapsByKey = - aggregatesByKey - .mapValues { record => - val featureMap = record.continuousFeatures.getOrElse(Map[Long, Double]()).toMap - featureMap.map { case (k, v) => new JLong(k) -> new JDouble(v) } - } - buildAggregates(featureMapsByKey) - } - - /** - * Returns a [[FeatureContext]] including all possible scoped features generated using this builder. - * - * @return a [[FeatureContext]] containing all scoped features. - */ - def scopedFeatureContext: FeatureContext = new FeatureContext(keyScopedAggregateMap.values.asJava) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregateFeaturesMerger.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregateFeaturesMerger.docx new file mode 100644 index 000000000..c35066d45 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregateFeaturesMerger.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregateFeaturesMerger.scala b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregateFeaturesMerger.scala deleted file mode 100644 index 156168a9d..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregateFeaturesMerger.scala +++ /dev/null @@ -1,213 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding - -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures._ -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.scalding.Stat -import com.twitter.scalding.typed.TypedPipe -import com.twitter.timelines.data_processing.ml_util.aggregation_framework._ -import com.twitter.timelines.data_processing.ml_util.sampling.SamplingUtils - -trait AggregateFeaturesMergerBase { - import Utils._ - - def samplingRateOpt: Option[Double] - def numReducers: Int = 2000 - def numReducersMerge: Int = 20000 - - def aggregationConfig: AggregationConfig - def storeRegister: StoreRegister - def storeMerger: StoreMerger - - def getAggregatePipe(storeName: String): DataSetPipe - def applyMaxSizeByTypeOpt(aggregateType: AggregateType.Value): Option[Int] = Option.empty[Int] - - def usersActiveSourcePipe: TypedPipe[Long] - def numRecords: Stat - def numFilteredRecords: Stat - - /* - * This method should only be called with a storeName that corresponds - * to a user aggregate store. - */ - def extractUserFeaturesMap(storeName: String): TypedPipe[(Long, KeyedRecord)] = { - val aggregateKey = storeRegister.storeNameToTypeMap(storeName) - samplingRateOpt - .map(rate => SamplingUtils.userBasedSample(getAggregatePipe(storeName), rate)) - .getOrElse(getAggregatePipe(storeName)) // must return store with only user aggregates - .records - .map { r: DataRecord => - val record = SRichDataRecord(r) - val userId = record.getFeatureValue(USER_ID).longValue - record.clearFeature(USER_ID) - (userId, KeyedRecord(aggregateKey, r)) - } - } - - /* - * When the secondaryKey being used is a String, then the shouldHash function should be set to true. - * Refactor such that the shouldHash parameter is removed and the behavior - * is defaulted to true. - * - * This method should only be called with a storeName that contains records with the - * desired secondaryKey. We provide secondaryKeyFilterPipeOpt against which secondary - * keys can be filtered to help prune the final merged MH dataset. - */ - def extractSecondaryTuples[T]( - storeName: String, - secondaryKey: Feature[T], - shouldHash: Boolean = false, - maxSizeOpt: Option[Int] = None, - secondaryKeyFilterPipeOpt: Option[TypedPipe[Long]] = None - ): TypedPipe[(Long, KeyedRecordMap)] = { - val aggregateKey = storeRegister.storeNameToTypeMap(storeName) - - val extractedRecordsBySecondaryKey = - samplingRateOpt - .map(rate => SamplingUtils.userBasedSample(getAggregatePipe(storeName), rate)) - .getOrElse(getAggregatePipe(storeName)) - .records - .map { r: DataRecord => - val record = SRichDataRecord(r) - val userId = keyFromLong(r, USER_ID) - val secondaryId = extractSecondary(r, secondaryKey, shouldHash) - record.clearFeature(USER_ID) - record.clearFeature(secondaryKey) - - numRecords.inc() - (userId, secondaryId -> r) - } - - val grouped = - (secondaryKeyFilterPipeOpt match { - case Some(secondaryKeyFilterPipe: TypedPipe[Long]) => - extractedRecordsBySecondaryKey - .map { - // In this step, we swap `userId` with `secondaryId` to join on the `secondaryId` - // It is important to swap them back after the join, otherwise the job will fail. - case (userId, (secondaryId, r)) => - (secondaryId, (userId, r)) - } - .join(secondaryKeyFilterPipe.groupBy(identity)) - .map { - case (secondaryId, ((userId, r), _)) => - numFilteredRecords.inc() - (userId, secondaryId -> r) - } - case _ => extractedRecordsBySecondaryKey - }).group - .withReducers(numReducers) - - maxSizeOpt match { - case Some(maxSize) => - grouped - .take(maxSize) - .mapValueStream(recordsIter => Iterator(KeyedRecordMap(aggregateKey, recordsIter.toMap))) - .toTypedPipe - case None => - grouped - .mapValueStream(recordsIter => Iterator(KeyedRecordMap(aggregateKey, recordsIter.toMap))) - .toTypedPipe - } - } - - def userPipes: Seq[TypedPipe[(Long, KeyedRecord)]] = - storeRegister.allStores.flatMap { storeConfig => - val StoreConfig(storeNames, aggregateType, _) = storeConfig - require(storeMerger.isValidToMerge(storeNames)) - - if (aggregateType == AggregateType.User) { - storeNames.map(extractUserFeaturesMap) - } else None - }.toSeq - - private def getSecondaryKeyFilterPipeOpt( - aggregateType: AggregateType.Value - ): Option[TypedPipe[Long]] = { - if (aggregateType == AggregateType.UserAuthor) { - Some(usersActiveSourcePipe) - } else None - } - - def userSecondaryKeyPipes: Seq[TypedPipe[(Long, KeyedRecordMap)]] = { - storeRegister.allStores.flatMap { storeConfig => - val StoreConfig(storeNames, aggregateType, shouldHash) = storeConfig - require(storeMerger.isValidToMerge(storeNames)) - - if (aggregateType != AggregateType.User) { - storeNames.flatMap { storeName => - storeConfig.secondaryKeyFeatureOpt - .map { secondaryFeature => - extractSecondaryTuples( - storeName, - secondaryFeature, - shouldHash, - applyMaxSizeByTypeOpt(aggregateType), - getSecondaryKeyFilterPipeOpt(aggregateType) - ) - } - } - } else None - }.toSeq - } - - def joinedAggregates: TypedPipe[(Long, MergedRecordsDescriptor)] = { - (userPipes ++ userSecondaryKeyPipes) - .reduce(_ ++ _) - .group - .withReducers(numReducersMerge) - .mapGroup { - case (uid, keyedRecordsAndMaps) => - /* - * For every user, partition their records by aggregate type. - * AggregateType.User should only contain KeyedRecord whereas - * other aggregate types (with secondary keys) contain KeyedRecordMap. - */ - val (userRecords, userSecondaryKeyRecords) = keyedRecordsAndMaps.toList - .map { record => - record match { - case record: KeyedRecord => (record.aggregateType, record) - case record: KeyedRecordMap => (record.aggregateType, record) - } - } - .groupBy(_._1) - .mapValues(_.map(_._2)) - .partition(_._1 == AggregateType.User) - - val userAggregateRecordMap: Map[AggregateType.Value, Option[KeyedRecord]] = - userRecords - .asInstanceOf[Map[AggregateType.Value, List[KeyedRecord]]] - .map { - case (aggregateType, keyedRecords) => - val mergedKeyedRecordOpt = mergeKeyedRecordOpts(keyedRecords.map(Some(_)): _*) - (aggregateType, mergedKeyedRecordOpt) - } - - val userSecondaryKeyAggregateRecordOpt: Map[AggregateType.Value, Option[KeyedRecordMap]] = - userSecondaryKeyRecords - .asInstanceOf[Map[AggregateType.Value, List[KeyedRecordMap]]] - .map { - case (aggregateType, keyedRecordMaps) => - val keyedRecordMapOpt = - keyedRecordMaps.foldLeft(Option.empty[KeyedRecordMap]) { - (mergedRecOpt, nextRec) => - applyMaxSizeByTypeOpt(aggregateType) - .map { maxSize => - mergeKeyedRecordMapOpts(mergedRecOpt, Some(nextRec), maxSize) - }.getOrElse { - mergeKeyedRecordMapOpts(mergedRecOpt, Some(nextRec)) - } - } - (aggregateType, keyedRecordMapOpt) - } - - Iterator( - MergedRecordsDescriptor( - userId = uid, - keyedRecords = userAggregateRecordMap, - keyedRecordMaps = userSecondaryKeyAggregateRecordOpt - ) - ) - }.toTypedPipe - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesStoreComparisonJob.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesStoreComparisonJob.docx new file mode 100644 index 000000000..0b5a1cb77 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesStoreComparisonJob.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesStoreComparisonJob.scala b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesStoreComparisonJob.scala deleted file mode 100644 index 054d5d428..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesStoreComparisonJob.scala +++ /dev/null @@ -1,200 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding - -import com.twitter.algebird.ScMapMonoid -import com.twitter.bijection.Injection -import com.twitter.bijection.thrift.CompactThriftCodec -import com.twitter.ml.api.util.CompactDataRecordConverter -import com.twitter.ml.api.CompactDataRecord -import com.twitter.ml.api.DataRecord -import com.twitter.scalding.commons.source.VersionedKeyValSource -import com.twitter.scalding.Args -import com.twitter.scalding.Days -import com.twitter.scalding.Duration -import com.twitter.scalding.RichDate -import com.twitter.scalding.TypedPipe -import com.twitter.scalding.TypedTsv -import com.twitter.scalding_internal.job.HasDateRange -import com.twitter.scalding_internal.job.analytics_batch.AnalyticsBatchJob -import com.twitter.summingbird.batch.BatchID -import com.twitter.summingbird_internal.bijection.BatchPairImplicits -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKeyInjection -import java.lang.{Double => JDouble} -import java.lang.{Long => JLong} -import scala.collection.JavaConverters._ - -/** - * The job takes four inputs: - * - The path to a AggregateStore using the DataRecord format. - * - The path to a AggregateStore using the CompactDataRecord format. - * - A version that must be present in both sources. - * - A sink to write the comparison statistics. - * - * The job reads in the two stores, converts the second one to DataRecords and - * then compared each key to see if the two stores have identical DataRecords, - * modulo the loss in precision on converting the Double to Float. - */ -class AggregatesStoreComparisonJob(args: Args) - extends AnalyticsBatchJob(args) - with BatchPairImplicits - with HasDateRange { - - import AggregatesStoreComparisonJob._ - override def batchIncrement: Duration = Days(1) - override def firstTime: RichDate = RichDate(args("firstTime")) - - private val dataRecordSourcePath = args("dataRecordSource") - private val compactDataRecordSourcePath = args("compactDataRecordSource") - - private val version = args.long("version") - - private val statsSink = args("sink") - - require(dataRecordSourcePath != compactDataRecordSourcePath) - - private val dataRecordSource = - VersionedKeyValSource[AggregationKey, (BatchID, DataRecord)]( - path = dataRecordSourcePath, - sourceVersion = Some(version) - ) - private val compactDataRecordSource = - VersionedKeyValSource[AggregationKey, (BatchID, CompactDataRecord)]( - path = compactDataRecordSourcePath, - sourceVersion = Some(version) - ) - - private val dataRecordPipe: TypedPipe[((AggregationKey, BatchID), DataRecord)] = TypedPipe - .from(dataRecordSource) - .map { case (key, (batchId, record)) => ((key, batchId), record) } - - private val compactDataRecordPipe: TypedPipe[((AggregationKey, BatchID), DataRecord)] = TypedPipe - .from(compactDataRecordSource) - .map { - case (key, (batchId, compactRecord)) => - val record = compactConverter.compactDataRecordToDataRecord(compactRecord) - ((key, batchId), record) - } - - dataRecordPipe - .outerJoin(compactDataRecordPipe) - .mapValues { case (leftOpt, rightOpt) => compareDataRecords(leftOpt, rightOpt) } - .values - .sum(mapMonoid) - .flatMap(_.toList) - .write(TypedTsv(statsSink)) -} - -object AggregatesStoreComparisonJob { - - val mapMonoid: ScMapMonoid[String, Long] = new ScMapMonoid[String, Long]() - - implicit private val aggregationKeyInjection: Injection[AggregationKey, Array[Byte]] = - AggregationKeyInjection - implicit private val aggregationKeyOrdering: Ordering[AggregationKey] = AggregationKeyOrdering - implicit private val dataRecordCodec: Injection[DataRecord, Array[Byte]] = - CompactThriftCodec[DataRecord] - implicit private val compactDataRecordCodec: Injection[CompactDataRecord, Array[Byte]] = - CompactThriftCodec[CompactDataRecord] - - private val compactConverter = new CompactDataRecordConverter - - val missingRecordFromLeft = "missingRecordFromLeft" - val missingRecordFromRight = "missingRecordFromRight" - val nonContinuousFeaturesDidNotMatch = "nonContinuousFeaturesDidNotMatch" - val missingFeaturesFromLeft = "missingFeaturesFromLeft" - val missingFeaturesFromRight = "missingFeaturesFromRight" - val recordsWithUnmatchedKeys = "recordsWithUnmatchedKeys" - val featureValuesMatched = "featureValuesMatched" - val featureValuesThatDidNotMatch = "featureValuesThatDidNotMatch" - val equalRecords = "equalRecords" - val keyCount = "keyCount" - - def compareDataRecords( - leftOpt: Option[DataRecord], - rightOpt: Option[DataRecord] - ): collection.Map[String, Long] = { - val stats = collection.Map((keyCount, 1L)) - (leftOpt, rightOpt) match { - case (Some(left), Some(right)) => - if (isIdenticalNonContinuousFeatureSet(left, right)) { - getContinuousFeaturesStats(left, right).foldLeft(stats)(mapMonoid.add) - } else { - mapMonoid.add(stats, (nonContinuousFeaturesDidNotMatch, 1L)) - } - case (Some(_), None) => mapMonoid.add(stats, (missingRecordFromRight, 1L)) - case (None, Some(_)) => mapMonoid.add(stats, (missingRecordFromLeft, 1L)) - case (None, None) => throw new IllegalArgumentException("Should never be possible") - } - } - - /** - * For Continuous features. - */ - private def getContinuousFeaturesStats( - left: DataRecord, - right: DataRecord - ): Seq[(String, Long)] = { - val leftFeatures = Option(left.getContinuousFeatures) - .map(_.asScala.toMap) - .getOrElse(Map.empty[JLong, JDouble]) - - val rightFeatures = Option(right.getContinuousFeatures) - .map(_.asScala.toMap) - .getOrElse(Map.empty[JLong, JDouble]) - - val numMissingFeaturesLeft = (rightFeatures.keySet diff leftFeatures.keySet).size - val numMissingFeaturesRight = (leftFeatures.keySet diff rightFeatures.keySet).size - - if (numMissingFeaturesLeft == 0 && numMissingFeaturesRight == 0) { - val Epsilon = 1e-5 - val numUnmatchedValues = leftFeatures.map { - case (id, lValue) => - val rValue = rightFeatures(id) - // The approximate match is to account for the precision loss due to - // the Double -> Float -> Double conversion. - if (math.abs(lValue - rValue) <= Epsilon) 0L else 1L - }.sum - - if (numUnmatchedValues == 0) { - Seq( - (equalRecords, 1L), - (featureValuesMatched, leftFeatures.size.toLong) - ) - } else { - Seq( - (featureValuesThatDidNotMatch, numUnmatchedValues), - ( - featureValuesMatched, - math.max(leftFeatures.size, rightFeatures.size) - numUnmatchedValues) - ) - } - } else { - Seq( - (recordsWithUnmatchedKeys, 1L), - (missingFeaturesFromLeft, numMissingFeaturesLeft.toLong), - (missingFeaturesFromRight, numMissingFeaturesRight.toLong) - ) - } - } - - /** - * For feature types that are not Feature.Continuous. We expect these to match exactly in the two stores. - * Mutable change - */ - private def isIdenticalNonContinuousFeatureSet(left: DataRecord, right: DataRecord): Boolean = { - val booleanMatched = safeEquals(left.binaryFeatures, right.binaryFeatures) - val discreteMatched = safeEquals(left.discreteFeatures, right.discreteFeatures) - val stringMatched = safeEquals(left.stringFeatures, right.stringFeatures) - val sparseBinaryMatched = safeEquals(left.sparseBinaryFeatures, right.sparseBinaryFeatures) - val sparseContinuousMatched = - safeEquals(left.sparseContinuousFeatures, right.sparseContinuousFeatures) - val blobMatched = safeEquals(left.blobFeatures, right.blobFeatures) - val tensorsMatched = safeEquals(left.tensors, right.tensors) - val sparseTensorsMatched = safeEquals(left.sparseTensors, right.sparseTensors) - - booleanMatched && discreteMatched && stringMatched && sparseBinaryMatched && - sparseContinuousMatched && blobMatched && tensorsMatched && sparseTensorsMatched - } - - def safeEquals[T](l: T, r: T): Boolean = Option(l).equals(Option(r)) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesV2ScaldingJob.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesV2ScaldingJob.docx new file mode 100644 index 000000000..8402f8f63 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesV2ScaldingJob.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesV2ScaldingJob.scala b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesV2ScaldingJob.scala deleted file mode 100644 index aa8ae3612..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregatesV2ScaldingJob.scala +++ /dev/null @@ -1,216 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding - -import com.twitter.bijection.thrift.CompactThriftCodec -import com.twitter.bijection.Codec -import com.twitter.bijection.Injection -import com.twitter.ml.api._ -import com.twitter.ml.api.constant.SharedFeatures.TIMESTAMP -import com.twitter.ml.api.util.CompactDataRecordConverter -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.scalding.Args -import com.twitter.scalding_internal.dalv2.DALWrite.D -import com.twitter.storehaus_internal.manhattan.ManhattanROConfig -import com.twitter.summingbird.batch.option.Reducers -import com.twitter.summingbird.batch.BatchID -import com.twitter.summingbird.batch.Batcher -import com.twitter.summingbird.batch.Timestamp -import com.twitter.summingbird.option._ -import com.twitter.summingbird.scalding.Scalding -import com.twitter.summingbird.scalding.batch.{BatchedStore => ScaldingBatchedStore} -import com.twitter.summingbird.Options -import com.twitter.summingbird.Producer -import com.twitter.summingbird_internal.bijection.BatchPairImplicits._ -import com.twitter.summingbird_internal.runner.common.JobName -import com.twitter.summingbird_internal.runner.scalding.GenericRunner -import com.twitter.summingbird_internal.runner.scalding.ScaldingConfig -import com.twitter.summingbird_internal.runner.scalding.StatebirdState -import com.twitter.summingbird_internal.dalv2.DAL -import com.twitter.summingbird_internal.runner.store_config._ -import com.twitter.timelines.data_processing.ml_util.aggregation_framework._ -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding.sources._ -import job.AggregatesV2Job -import org.apache.hadoop.conf.Configuration -/* - * Offline scalding version of summingbird job to compute aggregates v2. - * This is loosely based on the template created by sb-gen. - * Extend this trait in your own scalding job, and override the val - * "aggregatesToCompute" with your own desired set of aggregates. - */ -trait AggregatesV2ScaldingJob { - val aggregatesToCompute: Set[TypedAggregateGroup[_]] - - implicit val aggregationKeyInjection: Injection[AggregationKey, Array[Byte]] = - AggregationKeyInjection - - implicit val aggregationKeyOrdering: AggregationKeyOrdering.type = AggregationKeyOrdering - - implicit val dataRecordCodec: Injection[DataRecord, Array[Byte]] = CompactThriftCodec[DataRecord] - - private implicit val compactDataRecordCodec: Injection[CompactDataRecord, Array[Byte]] = - CompactThriftCodec[CompactDataRecord] - - private val compactDataRecordConverter = new CompactDataRecordConverter() - - def numReducers: Int = -1 - - /** - * Function that maps from a logical ''AggregateSource'' - * to an underlying physical source. The physical source - * for the scalding platform is a ScaldingAggregateSource. - */ - def dataRecordSourceToScalding( - source: AggregateSource - ): Option[Producer[Scalding, DataRecord]] = { - source match { - case offlineSource: OfflineAggregateSource => - Some(ScaldingAggregateSource(offlineSource).source) - case _ => None - } - } - - /** - * Creates and returns a versioned store using the config parameters - * with a specific number of versions to keep, and which can read from - * the most recent available version on HDFS rather than a specific - * version number. The store applies a timestamp correction based on the - * number of days of aggregate data skipped over at read time to ensure - * that skipping data plays nicely with halfLife decay. - * - * @param config specifying the Manhattan store parameters - * @param versionsToKeep number of old versions to keep - */ - def getMostRecentLagCorrectingVersionedStoreWithRetention[ - Key: Codec: Ordering, - ValInStore: Codec, - ValInMemory - ]( - config: OfflineStoreOnlyConfig[ManhattanROConfig], - versionsToKeep: Int, - lagCorrector: (ValInMemory, Long) => ValInMemory, - packer: ValInMemory => ValInStore, - unpacker: ValInStore => ValInMemory - ): ScaldingBatchedStore[Key, ValInMemory] = { - MostRecentLagCorrectingVersionedStore[Key, ValInStore, ValInMemory]( - config.offline.hdfsPath.toString, - packer = packer, - unpacker = unpacker, - versionsToKeep = versionsToKeep)( - Injection.connect[(Key, (BatchID, ValInStore)), (Array[Byte], Array[Byte])], - config.batcher, - implicitly[Ordering[Key]], - lagCorrector - ).withInitialBatch(config.batcher.batchOf(config.startTime.value)) - } - - def mutablyCorrectDataRecordTimestamp( - record: DataRecord, - lagToCorrectMillis: Long - ): DataRecord = { - val richRecord = SRichDataRecord(record) - if (richRecord.hasFeature(TIMESTAMP)) { - val timestamp = richRecord.getFeatureValue(TIMESTAMP).toLong - richRecord.setFeatureValue(TIMESTAMP, timestamp + lagToCorrectMillis) - } - record - } - - /** - * Function that maps from a logical ''AggregateStore'' - * to an underlying physical store. The physical store for - * scalding is a HDFS VersionedKeyValSource dataset. - */ - def aggregateStoreToScalding( - store: AggregateStore - ): Option[Scalding#Store[AggregationKey, DataRecord]] = { - store match { - case offlineStore: OfflineAggregateDataRecordStore => - Some( - getMostRecentLagCorrectingVersionedStoreWithRetention[ - AggregationKey, - DataRecord, - DataRecord]( - offlineStore, - versionsToKeep = offlineStore.batchesToKeep, - lagCorrector = mutablyCorrectDataRecordTimestamp, - packer = Injection.identity[DataRecord], - unpacker = Injection.identity[DataRecord] - ) - ) - case offlineStore: OfflineAggregateDataRecordStoreWithDAL => - Some( - DAL.versionedKeyValStore[AggregationKey, DataRecord]( - dataset = offlineStore.dalDataset, - pathLayout = D.Suffix(offlineStore.offline.hdfsPath.toString), - batcher = offlineStore.batcher, - maybeStartTime = Some(offlineStore.startTime), - maxErrors = offlineStore.maxKvSourceFailures - )) - case _ => None - } - } - - def generate(args: Args): ScaldingConfig = new ScaldingConfig { - val jobName = JobName(args("job_name")) - - /* - * Add registrars for chill serialization for user-defined types. - * We use the default: an empty List(). - */ - override def registrars = List() - - /* Use transformConfig to set Hadoop options. */ - override def transformConfig(config: Map[String, AnyRef]): Map[String, AnyRef] = - super.transformConfig(config) ++ Map( - "mapreduce.output.fileoutputformat.compress" -> "true", - "mapreduce.output.fileoutputformat.compress.codec" -> "com.hadoop.compression.lzo.LzoCodec", - "mapreduce.output.fileoutputformat.compress.type" -> "BLOCK" - ) - - /* - * Use getNamedOptions to set Summingbird runtime options - * The options we set are: - * 1) Set monoid to non-commutative to disable map-side - * aggregation and force all aggregation to reducers (provides a 20% speedup) - */ - override def getNamedOptions: Map[String, Options] = Map( - "DEFAULT" -> Options() - .set(MonoidIsCommutative(false)) - .set(Reducers(numReducers)) - ) - - implicit val batcher: Batcher = Batcher.ofHours(24) - - /* State implementation that uses Statebird (go/statebird) to track the batches processed. */ - def getWaitingState(hadoopConfig: Configuration, startDate: Option[Timestamp], batches: Int) = - StatebirdState( - jobName, - startDate, - batches, - args.optional("statebird_service_destination"), - args.optional("statebird_client_id_name") - )(batcher) - - val sourceNameFilter: Option[Set[String]] = - args.optional("input_sources").map(_.split(",").toSet) - val storeNameFilter: Option[Set[String]] = - args.optional("output_stores").map(_.split(",").toSet) - - val filteredAggregates = - AggregatesV2Job.filterAggregates( - aggregates = aggregatesToCompute, - sourceNames = sourceNameFilter, - storeNames = storeNameFilter - ) - - override val graph = - AggregatesV2Job.generateJobGraph[Scalding]( - filteredAggregates, - dataRecordSourceToScalding, - aggregateStoreToScalding - )(DataRecordAggregationMonoid(filteredAggregates)) - } - def main(args: Array[String]): Unit = { - GenericRunner(args, generate(_)) - - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregationKeyOrdering.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregationKeyOrdering.docx new file mode 100644 index 000000000..e32377ef4 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregationKeyOrdering.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregationKeyOrdering.scala b/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregationKeyOrdering.scala deleted file mode 100644 index af6f14ff2..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/AggregationKeyOrdering.scala +++ /dev/null @@ -1,17 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding - -import com.twitter.scalding_internal.job.RequiredBinaryComparators.ordSer -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MacroEqualityOrderedSerialization - -object AggregationKeyOrdering extends Ordering[AggregationKey] { - implicit val featureMapsOrdering: MacroEqualityOrderedSerialization[ - (Map[Long, Long], Map[Long, String]) - ] = ordSer[(Map[Long, Long], Map[Long, String])] - - override def compare(left: AggregationKey, right: AggregationKey): Int = - featureMapsOrdering.compare( - AggregationKey.unapply(left).get, - AggregationKey.unapply(right).get - ) -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/BUILD b/timelines/data_processing/ml_util/aggregation_framework/scalding/BUILD deleted file mode 100644 index d03766619..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -scala_library( - sources = ["*.scala"], - platform = "java8", - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/twitter/bijection:core", - "3rdparty/jvm/com/twitter/bijection:json", - "3rdparty/jvm/com/twitter/bijection:netty", - "3rdparty/jvm/com/twitter/bijection:scrooge", - "3rdparty/jvm/com/twitter/bijection:thrift", - "3rdparty/jvm/com/twitter/bijection:util", - "3rdparty/jvm/com/twitter/chill:bijection", - "3rdparty/jvm/com/twitter/storehaus:algebra", - "3rdparty/jvm/com/twitter/storehaus:core", - "3rdparty/jvm/org/apache/hadoop:hadoop-client-default", - "3rdparty/src/jvm/com/twitter/scalding:args", - "3rdparty/src/jvm/com/twitter/scalding:commons", - "3rdparty/src/jvm/com/twitter/scalding:core", - "3rdparty/src/jvm/com/twitter/summingbird:batch", - "3rdparty/src/jvm/com/twitter/summingbird:batch-hadoop", - "3rdparty/src/jvm/com/twitter/summingbird:chill", - "3rdparty/src/jvm/com/twitter/summingbird:core", - "3rdparty/src/jvm/com/twitter/summingbird:scalding", - "finagle/finagle-core/src/main", - "gizmoduck/snapshot/src/main/scala/com/twitter/gizmoduck/snapshot:deleted_user-scala", - "src/java/com/twitter/ml/api:api-base", - "src/java/com/twitter/ml/api/constant", - "src/scala/com/twitter/ml/api/util", - "src/scala/com/twitter/scalding_internal/dalv2", - "src/scala/com/twitter/scalding_internal/job/analytics_batch", - "src/scala/com/twitter/scalding_internal/util", - "src/scala/com/twitter/storehaus_internal/manhattan/config", - "src/scala/com/twitter/storehaus_internal/offline", - "src/scala/com/twitter/storehaus_internal/util", - "src/scala/com/twitter/summingbird_internal/bijection", - "src/scala/com/twitter/summingbird_internal/bijection:bijection-implicits", - "src/scala/com/twitter/summingbird_internal/dalv2", - "src/scala/com/twitter/summingbird_internal/runner/common", - "src/scala/com/twitter/summingbird_internal/runner/scalding", - "src/scala/com/twitter/summingbird_internal/runner/store_config", - "src/scala/com/twitter/summingbird_internal/runner/store_config/versioned_store", - "src/scala/com/twitter/summingbird_internal/sources/common", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:interpretable-model-java", - "src/thrift/com/twitter/statebird:compiled-v2-java", - "timelines/data_processing/ml_util/aggregation_framework:common_types", - "timelines/data_processing/ml_util/aggregation_framework:user_job", - "timelines/data_processing/ml_util/aggregation_framework/scalding/sources", - "timelines/data_processing/ml_util/sampling:sampling_utils", - ], - exports = [ - "3rdparty/src/jvm/com/twitter/summingbird:scalding", - "src/scala/com/twitter/storehaus_internal/manhattan/config", - "src/scala/com/twitter/summingbird_internal/runner/store_config", - ], -) - -hadoop_binary( - name = "bin", - basename = "aggregation_framework_scalding-deploy", - main = "com.twitter.scalding.Tool", - platform = "java8", - runtime_platform = "java8", - tags = [ - "bazel-compatible", - "bazel-compatible:migrated", - "bazel-only", - ], - dependencies = [ - ":scalding", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/BUILD.docx new file mode 100644 index 000000000..710764d5a Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/DeletedUserPruner.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/DeletedUserPruner.docx new file mode 100644 index 000000000..8333296ff Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/DeletedUserPruner.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/DeletedUserPruner.scala b/timelines/data_processing/ml_util/aggregation_framework/scalding/DeletedUserPruner.scala deleted file mode 100644 index 7e2f7a95c..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/DeletedUserPruner.scala +++ /dev/null @@ -1,97 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding - -import com.twitter.gizmoduck.snapshot.DeletedUserScalaDataset -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.Feature -import com.twitter.scalding.typed.TypedPipe -import com.twitter.scalding.DateOps -import com.twitter.scalding.DateRange -import com.twitter.scalding.Days -import com.twitter.scalding.RichDate -import com.twitter.scalding_internal.dalv2.DAL -import com.twitter.scalding_internal.dalv2.remote_access.AllowCrossClusterSameDC -import com.twitter.scalding_internal.job.RequiredBinaryComparators.ordSer -import com.twitter.scalding_internal.pruner.Pruner -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.AggregationKey -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.TypedAggregateGroup -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MacroEqualityOrderedSerialization -import java.{util => ju} - -object DeletedUserSeqPruner extends Pruner[Seq[Long]] { - implicit val tz: ju.TimeZone = DateOps.UTC - implicit val userIdSequenceOrdering: MacroEqualityOrderedSerialization[Seq[Long]] = - ordSer[Seq[Long]] - - private[scalding] def pruneDeletedUsers[T]( - input: TypedPipe[T], - extractor: T => Seq[Long], - deletedUsers: TypedPipe[Long] - ): TypedPipe[T] = { - val userIdsAndValues = input.map { t: T => - val userIds: Seq[Long] = extractor(t) - (userIds, t) - } - - // Find all valid sequences of userids in the input pipe - // that contain at least one deleted user. This is efficient - // as long as the number of deleted users is small. - val userSequencesWithDeletedUsers = userIdsAndValues - .flatMap { case (userIds, _) => userIds.map((_, userIds)) } - .leftJoin(deletedUsers.asKeys) - .collect { case (_, (userIds, Some(_))) => userIds } - .distinct - - userIdsAndValues - .leftJoin(userSequencesWithDeletedUsers.asKeys) - .collect { case (_, (t, None)) => t } - } - - override def prune[T]( - input: TypedPipe[T], - put: (T, Seq[Long]) => Option[T], - get: T => Seq[Long], - writeTime: RichDate - ): TypedPipe[T] = { - lazy val deletedUsers = DAL - .readMostRecentSnapshot(DeletedUserScalaDataset, DateRange(writeTime - Days(7), writeTime)) - .withRemoteReadPolicy(AllowCrossClusterSameDC) - .toTypedPipe - .map(_.userId) - - pruneDeletedUsers(input, get, deletedUsers) - } -} - -object AggregationKeyPruner { - - /** - * Makes a pruner that prunes aggregate records where any of the - * "userIdFeatures" set in the aggregation key correspond to a - * user who has deleted their account. Here, "userIdFeatures" is - * intended as a catch-all term for all features corresponding to - * a Twitter user in the input data record -- the feature itself - * could represent an authorId, retweeterId, engagerId, etc. - */ - def mkDeletedUsersPruner( - userIdFeatures: Seq[Feature[_]] - ): Pruner[(AggregationKey, DataRecord)] = { - val userIdFeatureIds = userIdFeatures.map(TypedAggregateGroup.getDenseFeatureId) - - def getter(tupled: (AggregationKey, DataRecord)): Seq[Long] = { - tupled match { - case (aggregationKey, _) => - userIdFeatureIds.flatMap { id => - aggregationKey.discreteFeaturesById - .get(id) - .orElse(aggregationKey.textFeaturesById.get(id).map(_.toLong)) - } - } - } - - // Setting putter to always return None here. The put function is not used within pruneDeletedUsers, this function is just needed for xmap api. - def putter: ((AggregationKey, DataRecord), Seq[Long]) => Option[(AggregationKey, DataRecord)] = - (t, seq) => None - - DeletedUserSeqPruner.xmap(putter, getter) - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/MostRecentVersionedStore.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/MostRecentVersionedStore.docx new file mode 100644 index 000000000..bd2efd0a9 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/MostRecentVersionedStore.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/MostRecentVersionedStore.scala b/timelines/data_processing/ml_util/aggregation_framework/scalding/MostRecentVersionedStore.scala deleted file mode 100644 index d60e67716..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/MostRecentVersionedStore.scala +++ /dev/null @@ -1,100 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding - -import com.twitter.bijection.Injection -import com.twitter.scalding.commons.source.VersionedKeyValSource -import com.twitter.scalding.TypedPipe -import com.twitter.scalding.{Hdfs => HdfsMode} -import com.twitter.summingbird.batch.store.HDFSMetadata -import com.twitter.summingbird.batch.BatchID -import com.twitter.summingbird.batch.Batcher -import com.twitter.summingbird.batch.OrderedFromOrderingExt -import com.twitter.summingbird.batch.PrunedSpace -import com.twitter.summingbird.scalding._ -import com.twitter.summingbird.scalding.store.VersionedBatchStore -import org.slf4j.LoggerFactory - -object MostRecentLagCorrectingVersionedStore { - def apply[Key, ValInStore, ValInMemory]( - rootPath: String, - packer: ValInMemory => ValInStore, - unpacker: ValInStore => ValInMemory, - versionsToKeep: Int = VersionedKeyValSource.defaultVersionsToKeep, - prunedSpace: PrunedSpace[(Key, ValInMemory)] = PrunedSpace.neverPruned - )( - implicit injection: Injection[(Key, (BatchID, ValInStore)), (Array[Byte], Array[Byte])], - batcher: Batcher, - ord: Ordering[Key], - lagCorrector: (ValInMemory, Long) => ValInMemory - ): MostRecentLagCorrectingVersionedBatchStore[Key, ValInMemory, Key, (BatchID, ValInStore)] = { - new MostRecentLagCorrectingVersionedBatchStore[Key, ValInMemory, Key, (BatchID, ValInStore)]( - rootPath, - versionsToKeep, - batcher - )(lagCorrector)({ case (batchID, (k, v)) => (k, (batchID.next, packer(v))) })({ - case (k, (_, v)) => (k, unpacker(v)) - }) { - override def select(b: List[BatchID]) = List(b.last) - override def pruning: PrunedSpace[(Key, ValInMemory)] = prunedSpace - } - } -} - -/** - * @param lagCorrector lagCorrector allows one to take data from one batch and pretend as if it - * came from a different batch. - * @param pack Converts the in-memory tuples to the type used by the underlying key-val store. - * @param unpack Converts the key-val tuples from the store in the form used by the calling object. - */ -class MostRecentLagCorrectingVersionedBatchStore[KeyInMemory, ValInMemory, KeyInStore, ValInStore]( - rootPath: String, - versionsToKeep: Int, - override val batcher: Batcher -)( - lagCorrector: (ValInMemory, Long) => ValInMemory -)( - pack: (BatchID, (KeyInMemory, ValInMemory)) => (KeyInStore, ValInStore) -)( - unpack: ((KeyInStore, ValInStore)) => (KeyInMemory, ValInMemory) -)( - implicit @transient injection: Injection[(KeyInStore, ValInStore), (Array[Byte], Array[Byte])], - override val ordering: Ordering[KeyInMemory]) - extends VersionedBatchStore[KeyInMemory, ValInMemory, KeyInStore, ValInStore]( - rootPath, - versionsToKeep, - batcher)(pack)(unpack)(injection, ordering) { - - import OrderedFromOrderingExt._ - - @transient private val logger = - LoggerFactory.getLogger(classOf[MostRecentLagCorrectingVersionedBatchStore[_, _, _, _]]) - - override protected def lastBatch( - exclusiveUB: BatchID, - mode: HdfsMode - ): Option[(BatchID, FlowProducer[TypedPipe[(KeyInMemory, ValInMemory)]])] = { - val batchToPretendAs = exclusiveUB.prev - val versionToPretendAs = batchIDToVersion(batchToPretendAs) - logger.info( - s"Most recent lag correcting versioned batched store at $rootPath entering lastBatch method versionToPretendAs = $versionToPretendAs") - val meta = new HDFSMetadata(mode.conf, rootPath) - meta.versions - .map { ver => (versionToBatchID(ver), readVersion(ver)) } - .filter { _._1 < exclusiveUB } - .reduceOption { (a, b) => if (a._1 > b._1) a else b } - .map { - case ( - lastBatchID: BatchID, - flowProducer: FlowProducer[TypedPipe[(KeyInMemory, ValInMemory)]]) => - val lastVersion = batchIDToVersion(lastBatchID) - val lagToCorrectMillis: Long = - batchIDToVersion(batchToPretendAs) - batchIDToVersion(lastBatchID) - logger.info( - s"Most recent available version is $lastVersion, so lagToCorrectMillis is $lagToCorrectMillis") - val lagCorrectedFlowProducer = flowProducer.map { - pipe: TypedPipe[(KeyInMemory, ValInMemory)] => - pipe.map { case (k, v) => (k, lagCorrector(v, lagToCorrectMillis)) } - } - (batchToPretendAs, lagCorrectedFlowProducer) - } - } -} diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/BUILD b/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/BUILD deleted file mode 100644 index ba065ecd7..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -scala_library( - sources = ["*.scala"], - platform = "java8", - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/twitter/storehaus:algebra", - "3rdparty/src/jvm/com/twitter/scalding:commons", - "3rdparty/src/jvm/com/twitter/scalding:core", - "3rdparty/src/jvm/com/twitter/scalding:date", - "3rdparty/src/jvm/com/twitter/summingbird:batch", - "3rdparty/src/jvm/com/twitter/summingbird:batch-hadoop", - "3rdparty/src/jvm/com/twitter/summingbird:chill", - "3rdparty/src/jvm/com/twitter/summingbird:core", - "3rdparty/src/jvm/com/twitter/summingbird:scalding", - "src/java/com/twitter/ml/api:api-base", - "src/scala/com/twitter/ml/api:api-base", - "src/scala/com/twitter/ml/api/internal", - "src/scala/com/twitter/ml/api/util", - "src/scala/com/twitter/scalding_internal/dalv2", - "src/scala/com/twitter/scalding_internal/dalv2/remote_access", - "src/scala/com/twitter/summingbird_internal/sources/common", - "src/thrift/com/twitter/ml/api:data-java", - "src/thrift/com/twitter/ml/api:interpretable-model-java", - "timelines/data_processing/ml_util/aggregation_framework:common_types", - ], -) diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/BUILD.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/BUILD.docx new file mode 100644 index 000000000..bf9b66104 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/BUILD.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/ScaldingAggregateSource.docx b/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/ScaldingAggregateSource.docx new file mode 100644 index 000000000..93574b5c6 Binary files /dev/null and b/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/ScaldingAggregateSource.docx differ diff --git a/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/ScaldingAggregateSource.scala b/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/ScaldingAggregateSource.scala deleted file mode 100644 index d1820b4fc..000000000 --- a/timelines/data_processing/ml_util/aggregation_framework/scalding/sources/ScaldingAggregateSource.scala +++ /dev/null @@ -1,77 +0,0 @@ -package com.twitter.timelines.data_processing.ml_util.aggregation_framework.scalding.sources - -import com.twitter.ml.api.DailySuffixFeatureSource -import com.twitter.ml.api.DataRecord -import com.twitter.ml.api.FixedPathFeatureSource -import com.twitter.ml.api.HourlySuffixFeatureSource -import com.twitter.ml.api.util.SRichDataRecord -import com.twitter.scalding._ -import com.twitter.scalding_internal.dalv2.DAL -import com.twitter.scalding_internal.dalv2.remote_access.AllowCrossClusterSameDC -import com.twitter.statebird.v2.thriftscala.Environment -import com.twitter.summingbird._ -import com.twitter.summingbird.scalding.Scalding.pipeFactoryExact -import com.twitter.summingbird.scalding._ -import com.twitter.summingbird_internal.sources.SourceFactory -import com.twitter.timelines.data_processing.ml_util.aggregation_framework.OfflineAggregateSource -import java.lang.{Long => JLong} - -/* - * Summingbird offline HDFS source that reads from data records on HDFS. - * - * @param offlineSource Underlying offline source that contains - * all the config info to build this platform-specific (scalding) source. - */ -case class ScaldingAggregateSource(offlineSource: OfflineAggregateSource) - extends SourceFactory[Scalding, DataRecord] { - - val hdfsPath: String = offlineSource.scaldingHdfsPath.getOrElse("") - val suffixType: String = offlineSource.scaldingSuffixType.getOrElse("daily") - val withValidation: Boolean = offlineSource.withValidation - def name: String = offlineSource.name - def description: String = - "Summingbird offline source that reads from data records at: " + hdfsPath - - implicit val timeExtractor: TimeExtractor[DataRecord] = TimeExtractor((record: DataRecord) => - SRichDataRecord(record).getFeatureValue[JLong, JLong](offlineSource.timestampFeature)) - - def getSourceForDateRange(dateRange: DateRange) = { - suffixType match { - case "daily" => DailySuffixFeatureSource(hdfsPath)(dateRange).source - case "hourly" => HourlySuffixFeatureSource(hdfsPath)(dateRange).source - case "fixed_path" => FixedPathFeatureSource(hdfsPath).source - case "dal" => - offlineSource.dalDataSet match { - case Some(dataset) => - DAL - .read(dataset, dateRange) - .withRemoteReadPolicy(AllowCrossClusterSameDC) - .withEnvironment(Environment.Prod) - .toTypedSource - case _ => - throw new IllegalArgumentException( - "cannot provide an empty dataset when defining DAL as the suffix type" - ) - } - } - } - - /** - * This method is similar to [[Scalding.sourceFromMappable]] except that this uses [[pipeFactoryExact]] - * instead of [[pipeFactory]]. [[pipeFactoryExact]] also invokes [[FileSource.validateTaps]] on the source. - * The validation ensures the presence of _SUCCESS file before processing. For more details, please refer to - * https://jira.twitter.biz/browse/TQ-10618 - */ - def sourceFromMappableWithValidation[T: TimeExtractor: Manifest]( - factory: (DateRange) => Mappable[T] - ): Producer[Scalding, T] = { - Producer.source[Scalding, T](pipeFactoryExact(factory)) - } - - def source: Producer[Scalding, DataRecord] = { - if (withValidation) - sourceFromMappableWithValidation(getSourceForDateRange) - else - Scalding.sourceFromMappable(getSourceForDateRange) - } -} diff --git a/topic-social-proof/README.docx b/topic-social-proof/README.docx new file mode 100644 index 000000000..328cf2597 Binary files /dev/null and b/topic-social-proof/README.docx differ diff --git a/topic-social-proof/README.md b/topic-social-proof/README.md deleted file mode 100644 index d98b7ba3b..000000000 --- a/topic-social-proof/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# Topic Social Proof Service (TSPS) -================= - -**Topic Social Proof Service** (TSPS) serves as a centralized source for verifying topics related to Timelines and Notifications. By analyzing user's topic preferences, such as following or unfollowing, and employing semantic annotations and tweet embeddings from SimClusters, or other machine learning models, TSPS delivers highly relevant topics tailored to each user's interests. - -For instance, when a tweet discusses Stephen Curry, the service determines if the content falls under topics like "NBA" and/or "Golden State Warriors" while also providing relevance scores based on SimClusters Embedding. Additionally, TSPS evaluates user-specific topic preferences to offer a comprehensive list of available topics, only those the user is currently following, or new topics they have not followed but may find interesting if recommended on specific product surfaces. - - diff --git a/topic-social-proof/server/BUILD b/topic-social-proof/server/BUILD deleted file mode 100644 index 9fb977d17..000000000 --- a/topic-social-proof/server/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -jvm_binary( - name = "bin", - basename = "topic-social-proof", - main = "com.twitter.tsp.TopicSocialProofStratoFedServerMain", - runtime_platform = "java11", - tags = [ - "bazel-compatible", - ], - dependencies = [ - "strato/src/main/scala/com/twitter/strato/logging/logback", - "topic-social-proof/server/src/main/resources", - "topic-social-proof/server/src/main/scala/com/twitter/tsp", - ], -) - -# Aurora Workflows build phase convention requires a jvm_app named with ${project-name}-app -jvm_app( - name = "topic-social-proof-app", - archive = "zip", - binary = ":bin", - tags = [ - "bazel-compatible", - ], -) diff --git a/topic-social-proof/server/BUILD.docx b/topic-social-proof/server/BUILD.docx new file mode 100644 index 000000000..7d8284057 Binary files /dev/null and b/topic-social-proof/server/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/resources/BUILD b/topic-social-proof/server/src/main/resources/BUILD deleted file mode 100644 index 8f96f402c..000000000 --- a/topic-social-proof/server/src/main/resources/BUILD +++ /dev/null @@ -1,8 +0,0 @@ -resources( - sources = [ - "*.xml", - "*.yml", - "config/*.yml", - ], - tags = ["bazel-compatible"], -) diff --git a/topic-social-proof/server/src/main/resources/BUILD.docx b/topic-social-proof/server/src/main/resources/BUILD.docx new file mode 100644 index 000000000..e94b3643a Binary files /dev/null and b/topic-social-proof/server/src/main/resources/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/resources/config/decider.docx b/topic-social-proof/server/src/main/resources/config/decider.docx new file mode 100644 index 000000000..b6d6f6b91 Binary files /dev/null and b/topic-social-proof/server/src/main/resources/config/decider.docx differ diff --git a/topic-social-proof/server/src/main/resources/config/decider.yml b/topic-social-proof/server/src/main/resources/config/decider.yml deleted file mode 100644 index c40dd7080..000000000 --- a/topic-social-proof/server/src/main/resources/config/decider.yml +++ /dev/null @@ -1,61 +0,0 @@ -# Keys are sorted in an alphabetical order - -enable_topic_social_proof_score: - comment : "Enable the calculation of cosine similarity score in TopicSocialProofStore. 0 means do not calculate the score and use a random rank to generate topic social proof" - default_availability: 0 - -enable_tweet_health_score: - comment: "Enable the calculation for health scores in tweetInfo. By enabling this decider, we will compute TweetHealthModelScore" - default_availability: 0 - -enable_user_agatha_score: - comment: "Enable the calculation for health scores in tweetInfo. By enabling this decider, we will compute UserHealthModelScore" - default_availability: 0 - -enable_loadshedding_HomeTimeline: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_HomeTimelineTopicTweets: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_HomeTimelineRecommendTopicTweets: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_MagicRecsRecommendTopicTweets: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_TopicLandingPage: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_HomeTimelineFeatures: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_HomeTimelineTopicTweetsMetrics: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_HomeTimelineUTEGTopicTweets: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_HomeTimelineSimClusters: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_ExploreTopicTweets: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_MagicRecsTopicTweets: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 - -enable_loadshedding_Search: - comment: "Enable loadshedding (from 0% to 100%). Requests that have been shed will return an empty response" - default_availability: 0 diff --git a/topic-social-proof/server/src/main/resources/logback.docx b/topic-social-proof/server/src/main/resources/logback.docx new file mode 100644 index 000000000..01c5f2a28 Binary files /dev/null and b/topic-social-proof/server/src/main/resources/logback.docx differ diff --git a/topic-social-proof/server/src/main/resources/logback.xml b/topic-social-proof/server/src/main/resources/logback.xml deleted file mode 100644 index d08b0a965..000000000 --- a/topic-social-proof/server/src/main/resources/logback.xml +++ /dev/null @@ -1,155 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - true - - - - - ${log.service.output} - - ${log.service.output}.%i - 1 - 10 - - - 50MB - - - %date %.-3level ${DEFAULT_SERVICE_PATTERN}%n - - - - - - ${log.strato_only.output} - - ${log.strato_only.output}.%i - 1 - 10 - - - 50MB - - - %date %.-3level ${DEFAULT_SERVICE_PATTERN}%n - - - - - - true - loglens - ${log.lens.index} - ${log.lens.tag}/service - - %msg%n - - - 500 - 50 - - - manhattan-client - .*InvalidRequest.* - - - - - - - - - ${async_queue_size} - ${async_max_flush_time} - - - - - ${async_queue_size} - ${async_max_flush_time} - - - - - ${async_queue_size} - ${async_max_flush_time} - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/BUILD deleted file mode 100644 index 2052c5047..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "finatra/inject/inject-thrift-client", - "strato/src/main/scala/com/twitter/strato/fed", - "strato/src/main/scala/com/twitter/strato/fed/server", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/columns", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/BUILD.docx new file mode 100644 index 000000000..efce96897 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/TopicSocialProofStratoFedServer.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/TopicSocialProofStratoFedServer.docx new file mode 100644 index 000000000..41a372147 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/TopicSocialProofStratoFedServer.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/TopicSocialProofStratoFedServer.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/TopicSocialProofStratoFedServer.scala deleted file mode 100644 index 22d3c19f0..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/TopicSocialProofStratoFedServer.scala +++ /dev/null @@ -1,56 +0,0 @@ -package com.twitter.tsp - -import com.google.inject.Module -import com.twitter.strato.fed._ -import com.twitter.strato.fed.server._ -import com.twitter.strato.warmup.Warmer -import com.twitter.tsp.columns.TopicSocialProofColumn -import com.twitter.tsp.columns.TopicSocialProofBatchColumn -import com.twitter.tsp.handlers.UttChildrenWarmupHandler -import com.twitter.tsp.modules.RepresentationScorerStoreModule -import com.twitter.tsp.modules.GizmoduckUserModule -import com.twitter.tsp.modules.TSPClientIdModule -import com.twitter.tsp.modules.TopicListingModule -import com.twitter.tsp.modules.TopicSocialProofStoreModule -import com.twitter.tsp.modules.TopicTweetCosineSimilarityAggregateStoreModule -import com.twitter.tsp.modules.TweetInfoStoreModule -import com.twitter.tsp.modules.TweetyPieClientModule -import com.twitter.tsp.modules.UttClientModule -import com.twitter.tsp.modules.UttLocalizationModule -import com.twitter.util.Future - -object TopicSocialProofStratoFedServerMain extends TopicSocialProofStratoFedServer - -trait TopicSocialProofStratoFedServer extends StratoFedServer { - override def dest: String = "/s/topic-social-proof/topic-social-proof" - - override val modules: Seq[Module] = - Seq( - GizmoduckUserModule, - RepresentationScorerStoreModule, - TopicSocialProofStoreModule, - TopicListingModule, - TopicTweetCosineSimilarityAggregateStoreModule, - TSPClientIdModule, - TweetInfoStoreModule, - TweetyPieClientModule, - UttClientModule, - UttLocalizationModule - ) - - override def columns: Seq[Class[_ <: StratoFed.Column]] = - Seq( - classOf[TopicSocialProofColumn], - classOf[TopicSocialProofBatchColumn] - ) - - override def configureWarmer(warmer: Warmer): Unit = { - warmer.add( - "uttChildrenWarmupHandler", - () => { - handle[UttChildrenWarmupHandler]() - Future.Unit - } - ) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/BUILD deleted file mode 100644 index c29b7ea35..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "stitch/stitch-storehaus", - "strato/src/main/scala/com/twitter/strato/fed", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/service", - "topic-social-proof/server/src/main/thrift:thrift-scala", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/BUILD.docx new file mode 100644 index 000000000..938843d2a Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofBatchColumn.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofBatchColumn.docx new file mode 100644 index 000000000..1832edf4b Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofBatchColumn.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofBatchColumn.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofBatchColumn.scala deleted file mode 100644 index f451e662a..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofBatchColumn.scala +++ /dev/null @@ -1,84 +0,0 @@ -package com.twitter.tsp.columns - -import com.twitter.stitch.SeqGroup -import com.twitter.stitch.Stitch -import com.twitter.strato.catalog.Fetch -import com.twitter.strato.catalog.OpMetadata -import com.twitter.strato.config._ -import com.twitter.strato.config.AllowAll -import com.twitter.strato.config.ContactInfo -import com.twitter.strato.config.Policy -import com.twitter.strato.data.Conv -import com.twitter.strato.data.Description.PlainText -import com.twitter.strato.data.Lifecycle.Production -import com.twitter.strato.fed.StratoFed -import com.twitter.strato.thrift.ScroogeConv -import com.twitter.tsp.thriftscala.TopicSocialProofRequest -import com.twitter.tsp.thriftscala.TopicSocialProofOptions -import com.twitter.tsp.service.TopicSocialProofService -import com.twitter.tsp.thriftscala.TopicWithScore -import com.twitter.util.Future -import com.twitter.util.Try -import javax.inject.Inject - -class TopicSocialProofBatchColumn @Inject() ( - topicSocialProofService: TopicSocialProofService) - extends StratoFed.Column(TopicSocialProofBatchColumn.Path) - with StratoFed.Fetch.Stitch { - - override val policy: Policy = - ReadWritePolicy( - readPolicy = AllowAll, - writePolicy = AllowKeyAuthenticatedTwitterUserId - ) - - override type Key = Long - override type View = TopicSocialProofOptions - override type Value = Seq[TopicWithScore] - - override val keyConv: Conv[Key] = Conv.ofType - override val viewConv: Conv[View] = ScroogeConv.fromStruct[TopicSocialProofOptions] - override val valueConv: Conv[Value] = Conv.seq(ScroogeConv.fromStruct[TopicWithScore]) - override val metadata: OpMetadata = - OpMetadata( - lifecycle = Some(Production), - Some(PlainText("Topic Social Proof Batched Federated Column"))) - - case class TspsGroup(view: View) extends SeqGroup[Long, Fetch.Result[Value]] { - override protected def run(keys: Seq[Long]): Future[Seq[Try[Result[Seq[TopicWithScore]]]]] = { - val request = TopicSocialProofRequest( - userId = view.userId, - tweetIds = keys.toSet, - displayLocation = view.displayLocation, - topicListingSetting = view.topicListingSetting, - context = view.context, - bypassModes = view.bypassModes, - tags = view.tags - ) - - val response = topicSocialProofService - .topicSocialProofHandlerStoreStitch(request) - .map(_.socialProofs) - Stitch - .run(response).map(r => - keys.map(key => { - Try { - val v = r.get(key) - if (v.nonEmpty && v.get.nonEmpty) { - found(v.get) - } else { - missing - } - } - })) - } - } - - override def fetch(key: Key, view: View): Stitch[Result[Value]] = { - Stitch.call(key, TspsGroup(view)) - } -} - -object TopicSocialProofBatchColumn { - val Path = "topic-signals/tsp/topic-social-proof-batched" -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofColumn.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofColumn.docx new file mode 100644 index 000000000..e4e632582 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofColumn.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofColumn.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofColumn.scala deleted file mode 100644 index 10425eccb..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/columns/TopicSocialProofColumn.scala +++ /dev/null @@ -1,47 +0,0 @@ -package com.twitter.tsp.columns - -import com.twitter.stitch -import com.twitter.stitch.Stitch -import com.twitter.strato.catalog.OpMetadata -import com.twitter.strato.config._ -import com.twitter.strato.config.AllowAll -import com.twitter.strato.config.ContactInfo -import com.twitter.strato.config.Policy -import com.twitter.strato.data.Conv -import com.twitter.strato.data.Description.PlainText -import com.twitter.strato.data.Lifecycle.Production -import com.twitter.strato.fed.StratoFed -import com.twitter.strato.thrift.ScroogeConv -import com.twitter.tsp.thriftscala.TopicSocialProofRequest -import com.twitter.tsp.thriftscala.TopicSocialProofResponse -import com.twitter.tsp.service.TopicSocialProofService -import javax.inject.Inject - -class TopicSocialProofColumn @Inject() ( - topicSocialProofService: TopicSocialProofService) - extends StratoFed.Column(TopicSocialProofColumn.Path) - with StratoFed.Fetch.Stitch { - - override type Key = TopicSocialProofRequest - override type View = Unit - override type Value = TopicSocialProofResponse - - override val keyConv: Conv[Key] = ScroogeConv.fromStruct[TopicSocialProofRequest] - override val viewConv: Conv[View] = Conv.ofType - override val valueConv: Conv[Value] = ScroogeConv.fromStruct[TopicSocialProofResponse] - override val metadata: OpMetadata = - OpMetadata(lifecycle = Some(Production), Some(PlainText("Topic Social Proof Federated Column"))) - - override def fetch(key: Key, view: View): Stitch[Result[Value]] = { - topicSocialProofService - .topicSocialProofHandlerStoreStitch(key) - .map { result => found(result) } - .handle { - case stitch.NotFound => missing - } - } -} - -object TopicSocialProofColumn { - val Path = "topic-signals/tsp/topic-social-proof" -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/BUILD deleted file mode 100644 index 7b5fda3b0..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "configapi/configapi-abdecider", - "configapi/configapi-core", - "content-recommender/thrift/src/main/thrift:thrift-scala", - "decider/src/main/scala", - "discovery-common/src/main/scala/com/twitter/discovery/common/configapi", - "featureswitches/featureswitches-core", - "finatra/inject/inject-core/src/main/scala", - "frigate/frigate-common:base", - "frigate/frigate-common:util", - "frigate/frigate-common/src/main/scala/com/twitter/frigate/common/candidate", - "interests-service/thrift/src/main/thrift:thrift-scala", - "src/scala/com/twitter/simclusters_v2/common", - "src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala", - "stitch/stitch-storehaus", - "topic-social-proof/server/src/main/thrift:thrift-scala", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/BUILD.docx new file mode 100644 index 000000000..0a31edade Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/DeciderConstants.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/DeciderConstants.docx new file mode 100644 index 000000000..750fb897b Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/DeciderConstants.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/DeciderConstants.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/DeciderConstants.scala deleted file mode 100644 index de025128d..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/DeciderConstants.scala +++ /dev/null @@ -1,19 +0,0 @@ -package com.twitter.tsp.common - -import com.twitter.servo.decider.DeciderKeyEnum - -object DeciderConstants { - val enableTopicSocialProofScore = "enable_topic_social_proof_score" - val enableHealthSignalsScoreDeciderKey = "enable_tweet_health_score" - val enableUserAgathaScoreDeciderKey = "enable_user_agatha_score" -} - -object DeciderKey extends DeciderKeyEnum { - - val enableHealthSignalsScoreDeciderKey: Value = Value( - DeciderConstants.enableHealthSignalsScoreDeciderKey - ) - val enableUserAgathaScoreDeciderKey: Value = Value( - DeciderConstants.enableUserAgathaScoreDeciderKey - ) -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/FeatureSwitchesBuilder.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/FeatureSwitchesBuilder.docx new file mode 100644 index 000000000..2e722a522 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/FeatureSwitchesBuilder.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/FeatureSwitchesBuilder.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/FeatureSwitchesBuilder.scala deleted file mode 100644 index a3b269cba..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/FeatureSwitchesBuilder.scala +++ /dev/null @@ -1,34 +0,0 @@ -package com.twitter.tsp.common - -import com.twitter.abdecider.LoggingABDecider -import com.twitter.featureswitches.v2.FeatureSwitches -import com.twitter.featureswitches.v2.builder.{FeatureSwitchesBuilder => FsBuilder} -import com.twitter.featureswitches.v2.experimentation.NullBucketImpressor -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.util.Duration - -case class FeatureSwitchesBuilder( - statsReceiver: StatsReceiver, - abDecider: LoggingABDecider, - featuresDirectory: String, - addServiceDetailsFromAurora: Boolean, - configRepoDirectory: String = "/usr/local/config", - fastRefresh: Boolean = false, - impressExperiments: Boolean = true) { - - def build(): FeatureSwitches = { - val featureSwitches = FsBuilder() - .abDecider(abDecider) - .statsReceiver(statsReceiver) - .configRepoAbsPath(configRepoDirectory) - .featuresDirectory(featuresDirectory) - .limitToReferencedExperiments(shouldLimit = true) - .experimentImpressionStatsEnabled(true) - - if (!impressExperiments) featureSwitches.experimentBucketImpressor(NullBucketImpressor) - if (addServiceDetailsFromAurora) featureSwitches.serviceDetailsFromAurora() - if (fastRefresh) featureSwitches.refreshPeriod(Duration.fromSeconds(10)) - - featureSwitches.build() - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/LoadShedder.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/LoadShedder.docx new file mode 100644 index 000000000..89c7e3bc6 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/LoadShedder.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/LoadShedder.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/LoadShedder.scala deleted file mode 100644 index 2071ea07e..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/LoadShedder.scala +++ /dev/null @@ -1,44 +0,0 @@ -package com.twitter.tsp.common - -import com.twitter.decider.Decider -import com.twitter.decider.RandomRecipient -import com.twitter.util.Future -import javax.inject.Inject -import scala.util.control.NoStackTrace - -/* - Provides deciders-controlled load shedding for a given displayLocation - The format of the decider keys is: - - enable_loadshedding_ - E.g.: - enable_loadshedding_HomeTimeline - - Deciders are fractional, so a value of 50.00 will drop 50% of responses. If a decider key is not - defined for a particular displayLocation, those requests will always be served. - - We should therefore aim to define keys for the locations we care most about in decider.yml, - so that we can control them during incidents. - */ -class LoadShedder @Inject() (decider: Decider) { - import LoadShedder._ - - // Fall back to False for any undefined key - private val deciderWithFalseFallback: Decider = decider.orElse(Decider.False) - private val keyPrefix = "enable_loadshedding" - - def apply[T](typeString: String)(serve: => Future[T]): Future[T] = { - /* - Per-typeString level load shedding: enable_loadshedding_HomeTimeline - Checks if per-typeString load shedding is enabled - */ - val keyTyped = s"${keyPrefix}_$typeString" - if (deciderWithFalseFallback.isAvailable(keyTyped, recipient = Some(RandomRecipient))) - Future.exception(LoadSheddingException) - else serve - } -} - -object LoadShedder { - object LoadSheddingException extends Exception with NoStackTrace -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/ParamsBuilder.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/ParamsBuilder.docx new file mode 100644 index 000000000..f7b90aef0 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/ParamsBuilder.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/ParamsBuilder.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/ParamsBuilder.scala deleted file mode 100644 index 93fe9cbaf..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/ParamsBuilder.scala +++ /dev/null @@ -1,98 +0,0 @@ -package com.twitter.tsp.common - -import com.twitter.abdecider.LoggingABDecider -import com.twitter.abdecider.UserRecipient -import com.twitter.contentrecommender.thriftscala.DisplayLocation -import com.twitter.discovery.common.configapi.FeatureContextBuilder -import com.twitter.featureswitches.FSRecipient -import com.twitter.featureswitches.Recipient -import com.twitter.featureswitches.UserAgent -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.interests.thriftscala.TopicListingViewerContext -import com.twitter.timelines.configapi -import com.twitter.timelines.configapi.Params -import com.twitter.timelines.configapi.RequestContext -import com.twitter.timelines.configapi.abdecider.LoggingABDeciderExperimentContext - -case class ParamsBuilder( - featureContextBuilder: FeatureContextBuilder, - abDecider: LoggingABDecider, - overridesConfig: configapi.Config, - statsReceiver: StatsReceiver) { - - def buildFromTopicListingViewerContext( - topicListingViewerContext: Option[TopicListingViewerContext], - displayLocation: DisplayLocation, - userRoleOverride: Option[Set[String]] = None - ): Params = { - - topicListingViewerContext.flatMap(_.userId) match { - case Some(userId) => - val userRecipient = ParamsBuilder.toFeatureSwitchRecipientWithTopicContext( - userId, - userRoleOverride, - topicListingViewerContext, - Some(displayLocation) - ) - - overridesConfig( - requestContext = RequestContext( - userId = Some(userId), - experimentContext = LoggingABDeciderExperimentContext( - abDecider, - Some(UserRecipient(userId, Some(userId)))), - featureContext = featureContextBuilder( - Some(userId), - Some(userRecipient) - ) - ), - statsReceiver - ) - case _ => - throw new IllegalArgumentException( - s"${this.getClass.getSimpleName} tried to build Param for a request without a userId" - ) - } - } -} - -object ParamsBuilder { - - def toFeatureSwitchRecipientWithTopicContext( - userId: Long, - userRolesOverride: Option[Set[String]], - context: Option[TopicListingViewerContext], - displayLocationOpt: Option[DisplayLocation] - ): Recipient = { - val userRoles = userRolesOverride match { - case Some(overrides) => Some(overrides) - case _ => context.flatMap(_.userRoles.map(_.toSet)) - } - - val recipient = FSRecipient( - userId = Some(userId), - userRoles = userRoles, - deviceId = context.flatMap(_.deviceId), - guestId = context.flatMap(_.guestId), - languageCode = context.flatMap(_.languageCode), - countryCode = context.flatMap(_.countryCode), - userAgent = context.flatMap(_.userAgent).flatMap(UserAgent(_)), - isVerified = None, - isTwoffice = None, - tooClient = None, - highWaterMark = None - ) - displayLocationOpt match { - case Some(displayLocation) => - recipient.withCustomFields(displayLocationCustomFieldMap(displayLocation)) - case None => - recipient - } - } - - private val DisplayLocationCustomField = "display_location" - - def displayLocationCustomFieldMap(displayLocation: DisplayLocation): (String, String) = - DisplayLocationCustomField -> displayLocation.toString - -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/RecTargetFactory.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/RecTargetFactory.docx new file mode 100644 index 000000000..5c76dcc13 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/RecTargetFactory.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/RecTargetFactory.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/RecTargetFactory.scala deleted file mode 100644 index 26eeda736..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/RecTargetFactory.scala +++ /dev/null @@ -1,65 +0,0 @@ -package com.twitter.tsp.common - -import com.twitter.abdecider.LoggingABDecider -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.common.base.TargetUser -import com.twitter.frigate.common.candidate.TargetABDecider -import com.twitter.frigate.common.util.ABDeciderWithOverride -import com.twitter.gizmoduck.thriftscala.User -import com.twitter.simclusters_v2.common.UserId -import com.twitter.storehaus.ReadableStore -import com.twitter.timelines.configapi.Params -import com.twitter.tsp.thriftscala.TopicSocialProofRequest -import com.twitter.util.Future - -case class DefaultRecTopicSocialProofTarget( - topicSocialProofRequest: TopicSocialProofRequest, - targetId: UserId, - user: Option[User], - abDecider: ABDeciderWithOverride, - params: Params -)( - implicit statsReceiver: StatsReceiver) - extends TargetUser - with TopicSocialProofRecRequest - with TargetABDecider { - override def globalStats: StatsReceiver = statsReceiver - override val targetUser: Future[Option[User]] = Future.value(user) -} - -trait TopicSocialProofRecRequest { - tuc: TargetUser => - - val topicSocialProofRequest: TopicSocialProofRequest -} - -case class RecTargetFactory( - abDecider: LoggingABDecider, - userStore: ReadableStore[UserId, User], - paramBuilder: ParamsBuilder, - statsReceiver: StatsReceiver) { - - type RecTopicSocialProofTarget = DefaultRecTopicSocialProofTarget - - def buildRecTopicSocialProofTarget( - request: TopicSocialProofRequest - ): Future[RecTopicSocialProofTarget] = { - val userId = request.userId - userStore.get(userId).map { userOpt => - val userRoles = userOpt.flatMap(_.roles.map(_.roles.toSet)) - - val context = request.context.copy(userId = Some(request.userId)) // override to make sure - - val params = paramBuilder - .buildFromTopicListingViewerContext(Some(context), request.displayLocation, userRoles) - - DefaultRecTopicSocialProofTarget( - request, - userId, - userOpt, - ABDeciderWithOverride(abDecider, None)(statsReceiver), - params - )(statsReceiver) - } - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofDecider.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofDecider.docx new file mode 100644 index 000000000..5a103f4b1 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofDecider.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofDecider.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofDecider.scala deleted file mode 100644 index 39a4acb89..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofDecider.scala +++ /dev/null @@ -1,26 +0,0 @@ -package com.twitter.tsp -package common - -import com.twitter.decider.Decider -import com.twitter.decider.RandomRecipient -import com.twitter.decider.Recipient -import com.twitter.simclusters_v2.common.DeciderGateBuilderWithIdHashing -import javax.inject.Inject - -case class TopicSocialProofDecider @Inject() (decider: Decider) { - - def isAvailable(feature: String, recipient: Option[Recipient]): Boolean = { - decider.isAvailable(feature, recipient) - } - - lazy val deciderGateBuilder = new DeciderGateBuilderWithIdHashing(decider) - - /** - * When useRandomRecipient is set to false, the decider is either completely on or off. - * When useRandomRecipient is set to true, the decider is on for the specified % of traffic. - */ - def isAvailable(feature: String, useRandomRecipient: Boolean = true): Boolean = { - if (useRandomRecipient) isAvailable(feature, Some(RandomRecipient)) - else isAvailable(feature, None) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofParams.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofParams.docx new file mode 100644 index 000000000..be5b9b5e4 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofParams.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofParams.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofParams.scala deleted file mode 100644 index 4effe1313..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/common/TopicSocialProofParams.scala +++ /dev/null @@ -1,104 +0,0 @@ -package com.twitter.tsp.common - -import com.twitter.finagle.stats.NullStatsReceiver -import com.twitter.logging.Logger -import com.twitter.timelines.configapi.BaseConfig -import com.twitter.timelines.configapi.BaseConfigBuilder -import com.twitter.timelines.configapi.FSBoundedParam -import com.twitter.timelines.configapi.FSParam -import com.twitter.timelines.configapi.FeatureSwitchOverrideUtil - -object TopicSocialProofParams { - - object TopicTweetsSemanticCoreVersionId - extends FSBoundedParam[Long]( - name = "topic_tweets_semantic_core_annotation_version_id", - default = 1433487161551032320L, - min = 0L, - max = Long.MaxValue - ) - object TopicTweetsSemanticCoreVersionIdsSet - extends FSParam[Set[Long]]( - name = "topic_tweets_semantic_core_annotation_version_id_allowed_set", - default = Set(TopicTweetsSemanticCoreVersionId.default)) - - /** - * Controls the Topic Social Proof cosine similarity threshold for the Topic Tweets. - */ - object TweetToTopicCosineSimilarityThreshold - extends FSBoundedParam[Double]( - name = "topic_tweets_cosine_similarity_threshold_tsp", - default = 0.0, - min = 0.0, - max = 1.0 - ) - - object EnablePersonalizedContextTopics // master feature switch to enable backfill - extends FSParam[Boolean]( - name = "topic_tweets_personalized_contexts_enable_personalized_contexts", - default = false - ) - - object EnableYouMightLikeTopic - extends FSParam[Boolean]( - name = "topic_tweets_personalized_contexts_enable_you_might_like", - default = false - ) - - object EnableRecentEngagementsTopic - extends FSParam[Boolean]( - name = "topic_tweets_personalized_contexts_enable_recent_engagements", - default = false - ) - - object EnableTopicTweetHealthFilterPersonalizedContexts - extends FSParam[Boolean]( - name = "topic_tweets_personalized_contexts_health_switch", - default = true - ) - - object EnableTweetToTopicScoreRanking - extends FSParam[Boolean]( - name = "topic_tweets_enable_tweet_to_topic_score_ranking", - default = true - ) - -} - -object FeatureSwitchConfig { - private val enumFeatureSwitchOverrides = FeatureSwitchOverrideUtil - .getEnumFSOverrides( - NullStatsReceiver, - Logger(getClass), - ) - - private val intFeatureSwitchOverrides = FeatureSwitchOverrideUtil.getBoundedIntFSOverrides() - - private val longFeatureSwitchOverrides = FeatureSwitchOverrideUtil.getBoundedLongFSOverrides( - TopicSocialProofParams.TopicTweetsSemanticCoreVersionId - ) - - private val doubleFeatureSwitchOverrides = FeatureSwitchOverrideUtil.getBoundedDoubleFSOverrides( - TopicSocialProofParams.TweetToTopicCosineSimilarityThreshold, - ) - - private val longSetFeatureSwitchOverrides = FeatureSwitchOverrideUtil.getLongSetFSOverrides( - TopicSocialProofParams.TopicTweetsSemanticCoreVersionIdsSet, - ) - - private val booleanFeatureSwitchOverrides = FeatureSwitchOverrideUtil.getBooleanFSOverrides( - TopicSocialProofParams.EnablePersonalizedContextTopics, - TopicSocialProofParams.EnableYouMightLikeTopic, - TopicSocialProofParams.EnableRecentEngagementsTopic, - TopicSocialProofParams.EnableTopicTweetHealthFilterPersonalizedContexts, - TopicSocialProofParams.EnableTweetToTopicScoreRanking, - ) - val config: BaseConfig = BaseConfigBuilder() - .set(enumFeatureSwitchOverrides: _*) - .set(intFeatureSwitchOverrides: _*) - .set(longFeatureSwitchOverrides: _*) - .set(doubleFeatureSwitchOverrides: _*) - .set(longSetFeatureSwitchOverrides: _*) - .set(booleanFeatureSwitchOverrides: _*) - .build() -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/BUILD deleted file mode 100644 index dc280e03d..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala", - "stitch/stitch-storehaus", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/common", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/stores", - "topic-social-proof/server/src/main/thrift:thrift-scala", - "topiclisting/topiclisting-core/src/main/scala/com/twitter/topiclisting", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/BUILD.docx new file mode 100644 index 000000000..aa1cb9197 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/TopicSocialProofHandler.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/TopicSocialProofHandler.docx new file mode 100644 index 000000000..9c4bf6627 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/TopicSocialProofHandler.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/TopicSocialProofHandler.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/TopicSocialProofHandler.scala deleted file mode 100644 index 848ec1d72..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/TopicSocialProofHandler.scala +++ /dev/null @@ -1,587 +0,0 @@ -package com.twitter.tsp.handlers - -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.mux.ClientDiscardedRequestException -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.common.util.StatsUtil -import com.twitter.simclusters_v2.common.SemanticCoreEntityId -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.simclusters_v2.thriftscala.EmbeddingType -import com.twitter.simclusters_v2.thriftscala.ModelVersion -import com.twitter.strato.response.Err -import com.twitter.storehaus.ReadableStore -import com.twitter.timelines.configapi.Params -import com.twitter.topic_recos.common.Configs.ConsumerTopicEmbeddingType -import com.twitter.topic_recos.common.Configs.DefaultModelVersion -import com.twitter.topic_recos.common.Configs.ProducerTopicEmbeddingType -import com.twitter.topic_recos.common.Configs.TweetEmbeddingType -import com.twitter.topiclisting.TopicListingViewerContext -import com.twitter.topic_recos.common.LocaleUtil -import com.twitter.topiclisting.AnnotationRuleProvider -import com.twitter.tsp.common.DeciderConstants -import com.twitter.tsp.common.LoadShedder -import com.twitter.tsp.common.RecTargetFactory -import com.twitter.tsp.common.TopicSocialProofDecider -import com.twitter.tsp.common.TopicSocialProofParams -import com.twitter.tsp.stores.TopicSocialProofStore -import com.twitter.tsp.stores.TopicSocialProofStore.TopicSocialProof -import com.twitter.tsp.stores.UttTopicFilterStore -import com.twitter.tsp.stores.TopicTweetsCosineSimilarityAggregateStore.ScoreKey -import com.twitter.tsp.thriftscala.MetricTag -import com.twitter.tsp.thriftscala.TopicFollowType -import com.twitter.tsp.thriftscala.TopicListingSetting -import com.twitter.tsp.thriftscala.TopicSocialProofRequest -import com.twitter.tsp.thriftscala.TopicSocialProofResponse -import com.twitter.tsp.thriftscala.TopicWithScore -import com.twitter.tsp.thriftscala.TspTweetInfo -import com.twitter.tsp.utils.HealthSignalsUtils -import com.twitter.util.Future -import com.twitter.util.Timer -import com.twitter.util.Duration -import com.twitter.util.TimeoutException - -import scala.util.Random - -class TopicSocialProofHandler( - topicSocialProofStore: ReadableStore[TopicSocialProofStore.Query, Seq[TopicSocialProof]], - tweetInfoStore: ReadableStore[TweetId, TspTweetInfo], - uttTopicFilterStore: UttTopicFilterStore, - recTargetFactory: RecTargetFactory, - decider: TopicSocialProofDecider, - statsReceiver: StatsReceiver, - loadShedder: LoadShedder, - timer: Timer) { - - import TopicSocialProofHandler._ - - def getTopicSocialProofResponse( - request: TopicSocialProofRequest - ): Future[TopicSocialProofResponse] = { - val scopedStats = statsReceiver.scope(request.displayLocation.toString) - scopedStats.counter("fanoutRequests").incr(request.tweetIds.size) - scopedStats.stat("numTweetsPerRequest").add(request.tweetIds.size) - StatsUtil.trackBlockStats(scopedStats) { - recTargetFactory - .buildRecTopicSocialProofTarget(request).flatMap { target => - val enableCosineSimilarityScoreCalculation = - decider.isAvailable(DeciderConstants.enableTopicSocialProofScore) - - val semanticCoreVersionId = - target.params(TopicSocialProofParams.TopicTweetsSemanticCoreVersionId) - - val semanticCoreVersionIdsSet = - target.params(TopicSocialProofParams.TopicTweetsSemanticCoreVersionIdsSet) - - val allowListWithTopicFollowTypeFut = uttTopicFilterStore - .getAllowListTopicsForUser( - request.userId, - request.topicListingSetting, - TopicListingViewerContext - .fromThrift(request.context).copy(languageCode = - LocaleUtil.getStandardLanguageCode(request.context.languageCode)), - request.bypassModes.map(_.toSet) - ).rescue { - case _ => - scopedStats.counter("uttTopicFilterStoreFailure").incr() - Future.value(Map.empty[SemanticCoreEntityId, Option[TopicFollowType]]) - } - - val tweetInfoMapFut: Future[Map[TweetId, Option[TspTweetInfo]]] = Future - .collect( - tweetInfoStore.multiGet(request.tweetIds.toSet) - ).raiseWithin(TweetInfoStoreTimeout)(timer).rescue { - case _: TimeoutException => - scopedStats.counter("tweetInfoStoreTimeout").incr() - Future.value(Map.empty[TweetId, Option[TspTweetInfo]]) - case _ => - scopedStats.counter("tweetInfoStoreFailure").incr() - Future.value(Map.empty[TweetId, Option[TspTweetInfo]]) - } - - val definedTweetInfoMapFut = - keepTweetsWithTweetInfoAndLanguage(tweetInfoMapFut, request.displayLocation.toString) - - Future - .join(definedTweetInfoMapFut, allowListWithTopicFollowTypeFut).map { - case (tweetInfoMap, allowListWithTopicFollowType) => - val tweetIdsToQuery = tweetInfoMap.keys.toSet - val topicProofQueries = - tweetIdsToQuery.map { tweetId => - TopicSocialProofStore.Query( - TopicSocialProofStore.CacheableQuery( - tweetId = tweetId, - tweetLanguage = LocaleUtil.getSupportedStandardLanguageCodeWithDefault( - tweetInfoMap.getOrElse(tweetId, None).flatMap { - _.language - }), - enableCosineSimilarityScoreCalculation = - enableCosineSimilarityScoreCalculation - ), - allowedSemanticCoreVersionIds = semanticCoreVersionIdsSet - ) - } - - val topicSocialProofsFut: Future[Map[TweetId, Seq[TopicSocialProof]]] = { - Future - .collect(topicSocialProofStore.multiGet(topicProofQueries)).map(_.map { - case (query, results) => - query.cacheableQuery.tweetId -> results.toSeq.flatten.filter( - _.semanticCoreVersionId == semanticCoreVersionId) - }) - }.raiseWithin(TopicSocialProofStoreTimeout)(timer).rescue { - case _: TimeoutException => - scopedStats.counter("topicSocialProofStoreTimeout").incr() - Future(Map.empty[TweetId, Seq[TopicSocialProof]]) - case _ => - scopedStats.counter("topicSocialProofStoreFailure").incr() - Future(Map.empty[TweetId, Seq[TopicSocialProof]]) - } - - val random = new Random(seed = request.userId.toInt) - - topicSocialProofsFut.map { topicSocialProofs => - val filteredTopicSocialProofs = filterByAllowedList( - topicSocialProofs, - request.topicListingSetting, - allowListWithTopicFollowType.keySet - ) - - val filteredTopicSocialProofsEmptyCount: Int = - filteredTopicSocialProofs.count { - case (_, topicSocialProofs: Seq[TopicSocialProof]) => - topicSocialProofs.isEmpty - } - - scopedStats - .counter("filteredTopicSocialProofsCount").incr(filteredTopicSocialProofs.size) - scopedStats - .counter("filteredTopicSocialProofsEmptyCount").incr( - filteredTopicSocialProofsEmptyCount) - - if (isCrTopicTweets(request)) { - val socialProofs = filteredTopicSocialProofs.mapValues(_.flatMap { topicProof => - val topicWithScores = buildTopicWithRandomScore( - topicProof, - allowListWithTopicFollowType, - random - ) - topicWithScores - }) - TopicSocialProofResponse(socialProofs) - } else { - val socialProofs = filteredTopicSocialProofs.mapValues(_.flatMap { topicProof => - getTopicProofScore( - topicProof = topicProof, - allowListWithTopicFollowType = allowListWithTopicFollowType, - params = target.params, - random = random, - statsReceiver = statsReceiver - ) - - }.sortBy(-_.score).take(MaxCandidates)) - - val personalizedContextSocialProofs = - if (target.params(TopicSocialProofParams.EnablePersonalizedContextTopics)) { - val personalizedContextEligibility = - checkPersonalizedContextsEligibility( - target.params, - allowListWithTopicFollowType) - val filteredTweets = - filterPersonalizedContexts(socialProofs, tweetInfoMap, target.params) - backfillPersonalizedContexts( - allowListWithTopicFollowType, - filteredTweets, - request.tags.getOrElse(Map.empty), - personalizedContextEligibility) - } else { - Map.empty[TweetId, Seq[TopicWithScore]] - } - - val mergedSocialProofs = socialProofs.map { - case (tweetId, proofs) => - ( - tweetId, - proofs - ++ personalizedContextSocialProofs.getOrElse(tweetId, Seq.empty)) - } - - // Note that we will NOT filter out tweets with no TSP in either case - TopicSocialProofResponse(mergedSocialProofs) - } - } - } - }.flatten.raiseWithin(Timeout)(timer).rescue { - case _: ClientDiscardedRequestException => - scopedStats.counter("ClientDiscardedRequestException").incr() - Future.value(DefaultResponse) - case err: Err if err.code == Err.Cancelled => - scopedStats.counter("CancelledErr").incr() - Future.value(DefaultResponse) - case _ => - scopedStats.counter("FailedRequests").incr() - Future.value(DefaultResponse) - } - } - } - - /** - * Fetch the Score for each Topic Social Proof - */ - private def getTopicProofScore( - topicProof: TopicSocialProof, - allowListWithTopicFollowType: Map[SemanticCoreEntityId, Option[TopicFollowType]], - params: Params, - random: Random, - statsReceiver: StatsReceiver - ): Option[TopicWithScore] = { - val scopedStats = statsReceiver.scope("getTopicProofScores") - val enableTweetToTopicScoreRanking = - params(TopicSocialProofParams.EnableTweetToTopicScoreRanking) - - val minTweetToTopicCosineSimilarityThreshold = - params(TopicSocialProofParams.TweetToTopicCosineSimilarityThreshold) - - val topicWithScore = - if (enableTweetToTopicScoreRanking) { - scopedStats.counter("enableTweetToTopicScoreRanking").incr() - buildTopicWithValidScore( - topicProof, - TweetEmbeddingType, - Some(ConsumerTopicEmbeddingType), - Some(ProducerTopicEmbeddingType), - allowListWithTopicFollowType, - DefaultModelVersion, - minTweetToTopicCosineSimilarityThreshold - ) - } else { - scopedStats.counter("buildTopicWithRandomScore").incr() - buildTopicWithRandomScore( - topicProof, - allowListWithTopicFollowType, - random - ) - } - topicWithScore - - } - - private[handlers] def isCrTopicTweets( - request: TopicSocialProofRequest - ): Boolean = { - // CrTopic (across a variety of DisplayLocations) is the only use case with TopicListingSetting.All - request.topicListingSetting == TopicListingSetting.All - } - - /** - * Consolidate logics relevant to whether only quality topics should be enabled for Implicit Follows - */ - - /*** - * Consolidate logics relevant to whether Personalized Contexts backfilling should be enabled - */ - private[handlers] def checkPersonalizedContextsEligibility( - params: Params, - allowListWithTopicFollowType: Map[SemanticCoreEntityId, Option[TopicFollowType]] - ): PersonalizedContextEligibility = { - val scopedStats = statsReceiver.scope("checkPersonalizedContextsEligibility") - val isRecentFavInAllowlist = allowListWithTopicFollowType - .contains(AnnotationRuleProvider.recentFavTopicId) - - val isRecentFavEligible = - isRecentFavInAllowlist && params(TopicSocialProofParams.EnableRecentEngagementsTopic) - if (isRecentFavEligible) - scopedStats.counter("isRecentFavEligible").incr() - - val isRecentRetweetInAllowlist = allowListWithTopicFollowType - .contains(AnnotationRuleProvider.recentRetweetTopicId) - - val isRecentRetweetEligible = - isRecentRetweetInAllowlist && params(TopicSocialProofParams.EnableRecentEngagementsTopic) - if (isRecentRetweetEligible) - scopedStats.counter("isRecentRetweetEligible").incr() - - val isYMLInAllowlist = allowListWithTopicFollowType - .contains(AnnotationRuleProvider.youMightLikeTopicId) - - val isYMLEligible = - isYMLInAllowlist && params(TopicSocialProofParams.EnableYouMightLikeTopic) - if (isYMLEligible) - scopedStats.counter("isYMLEligible").incr() - - PersonalizedContextEligibility(isRecentFavEligible, isRecentRetweetEligible, isYMLEligible) - } - - private[handlers] def filterPersonalizedContexts( - socialProofs: Map[TweetId, Seq[TopicWithScore]], - tweetInfoMap: Map[TweetId, Option[TspTweetInfo]], - params: Params - ): Map[TweetId, Seq[TopicWithScore]] = { - val filters: Seq[(Option[TspTweetInfo], Params) => Boolean] = Seq( - healthSignalsFilter, - tweetLanguageFilter - ) - applyFilters(socialProofs, tweetInfoMap, params, filters) - } - - /** * - * filter tweets with None tweetInfo and undefined language - */ - private def keepTweetsWithTweetInfoAndLanguage( - tweetInfoMapFut: Future[Map[TweetId, Option[TspTweetInfo]]], - displayLocation: String - ): Future[Map[TweetId, Option[TspTweetInfo]]] = { - val scopedStats = statsReceiver.scope(displayLocation) - tweetInfoMapFut.map { tweetInfoMap => - val filteredTweetInfoMap = tweetInfoMap.filter { - case (_, optTweetInfo: Option[TspTweetInfo]) => - if (optTweetInfo.isEmpty) { - scopedStats.counter("undefinedTweetInfoCount").incr() - } - - optTweetInfo.exists { tweetInfo: TspTweetInfo => - { - if (tweetInfo.language.isEmpty) { - scopedStats.counter("undefinedLanguageCount").incr() - } - tweetInfo.language.isDefined - } - } - - } - val undefinedTweetInfoOrLangCount = tweetInfoMap.size - filteredTweetInfoMap.size - scopedStats.counter("undefinedTweetInfoOrLangCount").incr(undefinedTweetInfoOrLangCount) - - scopedStats.counter("TweetInfoCount").incr(tweetInfoMap.size) - - filteredTweetInfoMap - } - } - - /*** - * filter tweets with NO evergreen topic social proofs by their health signal scores & tweet languages - * i.e., tweets that are possible to be converted into Personalized Context topic tweets - * TBD: whether we are going to apply filters to all topic tweet candidates - */ - private def applyFilters( - socialProofs: Map[TweetId, Seq[TopicWithScore]], - tweetInfoMap: Map[TweetId, Option[TspTweetInfo]], - params: Params, - filters: Seq[(Option[TspTweetInfo], Params) => Boolean] - ): Map[TweetId, Seq[TopicWithScore]] = { - socialProofs.collect { - case (tweetId, socialProofs) if socialProofs.nonEmpty || filters.forall { filter => - filter(tweetInfoMap.getOrElse(tweetId, None), params) - } => - tweetId -> socialProofs - } - } - - private def healthSignalsFilter( - tweetInfoOpt: Option[TspTweetInfo], - params: Params - ): Boolean = { - !params( - TopicSocialProofParams.EnableTopicTweetHealthFilterPersonalizedContexts) || HealthSignalsUtils - .isHealthyTweet(tweetInfoOpt) - } - - private def tweetLanguageFilter( - tweetInfoOpt: Option[TspTweetInfo], - params: Params - ): Boolean = { - PersonalizedContextTopicsAllowedLanguageSet - .contains(tweetInfoOpt.flatMap(_.language).getOrElse(LocaleUtil.DefaultLanguage)) - } - - private[handlers] def backfillPersonalizedContexts( - allowListWithTopicFollowType: Map[SemanticCoreEntityId, Option[TopicFollowType]], - socialProofs: Map[TweetId, Seq[TopicWithScore]], - metricTagsMap: scala.collection.Map[TweetId, scala.collection.Set[MetricTag]], - personalizedContextEligibility: PersonalizedContextEligibility - ): Map[TweetId, Seq[TopicWithScore]] = { - val scopedStats = statsReceiver.scope("backfillPersonalizedContexts") - socialProofs.map { - case (tweetId, topicWithScores) => - if (topicWithScores.nonEmpty) { - tweetId -> Seq.empty - } else { - val metricTagContainsTweetFav = metricTagsMap - .getOrElse(tweetId, Set.empty[MetricTag]).contains(MetricTag.TweetFavorite) - val backfillRecentFav = - personalizedContextEligibility.isRecentFavEligible && metricTagContainsTweetFav - if (metricTagContainsTweetFav) - scopedStats.counter("MetricTag.TweetFavorite").incr() - if (backfillRecentFav) - scopedStats.counter("backfillRecentFav").incr() - - val metricTagContainsRetweet = metricTagsMap - .getOrElse(tweetId, Set.empty[MetricTag]).contains(MetricTag.Retweet) - val backfillRecentRetweet = - personalizedContextEligibility.isRecentRetweetEligible && metricTagContainsRetweet - if (metricTagContainsRetweet) - scopedStats.counter("MetricTag.Retweet").incr() - if (backfillRecentRetweet) - scopedStats.counter("backfillRecentRetweet").incr() - - val metricTagContainsRecentSearches = metricTagsMap - .getOrElse(tweetId, Set.empty[MetricTag]).contains( - MetricTag.InterestsRankerRecentSearches) - - val backfillYML = personalizedContextEligibility.isYMLEligible - if (backfillYML) - scopedStats.counter("backfillYML").incr() - - tweetId -> buildBackfillTopics( - allowListWithTopicFollowType, - backfillRecentFav, - backfillRecentRetweet, - backfillYML) - } - } - } - - private def buildBackfillTopics( - allowListWithTopicFollowType: Map[SemanticCoreEntityId, Option[TopicFollowType]], - backfillRecentFav: Boolean, - backfillRecentRetweet: Boolean, - backfillYML: Boolean - ): Seq[TopicWithScore] = { - Seq( - if (backfillRecentFav) { - Some( - TopicWithScore( - topicId = AnnotationRuleProvider.recentFavTopicId, - score = 1.0, - topicFollowType = allowListWithTopicFollowType - .getOrElse(AnnotationRuleProvider.recentFavTopicId, None) - )) - } else { None }, - if (backfillRecentRetweet) { - Some( - TopicWithScore( - topicId = AnnotationRuleProvider.recentRetweetTopicId, - score = 1.0, - topicFollowType = allowListWithTopicFollowType - .getOrElse(AnnotationRuleProvider.recentRetweetTopicId, None) - )) - } else { None }, - if (backfillYML) { - Some( - TopicWithScore( - topicId = AnnotationRuleProvider.youMightLikeTopicId, - score = 1.0, - topicFollowType = allowListWithTopicFollowType - .getOrElse(AnnotationRuleProvider.youMightLikeTopicId, None) - )) - } else { None } - ).flatten - } - - def toReadableStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse] = { - new ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse] { - override def get(k: TopicSocialProofRequest): Future[Option[TopicSocialProofResponse]] = { - val displayLocation = k.displayLocation.toString - loadShedder(displayLocation) { - getTopicSocialProofResponse(k).map(Some(_)) - }.rescue { - case LoadShedder.LoadSheddingException => - statsReceiver.scope(displayLocation).counter("LoadSheddingException").incr() - Future.None - case _ => - statsReceiver.scope(displayLocation).counter("Exception").incr() - Future.None - } - } - } - } -} - -object TopicSocialProofHandler { - - private val MaxCandidates = 10 - // Currently we do hardcode for the language check of PersonalizedContexts Topics - private val PersonalizedContextTopicsAllowedLanguageSet: Set[String] = - Set("pt", "ko", "es", "ja", "tr", "id", "en", "hi", "ar", "fr", "ru") - - private val Timeout: Duration = 200.milliseconds - private val TopicSocialProofStoreTimeout: Duration = 40.milliseconds - private val TweetInfoStoreTimeout: Duration = 60.milliseconds - private val DefaultResponse: TopicSocialProofResponse = TopicSocialProofResponse(Map.empty) - - case class PersonalizedContextEligibility( - isRecentFavEligible: Boolean, - isRecentRetweetEligible: Boolean, - isYMLEligible: Boolean) - - /** - * Calculate the Topic Scores for each (tweet, topic), filter out topic proofs whose scores do not - * pass the minimum threshold - */ - private[handlers] def buildTopicWithValidScore( - topicProof: TopicSocialProof, - tweetEmbeddingType: EmbeddingType, - maybeConsumerEmbeddingType: Option[EmbeddingType], - maybeProducerEmbeddingType: Option[EmbeddingType], - allowListWithTopicFollowType: Map[SemanticCoreEntityId, Option[TopicFollowType]], - simClustersModelVersion: ModelVersion, - minTweetToTopicCosineSimilarityThreshold: Double - ): Option[TopicWithScore] = { - - val consumerScore = maybeConsumerEmbeddingType - .flatMap { consumerEmbeddingType => - topicProof.scores.get( - ScoreKey(consumerEmbeddingType, tweetEmbeddingType, simClustersModelVersion)) - }.getOrElse(0.0) - - val producerScore = maybeProducerEmbeddingType - .flatMap { producerEmbeddingType => - topicProof.scores.get( - ScoreKey(producerEmbeddingType, tweetEmbeddingType, simClustersModelVersion)) - }.getOrElse(0.0) - - val combinedScore = consumerScore + producerScore - if (combinedScore > minTweetToTopicCosineSimilarityThreshold || topicProof.ignoreSimClusterFiltering) { - Some( - TopicWithScore( - topicId = topicProof.topicId.entityId, - score = combinedScore, - topicFollowType = - allowListWithTopicFollowType.getOrElse(topicProof.topicId.entityId, None))) - } else { - None - } - } - - private[handlers] def buildTopicWithRandomScore( - topicSocialProof: TopicSocialProof, - allowListWithTopicFollowType: Map[SemanticCoreEntityId, Option[TopicFollowType]], - random: Random - ): Option[TopicWithScore] = { - - Some( - TopicWithScore( - topicId = topicSocialProof.topicId.entityId, - score = random.nextDouble(), - topicFollowType = - allowListWithTopicFollowType.getOrElse(topicSocialProof.topicId.entityId, None) - )) - } - - /** - * Filter all the non-qualified Topic Social Proof - */ - private[handlers] def filterByAllowedList( - topicProofs: Map[TweetId, Seq[TopicSocialProof]], - setting: TopicListingSetting, - allowList: Set[SemanticCoreEntityId] - ): Map[TweetId, Seq[TopicSocialProof]] = { - setting match { - case TopicListingSetting.All => - // Return all the topics - topicProofs - case _ => - topicProofs.mapValues( - _.filter(topicProof => allowList.contains(topicProof.topicId.entityId))) - } - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/UttChildrenWarmupHandler.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/UttChildrenWarmupHandler.docx new file mode 100644 index 000000000..b267c1e83 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/UttChildrenWarmupHandler.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/UttChildrenWarmupHandler.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/UttChildrenWarmupHandler.scala deleted file mode 100644 index b431685c8..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers/UttChildrenWarmupHandler.scala +++ /dev/null @@ -1,40 +0,0 @@ -package com.twitter.tsp.handlers - -import com.twitter.inject.utils.Handler -import com.twitter.topiclisting.FollowableTopicProductId -import com.twitter.topiclisting.ProductId -import com.twitter.topiclisting.TopicListingViewerContext -import com.twitter.topiclisting.utt.UttLocalization -import com.twitter.util.logging.Logging -import javax.inject.Inject -import javax.inject.Singleton - -/** * - * We configure Warmer to help warm up the cache hit rate under `CachedUttClient/get_utt_taxonomy/cache_hit_rate` - * In uttLocalization.getRecommendableTopics, we fetch all topics exist in UTT, and yet the process - * is in fact fetching the complete UTT tree struct (by calling getUttChildren recursively), which could take 1 sec - * Once we have the topics, we stored them in in-memory cache, and the cache hit rate is > 99% - * - */ -@Singleton -class UttChildrenWarmupHandler @Inject() (uttLocalization: UttLocalization) - extends Handler - with Logging { - - /** Executes the function of this handler. * */ - override def handle(): Unit = { - uttLocalization - .getRecommendableTopics( - productId = ProductId.Followable, - viewerContext = TopicListingViewerContext(languageCode = Some("en")), - enableInternationalTopics = true, - followableTopicProductId = FollowableTopicProductId.AllFollowable - ) - .onSuccess { result => - logger.info(s"successfully warmed up UttChildren. TopicId length = ${result.size}") - } - .onFailure { throwable => - logger.info(s"failed to warm up UttChildren. Throwable = ${throwable}") - } - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/BUILD deleted file mode 100644 index d68c9ad23..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "3rdparty/jvm/com/twitter/bijection:scrooge", - "3rdparty/jvm/com/twitter/storehaus:memcache", - "escherbird/src/scala/com/twitter/escherbird/util/uttclient", - "escherbird/src/thrift/com/twitter/escherbird/utt:strato-columns-scala", - "finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authentication", - "finatra-internal/mtls-thriftmux/src/main/scala", - "finatra/inject/inject-core/src/main/scala", - "finatra/inject/inject-thrift-client", - "frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/strato", - "hermit/hermit-core/src/main/scala/com/twitter/hermit/store/common", - "src/scala/com/twitter/storehaus_internal/memcache", - "src/scala/com/twitter/storehaus_internal/util", - "src/thrift/com/twitter/gizmoduck:thrift-scala", - "src/thrift/com/twitter/gizmoduck:user-thrift-scala", - "stitch/stitch-storehaus", - "stitch/stitch-tweetypie/src/main/scala", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/common", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/stores", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/utils", - "topic-social-proof/server/src/main/thrift:thrift-scala", - "topiclisting/common/src/main/scala/com/twitter/topiclisting/clients", - "topiclisting/topiclisting-utt/src/main/scala/com/twitter/topiclisting/utt", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/BUILD.docx new file mode 100644 index 000000000..08668b5dd Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/GizmoduckUserModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/GizmoduckUserModule.docx new file mode 100644 index 000000000..f1cbb85bb Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/GizmoduckUserModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/GizmoduckUserModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/GizmoduckUserModule.scala deleted file mode 100644 index a700d9fef..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/GizmoduckUserModule.scala +++ /dev/null @@ -1,35 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Module -import com.twitter.finagle.ThriftMux -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.mtls.client.MtlsStackClient._ -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.finagle.thrift.ClientId -import com.twitter.finatra.mtls.thriftmux.modules.MtlsClient -import com.twitter.gizmoduck.thriftscala.UserService -import com.twitter.inject.Injector -import com.twitter.inject.thrift.modules.ThriftMethodBuilderClientModule - -object GizmoduckUserModule - extends ThriftMethodBuilderClientModule[ - UserService.ServicePerEndpoint, - UserService.MethodPerEndpoint - ] - with MtlsClient { - - override val label: String = "gizmoduck" - override val dest: String = "/s/gizmoduck/gizmoduck" - override val modules: Seq[Module] = Seq(TSPClientIdModule) - - override def configureThriftMuxClient( - injector: Injector, - client: ThriftMux.Client - ): ThriftMux.Client = { - super - .configureThriftMuxClient(injector, client) - .withMutualTls(injector.instance[ServiceIdentifier]) - .withClientId(injector.instance[ClientId]) - .withStatsReceiver(injector.instance[StatsReceiver].scope("giz")) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/RepresentationScorerStoreModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/RepresentationScorerStoreModule.docx new file mode 100644 index 000000000..f3f6aeedb Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/RepresentationScorerStoreModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/RepresentationScorerStoreModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/RepresentationScorerStoreModule.scala deleted file mode 100644 index 329276d8d..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/RepresentationScorerStoreModule.scala +++ /dev/null @@ -1,47 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Module -import com.google.inject.Provides -import com.google.inject.Singleton -import com.twitter.app.Flag -import com.twitter.bijection.scrooge.BinaryScalaCodec -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.memcached.{Client => MemClient} -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.hermit.store.common.ObservedMemcachedReadableStore -import com.twitter.inject.TwitterModule -import com.twitter.simclusters_v2.thriftscala.Score -import com.twitter.simclusters_v2.thriftscala.ScoreId -import com.twitter.storehaus.ReadableStore -import com.twitter.strato.client.{Client => StratoClient} -import com.twitter.tsp.stores.RepresentationScorerStore - -object RepresentationScorerStoreModule extends TwitterModule { - override def modules: Seq[Module] = Seq(UnifiedCacheClient) - - private val tspRepresentationScoringColumnPath: Flag[String] = flag[String]( - name = "tsp.representationScoringColumnPath", - default = "recommendations/representation_scorer/score", - help = "Strato column path for Representation Scorer Store" - ) - - @Provides - @Singleton - def providesRepresentationScorerStore( - statsReceiver: StatsReceiver, - stratoClient: StratoClient, - tspUnifiedCacheClient: MemClient - ): ReadableStore[ScoreId, Score] = { - val underlyingStore = - RepresentationScorerStore(stratoClient, tspRepresentationScoringColumnPath(), statsReceiver) - ObservedMemcachedReadableStore.fromCacheClient( - backingStore = underlyingStore, - cacheClient = tspUnifiedCacheClient, - ttl = 2.hours - )( - valueInjection = BinaryScalaCodec(Score), - statsReceiver = statsReceiver.scope("RepresentationScorerStore"), - keyToString = { k: ScoreId => s"rsx/$k" } - ) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TSPClientIdModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TSPClientIdModule.docx new file mode 100644 index 000000000..def356157 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TSPClientIdModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TSPClientIdModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TSPClientIdModule.scala deleted file mode 100644 index d22ef500f..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TSPClientIdModule.scala +++ /dev/null @@ -1,14 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Provides -import com.twitter.finagle.thrift.ClientId -import com.twitter.inject.TwitterModule -import javax.inject.Singleton - -object TSPClientIdModule extends TwitterModule { - private val clientIdFlag = flag("thrift.clientId", "topic-social-proof.prod", "Thrift client id") - - @Provides - @Singleton - def providesClientId: ClientId = ClientId(clientIdFlag()) -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicListingModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicListingModule.docx new file mode 100644 index 000000000..608a419b2 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicListingModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicListingModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicListingModule.scala deleted file mode 100644 index 3f2768278..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicListingModule.scala +++ /dev/null @@ -1,17 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Provides -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.inject.TwitterModule -import com.twitter.topiclisting.TopicListing -import com.twitter.topiclisting.TopicListingBuilder -import javax.inject.Singleton - -object TopicListingModule extends TwitterModule { - - @Provides - @Singleton - def providesTopicListing(statsReceiver: StatsReceiver): TopicListing = { - new TopicListingBuilder(statsReceiver.scope(namespace = "TopicListingBuilder")).build - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicSocialProofStoreModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicSocialProofStoreModule.docx new file mode 100644 index 000000000..c96a56686 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicSocialProofStoreModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicSocialProofStoreModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicSocialProofStoreModule.scala deleted file mode 100644 index fe63b0e21..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicSocialProofStoreModule.scala +++ /dev/null @@ -1,68 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Module -import com.google.inject.Provides -import com.google.inject.Singleton -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.memcached.{Client => MemClient} -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.hermit.store.common.ObservedCachedReadableStore -import com.twitter.hermit.store.common.ObservedMemcachedReadableStore -import com.twitter.hermit.store.common.ObservedReadableStore -import com.twitter.inject.TwitterModule -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.simclusters_v2.thriftscala.Score -import com.twitter.simclusters_v2.thriftscala.ScoreId -import com.twitter.storehaus.ReadableStore -import com.twitter.strato.client.{Client => StratoClient} -import com.twitter.tsp.stores.SemanticCoreAnnotationStore -import com.twitter.tsp.stores.TopicSocialProofStore -import com.twitter.tsp.stores.TopicSocialProofStore.TopicSocialProof -import com.twitter.tsp.utils.LZ4Injection -import com.twitter.tsp.utils.SeqObjectInjection - -object TopicSocialProofStoreModule extends TwitterModule { - override def modules: Seq[Module] = Seq(UnifiedCacheClient) - - @Provides - @Singleton - def providesTopicSocialProofStore( - representationScorerStore: ReadableStore[ScoreId, Score], - statsReceiver: StatsReceiver, - stratoClient: StratoClient, - tspUnifiedCacheClient: MemClient, - ): ReadableStore[TopicSocialProofStore.Query, Seq[TopicSocialProof]] = { - val semanticCoreAnnotationStore: ReadableStore[TweetId, Seq[ - SemanticCoreAnnotationStore.TopicAnnotation - ]] = ObservedReadableStore( - SemanticCoreAnnotationStore(SemanticCoreAnnotationStore.getStratoStore(stratoClient)) - )(statsReceiver.scope("SemanticCoreAnnotationStore")) - - val underlyingStore = TopicSocialProofStore( - representationScorerStore, - semanticCoreAnnotationStore - )(statsReceiver.scope("TopicSocialProofStore")) - - val memcachedStore = ObservedMemcachedReadableStore.fromCacheClient( - backingStore = underlyingStore, - cacheClient = tspUnifiedCacheClient, - ttl = 15.minutes, - asyncUpdate = true - )( - valueInjection = LZ4Injection.compose(SeqObjectInjection[TopicSocialProof]()), - statsReceiver = statsReceiver.scope("memCachedTopicSocialProofStore"), - keyToString = { k: TopicSocialProofStore.Query => s"tsps/${k.cacheableQuery}" } - ) - - val inMemoryCachedStore = - ObservedCachedReadableStore.from[TopicSocialProofStore.Query, Seq[TopicSocialProof]]( - memcachedStore, - ttl = 10.minutes, - maxKeys = 16777215, // ~ avg 160B, < 3000MB - cacheName = "topic_social_proof_cache", - windowSize = 10000L - )(statsReceiver.scope("InMemoryCachedTopicSocialProofStore")) - - inMemoryCachedStore - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicTweetCosineSimilarityAggregateStoreModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicTweetCosineSimilarityAggregateStoreModule.docx new file mode 100644 index 000000000..5c05bcdce Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicTweetCosineSimilarityAggregateStoreModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicTweetCosineSimilarityAggregateStoreModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicTweetCosineSimilarityAggregateStoreModule.scala deleted file mode 100644 index ac15b3746..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TopicTweetCosineSimilarityAggregateStoreModule.scala +++ /dev/null @@ -1,26 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Provides -import com.google.inject.Singleton -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.inject.TwitterModule -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.simclusters_v2.thriftscala.Score -import com.twitter.simclusters_v2.thriftscala.ScoreId -import com.twitter.simclusters_v2.thriftscala.TopicId -import com.twitter.storehaus.ReadableStore -import com.twitter.tsp.stores.TopicTweetsCosineSimilarityAggregateStore -import com.twitter.tsp.stores.TopicTweetsCosineSimilarityAggregateStore.ScoreKey - -object TopicTweetCosineSimilarityAggregateStoreModule extends TwitterModule { - - @Provides - @Singleton - def providesTopicTweetCosineSimilarityAggregateStore( - representationScorerStore: ReadableStore[ScoreId, Score], - statsReceiver: StatsReceiver, - ): ReadableStore[(TopicId, TweetId, Seq[ScoreKey]), Map[ScoreKey, Double]] = { - TopicTweetsCosineSimilarityAggregateStore(representationScorerStore)( - statsReceiver.scope("topicTweetsCosineSimilarityAggregateStore")) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetInfoStoreModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetInfoStoreModule.docx new file mode 100644 index 000000000..6d6682f0e Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetInfoStoreModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetInfoStoreModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetInfoStoreModule.scala deleted file mode 100644 index 1e08a9209..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetInfoStoreModule.scala +++ /dev/null @@ -1,130 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Module -import com.google.inject.Provides -import com.google.inject.Singleton -import com.twitter.bijection.scrooge.BinaryScalaCodec -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.memcached.{Client => MemClient} -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.common.store.health.TweetHealthModelStore -import com.twitter.frigate.common.store.health.TweetHealthModelStore.TweetHealthModelStoreConfig -import com.twitter.frigate.common.store.health.UserHealthModelStore -import com.twitter.frigate.common.store.interests.UserId -import com.twitter.frigate.thriftscala.TweetHealthScores -import com.twitter.frigate.thriftscala.UserAgathaScores -import com.twitter.hermit.store.common.DeciderableReadableStore -import com.twitter.hermit.store.common.ObservedCachedReadableStore -import com.twitter.hermit.store.common.ObservedMemcachedReadableStore -import com.twitter.inject.TwitterModule -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.stitch.tweetypie.TweetyPie -import com.twitter.storehaus.ReadableStore -import com.twitter.strato.client.{Client => StratoClient} -import com.twitter.tsp.common.DeciderKey -import com.twitter.tsp.common.TopicSocialProofDecider -import com.twitter.tsp.stores.TweetInfoStore -import com.twitter.tsp.stores.TweetyPieFieldsStore -import com.twitter.tweetypie.thriftscala.TweetService -import com.twitter.tsp.thriftscala.TspTweetInfo -import com.twitter.util.JavaTimer -import com.twitter.util.Timer - -object TweetInfoStoreModule extends TwitterModule { - override def modules: Seq[Module] = Seq(UnifiedCacheClient) - implicit val timer: Timer = new JavaTimer(true) - - @Provides - @Singleton - def providesTweetInfoStore( - decider: TopicSocialProofDecider, - serviceIdentifier: ServiceIdentifier, - statsReceiver: StatsReceiver, - stratoClient: StratoClient, - tspUnifiedCacheClient: MemClient, - tweetyPieService: TweetService.MethodPerEndpoint - ): ReadableStore[TweetId, TspTweetInfo] = { - val tweetHealthModelStore: ReadableStore[TweetId, TweetHealthScores] = { - val underlyingStore = TweetHealthModelStore.buildReadableStore( - stratoClient, - Some( - TweetHealthModelStoreConfig( - enablePBlock = true, - enableToxicity = true, - enablePSpammy = true, - enablePReported = true, - enableSpammyTweetContent = true, - enablePNegMultimodal = false)) - )(statsReceiver.scope("UnderlyingTweetHealthModelStore")) - - DeciderableReadableStore( - ObservedMemcachedReadableStore.fromCacheClient( - backingStore = underlyingStore, - cacheClient = tspUnifiedCacheClient, - ttl = 2.hours - )( - valueInjection = BinaryScalaCodec(TweetHealthScores), - statsReceiver = statsReceiver.scope("TweetHealthModelStore"), - keyToString = { k: TweetId => s"tHMS/$k" } - ), - decider.deciderGateBuilder.idGate(DeciderKey.enableHealthSignalsScoreDeciderKey), - statsReceiver.scope("TweetHealthModelStore") - ) - } - - val userHealthModelStore: ReadableStore[UserId, UserAgathaScores] = { - val underlyingStore = - UserHealthModelStore.buildReadableStore(stratoClient)( - statsReceiver.scope("UnderlyingUserHealthModelStore")) - - DeciderableReadableStore( - ObservedMemcachedReadableStore.fromCacheClient( - backingStore = underlyingStore, - cacheClient = tspUnifiedCacheClient, - ttl = 18.hours - )( - valueInjection = BinaryScalaCodec(UserAgathaScores), - statsReceiver = statsReceiver.scope("UserHealthModelStore"), - keyToString = { k: UserId => s"uHMS/$k" } - ), - decider.deciderGateBuilder.idGate(DeciderKey.enableUserAgathaScoreDeciderKey), - statsReceiver.scope("UserHealthModelStore") - ) - } - - val tweetInfoStore: ReadableStore[TweetId, TspTweetInfo] = { - val underlyingStore = TweetInfoStore( - TweetyPieFieldsStore.getStoreFromTweetyPie(TweetyPie(tweetyPieService, statsReceiver)), - tweetHealthModelStore: ReadableStore[TweetId, TweetHealthScores], - userHealthModelStore: ReadableStore[UserId, UserAgathaScores], - timer: Timer - )(statsReceiver.scope("tweetInfoStore")) - - val memcachedStore = ObservedMemcachedReadableStore.fromCacheClient( - backingStore = underlyingStore, - cacheClient = tspUnifiedCacheClient, - ttl = 15.minutes, - // Hydrating tweetInfo is now a required step for all candidates, - // hence we needed to tune these thresholds. - asyncUpdate = serviceIdentifier.environment == "prod" - )( - valueInjection = BinaryScalaCodec(TspTweetInfo), - statsReceiver = statsReceiver.scope("memCachedTweetInfoStore"), - keyToString = { k: TweetId => s"tIS/$k" } - ) - - val inMemoryStore = ObservedCachedReadableStore.from( - memcachedStore, - ttl = 15.minutes, - maxKeys = 8388607, // Check TweetInfo definition. size~92b. Around 736 MB - windowSize = 10000L, - cacheName = "tweet_info_cache", - maxMultiGetSize = 20 - )(statsReceiver.scope("inMemoryCachedTweetInfoStore")) - - inMemoryStore - } - tweetInfoStore - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetyPieClientModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetyPieClientModule.docx new file mode 100644 index 000000000..950e474c8 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetyPieClientModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetyPieClientModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetyPieClientModule.scala deleted file mode 100644 index 98d515dda..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/TweetyPieClientModule.scala +++ /dev/null @@ -1,63 +0,0 @@ -package com.twitter.tsp -package modules - -import com.google.inject.Module -import com.google.inject.Provides -import com.twitter.conversions.DurationOps.richDurationFromInt -import com.twitter.finagle.ThriftMux -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.mtls.client.MtlsStackClient.MtlsThriftMuxClientSyntax -import com.twitter.finagle.mux.ClientDiscardedRequestException -import com.twitter.finagle.service.ReqRep -import com.twitter.finagle.service.ResponseClass -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.finagle.thrift.ClientId -import com.twitter.inject.Injector -import com.twitter.inject.thrift.modules.ThriftMethodBuilderClientModule -import com.twitter.tweetypie.thriftscala.TweetService -import com.twitter.util.Duration -import com.twitter.util.Throw -import com.twitter.stitch.tweetypie.{TweetyPie => STweetyPie} -import com.twitter.finatra.mtls.thriftmux.modules.MtlsClient -import javax.inject.Singleton - -object TweetyPieClientModule - extends ThriftMethodBuilderClientModule[ - TweetService.ServicePerEndpoint, - TweetService.MethodPerEndpoint - ] - with MtlsClient { - override val label = "tweetypie" - override val dest = "/s/tweetypie/tweetypie" - override val requestTimeout: Duration = 450.milliseconds - - override val modules: Seq[Module] = Seq(TSPClientIdModule) - - // We bump the success rate from the default of 0.8 to 0.9 since we're dropping the - // consecutive failures part of the default policy. - override def configureThriftMuxClient( - injector: Injector, - client: ThriftMux.Client - ): ThriftMux.Client = - super - .configureThriftMuxClient(injector, client) - .withMutualTls(injector.instance[ServiceIdentifier]) - .withStatsReceiver(injector.instance[StatsReceiver].scope("clnt")) - .withClientId(injector.instance[ClientId]) - .withResponseClassifier { - case ReqRep(_, Throw(_: ClientDiscardedRequestException)) => ResponseClass.Ignorable - } - .withSessionQualifier - .successRateFailureAccrual(successRate = 0.9, window = 30.seconds) - .withResponseClassifier { - case ReqRep(_, Throw(_: ClientDiscardedRequestException)) => ResponseClass.Ignorable - } - - @Provides - @Singleton - def providesTweetyPie( - tweetyPieService: TweetService.MethodPerEndpoint - ): STweetyPie = { - STweetyPie(tweetyPieService) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UnifiedCacheClient.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UnifiedCacheClient.docx new file mode 100644 index 000000000..e1a73caea Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UnifiedCacheClient.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UnifiedCacheClient.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UnifiedCacheClient.scala deleted file mode 100644 index 8fe65fc73..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UnifiedCacheClient.scala +++ /dev/null @@ -1,33 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Provides -import com.google.inject.Singleton -import com.twitter.app.Flag -import com.twitter.finagle.memcached.Client -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.inject.TwitterModule -import com.twitter.storehaus_internal.memcache.MemcacheStore -import com.twitter.storehaus_internal.util.ClientName -import com.twitter.storehaus_internal.util.ZkEndPoint - -object UnifiedCacheClient extends TwitterModule { - val tspUnifiedCacheDest: Flag[String] = flag[String]( - name = "tsp.unifiedCacheDest", - default = "/srv#/prod/local/cache/topic_social_proof_unified", - help = "Wily path to topic social proof unified cache" - ) - - @Provides - @Singleton - def provideUnifiedCacheClient( - serviceIdentifier: ServiceIdentifier, - statsReceiver: StatsReceiver, - ): Client = - MemcacheStore.memcachedClient( - name = ClientName("topic-social-proof-unified-memcache"), - dest = ZkEndPoint(tspUnifiedCacheDest()), - statsReceiver = statsReceiver.scope("cache_client"), - serviceIdentifier = serviceIdentifier - ) -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttClientModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttClientModule.docx new file mode 100644 index 000000000..96f404147 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttClientModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttClientModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttClientModule.scala deleted file mode 100644 index ae0099b8b..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttClientModule.scala +++ /dev/null @@ -1,41 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Provides -import com.twitter.escherbird.util.uttclient.CacheConfigV2 -import com.twitter.escherbird.util.uttclient.CachedUttClientV2 -import com.twitter.escherbird.util.uttclient.UttClientCacheConfigsV2 -import com.twitter.escherbird.utt.strato.thriftscala.Environment -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.inject.TwitterModule -import com.twitter.strato.client.Client -import com.twitter.topiclisting.clients.utt.UttClient -import javax.inject.Singleton - -object UttClientModule extends TwitterModule { - - @Provides - @Singleton - def providesUttClient( - stratoClient: Client, - statsReceiver: StatsReceiver - ): UttClient = { - - // Save 2 ^ 18 UTTs. Promising 100% cache rate - lazy val defaultCacheConfigV2: CacheConfigV2 = CacheConfigV2(262143) - lazy val uttClientCacheConfigsV2: UttClientCacheConfigsV2 = UttClientCacheConfigsV2( - getTaxonomyConfig = defaultCacheConfigV2, - getUttTaxonomyConfig = defaultCacheConfigV2, - getLeafIds = defaultCacheConfigV2, - getLeafUttEntities = defaultCacheConfigV2 - ) - - // CachedUttClient to use StratoClient - lazy val cachedUttClientV2: CachedUttClientV2 = new CachedUttClientV2( - stratoClient = stratoClient, - env = Environment.Prod, - cacheConfigs = uttClientCacheConfigsV2, - statsReceiver = statsReceiver.scope("CachedUttClient") - ) - new UttClient(cachedUttClientV2, statsReceiver) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttLocalizationModule.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttLocalizationModule.docx new file mode 100644 index 000000000..4beac4fe6 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttLocalizationModule.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttLocalizationModule.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttLocalizationModule.scala deleted file mode 100644 index 7d8844b98..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/modules/UttLocalizationModule.scala +++ /dev/null @@ -1,27 +0,0 @@ -package com.twitter.tsp.modules - -import com.google.inject.Provides -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.inject.TwitterModule -import com.twitter.topiclisting.TopicListing -import com.twitter.topiclisting.clients.utt.UttClient -import com.twitter.topiclisting.utt.UttLocalization -import com.twitter.topiclisting.utt.UttLocalizationImpl -import javax.inject.Singleton - -object UttLocalizationModule extends TwitterModule { - - @Provides - @Singleton - def providesUttLocalization( - topicListing: TopicListing, - uttClient: UttClient, - statsReceiver: StatsReceiver - ): UttLocalization = { - new UttLocalizationImpl( - topicListing, - uttClient, - statsReceiver - ) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/BUILD deleted file mode 100644 index 372962922..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "3rdparty/jvm/javax/inject:javax.inject", - "abdecider/src/main/scala", - "content-recommender/thrift/src/main/thrift:thrift-scala", - "hermit/hermit-core/src/main/scala/com/twitter/hermit/store/common", - "hermit/hermit-core/src/main/scala/com/twitter/hermit/store/gizmoduck", - "src/scala/com/twitter/topic_recos/stores", - "src/thrift/com/twitter/gizmoduck:thrift-scala", - "src/thrift/com/twitter/gizmoduck:user-thrift-scala", - "src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala", - "stitch/stitch-storehaus", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/common", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/handlers", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/modules", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/stores", - "topic-social-proof/server/src/main/thrift:thrift-scala", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/BUILD.docx new file mode 100644 index 000000000..a4e3037e5 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/TopicSocialProofService.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/TopicSocialProofService.docx new file mode 100644 index 000000000..0f3befad0 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/TopicSocialProofService.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/TopicSocialProofService.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/TopicSocialProofService.scala deleted file mode 100644 index f123e819f..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/service/TopicSocialProofService.scala +++ /dev/null @@ -1,182 +0,0 @@ -package com.twitter.tsp.service - -import com.twitter.abdecider.ABDeciderFactory -import com.twitter.abdecider.LoggingABDecider -import com.twitter.tsp.thriftscala.TspTweetInfo -import com.twitter.discovery.common.configapi.FeatureContextBuilder -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.gizmoduck.thriftscala.LookupContext -import com.twitter.gizmoduck.thriftscala.QueryFields -import com.twitter.gizmoduck.thriftscala.User -import com.twitter.gizmoduck.thriftscala.UserService -import com.twitter.hermit.store.gizmoduck.GizmoduckUserStore -import com.twitter.logging.Logger -import com.twitter.simclusters_v2.common.SemanticCoreEntityId -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.simclusters_v2.common.UserId -import com.twitter.spam.rtf.thriftscala.SafetyLevel -import com.twitter.stitch.storehaus.StitchOfReadableStore -import com.twitter.storehaus.ReadableStore -import com.twitter.strato.client.{Client => StratoClient} -import com.twitter.timelines.configapi -import com.twitter.timelines.configapi.CompositeConfig -import com.twitter.tsp.common.FeatureSwitchConfig -import com.twitter.tsp.common.FeatureSwitchesBuilder -import com.twitter.tsp.common.LoadShedder -import com.twitter.tsp.common.ParamsBuilder -import com.twitter.tsp.common.RecTargetFactory -import com.twitter.tsp.common.TopicSocialProofDecider -import com.twitter.tsp.handlers.TopicSocialProofHandler -import com.twitter.tsp.stores.LocalizedUttRecommendableTopicsStore -import com.twitter.tsp.stores.LocalizedUttTopicNameRequest -import com.twitter.tsp.stores.TopicResponses -import com.twitter.tsp.stores.TopicSocialProofStore -import com.twitter.tsp.stores.TopicSocialProofStore.TopicSocialProof -import com.twitter.tsp.stores.TopicStore -import com.twitter.tsp.stores.UttTopicFilterStore -import com.twitter.tsp.thriftscala.TopicSocialProofRequest -import com.twitter.tsp.thriftscala.TopicSocialProofResponse -import com.twitter.util.JavaTimer -import com.twitter.util.Timer -import javax.inject.Inject -import javax.inject.Singleton -import com.twitter.topiclisting.TopicListing -import com.twitter.topiclisting.utt.UttLocalization - -@Singleton -class TopicSocialProofService @Inject() ( - topicSocialProofStore: ReadableStore[TopicSocialProofStore.Query, Seq[TopicSocialProof]], - tweetInfoStore: ReadableStore[TweetId, TspTweetInfo], - serviceIdentifier: ServiceIdentifier, - stratoClient: StratoClient, - gizmoduck: UserService.MethodPerEndpoint, - topicListing: TopicListing, - uttLocalization: UttLocalization, - decider: TopicSocialProofDecider, - loadShedder: LoadShedder, - stats: StatsReceiver) { - - import TopicSocialProofService._ - - private val statsReceiver = stats.scope("topic-social-proof-management") - - private val isProd: Boolean = serviceIdentifier.environment == "prod" - - private val optOutStratoStorePath: String = - if (isProd) "interests/optOutInterests" else "interests/staging/optOutInterests" - - private val notInterestedInStorePath: String = - if (isProd) "interests/notInterestedTopicsGetter" - else "interests/staging/notInterestedTopicsGetter" - - private val userOptOutTopicsStore: ReadableStore[UserId, TopicResponses] = - TopicStore.userOptOutTopicStore(stratoClient, optOutStratoStorePath)( - statsReceiver.scope("ints_interests_opt_out_store")) - private val explicitFollowingTopicsStore: ReadableStore[UserId, TopicResponses] = - TopicStore.explicitFollowingTopicStore(stratoClient)( - statsReceiver.scope("ints_explicit_following_interests_store")) - private val userNotInterestedInTopicsStore: ReadableStore[UserId, TopicResponses] = - TopicStore.notInterestedInTopicsStore(stratoClient, notInterestedInStorePath)( - statsReceiver.scope("ints_not_interested_in_store")) - - private lazy val localizedUttRecommendableTopicsStore: ReadableStore[ - LocalizedUttTopicNameRequest, - Set[ - SemanticCoreEntityId - ] - ] = new LocalizedUttRecommendableTopicsStore(uttLocalization) - - implicit val timer: Timer = new JavaTimer(true) - - private lazy val uttTopicFilterStore = new UttTopicFilterStore( - topicListing = topicListing, - userOptOutTopicsStore = userOptOutTopicsStore, - explicitFollowingTopicsStore = explicitFollowingTopicsStore, - notInterestedTopicsStore = userNotInterestedInTopicsStore, - localizedUttRecommendableTopicsStore = localizedUttRecommendableTopicsStore, - timer = timer, - stats = statsReceiver.scope("UttTopicFilterStore") - ) - - private lazy val scribeLogger: Option[Logger] = Some(Logger.get("client_event")) - - private lazy val abDecider: LoggingABDecider = - ABDeciderFactory( - abDeciderYmlPath = configRepoDirectory + "/abdecider/abdecider.yml", - scribeLogger = scribeLogger, - decider = None, - environment = Some("production"), - ).buildWithLogging() - - private val builder: FeatureSwitchesBuilder = FeatureSwitchesBuilder( - statsReceiver = statsReceiver.scope("featureswitches-v2"), - abDecider = abDecider, - featuresDirectory = "features/topic-social-proof/main", - configRepoDirectory = configRepoDirectory, - addServiceDetailsFromAurora = !serviceIdentifier.isLocal, - fastRefresh = !isProd - ) - - private lazy val overridesConfig: configapi.Config = { - new CompositeConfig( - Seq( - FeatureSwitchConfig.config - ) - ) - } - - private val featureContextBuilder: FeatureContextBuilder = FeatureContextBuilder(builder.build()) - - private val paramsBuilder: ParamsBuilder = ParamsBuilder( - featureContextBuilder, - abDecider, - overridesConfig, - statsReceiver.scope("params") - ) - - private val userStore: ReadableStore[UserId, User] = { - val queryFields: Set[QueryFields] = Set( - QueryFields.Profile, - QueryFields.Account, - QueryFields.Roles, - QueryFields.Discoverability, - QueryFields.Safety, - QueryFields.Takedowns - ) - val context: LookupContext = LookupContext(safetyLevel = Some(SafetyLevel.Recommendations)) - - GizmoduckUserStore( - client = gizmoduck, - queryFields = queryFields, - context = context, - statsReceiver = statsReceiver.scope("gizmoduck") - ) - } - - private val recTargetFactory: RecTargetFactory = RecTargetFactory( - abDecider, - userStore, - paramsBuilder, - statsReceiver - ) - - private val topicSocialProofHandler = - new TopicSocialProofHandler( - topicSocialProofStore, - tweetInfoStore, - uttTopicFilterStore, - recTargetFactory, - decider, - statsReceiver.scope("TopicSocialProofHandler"), - loadShedder, - timer) - - val topicSocialProofHandlerStoreStitch: TopicSocialProofRequest => com.twitter.stitch.Stitch[ - TopicSocialProofResponse - ] = StitchOfReadableStore(topicSocialProofHandler.toReadableStore) -} - -object TopicSocialProofService { - private val configRepoDirectory = "/usr/local/config" -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/BUILD deleted file mode 100644 index a933b3782..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "3rdparty/jvm/com/twitter/storehaus:core", - "content-recommender/thrift/src/main/thrift:thrift-scala", - "escherbird/src/thrift/com/twitter/escherbird/topicannotation:topicannotation-thrift-scala", - "frigate/frigate-common:util", - "frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/health", - "frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/interests", - "frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/strato", - "hermit/hermit-core/src/main/scala/com/twitter/hermit/store/common", - "mediaservices/commons/src/main/thrift:thrift-scala", - "src/scala/com/twitter/simclusters_v2/common", - "src/scala/com/twitter/simclusters_v2/score", - "src/scala/com/twitter/topic_recos/common", - "src/scala/com/twitter/topic_recos/stores", - "src/thrift/com/twitter/frigate:frigate-common-thrift-scala", - "src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala", - "src/thrift/com/twitter/spam/rtf:safety-level-scala", - "src/thrift/com/twitter/tweetypie:service-scala", - "src/thrift/com/twitter/tweetypie:tweet-scala", - "stitch/stitch-storehaus", - "stitch/stitch-tweetypie/src/main/scala", - "strato/src/main/scala/com/twitter/strato/client", - "topic-social-proof/server/src/main/scala/com/twitter/tsp/utils", - "topic-social-proof/server/src/main/thrift:thrift-scala", - "topiclisting/topiclisting-core/src/main/scala/com/twitter/topiclisting", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/BUILD.docx new file mode 100644 index 000000000..25c32c8fe Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/LocalizedUttRecommendableTopicsStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/LocalizedUttRecommendableTopicsStore.docx new file mode 100644 index 000000000..97c008c90 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/LocalizedUttRecommendableTopicsStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/LocalizedUttRecommendableTopicsStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/LocalizedUttRecommendableTopicsStore.scala deleted file mode 100644 index bcac9d5f6..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/LocalizedUttRecommendableTopicsStore.scala +++ /dev/null @@ -1,30 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.storehaus.ReadableStore -import com.twitter.topiclisting.FollowableTopicProductId -import com.twitter.topiclisting.ProductId -import com.twitter.topiclisting.SemanticCoreEntityId -import com.twitter.topiclisting.TopicListingViewerContext -import com.twitter.topiclisting.utt.UttLocalization -import com.twitter.util.Future - -case class LocalizedUttTopicNameRequest( - productId: ProductId.Value, - viewerContext: TopicListingViewerContext, - enableInternationalTopics: Boolean) - -class LocalizedUttRecommendableTopicsStore(uttLocalization: UttLocalization) - extends ReadableStore[LocalizedUttTopicNameRequest, Set[SemanticCoreEntityId]] { - - override def get( - request: LocalizedUttTopicNameRequest - ): Future[Option[Set[SemanticCoreEntityId]]] = { - uttLocalization - .getRecommendableTopics( - productId = request.productId, - viewerContext = request.viewerContext, - enableInternationalTopics = request.enableInternationalTopics, - followableTopicProductId = FollowableTopicProductId.AllFollowable - ).map { response => Some(response) } - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/RepresentationScorerStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/RepresentationScorerStore.docx new file mode 100644 index 000000000..d802c601d Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/RepresentationScorerStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/RepresentationScorerStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/RepresentationScorerStore.scala deleted file mode 100644 index 7d5095ca6..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/RepresentationScorerStore.scala +++ /dev/null @@ -1,31 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.contentrecommender.thriftscala.ScoringResponse -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.common.store.strato.StratoFetchableStore -import com.twitter.hermit.store.common.ObservedReadableStore -import com.twitter.simclusters_v2.thriftscala.Score -import com.twitter.simclusters_v2.thriftscala.ScoreId -import com.twitter.storehaus.ReadableStore -import com.twitter.strato.client.Client -import com.twitter.strato.thrift.ScroogeConvImplicits._ -import com.twitter.tsp.utils.ReadableStoreWithMapOptionValues - -object RepresentationScorerStore { - - def apply( - stratoClient: Client, - scoringColumnPath: String, - stats: StatsReceiver - ): ReadableStore[ScoreId, Score] = { - val stratoFetchableStore = StratoFetchableStore - .withUnitView[ScoreId, ScoringResponse](stratoClient, scoringColumnPath) - - val enrichedStore = new ReadableStoreWithMapOptionValues[ScoreId, ScoringResponse, Score]( - stratoFetchableStore).mapOptionValues(_.score) - - ObservedReadableStore( - enrichedStore - )(stats.scope("representation_scorer_store")) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/SemanticCoreAnnotationStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/SemanticCoreAnnotationStore.docx new file mode 100644 index 000000000..e12f27b31 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/SemanticCoreAnnotationStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/SemanticCoreAnnotationStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/SemanticCoreAnnotationStore.scala deleted file mode 100644 index cfeb7722b..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/SemanticCoreAnnotationStore.scala +++ /dev/null @@ -1,64 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.escherbird.topicannotation.strato.thriftscala.TopicAnnotationValue -import com.twitter.escherbird.topicannotation.strato.thriftscala.TopicAnnotationView -import com.twitter.frigate.common.store.strato.StratoFetchableStore -import com.twitter.simclusters_v2.common.TopicId -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.storehaus.ReadableStore -import com.twitter.strato.client.Client -import com.twitter.strato.thrift.ScroogeConvImplicits._ -import com.twitter.util.Future - -/** - * This is copied from `src/scala/com/twitter/topic_recos/stores/SemanticCoreAnnotationStore.scala` - * Unfortunately their version assumes (incorrectly) that there is no View which causes warnings. - * While these warnings may not cause any problems in practice, better safe than sorry. - */ -object SemanticCoreAnnotationStore { - private val column = "semanticCore/topicannotation/topicAnnotation.Tweet" - - def getStratoStore(stratoClient: Client): ReadableStore[TweetId, TopicAnnotationValue] = { - StratoFetchableStore - .withView[TweetId, TopicAnnotationView, TopicAnnotationValue]( - stratoClient, - column, - TopicAnnotationView()) - } - - case class TopicAnnotation( - topicId: TopicId, - ignoreSimClustersFilter: Boolean, - modelVersionId: Long) -} - -/** - * Given a tweet Id, return the list of annotations defined by the TSIG team. - */ -case class SemanticCoreAnnotationStore(stratoStore: ReadableStore[TweetId, TopicAnnotationValue]) - extends ReadableStore[TweetId, Seq[SemanticCoreAnnotationStore.TopicAnnotation]] { - import SemanticCoreAnnotationStore._ - - override def multiGet[K1 <: TweetId]( - ks: Set[K1] - ): Map[K1, Future[Option[Seq[TopicAnnotation]]]] = { - stratoStore - .multiGet(ks) - .mapValues(_.map(_.map { topicAnnotationValue => - topicAnnotationValue.annotationsPerModel match { - case Some(annotationWithVersions) => - annotationWithVersions.flatMap { annotations => - annotations.annotations.map { annotation => - TopicAnnotation( - annotation.entityId, - annotation.ignoreQualityFilter.getOrElse(false), - annotations.modelVersionId - ) - } - } - case _ => - Nil - } - })) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicSocialProofStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicSocialProofStore.docx new file mode 100644 index 000000000..f2b1db94d Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicSocialProofStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicSocialProofStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicSocialProofStore.scala deleted file mode 100644 index 6ed71ca14..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicSocialProofStore.scala +++ /dev/null @@ -1,127 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.tsp.stores.TopicTweetsCosineSimilarityAggregateStore.ScoreKey -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.common.util.StatsUtil -import com.twitter.simclusters_v2.thriftscala._ -import com.twitter.storehaus.ReadableStore -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.tsp.stores.SemanticCoreAnnotationStore._ -import com.twitter.tsp.stores.TopicSocialProofStore.TopicSocialProof -import com.twitter.util.Future - -/** - * Provides a session-less Topic Social Proof information which doesn't rely on any User Info. - * This store is used by MemCache and In-Memory cache to achieve a higher performance. - * One Consumer embedding and Producer embedding are used to calculate raw score. - */ -case class TopicSocialProofStore( - representationScorerStore: ReadableStore[ScoreId, Score], - semanticCoreAnnotationStore: ReadableStore[TweetId, Seq[TopicAnnotation]] -)( - statsReceiver: StatsReceiver) - extends ReadableStore[TopicSocialProofStore.Query, Seq[TopicSocialProof]] { - import TopicSocialProofStore._ - - // Fetches the tweet's topic annotations from SemanticCore's Annotation API - override def get(query: TopicSocialProofStore.Query): Future[Option[Seq[TopicSocialProof]]] = { - StatsUtil.trackOptionStats(statsReceiver) { - for { - annotations <- - StatsUtil.trackItemsStats(statsReceiver.scope("semanticCoreAnnotationStore")) { - semanticCoreAnnotationStore.get(query.cacheableQuery.tweetId).map(_.getOrElse(Nil)) - } - - filteredAnnotations = filterAnnotationsByAllowList(annotations, query) - - scoredTopics <- - StatsUtil.trackItemMapStats(statsReceiver.scope("scoreTopicTweetsTweetLanguage")) { - // de-dup identical topicIds - val uniqueTopicIds = filteredAnnotations.map { annotation => - TopicId(annotation.topicId, Some(query.cacheableQuery.tweetLanguage), country = None) - }.toSet - - if (query.cacheableQuery.enableCosineSimilarityScoreCalculation) { - scoreTopicTweets(query.cacheableQuery.tweetId, uniqueTopicIds) - } else { - Future.value(uniqueTopicIds.map(id => id -> Map.empty[ScoreKey, Double]).toMap) - } - } - - } yield { - if (scoredTopics.nonEmpty) { - val versionedTopicProofs = filteredAnnotations.map { annotation => - val topicId = - TopicId(annotation.topicId, Some(query.cacheableQuery.tweetLanguage), country = None) - - TopicSocialProof( - topicId, - scores = scoredTopics.getOrElse(topicId, Map.empty), - annotation.ignoreSimClustersFilter, - annotation.modelVersionId - ) - } - Some(versionedTopicProofs) - } else { - None - } - } - } - } - - /*** - * When the allowList is not empty (e.g., TSP handler call, CrTopic handler call), - * the filter will be enabled and we will only keep annotations that have versionIds existing - * in the input allowedSemanticCoreVersionIds set. - * But when the allowList is empty (e.g., some debugger calls), - * we will not filter anything and pass. - * We limit the number of versionIds to be K = MaxNumberVersionIds - */ - private def filterAnnotationsByAllowList( - annotations: Seq[TopicAnnotation], - query: TopicSocialProofStore.Query - ): Seq[TopicAnnotation] = { - - val trimmedVersionIds = query.allowedSemanticCoreVersionIds.take(MaxNumberVersionIds) - annotations.filter { annotation => - trimmedVersionIds.isEmpty || trimmedVersionIds.contains(annotation.modelVersionId) - } - } - - private def scoreTopicTweets( - tweetId: TweetId, - topicIds: Set[TopicId] - ): Future[Map[TopicId, Map[ScoreKey, Double]]] = { - Future.collect { - topicIds.map { topicId => - val scoresFut = TopicTweetsCosineSimilarityAggregateStore.getRawScoresMap( - topicId, - tweetId, - TopicTweetsCosineSimilarityAggregateStore.DefaultScoreKeys, - representationScorerStore - ) - topicId -> scoresFut - }.toMap - } - } -} - -object TopicSocialProofStore { - - private val MaxNumberVersionIds = 9 - - case class Query( - cacheableQuery: CacheableQuery, - allowedSemanticCoreVersionIds: Set[Long] = Set.empty) // overridden by FS - - case class CacheableQuery( - tweetId: TweetId, - tweetLanguage: String, - enableCosineSimilarityScoreCalculation: Boolean = true) - - case class TopicSocialProof( - topicId: TopicId, - scores: Map[ScoreKey, Double], - ignoreSimClusterFiltering: Boolean, - semanticCoreVersionId: Long) -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicStore.docx new file mode 100644 index 000000000..373f12d8f Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicStore.scala deleted file mode 100644 index 61fae8c6a..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicStore.scala +++ /dev/null @@ -1,135 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.common.store.InterestedInInterestsFetchKey -import com.twitter.frigate.common.store.strato.StratoFetchableStore -import com.twitter.hermit.store.common.ObservedReadableStore -import com.twitter.interests.thriftscala.InterestId -import com.twitter.interests.thriftscala.InterestLabel -import com.twitter.interests.thriftscala.InterestRelationship -import com.twitter.interests.thriftscala.InterestRelationshipV1 -import com.twitter.interests.thriftscala.InterestedInInterestLookupContext -import com.twitter.interests.thriftscala.InterestedInInterestModel -import com.twitter.interests.thriftscala.OptOutInterestLookupContext -import com.twitter.interests.thriftscala.UserInterest -import com.twitter.interests.thriftscala.UserInterestData -import com.twitter.interests.thriftscala.UserInterestsResponse -import com.twitter.simclusters_v2.common.UserId -import com.twitter.storehaus.ReadableStore -import com.twitter.strato.client.Client -import com.twitter.strato.thrift.ScroogeConvImplicits._ - -case class TopicResponse( - entityId: Long, - interestedInData: Seq[InterestedInInterestModel], - scoreOverride: Option[Double] = None, - notInterestedInTimestamp: Option[Long] = None, - topicFollowTimestamp: Option[Long] = None) - -case class TopicResponses(responses: Seq[TopicResponse]) - -object TopicStore { - - private val InterestedInInterestsColumn = "interests/interestedInInterests" - private lazy val ExplicitInterestsContext: InterestedInInterestLookupContext = - InterestedInInterestLookupContext( - explicitContext = None, - inferredContext = None, - disableImplicit = Some(true) - ) - - private def userInterestsResponseToTopicResponse( - userInterestsResponse: UserInterestsResponse - ): TopicResponses = { - val responses = userInterestsResponse.interests.interests.toSeq.flatMap { userInterests => - userInterests.collect { - case UserInterest( - InterestId.SemanticCore(semanticCoreEntity), - Some(UserInterestData.InterestedIn(data))) => - val topicFollowingTimestampOpt = data.collect { - case InterestedInInterestModel.ExplicitModel( - InterestRelationship.V1(interestRelationshipV1)) => - interestRelationshipV1.timestampMs - }.lastOption - - TopicResponse(semanticCoreEntity.id, data, None, None, topicFollowingTimestampOpt) - } - } - TopicResponses(responses) - } - - def explicitFollowingTopicStore( - stratoClient: Client - )( - implicit statsReceiver: StatsReceiver - ): ReadableStore[UserId, TopicResponses] = { - val stratoStore = - StratoFetchableStore - .withUnitView[InterestedInInterestsFetchKey, UserInterestsResponse]( - stratoClient, - InterestedInInterestsColumn) - .composeKeyMapping[UserId](uid => - InterestedInInterestsFetchKey( - userId = uid, - labels = None, - lookupContext = Some(ExplicitInterestsContext) - )) - .mapValues(userInterestsResponseToTopicResponse) - - ObservedReadableStore(stratoStore) - } - - def userOptOutTopicStore( - stratoClient: Client, - optOutStratoStorePath: String - )( - implicit statsReceiver: StatsReceiver - ): ReadableStore[UserId, TopicResponses] = { - val stratoStore = - StratoFetchableStore - .withUnitView[ - (Long, Option[Seq[InterestLabel]], Option[OptOutInterestLookupContext]), - UserInterestsResponse](stratoClient, optOutStratoStorePath) - .composeKeyMapping[UserId](uid => (uid, None, None)) - .mapValues { userInterestsResponse => - val responses = userInterestsResponse.interests.interests.toSeq.flatMap { userInterests => - userInterests.collect { - case UserInterest( - InterestId.SemanticCore(semanticCoreEntity), - Some(UserInterestData.InterestedIn(data))) => - TopicResponse(semanticCoreEntity.id, data, None) - } - } - TopicResponses(responses) - } - ObservedReadableStore(stratoStore) - } - - def notInterestedInTopicsStore( - stratoClient: Client, - notInterestedInStorePath: String - )( - implicit statsReceiver: StatsReceiver - ): ReadableStore[UserId, TopicResponses] = { - val stratoStore = - StratoFetchableStore - .withUnitView[Long, Seq[UserInterest]](stratoClient, notInterestedInStorePath) - .composeKeyMapping[UserId](identity) - .mapValues { notInterestedInInterests => - val responses = notInterestedInInterests.collect { - case UserInterest( - InterestId.SemanticCore(semanticCoreEntity), - Some(UserInterestData.NotInterested(notInterestedInData))) => - val notInterestedInTimestampOpt = notInterestedInData.collect { - case InterestRelationship.V1(interestRelationshipV1: InterestRelationshipV1) => - interestRelationshipV1.timestampMs - }.lastOption - - TopicResponse(semanticCoreEntity.id, Seq.empty, None, notInterestedInTimestampOpt) - } - TopicResponses(responses) - } - ObservedReadableStore(stratoStore) - } - -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicTweetsCosineSimilarityAggregateStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicTweetsCosineSimilarityAggregateStore.docx new file mode 100644 index 000000000..5b42894c4 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicTweetsCosineSimilarityAggregateStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicTweetsCosineSimilarityAggregateStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicTweetsCosineSimilarityAggregateStore.scala deleted file mode 100644 index 3fb65d8ac..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TopicTweetsCosineSimilarityAggregateStore.scala +++ /dev/null @@ -1,99 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.simclusters_v2.thriftscala.EmbeddingType -import com.twitter.simclusters_v2.thriftscala.InternalId -import com.twitter.simclusters_v2.thriftscala.ModelVersion -import com.twitter.simclusters_v2.thriftscala.ScoreInternalId -import com.twitter.simclusters_v2.thriftscala.ScoringAlgorithm -import com.twitter.simclusters_v2.thriftscala.SimClustersEmbeddingId -import com.twitter.simclusters_v2.thriftscala.{ - SimClustersEmbeddingPairScoreId => ThriftSimClustersEmbeddingPairScoreId -} -import com.twitter.simclusters_v2.thriftscala.TopicId -import com.twitter.simclusters_v2.thriftscala.{Score => ThriftScore} -import com.twitter.simclusters_v2.thriftscala.{ScoreId => ThriftScoreId} -import com.twitter.storehaus.ReadableStore -import com.twitter.topic_recos.common._ -import com.twitter.topic_recos.common.Configs.DefaultModelVersion -import com.twitter.tsp.stores.TopicTweetsCosineSimilarityAggregateStore.ScoreKey -import com.twitter.util.Future - -object TopicTweetsCosineSimilarityAggregateStore { - - val TopicEmbeddingTypes: Seq[EmbeddingType] = - Seq( - EmbeddingType.FavTfgTopic, - EmbeddingType.LogFavBasedKgoApeTopic - ) - - // Add the new embedding types if want to test the new Tweet embedding performance. - val TweetEmbeddingTypes: Seq[EmbeddingType] = Seq(EmbeddingType.LogFavBasedTweet) - - val ModelVersions: Seq[ModelVersion] = - Seq(DefaultModelVersion) - - val DefaultScoreKeys: Seq[ScoreKey] = { - for { - modelVersion <- ModelVersions - topicEmbeddingType <- TopicEmbeddingTypes - tweetEmbeddingType <- TweetEmbeddingTypes - } yield { - ScoreKey( - topicEmbeddingType = topicEmbeddingType, - tweetEmbeddingType = tweetEmbeddingType, - modelVersion = modelVersion - ) - } - } - - case class ScoreKey( - topicEmbeddingType: EmbeddingType, - tweetEmbeddingType: EmbeddingType, - modelVersion: ModelVersion) - - def getRawScoresMap( - topicId: TopicId, - tweetId: TweetId, - scoreKeys: Seq[ScoreKey], - representationScorerStore: ReadableStore[ThriftScoreId, ThriftScore] - ): Future[Map[ScoreKey, Double]] = { - val scoresMapFut = scoreKeys.map { key => - val scoreInternalId = ScoreInternalId.SimClustersEmbeddingPairScoreId( - ThriftSimClustersEmbeddingPairScoreId( - buildTopicEmbedding(topicId, key.topicEmbeddingType, key.modelVersion), - SimClustersEmbeddingId( - key.tweetEmbeddingType, - key.modelVersion, - InternalId.TweetId(tweetId)) - )) - val scoreFut = representationScorerStore - .get( - ThriftScoreId( - algorithm = ScoringAlgorithm.PairEmbeddingCosineSimilarity, // Hard code as cosine sim - internalId = scoreInternalId - )) - key -> scoreFut - }.toMap - - Future - .collect(scoresMapFut).map(_.collect { - case (key, Some(ThriftScore(score))) => - (key, score) - }) - } -} - -case class TopicTweetsCosineSimilarityAggregateStore( - representationScorerStore: ReadableStore[ThriftScoreId, ThriftScore] -)( - statsReceiver: StatsReceiver) - extends ReadableStore[(TopicId, TweetId, Seq[ScoreKey]), Map[ScoreKey, Double]] { - import TopicTweetsCosineSimilarityAggregateStore._ - - override def get(k: (TopicId, TweetId, Seq[ScoreKey])): Future[Option[Map[ScoreKey, Double]]] = { - statsReceiver.counter("topicTweetsCosineSimilariltyAggregateStore").incr() - getRawScoresMap(k._1, k._2, k._3, representationScorerStore).map(Some(_)) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TweetInfoStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TweetInfoStore.docx new file mode 100644 index 000000000..382abf471 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TweetInfoStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TweetInfoStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TweetInfoStore.scala deleted file mode 100644 index 70cc00451..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/TweetInfoStore.scala +++ /dev/null @@ -1,230 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.conversions.DurationOps._ -import com.twitter.tsp.thriftscala.TspTweetInfo -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.thriftscala.TweetHealthScores -import com.twitter.frigate.thriftscala.UserAgathaScores -import com.twitter.logging.Logger -import com.twitter.mediaservices.commons.thriftscala.MediaCategory -import com.twitter.mediaservices.commons.tweetmedia.thriftscala.MediaInfo -import com.twitter.mediaservices.commons.tweetmedia.thriftscala.MediaSizeType -import com.twitter.simclusters_v2.common.TweetId -import com.twitter.simclusters_v2.common.UserId -import com.twitter.spam.rtf.thriftscala.SafetyLevel -import com.twitter.stitch.Stitch -import com.twitter.stitch.storehaus.ReadableStoreOfStitch -import com.twitter.stitch.tweetypie.TweetyPie -import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieException -import com.twitter.storehaus.ReadableStore -import com.twitter.topiclisting.AnnotationRuleProvider -import com.twitter.tsp.utils.HealthSignalsUtils -import com.twitter.tweetypie.thriftscala.TweetInclude -import com.twitter.tweetypie.thriftscala.{Tweet => TTweet} -import com.twitter.tweetypie.thriftscala._ -import com.twitter.util.Duration -import com.twitter.util.Future -import com.twitter.util.TimeoutException -import com.twitter.util.Timer - -object TweetyPieFieldsStore { - - // Tweet fields options. Only fields specified here will be hydrated in the tweet - private val CoreTweetFields: Set[TweetInclude] = Set[TweetInclude]( - TweetInclude.TweetFieldId(TTweet.IdField.id), - TweetInclude.TweetFieldId(TTweet.CoreDataField.id), // needed for the authorId - TweetInclude.TweetFieldId(TTweet.LanguageField.id), - TweetInclude.CountsFieldId(StatusCounts.FavoriteCountField.id), - TweetInclude.CountsFieldId(StatusCounts.RetweetCountField.id), - TweetInclude.TweetFieldId(TTweet.QuotedTweetField.id), - TweetInclude.TweetFieldId(TTweet.MediaKeysField.id), - TweetInclude.TweetFieldId(TTweet.EscherbirdEntityAnnotationsField.id), - TweetInclude.TweetFieldId(TTweet.MediaField.id), - TweetInclude.TweetFieldId(TTweet.UrlsField.id) - ) - - private val gtfo: GetTweetFieldsOptions = GetTweetFieldsOptions( - tweetIncludes = CoreTweetFields, - safetyLevel = Some(SafetyLevel.Recommendations) - ) - - def getStoreFromTweetyPie( - tweetyPie: TweetyPie, - convertExceptionsToNotFound: Boolean = true - ): ReadableStore[Long, GetTweetFieldsResult] = { - val log = Logger("TweetyPieFieldsStore") - - ReadableStoreOfStitch { tweetId: Long => - tweetyPie - .getTweetFields(tweetId, options = gtfo) - .rescue { - case ex: TweetyPieException if convertExceptionsToNotFound => - log.error(ex, s"Error while hitting tweetypie ${ex.result}") - Stitch.NotFound - } - } - } -} - -object TweetInfoStore { - - case class IsPassTweetHealthFilters(tweetStrictest: Option[Boolean]) - - case class IsPassAgathaHealthFilters(agathaStrictest: Option[Boolean]) - - private val HealthStoreTimeout: Duration = 40.milliseconds - private val isPassTweetHealthFilters: IsPassTweetHealthFilters = IsPassTweetHealthFilters(None) - private val isPassAgathaHealthFilters: IsPassAgathaHealthFilters = IsPassAgathaHealthFilters(None) -} - -case class TweetInfoStore( - tweetFieldsStore: ReadableStore[TweetId, GetTweetFieldsResult], - tweetHealthModelStore: ReadableStore[TweetId, TweetHealthScores], - userHealthModelStore: ReadableStore[UserId, UserAgathaScores], - timer: Timer -)( - statsReceiver: StatsReceiver) - extends ReadableStore[TweetId, TspTweetInfo] { - - import TweetInfoStore._ - - private[this] def toTweetInfo( - tweetFieldsResult: GetTweetFieldsResult - ): Future[Option[TspTweetInfo]] = { - tweetFieldsResult.tweetResult match { - case result: TweetFieldsResultState.Found if result.found.suppressReason.isEmpty => - val tweet = result.found.tweet - - val authorIdOpt = tweet.coreData.map(_.userId) - val favCountOpt = tweet.counts.flatMap(_.favoriteCount) - - val languageOpt = tweet.language.map(_.language) - val hasImageOpt = - tweet.mediaKeys.map(_.map(_.mediaCategory).exists(_ == MediaCategory.TweetImage)) - val hasGifOpt = - tweet.mediaKeys.map(_.map(_.mediaCategory).exists(_ == MediaCategory.TweetGif)) - val isNsfwAuthorOpt = Some( - tweet.coreData.exists(_.nsfwUser) || tweet.coreData.exists(_.nsfwAdmin)) - val isTweetReplyOpt = tweet.coreData.map(_.reply.isDefined) - val hasMultipleMediaOpt = - tweet.mediaKeys.map(_.map(_.mediaCategory).size > 1) - - val isKGODenylist = Some( - tweet.escherbirdEntityAnnotations - .exists(_.entityAnnotations.exists(AnnotationRuleProvider.isSuppressedTopicsDenylist))) - - val isNullcastOpt = tweet.coreData.map(_.nullcast) // These are Ads. go/nullcast - - val videoDurationOpt = tweet.media.flatMap(_.flatMap { - _.mediaInfo match { - case Some(MediaInfo.VideoInfo(info)) => - Some((info.durationMillis + 999) / 1000) // video playtime always round up - case _ => None - } - }.headOption) - - // There many different types of videos. To be robust to new types being added, we just use - // the videoDurationOpt to keep track of whether the item has a video or not. - val hasVideo = videoDurationOpt.isDefined - - val mediaDimensionsOpt = - tweet.media.flatMap(_.headOption.flatMap( - _.sizes.find(_.sizeType == MediaSizeType.Orig).map(size => (size.width, size.height)))) - - val mediaWidth = mediaDimensionsOpt.map(_._1).getOrElse(1) - val mediaHeight = mediaDimensionsOpt.map(_._2).getOrElse(1) - // high resolution media's width is always greater than 480px and height is always greater than 480px - val isHighMediaResolution = mediaHeight > 480 && mediaWidth > 480 - val isVerticalAspectRatio = mediaHeight >= mediaWidth && mediaWidth > 1 - val hasUrlOpt = tweet.urls.map(_.nonEmpty) - - (authorIdOpt, favCountOpt) match { - case (Some(authorId), Some(favCount)) => - hydrateHealthScores(tweet.id, authorId).map { - case (isPassAgathaHealthFilters, isPassTweetHealthFilters) => - Some( - TspTweetInfo( - authorId = authorId, - favCount = favCount, - language = languageOpt, - hasImage = hasImageOpt, - hasVideo = Some(hasVideo), - hasGif = hasGifOpt, - isNsfwAuthor = isNsfwAuthorOpt, - isKGODenylist = isKGODenylist, - isNullcast = isNullcastOpt, - videoDurationSeconds = videoDurationOpt, - isHighMediaResolution = Some(isHighMediaResolution), - isVerticalAspectRatio = Some(isVerticalAspectRatio), - isPassAgathaHealthFilterStrictest = isPassAgathaHealthFilters.agathaStrictest, - isPassTweetHealthFilterStrictest = isPassTweetHealthFilters.tweetStrictest, - isReply = isTweetReplyOpt, - hasMultipleMedia = hasMultipleMediaOpt, - hasUrl = hasUrlOpt - )) - } - case _ => - statsReceiver.counter("missingFields").incr() - Future.None // These values should always exist. - } - case _: TweetFieldsResultState.NotFound => - statsReceiver.counter("notFound").incr() - Future.None - case _: TweetFieldsResultState.Failed => - statsReceiver.counter("failed").incr() - Future.None - case _: TweetFieldsResultState.Filtered => - statsReceiver.counter("filtered").incr() - Future.None - case _ => - statsReceiver.counter("unknown").incr() - Future.None - } - } - - private[this] def hydrateHealthScores( - tweetId: TweetId, - authorId: Long - ): Future[(IsPassAgathaHealthFilters, IsPassTweetHealthFilters)] = { - Future - .join( - tweetHealthModelStore - .multiGet(Set(tweetId))(tweetId), - userHealthModelStore - .multiGet(Set(authorId))(authorId) - ).map { - case (tweetHealthScoresOpt, userAgathaScoresOpt) => - // This stats help us understand empty rate for AgathaCalibratedNsfw / NsfwTextUserScore - statsReceiver.counter("totalCountAgathaScore").incr() - if (userAgathaScoresOpt.getOrElse(UserAgathaScores()).agathaCalibratedNsfw.isEmpty) - statsReceiver.counter("emptyCountAgathaCalibratedNsfw").incr() - if (userAgathaScoresOpt.getOrElse(UserAgathaScores()).nsfwTextUserScore.isEmpty) - statsReceiver.counter("emptyCountNsfwTextUserScore").incr() - - val isPassAgathaHealthFilters = IsPassAgathaHealthFilters( - agathaStrictest = - Some(HealthSignalsUtils.isTweetAgathaModelQualified(userAgathaScoresOpt)), - ) - - val isPassTweetHealthFilters = IsPassTweetHealthFilters( - tweetStrictest = - Some(HealthSignalsUtils.isTweetHealthModelQualified(tweetHealthScoresOpt)) - ) - - (isPassAgathaHealthFilters, isPassTweetHealthFilters) - }.raiseWithin(HealthStoreTimeout)(timer).rescue { - case _: TimeoutException => - statsReceiver.counter("hydrateHealthScoreTimeout").incr() - Future.value((isPassAgathaHealthFilters, isPassTweetHealthFilters)) - case _ => - statsReceiver.counter("hydrateHealthScoreFailure").incr() - Future.value((isPassAgathaHealthFilters, isPassTweetHealthFilters)) - } - } - - override def multiGet[K1 <: TweetId](ks: Set[K1]): Map[K1, Future[Option[TspTweetInfo]]] = { - statsReceiver.counter("tweetFieldsStore").incr(ks.size) - tweetFieldsStore - .multiGet(ks).mapValues(_.flatMap { _.map { v => toTweetInfo(v) }.getOrElse(Future.None) }) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/UttTopicFilterStore.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/UttTopicFilterStore.docx new file mode 100644 index 000000000..7aa7c8caa Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/UttTopicFilterStore.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/UttTopicFilterStore.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/UttTopicFilterStore.scala deleted file mode 100644 index 89a502008..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/stores/UttTopicFilterStore.scala +++ /dev/null @@ -1,248 +0,0 @@ -package com.twitter.tsp.stores - -import com.twitter.conversions.DurationOps._ -import com.twitter.finagle.FailureFlags.flagsOf -import com.twitter.finagle.mux.ClientDiscardedRequestException -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.frigate.common.store.interests -import com.twitter.simclusters_v2.common.UserId -import com.twitter.storehaus.ReadableStore -import com.twitter.topiclisting.ProductId -import com.twitter.topiclisting.TopicListing -import com.twitter.topiclisting.TopicListingViewerContext -import com.twitter.topiclisting.{SemanticCoreEntityId => ScEntityId} -import com.twitter.tsp.thriftscala.TopicFollowType -import com.twitter.tsp.thriftscala.TopicListingSetting -import com.twitter.tsp.thriftscala.TopicSocialProofFilteringBypassMode -import com.twitter.util.Duration -import com.twitter.util.Future -import com.twitter.util.TimeoutException -import com.twitter.util.Timer - -class UttTopicFilterStore( - topicListing: TopicListing, - userOptOutTopicsStore: ReadableStore[interests.UserId, TopicResponses], - explicitFollowingTopicsStore: ReadableStore[interests.UserId, TopicResponses], - notInterestedTopicsStore: ReadableStore[interests.UserId, TopicResponses], - localizedUttRecommendableTopicsStore: ReadableStore[LocalizedUttTopicNameRequest, Set[Long]], - timer: Timer, - stats: StatsReceiver) { - import UttTopicFilterStore._ - - // Set of blacklisted SemanticCore IDs that are paused. - private[this] def getPausedTopics(topicCtx: TopicListingViewerContext): Set[ScEntityId] = { - topicListing.getPausedTopics(topicCtx) - } - - private[this] def getOptOutTopics(userId: Long): Future[Set[ScEntityId]] = { - stats.counter("getOptOutTopicsCount").incr() - userOptOutTopicsStore - .get(userId).map { responseOpt => - responseOpt - .map { responses => responses.responses.map(_.entityId) }.getOrElse(Seq.empty).toSet - }.raiseWithin(DefaultOptOutTimeout)(timer).rescue { - case err: TimeoutException => - stats.counter("getOptOutTopicsTimeout").incr() - Future.exception(err) - case err: ClientDiscardedRequestException - if flagsOf(err).contains("interrupted") && flagsOf(err) - .contains("ignorable") => - stats.counter("getOptOutTopicsDiscardedBackupRequest").incr() - Future.exception(err) - case err => - stats.counter("getOptOutTopicsFailure").incr() - Future.exception(err) - } - } - - private[this] def getNotInterestedIn(userId: Long): Future[Set[ScEntityId]] = { - stats.counter("getNotInterestedInCount").incr() - notInterestedTopicsStore - .get(userId).map { responseOpt => - responseOpt - .map { responses => responses.responses.map(_.entityId) }.getOrElse(Seq.empty).toSet - }.raiseWithin(DefaultNotInterestedInTimeout)(timer).rescue { - case err: TimeoutException => - stats.counter("getNotInterestedInTimeout").incr() - Future.exception(err) - case err: ClientDiscardedRequestException - if flagsOf(err).contains("interrupted") && flagsOf(err) - .contains("ignorable") => - stats.counter("getNotInterestedInDiscardedBackupRequest").incr() - Future.exception(err) - case err => - stats.counter("getNotInterestedInFailure").incr() - Future.exception(err) - } - } - - private[this] def getFollowedTopics(userId: Long): Future[Set[TopicResponse]] = { - stats.counter("getFollowedTopicsCount").incr() - - explicitFollowingTopicsStore - .get(userId).map { responseOpt => - responseOpt.map(_.responses.toSet).getOrElse(Set.empty) - }.raiseWithin(DefaultInterestedInTimeout)(timer).rescue { - case _: TimeoutException => - stats.counter("getFollowedTopicsTimeout").incr() - Future(Set.empty) - case _ => - stats.counter("getFollowedTopicsFailure").incr() - Future(Set.empty) - } - } - - private[this] def getFollowedTopicIds(userId: Long): Future[Set[ScEntityId]] = { - getFollowedTopics(userId: Long).map(_.map(_.entityId)) - } - - private[this] def getWhitelistTopicIds( - normalizedContext: TopicListingViewerContext, - enableInternationalTopics: Boolean - ): Future[Set[ScEntityId]] = { - stats.counter("getWhitelistTopicIdsCount").incr() - - val uttRequest = LocalizedUttTopicNameRequest( - productId = ProductId.Followable, - viewerContext = normalizedContext, - enableInternationalTopics = enableInternationalTopics - ) - localizedUttRecommendableTopicsStore - .get(uttRequest).map { response => - response.getOrElse(Set.empty) - }.rescue { - case _ => - stats.counter("getWhitelistTopicIdsFailure").incr() - Future(Set.empty) - } - } - - private[this] def getDenyListTopicIdsForUser( - userId: UserId, - topicListingSetting: TopicListingSetting, - context: TopicListingViewerContext, - bypassModes: Option[Set[TopicSocialProofFilteringBypassMode]] - ): Future[Set[ScEntityId]] = { - - val denyListTopicIdsFuture = topicListingSetting match { - case TopicListingSetting.ImplicitFollow => - getFollowedTopicIds(userId) - case _ => - Future(Set.empty[ScEntityId]) - } - - // we don't filter opt-out topics for implicit follow topic listing setting - val optOutTopicIdsFuture = topicListingSetting match { - case TopicListingSetting.ImplicitFollow => Future(Set.empty[ScEntityId]) - case _ => getOptOutTopics(userId) - } - - val notInterestedTopicIdsFuture = - if (bypassModes.exists(_.contains(TopicSocialProofFilteringBypassMode.NotInterested))) { - Future(Set.empty[ScEntityId]) - } else { - getNotInterestedIn(userId) - } - val pausedTopicIdsFuture = Future.value(getPausedTopics(context)) - - Future - .collect( - List( - denyListTopicIdsFuture, - optOutTopicIdsFuture, - notInterestedTopicIdsFuture, - pausedTopicIdsFuture)).map { list => list.reduce(_ ++ _) } - } - - private[this] def getDiff( - aFut: Future[Set[ScEntityId]], - bFut: Future[Set[ScEntityId]] - ): Future[Set[ScEntityId]] = { - Future.join(aFut, bFut).map { - case (a, b) => a.diff(b) - } - } - - /** - * calculates the diff of all the whitelisted IDs with blacklisted IDs and returns the set of IDs - * that we will be recommending from or followed topics by the user by client setting. - */ - def getAllowListTopicsForUser( - userId: UserId, - topicListingSetting: TopicListingSetting, - context: TopicListingViewerContext, - bypassModes: Option[Set[TopicSocialProofFilteringBypassMode]] - ): Future[Map[ScEntityId, Option[TopicFollowType]]] = { - - /** - * Title: an illustrative table to explain how allow list is composed - * AllowList = WhiteList - DenyList - OptOutTopics - PausedTopics - NotInterestedInTopics - * - * TopicListingSetting: Following ImplicitFollow All Followable - * Whitelist: FollowedTopics(user) AllWhitelistedTopics Nil AllWhitelistedTopics - * DenyList: Nil FollowedTopics(user) Nil Nil - * - * ps. for TopicListingSetting.All, the returned allow list is Nil. Why? - * It's because that allowList is not required given the TopicListingSetting == 'All'. - * See TopicSocialProofHandler.filterByAllowedList() for more details. - */ - - topicListingSetting match { - // "All" means all the UTT entity is qualified. So don't need to fetch the Whitelist anymore. - case TopicListingSetting.All => Future.value(Map.empty) - case TopicListingSetting.Following => - getFollowingTopicsForUserWithTimestamp(userId, context, bypassModes).map { - _.mapValues(_ => Some(TopicFollowType.Following)) - } - case TopicListingSetting.ImplicitFollow => - getDiff( - getWhitelistTopicIds(context, enableInternationalTopics = true), - getDenyListTopicIdsForUser(userId, topicListingSetting, context, bypassModes)).map { - _.map { scEntityId => - scEntityId -> Some(TopicFollowType.ImplicitFollow) - }.toMap - } - case _ => - val followedTopicIdsFut = getFollowedTopicIds(userId) - val allowListTopicIdsFut = getDiff( - getWhitelistTopicIds(context, enableInternationalTopics = true), - getDenyListTopicIdsForUser(userId, topicListingSetting, context, bypassModes)) - Future.join(allowListTopicIdsFut, followedTopicIdsFut).map { - case (allowListTopicId, followedTopicIds) => - allowListTopicId.map { scEntityId => - if (followedTopicIds.contains(scEntityId)) - scEntityId -> Some(TopicFollowType.Following) - else scEntityId -> Some(TopicFollowType.ImplicitFollow) - }.toMap - } - } - } - - private[this] def getFollowingTopicsForUserWithTimestamp( - userId: UserId, - context: TopicListingViewerContext, - bypassModes: Option[Set[TopicSocialProofFilteringBypassMode]] - ): Future[Map[ScEntityId, Option[Long]]] = { - - val followedTopicIdToTimestampFut = getFollowedTopics(userId).map(_.map { followedTopic => - followedTopic.entityId -> followedTopic.topicFollowTimestamp - }.toMap) - - followedTopicIdToTimestampFut.flatMap { followedTopicIdToTimestamp => - getDiff( - Future(followedTopicIdToTimestamp.keySet), - getDenyListTopicIdsForUser(userId, TopicListingSetting.Following, context, bypassModes) - ).map { - _.map { scEntityId => - scEntityId -> followedTopicIdToTimestamp.get(scEntityId).flatten - }.toMap - } - } - } -} - -object UttTopicFilterStore { - val DefaultNotInterestedInTimeout: Duration = 60.milliseconds - val DefaultOptOutTimeout: Duration = 60.milliseconds - val DefaultInterestedInTimeout: Duration = 60.milliseconds -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/BUILD b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/BUILD deleted file mode 100644 index 3f4c6f42c..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - tags = [ - "bazel-compatible", - ], - dependencies = [ - "3rdparty/jvm/org/lz4:lz4-java", - "content-recommender/thrift/src/main/thrift:thrift-scala", - "frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store", - "frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/health", - "stitch/stitch-storehaus", - "topic-social-proof/server/src/main/thrift:thrift-scala", - ], -) diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/BUILD.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/BUILD.docx new file mode 100644 index 000000000..63c381ea7 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/LZ4Injection.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/LZ4Injection.docx new file mode 100644 index 000000000..23c73d552 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/LZ4Injection.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/LZ4Injection.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/LZ4Injection.scala deleted file mode 100644 index c72b6032f..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/LZ4Injection.scala +++ /dev/null @@ -1,19 +0,0 @@ -package com.twitter.tsp.utils - -import com.twitter.bijection.Injection -import scala.util.Try -import net.jpountz.lz4.LZ4CompressorWithLength -import net.jpountz.lz4.LZ4DecompressorWithLength -import net.jpountz.lz4.LZ4Factory - -object LZ4Injection extends Injection[Array[Byte], Array[Byte]] { - private val lz4Factory = LZ4Factory.fastestInstance() - private val fastCompressor = new LZ4CompressorWithLength(lz4Factory.fastCompressor()) - private val decompressor = new LZ4DecompressorWithLength(lz4Factory.fastDecompressor()) - - override def apply(a: Array[Byte]): Array[Byte] = LZ4Injection.fastCompressor.compress(a) - - override def invert(b: Array[Byte]): Try[Array[Byte]] = Try { - LZ4Injection.decompressor.decompress(b) - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/ReadableStoreWithMapOptionValues.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/ReadableStoreWithMapOptionValues.docx new file mode 100644 index 000000000..ad3c08e5b Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/ReadableStoreWithMapOptionValues.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/ReadableStoreWithMapOptionValues.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/ReadableStoreWithMapOptionValues.scala deleted file mode 100644 index ddae5a310..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/ReadableStoreWithMapOptionValues.scala +++ /dev/null @@ -1,20 +0,0 @@ -package com.twitter.tsp.utils - -import com.twitter.storehaus.AbstractReadableStore -import com.twitter.storehaus.ReadableStore -import com.twitter.util.Future - -class ReadableStoreWithMapOptionValues[K, V1, V2](rs: ReadableStore[K, V1]) { - - def mapOptionValues( - fn: V1 => Option[V2] - ): ReadableStore[K, V2] = { - val self = rs - new AbstractReadableStore[K, V2] { - override def get(k: K): Future[Option[V2]] = self.get(k).map(_.flatMap(fn)) - - override def multiGet[K1 <: K](ks: Set[K1]): Map[K1, Future[Option[V2]]] = - self.multiGet(ks).mapValues(_.map(_.flatMap(fn))) - } - } -} diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/SeqObjectInjection.docx b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/SeqObjectInjection.docx new file mode 100644 index 000000000..a53f18be5 Binary files /dev/null and b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/SeqObjectInjection.docx differ diff --git a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/SeqObjectInjection.scala b/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/SeqObjectInjection.scala deleted file mode 100644 index 96a0740e4..000000000 --- a/topic-social-proof/server/src/main/scala/com/twitter/tsp/utils/SeqObjectInjection.scala +++ /dev/null @@ -1,32 +0,0 @@ -package com.twitter.tsp.utils - -import com.twitter.bijection.Injection -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream -import java.io.ObjectInputStream -import java.io.ObjectOutputStream -import java.io.Serializable -import scala.util.Try - -/** - * @tparam T must be a serializable class - */ -case class SeqObjectInjection[T <: Serializable]() extends Injection[Seq[T], Array[Byte]] { - - override def apply(seq: Seq[T]): Array[Byte] = { - val byteStream = new ByteArrayOutputStream() - val outputStream = new ObjectOutputStream(byteStream) - outputStream.writeObject(seq) - outputStream.close() - byteStream.toByteArray - } - - override def invert(bytes: Array[Byte]): Try[Seq[T]] = { - Try { - val inputStream = new ObjectInputStream(new ByteArrayInputStream(bytes)) - val seq = inputStream.readObject().asInstanceOf[Seq[T]] - inputStream.close() - seq - } - } -} diff --git a/topic-social-proof/server/src/main/thrift/BUILD b/topic-social-proof/server/src/main/thrift/BUILD deleted file mode 100644 index 9bdbb71e0..000000000 --- a/topic-social-proof/server/src/main/thrift/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -create_thrift_libraries( - base_name = "thrift", - sources = ["*.thrift"], - platform = "java8", - tags = [ - "bazel-compatible", - ], - dependency_roots = [ - "content-recommender/thrift/src/main/thrift", - "content-recommender/thrift/src/main/thrift:content-recommender-common", - "interests-service/thrift/src/main/thrift", - "src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift", - ], - generate_languages = [ - "java", - "scala", - "strato", - ], - provides_java_name = "tsp-thrift-java", - provides_scala_name = "tsp-thrift-scala", -) diff --git a/topic-social-proof/server/src/main/thrift/BUILD.docx b/topic-social-proof/server/src/main/thrift/BUILD.docx new file mode 100644 index 000000000..f83af9cad Binary files /dev/null and b/topic-social-proof/server/src/main/thrift/BUILD.docx differ diff --git a/topic-social-proof/server/src/main/thrift/service.docx b/topic-social-proof/server/src/main/thrift/service.docx new file mode 100644 index 000000000..5ed9987c0 Binary files /dev/null and b/topic-social-proof/server/src/main/thrift/service.docx differ diff --git a/topic-social-proof/server/src/main/thrift/service.thrift b/topic-social-proof/server/src/main/thrift/service.thrift deleted file mode 100644 index 70f3c5398..000000000 --- a/topic-social-proof/server/src/main/thrift/service.thrift +++ /dev/null @@ -1,104 +0,0 @@ -namespace java com.twitter.tsp.thriftjava -namespace py gen.twitter.tsp -#@namespace scala com.twitter.tsp.thriftscala -#@namespace strato com.twitter.tsp.strato - -include "com/twitter/contentrecommender/common.thrift" -include "com/twitter/simclusters_v2/identifier.thrift" -include "com/twitter/simclusters_v2/online_store.thrift" -include "topic_listing.thrift" - -enum TopicListingSetting { - All = 0 // All the existing Semantic Core Entity/Topics. ie., All topics on twitter, and may or may not have been launched yet. - Followable = 1 // All the topics which the user is allowed to follow. ie., topics that have shipped, and user may or may not be following it. - Following = 2 // Only topics the user is explicitly following - ImplicitFollow = 3 // The topics user has not followed but implicitly may follow. ie., Only topics that user has not followed. -} (hasPersonalData='false') - - -// used to tell Topic Social Proof endpoint which specific filtering can be bypassed -enum TopicSocialProofFilteringBypassMode { - NotInterested = 0 -} (hasPersonalData='false') - -struct TopicSocialProofRequest { - 1: required i64 userId(personalDataType = "UserId") - 2: required set tweetIds(personalDataType = 'TweetId') - 3: required common.DisplayLocation displayLocation - 4: required TopicListingSetting topicListingSetting - 5: required topic_listing.TopicListingViewerContext context - 6: optional set bypassModes - 7: optional map> tags -} - -struct TopicSocialProofOptions { - 1: required i64 userId(personalDataType = "UserId") - 2: required common.DisplayLocation displayLocation - 3: required TopicListingSetting topicListingSetting - 4: required topic_listing.TopicListingViewerContext context - 5: optional set bypassModes - 6: optional map> tags -} - -struct TopicSocialProofResponse { - 1: required map> socialProofs -}(hasPersonalData='false') - -// Distinguishes between how a topic tweet is generated. Useful for metric tracking and debugging -enum TopicTweetType { - // CrOON candidates - UserInterestedIn = 1 - Twistly = 2 - // crTopic candidates - SkitConsumerEmbeddings = 100 - SkitProducerEmbeddings = 101 - SkitHighPrecision = 102 - SkitInterestBrowser = 103 - Certo = 104 -}(persisted='true') - -struct TopicWithScore { - 1: required i64 topicId - 2: required double score // score used to rank topics relative to one another - 3: optional TopicTweetType algorithmType // how the topic is generated - 4: optional TopicFollowType topicFollowType // Whether the topic is being explicitly or implicily followed -}(persisted='true', hasPersonalData='false') - - -struct ScoreKey { - 1: required identifier.EmbeddingType userEmbeddingType - 2: required identifier.EmbeddingType topicEmbeddingType - 3: required online_store.ModelVersion modelVersion -}(persisted='true', hasPersonalData='false') - -struct UserTopicScore { - 1: required map scores -}(persisted='true', hasPersonalData='false') - - -enum TopicFollowType { - Following = 1 - ImplicitFollow = 2 -}(persisted='true') - -// Provide the Tags which provides the Recommended Tweets Source Signal and other context. -// Warning: Please don't use this tag in any ML Features or business logic. -enum MetricTag { - // Source Signal Tags - TweetFavorite = 0 - Retweet = 1 - - UserFollow = 101 - PushOpenOrNtabClick = 201 - - HomeTweetClick = 301 - HomeVideoView = 302 - HomeSongbirdShowMore = 303 - - - InterestsRankerRecentSearches = 401 // For Interests Candidate Expansion - - UserInterestedIn = 501 - MBCG = 503 - // Other Metric Tags -} (persisted='true', hasPersonalData='true') diff --git a/topic-social-proof/server/src/main/thrift/tweet_info.docx b/topic-social-proof/server/src/main/thrift/tweet_info.docx new file mode 100644 index 000000000..ac91a9e17 Binary files /dev/null and b/topic-social-proof/server/src/main/thrift/tweet_info.docx differ diff --git a/topic-social-proof/server/src/main/thrift/tweet_info.thrift b/topic-social-proof/server/src/main/thrift/tweet_info.thrift deleted file mode 100644 index d32b1aeac..000000000 --- a/topic-social-proof/server/src/main/thrift/tweet_info.thrift +++ /dev/null @@ -1,26 +0,0 @@ -namespace java com.twitter.tsp.thriftjava -namespace py gen.twitter.tsp -#@namespace scala com.twitter.tsp.thriftscala -#@namespace strato com.twitter.tsp.strato - -struct TspTweetInfo { - 1: required i64 authorId - 2: required i64 favCount - 3: optional string language - 6: optional bool hasImage - 7: optional bool hasVideo - 8: optional bool hasGif - 9: optional bool isNsfwAuthor - 10: optional bool isKGODenylist - 11: optional bool isNullcast - // available if the tweet contains video - 12: optional i32 videoDurationSeconds - 13: optional bool isHighMediaResolution - 14: optional bool isVerticalAspectRatio - // health signal scores - 15: optional bool isPassAgathaHealthFilterStrictest - 16: optional bool isPassTweetHealthFilterStrictest - 17: optional bool isReply - 18: optional bool hasMultipleMedia - 23: optional bool hasUrl -}(persisted='false', hasPersonalData='true') diff --git a/trust_and_safety_models/README.docx b/trust_and_safety_models/README.docx new file mode 100644 index 000000000..bac436def Binary files /dev/null and b/trust_and_safety_models/README.docx differ diff --git a/trust_and_safety_models/README.md b/trust_and_safety_models/README.md deleted file mode 100644 index c16de2d3d..000000000 --- a/trust_and_safety_models/README.md +++ /dev/null @@ -1,10 +0,0 @@ -Trust and Safety Models -======================= - -We decided to open source the training code of the following models: -- pNSFWMedia: Model to detect tweets with NSFW images. This includes adult and porn content. -- pNSFWText: Model to detect tweets with NSFW text, adult/sexual topics. -- pToxicity: Model to detect toxic tweets. Toxicity includes marginal content like insults and certain types of harassment. Toxic content does not violate Twitter's terms of service. -- pAbuse: Model to detect abusive content. This includes violations of Twitter's terms of service, including hate speech, targeted harassment and abusive behavior. - -We have several more models and rules that we are not going to open source at this time because of the adversarial nature of this area. The team is considering open sourcing more models going forward and will keep the community posted accordingly. diff --git a/trust_and_safety_models/abusive/abusive_model.docx b/trust_and_safety_models/abusive/abusive_model.docx new file mode 100644 index 000000000..3e1b40d53 Binary files /dev/null and b/trust_and_safety_models/abusive/abusive_model.docx differ diff --git a/trust_and_safety_models/abusive/abusive_model.py b/trust_and_safety_models/abusive/abusive_model.py deleted file mode 100644 index 06fff4ed2..000000000 --- a/trust_and_safety_models/abusive/abusive_model.py +++ /dev/null @@ -1,276 +0,0 @@ -import tensorflow as tf - -physical_devices = tf.config.list_physical_devices('GPU') -for device in physical_devices: - tf.config.experimental.set_memory_growth(device, True) - -from twitter.hmli.nimbus.modeling.model_config import FeatureType, EncodingType, Feature, Model, LogType -from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader -from twitter.cuad.representation.models.text_encoder import TextEncoder -from twitter.cuad.representation.models.optimization import create_optimizer -from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder - -import numpy as np -import pandas as pd -import utils - -cat_names = [ -... -] - -category_features = [Feature(name=cat_name, ftype=FeatureType.CONTINUOUS) for cat_name in cat_names] -features = [ - Feature(name="tweet_text_with_media_annotations", ftype=FeatureType.STRING, encoding=EncodingType.BERT), - Feature(name="precision_nsfw", ftype=FeatureType.CONTINUOUS), - Feature(name="has_media", ftype=FeatureType.BINARY), - Feature(name="num_media", ftype=FeatureType.DISCRETE) -] + category_features - -ptos_prototype = Model( - name='ptos_prototype', - export_path="...", - features=features, -) -print(ptos_prototype) - -cq_loader = BigQueryFeatureLoader(gcp_project=COMPUTE_PROJECT) -labels = [ - "has_non_punitive_action", - "has_punitive_action", - "has_punitive_action_contains_self_harm", - "has_punitive_action_encourage_self_harm", - "has_punitive_action_episodic", - "has_punitive_action_episodic_hateful_conduct", - "has_punitive_action_other_abuse_policy", - "has_punitive_action_without_self_harm" -] - -train_query = f""" -SELECT - {{feature_names}}, - {",".join(labels)}, -... -""" -val_query = f""" -SELECT - {{feature_names}}, - {",".join(labels)}, -... -""" - -print(train_query) -train = cq_loader.load_features(ptos_prototype, "", "", custom_query=train_query) -val = cq_loader.load_features(ptos_prototype, "", "", custom_query=val_query) -print(train.describe(model=ptos_prototype)) - -params = { - 'max_seq_lengths': 128, - 'batch_size': 196, - 'lr': 1e-5, - 'optimizer_type': 'adamw', - 'warmup_steps': 0, - 'cls_dropout_rate': 0.1, - 'epochs': 30, - 'steps_per_epoch': 5000, - 'model_type': 'twitter_multilingual_bert_base_cased_mlm', - 'mixed_precision': True, -} -params - -def parse_labeled_data(row_dict): - label = [row_dict.pop(l) for l in labels] - return row_dict, label - -mirrored_strategy = tf.distribute.MirroredStrategy() -BATCH_SIZE = params['batch_size'] * mirrored_strategy.num_replicas_in_sync - -train_ds = train.to_tf_dataset().map(parse_labeled_data).shuffle(BATCH_SIZE*100).batch(BATCH_SIZE).repeat() -val_ds = val.to_tf_dataset().map(parse_labeled_data).batch(BATCH_SIZE) - -for record in train_ds: - tf.print(record) - break - -def get_positive_weights(): - """Computes positive weights used for class imbalance from training data.""" - label_weights_df = utils.get_label_weights( - "tos-data-media-full", - project_id="twttr-abusive-interact-prod", - dataset_id="tos_policy" - ) - pos_weight_tensor = tf.cast( - label_weights_df.sort_values(by='label').positive_class_weight, - dtype=tf.float32 - ) - return pos_weight_tensor - -pos_weight_tensor = get_positive_weights() -print(pos_weight_tensor) - -class TextEncoderPooledOutput(TextEncoder): - def call(self, x): - return super().call([x])["pooled_output"] - - def get_config(self): - return super().get_config() - -with mirrored_strategy.scope(): - text_encoder_pooled_output = TextEncoderPooledOutput( - params['max_seq_lengths'], - model_type=params['model_type'], - trainable=True - ) - - fe = FeatureEncoder(train) - inputs, preprocessing_head = fe.build_model_head(model=ptos_prototype, text_encoder=text_encoder_pooled_output) - - cls_dropout = tf.keras.layers.Dropout(params['cls_dropout_rate'], name="cls_dropout") - outputs = cls_dropout(preprocessing_head) - outputs = tf.keras.layers.Dense(8, name="output", dtype="float32")(outputs) - - model = tf.keras.Model( - inputs=inputs, - outputs=outputs - ) - pr_auc = tf.keras.metrics.AUC(curve="PR", num_thresholds=1000, multi_label=True, from_logits=True) - - custom_loss = lambda y_true, y_pred: utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor) - optimizer = create_optimizer( - init_lr=params["lr"], - num_train_steps=(params["epochs"] * params["steps_per_epoch"]), - num_warmup_steps=params["warmup_steps"], - optimizer_type=params["optimizer_type"], - ) - if params.get("mixed_precision"): - optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) - - model.compile( - optimizer=optimizer, - loss=custom_loss, - metrics=[pr_auc] - ) - -model.weights -model.summary() -pr_auc.name - -import getpass -import wandb -from wandb.keras import WandbCallback -try: - wandb_key = ... - wandb.login(...) - run = wandb.init(project='ptos_with_media', - group='new-split-trains', - notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.', - entity='absv', - config=params, - name='tweet-text-w-nsfw-1.1', - sync_tensorboard=True) -except FileNotFoundError: - print('Wandb key not found') - run = wandb.init(mode='disabled') -import datetime -import os - -start_train_time = datetime.datetime.now() -print(start_train_time.strftime("%m-%d-%Y (%H:%M:%S)")) -checkpoint_path = os.path.join("...") -print("Saving model checkpoints here: ", checkpoint_path) - -cp_callback = tf.keras.callbacks.ModelCheckpoint( - filepath=os.path.join(checkpoint_path, "model.{epoch:04d}.tf"), - verbose=1, - monitor=f'val_{pr_auc.name}', - mode='max', - save_freq='epoch', - save_best_only=True -) - -early_stopping_callback = tf.keras.callbacks.EarlyStopping(patience=7, - monitor=f"val_{pr_auc.name}", - mode="max") - -model.fit(train_ds, epochs=params["epochs"], validation_data=val_ds, callbacks=[cp_callback, early_stopping_callback], - steps_per_epoch=params["steps_per_epoch"], - verbose=2) - -import tensorflow_hub as hub - -gs_model_path = ... -reloaded_keras_layer = hub.KerasLayer(gs_model_path) -inputs = tf.keras.layers.Input(name="tweet__core__tweet__text", shape=(1,), dtype=tf.string) -output = reloaded_keras_layer(inputs) -v7_model = tf.keras.models.Model(inputs=inputs, outputs=output) -pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc") -roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc") -v7_model.compile(metrics=[pr_auc, roc_auc]) - -model.load_weights("...") -candidate_model = model - -with mirrored_strategy.scope(): - candidate_eval = candidate_model.evaluate(val_ds) - -test_query = f""" -SELECT - {",".join(ptos_prototype.feature_names())}, - has_media, - precision_nsfw, - {",".join(labels)}, -... -""" - -test = cq_loader.load_features(ptos_prototype, "", "", custom_query=test_query) -test = test.to_tf_dataset().map(parse_labeled_data) - -print(test) - -test_only_media = test.filter(lambda x, y: tf.equal(x["has_media"], True)) -test_only_nsfw = test.filter(lambda x, y: tf.greater_equal(x["precision_nsfw"], 0.95)) -test_no_media = test.filter(lambda x, y: tf.equal(x["has_media"], False)) -test_media_not_nsfw = test.filter(lambda x, y: tf.logical_and(tf.equal(x["has_media"], True), tf.less(x["precision_nsfw"], 0.95))) -for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]: - print(d.reduce(0, lambda x, _: x + 1).numpy()) - -from notebook_eval_utils import SparseMultilabelEvaluator, EvalConfig -from dataclasses import asdict - -def display_metrics(probs, targets, labels=labels): - eval_config = EvalConfig(prediction_threshold=0.5, precision_k=0.9) - for eval_mode, y_mask in [("implicit", np.ones(targets.shape))]: - print("Evaluation mode", eval_mode) - metrics = SparseMultilabelEvaluator.evaluate( - targets, np.array(probs), y_mask, classes=labels, eval_config=eval_config - ) - metrics_df = pd.DataFrame.from_dict(asdict(metrics)["per_topic_metrics"]).transpose() - metrics_df["pos_to_neg"] = metrics_df["num_pos_samples"] / (metrics_df["num_neg_samples"] + 1) - display(metrics_df.median()) - display(metrics_df) - return metrics_df - - -def eval_model(model, df): - with mirrored_strategy.scope(): - targets = np.stack(list(df.map(lambda x, y: y).as_numpy_iterator()), axis=0) - df = df.padded_batch(BATCH_SIZE) - preds = model.predict(df) - return display_metrics(preds, targets) - -subsets = {"test": test, - "test_only_media": test_only_media, - "test_only_nsfw": test_only_nsfw, - "test_no_media": test_no_media, - "test_media_not_nsfw": test_media_not_nsfw} - -metrics = {} -for name, df in subsets.items(): - metrics[name] = eval_model(candidate_model, df) -[(name, m.pr_auc) for name, m in metrics.items()] -for name, x in [(name, m.pr_auc.to_string(index=False).strip().split("\n")) for name, m in metrics.items()]: - print(name) - for y in x: - print(y.strip(), end="\t") - print(".") -for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]: - print(d.reduce(0, lambda x, _: x + 1).numpy()) \ No newline at end of file diff --git a/trust_and_safety_models/nsfw/nsfw_media.docx b/trust_and_safety_models/nsfw/nsfw_media.docx new file mode 100644 index 000000000..cf8e17e41 Binary files /dev/null and b/trust_and_safety_models/nsfw/nsfw_media.docx differ diff --git a/trust_and_safety_models/nsfw/nsfw_media.py b/trust_and_safety_models/nsfw/nsfw_media.py deleted file mode 100644 index b5dfebb65..000000000 --- a/trust_and_safety_models/nsfw/nsfw_media.py +++ /dev/null @@ -1,466 +0,0 @@ -import kerastuner as kt -import math -import numpy as np -import pandas as pd -import random -import sklearn.metrics -import tensorflow as tf -import os -import glob - -from tqdm import tqdm -from matplotlib import pyplot as plt -from tensorflow.keras.models import Sequential -from tensorflow.keras.layers import Dense -from google.cloud import storage - -physical_devices = tf.config.list_physical_devices('GPU') -physical_devices - -tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU') -tf.config.get_visible_devices('GPU') - -def decode_fn_embedding(example_proto): - - feature_description = { - "embedding": tf.io.FixedLenFeature([256], dtype=tf.float32), - "labels": tf.io.FixedLenFeature([], dtype=tf.int64), - } - - example = tf.io.parse_single_example( - example_proto, - feature_description - ) - - return example - -def preprocess_embedding_example(example_dict, positive_label=1, features_as_dict=False): - labels = example_dict["labels"] - label = tf.math.reduce_any(labels == positive_label) - label = tf.cast(label, tf.int32) - embedding = example_dict["embedding"] - - if features_as_dict: - features = {"embedding": embedding} - else: - features = embedding - - return features, label -input_root = ... -sens_prev_input_root = ... - -use_sens_prev_data = True -has_validation_data = True -positive_label = 1 - -train_batch_size = 256 -test_batch_size = 256 -validation_batch_size = 256 - -do_resample = False -def class_func(features, label): - return label - -resample_fn = tf.data.experimental.rejection_resample( - class_func, target_dist = [0.5, 0.5], seed=0 -) -train_glob = f"{input_root}/train/tfrecord/*.tfrecord" -train_files = tf.io.gfile.glob(train_glob) - -if use_sens_prev_data: - train_sens_prev_glob = f"{sens_prev_input_root}/train/tfrecord/*.tfrecord" - train_sens_prev_files = tf.io.gfile.glob(train_sens_prev_glob) - train_files = train_files + train_sens_prev_files - -random.shuffle(train_files) - -if not len(train_files): - raise ValueError(f"Did not find any train files matching {train_glob}") - - -test_glob = f"{input_root}/test/tfrecord/*.tfrecord" -test_files = tf.io.gfile.glob(test_glob) - -if not len(test_files): - raise ValueError(f"Did not find any eval files matching {test_glob}") - -test_ds = tf.data.TFRecordDataset(test_files).map(decode_fn_embedding) -test_ds = test_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=test_batch_size) - -if use_sens_prev_data: - test_sens_prev_glob = f"{sens_prev_input_root}/test/tfrecord/*.tfrecord" - test_sens_prev_files = tf.io.gfile.glob(test_sens_prev_glob) - - if not len(test_sens_prev_files): - raise ValueError(f"Did not find any eval files matching {test_sens_prev_glob}") - - test_sens_prev_ds = tf.data.TFRecordDataset(test_sens_prev_files).map(decode_fn_embedding) - test_sens_prev_ds = test_sens_prev_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=test_batch_size) - -train_ds = tf.data.TFRecordDataset(train_files).map(decode_fn_embedding) -train_ds = train_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)) - -if do_resample: - train_ds = train_ds.apply(resample_fn).map(lambda _,b:(b)) - -train_ds = train_ds.batch(batch_size=256).shuffle(buffer_size=10) -train_ds = train_ds.repeat() - - -if has_validation_data: - eval_glob = f"{input_root}/validation/tfrecord/*.tfrecord" - eval_files = tf.io.gfile.glob(eval_glob) - - if use_sens_prev_data: - eval_sens_prev_glob = f"{sens_prev_input_root}/validation/tfrecord/*.tfrecord" - eval_sens_prev_files = tf.io.gfile.glob(eval_sens_prev_glob) - eval_files = eval_files + eval_sens_prev_files - - - if not len(eval_files): - raise ValueError(f"Did not find any eval files matching {eval_glob}") - - eval_ds = tf.data.TFRecordDataset(eval_files).map(decode_fn_embedding) - eval_ds = eval_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=validation_batch_size) - -else: - - eval_ds = tf.data.TFRecordDataset(test_files).map(decode_fn_embedding) - eval_ds = eval_ds.map(lambda x: preprocess_embedding_example(x, positive_label=positive_label)).batch(batch_size=validation_batch_size) -check_ds = tf.data.TFRecordDataset(train_files).map(decode_fn_embedding) -cnt = 0 -pos_cnt = 0 -for example in tqdm(check_ds): - label = example['labels'] - if label == 1: - pos_cnt += 1 - cnt += 1 -print(f'{cnt} train entries with {pos_cnt} positive') - -metrics = [] - -metrics.append( - tf.keras.metrics.PrecisionAtRecall( - recall=0.9, num_thresholds=200, class_id=None, name=None, dtype=None - ) -) - -metrics.append( - tf.keras.metrics.AUC( - num_thresholds=200, - curve="PR", - ) -) -def build_model(hp): - model = Sequential() - - optimizer = tf.keras.optimizers.Adam( - learning_rate=0.001, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-08, - amsgrad=False, - name="Adam", - ) - - activation=hp.Choice("activation", ["tanh", "gelu"]) - kernel_initializer=hp.Choice("kernel_initializer", ["he_uniform", "glorot_uniform"]) - for i in range(hp.Int("num_layers", 1, 2)): - model.add(tf.keras.layers.BatchNormalization()) - - units=hp.Int("units", min_value=128, max_value=256, step=128) - - if i == 0: - model.add( - Dense( - units=units, - activation=activation, - kernel_initializer=kernel_initializer, - input_shape=(None, 256) - ) - ) - else: - model.add( - Dense( - units=units, - activation=activation, - kernel_initializer=kernel_initializer, - ) - ) - - model.add(Dense(1, activation='sigmoid', kernel_initializer=kernel_initializer)) - model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=metrics) - - return model - -tuner = kt.tuners.BayesianOptimization( - build_model, - objective=kt.Objective('val_loss', direction="min"), - max_trials=30, - directory='tuner_dir', - project_name='with_twitter_clip') - -callbacks = [tf.keras.callbacks.EarlyStopping( - monitor='val_loss', min_delta=0, patience=5, verbose=0, - mode='auto', baseline=None, restore_best_weights=True -)] - -steps_per_epoch = 400 -tuner.search(train_ds, - epochs=100, - batch_size=256, - steps_per_epoch=steps_per_epoch, - verbose=2, - validation_data=eval_ds, - callbacks=callbacks) - -tuner.results_summary() -models = tuner.get_best_models(num_models=2) -best_model = models[0] - -best_model.build(input_shape=(None, 256)) -best_model.summary() - -tuner.get_best_hyperparameters()[0].values - -optimizer = tf.keras.optimizers.Adam( - learning_rate=0.001, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-08, - amsgrad=False, - name="Adam", - ) -best_model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=metrics) -best_model.summary() - -callbacks = [tf.keras.callbacks.EarlyStopping( - monitor='val_loss', min_delta=0, patience=10, verbose=0, - mode='auto', baseline=None, restore_best_weights=True -)] -history = best_model.fit(train_ds, epochs=100, validation_data=eval_ds, steps_per_epoch=steps_per_epoch, callbacks=callbacks) - -model_name = 'twitter_hypertuned' -model_path = f'models/nsfw_Keras_with_CLIP_{model_name}' -tf.keras.models.save_model(best_model, model_path) - -def copy_local_directory_to_gcs(local_path, bucket, gcs_path): - """Recursively copy a directory of files to GCS. - - local_path should be a directory and not have a trailing slash. - """ - assert os.path.isdir(local_path) - for local_file in glob.glob(local_path + '/**'): - if not os.path.isfile(local_file): - dir_name = os.path.basename(os.path.normpath(local_file)) - copy_local_directory_to_gcs(local_file, bucket, f"{gcs_path}/{dir_name}") - else: - remote_path = os.path.join(gcs_path, local_file[1 + len(local_path) :]) - blob = bucket.blob(remote_path) - blob.upload_from_filename(local_file) - -client = storage.Client(project=...) -bucket = client.get_bucket(...) -copy_local_directory_to_gcs(model_path, bucket, model_path) -copy_local_directory_to_gcs('tuner_dir', bucket, 'tuner_dir') -loaded_model = tf.keras.models.load_model(model_path) -print(history.history.keys()) - -plt.figure(figsize = (20, 5)) - -plt.subplot(1, 3, 1) -plt.plot(history.history['auc']) -plt.plot(history.history['val_auc']) -plt.title('model auc') -plt.ylabel('auc') -plt.xlabel('epoch') -plt.legend(['train', 'test'], loc='upper left') - -plt.subplot(1, 3, 2) -plt.plot(history.history['loss']) -plt.plot(history.history['val_loss']) -plt.title('model loss') -plt.ylabel('loss') -plt.xlabel('epoch') -plt.legend(['train', 'test'], loc='upper left') - -plt.subplot(1, 3, 3) -plt.plot(history.history['precision_at_recall']) -plt.plot(history.history['val_precision_at_recall']) -plt.title('model precision at 0.9 recall') -plt.ylabel('precision_at_recall') -plt.xlabel('epoch') -plt.legend(['train', 'test'], loc='upper left') - -plt.savefig('history_with_twitter_clip.pdf') - -test_labels = [] -test_preds = [] - -for batch_features, batch_labels in tqdm(test_ds): - test_preds.extend(loaded_model.predict_proba(batch_features)) - test_labels.extend(batch_labels.numpy()) - -test_sens_prev_labels = [] -test_sens_prev_preds = [] - -for batch_features, batch_labels in tqdm(test_sens_prev_ds): - test_sens_prev_preds.extend(loaded_model.predict_proba(batch_features)) - test_sens_prev_labels.extend(batch_labels.numpy()) - -n_test_pos = 0 -n_test_neg = 0 -n_test = 0 - -for label in test_labels: - n_test +=1 - if label == 1: - n_test_pos +=1 - else: - n_test_neg +=1 - -print(f'n_test = {n_test}, n_pos = {n_test_pos}, n_neg = {n_test_neg}') - -n_test_sens_prev_pos = 0 -n_test_sens_prev_neg = 0 -n_test_sens_prev = 0 - -for label in test_sens_prev_labels: - n_test_sens_prev +=1 - if label == 1: - n_test_sens_prev_pos +=1 - else: - n_test_sens_prev_neg +=1 - -print(f'n_test_sens_prev = {n_test_sens_prev}, n_pos_sens_prev = {n_test_sens_prev_pos}, n_neg = {n_test_sens_prev_neg}') - -test_weights = np.ones(np.asarray(test_preds).shape) - -test_labels = np.asarray(test_labels) -test_preds = np.asarray(test_preds) -test_weights = np.asarray(test_weights) - -pr = sklearn.metrics.precision_recall_curve( - test_labels, - test_preds) - -auc = sklearn.metrics.auc(pr[1], pr[0]) -plt.plot(pr[1], pr[0]) -plt.title("nsfw (MU test set)") - -test_sens_prev_weights = np.ones(np.asarray(test_sens_prev_preds).shape) - -test_sens_prev_labels = np.asarray(test_sens_prev_labels) -test_sens_prev_preds = np.asarray(test_sens_prev_preds) -test_sens_prev_weights = np.asarray(test_sens_prev_weights) - -pr_sens_prev = sklearn.metrics.precision_recall_curve( - test_sens_prev_labels, - test_sens_prev_preds) - -auc_sens_prev = sklearn.metrics.auc(pr_sens_prev[1], pr_sens_prev[0]) -plt.plot(pr_sens_prev[1], pr_sens_prev[0]) -plt.title("nsfw (sens prev test set)") - -df = pd.DataFrame( - { - "label": test_labels.squeeze(), - "preds_keras": np.asarray(test_preds).flatten(), - }) -plt.figure(figsize=(15, 10)) -df["preds_keras"].hist() -plt.title("Keras predictions", size=20) -plt.xlabel('score') -plt.ylabel("freq") - -plt.figure(figsize = (20, 5)) -plt.subplot(1, 3, 1) - -plt.plot(pr[2], pr[0][0:-1]) -plt.xlabel("threshold") -plt.ylabel("precision") - -plt.subplot(1, 3, 2) - -plt.plot(pr[2], pr[1][0:-1]) -plt.xlabel("threshold") -plt.ylabel("recall") -plt.title("Keras", size=20) - -plt.subplot(1, 3, 3) - -plt.plot(pr[1], pr[0]) -plt.xlabel("recall") -plt.ylabel("precision") - -plt.savefig('with_twitter_clip.pdf') - -def get_point_for_recall(recall_value, recall, precision): - idx = np.argmin(np.abs(recall - recall_value)) - return (recall[idx], precision[idx]) - -def get_point_for_precision(precision_value, recall, precision): - idx = np.argmin(np.abs(precision - precision_value)) - return (recall[idx], precision[idx]) -precision, recall, thresholds = pr - -auc_precision_recall = sklearn.metrics.auc(recall, precision) - -print(auc_precision_recall) - -plt.figure(figsize=(15, 10)) -plt.plot(recall, precision) - -plt.xlabel("recall") -plt.ylabel("precision") - -ptAt50 = get_point_for_recall(0.5, recall, precision) -print(ptAt50) -plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r') -plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r') - -ptAt90 = get_point_for_recall(0.9, recall, precision) -print(ptAt90) -plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b') -plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b') - -ptAt50fmt = "%.4f" % ptAt50[1] -ptAt90fmt = "%.4f" % ptAt90[1] -aucFmt = "%.4f" % auc_precision_recall -plt.title( - f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...}} ({...} pos), N_test={n_test} ({n_test_pos} pos)", - size=20 -) -plt.subplots_adjust(top=0.72) -plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf') - -precision, recall, thresholds = pr_sens_prev - -auc_precision_recall = sklearn.metrics.auc(recall, precision) -print(auc_precision_recall) -plt.figure(figsize=(15, 10)) - -plt.plot(recall, precision) - -plt.xlabel("recall") -plt.ylabel("precision") - -ptAt50 = get_point_for_recall(0.5, recall, precision) -print(ptAt50) -plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r') -plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r') - -ptAt90 = get_point_for_recall(0.9, recall, precision) -print(ptAt90) -plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b') -plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b') - -ptAt50fmt = "%.4f" % ptAt50[1] -ptAt90fmt = "%.4f" % ptAt90[1] -aucFmt = "%.4f" % auc_precision_recall -plt.title( - f"Keras (nsfw sens prev test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test_sens_prev} ({n_test_sens_prev_pos} pos)", - size=20 -) -plt.subplots_adjust(top=0.72) -plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_sens_prev_test.pdf') \ No newline at end of file diff --git a/trust_and_safety_models/nsfw/nsfw_text.docx b/trust_and_safety_models/nsfw/nsfw_text.docx new file mode 100644 index 000000000..e56a7ae2f Binary files /dev/null and b/trust_and_safety_models/nsfw/nsfw_text.docx differ diff --git a/trust_and_safety_models/nsfw/nsfw_text.py b/trust_and_safety_models/nsfw/nsfw_text.py deleted file mode 100644 index 980fc8fd4..000000000 --- a/trust_and_safety_models/nsfw/nsfw_text.py +++ /dev/null @@ -1,152 +0,0 @@ -from datetime import datetime -from functools import reduce -import os -import pandas as pd -import re -from sklearn.metrics import average_precision_score, classification_report, precision_recall_curve, PrecisionRecallDisplay -from sklearn.model_selection import train_test_split -import tensorflow as tf -import matplotlib.pyplot as plt -import re - -from twitter.cuad.representation.models.optimization import create_optimizer -from twitter.cuad.representation.models.text_encoder import TextEncoder - -pd.set_option('display.max_colwidth', None) -pd.set_option('display.expand_frame_repr', False) - -print(tf.__version__) -print(tf.config.list_physical_devices()) - -log_path = os.path.join('pnsfwtweettext_model_runs', datetime.now().strftime('%Y-%m-%d_%H.%M.%S')) - -tweet_text_feature = 'text' - -params = { - 'batch_size': 32, - 'max_seq_lengths': 256, - 'model_type': 'twitter_bert_base_en_uncased_augmented_mlm', - 'trainable_text_encoder': True, - 'lr': 5e-5, - 'epochs': 10, -} - -REGEX_PATTERNS = [ - r'^RT @[A-Za-z0-9_]+: ', - r"@[A-Za-z0-9_]+", - r'https:\/\/t\.co\/[A-Za-z0-9]{10}', - r'@\?\?\?\?\?', -] - -EMOJI_PATTERN = re.compile( - "([" - "\U0001F1E0-\U0001F1FF" - "\U0001F300-\U0001F5FF" - "\U0001F600-\U0001F64F" - "\U0001F680-\U0001F6FF" - "\U0001F700-\U0001F77F" - "\U0001F780-\U0001F7FF" - "\U0001F800-\U0001F8FF" - "\U0001F900-\U0001F9FF" - "\U0001FA00-\U0001FA6F" - "\U0001FA70-\U0001FAFF" - "\U00002702-\U000027B0" - "])" - ) - -def clean_tweet(text): - for pattern in REGEX_PATTERNS: - text = re.sub(pattern, '', text) - - text = re.sub(EMOJI_PATTERN, r' \1 ', text) - - text = re.sub(r'\n', ' ', text) - - return text.strip().lower() - - -df['processed_text'] = df['text'].astype(str).map(clean_tweet) -df.sample(10) - -X_train, X_val, y_train, y_val = train_test_split(df[['processed_text']], df['is_nsfw'], test_size=0.1, random_state=1) - -def df_to_ds(X, y, shuffle=False): - ds = tf.data.Dataset.from_tensor_slices(( - X.values, - tf.one_hot(tf.cast(y.values, tf.int32), depth=2, axis=-1) - )) - - if shuffle: - ds = ds.shuffle(1000, seed=1, reshuffle_each_iteration=True) - - return ds.map(lambda text, label: ({ tweet_text_feature: text }, label)).batch(params['batch_size']) - -ds_train = df_to_ds(X_train, y_train, shuffle=True) -ds_val = df_to_ds(X_val, y_val) -X_train.values - -inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name=tweet_text_feature) -encoder = TextEncoder( - max_seq_lengths=params['max_seq_lengths'], - model_type=params['model_type'], - trainable=params['trainable_text_encoder'], - local_preprocessor_path='demo-preprocessor' -) -embedding = encoder([inputs])["pooled_output"] -predictions = tf.keras.layers.Dense(2, activation='softmax')(embedding) -model = tf.keras.models.Model(inputs=inputs, outputs=predictions) - -model.summary() - -optimizer = create_optimizer( - params['lr'], - params['epochs'] * len(ds_train), - 0, - weight_decay_rate=0.01, - optimizer_type='adamw' -) -bce = tf.keras.losses.BinaryCrossentropy(from_logits=False) -pr_auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=1000, from_logits=False) -model.compile(optimizer=optimizer, loss=bce, metrics=[pr_auc]) - -callbacks = [ - tf.keras.callbacks.EarlyStopping( - monitor='val_loss', - mode='min', - patience=1, - restore_best_weights=True - ), - tf.keras.callbacks.ModelCheckpoint( - filepath=os.path.join(log_path, 'checkpoints', '{epoch:02d}'), - save_freq='epoch' - ), - tf.keras.callbacks.TensorBoard( - log_dir=os.path.join(log_path, 'scalars'), - update_freq='batch', - write_graph=False - ) -] -history = model.fit( - ds_train, - epochs=params['epochs'], - callbacks=callbacks, - validation_data=ds_val, - steps_per_epoch=len(ds_train) -) - -model.predict(["xxx 🍑"]) - -preds = X_val.processed_text.apply(apply_model) -print(classification_report(y_val, preds >= 0.90, digits=4)) - -precision, recall, thresholds = precision_recall_curve(y_val, preds) - -fig = plt.figure(figsize=(15, 10)) -plt.plot(precision, recall, lw=2) -plt.grid() -plt.xlim(0.2, 1) -plt.ylim(0.3, 1) -plt.xlabel("Recall", size=20) -plt.ylabel("Precision", size=20) - -average_precision_score(y_val, preds) diff --git a/trust_and_safety_models/toxicity/__init__.docx b/trust_and_safety_models/toxicity/__init__.docx new file mode 100644 index 000000000..cd6b29278 Binary files /dev/null and b/trust_and_safety_models/toxicity/__init__.docx differ diff --git a/trust_and_safety_models/toxicity/__init__.py b/trust_and_safety_models/toxicity/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trust_and_safety_models/toxicity/data/__init__.docx b/trust_and_safety_models/toxicity/data/__init__.docx new file mode 100644 index 000000000..cd6b29278 Binary files /dev/null and b/trust_and_safety_models/toxicity/data/__init__.docx differ diff --git a/trust_and_safety_models/toxicity/data/__init__.py b/trust_and_safety_models/toxicity/data/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trust_and_safety_models/toxicity/data/data_preprocessing.docx b/trust_and_safety_models/toxicity/data/data_preprocessing.docx new file mode 100644 index 000000000..05793f6e9 Binary files /dev/null and b/trust_and_safety_models/toxicity/data/data_preprocessing.docx differ diff --git a/trust_and_safety_models/toxicity/data/data_preprocessing.py b/trust_and_safety_models/toxicity/data/data_preprocessing.py deleted file mode 100644 index f7da608f6..000000000 --- a/trust_and_safety_models/toxicity/data/data_preprocessing.py +++ /dev/null @@ -1,118 +0,0 @@ -from abc import ABC -import re - -from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35 - -import numpy as np - - -TOXIC_35_set = set(TOXIC_35) - -url_group = r"(\bhttps?:\/\/\S+)" -mention_group = r"(\B@\S+)" -urls_mentions_re = re.compile(url_group + r"|" + mention_group, re.IGNORECASE) -url_re = re.compile(url_group, re.IGNORECASE) -mention_re = re.compile(mention_group, re.IGNORECASE) -newline_re = re.compile(r"\n+", re.IGNORECASE) -and_re = re.compile(r"&\s?amp\s?;", re.IGNORECASE) - - -class DataframeCleaner(ABC): - def __init__(self): - pass - - def _clean(self, df): - return df - - def _systematic_preprocessing(self, df): - df.reset_index(inplace=True, drop=True) - if "media_url" in df.columns: - print(".... removing tweets with media") - df.drop(df[~df.media_url.isna()].index, inplace=True, axis=0) - else: - print("WARNING you are not removing tweets with media to train a BERT model.") - - print(".... deleting duplicates") - df.drop_duplicates("text", inplace=True, keep="last") - print(f"Got {df.shape[0]} after cleaning") - - return df.reset_index(inplace=False, drop=True) - - def _postprocess(self, df, *args, **kwargs): - return df - - def __call__(self, df, *args, **kwargs): - print(f"Got {df.shape[0]} before cleaning") - - df["raw_text"] = df.text - df = self._clean(df) - - df = self._systematic_preprocessing(df) - - return self._postprocess(df, *args, **kwargs) - - -def mapping_func(el): - if el.aggregated_content in TOXIC_35_set: - return 2 - if el.label == 1: - return 1 - return 0 - - -class DefaultENNoPreprocessor(DataframeCleaner): - def _postprocess(self, df, *args, **kwargs): - if "toxic_count" in df.columns and "non_toxic_count" in df.columns: - df["vote"] = df.toxic_count / (df.toxic_count + df.non_toxic_count) - df["agreement_rate"] = np.max((df.vote, 1 - df.vote), axis=0) - - if "label_column" in kwargs and kwargs["label_column"] != "label": - if kwargs["label_column"] == "aggregated_content": - print("Replacing v3 label by v3.5 label.") - if "num_classes" in kwargs and kwargs["num_classes"] < 3: - df["label"] = np.where(df.aggregated_content.isin(TOXIC_35_set), 1, 0) - elif "num_classes" in kwargs and kwargs["num_classes"] == 3: - print("Making it a 3-class pb") - df["label"] = df.apply(mapping_func, axis=1) - else: - raise NotImplementedError - elif kwargs['label_column'] in df.columns: - df['label'] = df[kwargs['label_column']] - if kwargs['class_weight'] is not None: - df["class_weight"] = np.where(df['label'] == 1, 1-kwargs['class_weight'], - kwargs['class_weight']) - else: - raise NotImplementedError - - if "filter_low_agreements" in kwargs and kwargs["filter_low_agreements"] == True: - df.drop(df[(df.agreement_rate <= 0.6)].index, axis=0, inplace=True) - raise NotImplementedError - - return df - - -class DefaultENPreprocessor(DefaultENNoPreprocessor): - def _clean(self, adhoc_df): - print( - ".... removing \\n and replacing @mentions and URLs by placeholders. " - "Emoji filtering is not done." - ) - adhoc_df["text"] = [url_re.sub("URL", tweet) for tweet in adhoc_df.raw_text.values] - adhoc_df["text"] = [mention_re.sub("MENTION", tweet) for tweet in adhoc_df.text.values] - adhoc_df["text"] = [ - newline_re.sub(" ", tweet).lstrip(" ").rstrip(" ") for tweet in adhoc_df.text.values - ] - adhoc_df["text"] = [and_re.sub("&", tweet) for tweet in adhoc_df.text.values] - - return adhoc_df - - -class Defaulti18nPreprocessor(DataframeCleaner): - def _clean(self, adhoc_df): - print(".... removing @mentions, \\n and URLs. Emoji filtering is not done.") - adhoc_df["text"] = [urls_mentions_re.sub("", tweet) for tweet in adhoc_df.raw_text.values] - adhoc_df["text"] = [ - newline_re.sub(" ", tweet).lstrip(" ").rstrip(" ") for tweet in adhoc_df.text.values - ] - - return adhoc_df diff --git a/trust_and_safety_models/toxicity/data/dataframe_loader.docx b/trust_and_safety_models/toxicity/data/dataframe_loader.docx new file mode 100644 index 000000000..4150efe98 Binary files /dev/null and b/trust_and_safety_models/toxicity/data/dataframe_loader.docx differ diff --git a/trust_and_safety_models/toxicity/data/dataframe_loader.py b/trust_and_safety_models/toxicity/data/dataframe_loader.py deleted file mode 100644 index f3855d6b5..000000000 --- a/trust_and_safety_models/toxicity/data/dataframe_loader.py +++ /dev/null @@ -1,348 +0,0 @@ -from abc import ABC, abstractmethod -from datetime import date -from importlib import import_module -import pickle - -from toxicity_ml_pipeline.settings.default_settings_tox import ( - CLIENT, - EXISTING_TASK_VERSIONS, - GCS_ADDRESS, - TRAINING_DATA_LOCATION, -) -from toxicity_ml_pipeline.utils.helpers import execute_command, execute_query -from toxicity_ml_pipeline.utils.queries import ( - FULL_QUERY, - FULL_QUERY_W_TWEET_TYPES, - PARSER_UDF, - QUERY_SETTINGS, -) - -import numpy as np -import pandas - - -class DataframeLoader(ABC): - - def __init__(self, project): - self.project = project - - @abstractmethod - def produce_query(self): - pass - - @abstractmethod - def load_data(self, test=False): - pass - - -class ENLoader(DataframeLoader): - def __init__(self, project, setting_file): - super(ENLoader, self).__init__(project=project) - self.date_begin = setting_file.DATE_BEGIN - self.date_end = setting_file.DATE_END - TASK_VERSION = setting_file.TASK_VERSION - if TASK_VERSION not in EXISTING_TASK_VERSIONS: - raise ValueError - self.task_version = TASK_VERSION - self.query_settings = dict(QUERY_SETTINGS) - self.full_query = FULL_QUERY - - def produce_query(self, date_begin, date_end, task_version=None, **keys): - task_version = self.task_version if task_version is None else task_version - - if task_version in keys["table"]: - table_name = keys["table"][task_version] - print(f"Loading {table_name}") - - main_query = keys["main"].format( - table=table_name, - parser_udf=PARSER_UDF[task_version], - date_begin=date_begin, - date_end=date_end, - ) - - return self.full_query.format( - main_table_query=main_query, date_begin=date_begin, date_end=date_end - ) - return "" - - def _reload(self, test, file_keyword): - query = f"SELECT * from `{TRAINING_DATA_LOCATION.format(project=self.project)}_{file_keyword}`" - - if test: - query += " ORDER BY RAND() LIMIT 1000" - try: - df = execute_query(client=CLIENT, query=query) - except Exception: - print( - "Loading from BQ failed, trying to load from GCS. " - "NB: use this option only for intermediate files, which will be deleted at the end of " - "the project." - ) - copy_cmd = f"gsutil cp {GCS_ADDRESS.format(project=self.project)}/training_data/{file_keyword}.pkl ." - execute_command(copy_cmd) - try: - with open(f"{file_keyword}.pkl", "rb") as file: - df = pickle.load(file) - except Exception: - return None - - if test: - df = df.sample(frac=1) - return df.iloc[:1000] - - return df - - def load_data(self, test=False, **kwargs): - if "reload" in kwargs and kwargs["reload"]: - df = self._reload(test, kwargs["reload"]) - if df is not None and df.shape[0] > 0: - return df - - df = None - query_settings = self.query_settings - if test: - query_settings = {"fairness": self.query_settings["fairness"]} - query_settings["fairness"]["main"] += " LIMIT 500" - - for table, query_info in query_settings.items(): - curr_query = self.produce_query( - date_begin=self.date_begin, date_end=self.date_end, **query_info - ) - if curr_query == "": - continue - curr_df = execute_query(client=CLIENT, query=curr_query) - curr_df["origin"] = table - df = curr_df if df is None else pandas.concat((df, curr_df)) - - df["loading_date"] = date.today() - df["date"] = pandas.to_datetime(df.date) - return df - - def load_precision_set( - self, begin_date="...", end_date="...", with_tweet_types=False, task_version=3.5 - ): - if with_tweet_types: - self.full_query = FULL_QUERY_W_TWEET_TYPES - - query_settings = self.query_settings - curr_query = self.produce_query( - date_begin=begin_date, - date_end=end_date, - task_version=task_version, - **query_settings["precision"], - ) - curr_df = execute_query(client=CLIENT, query=curr_query) - - curr_df.rename(columns={"media_url": "media_presence"}, inplace=True) - return curr_df - - -class ENLoaderWithSampling(ENLoader): - - keywords = { - "politics": [ -... - ], - "insults": [ -... - ], - "race": [ -... - ], - } - n = ... - N = ... - - def __init__(self, project): - self.raw_loader = ENLoader(project=project) - if project == ...: - self.project = project - else: - raise ValueError - - def sample_with_weights(self, df, n): - w = df["label"].value_counts(normalize=True)[1] - dist = np.full((df.shape[0],), w) - sampled_df = df.sample(n=n, weights=dist, replace=False) - return sampled_df - - def sample_keywords(self, df, N, group): - print("\nmatching", group, "keywords...") - - keyword_list = self.keywords[group] - match_df = df.loc[df.text.str.lower().str.contains("|".join(keyword_list), regex=True)] - - print("sampling N/3 from", group) - if match_df.shape[0] <= N / 3: - print( - "WARNING: Sampling only", - match_df.shape[0], - "instead of", - N / 3, - "examples from race focused tweets due to insufficient data", - ) - sample_df = match_df - - else: - print( - "sampling", - group, - "at", - round(match_df["label"].value_counts(normalize=True)[1], 3), - "% action rate", - ) - sample_df = self.sample_with_weights(match_df, int(N / 3)) - print(sample_df.shape) - print(sample_df.label.value_counts(normalize=True)) - - print("\nshape of df before dropping sampled rows after", group, "matching..", df.shape[0]) - df = df.loc[ - df.index.difference(sample_df.index), - ] - print("\nshape of df after dropping sampled rows after", group, "matching..", df.shape[0]) - - return df, sample_df - - def sample_first_set_helper(self, train_df, first_set, new_n): - if first_set == "prev": - fset = train_df.loc[train_df["origin"].isin(["prevalence", "causal prevalence"])] - print( - "sampling prev at", round(fset["label"].value_counts(normalize=True)[1], 3), "% action rate" - ) - else: - fset = train_df - - n_fset = self.sample_with_weights(fset, new_n) - print("len of sampled first set", n_fset.shape[0]) - print(n_fset.label.value_counts(normalize=True)) - - return n_fset - - def sample(self, df, first_set, second_set, keyword_sampling, n, N): - train_df = df[df.origin != "precision"] - val_test_df = df[df.origin == "precision"] - - print("\nsampling first set of data") - new_n = n - N if second_set is not None else n - n_fset = self.sample_first_set_helper(train_df, first_set, new_n) - - print("\nsampling second set of data") - train_df = train_df.loc[ - train_df.index.difference(n_fset.index), - ] - - if second_set is None: - print("no second set sampling being done") - df = n_fset.append(val_test_df) - return df - - if second_set == "prev": - sset = train_df.loc[train_df["origin"].isin(["prevalence", "causal prevalence"])] - - elif second_set == "fdr": - sset = train_df.loc[train_df["origin"] == "fdr"] - - else: - sset = train_df - - if keyword_sampling == True: - print("sampling based off of keywords defined...") - print("second set is", second_set, "with length", sset.shape[0]) - - sset, n_politics = self.sample_keywords(sset, N, "politics") - sset, n_insults = self.sample_keywords(sset, N, "insults") - sset, n_race = self.sample_keywords(sset, N, "race") - - n_sset = n_politics.append([n_insults, n_race]) - print("len of sampled second set", n_sset.shape[0]) - - else: - print( - "No keyword sampling. Instead random sampling from", - second_set, - "at", - round(sset["label"].value_counts(normalize=True)[1], 3), - "% action rate", - ) - n_sset = self.sample_with_weights(sset, N) - print("len of sampled second set", n_sset.shape[0]) - print(n_sset.label.value_counts(normalize=True)) - - df = n_fset.append([n_sset, val_test_df]) - df = df.sample(frac=1).reset_index(drop=True) - - return df - - def load_data( - self, first_set="prev", second_set=None, keyword_sampling=False, test=False, **kwargs - ): - n = kwargs.get("n", self.n) - N = kwargs.get("N", self.N) - - df = self.raw_loader.load_data(test=test, **kwargs) - return self.sample(df, first_set, second_set, keyword_sampling, n, N) - - -class I18nLoader(DataframeLoader): - def __init__(self): - super().__init__(project=...) - from archive.settings.... import ACCEPTED_LANGUAGES, QUERY_SETTINGS - - self.accepted_languages = ACCEPTED_LANGUAGES - self.query_settings = dict(QUERY_SETTINGS) - - def produce_query(self, language, query, dataset, table, lang): - query = query.format(dataset=dataset, table=table) - add_query = f"AND reviewed.{lang}='{language}'" - query += add_query - - return query - - def query_keys(self, language, task=2, size="50"): - if task == 2: - if language == "ar": - self.query_settings["adhoc_v2"]["table"] = "..." - elif language == "tr": - self.query_settings["adhoc_v2"]["table"] = "..." - elif language == "es": - self.query_settings["adhoc_v2"]["table"] = f"..." - else: - self.query_settings["adhoc_v2"]["table"] = "..." - - return self.query_settings["adhoc_v2"] - - if task == 3: - return self.query_settings["adhoc_v3"] - - raise ValueError(f"There are no other tasks than 2 or 3. {task} does not exist.") - - def load_data(self, language, test=False, task=2): - if language not in self.accepted_languages: - raise ValueError( - f"Language not in the data {language}. Accepted values are " f"{self.accepted_languages}" - ) - - print(".... adhoc data") - key_dict = self.query_keys(language=language, task=task) - query_adhoc = self.produce_query(language=language, **key_dict) - if test: - query_adhoc += " LIMIT 500" - adhoc_df = execute_query(CLIENT, query_adhoc) - - if not (test or language == "tr" or task == 3): - if language == "es": - print(".... additional adhoc data") - key_dict = self.query_keys(language=language, size="100") - query_adhoc = self.produce_query(language=language, **key_dict) - adhoc_df = pandas.concat( - (adhoc_df, execute_query(CLIENT, query_adhoc)), axis=0, ignore_index=True - ) - - print(".... prevalence data") - query_prev = self.produce_query(language=language, **self.query_settings["prevalence_v2"]) - prev_df = execute_query(CLIENT, query_prev) - prev_df["description"] = "Prevalence" - adhoc_df = pandas.concat((adhoc_df, prev_df), axis=0, ignore_index=True) - - return self.clean(adhoc_df) diff --git a/trust_and_safety_models/toxicity/data/mb_generator.docx b/trust_and_safety_models/toxicity/data/mb_generator.docx new file mode 100644 index 000000000..4d619e210 Binary files /dev/null and b/trust_and_safety_models/toxicity/data/mb_generator.docx differ diff --git a/trust_and_safety_models/toxicity/data/mb_generator.py b/trust_and_safety_models/toxicity/data/mb_generator.py deleted file mode 100644 index 58a89f8c5..000000000 --- a/trust_and_safety_models/toxicity/data/mb_generator.py +++ /dev/null @@ -1,284 +0,0 @@ -from importlib import import_module -import os - -from toxicity_ml_pipeline.settings.default_settings_tox import ( - INNER_CV, - LOCAL_DIR, - MAX_SEQ_LENGTH, - NUM_PREFETCH, - NUM_WORKERS, - OUTER_CV, - TARGET_POS_PER_EPOCH, -) -from toxicity_ml_pipeline.utils.helpers import execute_command - -import numpy as np -import pandas -from sklearn.model_selection import StratifiedKFold -import tensorflow as tf - - -try: - from transformers import AutoTokenizer, DataCollatorWithPadding -except ModuleNotFoundError: - print("...") -else: - from datasets import Dataset - - -class BalancedMiniBatchLoader(object): - def __init__( - self, - fold, - mb_size, - seed, - perc_training_tox, - scope="TOX", - project=..., - dual_head=None, - n_outer_splits=None, - n_inner_splits=None, - sample_weights=None, - huggingface=False, - ): - if 0 >= perc_training_tox or perc_training_tox > 0.5: - raise ValueError("Perc_training_tox should be in ]0; 0.5]") - - self.perc_training_tox = perc_training_tox - if not n_outer_splits: - n_outer_splits = OUTER_CV - if isinstance(n_outer_splits, int): - self.n_outer_splits = n_outer_splits - self.get_outer_fold = self._get_outer_cv_fold - if fold < 0 or fold >= self.n_outer_splits or int(fold) != fold: - raise ValueError(f"Number of fold should be an integer in [0 ; {self.n_outer_splits} [.") - - elif n_outer_splits == "time": - self.get_outer_fold = self._get_time_fold - if fold != "time": - raise ValueError( - "To avoid repeating the same run many times, the external fold" - "should be time when test data is split according to dates." - ) - try: - setting_file = import_module(f"toxicity_ml_pipeline.settings.{scope.lower()}{project}_settings") - except ModuleNotFoundError: - raise ValueError(f"You need to define a setting file for your project {project}.") - self.test_begin_date = setting_file.TEST_BEGIN_DATE - self.test_end_date = setting_file.TEST_END_DATE - - else: - raise ValueError( - f"Argument n_outer_splits should either an integer or 'time'. Provided: {n_outer_splits}" - ) - - self.n_inner_splits = n_inner_splits if n_inner_splits is not None else INNER_CV - - self.seed = seed - self.mb_size = mb_size - self.fold = fold - - self.sample_weights = sample_weights - self.dual_head = dual_head - self.huggingface = huggingface - if self.huggingface: - self._load_tokenizer() - - def _load_tokenizer(self): - print("Making a local copy of Bertweet-base model") - local_model_dir = os.path.join(LOCAL_DIR, "models") - cmd = f"mkdir {local_model_dir} ; gsutil -m cp -r gs://... {local_model_dir}" - execute_command(cmd) - - self.tokenizer = AutoTokenizer.from_pretrained( - os.path.join(local_model_dir, "bertweet-base"), normalization=True - ) - - def tokenize_function(self, el): - return self.tokenizer( - el["text"], - max_length=MAX_SEQ_LENGTH, - padding="max_length", - truncation=True, - add_special_tokens=True, - return_token_type_ids=False, - return_attention_mask=False, - ) - - def _get_stratified_kfold(self, n_splits): - return StratifiedKFold(shuffle=True, n_splits=n_splits, random_state=self.seed) - - def _get_time_fold(self, df): - test_begin_date = pandas.to_datetime(self.test_begin_date).date() - test_end_date = pandas.to_datetime(self.test_end_date).date() - print(f"Test is going from {test_begin_date} to {test_end_date}.") - test_data = df.query("@test_begin_date <= date <= @test_end_date") - - query = "date < @test_begin_date" - other_set = df.query(query) - return other_set, test_data - - def _get_outer_cv_fold(self, df): - labels = df.int_label - stratifier = self._get_stratified_kfold(n_splits=self.n_outer_splits) - - k = 0 - for train_index, test_index in stratifier.split(np.zeros(len(labels)), labels): - if k == self.fold: - break - k += 1 - - train_data = df.iloc[train_index].copy() - test_data = df.iloc[test_index].copy() - - return train_data, test_data - - def get_steps_per_epoch(self, nb_pos_examples): - return int(max(TARGET_POS_PER_EPOCH, nb_pos_examples) / self.mb_size / self.perc_training_tox) - - def make_huggingface_tensorflow_ds(self, group, mb_size=None, shuffle=True): - huggingface_ds = Dataset.from_pandas(group).map(self.tokenize_function, batched=True) - data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer, return_tensors="tf") - tensorflow_ds = huggingface_ds.to_tf_dataset( - columns=["input_ids"], - label_cols=["labels"], - shuffle=shuffle, - batch_size=self.mb_size if mb_size is None else mb_size, - collate_fn=data_collator, - ) - - if shuffle: - return tensorflow_ds.repeat() - return tensorflow_ds - - def make_pure_tensorflow_ds(self, df, nb_samples): - buffer_size = nb_samples * 2 - - if self.sample_weights is not None: - if self.sample_weights not in df.columns: - raise ValueError - ds = tf.data.Dataset.from_tensor_slices( - (df.text.values, df.label.values, df[self.sample_weights].values) - ) - elif self.dual_head: - label_d = {f'{e}_output': df[f'{e}_label'].values for e in self.dual_head} - label_d['content_output'] = tf.keras.utils.to_categorical(label_d['content_output'], num_classes=3) - ds = tf.data.Dataset.from_tensor_slices((df.text.values, label_d)) - - else: - ds = tf.data.Dataset.from_tensor_slices((df.text.values, df.label.values)) - ds = ds.shuffle(buffer_size, seed=self.seed, reshuffle_each_iteration=True).repeat() - return ds - - def get_balanced_dataset(self, training_data, size_limit=None, return_as_batch=True): - training_data = training_data.sample(frac=1, random_state=self.seed) - nb_samples = training_data.shape[0] if not size_limit else size_limit - - num_classes = training_data.int_label.nunique() - toxic_class = training_data.int_label.max() - if size_limit: - training_data = training_data[: size_limit * num_classes] - - print( - ".... {} examples, incl. {:.2f}% tox in train, {} classes".format( - nb_samples, - 100 * training_data[training_data.int_label == toxic_class].shape[0] / nb_samples, - num_classes, - ) - ) - label_groups = training_data.groupby("int_label") - if self.huggingface: - label_datasets = { - label: self.make_huggingface_tensorflow_ds(group) for label, group in label_groups - } - - else: - label_datasets = { - label: self.make_pure_tensorflow_ds(group, nb_samples=nb_samples * 2) - for label, group in label_groups - } - - datasets = [label_datasets[0], label_datasets[1]] - weights = [1 - self.perc_training_tox, self.perc_training_tox] - if num_classes == 3: - datasets.append(label_datasets[2]) - weights = [1 - self.perc_training_tox, self.perc_training_tox / 2, self.perc_training_tox / 2] - elif num_classes != 2: - raise ValueError("Currently it should not be possible to get other than 2 or 3 classes") - resampled_ds = tf.data.experimental.sample_from_datasets(datasets, weights, seed=self.seed) - - if return_as_batch and not self.huggingface: - return resampled_ds.batch( - self.mb_size, drop_remainder=True, num_parallel_calls=NUM_WORKERS, deterministic=True - ).prefetch(NUM_PREFETCH) - - return resampled_ds - - @staticmethod - def _compute_int_labels(full_df): - if full_df.label.dtype == int: - full_df["int_label"] = full_df.label - - elif "int_label" not in full_df.columns: - if full_df.label.max() > 1: - raise ValueError("Binarizing labels that should not be.") - full_df["int_label"] = np.where(full_df.label >= 0.5, 1, 0) - - return full_df - - def __call__(self, full_df, *args, **kwargs): - full_df = self._compute_int_labels(full_df) - - train_data, test_data = self.get_outer_fold(df=full_df) - - stratifier = self._get_stratified_kfold(n_splits=self.n_inner_splits) - for train_index, val_index in stratifier.split( - np.zeros(train_data.shape[0]), train_data.int_label - ): - curr_train_data = train_data.iloc[train_index] - - mini_batches = self.get_balanced_dataset(curr_train_data) - - steps_per_epoch = self.get_steps_per_epoch( - nb_pos_examples=curr_train_data[curr_train_data.int_label != 0].shape[0] - ) - - val_data = train_data.iloc[val_index].copy() - - yield mini_batches, steps_per_epoch, val_data, test_data - - def simple_cv_load(self, full_df): - full_df = self._compute_int_labels(full_df) - - train_data, test_data = self.get_outer_fold(df=full_df) - if test_data.shape[0] == 0: - test_data = train_data.iloc[:500] - - mini_batches = self.get_balanced_dataset(train_data) - steps_per_epoch = self.get_steps_per_epoch( - nb_pos_examples=train_data[train_data.int_label != 0].shape[0] - ) - - return mini_batches, test_data, steps_per_epoch - - def no_cv_load(self, full_df): - full_df = self._compute_int_labels(full_df) - - val_test = full_df[full_df.origin == "precision"].copy(deep=True) - val_data, test_data = self.get_outer_fold(df=val_test) - - train_data = full_df.drop(full_df[full_df.origin == "precision"].index, axis=0) - if test_data.shape[0] == 0: - test_data = train_data.iloc[:500] - - mini_batches = self.get_balanced_dataset(train_data) - if train_data.int_label.nunique() == 1: - raise ValueError('Should be at least two labels') - - num_examples = train_data[train_data.int_label == 1].shape[0] - if train_data.int_label.nunique() > 2: - second_most_frequent_label = train_data.loc[train_data.int_label != 0, 'int_label'].mode().values[0] - num_examples = train_data[train_data.int_label == second_most_frequent_label].shape[0] * 2 - steps_per_epoch = self.get_steps_per_epoch(nb_pos_examples=num_examples) - - return mini_batches, steps_per_epoch, val_data, test_data diff --git a/trust_and_safety_models/toxicity/load_model.docx b/trust_and_safety_models/toxicity/load_model.docx new file mode 100644 index 000000000..b2f3623da Binary files /dev/null and b/trust_and_safety_models/toxicity/load_model.docx differ diff --git a/trust_and_safety_models/toxicity/load_model.py b/trust_and_safety_models/toxicity/load_model.py deleted file mode 100644 index 7b271066f..000000000 --- a/trust_and_safety_models/toxicity/load_model.py +++ /dev/null @@ -1,227 +0,0 @@ -import os - -from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH -try: - from toxicity_ml_pipeline.optim.losses import MaskedBCE -except ImportError: - print('No MaskedBCE loss') -from toxicity_ml_pipeline.utils.helpers import execute_command - -import tensorflow as tf - - -try: - from twitter.cuad.representation.models.text_encoder import TextEncoder -except ModuleNotFoundError: - print("No TextEncoder package") - -try: - from transformers import TFAutoModelForSequenceClassification -except ModuleNotFoundError: - print("No HuggingFace package") - -LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models") - - -def reload_model_weights(weights_dir, language, **kwargs): - optimizer = tf.keras.optimizers.Adam(0.01) - model_type = ( - "twitter_bert_base_en_uncased_mlm" - if language == "en" - else "twitter_multilingual_bert_base_cased_mlm" - ) - model = load(optimizer=optimizer, seed=42, model_type=model_type, **kwargs) - model.load_weights(weights_dir) - - return model - - -def _locally_copy_models(model_type): - if model_type == "twitter_multilingual_bert_base_cased_mlm": - preprocessor = "bert_multi_cased_preprocess_3" - elif model_type == "twitter_bert_base_en_uncased_mlm": - preprocessor = "bert_en_uncased_preprocess_3" - else: - raise NotImplementedError - - copy_cmd = """mkdir {local_dir} -gsutil cp -r ... -gsutil cp -r ...""" - execute_command( - copy_cmd.format(model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR) - ) - - return preprocessor - - -def load_encoder(model_type, trainable): - try: - model = TextEncoder( - max_seq_lengths=MAX_SEQ_LENGTH, - model_type=model_type, - cluster="gcp", - trainable=trainable, - enable_dynamic_shapes=True, - ) - except (OSError, tf.errors.AbortedError) as e: - print(e) - preprocessor = _locally_copy_models(model_type) - - model = TextEncoder( - max_seq_lengths=MAX_SEQ_LENGTH, - local_model_path=f"models/{model_type}", - local_preprocessor_path=f"models/{preprocessor}", - cluster="gcp", - trainable=trainable, - enable_dynamic_shapes=True, - ) - - return model - - -def get_loss(loss_name, from_logits, **kwargs): - loss_name = loss_name.lower() - if loss_name == "bce": - print("Binary CE loss") - return tf.keras.losses.BinaryCrossentropy(from_logits=from_logits) - - if loss_name == "cce": - print("Categorical cross-entropy loss") - return tf.keras.losses.CategoricalCrossentropy(from_logits=from_logits) - - if loss_name == "scce": - print("Sparse categorical cross-entropy loss") - return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits) - - if loss_name == "focal_bce": - gamma = kwargs.get("gamma", 2) - print("Focal binary CE loss", gamma) - return tf.keras.losses.BinaryFocalCrossentropy(gamma=gamma, from_logits=from_logits) - - if loss_name == 'masked_bce': - multitask = kwargs.get("multitask", False) - if from_logits or multitask: - raise NotImplementedError - print(f'Masked Binary Cross Entropy') - return MaskedBCE() - - if loss_name == "inv_kl_loss": - raise NotImplementedError - - raise ValueError( - f"This loss name is not valid: {loss_name}. Accepted loss names: BCE, masked BCE, CCE, sCCE, " - f"Focal_BCE, inv_KL_loss" - ) - -def _add_additional_embedding_layer(doc_embedding, glorot, seed): - doc_embedding = tf.keras.layers.Dense(768, activation="tanh", kernel_initializer=glorot)(doc_embedding) - doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding) - return doc_embedding - -def _get_bias(**kwargs): - smart_bias_value = kwargs.get('smart_bias_value', 0) - print('Smart bias init to ', smart_bias_value) - output_bias = tf.keras.initializers.Constant(smart_bias_value) - return output_bias - - -def load_inhouse_bert(model_type, trainable, seed, **kwargs): - inputs = tf.keras.layers.Input(shape=(), dtype=tf.string) - encoder = load_encoder(model_type=model_type, trainable=trainable) - doc_embedding = encoder([inputs])["pooled_output"] - doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding) - - glorot = tf.keras.initializers.glorot_uniform(seed=seed) - if kwargs.get("additional_layer", False): - doc_embedding = _add_additional_embedding_layer(doc_embedding, glorot, seed) - - if kwargs.get('content_num_classes', None): - probs = get_last_layer(glorot=glorot, last_layer_name='target_output', **kwargs)(doc_embedding) - second_probs = get_last_layer(num_classes=kwargs['content_num_classes'], - last_layer_name='content_output', - glorot=glorot)(doc_embedding) - probs = [probs, second_probs] - else: - probs = get_last_layer(glorot=glorot, **kwargs)(doc_embedding) - model = tf.keras.models.Model(inputs=inputs, outputs=probs) - - return model, False - -def get_last_layer(**kwargs): - output_bias = _get_bias(**kwargs) - if 'glorot' in kwargs: - glorot = kwargs['glorot'] - else: - glorot = tf.keras.initializers.glorot_uniform(seed=kwargs['seed']) - layer_name = kwargs.get('last_layer_name', 'dense_1') - - if kwargs.get('num_classes', 1) > 1: - last_layer = tf.keras.layers.Dense( - kwargs["num_classes"], activation="softmax", kernel_initializer=glorot, - bias_initializer=output_bias, name=layer_name - ) - - elif kwargs.get('num_raters', 1) > 1: - if kwargs.get('multitask', False): - raise NotImplementedError - last_layer = tf.keras.layers.Dense( - kwargs['num_raters'], activation="sigmoid", kernel_initializer=glorot, - bias_initializer=output_bias, name='probs') - - else: - last_layer = tf.keras.layers.Dense( - 1, activation="sigmoid", kernel_initializer=glorot, - bias_initializer=output_bias, name=layer_name - ) - - return last_layer - -def load_bertweet(**kwargs): - bert = TFAutoModelForSequenceClassification.from_pretrained( - os.path.join(LOCAL_MODEL_DIR, "bertweet-base"), - num_labels=1, - classifier_dropout=0.1, - hidden_size=768, - ) - if "num_classes" in kwargs and kwargs["num_classes"] > 2: - raise NotImplementedError - - return bert, True - - -def load( - optimizer, - seed, - model_type="twitter_multilingual_bert_base_cased_mlm", - loss_name="BCE", - trainable=True, - **kwargs, -): - if model_type == "bertweet-base": - model, from_logits = load_bertweet() - else: - model, from_logits = load_inhouse_bert(model_type, trainable, seed, **kwargs) - - pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc", from_logits=from_logits) - roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc", from_logits=from_logits) - - loss = get_loss(loss_name, from_logits, **kwargs) - if kwargs.get('content_num_classes', None): - second_loss = get_loss(loss_name=kwargs['content_loss_name'], from_logits=from_logits) - loss_weights = {'content_output': kwargs['content_loss_weight'], 'target_output': 1} - model.compile( - optimizer=optimizer, - loss={'content_output': second_loss, 'target_output': loss}, - loss_weights=loss_weights, - metrics=[pr_auc, roc_auc], - ) - - else: - model.compile( - optimizer=optimizer, - loss=loss, - metrics=[pr_auc, roc_auc], - ) - print(model.summary(), "logits: ", from_logits) - - return model \ No newline at end of file diff --git a/trust_and_safety_models/toxicity/optim/__init__.docx b/trust_and_safety_models/toxicity/optim/__init__.docx new file mode 100644 index 000000000..cd6b29278 Binary files /dev/null and b/trust_and_safety_models/toxicity/optim/__init__.docx differ diff --git a/trust_and_safety_models/toxicity/optim/__init__.py b/trust_and_safety_models/toxicity/optim/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trust_and_safety_models/toxicity/optim/callbacks.docx b/trust_and_safety_models/toxicity/optim/callbacks.docx new file mode 100644 index 000000000..90df1986b Binary files /dev/null and b/trust_and_safety_models/toxicity/optim/callbacks.docx differ diff --git a/trust_and_safety_models/toxicity/optim/callbacks.py b/trust_and_safety_models/toxicity/optim/callbacks.py deleted file mode 100644 index bbf8d7c97..000000000 --- a/trust_and_safety_models/toxicity/optim/callbacks.py +++ /dev/null @@ -1,220 +0,0 @@ -from collections import defaultdict -import os - -from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR -from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES -from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data -from toxicity_ml_pipeline.utils.helpers import compute_precision_fixed_recall, execute_command - -from sklearn.metrics import average_precision_score, roc_auc_score -import tensorflow as tf -import wandb - - -class NothingCallback(tf.keras.callbacks.Callback): - def on_epoch_begin(self, epoch, logs=None): - print("ici, ", epoch) - - def on_epoch_end(self, epoch, logs=None): - print("fin ", epoch) - - def on_train_batch_end(self, batch, logs=None): - print("fin de batch ", batch) - - -class ControlledStoppingCheckpointCallback(tf.keras.callbacks.ModelCheckpoint): - def __init__(self, stopping_epoch, *args, **kwargs): - super().__init__(*args, **kwargs) - self.stopping_epoch = stopping_epoch - - def on_epoch_end(self, epoch, logs=None): - super().on_epoch_end(epoch, logs) - if epoch == self.stopping_epoch: - self.model.stop_training = True - - -class SyncingTensorBoard(tf.keras.callbacks.TensorBoard): - def __init__(self, remote_logdir=None, *args, **kwargs): - super().__init__(*args, **kwargs) - self.remote_logdir = remote_logdir if remote_logdir is not None else REMOTE_LOGDIR - - def on_epoch_end(self, epoch, logs=None): - super().on_epoch_end(epoch, logs=logs) - self.synchronize() - - def synchronize(self): - base_dir = os.path.dirname(self.log_dir) - cmd = f"gsutil -m rsync -r {base_dir} {self.remote_logdir}" - execute_command(cmd) - - -class GradientLoggingTensorBoard(SyncingTensorBoard): - def __init__(self, loader, val_data, freq, *args, **kwargs): - super().__init__(*args, **kwargs) - val_dataset = loader.get_balanced_dataset( - training_data=val_data, size_limit=50, return_as_batch=False - ) - data_args = list(val_dataset.batch(32).take(1))[0] - self.x_batch, self.y_batch = data_args[0], data_args[1] - self.freq = freq - self.counter = 0 - - def _log_gradients(self): - writer = self._train_writer - - with writer.as_default(): - with tf.GradientTape() as tape: - y_pred = self.model(self.x_batch) - loss = self.model.compiled_loss(y_true=self.y_batch, y_pred=y_pred) - gradient_norm = tf.linalg.global_norm(tape.gradient(loss, self.model.trainable_weights)) - - tf.summary.scalar("gradient_norm", data=gradient_norm, step=self.counter) - writer.flush() - - def on_train_batch_end(self, batch, logs=None): - super().on_batch_end(batch, logs=logs) - self.counter += 1 - if batch % self.freq == 0: - self._log_gradients() - - -class AdditionalResultLogger(tf.keras.callbacks.Callback): - def __init__( - self, - data, - set_, - fixed_recall=0.85, - from_logits=False, - dataset_transform_func=None, - batch_size=64, - dual_head=None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.set_ = set_ - if data is None: - return None - - self.single_head = True - try: - self.labels = data.int_label.values - except AttributeError: - self.labels = data.to_dataframe()[LABEL_NAMES].values.astype('int') - self.data = data.to_tf_dataset().map(parse_labeled_data).batch(batch_size) - self.label_names = LABEL_NAMES - else: - self.label_names = [''] - if dual_head: - self.label_names = [f'{e}_label' for e in dual_head] - self.labels = {f'{e}_output': data[f'{e}_label'].values for e in dual_head} - self.single_head = False - if dataset_transform_func is None: - self.data = data.text.values - else: - self.data = dataset_transform_func(data, mb_size=batch_size, shuffle=False) - - finally: - if len(self.label_names) == 1: - self.metric_kw = {} - else: - self.metric_kw = {'average': None} - - self.counter = 0 - self.best_metrics = defaultdict(float) - self.from_logits = from_logits - print(f"Loaded callback for {set_}, from_logits: {from_logits}, labels {self.label_names}") - - if 1 < fixed_recall <= 100: - fixed_recall = fixed_recall / 100 - elif not (0 < fixed_recall <= 100): - raise ValueError("Threshold should be between 0 and 1, or 0 and 100") - self.fixed_recall = fixed_recall - self.batch_size = batch_size - - def compute_precision_fixed_recall(self, labels, preds): - result, _ = compute_precision_fixed_recall(labels=labels, preds=preds, - fixed_recall=self.fixed_recall) - - return result - - def on_epoch_end(self, epoch, logs=None): - self.additional_evaluations(step=epoch, eval_time="epoch") - - def on_train_batch_end(self, batch, logs=None): - self.counter += 1 - if self.counter % 2000 == 0: - self.additional_evaluations(step=self.counter, eval_time="batch") - - def _binary_evaluations(self, preds, label_name=None, class_index=None): - mask = None - curr_labels = self.labels - if label_name is not None: - curr_labels = self.labels[label_name] - if class_index is not None: - curr_labels = (curr_labels == class_index).astype(int) - - if -1 in curr_labels: - mask = curr_labels != -1 - curr_labels = curr_labels[mask] - preds = preds[mask] - - return { - f"precision_recall{self.fixed_recall}": self.compute_precision_fixed_recall( - labels=curr_labels, preds=preds - ), - "pr_auc": average_precision_score(y_true=curr_labels, y_score=preds), - "roc_auc": roc_auc_score(y_true=curr_labels, y_score=preds), - } - - - def _multiclass_evaluations(self, preds): - pr_auc_l = average_precision_score(y_true=self.labels, y_score=preds, **self.metric_kw) - roc_auc_l = roc_auc_score(y_true=self.labels, y_score=preds, **self.metric_kw) - metrics = {} - for i, label in enumerate(self.label_names): - metrics[f'pr_auc_{label}'] = pr_auc_l[i] - metrics[f'roc_auc_{label}'] = roc_auc_l[i] - - return metrics - - def additional_evaluations(self, step, eval_time): - print("Evaluating ", self.set_, eval_time, step) - - preds = self.model.predict(x=self.data, batch_size=self.batch_size) - if self.from_logits: - preds = tf.keras.activations.sigmoid(preds.logits).numpy() - - if self.single_head: - if len(self.label_names) == 1: - metrics = self._binary_evaluations(preds) - else: - metrics = self._multiclass_evaluations(preds) - else: - if preds[0].shape[1] == 1: - binary_preds = preds[0] - multic_preds = preds[1] - else: - binary_preds = preds[1] - multic_preds = preds[0] - - binary_metrics = self._binary_evaluations(binary_preds, label_name='target_output') - metrics = {f'{k}_target': v for k, v in binary_metrics.items()} - num_classes = multic_preds.shape[1] - for class_ in range(num_classes): - binary_metrics = self._binary_evaluations(multic_preds[:, class_], label_name='content_output', class_index=class_) - metrics.update({f'{k}_content_{class_}': v for k, v in binary_metrics.items()}) - - for k, v in metrics.items(): - self.best_metrics[f"max_{k}"] = max(v, self.best_metrics[f"max_{k}"]) - - self.log_metrics(metrics, step=step, eval_time=eval_time) - - def log_metrics(self, metrics_d, step, eval_time): - commit = False if self.set_ == "validation" else True - to_report = {self.set_: {**metrics_d, **self.best_metrics}} - - if eval_time == "epoch": - to_report["epoch"] = step - - wandb.log(to_report, commit=commit) diff --git a/trust_and_safety_models/toxicity/optim/losses.docx b/trust_and_safety_models/toxicity/optim/losses.docx new file mode 100644 index 000000000..58265d10b Binary files /dev/null and b/trust_and_safety_models/toxicity/optim/losses.docx differ diff --git a/trust_and_safety_models/toxicity/optim/losses.py b/trust_and_safety_models/toxicity/optim/losses.py deleted file mode 100644 index 273c6676e..000000000 --- a/trust_and_safety_models/toxicity/optim/losses.py +++ /dev/null @@ -1,56 +0,0 @@ -import tensorflow as tf -from keras.utils import tf_utils -from keras.utils import losses_utils -from keras import backend - -def inv_kl_divergence(y_true, y_pred): - y_pred = tf.convert_to_tensor(y_pred) - y_true = tf.cast(y_true, y_pred.dtype) - y_true = backend.clip(y_true, backend.epsilon(), 1) - y_pred = backend.clip(y_pred, backend.epsilon(), 1) - return tf.reduce_sum(y_pred * tf.math.log(y_pred / y_true), axis=-1) - -def masked_bce(y_true, y_pred): - y_true = tf.cast(y_true, dtype=tf.float32) - mask = y_true != -1 - - return tf.keras.metrics.binary_crossentropy(tf.boolean_mask(y_true, mask), - tf.boolean_mask(y_pred, mask)) - - -class LossFunctionWrapper(tf.keras.losses.Loss): - def __init__(self, - fn, - reduction=losses_utils.ReductionV2.AUTO, - name=None, - **kwargs): - super().__init__(reduction=reduction, name=name) - self.fn = fn - self._fn_kwargs = kwargs - - def call(self, y_true, y_pred): - if tf.is_tensor(y_pred) and tf.is_tensor(y_true): - y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true) - - ag_fn = tf.__internal__.autograph.tf_convert(self.fn, tf.__internal__.autograph.control_status_ctx()) - return ag_fn(y_true, y_pred, **self._fn_kwargs) - - def get_config(self): - config = {} - for k, v in self._fn_kwargs.items(): - config[k] = backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) - -class InvKLD(LossFunctionWrapper): - def __init__(self, - reduction=losses_utils.ReductionV2.AUTO, - name='inv_kl_divergence'): - super().__init__(inv_kl_divergence, name=name, reduction=reduction) - - -class MaskedBCE(LossFunctionWrapper): - def __init__(self, - reduction=losses_utils.ReductionV2.AUTO, - name='masked_bce'): - super().__init__(masked_bce, name=name, reduction=reduction) diff --git a/trust_and_safety_models/toxicity/optim/schedulers.docx b/trust_and_safety_models/toxicity/optim/schedulers.docx new file mode 100644 index 000000000..45af2d43a Binary files /dev/null and b/trust_and_safety_models/toxicity/optim/schedulers.docx differ diff --git a/trust_and_safety_models/toxicity/optim/schedulers.py b/trust_and_safety_models/toxicity/optim/schedulers.py deleted file mode 100644 index 59f6c9afa..000000000 --- a/trust_and_safety_models/toxicity/optim/schedulers.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Callable - -import tensorflow as tf - - -class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): - def __init__( - self, - initial_learning_rate: float, - decay_schedule_fn: Callable, - warmup_steps: int, - power: float = 1.0, - name: str = "", - ): - super().__init__() - self.initial_learning_rate = initial_learning_rate - self.warmup_steps = warmup_steps - self.power = power - self.decay_schedule_fn = decay_schedule_fn - self.name = name - - def __call__(self, step): - with tf.name_scope(self.name or "WarmUp") as name: - global_step_float = tf.cast(step, tf.float32) - warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) - warmup_percent_done = global_step_float / warmup_steps_float - warmup_learning_rate = self.initial_learning_rate * tf.math.pow( - warmup_percent_done, self.power - ) - return tf.cond( - global_step_float < warmup_steps_float, - lambda: warmup_learning_rate, - lambda: self.decay_schedule_fn(step - self.warmup_steps), - name=name, - ) - - def get_config(self): - return { - "initial_learning_rate": self.initial_learning_rate, - "decay_schedule_fn": self.decay_schedule_fn, - "warmup_steps": self.warmup_steps, - "power": self.power, - "name": self.name, - } diff --git a/trust_and_safety_models/toxicity/rescoring.docx b/trust_and_safety_models/toxicity/rescoring.docx new file mode 100644 index 000000000..a0c3a97f8 Binary files /dev/null and b/trust_and_safety_models/toxicity/rescoring.docx differ diff --git a/trust_and_safety_models/toxicity/rescoring.py b/trust_and_safety_models/toxicity/rescoring.py deleted file mode 100644 index 71d95ed76..000000000 --- a/trust_and_safety_models/toxicity/rescoring.py +++ /dev/null @@ -1,54 +0,0 @@ -from toxicity_ml_pipeline.load_model import reload_model_weights -from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model - -import numpy as np -import tensorflow as tf - - -def score(language, df, gcs_model_path, batch_size=64, text_col="text", kw="", **kwargs): - if language != "en": - raise NotImplementedError( - "Data preprocessing not implemented here, needs to be added for i18n models" - ) - model_folder = upload_model(full_gcs_model_path=gcs_model_path) - try: - inference_func = load_inference_func(model_folder) - except OSError: - model = reload_model_weights(model_folder, language, **kwargs) - preds = model.predict(x=df[text_col], batch_size=batch_size) - if type(preds) != list: - if len(preds.shape)> 1 and preds.shape[1] > 1: - if 'num_classes' in kwargs and kwargs['num_classes'] > 1: - raise NotImplementedError - preds = np.mean(preds, 1) - - df[f"prediction_{kw}"] = preds - else: - if len(preds) > 2: - raise NotImplementedError - for preds_arr in preds: - if preds_arr.shape[1] == 1: - df[f"prediction_{kw}_target"] = preds_arr - else: - for ind in range(preds_arr.shape[1]): - df[f"prediction_{kw}_content_{ind}"] = preds_arr[:, ind] - - return df - else: - return _get_score(inference_func, df, kw=kw, batch_size=batch_size, text_col=text_col) - - -def _get_score(inference_func, df, text_col="text", kw="", batch_size=64): - score_col = f"prediction_{kw}" - beginning = 0 - end = df.shape[0] - predictions = np.zeros(shape=end, dtype=float) - - while beginning < end: - mb = df[text_col].values[beginning : beginning + batch_size] - res = inference_func(input_1=tf.constant(mb)) - predictions[beginning : beginning + batch_size] = list(res.values())[0].numpy()[:, 0] - beginning += batch_size - - df[score_col] = predictions - return df diff --git a/trust_and_safety_models/toxicity/settings/__init__.docx b/trust_and_safety_models/toxicity/settings/__init__.docx new file mode 100644 index 000000000..cd6b29278 Binary files /dev/null and b/trust_and_safety_models/toxicity/settings/__init__.docx differ diff --git a/trust_and_safety_models/toxicity/settings/__init__.py b/trust_and_safety_models/toxicity/settings/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trust_and_safety_models/toxicity/settings/default_settings_tox.docx b/trust_and_safety_models/toxicity/settings/default_settings_tox.docx new file mode 100644 index 000000000..2de4b8c35 Binary files /dev/null and b/trust_and_safety_models/toxicity/settings/default_settings_tox.docx differ diff --git a/trust_and_safety_models/toxicity/settings/default_settings_tox.py b/trust_and_safety_models/toxicity/settings/default_settings_tox.py deleted file mode 100644 index 0968b9adc..000000000 --- a/trust_and_safety_models/toxicity/settings/default_settings_tox.py +++ /dev/null @@ -1,38 +0,0 @@ -import os - - -TEAM_PROJECT = "twttr-toxicity-prod" -try: - from google.cloud import bigquery -except (ModuleNotFoundError, ImportError): - print("No Google packages") - CLIENT = None -else: - from google.auth.exceptions import DefaultCredentialsError - - try: - CLIENT = bigquery.Client(project=TEAM_PROJECT) - except DefaultCredentialsError as e: - CLIENT = None - print("Issue at logging time", e) - -TRAINING_DATA_LOCATION = f"..." -GCS_ADDRESS = "..." -LOCAL_DIR = os.getcwd() -REMOTE_LOGDIR = "{GCS_ADDRESS}/logs" -MODEL_DIR = "{GCS_ADDRESS}/models" - -EXISTING_TASK_VERSIONS = {3, 3.5} - -RANDOM_SEED = ... -TRAIN_EPOCHS = 4 -MINI_BATCH_SIZE = 32 -TARGET_POS_PER_EPOCH = 5000 -PERC_TRAINING_TOX = ... -MAX_SEQ_LENGTH = 100 - -WARM_UP_PERC = 0.1 -OUTER_CV = 5 -INNER_CV = 5 -NUM_PREFETCH = 5 -NUM_WORKERS = 10 diff --git a/trust_and_safety_models/toxicity/train.docx b/trust_and_safety_models/toxicity/train.docx new file mode 100644 index 000000000..11a979339 Binary files /dev/null and b/trust_and_safety_models/toxicity/train.docx differ diff --git a/trust_and_safety_models/toxicity/train.py b/trust_and_safety_models/toxicity/train.py deleted file mode 100644 index de450ee7b..000000000 --- a/trust_and_safety_models/toxicity/train.py +++ /dev/null @@ -1,401 +0,0 @@ -from datetime import datetime -from importlib import import_module -import os - -from toxicity_ml_pipeline.data.data_preprocessing import ( - DefaultENNoPreprocessor, - DefaultENPreprocessor, -) -from toxicity_ml_pipeline.data.dataframe_loader import ENLoader, ENLoaderWithSampling -from toxicity_ml_pipeline.data.mb_generator import BalancedMiniBatchLoader -from toxicity_ml_pipeline.load_model import load, get_last_layer -from toxicity_ml_pipeline.optim.callbacks import ( - AdditionalResultLogger, - ControlledStoppingCheckpointCallback, - GradientLoggingTensorBoard, - SyncingTensorBoard, -) -from toxicity_ml_pipeline.optim.schedulers import WarmUp -from toxicity_ml_pipeline.settings.default_settings_abs import GCS_ADDRESS as ABS_GCS -from toxicity_ml_pipeline.settings.default_settings_tox import ( - GCS_ADDRESS as TOX_GCS, - MODEL_DIR, - RANDOM_SEED, - REMOTE_LOGDIR, - WARM_UP_PERC, -) -from toxicity_ml_pipeline.utils.helpers import check_gpu, set_seeds, upload_model - -import numpy as np -import tensorflow as tf - - -try: - from tensorflow_addons.optimizers import AdamW -except ModuleNotFoundError: - print("No TFA") - - -class Trainer(object): - OPTIMIZERS = ["Adam", "AdamW"] - - def __init__( - self, - optimizer_name, - weight_decay, - learning_rate, - mb_size, - train_epochs, - content_loss_weight=1, - language="en", - scope='TOX', - project=..., - experiment_id="default", - gradient_clipping=None, - fold="time", - seed=RANDOM_SEED, - log_gradients=False, - kw="", - stopping_epoch=None, - test=False, - ): - self.seed = seed - self.weight_decay = weight_decay - self.learning_rate = learning_rate - self.mb_size = mb_size - self.train_epochs = train_epochs - self.gradient_clipping = gradient_clipping - - if optimizer_name not in self.OPTIMIZERS: - raise ValueError( - f"Optimizer {optimizer_name} not implemented. Accepted values {self.OPTIMIZERS}." - ) - self.optimizer_name = optimizer_name - self.log_gradients = log_gradients - self.test = test - self.fold = fold - self.stopping_epoch = stopping_epoch - self.language = language - if scope == 'TOX': - GCS_ADDRESS = TOX_GCS.format(project=project) - elif scope == 'ABS': - GCS_ADDRESS = ABS_GCS - else: - raise ValueError - GCS_ADDRESS = GCS_ADDRESS.format(project=project) - try: - self.setting_file = import_module(f"toxicity_ml_pipeline.settings.{scope.lower()}{project}_settings") - except ModuleNotFoundError: - raise ValueError(f"You need to define a setting file for your project {project}.") - experiment_settings = self.setting_file.experiment_settings - - self.project = project - self.remote_logdir = REMOTE_LOGDIR.format(GCS_ADDRESS=GCS_ADDRESS, project=project) - self.model_dir = MODEL_DIR.format(GCS_ADDRESS=GCS_ADDRESS, project=project) - - if experiment_id not in experiment_settings: - raise ValueError("This is not an experiment id as defined in the settings file.") - - for var, default_value in experiment_settings["default"].items(): - override_val = experiment_settings[experiment_id].get(var, default_value) - print("Setting ", var, override_val) - self.__setattr__(var, override_val) - - self.content_loss_weight = content_loss_weight if self.dual_head else None - - self.mb_loader = BalancedMiniBatchLoader( - fold=self.fold, - seed=self.seed, - perc_training_tox=self.perc_training_tox, - mb_size=self.mb_size, - n_outer_splits="time", - scope=scope, - project=project, - dual_head=self.dual_head, - sample_weights=self.sample_weights, - huggingface=("bertweet" in self.model_type), - ) - self._init_dirnames(kw=kw, experiment_id=experiment_id) - print("------- Checking there is a GPU") - check_gpu() - - def _init_dirnames(self, kw, experiment_id): - kw = "test" if self.test else kw - hyper_param_kw = "" - if self.optimizer_name == "AdamW": - hyper_param_kw += f"{self.weight_decay}_" - if self.gradient_clipping: - hyper_param_kw += f"{self.gradient_clipping}_" - if self.content_loss_weight: - hyper_param_kw += f"{self.content_loss_weight}_" - experiment_name = ( - f"{self.language}{str(datetime.now()).replace(' ', '')[:-7]}{kw}_{experiment_id}{self.fold}_" - f"{self.optimizer_name}_" - f"{self.learning_rate}_" - f"{hyper_param_kw}" - f"{self.mb_size}_" - f"{self.perc_training_tox}_" - f"{self.train_epochs}_seed{self.seed}" - ) - print("------- Experiment name: ", experiment_name) - self.logdir = ( - f"..." - if self.test - else f"..." - ) - self.checkpoint_path = f"{self.model_dir}/{experiment_name}" - - @staticmethod - def _additional_writers(logdir, metric_name): - return tf.summary.create_file_writer(os.path.join(logdir, metric_name)) - - def get_callbacks(self, fold, val_data, test_data): - fold_logdir = self.logdir + f"_fold{fold}" - fold_checkpoint_path = self.checkpoint_path + f"_fold{fold}/{{epoch:02d}}" - - tb_args = { - "log_dir": fold_logdir, - "histogram_freq": 0, - "update_freq": 500, - "embeddings_freq": 0, - "remote_logdir": f"{self.remote_logdir}_{self.language}" - if not self.test - else f"{self.remote_logdir}_test", - } - tensorboard_callback = ( - GradientLoggingTensorBoard(loader=self.mb_loader, val_data=val_data, freq=10, **tb_args) - if self.log_gradients - else SyncingTensorBoard(**tb_args) - ) - - callbacks = [tensorboard_callback] - if "bertweet" in self.model_type: - from_logits = True - dataset_transform_func = self.mb_loader.make_huggingface_tensorflow_ds - else: - from_logits = False - dataset_transform_func = None - - fixed_recall = 0.85 if not self.dual_head else 0.5 - val_callback = AdditionalResultLogger( - data=val_data, - set_="validation", - from_logits=from_logits, - dataset_transform_func=dataset_transform_func, - dual_head=self.dual_head, - fixed_recall=fixed_recall - ) - if val_callback is not None: - callbacks.append(val_callback) - - test_callback = AdditionalResultLogger( - data=test_data, - set_="test", - from_logits=from_logits, - dataset_transform_func=dataset_transform_func, - dual_head=self.dual_head, - fixed_recall=fixed_recall - ) - callbacks.append(test_callback) - - checkpoint_args = { - "filepath": fold_checkpoint_path, - "verbose": 0, - "monitor": "val_pr_auc", - "save_weights_only": True, - "mode": "max", - "save_freq": "epoch", - } - if self.stopping_epoch: - checkpoint_callback = ControlledStoppingCheckpointCallback( - **checkpoint_args, - stopping_epoch=self.stopping_epoch, - save_best_only=False, - ) - callbacks.append(checkpoint_callback) - - return callbacks - - def get_lr_schedule(self, steps_per_epoch): - total_num_steps = steps_per_epoch * self.train_epochs - - warm_up_perc = WARM_UP_PERC if self.learning_rate >= 1e-3 else 0 - warm_up_steps = int(total_num_steps * warm_up_perc) - if self.linear_lr_decay: - learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( - self.learning_rate, - total_num_steps - warm_up_steps, - end_learning_rate=0.0, - power=1.0, - cycle=False, - ) - else: - print('Constant learning rate') - learning_rate_fn = self.learning_rate - - if warm_up_perc > 0: - print(f".... using warm-up for {warm_up_steps} steps") - warm_up_schedule = WarmUp( - initial_learning_rate=self.learning_rate, - decay_schedule_fn=learning_rate_fn, - warmup_steps=warm_up_steps, - ) - return warm_up_schedule - return learning_rate_fn - - def get_optimizer(self, schedule): - optim_args = { - "learning_rate": schedule, - "beta_1": 0.9, - "beta_2": 0.999, - "epsilon": 1e-6, - "amsgrad": False, - } - if self.gradient_clipping: - optim_args["global_clipnorm"] = self.gradient_clipping - - print(f".... {self.optimizer_name} w global clipnorm {self.gradient_clipping}") - if self.optimizer_name == "Adam": - return tf.keras.optimizers.Adam(**optim_args) - - if self.optimizer_name == "AdamW": - optim_args["weight_decay"] = self.weight_decay - return AdamW(**optim_args) - raise NotImplementedError - - def get_training_actors(self, steps_per_epoch, val_data, test_data, fold): - callbacks = self.get_callbacks(fold=fold, val_data=val_data, test_data=test_data) - schedule = self.get_lr_schedule(steps_per_epoch=steps_per_epoch) - - optimizer = self.get_optimizer(schedule) - - return optimizer, callbacks - - def load_data(self): - if self.project == 435 or self.project == 211: - if self.dataset_type is None: - data_loader = ENLoader(project=self.project, setting_file=self.setting_file) - dataset_type_args = {} - else: - data_loader = ENLoaderWithSampling(project=self.project, setting_file=self.setting_file) - dataset_type_args = self.dataset_type - - df = data_loader.load_data( - language=self.language, test=self.test, reload=self.dataset_reload, **dataset_type_args - ) - - return df - - def preprocess(self, df): - if self.project == 435 or self.project == 211: - if self.preprocessing is None: - data_prepro = DefaultENNoPreprocessor() - elif self.preprocessing == "default": - data_prepro = DefaultENPreprocessor() - else: - raise NotImplementedError - - return data_prepro( - df=df, - label_column=self.label_column, - class_weight=self.perc_training_tox if self.sample_weights == 'class_weight' else None, - filter_low_agreements=self.filter_low_agreements, - num_classes=self.num_classes, - ) - - def load_model(self, optimizer): - smart_bias_value = ( - np.log(self.perc_training_tox / (1 - self.perc_training_tox)) if self.smart_bias_init else 0 - ) - model = load( - optimizer, - seed=self.seed, - trainable=self.trainable, - model_type=self.model_type, - loss_name=self.loss_name, - num_classes=self.num_classes, - additional_layer=self.additional_layer, - smart_bias_value=smart_bias_value, - content_num_classes=self.content_num_classes, - content_loss_name=self.content_loss_name, - content_loss_weight=self.content_loss_weight - ) - - if self.model_reload is not False: - model_folder = upload_model(full_gcs_model_path=os.path.join(self.model_dir, self.model_reload)) - model.load_weights(model_folder) - if self.scratch_last_layer: - print('Putting the last layer back to scratch') - model.layers[-1] = get_last_layer(seed=self.seed, - num_classes=self.num_classes, - smart_bias_value=smart_bias_value) - - return model - - def _train_single_fold(self, mb_generator, test_data, steps_per_epoch, fold, val_data=None): - steps_per_epoch = 100 if self.test else steps_per_epoch - - optimizer, callbacks = self.get_training_actors( - steps_per_epoch=steps_per_epoch, val_data=val_data, test_data=test_data, fold=fold - ) - print("Loading model") - model = self.load_model(optimizer) - print(f"Nb of steps per epoch: {steps_per_epoch} ---- launching training") - training_args = { - "epochs": self.train_epochs, - "steps_per_epoch": steps_per_epoch, - "batch_size": self.mb_size, - "callbacks": callbacks, - "verbose": 2, - } - - model.fit(mb_generator, **training_args) - return - - def train_full_model(self): - print("Setting up random seed.") - set_seeds(self.seed) - - print(f"Loading {self.language} data") - df = self.load_data() - df = self.preprocess(df=df) - - print("Going to train on everything but the test dataset") - mini_batches, test_data, steps_per_epoch = self.mb_loader.simple_cv_load(df) - - self._train_single_fold( - mb_generator=mini_batches, test_data=test_data, steps_per_epoch=steps_per_epoch, fold="full" - ) - - def train(self): - print("Setting up random seed.") - set_seeds(self.seed) - - print(f"Loading {self.language} data") - df = self.load_data() - df = self.preprocess(df=df) - - print("Loading MB generator") - i = 0 - if self.project == 435 or self.project == 211: - mb_generator, steps_per_epoch, val_data, test_data = self.mb_loader.no_cv_load(full_df=df) - self._train_single_fold( - mb_generator=mb_generator, - val_data=val_data, - test_data=test_data, - steps_per_epoch=steps_per_epoch, - fold=i, - ) - else: - raise ValueError("Sure you want to do multiple fold training") - for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df): - self._train_single_fold( - mb_generator=mb_generator, - val_data=val_data, - test_data=test_data, - steps_per_epoch=steps_per_epoch, - fold=i, - ) - i += 1 - if i == 3: - break diff --git a/trust_and_safety_models/toxicity/utils/__init__.docx b/trust_and_safety_models/toxicity/utils/__init__.docx new file mode 100644 index 000000000..cd6b29278 Binary files /dev/null and b/trust_and_safety_models/toxicity/utils/__init__.docx differ diff --git a/trust_and_safety_models/toxicity/utils/__init__.py b/trust_and_safety_models/toxicity/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trust_and_safety_models/toxicity/utils/helpers.docx b/trust_and_safety_models/toxicity/utils/helpers.docx new file mode 100644 index 000000000..17da02a72 Binary files /dev/null and b/trust_and_safety_models/toxicity/utils/helpers.docx differ diff --git a/trust_and_safety_models/toxicity/utils/helpers.py b/trust_and_safety_models/toxicity/utils/helpers.py deleted file mode 100644 index c21d7eb1c..000000000 --- a/trust_and_safety_models/toxicity/utils/helpers.py +++ /dev/null @@ -1,99 +0,0 @@ -import bisect -import os -import random as python_random -import subprocess - -from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR - -import numpy as np -from sklearn.metrics import precision_recall_curve - - -try: - import tensorflow as tf -except ModuleNotFoundError: - pass - - -def upload_model(full_gcs_model_path): - folder_name = full_gcs_model_path - if folder_name[:5] != "gs://": - folder_name = "gs://" + folder_name - - dirname = os.path.dirname(folder_name) - epoch = os.path.basename(folder_name) - - model_dir = os.path.join(LOCAL_DIR, "models") - cmd = f"mkdir {model_dir}" - try: - execute_command(cmd) - except subprocess.CalledProcessError: - pass - model_dir = os.path.join(model_dir, os.path.basename(dirname)) - cmd = f"mkdir {model_dir}" - try: - execute_command(cmd) - except subprocess.CalledProcessError: - pass - - try: - _ = int(epoch) - except ValueError: - cmd = f"gsutil rsync -r '{folder_name}' {model_dir}" - weights_dir = model_dir - - else: - cmd = f"gsutil cp '{dirname}/checkpoint' {model_dir}/" - execute_command(cmd) - cmd = f"gsutil cp '{os.path.join(dirname, epoch)}*' {model_dir}/" - weights_dir = f"{model_dir}/{epoch}" - - execute_command(cmd) - return weights_dir - -def compute_precision_fixed_recall(labels, preds, fixed_recall): - precision_values, recall_values, thresholds = precision_recall_curve(y_true=labels, probas_pred=preds) - index_recall = bisect.bisect_left(-recall_values, -1 * fixed_recall) - result = precision_values[index_recall - 1] - print(f"Precision at {recall_values[index_recall-1]} recall: {result}") - - return result, thresholds[index_recall - 1] - -def load_inference_func(model_folder): - model = tf.saved_model.load(model_folder, ["serve"]) - inference_func = model.signatures["serving_default"] - return inference_func - - -def execute_query(client, query): - job = client.query(query) - df = job.result().to_dataframe() - return df - - -def execute_command(cmd, print_=True): - s = subprocess.run(cmd, shell=True, capture_output=print_, check=True) - if print_: - print(s.stderr.decode("utf-8")) - print(s.stdout.decode("utf-8")) - - -def check_gpu(): - try: - execute_command("nvidia-smi") - except subprocess.CalledProcessError: - print("There is no GPU when there should be one.") - raise AttributeError - - l = tf.config.list_physical_devices("GPU") - if len(l) == 0: - raise ModuleNotFoundError("Tensorflow has not found the GPU. Check your installation") - print(l) - - -def set_seeds(seed): - np.random.seed(seed) - - python_random.seed(seed) - - tf.random.set_seed(seed) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/AdditionalFields.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/AdditionalFields.docx new file mode 100644 index 000000000..695f9b973 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/AdditionalFields.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/AdditionalFields.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/AdditionalFields.scala deleted file mode 100644 index 91e06e4c6..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/AdditionalFields.scala +++ /dev/null @@ -1,118 +0,0 @@ -package com.twitter.tweetypie.additionalfields - -import com.twitter.tweetypie.thriftscala.Tweet -import com.twitter.scrooge.TFieldBlob -import com.twitter.scrooge.ThriftStructField - -object AdditionalFields { - type FieldId = Short - - /** additional fields really start at 100, be we are ignoring conversation id for now */ - val StartAdditionalId = 101 - - /** all known [[Tweet]] field IDs */ - val CompiledFieldIds: Seq[FieldId] = Tweet.metaData.fields.map(_.id) - - /** all known [[Tweet]] fields in the "additional-field" range (excludes id) */ - val CompiledAdditionalFieldMetaDatas: Seq[ThriftStructField[Tweet]] = - Tweet.metaData.fields.filter(f => isAdditionalFieldId(f.id)) - - val CompiledAdditionalFieldsMap: Map[Short, ThriftStructField[Tweet]] = - CompiledAdditionalFieldMetaDatas.map(field => (field.id, field)).toMap - - /** all known [[Tweet]] field IDs in the "additional-field" range */ - val CompiledAdditionalFieldIds: Seq[FieldId] = - CompiledAdditionalFieldsMap.keys.toSeq - - /** all [[Tweet]] field IDs which should be rejected when set as additional - * fields on via PostTweetRequest.additionalFields or RetweetRequest.additionalFields */ - val RejectedFieldIds: Seq[FieldId] = Seq( - // Should be provided via PostTweetRequest.conversationControl field. go/convocontrolsbackend - Tweet.ConversationControlField.id, - // This field should only be set based on whether the client sets the right community - // tweet annotation. - Tweet.CommunitiesField.id, - // This field should not be set by clients and should opt for - // [[PostTweetRequest.ExclusiveTweetControlOptions]]. - // The exclusiveTweetControl field requires the userId to be set - // and we shouldn't trust the client to provide the right one. - Tweet.ExclusiveTweetControlField.id, - // This field should not be set by clients and should opt for - // [[PostTweetRequest.TrustedFriendsControlOptions]]. - // The trustedFriendsControl field requires the trustedFriendsListId to be - // set and we shouldn't trust the client to provide the right one. - Tweet.TrustedFriendsControlField.id, - // This field should not be set by clients and should opt for - // [[PostTweetRequest.CollabControlOptions]]. - // The collabControl field requires a list of Collaborators to be - // set and we shouldn't trust the client to provide the right one. - Tweet.CollabControlField.id - ) - - def isAdditionalFieldId(fieldId: FieldId): Boolean = - fieldId >= StartAdditionalId - - /** - * Provides a list of all additional field IDs on the tweet, which include all - * the compiled additional fields and all the provided passthrough fields. This includes - * compiled additional fields where the value is None. - */ - def allAdditionalFieldIds(tweet: Tweet): Seq[FieldId] = - CompiledAdditionalFieldIds ++ tweet._passthroughFields.keys - - /** - * Provides a list of all field IDs that have a value on the tweet which are not known compiled - * additional fields (excludes [[Tweet.id]]). - */ - def unsettableAdditionalFieldIds(tweet: Tweet): Seq[FieldId] = - CompiledFieldIds - .filter { id => - !isAdditionalFieldId(id) && id != Tweet.IdField.id && tweet.getFieldBlob(id).isDefined - } ++ - tweet._passthroughFields.keys - - /** - * Provides a list of all field IDs that have a value on the tweet which are explicitly disallowed - * from being set via PostTweetRequest.additionalFields and RetweetRequest.additionalFields - */ - def rejectedAdditionalFieldIds(tweet: Tweet): Seq[FieldId] = - RejectedFieldIds - .filter { id => tweet.getFieldBlob(id).isDefined } - - def unsettableAdditionalFieldIdsErrorMessage(unsettableFieldIds: Seq[FieldId]): String = - s"request may not contain fields: [${unsettableFieldIds.sorted.mkString(", ")}]" - - /** - * Provides a list of all additional field IDs that have a value on the tweet, - * compiled and passthrough (excludes Tweet.id). - */ - def nonEmptyAdditionalFieldIds(tweet: Tweet): Seq[FieldId] = - CompiledAdditionalFieldMetaDatas.collect { - case f if f.getValue(tweet) != None => f.id - } ++ tweet._passthroughFields.keys - - def additionalFields(tweet: Tweet): Seq[TFieldBlob] = - (tweet.getFieldBlobs(CompiledAdditionalFieldIds) ++ tweet._passthroughFields).values.toSeq - - /** - * Merge base tweet with additional fields. - * Non-additional fields in the additional tweet are ignored. - * @param base: a tweet that contains basic fields - * @param additional: a tweet object that carries additional fields - */ - def setAdditionalFields(base: Tweet, additional: Tweet): Tweet = - setAdditionalFields(base, additionalFields(additional)) - - def setAdditionalFields(base: Tweet, additional: Option[Tweet]): Tweet = - additional.map(setAdditionalFields(base, _)).getOrElse(base) - - def setAdditionalFields(base: Tweet, additional: Traversable[TFieldBlob]): Tweet = - additional.foldLeft(base) { case (t, f) => t.setField(f) } - - /** - * Unsets the specified fields on the given tweet. - */ - def unsetFields(tweet: Tweet, fieldIds: Iterable[FieldId]): Tweet = { - tweet.unsetFields(fieldIds.toSet) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/BUILD deleted file mode 100644 index 472135458..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/org/apache/thrift:libthrift", - "mediaservices/commons/src/main/thrift:thrift-scala", - "scrooge/scrooge-core", - "src/thrift/com/twitter/escherbird:media-annotation-structs-scala", - "src/thrift/com/twitter/spam/rtf:safety-label-scala", - "tweetypie/common/src/thrift/com/twitter/tweetypie:tweet-scala", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/BUILD.docx new file mode 100644 index 000000000..957d0c76f Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/BUILD deleted file mode 100644 index 3e9bc82d8..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "finagle/finagle-memcached/src/main/scala", - "scrooge/scrooge-serializer", - "stitch/stitch-core", - "util/util-core", - "util/util-logging", - # CachedValue struct - "tweetypie/servo/repo/src/main/thrift:thrift-scala", - "util/util-slf4j-api/src/main/scala/com/twitter/util/logging", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/BUILD.docx new file mode 100644 index 000000000..9b583440f Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheOperations.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheOperations.docx new file mode 100644 index 000000000..b570d9ea2 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheOperations.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheOperations.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheOperations.scala deleted file mode 100644 index 816162fad..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheOperations.scala +++ /dev/null @@ -1,241 +0,0 @@ -package com.twitter.tweetypie.caching - -import com.twitter.finagle.service.StatsFilter -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.finagle.stats.ExceptionStatsHandler -import com.twitter.finagle.stats.Counter -import com.twitter.util.Future -import com.twitter.util.logging.Logger -import com.twitter.finagle.memcached -import scala.util.control.NonFatal - -/** - * Wrapper around a memcached client that performs serialization and - * deserialization, tracks stats, provides tracing, and provides - * per-key fresh/stale/failure/miss results. - * - * The operations that write values to cache will only write values - * that the ValueSerializer says are cacheable. The idea here is that - * the deserialize and serialize functions must be coherent, and no - * matter how you choose to write these values back to cache, the - * serializer will have the appropriate knowledge about whether the - * values are cacheable. - * - * For most cases, you will want to use [[StitchCaching]] rather than - * calling this wrapper directly. - * - * @param keySerializer How to convert a K value to a memcached key. - * - * @param valueSerializer How to serialize and deserialize V values, - * as well as which values are cacheable, and how long to store the - * values in cache. - */ -class CacheOperations[K, V]( - keySerializer: K => String, - valueSerializer: ValueSerializer[V], - memcachedClient: memcached.Client, - statsReceiver: StatsReceiver, - logger: Logger, - exceptionStatsHandler: ExceptionStatsHandler = StatsFilter.DefaultExceptions) { - // The memcached operations that are performed via this - // [[CacheOperations]] instance will be tracked under this stats - // receiver. - // - // We count all memcached failures together under this scope, - // because memcached operations should not fail unless there are - // communication problems, so differentiating the method that was - // being called will not give us any useful information. - private[this] val memcachedStats: StatsReceiver = statsReceiver.scope("memcached") - - // Incremented for every attempt to `get` a key from cache. - private[this] val memcachedGetCounter: Counter = memcachedStats.counter("get") - - // One of these two counters is incremented for every successful - // response returned from a `get` call to memcached. - private[this] val memcachedNotFoundCounter: Counter = memcachedStats.counter("not_found") - private[this] val memcachedFoundCounter: Counter = memcachedStats.counter("found") - - // Records the state of the cache load after serialization. The - // policy may transform a value that was successfully loaded from - // cache into any result type, which is why we explicitly track - // "found" and "not_found" above. If `stale` + `fresh` is not equal - // to `found`, then it means that the policy has translated a found - // value into a miss or failure. The policy may do this in order to - // cause the caching filter to treat the value that was found in - // cache in the way it would have treated a miss or failure from - // cache. - private[this] val resultStats: StatsReceiver = statsReceiver.scope("result") - private[this] val resultFreshCounter: Counter = resultStats.counter("fresh") - private[this] val resultStaleCounter: Counter = resultStats.counter("stale") - private[this] val resultMissCounter: Counter = resultStats.counter("miss") - private[this] val resultFailureCounter: Counter = resultStats.counter("failure") - - // Used for recording exceptions that occurred during - // deserialization. This will never be incremented if the - // deserializer returns a result, even if the result is a - // [[CacheResult.Failure]]. See the comment where this stat is - // incremented for more details. - private[this] val deserializeFailureStats: StatsReceiver = statsReceiver.scope("deserialize") - - private[this] val notSerializedCounter: Counter = statsReceiver.counter("not_serialized") - - /** - * Load a batch of values from cache. Mostly this deals with - * converting the [[memcached.GetResult]] to a - * [[Seq[CachedResult[V]]]]. The result is in the same order as the - * keys, and there will always be an entry for each key. This method - * should never return a [[Future.exception]]. - */ - def get(keys: Seq[K]): Future[Seq[CacheResult[V]]] = { - memcachedGetCounter.incr(keys.size) - val cacheKeys: Seq[String] = keys.map(keySerializer) - if (logger.isTraceEnabled) { - logger.trace { - val lines: Seq[String] = keys.zip(cacheKeys).map { case (k, c) => s"\n $k ($c)" } - "Starting load for keys:" + lines.mkString - } - } - - memcachedClient - .getResult(cacheKeys) - .map { getResult => - memcachedNotFoundCounter.incr(getResult.misses.size) - val results: Seq[CacheResult[V]] = - cacheKeys.map { cacheKey => - val result: CacheResult[V] = - getResult.hits.get(cacheKey) match { - case Some(memcachedValue) => - memcachedFoundCounter.incr() - try { - valueSerializer.deserialize(memcachedValue.value) - } catch { - case NonFatal(e) => - // If the serializer throws an exception, then - // the serialized value was malformed. In that - // case, we record the failure so that it can be - // detected and fixed, but treat it as a cache - // miss. The reason that we treat it as a miss - // rather than a failure is that a miss will - // cause a write back to cache, and we want to - // write a valid result back to cache to replace - // the bad entry that we just loaded. - // - // A serializer is free to return Miss itself to - // obtain this behavior if it is expected or - // desired, to avoid the logging and stats (and - // the minor overhead of catching an exception). - // - // The exceptions are tracked separately from - // other exceptions so that it is easy to see - // whether the deserializer itself ever throws an - // exception. - exceptionStatsHandler.record(deserializeFailureStats, e) - logger.warn(s"Failed deserializing value for cache key $cacheKey", e) - CacheResult.Miss - } - - case None if getResult.misses.contains(cacheKey) => - CacheResult.Miss - - case None => - val exception = - getResult.failures.get(cacheKey) match { - case None => - // To get here, this was not a hit or a miss, - // so we expect the key to be present in - // failures. If it is not, then either the - // contract of getResult was violated, or this - // method is somehow attempting to access a - // result for a key that was not - // loaded. Either of these indicates a bug, so - // we log a high priority log message. - logger.error( - s"Key $cacheKey not found in hits, misses or failures. " + - "This indicates a bug in the memcached library or " + - "CacheOperations.load" - ) - // We return this as a failure because that - // will cause the repo to be consulted and the - // value *not* to be written back to cache, - // which is probably the safest thing to do - // (if we don't know what's going on, default - // to an uncached repo). - new IllegalStateException - - case Some(e) => - e - } - exceptionStatsHandler.record(memcachedStats, exception) - CacheResult.Failure(exception) - } - - // Count each kind of CacheResult, to make it possible to - // see how effective the caching is. - result match { - case CacheResult.Fresh(_) => resultFreshCounter.incr() - case CacheResult.Stale(_) => resultStaleCounter.incr() - case CacheResult.Miss => resultMissCounter.incr() - case CacheResult.Failure(_) => resultFailureCounter.incr() - } - - result - } - - if (logger.isTraceEnabled) { - logger.trace { - val lines: Seq[String] = - (keys, cacheKeys, results).zipped.map { - case (key, cacheKey, result) => s"\n $key ($cacheKey) -> $result" - } - - "Cache results:" + lines.mkString - } - } - - results - } - .handle { - case e => - // If there is a failure from the memcached client, fan it - // out to each cache key, so that the caller does not need - // to handle failure of the batch differently than failure - // of individual keys. This should be rare anyway, since the - // memcached client already does this for common Finagle - // exceptions - resultFailureCounter.incr(keys.size) - val theFailure: CacheResult[V] = CacheResult.Failure(e) - keys.map { _ => - // Record this as many times as we would if it were in the GetResult - exceptionStatsHandler.record(memcachedStats, e) - theFailure - } - } - } - - // Incremented for every attempt to `set` a key in value. - private[this] val memcachedSetCounter: Counter = memcachedStats.counter("set") - - /** - * Write an entry back to cache, using `set`. If the serializer does - * not serialize the value, then this method will immediately return - * with success. - */ - def set(key: K, value: V): Future[Unit] = - valueSerializer.serialize(value) match { - case Some((expiry, serialized)) => - if (logger.isTraceEnabled) { - logger.trace(s"Writing back to cache $key -> $value (expiry = $expiry)") - } - memcachedSetCounter.incr() - memcachedClient - .set(key = keySerializer(key), flags = 0, expiry = expiry, value = serialized) - .onFailure(exceptionStatsHandler.record(memcachedStats, _)) - - case None => - if (logger.isTraceEnabled) { - logger.trace(s"Not writing back $key -> $value") - } - notSerializedCounter.incr() - Future.Done - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheResult.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheResult.docx new file mode 100644 index 000000000..5b1f8314a Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheResult.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheResult.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheResult.scala deleted file mode 100644 index c6e9500e7..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/CacheResult.scala +++ /dev/null @@ -1,45 +0,0 @@ -package com.twitter.tweetypie.caching - -/** - * Encodes the possible states of a value loaded from memcached. - * - * @see [[ValueSerializer]] and [[CacheOperations]] - */ -sealed trait CacheResult[+V] - -object CacheResult { - - /** - * Signals that the value could not be successfully loaded from - * cache. `Failure` values should not be written back to cache. - * - * This value may result from an error talking to the memcached - * instance or it may be returned from the Serializer when the value - * should not be reused, but should also not be overwritten. - */ - final case class Failure(e: Throwable) extends CacheResult[Nothing] - - /** - * Signals that the cache load attempt was successful, but there was - * not a usable value. - * - * When processing a `Miss`, the value should be written back to - * cache if it loads successfully. - */ - case object Miss extends CacheResult[Nothing] - - /** - * Signals that the value was found in cache. - * - * It is not necessary to load the value from the original source. - */ - case class Fresh[V](value: V) extends CacheResult[V] - - /** - * Signals that the value was found in cache. - * - * This value should be used, but it should be refreshed - * out-of-band. - */ - case class Stale[V](value: V) extends CacheResult[V] -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/Expiry.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/Expiry.docx new file mode 100644 index 000000000..433c714b4 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/Expiry.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/Expiry.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/Expiry.scala deleted file mode 100644 index 1f2a743c1..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/Expiry.scala +++ /dev/null @@ -1,34 +0,0 @@ -package com.twitter.tweetypie.caching - -import com.twitter.util.Duration -import com.twitter.util.Time - -/** - * Helpers for creating common expiry functions. - * - * An expiry function maps from the value to a time in the future when - * the value should expire from cache. These are useful in the - * implementation of a [[ValueSerializer]]. - */ -object Expiry { - - /** - * Return a time that indicates to memcached to never expire this - * value. - * - * This function takes [[Any]] so that it can be used at any value - * type, since it doesn't examine the value at all. - */ - val Never: Any => Time = - _ => Time.Top - - /** - * Return function that indicates to memcached that the value should - * not be used after the `ttl` has elapsed. - * - * This function takes [[Any]] so that it can be used at any value - * type, since it doesn't examine the value at all. - */ - def byAge(ttl: Duration): Any => Time = - _ => Time.now + ttl -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ServoCachedValueSerializer.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ServoCachedValueSerializer.docx new file mode 100644 index 000000000..862565a81 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ServoCachedValueSerializer.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ServoCachedValueSerializer.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ServoCachedValueSerializer.scala deleted file mode 100644 index 37aaa2216..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ServoCachedValueSerializer.scala +++ /dev/null @@ -1,140 +0,0 @@ -package com.twitter.tweetypie.caching - -import com.twitter.io.Buf -import com.twitter.scrooge.CompactThriftSerializer -import com.twitter.scrooge.ThriftStruct -import com.twitter.scrooge.ThriftStructCodec -import com.twitter.servo.cache.thriftscala.CachedValue -import com.twitter.servo.cache.thriftscala.CachedValueStatus -import com.twitter.stitch.NotFound -import com.twitter.util.Return -import com.twitter.util.Throw -import com.twitter.util.Time -import com.twitter.util.Try -import java.nio.ByteBuffer - -object ServoCachedValueSerializer { - - /** - * Thrown when the fields of the servo CachedValue struct do not - * satisfy the invariants expected by this serialization code. - */ - case class UnexpectedCachedValueState(cachedValue: CachedValue) extends Exception { - def message: String = s"Unexpected state for CachedValue. Value was: $cachedValue" - } - - val CachedValueThriftSerializer: CompactThriftSerializer[CachedValue] = CompactThriftSerializer( - CachedValue) -} - -/** - * A [[ValueSerializer]] that is compatible with the use of - * Servo's [[CachedValue]] struct by tweetypie: - * - * - The only [[CachedValueStatus]] values that are cacheable are - * [[CachedValueStatus.Found]] and [[CachedValueStatus.NotFound]]. - * - * - We only track the `cachedAtMsec` field, because tweetypie's cache - * interaction does not use the other fields, and the values that - * are cached this way are never updated, so storing readThroughAt - * or writtenThroughAt would not add any information. - * - * - When values are present, they are serialized using - * [[org.apache.thrift.protocol.TCompactProtocol]]. - * - * - The CachedValue struct itself is also serialized using TCompactProtocol. - * - * The serializer operates on [[Try]] values and will cache [[Return]] - * and `Throw(NotFound)` values. - */ -case class ServoCachedValueSerializer[V <: ThriftStruct]( - codec: ThriftStructCodec[V], - expiry: Try[V] => Time, - softTtl: SoftTtl[Try[V]]) - extends ValueSerializer[Try[V]] { - import ServoCachedValueSerializer.UnexpectedCachedValueState - import ServoCachedValueSerializer.CachedValueThriftSerializer - - private[this] val ValueThriftSerializer = CompactThriftSerializer(codec) - - /** - * Return an expiry based on the value and a - * TCompactProtocol-encoded servo CachedValue struct with the - * following fields defined: - * - * - `value`: [[None]] - * for {{{Throw(NotFound)}}, {{{Some(encodedStruct)}}} for - * [[Return]], where {{{encodedStruct}}} is a - * TCompactProtocol-encoding of the value inside of the Return. - * - * - `status`: [[CachedValueStatus.Found]] if the value is Return, - * and [[CachedValueStatus.NotFound]] if it is Throw(NotFound) - * - * - `cachedAtMsec`: The current time, accoring to [[Time.now]] - * - * No other fields will be defined. - * - * @throws IllegalArgumentException if called with a value that - * should not be cached. - */ - override def serialize(value: Try[V]): Option[(Time, Buf)] = { - def serializeCachedValue(payload: Option[ByteBuffer]) = { - val cachedValue = CachedValue( - value = payload, - status = if (payload.isDefined) CachedValueStatus.Found else CachedValueStatus.NotFound, - cachedAtMsec = Time.now.inMilliseconds) - - val serialized = Buf.ByteArray.Owned(CachedValueThriftSerializer.toBytes(cachedValue)) - - (expiry(value), serialized) - } - - value match { - case Throw(NotFound) => - Some(serializeCachedValue(None)) - case Return(struct) => - val payload = Some(ByteBuffer.wrap(ValueThriftSerializer.toBytes(struct))) - Some(serializeCachedValue(payload)) - case _ => - None - } - } - - /** - * Deserializes values serialized by [[serializeValue]]. The - * value will be [[CacheResult.Fresh]] or [[CacheResult.Stale]] - * depending on the result of {{{softTtl.isFresh}}}. - * - * @throws UnexpectedCachedValueState if the state of the - * [[CachedValue]] could not be produced by [[serialize]] - */ - override def deserialize(buf: Buf): CacheResult[Try[V]] = { - val cachedValue = CachedValueThriftSerializer.fromBytes(Buf.ByteArray.Owned.extract(buf)) - val hasValue = cachedValue.value.isDefined - val isValid = - (hasValue && cachedValue.status == CachedValueStatus.Found) || - (!hasValue && cachedValue.status == CachedValueStatus.NotFound) - - if (!isValid) { - // Exceptions thrown by deserialization are recorded and treated - // as a cache miss by CacheOperations, so throwing this - // exception will cause the value in cache to be - // overwritten. There will be stats recorded whenever this - // happens. - throw UnexpectedCachedValueState(cachedValue) - } - - val value = - cachedValue.value match { - case Some(valueBuffer) => - val valueBytes = new Array[Byte](valueBuffer.remaining) - valueBuffer.duplicate.get(valueBytes) - Return(ValueThriftSerializer.fromBytes(valueBytes)) - - case None => - Throw(NotFound) - } - - softTtl.toCacheResult(value, Time.fromMilliseconds(cachedValue.cachedAtMsec)) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/SoftTtl.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/SoftTtl.docx new file mode 100644 index 000000000..d5d65fe18 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/SoftTtl.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/SoftTtl.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/SoftTtl.scala deleted file mode 100644 index ad2237924..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/SoftTtl.scala +++ /dev/null @@ -1,120 +0,0 @@ -package com.twitter.tweetypie.caching - -import com.twitter.util.Duration -import com.twitter.util.Time -import scala.util.Random -import com.twitter.logging.Logger - -/** - * Used to determine whether values successfully retrieved from cache - * are [[CacheResult.Fresh]] or [[CacheResult.Stale]]. This is useful - * in the implementation of a [[ValueSerializer]]. - */ -trait SoftTtl[-V] { - - /** - * Determines whether a cached value was fresh. - * - * @param cachedAt the time at which the value was cached. - */ - def isFresh(value: V, cachedAt: Time): Boolean - - /** - * Wraps the value in Fresh or Stale depending on the value of `isFresh`. - * - * (The type variable U exists because it is not allowed to return - * values of a contravariant type, so we must define a variable that - * is a specific subclass of V. This is worth it because it allows - * us to create polymorphic policies without having to specify the - * type. Another solution would be to make the type invariant, but - * then we would have to specify the type whenever we create an - * instance.) - */ - def toCacheResult[U <: V](value: U, cachedAt: Time): CacheResult[U] = - if (isFresh(value, cachedAt)) CacheResult.Fresh(value) else CacheResult.Stale(value) -} - -object SoftTtl { - - /** - * Regardless of the inputs, the value will always be considered - * fresh. - */ - object NeverRefresh extends SoftTtl[Any] { - override def isFresh(_unusedValue: Any, _unusedCachedAt: Time): Boolean = true - } - - /** - * Trigger refresh based on the length of time that a value has been - * stored in cache, ignoring the value. - * - * @param softTtl Items that were cached longer ago than this value - * will be refreshed when they are accessed. - * - * @param jitter Add nondeterminism to the soft TTL to prevent a - * thundering herd of requests refreshing the value at the same - * time. The time at which the value is considered stale will be - * uniformly spread out over a range of +/- (jitter/2). It is - * valid to set the jitter to zero, which will turn off jittering. - * - * @param logger If non-null, use this logger rather than one based - * on the class name. This logger is only used for trace-level - * logging. - */ - case class ByAge[V]( - softTtl: Duration, - jitter: Duration, - specificLogger: Logger = null, - rng: Random = Random) - extends SoftTtl[Any] { - - private[this] val logger: Logger = - if (specificLogger == null) Logger(getClass) else specificLogger - - private[this] val maxJitterMs: Long = jitter.inMilliseconds - - // this requirement is due to using Random.nextInt to choose the - // jitter, but it allows jitter of greater than 24 days - require(maxJitterMs <= (Int.MaxValue / 2)) - - // Negative jitter probably indicates misuse of the API - require(maxJitterMs >= 0) - - // we want period +/- jitter, but the random generator - // generates non-negative numbers, so we generate [0, 2 * - // maxJitter) and subtract maxJitter to obtain [-maxJitter, - // maxJitter) - private[this] val maxJitterRangeMs: Int = (maxJitterMs * 2).toInt - - // We perform all calculations in milliseconds, so convert the - // period to milliseconds out here. - private[this] val softTtlMs: Long = softTtl.inMilliseconds - - // If the value is below this age, it will always be fresh, - // regardless of jitter. - private[this] val alwaysFreshAgeMs: Long = softTtlMs - maxJitterMs - - // If the value is above this age, it will always be stale, - // regardless of jitter. - private[this] val alwaysStaleAgeMs: Long = softTtlMs + maxJitterMs - - override def isFresh(value: Any, cachedAt: Time): Boolean = { - val ageMs: Long = (Time.now - cachedAt).inMilliseconds - val fresh = - if (ageMs <= alwaysFreshAgeMs) { - true - } else if (ageMs > alwaysStaleAgeMs) { - false - } else { - val jitterMs: Long = rng.nextInt(maxJitterRangeMs) - maxJitterMs - ageMs <= softTtlMs + jitterMs - } - - logger.ifTrace( - s"Checked soft ttl: fresh = $fresh, " + - s"soft_ttl_ms = $softTtlMs, age_ms = $ageMs, value = $value") - - fresh - } - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchAsync.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchAsync.docx new file mode 100644 index 000000000..eff511c1f Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchAsync.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchAsync.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchAsync.scala deleted file mode 100644 index 45861f04c..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchAsync.scala +++ /dev/null @@ -1,65 +0,0 @@ -package com.twitter.tweetypie.caching - -import scala.collection.mutable -import com.twitter.util.Future -import com.twitter.stitch.Stitch -import com.twitter.stitch.Runner -import com.twitter.stitch.FutureRunner -import com.twitter.stitch.Group - -/** - * Workaround for a infelicity in the implementation of [[Stitch.async]]. - * - * This has the same semantics to [[Stitch.async]], with the exception - * that interrupts to the main computation will not interrupt the - * async call. - * - * The problem that this implementation solves is that we do not want - * async calls grouped together with synchronous calls. See the - * mailing list thread [1] for discussion. This may eventually be - * fixed in Stitch. - */ -private[caching] object StitchAsync { - // Contains a deferred Stitch that we want to run asynchronously - private[this] class AsyncCall(deferred: => Stitch[_]) { - def call(): Stitch[_] = deferred - } - - private object AsyncGroup extends Group[AsyncCall, Unit] { - override def runner(): Runner[AsyncCall, Unit] = - new FutureRunner[AsyncCall, Unit] { - // All of the deferred calls of any type. When they are - // executed in `run`, the normal Stitch batching and deduping - // will occur. - private[this] val calls = new mutable.ArrayBuffer[AsyncCall] - - def add(call: AsyncCall): Stitch[Unit] = { - // Just remember the deferred call. - calls.append(call) - - // Since we don't wait for the completion of the effect, - // just return a constant value. - Stitch.Unit - } - - def run(): Future[_] = { - // The future returned from this innter invocation of - // Stitch.run is not linked to the returned future, so these - // effects are not linked to the outer Run in which this - // method was invoked. - Stitch.run { - Stitch.traverse(calls) { asyncCall: AsyncCall => - asyncCall - .call() - .liftToTry // So that an exception will not interrupt the other calls - } - } - Future.Unit - } - } - } - - def apply(call: => Stitch[_]): Stitch[Unit] = - // Group together all of the async calls - Stitch.call(new AsyncCall(call), AsyncGroup) -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCacheOperations.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCacheOperations.docx new file mode 100644 index 000000000..6a3a55ba4 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCacheOperations.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCacheOperations.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCacheOperations.scala deleted file mode 100644 index 8c9de67ff..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCacheOperations.scala +++ /dev/null @@ -1,62 +0,0 @@ -package com.twitter.tweetypie.caching - -import com.twitter.stitch.MapGroup -import com.twitter.stitch.Group -import com.twitter.stitch.Stitch -import com.twitter.util.Future -import com.twitter.util.Return -import com.twitter.util.Try - -/** - * Wrapper around [[CacheOperations]] providing a [[Stitch]] API. - */ -case class StitchCacheOperations[K, V](operations: CacheOperations[K, V]) { - import StitchCacheOperations.SetCall - - private[this] val getGroup: Group[K, CacheResult[V]] = - MapGroup[K, CacheResult[V]] { keys: Seq[K] => - operations - .get(keys) - .map(values => keys.zip(values).toMap.mapValues(Return(_))) - } - - def get(key: K): Stitch[CacheResult[V]] = - Stitch.call(key, getGroup) - - private[this] val setGroup: Group[SetCall[K, V], Unit] = - new MapGroup[SetCall[K, V], Unit] { - - override def run(calls: Seq[SetCall[K, V]]): Future[SetCall[K, V] => Try[Unit]] = - Future - .collectToTry(calls.map(call => operations.set(call.key, call.value))) - .map(tries => calls.zip(tries).toMap) - } - - /** - * Performs a [[CacheOperations.set]]. - */ - def set(key: K, value: V): Stitch[Unit] = - // This is implemented as a Stitch.call instead of a Stitch.future - // in order to handle the case where a batch has a duplicate - // key. Each copy of the duplicate key will trigger a write back - // to cache, so we dedupe the writes in order to avoid the - // extraneous RPC call. - Stitch.call(new StitchCacheOperations.SetCall(key, value), setGroup) -} - -object StitchCacheOperations { - - /** - * Used as the "call" for [[SetGroup]]. This is essentially a tuple - * where equality is defined only by the key. - */ - private class SetCall[K, V](val key: K, val value: V) { - override def equals(other: Any): Boolean = - other match { - case setCall: SetCall[_, _] => key == setCall.key - case _ => false - } - - override def hashCode: Int = key.hashCode - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCaching.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCaching.docx new file mode 100644 index 000000000..015853e12 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCaching.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCaching.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCaching.scala deleted file mode 100644 index 830bd11a2..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/StitchCaching.scala +++ /dev/null @@ -1,36 +0,0 @@ -package com.twitter.tweetypie.caching - -import com.twitter.stitch.Stitch - -/** - * Apply caching to a [[Stitch]] function. - * - * @see CacheResult for more information about the semantics - * implemented here. - */ -class StitchCaching[K, V](operations: CacheOperations[K, V], repo: K => Stitch[V]) - extends (K => Stitch[V]) { - - private[this] val stitchOps = new StitchCacheOperations(operations) - - override def apply(key: K): Stitch[V] = - stitchOps.get(key).flatMap { - case CacheResult.Fresh(value) => - Stitch.value(value) - - case CacheResult.Stale(staleValue) => - StitchAsync(repo(key).flatMap(refreshed => stitchOps.set(key, refreshed))) - .map(_ => staleValue) - - case CacheResult.Miss => - repo(key) - .applyEffect(value => StitchAsync(stitchOps.set(key, value))) - - case CacheResult.Failure(_) => - // In the case of failure, we don't attempt to write back to - // cache, because cache failure usually means communication - // failure, and sending more requests to the cache that holds - // the value for this key could make the situation worse. - repo(key) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ValueSerializer.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ValueSerializer.docx new file mode 100644 index 000000000..7b65a986c Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ValueSerializer.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ValueSerializer.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ValueSerializer.scala deleted file mode 100644 index 42335d0ff..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/caching/ValueSerializer.scala +++ /dev/null @@ -1,47 +0,0 @@ -package com.twitter.tweetypie.caching - -import com.twitter.io.Buf -import com.twitter.util.Time - -/** - * How to store values of type V in cache. This includes whether a - * given value is cacheable, how to serialize it, when it should - * expire from cache, and how to interpret byte patterns from cache. - */ -trait ValueSerializer[V] { - - /** - * Prepare the value for storage in cache. When a [[Some]] is - * returned, the [[Buf]] should be a valid input to [[deserialize]] - * and the [[Time]] will be used as the expiry in the memcached - * command. When [[None]] is returned, it indicates that the value - * cannot or should not be written back to cache. - * - * The most common use case for returning None is caching Try - * values, where certain exceptional values encode a cacheable state - * of a value. In particular, Throw(NotFound) is commonly used to - * encode a missing value, and we usually want to cache those - * negative lookups, but we don't want to cache e.g. a timeout - * exception. - * - * @return a pair of expiry time for this cache entry and the bytes - * to store in cache. If you do not want this value to explicitly - * expire, use Time.Top as the expiry. - */ - def serialize(value: V): Option[(Time, Buf)] - - /** - * Deserialize a value found in cache. This function converts the - * bytes found in memcache to a [[CacheResult]]. In general, you - * probably want to return [[CacheResult.Fresh]] or - * [[CacheResult.Stale]], but you are free to return any of the - * range of [[CacheResult]]s, depending on the behavior that you - * want. - * - * This is a total function because in the common use case, the - * bytes stored in cache will be appropriate for the - * serializer. This method is free to throw any exception if the - * bytes are not valid. - */ - def deserialize(serializedValue: Buf): CacheResult[V] -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/BUILD deleted file mode 100644 index c29029d8c..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authentication", - "finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/transport", - "finagle/finagle-thrift/src/main/scala", - "tweetypie/servo/util/src/main/scala:exception", - "strato/src/main/scala/com/twitter/strato/access", - "strato/src/main/scala/com/twitter/strato/data", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/BUILD.docx new file mode 100644 index 000000000..f27f6a589 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/ClientIdHelper.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/ClientIdHelper.docx new file mode 100644 index 000000000..e47046d2c Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/ClientIdHelper.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/ClientIdHelper.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/ClientIdHelper.scala deleted file mode 100644 index 8741ca80d..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/client_id/ClientIdHelper.scala +++ /dev/null @@ -1,185 +0,0 @@ -package com.twitter.tweetypie.client_id - -import com.twitter.finagle.mtls.authentication.EmptyServiceIdentifier -import com.twitter.finagle.mtls.authentication.ServiceIdentifier -import com.twitter.finagle.mtls.transport.S2STransport -import com.twitter.finagle.thrift.ClientId -import com.twitter.servo.util.Gate -import com.twitter.strato.access.Access -import com.twitter.strato.access.Access.ForwardedServiceIdentifier - -object ClientIdHelper { - - val UnknownClientId = "unknown" - - def default: ClientIdHelper = new ClientIdHelper(UseTransportServiceIdentifier) - - /** - * Trims off the last .element, which is usually .prod or .staging - */ - def getClientIdRoot(clientId: String): String = - clientId.lastIndexOf('.') match { - case -1 => clientId - case idx => clientId.substring(0, idx) - } - - /** - * Returns the last .element without the '.' - */ - def getClientIdEnv(clientId: String): String = - clientId.lastIndexOf('.') match { - case -1 => clientId - case idx => clientId.substring(idx + 1) - } - - private[client_id] def asClientId(s: ServiceIdentifier): String = s"${s.service}.${s.environment}" -} - -class ClientIdHelper(serviceIdentifierStrategy: ServiceIdentifierStrategy) { - - private[client_id] val ProcessPathPrefix = "/p/" - - /** - * The effective client id is used for request authorization and metrics - * attribution. For calls to Tweetypie's thrift API, the thrift ClientId - * is used and is expected in the form of "service-name.env". Federated - * Strato clients don't support configured ClientIds and instead provide - * a "process path" containing instance-specific information. So for - * calls to the federated API, we compute an effective client id from - * the ServiceIdentifier, if present, in Strato's Access principles. The - * implementation avoids computing this identifier unless necessary, - * since this method is invoked on every request. - */ - def effectiveClientId: Option[String] = { - val clientId: Option[String] = ClientId.current.map(_.name) - clientId - // Exclude process paths because they are instance-specific and aren't - // supported by tweetypie for authorization or metrics purposes. - .filterNot(_.startsWith(ProcessPathPrefix)) - // Try computing a value from the ServiceId if the thrift - // ClientId is undefined or unsupported. - .orElse(serviceIdentifierStrategy.serviceIdentifier.map(ClientIdHelper.asClientId)) - // Ultimately fall back to the ClientId value, even when given an - // unsupported format, so that error text and debug logs include - // the value passed by the caller. - .orElse(clientId) - } - - def effectiveClientIdRoot: Option[String] = effectiveClientId.map(ClientIdHelper.getClientIdRoot) - - def effectiveServiceIdentifier: Option[ServiceIdentifier] = - serviceIdentifierStrategy.serviceIdentifier -} - -/** Logic how to find a [[ServiceIdentifier]] for the purpose of crafting a client ID. */ -trait ServiceIdentifierStrategy { - def serviceIdentifier: Option[ServiceIdentifier] - - /** - * Returns the only element of given [[Set]] or [[None]]. - * - * This utility is used defensively against a set of principals collected - * from [[Access.getPrincipals]]. While the contract is that there should be at most one - * instance of each principal kind present in that set, in practice that has not been the case - * always. The safest strategy to in that case is to abandon a set completely if more than - * one principals are competing. - */ - final protected def onlyElement[T](set: Set[T]): Option[T] = - if (set.size <= 1) { - set.headOption - } else { - None - } -} - -/** - * Picks [[ServiceIdentifier]] from Finagle SSL Transport, if one exists. - * - * This works for both Thrift API calls as well as StratoFed API calls. Strato's - * [[Access#getPrincipals]] collection, which would typically be consulted by StratoFed - * column logic, contains the same [[ServiceIdentifier]] derived from the Finagle SSL - * transport, so there's no need to have separate strategies for Thrift vs StratoFed - * calls. - * - * This is the default behavior of using [[ServiceIdentifier]] for computing client ID. - */ -private[client_id] class UseTransportServiceIdentifier( - // overridable for testing - getPeerServiceIdentifier: => ServiceIdentifier, -) extends ServiceIdentifierStrategy { - override def serviceIdentifier: Option[ServiceIdentifier] = - getPeerServiceIdentifier match { - case EmptyServiceIdentifier => None - case si => Some(si) - } -} - -object UseTransportServiceIdentifier - extends UseTransportServiceIdentifier(S2STransport.peerServiceIdentifier) - -/** - * Picks [[ForwardedServiceIdentifier]] from Strato principals for client ID - * if [[ServiceIdentifier]] points at call coming from Strato. - * If not present, falls back to [[UseTransportServiceIdentifier]] behavior. - * - * Tweetypie utilizes the strategy to pick [[ServiceIdentifier]] for the purpose - * of generating a client ID when the client ID is absent or unknown. - * [[PreferForwardedServiceIdentifierForStrato]] looks for the [[ForwardedServiceIdentifier]] - * values set by stratoserver request. - * The reason is, stratoserver is effectively a conduit, forwarding the [[ServiceIdentifier]] - * of the _actual client_ that is calling stratoserver. - * Any direct callers not going through stratoserver will default to [[ServiceIdentfier]]. - */ -private[client_id] class PreferForwardedServiceIdentifierForStrato( - // overridable for testing - getPeerServiceIdentifier: => ServiceIdentifier, -) extends ServiceIdentifierStrategy { - val useTransportServiceIdentifier = - new UseTransportServiceIdentifier(getPeerServiceIdentifier) - - override def serviceIdentifier: Option[ServiceIdentifier] = - useTransportServiceIdentifier.serviceIdentifier match { - case Some(serviceIdentifier) if isStrato(serviceIdentifier) => - onlyElement( - Access.getPrincipals - .collect { - case forwarded: ForwardedServiceIdentifier => - forwarded.serviceIdentifier.serviceIdentifier - } - ).orElse(useTransportServiceIdentifier.serviceIdentifier) - case other => other - } - - /** - * Strato uses various service names like "stratoserver" and "stratoserver-patient". - * They all do start with "stratoserver" though, so at the point of implementing, - * the safest bet to recognize strato is to look for this prefix. - * - * This also works for staged strato instances (which it should), despite allowing - * for technically any caller to force this strategy, by creating service certificate - * with this service name. - */ - private def isStrato(serviceIdentifier: ServiceIdentifier): Boolean = - serviceIdentifier.service.startsWith("stratoserver") -} - -object PreferForwardedServiceIdentifierForStrato - extends PreferForwardedServiceIdentifierForStrato(S2STransport.peerServiceIdentifier) - -/** - * [[ServiceIdentifierStrategy]] which dispatches between two delegates based on the value - * of a unitary decider every time [[serviceIdentifier]] is called. - */ -class ConditionalServiceIdentifierStrategy( - private val condition: Gate[Unit], - private val ifTrue: ServiceIdentifierStrategy, - private val ifFalse: ServiceIdentifierStrategy) - extends ServiceIdentifierStrategy { - - override def serviceIdentifier: Option[ServiceIdentifier] = - if (condition()) { - ifTrue.serviceIdentifier - } else { - ifFalse.serviceIdentifier - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/context/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/context/BUILD deleted file mode 100644 index 30cef76c5..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/context/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - provides = scala_artifact( - org = "com.twitter.tweetypie", - name = "context", - repo = artifactory, - ), - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "finagle/finagle-core/src/main", - "graphql/common/src/main/scala/com/twitter/graphql/common/core", - "src/thrift/com/twitter/context:twitter-context-scala", - "twitter-context/src/main/scala", - "util/util-core:scala", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/context/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/context/BUILD.docx new file mode 100644 index 000000000..b562e215c Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/context/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/context/TweetypieContext.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/context/TweetypieContext.docx new file mode 100644 index 000000000..7eb915bff Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/context/TweetypieContext.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/context/TweetypieContext.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/context/TweetypieContext.scala deleted file mode 100644 index 4d987a02c..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/context/TweetypieContext.scala +++ /dev/null @@ -1,135 +0,0 @@ -package com.twitter.tweetypie.context - -import com.twitter.context.TwitterContext -import com.twitter.finagle.Filter -import com.twitter.finagle.Service -import com.twitter.finagle.SimpleFilter -import com.twitter.finagle.context.Contexts -import com.twitter.io.Buf -import com.twitter.io.Buf.ByteArray.Owned -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.graphql.common.core.GraphQlClientApplication -import com.twitter.util.Try -import java.nio.charset.StandardCharsets.UTF_8 -import scala.util.matching.Regex - -/** - * Context and filters to help track callers of Tweetypie's endpoints. This context and its - * filters were originally added to provide visibility into callers of Tweetypie who are - * using the birdherd library to access tweets. - * - * This context data is intended to be marshalled by callers to Tweetypie, but then the - * context data is stripped (moved from broadcast to local). This happens so that the - * context data is not forwarded down tweetypie's backend rpc chains, which often result - * in transitive calls back into tweetypie. This effectively creates single-hop marshalling. - */ -object TweetypieContext { - // Bring Tweetypie permitted TwitterContext into scope - val TwitterContext: TwitterContext = - com.twitter.context.TwitterContext(com.twitter.tweetypie.TwitterContextPermit) - - case class Ctx(via: String) - val Empty = Ctx("") - - object Broadcast { - private[this] object Key extends Contexts.broadcast.Key[Ctx](id = Ctx.getClass.getName) { - - override def marshal(value: Ctx): Buf = - Owned(value.via.getBytes(UTF_8)) - - override def tryUnmarshal(buf: Buf): Try[Ctx] = - Try(Ctx(new String(Owned.extract(buf), UTF_8))) - } - - private[TweetypieContext] def current(): Option[Ctx] = - Contexts.broadcast.get(Key) - - def currentOrElse(default: Ctx): Ctx = - current().getOrElse(default) - - def letClear[T](f: => T): T = - Contexts.broadcast.letClear(Key)(f) - - def let[T](ctx: Ctx)(f: => T): T = - if (Empty == ctx) { - letClear(f) - } else { - Contexts.broadcast.let(Key, ctx)(f) - } - - // ctx has to be by name so we can re-evaluate it for every request (for usage in ServiceTwitter.scala) - def filter(ctx: => Ctx): Filter.TypeAgnostic = - new Filter.TypeAgnostic { - override def toFilter[Req, Rep]: Filter[Req, Rep, Req, Rep] = - (request: Req, service: Service[Req, Rep]) => Broadcast.let(ctx)(service(request)) - } - } - - object Local { - private[this] val Key = - new Contexts.local.Key[Ctx] - - private[TweetypieContext] def let[T](ctx: Option[Ctx])(f: => T): T = - ctx match { - case Some(ctx) if ctx != Empty => Contexts.local.let(Key, ctx)(f) - case None => Contexts.local.letClear(Key)(f) - } - - def current(): Option[Ctx] = - Contexts.local.get(Key) - - def filter[Req, Rep]: SimpleFilter[Req, Rep] = - (request: Req, service: Service[Req, Rep]) => { - val ctx = Broadcast.current() - Broadcast.letClear(Local.let(ctx)(service(request))) - } - - private[this] def clientAppIdToName(clientAppId: Long) = - GraphQlClientApplication.AllById.get(clientAppId).map(_.name).getOrElse("nonTOO") - - private[this] val pathRegexes: Seq[(Regex, String)] = Seq( - ("timeline_conversation_.*_json".r, "timeline_conversation__slug__json"), - ("user_timeline_.*_json".r, "user_timeline__user__json"), - ("[0-9]{2,}".r, "_id_") - ) - - // `context.via` will either be a string like: "birdherd" or "birdherd:/1.1/statuses/show/123.json, - // depending on whether birdherd code was able to determine the path of the request. - private[this] def getViaAndPath(via: String): (String, Option[String]) = - via.split(":", 2) match { - case Array(via, path) => - val sanitizedPath = path - .replace('/', '_') - .replace('.', '_') - - // Apply each regex in turn - val normalizedPath = pathRegexes.foldLeft(sanitizedPath) { - case (path, (regex, replacement)) => regex.replaceAllIn(path, replacement) - } - - (via, Some(normalizedPath)) - case Array(via) => (via, None) - } - - def trackStats[U](scopes: StatsReceiver*): Unit = - for { - tweetypieCtx <- TweetypieContext.Local.current() - (via, pathOpt) = getViaAndPath(tweetypieCtx.via) - twitterCtx <- TwitterContext() - clientAppId <- twitterCtx.clientApplicationId - } yield { - val clientAppName = clientAppIdToName(clientAppId) - scopes.foreach { stats => - val ctxStats = stats.scope("context") - val viaStats = ctxStats.scope("via", via) - viaStats.scope("all").counter("requests").incr() - val viaClientStats = viaStats.scope("by_client", clientAppName) - viaClientStats.counter("requests").incr() - pathOpt.foreach { path => - val viaPathStats = viaStats.scope("by_path", path) - viaPathStats.counter("requests").incr() - } - } - } - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/BUILD deleted file mode 100644 index 8c40f583a..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -scala_library( - sources = ["DeciderGates.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "3rdparty/jvm/com/google/guava", - "decider", - "finagle/finagle-toggle/src/main/scala/com/twitter/finagle/server", - "tweetypie/servo/decider", - "tweetypie/servo/util/src/main/scala", - "util/util-core:scala", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/BUILD.docx new file mode 100644 index 000000000..3c61a4bcc Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/DeciderGates.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/DeciderGates.docx new file mode 100644 index 000000000..5060b0a47 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/DeciderGates.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/DeciderGates.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/DeciderGates.scala deleted file mode 100644 index 56df716f6..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/DeciderGates.scala +++ /dev/null @@ -1,60 +0,0 @@ -package com.twitter.tweetypie -package decider - -import com.google.common.hash.Hashing -import com.twitter.decider.Decider -import com.twitter.decider.Feature -import com.twitter.servo.gate.DeciderGate -import com.twitter.servo.util.Gate -import java.nio.charset.StandardCharsets -import scala.collection.mutable -trait DeciderGates { - def overrides: Map[String, Boolean] = Map.empty - def decider: Decider - def prefix: String - - protected val seenFeatures: mutable.HashSet[String] = new mutable.HashSet[String] - - private def deciderFeature(name: String): Feature = { - decider.feature(prefix + "_" + name) - } - - def withOverride[T](name: String, mkGate: Feature => Gate[T]): Gate[T] = { - seenFeatures += name - overrides.get(name).map(Gate.const).getOrElse(mkGate(deciderFeature(name))) - } - - protected def linear(name: String): Gate[Unit] = withOverride[Unit](name, DeciderGate.linear) - protected def byId(name: String): Gate[Long] = withOverride[Long](name, DeciderGate.byId) - - /** - * It returns a Gate[String] that can be used to check availability of the feature. - * The string is hashed into a Long and used as an "id" and then used to call servo's - * DeciderGate.byId - * - * @param name decider name - * @return Gate[String] - */ - protected def byStringId(name: String): Gate[String] = - byId(name).contramap { s: String => - Hashing.sipHash24().hashString(s, StandardCharsets.UTF_8).asLong() - } - - def all: Traversable[String] = seenFeatures - - def unusedOverrides: Set[String] = overrides.keySet.diff(all.toSet) - - /** - * Generate a map of name -> availability, taking into account overrides. - * Overrides are either on or off so map to 10000 or 0, respectively. - */ - def availabilityMap: Map[String, Option[Int]] = - all.map { name => - val availability: Option[Int] = overrides - .get(name) - .map(on => if (on) 10000 else 0) - .orElse(deciderFeature(name).availability) - - name -> availability - }.toMap -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/BUILD deleted file mode 100644 index a23ca66e4..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "decider", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/BUILD.docx new file mode 100644 index 000000000..cb8926abf Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/TweetyPieDeciderOverrides.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/TweetyPieDeciderOverrides.docx new file mode 100644 index 000000000..2952e5b11 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/TweetyPieDeciderOverrides.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/TweetyPieDeciderOverrides.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/TweetyPieDeciderOverrides.scala deleted file mode 100644 index 7b396f3f8..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/decider/overrides/TweetyPieDeciderOverrides.scala +++ /dev/null @@ -1,42 +0,0 @@ -package com.twitter.tweetypie.decider.overrides - -import com.twitter.decider.LocalOverrides - -object TweetyPieDeciderOverrides extends LocalOverrides.Namespace("tweetypie", "tweetypie_") { - val CheckSpamOnRetweet: LocalOverrides.Override = feature("check_spam_on_retweet") - val CheckSpamOnTweet: LocalOverrides.Override = feature("check_spam_on_tweet") - val ConversationControlUseFeatureSwitchResults: LocalOverrides.Override = feature( - "conversation_control_use_feature_switch_results") - val ConversationControlTweetCreateEnabled: LocalOverrides.Override = feature( - "conversation_control_tweet_create_enabled") - val EnableExclusiveTweetControlValidation: LocalOverrides.Override = feature( - "enable_exclusive_tweet_control_validation") - val EnableHotKeyCaches: LocalOverrides.Override = feature("enable_hot_key_caches") - val HydrateConversationMuted: LocalOverrides.Override = feature("hydrate_conversation_muted") - val HydrateExtensionsOnWrite: LocalOverrides.Override = feature("hydrate_extensions_on_write") - val HydrateEscherbirdAnnotations: LocalOverrides.Override = feature( - "hydrate_escherbird_annotations") - val HydrateGnipProfileGeoEnrichment: LocalOverrides.Override = feature( - "hydrate_gnip_profile_geo_enrichment") - val HydratePastedPics: LocalOverrides.Override = feature("hydrate_pasted_pics") - val HydratePerspectivesEditsForOtherSafetyLevels: LocalOverrides.Override = feature( - "hydrate_perspectives_edits_for_other_levels") - val HydrateScrubEngagements: LocalOverrides.Override = feature("hydrate_scrub_engagements") - val LogRepoExceptions: LocalOverrides.Override = feature("log_repo_exceptions") - val MediaRefsHydratorIncludePastedMedia: LocalOverrides.Override = feature( - "media_refs_hydrator_include_pasted_media") - val ShortCircuitLikelyPartialTweetReads: LocalOverrides.Override = feature( - "short_circuit_likely_partial_tweet_reads_ms") - val RateLimitByLimiterService: LocalOverrides.Override = feature("rate_limit_by_limiter_service") - val RateLimitTweetCreationFailure: LocalOverrides.Override = feature( - "rate_limit_tweet_creation_failure") - val ReplyTweetConversationControlHydrationEnabled = feature( - "reply_tweet_conversation_control_hydration_enabled" - ) - val DisableInviteViaMention = feature( - "disable_invite_via_mention" - ) - val EnableRemoveUnmentionedImplicitMentions: LocalOverrides.Override = feature( - "enable_remove_unmentioned_implicit_mentions") - val useReplicatedDeleteTweet2: LocalOverrides.Override = feature("use_replicated_delete_tweet_2") -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/BUILD deleted file mode 100644 index de6522d52..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -scala_library( - compiler_option_sets = ["fatal_warnings"], - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "finagle/finagle-core/src/main", - "incentives/jiminy/src/main/thrift/com/twitter/incentives/jiminy:thrift-scala", - "tweetypie/servo/util/src/main/scala", - "stitch/stitch-core", - "strato/src/main/scala/com/twitter/strato/client", - "tweetypie/server/src/main/scala/com/twitter/tweetypie/core", - "util/util-core", - "util/util-stats", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/BUILD.docx new file mode 100644 index 000000000..3b47c8505 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/NudgeBuilder.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/NudgeBuilder.docx new file mode 100644 index 000000000..46deee856 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/NudgeBuilder.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/NudgeBuilder.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/NudgeBuilder.scala deleted file mode 100644 index dd123206f..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/jiminy/tweetypie/NudgeBuilder.scala +++ /dev/null @@ -1,165 +0,0 @@ -package com.twitter.tweetypie.jiminy.tweetypie - -import com.twitter.finagle.stats.CategorizingExceptionStatsHandler -import com.twitter.finagle.stats.Stat -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.incentives.jiminy.thriftscala._ -import com.twitter.servo.util.FutureArrow -import com.twitter.servo.util.Gate -import com.twitter.stitch.Stitch -import com.twitter.strato.thrift.ScroogeConvImplicits._ -import com.twitter.strato.client.{Client => StratoClient} -import com.twitter.tweetypie.core.TweetCreateFailure -import com.twitter.util.Future -import com.twitter.util.Return -import com.twitter.util.Throw - -case class NudgeBuilderRequest( - text: String, - inReplyToTweetId: Option[NudgeBuilder.TweetId], - conversationId: Option[NudgeBuilder.TweetId], - hasQuotedTweet: Boolean, - nudgeOptions: Option[CreateTweetNudgeOptions], - tweetId: Option[NudgeBuilder.TweetId]) - -trait NudgeBuilder extends FutureArrow[NudgeBuilderRequest, Unit] { - - /** - * Check whether the user should receive a nudge instead of creating - * the Tweet. If nudgeOptions is None, then no nudge check will be - * performed. - * - * @return a Future.exception containing a [[TweetCreateFailure]] if the - * user should be nudged, or Future.Unit if the user should not be - * nudged. - */ - def apply( - request: NudgeBuilderRequest - ): Future[Unit] -} - -object NudgeBuilder { - type Type = FutureArrow[NudgeBuilderRequest, Unit] - type TweetId = Long - - // darkTrafficCreateNudgeOptions ensure that our dark traffic sends a request that will - // accurately test the Jiminy backend. in this case, we specify that we want checks for all - // possible nudge types - private[this] val darkTrafficCreateNudgeOptions = Some( - CreateTweetNudgeOptions( - requestedNudgeTypes = Some( - Set( - TweetNudgeType.PotentiallyToxicTweet, - TweetNudgeType.ReviseOrMute, - TweetNudgeType.ReviseOrHideThenBlock, - TweetNudgeType.ReviseOrBlock - ) - ) - ) - ) - - private[this] def mkJiminyRequest( - request: NudgeBuilderRequest, - isDarkRequest: Boolean = false - ): CreateTweetNudgeRequest = { - val tweetType = - if (request.inReplyToTweetId.nonEmpty) TweetType.Reply - else if (request.hasQuotedTweet) TweetType.QuoteTweet - else TweetType.OriginalTweet - - CreateTweetNudgeRequest( - tweetText = request.text, - tweetType = tweetType, - inReplyToTweetId = request.inReplyToTweetId, - conversationId = request.conversationId, - createTweetNudgeOptions = - if (isDarkRequest) darkTrafficCreateNudgeOptions else request.nudgeOptions, - tweetId = request.tweetId - ) - } - - /** - * NudgeBuilder implemented by calling the strato column `incentives/createNudge`. - * - * Stats recorded: - * - latency_ms: Latency histogram (also implicitly number of - * invocations). This is counted only in the case that a nudge - * check was requested (`nudgeOptions` is non-empty) - * - * - nudge: The nudge check succeeded and a nudge was created. - * - * - no_nudge: The nudge check succeeded, but no nudge was created. - * - * - failures: Calling strato to create a nudge failed. Broken out - * by exception. - */ - - def apply( - nudgeArrow: FutureArrow[CreateTweetNudgeRequest, CreateTweetNudgeResponse], - enableDarkTraffic: Gate[Unit], - stats: StatsReceiver - ): NudgeBuilder = { - new NudgeBuilder { - private[this] val nudgeLatencyStat = stats.stat("latency_ms") - private[this] val nudgeCounter = stats.counter("nudge") - private[this] val noNudgeCounter = stats.counter("no_nudge") - private[this] val darkRequestCounter = stats.counter("dark_request") - private[this] val nudgeExceptionHandler = new CategorizingExceptionStatsHandler - - override def apply( - request: NudgeBuilderRequest - ): Future[Unit] = - request.nudgeOptions match { - case None => - if (enableDarkTraffic()) { - darkRequestCounter.incr() - Stat - .timeFuture(nudgeLatencyStat) { - nudgeArrow(mkJiminyRequest(request, isDarkRequest = true)) - } - .transform { _ => - // ignore the response since it is a dark request - Future.Done - } - } else { - Future.Done - } - - case Some(_) => - Stat - .timeFuture(nudgeLatencyStat) { - nudgeArrow(mkJiminyRequest(request)) - } - .transform { - case Throw(e) => - nudgeExceptionHandler.record(stats, e) - // If we failed to invoke the nudge column, then - // just continue on with the Tweet creation. - Future.Done - - case Return(CreateTweetNudgeResponse(Some(nudge))) => - nudgeCounter.incr() - Future.exception(TweetCreateFailure.Nudged(nudge = nudge)) - - case Return(CreateTweetNudgeResponse(None)) => - noNudgeCounter.incr() - Future.Done - } - } - } - } - - def apply( - strato: StratoClient, - enableDarkTraffic: Gate[Unit], - stats: StatsReceiver - ): NudgeBuilder = { - val executer = - strato.executer[CreateTweetNudgeRequest, CreateTweetNudgeResponse]( - "incentives/createTweetNudge") - val nudgeArrow: FutureArrow[CreateTweetNudgeRequest, CreateTweetNudgeResponse] = { req => - Stitch.run(executer.execute(req)) - } - apply(nudgeArrow, enableDarkTraffic, stats) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/BUILD deleted file mode 100644 index 52259fc54..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "src/java/com/twitter/common/text/language:language-identifier", - "src/java/com/twitter/common/text/language:locale-util", - "src/java/com/twitter/common/text/pipeline", - "src/java/com/twitter/common/text/token", - "src/java/com/twitter/common_internal/text", - "src/java/com/twitter/common_internal/text/version", - "tweetypie/src/resources/com/twitter/tweetypie/matching", - "util/util-core/src/main/scala/com/twitter/concurrent", - "util/util-core/src/main/scala/com/twitter/io", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/BUILD.docx new file mode 100644 index 000000000..b32cd569e Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TokenSequence.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TokenSequence.docx new file mode 100644 index 000000000..2e57a1aa5 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TokenSequence.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TokenSequence.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TokenSequence.scala deleted file mode 100644 index 09e9695cc..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TokenSequence.scala +++ /dev/null @@ -1,92 +0,0 @@ -package com.twitter.tweetypie.matching - -object TokenSequence { - - /** - * Is `suffix` a suffix of `s`, starting at `offset` in `s`? - */ - def hasSuffixAt(s: CharSequence, suffix: CharSequence, offset: Int): Boolean = - if (offset == 0 && (s.eq(suffix) || s == suffix)) { - true - } else if (suffix.length != (s.length - offset)) { - false - } else { - @annotation.tailrec - def go(i: Int): Boolean = - if (i == suffix.length) true - else if (suffix.charAt(i) == s.charAt(offset + i)) go(i + 1) - else false - - go(0) - } - - /** - * Do two [[CharSequence]]s contain the same characters? - * - * [[CharSequence]] equality is not sufficient because - * [[CharSequence]]s of different types may not consider other - * [[CharSequence]]s containing the same characters equivalent. - */ - def sameCharacters(s1: CharSequence, s2: CharSequence): Boolean = - hasSuffixAt(s1, s2, 0) - - /** - * This method implements the product definition of a token matching a - * keyword. That definition is: - * - * - The token contains the same characters as the keyword. - * - The token contains the same characters as the keyword after - * dropping a leading '#' or '@' from the token. - * - * The intention is that a keyword matches an identical hashtag, but - * if the keyword itself is a hashtag, it only matches the hashtag - * form. - * - * The tokenization process should rule out tokens or keywords that - * start with multiple '#' characters, even though this implementation - * allows for e.g. token "##a" to match "#a". - */ - def tokenMatches(token: CharSequence, keyword: CharSequence): Boolean = - if (sameCharacters(token, keyword)) true - else if (token.length == 0) false - else { - val tokenStart = token.charAt(0) - (tokenStart == '#' || tokenStart == '@') && hasSuffixAt(token, keyword, 1) - } -} - -/** - * A sequence of normalized tokens. The sequence depends on the locale - * in which the text was parsed and the version of the penguin library - * that was used at tokenization time. - */ -case class TokenSequence private[matching] (toIndexedSeq: IndexedSeq[CharSequence]) { - import TokenSequence.tokenMatches - - private def apply(i: Int): CharSequence = toIndexedSeq(i) - - def isEmpty: Boolean = toIndexedSeq.isEmpty - def nonEmpty: Boolean = toIndexedSeq.nonEmpty - - /** - * Does the supplied sequence of keywords match a consecutive sequence - * of tokens within this sequence? - */ - def containsKeywordSequence(keywords: TokenSequence): Boolean = { - val finalIndex = toIndexedSeq.length - keywords.toIndexedSeq.length - - @annotation.tailrec - def matchesAt(offset: Int, i: Int): Boolean = - if (i >= keywords.toIndexedSeq.length) true - else if (tokenMatches(this(i + offset), keywords(i))) matchesAt(offset, i + 1) - else false - - @annotation.tailrec - def search(offset: Int): Boolean = - if (offset > finalIndex) false - else if (matchesAt(offset, 0)) true - else search(offset + 1) - - search(0) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/Tokenizer.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/Tokenizer.docx new file mode 100644 index 000000000..04eb5df95 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/Tokenizer.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/Tokenizer.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/Tokenizer.scala deleted file mode 100644 index 7cb3cd315..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/Tokenizer.scala +++ /dev/null @@ -1,156 +0,0 @@ -package com.twitter.tweetypie.matching - -import com.twitter.common.text.language.LocaleUtil -import com.twitter.common_internal.text.pipeline.TwitterTextNormalizer -import com.twitter.common_internal.text.pipeline.TwitterTextTokenizer -import com.twitter.common_internal.text.version.PenguinVersion -import com.twitter.concurrent.Once -import com.twitter.io.StreamIO -import java.util.Locale -import scala.collection.JavaConverters._ - -/** - * Extract a sequence of normalized tokens from the input text. The - * normalization and tokenization are properly configured for keyword - * matching between texts. - */ -trait Tokenizer { - def tokenize(input: String): TokenSequence -} - -object Tokenizer { - - /** - * When a Penguin version is not explicitly specified, use this - * version of Penguin to perform normalization and tokenization. If - * you cache tokenized text, be sure to store the version as well, to - * avoid comparing text that was normalized with different algorithms. - */ - val DefaultPenguinVersion: PenguinVersion = PenguinVersion.PENGUIN_6 - - /** - * If you already know the locale of the text that is being tokenized, - * use this method to get a tokenizer that is much more efficient than - * the Tweet or Query tokenizer, since it does not have to perform - * language detection. - */ - def forLocale(locale: Locale): Tokenizer = get(locale, DefaultPenguinVersion) - - /** - * Obtain a `Tokenizer` that will tokenize the text for the given - * locale and version of the Penguin library. - */ - def get(locale: Locale, version: PenguinVersion): Tokenizer = - TokenizerFactories(version).forLocale(locale) - - /** - * Encapsulates the configuration and use of [[TwitterTextTokenizer]] - * and [[TwitterTextNormalizer]]. - */ - private[this] class TokenizerFactory(version: PenguinVersion) { - // The normalizer is thread-safe, so share one instance. - private[this] val normalizer = - (new TwitterTextNormalizer.Builder(version)).build() - - // The TwitterTextTokenizer is relatively expensive to build, - // and is not thread safe, so keep instances of it in a - // ThreadLocal. - private[this] val local = - new ThreadLocal[TwitterTextTokenizer] { - override def initialValue: TwitterTextTokenizer = - (new TwitterTextTokenizer.Builder(version)).build() - } - - /** - * Obtain a [[Tokenizer]] for this combination of [[PenguinVersion]] - * and [[Locale]]. - */ - def forLocale(locale: Locale): Tokenizer = - new Tokenizer { - override def tokenize(input: String): TokenSequence = { - val stream = local.get.getTwitterTokenStreamFor(locale) - stream.reset(normalizer.normalize(input, locale)) - val builder = IndexedSeq.newBuilder[CharSequence] - while (stream.incrementToken) builder += stream.term() - TokenSequence(builder.result()) - } - } - } - - /** - * Since there are a small number of Penguin versions, eagerly - * initialize a TokenizerFactory for each version, to avoid managing - * mutable state. - */ - private[this] val TokenizerFactories: PenguinVersion => TokenizerFactory = - PenguinVersion.values.map(v => v -> new TokenizerFactory(v)).toMap - - /** - * The set of locales used in warmup. These locales are mentioned in - * the logic of TwitterTextTokenizer and TwitterTextNormalizer. - */ - private[this] val WarmUpLocales: Seq[Locale] = - Seq - .concat( - Seq( - Locale.JAPANESE, - Locale.KOREAN, - LocaleUtil.UNKNOWN, - LocaleUtil.THAI, - LocaleUtil.ARABIC, - LocaleUtil.SWEDISH - ), - LocaleUtil.CHINESE_JAPANESE_LOCALES.asScala, - LocaleUtil.CJK_LOCALES.asScala - ) - .toSet - .toArray - .toSeq - - /** - * Load the default inputs that are used for warming up this library. - */ - def warmUpCorpus(): Seq[String] = { - val stream = getClass.getResourceAsStream("warmup-text.txt") - val bytes = - try StreamIO.buffer(stream) - finally stream.close() - bytes.toString("UTF-8").linesIterator.toArray.toSeq - } - - /** - * Exercise the functionality of this library on the specified - * strings. In general, prefer [[warmUp]] to this method. - */ - def warmUpWith(ver: PenguinVersion, texts: Iterable[String]): Unit = - texts.foreach { txt => - // Exercise each locale - WarmUpLocales.foreach { loc => - Tokenizer.get(loc, ver).tokenize(txt) - UserMutes.builder().withPenguinVersion(ver).withLocale(loc).validate(txt) - } - - // Exercise language detection - TweetTokenizer.get(ver).tokenize(txt) - UserMutes.builder().withPenguinVersion(ver).validate(txt) - } - - private[this] val warmUpOnce = Once(warmUpWith(DefaultPenguinVersion, warmUpCorpus())) - - /** - * The creation of the first TwitterTextTokenizer is relatively - * expensive, and tokenizing some texts may cause significant - * initialization. - * - * This method exercises the functionality of this library - * with a range of texts in order to perform as much initialization as - * possible before the library is used in a latency-sensitive way. - * - * The warmup routine will only run once. Subsequent invocations of - * `warmUp` will no do additional work, and will return once warmup is - * complete. - * - * The warmup will take on the order of seconds. - */ - def warmUp(): Unit = warmUpOnce() -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TweetTokenizer.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TweetTokenizer.docx new file mode 100644 index 000000000..f85cb2bfe Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TweetTokenizer.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TweetTokenizer.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TweetTokenizer.scala deleted file mode 100644 index 592891235..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/TweetTokenizer.scala +++ /dev/null @@ -1,45 +0,0 @@ -package com.twitter.tweetypie.matching - -import com.twitter.common.text.pipeline.TwitterLanguageIdentifier -import com.twitter.common_internal.text.version.PenguinVersion -import java.util.Locale - -object TweetTokenizer extends Tokenizer { - type LocalePicking = Option[Locale] => Tokenizer - - /** - * Get a Tokenizer-producing function that uses the supplied locale - * to select an appropriate Tokenizer. - */ - def localePicking: LocalePicking = { - case None => TweetTokenizer - case Some(locale) => Tokenizer.forLocale(locale) - } - - private[this] val tweetLangIdentifier = - (new TwitterLanguageIdentifier.Builder).buildForTweet() - - /** - * Get a Tokenizer that performs Tweet language detection, and uses - * that result to tokenize the text. If you already know the locale of - * the tweet text, use `Tokenizer.get`, because it's much - * cheaper. - */ - def get(version: PenguinVersion): Tokenizer = - new Tokenizer { - override def tokenize(text: String): TokenSequence = { - val locale = tweetLangIdentifier.identify(text).getLocale - Tokenizer.get(locale, version).tokenize(text) - } - } - - private[this] val Default = get(Tokenizer.DefaultPenguinVersion) - - /** - * Tokenize the given text using Tweet language detection and - * `Tokenizer.DefaultPenguinVersion`. Prefer `Tokenizer.forLocale` if - * you already know the language of the text. - */ - override def tokenize(tweetText: String): TokenSequence = - Default.tokenize(tweetText) -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/UserMutes.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/UserMutes.docx new file mode 100644 index 000000000..cf4a8cae1 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/UserMutes.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/UserMutes.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/matching/UserMutes.scala deleted file mode 100644 index dc7430c86..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/matching/UserMutes.scala +++ /dev/null @@ -1,128 +0,0 @@ -package com.twitter.tweetypie.matching - -import com.twitter.common.text.pipeline.TwitterLanguageIdentifier -import com.twitter.common_internal.text.version.PenguinVersion -import java.util.Locale -import scala.collection.JavaConversions.asScalaBuffer - -object UserMutesBuilder { - private[matching] val Default = - new UserMutesBuilder(Tokenizer.DefaultPenguinVersion, None) - - private val queryLangIdentifier = - (new TwitterLanguageIdentifier.Builder).buildForQuery() -} - -class UserMutesBuilder private (penguinVersion: PenguinVersion, localeOpt: Option[Locale]) { - - /** - * Use the specified Penguin version when tokenizing a keyword mute - * string. In general, use the default version, unless you need to - * specify a particular version for compatibility with another system - * that is using that version. - */ - def withPenguinVersion(ver: PenguinVersion): UserMutesBuilder = - if (ver == penguinVersion) this - else new UserMutesBuilder(ver, localeOpt) - - /** - * Use the specified locale when tokenizing a keyword mute string. - */ - def withLocale(locale: Locale): UserMutesBuilder = - if (localeOpt.contains(locale)) this - else new UserMutesBuilder(penguinVersion, Some(locale)) - - /** - * When tokenizing a user mute list, detect the language of the - * text. This is significantly more expensive than using a predefined - * locale, but is appropriate when the locale is not yet known. - */ - def detectLocale(): UserMutesBuilder = - if (localeOpt.isEmpty) this - else new UserMutesBuilder(penguinVersion, localeOpt) - - private[this] lazy val tokenizer = - localeOpt match { - case None => - // No locale was specified, so use a Tokenizer that performs - // language detection before tokenizing. - new Tokenizer { - override def tokenize(text: String): TokenSequence = { - val locale = UserMutesBuilder.queryLangIdentifier.identify(text).getLocale - Tokenizer.get(locale, penguinVersion).tokenize(text) - } - } - - case Some(locale) => - Tokenizer.get(locale, penguinVersion) - } - - /** - * Given a list of the user's raw keyword mutes, return a preprocessed - * set of mutes suitable for matching against tweet text. If the input - * contains any phrases that fail validation, then they will be - * dropped. - */ - def build(rawInput: Seq[String]): UserMutes = - UserMutes(rawInput.flatMap(validate(_).right.toOption)) - - /** - * Java-friendly API for processing a user's list of raw keyword mutes - * into a preprocessed form suitable for matching against text. - */ - def fromJavaList(rawInput: java.util.List[String]): UserMutes = - build(asScalaBuffer(rawInput).toSeq) - - /** - * Validate the raw user input muted phrase. Currently, the only - * inputs that are not valid for keyword muting are those inputs that - * do not contain any keywords, because those inputs would match all - * tweets. - */ - def validate(mutedPhrase: String): Either[UserMutes.ValidationError, TokenSequence] = { - val keywords = tokenizer.tokenize(mutedPhrase) - if (keywords.isEmpty) UserMutes.EmptyPhraseError else Right(keywords) - } -} - -object UserMutes { - sealed trait ValidationError - - /** - * The phrase's tokenization did not produce any tokens - */ - case object EmptyPhrase extends ValidationError - - private[matching] val EmptyPhraseError = Left(EmptyPhrase) - - /** - * Get a [[UserMutesBuilder]] that uses the default Penguin version and - * performs language identification to choose a locale. - */ - def builder(): UserMutesBuilder = UserMutesBuilder.Default -} - -/** - * A user's muted keyword list, preprocessed into token sequences. - */ -case class UserMutes private[matching] (toSeq: Seq[TokenSequence]) { - - /** - * Do any of the users' muted keyword sequences occur within the - * supplied text? - */ - def matches(text: TokenSequence): Boolean = - toSeq.exists(text.containsKeywordSequence) - - /** - * Find all positions of matching muted keyword from the user's - * muted keyword list - */ - def find(text: TokenSequence): Seq[Int] = - toSeq.zipWithIndex.collect { - case (token, index) if text.containsKeywordSequence(token) => index - } - - def isEmpty: Boolean = toSeq.isEmpty - def nonEmpty: Boolean = toSeq.nonEmpty -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/media/BUILD deleted file mode 100644 index 2b1e9ec79..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/media/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - strict_deps = True, - tags = ["bazel-compatible"], - dependencies = [ - "mediaservices/commons/src/main/thrift:thrift-scala", - "scrooge/scrooge-core/src/main/scala", - "tweetypie/servo/util/src/main/scala", - "tweetypie/common/src/thrift/com/twitter/tweetypie:media-entity-scala", - "tweetypie/common/src/thrift/com/twitter/tweetypie:tweet-scala", - "tco-util", - "tweetypie/common/src/scala/com/twitter/tweetypie/util", - "util/util-logging/src/main/scala/com/twitter/logging", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/media/BUILD.docx new file mode 100644 index 000000000..a6bcb5ddc Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/media/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/Media.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/media/Media.docx new file mode 100644 index 000000000..609a9a76a Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/media/Media.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/Media.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/media/Media.scala deleted file mode 100644 index bd0e6f4a3..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/media/Media.scala +++ /dev/null @@ -1,149 +0,0 @@ -package com.twitter.tweetypie -package media - -import com.twitter.mediaservices.commons.thriftscala.MediaCategory -import com.twitter.mediaservices.commons.tweetmedia.thriftscala._ -import com.twitter.tco_util.TcoSlug -import com.twitter.tweetypie.thriftscala._ -import com.twitter.tweetypie.util.TweetLenses - -/** - * A smörgåsbord of media-related helper methods. - */ -object Media { - val AnimatedGifContentType = "video/mp4 codecs=avc1.42E0" - - case class MediaTco(expandedUrl: String, url: String, displayUrl: String) - - val ImageContentTypes: Set[MediaContentType] = - Set[MediaContentType]( - MediaContentType.ImageJpeg, - MediaContentType.ImagePng, - MediaContentType.ImageGif - ) - - val AnimatedGifContentTypes: Set[MediaContentType] = - Set[MediaContentType]( - MediaContentType.VideoMp4 - ) - - val VideoContentTypes: Set[MediaContentType] = - Set[MediaContentType]( - MediaContentType.VideoGeneric - ) - - val InUseContentTypes: Set[MediaContentType] = - Set[MediaContentType]( - MediaContentType.ImageGif, - MediaContentType.ImageJpeg, - MediaContentType.ImagePng, - MediaContentType.VideoMp4, - MediaContentType.VideoGeneric - ) - - def isImage(contentType: MediaContentType): Boolean = - ImageContentTypes.contains(contentType) - - def contentTypeToString(contentType: MediaContentType): String = - contentType match { - case MediaContentType.ImageGif => "image/gif" - case MediaContentType.ImageJpeg => "image/jpeg" - case MediaContentType.ImagePng => "image/png" - case MediaContentType.VideoMp4 => "video/mp4" - case MediaContentType.VideoGeneric => "video" - case _ => throw new IllegalArgumentException(s"UnknownMediaContentType: $contentType") - } - - def stringToContentType(str: String): MediaContentType = - str match { - case "image/gif" => MediaContentType.ImageGif - case "image/jpeg" => MediaContentType.ImageJpeg - case "image/png" => MediaContentType.ImagePng - case "video/mp4" => MediaContentType.VideoMp4 - case "video" => MediaContentType.VideoGeneric - case _ => throw new IllegalArgumentException(s"Unknown Content Type String: $str") - } - - def extensionForContentType(cType: MediaContentType): String = - cType match { - case MediaContentType.ImageJpeg => "jpg" - case MediaContentType.ImagePng => "png" - case MediaContentType.ImageGif => "gif" - case MediaContentType.VideoMp4 => "mp4" - case MediaContentType.VideoGeneric => "" - case _ => "unknown" - } - - /** - * Extract a URL entity from a media entity. - */ - def extractUrlEntity(mediaEntity: MediaEntity): UrlEntity = - UrlEntity( - fromIndex = mediaEntity.fromIndex, - toIndex = mediaEntity.toIndex, - url = mediaEntity.url, - expanded = Some(mediaEntity.expandedUrl), - display = Some(mediaEntity.displayUrl) - ) - - /** - * Copy the fields from the URL entity into the media entity. - */ - def copyFromUrlEntity(mediaEntity: MediaEntity, urlEntity: UrlEntity): MediaEntity = { - val expandedUrl = - urlEntity.expanded.orElse(Option(mediaEntity.expandedUrl)).getOrElse(urlEntity.url) - - val displayUrl = - urlEntity.url match { - case TcoSlug(slug) => MediaUrl.Display.fromTcoSlug(slug) - case _ => urlEntity.expanded.getOrElse(urlEntity.url) - } - - mediaEntity.copy( - fromIndex = urlEntity.fromIndex, - toIndex = urlEntity.toIndex, - url = urlEntity.url, - expandedUrl = expandedUrl, - displayUrl = displayUrl - ) - } - - def getAspectRatio(size: MediaSize): AspectRatio = - getAspectRatio(size.width, size.height) - - def getAspectRatio(width: Int, height: Int): AspectRatio = { - if (width == 0 || height == 0) { - throw new IllegalArgumentException(s"Dimensions must be non zero: ($width, $height)") - } - - def calculateGcd(a: Int, b: Int): Int = - if (b == 0) a else calculateGcd(b, a % b) - - val gcd = calculateGcd(math.max(width, height), math.min(width, height)) - AspectRatio((width / gcd).toShort, (height / gcd).toShort) - } - - /** - * Return just the media that belongs to this tweet - */ - def ownMedia(tweet: Tweet): Seq[MediaEntity] = - TweetLenses.media.get(tweet).filter(isOwnMedia(tweet.id, _)) - - /** - * Does the given media entity, which is was found on the tweet with the specified - * tweetId, belong to that tweet? - */ - def isOwnMedia(tweetId: TweetId, entity: MediaEntity): Boolean = - entity.sourceStatusId.forall(_ == tweetId) - - /** - * Mixed Media is any case where there is more than one media item & any of them is not an image. - */ - - def isMixedMedia(mediaEntities: Seq[MediaEntity]): Boolean = - mediaEntities.length > 1 && (mediaEntities.flatMap(_.mediaInfo).exists { - case _: MediaInfo.ImageInfo => false - case _ => true - } || - mediaEntities.flatMap(_.mediaKey).map(_.mediaCategory).exists(_ != MediaCategory.TweetImage)) -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/MediaUrl.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/media/MediaUrl.docx new file mode 100644 index 000000000..7be93bf43 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/media/MediaUrl.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/MediaUrl.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/media/MediaUrl.scala deleted file mode 100644 index eb26dfad8..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/media/MediaUrl.scala +++ /dev/null @@ -1,108 +0,0 @@ -package com.twitter.tweetypie -package media - -import com.twitter.logging.Logger -import com.twitter.tweetypie.thriftscala.MediaEntity -import com.twitter.tweetypie.thriftscala.UrlEntity - -/** - * Creating and parsing tweet media entity URLs. - * - * There are four kinds of URL in a media entity: - * - * - Display URLs: pic.twitter.com aliases for the short URL, for - * embedding in the tweet text. - * - * - Short URLs: regular t.co URLs that expand to the permalink URL. - * - * - Permalink URLs: link to a page that displays the media after - * doing authorization - * - * - Asset URLs: links to the actual media asset. - * - */ -object MediaUrl { - private[this] val log = Logger(getClass) - - /** - * The URL that should be filled in to the displayUrl field of the - * media entity. This URL behaves exactly the same as a t.co link - * (only the domain is different.) - */ - object Display { - val Root = "pic.twitter.com/" - - def fromTcoSlug(tcoSlug: String): String = Root + tcoSlug - } - - /** - * The link target for the link in the tweet text (the expanded URL - * for the media, copied from the URL entity.) For native photos, - * this is the tweet permalink page. - * - * For users without a screen name ("handleless" or NoScreenName users) - * a permalink to /i/status/:tweet_id is used. - */ - object Permalink { - val Root = "https://twitter.com/" - val Internal = "i" - val PhotoSuffix = "/photo/1" - val VideoSuffix = "/video/1" - - def apply(screenName: String, tweetId: TweetId, isVideo: Boolean): String = - Root + - (if (screenName.isEmpty) Internal else screenName) + - "/status/" + - tweetId + - (if (isVideo) VideoSuffix else PhotoSuffix) - - private[this] val PermalinkRegex = - """https?://twitter.com/(?:#!/)?\w+/status/(\d+)/(?:photo|video)/\d+""".r - - private[this] def getTweetId(permalink: String): Option[TweetId] = - permalink match { - case PermalinkRegex(tweetIdStr) => - try { - Some(tweetIdStr.toLong) - } catch { - // Digits too big to fit in a Long - case _: NumberFormatException => None - } - case _ => None - } - - def getTweetId(urlEntity: UrlEntity): Option[TweetId] = - urlEntity.expanded.flatMap(getTweetId) - - def hasTweetId(permalink: String, tweetId: TweetId): Boolean = - getTweetId(permalink).contains(tweetId) - - def hasTweetId(mediaEntity: MediaEntity, tweetId: TweetId): Boolean = - hasTweetId(mediaEntity.expandedUrl, tweetId) - - def hasTweetId(urlEntity: UrlEntity, tweetId: TweetId): Boolean = - getTweetId(urlEntity).contains(tweetId) - } - - /** - * Converts a url that starts with "https://" to one that starts with "http://". - */ - def httpsToHttp(url: String): String = - url.replace("https://", "http://") - - /** - * Gets the last path element from an asset url. This exists temporarily to support - * the now deprecated mediaPath element in MediaEntity. - */ - def mediaPathFromUrl(url: String): String = - url.lastIndexOf('/') match { - case -1 => - log.error("Invalid media path. Could not find last element: " + url) - // Better to return a broken preview URL to the client - // than to fail the whole request. - "" - - case idx => - url.substring(idx + 1) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/package.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/media/package.docx new file mode 100644 index 000000000..3a571dde6 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/media/package.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/media/package.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/media/package.scala deleted file mode 100644 index d8fb9b2d1..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/media/package.scala +++ /dev/null @@ -1,7 +0,0 @@ -package com.twitter.tweetypie - -package object media { - type TweetId = Long - type UserId = Long - type MediaId = Long -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/AddTweetHandler.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/AddTweetHandler.docx new file mode 100644 index 000000000..c0aaca94c Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/AddTweetHandler.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/AddTweetHandler.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/AddTweetHandler.scala deleted file mode 100644 index a0035b9e5..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/AddTweetHandler.scala +++ /dev/null @@ -1,80 +0,0 @@ -package com.twitter.tweetypie.storage - -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.stitch.Stitch -import com.twitter.storage.client.manhattan.kv.ManhattanValue -import com.twitter.tweetypie.storage.TweetUtils.collectWithRateLimitCheck -import com.twitter.tweetypie.storage_internal.thriftscala.StoredTweet -import com.twitter.tweetypie.thriftscala.Tweet -import com.twitter.util.Time - -object AddTweetHandler { - private[storage] type InternalAddTweet = ( - Tweet, - ManhattanOperations.Insert, - Scribe, - StatsReceiver, - Time - ) => Stitch[Unit] - - def apply( - insert: ManhattanOperations.Insert, - scribe: Scribe, - stats: StatsReceiver - ): TweetStorageClient.AddTweet = - tweet => doAddTweet(tweet, insert, scribe, stats, Time.now) - - def makeRecords( - storedTweet: StoredTweet, - timestamp: Time - ): Seq[TweetManhattanRecord] = { - val core = CoreFieldsCodec.fromTweet(storedTweet) - val packedCoreFieldsBlob = CoreFieldsCodec.toTFieldBlob(core) - val coreRecord = - TweetManhattanRecord( - TweetKey.coreFieldsKey(storedTweet.id), - ManhattanValue(TFieldBlobCodec.toByteBuffer(packedCoreFieldsBlob), Some(timestamp)) - ) - - val otherFieldIds = - TweetFields.nonCoreInternalFields ++ TweetFields.getAdditionalFieldIds(storedTweet) - - val otherFields = - storedTweet - .getFieldBlobs(otherFieldIds) - .map { - case (fieldId, tFieldBlob) => - TweetManhattanRecord( - TweetKey.fieldKey(storedTweet.id, fieldId), - ManhattanValue(TFieldBlobCodec.toByteBuffer(tFieldBlob), Some(timestamp)) - ) - } - .toSeq - otherFields :+ coreRecord - } - - private[storage] val doAddTweet: InternalAddTweet = ( - tweet: Tweet, - insert: ManhattanOperations.Insert, - scribe: Scribe, - stats: StatsReceiver, - timestamp: Time - ) => { - assert(tweet.coreData.isDefined, s"Tweet ${tweet.id} is missing coreData: $tweet") - - val storedTweet = StorageConversions.toStoredTweet(tweet) - val records = makeRecords(storedTweet, timestamp) - val inserts = records.map(insert) - val insertsWithRateLimitCheck = - Stitch.collect(inserts.map(_.liftToTry)).map(collectWithRateLimitCheck).lowerFromTry - - Stats.updatePerFieldQpsCounters( - "addTweet", - TweetFields.getAdditionalFieldIds(storedTweet), - 1, - stats - ) - - insertsWithRateLimitCheck.unit.onSuccess { _ => scribe.logAdded(storedTweet) } - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BUILD b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BUILD deleted file mode 100644 index 6a3db82e7..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -scala_library( - sources = ["*.scala"], - compiler_option_sets = ["fatal_warnings"], - platform = "java8", - strict_deps = True, - tags = [ - "bazel-compatible", - "bazel-incompatible-scaladoc", - ], - dependencies = [ - "3rdparty/jvm/com/chuusai:shapeless", - "3rdparty/jvm/com/fasterxml/jackson/core:jackson-databind", - "3rdparty/jvm/com/fasterxml/jackson/module:jackson-module-scala", - "3rdparty/jvm/com/google/guava", - "3rdparty/jvm/com/twitter/bijection:core", - "3rdparty/jvm/com/twitter/bijection:scrooge", - "3rdparty/jvm/com/twitter/bijection:thrift", - "3rdparty/jvm/commons-codec", - "3rdparty/jvm/org/apache/thrift:libthrift", - "diffshow", - "finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authorization", - "finagle/finagle-core/src/main", - "finagle/finagle-stats", - "finagle/finagle-thriftmux/src/main/scala", - "mediaservices/commons/src/main/thrift:thrift-scala", - "scrooge/scrooge-serializer/src/main/scala", - "tweetypie/servo/repo/src/main/scala", - "tweetypie/servo/util", - "snowflake:id", - "src/thrift/com/twitter/escherbird:media-annotation-structs-scala", - "src/thrift/com/twitter/manhattan:internal-scala", - "tweetypie/common/src/thrift/com/twitter/tweetypie:media-entity-scala", - "tweetypie/common/src/thrift/com/twitter/tweetypie:service-scala", - "tweetypie/common/src/thrift/com/twitter/tweetypie:tweet-scala", - "stitch/stitch-core", - "storage/clients/manhattan/client/src/main/scala", - "tbird-thrift:scala", - "tweetypie/common/src/scala/com/twitter/tweetypie/additionalfields", - "tweetypie/common/src/scala/com/twitter/tweetypie/client_id", - "tweetypie/common/src/scala/com/twitter/tweetypie/util", - "tweetypie/common/src/thrift/com/twitter/tweetypie/storage_internal:storage_internal-scala", - "util-internal/scribe", - "util/util-core:scala", - "util/util-slf4j-api/src/main/scala/com/twitter/util/logging", - "util/util-stats/src/main/scala", - ], -) diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BUILD.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BUILD.docx new file mode 100644 index 000000000..fa90a16c1 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BUILD.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BounceDeleteHandler.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BounceDeleteHandler.docx new file mode 100644 index 000000000..d16232ac5 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BounceDeleteHandler.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BounceDeleteHandler.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BounceDeleteHandler.scala deleted file mode 100644 index 224c09cb0..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/BounceDeleteHandler.scala +++ /dev/null @@ -1,20 +0,0 @@ -package com.twitter.tweetypie.storage - -import com.twitter.util.Time - -object BounceDeleteHandler { - def apply( - insert: ManhattanOperations.Insert, - scribe: Scribe - ): TweetStorageClient.BounceDelete = - tweetId => { - val mhTimestamp = Time.now - val bounceDeleteRecord = TweetStateRecord - .BounceDeleted(tweetId, mhTimestamp.inMillis) - .toTweetMhRecord - - insert(bounceDeleteRecord).onSuccess { _ => - scribe.logRemoved(tweetId, mhTimestamp, isSoftDeleted = true) - } - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Codecs.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Codecs.docx new file mode 100644 index 000000000..f59c60eea Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Codecs.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Codecs.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Codecs.scala deleted file mode 100644 index 670014f26..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Codecs.scala +++ /dev/null @@ -1,242 +0,0 @@ -package com.twitter.tweetypie.storage - -import com.twitter.bijection.Conversion.asMethod -import com.twitter.bijection.Injection -import com.twitter.scrooge.TFieldBlob -import com.twitter.storage.client.manhattan.kv._ -import com.twitter.tweetypie.storage.Response.FieldResponse -import com.twitter.tweetypie.storage.Response.FieldResponseCode -import com.twitter.tweetypie.storage_internal.thriftscala.CoreFields -import com.twitter.tweetypie.storage_internal.thriftscala.InternalTweet -import com.twitter.tweetypie.storage_internal.thriftscala.StoredTweet -import java.io.ByteArrayOutputStream -import java.nio.ByteBuffer -import org.apache.thrift.protocol.TBinaryProtocol -import org.apache.thrift.transport.TIOStreamTransport -import org.apache.thrift.transport.TMemoryInputTransport -import scala.collection.immutable -import scala.util.control.NoStackTrace - -// NOTE: All field ids and Tweet structure in this file correspond to the StoredTweet struct ONLY - -object ByteArrayCodec { - def toByteBuffer(byteArray: Array[Byte]): ByteBuffer = byteArray.as[ByteBuffer] - def fromByteBuffer(buffer: ByteBuffer): Array[Byte] = buffer.as[Array[Byte]] -} - -object StringCodec { - private val string2ByteBuffer = Injection.connect[String, Array[Byte], ByteBuffer] - def toByteBuffer(strValue: String): ByteBuffer = string2ByteBuffer(strValue) - def fromByteBuffer(buffer: ByteBuffer): String = string2ByteBuffer.invert(buffer).get -} - -/** - * Terminology - * ----------- - * Tweet id field : The field number of 'tweetId' in the 'Tweet' thrift structure (i.e "1") - * - * First AdditionalField id : The ID if the first additional field in 'Tweet' thrift structure. All field Ids less than this are - * considered internal and all the ids greater than or equal to this field id are considered 'Additional fields'. - * This is set to 100. - * - * Internal Fields : Fields with ids [1 to firstAdditionalFieldid) (excluding firstAdditionalFieldId) - * - * Core fields : (Subset of Internal fields)- Fields with ids [1 to 8, 19]. These fields are "packed" together and stored - * under a single key. This key is referred to as "CoreFieldsKey" (see @TweetKeyType.CoreFieldsKey). - * Note: Actually field 1 is skipped when packing as this field is the tweet id and it need not be - * explicitly stored since the pkey already contains the tweet Id) - * - * Root Core field id : The field id under which the packed core fields are stored in Manhattan. (This is field Id "1") - * - * Required fields : (Subset of Core fields) - Fields with ids [1 to 5] that MUST be present on every tweet. - * - * Additional Fields : All fields with field ids >= 'firstAdditionalFieldId' - * - * Compiled Additional fields : (Subset of Additional Fields) - All fields that the storage library knows about - * (i.e present on the latest storage_internal.thrift that is compiled-in). - * - * Passthrough fields : (Subset of Additional Fields) - The fields on storage_internal.thrift that the storage library is NOT aware of - * These field ids are is obtained looking at the "_passThroughFields" member of the scrooge-generated - * 'Tweet' object. - * - * coreFieldsIdInInternalTweet: This is the field id of the core fields (the only field) in the Internal Tweet struct - */ -object TweetFields { - val firstAdditionalFieldId: Short = 100 - val tweetIdField: Short = 1 - val geoFieldId: Short = 9 - - // The field under which all the core field values are stored (in serialized form). - val rootCoreFieldId: Short = 1 - - val coreFieldIds: immutable.IndexedSeq[FieldId] = { - val quotedTweetFieldId: Short = 19 - (1 to 8).map(_.toShort) ++ Seq(quotedTweetFieldId) - } - val requiredFieldIds: immutable.IndexedSeq[FieldId] = (1 to 5).map(_.toShort) - - val coreFieldsIdInInternalTweet: Short = 1 - - val compiledAdditionalFieldIds: Seq[FieldId] = - StoredTweet.metaData.fields.filter(_.id >= firstAdditionalFieldId).map(_.id) - val internalFieldIds: Seq[FieldId] = - StoredTweet.metaData.fields.filter(_.id < firstAdditionalFieldId).map(_.id) - val nonCoreInternalFields: Seq[FieldId] = - (internalFieldIds.toSet -- coreFieldIds.toSet).toSeq - def getAdditionalFieldIds(tweet: StoredTweet): Seq[FieldId] = - compiledAdditionalFieldIds ++ tweet._passthroughFields.keys.toSeq -} - -/** - * Helper object to convert TFieldBlob to ByteBuffer that gets stored in Manhattan. - * - * The following is the format in which the TFieldBlob gets stored: - * [Version][TField][TFieldBlob] - */ -object TFieldBlobCodec { - val BinaryProtocolFactory: TBinaryProtocol.Factory = new TBinaryProtocol.Factory() - val FormatVersion = 1.0 - - def toByteBuffer(tFieldBlob: TFieldBlob): ByteBuffer = { - val baos = new ByteArrayOutputStream() - val prot = BinaryProtocolFactory.getProtocol(new TIOStreamTransport(baos)) - - prot.writeDouble(FormatVersion) - prot.writeFieldBegin(tFieldBlob.field) - prot.writeBinary(ByteArrayCodec.toByteBuffer(tFieldBlob.data)) - - ByteArrayCodec.toByteBuffer(baos.toByteArray) - } - - def fromByteBuffer(buffer: ByteBuffer): TFieldBlob = { - val byteArray = ByteArrayCodec.fromByteBuffer(buffer) - val prot = BinaryProtocolFactory.getProtocol(new TMemoryInputTransport(byteArray)) - - val version = prot.readDouble() - if (version != FormatVersion) { - throw new VersionMismatchError( - "Version mismatch in decoding ByteBuffer to TFieldBlob. " + - "Actual version: " + version + ". Expected version: " + FormatVersion - ) - } - - val tField = prot.readFieldBegin() - val dataBuffer = prot.readBinary() - val data = ByteArrayCodec.fromByteBuffer(dataBuffer) - - TFieldBlob(tField, data) - } -} - -/** - * Helper object to help convert 'CoreFields' object to/from TFieldBlob (and also to construct - * 'CoreFields' object from a 'StoredTweet' object) - * - * More details: - * - A subset of fields on the 'StoredTweet' thrift structure (2-8,19) are 'packaged' and stored - * together as a serialized TFieldBlob object under a single key in Manhattan (see TweetKeyCodec - * helper object above for more details). - * - * - To make the packing/unpacking the fields to/from TFieldBlob object, we created the following - * two helper thrift structures 'CoreFields' and 'InternalTweet' - * - * // The field Ids and types here MUST exactly match field Ids on 'StoredTweet' thrift structure. - * struct CoreFields { - * 2: optional i64 user_id - * ... - * 8: optional i64 contributor_id - * ... - * 19: optional StoredQuotedTweet stored_quoted_tweet - * - * } - * - * // The field id of core fields MUST be "1" - * struct InternalTweet { - * 1: CoreFields coreFields - * } - * - * - Given the above two structures, packing/unpacking fields (2-8,19) on StoredTweet object into a TFieldBlob - * becomes very trivial: - * For packing: - * (i) Copy fields (2-8,19) from StoredTweet object to a new CoreFields object - * (ii) Create a new InternalTweet object with the 'CoreFields' object constructed in step (i) above - * (iii) Extract field "1" as a TFieldBlob from InternalField (by calling the scrooge generated "getFieldBlob(1)" - * function on the InternalField objecton - * - * For unpacking: - * (i) Create an empty 'InternalField' object - * (ii) Call scrooge-generated 'setField' by passing the tFieldBlob blob (created by packing steps above) - * (iii) Doing step (ii) above will create a hydrated 'CoreField' object that can be accessed by 'coreFields' - * member of 'InternalTweet' object. - */ -object CoreFieldsCodec { - val coreFieldIds: Seq[FieldId] = CoreFields.metaData.fields.map(_.id) - - // "Pack" the core fields i.e converts 'CoreFields' object to "packed" tFieldBlob (See description - // above for more details) - def toTFieldBlob(coreFields: CoreFields): TFieldBlob = { - InternalTweet(Some(coreFields)).getFieldBlob(TweetFields.coreFieldsIdInInternalTweet).get - } - - // "Unpack" the core fields from a packed TFieldBlob into a CoreFields object (see description above for - // more details) - def fromTFieldBlob(tFieldBlob: TFieldBlob): CoreFields = { - InternalTweet().setField(tFieldBlob).coreFields.get - } - - // "Unpack" the core fields from a packed TFieldBlob into a Map of core-fieldId-> TFieldBlob - def unpackFields(tFieldBlob: TFieldBlob): Map[Short, TFieldBlob] = - fromTFieldBlob(tFieldBlob).getFieldBlobs(coreFieldIds) - - // Create a 'CoreFields' thrift object from 'Tweet' thrift object. - def fromTweet(tweet: StoredTweet): CoreFields = { - // As mentioned above, the field ids and types on the 'CoreFields' struct exactly match the - // corresponding fields on StoredTweet structure. So it is safe to call .getField() on Tweet object and - // and pass the returned tFleldBlob a 'setField' on 'CoreFields' object. - coreFieldIds.foldLeft(CoreFields()) { - case (core, fieldId) => - tweet.getFieldBlob(fieldId) match { - case None => core - case Some(tFieldBlob) => core.setField(tFieldBlob) - } - } - } -} - -/** - * Helper object to convert ManhattanException to FieldResponseCode thrift object - */ -object FieldResponseCodeCodec { - import FieldResponseCodec.ValueNotFoundException - - def fromManhattanException(mhException: ManhattanException): FieldResponseCode = { - mhException match { - case _: ValueNotFoundException => FieldResponseCode.ValueNotFound - case _: InternalErrorManhattanException => FieldResponseCode.Error - case _: InvalidRequestManhattanException => FieldResponseCode.InvalidRequest - case _: DeniedManhattanException => FieldResponseCode.Error - case _: UnsatisfiableManhattanException => FieldResponseCode.Error - case _: TimeoutManhattanException => FieldResponseCode.Timeout - } - } -} - -/** - * Helper object to construct FieldResponse thrift object from an Exception. - * This is typically called to convert 'ManhattanException' object to 'FieldResponse' thrift object - */ -object FieldResponseCodec { - class ValueNotFoundException extends ManhattanException("Value not found!") with NoStackTrace - private[storage] val NotFound = new ValueNotFoundException - - def fromThrowable(e: Throwable, additionalMsg: Option[String] = None): FieldResponse = { - val (respCode, errMsg) = e match { - case mhException: ManhattanException => - (FieldResponseCodeCodec.fromManhattanException(mhException), mhException.getMessage) - case _ => (FieldResponseCode.Error, e.getMessage) - } - - val respMsg = additionalMsg.map(_ + ". " + errMsg).orElse(Some(errMsg.toString)) - FieldResponse(respCode, respMsg) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/DeleteAdditionalFieldsHandler.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/DeleteAdditionalFieldsHandler.docx new file mode 100644 index 000000000..3e813a2f1 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/DeleteAdditionalFieldsHandler.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/DeleteAdditionalFieldsHandler.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/DeleteAdditionalFieldsHandler.scala deleted file mode 100644 index 5c89c7a5e..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/DeleteAdditionalFieldsHandler.scala +++ /dev/null @@ -1,67 +0,0 @@ -package com.twitter.tweetypie.storage - -import com.twitter.finagle.stats.StatsReceiver -import com.twitter.stitch.Stitch -import com.twitter.storage.client.manhattan.kv.DeniedManhattanException -import com.twitter.tweetypie.storage.TweetUtils._ -import com.twitter.util.Throw -import com.twitter.util.Time - -object DeleteAdditionalFieldsHandler { - def apply( - delete: ManhattanOperations.Delete, - stats: StatsReceiver - ): TweetStorageClient.DeleteAdditionalFields = - (unfilteredTweetIds: Seq[TweetId], additionalFields: Seq[Field]) => { - val tweetIds = unfilteredTweetIds.filter(_ > 0) - val additionalFieldIds = additionalFields.map(_.id) - require(additionalFields.nonEmpty, "Additional fields to delete cannot be empty") - require( - additionalFieldIds.min >= TweetFields.firstAdditionalFieldId, - s"Additional fields $additionalFields must be in additional field range (>= ${TweetFields.firstAdditionalFieldId})" - ) - - Stats.addWidthStat("deleteAdditionalFields", "tweetIds", tweetIds.size, stats) - Stats.addWidthStat( - "deleteAdditionalFields", - "additionalFieldIds", - additionalFieldIds.size, - stats - ) - Stats.updatePerFieldQpsCounters( - "deleteAdditionalFields", - additionalFieldIds, - tweetIds.size, - stats - ) - val mhTimestamp = Time.now - - val stitches = tweetIds.map { tweetId => - val (fieldIds, mhKeysToDelete) = - additionalFieldIds.map { fieldId => - (fieldId, TweetKey.additionalFieldsKey(tweetId, fieldId)) - }.unzip - - val deletionStitches = mhKeysToDelete.map { mhKeyToDelete => - delete(mhKeyToDelete, Some(mhTimestamp)).liftToTry - } - - Stitch.collect(deletionStitches).map { responsesTries => - val wasRateLimited = responsesTries.exists { - case Throw(e: DeniedManhattanException) => true - case _ => false - } - - val resultsPerTweet = fieldIds.zip(responsesTries).toMap - - if (wasRateLimited) { - buildTweetOverCapacityResponse("deleteAdditionalFields", tweetId, resultsPerTweet) - } else { - buildTweetResponse("deleteAdditionalFields", tweetId, resultsPerTweet) - } - } - } - - Stitch.collect(stitches) - } -} diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Field.docx b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Field.docx new file mode 100644 index 000000000..9772ce885 Binary files /dev/null and b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Field.docx differ diff --git a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Field.scala b/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Field.scala deleted file mode 100644 index 093559234..000000000 --- a/tweetypie/common/src/scala/com/twitter/tweetypie/storage/Field.scala +++ /dev/null @@ -1,41 +0,0 @@ -package com.twitter.tweetypie.storage - -import com.twitter.tweetypie.additionalfields.AdditionalFields -import com.twitter.tweetypie.storage_internal.thriftscala.StoredTweet -import com.twitter.tweetypie.thriftscala.{Tweet => TpTweet} - -/** - * A field of the stored version of a tweet to read, update, or delete. - * - * There is not a one-to-one correspondence between the fields ids of - * [[com.twitter.tweetypie.thriftscala.Tweet]] and - * [[com.twitter.tweetypie.storage_internal.thriftscala.StoredTweet]]. For example, in StoredTweet, - * the nsfwUser property is field 11; in Tweet, it is a property of the coreData struct in field 2. - * To circumvent the confusion of using one set of field ids or the other, callers use instances of - * [[Field]] to reference the part of the object to modify. - */ -class Field private[storage] (val id: Short) extends AnyVal { - override def toString: String = id.toString -} - -/** - * NOTE: Make sure `AllUpdatableCompiledFields` is kept up to date when adding any new field - */ -object Field { - import AdditionalFields.isAdditionalFieldId - val Geo: Field = new Field(StoredTweet.GeoField.id) - val HasTakedown: Field = new Field(StoredTweet.HasTakedownField.id) - val NsfwUser: Field = new Field(StoredTweet.NsfwUserField.id) - val NsfwAdmin: Field = new Field(StoredTweet.NsfwAdminField.id) - val TweetypieOnlyTakedownCountryCodes: Field = - new Field(TpTweet.TweetypieOnlyTakedownCountryCodesField.id) - val TweetypieOnlyTakedownReasons: Field = - new Field(TpTweet.TweetypieOnlyTakedownReasonsField.id) - - val AllUpdatableCompiledFields: Set[Field] = Set(Geo, HasTakedown, NsfwUser, NsfwAdmin) - - def additionalField(id: Short): Field = { - require(isAdditionalFieldId(id), "field id must be in the additional field range") - new Field(id) - } -}