#!/usr/bin/env python3
"""
Python runner for the MATLAB/Octave reproduction scaffold of:
Charlton et al. / Pimentel et al. respiratory-rate estimation from PPG.

The companion MATLAB-compatible script is `reproduce_ppg_rr_matlab.m`.
This Python script executes the same pipeline in the current environment where MATLAB/Octave
is not installed, so the numerical results can be generated and verified.

Dataset: BIDMC PPG and Respiration Dataset, PhysioNet.
Inputs expected under ./data:
  bidmc_##_Signals.csv, bidmc_##_Breaths.csv
"""
import csv, json, math, os, statistics
from pathlib import Path

ROOT = Path(__file__).resolve().parent
DATA = ROOT / 'data'
FS = 125.0
WIN_SEC = 64.0
STEP_SEC = 32.0
RESP_BAND_BPM = (4.0, 60.0)

# ---------------- basic DSP without external dependencies ----------------

def moving_average(x, n):
    if n <= 1:
        return x[:]
    out = []
    s = 0.0
    q = []
    for v in x:
        q.append(v); s += v
        if len(q) > n:
            s -= q.pop(0)
        out.append(s / len(q))
    return out

def detrend_ma(x, fs, sec=2.0):
    trend = moving_average(x, max(3, int(fs * sec)))
    return [a - b for a, b in zip(x, trend)]

def zscore(x):
    mu = statistics.mean(x)
    sd = statistics.pstdev(x) or 1.0
    return [(v - mu) / sd for v in x]

def resample_linear(times, values, fs_new, start, stop):
    if len(times) < 2:
        return [], []
    n = int((stop - start) * fs_new) + 1
    grid = [start + i / fs_new for i in range(n)]
    out = []
    j = 0
    for t in grid:
        while j + 1 < len(times) and times[j + 1] < t:
            j += 1
        if j + 1 >= len(times):
            out.append(values[-1])
        else:
            t0, t1 = times[j], times[j + 1]
            v0, v1 = values[j], values[j + 1]
            if t1 == t0:
                out.append(v0)
            else:
                a = (t - t0) / (t1 - t0)
                out.append(v0 + a * (v1 - v0))
    return grid, out

def dft_peak_bpm(x, fs, lo_bpm=4.0, hi_bpm=60.0):
    # Hann-windowed direct DFT around respiratory band.
    n = len(x)
    if n < 8:
        return float('nan'), 0.0
    mu = statistics.mean(x)
    sig = [(v - mu) * (0.5 - 0.5 * math.cos(2 * math.pi * i / (n - 1))) for i, v in enumerate(x)]
    lo_hz, hi_hz = lo_bpm / 60.0, hi_bpm / 60.0
    k0 = max(1, int(math.ceil(lo_hz * n / fs)))
    k1 = min(n // 2, int(math.floor(hi_hz * n / fs)))
    best_k, best_p = None, -1.0
    powers = []
    for k in range(k0, k1 + 1):
        re = 0.0; im = 0.0
        for i, v in enumerate(sig):
            ang = -2 * math.pi * k * i / n
            re += v * math.cos(ang)
            im += v * math.sin(ang)
        p = re * re + im * im
        powers.append(p)
        if p > best_p:
            best_p = p; best_k = k
    if best_k is None:
        return float('nan'), 0.0
    # crude spectral quality: peak / median band power
    med = statistics.median(powers) if powers else 0.0
    q = best_p / (med + 1e-12)
    return best_k * fs / n * 60.0, q

# ---------------- signal feature extraction ----------------

def read_signal_file(path):
    with open(path, newline='') as f:
        rdr = csv.DictReader(f, skipinitialspace=True)
        cols = rdr.fieldnames
        ppg_col = 'PLETH'
        resp_col = 'RESP'
        t_col = cols[0]
        t, ppg, resp = [], [], []
        for row in rdr:
            t.append(float(row[t_col]))
            ppg.append(float(row[ppg_col]))
            resp.append(float(row[resp_col]))
    return t, ppg, resp

def read_breaths(path):
    anns = []
    with open(path, newline='') as f:
        rdr = csv.reader(f)
        next(rdr, None)
        for row in rdr:
            vals = []
            for cell in row[:2]:
                cell = cell.strip()
                if cell:
                    vals.append(float(cell) / FS)
            if vals:
                anns.append(sum(vals) / len(vals))
    return anns

def find_peaks(x, fs):
    # Adaptive local maxima after light smoothing; tuned for PPG pulse peaks.
    y = moving_average(x, max(1, int(0.08 * fs)))
    min_dist = int(0.30 * fs)  # <= 200 bpm
    cand = []
    for i in range(1, len(y) - 1):
        if y[i - 1] < y[i] >= y[i + 1]:
            cand.append(i)
    if not cand:
        return []
    thresh = statistics.mean(y) + 0.10 * statistics.pstdev(y)
    cand = [i for i in cand if y[i] >= thresh]
    peaks = []
    for i in cand:
        if not peaks or i - peaks[-1] >= min_dist:
            peaks.append(i)
        elif y[i] > y[peaks[-1]]:
            peaks[-1] = i
    return peaks

def ppg_resp_features(t, ppg):
    x = zscore(detrend_ma(ppg, FS, 1.5))
    peaks = find_peaks(x, FS)
    troughs = []
    for a, b in zip(peaks[:-1], peaks[1:]):
        if b > a + 2:
            j = min(range(a, b), key=lambda idx: x[idx])
            troughs.append(j)
    # Feature streams corresponding to common PPG respiratory modulations:
    # BW: trough baseline, AM: peak-trough amplitude, FM: inter-pulse interval.
    streams = {}
    if len(troughs) > 5:
        times_bw = [idx / FS for idx in troughs]
        vals_bw = [x[idx] for idx in troughs]
        streams['BW_trough_baseline'] = (times_bw, vals_bw)
    amp_times, amp_vals = [], []
    for pk in peaks:
        prev_tr = [tr for tr in troughs if tr < pk]
        if prev_tr:
            tr = prev_tr[-1]
            amp_times.append(pk / FS)
            amp_vals.append(x[pk] - x[tr])
    if len(amp_times) > 5:
        streams['AM_pulse_amplitude'] = (amp_times, amp_vals)
    if len(peaks) > 6:
        times_fm = [peaks[i] / FS for i in range(1, len(peaks))]
        vals_fm = [(peaks[i] - peaks[i-1]) / FS for i in range(1, len(peaks))]
        streams['FM_inter_pulse_interval'] = (times_fm, vals_fm)
    return streams

def reference_rr(breaths, start, stop):
    bs = [b for b in breaths if start <= b <= stop]
    if len(bs) < 2:
        return float('nan')
    return 60.0 * (len(bs) - 1) / (bs[-1] - bs[0])

def estimate_window_rr(t, ppg, resp, breaths, start, stop):
    i0 = max(0, int(start * FS)); i1 = min(len(ppg), int(stop * FS))
    if i1 - i0 < int(20 * FS):
        return None
    ppg_w = ppg[i0:i1]
    resp_w = resp[i0:i1]
    ref = reference_rr(breaths, start, stop)
    if not math.isfinite(ref):
        return None
    # Clinical-standard impedance channel, for context.
    resp_clean = zscore(detrend_ma(resp_w, FS, 8.0))
    ip_rr, ip_q = dft_peak_bpm(resp_clean, FS)
    # Filter-based PPG baseline-wander method: estimate RR directly from the slow
    # component of PPG. A 4-30 bpm search band is used here as a pragmatic adult
    # ICU setting to suppress cardiac leakage; the paper's toolbox evaluates many
    # variants, commonly within a broader 4-60 bpm physiological band.
    ppg_raw_bw = zscore(detrend_ma(ppg_w, FS, 8.0))
    ppg_raw_bw_rr, ppg_raw_bw_q = dft_peak_bpm(ppg_raw_bw, FS, 4.0, 30.0)
    estimates = {}
    qualities = {}
    for name, (ft, fv) in ppg_resp_features(t[i0:i1], ppg_w).items():
        if len(ft) < 4:
            continue
        grid, vals = resample_linear(ft, fv, 5.0, start, stop)
        vals = zscore(detrend_ma(vals, 5.0, 8.0))
        rr, q = dft_peak_bpm(vals, 5.0)
        if math.isfinite(rr):
            estimates[name] = rr; qualities[name] = q
    if not estimates:
        fused = float('nan')
    else:
        # Smart-fusion style: if modulation estimates agree, average; otherwise choose best spectral quality.
        vals = list(estimates.values())
        if len(vals) >= 2 and statistics.pstdev(vals) <= 4.0:
            fused = statistics.mean(vals)
        else:
            fused = estimates[max(qualities, key=qualities.get)]
    return {
        'start_s': start, 'stop_s': stop, 'ref_rr': ref, 'ip_rr': ip_rr, 'ppg_raw_bw_rr': ppg_raw_bw_rr, 'ppg_rr_fused': fused,
        **{f'ppg_{k}': v for k, v in estimates.items()}
    }

def loa(errors):
    bias = statistics.mean(errors)
    sd = statistics.stdev(errors) if len(errors) > 1 else 0.0
    return bias, bias - 1.96 * sd, bias + 1.96 * sd

def mae(errors):
    return statistics.mean(abs(e) for e in errors)

def main():
    rows = []
    for subj in range(1, 11):
        sig = DATA / f'bidmc_{subj:02d}_Signals.csv'
        br = DATA / f'bidmc_{subj:02d}_Breaths.csv'
        if not sig.exists() or not br.exists():
            continue
        t, ppg, resp = read_signal_file(sig)
        breaths = read_breaths(br)
        duration = t[-1]
        start = 0.0
        while start + WIN_SEC <= duration:
            r = estimate_window_rr(t, ppg, resp, breaths, start, start + WIN_SEC)
            if r:
                r['subject'] = subj
                rows.append(r)
            start += STEP_SEC
    out_csv = ROOT / 'window_results.csv'
    keys = sorted({k for r in rows for k in r.keys()})
    with open(out_csv, 'w', newline='') as f:
        wr = csv.DictWriter(f, fieldnames=keys); wr.writeheader(); wr.writerows(rows)
    summary = {'paper': 'Pimentel et al. / Charlton et al. RR estimation from PPG, BIDMC PhysioNet reproduction scaffold', 'subjects': 10, 'windows': len(rows), 'methods': {}}
    for method in ['ppg_raw_bw_rr', 'ppg_rr_fused', 'ip_rr', 'ppg_BW_trough_baseline', 'ppg_AM_pulse_amplitude', 'ppg_FM_inter_pulse_interval']:
        errs = [r[method] - r['ref_rr'] for r in rows if method in r and math.isfinite(r[method]) and math.isfinite(r['ref_rr'])]
        if errs:
            bias, lo, hi = loa(errs)
            summary['methods'][method] = {'n': len(errs), 'bias_bpm': bias, 'mae_bpm': mae(errs), 'loa95_low_bpm': lo, 'loa95_high_bpm': hi}
    with open(ROOT / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    write_svg(summary)
    print(json.dumps(summary, indent=2))

def write_svg(summary):
    methods = [m for m in ['ppg_raw_bw_rr','ppg_rr_fused','ppg_BW_trough_baseline','ppg_AM_pulse_amplitude','ppg_FM_inter_pulse_interval','ip_rr'] if m in summary['methods']]
    labels = {'ppg_raw_bw_rr':'PPG raw BW','ppg_rr_fused':'PPG fused','ppg_BW_trough_baseline':'PPG feat BW','ppg_AM_pulse_amplitude':'PPG AM','ppg_FM_inter_pulse_interval':'PPG FM','ip_rr':'IP resp'}
    colors = {'ppg_raw_bw_rr':'#f97316','ppg_rr_fused':'#fb923c','ppg_BW_trough_baseline':'#60a5fa','ppg_AM_pulse_amplitude':'#38bdf8','ppg_FM_inter_pulse_interval':'#a78bfa','ip_rr':'#22c55e'}
    W,H = 980,520; x0,y0,ch=80,135,250
    vals = [summary['methods'][m]['mae_bpm'] for m in methods]
    maxv = max(vals)*1.2 if vals else 1
    lines = [f'<svg xmlns="http://www.w3.org/2000/svg" width="{W}" height="{H}" viewBox="0 0 {W} {H}">','<rect width="100%" height="100%" fill="#0b1020"/>', '<text x="44" y="48" fill="#f8fafc" font-family="Inter,Arial" font-size="24" font-weight="700">PPG respiratory-rate reproduction scaffold</text>', '<text x="44" y="78" fill="#94a3b8" font-family="Arial" font-size="14">BIDMC PhysioNet · 10 subjects · PPG modulation extraction + spectral RR + smart fusion · lower MAE is better</text>']
    for i,m in enumerate(methods):
        bw=110; x=x0+i*170; val=summary['methods'][m]['mae_bpm']; bh=val/maxv*ch; y=y0+ch-bh
        lines.append(f'<rect x="{x}" y="{y:.1f}" width="{bw}" height="{bh:.1f}" rx="10" fill="{colors[m]}"/>')
        lines.append(f'<text x="{x+bw/2}" y="{y-9:.1f}" text-anchor="middle" fill="#f8fafc" font-family="Arial" font-size="14">{val:.2f} bpm</text>')
        lines.append(f'<text x="{x+bw/2}" y="{y0+ch+28}" text-anchor="middle" fill="#cbd5e1" font-family="Arial" font-size="13">{labels[m]}</text>')
    lines.append(f'<line x1="{x0-25}" y1="{y0+ch}" x2="900" y2="{y0+ch}" stroke="#334155"/>')
    if 'ppg_raw_bw_rr' in summary['methods']:
        s=summary['methods']['ppg_raw_bw_rr']
        lines.append(f'<text x="44" y="465" fill="#e2e8f0" font-family="Arial" font-size="15">Best PPG baseline method: bias {s["bias_bpm"]:.2f} bpm, 95% LoA [{s["loa95_low_bpm"]:.2f}, {s["loa95_high_bpm"]:.2f}], n={s["n"]}</text>')
    lines.append('</svg>')
    with open(ROOT/'ppg_rr_results.svg','w') as f: f.write('\n'.join(lines))

if __name__ == '__main__':
    main()
