#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cstdint>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

// flat row-major tensor: [slices][n][n]
uint8_t* tensor;

int n;
int slices;
int k;

int next_slice = 0;
// for poll int finished_slices = 0;

boost::mutex mtx;

// row-major index
inline int idx(int z, int i, int j)
{
    return z * n * n + i * n + j;
    }

    // --------------------------------
    void sort_slice_rows_dynamic(int tid)
    {
        while(true)
            {
                    int slice;

                            // critical section: get next slice
                                    mtx.lock();

                                            if(next_slice >= slices)
                                                    {
                                                                mtx.unlock();
                                                                            return;
                                                                                    }

                                                                                            slice = next_slice++;

                                                                                                    mtx.unlock();


                                                                                                            cout << "Thread " << tid << " START slice " << slice << endl;

                                                                                                                    for(int i=0; i<n; i++)
                                                                                                                            {
                                                                                                                                        uint8_t* row_ptr = &tensor[idx(slice, i, 0)];
                                                                                                                                                    row<uint8_t> r(row_ptr, n);
                                                                                                                                                                quick_sort(r);
                                                                                                                                                                	     for(int j = 0; j < n; j++)
                                                                                                                                                                	         		row_ptr[j] = r[j];
                                                                                                                                                                	         		        }

                                                                                                                                                                	         		                cout << "Thread " << tid << " FINISH slice " << slice << endl;

                                                                                                                                                                	         		                        // update finished counter
                                                                                                                                                                	         		                                /*mtx.lock();
                                                                                                                                                                	         		                                		    finished_slices++;
                                                                                                                                                                	         		                                		            mtx.unlock();*/
                                                                                                                                                                	         		                                		            		    }
                                                                                                                                                                	         		                                		            		    }

                                                                                                                                                                	         		                                		            		    void print_tensor()
                                                                                                                                                                	         		                                		            		    {
                                                                                                                                                                	         		                                		            		        for(int z=0; z<slices; z++)
                                                                                                                                                                	         		                                		            		            {
                                                                                                                                                                	         		                                		            		                    cout << "\nSlice " << z << ":\n";
                                                                                                                                                                	         		                                		            		                            for(int i=0; i<n; i++)
                                                                                                                                                                	         		                                		            		                                    {
                                                                                                                                                                	         		                                		            		                                                for(int j=0; j<n; j++)
                                                                                                                                                                	         		                                		            		                                                                cout << (int)tensor[idx(z, i, j)] << "\t";
                                                                                                                                                                	         		                                		            		                                                                            cout << endl;
                                                                                                                                                                	         		                                		            		                                                                                    }
                                                                                                                                                                	         		                                		            		                                                                                        }
                                                                                                                                                                	         		                                		            		                                                                                        }


                                                                                                                                                                	         		                                		            		                                                                                        // --------------------------------
                                                                                                                                                                	         		                                		            		                                                                                        int main()
                                                                                                                                                                	         		                                		            		                                                                                        {
                                                                                                                                                                	         		                                		            		                                                                                            n = 6;        // frame size (n x n)
                                                                                                                                                                	         		                                		            		                                                                                                slices = 500;  // number of tensor slices
                                                                                                                                                                	         		                                		            		                                                                                                    k = 4;        // worker threads

                                                                                                                                                                	         		                                		            		                                                                                                        srandom(time(NULL));

                                                                                                                                                                	         		                                		            		                                                                                                            // allocate flat row-major tensor
                                                                                                                                                                	         		                                		            		                                                                                                                tensor = new uint8_t[slices * n * n];


                                                                                                                                                                	         		                                		            		                                                                                                                    // fill tensor with grayscale values
                                                                                                                                                                	         		                                		            		                                                                                                                        for(int z=0; z<slices; z++)
                                                                                                                                                                	         		                                		            		                                                                                                                                for(int i=0; i<n; i++)
                                                                                                                                                                	         		                                		            		                                                                                                                                            for(int j=0; j<n; j++)
                                                                                                                                                                	         		                                		            		                                                                                                                                                            tensor[idx(z, i, j)] = random()%256;

                                                                                                                                                                	         		                                		            		                                                                                                                                                                //cout << "Initial tensor:\n";
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    //print_tensor();

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				    // --------------------------------
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				        // create thread pool
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				            // --------------------------------
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                boost::thread_group workers;

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                    boost::chrono::high_resolution_clock::time_point t1 =
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                            boost::chrono::high_resolution_clock::now();

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                for(int tid=0; tid<k; tid++)
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                        workers.create_thread(boost::bind(sort_slice_rows_dynamic, tid));


                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                            // --------------------------------
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                // poll until all slices processed
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                    // --------------------------------
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                        /*while(true)
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                            {
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                    mtx.lock();
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                            if(finished_slices >= slices)
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                    {
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                mtx.unlock();
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                            break;
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                    }
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                            mtx.unlock();

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                    boost::this_thread::sleep_for(boost::chrono::milliseconds(1));
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                        }*/


                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                            // join all threads
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                workers.join_all();

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                    boost::chrono::high_resolution_clock::time_point t2 =
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                            boost::chrono::high_resolution_clock::now();

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                boost::chrono::milliseconds elapsed =
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                        boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);


                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                            // --------------------------------
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                // print results
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                    // --------------------------------
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                        cout << "\nSorted tensor:\n";
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                            print_tensor();

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                                cout << "\nTotal thread execution time: "
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                                         << elapsed.count() << " ms\n";

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                                             delete[] tensor;

                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                                                 return 0;
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                                                 }
                                                                                                                                                                	         		                                		            		                                                                                                                                                                    				                                                                                                                                                                                                                 
