From 79bc9bf8d87dddcfc2d080626eb8c817c7339fd0 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Mon, 22 Apr 2019 12:40:35 -0700 Subject: [PATCH] libstrat: only hold sm sessions open when needed --- include/stratosphere/mitm/mitm_server.hpp | 71 +++++++++++------------ include/stratosphere/servers.hpp | 37 ++++++------ include/stratosphere/utilities.hpp | 36 ++++++++++++ source/utilities.cpp | 30 ++++++++++ 4 files changed, 119 insertions(+), 55 deletions(-) create mode 100644 source/utilities.cpp diff --git a/include/stratosphere/mitm/mitm_server.hpp b/include/stratosphere/mitm/mitm_server.hpp index 7224ed5c..3abf9b18 100644 --- a/include/stratosphere/mitm/mitm_server.hpp +++ b/include/stratosphere/mitm/mitm_server.hpp @@ -13,47 +13,47 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ - + #pragma once #include #include "mitm_query_service.hpp" #include "sm_mitm.h" #include "mitm_session.hpp" +#include "../utilities.hpp" void RegisterMitmServerQueryHandle(Handle query_h, ServiceObjectHolder &&service); template -class MitmServer : public IWaitable { +class MitmServer : public IWaitable { static_assert(std::is_base_of::value, "MitM Service Objects must derive from IMitmServiceObject"); private: Handle port_handle; unsigned int max_sessions; char mitm_name[9]; - + public: MitmServer(const char *service_name, unsigned int max_s) : port_handle(0), max_sessions(max_s) { Handle query_h = 0; - Result rc = smMitMInitialize(); - if (R_FAILED(rc)) { - fatalSimple(rc); - } - - strncpy(mitm_name, service_name, 8); - mitm_name[8] = '\x00'; - if (R_FAILED((rc = smMitMInstall(&this->port_handle, &query_h, mitm_name)))) { - fatalSimple(rc); - } - RegisterMitmServerQueryHandle(query_h, std::move(ServiceObjectHolder(std::move(std::make_shared>())))); - - smMitMExit(); - } - - virtual ~MitmServer() override { - if (this->port_handle) { - if (R_FAILED(smMitMUninstall(this->mitm_name))) { + + DoWithSmMitmSession([&]() { + strncpy(mitm_name, service_name, 8); + mitm_name[8] = '\x00'; + if (R_FAILED(smMitMInstall(&this->port_handle, &query_h, mitm_name))) { std::abort(); } + }); + + RegisterMitmServerQueryHandle(query_h, std::move(ServiceObjectHolder(std::move(std::make_shared>())))); + } + + virtual ~MitmServer() override { + if (this->port_handle) { + DoWithSmMitmSession([&]() { + if (R_FAILED(smMitMUninstall(this->mitm_name))) { + std::abort(); + } + }); svcCloseHandle(port_handle); } } @@ -61,12 +61,12 @@ class MitmServer : public IWaitable { SessionManagerBase *GetSessionManager() { return static_cast(this->GetManager()); } - - /* IWaitable */ + + /* IWaitable */ virtual Handle GetHandle() override { return this->port_handle; } - + virtual Result HandleSignaled(u64 timeout) override { /* If this server's port was signaled, accept a new session. */ Handle session_h; @@ -74,27 +74,22 @@ class MitmServer : public IWaitable { if (R_FAILED(rc)) { return rc; } - + /* Create a forward service for this instance. */ std::shared_ptr forward_service(new Service(), [](Service *s) { /* Custom deleter to ensure service is open as long as necessary. */ serviceClose(s); delete s; }); - - rc = smMitMInitialize(); - if (R_FAILED(rc)) { - fatalSimple(rc); - } - + u64 client_pid; - - if (R_FAILED(smMitMAcknowledgeSession(forward_service.get(), &client_pid, mitm_name))) { - /* TODO: Panic. */ - } - - smMitMExit(); - + + DoWithSmMitmSession([&]() { + if (R_FAILED(smMitMAcknowledgeSession(forward_service.get(), &client_pid, mitm_name))) { + std::abort(); + } + }); + this->GetSessionManager()->AddWaitable(new MitmSession(session_h, client_pid, forward_service, MakeShared(forward_service, client_pid))); return ResultSuccess; } diff --git a/include/stratosphere/servers.hpp b/include/stratosphere/servers.hpp index c37b95b7..1afcfa49 100644 --- a/include/stratosphere/servers.hpp +++ b/include/stratosphere/servers.hpp @@ -13,12 +13,13 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ - + #pragma once #include #include "iwaitable.hpp" #include "ipc.hpp" +#include "utilities.hpp" template class IServer : public IWaitable { @@ -26,11 +27,11 @@ class IServer : public IWaitable { protected: Handle port_handle; unsigned int max_sessions; - - public: + + public: IServer(unsigned int max_s) : port_handle(0), max_sessions(max_s) { } - - virtual ~IServer() { + + virtual ~IServer() { if (port_handle) { svcCloseHandle(port_handle); } @@ -39,12 +40,12 @@ class IServer : public IWaitable { SessionManagerBase *GetSessionManager() { return static_cast(this->GetManager()); } - - /* IWaitable */ + + /* IWaitable */ virtual Handle GetHandle() override { return this->port_handle; } - + virtual Result HandleSignaled(u64 timeout) override { /* If this server's port was signaled, accept a new session. */ Handle session_h; @@ -52,24 +53,26 @@ class IServer : public IWaitable { if (R_FAILED(rc)) { return rc; } - + this->GetSessionManager()->AddSession(session_h, std::move(ServiceObjectHolder(std::move(MakeShared())))); return ResultSuccess; } }; template > -class ServiceServer : public IServer { +class ServiceServer : public IServer { public: - ServiceServer(const char *service_name, unsigned int max_s) : IServer(max_s) { - if (R_FAILED(smRegisterService(&this->port_handle, service_name, false, this->max_sessions))) { - /* TODO: Panic. */ - } + ServiceServer(const char *service_name, unsigned int max_s) : IServer(max_s) { + DoWithSmSession([&]() { + if (R_FAILED(smRegisterService(&this->port_handle, service_name, false, this->max_sessions))) { + std::abort(); + } + }); } }; template > -class ExistingPortServer : public IServer { +class ExistingPortServer : public IServer { public: ExistingPortServer(Handle port_h, unsigned int max_s) : IServer(max_s) { this->port_handle = port_h; @@ -77,9 +80,9 @@ class ExistingPortServer : public IServer { }; template > -class ManagedPortServer : public IServer { +class ManagedPortServer : public IServer { public: - ManagedPortServer(const char *service_name, unsigned int max_s) : IServer(max_s) { + ManagedPortServer(const char *service_name, unsigned int max_s) : IServer(max_s) { if (R_FAILED(svcManageNamedPort(&this->port_handle, service_name, this->max_sessions))) { /* TODO: panic */ } diff --git a/include/stratosphere/utilities.hpp b/include/stratosphere/utilities.hpp index e67c3353..377dd7fd 100644 --- a/include/stratosphere/utilities.hpp +++ b/include/stratosphere/utilities.hpp @@ -18,6 +18,9 @@ #include #include +#include "hossynch.hpp" +#include "mitm/sm_mitm.h" + static inline void RebootToRcm() { SecmonArgs args = {0}; args.X[0] = 0xC3000401; /* smcSetConfig */ @@ -97,3 +100,36 @@ static inline bool IsRcmBugPatched() { } return rcm_bug_patched; } + +HosRecursiveMutex &GetSmSessionMutex(); +HosRecursiveMutex &GetSmMitmSessionMutex(); + +template +static void DoWithSmSession(F f) { + std::scoped_lock lk(GetSmSessionMutex()); + { + Result rc; + if (R_SUCCEEDED((rc = smInitialize()))) { + f(); + } else { + /* TODO: fatalSimple(rc); ? */ + std::abort(); + } + smExit(); + } +} + +template +static void DoWithSmMitmSession(F f) { + std::scoped_lock lk(GetSmMitmSessionMutex()); + { + Result rc; + if (R_SUCCEEDED((rc = smMitMInitialize()))) { + f(); + } else { + /* TODO: fatalSimple(rc); ? */ + std::abort(); + } + smMitMExit(); + } +} diff --git a/source/utilities.cpp b/source/utilities.cpp new file mode 100644 index 00000000..1494dd03 --- /dev/null +++ b/source/utilities.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2018-2019 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include +#include + +static HosRecursiveMutex g_sm_session_lock; +static HosRecursiveMutex g_sm_mitm_session_lock; + + +HosRecursiveMutex &GetSmSessionMutex() { + return g_sm_session_lock; +} + +HosRecursiveMutex &GetSmMitmSessionMutex() { + return g_sm_mitm_session_lock; +}