mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-09 22:39:22 +01:00
Updates
This commit is contained in:
parent
cc73f5fcb7
commit
f7f26d0c20
29
model.py
29
model.py
@ -54,13 +54,21 @@ def maybe_shard_model(
|
|||||||
model,
|
model,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
"""
|
||||||
|
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
||||||
|
|
||||||
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
|
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
|
||||||
DistributedModelParallel.
|
DistributedModelParallel.
|
||||||
|
|
||||||
If not in a distributed environment, returns model directly.
|
If not in a distributed environment, returns the model directly.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model.
|
||||||
|
device: The target device (e.g., 'cuda').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
|
||||||
|
"""
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
logging.info("***** Wrapping in DistributedModelParallel *****")
|
logging.info("***** Wrapping in DistributedModelParallel *****")
|
||||||
logging.info(f"Model before wrapping: {model}")
|
logging.info(f"Model before wrapping: {model}")
|
||||||
@ -74,14 +82,15 @@ def maybe_shard_model(
|
|||||||
|
|
||||||
|
|
||||||
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
|
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
|
||||||
"""Handy function to log the content of EBC embedding layer.
|
|
||||||
Only works for single GPU machines.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight_name: name of tensor, as defined in model
|
|
||||||
table_name: name of the EBC table the weight is taken from
|
|
||||||
weight_tensor: embedding weight tensor
|
|
||||||
"""
|
"""
|
||||||
|
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
|
||||||
|
Only works for single GPU machines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_name: Name of the tensor, as defined in the model.
|
||||||
|
table_name: Name of the EBC table the weight is taken from.
|
||||||
|
weight_tensor: Embedding weight tensor.
|
||||||
|
"""
|
||||||
logging.info(f"{weight_name}, {table_name}", rank=-1)
|
logging.info(f"{weight_name}, {table_name}", rank=-1)
|
||||||
logging.info(f"{weight_tensor.metadata()}", rank=-1)
|
logging.info(f"{weight_tensor.metadata()}", rank=-1)
|
||||||
output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0"))
|
output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0"))
|
||||||
|
@ -8,239 +8,250 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class ExplicitDateInputs(base_config.BaseConfig):
|
class ExplicitDateInputs(base_config.BaseConfig):
|
||||||
"""Arguments to select train/validation data using end_date and days of data."""
|
"""Arguments to select train/validation data using end_date and days of data."""
|
||||||
|
|
||||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||||
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
|
end_date: str = pydantic.Field(...,
|
||||||
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
|
description="Data end date, inclusive.")
|
||||||
num_missing_days_tol: int = pydantic.Field(
|
days: int = pydantic.Field(...,
|
||||||
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
description="Number of days of data for dataset.")
|
||||||
)
|
num_missing_days_tol: int = pydantic.Field(
|
||||||
|
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExplicitDatetimeInputs(base_config.BaseConfig):
|
class ExplicitDatetimeInputs(base_config.BaseConfig):
|
||||||
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
||||||
|
|
||||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||||
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
|
end_datetime: str = pydantic.Field(...,
|
||||||
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
|
description="Data end datetime, inclusive.")
|
||||||
num_missing_hours_tol: int = pydantic.Field(
|
hours: int = pydantic.Field(...,
|
||||||
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
description="Number of hours of data for dataset.")
|
||||||
)
|
num_missing_hours_tol: int = pydantic.Field(
|
||||||
|
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DdsCompressionOption(str, Enum):
|
class DdsCompressionOption(str, Enum):
|
||||||
"""The only valid compression option is 'AUTO'"""
|
"""The only valid compression option is 'AUTO'"""
|
||||||
|
|
||||||
AUTO = "AUTO"
|
AUTO = "AUTO"
|
||||||
|
|
||||||
|
|
||||||
class DatasetConfig(base_config.BaseConfig):
|
class DatasetConfig(base_config.BaseConfig):
|
||||||
inputs: str = pydantic.Field(
|
inputs: str = pydantic.Field(
|
||||||
None, description="A glob for selecting data.", one_of="date_inputs_format"
|
None, description="A glob for selecting data.", one_of="date_inputs_format"
|
||||||
)
|
)
|
||||||
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
||||||
None, one_of="date_inputs_format"
|
None, one_of="date_inputs_format"
|
||||||
)
|
)
|
||||||
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
|
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
|
||||||
|
None, one_of="date_inputs_format")
|
||||||
|
|
||||||
global_batch_size: pydantic.PositiveInt
|
global_batch_size: pydantic.PositiveInt
|
||||||
|
|
||||||
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
|
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
|
||||||
None, description="Number of shards to keep."
|
None, description="Number of shards to keep."
|
||||||
)
|
)
|
||||||
repeat_files: bool = pydantic.Field(
|
repeat_files: bool = pydantic.Field(
|
||||||
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
||||||
)
|
)
|
||||||
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
file_batch_size: pydantic.PositiveInt = pydantic.Field(
|
||||||
|
16, description="File batch size")
|
||||||
|
|
||||||
cache: bool = pydantic.Field(
|
cache: bool = pydantic.Field(
|
||||||
False,
|
False,
|
||||||
description="Cache dataset in memory. Careful to only use this when you"
|
description="Cache dataset in memory. Careful to only use this when you"
|
||||||
" have enough memory to fit entire dataset.",
|
" have enough memory to fit entire dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
data_service_dispatcher: str = pydantic.Field(None)
|
data_service_dispatcher: str = pydantic.Field(None)
|
||||||
ignore_data_errors: bool = pydantic.Field(
|
ignore_data_errors: bool = pydantic.Field(
|
||||||
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
|
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
|
||||||
)
|
)
|
||||||
dataset_service_compression: DdsCompressionOption = pydantic.Field(
|
dataset_service_compression: DdsCompressionOption = pydantic.Field(
|
||||||
None,
|
None,
|
||||||
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
|
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
|
||||||
)
|
)
|
||||||
|
|
||||||
# tf.data.Dataset options
|
# tf.data.Dataset options
|
||||||
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
|
examples_shuffle_buffer_size: int = pydantic.Field(
|
||||||
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
1024, description="Size of shuffle buffers.")
|
||||||
None, description="Number of parallel calls."
|
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||||||
)
|
None, description="Number of parallel calls."
|
||||||
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
)
|
||||||
None, description="Number of shards to interleave."
|
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||||||
)
|
None, description="Number of shards to interleave."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TruncateAndSlice(base_config.BaseConfig):
|
class TruncateAndSlice(base_config.BaseConfig):
|
||||||
# Apply truncation and then slice.
|
# Apply truncation and then slice.
|
||||||
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
||||||
None, description="Experimental. Truncates continuous features to this amount for efficiency."
|
None, description="Experimental. Truncates continuous features to this amount for efficiency."
|
||||||
)
|
)
|
||||||
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
||||||
None, description="Experimental. Truncates binary features to this amount for efficiency."
|
None, description="Experimental. Truncates binary features to this amount for efficiency."
|
||||||
)
|
)
|
||||||
|
|
||||||
continuous_feature_mask_path: str = pydantic.Field(
|
continuous_feature_mask_path: str = pydantic.Field(
|
||||||
None, description="Path of mask used to slice input continuous features."
|
None, description="Path of mask used to slice input continuous features."
|
||||||
)
|
)
|
||||||
binary_feature_mask_path: str = pydantic.Field(
|
binary_feature_mask_path: str = pydantic.Field(
|
||||||
None, description="Path of mask used to slice input binary features."
|
None, description="Path of mask used to slice input binary features."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DataType(str, Enum):
|
class DataType(str, Enum):
|
||||||
BFLOAT16 = "bfloat16"
|
BFLOAT16 = "bfloat16"
|
||||||
BOOL = "bool"
|
BOOL = "bool"
|
||||||
|
|
||||||
FLOAT32 = "float32"
|
FLOAT32 = "float32"
|
||||||
FLOAT16 = "float16"
|
FLOAT16 = "float16"
|
||||||
|
|
||||||
UINT8 = "uint8"
|
UINT8 = "uint8"
|
||||||
|
|
||||||
|
|
||||||
class DownCast(base_config.BaseConfig):
|
class DownCast(base_config.BaseConfig):
|
||||||
# Apply down casting to selected features.
|
# Apply down casting to selected features.
|
||||||
features: typing.Dict[str, DataType] = pydantic.Field(
|
features: typing.Dict[str, DataType] = pydantic.Field(
|
||||||
None, description="Map features to down cast data types."
|
None, description="Map features to down cast data types."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TaskData(base_config.BaseConfig):
|
class TaskData(base_config.BaseConfig):
|
||||||
pos_downsampling_rate: float = pydantic.Field(
|
pos_downsampling_rate: float = pydantic.Field(
|
||||||
1.0,
|
1.0,
|
||||||
description="Downsampling rate of positives used to generate dataset.",
|
description="Downsampling rate of positives used to generate dataset.",
|
||||||
)
|
)
|
||||||
neg_downsampling_rate: float = pydantic.Field(
|
neg_downsampling_rate: float = pydantic.Field(
|
||||||
1.0,
|
1.0,
|
||||||
description="Downsampling rate of negatives used to generate dataset.",
|
description="Downsampling rate of negatives used to generate dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SegDenseSchema(base_config.BaseConfig):
|
class SegDenseSchema(base_config.BaseConfig):
|
||||||
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
|
schema_path: str = pydantic.Field(...,
|
||||||
features: typing.List[str] = pydantic.Field(
|
description="Path to feature config json.")
|
||||||
[],
|
features: typing.List[str] = pydantic.Field(
|
||||||
description="List of features (in addition to the renamed features) to read from schema path above.",
|
[],
|
||||||
)
|
description="List of features (in addition to the renamed features) to read from schema path above.",
|
||||||
renamed_features: typing.Dict[str, str] = pydantic.Field(
|
)
|
||||||
{}, description="Dictionary of renamed features."
|
renamed_features: typing.Dict[str, str] = pydantic.Field(
|
||||||
)
|
{}, description="Dictionary of renamed features."
|
||||||
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
|
)
|
||||||
{},
|
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
|
||||||
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
|
{},
|
||||||
)
|
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RectifyLabels(base_config.BaseConfig):
|
class RectifyLabels(base_config.BaseConfig):
|
||||||
label_rectification_window_in_hours: float = pydantic.Field(
|
label_rectification_window_in_hours: float = pydantic.Field(
|
||||||
3.0, description="overlap time in hours for which to flip labels"
|
3.0, description="overlap time in hours for which to flip labels"
|
||||||
)
|
)
|
||||||
served_timestamp_field: str = pydantic.Field(
|
served_timestamp_field: str = pydantic.Field(
|
||||||
..., description="input field corresponding to served time"
|
..., description="input field corresponding to served time"
|
||||||
)
|
)
|
||||||
impressed_timestamp_field: str = pydantic.Field(
|
impressed_timestamp_field: str = pydantic.Field(
|
||||||
..., description="input field corresponding to impressed time"
|
..., description="input field corresponding to impressed time"
|
||||||
)
|
)
|
||||||
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
|
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
|
||||||
..., description="label to the input field corresponding to engagement time"
|
..., description="label to the input field corresponding to engagement time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtractFeaturesRow(base_config.BaseConfig):
|
class ExtractFeaturesRow(base_config.BaseConfig):
|
||||||
name: str = pydantic.Field(
|
name: str = pydantic.Field(
|
||||||
...,
|
...,
|
||||||
description="name of the new field name to be created",
|
description="name of the new field name to be created",
|
||||||
)
|
)
|
||||||
source_tensor: str = pydantic.Field(
|
source_tensor: str = pydantic.Field(
|
||||||
...,
|
...,
|
||||||
description="name of the dense tensor to look for the feature",
|
description="name of the dense tensor to look for the feature",
|
||||||
)
|
)
|
||||||
index: int = pydantic.Field(
|
index: int = pydantic.Field(
|
||||||
...,
|
...,
|
||||||
description="index of the feature in the dense tensor",
|
description="index of the feature in the dense tensor",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtractFeatures(base_config.BaseConfig):
|
class ExtractFeatures(base_config.BaseConfig):
|
||||||
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
|
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
|
||||||
[],
|
[],
|
||||||
description="list of features to be extracted with their name, source tensor and index",
|
description="list of features to be extracted with their name, source tensor and index",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DownsampleNegatives(base_config.BaseConfig):
|
class DownsampleNegatives(base_config.BaseConfig):
|
||||||
batch_multiplier: int = pydantic.Field(
|
batch_multiplier: int = pydantic.Field(
|
||||||
None,
|
None,
|
||||||
description="batch multiplier",
|
description="batch multiplier",
|
||||||
)
|
)
|
||||||
engagements_list: typing.List[str] = pydantic.Field(
|
engagements_list: typing.List[str] = pydantic.Field(
|
||||||
[],
|
[],
|
||||||
description="engagements with kept positives",
|
description="engagements with kept positives",
|
||||||
)
|
)
|
||||||
num_engagements: int = pydantic.Field(
|
num_engagements: int = pydantic.Field(
|
||||||
...,
|
...,
|
||||||
description="number engagements used in the model, including ones excluded in engagements_list",
|
description="number engagements used in the model, including ones excluded in engagements_list",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Preprocess(base_config.BaseConfig):
|
class Preprocess(base_config.BaseConfig):
|
||||||
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
|
truncate_and_slice: TruncateAndSlice = pydantic.Field(
|
||||||
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
|
None, description="Truncation and slicing.")
|
||||||
rectify_labels: RectifyLabels = pydantic.Field(
|
downcast: DownCast = pydantic.Field(
|
||||||
None, description="Rectify labels for a given overlap window"
|
None, description="Down cast to features.")
|
||||||
)
|
rectify_labels: RectifyLabels = pydantic.Field(
|
||||||
extract_features: ExtractFeatures = pydantic.Field(
|
None, description="Rectify labels for a given overlap window"
|
||||||
None, description="Extract features from dense tensors."
|
)
|
||||||
)
|
extract_features: ExtractFeatures = pydantic.Field(
|
||||||
downsample_negatives: DownsampleNegatives = pydantic.Field(
|
None, description="Extract features from dense tensors."
|
||||||
None, description="Downsample negatives."
|
)
|
||||||
)
|
downsample_negatives: DownsampleNegatives = pydantic.Field(
|
||||||
|
None, description="Downsample negatives."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sampler(base_config.BaseConfig):
|
class Sampler(base_config.BaseConfig):
|
||||||
"""Assumes function is defined in data/samplers.py.
|
"""Assumes function is defined in data/samplers.py.
|
||||||
|
|
||||||
Only use this for quick experimentation.
|
Only use this for quick experimentation.
|
||||||
If samplers are useful, we should sample from upstream data generation.
|
If samplers are useful, we should sample from upstream data generation.
|
||||||
|
|
||||||
DEPRICATED, DO NOT USE.
|
DEPRICATED, DO NOT USE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
kwargs: typing.Dict
|
kwargs: typing.Dict
|
||||||
|
|
||||||
|
|
||||||
class RecapDataConfig(DatasetConfig):
|
class RecapDataConfig(DatasetConfig):
|
||||||
seg_dense_schema: SegDenseSchema
|
seg_dense_schema: SegDenseSchema
|
||||||
|
|
||||||
tasks: typing.Dict[str, TaskData] = pydantic.Field(
|
tasks: typing.Dict[str, TaskData] = pydantic.Field(
|
||||||
description="Description of individual tasks in this dataset."
|
description="Description of individual tasks in this dataset."
|
||||||
)
|
)
|
||||||
evaluation_tasks: typing.List[str] = pydantic.Field(
|
evaluation_tasks: typing.List[str] = pydantic.Field(
|
||||||
[], description="If specified, lists the tasks we're generating metrics for."
|
[], description="If specified, lists the tasks we're generating metrics for."
|
||||||
)
|
)
|
||||||
|
|
||||||
preprocess: Preprocess = pydantic.Field(
|
preprocess: Preprocess = pydantic.Field(
|
||||||
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
|
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler: Sampler = pydantic.Field(
|
sampler: Sampler = pydantic.Field(
|
||||||
None,
|
None,
|
||||||
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
|
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pydantic.root_validator()
|
@pydantic.root_validator()
|
||||||
def _validate_evaluation_tasks(cls, values):
|
def _validate_evaluation_tasks(cls, values):
|
||||||
if values.get("evaluation_tasks") is not None:
|
if values.get("evaluation_tasks") is not None:
|
||||||
for task in values["evaluation_tasks"]:
|
for task in values["evaluation_tasks"]:
|
||||||
if task not in values["tasks"]:
|
if task not in values["tasks"]:
|
||||||
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
raise KeyError(
|
||||||
return values
|
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
||||||
|
return values
|
||||||
|
@ -9,9 +9,20 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class TruncateAndSlice(tf.keras.Model):
|
class TruncateAndSlice(tf.keras.Model):
|
||||||
"""Class for truncating and slicing."""
|
"""
|
||||||
|
A class for truncating and slicing input features based on the provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, truncate_and_slice_config):
|
def __init__(self, truncate_and_slice_config):
|
||||||
|
"""
|
||||||
|
Initializes the TruncateAndSlice model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._truncate_and_slice_config = truncate_and_slice_config
|
self._truncate_and_slice_config = truncate_and_slice_config
|
||||||
|
|
||||||
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
|
|||||||
self._binary_mask = None
|
self._binary_mask = None
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Applies truncation and slicing to the input features based on the configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of truncated and sliced input features.
|
||||||
|
"""
|
||||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||||
if self._truncate_and_slice_config.continuous_feature_truncation:
|
if self._truncate_and_slice_config.continuous_feature_truncation:
|
||||||
logging.info("Truncating continuous")
|
logging.info("Truncating continuous")
|
||||||
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class DownCast(tf.keras.Model):
|
class DownCast(tf.keras.Model):
|
||||||
"""Class for Down casting dataset before serialization and transferring to training host.
|
|
||||||
Depends on the data type and the actual data range, the down casting can be lossless or not.
|
|
||||||
It is strongly recommended to compare the metrics before and after down casting.
|
|
||||||
"""
|
"""
|
||||||
|
A class for downcasting dataset before serialization and transferring to the training host.
|
||||||
|
|
||||||
|
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
||||||
|
It is strongly recommended to compare the metrics before and after downcasting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, downcast_config):
|
def __init__(self, downcast_config):
|
||||||
|
"""
|
||||||
|
Initializes the DownCast model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = downcast_config
|
self.config = downcast_config
|
||||||
self._type_map = {
|
self._type_map = {
|
||||||
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Applies downcasting to the input features based on the configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of downcasted input features.
|
||||||
|
"""
|
||||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||||
for feature, type_str in self.config.features.items():
|
for feature, type_str in self.config.features.items():
|
||||||
assert type_str in self._type_map
|
assert type_str in self._type_map
|
||||||
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class RectifyLabels(tf.keras.Model):
|
class RectifyLabels(tf.keras.Model):
|
||||||
"""Class for rectifying labels"""
|
"""
|
||||||
|
A class for downcasting dataset before serialization and transferring to the training host.
|
||||||
|
|
||||||
|
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
||||||
|
It is strongly recommended to compare the metrics before and after downcasting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, rectify_label_config):
|
def __init__(self, rectify_label_config):
|
||||||
|
"""
|
||||||
|
Initializes the DownCast model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._config = rectify_label_config
|
self._config = rectify_label_config
|
||||||
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Applies downcasting to the input features based on the configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of downcasted input features.
|
||||||
|
"""
|
||||||
served_ts_field = self._config.served_timestamp_field
|
served_ts_field = self._config.served_timestamp_field
|
||||||
impressed_ts_field = self._config.impressed_timestamp_field
|
impressed_ts_field = self._config.impressed_timestamp_field
|
||||||
|
|
||||||
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class ExtractFeatures(tf.keras.Model):
|
class ExtractFeatures(tf.keras.Model):
|
||||||
"""Class for extracting individual features from dense tensors by their index."""
|
"""
|
||||||
|
A class for rectifying labels based on specified conditions.
|
||||||
|
|
||||||
|
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, extract_features_config):
|
def __init__(self, extract_features_config):
|
||||||
|
"""
|
||||||
|
Initializes the RectifyLabels model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._config = extract_features_config
|
self._config = extract_features_config
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Rectifies label values based on the specified conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features including timestamp fields and labels.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of input features with rectified label values.
|
||||||
|
"""
|
||||||
|
|
||||||
for row in self._config.extract_feature_table:
|
for row in self._config.extract_feature_table:
|
||||||
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
||||||
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
||||||
"""Builds a preprocess model to apply all preprocessing stages."""
|
"""
|
||||||
|
Builds a preprocess model to apply all preprocessing stages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preprocess_config: A configuration object specifying the preprocessing parameters.
|
||||||
|
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A preprocess model that applies all specified preprocessing stages.
|
||||||
|
"""
|
||||||
if mode == config_mod.JobMode.INFERENCE:
|
if mode == config_mod.JobMode.INFERENCE:
|
||||||
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
||||||
return None
|
return None
|
||||||
|
@ -8,122 +8,129 @@ import tensorflow as tf
|
|||||||
|
|
||||||
|
|
||||||
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
||||||
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
|
DTYPE_MAP = {"int64_list": tf.int64,
|
||||||
|
"float_list": tf.float32, "bytes_list": tf.string}
|
||||||
|
|
||||||
|
|
||||||
def create_tf_example_schema(
|
def create_tf_example_schema(
|
||||||
data_config: recap_data_config.SegDenseSchema,
|
data_config: recap_data_config.SegDenseSchema,
|
||||||
segdense_schema,
|
segdense_schema,
|
||||||
):
|
):
|
||||||
"""Generate schema for deseralizing tf.Example.
|
"""Generate schema for deseralizing tf.Example.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
|
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
|
||||||
labels: List of strings denoting labels.
|
labels: List of strings denoting labels.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary schema suitable for deserializing tf.Example.
|
A dictionary schema suitable for deserializing tf.Example.
|
||||||
"""
|
"""
|
||||||
segdense_config = data_config.seg_dense_schema
|
segdense_config = data_config.seg_dense_schema
|
||||||
labels = list(data_config.tasks.keys())
|
labels = list(data_config.tasks.keys())
|
||||||
used_features = (
|
used_features = (
|
||||||
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
|
segdense_config.features +
|
||||||
)
|
list(segdense_config.renamed_features.values()) + labels
|
||||||
logging.info(used_features)
|
)
|
||||||
|
logging.info(used_features)
|
||||||
|
|
||||||
tfe_schema = {}
|
tfe_schema = {}
|
||||||
for entry in segdense_schema:
|
for entry in segdense_schema:
|
||||||
feature_name = entry["feature_name"]
|
feature_name = entry["feature_name"]
|
||||||
|
|
||||||
if feature_name in used_features:
|
if feature_name in used_features:
|
||||||
length = entry["length"]
|
length = entry["length"]
|
||||||
dtype = entry["dtype"]
|
dtype = entry["dtype"]
|
||||||
|
|
||||||
if feature_name in labels:
|
if feature_name in labels:
|
||||||
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
|
logging.info(
|
||||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
f"Label: feature name is {feature_name} type is {dtype}")
|
||||||
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||||
)
|
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
||||||
elif length == -1:
|
)
|
||||||
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
|
elif length == -1:
|
||||||
else:
|
tfe_schema[feature_name] = tf.io.VarLenFeature(
|
||||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
DTYPE_MAP[dtype])
|
||||||
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
else:
|
||||||
)
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||||
for feature_name in used_features:
|
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
||||||
if feature_name not in tfe_schema:
|
)
|
||||||
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
for feature_name in used_features:
|
||||||
return tfe_schema
|
if feature_name not in tfe_schema:
|
||||||
|
raise ValueError(
|
||||||
|
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
||||||
|
return tfe_schema
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(1)
|
@functools.lru_cache(1)
|
||||||
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
|
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
|
||||||
"""For experimentating with emulating bfloat16 or less precise types."""
|
"""For experimentating with emulating bfloat16 or less precise types."""
|
||||||
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
|
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
|
||||||
|
|
||||||
|
|
||||||
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
|
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
|
||||||
"""For experimentating with emulating bfloat16 or less precise types."""
|
"""For experimentating with emulating bfloat16 or less precise types."""
|
||||||
mask: tf.Tensor = make_mantissa_mask(mask_length)
|
mask: tf.Tensor = make_mantissa_mask(mask_length)
|
||||||
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
|
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
|
||||||
|
|
||||||
|
|
||||||
def parse_tf_example(
|
def parse_tf_example(
|
||||||
serialized_example,
|
serialized_example,
|
||||||
tfe_schema,
|
tfe_schema,
|
||||||
seg_dense_schema_config,
|
seg_dense_schema_config,
|
||||||
):
|
):
|
||||||
"""Parse serialized tf.Example into dict of tensors.
|
"""Parse serialized tf.Example into dict of tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
serialized_example: Serialized tf.Example to be parsed.
|
serialized_example: Serialized tf.Example to be parsed.
|
||||||
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
|
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of tensors to be used as model input.
|
Dictionary of tensors to be used as model input.
|
||||||
"""
|
"""
|
||||||
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
|
inputs = tf.io.parse_example(
|
||||||
|
serialized=serialized_example, features=tfe_schema)
|
||||||
|
|
||||||
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
||||||
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
||||||
|
|
||||||
# This should not actually be used except for experimentation with low precision floats.
|
# This should not actually be used except for experimentation with low precision floats.
|
||||||
if "mask_mantissa_features" in seg_dense_schema_config:
|
if "mask_mantissa_features" in seg_dense_schema_config:
|
||||||
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
||||||
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
|
inputs[feature_name] = mask_mantissa(
|
||||||
|
inputs[feature_name], mask_length)
|
||||||
|
|
||||||
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
||||||
# at TF level.
|
# at TF level.
|
||||||
# We should not return empty tensors if we dont use embeddings.
|
# We should not return empty tensors if we dont use embeddings.
|
||||||
# Otherwise, it breaks numpy->pt conversion
|
# Otherwise, it breaks numpy->pt conversion
|
||||||
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
|
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
|
||||||
for renamed_key in renamed_keys:
|
for renamed_key in renamed_keys:
|
||||||
if "embedding" in renamed_key and (renamed_key not in inputs):
|
if "embedding" in renamed_key and (renamed_key not in inputs):
|
||||||
inputs[renamed_key] = tf.zeros([], tf.float32)
|
inputs[renamed_key] = tf.zeros([], tf.float32)
|
||||||
|
|
||||||
logging.info(f"parsed example and inputs are {inputs}")
|
logging.info(f"parsed example and inputs are {inputs}")
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
|
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
|
||||||
"""Placeholder for seg dense.
|
"""Placeholder for seg dense.
|
||||||
|
|
||||||
In the future, when we use more seg dense variations, we can change this.
|
In the future, when we use more seg dense variations, we can change this.
|
||||||
"""
|
"""
|
||||||
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
|
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
|
||||||
seg_dense_schema = json.load(f)["schema"]
|
seg_dense_schema = json.load(f)["schema"]
|
||||||
|
|
||||||
tf_example_schema = create_tf_example_schema(
|
tf_example_schema = create_tf_example_schema(
|
||||||
data_config,
|
data_config,
|
||||||
seg_dense_schema,
|
seg_dense_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("***** TF Example Schema *****")
|
logging.info("***** TF Example Schema *****")
|
||||||
logging.info(tf_example_schema)
|
logging.info(tf_example_schema)
|
||||||
|
|
||||||
parse = functools.partial(
|
parse = functools.partial(
|
||||||
parse_tf_example,
|
parse_tf_example,
|
||||||
tfe_schema=tf_example_schema,
|
tfe_schema=tf_example_schema,
|
||||||
seg_dense_schema_config=data_config.seg_dense_schema,
|
seg_dense_schema_config=data_config.seg_dense_schema,
|
||||||
)
|
)
|
||||||
return parse
|
return parse
|
||||||
|
@ -6,115 +6,160 @@ import tensorflow as tf
|
|||||||
|
|
||||||
|
|
||||||
def keyed_tensor_from_tensors_dict(
|
def keyed_tensor_from_tensors_dict(
|
||||||
tensor_map: Mapping[str, torch.Tensor]
|
tensor_map: Mapping[str, torch.Tensor]
|
||||||
) -> "torchrec.KeyedTensor":
|
) -> "torchrec.KeyedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of torch tensor to torchrec keyed tensor
|
Convert a dictionary of torch tensors to a torchrec KeyedTensor.
|
||||||
Args:
|
|
||||||
tensor_map:
|
|
||||||
|
|
||||||
Returns:
|
Args:
|
||||||
|
tensor_map: A mapping of tensor names to torch tensors.
|
||||||
|
|
||||||
"""
|
Returns:
|
||||||
keys = list(tensor_map.keys())
|
A torchrec KeyedTensor.
|
||||||
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
"""
|
||||||
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
keys = list(tensor_map.keys())
|
||||||
# [Batch_size x 1].
|
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
||||||
values = [
|
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
||||||
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
|
# [Batch_size x 1].
|
||||||
for key in keys
|
values = [
|
||||||
]
|
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
|
||||||
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
tensor_map[key], -1)
|
||||||
|
for key in keys
|
||||||
|
]
|
||||||
|
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
||||||
|
|
||||||
|
|
||||||
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if tensor.is_sparse:
|
"""
|
||||||
x = tensor.coalesce() # Ensure that the indices are ordered.
|
Compute a jagged tensor from a torch tensor.
|
||||||
lengths = torch.bincount(x.indices()[0])
|
|
||||||
values = x.values()
|
Args:
|
||||||
else:
|
tensor: Input torch tensor.
|
||||||
values = tensor
|
|
||||||
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
Returns:
|
||||||
return values, lengths
|
A tuple containing the values and lengths of the jagged tensor.
|
||||||
|
"""
|
||||||
|
if tensor.is_sparse:
|
||||||
|
x = tensor.coalesce() # Ensure that the indices are ordered.
|
||||||
|
lengths = torch.bincount(x.indices()[0])
|
||||||
|
values = x.values()
|
||||||
|
else:
|
||||||
|
values = tensor
|
||||||
|
lengths = torch.ones(
|
||||||
|
tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
||||||
|
return values, lengths
|
||||||
|
|
||||||
|
|
||||||
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a torch tensor to torchrec jagged tensor.
|
Convert a torch tensor to a torchrec jagged tensor.
|
||||||
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
|
|
||||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
|
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
|
||||||
dense_shape of the sparse tensor can be arbitrary.
|
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
|
||||||
Args:
|
|
||||||
tensor: a torch (sparse) tensor.
|
Args:
|
||||||
Returns:
|
tensor: A torch (sparse) tensor.
|
||||||
"""
|
|
||||||
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
Returns:
|
||||||
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
A torchrec JaggedTensor.
|
||||||
|
"""
|
||||||
|
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
||||||
|
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
||||||
|
|
||||||
|
|
||||||
def keyed_jagged_tensor_from_tensors_dict(
|
def keyed_jagged_tensor_from_tensors_dict(
|
||||||
tensor_map: Mapping[str, torch.Tensor]
|
tensor_map: Mapping[str, torch.Tensor]
|
||||||
) -> "torchrec.KeyedJaggedTensor":
|
) -> "torchrec.KeyedJaggedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
|
Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
|
||||||
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
|
|
||||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
|
|
||||||
dense_shape of the sparse tensor can be arbitrary.
|
|
||||||
Args:
|
|
||||||
tensor_map:
|
|
||||||
|
|
||||||
Returns:
|
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
|
||||||
|
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
|
||||||
|
|
||||||
"""
|
Args:
|
||||||
|
tensor_map: A mapping of tensor names to torch tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torchrec KeyedJaggedTensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not tensor_map:
|
||||||
|
return torchrec.KeyedJaggedTensor(
|
||||||
|
keys=[],
|
||||||
|
values=torch.zeros(0, dtype=torch.int),
|
||||||
|
lengths=torch.zeros(0, dtype=torch.int),
|
||||||
|
)
|
||||||
|
values = []
|
||||||
|
lengths = []
|
||||||
|
for tensor in tensor_map.values():
|
||||||
|
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
|
||||||
|
values.append(torch.squeeze(tensor_val))
|
||||||
|
lengths.append(tensor_len)
|
||||||
|
|
||||||
|
values = torch.cat(values, axis=0)
|
||||||
|
lengths = torch.cat(lengths, axis=0)
|
||||||
|
|
||||||
if not tensor_map:
|
|
||||||
return torchrec.KeyedJaggedTensor(
|
return torchrec.KeyedJaggedTensor(
|
||||||
keys=[],
|
keys=list(tensor_map.keys()),
|
||||||
values=torch.zeros(0, dtype=torch.int),
|
values=values,
|
||||||
lengths=torch.zeros(0, dtype=torch.int),
|
lengths=lengths,
|
||||||
)
|
)
|
||||||
values = []
|
|
||||||
lengths = []
|
|
||||||
for tensor in tensor_map.values():
|
|
||||||
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
|
|
||||||
values.append(torch.squeeze(tensor_val))
|
|
||||||
lengths.append(tensor_len)
|
|
||||||
|
|
||||||
values = torch.cat(values, axis=0)
|
|
||||||
lengths = torch.cat(lengths, axis=0)
|
|
||||||
|
|
||||||
return torchrec.KeyedJaggedTensor(
|
|
||||||
keys=list(tensor_map.keys()),
|
|
||||||
values=values,
|
|
||||||
lengths=lengths,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
||||||
return tf_tensor._numpy() # noqa
|
"""
|
||||||
|
Convert a TensorFlow tensor to a NumPy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_tensor: TensorFlow tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NumPy array.
|
||||||
|
"""
|
||||||
|
return tf_tensor._numpy() # noqa
|
||||||
|
|
||||||
|
|
||||||
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
||||||
tensor = _tf_to_numpy(tensor)
|
"""
|
||||||
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
Convert a dense TensorFlow tensor to a PyTorch tensor.
|
||||||
if tensor.dtype.name == "bfloat16":
|
|
||||||
tensor = tensor.astype(np.float32)
|
|
||||||
|
|
||||||
tensor = torch.from_numpy(tensor)
|
Args:
|
||||||
if pin_memory:
|
tensor: Dense TensorFlow tensor.
|
||||||
tensor = tensor.pin_memory()
|
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
||||||
return tensor
|
|
||||||
|
Returns:
|
||||||
|
PyTorch tensor.
|
||||||
|
"""
|
||||||
|
tensor = _tf_to_numpy(tensor)
|
||||||
|
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
||||||
|
if tensor.dtype.name == "bfloat16":
|
||||||
|
tensor = tensor.astype(np.float32)
|
||||||
|
|
||||||
|
tensor = torch.from_numpy(tensor)
|
||||||
|
if pin_memory:
|
||||||
|
tensor = tensor.pin_memory()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def sparse_or_dense_tf_to_torch(
|
def sparse_or_dense_tf_to_torch(
|
||||||
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if isinstance(tensor, tf.SparseTensor):
|
"""
|
||||||
tensor = torch.sparse_coo_tensor(
|
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
|
||||||
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
|
||||||
_dense_tf_to_torch(tensor.values, pin_memory),
|
Args:
|
||||||
torch.Size(_tf_to_numpy(tensor.dense_shape)),
|
tensor: TensorFlow tensor (sparse or dense).
|
||||||
)
|
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
||||||
else:
|
|
||||||
tensor = _dense_tf_to_torch(tensor, pin_memory)
|
Returns:
|
||||||
return tensor
|
PyTorch tensor.
|
||||||
|
"""
|
||||||
|
if isinstance(tensor, tf.SparseTensor):
|
||||||
|
tensor = torch.sparse_coo_tensor(
|
||||||
|
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
||||||
|
_dense_tf_to_torch(tensor.values, pin_memory),
|
||||||
|
torch.Size(_tf_to_numpy(tensor.dense_shape)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensor = _dense_tf_to_torch(tensor, pin_memory)
|
||||||
|
return tensor
|
||||||
|
@ -6,234 +6,243 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def _init_weights(module):
|
def _init_weights(module):
|
||||||
"""Initializes weights
|
"""Initializes weights
|
||||||
|
|
||||||
Example
|
Example
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# Define a simple linear layer
|
# Define a simple linear layer
|
||||||
linear_layer = nn.Linear(64, 32)
|
linear_layer = nn.Linear(64, 32)
|
||||||
|
|
||||||
# Initialize the weights and biases using _init_weights
|
# Initialize the weights and biases using _init_weights
|
||||||
_init_weights(linear_layer)
|
_init_weights(linear_layer)
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(module, torch.nn.Linear):
|
if isinstance(module, torch.nn.Linear):
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
torch.nn.init.constant_(module.bias, 0)
|
torch.nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
class MaskBlock(torch.nn.Module):
|
class MaskBlock(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
MaskBlock module in a mask-based neural network.
|
MaskBlock module in a mask-based neural network.
|
||||||
|
|
||||||
This module represents a MaskBlock, which applies a masking operation to the input data and then
|
This module represents a MaskBlock, which applies a masking operation to the input data and then
|
||||||
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
|
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||||
input_dim (int): Dimensionality of the input data.
|
input_dim (int): Dimensionality of the input data.
|
||||||
mask_input_dim (int): Dimensionality of the mask input.
|
mask_input_dim (int): Dimensionality of the mask input.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
To create and use a MaskBlock within a MaskNet, follow these steps:
|
To create and use a MaskBlock within a MaskNet, follow these steps:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Define the configuration for the MaskBlock
|
# Define the configuration for the MaskBlock
|
||||||
mask_block_config = MaskBlockConfig(
|
mask_block_config = MaskBlockConfig(
|
||||||
input_layer_norm=True, # Apply input layer normalization
|
input_layer_norm=True, # Apply input layer normalization
|
||||||
reduction_factor=0.5 # Reduce input dimensionality by 50%
|
reduction_factor=0.5 # Reduce input dimensionality by 50%
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an instance of the MaskBlock
|
||||||
|
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
|
||||||
|
|
||||||
|
# Generate input tensors
|
||||||
|
input_data = torch.randn(batch_size, 64)
|
||||||
|
mask_input = torch.randn(batch_size, 32)
|
||||||
|
|
||||||
|
# Perform a forward pass through the MaskBlock
|
||||||
|
output = mask_block(input_data, mask_input)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
|
||||||
|
operation that combines the input and mask input. Then, it passes the result through a hidden layer
|
||||||
|
with optional dimensionality reduction.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the MaskBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||||
|
input_dim (int): Dimensionality of the input data.
|
||||||
|
mask_input_dim (int): Dimensionality of the mask input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(MaskBlock, self).__init__()
|
||||||
|
self.mask_block_config = mask_block_config
|
||||||
|
output_size = mask_block_config.output_size
|
||||||
|
|
||||||
|
if mask_block_config.input_layer_norm:
|
||||||
|
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
|
||||||
|
else:
|
||||||
|
self._input_layer_norm = None
|
||||||
|
|
||||||
|
if mask_block_config.reduction_factor:
|
||||||
|
aggregation_size = int(
|
||||||
|
mask_input_dim * mask_block_config.reduction_factor)
|
||||||
|
elif mask_block_config.aggregation_size is not None:
|
||||||
|
aggregation_size = mask_block_config.aggregation_size
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Need one of reduction factor or aggregation size.")
|
||||||
|
|
||||||
|
self._mask_layer = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(mask_input_dim, aggregation_size),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(aggregation_size, input_dim),
|
||||||
)
|
)
|
||||||
|
self._mask_layer.apply(_init_weights)
|
||||||
|
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
|
||||||
|
self._hidden_layer.apply(_init_weights)
|
||||||
|
self._layer_norm = torch.nn.LayerNorm(output_size)
|
||||||
|
|
||||||
# Create an instance of the MaskBlock
|
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
||||||
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
|
|
||||||
|
|
||||||
# Generate input tensors
|
|
||||||
input_data = torch.randn(batch_size, 64)
|
|
||||||
mask_input = torch.randn(batch_size, 32)
|
|
||||||
|
|
||||||
# Perform a forward pass through the MaskBlock
|
|
||||||
output = mask_block(input_data, mask_input)
|
|
||||||
```
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
|
|
||||||
operation that combines the input and mask input. Then, it passes the result through a hidden layer
|
|
||||||
with optional dimensionality reduction.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
This class is intended for internal use within neural network architectures and should not be
|
|
||||||
directly accessed or modified by external code.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initializes the MaskBlock module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
|
||||||
input_dim (int): Dimensionality of the input data.
|
|
||||||
mask_input_dim (int): Dimensionality of the mask input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
|
Performs a forward pass through the MaskBlock.
|
||||||
|
|
||||||
super(MaskBlock, self).__init__()
|
Args:
|
||||||
self.mask_block_config = mask_block_config
|
net (torch.Tensor): Input data tensor.
|
||||||
output_size = mask_block_config.output_size
|
mask_input (torch.Tensor): Mask input tensor.
|
||||||
|
|
||||||
if mask_block_config.input_layer_norm:
|
Returns:
|
||||||
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
|
torch.Tensor: Output tensor of the MaskBlock.
|
||||||
else:
|
"""
|
||||||
self._input_layer_norm = None
|
if self._input_layer_norm:
|
||||||
|
net = self._input_layer_norm(net)
|
||||||
if mask_block_config.reduction_factor:
|
hidden_layer_output = self._hidden_layer(
|
||||||
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
|
net * self._mask_layer(mask_input))
|
||||||
elif mask_block_config.aggregation_size is not None:
|
return self._layer_norm(hidden_layer_output)
|
||||||
aggregation_size = mask_block_config.aggregation_size
|
|
||||||
else:
|
|
||||||
raise ValueError("Need one of reduction factor or aggregation size.")
|
|
||||||
|
|
||||||
self._mask_layer = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(mask_input_dim, aggregation_size),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(aggregation_size, input_dim),
|
|
||||||
)
|
|
||||||
self._mask_layer.apply(_init_weights)
|
|
||||||
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
|
|
||||||
self._hidden_layer.apply(_init_weights)
|
|
||||||
self._layer_norm = torch.nn.LayerNorm(output_size)
|
|
||||||
|
|
||||||
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Performs a forward pass through the MaskBlock.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
net (torch.Tensor): Input data tensor.
|
|
||||||
mask_input (torch.Tensor): Mask input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output tensor of the MaskBlock.
|
|
||||||
"""
|
|
||||||
if self._input_layer_norm:
|
|
||||||
net = self._input_layer_norm(net)
|
|
||||||
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
|
||||||
return self._layer_norm(hidden_layer_output)
|
|
||||||
|
|
||||||
|
|
||||||
class MaskNet(torch.nn.Module):
|
class MaskNet(torch.nn.Module):
|
||||||
"""
|
|
||||||
MaskNet module in a mask-based neural network.
|
|
||||||
|
|
||||||
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
|
|
||||||
create mask-based neural networks with parallel or stacked MaskBlocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
|
||||||
in_features (int): Dimensionality of the input data.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
To create and use a MaskNet, you can follow these steps:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Define the configuration for the MaskNet
|
|
||||||
mask_net_config = MaskNetConfig(
|
|
||||||
use_parallel=True, # Use parallel MaskBlocks
|
|
||||||
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create an instance of the MaskNet
|
|
||||||
mask_net = MaskNet(mask_net_config, in_features=64)
|
|
||||||
|
|
||||||
# Generate input tensors
|
|
||||||
input_data = torch.randn(batch_size, 64)
|
|
||||||
|
|
||||||
# Perform a forward pass through the MaskNet
|
|
||||||
outputs = mask_net(input_data)
|
|
||||||
|
|
||||||
# Access the output and shared layer
|
|
||||||
output = outputs["output"]
|
|
||||||
shared_layer = outputs["shared_layer"]
|
|
||||||
```
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
|
|
||||||
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
This class is intended for internal use within neural network architectures and should not be
|
|
||||||
directly accessed or modified by external code.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
MaskNet module in a mask-based neural network.
|
||||||
"""
|
|
||||||
Initializes the MaskNet module.
|
|
||||||
|
|
||||||
Args:
|
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
|
||||||
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
create mask-based neural networks with parallel or stacked MaskBlocks.
|
||||||
in_features (int): Dimensionality of the input data.
|
|
||||||
|
|
||||||
Returns:
|
Args:
|
||||||
None
|
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||||
|
in_features (int): Dimensionality of the input data.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To create and use a MaskNet, you can follow these steps:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define the configuration for the MaskNet
|
||||||
|
mask_net_config = MaskNetConfig(
|
||||||
|
use_parallel=True, # Use parallel MaskBlocks
|
||||||
|
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an instance of the MaskNet
|
||||||
|
mask_net = MaskNet(mask_net_config, in_features=64)
|
||||||
|
|
||||||
|
# Generate input tensors
|
||||||
|
input_data = torch.randn(batch_size, 64)
|
||||||
|
|
||||||
|
# Perform a forward pass through the MaskNet
|
||||||
|
outputs = mask_net(input_data)
|
||||||
|
|
||||||
|
# Access the output and shared layer
|
||||||
|
output = outputs["output"]
|
||||||
|
shared_layer = outputs["shared_layer"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
|
||||||
|
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||||
"""
|
"""
|
||||||
|
Initializes the MaskNet module.
|
||||||
|
|
||||||
super().__init__()
|
Args:
|
||||||
self.mask_net_config = mask_net_config
|
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||||
mask_blocks = []
|
in_features (int): Dimensionality of the input data.
|
||||||
|
|
||||||
if mask_net_config.use_parallel:
|
Returns:
|
||||||
total_output_mask_blocks = 0
|
None
|
||||||
for mask_block_config in mask_net_config.mask_blocks:
|
"""
|
||||||
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
|
|
||||||
total_output_mask_blocks += mask_block_config.output_size
|
|
||||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
|
||||||
else:
|
|
||||||
input_size = in_features
|
|
||||||
for mask_block_config in mask_net_config.mask_blocks:
|
|
||||||
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
|
|
||||||
input_size = mask_block_config.output_size
|
|
||||||
|
|
||||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
super().__init__()
|
||||||
total_output_mask_blocks = mask_block_config.output_size
|
self.mask_net_config = mask_net_config
|
||||||
|
mask_blocks = []
|
||||||
|
|
||||||
if mask_net_config.mlp:
|
if mask_net_config.use_parallel:
|
||||||
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
|
total_output_mask_blocks = 0
|
||||||
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
for mask_block_config in mask_net_config.mask_blocks:
|
||||||
else:
|
mask_blocks.append(
|
||||||
self.out_features = total_output_mask_blocks
|
MaskBlock(mask_block_config, in_features, in_features))
|
||||||
self.shared_size = total_output_mask_blocks
|
total_output_mask_blocks += mask_block_config.output_size
|
||||||
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||||
|
else:
|
||||||
|
input_size = in_features
|
||||||
|
for mask_block_config in mask_net_config.mask_blocks:
|
||||||
|
mask_blocks.append(
|
||||||
|
MaskBlock(mask_block_config, input_size, in_features))
|
||||||
|
input_size = mask_block_config.output_size
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor):
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||||
"""
|
total_output_mask_blocks = mask_block_config.output_size
|
||||||
Performs a forward pass through the MaskNet.
|
|
||||||
|
|
||||||
Args:
|
if mask_net_config.mlp:
|
||||||
inputs (torch.Tensor): Input data tensor.
|
self._dense_layers = mlp.Mlp(
|
||||||
|
total_output_mask_blocks, mask_net_config.mlp)
|
||||||
|
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
||||||
|
else:
|
||||||
|
self.out_features = total_output_mask_blocks
|
||||||
|
self.shared_size = total_output_mask_blocks
|
||||||
|
|
||||||
Returns:
|
def forward(self, inputs: torch.Tensor):
|
||||||
torch.Tensor: Output tensor of the MaskNet.
|
|
||||||
"""
|
"""
|
||||||
if self.mask_net_config.use_parallel:
|
Performs a forward pass through the MaskNet.
|
||||||
mask_outputs = []
|
|
||||||
for mask_layer in self._mask_blocks:
|
Args:
|
||||||
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
|
inputs (torch.Tensor): Input data tensor.
|
||||||
# Share the outputs of the MaskBlocks.
|
|
||||||
all_mask_outputs = torch.cat(mask_outputs, dim=1)
|
Returns:
|
||||||
output = (
|
torch.Tensor: Output tensor of the MaskNet.
|
||||||
all_mask_outputs
|
"""
|
||||||
if self.mask_net_config.mlp is None
|
if self.mask_net_config.use_parallel:
|
||||||
else self._dense_layers(all_mask_outputs)["output"]
|
mask_outputs = []
|
||||||
)
|
for mask_layer in self._mask_blocks:
|
||||||
return {"output": output, "shared_layer": all_mask_outputs}
|
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
|
||||||
else:
|
# Share the outputs of the MaskBlocks.
|
||||||
net = inputs
|
all_mask_outputs = torch.cat(mask_outputs, dim=1)
|
||||||
for mask_layer in self._mask_blocks:
|
output = (
|
||||||
net = mask_layer(net=net, mask_input=inputs)
|
all_mask_outputs
|
||||||
# Share the output of the stacked MaskBlocks.
|
if self.mask_net_config.mlp is None
|
||||||
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
|
else self._dense_layers(all_mask_outputs)["output"]
|
||||||
return {"output": output, "shared_layer": net}
|
)
|
||||||
|
return {"output": output, "shared_layer": all_mask_outputs}
|
||||||
|
else:
|
||||||
|
net = inputs
|
||||||
|
for mask_layer in self._mask_blocks:
|
||||||
|
net = mask_layer(net=net, mask_input=inputs)
|
||||||
|
# Share the output of the stacked MaskBlocks.
|
||||||
|
output = net if self.mask_net_config.mlp is None else self._dense_layers[
|
||||||
|
net]["output"]
|
||||||
|
return {"output": output, "shared_layer": net}
|
||||||
|
@ -5,113 +5,117 @@ from absl import logging
|
|||||||
|
|
||||||
|
|
||||||
class ModelAndLoss(torch.nn.Module):
|
class ModelAndLoss(torch.nn.Module):
|
||||||
"""
|
|
||||||
PyTorch module that combines a neural network model and loss function.
|
|
||||||
|
|
||||||
This module wraps a neural network model and facilitates the forward pass through the model
|
|
||||||
while also calculating the loss based on the model's predictions and provided labels.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The torch module to wrap.
|
|
||||||
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
|
||||||
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
|
||||||
for metrics stratification. Each stratifier config includes the name and index of discrete features
|
|
||||||
to emit for stratification.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
|
|
||||||
and loss function as arguments:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Create a neural network model
|
|
||||||
model = YourNeuralNetworkModel()
|
|
||||||
|
|
||||||
# Define a loss function
|
|
||||||
loss_fn = torch.nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
# Create an instance of ModelAndLoss
|
|
||||||
model_and_loss = ModelAndLoss(model, loss_fn)
|
|
||||||
|
|
||||||
# Generate a batch of training data (e.g., RecapBatch)
|
|
||||||
batch = generate_training_batch()
|
|
||||||
|
|
||||||
# Perform a forward pass through the model and calculate the loss
|
|
||||||
loss, outputs = model_and_loss(batch)
|
|
||||||
|
|
||||||
# You can now backpropagate and optimize using the computed loss
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
```
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
|
|
||||||
calculating loss, making it easier to integrate the model into your training loop. Additionally,
|
|
||||||
it supports the addition of stratifiers for metrics stratification, if needed.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
This class is intended for internal use within neural network architectures and should not be
|
|
||||||
directly accessed or modified by external code.
|
|
||||||
"""
|
"""
|
||||||
def __init__(
|
PyTorch module that combines a neural network model and loss function.
|
||||||
self,
|
|
||||||
model,
|
|
||||||
loss_fn: Callable,
|
|
||||||
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initializes the ModelAndLoss module.
|
|
||||||
|
|
||||||
Args:
|
This module wraps a neural network model and facilitates the forward pass through the model
|
||||||
model: The torch module to wrap.
|
while also calculating the loss based on the model's predictions and provided labels.
|
||||||
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
|
||||||
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
|
||||||
for metrics stratification.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.loss_fn = loss_fn
|
|
||||||
self.stratifiers = stratifiers
|
|
||||||
|
|
||||||
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
Args:
|
||||||
"""Runs model forward and calculates loss according to given loss_fn.
|
model: The torch module to wrap.
|
||||||
|
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||||
|
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||||
|
for metrics stratification. Each stratifier config includes the name and index of discrete features
|
||||||
|
to emit for stratification.
|
||||||
|
|
||||||
NOTE: The input signature here needs to be a Pipelineable object for
|
Example:
|
||||||
prefetching purposes during training using torchrec's pipeline. However
|
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
|
||||||
the underlying model signature needs to be exportable to onnx, requiring
|
and loss function as arguments:
|
||||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
|
||||||
|
|
||||||
"""
|
```python
|
||||||
outputs = self.model(
|
# Create a neural network model
|
||||||
continuous_features=batch.continuous_features,
|
model = YourNeuralNetworkModel()
|
||||||
binary_features=batch.binary_features,
|
|
||||||
discrete_features=batch.discrete_features,
|
|
||||||
sparse_features=batch.sparse_features,
|
|
||||||
user_embedding=batch.user_embedding,
|
|
||||||
user_eng_embedding=batch.user_eng_embedding,
|
|
||||||
author_embedding=batch.author_embedding,
|
|
||||||
labels=batch.labels,
|
|
||||||
weights=batch.weights,
|
|
||||||
)
|
|
||||||
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
|
|
||||||
|
|
||||||
if self.stratifiers:
|
# Define a loss function
|
||||||
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
|
loss_fn = torch.nn.CrossEntropyLoss()
|
||||||
outputs["stratifiers"] = {}
|
|
||||||
for stratifier in self.stratifiers:
|
|
||||||
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
|
|
||||||
|
|
||||||
# In general, we can have a large number of losses returned by our loss function.
|
# Create an instance of ModelAndLoss
|
||||||
if isinstance(losses, dict):
|
model_and_loss = ModelAndLoss(model, loss_fn)
|
||||||
return losses["loss"], {
|
|
||||||
**outputs,
|
# Generate a batch of training data (e.g., RecapBatch)
|
||||||
**losses,
|
batch = generate_training_batch()
|
||||||
"labels": batch.labels,
|
|
||||||
"weights": batch.weights,
|
# Perform a forward pass through the model and calculate the loss
|
||||||
}
|
loss, outputs = model_and_loss(batch)
|
||||||
else: # Assume that this is a float.
|
|
||||||
return losses, {
|
# You can now backpropagate and optimize using the computed loss
|
||||||
**outputs,
|
loss.backward()
|
||||||
"loss": losses,
|
optimizer.step()
|
||||||
"labels": batch.labels,
|
```
|
||||||
"weights": batch.weights,
|
|
||||||
}
|
Note:
|
||||||
|
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
|
||||||
|
calculating loss, making it easier to integrate the model into your training loop. Additionally,
|
||||||
|
it supports the addition of stratifiers for metrics stratification, if needed.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
loss_fn: Callable,
|
||||||
|
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the ModelAndLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The torch module to wrap.
|
||||||
|
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||||
|
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||||
|
for metrics stratification.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.stratifiers = stratifiers
|
||||||
|
|
||||||
|
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||||||
|
"""Runs model forward and calculates loss according to given loss_fn.
|
||||||
|
|
||||||
|
NOTE: The input signature here needs to be a Pipelineable object for
|
||||||
|
prefetching purposes during training using torchrec's pipeline. However
|
||||||
|
the underlying model signature needs to be exportable to onnx, requiring
|
||||||
|
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||||
|
|
||||||
|
"""
|
||||||
|
outputs = self.model(
|
||||||
|
continuous_features=batch.continuous_features,
|
||||||
|
binary_features=batch.binary_features,
|
||||||
|
discrete_features=batch.discrete_features,
|
||||||
|
sparse_features=batch.sparse_features,
|
||||||
|
user_embedding=batch.user_embedding,
|
||||||
|
user_eng_embedding=batch.user_eng_embedding,
|
||||||
|
author_embedding=batch.author_embedding,
|
||||||
|
labels=batch.labels,
|
||||||
|
weights=batch.weights,
|
||||||
|
)
|
||||||
|
losses = self.loss_fn(
|
||||||
|
outputs["logits"], batch.labels.float(), batch.weights.float())
|
||||||
|
|
||||||
|
if self.stratifiers:
|
||||||
|
logging.info(
|
||||||
|
f"***** Adding stratifiers *****\n {self.stratifiers}")
|
||||||
|
outputs["stratifiers"] = {}
|
||||||
|
for stratifier in self.stratifiers:
|
||||||
|
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
|
||||||
|
stratifier.index]
|
||||||
|
|
||||||
|
# In general, we can have a large number of losses returned by our loss function.
|
||||||
|
if isinstance(losses, dict):
|
||||||
|
return losses["loss"], {
|
||||||
|
**outputs,
|
||||||
|
**losses,
|
||||||
|
"labels": batch.labels,
|
||||||
|
"weights": batch.weights,
|
||||||
|
}
|
||||||
|
else: # Assume that this is a float.
|
||||||
|
return losses, {
|
||||||
|
**outputs,
|
||||||
|
"loss": losses,
|
||||||
|
"labels": batch.labels,
|
||||||
|
"weights": batch.weights,
|
||||||
|
}
|
||||||
|
@ -2,64 +2,65 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class NumericCalibration(torch.nn.Module):
|
class NumericCalibration(torch.nn.Module):
|
||||||
"""
|
|
||||||
Numeric calibration module for adjusting probability scores.
|
|
||||||
|
|
||||||
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
|
|
||||||
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
|
|
||||||
for tasks such as binary classification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pos_downsampling_rate (float): The downsampling rate for positive samples.
|
|
||||||
neg_downsampling_rate (float): The downsampling rate for negative samples.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
|
|
||||||
scores like this:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Create a NumericCalibration instance with downsampling rates
|
|
||||||
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
|
|
||||||
|
|
||||||
# Generate probability scores (e.g., from a neural network)
|
|
||||||
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
|
|
||||||
|
|
||||||
# Apply numeric calibration to adjust the probabilities
|
|
||||||
calibrated_probs = calibration(raw_probs)
|
|
||||||
|
|
||||||
# The `calibrated_probs` now contains the adjusted probability scores
|
|
||||||
```
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The `NumericCalibration` module is used to adjust probability scores to account for differences in
|
|
||||||
the number of positive and negative samples in a dataset. It can help improve the calibration of
|
|
||||||
probability estimates in imbalanced classification problems.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
This class is intended for internal use within neural network architectures and should not be
|
|
||||||
directly accessed or modified by external code.
|
|
||||||
"""
|
"""
|
||||||
def __init__(
|
Numeric calibration module for adjusting probability scores.
|
||||||
self,
|
|
||||||
pos_downsampling_rate: float,
|
|
||||||
neg_downsampling_rate: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Apply numeric calibration to probability scores.
|
|
||||||
|
|
||||||
Args:
|
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
|
||||||
probs (torch.Tensor): Probability scores to be calibrated.
|
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
|
||||||
|
for tasks such as binary classification.
|
||||||
|
|
||||||
Returns:
|
Args:
|
||||||
torch.Tensor: Calibrated probability scores.
|
pos_downsampling_rate (float): The downsampling rate for positive samples.
|
||||||
|
neg_downsampling_rate (float): The downsampling rate for negative samples.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
|
||||||
|
scores like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create a NumericCalibration instance with downsampling rates
|
||||||
|
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
|
||||||
|
|
||||||
|
# Generate probability scores (e.g., from a neural network)
|
||||||
|
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
|
||||||
|
|
||||||
|
# Apply numeric calibration to adjust the probabilities
|
||||||
|
calibrated_probs = calibration(raw_probs)
|
||||||
|
|
||||||
|
# The `calibrated_probs` now contains the adjusted probability scores
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `NumericCalibration` module is used to adjust probability scores to account for differences in
|
||||||
|
the number of positive and negative samples in a dataset. It can help improve the calibration of
|
||||||
|
probability estimates in imbalanced classification problems.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pos_downsampling_rate: float,
|
||||||
|
neg_downsampling_rate: float,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
Apply numeric calibration to probability scores.
|
||||||
|
|
||||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
Args:
|
||||||
# Will also be part of state_dict.
|
probs (torch.Tensor): Probability scores to be calibrated.
|
||||||
self.register_buffer(
|
|
||||||
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, probs: torch.Tensor):
|
Returns:
|
||||||
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))
|
torch.Tensor: Calibrated probability scores.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||||
|
# Will also be part of state_dict.
|
||||||
|
self.register_buffer(
|
||||||
|
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, probs: torch.Tensor):
|
||||||
|
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))
|
||||||
|
51
tools/pq.py
51
tools/pq.py
@ -38,6 +38,15 @@ import pyarrow.parquet as pq
|
|||||||
|
|
||||||
|
|
||||||
def _create_dataset(path: str):
|
def _create_dataset(path: str):
|
||||||
|
"""
|
||||||
|
Create a PyArrow dataset from Parquet files located at the specified path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the Parquet files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pyarrow.dataset.Dataset: The PyArrow dataset.
|
||||||
|
"""
|
||||||
fs = infer_fs(path)
|
fs = infer_fs(path)
|
||||||
files = fs.glob(path)
|
files = fs.glob(path)
|
||||||
return pads.dataset(files, format="parquet", filesystem=fs)
|
return pads.dataset(files, format="parquet", filesystem=fs)
|
||||||
@ -47,12 +56,27 @@ class PqReader:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize a Parquet Reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the Parquet files.
|
||||||
|
num (int): The maximum number of rows to read.
|
||||||
|
batch_size (int): The batch size for reading data.
|
||||||
|
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
|
||||||
|
"""
|
||||||
self._ds = _create_dataset(path)
|
self._ds = _create_dataset(path)
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._num = num
|
self._num = num
|
||||||
self._columns = columns
|
self._columns = columns
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Iterate through the Parquet data and yield batches of rows.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
pyarrow.RecordBatch: A batch of rows.
|
||||||
|
"""
|
||||||
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
||||||
rows_seen = 0
|
rows_seen = 0
|
||||||
for count, record in enumerate(batches):
|
for count, record in enumerate(batches):
|
||||||
@ -62,6 +86,12 @@ class PqReader:
|
|||||||
rows_seen += record.data.num_rows
|
rows_seen += record.data.num_rows
|
||||||
|
|
||||||
def _head(self):
|
def _head(self):
|
||||||
|
"""
|
||||||
|
Get the first `num` rows of the Parquet data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pyarrow.RecordBatch: A batch of rows.
|
||||||
|
"""
|
||||||
total_read = self._num * self.bytes_per_row
|
total_read = self._num * self.bytes_per_row
|
||||||
if total_read >= int(500e6):
|
if total_read >= int(500e6):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@ -71,6 +101,12 @@ class PqReader:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def bytes_per_row(self) -> int:
|
def bytes_per_row(self) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the estimated bytes per row in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The estimated bytes per row.
|
||||||
|
"""
|
||||||
nbits = 0
|
nbits = 0
|
||||||
for t in self._ds.schema.types:
|
for t in self._ds.schema.types:
|
||||||
try:
|
try:
|
||||||
@ -81,18 +117,23 @@ class PqReader:
|
|||||||
return nbits // 8
|
return nbits // 8
|
||||||
|
|
||||||
def schema(self):
|
def schema(self):
|
||||||
|
"""
|
||||||
|
Display the schema of the Parquet dataset.
|
||||||
|
"""
|
||||||
print(f"\n# Schema\n{self._ds.schema}")
|
print(f"\n# Schema\n{self._ds.schema}")
|
||||||
|
|
||||||
def head(self):
|
def head(self):
|
||||||
"""Displays first --num rows."""
|
"""
|
||||||
|
Display the first `num` rows of the Parquet data as a pandas DataFrame.
|
||||||
|
"""
|
||||||
print(self._head().to_pandas())
|
print(self._head().to_pandas())
|
||||||
|
|
||||||
def distinct(self):
|
def distinct(self):
|
||||||
"""Displays unique values seen in specified columns in the first `--num` rows.
|
|
||||||
|
|
||||||
Useful for getting an approximate vocabulary for certain columns.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Display unique values seen in specified columns in the first `num` rows.
|
||||||
|
|
||||||
|
Useful for getting an approximate vocabulary for certain columns.
|
||||||
|
"""
|
||||||
for col_name, column in zip(self._head().column_names, self._head().columns):
|
for col_name, column in zip(self._head().column_names, self._head().columns):
|
||||||
print(col_name)
|
print(col_name)
|
||||||
print("unique:", column.unique().to_pylist())
|
print("unique:", column.unique().to_pylist())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user