mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 13:21:10 +01:00
229 lines
7.1 KiB
Python
229 lines
7.1 KiB
Python
"""Reader utilities."""
|
|
import itertools
|
|
import time
|
|
from typing import Optional
|
|
|
|
from tml.common.batch import DataclassBatch
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
import pyarrow as pa
|
|
import torch
|
|
|
|
|
|
def roundrobin(*iterables):
|
|
"""
|
|
Iterate through provided iterables in a round-robin fashion.
|
|
|
|
This function takes multiple iterables and returns an iterator that yields elements from
|
|
each iterable in a round-robin manner. It continues cycling through the iterables until
|
|
all of them are exhausted.
|
|
|
|
Adapted from https://docs.python.org/3/library/itertools.html.
|
|
|
|
Args:
|
|
*iterables: One or more iterable objects to iterate through.
|
|
|
|
Yields:
|
|
Elements from the provided iterables in a round-robin fashion.
|
|
|
|
Raises:
|
|
StopIteration: If all provided iterables are exhausted.
|
|
|
|
Example:
|
|
```python
|
|
iterable1 = [1, 2, 3]
|
|
iterable2 = ['a', 'b', 'c']
|
|
iterable3 = [0.1, 0.2, 0.3]
|
|
|
|
for item in roundrobin(iterable1, iterable2, iterable3):
|
|
print(item)
|
|
|
|
# Output:
|
|
# 1
|
|
# 'a'
|
|
# 0.1
|
|
# 2
|
|
# 'b'
|
|
# 0.2
|
|
# 3
|
|
# 'c'
|
|
# 0.3
|
|
```
|
|
|
|
Note:
|
|
- If one of the provided iterables is shorter than the others, the function will
|
|
continue iterating through the remaining iterables until all are exhausted.
|
|
- If an iterable raises an exception during iteration, a warning message is logged,
|
|
and the function continues with the next iterable.
|
|
|
|
See Also:
|
|
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
|
|
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
|
|
"""
|
|
num_active = len(iterables)
|
|
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
|
while num_active:
|
|
try:
|
|
for _next in nexts:
|
|
result = _next()
|
|
yield result
|
|
except StopIteration:
|
|
# Remove the iterator we just exhausted from the cycle.
|
|
num_active -= 1
|
|
nexts = itertools.cycle(itertools.islice(nexts, num_active))
|
|
logging.warning(f"Iterable exhausted, {num_active} iterables left.")
|
|
except Exception as exc:
|
|
logging.warning(f"Iterable raised exception {exc}, ignoring.")
|
|
# continue
|
|
raise
|
|
|
|
|
|
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
|
"""
|
|
Monitor the speed and progress of data loading using a data loader.
|
|
|
|
This function iterates through a data loader for a specified number of steps or until
|
|
the end of the data loader is reached, periodically logging progress information.
|
|
|
|
Args:
|
|
data_loader: The data loader to monitor.
|
|
max_steps: The maximum number of steps to iterate through the data loader.
|
|
frequency: The frequency (in steps) at which to log progress.
|
|
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
|
batch contents for inspection.
|
|
|
|
Example:
|
|
```python
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
# Create a data loader (replace with your own DataLoader configuration)
|
|
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
|
|
|
# Monitor data loading speed and progress
|
|
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
|
|
```
|
|
|
|
Args:
|
|
data_loader: The data loader to monitor.
|
|
max_steps: The maximum number of steps to iterate through the data loader.
|
|
frequency: The frequency (in steps) at which to log progress.
|
|
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
|
batch contents for inspection.
|
|
|
|
Note:
|
|
- The function logs information about elapsed time, the number of examples processed,
|
|
and the processing speed in examples per second.
|
|
- If `peek` is provided, batch contents will be logged for inspection at the specified
|
|
frequency.
|
|
|
|
See Also:
|
|
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
|
|
iterating through datasets.
|
|
"""
|
|
num_examples = 0
|
|
prev = time.perf_counter()
|
|
for idx, batch in enumerate(data_loader):
|
|
if idx > max_steps:
|
|
break
|
|
if peek and idx % peek == 0:
|
|
logging.info(f"Batch: {batch}")
|
|
num_examples += batch.batch_size
|
|
if idx % frequency == 0:
|
|
now = time.perf_counter()
|
|
elapsed = now - prev
|
|
logging.info(
|
|
f"step: {idx}, "
|
|
f"elapsed(s): {elapsed}, "
|
|
f"examples: {num_examples}, "
|
|
f"ex/s: {num_examples / elapsed}, "
|
|
)
|
|
prev = now
|
|
num_examples = 0
|
|
|
|
|
|
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
|
"""
|
|
Convert a PyArrow Array to a PyTorch Tensor.
|
|
|
|
Args:
|
|
array (pa.array): The PyArrow Array to convert.
|
|
|
|
Returns:
|
|
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
|
|
|
|
Example:
|
|
```python
|
|
import pyarrow as pa
|
|
import torch
|
|
|
|
# Create a PyArrow Array
|
|
arrow_array = pa.array([1, 2, 3])
|
|
|
|
# Convert it to a PyTorch Tensor
|
|
torch_tensor = pa_to_torch(arrow_array)
|
|
```
|
|
"""
|
|
return torch.from_numpy(array.to_numpy())
|
|
|
|
|
|
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
|
"""
|
|
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
|
|
|
|
Args:
|
|
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
|
|
|
|
Returns:
|
|
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
|
|
|
|
Example:
|
|
```python
|
|
import pyarrow as pa
|
|
from dataclass_batch import DataclassBatch
|
|
|
|
# Define a PyArrow schema
|
|
schema = pa.schema([
|
|
("feature1", pa.float64()),
|
|
("feature2", pa.int64()),
|
|
("label", pa.int64()),
|
|
])
|
|
|
|
# Create the conversion function
|
|
pa_to_batch = create_default_pa_to_batch(schema)
|
|
|
|
# Create a PyArrow RecordBatch
|
|
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
|
|
"feature1": [1.0, 2.0, None],
|
|
"feature2": [10, 20, 30],
|
|
"label": [0, 1, None],
|
|
}))
|
|
|
|
# Convert the RecordBatch to a custom DataclassBatch
|
|
custom_batch = pa_to_batch(record_batch)
|
|
```
|
|
"""
|
|
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
|
|
|
def get_imputation_value(pa_type):
|
|
type_map = {
|
|
pa.float64(): pa.scalar(0, type=pa.float64()),
|
|
pa.int64(): pa.scalar(0, type=pa.int64()),
|
|
pa.string(): pa.scalar("", type=pa.string()),
|
|
}
|
|
if pa_type not in type_map:
|
|
raise Exception(f"Imputation for type {pa_type} not supported.")
|
|
return type_map[pa_type]
|
|
|
|
def _impute(array: pa.array) -> pa.array:
|
|
return array.fill_null(get_imputation_value(array.type))
|
|
|
|
def _column_to_tensor(record_batch: pa.RecordBatch):
|
|
tensors = {
|
|
col_name: pa_to_torch(_impute(record_batch.column(col_name)))
|
|
for col_name in record_batch.schema.names
|
|
}
|
|
return _CustomBatch(**tensors)
|
|
|
|
return _column_to_tensor
|