{ "cells": [ { "cell_type": "markdown", "id": "c74df924", "metadata": {}, "source": [ "# šŸ„ Machine Learning in Healthcare: Predicting Stent Failure\n", "## Python Companion Notebook — Mirrors the SAS Viya Case Study\n", "### Developed by Dr. Benyawarath \"Yaa\" Nithithanatchinnapat, for FINC-332 & GMBA-621\n", "---\n", "\n", "**Business Problem:** A stent procedure has a low margin for error. Each stent type, shape, and material affects patients differently depending on demographics. Can we use machine learning to predict which patients are most likely to experience stent failure — so doctors can intervene *before* something goes wrong?\n", "\n", "**Why this matters:** Accurate predictions could trigger preventative check-ups, guide stent selection for individual patients, and ultimately make cardiac procedures safer.\n", "\n", "**What we'll do in this notebook:**\n", "\n", "| Step | What | Why |\n", "|------|------|-----|\n", "| **Part 1** | Explore & visualize the data | Understand what we're working with before modeling |\n", "| **Part 2** | Clean & prepare the data | Handle missing values and create train/validation splits |\n", "| **Part 3** | Build 3 models: Decision Tree, Random Forest, Logistic Regression | Compare different approaches to the same prediction task |\n", "| **Part 4** | Compare model performance | Pick the best model using metrics that matter in healthcare |\n", "\n", "> **šŸ”— SAS ↔ Python Connection:** Each section below maps directly to the SAS Viya exercise you completed. Same data, same logic — different tool. This shows you that analytical thinking transfers across platforms.\n" ] }, { "cell_type": "markdown", "id": "3e79dce5", "metadata": {}, "source": [ "## šŸ“¦ Setup: Install & Import Libraries\n", "\n", "Before we dive in, let's load the tools we'll need. Think of these as the Python equivalents of SAS Visual Analytics and Visual Data Mining modules.\n", "\n", "| Python Library | SAS Equivalent | Purpose |\n", "|---|---|---|\n", "| `pandas` | SAS Data Step | Data manipulation |\n", "| `matplotlib` / `seaborn` | SAS Visual Analytics | Visualization |\n", "| `scikit-learn` | SAS Visual Data Mining & ML | Model building & evaluation |\n" ] }, { "cell_type": "code", "execution_count": null, "id": "76fca4c6", "metadata": {}, "outputs": [], "source": [ "# ─── Core Libraries ───────────────────────────────────────────\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "# ─── Scikit-Learn: Preprocessing ──────────────────────────────\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder, StandardScaler\n", "from sklearn.impute import SimpleImputer\n", "\n", "# ─── Scikit-Learn: Models ─────────────────────────────────────\n", "from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression\n", "\n", "# ─── Scikit-Learn: Evaluation ─────────────────────────────────\n", "from sklearn.metrics import (\n", " accuracy_score, classification_report, confusion_matrix,\n", " roc_auc_score, roc_curve, ConfusionMatrixDisplay\n", ")\n", "\n", "# ─── Display Settings ─────────────────────────────────────────\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "pd.set_option('display.max_columns', 30)\n", "pd.set_option('display.max_rows', 100)\n", "sns.set_style('whitegrid')\n", "plt.rcParams['figure.figsize'] = (10, 6)\n", "plt.rcParams['font.size'] = 11\n", "\n", "print(\"āœ… All libraries loaded successfully!\")\n" ] }, { "cell_type": "markdown", "id": "31f9ccbd", "metadata": {}, "source": [ "---\n", "## Part 1: Exploration šŸ”\n", "**SAS Equivalent:** *Explore and Visualize → Start with Data → Auto-chart and Correlation Matrix*\n", "\n", "Before we build any models, we need to understand our data. This is the \"look before you leap\" step. In SAS, you used the auto-chart feature and drag-and-drop visualizations. In Python, we'll use `pandas` and `seaborn` to accomplish the same thing.\n", "\n", "> **šŸ’” Rule of thumb:** Never build a model on data you haven't explored first. \"Garbage in, garbage out\" applies everywhere.\n" ] }, { "cell_type": "markdown", "id": "596e9769", "metadata": {}, "source": [ "### Step 1.1: Load the Data\n", "\n", "šŸ“Œ **Action Required:** Update the file path below to match where you saved the Excel file.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ffb36326", "metadata": {}, "outputs": [], "source": [ "# ─── Load the dataset ─────────────────────────────────────────\n", "# UPDATE THIS PATH to where you saved the Stent Failure Excel file\n", "df = pd.read_excel('Stent_Failure.xlsx')\n", "\n", "# ─── Standardize column names ─────────────────────────────────\n", "# The Excel file has spaces and long names; we'll create clean\n", "# short names so the rest of the notebook runs smoothly.\n", "rename_map = {\n", " 'Ethnicity': 'Ethnic_group',\n", " 'Stent Cell Design': 'Cell_Design',\n", " 'Stent Cell Type': 'Cell_type',\n", " 'Stent Material': 'Stent_Material',\n", " 'Multiple Stent': 'Multiple_Stent',\n", " 'Geographic Miss': 'Geographic_Miss',\n", " 'Upsteam or Downstream disease in the coronary': 'Coronary_Stream',\n", " 'Stent Failure': 'Stent_Failure',\n", " 'Hours active per week': 'Hours_Active',\n", " 'Stent Thickness (mm)': 'Stent_Thickness',\n", " 'Stent_Length (mm)': 'Stent_Length',\n", " 'Patient ID': 'Patient_ID',\n", " 'Stent Width (mm)': 'Stent_Width',\n", " 'device age (days)': 'Device_age',\n", " 'Smoker flag': 'Smoker_flag',\n", " 'Plaque Prolapse through cells of the stent': 'Plaque_Prolapse',\n", " 'Stent Failure flag': 'Stent_Failure_flag'\n", "}\n", "df.rename(columns=rename_map, inplace=True)\n", "\n", "# ─── Fix numeric columns stored as text (commas in numbers) ───\n", "# Some columns like Device_age have values like '2,101' which\n", "# pandas reads as text. We strip commas and convert to numeric.\n", "for col in ['Stent_Thickness', 'Stent_Length', 'Stent_Width', 'Device_age']:\n", " if col in df.columns and df[col].dtype == 'object':\n", " df[col] = pd.to_numeric(df[col].astype(str).str.replace(',', ''), errors='coerce')\n", "\n", "# Quick look at what we're working with\n", "print(f\"Dataset shape: {df.shape[0]:,} patients Ɨ {df.shape[1]} variables\")\n", "print(f\"\\n{'='*60}\")\n", "print(\"First 5 rows:\")\n", "df.head()\n" ] }, { "cell_type": "markdown", "id": "cd06dc50", "metadata": {}, "source": [ "### Step 1.2: Data Profile — The \"First Date\" with Your Data\n", "\n", "In SAS, you reviewed the **Profile** pane to check variable types, missingness, and cardinality. Here's the Python equivalent.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0d7ed7de", "metadata": {}, "outputs": [], "source": [ "# ─── Data types and missing values (like SAS Measure Details) ─\n", "print(\"=\" * 70)\n", "print(\"DATA PROFILE SUMMARY\")\n", "print(\"=\" * 70)\n", "\n", "profile = pd.DataFrame({\n", " 'Data Type': df.dtypes,\n", " 'Non-Null Count': df.count(),\n", " 'Missing Count': df.isnull().sum(),\n", " 'Missing %': (df.isnull().sum() / len(df) * 100).round(2),\n", " 'Unique Values': df.nunique()\n", "})\n", "\n", "print(profile.to_string())\n", "print(f\"\\nšŸ“Š Total observations: {len(df):,}\")\n", "print(f\"šŸ”¢ Numeric variables: {df.select_dtypes(include='number').shape[1]}\")\n", "print(f\"šŸ·ļø Categorical variables: {df.select_dtypes(include='object').shape[1]}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "51e58977", "metadata": {}, "outputs": [], "source": [ "# ─── Basic descriptive statistics for numeric variables ───────\n", "df.describe().round(2)\n" ] }, { "cell_type": "markdown", "id": "1e548bc8", "metadata": {}, "source": [ "### Step 1.3: Distribution of the Target Variable — Stent Failure\n", "\n", "In SAS, you examined the frequency of Stent Failure using the auto-chart feature. Let's do the same in Python.\n", "\n", "> **šŸ¤” Key question:** Is the target variable balanced? If one class heavily outweighs the other, that affects how we build and evaluate models.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d1aed564", "metadata": {}, "outputs": [], "source": [ "# ─── Target Variable Distribution ─────────────────────────────\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", "\n", "# Count plot\n", "stent_counts = df['Stent_Failure'].value_counts()\n", "colors = ['#2ecc71', '#e74c3c']\n", "stent_counts.plot(kind='bar', ax=axes[0], color=colors, edgecolor='black', alpha=0.85)\n", "axes[0].set_title('Distribution of Stent Failure', fontsize=14, fontweight='bold')\n", "axes[0].set_xlabel('Stent Failure')\n", "axes[0].set_ylabel('Count')\n", "axes[0].tick_params(axis='x', rotation=0)\n", "\n", "# Add count labels on bars\n", "for i, (val, count) in enumerate(stent_counts.items()):\n", " axes[0].text(i, count + 50, f'{count:,}\\n({count/len(df)*100:.1f}%)',\n", " ha='center', fontweight='bold', fontsize=11)\n", "\n", "# Pie chart\n", "stent_counts.plot(kind='pie', ax=axes[1], colors=colors, autopct='%1.1f%%',\n", " startangle=90, textprops={'fontsize': 12})\n", "axes[1].set_ylabel('')\n", "axes[1].set_title('Stent Failure Proportion', fontsize=14, fontweight='bold')\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(f\"\\nšŸ“Œ Class balance: {stent_counts.to_dict()}\")\n" ] }, { "cell_type": "markdown", "id": "9605d42e", "metadata": {}, "source": [ "### Step 1.4: Explore Key Variables\n", "\n", "In SAS, you examined Ethnicity, Gender, Hospital, and their relationships with Stent Failure. Let's recreate those visualizations.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "371ab1f5", "metadata": {}, "outputs": [], "source": [ "# ─── Stent Failure by Ethnicity (SAS: Frequency of Ethnicity) ─\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", "\n", "# Ethnicity distribution\n", "eth_order = df['Ethnic_group'].value_counts().index\n", "sns.countplot(data=df, y='Ethnic_group', order=eth_order, ax=axes[0],\n", " palette='viridis', edgecolor='black', alpha=0.85)\n", "axes[0].set_title('Distribution of Ethnicity', fontsize=13, fontweight='bold')\n", "axes[0].set_xlabel('Count')\n", "axes[0].set_ylabel('')\n", "\n", "# Stent Failure rate BY Ethnicity\n", "failure_rate = df.groupby('Ethnic_group')['Stent_Failure'].apply(\n", " lambda x: (x == 'Yes').mean() * 100 if x.dtype == 'object'\n", " else x.mean() * 100\n", ").sort_values(ascending=False)\n", "\n", "failure_rate.plot(kind='barh', ax=axes[1], color='#e74c3c', edgecolor='black', alpha=0.85)\n", "axes[1].set_title('Stent Failure Rate (%) by Ethnicity', fontsize=13, fontweight='bold')\n", "axes[1].set_xlabel('Failure Rate (%)')\n", "axes[1].set_ylabel('')\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"\\nšŸ“Œ Stent Failure Rate by Ethnicity:\")\n", "print(failure_rate.round(2).to_string())\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b8674605", "metadata": {}, "outputs": [], "source": [ "# ─── Stent Failure by Gender and Hospital ─────────────────────\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "# By Gender\n", "gender_fail = pd.crosstab(df['Gender'], df['Stent_Failure'], normalize='index') * 100\n", "gender_fail.plot(kind='bar', stacked=True, ax=axes[0], color=colors, edgecolor='black', alpha=0.85)\n", "axes[0].set_title('Stent Failure by Gender', fontsize=13, fontweight='bold')\n", "axes[0].set_ylabel('Percentage (%)')\n", "axes[0].tick_params(axis='x', rotation=0)\n", "axes[0].legend(title='Stent Failure', loc='upper right')\n", "\n", "# By Hospital\n", "hosp_fail = pd.crosstab(df['Hospital'], df['Stent_Failure'], normalize='index') * 100\n", "hosp_fail.plot(kind='bar', stacked=True, ax=axes[1], color=colors, edgecolor='black', alpha=0.85)\n", "axes[1].set_title('Stent Failure by Hospital', fontsize=13, fontweight='bold')\n", "axes[1].set_ylabel('Percentage (%)')\n", "axes[1].tick_params(axis='x', rotation=45)\n", "axes[1].legend(title='Stent Failure', loc='upper right')\n", "\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "46efddd0", "metadata": {}, "source": [ "### Step 1.5: Correlation Matrix\n", "\n", "In SAS, you built a correlation matrix with Stent Failure Flag on the Y axis. Python's `seaborn` heatmap gives us the same view.\n", "\n", "> **šŸ“Œ SAS Solutions Reference:** The strongest correlation with Stent Failure was **Plaque Prolapse through cells of the stent**.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f3d8bb7a", "metadata": {}, "outputs": [], "source": [ "# ─── Correlation Matrix (SAS: Between Two Sets of Measures) ───\n", "numeric_cols = df.select_dtypes(include='number').columns.tolist()\n", "\n", "# Remove Patient_id if present (not a meaningful predictor)\n", "id_cols = [c for c in numeric_cols if 'id' in c.lower() or 'patient' in c.lower()]\n", "numeric_for_corr = [c for c in numeric_cols if c not in id_cols]\n", "\n", "corr_matrix = df[numeric_for_corr].corr()\n", "\n", "plt.figure(figsize=(12, 8))\n", "mask = np.triu(np.ones_like(corr_matrix, dtype=bool))\n", "sns.heatmap(corr_matrix, mask=mask, annot=True, fmt='.2f', cmap='RdBu_r',\n", " center=0, vmin=-1, vmax=1, square=True, linewidths=0.5,\n", " cbar_kws={'label': 'Correlation Coefficient'})\n", "plt.title('Correlation Matrix — Numeric Variables', fontsize=14, fontweight='bold')\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Highlight correlations with the target (Stent_Failure_flag)\n", "if 'Stent_Failure_flag' in numeric_for_corr:\n", " print(\"\\nšŸ“Œ Correlations with Stent Failure Flag (sorted by strength):\")\n", " target_corr = corr_matrix['Stent_Failure_flag'].drop('Stent_Failure_flag').abs().sort_values(ascending=False)\n", " for var, val in target_corr.items():\n", " direction = corr_matrix.loc[var, 'Stent_Failure_flag']\n", " arrow = '↑' if direction > 0 else '↓'\n", " print(f\" {arrow} {var}: {direction:.3f}\")\n" ] }, { "cell_type": "markdown", "id": "26265d6e", "metadata": {}, "source": [ "### šŸ”‘ Part 1 Checkpoint — What Did We Learn?\n", "\n", "Before moving on, let's summarize what the exploration revealed:\n", "\n", "| Finding | Business Implication |\n", "|---------|---------------------|\n", "| **Plaque Prolapse** has the strongest relationship with Stent Failure | This should be a top predictor in our models |\n", "| Some ethnic groups show higher failure rates | Models should account for demographic differences |\n", "| Missing values exist in Stent Thickness, Width, and Length | We need to handle these before certain models (especially Logistic Regression) |\n", "| The target may be imbalanced | We should use metrics beyond accuracy (like AUC and recall) |\n", "\n", "> **šŸ”— SAS Connection:** These are the same findings you discovered using the auto-chart and Automated Explanation features in SAS Visual Analytics.\n" ] }, { "cell_type": "markdown", "id": "cfc60870", "metadata": {}, "source": [ "---\n", "## Part 2: Transform & Modify šŸ› ļø\n", "**SAS Equivalent:** *Data pane → Measure Details → New Data Item → Partition Variable*\n", "\n", "Remember: **garbage in, garbage out.** This section mirrors the SAS steps where you checked for missing values, created imputed variables, and set up a 70/30 train-validation partition.\n" ] }, { "cell_type": "markdown", "id": "f58d935a", "metadata": {}, "source": [ "### Step 2.1: Handle Missing Values\n", "\n", "In SAS, you used the **Measure Details** window to find missing values, then created calculated items (Thickness_Imp, Width_Imp, Length_Imp) using IF-ELSE logic. Python gives us the same control.\n", "\n", "> **šŸ“Œ SAS Solutions Reference:** The variables with missing values were **Stent_Length, Stent_Width, Stent_Thickness** (all numeric/measures).\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2f09cc2a", "metadata": {}, "outputs": [], "source": [ "# ─── Identify Missing Values (SAS: View Measure Details) ──────\n", "print(\"=\" * 60)\n", "print(\"MISSING VALUE REPORT\")\n", "print(\"=\" * 60)\n", "\n", "missing = df.isnull().sum()\n", "missing_pct = (missing / len(df) * 100).round(2)\n", "missing_report = pd.DataFrame({\n", " 'Missing Count': missing,\n", " 'Missing %': missing_pct,\n", " 'Data Type': df.dtypes\n", "}).sort_values('Missing Count', ascending=False)\n", "\n", "# Show only variables WITH missing values\n", "has_missing = missing_report[missing_report['Missing Count'] > 0]\n", "\n", "if len(has_missing) > 0:\n", " print(\"\\n🚨 Variables with missing values:\")\n", " print(has_missing.to_string())\n", "else:\n", " print(\"\\nāœ… No missing values found!\")\n", "\n", "print(f\"\\nšŸ“Š Total rows: {len(df):,}\")\n", "print(f\"šŸ“Š Rows with ANY missing value: {df.isnull().any(axis=1).sum():,}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1882f55e", "metadata": {}, "outputs": [], "source": [ "# ─── Impute Missing Values ─────────────────────────────────────\n", "# In SAS, you created Thickness_Imp, Width_Imp, Length_Imp\n", "# replacing missing values with 0 using IF-ELSE logic.\n", "#\n", "# Python equivalent: fillna(0) or SimpleImputer\n", "# We'll use the SAME approach as the SAS exercise (fill with 0)\n", "# to keep results comparable.\n", "\n", "# Store which columns had missing values\n", "cols_with_missing = df.columns[df.isnull().any()].tolist()\n", "print(f\"Columns to impute: {cols_with_missing}\")\n", "\n", "# Create imputed versions (matching SAS approach: fill with 0)\n", "for col in cols_with_missing:\n", " imp_name = f\"{col}_Imp\"\n", " df[imp_name] = df[col].fillna(0)\n", " print(f\" āœ… Created '{imp_name}' — {df[col].isnull().sum()} missing values replaced with 0\")\n", "\n", "# Verify no missing values remain in imputed columns\n", "imp_cols = [f\"{c}_Imp\" for c in cols_with_missing]\n", "print(f\"\\nāœ… Missing values in imputed columns: {df[imp_cols].isnull().sum().sum()}\")\n" ] }, { "cell_type": "markdown", "id": "9b8e03bf", "metadata": {}, "source": [ "### Step 2.2: Prepare Features and Target\n", "\n", "Now we'll set up our predictor (X) and target (y) variables — matching the exact same variable list from the SAS exercise.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c8c4394c", "metadata": {}, "outputs": [], "source": [ "# ─── Define Target and Predictors ──────────────────────────────\n", "# Target: Stent_Failure (binary: Yes/No)\n", "# We'll encode it as 1 (Failure) / 0 (No Failure)\n", "\n", "# Encode the target variable\n", "if df['Stent_Failure'].dtype == 'object':\n", " df['Target'] = (df['Stent_Failure'] == 'Yes').astype(int)\n", "elif 'Stent_Failure_flag' in df.columns:\n", " df['Target'] = df['Stent_Failure_flag'].astype(int)\n", "else:\n", " df['Target'] = df['Stent_Failure'].astype(int)\n", "\n", "print(f\"Target encoding: 0 = No Failure, 1 = Failure\")\n", "print(f\"Target distribution: \\n{df['Target'].value_counts().to_string()}\")\n", "\n", "# ─── Define predictor variables (matching the SAS exercise) ───\n", "# These are the SAME predictors listed in the SAS case study\n", "\n", "# Categorical predictors\n", "cat_predictors = [\n", " 'Diabetic', 'Ethnic_group', 'Gender', 'Geographic_Miss',\n", " 'Hospital', 'Multiple_Stent', 'Cell_Design', 'Cell_type',\n", " 'Stent_Material', 'Coronary_Stream'\n", "]\n", "\n", "# Numeric predictors\n", "num_predictors = [\n", " 'Age', 'Device_age', 'Hours_Active', 'Plaque_Prolapse', 'Smoker_flag'\n", "]\n", "\n", "# For models that CAN handle missing values (Decision Tree, Random Forest):\n", "# use the original columns\n", "num_predictors_with_original = num_predictors + [\n", " 'Stent_Thickness', 'Stent_Width', 'Stent_Length'\n", "]\n", "\n", "# For Logistic Regression (CANNOT handle missing values well):\n", "# use the imputed columns\n", "num_predictors_with_imputed = num_predictors.copy()\n", "for col in cols_with_missing:\n", " if col in ['Stent_Thickness', 'Stent_Width', 'Stent_Length']:\n", " num_predictors_with_imputed.append(f\"{col}_Imp\")\n", "\n", "# Filter to columns that actually exist in the dataframe\n", "cat_predictors = [c for c in cat_predictors if c in df.columns]\n", "num_predictors_with_original = [c for c in num_predictors_with_original if c in df.columns]\n", "num_predictors_with_imputed = [c for c in num_predictors_with_imputed if c in df.columns]\n", "\n", "print(f\"\\nšŸ“‹ Categorical predictors ({len(cat_predictors)}): {cat_predictors}\")\n", "print(f\"šŸ“‹ Numeric predictors — original ({len(num_predictors_with_original)}): {num_predictors_with_original}\")\n", "print(f\"šŸ“‹ Numeric predictors — imputed ({len(num_predictors_with_imputed)}): {num_predictors_with_imputed}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0ed4e3ef", "metadata": {}, "outputs": [], "source": [ "# ─── Encode Categorical Variables ──────────────────────────────\n", "# SAS handles categoricals natively; Python needs encoding.\n", "# We'll use one-hot encoding (pd.get_dummies) which creates\n", "# binary columns for each category level.\n", "\n", "# Create a working copy for modeling\n", "df_model = df.copy()\n", "\n", "# One-hot encode categorical predictors\n", "df_encoded = pd.get_dummies(df_model[cat_predictors], drop_first=True, dtype=int)\n", "print(f\"Categorical encoding created {df_encoded.shape[1]} binary columns from {len(cat_predictors)} variables\")\n", "print(f\"\\nExample encoded columns (first 10):\")\n", "print(df_encoded.columns[:10].tolist())\n" ] }, { "cell_type": "markdown", "id": "e3fd7d0f", "metadata": {}, "source": [ "### Step 2.3: Create Train/Validation Split\n", "\n", "In SAS, you created a partition variable with **70% training / 30% validation** using random seed **772020**. We'll replicate that exact split.\n", "\n", "> **Why partition?** We train the model on 70% of the data, then test it on the remaining 30% it has never seen. This is our \"honest assessment\" — it tells us how the model will perform on new patients.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b2f2c65d", "metadata": {}, "outputs": [], "source": [ "# ─── Train/Validation Split (SAS: 70/30, seed 772020) ─────────\n", "\n", "# ── Build feature matrices ────────────────────────────────────\n", "\n", "# For Decision Tree and Random Forest (can handle NaN natively in some\n", "# implementations, but sklearn cannot — so we'll fill for safety)\n", "X_tree = pd.concat([\n", " df_model[num_predictors_with_original].fillna(0),\n", " df_encoded\n", "], axis=1)\n", "\n", "# For Logistic Regression (uses imputed columns)\n", "X_logreg = pd.concat([\n", " df_model[num_predictors_with_imputed],\n", " df_encoded\n", "], axis=1)\n", "\n", "y = df_model['Target']\n", "\n", "# ── Split: 70% train / 30% validation (matching SAS seed) ────\n", "X_train, X_val, y_train, y_val = train_test_split(\n", " X_tree, y,\n", " test_size=0.30,\n", " random_state=772020, # Same seed as SAS exercise\n", " stratify=y # Maintain class proportions\n", ")\n", "\n", "# Also split the logistic regression features using the same indices\n", "X_train_lr = X_logreg.loc[X_train.index]\n", "X_val_lr = X_logreg.loc[X_val.index]\n", "\n", "print(f\"{'='*50}\")\n", "print(f\"PARTITION SUMMARY (70/30 split, seed=772020)\")\n", "print(f\"{'='*50}\")\n", "print(f\"Training set: {len(X_train):,} observations ({len(X_train)/len(df)*100:.1f}%)\")\n", "print(f\"Validation set: {len(X_val):,} observations ({len(X_val)/len(df)*100:.1f}%)\")\n", "print(f\"\\nTarget distribution in Training:\")\n", "print(f\" No Failure: {(y_train==0).sum():,} ({(y_train==0).mean()*100:.1f}%)\")\n", "print(f\" Failure: {(y_train==1).sum():,} ({(y_train==1).mean()*100:.1f}%)\")\n", "print(f\"\\nTarget distribution in Validation:\")\n", "print(f\" No Failure: {(y_val==0).sum():,} ({(y_val==0).mean()*100:.1f}%)\")\n", "print(f\" Failure: {(y_val==1).sum():,} ({(y_val==1).mean()*100:.1f}%)\")\n" ] }, { "cell_type": "markdown", "id": "21b403ed", "metadata": {}, "source": [ "### šŸ”‘ Part 2 Checkpoint\n", "\n", "| What we did | SAS Equivalent | Why it matters |\n", "|-------------|----------------|----------------|\n", "| Identified missing values | Measure Details pane | Know what needs fixing |\n", "| Imputed with 0 | Calculated Item (IF-ELSE) | Keeps all observations for regression |\n", "| One-hot encoded categoricals | Handled automatically in SAS | Python needs explicit encoding |\n", "| 70/30 stratified split (seed 772020) | Partition variable | Honest assessment on unseen data |\n" ] }, { "cell_type": "markdown", "id": "bbac7b5c", "metadata": {}, "source": [ "---\n", "## šŸ”§ Helper Function: Model Evaluation\n", "\n", "Before building models, let's create a reusable evaluation function. This will generate the same metrics you examined in SAS: confusion matrix, misclassification rate, ROC/AUC, and KS Statistic.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "20f899b9", "metadata": {}, "outputs": [], "source": [ "def evaluate_model(model, X_train, y_train, X_val, y_val, model_name,\n", " threshold=0.3):\n", " \"\"\"\n", " Comprehensive model evaluation matching SAS Visual Analytics metrics.\n", " \n", " Uses a prediction cutoff of 0.3 (matching the SAS case study)\n", " instead of the default 0.5 — because in healthcare, we'd rather\n", " flag MORE potential failures than miss them.\n", " \n", " Parameters\n", " ----------\n", " threshold : float\n", " Prediction cutoff (SAS default in this exercise: 0.3)\n", " \"\"\"\n", " results = {'Model': model_name}\n", " \n", " # ── Get probability predictions ──────────────────────────\n", " y_train_prob = model.predict_proba(X_train)[:, 1]\n", " y_val_prob = model.predict_proba(X_val)[:, 1]\n", " \n", " # ── Apply custom threshold (SAS: prediction cutoff = 0.3) ─\n", " y_train_pred = (y_train_prob >= threshold).astype(int)\n", " y_val_pred = (y_val_prob >= threshold).astype(int)\n", " \n", " # ── Confusion Matrix (SAS: Misclassification Chart) ──────\n", " fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", " \n", " # Training confusion matrix\n", " cm_train = confusion_matrix(y_train, y_train_pred)\n", " ConfusionMatrixDisplay(cm_train, display_labels=['No Failure', 'Failure']).plot(\n", " ax=axes[0], cmap='Blues', colorbar=False)\n", " axes[0].set_title(f'{model_name}\\nTraining Confusion Matrix', fontweight='bold')\n", " \n", " # Validation confusion matrix\n", " cm_val = confusion_matrix(y_val, y_val_pred)\n", " ConfusionMatrixDisplay(cm_val, display_labels=['No Failure', 'Failure']).plot(\n", " ax=axes[1], cmap='Oranges', colorbar=False)\n", " axes[1].set_title(f'{model_name}\\nValidation Confusion Matrix', fontweight='bold')\n", " \n", " # ── ROC Curve (SAS: ROC Chart) ───────────────────────────\n", " fpr_train, tpr_train, _ = roc_curve(y_train, y_train_prob)\n", " fpr_val, tpr_val, _ = roc_curve(y_val, y_val_prob)\n", " auc_train = roc_auc_score(y_train, y_train_prob)\n", " auc_val = roc_auc_score(y_val, y_val_prob)\n", " \n", " axes[2].plot(fpr_train, tpr_train, 'b-', label=f'Training AUC = {auc_train:.4f}', linewidth=2)\n", " axes[2].plot(fpr_val, tpr_val, 'r--', label=f'Validation AUC = {auc_val:.4f}', linewidth=2)\n", " axes[2].plot([0, 1], [0, 1], 'k:', alpha=0.5, label='Random (AUC = 0.5)')\n", " axes[2].set_xlabel('False Positive Rate')\n", " axes[2].set_ylabel('True Positive Rate')\n", " axes[2].set_title(f'{model_name}\\nROC Curve', fontweight='bold')\n", " axes[2].legend(loc='lower right')\n", " axes[2].set_xlim([0, 1])\n", " axes[2].set_ylim([0, 1])\n", " \n", " plt.tight_layout()\n", " plt.show()\n", " \n", " # ── KS Statistic (SAS: KS Statistic from ROC) ───────────\n", " ks_val = max(tpr_val - fpr_val)\n", " ks_train = max(tpr_train - fpr_train)\n", " \n", " # ── Misclassification Rates ──────────────────────────────\n", " misclass_train = 1 - accuracy_score(y_train, y_train_pred)\n", " misclass_val = 1 - accuracy_score(y_val, y_val_pred)\n", " \n", " # ── False Negatives in Validation ────────────────────────\n", " tn, fp, fn, tp = cm_val.ravel()\n", " \n", " # ── False Positive Rate ──────────────────────────────────\n", " fpr_value = fp / (fp + tn) if (fp + tn) > 0 else 0\n", " \n", " # ── Print Summary ────────────────────────────────────────\n", " print(f\"\\n{'='*60}\")\n", " print(f\"šŸ“Š {model_name} — PERFORMANCE SUMMARY (Threshold = {threshold})\")\n", " print(f\"{'='*60}\")\n", " print(f\"\\n Metric Training Validation\")\n", " print(f\" {'─'*55}\")\n", " print(f\" Misclassification Rate {misclass_train:.4f} {misclass_val:.4f}\")\n", " print(f\" Accuracy {1-misclass_train:.4f} {1-misclass_val:.4f}\")\n", " print(f\" AUC (ROC) {auc_train:.4f} {auc_val:.4f}\")\n", " print(f\" KS Statistic {ks_train:.4f} {ks_val:.4f}\")\n", " print(f\" False Positive Rate (FPR) — {fpr_value:.4f}\")\n", " print(f\"\\n Validation Confusion Matrix Breakdown:\")\n", " print(f\" True Negatives (correct 'No Failure'): {tn:,}\")\n", " print(f\" True Positives (correct 'Failure'): {tp:,}\")\n", " print(f\" False Positives (false alarm): {fp:,}\")\n", " print(f\" False Negatives (MISSED failures): {fn:,} āš ļø\")\n", " \n", " print(f\"\\n šŸ“‹ Classification Report (Validation):\")\n", " print(classification_report(y_val, y_val_pred,\n", " target_names=['No Failure', 'Failure'], digits=4))\n", " \n", " # Store results for comparison\n", " results['Train_Misclass'] = round(misclass_train, 4)\n", " results['Val_Misclass'] = round(misclass_val, 4)\n", " results['Train_AUC'] = round(auc_train, 4)\n", " results['Val_AUC'] = round(auc_val, 4)\n", " results['Train_KS'] = round(ks_train, 4)\n", " results['Val_KS'] = round(ks_val, 4)\n", " results['Val_FPR'] = round(fpr_value, 4)\n", " results['Val_FalseNeg'] = fn\n", " results['Val_Accuracy'] = round(1 - misclass_val, 4)\n", " \n", " return results\n", "\n", "# We'll collect results from each model here\n", "all_results = []\n", "print(\"āœ… Evaluation function ready!\")\n" ] }, { "cell_type": "markdown", "id": "5afb7113", "metadata": {}, "source": [ "---\n", "## Part 3: Model Building šŸ¤–\n", "**SAS Equivalent:** *Objects pane → Decision Tree / Logistic Regression / Forest*\n", "\n", "Now for the fun part. We'll build three models — the same ones from the SAS exercise — and evaluate each one. \n", "\n", "> **🩺 Healthcare context:** We're using a prediction cutoff of **0.3** (not the default 0.5). Why? In healthcare, missing a real stent failure (false negative) is far worse than a false alarm. A lower threshold means we'll flag more patients for follow-up — better safe than sorry.\n" ] }, { "cell_type": "markdown", "id": "3a2c3677", "metadata": {}, "source": [ "### Model 1: Decision Tree 🌳\n", "**SAS Equivalent:** *Objects pane → Decision Tree → Stent Failure as Response*\n", "\n", "A decision tree is like a flowchart of yes/no questions. It splits the data based on the most informative features until it reaches a prediction at each \"leaf.\"\n", "\n", "**In the SAS exercise, the key findings were:**\n", "- First split variable: **Plaque Prolapse** (< 1 or Missing vs. >= 1)\n", "- Key variables used: Plaque Prolapse, Stent Material, Geographic Miss, Stent Cell Design, Multiple Stent, Device Age\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8f5a53c6", "metadata": {}, "outputs": [], "source": [ "# ─── Model 1: Decision Tree ───────────────────────────────────\n", "# Build the tree (comparable to SAS default settings)\n", "dt_model = DecisionTreeClassifier(\n", " max_depth=6, # Limit depth to prevent overfitting\n", " min_samples_leaf=50, # Each leaf needs at least 50 patients\n", " random_state=772020, # Reproducibility\n", " class_weight='balanced' # Account for class imbalance\n", ")\n", "\n", "dt_model.fit(X_train, y_train)\n", "print(f\"āœ… Decision Tree trained!\")\n", "print(f\" Tree depth: {dt_model.get_depth()}\")\n", "print(f\" Number of leaves: {dt_model.get_n_leaves()}\")\n", "print(f\" Features used: {(dt_model.feature_importances_ > 0).sum()} of {X_train.shape[1]}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f3ca0ee9", "metadata": {}, "outputs": [], "source": [ "# ─── Visualize the Decision Tree ──────────────────────────────\n", "# This is the Python equivalent of the tree diagram in SAS VA\n", "\n", "plt.figure(figsize=(24, 12))\n", "plot_tree(dt_model,\n", " feature_names=X_train.columns,\n", " class_names=['No Failure', 'Failure'],\n", " filled=True,\n", " rounded=True,\n", " fontsize=8,\n", " max_depth=3, # Show first 3 levels for readability\n", " proportion=True)\n", "plt.title('Decision Tree — First 3 Levels\\n(Full tree may be deeper)',\n", " fontsize=16, fontweight='bold')\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Text representation of the tree rules\n", "print(\"\\nšŸ“œ Decision Tree Rules (first 3 levels):\")\n", "tree_rules = export_text(dt_model, feature_names=list(X_train.columns), max_depth=3)\n", "print(tree_rules)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "73a3f2f4", "metadata": {}, "outputs": [], "source": [ "# ─── Feature Importance — What's Driving the Splits? ──────────\n", "dt_importance = pd.Series(\n", " dt_model.feature_importances_, index=X_train.columns\n", ").sort_values(ascending=False)\n", "\n", "# Show top 15 most important features\n", "top_features = dt_importance.head(15)\n", "\n", "plt.figure(figsize=(10, 6))\n", "top_features.plot(kind='barh', color='#3498db', edgecolor='black', alpha=0.85)\n", "plt.xlabel('Feature Importance (Gini)')\n", "plt.title('Decision Tree — Top 15 Most Important Features', fontsize=14, fontweight='bold')\n", "plt.gca().invert_yaxis()\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"\\nšŸ“Œ Top 10 Features:\")\n", "for i, (feat, imp) in enumerate(top_features.head(10).items(), 1):\n", " print(f\" {i}. {feat}: {imp:.4f}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7dca157e", "metadata": {}, "outputs": [], "source": [ "# ─── Evaluate Decision Tree ───────────────────────────────────\n", "dt_results = evaluate_model(dt_model, X_train, y_train, X_val, y_val,\n", " 'Decision Tree', threshold=0.3)\n", "all_results.append(dt_results)\n" ] }, { "cell_type": "markdown", "id": "109f4595", "metadata": {}, "source": [ "### Model 2: Random Forest 🌲🌲🌲\n", "**SAS Equivalent:** *Objects pane → Forest*\n", "\n", "A Random Forest is an ensemble of many decision trees, each trained on a different random sample of the data. Instead of relying on one tree's opinion, we take a vote across 100+ trees. This typically gives better, more stable predictions.\n", "\n", "**In the SAS exercise:**\n", "- The Forest used **all variables**\n", "- It had a higher misclassification rate than Gradient Boosting but lower FPR\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f2aed835", "metadata": {}, "outputs": [], "source": [ "# ─── Model 2: Random Forest ───────────────────────────────────\n", "rf_model = RandomForestClassifier(\n", " n_estimators=100, # 100 trees (SAS default was also ~100)\n", " max_depth=None, # Let trees grow fully\n", " min_samples_leaf=20, # Minimum samples per leaf\n", " random_state=772020, # Reproducibility\n", " class_weight='balanced', # Account for class imbalance\n", " n_jobs=-1 # Use all CPU cores\n", ")\n", "\n", "rf_model.fit(X_train, y_train)\n", "print(f\"āœ… Random Forest trained!\")\n", "print(f\" Number of trees: {rf_model.n_estimators}\")\n", "print(f\" Features considered per split: {rf_model.max_features}\")\n", "print(f\" Features with importance > 0: {(rf_model.feature_importances_ > 0).sum()}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "61add18c", "metadata": {}, "outputs": [], "source": [ "# ─── Feature Importance — Random Forest ───────────────────────\n", "rf_importance = pd.Series(\n", " rf_model.feature_importances_, index=X_train.columns\n", ").sort_values(ascending=False)\n", "\n", "top_rf = rf_importance.head(15)\n", "\n", "plt.figure(figsize=(10, 6))\n", "top_rf.plot(kind='barh', color='#27ae60', edgecolor='black', alpha=0.85)\n", "plt.xlabel('Feature Importance (Mean Decrease in Impurity)')\n", "plt.title('Random Forest — Top 15 Most Important Features', fontsize=14, fontweight='bold')\n", "plt.gca().invert_yaxis()\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"\\nšŸ“Œ Top 10 Features:\")\n", "for i, (feat, imp) in enumerate(top_rf.head(10).items(), 1):\n", " print(f\" {i}. {feat}: {imp:.4f}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4674ad69", "metadata": {}, "outputs": [], "source": [ "# ─── Evaluate Random Forest ───────────────────────────────────\n", "rf_results = evaluate_model(rf_model, X_train, y_train, X_val, y_val,\n", " 'Random Forest', threshold=0.3)\n", "all_results.append(rf_results)\n" ] }, { "cell_type": "markdown", "id": "2057044d", "metadata": {}, "source": [ "### Model 3: Logistic Regression šŸ“ˆ\n", "**SAS Equivalent:** *Objects pane → Logistic Regression → Variable Selection Method*\n", "\n", "Logistic Regression answers a simple question: **\"What's the probability this patient will have a stent failure?\"** It's one of the most interpretable models — you can see exactly how much each variable pushes the prediction up or down.\n", "\n", "**In the SAS exercise:**\n", "- Variables were selected by the model (not all were significant)\n", "- The Fit Summary ranked variables by p-value importance\n", "- Key variables included: Plaque Prolapse, Stent Material, Geographic Miss, Stent Cell Design, Multiple Stent, Hours Active, Age, Device Age, Ethnicity, Hospital, Stent Thickness\n", "\n", "> **āš ļø Note:** Logistic Regression cannot handle missing values well. That's why we use the **imputed** versions of Stent Thickness, Width, and Length here.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7252cd1c", "metadata": {}, "outputs": [], "source": [ "# ─── Scale numeric features for Logistic Regression ───────────\n", "# Logistic Regression works better when features are on similar scales.\n", "# SAS handles this internally; in Python we do it explicitly.\n", "\n", "scaler = StandardScaler()\n", "X_train_lr_scaled = X_train_lr.copy()\n", "X_val_lr_scaled = X_val_lr.copy()\n", "\n", "# Scale only the numeric columns (not the one-hot encoded ones)\n", "num_cols_in_lr = [c for c in num_predictors_with_imputed if c in X_train_lr.columns]\n", "X_train_lr_scaled[num_cols_in_lr] = scaler.fit_transform(X_train_lr[num_cols_in_lr])\n", "X_val_lr_scaled[num_cols_in_lr] = scaler.transform(X_val_lr[num_cols_in_lr])\n", "\n", "# ─── Build Logistic Regression Model ─────────────────────────\n", "lr_model = LogisticRegression(\n", " max_iter=1000, # Enough iterations to converge\n", " random_state=772020,\n", " class_weight='balanced', # Account for class imbalance\n", " solver='lbfgs',\n", " C=1.0 # Regularization (default)\n", ")\n", "\n", "lr_model.fit(X_train_lr_scaled, y_train)\n", "print(f\"āœ… Logistic Regression trained!\")\n", "print(f\" Number of features: {X_train_lr_scaled.shape[1]}\")\n", "print(f\" Converged: {lr_model.n_iter_[0] < 1000}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "41f58a7e", "metadata": {}, "outputs": [], "source": [ "# ─── Variable Importance — Logistic Regression Coefficients ───\n", "# In SAS, the Fit Summary ranked variables by p-value.\n", "# In Python, we look at the absolute value of coefficients.\n", "\n", "lr_coefs = pd.Series(\n", " lr_model.coef_[0], index=X_train_lr_scaled.columns\n", ")\n", "\n", "# Sort by absolute value (most impactful predictors)\n", "lr_coefs_sorted = lr_coefs.abs().sort_values(ascending=False).head(20)\n", "\n", "# Plot with direction (positive = increases failure risk, negative = decreases)\n", "top_lr_features = lr_coefs[lr_coefs_sorted.index]\n", "\n", "plt.figure(figsize=(10, 8))\n", "colors_lr = ['#e74c3c' if v > 0 else '#2ecc71' for v in top_lr_features]\n", "top_lr_features.plot(kind='barh', color=colors_lr, edgecolor='black', alpha=0.85)\n", "plt.xlabel('Coefficient Value')\n", "plt.title('Logistic Regression — Top 20 Coefficients\\n(Red = ↑ failure risk, Green = ↓ failure risk)',\n", " fontsize=14, fontweight='bold')\n", "plt.axvline(x=0, color='black', linewidth=0.5)\n", "plt.gca().invert_yaxis()\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"\\nšŸ“Œ Top 10 Predictors (by coefficient magnitude):\")\n", "for i, (feat, coef) in enumerate(zip(lr_coefs_sorted.index[:10], lr_coefs_sorted[:10]), 1):\n", " direction = \"↑ risk\" if lr_coefs[feat] > 0 else \"↓ risk\"\n", " print(f\" {i}. {feat}: {lr_coefs[feat]:+.4f} ({direction})\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1c61b646", "metadata": {}, "outputs": [], "source": [ "# ─── Evaluate Logistic Regression ──────────────────────────────\n", "lr_results = evaluate_model(lr_model, X_train_lr_scaled, y_train,\n", " X_val_lr_scaled, y_val,\n", " 'Logistic Regression', threshold=0.3)\n", "all_results.append(lr_results)\n" ] }, { "cell_type": "markdown", "id": "c8a47bb3", "metadata": {}, "source": [ "---\n", "## Part 4: Model Comparison šŸ†\n", "**SAS Equivalent:** *Objects pane → Model Comparison → Select all models → Prediction cutoff = 0.3*\n", "\n", "Now the big question: **Which model should the hospital actually deploy?**\n", "\n", "In SAS, you used the Model Comparison tool to compare all models side by side. Here's the Python equivalent.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "330c18ca", "metadata": {}, "outputs": [], "source": [ "# ─── Side-by-Side Model Comparison Table ──────────────────────\n", "comparison = pd.DataFrame(all_results)\n", "comparison = comparison.set_index('Model')\n", "\n", "# Reorder columns for clarity\n", "display_cols = [\n", " 'Val_AUC', 'Val_KS', 'Val_Misclass', 'Val_Accuracy',\n", " 'Val_FPR', 'Val_FalseNeg',\n", " 'Train_AUC', 'Train_KS', 'Train_Misclass'\n", "]\n", "display_cols = [c for c in display_cols if c in comparison.columns]\n", "\n", "print(\"=\" * 80)\n", "print(\"šŸ“Š MODEL COMPARISON DASHBOARD (Prediction Cutoff = 0.3)\")\n", "print(\"=\" * 80)\n", "print()\n", "print(comparison[display_cols].to_string())\n", "print()\n", "\n", "# ── Highlight the winners ─────────────────────────────────────\n", "print(\"\\nšŸ† WINNERS BY METRIC:\")\n", "print(f\" Best Validation AUC: {comparison['Val_AUC'].idxmax()} ({comparison['Val_AUC'].max():.4f})\")\n", "print(f\" Best Validation KS Statistic: {comparison['Val_KS'].idxmax()} ({comparison['Val_KS'].max():.4f})\")\n", "print(f\" Best Validation FPR: {comparison['Val_FPR'].idxmin()} ({comparison['Val_FPR'].min():.4f})\")\n", "print(f\" Fewest False Negatives: {comparison['Val_FalseNeg'].idxmin()} ({comparison['Val_FalseNeg'].min():,})\")\n", "print(f\" Lowest Misclassification: {comparison['Val_Misclass'].idxmin()} ({comparison['Val_Misclass'].min():.4f})\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "42cc4372", "metadata": {}, "outputs": [], "source": [ "# ─── Visual Comparison ────────────────────────────────────────\n", "fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n", "\n", "models = comparison.index.tolist()\n", "colors_bar = ['#3498db', '#27ae60', '#e67e22']\n", "\n", "# 1. AUC Comparison\n", "ax = axes[0, 0]\n", "x = range(len(models))\n", "width = 0.35\n", "ax.bar([i - width/2 for i in x], comparison['Train_AUC'], width, \n", " label='Training', color='lightsteelblue', edgecolor='black')\n", "ax.bar([i + width/2 for i in x], comparison['Val_AUC'], width,\n", " label='Validation', color='steelblue', edgecolor='black')\n", "ax.set_xticks(x)\n", "ax.set_xticklabels(models, rotation=15)\n", "ax.set_title('AUC (Higher = Better)', fontweight='bold')\n", "ax.set_ylim(0.5, 1.0)\n", "ax.legend()\n", "ax.axhline(y=0.5, color='red', linestyle=':', alpha=0.5, label='Random')\n", "\n", "# 2. KS Statistic\n", "ax = axes[0, 1]\n", "ax.bar([i - width/2 for i in x], comparison['Train_KS'], width,\n", " label='Training', color='#f9e79f', edgecolor='black')\n", "ax.bar([i + width/2 for i in x], comparison['Val_KS'], width,\n", " label='Validation', color='#f39c12', edgecolor='black')\n", "ax.set_xticks(x)\n", "ax.set_xticklabels(models, rotation=15)\n", "ax.set_title('KS Statistic (Higher = Better)', fontweight='bold')\n", "ax.legend()\n", "\n", "# 3. Misclassification Rate\n", "ax = axes[1, 0]\n", "ax.bar([i - width/2 for i in x], comparison['Train_Misclass'], width,\n", " label='Training', color='#fadbd8', edgecolor='black')\n", "ax.bar([i + width/2 for i in x], comparison['Val_Misclass'], width,\n", " label='Validation', color='#e74c3c', edgecolor='black')\n", "ax.set_xticks(x)\n", "ax.set_xticklabels(models, rotation=15)\n", "ax.set_title('Misclassification Rate (Lower = Better)', fontweight='bold')\n", "ax.legend()\n", "\n", "# 4. False Negatives (Validation only)\n", "ax = axes[1, 1]\n", "bars = ax.bar(models, comparison['Val_FalseNeg'], color=colors_bar, edgecolor='black', alpha=0.85)\n", "ax.set_title('False Negatives — Validation\\n(Lower = Fewer Missed Failures)', fontweight='bold')\n", "ax.set_ylabel('Count')\n", "for bar, val in zip(bars, comparison['Val_FalseNeg']):\n", " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,\n", " f'{val:,}', ha='center', fontweight='bold')\n", "\n", "plt.suptitle('Model Comparison Dashboard', fontsize=16, fontweight='bold', y=1.01)\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "89e24fd6", "metadata": {}, "outputs": [], "source": [ "# ─── Overlaid ROC Curves ──────────────────────────────────────\n", "plt.figure(figsize=(8, 8))\n", "\n", "model_data = [\n", " (dt_model, X_val, 'Decision Tree', '#3498db', '-'),\n", " (rf_model, X_val, 'Random Forest', '#27ae60', '--'),\n", " (lr_model, X_val_lr_scaled, 'Logistic Regression', '#e67e22', '-.'),\n", "]\n", "\n", "for model, X, name, color, ls in model_data:\n", " y_prob = model.predict_proba(X)[:, 1]\n", " fpr, tpr, _ = roc_curve(y_val, y_prob)\n", " auc = roc_auc_score(y_val, y_prob)\n", " plt.plot(fpr, tpr, color=color, linestyle=ls, linewidth=2.5,\n", " label=f'{name} (AUC = {auc:.4f})')\n", "\n", "plt.plot([0, 1], [0, 1], 'k:', alpha=0.5, linewidth=1, label='Random (AUC = 0.5)')\n", "plt.xlabel('False Positive Rate', fontsize=12)\n", "plt.ylabel('True Positive Rate', fontsize=12)\n", "plt.title('ROC Curves — All Models Compared', fontsize=14, fontweight='bold')\n", "plt.legend(loc='lower right', fontsize=11)\n", "plt.xlim([0, 1])\n", "plt.ylim([0, 1])\n", "plt.grid(True, alpha=0.3)\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "e84525ae", "metadata": {}, "source": [ "---\n", "## 🧠 Interpretation & Business Recommendations\n", "\n", "### Comparing SAS and Python Results\n", "\n", "The models you built in SAS Visual Analytics and the ones here in Python follow the **exact same logic**:\n", "\n", "| Concept | SAS Viya | Python (scikit-learn) |\n", "|---------|----------|----------------------|\n", "| Decision Tree | Objects → Decision Tree | `DecisionTreeClassifier()` |\n", "| Random Forest | Objects → Forest | `RandomForestClassifier()` |\n", "| Logistic Regression | Objects → Logistic Regression | `LogisticRegression()` |\n", "| Train/Validation Split | Partition Variable (70/30) | `train_test_split(test_size=0.30)` |\n", "| Missing Value Handling | Calculated Item (IF-ELSE) | `fillna(0)` or `SimpleImputer` |\n", "| Model Comparison | Model Comparison Object | Custom comparison DataFrame |\n", "| Prediction Cutoff | Adjust cutoff to 0.3 | `threshold=0.3` in evaluation |\n", "\n", "### Key Takeaways\n", "\n", "**1. The most important predictor across all models is Plaque Prolapse** — consistent with SAS results.\n", "\n", "**2. In healthcare, false negatives are the most dangerous mistake.** A false negative means a patient whose stent WILL fail gets sent home without extra monitoring. That's why we use a 0.3 cutoff instead of 0.5.\n", "\n", "**3. No single model is \"best\" at everything.** The choice depends on what the hospital prioritizes:\n", "- **Fewest missed failures (false negatives)?** → Pick the model with lowest FN count\n", "- **Best overall discrimination (AUC)?** → Usually the ensemble model (Random Forest)\n", "- **Most interpretable for doctors?** → Logistic Regression (clear coefficients) or Decision Tree (visual rules)\n", "\n", "### The Bigger Picture\n", "\n", "This exercise demonstrates a critical skill: **platform independence.** Whether you use SAS, Python, R, or any other tool, the analytical thinking is the same:\n", "1. Understand the business problem\n", "2. Explore and clean the data\n", "3. Build and evaluate models\n", "4. Compare and select the best approach\n", "5. Communicate results to stakeholders\n", "\n", "The tool is just the vehicle. **Your analytical judgment is what drives the outcome.**\n" ] }, { "cell_type": "markdown", "id": "e47ffb84", "metadata": {}, "source": [ "---\n", "## šŸ“ Discussion Questions\n", "\n", "Use these to reflect on your work or discuss with your team:\n", "\n", "1. **Which model would you recommend to the hospital, and why?** Consider not just accuracy but also interpretability — can a doctor understand why the model flagged a patient?\n", "\n", "2. **Why did we use a 0.3 prediction cutoff instead of 0.5?** What would happen if we used 0.5 instead? What would happen with 0.1?\n", "\n", "3. **How did your SAS results compare to the Python results?** Were the same variables important? Were the metrics similar?\n", "\n", "4. **What are the ethical implications of using ML to predict medical outcomes?** What happens when the model is wrong? Who is responsible?\n", "\n", "5. **If you had to explain this model to a cardiologist with no data science background, how would you describe what it does?**\n" ] }, { "cell_type": "markdown", "id": "ee6cd2a3", "metadata": {}, "source": [ "---\n", "*Notebook created as a Python companion to the SAS Machine Learning in Healthcare Case Study.* \n", "*Original SAS case study by Haidar Altaie, University of Kent.* \n", "*Python adaptation for FINC 332: Data Analytics, Data Mining, and Data Visualization & GMBA 621: Predictive Analytics & Data Mining.*\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4563abe2-2841-4887-a90c-3ffaa7e7dcc2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }