/* * Copyright (c) 2018-2019 Atmosphère-NX * * This program is free software; you can redistribute it and/or modify it * under the terms and conditions of the GNU General Public License, * version 2, as published by the Free Software Foundation. * * This program is distributed in the hope it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #pragma once #include #include #include #include "results.hpp" #include "waitable_manager_base.hpp" #include "event.hpp" #include "ipc.hpp" #include "servers.hpp" #include "scope_guard.hpp" static inline Handle GetCurrentThreadHandle() { return threadGetCurHandle(); } struct DefaultManagerOptions { static constexpr size_t PointerBufferSize = 0; static constexpr size_t MaxDomains = 0; static constexpr size_t MaxDomainObjects = 0; }; struct DomainEntry { ServiceObjectHolder obj_holder; IDomainObject *owner = nullptr; }; template class WaitableManager : public SessionManagerBase { private: /* Domain Manager */ HosMutex domain_lock; std::array domain_keys; std::array is_domain_allocated; std::array domain_objects; /* Waitable Manager */ std::vector to_add_waitables; std::vector waitables; std::vector deferred_waitables; u32 num_extra_threads = 0; HosThread *threads = nullptr; HosMutex process_lock; HosMutex signal_lock; HosMutex add_lock; HosMutex cur_thread_lock; HosMutex deferred_lock; bool has_new_waitables = false; std::atomic should_stop = false; IWaitable *next_signaled = nullptr; Handle main_thread_handle = INVALID_HANDLE; Handle cur_thread_handle = INVALID_HANDLE; public: WaitableManager(u32 n, u32 ss = 0x8000) : num_extra_threads(n-1) { u32 prio; if (num_extra_threads) { threads = new HosThread[num_extra_threads]; R_ASSERT(svcGetThreadPriority(&prio, CUR_THREAD_HANDLE)); for (unsigned int i = 0; i < num_extra_threads; i++) { R_ASSERT(threads[i].Initialize(&WaitableManager::ProcessLoop, this, ss, prio)); } } } ~WaitableManager() override { /* This should call the destructor for every waitable. */ std::for_each(to_add_waitables.begin(), to_add_waitables.end(), std::default_delete{}); std::for_each(waitables.begin(), waitables.end(), std::default_delete{}); std::for_each(deferred_waitables.begin(), deferred_waitables.end(), std::default_delete{}); /* If we've reached here, we should already have exited the threads. */ } virtual void AddWaitable(IWaitable *w) override { std::scoped_lock lk{this->add_lock}; this->to_add_waitables.push_back(w); w->SetManager(this); this->has_new_waitables = true; this->CancelSynchronization(); } virtual void RequestStop() { this->should_stop = true; this->CancelSynchronization(); } virtual void CancelSynchronization() { svcCancelSynchronization(GetProcessingThreadHandle()); } virtual void NotifySignaled(IWaitable *w) override { std::scoped_lock lk{this->signal_lock}; if (this->next_signaled == nullptr) { this->next_signaled = w; } this->CancelSynchronization(); } virtual void Process() override { /* Add initial set of waitables. */ AddWaitablesInternal(); /* Set main thread handle. */ this->main_thread_handle = GetCurrentThreadHandle(); for (unsigned int i = 0; i < num_extra_threads; i++) { R_ASSERT(threads[i].Start()); } ProcessLoop(this); } private: void SetProcessingThreadHandle(Handle h) { std::scoped_lock lk{this->cur_thread_lock}; this->cur_thread_handle = h; } Handle GetProcessingThreadHandle() { std::scoped_lock lk{this->cur_thread_lock}; return this->cur_thread_handle; } static void ProcessLoop(void *t) { WaitableManager *this_ptr = (WaitableManager *)t; while (true) { IWaitable *w = this_ptr->GetWaitable(); if (this_ptr->should_stop) { if (GetCurrentThreadHandle() == this_ptr->main_thread_handle) { /* Join all threads but the main one. */ for (unsigned int i = 0; i < this_ptr->num_extra_threads; i++) { this_ptr->threads[i].Join(); } break; } else { /* Return, this will cause thread to exit. */ return; } } if (w) { if (w->HandleSignaled(0) == ResultKernelConnectionClosed) { /* Close! */ delete w; } else { if (w->IsDeferred()) { std::scoped_lock lk{this_ptr->deferred_lock}; this_ptr->deferred_waitables.push_back(w); } else { this_ptr->AddWaitable(w); } } } /* We finished processing, and maybe that means we can stop deferring an object. */ { std::scoped_lock lk{this_ptr->deferred_lock}; bool undeferred_any = true; while (undeferred_any) { undeferred_any = false; for (auto it = this_ptr->deferred_waitables.begin(); it != this_ptr->deferred_waitables.end();) { auto w = *it; const bool closed = (w->HandleDeferred() == ResultKernelConnectionClosed); if (closed || !w->IsDeferred()) { /* Remove from the deferred list, set iterator. */ it = this_ptr->deferred_waitables.erase(it); if (closed) { /* Delete the closed waitable. */ delete w; } else { /* Add to the waitables list. */ this_ptr->AddWaitable(w); undeferred_any = true; } } else { /* Move on to the next deferred waitable. */ it++; } } } } } } IWaitable *GetWaitable() { std::scoped_lock lk{this->process_lock}; /* Set processing thread handle while in scope. */ SetProcessingThreadHandle(GetCurrentThreadHandle()); ON_SCOPE_EXIT { SetProcessingThreadHandle(INVALID_HANDLE); }; /* Prepare variables for result. */ this->next_signaled = nullptr; IWaitable *result = nullptr; if (this->should_stop) { return nullptr; } /* Add new waitables, if any. */ AddWaitablesInternal(); /* First, see if anything's already signaled. */ for (auto &w : this->waitables) { if (w->IsSignaled()) { result = w; } } /* It's possible somebody signaled us while we were iterating. */ { std::scoped_lock lk{this->signal_lock}; if (this->next_signaled != nullptr) result = this->next_signaled; } if (result == nullptr) { std::vector handles; std::vector wait_list; int handle_index = 0; while (result == nullptr) { /* Sort waitables by priority. */ std::sort(this->waitables.begin(), this->waitables.end(), IWaitable::Compare); /* Copy out handles. */ handles.resize(this->waitables.size()); wait_list.resize(this->waitables.size()); unsigned int num_handles = 0; /* Try to add waitables to wait list. */ for (unsigned int i = 0; i < this->waitables.size(); i++) { Handle h = this->waitables[i]->GetHandle(); if (h != INVALID_HANDLE) { wait_list[num_handles] = this->waitables[i]; handles[num_handles++] = h; } } /* Wait forever. */ const Result wait_res = svcWaitSynchronization(&handle_index, handles.data(), num_handles, U64_MAX); if (this->should_stop) { return nullptr; } if (R_SUCCEEDED(wait_res)) { IWaitable *w = wait_list[handle_index]; size_t w_ind = std::distance(this->waitables.begin(), std::find(this->waitables.begin(), this->waitables.end(), w)); std::for_each(waitables.begin(), waitables.begin() + w_ind + 1, std::mem_fn(&IWaitable::UpdatePriority)); result = w; } else if (wait_res == ResultKernelTimedOut) { /* Timeout: Just update priorities. */ std::for_each(waitables.begin(), waitables.end(), std::mem_fn(&IWaitable::UpdatePriority)); } else if (wait_res == ResultKernelCancelled) { /* svcCancelSynchronization was called. */ AddWaitablesInternal(); { std::scoped_lock lk{this->signal_lock}; if (this->next_signaled != nullptr) { result = this->next_signaled; } } } else { /* TODO: Consider the following cases that this covers: */ /* 7601: Thread termination requested. */ /* E401: Handle is dead. */ /* E601: Handle list address invalid. */ /* EE01: Too many handles. */ std::abort(); } } } this->waitables.erase(std::remove_if(this->waitables.begin(), this->waitables.end(), [&](IWaitable *w) { return w == result; }), this->waitables.end()); return result; } void AddWaitablesInternal() { std::scoped_lock lk{this->add_lock}; if (this->has_new_waitables) { this->waitables.insert(this->waitables.end(), this->to_add_waitables.begin(), this->to_add_waitables.end()); this->to_add_waitables.clear(); this->has_new_waitables = false; } } /* Session Manager */ public: virtual void AddSession(Handle server_h, ServiceObjectHolder &&service) override { this->AddWaitable(new ServiceSession(server_h, ManagerOptions::PointerBufferSize, std::move(service))); } /* Domain Manager */ public: virtual std::shared_ptr AllocateDomain() override { std::scoped_lock lk{this->domain_lock}; for (size_t i = 0; i < ManagerOptions::MaxDomains; i++) { if (!this->is_domain_allocated[i]) { auto new_domain = std::make_shared(this); this->domain_keys[i] = reinterpret_cast(new_domain.get()); this->is_domain_allocated[i] = true; return new_domain; } } return nullptr; } void FreeDomain(IDomainObject *domain) override { std::scoped_lock lk{this->domain_lock}; for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) { FreeObject(domain, i+1); } for (size_t i = 0; i < ManagerOptions::MaxDomains; i++) { if (this->domain_keys[i] == reinterpret_cast(domain)) { this->is_domain_allocated[i] = false; break; } } } virtual Result ReserveObject(IDomainObject *domain, u32 *out_object_id) override { std::scoped_lock lk{this->domain_lock}; for (size_t i = 0; i < ManagerOptions::MaxDomainObjects; i++) { if (this->domain_objects[i].owner == nullptr) { this->domain_objects[i].owner = domain; *out_object_id = i+1; return ResultSuccess; } } return ResultServiceFrameworkOutOfDomainEntries; } virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return ResultServiceFrameworkOutOfDomainEntries; } if (this->domain_objects[object_id-1].owner == nullptr) { this->domain_objects[object_id-1].owner = domain; return ResultSuccess; } return ResultServiceFrameworkOutOfDomainEntries; } virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override { std::scoped_lock lk{this->domain_lock}; if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return; } if (this->domain_objects[object_id-1].owner == domain) { this->domain_objects[object_id-1].obj_holder = std::move(holder); } } virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return nullptr; } if (this->domain_objects[object_id-1].owner == domain) { return &this->domain_objects[object_id-1].obj_holder; } return nullptr; } virtual Result FreeObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return ResultHipcDomainObjectNotFound; } if (this->domain_objects[object_id-1].owner == domain) { this->domain_objects[object_id-1].obj_holder.Reset(); this->domain_objects[object_id-1].owner = nullptr; return ResultSuccess; } return ResultHipcDomainObjectNotFound; } virtual Result ForceFreeObject(u32 object_id) override { std::scoped_lock lk{this->domain_lock}; if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return ResultHipcDomainObjectNotFound; } if (this->domain_objects[object_id-1].owner != nullptr) { this->domain_objects[object_id-1].obj_holder.Reset(); this->domain_objects[object_id-1].owner = nullptr; return ResultSuccess; } return ResultHipcDomainObjectNotFound; } };