#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cstdint>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

// flat row-major tensor: [slices][n][n]
uint8_t* tensor;

int n;
int slices;
int k;

int next_slice = 0;
// for poll int finished_slices = 0;

boost::mutex mtx;

// row-major index
inline int idx(int z, int i, int j)
{
    return z * n * n + i * n + j;
}

// --------------------------------
void sort_slice_rows_dynamic(int tid)
{
    while(true)
    {
        int slice;

        // critical section: get next slice
        mtx.lock();

        if(next_slice >= slices)
        {
            mtx.unlock();
            return;
        }

        slice = next_slice++;

        mtx.unlock();


        cout << "Thread " << tid << " START slice " << slice << endl;

        for(int i=0; i<n; i++)
        {
            uint8_t* row_ptr = &tensor[idx(slice, i, 0)];
            row<uint8_t> r(row_ptr, n);
            quick_sort(r);
	     for(int j = 0; j < n; j++)
    		row_ptr[j] = r[j];
        }

        cout << "Thread " << tid << " FINISH slice " << slice << endl;

        // update finished counter
        /*mtx.lock();
		    finished_slices++;
        mtx.unlock();*/
		    }
}

void print_tensor()
{
    for(int z=0; z<slices; z++)
    {
        cout << "\nSlice " << z << ":\n";
        for(int i=0; i<n; i++)
        {
            for(int j=0; j<n; j++)
                cout << (int)tensor[idx(z, i, j)] << "\t";
            cout << endl;
        }
    }
}


// --------------------------------
int main()
{
    n = 6;        // frame size (n x n)
    slices = 500;  // number of tensor slices
    k = 4;        // worker threads

    srandom(time(NULL));

    // allocate flat row-major tensor
    tensor = new uint8_t[slices * n * n];


    // fill tensor with grayscale values
    for(int z=0; z<slices; z++)
        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                tensor[idx(z, i, j)] = random()%256;

    //cout << "Initial tensor:\n";
    //print_tensor();

				    // --------------------------------
    // create thread pool
    // --------------------------------
    boost::thread_group workers;

    boost::chrono::high_resolution_clock::time_point t1 =
        boost::chrono::high_resolution_clock::now();

    for(int tid=0; tid<k; tid++)
        workers.create_thread(boost::bind(sort_slice_rows_dynamic, tid));


    // --------------------------------
    // poll until all slices processed
    // --------------------------------
    /*while(true)
    {
        mtx.lock();
        if(finished_slices >= slices)
        {
            mtx.unlock();
            break;
        }
        mtx.unlock();

        boost::this_thread::sleep_for(boost::chrono::milliseconds(1));
    }*/


    // join all threads
    workers.join_all();

    boost::chrono::high_resolution_clock::time_point t2 =
        boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds elapsed =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);


    // --------------------------------
    // print results
    // --------------------------------
    cout << "\nSorted tensor:\n";
    print_tensor();

    cout << "\nTotal thread execution time: "
         << elapsed.count() << " ms\n";

    delete[] tensor;

    return 0;
}
