#include "bounded_collective_scheduler_log.hpp"

#include <iostream>
#include <cerrno>
#include <cstring>

#include <unistd.h>
#include <sys/wait.h>
#include <sched.h>

static bool bind_process_to_cpu_local(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);

    return sched_setaffinity(0, sizeof(cpu_set_t), &cpuset) == 0;
}

static void reap_direct_children()
{
    int status = 0;

    while (true)
    {
        pid_t finished = waitpid(-1, &status, 0);

        if (finished > 0)
            continue;

        if (finished == -1)
        {
            if (errno == ECHILD)
                break;

            if (errno == EINTR)
                continue;

            break;
        }
    }
}

static int ceil_log2_int(int x)
{
    int levels = 0;
    int value = 1;

    while (value < x)
    {
        value *= 2;
        levels++;
    }

    return levels;
}

static int collective_log_worker_main(
    int myWorkerId,
    int startLevel,
    int levels,
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity
)
{
    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);
    if (numCPUs <= 0)
        numCPUs = 1;

    if (bindAffinity)
    {
        int cpuId = myWorkerId % static_cast<int>(numCPUs);

        if (!bind_process_to_cpu_local(cpuId))
        {
            std::cerr << "[PID " << getpid()
                      << "] failed to bind to CPU "
                      << cpuId << "\n";
        }
    }

    for (int level = startLevel; level < levels; level++)
    {
        int childWorkerId = myWorkerId + (1 << level);

        if (childWorkerId >= workerCount)
            continue;

        pid_t pid = fork();

        if (pid < 0)
        {
            std::cerr << "[PID " << getpid() << "] fork failed: "
                      << std::strerror(errno) << "\n";
            return -1;
        }

        if (pid == 0)
        {
            myWorkerId = childWorkerId;

            if (bindAffinity)
            {
                int cpuId = myWorkerId % static_cast<int>(numCPUs);

                if (!bind_process_to_cpu_local(cpuId))
                {
                    std::cerr << "[PID " << getpid()
                              << "] failed to bind to CPU "
                              << cpuId << "\n";
                }
            }

            continue;
        }
    }

    for (int taskIndex = myWorkerId;
         taskIndex < totalTasks;
         taskIndex += workerCount)
    {
        taskFn(taskIndex, myWorkerId, workerCount, ctx);
    }

    reap_direct_children();

    return 0;
}

int run_log_bounded_collective_scheduler(
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity
)
{
    if (totalTasks < 0 || workerCount <= 0 || taskFn == nullptr)
    {
        std::cerr << "run_log_bounded_collective_scheduler: invalid arguments\n";
        return -1;
    }

    if (totalTasks == 0)
        return 0;

    if (workerCount > totalTasks)
        workerCount = totalTasks;

    int levels = ceil_log2_int(workerCount);

    pid_t first = fork();

    if (first < 0)
    {
        std::perror("fork failed");
        return -1;
    }

    if (first == 0)
    {
        int rc = collective_log_worker_main(
            0,
            0,
            levels,
            totalTasks,
            workerCount,
            taskFn,
            ctx,
            bindAffinity
        );

        _exit(rc == 0 ? 0 : 1);
    }

    int status = 0;

    while (true)
    {
        pid_t finished = waitpid(first, &status, 0);

        if (finished == first)
            break;

        if (finished == -1)
        {
            if (errno == EINTR)
                continue;

            std::perror("waitpid failed");
            return -1;
        }
    }

    if (WIFEXITED(status) && WEXITSTATUS(status) == 0)
        return 0;

    return -1;
}