Source code for aces.config
"""This module contains functions for loading and parsing the configuration file and subsequently building a
tree structure from the configuration."""
from __future__ import annotations
import dataclasses
import logging
import re
from dataclasses import field
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any
import networkx as nx
import polars as pl
import ruamel.yaml
from bigtree import Node
from .types import (
ANY_EVENT_COLUMN,
END_OF_RECORD_KEY,
START_OF_RECORD_KEY,
TemporalWindowBounds,
ToEventWindowBounds,
)
from .utils import parse_timedelta
if TYPE_CHECKING:
from collections import OrderedDict
logger = logging.getLogger(__name__)
[docs]@dataclasses.dataclass
class PlainPredicateConfig:
code: str | dict[str, Any]
value_min: float | None = None
value_max: float | None = None
value_min_inclusive: bool | None = None
value_max_inclusive: bool | None = None
static: bool = False
other_cols: dict[str, str] = field(default_factory=dict)
[docs] def MEDS_eval_expr(self) -> pl.Expr:
"""Returns a Polars expression that evaluates this predicate for a MEDS formatted dataset.
Note: The output syntax for the following examples is dependent on the polars version used. The
expected outputs may vary depending on the installed polars version.
Examples:
>>> print(PlainPredicateConfig("BP//systolic", 120, 140, True, False).MEDS_eval_expr())
[(col("code")) == ("BP//systolic")].all_horizontal([[(col("numeric_value")) >=
(dyn int: 120)], [(col("numeric_value")) < (dyn int: 140)]])
>>> cfg = PlainPredicateConfig("BP//systolic", value_min=120, value_min_inclusive=False)
>>> print(cfg.MEDS_eval_expr())
[(col("code")) == ("BP//systolic")].all_horizontal([[(col("numeric_value")) >
(dyn int: 120)]])
>>> cfg = PlainPredicateConfig("BP//systolic", value_max=140, value_max_inclusive=True)
>>> print(cfg.MEDS_eval_expr())
[(col("code")) == ("BP//systolic")].all_horizontal([[(col("numeric_value")) <=
(dyn int: 140)]])
>>> print(PlainPredicateConfig("BP//diastolic").MEDS_eval_expr())
[(col("code")) == ("BP//diastolic")]
>>> cfg = PlainPredicateConfig("BP//diastolic", other_cols={"chamber": "atrial"})
>>> print(cfg.MEDS_eval_expr())
[(col("code")) == ("BP//diastolic")].all_horizontal([[(col("chamber")) ==
("atrial")]])
>>> PlainPredicateConfig(code={'regex': None, 'any': None}).MEDS_eval_expr()
Traceback (most recent call last):
...
ValueError: Only one of 'regex' or 'any' can be specified in the code field!
Got: ['regex', 'any'].
>>> PlainPredicateConfig(code={'foo': None}).MEDS_eval_expr()
Traceback (most recent call last):
...
ValueError: Invalid specification in the code field! Got: {'foo': None}.
Expected one of 'regex', 'any'.
>>> PlainPredicateConfig(code={'regex': ''}).MEDS_eval_expr()
Traceback (most recent call last):
...
ValueError: Invalid specification in the code field! Got: {'regex': ''}.
Expected a non-empty string for 'regex'.
>>> PlainPredicateConfig(code={'any': []}).MEDS_eval_expr()
Traceback (most recent call last):
...
ValueError: Invalid specification in the code field! Got: {'any': []}.
Expected a list of strings for 'any'.
>>> print(PlainPredicateConfig(code={'regex': '^foo.*'}).MEDS_eval_expr())
col("code").str.contains(["^foo.*"])
>>> print(PlainPredicateConfig(code={'regex': '^foo.*'}, value_min=120).MEDS_eval_expr())
col("code").str.contains(["^foo.*"]).all_horizontal([[(col("numeric_value")) >
(dyn int: 120)]])
>>> print(PlainPredicateConfig(code={'any': ['foo', 'bar']}).MEDS_eval_expr())
col("code").is_in([["foo", "bar"]])
"""
criteria = []
if isinstance(self.code, dict):
if len(self.code) > 1:
raise ValueError(
"Only one of 'regex' or 'any' can be specified in the code field! "
f"Got: {list(self.code.keys())}."
)
if "regex" in self.code:
if not self.code["regex"] or not isinstance(self.code["regex"], str):
raise ValueError(
"Invalid specification in the code field! "
f"Got: {self.code}. "
"Expected a non-empty string for 'regex'."
)
criteria.append(pl.col("code").str.contains(self.code["regex"]))
elif "any" in self.code:
if not self.code["any"] or not isinstance(self.code["any"], list):
raise ValueError(
"Invalid specification in the code field! "
f"Got: {self.code}. "
f"Expected a list of strings for 'any'."
)
criteria.append(pl.Expr.is_in(pl.col("code"), self.code["any"]))
else:
raise ValueError(
"Invalid specification in the code field! "
f"Got: {self.code}. "
"Expected one of 'regex', 'any'."
)
else:
criteria.append(pl.col("code") == self.code)
if self.value_min is not None:
if self.value_min_inclusive:
criteria.append(pl.col("numeric_value") >= self.value_min)
else:
criteria.append(pl.col("numeric_value") > self.value_min)
if self.value_max is not None:
if self.value_max_inclusive:
criteria.append(pl.col("numeric_value") <= self.value_max)
else:
criteria.append(pl.col("numeric_value") < self.value_max)
if self.other_cols:
criteria.extend([pl.col(col) == value for col, value in self.other_cols.items()])
if len(criteria) == 1:
return criteria[0]
else:
return pl.all_horizontal(criteria)
[docs] def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr:
"""Returns a Polars expression that evaluates this predicate for a ESGPT formatted dataset.
Note: The output syntax for the following examples is dependent on the polars version used. The
expected outputs may vary depending on the installed polars version.
Examples:
>>> cfg = PlainPredicateConfig("HR", value_min=120, value_min_inclusive=False)
>>> print(cfg.ESGPT_eval_expr("HR"))
[(col("HR")) > (dyn int: 120)]
>>> print(PlainPredicateConfig("BP//systolic", 120, 140, True, False).ESGPT_eval_expr("BP_value"))
[(col("BP")) == ("systolic")].all_horizontal([[(col("BP_value")) >=
(dyn int: 120)], [(col("BP_value")) < (dyn int: 140)]])
>>> cfg = PlainPredicateConfig("BP//systolic", value_min=120, value_min_inclusive=False)
>>> print(cfg.ESGPT_eval_expr("blood_pressure_value"))
[(col("BP")) == ("systolic")].all_horizontal([[(col("blood_pressure_value")) >
(dyn int: 120)]])
>>> cfg = PlainPredicateConfig("BP//systolic", value_max=140, value_max_inclusive=True)
>>> print(cfg.ESGPT_eval_expr("blood_pressure_value"))
[(col("BP")) == ("systolic")].all_horizontal([[(col("blood_pressure_value")) <=
(dyn int: 140)]])
>>> print(PlainPredicateConfig("BP//diastolic").ESGPT_eval_expr())
[(col("BP")) == ("diastolic")]
>>> print(PlainPredicateConfig("event_type//ADMISSION").ESGPT_eval_expr())
col("event_type").strict_cast(String).str.split(["&"]).list.contains(["ADMISSION"])
>>> print(PlainPredicateConfig("BP//diastolic//atrial").ESGPT_eval_expr())
[(col("BP")) == ("diastolic//atrial")]
>>> print(PlainPredicateConfig("BP//diastolic", None, None).ESGPT_eval_expr())
[(col("BP")) == ("diastolic")]
>>> print(PlainPredicateConfig("BP").ESGPT_eval_expr())
col("BP").is_not_null()
>>> print(PlainPredicateConfig("BP//systole", other_cols={"chamber": "atrial"}).ESGPT_eval_expr())
[(col("BP")) == ("systole")].all_horizontal([[(col("chamber")) == ("atrial")]])
>>> PlainPredicateConfig("BP//systolic", value_min=120).ESGPT_eval_expr()
Traceback (most recent call last):
...
ValueError: Must specify a values column for ESGPT predicates with a value_min = 120
>>> PlainPredicateConfig("BP//systolic", value_max=140).ESGPT_eval_expr()
Traceback (most recent call last):
...
ValueError: Must specify a values column for ESGPT predicates with a value_max = 140
"""
code_is_in_parts = "//" in self.code
if code_is_in_parts:
codes = self.code.split("//")
measurement_name = codes.pop(0)
code = "//".join(codes) if len(codes) > 1 else codes[0]
if measurement_name.lower() == "event_type":
criteria = [pl.col("event_type").cast(pl.String).str.split("&").list.contains(code)]
else:
criteria = [pl.col(measurement_name) == code]
elif (self.value_min is None) and (self.value_max is None):
return pl.col(self.code).is_not_null()
else:
values_column = self.code
criteria = []
if self.value_min is not None:
if values_column is None:
raise ValueError(
f"Must specify a values column for ESGPT predicates with a value_min = {self.value_min}"
)
if self.value_min_inclusive:
criteria.append(pl.col(values_column) >= self.value_min)
else:
criteria.append(pl.col(values_column) > self.value_min)
if self.value_max is not None:
if values_column is None:
raise ValueError(
f"Must specify a values column for ESGPT predicates with a value_max = {self.value_max}"
)
if self.value_max_inclusive:
criteria.append(pl.col(values_column) <= self.value_max)
else:
criteria.append(pl.col(values_column) < self.value_max)
if self.other_cols:
criteria.extend([pl.col(col) == value for col, value in self.other_cols.items()])
if len(criteria) == 1:
return criteria[0]
else:
return pl.all_horizontal(criteria)
@property
def is_plain(self) -> bool:
return True
[docs]@dataclasses.dataclass
class DerivedPredicateConfig:
"""A configuration object for derived predicates, which are composed of multiple input predicates.
Args:
expr: The expression defining the derived predicate in terms of other predicates.
Raises:
ValueError: If the expression is empty, does not start with 'and(' or 'or(', or does not contain at
least two input predicates.
Examples:
>>> pred = DerivedPredicateConfig("and(P1, P2, P3)")
>>> pred = DerivedPredicateConfig("and()")
Traceback (most recent call last):
...
ValueError: Derived predicate expression must have at least two input predicates (comma separated).
Got: 'and()'
>>> pred = DerivedPredicateConfig("or(PA, PB)")
>>> pred = DerivedPredicateConfig("PA + PB")
Traceback (most recent call last):
...
ValueError: Derived predicate expression must start with 'and(' or 'or('. Got: 'PA + PB'
>>> pred = DerivedPredicateConfig("")
Traceback (most recent call last):
...
ValueError: Derived predicates must have a non-empty expression field.
"""
expr: str
static: bool = False
def __post_init__(self) -> None:
if not self.expr:
raise ValueError("Derived predicates must have a non-empty expression field.")
self.is_and = self.expr.startswith("and(") and self.expr.endswith(")")
self.is_or = self.expr.startswith("or(") and self.expr.endswith(")")
if not (self.is_and or self.is_or):
raise ValueError(
f"Derived predicate expression must start with 'and(' or 'or('. Got: '{self.expr}'"
)
if self.is_and:
self.input_predicates = [x.strip() for x in self.expr[4:-1].split(",")]
elif self.is_or:
self.input_predicates = [x.strip() for x in self.expr[3:-1].split(",")]
if len(self.input_predicates) < 2:
raise ValueError(
"Derived predicate expression must have at least two input predicates (comma separated). "
f"Got: '{self.expr}'"
)
[docs] def eval_expr(self) -> pl.Expr:
"""Returns a Polars expression that evaluates this predicate against necessary dependent predicates.
Note: The output syntax for the following examples is dependent on the polars version used. The
expected outputs may vary depending on the installed polars version.
Examples:
>>> print(DerivedPredicateConfig("and(P1, P2, P3)").eval_expr())
[(col("P1")) > (dyn int: 0)].all_horizontal([[(col("P2")) >
(dyn int: 0)], [(col("P3")) > (dyn int: 0)]])
>>> print(DerivedPredicateConfig("or(PA, PB)").eval_expr())
[(col("PA")) > (dyn int: 0)].any_horizontal([[(col("PB")) > (dyn int: 0)]])
"""
if self.is_and:
return pl.all_horizontal([pl.col(pred) > 0 for pred in self.input_predicates])
elif self.is_or:
return pl.any_horizontal([pl.col(pred) > 0 for pred in self.input_predicates])
@property
def is_plain(self) -> bool:
return False
[docs]@dataclasses.dataclass
class WindowConfig:
"""A configuration object for defining a window in the task extraction process.
This defines the boundary points and constraints for a window in the patient record in the task extraction
process.
Args:
start: The boundary conditions for the start of the window. This (like ``end``) can either be `None`,
in which case the window starts at the beginning of the patient record, or is expressed through a
string language that expresses a relative startpoint to this window either in reference to (a)
another window's start or end event, (b) this window's `end` event. In case (a), this window's end
event must either be `None` or reference this window's start event, and in case (b), this window's
end event must reference a different window's start or end event.
The string language is as follows:
- ``None``: The window starts at the beginning of the patient record.
- ``$REFERENCED <- $PREDICATE`` or ``$REFERENCED -> $PREDICATE``: The window starts at the
closest event satisfying the predicate ``$PREDICATE`` relative to the ``$REFERENCED`` event.
Form ``$REFERENCED <- $PREDICATE`` means that the window starts at the closest event _prior
to_ the ``$REFERENCED`` event that satisfies the predicate ``$PREDICATE``, and the other form
is analogous but with the closest event _after_ the ``$REFERENCED`` event.
- ``$REFERENCED +- timedelta``: The window starts at the ``$REFERENCED`` event plus or minus the
specified timedelta. The timedelta is expressed through the string language specified in the
`utils.parse_timedelta` function.
- ``$REFERENCED``: The window starts at the ``$REFERENCED`` event.
In all cases, the ``$REFERENCED`` event must be either
- The name of another window's start or end event, as specified by ``$WINDOW_NAME.start`` or
``$WINDOW_NAME.end``.
- This window's end event, as specified by ``end``.
In the case that ``$REFERENCED`` is this window's end event, the window must be defined such that
``start`` would precede ``end`` in the order of the patient record (e.g., ``$PREDICATE -> end`` is
invalid, and ``end - timedelta`` is invalid).
end: The name of the event that ends the window. See the documentation for ``start`` for more details
on the specification language.
start_inclusive: Whether or not the start event is included in the window. Note that this term can not
only dictate whether an event's counts are included in the summarization of the window, but also
whether or not an event satisfying ``$PREDICATE`` can be used as the boundary of an event. E.g.,
if we have that `start_inclusive=False` and the `end` field is equal to `start -> $PREDICATE`, and
it so happens that the `start` event itself satisfies `$PREDICATE`, the fact that
`start_inclusive=False` will mean that we do not consider the `start` event itself to be a valid
start to any window that ends at the same `start` event, as its timestamp when considered as the
prospective "window start timestamp" occurs "after" the effective timestamp of itself when
considered as the `$PREDICATE` event that marks the window end given that `start_inclusive=False`
and thus we will think of the window as truly starting an iota after the timestamp of the `start`
event itself.
end_inclusive: Whether or not the end event is included in the window.
has: A dictionary of predicates that must be present in the window, mapped to tuples of the form
`(min_valid, max_valid)` that define the valid range the count of observations of the named
predicate that must be found in a window for it to be considered valid. Either `min_valid` or
`max_valid` constraints can be `None`, in which case those endpoints are left unconstrained.
Likewise, unreferenced predicates are also left unconstrained. Note that as predicate counts are
always integral, this specification does not need an additional inclusive/exclusive endpoint
field, as one can simply increment the bound by one in the appropriate direction to achieve the
result. Instead, this bound is always interpreted to be inclusive, so a window would satisfy the
constraint for predicate `name` with constraint `name: (1, 2)` if the count of observations of
predicate `name` in a window was either 1 or 2. All constraints in the dictionary must be
satisfied on a window for it to be included.
label: A string that specifies the name of a predicate to be used as the label for the task. The
predicate count of the window this field is specified in will be extracted as a column in the
final result. Hence, there can only be one 'label' per TaskExtractorConfig. If more than one
'label' is specified, an error is raised. If the specified 'label' is not a defined predicate,
an error is also raised. If no 'label' is specified, there will be not be a 'label' column.
index_timestamp: A string that is either 'start' or 'end' and is used to index result rows. If it is
defined, there will be an 'index_timestamp' column in the result with its values equal to the
'start' or 'end' timestamp of the window in which it was specified. Usually, this will be
specified to indicate the time of prediction for the task, which is often the 'end' of the input
window. There can only be one 'index_timestamp' per TaskExtractorConfig. If more than one
'index_timestamp' is specified, an error is raised. If the specified 'index_timestamp' is not
'start' or 'end', an error is also raised. If no 'index_timestamp' is defined, there will be no
'index_timestamp' column.
Raises:
ValueError: If the window is misconfigured in any of a variety of ways; see below for examples.
Examples:
>>> input_window = WindowConfig(
... start=None,
... end="trigger + 2 days",
... start_inclusive=True,
... end_inclusive=True,
... has={"_ANY_EVENT": "(5, None)"},
... index_timestamp="end",
... )
>>> input_window.referenced_event
('trigger',)
>>> # This window does not reference any "true" external predicates, only implicit predicates like
>>> # start, end, and * events, so this list should be empty.
>>> sorted(input_window.referenced_predicates)
['_ANY_EVENT']
>>> input_window.start_endpoint_expr
ToEventWindowBounds(left_inclusive=True,
end_event='-_RECORD_START',
right_inclusive=True,
offset=datetime.timedelta(0))
>>> input_window.end_endpoint_expr
TemporalWindowBounds(left_inclusive=False,
window_size=datetime.timedelta(days=2),
right_inclusive=False,
offset=datetime.timedelta(0))
>>> input_window.root_node
'end'
>>> gap_window = WindowConfig(
... start="input.end",
... end="start + 24h",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
>>> gap_window.referenced_event
('input', 'end')
>>> sorted(gap_window.referenced_predicates)
['death', 'discharge']
>>> gap_window.start_endpoint_expr is None
True
>>> gap_window.end_endpoint_expr
TemporalWindowBounds(left_inclusive=False,
window_size=datetime.timedelta(days=1),
right_inclusive=True,
offset=datetime.timedelta(0))
>>> gap_window.root_node
'start'
>>> gap_window = WindowConfig(
... start="input.end",
... end="start + 0h",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
>>> gap_window.referenced_event
('input', 'end')
>>> sorted(gap_window.referenced_predicates)
['death', 'discharge']
>>> gap_window.start_endpoint_expr is None
True
>>> gap_window.end_endpoint_expr is None
True
>>> gap_window.root_node
'start'
>>> target_window = WindowConfig(
... start="gap.end",
... end="start -> discharge_or_death",
... start_inclusive=False,
... end_inclusive=True,
... has={}
... )
>>> target_window.referenced_event
('gap', 'end')
>>> sorted(target_window.referenced_predicates)
['discharge_or_death']
>>> target_window.start_endpoint_expr is None
True
>>> target_window.end_endpoint_expr
ToEventWindowBounds(left_inclusive=False,
end_event='discharge_or_death',
right_inclusive=True,
offset=datetime.timedelta(0))
>>> target_window.root_node
'start'
>>> target_window = WindowConfig(
... start="end",
... end="gap.end <- discharge_or_death",
... start_inclusive=False,
... end_inclusive=True,
... has={}
... )
>>> target_window.referenced_event
('gap', 'end')
>>> sorted(target_window.referenced_predicates)
['discharge_or_death']
>>> target_window.start_endpoint_expr is None
True
>>> target_window.end_endpoint_expr
ToEventWindowBounds(left_inclusive=False,
end_event='-discharge_or_death',
right_inclusive=False,
offset=datetime.timedelta(0))
>>> target_window.root_node
'end'
>>> invalid_window = WindowConfig(
... start="gap.end gap.start",
... end="start -> discharge_or_death",
... start_inclusive=False,
... end_inclusive=True,
... has={}
... )
Traceback (most recent call last):
...
ValueError: Window boundary reference must be either a valid alphanumeric/'_' string or a reference to
another window's start or end event, formatted as a valid alphanumeric/'_' string, followed by
'.start' or '.end'.
Got: 'gap.end gap.start'
>>> invalid_window = WindowConfig(
... start="input",
... end="start window -> discharge_or_death",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Window boundary reference must be either a valid alphanumeric/'_' string or a reference
to another window's start or end event, formatted as a valid alphanumeric/'_' string, followed by
'.start' or '.end'. Got: 'start window'
>>> invalid_window = WindowConfig(
... start="input",
... end="window.foo -> discharge_or_death",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Window boundary reference must be either a valid alphanumeric/'_' string or a reference
to another window's start or end event, formatted as a valid alphanumeric/'_' string, followed by
'.start' or '.end'. Got: 'window.foo'
>>> invalid_window = WindowConfig(
... start=None, end=None, start_inclusive=True, end_inclusive=True, has={}
... )
Traceback (most recent call last):
...
ValueError: Window cannot progress from the start of the record to the end of the record.
>>> invalid_window = WindowConfig(
... start="input.end",
... end="start - 2d",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Window start will not occur before window end! Got: input.end -> start - 2d
>>> invalid_window = WindowConfig(
... start="end -> predicate",
... end="input.end",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Window start will not occur before window end! Got: end -> predicate -> input.end
>>> invalid_window = WindowConfig(
... start="end - 24h", end="start + 1d", start_inclusive=True, end_inclusive=True, has={}
... )
Traceback (most recent call last):
...
ValueError: Exactly one of the start or end of the window must reference the other.
Got: end - 24h -> start + 1d
>>> invalid_window = WindowConfig(
... start="input.end",
... end="input.end + 2d",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Exactly one of the start or end of the window must reference the other.
Got: input.end -> input.end + 2d
>>> invalid_window = WindowConfig(
... start="input.end",
... end="start + -24h",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Window boundary cannot contain both '+' and '-' operators.
>>> invalid_window = WindowConfig(
... start="input.end",
... end="start + invalid time string.",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(None, 0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Failed to parse timedelta from window offset for 'invalid time string.'
>>> target_window = WindowConfig(
... start="gap.end",
... end="start <-> discharge_or_death",
... start_inclusive=False,
... end_inclusive=True,
... has={}
... )
Traceback (most recent call last):
...
ValueError: Window boundary cannot contain both '->' and '<-' operators.
>>> invalid_window = WindowConfig(
... start="input.end",
... end="input.end + 2d",
... start_inclusive=False,
... end_inclusive=True,
... has={"discharge": "(0)", "death": "(None, 0)"}
... )
Traceback (most recent call last):
...
ValueError: Invalid constraint format: discharge.
Expected format: '(min, max)'. Got: '(0)'
"""
start: str | None
end: str | None
start_inclusive: bool
end_inclusive: bool
has: dict[str, str] = field(default_factory=dict)
label: str | None = None
index_timestamp: str | None = None
@classmethod
def _check_reference(cls, reference: str) -> None:
"""Checks to ensure referenced events are valid."""
err_str = (
"Window boundary reference must be either a valid alphanumeric/'_' string "
"or a reference to another window's start or end event, formatted as a valid "
f"alphanumeric/'_' string, followed by '.start' or '.end'. Got: '{reference}'"
)
if "." in reference:
if reference.count(".") > 1:
raise ValueError(err_str)
window, event = reference.split(".")
if event not in {"start", "end"} or not re.match(r"^\w+$", window):
raise ValueError(err_str)
elif not re.match(r"^\w+$", reference):
raise ValueError(err_str)
@classmethod
def _parse_boundary(cls, boundary: str) -> dict[str, str]:
if "->" in boundary or "<-" in boundary:
if "->" in boundary and "<-" in boundary:
raise ValueError("Window boundary cannot contain both '->' and '<-' operators.")
elif "->" in boundary:
ref, predicate = (x.strip() for x in boundary.split("->"))
else:
ref, predicate = (x.strip() for x in boundary.split("<-"))
predicate = "-" + predicate
cls._check_reference(ref)
return {
"referenced": ref,
"offset": None,
"event_bound": predicate,
"occurs_before": "-" in predicate,
}
elif "+" in boundary or "-" in boundary:
if "+" in boundary and "-" in boundary:
raise ValueError("Window boundary cannot contain both '+' and '-' operators.")
elif "+" in boundary:
ref, offset = (x.strip() for x in boundary.split("+"))
else:
ref, offset = (x.strip() for x in boundary.split("-"))
offset = "-" + offset
cls._check_reference(ref)
try:
parsed_offset = parse_timedelta(offset)
if parsed_offset == timedelta(0):
logger.warning(f"Window offset for {boundary} is zero; this may not be intended.")
return {"referenced": ref, "offset": None, "event_bound": None, "occurs_before": None}
except (ValueError, TypeError) as e:
raise ValueError(f"Failed to parse timedelta from window offset for '{offset}'") from e
return {"referenced": ref, "offset": offset, "event_bound": None, "occurs_before": "-" in offset}
else:
ref = boundary.strip()
cls._check_reference(ref)
return {"referenced": ref, "offset": None, "event_bound": None, "occurs_before": None}
def __post_init__(self) -> None:
# Parse the has constraints from the string representation to the tuple representation
if self.has is not None:
for each_constraint in self.has:
elements = self.has[each_constraint].strip("()").split(",")
elements = [element.strip() for element in elements]
if len(elements) != 2:
raise ValueError(
f"Invalid constraint format: {each_constraint}. "
f"Expected format: '(min, max)'. Got: '{self.has[each_constraint]}'"
)
self.has[each_constraint] = tuple(
int(element) if element not in ("None", "") else None for element in elements
)
if self.start is None and self.end is None:
raise ValueError("Window cannot progress from the start of the record to the end of the record.")
if self.start is None:
self._parsed_start = {
"referenced": "end",
"offset": None,
"event_bound": f"-{START_OF_RECORD_KEY}",
"occurs_before": True,
}
else:
self._parsed_start = self._parse_boundary(self.start)
if self.end is None:
self._parsed_end = {
"referenced": "start",
"offset": None,
"event_bound": END_OF_RECORD_KEY,
"occurs_before": False,
}
else:
self._parsed_end = self._parse_boundary(self.end)
if self._parsed_start["referenced"] == "end" and self._parsed_end["referenced"] == "start":
raise ValueError(
"Exactly one of the start or end of the window must reference the other. "
f"Got: {self.start} -> {self.end}"
)
elif self._parsed_start["referenced"] == "end":
self._start_references_end = True
# We use `is False` because it may be None, which is distinct from True or False
if self._parsed_start["occurs_before"] is False:
raise ValueError(
f"Window start will not occur before window end! Got: {self.start} -> {self.end}"
)
elif self._parsed_end["referenced"] == "start":
self._start_references_end = False
# We use `is True` because it may be None, which is distinct from True or False
if self._parsed_end["occurs_before"] is True:
raise ValueError(
f"Window start will not occur before window end! Got: {self.start} -> {self.end}"
)
else:
raise ValueError(
"Exactly one of the start or end of the window must reference the other. "
f"Got: {self.start} -> {self.end}"
)
@property
def root_node(self) -> str:
"""Returns 'start' if the end of the window is defined relative to the start and 'end' otherwise."""
return "end" if self._start_references_end else "start"
@property
def referenced_event(self) -> tuple[str]:
if self._start_references_end:
return tuple(self._parsed_end["referenced"].split("."))
else:
return tuple(self._parsed_start["referenced"].split("."))
@property
def constraint_predicates(self) -> set[str]:
predicates = set(self.has.keys())
return predicates
@property
def referenced_predicates(self) -> set[str]:
predicates = set(self.has.keys())
if self._parsed_start["event_bound"]:
predicates.add(self._parsed_start["event_bound"].replace("-", ""))
if self._parsed_end["event_bound"]:
predicates.add(self._parsed_end["event_bound"].replace("-", ""))
predicates -= {START_OF_RECORD_KEY, END_OF_RECORD_KEY}
return predicates
@property
def start_endpoint_expr(self) -> None | ToEventWindowBounds | TemporalWindowBounds:
if self._start_references_end:
# If end references start, then end will occur after start, so `left_inclusive` corresponds to
# `start_inclusive` and `right_inclusive` corresponds to `end_inclusive`.
left_inclusive = self.start_inclusive
right_inclusive = self.end_inclusive
else:
# If this window references end from start, then the end event window expression will not have
# any constraints as it will reference an external event, and therefore the inclusive
# parameters don't matter.
left_inclusive = False
right_inclusive = False
if self._parsed_start["event_bound"]:
return ToEventWindowBounds(
end_event=self._parsed_start["event_bound"],
left_inclusive=left_inclusive,
right_inclusive=right_inclusive,
)
elif self._parsed_start["offset"]:
return TemporalWindowBounds(
window_size=parse_timedelta(self._parsed_start["offset"]),
left_inclusive=left_inclusive,
right_inclusive=right_inclusive,
)
else:
return None
@property
def end_endpoint_expr(self) -> None | ToEventWindowBounds | TemporalWindowBounds:
if self._start_references_end:
# If this window references end from start, then the end event window expression will not have
# any constraints as it will reference an external event, and therefore the inclusive
# parameters don't matter.
left_inclusive = False
right_inclusive = False
else:
# If end references start, then end will occur after start, so `left_inclusive` corresponds to
# `start_inclusive` and `right_inclusive` corresponds to `end_inclusive`.
left_inclusive = self.start_inclusive
right_inclusive = self.end_inclusive
if self._parsed_end["event_bound"]:
return ToEventWindowBounds(
end_event=self._parsed_end["event_bound"],
left_inclusive=left_inclusive,
right_inclusive=right_inclusive,
)
elif self._parsed_end["offset"]:
return TemporalWindowBounds(
window_size=parse_timedelta(self._parsed_end["offset"]),
left_inclusive=left_inclusive,
right_inclusive=right_inclusive,
)
else:
return None
[docs]@dataclasses.dataclass
class EventConfig:
"""A configuration object for defining the event that triggers the task extraction process.
This is defined by all events that match a simple predicate. This event serves as the root of the window
tree, and its form is dictated by the fact that we must be able to localize the tree to identify valid
realizations of the tree.
Examples:
>>> event = EventConfig("event_type//ADMISSION")
>>> event.predicate
'event_type//ADMISSION'
"""
predicate: str
[docs]@dataclasses.dataclass
class TaskExtractorConfig:
"""A configuration object for parsing the plain-data stored in a task extractor config.
This class can be serialized to and deserialized from a YAML file, and is largely a collection of
utilities to parse, validate, and leverage task extraction configuration data in practice. There is no
state stored in this class that is not present or recoverable from the source YAML file on disk. It also
can be read from a simplified, "user-friendly" language, which can also be stored on or read from disk,
which is ultimately parsed into the expansive, full specification contained in the YAML file referenced
above.
Args:
predicates: A dictionary of predicate configurations, stored as either plain or derived predicate
configuration objects (which are simple dataclasses with utility functions over plain
dictionaries).
trigger: The event configuration that triggers the task extraction process. This is a simple
dataclass with a single field, the name of the predicate that triggers the task extraction and
serves as the root of the window tree.
windows: A dictionary of window configurations. Each window configuration is a simple dataclass with
that can be materialized to/from a simple, POD dictionary.
Raises:
ValueError: If any window or predicate names are not composed of alphanumeric or "_" characters.
Examples:
>>> from bigtree import print_tree
>>> predicates = {
... "admission": PlainPredicateConfig("admission"),
... "discharge": PlainPredicateConfig("discharge"),
... "death": PlainPredicateConfig("death"),
... "death_or_discharge": DerivedPredicateConfig("or(death, discharge)"),
... "diabetes_icd9": PlainPredicateConfig("ICD9CM//250.02"),
... "diabetes_icd10": PlainPredicateConfig("ICD10CM//E11.65"),
... "diabetes": DerivedPredicateConfig("or(diabetes_icd9, diabetes_icd10)"),
... "diabetes_and_discharge": DerivedPredicateConfig("and(diabetes, discharge)"),
... }
>>> trigger = EventConfig("admission")
>>> windows = {
... "input": WindowConfig(
... start=None,
... end="trigger + 24h",
... start_inclusive=True,
... end_inclusive=True,
... has={"_ANY_EVENT": "(32, None)"},
... index_timestamp="end",
... ),
... "gap": WindowConfig(
... start="input.end",
... end="start + 24h",
... start_inclusive=False,
... end_inclusive=True,
... has={"death_or_discharge": "(None, 0)", "admission": "(None, 0)"},
... ),
... "target": WindowConfig(
... start="gap.end",
... end="start -> death_or_discharge",
... start_inclusive=False,
... end_inclusive=True,
... has={},
... label="death",
... ),
... }
>>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows)
>>> print(config.plain_predicates)
{'admission': PlainPredicateConfig(code='admission',
value_min=None,
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False,
other_cols={}),
'discharge': PlainPredicateConfig(code='discharge',
value_min=None,
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False,
other_cols={}),
'death': PlainPredicateConfig(code='death',
value_min=None,
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False,
other_cols={}),
'diabetes_icd9': PlainPredicateConfig(code='ICD9CM//250.02',
value_min=None,
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False,
other_cols={}),
'diabetes_icd10': PlainPredicateConfig(code='ICD10CM//E11.65',
value_min=None,
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False,
other_cols={})}
>>> print(config.label_window)
target
>>> print(config.index_timestamp_window)
input
>>> print(config.derived_predicates)
{'death_or_discharge': DerivedPredicateConfig(expr='or(death, discharge)', static=False),
'diabetes': DerivedPredicateConfig(expr='or(diabetes_icd9, diabetes_icd10)', static=False),
'diabetes_and_discharge': DerivedPredicateConfig(expr='and(diabetes, discharge)', static=False)}
>>> print(nx.write_network_text(config.predicates_DAG))
╟── death
╎ └─╼ death_or_discharge ╾ discharge
╟── discharge
╎ ├─╼ diabetes_and_discharge ╾ diabetes
╎ └─╼ ...
╟── diabetes_icd9
╎ └─╼ diabetes ╾ diabetes_icd10
╎ └─╼ ...
╙── diabetes_icd10
└─╼ ...
>>> print_tree(config.window_tree)
trigger
└── input.end
├── input.start
└── gap.end
└── target.end
Configs will error out in various ways when passed inappropriate arguments:
>>> config_path = "/foo/non_existent_file.yaml"
>>> cfg = TaskExtractorConfig.load(config_path)
Traceback (most recent call last):
...
FileNotFoundError: Cannot load missing configuration file /foo/non_existent_file.yaml!
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f:
... config_path = Path(f.name)
... cfg = TaskExtractorConfig.load(config_path)
Traceback (most recent call last):
...
ValueError: Only supports reading from '.yaml'. Got: '.txt' in ....txt'.
>>> predicates_path = "/foo/non_existent_predicates.yaml"
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
... config_path = Path(f.name)
... cfg = TaskExtractorConfig.load(config_path, predicates_path)
Traceback (most recent call last):
...
FileNotFoundError: Cannot load missing predicates file /foo/non_existent_predicates.yaml!
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f:
... predicates_path = Path(f.name)
... with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f2:
... config_path = Path(f2.name)
... cfg = TaskExtractorConfig.load(config_path, predicates_path)
Traceback (most recent call last):
...
ValueError: Only supports reading from '.yaml'. Got: '.txt' in ....txt'.
>>> data = {
... 'predicates': {},
... 'trigger': {},
... 'foo': {}
... }
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
... config_path = Path(f.name)
... yaml.dump(data, f)
... cfg = TaskExtractorConfig.load(config_path)
Traceback (most recent call last):
...
ValueError: Unrecognized keys in configuration file: 'foo'
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
... predicates_path = Path(f.name)
... yaml.dump(data, f)
... with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f2:
... config_path = Path(f2.name)
... cfg = TaskExtractorConfig.load(config_path, predicates_path)
Traceback (most recent call last):
...
ValueError: Unrecognized keys in configuration file: 'foo, trigger'
>>> predicates = {"foo bar": PlainPredicateConfig("foo")}
>>> trigger = EventConfig("foo")
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={})
Traceback (most recent call last):
...
ValueError: Predicate name 'foo bar' is invalid; must be composed of alphanumeric or '_' characters.
>>> predicates = {"foo": str("foo")}
>>> trigger = EventConfig("foo")
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={})
Traceback (most recent call last):
...
ValueError: Invalid predicate configuration for 'foo': foo. Must be either a PlainPredicateConfig or
DerivedPredicateConfig object. Got: <class 'str'>
>>> predicates = {
... "foo": PlainPredicateConfig("foo"),
... "foobar": DerivedPredicateConfig("or(foo, bar)"),
... }
>>> trigger = EventConfig("foo")
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={})
Traceback (most recent call last):
...
KeyError: "Missing 1 relationships: Derived predicate 'foobar' references undefined predicate 'bar'"
>>> predicates = {"foo": PlainPredicateConfig("foo")}
>>> trigger = EventConfig("foo")
>>> windows = {"foo bar": WindowConfig("gap.end", "start + 24h", True, True)}
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows)
Traceback (most recent call last):
...
ValueError: Window name 'foo bar' is invalid; must be composed of alphanumeric or '_' characters.
>>> windows = {"foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "bar")}
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows)
Traceback (most recent call last):
...
ValueError: Label must be one of the defined predicates. Got: bar for window 'foo'
>>> windows = {"foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo", "bar")}
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows)
Traceback (most recent call last):
...
ValueError: Index timestamp must be either 'start' or 'end'. Got: bar for window 'foo'
>>> windows = {
... "foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo"),
... "bar": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo")
... }
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows)
Traceback (most recent call last):
...
ValueError: Only one window can be labeled, found 2 labeled windows: foo, bar
>>> windows = {
... "foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo", "start"),
... "bar": WindowConfig("gap.end", "start + 24h", True, True, {}, index_timestamp="start")
... }
>>> TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows)
Traceback (most recent call last):
...
ValueError: Only the 'start'/'end' of one window can be used as the index timestamp, found
2 windows with index_timestamp: foo, bar
>>> predicates = {"foo": PlainPredicateConfig("foo")}
>>> TaskExtractorConfig(predicates=predicates, trigger=EventConfig("bar"), windows={})
Traceback (most recent call last):
...
KeyError: "Trigger event predicate 'bar' not found in predicates: foo"
"""
predicates: dict[str, PlainPredicateConfig | DerivedPredicateConfig]
trigger: EventConfig
windows: dict[str, WindowConfig] | None
label_window: str | None = None
index_timestamp_window: str | None = None
[docs] @classmethod
def load(
cls: TaskExtractorConfig,
config_path: str | Path,
predicates_path: str | Path | None = None,
) -> TaskExtractorConfig:
"""Load a configuration file from the given path and return it as a dict.
Args:
cls: The TaskExtractorConfig class that is instantiated.
config_path: The path to which a configuration object will be read from in YAML form.
predicates_path: The path to which a predicates configuration object will be read from in
YAML form. Used to override predicates in the original configuration file.
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file is not a ".yaml" file.
Examples:
>>> yaml = ruamel.yaml.YAML(typ="safe", pure=True)
>>> config_dict = {
... "metadata": {'description': 'A test configuration file'},
... "description": 'this is a test',
... "predicates": {"admission": {"code": "admission"}},
... "trigger": "admission",
... "windows": {
... "start": {
... "start": None, "end": "trigger + 24h", "start_inclusive": True,
... "end_inclusive": True,
... }
... },
... }
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
... config_path = Path(f.name)
... yaml.dump(config_dict, f)
... cfg = TaskExtractorConfig.load(config_path)
>>> cfg
TaskExtractorConfig(predicates={'admission': PlainPredicateConfig(code='admission',
value_min=None, value_max=None,
value_min_inclusive=None, value_max_inclusive=None,
static=False, other_cols={})},
trigger=EventConfig(predicate='admission'),
windows={'start': WindowConfig(start=None, end='trigger + 24h',
start_inclusive=True, end_inclusive=True, has={},
label=None, index_timestamp=None)},
label_window=None, index_timestamp_window=None)
>>> predicates_dict = {
... "metadata": {'description': 'A test predicates file'},
... "description": 'this is a test',
... "patient_demographics": {"brown_eyes": {"code": "eye_color//BR"}},
... "predicates": {"admission": {"code": "admission"}},
... }
>>> no_predicates_config = {k: v for k, v in config_dict.items() if k != "predicates"}
>>> with (tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as config_fp,
... tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as pred_fp):
... config_path = Path(config_fp.name)
... pred_path = Path(pred_fp.name)
... yaml.dump(no_predicates_config, config_fp)
... yaml.dump(predicates_dict, pred_fp)
... cfg = TaskExtractorConfig.load(config_path, pred_path)
>>> cfg
TaskExtractorConfig(predicates={'admission': PlainPredicateConfig(code='admission',
value_min=None, value_max=None,
value_min_inclusive=None, value_max_inclusive=None,
static=False, other_cols={}),
'brown_eyes': PlainPredicateConfig(code='eye_color//BR',
value_min=None, value_max=None,
value_min_inclusive=None,
value_max_inclusive=None, static=True,
other_cols={})},
trigger=EventConfig(predicate='admission'),
windows={'start': WindowConfig(start=None, end='trigger + 24h',
start_inclusive=True, end_inclusive=True, has={},
label=None, index_timestamp=None)},
label_window=None, index_timestamp_window=None)
>>> config_dict = {
... "metadata": {'description': 'A test configuration file'},
... "description": 'this is a test for joining static and plain predicates',
... "patient_demographics": {"male": {"code": "MALE"}, "female": {"code": "FEMALE"}},
... "predicates": {"normal_male_lab_range": {"code": "LAB", "value_min": 0, "value_max": 100,
... "value_min_inclusive": True, "value_max_inclusive": True},
... "normal_female_lab_range": {"code": "LAB", "value_min": 0, "value_max": 90,
... "value_min_inclusive": True, "value_max_inclusive": True},
... "normal_lab_male": {"expr": "and(normal_male_lab_range, male)"},
... "normal_lab_female": {"expr": "and(normal_female_lab_range, female)"}},
... "trigger": "_ANY_EVENT",
... "windows": {
... "start": {
... "start": None, "end": "trigger + 24h", "start_inclusive": True,
... "end_inclusive": True, "has": {"normal_lab_male": "(1, None)"},
... }
... },
... }
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
... config_path = Path(f.name)
... yaml.dump(config_dict, f)
... cfg = TaskExtractorConfig.load(config_path)
>>> cfg.predicates.keys()
dict_keys(['normal_lab_male', 'normal_male_lab_range', 'female', 'male'])
>>> config_dict = {
... "metadata": {'description': 'A test configuration file'},
... "description": 'this is a test for nested derived predicates',
... "patient_demographics": {"male": {"code": "MALE"}, "female": {"code": "FEMALE"}},
... "predicates": {"abnormally_low_male_lab_range": {"code": "LAB", "value_max": 90,
... "value_max_inclusive": False},
... "abnormally_low_female_lab_range": {"code": "LAB", "value_max": 80,
... "value_max_inclusive": False},
... "abnormally_high_lab_range": {"code": "LAB", "value_min": 120,
... "value_min_inclusive": False},
... "abnormal_lab_male_range": {"expr":
... "or(abnormally_low_male_lab_range, abnormally_high_lab_range)"},
... "abnormal_lab_female_range": {"expr":
... "or(abnormally_low_female_lab_range, abnormally_high_lab_range)"},
... "abnormal_lab_male": {"expr": "and(abnormal_lab_male_range, male)"},
... "abnormal_lab_female": {"expr": "and(abnormal_lab_female_range, female)"},
... "abnormal_labs": {"expr": "or(abnormal_lab_male, abnormal_lab_female)"}},
... "trigger": "_ANY_EVENT",
... "windows": {
... "start": {
... "start": None, "end": "trigger + 24h", "start_inclusive": True,
... "end_inclusive": True, "label": "abnormal_labs",
... "has": {"abnormal_labs": "(1, None)"},
... }
... },
... }
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
... config_path = Path(f.name)
... yaml.dump(config_dict, f)
... cfg = TaskExtractorConfig.load(config_path)
>>> cfg.predicates.keys()
dict_keys(['abnormal_lab_female', 'abnormal_lab_female_range', 'abnormal_lab_male',
'abnormal_lab_male_range', 'abnormal_labs', 'abnormally_high_lab_range',
'abnormally_low_female_lab_range', 'abnormally_low_male_lab_range', 'female', 'male'])
>>> predicates_dict = {
... "metadata": {'description': 'A test predicates file'},
... "description": 'this is a test',
... "patient_demographics": {"brown_eyes": {"code": "eye_color//BR"}},
... "predicates": {'admission': "invalid"},
... }
>>> with (tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as config_fp,
... tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as pred_fp):
... config_path = Path(config_fp.name)
... pred_path = Path(pred_fp.name)
... yaml.dump(no_predicates_config, config_fp)
... yaml.dump(predicates_dict, pred_fp)
... cfg = TaskExtractorConfig.load(config_path, pred_path)
Traceback (most recent call last):
...
ValueError: Predicate 'admission' is not defined correctly in the configuration file. Currently
defined as the string: invalid. Please refer to the documentation for the supported formats.
>>> predicates_dict = {
... "predicates": {'adm': {"code": "admission"}},
... }
>>> with (tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as config_fp,
... tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as pred_fp):
... config_path = Path(config_fp.name)
... pred_path = Path(pred_fp.name)
... yaml.dump(no_predicates_config, config_fp)
... yaml.dump(predicates_dict, pred_fp)
... cfg = TaskExtractorConfig.load(config_path, pred_path)
Traceback (most recent call last):
...
KeyError: "Something referenced predicate 'admission' that wasn't defined in the configuration."
>>> config_dict = {
... "predicates": {"A": {"code": "A"}, "B": {"code": "B"}, "A_or_B": {"expr": "or(A, B)"},
... "A_or_B_and_C": {"expr": "and(A_or_B, C)"}},
... "trigger": "_ANY_EVENT",
... "windows": {"start": {"start": None, "end": "trigger + 24h", "start_inclusive": True,
... "end_inclusive": True, "has": {"A_or_B_and_C": "(1, None)"}}},
... }
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
... config_path = Path(f.name)
... yaml.dump(config_dict, f)
... cfg = TaskExtractorConfig.load(config_path)
Traceback (most recent call last):
...
KeyError: "Predicate 'C' referenced in 'A_or_B_and_C' is not defined in the configuration."
"""
if isinstance(config_path, str):
config_path = Path(config_path)
if not config_path.is_file():
raise FileNotFoundError(f"Cannot load missing configuration file {config_path.resolve()!s}!")
if config_path.suffix == ".yaml":
yaml = ruamel.yaml.YAML(typ="safe", pure=True)
loaded_dict = yaml.load(config_path.read_text())
else:
raise ValueError(
f"Only supports reading from '.yaml'. Got: '{config_path.suffix}' in '{config_path.name}'."
)
overriding_predicates = {}
overriding_demographics = {}
if predicates_path:
if isinstance(predicates_path, str):
predicates_path = Path(predicates_path)
if not predicates_path.is_file():
raise FileNotFoundError(f"Cannot load missing predicates file {predicates_path.resolve()!s}!")
if predicates_path.suffix == ".yaml":
yaml = ruamel.yaml.YAML(typ="safe", pure=True)
predicates_dict = yaml.load(predicates_path.read_text())
else:
raise ValueError(
f"Only supports reading from '.yaml'. Got: '{predicates_path.suffix}' in "
f"'{predicates_path.name}'."
)
# Remove the description or metadata keys if they exist - currently unused except for readability
# in the YAML
_ = predicates_dict.pop("description", None)
_ = predicates_dict.pop("metadata", None)
overriding_predicates = predicates_dict.pop("predicates", {})
overriding_demographics = predicates_dict.pop("patient_demographics", {})
if predicates_dict:
raise ValueError(
f"Unrecognized keys in configuration file: '{', '.join(predicates_dict.keys())}'"
)
# Remove the description or metadata keys if they exist - currently unused except for readability
# in the YAML
_ = loaded_dict.pop("description", None)
_ = loaded_dict.pop("metadata", None)
trigger = loaded_dict.pop("trigger")
windows = loaded_dict.pop("windows", None)
predicates = loaded_dict.pop("predicates", {})
patient_demographics = loaded_dict.pop("patient_demographics", {})
if loaded_dict:
raise ValueError(f"Unrecognized keys in configuration file: '{', '.join(loaded_dict.keys())}'")
final_predicates = {**predicates, **overriding_predicates}
final_demographics = {**patient_demographics, **overriding_demographics}
all_predicates = {**final_predicates, **final_demographics}
logger.info("Parsing windows...")
if windows is None: # pragma: no cover
windows = {}
logger.warning(
"No windows specified in configuration file. Extracting only matching trigger events."
)
else:
windows = {n: WindowConfig(**w) for n, w in windows.items()}
logger.info("Parsing trigger event...")
trigger = EventConfig(trigger)
# add window referenced predicates
referenced_predicates = {pred for w in windows.values() for pred in w.referenced_predicates}
# add trigger predicate
referenced_predicates.add(trigger.predicate)
# add label predicate if it exists and not already added
label_reference = [w.label for w in windows.values() if w.label]
if label_reference:
referenced_predicates.update(set(label_reference))
special_predicates = {ANY_EVENT_COLUMN, START_OF_RECORD_KEY, END_OF_RECORD_KEY}
for pred in set(referenced_predicates) - special_predicates:
if pred not in all_predicates:
raise KeyError(
f"Something referenced predicate '{pred}' that wasn't defined in the configuration."
)
if "expr" in all_predicates[pred]:
stack = list(DerivedPredicateConfig(**all_predicates[pred]).input_predicates)
while stack:
nested_pred = stack.pop()
if nested_pred not in all_predicates:
raise KeyError(
f"Predicate '{nested_pred}' referenced in '{pred}' is not defined in the "
"configuration."
)
# if nested_pred is a DerivedPredicateConfig, unpack input_predicates and add to stack
if "expr" in all_predicates[nested_pred]:
derived_config = DerivedPredicateConfig(**all_predicates[nested_pred])
stack.extend(derived_config.input_predicates)
referenced_predicates.add(nested_pred) # also add itself to referenced_predicates
else:
# if nested_pred is a PlainPredicateConfig, only add it to referenced_predicates
referenced_predicates.add(nested_pred)
logger.info("Parsing predicates...")
predicates_to_parse = {k: v for k, v in final_predicates.items() if k in referenced_predicates}
predicate_objs = {}
for n, p in predicates_to_parse.items():
if "expr" in p:
predicate_objs[n] = DerivedPredicateConfig(**p)
else:
if isinstance(p, str):
raise ValueError(
f"Predicate '{n}' is not defined correctly in the configuration file. "
f"Currently defined as the string: {p}. "
"Please refer to the documentation for the supported formats."
)
config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__}
other_cols = {k: v for k, v in p.items() if k not in config_data}
predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols)
if final_demographics:
logger.info("Parsing patient demographics...")
final_demographics = {
n: PlainPredicateConfig(**p, static=True) for n, p in final_demographics.items()
}
predicate_objs.update(final_demographics)
return cls(predicates=predicate_objs, trigger=trigger, windows=windows)
def _initialize_predicates(self) -> None:
"""Initialize the predicates tree from the configuration object and check validity.
Raises:
ValueError: If the predicate name is not valid.
Examples:
>>> TaskExtractorConfig(
... predicates={
... "A": DerivedPredicateConfig("and(A, B)"), # A depends on B
... "B": DerivedPredicateConfig("and(B, C)"), # B depends on C
... "C": DerivedPredicateConfig("and(A, C)"), # C depends on A (Cyclic dependency)
... },
... trigger=EventConfig("A"),
... windows={},
... )
Traceback (most recent call last):
...
ValueError: Predicate graph is not a directed acyclic graph!
Cycle found: [('A', 'A')]
Graph: None
"""
dag_relationships = []
for name, predicate in self.predicates.items():
if re.match(r"^\w+$", name) is None:
raise ValueError(
f"Predicate name '{name}' is invalid; must be composed of alphanumeric or '_' characters."
)
match predicate:
case PlainPredicateConfig():
pass
case DerivedPredicateConfig():
for pred in predicate.input_predicates:
dag_relationships.append((pred, name))
case _:
raise ValueError(
f"Invalid predicate configuration for '{name}': {predicate}. "
"Must be either a PlainPredicateConfig or DerivedPredicateConfig object. "
f"Got: {type(predicate)}"
)
missing_predicates = []
for parent, child in dag_relationships:
if parent not in self.predicates:
missing_predicates.append(
f"Derived predicate '{child}' references undefined predicate '{parent}'"
)
if missing_predicates:
raise KeyError(
f"Missing {len(missing_predicates)} relationships: " + "; ".join(missing_predicates)
)
self._predicate_dag_graph = nx.DiGraph(dag_relationships)
if not nx.is_directed_acyclic_graph(self._predicate_dag_graph):
raise ValueError(
"Predicate graph is not a directed acyclic graph!\n"
f"Cycle found: {nx.find_cycle(self._predicate_dag_graph)}\n"
f"Graph: {nx.write_network_text(self._predicate_dag_graph)}"
)
def _initialize_windows(self) -> None:
"""Initialize the windows tree from the configuration object and check validity.
Raises:
ValueError: If the window name is not valid.
Examples:
>>> TaskExtractorConfig(
... predicates={"A": PlainPredicateConfig("A")},
... windows={
... "win1": WindowConfig(None, "trigger", True, False, has={"B": "(1, 0)"}) # B undefined
... },
... trigger=EventConfig("_ANY_EVENT"),
... )
Traceback (most recent call last):
...
KeyError: "Window 'win1' references undefined predicate 'B'.
Window predicates: B;
Defined predicates: A"
>>> TaskExtractorConfig(
... predicates={"A": PlainPredicateConfig("A")},
... windows={
... "win1": WindowConfig(None, "event_not_trigger", True, False)
... },
... trigger=EventConfig("_ANY_EVENT"),
... )
Traceback (most recent call last):
...
KeyError: "Window 'win1' references undefined trigger event
'event_not_trigger' -- must be trigger!"
>>> TaskExtractorConfig(
... predicates={"A": PlainPredicateConfig("A")},
... windows={
... "win1": WindowConfig("win2.end", "start -> A", True, False)
... },
... trigger=EventConfig("_ANY_EVENT"),
... )
Traceback (most recent call last):
...
KeyError: "Window 'win1' references undefined window 'win2' for event 'end'.
Allowed windows: win1"
"""
for name in self.windows:
if re.match(r"^\w+$", name) is None:
raise ValueError(
f"Window name '{name}' is invalid; must be composed of alphanumeric or '_' characters."
)
label_windows = []
index_timestamp_windows = []
for name, window in self.windows.items():
if window.label:
if window.label not in self.predicates:
raise ValueError(
f"Label must be one of the defined predicates. Got: {window.label} "
f"for window '{name}'"
)
label_windows.append(name)
if window.index_timestamp:
if window.index_timestamp not in ["start", "end"]:
raise ValueError(
f"Index timestamp must be either 'start' or 'end'. Got: {window.index_timestamp} "
f"for window '{name}'"
)
index_timestamp_windows.append(name)
if len(label_windows) > 1:
raise ValueError(
f"Only one window can be labeled, found {len(label_windows)} labeled windows: "
f"{', '.join(label_windows)}"
)
self.label_window = label_windows[0] if label_windows else None
if len(index_timestamp_windows) > 1:
raise ValueError(
f"Only the 'start'/'end' of one window can be used as the index timestamp, "
f"found {len(index_timestamp_windows)} windows with index_timestamp: "
f"{', '.join(index_timestamp_windows)}"
)
self.index_timestamp_window = index_timestamp_windows[0] if index_timestamp_windows else None
if self.trigger.predicate not in self.predicates and self.trigger.predicate not in [
ANY_EVENT_COLUMN,
START_OF_RECORD_KEY,
END_OF_RECORD_KEY,
]:
raise KeyError(
f"Trigger event predicate '{self.trigger.predicate}' not found in predicates: "
f"{', '.join(self.predicates.keys())}"
)
trigger_node = Node("trigger")
window_nodes = {"trigger": trigger_node}
for name, window in self.windows.items():
start_node = Node(f"{name}.start", endpoint_expr=window.start_endpoint_expr)
end_node = Node(f"{name}.end", endpoint_expr=window.end_endpoint_expr)
if window.root_node == "end":
# In this case, the end_node will bound an unconstrained window, as it is the window between
# a prior window and the defined anchor for this window, so it has no constraints. But the
# start_node will have the constraints corresponding to this window, as it is defined relative
# to the end node.
end_node.constraints = {}
start_node.constraints = window.has
start_node.parent = end_node
else:
# In this case, the start_node will bound an unconstrained window, as it is the window between
# a prior window and the defined anchor for this window, so it has no constraints. But the
# start_node will have the constraints corresponding to this window, as it is defined relative
# to the end node.
end_node.constraints = window.has
start_node.constraints = {}
end_node.parent = start_node
window_nodes[f"{name}.start"] = start_node
window_nodes[f"{name}.end"] = end_node
for name, window in self.windows.items():
for predicate in window.referenced_predicates - {ANY_EVENT_COLUMN}:
if predicate not in self.predicates:
raise KeyError(
f"Window '{name}' references undefined predicate '{predicate}'. "
f"Window predicates: {', '.join(window.referenced_predicates)}; "
f"Defined predicates: {', '.join(self.predicates.keys())}"
)
if len(window.referenced_event) == 1:
event = window.referenced_event[0]
if event != "trigger":
raise KeyError(
f"Window '{name}' references undefined trigger event '{event}' -- must be trigger!"
)
window_nodes[f"{name}.{window.root_node}"].parent = window_nodes[event]
elif len(window.referenced_event) == 2:
referenced_window, referenced_event = window.referenced_event
if referenced_window not in self.windows:
raise KeyError(
f"Window '{name}' references undefined window '{referenced_window}' "
f"for event '{referenced_event}'. Allowed windows: {', '.join(self.windows.keys())}"
)
# Might not be needed as valid window event references are already checked (line 660)
if referenced_event not in {"start", "end"}: # pragma: no cover
raise KeyError(
f"Window '{name}' references undefined event '{referenced_event}' "
f"for window '{referenced_window}'. Allowed events: 'start', 'end'"
)
parent_node = f"{referenced_window}.{referenced_event}"
window_nodes[f"{name}.{window.root_node}"].parent = window_nodes[parent_node]
# Might not be needed as valid window event references are already checked (line 660)
else: # pragma: no cover
raise ValueError(
f"Window '{name}' references invalid event '{window.referenced_event}' "
"must be of length 1 or 2."
)
# Clean up the tree
nodes_to_remove = []
# First pass: identify nodes to remove, reassign children's parent
for n, node in window_nodes.items():
if n != "trigger" and node.endpoint_expr is None:
nodes_to_remove.append(n)
for child in node.children:
# Reassign
child.parent = node.parent
if node.parent and child not in node.parent.children:
node.parent.children += (child,)
# Second pass: remove nodes from parent's children
for node_name in nodes_to_remove:
node = window_nodes[node_name]
if node.parent:
# Remove
node.parent.children = [child for child in node.parent.children if child.name != node_name]
# Delete nodes_to_remove
for node_name in nodes_to_remove:
del window_nodes[node_name]
self.window_nodes = window_nodes
def __post_init__(self) -> None:
self._initialize_predicates()
self._initialize_windows()
@property
def window_tree(self) -> Node:
return self.window_nodes["trigger"]
@property
def predicates_DAG(self) -> nx.DiGraph:
return self._predicate_dag_graph
@property
def plain_predicates(self) -> dict[str, PlainPredicateConfig]:
"""Returns a dictionary of plain predicates in {name: code} format."""
return {p: cfg for p, cfg in self.predicates.items() if cfg.is_plain}
@property
def derived_predicates(self) -> OrderedDict[str, DerivedPredicateConfig]:
"""Returns an ordered dictionary mapping derived predicates to their configs in a proper order."""
return {
p: self.predicates[p]
for p in nx.topological_sort(self.predicates_DAG)
if not self.predicates[p].is_plain
}