/*
Guided scheduler:

Threads do not receive a fixed number of slices in advance.

Instead, whenever a thread becomes free, it enters a critical section
and checks how many slices remain unassigned.

It then takes a chunk whose size is based on the remaining work:

    chunk = remaining / k

where:
- remaining = slices - next_slice
- k = number of threads

This means:

- at the beginning, when many slices remain, threads take larger chunks
- later, when fewer slices remain, chunk sizes become smaller
- near the end, each thread may take just 1 slice

This reduces scheduling overhead compared to fully dynamic scheduling
(one slice at a time), while still improving load balance compared to
pure static scheduling.

Each claimed slice is processed row by row:
- get pointer to the row in row-major memory
- wrap it in row<uint8_t>
- apply quick_sort on that row
*/

#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cstdint>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

uint8_t *tensor;

int n;
int slices;
int k;

int next_slice = 0;
boost::mutex mtx;

// --------------------------------
inline int idx(int z, int i, int j)
{
    return z * n * n + i * n + j;
}

// --------------------------------
void print_tensor()
{
    for(int z = 0; z < slices; z++)
    {
        cout << "\nSlice " << z << ":\n";
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < n; j++)
                cout << (int)tensor[idx(z, i, j)] << "\t";
            cout << endl;
        }
    }
}

// --------------------------------
// Guided scheduling
// --------------------------------
void sort_rows_guided(int tid)
{
    while(true)
    {
        int start, end;

        mtx.lock();

        if(next_slice >= slices)
        {
            mtx.unlock();
            return;
        }

        int remaining = slices - next_slice;
        int chunk = remaining / k;

        if(chunk < 1)
            chunk = 1;

        start = next_slice;
        next_slice += chunk;
        end = next_slice;

        if(end > slices)
            end = slices;

        mtx.unlock();

        for(int z = start; z < end; z++)
        {
            cout << "Thread " << tid
                 << " GUIDED slice " << z << endl;

            for(int i = 0; i < n; i++)
            {
                uint8_t* row_ptr = &tensor[idx(z, i, 0)];
                row<uint8_t> r(row_ptr, n);
                quick_sort(r);
		 for(int j = 0; j < n; j++)
                   row_ptr[j] = r[j];
            }
        }
    }
}

// --------------------------------
int main()
{
    n = 6;
    slices = 500;
    k = 4;

    srandom(time(NULL));

    tensor = new uint8_t[slices * n * n];

    for(int z = 0; z < slices; z++)
        for(int i = 0; i < n; i++)
            for(int j = 0; j < n; j++)
                tensor[idx(z, i, j)] = random() % 256;

    //cout << "Initial tensor:\n";
    //print_tensor();

    boost::thread_group workers;

    boost::chrono::high_resolution_clock::time_point t1 =
        boost::chrono::high_resolution_clock::now();

    for(int tid = 0; tid < k; tid++)
        workers.create_thread(boost::bind(sort_rows_guided, tid));

    workers.join_all();

    boost::chrono::high_resolution_clock::time_point t2 =
        boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds elapsed =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);

    cout << "\nSorted tensor:\n";
    print_tensor();

    cout << "\nTotal execution time: "
         << elapsed.count() << " ms\n";

    delete[] tensor;

    return 0;
}
