diff --git a/include/stratosphere/mitm/mitm_server.hpp b/include/stratosphere/mitm/mitm_server.hpp
index 7224ed5c..3abf9b18 100644
--- a/include/stratosphere/mitm/mitm_server.hpp
+++ b/include/stratosphere/mitm/mitm_server.hpp
@@ -13,47 +13,47 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
-
+
#pragma once
#include
#include "mitm_query_service.hpp"
#include "sm_mitm.h"
#include "mitm_session.hpp"
+#include "../utilities.hpp"
void RegisterMitmServerQueryHandle(Handle query_h, ServiceObjectHolder &&service);
template
-class MitmServer : public IWaitable {
+class MitmServer : public IWaitable {
static_assert(std::is_base_of::value, "MitM Service Objects must derive from IMitmServiceObject");
private:
Handle port_handle;
unsigned int max_sessions;
char mitm_name[9];
-
+
public:
MitmServer(const char *service_name, unsigned int max_s) : port_handle(0), max_sessions(max_s) {
Handle query_h = 0;
- Result rc = smMitMInitialize();
- if (R_FAILED(rc)) {
- fatalSimple(rc);
- }
-
- strncpy(mitm_name, service_name, 8);
- mitm_name[8] = '\x00';
- if (R_FAILED((rc = smMitMInstall(&this->port_handle, &query_h, mitm_name)))) {
- fatalSimple(rc);
- }
- RegisterMitmServerQueryHandle(query_h, std::move(ServiceObjectHolder(std::move(std::make_shared>()))));
-
- smMitMExit();
- }
-
- virtual ~MitmServer() override {
- if (this->port_handle) {
- if (R_FAILED(smMitMUninstall(this->mitm_name))) {
+
+ DoWithSmMitmSession([&]() {
+ strncpy(mitm_name, service_name, 8);
+ mitm_name[8] = '\x00';
+ if (R_FAILED(smMitMInstall(&this->port_handle, &query_h, mitm_name))) {
std::abort();
}
+ });
+
+ RegisterMitmServerQueryHandle(query_h, std::move(ServiceObjectHolder(std::move(std::make_shared>()))));
+ }
+
+ virtual ~MitmServer() override {
+ if (this->port_handle) {
+ DoWithSmMitmSession([&]() {
+ if (R_FAILED(smMitMUninstall(this->mitm_name))) {
+ std::abort();
+ }
+ });
svcCloseHandle(port_handle);
}
}
@@ -61,12 +61,12 @@ class MitmServer : public IWaitable {
SessionManagerBase *GetSessionManager() {
return static_cast(this->GetManager());
}
-
- /* IWaitable */
+
+ /* IWaitable */
virtual Handle GetHandle() override {
return this->port_handle;
}
-
+
virtual Result HandleSignaled(u64 timeout) override {
/* If this server's port was signaled, accept a new session. */
Handle session_h;
@@ -74,27 +74,22 @@ class MitmServer : public IWaitable {
if (R_FAILED(rc)) {
return rc;
}
-
+
/* Create a forward service for this instance. */
std::shared_ptr forward_service(new Service(), [](Service *s) {
/* Custom deleter to ensure service is open as long as necessary. */
serviceClose(s);
delete s;
});
-
- rc = smMitMInitialize();
- if (R_FAILED(rc)) {
- fatalSimple(rc);
- }
-
+
u64 client_pid;
-
- if (R_FAILED(smMitMAcknowledgeSession(forward_service.get(), &client_pid, mitm_name))) {
- /* TODO: Panic. */
- }
-
- smMitMExit();
-
+
+ DoWithSmMitmSession([&]() {
+ if (R_FAILED(smMitMAcknowledgeSession(forward_service.get(), &client_pid, mitm_name))) {
+ std::abort();
+ }
+ });
+
this->GetSessionManager()->AddWaitable(new MitmSession(session_h, client_pid, forward_service, MakeShared(forward_service, client_pid)));
return ResultSuccess;
}
diff --git a/include/stratosphere/servers.hpp b/include/stratosphere/servers.hpp
index c37b95b7..1afcfa49 100644
--- a/include/stratosphere/servers.hpp
+++ b/include/stratosphere/servers.hpp
@@ -13,12 +13,13 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
-
+
#pragma once
#include
#include "iwaitable.hpp"
#include "ipc.hpp"
+#include "utilities.hpp"
template
class IServer : public IWaitable {
@@ -26,11 +27,11 @@ class IServer : public IWaitable {
protected:
Handle port_handle;
unsigned int max_sessions;
-
- public:
+
+ public:
IServer(unsigned int max_s) : port_handle(0), max_sessions(max_s) { }
-
- virtual ~IServer() {
+
+ virtual ~IServer() {
if (port_handle) {
svcCloseHandle(port_handle);
}
@@ -39,12 +40,12 @@ class IServer : public IWaitable {
SessionManagerBase *GetSessionManager() {
return static_cast(this->GetManager());
}
-
- /* IWaitable */
+
+ /* IWaitable */
virtual Handle GetHandle() override {
return this->port_handle;
}
-
+
virtual Result HandleSignaled(u64 timeout) override {
/* If this server's port was signaled, accept a new session. */
Handle session_h;
@@ -52,24 +53,26 @@ class IServer : public IWaitable {
if (R_FAILED(rc)) {
return rc;
}
-
+
this->GetSessionManager()->AddSession(session_h, std::move(ServiceObjectHolder(std::move(MakeShared()))));
return ResultSuccess;
}
};
template >
-class ServiceServer : public IServer {
+class ServiceServer : public IServer {
public:
- ServiceServer(const char *service_name, unsigned int max_s) : IServer(max_s) {
- if (R_FAILED(smRegisterService(&this->port_handle, service_name, false, this->max_sessions))) {
- /* TODO: Panic. */
- }
+ ServiceServer(const char *service_name, unsigned int max_s) : IServer(max_s) {
+ DoWithSmSession([&]() {
+ if (R_FAILED(smRegisterService(&this->port_handle, service_name, false, this->max_sessions))) {
+ std::abort();
+ }
+ });
}
};
template >
-class ExistingPortServer : public IServer {
+class ExistingPortServer : public IServer {
public:
ExistingPortServer(Handle port_h, unsigned int max_s) : IServer(max_s) {
this->port_handle = port_h;
@@ -77,9 +80,9 @@ class ExistingPortServer : public IServer {
};
template >
-class ManagedPortServer : public IServer {
+class ManagedPortServer : public IServer {
public:
- ManagedPortServer(const char *service_name, unsigned int max_s) : IServer(max_s) {
+ ManagedPortServer(const char *service_name, unsigned int max_s) : IServer(max_s) {
if (R_FAILED(svcManageNamedPort(&this->port_handle, service_name, this->max_sessions))) {
/* TODO: panic */
}
diff --git a/include/stratosphere/utilities.hpp b/include/stratosphere/utilities.hpp
index e67c3353..377dd7fd 100644
--- a/include/stratosphere/utilities.hpp
+++ b/include/stratosphere/utilities.hpp
@@ -18,6 +18,9 @@
#include
#include
+#include "hossynch.hpp"
+#include "mitm/sm_mitm.h"
+
static inline void RebootToRcm() {
SecmonArgs args = {0};
args.X[0] = 0xC3000401; /* smcSetConfig */
@@ -97,3 +100,36 @@ static inline bool IsRcmBugPatched() {
}
return rcm_bug_patched;
}
+
+HosRecursiveMutex &GetSmSessionMutex();
+HosRecursiveMutex &GetSmMitmSessionMutex();
+
+template
+static void DoWithSmSession(F f) {
+ std::scoped_lock lk(GetSmSessionMutex());
+ {
+ Result rc;
+ if (R_SUCCEEDED((rc = smInitialize()))) {
+ f();
+ } else {
+ /* TODO: fatalSimple(rc); ? */
+ std::abort();
+ }
+ smExit();
+ }
+}
+
+template
+static void DoWithSmMitmSession(F f) {
+ std::scoped_lock lk(GetSmMitmSessionMutex());
+ {
+ Result rc;
+ if (R_SUCCEEDED((rc = smMitMInitialize()))) {
+ f();
+ } else {
+ /* TODO: fatalSimple(rc); ? */
+ std::abort();
+ }
+ smMitMExit();
+ }
+}
diff --git a/source/utilities.cpp b/source/utilities.cpp
new file mode 100644
index 00000000..1494dd03
--- /dev/null
+++ b/source/utilities.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2018-2019 Atmosphère-NX
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms and conditions of the GNU General Public License,
+ * version 2, as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+#include
+#include
+
+static HosRecursiveMutex g_sm_session_lock;
+static HosRecursiveMutex g_sm_mitm_session_lock;
+
+
+HosRecursiveMutex &GetSmSessionMutex() {
+ return g_sm_session_lock;
+}
+
+HosRecursiveMutex &GetSmMitmSessionMutex() {
+ return g_sm_mitm_session_lock;
+}