diff --git a/nx/include/switch/kernel/rwlock.h b/nx/include/switch/kernel/rwlock.h index de594b81..bf23edaa 100644 --- a/nx/include/switch/kernel/rwlock.h +++ b/nx/include/switch/kernel/rwlock.h @@ -6,12 +6,15 @@ */ #pragma once #include "../kernel/mutex.h" +#include "../kernel/condvar.h" /// Read/write lock structure. typedef struct { - RMutex r; - RMutex g; - u64 b; + Mutex mutex; + CondVar condvar_readers; + CondVar condvar_writer; + u32 readers : 31; + bool writer : 1; } RwLock; /** diff --git a/nx/source/kernel/rwlock.c b/nx/source/kernel/rwlock.c index 962dd0ca..1da465ad 100644 --- a/nx/source/kernel/rwlock.c +++ b/nx/source/kernel/rwlock.c @@ -3,33 +3,57 @@ #include "kernel/rwlock.h" void rwlockInit(RwLock* r) { - rmutexInit(&r->r); - rmutexInit(&r->g); - r->b = 0; + mutexInit(&r->mutex); + condvarInit(&r->condvar_readers); + condvarInit(&r->condvar_writer); + + r->readers = 0; + r->writer = false; } void rwlockReadLock(RwLock* r) { - rmutexLock(&r->r); + mutexLock(&r->mutex); - if (r->b++ == 0) - rmutexLock(&r->g); + while (r->writer) { + condvarWait(&r->condvar_writer, &r->mutex); + } - rmutexUnlock(&r->r); + r->readers++; + + mutexUnlock(&r->mutex); } void rwlockReadUnlock(RwLock* r) { - rmutexLock(&r->r); + mutexLock(&r->mutex); - if (r->b-- == 1) - rmutexUnlock(&r->g); + if (--r->readers == 0) { + condvarWakeAll(&r->condvar_readers); + } - rmutexUnlock(&r->r); + mutexUnlock(&r->mutex); } void rwlockWriteLock(RwLock* r) { - rmutexLock(&r->g); + mutexLock(&r->mutex); + + while (r->writer) { + condvarWait(&r->condvar_writer, &r->mutex); + } + + r->writer = true; + + while (r->readers > 0) { + condvarWait(&r->condvar_readers, &r->mutex); + } + + mutexUnlock(&r->mutex); } void rwlockWriteUnlock(RwLock* r) { - rmutexUnlock(&r->g); + mutexLock(&r->mutex); + + r->writer = false; + condvarWakeAll(&r->condvar_writer); + + mutexUnlock(&r->mutex); }