mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 06:41:49 +01:00
Update speed_check(), add type hints.
- Print batch performance metrics when idx is greater than zero and not calculate and change each batch. - Generate a ValueError when placing an Exception when an iterable generates a problem.
This commit is contained in:
parent
78c3235eee
commit
9029e9ee2e
121
reader/utils.py
121
reader/utils.py
@ -1,87 +1,78 @@
|
|||||||
"""Reader utilities."""
|
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from tml.common.batch import DataclassBatch
|
|
||||||
from tml.ml_logging.torch_logging import logging
|
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tml.common.batch import DataclassBatch
|
||||||
|
from tml.ml_logging.torch_logging import logging
|
||||||
|
|
||||||
|
|
||||||
def roundrobin(*iterables):
|
def roundrobin(*iterables):
|
||||||
"""Round robin through provided iterables, useful for simple load balancing.
|
"""Round robin through provided iterables, useful for simple load balancing.
|
||||||
|
Adapted from https://docs.python.org/3/library/itertools.html.
|
||||||
Adapted from https://docs.python.org/3/library/itertools.html.
|
"""
|
||||||
|
num_active = len(iterables)
|
||||||
"""
|
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
||||||
num_active = len(iterables)
|
while num_active:
|
||||||
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
try:
|
||||||
while num_active:
|
for _next in nexts:
|
||||||
try:
|
result = _next()
|
||||||
for _next in nexts:
|
yield result
|
||||||
result = _next()
|
except StopIteration:
|
||||||
yield result
|
num_active -= 1
|
||||||
except StopIteration:
|
nexts = itertools.cycle(itertools.islice(nexts, num_active))
|
||||||
# Remove the iterator we just exhausted from the cycle.
|
logging.warning(f"Iterable exhausted, {num_active} iterables left.")
|
||||||
num_active -= 1
|
except Exception as exc:
|
||||||
nexts = itertools.cycle(itertools.islice(nexts, num_active))
|
logging.warning(f"Iterable raised exception {exc}, ignoring.")
|
||||||
logging.warning(f"Iterable exhausted, {num_active} iterables left.")
|
raise
|
||||||
except Exception as exc:
|
|
||||||
logging.warning(f"Iterable raised exception {exc}, ignoring.")
|
|
||||||
# continue
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
||||||
num_examples = 0
|
prev = time.perf_counter()
|
||||||
prev = time.perf_counter()
|
for idx, batch in enumerate(data_loader):
|
||||||
for idx, batch in enumerate(data_loader):
|
if idx > max_steps:
|
||||||
if idx > max_steps:
|
break
|
||||||
break
|
if peek and idx % peek == 0:
|
||||||
if peek and idx % peek == 0:
|
logging.info(f"Batch: {batch}")
|
||||||
logging.info(f"Batch: {batch}")
|
if idx % frequency == 0 and idx > 0:
|
||||||
num_examples += batch.batch_size
|
now = time.perf_counter()
|
||||||
if idx % frequency == 0:
|
elapsed = now - prev
|
||||||
now = time.perf_counter()
|
examples_per_second = batch.batch_size / elapsed
|
||||||
elapsed = now - prev
|
logging.info(
|
||||||
logging.info(
|
f"step: {idx}, elapsed(s): {elapsed:.2f}, examples: {batch.batch_size}, "
|
||||||
f"step: {idx}, "
|
f"ex/s: {examples_per_second:.2f}"
|
||||||
f"elapsed(s): {elapsed}, "
|
)
|
||||||
f"examples: {num_examples}, "
|
prev = now
|
||||||
f"ex/s: {num_examples / elapsed}, "
|
|
||||||
)
|
|
||||||
prev = now
|
|
||||||
num_examples = 0
|
|
||||||
|
|
||||||
|
|
||||||
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
||||||
return torch.from_numpy(array.to_numpy())
|
return torch.from_numpy(array.to_numpy())
|
||||||
|
|
||||||
|
|
||||||
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
||||||
""" """
|
""" """
|
||||||
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
||||||
|
|
||||||
def get_imputation_value(pa_type):
|
def get_imputation_value(pa_type):
|
||||||
type_map = {
|
type_map = {
|
||||||
pa.float64(): pa.scalar(0, type=pa.float64()),
|
pa.float64(): pa.scalar(0, type=pa.float64()),
|
||||||
pa.int64(): pa.scalar(0, type=pa.int64()),
|
pa.int64(): pa.scalar(0, type=pa.int64()),
|
||||||
pa.string(): pa.scalar("", type=pa.string()),
|
pa.string(): pa.scalar("", type=pa.string()),
|
||||||
}
|
}
|
||||||
if pa_type not in type_map:
|
if pa_type not in type_map:
|
||||||
raise Exception(f"Imputation for type {pa_type} not supported.")
|
raise ValueError(f"Imputation for type {pa_type} not supported.")
|
||||||
return type_map[pa_type]
|
return type_map[pa_type]
|
||||||
|
|
||||||
def _impute(array: pa.array) -> pa.array:
|
def _impute(array: pa.array) -> pa.array:
|
||||||
return array.fill_null(get_imputation_value(array.type))
|
return array.fill_null(get_imputation_value(array.type))
|
||||||
|
|
||||||
def _column_to_tensor(record_batch: pa.RecordBatch):
|
def _column_to_tensor(record_batch: pa.RecordBatch):
|
||||||
tensors = {
|
tensors = {
|
||||||
col_name: pa_to_torch(_impute(record_batch.column(col_name)))
|
col_name: pa_to_torch(_impute(record_batch.column(col_name)))
|
||||||
for col_name in record_batch.schema.names
|
for col_name in record_batch.schema.names
|
||||||
}
|
}
|
||||||
return _CustomBatch(**tensors)
|
return _CustomBatch(**tensors)
|
||||||
|
|
||||||
return _column_to_tensor
|
return _column_to_tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user