#include "one_to_many_scheduler.hpp"

#include <algorithm>
#include <cerrno>
#include <cstring>
#include <iostream>
#include <vector>

#include <sched.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>

using namespace std;

namespace
{
    static const int PIPE_SCHEDULER_CHUNK = 10;

    bool read_full(int fd, void* buffer, size_t bytes)
    {
        char* p = static_cast<char*>(buffer);
        size_t left = bytes;

        while (left > 0)
        {
            ssize_t r = read(fd, p, left);

            if (r == 0)
                return false;

            if (r < 0)
            {
                if (errno == EINTR)
                    continue;

                return false;
            }

            p += r;
            left -= static_cast<size_t>(r);
        }

        return true;
    }

    bool write_full(int fd, const void* buffer, size_t bytes)
    {
        const char* p = static_cast<const char*>(buffer);
        size_t left = bytes;

        while (left > 0)
        {
            ssize_t w = write(fd, p, left);

            if (w < 0)
            {
                if (errno == EINTR)
                    continue;

                return false;
            }

            p += w;
            left -= static_cast<size_t>(w);
        }

        return true;
    }

    void apply_affinity_if_requested(int workerIndex, bool useAffinity)
    {
        if (!useAffinity)
            return;

        long cpuCount = sysconf(_SC_NPROCESSORS_ONLN);

        if (cpuCount <= 0)
            return;

        cpu_set_t set;
        CPU_ZERO(&set);
        CPU_SET(workerIndex % cpuCount, &set);

        sched_setaffinity(0, sizeof(set), &set);
    }

    bool send_task_message(int fd, int startTask, int count)
    {
        int msg[2] = {startTask, count};
        return write_full(fd, msg, sizeof(msg));
    }

    void close_child_unused_fds(vector<int>& parentToChild, int workerIndex, int workerCount)
    {
        for (int w = 0; w < workerCount; ++w)
        {
            int readFd = parentToChild[static_cast<size_t>(2 * w)];
            int writeFd = parentToChild[static_cast<size_t>(2 * w + 1)];

            if (w == workerIndex)
            {
                close(writeFd);
            }
            else
            {
                close(readFd);
                close(writeFd);
            }
        }
    }

    void run_worker_loop(int workerIndex,
                         int workerCount,
                         int assignmentReadFd,
                         OneToManyPipeTaskFunction taskFn,
                         void* ctx)
    {
        while (true)
        {
            int msg[2] = {0, 0};

            if (!read_full(assignmentReadFd, msg, sizeof(msg)))
                break;

            int startTask = msg[0];
            int count = msg[1];

            if (count <= 0)
                break;

            for (int offset = 0; offset < count; ++offset)
                taskFn(startTask + offset, workerIndex, workerCount, ctx);
        }

        close(assignmentReadFd);
        _exit(0);
    }

    int wait_for_workers(const vector<pid_t>& pids)
    {
        int rc = 0;

        for (size_t i = 0; i < pids.size(); ++i)
        {
            if (pids[i] <= 0)
                continue;

            int status = 0;

            while (waitpid(pids[i], &status, 0) == -1)
            {
                if (errno == EINTR)
                    continue;

                cerr << "waitpid failed for worker " << i << ": "
                     << strerror(errno) << "\n";
                rc = 1;
                break;
            }

            if (!WIFEXITED(status) || WEXITSTATUS(status) != 0)
                rc = 1;
        }

        return rc;
    }
}

int run_one_to_many_pipe_scheduler(int taskCount,
                                   int workerCount,
                                   OneToManyPipeTaskFunction taskFn,
                                   void* ctx,
                                   bool useAffinity)
{
    if (taskCount <= 0 || workerCount <= 0 || taskFn == nullptr || PIPE_SCHEDULER_CHUNK <= 0)
        return 1;

    vector<int> parentToChild(static_cast<size_t>(2 * workerCount), -1);
    vector<pid_t> pids(static_cast<size_t>(workerCount), -1);

    for (int w = 0; w < workerCount; ++w)
    {
        if (pipe(&parentToChild[static_cast<size_t>(2 * w)]) == -1)
        {
            cerr << "pipe failed in one-to-many scheduler: " << strerror(errno) << "\n";
            return 1;
        }
    }

    for (int w = 0; w < workerCount; ++w)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            cerr << "fork failed in one-to-many scheduler: " << strerror(errno) << "\n";
            return 1;
        }

        if (pid == 0)
        {
            int assignmentReadFd = parentToChild[static_cast<size_t>(2 * w)];

            close_child_unused_fds(parentToChild, w, workerCount);
            apply_affinity_if_requested(w, useAffinity);
            run_worker_loop(w, workerCount, assignmentReadFd, taskFn, ctx);
        }

        pids[static_cast<size_t>(w)] = pid;
    }

    for (int w = 0; w < workerCount; ++w)
    {
        close(parentToChild[static_cast<size_t>(2 * w)]);
        parentToChild[static_cast<size_t>(2 * w)] = -1;
    }

    int nextTask = 0;
    int assignmentNumber = 0;

    while (nextTask < taskCount)
    {
        int workerIndex = assignmentNumber % workerCount;
        int writeFd = parentToChild[static_cast<size_t>(2 * workerIndex + 1)];
        int count = min(PIPE_SCHEDULER_CHUNK, taskCount - nextTask);

        if (!send_task_message(writeFd, nextTask, count))
        {
            cerr << "write failed in one-to-many scheduler: " << strerror(errno) << "\n";
            for (int w = 0; w < workerCount; ++w)
            {
                if (parentToChild[static_cast<size_t>(2 * w + 1)] >= 0)
                    close(parentToChild[static_cast<size_t>(2 * w + 1)]);
            }
            wait_for_workers(pids);
            return 1;
        }

        nextTask += count;
        ++assignmentNumber;
    }

    for (int w = 0; w < workerCount; ++w)
    {
        int writeFd = parentToChild[static_cast<size_t>(2 * w + 1)];

        if (!send_task_message(writeFd, 0, 0))
        {
            cerr << "stop message failed in one-to-many scheduler: " << strerror(errno) << "\n";
            close(writeFd);
            parentToChild[static_cast<size_t>(2 * w + 1)] = -1;
            wait_for_workers(pids);
            return 1;
        }

        close(writeFd);
        parentToChild[static_cast<size_t>(2 * w + 1)] = -1;
    }

    return wait_for_workers(pids);
}
