the-algorithm/cr-mixer/server/src/main/scala/com/twitter/cr_mixer/module/similarity_engine/TweetBasedTwHINSimlarityEngineModule.scala
2023-04-03 20:06:44 -04:00

71 lines
3.2 KiB
Scala

package com.twitter.cr_mixer.module.similarity_engine
import com.google.inject.Provides
import com.twitter.ann.common.thriftscala.AnnQueryService
import com.twitter.cr_mixer.model.ModelConfig
import com.twitter.cr_mixer.module.EmbeddingStoreModule
import com.twitter.cr_mixer.module.thrift_client.AnnQueryServiceClientModule
import com.twitter.cr_mixer.similarity_engine.HnswANNSimilarityEngine
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.inject.TwitterModule
import com.twitter.simclusters_v2.thriftscala.InternalId
import com.twitter.storehaus.ReadableStore
import javax.inject.Named
import com.twitter.ml.api.{thriftscala => api}
import com.twitter.conversions.DurationOps._
import com.twitter.cr_mixer.model.ModuleNames
import com.twitter.cr_mixer.config.TimeoutConfig
import com.twitter.cr_mixer.similarity_engine.HnswANNEngineQuery
import com.twitter.cr_mixer.similarity_engine.SimilarityEngine
import com.twitter.cr_mixer.similarity_engine.SimilarityEngine.GatingConfig
import com.twitter.cr_mixer.similarity_engine.SimilarityEngine.SimilarityEngineConfig
import com.twitter.cr_mixer.thriftscala.SimilarityEngineType
import com.twitter.finagle.memcached.{Client => MemcachedClient}
object TweetBasedTwHINSimilarityEngineModule extends TwitterModule {
@Provides
@Named(ModuleNames.TweetBasedTwHINANNSimilarityEngine)
def providesTweetBasedTwHINANNSimilarityEngine(
// MH stores
@Named(EmbeddingStoreModule.TwHINEmbeddingRegularUpdateMhStoreName)
twHINEmbeddingRegularUpdateMhStore: ReadableStore[InternalId, api.Embedding],
@Named(EmbeddingStoreModule.DebuggerDemoTweetEmbeddingMhStoreName)
debuggerDemoTweetEmbeddingMhStore: ReadableStore[InternalId, api.Embedding],
// ANN clients
@Named(AnnQueryServiceClientModule.TwHINRegularUpdateAnnServiceClientName)
twHINRegularUpdateAnnService: AnnQueryService.MethodPerEndpoint,
@Named(AnnQueryServiceClientModule.DebuggerDemoAnnServiceClientName)
debuggerDemoAnnService: AnnQueryService.MethodPerEndpoint,
// Other configs
@Named(ModuleNames.UnifiedCache) crMixerUnifiedCacheClient: MemcachedClient,
timeoutConfig: TimeoutConfig,
statsReceiver: StatsReceiver
): HnswANNSimilarityEngine = {
new HnswANNSimilarityEngine(
embeddingStoreLookUpMap = Map(
ModelConfig.TweetBasedTwHINRegularUpdateAll20221024 -> twHINEmbeddingRegularUpdateMhStore,
ModelConfig.DebuggerDemo -> debuggerDemoTweetEmbeddingMhStore,
),
annServiceLookUpMap = Map(
ModelConfig.TweetBasedTwHINRegularUpdateAll20221024 -> twHINRegularUpdateAnnService,
ModelConfig.DebuggerDemo -> debuggerDemoAnnService,
),
globalStats = statsReceiver,
identifier = SimilarityEngineType.TweetBasedTwHINANN,
engineConfig = SimilarityEngineConfig(
timeout = timeoutConfig.similarityEngineTimeout,
gatingConfig = GatingConfig(
deciderConfig = None,
enableFeatureSwitch = None
)
),
memCacheConfigOpt = Some(
SimilarityEngine.MemCacheConfig[HnswANNEngineQuery](
cacheClient = crMixerUnifiedCacheClient,
ttl = 30.minutes,
keyToString = (query: HnswANNEngineQuery) =>
SimilarityEngine.keyHasher.hashKey(query.cacheKey.getBytes).toString
))
)
}
}