#include <pthread.h>
#include <atomic>
#include <chrono>
#include <algorithm>
#include "row.hpp"
#include "sort_algorithms.hpp"

// Χρησιμοποιούμε το MatrixSortContext που ορίσαμε στο main.cpp
struct MatrixSortContext {
    uint8_t* A;
    int F, H, W;
};

struct AIMDArgs {
    MatrixSortContext* ctx;
    std::atomic<int>* next_row;
    int total_rows;
};

void* aimd_worker(void* arg) {
    AIMDArgs* aa = (AIMDArgs*)arg;
    int current_chunk = 2; // Αρχικό chunk size
    const long long target_usec = 1000; // Στόχος: 1ms ανά chunk

    while (true) {
        int start = aa->next_row->fetch_add(current_chunk);
        if (start >= aa->total_rows) break;

        int end = std::min(start + current_chunk, aa->total_rows);
        
        auto t1 = std::chrono::steady_clock::now();
        for (int i = start; i < end; ++i) {
            uint8_t* row_ptr = aa->ctx->A + (static_cast<size_t>(i) * aa->ctx->W);
            row<uint8_t> r(row_ptr, (uint32_t)aa->ctx->W, false);
            quick_sort(r);
        }
        auto t2 = std::chrono::steady_clock::now();
        
        long long duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();

        // AIMD Logic
        if (duration < target_usec) {
            current_chunk += 1; // Additive Increase
        } else {
            current_chunk = std::max(2, current_chunk / 2); // Multiplicative Decrease
        }
    }
    return nullptr;
}

long long run_aimd_scheduler(MatrixSortContext* ctx, int K) {
    int total_rows = ctx->F * ctx->H;
    std::atomic<int> next_row(0);
    pthread_t threads[K];
    AIMDArgs args[K];

    auto start_bench = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < K; ++i) {
        args[i] = {ctx, &next_row, total_rows};
        pthread_create(&threads[i], nullptr, aimd_worker, &args[i]);
    }
    for (int i = 0; i < K; ++i) pthread_join(threads[i], nullptr);
    auto end_bench = std::chrono::high_resolution_clock::now();

    return std::chrono::duration_cast<std::chrono::milliseconds>(end_bench - start_bench).count();
}
  
