Source code for aces.expand_shards
#!/usr/bin/env python
import os
import re
import sys
from pathlib import Path
[docs]def expand_shards(*shards: str) -> str:
"""This function expands a set of shard prefixes and number of shards into a list of all shards or expands
a directory into a list of all files within it.
This can be useful with Hydra applications where you wish to expand a list of options for the sweeper to
sweep over but can't use an OmegaConf resolver as those are evaluated after the sweep has been
initialized.
Args:
shards: A list of shard prefixes and number of shards to expand, or a directory to list all files.
Returns: A comma-separated list of all shards, expanded to the specified number, or all files in the
directory.
Examples:
>>> import polars as pl
>>> import tempfile
>>> expand_shards("train/4", "val/IID/1", "val/prospective/1")
'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0'
>>> expand_shards("data/data_4", "data/test_4")
'data/data_0,data/data_1,data/data_2,data/data_3,data/test_0,data/test_1,data/test_2,data/test_3'
>>> parquet_data = pl.DataFrame({
... "subject_id": [1, 1, 1, 2, 3],
... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None],
... "code": ['admission', 'discharge', 'discharge', 'admission', "gender"],
... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M"))
>>> with tempfile.TemporaryDirectory() as tmpdirname:
... for i in range(4):
... if i in (0, 2):
... data_path = Path(tmpdirname) / f"evens/0/file_{i}.parquet"
... data_path.parent.mkdir(parents=True, exist_ok=True)
... else:
... data_path = Path(tmpdirname) / f"{i}.parquet"
... parquet_data.write_parquet(data_path)
... json_fp = Path(tmpdirname) / "4.json"
... _ = json_fp.write_text('["foo"]')
... result = expand_shards(tmpdirname)
... sorted(result.split(","))
['1', '3', 'evens/0/file_0', 'evens/0/file_2']
>>> expand_shards("train.invalid")
Traceback (most recent call last):
...
ValueError: Invalid shard format: train.invalid
"""
result = []
for arg in shards:
if os.path.isdir(arg):
# If the argument is a directory, take all parquet files in any subdirs of the directory
result.extend(
str(x.relative_to(Path(arg)).with_suffix("")) for x in Path(arg).glob("**/*.parquet")
)
else:
# Otherwise, treat it as a shard prefix and number of shards
match = re.match(r"(.+)([/_])(\d+)$", arg)
if match:
prefix = match.group(1)
delimiter = match.group(2)
num = int(match.group(3))
result.extend(f"{prefix}{delimiter}{i}" for i in range(num))
else:
raise ValueError(f"Invalid shard format: {arg}")
return ",".join(result)
[docs]def main() -> None: # pragma: no cover
print(expand_shards(*sys.argv[1:]))
if __name__ == "__main__": # pragma: no cover
main()