mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-11-16 08:29:21 +01:00
[docx] split commit for file 3400
Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
parent
f3c5ff35cb
commit
9ebf058832
Binary file not shown.
@ -1,281 +0,0 @@
|
|||||||
package com.twitter.product_mixer.shared_library.observer
|
|
||||||
|
|
||||||
import com.twitter.finagle.stats.Counter
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.ArrowObserver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.FunctionObserver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.FutureObserver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.Observer
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.StitchObserver
|
|
||||||
import com.twitter.stitch.Arrow
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import com.twitter.util.Try
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper functions to observe requests, successes, failures, cancellations, exceptions, latency,
|
|
||||||
* and result counts. Supports native functions and asynchronous operations.
|
|
||||||
*/
|
|
||||||
object ResultsObserver {
|
|
||||||
val Total = "total"
|
|
||||||
val Found = "found"
|
|
||||||
val NotFound = "not_found"
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a stitch and result counts
|
|
||||||
*
|
|
||||||
* @see [[StitchResultsObserver]]
|
|
||||||
*/
|
|
||||||
def stitchResults[T](
|
|
||||||
size: T => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): StitchResultsObserver[T] = {
|
|
||||||
new StitchResultsObserver[T](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a stitch and traversable (e.g. Seq, Set) result counts
|
|
||||||
*
|
|
||||||
* @see [[StitchResultsObserver]]
|
|
||||||
*/
|
|
||||||
def stitchResults[T <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): StitchResultsObserver[T] = {
|
|
||||||
new StitchResultsObserver[T](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and result counts
|
|
||||||
*
|
|
||||||
* @see [[ArrowResultsObserver]]
|
|
||||||
*/
|
|
||||||
def arrowResults[In, Out](
|
|
||||||
size: Out => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): ArrowResultsObserver[In, Out] = {
|
|
||||||
new ArrowResultsObserver[In, Out](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts
|
|
||||||
*
|
|
||||||
* @see [[ArrowResultsObserver]]
|
|
||||||
*/
|
|
||||||
def arrowResults[In, Out <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): ArrowResultsObserver[In, Out] = {
|
|
||||||
new ArrowResultsObserver[In, Out](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and result counts
|
|
||||||
*
|
|
||||||
* @see [[TransformingArrowResultsObserver]]
|
|
||||||
*/
|
|
||||||
def transformingArrowResults[In, Out, Transformed](
|
|
||||||
transformer: Out => Try[Transformed],
|
|
||||||
size: Transformed => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): TransformingArrowResultsObserver[In, Out, Transformed] = {
|
|
||||||
new TransformingArrowResultsObserver[In, Out, Transformed](
|
|
||||||
transformer,
|
|
||||||
size,
|
|
||||||
statsReceiver,
|
|
||||||
scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts
|
|
||||||
*
|
|
||||||
* @see [[TransformingArrowResultsObserver]]
|
|
||||||
*/
|
|
||||||
def transformingArrowResults[In, Out, Transformed <: TraversableOnce[_]](
|
|
||||||
transformer: Out => Try[Transformed],
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): TransformingArrowResultsObserver[In, Out, Transformed] = {
|
|
||||||
new TransformingArrowResultsObserver[In, Out, Transformed](
|
|
||||||
transformer,
|
|
||||||
_.size,
|
|
||||||
statsReceiver,
|
|
||||||
scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a future and result counts
|
|
||||||
*
|
|
||||||
* @see [[FutureResultsObserver]]
|
|
||||||
*/
|
|
||||||
def futureResults[T](
|
|
||||||
size: T => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FutureResultsObserver[T] = {
|
|
||||||
new FutureResultsObserver[T](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a future and traversable (e.g. Seq, Set) result counts
|
|
||||||
*
|
|
||||||
* @see [[FutureResultsObserver]]
|
|
||||||
*/
|
|
||||||
def futureResults[T <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FutureResultsObserver[T] = {
|
|
||||||
new FutureResultsObserver[T](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a function and result counts
|
|
||||||
*
|
|
||||||
* @see [[FunctionResultsObserver]]
|
|
||||||
*/
|
|
||||||
def functionResults[T](
|
|
||||||
size: T => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FunctionResultsObserver[T] = {
|
|
||||||
new FunctionResultsObserver[T](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a function and traversable (e.g. Seq, Set) result counts
|
|
||||||
*
|
|
||||||
* @see [[FunctionResultsObserver]]
|
|
||||||
*/
|
|
||||||
def functionResults[T <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FunctionResultsObserver[T] = {
|
|
||||||
new FunctionResultsObserver[T](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** [[StitchObserver]] that also records result size */
|
|
||||||
class StitchResultsObserver[T](
|
|
||||||
override val size: T => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends StitchObserver[T](statsReceiver, scopes)
|
|
||||||
with ResultsObserver[T] {
|
|
||||||
|
|
||||||
override def apply(stitch: => Stitch[T]): Stitch[T] =
|
|
||||||
super
|
|
||||||
.apply(stitch)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** [[ArrowObserver]] that also records result size */
|
|
||||||
class ArrowResultsObserver[In, Out](
|
|
||||||
override val size: Out => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends ArrowObserver[In, Out](statsReceiver, scopes)
|
|
||||||
with ResultsObserver[Out] {
|
|
||||||
|
|
||||||
override def apply(arrow: Arrow[In, Out]): Arrow[In, Out] =
|
|
||||||
super
|
|
||||||
.apply(arrow)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* [[TransformingArrowResultsObserver]] functions like an [[ArrowObserver]] except
|
|
||||||
* that it transforms the result using [[transformer]] before recording stats.
|
|
||||||
*
|
|
||||||
* The original non-transformed result is then returned.
|
|
||||||
*/
|
|
||||||
class TransformingArrowResultsObserver[In, Out, Transformed](
|
|
||||||
val transformer: Out => Try[Transformed],
|
|
||||||
override val size: Transformed => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends Observer[Transformed]
|
|
||||||
with ResultsObserver[Transformed] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a new Arrow that records stats on the result after applying [[transformer]] when it's run.
|
|
||||||
* The original, non-transformed, result of the Arrow is passed through.
|
|
||||||
*
|
|
||||||
* @note the provided Arrow must contain the parts that need to be timed.
|
|
||||||
* Using this on just the result of the computation the latency stat
|
|
||||||
* will be incorrect.
|
|
||||||
*/
|
|
||||||
def apply(arrow: Arrow[In, Out]): Arrow[In, Out] = {
|
|
||||||
Arrow
|
|
||||||
.time(arrow)
|
|
||||||
.map {
|
|
||||||
case (response, stitchRunDuration) =>
|
|
||||||
observe(response.flatMap(transformer), stitchRunDuration)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
response
|
|
||||||
}.lowerFromTry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** [[FutureObserver]] that also records result size */
|
|
||||||
class FutureResultsObserver[T](
|
|
||||||
override val size: T => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends FutureObserver[T](statsReceiver, scopes)
|
|
||||||
with ResultsObserver[T] {
|
|
||||||
|
|
||||||
override def apply(future: => Future[T]): Future[T] =
|
|
||||||
super
|
|
||||||
.apply(future)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** [[FunctionObserver]] that also records result size */
|
|
||||||
class FunctionResultsObserver[T](
|
|
||||||
override val size: T => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends FunctionObserver[T](statsReceiver, scopes)
|
|
||||||
with ResultsObserver[T] {
|
|
||||||
|
|
||||||
override def apply(f: => T): T = observeResults(super.apply(f))
|
|
||||||
}
|
|
||||||
|
|
||||||
/** [[ResultsObserver]] provides methods for recording stats for the result size */
|
|
||||||
trait ResultsObserver[T] {
|
|
||||||
protected val statsReceiver: StatsReceiver
|
|
||||||
|
|
||||||
/** Scopes that prefix all stats */
|
|
||||||
protected val scopes: Seq[String]
|
|
||||||
|
|
||||||
protected val totalCounter: Counter = statsReceiver.counter(scopes :+ Total: _*)
|
|
||||||
protected val foundCounter: Counter = statsReceiver.counter(scopes :+ Found: _*)
|
|
||||||
protected val notFoundCounter: Counter = statsReceiver.counter(scopes :+ NotFound: _*)
|
|
||||||
|
|
||||||
/** given a [[T]] returns it's size. */
|
|
||||||
protected val size: T => Int
|
|
||||||
|
|
||||||
/** Records the size of the `results` using [[size]] and return the original value. */
|
|
||||||
protected def observeResults(results: T): T = {
|
|
||||||
val resultsSize = size(results)
|
|
||||||
observeResultsWithSize(results, resultsSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Records the `resultsSize` and returns the `results`
|
|
||||||
*
|
|
||||||
* This is useful if the size is already available and is expensive to calculate.
|
|
||||||
*/
|
|
||||||
protected def observeResultsWithSize(results: T, resultsSize: Int): T = {
|
|
||||||
if (resultsSize > 0) {
|
|
||||||
totalCounter.incr(resultsSize)
|
|
||||||
foundCounter.incr()
|
|
||||||
} else {
|
|
||||||
notFoundCounter.incr()
|
|
||||||
}
|
|
||||||
results
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,243 +0,0 @@
|
|||||||
package com.twitter.product_mixer.shared_library.observer
|
|
||||||
|
|
||||||
import com.twitter.finagle.stats.Stat
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.ArrowObserver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.FunctionObserver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.FutureObserver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.Observer
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.Observer.StitchObserver
|
|
||||||
import com.twitter.product_mixer.shared_library.observer.ResultsObserver.ResultsObserver
|
|
||||||
import com.twitter.stitch.Arrow
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import com.twitter.util.Try
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper functions to observe requests, successes, failures, cancellations, exceptions, latency,
|
|
||||||
* and result counts and time-series stats. Supports native functions and asynchronous operations.
|
|
||||||
*
|
|
||||||
* Note that since time-series stats are expensive to compute (relative to counters), prefer
|
|
||||||
* [[ResultsObserver]] unless a time-series stat is needed.
|
|
||||||
*/
|
|
||||||
object ResultsStatsObserver {
|
|
||||||
val Size = "size"
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a stitch and result counts and time-series stats
|
|
||||||
*/
|
|
||||||
def stitchResultsStats[T](
|
|
||||||
size: T => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): StitchResultsStatsObserver[T] = {
|
|
||||||
new StitchResultsStatsObserver[T](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a stitch and traversable (e.g. Seq, Set) result counts and
|
|
||||||
* time-series stats
|
|
||||||
*/
|
|
||||||
def stitchResultsStats[T <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): StitchResultsStatsObserver[T] = {
|
|
||||||
new StitchResultsStatsObserver[T](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and result counts and time-series stats
|
|
||||||
*/
|
|
||||||
def arrowResultsStats[T, U](
|
|
||||||
size: U => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): ArrowResultsStatsObserver[T, U] = {
|
|
||||||
new ArrowResultsStatsObserver[T, U](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts and
|
|
||||||
* * time-series stats
|
|
||||||
*/
|
|
||||||
def arrowResultsStats[T, U <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): ArrowResultsStatsObserver[T, U] = {
|
|
||||||
new ArrowResultsStatsObserver[T, U](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and result counts
|
|
||||||
*
|
|
||||||
* @see [[TransformingArrowResultsStatsObserver]]
|
|
||||||
*/
|
|
||||||
def transformingArrowResultsStats[In, Out, Transformed](
|
|
||||||
transformer: Out => Try[Transformed],
|
|
||||||
size: Transformed => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): TransformingArrowResultsStatsObserver[In, Out, Transformed] = {
|
|
||||||
new TransformingArrowResultsStatsObserver[In, Out, Transformed](
|
|
||||||
transformer,
|
|
||||||
size,
|
|
||||||
statsReceiver,
|
|
||||||
scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts
|
|
||||||
*
|
|
||||||
* @see [[TransformingArrowResultsStatsObserver]]
|
|
||||||
*/
|
|
||||||
def transformingArrowResultsStats[In, Out, Transformed <: TraversableOnce[_]](
|
|
||||||
transformer: Out => Try[Transformed],
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): TransformingArrowResultsStatsObserver[In, Out, Transformed] = {
|
|
||||||
new TransformingArrowResultsStatsObserver[In, Out, Transformed](
|
|
||||||
transformer,
|
|
||||||
_.size,
|
|
||||||
statsReceiver,
|
|
||||||
scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a future and result counts and time-series stats
|
|
||||||
*/
|
|
||||||
def futureResultsStats[T](
|
|
||||||
size: T => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FutureResultsStatsObserver[T] = {
|
|
||||||
new FutureResultsStatsObserver[T](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to observe a future and traversable (e.g. Seq, Set) result counts and
|
|
||||||
* time-series stats
|
|
||||||
*/
|
|
||||||
def futureResultsStats[T <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FutureResultsStatsObserver[T] = {
|
|
||||||
new FutureResultsStatsObserver[T](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function observe a function and result counts and time-series stats
|
|
||||||
*/
|
|
||||||
def functionResultsStats[T](
|
|
||||||
size: T => Int,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FunctionResultsStatsObserver[T] = {
|
|
||||||
new FunctionResultsStatsObserver[T](size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function observe a function and traversable (e.g. Seq, Set) result counts and
|
|
||||||
* time-series stats
|
|
||||||
*/
|
|
||||||
def functionResultsStats[T <: TraversableOnce[_]](
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
scopes: String*
|
|
||||||
): FunctionResultsStatsObserver[T] = {
|
|
||||||
new FunctionResultsStatsObserver[T](_.size, statsReceiver, scopes)
|
|
||||||
}
|
|
||||||
|
|
||||||
class StitchResultsStatsObserver[T](
|
|
||||||
override val size: T => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends StitchObserver[T](statsReceiver, scopes)
|
|
||||||
with ResultsStatsObserver[T] {
|
|
||||||
|
|
||||||
override def apply(stitch: => Stitch[T]): Stitch[T] =
|
|
||||||
super
|
|
||||||
.apply(stitch)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
class ArrowResultsStatsObserver[T, U](
|
|
||||||
override val size: U => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends ArrowObserver[T, U](statsReceiver, scopes)
|
|
||||||
with ResultsStatsObserver[U] {
|
|
||||||
|
|
||||||
override def apply(arrow: Arrow[T, U]): Arrow[T, U] =
|
|
||||||
super
|
|
||||||
.apply(arrow)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* [[TransformingArrowResultsStatsObserver]] functions like an [[ArrowObserver]] except
|
|
||||||
* that it transforms the result using [[transformer]] before recording stats.
|
|
||||||
*
|
|
||||||
* The original non-transformed result is then returned.
|
|
||||||
*/
|
|
||||||
class TransformingArrowResultsStatsObserver[In, Out, Transformed](
|
|
||||||
val transformer: Out => Try[Transformed],
|
|
||||||
override val size: Transformed => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends Observer[Transformed]
|
|
||||||
with ResultsStatsObserver[Transformed] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a new Arrow that records stats on the result after applying [[transformer]] when it's run.
|
|
||||||
* The original, non-transformed, result of the Arrow is passed through.
|
|
||||||
*
|
|
||||||
* @note the provided Arrow must contain the parts that need to be timed.
|
|
||||||
* Using this on just the result of the computation the latency stat
|
|
||||||
* will be incorrect.
|
|
||||||
*/
|
|
||||||
def apply(arrow: Arrow[In, Out]): Arrow[In, Out] = {
|
|
||||||
Arrow
|
|
||||||
.time(arrow)
|
|
||||||
.map {
|
|
||||||
case (response, stitchRunDuration) =>
|
|
||||||
observe(response.flatMap(transformer), stitchRunDuration)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
response
|
|
||||||
}.lowerFromTry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class FutureResultsStatsObserver[T](
|
|
||||||
override val size: T => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends FutureObserver[T](statsReceiver, scopes)
|
|
||||||
with ResultsStatsObserver[T] {
|
|
||||||
|
|
||||||
override def apply(future: => Future[T]): Future[T] =
|
|
||||||
super
|
|
||||||
.apply(future)
|
|
||||||
.onSuccess(observeResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
class FunctionResultsStatsObserver[T](
|
|
||||||
override val size: T => Int,
|
|
||||||
override val statsReceiver: StatsReceiver,
|
|
||||||
override val scopes: Seq[String])
|
|
||||||
extends FunctionObserver[T](statsReceiver, scopes)
|
|
||||||
with ResultsStatsObserver[T] {
|
|
||||||
|
|
||||||
override def apply(f: => T): T = {
|
|
||||||
observeResults(super.apply(f))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
trait ResultsStatsObserver[T] extends ResultsObserver[T] {
|
|
||||||
private val sizeStat: Stat = statsReceiver.stat(scopes :+ Size: _*)
|
|
||||||
|
|
||||||
protected override def observeResults(results: T): T = {
|
|
||||||
val resultsSize = size(results)
|
|
||||||
sizeStat.add(resultsSize)
|
|
||||||
observeResultsWithSize(results, resultsSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,20 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"finagle/finagle-core/src/main",
|
|
||||||
"finagle/finagle-thriftmux/src/main/scala",
|
|
||||||
"finatra-internal/mtls-http/src/main/scala",
|
|
||||||
"finatra-internal/mtls-thriftmux/src/main/scala",
|
|
||||||
"util/util-core",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"finagle/finagle-core/src/main",
|
|
||||||
"finagle/finagle-thriftmux/src/main/scala",
|
|
||||||
"finatra-internal/mtls-http/src/main/scala",
|
|
||||||
"finatra-internal/mtls-thriftmux/src/main/scala",
|
|
||||||
"util/util-core",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,198 +0,0 @@
|
|||||||
package com.twitter.product_mixer.shared_library.thrift_client
|
|
||||||
|
|
||||||
import com.twitter.conversions.DurationOps._
|
|
||||||
import com.twitter.finagle.ThriftMux
|
|
||||||
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
|
|
||||||
import com.twitter.finagle.mtls.client.MtlsStackClient._
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.finagle.thrift.ClientId
|
|
||||||
import com.twitter.finagle.thrift.service.Filterable
|
|
||||||
import com.twitter.finagle.thrift.service.MethodPerEndpointBuilder
|
|
||||||
import com.twitter.finagle.thrift.service.ServicePerEndpointBuilder
|
|
||||||
import com.twitter.finagle.thriftmux.MethodBuilder
|
|
||||||
import com.twitter.util.Duration
|
|
||||||
import org.apache.thrift.protocol.TProtocolFactory
|
|
||||||
|
|
||||||
sealed trait Idempotency
|
|
||||||
case object NonIdempotent extends Idempotency
|
|
||||||
case class Idempotent(maxExtraLoadPercent: Double) extends Idempotency
|
|
||||||
|
|
||||||
object FinagleThriftClientBuilder {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Library to build a Finagle Thrift method per endpoint client is a less error-prone way when
|
|
||||||
* compared to the builders in Finagle. This is achieved by requiring values for fields that should
|
|
||||||
* always be set in practice. For example, request timeouts in Finagle are unbounded when not
|
|
||||||
* explicitly set, and this method requires that timeout durations are passed into the method and
|
|
||||||
* set on the Finagle builder.
|
|
||||||
*
|
|
||||||
* Usage of
|
|
||||||
* [[com.twitter.inject.thrift.modules.ThriftMethodBuilderClientModule]] is almost always preferred,
|
|
||||||
* and the Product Mixer component library [[com.twitter.product_mixer.component_library.module]]
|
|
||||||
* package contains numerous examples. However, if multiple versions of a client are needed e.g.
|
|
||||||
* for different timeout settings, this method is useful to easily provide multiple variants.
|
|
||||||
*
|
|
||||||
* @example
|
|
||||||
* {{{
|
|
||||||
* final val SampleServiceClientName = "SampleServiceClient"
|
|
||||||
* @Provides
|
|
||||||
* @Singleton
|
|
||||||
* @Named(SampleServiceClientName)
|
|
||||||
* def provideSampleServiceClient(
|
|
||||||
* serviceIdentifier: ServiceIdentifier,
|
|
||||||
* clientId: ClientId,
|
|
||||||
* statsReceiver: StatsReceiver,
|
|
||||||
* ): SampleService.MethodPerEndpoint =
|
|
||||||
* buildFinagleMethodPerEndpoint[SampleService.ServicePerEndpoint, SampleService.MethodPerEndpoint](
|
|
||||||
* serviceIdentifier = serviceIdentifier,
|
|
||||||
* clientId = clientId,
|
|
||||||
* dest = "/s/sample/sample",
|
|
||||||
* label = "sample",
|
|
||||||
* statsReceiver = statsReceiver,
|
|
||||||
* idempotency = Idempotent(5.percent),
|
|
||||||
* timeoutPerRequest = 200.milliseconds,
|
|
||||||
* timeoutTotal = 400.milliseconds
|
|
||||||
* )
|
|
||||||
* }}}
|
|
||||||
* @param serviceIdentifier Service ID used to S2S Auth
|
|
||||||
* @param clientId Client ID
|
|
||||||
* @param dest Destination as a Wily path e.g. "/s/sample/sample"
|
|
||||||
* @param label Label of the client
|
|
||||||
* @param statsReceiver Stats
|
|
||||||
* @param idempotency Idempotency semantics of the client
|
|
||||||
* @param timeoutPerRequest Thrift client timeout per request. The Finagle default is
|
|
||||||
* unbounded which is almost never optimal.
|
|
||||||
* @param timeoutTotal Thrift client total timeout. The Finagle default is
|
|
||||||
* unbounded which is almost never optimal.
|
|
||||||
* If the client is set as idempotent, which adds a
|
|
||||||
* [[com.twitter.finagle.client.BackupRequestFilter]],
|
|
||||||
* be sure to leave enough room for the backup request. A
|
|
||||||
* reasonable (albeit usually too large) starting point is to
|
|
||||||
* make the total timeout 2x relative to the per request timeout.
|
|
||||||
* If the client is set as non-idempotent, the total timeout and
|
|
||||||
* the per request timeout should be the same, as there will be
|
|
||||||
* no backup requests.
|
|
||||||
* @param connectTimeout Thrift client transport connect timeout. The Finagle default
|
|
||||||
* of one second is reasonable but we lower this to match
|
|
||||||
* acquisitionTimeout for consistency.
|
|
||||||
* @param acquisitionTimeout Thrift client session acquisition timeout. The Finagle default
|
|
||||||
* is unbounded which is almost never optimal.
|
|
||||||
* @param protocolFactoryOverride Override the default protocol factory
|
|
||||||
* e.g. [[org.apache.thrift.protocol.TCompactProtocol.Factory]]
|
|
||||||
* @param servicePerEndpointBuilder implicit service per endpoint builder
|
|
||||||
* @param methodPerEndpointBuilder implicit method per endpoint builder
|
|
||||||
*
|
|
||||||
* @see [[https://twitter.github.io/finagle/guide/MethodBuilder.html user guide]]
|
|
||||||
* @see [[https://twitter.github.io/finagle/guide/MethodBuilder.html#idempotency user guide]]
|
|
||||||
* @return method per endpoint Finagle Thrift Client
|
|
||||||
*/
|
|
||||||
def buildFinagleMethodPerEndpoint[
|
|
||||||
ServicePerEndpoint <: Filterable[ServicePerEndpoint],
|
|
||||||
MethodPerEndpoint
|
|
||||||
](
|
|
||||||
serviceIdentifier: ServiceIdentifier,
|
|
||||||
clientId: ClientId,
|
|
||||||
dest: String,
|
|
||||||
label: String,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
idempotency: Idempotency,
|
|
||||||
timeoutPerRequest: Duration,
|
|
||||||
timeoutTotal: Duration,
|
|
||||||
connectTimeout: Duration = 500.milliseconds,
|
|
||||||
acquisitionTimeout: Duration = 500.milliseconds,
|
|
||||||
protocolFactoryOverride: Option[TProtocolFactory] = None,
|
|
||||||
)(
|
|
||||||
implicit servicePerEndpointBuilder: ServicePerEndpointBuilder[ServicePerEndpoint],
|
|
||||||
methodPerEndpointBuilder: MethodPerEndpointBuilder[ServicePerEndpoint, MethodPerEndpoint]
|
|
||||||
): MethodPerEndpoint = {
|
|
||||||
val service: ServicePerEndpoint = buildFinagleServicePerEndpoint(
|
|
||||||
serviceIdentifier = serviceIdentifier,
|
|
||||||
clientId = clientId,
|
|
||||||
dest = dest,
|
|
||||||
label = label,
|
|
||||||
statsReceiver = statsReceiver,
|
|
||||||
idempotency = idempotency,
|
|
||||||
timeoutPerRequest = timeoutPerRequest,
|
|
||||||
timeoutTotal = timeoutTotal,
|
|
||||||
connectTimeout = connectTimeout,
|
|
||||||
acquisitionTimeout = acquisitionTimeout,
|
|
||||||
protocolFactoryOverride = protocolFactoryOverride
|
|
||||||
)
|
|
||||||
|
|
||||||
ThriftMux.Client.methodPerEndpoint(service)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a Finagle Thrift service per endpoint client.
|
|
||||||
*
|
|
||||||
* @note [[buildFinagleMethodPerEndpoint]] should be preferred over the service per endpoint variant
|
|
||||||
*
|
|
||||||
* @param serviceIdentifier Service ID used to S2S Auth
|
|
||||||
* @param clientId Client ID
|
|
||||||
* @param dest Destination as a Wily path e.g. "/s/sample/sample"
|
|
||||||
* @param label Label of the client
|
|
||||||
* @param statsReceiver Stats
|
|
||||||
* @param idempotency Idempotency semantics of the client
|
|
||||||
* @param timeoutPerRequest Thrift client timeout per request. The Finagle default is
|
|
||||||
* unbounded which is almost never optimal.
|
|
||||||
* @param timeoutTotal Thrift client total timeout. The Finagle default is
|
|
||||||
* unbounded which is almost never optimal.
|
|
||||||
* If the client is set as idempotent, which adds a
|
|
||||||
* [[com.twitter.finagle.client.BackupRequestFilter]],
|
|
||||||
* be sure to leave enough room for the backup request. A
|
|
||||||
* reasonable (albeit usually too large) starting point is to
|
|
||||||
* make the total timeout 2x relative to the per request timeout.
|
|
||||||
* If the client is set as non-idempotent, the total timeout and
|
|
||||||
* the per request timeout should be the same, as there will be
|
|
||||||
* no backup requests.
|
|
||||||
* @param connectTimeout Thrift client transport connect timeout. The Finagle default
|
|
||||||
* of one second is reasonable but we lower this to match
|
|
||||||
* acquisitionTimeout for consistency.
|
|
||||||
* @param acquisitionTimeout Thrift client session acquisition timeout. The Finagle default
|
|
||||||
* is unbounded which is almost never optimal.
|
|
||||||
* @param protocolFactoryOverride Override the default protocol factory
|
|
||||||
* e.g. [[org.apache.thrift.protocol.TCompactProtocol.Factory]]
|
|
||||||
*
|
|
||||||
* @return service per endpoint Finagle Thrift Client
|
|
||||||
*/
|
|
||||||
def buildFinagleServicePerEndpoint[ServicePerEndpoint <: Filterable[ServicePerEndpoint]](
|
|
||||||
serviceIdentifier: ServiceIdentifier,
|
|
||||||
clientId: ClientId,
|
|
||||||
dest: String,
|
|
||||||
label: String,
|
|
||||||
statsReceiver: StatsReceiver,
|
|
||||||
idempotency: Idempotency,
|
|
||||||
timeoutPerRequest: Duration,
|
|
||||||
timeoutTotal: Duration,
|
|
||||||
connectTimeout: Duration = 500.milliseconds,
|
|
||||||
acquisitionTimeout: Duration = 500.milliseconds,
|
|
||||||
protocolFactoryOverride: Option[TProtocolFactory] = None,
|
|
||||||
)(
|
|
||||||
implicit servicePerEndpointBuilder: ServicePerEndpointBuilder[ServicePerEndpoint]
|
|
||||||
): ServicePerEndpoint = {
|
|
||||||
val thriftMux: ThriftMux.Client = ThriftMux.client
|
|
||||||
.withMutualTls(serviceIdentifier)
|
|
||||||
.withClientId(clientId)
|
|
||||||
.withLabel(label)
|
|
||||||
.withStatsReceiver(statsReceiver)
|
|
||||||
.withTransport.connectTimeout(connectTimeout)
|
|
||||||
.withSession.acquisitionTimeout(acquisitionTimeout)
|
|
||||||
|
|
||||||
val protocolThriftMux: ThriftMux.Client = protocolFactoryOverride
|
|
||||||
.map { protocolFactory =>
|
|
||||||
thriftMux.withProtocolFactory(protocolFactory)
|
|
||||||
}.getOrElse(thriftMux)
|
|
||||||
|
|
||||||
val methodBuilder: MethodBuilder = protocolThriftMux
|
|
||||||
.methodBuilder(dest)
|
|
||||||
.withTimeoutPerRequest(timeoutPerRequest)
|
|
||||||
.withTimeoutTotal(timeoutTotal)
|
|
||||||
|
|
||||||
val idempotencyMethodBuilder: MethodBuilder = idempotency match {
|
|
||||||
case NonIdempotent => methodBuilder.nonIdempotent
|
|
||||||
case Idempotent(maxExtraLoad) => methodBuilder.idempotent(maxExtraLoad = maxExtraLoad)
|
|
||||||
}
|
|
||||||
|
|
||||||
idempotencyMethodBuilder.servicePerEndpoint[ServicePerEndpoint]
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,48 +0,0 @@
|
|||||||
alias(
|
|
||||||
name = "frigate-pushservice",
|
|
||||||
target = ":frigate-pushservice_lib",
|
|
||||||
)
|
|
||||||
|
|
||||||
target(
|
|
||||||
name = "frigate-pushservice_lib",
|
|
||||||
dependencies = [
|
|
||||||
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
jvm_binary(
|
|
||||||
name = "bin",
|
|
||||||
basename = "frigate-pushservice",
|
|
||||||
main = "com.twitter.frigate.pushservice.PushServiceMain",
|
|
||||||
runtime_platform = "java11",
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/ch/qos/logback:logback-classic",
|
|
||||||
"finatra/inject/inject-logback/src/main/scala",
|
|
||||||
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
|
|
||||||
"loglens/loglens-logback/src/main/scala/com/twitter/loglens/logback",
|
|
||||||
"twitter-server/logback-classic/src/main/scala",
|
|
||||||
],
|
|
||||||
excludes = [
|
|
||||||
exclude("com.twitter.translations", "translations-twitter"),
|
|
||||||
exclude("org.apache.hadoop", "hadoop-aws"),
|
|
||||||
exclude("org.tensorflow"),
|
|
||||||
scala_exclude("com.twitter", "ckoia-scala"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
jvm_app(
|
|
||||||
name = "bundle",
|
|
||||||
basename = "frigate-pushservice-package-dist",
|
|
||||||
archive = "zip",
|
|
||||||
binary = ":bin",
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "mr_model_constants",
|
|
||||||
sources = [
|
|
||||||
"config/deepbird/constants.py",
|
|
||||||
],
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
)
|
|
BIN
pushservice/BUILD.docx
Normal file
BIN
pushservice/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/README.docx
Normal file
BIN
pushservice/README.docx
Normal file
Binary file not shown.
@ -1,45 +0,0 @@
|
|||||||
# Pushservice
|
|
||||||
|
|
||||||
Pushservice is the main push recommendation service at Twitter used to generate recommendation-based notifications for users. It currently powers two functionalities:
|
|
||||||
|
|
||||||
- RefreshForPushHandler: This handler determines whether to send a recommendation push to a user based on their ID. It generates the best push recommendation item and coordinates with downstream services to deliver it
|
|
||||||
- SendHandler: This handler determines and manage whether send the push to users based on the given target user details and the provided push recommendation item
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
### RefreshForPushHandler
|
|
||||||
|
|
||||||
RefreshForPushHandler follows these steps:
|
|
||||||
|
|
||||||
- Building Target and checking eligibility
|
|
||||||
- Builds a target user object based on the given user ID
|
|
||||||
- Performs target-level filterings to determine if the target is eligible for a recommendation push
|
|
||||||
- Fetch Candidates
|
|
||||||
- Retrieves a list of potential candidates for the push by querying various candidate sources using the target
|
|
||||||
- Candidate Hydration
|
|
||||||
- Hydrates the candidate details with batch calls to different downstream services
|
|
||||||
- Pre-rank Filtering, also called Light Filtering
|
|
||||||
- Filters the hydrated candidates with lightweight RPC calls
|
|
||||||
- Rank
|
|
||||||
- Perform feature hydration for candidates and target user
|
|
||||||
- Performs light ranking on candidates
|
|
||||||
- Performs heavy ranking on candidates
|
|
||||||
- Take Step, also called Heavy Filtering
|
|
||||||
- Takes the top-ranked candidates one by one and applies heavy filtering until one candidate passes all filter steps
|
|
||||||
- Send
|
|
||||||
- Calls the appropriate downstream service to deliver the eligible candidate as a push and in-app notification to the target user
|
|
||||||
|
|
||||||
### SendHandler
|
|
||||||
|
|
||||||
SendHandler follows these steps:
|
|
||||||
|
|
||||||
- Building Target
|
|
||||||
- Builds a target user object based on the given user ID
|
|
||||||
- Candidate Hydration
|
|
||||||
- Hydrates the candidate details with batch calls to different downstream services
|
|
||||||
- Feature Hydration
|
|
||||||
- Perform feature hydration for candidates and target user
|
|
||||||
- Take Step, also called Heavy Filtering
|
|
||||||
- Perform filterings and validation checking for the given candidate
|
|
||||||
- Send
|
|
||||||
- Calls the appropriate downstream service to deliver the given candidate as a push and/or in-app notification to the target user
|
|
@ -1,169 +0,0 @@
|
|||||||
python37_binary(
|
|
||||||
name = "update_warm_start_checkpoint",
|
|
||||||
source = "update_warm_start_checkpoint.py",
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":deep_norm_lib",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:update_warm_start_checkpoint",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "params_lib",
|
|
||||||
sources = ["params.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/python/pydantic:default",
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:params_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "features_lib",
|
|
||||||
sources = ["features.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":params_lib",
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
|
||||||
"twml:twml-nodeps",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "model_pools_lib",
|
|
||||||
sources = ["model_pools.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":features_lib",
|
|
||||||
":params_lib",
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:model_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "graph_lib",
|
|
||||||
sources = ["graph.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":params_lib",
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "run_args_lib",
|
|
||||||
sources = ["run_args.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":features_lib",
|
|
||||||
":params_lib",
|
|
||||||
"twml:twml-nodeps",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "deep_norm_lib",
|
|
||||||
sources = ["deep_norm.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":features_lib",
|
|
||||||
":graph_lib",
|
|
||||||
":model_pools_lib",
|
|
||||||
":params_lib",
|
|
||||||
":run_args_lib",
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
|
||||||
"src/python/twitter/deepbird/util/data",
|
|
||||||
"twml:twml-nodeps",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "eval_lib",
|
|
||||||
sources = ["eval.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":features_lib",
|
|
||||||
":graph_lib",
|
|
||||||
":model_pools_lib",
|
|
||||||
":params_lib",
|
|
||||||
":run_args_lib",
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
|
||||||
"twml:twml-nodeps",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "deep_norm",
|
|
||||||
source = "deep_norm.py",
|
|
||||||
dependencies = [
|
|
||||||
":deep_norm_lib",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:deep_norm",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "eval",
|
|
||||||
source = "eval.py",
|
|
||||||
dependencies = [
|
|
||||||
":eval_lib",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "mlwf_libs",
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
":deep_norm_lib",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "train_model",
|
|
||||||
source = "deep_norm.py",
|
|
||||||
dependencies = [
|
|
||||||
":deep_norm_lib",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "train_model_local",
|
|
||||||
source = "deep_norm.py",
|
|
||||||
dependencies = [
|
|
||||||
":deep_norm_lib",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model_local",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "eval_model_local",
|
|
||||||
source = "eval.py",
|
|
||||||
dependencies = [
|
|
||||||
":eval_lib",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model_local",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "eval_model",
|
|
||||||
source = "eval.py",
|
|
||||||
dependencies = [
|
|
||||||
":eval_lib",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "mlwf_model",
|
|
||||||
source = "deep_norm.py",
|
|
||||||
dependencies = [
|
|
||||||
":mlwf_libs",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:mlwf_model",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/heavy_ranking/README.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/README.docx
Normal file
Binary file not shown.
@ -1,20 +0,0 @@
|
|||||||
# Notification Heavy Ranker Model
|
|
||||||
|
|
||||||
## Model Context
|
|
||||||
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification heavy ranker model is the core ranking model for the personalised notifications recommendation. It's a multi-task learning model to predict the probabilities that the target users will open and engage with the sent notifications.
|
|
||||||
|
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
- BUILD: this file defines python library dependencies
|
|
||||||
- deep_norm.py: this file contains how to set up continuous training, model evaluation and model exporting for the notification heavy ranker model
|
|
||||||
- eval.py: the main python entry file to set up the overall model evaluation pipeline
|
|
||||||
- features.py: this file contains importing feature list and support functions for feature engineering
|
|
||||||
- graph.py: this file defines how to build the tensorflow graph with specified model architecture, loss function and training configuration
|
|
||||||
- model_pools.py: this file defines the available model types for the heavy ranker
|
|
||||||
- params.py: this file defines hyper-parameters used in the notification heavy ranker
|
|
||||||
- run_args.py: this file defines command line parameters to run model training & evaluation
|
|
||||||
- update_warm_start_checkpoint.py: this file contains the support to modify checkpoints of the given saved heavy ranker model
|
|
||||||
- lib/BUILD: this file defines python library dependencies for tensorflow model architecture
|
|
||||||
- lib/layers.py: this file defines different type of convolution layers to be used in the heavy ranker model
|
|
||||||
- lib/model.py: this file defines the module containing ClemNet, the heavy ranker model type
|
|
||||||
- lib/params.py: this file defines parameters used in the heavy ranker model
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/__init__.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/__init__.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/heavy_ranking/deep_norm.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/deep_norm.docx
Normal file
Binary file not shown.
@ -1,136 +0,0 @@
|
|||||||
"""
|
|
||||||
Training job for the heavy ranker of the push notification service.
|
|
||||||
"""
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import twml
|
|
||||||
|
|
||||||
from ..libs.metric_fn_utils import flip_disliked_labels, get_metric_fn
|
|
||||||
from ..libs.model_utils import read_config
|
|
||||||
from ..libs.warm_start_utils import get_feature_list_for_heavy_ranking, warm_start_checkpoint
|
|
||||||
from .features import get_feature_config
|
|
||||||
from .model_pools import ALL_MODELS
|
|
||||||
from .params import load_graph_params
|
|
||||||
from .run_args import get_training_arg_parser
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
from tensorflow.compat.v1 import logging
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args, _ = get_training_arg_parser().parse_known_args()
|
|
||||||
logging.info(f"Parsed args: {args}")
|
|
||||||
|
|
||||||
params = load_graph_params(args)
|
|
||||||
logging.info(f"Loaded graph params: {params}")
|
|
||||||
|
|
||||||
param_file = os.path.join(args.save_dir, "params.json")
|
|
||||||
logging.info(f"Saving graph params to: {param_file}")
|
|
||||||
with tf.io.gfile.GFile(param_file, mode="w") as file:
|
|
||||||
json.dump(params.json(), file, ensure_ascii=False, indent=4)
|
|
||||||
|
|
||||||
logging.info(f"Get Feature Config: {args.feature_list}")
|
|
||||||
feature_list = read_config(args.feature_list).items()
|
|
||||||
feature_config = get_feature_config(
|
|
||||||
data_spec_path=args.data_spec,
|
|
||||||
params=params,
|
|
||||||
feature_list_provided=feature_list,
|
|
||||||
)
|
|
||||||
feature_list_path = args.feature_list
|
|
||||||
|
|
||||||
warm_start_from = args.warm_start_from
|
|
||||||
if args.warm_start_base_dir:
|
|
||||||
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")
|
|
||||||
|
|
||||||
continuous_binary_feat_list_save_path = os.path.join(
|
|
||||||
args.warm_start_base_dir, "continuous_binary_feat_list.json"
|
|
||||||
)
|
|
||||||
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
|
|
||||||
job_name = os.path.basename(args.save_dir)
|
|
||||||
ws_output_ckpt_folder = os.path.join(args.warm_start_base_dir, f"warm_start_for_{job_name}")
|
|
||||||
if tf.io.gfile.exists(ws_output_ckpt_folder):
|
|
||||||
tf.io.gfile.rmtree(ws_output_ckpt_folder)
|
|
||||||
|
|
||||||
tf.io.gfile.mkdir(ws_output_ckpt_folder)
|
|
||||||
|
|
||||||
warm_start_from = warm_start_checkpoint(
|
|
||||||
warm_start_folder,
|
|
||||||
continuous_binary_feat_list_save_path,
|
|
||||||
feature_list_path,
|
|
||||||
args.data_spec,
|
|
||||||
ws_output_ckpt_folder,
|
|
||||||
)
|
|
||||||
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")
|
|
||||||
|
|
||||||
logging.info("Build Trainer.")
|
|
||||||
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
|
|
||||||
|
|
||||||
trainer = twml.trainers.DataRecordTrainer(
|
|
||||||
name="magic_recs",
|
|
||||||
params=args,
|
|
||||||
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
|
|
||||||
save_dir=args.save_dir,
|
|
||||||
run_config=None,
|
|
||||||
feature_config=feature_config,
|
|
||||||
metric_fn=flip_disliked_labels(metric_fn),
|
|
||||||
warm_start_from=warm_start_from,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Build train and eval input functions.")
|
|
||||||
train_input_fn = trainer.get_train_input_fn(shuffle=True)
|
|
||||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
|
||||||
|
|
||||||
learn = trainer.learn
|
|
||||||
if args.distributed or args.num_workers is not None:
|
|
||||||
learn = trainer.train_and_evaluate
|
|
||||||
|
|
||||||
if not args.directly_export_best:
|
|
||||||
logging.info("Starting training")
|
|
||||||
start = datetime.now()
|
|
||||||
learn(
|
|
||||||
early_stop_minimize=False,
|
|
||||||
early_stop_metric="pr_auc_unweighted_OONC",
|
|
||||||
early_stop_patience=args.early_stop_patience,
|
|
||||||
early_stop_tolerance=args.early_stop_tolerance,
|
|
||||||
eval_input_fn=eval_input_fn,
|
|
||||||
train_input_fn=train_input_fn,
|
|
||||||
)
|
|
||||||
logging.info(f"Total training time: {datetime.now() - start}")
|
|
||||||
else:
|
|
||||||
logging.info("Directly exporting the model")
|
|
||||||
|
|
||||||
if not args.export_dir:
|
|
||||||
args.export_dir = os.path.join(args.save_dir, "exported_models")
|
|
||||||
|
|
||||||
logging.info(f"Exporting the model to {args.export_dir}.")
|
|
||||||
start = datetime.now()
|
|
||||||
twml.contrib.export.export_fn.export_all_models(
|
|
||||||
trainer=trainer,
|
|
||||||
export_dir=args.export_dir,
|
|
||||||
parse_fn=feature_config.get_parse_fn(),
|
|
||||||
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
|
|
||||||
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(f"Total model export time: {datetime.now() - start}")
|
|
||||||
logging.info(f"The MLP directory is: {args.save_dir}")
|
|
||||||
|
|
||||||
continuous_binary_feat_list_save_path = os.path.join(
|
|
||||||
args.save_dir, "continuous_binary_feat_list.json"
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
|
|
||||||
)
|
|
||||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
|
||||||
feature_list_path, args.data_spec
|
|
||||||
)
|
|
||||||
twml.util.write_file(
|
|
||||||
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
logging.info("Done.")
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/eval.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/eval.docx
Normal file
Binary file not shown.
@ -1,59 +0,0 @@
|
|||||||
"""
|
|
||||||
Evaluation job for the heavy ranker of the push notification service.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import twml
|
|
||||||
|
|
||||||
from ..libs.metric_fn_utils import get_metric_fn
|
|
||||||
from ..libs.model_utils import read_config
|
|
||||||
from .features import get_feature_config
|
|
||||||
from .model_pools import ALL_MODELS
|
|
||||||
from .params import load_graph_params
|
|
||||||
from .run_args import get_eval_arg_parser
|
|
||||||
|
|
||||||
from tensorflow.compat.v1 import logging
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args, _ = get_eval_arg_parser().parse_known_args()
|
|
||||||
logging.info(f"Parsed args: {args}")
|
|
||||||
|
|
||||||
params = load_graph_params(args)
|
|
||||||
logging.info(f"Loaded graph params: {params}")
|
|
||||||
|
|
||||||
logging.info(f"Get Feature Config: {args.feature_list}")
|
|
||||||
feature_list = read_config(args.feature_list).items()
|
|
||||||
feature_config = get_feature_config(
|
|
||||||
data_spec_path=args.data_spec,
|
|
||||||
params=params,
|
|
||||||
feature_list_provided=feature_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Build DataRecordTrainer.")
|
|
||||||
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
|
|
||||||
|
|
||||||
trainer = twml.trainers.DataRecordTrainer(
|
|
||||||
name="magic_recs",
|
|
||||||
params=args,
|
|
||||||
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
|
|
||||||
save_dir=args.save_dir,
|
|
||||||
run_config=None,
|
|
||||||
feature_config=feature_config,
|
|
||||||
metric_fn=metric_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Run the evaluation.")
|
|
||||||
start = datetime.now()
|
|
||||||
trainer._estimator.evaluate(
|
|
||||||
input_fn=trainer.get_eval_input_fn(repeat=False, shuffle=False),
|
|
||||||
steps=None if (args.eval_steps is not None and args.eval_steps < 0) else args.eval_steps,
|
|
||||||
checkpoint_path=args.eval_checkpoint,
|
|
||||||
)
|
|
||||||
logging.info(f"Evaluating time: {datetime.now() - start}.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
logging.info("Job done.")
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/features.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/features.docx
Normal file
Binary file not shown.
@ -1,138 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from twitter.deepbird.projects.magic_recs.libs.model_utils import filter_nans_and_infs
|
|
||||||
import twml
|
|
||||||
from twml.layers import full_sparse, sparse_max_norm
|
|
||||||
|
|
||||||
from .params import FeaturesParams, GraphParams, SparseFeaturesParams
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow import Tensor
|
|
||||||
import tensorflow.compat.v1 as tf1
|
|
||||||
|
|
||||||
|
|
||||||
FEAT_CONFIG_DEFAULT_VAL = 0
|
|
||||||
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
|
|
||||||
FEATURE_LIST_DEFAULT_PATH = os.path.join(
|
|
||||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_feature_config(data_spec_path=None, feature_list_provided=[], params: GraphParams = None):
|
|
||||||
|
|
||||||
a_string_feat_list = [feat for feat, feat_type in feature_list_provided if feat_type != "S"]
|
|
||||||
|
|
||||||
builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
|
||||||
data_spec_path=data_spec_path, debug=False
|
|
||||||
)
|
|
||||||
|
|
||||||
builder = builder.extract_feature_group(
|
|
||||||
feature_regexes=a_string_feat_list,
|
|
||||||
group_name="continuous_features",
|
|
||||||
default_value=FEAT_CONFIG_DEFAULT_VAL,
|
|
||||||
type_filter=["CONTINUOUS"],
|
|
||||||
)
|
|
||||||
|
|
||||||
builder = builder.extract_feature_group(
|
|
||||||
feature_regexes=a_string_feat_list,
|
|
||||||
group_name="binary_features",
|
|
||||||
type_filter=["BINARY"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.model.features.sparse_features:
|
|
||||||
builder = builder.extract_features_as_hashed_sparse(
|
|
||||||
feature_regexes=a_string_feat_list,
|
|
||||||
hash_space_size_bits=params.model.features.sparse_features.bits,
|
|
||||||
type_filter=["DISCRETE", "STRING", "SPARSE_BINARY"],
|
|
||||||
output_tensor_name="sparse_not_continuous",
|
|
||||||
)
|
|
||||||
|
|
||||||
builder = builder.extract_features_as_hashed_sparse(
|
|
||||||
feature_regexes=[feat for feat, feat_type in feature_list_provided if feat_type == "S"],
|
|
||||||
hash_space_size_bits=params.model.features.sparse_features.bits,
|
|
||||||
type_filter=["SPARSE_CONTINUOUS"],
|
|
||||||
output_tensor_name="sparse_continuous",
|
|
||||||
)
|
|
||||||
|
|
||||||
builder = builder.add_labels([task.label for task in params.tasks] + ["label.ntabDislike"])
|
|
||||||
|
|
||||||
if params.weight:
|
|
||||||
builder = builder.define_weight(params.weight)
|
|
||||||
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
def dense_features(features: Dict[str, Tensor], training: bool) -> Tensor:
|
|
||||||
"""
|
|
||||||
Performs feature transformations on the raw dense features (continuous and binary).
|
|
||||||
"""
|
|
||||||
with tf.name_scope("dense_features"):
|
|
||||||
x = filter_nans_and_infs(features["continuous_features"])
|
|
||||||
|
|
||||||
x = tf.sign(x) * tf.math.log(tf.abs(x) + 1)
|
|
||||||
x = tf1.layers.batch_normalization(
|
|
||||||
x, momentum=0.9999, training=training, renorm=training, axis=1
|
|
||||||
)
|
|
||||||
x = tf.clip_by_value(x, -5, 5)
|
|
||||||
|
|
||||||
transformed_continous_features = tf.where(tf.math.is_nan(x), tf.zeros_like(x), x)
|
|
||||||
|
|
||||||
binary_features = filter_nans_and_infs(features["binary_features"])
|
|
||||||
binary_features = tf.dtypes.cast(binary_features, tf.float32)
|
|
||||||
|
|
||||||
output = tf.concat([transformed_continous_features, binary_features], axis=1)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_features(
|
|
||||||
features: Dict[str, Tensor], training: bool, params: SparseFeaturesParams
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Performs feature transformations on the raw sparse features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
with tf.name_scope("sparse_features"):
|
|
||||||
with tf.name_scope("sparse_not_continuous"):
|
|
||||||
sparse_not_continuous = full_sparse(
|
|
||||||
inputs=features["sparse_not_continuous"],
|
|
||||||
output_size=params.embedding_size,
|
|
||||||
use_sparse_grads=training,
|
|
||||||
use_binary_values=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
with tf.name_scope("sparse_continuous"):
|
|
||||||
shape_enforced_input = twml.util.limit_sparse_tensor_size(
|
|
||||||
sparse_tf=features["sparse_continuous"], input_size_bits=params.bits, mask_indices=False
|
|
||||||
)
|
|
||||||
|
|
||||||
normalized_continuous_sparse = sparse_max_norm(
|
|
||||||
inputs=shape_enforced_input, is_training=training
|
|
||||||
)
|
|
||||||
|
|
||||||
sparse_continuous = full_sparse(
|
|
||||||
inputs=normalized_continuous_sparse,
|
|
||||||
output_size=params.embedding_size,
|
|
||||||
use_sparse_grads=training,
|
|
||||||
use_binary_values=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = tf.concat([sparse_not_continuous, sparse_continuous], axis=1)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def get_features(features: Dict[str, Tensor], training: bool, params: FeaturesParams) -> Tensor:
|
|
||||||
"""
|
|
||||||
Performs feature transformations on the dense and sparse features and combine the resulting
|
|
||||||
tensors into a single one.
|
|
||||||
"""
|
|
||||||
with tf.name_scope("features"):
|
|
||||||
x = dense_features(features, training)
|
|
||||||
tf1.logging.info(f"Dense features: {x.shape}")
|
|
||||||
|
|
||||||
if params.sparse_features:
|
|
||||||
x = tf.concat([x, sparse_features(features, training, params.sparse_features)], axis=1)
|
|
||||||
|
|
||||||
return x
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/graph.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/graph.docx
Normal file
Binary file not shown.
@ -1,129 +0,0 @@
|
|||||||
"""
|
|
||||||
Graph class defining methods to obtain key quantities such as:
|
|
||||||
* the logits
|
|
||||||
* the probabilities
|
|
||||||
* the final score
|
|
||||||
* the loss function
|
|
||||||
* the training operator
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from twitter.deepbird.hparam import HParams
|
|
||||||
import twml
|
|
||||||
|
|
||||||
from ..libs.model_utils import generate_disliked_mask
|
|
||||||
from .params import GraphParams
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow.compat.v1 as tf1
|
|
||||||
|
|
||||||
|
|
||||||
class Graph(ABC):
|
|
||||||
def __init__(self, params: GraphParams):
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_logits(self, features: Dict[str, tf.Tensor], mode: tf.estimator.ModeKeys) -> tf.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_probabilities(self, logits: tf.Tensor) -> tf.Tensor:
|
|
||||||
return tf.math.cumprod(tf.nn.sigmoid(logits), axis=1, name="probabilities")
|
|
||||||
|
|
||||||
def get_task_weights(self, labels: tf.Tensor) -> tf.Tensor:
|
|
||||||
oonc_label = tf.reshape(labels[:, 0], shape=(-1, 1))
|
|
||||||
task_weights = tf.concat([tf.ones_like(oonc_label), oonc_label], axis=1)
|
|
||||||
|
|
||||||
n_labels = len(self.params.tasks)
|
|
||||||
task_weights = tf.reshape(task_weights[:, 0:n_labels], shape=(-1, n_labels))
|
|
||||||
|
|
||||||
return task_weights
|
|
||||||
|
|
||||||
def get_loss(self, labels: tf.Tensor, logits: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
||||||
with tf.name_scope("weights"):
|
|
||||||
disliked_mask = generate_disliked_mask(labels)
|
|
||||||
|
|
||||||
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
|
|
||||||
|
|
||||||
labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
|
|
||||||
|
|
||||||
with tf.name_scope("task_weight"):
|
|
||||||
task_weights = self.get_task_weights(labels)
|
|
||||||
|
|
||||||
with tf.name_scope("batch_size"):
|
|
||||||
batch_size = tf.cast(tf.shape(labels)[0], dtype=tf.float32, name="batch_size")
|
|
||||||
|
|
||||||
weights = task_weights / batch_size
|
|
||||||
|
|
||||||
with tf.name_scope("loss"):
|
|
||||||
loss = tf.reduce_sum(
|
|
||||||
tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) * weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def get_score(self, probabilities: tf.Tensor) -> tf.Tensor:
|
|
||||||
with tf.name_scope("score_weight"):
|
|
||||||
score_weights = tf.constant([task.score_weight for task in self.params.tasks])
|
|
||||||
score_weights = score_weights / tf.reduce_sum(score_weights, axis=0)
|
|
||||||
|
|
||||||
with tf.name_scope("score"):
|
|
||||||
score = tf.reshape(tf.reduce_sum(probabilities * score_weights, axis=1), shape=[-1, 1])
|
|
||||||
|
|
||||||
return score
|
|
||||||
|
|
||||||
def get_train_op(self, loss: tf.Tensor, twml_params) -> Any:
|
|
||||||
with tf.name_scope("optimizer"):
|
|
||||||
learning_rate = twml_params.learning_rate
|
|
||||||
optimizer = tf1.train.GradientDescentOptimizer(learning_rate=learning_rate)
|
|
||||||
|
|
||||||
update_ops = set(tf1.get_collection(tf1.GraphKeys.UPDATE_OPS))
|
|
||||||
with tf.control_dependencies(update_ops):
|
|
||||||
train_op = twml.optimizers.optimize_loss(
|
|
||||||
loss=loss,
|
|
||||||
variables=tf1.trainable_variables(),
|
|
||||||
global_step=tf1.train.get_global_step(),
|
|
||||||
optimizer=optimizer,
|
|
||||||
learning_rate=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_op
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
features: Dict[str, tf.Tensor],
|
|
||||||
labels: tf.Tensor,
|
|
||||||
mode: tf.estimator.ModeKeys,
|
|
||||||
params: HParams,
|
|
||||||
config=None,
|
|
||||||
) -> Dict[str, tf.Tensor]:
|
|
||||||
training = mode == tf.estimator.ModeKeys.TRAIN
|
|
||||||
logits = self.get_logits(features=features, training=training)
|
|
||||||
probabilities = self.get_probabilities(logits=logits)
|
|
||||||
score = None
|
|
||||||
loss = None
|
|
||||||
train_op = None
|
|
||||||
|
|
||||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
|
||||||
score = self.get_score(probabilities=probabilities)
|
|
||||||
output = {"loss": loss, "train_op": train_op, "prediction": score}
|
|
||||||
|
|
||||||
elif mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
|
|
||||||
loss = self.get_loss(labels=labels, logits=logits)
|
|
||||||
|
|
||||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
||||||
train_op = self.get_train_op(loss=loss, twml_params=params)
|
|
||||||
|
|
||||||
output = {"loss": loss, "train_op": train_op, "output": probabilities}
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"""
|
|
||||||
Invalid mode. Possible values are: {tf.estimator.ModeKeys.PREDICT}, {tf.estimator.ModeKeys.TRAIN}, and {tf.estimator.ModeKeys.EVAL}
|
|
||||||
. Passed: {mode}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
@ -1,42 +0,0 @@
|
|||||||
python3_library(
|
|
||||||
name = "params_lib",
|
|
||||||
sources = [
|
|
||||||
"params.py",
|
|
||||||
],
|
|
||||||
tags = [
|
|
||||||
"bazel-compatible",
|
|
||||||
"no-mypy",
|
|
||||||
],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/python/pydantic:default",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "layers_lib",
|
|
||||||
sources = [
|
|
||||||
"layers.py",
|
|
||||||
],
|
|
||||||
tags = [
|
|
||||||
"bazel-compatible",
|
|
||||||
"no-mypy",
|
|
||||||
],
|
|
||||||
dependencies = [
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "model_lib",
|
|
||||||
sources = [
|
|
||||||
"model.py",
|
|
||||||
],
|
|
||||||
tags = [
|
|
||||||
"bazel-compatible",
|
|
||||||
"no-mypy",
|
|
||||||
],
|
|
||||||
dependencies = [
|
|
||||||
":layers_lib",
|
|
||||||
":params_lib",
|
|
||||||
"3rdparty/python/absl-py:default",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/lib/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/heavy_ranking/lib/layers.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/layers.docx
Normal file
Binary file not shown.
@ -1,128 +0,0 @@
|
|||||||
"""
|
|
||||||
Different type of convolution layers to be used in the ClemNet.
|
|
||||||
"""
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
|
|
||||||
class KerasConv1D(tf.keras.layers.Layer):
|
|
||||||
"""
|
|
||||||
Basic Conv1D layer in a wrapper to be compatible with ClemNet.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
kernel_size: int,
|
|
||||||
filters: int,
|
|
||||||
strides: int,
|
|
||||||
padding: str,
|
|
||||||
use_bias: bool = True,
|
|
||||||
kernel_initializer: str = "glorot_uniform",
|
|
||||||
bias_initializer: str = "zeros",
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
super(KerasConv1D, self).__init__(**kwargs)
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.filters = filters
|
|
||||||
self.use_bias = use_bias
|
|
||||||
self.kernel_initializer = kernel_initializer
|
|
||||||
self.bias_initializer = bias_initializer
|
|
||||||
self.strides = strides
|
|
||||||
self.padding = padding
|
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape) -> None:
|
|
||||||
assert (
|
|
||||||
len(input_shape) == 3
|
|
||||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
||||||
|
|
||||||
self.features = input_shape[1]
|
|
||||||
|
|
||||||
self.w = tf.keras.layers.Conv1D(
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
filters=self.filters,
|
|
||||||
strides=self.strides,
|
|
||||||
padding=self.padding,
|
|
||||||
use_bias=self.use_bias,
|
|
||||||
kernel_initializer=self.kernel_initializer,
|
|
||||||
bias_initializer=self.bias_initializer,
|
|
||||||
name=self.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
||||||
return self.w(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelWiseDense(tf.keras.layers.Layer):
|
|
||||||
"""
|
|
||||||
Dense layer is applied to each channel separately. This is more memory and computationally
|
|
||||||
efficient than flattening the channels and performing single dense layers over it which is the
|
|
||||||
default behavior in tf1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
output_size: int,
|
|
||||||
use_bias: bool,
|
|
||||||
kernel_initializer: str = "uniform_glorot",
|
|
||||||
bias_initializer: str = "zeros",
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
super(ChannelWiseDense, self).__init__(**kwargs)
|
|
||||||
self.output_size = output_size
|
|
||||||
self.use_bias = use_bias
|
|
||||||
self.kernel_initializer = kernel_initializer
|
|
||||||
self.bias_initializer = bias_initializer
|
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape) -> None:
|
|
||||||
assert (
|
|
||||||
len(input_shape) == 3
|
|
||||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
||||||
|
|
||||||
input_size = input_shape[1]
|
|
||||||
channels = input_shape[2]
|
|
||||||
|
|
||||||
self.kernel = self.add_weight(
|
|
||||||
name="kernel",
|
|
||||||
shape=(channels, input_size, self.output_size),
|
|
||||||
initializer=self.kernel_initializer,
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bias = self.add_weight(
|
|
||||||
name="bias",
|
|
||||||
shape=(channels, self.output_size),
|
|
||||||
initializer=self.bias_initializer,
|
|
||||||
trainable=self.use_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
||||||
x = inputs
|
|
||||||
|
|
||||||
transposed_x = tf.transpose(x, perm=[2, 0, 1])
|
|
||||||
transposed_residual = (
|
|
||||||
tf.transpose(tf.matmul(transposed_x, self.kernel), perm=[1, 0, 2]) + self.bias
|
|
||||||
)
|
|
||||||
output = tf.transpose(transposed_residual, perm=[0, 2, 1])
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualLayer(tf.keras.layers.Layer):
|
|
||||||
"""
|
|
||||||
Layer implementing a 3D-residual connection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape) -> None:
|
|
||||||
assert (
|
|
||||||
len(input_shape) == 3
|
|
||||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, residual: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
||||||
shortcut = tf.keras.layers.Conv1D(
|
|
||||||
filters=int(residual.shape[2]), strides=1, kernel_size=1, padding="SAME", use_bias=False
|
|
||||||
)(inputs)
|
|
||||||
|
|
||||||
output = tf.add(shortcut, residual)
|
|
||||||
|
|
||||||
return output
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/lib/model.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/model.docx
Normal file
Binary file not shown.
@ -1,76 +0,0 @@
|
|||||||
"""
|
|
||||||
Module containing ClemNet.
|
|
||||||
"""
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .layers import ChannelWiseDense, KerasConv1D, ResidualLayer
|
|
||||||
from .params import BlockParams, ClemNetParams
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow.compat.v1 as tf1
|
|
||||||
|
|
||||||
|
|
||||||
class Block2(tf.keras.layers.Layer):
|
|
||||||
"""
|
|
||||||
Possible ClemNet block. Architecture is as follow:
|
|
||||||
Optional(DenseLayer + BN + Act)
|
|
||||||
Optional(ConvLayer + BN + Act)
|
|
||||||
Optional(Residual Layer)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, params: BlockParams, **kwargs: Any):
|
|
||||||
super(Block2, self).__init__(**kwargs)
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape) -> None:
|
|
||||||
assert (
|
|
||||||
len(input_shape) == 3
|
|
||||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
|
|
||||||
x = inputs
|
|
||||||
if self.params.dense:
|
|
||||||
x = ChannelWiseDense(**self.params.dense.dict())(inputs=x, training=training)
|
|
||||||
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
|
|
||||||
x = tf.keras.layers.Activation(self.params.activation)(x)
|
|
||||||
|
|
||||||
if self.params.conv:
|
|
||||||
x = KerasConv1D(**self.params.conv.dict())(inputs=x, training=training)
|
|
||||||
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
|
|
||||||
x = tf.keras.layers.Activation(self.params.activation)(x)
|
|
||||||
|
|
||||||
if self.params.residual:
|
|
||||||
x = ResidualLayer()(inputs=inputs, residual=x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ClemNet(tf.keras.layers.Layer):
|
|
||||||
"""
|
|
||||||
A residual network stacking residual blocks composed of dense layers and convolutions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, params: ClemNetParams, **kwargs: Any):
|
|
||||||
super(ClemNet, self).__init__(**kwargs)
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape) -> None:
|
|
||||||
assert len(input_shape) in (
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
|
|
||||||
if len(inputs.shape) < 3:
|
|
||||||
inputs = tf.expand_dims(inputs, axis=-1)
|
|
||||||
|
|
||||||
x = inputs
|
|
||||||
for block_params in self.params.blocks:
|
|
||||||
x = Block2(block_params)(inputs=x, training=training)
|
|
||||||
|
|
||||||
x = tf.keras.layers.Flatten(name="flattened")(x)
|
|
||||||
if self.params.top:
|
|
||||||
x = tf.keras.layers.Dense(units=self.params.top.n_labels, name="logits")(x)
|
|
||||||
|
|
||||||
return x
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/lib/params.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/params.docx
Normal file
Binary file not shown.
@ -1,49 +0,0 @@
|
|||||||
"""
|
|
||||||
Parameters used in ClemNet.
|
|
||||||
"""
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, PositiveInt
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedBaseModel(BaseModel):
|
|
||||||
class Config:
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
|
|
||||||
class DenseParams(ExtendedBaseModel):
|
|
||||||
name: Optional[str]
|
|
||||||
bias_initializer: str = "zeros"
|
|
||||||
kernel_initializer: str = "glorot_uniform"
|
|
||||||
output_size: PositiveInt
|
|
||||||
use_bias: bool = Field(True)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvParams(ExtendedBaseModel):
|
|
||||||
name: Optional[str]
|
|
||||||
bias_initializer: str = "zeros"
|
|
||||||
filters: PositiveInt
|
|
||||||
kernel_initializer: str = "glorot_uniform"
|
|
||||||
kernel_size: PositiveInt
|
|
||||||
padding: str = "SAME"
|
|
||||||
strides: PositiveInt = 1
|
|
||||||
use_bias: bool = Field(True)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockParams(ExtendedBaseModel):
|
|
||||||
activation: Optional[str]
|
|
||||||
conv: Optional[ConvParams]
|
|
||||||
dense: Optional[DenseParams]
|
|
||||||
residual: Optional[bool]
|
|
||||||
|
|
||||||
|
|
||||||
class TopLayerParams(ExtendedBaseModel):
|
|
||||||
n_labels: PositiveInt
|
|
||||||
|
|
||||||
|
|
||||||
class ClemNetParams(ExtendedBaseModel):
|
|
||||||
blocks: List[BlockParams] = []
|
|
||||||
top: Optional[TopLayerParams]
|
|
Binary file not shown.
@ -1,34 +0,0 @@
|
|||||||
"""
|
|
||||||
Candidate architectures for each task's.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from .features import get_features
|
|
||||||
from .graph import Graph
|
|
||||||
from .lib.model import ClemNet
|
|
||||||
from .params import ModelTypeEnum
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
|
|
||||||
class MagicRecsClemNet(Graph):
|
|
||||||
def get_logits(self, features: Dict[str, tf.Tensor], training: bool) -> tf.Tensor:
|
|
||||||
|
|
||||||
with tf.name_scope("logits"):
|
|
||||||
inputs = get_features(features=features, training=training, params=self.params.model.features)
|
|
||||||
|
|
||||||
with tf.name_scope("OONC_logits"):
|
|
||||||
model = ClemNet(params=self.params.model.architecture)
|
|
||||||
oonc_logit = model(inputs=inputs, training=training)
|
|
||||||
|
|
||||||
with tf.name_scope("EngagementGivenOONC_logits"):
|
|
||||||
model = ClemNet(params=self.params.model.architecture)
|
|
||||||
eng_logits = model(inputs=inputs, training=training)
|
|
||||||
|
|
||||||
return tf.concat([oonc_logit, eng_logits], axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
ALL_MODELS = {ModelTypeEnum.clemnet: MagicRecsClemNet}
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/params.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/params.docx
Normal file
Binary file not shown.
@ -1,89 +0,0 @@
|
|||||||
import enum
|
|
||||||
import json
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from .lib.params import BlockParams, ClemNetParams, ConvParams, DenseParams, TopLayerParams
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, NonNegativeFloat
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedBaseModel(BaseModel):
|
|
||||||
class Config:
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
|
|
||||||
class SparseFeaturesParams(ExtendedBaseModel):
|
|
||||||
bits: int
|
|
||||||
embedding_size: int
|
|
||||||
|
|
||||||
|
|
||||||
class FeaturesParams(ExtendedBaseModel):
|
|
||||||
sparse_features: Optional[SparseFeaturesParams]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelTypeEnum(str, enum.Enum):
|
|
||||||
clemnet: str = "clemnet"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelParams(ExtendedBaseModel):
|
|
||||||
name: ModelTypeEnum
|
|
||||||
features: FeaturesParams
|
|
||||||
architecture: ClemNetParams
|
|
||||||
|
|
||||||
|
|
||||||
class TaskNameEnum(str, enum.Enum):
|
|
||||||
oonc: str = "OONC"
|
|
||||||
engagement: str = "Engagement"
|
|
||||||
|
|
||||||
|
|
||||||
class Task(ExtendedBaseModel):
|
|
||||||
name: TaskNameEnum
|
|
||||||
label: str
|
|
||||||
score_weight: NonNegativeFloat
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TASKS = [
|
|
||||||
Task(name=TaskNameEnum.oonc, label="label", score_weight=0.9),
|
|
||||||
Task(name=TaskNameEnum.engagement, label="label.engagement", score_weight=0.1),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class GraphParams(ExtendedBaseModel):
|
|
||||||
tasks: List[Task] = DEFAULT_TASKS
|
|
||||||
model: ModelParams
|
|
||||||
weight: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_ARCHITECTURE_PARAMS = ClemNetParams(
|
|
||||||
blocks=[
|
|
||||||
BlockParams(
|
|
||||||
activation="relu",
|
|
||||||
conv=ConvParams(kernel_size=3, filters=5),
|
|
||||||
dense=DenseParams(output_size=output_size),
|
|
||||||
residual=False,
|
|
||||||
)
|
|
||||||
for output_size in [1024, 512, 256, 128]
|
|
||||||
],
|
|
||||||
top=TopLayerParams(n_labels=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_GRAPH_PARAMS = GraphParams(
|
|
||||||
model=ModelParams(
|
|
||||||
name=ModelTypeEnum.clemnet,
|
|
||||||
architecture=DEFAULT_ARCHITECTURE_PARAMS,
|
|
||||||
features=FeaturesParams(sparse_features=SparseFeaturesParams(bits=18, embedding_size=50)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_graph_params(args) -> GraphParams:
|
|
||||||
params = DEFAULT_GRAPH_PARAMS
|
|
||||||
if args.param_file:
|
|
||||||
with tf.io.gfile.GFile(args.param_file, mode="r+") as file:
|
|
||||||
params = GraphParams.parse_obj(json.load(file))
|
|
||||||
|
|
||||||
return params
|
|
BIN
pushservice/src/main/python/models/heavy_ranking/run_args.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/run_args.docx
Normal file
Binary file not shown.
@ -1,59 +0,0 @@
|
|||||||
from twml.trainers import DataRecordTrainer
|
|
||||||
|
|
||||||
from .features import FEATURE_LIST_DEFAULT_PATH
|
|
||||||
|
|
||||||
|
|
||||||
def get_training_arg_parser():
|
|
||||||
parser = DataRecordTrainer.add_parser_arguments()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--feature_list",
|
|
||||||
default=FEATURE_LIST_DEFAULT_PATH,
|
|
||||||
type=str,
|
|
||||||
help="Which features to use for training",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--param_file",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="Path to JSON file containing the graph parameters. If None, model will load default parameters.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--directly_export_best",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="whether to directly_export best_checkpoint",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--warm_start_base_dir",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="latest ckpt in this folder will be used to ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_type",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="Which type of model to train.",
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_arg_parser():
|
|
||||||
parser = get_training_arg_parser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_checkpoint",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="Which checkpoint to use for evaluation",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
Binary file not shown.
@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Model for modifying the checkpoints of the magic recs cnn Model with addition, deletion, and reordering
|
|
||||||
of continuous and binary features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from twitter.deepbird.projects.magic_recs.libs.get_feat_config import FEATURE_LIST_DEFAULT_PATH
|
|
||||||
from twitter.deepbird.projects.magic_recs.libs.warm_start_utils_v11 import (
|
|
||||||
get_feature_list_for_heavy_ranking,
|
|
||||||
mkdirp,
|
|
||||||
rename_dir,
|
|
||||||
rmdir,
|
|
||||||
warm_start_checkpoint,
|
|
||||||
)
|
|
||||||
import twml
|
|
||||||
from twml.trainers import DataRecordTrainer
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
from tensorflow.compat.v1 import logging
|
|
||||||
|
|
||||||
|
|
||||||
def get_arg_parser():
|
|
||||||
parser = DataRecordTrainer.add_parser_arguments()
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_type",
|
|
||||||
default="deepnorm_gbdt_inputdrop2_rescale",
|
|
||||||
type=str,
|
|
||||||
help="specify the model type to use.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_trainer_name",
|
|
||||||
default="None",
|
|
||||||
type=str,
|
|
||||||
help="deprecated, added here just for api compatibility.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--warm_start_base_dir",
|
|
||||||
default="none",
|
|
||||||
type=str,
|
|
||||||
help="latest ckpt in this folder will be used.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_checkpoint_dir",
|
|
||||||
default="none",
|
|
||||||
type=str,
|
|
||||||
help="Output folder for warm started ckpt. If none, it will move warm_start_base_dir to backup, and overwrite it",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--feature_list",
|
|
||||||
default="none",
|
|
||||||
type=str,
|
|
||||||
help="Which features to use for training",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--old_feature_list",
|
|
||||||
default="none",
|
|
||||||
type=str,
|
|
||||||
help="Which features to use for training",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_params(args=None):
|
|
||||||
parser = get_arg_parser()
|
|
||||||
if args is None:
|
|
||||||
return parser.parse_args()
|
|
||||||
else:
|
|
||||||
return parser.parse_args(args)
|
|
||||||
|
|
||||||
|
|
||||||
def _main():
|
|
||||||
opt = get_params()
|
|
||||||
logging.info("parse is: ")
|
|
||||||
logging.info(opt)
|
|
||||||
|
|
||||||
if opt.feature_list == "none":
|
|
||||||
feature_list_path = FEATURE_LIST_DEFAULT_PATH
|
|
||||||
else:
|
|
||||||
feature_list_path = opt.feature_list
|
|
||||||
|
|
||||||
if opt.warm_start_base_dir != "none" and tf.io.gfile.exists(opt.warm_start_base_dir):
|
|
||||||
if opt.output_checkpoint_dir == "none" or opt.output_checkpoint_dir == opt.warm_start_base_dir:
|
|
||||||
_warm_start_base_dir = os.path.normpath(opt.warm_start_base_dir) + "_backup_warm_start"
|
|
||||||
_output_folder_dir = opt.warm_start_base_dir
|
|
||||||
|
|
||||||
rename_dir(opt.warm_start_base_dir, _warm_start_base_dir)
|
|
||||||
tf.logging.info(f"moved {opt.warm_start_base_dir} to {_warm_start_base_dir}")
|
|
||||||
else:
|
|
||||||
_warm_start_base_dir = opt.warm_start_base_dir
|
|
||||||
_output_folder_dir = opt.output_checkpoint_dir
|
|
||||||
|
|
||||||
continuous_binary_feat_list_save_path = os.path.join(
|
|
||||||
_warm_start_base_dir, "continuous_binary_feat_list.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
if opt.old_feature_list != "none":
|
|
||||||
tf.logging.info("getting old continuous_binary_feat_list")
|
|
||||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
|
||||||
opt.old_feature_list, opt.data_spec
|
|
||||||
)
|
|
||||||
rmdir(continuous_binary_feat_list_save_path)
|
|
||||||
twml.util.write_file(
|
|
||||||
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
|
||||||
)
|
|
||||||
tf.logging.info(f"Finish writting files to {continuous_binary_feat_list_save_path}")
|
|
||||||
|
|
||||||
warm_start_folder = os.path.join(_warm_start_base_dir, "best_checkpoint")
|
|
||||||
if not tf.io.gfile.exists(warm_start_folder):
|
|
||||||
warm_start_folder = _warm_start_base_dir
|
|
||||||
|
|
||||||
rmdir(_output_folder_dir)
|
|
||||||
mkdirp(_output_folder_dir)
|
|
||||||
|
|
||||||
new_ckpt = warm_start_checkpoint(
|
|
||||||
warm_start_folder,
|
|
||||||
continuous_binary_feat_list_save_path,
|
|
||||||
feature_list_path,
|
|
||||||
opt.data_spec,
|
|
||||||
_output_folder_dir,
|
|
||||||
opt.model_type,
|
|
||||||
)
|
|
||||||
logging.info(f"Created new ckpt {new_ckpt} from {warm_start_folder}")
|
|
||||||
|
|
||||||
tf.logging.info("getting new continuous_binary_feat_list")
|
|
||||||
new_continuous_binary_feat_list_save_path = os.path.join(
|
|
||||||
_output_folder_dir, "continuous_binary_feat_list.json"
|
|
||||||
)
|
|
||||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
|
||||||
feature_list_path, opt.data_spec
|
|
||||||
)
|
|
||||||
rmdir(new_continuous_binary_feat_list_save_path)
|
|
||||||
twml.util.write_file(
|
|
||||||
new_continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
|
||||||
)
|
|
||||||
tf.logging.info(f"Finish writting files to {new_continuous_binary_feat_list_save_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
_main()
|
|
@ -1,16 +0,0 @@
|
|||||||
python3_library(
|
|
||||||
name = "libs",
|
|
||||||
sources = ["*.py"],
|
|
||||||
tags = [
|
|
||||||
"bazel-compatible",
|
|
||||||
"no-mypy",
|
|
||||||
],
|
|
||||||
dependencies = [
|
|
||||||
"cortex/recsys/src/python/twitter/cortex/recsys/utils",
|
|
||||||
"magicpony/common/file_access/src/python/twitter/magicpony/common/file_access",
|
|
||||||
"src/python/twitter/cortex/ml/embeddings/deepbird",
|
|
||||||
"src/python/twitter/cortex/ml/embeddings/deepbird/grouped_metrics",
|
|
||||||
"src/python/twitter/deepbird/util/data",
|
|
||||||
"twml:twml-nodeps",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
pushservice/src/main/python/models/libs/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/libs/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/libs/__init__.docx
Normal file
BIN
pushservice/src/main/python/models/libs/__init__.docx
Normal file
Binary file not shown.
Binary file not shown.
@ -1,56 +0,0 @@
|
|||||||
# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument
|
|
||||||
"""
|
|
||||||
Implementing Full Sparse Layer, allow specify use_binary_value in call() to
|
|
||||||
overide default action.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from twml.layers import FullSparse as defaultFullSparse
|
|
||||||
from twml.layers.full_sparse import sparse_dense_matmul
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
|
|
||||||
|
|
||||||
class FullSparse(defaultFullSparse):
|
|
||||||
def call(self, inputs, use_binary_values=None, **kwargs): # pylint: disable=unused-argument
|
|
||||||
"""The logic of the layer lives here.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
inputs:
|
|
||||||
A SparseTensor or a list of SparseTensors.
|
|
||||||
If `inputs` is a list, all tensors must have same `dense_shape`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`.
|
|
||||||
- If `inputs` is a `list[SparseTensor`, then returns
|
|
||||||
`bias + add_n([sp_a * dense_b for sp_a in inputs])`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if use_binary_values is not None:
|
|
||||||
default_use_binary_values = use_binary_values
|
|
||||||
else:
|
|
||||||
default_use_binary_values = self.use_binary_values
|
|
||||||
|
|
||||||
if isinstance(default_use_binary_values, (list, tuple)):
|
|
||||||
raise ValueError(
|
|
||||||
"use_binary_values can not be %s when inputs is %s"
|
|
||||||
% (type(default_use_binary_values), type(inputs))
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = sparse_dense_matmul(
|
|
||||||
inputs,
|
|
||||||
self.weight,
|
|
||||||
self.use_sparse_grads,
|
|
||||||
default_use_binary_values,
|
|
||||||
name="sparse_mm",
|
|
||||||
partition_axis=self.partition_axis,
|
|
||||||
num_partitions=self.num_partitions,
|
|
||||||
compress_ids=self._use_compression,
|
|
||||||
cast_indices_dtype=self._cast_indices_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
outputs = tf.nn.bias_add(outputs, self.bias)
|
|
||||||
|
|
||||||
if self.activation is not None:
|
|
||||||
return self.activation(outputs) # pylint: disable=not-callable
|
|
||||||
return outputs
|
|
BIN
pushservice/src/main/python/models/libs/get_feat_config.docx
Normal file
BIN
pushservice/src/main/python/models/libs/get_feat_config.docx
Normal file
Binary file not shown.
@ -1,176 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from twitter.deepbird.projects.magic_recs.libs.metric_fn_utils import USER_AGE_FEATURE_NAME
|
|
||||||
from twitter.deepbird.projects.magic_recs.libs.model_utils import read_config
|
|
||||||
from twml.contrib import feature_config as contrib_feature_config
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
FEAT_CONFIG_DEFAULT_VAL = -1.23456789
|
|
||||||
|
|
||||||
DEFAULT_INPUT_SIZE_BITS = 18
|
|
||||||
|
|
||||||
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
|
|
||||||
FEATURE_LIST_DEFAULT_PATH = os.path.join(
|
|
||||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH = "./feature_list_light_ranking.yaml"
|
|
||||||
FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH = os.path.join(
|
|
||||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
FEATURE_LIST_DEFAULT = read_config(FEATURE_LIST_DEFAULT_PATH).items()
|
|
||||||
FEATURE_LIST_LIGHT_RANKING_DEFAULT = read_config(FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH).items()
|
|
||||||
|
|
||||||
|
|
||||||
LABELS = ["label"]
|
|
||||||
LABELS_MTL = {"OONC": ["label"], "OONC_Engagement": ["label", "label.engagement"]}
|
|
||||||
LABELS_LR = {
|
|
||||||
"Sent": ["label.sent"],
|
|
||||||
"HeavyRankPosition": ["meta.ranking.is_top3"],
|
|
||||||
"HeavyRankProbability": ["meta.ranking.weighted_oonc_model_score"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_new_feature_config_base(
|
|
||||||
data_spec_path,
|
|
||||||
labels,
|
|
||||||
add_sparse_continous=True,
|
|
||||||
add_gbdt=True,
|
|
||||||
add_user_id=False,
|
|
||||||
add_timestamp=False,
|
|
||||||
add_user_age=False,
|
|
||||||
feature_list_provided=[],
|
|
||||||
opt=None,
|
|
||||||
run_light_ranking_group_metrics_in_bq=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Getter of the feature config based on specification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_spec_path: A string indicating the path of the data_spec.json file, which could be
|
|
||||||
either a local path or a hdfs path.
|
|
||||||
labels: A list of strings indicating the name of the label in the data spec.
|
|
||||||
add_sparse_continous: A bool indicating if sparse_continuous feature needs to be included.
|
|
||||||
add_gbdt: A bool indicating if gbdt feature needs to be included.
|
|
||||||
add_user_id: A bool indicating if user_id feature needs to be included.
|
|
||||||
add_timestamp: A bool indicating if timestamp feature needs to be included. This will be useful
|
|
||||||
for sequential models and meta learning models.
|
|
||||||
add_user_age: A bool indicating if the user age feature needs to be included.
|
|
||||||
feature_list_provided: A list of features thats need to be included. If not specified, will use
|
|
||||||
FEATURE_LIST_DEFAULT by default.
|
|
||||||
opt: A namespace of arguments indicating the hyparameters.
|
|
||||||
run_light_ranking_group_metrics_in_bq: A bool indicating if heavy ranker score info needs to be included to compute group metrics in BigQuery.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A twml feature config object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_size_bits = DEFAULT_INPUT_SIZE_BITS if opt is None else opt.input_size_bits
|
|
||||||
|
|
||||||
feature_list = feature_list_provided if feature_list_provided != [] else FEATURE_LIST_DEFAULT
|
|
||||||
a_string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
|
||||||
|
|
||||||
builder = contrib_feature_config.FeatureConfigBuilder(data_spec_path=data_spec_path)
|
|
||||||
|
|
||||||
builder = builder.extract_feature_group(
|
|
||||||
feature_regexes=a_string_feat_list,
|
|
||||||
group_name="continuous",
|
|
||||||
default_value=FEAT_CONFIG_DEFAULT_VAL,
|
|
||||||
type_filter=["CONTINUOUS"],
|
|
||||||
)
|
|
||||||
|
|
||||||
builder = builder.extract_features_as_hashed_sparse(
|
|
||||||
feature_regexes=a_string_feat_list,
|
|
||||||
output_tensor_name="sparse_no_continuous",
|
|
||||||
hash_space_size_bits=input_size_bits,
|
|
||||||
type_filter=["BINARY", "DISCRETE", "STRING", "SPARSE_BINARY"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_gbdt:
|
|
||||||
builder = builder.extract_features_as_hashed_sparse(
|
|
||||||
feature_regexes=["ads\..*"],
|
|
||||||
output_tensor_name="gbdt_sparse",
|
|
||||||
hash_space_size_bits=input_size_bits,
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_sparse_continous:
|
|
||||||
s_string_feat_list = [f[0] for f in feature_list if f[1] == "S"]
|
|
||||||
|
|
||||||
builder = builder.extract_features_as_hashed_sparse(
|
|
||||||
feature_regexes=s_string_feat_list,
|
|
||||||
output_tensor_name="sparse_continuous",
|
|
||||||
hash_space_size_bits=input_size_bits,
|
|
||||||
type_filter=["SPARSE_CONTINUOUS"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_user_id:
|
|
||||||
builder = builder.extract_feature("meta.user_id")
|
|
||||||
if add_timestamp:
|
|
||||||
builder = builder.extract_feature("meta.timestamp")
|
|
||||||
if add_user_age:
|
|
||||||
builder = builder.extract_feature(USER_AGE_FEATURE_NAME)
|
|
||||||
|
|
||||||
if run_light_ranking_group_metrics_in_bq:
|
|
||||||
builder = builder.extract_feature("meta.trace_id")
|
|
||||||
builder = builder.extract_feature("meta.ranking.weighted_oonc_model_score")
|
|
||||||
|
|
||||||
builder = builder.add_labels(labels).define_weight("meta.weight")
|
|
||||||
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
def get_feature_config_with_sparse_continuous(
|
|
||||||
data_spec_path,
|
|
||||||
feature_list_provided=[],
|
|
||||||
opt=None,
|
|
||||||
add_user_id=False,
|
|
||||||
add_timestamp=False,
|
|
||||||
add_user_age=False,
|
|
||||||
):
|
|
||||||
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "OONC"
|
|
||||||
if task_name not in LABELS_MTL:
|
|
||||||
raise ValueError("Invalid Task Name !")
|
|
||||||
|
|
||||||
return _get_new_feature_config_base(
|
|
||||||
data_spec_path=data_spec_path,
|
|
||||||
labels=LABELS_MTL[task_name],
|
|
||||||
add_sparse_continous=True,
|
|
||||||
add_user_id=add_user_id,
|
|
||||||
add_timestamp=add_timestamp,
|
|
||||||
add_user_age=add_user_age,
|
|
||||||
feature_list_provided=feature_list_provided,
|
|
||||||
opt=opt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_feature_config_light_ranking(
|
|
||||||
data_spec_path,
|
|
||||||
feature_list_provided=[],
|
|
||||||
opt=None,
|
|
||||||
add_user_id=True,
|
|
||||||
add_timestamp=False,
|
|
||||||
add_user_age=False,
|
|
||||||
add_gbdt=False,
|
|
||||||
run_light_ranking_group_metrics_in_bq=False,
|
|
||||||
):
|
|
||||||
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "HeavyRankPosition"
|
|
||||||
if task_name not in LABELS_LR:
|
|
||||||
raise ValueError("Invalid Task Name !")
|
|
||||||
if not feature_list_provided:
|
|
||||||
feature_list_provided = FEATURE_LIST_LIGHT_RANKING_DEFAULT
|
|
||||||
|
|
||||||
return _get_new_feature_config_base(
|
|
||||||
data_spec_path=data_spec_path,
|
|
||||||
labels=LABELS_LR[task_name],
|
|
||||||
add_sparse_continous=False,
|
|
||||||
add_gbdt=add_gbdt,
|
|
||||||
add_user_id=add_user_id,
|
|
||||||
add_timestamp=add_timestamp,
|
|
||||||
add_user_age=add_user_age,
|
|
||||||
feature_list_provided=feature_list_provided,
|
|
||||||
opt=opt,
|
|
||||||
run_light_ranking_group_metrics_in_bq=run_light_ranking_group_metrics_in_bq,
|
|
||||||
)
|
|
BIN
pushservice/src/main/python/models/libs/graph_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/graph_utils.docx
Normal file
Binary file not shown.
@ -1,42 +0,0 @@
|
|||||||
"""
|
|
||||||
Utilties that aid in building the magic recs graph.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
|
|
||||||
|
|
||||||
def get_trainable_variables(all_trainable_variables, trainable_regexes):
|
|
||||||
"""Returns a subset of trainable variables for training.
|
|
||||||
|
|
||||||
Given a collection of trainable variables, this will return all those that match the given regexes.
|
|
||||||
Will also log those variables.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
all_trainable_variables (a collection of trainable tf.Variable): The variables to search through.
|
|
||||||
trainable_regexes (a collection of regexes): Variables that match any regex will be included.
|
|
||||||
|
|
||||||
Returns a list of tf.Variable
|
|
||||||
"""
|
|
||||||
if trainable_regexes is None or len(trainable_regexes) == 0:
|
|
||||||
tf.logging.info("No trainable regexes found. Not using get_trainable_variables behavior.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert any(
|
|
||||||
tf.is_tensor(var) for var in all_trainable_variables
|
|
||||||
), f"Non TF variable found: {all_trainable_variables}"
|
|
||||||
trainable_variables = list(
|
|
||||||
filter(
|
|
||||||
lambda var: any(re.match(regex, var.name, re.IGNORECASE) for regex in trainable_regexes),
|
|
||||||
all_trainable_variables,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tf.logging.info(f"Using filtered trainable variables: {trainable_variables}")
|
|
||||||
|
|
||||||
assert (
|
|
||||||
trainable_variables
|
|
||||||
), "Did not find trainable variables after filtering after filtering from {} number of vars originaly. All vars: {} and train regexes: {}".format(
|
|
||||||
len(all_trainable_variables), all_trainable_variables, trainable_regexes
|
|
||||||
)
|
|
||||||
return trainable_variables
|
|
BIN
pushservice/src/main/python/models/libs/group_metrics.docx
Normal file
BIN
pushservice/src/main/python/models/libs/group_metrics.docx
Normal file
Binary file not shown.
@ -1,114 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.computation import (
|
|
||||||
write_grouped_metrics_to_mldash,
|
|
||||||
)
|
|
||||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
|
|
||||||
ClassificationGroupedMetricsConfiguration,
|
|
||||||
NDCGGroupedMetricsConfiguration,
|
|
||||||
)
|
|
||||||
import twml
|
|
||||||
|
|
||||||
from .light_ranking_metrics import (
|
|
||||||
CGRGroupedMetricsConfiguration,
|
|
||||||
ExpectedLossGroupedMetricsConfiguration,
|
|
||||||
RecallGroupedMetricsConfiguration,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
from tensorflow.compat.v1 import logging
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
def run_group_metrics(trainer, data_dir, model_path, parse_fn, group_feature_name="meta.user_id"):
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
logging.info("Evaluating with group metrics.")
|
|
||||||
|
|
||||||
metrics = write_grouped_metrics_to_mldash(
|
|
||||||
trainer=trainer,
|
|
||||||
data_dir=data_dir,
|
|
||||||
model_path=model_path,
|
|
||||||
group_fn=lambda datarecord: str(
|
|
||||||
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
|
|
||||||
),
|
|
||||||
parse_fn=parse_fn,
|
|
||||||
metric_configurations=[
|
|
||||||
ClassificationGroupedMetricsConfiguration(),
|
|
||||||
NDCGGroupedMetricsConfiguration(k=[5, 10, 20]),
|
|
||||||
],
|
|
||||||
total_records_to_read=1000000000,
|
|
||||||
shuffle=False,
|
|
||||||
mldash_metrics_name="grouped_metrics",
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
logging.info(f"Evaluated Group Metics: {metrics}.")
|
|
||||||
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
|
|
||||||
|
|
||||||
|
|
||||||
def run_group_metrics_light_ranking(
|
|
||||||
trainer, data_dir, model_path, parse_fn, group_feature_name="meta.trace_id"
|
|
||||||
):
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
logging.info("Evaluating with group metrics.")
|
|
||||||
|
|
||||||
metrics = write_grouped_metrics_to_mldash(
|
|
||||||
trainer=trainer,
|
|
||||||
data_dir=data_dir,
|
|
||||||
model_path=model_path,
|
|
||||||
group_fn=lambda datarecord: str(
|
|
||||||
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
|
|
||||||
),
|
|
||||||
parse_fn=parse_fn,
|
|
||||||
metric_configurations=[
|
|
||||||
CGRGroupedMetricsConfiguration(lightNs=[50, 100, 200], heavyKs=[1, 3, 10, 20, 50]),
|
|
||||||
RecallGroupedMetricsConfiguration(n=[50, 100, 200], k=[1, 3, 10, 20, 50]),
|
|
||||||
ExpectedLossGroupedMetricsConfiguration(lightNs=[50, 100, 200]),
|
|
||||||
],
|
|
||||||
total_records_to_read=10000000,
|
|
||||||
num_batches_to_load=50,
|
|
||||||
batch_size=1024,
|
|
||||||
shuffle=False,
|
|
||||||
mldash_metrics_name="grouped_metrics_for_light_ranking",
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
logging.info(f"Evaluated Group Metics for Light Ranking: {metrics}.")
|
|
||||||
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
|
|
||||||
|
|
||||||
|
|
||||||
def run_group_metrics_light_ranking_in_bq(trainer, params, checkpoint_path):
|
|
||||||
logging.info("getting Test Predictions for Light Ranking Group Metrics in BigQuery !!!")
|
|
||||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
|
||||||
info_pool = []
|
|
||||||
|
|
||||||
for result in trainer.estimator.predict(
|
|
||||||
eval_input_fn, checkpoint_path=checkpoint_path, yield_single_examples=False
|
|
||||||
):
|
|
||||||
traceID = result["trace_id"]
|
|
||||||
pred = result["prediction"]
|
|
||||||
label = result["target"]
|
|
||||||
info = np.concatenate([traceID, pred, label], axis=1)
|
|
||||||
info_pool.append(info)
|
|
||||||
|
|
||||||
info_pool = np.concatenate(info_pool)
|
|
||||||
|
|
||||||
locname = "/tmp/000/"
|
|
||||||
if not os.path.exists(locname):
|
|
||||||
os.makedirs(locname)
|
|
||||||
|
|
||||||
locfile = locname + params.pred_file_name
|
|
||||||
columns = ["trace_id", "model_prediction", "meta__ranking__weighted_oonc_model_score"]
|
|
||||||
np.savetxt(locfile, info_pool, delimiter=",", header=",".join(columns))
|
|
||||||
tf.io.gfile.copy(locfile, params.pred_file_path + params.pred_file_name, overwrite=True)
|
|
||||||
|
|
||||||
if os.path.isfile(locfile):
|
|
||||||
os.remove(locfile)
|
|
||||||
|
|
||||||
logging.info("Done Prediction for Light Ranking Group Metrics in BigQuery.")
|
|
BIN
pushservice/src/main/python/models/libs/initializer.docx
Normal file
BIN
pushservice/src/main/python/models/libs/initializer.docx
Normal file
Binary file not shown.
@ -1,118 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from tensorflow.keras import backend as K
|
|
||||||
|
|
||||||
|
|
||||||
class VarianceScaling(object):
|
|
||||||
"""Initializer capable of adapting its scale to the shape of weights.
|
|
||||||
With `distribution="normal"`, samples are drawn from a truncated normal
|
|
||||||
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
|
|
||||||
- number of input units in the weight tensor, if mode = "fan_in"
|
|
||||||
- number of output units, if mode = "fan_out"
|
|
||||||
- average of the numbers of input and output units, if mode = "fan_avg"
|
|
||||||
With `distribution="uniform"`,
|
|
||||||
samples are drawn from a uniform distribution
|
|
||||||
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
|
|
||||||
# Arguments
|
|
||||||
scale: Scaling factor (positive float).
|
|
||||||
mode: One of "fan_in", "fan_out", "fan_avg".
|
|
||||||
distribution: Random distribution to use. One of "normal", "uniform".
|
|
||||||
seed: A Python integer. Used to seed the random generator.
|
|
||||||
# Raises
|
|
||||||
ValueError: In case of an invalid value for the "scale", mode" or
|
|
||||||
"distribution" arguments."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
scale=1.0,
|
|
||||||
mode="fan_in",
|
|
||||||
distribution="normal",
|
|
||||||
seed=None,
|
|
||||||
fan_in=None,
|
|
||||||
fan_out=None,
|
|
||||||
):
|
|
||||||
self.fan_in = fan_in
|
|
||||||
self.fan_out = fan_out
|
|
||||||
if scale <= 0.0:
|
|
||||||
raise ValueError("`scale` must be a positive float. Got:", scale)
|
|
||||||
mode = mode.lower()
|
|
||||||
if mode not in {"fan_in", "fan_out", "fan_avg"}:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid `mode` argument: " 'expected on of {"fan_in", "fan_out", "fan_avg"} ' "but got",
|
|
||||||
mode,
|
|
||||||
)
|
|
||||||
distribution = distribution.lower()
|
|
||||||
if distribution not in {"normal", "uniform"}:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid `distribution` argument: " 'expected one of {"normal", "uniform"} ' "but got",
|
|
||||||
distribution,
|
|
||||||
)
|
|
||||||
self.scale = scale
|
|
||||||
self.mode = mode
|
|
||||||
self.distribution = distribution
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
def __call__(self, shape, dtype=None, partition_info=None):
|
|
||||||
fan_in = shape[-2] if self.fan_in is None else self.fan_in
|
|
||||||
fan_out = shape[-1] if self.fan_out is None else self.fan_out
|
|
||||||
|
|
||||||
scale = self.scale
|
|
||||||
if self.mode == "fan_in":
|
|
||||||
scale /= max(1.0, fan_in)
|
|
||||||
elif self.mode == "fan_out":
|
|
||||||
scale /= max(1.0, fan_out)
|
|
||||||
else:
|
|
||||||
scale /= max(1.0, float(fan_in + fan_out) / 2)
|
|
||||||
if self.distribution == "normal":
|
|
||||||
stddev = np.sqrt(scale) / 0.87962566103423978
|
|
||||||
return K.truncated_normal(shape, 0.0, stddev, dtype=dtype, seed=self.seed)
|
|
||||||
else:
|
|
||||||
limit = np.sqrt(3.0 * scale)
|
|
||||||
return K.random_uniform(shape, -limit, limit, dtype=dtype, seed=self.seed)
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return {
|
|
||||||
"scale": self.scale,
|
|
||||||
"mode": self.mode,
|
|
||||||
"distribution": self.distribution,
|
|
||||||
"seed": self.seed,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def customized_glorot_uniform(seed=None, fan_in=None, fan_out=None):
|
|
||||||
"""Glorot uniform initializer, also called Xavier uniform initializer.
|
|
||||||
It draws samples from a uniform distribution within [-limit, limit]
|
|
||||||
where `limit` is `sqrt(6 / (fan_in + fan_out))`
|
|
||||||
where `fan_in` is the number of input units in the weight tensor
|
|
||||||
and `fan_out` is the number of output units in the weight tensor.
|
|
||||||
# Arguments
|
|
||||||
seed: A Python integer. Used to seed the random generator.
|
|
||||||
# Returns
|
|
||||||
An initializer."""
|
|
||||||
return VarianceScaling(
|
|
||||||
scale=1.0,
|
|
||||||
mode="fan_avg",
|
|
||||||
distribution="uniform",
|
|
||||||
seed=seed,
|
|
||||||
fan_in=fan_in,
|
|
||||||
fan_out=fan_out,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def customized_glorot_norm(seed=None, fan_in=None, fan_out=None):
|
|
||||||
"""Glorot norm initializer, also called Xavier uniform initializer.
|
|
||||||
It draws samples from a uniform distribution within [-limit, limit]
|
|
||||||
where `limit` is `sqrt(6 / (fan_in + fan_out))`
|
|
||||||
where `fan_in` is the number of input units in the weight tensor
|
|
||||||
and `fan_out` is the number of output units in the weight tensor.
|
|
||||||
# Arguments
|
|
||||||
seed: A Python integer. Used to seed the random generator.
|
|
||||||
# Returns
|
|
||||||
An initializer."""
|
|
||||||
return VarianceScaling(
|
|
||||||
scale=1.0,
|
|
||||||
mode="fan_avg",
|
|
||||||
distribution="normal",
|
|
||||||
seed=seed,
|
|
||||||
fan_in=fan_in,
|
|
||||||
fan_out=fan_out,
|
|
||||||
)
|
|
Binary file not shown.
@ -1,255 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
|
|
||||||
GroupedMetricsConfiguration,
|
|
||||||
)
|
|
||||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.helpers import (
|
|
||||||
extract_prediction_from_prediction_record,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
def score_loss_at_n(labels, predictions, lightN):
|
|
||||||
"""
|
|
||||||
Compute the absolute ScoreLoss ranking metric
|
|
||||||
Args:
|
|
||||||
labels (list) : A list of label values (HeavyRanking Reference)
|
|
||||||
predictions (list): A list of prediction values (LightRanking Predictions)
|
|
||||||
lightN (int): size of the list at which of Initial candidates to compute ScoreLoss. (LightRanking)
|
|
||||||
"""
|
|
||||||
assert len(labels) == len(predictions)
|
|
||||||
|
|
||||||
if lightN <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
labels_with_predictions = zip(labels, predictions)
|
|
||||||
labels_with_sorted_predictions = sorted(
|
|
||||||
labels_with_predictions, key=lambda x: x[1], reverse=True
|
|
||||||
)[:lightN]
|
|
||||||
labels_top1_light = max([label for label, _ in labels_with_sorted_predictions])
|
|
||||||
labels_top1_heavy = max(labels)
|
|
||||||
|
|
||||||
return labels_top1_heavy - labels_top1_light
|
|
||||||
|
|
||||||
|
|
||||||
def cgr_at_nk(labels, predictions, lightN, heavyK):
|
|
||||||
"""
|
|
||||||
Compute Cumulative Gain Ratio (CGR) ranking metric
|
|
||||||
Args:
|
|
||||||
labels (list) : A list of label values (HeavyRanking Reference)
|
|
||||||
predictions (list): A list of prediction values (LightRanking Predictions)
|
|
||||||
lightN (int): size of the list at which of Initial candidates to compute CGR. (LightRanking)
|
|
||||||
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
|
|
||||||
"""
|
|
||||||
assert len(labels) == len(predictions)
|
|
||||||
|
|
||||||
if (not lightN) or (not heavyK):
|
|
||||||
out = None
|
|
||||||
elif lightN <= 0 or heavyK <= 0:
|
|
||||||
out = None
|
|
||||||
else:
|
|
||||||
|
|
||||||
labels_with_predictions = zip(labels, predictions)
|
|
||||||
labels_with_sorted_predictions = sorted(
|
|
||||||
labels_with_predictions, key=lambda x: x[1], reverse=True
|
|
||||||
)[:lightN]
|
|
||||||
labels_topN_light = [label for label, _ in labels_with_sorted_predictions]
|
|
||||||
|
|
||||||
if lightN <= heavyK:
|
|
||||||
cg_light = sum(labels_topN_light)
|
|
||||||
else:
|
|
||||||
labels_topK_heavy_from_light = sorted(labels_topN_light, reverse=True)[:heavyK]
|
|
||||||
cg_light = sum(labels_topK_heavy_from_light)
|
|
||||||
|
|
||||||
ideal_ordering = sorted(labels, reverse=True)
|
|
||||||
cg_heavy = sum(ideal_ordering[: min(lightN, heavyK)])
|
|
||||||
|
|
||||||
out = 0.0
|
|
||||||
if cg_heavy != 0:
|
|
||||||
out = max(cg_light / cg_heavy, 0)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _get_weight(w, atK):
|
|
||||||
if not w:
|
|
||||||
return 1.0
|
|
||||||
elif len(w) <= atK:
|
|
||||||
return 0.0
|
|
||||||
else:
|
|
||||||
return w[atK]
|
|
||||||
|
|
||||||
|
|
||||||
def recall_at_nk(labels, predictions, n=None, k=None, w=None):
|
|
||||||
"""
|
|
||||||
Recall at N-K ranking metric
|
|
||||||
Args:
|
|
||||||
labels (list): A list of label values
|
|
||||||
predictions (list): A list of prediction values
|
|
||||||
n (int): size of the list at which of predictions to compute recall. (Light Ranking Predictions)
|
|
||||||
The default is None in which case the length of the provided predictions is used as L
|
|
||||||
k (int): size of the list at which of labels to compute recall. (Heavy Ranking Predictions)
|
|
||||||
The default is None in which case the length of the provided labels is used as L
|
|
||||||
w (list): weight vector sorted by labels
|
|
||||||
"""
|
|
||||||
assert len(labels) == len(predictions)
|
|
||||||
|
|
||||||
if not any(labels):
|
|
||||||
out = None
|
|
||||||
else:
|
|
||||||
|
|
||||||
safe_n = len(predictions) if not n else min(len(predictions), n)
|
|
||||||
safe_k = len(labels) if not k else min(len(labels), k)
|
|
||||||
|
|
||||||
labels_with_predictions = zip(labels, predictions)
|
|
||||||
sorted_labels_with_predictions = sorted(
|
|
||||||
labels_with_predictions, key=lambda x: x[0], reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
order_sorted_labels_predictions = zip(range(len(labels)), *zip(*sorted_labels_with_predictions))
|
|
||||||
|
|
||||||
order_with_predictions = [
|
|
||||||
(order, pred) for order, label, pred in order_sorted_labels_predictions
|
|
||||||
]
|
|
||||||
order_with_sorted_predictions = sorted(order_with_predictions, key=lambda x: x[1], reverse=True)
|
|
||||||
|
|
||||||
pred_sorted_order_at_n = [order for order, _ in order_with_sorted_predictions][:safe_n]
|
|
||||||
|
|
||||||
intersection_weight = [
|
|
||||||
_get_weight(w, order) if order < safe_k else 0 for order in pred_sorted_order_at_n
|
|
||||||
]
|
|
||||||
|
|
||||||
intersection_score = sum(intersection_weight)
|
|
||||||
full_score = sum(w) if w else float(safe_k)
|
|
||||||
|
|
||||||
out = 0.0
|
|
||||||
if full_score != 0:
|
|
||||||
out = intersection_score / full_score
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ExpectedLossGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
|
||||||
"""
|
|
||||||
This is the Expected Loss Grouped metric computation configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, lightNs=[]):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
lightNs (list): size of the list at which of Initial candidates to compute Expected Loss. (LightRanking)
|
|
||||||
"""
|
|
||||||
self.lightNs = lightNs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return "ExpectedLoss"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def metrics_dict(self):
|
|
||||||
metrics_to_compute = {}
|
|
||||||
for lightN in self.lightNs:
|
|
||||||
metric_name = "ExpectedLoss_atLight_" + str(lightN)
|
|
||||||
metrics_to_compute[metric_name] = partial(score_loss_at_n, lightN=lightN)
|
|
||||||
return metrics_to_compute
|
|
||||||
|
|
||||||
def extract_label(self, prec, drec, drec_label):
|
|
||||||
return drec_label
|
|
||||||
|
|
||||||
def extract_prediction(self, prec, drec, drec_label):
|
|
||||||
return extract_prediction_from_prediction_record(prec)
|
|
||||||
|
|
||||||
|
|
||||||
class CGRGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
|
||||||
"""
|
|
||||||
This is the Cumulative Gain Ratio (CGR) Grouped metric computation configuration.
|
|
||||||
CGR at the max length of each session is the default.
|
|
||||||
CGR at additional positions can be computed by specifying a list of 'n's and 'k's
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, lightNs=[], heavyKs=[]):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
lightNs (list): size of the list at which of Initial candidates to compute CGR. (LightRanking)
|
|
||||||
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
|
|
||||||
"""
|
|
||||||
self.lightNs = lightNs
|
|
||||||
self.heavyKs = heavyKs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return "cgr"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def metrics_dict(self):
|
|
||||||
metrics_to_compute = {}
|
|
||||||
for lightN in self.lightNs:
|
|
||||||
for heavyK in self.heavyKs:
|
|
||||||
metric_name = "cgr_atLight_" + str(lightN) + "_atHeavy_" + str(heavyK)
|
|
||||||
metrics_to_compute[metric_name] = partial(cgr_at_nk, lightN=lightN, heavyK=heavyK)
|
|
||||||
return metrics_to_compute
|
|
||||||
|
|
||||||
def extract_label(self, prec, drec, drec_label):
|
|
||||||
return drec_label
|
|
||||||
|
|
||||||
def extract_prediction(self, prec, drec, drec_label):
|
|
||||||
return extract_prediction_from_prediction_record(prec)
|
|
||||||
|
|
||||||
|
|
||||||
class RecallGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
|
||||||
"""
|
|
||||||
This is the Recall Grouped metric computation configuration.
|
|
||||||
Recall at the max length of each session is the default.
|
|
||||||
Recall at additional positions can be computed by specifying a list of 'n's and 'k's
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n=[], k=[], w=[]):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
n (list): A list of ints. List of prediction rank thresholds (for light)
|
|
||||||
k (list): A list of ints. List of label rank thresholds (for heavy)
|
|
||||||
"""
|
|
||||||
self.predN = n
|
|
||||||
self.labelK = k
|
|
||||||
self.weight = w
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return "group_recall"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def metrics_dict(self):
|
|
||||||
metrics_to_compute = {"group_recall_unweighted": recall_at_nk}
|
|
||||||
if not self.weight:
|
|
||||||
metrics_to_compute["group_recall_weighted"] = partial(recall_at_nk, w=self.weight)
|
|
||||||
|
|
||||||
if self.predN and self.labelK:
|
|
||||||
for n in self.predN:
|
|
||||||
for k in self.labelK:
|
|
||||||
if n >= k:
|
|
||||||
metrics_to_compute[
|
|
||||||
"group_recall_unweighted_at_L" + str(n) + "_at_H" + str(k)
|
|
||||||
] = partial(recall_at_nk, n=n, k=k)
|
|
||||||
if self.weight:
|
|
||||||
metrics_to_compute[
|
|
||||||
"group_recall_weighted_at_L" + str(n) + "_at_H" + str(k)
|
|
||||||
] = partial(recall_at_nk, n=n, k=k, w=self.weight)
|
|
||||||
|
|
||||||
if self.labelK and not self.predN:
|
|
||||||
for k in self.labelK:
|
|
||||||
metrics_to_compute["group_recall_unweighted_at_full_at_H" + str(k)] = partial(
|
|
||||||
recall_at_nk, k=k
|
|
||||||
)
|
|
||||||
if self.weight:
|
|
||||||
metrics_to_compute["group_recall_weighted_at_full_at_H" + str(k)] = partial(
|
|
||||||
recall_at_nk, k=k, w=self.weight
|
|
||||||
)
|
|
||||||
return metrics_to_compute
|
|
||||||
|
|
||||||
def extract_label(self, prec, drec, drec_label):
|
|
||||||
return drec_label
|
|
||||||
|
|
||||||
def extract_prediction(self, prec, drec, drec_label):
|
|
||||||
return extract_prediction_from_prediction_record(prec)
|
|
BIN
pushservice/src/main/python/models/libs/metric_fn_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/metric_fn_utils.docx
Normal file
Binary file not shown.
@ -1,294 +0,0 @@
|
|||||||
"""
|
|
||||||
Utilties for constructing a metric_fn for magic recs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from twml.contrib.metrics.metrics import (
|
|
||||||
get_dual_binary_tasks_metric_fn,
|
|
||||||
get_numeric_metric_fn,
|
|
||||||
get_partial_multi_binary_class_metric_fn,
|
|
||||||
get_single_binary_task_metric_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .model_utils import generate_disliked_mask
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
|
|
||||||
|
|
||||||
METRIC_BOOK = {
|
|
||||||
"OONC": ["OONC"],
|
|
||||||
"OONC_Engagement": ["OONC", "Engagement"],
|
|
||||||
"Sent": ["Sent"],
|
|
||||||
"HeavyRankPosition": ["HeavyRankPosition"],
|
|
||||||
"HeavyRankProbability": ["HeavyRankProbability"],
|
|
||||||
}
|
|
||||||
|
|
||||||
USER_AGE_FEATURE_NAME = "accountAge"
|
|
||||||
NEW_USER_AGE_CUTOFF = 0
|
|
||||||
|
|
||||||
|
|
||||||
def remove_padding_and_flatten(tensor, valid_batch_size):
|
|
||||||
"""Remove the padding of the input padded tensor given the valid batch size tensor,
|
|
||||||
then flatten the output with respect to the first dimension.
|
|
||||||
Args:
|
|
||||||
tensor: A tensor of size [META_BATCH_SIZE, BATCH_SIZE, FEATURE_DIM].
|
|
||||||
valid_batch_size: A tensor of size [META_BATCH_SIZE], with each element indicating
|
|
||||||
the effective batch size of the BATCH_SIZE dimension.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tesnor of size [tf.reduce_sum(valid_batch_size), FEATURE_DIM].
|
|
||||||
"""
|
|
||||||
unpadded_ragged_tensor = tf.RaggedTensor.from_tensor(tensor=tensor, lengths=valid_batch_size)
|
|
||||||
|
|
||||||
return unpadded_ragged_tensor.flat_values
|
|
||||||
|
|
||||||
|
|
||||||
def safe_mask(values, mask):
|
|
||||||
"""Mask values if possible.
|
|
||||||
|
|
||||||
Boolean mask inputed values if and only if values is a tensor of the same dimension as mask (or can be broadcasted to that dimension).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
values (Any or Tensor): Input tensor to mask. Dim 0 should be size N.
|
|
||||||
mask (boolean tensor): A boolean tensor of size N.
|
|
||||||
|
|
||||||
Returns Values or Values masked.
|
|
||||||
"""
|
|
||||||
if values is None:
|
|
||||||
return values
|
|
||||||
if not tf.is_tensor(values):
|
|
||||||
return values
|
|
||||||
values_shape = values.get_shape()
|
|
||||||
if not values_shape or len(values_shape) == 0:
|
|
||||||
return values
|
|
||||||
if not mask.get_shape().is_compatible_with(values_shape[0]):
|
|
||||||
return values
|
|
||||||
return tf.boolean_mask(values, mask)
|
|
||||||
|
|
||||||
|
|
||||||
def add_new_user_metrics(metric_fn):
|
|
||||||
"""Will stratify the metric_fn by adding new user metrics.
|
|
||||||
|
|
||||||
Given an input metric_fn, double every metric: One will be the orignal and the other will only include those for new users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metric_fn (python function): Base twml metric_fn.
|
|
||||||
|
|
||||||
Returns a metric_fn with new user metrics included.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def metric_fn_with_new_users(graph_output, labels, weights):
|
|
||||||
if USER_AGE_FEATURE_NAME not in graph_output:
|
|
||||||
raise ValueError(
|
|
||||||
"In order to get metrics stratified by user age, {name} feature should be added to model graph output. However, only the following output keys were found: {keys}.".format(
|
|
||||||
name=USER_AGE_FEATURE_NAME, keys=graph_output.keys()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
metric_ops = metric_fn(graph_output, labels, weights)
|
|
||||||
|
|
||||||
is_new = tf.reshape(
|
|
||||||
tf.math.less_equal(
|
|
||||||
tf.cast(graph_output[USER_AGE_FEATURE_NAME], tf.int64),
|
|
||||||
tf.cast(NEW_USER_AGE_CUTOFF, tf.int64),
|
|
||||||
),
|
|
||||||
[-1],
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = safe_mask(labels, is_new)
|
|
||||||
weights = safe_mask(weights, is_new)
|
|
||||||
graph_output = {key: safe_mask(values, is_new) for key, values in graph_output.items()}
|
|
||||||
|
|
||||||
new_user_metric_ops = metric_fn(graph_output, labels, weights)
|
|
||||||
new_user_metric_ops = {name + "_new_users": ops for name, ops in new_user_metric_ops.items()}
|
|
||||||
metric_ops.update(new_user_metric_ops)
|
|
||||||
return metric_ops
|
|
||||||
|
|
||||||
return metric_fn_with_new_users
|
|
||||||
|
|
||||||
|
|
||||||
def get_meta_learn_single_binary_task_metric_fn(
|
|
||||||
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
|
|
||||||
):
|
|
||||||
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metrics: A list of string representing metric names.
|
|
||||||
classnames: A list of string repsenting class names, In case of multiple binary class models,
|
|
||||||
the names for each class or label.
|
|
||||||
top_k: A tuple of int to specify top K metrics.
|
|
||||||
use_top_k: A boolean value indicating of top K of metrics is used.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A customized metric_fn function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_eval_metric_ops(graph_output, labels, weights):
|
|
||||||
"""The op func of the eval_metrics. Comparing with normal version,
|
|
||||||
the difference is we flatten the output, label, and weights.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_output: A dict of tensors.
|
|
||||||
labels: A tensor of int32 be the value of either 0 or 1.
|
|
||||||
weights: A tensor of float32 to indicate the per record weight.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict of metric names and values.
|
|
||||||
"""
|
|
||||||
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
|
|
||||||
metrics, predcols=0, classes=classnames
|
|
||||||
)
|
|
||||||
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
|
|
||||||
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
|
|
||||||
metrics, predcols=0, classes=classnames_unweighted
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_batch_size = graph_output["valid_batch_size"]
|
|
||||||
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
|
|
||||||
labels = remove_padding_and_flatten(labels, valid_batch_size)
|
|
||||||
weights = remove_padding_and_flatten(weights, valid_batch_size)
|
|
||||||
|
|
||||||
tf.ensure_shape(graph_output["output"], [None, 1])
|
|
||||||
tf.ensure_shape(labels, [None, 1])
|
|
||||||
tf.ensure_shape(weights, [None, 1])
|
|
||||||
|
|
||||||
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
|
|
||||||
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
|
|
||||||
metrics_weighted.update(metrics_unweighted)
|
|
||||||
|
|
||||||
if use_top_k:
|
|
||||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=0, labelcol=1)
|
|
||||||
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
|
|
||||||
metrics_weighted.update(metrics_numeric)
|
|
||||||
return metrics_weighted
|
|
||||||
|
|
||||||
return get_eval_metric_ops
|
|
||||||
|
|
||||||
|
|
||||||
def get_meta_learn_dual_binary_tasks_metric_fn(
|
|
||||||
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
|
|
||||||
):
|
|
||||||
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metrics: A list of string representing metric names.
|
|
||||||
classnames: A list of string repsenting class names, In case of multiple binary class models,
|
|
||||||
the names for each class or label.
|
|
||||||
top_k: A tuple of int to specify top K metrics.
|
|
||||||
use_top_k: A boolean value indicating of top K of metrics is used.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A customized metric_fn function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_eval_metric_ops(graph_output, labels, weights):
|
|
||||||
"""The op func of the eval_metrics. Comparing with normal version,
|
|
||||||
the difference is we flatten the output, label, and weights.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_output: A dict of tensors.
|
|
||||||
labels: A tensor of int32 be the value of either 0 or 1.
|
|
||||||
weights: A tensor of float32 to indicate the per record weight.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict of metric names and values.
|
|
||||||
"""
|
|
||||||
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
|
|
||||||
metrics, predcols=[0, 1], classes=classnames
|
|
||||||
)
|
|
||||||
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
|
|
||||||
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
|
|
||||||
metrics, predcols=[0, 1], classes=classnames_unweighted
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_batch_size = graph_output["valid_batch_size"]
|
|
||||||
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
|
|
||||||
labels = remove_padding_and_flatten(labels, valid_batch_size)
|
|
||||||
weights = remove_padding_and_flatten(weights, valid_batch_size)
|
|
||||||
|
|
||||||
tf.ensure_shape(graph_output["output"], [None, 2])
|
|
||||||
tf.ensure_shape(labels, [None, 2])
|
|
||||||
tf.ensure_shape(weights, [None, 1])
|
|
||||||
|
|
||||||
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
|
|
||||||
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
|
|
||||||
metrics_weighted.update(metrics_unweighted)
|
|
||||||
|
|
||||||
if use_top_k:
|
|
||||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=2, labelcol=2)
|
|
||||||
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
|
|
||||||
metrics_weighted.update(metrics_numeric)
|
|
||||||
return metrics_weighted
|
|
||||||
|
|
||||||
return get_eval_metric_ops
|
|
||||||
|
|
||||||
|
|
||||||
def get_metric_fn(task_name, use_stratify_metrics, use_meta_batch=False):
|
|
||||||
"""Will retrieve the metric_fn for magic recs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_name (string): Which task is being used for this model.
|
|
||||||
use_stratify_metrics (boolean): Should we add stratified metrics (new user metrics).
|
|
||||||
use_meta_batch (boolean): If the output/label/weights are passed in 3D shape instead of
|
|
||||||
2D shape.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A metric_fn function to pass in twml Trainer.
|
|
||||||
"""
|
|
||||||
if task_name not in METRIC_BOOK:
|
|
||||||
raise ValueError(
|
|
||||||
"Task name of {task_name} not recognized. Unable to retrieve metrics.".format(
|
|
||||||
task_name=task_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
class_names = METRIC_BOOK[task_name]
|
|
||||||
if use_meta_batch:
|
|
||||||
get_n_binary_task_metric_fn = (
|
|
||||||
get_meta_learn_single_binary_task_metric_fn
|
|
||||||
if len(class_names) == 1
|
|
||||||
else get_meta_learn_dual_binary_tasks_metric_fn
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
get_n_binary_task_metric_fn = (
|
|
||||||
get_single_binary_task_metric_fn if len(class_names) == 1 else get_dual_binary_tasks_metric_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
metric_fn = get_n_binary_task_metric_fn(metrics=None, classnames=METRIC_BOOK[task_name])
|
|
||||||
|
|
||||||
if use_stratify_metrics:
|
|
||||||
metric_fn = add_new_user_metrics(metric_fn)
|
|
||||||
|
|
||||||
return metric_fn
|
|
||||||
|
|
||||||
|
|
||||||
def flip_disliked_labels(metric_fn):
|
|
||||||
"""This function returns an adapted metric_fn which flips the labels of the OONCed evaluation data to 0 if it is disliked.
|
|
||||||
Args:
|
|
||||||
metric_fn: A metric_fn function to pass in twml Trainer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_adapted_metric_fn: A customized metric_fn function with disliked OONC labels flipped.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _adapted_metric_fn(graph_output, labels, weights):
|
|
||||||
"""A customized metric_fn function with disliked OONC labels flipped.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_output: A dict of tensors.
|
|
||||||
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
|
|
||||||
weights: A tensor of float32 to indicate the per record weight.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict of metric names and values.
|
|
||||||
"""
|
|
||||||
# We want to multiply the label of the observation by 0 only when it is disliked
|
|
||||||
disliked_mask = generate_disliked_mask(labels)
|
|
||||||
|
|
||||||
# Extract OONC and engagement labels only.
|
|
||||||
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
|
|
||||||
|
|
||||||
# Labels will be set to 0 if it is disliked.
|
|
||||||
adapted_labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
|
|
||||||
|
|
||||||
return metric_fn(graph_output, adapted_labels, weights)
|
|
||||||
|
|
||||||
return _adapted_metric_fn
|
|
BIN
pushservice/src/main/python/models/libs/model_args.docx
Normal file
BIN
pushservice/src/main/python/models/libs/model_args.docx
Normal file
Binary file not shown.
@ -1,231 +0,0 @@
|
|||||||
from twml.trainers import DataRecordTrainer
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
def get_arg_parser():
|
|
||||||
parser = DataRecordTrainer.add_parser_arguments()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--input_size_bits",
|
|
||||||
type=int,
|
|
||||||
default=18,
|
|
||||||
help="number of bits allocated to the input size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_trainer_name",
|
|
||||||
default="magic_recs_mlp_calibration_MTL_OONC_Engagement",
|
|
||||||
type=str,
|
|
||||||
help="specify the model trainer name.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_type",
|
|
||||||
default="deepnorm_gbdt_inputdrop2_rescale",
|
|
||||||
type=str,
|
|
||||||
help="specify the model type to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--feat_config_type",
|
|
||||||
default="get_feature_config_with_sparse_continuous",
|
|
||||||
type=str,
|
|
||||||
help="specify the feature configure function to use.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--directly_export_best",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="whether to directly_export best_checkpoint",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--warm_start_base_dir",
|
|
||||||
default="none",
|
|
||||||
type=str,
|
|
||||||
help="latest ckpt in this folder will be used to ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--feature_list",
|
|
||||||
default="none",
|
|
||||||
type=str,
|
|
||||||
help="Which features to use for training",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--momentum", default=0.99999, type=float, help="Momentum term for batch normalization"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dropout",
|
|
||||||
default=0.2,
|
|
||||||
type=float,
|
|
||||||
help="input_dropout_rate to rescale output by (1 - input_dropout_rate)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--out_layer_1_size", default=256, type=int, help="Size of MLP_branch layer 1"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--out_layer_2_size", default=128, type=int, help="Size of MLP_branch layer 2"
|
|
||||||
)
|
|
||||||
parser.add_argument("--out_layer_3_size", default=64, type=int, help="Size of MLP_branch layer 3")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sparse_embedding_size", default=50, type=int, help="Dimensionality of sparse embedding layer"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dense_embedding_size", default=128, type=int, help="Dimensionality of dense embedding layer"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_uam_label",
|
|
||||||
default=False,
|
|
||||||
type=str,
|
|
||||||
help="Whether to use uam_label or not",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--task_name",
|
|
||||||
default="OONC_Engagement",
|
|
||||||
type=str,
|
|
||||||
help="specify the task name to use: OONC or OONC_Engagement.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--init_weight",
|
|
||||||
default=0.9,
|
|
||||||
type=float,
|
|
||||||
help="Initial OONC Task Weight MTL: OONC+Engagement.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_engagement_weight",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="whether to use engagement weight for base model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mtl_num_extra_layers",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of Hidden Layers for each TaskBranch.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mtl_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MTL Extra Layers."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_oonc_score",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="whether to use oonc score only or combined score.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_stratified_metrics",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Use stratified metrics: Break out new-user metrics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--run_group_metrics",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Will run evaluation metrics grouped by user.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_full_scope",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Will add extra scope and naming to graph.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--trainable_regexes",
|
|
||||||
default=None,
|
|
||||||
nargs="*",
|
|
||||||
help="The union of variables specified by the list of regexes will be considered trainable.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fine_tuning.ckpt_to_initialize_from",
|
|
||||||
dest="fine_tuning_ckpt_to_initialize_from",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Checkpoint path from which to warm start. Indicates the pre-trained model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fine_tuning.warm_start_scope_regex",
|
|
||||||
dest="fine_tuning_warm_start_scope_regex",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="All variables matching this will be restored.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_params(args=None):
|
|
||||||
parser = get_arg_parser()
|
|
||||||
if args is None:
|
|
||||||
return parser.parse_args()
|
|
||||||
else:
|
|
||||||
return parser.parse_args(args)
|
|
||||||
|
|
||||||
|
|
||||||
def get_arg_parser_light_ranking():
|
|
||||||
parser = get_arg_parser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_record_weight",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="whether to use record weight for base model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--min_record_weight", default=0.0, type=float, help="Minimum record weight to use."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--smooth_weight", default=0.0, type=float, help="Factor to smooth Rank Position Weight."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_mlp_layers", type=int, default=3, help="Number of Hidden Layers for MLP model."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mlp_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MLP Layers."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--run_light_ranking_group_metrics",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Will run evaluation metrics grouped by user for Light Ranking.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_missing_sub_branch",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to use missing value sub-branch for Light Ranking.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_gbdt_features",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to use GBDT features for Light Ranking.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--run_light_ranking_group_metrics_in_bq",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to get_predictions for Light Ranking to compute group metrics in BigQuery.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pred_file_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="path",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pred_file_name",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="path",
|
|
||||||
)
|
|
||||||
return parser
|
|
BIN
pushservice/src/main/python/models/libs/model_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/model_utils.docx
Normal file
Binary file not shown.
@ -1,339 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
import twml
|
|
||||||
|
|
||||||
from .initializer import customized_glorot_uniform
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
def read_config(whitelist_yaml_file):
|
|
||||||
with tf.gfile.FastGFile(whitelist_yaml_file) as f:
|
|
||||||
try:
|
|
||||||
return yaml.safe_load(f)
|
|
||||||
except yaml.YAMLError as exc:
|
|
||||||
print(exc)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _sparse_feature_fixup(features, input_size_bits):
|
|
||||||
"""Rebuild a sparse tensor feature so that its dense shape attribute is present.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
features (SparseTensor): Sparse feature tensor of shape ``(B, sparse_feature_dim)``.
|
|
||||||
input_size_bits (int): Number of columns in ``log2`` scale. Must be positive.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SparseTensor: Rebuilt and non-faulty version of `features`."""
|
|
||||||
sparse_feature_dim = tf.constant(2**input_size_bits, dtype=tf.int64)
|
|
||||||
sparse_shape = tf.stack([features.dense_shape[0], sparse_feature_dim])
|
|
||||||
sparse_tf = tf.SparseTensor(features.indices, features.values, sparse_shape)
|
|
||||||
return sparse_tf
|
|
||||||
|
|
||||||
|
|
||||||
def self_atten_dense(input, out_dim, activation=None, use_bias=True, name=None):
|
|
||||||
def safe_concat(base, suffix):
|
|
||||||
"""Concats variables name components if base is given."""
|
|
||||||
if not base:
|
|
||||||
return base
|
|
||||||
return f"{base}:{suffix}"
|
|
||||||
|
|
||||||
input_dim = input.shape.as_list()[1]
|
|
||||||
|
|
||||||
sigmoid_out = twml.layers.FullDense(
|
|
||||||
input_dim, dtype=tf.float32, activation=tf.nn.sigmoid, name=safe_concat(name, "sigmoid_out")
|
|
||||||
)(input)
|
|
||||||
atten_input = sigmoid_out * input
|
|
||||||
mlp_out = twml.layers.FullDense(
|
|
||||||
out_dim,
|
|
||||||
dtype=tf.float32,
|
|
||||||
activation=activation,
|
|
||||||
use_bias=use_bias,
|
|
||||||
name=safe_concat(name, "mlp_out"),
|
|
||||||
)(atten_input)
|
|
||||||
return mlp_out
|
|
||||||
|
|
||||||
|
|
||||||
def get_dense_out(input, out_dim, activation, dense_type):
|
|
||||||
if dense_type == "full_dense":
|
|
||||||
out = twml.layers.FullDense(out_dim, dtype=tf.float32, activation=activation)(input)
|
|
||||||
elif dense_type == "self_atten_dense":
|
|
||||||
out = self_atten_dense(input, out_dim, activation=activation)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def get_input_trans_func(bn_normalized_dense, is_training):
|
|
||||||
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
|
|
||||||
group_num = bn_normalized_dense.shape.as_list()[1]
|
|
||||||
|
|
||||||
gw_normalized_dense = GroupWiseTrans(group_num, 1, 8, name="groupwise_1", activation=tf.tanh)(
|
|
||||||
gw_normalized_dense
|
|
||||||
)
|
|
||||||
gw_normalized_dense = GroupWiseTrans(group_num, 8, 4, name="groupwise_2", activation=tf.tanh)(
|
|
||||||
gw_normalized_dense
|
|
||||||
)
|
|
||||||
gw_normalized_dense = GroupWiseTrans(group_num, 4, 1, name="groupwise_3", activation=tf.tanh)(
|
|
||||||
gw_normalized_dense
|
|
||||||
)
|
|
||||||
|
|
||||||
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
|
|
||||||
|
|
||||||
bn_gw_normalized_dense = tf.layers.batch_normalization(
|
|
||||||
gw_normalized_dense,
|
|
||||||
training=is_training,
|
|
||||||
renorm_momentum=0.9999,
|
|
||||||
momentum=0.9999,
|
|
||||||
renorm=is_training,
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return bn_gw_normalized_dense
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_dropout(
|
|
||||||
input_tensor,
|
|
||||||
rate,
|
|
||||||
is_training,
|
|
||||||
sparse_tensor=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Implements dropout layer for both dense and sparse input_tensor
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
input_tensor:
|
|
||||||
B x D dense tensor, or a sparse tensor
|
|
||||||
rate (float32):
|
|
||||||
dropout rate
|
|
||||||
is_training (bool):
|
|
||||||
training stage or not.
|
|
||||||
sparse_tensor (bool):
|
|
||||||
whether the input_tensor is sparse tensor or not. Default to be None, this value has to be passed explicitly.
|
|
||||||
rescale_sparse_dropout (bool):
|
|
||||||
Do we need to do rescaling or not.
|
|
||||||
Returns:
|
|
||||||
tensor dropped out"""
|
|
||||||
if sparse_tensor == True:
|
|
||||||
if is_training:
|
|
||||||
with tf.variable_scope("sparse_dropout"):
|
|
||||||
values = input_tensor.values
|
|
||||||
keep_mask = tf.keras.backend.random_binomial(
|
|
||||||
tf.shape(values), p=1 - rate, dtype=tf.float32, seed=None
|
|
||||||
)
|
|
||||||
keep_mask.set_shape([None])
|
|
||||||
keep_mask = tf.cast(keep_mask, tf.bool)
|
|
||||||
|
|
||||||
keep_indices = tf.boolean_mask(input_tensor.indices, keep_mask, axis=0)
|
|
||||||
keep_values = tf.boolean_mask(values, keep_mask, axis=0)
|
|
||||||
|
|
||||||
dropped_tensor = tf.SparseTensor(keep_indices, keep_values, input_tensor.dense_shape)
|
|
||||||
return dropped_tensor
|
|
||||||
else:
|
|
||||||
return input_tensor
|
|
||||||
elif sparse_tensor == False:
|
|
||||||
return tf.layers.dropout(input_tensor, rate=rate, training=is_training)
|
|
||||||
|
|
||||||
|
|
||||||
def adaptive_transformation(bn_normalized_dense, is_training, func_type="default"):
|
|
||||||
assert func_type in [
|
|
||||||
"default",
|
|
||||||
"tiny",
|
|
||||||
], f"fun_type can only be one of default and tiny, but get {func_type}"
|
|
||||||
|
|
||||||
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
|
|
||||||
group_num = bn_normalized_dense.shape.as_list()[1]
|
|
||||||
|
|
||||||
if func_type == "default":
|
|
||||||
gw_normalized_dense = FastGroupWiseTrans(
|
|
||||||
group_num, 1, 8, name="groupwise_1", activation=tf.tanh, init_multiplier=8
|
|
||||||
)(gw_normalized_dense)
|
|
||||||
|
|
||||||
gw_normalized_dense = FastGroupWiseTrans(
|
|
||||||
group_num, 8, 4, name="groupwise_2", activation=tf.tanh, init_multiplier=8
|
|
||||||
)(gw_normalized_dense)
|
|
||||||
|
|
||||||
gw_normalized_dense = FastGroupWiseTrans(
|
|
||||||
group_num, 4, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
|
|
||||||
)(gw_normalized_dense)
|
|
||||||
elif func_type == "tiny":
|
|
||||||
gw_normalized_dense = FastGroupWiseTrans(
|
|
||||||
group_num, 1, 2, name="groupwise_1", activation=tf.tanh, init_multiplier=8
|
|
||||||
)(gw_normalized_dense)
|
|
||||||
|
|
||||||
gw_normalized_dense = FastGroupWiseTrans(
|
|
||||||
group_num, 2, 1, name="groupwise_2", activation=tf.tanh, init_multiplier=8
|
|
||||||
)(gw_normalized_dense)
|
|
||||||
|
|
||||||
gw_normalized_dense = FastGroupWiseTrans(
|
|
||||||
group_num, 1, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
|
|
||||||
)(gw_normalized_dense)
|
|
||||||
|
|
||||||
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
|
|
||||||
bn_gw_normalized_dense = tf.layers.batch_normalization(
|
|
||||||
gw_normalized_dense,
|
|
||||||
training=is_training,
|
|
||||||
renorm_momentum=0.9999,
|
|
||||||
momentum=0.9999,
|
|
||||||
renorm=is_training,
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return bn_gw_normalized_dense
|
|
||||||
|
|
||||||
|
|
||||||
class FastGroupWiseTrans(object):
|
|
||||||
"""
|
|
||||||
used to apply group-wise fully connected layers to the input.
|
|
||||||
it applies a tiny, unique MLP to each individual feature."""
|
|
||||||
|
|
||||||
def __init__(self, group_num, input_dim, out_dim, name, activation=None, init_multiplier=1):
|
|
||||||
self.group_num = group_num
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.out_dim = out_dim
|
|
||||||
self.activation = activation
|
|
||||||
self.init_multiplier = init_multiplier
|
|
||||||
|
|
||||||
self.w = tf.get_variable(
|
|
||||||
name + "_group_weight",
|
|
||||||
[1, group_num, input_dim, out_dim],
|
|
||||||
initializer=customized_glorot_uniform(
|
|
||||||
fan_in=input_dim * init_multiplier, fan_out=out_dim * init_multiplier
|
|
||||||
),
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
self.b = tf.get_variable(
|
|
||||||
name + "_group_bias",
|
|
||||||
[1, group_num, out_dim],
|
|
||||||
initializer=tf.constant_initializer(0.0),
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input_tensor):
|
|
||||||
"""
|
|
||||||
input_tensor: batch_size x group_num x input_dim
|
|
||||||
output_tensor: batch_size x group_num x out_dim"""
|
|
||||||
input_tensor_expand = tf.expand_dims(input_tensor, axis=-1)
|
|
||||||
|
|
||||||
output_tensor = tf.add(
|
|
||||||
tf.reduce_sum(tf.multiply(input_tensor_expand, self.w), axis=-2, keepdims=False),
|
|
||||||
self.b,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.activation is not None:
|
|
||||||
output_tensor = self.activation(output_tensor)
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class GroupWiseTrans(object):
|
|
||||||
"""
|
|
||||||
Used to apply group fully connected layers to the input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, group_num, input_dim, out_dim, name, activation=None):
|
|
||||||
self.group_num = group_num
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.out_dim = out_dim
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
w_list, b_list = [], []
|
|
||||||
for idx in range(out_dim):
|
|
||||||
this_w = tf.get_variable(
|
|
||||||
name + f"_group_weight_{idx}",
|
|
||||||
[1, group_num, input_dim],
|
|
||||||
initializer=tf.keras.initializers.glorot_uniform(),
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
this_b = tf.get_variable(
|
|
||||||
name + f"_group_bias_{idx}",
|
|
||||||
[1, group_num, 1],
|
|
||||||
initializer=tf.constant_initializer(0.0),
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
w_list.append(this_w)
|
|
||||||
b_list.append(this_b)
|
|
||||||
self.w_list = w_list
|
|
||||||
self.b_list = b_list
|
|
||||||
|
|
||||||
def __call__(self, input_tensor):
|
|
||||||
"""
|
|
||||||
input_tensor: batch_size x group_num x input_dim
|
|
||||||
output_tensor: batch_size x group_num x out_dim
|
|
||||||
"""
|
|
||||||
out_tensor_list = []
|
|
||||||
for idx in range(self.out_dim):
|
|
||||||
this_res = (
|
|
||||||
tf.reduce_sum(input_tensor * self.w_list[idx], axis=-1, keepdims=True) + self.b_list[idx]
|
|
||||||
)
|
|
||||||
out_tensor_list.append(this_res)
|
|
||||||
output_tensor = tf.concat(out_tensor_list, axis=-1)
|
|
||||||
|
|
||||||
if self.activation is not None:
|
|
||||||
output_tensor = self.activation(output_tensor)
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def add_scalar_summary(var, name, name_scope="hist_dense_feature/"):
|
|
||||||
with tf.name_scope("summaries/"):
|
|
||||||
with tf.name_scope(name_scope):
|
|
||||||
tf.summary.scalar(name, var)
|
|
||||||
|
|
||||||
|
|
||||||
def add_histogram_summary(var, name, name_scope="hist_dense_feature/"):
|
|
||||||
with tf.name_scope("summaries/"):
|
|
||||||
with tf.name_scope(name_scope):
|
|
||||||
tf.summary.histogram(name, tf.reshape(var, [-1]))
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_clip_by_value(sparse_tf, min_val, max_val):
|
|
||||||
new_vals = tf.clip_by_value(sparse_tf.values, min_val, max_val)
|
|
||||||
return tf.SparseTensor(sparse_tf.indices, new_vals, sparse_tf.dense_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def check_numerics_with_msg(tensor, message="", sparse_tensor=False):
|
|
||||||
if sparse_tensor:
|
|
||||||
values = tf.debugging.check_numerics(tensor.values, message=message)
|
|
||||||
return tf.SparseTensor(tensor.indices, values, tensor.dense_shape)
|
|
||||||
else:
|
|
||||||
return tf.debugging.check_numerics(tensor, message=message)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_empty_sparse_tensor(tensor):
|
|
||||||
dummy_tensor = tf.SparseTensor(
|
|
||||||
indices=[[0, 0]],
|
|
||||||
values=[0.00001],
|
|
||||||
dense_shape=tensor.dense_shape,
|
|
||||||
)
|
|
||||||
result = tf.cond(
|
|
||||||
tf.equal(tf.size(tensor.values), 0),
|
|
||||||
lambda: dummy_tensor,
|
|
||||||
lambda: tensor,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def filter_nans_and_infs(tensor, sparse_tensor=False):
|
|
||||||
if sparse_tensor:
|
|
||||||
sparse_values = tensor.values
|
|
||||||
filtered_val = tf.where(
|
|
||||||
tf.logical_or(tf.is_nan(sparse_values), tf.is_inf(sparse_values)),
|
|
||||||
tf.zeros_like(sparse_values),
|
|
||||||
sparse_values,
|
|
||||||
)
|
|
||||||
return tf.SparseTensor(tensor.indices, filtered_val, tensor.dense_shape)
|
|
||||||
else:
|
|
||||||
return tf.where(
|
|
||||||
tf.logical_or(tf.is_nan(tensor), tf.is_inf(tensor)), tf.zeros_like(tensor), tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_disliked_mask(labels):
|
|
||||||
"""Generate a disliked mask where only samples with dislike labels are set to 1 otherwise set to 0.
|
|
||||||
Args:
|
|
||||||
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
|
|
||||||
Returns:
|
|
||||||
1D tensor of shape batch_size x 1: [dislikes (booleans)]
|
|
||||||
"""
|
|
||||||
return tf.equal(tf.reshape(labels[:, 2], shape=[-1, 1]), 1)
|
|
BIN
pushservice/src/main/python/models/libs/warm_start_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/warm_start_utils.docx
Normal file
Binary file not shown.
@ -1,309 +0,0 @@
|
|||||||
from collections import OrderedDict
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from os.path import join
|
|
||||||
|
|
||||||
from twitter.magicpony.common import file_access
|
|
||||||
import twml
|
|
||||||
|
|
||||||
from .model_utils import read_config
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from scipy import stats
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_type_to_tensors_to_change_axis():
|
|
||||||
model_type_to_tensors_to_change_axis = {
|
|
||||||
"magic_recs/model/batch_normalization/beta": ([0], "continuous"),
|
|
||||||
"magic_recs/model/batch_normalization/gamma": ([0], "continuous"),
|
|
||||||
"magic_recs/model/batch_normalization/moving_mean": ([0], "continuous"),
|
|
||||||
"magic_recs/model/batch_normalization/moving_stddev": ([0], "continuous"),
|
|
||||||
"magic_recs/model/batch_normalization/moving_variance": ([0], "continuous"),
|
|
||||||
"magic_recs/model/batch_normalization/renorm_mean": ([0], "continuous"),
|
|
||||||
"magic_recs/model/batch_normalization/renorm_stddev": ([0], "continuous"),
|
|
||||||
"magic_recs/model/logits/EngagementGivenOONC_logits/clem_net_1/block2_4/channel_wise_dense_4/kernel": (
|
|
||||||
[1],
|
|
||||||
"all",
|
|
||||||
),
|
|
||||||
"magic_recs/model/logits/OONC_logits/clem_net/block2/channel_wise_dense/kernel": ([1], "all"),
|
|
||||||
}
|
|
||||||
|
|
||||||
return model_type_to_tensors_to_change_axis
|
|
||||||
|
|
||||||
|
|
||||||
def mkdirp(dirname):
|
|
||||||
if not tf.io.gfile.exists(dirname):
|
|
||||||
tf.io.gfile.makedirs(dirname)
|
|
||||||
|
|
||||||
|
|
||||||
def rename_dir(dirname, dst):
|
|
||||||
file_access.hdfs.mv(dirname, dst)
|
|
||||||
|
|
||||||
|
|
||||||
def rmdir(dirname):
|
|
||||||
if tf.io.gfile.exists(dirname):
|
|
||||||
if tf.io.gfile.isdir(dirname):
|
|
||||||
tf.io.gfile.rmtree(dirname)
|
|
||||||
else:
|
|
||||||
tf.io.gfile.remove(dirname)
|
|
||||||
|
|
||||||
|
|
||||||
def get_var_dict(checkpoint_path):
|
|
||||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_path)
|
|
||||||
var_dict = OrderedDict()
|
|
||||||
with tf.Session() as sess:
|
|
||||||
all_var_list = tf.train.list_variables(checkpoint_path)
|
|
||||||
for var_name, _ in all_var_list:
|
|
||||||
# Load the variable
|
|
||||||
var = tf.train.load_variable(checkpoint_path, var_name)
|
|
||||||
var_dict[var_name] = var
|
|
||||||
return var_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_continunous_mapping_from_feat_list(old_feature_list, new_feature_list):
|
|
||||||
"""
|
|
||||||
get var_ind for old_feature and corresponding var_ind for new_feature
|
|
||||||
"""
|
|
||||||
new_var_ind, old_var_ind = [], []
|
|
||||||
for this_new_id, this_new_name in enumerate(new_feature_list):
|
|
||||||
if this_new_name in old_feature_list:
|
|
||||||
this_old_id = old_feature_list.index(this_new_name)
|
|
||||||
new_var_ind.append(this_new_id)
|
|
||||||
old_var_ind.append(this_old_id)
|
|
||||||
return np.asarray(old_var_ind), np.asarray(new_var_ind)
|
|
||||||
|
|
||||||
|
|
||||||
def get_continuous_mapping_from_feat_dict(old_feature_dict, new_feature_dict):
|
|
||||||
"""
|
|
||||||
get var_ind for old_feature and corresponding var_ind for new_feature
|
|
||||||
"""
|
|
||||||
old_cont = old_feature_dict["continuous"]
|
|
||||||
old_bin = old_feature_dict["binary"]
|
|
||||||
|
|
||||||
new_cont = new_feature_dict["continuous"]
|
|
||||||
new_bin = new_feature_dict["binary"]
|
|
||||||
|
|
||||||
_dummy_sparse_feat = [f"sparse_feature_{_idx}" for _idx in range(100)]
|
|
||||||
|
|
||||||
cont_old_var_ind, cont_new_var_ind = get_continunous_mapping_from_feat_list(old_cont, new_cont)
|
|
||||||
|
|
||||||
all_old_var_ind, all_new_var_ind = get_continunous_mapping_from_feat_list(
|
|
||||||
old_cont + old_bin + _dummy_sparse_feat, new_cont + new_bin + _dummy_sparse_feat
|
|
||||||
)
|
|
||||||
|
|
||||||
_res = {
|
|
||||||
"continuous": (cont_old_var_ind, cont_new_var_ind),
|
|
||||||
"all": (all_old_var_ind, all_new_var_ind),
|
|
||||||
}
|
|
||||||
|
|
||||||
return _res
|
|
||||||
|
|
||||||
|
|
||||||
def warm_start_from_var_dict(
|
|
||||||
old_ckpt_path,
|
|
||||||
var_ind_dict,
|
|
||||||
output_dir,
|
|
||||||
new_len_var,
|
|
||||||
var_to_change_dict_fn=get_model_type_to_tensors_to_change_axis,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
old_ckpt_path (str): path to the old checkpoint path
|
|
||||||
new_var_ind (array of int): index to overlapping features in new var between old and new feature list.
|
|
||||||
old_var_ind (array of int): index to overlapping features in old var between old and new feature list.
|
|
||||||
|
|
||||||
output_dir (str): dir that used to write modified checkpoint
|
|
||||||
new_len_var ({str:int}): number of feature in the new feature list.
|
|
||||||
var_to_change_dict_fn (dict): A function to get the dictionary of format {var_name: dim_to_change}
|
|
||||||
"""
|
|
||||||
old_var_dict = get_var_dict(old_ckpt_path)
|
|
||||||
|
|
||||||
ckpt_file_name = os.path.basename(old_ckpt_path)
|
|
||||||
mkdirp(output_dir)
|
|
||||||
output_path = join(output_dir, ckpt_file_name)
|
|
||||||
|
|
||||||
tensors_to_change = var_to_change_dict_fn()
|
|
||||||
tf.compat.v1.reset_default_graph()
|
|
||||||
|
|
||||||
with tf.Session() as sess:
|
|
||||||
var_name_shape_list = tf.train.list_variables(old_ckpt_path)
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for var_name, var_shape in var_name_shape_list:
|
|
||||||
old_var = old_var_dict[var_name]
|
|
||||||
if var_name in tensors_to_change.keys():
|
|
||||||
_info_tuple = tensors_to_change[var_name]
|
|
||||||
dims_to_remove_from, var_type = _info_tuple
|
|
||||||
|
|
||||||
new_var_ind, old_var_ind = var_ind_dict[var_type]
|
|
||||||
|
|
||||||
this_shape = list(old_var.shape)
|
|
||||||
for this_dim in dims_to_remove_from:
|
|
||||||
this_shape[this_dim] = new_len_var[var_type]
|
|
||||||
|
|
||||||
stddev = np.std(old_var)
|
|
||||||
truncated_norm_generator = stats.truncnorm(-0.5, 0.5, loc=0, scale=stddev)
|
|
||||||
size = np.prod(this_shape)
|
|
||||||
new_var = truncated_norm_generator.rvs(size).reshape(this_shape)
|
|
||||||
new_var = new_var.astype(old_var.dtype)
|
|
||||||
|
|
||||||
new_var = copy_feat_based_on_mapping(
|
|
||||||
new_var, old_var, dims_to_remove_from, new_var_ind, old_var_ind
|
|
||||||
)
|
|
||||||
count = count + 1
|
|
||||||
else:
|
|
||||||
new_var = old_var
|
|
||||||
var = tf.Variable(new_var, name=var_name)
|
|
||||||
assert count == len(tensors_to_change.keys()), "not all variables are exchanged.\n"
|
|
||||||
saver = tf.train.Saver()
|
|
||||||
sess.run(tf.global_variables_initializer())
|
|
||||||
saver.save(sess, output_path)
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def copy_feat_based_on_mapping(new_array, old_array, dims_to_remove_from, new_var_ind, old_var_ind):
|
|
||||||
if dims_to_remove_from == [0, 1]:
|
|
||||||
for this_new_ind, this_old_ind in zip(new_var_ind, old_var_ind):
|
|
||||||
new_array[this_new_ind, new_var_ind] = old_array[this_old_ind, old_var_ind]
|
|
||||||
elif dims_to_remove_from == [0]:
|
|
||||||
new_array[new_var_ind] = old_array[old_var_ind]
|
|
||||||
elif dims_to_remove_from == [1]:
|
|
||||||
new_array[:, new_var_ind] = old_array[:, old_var_ind]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"undefined dims_to_remove_from pattern: ({dims_to_remove_from})")
|
|
||||||
return new_array
|
|
||||||
|
|
||||||
|
|
||||||
def read_file(filename, decode=False):
|
|
||||||
"""
|
|
||||||
Reads contents from a file and optionally decodes it.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
filename:
|
|
||||||
path to file where the contents will be loaded from.
|
|
||||||
Accepts HDFS and local paths.
|
|
||||||
decode:
|
|
||||||
False or 'json'. When decode='json', contents is decoded
|
|
||||||
with json.loads. When False, contents is returned as is.
|
|
||||||
"""
|
|
||||||
graph = tf.Graph()
|
|
||||||
with graph.as_default():
|
|
||||||
read = tf.read_file(filename)
|
|
||||||
|
|
||||||
with tf.Session(graph=graph) as sess:
|
|
||||||
contents = sess.run(read)
|
|
||||||
if not isinstance(contents, str):
|
|
||||||
contents = contents.decode()
|
|
||||||
|
|
||||||
if decode == "json":
|
|
||||||
contents = json.loads(contents)
|
|
||||||
|
|
||||||
return contents
|
|
||||||
|
|
||||||
|
|
||||||
def read_feat_list_from_disk(file_path):
|
|
||||||
return read_file(file_path, decode="json")
|
|
||||||
|
|
||||||
|
|
||||||
def get_feature_list_for_light_ranking(feature_list_path, data_spec_path):
|
|
||||||
feature_list = read_config(feature_list_path).items()
|
|
||||||
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
|
||||||
|
|
||||||
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
|
||||||
data_spec_path=data_spec_path
|
|
||||||
)
|
|
||||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
|
||||||
feature_regexes=string_feat_list,
|
|
||||||
group_name="continuous",
|
|
||||||
default_value=-1,
|
|
||||||
type_filter=["CONTINUOUS"],
|
|
||||||
)
|
|
||||||
feature_config = feature_config_builder.build()
|
|
||||||
feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
|
|
||||||
"CONTINUOUS"
|
|
||||||
]
|
|
||||||
return feature_list
|
|
||||||
|
|
||||||
|
|
||||||
def get_feature_list_for_heavy_ranking(feature_list_path, data_spec_path):
|
|
||||||
feature_list = read_config(feature_list_path).items()
|
|
||||||
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
|
||||||
|
|
||||||
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
|
||||||
data_spec_path=data_spec_path
|
|
||||||
)
|
|
||||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
|
||||||
feature_regexes=string_feat_list,
|
|
||||||
group_name="continuous",
|
|
||||||
default_value=-1,
|
|
||||||
type_filter=["CONTINUOUS"],
|
|
||||||
)
|
|
||||||
|
|
||||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
|
||||||
feature_regexes=string_feat_list,
|
|
||||||
group_name="binary",
|
|
||||||
default_value=False,
|
|
||||||
type_filter=["BINARY"],
|
|
||||||
)
|
|
||||||
|
|
||||||
feature_config_builder = feature_config_builder.build()
|
|
||||||
|
|
||||||
continuous_feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
|
|
||||||
"CONTINUOUS"
|
|
||||||
]
|
|
||||||
|
|
||||||
binary_feature_list = feature_config_builder._feature_group_extraction_configs[1].feature_map[
|
|
||||||
"BINARY"
|
|
||||||
]
|
|
||||||
return {"continuous": continuous_feature_list, "binary": binary_feature_list}
|
|
||||||
|
|
||||||
|
|
||||||
def warm_start_checkpoint(
|
|
||||||
old_best_ckpt_folder,
|
|
||||||
old_feature_list_path,
|
|
||||||
feature_allow_list_path,
|
|
||||||
data_spec_path,
|
|
||||||
output_ckpt_folder,
|
|
||||||
*args,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Reads old checkpoint and the old feature list, and create a new ckpt warm started from old ckpt using new features .
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
old_best_ckpt_folder:
|
|
||||||
path to the best_checkpoint_folder for old model
|
|
||||||
old_feature_list_path:
|
|
||||||
path to the json file that stores the list of continuous features used in old models.
|
|
||||||
feature_allow_list_path:
|
|
||||||
yaml file that contain the feature allow list.
|
|
||||||
data_spec_path:
|
|
||||||
path to the data_spec file
|
|
||||||
output_ckpt_folder:
|
|
||||||
folder that contains the modified ckpt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
path to the modified ckpt."""
|
|
||||||
old_ckpt_path = tf.train.latest_checkpoint(old_best_ckpt_folder, latest_filename=None)
|
|
||||||
|
|
||||||
new_feature_dict = get_feature_list(feature_allow_list_path, data_spec_path)
|
|
||||||
old_feature_dict = read_feat_list_from_disk(old_feature_list_path)
|
|
||||||
|
|
||||||
var_ind_dict = get_continuous_mapping_from_feat_dict(new_feature_dict, old_feature_dict)
|
|
||||||
|
|
||||||
new_len_var = {
|
|
||||||
"continuous": len(new_feature_dict["continuous"]),
|
|
||||||
"all": len(new_feature_dict["continuous"] + new_feature_dict["binary"]) + 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
warm_started_ckpt_path = warm_start_from_var_dict(
|
|
||||||
old_ckpt_path,
|
|
||||||
var_ind_dict,
|
|
||||||
output_dir=output_ckpt_folder,
|
|
||||||
new_len_var=new_len_var,
|
|
||||||
)
|
|
||||||
|
|
||||||
return warm_started_ckpt_path
|
|
@ -1,69 +0,0 @@
|
|||||||
#":mlwf_libs",
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "eval_model",
|
|
||||||
source = "eval_model.py",
|
|
||||||
dependencies = [
|
|
||||||
":libs",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "train_model",
|
|
||||||
source = "deep_norm.py",
|
|
||||||
dependencies = [
|
|
||||||
":libs",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "train_model_local",
|
|
||||||
source = "deep_norm.py",
|
|
||||||
dependencies = [
|
|
||||||
":libs",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model_local",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "eval_model_local",
|
|
||||||
source = "eval_model.py",
|
|
||||||
dependencies = [
|
|
||||||
":libs",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model_local",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "mlwf_model",
|
|
||||||
source = "deep_norm.py",
|
|
||||||
dependencies = [
|
|
||||||
":mlwf_libs",
|
|
||||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:mlwf_model",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "libs",
|
|
||||||
sources = ["**/*.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
|
||||||
"src/python/twitter/deepbird/util/data",
|
|
||||||
"twml:twml-nodeps",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "mlwf_libs",
|
|
||||||
sources = ["**/*.py"],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
dependencies = [
|
|
||||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
|
||||||
"twml",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
pushservice/src/main/python/models/light_ranking/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/light_ranking/README.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/README.docx
Normal file
Binary file not shown.
@ -1,14 +0,0 @@
|
|||||||
# Notification Light Ranker Model
|
|
||||||
|
|
||||||
## Model Context
|
|
||||||
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification light ranker model bridges candidate generation and heavy ranking by pre-selecting highly-relevant candidates from the initial huge candidate pool. It’s a light-weight model to reduce system cost during heavy ranking without hurting user experience.
|
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
- BUILD: this file defines python library dependencies
|
|
||||||
- model_pools_mlp.py: this file defines tensorflow model architecture for the notification light ranker model
|
|
||||||
- deep_norm.py: this file contains 1) how to build the tensorflow graph with specified model architecture, loss function and training configuration. 2) how to set up the overall model training & evaluation pipeline
|
|
||||||
- eval_model.py: the main python entry file to set up the overall model evaluation pipeline
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
pushservice/src/main/python/models/light_ranking/__init__.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/__init__.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/light_ranking/deep_norm.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/deep_norm.docx
Normal file
Binary file not shown.
@ -1,226 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from functools import partial
|
|
||||||
import os
|
|
||||||
|
|
||||||
from twitter.cortex.ml.embeddings.common.helpers import decode_str_or_unicode
|
|
||||||
import twml
|
|
||||||
from twml.trainers import DataRecordTrainer
|
|
||||||
|
|
||||||
from ..libs.get_feat_config import get_feature_config_light_ranking, LABELS_LR
|
|
||||||
from ..libs.graph_utils import get_trainable_variables
|
|
||||||
from ..libs.group_metrics import (
|
|
||||||
run_group_metrics_light_ranking,
|
|
||||||
run_group_metrics_light_ranking_in_bq,
|
|
||||||
)
|
|
||||||
from ..libs.metric_fn_utils import get_metric_fn
|
|
||||||
from ..libs.model_args import get_arg_parser_light_ranking
|
|
||||||
from ..libs.model_utils import read_config
|
|
||||||
from ..libs.warm_start_utils import get_feature_list_for_light_ranking
|
|
||||||
from .model_pools_mlp import light_ranking_mlp_ngbdt
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
from tensorflow.compat.v1 import logging
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
|
|
||||||
def build_graph(
|
|
||||||
features, label, mode, params, config=None, run_light_ranking_group_metrics_in_bq=False
|
|
||||||
):
|
|
||||||
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
|
||||||
this_model_func = light_ranking_mlp_ngbdt
|
|
||||||
model_output = this_model_func(features, is_training, params, label)
|
|
||||||
|
|
||||||
logits = model_output["output"]
|
|
||||||
graph_output = {}
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# define graph output dict
|
|
||||||
# --------------------------------------------------------
|
|
||||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
|
||||||
loss = None
|
|
||||||
output_label = "prediction"
|
|
||||||
if params.task_name in LABELS_LR:
|
|
||||||
output = tf.nn.sigmoid(logits)
|
|
||||||
output = tf.clip_by_value(output, 0, 1)
|
|
||||||
|
|
||||||
if run_light_ranking_group_metrics_in_bq:
|
|
||||||
graph_output["trace_id"] = features["meta.trace_id"]
|
|
||||||
graph_output["target"] = features["meta.ranking.weighted_oonc_model_score"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid Task Name !")
|
|
||||||
|
|
||||||
else:
|
|
||||||
output_label = "output"
|
|
||||||
weights = tf.cast(features["weights"], dtype=tf.float32, name="RecordWeights")
|
|
||||||
|
|
||||||
if params.task_name in LABELS_LR:
|
|
||||||
if params.use_record_weight:
|
|
||||||
weights = tf.clip_by_value(
|
|
||||||
1.0 / (1.0 + weights + params.smooth_weight), params.min_record_weight, 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = tf.reduce_sum(
|
|
||||||
tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits) * weights
|
|
||||||
) / (tf.reduce_sum(weights))
|
|
||||||
else:
|
|
||||||
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits))
|
|
||||||
output = tf.nn.sigmoid(logits)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid Task Name !")
|
|
||||||
|
|
||||||
train_op = None
|
|
||||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# get train_op
|
|
||||||
# --------------------------------------------------------
|
|
||||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=params.learning_rate)
|
|
||||||
update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
|
||||||
variables = get_trainable_variables(
|
|
||||||
all_trainable_variables=tf.trainable_variables(), trainable_regexes=params.trainable_regexes
|
|
||||||
)
|
|
||||||
with tf.control_dependencies(update_ops):
|
|
||||||
train_op = twml.optimizers.optimize_loss(
|
|
||||||
loss=loss,
|
|
||||||
variables=variables,
|
|
||||||
global_step=tf.train.get_global_step(),
|
|
||||||
optimizer=optimizer,
|
|
||||||
learning_rate=params.learning_rate,
|
|
||||||
learning_rate_decay_fn=twml.learning_rate_decay.get_learning_rate_decay_fn(params),
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_output[output_label] = output
|
|
||||||
graph_output["loss"] = loss
|
|
||||||
graph_output["train_op"] = train_op
|
|
||||||
return graph_output
|
|
||||||
|
|
||||||
|
|
||||||
def get_params(args=None):
|
|
||||||
parser = get_arg_parser_light_ranking()
|
|
||||||
if args is None:
|
|
||||||
return parser.parse_args()
|
|
||||||
else:
|
|
||||||
return parser.parse_args(args)
|
|
||||||
|
|
||||||
|
|
||||||
def _main():
|
|
||||||
opt = get_params()
|
|
||||||
logging.info("parse is: ")
|
|
||||||
logging.info(opt)
|
|
||||||
|
|
||||||
feature_list = read_config(opt.feature_list).items()
|
|
||||||
feature_config = get_feature_config_light_ranking(
|
|
||||||
data_spec_path=opt.data_spec,
|
|
||||||
feature_list_provided=feature_list,
|
|
||||||
opt=opt,
|
|
||||||
add_gbdt=opt.use_gbdt_features,
|
|
||||||
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
|
|
||||||
)
|
|
||||||
feature_list_path = opt.feature_list
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# Create Trainer
|
|
||||||
# --------------------------------------------------------
|
|
||||||
trainer = DataRecordTrainer(
|
|
||||||
name=opt.model_trainer_name,
|
|
||||||
params=opt,
|
|
||||||
build_graph_fn=build_graph,
|
|
||||||
save_dir=opt.save_dir,
|
|
||||||
run_config=None,
|
|
||||||
feature_config=feature_config,
|
|
||||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
|
||||||
)
|
|
||||||
if opt.directly_export_best:
|
|
||||||
logging.info("Directly exporting the model without training")
|
|
||||||
else:
|
|
||||||
# ----------------------------------------------------
|
|
||||||
# Model Training & Evaluation
|
|
||||||
# ----------------------------------------------------
|
|
||||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
|
||||||
train_input_fn = trainer.get_train_input_fn(shuffle=True)
|
|
||||||
|
|
||||||
if opt.distributed or opt.num_workers is not None:
|
|
||||||
learn = trainer.train_and_evaluate
|
|
||||||
else:
|
|
||||||
learn = trainer.learn
|
|
||||||
logging.info("Training...")
|
|
||||||
start = datetime.now()
|
|
||||||
|
|
||||||
early_stop_metric = "rce_unweighted_" + opt.task_name
|
|
||||||
learn(
|
|
||||||
early_stop_minimize=False,
|
|
||||||
early_stop_metric=early_stop_metric,
|
|
||||||
early_stop_patience=opt.early_stop_patience,
|
|
||||||
early_stop_tolerance=opt.early_stop_tolerance,
|
|
||||||
eval_input_fn=eval_input_fn,
|
|
||||||
train_input_fn=train_input_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
end = datetime.now()
|
|
||||||
logging.info("Training time: " + str(end - start))
|
|
||||||
|
|
||||||
logging.info("Exporting the models...")
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# Do the model exporting
|
|
||||||
# --------------------------------------------------------
|
|
||||||
start = datetime.now()
|
|
||||||
if not opt.export_dir:
|
|
||||||
opt.export_dir = os.path.join(opt.save_dir, "exported_models")
|
|
||||||
|
|
||||||
raw_model_path = twml.contrib.export.export_fn.export_all_models(
|
|
||||||
trainer=trainer,
|
|
||||||
export_dir=opt.export_dir,
|
|
||||||
parse_fn=feature_config.get_parse_fn(),
|
|
||||||
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
|
|
||||||
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
|
|
||||||
)
|
|
||||||
export_model_dir = decode_str_or_unicode(raw_model_path)
|
|
||||||
|
|
||||||
logging.info("Model export time: " + str(datetime.now() - start))
|
|
||||||
logging.info("The saved model directory is: " + opt.save_dir)
|
|
||||||
|
|
||||||
tf.logging.info("getting default continuous_feature_list")
|
|
||||||
continuous_feature_list = get_feature_list_for_light_ranking(feature_list_path, opt.data_spec)
|
|
||||||
continous_feature_list_save_path = os.path.join(opt.save_dir, "continuous_feature_list.json")
|
|
||||||
twml.util.write_file(continous_feature_list_save_path, continuous_feature_list, encode="json")
|
|
||||||
tf.logging.info(f"Finish writting files to {continous_feature_list_save_path}")
|
|
||||||
|
|
||||||
if opt.run_light_ranking_group_metrics:
|
|
||||||
# --------------------------------------------
|
|
||||||
# Run Light Ranking Group Metrics
|
|
||||||
# --------------------------------------------
|
|
||||||
run_group_metrics_light_ranking(
|
|
||||||
trainer=trainer,
|
|
||||||
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
|
|
||||||
model_path=export_model_dir,
|
|
||||||
parse_fn=feature_config.get_parse_fn(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if opt.run_light_ranking_group_metrics_in_bq:
|
|
||||||
# ----------------------------------------------------------------------------------------
|
|
||||||
# Get Light/Heavy Ranker Predictions for Light Ranking Group Metrics in BigQuery
|
|
||||||
# ----------------------------------------------------------------------------------------
|
|
||||||
trainer_pred = DataRecordTrainer(
|
|
||||||
name=opt.model_trainer_name,
|
|
||||||
params=opt,
|
|
||||||
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
|
|
||||||
save_dir=opt.save_dir + "/tmp/",
|
|
||||||
run_config=None,
|
|
||||||
feature_config=feature_config,
|
|
||||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
|
||||||
)
|
|
||||||
checkpoint_folder = os.path.join(opt.save_dir, "best_checkpoint")
|
|
||||||
checkpoint = tf.train.latest_checkpoint(checkpoint_folder, latest_filename=None)
|
|
||||||
tf.logging.info("\n\nPrediction from Checkpoint: {:}.\n\n".format(checkpoint))
|
|
||||||
run_group_metrics_light_ranking_in_bq(
|
|
||||||
trainer=trainer_pred, params=opt, checkpoint_path=checkpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
tf.logging.info("Done Training & Prediction.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
_main()
|
|
BIN
pushservice/src/main/python/models/light_ranking/eval_model.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/eval_model.docx
Normal file
Binary file not shown.
@ -1,89 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from functools import partial
|
|
||||||
import os
|
|
||||||
|
|
||||||
from ..libs.group_metrics import (
|
|
||||||
run_group_metrics_light_ranking,
|
|
||||||
run_group_metrics_light_ranking_in_bq,
|
|
||||||
)
|
|
||||||
from ..libs.metric_fn_utils import get_metric_fn
|
|
||||||
from ..libs.model_args import get_arg_parser_light_ranking
|
|
||||||
from ..libs.model_utils import read_config
|
|
||||||
from .deep_norm import build_graph, DataRecordTrainer, get_config_func, logging
|
|
||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = get_arg_parser_light_ranking()
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_checkpoint",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="Which checkpoint to use for evaluation",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--saved_model_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="Path to saved model for evaluation",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--run_binary_metrics",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to compute the basic binary metrics for Light Ranking.",
|
|
||||||
)
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
logging.info("parse is: ")
|
|
||||||
logging.info(opt)
|
|
||||||
|
|
||||||
feature_list = read_config(opt.feature_list).items()
|
|
||||||
feature_config = get_config_func(opt.feat_config_type)(
|
|
||||||
data_spec_path=opt.data_spec,
|
|
||||||
feature_list_provided=feature_list,
|
|
||||||
opt=opt,
|
|
||||||
add_gbdt=opt.use_gbdt_features,
|
|
||||||
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
|
|
||||||
)
|
|
||||||
|
|
||||||
# -----------------------------------------------
|
|
||||||
# Create Trainer
|
|
||||||
# -----------------------------------------------
|
|
||||||
trainer = DataRecordTrainer(
|
|
||||||
name=opt.model_trainer_name,
|
|
||||||
params=opt,
|
|
||||||
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
|
|
||||||
save_dir=opt.save_dir,
|
|
||||||
run_config=None,
|
|
||||||
feature_config=feature_config,
|
|
||||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# -----------------------------------------------
|
|
||||||
# Model Evaluation
|
|
||||||
# -----------------------------------------------
|
|
||||||
logging.info("Evaluating...")
|
|
||||||
start = datetime.now()
|
|
||||||
|
|
||||||
if opt.run_binary_metrics:
|
|
||||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
|
||||||
eval_steps = None if (opt.eval_steps is not None and opt.eval_steps < 0) else opt.eval_steps
|
|
||||||
trainer.estimator.evaluate(eval_input_fn, steps=eval_steps, checkpoint_path=opt.eval_checkpoint)
|
|
||||||
|
|
||||||
if opt.run_light_ranking_group_metrics_in_bq:
|
|
||||||
run_group_metrics_light_ranking_in_bq(
|
|
||||||
trainer=trainer, params=opt, checkpoint_path=opt.eval_checkpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
if opt.run_light_ranking_group_metrics:
|
|
||||||
run_group_metrics_light_ranking(
|
|
||||||
trainer=trainer,
|
|
||||||
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
|
|
||||||
model_path=opt.saved_model_path,
|
|
||||||
parse_fn=feature_config.get_parse_fn(),
|
|
||||||
)
|
|
||||||
|
|
||||||
end = datetime.now()
|
|
||||||
logging.info("Evaluating time: " + str(end - start))
|
|
Binary file not shown.
@ -1,187 +0,0 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
from twml.contrib.layers import ZscoreNormalization
|
|
||||||
|
|
||||||
from ...libs.customized_full_sparse import FullSparse
|
|
||||||
from ...libs.get_feat_config import FEAT_CONFIG_DEFAULT_VAL as MISSING_VALUE_MARKER
|
|
||||||
from ...libs.model_utils import (
|
|
||||||
_sparse_feature_fixup,
|
|
||||||
adaptive_transformation,
|
|
||||||
filter_nans_and_infs,
|
|
||||||
get_dense_out,
|
|
||||||
tensor_dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
# checkstyle: noqa
|
|
||||||
|
|
||||||
def light_ranking_mlp_ngbdt(features, is_training, params, label=None):
|
|
||||||
return deepnorm_light_ranking(
|
|
||||||
features,
|
|
||||||
is_training,
|
|
||||||
params,
|
|
||||||
label=label,
|
|
||||||
decay=params.momentum,
|
|
||||||
dense_emb_size=params.dense_embedding_size,
|
|
||||||
base_activation=tf.keras.layers.LeakyReLU(),
|
|
||||||
input_dropout_rate=params.dropout,
|
|
||||||
use_gbdt=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def deepnorm_light_ranking(
|
|
||||||
features,
|
|
||||||
is_training,
|
|
||||||
params,
|
|
||||||
label=None,
|
|
||||||
decay=0.99999,
|
|
||||||
dense_emb_size=128,
|
|
||||||
base_activation=None,
|
|
||||||
input_dropout_rate=None,
|
|
||||||
input_dense_type="self_atten_dense",
|
|
||||||
emb_dense_type="self_atten_dense",
|
|
||||||
mlp_dense_type="self_atten_dense",
|
|
||||||
use_gbdt=False,
|
|
||||||
):
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# Initial Parameter Checking
|
|
||||||
# --------------------------------------------------------
|
|
||||||
if base_activation is None:
|
|
||||||
base_activation = tf.keras.layers.LeakyReLU()
|
|
||||||
|
|
||||||
if label is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"Label is unused in deepnorm_gbdt. Stop using this argument.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
with tf.variable_scope("helper_layers"):
|
|
||||||
full_sparse_layer = FullSparse(
|
|
||||||
output_size=params.sparse_embedding_size,
|
|
||||||
activation=base_activation,
|
|
||||||
use_sparse_grads=is_training,
|
|
||||||
use_binary_values=False,
|
|
||||||
dtype=tf.float32,
|
|
||||||
)
|
|
||||||
input_normalizing_layer = ZscoreNormalization(decay=decay, name="input_normalizing_layer")
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# Feature Selection & Embedding
|
|
||||||
# --------------------------------------------------------
|
|
||||||
if use_gbdt:
|
|
||||||
sparse_gbdt_features = _sparse_feature_fixup(features["gbdt_sparse"], params.input_size_bits)
|
|
||||||
if input_dropout_rate is not None:
|
|
||||||
sparse_gbdt_features = tensor_dropout(
|
|
||||||
sparse_gbdt_features, input_dropout_rate, is_training, sparse_tensor=True
|
|
||||||
)
|
|
||||||
|
|
||||||
total_embed = full_sparse_layer(sparse_gbdt_features, use_binary_values=True)
|
|
||||||
|
|
||||||
if (input_dropout_rate is not None) and is_training:
|
|
||||||
total_embed = total_embed / (1 - input_dropout_rate)
|
|
||||||
|
|
||||||
else:
|
|
||||||
with tf.variable_scope("dense_branch"):
|
|
||||||
dense_continuous_features = filter_nans_and_infs(features["continuous"])
|
|
||||||
|
|
||||||
if params.use_missing_sub_branch:
|
|
||||||
is_missing = tf.equal(dense_continuous_features, MISSING_VALUE_MARKER)
|
|
||||||
continuous_features_filled = tf.where(
|
|
||||||
is_missing,
|
|
||||||
tf.zeros_like(dense_continuous_features),
|
|
||||||
dense_continuous_features,
|
|
||||||
)
|
|
||||||
normalized_features = input_normalizing_layer(
|
|
||||||
continuous_features_filled, is_training, tf.math.logical_not(is_missing)
|
|
||||||
)
|
|
||||||
|
|
||||||
with tf.variable_scope("missing_sub_branch"):
|
|
||||||
missing_feature_embed = get_dense_out(
|
|
||||||
tf.cast(is_missing, tf.float32),
|
|
||||||
dense_emb_size,
|
|
||||||
activation=base_activation,
|
|
||||||
dense_type=input_dense_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
continuous_features_filled = dense_continuous_features
|
|
||||||
normalized_features = input_normalizing_layer(continuous_features_filled, is_training)
|
|
||||||
|
|
||||||
with tf.variable_scope("continuous_sub_branch"):
|
|
||||||
normalized_features = adaptive_transformation(
|
|
||||||
normalized_features, is_training, func_type="tiny"
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_dropout_rate is not None:
|
|
||||||
normalized_features = tensor_dropout(
|
|
||||||
normalized_features,
|
|
||||||
input_dropout_rate,
|
|
||||||
is_training,
|
|
||||||
sparse_tensor=False,
|
|
||||||
)
|
|
||||||
filled_feature_embed = get_dense_out(
|
|
||||||
normalized_features,
|
|
||||||
dense_emb_size,
|
|
||||||
activation=base_activation,
|
|
||||||
dense_type=input_dense_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_missing_sub_branch:
|
|
||||||
dense_embed = tf.concat(
|
|
||||||
[filled_feature_embed, missing_feature_embed], axis=1, name="merge_dense_emb"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dense_embed = filled_feature_embed
|
|
||||||
|
|
||||||
with tf.variable_scope("sparse_branch"):
|
|
||||||
sparse_discrete_features = _sparse_feature_fixup(
|
|
||||||
features["sparse_no_continuous"], params.input_size_bits
|
|
||||||
)
|
|
||||||
if input_dropout_rate is not None:
|
|
||||||
sparse_discrete_features = tensor_dropout(
|
|
||||||
sparse_discrete_features, input_dropout_rate, is_training, sparse_tensor=True
|
|
||||||
)
|
|
||||||
|
|
||||||
discrete_features_embed = full_sparse_layer(sparse_discrete_features, use_binary_values=True)
|
|
||||||
|
|
||||||
if (input_dropout_rate is not None) and is_training:
|
|
||||||
discrete_features_embed = discrete_features_embed / (1 - input_dropout_rate)
|
|
||||||
|
|
||||||
total_embed = tf.concat(
|
|
||||||
[dense_embed, discrete_features_embed],
|
|
||||||
axis=1,
|
|
||||||
name="total_embed",
|
|
||||||
)
|
|
||||||
|
|
||||||
total_embed = tf.layers.batch_normalization(
|
|
||||||
total_embed,
|
|
||||||
training=is_training,
|
|
||||||
renorm_momentum=decay,
|
|
||||||
momentum=decay,
|
|
||||||
renorm=is_training,
|
|
||||||
trainable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# MLP Layers
|
|
||||||
# --------------------------------------------------------
|
|
||||||
with tf.variable_scope("MLP_branch"):
|
|
||||||
|
|
||||||
assert params.num_mlp_layers >= 0
|
|
||||||
embed_list = [total_embed] + [None for _ in range(params.num_mlp_layers)]
|
|
||||||
dense_types = [emb_dense_type] + [mlp_dense_type for _ in range(params.num_mlp_layers - 1)]
|
|
||||||
|
|
||||||
for xl in range(1, params.num_mlp_layers + 1):
|
|
||||||
neurons = params.mlp_neuron_scale ** (params.num_mlp_layers + 1 - xl)
|
|
||||||
embed_list[xl] = get_dense_out(
|
|
||||||
embed_list[xl - 1], neurons, activation=base_activation, dense_type=dense_types[xl - 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.task_name in ["Sent", "HeavyRankPosition", "HeavyRankProbability"]:
|
|
||||||
logits = get_dense_out(embed_list[-1], 1, activation=None, dense_type=mlp_dense_type)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid Task Name !")
|
|
||||||
|
|
||||||
output_dict = {"output": logits}
|
|
||||||
return output_dict
|
|
@ -1,337 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["**/*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = [
|
|
||||||
"bazel-compatible",
|
|
||||||
],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/com/twitter/bijection:scrooge",
|
|
||||||
"3rdparty/jvm/com/twitter/storehaus:core",
|
|
||||||
"abdecider",
|
|
||||||
"abuse/detection/src/main/thrift/com/twitter/abuse/detection/scoring:thrift-scala",
|
|
||||||
"ann/src/main/scala/com/twitter/ann/common",
|
|
||||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
|
||||||
"audience-rewards/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"communities/thrift/src/main/thrift/com/twitter/communities:thrift-scala",
|
|
||||||
"configapi/configapi-core",
|
|
||||||
"configapi/configapi-decider",
|
|
||||||
"content-mixer/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"content-recommender/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"copyselectionservice/server/src/main/scala/com/twitter/copyselectionservice/algorithms",
|
|
||||||
"copyselectionservice/thrift/src/main/thrift:copyselectionservice-scala",
|
|
||||||
"cortex-deepbird/thrift/src/main/thrift:thrift-java",
|
|
||||||
"cr-mixer/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"cuad/projects/hashspace/thrift:thrift-scala",
|
|
||||||
"cuad/projects/tagspace/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"detopic/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"discovery-common/src/main/scala/com/twitter/discovery/common/configapi",
|
|
||||||
"discovery-common/src/main/scala/com/twitter/discovery/common/ddg",
|
|
||||||
"discovery-common/src/main/scala/com/twitter/discovery/common/environment",
|
|
||||||
"discovery-common/src/main/scala/com/twitter/discovery/common/fatigue",
|
|
||||||
"discovery-common/src/main/scala/com/twitter/discovery/common/nackwarmupfilter",
|
|
||||||
"discovery-common/src/main/scala/com/twitter/discovery/common/server",
|
|
||||||
"discovery-ds/src/main/thrift/com/twitter/dds/scio/searcher_aggregate_history_srp:searcher_aggregate_history_srp-scala",
|
|
||||||
"escherbird/src/scala/com/twitter/escherbird/util/metadatastitch",
|
|
||||||
"escherbird/src/scala/com/twitter/escherbird/util/uttclient",
|
|
||||||
"escherbird/src/thrift/com/twitter/escherbird/utt:strato-columns-scala",
|
|
||||||
"eventbus/client",
|
|
||||||
"eventdetection/event_context/src/main/scala/com/twitter/eventdetection/event_context/util",
|
|
||||||
"events-recos/events-recos-service/src/main/thrift:events-recos-thrift-scala",
|
|
||||||
"explore/explore-ranker/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"featureswitches/featureswitches-core/src/main/scala",
|
|
||||||
"featureswitches/featureswitches-core/src/main/scala:dynmap",
|
|
||||||
"featureswitches/featureswitches-core/src/main/scala:recipient",
|
|
||||||
"featureswitches/featureswitches-core/src/main/scala:useragent",
|
|
||||||
"featureswitches/featureswitches-core/src/main/scala/com/twitter/featureswitches/v2/builder",
|
|
||||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authentication",
|
|
||||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/server",
|
|
||||||
"finagle-internal/ostrich-stats",
|
|
||||||
"finagle/finagle-core/src/main",
|
|
||||||
"finagle/finagle-http/src/main/scala",
|
|
||||||
"finagle/finagle-memcached/src/main/scala",
|
|
||||||
"finagle/finagle-stats",
|
|
||||||
"finagle/finagle-thriftmux",
|
|
||||||
"finagle/finagle-tunable/src/main/scala",
|
|
||||||
"finagle/finagle-zipkin-scribe",
|
|
||||||
"finatra-internal/abdecider",
|
|
||||||
"finatra-internal/decider",
|
|
||||||
"finatra-internal/mtls-http/src/main/scala",
|
|
||||||
"finatra-internal/mtls-thriftmux/src/main/scala",
|
|
||||||
"finatra/http-client/src/main/scala",
|
|
||||||
"finatra/http-core/src/main/java/com/twitter/finatra/http",
|
|
||||||
"finatra/http-core/src/main/scala/com/twitter/finatra/http/response",
|
|
||||||
"finatra/http-server/src/main/scala/com/twitter/finatra/http",
|
|
||||||
"finatra/http-server/src/main/scala/com/twitter/finatra/http/filters",
|
|
||||||
"finatra/inject/inject-app/src/main/java/com/twitter/inject/annotations",
|
|
||||||
"finatra/inject/inject-app/src/main/scala",
|
|
||||||
"finatra/inject/inject-core/src/main/scala",
|
|
||||||
"finatra/inject/inject-server/src/main/scala",
|
|
||||||
"finatra/inject/inject-slf4j/src/main/scala/com/twitter/inject",
|
|
||||||
"finatra/inject/inject-thrift-client/src/main/scala",
|
|
||||||
"finatra/inject/inject-utils/src/main/scala",
|
|
||||||
"finatra/utils/src/main/java/com/twitter/finatra/annotations",
|
|
||||||
"fleets/fleets-proxy/thrift/src/main/thrift:fleet-scala",
|
|
||||||
"fleets/fleets-proxy/thrift/src/main/thrift/service:baseservice-scala",
|
|
||||||
"flock-client/src/main/scala",
|
|
||||||
"flock-client/src/main/thrift:thrift-scala",
|
|
||||||
"follow-recommendations-service/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"frigate/frigate-common:base",
|
|
||||||
"frigate/frigate-common:config",
|
|
||||||
"frigate/frigate-common:debug",
|
|
||||||
"frigate/frigate-common:entity_graph_client",
|
|
||||||
"frigate/frigate-common:history",
|
|
||||||
"frigate/frigate-common:logger",
|
|
||||||
"frigate/frigate-common:ml-base",
|
|
||||||
"frigate/frigate-common:ml-feature",
|
|
||||||
"frigate/frigate-common:ml-prediction",
|
|
||||||
"frigate/frigate-common:ntab",
|
|
||||||
"frigate/frigate-common:predicate",
|
|
||||||
"frigate/frigate-common:rec_types",
|
|
||||||
"frigate/frigate-common:score_summary",
|
|
||||||
"frigate/frigate-common:util",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/candidate",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/experiments",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/filter",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/modules/store:semantic_core_stores",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/deviceinfo",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/interests",
|
|
||||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/strato",
|
|
||||||
"frigate/push-mixer/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"geo/geo-prediction/src/main/thrift:local-viral-tweets-thrift-scala",
|
|
||||||
"geoduck/service/src/main/scala/com/twitter/geoduck/service/common/clientmodules",
|
|
||||||
"geoduck/util/country",
|
|
||||||
"gizmoduck/client/src/main/scala/com/twitter/gizmoduck/testusers/client",
|
|
||||||
"hermit/hermit-core:model-user_state",
|
|
||||||
"hermit/hermit-core:predicate",
|
|
||||||
"hermit/hermit-core:predicate-gizmoduck",
|
|
||||||
"hermit/hermit-core:predicate-scarecrow",
|
|
||||||
"hermit/hermit-core:predicate-socialgraph",
|
|
||||||
"hermit/hermit-core:predicate-tweetypie",
|
|
||||||
"hermit/hermit-core:store-labeled_push_recs",
|
|
||||||
"hermit/hermit-core:store-metastore",
|
|
||||||
"hermit/hermit-core:store-timezone",
|
|
||||||
"hermit/hermit-core:store-tweetypie",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/constants",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/model",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/common",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/gizmoduck",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/scarecrow",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/semantic_core",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_htl_session_store",
|
|
||||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_interest",
|
|
||||||
"hmli/hss/src/main/thrift/com/twitter/hss:thrift-scala",
|
|
||||||
"ibis2/service/src/main/scala/com/twitter/ibis2/lib",
|
|
||||||
"ibis2/service/src/main/thrift/com/twitter/ibis2/service:ibis2-service-scala",
|
|
||||||
"interests-service/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"interests_discovery/thrift/src/main/thrift:batch-thrift-scala",
|
|
||||||
"interests_discovery/thrift/src/main/thrift:service-thrift-scala",
|
|
||||||
"kujaku/thrift/src/main/thrift:domain-scala",
|
|
||||||
"live-video-timeline/client/src/main/scala/com/twitter/livevideo/timeline/client/v2",
|
|
||||||
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain",
|
|
||||||
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain/v2",
|
|
||||||
"live-video-timeline/thrift/src/main/thrift/com/twitter/livevideo/timeline:thrift-scala",
|
|
||||||
"live-video/common/src/main/scala/com/twitter/livevideo/common/domain/v2",
|
|
||||||
"live-video/common/src/main/scala/com/twitter/livevideo/common/ids",
|
|
||||||
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:exception-scala",
|
|
||||||
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:thrift-scala",
|
|
||||||
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:custom-notification-actions-scala",
|
|
||||||
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:thrift-scala",
|
|
||||||
"notifications-relevance/src/scala/com/twitter/nrel/heavyranker",
|
|
||||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/base",
|
|
||||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/frigate",
|
|
||||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/push",
|
|
||||||
"notifications-relevance/src/scala/com/twitter/nrel/lightranker",
|
|
||||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/genericfeedbackstore",
|
|
||||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model:alias",
|
|
||||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model/service",
|
|
||||||
"notificationservice/common/src/test/scala/com/twitter/notificationservice/mocks",
|
|
||||||
"notificationservice/scribe/src/main/scala/com/twitter/notificationservice/scribe/manhattan:mh_wrapper",
|
|
||||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/api:thrift-scala",
|
|
||||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/badgecount-api:thrift-scala",
|
|
||||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/generic_notifications:thrift-scala",
|
|
||||||
"notifinfra/ni-lib/src/main/scala/com/twitter/ni/lib/logged_out_transform",
|
|
||||||
"observability/observability-manhattan-client/src/main/scala",
|
|
||||||
"onboarding/service/src/main/scala/com/twitter/onboarding/task/service/models/external",
|
|
||||||
"onboarding/service/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"people-discovery/api/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"periscope/api-proxy-thrift/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module/stringcenter",
|
|
||||||
"product-mixer/core/src/main/thrift/com/twitter/product_mixer/core:thrift-scala",
|
|
||||||
"qig-ranker/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"rux-ds/src/main/thrift/com/twitter/ruxds/jobs/user_past_aggregate:user_past_aggregate-scala",
|
|
||||||
"rux/common/src/main/scala/com/twitter/rux/common/encode",
|
|
||||||
"rux/common/thrift/src/main/thrift/rux-context:rux-context-scala",
|
|
||||||
"rux/common/thrift/src/main/thrift/strato:strato-scala",
|
|
||||||
"scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers",
|
|
||||||
"scrooge/scrooge-core",
|
|
||||||
"scrooge/scrooge-serializer/src/main/scala",
|
|
||||||
"sensitive-ds/src/main/thrift/com/twitter/scio/nsfw_user_segmentation:nsfw_user_segmentation-scala",
|
|
||||||
"servo/decider/src/main/scala",
|
|
||||||
"servo/request/src/main/scala",
|
|
||||||
"servo/util/src/main/scala",
|
|
||||||
"src/java/com/twitter/ml/api:api-base",
|
|
||||||
"src/java/com/twitter/ml/prediction/core",
|
|
||||||
"src/scala/com/twitter/frigate/data_pipeline/common",
|
|
||||||
"src/scala/com/twitter/frigate/data_pipeline/embedding_cg:embedding_cg-test-user-ids",
|
|
||||||
"src/scala/com/twitter/frigate/data_pipeline/features_common",
|
|
||||||
"src/scala/com/twitter/frigate/news_article_recs/news_articles_metadata:thrift-scala",
|
|
||||||
"src/scala/com/twitter/frontpage/stream/util",
|
|
||||||
"src/scala/com/twitter/language/normalization",
|
|
||||||
"src/scala/com/twitter/ml/api/embedding",
|
|
||||||
"src/scala/com/twitter/ml/api/util:datarecord",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/catalog/entities/core",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/catalog/entities/magicrecs",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/catalog/features/core:aggregate",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/catalog/features/cuad:aggregate",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/catalog/features/embeddings",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/catalog/features/magicrecs:aggregate",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/catalog/features/topic_signals:aggregate",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/lib",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/lib/data",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/lib/dynamic",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/lib/entity",
|
|
||||||
"src/scala/com/twitter/ml/featurestore/lib/online",
|
|
||||||
"src/scala/com/twitter/recommendation/interests/discovery/core/config",
|
|
||||||
"src/scala/com/twitter/recommendation/interests/discovery/core/deploy",
|
|
||||||
"src/scala/com/twitter/recommendation/interests/discovery/core/model",
|
|
||||||
"src/scala/com/twitter/recommendation/interests/discovery/popgeo/deploy",
|
|
||||||
"src/scala/com/twitter/simclusters_v2/common",
|
|
||||||
"src/scala/com/twitter/storehaus_internal/manhattan",
|
|
||||||
"src/scala/com/twitter/storehaus_internal/manhattan/config",
|
|
||||||
"src/scala/com/twitter/storehaus_internal/memcache",
|
|
||||||
"src/scala/com/twitter/storehaus_internal/memcache/config",
|
|
||||||
"src/scala/com/twitter/storehaus_internal/util",
|
|
||||||
"src/scala/com/twitter/taxi/common",
|
|
||||||
"src/scala/com/twitter/taxi/config",
|
|
||||||
"src/scala/com/twitter/taxi/deploy",
|
|
||||||
"src/scala/com/twitter/taxi/trending/common",
|
|
||||||
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
|
|
||||||
"src/thrift/com/twitter/clientapp/gen:clientapp-scala",
|
|
||||||
"src/thrift/com/twitter/core_workflows/user_model:user_model-scala",
|
|
||||||
"src/thrift/com/twitter/escherbird/common:constants-scala",
|
|
||||||
"src/thrift/com/twitter/escherbird/metadata:megadata-scala",
|
|
||||||
"src/thrift/com/twitter/escherbird/metadata:metadata-service-scala",
|
|
||||||
"src/thrift/com/twitter/escherbird/search:search-service-scala",
|
|
||||||
"src/thrift/com/twitter/expandodo:only-scala",
|
|
||||||
"src/thrift/com/twitter/frigate:frigate-common-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate:frigate-ml-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate:frigate-notification-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate:frigate-secondary-accounts-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate:frigate-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate:frigate-user-media-representation-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/data_pipeline:frigate-user-history-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/dau_model:frigate-dau-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/magic_events:frigate-magic-events-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/magic_events/scribe:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/pushcap:frigate-pushcap-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/pushservice:frigate-pushservice-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/scribe:frigate-scribe-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/subscribed_search:frigate-subscribed-search-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/frigate/user_states:frigate-userstates-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/geoduck:geoduck-scala",
|
|
||||||
"src/thrift/com/twitter/gizmoduck:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/gizmoduck:user-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/hermit:hermit-scala",
|
|
||||||
"src/thrift/com/twitter/hermit/pop_geo:hermit-pop-geo-scala",
|
|
||||||
"src/thrift/com/twitter/hermit/stp:hermit-stp-scala",
|
|
||||||
"src/thrift/com/twitter/ibis:service-scala",
|
|
||||||
"src/thrift/com/twitter/manhattan:v1-scala",
|
|
||||||
"src/thrift/com/twitter/manhattan:v2-scala",
|
|
||||||
"src/thrift/com/twitter/ml/api:data-java",
|
|
||||||
"src/thrift/com/twitter/ml/api:data-scala",
|
|
||||||
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-scala",
|
|
||||||
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-strato",
|
|
||||||
"src/thrift/com/twitter/ml/prediction_service:prediction_service-java",
|
|
||||||
"src/thrift/com/twitter/permissions_storage:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/pink-floyd/thrift:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/recos:recos-common-scala",
|
|
||||||
"src/thrift/com/twitter/recos/user_tweet_entity_graph:user_tweet_entity_graph-scala",
|
|
||||||
"src/thrift/com/twitter/recos/user_user_graph:user_user_graph-scala",
|
|
||||||
"src/thrift/com/twitter/relevance/feature_store:feature_store-scala",
|
|
||||||
"src/thrift/com/twitter/search:earlybird-scala",
|
|
||||||
"src/thrift/com/twitter/search/common:features-scala",
|
|
||||||
"src/thrift/com/twitter/search/query_interaction_graph:query_interaction_graph-scala",
|
|
||||||
"src/thrift/com/twitter/search/query_interaction_graph/service:qig-service-scala",
|
|
||||||
"src/thrift/com/twitter/service/metastore/gen:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/service/scarecrow/gen:scarecrow-scala",
|
|
||||||
"src/thrift/com/twitter/service/scarecrow/gen:tiered-actions-scala",
|
|
||||||
"src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/socialgraph:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/spam/rtf:safety-level-scala",
|
|
||||||
"src/thrift/com/twitter/timelinemixer:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/timelinemixer/server/internal:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/timelines/author_features/user_health:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/timelines/real_graph:real_graph-scala",
|
|
||||||
"src/thrift/com/twitter/timelinescorer:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/timelinescorer/server/internal:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/timelineservice/server/internal:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/timelineservice/server/suggests/logging:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/trends/common:common-scala",
|
|
||||||
"src/thrift/com/twitter/trends/trip_v1:trip-tweets-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/tweetypie:service-scala",
|
|
||||||
"src/thrift/com/twitter/tweetypie:tweet-scala",
|
|
||||||
"src/thrift/com/twitter/user_session_store:thrift-scala",
|
|
||||||
"src/thrift/com/twitter/wtf/candidate:wtf-candidate-scala",
|
|
||||||
"src/thrift/com/twitter/wtf/interest:interest-thrift-scala",
|
|
||||||
"src/thrift/com/twitter/wtf/scalding/common:thrift-scala",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
"stitch/stitch-gizmoduck",
|
|
||||||
"stitch/stitch-socialgraph/src/main/scala",
|
|
||||||
"stitch/stitch-storehaus/src/main/scala",
|
|
||||||
"stitch/stitch-tweetypie/src/main/scala",
|
|
||||||
"storage/clients/manhattan/client/src/main/scala",
|
|
||||||
"strato/config/columns/clients:clients-strato-client",
|
|
||||||
"strato/config/columns/geo/user:user-strato-client",
|
|
||||||
"strato/config/columns/globe/curation:curation-strato-client",
|
|
||||||
"strato/config/columns/interests:interests-strato-client",
|
|
||||||
"strato/config/columns/ml/featureStore:featureStore-strato-client",
|
|
||||||
"strato/config/columns/notifications:notifications-strato-client",
|
|
||||||
"strato/config/columns/notifinfra:notifinfra-strato-client",
|
|
||||||
"strato/config/columns/periscope:periscope-strato-client",
|
|
||||||
"strato/config/columns/rux",
|
|
||||||
"strato/config/columns/rux:rux-strato-client",
|
|
||||||
"strato/config/columns/rux/open-app:open-app-strato-client",
|
|
||||||
"strato/config/columns/socialgraph/graphs:graphs-strato-client",
|
|
||||||
"strato/config/columns/socialgraph/service/soft_users:soft_users-strato-client",
|
|
||||||
"strato/config/columns/translation/service:service-strato-client",
|
|
||||||
"strato/config/columns/translation/service/platform:platform-strato-client",
|
|
||||||
"strato/config/columns/trends/trip:trip-strato-client",
|
|
||||||
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
|
|
||||||
"strato/config/src/thrift/com/twitter/strato/columns/notifications:thrift-scala",
|
|
||||||
"strato/src/main/scala/com/twitter/strato/config",
|
|
||||||
"strato/src/main/scala/com/twitter/strato/response",
|
|
||||||
"thrift-web-forms",
|
|
||||||
"timeline-training-service/service/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"timelines/src/main/scala/com/twitter/timelines/features/app",
|
|
||||||
"topic-social-proof/server/src/main/thrift:thrift-scala",
|
|
||||||
"topiclisting/topiclisting-core/src/main/scala/com/twitter/topiclisting",
|
|
||||||
"topiclisting/topiclisting-utt/src/main/scala/com/twitter/topiclisting/utt",
|
|
||||||
"trends/common/src/main/thrift/com/twitter/trends/common:thrift-scala",
|
|
||||||
"tweetypie/src/scala/com/twitter/tweetypie/tweettext",
|
|
||||||
"twitter-context/src/main/scala",
|
|
||||||
"twitter-server-internal",
|
|
||||||
"twitter-server/server/src/main/scala",
|
|
||||||
"twitter-text/lib/java/src/main/java/com/twitter/twittertext",
|
|
||||||
"twml/runtime/src/main/scala/com/twitter/deepbird/runtime/prediction_engine:prediction_engine_mkl",
|
|
||||||
"ubs/common/src/main/thrift/com/twitter/ubs:broadcast-thrift-scala",
|
|
||||||
"ubs/common/src/main/thrift/com/twitter/ubs:seller_application-thrift-scala",
|
|
||||||
"user_session_store/src/main/scala/com/twitter/user_session_store/impl/manhattan/readwrite",
|
|
||||||
"util-internal/scribe",
|
|
||||||
"util-internal/tunable/src/main/scala/com/twitter/util/tunable",
|
|
||||||
"util/util-app",
|
|
||||||
"util/util-hashing/src/main/scala",
|
|
||||||
"util/util-slf4j-api/src/main/scala",
|
|
||||||
"util/util-stats/src/main/scala",
|
|
||||||
"visibility/lib/src/main/scala/com/twitter/visibility/builder",
|
|
||||||
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/push_service",
|
|
||||||
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/spaces",
|
|
||||||
"visibility/lib/src/main/scala/com/twitter/visibility/util",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,93 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice
|
|
||||||
|
|
||||||
import com.google.inject.Inject
|
|
||||||
import com.google.inject.Singleton
|
|
||||||
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
|
|
||||||
import com.twitter.finagle.thrift.ClientId
|
|
||||||
import com.twitter.finatra.thrift.routing.ThriftWarmup
|
|
||||||
import com.twitter.util.logging.Logging
|
|
||||||
import com.twitter.inject.utils.Handler
|
|
||||||
import com.twitter.frigate.pushservice.{thriftscala => t}
|
|
||||||
import com.twitter.frigate.thriftscala.NotificationDisplayLocation
|
|
||||||
import com.twitter.util.Stopwatch
|
|
||||||
import com.twitter.scrooge.Request
|
|
||||||
import com.twitter.scrooge.Response
|
|
||||||
import com.twitter.util.Return
|
|
||||||
import com.twitter.util.Throw
|
|
||||||
import com.twitter.util.Try
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Warms up the refresh request path.
|
|
||||||
* If service is running as pushservice-send then the warmup does nothing.
|
|
||||||
*
|
|
||||||
* When making the warmup refresh requests we
|
|
||||||
* - Set skipFilters to true to execute as much of the request path as possible
|
|
||||||
* - Set darkWrite to true to prevent sending a push
|
|
||||||
*/
|
|
||||||
@Singleton
|
|
||||||
class PushMixerThriftServerWarmupHandler @Inject() (
|
|
||||||
warmup: ThriftWarmup,
|
|
||||||
serviceIdentifier: ServiceIdentifier)
|
|
||||||
extends Handler
|
|
||||||
with Logging {
|
|
||||||
|
|
||||||
private val clientId = ClientId("thrift-warmup-client")
|
|
||||||
|
|
||||||
def handle(): Unit = {
|
|
||||||
val refreshServices = Set(
|
|
||||||
"frigate-pushservice",
|
|
||||||
"frigate-pushservice-canary",
|
|
||||||
"frigate-pushservice-canary-control",
|
|
||||||
"frigate-pushservice-canary-treatment"
|
|
||||||
)
|
|
||||||
val isRefresh = refreshServices.contains(serviceIdentifier.service)
|
|
||||||
if (isRefresh && !serviceIdentifier.isLocal) refreshWarmup()
|
|
||||||
}
|
|
||||||
|
|
||||||
def refreshWarmup(): Unit = {
|
|
||||||
val elapsed = Stopwatch.start()
|
|
||||||
val testIds = Seq(
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3
|
|
||||||
)
|
|
||||||
try {
|
|
||||||
clientId.asCurrent {
|
|
||||||
testIds.foreach { id =>
|
|
||||||
val warmupReq = warmupQuery(id)
|
|
||||||
info(s"Sending warm-up request to service with query: $warmupReq")
|
|
||||||
warmup.sendRequest(
|
|
||||||
method = t.PushService.Refresh,
|
|
||||||
req = Request(t.PushService.Refresh.Args(warmupReq)))(assertWarmupResponse)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
case e: Throwable =>
|
|
||||||
error(e.getMessage, e)
|
|
||||||
}
|
|
||||||
info(s"Warm up complete. Time taken: ${elapsed().toString}")
|
|
||||||
}
|
|
||||||
|
|
||||||
private def warmupQuery(userId: Long): t.RefreshRequest = {
|
|
||||||
t.RefreshRequest(
|
|
||||||
userId = userId,
|
|
||||||
notificationDisplayLocation = NotificationDisplayLocation.PushToMobileDevice,
|
|
||||||
context = Some(
|
|
||||||
t.PushContext(
|
|
||||||
skipFilters = Some(true),
|
|
||||||
darkWrite = Some(true)
|
|
||||||
))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def assertWarmupResponse(
|
|
||||||
result: Try[Response[t.PushService.Refresh.SuccessType]]
|
|
||||||
): Unit = {
|
|
||||||
result match {
|
|
||||||
case Return(_) => // ok
|
|
||||||
case Throw(exception) =>
|
|
||||||
warn("Error performing warm-up request.")
|
|
||||||
error(exception.getMessage, exception)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,193 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice
|
|
||||||
|
|
||||||
import com.twitter.discovery.common.environment.modules.EnvironmentModule
|
|
||||||
import com.twitter.finagle.Filter
|
|
||||||
import com.twitter.finatra.annotations.DarkTrafficFilterType
|
|
||||||
import com.twitter.finatra.decider.modules.DeciderModule
|
|
||||||
import com.twitter.finatra.http.HttpServer
|
|
||||||
import com.twitter.finatra.http.filters.CommonFilters
|
|
||||||
import com.twitter.finatra.http.routing.HttpRouter
|
|
||||||
import com.twitter.finatra.mtls.http.{Mtls => HttpMtls}
|
|
||||||
import com.twitter.finatra.mtls.thriftmux.{Mtls => ThriftMtls}
|
|
||||||
import com.twitter.finatra.mtls.thriftmux.filters.MtlsServerSessionTrackerFilter
|
|
||||||
import com.twitter.finatra.thrift.ThriftServer
|
|
||||||
import com.twitter.finatra.thrift.filters.ExceptionMappingFilter
|
|
||||||
import com.twitter.finatra.thrift.filters.LoggingMDCFilter
|
|
||||||
import com.twitter.finatra.thrift.filters.StatsFilter
|
|
||||||
import com.twitter.finatra.thrift.filters.ThriftMDCFilter
|
|
||||||
import com.twitter.finatra.thrift.filters.TraceIdMDCFilter
|
|
||||||
import com.twitter.finatra.thrift.routing.ThriftRouter
|
|
||||||
import com.twitter.frigate.common.logger.MRLoggerGlobalVariables
|
|
||||||
import com.twitter.frigate.pushservice.controller.PushServiceController
|
|
||||||
import com.twitter.frigate.pushservice.module._
|
|
||||||
import com.twitter.inject.TwitterModule
|
|
||||||
import com.twitter.inject.annotations.Flags
|
|
||||||
import com.twitter.inject.thrift.modules.ThriftClientIdModule
|
|
||||||
import com.twitter.logging.BareFormatter
|
|
||||||
import com.twitter.logging.Level
|
|
||||||
import com.twitter.logging.LoggerFactory
|
|
||||||
import com.twitter.logging.{Logging => JLogging}
|
|
||||||
import com.twitter.logging.QueueingHandler
|
|
||||||
import com.twitter.logging.ScribeHandler
|
|
||||||
import com.twitter.product_mixer.core.module.product_mixer_flags.ProductMixerFlagModule
|
|
||||||
import com.twitter.product_mixer.core.module.ABDeciderModule
|
|
||||||
import com.twitter.product_mixer.core.module.FeatureSwitchesModule
|
|
||||||
import com.twitter.product_mixer.core.module.StratoClientModule
|
|
||||||
|
|
||||||
object PushServiceMain extends PushServiceFinatraServer
|
|
||||||
|
|
||||||
class PushServiceFinatraServer
|
|
||||||
extends ThriftServer
|
|
||||||
with ThriftMtls
|
|
||||||
with HttpServer
|
|
||||||
with HttpMtls
|
|
||||||
with JLogging {
|
|
||||||
|
|
||||||
override val name = "PushService"
|
|
||||||
|
|
||||||
override val modules: Seq[TwitterModule] = {
|
|
||||||
Seq(
|
|
||||||
ABDeciderModule,
|
|
||||||
DeciderModule,
|
|
||||||
FeatureSwitchesModule,
|
|
||||||
FilterModule,
|
|
||||||
FlagModule,
|
|
||||||
EnvironmentModule,
|
|
||||||
ThriftClientIdModule,
|
|
||||||
DeployConfigModule,
|
|
||||||
ProductMixerFlagModule,
|
|
||||||
StratoClientModule,
|
|
||||||
PushHandlerModule,
|
|
||||||
PushTargetUserBuilderModule,
|
|
||||||
PushServiceDarkTrafficModule,
|
|
||||||
LoggedOutPushTargetUserBuilderModule,
|
|
||||||
new ThriftWebFormsModule(this),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def configureThrift(router: ThriftRouter): Unit = {
|
|
||||||
router
|
|
||||||
.filter[ExceptionMappingFilter]
|
|
||||||
.filter[LoggingMDCFilter]
|
|
||||||
.filter[TraceIdMDCFilter]
|
|
||||||
.filter[ThriftMDCFilter]
|
|
||||||
.filter[MtlsServerSessionTrackerFilter]
|
|
||||||
.filter[StatsFilter]
|
|
||||||
.filter[Filter.TypeAgnostic, DarkTrafficFilterType]
|
|
||||||
.add[PushServiceController]
|
|
||||||
}
|
|
||||||
|
|
||||||
override def configureHttp(router: HttpRouter): Unit =
|
|
||||||
router
|
|
||||||
.filter[CommonFilters]
|
|
||||||
|
|
||||||
override protected def start(): Unit = {
|
|
||||||
MRLoggerGlobalVariables.setRequiredFlags(
|
|
||||||
traceLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerIsTraceAll.name)),
|
|
||||||
nthLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerNthLog.name)),
|
|
||||||
nthLogValFlag = injector.instance[Long](Flags.named(FlagModule.mrLoggerNthVal.name))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override protected def warmup(): Unit = {
|
|
||||||
handle[PushMixerThriftServerWarmupHandler]()
|
|
||||||
}
|
|
||||||
|
|
||||||
override protected def configureLoggerFactories(): Unit = {
|
|
||||||
loggerFactories.foreach { _() }
|
|
||||||
}
|
|
||||||
|
|
||||||
override def loggerFactories: List[LoggerFactory] = {
|
|
||||||
val scribeScope = statsReceiver.scope("scribe")
|
|
||||||
List(
|
|
||||||
LoggerFactory(
|
|
||||||
level = Some(levelFlag()),
|
|
||||||
handlers = handlers
|
|
||||||
),
|
|
||||||
LoggerFactory(
|
|
||||||
node = "request_scribe",
|
|
||||||
level = Some(Level.INFO),
|
|
||||||
useParents = false,
|
|
||||||
handlers = QueueingHandler(
|
|
||||||
maxQueueSize = 10000,
|
|
||||||
handler = ScribeHandler(
|
|
||||||
category = "frigate_pushservice_log",
|
|
||||||
formatter = BareFormatter,
|
|
||||||
statsReceiver = scribeScope.scope("frigate_pushservice_log")
|
|
||||||
)
|
|
||||||
) :: Nil
|
|
||||||
),
|
|
||||||
LoggerFactory(
|
|
||||||
node = "notification_scribe",
|
|
||||||
level = Some(Level.INFO),
|
|
||||||
useParents = false,
|
|
||||||
handlers = QueueingHandler(
|
|
||||||
maxQueueSize = 10000,
|
|
||||||
handler = ScribeHandler(
|
|
||||||
category = "frigate_notifier",
|
|
||||||
formatter = BareFormatter,
|
|
||||||
statsReceiver = scribeScope.scope("frigate_notifier")
|
|
||||||
)
|
|
||||||
) :: Nil
|
|
||||||
),
|
|
||||||
LoggerFactory(
|
|
||||||
node = "push_scribe",
|
|
||||||
level = Some(Level.INFO),
|
|
||||||
useParents = false,
|
|
||||||
handlers = QueueingHandler(
|
|
||||||
maxQueueSize = 10000,
|
|
||||||
handler = ScribeHandler(
|
|
||||||
category = "test_frigate_push",
|
|
||||||
formatter = BareFormatter,
|
|
||||||
statsReceiver = scribeScope.scope("test_frigate_push")
|
|
||||||
)
|
|
||||||
) :: Nil
|
|
||||||
),
|
|
||||||
LoggerFactory(
|
|
||||||
node = "push_subsample_scribe",
|
|
||||||
level = Some(Level.INFO),
|
|
||||||
useParents = false,
|
|
||||||
handlers = QueueingHandler(
|
|
||||||
maxQueueSize = 2500,
|
|
||||||
handler = ScribeHandler(
|
|
||||||
category = "magicrecs_candidates_subsample_scribe",
|
|
||||||
maxMessagesPerTransaction = 250,
|
|
||||||
maxMessagesToBuffer = 2500,
|
|
||||||
formatter = BareFormatter,
|
|
||||||
statsReceiver = scribeScope.scope("magicrecs_candidates_subsample_scribe")
|
|
||||||
)
|
|
||||||
) :: Nil
|
|
||||||
),
|
|
||||||
LoggerFactory(
|
|
||||||
node = "mr_request_scribe",
|
|
||||||
level = Some(Level.INFO),
|
|
||||||
useParents = false,
|
|
||||||
handlers = QueueingHandler(
|
|
||||||
maxQueueSize = 2500,
|
|
||||||
handler = ScribeHandler(
|
|
||||||
category = "mr_request_scribe",
|
|
||||||
maxMessagesPerTransaction = 250,
|
|
||||||
maxMessagesToBuffer = 2500,
|
|
||||||
formatter = BareFormatter,
|
|
||||||
statsReceiver = scribeScope.scope("mr_request_scribe")
|
|
||||||
)
|
|
||||||
) :: Nil
|
|
||||||
),
|
|
||||||
LoggerFactory(
|
|
||||||
node = "high_quality_candidates_scribe",
|
|
||||||
level = Some(Level.INFO),
|
|
||||||
useParents = false,
|
|
||||||
handlers = QueueingHandler(
|
|
||||||
maxQueueSize = 2500,
|
|
||||||
handler = ScribeHandler(
|
|
||||||
category = "frigate_high_quality_candidates_log",
|
|
||||||
maxMessagesPerTransaction = 250,
|
|
||||||
maxMessagesToBuffer = 2500,
|
|
||||||
formatter = BareFormatter,
|
|
||||||
statsReceiver = scribeScope.scope("high_quality_candidates_scribe")
|
|
||||||
)
|
|
||||||
) :: Nil
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,323 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.contentrecommender.thriftscala.MetricTag
|
|
||||||
import com.twitter.cr_mixer.thriftscala.CrMixerTweetRequest
|
|
||||||
import com.twitter.cr_mixer.thriftscala.NotificationsContext
|
|
||||||
import com.twitter.cr_mixer.thriftscala.Product
|
|
||||||
import com.twitter.cr_mixer.thriftscala.ProductContext
|
|
||||||
import com.twitter.cr_mixer.thriftscala.{MetricTag => CrMixerMetricTag}
|
|
||||||
import com.twitter.finagle.stats.Stat
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base.AlgorithmScore
|
|
||||||
import com.twitter.frigate.common.base.CandidateSource
|
|
||||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
|
||||||
import com.twitter.frigate.common.base.CrMixerCandidate
|
|
||||||
import com.twitter.frigate.common.base.TopicCandidate
|
|
||||||
import com.twitter.frigate.common.base.TopicProofTweetCandidate
|
|
||||||
import com.twitter.frigate.common.base.TweetCandidate
|
|
||||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutInNetworkTweets
|
|
||||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
|
||||||
import com.twitter.frigate.pushservice.params.PushParams
|
|
||||||
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
|
|
||||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
|
||||||
import com.twitter.frigate.pushservice.util.AdaptorUtils
|
|
||||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
|
||||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
|
||||||
import com.twitter.frigate.pushservice.util.TweetWithTopicProof
|
|
||||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
|
||||||
import com.twitter.hermit.predicate.socialgraph.RelationEdge
|
|
||||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
|
||||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.topiclisting.utt.LocalizedEntity
|
|
||||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
|
||||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import scala.collection.Map
|
|
||||||
|
|
||||||
case class ContentRecommenderMixerAdaptor(
|
|
||||||
crMixerTweetStore: CrMixerTweetStore,
|
|
||||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
|
||||||
edgeStore: ReadableStore[RelationEdge, Boolean],
|
|
||||||
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
|
|
||||||
uttEntityHydrationStore: UttEntityHydrationStore,
|
|
||||||
globalStats: StatsReceiver)
|
|
||||||
extends CandidateSource[Target, RawCandidate]
|
|
||||||
with CandidateSourceEligible[Target, RawCandidate] {
|
|
||||||
|
|
||||||
override val name: String = this.getClass.getSimpleName
|
|
||||||
|
|
||||||
private[this] val stats = globalStats.scope("ContentRecommenderMixerAdaptor")
|
|
||||||
private[this] val numOfValidAuthors = stats.stat("num_of_valid_authors")
|
|
||||||
private[this] val numOutOfMaximumDropped = stats.stat("dropped_due_out_of_maximum")
|
|
||||||
private[this] val totalInputRecs = stats.counter("input_recs")
|
|
||||||
private[this] val totalOutputRecs = stats.stat("output_recs")
|
|
||||||
private[this] val totalRequests = stats.counter("total_requests")
|
|
||||||
private[this] val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
|
||||||
private[this] val totalOutNetworkRecs = stats.counter("out_network_tweets")
|
|
||||||
private[this] val totalInNetworkRecs = stats.counter("in_network_tweets")
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds OON raw candidates based on input OON Tweets
|
|
||||||
*/
|
|
||||||
def buildOONRawCandidates(
|
|
||||||
inputTarget: Target,
|
|
||||||
oonTweets: Seq[TweetyPieResult],
|
|
||||||
tweetScoreMap: Map[Long, Double],
|
|
||||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
|
|
||||||
maxNumOfCandidates: Int
|
|
||||||
): Option[Seq[RawCandidate]] = {
|
|
||||||
val cands = oonTweets.flatMap { tweetResult =>
|
|
||||||
val tweetId = tweetResult.tweet.id
|
|
||||||
generateOONRawCandidate(
|
|
||||||
inputTarget,
|
|
||||||
tweetId,
|
|
||||||
Some(tweetResult),
|
|
||||||
tweetScoreMap,
|
|
||||||
tweetIdToTagsMap
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
val candidates = restrict(
|
|
||||||
maxNumOfCandidates,
|
|
||||||
cands,
|
|
||||||
numOutOfMaximumDropped,
|
|
||||||
totalOutputRecs
|
|
||||||
)
|
|
||||||
|
|
||||||
Some(candidates)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds a single RawCandidate With TopicProofTweetCandidate
|
|
||||||
*/
|
|
||||||
def buildTopicTweetRawCandidate(
|
|
||||||
inputTarget: Target,
|
|
||||||
tweetWithTopicProof: TweetWithTopicProof,
|
|
||||||
localizedEntity: LocalizedEntity,
|
|
||||||
tags: Option[Seq[MetricTag]],
|
|
||||||
): RawCandidate with TopicProofTweetCandidate = {
|
|
||||||
new RawCandidate with TopicProofTweetCandidate {
|
|
||||||
override def target: Target = inputTarget
|
|
||||||
override def topicListingSetting: Option[String] = Some(
|
|
||||||
tweetWithTopicProof.topicListingSetting)
|
|
||||||
override def tweetId: Long = tweetWithTopicProof.tweetId
|
|
||||||
override def tweetyPieResult: Option[TweetyPieResult] = Some(
|
|
||||||
tweetWithTopicProof.tweetyPieResult)
|
|
||||||
override def semanticCoreEntityId: Option[Long] = Some(tweetWithTopicProof.topicId)
|
|
||||||
override def localizedUttEntity: Option[LocalizedEntity] = Some(localizedEntity)
|
|
||||||
override def algorithmCR: Option[String] = tweetWithTopicProof.algorithmCR
|
|
||||||
override def tagsCR: Option[Seq[MetricTag]] = tags
|
|
||||||
override def isOutOfNetwork: Boolean = tweetWithTopicProof.isOON
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Takes a group of TopicTweets and transforms them into RawCandidates
|
|
||||||
*/
|
|
||||||
def buildTopicTweetRawCandidates(
|
|
||||||
inputTarget: Target,
|
|
||||||
topicProofCandidates: Seq[TweetWithTopicProof],
|
|
||||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
|
|
||||||
maxNumberOfCands: Int
|
|
||||||
): Future[Option[Seq[RawCandidate]]] = {
|
|
||||||
val semanticCoreEntityIds = topicProofCandidates
|
|
||||||
.map(_.topicId)
|
|
||||||
.toSet
|
|
||||||
|
|
||||||
TopicsUtil
|
|
||||||
.getLocalizedEntityMap(inputTarget, semanticCoreEntityIds, uttEntityHydrationStore)
|
|
||||||
.map { localizedEntityMap =>
|
|
||||||
val rawCandidates = topicProofCandidates.collect {
|
|
||||||
case topicSocialProof: TweetWithTopicProof
|
|
||||||
if localizedEntityMap.contains(topicSocialProof.topicId) =>
|
|
||||||
// Once we deprecate CR calls, we should replace this code to use the CrMixerMetricTag
|
|
||||||
val tags = tweetIdToTagsMap.get(topicSocialProof.tweetId).map {
|
|
||||||
_.flatMap { tag => MetricTag.get(tag.value) }
|
|
||||||
}
|
|
||||||
buildTopicTweetRawCandidate(
|
|
||||||
inputTarget,
|
|
||||||
topicSocialProof,
|
|
||||||
localizedEntityMap(topicSocialProof.topicId),
|
|
||||||
tags
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
val candResult = restrict(
|
|
||||||
maxNumberOfCands,
|
|
||||||
rawCandidates,
|
|
||||||
numOutOfMaximumDropped,
|
|
||||||
totalOutputRecs
|
|
||||||
)
|
|
||||||
|
|
||||||
Some(candResult)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def generateOONRawCandidate(
|
|
||||||
inputTarget: Target,
|
|
||||||
id: Long,
|
|
||||||
result: Option[TweetyPieResult],
|
|
||||||
tweetScoreMap: Map[Long, Double],
|
|
||||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]]
|
|
||||||
): Option[RawCandidate with TweetCandidate] = {
|
|
||||||
val tagsFromCR = tweetIdToTagsMap.get(id).map { _.flatMap { tag => MetricTag.get(tag.value) } }
|
|
||||||
val candidate = new RawCandidate with CrMixerCandidate with TopicCandidate with AlgorithmScore {
|
|
||||||
override val tweetId = id
|
|
||||||
override val target = inputTarget
|
|
||||||
override val tweetyPieResult = result
|
|
||||||
override val localizedUttEntity = None
|
|
||||||
override val semanticCoreEntityId = None
|
|
||||||
override def commonRecType =
|
|
||||||
getMediaBasedCRT(
|
|
||||||
CommonRecommendationType.TwistlyTweet,
|
|
||||||
CommonRecommendationType.TwistlyPhoto,
|
|
||||||
CommonRecommendationType.TwistlyVideo)
|
|
||||||
override def tagsCR = tagsFromCR
|
|
||||||
override def algorithmScore = tweetScoreMap.get(id)
|
|
||||||
override def algorithmCR = None
|
|
||||||
}
|
|
||||||
Some(candidate)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def restrict(
|
|
||||||
maxNumToReturn: Int,
|
|
||||||
candidates: Seq[RawCandidate],
|
|
||||||
numOutOfMaximumDropped: Stat,
|
|
||||||
totalOutputRecs: Stat
|
|
||||||
): Seq[RawCandidate] = {
|
|
||||||
val newCandidates = candidates.take(maxNumToReturn)
|
|
||||||
val numDropped = candidates.length - newCandidates.length
|
|
||||||
numOutOfMaximumDropped.add(numDropped)
|
|
||||||
totalOutputRecs.add(newCandidates.size)
|
|
||||||
newCandidates
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildCrMixerRequest(
|
|
||||||
target: Target,
|
|
||||||
countryCode: Option[String],
|
|
||||||
language: Option[String],
|
|
||||||
seenTweets: Seq[Long]
|
|
||||||
): CrMixerTweetRequest = {
|
|
||||||
CrMixerTweetRequest(
|
|
||||||
clientContext = ClientContext(
|
|
||||||
userId = Some(target.targetId),
|
|
||||||
countryCode = countryCode,
|
|
||||||
languageCode = language
|
|
||||||
),
|
|
||||||
product = Product.Notifications,
|
|
||||||
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
|
|
||||||
excludedTweetIds = Some(seenTweets)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def selectCandidatesToSendBasedOnSettings(
|
|
||||||
isRecommendationsEligible: Boolean,
|
|
||||||
isTopicsEligible: Boolean,
|
|
||||||
oonRawCandidates: Option[Seq[RawCandidate]],
|
|
||||||
topicTweetCandidates: Option[Seq[RawCandidate]]
|
|
||||||
): Option[Seq[RawCandidate]] = {
|
|
||||||
if (isRecommendationsEligible && isTopicsEligible) {
|
|
||||||
Some(topicTweetCandidates.getOrElse(Seq.empty) ++ oonRawCandidates.getOrElse(Seq.empty))
|
|
||||||
} else if (isRecommendationsEligible) {
|
|
||||||
oonRawCandidates
|
|
||||||
} else if (isTopicsEligible) {
|
|
||||||
topicTweetCandidates
|
|
||||||
} else None
|
|
||||||
}
|
|
||||||
|
|
||||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
|
||||||
Future
|
|
||||||
.join(
|
|
||||||
target.seenTweetIds,
|
|
||||||
target.countryCode,
|
|
||||||
target.inferredUserDeviceLanguage,
|
|
||||||
PushDeviceUtil.isTopicsEligible(target),
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target)
|
|
||||||
).flatMap {
|
|
||||||
case (seenTweets, countryCode, language, isTopicsEligible, isRecommendationsEligible) =>
|
|
||||||
val request = buildCrMixerRequest(target, countryCode, language, seenTweets)
|
|
||||||
crMixerTweetStore.getTweetRecommendations(request).flatMap {
|
|
||||||
case Some(response) =>
|
|
||||||
totalInputRecs.incr(response.tweets.size)
|
|
||||||
totalRequests.incr()
|
|
||||||
AdaptorUtils
|
|
||||||
.getTweetyPieResults(
|
|
||||||
response.tweets.map(_.tweetId).toSet,
|
|
||||||
tweetyPieStore).flatMap { tweetyPieResultMap =>
|
|
||||||
filterOutInNetworkTweets(
|
|
||||||
target,
|
|
||||||
filterOutReplyTweet(tweetyPieResultMap.toMap, nonReplyTweetsCounter),
|
|
||||||
edgeStore,
|
|
||||||
numOfValidAuthors).flatMap {
|
|
||||||
outNetworkTweetsWithId: Seq[(Long, TweetyPieResult)] =>
|
|
||||||
totalOutNetworkRecs.incr(outNetworkTweetsWithId.size)
|
|
||||||
totalInNetworkRecs.incr(response.tweets.size - outNetworkTweetsWithId.size)
|
|
||||||
val outNetworkTweets: Seq[TweetyPieResult] = outNetworkTweetsWithId.map {
|
|
||||||
case (_, tweetyPieResult) => tweetyPieResult
|
|
||||||
}
|
|
||||||
|
|
||||||
val tweetIdToTagsMap = response.tweets.map { tweet =>
|
|
||||||
tweet.tweetId -> tweet.metricTags.getOrElse(Seq.empty)
|
|
||||||
}.toMap
|
|
||||||
|
|
||||||
val tweetScoreMap = response.tweets.map { tweet =>
|
|
||||||
tweet.tweetId -> tweet.score
|
|
||||||
}.toMap
|
|
||||||
|
|
||||||
val maxNumOfCandidates =
|
|
||||||
target.params(PushFeatureSwitchParams.NumberOfMaxCrMixerCandidatesParam)
|
|
||||||
|
|
||||||
val oonRawCandidates =
|
|
||||||
buildOONRawCandidates(
|
|
||||||
target,
|
|
||||||
outNetworkTweets,
|
|
||||||
tweetScoreMap,
|
|
||||||
tweetIdToTagsMap,
|
|
||||||
maxNumOfCandidates)
|
|
||||||
|
|
||||||
TopicsUtil
|
|
||||||
.getTopicSocialProofs(
|
|
||||||
target,
|
|
||||||
outNetworkTweets,
|
|
||||||
topicSocialProofServiceStore,
|
|
||||||
edgeStore,
|
|
||||||
PushFeatureSwitchParams.TopicProofTweetCandidatesTopicScoreThreshold).flatMap {
|
|
||||||
tweetsWithTopicProof =>
|
|
||||||
buildTopicTweetRawCandidates(
|
|
||||||
target,
|
|
||||||
tweetsWithTopicProof,
|
|
||||||
tweetIdToTagsMap,
|
|
||||||
maxNumOfCandidates)
|
|
||||||
}.map { topicTweetCandidates =>
|
|
||||||
selectCandidatesToSendBasedOnSettings(
|
|
||||||
isRecommendationsEligible,
|
|
||||||
isTopicsEligible,
|
|
||||||
oonRawCandidates,
|
|
||||||
topicTweetCandidates)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case _ => Future.None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* For a user to be available the following news to happen
|
|
||||||
*/
|
|
||||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
|
||||||
Future
|
|
||||||
.join(
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target),
|
|
||||||
PushDeviceUtil.isTopicsEligible(target)
|
|
||||||
).map {
|
|
||||||
case (isRecommendationsEligible, isTopicsEligible) =>
|
|
||||||
(isRecommendationsEligible || isTopicsEligible) &&
|
|
||||||
target.params(PushParams.ContentRecommenderMixerAdaptorDecider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,293 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.finagle.stats.Stat
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base._
|
|
||||||
import com.twitter.frigate.common.candidate._
|
|
||||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
|
||||||
import com.twitter.frigate.pushservice.params.PushParams
|
|
||||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
|
||||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
|
||||||
import com.twitter.recos.recos_common.thriftscala.SocialProofType
|
|
||||||
import com.twitter.search.common.features.thriftscala.ThriftSearchResultFeatures
|
|
||||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.timelines.configapi.Param
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import com.twitter.util.Time
|
|
||||||
import scala.collection.Map
|
|
||||||
|
|
||||||
case class EarlyBirdFirstDegreeCandidateAdaptor(
|
|
||||||
earlyBirdFirstDegreeCandidates: CandidateSource[
|
|
||||||
EarlybirdCandidateSource.Query,
|
|
||||||
EarlybirdCandidate
|
|
||||||
],
|
|
||||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
|
||||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
|
||||||
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
|
||||||
maxResultsParam: Param[Int],
|
|
||||||
globalStats: StatsReceiver)
|
|
||||||
extends CandidateSource[Target, RawCandidate]
|
|
||||||
with CandidateSourceEligible[Target, RawCandidate] {
|
|
||||||
|
|
||||||
type EBCandidate = EarlybirdCandidate with TweetDetails
|
|
||||||
private val stats = globalStats.scope("EarlyBirdFirstDegreeAdaptor")
|
|
||||||
private val earlyBirdCandsStat: Stat = stats.stat("early_bird_cands_dist")
|
|
||||||
private val emptyEarlyBirdCands = stats.counter("empty_early_bird_candidates")
|
|
||||||
private val seedSetEmpty = stats.counter("empty_seedset")
|
|
||||||
private val seenTweetsStat = stats.stat("filtered_by_seen_tweets")
|
|
||||||
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
|
|
||||||
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
|
||||||
private val enableRetweets = stats.counter("enable_retweets")
|
|
||||||
private val f1withoutSocialContexts = stats.counter("f1_without_social_context")
|
|
||||||
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
|
|
||||||
|
|
||||||
override val name: String = earlyBirdFirstDegreeCandidates.name
|
|
||||||
|
|
||||||
private def getAllSocialContextActions(
|
|
||||||
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
|
|
||||||
): Seq[SocialContextAction] = {
|
|
||||||
socialProofTypes.flatMap {
|
|
||||||
case (SocialProofType.Favorite, scIds) =>
|
|
||||||
scIds.map { scId =>
|
|
||||||
SocialContextAction(
|
|
||||||
scId,
|
|
||||||
Time.now.inMilliseconds,
|
|
||||||
socialContextActionType = Some(SocialContextActionType.Favorite)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
case (SocialProofType.Retweet, scIds) =>
|
|
||||||
scIds.map { scId =>
|
|
||||||
SocialContextAction(
|
|
||||||
scId,
|
|
||||||
Time.now.inMilliseconds,
|
|
||||||
socialContextActionType = Some(SocialContextActionType.Retweet)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
case (SocialProofType.Reply, scIds) =>
|
|
||||||
scIds.map { scId =>
|
|
||||||
SocialContextAction(
|
|
||||||
scId,
|
|
||||||
Time.now.inMilliseconds,
|
|
||||||
socialContextActionType = Some(SocialContextActionType.Reply)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
case (SocialProofType.Tweet, scIds) =>
|
|
||||||
scIds.map { scId =>
|
|
||||||
SocialContextAction(
|
|
||||||
scId,
|
|
||||||
Time.now.inMilliseconds,
|
|
||||||
socialContextActionType = Some(SocialContextActionType.Tweet)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
case _ => Nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def generateRetweetCandidate(
|
|
||||||
inputTarget: Target,
|
|
||||||
candidate: EBCandidate,
|
|
||||||
scIds: Seq[Long],
|
|
||||||
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
|
|
||||||
): RawCandidate = {
|
|
||||||
val scActions = scIds.map { scId => SocialContextAction(scId, Time.now.inMilliseconds) }
|
|
||||||
new RawCandidate with TweetRetweetCandidate with EarlybirdTweetFeatures {
|
|
||||||
override val socialContextActions = scActions
|
|
||||||
override val socialContextAllTypeActions = getAllSocialContextActions(socialProofTypes)
|
|
||||||
override val tweetId = candidate.tweetId
|
|
||||||
override val target = inputTarget
|
|
||||||
override val tweetyPieResult = candidate.tweetyPieResult
|
|
||||||
override val features = candidate.features
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def generateF1CandidateWithoutSocialContext(
|
|
||||||
inputTarget: Target,
|
|
||||||
candidate: EBCandidate
|
|
||||||
): RawCandidate = {
|
|
||||||
f1withoutSocialContexts.incr()
|
|
||||||
new RawCandidate with F1FirstDegree with EarlybirdTweetFeatures {
|
|
||||||
override val tweetId = candidate.tweetId
|
|
||||||
override val target = inputTarget
|
|
||||||
override val tweetyPieResult = candidate.tweetyPieResult
|
|
||||||
override val features = candidate.features
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def generateEarlyBirdCandidate(
|
|
||||||
id: Long,
|
|
||||||
result: Option[TweetyPieResult],
|
|
||||||
ebFeatures: Option[ThriftSearchResultFeatures]
|
|
||||||
): EBCandidate = {
|
|
||||||
new EarlybirdCandidate with TweetDetails {
|
|
||||||
override val tweetyPieResult: Option[TweetyPieResult] = result
|
|
||||||
override val tweetId: Long = id
|
|
||||||
override val features: Option[ThriftSearchResultFeatures] = ebFeatures
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def filterOutSeenTweets(seenTweetIds: Seq[Long], inputTweetIds: Seq[Long]): Seq[Long] = {
|
|
||||||
inputTweetIds.filterNot(seenTweetIds.contains)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def filterInvalidTweets(
|
|
||||||
tweetIds: Seq[Long],
|
|
||||||
target: Target
|
|
||||||
): Future[Seq[(Long, TweetyPieResult)]] = {
|
|
||||||
|
|
||||||
val resMap = {
|
|
||||||
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
|
|
||||||
userTweetTweetyPieStoreCounter.incr()
|
|
||||||
val keys = tweetIds.map { tweetId =>
|
|
||||||
UserTweet(tweetId, Some(target.targetId))
|
|
||||||
}
|
|
||||||
|
|
||||||
userTweetTweetyPieStore
|
|
||||||
.multiGet(keys.toSet).map {
|
|
||||||
case (userTweet, resultFut) =>
|
|
||||||
userTweet.tweetId -> resultFut
|
|
||||||
}.toMap
|
|
||||||
} else {
|
|
||||||
(target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
|
|
||||||
case true => tweetyPieStore
|
|
||||||
case false => tweetyPieStoreNoVF
|
|
||||||
}).multiGet(tweetIds.toSet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Future.collect(resMap).map { tweetyPieResultMap =>
|
|
||||||
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
|
|
||||||
case (id: Long, Some(result)) =>
|
|
||||||
id -> result
|
|
||||||
}
|
|
||||||
|
|
||||||
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
|
|
||||||
cands.toSeq
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getEBRetweetCandidates(
|
|
||||||
inputTarget: Target,
|
|
||||||
retweets: Seq[(Long, TweetyPieResult)]
|
|
||||||
): Seq[RawCandidate] = {
|
|
||||||
retweets.flatMap {
|
|
||||||
case (_, tweetypieResult) =>
|
|
||||||
tweetypieResult.tweet.coreData.flatMap { coreData =>
|
|
||||||
tweetypieResult.sourceTweet.map { sourceTweet =>
|
|
||||||
val tweetId = sourceTweet.id
|
|
||||||
val scId = coreData.userId
|
|
||||||
val socialProofTypes = Seq((SocialProofType.Retweet, Seq(scId)))
|
|
||||||
val candidate = generateEarlyBirdCandidate(
|
|
||||||
tweetId,
|
|
||||||
Some(TweetyPieResult(sourceTweet, None, None)),
|
|
||||||
None
|
|
||||||
)
|
|
||||||
generateRetweetCandidate(
|
|
||||||
inputTarget,
|
|
||||||
candidate,
|
|
||||||
Seq(scId),
|
|
||||||
socialProofTypes
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getEBFirstDegreeCands(
|
|
||||||
tweets: Seq[(Long, TweetyPieResult)],
|
|
||||||
ebTweetIdMap: Map[Long, Option[ThriftSearchResultFeatures]]
|
|
||||||
): Seq[EBCandidate] = {
|
|
||||||
tweets.map {
|
|
||||||
case (id, tweetypieResult) =>
|
|
||||||
val features = ebTweetIdMap.getOrElse(id, None)
|
|
||||||
generateEarlyBirdCandidate(id, Some(tweetypieResult), features)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a combination of raw candidates made of: f1 recs, topic social proof recs, sc recs and retweet candidates
|
|
||||||
*/
|
|
||||||
def buildRawCandidates(
|
|
||||||
inputTarget: Target,
|
|
||||||
firstDegreeCandidates: Seq[EBCandidate],
|
|
||||||
retweetCandidates: Seq[RawCandidate]
|
|
||||||
): Seq[RawCandidate] = {
|
|
||||||
val hydratedF1Recs =
|
|
||||||
firstDegreeCandidates.map(generateF1CandidateWithoutSocialContext(inputTarget, _))
|
|
||||||
hydratedF1Recs ++ retweetCandidates
|
|
||||||
}
|
|
||||||
|
|
||||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
|
||||||
inputTarget.seedsWithWeight.flatMap { seedsetOpt =>
|
|
||||||
val seedsetMap = seedsetOpt.getOrElse(Map.empty)
|
|
||||||
|
|
||||||
if (seedsetMap.isEmpty) {
|
|
||||||
seedSetEmpty.incr()
|
|
||||||
Future.None
|
|
||||||
} else {
|
|
||||||
val maxResultsToReturn = inputTarget.params(maxResultsParam)
|
|
||||||
val maxTweetAge = inputTarget.params(PushFeatureSwitchParams.F1CandidateMaxTweetAgeParam)
|
|
||||||
val earlybirdQuery = EarlybirdCandidateSource.Query(
|
|
||||||
maxNumResultsToReturn = maxResultsToReturn,
|
|
||||||
seedset = seedsetMap,
|
|
||||||
maxConsecutiveResultsByTheSameUser = Some(1),
|
|
||||||
maxTweetAge = maxTweetAge,
|
|
||||||
disableTimelinesMLModel = false,
|
|
||||||
searcherId = Some(inputTarget.targetId),
|
|
||||||
isProtectTweetsEnabled =
|
|
||||||
inputTarget.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors),
|
|
||||||
followedUserIds = Some(seedsetMap.keySet.toSeq)
|
|
||||||
)
|
|
||||||
|
|
||||||
Future
|
|
||||||
.join(inputTarget.seenTweetIds, earlyBirdFirstDegreeCandidates.get(earlybirdQuery))
|
|
||||||
.flatMap {
|
|
||||||
case (seenTweetIds, Some(candidates)) =>
|
|
||||||
earlyBirdCandsStat.add(candidates.size)
|
|
||||||
|
|
||||||
val ebTweetIdMap = candidates.map { cand => cand.tweetId -> cand.features }.toMap
|
|
||||||
|
|
||||||
val ebTweetIds = ebTweetIdMap.keys.toSeq
|
|
||||||
|
|
||||||
val tweetIds = filterOutSeenTweets(seenTweetIds, ebTweetIds)
|
|
||||||
seenTweetsStat.add(ebTweetIds.size - tweetIds.size)
|
|
||||||
|
|
||||||
filterInvalidTweets(tweetIds, inputTarget)
|
|
||||||
.map { validTweets =>
|
|
||||||
val (retweets, tweets) = validTweets.partition {
|
|
||||||
case (_, tweetypieResult) =>
|
|
||||||
tweetypieResult.sourceTweet.isDefined
|
|
||||||
}
|
|
||||||
|
|
||||||
val firstDegreeCandidates = getEBFirstDegreeCands(tweets, ebTweetIdMap)
|
|
||||||
|
|
||||||
val retweetCandidates = {
|
|
||||||
if (inputTarget.params(PushParams.EarlyBirdSCBasedCandidatesParam) &&
|
|
||||||
inputTarget.params(PushParams.MRTweetRetweetRecsParam)) {
|
|
||||||
enableRetweets.incr()
|
|
||||||
getEBRetweetCandidates(inputTarget, retweets)
|
|
||||||
} else Nil
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(
|
|
||||||
buildRawCandidates(
|
|
||||||
inputTarget,
|
|
||||||
firstDegreeCandidates,
|
|
||||||
retweetCandidates
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
case _ =>
|
|
||||||
emptyEarlyBirdCands.incr()
|
|
||||||
Future.None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target)
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,120 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerProductResponse
|
|
||||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerRequest
|
|
||||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerResponse
|
|
||||||
import com.twitter.explore_ranker.thriftscala.ExploreRecommendation
|
|
||||||
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResponse
|
|
||||||
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResult
|
|
||||||
import com.twitter.explore_ranker.thriftscala.NotificationsVideoRecs
|
|
||||||
import com.twitter.explore_ranker.thriftscala.Product
|
|
||||||
import com.twitter.explore_ranker.thriftscala.ProductContext
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base.CandidateSource
|
|
||||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
|
||||||
import com.twitter.frigate.common.base.OutOfNetworkTweetCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
|
||||||
import com.twitter.frigate.pushservice.util.AdaptorUtils
|
|
||||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
|
||||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
|
||||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
|
||||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
|
||||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
|
||||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
case class ExploreVideoTweetCandidateAdaptor(
|
|
||||||
exploreRankerStore: ReadableStore[ExploreRankerRequest, ExploreRankerResponse],
|
|
||||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
|
||||||
globalStats: StatsReceiver)
|
|
||||||
extends CandidateSource[Target, RawCandidate]
|
|
||||||
with CandidateSourceEligible[Target, RawCandidate] {
|
|
||||||
|
|
||||||
override def name: String = this.getClass.getSimpleName
|
|
||||||
private[this] val stats = globalStats.scope("ExploreVideoTweetCandidateAdaptor")
|
|
||||||
private[this] val totalInputRecs = stats.stat("input_recs")
|
|
||||||
private[this] val totalRequests = stats.counter("total_requests")
|
|
||||||
private[this] val totalEmptyResponse = stats.counter("total_empty_response")
|
|
||||||
|
|
||||||
private def buildExploreRankerRequest(
|
|
||||||
target: Target,
|
|
||||||
countryCode: Option[String],
|
|
||||||
language: Option[String],
|
|
||||||
): ExploreRankerRequest = {
|
|
||||||
ExploreRankerRequest(
|
|
||||||
clientContext = ClientContext(
|
|
||||||
userId = Some(target.targetId),
|
|
||||||
countryCode = countryCode,
|
|
||||||
languageCode = language,
|
|
||||||
),
|
|
||||||
product = Product.NotificationsVideoRecs,
|
|
||||||
productContext = Some(ProductContext.NotificationsVideoRecs(NotificationsVideoRecs())),
|
|
||||||
maxResults = Some(target.params(PushFeatureSwitchParams.MaxExploreVideoTweets))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
|
||||||
Future
|
|
||||||
.join(
|
|
||||||
target.countryCode,
|
|
||||||
target.inferredUserDeviceLanguage
|
|
||||||
).flatMap {
|
|
||||||
case (countryCode, language) =>
|
|
||||||
val request = buildExploreRankerRequest(target, countryCode, language)
|
|
||||||
exploreRankerStore.get(request).flatMap {
|
|
||||||
case Some(response) =>
|
|
||||||
val exploreResonseTweetIds = response match {
|
|
||||||
case ExploreRankerResponse(ExploreRankerProductResponse
|
|
||||||
.ImmersiveRecsResponse(ImmersiveRecsResponse(immersiveRecsResult))) =>
|
|
||||||
immersiveRecsResult.collect {
|
|
||||||
case ImmersiveRecsResult(ExploreRecommendation
|
|
||||||
.ExploreTweetRecommendation(exploreTweetRecommendation)) =>
|
|
||||||
exploreTweetRecommendation.tweetId
|
|
||||||
}
|
|
||||||
case _ =>
|
|
||||||
Seq.empty
|
|
||||||
}
|
|
||||||
|
|
||||||
totalInputRecs.add(exploreResonseTweetIds.size)
|
|
||||||
totalRequests.incr()
|
|
||||||
AdaptorUtils
|
|
||||||
.getTweetyPieResults(exploreResonseTweetIds.toSet, tweetyPieStore).map {
|
|
||||||
tweetyPieResultMap =>
|
|
||||||
val candidates = tweetyPieResultMap.values.flatten
|
|
||||||
.map(buildVideoRawCandidates(target, _))
|
|
||||||
Some(candidates.toSeq)
|
|
||||||
}
|
|
||||||
case _ =>
|
|
||||||
totalEmptyResponse.incr()
|
|
||||||
Future.None
|
|
||||||
}
|
|
||||||
case _ =>
|
|
||||||
Future.None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target).map { userRecommendationsEligible =>
|
|
||||||
userRecommendationsEligible && target.params(PushFeatureSwitchParams.EnableExploreVideoTweets)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private def buildVideoRawCandidates(
|
|
||||||
target: Target,
|
|
||||||
tweetyPieResult: TweetyPieResult
|
|
||||||
): RawCandidate with OutOfNetworkTweetCandidate = {
|
|
||||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
|
||||||
inputTarget = target,
|
|
||||||
id = tweetyPieResult.tweet.id,
|
|
||||||
mediaCRT = MediaCRT(
|
|
||||||
CommonRecommendationType.ExploreVideoTweet,
|
|
||||||
CommonRecommendationType.ExploreVideoTweet,
|
|
||||||
CommonRecommendationType.ExploreVideoTweet
|
|
||||||
),
|
|
||||||
result = Some(tweetyPieResult),
|
|
||||||
localizedEntity = None
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,272 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.cr_mixer.thriftscala.FrsTweetRequest
|
|
||||||
import com.twitter.cr_mixer.thriftscala.NotificationsContext
|
|
||||||
import com.twitter.cr_mixer.thriftscala.Product
|
|
||||||
import com.twitter.cr_mixer.thriftscala.ProductContext
|
|
||||||
import com.twitter.finagle.stats.Counter
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base.CandidateSource
|
|
||||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
|
||||||
import com.twitter.frigate.common.base._
|
|
||||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
|
||||||
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
|
|
||||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
|
||||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
|
||||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
|
||||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
|
||||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
|
||||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
|
||||||
import com.twitter.hermit.constants.AlgorithmFeedbackTokens
|
|
||||||
import com.twitter.hermit.model.Algorithm.Algorithm
|
|
||||||
import com.twitter.hermit.model.Algorithm.CrowdSearchAccounts
|
|
||||||
import com.twitter.hermit.model.Algorithm.ForwardEmailBook
|
|
||||||
import com.twitter.hermit.model.Algorithm.ForwardPhoneBook
|
|
||||||
import com.twitter.hermit.model.Algorithm.ReverseEmailBookIbis
|
|
||||||
import com.twitter.hermit.model.Algorithm.ReversePhoneBook
|
|
||||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
|
||||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
|
||||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
|
||||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
object FRSAlgorithmFeedbackTokenUtil {
|
|
||||||
private val crtsByAlgoToken = Map(
|
|
||||||
getAlgorithmToken(ReverseEmailBookIbis) -> CommonRecommendationType.ReverseAddressbookTweet,
|
|
||||||
getAlgorithmToken(ReversePhoneBook) -> CommonRecommendationType.ReverseAddressbookTweet,
|
|
||||||
getAlgorithmToken(ForwardEmailBook) -> CommonRecommendationType.ForwardAddressbookTweet,
|
|
||||||
getAlgorithmToken(ForwardPhoneBook) -> CommonRecommendationType.ForwardAddressbookTweet,
|
|
||||||
getAlgorithmToken(CrowdSearchAccounts) -> CommonRecommendationType.CrowdSearchTweet
|
|
||||||
)
|
|
||||||
|
|
||||||
def getAlgorithmToken(algorithm: Algorithm): Int = {
|
|
||||||
AlgorithmFeedbackTokens.AlgorithmToFeedbackTokenMap(algorithm)
|
|
||||||
}
|
|
||||||
|
|
||||||
def getCRTForAlgoToken(algorithmToken: Int): Option[CommonRecommendationType] = {
|
|
||||||
crtsByAlgoToken.get(algorithmToken)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case class FRSTweetCandidateAdaptor(
|
|
||||||
crMixerTweetStore: CrMixerTweetStore,
|
|
||||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
|
||||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
|
||||||
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
|
||||||
uttEntityHydrationStore: UttEntityHydrationStore,
|
|
||||||
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
|
|
||||||
globalStats: StatsReceiver)
|
|
||||||
extends CandidateSource[Target, RawCandidate]
|
|
||||||
with CandidateSourceEligible[Target, RawCandidate] {
|
|
||||||
|
|
||||||
private val stats = globalStats.scope(this.getClass.getSimpleName)
|
|
||||||
private val crtStats = stats.scope("CandidateDistribution")
|
|
||||||
private val totalRequests = stats.counter("total_requests")
|
|
||||||
|
|
||||||
// Candidate Distribution stats
|
|
||||||
private val reverseAddressbookCounter = crtStats.counter("reverse_addressbook")
|
|
||||||
private val forwardAddressbookCounter = crtStats.counter("forward_addressbook")
|
|
||||||
private val frsTweetCounter = crtStats.counter("frs_tweet")
|
|
||||||
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
|
||||||
private val crtToCounterMapping: Map[CommonRecommendationType, Counter] = Map(
|
|
||||||
CommonRecommendationType.ReverseAddressbookTweet -> reverseAddressbookCounter,
|
|
||||||
CommonRecommendationType.ForwardAddressbookTweet -> forwardAddressbookCounter,
|
|
||||||
CommonRecommendationType.FrsTweet -> frsTweetCounter
|
|
||||||
)
|
|
||||||
|
|
||||||
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
|
|
||||||
|
|
||||||
private[this] val numberReturnedCandidates = stats.stat("returned_candidates_from_earlybird")
|
|
||||||
private[this] val numberCandidateWithTopic: Counter = stats.counter("num_can_with_topic")
|
|
||||||
private[this] val numberCandidateWithoutTopic: Counter = stats.counter("num_can_without_topic")
|
|
||||||
|
|
||||||
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
|
|
||||||
|
|
||||||
override val name: String = this.getClass.getSimpleName
|
|
||||||
|
|
||||||
private def filterInvalidTweets(
|
|
||||||
tweetIds: Seq[Long],
|
|
||||||
target: Target
|
|
||||||
): Future[Map[Long, TweetyPieResult]] = {
|
|
||||||
val resMap = {
|
|
||||||
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
|
|
||||||
userTweetTweetyPieStoreCounter.incr()
|
|
||||||
val keys = tweetIds.map { tweetId =>
|
|
||||||
UserTweet(tweetId, Some(target.targetId))
|
|
||||||
}
|
|
||||||
userTweetTweetyPieStore
|
|
||||||
.multiGet(keys.toSet).map {
|
|
||||||
case (userTweet, resultFut) =>
|
|
||||||
userTweet.tweetId -> resultFut
|
|
||||||
}.toMap
|
|
||||||
} else {
|
|
||||||
(if (target.params(PushFeatureSwitchParams.EnableVFInTweetypie)) {
|
|
||||||
tweetyPieStore
|
|
||||||
} else {
|
|
||||||
tweetyPieStoreNoVF
|
|
||||||
}).multiGet(tweetIds.toSet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Future.collect(resMap).map { tweetyPieResultMap =>
|
|
||||||
// Filter out replies and generate earlybird candidates only for non-empty tweetypie result
|
|
||||||
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
|
|
||||||
case (id: Long, Some(result)) =>
|
|
||||||
id -> result
|
|
||||||
}
|
|
||||||
|
|
||||||
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
|
|
||||||
cands
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildRawCandidates(
|
|
||||||
target: Target,
|
|
||||||
ebCandidates: Seq[FRSTweetCandidate]
|
|
||||||
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
|
||||||
|
|
||||||
val enableTopic = target.params(PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicAnnotation)
|
|
||||||
val topicScoreThre =
|
|
||||||
target.params(PushFeatureSwitchParams.FrsTweetCandidatesTopicScoreThreshold)
|
|
||||||
|
|
||||||
val ebTweets = ebCandidates.map { ebCandidate =>
|
|
||||||
ebCandidate.tweetId -> ebCandidate.tweetyPieResult
|
|
||||||
}.toMap
|
|
||||||
|
|
||||||
val tweetIdLocalizedEntityMapFut = TopicsUtil.getTweetIdLocalizedEntityMap(
|
|
||||||
target,
|
|
||||||
ebTweets,
|
|
||||||
uttEntityHydrationStore,
|
|
||||||
topicSocialProofServiceStore,
|
|
||||||
enableTopic,
|
|
||||||
topicScoreThre
|
|
||||||
)
|
|
||||||
|
|
||||||
Future.join(target.deviceInfo, tweetIdLocalizedEntityMapFut).map {
|
|
||||||
case (Some(deviceInfo), tweetIdLocalizedEntityMap) =>
|
|
||||||
val candidates = ebCandidates
|
|
||||||
.map { ebCandidate =>
|
|
||||||
val crt = ebCandidate.commonRecType
|
|
||||||
crtToCounterMapping.get(crt).foreach(_.incr())
|
|
||||||
|
|
||||||
val tweetId = ebCandidate.tweetId
|
|
||||||
val localizedEntityOpt = {
|
|
||||||
if (tweetIdLocalizedEntityMap
|
|
||||||
.contains(tweetId) && tweetIdLocalizedEntityMap.contains(
|
|
||||||
tweetId) && deviceInfo.isTopicsEligible) {
|
|
||||||
tweetIdLocalizedEntityMap(tweetId)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
|
||||||
inputTarget = target,
|
|
||||||
id = ebCandidate.tweetId,
|
|
||||||
mediaCRT = MediaCRT(
|
|
||||||
crt,
|
|
||||||
crt,
|
|
||||||
crt
|
|
||||||
),
|
|
||||||
result = ebCandidate.tweetyPieResult,
|
|
||||||
localizedEntity = localizedEntityOpt)
|
|
||||||
}.filter { candidate =>
|
|
||||||
// If user only has the topic setting enabled, filter out all non-topic cands
|
|
||||||
deviceInfo.isRecommendationsEligible || (deviceInfo.isTopicsEligible && candidate.semanticCoreEntityId.nonEmpty)
|
|
||||||
}
|
|
||||||
|
|
||||||
candidates.map { candidate =>
|
|
||||||
if (candidate.semanticCoreEntityId.nonEmpty) {
|
|
||||||
numberCandidateWithTopic.incr()
|
|
||||||
} else {
|
|
||||||
numberCandidateWithoutTopic.incr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
numberReturnedCandidates.add(candidates.length)
|
|
||||||
Some(candidates)
|
|
||||||
case _ => Some(Seq.empty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def getTweetCandidatesFromCrMixer(
|
|
||||||
inputTarget: Target,
|
|
||||||
showAllResultsFromFrs: Boolean,
|
|
||||||
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
|
||||||
Future
|
|
||||||
.join(
|
|
||||||
inputTarget.seenTweetIds,
|
|
||||||
inputTarget.pushRecItems,
|
|
||||||
inputTarget.countryCode,
|
|
||||||
inputTarget.targetLanguage).flatMap {
|
|
||||||
case (seenTweetIds, pastRecItems, countryCode, language) =>
|
|
||||||
val pastUserRecs = pastRecItems.userIds.toSeq
|
|
||||||
val request = FrsTweetRequest(
|
|
||||||
clientContext = ClientContext(
|
|
||||||
userId = Some(inputTarget.targetId),
|
|
||||||
countryCode = countryCode,
|
|
||||||
languageCode = language
|
|
||||||
),
|
|
||||||
product = Product.Notifications,
|
|
||||||
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
|
|
||||||
excludedUserIds = Some(pastUserRecs),
|
|
||||||
excludedTweetIds = Some(seenTweetIds)
|
|
||||||
)
|
|
||||||
crMixerTweetStore.getFRSTweetCandidates(request).flatMap {
|
|
||||||
case Some(response) =>
|
|
||||||
val tweetIds = response.tweets.map(_.tweetId)
|
|
||||||
val validTweets = filterInvalidTweets(tweetIds, inputTarget)
|
|
||||||
validTweets.flatMap { tweetypieMap =>
|
|
||||||
val ebCandidates = response.tweets
|
|
||||||
.map { frsTweet =>
|
|
||||||
val candidateTweetId = frsTweet.tweetId
|
|
||||||
val resultFromTweetyPie = tweetypieMap.get(candidateTweetId)
|
|
||||||
new FRSTweetCandidate {
|
|
||||||
override val tweetId = candidateTweetId
|
|
||||||
override val features = None
|
|
||||||
override val tweetyPieResult = resultFromTweetyPie
|
|
||||||
override val feedbackToken = frsTweet.frsPrimarySource
|
|
||||||
override val commonRecType: CommonRecommendationType = feedbackToken
|
|
||||||
.flatMap(token =>
|
|
||||||
FRSAlgorithmFeedbackTokenUtil.getCRTForAlgoToken(token)).getOrElse(
|
|
||||||
CommonRecommendationType.FrsTweet)
|
|
||||||
}
|
|
||||||
}.filter { ebCandidate =>
|
|
||||||
showAllResultsFromFrs || ebCandidate.commonRecType == CommonRecommendationType.ReverseAddressbookTweet
|
|
||||||
}
|
|
||||||
|
|
||||||
numberReturnedCandidates.add(ebCandidates.length)
|
|
||||||
buildRawCandidates(
|
|
||||||
inputTarget,
|
|
||||||
ebCandidates
|
|
||||||
)
|
|
||||||
}
|
|
||||||
case _ => Future.None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
|
||||||
totalRequests.incr()
|
|
||||||
val enableResultsFromFrs =
|
|
||||||
inputTarget.params(PushFeatureSwitchParams.EnableResultFromFrsCandidates)
|
|
||||||
getTweetCandidatesFromCrMixer(inputTarget, enableResultsFromFrs)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
|
||||||
lazy val enableFrsCandidates = target.params(PushFeatureSwitchParams.EnableFrsCandidates)
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target).flatMap { isEnabledForRecosSetting =>
|
|
||||||
PushDeviceUtil.isTopicsEligible(target).map { topicSettingEnabled =>
|
|
||||||
val isEnabledForTopics =
|
|
||||||
topicSettingEnabled && target.params(
|
|
||||||
PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicSetting)
|
|
||||||
(isEnabledForRecosSetting || isEnabledForTopics) && enableFrsCandidates
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,107 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base._
|
|
||||||
import com.twitter.frigate.common.candidate._
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.frigate.pushservice.params.PushParams
|
|
||||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
|
||||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
object GenericCandidates {
|
|
||||||
type Target =
|
|
||||||
TargetUser
|
|
||||||
with UserDetails
|
|
||||||
with TargetDecider
|
|
||||||
with TargetABDecider
|
|
||||||
with TweetImpressionHistory
|
|
||||||
with HTLVisitHistory
|
|
||||||
with MaxTweetAge
|
|
||||||
with NewUserDetails
|
|
||||||
with FrigateHistory
|
|
||||||
with TargetWithSeedUsers
|
|
||||||
}
|
|
||||||
|
|
||||||
case class GenericCandidateAdaptor(
|
|
||||||
genericCandidates: CandidateSource[GenericCandidates.Target, Candidate],
|
|
||||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
|
||||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
|
||||||
stats: StatsReceiver)
|
|
||||||
extends CandidateSource[Target, RawCandidate]
|
|
||||||
with CandidateSourceEligible[Target, RawCandidate] {
|
|
||||||
|
|
||||||
override val name: String = genericCandidates.name
|
|
||||||
|
|
||||||
private def generateTweetFavCandidate(
|
|
||||||
_target: Target,
|
|
||||||
_tweetId: Long,
|
|
||||||
_socialContextActions: Seq[SocialContextAction],
|
|
||||||
socialContextActionsAllTypes: Seq[SocialContextAction],
|
|
||||||
_tweetyPieResult: Option[TweetyPieResult]
|
|
||||||
): RawCandidate = {
|
|
||||||
new RawCandidate with TweetFavoriteCandidate {
|
|
||||||
override val socialContextActions = _socialContextActions
|
|
||||||
override val socialContextAllTypeActions =
|
|
||||||
socialContextActionsAllTypes
|
|
||||||
val tweetId = _tweetId
|
|
||||||
val target = _target
|
|
||||||
val tweetyPieResult = _tweetyPieResult
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def generateTweetRetweetCandidate(
|
|
||||||
_target: Target,
|
|
||||||
_tweetId: Long,
|
|
||||||
_socialContextActions: Seq[SocialContextAction],
|
|
||||||
socialContextActionsAllTypes: Seq[SocialContextAction],
|
|
||||||
_tweetyPieResult: Option[TweetyPieResult]
|
|
||||||
): RawCandidate = {
|
|
||||||
new RawCandidate with TweetRetweetCandidate {
|
|
||||||
override val socialContextActions = _socialContextActions
|
|
||||||
override val socialContextAllTypeActions = socialContextActionsAllTypes
|
|
||||||
val tweetId = _tweetId
|
|
||||||
val target = _target
|
|
||||||
val tweetyPieResult = _tweetyPieResult
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
|
||||||
genericCandidates.get(inputTarget).map { candidatesOpt =>
|
|
||||||
candidatesOpt
|
|
||||||
.map { candidates =>
|
|
||||||
val candidatesSeq =
|
|
||||||
candidates.collect {
|
|
||||||
case tweetRetweet: TweetRetweetCandidate
|
|
||||||
if inputTarget.params(PushParams.MRTweetRetweetRecsParam) =>
|
|
||||||
generateTweetRetweetCandidate(
|
|
||||||
inputTarget,
|
|
||||||
tweetRetweet.tweetId,
|
|
||||||
tweetRetweet.socialContextActions,
|
|
||||||
tweetRetweet.socialContextAllTypeActions,
|
|
||||||
tweetRetweet.tweetyPieResult)
|
|
||||||
case tweetFavorite: TweetFavoriteCandidate
|
|
||||||
if inputTarget.params(PushParams.MRTweetFavRecsParam) =>
|
|
||||||
generateTweetFavCandidate(
|
|
||||||
inputTarget,
|
|
||||||
tweetFavorite.tweetId,
|
|
||||||
tweetFavorite.socialContextActions,
|
|
||||||
tweetFavorite.socialContextAllTypeActions,
|
|
||||||
tweetFavorite.tweetyPieResult)
|
|
||||||
}
|
|
||||||
candidatesSeq.foreach { candidate =>
|
|
||||||
stats.counter(s"${candidate.commonRecType}_count").incr()
|
|
||||||
}
|
|
||||||
candidatesSeq
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target).map { isAvailable =>
|
|
||||||
isAvailable && target.params(PushParams.GenericCandidateAdaptorDecider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,280 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.finagle.stats.Stat
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base.CandidateSource
|
|
||||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
|
||||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum
|
|
||||||
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum._
|
|
||||||
import com.twitter.frigate.pushservice.params.PushConstants.targetUserAgeFeatureName
|
|
||||||
import com.twitter.frigate.pushservice.params.PushConstants.targetUserPreferredLanguage
|
|
||||||
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
|
|
||||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
|
||||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
|
||||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
|
||||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
|
||||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
|
||||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
|
||||||
import com.twitter.interests.thriftscala.InterestId.SemanticCore
|
|
||||||
import com.twitter.interests.thriftscala.UserInterests
|
|
||||||
import com.twitter.language.normalization.UserDisplayLanguage
|
|
||||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
|
||||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweet
|
|
||||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
object HighQualityTweetsHelper {
|
|
||||||
def getFollowedTopics(
|
|
||||||
target: Target,
|
|
||||||
interestsWithLookupContextStore: ReadableStore[
|
|
||||||
InterestsLookupRequestWithContext,
|
|
||||||
UserInterests
|
|
||||||
],
|
|
||||||
followedTopicsStats: Stat
|
|
||||||
): Future[Seq[Long]] = {
|
|
||||||
TopicsUtil
|
|
||||||
.getTopicsFollowedByUser(target, interestsWithLookupContextStore, followedTopicsStats).map {
|
|
||||||
userInterestsOpt =>
|
|
||||||
val userInterests = userInterestsOpt.getOrElse(Seq.empty)
|
|
||||||
val extractedTopicIds = userInterests.flatMap {
|
|
||||||
_.interestId match {
|
|
||||||
case SemanticCore(semanticCore) => Some(semanticCore.id)
|
|
||||||
case _ => None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
extractedTopicIds
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def getTripQueries(
|
|
||||||
target: Target,
|
|
||||||
enabledGroups: Set[HighQualityCandidateGroupEnum.Value],
|
|
||||||
interestsWithLookupContextStore: ReadableStore[
|
|
||||||
InterestsLookupRequestWithContext,
|
|
||||||
UserInterests
|
|
||||||
],
|
|
||||||
sourceIds: Seq[String],
|
|
||||||
stat: Stat
|
|
||||||
): Future[Set[TripDomain]] = {
|
|
||||||
|
|
||||||
val followedTopicIdsSetFut: Future[Set[Long]] = if (enabledGroups.contains(Topic)) {
|
|
||||||
getFollowedTopics(target, interestsWithLookupContextStore, stat).map(topicIds =>
|
|
||||||
topicIds.toSet)
|
|
||||||
} else {
|
|
||||||
Future.value(Set.empty)
|
|
||||||
}
|
|
||||||
|
|
||||||
Future
|
|
||||||
.join(target.featureMap, target.inferredUserDeviceLanguage, followedTopicIdsSetFut).map {
|
|
||||||
case (
|
|
||||||
featureMap,
|
|
||||||
deviceLanguageOpt,
|
|
||||||
followedTopicIds
|
|
||||||
) =>
|
|
||||||
val ageBucketOpt = if (enabledGroups.contains(AgeBucket)) {
|
|
||||||
featureMap.categoricalFeatures.get(targetUserAgeFeatureName)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
val languageOptions: Set[Option[String]] = if (enabledGroups.contains(Language)) {
|
|
||||||
val userPreferredLanguages = featureMap.sparseBinaryFeatures
|
|
||||||
.getOrElse(targetUserPreferredLanguage, Set.empty[String])
|
|
||||||
if (userPreferredLanguages.nonEmpty) {
|
|
||||||
userPreferredLanguages.map(lang => Some(UserDisplayLanguage.toTweetLanguage(lang)))
|
|
||||||
} else {
|
|
||||||
Set(deviceLanguageOpt.map(UserDisplayLanguage.toTweetLanguage))
|
|
||||||
}
|
|
||||||
} else Set(None)
|
|
||||||
|
|
||||||
val followedTopicOptions: Set[Option[Long]] = if (followedTopicIds.nonEmpty) {
|
|
||||||
followedTopicIds.map(topic => Some(topic))
|
|
||||||
} else Set(None)
|
|
||||||
|
|
||||||
val tripQueries = followedTopicOptions.flatMap { topicOption =>
|
|
||||||
languageOptions.flatMap { languageOption =>
|
|
||||||
sourceIds.map { sourceId =>
|
|
||||||
TripDomain(
|
|
||||||
sourceId = sourceId,
|
|
||||||
language = languageOption,
|
|
||||||
placeId = None,
|
|
||||||
topicId = topicOption,
|
|
||||||
gender = None,
|
|
||||||
ageBucket = ageBucketOpt
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tripQueries
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case class HighQualityTweetsAdaptor(
|
|
||||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
|
||||||
interestsWithLookupContextStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
|
|
||||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
|
||||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
|
||||||
globalStats: StatsReceiver)
|
|
||||||
extends CandidateSource[Target, RawCandidate]
|
|
||||||
with CandidateSourceEligible[Target, RawCandidate] {
|
|
||||||
|
|
||||||
override def name: String = this.getClass.getSimpleName
|
|
||||||
|
|
||||||
private val stats = globalStats.scope("HighQualityCandidateAdaptor")
|
|
||||||
private val followedTopicsStats = stats.stat("followed_topics")
|
|
||||||
private val missingResponseCounter = stats.counter("missing_respond_counter")
|
|
||||||
private val crtFatigueCounter = stats.counter("fatigue_by_crt")
|
|
||||||
private val fallbackRequestsCounter = stats.counter("fallback_requests")
|
|
||||||
|
|
||||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target).map {
|
|
||||||
_ && target.params(FS.HighQualityCandidatesEnableCandidateSource)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val highQualityCandidateFrequencyPredicate = {
|
|
||||||
TargetPredicates
|
|
||||||
.pushRecTypeFatiguePredicate(
|
|
||||||
CommonRecommendationType.TripHqTweet,
|
|
||||||
FS.HighQualityTweetsPushInterval,
|
|
||||||
FS.MaxHighQualityTweetsPushGivenInterval,
|
|
||||||
stats
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getTripCandidatesStrato(
|
|
||||||
target: Target
|
|
||||||
): Future[Map[Long, Set[TripDomain]]] = {
|
|
||||||
val tripQueriesF: Future[Set[TripDomain]] = HighQualityTweetsHelper.getTripQueries(
|
|
||||||
target = target,
|
|
||||||
enabledGroups = target.params(FS.HighQualityCandidatesEnableGroups).toSet,
|
|
||||||
interestsWithLookupContextStore = interestsWithLookupContextStore,
|
|
||||||
sourceIds = target.params(FS.TripTweetCandidateSourceIds),
|
|
||||||
stat = followedTopicsStats
|
|
||||||
)
|
|
||||||
|
|
||||||
lazy val fallbackTripQueriesFut: Future[Set[TripDomain]] =
|
|
||||||
if (target.params(FS.HighQualityCandidatesEnableFallback))
|
|
||||||
HighQualityTweetsHelper.getTripQueries(
|
|
||||||
target = target,
|
|
||||||
enabledGroups = target.params(FS.HighQualityCandidatesFallbackEnabledGroups).toSet,
|
|
||||||
interestsWithLookupContextStore = interestsWithLookupContextStore,
|
|
||||||
sourceIds = target.params(FS.HighQualityCandidatesFallbackSourceIds),
|
|
||||||
stat = followedTopicsStats
|
|
||||||
)
|
|
||||||
else Future.value(Set.empty)
|
|
||||||
|
|
||||||
val initialTweetsFut: Future[Map[TripDomain, Seq[TripTweet]]] = tripQueriesF.flatMap {
|
|
||||||
tripQueries => getTripTweetsByDomains(tripQueries)
|
|
||||||
}
|
|
||||||
|
|
||||||
val tweetsByDomainFut: Future[Map[TripDomain, Seq[TripTweet]]] =
|
|
||||||
if (target.params(FS.HighQualityCandidatesEnableFallback)) {
|
|
||||||
initialTweetsFut.flatMap { candidates =>
|
|
||||||
val minCandidatesForFallback: Int =
|
|
||||||
target.params(FS.HighQualityCandidatesMinNumOfCandidatesToFallback)
|
|
||||||
val validCandidates = candidates.filter(_._2.size >= minCandidatesForFallback)
|
|
||||||
|
|
||||||
if (validCandidates.nonEmpty) {
|
|
||||||
Future.value(validCandidates)
|
|
||||||
} else {
|
|
||||||
fallbackTripQueriesFut.flatMap { fallbackTripDomains =>
|
|
||||||
fallbackRequestsCounter.incr(fallbackTripDomains.size)
|
|
||||||
getTripTweetsByDomains(fallbackTripDomains)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
initialTweetsFut
|
|
||||||
}
|
|
||||||
|
|
||||||
val numOfCandidates: Int = target.params(FS.HighQualityCandidatesNumberOfCandidates)
|
|
||||||
tweetsByDomainFut.map(tweetsByDomain => reformatDomainTweetMap(tweetsByDomain, numOfCandidates))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getTripTweetsByDomains(
|
|
||||||
tripQueries: Set[TripDomain]
|
|
||||||
): Future[Map[TripDomain, Seq[TripTweet]]] = {
|
|
||||||
Future.collect(tripTweetCandidateStore.multiGet(tripQueries)).map { response =>
|
|
||||||
response
|
|
||||||
.filter(p => p._2.exists(_.tweets.nonEmpty))
|
|
||||||
.mapValues(_.map(_.tweets).getOrElse(Seq.empty))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def reformatDomainTweetMap(
|
|
||||||
tweetsByDomain: Map[TripDomain, Seq[TripTweet]],
|
|
||||||
numOfCandidates: Int
|
|
||||||
): Map[Long, Set[TripDomain]] = tweetsByDomain
|
|
||||||
.flatMap {
|
|
||||||
case (tripDomain, tripTweets) =>
|
|
||||||
tripTweets
|
|
||||||
.sortBy(_.score)(Ordering[Double].reverse)
|
|
||||||
.take(numOfCandidates)
|
|
||||||
.map { tweet => (tweet.tweetId, tripDomain) }
|
|
||||||
}.groupBy(_._1).mapValues(_.map(_._2).toSet)
|
|
||||||
|
|
||||||
private def buildRawCandidate(
|
|
||||||
target: Target,
|
|
||||||
tweetyPieResult: TweetyPieResult,
|
|
||||||
tripDomain: Option[scala.collection.Set[TripDomain]]
|
|
||||||
): RawCandidate = {
|
|
||||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
|
||||||
inputTarget = target,
|
|
||||||
id = tweetyPieResult.tweet.id,
|
|
||||||
mediaCRT = MediaCRT(
|
|
||||||
CommonRecommendationType.TripHqTweet,
|
|
||||||
CommonRecommendationType.TripHqTweet,
|
|
||||||
CommonRecommendationType.TripHqTweet
|
|
||||||
),
|
|
||||||
result = Some(tweetyPieResult),
|
|
||||||
tripTweetDomain = tripDomain
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getTweetyPieResults(
|
|
||||||
target: Target,
|
|
||||||
tweetToTripDomain: Map[Long, Set[TripDomain]]
|
|
||||||
): Future[Map[Long, Option[TweetyPieResult]]] = {
|
|
||||||
Future.collect((if (target.params(FS.EnableVFInTweetypie)) {
|
|
||||||
tweetyPieStore
|
|
||||||
} else {
|
|
||||||
tweetyPieStoreNoVF
|
|
||||||
}).multiGet(tweetToTripDomain.keySet))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
|
||||||
for {
|
|
||||||
tweetsToTripDomainMap <- getTripCandidatesStrato(target)
|
|
||||||
tweetyPieResults <- getTweetyPieResults(target, tweetsToTripDomainMap)
|
|
||||||
} yield {
|
|
||||||
val candidates = tweetyPieResults.flatMap {
|
|
||||||
case (tweetId, tweetyPieResultOpt) =>
|
|
||||||
tweetyPieResultOpt.map(buildRawCandidate(target, _, tweetsToTripDomainMap.get(tweetId)))
|
|
||||||
}
|
|
||||||
if (candidates.nonEmpty) {
|
|
||||||
highQualityCandidateFrequencyPredicate(Seq(target))
|
|
||||||
.map(_.head)
|
|
||||||
.map { isTargetFatigueEligible =>
|
|
||||||
if (isTargetFatigueEligible) Some(candidates)
|
|
||||||
else {
|
|
||||||
crtFatigueCounter.incr()
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(candidates.toSeq)
|
|
||||||
} else {
|
|
||||||
missingResponseCounter.incr()
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,152 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base.CandidateSource
|
|
||||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
|
||||||
import com.twitter.frigate.common.base.ListPushCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
|
||||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
|
||||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
|
||||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
|
||||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
|
||||||
import com.twitter.interests_discovery.thriftscala.DisplayLocation
|
|
||||||
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
|
|
||||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
|
|
||||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
case class ListsToRecommendCandidateAdaptor(
|
|
||||||
listRecommendationsStore: ReadableStore[String, NonPersonalizedRecommendedLists],
|
|
||||||
geoDuckV2Store: ReadableStore[Long, LocationResponse],
|
|
||||||
idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse],
|
|
||||||
globalStats: StatsReceiver)
|
|
||||||
extends CandidateSource[Target, RawCandidate]
|
|
||||||
with CandidateSourceEligible[Target, RawCandidate] {
|
|
||||||
|
|
||||||
override val name: String = this.getClass.getSimpleName
|
|
||||||
|
|
||||||
private[this] val stats = globalStats.scope(name)
|
|
||||||
private[this] val noLocationCodeCounter = stats.counter("no_location_code")
|
|
||||||
private[this] val noCandidatesCounter = stats.counter("no_candidates_for_geo")
|
|
||||||
private[this] val disablePopGeoListsCounter = stats.counter("disable_pop_geo_lists")
|
|
||||||
private[this] val disableIDSListsCounter = stats.counter("disable_ids_lists")
|
|
||||||
|
|
||||||
private def getListCandidate(
|
|
||||||
targetUser: Target,
|
|
||||||
_listId: Long
|
|
||||||
): RawCandidate with ListPushCandidate = {
|
|
||||||
new RawCandidate with ListPushCandidate {
|
|
||||||
override val listId: Long = _listId
|
|
||||||
|
|
||||||
override val commonRecType: CommonRecommendationType = CommonRecommendationType.List
|
|
||||||
|
|
||||||
override val target: Target = targetUser
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getListsRecommendedFromHistory(
|
|
||||||
target: Target
|
|
||||||
): Future[Seq[Long]] = {
|
|
||||||
target.history.map { history =>
|
|
||||||
history.sortedHistory.flatMap {
|
|
||||||
case (_, notif) if notif.commonRecommendationType == List =>
|
|
||||||
notif.listNotification.map(_.listId)
|
|
||||||
case _ => None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getIDSListRecs(
|
|
||||||
target: Target,
|
|
||||||
historicalListIds: Seq[Long]
|
|
||||||
): Future[Seq[Long]] = {
|
|
||||||
val request = RecommendedListsRequest(
|
|
||||||
target.targetId,
|
|
||||||
DisplayLocation.ListDiscoveryPage,
|
|
||||||
Some(historicalListIds)
|
|
||||||
)
|
|
||||||
if (target.params(PushFeatureSwitchParams.EnableIDSListRecommendations)) {
|
|
||||||
idsStore.get(request).map {
|
|
||||||
case Some(response) =>
|
|
||||||
response.channels.map(_.id)
|
|
||||||
case _ => Nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
disableIDSListsCounter.incr()
|
|
||||||
Future.Nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getPopGeoLists(
|
|
||||||
target: Target,
|
|
||||||
historicalListIds: Seq[Long]
|
|
||||||
): Future[Seq[Long]] = {
|
|
||||||
if (target.params(PushFeatureSwitchParams.EnablePopGeoListRecommendations)) {
|
|
||||||
geoDuckV2Store.get(target.targetId).flatMap {
|
|
||||||
case Some(locationResponse) if locationResponse.geohash.isDefined =>
|
|
||||||
val geoHashLength =
|
|
||||||
target.params(PushFeatureSwitchParams.ListRecommendationsGeoHashLength)
|
|
||||||
val geoHash = locationResponse.geohash.get.take(geoHashLength)
|
|
||||||
listRecommendationsStore
|
|
||||||
.get(s"geohash_$geoHash")
|
|
||||||
.map {
|
|
||||||
case Some(recommendedLists) =>
|
|
||||||
recommendedLists.recommendedListsByAlgo.flatMap { topLists =>
|
|
||||||
topLists.lists.collect {
|
|
||||||
case list if !historicalListIds.contains(list.listId) => list.listId
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case _ => Nil
|
|
||||||
}
|
|
||||||
case _ =>
|
|
||||||
noLocationCodeCounter.incr()
|
|
||||||
Future.Nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
disablePopGeoListsCounter.incr()
|
|
||||||
Future.Nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
|
||||||
getListsRecommendedFromHistory(target).flatMap { historicalListIds =>
|
|
||||||
Future
|
|
||||||
.join(
|
|
||||||
getPopGeoLists(target, historicalListIds),
|
|
||||||
getIDSListRecs(target, historicalListIds)
|
|
||||||
)
|
|
||||||
.map {
|
|
||||||
case (popGeoListsIds, idsListIds) =>
|
|
||||||
val candidates = (idsListIds ++ popGeoListsIds).map(getListCandidate(target, _))
|
|
||||||
Some(candidates)
|
|
||||||
case _ =>
|
|
||||||
noCandidatesCounter.incr()
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val pushCapFatiguePredicate = TargetPredicates.pushRecTypeFatiguePredicate(
|
|
||||||
CommonRecommendationType.List,
|
|
||||||
PushFeatureSwitchParams.ListRecommendationsPushInterval,
|
|
||||||
PushFeatureSwitchParams.MaxListRecommendationsPushGivenInterval,
|
|
||||||
stats,
|
|
||||||
)
|
|
||||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
|
||||||
|
|
||||||
val isNotFatigued = pushCapFatiguePredicate.apply(Seq(target)).map(_.head)
|
|
||||||
|
|
||||||
Future
|
|
||||||
.join(
|
|
||||||
PushDeviceUtil.isRecommendationsEligible(target),
|
|
||||||
isNotFatigued
|
|
||||||
).map {
|
|
||||||
case (userRecommendationsEligible, isUnderCAP) =>
|
|
||||||
userRecommendationsEligible && isUnderCAP && target.params(
|
|
||||||
PushFeatureSwitchParams.EnableListRecommendations)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,54 +0,0 @@
|
|||||||
package com.twitter.frigate.pushservice.adaptor
|
|
||||||
|
|
||||||
import com.twitter.finagle.stats.StatsReceiver
|
|
||||||
import com.twitter.frigate.common.base.CandidateSource
|
|
||||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
|
||||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
|
||||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
|
||||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
|
||||||
import com.twitter.storehaus.ReadableStore
|
|
||||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
|
||||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
|
||||||
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
|
|
||||||
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
|
|
||||||
import com.twitter.geoduck.common.thriftscala.Location
|
|
||||||
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
|
|
||||||
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
|
|
||||||
|
|
||||||
class LoggedOutPushCandidateSourceGenerator(
|
|
||||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
|
||||||
geoDuckV2Store: ReadableStore[Long, LocationResponse],
|
|
||||||
safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
|
||||||
cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult],
|
|
||||||
cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
|
||||||
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
|
|
||||||
softUserLocationStore: ReadableStore[Long, Location],
|
|
||||||
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
|
|
||||||
topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace],
|
|
||||||
)(
|
|
||||||
implicit val globalStats: StatsReceiver) {
|
|
||||||
val sources: Seq[CandidateSource[Target, RawCandidate] with CandidateSourceEligible[
|
|
||||||
Target,
|
|
||||||
RawCandidate
|
|
||||||
]] = {
|
|
||||||
Seq(
|
|
||||||
TripGeoCandidatesAdaptor(
|
|
||||||
tripTweetCandidateStore,
|
|
||||||
contentMixerStore,
|
|
||||||
safeCachedTweetyPieStoreV2,
|
|
||||||
cachedTweetyPieStoreV2NoVF,
|
|
||||||
globalStats
|
|
||||||
),
|
|
||||||
TopTweetsByGeoAdaptor(
|
|
||||||
geoDuckV2Store,
|
|
||||||
softUserLocationStore,
|
|
||||||
topTweetsByGeoStore,
|
|
||||||
topTweetsByGeoV2VersionedStore,
|
|
||||||
cachedTweetyPieStoreV2,
|
|
||||||
cachedTweetyPieStoreV2NoVF,
|
|
||||||
globalStats
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user