diff --git a/include/stratosphere/ipc/ipc_domain_object.hpp b/include/stratosphere/ipc/ipc_domain_object.hpp index 9a7244a4..b0c01252 100644 --- a/include/stratosphere/ipc/ipc_domain_object.hpp +++ b/include/stratosphere/ipc/ipc_domain_object.hpp @@ -25,6 +25,8 @@ class IDomainObject; class DomainManager { + public: + static constexpr u32 MinimumDomainId = 1; public: virtual std::shared_ptr AllocateDomain() = 0; virtual void FreeDomain(IDomainObject *domain) = 0; diff --git a/include/stratosphere/ipc/ipc_service_object.hpp b/include/stratosphere/ipc/ipc_service_object.hpp index 3cc313ac..01040513 100644 --- a/include/stratosphere/ipc/ipc_service_object.hpp +++ b/include/stratosphere/ipc/ipc_service_object.hpp @@ -35,6 +35,7 @@ struct ServiceCommandMeta { class IServiceObject { public: virtual ~IServiceObject() { } + virtual bool IsMitmObject() const { return false; } }; #define SERVICE_DISPATCH_TABLE_NAME s_DispatchTable @@ -96,6 +97,10 @@ class ServiceObjectHolder { return reinterpret_cast(this->dispatch_table); } + bool IsMitmObject() const { + return this->srv->IsMitmObject(); + } + /* Default constructor, move constructor, move assignment operator. */ ServiceObjectHolder() : srv(nullptr), dispatch_table(nullptr) { } diff --git a/include/stratosphere/mitm/imitmserviceobject.hpp b/include/stratosphere/mitm/imitmserviceobject.hpp index 69d64352..db33550d 100644 --- a/include/stratosphere/mitm/imitmserviceobject.hpp +++ b/include/stratosphere/mitm/imitmserviceobject.hpp @@ -39,6 +39,8 @@ class IMitmServiceObject : public IServiceObject { return this->process_id; } + virtual bool IsMitmObject() const override { return true; } + static bool ShouldMitm(u64 pid, u64 tid); protected: diff --git a/include/stratosphere/mitm/mitm_session.hpp b/include/stratosphere/mitm/mitm_session.hpp index ac6600a4..e4d069f6 100644 --- a/include/stratosphere/mitm/mitm_session.hpp +++ b/include/stratosphere/mitm/mitm_session.hpp @@ -122,15 +122,28 @@ class MitmSession final : public ServiceSession { case DomainMessageType_Invalid: return ResultKernelConnectionClosed; case DomainMessageType_Close: - rc = ForwardRequest(ctx); - if (R_SUCCEEDED(rc)) { - ctx->obj_holder->GetServiceObject()->FreeObject(ctx->request.InThisObjectId); + { + auto sub_obj = ctx->obj_holder->GetServiceObject()->GetObject(ctx->request.InThisObjectId); + if (sub_obj == nullptr || (!sub_obj)) { + rc = ForwardRequest(ctx); + return rc; } + + if (sub_obj->IsMitmObject()) { + rc = ForwardRequest(ctx); + if (R_SUCCEEDED(rc)) { + ctx->obj_holder->GetServiceObject()->FreeObject(ctx->request.InThisObjectId); + } + } else { + rc = ctx->obj_holder->GetServiceObject()->FreeObject(ctx->request.InThisObjectId); + } + if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get()) && !this->service_post_process_ctx->closed) { /* If we're not longer MitMing anything, we'll no longer do any postprocessing. */ this->service_post_process_ctx->closed = true; } return rc; + } case DomainMessageType_SendMessage: { auto sub_obj = ctx->obj_holder->GetServiceObject()->GetObject(ctx->request.InThisObjectId); diff --git a/include/stratosphere/waitable_manager.hpp b/include/stratosphere/waitable_manager.hpp index 5072a542..d0ae8c2f 100644 --- a/include/stratosphere/waitable_manager.hpp +++ b/include/stratosphere/waitable_manager.hpp @@ -364,7 +364,7 @@ class WaitableManager : public SessionManagerBase { virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; - if (object_id > ManagerOptions::MaxDomainObjects) { + if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return ResultServiceFrameworkOutOfDomainEntries; } if (this->domain_objects[object_id-1].owner == nullptr) { @@ -376,7 +376,7 @@ class WaitableManager : public SessionManagerBase { virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override { std::scoped_lock lk{this->domain_lock}; - if (object_id > ManagerOptions::MaxDomainObjects) { + if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return; } if (this->domain_objects[object_id-1].owner == domain) { @@ -386,7 +386,7 @@ class WaitableManager : public SessionManagerBase { virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; - if (object_id > ManagerOptions::MaxDomainObjects) { + if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return nullptr; } if (this->domain_objects[object_id-1].owner == domain) { @@ -397,7 +397,7 @@ class WaitableManager : public SessionManagerBase { virtual Result FreeObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; - if (object_id > ManagerOptions::MaxDomainObjects) { + if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return ResultHipcDomainObjectNotFound; } if (this->domain_objects[object_id-1].owner == domain) { @@ -410,7 +410,7 @@ class WaitableManager : public SessionManagerBase { virtual Result ForceFreeObject(u32 object_id) override { std::scoped_lock lk{this->domain_lock}; - if (object_id > ManagerOptions::MaxDomainObjects) { + if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) { return ResultHipcDomainObjectNotFound; } if (this->domain_objects[object_id-1].owner != nullptr) {