"""Configuration for the main Recap model.""" import enum from typing import List, Optional, Dict import tml.core.config as base_config from tml.projects.home.recap.embedding import config as embedding_config import pydantic class DropoutConfig(base_config.BaseConfig): """Configuration for the dropout layer.""" rate: pydantic.PositiveFloat = pydantic.Field( 0.1, description="Fraction of inputs to be dropped." ) class LayerNormConfig(base_config.BaseConfig): """Configruation for the layer normalization.""" epsilon: float = pydantic.Field( 1e-3, description="Small float added to variance to avoid dividing by zero." ) axis: int = pydantic.Field(-1, description="Axis or axes to normalize across.") center: bool = pydantic.Field(True, description="Whether to add learnable center.") scale: bool = pydantic.Field(True, description="Whether to add learnable scale.") class BatchNormConfig(base_config.BaseConfig): """Configuration of the batch normalization layer.""" epsilon: pydantic.PositiveFloat = 1e-5 momentum: pydantic.PositiveFloat = 0.9 training_mode_at_inference_time: bool = False use_renorm: bool = False center: bool = pydantic.Field(True, description="Whether to add learnable center.") scale: bool = pydantic.Field(True, description="Whether to add learnable scale.") class DenseLayerConfig(base_config.BaseConfig): """Configuration for the dense layer.""" layer_size: pydantic.PositiveInt dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.") class MlpConfig(base_config.BaseConfig): """Configuration for MLP model.""" layer_sizes: List[pydantic.PositiveInt] = pydantic.Field(None, one_of="mlp_layer_definition") layers: List[DenseLayerConfig] = pydantic.Field(None, one_of="mlp_layer_definition") class BatchNormConfig(base_config.BaseConfig): """Configuration for the batch norm layer.""" affine: bool = pydantic.Field(True, description="Use affine transformation.") momentum: pydantic.PositiveFloat = pydantic.Field( 0.1, description="Forgetting parameter in moving average." ) class DoubleNormLogConfig(base_config.BaseConfig): """Configuration for the double norm log transform.""" batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None) clip_magnitude: float = pydantic.Field( 5.0, description="Threshold to clip the normalized input values." ) layer_norm_config: Optional[LayerNormConfig] = pydantic.Field(None) class Log1pAbsConfig(base_config.BaseConfig): """Simple configuration where only the log transform is performed.""" class ClipLog1pAbsConfig(base_config.BaseConfig): """Configuration for the clip log transform.""" clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field( 3e38, description="Threshold to clip the input values." ) class ZScoreLogConfig(base_config.BaseConfig): """Configuration for the z-score log transform.""" analysis_path: str schema_path: str = pydantic.Field( None, description="Schema path which feaure statistics are generated with. Can be different from scehma in data config.", ) clip_magnitude: float = pydantic.Field( 5.0, description="Threshold to clip the normalized input values." ) use_batch_norm: bool = pydantic.Field( False, description="Option to use batch normalization on the inputs." ) use_renorm: bool = pydantic.Field( False, description="Option to use batch renormalization for trainig and serving consistency." ) use_bq_stats: bool = pydantic.Field( False, description="Option to load the partitioned json files from BQ as statistics." ) class FeaturizationConfig(base_config.BaseConfig): """Configuration for featurization.""" log1p_abs_config: Log1pAbsConfig = pydantic.Field(None, one_of="featurization") clip_log1p_abs_config: ClipLog1pAbsConfig = pydantic.Field(None, one_of="featurization") z_score_log_config: ZScoreLogConfig = pydantic.Field(None, one_of="featurization") double_norm_log_config: DoubleNormLogConfig = pydantic.Field(None, one_of="featurization") feature_names_to_concat: List[str] = pydantic.Field( ["binary"], description="Feature names to concatenate as raw values with continuous features." ) class DropoutConfig(base_config.BaseConfig): """Configuration for the dropout layer.""" rate: pydantic.PositiveFloat = pydantic.Field( 0.1, description="Fraction of inputs to be dropped." ) class MlpConfig(base_config.BaseConfig): """Configuration for MLP model.""" layer_sizes: List[pydantic.PositiveInt] batch_norm: BatchNormConfig = pydantic.Field( None, description="Optional batch norm configuration." ) dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout configuration.") final_layer_activation: bool = pydantic.Field( False, description="Whether to include activation on final layer." ) class DcnConfig(base_config.BaseConfig): """Config for DCN model.""" poly_degree: pydantic.PositiveInt projection_dim: pydantic.PositiveInt = pydantic.Field( None, description="Factorizes main DCN matmul with projection." ) parallel_mlp: Optional[MlpConfig] = pydantic.Field( None, description="Config for the mlp if used. If None, only the cross layers are used." ) use_parallel: bool = pydantic.Field(True, description="Whether to use parallel DCN.") output_mlp: Optional[MlpConfig] = pydantic.Field(None, description="Config for the output mlp.") class MaskBlockConfig(base_config.BaseConfig): """Config for MaskNet block.""" output_size: int reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field( None, one_of="aggregation_size" ) aggregation_size: Optional[pydantic.PositiveInt] = pydantic.Field( None, description="Specify the aggregation size directly.", one_of="aggregation_size" ) input_layer_norm: bool class MaskNetConfig(base_config.BaseConfig): """Config for MaskNet model.""" mask_blocks: List[MaskBlockConfig] mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel") use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.") class PositionDebiasConfig(base_config.BaseConfig): """ Configuration for Position Debias. """ max_position: int = pydantic.Field(256, description="Bucket all later positions.") num_dims: pydantic.PositiveInt = pydantic.Field( 64, description="Number of dimensions in embedding." ) drop_probability: float = pydantic.Field(0.5, description="Probability of dropping position.") # Currently it should be 51 based on dataset being tested at the time of writing this model # However, no default provided here to make sure user of the model is aware of its importance. position_feature_index: int = pydantic.Field( description="The index of the position feature in the discrete features" ) class AffineMap(base_config.BaseConfig): """An affine map that scales the logits into the appropriate range.""" scale: float = pydantic.Field(1.0) bias: float = pydantic.Field(0.0) class DLRMConfig(base_config.BaseConfig): """Config for DLRM model.""" bottom_mlp: MlpConfig = pydantic.Field( ..., description="Bottom mlp, the output to be combined with sparse features and feed to interaction", ) top_mlp: MlpConfig = pydantic.Field(..., description="Top mlp, generate the final output") class TaskModel(base_config.BaseConfig): """Configuration for a single task.""" mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture") dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture") dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture") mask_net_config: MaskNetConfig = pydantic.Field(None, one_of="architecture") affine_map: AffineMap = pydantic.Field( None, description="Affine map applied to logits so we can represent a broader range of probabilities.", ) # DANGER DANGER: not implemented yet. # loss_weight: float = pydantic.Field(1.0, description="Weight for task in loss.") pos_weight: float = pydantic.Field(1.0, description="Weight of positive in loss.") class MultiTaskType(str, enum.Enum): """Type of multi task architecture.""" SHARE_NONE = "share_none" # Tasks are separate. SHARE_ALL = "share_all" # Tasks share same backbone. SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions. class ModelConfig(base_config.BaseConfig): """Specify model architecture.""" tasks: Dict[str, TaskModel] = pydantic.Field( description="Specification of architecture per task." ) large_embeddings: embedding_config.LargeEmbeddingsConfig = pydantic.Field(None) small_embeddings: embedding_config.SmallEmbeddingsConfig = pydantic.Field(None) # Not implemented yet. # multi_task_loss_reduction_fn: str = "mean" position_debias_config: PositionDebiasConfig = pydantic.Field( default=None, description="position debias model configuration" ) featurization_config: FeaturizationConfig = pydantic.Field(None) multi_task_type: MultiTaskType = pydantic.Field( MultiTaskType.SHARE_NONE, description="Multi task architecture" ) backbone: TaskModel = pydantic.Field(None, description="Type of architecture for the backbone.") stratifiers: List[embedding_config.StratifierConfig] = pydantic.Field( default=None, description="Discrete features and values to stratify metrics by." ) @pydantic.root_validator() def _validate_mtl(cls, values): """Validate the multi task architecture.""" if values.get("multi_task_type", None) is None: return values elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]: if values.get("backbone", None) is None: raise ValueError("Require `backbone` for SHARE_ALL and SHARE_PARTIAL.") elif values["multi_task_type"] in [ MultiTaskType.SHARE_NONE, ]: if values.get("backbone", None) is not None: raise ValueError("Can not have backbone if the share type is SHARE_NONE") return values