libstrat: improve Domain Object Close semantics

This commit is contained in:
Michael Scire 2019-04-24 02:56:05 -07:00
parent 9e34dbe7e2
commit b5a7963a81
5 changed files with 30 additions and 8 deletions

View File

@ -25,6 +25,8 @@
class IDomainObject; class IDomainObject;
class DomainManager { class DomainManager {
public:
static constexpr u32 MinimumDomainId = 1;
public: public:
virtual std::shared_ptr<IDomainObject> AllocateDomain() = 0; virtual std::shared_ptr<IDomainObject> AllocateDomain() = 0;
virtual void FreeDomain(IDomainObject *domain) = 0; virtual void FreeDomain(IDomainObject *domain) = 0;

View File

@ -35,6 +35,7 @@ struct ServiceCommandMeta {
class IServiceObject { class IServiceObject {
public: public:
virtual ~IServiceObject() { } virtual ~IServiceObject() { }
virtual bool IsMitmObject() const { return false; }
}; };
#define SERVICE_DISPATCH_TABLE_NAME s_DispatchTable #define SERVICE_DISPATCH_TABLE_NAME s_DispatchTable
@ -96,6 +97,10 @@ class ServiceObjectHolder {
return reinterpret_cast<uintptr_t>(this->dispatch_table); return reinterpret_cast<uintptr_t>(this->dispatch_table);
} }
bool IsMitmObject() const {
return this->srv->IsMitmObject();
}
/* Default constructor, move constructor, move assignment operator. */ /* Default constructor, move constructor, move assignment operator. */
ServiceObjectHolder() : srv(nullptr), dispatch_table(nullptr) { } ServiceObjectHolder() : srv(nullptr), dispatch_table(nullptr) { }

View File

@ -39,6 +39,8 @@ class IMitmServiceObject : public IServiceObject {
return this->process_id; return this->process_id;
} }
virtual bool IsMitmObject() const override { return true; }
static bool ShouldMitm(u64 pid, u64 tid); static bool ShouldMitm(u64 pid, u64 tid);
protected: protected:

View File

@ -122,15 +122,28 @@ class MitmSession final : public ServiceSession {
case DomainMessageType_Invalid: case DomainMessageType_Invalid:
return ResultKernelConnectionClosed; return ResultKernelConnectionClosed;
case DomainMessageType_Close: case DomainMessageType_Close:
rc = ForwardRequest(ctx); {
if (R_SUCCEEDED(rc)) { auto sub_obj = ctx->obj_holder->GetServiceObject<IDomainObject>()->GetObject(ctx->request.InThisObjectId);
ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId); if (sub_obj == nullptr || (!sub_obj)) {
rc = ForwardRequest(ctx);
return rc;
} }
if (sub_obj->IsMitmObject()) {
rc = ForwardRequest(ctx);
if (R_SUCCEEDED(rc)) {
ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId);
}
} else {
rc = ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId);
}
if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get()) && !this->service_post_process_ctx->closed) { if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get()) && !this->service_post_process_ctx->closed) {
/* If we're not longer MitMing anything, we'll no longer do any postprocessing. */ /* If we're not longer MitMing anything, we'll no longer do any postprocessing. */
this->service_post_process_ctx->closed = true; this->service_post_process_ctx->closed = true;
} }
return rc; return rc;
}
case DomainMessageType_SendMessage: case DomainMessageType_SendMessage:
{ {
auto sub_obj = ctx->obj_holder->GetServiceObject<IDomainObject>()->GetObject(ctx->request.InThisObjectId); auto sub_obj = ctx->obj_holder->GetServiceObject<IDomainObject>()->GetObject(ctx->request.InThisObjectId);

View File

@ -364,7 +364,7 @@ class WaitableManager : public SessionManagerBase {
virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override { virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) { if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return ResultServiceFrameworkOutOfDomainEntries; return ResultServiceFrameworkOutOfDomainEntries;
} }
if (this->domain_objects[object_id-1].owner == nullptr) { if (this->domain_objects[object_id-1].owner == nullptr) {
@ -376,7 +376,7 @@ class WaitableManager : public SessionManagerBase {
virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override { virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) { if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return; return;
} }
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
@ -386,7 +386,7 @@ class WaitableManager : public SessionManagerBase {
virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override { virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) { if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return nullptr; return nullptr;
} }
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
@ -397,7 +397,7 @@ class WaitableManager : public SessionManagerBase {
virtual Result FreeObject(IDomainObject *domain, u32 object_id) override { virtual Result FreeObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) { if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return ResultHipcDomainObjectNotFound; return ResultHipcDomainObjectNotFound;
} }
if (this->domain_objects[object_id-1].owner == domain) { if (this->domain_objects[object_id-1].owner == domain) {
@ -410,7 +410,7 @@ class WaitableManager : public SessionManagerBase {
virtual Result ForceFreeObject(u32 object_id) override { virtual Result ForceFreeObject(u32 object_id) override {
std::scoped_lock lk{this->domain_lock}; std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) { if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return ResultHipcDomainObjectNotFound; return ResultHipcDomainObjectNotFound;
} }
if (this->domain_objects[object_id-1].owner != nullptr) { if (this->domain_objects[object_id-1].owner != nullptr) {