#include "bounded_collective_scheduler.hpp"

#include <iostream>
#include <cerrno>
#include <cstring>

#include <unistd.h>
#include <sys/wait.h>
#include <sched.h>

static bool bind_process_to_cpu_local(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);
    return sched_setaffinity(0, sizeof(cpu_set_t), &cpuset) == 0;
}

int run_bounded_collective_scheduler(
    int totalTasks,
    int workerCount,
    collective_task_fn taskFn,
    void* ctx,
    bool bindAffinity)
{
    if (totalTasks <= 0 || workerCount <= 0 || taskFn == nullptr)
        return -1;

    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);
    if (numCPUs <= 0)
        numCPUs = 1;

    pid_t parent = getpid();

    // Parent forks exactly workerCount children. Each child knows its
    // workerId from the loop index. The parent stays in the loop only
    // to fork more children; it does NOT execute any tasks itself.
    for (int w = 0; w < workerCount; ++w)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            std::cerr << "fork failed: " << std::strerror(errno) << "\n";
            // Best effort: reap any survivors so we don't leak zombies.
            while (waitpid(-1, nullptr, 0) > 0) {}
            return -1;
        }

        if (pid == 0)
        {
            // Child: bind affinity (if requested), execute its share, exit.
            if (bindAffinity)
            {
                int cpuId = w % static_cast<int>(numCPUs);
                if (!bind_process_to_cpu_local(cpuId))
                {
                    std::cerr << "[PID " << getpid()
                              << "] failed to bind to CPU "
                              << cpuId << "\n";
                }
            }

            for (int t = w; t < totalTasks; t += workerCount)
            {
                taskFn(t, w, workerCount, ctx);
            }

            _exit(0);
        }
        // Parent: continue forking the next worker.
    }

    // Parent reaps every direct child.
    if (getpid() == parent)
    {
        while (true)
        {
            int status = 0;
            pid_t finished = waitpid(-1, &status, 0);

            if (finished > 0)
                continue;

            if (finished == -1)
            {
                if (errno == ECHILD)
                    break;
                if (errno == EINTR)
                    continue;
                break;
            }
        }
    }

    return 0;
}
