#include "many_to_many_scheduler.hpp"

#include <algorithm>
#include <cerrno>
#include <cstring>
#include <iostream>
#include <vector>

#include <sched.h>
#include <sys/select.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>

using namespace std;

namespace
{
    static const int PIPE_SCHEDULER_CHUNK = 105;

    bool read_full(int fd, void* buffer, size_t bytes)
    {
        char* p = static_cast<char*>(buffer);
        size_t left = bytes;

        while (left > 0)
        {
            ssize_t r = read(fd, p, left);

            if (r == 0)
                return false;

            if (r < 0)
            {
                if (errno == EINTR)
                    continue;

                return false;
            }

            p += r;
            left -= static_cast<size_t>(r);
        }

        return true;
    }

    bool write_full(int fd, const void* buffer, size_t bytes)
    {
        const char* p = static_cast<const char*>(buffer);
        size_t left = bytes;

        while (left > 0)
        {
            ssize_t w = write(fd, p, left);

            if (w < 0)
            {
                if (errno == EINTR)
                    continue;

                return false;
            }

            p += w;
            left -= static_cast<size_t>(w);
        }

        return true;
    }

    void apply_affinity_if_requested(int workerIndex, bool useAffinity)
    {
        if (!useAffinity)
            return;

        long cpuCount = sysconf(_SC_NPROCESSORS_ONLN);

        if (cpuCount <= 0)
            return;

        cpu_set_t set;
        CPU_ZERO(&set);
        CPU_SET(workerIndex % cpuCount, &set);

        sched_setaffinity(0, sizeof(set), &set);
    }

    bool send_request(int requestWriteFd)
    {
        int request = 1;
        return write_full(requestWriteFd, &request, sizeof(request));
    }

    bool receive_request(int requestReadFd)
    {
        int request = 0;
        return read_full(requestReadFd, &request, sizeof(request));
    }

    bool send_task_message(int fd, int startTask, int count)
    {
        int msg[2] = {startTask, count};
        return write_full(fd, msg, sizeof(msg));
    }

    void close_child_unused_fds(vector<int>& parentToChild,
                                vector<int>& childToParent,
                                int workerIndex,
                                int workerCount)
    {
        for (int w = 0; w < workerCount; ++w)
        {
            close(parentToChild[static_cast<size_t>(2 * w + 1)]);
            close(childToParent[static_cast<size_t>(2 * w)]);

            if (w != workerIndex)
            {
                close(parentToChild[static_cast<size_t>(2 * w)]);
                close(childToParent[static_cast<size_t>(2 * w + 1)]);
            }
        }
    }

    void close_parent_worker_fds(vector<int>& parentToChild,
                                 vector<int>& childToParent,
                                 int workerIndex)
    {
        int requestReadFd = childToParent[static_cast<size_t>(2 * workerIndex)];
        int assignmentWriteFd = parentToChild[static_cast<size_t>(2 * workerIndex + 1)];

        if (requestReadFd >= 0)
        {
            close(requestReadFd);
            childToParent[static_cast<size_t>(2 * workerIndex)] = -1;
        }

        if (assignmentWriteFd >= 0)
        {
            close(assignmentWriteFd);
            parentToChild[static_cast<size_t>(2 * workerIndex + 1)] = -1;
        }
    }

    void run_worker_loop(int workerIndex,
                         int workerCount,
                         int assignmentReadFd,
                         int requestWriteFd,
                         ManyToManyPipeTaskFunction taskFn,
                         void* ctx)
    {
        while (true)
        {
            if (!send_request(requestWriteFd))
                break;

            int msg[2] = {0, 0};

            if (!read_full(assignmentReadFd, msg, sizeof(msg)))
                break;

            int startTask = msg[0];
            int count = msg[1];

            if (count <= 0)
                break;

            for (int offset = 0; offset < count; ++offset)
                taskFn(startTask + offset, workerIndex, workerCount, ctx);
        }

        close(assignmentReadFd);
        close(requestWriteFd);
        _exit(0);
    }

    int wait_for_workers(const vector<pid_t>& pids)
    {
        int rc = 0;

        for (size_t i = 0; i < pids.size(); ++i)
        {
            if (pids[i] <= 0)
                continue;

            int status = 0;

            while (waitpid(pids[i], &status, 0) == -1)
            {
                if (errno == EINTR)
                    continue;

                cerr << "waitpid failed for worker " << i << ": "
                     << strerror(errno) << "\n";
                rc = 1;
                break;
            }

            if (!WIFEXITED(status) || WEXITSTATUS(status) != 0)
                rc = 1;
        }

        return rc;
    }
}

int run_many_to_many_pipe_scheduler(int taskCount,
                                    int workerCount,
                                    ManyToManyPipeTaskFunction taskFn,
                                    void* ctx,
                                    bool useAffinity)
{
    if (taskCount <= 0 || workerCount <= 0 || taskFn == nullptr || PIPE_SCHEDULER_CHUNK <= 0)
        return 1;

    vector<int> parentToChild(static_cast<size_t>(2 * workerCount), -1);
    vector<int> childToParent(static_cast<size_t>(2 * workerCount), -1);
    vector<pid_t> pids(static_cast<size_t>(workerCount), -1);

    for (int w = 0; w < workerCount; ++w)
    {
        if (pipe(&parentToChild[static_cast<size_t>(2 * w)]) == -1)
        {
            cerr << "assignment pipe failed in many-to-many scheduler: "
                 << strerror(errno) << "\n";
            return 1;
        }

        if (pipe(&childToParent[static_cast<size_t>(2 * w)]) == -1)
        {
            cerr << "request pipe failed in many-to-many scheduler: "
                 << strerror(errno) << "\n";
            return 1;
        }
    }

    for (int w = 0; w < workerCount; ++w)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            cerr << "fork failed in many-to-many scheduler: " << strerror(errno) << "\n";
            return 1;
        }

        if (pid == 0)
        {
            int assignmentReadFd = parentToChild[static_cast<size_t>(2 * w)];
            int requestWriteFd = childToParent[static_cast<size_t>(2 * w + 1)];

            close_child_unused_fds(parentToChild, childToParent, w, workerCount);
            apply_affinity_if_requested(w, useAffinity);
            run_worker_loop(w, workerCount, assignmentReadFd, requestWriteFd, taskFn, ctx);
        }

        pids[static_cast<size_t>(w)] = pid;
    }

    for (int w = 0; w < workerCount; ++w)
    {
        close(parentToChild[static_cast<size_t>(2 * w)]);
        parentToChild[static_cast<size_t>(2 * w)] = -1;

        close(childToParent[static_cast<size_t>(2 * w + 1)]);
        childToParent[static_cast<size_t>(2 * w + 1)] = -1;
    }

    int nextTask = 0;
    int activeWorkers = workerCount;
    int schedulerRc = 0;

    while (activeWorkers > 0)
    {
        fd_set readfds;
        FD_ZERO(&readfds);

        int maxFd = -1;

        for (int w = 0; w < workerCount; ++w)
        {
            int requestReadFd = childToParent[static_cast<size_t>(2 * w)];

            if (requestReadFd >= 0)
            {
                FD_SET(requestReadFd, &readfds);

                if (requestReadFd > maxFd)
                    maxFd = requestReadFd;
            }
        }

        if (maxFd < 0)
            break;

        int ready = select(maxFd + 1, &readfds, nullptr, nullptr, nullptr);

        if (ready < 0)
        {
            if (errno == EINTR)
                continue;

            cerr << "select failed in many-to-many scheduler: " << strerror(errno) << "\n";
            schedulerRc = 1;
            break;
        }

        for (int w = 0; w < workerCount && ready > 0; ++w)
        {
            int requestReadFd = childToParent[static_cast<size_t>(2 * w)];

            if (requestReadFd < 0 || !FD_ISSET(requestReadFd, &readfds))
                continue;

            --ready;

            if (!receive_request(requestReadFd))
            {
                close_parent_worker_fds(parentToChild, childToParent, w);
                --activeWorkers;
                continue;
            }

            int count = 0;
            int startTask = 0;

            if (nextTask < taskCount)
            {
                startTask = nextTask;
                count = min(PIPE_SCHEDULER_CHUNK, taskCount - nextTask);
                nextTask += count;
            }

            int assignmentWriteFd = parentToChild[static_cast<size_t>(2 * w + 1)];

            if (!send_task_message(assignmentWriteFd, startTask, count))
            {
                cerr << "assignment write failed in many-to-many scheduler: "
                     << strerror(errno) << "\n";
                close_parent_worker_fds(parentToChild, childToParent, w);
                --activeWorkers;
                schedulerRc = 1;
                continue;
            }

            if (count == 0)
            {
                close_parent_worker_fds(parentToChild, childToParent, w);
                --activeWorkers;
            }
        }
    }

    for (int w = 0; w < workerCount; ++w)
        close_parent_worker_fds(parentToChild, childToParent, w);

    int waitRc = wait_for_workers(pids);

    return schedulerRc == 0 ? waitRc : schedulerRc;
}
