#include <iostream>
#include <cstdlib>
#include <ctime>
#include <sys/mman.h>
#include <sys/wait.h>
#include <unistd.h>
#include <chrono>
#include <cmath>
#include "ProlificScheduler.hpp"
#include "CollectiveScheduler.hpp"

using namespace std;

// S????t?s? e?t?p?s?? ??a t? ???t?µa (3)
void print_tensor(double* B, int n, int k) {
    for (int i = 0; i < k; i++) {
        cout << "--- Sub-tensor B(" << i << ") ---" << endl;
        for (int r = 0; r < n; r++) {
            for (int c = 0; c < n; c++) {
                cout << B[i * n * n + r * n + c] << "\t";
            }
            cout << endl;
        }
        cout << endl;
    }
}

int main() {

    int n = 10;  
    int k = 10;  
    size_t size = n * n * k * sizeof(double);
    
    double* B = (double*)mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0);
    if (B == MAP_FAILED) return 1;

    srand((unsigned int)time(NULL));
    for (int i = 0; i < n * n * k; i++) B[i] = (rand() % 2001) - 1000.0;


    if (n <= 5 && k <= 5) {
        cout << "=== ORIGINAL TENSOR B ===" << endl;
        print_tensor(B, n, k);
    }

    ProlificScheduler prolific;
    CollectiveScheduler collective;

    cout << "--- Performance Test (n=" << n << ", k=" << k << ") ---" << endl;

    // 1. ??t??s? Prolific
    auto s1 = std::chrono::high_resolution_clock::now();
    prolific.execute(k, n * k, B, n); 
    auto e1 = std::chrono::high_resolution_clock::now();
    auto dur1 = std::chrono::duration_cast<std::chrono::milliseconds>(e1 - s1).count();
    cout << "Prolific Total Time: " << dur1 << " ms" << endl;

    // ??a??teµa ??a d??a?? test st?? Collective
    for (int i = 0; i < n * n * k; i++) B[i] = (rand() % 2001) - 1000.0;

    // 2. ??t??s? Collective
    int levels = ceil(log2(k + 1)); // ??t?µat?? workers a?????a µe t? k
    
    auto s2 = std::chrono::high_resolution_clock::now();
    collective.execute(levels, n * k, B, n);
    auto e2 = std::chrono::high_resolution_clock::now();
    auto dur2 = std::chrono::duration_cast<std::chrono::milliseconds>(e2 - s2).count();
    cout << "Collective Total Time: " << dur2 << " ms" << endl;

    // ??t?p?s? ???? (µ??? ??a µ???? ded?µ??a)
    if (n <= 5 && k <= 5) {
        cout << "=== ROW-SORTED TENSOR B (Result) ===" << endl;
        print_tensor(B, n, k);
    }

    munmap(B, size);
    return 0;
}
