#include "adaptive_scheduler.hpp"

#include <vector>

#include <boost/thread.hpp>
#include <boost/chrono.hpp>

namespace {

struct AdaptState {
    int F;
    int K;
    int next_frame;
    boost::mutex mtx;
    std::vector<double> total_time_ms;
    std::vector<int>    total_tasks_done;
};

struct AdaptArgs {
    int tid;
    AdaptState* st;
    task_fn_t task;
    void* ctx;
    bool affinity;
};

void worker(AdaptArgs a) {
    if (a.affinity)
        set_thread_affinity(a.tid);

    while (true) {
        int start, end, chunk;

        {
            boost::mutex::scoped_lock lk(a.st->mtx);
            if (a.st->next_frame >= a.st->F)
                return;

            int remaining = a.st->F - a.st->next_frame;
            int base_chunk = remaining / a.st->K;
            if (base_chunk < 1) base_chunk = 1;

            double speed_factor = 1.0;
            if (a.st->total_tasks_done[a.tid] > 0) {
                double my_avg = a.st->total_time_ms[a.tid] /
                                a.st->total_tasks_done[a.tid];
                double global_avg = 0.0;
                int active = 0;
                for (int t = 0; t < a.st->K; ++t) {
                    if (a.st->total_tasks_done[t] > 0) {
                        global_avg += a.st->total_time_ms[t] /
                                      a.st->total_tasks_done[t];
                        ++active;
                    }
                }
                if (active > 0) {
                    global_avg /= active;
                    if (my_avg > 0.0)
                        speed_factor = global_avg / my_avg;
                }
            }

            chunk = static_cast<int>(base_chunk * speed_factor);
            if (chunk < 1) chunk = 1;
            if (chunk > remaining) chunk = remaining;

            start = a.st->next_frame;
            a.st->next_frame += chunk;
            end = a.st->next_frame;
        }

        boost::chrono::high_resolution_clock::time_point c1 =
            boost::chrono::high_resolution_clock::now();

        for (int f = start; f < end; ++f)
            a.task(f, a.tid, a.st->K, a.ctx);

        boost::chrono::high_resolution_clock::time_point c2 =
            boost::chrono::high_resolution_clock::now();
        boost::chrono::duration<double, boost::milli> elapsed = c2 - c1;

        {
            boost::mutex::scoped_lock lk(a.st->mtx);
            a.st->total_time_ms[a.tid] += elapsed.count();
            a.st->total_tasks_done[a.tid] += (end - start);
        }
    }
}

} // namespace

int run_adaptive_scheduler(int F, int K, task_fn_t task, void* ctx, bool affinity)
{
    if (F <= 0 || K <= 0 || task == nullptr)
        return -1;

    AdaptState st;
    st.F = F;
    st.K = K;
    st.next_frame = 0;
    st.total_time_ms.assign(K, 0.0);
    st.total_tasks_done.assign(K, 0);

    boost::thread_group workers;
    for (int tid = 0; tid < K; ++tid) {
        AdaptArgs a{tid, &st, task, ctx, affinity};
        workers.create_thread(boost::bind(worker, a));
    }
    workers.join_all();
    return 0;
}
