aces.expand_shards module¶
- aces.expand_shards.expand_shards(*shards: str) str[source]¶
This function expands a set of shard prefixes and number of shards into a list of all shards or expands a directory into a list of all files within it.
This can be useful with Hydra applications where you wish to expand a list of options for the sweeper to sweep over but can’t use an OmegaConf resolver as those are evaluated after the sweep has been initialized.
- Parameters:¶
- Returns: A comma-separated list of all shards, expanded to the specified number, or all files in the
directory.
Examples
>>> import polars as pl >>> import tempfile>>> expand_shards("train/4", "val/IID/1", "val/prospective/1") 'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0' >>> expand_shards("data/data_4", "data/test_4") 'data/data_0,data/data_1,data/data_2,data/data_3,data/test_0,data/test_1,data/test_2,data/test_3'>>> parquet_data = pl.DataFrame({ ... "subject_id": [1, 1, 1, 2, 3], ... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender"], ... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M"))>>> with tempfile.TemporaryDirectory() as tmpdirname: ... for i in range(4): ... if i in (0, 2): ... data_path = Path(tmpdirname) / f"evens/0/file_{i}.parquet" ... data_path.parent.mkdir(parents=True, exist_ok=True) ... else: ... data_path = Path(tmpdirname) / f"{i}.parquet" ... parquet_data.write_parquet(data_path) ... json_fp = Path(tmpdirname) / "4.json" ... _ = json_fp.write_text('["foo"]') ... result = expand_shards(tmpdirname) ... sorted(result.split(",")) ['1', '3', 'evens/0/file_0', 'evens/0/file_2']>>> expand_shards("train.invalid") Traceback (most recent call last): ... ValueError: Invalid shard format: train.invalid