fix: typos & weird extra lines

This commit is contained in:
Matias Fanger 2023-03-31 20:22:09 -03:00
parent 78c3235eee
commit 89f541e988
22 changed files with 3 additions and 33 deletions

View File

@ -1,5 +1,4 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
"""Extension of torchrec.dataset.utils.Batch to cover any dataset."""
# flake8: noqa
from __future__ import annotations
from typing import Dict

View File

@ -16,7 +16,6 @@ class Snapshot:
"""Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
@ -146,7 +145,6 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
Args:
seconds_to_sleep: time between polling calls.
timeout: how long to wait for a new checkpoint.
"""
def _poll(last_checkpoint: Optional[str] = None):
@ -184,7 +182,6 @@ def get_checkpoint(
Args:
global_step: Finds this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
checkpoints = get_checkpoints(save_dir)
if not checkpoints:

View File

@ -18,7 +18,6 @@ def weights_to_log(
model: model to traverse.
how_to_log: if a function, then applies this to every parameter, if a dict
then only applies and logs specified parameters.
"""
if not how_to_log:
return

View File

@ -6,7 +6,6 @@ Features (go/100_enablements):
- On/off-chief evaluation
- Warmstart/checkpoint management
- go/dataset-service 0-copy integration
"""
import datetime
import os

View File

@ -22,7 +22,6 @@ for its update method.
2. Redefine an existing metric class.
SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})
"""
from abc import abstractmethod
from typing import Callable, Dict, List
@ -88,7 +87,6 @@ def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
Functionally the same using inheritance, just saves some lines of code
if no need for class attributes.
"""
def transform_method(_self, *args, **kwargs):

View File

@ -1,7 +1,6 @@
"""Common metrics that also support multi task.
We assume multi task models will output [task_idx, ...] predictions
"""
from typing import Any, Dict

View File

@ -1,7 +1,6 @@
"""
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
with TrainPipelineSparseDist.progress modified to support gradient accumulation.
"""
import abc
from dataclasses import dataclass, field
@ -529,7 +528,6 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
"""
should_step_optimizer = (
self._grad_accum is not None

View File

@ -15,7 +15,6 @@ Args:
Typical usage would be in a bash script, e.g.:
LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len)
"""
import tml.machines.environment as env

View File

@ -24,7 +24,6 @@ def update_mean(
weight: The weights for the new value.
Returns: The updated mean and updated weighted sum.
"""
weight = torch.broadcast_to(weight, value.shape)
@ -42,7 +41,6 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
state: A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers.
"""
mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),

View File

@ -62,7 +62,6 @@ class AUROCWithMWU(torchmetrics.Metric):
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
"""
Args:
label_threshold: Labels strictly above this threshold are considered positive labels,
otherwise, they are considered negative.

View File

@ -41,7 +41,6 @@ def _binary_cross_entropy_with_clipping(
reduction: The reduction method to use.
Returns: Binary cross entropy on the clipped predictions.
"""
predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon)
bce = -target * torch.log(predictions + epsilon)
@ -220,14 +219,12 @@ class NRCE(RCE):
perform if it was well calibrated.
.. note:: A big gap between NRCE and RCE might indicate a badly calibrated model,
"""
def __init__(
self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
):
"""
Args:
from_logits: whether or not predictions are logits or probabilities.
label_smoothing: label smoothing constant.

View File

@ -5,7 +5,6 @@
Usage:
>>> from twitter.ml.logging.absl_logging import logging
>>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.")
"""
import logging as py_logging
import sys

View File

@ -6,7 +6,6 @@
>>> from ml.logging.torch_logging import logging
>>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.")
>>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1)
"""
import functools
from typing import Optional

View File

@ -33,7 +33,6 @@ class ModelAndLoss(torch.nn.Module):
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(batch)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())

View File

@ -155,7 +155,6 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
Returns:
data_files: Sorted list of files to read corresponding to data at the desired datetimes
num_hours_missing: Number of hours that we are missing data
"""
datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]

View File

@ -14,7 +14,6 @@ def keyed_tensor_from_tensors_dict(
tensor_map:
Returns:
"""
keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
@ -64,7 +63,6 @@ def keyed_jagged_tensor_from_tensors_dict(
tensor_map:
Returns:
"""
if not tensor_map:

View File

@ -26,7 +26,7 @@ def unsanitize(sanitized_task_name):
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task"""
""""Builds a model for a single task"""
if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config:

View File

@ -29,7 +29,6 @@ class ModelAndLoss(torch.nn.Module):
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(
continuous_features=batch.continuous_features,

View File

@ -27,7 +27,6 @@ class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
def __init__(

View File

@ -142,7 +142,6 @@ class TwhinModelAndLoss(torch.nn.Module):
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(batch)
logits = outputs["logits"]

View File

@ -2,7 +2,6 @@
- Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch.
- Readers can be colocated or off trainer machines.
"""
import abc
import functools

View File

@ -1,5 +1,4 @@
"""Dataset service orchestrated by a TFJob
"""
"""Dataset service orchestrated by a TFJob"""
from typing import Optional
import uuid
@ -104,7 +103,6 @@ def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
- all rank processes will consume from the same job/dataset.
Without this, dataset workers will try to serve 1 job per rank process and OOM.
"""
if not env.has_readers():
return dataset