the-algorithm/pushservice/src/main/python/models/heavy_ranking/model_pools.py

35 lines
977 B
Python

"""
Candidate architectures for each task's.
"""
from __future__ import annotations
from typing import Dict
from .features import get_features
from .graph import Graph
from .lib.model import ClemNet
from .params import ModelTypeEnum
import tensorflow as tf
class MagicRecsClemNet(Graph):
def get_logits(self, features: Dict[str, tf.Tensor], training: bool) -> tf.Tensor:
with tf.name_scope("logits"):
inputs = get_features(features=features, training=training, params=self.params.model.features)
with tf.name_scope("OONC_logits"):
model = ClemNet(params=self.params.model.architecture)
oonc_logit = model(inputs=inputs, training=training)
with tf.name_scope("EngagementGivenOONC_logits"):
model = ClemNet(params=self.params.model.architecture)
eng_logits = model(inputs=inputs, training=training)
return tf.concat([oonc_logit, eng_logits], axis=1)
ALL_MODELS = {ModelTypeEnum.clemnet: MagicRecsClemNet}