/*
Adaptive chunk scheduler:

This scheduler is similar to guided scheduling, but the chunk size is not
based only on the remaining work. It also adapts to the observed speed of
each thread.

How it starts:

1. At the beginning, no thread has timing history yet.
2. So every thread starts with:
       speed_factor = 1.0
3. The first chunk is therefore based only on the remaining work:
       base_chunk = remaining / k
4. This means the scheduler begins like a guided scheduler.

How it adapts and changes chunk size:

1. After a thread finishes a chunk, its execution time is measured using
   boost::chrono::high_resolution_clock.
2. The scheduler stores for each thread:
       total_time_ms[tid]
       total_slices_done[tid]
3. From these, it computes the average time per slice for that thread:
       my_avg = total_time_ms[tid] / total_slices_done[tid]
4. It also computes the average slice time across all threads with history:
       global_avg
5. Then it computes:
       speed_factor = global_avg / my_avg

Interpretation:

- if a thread is faster than average, then my_avg is smaller,
  so speed_factor becomes greater than 1
  -> that thread receives a larger chunk

- if a thread is slower than average, then my_avg is larger,
  so speed_factor becomes less than 1
  -> that thread receives a smaller chunk

So the chunk size is:

    chunk = base_chunk * speed_factor

How chunk size evolves over time:

- initially, chunks are moderate and guided by remaining/k
- after timing history is collected, faster threads may receive larger chunks
- slower threads may receive smaller chunks
- near the end, remaining work becomes small, so chunk sizes shrink again
- chunk is always clamped so that:
      chunk >= 1
      chunk <= remaining

So this scheduler:
- starts like guided scheduling
- increases chunk size for faster threads
- decreases chunk size for slower threads
- naturally shrinks chunks again near the end as work runs out

This keeps scheduling overhead lower than fully dynamic scheduling,
while adapting better than plain guided scheduling when thread speeds differ.

Each slice is processed row by row:
- get pointer to the row in row-major memory
- wrap it in row<uint8_t>
- apply quick_sort on that row
*/

#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cstdint>
#include <vector>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

uint8_t *tensor;

int n;
int slices;
int k;

int next_slice = 0;

boost::mutex mtx;

// per-thread adaptive statistics
vector<double> total_time_ms;
vector<int> total_slices_done;

// --------------------------------
inline int idx(int z, int i, int j)
{
    return z * n * n + i * n + j;
}

// --------------------------------
void print_tensor()
{
    for(int z = 0; z < slices; z++)
    {
        cout << "\nSlice " << z << ":\n";
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < n; j++)
                cout << (int)tensor[idx(z, i, j)] << "\t";
            cout << endl;
        }
    }
}

// --------------------------------
void sort_rows_adaptive(int tid)
{
    while(true)
    {
        int start, end, chunk;

        mtx.lock();

        if(next_slice >= slices)
        {
            mtx.unlock();
            return;
        }

        int remaining = slices - next_slice;

        int base_chunk = remaining / k;
        if(base_chunk < 1)
            base_chunk = 1;

        // default factor = 1.0 until enough history exists
        double speed_factor = 1.0;

        if(total_slices_done[tid] > 0)
        {
            double my_avg = total_time_ms[tid] / total_slices_done[tid];

            double global_avg = 0.0;
            int active_threads = 0;

            for(int t = 0; t < k; t++)
            {
                if(total_slices_done[t] > 0)
                {
                    global_avg += total_time_ms[t] / total_slices_done[t];
                    active_threads++;
                }
            }

            if(active_threads > 0)
            {
                global_avg /= active_threads;

                if(my_avg > 0.0)
                    speed_factor = global_avg / my_avg;
            }
        }

        chunk = (int)(base_chunk * speed_factor);

        if(chunk < 1)
            chunk = 1;
        if(chunk > remaining)
            chunk = remaining;

        start = next_slice;
        next_slice += chunk;
        end = next_slice;

        mtx.unlock();

        boost::chrono::high_resolution_clock::time_point c1 =
            boost::chrono::high_resolution_clock::now();

        for(int z = start; z < end; z++)
        {
            cout << "Thread " << tid
                 << " ADAPTIVE slice " << z
                 << " chunk=" << chunk << endl;

            for(int i = 0; i < n; i++)
            {
                uint8_t* row_ptr = &tensor[idx(z, i, 0)];
                row<uint8_t> r(row_ptr, n);
                quick_sort(r);
		  for(int j = 0; j < n; j++)
               		row_ptr[j] = r[j];
            }
        }

        boost::chrono::high_resolution_clock::time_point c2 =
            boost::chrono::high_resolution_clock::now();

        boost::chrono::duration<double, boost::milli> elapsed =
            c2 - c1;

        mtx.lock();
        total_time_ms[tid] += elapsed.count();
        total_slices_done[tid] += (end - start);
        mtx.unlock();
    }
}

// --------------------------------
int main()
{
    n = 6;
    slices = 500;
    k = 4;

    srandom(time(NULL));

    tensor = new uint8_t[slices * n * n];

    for(int z = 0; z < slices; z++)
        for(int i = 0; i < n; i++)
            for(int j = 0; j < n; j++)
                tensor[idx(z, i, j)] = random() % 256;

    total_time_ms.resize(k, 0.0);
    total_slices_done.resize(k, 0);

    //cout << "Initial tensor:\n";
    //print_tensor();

    boost::thread_group workers;

    boost::chrono::high_resolution_clock::time_point t1 =
        boost::chrono::high_resolution_clock::now();

    for(int tid = 0; tid < k; tid++)
        workers.create_thread(boost::bind(sort_rows_adaptive, tid));

    workers.join_all();

    boost::chrono::high_resolution_clock::time_point t2 =
        boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds elapsed =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);

    cout << "\nSorted tensor:\n";
    print_tensor();

    cout << "\nPer-thread timing stats:\n";
    for(int tid = 0; tid < k; tid++)
    {
        cout << "Thread " << tid
             << ": slices = " << total_slices_done[tid]
             << ", total_time = " << total_time_ms[tid] << " ms";

        if(total_slices_done[tid] > 0)
            cout << ", avg_per_slice = "
                 << (total_time_ms[tid] / total_slices_done[tid]) << " ms";

        cout << endl;
    }

    cout << "\nTotal execution time: "
         << elapsed.count() << " ms\n";

    delete[] tensor;

    return 0;
}
