// Unified benchmark — THREAD schedulers only.
//
// Schedulers: static, dynamic, chunk, chunk_steal, guided, adaptive, aimd
//
// Each scheduler gets its own SHM copy of the original tensor produced by
// shm_generator, sorts every row of every frame in place via quick_sort,
// and reports elapsed milliseconds. The copy segment for a given scheduler
// is attached just before that scheduler runs and detached right after,
// so the next scheduler starts with a clean address space.

#include <iostream>
#include <iomanip>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cerrno>
#include <limits>
#include <chrono>

#include <unistd.h>
#include <sys/ipc.h>
#include <sys/shm.h>

#include "shm_keys.hpp"
#include "row.hpp"
#include "sort_algorithms.hpp"

#include "scheduler_iface.hpp"
#include "utilization_monitor.hpp"
#include "static_scheduler.hpp"
#include "dynamic_scheduler.hpp"
#include "chunk_scheduler.hpp"
#include "chunk_steal_scheduler.hpp"
#include "guided_scheduler.hpp"
#include "adaptive_scheduler.hpp"
#include "aimd.hpp"

using namespace std;

struct MatrixSortContext
{
    uint8_t* A;
    int F;
    int H;
    int W;
};

static inline uint8_t& CELL(uint8_t* A, int f, int i, int j, int H, int W)
{
    return A[(static_cast<size_t>(f) * H + i) * W + j];
}

static void sort_frame_task(int frameId, int workerId, int workerCount, void* ctxPtr)
{
    (void)workerId;
    (void)workerCount;

    MatrixSortContext* ctx = static_cast<MatrixSortContext*>(ctxPtr);
    int H = ctx->H;
    int W = ctx->W;
    uint8_t* A = ctx->A;

    for (int i = 0; i < H; ++i)
    {
        uint8_t* rowPtr = &CELL(A, frameId, i, 0, H, W);
        row<uint8_t> r(rowPtr, static_cast<uint32_t>(W), false);
        quick_sort(r);
    }
}

static bool checked_total_bytes(size_t F, size_t H, size_t W,
                                size_t& totalCells, size_t& totalBytes)
{
    if (F == 0 || H == 0 || W == 0) return false;
    if (F > numeric_limits<size_t>::max() / H) return false;
    size_t FH = F * H;
    if (FH > numeric_limits<size_t>::max() / W) return false;
    totalCells = FH * W;
    totalBytes = totalCells * sizeof(uint8_t);
    return true;
}

static int get_existing_shmid(key_t key)
{
    int shmid = shmget(key, 1, 0666);
    if (shmid == -1)
        cerr << "Could not find SHM key " << key << ": " << strerror(errno) << "\n";
    return shmid;
}

static uint8_t* attach_shm(int shmid)
{
    void* base = shmat(shmid, nullptr, 0);
    if (base == reinterpret_cast<void*>(-1))
    {
        cerr << "shmat failed: " << strerror(errno) << "\n";
        return nullptr;
    }
    return static_cast<uint8_t*>(base);
}

static void remove_shm_if_exists(key_t key)
{
    int shmid = shmget(key, 1, 0666);
    if (shmid != -1)
        shmctl(shmid, IPC_RMID, nullptr);
}

static int create_copy_segment(key_t key, size_t totalBytes)
{
    remove_shm_if_exists(key);
    int shmid = shmget(key, totalBytes, IPC_CREAT | IPC_EXCL | 0666);
    if (shmid == -1)
        cerr << "Could not create copy segment " << key << ": " << strerror(errno) << "\n";
    return shmid;
}

static bool is_sorted_row_major(const uint8_t* A, int F, int H, int W)
{
    for (int f = 0; f < F; ++f)
    {
        for (int i = 0; i < H; ++i)
        {
            const uint8_t* rowPtr = &A[(static_cast<size_t>(f) * H + i) * W];
            for (int j = 1; j < W; ++j)
                if (rowPtr[j - 1] > rowPtr[j]) return false;
        }
    }
    return true;
}

typedef int (*thread_runner_t)(int F, int K, task_fn_t task, void* ctx, bool affinity);

struct Variant
{
    const char*      name;
    key_t            key;
    thread_runner_t  runner;
    int              shmid;
    long long        ms;
    bool             sorted_ok;
};

int main(int argc, char** argv)
{
    if (argc < 5)
    {
        cout << "Usage: ./shm_benchmark F H W K [affinity: 0|1] [--only NAME]\n";
        cout << "Examples:\n";
        cout << "  ./shm_benchmark 1000 640 640 24 1                  # run all threads\n";
        cout << "  ./shm_benchmark 1000 640 640 24 1 --only guided\n";
        cout << "Valid NAMEs: static, dynamic, chunk, chunk_steal, guided, adaptive, aimd\n";
        return 1;
    }

    int F = atoi(argv[1]);
    int H = atoi(argv[2]);
    int W = atoi(argv[3]);
    int K = atoi(argv[4]);
    bool affinity = true;
    const char* only_name = nullptr;

    for (int i = 5; i < argc; ++i)
    {
        if (strcmp(argv[i], "--only") == 0 && i + 1 < argc)
        {
            only_name = argv[i + 1];
            ++i;
        }
        else if (argv[i][0] == '0' || argv[i][0] == '1')
        {
            affinity = (atoi(argv[i]) != 0);
        }
    }

    if (F <= 0 || H <= 0 || W <= 0 || K <= 0)
    {
        cerr << "Invalid arguments.\n";
        return 1;
    }

    size_t totalCells = 0, totalBytes = 0;
    if (!checked_total_bytes(F, H, W, totalCells, totalBytes))
    {
        cerr << "Invalid or too large tensor dimensions.\n";
        return 1;
    }

    Variant all_variants[] = {
        { "static",      SHM_KEY_STATIC_COPY,       run_static_scheduler,       -1, 0, false },
        { "dynamic",     SHM_KEY_DYNAMIC_COPY,      run_dynamic_scheduler,      -1, 0, false },
        { "chunk",       SHM_KEY_CHUNK_COPY,        run_chunk_scheduler,        -1, 0, false },
        { "chunk_steal", SHM_KEY_CHUNKSTEAL_COPY,   run_chunk_steal_scheduler,  -1, 0, false },
        { "guided",      SHM_KEY_GUIDED_COPY,       run_guided_scheduler,       -1, 0, false },
        { "adaptive",    SHM_KEY_ADAPTIVE_COPY,     run_adaptive_scheduler,     -1, 0, false },
        { "aimd",        SHM_KEY_AIMD_COPY,         run_aimd_scheduler,         -1, 0, false },
    };
    const int ALL_N = sizeof(all_variants) / sizeof(all_variants[0]);

    Variant variants[ALL_N];
    int N = 0;
    if (only_name)
    {
        for (int i = 0; i < ALL_N; ++i)
        {
            if (strcmp(all_variants[i].name, only_name) == 0)
            {
                variants[N++] = all_variants[i];
                break;
            }
        }
        if (N == 0)
        {
            cerr << "Unknown scheduler name: " << only_name << "\n";
            cerr << "Valid names: static, dynamic, chunk, chunk_steal, guided, adaptive, aimd\n";
            return 1;
        }
    }
    else
    {
        for (int i = 0; i < ALL_N; ++i)
            variants[N++] = all_variants[i];
    }

    int originalId = get_existing_shmid(SHM_KEY_ORIGINAL);
    if (originalId == -1)
    {
        cerr << "Run ./shm_generator F H W first.\n";
        return 1;
    }

    // Create every copy segment up front, seed each by attach+memcpy+detach,
    // so the original is only co-mapped briefly.
    for (int i = 0; i < N; ++i)
    {
        variants[i].shmid = create_copy_segment(variants[i].key, totalBytes);
        if (variants[i].shmid == -1)
        {
            for (int j = 0; j < i; ++j)
                if (variants[j].shmid != -1)
                    shmctl(variants[j].shmid, IPC_RMID, nullptr);
            return 1;
        }

        uint8_t* original = attach_shm(originalId);
        uint8_t* dst      = attach_shm(variants[i].shmid);
        if (!original || !dst)
        {
            if (original) shmdt(original);
            if (dst)      shmdt(dst);
            return 1;
        }
        memcpy(dst, original, totalBytes);
        shmdt(dst);
        shmdt(original);
    }

    cout << "Benchmark input prepared.\n";
    cout << "F=" << F << " H=" << H << " W=" << W
         << " K=" << K << " affinity=" << (affinity ? "on" : "off") << "\n\n";

    // Spin up the CPU-utilization monitor ONCE here, outside any timing
    // window. AIMD will read it instead of creating its own — that way
    // the ~1s start/stop latency of the monitor's background sampling
    // thread is not charged to the AIMD wall-clock measurement.
    UtilizationMonitor* shared_monitor = new UtilizationMonitor(1, 0.7, 0.8);
    shared_monitor->start();

    for (int i = 0; i < N; ++i)
    {
        uint8_t* buf = attach_shm(variants[i].shmid);
        if (!buf)
        {
            cerr << variants[i].name << " attach failed.\n";
            variants[i].ms = -1;
            variants[i].sorted_ok = false;
            continue;
        }

        MatrixSortContext ctx;
        ctx.A = buf;
        ctx.F = F;
        ctx.H = H;
        ctx.W = W;

        auto t1 = chrono::high_resolution_clock::now();
        int rc;
        if (strcmp(variants[i].name, "aimd") == 0) {
            // Reuse the already-started monitor instead of letting AIMD
            // build one itself.
            rc = run_aimd_scheduler(F, K, sort_frame_task, &ctx, affinity,
                                    shared_monitor);
        } else {
            rc = variants[i].runner(F, K, sort_frame_task, &ctx, affinity);
        }
        auto t2 = chrono::high_resolution_clock::now();

        if (rc != 0)
        {
            cerr << variants[i].name << " scheduler failed.\n";
            variants[i].ms = -1;
        }
        else
        {
            variants[i].ms = chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
        }
        variants[i].sorted_ok = is_sorted_row_major(buf, F, H, W);

        shmdt(buf);

        cout << left << setw(16) << variants[i].name
             << "  " << right << setw(8) << variants[i].ms << " ms"
             << "  sorted=" << (variants[i].sorted_ok ? "yes" : "NO")
             << "\n";
    }

    int best = -1;
    for (int i = 0; i < N; ++i)
    {
        if (!variants[i].sorted_ok || variants[i].ms < 0) continue;
        if (best == -1 || variants[i].ms < variants[best].ms) best = i;
    }
    if (best != -1)
    {
        cout << "\nFastest: " << variants[best].name
             << " (" << variants[best].ms << " ms)\n";
    }

    for (int i = 0; i < N; ++i)
        if (variants[i].shmid != -1)
            shmctl(variants[i].shmid, IPC_RMID, nullptr);

    // Tear down the shared monitor. This is the slow stop (up to one
    // sampling interval), but it happens AFTER all timing is done.
    shared_monitor->stop();
    delete shared_monitor;

    return 0;
}
