/*Option B as a separate full scheduler: one shared request pipe from all workers to the parent, plus one private assignment pipe from the parent to each worker. This keeps targeted assignments while reducing the number of pipes versus the original 2k request/assignment design.
Workers write requests to the same request pipe.
Each request contains the worker index.
Parent reads the request, knows which worker asked, and writes the assignment to that worker’s private pipe.
No select().
No socketpair().
*/
// pipe_scheduler_option_b.cpp

#include <iostream>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cerrno>
#include <iomanip>

#include <unistd.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/mman.h>
#include <sys/wait.h>

#include <boost/chrono.hpp>

#include "row.hpp"
#include "sort_algorithms.hpp"

#define SHM_KEY 1234

#ifndef MAP_ANONYMOUS
#define MAP_ANONYMOUS MAP_ANON
#endif

using namespace std;

/*
    Computes the row-major 1D index for tensor[z][i][j].

    Row-major layout:

        tensor[z][i][j] -> z * N * N + i * N + j
*/
static inline size_t idx(size_t z, size_t i, size_t j, size_t N)
{
    return z * N * N + i * N + j;
}

/*
    Reads exactly 'bytes' bytes from file descriptor 'fd'.

    Returns:
        1 on success
        0 on EOF or error
*/
int read_full(int fd, void *buf, size_t bytes)
{
    char *p = (char *)buf;
    size_t left = bytes;

    while (left > 0)
    {
        long r = (long)read(fd, p, left);

        if (r == 0)
        {
            return 0;
        }

        if (r < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }

            return 0;
        }

        p += (size_t)r;
        left -= (size_t)r;
    }

    return 1;
}

/*
    Writes exactly 'bytes' bytes to file descriptor 'fd'.

    Returns:
        1 on success
        0 on error
*/
int write_full(int fd, const void *buf, size_t bytes)
{
    const char *p = (const char *)buf;
    size_t left = bytes;

    while (left > 0)
    {
        long w = (long)write(fd, p, left);

        if (w < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }

            return 0;
        }

        p += (size_t)w;
        left -= (size_t)w;
    }

    return 1;
}

/*
    Sorts rows of matrices from start_slice up to end_slice - 1.

    Each slice is one N x N matrix.
*/
void sort_slice_range(
    uint8_t *tensor,
    size_t start_slice,
    size_t end_slice,
    size_t N)
{
    for (size_t z = start_slice; z < end_slice; z++)
    {
        for (size_t i = 0; i < N; i++)
        {
            uint8_t *row_ptr = &tensor[idx(z, i, 0, N)];

            row<uint8_t> r(row_ptr, N);

            quick_sort(r);

            for (size_t j = 0; j < N; j++)
            {
                row_ptr[j] = r[j];
            }
        }
    }
}

/*
    Worker sends one request to the shared request pipe.

    The request contains the zero-based worker index.

    Example:
        worker 1 sends 0
        worker 2 sends 1
        worker 3 sends 2

    This allows the parent to know which private assignment pipe
    should receive the next task.
*/
int send_worker_request(int request_write_fd, size_t worker_index)
{
    if (!write_full(request_write_fd, &worker_index, sizeof(worker_index)))
    {
        return 0;
    }

    return 1;
}

/*
    Parent receives one request from the shared request pipe.

    The value read is the zero-based worker index.
*/
int receive_worker_request(int request_read_fd, size_t *worker_index)
{
    if (!read_full(request_read_fd, worker_index, sizeof(size_t)))
    {
        return 0;
    }

    return 1;
}

/*
    Parent sends either a real task or a stop message to one worker.

    Message format:
        msg[0] = start slice
        msg[1] = number of slices

    Stop message:
        msg[0] = 0
        msg[1] = 0

    Returns:
        1 if real work was sent
        0 if stop message was sent
*/
int send_assignment_or_stop(
    int assignment_write_fd,
    size_t *next_slice,
    size_t L,
    size_t batch_size)
{
    size_t msg[2];

    if (*next_slice < L)
    {
        size_t count = batch_size;

        if (count > L - *next_slice)
        {
            count = L - *next_slice;
        }

        msg[0] = *next_slice;
        msg[1] = count;

        *next_slice += count;
    }
    else
    {
        msg[0] = 0;
        msg[1] = 0;
    }

    if (!write_full(assignment_write_fd, msg, sizeof(msg)))
    {
        cout << "write_full failed while sending assignment." << endl;
        _exit(1);
    }

    if (msg[1] == 0)
    {
        return 0;
    }

    return 1;
}

/*
    Worker process loop for Option B.

    Each worker has:
        1. Shared request pipe write end.
        2. Private assignment pipe read end.

    Worker protocol:
        1. Send worker index to parent through shared request pipe.
        2. Read assignment from private assignment pipe.
        3. Sort assigned slices.
        4. Repeat.

    If slice_count == 0, worker exits.
*/
void worker_loop(
    size_t worker_id,
    size_t worker_index,
    int request_write_fd,
    int assignment_read_fd,
    uint8_t *mmap_tensor,
    size_t N)
{
    while (true)
    {
        if (!send_worker_request(request_write_fd, worker_index))
        {
            break;
        }

        size_t msg[2];

        if (!read_full(assignment_read_fd, msg, sizeof(msg)))
        {
            break;
        }

        size_t start_slice = msg[0];
        size_t slice_count = msg[1];

        if (slice_count == 0)
        {
            break;
        }

        size_t end_slice = start_slice + slice_count;

        sort_slice_range(
            mmap_tensor,
            start_slice,
            end_slice,
            N);

        cout << "Worker " << worker_id
             << " PID " << getpid()
             << " sorted slices ["
             << start_slice << ", "
             << end_slice << ")"
             << endl;
    }

    close(request_write_fd);
    close(assignment_read_fd);

    _exit(0);
}

/*
    Closes assignment pipe descriptors not needed by one child.

    Each worker keeps only:

        parent_to_child[2*w]

    because this is its private assignment read end.

    The child closes:
        - its own assignment write end
        - all descriptors for other workers
*/
void close_child_unused_assignment_fds(
    int *parent_to_child,
    size_t k,
    size_t w)
{
    for (size_t j = 0; j < k; j++)
    {
        if (j == w)
        {
            close(parent_to_child[2 * j + 1]);
        }
        else
        {
            close(parent_to_child[2 * j]);
            close(parent_to_child[2 * j + 1]);
        }
    }
}

/*
    Closes one parent-side assignment pipe.

    Parent keeps only:
        parent_to_child[2*w + 1]

    This is the write end used to send assignments to worker w.
*/
void close_parent_assignment_fd(
    int *parent_to_child,
    size_t w)
{
    if (parent_to_child[2 * w + 1] >= 0)
    {
        close(parent_to_child[2 * w + 1]);
        parent_to_child[2 * w + 1] = -1;
    }
}

/*
    Checks whether one row of one slice is sorted.

    Returns:
        1 if sorted
        0 otherwise
*/
int is_row_sorted(uint8_t *tensor, size_t z, size_t i, size_t N)
{
    uint8_t *row_ptr = &tensor[idx(z, i, 0, N)];

    for (size_t j = 1; j < N; j++)
    {
        if (row_ptr[j - 1] > row_ptr[j])
        {
            return 0;
        }
    }

    return 1;
}

/*
    Verifies all rows of all slices.

    Returns:
        1 if all rows are sorted
        0 otherwise
*/
int verify_all(uint8_t *tensor, size_t L, size_t N)
{
    for (size_t z = 0; z < L; z++)
    {
        for (size_t i = 0; i < N; i++)
        {
            if (!is_row_sorted(tensor, z, i, N))
            {
                cout << "Verification failed at slice "
                     << z << ", row " << i << endl;

                return 0;
            }
        }
    }

    return 1;
}

/*
    Prints a small preview of one slice.

    This avoids printing the whole N x N matrix.
*/
void print_tensor_slice_preview(
    uint8_t *tensor,
    size_t L,
    size_t N,
    size_t z,
    size_t print_rows,
    size_t print_cols,
    const char *title)
{
    if (L == 0 || N == 0)
    {
        return;
    }

    if (z >= L)
    {
        z = 0;
    }

    if (print_rows > N)
    {
        print_rows = N;
    }

    if (print_cols > N)
    {
        print_cols = N;
    }

    cout << endl;
    cout << title << endl;
    cout << "Slice z = " << z
         << ", preview "
         << print_rows << " x " << print_cols
         << " from full "
         << N << " x " << N
         << " matrix"
         << endl;

    for (size_t i = 0; i < print_rows; i++)
    {
        for (size_t j = 0; j < print_cols; j++)
        {
            cout << setw(4)
                 << (unsigned int)tensor[idx(z, i, j, N)];
        }

        if (print_cols < N)
        {
            cout << " ...";
        }

        cout << endl;
    }

    if (print_rows < N)
    {
        cout << "..." << endl;
    }
}

/*
    Waits for all worker processes using waitpid().
*/
void wait_all_workers(pid_t *pids, size_t k)
{
    for (size_t w = 0; w < k; w++)
    {
        int status = 0;

        if (pids[w] <= 0)
        {
            continue;
        }

        while (waitpid(pids[w], &status, 0) == -1)
        {
            if (errno == EINTR)
            {
                continue;
            }

            cout << "waitpid failed for worker "
                 << (w + 1)
                 << endl;
            break;
        }

        if (WIFEXITED(status))
        {
            cout << "Worker " << (w + 1)
                 << " PID " << pids[w]
                 << " exited with code "
                 << WEXITSTATUS(status)
                 << endl;
        }
        else if (WIFSIGNALED(status))
        {
            cout << "Worker " << (w + 1)
                 << " PID " << pids[w]
                 << " killed by signal "
                 << WTERMSIG(status)
                 << endl;
        }
    }
}

/*
    Main scheduler process.

    Option B design:

        shared request pipe:
            workers write requests
            parent reads requests

        private assignment pipes:
            parent writes assignment to the specific worker
            worker reads assignment from its own pipe

    Number of pipes:
        1 shared request pipe
        k private assignment pipes
*/
int main(int argc, char **argv)
{
    if (argc < 4 || argc > 8)
    {
        cout << "Usage: "
             << argv[0]
             << " L N k [batch_size] [print_slice] [print_rows] [print_cols]"
             << endl;

        cout << "Example: "
             << argv[0]
             << " 100 640 4 5"
             << endl;

        cout << "Example with print options: "
             << argv[0]
             << " 100 640 4 5 0 8 16"
             << endl;

        _exit(1);
    }

    size_t L = (size_t)atoll(argv[1]);
    size_t N = (size_t)atoll(argv[2]);
    size_t k = (size_t)atoll(argv[3]);

    size_t batch_size = 1;
    size_t print_slice = 0;
    size_t print_rows = 8;
    size_t print_cols = 16;

    if (argc >= 5)
    {
        batch_size = (size_t)atoll(argv[4]);
    }

    if (argc >= 6)
    {
        print_slice = (size_t)atoll(argv[5]);
    }

    if (argc >= 7)
    {
        print_rows = (size_t)atoll(argv[6]);
    }

    if (argc >= 8)
    {
        print_cols = (size_t)atoll(argv[7]);
    }

    if (L == 0 || N == 0 || k == 0 || batch_size == 0)
    {
        cout << "Invalid arguments." << endl;
        _exit(1);
    }

    if (print_rows == 0)
    {
        print_rows = 8;
    }

    if (print_cols == 0)
    {
        print_cols = 16;
    }

    if (print_slice >= L)
    {
        cout << "Requested print_slice is outside range. Using slice 0." << endl;
        print_slice = 0;
    }

    size_t total_elements = L * N * N;
    size_t total_bytes = total_elements * sizeof(uint8_t);

    /*
        Attach to original SysV shared memory created by generator8.
    */
    int shmid = shmget(SHM_KEY, total_bytes, 0666);

    if (shmid == -1)
    {
        cout << "shmget failed. Did you run generator8 first?" << endl;
        _exit(1);
    }

    void *sysv_base = shmat(shmid, NULL, 0);

    if (sysv_base == (void *)-1)
    {
        cout << "shmat failed" << endl;
        _exit(1);
    }

    uint8_t *sysv_tensor = (uint8_t *)sysv_base;

    /*
        Allocate mmap shared memory.

        Workers sort this mmap copy.
        The original SysV shared memory remains unsorted.
    */
    void *mmap_base = mmap(
        NULL,
        total_bytes,
        PROT_READ | PROT_WRITE,
        MAP_SHARED | MAP_ANONYMOUS,
        -1,
        0);

    if (mmap_base == MAP_FAILED)
    {
        cout << "mmap failed" << endl;
        shmdt(sysv_base);
        _exit(1);
    }

    uint8_t *mmap_tensor = (uint8_t *)mmap_base;

    memcpy(mmap_tensor, sysv_tensor, total_bytes);

    cout << "Copied SysV shared memory to mmap shared memory." << endl;

    /*
        Shared request pipe:

            request_pipe[0] = parent reads requests
            request_pipe[1] = workers write requests
    */
    int request_pipe[2];

    if (pipe(request_pipe) == -1)
    {
        cout << "request pipe failed" << endl;
        munmap(mmap_base, total_bytes);
        shmdt(sysv_base);
        _exit(1);
    }

    /*
        Private assignment pipes.

        For worker w:

            parent_to_child[2*w]     = worker reads assignment
            parent_to_child[2*w + 1] = parent writes assignment
    */
    int *parent_to_child = new int[2 * k];
    pid_t *pids = new pid_t[k];

    for (size_t i = 0; i < 2 * k; i++)
    {
        parent_to_child[i] = -1;
    }

    for (size_t w = 0; w < k; w++)
    {
        pids[w] = -1;
    }

    for (size_t w = 0; w < k; w++)
    {
        if (pipe(parent_to_child + 2 * w) == -1)
        {
            cout << "assignment pipe failed" << endl;
            _exit(1);
        }
    }

    pid_t root_pid = getpid();

    /*
        Create k workers.
    */
    for (size_t w = 0; w < k; w++)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            cout << "fork failed" << endl;
            _exit(1);
        }

        if (pid == 0)
        {
            /*
                Child writes to shared request pipe.
                Child does not read from request pipe.
            */
            close(request_pipe[0]);

            /*
                Child keeps only its private assignment read end:
                    parent_to_child[2*w]
            */
            close_child_unused_assignment_fds(
                parent_to_child,
                k,
                w);

            /*
                Child does not need original SysV memory.
                It sorts only mmap_tensor.
            */
            shmdt(sysv_base);

            int request_write_fd = request_pipe[1];
            int assignment_read_fd = parent_to_child[2 * w];

            delete[] parent_to_child;
            delete[] pids;

            worker_loop(
                w + 1,
                w,
                request_write_fd,
                assignment_read_fd,
                mmap_tensor,
                N);

            _exit(0);
        }

        if (getpid() == root_pid)
        {
            pids[w] = pid;

            cout << "Root parent " << getpid()
                 << " created worker " << (w + 1)
                 << " with PID " << pid
                 << endl;
        }
    }

    /*
        Parent reads from shared request pipe.
        Parent does not write requests.
    */
    close(request_pipe[1]);

    /*
        Parent writes assignments only.
        Parent closes all assignment read ends.
    */
    for (size_t w = 0; w < k; w++)
    {
        close(parent_to_child[2 * w]);
        parent_to_child[2 * w] = -1;
    }

    size_t next_slice = 0;
    size_t active_workers = k;

    boost::chrono::high_resolution_clock::time_point sort_start;
    boost::chrono::high_resolution_clock::time_point sort_end;

    sort_start = boost::chrono::high_resolution_clock::now();

    /*
        Main Option B scheduler loop.

        Parent blocks on the shared request pipe.

        When a worker finishes and requests more work:
            1. Parent reads worker_index.
            2. Parent sends next assignment to that worker's private pipe.

        When no work remains:
            Parent sends stop message to that worker and closes its pipe.
    */
    while (active_workers > 0)
    {
        size_t worker_index = 0;

        if (!receive_worker_request(request_pipe[0], &worker_index))
        {
            cout << "Parent failed to read worker request." << endl;
            break;
        }

        if (worker_index >= k)
        {
            cout << "Parent received invalid worker index." << endl;
            continue;
        }

        if (parent_to_child[2 * worker_index + 1] < 0)
        {
            continue;
        }

        int worker_has_work = send_assignment_or_stop(
            parent_to_child[2 * worker_index + 1],
            &next_slice,
            L,
            batch_size);

        if (!worker_has_work)
        {
            close_parent_assignment_fd(
                parent_to_child,
                worker_index);

            active_workers--;
        }
    }

    close(request_pipe[0]);

    /*
        Wait for all workers.
    */
    wait_all_workers(pids, k);

    sort_end = boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds sort_ms =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(
            sort_end - sort_start);

    cout << endl;
    cout << "All workers finished." << endl;
    cout << "Total sorting process duration: "
         << sort_ms.count()
         << " ms"
         << endl;

    if (verify_all(mmap_tensor, L, N))
    {
        cout << "Verification success: all rows are sorted." << endl;
    }
    else
    {
        cout << "Verification failed." << endl;
    }

    print_tensor_slice_preview(
        mmap_tensor,
        L,
        N,
        print_slice,
        print_rows,
        print_cols,
        "Sorted mmap copy");

    print_tensor_slice_preview(
        sysv_tensor,
        L,
        N,
        print_slice,
        print_rows,
        print_cols,
        "Original unsorted SysV shared memory");

    munmap(mmap_base, total_bytes);
    shmdt(sysv_base);

    delete[] parent_to_child;
    delete[] pids;

    return 0;
}
