Source code for aces.constraints

"""Contains utilities for validating that windows satisfy a set of constraints."""

import logging

import polars as pl

from .types import ANY_EVENT_COLUMN

logger = logging.getLogger(__name__)


[docs]def check_constraints( window_constraints: dict[str, tuple[int | None, int | None]], summary_df: pl.DataFrame ) -> pl.DataFrame: """Checks the constraints on the counts of predicates in the summary dataframe. Args: window_constraints: constraints on counts of predicates that must be satisfied, organized as a dictionary from predicate column name to the lowerbound and upper bound range required for that constraint to be satisfied. summary_df: A dataframe containing a row for every possible prospective window to be analyzed. The only columns expected are predicate columns within the ``window_constraints`` dictionary. Returns: A filtered dataframe containing only the rows that satisfy the constraints. Raises: ValueError: If the constraint for a column is empty. Examples: >>> df = pl.DataFrame({ ... "subject_id": [1, 1, 1, 1, 2, 2], ... "timestamp": [ ... # Subject 1 ... datetime(year=1989, month=12, day=1, hour=12, minute=3), ... datetime(year=1989, month=12, day=2, hour=5, minute=17), ... datetime(year=1989, month=12, day=2, hour=12, minute=3), ... datetime(year=1989, month=12, day=6, hour=11, minute=0), ... # Subject 2 ... datetime(year=1989, month=12, day=1, hour=13, minute=14), ... datetime(year=1989, month=12, day=3, hour=15, minute=17), ... ], ... "is_A": [1, 4, 1, 3, 3, 3], ... "is_B": [0, 2, 0, 2, 10, 2], ... "is_C": [1, 1, 1, 0, 1, 1], ... }) >>> check_constraints({"is_A": (None, None), "is_B": (2, 6), "is_C": (1, 1)}, df) Traceback (most recent call last): ... ValueError: Invalid constraint for 'is_A': None - None >>> check_constraints({"is_A": (2, 1), "is_B": (2, 6), "is_C": (1, 1)}, df) Traceback (most recent call last): ... ValueError: Invalid constraint for 'is_A': 2 - 1 >>> check_constraints({"is_A": (3, 4), "is_B": (2, 6), "is_C": (1, 1)}, df) shape: (2, 5) ┌────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ is_A ┆ is_B ┆ is_C │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 │ ╞════════════╪═════════════════════╪══════╪══════╪══════╡ │ 1 ┆ 1989-12-02 05:17:00 ┆ 4 ┆ 2 ┆ 1 │ │ 2 ┆ 1989-12-03 15:17:00 ┆ 3 ┆ 2 ┆ 1 │ └────────────┴─────────────────────┴──────┴──────┴──────┘ >>> check_constraints({"is_A": (3, 4), "is_B": (2, None), "is_C": (None, 1)}, df) shape: (4, 5) ┌────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ is_A ┆ is_B ┆ is_C │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 │ ╞════════════╪═════════════════════╪══════╪══════╪══════╡ │ 1 ┆ 1989-12-02 05:17:00 ┆ 4 ┆ 2 ┆ 1 │ │ 1 ┆ 1989-12-06 11:00:00 ┆ 3 ┆ 2 ┆ 0 │ │ 2 ┆ 1989-12-01 13:14:00 ┆ 3 ┆ 10 ┆ 1 │ │ 2 ┆ 1989-12-03 15:17:00 ┆ 3 ┆ 2 ┆ 1 │ └────────────┴─────────────────────┴──────┴──────┴──────┘ >>> predicates_df = pl.DataFrame({ ... "subject_id": [1, 1, 3], ... "timestamp": [datetime(1980, 12, 28), datetime(2010, 6, 20), datetime(2010, 5, 11)], ... "A": [False, False, False], ... "_ANY_EVENT": [True, True, True], ... }) >>> check_constraints({"_ANY_EVENT": (1, None)}, predicates_df) shape: (3, 4) ┌────────────┬─────────────────────┬───────┬────────────┐ │ subject_id ┆ timestamp ┆ A ┆ _ANY_EVENT │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ datetime[μs] ┆ bool ┆ bool │ ╞════════════╪═════════════════════╪═══════╪════════════╡ │ 1 ┆ 1980-12-28 00:00:00 ┆ false ┆ true │ │ 1 ┆ 2010-06-20 00:00:00 ┆ false ┆ true │ │ 3 ┆ 2010-05-11 00:00:00 ┆ false ┆ true │ └────────────┴─────────────────────┴───────┴────────────┘ """ should_drop = pl.lit(False) for col, (valid_min_inc, valid_max_inc) in window_constraints.items(): if (valid_min_inc is None and valid_max_inc is None) or ( valid_min_inc is not None and valid_max_inc is not None and valid_max_inc < valid_min_inc ): raise ValueError(f"Invalid constraint for '{col}': {valid_min_inc} - {valid_max_inc}") if col == "*": col = ANY_EVENT_COLUMN drop_expr = pl.lit(False) if valid_min_inc is not None: drop_expr = drop_expr | (pl.col(col) < valid_min_inc) if valid_max_inc is not None: drop_expr = drop_expr | (pl.col(col) > valid_max_inc) logger.info( f"Excluding {summary_df.select(drop_expr.sum()).item():,} rows " f"as they failed to satisfy '{valid_min_inc} <= {col} <= {valid_max_inc}'." ) should_drop = should_drop | drop_expr return summary_df.filter(~should_drop)
[docs]def check_static_variables(patient_demographics: list[str], predicates_df: pl.DataFrame) -> pl.DataFrame: """Checks the constraints on the counts of predicates in the summary dataframe. Args: patient_demographics: List of columns representing static patient demographics. predicates_df: Dataframe containing a row for each event with patient demographics and timestamps. Returns: A filtered dataframe containing only the rows that satisfy the patient demographics. Raises: ValueError: If the static predicate used by constraint is not in the predicates dataframe. Examples: >>> predicates_df = pl.DataFrame({ ... "subject_id": [1, 1, 1, 1, 1, 2, 2, 2], ... "timestamp": [ ... # Subject 1 ... None, ... datetime(year=1989, month=12, day=1, hour=12, minute=3), ... datetime(year=1989, month=12, day=2, hour=5, minute=17), ... datetime(year=1989, month=12, day=2, hour=12, minute=3), ... datetime(year=1989, month=12, day=6, hour=11, minute=0), ... # Subject 2 ... None, ... datetime(year=1989, month=12, day=1, hour=13, minute=14), ... datetime(year=1989, month=12, day=3, hour=15, minute=17), ... ], ... "is_A": [0, 1, 4, 1, 0, 3, 3, 3], ... "is_B": [0, 0, 2, 0, 0, 2, 10, 2], ... "is_C": [0, 1, 1, 1, 0, 0, 1, 1], ... "male": [1, 0, 0, 0, 0, 0, 0, 0] ... }) >>> check_static_variables(['male'], predicates_df) shape: (4, 5) ┌────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ is_A ┆ is_B ┆ is_C │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 │ ╞════════════╪═════════════════════╪══════╪══════╪══════╡ │ 1 ┆ 1989-12-01 12:03:00 ┆ 1 ┆ 0 ┆ 1 │ │ 1 ┆ 1989-12-02 05:17:00 ┆ 4 ┆ 2 ┆ 1 │ │ 1 ┆ 1989-12-02 12:03:00 ┆ 1 ┆ 0 ┆ 1 │ │ 1 ┆ 1989-12-06 11:00:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴──────┴──────┴──────┘ >>> check_static_variables(['female'], predicates_df) Traceback (most recent call last): ... ValueError: Static predicate 'female' not found in the predicates dataframe. """ predicates_constraints = [] for demographic in patient_demographics: if demographic not in predicates_df.columns: raise ValueError(f"Static predicate '{demographic}' not found in the predicates dataframe.") predicates_constraints.append( (pl.col("timestamp").is_null() & (pl.col(demographic) > 0)).any().over("subject_id") ) predicate_filter = pl.all_horizontal(predicates_constraints) return predicates_df.filter(predicate_filter).drop_nulls(subset=["timestamp"]).drop(patient_demographics)