This commit is contained in:
Karthik Nair 2023-07-16 13:52:04 +03:00 committed by GitHub
commit 012195cb61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
package com.twitter.follow_recommendations.common.rankers.weighted_candidate_source_ranker
import com.twitter.follow_recommendations.common.base.Ranker
import com.twitter.follow_recommendations.common.models.CandidateUser
import com.twitter.follow_recommendations.common.rankers.common.DedupCandidates
@ -19,65 +20,42 @@ import com.twitter.timelines.configapi.HasParams
* @param shuffleFn the shuffle function that will be used to shuffle each algorithm's sorted candidate list.
* @param dedup whether to remove duplicated candidates from the final output.
*/
class WeightedCandidateSourceRanker[Target <: HasParams](
basedRanker: WeightedCandidateSourceBaseRanker[
CandidateSourceIdentifier,
CandidateUser
],
shuffleFn: Seq[CandidateUser] => Seq[CandidateUser],
dedup: Boolean)
extends Ranker[Target, CandidateUser] {
val name: String = this.getClass.getSimpleName
class WeightedCandidateSourceRanker(
basedRanker: WeightedCandidateSourceBaseRanker[CandidateSourceIdentifier, CandidateUser],
shuffleFn: Seq[CandidateUser] => Seq[CandidateUser],
dedup: Boolean
) extends Ranker[Target, CandidateUser] {
override def rank(target: Target, candidates: Seq[CandidateUser]): Stitch[Seq[CandidateUser]] = {
val scribeRankingInfo: Boolean =
target.params(WeightedCandidateSourceRankerParams.ScribeRankingInfoInWeightedRanker)
val rankedCands = rankCandidates(group(candidates))
Stitch.value(if (scribeRankingInfo) Utils.addRankingInfo(rankedCands, name) else rankedCands)
val scribeRankingInfo = target.params(WeightedCandidateSourceRankerParams.ScribeRankingInfoInWeightedRanker)
val rankedCandidates = rankCandidates(group(candidates))
Stitch.value(if (scribeRankingInfo) Utils.addRankingInfo(rankedCandidates, name) else rankedCandidates)
}
private def group(
candidates: Seq[CandidateUser]
): Map[CandidateSourceIdentifier, Seq[CandidateUser]] = {
val flattened = for {
candidate <- candidates
identifier <- candidate.getPrimaryCandidateSource
} yield (identifier, candidate)
flattened.groupBy(_._1).mapValues(_.map(_._2))
private def group(candidates: Seq[CandidateUser]): Map[CandidateSourceIdentifier, Seq[CandidateUser]] = {
candidates.flatMap(_.getPrimaryCandidateSource.map(identifier => (identifier, Seq(_)))).toMap
}
private def rankCandidates(
input: Map[CandidateSourceIdentifier, Seq[CandidateUser]]
): Seq[CandidateUser] = {
// Sort and shuffle candidates per candidate source.
// Note 1: Using map instead mapValue here since mapValue somehow caused infinite loop when used as part of Stream.
val sortAndShuffledCandidates = input.map {
case (source, candidates) =>
// Note 2: toList is required here since candidates is a view, and it will result in infinit loop when used as part of Stream.
// Note 3: there is no real sorting logic here, it assumes the input is already sorted by candidate sources
val sortedCandidates = candidates.toList
source -> shuffleFn(sortedCandidates).iterator
}
private def rankCandidates(input: Map[CandidateSourceIdentifier, Seq[CandidateUser]]): Seq[CandidateUser] = {
val sortAndShuffledCandidates = input.mapValues(shuffleFn compose (_.toList)).toSeq
val rankedCandidates = basedRanker(sortAndShuffledCandidates)
if (dedup) DedupCandidates(rankedCandidates) else rankedCandidates
}
val name: String = getClass.getSimpleName
}
object WeightedCandidateSourceRanker {
def build[Target <: HasParams](
def build(
candidateSourceWeight: Map[CandidateSourceIdentifier, Double],
shuffleFn: Seq[CandidateUser] => Seq[CandidateUser] = identity,
dedup: Boolean = false,
randomSeed: Option[Long] = None
): WeightedCandidateSourceRanker[Target] = {
): WeightedCandidateSourceRanker = {
new WeightedCandidateSourceRanker(
new WeightedCandidateSourceBaseRanker(
candidateSourceWeight,
WeightMethod.WeightedRandomSampling,
randomSeed = randomSeed),
new WeightedCandidateSourceBaseRanker(candidateSourceWeight, WeightMethod.WeightedRandomSampling, randomSeed),
shuffleFn,
dedup
)
@ -85,16 +63,11 @@ object WeightedCandidateSourceRanker {
}
object WeightedCandidateSourceRankerWithoutRandomSampling {
def build[Target <: HasParams](
candidateSourceWeight: Map[CandidateSourceIdentifier, Double]
): WeightedCandidateSourceRanker[Target] = {
def build(candidateSourceWeight: Map[CandidateSourceIdentifier, Double]): WeightedCandidateSourceRanker = {
new WeightedCandidateSourceRanker(
new WeightedCandidateSourceBaseRanker(
candidateSourceWeight,
WeightMethod.WeightedRoundRobin,
randomSeed = None),
new WeightedCandidateSourceBaseRanker(candidateSourceWeight, WeightMethod.WeightedRoundRobin, None),
identity,
false,
false
)
}
}