diff --git a/nx/source/runtime/c11-threads.c b/nx/source/runtime/c11-threads.c index b88f597d..d256750c 100644 --- a/nx/source/runtime/c11-threads.c +++ b/nx/source/runtime/c11-threads.c @@ -1,13 +1,47 @@ #include +#include #include #include "kernel/svc.h" #include "../internal.h" +#define THRD_MAIN_HANDLE ((thrd_t)~(uintptr_t)0) + +static bool timespec_subtract(struct timespec x, struct timespec y, struct timespec *__restrict result) +{ + // Perform the carry for the later subtraction by updating y + if (x.tv_nsec < y.tv_nsec) { + int seconds = (y.tv_nsec - x.tv_nsec) / 1000000000 + 1; + y.tv_nsec -= 1000000000 * seconds; + y.tv_sec += seconds; + } + if (x.tv_nsec - y.tv_nsec > 1000000000) { + int seconds = (x.tv_nsec - y.tv_nsec) / 1000000000; + y.tv_nsec += 1000000000 * seconds; + y.tv_sec -= seconds; + } + + // Compute the time remaining to wait + result->tv_sec = x.tv_sec - y.tv_sec; + result->tv_nsec = x.tv_nsec - y.tv_nsec; + + // Return true if result is negative + return x.tv_sec < y.tv_sec; +} + static inline u64 impl_timespec2nsec(const struct timespec *__restrict ts) { return (u64)ts->tv_sec * 1000000000 + ts->tv_nsec; } +static u64 impl_abstimespec2nsec(const struct timespec *__restrict ts) +{ + struct timespec now, diff; + clock_gettime(CLOCK_REALTIME, &now); + if (timespec_subtract(*ts, now, &diff)) + return 0; + return impl_timespec2nsec(&diff); +} + void call_once(once_flag *flag, void (*func)(void)) { mtx_lock(&flag->mutex); @@ -61,11 +95,25 @@ int cnd_signal(cnd_t *cond) static int __cnd_timedwait(cnd_t *__restrict cond, mtx_t *__restrict mtx, u64 timeout) { - if (!cond || !mtx || mtx->type != mtx_plain) + if (!cond || !mtx) return thrd_error; + uint32_t thread_tag_backup = 0; + if (mtx->type & mtx_recursive) { + if (mtx->rmutex.counter != 1) + return thrd_error; + thread_tag_backup = mtx->rmutex.thread_tag; + mtx->rmutex.thread_tag = 0; + mtx->rmutex.counter = 0; + } + Result rc = condvarWaitTimeout(cond, &mtx->mutex, timeout); + if (mtx->type & mtx_recursive) { + mtx->rmutex.thread_tag = thread_tag_backup; + mtx->rmutex.counter = 1; + } + return R_SUCCEEDED(rc) ? thrd_success : thrd_error; } @@ -74,7 +122,7 @@ int cnd_timedwait(cnd_t *__restrict cond, mtx_t *__restrict mtx, const struct ti if (!abs_time) return thrd_error; - return __cnd_timedwait(cond, mtx, impl_timespec2nsec(abs_time)); + return __cnd_timedwait(cond, mtx, impl_abstimespec2nsec(abs_time)); } int cnd_wait(cnd_t *cond, mtx_t *mtx) @@ -89,18 +137,14 @@ void mtx_destroy(mtx_t *mtx) int mtx_init(mtx_t *mtx, int type) { - if (!mtx || (type != mtx_plain && type != mtx_recursive)) + if (!mtx || (type & mtx_timed) || !(type & mtx_plain)) return thrd_error; mtx->type = type; - switch (type) { - case mtx_plain: - mutexInit(&mtx->mutex); - break; - case mtx_recursive: - rmutexInit(&mtx->rmutex); - break; - } + if (mtx->type & mtx_recursive) + rmutexInit(&mtx->rmutex); + else + mutexInit(&mtx->mutex); return thrd_success; } @@ -109,14 +153,10 @@ int mtx_lock(mtx_t *mtx) if (!mtx) return thrd_error; - switch (mtx->type) { - case mtx_plain: - mutexLock(&mtx->mutex); - break; - case mtx_recursive: - rmutexLock(&mtx->rmutex); - break; - } + if (mtx->type & mtx_recursive) + rmutexLock(&mtx->rmutex); + else + mutexLock(&mtx->mutex); return thrd_success; } @@ -132,14 +172,10 @@ int mtx_trylock(mtx_t *mtx) return thrd_error; bool res = false; - switch (mtx->type) { - case mtx_plain: - res = mutexTryLock(&mtx->mutex); - break; - case mtx_recursive: - res = rmutexTryLock(&mtx->rmutex); - break; - } + if (mtx->type & mtx_recursive) + res = rmutexTryLock(&mtx->rmutex); + else + res = mutexTryLock(&mtx->mutex); return res ? thrd_success : thrd_error; } @@ -148,14 +184,10 @@ int mtx_unlock(mtx_t *mtx) if (!mtx) return thrd_error; - switch (mtx->type) { - case mtx_plain: - mutexUnlock(&mtx->mutex); - break; - case mtx_recursive: - rmutexUnlock(&mtx->rmutex); - break; - } + if (mtx->type & mtx_recursive) + rmutexUnlock(&mtx->rmutex); + else + mutexUnlock(&mtx->mutex); return thrd_success; } @@ -237,7 +269,8 @@ _error1: thrd_t thrd_current(void) { - return (thrd_t)getThreadVars()->thread_ptr; + Thread* t = getThreadVars()->thread_ptr; + return t ? (thrd_t)t : THRD_MAIN_HANDLE; } /* @@ -248,7 +281,7 @@ int thrd_detach(thrd_t thr) int thrd_equal(thrd_t thr1, thrd_t thr2) { - return thr1 && thr2 && thr1->thr.handle == thr2->thr.handle; + return thr1 == thr2; } void thrd_exit(int res) @@ -261,6 +294,8 @@ void thrd_exit(int res) int thrd_join(thrd_t thr, int *res) { Result rc; + if (!thr || thr == THRD_MAIN_HANDLE) + return thrd_error; rc = threadWaitForExit(&thr->thr); if (R_FAILED(rc))