#include "chunk_steal_scheduler.hpp"

#include <algorithm>
#include <deque>
#include <vector>

#include <boost/thread.hpp>

namespace {

const int CHUNK = 8;

struct StealState {
    int F;
    int K;
    std::vector<std::deque<int>> queues;
    std::vector<boost::mutex*>   qmutex;
};

struct StealArgs {
    int tid;
    StealState* st;
    task_fn_t task;
    void* ctx;
    bool affinity;
};

void worker(StealArgs a) {
    if (a.affinity)
        set_thread_affinity(a.tid);

    while (true) {
        int start = -1;

        {
            boost::mutex::scoped_lock lk(*a.st->qmutex[a.tid]);
            if (!a.st->queues[a.tid].empty()) {
                start = a.st->queues[a.tid].front();
                a.st->queues[a.tid].pop_front();
            }
        }

        if (start == -1) {
            for (int v = 0; v < a.st->K; ++v) {
                if (v == a.tid) continue;
                if (a.st->qmutex[v]->try_lock()) {
                    if (!a.st->queues[v].empty()) {
                        start = a.st->queues[v].back();
                        a.st->queues[v].pop_back();
                        a.st->qmutex[v]->unlock();
                        break;
                    }
                    a.st->qmutex[v]->unlock();
                }
            }
        }

        if (start == -1)
            return;

        int end = std::min(start + CHUNK, a.st->F);
        for (int f = start; f < end; ++f)
            a.task(f, a.tid, a.st->K, a.ctx);
    }
}

} // namespace

int run_chunk_steal_scheduler(int F, int K, task_fn_t task, void* ctx, bool affinity)
{
    if (F <= 0 || K <= 0 || task == nullptr)
        return -1;

    StealState st;
    st.F = F;
    st.K = K;
    st.queues.resize(K);
    st.qmutex.resize(K);
    for (int i = 0; i < K; ++i)
        st.qmutex[i] = new boost::mutex;

    // Distribute CHUNK-aligned starts round-robin across the deques.
    for (int s = 0; s < F; s += CHUNK)
        st.queues[(s / CHUNK) % K].push_back(s);

    boost::thread_group workers;
    for (int tid = 0; tid < K; ++tid) {
        StealArgs a{tid, &st, task, ctx, affinity};
        workers.create_thread(boost::bind(worker, a));
    }
    workers.join_all();

    for (int i = 0; i < K; ++i)
        delete st.qmutex[i];

    return 0;
}
