#include <iostream>
#include <iomanip>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cerrno>
#include <limits>
#include <chrono>

#include <unistd.h>
#include <sys/ipc.h>
#include <sys/shm.h>

#include "shm_keys.hpp"
#include "row.hpp"
#include "sort_algorithms.hpp"
#include "scheduler_iface.hpp"
#include "aimd.hpp"
#include "aimd_fixed.hpp"

// Keeps this file usable even if shm_keys.hpp has not been edited.
#ifndef SHM_KEY_AIMD_FIXED_COPY
#define SHM_KEY_AIMD_FIXED_COPY (SHM_KEY_AIMD_COPY + 1000)
#endif

using namespace std;

struct MatrixSortContext {
    uint8_t* A;
    int F;
    int H;
    int W;
};

static inline uint8_t& CELL(uint8_t* A, int f, int i, int j, int H, int W) {
    return A[(static_cast<size_t>(f) * H + i) * W + j];
}

// Per-frame task: sort each of the H rows of frame frameId in place.
static void sort_frame_task(int frameId, int workerId, int workerCount, void* ctxPtr) {
    (void)workerId;
    (void)workerCount;

    MatrixSortContext* ctx = static_cast<MatrixSortContext*>(ctxPtr);
    int H = ctx->H;
    int W = ctx->W;
    uint8_t* A = ctx->A;

    for (int i = 0; i < H; ++i) {
        uint8_t* rowPtr = &CELL(A, frameId, i, 0, H, W);
        row<uint8_t> r(rowPtr, static_cast<uint32_t>(W), false);
        quick_sort(r);
    }
}

static bool checked_total_bytes(size_t F, size_t H, size_t W,
                                size_t& totalCells, size_t& totalBytes) {
    if (F == 0 || H == 0 || W == 0) return false;
    if (F > numeric_limits<size_t>::max() / H) return false;
    size_t FH = F * H;
    if (FH > numeric_limits<size_t>::max() / W) return false;
    totalCells = FH * W;
    totalBytes = totalCells * sizeof(uint8_t);
    return true;
}

static int get_existing_shmid(key_t key) {
    int shmid = shmget(key, 1, 0666);
    if (shmid == -1) {
        cerr << "Could not find shared memory key " << key << ": "
             << strerror(errno) << "\n";
    }
    return shmid;
}

static uint8_t* attach_shm(int shmid) {
    void* base = shmat(shmid, nullptr, 0);
    if (base == reinterpret_cast<void*>(-1)) {
        cerr << "shmat failed: " << strerror(errno) << "\n";
        return nullptr;
    }
    return static_cast<uint8_t*>(base);
}

static void remove_shm_if_exists(key_t key) {
    int shmid = shmget(key, 1, 0666);
    if (shmid != -1) {
        shmctl(shmid, IPC_RMID, nullptr);
    }
}

static int create_copy_segment(key_t key, size_t totalBytes) {
    remove_shm_if_exists(key);
    int shmid = shmget(key, totalBytes, IPC_CREAT | IPC_EXCL | 0666);
    if (shmid == -1) {
        cerr << "Could not create copy segment " << key << ": "
             << strerror(errno) << "\n";
    }
    return shmid;
}

static bool is_sorted_row_major(const uint8_t* A, int F, int H, int W) {
    for (int f = 0; f < F; ++f) {
        for (int i = 0; i < H; ++i) {
            const uint8_t* rowPtr = &A[(static_cast<size_t>(f) * H + i) * W];
            for (int j = 1; j < W; ++j) {
                if (rowPtr[j - 1] > rowPtr[j]) return false;
            }
        }
    }
    return true;
}

struct Variant {
    const char* name;
    key_t key;
    int (*run)(int, int, task_fn_t, void*, bool);
    int shmid;
    uint8_t* buf;
    long long ms;
    bool sorted_ok;
};

int main(int argc, char** argv) {
    if (argc < 5 || argc > 6) {
        cout << "Usage: ./shm_benchmark F H W K [affinity: 0|1]\n";
        cout << "Example: ./shm_benchmark 1000 640 640 4 1\n";
        return 1;
    }

    int F = atoi(argv[1]);
    int H = atoi(argv[2]);
    int W = atoi(argv[3]);
    int K = atoi(argv[4]);
    bool affinity = (argc == 6) ? (atoi(argv[5]) != 0) : true;

    if (F <= 0 || H <= 0 || W <= 0 || K <= 0) {
        cerr << "Invalid arguments.\n";
        return 1;
    }

    size_t totalCells = 0, totalBytes = 0;
    if (!checked_total_bytes(F, H, W, totalCells, totalBytes)) {
        cerr << "Invalid or too large tensor dimensions.\n";
        return 1;
    }

    int originalId = get_existing_shmid(SHM_KEY_ORIGINAL);
    if (originalId == -1) {
        cerr << "Run ./shm_generator F H W first.\n";
        return 1;
    }

    uint8_t* original = attach_shm(originalId);
    if (!original) return 1;

    Variant variants[] = {
        { "aimd_old",   SHM_KEY_AIMD_COPY,       run_aimd_scheduler,       -1, nullptr, 0, false },
        { "aimd_fixed", SHM_KEY_AIMD_FIXED_COPY, run_aimd_fixed_scheduler, -1, nullptr, 0, false }
    };
    const int N = sizeof(variants) / sizeof(variants[0]);

    for (int i = 0; i < N; ++i) {
        variants[i].shmid = create_copy_segment(variants[i].key, totalBytes);
        if (variants[i].shmid == -1) {
            shmdt(original);
            return 1;
        }

        variants[i].buf = attach_shm(variants[i].shmid);
        if (!variants[i].buf) {
            shmdt(original);
            return 1;
        }

        memcpy(variants[i].buf, original, totalBytes);
    }

    cout << "AIMD comparison input prepared.\n";
    cout << "F=" << F << " H=" << H << " W=" << W
         << " K=" << K << " affinity=" << (affinity ? "on" : "off") << "\n\n";

    for (int i = 0; i < N; ++i) {
        MatrixSortContext ctx;
        ctx.A = variants[i].buf;
        ctx.F = F;
        ctx.H = H;
        ctx.W = W;

        auto t1 = chrono::high_resolution_clock::now();
        int rc = variants[i].run(F, K, sort_frame_task, &ctx, affinity);
        auto t2 = chrono::high_resolution_clock::now();

        if (rc != 0) {
            cerr << variants[i].name << " scheduler failed.\n";
            variants[i].ms = -1;
        } else {
            variants[i].ms = chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
        }

        variants[i].sorted_ok = is_sorted_row_major(variants[i].buf, F, H, W);

        cout << left << setw(14) << variants[i].name
             << "  " << right << setw(8) << variants[i].ms << " ms"
             << "  sorted=" << (variants[i].sorted_ok ? "yes" : "NO")
             << "\n";
    }

    shmdt(original);
    for (int i = 0; i < N; ++i) {
        if (variants[i].buf) shmdt(variants[i].buf);
        if (variants[i].shmid != -1) shmctl(variants[i].shmid, IPC_RMID, nullptr);
    }

    return 0;
}
