/*One shared pipe: task_pipe
Parent writes all tasks into the pipe.
All workers read tasks from the same pipe.
No assignment pipe per worker*/

// pipe_scheduler.cpp

#include <iostream>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cerrno>
#include <iomanip>

#include <unistd.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/mman.h>
#include <sys/wait.h>

#include <boost/chrono.hpp>

#include "row.hpp"
#include "sort_algorithms.hpp"

#define SHM_KEY 1234

/*#ifndef MAP_ANONYMOUS
#define MAP_ANONYMOUS MAP_ANON
#endif*/

using namespace std;

/*
    Computes the row-major 1D index for tensor[z][i][j].

    The tensor contains L matrices.
    Each matrix has size N x N.

    Row-major layout:

        tensor[z][i][j] -> z * N * N + i * N + j
*/
static inline size_t idx(size_t z, size_t i, size_t j, size_t N)
{
    return z * N * N + i * N + j;
}

/*
    Reads exactly 'bytes' bytes from file descriptor 'fd'.

    Returns:
        1 on success
        0 on EOF or error

    In this program, workers read fixed-size task messages:

        msg[0] = start slice
        msg[1] = slice count
*/
int read_full(int fd, void *buf, size_t bytes)
{
    char *p = (char *)buf;
    size_t left = bytes;

    while (left > 0)
    {
        long r = (long)read(fd, p, left);

        if (r == 0)
        {
            return 0;
        }

        if (r < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }

            return 0;
        }

        p += (size_t)r;
        left -= (size_t)r;
    }

    return 1;
}

/*
    Writes exactly 'bytes' bytes to file descriptor 'fd'.

    Returns:
        1 on success
        0 on error

    The parent writes fixed-size task messages to the shared pipe.
*/
int write_full(int fd, const void *buf, size_t bytes)
{
    const char *p = (const char *)buf;
    size_t left = bytes;

    while (left > 0)
    {
        long w = (long)write(fd, p, left);

        if (w < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }

            return 0;
        }

        p += (size_t)w;
        left -= (size_t)w;
    }

    return 1;
}

/*
    Sorts rows of matrices from start_slice up to end_slice - 1.

    Each slice is one N x N matrix.

    For every row:
        1. Get a pointer to the row.
        2. Wrap the row with row<uint8_t>.
        3. Sort the row using quick_sort.
        4. Copy sorted values back to the mmap tensor.
*/
void sort_slice_range(
    uint8_t *tensor,
    size_t start_slice,
    size_t end_slice,
    size_t N)
{
    for (size_t z = start_slice; z < end_slice; z++)
    {
        for (size_t i = 0; i < N; i++)
        {
            uint8_t *row_ptr = &tensor[idx(z, i, 0, N)];

            row<uint8_t> r(row_ptr, N);

            quick_sort(r);

            for (size_t j = 0; j < N; j++)
            {
                row_ptr[j] = r[j];
            }
        }
    }
}

/*
    Worker process loop for Option A.

    There is only one shared task pipe.

    The parent writes task messages into the pipe.
    All workers read from the same pipe.

    A task message has:

        msg[0] = start slice
        msg[1] = slice count

    A stop message has:

        msg[0] = 0
        msg[1] = 0

    The worker exits when it reads a stop message.
*/
void worker_loop(
    size_t worker_id,
    int task_read_fd,
    uint8_t *mmap_tensor,
    size_t N)
{
    while (true)
    {
        size_t msg[2];

        if (!read_full(task_read_fd, msg, sizeof(msg)))
        {
            break;
        }

        size_t start_slice = msg[0];
        size_t slice_count = msg[1];

        if (slice_count == 0)
        {
            break;
        }

        size_t end_slice = start_slice + slice_count;

        sort_slice_range(
            mmap_tensor,
            start_slice,
            end_slice,
            N);

        cout << "Worker " << worker_id
             << " PID " << getpid()
             << " sorted slices ["
             << start_slice << ", "
             << end_slice << ")"
             << endl;
    }

    close(task_read_fd);

    _exit(0);
}

/*
    Checks whether one row of one slice is sorted.

    Returns:
        1 if sorted
        0 otherwise
*/
int is_row_sorted(uint8_t *tensor, size_t z, size_t i, size_t N)
{
    uint8_t *row_ptr = &tensor[idx(z, i, 0, N)];

    for (size_t j = 1; j < N; j++)
    {
        if (row_ptr[j - 1] > row_ptr[j])
        {
            return 0;
        }
    }

    return 1;
}

/*
    Verifies that every row of every matrix in the mmap copy is sorted.

    Returns:
        1 if all rows are sorted
        0 otherwise
*/
int verify_all(uint8_t *tensor, size_t L, size_t N)
{
    for (size_t z = 0; z < L; z++)
    {
        for (size_t i = 0; i < N; i++)
        {
            if (!is_row_sorted(tensor, z, i, N))
            {
                cout << "Verification failed at slice "
                     << z << ", row " << i << endl;

                return 0;
            }
        }
    }

    return 1;
}

/*
    Prints a small preview window from one N x N slice.

    This avoids printing the whole N x N matrix.
*/
void print_tensor_slice_preview(
    uint8_t *tensor,
    size_t L,
    size_t N,
    size_t z,
    size_t print_rows,
    size_t print_cols,
    const char *title)
{
    if (L == 0 || N == 0)
    {
        return;
    }

    if (z >= L)
    {
        z = 0;
    }

    if (print_rows > N)
    {
        print_rows = N;
    }

    if (print_cols > N)
    {
        print_cols = N;
    }

    cout << endl;
    cout << title << endl;
    cout << "Slice z = " << z
         << ", preview "
         << print_rows << " x " << print_cols
         << " from full "
         << N << " x " << N
         << " matrix"
         << endl;

    for (size_t i = 0; i < print_rows; i++)
    {
        for (size_t j = 0; j < print_cols; j++)
        {
            cout << setw(4)
                 << (unsigned int)tensor[idx(z, i, j, N)];
        }

        if (print_cols < N)
        {
            cout << " ...";
        }

        cout << endl;
    }

    if (print_rows < N)
    {
        cout << "..." << endl;
    }
}

/*
    Writes all sorting tasks to the shared task pipe.

    The pipe works as a task queue.

    Parent writes:

        [0, batch_size)
        [batch_size, 2 * batch_size)
        ...

    Each worker reads one task at a time.

    After all real tasks, the parent writes one stop message
    for every worker.
*/
void write_all_tasks(
    int task_write_fd,
    size_t L,
    size_t batch_size,
    size_t k)
{
    size_t next_slice = 0;

    while (next_slice < L)
    {
        size_t count = batch_size;

        if (count > L - next_slice)
        {
            count = L - next_slice;
        }

        size_t msg[2];

        msg[0] = next_slice;
        msg[1] = count;

        if (!write_full(task_write_fd, msg, sizeof(msg)))
        {
            cout << "write_full failed while writing task." << endl;
            _exit(1);
        }

        next_slice += count;
    }

    /*
        Send one stop message per worker.

        If there are k workers, exactly k stop messages are needed,
        otherwise some workers may remain blocked on read().
    */
    for (size_t w = 0; w < k; w++)
    {
        size_t stop_msg[2];

        stop_msg[0] = 0;
        stop_msg[1] = 0;

        if (!write_full(task_write_fd, stop_msg, sizeof(stop_msg)))
        {
            cout << "write_full failed while writing stop message." << endl;
            _exit(1);
        }
    }
}

/*
    Waits for all worker processes using waitpid().

    pids[w] contains the PID returned by fork() for worker w.
*/
void wait_all_workers(pid_t *pids, size_t k)
{
    for (size_t w = 0; w < k; w++)
    {
        int status = 0;

        if (pids[w] <= 0)
        {
            continue;
        }

        while (waitpid(pids[w], &status, 0) == -1)
        {
            if (errno == EINTR)
            {
                continue;
            }

            cout << "waitpid failed for worker "
                 << (w + 1)
                 << endl;
            break;
        }

        if (WIFEXITED(status))
        {
            cout << "Worker " << (w + 1)
                 << " PID " << pids[w]
                 << " exited with code "
                 << WEXITSTATUS(status)
                 << endl;
        }
        else if (WIFSIGNALED(status))
        {
            cout << "Worker " << (w + 1)
                 << " PID " << pids[w]
                 << " killed by signal "
                 << WTERMSIG(status)
                 << endl;
        }
    }
}

/*
    Main parent scheduler process.

    Option A design:

        1. Parent creates one shared task pipe.
        2. Parent forks k workers.
        3. Every worker reads from the same pipe.
        4. Parent writes all tasks to the pipe.
        5. Parent writes k stop messages.
        6. Parent waits for all workers.

    No request pipe is needed.
    No select() is needed.
    No per-worker assignment pipe is needed.
*/
int main(int argc, char **argv)
{
    if (argc < 4 || argc > 8)
    {
        cout << "Usage: "
             << argv[0]
             << " L N k [batch_size] [print_slice] [print_rows] [print_cols]"
             << endl;

        cout << "Example: "
             << argv[0]
             << " 100 640 4 5"
             << endl;

        cout << "Example with print options: "
             << argv[0]
             << " 100 640 4 5 0 8 16"
             << endl;

        _exit(1);
    }

    size_t L = (size_t)atoll(argv[1]);
    size_t N = (size_t)atoll(argv[2]);
    size_t k = (size_t)atoll(argv[3]);

    size_t batch_size = 1;
    size_t print_slice = 0;
    size_t print_rows = 8;
    size_t print_cols = 16;

    if (argc >= 5)
    {
        batch_size = (size_t)atoll(argv[4]);
    }

    if (argc >= 6)
    {
        print_slice = (size_t)atoll(argv[5]);
    }

    if (argc >= 7)
    {
        print_rows = (size_t)atoll(argv[6]);
    }

    if (argc >= 8)
    {
        print_cols = (size_t)atoll(argv[7]);
    }

    if (L == 0 || N == 0 || k == 0 || batch_size == 0)
    {
        cout << "Invalid arguments." << endl;
        _exit(1);
    }

    if (print_rows == 0)
    {
        print_rows = 8;
    }

    if (print_cols == 0)
    {
        print_cols = 16;
    }

    if (print_slice >= L)
    {
        cout << "Requested print_slice is outside range. Using slice 0." << endl;
        print_slice = 0;
    }

    size_t total_elements = L * N * N;
    size_t total_bytes = total_elements * sizeof(uint8_t);

    /*
        Attach to the original SysV shared memory created by generator8.
    */
    int shmid = shmget(SHM_KEY, total_bytes, 0666);

    if (shmid == -1)
    {
        cout << "shmget failed. Did you run generator8 first?" << endl;
        _exit(1);
    }

    void *sysv_base = shmat(shmid, NULL, 0);

    if (sysv_base == (void *)-1)
    {
        cout << "shmat failed" << endl;
        _exit(1);
    }

    uint8_t *sysv_tensor = (uint8_t *)sysv_base;

    /*
        Allocate mmap shared memory.

        This mmap region is shared between parent and forked workers.
        Workers sort this copy, not the original SysV memory.
    */
    void *mmap_base = mmap(
        NULL,
        total_bytes,
        PROT_READ | PROT_WRITE,
        MAP_SHARED | MAP_ANON,
        -1,
        0);

    if (mmap_base == MAP_FAILED)
    {
        cout << "mmap failed" << endl;
        shmdt(sysv_base);
        _exit(1);
    }

    uint8_t *mmap_tensor = (uint8_t *)mmap_base;

    /*
        Fast copy from original unsorted SysV tensor to mmap tensor.
    */
    memcpy(mmap_tensor, sysv_tensor, total_bytes);

    cout << "Copied SysV shared memory to mmap shared memory." << endl;

    /*
        Create one shared task pipe.

        task_pipe[0] = read end
        task_pipe[1] = write end

        Parent writes tasks to task_pipe[1].
        Workers read tasks from task_pipe[0].
    */
    int task_pipe[2];

    if (pipe(task_pipe) == -1)
    {
        cout << "pipe failed" << endl;
        munmap(mmap_base, total_bytes);
        shmdt(sysv_base);
        _exit(1);
    }

    pid_t root_pid = getpid();

    pid_t *pids = new pid_t[k];

    for (size_t w = 0; w < k; w++)
    {
        pids[w] = -1;
    }

    /*
        Create k worker processes.

        Every worker inherits the same task_pipe[0].
        Therefore all workers compete for tasks from the same pipe.
    */
    for (size_t w = 0; w < k; w++)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            cout << "fork failed" << endl;
            _exit(1);
        }

        if (pid == 0)
        {
            /*
                Child process.

                The child only reads tasks.
                It must close the write end of the pipe.
            */
            close(task_pipe[1]);

            /*
                Child does not need the original SysV memory.
                It sorts only the mmap copy.
            */
            shmdt(sysv_base);

            int task_read_fd = task_pipe[0];

            delete[] pids;

            worker_loop(
                w + 1,
                task_read_fd,
                mmap_tensor,
                N);

            _exit(0);
        }

        /*
            Only the root parent reaches here.
        */
        if (getpid() == root_pid)
        {
            pids[w] = pid;

            cout << "Root parent " << getpid()
                 << " created worker " << (w + 1)
                 << " with PID " << pid
                 << endl;
        }
    }

    /*
        Only the original parent should write tasks and wait for workers.
    */
    if (getpid() == root_pid)
    {
        /*
            Parent only writes tasks.
            It must close the read end of the task pipe.
        */
        close(task_pipe[0]);

        boost::chrono::high_resolution_clock::time_point sort_start;
        boost::chrono::high_resolution_clock::time_point sort_end;

        sort_start = boost::chrono::high_resolution_clock::now();

        /*
            Parent writes all tasks into the shared pipe.

            This replaces:
                - request pipes
                - assignment pipes
                - select()
                - FD_SET()
                - FD_ISSET()

            The pipe itself behaves like a work queue.
        */
        write_all_tasks(
            task_pipe[1],
            L,
            batch_size,
            k);

        /*
            Close the write end.

            This is not strictly enough by itself because we already sent
            explicit stop messages, but it is still correct cleanup.
        */
        close(task_pipe[1]);

        /*
            Wait for all workers by PID.
        */
        wait_all_workers(pids, k);

        sort_end = boost::chrono::high_resolution_clock::now();

        boost::chrono::milliseconds sort_ms =
            boost::chrono::duration_cast<boost::chrono::milliseconds>(
                sort_end - sort_start);

        cout << endl;
        cout << "All workers finished." << endl;
        cout << "Total sorting process duration: "
             << sort_ms.count()
             << " ms"
             << endl;

        if (verify_all(mmap_tensor, L, N))
        {
            cout << "Verification success: all rows are sorted." << endl;
        }
        else
        {
            cout << "Verification failed." << endl;
        }

        print_tensor_slice_preview(
            mmap_tensor,
            L,
            N,
            print_slice,
            print_rows,
            print_cols,
            "Sorted mmap copy");

        print_tensor_slice_preview(
            sysv_tensor,
            L,
            N,
            print_slice,
            print_rows,
            print_cols,
            "Original unsorted SysV shared memory");

        munmap(mmap_base, total_bytes);
        shmdt(sysv_base);

        delete[] pids;
    }

    return 0;
}
