#include <iostream>
#include <cstdlib>
#include <cstdint>
#include <ctime>
#include <cerrno>
#include <cstring>
#include <limits>

#include <unistd.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <thread>
#include <vector>

#include "shm_keys.hpp"

using namespace std;

static uint8_t random_uint8(unsigned int* seed)
{
    return static_cast<uint8_t>(rand_r(seed) % 256);
}

static bool checked_total_bytes(size_t F, size_t H, size_t W, size_t& totalCells, size_t& totalBytes)
{
    if (F == 0 || H == 0 || W == 0)
        return false;

    if (F > numeric_limits<size_t>::max() / H)
        return false;

    size_t FH = F * H;

    if (FH > numeric_limits<size_t>::max() / W)
        return false;

    totalCells = FH * W;

    if (totalCells > numeric_limits<size_t>::max() / sizeof(uint8_t))
        return false;

    totalBytes = totalCells * sizeof(uint8_t);
    return true;
}

static void remove_shm_if_exists(key_t key)
{
    int shmid = shmget(key, 1, 0666);
    if (shmid != -1)
    {
        shmctl(shmid, IPC_RMID, nullptr);
    }
}

static void fill_random_range(uint8_t* A,
                              size_t totalCells,
                              unsigned int threadId,
                              unsigned int threadCount)
{
    size_t chunk = totalCells / threadCount;
    size_t start = threadId * chunk;
    size_t end = (threadId == threadCount - 1) ? totalCells : start + chunk;

    unsigned int seed = static_cast<unsigned int>(time(nullptr));
    seed ^= static_cast<unsigned int>(getpid());
    seed ^= threadId * 2654435761u;

    for (size_t idx = start; idx < end; ++idx)
    {
        A[idx] = random_uint8(&seed);
    }
}

static void generate_tensor(uint8_t* A, size_t totalCells)
{
    unsigned int threadCount = std::thread::hardware_concurrency();
    if (threadCount == 0)
        threadCount = 1;

    if (threadCount > totalCells)
        threadCount = static_cast<unsigned int>(totalCells);

    vector<thread> threads;
    threads.reserve(threadCount);

    for (unsigned int tid = 0; tid < threadCount; ++tid)
    {
        threads.push_back(thread(fill_random_range, A, totalCells, tid, threadCount));
    }

    for (unsigned int tid = 0; tid < threadCount; ++tid)
    {
        threads[tid].join();
    }
}

static int create_and_generate(size_t F, size_t H, size_t W)
{
    size_t totalCells = 0;
    size_t totalBytes = 0;

    if (!checked_total_bytes(F, H, W, totalCells, totalBytes))
    {
        cerr << "Invalid or too large tensor dimensions.\n";
        return 1;
    }

    remove_shm_if_exists(SHM_KEY_ORIGINAL);

    int shmid = shmget(SHM_KEY_ORIGINAL, totalBytes, IPC_CREAT | IPC_EXCL | 0666);
    if (shmid == -1)
    {
        cerr << "shmget create failed: " << strerror(errno) << "\n";
        return 1;
    }

    void* base = shmat(shmid, nullptr, 0);
    if (base == reinterpret_cast<void*>(-1))
    {
        cerr << "shmat failed: " << strerror(errno) << "\n";
        shmctl(shmid, IPC_RMID, nullptr);
        return 1;
    }

    uint8_t* A = static_cast<uint8_t*>(base);
    generate_tensor(A, totalCells);

    shmdt(base);

    cout << "Original shared tensor created.\n";
    cout << "Key: " << SHM_KEY_ORIGINAL << "\n";
    cout << "Dimensions: F=" << F << ", H=" << H << ", W=" << W << "\n";
    cout << "Bytes: " << totalBytes << "\n";
    return 0;
}

static int delete_all_project_shm()
{
    remove_shm_if_exists(SHM_KEY_ORIGINAL);
    remove_shm_if_exists(SHM_KEY_PROLIFIC_COPY);
    remove_shm_if_exists(SHM_KEY_COLLECTIVE_COPY);
    remove_shm_if_exists(SHM_KEY_LOG_COLLECTIVE_COPY);
    remove_shm_if_exists(SHM_KEY_CHUNK_STEALING_COPY);

    cout << "Project shared-memory segments deleted.\n";
    return 0;
}

int main(int argc, char** argv)
{
    if (argc == 2 && strcmp(argv[1], "delete") == 0)
    {
        return delete_all_project_shm();
    }

    if (argc != 4)
    {
        cout << "Usage:\n";
        cout << "  ./shm_generator F H W\n";
        cout << "  ./shm_generator delete\n";
        return 1;
    }

    size_t F = static_cast<size_t>(atoll(argv[1]));
    size_t H = static_cast<size_t>(atoll(argv[2]));
    size_t W = static_cast<size_t>(atoll(argv[3]));

    return create_and_generate(F, H, W);
}
