From 3109b2c480efe0e542b3dedc5d047efe86e2a2e8 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Fri, 18 Oct 2019 18:21:31 -0700 Subject: [PATCH] threads: support using existing mem as stack --- nx/include/switch/kernel/thread.h | 16 ++-- nx/source/kernel/thread.c | 144 ++++++++++++++++++------------ nx/source/runtime/newlib.c | 4 +- 3 files changed, 98 insertions(+), 66 deletions(-) diff --git a/nx/include/switch/kernel/thread.h b/nx/include/switch/kernel/thread.h index 6c533fd7..041acade 100644 --- a/nx/include/switch/kernel/thread.h +++ b/nx/include/switch/kernel/thread.h @@ -11,10 +11,11 @@ /// Thread information structure. typedef struct Thread { - Handle handle; ///< Thread handle. - void* stack_mem; ///< Pointer to stack memory. - void* stack_mirror; ///< Pointer to stack memory mirror. - size_t stack_sz; ///< Stack size. + Handle handle; ///< Thread handle. + bool owns_stack_mem; ///< Whether the stack memory is automatically allocated. + void* stack_mem; ///< Pointer to stack memory. + void* stack_mirror; ///< Pointer to stack memory mirror. + size_t stack_sz; ///< Stack size. void** tls_array; struct Thread* next; struct Thread** prev_next; @@ -31,14 +32,15 @@ static inline Waiter waiterForThread(Thread* t) * @param t Thread information structure which will be filled in. * @param entry Entrypoint of the thread. * @param arg Argument to pass to the entrypoint. - * @param stack_sz Stack size (rounded up to page alignment). + * @param stack_mem Memory to use as backing for stack/tls/reent. Must be page-aligned. NULL argument means to allocate new memory. + * @param stack_sz Stack size (rounded up to page alignment if stack_mem is NULL). * @param prio Thread priority (0x00~0x3F); 0x2C is the usual priority of the main thread, 0x3B is a special priority on cores 0..2 that enables preemptive multithreading (0x3F on core 3). * @param cpuid ID of the core on which to create the thread (0~3); or -2 to use the default core for the current process. * @return Result code. */ Result threadCreate( - Thread* t, ThreadFunc entry, void* arg, size_t stack_sz, int prio, - int cpuid); + Thread* t, ThreadFunc entry, void* arg, void *stack_mem, size_t stack_sz, + int prio, int cpuid); /** * @brief Starts the execution of a thread. diff --git a/nx/source/kernel/thread.c b/nx/source/kernel/thread.c index 5ac1ea82..21faf233 100644 --- a/nx/source/kernel/thread.c +++ b/nx/source/kernel/thread.c @@ -61,73 +61,98 @@ static void _EntryWrap(ThreadEntryArgs* args) { } Result threadCreate( - Thread* t, ThreadFunc entry, void* arg, size_t stack_sz, int prio, - int cpuid) + Thread* t, ThreadFunc entry, void* arg, void* stack_mem, size_t stack_sz, + int prio, int cpuid) { - stack_sz = (stack_sz+0xFFF) &~ 0xFFF; - Result rc = 0; - size_t reent_sz = (sizeof(struct _reent)+0xF) &~ 0xF; - size_t tls_sz = (__tls_end-__tls_start+0xF) &~ 0xF; - void* stack = memalign(0x1000, stack_sz + reent_sz + tls_sz); + void* tls; + const size_t tls_sz = (__tls_end-__tls_start+0xF) &~ 0xF; + void* reent; + const size_t reent_sz = (sizeof(struct _reent)+0xF) &~ 0xF; - if (stack == NULL) { - rc = MAKERESULT(Module_Libnx, LibnxError_OutOfMemory); + if (stack_mem == NULL) { + // Allocate new memory, stack then reent then tls. + stack_sz = (stack_sz+0xFFF) & ~0xFFF; + stack_mem = memalign(0x1000, stack_sz + reent_sz + tls_sz); + reent = (void*)((uintptr_t)stack_mem + stack_sz); + tls = (void*)((uintptr_t)reent + reent_sz); + + t->owns_stack_mem = true; + } else { + // Use provided memory for stack, reent, and tls. + if (((uintptr_t)stack_mem & 0xFFF) || (stack_sz & 0xFFF)) { + return MAKERESULT(Module_Libnx, LibnxError_BadInput); + } + + tls = (void*)((uintptr_t)stack_mem + stack_sz - tls_sz); + reent = (void*)((uintptr_t)tls - reent_sz); + + // Ensure we don't go out of bounds. + if ((uintptr_t)reent <= (uintptr_t)stack_mem) { + return MAKERESULT(Module_Libnx, LibnxError_OutOfMemory); + } + + t->owns_stack_mem = false; } - else { - void* stack_mirror = virtmemReserveStack(stack_sz); - rc = svcMapMemory(stack_mirror, stack, stack_sz); + + if (stack_mem == NULL) { + return MAKERESULT(Module_Libnx, LibnxError_OutOfMemory); + } + + void* stack_mirror = virtmemReserveStack(stack_sz); + Result rc = svcMapMemory(stack_mirror, stack_mem, stack_sz); + + if (R_SUCCEEDED(rc)) + { + uintptr_t stack_top = ((uintptr_t)stack_mirror) + stack_sz - sizeof(ThreadEntryArgs); + ThreadEntryArgs* args = (ThreadEntryArgs*) stack_top; + Handle handle; + + rc = svcCreateThread( + &handle, (ThreadFunc) &_EntryWrap, args, (void*)stack_top, + prio, cpuid); if (R_SUCCEEDED(rc)) { - u64 stack_top = ((u64)stack_mirror) + stack_sz - sizeof(ThreadEntryArgs); - ThreadEntryArgs* args = (ThreadEntryArgs*) stack_top; - Handle handle; + t->handle = handle; + t->stack_mem = stack_mem; + t->stack_mirror = stack_mirror; + t->stack_sz = stack_sz - sizeof(ThreadEntryArgs); + t->tls_array = NULL; + t->next = NULL; + t->prev_next = NULL; - rc = svcCreateThread( - &handle, (ThreadFunc) &_EntryWrap, args, (void*)stack_top, - prio, cpuid); + args->t = t; + args->entry = entry; + args->arg = arg; + args->reent = reent; + args->tls = tls; - if (R_SUCCEEDED(rc)) - { - t->handle = handle; - t->stack_mem = stack; - t->stack_mirror = stack_mirror; - t->stack_sz = stack_sz; - t->tls_array = NULL; - t->next = NULL; - t->prev_next = NULL; + // Set up child thread's reent struct, inheriting standard file handles + _REENT_INIT_PTR(args->reent); + struct _reent* cur = getThreadVars()->reent; + args->reent->_stdin = cur->_stdin; + args->reent->_stdout = cur->_stdout; + args->reent->_stderr = cur->_stderr; - args->t = t; - args->entry = entry; - args->arg = arg; - args->reent = (struct _reent*)((u8*)stack + stack_sz); - args->tls = (u8*)stack + stack_sz + reent_sz; - - // Set up child thread's reent struct, inheriting standard file handles - _REENT_INIT_PTR(args->reent); - struct _reent* cur = getThreadVars()->reent; - args->reent->_stdin = cur->_stdin; - args->reent->_stdout = cur->_stdout; - args->reent->_stderr = cur->_stderr; - - // Set up child thread's TLS segment - size_t tls_load_sz = __tdata_lma_end - __tdata_lma; - size_t tls_bss_sz = tls_sz - tls_load_sz; - if (tls_load_sz) - memcpy(args->tls, __tdata_lma, tls_load_sz); - if (tls_bss_sz) - memset(args->tls+tls_load_sz, 0, tls_bss_sz); - } - - if (R_FAILED(rc)) { - svcUnmapMemory(stack_mirror, stack, stack_sz); - } + // Set up child thread's TLS segment + size_t tls_load_sz = __tdata_lma_end - __tdata_lma; + size_t tls_bss_sz = tls_sz - tls_load_sz; + if (tls_load_sz) + memcpy(args->tls, __tdata_lma, tls_load_sz); + if (tls_bss_sz) + memset(args->tls+tls_load_sz, 0, tls_bss_sz); } if (R_FAILED(rc)) { - virtmemFreeStack(stack_mirror, stack_sz); - free(stack); + svcUnmapMemory(stack_mirror, stack_mem, stack_sz); + } + } + + if (R_FAILED(rc)) { + virtmemFreeStack(stack_mirror, stack_sz); + if (t->owns_stack_mem) { + free(stack_mem); } } @@ -178,9 +203,14 @@ Result threadClose(Thread* t) { return MAKERESULT(Module_Libnx, LibnxError_BadInput); rc = svcUnmapMemory(t->stack_mirror, t->stack_mem, t->stack_sz); - virtmemFreeStack(t->stack_mirror, t->stack_sz); - free(t->stack_mem); - svcCloseHandle(t->handle); + + if (R_SUCCEEDED(rc)) { + virtmemFreeStack(t->stack_mirror, t->stack_sz); + if (t->owns_stack_mem) { + free(t->stack_mem); + } + svcCloseHandle(t->handle); + } return rc; } diff --git a/nx/source/runtime/newlib.c b/nx/source/runtime/newlib.c index 7a693845..d3dea4e0 100644 --- a/nx/source/runtime/newlib.c +++ b/nx/source/runtime/newlib.c @@ -170,7 +170,7 @@ static void __thread_entry(void* __arg) int __syscall_thread_create(struct __pthread_t **thread, void* (*func)(void*), void *arg, void *stack_addr, size_t stack_size) { - if (stack_addr || (stack_size & 0xFFF)) + if (((uintptr_t)stack_addr & 0xFFF) || (stack_size & 0xFFF)) return EINVAL; if (!stack_size) stack_size = 128*1024; @@ -194,7 +194,7 @@ int __syscall_thread_create(struct __pthread_t **thread, void* (*func)(void*), v mutexInit(&info.mutex); condvarInit(&info.cond); - rc = threadCreate(&t->thr, __thread_entry, &info, stack_size, 0x3B, -2); + rc = threadCreate(&t->thr, __thread_entry, &info, stack_addr, stack_size, 0x3B, -2); if (R_FAILED(rc)) goto _error1;