#include "bounded_collective_scheduler.hpp"

#include <iostream>
#include <cerrno>
#include <cstring>

#include <unistd.h>
#include <sys/mman.h>
#include <sys/wait.h>
#include <sched.h>

struct CollectiveSharedState
{
    int nextWorkerId;
    int creationError;
};

static bool bind_process_to_cpu_local(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);

    int rc = sched_setaffinity(0, sizeof(cpu_set_t), &cpuset);
    return rc == 0;
}

static void reap_direct_children()
{
    int status = 0;

    while (true)
    {
        pid_t finished = waitpid(-1, &status, 0);

        if (finished > 0)
            continue;

        if (finished == -1)
        {
            if (errno == ECHILD)
                break;

            if (errno == EINTR)
                continue;

            break;
        }
    }
}

static int collective_worker_main(
    int myWorkerId,
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity,
    CollectiveSharedState* shared
)
{
    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);
    if (numCPUs <= 0)
        numCPUs = 1;

    if (bindAffinity)
    {
        int cpuId = myWorkerId % static_cast<int>(numCPUs);

        if (!bind_process_to_cpu_local(cpuId))
        {
            std::cerr << "[PID " << getpid() << "] failed to bind to CPU "
                      << cpuId << "\n";
            shared->creationError = 1;
        }
    }

    const int SPAWN_BURST = 2;

    while (true)
    {
        int spawnedSomething = 0;

        for (int s = 0; s < SPAWN_BURST; s++)
        {
            int reservedId = __sync_fetch_and_add(&shared->nextWorkerId, 1);

            if (reservedId >= workerCount)
                break;

            pid_t pid = fork();

            if (pid < 0)
            {
                std::cerr << "[PID " << getpid() << "] fork failed: "
                          << std::strerror(errno) << "\n";
                shared->creationError = 1;
                break;
            }

            if (pid == 0)
            {
                myWorkerId = reservedId;

                if (bindAffinity)
                {
                    int cpuId = myWorkerId % static_cast<int>(numCPUs);

                    if (!bind_process_to_cpu_local(cpuId))
                    {
                        std::cerr << "[PID " << getpid()
                                  << "] failed to bind to CPU "
                                  << cpuId << "\n";
                        shared->creationError = 1;
                    }
                }

                spawnedSomething = 1;
                break;
            }

            spawnedSomething = 1;
        }

        if (!spawnedSomething)
            break;

        if (shared->nextWorkerId >= workerCount)
            break;
    }

    for (int taskIndex = myWorkerId; taskIndex < totalTasks; taskIndex += workerCount)
    {
        taskFn(taskIndex, myWorkerId, workerCount, ctx);
    }

    reap_direct_children();

    return 0;
}

int run_bounded_collective_scheduler(
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity
)
{
    if (totalTasks < 0 || workerCount <= 0 || taskFn == nullptr)
    {
        std::cerr << "run_bounded_collective_scheduler: invalid arguments\n";
        return -1;
    }

    if (totalTasks == 0)
        return 0;

    if (workerCount > totalTasks)
        workerCount = totalTasks;

    CollectiveSharedState* shared =
        static_cast<CollectiveSharedState*>(mmap(
            nullptr,
            sizeof(CollectiveSharedState),
            PROT_READ | PROT_WRITE,
            MAP_SHARED | MAP_ANONYMOUS,
            -1,
            0
        ));

    if (shared == MAP_FAILED)
    {
        std::perror("mmap failed");
        return -1;
    }

    shared->nextWorkerId = 1;
    shared->creationError = 0;

    pid_t first = fork();

    if (first < 0)
    {
        std::perror("fork failed");
        munmap(shared, sizeof(CollectiveSharedState));
        return -1;
    }

    if (first == 0)
    {
        int rc = collective_worker_main(
            0,
            totalTasks,
            workerCount,
            taskFn,
            ctx,
            bindAffinity,
            shared
        );

        _exit(rc == 0 ? 0 : 1);
    }

    reap_direct_children();

    int rc = shared->creationError != 0 ? -1 : 0;

    munmap(shared, sizeof(CollectiveSharedState));

    return rc;
}