#include "aimd_fixed.hpp"
#include "utilization_monitor_fixed.hpp"

#include <algorithm>
#include <unistd.h>

#include <boost/thread.hpp>

namespace {

const int CHUNK = 16;
const int AIMD_UPDATE_PERIOD = 16;

// UtilizationMonitorFixed returns a percentage-like utilization value.
// Therefore, compare it with a percentage threshold, not with CPU count.
const double AIMD_TARGET_UTILIZATION_PERCENT = 90.0;

struct AimdFixedState {
    int F;
    int K;
    int next_frame;
    int in_flight;
    int contention_window;
    long processors;
    int completed_chunks;

    boost::mutex mtx;
    boost::condition_variable cv;

    UtilizationMonitorFixed* util;
};

struct AimdFixedArgs {
    int tid;
    AimdFixedState* st;
    task_fn_t task;
    void* ctx;
    bool affinity;
};

void update_cwnd(AimdFixedState* st) {
    double ewma_raw = (st->util ? st->util->getEwmaRawUtilization() : 0.0);
    if (ewma_raw <= 0.0) {
    st->contention_window = std::min(st->contention_window + 1, st->K);
    return;
    }

    if (ewma_raw > AIMD_TARGET_UTILIZATION_PERCENT) {
        st->contention_window = std::max(1, st->contention_window / 2);
    } else {
         st->contention_window = std::min(st->contention_window + 1, st->K);
      }
}

void worker(AimdFixedArgs a) {
    if (a.affinity) {
    set_thread_affinity(a.tid);
    }
    while (true) {
    int start = 0;
    int end = 0;
    {
    boost::unique_lock<boost::mutex> lk(a.st->mtx);
    while (a.st->next_frame < a.st->F &&
    a.st->in_flight >= a.st->contention_window) {
    a.st->cv.wait(lk);
    }

                                                                                                                                                                                                                                                                       if (a.st->next_frame >= a.st->F) {
                                                                                                                                                                                                                                                                                       return;
                                                                                                                                                                                                                                                                                                   }

                                                                                                                                                                                                                                                                                                               start = a.st->next_frame;
                                                                                                                                                                                                                                                                                                                           end = std::min(a.st->next_frame + CHUNK, a.st->F);

                                                                                                                                                                                                                                                                                                                                       a.st->next_frame = end;
                                                                                                                                                                                                                                                                                                                                                   ++a.st->in_flight;
                                                                                                                                                                                                                                                                                                                                                           }

                                                                                                                                                                                                                                                                                                                                                                   for (int f = start; f < end; ++f) {
                                                                                                                                                                                                                                                                                                                                                                               a.task(f, a.tid, a.st->K, a.ctx);
                                                                                                                                                                                                                                                                                                                                                                                       }

                                                                                                                                                                                                                                                                                                                                                                                               bool no_more_work_to_assign = false;

                                                                                                                                                                                                                                                                                                                                                                                                       {
                                                                                                                                                                                                                                                                                                                                                                                                                   boost::unique_lock<boost::mutex> lk(a.st->mtx);

                                                                                                                                                                                                                                                                                                                                                                                                                               --a.st->in_flight;
                                                                                                                                                                                                                                                                                                                                                                                                                                           ++a.st->completed_chunks;

                                                                                                                                                                                                                                                                                                                                                                                                                                                       no_more_work_to_assign = (a.st->next_frame >= a.st->F);

                                                                                                                                                                                                                                                                                                                                                                                                                                                                   // Updating the AIMD window after every chunk is expensive.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                               // Update periodically, and always update near the end.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           if (a.st->completed_chunks % AIMD_UPDATE_PERIOD == 0 ||
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           no_more_work_to_assign) {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           update_cwnd(a.st);
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               }

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       // If all frames have been assigned, wake everyone so waiting workers
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               // can notice that next_frame >= F and exit cleanly.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       // Otherwise, wake one worker to avoid unnecessary wake storms.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               if (no_more_work_to_assign) {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           a.st->cv.notify_all();
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   } else {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               a.st->cv.notify_one();
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           }

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           } // namespace

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           int run_aimd_fixed_scheduler(int F, int K, task_fn_t task, void* ctx, bool affinity)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               if (F <= 0 || K <= 0 || task == nullptr) {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       return -1;
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           }

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               AimdFixedState st;

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   st.F = F;
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       st.K = K;
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           st.next_frame = 0;
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               st.in_flight = 0;
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   st.completed_chunks = 0;

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       st.processors = sysconf(_SC_NPROCESSORS_ONLN);
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           if (st.processors <= 0) {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   st.processors = 1;
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       }

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           // Start fully open for this benchmark.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               // The old AIMD started at 1, which caused near-serial warm-up.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   st.contention_window = K;

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       st.util = new UtilizationMonitorFixed(1, 0.7, 0.8);
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           st.util->start();

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               boost::thread_group workers;

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   for (int tid = 0; tid < K; ++tid) {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           AimdFixedArgs a{tid, &st, task, ctx, affinity};
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   workers.create_thread(boost::bind(worker, a));
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       }

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           workers.join_all();

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               st.util->stop();
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   delete st.util;

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       return 0;
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       }
