libstrat: improve Domain Object Close semantics

This commit is contained in:
Michael Scire 2019-04-24 02:56:05 -07:00
parent 9e34dbe7e2
commit b5a7963a81
5 changed files with 30 additions and 8 deletions

View File

@ -25,6 +25,8 @@
class IDomainObject;
class DomainManager {
public:
static constexpr u32 MinimumDomainId = 1;
public:
virtual std::shared_ptr<IDomainObject> AllocateDomain() = 0;
virtual void FreeDomain(IDomainObject *domain) = 0;

View File

@ -35,6 +35,7 @@ struct ServiceCommandMeta {
class IServiceObject {
public:
virtual ~IServiceObject() { }
virtual bool IsMitmObject() const { return false; }
};
#define SERVICE_DISPATCH_TABLE_NAME s_DispatchTable
@ -96,6 +97,10 @@ class ServiceObjectHolder {
return reinterpret_cast<uintptr_t>(this->dispatch_table);
}
bool IsMitmObject() const {
return this->srv->IsMitmObject();
}
/* Default constructor, move constructor, move assignment operator. */
ServiceObjectHolder() : srv(nullptr), dispatch_table(nullptr) { }

View File

@ -39,6 +39,8 @@ class IMitmServiceObject : public IServiceObject {
return this->process_id;
}
virtual bool IsMitmObject() const override { return true; }
static bool ShouldMitm(u64 pid, u64 tid);
protected:

View File

@ -122,15 +122,28 @@ class MitmSession final : public ServiceSession {
case DomainMessageType_Invalid:
return ResultKernelConnectionClosed;
case DomainMessageType_Close:
rc = ForwardRequest(ctx);
if (R_SUCCEEDED(rc)) {
ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId);
{
auto sub_obj = ctx->obj_holder->GetServiceObject<IDomainObject>()->GetObject(ctx->request.InThisObjectId);
if (sub_obj == nullptr || (!sub_obj)) {
rc = ForwardRequest(ctx);
return rc;
}
if (sub_obj->IsMitmObject()) {
rc = ForwardRequest(ctx);
if (R_SUCCEEDED(rc)) {
ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId);
}
} else {
rc = ctx->obj_holder->GetServiceObject<IDomainObject>()->FreeObject(ctx->request.InThisObjectId);
}
if (R_SUCCEEDED(rc) && ctx->request.InThisObjectId == serviceGetObjectId(this->forward_service.get()) && !this->service_post_process_ctx->closed) {
/* If we're not longer MitMing anything, we'll no longer do any postprocessing. */
this->service_post_process_ctx->closed = true;
}
return rc;
}
case DomainMessageType_SendMessage:
{
auto sub_obj = ctx->obj_holder->GetServiceObject<IDomainObject>()->GetObject(ctx->request.InThisObjectId);

View File

@ -364,7 +364,7 @@ class WaitableManager : public SessionManagerBase {
virtual Result ReserveSpecificObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return ResultServiceFrameworkOutOfDomainEntries;
}
if (this->domain_objects[object_id-1].owner == nullptr) {
@ -376,7 +376,7 @@ class WaitableManager : public SessionManagerBase {
virtual void SetObject(IDomainObject *domain, u32 object_id, ServiceObjectHolder&& holder) override {
std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return;
}
if (this->domain_objects[object_id-1].owner == domain) {
@ -386,7 +386,7 @@ class WaitableManager : public SessionManagerBase {
virtual ServiceObjectHolder *GetObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return nullptr;
}
if (this->domain_objects[object_id-1].owner == domain) {
@ -397,7 +397,7 @@ class WaitableManager : public SessionManagerBase {
virtual Result FreeObject(IDomainObject *domain, u32 object_id) override {
std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return ResultHipcDomainObjectNotFound;
}
if (this->domain_objects[object_id-1].owner == domain) {
@ -410,7 +410,7 @@ class WaitableManager : public SessionManagerBase {
virtual Result ForceFreeObject(u32 object_id) override {
std::scoped_lock lk{this->domain_lock};
if (object_id > ManagerOptions::MaxDomainObjects) {
if (object_id > ManagerOptions::MaxDomainObjects || object_id < MinimumDomainId) {
return ResultHipcDomainObjectNotFound;
}
if (this->domain_objects[object_id-1].owner != nullptr) {