the-algorithm/simclusters-ann/server/src/main/scala/com/twitter/simclustersann/filters/SimClustersAnnVariantFilter.scala
twitter-team ef4c5eb65e Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
2023-03-31 17:36:31 -05:00

54 lines
1.9 KiB
Scala

package com.twitter.simclustersann.filters
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
import com.twitter.finagle.Service
import com.twitter.finagle.SimpleFilter
import com.twitter.relevance_platform.simclustersann.multicluster.ServiceNameMapper
import com.twitter.scrooge.Request
import com.twitter.scrooge.Response
import com.twitter.simclustersann.exceptions.InvalidRequestForSimClustersAnnVariantException
import com.twitter.simclustersann.thriftscala.SimClustersANNService
import com.twitter.util.Future
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class SimClustersAnnVariantFilter @Inject() (
serviceNameMapper: ServiceNameMapper,
serviceIdentifier: ServiceIdentifier,
) extends SimpleFilter[Request[SimClustersANNService.GetTweetCandidates.Args], Response[
SimClustersANNService.GetTweetCandidates.SuccessType
]] {
override def apply(
request: Request[SimClustersANNService.GetTweetCandidates.Args],
service: Service[Request[SimClustersANNService.GetTweetCandidates.Args], Response[
SimClustersANNService.GetTweetCandidates.SuccessType
]]
): Future[Response[SimClustersANNService.GetTweetCandidates.SuccessType]] = {
validateRequest(request)
service(request)
}
private def validateRequest(
request: Request[SimClustersANNService.GetTweetCandidates.Args]
): Unit = {
val modelVersion = request.args.query.sourceEmbeddingId.modelVersion
val embeddingType = request.args.query.config.candidateEmbeddingType
val actualServiceName = serviceIdentifier.service
val expectedServiceName = serviceNameMapper.getServiceName(modelVersion, embeddingType)
expectedServiceName match {
case Some(name) if name == actualServiceName => ()
case _ =>
throw InvalidRequestForSimClustersAnnVariantException(
modelVersion,
embeddingType,
actualServiceName,
expectedServiceName)
}
}
}