aces.expand_shards module

aces.expand_shards.expand_shards(*shards: str) str[source]

This function expands a set of shard prefixes and number of shards into a list of all shards or expands a directory into a list of all files within it.

This can be useful with Hydra applications where you wish to expand a list of options for the sweeper to sweep over but can’t use an OmegaConf resolver as those are evaluated after the sweep has been initialized.

Parameters:
*shards: str

A list of shard prefixes and number of shards to expand, or a directory to list all files.

Returns: A comma-separated list of all shards, expanded to the specified number, or all files in the

directory.

Examples

>>> import polars as pl
>>> import tempfile
>>> expand_shards("train/4", "val/IID/1", "val/prospective/1")
'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0'
>>> expand_shards("data/data_4", "data/test_4")
'data/data_0,data/data_1,data/data_2,data/data_3,data/test_0,data/test_1,data/test_2,data/test_3'
>>> parquet_data = pl.DataFrame({
...     "subject_id": [1, 1, 1, 2, 3],
...     "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None],
...     "code": ['admission', 'discharge', 'discharge', 'admission', "gender"],
... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M"))
>>> with tempfile.TemporaryDirectory() as tmpdirname:
...     for i in range(4):
...         if i in (0, 2):
...             data_path = Path(tmpdirname) / f"evens/0/file_{i}.parquet"
...             data_path.parent.mkdir(parents=True, exist_ok=True)
...         else:
...             data_path = Path(tmpdirname) / f"{i}.parquet"
...         parquet_data.write_parquet(data_path)
...     json_fp = Path(tmpdirname) / "4.json"
...     _ = json_fp.write_text('["foo"]')
...     result = expand_shards(tmpdirname)
...     sorted(result.split(","))
['1', '3', 'evens/0/file_0', 'evens/0/file_2']
>>> expand_shards("train.invalid")
Traceback (most recent call last):
    ...
ValueError: Invalid shard format: train.invalid
aces.expand_shards.main() None[source]