diff --git a/nx/include/switch/kernel/barrier.h b/nx/include/switch/kernel/barrier.h index befa6e93..b5f3acd6 100644 --- a/nx/include/switch/kernel/barrier.h +++ b/nx/include/switch/kernel/barrier.h @@ -5,15 +5,15 @@ * @copyright libnx Authors */ #pragma once -#include "semaphore.h" +#include "mutex.h" +#include "condvar.h" /// Barrier structure. typedef struct Barrier { - u64 count; ///< Number of threads to reach the barrier. - u64 thread_total; ///< Number of threads to wait on. - Semaphore throttle; ///< Semaphore to make sure threads release to scheduler one at a time. - Semaphore lock; ///< Semaphore to lock barrier to prevent multiple operations by threads at once. - Semaphore thread_wait; ///< Semaphore to force a thread to wait if count < thread_total. + u64 count; ///< Number of threads to reach the barrier. + u64 total; ///< Number of threads to wait on. + Mutex mutex; + CondVar condvar; } Barrier; /** diff --git a/nx/source/kernel/barrier.c b/nx/source/kernel/barrier.c index d0c9af1b..f11f4f1b 100644 --- a/nx/source/kernel/barrier.c +++ b/nx/source/kernel/barrier.c @@ -1,29 +1,22 @@ #include "kernel/barrier.h" -void barrierInit(Barrier *b, u64 thread_count) { +void barrierInit(Barrier *b, u64 total) { b->count = 0; - b->thread_total = thread_count; - semaphoreInit(&b->throttle, 0); - semaphoreInit(&b->lock, 1); - semaphoreInit(&b->thread_wait, 0); + b->total = total - 1; + mutexInit(&b->mutex); + condvarInit(&b->condvar); } void barrierWait(Barrier *b) { - semaphoreWait(&b->lock); - if(b->count < b->thread_total) { - b->count++; + mutexLock(&b->mutex); + + if (b->count++ == b->total) { + b->count = 0; + condvarWake(&b->condvar, b->total); } - if(b->count < b->thread_total) { - semaphoreSignal(&b->lock); - semaphoreWait(&b->thread_wait); - semaphoreSignal(&b->throttle); - } - else if(b->count == b->thread_total) { - for(int i = 0; i < b->thread_total-1; i++) { - semaphoreSignal(&b->thread_wait); - semaphoreWait(&b->throttle); - } - b->count = 0; - semaphoreSignal(&b->lock); + else { + condvarWait(&b->condvar, &b->mutex); } + + mutexUnlock(&b->mutex); }