#include <iostream>
#include <vector>
#include <random>
#include <string>
#include <cstdlib>
#include <chrono>
#include "prolific_scheduler.hpp"
#include "collective_scheduler.hpp"

int main(int argc, char* argv[]) {
    if (argc != 4) {
        std::cerr << "Usage: " << argv[0] << " n k {prolific|collective}\n";
        return 1;
    }

    int n = std::atoi(argv[1]);
    int k = std::atoi(argv[2]);
    std::string mode = argv[3];

    //generates random tensor B(n,n,k)
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> dist(-1000, 1000);

    std::vector<int> B(n * n * k);
    for (int s = 0; s < k; ++s) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                B[s * n * n + i * n + j] = dist(gen);
            }
        }
    }

    std::vector<int> sorted_B(n * n * k);

    double total_init_time = 0.0;
    double total_task_time = 0.0;

    if (mode == "prolific") {
        for (int s = 0; s < k; ++s) {
            //measure initialization
            auto init_start = std::chrono::high_resolution_clock::now();
            prolific_scheduler slice_scheduler(n);
            auto init_end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> init_elapsed = init_end - init_start;
            total_init_time += init_elapsed.count();

            //extracts slice
            std::vector<int> slice(n * n);
            std::copy(B.begin() + s * n * n, B.begin() + (s + 1) * n * n, slice.begin());

            //measures sorting task
            auto task_start = std::chrono::high_resolution_clock::now();
            std::vector<int> sorted_slice = slice_scheduler.sort_matrix_rows(slice, n, n);
            auto task_end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> task_elapsed = task_end - task_start;
            total_task_time += task_elapsed.count();

            std::copy(sorted_slice.begin(), sorted_slice.end(), sorted_B.begin() + s * n * n);
        }
    } else if (mode == "collective") {
        for (int s = 0; s < k; ++s) {
            auto init_start = std::chrono::high_resolution_clock::now();
            collective_scheduler slice_scheduler(n);
            auto init_end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> init_elapsed = init_end - init_start;
            total_init_time += init_elapsed.count();

            std::vector<int> slice(n * n);
            std::copy(B.begin() + s * n * n, B.begin() + (s + 1) * n * n, slice.begin());

            auto task_start = std::chrono::high_resolution_clock::now();
            std::vector<int> sorted_slice = slice_scheduler.sort_matrix_rows(slice, n, n);
            auto task_end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> task_elapsed = task_end - task_start;
            total_task_time += task_elapsed.count();

            std::copy(sorted_slice.begin(), sorted_slice.end(), sorted_B.begin() + s * n * n);
        }
    } else {
        std::cerr << "Invalid mode. Use 'prolific' or 'collective'.\n";
        return 1;
    }

    std::cout << "Mode: " << mode << "\n";
    std::cout << "n = " << n << ", k = " << k << "\n";
    std::cout << "Total initialization time (all slices): " << total_init_time << " seconds\n";
    std::cout << "Total task execution time (all slices): " << total_task_time << " seconds\n";

    
    if (n <= 5 && k <= 3) {
        std::cout << "\nSorted tensor B (each row of each slice sorted):\n";
        for (int s = 0; s < k; ++s) {
            std::cout << "Slice " << s << ":\n";
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < n; ++j) {
                    std::cout << sorted_B[s * n * n + i * n + j] << "\t";
                }
                std::cout << "\n";
            }
            std::cout << "\n";
        }
    }

    return 0;
}
