Source code for aces.extract_subtree
"""This module contains the functions for extracting constraint hierarchy subtrees."""
import dataclasses
import logging
from datetime import timedelta
import polars as pl
from bigtree import Node
from .aggregate import aggregate_event_bound_window, aggregate_temporal_window
from .constraints import check_constraints
logger = logging.getLogger(__name__)
[docs]def extract_subtree(
subtree: Node,
subtree_anchor_realizations: pl.DataFrame,
predicates_df: pl.DataFrame,
subtree_root_offset: timedelta = timedelta(0),
) -> pl.DataFrame:
"""The main algorithmic recursive call to identify valid realizations of a subtree.
This function takes in a global ``predicates_df``, a subtree of constraints, and the temporal offset that
any realization the root timestamp of the subtree would have relative to the corresponding subtree anchor.
It will use this information to recurse through the subtree and identify any valid realizations of this
subtree, returning them in a dataframe keyed by the subtree anchor event timestamps and with a series of
columns containing subtree edge start and end timestamps and contained predicate counts.
Args:
subtree: The subtree to extract realizations from. This is specified through a `BigTree.Node` object.
This ``Node`` object can have zero or more children, each of which must have the following:
- ``name``: The name of the subtree root.
- ``constraints``: The constraints associated with the subtree root, structured as a dictionary
from predicate column name to a tuple containing the valid (inclusive) minimum and maximum
values the predicate counts can take on (use `None` for no constraint).
- ``endpoint_expr``: A tuple containing the endpoint expression for the subtree root. This
should be either a `ToEventWindowBounds` or a `TemporalWindowBounds` formatted tuple object,
less the offset parameter, as that is something determined by the structure of the tree, not
pre-set in the configuration.
subtree_anchor_realizations: The dataframe containing the anchor to subtree root mapping. This
dataframe will have the following columns:
- ``"subject_id"``: The ID of the subject. All analyses will be performed within ``subject_id``
groups.
- ``subtree_anchor_timestamp``: The timestamp of all possible prospective subtree anchor
realizations. These will all correspond to extant events (``subject_id``, ``timestamp`` pairs
in ``predicates_df``).
predicates_df: The dataframe containing the predicates to summarize. This dataframe will have the
following mandatory columns:
subtree_root_offset: The temporal offset of the subtree root relative to the subtree anchor.
Returns:
pl.DataFrame: The result of the subtree extraction, containing subjects who satisfy the conditions
defined in the subtree. Timestamps for the start/end boundaries of each window specified in the
subtree configuration, as well as predicate counts for each window, are provided.
Examples:
>>> from .types import ToEventWindowBounds, TemporalWindowBounds
>>> # We'll use an example for in-hospital mortality prediction. Our root event of the tree will be
>>> # an admission event.
>>> root = Node("admission")
>>> #
>>> #### BRANCH 1 ####
>>> # Our first branch off of admission will be checking a gap window, then our target window.
>>> # Node 1 will represent our gap window. We say that in the 24 hours after the admission, there
>>> # should be no discharges, deaths, or covid events.
>>> gap_node = Node("gap") # This sets the node's name.
>>> gap_node.endpoint_expr = TemporalWindowBounds(True, timedelta(days=2), True)
>>> gap_node.constraints = {
... "is_discharge": (None, 0), "is_death": (None, 0), "is_covid_dx": (None, 0)
... }
>>> gap_node.parent = root
>>> # Node 2 will start our target window and span until the next discharge or death event.
>>> # There should be no covid events.
>>> target_node = Node("target") # This sets the node's name.
>>> target_node.endpoint_expr = ToEventWindowBounds(True, "is_discharge", True)
>>> target_node.constraints = {"is_covid_dx": (None, 0)}
>>> target_node.parent = gap_node
>>> #
>>> #### BRANCH 2 ####
>>> # Finally, for our second branch, we will impose no constraints but track the input time range,
>>> # which will span from the beginning of the record to 24 hours after admission.
>>> input_end_node = Node("input_end")
>>> input_end_node.endpoint_expr = TemporalWindowBounds(True, timedelta(days=1), True)
>>> input_end_node.constraints = {}
>>> input_end_node.parent = root
>>> input_start_node = Node("input_start")
>>> input_start_node.endpoint_expr = ToEventWindowBounds(True, "-_RECORD_START", True)
>>> input_start_node.constraints = {}
>>> input_start_node.parent = root
>>> #
>>> #### BRANCH 3 ####
>>> # For our last branch, we will validate that the patient has sufficient historical data, asserting
>>> # that they should have at least 1 event of any kind at least 1 year prior to the trigger event.
>>> # This will be expressed through two windows, one spanning back a year, and the other looking
>>> # prior to that year.
>>> pre_node_1yr = Node("pre_node_1yr")
>>> pre_node_1yr.endpoint_expr = TemporalWindowBounds(False, timedelta(days=-365), False)
>>> pre_node_1yr.constraints = {}
>>> pre_node_1yr.parent = root
>>> pre_node_total = Node("pre_node_total")
>>> pre_node_total.endpoint_expr = ToEventWindowBounds(False, "-_RECORD_START", False)
>>> pre_node_total.constraints = {"*": (1, None)}
>>> pre_node_total.parent = pre_node_1yr
>>> #
>>> #### PREDICATES_DF ####
>>> # We'll have the following patient data:
>>> # - subject 1 will have an admission that won't count because they'll have a covid diagnosis,
>>> # then an admission that won't count because there will be no associated discharge.
>>> # - subject 2 will have an admission that won't count because they'll have too little data before
>>> # it, then a second admission that will count.
>>> # - subject 3 will have an admission that will be too short.
>>> #
>>> predicates_df = pl.DataFrame({
... "subject_id": [
... 1, 1, 1, 1, 1, # Pre-event, Admission, Covid, Discharge, Admission.
... 2, 2, 2, 2, 2, # Pre-event-too-close, Admission, Discharge, Admission, Death & Discharge.
... 3, 3, 3, # Pre-event, Admission, Death
... ],
... "timestamp": [
... # Subject 1
... datetime(year=1980, month=12, day=1, hour=12, minute=3), # Pre-event
... datetime(year=1989, month=12, day=3, hour=13, minute=14), # Admission
... datetime(year=1989, month=12, day=5, hour=15, minute=17), # Covid
... datetime(year=1989, month=12, day=7, hour=11, minute=4), # Discharge
... datetime(year=1989, month=12, day=23, hour=3, minute=12), # Admission
... # Subject 2
... datetime(year=1983, month=12, day=1, hour=22, minute=2), # Pre-event-too-close
... datetime(year=1983, month=12, day=2, hour=12, minute=3), # Admission
... datetime(year=1983, month=12, day=8, hour=13, minute=14), # Discharge
... datetime(year=1989, month=12, day=6, hour=15, minute=17), # Valid Admission
... datetime(year=1989, month=12, day=10, hour=16, minute=22), # Death & Discharge
... # Subject 3
... datetime(year=1982, month=2, day=13, hour=10, minute=44), # Pre-event
... datetime(year=1999, month=12, day=6, hour=15, minute=17), # Admission
... datetime(year=1999, month=12, day=6, hour=16, minute=22), # Discharge
... ],
... "is_admission": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0],
... "is_discharge": [0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1],
... "is_death": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
... "is_covid_dx": [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
... "_ANY_EVENT": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
... })
>>> subtreee_anchor_realizations = (
... predicates_df
... .filter(pl.col("is_admission") > 0)
... .rename({"timestamp": "subtree_anchor_timestamp"})
... .select("subject_id", "subtree_anchor_timestamp")
... )
>>> print(subtreee_anchor_realizations)
shape: (5, 2)
┌────────────┬──────────────────────────┐
│ subject_id ┆ subtree_anchor_timestamp │
│ --- ┆ --- │
│ i64 ┆ datetime[μs] │
╞════════════╪══════════════════════════╡
│ 1 ┆ 1989-12-03 13:14:00 │
│ 1 ┆ 1989-12-23 03:12:00 │
│ 2 ┆ 1983-12-02 12:03:00 │
│ 2 ┆ 1989-12-06 15:17:00 │
│ 3 ┆ 1999-12-06 15:17:00 │
└────────────┴──────────────────────────┘
>>> out = extract_subtree(root, subtreee_anchor_realizations, predicates_df, timedelta(0))
>>> out.select("subject_id", "subtree_anchor_timestamp")
shape: (1, 2)
┌────────────┬──────────────────────────┐
│ subject_id ┆ subtree_anchor_timestamp │
│ --- ┆ --- │
│ i64 ┆ datetime[μs] │
╞════════════╪══════════════════════════╡
│ 2 ┆ 1989-12-06 15:17:00 │
└────────────┴──────────────────────────┘
>>> out.columns
['subject_id',
'target_summary',
'subtree_anchor_timestamp',
'gap_summary',
'input_end_summary',
'input_start_summary',
'pre_node_total_summary',
'pre_node_1yr_summary']
>>> def print_window(name: str, do_drop_any_events: bool = True):
... drop_cols = ["window_name", "subject_id", "subtree_anchor_timestamp"]
... if do_drop_any_events:
... drop_cols.append("_ANY_EVENT")
... return (
... out.select("subject_id", "subtree_anchor_timestamp", name)
... .unnest(name)
... .drop(*drop_cols)
... )
>>> print_window("gap_summary")
shape: (1, 6)
┌─────────────────────┬─────────────────────┬──────────────┬──────────────┬──────────┬─────────────┐
│ timestamp_at_start ┆ timestamp_at_end ┆ is_admission ┆ is_discharge ┆ is_death ┆ is_covid_dx │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════════════════════╪══════════════╪══════════════╪══════════╪═════════════╡
│ 1989-12-06 15:17:00 ┆ 1989-12-08 15:17:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │
└─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘
>>> print_window("target_summary")
shape: (1, 6)
┌─────────────────────┬─────────────────────┬──────────────┬──────────────┬──────────┬─────────────┐
│ timestamp_at_start ┆ timestamp_at_end ┆ is_admission ┆ is_discharge ┆ is_death ┆ is_covid_dx │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════════════════════╪══════════════╪══════════════╪══════════╪═════════════╡
│ 1989-12-08 15:17:00 ┆ 1989-12-10 16:22:00 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │
└─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘
>>> print_window("input_start_summary")
shape: (1, 6)
┌─────────────────────┬─────────────────────┬──────────────┬──────────────┬──────────┬─────────────┐
│ timestamp_at_start ┆ timestamp_at_end ┆ is_admission ┆ is_discharge ┆ is_death ┆ is_covid_dx │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════════════════════╪══════════════╪══════════════╪══════════╪═════════════╡
│ 1983-12-01 22:02:00 ┆ 1989-12-06 15:17:00 ┆ 2 ┆ 1 ┆ 0 ┆ 0 │
└─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘
>>> print_window("input_end_summary")
shape: (1, 6)
┌─────────────────────┬─────────────────────┬──────────────┬──────────────┬──────────┬─────────────┐
│ timestamp_at_start ┆ timestamp_at_end ┆ is_admission ┆ is_discharge ┆ is_death ┆ is_covid_dx │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════════════════════╪══════════════╪══════════════╪══════════╪═════════════╡
│ 1989-12-06 15:17:00 ┆ 1989-12-07 15:17:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │
└─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘
>>> print_window("pre_node_1yr_summary")
shape: (1, 6)
┌─────────────────────┬─────────────────────┬──────────────┬──────────────┬──────────┬─────────────┐
│ timestamp_at_start ┆ timestamp_at_end ┆ is_admission ┆ is_discharge ┆ is_death ┆ is_covid_dx │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════════════════════╪══════════════╪══════════════╪══════════╪═════════════╡
│ 1989-12-06 15:17:00 ┆ 1988-12-06 15:17:00 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │
└─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘
>>> print_window("pre_node_total_summary")
shape: (1, 6)
┌─────────────────────┬─────────────────────┬──────────────┬──────────────┬──────────┬─────────────┐
│ timestamp_at_start ┆ timestamp_at_end ┆ is_admission ┆ is_discharge ┆ is_death ┆ is_covid_dx │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════════════════════╪══════════════╪══════════════╪══════════╪═════════════╡
│ 1983-12-01 22:02:00 ┆ 1988-12-06 15:17:00 ┆ 1 ┆ 1 ┆ 0 ┆ 0 │
└─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘
>>> root = Node("root")
>>> child = Node("child")
>>> child.endpoint_expr = (True, timedelta(days=3))
>>> child.constraints = {}
>>> child.parent = root
>>> predicates_df = pl.DataFrame({
... "subject_id": [1],
... "timestamp": [datetime(2020, 1, 1)]
... })
>>> subtree_anchor_realizations = pl.DataFrame({
... "subject_id": [1],
... "subtree_anchor_timestamp": [datetime(2020, 1, 1)]
... })
>>> print(child.endpoint_expr)
(True, datetime.timedelta(days=3))
>>> extract_subtree(root, subtree_anchor_realizations, predicates_df, timedelta(0))
shape: (1, 3)
┌────────────┬──────────────────────────┬─────────────────────────────────┐
│ subject_id ┆ subtree_anchor_timestamp ┆ child_summary │
│ --- ┆ --- ┆ --- │
│ i64 ┆ datetime[μs] ┆ struct[3] │
╞════════════╪══════════════════════════╪═════════════════════════════════╡
│ 1 ┆ 2020-01-01 00:00:00 ┆ {"child",2020-01-01 00:00:00,2… │
└────────────┴──────────────────────────┴─────────────────────────────────┘
>>> print(child.endpoint_expr)
(True, datetime.timedelta(days=3))
>>> child.endpoint_expr = (True, 42)
>>> extract_subtree(root, subtree_anchor_realizations, predicates_df, timedelta(0))
Traceback (most recent call last):
...
ValueError: Invalid endpoint expression: '(True, 42, datetime.timedelta(0))'
"""
recursive_results = []
predicate_cols = [c for c in predicates_df.columns if c not in {"subject_id", "timestamp"}]
if not subtree.children:
return subtree_anchor_realizations
for child in subtree.children:
logger.info(f"Summarizing subtree rooted at '{child.name}'...")
# Step 1: Summarize the window from the subtree.root to child.
# Construct a new endpoint_expr with the accumulated offset instead of mutating
# child.endpoint_expr in place — the child node is reused across calls, so mutating
# it would cause subtree_root_offset to compound on every invocation.
endpoint_expr = child.endpoint_expr
if type(endpoint_expr) is tuple:
endpoint_expr = (*endpoint_expr, subtree_root_offset)
else:
endpoint_expr = dataclasses.replace(
endpoint_expr, offset=endpoint_expr.offset + subtree_root_offset
)
match endpoint_expr[1]:
case timedelta():
child_root_offset = subtree_root_offset + endpoint_expr[1]
window_summary_df = (
aggregate_temporal_window(predicates_df, endpoint_expr)
.with_columns(
pl.col("timestamp").alias("subtree_anchor_timestamp"),
pl.col("timestamp").alias("child_anchor_timestamp"),
)
.drop("timestamp")
)
case str():
# In an event bound case, the child root will be a proper extant event, so it will be the
# anchor as well, and thus the child root offset should be zero.
child_root_offset = timedelta(days=0)
if endpoint_expr.end_event.startswith("-"):
child_anchor_time = "timestamp_at_start"
else:
child_anchor_time = "timestamp_at_end"
window_summary_df = (
aggregate_event_bound_window(predicates_df, endpoint_expr)
.with_columns(
pl.col("timestamp").alias("subtree_anchor_timestamp"),
pl.col(child_anchor_time).alias("child_anchor_timestamp"),
)
.drop("timestamp")
)
case _:
raise ValueError(f"Invalid endpoint expression: '{endpoint_expr}'")
# Step 2: Filter to valid subtree anchors
window_summary_df = window_summary_df.join(
subtree_anchor_realizations, on=["subject_id", "subtree_anchor_timestamp"], how="inner"
)
# Step 3: Filter to where constraints are valid
window_summary_df = check_constraints(child.constraints, window_summary_df)
# Step 4: Produce child anchor realizations
child_anchor_realizations = window_summary_df.select(
"subject_id",
pl.col("child_anchor_timestamp").alias("subtree_anchor_timestamp"),
).unique(maintain_order=True)
# Step 5: Recurse
recursive_result = extract_subtree(
child,
child_anchor_realizations,
predicates_df,
child_root_offset,
)
# Step 6: Join summaries and timestamps
# Step 6.1: Convert recursive_result up to subtree anchor space.
recursive_result = (
recursive_result.rename({"subtree_anchor_timestamp": "child_anchor_timestamp"})
.join(
window_summary_df.select("subject_id", "subtree_anchor_timestamp", "child_anchor_timestamp"),
on=["subject_id", "child_anchor_timestamp"],
how="left",
)
.drop("child_anchor_timestamp")
)
# Step 6.2: Summarize the observed window statistics and timestamps for eventual return.
for_return = window_summary_df.select(
"subject_id",
"subtree_anchor_timestamp",
pl.struct(
pl.lit(child.name).alias("window_name"),
"timestamp_at_start",
"timestamp_at_end",
*predicate_cols,
).alias(f"{child.name}_summary"),
)
recursive_results.append(
recursive_result.join(for_return, on=["subject_id", "subtree_anchor_timestamp"], how="left")
)
# Step 7: Join children recursive results where all children find a valid realization
all_children = recursive_results[0]
for df in recursive_results[1:]:
all_children = all_children.join(df, on=["subject_id", "subtree_anchor_timestamp"], how="inner")
return all_children