import os
import xml.etree.ElementTree as ET
import argparse
import shutil
from pathlib import Path
import random
import torch
from ultralytics import YOLO

import csv
from datetime import datetime

# Global variable to hold CSV path
metrics_csv_path = "epoch_metrics.csv"

# Write header once
with open(metrics_csv_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow([
        "epoch",
        "train_box_loss",
        "train_cls_loss",
        "train_dfl_loss",
        "val_box_loss",
        "val_cls_loss",
        "val_dfl_loss",
        "precision",
        "recall",
        "mAP50",
        "mAP50-95"
    ])




# -------------------------------------------------
# VOC to YOLO Detection Conversion
# -------------------------------------------------
def voc_xml_to_yolo_det(xml_path, label_path, class_map):

    tree = ET.parse(xml_path)
    root = tree.getroot()

    size = root.find("size")
    w = int(size.find("width").text)
    h = int(size.find("height").text)

    with open(label_path, "w") as out_file:
        for obj in root.iter("object"):
            cls_name = obj.find("name").text.strip()

            if cls_name not in class_map:
                continue

            cls_id = class_map[cls_name]
            xmlbox = obj.find("bndbox")

            xmin = float(xmlbox.find("xmin").text)
            ymin = float(xmlbox.find("ymin").text)
            xmax = float(xmlbox.find("xmax").text)
            ymax = float(xmlbox.find("ymax").text)

            # Clamp to image bounds
            xmin = max(0, min(xmin, w))
            xmax = max(0, min(xmax, w))
            ymin = max(0, min(ymin, h))
            ymax = max(0, min(ymax, h))

            if xmax <= xmin or ymax <= ymin:
                continue

            # Convert to YOLO format
            x_center = ((xmin + xmax) / 2) / w
            y_center = ((ymin + ymax) / 2) / h
            box_w = (xmax - xmin) / w
            box_h = (ymax - ymin) / h

            out_file.write(f"{cls_id} {x_center} {y_center} {box_w} {box_h}\n")


# -------------------------------------------------
# Dataset Builder
# -------------------------------------------------
def build_voc_dataset(voc_root, output_root, class_names,
                      train_ratio=0.7, val_ratio=0.10):

    images_src = Path(voc_root) / "images"
    ann_src = Path(voc_root) / "annotations"

    images_out = Path(output_root) / "images"
    labels_out = Path(output_root) / "labels"

    for split in ["train", "val", "test"]:
        (images_out / split).mkdir(parents=True, exist_ok=True)
        (labels_out / split).mkdir(parents=True, exist_ok=True)

    class_map = {name: i for i, name in enumerate(class_names)}

    image_files = list(images_src.glob("*.png"))
    random.shuffle(image_files)

    n = len(image_files)
    train_end = int(n * train_ratio)
    val_end = train_end + int(n * val_ratio)

    splits = {
        "train": image_files[:train_end],
        "val": image_files[train_end:val_end],
        "test": image_files[val_end:]
    }

    for split, files in splits.items():
        for img_path in files:
            xml_path = ann_src / (img_path.stem + ".xml")
            if not xml_path.exists():
                continue

            img_dst = images_out / split / img_path.name
            label_dst = labels_out / split / (img_path.stem + ".txt")

            shutil.copy(img_path, img_dst)
            voc_xml_to_yolo_det(xml_path, label_dst, class_map)

    yaml_path = Path(output_root) / "data.yaml"
    with open(yaml_path, "w") as f:
        f.write(f"path: {output_root}\n")
        f.write("train: images/train\n")
        f.write("val: images/val\n")
        f.write("test: images/test\n\n")
        f.write("names:\n")
        for i, name in enumerate(class_names):
            f.write(f"  {i}: {name}\n")

    return yaml_path



# -------------------------------------------------
# Custom Callback to Save Full Metrics Per Epoch
# -------------------------------------------------
def log_metrics(trainer):
    """Callback to log metrics after each epoch."""

    epoch = trainer.epoch + 1

    # --- Train losses ---
    train_box = trainer.loss_items[0] if trainer.loss_items is not None else 0
    train_cls = trainer.loss_items[1] if trainer.loss_items is not None else 0
    train_dfl = trainer.loss_items[2] if trainer.loss_items is not None else 0

    # --- Validation metrics ---
    val_results = trainer.validator.metrics if trainer.validator else None

    if val_results:
        precision = val_results.box.mp
        recall = val_results.box.mr
        map50 = val_results.box.map50
        map5095 = val_results.box.map
    else:
        precision = recall = map50 = map5095 = 0

    # Validation losses
    val_box = trainer.metrics.get("val/box_loss", 0)
    val_cls = trainer.metrics.get("val/cls_loss", 0)
    val_dfl = trainer.metrics.get("val/dfl_loss", 0)

    row = [
        epoch,
        float(train_box),
        float(train_cls),
        float(train_dfl),
        float(val_box),
        float(val_cls),
        float(val_dfl),
        float(precision),
        float(recall),
        float(map50),
        float(map5095),
    ]

    with open(metrics_csv_path, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(row)
# -------------------------------------------------
# Main
# -------------------------------------------------
def main(args):

    # Dataset conversion
    if args.dataset_type == "voc":
        print("Converting VOC -> YOLO format...")
        yaml_path = build_voc_dataset(
            voc_root=args.dataset_path,
            output_root="converted_dataset",
            class_names=args.classes
        )
    else:
        yaml_path = Path(args.dataset_path) / "data.yaml"

    device =  0 if torch.cuda.is_available() else "cpu"

    model = YOLO(args.model)
    # Add the logging callback
    model.add_callback("on_fit_epoch_end", log_metrics)


    results = model.train(
        data=str(yaml_path),
        epochs=args.epochs,
        imgsz=args.imgsz,
        batch=args.batch,
        device=device,
        workers=4,
        project="runs",
        name="geom_detect",
        exist_ok=True,
        amp=True,
        rect=False,
        # Proper scheduler
        cos_lr=False,
        optimizer="SGD",
        lr0=0.01, #typical for SGD
        lrf=0.0001,  # final lr = 1% of initial
        momentum=0.937,
        weight_decay=5e-6,
        # Native early stopping
        patience=20
        #callbacks=[log_metrics]
        #callbacks={"on_fit_epoch_end": metrics_logger}
    )

    print("Training completed.")


# -------------------------------------------------
# CLI
# -------------------------------------------------
if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset_type", choices=["voc", "yolo"], required=True)
    parser.add_argument("--dataset_path", required=True)
    parser.add_argument("--classes", nargs="+")
    parser.add_argument("--model", default="yolo11m.pt")
    parser.add_argument("--epochs", type=int, default=150)
    parser.add_argument("--imgsz", type=int, default=640)
    parser.add_argument("--batch", type=int, default=8)

    args = parser.parse_args()

    if args.dataset_type == "voc" and not args.classes:
        raise ValueError("--classes required when dataset_type=voc")

    main(args)

