#!/usr/bin/env python3
"""
Dependency-free minimal reproduction for:
David Arthur & Sergei Vassilvitskii (2007),
"k-means++: The Advantages of Careful Seeding".

Reproduces the core empirical claim on synthetic clustered data:
D^2/k-means++ seeding gives lower k-means objective than uniform random seeding.
"""
import csv, json, math, os, random, statistics

OUT = os.path.dirname(__file__)
K = 8
TRIALS = 160
LLOYD_ITERS = 60

CENTERS = [
    (-7.5, -5.8), (-4.5, 3.6), (-0.8, -1.2), (2.5, 5.5),
    (5.8, -4.6), (8.2, 2.6), (0.5, 8.4), (-8.2, 7.2)
]
COUNTS = [220, 90, 160, 70, 180, 55, 45, 40]
STDS = [0.95, 0.72, 1.15, 0.65, 0.9, 0.6, 0.48, 0.5]

def make_data(seed=0):
    rng = random.Random(seed)
    pts = []
    labels = []
    for c, n, std in zip(CENTERS, COUNTS, STDS):
        for _ in range(n):
            pts.append((rng.gauss(c[0], std), rng.gauss(c[1], std)))
            labels.append(len(labels))
    # add mild bridge/noise points to make initialization meaningful
    for _ in range(45):
        pts.append((rng.uniform(-7, 7), rng.uniform(-4, 5)))
    rng.shuffle(pts)
    return pts

def dist2(a, b):
    dx, dy = a[0]-b[0], a[1]-b[1]
    return dx*dx + dy*dy

def nearest_dist2(x, centers):
    return min(dist2(x, c) for c in centers)

def objective(points, centers):
    return sum(nearest_dist2(x, centers) for x in points)

def seed_random(points, k, rng):
    return rng.sample(points, k)

def seed_kmeanspp(points, k, rng):
    centers = [rng.choice(points)]
    while len(centers) < k:
        weights = [nearest_dist2(x, centers) for x in points]
        total = sum(weights)
        if total <= 0:
            centers.append(rng.choice(points)); continue
        r = rng.random() * total
        acc = 0.0
        for x, w in zip(points, weights):
            acc += w
            if acc >= r:
                centers.append(x)
                break
    return centers

def lloyd(points, centers, max_iter=LLOYD_ITERS):
    centers = list(centers)
    k = len(centers)
    for _ in range(max_iter):
        groups = [[] for _ in range(k)]
        for x in points:
            j = min(range(k), key=lambda i: dist2(x, centers[i]))
            groups[j].append(x)
        new_centers = []
        for j, g in enumerate(groups):
            if not g:
                new_centers.append(centers[j])
            else:
                new_centers.append((sum(p[0] for p in g)/len(g), sum(p[1] for p in g)/len(g)))
        move = sum(dist2(a, b) for a, b in zip(centers, new_centers))
        centers = new_centers
        if move < 1e-10:
            break
    return centers

def summarize(vals):
    return {
        'mean': statistics.mean(vals),
        'median': statistics.median(vals),
        'sd': statistics.stdev(vals) if len(vals) > 1 else 0.0,
        'min': min(vals),
        'max': max(vals),
    }

def write_svg(summary):
    random_mean = summary['random_final']['mean']
    pp_mean = summary['kmeanspp_final']['mean']
    random_init = summary['random_initial']['mean']
    pp_init = summary['kmeanspp_initial']['mean']
    bars = [
        ('Random init', random_init, '#94a3b8'),
        ('k-means++ init', pp_init, '#fb923c'),
        ('Random final', random_mean, '#64748b'),
        ('k-means++ final', pp_mean, '#f97316'),
    ]
    W, H = 880, 460
    x0, y0, chart_w, chart_h = 90, 125, 700, 230
    maxv = max(v for _, v, _ in bars) * 1.08
    lines = [f'<svg xmlns="http://www.w3.org/2000/svg" width="{W}" height="{H}" viewBox="0 0 {W} {H}">',
             '<rect width="100%" height="100%" fill="#0f172a"/>',
             '<text x="44" y="48" fill="#f8fafc" font-family="Inter,Arial" font-size="24" font-weight="700">k-means++ reproduction: careful seeding lowers objective</text>',
             '<text x="44" y="78" fill="#94a3b8" font-family="Inter,Arial" font-size="14">Arthur & Vassilvitskii 2007; synthetic 8-cluster data; 160 random trials; lower SSE is better.</text>']
    for i, (name, val, color) in enumerate(bars):
        bw = 115
        x = x0 + i * 165
        bh = val / maxv * chart_h
        y = y0 + chart_h - bh
        lines.append(f'<rect x="{x}" y="{y:.2f}" width="{bw}" height="{bh:.2f}" rx="10" fill="{color}"/>')
        lines.append(f'<text x="{x+bw/2}" y="{y-10:.2f}" fill="#f8fafc" text-anchor="middle" font-family="Arial" font-size="13">{val:,.0f}</text>')
        label = name.replace(' ', '\u00a0')
        lines.append(f'<text x="{x+bw/2}" y="{y0+chart_h+28}" fill="#cbd5e1" text-anchor="middle" font-family="Arial" font-size="13">{label}</text>')
    lines.append(f'<line x1="{x0-20}" y1="{y0+chart_h}" x2="{x0+chart_w}" y2="{y0+chart_h}" stroke="#334155"/>')
    lines.append(f'<text x="44" y="412" fill="#e2e8f0" font-family="Arial" font-size="15">Final objective reduction: {(1 - pp_mean/random_mean)*100:.1f}% · k-means++ wins {summary["kmeanspp_win_rate"]*100:.1f}% of paired trials</text>')
    lines.append('</svg>')
    path = os.path.join(OUT, 'kmeanspp_results.svg')
    with open(path, 'w', encoding='utf-8') as f: f.write('\n'.join(lines))
    return path

def main():
    points = make_data(42)
    rows = []
    for t in range(TRIALS):
        rng_r = random.Random(1000 + t)
        rng_p = random.Random(2000 + t)
        c0_r = seed_random(points, K, rng_r)
        c0_p = seed_kmeanspp(points, K, rng_p)
        init_r = objective(points, c0_r)
        init_p = objective(points, c0_p)
        cf_r = lloyd(points, c0_r)
        cf_p = lloyd(points, c0_p)
        final_r = objective(points, cf_r)
        final_p = objective(points, cf_p)
        rows.append({
            'trial': t,
            'random_initial_sse': init_r,
            'kmeanspp_initial_sse': init_p,
            'random_final_sse': final_r,
            'kmeanspp_final_sse': final_p,
            'final_ratio_kmeanspp_over_random': final_p/final_r,
            'kmeanspp_wins': int(final_p < final_r),
        })
    summary = {
        'paper': 'Arthur & Vassilvitskii (2007), k-means++: The Advantages of Careful Seeding',
        'dataset': f'synthetic Gaussian mixture, n={len(points)}, k={K}',
        'trials': TRIALS,
        'random_initial': summarize([r['random_initial_sse'] for r in rows]),
        'kmeanspp_initial': summarize([r['kmeanspp_initial_sse'] for r in rows]),
        'random_final': summarize([r['random_final_sse'] for r in rows]),
        'kmeanspp_final': summarize([r['kmeanspp_final_sse'] for r in rows]),
        'mean_final_ratio_kmeanspp_over_random': statistics.mean([r['final_ratio_kmeanspp_over_random'] for r in rows]),
        'kmeanspp_win_rate': statistics.mean([r['kmeanspp_wins'] for r in rows]),
    }
    summary['final_objective_reduction_vs_random_mean'] = 1 - summary['kmeanspp_final']['mean'] / summary['random_final']['mean']
    with open(os.path.join(OUT, 'trials.csv'), 'w', newline='', encoding='utf-8') as f:
        wr = csv.DictWriter(f, fieldnames=list(rows[0].keys())); wr.writeheader(); wr.writerows(rows)
    with open(os.path.join(OUT, 'summary.json'), 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2)
    svg = write_svg(summary)
    print(json.dumps(summary, indent=2))
    print('wrote', svg)

if __name__ == '__main__':
    main()
