fs.mitm: WIP LayeredFS impl (NOTE: UNUSABLE ATM)

Also greatly refactors libstratosphere, and does a lot of other things.
There is a lot of code in this one.
This commit is contained in:
Michael Scire 2018-06-14 17:50:01 -06:00
parent cd5da88405
commit 26e676424d
11 changed files with 206 additions and 50 deletions

View File

@ -16,3 +16,4 @@
#include "stratosphere/hossynch.hpp" #include "stratosphere/hossynch.hpp"
#include "stratosphere/waitablemanager.hpp" #include "stratosphere/waitablemanager.hpp"
#include "stratosphere/multithreadedwaitablemanager.hpp"

View File

@ -1,37 +1,36 @@
#pragma once #pragma once
#include <switch.h> #include <switch.h>
#include <memory>
#include <type_traits> #include <type_traits>
#include "iserviceobject.hpp" #include "iserviceobject.hpp"
#define DOMAIN_ID_MAX 0x200 #define DOMAIN_ID_MAX 0x1000
class IServiceObject; class IServiceObject;
class DomainOwner { class DomainOwner {
private: private:
IServiceObject *domain_objects[DOMAIN_ID_MAX]; std::shared_ptr<IServiceObject> domain_objects[DOMAIN_ID_MAX];
public: public:
DomainOwner() { DomainOwner() {
for (unsigned int i = 0; i < DOMAIN_ID_MAX; i++) { for (unsigned int i = 0; i < DOMAIN_ID_MAX; i++) {
domain_objects[i] = NULL; domain_objects[i].reset();
} }
} }
virtual ~DomainOwner() { virtual ~DomainOwner() {
for (unsigned int i = 0; i < DOMAIN_ID_MAX; i++) { /* Shared ptrs should auto delete here. */
this->delete_object(i);
}
} }
IServiceObject *get_domain_object(unsigned int i) { std::shared_ptr<IServiceObject> get_domain_object(unsigned int i) {
if (i < DOMAIN_ID_MAX) { if (i < DOMAIN_ID_MAX) {
return domain_objects[i]; return domain_objects[i];
} }
return NULL; return nullptr;
} }
Result reserve_object(IServiceObject *object, unsigned int *out_i) { Result reserve_object(std::shared_ptr<IServiceObject> object, unsigned int *out_i) {
for (unsigned int i = 4; i < DOMAIN_ID_MAX; i++) { for (unsigned int i = 4; i < DOMAIN_ID_MAX; i++) {
if (domain_objects[i] == NULL) { if (domain_objects[i] == NULL) {
domain_objects[i] = object; domain_objects[i] = object;
@ -43,7 +42,7 @@ class DomainOwner {
return 0x1900B; return 0x1900B;
} }
Result set_object(IServiceObject *object, unsigned int i) { Result set_object(std::shared_ptr<IServiceObject> object, unsigned int i) {
if (domain_objects[i] == NULL) { if (domain_objects[i] == NULL) {
domain_objects[i] = object; domain_objects[i] = object;
object->set_owner(this); object->set_owner(this);
@ -52,7 +51,7 @@ class DomainOwner {
return 0x1900B; return 0x1900B;
} }
unsigned int get_object_id(IServiceObject *object) { unsigned int get_object_id(std::shared_ptr<IServiceObject> object) {
for (unsigned int i = 0; i < DOMAIN_ID_MAX; i++) { for (unsigned int i = 0; i < DOMAIN_ID_MAX; i++) {
if (domain_objects[i] == object) { if (domain_objects[i] == object) {
return i; return i;
@ -63,16 +62,14 @@ class DomainOwner {
void delete_object(unsigned int i) { void delete_object(unsigned int i) {
if (domain_objects[i]) { if (domain_objects[i]) {
delete domain_objects[i]; domain_objects[i].reset();
domain_objects[i] = NULL;
} }
} }
void delete_object(IServiceObject *object) { void delete_object(std::shared_ptr<IServiceObject> object) {
for (unsigned int i = 0; i < DOMAIN_ID_MAX; i++) { for (unsigned int i = 0; i < DOMAIN_ID_MAX; i++) {
if (domain_objects[i] == object) { if (domain_objects[i] == object) {
delete domain_objects[i]; domain_objects[i].reset();
domain_objects[i] = NULL;
break; break;
} }
} }

View File

@ -4,18 +4,20 @@
#include "iwaitable.hpp" #include "iwaitable.hpp"
typedef Result (*EventCallback)(Handle *handles, size_t num_handles, u64 timeout); typedef Result (*EventCallback)(void *arg, Handle *handles, size_t num_handles, u64 timeout);
class IEvent : public IWaitable { class IEvent : public IWaitable {
protected: protected:
std::vector<Handle> handles; std::vector<Handle> handles;
EventCallback callback; EventCallback callback;
void *arg;
public: public:
IEvent(Handle wait_h, EventCallback callback) { IEvent(Handle wait_h, void *a, EventCallback callback) {
if (wait_h) { if (wait_h) {
this->handles.push_back(wait_h); this->handles.push_back(wait_h);
} }
this->arg = a;
this->callback = callback; this->callback = callback;
} }
@ -41,7 +43,7 @@ class IEvent : public IWaitable {
} }
virtual Result handle_signaled(u64 timeout) { virtual Result handle_signaled(u64 timeout) {
return this->callback(this->handles.data(), this->handles.size(), timeout); return this->callback(this->arg, this->handles.data(), this->handles.size(), timeout);
} }
static Result PanicCallback(Handle *handles, size_t num_handles, u64 timeout) { static Result PanicCallback(Handle *handles, size_t num_handles, u64 timeout) {

View File

@ -17,13 +17,13 @@ class IPCSession final : public ISession<T> {
if (R_FAILED((rc = svcCreateSession(&this->server_handle, &this->client_handle, 0, 0)))) { if (R_FAILED((rc = svcCreateSession(&this->server_handle, &this->client_handle, 0, 0)))) {
fatalSimple(rc); fatalSimple(rc);
} }
this->service_object = new T(); this->service_object = std::make_shared<T>();
this->pointer_buffer_size = pbs; this->pointer_buffer_size = pbs;
this->pointer_buffer = new char[this->pointer_buffer_size]; this->pointer_buffer = new char[this->pointer_buffer_size];
this->is_domain = false; this->is_domain = false;
} }
IPCSession<T>(T *so, size_t pbs = 0x400) : ISession<T>(NULL, 0, 0, so, 0) { IPCSession<T>(std::shared_ptr<T> so, size_t pbs = 0x400) : ISession<T>(NULL, 0, 0, so, 0) {
Result rc; Result rc;
if (R_FAILED((rc = svcCreateSession(&this->server_handle, &this->client_handle, 0, 0)))) { if (R_FAILED((rc = svcCreateSession(&this->server_handle, &this->client_handle, 0, 0)))) {
fatalSimple(rc); fatalSimple(rc);

View File

@ -18,7 +18,6 @@ class IServiceObject {
DomainOwner *get_owner() { return this->owner; } DomainOwner *get_owner() { return this->owner; }
void set_owner(DomainOwner *owner) { this->owner = owner; } void set_owner(DomainOwner *owner) { this->owner = owner; }
virtual Result dispatch(IpcParsedCommand &r, IpcCommand &out_c, u64 cmd_id, u8 *pointer_buffer, size_t pointer_buffer_size) = 0; virtual Result dispatch(IpcParsedCommand &r, IpcCommand &out_c, u64 cmd_id, u8 *pointer_buffer, size_t pointer_buffer_size) = 0;
protected:
virtual Result handle_deferred() = 0; virtual Result handle_deferred() = 0;
}; };

View File

@ -27,10 +27,10 @@ class IServer;
class IServiceObject; class IServiceObject;
template <typename T> template <typename T>
class ISession : public IWaitable, public DomainOwner { class ISession : public IWaitable {
static_assert(std::is_base_of<IServiceObject, T>::value, "Service Objects must derive from IServiceObject"); static_assert(std::is_base_of<IServiceObject, T>::value, "Service Objects must derive from IServiceObject");
protected: protected:
T *service_object; std::shared_ptr<T> service_object;
IServer<T> *server; IServer<T> *server;
Handle server_handle; Handle server_handle;
Handle client_handle; Handle client_handle;
@ -38,35 +38,35 @@ class ISession : public IWaitable, public DomainOwner {
size_t pointer_buffer_size; size_t pointer_buffer_size;
bool is_domain; bool is_domain;
std::shared_ptr<DomainOwner> domain;
IServiceObject *active_object; std::shared_ptr<IServiceObject> active_object;
static_assert(sizeof(pointer_buffer) <= POINTER_BUFFER_SIZE_MAX, "Incorrect Size for PointerBuffer!"); static_assert(sizeof(pointer_buffer) <= POINTER_BUFFER_SIZE_MAX, "Incorrect Size for PointerBuffer!");
public: public:
ISession<T>(IServer<T> *s, Handle s_h, Handle c_h, size_t pbs = 0x400) : server(s), server_handle(s_h), client_handle(c_h), pointer_buffer_size(pbs) { ISession<T>(IServer<T> *s, Handle s_h, Handle c_h, size_t pbs = 0x400) : server(s), server_handle(s_h), client_handle(c_h), pointer_buffer_size(pbs) {
this->service_object = new T(); this->service_object = std::make_shared<T>();
if (this->pointer_buffer_size) { if (this->pointer_buffer_size) {
this->pointer_buffer = new char[this->pointer_buffer_size]; this->pointer_buffer = new char[this->pointer_buffer_size];
} }
this->is_domain = false; this->is_domain = false;
this->active_object = NULL; this->domain.reset();
this->active_object.reset();
} }
ISession<T>(IServer<T> *s, Handle s_h, Handle c_h, T *so, size_t pbs = 0x400) : service_object(so), server(s), server_handle(s_h), client_handle(c_h), pointer_buffer_size(pbs) { ISession<T>(IServer<T> *s, Handle s_h, Handle c_h, std::shared_ptr<T> so, size_t pbs = 0x400) : service_object(so), server(s), server_handle(s_h), client_handle(c_h), pointer_buffer_size(pbs) {
if (this->pointer_buffer_size) { if (this->pointer_buffer_size) {
this->pointer_buffer = new char[this->pointer_buffer_size]; this->pointer_buffer = new char[this->pointer_buffer_size];
} }
this->is_domain = false; this->is_domain = false;
this->active_object = NULL; this->domain.reset();
this->active_object.reset();
} }
~ISession() override { ~ISession() override {
delete this->pointer_buffer; delete this->pointer_buffer;
if (this->service_object && !this->is_domain) {
//delete this->service_object;
}
if (server_handle) { if (server_handle) {
svcCloseHandle(server_handle); svcCloseHandle(server_handle);
} }
@ -86,12 +86,12 @@ class ISession : public IWaitable, public DomainOwner {
} }
} }
T *get_service_object() { return this->service_object; } std::shared_ptr<T> get_service_object() { return this->service_object; }
Handle get_server_handle() { return this->server_handle; } Handle get_server_handle() { return this->server_handle; }
Handle get_client_handle() { return this->client_handle; } Handle get_client_handle() { return this->client_handle; }
DomainOwner *get_owner() { return is_domain ? this : NULL; } DomainOwner *get_owner() { return this->is_domain ? this->domain.get() : NULL; }
/* IWaitable */ /* IWaitable */
Handle get_handle() override { Handle get_handle() override {
@ -125,7 +125,7 @@ class ISession : public IWaitable, public DomainOwner {
if (r.IsDomainMessage && r.MessageType == DomainMessageType_Close) { if (r.IsDomainMessage && r.MessageType == DomainMessageType_Close) {
this->delete_object(this->active_object); this->domain->delete_object(this->active_object);
this->active_object = NULL; this->active_object = NULL;
struct { struct {
u64 magic; u64 magic;
@ -188,7 +188,7 @@ class ISession : public IWaitable, public DomainOwner {
ipcAddRecvStatic(&c_for_reply, this->pointer_buffer, this->pointer_buffer_size, 0); ipcAddRecvStatic(&c_for_reply, this->pointer_buffer, this->pointer_buffer_size, 0);
ipcPrepareHeader(&c_for_reply, 0); ipcPrepareHeader(&c_for_reply, 0);
if (R_SUCCEEDED(rc = svcReplyAndReceive(&handle_index, &this->server_handle, 1, 0, timeout))) { if (R_SUCCEEDED(rc = svcReplyAndReceive(&handle_index, &this->server_handle, 1, 0, U64_MAX))) {
if (handle_index != 0) { if (handle_index != 0) {
/* TODO: Panic? */ /* TODO: Panic? */
} }
@ -203,7 +203,7 @@ class ISession : public IWaitable, public DomainOwner {
if (!r.IsDomainMessage || r.ThisObjectId >= DOMAIN_ID_MAX) { if (!r.IsDomainMessage || r.ThisObjectId >= DOMAIN_ID_MAX) {
retval = 0xF601; retval = 0xF601;
} else { } else {
this->active_object = this->get_domain_object(r.ThisObjectId); this->active_object = this->domain->get_domain_object(r.ThisObjectId);
} }
} else { } else {
this->active_object = this->service_object; this->active_object = this->service_object;
@ -218,19 +218,22 @@ class ISession : public IWaitable, public DomainOwner {
if (retval == RESULT_DEFER_SESSION) { if (retval == RESULT_DEFER_SESSION) {
/* Session defer. */ /* Session defer. */
this->active_object = NULL; this->active_object.reset();
this->set_deferred(true); this->set_deferred(true);
rc = retval; rc = retval;
} else if (retval == 0xF601) { } else if (retval == 0xF601) {
/* Session close. */ /* Session close. */
this->active_object = NULL; this->active_object.reset();
rc = retval; rc = retval;
} else { } else {
if (R_SUCCEEDED(retval)) { if (R_SUCCEEDED(retval)) {
this->postprocess(r, cmd_id); this->postprocess(r, cmd_id);
} }
this->active_object = NULL; this->active_object.reset();
rc = svcReplyAndReceive(&handle_index, &this->server_handle, 0, this->server_handle, 0); rc = svcReplyAndReceive(&handle_index, &this->server_handle, 0, this->server_handle, 0);
if (rc == 0xEA01) {
rc = 0x0;
}
this->cleanup(); this->cleanup();
} }
} }

View File

@ -0,0 +1,42 @@
#pragma once
#include <switch.h>
#include <vector>
#include "waitablemanager.hpp"
#include "systemevent.hpp"
class MultiThreadedWaitableManager : public WaitableManager {
protected:
u32 num_threads;
Thread *threads;
HosMutex get_waitable_lock;
SystemEvent *new_waitable_event;
public:
MultiThreadedWaitableManager(u32 n, u64 t, u32 ss = 0x8000) : WaitableManager(t), num_threads(n-1) {
u32 prio;
u32 cpuid = svcGetCurrentProcessorNumber();
Result rc;
threads = new Thread[num_threads];
if (R_FAILED((rc = svcGetThreadPriority(&prio, CUR_THREAD_HANDLE)))) {
fatalSimple(rc);
}
for (unsigned int i = 0; i < num_threads; i++) {
threads[i] = {0};
threadCreate(&threads[i], &MultiThreadedWaitableManager::thread_func, this, ss, prio, cpuid);
}
new_waitable_event = new SystemEvent(this, &MultiThreadedWaitableManager::add_waitable_callback);
this->waitables.push_back(new_waitable_event);
}
~MultiThreadedWaitableManager() override {
/* TODO: Exit the threads? */
}
IWaitable *get_waitable();
void add_waitable(IWaitable *waitable) override;
void process() override;
void process_until_timeout() override;
static Result add_waitable_callback(void *this_ptr, Handle *handles, size_t num_handles, u64 timeout);
static void thread_func(void *this_ptr);
};

View File

@ -9,7 +9,7 @@
class SystemEvent final : public IEvent { class SystemEvent final : public IEvent {
public: public:
SystemEvent(EventCallback callback) : IEvent(0, callback) { SystemEvent(void *a, EventCallback callback) : IEvent(0, a, callback) {
Handle wait_h; Handle wait_h;
Handle sig_h; Handle sig_h;
if (R_FAILED(svcCreateEvent(&sig_h, &wait_h))) { if (R_FAILED(svcCreateEvent(&sig_h, &wait_h))) {

View File

@ -9,16 +9,17 @@
class IWaitable; class IWaitable;
class WaitableManager : public WaitableManagerBase { class WaitableManager : public WaitableManagerBase {
std::vector<IWaitable *> to_add_waitables; protected:
std::vector<IWaitable *> waitables; std::vector<IWaitable *> to_add_waitables;
u64 timeout; std::vector<IWaitable *> waitables;
HosMutex lock; u64 timeout;
std::atomic_bool has_new_items; HosMutex lock;
std::atomic_bool has_new_items;
private: private:
void process_internal(bool break_on_timeout); void process_internal(bool break_on_timeout);
public: public:
WaitableManager(u64 t) : waitables(0), timeout(t), has_new_items(false) { } WaitableManager(u64 t) : waitables(0), timeout(t), has_new_items(false) { }
~WaitableManager() { ~WaitableManager() override {
/* This should call the destructor for every waitable. */ /* This should call the destructor for every waitable. */
for (auto & waitable : waitables) { for (auto & waitable : waitables) {
delete waitable; delete waitable;
@ -26,7 +27,7 @@ class WaitableManager : public WaitableManagerBase {
waitables.clear(); waitables.clear();
} }
void add_waitable(IWaitable *waitable); virtual void add_waitable(IWaitable *waitable);
void process(); virtual void process();
void process_until_timeout(); virtual void process_until_timeout();
}; };

View File

@ -7,6 +7,7 @@ class WaitableManagerBase {
std::atomic<u64> cur_priority; std::atomic<u64> cur_priority;
public: public:
WaitableManagerBase() : cur_priority(0) { } WaitableManagerBase() : cur_priority(0) { }
virtual ~WaitableManagerBase() { }
u64 get_priority() { u64 get_priority() {
return std::atomic_fetch_add(&cur_priority, (u64)1); return std::atomic_fetch_add(&cur_priority, (u64)1);

View File

@ -0,0 +1,110 @@
#include <switch.h>
#include <algorithm>
#include <stratosphere/multithreadedwaitablemanager.hpp>
void MultiThreadedWaitableManager::process() {
Result rc;
for (unsigned int i = 0; i < num_threads; i++) {
if (R_FAILED((rc = threadStart(&threads[i])))) {
fatalSimple(rc);
}
}
MultiThreadedWaitableManager::thread_func(this);
}
void MultiThreadedWaitableManager::process_until_timeout() {
/* TODO: Panic. */
}
void MultiThreadedWaitableManager::add_waitable(IWaitable *waitable) {
this->lock.Lock();
this->to_add_waitables.push_back(waitable);
waitable->set_manager(this);
this->new_waitable_event->signal_event();
this->lock.Unlock();
}
IWaitable *MultiThreadedWaitableManager::get_waitable() {
std::vector<Handle> handles;
int handle_index = 0;
Result rc;
this->get_waitable_lock.Lock();
while (1) {
/* Sort waitables by priority. */
std::sort(this->waitables.begin(), this->waitables.end(), IWaitable::compare);
/* Copy out handles. */
handles.resize(this->waitables.size());
std::transform(this->waitables.begin(), this->waitables.end(), handles.begin(), [](IWaitable *w) { return w->get_handle(); });
rc = svcWaitSynchronization(&handle_index, handles.data(), this->waitables.size(), this->timeout);
IWaitable *w = this->waitables[handle_index];
if (R_SUCCEEDED(rc)) {
for (int i = 0; i < handle_index; i++) {
this->waitables[i]->update_priority();
}
this->waitables.erase(this->waitables.begin() + handle_index);
} else if (rc == 0xEA01) {
/* Timeout. */
for (auto & waitable : this->waitables) {
waitable->update_priority();
}
} else if (rc != 0xF601) {
/* TODO: Panic. When can this happen? */
} else {
for (int i = 0; i < handle_index; i++) {
this->waitables[i]->update_priority();
}
this->waitables.erase(this->waitables.begin() + handle_index);
delete w;
}
/* Do deferred callback for each waitable. */
for (auto & waitable : this->waitables) {
if (waitable->get_deferred()) {
waitable->handle_deferred();
}
}
/* Return waitable. */
if (R_SUCCEEDED(rc)) {
if (w == this->new_waitable_event) {
w->handle_signaled(0);
this->waitables.push_back(w);
} else {
this->get_waitable_lock.Unlock();
return w;
}
}
}
}
Result MultiThreadedWaitableManager::add_waitable_callback(void *arg, Handle *handles, size_t num_handles, u64 timeout) {
MultiThreadedWaitableManager *this_ptr = (MultiThreadedWaitableManager *)arg;
svcClearEvent(handles[0]);
this_ptr->lock.Lock();
this_ptr->waitables.insert(this_ptr->waitables.end(), this_ptr->to_add_waitables.begin(), this_ptr->to_add_waitables.end());
this_ptr->to_add_waitables.clear();
this_ptr->lock.Unlock();
return 0;
}
void MultiThreadedWaitableManager::thread_func(void *t) {
MultiThreadedWaitableManager *this_ptr = (MultiThreadedWaitableManager *)t;
while (1) {
IWaitable *w = this_ptr->get_waitable();
if (w) {
Result rc = w->handle_signaled(0);
if (rc == 0xF601) {
/* Close! */
delete w;
} else {
this_ptr->add_waitable(w);
}
}
}
}