aces.query module

This module contains the main function for querying a task.

It accepts the configuration file and predicate columns, builds the tree, and recursively queries the tree.

aces.query.query(cfg: TaskExtractorConfig, predicates_df: DataFrame) DataFrame[source]

Query a task using the provided configuration file and predicates dataframe.

Parameters:
cfg: TaskExtractorConfig

TaskExtractorConfig object of the configuration file.

predicates_df: DataFrame

Polars predicates dataframe.

Returns:

The result of the task query, containing subjects who satisfy the conditions

defined in cfg. Timestamps for the start/end boundaries of each window specified in the task configuration, as well as predicate counts for each window, are provided.

Return type:

polars.DataFrame

Raises:
  • TypeError – If predicates_df is not a polars.DataFrame.

  • ValueError – If the (subject_id, timestamp) columns are not unique.

Examples

>>> from .config import PlainPredicateConfig, WindowConfig, EventConfig
>>> cfg = None # This is obviously invalid, but we're just testing the error case.
>>> predicates_df = {"subject_id": [1, 1], "timestamp": [1, 1]}
>>> query(cfg, predicates_df)
Traceback (most recent call last):
    ...
TypeError: Predicates dataframe type must be a polars.DataFrame. Got: <class 'dict'>.
>>> query(cfg, pl.DataFrame(predicates_df))
Traceback (most recent call last):
    ...
ValueError: The (subject_id, timestamp) columns must be unique.
>>> cfg = TaskExtractorConfig(
...     predicates={"A": PlainPredicateConfig("A")},
...     trigger=EventConfig("_ANY_EVENT"),
...     windows={
...         "pre": WindowConfig(None, "trigger", True, False, index_timestamp="start"),
...         "post": WindowConfig("pre.end", None, True, True, label="A"),
...     },
...     index_timestamp_window="pre",
...     label_window="post",
... )
>>> predicates_df = pl.DataFrame({
...     "subject_id": [1, 1, 3],
...     "timestamp": [datetime(1980, 12, 28), datetime(2010, 6, 20), datetime(2010, 5, 11)],
...     "A": [False, False, False],
...     "_ANY_EVENT": [True, True, True],
... })
>>> with caplog.at_level(logging.INFO):
...     result = query(cfg, predicates_df)
>>> result.select("subject_id", "trigger")
shape: (3, 2)
┌────────────┬─────────────────────┐
│ subject_id ┆ trigger             │
│ ---        ┆ ---                 │
│ i64        ┆ datetime[μs]        │
╞════════════╪═════════════════════╡
│ 1          ┆ 1980-12-28 00:00:00 │
│ 1          ┆ 2010-06-20 00:00:00 │
│ 3          ┆ 2010-05-11 00:00:00 │
└────────────┴─────────────────────┘
>>> "index_timestamp" in result.columns
True
>>> "label" in result.columns
True
>>> cfg = TaskExtractorConfig(
...     predicates={"A": PlainPredicateConfig("A", static=True)},
...     trigger=EventConfig("_ANY_EVENT"),
...     windows={},
... )
>>> with caplog.at_level(logging.INFO):
...     query(cfg, predicates_df)
shape: (0, 0)
┌┐
╞╡
└┘
>>> "Static variable criteria specified, filtering patient demographics..." in caplog.text
True
>>> "No static variable criteria specified, removing all rows with null timestamps..." in caplog.text
True
>>> predicates_df = pl.DataFrame({
...     "subject_id": [1, 1, 3],
...     "timestamp": [None, datetime(2010, 6, 20), datetime(2010, 5, 11)],
...     "A": [True, False, False],
...     "_ANY_EVENT": [False, False, False],
... })
>>> with caplog.at_level(logging.INFO):
...     result = query(cfg, predicates_df)
>>> "No valid rows found for the trigger event" in caplog.text
True