#include "aimd.hpp"
#include "utilization_monitor.hpp"

#include <algorithm>
#include <unistd.h>

#include <boost/thread.hpp>
#include <boost/chrono.hpp>

namespace {

// Smaller chunks give the scheduler more decision points and let
// load-imbalance corrections (AIMD halving/incrementing) actually take
// effect during the run. With CHUNK=8 and F=1000 we'd only have ~125
// chunks total, of which the first dozens would run with a tiny window.
const int CHUNK = 2;

struct AimdState {
    int F;
    int K;
    int next_frame;
    int in_flight;
    int contention_window;
    long processors;
    long update_calls;   // # times update_cwnd has been entered (under mtx)

    boost::mutex mtx;
    boost::condition_variable cv;

    UtilizationMonitor* util;
};

struct AimdArgs {
    int tid;
    AimdState* st;
    task_fn_t task;
    void* ctx;
    bool affinity;
};

void update_cwnd(AimdState* st) {
    // Throttle EWMA reads. The monitor only updates the EWMA value once
    // per sampling interval (1 sec), so reading it on every chunk would
    // do nothing useful and just generate cache-coherence traffic with
    // the 24 worker threads pulling the same atomic. We sample it once
    // every UPDATE_PERIOD chunks; on the in-between calls we just nudge
    // the window by +1 in slow-start and leave it unchanged otherwise.
    static const int UPDATE_PERIOD = 32;
    ++st->update_calls;
    if (st->update_calls % UPDATE_PERIOD != 0) {
        if (st->contention_window < static_cast<int>(st->processors))
            st->contention_window = std::min(
                static_cast<int>(st->processors),
                st->contention_window + 1);
        return;
    }

    double ewma_raw = (st->util ? st->util->getEwmaRawUtilizationLockFree() : 0.0);

    if (ewma_raw == 0.0) {
        if (st->contention_window < static_cast<int>(st->processors))
            st->contention_window = std::min(
                static_cast<int>(st->processors),
                st->contention_window * 2);
        return;
    }

    const double overload_pct = 95.0;
    if (ewma_raw > overload_pct)
        st->contention_window = std::max(1, st->contention_window / 2);
    else
        st->contention_window = std::min(st->contention_window + 1, st->K);
}

void worker(AimdArgs a) {
    if (a.affinity)
        set_thread_affinity(a.tid);

    while (true) {
        int start, end;

        {
            boost::unique_lock<boost::mutex> lk(a.st->mtx);

            while (a.st->next_frame < a.st->F &&
                   a.st->in_flight >= a.st->contention_window)
                a.st->cv.wait(lk);

            if (a.st->next_frame >= a.st->F)
                return;

            start = a.st->next_frame;
            end = std::min(a.st->next_frame + CHUNK, a.st->F);
            a.st->next_frame = end;
            ++a.st->in_flight;
        }

        for (int f = start; f < end; ++f)
            a.task(f, a.tid, a.st->K, a.ctx);

        {
            boost::unique_lock<boost::mutex> lk(a.st->mtx);
            --a.st->in_flight;
            update_cwnd(a.st);
        }
        a.st->cv.notify_all();
    }
}

} // namespace

int run_aimd_scheduler(int F, int K, task_fn_t task, void* ctx, bool affinity,
                       UtilizationMonitor* external_monitor)
{
    if (F <= 0 || K <= 0 || task == nullptr)
        return -1;

    AimdState st;
    st.F = F;
    st.K = K;
    st.next_frame = 0;
    st.in_flight = 0;
    st.update_calls = 0;
    st.processors = sysconf(_SC_NPROCESSORS_ONLN);
    if (st.processors <= 0) st.processors = 1;

    // Start the AIMD window at a sensible bootstrap value instead of 1.
    // With cwnd=1 the scheduler would run almost serially during the
    // first sampling intervals (while EWMA is still zero) because the
    // window only grows when a chunk completes. Starting near processor
    // count lets the scheduler benefit from parallelism immediately;
    // if it really does cause contention, the multiplicative-decrease
    // path will halve it on the next probe.
    st.contention_window = std::min(K, static_cast<int>(st.processors));

    // Use an externally provided, already-running monitor if given.
    // Otherwise, create one locally — but that adds ~1s of start/stop
    // overhead that the caller should be aware of.
    bool owns_monitor = false;
    if (external_monitor != nullptr) {
        st.util = external_monitor;
    } else {
        st.util = new UtilizationMonitor(1, 0.7, 0.8);
        st.util->start();
        owns_monitor = true;
    }

    boost::thread_group workers;
    for (int tid = 0; tid < K; ++tid) {
        AimdArgs a{tid, &st, task, ctx, affinity};
        workers.create_thread(boost::bind(worker, a));
    }
    workers.join_all();

    if (owns_monitor) {
        st.util->stop();
        delete st.util;
    }
    return 0;
}
