#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <iostream>
#include <cstdint>
#include <cstdlib>
#include <ctime>
#include "processor.hpp"
#include "row.hpp"
#include "sort_algorithms.hpp"
using namespace std;

//flat row-major tensor: [slices][n][n]
int* tensor = nullptr;

int n;
int slices;
int k;
//1.processor

Processor processor;
//params.stride=2;
//params.mode="valid";
//CNNParams params = {1, "valid"};
std::string inputPath = "input.mp4";
std::string outputPath = "output.mp4"; 
VideoMemoryRaw* vm;

int idx(int z, int i, int j)
{
    return z * n * n + i * n + j;
}

// --------------------------------------------------
// Static slice sharing:
// each thread gets slices/k slices
// last thread handles the remainder
// --------------------------------------------------
void sort_slices_static(int tid)
{
    int base = slices / k;
    int rem  = slices % k;

    int start_slice = tid * base;
    int end_slice   = start_slice + base;

    // last thread takes the remainder too
    if (tid == k - 1)
        end_slice += rem;

    for (int z = start_slice; z < end_slice; z++)
    {
        cout << "Thread " << tid << " handles slice " << z << endl;
	//processor.processVideoFromMemoryIndex(vm, z);
        for (int i = 0; i < n; i++)
        {
            int* row_ptr = &tensor[idx(z, i, 0)];
            row<int> r(row_ptr, n);
            quick_sort(r);
	    for(int j = 0; j < n; j++)
               row_ptr[j] = r[j];
        }
	
    }
}
//Stacic Slicing 
//Balanced Slice
// --------------------------------------------------
// Each thread gets a static block of slices
//Balances excesive load on top
//Works also if k>slices
void sort_slices_static_b(int tid)
{
    int slices_per_thread = slices / k;
    int extra = slices % k;

    int start_slice, end_slice;

    if (tid < extra)
    {
        start_slice = tid * (slices_per_thread + 1);
        end_slice   = start_slice + (slices_per_thread + 1);
    }
    else
    {
        start_slice = extra * (slices_per_thread + 1) + (tid - extra) * slices_per_thread;
        end_slice   = start_slice + slices_per_thread;
    }

    for (int z = start_slice; z < end_slice; z++)
    {
        for (int i = 0; i < n; i++)
        {
            int* row_ptr = &tensor[idx(z, i, 0)];
            row<int> r(row_ptr, n);
            quick_sort(r);
        }
    }
}

void print_tensor()
{
    for (int z = 0; z < slices; z++)
    {
        cout << "\nSlice " << z << ":\n";
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < n; j++)
                cout << tensor[idx(z, i, j)] << "\t";
            cout << endl;
        }
    }
}

int main()
{
    //n = 6;
    //slices = 300;
    k = 719;
    srand((unsigned)time(nullptr));
    //2.Load video to memory
    vm = processor.loadVideoToMemoryRowMajor(inputPath);
    slices=vm->frameCount;
    tensor = new int[slices * n * n];

    /*for (int z = 0; z < slices; z++)
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                tensor[idx(z, i, j)] = rand() % 100;
     */
    //cout << "Initial tensor:\n";
    //print_tensor();

    boost::thread_group workers;
     // -----------------------------
    // start total execution timer
    // -----------------------------
boost::chrono::high_resolution_clock::time_point t1 =
    boost::chrono::high_resolution_clock::now();


    for (int tid = 0; tid < k; tid++)
        workers.create_thread(boost::bind(sort_slices_static, tid));

    workers.join_all();
     // -----------------------------
    // stop total execution timer
    // -----------------------------

boost::chrono::high_resolution_clock::time_point t2 =
    boost::chrono::high_resolution_clock::now();

boost::chrono::milliseconds elapsed =
    boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);

    //cout << "\nSorted tensor:\n";
    //print_tensor();
    processor.saveVideoFromMemory(vm, outputPath);
    processor.freeVideoMemory(vm);
    cout << "\nTotal thread execution time: "
       << elapsed.count() << " ms\n";
    delete[] tensor;
    return 0;
}
