Code Tutorial with Synthetic MEDS Data

Set-up

Imports

First, let’s import ACES! Three modules - config, predicates, and query - are required to execute an end-to-end cohort extraction. omegaconf is also required to express our data config parameters in order to load our MEDS dataset. Other imports are only needed for visualization!

[1]:
import json

import pandas as pd
import yaml
from bigtree import print_tree
from IPython.display import display
from omegaconf import DictConfig

from aces import config, predicates, query

Directories

Next, let’s specify our paths and directories. In this tutorial, we will extract a cohort for a typical in-hospital mortality prediction task from the MEDS synthetic sample dataset. The task configuration file and sample data are both shipped with the repository in sample_configs/ and sample_data/ folders in the project root, respectively.

[2]:
config_path = "../../../sample_configs/inhospital_mortality.yaml"
data_path = "../../../sample_data/meds_sample/"

Configuration File

The task configuration file is the core configuration language that ACES uses to extract cohorts. Details about this configuration language is available in Configuration Language. In brief, the configuration file contains predicates, patient_demographics, trigger, and windows sections.

The predicates section is used to define dataset-specific concepts that are needed for the task. In our case of binary mortality prediction, we are interested in extracting a cohort of patients that have been admitted into the hospital and who were subsequently discharged or died. As such admission, discharge, death, and discharge_or_death would be handy predicates.

The patient_demographics section is used to define static concepts that remain constant for subjects over time. For instance, sex is a common static variable. Should we want to filter out cohort to patients with a specific sex, we can do so here in the same way as defining predicates. For more information on predicates, please refer to this guide. In this example, let’s say we are only interested in male patients.

We’d also like to make a prediction of mortality for each admission. Hence, a reasonable trigger event would be an admission predicate.

Suppose in our task, we’d like to set a constraint that the admission must have been more than 48 hours long. Additionally, for our prediction inputs, we’d like to use all information in the patient record up until 24 hours after admission, which must contain at least 5 event records (as we’d want to ensure there is sufficient input data). These clauses are captured in the windows section where each window is defined relative to another.

[3]:
with open(config_path) as stream:
    data_loaded = yaml.safe_load(stream)
    print(json.dumps(data_loaded, indent=4))
{
    "predicates": {
        "admission": {
            "code": {
                "regex": "ADMISSION//.*"
            }
        },
        "discharge": {
            "code": {
                "regex": "DISCHARGE//.*"
            }
        },
        "death": {
            "code": "DEATH"
        },
        "discharge_or_death": {
            "expr": "or(discharge, death)"
        }
    },
    "patient_demographics": {
        "male": {
            "code": "SEX//male"
        }
    },
    "trigger": "admission",
    "windows": {
        "input": {
            "start": null,
            "end": "trigger + 24h",
            "start_inclusive": true,
            "end_inclusive": true,
            "has": {
                "_ANY_EVENT": "(5, None)"
            },
            "index_timestamp": "end"
        },
        "gap": {
            "start": "trigger",
            "end": "start + 48h",
            "start_inclusive": false,
            "end_inclusive": true,
            "has": {
                "admission": "(None, 0)",
                "discharge": "(None, 0)",
                "death": "(None, 0)"
            }
        },
        "target": {
            "start": "gap.end",
            "end": "start -> discharge_or_death",
            "start_inclusive": false,
            "end_inclusive": true,
            "label": "death"
        }
    }
}

We can see that the input window begins at null (start of the patient record) and ends 24 hours after trigger (admission). A gap window is defined for 24 hours after the end of the input window, constraining the admission to be longer than 48 hours at minimum. Finally, a target window is specified from the end of the gap window to either the next discharge or death event (ie., discharge_or_death). This would allow us to extract a binary label for each patient in our cohort to be used in the prediction task (ie., field label in the target window, which will extract 0: discharged, 1: died). Additionally, an index_timestamp field is set as the end of the input window to denote when a prediction is made (ie., at the end of the input window when all input data is fed into the model), and can be used to index extraction results.

We now load our configuration file by passing its path (str) into config.TaskExtractorConfig.load(). This parses the configuration file for each of the three key sections indicated above and prepares ACES for extraction based on our defined constraints (inclusion/exclusion criteria for each window).

[4]:
cfg = config.TaskExtractorConfig.load(config_path=config_path)

Task Tree

With the configuration file loaded and parsed, we can access a visualization of a tree structure that is representative of our task of interest. As seen, the tree nodes are start and end time points of the windows that were defined in the configuration file, and the tree edges express the relationships between these windows. ACES will traverse this tree and recursively compute aggregated predicate counts for each subtree. This would allow us to filter our dataset to valid realizations of this task tree, which would make up our task cohort.

[5]:
tree = cfg.window_tree
print_tree(tree)
trigger
├── input.end
│   └── input.start
└── gap.end
    └── target.end

Data

This tutorial uses synthetic data of 100 patients stored in the MEDS standard. For more information about this data, please refer to the generation of this synthetic data in the ESGPT Documentation (separately converted to MEDS). Here is what the data looks like:

[6]:
pd.read_parquet(f"{data_path}/train/0.parquet").head()
[6]:
subject_id time code numeric_value
0 0 NaT SEX//male NaN
1 0 2010-06-24 13:23:00 ADMISSION//CARDIAC NaN
2 0 2010-06-24 13:23:00 HR -0.266073
3 0 2010-06-24 13:23:00 LAB//SpO2 0.283409
4 0 2010-06-24 13:23:00 TEMP 0.618533

Predicate Columns

The next step in our cohort extraction is the generation of predicate columns. Our defined dataset-agnostic windows (ie., complex task logic) are linked to dataset-specific predicates (ie., dataset observations and concepts), which facilitates the sharing of tasks across datasets. As such, the predicates dataframe is the foundational unit on which ACES acts upon.

A predicate column is simply a column containing numerical counts (often just 0’s and 1’s), representing the number of times a given predicate (concept) occurs at a given timestamp for a given patient.

In the case of MEDS (and ESGPT), ACES support the automatic generation of these predicate columns from the configuration file. However, some fields need to be provided via a DictConfig object. These include the path to the directory of the MEDS dataset (str) and the data standard (which is meds in this case).

Given this data configuration, we then call predicates.get_predicates_df() to generate the relevant predicate columns for our task. Due to the nature of the specified predicates, the resulting dataframe simply contains the unique (subject_id, timestamp) pairs and binary columns for each predicate. An additional predicate _ANY_EVENT is also generated - this will be used to enforce our constraint of the number of events in the input window.

[7]:
data_config = DictConfig({"path": data_path, "standard": "meds"})

predicates_df = predicates.get_predicates_df(cfg=cfg, data_config=data_config)
display(predicates_df)
Expand shards is not enabled but your data path is a directory. If you are working with sharded datasets or large-scale queries, using `expand_shards` and`data=sharded` will improve efficiency and completeness.
shape: (31_025, 8)
subject_idtimestampadmissiondischargedeathmaledischarge_or_death_ANY_EVENT
i64datetime[μs]i64i64i64i64i64i64
0null00010null
02010-06-24 13:23:00100001
02010-06-24 14:23:00000001
02010-06-24 15:23:00000001
02010-06-24 16:23:00000001
992010-11-20 08:20:06000001
992010-11-20 09:20:06000001
992010-11-20 10:20:06000001
992010-11-20 11:20:06000001
992010-11-20 12:20:06010011

End-to-End Query

Finally, with our task configuration object and the computed predicates dataframe, we can call query.query() to execute the extraction of our cohort.

Each row of the resulting dataframe is a valid realization of our task tree. Hence, each instance can be included in our cohort used for the prediction of in-hospital mortality as defined in our task configuration file. The output contains:

  • subject_id: subject IDs of our cohort (since we’d like to treat individual admissions as separate samples, there will be duplicate subject IDs)

  • index_timestamp: timestamp of when a prediction is made, which coincides with the end timestamp of the input window (as specified in our task configuration)

  • label: binary label of mortality, which is derived from the death predicate of the target window (as specified in our task configuration)

  • trigger: timestamp of the trigger event, which is the admission predicate (as specified in our task configuration)

Additionally, it also includes a column for each node of our task tree in a pre-order traversal order. Each column contains a pl.Struct object containing the name of the node, the start and end times of the window it represents, and the counts of all defined predicates in that window.

[8]:
df_result = query.query(cfg=cfg, predicates_df=predicates_df)
display(df_result)
All labels in the extracted cohort are the same: '0'. This may indicate an issue with the task logic. Please double-check your configuration file if this is not expected.
shape: (87, 8)
subject_idindex_timestamplabeltriggerinput.end_summaryinput.start_summarygap.end_summarytarget.end_summary
i64datetime[μs]i64datetime[μs]struct[8]struct[8]struct[8]struct[8]
02010-10-05 17:23:0002010-10-04 17:23:00{"input.end",2010-10-04 17:23:00,2010-10-05 17:23:00,0,0,0,0,23}{"input.start",2010-06-24 13:23:00,2010-10-05 17:23:00,2,1,0,1,44}{"gap.end",2010-10-04 17:23:00,2010-10-06 17:23:00,0,0,0,0,48}{"target.end",2010-10-06 17:23:00,2010-10-16 00:23:00,0,1,0,1,223}
12010-02-13 20:16:1302010-02-12 20:16:13{"input.end",2010-02-12 20:16:13,2010-02-13 20:16:13,0,0,0,0,22}{"input.start",2010-02-12 20:16:13,2010-02-13 20:16:13,1,0,0,0,24}{"gap.end",2010-02-12 20:16:13,2010-02-14 20:16:13,0,0,0,0,47}{"target.end",2010-02-14 20:16:13,2010-02-15 17:16:13,0,1,0,1,21}
22010-01-19 23:07:0702010-01-18 23:07:07{"input.end",2010-01-18 23:07:07,2010-01-19 23:07:07,0,0,0,0,19}{"input.start",2010-01-18 23:07:07,2010-01-19 23:07:07,1,0,0,0,21}{"gap.end",2010-01-18 23:07:07,2010-01-20 23:07:07,0,0,0,0,40}{"target.end",2010-01-20 23:07:07,2010-01-30 19:07:07,0,1,0,1,210}
42010-06-30 07:20:1402010-06-29 07:20:14{"input.end",2010-06-29 07:20:14,2010-06-30 07:20:14,0,0,0,0,22}{"input.start",2010-06-29 07:20:14,2010-06-30 07:20:14,1,0,0,0,24}{"gap.end",2010-06-29 07:20:14,2010-07-01 07:20:14,0,0,0,0,46}{"target.end",2010-07-01 07:20:14,2010-07-05 10:20:14,0,1,0,1,93}
42010-08-03 14:20:1402010-08-02 14:20:14{"input.end",2010-08-02 14:20:14,2010-08-03 14:20:14,0,0,0,0,22}{"input.start",2010-06-29 07:20:14,2010-08-03 14:20:14,2,1,0,1,164}{"gap.end",2010-08-02 14:20:14,2010-08-04 14:20:14,0,0,0,0,46}{"target.end",2010-08-04 14:20:14,2010-08-07 01:20:14,0,1,0,1,56}
982010-06-29 22:25:5202010-06-28 22:25:52{"input.end",2010-06-28 22:25:52,2010-06-29 22:25:52,0,0,0,0,21}{"input.start",2010-04-05 19:25:52,2010-06-29 22:25:52,2,1,0,1,135}{"gap.end",2010-06-28 22:25:52,2010-06-30 22:25:52,0,0,0,0,42}{"target.end",2010-06-30 22:25:52,2010-07-12 13:25:52,0,1,0,1,242}
982010-08-29 00:25:5202010-08-28 00:25:52{"input.end",2010-08-28 00:25:52,2010-08-29 00:25:52,0,0,0,0,19}{"input.start",2010-04-05 19:25:52,2010-08-29 00:25:52,3,2,0,2,419}{"gap.end",2010-08-28 00:25:52,2010-08-30 00:25:52,0,0,0,0,41}{"target.end",2010-08-30 00:25:52,2010-09-01 19:25:52,0,1,0,1,58}
992010-04-16 18:20:0602010-04-15 18:20:06{"input.end",2010-04-15 18:20:06,2010-04-16 18:20:06,0,0,0,0,21}{"input.start",2010-04-15 18:20:06,2010-04-16 18:20:06,1,0,0,0,23}{"gap.end",2010-04-15 18:20:06,2010-04-17 18:20:06,0,0,0,0,44}{"target.end",2010-04-17 18:20:06,2010-04-23 19:20:06,0,1,0,1,131}
992010-10-13 22:20:0602010-10-12 22:20:06{"input.end",2010-10-12 22:20:06,2010-10-13 22:20:06,0,0,0,0,20}{"input.start",2010-04-15 18:20:06,2010-10-13 22:20:06,2,1,0,1,198}{"gap.end",2010-10-12 22:20:06,2010-10-14 22:20:06,0,0,0,0,44}{"target.end",2010-10-14 22:20:06,2010-10-21 03:20:06,0,1,0,1,130}
992010-11-15 08:20:0602010-11-14 08:20:06{"input.end",2010-11-14 08:20:06,2010-11-15 08:20:06,0,0,0,0,21}{"input.start",2010-04-15 18:20:06,2010-11-15 08:20:06,3,2,0,2,374}{"gap.end",2010-11-14 08:20:06,2010-11-16 08:20:06,0,0,0,0,44}{"target.end",2010-11-16 08:20:06,2010-11-20 12:20:06,0,1,0,1,89}

… and that’s a wrap! We have used ACES to perform an end-to-end extraction on a MEDS dataset for a cohort that can be used to predict in-hospital mortality. Similar pipelines can be made for other tasks, as well as using the ESGPT data standard. You may also pre-compute predicate columns and use the direct flag when loading in .csv or .parquet data files. More information about this is available in Predicates DataFrame.

As always, please don’t hesitate to reach out should you have any questions about ACES!