mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-08 22:20:40 +01:00
fix: typos & weird extra lines
This commit is contained in:
parent
78c3235eee
commit
89f541e988
@ -1,5 +1,4 @@
|
||||
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
|
||||
"""
|
||||
"""Extension of torchrec.dataset.utils.Batch to cover any dataset."""
|
||||
# flake8: noqa
|
||||
from __future__ import annotations
|
||||
from typing import Dict
|
||||
|
@ -16,7 +16,6 @@ class Snapshot:
|
||||
"""Checkpoints using torchsnapshot.
|
||||
|
||||
Also saves step to be updated by the training loop.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
||||
@ -146,7 +145,6 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
|
||||
Args:
|
||||
seconds_to_sleep: time between polling calls.
|
||||
timeout: how long to wait for a new checkpoint.
|
||||
|
||||
"""
|
||||
|
||||
def _poll(last_checkpoint: Optional[str] = None):
|
||||
@ -184,7 +182,6 @@ def get_checkpoint(
|
||||
Args:
|
||||
global_step: Finds this checkpoint if specified.
|
||||
missing_ok: if True and checkpoints do not exist, returns without restoration.
|
||||
|
||||
"""
|
||||
checkpoints = get_checkpoints(save_dir)
|
||||
if not checkpoints:
|
||||
|
@ -18,7 +18,6 @@ def weights_to_log(
|
||||
model: model to traverse.
|
||||
how_to_log: if a function, then applies this to every parameter, if a dict
|
||||
then only applies and logs specified parameters.
|
||||
|
||||
"""
|
||||
if not how_to_log:
|
||||
return
|
||||
|
@ -6,7 +6,6 @@ Features (go/100_enablements):
|
||||
- On/off-chief evaluation
|
||||
- Warmstart/checkpoint management
|
||||
- go/dataset-service 0-copy integration
|
||||
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
|
@ -22,7 +22,6 @@ for its update method.
|
||||
2. Redefine an existing metric class.
|
||||
|
||||
SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})
|
||||
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List
|
||||
@ -88,7 +87,6 @@ def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
||||
|
||||
Functionally the same using inheritance, just saves some lines of code
|
||||
if no need for class attributes.
|
||||
|
||||
"""
|
||||
|
||||
def transform_method(_self, *args, **kwargs):
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Common metrics that also support multi task.
|
||||
|
||||
We assume multi task models will output [task_idx, ...] predictions
|
||||
|
||||
"""
|
||||
from typing import Any, Dict
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
|
||||
with TrainPipelineSparseDist.progress modified to support gradient accumulation.
|
||||
|
||||
"""
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
@ -529,7 +528,6 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||
step.
|
||||
|
||||
"""
|
||||
should_step_optimizer = (
|
||||
self._grad_accum is not None
|
||||
|
@ -15,7 +15,6 @@ Args:
|
||||
Typical usage would be in a bash script, e.g.:
|
||||
|
||||
LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len)
|
||||
|
||||
"""
|
||||
import tml.machines.environment as env
|
||||
|
||||
|
@ -24,7 +24,6 @@ def update_mean(
|
||||
weight: The weights for the new value.
|
||||
|
||||
Returns: The updated mean and updated weighted sum.
|
||||
|
||||
"""
|
||||
weight = torch.broadcast_to(weight, value.shape)
|
||||
|
||||
@ -42,7 +41,6 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||
state: A tensor with the first dimension indicating workers.
|
||||
|
||||
Returns: The accumulated mean from all workers.
|
||||
|
||||
"""
|
||||
mean, weight_sum = update_mean(
|
||||
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||
|
@ -62,7 +62,6 @@ class AUROCWithMWU(torchmetrics.Metric):
|
||||
|
||||
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
label_threshold: Labels strictly above this threshold are considered positive labels,
|
||||
otherwise, they are considered negative.
|
||||
|
@ -41,7 +41,6 @@ def _binary_cross_entropy_with_clipping(
|
||||
reduction: The reduction method to use.
|
||||
|
||||
Returns: Binary cross entropy on the clipped predictions.
|
||||
|
||||
"""
|
||||
predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon)
|
||||
bce = -target * torch.log(predictions + epsilon)
|
||||
@ -220,14 +219,12 @@ class NRCE(RCE):
|
||||
perform if it was well calibrated.
|
||||
|
||||
.. note:: A big gap between NRCE and RCE might indicate a badly calibrated model,
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
from_logits: whether or not predictions are logits or probabilities.
|
||||
label_smoothing: label smoothing constant.
|
||||
|
@ -5,7 +5,6 @@
|
||||
Usage:
|
||||
>>> from twitter.ml.logging.absl_logging import logging
|
||||
>>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.")
|
||||
|
||||
"""
|
||||
import logging as py_logging
|
||||
import sys
|
||||
|
@ -6,7 +6,6 @@
|
||||
>>> from ml.logging.torch_logging import logging
|
||||
>>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.")
|
||||
>>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1)
|
||||
|
||||
"""
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
1
model.py
1
model.py
@ -33,7 +33,6 @@ class ModelAndLoss(torch.nn.Module):
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||
|
||||
"""
|
||||
outputs = self.model(batch)
|
||||
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
|
||||
|
@ -155,7 +155,6 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||||
Returns:
|
||||
data_files: Sorted list of files to read corresponding to data at the desired datetimes
|
||||
num_hours_missing: Number of hours that we are missing data
|
||||
|
||||
"""
|
||||
datetimes = get_datetimes(explicit_datetime_inputs)
|
||||
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
|
||||
|
@ -14,7 +14,6 @@ def keyed_tensor_from_tensors_dict(
|
||||
tensor_map:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
keys = list(tensor_map.keys())
|
||||
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
||||
@ -64,7 +63,6 @@ def keyed_jagged_tensor_from_tensors_dict(
|
||||
tensor_map:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
if not tensor_map:
|
||||
|
@ -26,7 +26,7 @@ def unsanitize(sanitized_task_name):
|
||||
|
||||
|
||||
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
|
||||
""" "Builds a model for a single task"""
|
||||
""""Builds a model for a single task"""
|
||||
if task.mlp_config:
|
||||
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
|
||||
elif task.dcn_config:
|
||||
|
@ -29,7 +29,6 @@ class ModelAndLoss(torch.nn.Module):
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||
|
||||
"""
|
||||
outputs = self.model(
|
||||
continuous_features=batch.continuous_features,
|
||||
|
@ -27,7 +27,6 @@ class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -142,7 +142,6 @@ class TwhinModelAndLoss(torch.nn.Module):
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||
|
||||
"""
|
||||
outputs = self.model(batch)
|
||||
logits = outputs["logits"]
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
- Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch.
|
||||
- Readers can be colocated or off trainer machines.
|
||||
|
||||
"""
|
||||
import abc
|
||||
import functools
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Dataset service orchestrated by a TFJob
|
||||
"""
|
||||
"""Dataset service orchestrated by a TFJob"""
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
@ -104,7 +103,6 @@ def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
||||
- all rank processes will consume from the same job/dataset.
|
||||
|
||||
Without this, dataset workers will try to serve 1 job per rank process and OOM.
|
||||
|
||||
"""
|
||||
if not env.has_readers():
|
||||
return dataset
|
||||
|
Loading…
Reference in New Issue
Block a user