mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-01 08:48:46 +02:00
35 lines
977 B
Python
35 lines
977 B
Python
|
"""
|
||
|
Candidate architectures for each task's.
|
||
|
"""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Dict
|
||
|
|
||
|
from .features import get_features
|
||
|
from .graph import Graph
|
||
|
from .lib.model import ClemNet
|
||
|
from .params import ModelTypeEnum
|
||
|
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
class MagicRecsClemNet(Graph):
|
||
|
def get_logits(self, features: Dict[str, tf.Tensor], training: bool) -> tf.Tensor:
|
||
|
|
||
|
with tf.name_scope("logits"):
|
||
|
inputs = get_features(features=features, training=training, params=self.params.model.features)
|
||
|
|
||
|
with tf.name_scope("OONC_logits"):
|
||
|
model = ClemNet(params=self.params.model.architecture)
|
||
|
oonc_logit = model(inputs=inputs, training=training)
|
||
|
|
||
|
with tf.name_scope("EngagementGivenOONC_logits"):
|
||
|
model = ClemNet(params=self.params.model.architecture)
|
||
|
eng_logits = model(inputs=inputs, training=training)
|
||
|
|
||
|
return tf.concat([oonc_logit, eng_logits], axis=1)
|
||
|
|
||
|
|
||
|
ALL_MODELS = {ModelTypeEnum.clemnet: MagicRecsClemNet}
|