Representativeness through interoperability;
Jupyter notebook.
Below is a static version of the notebook.
Please download the interactive version to play around with yourself.
Improving Data Representativeness/Fairness Via Interoperability¶
This notebook is intended to demonstrate how interoperating systems can help improve representativeness of subsets in a dataset (e.g., demographics, disease distribution).
To demonstrate this, we'll use a synthetic database that has also been used for other MIDRC tutorials. It can be downloaded at this link: Data Download. Note, you must be signed in to the MIDRC Data Commons at data.midrc.org to access the file, the download button can be found at the bottom of the opened tab. When you download the file, please save it in your current working directory for this notebook.
Now, we will begin the tutorial by loading packages and defining useful functions that will be used to create plots for visualizing representation later. Run the next two cells of code.
import pandas as pd
import urllib.request
import sys
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
np.random.seed(4)
def radar_plot(jsd_dicts, labels, features, title=None, rmax=1.0):
"""
Create radar/spider charts to visualize JSD comparison of data subsets
jsd_dicts: list[dict] each dict maps feature->value
labels: list[str] legend labels for each dict
features: list[str] axis labels in desired order
"""
N = len(features)
angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
angles += angles[:1]
fig = plt.figure(figsize=(8, 8))
ax = plt.subplot(111, polar=True)
for d, lab in zip(jsd_dicts, labels):
values = [float(d.get(f, np.nan)) for f in features]
values += values[:1]
ax.plot(angles, values, linewidth=2, label=lab)
ax.fill(angles, values, alpha=0.15)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(features)
ax.set_ylim(0, rmax)
ax.grid(True)
if title:
ax.set_title(title, y=1.08)
ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.10), frameon=False)
plt.tight_layout()
return fig, ax
def pie_chart_plots(df, columns):
'''
Visualize pie charts for specified dataframe columns
df: pandas dataframe
columns: dataframe variables to be plotted
output: len(columns) pie charts
'''
n_rows = len(columns)
fig, axes = plt.subplots(nrows=n_rows, ncols=1, figsize=(12, 4 * n_rows))
fig.subplots_adjust(right=0.82, hspace=0.7)
cmap = plt.get_cmap("tab20")
for i, c in enumerate(columns):
ax = axes[i]
counts = df[c].value_counts(dropna=True).sort_values(ascending=False)
total = counts.sum()
pct = (counts / total * 100) if total > 0 else (counts * 0)
labels = counts.index.tolist()
values = counts.values
colors = [cmap(k % cmap.N) for k in range(len(values))]
wedges, _ = ax.pie(
values,
labels=None,
startangle=90,
colors=colors,
radius=1.15
)
ax.set_title(f"{c}\nALL", pad=10)
legend_labels = [f"{lab} ({p:.1f}%)" for lab, p in zip(labels, pct.values)]
ax.legend(
wedges, legend_labels,
loc="center left",
bbox_to_anchor=(1.02, 0.5),
borderaxespad=0.0,
frameon=False,
fontsize=9
)
def pie_chart_plots2(columns, groups):
'''
Creates pie charts to visualize data distributions.
columns: list of dataframe columns to be visualized
groups: list of tuples, each tuple is 2 elements formatted (<group name, string>, dataframe)
each tuple defines a set of pie charts to be created
output: len(groups) columns of pie charts, each column containing len(columns) rows
'''
n_rows = len(columns)
n_cols = len(groups)
fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(18, 4 * n_rows))
axes = np.atleast_2d(axes)
plt.subplots_adjust(wspace=0.9, hspace=0.8)
cmap = plt.get_cmap("tab20")
for i, c in enumerate(columns):
for j, (title, df_sub) in enumerate(groups):
ax = axes[i, j]
counts = df_sub[c].value_counts(dropna=True).sort_values(ascending=False)
total = counts.sum()
pct = (counts / total * 100) if total > 0 else (counts * 0)
labels = counts.index.tolist()
values = counts.values
colors = [cmap(k % cmap.N) for k in range(len(values))]
wedges, _ = ax.pie(
values,
labels=None,
startangle=90,
colors=colors
)
ax.set_title(f"{c}\n{title}")
legend_labels = [f"{lab} ({p:.1f}%)" for lab, p in zip(labels, pct.values)]
ax.legend(
wedges, legend_labels,
loc="upper center",
bbox_to_anchor=(0.5, -0.10),
ncol=1,
frameon=False,
fontsize=9
)
We will also need to access some functions that are used in other MIDRC tools, namely the MIDRC Generalized Stratified Sampling and REACT tools. This next cell will create a new directory called "utilities" in your current working directory, retrieve the files we need from the public GitHub repositories for those tools, and save them in that directory. You can see at the end of the cell that these functions are for calculating the Jenson-Shannon distance (JSD) between two sets of data and for performing stratified sampling.
url = "https://raw.githubusercontent.com/MIDRC/Generalized_Stratified_Sampling/refs/heads/main"
#Create needed directories
utils_dir = Path("utilities")
utils_dir.mkdir(exist_ok=True)
core_dir = utils_dir / "midrc_react" / "core"
core_dir.mkdir(parents=True, exist_ok=True)
#Get Stratified Sampling files
files = ["stratified_sampling.py", "CONFIG.py", "data_preprocessing.py"]
for f in files:
urllib.request.urlretrieve(f"{url}/{f}", utils_dir / f)
sys.path.insert(0, str(utils_dir))
#Get MIDRC-REACT files
core_dir = utils_dir / "midrc_react" / "core"
core_dir.mkdir(parents=True, exist_ok=True)
url = "https://raw.githubusercontent.com/MIDRC/MIDRC-REACT/refs/heads/master/src/midrc_react/core/"
files = ["aggregate_jsd_calc.py", "data_preprocessing.py"]
for f in files:
urllib.request.urlretrieve(f"{url}/{f}", core_dir / f)
sys.path.insert(0, str(utils_dir))
from midrc_react.core.aggregate_jsd_calc import calc_jsd_by_features_2df, calc_jsd_by_features
from CONFIG import SamplingData
from stratified_sampling import stratified_sampling
Now that we have all the functions and packages we'll need, we can start working with the data. Ensure that the downloaded dataset is saved in the current working directory (or modify the file path in the "read_excel" function below to the correct path).
The data needs to be preprocessed in a way that allows for better visualization and for stratification; specifically, numeric variables must be binned into distinct categories. In our dataset, age is the only continuous numeric variable so we provide ~10-year age bins starting at adulthood (age 18). Additionally, our dataset has information about imaging for each patient in the "modality" column indicating if the patient has CR, DX, CT, and MRI data available. In this example, we will suppose that we are only interested in if the patient has either DX or CT images and provide an additional column in a more functional format called "modality_categorized" that indicates more specifically for these 2 imaging modalities.
all_data = pd.read_excel('MIDRC_Stratified_Sampling_Example_5000_Patient_Subset.xlsx')
all_data = all_data.drop(columns=['dataset', 'batch'])
all_data = all_data.rename(columns={"covid19_positive":"disease_status"})
age_groups = [0, 17, 29, 39, 49, 59, 69, 79, 100]
age_labels = ['0-17', '18-29', '30-39', '40-49', '50-59', '60-69', '70-79', '80+']
all_data['age_group'] = pd.cut(all_data['age_at_index'], bins=age_groups, labels=age_labels)
s1 = "CT"
s2 = "DX"
m1 = all_data["modality"].fillna("").str.contains(s1, case=False, na=False)
m2 = all_data["modality"].fillna("").str.contains(s2, case=False, na=False)
all_data["modality_categorized"] = np.select([~m1&~m2, m1&~m2, ~m1&m2, m1&m2], ["neither", f"only_{s1}", f"only_{s2}", "both"], default="neither")
print(all_data.head())
submitter_id age_at_index disease_status ethnicity \
0 case_0001 27 Yes Not Hispanic or Latino
1 case_0002 84 Yes Not Hispanic or Latino
2 case_0003 65 Yes Not Hispanic or Latino
3 case_0004 69 Yes Not Hispanic or Latino
4 case_0005 37 Yes Not Hispanic or Latino
race sex modality age_group modality_categorized
0 Black or African American Female CR 18-29 neither
1 White Male DX 80+ only_DX
2 White Male CR 60-69 neither
3 White Male DX 60-69 only_DX
4 Black or African American Male DX 30-39 only_DX
First, let's take a look at the distributions of different variables in our subset using the pie chart function we defined before, noting the overall distribution of each variable.
columns = ['age_group', 'race', 'ethnicity', 'sex', 'disease_status', 'modality_categorized']
pie_chart_plots(all_data, columns)
plt.show()
The goal of this tutorial is to evaluate how interoperating datasets can improve representativeness, so we need to simulate interoperating datasets. Thus, we divide our database into 2 subsets, one with demographic features and one with disease and imaging data, simulating 2 interoperating sites. We'll call these df_site1 (holding variables related to demographics, Site 1) and df_site2 (holding the disease and imaging data variables, Site 2).
df_site1 = all_data[["submitter_id", "age_at_index", "ethnicity", "race", "sex", "age_group"]]
df_site2 = all_data[["submitter_id", "disease_status", "modality_categorized"]]
Let's focus on just the demographic data site (Site 1) to begin. Suppose we are interested in developing a database into training and testing subsets for training a machine learning algorithm. One way to split data is through stratified sampling, where the distribution of selected variables is matched across data partitions. MIDRC has developed a general stratified sampling algorithm and tool to perform such data splitting; we've already pulled some functions from the associated GitHub repository, so we'll utilize those to split the data from Site 1 into training and testing subsets of 80% and 20%, respectively.
sampling_data = SamplingData(
filename=None,
dataset_column="dataset",
features=["age_at_index", "ethnicity", "race", "sex"],
title="Site 1 Stratified Sampling",
datasets={"Training":0.8, "Testing":0.2},
numeric_cols={'age_at_index':{'bins':[0, 18, 30, 40, 50, 60, 70, 80, 1000], 'labels':['0-17','18-29','30-39','40-49','50-59','60-69','70-79','80+']}},
uid_col="submitter_id")
df_site1_strat = stratified_sampling(df_site1, sampling_data)
print(df_site1_strat.head())
D:\Research\MIDRC\MIDRC-Learn\MIDRC-Learn_JNotebooks\utilities\stratified_sampling.py:99: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy data_in[uid_col] = data_in[uid_col].astype(str) D:\Research\MIDRC\MIDRC-Learn\MIDRC-Learn_JNotebooks\utilities\stratified_sampling.py:107: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy data_in[col_name] = pd.to_numeric(data_in[col_name], errors='coerce') D:\Research\MIDRC\MIDRC-Learn\MIDRC-Learn_JNotebooks\utilities\stratified_sampling.py:109: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy data_in[col_name] = data_in[col_name].astype(str)
submitter_id age_at_index ethnicity \
0 case_0001 27 Not Hispanic or Latino
1 case_0002 84 Not Hispanic or Latino
2 case_0003 65 Not Hispanic or Latino
3 case_0004 69 Not Hispanic or Latino
4 case_0005 37 Not Hispanic or Latino
race sex age_group dataset
0 Black or African American Female 18-29 Training
1 White Male 80+ Training
2 White Male 60-69 Training
3 White Male 60-69 Training
4 Black or African American Male 30-39 Training
Now let's plot the demographic variable distributions for the training subset, testing subset, and full dataset so that we can compare distributions.
columns = ['age_group', 'race', 'ethnicity', 'sex']
split_col = "dataset"
v1, v2 = "Training", "Testing"
groups = [
(f"{split_col}={v1}", df_site1_strat[df_site1_strat[split_col] == v1]),
(f"{split_col}={v2}", df_site1_strat[df_site1_strat[split_col] == v2]),
("ALL", df_site1_strat)
]
pie_chart_plots2(columns, groups)
print(groups)
plt.tight_layout()
plt.show()
[('dataset=Training', submitter_id age_at_index ethnicity \
0 case_0001 27 Not Hispanic or Latino
1 case_0002 84 Not Hispanic or Latino
2 case_0003 65 Not Hispanic or Latino
3 case_0004 69 Not Hispanic or Latino
4 case_0005 37 Not Hispanic or Latino
... ... ... ...
4993 case_4994 58 Not Hispanic or Latino
4994 case_4995 86 Not Hispanic or Latino
4995 case_4996 72 Not Hispanic or Latino
4997 case_4998 80 Not Hispanic or Latino
4999 case_5000 22 Not Hispanic or Latino
race sex age_group dataset
0 Black or African American Female 18-29 Training
1 White Male 80+ Training
2 White Male 60-69 Training
3 White Male 60-69 Training
4 Black or African American Male 30-39 Training
... ... ... ... ...
4993 Black or African American Female 50-59 Training
4994 White Male 80+ Training
4995 White Female 70-79 Training
4997 White Female 80+ Training
4999 White Female 18-29 Training
[4004 rows x 7 columns]), ('dataset=Testing', submitter_id age_at_index ethnicity \
9 case_0010 55 Not Hispanic or Latino
25 case_0026 77 Not Hispanic or Latino
37 case_0038 50 Not Hispanic or Latino
38 case_0039 82 Not Hispanic or Latino
39 case_0040 79 Not Hispanic or Latino
... ... ... ...
4969 case_4970 50 Not Hispanic or Latino
4982 case_4983 60 Not Hispanic or Latino
4988 case_4989 54 Hispanic or Latino
4996 case_4997 73 Not Hispanic or Latino
4998 case_4999 55 Not Hispanic or Latino
race sex age_group dataset
9 White Female 50-59 Testing
25 Asian Female 70-79 Testing
37 Black or African American Male 50-59 Testing
38 White Male 80+ Testing
39 White Female 70-79 Testing
... ... ... ... ...
4969 White Female 50-59 Testing
4982 White Male 60-69 Testing
4988 White Male 50-59 Testing
4996 White Female 70-79 Testing
4998 White Male 50-59 Testing
[996 rows x 7 columns]), ('ALL', submitter_id age_at_index ethnicity \
0 case_0001 27 Not Hispanic or Latino
1 case_0002 84 Not Hispanic or Latino
2 case_0003 65 Not Hispanic or Latino
3 case_0004 69 Not Hispanic or Latino
4 case_0005 37 Not Hispanic or Latino
... ... ... ...
4995 case_4996 72 Not Hispanic or Latino
4996 case_4997 73 Not Hispanic or Latino
4997 case_4998 80 Not Hispanic or Latino
4998 case_4999 55 Not Hispanic or Latino
4999 case_5000 22 Not Hispanic or Latino
race sex age_group dataset
0 Black or African American Female 18-29 Training
1 White Male 80+ Training
2 White Male 60-69 Training
3 White Male 60-69 Training
4 Black or African American Male 30-39 Training
... ... ... ... ...
4995 White Female 70-79 Training
4996 White Female 70-79 Testing
4997 White Female 80+ Training
4998 White Male 50-59 Testing
4999 White Female 18-29 Training
[5000 rows x 7 columns])]
Note that, for the most part, the training and validation subsets are very similar to each other and to the overall dataset. However, let's look at the distributions of the non-demographic variables (Site 2) that we achieved by sampling over our demographic (Site 1 variables).
df_site2_s1strat = df_site2.copy()
df_site2_s1strat["dataset"] = df_site1_strat["dataset"].values
columns = ["modality_categorized", "disease_status"]
split_col = "dataset"
v1, v2 = "Training", "Testing"
groups = [
(f"{split_col}={v1}", df_site2_s1strat[df_site2_s1strat[split_col] == v1]),
(f"{split_col}={v2}", df_site2_s1strat[df_site2_s1strat[split_col] == v2]),
("ALL", df_site2_s1strat),
]
pie_chart_plots2(columns, groups)
plt.tight_layout()
plt.show()
Notably, the non-demographic (Site 2) variables do not achieve the same similarity as the demographic variables (Site 1) because they were not considered in the data stratification. It is important to note that, by chance, these variables are not neccessarily matched worse than the demographic variables, but that we increase the likelihood of this happening by not including them in the stratification.
We can quantify the similarity between subsets via the Jenson-Shannon Distance (JSD). Smaller JSD values indicate better matching across different subsets. In this tutorial, the JSD values that we will show compare the Training and Testing subsets, ignoring the comparison to the overall dataset.
columns = ["age_group", "race", "ethnicity", "sex", "modality_categorized", "disease_status"]
df_site1_strat["modality_categorized"] = df_site2["modality_categorized"].values
df_site1_strat["disease_status"] = df_site2["disease_status"].values
jsd1 = {}
for c in columns:
jsd1[c] = calc_jsd_by_features_2df(df_site1_strat[df_site1_strat["dataset"]=="Training"], df_site1_strat[df_site1_strat["dataset"]=="Testing"], cols_to_use=[c])
print(f"JSD for feature {c}:{jsd1[c]}")
radar_plot(jsd_dicts=[jsd1], labels=["Site 1 Stratification"], features=columns, title="JSD by feature", rmax=max(jsd1.values())*1.25)
plt.show()
JSD for feature age_group:0.010511574286073177 JSD for feature race:0.018233603104896908 JSD for feature ethnicity:0.013489099999595546 JSD for feature sex:0.00916418745943001 JSD for feature modality_categorized:0.022187360048203108 JSD for feature disease_status:0.03013432770823218
As expected, we see larger JSD values for the non-demographic variables (modality_categorized, disease_status) compared to the demographic variables.
Now, we will repeat this process but swap the Site 1 and Site 2 data (i.e., perform stratified sampling on the site 2 variables this time).
sampling_data = SamplingData(
filename=None,
dataset_column="dataset",
features=["modality_categorized", "disease_status"],
title="Site 2 Stratified Sampling",
datasets={"Training":0.8, "Testing":0.2},
numeric_cols={},
uid_col="submitter_id")
df_site2_strat = stratified_sampling(df_site2, sampling_data)
print(df_site2_strat.head())
submitter_id disease_status modality_categorized dataset 0 case_0001 Yes neither Training 1 case_0002 Yes only_DX Testing 2 case_0003 Yes neither Training 3 case_0004 Yes only_DX Training 4 case_0005 Yes only_DX Training
D:\Research\MIDRC\MIDRC-Learn\MIDRC-Learn_JNotebooks\utilities\stratified_sampling.py:99: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy data_in[uid_col] = data_in[uid_col].astype(str) D:\Research\MIDRC\MIDRC-Learn\MIDRC-Learn_JNotebooks\utilities\stratified_sampling.py:109: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy data_in[col_name] = data_in[col_name].astype(str)
columns = ["modality_categorized", "disease_status"]
split_col = "dataset"
v1, v2 = "Training", "Testing"
groups = [
(f"{split_col}={v1}", df_site2_strat[df_site2_strat[split_col] == v1]),
(f"{split_col}={v2}", df_site2_strat[df_site2_strat[split_col] == v2]),
("ALL", df_site2_strat),
]
pie_chart_plots2(columns, groups)
plt.tight_layout()
plt.show()
Qualitatively, it seems from the pie charts here compared to the pie charts from when we used demographic variables for stratification that the Site 2 variables are now better matched across subsets. This is as expected. We'll now show the opposite effect for the Site 1 variables.
df_site1_s2strat = df_site1.copy()
df_site1_s2strat["dataset"] = df_site2_strat["dataset"].values
columns = ["age_group", "race", "ethnicity", "sex"]
split_col = "dataset"
v1, v2 = "Training", "Testing"
groups = [
(f"{split_col}={v1}", df_site1_s2strat[df_site1_s2strat[split_col] == v1]),
(f"{split_col}={v2}", df_site1_s2strat[df_site1_s2strat[split_col] == v2]),
("ALL", df_site1_s2strat),
]
pie_chart_plots2(columns, groups)
plt.tight_layout()
plt.show()
It may be difficult to evaluate all of these qualitatively, but what we should see are larger differences between the demographic varibale distributions across subsets than we saw before. This tends to be more evident when there are more variable possibilities present (e.g., in the age_group and race variables in this category). We'll once again utilize the JSD to quantify distribution similarity.
columns = ["age_group", "race", "ethnicity", "sex", "modality_categorized", "disease_status"]
df_site2_strat["age_group"] = df_site1["age_group"].values
df_site2_strat["race"] = df_site1["race"].values
df_site2_strat["ethnicity"] = df_site1["ethnicity"].values
df_site2_strat["sex"] = df_site1["sex"].values
jsd2 = {}
for c in columns:
jsd2[c] = calc_jsd_by_features_2df(df_site2_strat[df_site2_strat["dataset"]=="Training"], df_site2_strat[df_site2_strat["dataset"]=="Testing"], cols_to_use=[c])
print(f"JSD for feature {c}:{jsd2[c]}")
radar_plot(jsd_dicts=[jsd1, jsd2], labels=["Site 1 Stratification", "Site 2 Stratification"], features=columns, title="JSD by feature", rmax=max(max(jsd1.values()),max(jsd2.values()))*1.25)
plt.show()
JSD for feature age_group:0.03063708975700618 JSD for feature race:0.04831477567493617 JSD for feature ethnicity:0.03448631154034046 JSD for feature sex:0.010713813753663474 JSD for feature modality_categorized:0.0013471689678381568 JSD for feature disease_status:0.01123503018755874
Finally, we will show how stratified over both site's variables combined improves representativeness in the training/testing split.
sampling_data = SamplingData(
filename=None,
dataset_column="dataset",
features=["age_group", "race", "ethnicity", "sex", "modality_categorized", "disease_status"],
title="All Data Stratified Sampling",
datasets={"Training":0.8, "Testing":0.2},
numeric_cols={},
uid_col="submitter_id")
df_all_strat = stratified_sampling(all_data, sampling_data)
print(df_all_strat.head())
submitter_id age_at_index disease_status ethnicity \
0 case_0001 27 Yes Not Hispanic or Latino
1 case_0002 84 Yes Not Hispanic or Latino
2 case_0003 65 Yes Not Hispanic or Latino
3 case_0004 69 Yes Not Hispanic or Latino
4 case_0005 37 Yes Not Hispanic or Latino
race sex modality age_group modality_categorized \
0 Black or African American Female CR 18-29 neither
1 White Male DX 80+ only_DX
2 White Male CR 60-69 neither
3 White Male DX 60-69 only_DX
4 Black or African American Male DX 30-39 only_DX
dataset
0 Training
1 Training
2 Testing
3 Training
4 Training
As before, we visualize the variable distributions for both Site 1 and Site 2.
columns = ["age_group", "race", "ethnicity", "sex", "modality_categorized", "disease_status"]
split_col = "dataset"
v1, v2 = "Training", "Testing"
groups = [
(f"{split_col}={v1}", df_all_strat[df_all_strat[split_col] == v1]),
(f"{split_col}={v2}", df_all_strat[df_all_strat[split_col] == v2]),
("ALL", df_all_strat),
]
pie_chart_plots2(columns, groups)
plt.tight_layout()
plt.show()
And finally, we quantify the comparison with the JSD.
columns = ["age_group", "race", "ethnicity", "sex", "modality_categorized", "disease_status"]
jsd3 = {}
for c in columns:
jsd3[c] = calc_jsd_by_features_2df(df_all_strat[df_all_strat["dataset"]=="Training"], df_all_strat[df_all_strat["dataset"]=="Testing"], cols_to_use=[c])
print(f"JSD for feature {c}:{jsd3[c]}")
radar_plot(jsd_dicts=[jsd1, jsd2, jsd3], labels=["Site 1 Stratification", "Site 2 Stratification", "All Stratification"], features=columns, title="JSD by feature", rmax=max(max(jsd1.values()),max(jsd2.values()), max(jsd3.values()))*1.25)
plt.show()
JSD for feature age_group:0.012779933563241419 JSD for feature race:0.014999511283776247 JSD for feature ethnicity:0.012365893326651741 JSD for feature sex:0.02391632106617098 JSD for feature modality_categorized:0.009426707375618812 JSD for feature disease_status:0.01120882649402947
As expected, stratifying across both sites combined achieved better average JSD values than stratifying across either of the Site 1 or Site 2 variables alone. This emphasizes that when we have more information available from multiple sites, we are able to improve the representation of different variables in our training and testing subsets; this allows for more appropriate model development and evaluation as certain populations are not unfairly weighted due to overrepresentation or underrepresentation.
However, there are a few caveats to this. Of course, this aspect of improving representation is reliant on fair collection of data in the first place; if the original collection is biased, then this subsequent sampling would also be biased. Additionally, the improved sampling JSD values achieved by stratifying across all variables compared to only a subset is not a guarantee. It is quite possible that, by random chance, a better JSD could be achieved via only Site 1 or Site 2 stratification; the stratification across all site variables as opposed to only a single site simply mitigates the likelihood that this would happen, but does not eliminate it. The clear example of this is that the JSD for the Sex variable actually increased in this tutorial even though all other variables decreased. You can experiment with this effect using different random seeds and re-running this tutorial (the random seed is established at the top of the second code block, "np.random.seed(4)". Changing the number will change the seed and thus cause different results).
This concludes the tutorial.