#!/usr/bin/env python3
"""
Minimal reproduction of the core claim in:
Foret et al. (2021), "Sharpness-Aware Minimization for Efficiently Improving Generalization".

This is intentionally dependency-free: pure Python stdlib only.
It compares standard mini-batch SGD vs SAM on a noisy two-moons binary classification task.
The goal is not to reproduce the paper's CIFAR/ImageNet tables, but to reproduce the central
phenomenon: optimizing for neighborhoods of low loss can improve held-out generalization.
"""
import csv, json, math, os, random, statistics
from dataclasses import dataclass

OUT_DIR = os.path.dirname(__file__)

# ----------------------------- data -----------------------------

def make_moons(n=1200, noise=0.22, seed=0):
    rng = random.Random(seed)
    X, y = [], []
    for i in range(n):
        label = i % 2
        t = rng.random() * math.pi
        if label == 0:
            x1, x2 = math.cos(t), math.sin(t)
        else:
            x1, x2 = 1.0 - math.cos(t), 0.5 - math.sin(t)
        x1 += rng.gauss(0, noise)
        x2 += rng.gauss(0, noise)
        X.append([x1, x2])
        y.append(label)
    idx = list(range(n))
    rng.shuffle(idx)
    X = [X[i] for i in idx]
    y = [y[i] for i in idx]
    # standardize by train statistics
    split = int(n * 0.65)
    means = [sum(row[j] for row in X[:split]) / split for j in range(2)]
    stds = []
    for j in range(2):
        var = sum((row[j] - means[j]) ** 2 for row in X[:split]) / split
        stds.append(math.sqrt(var) + 1e-12)
    X = [[(row[j] - means[j]) / stds[j] for j in range(2)] for row in X]
    return X[:split], y[:split], X[split:], y[split:]

# ----------------------------- model -----------------------------

@dataclass
class Params:
    W1: list
    b1: list
    W2: list
    b2: float

def init_params(seed, hidden=24):
    rng = random.Random(seed)
    W1 = [[rng.gauss(0, 0.35) for _ in range(hidden)] for _ in range(2)]
    b1 = [0.0 for _ in range(hidden)]
    W2 = [rng.gauss(0, 0.35) for _ in range(hidden)]
    return Params(W1, b1, W2, 0.0)

def zeros_like(p):
    return Params([[0.0 for _ in row] for row in p.W1], [0.0 for _ in p.b1], [0.0 for _ in p.W2], 0.0)

def copy_params(p):
    return Params([row[:] for row in p.W1], p.b1[:], p.W2[:], p.b2)

def sigmoid(z):
    if z >= 0:
        ez = math.exp(-z)
        return 1.0 / (1.0 + ez)
    ez = math.exp(z)
    return ez / (1.0 + ez)

def forward_one(p, x):
    h = []
    for k in range(len(p.b1)):
        z = x[0] * p.W1[0][k] + x[1] * p.W1[1][k] + p.b1[k]
        h.append(math.tanh(z))
    logit = sum(h[k] * p.W2[k] for k in range(len(h))) + p.b2
    return h, sigmoid(logit)

def loss_and_grad(p, X, y, indices, wd=1e-4):
    g = zeros_like(p)
    loss = 0.0
    n = len(indices)
    H = len(p.b1)
    for idx in indices:
        x, target = X[idx], y[idx]
        h, prob = forward_one(p, x)
        prob = min(max(prob, 1e-8), 1 - 1e-8)
        loss += -(target * math.log(prob) + (1 - target) * math.log(1 - prob))
        dlogit = prob - target
        for k in range(H):
            g.W2[k] += dlogit * h[k]
        g.b2 += dlogit
        for k in range(H):
            dz = dlogit * p.W2[k] * (1 - h[k] * h[k])
            g.W1[0][k] += dz * x[0]
            g.W1[1][k] += dz * x[1]
            g.b1[k] += dz
    inv = 1.0 / n
    loss *= inv
    for k in range(H):
        loss += 0.5 * wd * (p.W1[0][k] ** 2 + p.W1[1][k] ** 2 + p.W2[k] ** 2)
        g.W1[0][k] = g.W1[0][k] * inv + wd * p.W1[0][k]
        g.W1[1][k] = g.W1[1][k] * inv + wd * p.W1[1][k]
        g.W2[k] = g.W2[k] * inv + wd * p.W2[k]
        g.b1[k] *= inv
    g.b2 *= inv
    return loss, g

def grad_norm(g):
    s = g.b2 * g.b2
    for row in g.W1:
        for v in row: s += v * v
    for v in g.b1: s += v * v
    for v in g.W2: s += v * v
    return math.sqrt(s) + 1e-12

def add_scaled(p, g, scale):
    for i in range(2):
        for k in range(len(p.b1)):
            p.W1[i][k] += scale * g.W1[i][k]
    for k in range(len(p.b1)):
        p.b1[k] += scale * g.b1[k]
        p.W2[k] += scale * g.W2[k]
    p.b2 += scale * g.b2

def sgd_step(p, g, lr):
    add_scaled(p, g, -lr)

def evaluate(p, X, y):
    loss = 0.0
    correct = 0
    for x, target in zip(X, y):
        _, prob = forward_one(p, x)
        prob = min(max(prob, 1e-8), 1 - 1e-8)
        loss += -(target * math.log(prob) + (1 - target) * math.log(1 - prob))
        correct += int((prob >= 0.5) == bool(target))
    return loss / len(y), correct / len(y)

def sharpness_proxy(p, X, y, seed=0, radius=0.05, trials=16):
    # Empirical max loss increase under random normalized perturbations.
    rng = random.Random(seed)
    base_loss, _ = evaluate(p, X, y)
    max_increase = 0.0
    H = len(p.b1)
    for _ in range(trials):
        d = zeros_like(p)
        d.b2 = rng.gauss(0, 1)
        for i in range(2):
            for k in range(H):
                d.W1[i][k] = rng.gauss(0, 1)
        for k in range(H):
            d.b1[k] = rng.gauss(0, 1)
            d.W2[k] = rng.gauss(0, 1)
        norm = grad_norm(d)
        q = copy_params(p)
        add_scaled(q, d, radius / norm)
        pert_loss, _ = evaluate(q, X, y)
        max_increase = max(max_increase, pert_loss - base_loss)
    return max_increase

# ----------------------------- training -----------------------------

def train(method, seed, epochs=260, batch_size=64, lr=0.055, rho=0.12):
    Xtr, ytr, Xte, yte = make_moons(seed=10_000 + seed)
    p = init_params(seed=1_000 + seed)
    rng = random.Random(20_000 + seed)
    hist = []
    n = len(ytr)
    for ep in range(1, epochs + 1):
        indices = list(range(n))
        rng.shuffle(indices)
        for start in range(0, n, batch_size):
            batch = indices[start:start+batch_size]
            if method == 'sgd':
                _, g = loss_and_grad(p, Xtr, ytr, batch)
                sgd_step(p, g, lr)
            elif method == 'sam':
                _, g1 = loss_and_grad(p, Xtr, ytr, batch)
                norm = grad_norm(g1)
                add_scaled(p, g1, rho / norm)
                _, g2 = loss_and_grad(p, Xtr, ytr, batch)
                add_scaled(p, g1, -rho / norm)
                sgd_step(p, g2, lr)
            else:
                raise ValueError(method)
        if ep % 20 == 0 or ep == 1:
            tr_loss, tr_acc = evaluate(p, Xtr, ytr)
            te_loss, te_acc = evaluate(p, Xte, yte)
            hist.append({'epoch': ep, 'train_loss': tr_loss, 'train_acc': tr_acc, 'test_loss': te_loss, 'test_acc': te_acc})
    tr_loss, tr_acc = evaluate(p, Xtr, ytr)
    te_loss, te_acc = evaluate(p, Xte, yte)
    sharp = sharpness_proxy(p, Xte, yte, seed=30_000 + seed)
    return {
        'method': method, 'seed': seed,
        'train_loss': tr_loss, 'train_acc': tr_acc,
        'test_loss': te_loss, 'test_acc': te_acc,
        'sharpness_proxy': sharp,
        'history': hist,
    }

def mean_sd(vals):
    return statistics.mean(vals), (statistics.stdev(vals) if len(vals) > 1 else 0.0)

def write_svg(summary):
    # Simple dependency-free bar chart.
    w, h = 760, 420
    margin = 70
    methods = ['sgd', 'sam']
    colors = {'sgd': '#7c8aa5', 'sam': '#f97316'}
    acc = {m: summary[m]['test_acc_mean'] for m in methods}
    sharp = {m: summary[m]['sharpness_proxy_mean'] for m in methods}
    max_sharp = max(sharp.values()) * 1.25
    lines = [f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}" viewBox="0 0 {w} {h}">',
             '<rect width="100%" height="100%" fill="#0b1020"/>',
             '<text x="40" y="42" fill="#f8fafc" font-family="Inter,Arial" font-size="22" font-weight="700">SAM minimal reproduction: generalization and sharpness</text>',
             '<text x="40" y="70" fill="#94a3b8" font-family="Inter,Arial" font-size="13">Noisy two-moons MLP; mean over 8 random seeds. Higher test accuracy is better; lower sharpness proxy is better.</text>']
    # left panel: accuracy
    x0, y0, ph, pw = 70, 120, 230, 260
    lines += [f'<text x="{x0}" y="105" fill="#e2e8f0" font-family="Arial" font-size="16">Test accuracy</text>']
    for i, m in enumerate(methods):
        bw = 70
        x = x0 + 55 + i * 105
        bh = (acc[m] - 0.75) / 0.25 * ph
        y = y0 + ph - bh
        lines.append(f'<rect x="{x}" y="{y:.1f}" width="{bw}" height="{bh:.1f}" rx="8" fill="{colors[m]}"/>')
        lines.append(f'<text x="{x+bw/2}" y="{y-8:.1f}" fill="#f8fafc" text-anchor="middle" font-family="Arial" font-size="14">{acc[m]*100:.2f}%</text>')
        lines.append(f'<text x="{x+bw/2}" y="{y0+ph+28}" fill="#cbd5e1" text-anchor="middle" font-family="Arial" font-size="14">{m.upper()}</text>')
    lines.append(f'<line x1="{x0}" y1="{y0+ph}" x2="{x0+pw}" y2="{y0+ph}" stroke="#334155"/>')
    lines.append(f'<text x="{x0}" y="{y0+ph+45}" fill="#64748b" font-family="Arial" font-size="12">axis starts at 75%</text>')
    # right panel: sharpness
    x0 = 430
    lines += [f'<text x="{x0}" y="105" fill="#e2e8f0" font-family="Arial" font-size="16">Sharpness proxy</text>']
    for i, m in enumerate(methods):
        bw = 70
        x = x0 + 55 + i * 105
        bh = sharp[m] / max_sharp * ph if max_sharp else 0
        y = y0 + ph - bh
        lines.append(f'<rect x="{x}" y="{y:.1f}" width="{bw}" height="{bh:.1f}" rx="8" fill="{colors[m]}"/>')
        lines.append(f'<text x="{x+bw/2}" y="{y-8:.1f}" fill="#f8fafc" text-anchor="middle" font-family="Arial" font-size="14">{sharp[m]:.4f}</text>')
        lines.append(f'<text x="{x+bw/2}" y="{y0+ph+28}" fill="#cbd5e1" text-anchor="middle" font-family="Arial" font-size="14">{m.upper()}</text>')
    lines.append(f'<line x1="{x0}" y1="{y0+ph}" x2="{x0+pw}" y2="{y0+ph}" stroke="#334155"/>')
    lines.append('</svg>')
    path = os.path.join(OUT_DIR, 'sam_minimal_results.svg')
    with open(path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(lines))
    return path

def main():
    seeds = list(range(8))
    rows = []
    histories = []
    for seed in seeds:
        for method in ['sgd', 'sam']:
            print(f'train {method} seed={seed}', flush=True)
            r = train(method, seed)
            rows.append({k: v for k, v in r.items() if k != 'history'})
            for h in r['history']:
                histories.append({'method': method, 'seed': seed, **h})
    summary = {}
    for method in ['sgd', 'sam']:
        ms = [r for r in rows if r['method'] == method]
        summary[method] = {}
        for key in ['train_acc', 'test_acc', 'test_loss', 'sharpness_proxy']:
            mu, sd = mean_sd([r[key] for r in ms])
            summary[method][key + '_mean'] = mu
            summary[method][key + '_sd'] = sd
    summary['delta_test_acc_sam_minus_sgd'] = summary['sam']['test_acc_mean'] - summary['sgd']['test_acc_mean']
    summary['delta_sharpness_sam_minus_sgd'] = summary['sam']['sharpness_proxy_mean'] - summary['sgd']['sharpness_proxy_mean']

    with open(os.path.join(OUT_DIR, 'results.csv'), 'w', newline='', encoding='utf-8') as f:
        wr = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        wr.writeheader(); wr.writerows(rows)
    with open(os.path.join(OUT_DIR, 'history.csv'), 'w', newline='', encoding='utf-8') as f:
        wr = csv.DictWriter(f, fieldnames=list(histories[0].keys()))
        wr.writeheader(); wr.writerows(histories)
    with open(os.path.join(OUT_DIR, 'summary.json'), 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2)
    svg = write_svg(summary)
    print(json.dumps(summary, indent=2))
    print('wrote', svg)

if __name__ == '__main__':
    main()
