fix: typos & weird extra lines

This commit is contained in:
Matias Fanger 2023-03-31 20:22:09 -03:00
parent 78c3235eee
commit 89f541e988
22 changed files with 3 additions and 33 deletions

View File

@ -1,5 +1,4 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset. """Extension of torchrec.dataset.utils.Batch to cover any dataset."""
"""
# flake8: noqa # flake8: noqa
from __future__ import annotations from __future__ import annotations
from typing import Dict from typing import Dict

View File

@ -16,7 +16,6 @@ class Snapshot:
"""Checkpoints using torchsnapshot. """Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop. Also saves step to be updated by the training loop.
""" """
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None: def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
@ -146,7 +145,6 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
Args: Args:
seconds_to_sleep: time between polling calls. seconds_to_sleep: time between polling calls.
timeout: how long to wait for a new checkpoint. timeout: how long to wait for a new checkpoint.
""" """
def _poll(last_checkpoint: Optional[str] = None): def _poll(last_checkpoint: Optional[str] = None):
@ -184,7 +182,6 @@ def get_checkpoint(
Args: Args:
global_step: Finds this checkpoint if specified. global_step: Finds this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration. missing_ok: if True and checkpoints do not exist, returns without restoration.
""" """
checkpoints = get_checkpoints(save_dir) checkpoints = get_checkpoints(save_dir)
if not checkpoints: if not checkpoints:

View File

@ -18,7 +18,6 @@ def weights_to_log(
model: model to traverse. model: model to traverse.
how_to_log: if a function, then applies this to every parameter, if a dict how_to_log: if a function, then applies this to every parameter, if a dict
then only applies and logs specified parameters. then only applies and logs specified parameters.
""" """
if not how_to_log: if not how_to_log:
return return

View File

@ -6,7 +6,6 @@ Features (go/100_enablements):
- On/off-chief evaluation - On/off-chief evaluation
- Warmstart/checkpoint management - Warmstart/checkpoint management
- go/dataset-service 0-copy integration - go/dataset-service 0-copy integration
""" """
import datetime import datetime
import os import os

View File

@ -22,7 +22,6 @@ for its update method.
2. Redefine an existing metric class. 2. Redefine an existing metric class.
SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1}) SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})
""" """
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, Dict, List from typing import Callable, Dict, List
@ -88,7 +87,6 @@ def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
Functionally the same using inheritance, just saves some lines of code Functionally the same using inheritance, just saves some lines of code
if no need for class attributes. if no need for class attributes.
""" """
def transform_method(_self, *args, **kwargs): def transform_method(_self, *args, **kwargs):

View File

@ -1,7 +1,6 @@
"""Common metrics that also support multi task. """Common metrics that also support multi task.
We assume multi task models will output [task_idx, ...] predictions We assume multi task models will output [task_idx, ...] predictions
""" """
from typing import Any, Dict from typing import Any, Dict

View File

@ -1,7 +1,6 @@
""" """
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
with TrainPipelineSparseDist.progress modified to support gradient accumulation. with TrainPipelineSparseDist.progress modified to support gradient accumulation.
""" """
import abc import abc
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -529,7 +528,6 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
optimizer update/reset is skipped for `_grad_accum` calls of `progress` optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th (congruent to training steps), and then update/reset on every `_grad_accum`th
step. step.
""" """
should_step_optimizer = ( should_step_optimizer = (
self._grad_accum is not None self._grad_accum is not None

View File

@ -15,7 +15,6 @@ Args:
Typical usage would be in a bash script, e.g.: Typical usage would be in a bash script, e.g.:
LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len) LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len)
""" """
import tml.machines.environment as env import tml.machines.environment as env

View File

@ -24,7 +24,6 @@ def update_mean(
weight: The weights for the new value. weight: The weights for the new value.
Returns: The updated mean and updated weighted sum. Returns: The updated mean and updated weighted sum.
""" """
weight = torch.broadcast_to(weight, value.shape) weight = torch.broadcast_to(weight, value.shape)
@ -42,7 +41,6 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
state: A tensor with the first dimension indicating workers. state: A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers. Returns: The accumulated mean from all workers.
""" """
mean, weight_sum = update_mean( mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),

View File

@ -62,7 +62,6 @@ class AUROCWithMWU(torchmetrics.Metric):
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
""" """
Args: Args:
label_threshold: Labels strictly above this threshold are considered positive labels, label_threshold: Labels strictly above this threshold are considered positive labels,
otherwise, they are considered negative. otherwise, they are considered negative.

View File

@ -41,7 +41,6 @@ def _binary_cross_entropy_with_clipping(
reduction: The reduction method to use. reduction: The reduction method to use.
Returns: Binary cross entropy on the clipped predictions. Returns: Binary cross entropy on the clipped predictions.
""" """
predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon) predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon)
bce = -target * torch.log(predictions + epsilon) bce = -target * torch.log(predictions + epsilon)
@ -220,14 +219,12 @@ class NRCE(RCE):
perform if it was well calibrated. perform if it was well calibrated.
.. note:: A big gap between NRCE and RCE might indicate a badly calibrated model, .. note:: A big gap between NRCE and RCE might indicate a badly calibrated model,
""" """
def __init__( def __init__(
self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
): ):
""" """
Args: Args:
from_logits: whether or not predictions are logits or probabilities. from_logits: whether or not predictions are logits or probabilities.
label_smoothing: label smoothing constant. label_smoothing: label smoothing constant.

View File

@ -5,7 +5,6 @@
Usage: Usage:
>>> from twitter.ml.logging.absl_logging import logging >>> from twitter.ml.logging.absl_logging import logging
>>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.") >>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.")
""" """
import logging as py_logging import logging as py_logging
import sys import sys

View File

@ -6,7 +6,6 @@
>>> from ml.logging.torch_logging import logging >>> from ml.logging.torch_logging import logging
>>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.") >>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.")
>>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1) >>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1)
""" """
import functools import functools
from typing import Optional from typing import Optional

View File

@ -33,7 +33,6 @@ class ModelAndLoss(torch.nn.Module):
prefetching purposes during training using torchrec's pipeline. However prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types. generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
""" """
outputs = self.model(batch) outputs = self.model(batch)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())

View File

@ -155,7 +155,6 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
Returns: Returns:
data_files: Sorted list of files to read corresponding to data at the desired datetimes data_files: Sorted list of files to read corresponding to data at the desired datetimes
num_hours_missing: Number of hours that we are missing data num_hours_missing: Number of hours that we are missing data
""" """
datetimes = get_datetimes(explicit_datetime_inputs) datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes] folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]

View File

@ -14,7 +14,6 @@ def keyed_tensor_from_tensors_dict(
tensor_map: tensor_map:
Returns: Returns:
""" """
keys = list(tensor_map.keys()) keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size], # We expect batch size to be first dim. However, if we get a shape [Batch_size],
@ -64,7 +63,6 @@ def keyed_jagged_tensor_from_tensors_dict(
tensor_map: tensor_map:
Returns: Returns:
""" """
if not tensor_map: if not tensor_map:

View File

@ -26,7 +26,7 @@ def unsanitize(sanitized_task_name):
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int): def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task""" """"Builds a model for a single task"""
if task.mlp_config: if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config) return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config: elif task.dcn_config:

View File

@ -29,7 +29,6 @@ class ModelAndLoss(torch.nn.Module):
prefetching purposes during training using torchrec's pipeline. However prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types. generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
""" """
outputs = self.model( outputs = self.model(
continuous_features=batch.continuous_features, continuous_features=batch.continuous_features,

View File

@ -27,7 +27,6 @@ class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
This adheres to the torch.optim scheduler API and can be plugged anywhere that This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used. e.g. exponential decay can be used.
""" """
def __init__( def __init__(

View File

@ -142,7 +142,6 @@ class TwhinModelAndLoss(torch.nn.Module):
prefetching purposes during training using torchrec's pipeline. However prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types. generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
""" """
outputs = self.model(batch) outputs = self.model(batch)
logits = outputs["logits"] logits = outputs["logits"]

View File

@ -2,7 +2,6 @@
- Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch. - Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch.
- Readers can be colocated or off trainer machines. - Readers can be colocated or off trainer machines.
""" """
import abc import abc
import functools import functools

View File

@ -1,5 +1,4 @@
"""Dataset service orchestrated by a TFJob """Dataset service orchestrated by a TFJob"""
"""
from typing import Optional from typing import Optional
import uuid import uuid
@ -104,7 +103,6 @@ def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
- all rank processes will consume from the same job/dataset. - all rank processes will consume from the same job/dataset.
Without this, dataset workers will try to serve 1 job per rank process and OOM. Without this, dataset workers will try to serve 1 job per rank process and OOM.
""" """
if not env.has_readers(): if not env.has_readers():
return dataset return dataset