#include "bounded_collective_scheduler_log.hpp"

#include <iostream>
#include <cerrno>
#include <cstring>

#include <unistd.h>
#include <sys/wait.h>
#include <sched.h>


/*
    Logarithmic bounded collective scheduler

    Idea:
    If workerCount = 2^levels, the collective scheduler creates workers
    through a balanced binary process tree.

    Example for workerCount = 8:

        level 0: 0 forks 1
        level 1: 0 forks 2, 1 forks 3
        level 2: 0 forks 4, 1 forks 5, 2 forks 6, 3 forks 7

    After the tree is created, every worker executes:

        taskIndex = workerId;
        taskIndex < totalTasks;
        taskIndex += workerCount;

    So all workers share the same task distribution style as the prolific version,
    but the spawning is collective and logarithmic instead of parent-only.
*/

static bool bind_process_to_cpu_local(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);

    return sched_setaffinity(0, sizeof(cpu_set_t), &cpuset) == 0;
}

static void reap_direct_children()
{
    int status = 0;

    while (true)
    {
        pid_t finished = waitpid(-1, &status, 0);

        if (finished > 0)
            continue;

        if (finished == -1)
        {
            if (errno == ECHILD)
                break;

            if (errno == EINTR)
                continue;

            break;
        }
    }
}

static bool is_power_of_two(int x)
{
    return x > 0 && (x & (x - 1)) == 0;
}

static int floor_power_of_two(int x)
{
    int p = 1;

    while (p <= x / 2)
        p *= 2;

    return p;
}

static int log2_int(int x)
{
    int levels = 0;

    while (x > 1)
    {
        x /= 2;
        levels++;
    }

    return levels;
}

static int collective_log_worker_main(
    int myWorkerId,
    int startLevel,
    int levels,
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity
)
{
    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);
    if (numCPUs <= 0)
        numCPUs = 1;

    if (bindAffinity)
    {
        int cpuId = myWorkerId % static_cast<int>(numCPUs);

        if (!bind_process_to_cpu_local(cpuId))
        {
            std::cerr << "[PID " << getpid() << "] failed to bind to CPU "
                      << cpuId << "\n";
        }
    }

    for (int level = startLevel; level < levels; level++)
    {
        int childWorkerId = myWorkerId + (1 << level);

        if (childWorkerId >= workerCount)
            continue;

        pid_t pid = fork();

        if (pid < 0)
        {
            std::cerr << "[PID " << getpid() << "] fork failed: "
                      << std::strerror(errno) << "\n";
            return -1;
        }

        if (pid == 0)
        {
            myWorkerId = childWorkerId;

            if (bindAffinity)
            {
                int cpuId = myWorkerId % static_cast<int>(numCPUs);

                if (!bind_process_to_cpu_local(cpuId))
                {
                    std::cerr << "[PID " << getpid()
                              << "] failed to bind to CPU "
                              << cpuId << "\n";
                }
            }

            /*
                The child continues from the next level.
                This is what makes the tree balanced:
                every existing process participates in the next doubling step.
            */
            continue;
        }
    }

    for (int taskIndex = myWorkerId; taskIndex < totalTasks; taskIndex += workerCount)
    {
        taskFn(taskIndex, myWorkerId, workerCount, ctx);
    }

    reap_direct_children();

    return 0;
}

int run_log_bounded_collective_scheduler(
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity
)
{
    if (totalTasks < 0 || workerCount <= 0 || taskFn == nullptr)
    {
        std::cerr << "run_log_bounded_collective_scheduler: invalid arguments\n";
        return -1;
    }

    if (totalTasks == 0)
        return 0;

    if (workerCount > totalTasks)
        workerCount = totalTasks;

    if (!is_power_of_two(workerCount))
    {
        int adjusted = floor_power_of_two(workerCount);

        std::cerr << "[Collective log scheduler] workerCount must be a power of 2. "
                  << "Using " << adjusted << " instead of " << workerCount << ".\n";

        workerCount = adjusted;
    }

    int levels = log2_int(workerCount);

    pid_t first = fork();

    if (first < 0)
    {
        std::perror("fork failed");
        return -1;
    }

    if (first == 0)
    {
        int rc = collective_log_worker_main(
            0,
            0,
            levels,
            totalTasks,
            workerCount,
            taskFn,
            ctx,
            bindAffinity
        );

        _exit(rc == 0 ? 0 : 1);
    }

    int status = 0;
    while (true)
    {
        pid_t finished = waitpid(first, &status, 0);

        if (finished == first)
            break;

        if (finished == -1)
        {
            if (errno == EINTR)
                continue;

            std::perror("waitpid failed");
            return -1;
        }
    }

    if (WIFEXITED(status) && WEXITSTATUS(status) == 0)
        return 0;

    return -1;
}
