#include <iostream>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cerrno>
#include <limits>

#include <unistd.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sched.h>
#include <chrono>

#include "shm_keys.hpp"
#include "row.hpp"
#include "sort_algorithms.hpp"
#include "bounded_prolific_scheduler.hpp"
#include "bounded_collective_scheduler.hpp"
#include "bounded_collective_scheduler_log.hpp"
#include "chunk_stealing_scheduler.hpp"

using namespace std;

struct MatrixSortContext
{
    uint8_t* A;
    int F;
    int H;
    int W;
};

static inline uint8_t& CELL(uint8_t* A, int f, int i, int j, int H, int W)
{
    return A[(static_cast<size_t>(f) * H + i) * W + j];
}

static bool checked_total_bytes(size_t F, size_t H, size_t W, size_t& totalCells, size_t& totalBytes)
{
    if (F == 0 || H == 0 || W == 0)
        return false;

    if (F > numeric_limits<size_t>::max() / H)
        return false;

    size_t FH = F * H;

    if (FH > numeric_limits<size_t>::max() / W)
        return false;

    totalCells = FH * W;
    totalBytes = totalCells * sizeof(uint8_t);
    return true;
}

static bool bind_process_to_cpu(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);
    return sched_setaffinity(0, sizeof(cpu_set_t), &cpuset) == 0;
}

static int get_existing_shmid(key_t key)
{
    int shmid = shmget(key, 1, 0666);
    if (shmid == -1)
    {
        cerr << "Could not find shared memory key " << key << ": " << strerror(errno) << "\n";
    }
    return shmid;
}

static uint8_t* attach_shm(int shmid)
{
    void* base = shmat(shmid, nullptr, 0);
    if (base == reinterpret_cast<void*>(-1))
    {
        cerr << "shmat failed: " << strerror(errno) << "\n";
        return nullptr;
    }
    return static_cast<uint8_t*>(base);
}

static void remove_shm_if_exists(key_t key)
{
    int shmid = shmget(key, 1, 0666);
    if (shmid != -1)
    {
        shmctl(shmid, IPC_RMID, nullptr);
    }
}

static int create_copy_segment(key_t key, size_t totalBytes)
{
    remove_shm_if_exists(key);

    int shmid = shmget(key, totalBytes, IPC_CREAT | IPC_EXCL | 0666);
    if (shmid == -1)
    {
        cerr << "Could not create copy segment " << key << ": " << strerror(errno) << "\n";
        return -1;
    }
    return shmid;
}

static bool is_sorted_row_major(const uint8_t* A, int F, int H, int W)
{
    for (int f = 0; f < F; ++f)
    {
        for (int i = 0; i < H; ++i)
        {
            const uint8_t* rowPtr = &A[(static_cast<size_t>(f) * H + i) * W];
            for (int j = 1; j < W; ++j)
            {
                if (rowPtr[j - 1] > rowPtr[j])
                    return false;
            }
        }
    }
    return true;
}

static bool same_bytes(const uint8_t* A, const uint8_t* B, size_t totalBytes)
{
    return memcmp(A, B, totalBytes) == 0;
}

void sort_frame_task(int frameId, int workerId, int workerCount, void* ctxPtr)
{
    (void)workerId;
    (void)workerCount;

    MatrixSortContext* ctx = static_cast<MatrixSortContext*>(ctxPtr);
    uint8_t* A = ctx->A;
    const int H = ctx->H;
    const int W = ctx->W;

    for (int i = 0; i < H; ++i)
    {
        uint8_t* rowPtr = &CELL(A, frameId, i, 0, H, W);

        // false = do not deep-copy the row. Sort directly inside shared memory.
        row<uint8_t> currentRow(rowPtr, static_cast<uint32_t>(W), false);
        counting_sort_u8(currentRow);
    }
}

static long long run_prolific(uint8_t* A, int F, int H, int W, int K, bool affinity)
{
    MatrixSortContext ctx;
    ctx.A = A;
    ctx.F = F;
    ctx.H = H;
    ctx.W = W;

    chrono::high_resolution_clock::time_point t1 =
        chrono::high_resolution_clock::now();

    int rc = run_bounded_prolific_scheduler(F, K, sort_frame_task, &ctx, affinity);

    chrono::high_resolution_clock::time_point t2 =
        chrono::high_resolution_clock::now();

    if (rc != 0)
        return -1;

    return chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
}

static long long run_collective(uint8_t* A, int F, int H, int W, int K, bool affinity)
{
    MatrixSortContext ctx;
    ctx.A = A;
    ctx.F = F;
    ctx.H = H;
    ctx.W = W;

    chrono::high_resolution_clock::time_point t1 =
        chrono::high_resolution_clock::now();

    int rc = run_bounded_collective_scheduler(F, K, sort_frame_task, &ctx, affinity);

    chrono::high_resolution_clock::time_point t2 =
        chrono::high_resolution_clock::now();

    if (rc != 0)
        return -1;

    return chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
}

static long long run_log_collective(uint8_t* A, int F, int H, int W, int K, bool affinity)
{
    (void)K;
    (void)affinity;

    MatrixSortContext ctx;
    ctx.A = A;
    ctx.F = F;
    ctx.H = H;
    ctx.W = W;

    chrono::high_resolution_clock::time_point t1 =
        chrono::high_resolution_clock::now();

    int rc = run_log_bounded_collective_scheduler(F, K, sort_frame_task, &ctx, affinity);

    chrono::high_resolution_clock::time_point t2 =
        chrono::high_resolution_clock::now();

    if (rc != 0)
        return -1;

    return chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
}

static long long run_chunk_stealing(uint8_t* A, int F, int H, int W, int K, bool affinity)
{
    (void)K;
    (void)affinity;

    chrono::high_resolution_clock::time_point t1 =
        chrono::high_resolution_clock::now();

    int rc = run_chunk_stealing_scheduler(A, F, H, W, 16, 10);

    chrono::high_resolution_clock::time_point t2 =
        chrono::high_resolution_clock::now();

    if (rc != 0)
        return -1;

    return chrono::duration_cast<chrono::milliseconds>(t2 - t1).count();
}

int main(int argc, char** argv)
{
    if (argc < 5 || argc > 6)
    {
        cout << "Usage: ./shm_benchmark F H W K [affinity: 0|1]\n";
        cout << "Example: ./shm_benchmark 1000 640 640 16 1\n";
        return 1;
    }

    int F = atoi(argv[1]);
    int H = atoi(argv[2]);
    int W = atoi(argv[3]);
    int K = atoi(argv[4]);
    bool affinity = true;

    if (argc == 6)
        affinity = atoi(argv[5]) != 0;

    if (F <= 0 || H <= 0 || W <= 0 || K <= 0)
    {
        cerr << "Invalid arguments.\n";
        return 1;
    }

    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);
    if (numCPUs <= 0)
        numCPUs = 1;
    bind_process_to_cpu(0 % static_cast<int>(numCPUs));

    size_t totalCells = 0;
    size_t totalBytes = 0;
    if (!checked_total_bytes(static_cast<size_t>(F), static_cast<size_t>(H), static_cast<size_t>(W), totalCells, totalBytes))
    {
        cerr << "Invalid or too large tensor dimensions.\n";
        return 1;
    }

    int originalId = get_existing_shmid(SHM_KEY_ORIGINAL);
    if (originalId == -1)
    {
        cerr << "Run ./shm_generator F H W first.\n";
        return 1;
    }

    uint8_t* original = attach_shm(originalId);
    if (original == nullptr)
        return 1;

    int prolificId = create_copy_segment(SHM_KEY_PROLIFIC_COPY, totalBytes);
    int collectiveId = create_copy_segment(SHM_KEY_COLLECTIVE_COPY, totalBytes);
    int logCollectiveId = create_copy_segment(SHM_KEY_LOG_COLLECTIVE_COPY, totalBytes);
    int chunkStealingId = create_copy_segment(SHM_KEY_LOG_COLLECTIVE_COPY + 1, totalBytes); 

    if (prolificId == -1 || collectiveId == -1 || logCollectiveId == -1)
    {
        shmdt(original);
        return 1;
    }

    uint8_t* prolificCopy = attach_shm(prolificId);
    uint8_t* collectiveCopy = attach_shm(collectiveId);
    uint8_t* logCollectiveCopy = attach_shm(logCollectiveId);
    uint8_t* chunkStealingCopy = attach_shm(chunkStealingId);

    if (prolificCopy == nullptr || collectiveCopy == nullptr || logCollectiveCopy == nullptr)
    {
        shmdt(original);
        return 1;
    }

    memcpy(prolificCopy, original, totalBytes);
    memcpy(collectiveCopy, original, totalBytes);
    memcpy(logCollectiveCopy, original, totalBytes);
    memcpy(chunkStealingCopy, original, totalBytes);

    cout << "Benchmark input prepared.\n";
    cout << "Original key: " << SHM_KEY_ORIGINAL << "\n";
    cout << "Prolific work-copy key: " << SHM_KEY_PROLIFIC_COPY << "\n";
    cout << "Collective work-copy key: " << SHM_KEY_COLLECTIVE_COPY << "\n";
    cout << "Log Collective work-copy key: " << SHM_KEY_LOG_COLLECTIVE_COPY << "\n";
    cout << "Chunk Stealing work-copy key: " << SHM_KEY_LOG_COLLECTIVE_COPY + 1 << "\n";
    cout << "F=" << F << ", H=" << H << ", W=" << W << ", K=" << K;
    cout << ", affinity=" << (affinity ? "on" : "off") << "\n";
    cout << "Copies identical before sorting: "
         << (same_bytes(prolificCopy, collectiveCopy, totalBytes) ? "yes" : "no") << "\n";

    long long prolificMs = run_prolific(prolificCopy, F, H, W, K, affinity);
    long long collectiveMs = run_collective(collectiveCopy, F, H, W, K, affinity);
    long long logCollectiveMs = run_log_collective(logCollectiveCopy, F, H, W, K, affinity);
    long long chunkStealingMs = run_chunk_stealing(chunkStealingCopy, F, H, W, K, affinity);

    if (prolificMs < 0 || collectiveMs < 0 || logCollectiveMs < 0 || chunkStealingMs < 0)
    {
        cerr << "A scheduler failed.\n";
        shmdt(original);
        shmdt(prolificCopy);
        shmdt(collectiveCopy);
        shmdt(logCollectiveCopy);
        shmdt(chunkStealingCopy);
        return 1;
    }

    cout << "\nResults\n";
    cout << "-------\n";
    cout << "Bounded Prolific sorting time:   " << prolificMs << " ms\n";
    cout << "Bounded Collective sorting time: " << collectiveMs << " ms\n";
    cout << "Log Bounded Collective sorting time: " << logCollectiveMs << " ms\n";
    cout << "Chunk Stealing sorting time: " << chunkStealingMs << " ms\n";
    cout << "Prolific sorted correctly:   " << (is_sorted_row_major(prolificCopy, F, H, W) ? "yes" : "no") << "\n";
    cout << "Collective sorted correctly: " << (is_sorted_row_major(collectiveCopy, F, H, W) ? "yes" : "no") << "\n";
    cout << "Log Bounded Collective sorted correctly: " << (is_sorted_row_major(logCollectiveCopy, F, H, W) ? "yes" : "no") << "\n";
    cout << "Chunk Stealing sorted correctly: " << (is_sorted_row_major(chunkStealingCopy, F, H, W) ? "yes" : "no") << "\n";

    shmdt(original);
    shmdt(prolificCopy);
    shmdt(collectiveCopy);
    shmdt(logCollectiveCopy);
    shmdt(chunkStealingCopy);

    // Keep only the original input segment alive. Work copies are temporary.
    shmctl(prolificId, IPC_RMID, nullptr);
    shmctl(collectiveId, IPC_RMID, nullptr);
    shmctl(logCollectiveId, IPC_RMID, nullptr);
    shmctl(chunkStealingId, IPC_RMID, nullptr);

    return 0;
}
