#include "collective_scheduler.hpp"
#include <iostream>
#include <algorithm>
#include <cstring>
#include <sys/wait.h>
#include <errno.h>

collective_scheduler::collective_scheduler(int workers) : num_workers(workers) {
    worker_pids = new pid_t[num_workers];
    to_worker = new int*[num_workers];
    from_worker = new int*[num_workers];

    for (int i = 0; i < num_workers; ++i) {
        to_worker[i] = new int[2];
        from_worker[i] = new int[2];
        if (pipe(to_worker[i]) == -1 || pipe(from_worker[i]) == -1) {
            perror("pipe");
            exit(1);
        }
    }
    create_workers();
}

void collective_scheduler::create_workers() {
    for (int i = 0; i < num_workers; ++i) {
        pid_t pid = fork();
        if (pid == -1) {
            perror("fork");
            exit(1);
        }
        if (pid == 0) {
            //child process: worker i
            close(to_worker[i][1]);
            close(from_worker[i][0]);

            for (int j = 0; j < num_workers; ++j) {
                if (j != i) {
                    close(to_worker[j][0]);
                    close(to_worker[j][1]);
                    close(from_worker[j][0]);
                    close(from_worker[j][1]);
                }
            }

            //worker loop 
            while (true) {
                int idx;
                ssize_t bytes = read(to_worker[i][0], &idx, sizeof(int));
                if (bytes == 0) break;
                if (bytes != sizeof(int)) break;

                int cols;
                bytes = read(to_worker[i][0], &cols, sizeof(int));
                if (bytes != sizeof(int)) break;

                std::vector<int> row(cols);
                bytes = read(to_worker[i][0], row.data(), cols * sizeof(int));
                if (bytes != cols * sizeof(int)) break;

                std::sort(row.begin(), row.end());

                write(from_worker[i][1], &idx, sizeof(int));
                write(from_worker[i][1], &cols, sizeof(int));
                write(from_worker[i][1], row.data(), cols * sizeof(int));
            }

            close(to_worker[i][0]);
            close(from_worker[i][1]);
            _exit(0);
        } else {
            worker_pids[i] = pid;
            close(to_worker[i][0]);
            close(from_worker[i][1]);
        }
    }
}

collective_scheduler::~collective_scheduler() {
    for (int i = 0; i < num_workers; ++i) {
        close(to_worker[i][1]);
        close(from_worker[i][0]);
    }
    wait_for_children();

    for (int i = 0; i < num_workers; ++i) {
        delete[] to_worker[i];
        delete[] from_worker[i];
    }
    delete[] to_worker;
    delete[] from_worker;
    delete[] worker_pids;
}

void collective_scheduler::wait_for_children() {
    while (true) {
        int status;
        pid_t done = waitpid(-1, &status, 0);
        if (done == -1) {
            if (errno == ECHILD) break;
            perror("waitpid");
            break;
        }
    }
}

std::vector<int> collective_scheduler::sort_matrix_rows(const std::vector<int>& matrix, int rows, int cols) {
    std::vector<int> result(rows * cols);

    for (int i = 0; i < rows; ++i) {
        int worker = i % num_workers;
        write(to_worker[worker][1], &i, sizeof(int));
        write(to_worker[worker][1], &cols, sizeof(int));
        write(to_worker[worker][1], matrix.data() + i * cols, cols * sizeof(int));
    }

    for (int i = 0; i < rows; ++i) {
        int idx;
        ssize_t bytes = read(from_worker[i % num_workers][0], &idx, sizeof(int));
        if (bytes != sizeof(int)) break;

        int c;
        bytes = read(from_worker[i % num_workers][0], &c, sizeof(int));
        if (bytes != sizeof(int)) break;

        std::vector<int> sorted_row(c);
        bytes = read(from_worker[i % num_workers][0], sorted_row.data(), c * sizeof(int));
        if (bytes != c * sizeof(int)) break;

        std::copy(sorted_row.begin(), sorted_row.end(), result.begin() + idx * c);
    }
    return result;
}
