#!/usr/bin/env python3
"""
Minimal reproduction for arXiv:2606.26990
"Decision-Aligned Evaluation of Uncertainty Quantification".

Focus: reproduce the paper's central binary-decision claim that standard UQ metrics
(NLL/Brier/ECE) can rank probabilistic predictors differently from the utility of a
downstream decision with an application-specific false-positive cost prior.

No external dependencies; pure Python stdlib.
"""
import csv, json, math, os, random, statistics

OUT = os.path.dirname(__file__)
EPS = 1e-12

# ------------------------- synthetic data and predictors -------------------------

def make_labels(n=4000, seed=0):
    rng = random.Random(seed)
    labels = []
    # Heterogeneous ground-truth probabilities to mimic varying risk.
    for i in range(n):
        if i % 4 == 0:
            p = rng.betavariate(0.8, 5.5)      # mostly low risk
        elif i % 4 == 1:
            p = rng.betavariate(2.0, 2.0)      # ambiguous
        elif i % 4 == 2:
            p = rng.betavariate(5.5, 0.8)      # mostly high risk
        else:
            p = rng.betavariate(1.1, 1.1)      # broad
        labels.append(1 if rng.random() < p else 0)
    return labels

def make_predictors(y, seed=1):
    rng = random.Random(seed)
    preds = {}
    # A: generally reasonable and fairly calibrated, but with many probabilities near 0.5.
    a = []
    # B: worse NLL/Brier due to overconfidence/noise at extremes, but often makes better
    # threshold decisions for high-cost false-positive regimes around c≈0.7-0.85.
    b = []
    # C: underconfident; often good calibration-looking bins but weak for decisions.
    c = []
    for yi in y:
        if yi:
            pa = min(max(rng.betavariate(4.2, 2.3), 0.01), 0.99)
            pb = min(max(rng.betavariate(6.5, 2.1), 0.01), 0.995)
            pc = 0.50 + 0.30 * rng.betavariate(2.5, 2.5)
        else:
            pa = min(max(rng.betavariate(2.3, 4.2), 0.01), 0.99)
            pb = min(max(rng.betavariate(1.9, 7.2), 0.005), 0.99)
            pc = 0.50 - 0.30 * rng.betavariate(2.5, 2.5)
        # Inject rare high-confidence mistakes in B; these hurt NLL strongly, but not necessarily
        # the chosen operating-region utility enough to dominate decision value.
        if rng.random() < 0.035:
            pb = 1.0 - pb
        a.append(pa); b.append(pb); c.append(pc)
    preds['A_calibratedish'] = a
    preds['B_decision_oriented'] = b
    preds['C_underconfident'] = c
    return preds

# ------------------------- metrics -------------------------

def nll(p, y):
    return statistics.mean(-yi * math.log(max(pi, EPS)) - (1 - yi) * math.log(max(1 - pi, EPS)) for pi, yi in zip(p, y))

def brier(p, y):
    return statistics.mean((pi - yi) ** 2 for pi, yi in zip(p, y))

def accuracy(p, y, threshold=0.5):
    return statistics.mean(int((pi > threshold) == bool(yi)) for pi, yi in zip(p, y))

def ece(p, y, bins=15):
    # Equal-width ECE.
    total = len(y)
    out = 0.0
    for b in range(bins):
        lo, hi = b / bins, (b + 1) / bins
        idx = [i for i, pi in enumerate(p) if (lo <= pi < hi) or (b == bins - 1 and pi == 1.0)]
        if not idx:
            continue
        conf = statistics.mean(p[i] for i in idx)
        freq = statistics.mean(y[i] for i in idx)
        out += len(idx) / total * abs(conf - freq)
    return out

def utility_at_cost(p, y, cost_fp):
    # Paper's binary decision utility: choose action 1 iff f > c.
    # Utility is negative normalized cost.
    c = cost_fp
    vals = []
    for pi, yi in zip(p, y):
        action = 1 if pi > c else 0
        if action == 1 and yi == 0:
            vals.append(-c)
        elif action == 0 and yi == 1:
            vals.append(-(1 - c))
        else:
            vals.append(0.0)
    return statistics.mean(vals)

def prior_weighted_utility(p, y, costs):
    return statistics.mean(utility_at_cost(p, y, c) for c in costs)

# ------------------------- reproduction experiment -------------------------

def rank_metric(metric_values, higher_is_better=False):
    return sorted(metric_values, key=metric_values.get, reverse=higher_is_better)

def spearman_rank_corr(order1, order2):
    names = order1
    r1 = {name: i + 1 for i, name in enumerate(order1)}
    r2 = {name: i + 1 for i, name in enumerate(order2)}
    n = len(names)
    d2 = sum((r1[name] - r2[name]) ** 2 for name in names)
    return 1 - 6 * d2 / (n * (n * n - 1))

def run_once(seed):
    y = make_labels(seed=1000 + seed)
    preds = make_predictors(y, seed=2000 + seed)
    # Application prior: expensive false positives, concentrated on c in [0.70, 0.85].
    # This is intentionally non-uniform to reproduce the paper's point that task priors matter.
    costs = [0.70 + 0.15 * i / 60 for i in range(61)]
    rows = []
    values = {'NLL': {}, 'Brier': {}, 'ECE': {}, 'Accuracy@0.5': {}, 'PriorWeightedUtility': {}}
    for name, p in preds.items():
        vals = {
            'NLL': nll(p, y),
            'Brier': brier(p, y),
            'ECE': ece(p, y),
            'Accuracy@0.5': accuracy(p, y),
            'PriorWeightedUtility': prior_weighted_utility(p, y, costs),
            'Utility@0.75': utility_at_cost(p, y, 0.75),
            'Utility@0.80': utility_at_cost(p, y, 0.80),
        }
        for k in values:
            values[k][name] = vals[k]
        rows.append({'seed': seed, 'model': name, **vals})
    utility_rank = rank_metric(values['PriorWeightedUtility'], higher_is_better=True)
    ranks = {
        'NLL': rank_metric(values['NLL'], higher_is_better=False),
        'Brier': rank_metric(values['Brier'], higher_is_better=False),
        'ECE': rank_metric(values['ECE'], higher_is_better=False),
        'Accuracy@0.5': rank_metric(values['Accuracy@0.5'], higher_is_better=True),
        'PriorWeightedUtility': utility_rank,
    }
    align = {m: spearman_rank_corr(ranks[m], utility_rank) for m in ['NLL', 'Brier', 'ECE', 'Accuracy@0.5', 'PriorWeightedUtility']}
    return rows, ranks, align

def aggregate(records):
    models = sorted({r['model'] for r in records})
    metrics = ['NLL', 'Brier', 'ECE', 'Accuracy@0.5', 'PriorWeightedUtility', 'Utility@0.75', 'Utility@0.80']
    out = {}
    for model in models:
        out[model] = {}
        rs = [r for r in records if r['model'] == model]
        for m in metrics:
            vals = [r[m] for r in rs]
            out[model][m + '_mean'] = statistics.mean(vals)
            out[model][m + '_sd'] = statistics.stdev(vals) if len(vals) > 1 else 0.0
    return out

def write_svg(summary, alignment):
    models = ['A_calibratedish', 'B_decision_oriented', 'C_underconfident']
    labels = {'A_calibratedish': 'A\ncalib-ish', 'B_decision_oriented': 'B\ndecision', 'C_underconfident': 'C\nunderconf'}
    colors = {'A_calibratedish': '#60a5fa', 'B_decision_oriented': '#f97316', 'C_underconfident': '#94a3b8'}
    W, H = 980, 520
    lines = [f'<svg xmlns="http://www.w3.org/2000/svg" width="{W}" height="{H}" viewBox="0 0 {W} {H}">',
             '<rect width="100%" height="100%" fill="#0b1020"/>',
             '<text x="44" y="46" fill="#f8fafc" font-family="Inter,Arial" font-size="24" font-weight="700">Decision-aligned UQ minimal reproduction</text>',
             '<text x="44" y="76" fill="#94a3b8" font-family="Inter,Arial" font-size="14">arXiv:2606.26990 · standard UQ metrics can disagree with downstream decision utility</text>']
    # Panel 1: NLL lower is better
    x0, y0, ch = 70, 130, 245
    panel_w = 390
    nlls = [summary[m]['NLL_mean'] for m in models]
    max_nll = max(nlls) * 1.08
    lines.append(f'<text x="{x0}" y="112" fill="#e2e8f0" font-family="Arial" font-size="17">NLL / generic metric, lower is better</text>')
    for i, m in enumerate(models):
        bw = 76; x = x0 + 35 + i * 115; val = summary[m]['NLL_mean']; bh = val / max_nll * ch; y = y0 + ch - bh
        lines.append(f'<rect x="{x}" y="{y:.1f}" width="{bw}" height="{bh:.1f}" rx="8" fill="{colors[m]}"/>')
        lines.append(f'<text x="{x+bw/2}" y="{y-8:.1f}" fill="#f8fafc" text-anchor="middle" font-family="Arial" font-size="13">{val:.3f}</text>')
        lines.append(f'<text x="{x+bw/2}" y="{y0+ch+25}" fill="#cbd5e1" text-anchor="middle" font-family="Arial" font-size="12">{labels[m].split(chr(10))[0]}</text>')
        lines.append(f'<text x="{x+bw/2}" y="{y0+ch+40}" fill="#cbd5e1" text-anchor="middle" font-family="Arial" font-size="12">{labels[m].split(chr(10))[1]}</text>')
    lines.append(f'<line x1="{x0}" y1="{y0+ch}" x2="{x0+panel_w}" y2="{y0+ch}" stroke="#334155"/>')
    # Panel 2: utility higher is better, plot negative cost upward from min
    x0 = 535
    utils = [summary[m]['PriorWeightedUtility_mean'] for m in models]
    umin, umax = min(utils), max(utils)
    span = (umax - umin) or 1
    lines.append(f'<text x="{x0}" y="112" fill="#e2e8f0" font-family="Arial" font-size="17">Prior-weighted utility, higher is better</text>')
    for i, m in enumerate(models):
        bw = 76; x = x0 + 35 + i * 115; val = summary[m]['PriorWeightedUtility_mean']; bh = (val - umin) / span * ch * 0.82 + 18; y = y0 + ch - bh
        lines.append(f'<rect x="{x}" y="{y:.1f}" width="{bw}" height="{bh:.1f}" rx="8" fill="{colors[m]}"/>')
        lines.append(f'<text x="{x+bw/2}" y="{y-8:.1f}" fill="#f8fafc" text-anchor="middle" font-family="Arial" font-size="13">{val:.4f}</text>')
        lines.append(f'<text x="{x+bw/2}" y="{y0+ch+25}" fill="#cbd5e1" text-anchor="middle" font-family="Arial" font-size="12">{labels[m].split(chr(10))[0]}</text>')
        lines.append(f'<text x="{x+bw/2}" y="{y0+ch+40}" fill="#cbd5e1" text-anchor="middle" font-family="Arial" font-size="12">{labels[m].split(chr(10))[1]}</text>')
    lines.append(f'<line x1="{x0}" y1="{y0+ch}" x2="{x0+panel_w}" y2="{y0+ch}" stroke="#334155"/>')
    lines.append(f'<text x="44" y="470" fill="#e2e8f0" font-family="Arial" font-size="15">Rank alignment with utility, mean Spearman: NLL {alignment["NLL_mean"]:.2f}, Brier {alignment["Brier_mean"]:.2f}, ECE {alignment["ECE_mean"]:.2f}, prior-weighted utility {alignment["PriorWeightedUtility_mean"]:.2f}</text>')
    lines.append('</svg>')
    path = os.path.join(OUT, 'decision_alignment_results.svg')
    with open(path, 'w', encoding='utf-8') as f: f.write('\n'.join(lines))
    return path

def main():
    all_rows = []
    align_rows = []
    rank_rows = []
    for seed in range(30):
        rows, ranks, align = run_once(seed)
        all_rows.extend(rows)
        align_rows.append({'seed': seed, **align})
        for metric, order in ranks.items():
            rank_rows.append({'seed': seed, 'metric': metric, 'rank_order': ' > '.join(order)})
    summary = aggregate(all_rows)
    alignment_summary = {}
    for m in ['NLL', 'Brier', 'ECE', 'Accuracy@0.5', 'PriorWeightedUtility']:
        vals = [r[m] for r in align_rows]
        alignment_summary[m + '_mean'] = statistics.mean(vals)
        alignment_summary[m + '_sd'] = statistics.stdev(vals) if len(vals) > 1 else 0.0
    # Utility ranking from aggregate means
    utility_values = {m: summary[m]['PriorWeightedUtility_mean'] for m in summary}
    nll_values = {m: summary[m]['NLL_mean'] for m in summary}
    brier_values = {m: summary[m]['Brier_mean'] for m in summary}
    ece_values = {m: summary[m]['ECE_mean'] for m in summary}
    result = {
        'paper': 'Decision-Aligned Evaluation of Uncertainty Quantification (arXiv:2606.26990v1)',
        'reproduction_scope': 'binary decision toy reproduction of metric-vs-utility ranking disagreement',
        'decision_prior': 'uniform grid over false-positive cost c in [0.70, 0.85]',
        'seeds': 30,
        'models': summary,
        'aggregate_rankings': {
            'NLL_lower_better': rank_metric(nll_values, False),
            'Brier_lower_better': rank_metric(brier_values, False),
            'ECE_lower_better': rank_metric(ece_values, False),
            'PriorWeightedUtility_higher_better': rank_metric(utility_values, True),
        },
        'rank_alignment_spearman_mean': alignment_summary,
    }
    with open(os.path.join(OUT, 'metrics_by_seed.csv'), 'w', newline='', encoding='utf-8') as f:
        wr = csv.DictWriter(f, fieldnames=list(all_rows[0].keys())); wr.writeheader(); wr.writerows(all_rows)
    with open(os.path.join(OUT, 'rankings_by_seed.csv'), 'w', newline='', encoding='utf-8') as f:
        wr = csv.DictWriter(f, fieldnames=list(rank_rows[0].keys())); wr.writeheader(); wr.writerows(rank_rows)
    with open(os.path.join(OUT, 'summary.json'), 'w', encoding='utf-8') as f:
        json.dump(result, f, indent=2)
    svg = write_svg(summary, alignment_summary)
    print(json.dumps(result, indent=2))
    print('wrote', svg)

if __name__ == '__main__':
    main()
