/*
Dynamic contention scheduler with:
- maximum worker pool = get_max_threads/processors
- additive increase
- multiplicative decrease only when EWMA raw utilization exceeds processor count

Behavior:
1. Start with contention_window = 1
2. Give a chunk of slices to worker threads
3. If ewma_utilization_raw > processors:
      contention_window = contention_window / 2
4. Otherwise:
      contention_window++
5. Each worker thread exists for the whole program, but only
   contention_window chunks may be "in flight" at once

while getEwmaRawUtilization()==0 keep the cwnd increase up to the number of processors
*/

#include <boost/thread.hpp>
#include <boost/chrono.hpp>
#include <boost/bind/bind.hpp>

#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cstdint>
#include <vector>
#include <algorithm>
#include <limits>
//#include <new>
#include <cmath>
#include "row.hpp"
#include "sort_algorithms.hpp"
#include "utilization_monitor.hpp"

using namespace std;

uint8_t* tensor = 0;

size_t n;
size_t slices;

boost::mutex mtx;
boost::condition_variable cv;

size_t next_slice = 0;
size_t in_flight = 0;

size_t processors = 1;
size_t worker_threads = 1;
size_t contention_window = 1;

const size_t CHUNK_SLICES = 100;

vector<double> total_time_ms;
vector<size_t> total_slices_done;

UtilizationMonitor* util_monitor = 0;
// --------------------------------
void idle_thread_probe()
{
    try
    {
        while(true)
        {
            boost::this_thread::interruption_point();
            boost::this_thread::sleep_for(boost::chrono::milliseconds(100));
        }
    }
    catch(const boost::thread_interrupted&)
    {
    }
}


// --------------------------------
size_t get_max_boost_threads_per_process()
{
    std::vector<boost::thread*> threads;

    size_t count = 0;

    try
    {
        while(true)
        {
            boost::thread* th = new boost::thread(idle_thread_probe);
            threads.push_back(th);
            ++count;
        }
    } catch(...)
    {
        std::cerr << "Thread probe stopped by unknown exception" << std::endl;
    }

    for(size_t i = 0; i < threads.size(); i++)
        threads[i]->interrupt();

    for(size_t i = 0; i < threads.size(); i++)
    {
        threads[i]->join();
        delete threads[i];
    }

    return count;
}


// --------------------------------
inline size_t idx(size_t z, size_t i, size_t j)
{
    return z * n * n + i * n + j;
}

// --------------------------------
/*void update_contention_window()
{
    double ewma_raw = 0.0;
    if(util_monitor != 0)
        ewma_raw = util_monitor->getEWMAUtilizationRaw();

    if(ewma_raw > static_cast<double>(processors))
    {
        contention_window = max(static_cast<size_t>(1), contention_window / 2);
    }
    else
    {
        ++contention_window;
    }
}*/
/*void update_contention_window()
{
    double ewma_raw = 0.0;
    if(util_monitor != 0)
        ewma_raw = util_monitor->getEwmaRawUtilization();

    if(ewma_raw > (double)(processors))
    {
        contention_window = max(static_cast<size_t>(1), contention_window / 2);
    }
    else
    {
        contention_window = min(contention_window + static_cast<size_t>(1), worker_threads);
    }
}*/
void update_contention_window()
{
    double ewma_raw = 0.0;
    if(util_monitor != 0)
        ewma_raw = util_monitor->getEwmaRawUtilization();

    // While EWMA is still zero, ramp up only to the number of processors
    if(ewma_raw == 0.0)
    {
        if(contention_window < processors)
            ++contention_window;
        return;
    }

    // Multiplicative decrease on overshoot
    if(ewma_raw > static_cast<double>(processors))
    {
        contention_window = max(static_cast<size_t>(1), contention_window / 2);
    }
    else
    {
        // Additive increase, capped by worker pool
        contention_window = min(contention_window + static_cast<size_t>(1), worker_threads);
    }
}

// --------------------------------
void sort_rows_in_slice(size_t z)
{
    for(size_t i = 0; i < n; i++)
    {
        uint8_t* row_ptr = &tensor[idx(z, i, 0)];
        row<uint8_t> r(row_ptr, n);
        quick_sort(r);

        for(size_t j = 0; j < n; j++)
            row_ptr[j] = r[j];
    }
}

// --------------------------------
void sort_rows_dynamic_contention(size_t tid)
{
    while(true)
    {
        size_t start_z = 0;
        size_t end_z = 0;

        {
            boost::unique_lock<boost::mutex> lock(mtx);

            while(next_slice < slices && in_flight >= contention_window)
                cv.wait(lock);

            if(next_slice >= slices)
                return;

            start_z = next_slice;
            end_z = min(next_slice + CHUNK_SLICES, slices);
            next_slice = end_z;

            in_flight++;

            double ewma_raw = 0.0;
            if(util_monitor != 0)
                ewma_raw = util_monitor->getEwmaRawUtilization();

            cout << "Thread " << tid
                 << " gets slice chunk [" << start_z << ", " << (end_z - 1) << "]"
                 << " | cwnd=" << contention_window
                 << " | in_flight=" << in_flight
                 << " | ewma_raw=" << ewma_raw
                 << endl;
        }

        boost::chrono::high_resolution_clock::time_point c1 =
            boost::chrono::high_resolution_clock::now();

        for(size_t z = start_z; z < end_z; z++)
            sort_rows_in_slice(z);

        boost::chrono::high_resolution_clock::time_point c2 =
            boost::chrono::high_resolution_clock::now();

        boost::chrono::duration<double, boost::milli> elapsed = c2 - c1;

        {
            boost::unique_lock<boost::mutex> lock(mtx);

            total_time_ms[tid] += elapsed.count();
            total_slices_done[tid] += (end_z - start_z);

            in_flight--;

            update_contention_window();
        }

        cv.notify_all();
    }
}

// --------------------------------
int main()
{
    n = 10;
    slices = 500000;

    srandom(time(NULL));

    processors = boost::thread::hardware_concurrency();
    size_t max_threads=get_max_boost_threads_per_process();
    //Worker threads max up to 100*processors
    //worker_threads = max((size_t)(1), (size_t)(processors*processors));
    cout << "Max threads="<< max_threads << endl;
    //worker_threads = max((size_t)(1), (size_t)(2*processors));
    worker_threads = max((size_t)(1), (size_t)(max_threads)/processors);
    //SKOsize_t slice_threads=(slices/(int) CHUNK_SLICES) +1;
    //worker_threads = (size_t) (max_threads-processors);
    //worker_threads=min((std::size_t) slice_threads, (std::size_t) max_threads-(std::size_t) processors);
    contention_window = 1;

    const size_t elements_per_slice = n * n;
    if(n != 0 && elements_per_slice / n != n)
    {
        cerr << "Overflow computing elements_per_slice\n";
        return 1;
    }

    if(slices != 0 && elements_per_slice > numeric_limits<size_t>::max() / slices)
    {
        cerr << "Overflow computing total tensor size\n";
        return 1;
    }

    const size_t total_elements = slices * elements_per_slice;

    try
    {
        tensor = new uint8_t[total_elements];
    }
    catch(const bad_alloc&)
    {
        cerr << "Failed to allocate tensor of " << total_elements
             << " bytes (" << (static_cast<double>(total_elements) / (1024.0 * 1024.0 * 1024.0))
             << " GiB)\n";
        return 1;
    }

    for(size_t z = 0; z < slices; z++)
        for(size_t i = 0; i < n; i++)
            for(size_t j = 0; j < n; j++)
                tensor[idx(z, i, j)] = static_cast<uint8_t>(random() % 256);

    total_time_ms.resize(worker_threads, 0.0);
    total_slices_done.resize(worker_threads, 0);

    util_monitor = new UtilizationMonitor(1, 0.7, 0.8);
    util_monitor->start();

    cout << "Processors: " << processors << endl;
    cout << "Worker threads: " << worker_threads << endl;
    cout << "Initial contention_window: " << contention_window << endl;
    cout << "Chunk size (slices): " << CHUNK_SLICES << endl;
    cout << "Tensor bytes: " << total_elements << endl;

    boost::thread_group workers;

    boost::chrono::high_resolution_clock::time_point t1 =
        boost::chrono::high_resolution_clock::now();

    for(size_t tid = 0; tid < worker_threads; tid++)
        workers.create_thread(boost::bind(sort_rows_dynamic_contention, tid));

    workers.join_all();

    boost::chrono::high_resolution_clock::time_point t2 =
        boost::chrono::high_resolution_clock::now();

    boost::chrono::milliseconds elapsed =
        boost::chrono::duration_cast<boost::chrono::milliseconds>(t2 - t1);

    if(util_monitor != 0)
        util_monitor->stop();

    cout << "\nPer-thread timing stats:\n";
    for(size_t tid = 0; tid < worker_threads; tid++)
    {
        cout << "Thread " << tid
             << ": slices = " << total_slices_done[tid]
             << ", total_time = " << total_time_ms[tid] << " ms";

        if(total_slices_done[tid] > 0)
        {
            cout << ", avg_per_slice = "
                 << (total_time_ms[tid] / static_cast<double>(total_slices_done[tid])) << " ms";
        }

        cout << endl;
    }

    cout << "\nFinal scheduler state:\n";
    cout << "contention_window = " << contention_window << endl;

    if(util_monitor != 0)
    {
        cout << "final_ewma_utilization_raw = "
             << util_monitor->getEwmaRawUtilization() << endl;
    }

    cout << "\nTotal execution time: "
         << elapsed.count() << " ms\n";

    delete util_monitor;
    util_monitor = 0;

    delete[] tensor;
    tensor = 0;

    return 0;
}
