#include "many_to_many_scheduler.hpp"

#include <iostream>
#include <vector>
#include <algorithm>
#include <cerrno>
#include <cstring>

#include <unistd.h>
#include <sys/wait.h>
#include <sched.h>

using namespace std;

static bool bind_process_to_cpu_pipe(int cpuId)
{
    cpu_set_t cpuset;

    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);

    int rc = sched_setaffinity(
        0,
        sizeof(cpu_set_t),
        &cpuset
    );

    return rc == 0;
}

int run_many_to_many_pipe_scheduler(
    int totalTasks,
    int workerCount,
    scheduler_task_fn taskFn,
    void* context,
    bool enableAffinity
)
{
    if (totalTasks <= 0)
    {
        return 0;
    }

    if (workerCount <= 0)
    {
        workerCount = 1;
    }

    if (workerCount > totalTasks)
    {
        workerCount = totalTasks;
    }

    if (taskFn == nullptr)
    {
        cerr << "[Many-to-Many Pipe Scheduler] taskFn is null." << endl;
        return -1;
    }

    long numCPUs = sysconf(_SC_NPROCESSORS_ONLN);

    if (numCPUs <= 0)
    {
        numCPUs = 1;
    }

    vector<pid_t> children;
    children.reserve(static_cast<size_t>(workerCount));



    int parentToChild[workerCount][2];
    int childToParent[workerCount][2];

    for (int i = 0; i < workerCount; i++)
    {
        if (pipe(parentToChild[i]) == -1)
        {
            perror("[Many-to-Many Pipe Scheduler] parentToChild pipe failed");
            return -1;
        }

        if (pipe(childToParent[i]) == -1)
        {
            perror("[Many-to-Many Pipe Scheduler] childToParent pipe failed");
            return -1;
        }
    }

    for (int workerId = 0; workerId < workerCount; workerId++)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            perror("[Many-to-Many Pipe Scheduler] fork failed");

            for (int i = 0; i < workerCount; i++)
            {
                close(parentToChild[i][0]);
                close(parentToChild[i][1]);
                close(childToParent[i][0]);
                close(childToParent[i][1]);
            }

            int status = 0;

            while (waitpid(-1, &status, 0) > 0)
            {
    
            }

            return -1;
        }

        if (pid == 0)
        {
            close(parentToChild[workerId][1]);
            close(childToParent[workerId][0]);

            for (int j = 0; j < workerCount; j++)
            {
                if (j != workerId)
                {
                    close(parentToChild[j][0]);
                    close(parentToChild[j][1]);
                    close(childToParent[j][0]);
                    close(childToParent[j][1]);
                }
            }

            if (enableAffinity)
            {
                int cpuId = workerId % static_cast<int>(numCPUs);
                bind_process_to_cpu_pipe(cpuId);
            }

            int executed = 0;

            while (true)
            {
                int taskId;

                ssize_t n = read(
                    parentToChild[workerId][0],
                    &taskId,
                    sizeof(int)
                );

                if (n <= 0)
                {
                    break;
                }

                if (taskId == -1)
                {
                    break;
                }

                taskFn(
                    taskId,
                    workerId,
                    workerCount,
                    context
                );

                executed++;
            }

            write(
                childToParent[workerId][1],
                &executed,
                sizeof(int)
            );

            close(parentToChild[workerId][0]);
            close(childToParent[workerId][1]);

            _exit(0);
        }
        else
        {
            children.push_back(pid);
        }
    }

    
    for (int i = 0; i < workerCount; i++)
    {
        close(parentToChild[i][0]);
        close(childToParent[i][1]);
    }

    for (int taskId = 0; taskId < totalTasks; taskId++)
    {
        int worker = taskId % workerCount;

        ssize_t written = write(
            parentToChild[worker][1],
            &taskId,
            sizeof(int)
        );

        if (written != sizeof(int))
        {
            perror("[Many-to-Many Pipe Scheduler] write task failed");
        }
    }

    int stop = -1;

    for (int i = 0; i < workerCount; i++)
    {
        write(
            parentToChild[i][1],
            &stop,
            sizeof(int)
        );

        close(parentToChild[i][1]);
    }

    int totalExecuted = 0;

    for (int i = 0; i < workerCount; i++)
    {
        int executed = 0;

        ssize_t n = read(
            childToParent[i][0],
            &executed,
            sizeof(int)
        );

        if (n == sizeof(int))
        {
            totalExecuted += executed;
        }

        close(childToParent[i][0]);
    }

    while (!children.empty())
    {
        int status = 0;

        pid_t finished = waitpid(
            -1,
            &status,
            0
        );

        if (finished < 0)
        {
            if (errno == ECHILD)
            {
                break;
            }

            perror("[Many-to-Many Pipe Scheduler] waitpid failed");
            return -1;
        }

        auto it = find(
            children.begin(),
            children.end(),
            finished
        );

        if (it != children.end())
        {
            children.erase(it);
        }
    }

    if (totalExecuted != totalTasks)
    {
        cerr << "[Many-to-Many Pipe Scheduler] Warning: executed "
             << totalExecuted
             << " tasks, expected "
             << totalTasks
             << endl;
    }

    return 0;
}