/*Below is the chunk-stealing scheduler version of your program.

Chunk stealing combines:
chunk scheduling -> threads process groups of slices
work stealing -> idle threads steal chunks from others
Each thread has a local deque of chunks.
When its queue is empty it steals a chunk from another thread.

This reduces:
lock contention, synchronization overhead, load imbalance*/

#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <vector>
#include <deque>
#include <cstdint>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

uint8_t *tensor;

int n;
int slices;
int k;

int CHUNK = 2;

vector<deque<int> > queues;
vector<boost::mutex*> qmutex;

// --------------------------------
inline int idx(int z, int i, int j)
{
    return z * n * n + i * n + j;
}

// --------------------------------
void sort_slice_rows(int tid)
{
    while(true)
    {
        int start = -1;

        // try own queue first
        qmutex[tid]->lock();

        if(!queues[tid].empty())
        {
            start = queues[tid].front();
            queues[tid].pop_front();
        }

        qmutex[tid]->unlock();

        // attempt stealing
        if(start == -1)
        {
            for(int v = 0; v < k; v++)
            {
                if(v == tid)
                    continue;

                qmutex[v]->lock();

                if(!queues[v].empty())
                {
                    start = queues[v].back();
                    queues[v].pop_back();
                    qmutex[v]->unlock();
                    break;
                }

                qmutex[v]->unlock();
            }
        }

        // nothing left anywhere
        if(start == -1)
            return;

        int end = start + CHUNK;

        if(end > slices)
            end = slices;

        for(int slice = start; slice < end; slice++)
        {
            cout << "Thread " << tid
                 << " START slice " << slice << endl;

            for(int i = 0; i < n; i++)
            {
                uint8_t* row_ptr = &tensor[idx(slice, i, 0)];
                row<uint8_t> r(row_ptr, n);
                quick_sort(r);
		for(int j = 0; j < n; j++)
    			row_ptr[j] = r[j];
            }

            cout << "Thread " << tid
                 << " FINISH slice " << slice << endl;
        }
    }
}

void print_tensor()
{
    for(int z = 0; z < slices; z++)
    {
        cout << "\nSlice " << z << ":\n";
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < n; j++)
                cout << (int)tensor[idx(z, i, j)] << "\t";
            cout << endl;
        }
    }
}

// --------------------------------
int main()
{
    n = 6;
    slices = 500;
    k = 4;

    srandom(time(NULL));

    // allocate tensor
    tensor = new uint8_t[slices * n * n];

    // fill tensor
    for(int z = 0; z < slices; z++)
        for(int i = 0; i < n; i++)
            for(int j = 0; j < n; j++)
                tensor[idx(z, i, j)] = random() % 256;

    //cout << "Initial tensor:\n";
    //print_tensor();

    // create local chunk queues
    queues.resize(k);
    qmutex.resize(k);

    for(int i = 0; i < k; i++)
        qmutex[i] = new boost::mutex;

    for(int s = 0; s < slices; s += CHUNK)
    {
        int owner = (s / CHUNK) % k;
        queues[owner].push_back(s);
    }

    // create workers
    boost::thread_group workers;

    boost::chrono::high_resolution_clock::time_point t1 =
        boost::chrono::high_resolution_clock::now();

    for(int tid = 0; tid < k; tid++)
        workers.create_thread(boost::bind(sort_slice_rows, tid));

    workers.join_all();

    boost::chrono::high_resolution_clock::time_point t2 =
        boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds elapsed =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);
	
    // print results
    cout << "\nSorted tensor\n";
    print_tensor();

    cout << "\nTotal thread execution time: "
         << elapsed.count() << " ms\n";

    delete[] tensor;

    // free mutexes
    for(int i = 0; i < k; i++)
        delete qmutex[i];

    return 0;
}
