#include <iostream>
#include <vector>
#include <pthread.h>
#include <chrono>
#include "row.hpp"
#include "sort_algorithms.hpp"


struct StaticWorkerArgs {
    uint8_t* shared_data; 
    int F, H, W;          
    int start_row;        
    int end_row;          
};


void* static_worker(void* arg) {
    StaticWorkerArgs* wa = static_cast<StaticWorkerArgs*>(arg);
    
    for (int i = wa->start_row; i < wa->end_row; ++i) {
        
        uint8_t* row_ptr = wa->shared_data + (static_cast<size_t>(i) * wa->W);
        
       
        row<uint8_t> currentRow(row_ptr, static_cast<uint32_t>(wa->W), false);
        
       
        quick_sort(currentRow);
    }
    return nullptr;
}


long long run_static_scheduler(uint8_t* data, int F, int H, int W, int K) {
    int total_rows = F * H; 
    int rows_per_thread = total_rows / K;
    
    std::vector<pthread_t> threads(K);
    std::vector<StaticWorkerArgs> args(K);
    
    auto t1 = std::chrono::high_resolution_clock::now();
    
    for (int i = 0; i < K; ++i) {
        args[i].shared_data = data;
        args[i].F = F; args[i].H = H; args[i].W = W;
        args[i].start_row = i * rows_per_thread;
      
        args[i].end_row = (i == K - 1) ? total_rows : (i + 1) * rows_per_thread;
        
        pthread_create(&threads[i], nullptr, static_worker, &args[i]);
    }
    
    
    for (int i = 0; i < K; ++i) {
        pthread_join(threads[i], nullptr);
    }
    
    auto t2 = std::chrono::high_resolution_clock::now();
    return std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();
}