diff --git a/nx/include/switch.h b/nx/include/switch.h index c3779377..162cb080 100644 --- a/nx/include/switch.h +++ b/nx/include/switch.h @@ -28,6 +28,7 @@ extern "C" { #include "switch/kernel/condvar.h" #include "switch/kernel/thread.h" #include "switch/kernel/semaphore.h" +#include "switch/kernel/barrier.h" #include "switch/kernel/virtmem.h" #include "switch/kernel/detect.h" #include "switch/kernel/random.h" @@ -75,6 +76,7 @@ extern "C" { #include "switch/runtime/nxlink.h" #include "switch/runtime/util/utf.h" +#include "switch/runtime/util/list.h" #include "switch/runtime/devices/console.h" #include "switch/runtime/devices/usb_comms.h" diff --git a/nx/include/switch/kernel/barrier.h b/nx/include/switch/kernel/barrier.h index acd23b9c..4e9d7f6e 100644 --- a/nx/include/switch/kernel/barrier.h +++ b/nx/include/switch/kernel/barrier.h @@ -13,7 +13,7 @@ typedef struct barrier { List threads_registered; List threads_waiting; - RwLock mutex; + Mutex mutex; bool isInited; } Barrier; diff --git a/nx/source/kernel/barrier.c b/nx/source/kernel/barrier.c index 88603abe..26115ad5 100644 --- a/nx/source/kernel/barrier.c +++ b/nx/source/kernel/barrier.c @@ -1,17 +1,27 @@ #include "kernel/barrier.h" void barrierInit(Barrier* b) { - rwlockWriteLock(b->mutex); + mutexInit(&b->mutex); + mutexLock(&b->mutex); if(b->isInited) { return; } listInit(b->threads_registered); listInit(b->threads_waiting); b->isInited = true; - rwlockWriteUnlock(b->mutex); + mutexUnlock(&b->mutex); } -void barrierFree(Barrier* b); +void barrierFree(Barrier* b) { + mutexLock(&b->mutex); + if(!b->isInited) { + return; + } + listFree(b->threads_registered); + listFree(b->threads_waiting); + b->isInited = false; + mutexUnlock(&b->mutex); +} void barrierRegister(Barrier* b, Thread* thread) { if(listIsInserted(b->threads_registered, (void*)thread)) { @@ -28,12 +38,19 @@ void barrierWait(Barrier* b, Thread* thread) { if(!listIsInserted(b->threads_registered)) { return; } - threadPause((void*)thread); - listInsertLast(b->threads_waiting, thread); - if(listGetNumNodes(b->threads_registered) == listGetNumNodes(b->threads_waiting)) { + mutexLock(&b->mutex); + if(listGetNumNodes(b->threads_registered) == listGetNumNodes(b->threads_waiting)+1) { while(listGetNumNodes(b->threads_waiting) > 0) { - threadResume(listGetItem(b->threads_waiting, 0)); + Thread* current_thread = listGetItem(b->threads_waiting, 0); + threadResume(current_thread); + listDelete(b->threads_waiting, current_thread); } + mutexUnlock(&b->mutex); + } + else { + listInsertLast(b->threads_waiting, thread); + mutexUnlock(&b->mutex); + threadPause((void*)thread); } } \ No newline at end of file diff --git a/nx/source/runtime/util/list.c b/nx/source/runtime/util/list.c index e634f841..2caed1b1 100644 --- a/nx/source/runtime/util/list.c +++ b/nx/source/runtime/util/list.c @@ -2,7 +2,7 @@ #include void listInit(List* l) { - rwlockWriteLock(l->mutex); + rwlockWriteLock(&l->mutex); if(l->isInited) { return; } @@ -14,11 +14,11 @@ void listInit(List* l) { l->last = header; l->num_nodes = 0; l->isInited = true; - rwlockWriteUnlock(l->mutex); + rwlockWriteUnlock(&l->mutex); } void listFree(List* l) { - rwlockWriteLock(l->mutex); + rwlockWriteLock(&l->mutex); if(!l->isInited) { return; } @@ -31,20 +31,20 @@ void listFree(List* l) { l->last = NULL; l->num_nodes = 0; l->isInited = false; - rwlockWriteUnlock(l->mutex); + rwlockWriteUnlock(&l->mutex); } void listInsert(List* l, void* item, u32 pos) { - rwlockReadLock(l->mutex); + rwlockReadLock(&l->mutex); if(!l->isInited) { return; } if(pos > l->num_nodes || pos < 0) { return; } - rwlockReadUnlock(l->mutex); + rwlockReadUnlock(&l->mutex); - rwlockWriteLock(l->mutex); + rwlockWriteLock(&l->mutex); Node* aux = l->header; for(u32 i = pos; i > 0; i--) { aux = aux->next; @@ -56,11 +56,11 @@ void listInsert(List* l, void* item, u32 pos) { aux->next = new; l->num_nodes++; - rwlockWriteUnlock(l->mutex); + rwlockWriteUnlock(&l->mutex); } void listInsertLast(List* l, void* item) { - rwlockWriteLock(l->mutex); + rwlockWriteLock(&l->mutex); if(!l->isInited) { return; } @@ -70,11 +70,11 @@ void listInsertLast(List* l, void* item) { l->last->next = new; l->last = new; - rwlockWriteUnlock(l->mutex); + rwlockWriteUnlock(&l->mutex); } void listDelete(List* l, void* item) { - rwlockWriteLock(l->mutex); + rwlockWriteLock(&l->mutex); if(!l->isInited) { return; } @@ -91,11 +91,11 @@ void listDelete(List* l, void* item) { l->num_nodes--; } - rwlockWriteUnlock(l->mutex); + rwlockWriteUnlock(&l->mutex); } bool listIsInserted(List* l, void* item) { - rwlockReadLock(l->mutex); + rwlockReadLock(&l->mutex); if(!l->isInited) { return; } @@ -104,22 +104,22 @@ bool listIsInserted(List* l, void* item) { aux = aux->next; } bool result = aux == NULL ? false : true; - rwlockReadUnlock(l->mutex); + rwlockReadUnlock(&l->mutex); return result; } u32 listGetNumNodes(List* l) { - rwlockReadLock(l->mutex); + rwlockReadLock(&l->mutex); if(!l->isInited) { return; } u32 result = l->num_nodes; - rwlockReadUnlock(l->mutex); + rwlockReadUnlock(&l->mutex); return result; } void* listGetItem(List* l, u32 pos) { - rwlockReadLock(l->mutex); + rwlockReadLock(&l->mutex); if(!l->isInited) { return; } @@ -131,6 +131,6 @@ void* listGetItem(List* l, u32 pos) { aux = aux->next; } void* result = aux->item; - rwlockReadUnlock(l->mutex); + rwlockReadUnlock(&l->mutex); return result; } \ No newline at end of file