Fix a number of bugs in mitm.

This commit is contained in:
Michael Scire 2018-12-12 03:07:26 -08:00
parent fa37b70b0e
commit 8b9dd81fa6
3 changed files with 46 additions and 19 deletions

View File

@ -537,7 +537,7 @@ struct Encoder<MetaInfo, std::tuple<Args...>> {
u64 result; u64 result;
} *raw; } *raw;
if (is_domain) { if (is_domain) {
raw = (decltype(raw))ipcPrepareHeaderForDomain(&ctx->reply, sizeof(*raw) + MetaInfo::OutRawArgSize, 0); raw = (decltype(raw))ipcPrepareHeaderForDomain(&ctx->reply, sizeof(*raw) + MetaInfo::OutRawArgSize + sizeof(*ctx->out_object_ids) * MetaInfo::NumOutSessions, 0);
auto resp_header = (DomainResponseHeader *)((uintptr_t)raw - sizeof(DomainResponseHeader)); auto resp_header = (DomainResponseHeader *)((uintptr_t)raw - sizeof(DomainResponseHeader));
*resp_header = {0}; *resp_header = {0};
resp_header->NumObjectIds = MetaInfo::NumOutSessions; resp_header->NumObjectIds = MetaInfo::NumOutSessions;

View File

@ -28,8 +28,12 @@ class MitmSession final : public ServiceSession {
/* This will be for the actual session. */ /* This will be for the actual session. */
std::shared_ptr<Service> forward_service; std::shared_ptr<Service> forward_service;
struct PostProcessHandlerContext {
bool closed;
void (*handler)(IMitmServiceObject *, IpcResponseContext *);
};
/* Store a handler for the service. */ /* Store a handler for the service. */
void (*service_post_process_handler)(IMitmServiceObject *, IpcResponseContext *); std::shared_ptr<PostProcessHandlerContext> service_post_process_ctx;
/* For cleanup usage. */ /* For cleanup usage. */
u64 client_pid; u64 client_pid;
@ -41,7 +45,9 @@ class MitmSession final : public ServiceSession {
this->forward_service = std::move(fs); this->forward_service = std::move(fs);
this->obj_holder = std::move(ServiceObjectHolder(std::move(srv))); this->obj_holder = std::move(ServiceObjectHolder(std::move(srv)));
this->service_post_process_handler = T::PostProcess; this->service_post_process_ctx = std::make_shared<PostProcessHandlerContext>();
this->service_post_process_ctx->closed = false;
this->service_post_process_ctx->handler = T::PostProcess;
size_t pbs; size_t pbs;
if (R_FAILED(ipcQueryPointerBufferSize(forward_service->handle, &pbs))) { if (R_FAILED(ipcQueryPointerBufferSize(forward_service->handle, &pbs))) {
@ -52,12 +58,12 @@ class MitmSession final : public ServiceSession {
this->control_holder = std::move(ServiceObjectHolder(std::move(std::make_shared<IMitmHipcControlService>(this)))); this->control_holder = std::move(ServiceObjectHolder(std::move(std::make_shared<IMitmHipcControlService>(this))));
} }
MitmSession(Handle s_h, u64 pid, std::shared_ptr<Service> fs, ServiceObjectHolder &&h, void (*pph)(IMitmServiceObject *, IpcResponseContext *)) : ServiceSession(s_h), client_pid(pid) { MitmSession(Handle s_h, u64 pid, std::shared_ptr<Service> fs, ServiceObjectHolder &&h, std::shared_ptr<PostProcessHandlerContext> ppc) : ServiceSession(s_h), client_pid(pid) {
this->session_handle = s_h; this->session_handle = s_h;
this->forward_service = std::move(fs); this->forward_service = std::move(fs);
this->obj_holder = std::move(h); this->obj_holder = std::move(h);
this->service_post_process_handler = pph; this->service_post_process_ctx = ppc;
size_t pbs; size_t pbs;
if (R_FAILED(ipcQueryPointerBufferSize(forward_service->handle, &pbs))) { if (R_FAILED(ipcQueryPointerBufferSize(forward_service->handle, &pbs))) {
@ -121,12 +127,9 @@ class MitmSession final : public ServiceSession {
if (R_SUCCEEDED(rc)) { if (R_SUCCEEDED(rc)) {
ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId); ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId);
} }
if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get())) { if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get()) && !this->service_post_process_ctx->closed) {
/* If we're not longer MitMing anything, we don't need a mitm session. */ /* If we're not longer MitMing anything, we'll no longer do any postprocessing. */
this->Reply(); this->service_post_process_ctx->closed = true;
this->GetSessionManager()->AddSession(this->session_handle, std::move(this->obj_holder));
this->session_handle = 0;
return 0xF601;
} }
return rc; return rc;
case DomainMessageType_SendMessage: case DomainMessageType_SendMessage:
@ -162,14 +165,23 @@ class MitmSession final : public ServiceSession {
virtual void PostProcessResponse(IpcResponseContext *ctx) override { virtual void PostProcessResponse(IpcResponseContext *ctx) override {
if ((ctx->cmd_type == IpcCommandType_Request || ctx->cmd_type == IpcCommandType_RequestWithContext) && R_SUCCEEDED(ctx->rc)) { if ((ctx->cmd_type == IpcCommandType_Request || ctx->cmd_type == IpcCommandType_RequestWithContext) && R_SUCCEEDED(ctx->rc)) {
if (this->service_post_process_ctx->closed) {
return;
}
if (!IsDomainObject(ctx->obj_holder) || ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get())) { if (!IsDomainObject(ctx->obj_holder) || ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get())) {
IMitmServiceObject *obj; IMitmServiceObject *obj = nullptr;
if (!IsDomainObject(ctx->obj_holder)) { if (!IsDomainObject(ctx->obj_holder)) {
obj = ctx->obj_holder->GetServiceObjectUnsafe<IMitmServiceObject>(); obj = ctx->obj_holder->GetServiceObjectUnsafe<IMitmServiceObject>();
} else { } else {
obj = ctx->obj_holder->GetServiceObject<IDomainObject>()->GetObject(ctx->request.InThisObjectId)->GetServiceObjectUnsafe<IMitmServiceObject>(); const auto sub_obj = ctx->obj_holder->GetServiceObject<IDomainObject>()->GetObject(ctx->request.InThisObjectId);
if (sub_obj != nullptr) {
obj = sub_obj->GetServiceObjectUnsafe<IMitmServiceObject>();
}
}
if (obj != nullptr) {
this->service_post_process_ctx->handler(obj, ctx);
} }
this->service_post_process_handler(obj, ctx);
} }
} }
} }
@ -278,7 +290,7 @@ class MitmSession final : public ServiceSession {
out_h.SetValue(client_h); out_h.SetValue(client_h);
if (id == serviceGetObjectId(this->session->forward_service.get())) { if (id == serviceGetObjectId(this->session->forward_service.get())) {
this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(object->Clone()), this->session->service_post_process_handler)); this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(object->Clone()), this->session->service_post_process_ctx));
} else { } else {
this->session->GetSessionManager()->AddSession(server_h, std::move(object->Clone())); this->session->GetSessionManager()->AddSession(server_h, std::move(object->Clone()));
} }
@ -292,7 +304,7 @@ class MitmSession final : public ServiceSession {
std::abort(); std::abort();
} }
this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(this->session->obj_holder.Clone()), this->session->service_post_process_handler)); this->session->GetSessionManager()->AddWaitable(new MitmSession(server_h, this->session->client_pid, this->session->forward_service, std::move(this->session->obj_holder.Clone()), this->session->service_post_process_ctx));
out_h.SetValue(client_h); out_h.SetValue(client_h);
} }

View File

@ -358,6 +358,9 @@ class WaitableManager : public SessionManagerBase {
virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override { virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
return 0x25A0A;
}
if (this->domain_objects[object_id-1].owner == nullptr) { if (this->domain_objects[object_id-1].owner == nullptr) {
this->domain_objects[object_id-1].owner = domain; this->domain_objects[object_id-1].owner = domain;
return 0; return 0;
@ -367,6 +370,9 @@ class WaitableManager : public SessionManagerBase {
virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override { virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
return;
}
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
this->domain_objects[object_id-1].obj_holder = std::move(holder); this->domain_objects[object_id-1].obj_holder = std::move(holder);
} }
@ -374,6 +380,9 @@ class WaitableManager : public SessionManagerBase {
virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override { virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
return nullptr;
}
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
return &this->domain_objects[object_id-1].obj_holder; return &this->domain_objects[object_id-1].obj_holder;
} }
@ -382,6 +391,9 @@ class WaitableManager : public SessionManagerBase {
virtual Result FreeObject(IDomainObject *domain, u32 object_id) override { virtual Result FreeObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
return 0x3D80B;
}
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
this->domain_objects[object_id-1].obj_holder.Reset(); this->domain_objects[object_id-1].obj_holder.Reset();
this->domain_objects[object_id-1].owner = nullptr; this->domain_objects[object_id-1].owner = nullptr;
@ -392,6 +404,9 @@ class WaitableManager : public SessionManagerBase {
virtual Result ForceFreeObject(u32 object_id) override { virtual Result ForceFreeObject(u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
return 0x3D80B;
}
if (this->domain_objects[object_id-1].owner != nullptr) { if (this->domain_objects[object_id-1].owner != nullptr) {
this->domain_objects[object_id-1].obj_holder.Reset(); this->domain_objects[object_id-1].obj_holder.Reset();
this->domain_objects[object_id-1].owner = nullptr; this->domain_objects[object_id-1].owner = nullptr;