the-algorithm/product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/selector/InsertAppendRatioResults.scala
twitter-team ef4c5eb65e Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
2023-03-31 17:36:31 -05:00

172 lines
7.3 KiB
Scala

package com.twitter.product_mixer.component_library.selector
import com.twitter.product_mixer.core.functional_component.common.CandidateScope
import com.twitter.product_mixer.core.functional_component.common.SpecificPipelines
import com.twitter.product_mixer.core.functional_component.selector.Selector
import com.twitter.product_mixer.core.functional_component.selector.SelectorResult
import com.twitter.product_mixer.core.model.common.identifier.CandidatePipelineIdentifier
import com.twitter.product_mixer.core.model.common.presentation.CandidateWithDetails
import com.twitter.product_mixer.core.pipeline.PipelineQuery
import com.twitter.timelines.configapi.Param
import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.Random
/**
* Select candidates and add them according to the ratio assigned for each [[Bucket]]
* For instance, if given `Set((A, 0.8), (B, 0.2))` then candidates will randomly be added to the
* results with an 80% chance of any candidate being from `A` and 20% from`B`. If there are no more
* candidates from a given `CandidatePipeline` then it's simply skipped, so if we run out of `A`
* candidates the rest will be `B`. The end result is all candidates from all [[candidatePipelines]]s
* provided will end up in the result.
*
* For example, an output may look like `Seq(A, A, B, A, A)`, `Seq(A, A, A, A, B)`. If we eventually
* run out of `A` candidates then we would end up with the remaining candidates at the end,
* `Seq(A, A, B, A, A, A, B, A, A, A [run out of A], B, B, B, B, B, B)`
*
* @note the ratios provided are proportional to the sum of all ratios, so if you give 0.3 and 0.7,
* they will be function as to 30% and 70%, and the same for if you provided 3000 and 7000 for
* ratios.
*
* @note Its important to be sure to update all [[Param]]s when changing the ratio for 1 of them
* otherwise you may get unexpected results. For instance, of you have 0.3 and 0.7 which
* correspond to 30% and 70%, and you change `0.7 -> 0.9`, then the total sum of the ratios is
* now 1.2, so you have 25% and 75% when you intended to have 10% and 90%. To prevent this,
* be sure to update all [[Param]]s together, so `0.3 -> 0.1` and `0.7 -> 0.9` so the total
* remains the same.
*/
case class InsertAppendRatioResults[-Query <: PipelineQuery, Bucket](
candidatePipelines: Set[CandidatePipelineIdentifier],
bucketer: Bucketer[Bucket],
ratios: Map[Bucket, Param[Double]],
random: Random = new Random(0))
extends Selector[Query] {
require(ratios.nonEmpty, "bucketRatios must be non-empty")
override val pipelineScope: CandidateScope = SpecificPipelines(candidatePipelines)
private sealed trait PatternResult
private case object NotASelectedCandidatePipeline extends PatternResult
private case object NotABucketInThePattern extends PatternResult
private case class Bucketed(bucket: Bucket) extends PatternResult
override def apply(
query: Query,
remainingCandidates: Seq[CandidateWithDetails],
result: Seq[CandidateWithDetails]
): SelectorResult = {
val groupedCandidates: Map[PatternResult, Seq[CandidateWithDetails]] =
remainingCandidates.groupBy { candidateWithDetails =>
if (pipelineScope.contains(candidateWithDetails)) {
// if a candidate's Bucket doesnt appear in the pattern it's backfilled at the end
val bucket = bucketer(candidateWithDetails)
if (ratios.contains(bucket)) {
Bucketed(bucket)
} else {
NotABucketInThePattern
}
} else {
NotASelectedCandidatePipeline
}
}
val otherCandidates =
groupedCandidates.getOrElse(NotASelectedCandidatePipeline, Seq.empty)
val notABucketInThePattern =
groupedCandidates.getOrElse(NotABucketInThePattern, Seq.empty)
val groupedCandidatesIterators = groupedCandidates.collect {
case (Bucketed(bucket), candidatesWithDetails) => (bucket, candidatesWithDetails.iterator)
}
// using a LinkedHashMap and sorting by descending ratio
// the highest ratios will always be checked first when iterating
// mutable so we can remove finished ratios when they are finished to optimize looping for large numbers of ratios
val currentBucketRatios: mutable.Map[Bucket, Double] = {
val bucketsAndRatiosSortedByRatio =
ratios.iterator
.map {
case (bucket, param) =>
val ratio = query.params(param)
require(
ratio >= 0,
"The ratio for an InsertAppendRatioResults selector can not be negative")
(bucket, ratio)
}.toSeq
.sortBy { case (_, ratio) => ratio }(Ordering.Double.reverse)
mutable.LinkedHashMap(bucketsAndRatiosSortedByRatio: _*)
}
// keep track of the sum of all ratios so we can look only at random values between 0 and that
var ratioSum = currentBucketRatios.valuesIterator.sum
// add candidates to `newResults` until all remaining candidates are for a single bucket
val newResult = new mutable.ArrayBuffer[CandidateWithDetails]()
while (currentBucketRatios.size > 1) {
// random number between 0 and the sum of the ratios of all params
val randomValue = random.nextDouble() * ratioSum
val currentBucketRatiosIterator: Iterator[(Bucket, Double)] =
currentBucketRatios.iterator
val (currentBucket, ratio) = currentBucketRatiosIterator.next()
val componentToTakeFrom = findBucketToTakeFrom(
randomValue = randomValue,
cumulativeSumOfRatios = ratio,
bucket = currentBucket,
bucketRatiosIterator = currentBucketRatiosIterator
)
groupedCandidatesIterators.get(componentToTakeFrom) match {
case Some(iteratorForBucket) if iteratorForBucket.nonEmpty =>
newResult += iteratorForBucket.next()
case _ =>
ratioSum -= currentBucketRatios(componentToTakeFrom)
currentBucketRatios.remove(componentToTakeFrom)
}
}
// with only have 1 source remaining, we can skip all the above work and insert them in bulk
val remainingBucketInRatio =
currentBucketRatios.keysIterator.flatMap(groupedCandidatesIterators.get).flatten
SelectorResult(
remainingCandidates = otherCandidates,
result = result ++ newResult ++ remainingBucketInRatio ++ notABucketInThePattern)
}
/**
* iterates through the `bucketRatiosIterator` until it finds a the
* [[Bucket]] that corresponds with the current `randomValue`.
*
* This method expects that `0 <= randomValue <= sum of all ratios`
*
* @example If the given ratios are `Seq(A -> 0.2, B -> 0.35, C -> 0.45)`
* check if the given `randomValue` is
* - `< 0.45`, if not then check
* - `< 0.8` (0.45 + 0.35), if not then check
* - `< 1.0` (0.45 + 0.35 + 0.2)
*
* and return the corresponding [[Bucket]]
*/
@tailrec private def findBucketToTakeFrom(
randomValue: Double,
cumulativeSumOfRatios: Double,
bucket: Bucket,
bucketRatiosIterator: Iterator[(Bucket, Double)]
): Bucket = {
if (randomValue < cumulativeSumOfRatios || bucketRatiosIterator.isEmpty) {
bucket
} else {
val (nextBucket, ratio) = bucketRatiosIterator.next()
findBucketToTakeFrom(
randomValue,
cumulativeSumOfRatios + ratio,
nextBucket,
bucketRatiosIterator)
}
}
}