diff --git a/nx/include/switch/kernel/thread.h b/nx/include/switch/kernel/thread.h index e3e4ea15..6c533fd7 100644 --- a/nx/include/switch/kernel/thread.h +++ b/nx/include/switch/kernel/thread.h @@ -10,11 +10,14 @@ #include "wait.h" /// Thread information structure. -typedef struct { +typedef struct Thread { Handle handle; ///< Thread handle. void* stack_mem; ///< Pointer to stack memory. void* stack_mirror; ///< Pointer to stack memory mirror. size_t stack_sz; ///< Stack size. + void** tls_array; + struct Thread* next; + struct Thread** prev_next; } Thread; /// Creates a \ref Waiter for a \ref Thread. @@ -44,6 +47,11 @@ Result threadCreate( */ Result threadStart(Thread* t); +/** + * @brief Exits the current thread immediately. + */ +void NORETURN threadExit(void); + /** * @brief Waits for a thread to finish executing. * @param t Thread information structure. @@ -86,3 +94,30 @@ Result threadDumpContext(ThreadContext* ctx, Thread* t); * @return The current thread's handle. */ Handle threadGetCurHandle(void); + +/** + * @brief Allocates a TLS slot. + * @param destructor Function to run automatically when a thread exits. + * @return TLS slot ID on success, or a negative value on failure. + */ +s32 threadTlsAlloc(void (* destructor)(void*)); + +/** + * @brief Retrieves the value stored in a TLS slot. + * @param slot_id TLS slot ID. + * @return Value. + */ +void* threadTlsGet(s32 slot_id); + +/** + * @brief Stores the specified value into a TLS slot. + * @param slot_id TLS slot ID. + * @param value Value. + */ +void threadTlsSet(s32 slot_id, void* value); + +/** + * @brief Frees a TLS slot. + * @param slot_id TLS slot ID. + */ +void threadTlsFree(s32 slot_id); diff --git a/nx/source/internal.h b/nx/source/internal.h index f86ec6db..422fe7de 100644 --- a/nx/source/internal.h +++ b/nx/source/internal.h @@ -5,7 +5,7 @@ #define THREADVARS_MAGIC 0x21545624 // !TV$ -// This structure is exactly 0x20 bytes, if more is needed modify getThreadVars() below +// This structure is exactly 0x20 bytes typedef struct { // Magic value used to check if the struct is initialized u32 magic; @@ -24,5 +24,5 @@ typedef struct { } ThreadVars; static inline ThreadVars* getThreadVars(void) { - return (ThreadVars*)((u8*)armGetTls() + 0x1E0); + return (ThreadVars*)((u8*)armGetTls() + 0x200 - sizeof(ThreadVars)); } diff --git a/nx/source/kernel/thread.c b/nx/source/kernel/thread.c index f12eda8a..c78dec27 100644 --- a/nx/source/kernel/thread.c +++ b/nx/source/kernel/thread.c @@ -5,15 +5,25 @@ #include "result.h" #include "kernel/svc.h" #include "kernel/virtmem.h" +#include "kernel/mutex.h" #include "kernel/thread.h" #include "kernel/wait.h" +#include "services/fatal.h" #include "../internal.h" +#define NUM_TLS_SLOTS ((0x100 - sizeof(ThreadVars)) / sizeof(void*)) + extern const u8 __tdata_lma[]; extern const u8 __tdata_lma_end[]; extern u8 __tls_start[]; extern u8 __tls_end[]; +static Mutex g_threadMutex; +static Thread* g_threadList; + +static u64 g_tlsUsageMask; +static void (* g_tlsDestructors[NUM_TLS_SLOTS])(void*); + // Thread creation args; keep this struct's size 16-byte aligned typedef struct { Thread* t; @@ -33,9 +43,19 @@ static void _EntryWrap(ThreadEntryArgs* args) { tv->tls_tp = (u8*)args->tls-2*sizeof(void*); // subtract size of Thread Control Block (TCB) tv->handle = args->t->handle; + // Initialize thread info + mutexLock(&g_threadMutex); + args->t->tls_array = (void**)((u8*)armGetTls() + 0x100); + args->t->prev_next = &g_threadList; + args->t->next = g_threadList; + if (g_threadList) + g_threadList->prev_next = &args->t->next; + g_threadList = args->t; + mutexUnlock(&g_threadMutex); + // Launch thread entrypoint args->entry(args->arg); - svcExitThread(); + threadExit(); } Result threadCreate( @@ -72,6 +92,9 @@ Result threadCreate( t->stack_mem = stack; t->stack_mirror = stack_mirror; t->stack_sz = stack_sz; + t->tls_array = NULL; + t->next = NULL; + t->prev_next = NULL; args->t = t; args->entry = entry; @@ -109,6 +132,34 @@ Result threadCreate( return rc; } +void threadExit(void) { + Thread* t = getThreadVars()->thread_ptr; + if (!t) + fatalSimple(MAKERESULT(Module_Libnx, LibnxError_NotInitialized)); + + u64 tls_mask = __atomic_load_n(&g_tlsUsageMask, __ATOMIC_SEQ_CST); + for (s32 i = 0; i < NUM_TLS_SLOTS; i ++) { + if (!(tls_mask & ((UINT64_C(1) << i)))) + continue; + if (t->tls_array[i]) { + if (g_tlsDestructors[i]) + g_tlsDestructors[i](t->tls_array[i]); + t->tls_array[i] = NULL; + } + } + + mutexLock(&g_threadMutex); + *t->prev_next = t->next; + if (t->next) + t->next->prev_next = t->prev_next; + t->tls_array = NULL; + t->next = NULL; + t->prev_next = NULL; + mutexUnlock(&g_threadMutex); + + svcExitThread(); +} + Result threadStart(Thread* t) { return svcStartThread(t->handle); } @@ -120,6 +171,9 @@ Result threadWaitForExit(Thread* t) { Result threadClose(Thread* t) { Result rc; + if (t->tls_array) + return MAKERESULT(Module_Libnx, LibnxError_BadInput); + rc = svcUnmapMemory(t->stack_mirror, t->stack_mem, t->stack_sz); virtmemFreeStack(t->stack_mirror, t->stack_sz); free(t->stack_mem); @@ -143,3 +197,43 @@ Result threadDumpContext(ThreadContext* ctx, Thread* t) { Handle threadGetCurHandle(void) { return getThreadVars()->handle; } + +s32 threadTlsAlloc(void (* destructor)(void*)) { + s32 slot_id; + u64 new_mask; + u64 cur_mask = __atomic_load_n(&g_tlsUsageMask, __ATOMIC_SEQ_CST); + do { + slot_id = __builtin_ffs(~cur_mask)-1; + if (slot_id < 0 || slot_id >= NUM_TLS_SLOTS) return -1; + new_mask = cur_mask | (UINT64_C(1) << slot_id); + } while (!__atomic_compare_exchange_n(&g_tlsUsageMask, &cur_mask, new_mask, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)); + + threadTlsSet(slot_id, NULL); + mutexLock(&g_threadMutex); + for (Thread *t = g_threadList; t; t = t->next) + t->tls_array[slot_id] = NULL; + mutexUnlock(&g_threadMutex); + + g_tlsDestructors[slot_id] = destructor; + return slot_id; +} + +void* threadTlsGet(s32 slot_id) { + void** tls_array = (void**)((u8*)armGetTls() + 0x100); + return tls_array[slot_id]; +} + +void threadTlsSet(s32 slot_id, void* value) { + void** tls_array = (void**)((u8*)armGetTls() + 0x100); + tls_array[slot_id] = value; +} + +void threadTlsFree(s32 slot_id) { + g_tlsDestructors[slot_id] = NULL; + + u64 new_mask; + u64 cur_mask = __atomic_load_n(&g_tlsUsageMask, __ATOMIC_SEQ_CST); + do + new_mask = cur_mask &~ (UINT64_C(1) << slot_id); + while (!__atomic_compare_exchange_n(&g_tlsUsageMask, &cur_mask, new_mask, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)); +} diff --git a/nx/source/runtime/c11-threads.c b/nx/source/runtime/c11-threads.c index 43c3667d..28645ff5 100644 --- a/nx/source/runtime/c11-threads.c +++ b/nx/source/runtime/c11-threads.c @@ -291,7 +291,7 @@ void thrd_exit(int res) { thrd_t t = thrd_current(); t->rc = res; - svcExitThread(); + threadExit(); } int thrd_join(thrd_t thr, int *res) @@ -332,20 +332,32 @@ void thrd_yield(void) svcSleepThread(-1); } -/* int tss_create(tss_t *key, tss_dtor_t dtor) { + if (!key) + return thrd_error; + + s32 slot_id = threadTlsAlloc(dtor); + if (slot_id >= 0) { + *key = slot_id; + return thrd_success; + } + + return thrd_error; } void tss_delete(tss_t key) { + threadTlsFree(key); } void * tss_get(tss_t key) { + return threadTlsGet(key); } int tss_set(tss_t key, void *val) { + threadTlsSet(key, val); + return thrd_success; } -*/