#include <iostream>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cerrno>
#include <limits>
#include <chrono>

#include <unistd.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sched.h>

#include "shm_keys.hpp"
#include "row.hpp"
#include "sort_algorithms.hpp"

#include "one_to_one_pipe_scheduler.hpp"
#include "one_to_many_scheduler.hpp"
#include "many_to_many_scheduler.hpp"

using namespace std;

struct MatrixSortContext
{
    uint8_t* A;
    int F;
    int H;
    int W;
};

static inline uint8_t& CELL(uint8_t* A, int f, int i, int j, int H, int W)
{
    return A[(static_cast<size_t>(f) * H + i) * W + j];
}

static bool checked_total_bytes(size_t F, size_t H, size_t W, size_t& totalCells, size_t& totalBytes)
{
    if (F == 0 || H == 0 || W == 0)
        return false;

    if (F > numeric_limits<size_t>::max() / H)
        return false;

    size_t FH = F * H;

    if (FH > numeric_limits<size_t>::max() / W)
        return false;

    totalCells = FH * W;
    totalBytes = totalCells * sizeof(uint8_t);
    return true;
}

static int get_existing_shmid(key_t key)
{
    int shmid = shmget(key, 1, 0666);

    if (shmid == -1)
    {
        cerr << "Could not find shared memory key " << key << ": "
             << strerror(errno) << "\n";
    }

    return shmid;
}

static uint8_t* attach_shm(int shmid)
{
    void* base = shmat(shmid, nullptr, 0);

    if (base == reinterpret_cast<void*>(-1))
    {
        cerr << "shmat failed: " << strerror(errno) << "\n";
        return nullptr;
    }

    return static_cast<uint8_t*>(base);
}

static void remove_shm_if_exists(key_t key)
{
    int shmid = shmget(key, 1, 0666);

    if (shmid != -1)
    {
        shmctl(shmid, IPC_RMID, nullptr);
    }
}

static int create_copy_segment(key_t key, size_t totalBytes)
{
    remove_shm_if_exists(key);

    int shmid = shmget(key, totalBytes, IPC_CREAT | IPC_EXCL | 0666);

    if (shmid == -1)
    {
        cerr << "Could not create copy segment " << key << ": "
             << strerror(errno) << "\n";
        return -1;
    }

    return shmid;
}

static bool same_bytes(const uint8_t* A, const uint8_t* B, size_t totalBytes)
{
    return memcmp(A, B, totalBytes) == 0;
}

static bool is_sorted_row_major(const uint8_t* A, int F, int H, int W)
{
    for (int f = 0; f < F; ++f)
    {
        for (int i = 0; i < H; ++i)
        {
            const uint8_t* rowPtr = &A[(static_cast<size_t>(f) * H + i) * W];

            for (int j = 1; j < W; ++j)
            {
                if (rowPtr[j - 1] > rowPtr[j])
                    return false;
            }
        }
    }

    return true;
}

void sort_frame_task(int frameId, int workerId, int workerCount, void* ctxPtr)
{
    (void)workerId;
    (void)workerCount;

    MatrixSortContext* ctx = static_cast<MatrixSortContext*>(ctxPtr);

    uint8_t* A = ctx->A;
    int H = ctx->H;
    int W = ctx->W;

    for (int i = 0; i < H; ++i)
    {
        uint8_t* rowPtr = &CELL(A, frameId, i, 0, H, W);

        row<uint8_t> currentRow(rowPtr, static_cast<uint32_t>(W), false);
        quick_sort(currentRow);
    }
}

static long long run_onetoone(uint8_t* A, int F, int H, int W, int K, bool affinity)
{
    MatrixSortContext ctx;
    ctx.A = A;
    ctx.F = F;
    ctx.H = H;
    ctx.W = W;

    auto t1 = chrono::high_resolution_clock::now();

    int rc = run_one_to_one_pipe_scheduler(F, K, sort_frame_task, &ctx, affinity);

    auto t2 = chrono::high_resolution_clock::now();

    if (rc != 0)
        return -1;

    return chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
}

static long long run_onetomany(uint8_t* A, int F, int H, int W, int K, bool affinity)
{
    MatrixSortContext ctx;
    ctx.A = A;
    ctx.F = F;
    ctx.H = H;
    ctx.W = W;

    auto t1 = chrono::high_resolution_clock::now();

    int rc = run_one_to_many_pipe_scheduler(F, K, sort_frame_task, &ctx, affinity);

    auto t2 = chrono::high_resolution_clock::now();

    if (rc != 0)
        return -1;

    return chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
}

static long long run_manytomany(uint8_t* A, int F, int H, int W, int K, bool affinity)
{
    MatrixSortContext ctx;
    ctx.A = A;
    ctx.F = F;
    ctx.H = H;
    ctx.W = W;

    auto t1 = chrono::high_resolution_clock::now();

    int rc = run_many_to_many_pipe_scheduler(F, K, sort_frame_task, &ctx, affinity);

    auto t2 = chrono::high_resolution_clock::now();

    if (rc != 0)
        return -1;

    return chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
}

int main(int argc, char** argv)
{
    if (argc < 5 || argc > 6)
    {
        cout << "Usage: ./shm_benchmark F H W K [affinity: 0|1]\n";
        cout << "Example: ./shm_benchmark 1000 640 640 24 1\n";
        return 1;
    }

    int F = atoi(argv[1]);
    int H = atoi(argv[2]);
    int W = atoi(argv[3]);
    int K = atoi(argv[4]);

    bool affinity = true;

    if (argc == 6)
        affinity = atoi(argv[5]) != 0;

    if (F <= 0 || H <= 0 || W <= 0 || K <= 0)
    {
        cerr << "Invalid arguments.\n";
        return 1;
    }

    size_t totalCells = 0;
    size_t totalBytes = 0;

    if (!checked_total_bytes(
            static_cast<size_t>(F),
            static_cast<size_t>(H),
            static_cast<size_t>(W),
            totalCells,
            totalBytes))
    {
        cerr << "Invalid or too large tensor dimensions.\n";
        return 1;
    }

    int originalId = get_existing_shmid(SHM_KEY_ORIGINAL);

    if (originalId == -1)
    {
        cerr << "Run ./shm_generator F H W first.\n";
        return 1;
    }

    uint8_t* original = attach_shm(originalId);

    if (original == nullptr)
        return 1;

    int onetooneId = create_copy_segment(SHM_KEY_ONETOONE_COPY, totalBytes);
    int onetomanyId = create_copy_segment(SHM_KEY_ONETOMANY_COPY, totalBytes);
    int manytomanyId = create_copy_segment(SHM_KEY_MANYTOMANY_COPY, totalBytes);

    if (onetooneId == -1 || onetomanyId == -1 || manytomanyId == -1)
    {
        shmdt(original);
        return 1;
    }

    uint8_t* onetooneCopy = attach_shm(onetooneId);
    uint8_t* onetomanyCopy = attach_shm(onetomanyId);
    uint8_t* manytomanyCopy = attach_shm(manytomanyId);

    if (onetooneCopy == nullptr || onetomanyCopy == nullptr || manytomanyCopy == nullptr)
    {
        shmdt(original);
        return 1;
    }

    memcpy(onetooneCopy, original, totalBytes);
    memcpy(onetomanyCopy, original, totalBytes);
    memcpy(manytomanyCopy, original, totalBytes);

    cout << "Benchmark input prepared.\n";
    cout << "F=" << F << ", H=" << H << ", W=" << W << ", K=" << K;
    cout << ", affinity=" << (affinity ? "on" : "off") << "\n";

    cout << "Copies identical before sorting: "
         << (same_bytes(onetooneCopy, onetomanyCopy, totalBytes) &&
             same_bytes(onetooneCopy, manytomanyCopy, totalBytes)
                 ? "yes"
                 : "no")
         << "\n";

    long long onetooneMs = run_onetoone(onetooneCopy, F, H, W, K, affinity);
    long long onetomanyMs = run_onetomany(onetomanyCopy, F, H, W, K, affinity);
    long long manytomanyMs = run_manytomany(manytomanyCopy, F, H, W, K, affinity);

    if (onetooneMs < 0 || onetomanyMs < 0 || manytomanyMs < 0)
    {
        cerr << "A scheduler failed.\n";
        return 1;
    }

    cout << "\nResults\n";
    cout << "-------\n";
    cout << "One-to-One Pipe sorting time:    " << onetooneMs << " ms\n";
    cout << "One-to-Many Pipe sorting time:   " << onetomanyMs << " ms\n";
    cout << "Many-to-Many Pipe sorting time:  " << manytomanyMs << " ms\n";

    cout << "\nCorrectness\n";
    cout << "-----------\n";
    cout << "One-to-One sorted correctly:     "
         << (is_sorted_row_major(onetooneCopy, F, H, W) ? "yes" : "no") << "\n";

    cout << "One-to-Many sorted correctly:    "
         << (is_sorted_row_major(onetomanyCopy, F, H, W) ? "yes" : "no") << "\n";

    cout << "Many-to-Many sorted correctly:   "
         << (is_sorted_row_major(manytomanyCopy, F, H, W) ? "yes" : "no") << "\n";

    shmdt(original);
    shmdt(onetooneCopy);
    shmdt(onetomanyCopy);
    shmdt(manytomanyCopy);

    shmctl(onetooneId, IPC_RMID, nullptr);
    shmctl(onetomanyId, IPC_RMID, nullptr);
    shmctl(manytomanyId, IPC_RMID, nullptr);

    return 0;
}