#!/usr/bin/env python3
"""
Open-data proxy reproduction for the hypothesis:
PPG dicrotic notch amplitude / notch ratio changes around intraoperative noxious stimulation.

Direct datasets from electrical-stimulation nociception papers are not public. This uses VitalDB
intraoperative PPG and treats operation start (opstart) as a proxy noxious/surgical-stimulus event.

Signals: SNUADC/PLETH, downsampled to 100 Hz via vitaldb interval=0.01.
Event windows: pre [-90,-30] s and post [+30,+90] s around opstart.
Features per pulse:
  - PPGA: peak - foot
  - notch_amp: notch - foot, where notch is first local minimum after systolic peak
  - notch_ratio: notch_amp / PPGA
  - notch_depth_ratio: (peak - notch) / PPGA
"""
import csv, json, math, os, statistics
from pathlib import Path

import numpy as np
from scipy.signal import butter, filtfilt, find_peaks, savgol_filter
from scipy.stats import wilcoxon
import matplotlib.pyplot as plt
import vitaldb

ROOT = Path(__file__).resolve().parent
CLIN = ROOT / 'clinical_data.csv'
OUT = ROOT
FS = 100.0
TRACK = 'SNUADC/PLETH'


def load_clinical_cases(limit=40):
    rows = []
    with open(CLIN, newline='', encoding='utf-8-sig') as f:
        rdr = csv.DictReader(f)
        for r in rdr:
            try:
                cid = int(r['caseid']); opstart = float(r['opstart']); caseend = float(r['caseend'])
            except Exception:
                continue
            ane = r.get('ane_type', '')
            # Need enough pre/post recording around opstart.
            if opstart > 180 and caseend > opstart + 180 and (not ane or 'General' in ane or 'general' in ane.lower()):
                rows.append({'caseid': cid, 'opstart': opstart, 'caseend': caseend, 'ane_type': ane})
    # Prefer smaller case IDs because API/cache tends to be faster; still deterministic.
    return rows[:limit]


def preprocess(x):
    x = np.asarray(x, dtype=float)
    # Remove pathological sentinel values and interpolate NaNs.
    x[(x < -100) | (x > 1000)] = np.nan
    if np.mean(np.isfinite(x)) < 0.8:
        return None
    idx = np.arange(len(x))
    good = np.isfinite(x)
    x[~good] = np.interp(idx[~good], idx[good], x[good])
    # Bandpass for pulse morphology. Keep enough harmonic content for notch.
    b, a = butter(3, [0.4/(FS/2), 12/(FS/2)], btype='band')
    y = filtfilt(b, a, x)
    # Normalize sign: PPG pulse should have sharp upward systolic peaks.
    # If negative peaks dominate, invert.
    if np.percentile(y, 95) < abs(np.percentile(y, 5)):
        y = -y
    y = (y - np.median(y)) / (np.std(y) + 1e-9)
    return y


def pulse_features(seg):
    y = preprocess(seg)
    if y is None or len(y) < 10 * FS:
        return []
    # Smooth just enough for extrema stability.
    win = int(0.09 * FS) // 2 * 2 + 1
    ys = savgol_filter(y, win, 3)
    # Systolic peaks: plausible HR 40-150 bpm.
    peaks, props = find_peaks(ys, distance=int(0.40*FS), prominence=0.25)
    feats = []
    for p0, p1 in zip(peaks[:-1], peaks[1:]):
        ibi = (p1 - p0) / FS
        if not (0.4 <= ibi <= 1.5):
            continue
        # Foot/trough before current peak: local min between previous peak and peak.
        left = max(0, p0 - int(0.5 * FS))
        foot_idx = left + int(np.argmin(ys[left:p0+1])) if p0 > left else p0
        foot = ys[foot_idx]
        peak = ys[p0]
        amp = peak - foot
        if amp <= 0.25:
            continue
        # Dicrotic notch: first prominent local minimum on downslope after peak.
        start = p0 + int(0.08 * FS)
        end = min(p0 + int(0.55 * ibi * FS), p1 - int(0.05 * FS))
        if end <= start + 3:
            continue
        ds = ys[start:end]
        mins, _ = find_peaks(-ds, distance=int(0.05*FS), prominence=0.03)
        if len(mins):
            ni = start + int(mins[0])
        else:
            # Fallback: minimum in the expected catacrotic region.
            ni = start + int(np.argmin(ds))
        notch = ys[ni]
        # Exclude implausible notches below foot by too much or above peak.
        notch_amp = notch - foot
        notch_ratio = notch_amp / amp
        notch_depth_ratio = (peak - notch) / amp
        if -0.5 <= notch_ratio <= 1.2 and -0.2 <= notch_depth_ratio <= 1.5:
            feats.append({
                'ppga': float(amp),
                'notch_amp': float(notch_amp),
                'notch_ratio': float(notch_ratio),
                'notch_depth_ratio': float(notch_depth_ratio),
                'ibi': float(ibi),
            })
    return feats


def med(feats, key):
    vals = [f[key] for f in feats if math.isfinite(f[key])]
    return float(np.median(vals)) if vals else float('nan')


def analyze_case(case):
    cid = case['caseid']; op = case['opstart']
    # Load whole PLETH at 100 Hz; VitalDB aligns to case timeline.
    arr = vitaldb.load_case(cid, [TRACK], interval=0.01)
    if arr is None or len(arr) == 0:
        return None
    x = arr[:, 0]
    n = len(x)
    pre0, pre1 = int((op - 90) * FS), int((op - 30) * FS)
    post0, post1 = int((op + 30) * FS), int((op + 90) * FS)
    if pre0 < 0 or post1 > n:
        return None
    pre = pulse_features(x[pre0:pre1])
    post = pulse_features(x[post0:post1])
    if len(pre) < 20 or len(post) < 20:
        return None
    row = {'caseid': cid, 'opstart': op, 'n_pre': len(pre), 'n_post': len(post)}
    for key in ['ppga', 'notch_amp', 'notch_ratio', 'notch_depth_ratio', 'ibi']:
        row[f'pre_{key}'] = med(pre, key)
        row[f'post_{key}'] = med(post, key)
        row[f'delta_{key}'] = row[f'post_{key}'] - row[f'pre_{key}']
        row[f'rel_delta_{key}'] = row[f'delta_{key}'] / (abs(row[f'pre_{key}']) + 1e-9)
    return row


def auc_from_scores(labels, scores):
    # Mann-Whitney AUC; labels 1=post, 0=pre.
    pairs = sorted(zip(scores, labels), key=lambda x: x[0])
    ranks = {}
    i = 0
    rank_vals = [0] * len(pairs)
    while i < len(pairs):
        j = i
        while j < len(pairs) and pairs[j][0] == pairs[i][0]:
            j += 1
        avg = (i + 1 + j) / 2
        for k in range(i, j):
            rank_vals[k] = avg
        i = j
    n1 = sum(labels); n0 = len(labels) - n1
    if n1 == 0 or n0 == 0:
        return float('nan')
    sum_r1 = sum(r for r, (_, lab) in zip(rank_vals, pairs) if lab == 1)
    return (sum_r1 - n1*(n1+1)/2) / (n1*n0)


def make_plot(rows, summary):
    keys = ['ppga', 'notch_amp', 'notch_ratio', 'notch_depth_ratio']
    labels = ['PPGA', 'Notch amp', 'Notch ratio', 'Notch depth ratio']
    deltas = [[r[f'rel_delta_{k}'] for r in rows] for k in keys]
    fig, ax = plt.subplots(figsize=(10, 5.2), facecolor='#0b1020')
    ax.set_facecolor('#0b1020')
    bp = ax.boxplot(deltas, tick_labels=labels, patch_artist=True, showfliers=False)
    colors = ['#60a5fa', '#f97316', '#fb923c', '#a78bfa']
    for patch, c in zip(bp['boxes'], colors):
        patch.set_facecolor(c); patch.set_alpha(0.8)
    for part in ['whiskers','caps','medians']:
        for item in bp[part]: item.set_color('#e2e8f0')
    ax.axhline(0, color='#94a3b8', lw=1, alpha=0.7)
    ax.set_ylabel('Relative change post-opstart vs pre-opstart')
    ax.tick_params(colors='#e2e8f0')
    ax.yaxis.label.set_color('#e2e8f0')
    ax.set_title('PPG notch features around operation start (VitalDB proxy)', color='#f8fafc', fontsize=15, weight='bold')
    ax.text(0.02, -0.22, f"n={len(rows)} cases · opstart used as proxy noxious event · median pulse features per 60s window", transform=ax.transAxes, color='#94a3b8')
    for spine in ax.spines.values(): spine.set_color('#334155')
    plt.tight_layout()
    out = OUT / 'notch_opstart_results.png'
    fig.savefig(out, dpi=160)
    plt.close(fig)


def main():
    cases = load_clinical_cases(limit=30)
    rows = []
    errors = []
    for c in cases:
        try:
            print(f"case {c['caseid']} opstart={c['opstart']}", flush=True)
            r = analyze_case(c)
            if r:
                rows.append(r)
                print('  ok', r['n_pre'], r['n_post'], 'd_notch_ratio', round(r['delta_notch_ratio'], 4), flush=True)
            else:
                print('  skipped', flush=True)
        except Exception as e:
            errors.append({'caseid': c['caseid'], 'error': str(e)})
            print('  error', e, flush=True)
        if len(rows) >= 12:
            break
    if not rows:
        raise SystemExit('no analyzable cases')

    fieldnames = list(rows[0].keys())
    with open(OUT / 'case_features.csv', 'w', newline='') as f:
        wr = csv.DictWriter(f, fieldnames=fieldnames); wr.writeheader(); wr.writerows(rows)

    summary = {'dataset': 'VitalDB v1.0.0 via PhysioNet/API', 'event_proxy': 'operation start (opstart)', 'n_cases': len(rows), 'features': {}, 'errors': errors}
    for key in ['ppga', 'notch_amp', 'notch_ratio', 'notch_depth_ratio', 'ibi']:
        pre = np.array([r[f'pre_{key}'] for r in rows], float)
        post = np.array([r[f'post_{key}'] for r in rows], float)
        delta = post - pre
        try:
            w = wilcoxon(post, pre, zero_method='wilcox', alternative='two-sided')
            p = float(w.pvalue)
        except Exception:
            p = float('nan')
        # Direction-invariant discrimination: if feature decreases, use -feature as post score.
        mean_delta = float(np.mean(delta))
        labels = [0]*len(pre) + [1]*len(post)
        scores = list(pre) + list(post)
        auc = auc_from_scores(labels, scores)
        if auc < 0.5:
            auc_dir = 1 - auc
            direction = 'decrease_post'
        else:
            auc_dir = auc
            direction = 'increase_post'
        summary['features'][key] = {
            'pre_median': float(np.median(pre)),
            'post_median': float(np.median(post)),
            'delta_median': float(np.median(delta)),
            'delta_mean': mean_delta,
            'relative_delta_median': float(np.median(delta / (np.abs(pre)+1e-9))),
            'wilcoxon_p': p,
            'directional_auc_pre_vs_post': float(auc_dir),
            'direction': direction,
        }
    with open(OUT / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    make_plot(rows, summary)
    print(json.dumps(summary, indent=2))

if __name__ == '__main__':
    main()
