#include "chunk_stealing_scheduler.hpp"

#include <boost/thread.hpp>
#include <boost/bind.hpp>

#include <algorithm>
#include <cstddef>
#include <deque>
#include <vector>

#include "row.hpp"
#include "sort_algorithms.hpp"

using namespace std;

struct ChunkStealingContext
{
    uint8_t* A;
    int F;
    int H;
    int W;
    int workerCount;
    int chunkSize;

    vector<deque<int> > queues;
    vector<boost::mutex*> qmutex;
};

static inline size_t tensor_index(int f, int i, int j, int H, int W)
{
    return (static_cast<size_t>(f) * static_cast<size_t>(H) + static_cast<size_t>(i))
         * static_cast<size_t>(W) + static_cast<size_t>(j);
}

static void sort_slice_rows(uint8_t* A, int slice, int H, int W)
{
    for (int i = 0; i < H; ++i)
    {
        uint8_t* rowPtr = &A[tensor_index(slice, i, 0, H, W)];

        // copy_data = false: the row object directly wraps the tensor row.
        // quick_sort modifies the tensor row in place.
        row<uint8_t> currentRow(rowPtr, static_cast<uint32_t>(W), false);
        quick_sort(currentRow);
    }
}

static bool pop_own_chunk(ChunkStealingContext* ctx, int tid, int& start)
{
    boost::mutex::scoped_lock lock(*ctx->qmutex[tid]);

    if (ctx->queues[tid].empty())
        return false;

    start = ctx->queues[tid].front();
    ctx->queues[tid].pop_front();
    return true;
}

static bool steal_chunk(ChunkStealingContext* ctx, int tid, int& start)
{
    for (int offset = 1; offset < ctx->workerCount; ++offset)
    {
        int victim = (tid + offset) % ctx->workerCount;

        boost::mutex::scoped_lock lock(*ctx->qmutex[victim]);

        if (!ctx->queues[victim].empty())
        {
            start = ctx->queues[victim].back();
            ctx->queues[victim].pop_back();
            return true;
        }
    }

    return false;
}

static void chunk_stealing_worker(ChunkStealingContext* ctx, int tid)
{
    while (true)
    {
        int start = -1;

        if (!pop_own_chunk(ctx, tid, start))
        {
            if (!steal_chunk(ctx, tid, start))
                return;
        }

        int end = min(start + ctx->chunkSize, ctx->F);

        for (int slice = start; slice < end; ++slice)
            sort_slice_rows(ctx->A, slice, ctx->H, ctx->W);
    }
}

int run_chunk_stealing_scheduler(
    uint8_t* A,
    int F,
    int H,
    int W,
    int workerCount,
    int chunkSize
)
{
    if (A == nullptr || F <= 0 || H <= 0 || W <= 0)
        return -1;

    if (workerCount <= 0)
        workerCount = 1;

    if (workerCount > F)
        workerCount = F;

    if (chunkSize <= 0)
        chunkSize = 1;

    ChunkStealingContext ctx;
    ctx.A = A;
    ctx.F = F;
    ctx.H = H;
    ctx.W = W;
    ctx.workerCount = workerCount;
    ctx.chunkSize = chunkSize;

    ctx.queues.resize(static_cast<size_t>(workerCount));
    ctx.qmutex.resize(static_cast<size_t>(workerCount));

    for (int i = 0; i < workerCount; ++i)
        ctx.qmutex[i] = new boost::mutex;

    for (int s = 0; s < F; s += chunkSize)
    {
        int owner = (s / chunkSize) % workerCount;
        ctx.queues[owner].push_back(s);
    }

    boost::thread_group workers;

    for (int tid = 0; tid < workerCount; ++tid)
        workers.create_thread(boost::bind(chunk_stealing_worker, &ctx, tid));

    workers.join_all();

    for (int i = 0; i < workerCount; ++i)
        delete ctx.qmutex[i];

    return 0;
}
