Update speed_check(), add type hints.

- Print batch performance metrics when idx is greater than zero and not calculate and change each batch.
- Generate a ValueError when placing an Exception when an iterable generates a problem.
This commit is contained in:
Bruno Rodrigues Faria 2023-04-02 13:11:15 -03:00 committed by GitHub
parent 78c3235eee
commit 9029e9ee2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,87 +1,78 @@
"""Reader utilities."""
import itertools import itertools
import time import time
from typing import Optional from typing import Optional
from tml.common.batch import DataclassBatch
from tml.ml_logging.torch_logging import logging
import pyarrow as pa import pyarrow as pa
import torch import torch
from tml.common.batch import DataclassBatch
from tml.ml_logging.torch_logging import logging
def roundrobin(*iterables): def roundrobin(*iterables):
"""Round robin through provided iterables, useful for simple load balancing. """Round robin through provided iterables, useful for simple load balancing.
Adapted from https://docs.python.org/3/library/itertools.html.
Adapted from https://docs.python.org/3/library/itertools.html. """
num_active = len(iterables)
""" nexts = itertools.cycle(iter(it).__next__ for it in iterables)
num_active = len(iterables) while num_active:
nexts = itertools.cycle(iter(it).__next__ for it in iterables) try:
while num_active: for _next in nexts:
try: result = _next()
for _next in nexts: yield result
result = _next() except StopIteration:
yield result num_active -= 1
except StopIteration: nexts = itertools.cycle(itertools.islice(nexts, num_active))
# Remove the iterator we just exhausted from the cycle. logging.warning(f"Iterable exhausted, {num_active} iterables left.")
num_active -= 1 except Exception as exc:
nexts = itertools.cycle(itertools.islice(nexts, num_active)) logging.warning(f"Iterable raised exception {exc}, ignoring.")
logging.warning(f"Iterable exhausted, {num_active} iterables left.") raise
except Exception as exc:
logging.warning(f"Iterable raised exception {exc}, ignoring.")
# continue
raise
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]): def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
num_examples = 0 prev = time.perf_counter()
prev = time.perf_counter() for idx, batch in enumerate(data_loader):
for idx, batch in enumerate(data_loader): if idx > max_steps:
if idx > max_steps: break
break if peek and idx % peek == 0:
if peek and idx % peek == 0: logging.info(f"Batch: {batch}")
logging.info(f"Batch: {batch}") if idx % frequency == 0 and idx > 0:
num_examples += batch.batch_size now = time.perf_counter()
if idx % frequency == 0: elapsed = now - prev
now = time.perf_counter() examples_per_second = batch.batch_size / elapsed
elapsed = now - prev logging.info(
logging.info( f"step: {idx}, elapsed(s): {elapsed:.2f}, examples: {batch.batch_size}, "
f"step: {idx}, " f"ex/s: {examples_per_second:.2f}"
f"elapsed(s): {elapsed}, " )
f"examples: {num_examples}, " prev = now
f"ex/s: {num_examples / elapsed}, "
)
prev = now
num_examples = 0
def pa_to_torch(array: pa.array) -> torch.Tensor: def pa_to_torch(array: pa.array) -> torch.Tensor:
return torch.from_numpy(array.to_numpy()) return torch.from_numpy(array.to_numpy())
def create_default_pa_to_batch(schema) -> DataclassBatch: def create_default_pa_to_batch(schema) -> DataclassBatch:
""" """ """ """
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema) _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
def get_imputation_value(pa_type): def get_imputation_value(pa_type):
type_map = { type_map = {
pa.float64(): pa.scalar(0, type=pa.float64()), pa.float64(): pa.scalar(0, type=pa.float64()),
pa.int64(): pa.scalar(0, type=pa.int64()), pa.int64(): pa.scalar(0, type=pa.int64()),
pa.string(): pa.scalar("", type=pa.string()), pa.string(): pa.scalar("", type=pa.string()),
} }
if pa_type not in type_map: if pa_type not in type_map:
raise Exception(f"Imputation for type {pa_type} not supported.") raise ValueError(f"Imputation for type {pa_type} not supported.")
return type_map[pa_type] return type_map[pa_type]
def _impute(array: pa.array) -> pa.array: def _impute(array: pa.array) -> pa.array:
return array.fill_null(get_imputation_value(array.type)) return array.fill_null(get_imputation_value(array.type))
def _column_to_tensor(record_batch: pa.RecordBatch): def _column_to_tensor(record_batch: pa.RecordBatch):
tensors = { tensors = {
col_name: pa_to_torch(_impute(record_batch.column(col_name))) col_name: pa_to_torch(_impute(record_batch.column(col_name)))
for col_name in record_batch.schema.names for col_name in record_batch.schema.names
} }
return _CustomBatch(**tensors) return _CustomBatch(**tensors)
return _column_to_tensor return _column_to_tensor