// Unified benchmark — runs all 13 schedulers against the same input.
//
// Process schedulers   : prolific, collective, log_collective
// Thread schedulers    : static, dynamic, chunk, chunk_steal, guided, adaptive, aimd
// Pipe schedulers      : one_to_one, one_to_many, many_to_many
//
// Each scheduler gets its own SHM copy of the original tensor produced by
// shm_generator, sorts every row of every frame in place via quick_sort,
// and reports elapsed milliseconds. At the end we print the fastest one
// that also passed the sortedness check.

#include <iostream>
#include <iomanip>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cerrno>
#include <limits>
#include <chrono>

#include <unistd.h>
#include <sys/ipc.h>
#include <sys/shm.h>

#include "shm_keys.hpp"
#include "row.hpp"
#include "sort_algorithms.hpp"

// Thread schedulers (task_fn_t)
#include "scheduler_iface.hpp"
#include "static_scheduler.hpp"
#include "dynamic_scheduler.hpp"
#include "chunk_scheduler.hpp"
#include "chunk_steal_scheduler.hpp"
#include "guided_scheduler.hpp"
#include "adaptive_scheduler.hpp"
#include "aimd.hpp"

// Process schedulers
#include "bounded_prolific_scheduler.hpp"        // scheduler_task_fn
#include "bounded_collective_scheduler.hpp"      // collective_task_fn

// Pipe schedulers (use scheduler_task_fn from bounded_prolific_scheduler.hpp)
#include "one_to_one_pipe_scheduler.hpp"
#include "one_to_many_scheduler.hpp"
#include "many_to_many_scheduler.hpp"

using namespace std;

struct MatrixSortContext
{
    uint8_t* A;
    int F;
    int H;
    int W;
};

static inline uint8_t& CELL(uint8_t* A, int f, int i, int j, int H, int W)
{
    return A[(static_cast<size_t>(f) * H + i) * W + j];
}

// Sort every row of frame `frameId` in place. The signature matches every
// scheduler's task callback type — task_fn_t, scheduler_task_fn,
// collective_task_fn — because they're all identical function-pointer types.
static void sort_frame_task(int frameId, int workerId, int workerCount, void* ctxPtr)
{
    (void)workerId;
    (void)workerCount;

    MatrixSortContext* ctx = static_cast<MatrixSortContext*>(ctxPtr);
    int H = ctx->H;
    int W = ctx->W;
    uint8_t* A = ctx->A;

    for (int i = 0; i < H; ++i)
    {
        uint8_t* rowPtr = &CELL(A, frameId, i, 0, H, W);
        row<uint8_t> r(rowPtr, static_cast<uint32_t>(W), false);
        quick_sort(r);
    }
}

static bool checked_total_bytes(size_t F, size_t H, size_t W,
                                size_t& totalCells, size_t& totalBytes)
{
    if (F == 0 || H == 0 || W == 0) return false;
    if (F > numeric_limits<size_t>::max() / H) return false;
    size_t FH = F * H;
    if (FH > numeric_limits<size_t>::max() / W) return false;
    totalCells = FH * W;
    totalBytes = totalCells * sizeof(uint8_t);
    return true;
}

static int get_existing_shmid(key_t key)
{
    int shmid = shmget(key, 1, 0666);
    if (shmid == -1)
        cerr << "Could not find SHM key " << key << ": " << strerror(errno) << "\n";
    return shmid;
}

static uint8_t* attach_shm(int shmid)
{
    void* base = shmat(shmid, nullptr, 0);
    if (base == reinterpret_cast<void*>(-1))
    {
        cerr << "shmat failed: " << strerror(errno) << "\n";
        return nullptr;
    }
    return static_cast<uint8_t*>(base);
}

static void remove_shm_if_exists(key_t key)
{
    int shmid = shmget(key, 1, 0666);
    if (shmid != -1)
        shmctl(shmid, IPC_RMID, nullptr);
}

static int create_copy_segment(key_t key, size_t totalBytes)
{
    remove_shm_if_exists(key);
    int shmid = shmget(key, totalBytes, IPC_CREAT | IPC_EXCL | 0666);
    if (shmid == -1)
        cerr << "Could not create copy segment " << key << ": " << strerror(errno) << "\n";
    return shmid;
}

static bool is_sorted_row_major(const uint8_t* A, int F, int H, int W)
{
    for (int f = 0; f < F; ++f)
    {
        for (int i = 0; i < H; ++i)
        {
            const uint8_t* rowPtr = &A[(static_cast<size_t>(f) * H + i) * W];
            for (int j = 1; j < W; ++j)
                if (rowPtr[j - 1] > rowPtr[j]) return false;
        }
    }
    return true;
}

// We have three families of scheduler entry points that all take the
// same arguments but differ in the C type of their function pointer.
// Wrapper structs let us put them in one table while keeping the calls
// type-correct.
typedef int (*thread_runner_t)(int F, int K, task_fn_t task, void* ctx, bool affinity);
typedef int (*proc_runner_t)  (int F, int K, scheduler_task_fn task, void* ctx, bool affinity);
typedef int (*coll_runner_t)  (int F, int K, collective_task_fn task, void* ctx, bool affinity);

enum RunnerKind { KIND_THREAD, KIND_PROC, KIND_COLL };

// Wrapper for AIMD because run_aimd_scheduler has one extra optional
// UtilizationMonitor* argument compared with the common thread_runner_t type.
// Passing nullptr keeps AIMD's internal monitor behavior and avoids undefined
// behavior from calling a 6-argument function through a 5-argument pointer.
static int run_aimd_scheduler_wrapper(int F, int K, task_fn_t task, void* ctx, bool affinity)
{
    return run_aimd_scheduler(F, K, task, ctx, affinity, nullptr);
}

struct Variant
{
    const char*  name;
    key_t        key;
    RunnerKind   kind;
    void*        runner;   // cast to the matching pointer based on kind
    int          shmid;
    uint8_t*     buf;
    long long    ms;
    bool         sorted_ok;
};

static int run_variant(Variant& v, int F, int K, void* ctx, bool affinity)
{
    switch (v.kind)
    {
        case KIND_THREAD:
            return reinterpret_cast<thread_runner_t>(v.runner)(F, K, sort_frame_task, ctx, affinity);
        case KIND_PROC:
            return reinterpret_cast<proc_runner_t>(v.runner)(F, K, sort_frame_task, ctx, affinity);
        case KIND_COLL:
            return reinterpret_cast<coll_runner_t>(v.runner)(F, K, sort_frame_task, ctx, affinity);
    }
    return -1;
}

int main(int argc, char** argv)
{
    if (argc < 5 || argc > 6)
    {
        cout << "Usage: ./shm_benchmark F H W K [affinity: 0|1]\n";
        cout << "Example: ./shm_benchmark 1000 640 640 24 1\n";
        return 1;
    }

    int F = atoi(argv[1]);
    int H = atoi(argv[2]);
    int W = atoi(argv[3]);
    int K = atoi(argv[4]);
    bool affinity = (argc == 6) ? (atoi(argv[5]) != 0) : true;

    if (F <= 0 || H <= 0 || W <= 0 || K <= 0)
    {
        cerr << "Invalid arguments.\n";
        return 1;
    }

    size_t totalCells = 0, totalBytes = 0;
    if (!checked_total_bytes(F, H, W, totalCells, totalBytes))
    {
        cerr << "Invalid or too large tensor dimensions.\n";
        return 1;
    }

    int originalId = get_existing_shmid(SHM_KEY_ORIGINAL);
    if (originalId == -1)
    {
        cerr << "Run ./shm_generator F H W first.\n";
        return 1;
    }
    uint8_t* original = attach_shm(originalId);
    if (!original) return 1;

    // Order chosen to match the screenshot: pipes, processes, threads.
    Variant variants[] = {
        // pipes
        { "one_to_one",     SHM_KEY_ONETOONE_COPY,        KIND_PROC, (void*)run_one_to_one_pipe_scheduler,        -1, nullptr, 0, false },
        { "one_to_many",    SHM_KEY_ONETOMANY_COPY,       KIND_PROC, (void*)run_one_to_many_pipe_scheduler,       -1, nullptr, 0, false },
        { "many_to_many",   SHM_KEY_MANYTOMANY_COPY,      KIND_PROC, (void*)run_many_to_many_pipe_scheduler,      -1, nullptr, 0, false },

        // processes
        { "prolific",       SHM_KEY_PROLIFIC_COPY,        KIND_PROC, (void*)run_bounded_prolific_scheduler,       -1, nullptr, 0, false },
        { "collective",     SHM_KEY_COLLECTIVE_COPY,      KIND_COLL, (void*)run_bounded_collective_scheduler,     -1, nullptr, 0, false },

        // threads
        { "static",         SHM_KEY_STATIC_COPY,          KIND_THREAD, (void*)run_static_scheduler,               -1, nullptr, 0, false },
        { "dynamic",        SHM_KEY_DYNAMIC_COPY,         KIND_THREAD, (void*)run_dynamic_scheduler,              -1, nullptr, 0, false },
        { "chunk",          SHM_KEY_CHUNK_COPY,           KIND_THREAD, (void*)run_chunk_scheduler,                -1, nullptr, 0, false },
        { "chunk_steal",    SHM_KEY_CHUNKSTEAL_COPY,      KIND_THREAD, (void*)run_chunk_steal_scheduler,          -1, nullptr, 0, false },
        { "guided",         SHM_KEY_GUIDED_COPY,          KIND_THREAD, (void*)run_guided_scheduler,               -1, nullptr, 0, false },
        { "adaptive",       SHM_KEY_ADAPTIVE_COPY,        KIND_THREAD, (void*)run_adaptive_scheduler,             -1, nullptr, 0, false },
        { "aimd",           SHM_KEY_AIMD_COPY,            KIND_THREAD, (void*)run_aimd_scheduler_wrapper,         -1, nullptr, 0, false },
    };
    const int N = sizeof(variants) / sizeof(variants[0]);

    // Allocate one fresh copy segment per scheduler, then memcpy the
    // original into each so every scheduler sees identical input.
    for (int i = 0; i < N; ++i)
    {
        variants[i].shmid = create_copy_segment(variants[i].key, totalBytes);
        if (variants[i].shmid == -1)
        {
            shmdt(original);
            for (int j = 0; j < i; ++j)
            {
                if (variants[j].buf) shmdt(variants[j].buf);
                shmctl(variants[j].shmid, IPC_RMID, nullptr);
            }
            return 1;
        }
        variants[i].buf = attach_shm(variants[i].shmid);
        if (!variants[i].buf)
        {
            shmdt(original);
            return 1;
        }
        memcpy(variants[i].buf, original, totalBytes);
    }

    cout << "Benchmark input prepared.\n";
    cout << "F=" << F << " H=" << H << " W=" << W
         << " K=" << K << " affinity=" << (affinity ? "on" : "off") << "\n\n";

    for (int i = 0; i < N; ++i)
    {
        MatrixSortContext ctx;
        ctx.A = variants[i].buf;
        ctx.F = F;
        ctx.H = H;
        ctx.W = W;

        auto t1 = chrono::high_resolution_clock::now();
        int rc = run_variant(variants[i], F, K, &ctx, affinity);
        auto t2 = chrono::high_resolution_clock::now();

        if (rc != 0)
        {
            cerr << variants[i].name << " scheduler failed.\n";
            variants[i].ms = -1;
        }
        else
        {
            variants[i].ms = chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
        }
        variants[i].sorted_ok = is_sorted_row_major(variants[i].buf, F, H, W);

        cout << left << setw(16) << variants[i].name
             << "  " << right << setw(8) << variants[i].ms << " ms"
             << "  sorted=" << (variants[i].sorted_ok ? "yes" : "NO")
             << "\n";
    }

    // Pick the fastest correctly-sorted variant.
    int best = -1;
    for (int i = 0; i < N; ++i)
    {
        if (!variants[i].sorted_ok || variants[i].ms < 0) continue;
        if (best == -1 || variants[i].ms < variants[best].ms) best = i;
    }
    if (best != -1)
    {
        cout << "\nFastest: " << variants[best].name
             << " (" << variants[best].ms << " ms)\n";
    }

    // Tear down all copy segments; leave SHM_KEY_ORIGINAL alone for reuse.
    shmdt(original);
    for (int i = 0; i < N; ++i)
    {
        if (variants[i].buf) shmdt(variants[i].buf);
        if (variants[i].shmid != -1) shmctl(variants[i].shmid, IPC_RMID, nullptr);
    }
    return 0;
}
