diff --git a/include/stratosphere/ipc/ipc_serialization.hpp b/include/stratosphere/ipc/ipc_serialization.hpp index 7ce9bd39..3d026ae7 100644 --- a/include/stratosphere/ipc/ipc_serialization.hpp +++ b/include/stratosphere/ipc/ipc_serialization.hpp @@ -537,7 +537,7 @@ struct Encoder> { u64 result; } *raw; if (is_domain) { - raw = (decltype(raw))ipcPrepareHeaderForDomain(&ctx->reply, sizeof(*raw) + MetaInfo::OutRawArgSize, 0); + raw = (decltype(raw))ipcPrepareHeaderForDomain(&ctx->reply, sizeof(*raw) + MetaInfo::OutRawArgSize + sizeof(*ctx->out_object_ids) * MetaInfo::NumOutSessions, 0); auto resp_header = (DomainResponseHeader *)((uintptr_t)raw - sizeof(DomainResponseHeader)); *resp_header = {0}; resp_header->NumObjectIds = MetaInfo::NumOutSessions; diff --git a/include/stratosphere/mitm/mitm_session.hpp b/include/stratosphere/mitm/mitm_session.hpp index 27105dba..f8eeb1e5 100644 --- a/include/stratosphere/mitm/mitm_session.hpp +++ b/include/stratosphere/mitm/mitm_session.hpp @@ -23,13 +23,17 @@ #define RESULT_FORWARD_TO_SESSION 0xCAFEFC -class MitmSession final : public ServiceSession { +class MitmSession final : public ServiceSession { private: /* This will be for the actual session. */ std::shared_ptr forward_service; + struct PostProcessHandlerContext { + bool closed; + void (*handler)(IMitmServiceObject *, IpcResponseContext *); + }; /* Store a handler for the service. */ - void (*service_post_process_handler)(IMitmServiceObject *, IpcResponseContext *); + std::shared_ptr service_post_process_ctx; /* For cleanup usage. */ u64 client_pid; @@ -41,7 +45,9 @@ class MitmSession final : public ServiceSession { this->forward_service = std::move(fs); this->obj_holder = std::move(ServiceObjectHolder(std::move(srv))); - this->service_post_process_handler = T::PostProcess; + this->service_post_process_ctx = std::make_shared(); + this->service_post_process_ctx->closed = false; + this->service_post_process_ctx->handler = T::PostProcess; size_t pbs; if (R_FAILED(ipcQueryPointerBufferSize(forward_service->handle, &pbs))) { @@ -52,12 +58,12 @@ class MitmSession final : public ServiceSession { this->control_holder = std::move(ServiceObjectHolder(std::move(std::make_shared(this)))); } - MitmSession(Handle s_h, u64 pid, std::shared_ptr fs, ServiceObjectHolder &&h, void (*pph)(IMitmServiceObject *, IpcResponseContext *)) : ServiceSession(s_h), client_pid(pid) { + MitmSession(Handle s_h, u64 pid, std::shared_ptr fs, ServiceObjectHolder &&h, std::shared_ptr ppc) : ServiceSession(s_h), client_pid(pid) { this->session_handle = s_h; this->forward_service = std::move(fs); this->obj_holder = std::move(h); - this->service_post_process_handler = pph; + this->service_post_process_ctx = ppc; size_t pbs; if (R_FAILED(ipcQueryPointerBufferSize(forward_service->handle, &pbs))) { @@ -104,14 +110,14 @@ class MitmSession final : public ServiceSession { } return rc; } - + virtual Result GetResponse(IpcResponseContext *ctx) { Result rc = 0xF601; FirmwareVersion fw = GetRuntimeFirmwareVersion(); const ServiceCommandMeta *dispatch_table = ctx->obj_holder->GetDispatchTable(); size_t entry_count = ctx->obj_holder->GetDispatchTableEntryCount(); - + if (IsDomainObject(ctx->obj_holder)) { switch (ctx->request.InMessageType) { case DomainMessageType_Invalid: @@ -121,12 +127,9 @@ class MitmSession final : public ServiceSession { if (R_SUCCEEDED(rc)) { ctx->obj_holder->GetServiceObject()->FreeObject(ctx->request.InThisObjectId); } - if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get())) { - /* If we're not longer MitMing anything, we don't need a mitm session. */ - this->Reply(); - this->GetSessionManager()->AddSession(this->session_handle, std::move(this->obj_holder)); - this->session_handle = 0; - return 0xF601; + if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get()) && !this->service_post_process_ctx->closed) { + /* If we're not longer MitMing anything, we'll no longer do any postprocessing. */ + this->service_post_process_ctx->closed = true; } return rc; case DomainMessageType_SendMessage: @@ -162,14 +165,23 @@ class MitmSession final : public ServiceSession { virtual void PostProcessResponse(IpcResponseContext *ctx) override { if ((ctx->cmd_type == IpcCommandType_Request || ctx->cmd_type == IpcCommandType_RequestWithContext) && R_SUCCEEDED(ctx->rc)) { + if (this->service_post_process_ctx->closed) { + return; + } + if (!IsDomainObject(ctx->obj_holder) || ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get())) { - IMitmServiceObject *obj; + IMitmServiceObject *obj = nullptr; if (!IsDomainObject(ctx->obj_holder)) { obj = ctx->obj_holder->GetServiceObjectUnsafe(); } else { - obj = ctx->obj_holder->GetServiceObject()->GetObject(ctx->request.InThisObjectId)->GetServiceObjectUnsafe(); + const auto sub_obj = ctx->obj_holder->GetServiceObject()->GetObject(ctx->request.InThisObjectId); + if (sub_obj != nullptr) { + obj = sub_obj->GetServiceObjectUnsafe(); + } + } + if (obj != nullptr) { + this->service_post_process_ctx->handler(obj, ctx); } - this->service_post_process_handler(obj, ctx); } } } @@ -278,7 +290,7 @@ class MitmSession final : public ServiceSession { out_h.SetValue(client_h); if (id == serviceGetObjectId(this->session->forward_service.get())) { - this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(object->Clone()), this->session->service_post_process_handler)); + this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(object->Clone()), this->session->service_post_process_ctx)); } else { this->session->GetSessionManager()->AddSession(server_h, std::move(object->Clone())); } @@ -292,7 +304,7 @@ class MitmSession final : public ServiceSession { std::abort(); } - this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(this->session->obj_holder.Clone()), this->session->service_post_process_handler)); + this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(this->session->obj_holder.Clone()), this->session->service_post_process_ctx)); out_h.SetValue(client_h); } diff --git a/include/stratosphere/waitable_manager.hpp b/include/stratosphere/waitable_manager.hpp index f7fcb810..3bbd6bae 100644 --- a/include/stratosphere/waitable_manager.hpp +++ b/include/stratosphere/waitable_manager.hpp @@ -358,6 +358,9 @@ class WaitableManager : public SessionManagerBase { virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; + if (object_id > ManagerOptions::MaxDomainObjects) { + return 0x25A0A; + } if (this->domain_objects[object_id-1].owner == nullptr) { this->domain_objects[object_id-1].owner = domain; return 0; @@ -367,6 +370,9 @@ class WaitableManager : public SessionManagerBase { virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override { std::scoped_lock lk{this->domain_lock}; + if (object_id > ManagerOptions::MaxDomainObjects) { + return; + } if (this->domain_objects[object_id-1].owner == domain) { this->domain_objects[object_id-1].obj_holder = std::move(holder); } @@ -374,6 +380,9 @@ class WaitableManager : public SessionManagerBase { virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; + if (object_id > ManagerOptions::MaxDomainObjects) { + return nullptr; + } if (this->domain_objects[object_id-1].owner == domain) { return &this->domain_objects[object_id-1].obj_holder; } @@ -382,6 +391,9 @@ class WaitableManager : public SessionManagerBase { virtual Result FreeObject(IDomainObject *domain, u32 object_id) override { std::scoped_lock lk{this->domain_lock}; + if (object_id > ManagerOptions::MaxDomainObjects) { + return 0x3D80B; + } if (this->domain_objects[object_id-1].owner == domain) { this->domain_objects[object_id-1].obj_holder.Reset(); this->domain_objects[object_id-1].owner = nullptr; @@ -392,6 +404,9 @@ class WaitableManager : public SessionManagerBase { virtual Result ForceFreeObject(u32 object_id) override { std::scoped_lock lk{this->domain_lock}; + if (object_id > ManagerOptions::MaxDomainObjects) { + return 0x3D80B; + } if (this->domain_objects[object_id-1].owner != nullptr) { this->domain_objects[object_id-1].obj_holder.Reset(); this->domain_objects[object_id-1].owner = nullptr;