/* 
take CHUNK slices instead of one
process them
repeat
*/
#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cstdint>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

// flat row-major tensor: [slices][n][n]
uint8_t *tensor;

int n;
int slices;
int k;

int next_slice = 0;
int CHUNK = 3;

boost::mutex mtx;

// row-major index
inline int idx(int z, int i, int j)
{
    return z * n * n + i * n + j;
}


// --------------------------------
void sort_slice_rows_chunk(int tid)
{
    while(true)
    {
        int start;
        int end;

        // critical section: assign chunk
        mtx.lock();

        if(next_slice >= slices)
        {
            mtx.unlock();
            return;
        }

        start = next_slice;
        next_slice += CHUNK;

        mtx.unlock();

        end = start + CHUNK;

        if(end > slices)
            end = slices;

        for(int slice = start; slice < end; slice++)
        {
            cout << "Thread " << tid << " START slice " << slice << endl;

            for(int i=0; i<n; i++)
            {
                uint8_t* row_ptr = &tensor[idx(slice, i, 0)];
		//cout << "Before: ";
		//for(int j=0; j<n; j++) cout << (int)row_ptr[j] << " ";
		//cout << endl;

                row<uint8_t> r(row_ptr, n);
                quick_sort(r);
		for(int j = 0; j < n; j++)
    			row_ptr[j] = r[j];
		//cout << "After:  ";
		//for(int j=0; j<n; j++) cout << (int)row_ptr[j] << " ";
		//cout << endl;
            }

            cout << "Thread " << tid << " FINISH slice " << slice << endl;
        }
    }
}

void print_tensor()
{
    for(int z=0; z<slices; z++)
    {
        cout << "\nSlice " << z << ":\n";
        for(int i=0; i<n; i++)
        {
            for(int j=0; j<n; j++)
                cout << (int)tensor[idx(z, i, j)] << "\t";
            cout << endl;
        }
    }
}


// --------------------------------
int main()
{
    n = 6;
    slices = 500;
    k = 4;

    srandom(time(NULL));

    // allocate flat row-major tensor
    tensor = new uint8_t[slices * n * n];

    // fill tensor
    for(int z=0; z<slices; z++)
        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                tensor[idx(z, i, j)] = random()%256;

    //cout << "Initial tensor:\n";
    //print_tensor();

    // --------------------------------
    // create thread pool
    // --------------------------------
    boost::thread_group workers;

    boost::chrono::high_resolution_clock::time_point t1 =
        boost::chrono::high_resolution_clock::now();

    for(int tid=0; tid<k; tid++)
        workers.create_thread(boost::bind(sort_slice_rows_chunk, tid));

    // wait for all workers
    workers.join_all();

    boost::chrono::high_resolution_clock::time_point t2 =
        boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds elapsed =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);

    // --------------------------------
    // print results
    // --------------------------------
    cout << "\nSorted tensor:\n";
    print_tensor();

    cout << "\nTotal thread execution time: "
         << elapsed.count() << " ms\n";

    delete[] tensor;

    return 0;
}
