#include <assert.h>
#include <fcntl.h> // For O_* constants
#include <pthread.h>
#include <semaphore.h>
#include <stdlib.h>
#include <stdio.h>
#include <sys/stat.h> // For mode constants

#ifdef __APPLE__
#include "pthread_barrier.h"
#endif

#include "levenshtein.h"


// Pthreads rows Levenshtein ======================================================================

typedef struct
{
	// Strings to be compared
	const void* col_str;
	const void* row_str;

	// Lengths of the strings to be compared
	size_t col_len;
	size_t row_len;

	// Total number of workers launched
	size_t worker_count;

	// Matrix of w*(min_len+1) distances
	size_t** distances;

	// Threads will be calculating the distances of this row
	size_t current_row;

	// Shared semaphores to avoid a thread bypass a previous one
	sem_t** semaphores;

	// True if comparison must be done in Unicode, false for ASCII
	size_t unicode; // bool, but used size_t for padding
} byrows_shared_t;

typedef struct
{
	// Thread number
	size_t id;
	// Shared data
	byrows_shared_t* shared;
} byrows_worker_data_t;

void* calculate_distance_byrows(void* data);

size_t levenshtein_distance_pthreads_byrows(const void* s1, size_t len1, const void* s2, size_t len2, bool unicode, size_t* workers)
{
	assert(s1);
	assert(s2);
	assert(len1);
	assert(len2);

	// Create the shared data for all the threads
	byrows_shared_t* shared = (byrows_shared_t*) malloc( 1 * sizeof(byrows_shared_t) );
	if ( shared == NULL )
		return (void)fprintf(stderr, "levdist: error: no memory for shared data\n"), (size_t)-1;

	// Init the shared data
	shared->unicode = unicode;

	// We use the longest string for rows and shortest for columns
	shared->row_str = len1 > len2 ? s1 : s2; // max
	shared->col_str = len1 > len2 ? s2 : s1; // min
	shared->row_len = len1 > len2 ? len1 : len2; // max
	shared->col_len = len1 > len2 ? len2 : len1; // min

	// Do not use more workers than the available rows
	*workers = min(*workers, shared->row_len + 1);
	shared->worker_count = *workers;

	// Create a matrix of distances shared by all workers with a row for each thread
	const size_t cols = shared->col_len + 1;
	shared->distances = create_distance_matrix(*workers, cols);

	// Init the first row
	for ( size_t col = 0; col < cols; ++col )
		shared->distances[0][col] = col;

	// Next pending row is 1
	shared->current_row = 1;

	// Create and init the semaphores, one for each thread
	shared->semaphores = (sem_t**) malloc( shared->worker_count * sizeof(sem_t*) );
	for ( size_t index = 0; index < *workers; ++index )
	{
		// The first thread has its semaphore on, the rest have theirs off
		char name[250];
		snprintf(name, 250, "/levdist%zu", index);
		if ( (shared->semaphores[index] = sem_open(name, O_CREAT, 0644, index == 0)) == SEM_FAILED )
			return (void)fprintf(stderr, "levdist: error: could not create semaphore\n"), (size_t)-1;
	}

	// Create the exclusive data of each thread
	byrows_worker_data_t* thread_data = (byrows_worker_data_t*)malloc( *workers * sizeof(byrows_worker_data_t) );
	if ( thread_data == NULL )
		return (void)fprintf(stderr, "levdist: error: no memory for thread data\n"), (size_t)-1;

	// Create the threads
	pthread_t* threads = (pthread_t*)malloc( *workers * sizeof(pthread_t) );
	if ( threads == NULL )
		return (void)fprintf(stderr, "levdist: error: no memory for threads\n"), (size_t)-1;

	// Start each thread
	for ( size_t index = 0; index < *workers; ++index )
	{
		thread_data[index].id = index;
		thread_data[index].shared = shared;

		if ( pthread_create( threads + index, NULL, calculate_distance_byrows, thread_data + index) )
			return (void)fprintf(stderr, "levdist: error: could not create thread %zu\n", index), (size_t)-1;
	}

	// Wait for all threads to finish
	for ( size_t index = 0; index < *workers; ++index )
		pthread_join(threads[index], NULL);

	// Hold the result
	size_t result = shared->distances[ (shared->current_row - 1) % *workers ][shared->col_len];

	// Close and free the semaphores
	for ( size_t index = 0; index < *workers; ++index )
		sem_close(shared->semaphores[index]);
	free(shared->semaphores);

	// Free the arrays
	destroy_distance_matrix(shared->distances, 3);

	// Destroy the threads and shared data
	free(thread_data);
	free(threads);
	free(shared);

	// Return the Levenshtein distance
	return result;
}

static inline void calculate_distance_byrows_ascii(byrows_worker_data_t* worker, size_t my_row, size_t* row, size_t* prev)
{
	const char* row_str = (const char*) worker->shared->row_str;
	const char* col_str = (const char*) worker->shared->col_str;

	// Calculate distances for current line
	for ( size_t col = 0; col < worker->shared->col_len; ++col )
	{
		// Wait until the previous thread has updated previous row cell
		sem_wait( worker->shared->semaphores[worker->id] );

		// Update
		row[col + 1] = min3(
			prev[col + 1] + 1,
			row[col] + 1,
			prev[col] + (size_t)(row_str[my_row] != col_str[col]) );

		// I have produced the value that the next worker needs
		sem_post( worker->shared->semaphores[ (worker->id + 1) % worker->shared->worker_count] );
	}
}

static inline void calculate_distance_byrows_unicode(byrows_worker_data_t* worker, size_t my_row, size_t* row, size_t* prev)
{
	const wchar_t* row_str = (const wchar_t*) worker->shared->row_str;
	const wchar_t* col_str = (const wchar_t*) worker->shared->col_str;

	// Calculate distances for current line
	for ( size_t col = 0; col < worker->shared->col_len; ++col )
	{
		// Wait until the previous thread has updated previous row cell
		sem_wait( worker->shared->semaphores[worker->id] );

		// Update
		row[col + 1] = min3(
			prev[col + 1] + 1,
			row[col] + 1,
			prev[col] + (size_t)(row_str[my_row] != col_str[col]) );

		// I have produced the value that the next worker needs
		sem_post( worker->shared->semaphores[ (worker->id + 1) % worker->shared->worker_count] );
	}
}

void* calculate_distance_byrows(void* data)
{
	// Extract the given data from the void* pointers
	byrows_worker_data_t* worker = (byrows_worker_data_t*)data;
	byrows_shared_t* shared = worker->shared;

	// Descend
	while ( shared->current_row <= shared->row_len )
	{
		// Wait for my semaphore to turn on
		sem_wait( shared->semaphores[worker->id] );

		// Take the next pending row
		const size_t row  = shared->current_row++ % shared->worker_count;

		// To calculate this row I need values from the previous row
		const size_t prev = (row + shared->worker_count - 1) % shared->worker_count;

		// Allow next thread to take its row
		sem_post( shared->semaphores[ (worker->id + 1) % shared->worker_count] );

		// Fill the first column in this row
		shared->distances[row][0] = row + 1;

		// Update the distances of this row using the Levenshtein algoritm
		shared->unicode
			? calculate_distance_byrows_unicode(worker, row, shared->distances[row], shared->distances[prev])
			: calculate_distance_byrows_ascii  (worker, row, shared->distances[row], shared->distances[prev]);

		// Todo: we need a barrier here
	}

	return NULL;
}