#include <opencv2/opencv.hpp>
#include <iostream>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <vector>
#include <unistd.h>
#include <sys/wait.h>
#include <boost/chrono.hpp>
#include "row.hpp"
#include "sort_algorithms.hpp"
#include "bounded_collective.hpp"

#include <sys/resource.h>
#include <pthread.h>

using namespace std;

bool bind_process_to_cpu(int cpuId)
{
	cpu_set_t cpuset;
	CPU_ZERO(&cpuset);
	CPU_SET(cpuId , &cpuset);

	int crc = sched_setaffinity(0,sizeof(cpu_set_t) , &cpuset);
	if(crc==0)
	{
		cout << "Bound to CPU " << cpuId << endl;
		return true;
	}
	else
	{
		cout << "Failer to bind to CPU " << cpuId << endl;
		return false;
	}
}

uint8_t& PIXEL(uint8_t* video , int f , int i, int j, int k , int H , int W , int C)
{
	return video[(((size_t)f * H + i) * W* j)*C+k];
}

int main(int argc, char** argv)
{
	if(argc < 3)
	{
		cout << "Usage ./sort_blue_video input.mp4 output.mp4 \n";
		return -1;
	}

	long numCPUs = sysconf(_SC_NPROCESSORS_ONLY);
	if (numCPUs <=0)
	{
		cout << "could not determine number of online processors . Using 1 "<< endl;
		numCPUs=1;
	}

	cout << "Number of online processors: " << numCPUs << endl;

	int parentCpu = 0%numCPUs;
	bind_process_to_cpu(parentCpu);

	cv::VideoCapture cap(argv[1]);
	if (!cap.isOpened())
	      {
	       cout << "Could not open input video.\n";
	       return -1;
	      }

	int F = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_COUNT));
	int W = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_WIDTH));
	int H = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_HEIGHT));
	int C = 3;

	double fps = cap.get(cv::CAP_PROP_FPS);
	if (fps <= 0.0) fps = 30.0;

	if (F <= 0 || W <= 0 || H <= 0)
	{
		cout << "Invalid video properties.\n";
	        return -1;
	}
	cout << "Frames: " << F << ", Width: " << W << ", Height: " << H
	         << ", FPS: " << fps << endl;

	// Shared memory allocation for all video data
	size_t totalBytes = (size_t)F * H * W * C * sizeof(uint8_t);

	uint8_t* video = (uint8_t*)mmap(
		nullptr,
	        totalBytes,
	        PROT_READ | PROT_WRITE,
	        MAP_SHARED | MAP_ANONYMOUS,
	        -1,
	        0
	       );

	if (video == MAP_FAILED)
	    {
	           perror("mmap failed");
	           return -1;
             }
	
	cv::Mat frame;
	int f = 0;
	while (cap.read(frame) && f < F)
	{
	     if (frame.empty())
	     break;
	                                
	     if (frame.channels() != 3)
	     {
	          cout << "Expected 3-channel BGR video.\n";
	          munmap(video, totalBytes);
	          return -1;
	     }	
	     
	     for (int i = 0; i < H; i++)
	     {
	     	for (int j = 0; j < W; j++)
	        {
	        	cv::Vec3b pixel = frame.at<cv::Vec3b>(i, j);
	                for (int k = 0; k < C; k++)
	                PIXEL(video, f, i, j, k, H, W, C) = pixel[k];
	        }
	    }
	    f++;
    }
    cap.release();
                                
    int actualFrames = f;
    if (actualFrames == 0)
                                                                                                                                                                                                                             {
	                                                                                                                                                                                                                                                cout << "No frames read from input video.\n";
	                                                                                                                                                                                                                                                        munmap(video, totalBytes);
	                                                                                                                                                                                                                                                                return -1;
	                                                                                                                                                                                                                                                                    }
	                                                                                                                                                                                                                                                                                            

// Use 4 forked processes
    const int K = 23;
        //const int CHANNEL=2;
            vector<pid_t> children;
                cout << "\n=== Parallel row-sorting section begins ===" << endl;
                    cout << "Parent PID: " << getpid() << endl;
                        cout << "Creating " << K << " child processes..." << endl;
                             boost::chrono::high_resolution_clock::time_point t1 =
                                 boost::chrono::high_resolution_clock::now();

                                     for (int p = 0; p < K; p++)
                                         {
                                                 // cout << "[Parent " << getpid() << "] Requesting creation of child process "
                                                         // << p << "..." << endl;
                                                                 pid_t pid = fork();

                                                                         if (pid < 0)
                                                                                 {
                                                                                             perror("fork failed");
                                                                                                         munmap(video, totalBytes);
                                                                                                                     return -1;
                                                                                                                             }

                                                                                                                                     if (pid == 0)
                                                                                                                                             {
                                                                                                                                             	     //COW
                                                                                                                                             	          pid_t mypid = getpid();
                                                                                                                                             	               pid_t parentpid = getppid();
                                                                                                                                             	                    //Child affinity
                                                                                                                                             	                          int childCpu = (p + 1) % numCPUs;
                                                                                                                                             	                                  bind_process_to_cpu(childCpu);
                                                                                                                                             	                                       // Try to raise child priority: lower nice value = higher priority
                                                                                                                                             	                                               cout << "[Child index " << p << "] STARTED"
                                                                                                                                             	                                                            << " | PID = " << mypid
                                                                                                                                             	                                                                         << " | Parent PID = " << parentpid
                                                                                                                                             	                                                                                      << " | Priority set to high (nice = "
                                                                                                                                             	                                                                                                   << getpriority(PRIO_PROCESS, 0)
                                                                                                                                             	                                                                                                   	     << " CPU="<< getpriority(PRIO_PROCESS, 0) << ")"
                                                                                                                                             	                                                                                                   	                  << endl;
                                                                                                                                             	                                                                                                   	                              // Child process:
                                                                                                                                             	                                                                                                   	                                          // process frames p, p+K, p+2K, ...
                                                                                                                                             	                                                                                                   	                                          	      cout << "[Child index " << p << "] Assigned frames: ";
                                                                                                                                             	                                                                                                   	                                          	                   bool first = true;
                                                                                                                                             	                                                                                                   	                                          	                   	    for (int ff = p; ff < actualFrames; ff += K)
                                                                                                                                             	                                                                                                   	                                          	                   	                {	
                                                                                                                                             	                                                                                                   	                                          	                   	                            	if (!first) cout << ", ";
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		cout << ff;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		first = false;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		            }
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    	cout << endl;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		   //COW
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		              int rowsProcessed = 0;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                         int framesProcessed = 0;

                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                     for (int ff = p; ff < actualFrames; ff += K)
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 {
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		cout << "[Child index " << p << " | PID " << mypid
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                 << "] Processing frame " << ff << "..." << endl;

                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	framesProcessed++;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                for (int i = 0; i < H; i++)
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                {
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		     uint8_t* blueRow  = new uint8_t[W];
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	     uint8_t* greenRow = new uint8_t[W];
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                          uint8_t* redRow   = new uint8_t[W];
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              for (int j = 0; j < W; j++) {
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		        blueRow[j]  = PIXEL(video, ff, i, j, 0, H, W, C);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	greenRow[j] = PIXEL(video, ff, i, j, 1, H, W, C);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	redRow[j]   = PIXEL(video, ff, i, j, 2, H, W, C);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                     }
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                         // Wrap with row<uint8_t> and sort
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                             //row<uint8_t> r(blueRow, W);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 //quick_sort(r);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			row<uint8_t> b(blueRow, W);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		row<uint8_t> g(greenRow, W);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            		row<uint8_t> r(redRow, W);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					quick_sort(b);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	        quick_sort(g);
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                quick_sort(r);

                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                    // Write sorted blue channel back
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                        for (int j = 0; j < W; j++) {
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                //PIXEL(video, ff, i, j, CHANNEL, H, W, C) = r[j];
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                        PIXEL(video, ff, i, j, 0, H, W, C) = b[j];
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                PIXEL(video, ff, i, j, 1, H, W, C) = g[j];
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        PIXEL(video, ff, i, j, 2, H, W, C) = r[j];
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		    }

                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		                        delete[] blueRow;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		                                            delete[] greenRow;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		                                                                delete[] redRow;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		                                                                		    //COW
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		                                                                		    		    rowsProcessed++;
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		                                                                		    		                    }
                                                                                                                                             	                                                                                                   	                                          	                   	                            	            		            		                    		                                                 		                             	                                		                 	                                              		                        	                	                                                                                 			            		            					            	                                                                                                                                                                        		                                                                		    		                    
