mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-17 21:49:21 +01:00
105 lines
2.6 KiB
Python
105 lines
2.6 KiB
Python
|
"""Local reader of parquet files.
|
||
|
|
||
|
1. Make sure you are initialized locally:
|
||
|
```
|
||
|
./images/init_venv_macos.sh
|
||
|
```
|
||
|
2. Activate
|
||
|
```
|
||
|
source ~/tml_venv/bin/activate
|
||
|
```
|
||
|
3. Use tool, e.g.
|
||
|
|
||
|
`head` prints the first `--num` rows of the dataset.
|
||
|
```
|
||
|
python3 tools/pq.py \
|
||
|
--num 5 --path "tweet_eng/small/edges/all/*" \
|
||
|
head
|
||
|
```
|
||
|
|
||
|
`distinct` prints the observed values in the first `--num` rows for the specified columns.
|
||
|
```
|
||
|
python3 tools/pq.py \
|
||
|
--num 1000000000 --columns '["rel"]' \
|
||
|
--path "tweet_eng/small/edges/all/*" \
|
||
|
distinct
|
||
|
```
|
||
|
|
||
|
"""
|
||
|
from typing import List, Optional
|
||
|
|
||
|
from tml.common.filesystem import infer_fs
|
||
|
|
||
|
import fire
|
||
|
import pandas as pd
|
||
|
import pyarrow as pa
|
||
|
import pyarrow.dataset as pads
|
||
|
import pyarrow.parquet as pq
|
||
|
|
||
|
|
||
|
def _create_dataset(path: str):
|
||
|
fs = infer_fs(path)
|
||
|
files = fs.glob(path)
|
||
|
return pads.dataset(files, format="parquet", filesystem=fs)
|
||
|
|
||
|
|
||
|
class PqReader:
|
||
|
def __init__(
|
||
|
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
||
|
):
|
||
|
self._ds = _create_dataset(path)
|
||
|
self._batch_size = batch_size
|
||
|
self._num = num
|
||
|
self._columns = columns
|
||
|
|
||
|
def __iter__(self):
|
||
|
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
||
|
rows_seen = 0
|
||
|
for count, record in enumerate(batches):
|
||
|
if self._num and rows_seen >= self._num:
|
||
|
break
|
||
|
yield record
|
||
|
rows_seen += record.data.num_rows
|
||
|
|
||
|
def _head(self):
|
||
|
total_read = self._num * self.bytes_per_row
|
||
|
if total_read >= int(500e6):
|
||
|
raise Exception(
|
||
|
"Sorry you're trying to read more than 500 MB " f"into memory ({total_read} bytes)."
|
||
|
)
|
||
|
return self._ds.head(self._num, columns=self._columns)
|
||
|
|
||
|
@property
|
||
|
def bytes_per_row(self) -> int:
|
||
|
nbits = 0
|
||
|
for t in self._ds.schema.types:
|
||
|
try:
|
||
|
nbits += t.bit_width
|
||
|
except:
|
||
|
# Just estimate size if it is variable
|
||
|
nbits += 8
|
||
|
return nbits // 8
|
||
|
|
||
|
def schema(self):
|
||
|
print(f"\n# Schema\n{self._ds.schema}")
|
||
|
|
||
|
def head(self):
|
||
|
"""Displays first --num rows."""
|
||
|
print(self._head().to_pandas())
|
||
|
|
||
|
def distinct(self):
|
||
|
"""Displays unique values seen in specified columns in the first `--num` rows.
|
||
|
|
||
|
Useful for getting an approximate vocabulary for certain columns.
|
||
|
|
||
|
"""
|
||
|
for col_name, column in zip(self._head().column_names, self._head().columns):
|
||
|
print(col_name)
|
||
|
print("unique:", column.unique().to_pylist())
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
pd.set_option("display.max_columns", None)
|
||
|
pd.set_option("display.max_rows", None)
|
||
|
fire.Fire(PqReader)
|