#ifndef SCHED_PROL_BOUND_H
#define SCHED_PROL_BOUND_H

#include <opencv2/opencv.hpp>
#include <iostream>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <vector>
#include <unistd.h>
#include <sys/wait.h>
#include <sys/mman.h>
#include <sys/resource.h>
#include <sched.h>
#include <cerrno>
#include <boost/chrono.hpp>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

// =========================
// CPU affinity helper
// =========================
inline bool bind_thread_to_cpu(int cpuId)
{
    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(cpuId, &cpuset);

    int rc = sched_setaffinity(0, sizeof(cpu_set_t), &cpuset);
    if (rc == 0)
    {
        cout << "[PID " << getpid() << "] bound to CPU " << cpuId << endl;
        return true;
    }
    else
    {
        perror("sched_setaffinity failed");
        return false;
    }
}

// =========================
// Shared video access helper
// =========================
inline uint8_t& PIXEL(uint8_t* video, int f, int i, int j, int k,
                      int H, int W, int C)
{
    return video[(((size_t)f * H + i) * W + j) * C + k];
}

// =========================
// Sorting algorithm selector
// =========================
enum SortAlgorithm
{
    USE_QUICK = 0,
    USE_BUBBLE = 1,
    USE_INSERTION = 2,
    USE_SELECTION = 3
};

inline const char* sort_algorithm_name(SortAlgorithm algo)
{
    switch (algo)
    {
        case USE_QUICK:     return "quick_sort";
        case USE_BUBBLE:    return "bubble_sort";
        case USE_INSERTION: return "insertion_sort";
        case USE_SELECTION: return "selection_sort";
        default:            return "unknown";
    }
}

inline void sort_row_by_algorithm(row<uint8_t>& r, SortAlgorithm algo)
{
    switch (algo)
    {
        case USE_QUICK:
            quick_sort(r);
            break;
        case USE_BUBBLE:
            bubble_sort(r);
            break;
        case USE_INSERTION:
            insertion_sort(r);
            break;
        case USE_SELECTION:
            selection_sort(r);
            break;
        default:
            quick_sort(r);
            break;
    }
}

// =========================
// Worker task
// Each worker processes frames:
// worker_id, worker_id + K, worker_id + 2K, ...
// =========================
inline void process_video_rows(uint8_t* video,
                               int worker_id,
                               int actualFrames,
                               int H,
                               int W,
                               int C,
                               int K,
                               long numCPUs,
                               SortAlgorithm algo)
{
    pid_t mypid = getpid();
    pid_t parentpid = getppid();

    int cpuId = (worker_id + 1) % ((numCPUs > 0) ? numCPUs : 1);
    bind_thread_to_cpu(cpuId);

    cout << "[Worker " << worker_id << "] STARTED"
         << " | PID = " << mypid
         << " | Parent PID = " << parentpid
         << " | nice = " << getpriority(PRIO_PROCESS, 0)
         << " | CPU = " << cpuId
         << " | algorithm = " << sort_algorithm_name(algo)
         << endl;

    boost::chrono::high_resolution_clock::time_point worker_t1 =
        boost::chrono::high_resolution_clock::now();

    cout << "[Worker " << worker_id << "] Assigned frames: ";
    bool first = true;
    for (int ff = worker_id; ff < actualFrames; ff += K)
    {
        if (!first) cout << ", ";
        cout << ff;
        first = false;
    }
    cout << endl;

    int framesProcessed = 0;
    int rowsProcessed = 0;

    for (int ff = worker_id; ff < actualFrames; ff += K)
    {
        cout << "[Worker " << worker_id << " | PID " << mypid
             << "] Processing frame " << ff << "..." << endl;

        framesProcessed++;

        for (int i = 0; i < H; i++)
        {
            uint8_t* grayRow = new uint8_t[W];

            for (int j = 0; j < W; j++)
            {
                grayRow[j] = PIXEL(video, ff, i, j, 0, H, W, C);
            }

            row<uint8_t> gray(grayRow, W);
            sort_row_by_algorithm(gray, algo);

            for (int j = 0; j < W; j++)
            {
                PIXEL(video, ff, i, j, 0, H, W, C) = gray[j];
                PIXEL(video, ff, i, j, 1, H, W, C) = gray[j];
                PIXEL(video, ff, i, j, 2, H, W, C) = gray[j];
            }

            rowsProcessed++;
        }

        cout << "[Worker " << worker_id << " | PID " << mypid
             << "] Finished frame " << ff << "." << endl;
    }

    boost::chrono::high_resolution_clock::time_point worker_t2 =
        boost::chrono::high_resolution_clock::now();

    boost::chrono::microseconds worker_us =
        boost::chrono::duration_cast<boost::chrono::microseconds>(worker_t2 - worker_t1);

    cout << "[Worker " << worker_id << "] FINISHED"
         << " | PID = " << mypid
         << " | Frames processed = " << framesProcessed
         << " | Rows processed = " << rowsProcessed
         << " | Processing time = " << worker_us.count() << " microseconds"
         << endl;
}

// =========================
// Prolific scheduler
//
// K = fork depth
// total workers = 2^K
//
// Every process forks at every level.
// Each final process gets a unique worker_id.
// The original root returns to main.cpp after waiting,
// while non-root processes exit.
// =========================
inline void run_video_scheduler(uint8_t* video,
                                int actualFrames,
                                int H,
                                int W,
                                int C,
                                int K,
                                long numCPUs,
                                SortAlgorithm algo = USE_QUICK)
{
    pid_t original_root = getpid();

    int worker_id = 0;
    vector<pid_t> my_children;

    for (int level = 0; level < K; ++level)
    {
        pid_t pid = fork();

        if (pid < 0)
        {
            perror("fork failed");

            for (pid_t child : my_children)
            {
                int status = 0;
                waitpid(child, &status, 0);
            }

            if (getpid() != original_root)
                _exit(1);
            return;
        }

        if (pid == 0)
        {
            // Child branch adds bit 1
            worker_id = worker_id * 2 + 1;

            cout << "[PID " << getpid()
                 << "] created at level " << level
                 << " with worker_id = " << worker_id
                 << endl;
        }
        else
        {
            // Parent branch adds bit 0
            my_children.push_back(pid);
            worker_id = worker_id * 2;
        }
    }

    int totalWorkers = 1 << K;

    process_video_rows(video, worker_id, actualFrames, H, W, C,
                       totalWorkers, numCPUs, algo);

    for (pid_t child : my_children)
    {
        int status = 0;
        pid_t finished = waitpid(child, &status, 0);

        if (finished > 0)
        {
            cout << "[PID " << getpid() << "] Child PID "
                 << finished << " finished";

            if (WIFEXITED(status))
                cout << " with exit code " << WEXITSTATUS(status);
            else if (WIFSIGNALED(status))
                cout << " due to signal " << WTERMSIG(status);

            cout << "." << endl;
        }
    }

    // Only children exit. Root returns to main.cpp.
    if (getpid() != original_root)
        _exit(0);
}

#endif