mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-11 19:59:10 +01:00
[docx] split commit for file 3400
Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
parent
f3c5ff35cb
commit
9ebf058832
Binary file not shown.
@ -1,281 +0,0 @@
|
||||
package com.twitter.product_mixer.shared_library.observer
|
||||
|
||||
import com.twitter.finagle.stats.Counter
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.ArrowObserver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.FunctionObserver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.FutureObserver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.Observer
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.StitchObserver
|
||||
import com.twitter.stitch.Arrow
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Try
|
||||
|
||||
/**
|
||||
* Helper functions to observe requests, successes, failures, cancellations, exceptions, latency,
|
||||
* and result counts. Supports native functions and asynchronous operations.
|
||||
*/
|
||||
object ResultsObserver {
|
||||
val Total = "total"
|
||||
val Found = "found"
|
||||
val NotFound = "not_found"
|
||||
|
||||
/**
|
||||
* Helper function to observe a stitch and result counts
|
||||
*
|
||||
* @see [[StitchResultsObserver]]
|
||||
*/
|
||||
def stitchResults[T](
|
||||
size: T => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): StitchResultsObserver[T] = {
|
||||
new StitchResultsObserver[T](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a stitch and traversable (e.g. Seq, Set) result counts
|
||||
*
|
||||
* @see [[StitchResultsObserver]]
|
||||
*/
|
||||
def stitchResults[T <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): StitchResultsObserver[T] = {
|
||||
new StitchResultsObserver[T](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and result counts
|
||||
*
|
||||
* @see [[ArrowResultsObserver]]
|
||||
*/
|
||||
def arrowResults[In, Out](
|
||||
size: Out => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): ArrowResultsObserver[In, Out] = {
|
||||
new ArrowResultsObserver[In, Out](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts
|
||||
*
|
||||
* @see [[ArrowResultsObserver]]
|
||||
*/
|
||||
def arrowResults[In, Out <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): ArrowResultsObserver[In, Out] = {
|
||||
new ArrowResultsObserver[In, Out](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and result counts
|
||||
*
|
||||
* @see [[TransformingArrowResultsObserver]]
|
||||
*/
|
||||
def transformingArrowResults[In, Out, Transformed](
|
||||
transformer: Out => Try[Transformed],
|
||||
size: Transformed => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): TransformingArrowResultsObserver[In, Out, Transformed] = {
|
||||
new TransformingArrowResultsObserver[In, Out, Transformed](
|
||||
transformer,
|
||||
size,
|
||||
statsReceiver,
|
||||
scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts
|
||||
*
|
||||
* @see [[TransformingArrowResultsObserver]]
|
||||
*/
|
||||
def transformingArrowResults[In, Out, Transformed <: TraversableOnce[_]](
|
||||
transformer: Out => Try[Transformed],
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): TransformingArrowResultsObserver[In, Out, Transformed] = {
|
||||
new TransformingArrowResultsObserver[In, Out, Transformed](
|
||||
transformer,
|
||||
_.size,
|
||||
statsReceiver,
|
||||
scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a future and result counts
|
||||
*
|
||||
* @see [[FutureResultsObserver]]
|
||||
*/
|
||||
def futureResults[T](
|
||||
size: T => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FutureResultsObserver[T] = {
|
||||
new FutureResultsObserver[T](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a future and traversable (e.g. Seq, Set) result counts
|
||||
*
|
||||
* @see [[FutureResultsObserver]]
|
||||
*/
|
||||
def futureResults[T <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FutureResultsObserver[T] = {
|
||||
new FutureResultsObserver[T](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a function and result counts
|
||||
*
|
||||
* @see [[FunctionResultsObserver]]
|
||||
*/
|
||||
def functionResults[T](
|
||||
size: T => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FunctionResultsObserver[T] = {
|
||||
new FunctionResultsObserver[T](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a function and traversable (e.g. Seq, Set) result counts
|
||||
*
|
||||
* @see [[FunctionResultsObserver]]
|
||||
*/
|
||||
def functionResults[T <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FunctionResultsObserver[T] = {
|
||||
new FunctionResultsObserver[T](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/** [[StitchObserver]] that also records result size */
|
||||
class StitchResultsObserver[T](
|
||||
override val size: T => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends StitchObserver[T](statsReceiver, scopes)
|
||||
with ResultsObserver[T] {
|
||||
|
||||
override def apply(stitch: => Stitch[T]): Stitch[T] =
|
||||
super
|
||||
.apply(stitch)
|
||||
.onSuccess(observeResults)
|
||||
}
|
||||
|
||||
/** [[ArrowObserver]] that also records result size */
|
||||
class ArrowResultsObserver[In, Out](
|
||||
override val size: Out => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends ArrowObserver[In, Out](statsReceiver, scopes)
|
||||
with ResultsObserver[Out] {
|
||||
|
||||
override def apply(arrow: Arrow[In, Out]): Arrow[In, Out] =
|
||||
super
|
||||
.apply(arrow)
|
||||
.onSuccess(observeResults)
|
||||
}
|
||||
|
||||
/**
|
||||
* [[TransformingArrowResultsObserver]] functions like an [[ArrowObserver]] except
|
||||
* that it transforms the result using [[transformer]] before recording stats.
|
||||
*
|
||||
* The original non-transformed result is then returned.
|
||||
*/
|
||||
class TransformingArrowResultsObserver[In, Out, Transformed](
|
||||
val transformer: Out => Try[Transformed],
|
||||
override val size: Transformed => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends Observer[Transformed]
|
||||
with ResultsObserver[Transformed] {
|
||||
|
||||
/**
|
||||
* Returns a new Arrow that records stats on the result after applying [[transformer]] when it's run.
|
||||
* The original, non-transformed, result of the Arrow is passed through.
|
||||
*
|
||||
* @note the provided Arrow must contain the parts that need to be timed.
|
||||
* Using this on just the result of the computation the latency stat
|
||||
* will be incorrect.
|
||||
*/
|
||||
def apply(arrow: Arrow[In, Out]): Arrow[In, Out] = {
|
||||
Arrow
|
||||
.time(arrow)
|
||||
.map {
|
||||
case (response, stitchRunDuration) =>
|
||||
observe(response.flatMap(transformer), stitchRunDuration)
|
||||
.onSuccess(observeResults)
|
||||
response
|
||||
}.lowerFromTry
|
||||
}
|
||||
}
|
||||
|
||||
/** [[FutureObserver]] that also records result size */
|
||||
class FutureResultsObserver[T](
|
||||
override val size: T => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends FutureObserver[T](statsReceiver, scopes)
|
||||
with ResultsObserver[T] {
|
||||
|
||||
override def apply(future: => Future[T]): Future[T] =
|
||||
super
|
||||
.apply(future)
|
||||
.onSuccess(observeResults)
|
||||
}
|
||||
|
||||
/** [[FunctionObserver]] that also records result size */
|
||||
class FunctionResultsObserver[T](
|
||||
override val size: T => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends FunctionObserver[T](statsReceiver, scopes)
|
||||
with ResultsObserver[T] {
|
||||
|
||||
override def apply(f: => T): T = observeResults(super.apply(f))
|
||||
}
|
||||
|
||||
/** [[ResultsObserver]] provides methods for recording stats for the result size */
|
||||
trait ResultsObserver[T] {
|
||||
protected val statsReceiver: StatsReceiver
|
||||
|
||||
/** Scopes that prefix all stats */
|
||||
protected val scopes: Seq[String]
|
||||
|
||||
protected val totalCounter: Counter = statsReceiver.counter(scopes :+ Total: _*)
|
||||
protected val foundCounter: Counter = statsReceiver.counter(scopes :+ Found: _*)
|
||||
protected val notFoundCounter: Counter = statsReceiver.counter(scopes :+ NotFound: _*)
|
||||
|
||||
/** given a [[T]] returns it's size. */
|
||||
protected val size: T => Int
|
||||
|
||||
/** Records the size of the `results` using [[size]] and return the original value. */
|
||||
protected def observeResults(results: T): T = {
|
||||
val resultsSize = size(results)
|
||||
observeResultsWithSize(results, resultsSize)
|
||||
}
|
||||
|
||||
/**
|
||||
* Records the `resultsSize` and returns the `results`
|
||||
*
|
||||
* This is useful if the size is already available and is expensive to calculate.
|
||||
*/
|
||||
protected def observeResultsWithSize(results: T, resultsSize: Int): T = {
|
||||
if (resultsSize > 0) {
|
||||
totalCounter.incr(resultsSize)
|
||||
foundCounter.incr()
|
||||
} else {
|
||||
notFoundCounter.incr()
|
||||
}
|
||||
results
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,243 +0,0 @@
|
||||
package com.twitter.product_mixer.shared_library.observer
|
||||
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.ArrowObserver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.FunctionObserver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.FutureObserver
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.Observer
|
||||
import com.twitter.product_mixer.shared_library.observer.Observer.StitchObserver
|
||||
import com.twitter.product_mixer.shared_library.observer.ResultsObserver.ResultsObserver
|
||||
import com.twitter.stitch.Arrow
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Try
|
||||
|
||||
/**
|
||||
* Helper functions to observe requests, successes, failures, cancellations, exceptions, latency,
|
||||
* and result counts and time-series stats. Supports native functions and asynchronous operations.
|
||||
*
|
||||
* Note that since time-series stats are expensive to compute (relative to counters), prefer
|
||||
* [[ResultsObserver]] unless a time-series stat is needed.
|
||||
*/
|
||||
object ResultsStatsObserver {
|
||||
val Size = "size"
|
||||
|
||||
/**
|
||||
* Helper function to observe a stitch and result counts and time-series stats
|
||||
*/
|
||||
def stitchResultsStats[T](
|
||||
size: T => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): StitchResultsStatsObserver[T] = {
|
||||
new StitchResultsStatsObserver[T](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a stitch and traversable (e.g. Seq, Set) result counts and
|
||||
* time-series stats
|
||||
*/
|
||||
def stitchResultsStats[T <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): StitchResultsStatsObserver[T] = {
|
||||
new StitchResultsStatsObserver[T](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and result counts and time-series stats
|
||||
*/
|
||||
def arrowResultsStats[T, U](
|
||||
size: U => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): ArrowResultsStatsObserver[T, U] = {
|
||||
new ArrowResultsStatsObserver[T, U](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts and
|
||||
* * time-series stats
|
||||
*/
|
||||
def arrowResultsStats[T, U <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): ArrowResultsStatsObserver[T, U] = {
|
||||
new ArrowResultsStatsObserver[T, U](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and result counts
|
||||
*
|
||||
* @see [[TransformingArrowResultsStatsObserver]]
|
||||
*/
|
||||
def transformingArrowResultsStats[In, Out, Transformed](
|
||||
transformer: Out => Try[Transformed],
|
||||
size: Transformed => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): TransformingArrowResultsStatsObserver[In, Out, Transformed] = {
|
||||
new TransformingArrowResultsStatsObserver[In, Out, Transformed](
|
||||
transformer,
|
||||
size,
|
||||
statsReceiver,
|
||||
scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe an arrow and traversable (e.g. Seq, Set) result counts
|
||||
*
|
||||
* @see [[TransformingArrowResultsStatsObserver]]
|
||||
*/
|
||||
def transformingArrowResultsStats[In, Out, Transformed <: TraversableOnce[_]](
|
||||
transformer: Out => Try[Transformed],
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): TransformingArrowResultsStatsObserver[In, Out, Transformed] = {
|
||||
new TransformingArrowResultsStatsObserver[In, Out, Transformed](
|
||||
transformer,
|
||||
_.size,
|
||||
statsReceiver,
|
||||
scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a future and result counts and time-series stats
|
||||
*/
|
||||
def futureResultsStats[T](
|
||||
size: T => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FutureResultsStatsObserver[T] = {
|
||||
new FutureResultsStatsObserver[T](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to observe a future and traversable (e.g. Seq, Set) result counts and
|
||||
* time-series stats
|
||||
*/
|
||||
def futureResultsStats[T <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FutureResultsStatsObserver[T] = {
|
||||
new FutureResultsStatsObserver[T](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function observe a function and result counts and time-series stats
|
||||
*/
|
||||
def functionResultsStats[T](
|
||||
size: T => Int,
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FunctionResultsStatsObserver[T] = {
|
||||
new FunctionResultsStatsObserver[T](size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function observe a function and traversable (e.g. Seq, Set) result counts and
|
||||
* time-series stats
|
||||
*/
|
||||
def functionResultsStats[T <: TraversableOnce[_]](
|
||||
statsReceiver: StatsReceiver,
|
||||
scopes: String*
|
||||
): FunctionResultsStatsObserver[T] = {
|
||||
new FunctionResultsStatsObserver[T](_.size, statsReceiver, scopes)
|
||||
}
|
||||
|
||||
class StitchResultsStatsObserver[T](
|
||||
override val size: T => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends StitchObserver[T](statsReceiver, scopes)
|
||||
with ResultsStatsObserver[T] {
|
||||
|
||||
override def apply(stitch: => Stitch[T]): Stitch[T] =
|
||||
super
|
||||
.apply(stitch)
|
||||
.onSuccess(observeResults)
|
||||
}
|
||||
|
||||
class ArrowResultsStatsObserver[T, U](
|
||||
override val size: U => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends ArrowObserver[T, U](statsReceiver, scopes)
|
||||
with ResultsStatsObserver[U] {
|
||||
|
||||
override def apply(arrow: Arrow[T, U]): Arrow[T, U] =
|
||||
super
|
||||
.apply(arrow)
|
||||
.onSuccess(observeResults)
|
||||
}
|
||||
|
||||
/**
|
||||
* [[TransformingArrowResultsStatsObserver]] functions like an [[ArrowObserver]] except
|
||||
* that it transforms the result using [[transformer]] before recording stats.
|
||||
*
|
||||
* The original non-transformed result is then returned.
|
||||
*/
|
||||
class TransformingArrowResultsStatsObserver[In, Out, Transformed](
|
||||
val transformer: Out => Try[Transformed],
|
||||
override val size: Transformed => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends Observer[Transformed]
|
||||
with ResultsStatsObserver[Transformed] {
|
||||
|
||||
/**
|
||||
* Returns a new Arrow that records stats on the result after applying [[transformer]] when it's run.
|
||||
* The original, non-transformed, result of the Arrow is passed through.
|
||||
*
|
||||
* @note the provided Arrow must contain the parts that need to be timed.
|
||||
* Using this on just the result of the computation the latency stat
|
||||
* will be incorrect.
|
||||
*/
|
||||
def apply(arrow: Arrow[In, Out]): Arrow[In, Out] = {
|
||||
Arrow
|
||||
.time(arrow)
|
||||
.map {
|
||||
case (response, stitchRunDuration) =>
|
||||
observe(response.flatMap(transformer), stitchRunDuration)
|
||||
.onSuccess(observeResults)
|
||||
response
|
||||
}.lowerFromTry
|
||||
}
|
||||
}
|
||||
|
||||
class FutureResultsStatsObserver[T](
|
||||
override val size: T => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends FutureObserver[T](statsReceiver, scopes)
|
||||
with ResultsStatsObserver[T] {
|
||||
|
||||
override def apply(future: => Future[T]): Future[T] =
|
||||
super
|
||||
.apply(future)
|
||||
.onSuccess(observeResults)
|
||||
}
|
||||
|
||||
class FunctionResultsStatsObserver[T](
|
||||
override val size: T => Int,
|
||||
override val statsReceiver: StatsReceiver,
|
||||
override val scopes: Seq[String])
|
||||
extends FunctionObserver[T](statsReceiver, scopes)
|
||||
with ResultsStatsObserver[T] {
|
||||
|
||||
override def apply(f: => T): T = {
|
||||
observeResults(super.apply(f))
|
||||
}
|
||||
}
|
||||
|
||||
trait ResultsStatsObserver[T] extends ResultsObserver[T] {
|
||||
private val sizeStat: Stat = statsReceiver.stat(scopes :+ Size: _*)
|
||||
|
||||
protected override def observeResults(results: T): T = {
|
||||
val resultsSize = size(results)
|
||||
sizeStat.add(resultsSize)
|
||||
observeResultsWithSize(results, resultsSize)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"finagle/finagle-core/src/main",
|
||||
"finagle/finagle-thriftmux/src/main/scala",
|
||||
"finatra-internal/mtls-http/src/main/scala",
|
||||
"finatra-internal/mtls-thriftmux/src/main/scala",
|
||||
"util/util-core",
|
||||
],
|
||||
exports = [
|
||||
"finagle/finagle-core/src/main",
|
||||
"finagle/finagle-thriftmux/src/main/scala",
|
||||
"finatra-internal/mtls-http/src/main/scala",
|
||||
"finatra-internal/mtls-thriftmux/src/main/scala",
|
||||
"util/util-core",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
@ -1,198 +0,0 @@
|
||||
package com.twitter.product_mixer.shared_library.thrift_client
|
||||
|
||||
import com.twitter.conversions.DurationOps._
|
||||
import com.twitter.finagle.ThriftMux
|
||||
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
|
||||
import com.twitter.finagle.mtls.client.MtlsStackClient._
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.finagle.thrift.ClientId
|
||||
import com.twitter.finagle.thrift.service.Filterable
|
||||
import com.twitter.finagle.thrift.service.MethodPerEndpointBuilder
|
||||
import com.twitter.finagle.thrift.service.ServicePerEndpointBuilder
|
||||
import com.twitter.finagle.thriftmux.MethodBuilder
|
||||
import com.twitter.util.Duration
|
||||
import org.apache.thrift.protocol.TProtocolFactory
|
||||
|
||||
sealed trait Idempotency
|
||||
case object NonIdempotent extends Idempotency
|
||||
case class Idempotent(maxExtraLoadPercent: Double) extends Idempotency
|
||||
|
||||
object FinagleThriftClientBuilder {
|
||||
|
||||
/**
|
||||
* Library to build a Finagle Thrift method per endpoint client is a less error-prone way when
|
||||
* compared to the builders in Finagle. This is achieved by requiring values for fields that should
|
||||
* always be set in practice. For example, request timeouts in Finagle are unbounded when not
|
||||
* explicitly set, and this method requires that timeout durations are passed into the method and
|
||||
* set on the Finagle builder.
|
||||
*
|
||||
* Usage of
|
||||
* [[com.twitter.inject.thrift.modules.ThriftMethodBuilderClientModule]] is almost always preferred,
|
||||
* and the Product Mixer component library [[com.twitter.product_mixer.component_library.module]]
|
||||
* package contains numerous examples. However, if multiple versions of a client are needed e.g.
|
||||
* for different timeout settings, this method is useful to easily provide multiple variants.
|
||||
*
|
||||
* @example
|
||||
* {{{
|
||||
* final val SampleServiceClientName = "SampleServiceClient"
|
||||
* @Provides
|
||||
* @Singleton
|
||||
* @Named(SampleServiceClientName)
|
||||
* def provideSampleServiceClient(
|
||||
* serviceIdentifier: ServiceIdentifier,
|
||||
* clientId: ClientId,
|
||||
* statsReceiver: StatsReceiver,
|
||||
* ): SampleService.MethodPerEndpoint =
|
||||
* buildFinagleMethodPerEndpoint[SampleService.ServicePerEndpoint, SampleService.MethodPerEndpoint](
|
||||
* serviceIdentifier = serviceIdentifier,
|
||||
* clientId = clientId,
|
||||
* dest = "/s/sample/sample",
|
||||
* label = "sample",
|
||||
* statsReceiver = statsReceiver,
|
||||
* idempotency = Idempotent(5.percent),
|
||||
* timeoutPerRequest = 200.milliseconds,
|
||||
* timeoutTotal = 400.milliseconds
|
||||
* )
|
||||
* }}}
|
||||
* @param serviceIdentifier Service ID used to S2S Auth
|
||||
* @param clientId Client ID
|
||||
* @param dest Destination as a Wily path e.g. "/s/sample/sample"
|
||||
* @param label Label of the client
|
||||
* @param statsReceiver Stats
|
||||
* @param idempotency Idempotency semantics of the client
|
||||
* @param timeoutPerRequest Thrift client timeout per request. The Finagle default is
|
||||
* unbounded which is almost never optimal.
|
||||
* @param timeoutTotal Thrift client total timeout. The Finagle default is
|
||||
* unbounded which is almost never optimal.
|
||||
* If the client is set as idempotent, which adds a
|
||||
* [[com.twitter.finagle.client.BackupRequestFilter]],
|
||||
* be sure to leave enough room for the backup request. A
|
||||
* reasonable (albeit usually too large) starting point is to
|
||||
* make the total timeout 2x relative to the per request timeout.
|
||||
* If the client is set as non-idempotent, the total timeout and
|
||||
* the per request timeout should be the same, as there will be
|
||||
* no backup requests.
|
||||
* @param connectTimeout Thrift client transport connect timeout. The Finagle default
|
||||
* of one second is reasonable but we lower this to match
|
||||
* acquisitionTimeout for consistency.
|
||||
* @param acquisitionTimeout Thrift client session acquisition timeout. The Finagle default
|
||||
* is unbounded which is almost never optimal.
|
||||
* @param protocolFactoryOverride Override the default protocol factory
|
||||
* e.g. [[org.apache.thrift.protocol.TCompactProtocol.Factory]]
|
||||
* @param servicePerEndpointBuilder implicit service per endpoint builder
|
||||
* @param methodPerEndpointBuilder implicit method per endpoint builder
|
||||
*
|
||||
* @see [[https://twitter.github.io/finagle/guide/MethodBuilder.html user guide]]
|
||||
* @see [[https://twitter.github.io/finagle/guide/MethodBuilder.html#idempotency user guide]]
|
||||
* @return method per endpoint Finagle Thrift Client
|
||||
*/
|
||||
def buildFinagleMethodPerEndpoint[
|
||||
ServicePerEndpoint <: Filterable[ServicePerEndpoint],
|
||||
MethodPerEndpoint
|
||||
](
|
||||
serviceIdentifier: ServiceIdentifier,
|
||||
clientId: ClientId,
|
||||
dest: String,
|
||||
label: String,
|
||||
statsReceiver: StatsReceiver,
|
||||
idempotency: Idempotency,
|
||||
timeoutPerRequest: Duration,
|
||||
timeoutTotal: Duration,
|
||||
connectTimeout: Duration = 500.milliseconds,
|
||||
acquisitionTimeout: Duration = 500.milliseconds,
|
||||
protocolFactoryOverride: Option[TProtocolFactory] = None,
|
||||
)(
|
||||
implicit servicePerEndpointBuilder: ServicePerEndpointBuilder[ServicePerEndpoint],
|
||||
methodPerEndpointBuilder: MethodPerEndpointBuilder[ServicePerEndpoint, MethodPerEndpoint]
|
||||
): MethodPerEndpoint = {
|
||||
val service: ServicePerEndpoint = buildFinagleServicePerEndpoint(
|
||||
serviceIdentifier = serviceIdentifier,
|
||||
clientId = clientId,
|
||||
dest = dest,
|
||||
label = label,
|
||||
statsReceiver = statsReceiver,
|
||||
idempotency = idempotency,
|
||||
timeoutPerRequest = timeoutPerRequest,
|
||||
timeoutTotal = timeoutTotal,
|
||||
connectTimeout = connectTimeout,
|
||||
acquisitionTimeout = acquisitionTimeout,
|
||||
protocolFactoryOverride = protocolFactoryOverride
|
||||
)
|
||||
|
||||
ThriftMux.Client.methodPerEndpoint(service)
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a Finagle Thrift service per endpoint client.
|
||||
*
|
||||
* @note [[buildFinagleMethodPerEndpoint]] should be preferred over the service per endpoint variant
|
||||
*
|
||||
* @param serviceIdentifier Service ID used to S2S Auth
|
||||
* @param clientId Client ID
|
||||
* @param dest Destination as a Wily path e.g. "/s/sample/sample"
|
||||
* @param label Label of the client
|
||||
* @param statsReceiver Stats
|
||||
* @param idempotency Idempotency semantics of the client
|
||||
* @param timeoutPerRequest Thrift client timeout per request. The Finagle default is
|
||||
* unbounded which is almost never optimal.
|
||||
* @param timeoutTotal Thrift client total timeout. The Finagle default is
|
||||
* unbounded which is almost never optimal.
|
||||
* If the client is set as idempotent, which adds a
|
||||
* [[com.twitter.finagle.client.BackupRequestFilter]],
|
||||
* be sure to leave enough room for the backup request. A
|
||||
* reasonable (albeit usually too large) starting point is to
|
||||
* make the total timeout 2x relative to the per request timeout.
|
||||
* If the client is set as non-idempotent, the total timeout and
|
||||
* the per request timeout should be the same, as there will be
|
||||
* no backup requests.
|
||||
* @param connectTimeout Thrift client transport connect timeout. The Finagle default
|
||||
* of one second is reasonable but we lower this to match
|
||||
* acquisitionTimeout for consistency.
|
||||
* @param acquisitionTimeout Thrift client session acquisition timeout. The Finagle default
|
||||
* is unbounded which is almost never optimal.
|
||||
* @param protocolFactoryOverride Override the default protocol factory
|
||||
* e.g. [[org.apache.thrift.protocol.TCompactProtocol.Factory]]
|
||||
*
|
||||
* @return service per endpoint Finagle Thrift Client
|
||||
*/
|
||||
def buildFinagleServicePerEndpoint[ServicePerEndpoint <: Filterable[ServicePerEndpoint]](
|
||||
serviceIdentifier: ServiceIdentifier,
|
||||
clientId: ClientId,
|
||||
dest: String,
|
||||
label: String,
|
||||
statsReceiver: StatsReceiver,
|
||||
idempotency: Idempotency,
|
||||
timeoutPerRequest: Duration,
|
||||
timeoutTotal: Duration,
|
||||
connectTimeout: Duration = 500.milliseconds,
|
||||
acquisitionTimeout: Duration = 500.milliseconds,
|
||||
protocolFactoryOverride: Option[TProtocolFactory] = None,
|
||||
)(
|
||||
implicit servicePerEndpointBuilder: ServicePerEndpointBuilder[ServicePerEndpoint]
|
||||
): ServicePerEndpoint = {
|
||||
val thriftMux: ThriftMux.Client = ThriftMux.client
|
||||
.withMutualTls(serviceIdentifier)
|
||||
.withClientId(clientId)
|
||||
.withLabel(label)
|
||||
.withStatsReceiver(statsReceiver)
|
||||
.withTransport.connectTimeout(connectTimeout)
|
||||
.withSession.acquisitionTimeout(acquisitionTimeout)
|
||||
|
||||
val protocolThriftMux: ThriftMux.Client = protocolFactoryOverride
|
||||
.map { protocolFactory =>
|
||||
thriftMux.withProtocolFactory(protocolFactory)
|
||||
}.getOrElse(thriftMux)
|
||||
|
||||
val methodBuilder: MethodBuilder = protocolThriftMux
|
||||
.methodBuilder(dest)
|
||||
.withTimeoutPerRequest(timeoutPerRequest)
|
||||
.withTimeoutTotal(timeoutTotal)
|
||||
|
||||
val idempotencyMethodBuilder: MethodBuilder = idempotency match {
|
||||
case NonIdempotent => methodBuilder.nonIdempotent
|
||||
case Idempotent(maxExtraLoad) => methodBuilder.idempotent(maxExtraLoad = maxExtraLoad)
|
||||
}
|
||||
|
||||
idempotencyMethodBuilder.servicePerEndpoint[ServicePerEndpoint]
|
||||
}
|
||||
}
|
@ -1,48 +0,0 @@
|
||||
alias(
|
||||
name = "frigate-pushservice",
|
||||
target = ":frigate-pushservice_lib",
|
||||
)
|
||||
|
||||
target(
|
||||
name = "frigate-pushservice_lib",
|
||||
dependencies = [
|
||||
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
|
||||
],
|
||||
)
|
||||
|
||||
jvm_binary(
|
||||
name = "bin",
|
||||
basename = "frigate-pushservice",
|
||||
main = "com.twitter.frigate.pushservice.PushServiceMain",
|
||||
runtime_platform = "java11",
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/ch/qos/logback:logback-classic",
|
||||
"finatra/inject/inject-logback/src/main/scala",
|
||||
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
|
||||
"loglens/loglens-logback/src/main/scala/com/twitter/loglens/logback",
|
||||
"twitter-server/logback-classic/src/main/scala",
|
||||
],
|
||||
excludes = [
|
||||
exclude("com.twitter.translations", "translations-twitter"),
|
||||
exclude("org.apache.hadoop", "hadoop-aws"),
|
||||
exclude("org.tensorflow"),
|
||||
scala_exclude("com.twitter", "ckoia-scala"),
|
||||
],
|
||||
)
|
||||
|
||||
jvm_app(
|
||||
name = "bundle",
|
||||
basename = "frigate-pushservice-package-dist",
|
||||
archive = "zip",
|
||||
binary = ":bin",
|
||||
tags = ["bazel-compatible"],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "mr_model_constants",
|
||||
sources = [
|
||||
"config/deepbird/constants.py",
|
||||
],
|
||||
tags = ["bazel-compatible"],
|
||||
)
|
BIN
pushservice/BUILD.docx
Normal file
BIN
pushservice/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/README.docx
Normal file
BIN
pushservice/README.docx
Normal file
Binary file not shown.
@ -1,45 +0,0 @@
|
||||
# Pushservice
|
||||
|
||||
Pushservice is the main push recommendation service at Twitter used to generate recommendation-based notifications for users. It currently powers two functionalities:
|
||||
|
||||
- RefreshForPushHandler: This handler determines whether to send a recommendation push to a user based on their ID. It generates the best push recommendation item and coordinates with downstream services to deliver it
|
||||
- SendHandler: This handler determines and manage whether send the push to users based on the given target user details and the provided push recommendation item
|
||||
|
||||
## Overview
|
||||
|
||||
### RefreshForPushHandler
|
||||
|
||||
RefreshForPushHandler follows these steps:
|
||||
|
||||
- Building Target and checking eligibility
|
||||
- Builds a target user object based on the given user ID
|
||||
- Performs target-level filterings to determine if the target is eligible for a recommendation push
|
||||
- Fetch Candidates
|
||||
- Retrieves a list of potential candidates for the push by querying various candidate sources using the target
|
||||
- Candidate Hydration
|
||||
- Hydrates the candidate details with batch calls to different downstream services
|
||||
- Pre-rank Filtering, also called Light Filtering
|
||||
- Filters the hydrated candidates with lightweight RPC calls
|
||||
- Rank
|
||||
- Perform feature hydration for candidates and target user
|
||||
- Performs light ranking on candidates
|
||||
- Performs heavy ranking on candidates
|
||||
- Take Step, also called Heavy Filtering
|
||||
- Takes the top-ranked candidates one by one and applies heavy filtering until one candidate passes all filter steps
|
||||
- Send
|
||||
- Calls the appropriate downstream service to deliver the eligible candidate as a push and in-app notification to the target user
|
||||
|
||||
### SendHandler
|
||||
|
||||
SendHandler follows these steps:
|
||||
|
||||
- Building Target
|
||||
- Builds a target user object based on the given user ID
|
||||
- Candidate Hydration
|
||||
- Hydrates the candidate details with batch calls to different downstream services
|
||||
- Feature Hydration
|
||||
- Perform feature hydration for candidates and target user
|
||||
- Take Step, also called Heavy Filtering
|
||||
- Perform filterings and validation checking for the given candidate
|
||||
- Send
|
||||
- Calls the appropriate downstream service to deliver the given candidate as a push and/or in-app notification to the target user
|
@ -1,169 +0,0 @@
|
||||
python37_binary(
|
||||
name = "update_warm_start_checkpoint",
|
||||
source = "update_warm_start_checkpoint.py",
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:update_warm_start_checkpoint",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "params_lib",
|
||||
sources = ["params.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
"3rdparty/python/pydantic:default",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:params_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "features_lib",
|
||||
sources = ["features.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":params_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "model_pools_lib",
|
||||
sources = ["model_pools.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":params_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:model_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "graph_lib",
|
||||
sources = ["graph.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":params_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "run_args_lib",
|
||||
sources = ["run_args.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":params_lib",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "deep_norm_lib",
|
||||
sources = ["deep_norm.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":graph_lib",
|
||||
":model_pools_lib",
|
||||
":params_lib",
|
||||
":run_args_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"src/python/twitter/deepbird/util/data",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "eval_lib",
|
||||
sources = ["eval.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":graph_lib",
|
||||
":model_pools_lib",
|
||||
":params_lib",
|
||||
":run_args_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "deep_norm",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:deep_norm",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval",
|
||||
source = "eval.py",
|
||||
dependencies = [
|
||||
":eval_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "mlwf_libs",
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model_local",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model_local",
|
||||
source = "eval.py",
|
||||
dependencies = [
|
||||
":eval_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model",
|
||||
source = "eval.py",
|
||||
dependencies = [
|
||||
":eval_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "mlwf_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":mlwf_libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:mlwf_model",
|
||||
],
|
||||
)
|
BIN
pushservice/src/main/python/models/heavy_ranking/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/heavy_ranking/README.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/README.docx
Normal file
Binary file not shown.
@ -1,20 +0,0 @@
|
||||
# Notification Heavy Ranker Model
|
||||
|
||||
## Model Context
|
||||
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification heavy ranker model is the core ranking model for the personalised notifications recommendation. It's a multi-task learning model to predict the probabilities that the target users will open and engage with the sent notifications.
|
||||
|
||||
|
||||
## Directory Structure
|
||||
- BUILD: this file defines python library dependencies
|
||||
- deep_norm.py: this file contains how to set up continuous training, model evaluation and model exporting for the notification heavy ranker model
|
||||
- eval.py: the main python entry file to set up the overall model evaluation pipeline
|
||||
- features.py: this file contains importing feature list and support functions for feature engineering
|
||||
- graph.py: this file defines how to build the tensorflow graph with specified model architecture, loss function and training configuration
|
||||
- model_pools.py: this file defines the available model types for the heavy ranker
|
||||
- params.py: this file defines hyper-parameters used in the notification heavy ranker
|
||||
- run_args.py: this file defines command line parameters to run model training & evaluation
|
||||
- update_warm_start_checkpoint.py: this file contains the support to modify checkpoints of the given saved heavy ranker model
|
||||
- lib/BUILD: this file defines python library dependencies for tensorflow model architecture
|
||||
- lib/layers.py: this file defines different type of convolution layers to be used in the heavy ranker model
|
||||
- lib/model.py: this file defines the module containing ClemNet, the heavy ranker model type
|
||||
- lib/params.py: this file defines parameters used in the heavy ranker model
|
BIN
pushservice/src/main/python/models/heavy_ranking/__init__.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/__init__.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/heavy_ranking/deep_norm.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/deep_norm.docx
Normal file
Binary file not shown.
@ -1,136 +0,0 @@
|
||||
"""
|
||||
Training job for the heavy ranker of the push notification service.
|
||||
"""
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
import twml
|
||||
|
||||
from ..libs.metric_fn_utils import flip_disliked_labels, get_metric_fn
|
||||
from ..libs.model_utils import read_config
|
||||
from ..libs.warm_start_utils import get_feature_list_for_heavy_ranking, warm_start_checkpoint
|
||||
from .features import get_feature_config
|
||||
from .model_pools import ALL_MODELS
|
||||
from .params import load_graph_params
|
||||
from .run_args import get_training_arg_parser
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args, _ = get_training_arg_parser().parse_known_args()
|
||||
logging.info(f"Parsed args: {args}")
|
||||
|
||||
params = load_graph_params(args)
|
||||
logging.info(f"Loaded graph params: {params}")
|
||||
|
||||
param_file = os.path.join(args.save_dir, "params.json")
|
||||
logging.info(f"Saving graph params to: {param_file}")
|
||||
with tf.io.gfile.GFile(param_file, mode="w") as file:
|
||||
json.dump(params.json(), file, ensure_ascii=False, indent=4)
|
||||
|
||||
logging.info(f"Get Feature Config: {args.feature_list}")
|
||||
feature_list = read_config(args.feature_list).items()
|
||||
feature_config = get_feature_config(
|
||||
data_spec_path=args.data_spec,
|
||||
params=params,
|
||||
feature_list_provided=feature_list,
|
||||
)
|
||||
feature_list_path = args.feature_list
|
||||
|
||||
warm_start_from = args.warm_start_from
|
||||
if args.warm_start_base_dir:
|
||||
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")
|
||||
|
||||
continuous_binary_feat_list_save_path = os.path.join(
|
||||
args.warm_start_base_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
|
||||
job_name = os.path.basename(args.save_dir)
|
||||
ws_output_ckpt_folder = os.path.join(args.warm_start_base_dir, f"warm_start_for_{job_name}")
|
||||
if tf.io.gfile.exists(ws_output_ckpt_folder):
|
||||
tf.io.gfile.rmtree(ws_output_ckpt_folder)
|
||||
|
||||
tf.io.gfile.mkdir(ws_output_ckpt_folder)
|
||||
|
||||
warm_start_from = warm_start_checkpoint(
|
||||
warm_start_folder,
|
||||
continuous_binary_feat_list_save_path,
|
||||
feature_list_path,
|
||||
args.data_spec,
|
||||
ws_output_ckpt_folder,
|
||||
)
|
||||
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")
|
||||
|
||||
logging.info("Build Trainer.")
|
||||
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
|
||||
|
||||
trainer = twml.trainers.DataRecordTrainer(
|
||||
name="magic_recs",
|
||||
params=args,
|
||||
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
|
||||
save_dir=args.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=flip_disliked_labels(metric_fn),
|
||||
warm_start_from=warm_start_from,
|
||||
)
|
||||
|
||||
logging.info("Build train and eval input functions.")
|
||||
train_input_fn = trainer.get_train_input_fn(shuffle=True)
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
|
||||
learn = trainer.learn
|
||||
if args.distributed or args.num_workers is not None:
|
||||
learn = trainer.train_and_evaluate
|
||||
|
||||
if not args.directly_export_best:
|
||||
logging.info("Starting training")
|
||||
start = datetime.now()
|
||||
learn(
|
||||
early_stop_minimize=False,
|
||||
early_stop_metric="pr_auc_unweighted_OONC",
|
||||
early_stop_patience=args.early_stop_patience,
|
||||
early_stop_tolerance=args.early_stop_tolerance,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_input_fn=train_input_fn,
|
||||
)
|
||||
logging.info(f"Total training time: {datetime.now() - start}")
|
||||
else:
|
||||
logging.info("Directly exporting the model")
|
||||
|
||||
if not args.export_dir:
|
||||
args.export_dir = os.path.join(args.save_dir, "exported_models")
|
||||
|
||||
logging.info(f"Exporting the model to {args.export_dir}.")
|
||||
start = datetime.now()
|
||||
twml.contrib.export.export_fn.export_all_models(
|
||||
trainer=trainer,
|
||||
export_dir=args.export_dir,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
|
||||
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
|
||||
)
|
||||
|
||||
logging.info(f"Total model export time: {datetime.now() - start}")
|
||||
logging.info(f"The MLP directory is: {args.save_dir}")
|
||||
|
||||
continuous_binary_feat_list_save_path = os.path.join(
|
||||
args.save_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
logging.info(
|
||||
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
|
||||
)
|
||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
||||
feature_list_path, args.data_spec
|
||||
)
|
||||
twml.util.write_file(
|
||||
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
logging.info("Done.")
|
BIN
pushservice/src/main/python/models/heavy_ranking/eval.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/eval.docx
Normal file
Binary file not shown.
@ -1,59 +0,0 @@
|
||||
"""
|
||||
Evaluation job for the heavy ranker of the push notification service.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import twml
|
||||
|
||||
from ..libs.metric_fn_utils import get_metric_fn
|
||||
from ..libs.model_utils import read_config
|
||||
from .features import get_feature_config
|
||||
from .model_pools import ALL_MODELS
|
||||
from .params import load_graph_params
|
||||
from .run_args import get_eval_arg_parser
|
||||
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
def main():
|
||||
args, _ = get_eval_arg_parser().parse_known_args()
|
||||
logging.info(f"Parsed args: {args}")
|
||||
|
||||
params = load_graph_params(args)
|
||||
logging.info(f"Loaded graph params: {params}")
|
||||
|
||||
logging.info(f"Get Feature Config: {args.feature_list}")
|
||||
feature_list = read_config(args.feature_list).items()
|
||||
feature_config = get_feature_config(
|
||||
data_spec_path=args.data_spec,
|
||||
params=params,
|
||||
feature_list_provided=feature_list,
|
||||
)
|
||||
|
||||
logging.info("Build DataRecordTrainer.")
|
||||
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
|
||||
|
||||
trainer = twml.trainers.DataRecordTrainer(
|
||||
name="magic_recs",
|
||||
params=args,
|
||||
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
|
||||
save_dir=args.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=metric_fn,
|
||||
)
|
||||
|
||||
logging.info("Run the evaluation.")
|
||||
start = datetime.now()
|
||||
trainer._estimator.evaluate(
|
||||
input_fn=trainer.get_eval_input_fn(repeat=False, shuffle=False),
|
||||
steps=None if (args.eval_steps is not None and args.eval_steps < 0) else args.eval_steps,
|
||||
checkpoint_path=args.eval_checkpoint,
|
||||
)
|
||||
logging.info(f"Evaluating time: {datetime.now() - start}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
logging.info("Job done.")
|
BIN
pushservice/src/main/python/models/heavy_ranking/features.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/features.docx
Normal file
Binary file not shown.
@ -1,138 +0,0 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from twitter.deepbird.projects.magic_recs.libs.model_utils import filter_nans_and_infs
|
||||
import twml
|
||||
from twml.layers import full_sparse, sparse_max_norm
|
||||
|
||||
from .params import FeaturesParams, GraphParams, SparseFeaturesParams
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import Tensor
|
||||
import tensorflow.compat.v1 as tf1
|
||||
|
||||
|
||||
FEAT_CONFIG_DEFAULT_VAL = 0
|
||||
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
|
||||
FEATURE_LIST_DEFAULT_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
|
||||
)
|
||||
|
||||
|
||||
def get_feature_config(data_spec_path=None, feature_list_provided=[], params: GraphParams = None):
|
||||
|
||||
a_string_feat_list = [feat for feat, feat_type in feature_list_provided if feat_type != "S"]
|
||||
|
||||
builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
||||
data_spec_path=data_spec_path, debug=False
|
||||
)
|
||||
|
||||
builder = builder.extract_feature_group(
|
||||
feature_regexes=a_string_feat_list,
|
||||
group_name="continuous_features",
|
||||
default_value=FEAT_CONFIG_DEFAULT_VAL,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
|
||||
builder = builder.extract_feature_group(
|
||||
feature_regexes=a_string_feat_list,
|
||||
group_name="binary_features",
|
||||
type_filter=["BINARY"],
|
||||
)
|
||||
|
||||
if params.model.features.sparse_features:
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=a_string_feat_list,
|
||||
hash_space_size_bits=params.model.features.sparse_features.bits,
|
||||
type_filter=["DISCRETE", "STRING", "SPARSE_BINARY"],
|
||||
output_tensor_name="sparse_not_continuous",
|
||||
)
|
||||
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=[feat for feat, feat_type in feature_list_provided if feat_type == "S"],
|
||||
hash_space_size_bits=params.model.features.sparse_features.bits,
|
||||
type_filter=["SPARSE_CONTINUOUS"],
|
||||
output_tensor_name="sparse_continuous",
|
||||
)
|
||||
|
||||
builder = builder.add_labels([task.label for task in params.tasks] + ["label.ntabDislike"])
|
||||
|
||||
if params.weight:
|
||||
builder = builder.define_weight(params.weight)
|
||||
|
||||
return builder.build()
|
||||
|
||||
|
||||
def dense_features(features: Dict[str, Tensor], training: bool) -> Tensor:
|
||||
"""
|
||||
Performs feature transformations on the raw dense features (continuous and binary).
|
||||
"""
|
||||
with tf.name_scope("dense_features"):
|
||||
x = filter_nans_and_infs(features["continuous_features"])
|
||||
|
||||
x = tf.sign(x) * tf.math.log(tf.abs(x) + 1)
|
||||
x = tf1.layers.batch_normalization(
|
||||
x, momentum=0.9999, training=training, renorm=training, axis=1
|
||||
)
|
||||
x = tf.clip_by_value(x, -5, 5)
|
||||
|
||||
transformed_continous_features = tf.where(tf.math.is_nan(x), tf.zeros_like(x), x)
|
||||
|
||||
binary_features = filter_nans_and_infs(features["binary_features"])
|
||||
binary_features = tf.dtypes.cast(binary_features, tf.float32)
|
||||
|
||||
output = tf.concat([transformed_continous_features, binary_features], axis=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def sparse_features(
|
||||
features: Dict[str, Tensor], training: bool, params: SparseFeaturesParams
|
||||
) -> Tensor:
|
||||
"""
|
||||
Performs feature transformations on the raw sparse features.
|
||||
"""
|
||||
|
||||
with tf.name_scope("sparse_features"):
|
||||
with tf.name_scope("sparse_not_continuous"):
|
||||
sparse_not_continuous = full_sparse(
|
||||
inputs=features["sparse_not_continuous"],
|
||||
output_size=params.embedding_size,
|
||||
use_sparse_grads=training,
|
||||
use_binary_values=False,
|
||||
)
|
||||
|
||||
with tf.name_scope("sparse_continuous"):
|
||||
shape_enforced_input = twml.util.limit_sparse_tensor_size(
|
||||
sparse_tf=features["sparse_continuous"], input_size_bits=params.bits, mask_indices=False
|
||||
)
|
||||
|
||||
normalized_continuous_sparse = sparse_max_norm(
|
||||
inputs=shape_enforced_input, is_training=training
|
||||
)
|
||||
|
||||
sparse_continuous = full_sparse(
|
||||
inputs=normalized_continuous_sparse,
|
||||
output_size=params.embedding_size,
|
||||
use_sparse_grads=training,
|
||||
use_binary_values=False,
|
||||
)
|
||||
|
||||
output = tf.concat([sparse_not_continuous, sparse_continuous], axis=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_features(features: Dict[str, Tensor], training: bool, params: FeaturesParams) -> Tensor:
|
||||
"""
|
||||
Performs feature transformations on the dense and sparse features and combine the resulting
|
||||
tensors into a single one.
|
||||
"""
|
||||
with tf.name_scope("features"):
|
||||
x = dense_features(features, training)
|
||||
tf1.logging.info(f"Dense features: {x.shape}")
|
||||
|
||||
if params.sparse_features:
|
||||
x = tf.concat([x, sparse_features(features, training, params.sparse_features)], axis=1)
|
||||
|
||||
return x
|
BIN
pushservice/src/main/python/models/heavy_ranking/graph.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/graph.docx
Normal file
Binary file not shown.
@ -1,129 +0,0 @@
|
||||
"""
|
||||
Graph class defining methods to obtain key quantities such as:
|
||||
* the logits
|
||||
* the probabilities
|
||||
* the final score
|
||||
* the loss function
|
||||
* the training operator
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
from twitter.deepbird.hparam import HParams
|
||||
import twml
|
||||
|
||||
from ..libs.model_utils import generate_disliked_mask
|
||||
from .params import GraphParams
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf1
|
||||
|
||||
|
||||
class Graph(ABC):
|
||||
def __init__(self, params: GraphParams):
|
||||
self.params = params
|
||||
|
||||
@abstractmethod
|
||||
def get_logits(self, features: Dict[str, tf.Tensor], mode: tf.estimator.ModeKeys) -> tf.Tensor:
|
||||
pass
|
||||
|
||||
def get_probabilities(self, logits: tf.Tensor) -> tf.Tensor:
|
||||
return tf.math.cumprod(tf.nn.sigmoid(logits), axis=1, name="probabilities")
|
||||
|
||||
def get_task_weights(self, labels: tf.Tensor) -> tf.Tensor:
|
||||
oonc_label = tf.reshape(labels[:, 0], shape=(-1, 1))
|
||||
task_weights = tf.concat([tf.ones_like(oonc_label), oonc_label], axis=1)
|
||||
|
||||
n_labels = len(self.params.tasks)
|
||||
task_weights = tf.reshape(task_weights[:, 0:n_labels], shape=(-1, n_labels))
|
||||
|
||||
return task_weights
|
||||
|
||||
def get_loss(self, labels: tf.Tensor, logits: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
with tf.name_scope("weights"):
|
||||
disliked_mask = generate_disliked_mask(labels)
|
||||
|
||||
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
|
||||
|
||||
labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
|
||||
|
||||
with tf.name_scope("task_weight"):
|
||||
task_weights = self.get_task_weights(labels)
|
||||
|
||||
with tf.name_scope("batch_size"):
|
||||
batch_size = tf.cast(tf.shape(labels)[0], dtype=tf.float32, name="batch_size")
|
||||
|
||||
weights = task_weights / batch_size
|
||||
|
||||
with tf.name_scope("loss"):
|
||||
loss = tf.reduce_sum(
|
||||
tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) * weights,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def get_score(self, probabilities: tf.Tensor) -> tf.Tensor:
|
||||
with tf.name_scope("score_weight"):
|
||||
score_weights = tf.constant([task.score_weight for task in self.params.tasks])
|
||||
score_weights = score_weights / tf.reduce_sum(score_weights, axis=0)
|
||||
|
||||
with tf.name_scope("score"):
|
||||
score = tf.reshape(tf.reduce_sum(probabilities * score_weights, axis=1), shape=[-1, 1])
|
||||
|
||||
return score
|
||||
|
||||
def get_train_op(self, loss: tf.Tensor, twml_params) -> Any:
|
||||
with tf.name_scope("optimizer"):
|
||||
learning_rate = twml_params.learning_rate
|
||||
optimizer = tf1.train.GradientDescentOptimizer(learning_rate=learning_rate)
|
||||
|
||||
update_ops = set(tf1.get_collection(tf1.GraphKeys.UPDATE_OPS))
|
||||
with tf.control_dependencies(update_ops):
|
||||
train_op = twml.optimizers.optimize_loss(
|
||||
loss=loss,
|
||||
variables=tf1.trainable_variables(),
|
||||
global_step=tf1.train.get_global_step(),
|
||||
optimizer=optimizer,
|
||||
learning_rate=None,
|
||||
)
|
||||
|
||||
return train_op
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
features: Dict[str, tf.Tensor],
|
||||
labels: tf.Tensor,
|
||||
mode: tf.estimator.ModeKeys,
|
||||
params: HParams,
|
||||
config=None,
|
||||
) -> Dict[str, tf.Tensor]:
|
||||
training = mode == tf.estimator.ModeKeys.TRAIN
|
||||
logits = self.get_logits(features=features, training=training)
|
||||
probabilities = self.get_probabilities(logits=logits)
|
||||
score = None
|
||||
loss = None
|
||||
train_op = None
|
||||
|
||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||
score = self.get_score(probabilities=probabilities)
|
||||
output = {"loss": loss, "train_op": train_op, "prediction": score}
|
||||
|
||||
elif mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
|
||||
loss = self.get_loss(labels=labels, logits=logits)
|
||||
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
train_op = self.get_train_op(loss=loss, twml_params=params)
|
||||
|
||||
output = {"loss": loss, "train_op": train_op, "output": probabilities}
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""
|
||||
Invalid mode. Possible values are: {tf.estimator.ModeKeys.PREDICT}, {tf.estimator.ModeKeys.TRAIN}, and {tf.estimator.ModeKeys.EVAL}
|
||||
. Passed: {mode}
|
||||
"""
|
||||
)
|
||||
|
||||
return output
|
@ -1,42 +0,0 @@
|
||||
python3_library(
|
||||
name = "params_lib",
|
||||
sources = [
|
||||
"params.py",
|
||||
],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
"3rdparty/python/pydantic:default",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "layers_lib",
|
||||
sources = [
|
||||
"layers.py",
|
||||
],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "model_lib",
|
||||
sources = [
|
||||
"model.py",
|
||||
],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
":layers_lib",
|
||||
":params_lib",
|
||||
"3rdparty/python/absl-py:default",
|
||||
],
|
||||
)
|
BIN
pushservice/src/main/python/models/heavy_ranking/lib/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/heavy_ranking/lib/layers.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/layers.docx
Normal file
Binary file not shown.
@ -1,128 +0,0 @@
|
||||
"""
|
||||
Different type of convolution layers to be used in the ClemNet.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class KerasConv1D(tf.keras.layers.Layer):
|
||||
"""
|
||||
Basic Conv1D layer in a wrapper to be compatible with ClemNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int,
|
||||
filters: int,
|
||||
strides: int,
|
||||
padding: str,
|
||||
use_bias: bool = True,
|
||||
kernel_initializer: str = "glorot_uniform",
|
||||
bias_initializer: str = "zeros",
|
||||
**kwargs: Any,
|
||||
):
|
||||
super(KerasConv1D, self).__init__(**kwargs)
|
||||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
self.use_bias = use_bias
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
self.features = input_shape[1]
|
||||
|
||||
self.w = tf.keras.layers.Conv1D(
|
||||
kernel_size=self.kernel_size,
|
||||
filters=self.filters,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
use_bias=self.use_bias,
|
||||
kernel_initializer=self.kernel_initializer,
|
||||
bias_initializer=self.bias_initializer,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
return self.w(inputs)
|
||||
|
||||
|
||||
class ChannelWiseDense(tf.keras.layers.Layer):
|
||||
"""
|
||||
Dense layer is applied to each channel separately. This is more memory and computationally
|
||||
efficient than flattening the channels and performing single dense layers over it which is the
|
||||
default behavior in tf1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_size: int,
|
||||
use_bias: bool,
|
||||
kernel_initializer: str = "uniform_glorot",
|
||||
bias_initializer: str = "zeros",
|
||||
**kwargs: Any,
|
||||
):
|
||||
super(ChannelWiseDense, self).__init__(**kwargs)
|
||||
self.output_size = output_size
|
||||
self.use_bias = use_bias
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
input_size = input_shape[1]
|
||||
channels = input_shape[2]
|
||||
|
||||
self.kernel = self.add_weight(
|
||||
name="kernel",
|
||||
shape=(channels, input_size, self.output_size),
|
||||
initializer=self.kernel_initializer,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
self.bias = self.add_weight(
|
||||
name="bias",
|
||||
shape=(channels, self.output_size),
|
||||
initializer=self.bias_initializer,
|
||||
trainable=self.use_bias,
|
||||
)
|
||||
|
||||
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
x = inputs
|
||||
|
||||
transposed_x = tf.transpose(x, perm=[2, 0, 1])
|
||||
transposed_residual = (
|
||||
tf.transpose(tf.matmul(transposed_x, self.kernel), perm=[1, 0, 2]) + self.bias
|
||||
)
|
||||
output = tf.transpose(transposed_residual, perm=[0, 2, 1])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResidualLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
Layer implementing a 3D-residual connection.
|
||||
"""
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
def call(self, inputs: tf.Tensor, residual: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
shortcut = tf.keras.layers.Conv1D(
|
||||
filters=int(residual.shape[2]), strides=1, kernel_size=1, padding="SAME", use_bias=False
|
||||
)(inputs)
|
||||
|
||||
output = tf.add(shortcut, residual)
|
||||
|
||||
return output
|
BIN
pushservice/src/main/python/models/heavy_ranking/lib/model.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/model.docx
Normal file
Binary file not shown.
@ -1,76 +0,0 @@
|
||||
"""
|
||||
Module containing ClemNet.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from .layers import ChannelWiseDense, KerasConv1D, ResidualLayer
|
||||
from .params import BlockParams, ClemNetParams
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf1
|
||||
|
||||
|
||||
class Block2(tf.keras.layers.Layer):
|
||||
"""
|
||||
Possible ClemNet block. Architecture is as follow:
|
||||
Optional(DenseLayer + BN + Act)
|
||||
Optional(ConvLayer + BN + Act)
|
||||
Optional(Residual Layer)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params: BlockParams, **kwargs: Any):
|
||||
super(Block2, self).__init__(**kwargs)
|
||||
self.params = params
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
|
||||
x = inputs
|
||||
if self.params.dense:
|
||||
x = ChannelWiseDense(**self.params.dense.dict())(inputs=x, training=training)
|
||||
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
|
||||
x = tf.keras.layers.Activation(self.params.activation)(x)
|
||||
|
||||
if self.params.conv:
|
||||
x = KerasConv1D(**self.params.conv.dict())(inputs=x, training=training)
|
||||
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
|
||||
x = tf.keras.layers.Activation(self.params.activation)(x)
|
||||
|
||||
if self.params.residual:
|
||||
x = ResidualLayer()(inputs=inputs, residual=x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ClemNet(tf.keras.layers.Layer):
|
||||
"""
|
||||
A residual network stacking residual blocks composed of dense layers and convolutions.
|
||||
"""
|
||||
|
||||
def __init__(self, params: ClemNetParams, **kwargs: Any):
|
||||
super(ClemNet, self).__init__(**kwargs)
|
||||
self.params = params
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert len(input_shape) in (
|
||||
2,
|
||||
3,
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
|
||||
if len(inputs.shape) < 3:
|
||||
inputs = tf.expand_dims(inputs, axis=-1)
|
||||
|
||||
x = inputs
|
||||
for block_params in self.params.blocks:
|
||||
x = Block2(block_params)(inputs=x, training=training)
|
||||
|
||||
x = tf.keras.layers.Flatten(name="flattened")(x)
|
||||
if self.params.top:
|
||||
x = tf.keras.layers.Dense(units=self.params.top.n_labels, name="logits")(x)
|
||||
|
||||
return x
|
BIN
pushservice/src/main/python/models/heavy_ranking/lib/params.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/lib/params.docx
Normal file
Binary file not shown.
@ -1,49 +0,0 @@
|
||||
"""
|
||||
Parameters used in ClemNet.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, PositiveInt
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
class ExtendedBaseModel(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class DenseParams(ExtendedBaseModel):
|
||||
name: Optional[str]
|
||||
bias_initializer: str = "zeros"
|
||||
kernel_initializer: str = "glorot_uniform"
|
||||
output_size: PositiveInt
|
||||
use_bias: bool = Field(True)
|
||||
|
||||
|
||||
class ConvParams(ExtendedBaseModel):
|
||||
name: Optional[str]
|
||||
bias_initializer: str = "zeros"
|
||||
filters: PositiveInt
|
||||
kernel_initializer: str = "glorot_uniform"
|
||||
kernel_size: PositiveInt
|
||||
padding: str = "SAME"
|
||||
strides: PositiveInt = 1
|
||||
use_bias: bool = Field(True)
|
||||
|
||||
|
||||
class BlockParams(ExtendedBaseModel):
|
||||
activation: Optional[str]
|
||||
conv: Optional[ConvParams]
|
||||
dense: Optional[DenseParams]
|
||||
residual: Optional[bool]
|
||||
|
||||
|
||||
class TopLayerParams(ExtendedBaseModel):
|
||||
n_labels: PositiveInt
|
||||
|
||||
|
||||
class ClemNetParams(ExtendedBaseModel):
|
||||
blocks: List[BlockParams] = []
|
||||
top: Optional[TopLayerParams]
|
Binary file not shown.
@ -1,34 +0,0 @@
|
||||
"""
|
||||
Candidate architectures for each task's.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .features import get_features
|
||||
from .graph import Graph
|
||||
from .lib.model import ClemNet
|
||||
from .params import ModelTypeEnum
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class MagicRecsClemNet(Graph):
|
||||
def get_logits(self, features: Dict[str, tf.Tensor], training: bool) -> tf.Tensor:
|
||||
|
||||
with tf.name_scope("logits"):
|
||||
inputs = get_features(features=features, training=training, params=self.params.model.features)
|
||||
|
||||
with tf.name_scope("OONC_logits"):
|
||||
model = ClemNet(params=self.params.model.architecture)
|
||||
oonc_logit = model(inputs=inputs, training=training)
|
||||
|
||||
with tf.name_scope("EngagementGivenOONC_logits"):
|
||||
model = ClemNet(params=self.params.model.architecture)
|
||||
eng_logits = model(inputs=inputs, training=training)
|
||||
|
||||
return tf.concat([oonc_logit, eng_logits], axis=1)
|
||||
|
||||
|
||||
ALL_MODELS = {ModelTypeEnum.clemnet: MagicRecsClemNet}
|
BIN
pushservice/src/main/python/models/heavy_ranking/params.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/params.docx
Normal file
Binary file not shown.
@ -1,89 +0,0 @@
|
||||
import enum
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from .lib.params import BlockParams, ClemNetParams, ConvParams, DenseParams, TopLayerParams
|
||||
|
||||
from pydantic import BaseModel, Extra, NonNegativeFloat
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
class ExtendedBaseModel(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class SparseFeaturesParams(ExtendedBaseModel):
|
||||
bits: int
|
||||
embedding_size: int
|
||||
|
||||
|
||||
class FeaturesParams(ExtendedBaseModel):
|
||||
sparse_features: Optional[SparseFeaturesParams]
|
||||
|
||||
|
||||
class ModelTypeEnum(str, enum.Enum):
|
||||
clemnet: str = "clemnet"
|
||||
|
||||
|
||||
class ModelParams(ExtendedBaseModel):
|
||||
name: ModelTypeEnum
|
||||
features: FeaturesParams
|
||||
architecture: ClemNetParams
|
||||
|
||||
|
||||
class TaskNameEnum(str, enum.Enum):
|
||||
oonc: str = "OONC"
|
||||
engagement: str = "Engagement"
|
||||
|
||||
|
||||
class Task(ExtendedBaseModel):
|
||||
name: TaskNameEnum
|
||||
label: str
|
||||
score_weight: NonNegativeFloat
|
||||
|
||||
|
||||
DEFAULT_TASKS = [
|
||||
Task(name=TaskNameEnum.oonc, label="label", score_weight=0.9),
|
||||
Task(name=TaskNameEnum.engagement, label="label.engagement", score_weight=0.1),
|
||||
]
|
||||
|
||||
|
||||
class GraphParams(ExtendedBaseModel):
|
||||
tasks: List[Task] = DEFAULT_TASKS
|
||||
model: ModelParams
|
||||
weight: Optional[str]
|
||||
|
||||
|
||||
DEFAULT_ARCHITECTURE_PARAMS = ClemNetParams(
|
||||
blocks=[
|
||||
BlockParams(
|
||||
activation="relu",
|
||||
conv=ConvParams(kernel_size=3, filters=5),
|
||||
dense=DenseParams(output_size=output_size),
|
||||
residual=False,
|
||||
)
|
||||
for output_size in [1024, 512, 256, 128]
|
||||
],
|
||||
top=TopLayerParams(n_labels=1),
|
||||
)
|
||||
|
||||
DEFAULT_GRAPH_PARAMS = GraphParams(
|
||||
model=ModelParams(
|
||||
name=ModelTypeEnum.clemnet,
|
||||
architecture=DEFAULT_ARCHITECTURE_PARAMS,
|
||||
features=FeaturesParams(sparse_features=SparseFeaturesParams(bits=18, embedding_size=50)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def load_graph_params(args) -> GraphParams:
|
||||
params = DEFAULT_GRAPH_PARAMS
|
||||
if args.param_file:
|
||||
with tf.io.gfile.GFile(args.param_file, mode="r+") as file:
|
||||
params = GraphParams.parse_obj(json.load(file))
|
||||
|
||||
return params
|
BIN
pushservice/src/main/python/models/heavy_ranking/run_args.docx
Normal file
BIN
pushservice/src/main/python/models/heavy_ranking/run_args.docx
Normal file
Binary file not shown.
@ -1,59 +0,0 @@
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
from .features import FEATURE_LIST_DEFAULT_PATH
|
||||
|
||||
|
||||
def get_training_arg_parser():
|
||||
parser = DataRecordTrainer.add_parser_arguments()
|
||||
|
||||
parser.add_argument(
|
||||
"--feature_list",
|
||||
default=FEATURE_LIST_DEFAULT_PATH,
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--param_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to JSON file containing the graph parameters. If None, model will load default parameters.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--directly_export_best",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to directly_export best_checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_base_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="latest ckpt in this folder will be used to ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Which type of model to train.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_eval_arg_parser():
|
||||
parser = get_training_arg_parser()
|
||||
parser.add_argument(
|
||||
"--eval_checkpoint",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Which checkpoint to use for evaluation",
|
||||
)
|
||||
|
||||
return parser
|
Binary file not shown.
@ -1,146 +0,0 @@
|
||||
"""
|
||||
Model for modifying the checkpoints of the magic recs cnn Model with addition, deletion, and reordering
|
||||
of continuous and binary features.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from twitter.deepbird.projects.magic_recs.libs.get_feat_config import FEATURE_LIST_DEFAULT_PATH
|
||||
from twitter.deepbird.projects.magic_recs.libs.warm_start_utils_v11 import (
|
||||
get_feature_list_for_heavy_ranking,
|
||||
mkdirp,
|
||||
rename_dir,
|
||||
rmdir,
|
||||
warm_start_checkpoint,
|
||||
)
|
||||
import twml
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
def get_arg_parser():
|
||||
parser = DataRecordTrainer.add_parser_arguments()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="deepnorm_gbdt_inputdrop2_rescale",
|
||||
type=str,
|
||||
help="specify the model type to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_trainer_name",
|
||||
default="None",
|
||||
type=str,
|
||||
help="deprecated, added here just for api compatibility.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_base_dir",
|
||||
default="none",
|
||||
type=str,
|
||||
help="latest ckpt in this folder will be used.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_checkpoint_dir",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Output folder for warm started ckpt. If none, it will move warm_start_base_dir to backup, and overwrite it",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature_list",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--old_feature_list",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params(args=None):
|
||||
parser = get_arg_parser()
|
||||
if args is None:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def _main():
|
||||
opt = get_params()
|
||||
logging.info("parse is: ")
|
||||
logging.info(opt)
|
||||
|
||||
if opt.feature_list == "none":
|
||||
feature_list_path = FEATURE_LIST_DEFAULT_PATH
|
||||
else:
|
||||
feature_list_path = opt.feature_list
|
||||
|
||||
if opt.warm_start_base_dir != "none" and tf.io.gfile.exists(opt.warm_start_base_dir):
|
||||
if opt.output_checkpoint_dir == "none" or opt.output_checkpoint_dir == opt.warm_start_base_dir:
|
||||
_warm_start_base_dir = os.path.normpath(opt.warm_start_base_dir) + "_backup_warm_start"
|
||||
_output_folder_dir = opt.warm_start_base_dir
|
||||
|
||||
rename_dir(opt.warm_start_base_dir, _warm_start_base_dir)
|
||||
tf.logging.info(f"moved {opt.warm_start_base_dir} to {_warm_start_base_dir}")
|
||||
else:
|
||||
_warm_start_base_dir = opt.warm_start_base_dir
|
||||
_output_folder_dir = opt.output_checkpoint_dir
|
||||
|
||||
continuous_binary_feat_list_save_path = os.path.join(
|
||||
_warm_start_base_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
|
||||
if opt.old_feature_list != "none":
|
||||
tf.logging.info("getting old continuous_binary_feat_list")
|
||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
||||
opt.old_feature_list, opt.data_spec
|
||||
)
|
||||
rmdir(continuous_binary_feat_list_save_path)
|
||||
twml.util.write_file(
|
||||
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
||||
)
|
||||
tf.logging.info(f"Finish writting files to {continuous_binary_feat_list_save_path}")
|
||||
|
||||
warm_start_folder = os.path.join(_warm_start_base_dir, "best_checkpoint")
|
||||
if not tf.io.gfile.exists(warm_start_folder):
|
||||
warm_start_folder = _warm_start_base_dir
|
||||
|
||||
rmdir(_output_folder_dir)
|
||||
mkdirp(_output_folder_dir)
|
||||
|
||||
new_ckpt = warm_start_checkpoint(
|
||||
warm_start_folder,
|
||||
continuous_binary_feat_list_save_path,
|
||||
feature_list_path,
|
||||
opt.data_spec,
|
||||
_output_folder_dir,
|
||||
opt.model_type,
|
||||
)
|
||||
logging.info(f"Created new ckpt {new_ckpt} from {warm_start_folder}")
|
||||
|
||||
tf.logging.info("getting new continuous_binary_feat_list")
|
||||
new_continuous_binary_feat_list_save_path = os.path.join(
|
||||
_output_folder_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
||||
feature_list_path, opt.data_spec
|
||||
)
|
||||
rmdir(new_continuous_binary_feat_list_save_path)
|
||||
twml.util.write_file(
|
||||
new_continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
||||
)
|
||||
tf.logging.info(f"Finish writting files to {new_continuous_binary_feat_list_save_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
@ -1,16 +0,0 @@
|
||||
python3_library(
|
||||
name = "libs",
|
||||
sources = ["*.py"],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
"cortex/recsys/src/python/twitter/cortex/recsys/utils",
|
||||
"magicpony/common/file_access/src/python/twitter/magicpony/common/file_access",
|
||||
"src/python/twitter/cortex/ml/embeddings/deepbird",
|
||||
"src/python/twitter/cortex/ml/embeddings/deepbird/grouped_metrics",
|
||||
"src/python/twitter/deepbird/util/data",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
BIN
pushservice/src/main/python/models/libs/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/libs/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/libs/__init__.docx
Normal file
BIN
pushservice/src/main/python/models/libs/__init__.docx
Normal file
Binary file not shown.
Binary file not shown.
@ -1,56 +0,0 @@
|
||||
# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument
|
||||
"""
|
||||
Implementing Full Sparse Layer, allow specify use_binary_value in call() to
|
||||
overide default action.
|
||||
"""
|
||||
|
||||
from twml.layers import FullSparse as defaultFullSparse
|
||||
from twml.layers.full_sparse import sparse_dense_matmul
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class FullSparse(defaultFullSparse):
|
||||
def call(self, inputs, use_binary_values=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
inputs:
|
||||
A SparseTensor or a list of SparseTensors.
|
||||
If `inputs` is a list, all tensors must have same `dense_shape`.
|
||||
|
||||
Returns:
|
||||
- If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`.
|
||||
- If `inputs` is a `list[SparseTensor`, then returns
|
||||
`bias + add_n([sp_a * dense_b for sp_a in inputs])`.
|
||||
"""
|
||||
|
||||
if use_binary_values is not None:
|
||||
default_use_binary_values = use_binary_values
|
||||
else:
|
||||
default_use_binary_values = self.use_binary_values
|
||||
|
||||
if isinstance(default_use_binary_values, (list, tuple)):
|
||||
raise ValueError(
|
||||
"use_binary_values can not be %s when inputs is %s"
|
||||
% (type(default_use_binary_values), type(inputs))
|
||||
)
|
||||
|
||||
outputs = sparse_dense_matmul(
|
||||
inputs,
|
||||
self.weight,
|
||||
self.use_sparse_grads,
|
||||
default_use_binary_values,
|
||||
name="sparse_mm",
|
||||
partition_axis=self.partition_axis,
|
||||
num_partitions=self.num_partitions,
|
||||
compress_ids=self._use_compression,
|
||||
cast_indices_dtype=self._cast_indices_dtype,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
outputs = tf.nn.bias_add(outputs, self.bias)
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs) # pylint: disable=not-callable
|
||||
return outputs
|
BIN
pushservice/src/main/python/models/libs/get_feat_config.docx
Normal file
BIN
pushservice/src/main/python/models/libs/get_feat_config.docx
Normal file
Binary file not shown.
@ -1,176 +0,0 @@
|
||||
import os
|
||||
|
||||
from twitter.deepbird.projects.magic_recs.libs.metric_fn_utils import USER_AGE_FEATURE_NAME
|
||||
from twitter.deepbird.projects.magic_recs.libs.model_utils import read_config
|
||||
from twml.contrib import feature_config as contrib_feature_config
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
FEAT_CONFIG_DEFAULT_VAL = -1.23456789
|
||||
|
||||
DEFAULT_INPUT_SIZE_BITS = 18
|
||||
|
||||
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
|
||||
FEATURE_LIST_DEFAULT_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
|
||||
)
|
||||
|
||||
DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH = "./feature_list_light_ranking.yaml"
|
||||
FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH
|
||||
)
|
||||
|
||||
FEATURE_LIST_DEFAULT = read_config(FEATURE_LIST_DEFAULT_PATH).items()
|
||||
FEATURE_LIST_LIGHT_RANKING_DEFAULT = read_config(FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH).items()
|
||||
|
||||
|
||||
LABELS = ["label"]
|
||||
LABELS_MTL = {"OONC": ["label"], "OONC_Engagement": ["label", "label.engagement"]}
|
||||
LABELS_LR = {
|
||||
"Sent": ["label.sent"],
|
||||
"HeavyRankPosition": ["meta.ranking.is_top3"],
|
||||
"HeavyRankProbability": ["meta.ranking.weighted_oonc_model_score"],
|
||||
}
|
||||
|
||||
|
||||
def _get_new_feature_config_base(
|
||||
data_spec_path,
|
||||
labels,
|
||||
add_sparse_continous=True,
|
||||
add_gbdt=True,
|
||||
add_user_id=False,
|
||||
add_timestamp=False,
|
||||
add_user_age=False,
|
||||
feature_list_provided=[],
|
||||
opt=None,
|
||||
run_light_ranking_group_metrics_in_bq=False,
|
||||
):
|
||||
"""
|
||||
Getter of the feature config based on specification.
|
||||
|
||||
Args:
|
||||
data_spec_path: A string indicating the path of the data_spec.json file, which could be
|
||||
either a local path or a hdfs path.
|
||||
labels: A list of strings indicating the name of the label in the data spec.
|
||||
add_sparse_continous: A bool indicating if sparse_continuous feature needs to be included.
|
||||
add_gbdt: A bool indicating if gbdt feature needs to be included.
|
||||
add_user_id: A bool indicating if user_id feature needs to be included.
|
||||
add_timestamp: A bool indicating if timestamp feature needs to be included. This will be useful
|
||||
for sequential models and meta learning models.
|
||||
add_user_age: A bool indicating if the user age feature needs to be included.
|
||||
feature_list_provided: A list of features thats need to be included. If not specified, will use
|
||||
FEATURE_LIST_DEFAULT by default.
|
||||
opt: A namespace of arguments indicating the hyparameters.
|
||||
run_light_ranking_group_metrics_in_bq: A bool indicating if heavy ranker score info needs to be included to compute group metrics in BigQuery.
|
||||
|
||||
Returns:
|
||||
A twml feature config object.
|
||||
"""
|
||||
|
||||
input_size_bits = DEFAULT_INPUT_SIZE_BITS if opt is None else opt.input_size_bits
|
||||
|
||||
feature_list = feature_list_provided if feature_list_provided != [] else FEATURE_LIST_DEFAULT
|
||||
a_string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
||||
|
||||
builder = contrib_feature_config.FeatureConfigBuilder(data_spec_path=data_spec_path)
|
||||
|
||||
builder = builder.extract_feature_group(
|
||||
feature_regexes=a_string_feat_list,
|
||||
group_name="continuous",
|
||||
default_value=FEAT_CONFIG_DEFAULT_VAL,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=a_string_feat_list,
|
||||
output_tensor_name="sparse_no_continuous",
|
||||
hash_space_size_bits=input_size_bits,
|
||||
type_filter=["BINARY", "DISCRETE", "STRING", "SPARSE_BINARY"],
|
||||
)
|
||||
|
||||
if add_gbdt:
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=["ads\..*"],
|
||||
output_tensor_name="gbdt_sparse",
|
||||
hash_space_size_bits=input_size_bits,
|
||||
)
|
||||
|
||||
if add_sparse_continous:
|
||||
s_string_feat_list = [f[0] for f in feature_list if f[1] == "S"]
|
||||
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=s_string_feat_list,
|
||||
output_tensor_name="sparse_continuous",
|
||||
hash_space_size_bits=input_size_bits,
|
||||
type_filter=["SPARSE_CONTINUOUS"],
|
||||
)
|
||||
|
||||
if add_user_id:
|
||||
builder = builder.extract_feature("meta.user_id")
|
||||
if add_timestamp:
|
||||
builder = builder.extract_feature("meta.timestamp")
|
||||
if add_user_age:
|
||||
builder = builder.extract_feature(USER_AGE_FEATURE_NAME)
|
||||
|
||||
if run_light_ranking_group_metrics_in_bq:
|
||||
builder = builder.extract_feature("meta.trace_id")
|
||||
builder = builder.extract_feature("meta.ranking.weighted_oonc_model_score")
|
||||
|
||||
builder = builder.add_labels(labels).define_weight("meta.weight")
|
||||
|
||||
return builder.build()
|
||||
|
||||
|
||||
def get_feature_config_with_sparse_continuous(
|
||||
data_spec_path,
|
||||
feature_list_provided=[],
|
||||
opt=None,
|
||||
add_user_id=False,
|
||||
add_timestamp=False,
|
||||
add_user_age=False,
|
||||
):
|
||||
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "OONC"
|
||||
if task_name not in LABELS_MTL:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
return _get_new_feature_config_base(
|
||||
data_spec_path=data_spec_path,
|
||||
labels=LABELS_MTL[task_name],
|
||||
add_sparse_continous=True,
|
||||
add_user_id=add_user_id,
|
||||
add_timestamp=add_timestamp,
|
||||
add_user_age=add_user_age,
|
||||
feature_list_provided=feature_list_provided,
|
||||
opt=opt,
|
||||
)
|
||||
|
||||
|
||||
def get_feature_config_light_ranking(
|
||||
data_spec_path,
|
||||
feature_list_provided=[],
|
||||
opt=None,
|
||||
add_user_id=True,
|
||||
add_timestamp=False,
|
||||
add_user_age=False,
|
||||
add_gbdt=False,
|
||||
run_light_ranking_group_metrics_in_bq=False,
|
||||
):
|
||||
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "HeavyRankPosition"
|
||||
if task_name not in LABELS_LR:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
if not feature_list_provided:
|
||||
feature_list_provided = FEATURE_LIST_LIGHT_RANKING_DEFAULT
|
||||
|
||||
return _get_new_feature_config_base(
|
||||
data_spec_path=data_spec_path,
|
||||
labels=LABELS_LR[task_name],
|
||||
add_sparse_continous=False,
|
||||
add_gbdt=add_gbdt,
|
||||
add_user_id=add_user_id,
|
||||
add_timestamp=add_timestamp,
|
||||
add_user_age=add_user_age,
|
||||
feature_list_provided=feature_list_provided,
|
||||
opt=opt,
|
||||
run_light_ranking_group_metrics_in_bq=run_light_ranking_group_metrics_in_bq,
|
||||
)
|
BIN
pushservice/src/main/python/models/libs/graph_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/graph_utils.docx
Normal file
Binary file not shown.
@ -1,42 +0,0 @@
|
||||
"""
|
||||
Utilties that aid in building the magic recs graph.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
def get_trainable_variables(all_trainable_variables, trainable_regexes):
|
||||
"""Returns a subset of trainable variables for training.
|
||||
|
||||
Given a collection of trainable variables, this will return all those that match the given regexes.
|
||||
Will also log those variables.
|
||||
|
||||
Args:
|
||||
all_trainable_variables (a collection of trainable tf.Variable): The variables to search through.
|
||||
trainable_regexes (a collection of regexes): Variables that match any regex will be included.
|
||||
|
||||
Returns a list of tf.Variable
|
||||
"""
|
||||
if trainable_regexes is None or len(trainable_regexes) == 0:
|
||||
tf.logging.info("No trainable regexes found. Not using get_trainable_variables behavior.")
|
||||
return None
|
||||
|
||||
assert any(
|
||||
tf.is_tensor(var) for var in all_trainable_variables
|
||||
), f"Non TF variable found: {all_trainable_variables}"
|
||||
trainable_variables = list(
|
||||
filter(
|
||||
lambda var: any(re.match(regex, var.name, re.IGNORECASE) for regex in trainable_regexes),
|
||||
all_trainable_variables,
|
||||
)
|
||||
)
|
||||
tf.logging.info(f"Using filtered trainable variables: {trainable_variables}")
|
||||
|
||||
assert (
|
||||
trainable_variables
|
||||
), "Did not find trainable variables after filtering after filtering from {} number of vars originaly. All vars: {} and train regexes: {}".format(
|
||||
len(all_trainable_variables), all_trainable_variables, trainable_regexes
|
||||
)
|
||||
return trainable_variables
|
BIN
pushservice/src/main/python/models/libs/group_metrics.docx
Normal file
BIN
pushservice/src/main/python/models/libs/group_metrics.docx
Normal file
Binary file not shown.
@ -1,114 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.computation import (
|
||||
write_grouped_metrics_to_mldash,
|
||||
)
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
|
||||
ClassificationGroupedMetricsConfiguration,
|
||||
NDCGGroupedMetricsConfiguration,
|
||||
)
|
||||
import twml
|
||||
|
||||
from .light_ranking_metrics import (
|
||||
CGRGroupedMetricsConfiguration,
|
||||
ExpectedLossGroupedMetricsConfiguration,
|
||||
RecallGroupedMetricsConfiguration,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def run_group_metrics(trainer, data_dir, model_path, parse_fn, group_feature_name="meta.user_id"):
|
||||
|
||||
start_time = time.time()
|
||||
logging.info("Evaluating with group metrics.")
|
||||
|
||||
metrics = write_grouped_metrics_to_mldash(
|
||||
trainer=trainer,
|
||||
data_dir=data_dir,
|
||||
model_path=model_path,
|
||||
group_fn=lambda datarecord: str(
|
||||
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
|
||||
),
|
||||
parse_fn=parse_fn,
|
||||
metric_configurations=[
|
||||
ClassificationGroupedMetricsConfiguration(),
|
||||
NDCGGroupedMetricsConfiguration(k=[5, 10, 20]),
|
||||
],
|
||||
total_records_to_read=1000000000,
|
||||
shuffle=False,
|
||||
mldash_metrics_name="grouped_metrics",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info(f"Evaluated Group Metics: {metrics}.")
|
||||
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
|
||||
|
||||
|
||||
def run_group_metrics_light_ranking(
|
||||
trainer, data_dir, model_path, parse_fn, group_feature_name="meta.trace_id"
|
||||
):
|
||||
|
||||
start_time = time.time()
|
||||
logging.info("Evaluating with group metrics.")
|
||||
|
||||
metrics = write_grouped_metrics_to_mldash(
|
||||
trainer=trainer,
|
||||
data_dir=data_dir,
|
||||
model_path=model_path,
|
||||
group_fn=lambda datarecord: str(
|
||||
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
|
||||
),
|
||||
parse_fn=parse_fn,
|
||||
metric_configurations=[
|
||||
CGRGroupedMetricsConfiguration(lightNs=[50, 100, 200], heavyKs=[1, 3, 10, 20, 50]),
|
||||
RecallGroupedMetricsConfiguration(n=[50, 100, 200], k=[1, 3, 10, 20, 50]),
|
||||
ExpectedLossGroupedMetricsConfiguration(lightNs=[50, 100, 200]),
|
||||
],
|
||||
total_records_to_read=10000000,
|
||||
num_batches_to_load=50,
|
||||
batch_size=1024,
|
||||
shuffle=False,
|
||||
mldash_metrics_name="grouped_metrics_for_light_ranking",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info(f"Evaluated Group Metics for Light Ranking: {metrics}.")
|
||||
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
|
||||
|
||||
|
||||
def run_group_metrics_light_ranking_in_bq(trainer, params, checkpoint_path):
|
||||
logging.info("getting Test Predictions for Light Ranking Group Metrics in BigQuery !!!")
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
info_pool = []
|
||||
|
||||
for result in trainer.estimator.predict(
|
||||
eval_input_fn, checkpoint_path=checkpoint_path, yield_single_examples=False
|
||||
):
|
||||
traceID = result["trace_id"]
|
||||
pred = result["prediction"]
|
||||
label = result["target"]
|
||||
info = np.concatenate([traceID, pred, label], axis=1)
|
||||
info_pool.append(info)
|
||||
|
||||
info_pool = np.concatenate(info_pool)
|
||||
|
||||
locname = "/tmp/000/"
|
||||
if not os.path.exists(locname):
|
||||
os.makedirs(locname)
|
||||
|
||||
locfile = locname + params.pred_file_name
|
||||
columns = ["trace_id", "model_prediction", "meta__ranking__weighted_oonc_model_score"]
|
||||
np.savetxt(locfile, info_pool, delimiter=",", header=",".join(columns))
|
||||
tf.io.gfile.copy(locfile, params.pred_file_path + params.pred_file_name, overwrite=True)
|
||||
|
||||
if os.path.isfile(locfile):
|
||||
os.remove(locfile)
|
||||
|
||||
logging.info("Done Prediction for Light Ranking Group Metrics in BigQuery.")
|
BIN
pushservice/src/main/python/models/libs/initializer.docx
Normal file
BIN
pushservice/src/main/python/models/libs/initializer.docx
Normal file
Binary file not shown.
@ -1,118 +0,0 @@
|
||||
import numpy as np
|
||||
from tensorflow.keras import backend as K
|
||||
|
||||
|
||||
class VarianceScaling(object):
|
||||
"""Initializer capable of adapting its scale to the shape of weights.
|
||||
With `distribution="normal"`, samples are drawn from a truncated normal
|
||||
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
|
||||
- number of input units in the weight tensor, if mode = "fan_in"
|
||||
- number of output units, if mode = "fan_out"
|
||||
- average of the numbers of input and output units, if mode = "fan_avg"
|
||||
With `distribution="uniform"`,
|
||||
samples are drawn from a uniform distribution
|
||||
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
|
||||
# Arguments
|
||||
scale: Scaling factor (positive float).
|
||||
mode: One of "fan_in", "fan_out", "fan_avg".
|
||||
distribution: Random distribution to use. One of "normal", "uniform".
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
# Raises
|
||||
ValueError: In case of an invalid value for the "scale", mode" or
|
||||
"distribution" arguments."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale=1.0,
|
||||
mode="fan_in",
|
||||
distribution="normal",
|
||||
seed=None,
|
||||
fan_in=None,
|
||||
fan_out=None,
|
||||
):
|
||||
self.fan_in = fan_in
|
||||
self.fan_out = fan_out
|
||||
if scale <= 0.0:
|
||||
raise ValueError("`scale` must be a positive float. Got:", scale)
|
||||
mode = mode.lower()
|
||||
if mode not in {"fan_in", "fan_out", "fan_avg"}:
|
||||
raise ValueError(
|
||||
"Invalid `mode` argument: " 'expected on of {"fan_in", "fan_out", "fan_avg"} ' "but got",
|
||||
mode,
|
||||
)
|
||||
distribution = distribution.lower()
|
||||
if distribution not in {"normal", "uniform"}:
|
||||
raise ValueError(
|
||||
"Invalid `distribution` argument: " 'expected one of {"normal", "uniform"} ' "but got",
|
||||
distribution,
|
||||
)
|
||||
self.scale = scale
|
||||
self.mode = mode
|
||||
self.distribution = distribution
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
fan_in = shape[-2] if self.fan_in is None else self.fan_in
|
||||
fan_out = shape[-1] if self.fan_out is None else self.fan_out
|
||||
|
||||
scale = self.scale
|
||||
if self.mode == "fan_in":
|
||||
scale /= max(1.0, fan_in)
|
||||
elif self.mode == "fan_out":
|
||||
scale /= max(1.0, fan_out)
|
||||
else:
|
||||
scale /= max(1.0, float(fan_in + fan_out) / 2)
|
||||
if self.distribution == "normal":
|
||||
stddev = np.sqrt(scale) / 0.87962566103423978
|
||||
return K.truncated_normal(shape, 0.0, stddev, dtype=dtype, seed=self.seed)
|
||||
else:
|
||||
limit = np.sqrt(3.0 * scale)
|
||||
return K.random_uniform(shape, -limit, limit, dtype=dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"scale": self.scale,
|
||||
"mode": self.mode,
|
||||
"distribution": self.distribution,
|
||||
"seed": self.seed,
|
||||
}
|
||||
|
||||
|
||||
def customized_glorot_uniform(seed=None, fan_in=None, fan_out=None):
|
||||
"""Glorot uniform initializer, also called Xavier uniform initializer.
|
||||
It draws samples from a uniform distribution within [-limit, limit]
|
||||
where `limit` is `sqrt(6 / (fan_in + fan_out))`
|
||||
where `fan_in` is the number of input units in the weight tensor
|
||||
and `fan_out` is the number of output units in the weight tensor.
|
||||
# Arguments
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
# Returns
|
||||
An initializer."""
|
||||
return VarianceScaling(
|
||||
scale=1.0,
|
||||
mode="fan_avg",
|
||||
distribution="uniform",
|
||||
seed=seed,
|
||||
fan_in=fan_in,
|
||||
fan_out=fan_out,
|
||||
)
|
||||
|
||||
|
||||
def customized_glorot_norm(seed=None, fan_in=None, fan_out=None):
|
||||
"""Glorot norm initializer, also called Xavier uniform initializer.
|
||||
It draws samples from a uniform distribution within [-limit, limit]
|
||||
where `limit` is `sqrt(6 / (fan_in + fan_out))`
|
||||
where `fan_in` is the number of input units in the weight tensor
|
||||
and `fan_out` is the number of output units in the weight tensor.
|
||||
# Arguments
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
# Returns
|
||||
An initializer."""
|
||||
return VarianceScaling(
|
||||
scale=1.0,
|
||||
mode="fan_avg",
|
||||
distribution="normal",
|
||||
seed=seed,
|
||||
fan_in=fan_in,
|
||||
fan_out=fan_out,
|
||||
)
|
Binary file not shown.
@ -1,255 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
|
||||
GroupedMetricsConfiguration,
|
||||
)
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.helpers import (
|
||||
extract_prediction_from_prediction_record,
|
||||
)
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def score_loss_at_n(labels, predictions, lightN):
|
||||
"""
|
||||
Compute the absolute ScoreLoss ranking metric
|
||||
Args:
|
||||
labels (list) : A list of label values (HeavyRanking Reference)
|
||||
predictions (list): A list of prediction values (LightRanking Predictions)
|
||||
lightN (int): size of the list at which of Initial candidates to compute ScoreLoss. (LightRanking)
|
||||
"""
|
||||
assert len(labels) == len(predictions)
|
||||
|
||||
if lightN <= 0:
|
||||
return None
|
||||
|
||||
labels_with_predictions = zip(labels, predictions)
|
||||
labels_with_sorted_predictions = sorted(
|
||||
labels_with_predictions, key=lambda x: x[1], reverse=True
|
||||
)[:lightN]
|
||||
labels_top1_light = max([label for label, _ in labels_with_sorted_predictions])
|
||||
labels_top1_heavy = max(labels)
|
||||
|
||||
return labels_top1_heavy - labels_top1_light
|
||||
|
||||
|
||||
def cgr_at_nk(labels, predictions, lightN, heavyK):
|
||||
"""
|
||||
Compute Cumulative Gain Ratio (CGR) ranking metric
|
||||
Args:
|
||||
labels (list) : A list of label values (HeavyRanking Reference)
|
||||
predictions (list): A list of prediction values (LightRanking Predictions)
|
||||
lightN (int): size of the list at which of Initial candidates to compute CGR. (LightRanking)
|
||||
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
|
||||
"""
|
||||
assert len(labels) == len(predictions)
|
||||
|
||||
if (not lightN) or (not heavyK):
|
||||
out = None
|
||||
elif lightN <= 0 or heavyK <= 0:
|
||||
out = None
|
||||
else:
|
||||
|
||||
labels_with_predictions = zip(labels, predictions)
|
||||
labels_with_sorted_predictions = sorted(
|
||||
labels_with_predictions, key=lambda x: x[1], reverse=True
|
||||
)[:lightN]
|
||||
labels_topN_light = [label for label, _ in labels_with_sorted_predictions]
|
||||
|
||||
if lightN <= heavyK:
|
||||
cg_light = sum(labels_topN_light)
|
||||
else:
|
||||
labels_topK_heavy_from_light = sorted(labels_topN_light, reverse=True)[:heavyK]
|
||||
cg_light = sum(labels_topK_heavy_from_light)
|
||||
|
||||
ideal_ordering = sorted(labels, reverse=True)
|
||||
cg_heavy = sum(ideal_ordering[: min(lightN, heavyK)])
|
||||
|
||||
out = 0.0
|
||||
if cg_heavy != 0:
|
||||
out = max(cg_light / cg_heavy, 0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _get_weight(w, atK):
|
||||
if not w:
|
||||
return 1.0
|
||||
elif len(w) <= atK:
|
||||
return 0.0
|
||||
else:
|
||||
return w[atK]
|
||||
|
||||
|
||||
def recall_at_nk(labels, predictions, n=None, k=None, w=None):
|
||||
"""
|
||||
Recall at N-K ranking metric
|
||||
Args:
|
||||
labels (list): A list of label values
|
||||
predictions (list): A list of prediction values
|
||||
n (int): size of the list at which of predictions to compute recall. (Light Ranking Predictions)
|
||||
The default is None in which case the length of the provided predictions is used as L
|
||||
k (int): size of the list at which of labels to compute recall. (Heavy Ranking Predictions)
|
||||
The default is None in which case the length of the provided labels is used as L
|
||||
w (list): weight vector sorted by labels
|
||||
"""
|
||||
assert len(labels) == len(predictions)
|
||||
|
||||
if not any(labels):
|
||||
out = None
|
||||
else:
|
||||
|
||||
safe_n = len(predictions) if not n else min(len(predictions), n)
|
||||
safe_k = len(labels) if not k else min(len(labels), k)
|
||||
|
||||
labels_with_predictions = zip(labels, predictions)
|
||||
sorted_labels_with_predictions = sorted(
|
||||
labels_with_predictions, key=lambda x: x[0], reverse=True
|
||||
)
|
||||
|
||||
order_sorted_labels_predictions = zip(range(len(labels)), *zip(*sorted_labels_with_predictions))
|
||||
|
||||
order_with_predictions = [
|
||||
(order, pred) for order, label, pred in order_sorted_labels_predictions
|
||||
]
|
||||
order_with_sorted_predictions = sorted(order_with_predictions, key=lambda x: x[1], reverse=True)
|
||||
|
||||
pred_sorted_order_at_n = [order for order, _ in order_with_sorted_predictions][:safe_n]
|
||||
|
||||
intersection_weight = [
|
||||
_get_weight(w, order) if order < safe_k else 0 for order in pred_sorted_order_at_n
|
||||
]
|
||||
|
||||
intersection_score = sum(intersection_weight)
|
||||
full_score = sum(w) if w else float(safe_k)
|
||||
|
||||
out = 0.0
|
||||
if full_score != 0:
|
||||
out = intersection_score / full_score
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ExpectedLossGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
||||
"""
|
||||
This is the Expected Loss Grouped metric computation configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, lightNs=[]):
|
||||
"""
|
||||
Args:
|
||||
lightNs (list): size of the list at which of Initial candidates to compute Expected Loss. (LightRanking)
|
||||
"""
|
||||
self.lightNs = lightNs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "ExpectedLoss"
|
||||
|
||||
@property
|
||||
def metrics_dict(self):
|
||||
metrics_to_compute = {}
|
||||
for lightN in self.lightNs:
|
||||
metric_name = "ExpectedLoss_atLight_" + str(lightN)
|
||||
metrics_to_compute[metric_name] = partial(score_loss_at_n, lightN=lightN)
|
||||
return metrics_to_compute
|
||||
|
||||
def extract_label(self, prec, drec, drec_label):
|
||||
return drec_label
|
||||
|
||||
def extract_prediction(self, prec, drec, drec_label):
|
||||
return extract_prediction_from_prediction_record(prec)
|
||||
|
||||
|
||||
class CGRGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
||||
"""
|
||||
This is the Cumulative Gain Ratio (CGR) Grouped metric computation configuration.
|
||||
CGR at the max length of each session is the default.
|
||||
CGR at additional positions can be computed by specifying a list of 'n's and 'k's
|
||||
"""
|
||||
|
||||
def __init__(self, lightNs=[], heavyKs=[]):
|
||||
"""
|
||||
Args:
|
||||
lightNs (list): size of the list at which of Initial candidates to compute CGR. (LightRanking)
|
||||
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
|
||||
"""
|
||||
self.lightNs = lightNs
|
||||
self.heavyKs = heavyKs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "cgr"
|
||||
|
||||
@property
|
||||
def metrics_dict(self):
|
||||
metrics_to_compute = {}
|
||||
for lightN in self.lightNs:
|
||||
for heavyK in self.heavyKs:
|
||||
metric_name = "cgr_atLight_" + str(lightN) + "_atHeavy_" + str(heavyK)
|
||||
metrics_to_compute[metric_name] = partial(cgr_at_nk, lightN=lightN, heavyK=heavyK)
|
||||
return metrics_to_compute
|
||||
|
||||
def extract_label(self, prec, drec, drec_label):
|
||||
return drec_label
|
||||
|
||||
def extract_prediction(self, prec, drec, drec_label):
|
||||
return extract_prediction_from_prediction_record(prec)
|
||||
|
||||
|
||||
class RecallGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
||||
"""
|
||||
This is the Recall Grouped metric computation configuration.
|
||||
Recall at the max length of each session is the default.
|
||||
Recall at additional positions can be computed by specifying a list of 'n's and 'k's
|
||||
"""
|
||||
|
||||
def __init__(self, n=[], k=[], w=[]):
|
||||
"""
|
||||
Args:
|
||||
n (list): A list of ints. List of prediction rank thresholds (for light)
|
||||
k (list): A list of ints. List of label rank thresholds (for heavy)
|
||||
"""
|
||||
self.predN = n
|
||||
self.labelK = k
|
||||
self.weight = w
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "group_recall"
|
||||
|
||||
@property
|
||||
def metrics_dict(self):
|
||||
metrics_to_compute = {"group_recall_unweighted": recall_at_nk}
|
||||
if not self.weight:
|
||||
metrics_to_compute["group_recall_weighted"] = partial(recall_at_nk, w=self.weight)
|
||||
|
||||
if self.predN and self.labelK:
|
||||
for n in self.predN:
|
||||
for k in self.labelK:
|
||||
if n >= k:
|
||||
metrics_to_compute[
|
||||
"group_recall_unweighted_at_L" + str(n) + "_at_H" + str(k)
|
||||
] = partial(recall_at_nk, n=n, k=k)
|
||||
if self.weight:
|
||||
metrics_to_compute[
|
||||
"group_recall_weighted_at_L" + str(n) + "_at_H" + str(k)
|
||||
] = partial(recall_at_nk, n=n, k=k, w=self.weight)
|
||||
|
||||
if self.labelK and not self.predN:
|
||||
for k in self.labelK:
|
||||
metrics_to_compute["group_recall_unweighted_at_full_at_H" + str(k)] = partial(
|
||||
recall_at_nk, k=k
|
||||
)
|
||||
if self.weight:
|
||||
metrics_to_compute["group_recall_weighted_at_full_at_H" + str(k)] = partial(
|
||||
recall_at_nk, k=k, w=self.weight
|
||||
)
|
||||
return metrics_to_compute
|
||||
|
||||
def extract_label(self, prec, drec, drec_label):
|
||||
return drec_label
|
||||
|
||||
def extract_prediction(self, prec, drec, drec_label):
|
||||
return extract_prediction_from_prediction_record(prec)
|
BIN
pushservice/src/main/python/models/libs/metric_fn_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/metric_fn_utils.docx
Normal file
Binary file not shown.
@ -1,294 +0,0 @@
|
||||
"""
|
||||
Utilties for constructing a metric_fn for magic recs.
|
||||
"""
|
||||
|
||||
from twml.contrib.metrics.metrics import (
|
||||
get_dual_binary_tasks_metric_fn,
|
||||
get_numeric_metric_fn,
|
||||
get_partial_multi_binary_class_metric_fn,
|
||||
get_single_binary_task_metric_fn,
|
||||
)
|
||||
|
||||
from .model_utils import generate_disliked_mask
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
METRIC_BOOK = {
|
||||
"OONC": ["OONC"],
|
||||
"OONC_Engagement": ["OONC", "Engagement"],
|
||||
"Sent": ["Sent"],
|
||||
"HeavyRankPosition": ["HeavyRankPosition"],
|
||||
"HeavyRankProbability": ["HeavyRankProbability"],
|
||||
}
|
||||
|
||||
USER_AGE_FEATURE_NAME = "accountAge"
|
||||
NEW_USER_AGE_CUTOFF = 0
|
||||
|
||||
|
||||
def remove_padding_and_flatten(tensor, valid_batch_size):
|
||||
"""Remove the padding of the input padded tensor given the valid batch size tensor,
|
||||
then flatten the output with respect to the first dimension.
|
||||
Args:
|
||||
tensor: A tensor of size [META_BATCH_SIZE, BATCH_SIZE, FEATURE_DIM].
|
||||
valid_batch_size: A tensor of size [META_BATCH_SIZE], with each element indicating
|
||||
the effective batch size of the BATCH_SIZE dimension.
|
||||
|
||||
Returns:
|
||||
A tesnor of size [tf.reduce_sum(valid_batch_size), FEATURE_DIM].
|
||||
"""
|
||||
unpadded_ragged_tensor = tf.RaggedTensor.from_tensor(tensor=tensor, lengths=valid_batch_size)
|
||||
|
||||
return unpadded_ragged_tensor.flat_values
|
||||
|
||||
|
||||
def safe_mask(values, mask):
|
||||
"""Mask values if possible.
|
||||
|
||||
Boolean mask inputed values if and only if values is a tensor of the same dimension as mask (or can be broadcasted to that dimension).
|
||||
|
||||
Args:
|
||||
values (Any or Tensor): Input tensor to mask. Dim 0 should be size N.
|
||||
mask (boolean tensor): A boolean tensor of size N.
|
||||
|
||||
Returns Values or Values masked.
|
||||
"""
|
||||
if values is None:
|
||||
return values
|
||||
if not tf.is_tensor(values):
|
||||
return values
|
||||
values_shape = values.get_shape()
|
||||
if not values_shape or len(values_shape) == 0:
|
||||
return values
|
||||
if not mask.get_shape().is_compatible_with(values_shape[0]):
|
||||
return values
|
||||
return tf.boolean_mask(values, mask)
|
||||
|
||||
|
||||
def add_new_user_metrics(metric_fn):
|
||||
"""Will stratify the metric_fn by adding new user metrics.
|
||||
|
||||
Given an input metric_fn, double every metric: One will be the orignal and the other will only include those for new users.
|
||||
|
||||
Args:
|
||||
metric_fn (python function): Base twml metric_fn.
|
||||
|
||||
Returns a metric_fn with new user metrics included.
|
||||
"""
|
||||
|
||||
def metric_fn_with_new_users(graph_output, labels, weights):
|
||||
if USER_AGE_FEATURE_NAME not in graph_output:
|
||||
raise ValueError(
|
||||
"In order to get metrics stratified by user age, {name} feature should be added to model graph output. However, only the following output keys were found: {keys}.".format(
|
||||
name=USER_AGE_FEATURE_NAME, keys=graph_output.keys()
|
||||
)
|
||||
)
|
||||
|
||||
metric_ops = metric_fn(graph_output, labels, weights)
|
||||
|
||||
is_new = tf.reshape(
|
||||
tf.math.less_equal(
|
||||
tf.cast(graph_output[USER_AGE_FEATURE_NAME], tf.int64),
|
||||
tf.cast(NEW_USER_AGE_CUTOFF, tf.int64),
|
||||
),
|
||||
[-1],
|
||||
)
|
||||
|
||||
labels = safe_mask(labels, is_new)
|
||||
weights = safe_mask(weights, is_new)
|
||||
graph_output = {key: safe_mask(values, is_new) for key, values in graph_output.items()}
|
||||
|
||||
new_user_metric_ops = metric_fn(graph_output, labels, weights)
|
||||
new_user_metric_ops = {name + "_new_users": ops for name, ops in new_user_metric_ops.items()}
|
||||
metric_ops.update(new_user_metric_ops)
|
||||
return metric_ops
|
||||
|
||||
return metric_fn_with_new_users
|
||||
|
||||
|
||||
def get_meta_learn_single_binary_task_metric_fn(
|
||||
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
|
||||
):
|
||||
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
|
||||
|
||||
Args:
|
||||
metrics: A list of string representing metric names.
|
||||
classnames: A list of string repsenting class names, In case of multiple binary class models,
|
||||
the names for each class or label.
|
||||
top_k: A tuple of int to specify top K metrics.
|
||||
use_top_k: A boolean value indicating of top K of metrics is used.
|
||||
|
||||
Returns:
|
||||
A customized metric_fn function.
|
||||
"""
|
||||
|
||||
def get_eval_metric_ops(graph_output, labels, weights):
|
||||
"""The op func of the eval_metrics. Comparing with normal version,
|
||||
the difference is we flatten the output, label, and weights.
|
||||
|
||||
Args:
|
||||
graph_output: A dict of tensors.
|
||||
labels: A tensor of int32 be the value of either 0 or 1.
|
||||
weights: A tensor of float32 to indicate the per record weight.
|
||||
|
||||
Returns:
|
||||
A dict of metric names and values.
|
||||
"""
|
||||
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=0, classes=classnames
|
||||
)
|
||||
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
|
||||
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=0, classes=classnames_unweighted
|
||||
)
|
||||
|
||||
valid_batch_size = graph_output["valid_batch_size"]
|
||||
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
|
||||
labels = remove_padding_and_flatten(labels, valid_batch_size)
|
||||
weights = remove_padding_and_flatten(weights, valid_batch_size)
|
||||
|
||||
tf.ensure_shape(graph_output["output"], [None, 1])
|
||||
tf.ensure_shape(labels, [None, 1])
|
||||
tf.ensure_shape(weights, [None, 1])
|
||||
|
||||
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
|
||||
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
|
||||
metrics_weighted.update(metrics_unweighted)
|
||||
|
||||
if use_top_k:
|
||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=0, labelcol=1)
|
||||
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
|
||||
metrics_weighted.update(metrics_numeric)
|
||||
return metrics_weighted
|
||||
|
||||
return get_eval_metric_ops
|
||||
|
||||
|
||||
def get_meta_learn_dual_binary_tasks_metric_fn(
|
||||
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
|
||||
):
|
||||
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
|
||||
|
||||
Args:
|
||||
metrics: A list of string representing metric names.
|
||||
classnames: A list of string repsenting class names, In case of multiple binary class models,
|
||||
the names for each class or label.
|
||||
top_k: A tuple of int to specify top K metrics.
|
||||
use_top_k: A boolean value indicating of top K of metrics is used.
|
||||
|
||||
Returns:
|
||||
A customized metric_fn function.
|
||||
"""
|
||||
|
||||
def get_eval_metric_ops(graph_output, labels, weights):
|
||||
"""The op func of the eval_metrics. Comparing with normal version,
|
||||
the difference is we flatten the output, label, and weights.
|
||||
|
||||
Args:
|
||||
graph_output: A dict of tensors.
|
||||
labels: A tensor of int32 be the value of either 0 or 1.
|
||||
weights: A tensor of float32 to indicate the per record weight.
|
||||
|
||||
Returns:
|
||||
A dict of metric names and values.
|
||||
"""
|
||||
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=[0, 1], classes=classnames
|
||||
)
|
||||
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
|
||||
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=[0, 1], classes=classnames_unweighted
|
||||
)
|
||||
|
||||
valid_batch_size = graph_output["valid_batch_size"]
|
||||
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
|
||||
labels = remove_padding_and_flatten(labels, valid_batch_size)
|
||||
weights = remove_padding_and_flatten(weights, valid_batch_size)
|
||||
|
||||
tf.ensure_shape(graph_output["output"], [None, 2])
|
||||
tf.ensure_shape(labels, [None, 2])
|
||||
tf.ensure_shape(weights, [None, 1])
|
||||
|
||||
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
|
||||
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
|
||||
metrics_weighted.update(metrics_unweighted)
|
||||
|
||||
if use_top_k:
|
||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=2, labelcol=2)
|
||||
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
|
||||
metrics_weighted.update(metrics_numeric)
|
||||
return metrics_weighted
|
||||
|
||||
return get_eval_metric_ops
|
||||
|
||||
|
||||
def get_metric_fn(task_name, use_stratify_metrics, use_meta_batch=False):
|
||||
"""Will retrieve the metric_fn for magic recs.
|
||||
|
||||
Args:
|
||||
task_name (string): Which task is being used for this model.
|
||||
use_stratify_metrics (boolean): Should we add stratified metrics (new user metrics).
|
||||
use_meta_batch (boolean): If the output/label/weights are passed in 3D shape instead of
|
||||
2D shape.
|
||||
|
||||
Returns:
|
||||
A metric_fn function to pass in twml Trainer.
|
||||
"""
|
||||
if task_name not in METRIC_BOOK:
|
||||
raise ValueError(
|
||||
"Task name of {task_name} not recognized. Unable to retrieve metrics.".format(
|
||||
task_name=task_name
|
||||
)
|
||||
)
|
||||
class_names = METRIC_BOOK[task_name]
|
||||
if use_meta_batch:
|
||||
get_n_binary_task_metric_fn = (
|
||||
get_meta_learn_single_binary_task_metric_fn
|
||||
if len(class_names) == 1
|
||||
else get_meta_learn_dual_binary_tasks_metric_fn
|
||||
)
|
||||
else:
|
||||
get_n_binary_task_metric_fn = (
|
||||
get_single_binary_task_metric_fn if len(class_names) == 1 else get_dual_binary_tasks_metric_fn
|
||||
)
|
||||
|
||||
metric_fn = get_n_binary_task_metric_fn(metrics=None, classnames=METRIC_BOOK[task_name])
|
||||
|
||||
if use_stratify_metrics:
|
||||
metric_fn = add_new_user_metrics(metric_fn)
|
||||
|
||||
return metric_fn
|
||||
|
||||
|
||||
def flip_disliked_labels(metric_fn):
|
||||
"""This function returns an adapted metric_fn which flips the labels of the OONCed evaluation data to 0 if it is disliked.
|
||||
Args:
|
||||
metric_fn: A metric_fn function to pass in twml Trainer.
|
||||
|
||||
Returns:
|
||||
_adapted_metric_fn: A customized metric_fn function with disliked OONC labels flipped.
|
||||
"""
|
||||
|
||||
def _adapted_metric_fn(graph_output, labels, weights):
|
||||
"""A customized metric_fn function with disliked OONC labels flipped.
|
||||
|
||||
Args:
|
||||
graph_output: A dict of tensors.
|
||||
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
|
||||
weights: A tensor of float32 to indicate the per record weight.
|
||||
|
||||
Returns:
|
||||
A dict of metric names and values.
|
||||
"""
|
||||
# We want to multiply the label of the observation by 0 only when it is disliked
|
||||
disliked_mask = generate_disliked_mask(labels)
|
||||
|
||||
# Extract OONC and engagement labels only.
|
||||
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
|
||||
|
||||
# Labels will be set to 0 if it is disliked.
|
||||
adapted_labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
|
||||
|
||||
return metric_fn(graph_output, adapted_labels, weights)
|
||||
|
||||
return _adapted_metric_fn
|
BIN
pushservice/src/main/python/models/libs/model_args.docx
Normal file
BIN
pushservice/src/main/python/models/libs/model_args.docx
Normal file
Binary file not shown.
@ -1,231 +0,0 @@
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def get_arg_parser():
|
||||
parser = DataRecordTrainer.add_parser_arguments()
|
||||
|
||||
parser.add_argument(
|
||||
"--input_size_bits",
|
||||
type=int,
|
||||
default=18,
|
||||
help="number of bits allocated to the input size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_trainer_name",
|
||||
default="magic_recs_mlp_calibration_MTL_OONC_Engagement",
|
||||
type=str,
|
||||
help="specify the model trainer name.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="deepnorm_gbdt_inputdrop2_rescale",
|
||||
type=str,
|
||||
help="specify the model type to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--feat_config_type",
|
||||
default="get_feature_config_with_sparse_continuous",
|
||||
type=str,
|
||||
help="specify the feature configure function to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--directly_export_best",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to directly_export best_checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_base_dir",
|
||||
default="none",
|
||||
type=str,
|
||||
help="latest ckpt in this folder will be used to ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature_list",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.99999, type=float, help="Momentum term for batch normalization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
default=0.2,
|
||||
type=float,
|
||||
help="input_dropout_rate to rescale output by (1 - input_dropout_rate)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_layer_1_size", default=256, type=int, help="Size of MLP_branch layer 1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_layer_2_size", default=128, type=int, help="Size of MLP_branch layer 2"
|
||||
)
|
||||
parser.add_argument("--out_layer_3_size", default=64, type=int, help="Size of MLP_branch layer 3")
|
||||
parser.add_argument(
|
||||
"--sparse_embedding_size", default=50, type=int, help="Dimensionality of sparse embedding layer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dense_embedding_size", default=128, type=int, help="Dimensionality of dense embedding layer"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_uam_label",
|
||||
default=False,
|
||||
type=str,
|
||||
help="Whether to use uam_label or not",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default="OONC_Engagement",
|
||||
type=str,
|
||||
help="specify the task name to use: OONC or OONC_Engagement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_weight",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="Initial OONC Task Weight MTL: OONC+Engagement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_engagement_weight",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to use engagement weight for base model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mtl_num_extra_layers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of Hidden Layers for each TaskBranch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mtl_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MTL Extra Layers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_oonc_score",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to use oonc score only or combined score.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stratified_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use stratified metrics: Break out new-user metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_group_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Will run evaluation metrics grouped by user.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_full_scope",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Will add extra scope and naming to graph.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainable_regexes",
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="The union of variables specified by the list of regexes will be considered trainable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fine_tuning.ckpt_to_initialize_from",
|
||||
dest="fine_tuning_ckpt_to_initialize_from",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint path from which to warm start. Indicates the pre-trained model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fine_tuning.warm_start_scope_regex",
|
||||
dest="fine_tuning_warm_start_scope_regex",
|
||||
type=str,
|
||||
default=None,
|
||||
help="All variables matching this will be restored.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params(args=None):
|
||||
parser = get_arg_parser()
|
||||
if args is None:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def get_arg_parser_light_ranking():
|
||||
parser = get_arg_parser()
|
||||
|
||||
parser.add_argument(
|
||||
"--use_record_weight",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to use record weight for base model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_record_weight", default=0.0, type=float, help="Minimum record weight to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smooth_weight", default=0.0, type=float, help="Factor to smooth Rank Position Weight."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_mlp_layers", type=int, default=3, help="Number of Hidden Layers for MLP model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlp_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MLP Layers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_light_ranking_group_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Will run evaluation metrics grouped by user for Light Ranking.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_missing_sub_branch",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use missing value sub-branch for Light Ranking.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gbdt_features",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use GBDT features for Light Ranking.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_light_ranking_group_metrics_in_bq",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to get_predictions for Light Ranking to compute group metrics in BigQuery.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_file_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_file_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path",
|
||||
)
|
||||
return parser
|
BIN
pushservice/src/main/python/models/libs/model_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/model_utils.docx
Normal file
Binary file not shown.
@ -1,339 +0,0 @@
|
||||
import sys
|
||||
|
||||
import twml
|
||||
|
||||
from .initializer import customized_glorot_uniform
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
import yaml
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def read_config(whitelist_yaml_file):
|
||||
with tf.gfile.FastGFile(whitelist_yaml_file) as f:
|
||||
try:
|
||||
return yaml.safe_load(f)
|
||||
except yaml.YAMLError as exc:
|
||||
print(exc)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _sparse_feature_fixup(features, input_size_bits):
|
||||
"""Rebuild a sparse tensor feature so that its dense shape attribute is present.
|
||||
|
||||
Arguments:
|
||||
features (SparseTensor): Sparse feature tensor of shape ``(B, sparse_feature_dim)``.
|
||||
input_size_bits (int): Number of columns in ``log2`` scale. Must be positive.
|
||||
|
||||
Returns:
|
||||
SparseTensor: Rebuilt and non-faulty version of `features`."""
|
||||
sparse_feature_dim = tf.constant(2**input_size_bits, dtype=tf.int64)
|
||||
sparse_shape = tf.stack([features.dense_shape[0], sparse_feature_dim])
|
||||
sparse_tf = tf.SparseTensor(features.indices, features.values, sparse_shape)
|
||||
return sparse_tf
|
||||
|
||||
|
||||
def self_atten_dense(input, out_dim, activation=None, use_bias=True, name=None):
|
||||
def safe_concat(base, suffix):
|
||||
"""Concats variables name components if base is given."""
|
||||
if not base:
|
||||
return base
|
||||
return f"{base}:{suffix}"
|
||||
|
||||
input_dim = input.shape.as_list()[1]
|
||||
|
||||
sigmoid_out = twml.layers.FullDense(
|
||||
input_dim, dtype=tf.float32, activation=tf.nn.sigmoid, name=safe_concat(name, "sigmoid_out")
|
||||
)(input)
|
||||
atten_input = sigmoid_out * input
|
||||
mlp_out = twml.layers.FullDense(
|
||||
out_dim,
|
||||
dtype=tf.float32,
|
||||
activation=activation,
|
||||
use_bias=use_bias,
|
||||
name=safe_concat(name, "mlp_out"),
|
||||
)(atten_input)
|
||||
return mlp_out
|
||||
|
||||
|
||||
def get_dense_out(input, out_dim, activation, dense_type):
|
||||
if dense_type == "full_dense":
|
||||
out = twml.layers.FullDense(out_dim, dtype=tf.float32, activation=activation)(input)
|
||||
elif dense_type == "self_atten_dense":
|
||||
out = self_atten_dense(input, out_dim, activation=activation)
|
||||
return out
|
||||
|
||||
|
||||
def get_input_trans_func(bn_normalized_dense, is_training):
|
||||
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
|
||||
group_num = bn_normalized_dense.shape.as_list()[1]
|
||||
|
||||
gw_normalized_dense = GroupWiseTrans(group_num, 1, 8, name="groupwise_1", activation=tf.tanh)(
|
||||
gw_normalized_dense
|
||||
)
|
||||
gw_normalized_dense = GroupWiseTrans(group_num, 8, 4, name="groupwise_2", activation=tf.tanh)(
|
||||
gw_normalized_dense
|
||||
)
|
||||
gw_normalized_dense = GroupWiseTrans(group_num, 4, 1, name="groupwise_3", activation=tf.tanh)(
|
||||
gw_normalized_dense
|
||||
)
|
||||
|
||||
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
|
||||
|
||||
bn_gw_normalized_dense = tf.layers.batch_normalization(
|
||||
gw_normalized_dense,
|
||||
training=is_training,
|
||||
renorm_momentum=0.9999,
|
||||
momentum=0.9999,
|
||||
renorm=is_training,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
return bn_gw_normalized_dense
|
||||
|
||||
|
||||
def tensor_dropout(
|
||||
input_tensor,
|
||||
rate,
|
||||
is_training,
|
||||
sparse_tensor=None,
|
||||
):
|
||||
"""
|
||||
Implements dropout layer for both dense and sparse input_tensor
|
||||
|
||||
Arguments:
|
||||
input_tensor:
|
||||
B x D dense tensor, or a sparse tensor
|
||||
rate (float32):
|
||||
dropout rate
|
||||
is_training (bool):
|
||||
training stage or not.
|
||||
sparse_tensor (bool):
|
||||
whether the input_tensor is sparse tensor or not. Default to be None, this value has to be passed explicitly.
|
||||
rescale_sparse_dropout (bool):
|
||||
Do we need to do rescaling or not.
|
||||
Returns:
|
||||
tensor dropped out"""
|
||||
if sparse_tensor == True:
|
||||
if is_training:
|
||||
with tf.variable_scope("sparse_dropout"):
|
||||
values = input_tensor.values
|
||||
keep_mask = tf.keras.backend.random_binomial(
|
||||
tf.shape(values), p=1 - rate, dtype=tf.float32, seed=None
|
||||
)
|
||||
keep_mask.set_shape([None])
|
||||
keep_mask = tf.cast(keep_mask, tf.bool)
|
||||
|
||||
keep_indices = tf.boolean_mask(input_tensor.indices, keep_mask, axis=0)
|
||||
keep_values = tf.boolean_mask(values, keep_mask, axis=0)
|
||||
|
||||
dropped_tensor = tf.SparseTensor(keep_indices, keep_values, input_tensor.dense_shape)
|
||||
return dropped_tensor
|
||||
else:
|
||||
return input_tensor
|
||||
elif sparse_tensor == False:
|
||||
return tf.layers.dropout(input_tensor, rate=rate, training=is_training)
|
||||
|
||||
|
||||
def adaptive_transformation(bn_normalized_dense, is_training, func_type="default"):
|
||||
assert func_type in [
|
||||
"default",
|
||||
"tiny",
|
||||
], f"fun_type can only be one of default and tiny, but get {func_type}"
|
||||
|
||||
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
|
||||
group_num = bn_normalized_dense.shape.as_list()[1]
|
||||
|
||||
if func_type == "default":
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 1, 8, name="groupwise_1", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 8, 4, name="groupwise_2", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 4, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
elif func_type == "tiny":
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 1, 2, name="groupwise_1", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 2, 1, name="groupwise_2", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 1, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
|
||||
bn_gw_normalized_dense = tf.layers.batch_normalization(
|
||||
gw_normalized_dense,
|
||||
training=is_training,
|
||||
renorm_momentum=0.9999,
|
||||
momentum=0.9999,
|
||||
renorm=is_training,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
return bn_gw_normalized_dense
|
||||
|
||||
|
||||
class FastGroupWiseTrans(object):
|
||||
"""
|
||||
used to apply group-wise fully connected layers to the input.
|
||||
it applies a tiny, unique MLP to each individual feature."""
|
||||
|
||||
def __init__(self, group_num, input_dim, out_dim, name, activation=None, init_multiplier=1):
|
||||
self.group_num = group_num
|
||||
self.input_dim = input_dim
|
||||
self.out_dim = out_dim
|
||||
self.activation = activation
|
||||
self.init_multiplier = init_multiplier
|
||||
|
||||
self.w = tf.get_variable(
|
||||
name + "_group_weight",
|
||||
[1, group_num, input_dim, out_dim],
|
||||
initializer=customized_glorot_uniform(
|
||||
fan_in=input_dim * init_multiplier, fan_out=out_dim * init_multiplier
|
||||
),
|
||||
trainable=True,
|
||||
)
|
||||
self.b = tf.get_variable(
|
||||
name + "_group_bias",
|
||||
[1, group_num, out_dim],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
def __call__(self, input_tensor):
|
||||
"""
|
||||
input_tensor: batch_size x group_num x input_dim
|
||||
output_tensor: batch_size x group_num x out_dim"""
|
||||
input_tensor_expand = tf.expand_dims(input_tensor, axis=-1)
|
||||
|
||||
output_tensor = tf.add(
|
||||
tf.reduce_sum(tf.multiply(input_tensor_expand, self.w), axis=-2, keepdims=False),
|
||||
self.b,
|
||||
)
|
||||
|
||||
if self.activation is not None:
|
||||
output_tensor = self.activation(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class GroupWiseTrans(object):
|
||||
"""
|
||||
Used to apply group fully connected layers to the input.
|
||||
"""
|
||||
|
||||
def __init__(self, group_num, input_dim, out_dim, name, activation=None):
|
||||
self.group_num = group_num
|
||||
self.input_dim = input_dim
|
||||
self.out_dim = out_dim
|
||||
self.activation = activation
|
||||
|
||||
w_list, b_list = [], []
|
||||
for idx in range(out_dim):
|
||||
this_w = tf.get_variable(
|
||||
name + f"_group_weight_{idx}",
|
||||
[1, group_num, input_dim],
|
||||
initializer=tf.keras.initializers.glorot_uniform(),
|
||||
trainable=True,
|
||||
)
|
||||
this_b = tf.get_variable(
|
||||
name + f"_group_bias_{idx}",
|
||||
[1, group_num, 1],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
trainable=True,
|
||||
)
|
||||
w_list.append(this_w)
|
||||
b_list.append(this_b)
|
||||
self.w_list = w_list
|
||||
self.b_list = b_list
|
||||
|
||||
def __call__(self, input_tensor):
|
||||
"""
|
||||
input_tensor: batch_size x group_num x input_dim
|
||||
output_tensor: batch_size x group_num x out_dim
|
||||
"""
|
||||
out_tensor_list = []
|
||||
for idx in range(self.out_dim):
|
||||
this_res = (
|
||||
tf.reduce_sum(input_tensor * self.w_list[idx], axis=-1, keepdims=True) + self.b_list[idx]
|
||||
)
|
||||
out_tensor_list.append(this_res)
|
||||
output_tensor = tf.concat(out_tensor_list, axis=-1)
|
||||
|
||||
if self.activation is not None:
|
||||
output_tensor = self.activation(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def add_scalar_summary(var, name, name_scope="hist_dense_feature/"):
|
||||
with tf.name_scope("summaries/"):
|
||||
with tf.name_scope(name_scope):
|
||||
tf.summary.scalar(name, var)
|
||||
|
||||
|
||||
def add_histogram_summary(var, name, name_scope="hist_dense_feature/"):
|
||||
with tf.name_scope("summaries/"):
|
||||
with tf.name_scope(name_scope):
|
||||
tf.summary.histogram(name, tf.reshape(var, [-1]))
|
||||
|
||||
|
||||
def sparse_clip_by_value(sparse_tf, min_val, max_val):
|
||||
new_vals = tf.clip_by_value(sparse_tf.values, min_val, max_val)
|
||||
return tf.SparseTensor(sparse_tf.indices, new_vals, sparse_tf.dense_shape)
|
||||
|
||||
|
||||
def check_numerics_with_msg(tensor, message="", sparse_tensor=False):
|
||||
if sparse_tensor:
|
||||
values = tf.debugging.check_numerics(tensor.values, message=message)
|
||||
return tf.SparseTensor(tensor.indices, values, tensor.dense_shape)
|
||||
else:
|
||||
return tf.debugging.check_numerics(tensor, message=message)
|
||||
|
||||
|
||||
def pad_empty_sparse_tensor(tensor):
|
||||
dummy_tensor = tf.SparseTensor(
|
||||
indices=[[0, 0]],
|
||||
values=[0.00001],
|
||||
dense_shape=tensor.dense_shape,
|
||||
)
|
||||
result = tf.cond(
|
||||
tf.equal(tf.size(tensor.values), 0),
|
||||
lambda: dummy_tensor,
|
||||
lambda: tensor,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def filter_nans_and_infs(tensor, sparse_tensor=False):
|
||||
if sparse_tensor:
|
||||
sparse_values = tensor.values
|
||||
filtered_val = tf.where(
|
||||
tf.logical_or(tf.is_nan(sparse_values), tf.is_inf(sparse_values)),
|
||||
tf.zeros_like(sparse_values),
|
||||
sparse_values,
|
||||
)
|
||||
return tf.SparseTensor(tensor.indices, filtered_val, tensor.dense_shape)
|
||||
else:
|
||||
return tf.where(
|
||||
tf.logical_or(tf.is_nan(tensor), tf.is_inf(tensor)), tf.zeros_like(tensor), tensor
|
||||
)
|
||||
|
||||
|
||||
def generate_disliked_mask(labels):
|
||||
"""Generate a disliked mask where only samples with dislike labels are set to 1 otherwise set to 0.
|
||||
Args:
|
||||
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
|
||||
Returns:
|
||||
1D tensor of shape batch_size x 1: [dislikes (booleans)]
|
||||
"""
|
||||
return tf.equal(tf.reshape(labels[:, 2], shape=[-1, 1]), 1)
|
BIN
pushservice/src/main/python/models/libs/warm_start_utils.docx
Normal file
BIN
pushservice/src/main/python/models/libs/warm_start_utils.docx
Normal file
Binary file not shown.
@ -1,309 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
from twitter.magicpony.common import file_access
|
||||
import twml
|
||||
|
||||
from .model_utils import read_config
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def get_model_type_to_tensors_to_change_axis():
|
||||
model_type_to_tensors_to_change_axis = {
|
||||
"magic_recs/model/batch_normalization/beta": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/gamma": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/moving_mean": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/moving_stddev": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/moving_variance": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/renorm_mean": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/renorm_stddev": ([0], "continuous"),
|
||||
"magic_recs/model/logits/EngagementGivenOONC_logits/clem_net_1/block2_4/channel_wise_dense_4/kernel": (
|
||||
[1],
|
||||
"all",
|
||||
),
|
||||
"magic_recs/model/logits/OONC_logits/clem_net/block2/channel_wise_dense/kernel": ([1], "all"),
|
||||
}
|
||||
|
||||
return model_type_to_tensors_to_change_axis
|
||||
|
||||
|
||||
def mkdirp(dirname):
|
||||
if not tf.io.gfile.exists(dirname):
|
||||
tf.io.gfile.makedirs(dirname)
|
||||
|
||||
|
||||
def rename_dir(dirname, dst):
|
||||
file_access.hdfs.mv(dirname, dst)
|
||||
|
||||
|
||||
def rmdir(dirname):
|
||||
if tf.io.gfile.exists(dirname):
|
||||
if tf.io.gfile.isdir(dirname):
|
||||
tf.io.gfile.rmtree(dirname)
|
||||
else:
|
||||
tf.io.gfile.remove(dirname)
|
||||
|
||||
|
||||
def get_var_dict(checkpoint_path):
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_path)
|
||||
var_dict = OrderedDict()
|
||||
with tf.Session() as sess:
|
||||
all_var_list = tf.train.list_variables(checkpoint_path)
|
||||
for var_name, _ in all_var_list:
|
||||
# Load the variable
|
||||
var = tf.train.load_variable(checkpoint_path, var_name)
|
||||
var_dict[var_name] = var
|
||||
return var_dict
|
||||
|
||||
|
||||
def get_continunous_mapping_from_feat_list(old_feature_list, new_feature_list):
|
||||
"""
|
||||
get var_ind for old_feature and corresponding var_ind for new_feature
|
||||
"""
|
||||
new_var_ind, old_var_ind = [], []
|
||||
for this_new_id, this_new_name in enumerate(new_feature_list):
|
||||
if this_new_name in old_feature_list:
|
||||
this_old_id = old_feature_list.index(this_new_name)
|
||||
new_var_ind.append(this_new_id)
|
||||
old_var_ind.append(this_old_id)
|
||||
return np.asarray(old_var_ind), np.asarray(new_var_ind)
|
||||
|
||||
|
||||
def get_continuous_mapping_from_feat_dict(old_feature_dict, new_feature_dict):
|
||||
"""
|
||||
get var_ind for old_feature and corresponding var_ind for new_feature
|
||||
"""
|
||||
old_cont = old_feature_dict["continuous"]
|
||||
old_bin = old_feature_dict["binary"]
|
||||
|
||||
new_cont = new_feature_dict["continuous"]
|
||||
new_bin = new_feature_dict["binary"]
|
||||
|
||||
_dummy_sparse_feat = [f"sparse_feature_{_idx}" for _idx in range(100)]
|
||||
|
||||
cont_old_var_ind, cont_new_var_ind = get_continunous_mapping_from_feat_list(old_cont, new_cont)
|
||||
|
||||
all_old_var_ind, all_new_var_ind = get_continunous_mapping_from_feat_list(
|
||||
old_cont + old_bin + _dummy_sparse_feat, new_cont + new_bin + _dummy_sparse_feat
|
||||
)
|
||||
|
||||
_res = {
|
||||
"continuous": (cont_old_var_ind, cont_new_var_ind),
|
||||
"all": (all_old_var_ind, all_new_var_ind),
|
||||
}
|
||||
|
||||
return _res
|
||||
|
||||
|
||||
def warm_start_from_var_dict(
|
||||
old_ckpt_path,
|
||||
var_ind_dict,
|
||||
output_dir,
|
||||
new_len_var,
|
||||
var_to_change_dict_fn=get_model_type_to_tensors_to_change_axis,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
old_ckpt_path (str): path to the old checkpoint path
|
||||
new_var_ind (array of int): index to overlapping features in new var between old and new feature list.
|
||||
old_var_ind (array of int): index to overlapping features in old var between old and new feature list.
|
||||
|
||||
output_dir (str): dir that used to write modified checkpoint
|
||||
new_len_var ({str:int}): number of feature in the new feature list.
|
||||
var_to_change_dict_fn (dict): A function to get the dictionary of format {var_name: dim_to_change}
|
||||
"""
|
||||
old_var_dict = get_var_dict(old_ckpt_path)
|
||||
|
||||
ckpt_file_name = os.path.basename(old_ckpt_path)
|
||||
mkdirp(output_dir)
|
||||
output_path = join(output_dir, ckpt_file_name)
|
||||
|
||||
tensors_to_change = var_to_change_dict_fn()
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
with tf.Session() as sess:
|
||||
var_name_shape_list = tf.train.list_variables(old_ckpt_path)
|
||||
count = 0
|
||||
|
||||
for var_name, var_shape in var_name_shape_list:
|
||||
old_var = old_var_dict[var_name]
|
||||
if var_name in tensors_to_change.keys():
|
||||
_info_tuple = tensors_to_change[var_name]
|
||||
dims_to_remove_from, var_type = _info_tuple
|
||||
|
||||
new_var_ind, old_var_ind = var_ind_dict[var_type]
|
||||
|
||||
this_shape = list(old_var.shape)
|
||||
for this_dim in dims_to_remove_from:
|
||||
this_shape[this_dim] = new_len_var[var_type]
|
||||
|
||||
stddev = np.std(old_var)
|
||||
truncated_norm_generator = stats.truncnorm(-0.5, 0.5, loc=0, scale=stddev)
|
||||
size = np.prod(this_shape)
|
||||
new_var = truncated_norm_generator.rvs(size).reshape(this_shape)
|
||||
new_var = new_var.astype(old_var.dtype)
|
||||
|
||||
new_var = copy_feat_based_on_mapping(
|
||||
new_var, old_var, dims_to_remove_from, new_var_ind, old_var_ind
|
||||
)
|
||||
count = count + 1
|
||||
else:
|
||||
new_var = old_var
|
||||
var = tf.Variable(new_var, name=var_name)
|
||||
assert count == len(tensors_to_change.keys()), "not all variables are exchanged.\n"
|
||||
saver = tf.train.Saver()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
saver.save(sess, output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def copy_feat_based_on_mapping(new_array, old_array, dims_to_remove_from, new_var_ind, old_var_ind):
|
||||
if dims_to_remove_from == [0, 1]:
|
||||
for this_new_ind, this_old_ind in zip(new_var_ind, old_var_ind):
|
||||
new_array[this_new_ind, new_var_ind] = old_array[this_old_ind, old_var_ind]
|
||||
elif dims_to_remove_from == [0]:
|
||||
new_array[new_var_ind] = old_array[old_var_ind]
|
||||
elif dims_to_remove_from == [1]:
|
||||
new_array[:, new_var_ind] = old_array[:, old_var_ind]
|
||||
else:
|
||||
raise RuntimeError(f"undefined dims_to_remove_from pattern: ({dims_to_remove_from})")
|
||||
return new_array
|
||||
|
||||
|
||||
def read_file(filename, decode=False):
|
||||
"""
|
||||
Reads contents from a file and optionally decodes it.
|
||||
|
||||
Arguments:
|
||||
filename:
|
||||
path to file where the contents will be loaded from.
|
||||
Accepts HDFS and local paths.
|
||||
decode:
|
||||
False or 'json'. When decode='json', contents is decoded
|
||||
with json.loads. When False, contents is returned as is.
|
||||
"""
|
||||
graph = tf.Graph()
|
||||
with graph.as_default():
|
||||
read = tf.read_file(filename)
|
||||
|
||||
with tf.Session(graph=graph) as sess:
|
||||
contents = sess.run(read)
|
||||
if not isinstance(contents, str):
|
||||
contents = contents.decode()
|
||||
|
||||
if decode == "json":
|
||||
contents = json.loads(contents)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def read_feat_list_from_disk(file_path):
|
||||
return read_file(file_path, decode="json")
|
||||
|
||||
|
||||
def get_feature_list_for_light_ranking(feature_list_path, data_spec_path):
|
||||
feature_list = read_config(feature_list_path).items()
|
||||
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
||||
|
||||
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
||||
data_spec_path=data_spec_path
|
||||
)
|
||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
||||
feature_regexes=string_feat_list,
|
||||
group_name="continuous",
|
||||
default_value=-1,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
feature_config = feature_config_builder.build()
|
||||
feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
|
||||
"CONTINUOUS"
|
||||
]
|
||||
return feature_list
|
||||
|
||||
|
||||
def get_feature_list_for_heavy_ranking(feature_list_path, data_spec_path):
|
||||
feature_list = read_config(feature_list_path).items()
|
||||
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
||||
|
||||
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
||||
data_spec_path=data_spec_path
|
||||
)
|
||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
||||
feature_regexes=string_feat_list,
|
||||
group_name="continuous",
|
||||
default_value=-1,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
|
||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
||||
feature_regexes=string_feat_list,
|
||||
group_name="binary",
|
||||
default_value=False,
|
||||
type_filter=["BINARY"],
|
||||
)
|
||||
|
||||
feature_config_builder = feature_config_builder.build()
|
||||
|
||||
continuous_feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
|
||||
"CONTINUOUS"
|
||||
]
|
||||
|
||||
binary_feature_list = feature_config_builder._feature_group_extraction_configs[1].feature_map[
|
||||
"BINARY"
|
||||
]
|
||||
return {"continuous": continuous_feature_list, "binary": binary_feature_list}
|
||||
|
||||
|
||||
def warm_start_checkpoint(
|
||||
old_best_ckpt_folder,
|
||||
old_feature_list_path,
|
||||
feature_allow_list_path,
|
||||
data_spec_path,
|
||||
output_ckpt_folder,
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
Reads old checkpoint and the old feature list, and create a new ckpt warm started from old ckpt using new features .
|
||||
|
||||
Arguments:
|
||||
old_best_ckpt_folder:
|
||||
path to the best_checkpoint_folder for old model
|
||||
old_feature_list_path:
|
||||
path to the json file that stores the list of continuous features used in old models.
|
||||
feature_allow_list_path:
|
||||
yaml file that contain the feature allow list.
|
||||
data_spec_path:
|
||||
path to the data_spec file
|
||||
output_ckpt_folder:
|
||||
folder that contains the modified ckpt.
|
||||
|
||||
Returns:
|
||||
path to the modified ckpt."""
|
||||
old_ckpt_path = tf.train.latest_checkpoint(old_best_ckpt_folder, latest_filename=None)
|
||||
|
||||
new_feature_dict = get_feature_list(feature_allow_list_path, data_spec_path)
|
||||
old_feature_dict = read_feat_list_from_disk(old_feature_list_path)
|
||||
|
||||
var_ind_dict = get_continuous_mapping_from_feat_dict(new_feature_dict, old_feature_dict)
|
||||
|
||||
new_len_var = {
|
||||
"continuous": len(new_feature_dict["continuous"]),
|
||||
"all": len(new_feature_dict["continuous"] + new_feature_dict["binary"]) + 100,
|
||||
}
|
||||
|
||||
warm_started_ckpt_path = warm_start_from_var_dict(
|
||||
old_ckpt_path,
|
||||
var_ind_dict,
|
||||
output_dir=output_ckpt_folder,
|
||||
new_len_var=new_len_var,
|
||||
)
|
||||
|
||||
return warm_started_ckpt_path
|
@ -1,69 +0,0 @@
|
||||
#":mlwf_libs",
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model",
|
||||
source = "eval_model.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model_local",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model_local",
|
||||
source = "eval_model.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "mlwf_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":mlwf_libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:mlwf_model",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "libs",
|
||||
sources = ["**/*.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"src/python/twitter/deepbird/util/data",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "mlwf_libs",
|
||||
sources = ["**/*.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"twml",
|
||||
],
|
||||
)
|
BIN
pushservice/src/main/python/models/light_ranking/BUILD.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/BUILD.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/light_ranking/README.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/README.docx
Normal file
Binary file not shown.
@ -1,14 +0,0 @@
|
||||
# Notification Light Ranker Model
|
||||
|
||||
## Model Context
|
||||
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification light ranker model bridges candidate generation and heavy ranking by pre-selecting highly-relevant candidates from the initial huge candidate pool. It’s a light-weight model to reduce system cost during heavy ranking without hurting user experience.
|
||||
|
||||
## Directory Structure
|
||||
- BUILD: this file defines python library dependencies
|
||||
- model_pools_mlp.py: this file defines tensorflow model architecture for the notification light ranker model
|
||||
- deep_norm.py: this file contains 1) how to build the tensorflow graph with specified model architecture, loss function and training configuration. 2) how to set up the overall model training & evaluation pipeline
|
||||
- eval_model.py: the main python entry file to set up the overall model evaluation pipeline
|
||||
|
||||
|
||||
|
||||
|
BIN
pushservice/src/main/python/models/light_ranking/__init__.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/__init__.docx
Normal file
Binary file not shown.
BIN
pushservice/src/main/python/models/light_ranking/deep_norm.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/deep_norm.docx
Normal file
Binary file not shown.
@ -1,226 +0,0 @@
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import os
|
||||
|
||||
from twitter.cortex.ml.embeddings.common.helpers import decode_str_or_unicode
|
||||
import twml
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
from ..libs.get_feat_config import get_feature_config_light_ranking, LABELS_LR
|
||||
from ..libs.graph_utils import get_trainable_variables
|
||||
from ..libs.group_metrics import (
|
||||
run_group_metrics_light_ranking,
|
||||
run_group_metrics_light_ranking_in_bq,
|
||||
)
|
||||
from ..libs.metric_fn_utils import get_metric_fn
|
||||
from ..libs.model_args import get_arg_parser_light_ranking
|
||||
from ..libs.model_utils import read_config
|
||||
from ..libs.warm_start_utils import get_feature_list_for_light_ranking
|
||||
from .model_pools_mlp import light_ranking_mlp_ngbdt
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def build_graph(
|
||||
features, label, mode, params, config=None, run_light_ranking_group_metrics_in_bq=False
|
||||
):
|
||||
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
||||
this_model_func = light_ranking_mlp_ngbdt
|
||||
model_output = this_model_func(features, is_training, params, label)
|
||||
|
||||
logits = model_output["output"]
|
||||
graph_output = {}
|
||||
# --------------------------------------------------------
|
||||
# define graph output dict
|
||||
# --------------------------------------------------------
|
||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||
loss = None
|
||||
output_label = "prediction"
|
||||
if params.task_name in LABELS_LR:
|
||||
output = tf.nn.sigmoid(logits)
|
||||
output = tf.clip_by_value(output, 0, 1)
|
||||
|
||||
if run_light_ranking_group_metrics_in_bq:
|
||||
graph_output["trace_id"] = features["meta.trace_id"]
|
||||
graph_output["target"] = features["meta.ranking.weighted_oonc_model_score"]
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
else:
|
||||
output_label = "output"
|
||||
weights = tf.cast(features["weights"], dtype=tf.float32, name="RecordWeights")
|
||||
|
||||
if params.task_name in LABELS_LR:
|
||||
if params.use_record_weight:
|
||||
weights = tf.clip_by_value(
|
||||
1.0 / (1.0 + weights + params.smooth_weight), params.min_record_weight, 1.0
|
||||
)
|
||||
|
||||
loss = tf.reduce_sum(
|
||||
tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits) * weights
|
||||
) / (tf.reduce_sum(weights))
|
||||
else:
|
||||
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits))
|
||||
output = tf.nn.sigmoid(logits)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
train_op = None
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
# --------------------------------------------------------
|
||||
# get train_op
|
||||
# --------------------------------------------------------
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=params.learning_rate)
|
||||
update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
variables = get_trainable_variables(
|
||||
all_trainable_variables=tf.trainable_variables(), trainable_regexes=params.trainable_regexes
|
||||
)
|
||||
with tf.control_dependencies(update_ops):
|
||||
train_op = twml.optimizers.optimize_loss(
|
||||
loss=loss,
|
||||
variables=variables,
|
||||
global_step=tf.train.get_global_step(),
|
||||
optimizer=optimizer,
|
||||
learning_rate=params.learning_rate,
|
||||
learning_rate_decay_fn=twml.learning_rate_decay.get_learning_rate_decay_fn(params),
|
||||
)
|
||||
|
||||
graph_output[output_label] = output
|
||||
graph_output["loss"] = loss
|
||||
graph_output["train_op"] = train_op
|
||||
return graph_output
|
||||
|
||||
|
||||
def get_params(args=None):
|
||||
parser = get_arg_parser_light_ranking()
|
||||
if args is None:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def _main():
|
||||
opt = get_params()
|
||||
logging.info("parse is: ")
|
||||
logging.info(opt)
|
||||
|
||||
feature_list = read_config(opt.feature_list).items()
|
||||
feature_config = get_feature_config_light_ranking(
|
||||
data_spec_path=opt.data_spec,
|
||||
feature_list_provided=feature_list,
|
||||
opt=opt,
|
||||
add_gbdt=opt.use_gbdt_features,
|
||||
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
|
||||
)
|
||||
feature_list_path = opt.feature_list
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Create Trainer
|
||||
# --------------------------------------------------------
|
||||
trainer = DataRecordTrainer(
|
||||
name=opt.model_trainer_name,
|
||||
params=opt,
|
||||
build_graph_fn=build_graph,
|
||||
save_dir=opt.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
||||
)
|
||||
if opt.directly_export_best:
|
||||
logging.info("Directly exporting the model without training")
|
||||
else:
|
||||
# ----------------------------------------------------
|
||||
# Model Training & Evaluation
|
||||
# ----------------------------------------------------
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
train_input_fn = trainer.get_train_input_fn(shuffle=True)
|
||||
|
||||
if opt.distributed or opt.num_workers is not None:
|
||||
learn = trainer.train_and_evaluate
|
||||
else:
|
||||
learn = trainer.learn
|
||||
logging.info("Training...")
|
||||
start = datetime.now()
|
||||
|
||||
early_stop_metric = "rce_unweighted_" + opt.task_name
|
||||
learn(
|
||||
early_stop_minimize=False,
|
||||
early_stop_metric=early_stop_metric,
|
||||
early_stop_patience=opt.early_stop_patience,
|
||||
early_stop_tolerance=opt.early_stop_tolerance,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_input_fn=train_input_fn,
|
||||
)
|
||||
|
||||
end = datetime.now()
|
||||
logging.info("Training time: " + str(end - start))
|
||||
|
||||
logging.info("Exporting the models...")
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Do the model exporting
|
||||
# --------------------------------------------------------
|
||||
start = datetime.now()
|
||||
if not opt.export_dir:
|
||||
opt.export_dir = os.path.join(opt.save_dir, "exported_models")
|
||||
|
||||
raw_model_path = twml.contrib.export.export_fn.export_all_models(
|
||||
trainer=trainer,
|
||||
export_dir=opt.export_dir,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
|
||||
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
|
||||
)
|
||||
export_model_dir = decode_str_or_unicode(raw_model_path)
|
||||
|
||||
logging.info("Model export time: " + str(datetime.now() - start))
|
||||
logging.info("The saved model directory is: " + opt.save_dir)
|
||||
|
||||
tf.logging.info("getting default continuous_feature_list")
|
||||
continuous_feature_list = get_feature_list_for_light_ranking(feature_list_path, opt.data_spec)
|
||||
continous_feature_list_save_path = os.path.join(opt.save_dir, "continuous_feature_list.json")
|
||||
twml.util.write_file(continous_feature_list_save_path, continuous_feature_list, encode="json")
|
||||
tf.logging.info(f"Finish writting files to {continous_feature_list_save_path}")
|
||||
|
||||
if opt.run_light_ranking_group_metrics:
|
||||
# --------------------------------------------
|
||||
# Run Light Ranking Group Metrics
|
||||
# --------------------------------------------
|
||||
run_group_metrics_light_ranking(
|
||||
trainer=trainer,
|
||||
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
|
||||
model_path=export_model_dir,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
)
|
||||
|
||||
if opt.run_light_ranking_group_metrics_in_bq:
|
||||
# ----------------------------------------------------------------------------------------
|
||||
# Get Light/Heavy Ranker Predictions for Light Ranking Group Metrics in BigQuery
|
||||
# ----------------------------------------------------------------------------------------
|
||||
trainer_pred = DataRecordTrainer(
|
||||
name=opt.model_trainer_name,
|
||||
params=opt,
|
||||
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
|
||||
save_dir=opt.save_dir + "/tmp/",
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
||||
)
|
||||
checkpoint_folder = os.path.join(opt.save_dir, "best_checkpoint")
|
||||
checkpoint = tf.train.latest_checkpoint(checkpoint_folder, latest_filename=None)
|
||||
tf.logging.info("\n\nPrediction from Checkpoint: {:}.\n\n".format(checkpoint))
|
||||
run_group_metrics_light_ranking_in_bq(
|
||||
trainer=trainer_pred, params=opt, checkpoint_path=checkpoint
|
||||
)
|
||||
|
||||
tf.logging.info("Done Training & Prediction.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
BIN
pushservice/src/main/python/models/light_ranking/eval_model.docx
Normal file
BIN
pushservice/src/main/python/models/light_ranking/eval_model.docx
Normal file
Binary file not shown.
@ -1,89 +0,0 @@
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import os
|
||||
|
||||
from ..libs.group_metrics import (
|
||||
run_group_metrics_light_ranking,
|
||||
run_group_metrics_light_ranking_in_bq,
|
||||
)
|
||||
from ..libs.metric_fn_utils import get_metric_fn
|
||||
from ..libs.model_args import get_arg_parser_light_ranking
|
||||
from ..libs.model_utils import read_config
|
||||
from .deep_norm import build_graph, DataRecordTrainer, get_config_func, logging
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_arg_parser_light_ranking()
|
||||
parser.add_argument(
|
||||
"--eval_checkpoint",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Which checkpoint to use for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--saved_model_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to saved model for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_binary_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to compute the basic binary metrics for Light Ranking.",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
logging.info("parse is: ")
|
||||
logging.info(opt)
|
||||
|
||||
feature_list = read_config(opt.feature_list).items()
|
||||
feature_config = get_config_func(opt.feat_config_type)(
|
||||
data_spec_path=opt.data_spec,
|
||||
feature_list_provided=feature_list,
|
||||
opt=opt,
|
||||
add_gbdt=opt.use_gbdt_features,
|
||||
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
|
||||
)
|
||||
|
||||
# -----------------------------------------------
|
||||
# Create Trainer
|
||||
# -----------------------------------------------
|
||||
trainer = DataRecordTrainer(
|
||||
name=opt.model_trainer_name,
|
||||
params=opt,
|
||||
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
|
||||
save_dir=opt.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
||||
)
|
||||
|
||||
# -----------------------------------------------
|
||||
# Model Evaluation
|
||||
# -----------------------------------------------
|
||||
logging.info("Evaluating...")
|
||||
start = datetime.now()
|
||||
|
||||
if opt.run_binary_metrics:
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
eval_steps = None if (opt.eval_steps is not None and opt.eval_steps < 0) else opt.eval_steps
|
||||
trainer.estimator.evaluate(eval_input_fn, steps=eval_steps, checkpoint_path=opt.eval_checkpoint)
|
||||
|
||||
if opt.run_light_ranking_group_metrics_in_bq:
|
||||
run_group_metrics_light_ranking_in_bq(
|
||||
trainer=trainer, params=opt, checkpoint_path=opt.eval_checkpoint
|
||||
)
|
||||
|
||||
if opt.run_light_ranking_group_metrics:
|
||||
run_group_metrics_light_ranking(
|
||||
trainer=trainer,
|
||||
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
|
||||
model_path=opt.saved_model_path,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
)
|
||||
|
||||
end = datetime.now()
|
||||
logging.info("Evaluating time: " + str(end - start))
|
Binary file not shown.
@ -1,187 +0,0 @@
|
||||
import warnings
|
||||
|
||||
from twml.contrib.layers import ZscoreNormalization
|
||||
|
||||
from ...libs.customized_full_sparse import FullSparse
|
||||
from ...libs.get_feat_config import FEAT_CONFIG_DEFAULT_VAL as MISSING_VALUE_MARKER
|
||||
from ...libs.model_utils import (
|
||||
_sparse_feature_fixup,
|
||||
adaptive_transformation,
|
||||
filter_nans_and_infs,
|
||||
get_dense_out,
|
||||
tensor_dropout,
|
||||
)
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
# checkstyle: noqa
|
||||
|
||||
def light_ranking_mlp_ngbdt(features, is_training, params, label=None):
|
||||
return deepnorm_light_ranking(
|
||||
features,
|
||||
is_training,
|
||||
params,
|
||||
label=label,
|
||||
decay=params.momentum,
|
||||
dense_emb_size=params.dense_embedding_size,
|
||||
base_activation=tf.keras.layers.LeakyReLU(),
|
||||
input_dropout_rate=params.dropout,
|
||||
use_gbdt=False,
|
||||
)
|
||||
|
||||
|
||||
def deepnorm_light_ranking(
|
||||
features,
|
||||
is_training,
|
||||
params,
|
||||
label=None,
|
||||
decay=0.99999,
|
||||
dense_emb_size=128,
|
||||
base_activation=None,
|
||||
input_dropout_rate=None,
|
||||
input_dense_type="self_atten_dense",
|
||||
emb_dense_type="self_atten_dense",
|
||||
mlp_dense_type="self_atten_dense",
|
||||
use_gbdt=False,
|
||||
):
|
||||
# --------------------------------------------------------
|
||||
# Initial Parameter Checking
|
||||
# --------------------------------------------------------
|
||||
if base_activation is None:
|
||||
base_activation = tf.keras.layers.LeakyReLU()
|
||||
|
||||
if label is not None:
|
||||
warnings.warn(
|
||||
"Label is unused in deepnorm_gbdt. Stop using this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
with tf.variable_scope("helper_layers"):
|
||||
full_sparse_layer = FullSparse(
|
||||
output_size=params.sparse_embedding_size,
|
||||
activation=base_activation,
|
||||
use_sparse_grads=is_training,
|
||||
use_binary_values=False,
|
||||
dtype=tf.float32,
|
||||
)
|
||||
input_normalizing_layer = ZscoreNormalization(decay=decay, name="input_normalizing_layer")
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Feature Selection & Embedding
|
||||
# --------------------------------------------------------
|
||||
if use_gbdt:
|
||||
sparse_gbdt_features = _sparse_feature_fixup(features["gbdt_sparse"], params.input_size_bits)
|
||||
if input_dropout_rate is not None:
|
||||
sparse_gbdt_features = tensor_dropout(
|
||||
sparse_gbdt_features, input_dropout_rate, is_training, sparse_tensor=True
|
||||
)
|
||||
|
||||
total_embed = full_sparse_layer(sparse_gbdt_features, use_binary_values=True)
|
||||
|
||||
if (input_dropout_rate is not None) and is_training:
|
||||
total_embed = total_embed / (1 - input_dropout_rate)
|
||||
|
||||
else:
|
||||
with tf.variable_scope("dense_branch"):
|
||||
dense_continuous_features = filter_nans_and_infs(features["continuous"])
|
||||
|
||||
if params.use_missing_sub_branch:
|
||||
is_missing = tf.equal(dense_continuous_features, MISSING_VALUE_MARKER)
|
||||
continuous_features_filled = tf.where(
|
||||
is_missing,
|
||||
tf.zeros_like(dense_continuous_features),
|
||||
dense_continuous_features,
|
||||
)
|
||||
normalized_features = input_normalizing_layer(
|
||||
continuous_features_filled, is_training, tf.math.logical_not(is_missing)
|
||||
)
|
||||
|
||||
with tf.variable_scope("missing_sub_branch"):
|
||||
missing_feature_embed = get_dense_out(
|
||||
tf.cast(is_missing, tf.float32),
|
||||
dense_emb_size,
|
||||
activation=base_activation,
|
||||
dense_type=input_dense_type,
|
||||
)
|
||||
|
||||
else:
|
||||
continuous_features_filled = dense_continuous_features
|
||||
normalized_features = input_normalizing_layer(continuous_features_filled, is_training)
|
||||
|
||||
with tf.variable_scope("continuous_sub_branch"):
|
||||
normalized_features = adaptive_transformation(
|
||||
normalized_features, is_training, func_type="tiny"
|
||||
)
|
||||
|
||||
if input_dropout_rate is not None:
|
||||
normalized_features = tensor_dropout(
|
||||
normalized_features,
|
||||
input_dropout_rate,
|
||||
is_training,
|
||||
sparse_tensor=False,
|
||||
)
|
||||
filled_feature_embed = get_dense_out(
|
||||
normalized_features,
|
||||
dense_emb_size,
|
||||
activation=base_activation,
|
||||
dense_type=input_dense_type,
|
||||
)
|
||||
|
||||
if params.use_missing_sub_branch:
|
||||
dense_embed = tf.concat(
|
||||
[filled_feature_embed, missing_feature_embed], axis=1, name="merge_dense_emb"
|
||||
)
|
||||
else:
|
||||
dense_embed = filled_feature_embed
|
||||
|
||||
with tf.variable_scope("sparse_branch"):
|
||||
sparse_discrete_features = _sparse_feature_fixup(
|
||||
features["sparse_no_continuous"], params.input_size_bits
|
||||
)
|
||||
if input_dropout_rate is not None:
|
||||
sparse_discrete_features = tensor_dropout(
|
||||
sparse_discrete_features, input_dropout_rate, is_training, sparse_tensor=True
|
||||
)
|
||||
|
||||
discrete_features_embed = full_sparse_layer(sparse_discrete_features, use_binary_values=True)
|
||||
|
||||
if (input_dropout_rate is not None) and is_training:
|
||||
discrete_features_embed = discrete_features_embed / (1 - input_dropout_rate)
|
||||
|
||||
total_embed = tf.concat(
|
||||
[dense_embed, discrete_features_embed],
|
||||
axis=1,
|
||||
name="total_embed",
|
||||
)
|
||||
|
||||
total_embed = tf.layers.batch_normalization(
|
||||
total_embed,
|
||||
training=is_training,
|
||||
renorm_momentum=decay,
|
||||
momentum=decay,
|
||||
renorm=is_training,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# MLP Layers
|
||||
# --------------------------------------------------------
|
||||
with tf.variable_scope("MLP_branch"):
|
||||
|
||||
assert params.num_mlp_layers >= 0
|
||||
embed_list = [total_embed] + [None for _ in range(params.num_mlp_layers)]
|
||||
dense_types = [emb_dense_type] + [mlp_dense_type for _ in range(params.num_mlp_layers - 1)]
|
||||
|
||||
for xl in range(1, params.num_mlp_layers + 1):
|
||||
neurons = params.mlp_neuron_scale ** (params.num_mlp_layers + 1 - xl)
|
||||
embed_list[xl] = get_dense_out(
|
||||
embed_list[xl - 1], neurons, activation=base_activation, dense_type=dense_types[xl - 1]
|
||||
)
|
||||
|
||||
if params.task_name in ["Sent", "HeavyRankPosition", "HeavyRankProbability"]:
|
||||
logits = get_dense_out(embed_list[-1], 1, activation=None, dense_type=mlp_dense_type)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
output_dict = {"output": logits}
|
||||
return output_dict
|
@ -1,337 +0,0 @@
|
||||
scala_library(
|
||||
sources = ["**/*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/com/twitter/bijection:scrooge",
|
||||
"3rdparty/jvm/com/twitter/storehaus:core",
|
||||
"abdecider",
|
||||
"abuse/detection/src/main/thrift/com/twitter/abuse/detection/scoring:thrift-scala",
|
||||
"ann/src/main/scala/com/twitter/ann/common",
|
||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
||||
"audience-rewards/thrift/src/main/thrift:thrift-scala",
|
||||
"communities/thrift/src/main/thrift/com/twitter/communities:thrift-scala",
|
||||
"configapi/configapi-core",
|
||||
"configapi/configapi-decider",
|
||||
"content-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"content-recommender/thrift/src/main/thrift:thrift-scala",
|
||||
"copyselectionservice/server/src/main/scala/com/twitter/copyselectionservice/algorithms",
|
||||
"copyselectionservice/thrift/src/main/thrift:copyselectionservice-scala",
|
||||
"cortex-deepbird/thrift/src/main/thrift:thrift-java",
|
||||
"cr-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"cuad/projects/hashspace/thrift:thrift-scala",
|
||||
"cuad/projects/tagspace/thrift/src/main/thrift:thrift-scala",
|
||||
"detopic/thrift/src/main/thrift:thrift-scala",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/configapi",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/ddg",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/environment",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/fatigue",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/nackwarmupfilter",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/server",
|
||||
"discovery-ds/src/main/thrift/com/twitter/dds/scio/searcher_aggregate_history_srp:searcher_aggregate_history_srp-scala",
|
||||
"escherbird/src/scala/com/twitter/escherbird/util/metadatastitch",
|
||||
"escherbird/src/scala/com/twitter/escherbird/util/uttclient",
|
||||
"escherbird/src/thrift/com/twitter/escherbird/utt:strato-columns-scala",
|
||||
"eventbus/client",
|
||||
"eventdetection/event_context/src/main/scala/com/twitter/eventdetection/event_context/util",
|
||||
"events-recos/events-recos-service/src/main/thrift:events-recos-thrift-scala",
|
||||
"explore/explore-ranker/thrift/src/main/thrift:thrift-scala",
|
||||
"featureswitches/featureswitches-core/src/main/scala",
|
||||
"featureswitches/featureswitches-core/src/main/scala:dynmap",
|
||||
"featureswitches/featureswitches-core/src/main/scala:recipient",
|
||||
"featureswitches/featureswitches-core/src/main/scala:useragent",
|
||||
"featureswitches/featureswitches-core/src/main/scala/com/twitter/featureswitches/v2/builder",
|
||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authentication",
|
||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/server",
|
||||
"finagle-internal/ostrich-stats",
|
||||
"finagle/finagle-core/src/main",
|
||||
"finagle/finagle-http/src/main/scala",
|
||||
"finagle/finagle-memcached/src/main/scala",
|
||||
"finagle/finagle-stats",
|
||||
"finagle/finagle-thriftmux",
|
||||
"finagle/finagle-tunable/src/main/scala",
|
||||
"finagle/finagle-zipkin-scribe",
|
||||
"finatra-internal/abdecider",
|
||||
"finatra-internal/decider",
|
||||
"finatra-internal/mtls-http/src/main/scala",
|
||||
"finatra-internal/mtls-thriftmux/src/main/scala",
|
||||
"finatra/http-client/src/main/scala",
|
||||
"finatra/http-core/src/main/java/com/twitter/finatra/http",
|
||||
"finatra/http-core/src/main/scala/com/twitter/finatra/http/response",
|
||||
"finatra/http-server/src/main/scala/com/twitter/finatra/http",
|
||||
"finatra/http-server/src/main/scala/com/twitter/finatra/http/filters",
|
||||
"finatra/inject/inject-app/src/main/java/com/twitter/inject/annotations",
|
||||
"finatra/inject/inject-app/src/main/scala",
|
||||
"finatra/inject/inject-core/src/main/scala",
|
||||
"finatra/inject/inject-server/src/main/scala",
|
||||
"finatra/inject/inject-slf4j/src/main/scala/com/twitter/inject",
|
||||
"finatra/inject/inject-thrift-client/src/main/scala",
|
||||
"finatra/inject/inject-utils/src/main/scala",
|
||||
"finatra/utils/src/main/java/com/twitter/finatra/annotations",
|
||||
"fleets/fleets-proxy/thrift/src/main/thrift:fleet-scala",
|
||||
"fleets/fleets-proxy/thrift/src/main/thrift/service:baseservice-scala",
|
||||
"flock-client/src/main/scala",
|
||||
"flock-client/src/main/thrift:thrift-scala",
|
||||
"follow-recommendations-service/thrift/src/main/thrift:thrift-scala",
|
||||
"frigate/frigate-common:base",
|
||||
"frigate/frigate-common:config",
|
||||
"frigate/frigate-common:debug",
|
||||
"frigate/frigate-common:entity_graph_client",
|
||||
"frigate/frigate-common:history",
|
||||
"frigate/frigate-common:logger",
|
||||
"frigate/frigate-common:ml-base",
|
||||
"frigate/frigate-common:ml-feature",
|
||||
"frigate/frigate-common:ml-prediction",
|
||||
"frigate/frigate-common:ntab",
|
||||
"frigate/frigate-common:predicate",
|
||||
"frigate/frigate-common:rec_types",
|
||||
"frigate/frigate-common:score_summary",
|
||||
"frigate/frigate-common:util",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/candidate",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/experiments",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/filter",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/modules/store:semantic_core_stores",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/deviceinfo",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/interests",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/strato",
|
||||
"frigate/push-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"geo/geo-prediction/src/main/thrift:local-viral-tweets-thrift-scala",
|
||||
"geoduck/service/src/main/scala/com/twitter/geoduck/service/common/clientmodules",
|
||||
"geoduck/util/country",
|
||||
"gizmoduck/client/src/main/scala/com/twitter/gizmoduck/testusers/client",
|
||||
"hermit/hermit-core:model-user_state",
|
||||
"hermit/hermit-core:predicate",
|
||||
"hermit/hermit-core:predicate-gizmoduck",
|
||||
"hermit/hermit-core:predicate-scarecrow",
|
||||
"hermit/hermit-core:predicate-socialgraph",
|
||||
"hermit/hermit-core:predicate-tweetypie",
|
||||
"hermit/hermit-core:store-labeled_push_recs",
|
||||
"hermit/hermit-core:store-metastore",
|
||||
"hermit/hermit-core:store-timezone",
|
||||
"hermit/hermit-core:store-tweetypie",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/constants",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/model",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/common",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/gizmoduck",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/scarecrow",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/semantic_core",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_htl_session_store",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_interest",
|
||||
"hmli/hss/src/main/thrift/com/twitter/hss:thrift-scala",
|
||||
"ibis2/service/src/main/scala/com/twitter/ibis2/lib",
|
||||
"ibis2/service/src/main/thrift/com/twitter/ibis2/service:ibis2-service-scala",
|
||||
"interests-service/thrift/src/main/thrift:thrift-scala",
|
||||
"interests_discovery/thrift/src/main/thrift:batch-thrift-scala",
|
||||
"interests_discovery/thrift/src/main/thrift:service-thrift-scala",
|
||||
"kujaku/thrift/src/main/thrift:domain-scala",
|
||||
"live-video-timeline/client/src/main/scala/com/twitter/livevideo/timeline/client/v2",
|
||||
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain",
|
||||
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain/v2",
|
||||
"live-video-timeline/thrift/src/main/thrift/com/twitter/livevideo/timeline:thrift-scala",
|
||||
"live-video/common/src/main/scala/com/twitter/livevideo/common/domain/v2",
|
||||
"live-video/common/src/main/scala/com/twitter/livevideo/common/ids",
|
||||
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:exception-scala",
|
||||
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:thrift-scala",
|
||||
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:custom-notification-actions-scala",
|
||||
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:thrift-scala",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/heavyranker",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/base",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/frigate",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/push",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/lightranker",
|
||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/genericfeedbackstore",
|
||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model:alias",
|
||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model/service",
|
||||
"notificationservice/common/src/test/scala/com/twitter/notificationservice/mocks",
|
||||
"notificationservice/scribe/src/main/scala/com/twitter/notificationservice/scribe/manhattan:mh_wrapper",
|
||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/api:thrift-scala",
|
||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/badgecount-api:thrift-scala",
|
||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/generic_notifications:thrift-scala",
|
||||
"notifinfra/ni-lib/src/main/scala/com/twitter/ni/lib/logged_out_transform",
|
||||
"observability/observability-manhattan-client/src/main/scala",
|
||||
"onboarding/service/src/main/scala/com/twitter/onboarding/task/service/models/external",
|
||||
"onboarding/service/thrift/src/main/thrift:thrift-scala",
|
||||
"people-discovery/api/thrift/src/main/thrift:thrift-scala",
|
||||
"periscope/api-proxy-thrift/thrift/src/main/thrift:thrift-scala",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module/stringcenter",
|
||||
"product-mixer/core/src/main/thrift/com/twitter/product_mixer/core:thrift-scala",
|
||||
"qig-ranker/thrift/src/main/thrift:thrift-scala",
|
||||
"rux-ds/src/main/thrift/com/twitter/ruxds/jobs/user_past_aggregate:user_past_aggregate-scala",
|
||||
"rux/common/src/main/scala/com/twitter/rux/common/encode",
|
||||
"rux/common/thrift/src/main/thrift/rux-context:rux-context-scala",
|
||||
"rux/common/thrift/src/main/thrift/strato:strato-scala",
|
||||
"scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers",
|
||||
"scrooge/scrooge-core",
|
||||
"scrooge/scrooge-serializer/src/main/scala",
|
||||
"sensitive-ds/src/main/thrift/com/twitter/scio/nsfw_user_segmentation:nsfw_user_segmentation-scala",
|
||||
"servo/decider/src/main/scala",
|
||||
"servo/request/src/main/scala",
|
||||
"servo/util/src/main/scala",
|
||||
"src/java/com/twitter/ml/api:api-base",
|
||||
"src/java/com/twitter/ml/prediction/core",
|
||||
"src/scala/com/twitter/frigate/data_pipeline/common",
|
||||
"src/scala/com/twitter/frigate/data_pipeline/embedding_cg:embedding_cg-test-user-ids",
|
||||
"src/scala/com/twitter/frigate/data_pipeline/features_common",
|
||||
"src/scala/com/twitter/frigate/news_article_recs/news_articles_metadata:thrift-scala",
|
||||
"src/scala/com/twitter/frontpage/stream/util",
|
||||
"src/scala/com/twitter/language/normalization",
|
||||
"src/scala/com/twitter/ml/api/embedding",
|
||||
"src/scala/com/twitter/ml/api/util:datarecord",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/entities/core",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/entities/magicrecs",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/core:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/cuad:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/embeddings",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/magicrecs:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/topic_signals:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/lib",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/data",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/dynamic",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/entity",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/online",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/core/config",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/core/deploy",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/core/model",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/popgeo/deploy",
|
||||
"src/scala/com/twitter/simclusters_v2/common",
|
||||
"src/scala/com/twitter/storehaus_internal/manhattan",
|
||||
"src/scala/com/twitter/storehaus_internal/manhattan/config",
|
||||
"src/scala/com/twitter/storehaus_internal/memcache",
|
||||
"src/scala/com/twitter/storehaus_internal/memcache/config",
|
||||
"src/scala/com/twitter/storehaus_internal/util",
|
||||
"src/scala/com/twitter/taxi/common",
|
||||
"src/scala/com/twitter/taxi/config",
|
||||
"src/scala/com/twitter/taxi/deploy",
|
||||
"src/scala/com/twitter/taxi/trending/common",
|
||||
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
|
||||
"src/thrift/com/twitter/clientapp/gen:clientapp-scala",
|
||||
"src/thrift/com/twitter/core_workflows/user_model:user_model-scala",
|
||||
"src/thrift/com/twitter/escherbird/common:constants-scala",
|
||||
"src/thrift/com/twitter/escherbird/metadata:megadata-scala",
|
||||
"src/thrift/com/twitter/escherbird/metadata:metadata-service-scala",
|
||||
"src/thrift/com/twitter/escherbird/search:search-service-scala",
|
||||
"src/thrift/com/twitter/expandodo:only-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-common-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-ml-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-notification-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-secondary-accounts-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-user-media-representation-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/data_pipeline:frigate-user-history-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/dau_model:frigate-dau-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/magic_events:frigate-magic-events-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/magic_events/scribe:thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/pushcap:frigate-pushcap-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/pushservice:frigate-pushservice-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/scribe:frigate-scribe-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/subscribed_search:frigate-subscribed-search-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/user_states:frigate-userstates-thrift-scala",
|
||||
"src/thrift/com/twitter/geoduck:geoduck-scala",
|
||||
"src/thrift/com/twitter/gizmoduck:thrift-scala",
|
||||
"src/thrift/com/twitter/gizmoduck:user-thrift-scala",
|
||||
"src/thrift/com/twitter/hermit:hermit-scala",
|
||||
"src/thrift/com/twitter/hermit/pop_geo:hermit-pop-geo-scala",
|
||||
"src/thrift/com/twitter/hermit/stp:hermit-stp-scala",
|
||||
"src/thrift/com/twitter/ibis:service-scala",
|
||||
"src/thrift/com/twitter/manhattan:v1-scala",
|
||||
"src/thrift/com/twitter/manhattan:v2-scala",
|
||||
"src/thrift/com/twitter/ml/api:data-java",
|
||||
"src/thrift/com/twitter/ml/api:data-scala",
|
||||
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-scala",
|
||||
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-strato",
|
||||
"src/thrift/com/twitter/ml/prediction_service:prediction_service-java",
|
||||
"src/thrift/com/twitter/permissions_storage:thrift-scala",
|
||||
"src/thrift/com/twitter/pink-floyd/thrift:thrift-scala",
|
||||
"src/thrift/com/twitter/recos:recos-common-scala",
|
||||
"src/thrift/com/twitter/recos/user_tweet_entity_graph:user_tweet_entity_graph-scala",
|
||||
"src/thrift/com/twitter/recos/user_user_graph:user_user_graph-scala",
|
||||
"src/thrift/com/twitter/relevance/feature_store:feature_store-scala",
|
||||
"src/thrift/com/twitter/search:earlybird-scala",
|
||||
"src/thrift/com/twitter/search/common:features-scala",
|
||||
"src/thrift/com/twitter/search/query_interaction_graph:query_interaction_graph-scala",
|
||||
"src/thrift/com/twitter/search/query_interaction_graph/service:qig-service-scala",
|
||||
"src/thrift/com/twitter/service/metastore/gen:thrift-scala",
|
||||
"src/thrift/com/twitter/service/scarecrow/gen:scarecrow-scala",
|
||||
"src/thrift/com/twitter/service/scarecrow/gen:tiered-actions-scala",
|
||||
"src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala",
|
||||
"src/thrift/com/twitter/socialgraph:thrift-scala",
|
||||
"src/thrift/com/twitter/spam/rtf:safety-level-scala",
|
||||
"src/thrift/com/twitter/timelinemixer:thrift-scala",
|
||||
"src/thrift/com/twitter/timelinemixer/server/internal:thrift-scala",
|
||||
"src/thrift/com/twitter/timelines/author_features/user_health:thrift-scala",
|
||||
"src/thrift/com/twitter/timelines/real_graph:real_graph-scala",
|
||||
"src/thrift/com/twitter/timelinescorer:thrift-scala",
|
||||
"src/thrift/com/twitter/timelinescorer/server/internal:thrift-scala",
|
||||
"src/thrift/com/twitter/timelineservice/server/internal:thrift-scala",
|
||||
"src/thrift/com/twitter/timelineservice/server/suggests/logging:thrift-scala",
|
||||
"src/thrift/com/twitter/trends/common:common-scala",
|
||||
"src/thrift/com/twitter/trends/trip_v1:trip-tweets-thrift-scala",
|
||||
"src/thrift/com/twitter/tweetypie:service-scala",
|
||||
"src/thrift/com/twitter/tweetypie:tweet-scala",
|
||||
"src/thrift/com/twitter/user_session_store:thrift-scala",
|
||||
"src/thrift/com/twitter/wtf/candidate:wtf-candidate-scala",
|
||||
"src/thrift/com/twitter/wtf/interest:interest-thrift-scala",
|
||||
"src/thrift/com/twitter/wtf/scalding/common:thrift-scala",
|
||||
"stitch/stitch-core",
|
||||
"stitch/stitch-gizmoduck",
|
||||
"stitch/stitch-socialgraph/src/main/scala",
|
||||
"stitch/stitch-storehaus/src/main/scala",
|
||||
"stitch/stitch-tweetypie/src/main/scala",
|
||||
"storage/clients/manhattan/client/src/main/scala",
|
||||
"strato/config/columns/clients:clients-strato-client",
|
||||
"strato/config/columns/geo/user:user-strato-client",
|
||||
"strato/config/columns/globe/curation:curation-strato-client",
|
||||
"strato/config/columns/interests:interests-strato-client",
|
||||
"strato/config/columns/ml/featureStore:featureStore-strato-client",
|
||||
"strato/config/columns/notifications:notifications-strato-client",
|
||||
"strato/config/columns/notifinfra:notifinfra-strato-client",
|
||||
"strato/config/columns/periscope:periscope-strato-client",
|
||||
"strato/config/columns/rux",
|
||||
"strato/config/columns/rux:rux-strato-client",
|
||||
"strato/config/columns/rux/open-app:open-app-strato-client",
|
||||
"strato/config/columns/socialgraph/graphs:graphs-strato-client",
|
||||
"strato/config/columns/socialgraph/service/soft_users:soft_users-strato-client",
|
||||
"strato/config/columns/translation/service:service-strato-client",
|
||||
"strato/config/columns/translation/service/platform:platform-strato-client",
|
||||
"strato/config/columns/trends/trip:trip-strato-client",
|
||||
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
|
||||
"strato/config/src/thrift/com/twitter/strato/columns/notifications:thrift-scala",
|
||||
"strato/src/main/scala/com/twitter/strato/config",
|
||||
"strato/src/main/scala/com/twitter/strato/response",
|
||||
"thrift-web-forms",
|
||||
"timeline-training-service/service/thrift/src/main/thrift:thrift-scala",
|
||||
"timelines/src/main/scala/com/twitter/timelines/features/app",
|
||||
"topic-social-proof/server/src/main/thrift:thrift-scala",
|
||||
"topiclisting/topiclisting-core/src/main/scala/com/twitter/topiclisting",
|
||||
"topiclisting/topiclisting-utt/src/main/scala/com/twitter/topiclisting/utt",
|
||||
"trends/common/src/main/thrift/com/twitter/trends/common:thrift-scala",
|
||||
"tweetypie/src/scala/com/twitter/tweetypie/tweettext",
|
||||
"twitter-context/src/main/scala",
|
||||
"twitter-server-internal",
|
||||
"twitter-server/server/src/main/scala",
|
||||
"twitter-text/lib/java/src/main/java/com/twitter/twittertext",
|
||||
"twml/runtime/src/main/scala/com/twitter/deepbird/runtime/prediction_engine:prediction_engine_mkl",
|
||||
"ubs/common/src/main/thrift/com/twitter/ubs:broadcast-thrift-scala",
|
||||
"ubs/common/src/main/thrift/com/twitter/ubs:seller_application-thrift-scala",
|
||||
"user_session_store/src/main/scala/com/twitter/user_session_store/impl/manhattan/readwrite",
|
||||
"util-internal/scribe",
|
||||
"util-internal/tunable/src/main/scala/com/twitter/util/tunable",
|
||||
"util/util-app",
|
||||
"util/util-hashing/src/main/scala",
|
||||
"util/util-slf4j-api/src/main/scala",
|
||||
"util/util-stats/src/main/scala",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/builder",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/push_service",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/spaces",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/util",
|
||||
],
|
||||
exports = [
|
||||
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
@ -1,93 +0,0 @@
|
||||
package com.twitter.frigate.pushservice
|
||||
|
||||
import com.google.inject.Inject
|
||||
import com.google.inject.Singleton
|
||||
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
|
||||
import com.twitter.finagle.thrift.ClientId
|
||||
import com.twitter.finatra.thrift.routing.ThriftWarmup
|
||||
import com.twitter.util.logging.Logging
|
||||
import com.twitter.inject.utils.Handler
|
||||
import com.twitter.frigate.pushservice.{thriftscala => t}
|
||||
import com.twitter.frigate.thriftscala.NotificationDisplayLocation
|
||||
import com.twitter.util.Stopwatch
|
||||
import com.twitter.scrooge.Request
|
||||
import com.twitter.scrooge.Response
|
||||
import com.twitter.util.Return
|
||||
import com.twitter.util.Throw
|
||||
import com.twitter.util.Try
|
||||
|
||||
/**
|
||||
* Warms up the refresh request path.
|
||||
* If service is running as pushservice-send then the warmup does nothing.
|
||||
*
|
||||
* When making the warmup refresh requests we
|
||||
* - Set skipFilters to true to execute as much of the request path as possible
|
||||
* - Set darkWrite to true to prevent sending a push
|
||||
*/
|
||||
@Singleton
|
||||
class PushMixerThriftServerWarmupHandler @Inject() (
|
||||
warmup: ThriftWarmup,
|
||||
serviceIdentifier: ServiceIdentifier)
|
||||
extends Handler
|
||||
with Logging {
|
||||
|
||||
private val clientId = ClientId("thrift-warmup-client")
|
||||
|
||||
def handle(): Unit = {
|
||||
val refreshServices = Set(
|
||||
"frigate-pushservice",
|
||||
"frigate-pushservice-canary",
|
||||
"frigate-pushservice-canary-control",
|
||||
"frigate-pushservice-canary-treatment"
|
||||
)
|
||||
val isRefresh = refreshServices.contains(serviceIdentifier.service)
|
||||
if (isRefresh && !serviceIdentifier.isLocal) refreshWarmup()
|
||||
}
|
||||
|
||||
def refreshWarmup(): Unit = {
|
||||
val elapsed = Stopwatch.start()
|
||||
val testIds = Seq(
|
||||
1,
|
||||
2,
|
||||
3
|
||||
)
|
||||
try {
|
||||
clientId.asCurrent {
|
||||
testIds.foreach { id =>
|
||||
val warmupReq = warmupQuery(id)
|
||||
info(s"Sending warm-up request to service with query: $warmupReq")
|
||||
warmup.sendRequest(
|
||||
method = t.PushService.Refresh,
|
||||
req = Request(t.PushService.Refresh.Args(warmupReq)))(assertWarmupResponse)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
error(e.getMessage, e)
|
||||
}
|
||||
info(s"Warm up complete. Time taken: ${elapsed().toString}")
|
||||
}
|
||||
|
||||
private def warmupQuery(userId: Long): t.RefreshRequest = {
|
||||
t.RefreshRequest(
|
||||
userId = userId,
|
||||
notificationDisplayLocation = NotificationDisplayLocation.PushToMobileDevice,
|
||||
context = Some(
|
||||
t.PushContext(
|
||||
skipFilters = Some(true),
|
||||
darkWrite = Some(true)
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
private def assertWarmupResponse(
|
||||
result: Try[Response[t.PushService.Refresh.SuccessType]]
|
||||
): Unit = {
|
||||
result match {
|
||||
case Return(_) => // ok
|
||||
case Throw(exception) =>
|
||||
warn("Error performing warm-up request.")
|
||||
error(exception.getMessage, exception)
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,193 +0,0 @@
|
||||
package com.twitter.frigate.pushservice
|
||||
|
||||
import com.twitter.discovery.common.environment.modules.EnvironmentModule
|
||||
import com.twitter.finagle.Filter
|
||||
import com.twitter.finatra.annotations.DarkTrafficFilterType
|
||||
import com.twitter.finatra.decider.modules.DeciderModule
|
||||
import com.twitter.finatra.http.HttpServer
|
||||
import com.twitter.finatra.http.filters.CommonFilters
|
||||
import com.twitter.finatra.http.routing.HttpRouter
|
||||
import com.twitter.finatra.mtls.http.{Mtls => HttpMtls}
|
||||
import com.twitter.finatra.mtls.thriftmux.{Mtls => ThriftMtls}
|
||||
import com.twitter.finatra.mtls.thriftmux.filters.MtlsServerSessionTrackerFilter
|
||||
import com.twitter.finatra.thrift.ThriftServer
|
||||
import com.twitter.finatra.thrift.filters.ExceptionMappingFilter
|
||||
import com.twitter.finatra.thrift.filters.LoggingMDCFilter
|
||||
import com.twitter.finatra.thrift.filters.StatsFilter
|
||||
import com.twitter.finatra.thrift.filters.ThriftMDCFilter
|
||||
import com.twitter.finatra.thrift.filters.TraceIdMDCFilter
|
||||
import com.twitter.finatra.thrift.routing.ThriftRouter
|
||||
import com.twitter.frigate.common.logger.MRLoggerGlobalVariables
|
||||
import com.twitter.frigate.pushservice.controller.PushServiceController
|
||||
import com.twitter.frigate.pushservice.module._
|
||||
import com.twitter.inject.TwitterModule
|
||||
import com.twitter.inject.annotations.Flags
|
||||
import com.twitter.inject.thrift.modules.ThriftClientIdModule
|
||||
import com.twitter.logging.BareFormatter
|
||||
import com.twitter.logging.Level
|
||||
import com.twitter.logging.LoggerFactory
|
||||
import com.twitter.logging.{Logging => JLogging}
|
||||
import com.twitter.logging.QueueingHandler
|
||||
import com.twitter.logging.ScribeHandler
|
||||
import com.twitter.product_mixer.core.module.product_mixer_flags.ProductMixerFlagModule
|
||||
import com.twitter.product_mixer.core.module.ABDeciderModule
|
||||
import com.twitter.product_mixer.core.module.FeatureSwitchesModule
|
||||
import com.twitter.product_mixer.core.module.StratoClientModule
|
||||
|
||||
object PushServiceMain extends PushServiceFinatraServer
|
||||
|
||||
class PushServiceFinatraServer
|
||||
extends ThriftServer
|
||||
with ThriftMtls
|
||||
with HttpServer
|
||||
with HttpMtls
|
||||
with JLogging {
|
||||
|
||||
override val name = "PushService"
|
||||
|
||||
override val modules: Seq[TwitterModule] = {
|
||||
Seq(
|
||||
ABDeciderModule,
|
||||
DeciderModule,
|
||||
FeatureSwitchesModule,
|
||||
FilterModule,
|
||||
FlagModule,
|
||||
EnvironmentModule,
|
||||
ThriftClientIdModule,
|
||||
DeployConfigModule,
|
||||
ProductMixerFlagModule,
|
||||
StratoClientModule,
|
||||
PushHandlerModule,
|
||||
PushTargetUserBuilderModule,
|
||||
PushServiceDarkTrafficModule,
|
||||
LoggedOutPushTargetUserBuilderModule,
|
||||
new ThriftWebFormsModule(this),
|
||||
)
|
||||
}
|
||||
|
||||
override def configureThrift(router: ThriftRouter): Unit = {
|
||||
router
|
||||
.filter[ExceptionMappingFilter]
|
||||
.filter[LoggingMDCFilter]
|
||||
.filter[TraceIdMDCFilter]
|
||||
.filter[ThriftMDCFilter]
|
||||
.filter[MtlsServerSessionTrackerFilter]
|
||||
.filter[StatsFilter]
|
||||
.filter[Filter.TypeAgnostic, DarkTrafficFilterType]
|
||||
.add[PushServiceController]
|
||||
}
|
||||
|
||||
override def configureHttp(router: HttpRouter): Unit =
|
||||
router
|
||||
.filter[CommonFilters]
|
||||
|
||||
override protected def start(): Unit = {
|
||||
MRLoggerGlobalVariables.setRequiredFlags(
|
||||
traceLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerIsTraceAll.name)),
|
||||
nthLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerNthLog.name)),
|
||||
nthLogValFlag = injector.instance[Long](Flags.named(FlagModule.mrLoggerNthVal.name))
|
||||
)
|
||||
}
|
||||
|
||||
override protected def warmup(): Unit = {
|
||||
handle[PushMixerThriftServerWarmupHandler]()
|
||||
}
|
||||
|
||||
override protected def configureLoggerFactories(): Unit = {
|
||||
loggerFactories.foreach { _() }
|
||||
}
|
||||
|
||||
override def loggerFactories: List[LoggerFactory] = {
|
||||
val scribeScope = statsReceiver.scope("scribe")
|
||||
List(
|
||||
LoggerFactory(
|
||||
level = Some(levelFlag()),
|
||||
handlers = handlers
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "request_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 10000,
|
||||
handler = ScribeHandler(
|
||||
category = "frigate_pushservice_log",
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("frigate_pushservice_log")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "notification_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 10000,
|
||||
handler = ScribeHandler(
|
||||
category = "frigate_notifier",
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("frigate_notifier")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "push_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 10000,
|
||||
handler = ScribeHandler(
|
||||
category = "test_frigate_push",
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("test_frigate_push")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "push_subsample_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 2500,
|
||||
handler = ScribeHandler(
|
||||
category = "magicrecs_candidates_subsample_scribe",
|
||||
maxMessagesPerTransaction = 250,
|
||||
maxMessagesToBuffer = 2500,
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("magicrecs_candidates_subsample_scribe")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "mr_request_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 2500,
|
||||
handler = ScribeHandler(
|
||||
category = "mr_request_scribe",
|
||||
maxMessagesPerTransaction = 250,
|
||||
maxMessagesToBuffer = 2500,
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("mr_request_scribe")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "high_quality_candidates_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 2500,
|
||||
handler = ScribeHandler(
|
||||
category = "frigate_high_quality_candidates_log",
|
||||
maxMessagesPerTransaction = 250,
|
||||
maxMessagesToBuffer = 2500,
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("high_quality_candidates_scribe")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,323 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.contentrecommender.thriftscala.MetricTag
|
||||
import com.twitter.cr_mixer.thriftscala.CrMixerTweetRequest
|
||||
import com.twitter.cr_mixer.thriftscala.NotificationsContext
|
||||
import com.twitter.cr_mixer.thriftscala.Product
|
||||
import com.twitter.cr_mixer.thriftscala.ProductContext
|
||||
import com.twitter.cr_mixer.thriftscala.{MetricTag => CrMixerMetricTag}
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.AlgorithmScore
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.CrMixerCandidate
|
||||
import com.twitter.frigate.common.base.TopicCandidate
|
||||
import com.twitter.frigate.common.base.TopicProofTweetCandidate
|
||||
import com.twitter.frigate.common.base.TweetCandidate
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutInNetworkTweets
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
|
||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
||||
import com.twitter.frigate.pushservice.util.AdaptorUtils
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
||||
import com.twitter.frigate.pushservice.util.TweetWithTopicProof
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.hermit.predicate.socialgraph.RelationEdge
|
||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.topiclisting.utt.LocalizedEntity
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
||||
import com.twitter.util.Future
|
||||
import scala.collection.Map
|
||||
|
||||
case class ContentRecommenderMixerAdaptor(
|
||||
crMixerTweetStore: CrMixerTweetStore,
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
edgeStore: ReadableStore[RelationEdge, Boolean],
|
||||
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore,
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
private[this] val stats = globalStats.scope("ContentRecommenderMixerAdaptor")
|
||||
private[this] val numOfValidAuthors = stats.stat("num_of_valid_authors")
|
||||
private[this] val numOutOfMaximumDropped = stats.stat("dropped_due_out_of_maximum")
|
||||
private[this] val totalInputRecs = stats.counter("input_recs")
|
||||
private[this] val totalOutputRecs = stats.stat("output_recs")
|
||||
private[this] val totalRequests = stats.counter("total_requests")
|
||||
private[this] val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
private[this] val totalOutNetworkRecs = stats.counter("out_network_tweets")
|
||||
private[this] val totalInNetworkRecs = stats.counter("in_network_tweets")
|
||||
|
||||
/**
|
||||
* Builds OON raw candidates based on input OON Tweets
|
||||
*/
|
||||
def buildOONRawCandidates(
|
||||
inputTarget: Target,
|
||||
oonTweets: Seq[TweetyPieResult],
|
||||
tweetScoreMap: Map[Long, Double],
|
||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
|
||||
maxNumOfCandidates: Int
|
||||
): Option[Seq[RawCandidate]] = {
|
||||
val cands = oonTweets.flatMap { tweetResult =>
|
||||
val tweetId = tweetResult.tweet.id
|
||||
generateOONRawCandidate(
|
||||
inputTarget,
|
||||
tweetId,
|
||||
Some(tweetResult),
|
||||
tweetScoreMap,
|
||||
tweetIdToTagsMap
|
||||
)
|
||||
}
|
||||
|
||||
val candidates = restrict(
|
||||
maxNumOfCandidates,
|
||||
cands,
|
||||
numOutOfMaximumDropped,
|
||||
totalOutputRecs
|
||||
)
|
||||
|
||||
Some(candidates)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a single RawCandidate With TopicProofTweetCandidate
|
||||
*/
|
||||
def buildTopicTweetRawCandidate(
|
||||
inputTarget: Target,
|
||||
tweetWithTopicProof: TweetWithTopicProof,
|
||||
localizedEntity: LocalizedEntity,
|
||||
tags: Option[Seq[MetricTag]],
|
||||
): RawCandidate with TopicProofTweetCandidate = {
|
||||
new RawCandidate with TopicProofTweetCandidate {
|
||||
override def target: Target = inputTarget
|
||||
override def topicListingSetting: Option[String] = Some(
|
||||
tweetWithTopicProof.topicListingSetting)
|
||||
override def tweetId: Long = tweetWithTopicProof.tweetId
|
||||
override def tweetyPieResult: Option[TweetyPieResult] = Some(
|
||||
tweetWithTopicProof.tweetyPieResult)
|
||||
override def semanticCoreEntityId: Option[Long] = Some(tweetWithTopicProof.topicId)
|
||||
override def localizedUttEntity: Option[LocalizedEntity] = Some(localizedEntity)
|
||||
override def algorithmCR: Option[String] = tweetWithTopicProof.algorithmCR
|
||||
override def tagsCR: Option[Seq[MetricTag]] = tags
|
||||
override def isOutOfNetwork: Boolean = tweetWithTopicProof.isOON
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a group of TopicTweets and transforms them into RawCandidates
|
||||
*/
|
||||
def buildTopicTweetRawCandidates(
|
||||
inputTarget: Target,
|
||||
topicProofCandidates: Seq[TweetWithTopicProof],
|
||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
|
||||
maxNumberOfCands: Int
|
||||
): Future[Option[Seq[RawCandidate]]] = {
|
||||
val semanticCoreEntityIds = topicProofCandidates
|
||||
.map(_.topicId)
|
||||
.toSet
|
||||
|
||||
TopicsUtil
|
||||
.getLocalizedEntityMap(inputTarget, semanticCoreEntityIds, uttEntityHydrationStore)
|
||||
.map { localizedEntityMap =>
|
||||
val rawCandidates = topicProofCandidates.collect {
|
||||
case topicSocialProof: TweetWithTopicProof
|
||||
if localizedEntityMap.contains(topicSocialProof.topicId) =>
|
||||
// Once we deprecate CR calls, we should replace this code to use the CrMixerMetricTag
|
||||
val tags = tweetIdToTagsMap.get(topicSocialProof.tweetId).map {
|
||||
_.flatMap { tag => MetricTag.get(tag.value) }
|
||||
}
|
||||
buildTopicTweetRawCandidate(
|
||||
inputTarget,
|
||||
topicSocialProof,
|
||||
localizedEntityMap(topicSocialProof.topicId),
|
||||
tags
|
||||
)
|
||||
}
|
||||
|
||||
val candResult = restrict(
|
||||
maxNumberOfCands,
|
||||
rawCandidates,
|
||||
numOutOfMaximumDropped,
|
||||
totalOutputRecs
|
||||
)
|
||||
|
||||
Some(candResult)
|
||||
}
|
||||
}
|
||||
|
||||
private def generateOONRawCandidate(
|
||||
inputTarget: Target,
|
||||
id: Long,
|
||||
result: Option[TweetyPieResult],
|
||||
tweetScoreMap: Map[Long, Double],
|
||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]]
|
||||
): Option[RawCandidate with TweetCandidate] = {
|
||||
val tagsFromCR = tweetIdToTagsMap.get(id).map { _.flatMap { tag => MetricTag.get(tag.value) } }
|
||||
val candidate = new RawCandidate with CrMixerCandidate with TopicCandidate with AlgorithmScore {
|
||||
override val tweetId = id
|
||||
override val target = inputTarget
|
||||
override val tweetyPieResult = result
|
||||
override val localizedUttEntity = None
|
||||
override val semanticCoreEntityId = None
|
||||
override def commonRecType =
|
||||
getMediaBasedCRT(
|
||||
CommonRecommendationType.TwistlyTweet,
|
||||
CommonRecommendationType.TwistlyPhoto,
|
||||
CommonRecommendationType.TwistlyVideo)
|
||||
override def tagsCR = tagsFromCR
|
||||
override def algorithmScore = tweetScoreMap.get(id)
|
||||
override def algorithmCR = None
|
||||
}
|
||||
Some(candidate)
|
||||
}
|
||||
|
||||
private def restrict(
|
||||
maxNumToReturn: Int,
|
||||
candidates: Seq[RawCandidate],
|
||||
numOutOfMaximumDropped: Stat,
|
||||
totalOutputRecs: Stat
|
||||
): Seq[RawCandidate] = {
|
||||
val newCandidates = candidates.take(maxNumToReturn)
|
||||
val numDropped = candidates.length - newCandidates.length
|
||||
numOutOfMaximumDropped.add(numDropped)
|
||||
totalOutputRecs.add(newCandidates.size)
|
||||
newCandidates
|
||||
}
|
||||
|
||||
private def buildCrMixerRequest(
|
||||
target: Target,
|
||||
countryCode: Option[String],
|
||||
language: Option[String],
|
||||
seenTweets: Seq[Long]
|
||||
): CrMixerTweetRequest = {
|
||||
CrMixerTweetRequest(
|
||||
clientContext = ClientContext(
|
||||
userId = Some(target.targetId),
|
||||
countryCode = countryCode,
|
||||
languageCode = language
|
||||
),
|
||||
product = Product.Notifications,
|
||||
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
|
||||
excludedTweetIds = Some(seenTweets)
|
||||
)
|
||||
}
|
||||
|
||||
private def selectCandidatesToSendBasedOnSettings(
|
||||
isRecommendationsEligible: Boolean,
|
||||
isTopicsEligible: Boolean,
|
||||
oonRawCandidates: Option[Seq[RawCandidate]],
|
||||
topicTweetCandidates: Option[Seq[RawCandidate]]
|
||||
): Option[Seq[RawCandidate]] = {
|
||||
if (isRecommendationsEligible && isTopicsEligible) {
|
||||
Some(topicTweetCandidates.getOrElse(Seq.empty) ++ oonRawCandidates.getOrElse(Seq.empty))
|
||||
} else if (isRecommendationsEligible) {
|
||||
oonRawCandidates
|
||||
} else if (isTopicsEligible) {
|
||||
topicTweetCandidates
|
||||
} else None
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
Future
|
||||
.join(
|
||||
target.seenTweetIds,
|
||||
target.countryCode,
|
||||
target.inferredUserDeviceLanguage,
|
||||
PushDeviceUtil.isTopicsEligible(target),
|
||||
PushDeviceUtil.isRecommendationsEligible(target)
|
||||
).flatMap {
|
||||
case (seenTweets, countryCode, language, isTopicsEligible, isRecommendationsEligible) =>
|
||||
val request = buildCrMixerRequest(target, countryCode, language, seenTweets)
|
||||
crMixerTweetStore.getTweetRecommendations(request).flatMap {
|
||||
case Some(response) =>
|
||||
totalInputRecs.incr(response.tweets.size)
|
||||
totalRequests.incr()
|
||||
AdaptorUtils
|
||||
.getTweetyPieResults(
|
||||
response.tweets.map(_.tweetId).toSet,
|
||||
tweetyPieStore).flatMap { tweetyPieResultMap =>
|
||||
filterOutInNetworkTweets(
|
||||
target,
|
||||
filterOutReplyTweet(tweetyPieResultMap.toMap, nonReplyTweetsCounter),
|
||||
edgeStore,
|
||||
numOfValidAuthors).flatMap {
|
||||
outNetworkTweetsWithId: Seq[(Long, TweetyPieResult)] =>
|
||||
totalOutNetworkRecs.incr(outNetworkTweetsWithId.size)
|
||||
totalInNetworkRecs.incr(response.tweets.size - outNetworkTweetsWithId.size)
|
||||
val outNetworkTweets: Seq[TweetyPieResult] = outNetworkTweetsWithId.map {
|
||||
case (_, tweetyPieResult) => tweetyPieResult
|
||||
}
|
||||
|
||||
val tweetIdToTagsMap = response.tweets.map { tweet =>
|
||||
tweet.tweetId -> tweet.metricTags.getOrElse(Seq.empty)
|
||||
}.toMap
|
||||
|
||||
val tweetScoreMap = response.tweets.map { tweet =>
|
||||
tweet.tweetId -> tweet.score
|
||||
}.toMap
|
||||
|
||||
val maxNumOfCandidates =
|
||||
target.params(PushFeatureSwitchParams.NumberOfMaxCrMixerCandidatesParam)
|
||||
|
||||
val oonRawCandidates =
|
||||
buildOONRawCandidates(
|
||||
target,
|
||||
outNetworkTweets,
|
||||
tweetScoreMap,
|
||||
tweetIdToTagsMap,
|
||||
maxNumOfCandidates)
|
||||
|
||||
TopicsUtil
|
||||
.getTopicSocialProofs(
|
||||
target,
|
||||
outNetworkTweets,
|
||||
topicSocialProofServiceStore,
|
||||
edgeStore,
|
||||
PushFeatureSwitchParams.TopicProofTweetCandidatesTopicScoreThreshold).flatMap {
|
||||
tweetsWithTopicProof =>
|
||||
buildTopicTweetRawCandidates(
|
||||
target,
|
||||
tweetsWithTopicProof,
|
||||
tweetIdToTagsMap,
|
||||
maxNumOfCandidates)
|
||||
}.map { topicTweetCandidates =>
|
||||
selectCandidatesToSendBasedOnSettings(
|
||||
isRecommendationsEligible,
|
||||
isTopicsEligible,
|
||||
oonRawCandidates,
|
||||
topicTweetCandidates)
|
||||
}
|
||||
}
|
||||
}
|
||||
case _ => Future.None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* For a user to be available the following news to happen
|
||||
*/
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
Future
|
||||
.join(
|
||||
PushDeviceUtil.isRecommendationsEligible(target),
|
||||
PushDeviceUtil.isTopicsEligible(target)
|
||||
).map {
|
||||
case (isRecommendationsEligible, isTopicsEligible) =>
|
||||
(isRecommendationsEligible || isTopicsEligible) &&
|
||||
target.params(PushParams.ContentRecommenderMixerAdaptorDecider)
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,293 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.candidate._
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
||||
import com.twitter.recos.recos_common.thriftscala.SocialProofType
|
||||
import com.twitter.search.common.features.thriftscala.ThriftSearchResultFeatures
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.timelines.configapi.Param
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Time
|
||||
import scala.collection.Map
|
||||
|
||||
case class EarlyBirdFirstDegreeCandidateAdaptor(
|
||||
earlyBirdFirstDegreeCandidates: CandidateSource[
|
||||
EarlybirdCandidateSource.Query,
|
||||
EarlybirdCandidate
|
||||
],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
||||
maxResultsParam: Param[Int],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
type EBCandidate = EarlybirdCandidate with TweetDetails
|
||||
private val stats = globalStats.scope("EarlyBirdFirstDegreeAdaptor")
|
||||
private val earlyBirdCandsStat: Stat = stats.stat("early_bird_cands_dist")
|
||||
private val emptyEarlyBirdCands = stats.counter("empty_early_bird_candidates")
|
||||
private val seedSetEmpty = stats.counter("empty_seedset")
|
||||
private val seenTweetsStat = stats.stat("filtered_by_seen_tweets")
|
||||
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
|
||||
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
private val enableRetweets = stats.counter("enable_retweets")
|
||||
private val f1withoutSocialContexts = stats.counter("f1_without_social_context")
|
||||
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
|
||||
|
||||
override val name: String = earlyBirdFirstDegreeCandidates.name
|
||||
|
||||
private def getAllSocialContextActions(
|
||||
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
|
||||
): Seq[SocialContextAction] = {
|
||||
socialProofTypes.flatMap {
|
||||
case (SocialProofType.Favorite, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Favorite)
|
||||
)
|
||||
}
|
||||
case (SocialProofType.Retweet, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Retweet)
|
||||
)
|
||||
}
|
||||
case (SocialProofType.Reply, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Reply)
|
||||
)
|
||||
}
|
||||
case (SocialProofType.Tweet, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Tweet)
|
||||
)
|
||||
}
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
||||
private def generateRetweetCandidate(
|
||||
inputTarget: Target,
|
||||
candidate: EBCandidate,
|
||||
scIds: Seq[Long],
|
||||
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
|
||||
): RawCandidate = {
|
||||
val scActions = scIds.map { scId => SocialContextAction(scId, Time.now.inMilliseconds) }
|
||||
new RawCandidate with TweetRetweetCandidate with EarlybirdTweetFeatures {
|
||||
override val socialContextActions = scActions
|
||||
override val socialContextAllTypeActions = getAllSocialContextActions(socialProofTypes)
|
||||
override val tweetId = candidate.tweetId
|
||||
override val target = inputTarget
|
||||
override val tweetyPieResult = candidate.tweetyPieResult
|
||||
override val features = candidate.features
|
||||
}
|
||||
}
|
||||
|
||||
private def generateF1CandidateWithoutSocialContext(
|
||||
inputTarget: Target,
|
||||
candidate: EBCandidate
|
||||
): RawCandidate = {
|
||||
f1withoutSocialContexts.incr()
|
||||
new RawCandidate with F1FirstDegree with EarlybirdTweetFeatures {
|
||||
override val tweetId = candidate.tweetId
|
||||
override val target = inputTarget
|
||||
override val tweetyPieResult = candidate.tweetyPieResult
|
||||
override val features = candidate.features
|
||||
}
|
||||
}
|
||||
|
||||
private def generateEarlyBirdCandidate(
|
||||
id: Long,
|
||||
result: Option[TweetyPieResult],
|
||||
ebFeatures: Option[ThriftSearchResultFeatures]
|
||||
): EBCandidate = {
|
||||
new EarlybirdCandidate with TweetDetails {
|
||||
override val tweetyPieResult: Option[TweetyPieResult] = result
|
||||
override val tweetId: Long = id
|
||||
override val features: Option[ThriftSearchResultFeatures] = ebFeatures
|
||||
}
|
||||
}
|
||||
|
||||
private def filterOutSeenTweets(seenTweetIds: Seq[Long], inputTweetIds: Seq[Long]): Seq[Long] = {
|
||||
inputTweetIds.filterNot(seenTweetIds.contains)
|
||||
}
|
||||
|
||||
private def filterInvalidTweets(
|
||||
tweetIds: Seq[Long],
|
||||
target: Target
|
||||
): Future[Seq[(Long, TweetyPieResult)]] = {
|
||||
|
||||
val resMap = {
|
||||
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
|
||||
userTweetTweetyPieStoreCounter.incr()
|
||||
val keys = tweetIds.map { tweetId =>
|
||||
UserTweet(tweetId, Some(target.targetId))
|
||||
}
|
||||
|
||||
userTweetTweetyPieStore
|
||||
.multiGet(keys.toSet).map {
|
||||
case (userTweet, resultFut) =>
|
||||
userTweet.tweetId -> resultFut
|
||||
}.toMap
|
||||
} else {
|
||||
(target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
|
||||
case true => tweetyPieStore
|
||||
case false => tweetyPieStoreNoVF
|
||||
}).multiGet(tweetIds.toSet)
|
||||
}
|
||||
}
|
||||
Future.collect(resMap).map { tweetyPieResultMap =>
|
||||
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
|
||||
case (id: Long, Some(result)) =>
|
||||
id -> result
|
||||
}
|
||||
|
||||
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
|
||||
cands.toSeq
|
||||
}
|
||||
}
|
||||
|
||||
private def getEBRetweetCandidates(
|
||||
inputTarget: Target,
|
||||
retweets: Seq[(Long, TweetyPieResult)]
|
||||
): Seq[RawCandidate] = {
|
||||
retweets.flatMap {
|
||||
case (_, tweetypieResult) =>
|
||||
tweetypieResult.tweet.coreData.flatMap { coreData =>
|
||||
tweetypieResult.sourceTweet.map { sourceTweet =>
|
||||
val tweetId = sourceTweet.id
|
||||
val scId = coreData.userId
|
||||
val socialProofTypes = Seq((SocialProofType.Retweet, Seq(scId)))
|
||||
val candidate = generateEarlyBirdCandidate(
|
||||
tweetId,
|
||||
Some(TweetyPieResult(sourceTweet, None, None)),
|
||||
None
|
||||
)
|
||||
generateRetweetCandidate(
|
||||
inputTarget,
|
||||
candidate,
|
||||
Seq(scId),
|
||||
socialProofTypes
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getEBFirstDegreeCands(
|
||||
tweets: Seq[(Long, TweetyPieResult)],
|
||||
ebTweetIdMap: Map[Long, Option[ThriftSearchResultFeatures]]
|
||||
): Seq[EBCandidate] = {
|
||||
tweets.map {
|
||||
case (id, tweetypieResult) =>
|
||||
val features = ebTweetIdMap.getOrElse(id, None)
|
||||
generateEarlyBirdCandidate(id, Some(tweetypieResult), features)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a combination of raw candidates made of: f1 recs, topic social proof recs, sc recs and retweet candidates
|
||||
*/
|
||||
def buildRawCandidates(
|
||||
inputTarget: Target,
|
||||
firstDegreeCandidates: Seq[EBCandidate],
|
||||
retweetCandidates: Seq[RawCandidate]
|
||||
): Seq[RawCandidate] = {
|
||||
val hydratedF1Recs =
|
||||
firstDegreeCandidates.map(generateF1CandidateWithoutSocialContext(inputTarget, _))
|
||||
hydratedF1Recs ++ retweetCandidates
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
inputTarget.seedsWithWeight.flatMap { seedsetOpt =>
|
||||
val seedsetMap = seedsetOpt.getOrElse(Map.empty)
|
||||
|
||||
if (seedsetMap.isEmpty) {
|
||||
seedSetEmpty.incr()
|
||||
Future.None
|
||||
} else {
|
||||
val maxResultsToReturn = inputTarget.params(maxResultsParam)
|
||||
val maxTweetAge = inputTarget.params(PushFeatureSwitchParams.F1CandidateMaxTweetAgeParam)
|
||||
val earlybirdQuery = EarlybirdCandidateSource.Query(
|
||||
maxNumResultsToReturn = maxResultsToReturn,
|
||||
seedset = seedsetMap,
|
||||
maxConsecutiveResultsByTheSameUser = Some(1),
|
||||
maxTweetAge = maxTweetAge,
|
||||
disableTimelinesMLModel = false,
|
||||
searcherId = Some(inputTarget.targetId),
|
||||
isProtectTweetsEnabled =
|
||||
inputTarget.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors),
|
||||
followedUserIds = Some(seedsetMap.keySet.toSeq)
|
||||
)
|
||||
|
||||
Future
|
||||
.join(inputTarget.seenTweetIds, earlyBirdFirstDegreeCandidates.get(earlybirdQuery))
|
||||
.flatMap {
|
||||
case (seenTweetIds, Some(candidates)) =>
|
||||
earlyBirdCandsStat.add(candidates.size)
|
||||
|
||||
val ebTweetIdMap = candidates.map { cand => cand.tweetId -> cand.features }.toMap
|
||||
|
||||
val ebTweetIds = ebTweetIdMap.keys.toSeq
|
||||
|
||||
val tweetIds = filterOutSeenTweets(seenTweetIds, ebTweetIds)
|
||||
seenTweetsStat.add(ebTweetIds.size - tweetIds.size)
|
||||
|
||||
filterInvalidTweets(tweetIds, inputTarget)
|
||||
.map { validTweets =>
|
||||
val (retweets, tweets) = validTweets.partition {
|
||||
case (_, tweetypieResult) =>
|
||||
tweetypieResult.sourceTweet.isDefined
|
||||
}
|
||||
|
||||
val firstDegreeCandidates = getEBFirstDegreeCands(tweets, ebTweetIdMap)
|
||||
|
||||
val retweetCandidates = {
|
||||
if (inputTarget.params(PushParams.EarlyBirdSCBasedCandidatesParam) &&
|
||||
inputTarget.params(PushParams.MRTweetRetweetRecsParam)) {
|
||||
enableRetweets.incr()
|
||||
getEBRetweetCandidates(inputTarget, retweets)
|
||||
} else Nil
|
||||
}
|
||||
|
||||
Some(
|
||||
buildRawCandidates(
|
||||
inputTarget,
|
||||
firstDegreeCandidates,
|
||||
retweetCandidates
|
||||
))
|
||||
}
|
||||
|
||||
case _ =>
|
||||
emptyEarlyBirdCands.incr()
|
||||
Future.None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target)
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,120 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerProductResponse
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerRequest
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerResponse
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRecommendation
|
||||
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResponse
|
||||
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResult
|
||||
import com.twitter.explore_ranker.thriftscala.NotificationsVideoRecs
|
||||
import com.twitter.explore_ranker.thriftscala.Product
|
||||
import com.twitter.explore_ranker.thriftscala.ProductContext
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.OutOfNetworkTweetCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.util.AdaptorUtils
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
case class ExploreVideoTweetCandidateAdaptor(
|
||||
exploreRankerStore: ReadableStore[ExploreRankerRequest, ExploreRankerResponse],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override def name: String = this.getClass.getSimpleName
|
||||
private[this] val stats = globalStats.scope("ExploreVideoTweetCandidateAdaptor")
|
||||
private[this] val totalInputRecs = stats.stat("input_recs")
|
||||
private[this] val totalRequests = stats.counter("total_requests")
|
||||
private[this] val totalEmptyResponse = stats.counter("total_empty_response")
|
||||
|
||||
private def buildExploreRankerRequest(
|
||||
target: Target,
|
||||
countryCode: Option[String],
|
||||
language: Option[String],
|
||||
): ExploreRankerRequest = {
|
||||
ExploreRankerRequest(
|
||||
clientContext = ClientContext(
|
||||
userId = Some(target.targetId),
|
||||
countryCode = countryCode,
|
||||
languageCode = language,
|
||||
),
|
||||
product = Product.NotificationsVideoRecs,
|
||||
productContext = Some(ProductContext.NotificationsVideoRecs(NotificationsVideoRecs())),
|
||||
maxResults = Some(target.params(PushFeatureSwitchParams.MaxExploreVideoTweets))
|
||||
)
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
Future
|
||||
.join(
|
||||
target.countryCode,
|
||||
target.inferredUserDeviceLanguage
|
||||
).flatMap {
|
||||
case (countryCode, language) =>
|
||||
val request = buildExploreRankerRequest(target, countryCode, language)
|
||||
exploreRankerStore.get(request).flatMap {
|
||||
case Some(response) =>
|
||||
val exploreResonseTweetIds = response match {
|
||||
case ExploreRankerResponse(ExploreRankerProductResponse
|
||||
.ImmersiveRecsResponse(ImmersiveRecsResponse(immersiveRecsResult))) =>
|
||||
immersiveRecsResult.collect {
|
||||
case ImmersiveRecsResult(ExploreRecommendation
|
||||
.ExploreTweetRecommendation(exploreTweetRecommendation)) =>
|
||||
exploreTweetRecommendation.tweetId
|
||||
}
|
||||
case _ =>
|
||||
Seq.empty
|
||||
}
|
||||
|
||||
totalInputRecs.add(exploreResonseTweetIds.size)
|
||||
totalRequests.incr()
|
||||
AdaptorUtils
|
||||
.getTweetyPieResults(exploreResonseTweetIds.toSet, tweetyPieStore).map {
|
||||
tweetyPieResultMap =>
|
||||
val candidates = tweetyPieResultMap.values.flatten
|
||||
.map(buildVideoRawCandidates(target, _))
|
||||
Some(candidates.toSeq)
|
||||
}
|
||||
case _ =>
|
||||
totalEmptyResponse.incr()
|
||||
Future.None
|
||||
}
|
||||
case _ =>
|
||||
Future.None
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target).map { userRecommendationsEligible =>
|
||||
userRecommendationsEligible && target.params(PushFeatureSwitchParams.EnableExploreVideoTweets)
|
||||
}
|
||||
}
|
||||
private def buildVideoRawCandidates(
|
||||
target: Target,
|
||||
tweetyPieResult: TweetyPieResult
|
||||
): RawCandidate with OutOfNetworkTweetCandidate = {
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = tweetyPieResult.tweet.id,
|
||||
mediaCRT = MediaCRT(
|
||||
CommonRecommendationType.ExploreVideoTweet,
|
||||
CommonRecommendationType.ExploreVideoTweet,
|
||||
CommonRecommendationType.ExploreVideoTweet
|
||||
),
|
||||
result = Some(tweetyPieResult),
|
||||
localizedEntity = None
|
||||
)
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,272 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.cr_mixer.thriftscala.FrsTweetRequest
|
||||
import com.twitter.cr_mixer.thriftscala.NotificationsContext
|
||||
import com.twitter.cr_mixer.thriftscala.Product
|
||||
import com.twitter.cr_mixer.thriftscala.ProductContext
|
||||
import com.twitter.finagle.stats.Counter
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
|
||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.hermit.constants.AlgorithmFeedbackTokens
|
||||
import com.twitter.hermit.model.Algorithm.Algorithm
|
||||
import com.twitter.hermit.model.Algorithm.CrowdSearchAccounts
|
||||
import com.twitter.hermit.model.Algorithm.ForwardEmailBook
|
||||
import com.twitter.hermit.model.Algorithm.ForwardPhoneBook
|
||||
import com.twitter.hermit.model.Algorithm.ReverseEmailBookIbis
|
||||
import com.twitter.hermit.model.Algorithm.ReversePhoneBook
|
||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
||||
import com.twitter.util.Future
|
||||
|
||||
object FRSAlgorithmFeedbackTokenUtil {
|
||||
private val crtsByAlgoToken = Map(
|
||||
getAlgorithmToken(ReverseEmailBookIbis) -> CommonRecommendationType.ReverseAddressbookTweet,
|
||||
getAlgorithmToken(ReversePhoneBook) -> CommonRecommendationType.ReverseAddressbookTweet,
|
||||
getAlgorithmToken(ForwardEmailBook) -> CommonRecommendationType.ForwardAddressbookTweet,
|
||||
getAlgorithmToken(ForwardPhoneBook) -> CommonRecommendationType.ForwardAddressbookTweet,
|
||||
getAlgorithmToken(CrowdSearchAccounts) -> CommonRecommendationType.CrowdSearchTweet
|
||||
)
|
||||
|
||||
def getAlgorithmToken(algorithm: Algorithm): Int = {
|
||||
AlgorithmFeedbackTokens.AlgorithmToFeedbackTokenMap(algorithm)
|
||||
}
|
||||
|
||||
def getCRTForAlgoToken(algorithmToken: Int): Option[CommonRecommendationType] = {
|
||||
crtsByAlgoToken.get(algorithmToken)
|
||||
}
|
||||
}
|
||||
|
||||
case class FRSTweetCandidateAdaptor(
|
||||
crMixerTweetStore: CrMixerTweetStore,
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore,
|
||||
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
private val stats = globalStats.scope(this.getClass.getSimpleName)
|
||||
private val crtStats = stats.scope("CandidateDistribution")
|
||||
private val totalRequests = stats.counter("total_requests")
|
||||
|
||||
// Candidate Distribution stats
|
||||
private val reverseAddressbookCounter = crtStats.counter("reverse_addressbook")
|
||||
private val forwardAddressbookCounter = crtStats.counter("forward_addressbook")
|
||||
private val frsTweetCounter = crtStats.counter("frs_tweet")
|
||||
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
private val crtToCounterMapping: Map[CommonRecommendationType, Counter] = Map(
|
||||
CommonRecommendationType.ReverseAddressbookTweet -> reverseAddressbookCounter,
|
||||
CommonRecommendationType.ForwardAddressbookTweet -> forwardAddressbookCounter,
|
||||
CommonRecommendationType.FrsTweet -> frsTweetCounter
|
||||
)
|
||||
|
||||
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
|
||||
|
||||
private[this] val numberReturnedCandidates = stats.stat("returned_candidates_from_earlybird")
|
||||
private[this] val numberCandidateWithTopic: Counter = stats.counter("num_can_with_topic")
|
||||
private[this] val numberCandidateWithoutTopic: Counter = stats.counter("num_can_without_topic")
|
||||
|
||||
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
private def filterInvalidTweets(
|
||||
tweetIds: Seq[Long],
|
||||
target: Target
|
||||
): Future[Map[Long, TweetyPieResult]] = {
|
||||
val resMap = {
|
||||
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
|
||||
userTweetTweetyPieStoreCounter.incr()
|
||||
val keys = tweetIds.map { tweetId =>
|
||||
UserTweet(tweetId, Some(target.targetId))
|
||||
}
|
||||
userTweetTweetyPieStore
|
||||
.multiGet(keys.toSet).map {
|
||||
case (userTweet, resultFut) =>
|
||||
userTweet.tweetId -> resultFut
|
||||
}.toMap
|
||||
} else {
|
||||
(if (target.params(PushFeatureSwitchParams.EnableVFInTweetypie)) {
|
||||
tweetyPieStore
|
||||
} else {
|
||||
tweetyPieStoreNoVF
|
||||
}).multiGet(tweetIds.toSet)
|
||||
}
|
||||
}
|
||||
|
||||
Future.collect(resMap).map { tweetyPieResultMap =>
|
||||
// Filter out replies and generate earlybird candidates only for non-empty tweetypie result
|
||||
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
|
||||
case (id: Long, Some(result)) =>
|
||||
id -> result
|
||||
}
|
||||
|
||||
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
|
||||
cands
|
||||
}
|
||||
}
|
||||
|
||||
private def buildRawCandidates(
|
||||
target: Target,
|
||||
ebCandidates: Seq[FRSTweetCandidate]
|
||||
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
||||
|
||||
val enableTopic = target.params(PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicAnnotation)
|
||||
val topicScoreThre =
|
||||
target.params(PushFeatureSwitchParams.FrsTweetCandidatesTopicScoreThreshold)
|
||||
|
||||
val ebTweets = ebCandidates.map { ebCandidate =>
|
||||
ebCandidate.tweetId -> ebCandidate.tweetyPieResult
|
||||
}.toMap
|
||||
|
||||
val tweetIdLocalizedEntityMapFut = TopicsUtil.getTweetIdLocalizedEntityMap(
|
||||
target,
|
||||
ebTweets,
|
||||
uttEntityHydrationStore,
|
||||
topicSocialProofServiceStore,
|
||||
enableTopic,
|
||||
topicScoreThre
|
||||
)
|
||||
|
||||
Future.join(target.deviceInfo, tweetIdLocalizedEntityMapFut).map {
|
||||
case (Some(deviceInfo), tweetIdLocalizedEntityMap) =>
|
||||
val candidates = ebCandidates
|
||||
.map { ebCandidate =>
|
||||
val crt = ebCandidate.commonRecType
|
||||
crtToCounterMapping.get(crt).foreach(_.incr())
|
||||
|
||||
val tweetId = ebCandidate.tweetId
|
||||
val localizedEntityOpt = {
|
||||
if (tweetIdLocalizedEntityMap
|
||||
.contains(tweetId) && tweetIdLocalizedEntityMap.contains(
|
||||
tweetId) && deviceInfo.isTopicsEligible) {
|
||||
tweetIdLocalizedEntityMap(tweetId)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = ebCandidate.tweetId,
|
||||
mediaCRT = MediaCRT(
|
||||
crt,
|
||||
crt,
|
||||
crt
|
||||
),
|
||||
result = ebCandidate.tweetyPieResult,
|
||||
localizedEntity = localizedEntityOpt)
|
||||
}.filter { candidate =>
|
||||
// If user only has the topic setting enabled, filter out all non-topic cands
|
||||
deviceInfo.isRecommendationsEligible || (deviceInfo.isTopicsEligible && candidate.semanticCoreEntityId.nonEmpty)
|
||||
}
|
||||
|
||||
candidates.map { candidate =>
|
||||
if (candidate.semanticCoreEntityId.nonEmpty) {
|
||||
numberCandidateWithTopic.incr()
|
||||
} else {
|
||||
numberCandidateWithoutTopic.incr()
|
||||
}
|
||||
}
|
||||
|
||||
numberReturnedCandidates.add(candidates.length)
|
||||
Some(candidates)
|
||||
case _ => Some(Seq.empty)
|
||||
}
|
||||
}
|
||||
|
||||
def getTweetCandidatesFromCrMixer(
|
||||
inputTarget: Target,
|
||||
showAllResultsFromFrs: Boolean,
|
||||
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
||||
Future
|
||||
.join(
|
||||
inputTarget.seenTweetIds,
|
||||
inputTarget.pushRecItems,
|
||||
inputTarget.countryCode,
|
||||
inputTarget.targetLanguage).flatMap {
|
||||
case (seenTweetIds, pastRecItems, countryCode, language) =>
|
||||
val pastUserRecs = pastRecItems.userIds.toSeq
|
||||
val request = FrsTweetRequest(
|
||||
clientContext = ClientContext(
|
||||
userId = Some(inputTarget.targetId),
|
||||
countryCode = countryCode,
|
||||
languageCode = language
|
||||
),
|
||||
product = Product.Notifications,
|
||||
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
|
||||
excludedUserIds = Some(pastUserRecs),
|
||||
excludedTweetIds = Some(seenTweetIds)
|
||||
)
|
||||
crMixerTweetStore.getFRSTweetCandidates(request).flatMap {
|
||||
case Some(response) =>
|
||||
val tweetIds = response.tweets.map(_.tweetId)
|
||||
val validTweets = filterInvalidTweets(tweetIds, inputTarget)
|
||||
validTweets.flatMap { tweetypieMap =>
|
||||
val ebCandidates = response.tweets
|
||||
.map { frsTweet =>
|
||||
val candidateTweetId = frsTweet.tweetId
|
||||
val resultFromTweetyPie = tweetypieMap.get(candidateTweetId)
|
||||
new FRSTweetCandidate {
|
||||
override val tweetId = candidateTweetId
|
||||
override val features = None
|
||||
override val tweetyPieResult = resultFromTweetyPie
|
||||
override val feedbackToken = frsTweet.frsPrimarySource
|
||||
override val commonRecType: CommonRecommendationType = feedbackToken
|
||||
.flatMap(token =>
|
||||
FRSAlgorithmFeedbackTokenUtil.getCRTForAlgoToken(token)).getOrElse(
|
||||
CommonRecommendationType.FrsTweet)
|
||||
}
|
||||
}.filter { ebCandidate =>
|
||||
showAllResultsFromFrs || ebCandidate.commonRecType == CommonRecommendationType.ReverseAddressbookTweet
|
||||
}
|
||||
|
||||
numberReturnedCandidates.add(ebCandidates.length)
|
||||
buildRawCandidates(
|
||||
inputTarget,
|
||||
ebCandidates
|
||||
)
|
||||
}
|
||||
case _ => Future.None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
||||
totalRequests.incr()
|
||||
val enableResultsFromFrs =
|
||||
inputTarget.params(PushFeatureSwitchParams.EnableResultFromFrsCandidates)
|
||||
getTweetCandidatesFromCrMixer(inputTarget, enableResultsFromFrs)
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
lazy val enableFrsCandidates = target.params(PushFeatureSwitchParams.EnableFrsCandidates)
|
||||
PushDeviceUtil.isRecommendationsEligible(target).flatMap { isEnabledForRecosSetting =>
|
||||
PushDeviceUtil.isTopicsEligible(target).map { topicSettingEnabled =>
|
||||
val isEnabledForTopics =
|
||||
topicSettingEnabled && target.params(
|
||||
PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicSetting)
|
||||
(isEnabledForRecosSetting || isEnabledForTopics) && enableFrsCandidates
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,107 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.candidate._
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
object GenericCandidates {
|
||||
type Target =
|
||||
TargetUser
|
||||
with UserDetails
|
||||
with TargetDecider
|
||||
with TargetABDecider
|
||||
with TweetImpressionHistory
|
||||
with HTLVisitHistory
|
||||
with MaxTweetAge
|
||||
with NewUserDetails
|
||||
with FrigateHistory
|
||||
with TargetWithSeedUsers
|
||||
}
|
||||
|
||||
case class GenericCandidateAdaptor(
|
||||
genericCandidates: CandidateSource[GenericCandidates.Target, Candidate],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
stats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override val name: String = genericCandidates.name
|
||||
|
||||
private def generateTweetFavCandidate(
|
||||
_target: Target,
|
||||
_tweetId: Long,
|
||||
_socialContextActions: Seq[SocialContextAction],
|
||||
socialContextActionsAllTypes: Seq[SocialContextAction],
|
||||
_tweetyPieResult: Option[TweetyPieResult]
|
||||
): RawCandidate = {
|
||||
new RawCandidate with TweetFavoriteCandidate {
|
||||
override val socialContextActions = _socialContextActions
|
||||
override val socialContextAllTypeActions =
|
||||
socialContextActionsAllTypes
|
||||
val tweetId = _tweetId
|
||||
val target = _target
|
||||
val tweetyPieResult = _tweetyPieResult
|
||||
}
|
||||
}
|
||||
|
||||
private def generateTweetRetweetCandidate(
|
||||
_target: Target,
|
||||
_tweetId: Long,
|
||||
_socialContextActions: Seq[SocialContextAction],
|
||||
socialContextActionsAllTypes: Seq[SocialContextAction],
|
||||
_tweetyPieResult: Option[TweetyPieResult]
|
||||
): RawCandidate = {
|
||||
new RawCandidate with TweetRetweetCandidate {
|
||||
override val socialContextActions = _socialContextActions
|
||||
override val socialContextAllTypeActions = socialContextActionsAllTypes
|
||||
val tweetId = _tweetId
|
||||
val target = _target
|
||||
val tweetyPieResult = _tweetyPieResult
|
||||
}
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
genericCandidates.get(inputTarget).map { candidatesOpt =>
|
||||
candidatesOpt
|
||||
.map { candidates =>
|
||||
val candidatesSeq =
|
||||
candidates.collect {
|
||||
case tweetRetweet: TweetRetweetCandidate
|
||||
if inputTarget.params(PushParams.MRTweetRetweetRecsParam) =>
|
||||
generateTweetRetweetCandidate(
|
||||
inputTarget,
|
||||
tweetRetweet.tweetId,
|
||||
tweetRetweet.socialContextActions,
|
||||
tweetRetweet.socialContextAllTypeActions,
|
||||
tweetRetweet.tweetyPieResult)
|
||||
case tweetFavorite: TweetFavoriteCandidate
|
||||
if inputTarget.params(PushParams.MRTweetFavRecsParam) =>
|
||||
generateTweetFavCandidate(
|
||||
inputTarget,
|
||||
tweetFavorite.tweetId,
|
||||
tweetFavorite.socialContextActions,
|
||||
tweetFavorite.socialContextAllTypeActions,
|
||||
tweetFavorite.tweetyPieResult)
|
||||
}
|
||||
candidatesSeq.foreach { candidate =>
|
||||
stats.counter(s"${candidate.commonRecType}_count").incr()
|
||||
}
|
||||
candidatesSeq
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target).map { isAvailable =>
|
||||
isAvailable && target.params(PushParams.GenericCandidateAdaptorDecider)
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,280 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum
|
||||
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum._
|
||||
import com.twitter.frigate.pushservice.params.PushConstants.targetUserAgeFeatureName
|
||||
import com.twitter.frigate.pushservice.params.PushConstants.targetUserPreferredLanguage
|
||||
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
|
||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.interests.thriftscala.InterestId.SemanticCore
|
||||
import com.twitter.interests.thriftscala.UserInterests
|
||||
import com.twitter.language.normalization.UserDisplayLanguage
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweet
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
||||
import com.twitter.util.Future
|
||||
|
||||
object HighQualityTweetsHelper {
|
||||
def getFollowedTopics(
|
||||
target: Target,
|
||||
interestsWithLookupContextStore: ReadableStore[
|
||||
InterestsLookupRequestWithContext,
|
||||
UserInterests
|
||||
],
|
||||
followedTopicsStats: Stat
|
||||
): Future[Seq[Long]] = {
|
||||
TopicsUtil
|
||||
.getTopicsFollowedByUser(target, interestsWithLookupContextStore, followedTopicsStats).map {
|
||||
userInterestsOpt =>
|
||||
val userInterests = userInterestsOpt.getOrElse(Seq.empty)
|
||||
val extractedTopicIds = userInterests.flatMap {
|
||||
_.interestId match {
|
||||
case SemanticCore(semanticCore) => Some(semanticCore.id)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
extractedTopicIds
|
||||
}
|
||||
}
|
||||
|
||||
def getTripQueries(
|
||||
target: Target,
|
||||
enabledGroups: Set[HighQualityCandidateGroupEnum.Value],
|
||||
interestsWithLookupContextStore: ReadableStore[
|
||||
InterestsLookupRequestWithContext,
|
||||
UserInterests
|
||||
],
|
||||
sourceIds: Seq[String],
|
||||
stat: Stat
|
||||
): Future[Set[TripDomain]] = {
|
||||
|
||||
val followedTopicIdsSetFut: Future[Set[Long]] = if (enabledGroups.contains(Topic)) {
|
||||
getFollowedTopics(target, interestsWithLookupContextStore, stat).map(topicIds =>
|
||||
topicIds.toSet)
|
||||
} else {
|
||||
Future.value(Set.empty)
|
||||
}
|
||||
|
||||
Future
|
||||
.join(target.featureMap, target.inferredUserDeviceLanguage, followedTopicIdsSetFut).map {
|
||||
case (
|
||||
featureMap,
|
||||
deviceLanguageOpt,
|
||||
followedTopicIds
|
||||
) =>
|
||||
val ageBucketOpt = if (enabledGroups.contains(AgeBucket)) {
|
||||
featureMap.categoricalFeatures.get(targetUserAgeFeatureName)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
val languageOptions: Set[Option[String]] = if (enabledGroups.contains(Language)) {
|
||||
val userPreferredLanguages = featureMap.sparseBinaryFeatures
|
||||
.getOrElse(targetUserPreferredLanguage, Set.empty[String])
|
||||
if (userPreferredLanguages.nonEmpty) {
|
||||
userPreferredLanguages.map(lang => Some(UserDisplayLanguage.toTweetLanguage(lang)))
|
||||
} else {
|
||||
Set(deviceLanguageOpt.map(UserDisplayLanguage.toTweetLanguage))
|
||||
}
|
||||
} else Set(None)
|
||||
|
||||
val followedTopicOptions: Set[Option[Long]] = if (followedTopicIds.nonEmpty) {
|
||||
followedTopicIds.map(topic => Some(topic))
|
||||
} else Set(None)
|
||||
|
||||
val tripQueries = followedTopicOptions.flatMap { topicOption =>
|
||||
languageOptions.flatMap { languageOption =>
|
||||
sourceIds.map { sourceId =>
|
||||
TripDomain(
|
||||
sourceId = sourceId,
|
||||
language = languageOption,
|
||||
placeId = None,
|
||||
topicId = topicOption,
|
||||
gender = None,
|
||||
ageBucket = ageBucketOpt
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tripQueries
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class HighQualityTweetsAdaptor(
|
||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
||||
interestsWithLookupContextStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override def name: String = this.getClass.getSimpleName
|
||||
|
||||
private val stats = globalStats.scope("HighQualityCandidateAdaptor")
|
||||
private val followedTopicsStats = stats.stat("followed_topics")
|
||||
private val missingResponseCounter = stats.counter("missing_respond_counter")
|
||||
private val crtFatigueCounter = stats.counter("fatigue_by_crt")
|
||||
private val fallbackRequestsCounter = stats.counter("fallback_requests")
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target).map {
|
||||
_ && target.params(FS.HighQualityCandidatesEnableCandidateSource)
|
||||
}
|
||||
}
|
||||
|
||||
private val highQualityCandidateFrequencyPredicate = {
|
||||
TargetPredicates
|
||||
.pushRecTypeFatiguePredicate(
|
||||
CommonRecommendationType.TripHqTweet,
|
||||
FS.HighQualityTweetsPushInterval,
|
||||
FS.MaxHighQualityTweetsPushGivenInterval,
|
||||
stats
|
||||
)
|
||||
}
|
||||
|
||||
private def getTripCandidatesStrato(
|
||||
target: Target
|
||||
): Future[Map[Long, Set[TripDomain]]] = {
|
||||
val tripQueriesF: Future[Set[TripDomain]] = HighQualityTweetsHelper.getTripQueries(
|
||||
target = target,
|
||||
enabledGroups = target.params(FS.HighQualityCandidatesEnableGroups).toSet,
|
||||
interestsWithLookupContextStore = interestsWithLookupContextStore,
|
||||
sourceIds = target.params(FS.TripTweetCandidateSourceIds),
|
||||
stat = followedTopicsStats
|
||||
)
|
||||
|
||||
lazy val fallbackTripQueriesFut: Future[Set[TripDomain]] =
|
||||
if (target.params(FS.HighQualityCandidatesEnableFallback))
|
||||
HighQualityTweetsHelper.getTripQueries(
|
||||
target = target,
|
||||
enabledGroups = target.params(FS.HighQualityCandidatesFallbackEnabledGroups).toSet,
|
||||
interestsWithLookupContextStore = interestsWithLookupContextStore,
|
||||
sourceIds = target.params(FS.HighQualityCandidatesFallbackSourceIds),
|
||||
stat = followedTopicsStats
|
||||
)
|
||||
else Future.value(Set.empty)
|
||||
|
||||
val initialTweetsFut: Future[Map[TripDomain, Seq[TripTweet]]] = tripQueriesF.flatMap {
|
||||
tripQueries => getTripTweetsByDomains(tripQueries)
|
||||
}
|
||||
|
||||
val tweetsByDomainFut: Future[Map[TripDomain, Seq[TripTweet]]] =
|
||||
if (target.params(FS.HighQualityCandidatesEnableFallback)) {
|
||||
initialTweetsFut.flatMap { candidates =>
|
||||
val minCandidatesForFallback: Int =
|
||||
target.params(FS.HighQualityCandidatesMinNumOfCandidatesToFallback)
|
||||
val validCandidates = candidates.filter(_._2.size >= minCandidatesForFallback)
|
||||
|
||||
if (validCandidates.nonEmpty) {
|
||||
Future.value(validCandidates)
|
||||
} else {
|
||||
fallbackTripQueriesFut.flatMap { fallbackTripDomains =>
|
||||
fallbackRequestsCounter.incr(fallbackTripDomains.size)
|
||||
getTripTweetsByDomains(fallbackTripDomains)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
initialTweetsFut
|
||||
}
|
||||
|
||||
val numOfCandidates: Int = target.params(FS.HighQualityCandidatesNumberOfCandidates)
|
||||
tweetsByDomainFut.map(tweetsByDomain => reformatDomainTweetMap(tweetsByDomain, numOfCandidates))
|
||||
}
|
||||
|
||||
private def getTripTweetsByDomains(
|
||||
tripQueries: Set[TripDomain]
|
||||
): Future[Map[TripDomain, Seq[TripTweet]]] = {
|
||||
Future.collect(tripTweetCandidateStore.multiGet(tripQueries)).map { response =>
|
||||
response
|
||||
.filter(p => p._2.exists(_.tweets.nonEmpty))
|
||||
.mapValues(_.map(_.tweets).getOrElse(Seq.empty))
|
||||
}
|
||||
}
|
||||
|
||||
private def reformatDomainTweetMap(
|
||||
tweetsByDomain: Map[TripDomain, Seq[TripTweet]],
|
||||
numOfCandidates: Int
|
||||
): Map[Long, Set[TripDomain]] = tweetsByDomain
|
||||
.flatMap {
|
||||
case (tripDomain, tripTweets) =>
|
||||
tripTweets
|
||||
.sortBy(_.score)(Ordering[Double].reverse)
|
||||
.take(numOfCandidates)
|
||||
.map { tweet => (tweet.tweetId, tripDomain) }
|
||||
}.groupBy(_._1).mapValues(_.map(_._2).toSet)
|
||||
|
||||
private def buildRawCandidate(
|
||||
target: Target,
|
||||
tweetyPieResult: TweetyPieResult,
|
||||
tripDomain: Option[scala.collection.Set[TripDomain]]
|
||||
): RawCandidate = {
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = tweetyPieResult.tweet.id,
|
||||
mediaCRT = MediaCRT(
|
||||
CommonRecommendationType.TripHqTweet,
|
||||
CommonRecommendationType.TripHqTweet,
|
||||
CommonRecommendationType.TripHqTweet
|
||||
),
|
||||
result = Some(tweetyPieResult),
|
||||
tripTweetDomain = tripDomain
|
||||
)
|
||||
}
|
||||
|
||||
private def getTweetyPieResults(
|
||||
target: Target,
|
||||
tweetToTripDomain: Map[Long, Set[TripDomain]]
|
||||
): Future[Map[Long, Option[TweetyPieResult]]] = {
|
||||
Future.collect((if (target.params(FS.EnableVFInTweetypie)) {
|
||||
tweetyPieStore
|
||||
} else {
|
||||
tweetyPieStoreNoVF
|
||||
}).multiGet(tweetToTripDomain.keySet))
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
for {
|
||||
tweetsToTripDomainMap <- getTripCandidatesStrato(target)
|
||||
tweetyPieResults <- getTweetyPieResults(target, tweetsToTripDomainMap)
|
||||
} yield {
|
||||
val candidates = tweetyPieResults.flatMap {
|
||||
case (tweetId, tweetyPieResultOpt) =>
|
||||
tweetyPieResultOpt.map(buildRawCandidate(target, _, tweetsToTripDomainMap.get(tweetId)))
|
||||
}
|
||||
if (candidates.nonEmpty) {
|
||||
highQualityCandidateFrequencyPredicate(Seq(target))
|
||||
.map(_.head)
|
||||
.map { isTargetFatigueEligible =>
|
||||
if (isTargetFatigueEligible) Some(candidates)
|
||||
else {
|
||||
crtFatigueCounter.incr()
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Some(candidates.toSeq)
|
||||
} else {
|
||||
missingResponseCounter.incr()
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,152 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.ListPushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
||||
import com.twitter.interests_discovery.thriftscala.DisplayLocation
|
||||
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
case class ListsToRecommendCandidateAdaptor(
|
||||
listRecommendationsStore: ReadableStore[String, NonPersonalizedRecommendedLists],
|
||||
geoDuckV2Store: ReadableStore[Long, LocationResponse],
|
||||
idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
private[this] val stats = globalStats.scope(name)
|
||||
private[this] val noLocationCodeCounter = stats.counter("no_location_code")
|
||||
private[this] val noCandidatesCounter = stats.counter("no_candidates_for_geo")
|
||||
private[this] val disablePopGeoListsCounter = stats.counter("disable_pop_geo_lists")
|
||||
private[this] val disableIDSListsCounter = stats.counter("disable_ids_lists")
|
||||
|
||||
private def getListCandidate(
|
||||
targetUser: Target,
|
||||
_listId: Long
|
||||
): RawCandidate with ListPushCandidate = {
|
||||
new RawCandidate with ListPushCandidate {
|
||||
override val listId: Long = _listId
|
||||
|
||||
override val commonRecType: CommonRecommendationType = CommonRecommendationType.List
|
||||
|
||||
override val target: Target = targetUser
|
||||
}
|
||||
}
|
||||
|
||||
private def getListsRecommendedFromHistory(
|
||||
target: Target
|
||||
): Future[Seq[Long]] = {
|
||||
target.history.map { history =>
|
||||
history.sortedHistory.flatMap {
|
||||
case (_, notif) if notif.commonRecommendationType == List =>
|
||||
notif.listNotification.map(_.listId)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getIDSListRecs(
|
||||
target: Target,
|
||||
historicalListIds: Seq[Long]
|
||||
): Future[Seq[Long]] = {
|
||||
val request = RecommendedListsRequest(
|
||||
target.targetId,
|
||||
DisplayLocation.ListDiscoveryPage,
|
||||
Some(historicalListIds)
|
||||
)
|
||||
if (target.params(PushFeatureSwitchParams.EnableIDSListRecommendations)) {
|
||||
idsStore.get(request).map {
|
||||
case Some(response) =>
|
||||
response.channels.map(_.id)
|
||||
case _ => Nil
|
||||
}
|
||||
} else {
|
||||
disableIDSListsCounter.incr()
|
||||
Future.Nil
|
||||
}
|
||||
}
|
||||
|
||||
private def getPopGeoLists(
|
||||
target: Target,
|
||||
historicalListIds: Seq[Long]
|
||||
): Future[Seq[Long]] = {
|
||||
if (target.params(PushFeatureSwitchParams.EnablePopGeoListRecommendations)) {
|
||||
geoDuckV2Store.get(target.targetId).flatMap {
|
||||
case Some(locationResponse) if locationResponse.geohash.isDefined =>
|
||||
val geoHashLength =
|
||||
target.params(PushFeatureSwitchParams.ListRecommendationsGeoHashLength)
|
||||
val geoHash = locationResponse.geohash.get.take(geoHashLength)
|
||||
listRecommendationsStore
|
||||
.get(s"geohash_$geoHash")
|
||||
.map {
|
||||
case Some(recommendedLists) =>
|
||||
recommendedLists.recommendedListsByAlgo.flatMap { topLists =>
|
||||
topLists.lists.collect {
|
||||
case list if !historicalListIds.contains(list.listId) => list.listId
|
||||
}
|
||||
}
|
||||
case _ => Nil
|
||||
}
|
||||
case _ =>
|
||||
noLocationCodeCounter.incr()
|
||||
Future.Nil
|
||||
}
|
||||
} else {
|
||||
disablePopGeoListsCounter.incr()
|
||||
Future.Nil
|
||||
}
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
getListsRecommendedFromHistory(target).flatMap { historicalListIds =>
|
||||
Future
|
||||
.join(
|
||||
getPopGeoLists(target, historicalListIds),
|
||||
getIDSListRecs(target, historicalListIds)
|
||||
)
|
||||
.map {
|
||||
case (popGeoListsIds, idsListIds) =>
|
||||
val candidates = (idsListIds ++ popGeoListsIds).map(getListCandidate(target, _))
|
||||
Some(candidates)
|
||||
case _ =>
|
||||
noCandidatesCounter.incr()
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val pushCapFatiguePredicate = TargetPredicates.pushRecTypeFatiguePredicate(
|
||||
CommonRecommendationType.List,
|
||||
PushFeatureSwitchParams.ListRecommendationsPushInterval,
|
||||
PushFeatureSwitchParams.MaxListRecommendationsPushGivenInterval,
|
||||
stats,
|
||||
)
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
|
||||
val isNotFatigued = pushCapFatiguePredicate.apply(Seq(target)).map(_.head)
|
||||
|
||||
Future
|
||||
.join(
|
||||
PushDeviceUtil.isRecommendationsEligible(target),
|
||||
isNotFatigued
|
||||
).map {
|
||||
case (userRecommendationsEligible, isUnderCAP) =>
|
||||
userRecommendationsEligible && isUnderCAP && target.params(
|
||||
PushFeatureSwitchParams.EnableListRecommendations)
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,54 +0,0 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
|
||||
import com.twitter.geoduck.common.thriftscala.Location
|
||||
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
|
||||
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
|
||||
|
||||
class LoggedOutPushCandidateSourceGenerator(
|
||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
||||
geoDuckV2Store: ReadableStore[Long, LocationResponse],
|
||||
safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
||||
cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult],
|
||||
cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
||||
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
|
||||
softUserLocationStore: ReadableStore[Long, Location],
|
||||
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
|
||||
topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace],
|
||||
)(
|
||||
implicit val globalStats: StatsReceiver) {
|
||||
val sources: Seq[CandidateSource[Target, RawCandidate] with CandidateSourceEligible[
|
||||
Target,
|
||||
RawCandidate
|
||||
]] = {
|
||||
Seq(
|
||||
TripGeoCandidatesAdaptor(
|
||||
tripTweetCandidateStore,
|
||||
contentMixerStore,
|
||||
safeCachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats
|
||||
),
|
||||
TopTweetsByGeoAdaptor(
|
||||
geoDuckV2Store,
|
||||
softUserLocationStore,
|
||||
topTweetsByGeoStore,
|
||||
topTweetsByGeoV2VersionedStore,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user