#include "bounded_collective_scheduler.hpp"

#include <iostream>
#include <cerrno>
#include <cstdlib>
#include <cstring>

#include <unistd.h>
#include <sys/mman.h>
#include <sys/wait.h>

struct CollectiveSharedState
{
    int nextWorkerId;
    int creationError;
};

static void reap_direct_children()
{
    int status = 0;

    while (true)
    {
        pid_t finished = waitpid(-1, &status, 0);

        if (finished > 0)
            continue;

        if (finished == -1)
        {
            if (errno == ECHILD)
                break;

            if (errno == EINTR)
                continue;

            break;
        }
    }
}

static int collective_worker_main(
    int myWorkerId,
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    CollectiveSharedState* shared
)
{
    while (true)
    {
        int reservedId = __sync_fetch_and_add(&shared->nextWorkerId, 1);

        if (reservedId >= workerCount)
            break;

        pid_t pid = fork();

        if (pid < 0)
        {
            std::cerr << "[PID " << getpid() << "] fork failed: "
                      << std::strerror(errno) << "\n";
            shared->creationError = 1;
            break;
        }

        if (pid == 0)
        {
            myWorkerId = reservedId;
            continue;
        }

    }

    for (int taskIndex = myWorkerId; taskIndex < totalTasks; taskIndex += workerCount)
    {
        taskFn(taskIndex, myWorkerId, workerCount, ctx);
    }

    reap_direct_children();
    return 0;
}

int run_bounded_collective_scheduler(
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx
)
{
    if (totalTasks < 0 || workerCount <= 0 || taskFn == nullptr)
    {
        std::cerr << "run_bounded_collective_scheduler: invalid arguments\n";
        return -1;
    }

    CollectiveSharedState* shared =
        static_cast<CollectiveSharedState*>(mmap(
            nullptr,
            sizeof(CollectiveSharedState),
            PROT_READ | PROT_WRITE,
            MAP_SHARED | MAP_ANONYMOUS,
            -1,
            0
        ));

    if (shared == MAP_FAILED)
    {
        perror("mmap failed");
        return -1;
    }

    shared->nextWorkerId = 1;
    shared->creationError = 0;

    pid_t first = fork();

    if (first < 0)
    {
        perror("fork failed");
        munmap(shared, sizeof(CollectiveSharedState));
        return -1;
    }

    if (first == 0)
    {
        int rc = collective_worker_main(
            0,
            totalTasks,
            workerCount,
            taskFn,
            ctx,
            shared
        );

        _exit(rc == 0 ? 0 : 1);
    }

    reap_direct_children();

    int rc = (shared->creationError != 0) ? -1 : 0;
    munmap(shared, sizeof(CollectiveSharedState));
    return rc;
}