Deep Learning on OMOP Data in EHRData with PyPOTS#

This tutorial demonstrates how to quickly apply machine learning to OMOP data using PyPOTS, a powerful toolkit for time series analysis [Du23].

Prerequisites: Complete the OMOP Introduction tutorial first to understand how to load OMOP data into EHRData.

Use Case: ICU Mortality Prediction#

We’ll predict in-hospital mortality for ICU patients using the MIMIC-IV demo dataset in OMOP format [RJJ+20] [GAG+00].

Note

This is a demonstration example. Real clinical prediction requires more sophisticated preprocessing, validation, and careful consideration of clinical context.

What is PyPOTS?#

PyPOTS provides state-of-the-art neural network models for time series tasks:

  • Imputation - Fill missing values in incomplete time series

  • Classification - Predict outcomes from time series

  • Forecasting - Predict future values

  • Clustering - Group similar patients

PyPOTS works seamlessly with EHRData objects!

Setup and Installation#

%pip install pypots
# PyPOTS requires this for scipy compatibility
import os

os.environ["SCIPY_ARRAY_API"] = "1"

import ehrdata as ed
import duckdb
import pandas as pd
import torch
from pypots.classification import BRITS

Setup Database and Download Data#

# Create database connection
con = duckdb.connect(":memory:")

# Download MIMIC-IV OMOP demo data
ed.dt.mimic_iv_omop(backend_handle=con)

Define the Cohort#

We’ll focus on ICU patients by filtering visit_occurrence for ICU stays using OMOP concept IDs:

  • 4305366: Surgical ICU

  • 40481392: Medical ICU

  • 32037: Intensive Care

  • 763903: Trauma ICU

  • 4149943: Cardiac ICU

We apply two key filters:

  1. Duration: Only ICU stays >24 hours (to ensure sufficient data for 24-hour analysis)

  2. First visit: If a patient had multiple ICU stays, we select their first ICU visit

We do this here with SQL, operating on our (and any other) OMOP CDM database; SQL by for instance OHDSI’s ATLAS tool can also be used in such a context!

Alternative, the EHRData object can be filtered afterwards, working completely in Python (with less control over the “raw” data as you have it with SQL, though).

# Filter for first ICU visit per patient (>24 hours only)
con.execute("""
    WITH RankedVisits AS (
        SELECT
            v.*,
            vd.*,
            ROW_NUMBER() OVER (PARTITION BY v.person_id ORDER BY v.visit_start_date) AS rn
        FROM visit_occurrence v
        JOIN visit_detail vd USING (visit_occurrence_id)
        WHERE vd.visit_detail_concept_id IN (4305366, 40481392, 32037, 763903, 4149943)
            AND date_diff('hour', v.visit_start_date, v.visit_end_date) > 24
    ),
    first_icu_visit_occurrence_id AS (
        SELECT visit_occurrence_id
        FROM RankedVisits
        WHERE rn = 1
    )
    DELETE FROM visit_occurrence
    WHERE visit_occurrence_id NOT IN (SELECT visit_occurrence_id FROM first_icu_visit_occurrence_id)
""")

# Check how many ICU visits remain
n_visits = con.execute("SELECT COUNT(*) FROM visit_occurrence").fetchone()[0]
print(f"ICU cohort: {n_visits} patients (first ICU visit >24h only)")
ICU cohort: 99 patients (first ICU visit >24h only)

Build EHRData from OMOP#

Now we construct the EHRData object using ICU visit start as the time reference (t=0) for each patient:

# Step 1: Setup observations from person + visit_occurrence
edata = ed.io.omop.setup_obs(
    backend_handle=con,
    observation_table="person_visit_occurrence",  # Each row = one ICU visit
    death_table=True,
)

print(f"Created EHRData with {edata.n_obs} ICU visits")
edata.obs.head()
Created EHRData with 99 ICU visits
person_id gender_concept_id year_of_birth month_of_birth day_of_birth birth_datetime race_concept_id ethnicity_concept_id location_id provider_id ... admitting_source_value discharge_to_concept_id discharge_to_source_value preceding_visit_occurrence_id death_date death_datetime death_type_concept_id cause_concept_id cause_source_value cause_source_concept_id
0 4239478333578644568 8507 2111 None None NaT 8527 0 None None ... PHYSICIAN REFERRAL 581476 HOME <NA> NaT NaT NaN NaN NaN NaN
1 -8090189584974691216 8507 2118 None None NaT 8527 0 None None ... EMERGENCY ROOM 581476 HOME <NA> NaT NaT NaN NaN NaN NaN
2 2161418207209636934 8507 2060 None None NaT 2000001401 0 None None ... TRANSFER FROM HOSPITAL 8863 SKILLED NURSING FACILITY <NA> NaT NaT NaN NaN NaN NaN
3 1532249960797525190 8532 2106 None None NaT 2000001405 0 None None ... EMERGENCY ROOM 581476 HOME HEALTH CARE <NA> NaT NaT NaN NaN NaN NaN
4 2288881942133868955 8532 2102 None None NaT 8527 0 None None ... EMERGENCY ROOM 581476 HOME HEALTH CARE <NA> NaT NaT NaN NaN NaN NaN

5 rows × 41 columns

# Step 2: Extract measurements from the first 24 hours
edata = ed.io.omop.setup_variables(
    edata=edata,
    backend_handle=con,
    layer="measurements",
    data_tables=["measurement"],
    data_field_to_keep={"measurement": "value_as_number"},
    interval_length_number=1,
    interval_length_unit="h",  # Hourly intervals
    num_intervals=24,  # First 24 hours
    aggregation_strategy="last",
    enrich_var_with_feature_info=True,
    instantiate_tensor=True,
)

edata
EHRData object with n_obs × n_vars × n_t = 99 × 450 × 24
    obs: 'person_id', 'gender_concept_id', 'year_of_birth', 'month_of_birth', 'day_of_birth', 'birth_datetime', 'race_concept_id', 'ethnicity_concept_id', 'location_id', 'provider_id', 'care_site_id', 'person_source_value', 'gender_source_value', 'gender_source_concept_id', 'race_source_value', 'race_source_concept_id', 'ethnicity_source_value', 'ethnicity_source_concept_id', 'visit_occurrence_id', 'person_id_1', 'visit_concept_id', 'visit_start_date', 'visit_start_datetime', 'visit_end_date', 'visit_end_datetime', 'visit_type_concept_id', 'provider_id_1', 'care_site_id_1', 'visit_source_value', 'visit_source_concept_id', 'admitting_source_concept_id', 'admitting_source_value', 'discharge_to_concept_id', 'discharge_to_source_value', 'preceding_visit_occurrence_id', 'death_date', 'death_datetime', 'death_type_concept_id', 'cause_concept_id', 'cause_source_value', 'cause_source_concept_id'
    var: 'data_table_concept_id', 'data_table_concept_id_mapped', 'concept_id', 'concept_name', 'domain_id', 'vocabulary_id', 'concept_class_id', 'standard_concept', 'concept_code', 'valid_start_date', 'valid_end_date', 'invalid_reason'
    tem: '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23'
    uns: 'omop_io_observation_table', 'unit_report_measurement'
    layers: 'measurements'
    shape of .measurements: (99, 450, 24)

Task 2: Mortality Prediction with BRITS#

Now let’s predict in-hospital mortality using BRITS, which handles missing values during classification.

First, prepare labels from the extracted OMOP’s death table:

For a simplistic cohort design we select only people that survived the first 24h of their ICU visit.

We consider the prediction task of predicting death after 24h of their ICU visit begin up to 7 days after the end of their ICU visit.

# Filter for patients surviving the first 24h
edata = edata[
    pd.isnull(edata.obs["death_datetime"])
    | (edata.obs["death_datetime"] > edata.obs["visit_start_date"] + pd.Timedelta(hours=24))
].copy()
print(f"Patients surviving the first 24h: {len(edata)}")
Patients surviving the first 24h: 99
# Create binary labels for the prediction task
edata.obs["death"] = edata.obs["death_datetime"] <= edata.obs["visit_end_date"] + pd.Timedelta(days=7)
print(f"Patients dying within 7 days after ICU stay end: {edata.obs['death'].sum()} patients")
Patients dying within 7 days after ICU stay end: 10 patients

We split the data into a train and a test set. Notice how small the dataset and the labels are; we emphasize that this is merely a demonstration example with publicly available data, with not enough data to derive clinically meaningful results.

# Split into train/test (simple split for demonstration)
n_train = int(0.5 * len(edata))
n_test = int(0.5 * len(edata)), len(edata)

edata_train = edata[:n_train]
edata_test = edata[n_train:]

print(f"Training set: {len(edata_train)} patients ({edata_train.obs['death'].mean() * 100:.1f}% mortality)")
print(f"Test set: {len(edata_test)} patients ({edata_test.obs['death'].mean() * 100:.1f}% mortality)")
Training set: 49 patients (18.4% mortality)
Test set: 50 patients (2.0% mortality)

Now, we can with a few lines of code train e.g. BRITS for our prediction task.

# Initialize BRITS classifier
torch.manual_seed(42)
brits = BRITS(
    n_steps=edata_train.shape[2],
    n_features=edata_train.shape[1],
    rnn_hidden_size=32,
    n_classes=2,
    epochs=10,
    batch_size=16,
)

# Train the model
print("Training BRITS...")
brits.fit({"X": edata_train.layers["measurements"].transpose(0, 2, 1), "y": edata_train.obs["death"].values})

# Make predictions
predictions = brits.predict({"X": edata_test.layers["measurements"].transpose(0, 2, 1)})
pred_labels = predictions["classification"]

# Calculate accuracy
accuracy = (pred_labels == edata_test.obs["death"]).mean()
print(f"\nTest Accuracy: {accuracy * 100:.1f}%")
print(
    f"Baseline (predict majority class): {max(edata_test.obs['death'].mean(), 1 - edata_test.obs['death'].mean()) * 100:.1f}%"
)
Training BRITS...

Test Accuracy: 98.0%
Baseline (predict majority class): 98.0%

When we quickly inspect the results, we can see what is happening on this small dataset:

print(f"Predicting deaths in test set labels: {pred_labels.sum()}/{pred_labels.shape[0]}")
Predicting deaths in test set labels: 0/50

The model, without further weighting of sample importance, and a clear lack of data, simply learns to predict the imbalanced class “no death”.

Important caveats for this demo:

Warning

This demonstration uses only 100 ICU visits from the MIMIC-IV demo dataset. Real clinical prediction models require:

  • Much larger datasets (thousands of patients)

  • Careful feature engineering and clinical domain knowledge

  • Proper validation (cross-validation, external validation)

  • Clinical evaluation and prospective testing

The model performance shown here is not clinically meaningful due to the small sample size and simplified preprocessing. This tutorial demonstrates the technical workflow, not a production-ready model.

Next Tutorial#

Continue on ehrapy with Longitudinal Data Analysis with ehrapy and ehrdata: SAITS on the PhysioNet Challenge Dataset if you want to see a larger example of ehrdata, ehrapy, and PyPOTS in action together.

Continue with Interactive Visualization of EHRData with Vitessce to explore your data interactively with Vitessce.

Further Resources#