#include "bounded_collective_scheduler.hpp"

#include <iostream>
#include <vector>
#include <cerrno>
#include <cstdlib>
#include <cstring>

#include <unistd.h>
#include <sys/mman.h>
#include <sys/wait.h>
#include <sched.h>

struct CollectiveSharedState
{
    int nextWorkerId;   // next worker id to assign
    int creationError;  // set to 1 if any fork/bind/setup failure happens
};

static bool bind_process_to_cpu_local(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);

    int rc = sched_setaffinity(0, sizeof(cpu_set_t), &cpuset);
    return (rc == 0);
}

static void reap_direct_children()
{
    int status = 0;
    while (true)
    {
        pid_t finished = waitpid(-1, &status, 0);
        if (finished > 0)
            continue;

        if (finished == -1)
        {
            if (errno == ECHILD)
                break;

            if (errno == EINTR)
                continue;

            break;
        }
    }
}

static int collective_worker_main(
    int myWorkerId,
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity,
    CollectiveSharedState* shared
)
{
    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);
    if (numCPUs <= 0)
        numCPUs = 1;

    if (bindAffinity)
    {
        int cpuId = myWorkerId % static_cast<int>(numCPUs);
        if (!bind_process_to_cpu_local(cpuId))
        {
            std::cerr << "[PID " << getpid() << "] failed to bind to CPU "
                      << cpuId << "\n";
            shared->creationError = 1;
        }
    }

    while (true)
    {
        int reservedId = __sync_fetch_and_add(&shared->nextWorkerId, 1);

        if (reservedId >= workerCount)
            break;

        pid_t pid = fork();

        if (pid < 0)
        {
            std::cerr << "[PID " << getpid() << "] fork failed: "
                      << std::strerror(errno) << "\n";
            shared->creationError = 1;
            break;
        }

        if (pid == 0)
        {
            myWorkerId = reservedId;

            if (bindAffinity)
            {
                int cpuId = myWorkerId % static_cast<int>(numCPUs);
                if (!bind_process_to_cpu_local(cpuId))
                {
                    std::cerr << "[PID " << getpid() << "] failed to bind to CPU "
                              << cpuId << "\n";
                    shared->creationError = 1;
                }
            }

            continue;
        }
    }

    for (int taskIndex = myWorkerId; taskIndex < totalTasks; taskIndex += workerCount)
    {
        taskFn(taskIndex, myWorkerId, workerCount, ctx);
    }
    reap_direct_children();
    return 0;
}

int run_bounded_collective_scheduler(
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity
)
{
    if (totalTasks < 0 || workerCount <= 0 || taskFn == nullptr)
    {
        std::cerr << "run_bounded_collective_scheduler: invalid arguments\n";
        return -1;
    }

    CollectiveSharedState* shared =
        static_cast<CollectiveSharedState*>(mmap(
            nullptr,
            sizeof(CollectiveSharedState),
            PROT_READ | PROT_WRITE,
            MAP_SHARED | MAP_ANONYMOUS,
            -1,
            0
        ));

    if (shared == MAP_FAILED)
    {
        std::perror("mmap failed");
        return -1;
    }

    shared->nextWorkerId = 1;   
    shared->creationError = 0;

    pid_t first = fork();
    if (first < 0)
    {
        std::perror("fork failed");
        munmap(shared, sizeof(CollectiveSharedState));
        return -1;
    }

    if (first == 0)
    {
        int rc = collective_worker_main(
            0,
            totalTasks,
            workerCount,
            taskFn,
            ctx,
            bindAffinity,
            shared
        );

        _exit(rc == 0 ? 0 : 1);
    }

    reap_direct_children();

    int rc = (shared->creationError != 0) ? -1 : 0;
    munmap(shared, sizeof(CollectiveSharedState));
    return rc;
}