#include "bounded_prolific_scheduler.hpp"

#include <iostream>
#include <vector>
#include <algorithm>
#include <cerrno>
#include <cstring>
#include <unistd.h>
#include <sys/wait.h>
#include <sched.h>

using namespace std;

bool bind_process_to_cpu_scheduler(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);

    int rc = sched_setaffinity(0, sizeof(cpu_set_t), &cpuset);
    if (rc == 0)
    {
        return true;
    }
    else
    {
        return false;
    }
}

int run_bounded_prolific_scheduler(
    int totalTasks,
    int workerCount,
    scheduler_task_fn taskFn,
    void* context,
    bool enableAffinity
)
{
    if (totalTasks <= 0)
    {
        return 0;
    }

    if (workerCount <= 0)
    {
        workerCount = 1;
    }

    if (workerCount > totalTasks)
    {
        workerCount = totalTasks;
    }

    if (taskFn == nullptr)
    {
        cerr << "[Scheduler] taskFn is null." << endl;
        return -1;
    }

    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);
    if (numCPUs <= 0)
        numCPUs = 1;

    pid_t root = getpid();
    vector<pid_t> children;
    children.reserve(static_cast<size_t>(workerCount));

    for (int workerId = 0; workerId < workerCount; ++workerId)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            perror("[Scheduler] fork failed");

            int status = 0;
            while (!children.empty())
            {
                pid_t finished = waitpid(-1, &status, 0);
                if (finished > 0)
                {
                    auto it = find(children.begin(), children.end(), finished);
                    if (it != children.end())
                        children.erase(it);
                }
                else
                {
                    break;
                }
            }
            return -1;
        }

        if (pid == 0)
        {
            pid_t mypid = getpid();
            pid_t parentpid = getppid();

            if (enableAffinity)
            {
                int cpuId = (workerId + 1) % numCPUs;
                bind_process_to_cpu_scheduler(cpuId);
            }

            int executed = 0;
            for (int taskId = workerId; taskId < totalTasks; taskId += workerCount)
            {
                taskFn(taskId, workerId, workerCount, context);
                executed++;
            }
            _exit(0);
        }
        else
        {
            children.push_back(pid);
        }
    }

    while (!children.empty())
    {
        int status = 0;
        pid_t finished = waitpid(-1, &status, 0);

        if (finished < 0)
        {
            if (errno == ECHILD)
                break;

            perror("[Scheduler] waitpid failed");
            return -1;
        }
        auto it = find(children.begin(), children.end(), finished);
        if (it != children.end())
            children.erase(it);
    }

    return 0;
}