// pipe_scheduler.cpp
//Many request 
//Mane send

#include <iostream>
#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <iomanip>

#include <unistd.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/mman.h>
#include <sys/wait.h>
#include <sys/select.h>

#include <boost/chrono.hpp>

#include "row.hpp"
#include "sort_algorithms.hpp"

#define SHM_KEY 1234

#ifndef MAP_ANONYMOUS
#define MAP_ANONYMOUS MAP_ANON
#endif

using namespace std;

/*
    Computes the row-major 1D index for tensor[z][i][j].

    The tensor contains L matrices.
    Each matrix has size N x N.

    Row-major layout:

        tensor[z][i][j] -> z * N * N + i * N + j
*/
static inline size_t idx(size_t z, size_t i, size_t j, size_t N)
{
    return z * N * N + i * N + j;
}

/*
    Reads exactly 'bytes' bytes from file descriptor 'fd'.

    Returns:
        1 on success
        0 on EOF or error

*/
int read_full(int fd, void *buf, size_t bytes)
{
    char *p = (char *)buf;
    size_t left = bytes;

    while (left > 0)
    {
        long r = (long)read(fd, p, left);

        if (r == 0)
        {
            return 0;
        }

        if (r < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }

            return 0;
        }

        p += (size_t)r;
        left -= (size_t)r;
    }

    return 1;
}

/*
    Writes exactly 'bytes' bytes to file descriptor 'fd'.

    Returns:
        1 on success
        0 on error

*/
int write_full(int fd, const void *buf, size_t bytes)
{
    const char *p = (const char *)buf;
    size_t left = bytes;

    while (left > 0)
    {
        long w = (long)write(fd, p, left);

        if (w < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }

            return 0;
        }

        p += (size_t)w;
        left -= (size_t)w;
    }

    return 1;
}

/*
    Receives one work request from a worker.

    The worker sends a size_t request value.
    The value itself is not important.
    It only tells the parent that the worker is ready.
*/
int receive_worker_request(int request_fd)
{
    size_t request = 0;

    if (!read_full(request_fd, &request, sizeof(request)))
    {
        return 0;
    }

    return 1;
}

/*
    Sends either a real task assignment or a stop message to a worker.

    Message format:
        msg[0] = start slice
        msg[1] = number of slices

    If msg[1] == 0, the worker exits.

    Returns:
        1 if real work was assigned
        0 if stop message was assigned
*/
int send_assignment_or_stop(
    int assignment_fd,
    size_t *next_slice,
    size_t L,
    size_t batch_size)
{
    size_t msg[2];

    if (*next_slice < L)
    {
        size_t count = batch_size;

        if (count > L - *next_slice)
        {
            count = L - *next_slice;
        }

        msg[0] = *next_slice;
        msg[1] = count;

        *next_slice += count;
    }
    else
    {
        msg[0] = 0;
        msg[1] = 0;
    }

    write_full(assignment_fd, msg, sizeof(msg));

    if (msg[1] == 0)
    {
        return 0;
    }

    return 1;
}

/*
    Closes the two parent-side pipe descriptors for one worker.

    Parent keeps:
        parent_to_child[2*w + 1]  for writing assignments
        child_to_parent[2*w]      for reading requests

    When a worker is finished, both are closed.
*/
void close_parent_worker_fds(
    int *parent_to_child,
    int *child_to_parent,
    size_t w)
{
    if (child_to_parent[2 * w] >= 0)
    {
        close(child_to_parent[2 * w]);
        child_to_parent[2 * w] = -1;
    }

    if (parent_to_child[2 * w + 1] >= 0)
    {
        close(parent_to_child[2 * w + 1]);
        parent_to_child[2 * w + 1] = -1;
    }
}

/*
    Closes all unused pipe descriptors inside one child process.

    Each child keeps only:
        parent_to_child[2*w]      for reading assignments
        child_to_parent[2*w + 1]  for sending requests

    All other descriptors must be closed.
*/
void close_child_unused_fds(
    int *parent_to_child,
    int *child_to_parent,
    size_t k,
    size_t w)
{
    for (size_t j = 0; j < k; j++)
    {
        close(parent_to_child[2 * j + 1]);
        close(child_to_parent[2 * j]);

        if (j != w)
        {
            close(parent_to_child[2 * j]);
            close(child_to_parent[2 * j + 1]);
        }
    }
}

/*
    Sorts rows of matrices from start_slice up to end_slice - 1.

    Each slice is one N x N matrix.

    For every row:
        1. Get a pointer to the row.
        2. Wrap the row with row<uint8_t>.
        3. Sort the row using quick_sort.
        4. Copy sorted values back to the mmap tensor.
*/
void sort_slice_range(
    uint8_t *tensor,
    size_t start_slice,
    size_t end_slice,
    size_t N)
{
    for (size_t z = start_slice; z < end_slice; z++)
    {
        for (size_t i = 0; i < N; i++)
        {
            uint8_t *row_ptr = &tensor[idx(z, i, 0, N)];

            row<uint8_t> r(row_ptr, N);

            quick_sort(r);

            for (size_t j = 0; j < N; j++)
            {
                row_ptr[j] = r[j];
            }
        }
    }
}

/*
    Worker process loop.

    Each worker repeatedly:
        1. Sends a request to the parent.
        2. Waits for an assignment.
        3. Sorts the assigned slices.
        4. Requests more work.

    If the parent sends slice_count == 0, the worker exits.
*/
void worker_loop(
    size_t worker_id,
    int read_from_parent,
    int write_to_parent,
    uint8_t *mmap_tensor,
    size_t N)
{
    while (true)
    {
        size_t request = 1;

        if (!write_full(write_to_parent, &request, sizeof(request)))
        {
            break;
        }

        size_t msg[2];

        if (!read_full(read_from_parent, msg, sizeof(msg)))
        {
            break;
        }

        size_t start_slice = msg[0];
        size_t slice_count = msg[1];

        if (slice_count == 0)
        {
            break;
        }

        size_t end_slice = start_slice + slice_count;

        sort_slice_range(mmap_tensor, start_slice, end_slice, N);

        cout << "Worker " << worker_id
             << " PID " << getpid()
             << " sorted slices ["
             << start_slice << ", "
             << end_slice << ")"
             << endl;
    }

    close(read_from_parent);
    close(write_to_parent);

    _exit(0);
}

/*
    Checks whether one row of one slice is sorted.

    Returns:
        1 if sorted
        0 otherwise
*/
int is_row_sorted(uint8_t *tensor, size_t z, size_t i, size_t N)
{
    uint8_t *row_ptr = &tensor[idx(z, i, 0, N)];

    for (size_t j = 1; j < N; j++)
    {
        if (row_ptr[j - 1] > row_ptr[j])
        {
            return 0;
        }
    }

    return 1;
}

/*
    Verifies that every row of every matrix in the mmap copy is sorted.

    Returns:
        1 if all rows are sorted
        0 otherwise
*/
int verify_all(uint8_t *tensor, size_t L, size_t N)
{
    for (size_t z = 0; z < L; z++)
    {
        for (size_t i = 0; i < N; i++)
        {
            if (!is_row_sorted(tensor, z, i, N))
            {
                cout << "Verification failed at slice "
                     << z << ", row " << i << endl;

                return 0;
            }
        }
    }

    return 1;
}

/*
    Prints a small preview window from one N x N slice.

    This avoids printing a huge 640 x 640 matrix.

    Example:
        print_rows = 8
        int_cols = 16

    prints only the first 8 rows and first 16 columns of the selected slice.
*/
void print_tensor_slice_preview(
    uint8_t *tensor,
    size_t L,
    size_t N,
    size_t z,
    size_t print_rows,
    size_t print_cols,
    const char *title)
{
    if (L == 0 || N == 0)
    {
        return;
    }

    if (z >= L)
    {
        z = 0;
    }

    if (print_rows > N)
    {
        print_rows = N;
    }

    if (print_cols > N)
    {
        print_cols = N;
    }

    cout << endl;
    cout << title << endl;
    cout << "Slice z = " << z
         << ", preview "
         << print_rows << " x " << print_cols
         << " from full "
         << N << " x " << N
         << " matrix"
         << endl;

    for (size_t i = 0; i < print_rows; i++)
    {
        for (size_t j = 0; j < print_cols; j++)
        {
            cout << setw(4)
                 << (unsigned int)tensor[idx(z, i, j, N)];
        }

        if (print_cols < N)
        {
            cout << " ...";
        }

        cout << endl;
    }

    if (print_rows < N)
    {
        cout << "..." << endl;
    }
}

/*
    Main parent scheduler process.

    Responsibilities:
        1. Attach to SysV shared memory created by generator8.
        2. Allocate mmap shared memory.
        3. Copy SysV memory into mmap memory.
        4. Create k worker processes.
        5. Give every worker one initial batch.
        6. Continue dynamic request-based scheduling.
        7. Measure total sorting/scheduling time with boost::chrono.
        8. Verify sorted mmap copy.
        9. Print sorted mmap slice and original unsorted SysV slice.
*/
int main(int argc, char **argv)
{
    if (argc < 4 || argc > 8)
    {
        cout << "Usage: "
             << argv[0]
             << " L N k [batch_size] [print_slice] [print_rows] [print_cols]"
             << endl;

        cout << "Example: "
             << argv[0]
             << " 100 640 4 5"
             << endl;

        cout << "Example with print options: "
             << argv[0]
             << " 100 640 4 5 0 8 16"
             << endl;

        return 1;
    }

    size_t L = (size_t)atoll(argv[1]);
    size_t N = (size_t)atoll(argv[2]);
    size_t k = (size_t)atoll(argv[3]);

    size_t batch_size = 1;
    size_t print_slice = 0;
    size_t print_rows = 8;
    size_t print_cols = 16;

    if (argc >= 5)
    {
        batch_size = (size_t)atoll(argv[4]);
    }

    if (argc >= 6)
    {
        print_slice = (size_t)atoll(argv[5]);
    }

    if (argc >= 7)
    {
        print_rows = (size_t)atoll(argv[6]);
    }

    if (argc >= 8)
    {
        print_cols = (size_t)atoll(argv[7]);
    }

    if (L == 0 || N == 0 || k == 0 || batch_size == 0)
    {
        cout << "Invalid arguments." << endl;
        return 1;
    }

    if (print_rows == 0)
    {
        print_rows = 8;
    }

    if (print_cols == 0)
    {
        print_cols = 16;
    }

    if (print_slice >= L)
    {
        cout << "Requested print_slice is outside range. Using slice 0." << endl;
        print_slice = 0;
    }

    size_t total_elements = L * N * N;
    size_t total_bytes = total_elements * sizeof(uint8_t);

    int shmid = shmget(SHM_KEY, total_bytes, 0666);

    if (shmid == -1)
    {
        cout << "shmget failed. Did you run generator8 first?"<< endl;
        _exit(1);
    }

    void *sysv_base = shmat(shmid, NULL, 0);

    if (sysv_base == (void *)-1)
    {
        cout << "shmat failed"<< endl;
        _exit(1);
    }

    uint8_t *sysv_tensor = (uint8_t *)sysv_base;

    void *mmap_base = mmap(
        NULL,
        total_bytes,
        PROT_READ | PROT_WRITE,
        MAP_SHARED | MAP_ANONYMOUS,
        -1,
        0);

    if (mmap_base == MAP_FAILED)
    {
        cout << "mmap failed"<<endl;
        shmdt(sysv_base);
        return 1;
    }

    uint8_t *mmap_tensor = (uint8_t *)mmap_base;
    /*Fast robust copy*/
    memcpy(mmap_tensor, sysv_tensor, total_bytes);

    cout << "Copied SysV shared memory to mmap shared memory." << endl;

    int *parent_to_child = new int[2 * k];
    int *child_to_parent = new int[2 * k];
/*
                 assignment pipe
        parent --------------------> worker
        writes                       reads
        parent_to_child[2*w + 1]     parent_to_child[2*w]


                 request pipe
        worker --------------------> parent
        writes                       reads
        child_to_parent[2*w + 1]     child_to_parent[2*w]
*/
    int *pids = new pid_t[k];

    for (size_t i = 0; i < 2 * k; i++)
    {
        parent_to_child[i] = -1;
        child_to_parent[i] = -1;
    }

    for (size_t i = 0; i < k; i++)
    {
        if (pipe(parent_to_child + 2 * i) == -1)
        {
            perror("pipe parent_to_child failed");
            return 1;
        }

        if (pipe(child_to_parent + 2 * i) == -1)
        {
            perror("pipe child_to_parent failed");
            return 1;
        }
    }
    int root_pid=getpid();

    for (size_t w = 0; w < k; w++)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            perror("fork failed");
            return 1;
        }

        if (pid == 0)
        {
            close_child_unused_fds(
                parent_to_child,
                child_to_parent,
                k,
                w);

            /*
                The child does not need the original SysV shared memory.
                It only sorts the mmap copy.
            */
            shmdt(sysv_base);
	    /*
            Save the descriptors BEFORE delete[].
            */
             int read_fd = parent_to_child[2 * w];
             int write_fd = child_to_parent[2 * w + 1];
             cout << "Worker " << (w + 1)
            << " started. PID = " << getpid()
            << ", read_fd = " << read_fd
            << ", write_fd = " << write_fd
            << endl;

            delete[] parent_to_child;
            delete[] child_to_parent;
            delete[] pids;

            worker_loop(
                w + 1,
		read_fd,
                //parent_to_child[2 * w],
                //child_to_parent[2 * w + 1],
		write_fd,
                mmap_tensor,
                N);
        }

	if (getpid()==root_pid) {
		 cout << "Root parent " << getpid()
         << " created worker " << (w + 1)
         << " with PID " << pid
         << endl;
        	pids[w] = pid;
	} 
        
        if (pid >0 && getpid()!=root_pid)
		_exit(0);
    }

    /*
        Parent closes pipe ends that it does not use.

        Parent writes assignments to parent_to_child[2*w + 1].
        Parent reads requests from child_to_parent[2*w].
    */
    if (getpid()==root_pid) {
      for (size_t w = 0; w < k; w++)
      {
        close(parent_to_child[2 * w]);
        close(child_to_parent[2 * w + 1]);
      }
   }//EOF root getpid
      boost::chrono::high_resolution_clock::time_point sort_start;
      boost::chrono::high_resolution_clock::time_point sort_end;

      sort_start = boost::chrono::high_resolution_clock::now();

      size_t next_slice = 0;
      size_t active_workers = k;

    /*
        Fair initial assignment phase.

        This is the important fix.

        Every worker must first send one request.
        The parent consumes that first request and gives that worker
        one initial batch before the dynamic scheduling loop starts.
    */
    for (size_t w = 0; w < k; w++)
    {
        if (!receive_worker_request(child_to_parent[2 * w]))
        {
            close_parent_worker_fds(
                parent_to_child,
                child_to_parent,
                w);

            active_workers--;
            continue;
        }

        int worker_has_work = send_assignment_or_stop(
            parent_to_child[2 * w + 1],
            &next_slice,
            L,
            batch_size);

        if (!worker_has_work)
        {
            close_parent_worker_fds(
                parent_to_child,
                child_to_parent,
                w);

            active_workers--;
        }
    }

    /*
        Dynamic scheduling phase.

        After the fair initial assignment, whichever worker finishes first
        can request more work.
    */
    while (active_workers > 0)
    {
        /*
    readfds is a set of file descriptors.

    We use it to tell select():

        "Watch these pipe read-ends.
         Wake up when at least one worker sends a request."

    In this program, every worker has one pipe for messages:

        child  --->  parent

    The parent reads from:

        child_to_parent[2 * w]

    The worker writes to:

        child_to_parent[2 * w + 1]
     */
        fd_set readfds;
      /*
    FD_ZERO clears the set.

    Before adding file descriptors with FD_SET, the set must be empty.

    Think of this as:

        readfds = empty set
    */
        FD_ZERO(&readfds);

        int maxfd = -1;

        for (size_t w = 0; w < k; w++)
        {
            int fd = child_to_parent[2 * w];

            if (fd >= 0)
            {
		/*FD_SET(fd, &readfds) adds a file descriptor to the set that select() watches.*/
                FD_SET(fd, &readfds);

                if (fd > maxfd)
                {
                    maxfd = fd;
                }
            }
        }

        if (maxfd < 0)
        {
            break;
        }
	/*
	int select(
    	int nfds,
    	fd_set *readfds,
    	fd_set *writefds,
    	fd_set *exceptfds,
    	struct timeval *timeout NULL, select() waits forever until something happens.
	);
	checks file descriptors from:
	0 up to maxfd
	*/
	struct timeval timeout;
	timeout.tv_sec = 0;
	timeout.tv_usec = 10000;   // 10 ms = 10,000 microseconds
        int ready = select(maxfd + 1, &readfds, NULL, NULL, &timeout);
        //int ready = select(maxfd + 1, &readfds, NULL, NULL, NULL);

        if (ready < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }

            cout << "select failed" << endl ;
	    //_exit(1);
            break;
        }

        for (size_t w = 0; w < k; w++)
        {
            int request_fd = child_to_parent[2 * w];

            if (request_fd < 0)
            {
                continue;
            }

            if (FD_ISSET(request_fd, &readfds))
            {
                if (!receive_worker_request(request_fd))
                {
                    close_parent_worker_fds(
                        parent_to_child,
                        child_to_parent,
                        w);

                    active_workers--;
                    continue;
                }

                int worker_has_work = send_assignment_or_stop(
                    parent_to_child[2 * w + 1],
                    &next_slice,
                    L,
                    batch_size);

                if (!worker_has_work)
                {
                    close_parent_worker_fds(
                        parent_to_child,
                        child_to_parent,
                        w);

                    active_workers--;
                }
            }
        }
    }
/*Wain MINIMA */
  //  while (wait(NULL) > 0)
  // {
  //  }
/*Wait Maxima*/
/*
    Wait for all worker processes using their stored PIDs.

    pids[w] contains the PID returned by fork() for worker w.
*/
if (getpid()==root_pid)
for (size_t w = 0; w < k; w++)
{
    int status = 0;

    if (pids[w] <= 0)
    {
        continue;
    }

    while (waitpid(pids[w], &status, 0) == -1)
    {
        if (errno == EINTR)
        {
            continue;
        }

        perror("waitpid failed");
        break;
    }

    if (WIFEXITED(status))
    {
        cout << "Worker " << (w + 1)
             << " exited with code "
             << WEXITSTATUS(status)
             << endl;
    }
    else if (WIFSIGNALED(status))
    {
        cout << "Worker " << (w + 1)
             << " killed by signal "
             << WTERMSIG(status)
             << endl;
    }
}
/*EOF Wait Maxima*/    
    sort_end = boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds sort_ms =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(
            sort_end - sort_start);

    cout << endl;
    cout << "All workers finished." << endl;
    cout << "Total sorting process duration: "
         << sort_ms.count()
         << " ms"
         << endl;

    if (verify_all(mmap_tensor, L, N))
    {
        cout << "Verification success: all rows are sorted." << endl;
    }
    else
    {
        cout << "Verification failed." << endl;
    }

    print_tensor_slice_preview(
        mmap_tensor,
        L,
        N,
        print_slice,
        print_rows,
        print_cols,
        "Sorted mmap copy");

    print_tensor_slice_preview(
        sysv_tensor,
        L,
        N,
        print_slice,
        print_rows,
        print_cols,
        "Original unsorted SysV shared memory");

    munmap(mmap_base, total_bytes);
    shmdt(sysv_base);

    delete[] parent_to_child;
    delete[] child_to_parent;
    delete[] pids;

    return 0;
}
