#include "aimd.hpp"
#include "utilization_monitor.hpp"

#include <algorithm>
#include <unistd.h>

#include <boost/thread.hpp>
#include <boost/chrono.hpp>

namespace {

const int CHUNK = 8;

struct AimdState {
    int F;
    int K;
    int next_frame;
    int in_flight;
    int contention_window;
    long processors;

    boost::mutex mtx;
    boost::condition_variable cv;

    UtilizationMonitor* util;
};

struct AimdArgs {
    int tid;
    AimdState* st;
    task_fn_t task;
    void* ctx;
    bool affinity;
};

void update_cwnd(AimdState* st) {
    double ewma_raw = (st->util ? st->util->getEwmaRawUtilization() : 0.0);

    if (ewma_raw == 0.0) {
        if (st->contention_window < static_cast<int>(st->processors))
            ++st->contention_window;
        return;
    }

    if (ewma_raw > static_cast<double>(st->processors))
        st->contention_window = std::max(1, st->contention_window / 2);
    else
        st->contention_window = std::min(st->contention_window + 1, st->K);
}

void worker(AimdArgs a) {
    if (a.affinity)
        set_thread_affinity(a.tid);

    while (true) {
        int start, end;

        {
            boost::unique_lock<boost::mutex> lk(a.st->mtx);

            while (a.st->next_frame < a.st->F &&
                   a.st->in_flight >= a.st->contention_window)
                a.st->cv.wait(lk);

            if (a.st->next_frame >= a.st->F)
                return;

            start = a.st->next_frame;
            end = std::min(a.st->next_frame + CHUNK, a.st->F);
            a.st->next_frame = end;
            ++a.st->in_flight;
        }

        for (int f = start; f < end; ++f)
            a.task(f, a.tid, a.st->K, a.ctx);

        {
            boost::unique_lock<boost::mutex> lk(a.st->mtx);
            --a.st->in_flight;
            update_cwnd(a.st);
        }
        a.st->cv.notify_all();
    }
}

} // namespace

int run_aimd_scheduler(int F, int K, task_fn_t task, void* ctx, bool affinity)
{
    if (F <= 0 || K <= 0 || task == nullptr)
        return -1;

    AimdState st;
    st.F = F;
    st.K = K;
    st.next_frame = 0;
    st.in_flight = 0;
    st.contention_window = 1;
    st.processors = sysconf(_SC_NPROCESSORS_ONLN);
    if (st.processors <= 0) st.processors = 1;

    st.util = new UtilizationMonitor(1, 0.7, 0.8);
    st.util->start();

    boost::thread_group workers;
    for (int tid = 0; tid < K; ++tid) {
        AimdArgs a{tid, &st, task, ctx, affinity};
        workers.create_thread(boost::bind(worker, a));
    }
    workers.join_all();

    st.util->stop();
    delete st.util;
    return 0;
}
