/*
 * Copyright (c) Atmosphère-NX
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
#ifdef MESOSPHERE_USE_STUBBED_SVC_TABLES
#include 
#endif
#include 
#include 
namespace ams::kern::svc {
    /* Declare special prototypes for the light ipc handlers. */
    void CallSendSyncRequestLight64();
    void CallSendSyncRequestLight64From32();
    void CallReplyAndReceiveLight64();
    void CallReplyAndReceiveLight64From32();
    /* Declare special prototypes for ReturnFromException. */
    void CallReturnFromException64();
    void CallReturnFromException64From32();
    /* Declare special prototype for (unsupported) CallCallSecureMonitor64From32. */
    void CallCallSecureMonitor64From32();
    namespace {
        #ifndef MESOSPHERE_USE_STUBBED_SVC_TABLES
            #define DECLARE_SVC_STRUCT(ID, RETURN_TYPE, NAME, ...)                                                                        \
                class NAME {                                                                                                              \
                    private:                                                                                                              \
                        using Impl = ::ams::svc::codegen::KernelSvcWrapper<::ams::kern::svc::NAME##64, ::ams::kern::svc::NAME##64From32>; \
                    public:                                                                                                               \
                        static NOINLINE void       Call64() { return Impl::Call64(); }                                                    \
                        static NOINLINE void Call64From32() { return Impl::Call64From32(); }                                              \
                };
        #else
            #define DECLARE_SVC_STRUCT(ID, RETURN_TYPE, NAME, ...)                                                                        \
                class NAME {                                                                                                              \
                    public:                                                                                                               \
                        static NOINLINE void Call64()       { MESOSPHERE_PANIC("Stubbed Svc"#NAME"64 was called"); }                      \
                        static NOINLINE void Call64From32() { MESOSPHERE_PANIC("Stubbed Svc"#NAME"64From32 was called"); }                \
                };
        #endif
        /* Set omit-frame-pointer to prevent GCC from emitting MOV X29, SP instructions. */
        #pragma GCC push_options
        #pragma GCC optimize ("-O3")
        #pragma GCC optimize ("omit-frame-pointer")
            AMS_SVC_FOREACH_KERN_DEFINITION(DECLARE_SVC_STRUCT, _)
        #pragma GCC pop_options
        constexpr const std::array SvcTable64From32Impl = [] {
            std::array table = {};
            #define AMS_KERN_SVC_SET_TABLE_ENTRY(ID, RETURN_TYPE, NAME, ...) \
                if (table[ID] == nullptr) { table[ID] = NAME::Call64From32; }
            AMS_SVC_FOREACH_KERN_DEFINITION(AMS_KERN_SVC_SET_TABLE_ENTRY, _)
            #undef AMS_KERN_SVC_SET_TABLE_ENTRY
            table[svc::SvcId_SendSyncRequestLight] = CallSendSyncRequestLight64From32;
            table[svc::SvcId_ReplyAndReceiveLight] = CallReplyAndReceiveLight64From32;
            table[svc::SvcId_ReturnFromException]  = CallReturnFromException64From32;
            table[svc::SvcId_CallSecureMonitor]    = CallCallSecureMonitor64From32;
            return table;
        }();
        constexpr const std::array SvcTable64Impl = [] {
            std::array table = {};
            #define AMS_KERN_SVC_SET_TABLE_ENTRY(ID, RETURN_TYPE, NAME, ...) \
                if (table[ID] == nullptr) { table[ID] = NAME::Call64; }
            AMS_SVC_FOREACH_KERN_DEFINITION(AMS_KERN_SVC_SET_TABLE_ENTRY, _)
            #undef AMS_KERN_SVC_SET_TABLE_ENTRY
            table[svc::SvcId_SendSyncRequestLight] = CallSendSyncRequestLight64;
            table[svc::SvcId_ReplyAndReceiveLight] = CallReplyAndReceiveLight64;
            table[svc::SvcId_ReturnFromException]  = CallReturnFromException64;
            return table;
        }();
        constexpr bool IsValidSvcTable(const std::array &table) {
            for (size_t i = 0; i < NumSupervisorCalls; i++) {
                if (table[i] != nullptr) {
                    return true;
                }
            }
            return false;
        }
        static_assert(IsValidSvcTable(SvcTable64Impl));
        static_assert(IsValidSvcTable(SvcTable64From32Impl));
    }
    constinit const std::array SvcTable64       = SvcTable64Impl;
    constinit const std::array SvcTable64From32 = SvcTable64From32Impl;
    void PatchSvcTableEntry(const SvcTableEntry *table, u32 id, SvcTableEntry entry);
    namespace {
        /* NOTE: Although the SVC tables are constants, our global constructor will run before .rodata is protected R--. */
        class SvcTablePatcher {
            private:
                using SvcTable = std::array;
            private:
                static SvcTablePatcher s_instance;
            private:
                ALWAYS_INLINE const SvcTableEntry *GetTableData(const SvcTable *table) {
                    if (table != nullptr) {
                        return table->data();
                    } else {
                        return nullptr;
                    }
                }
                NOINLINE void PatchTables(const SvcTableEntry *table_64, const SvcTableEntry *table_64_from_32) {
                    /* Get the target firmware. */
                    const auto target_fw = kern::GetTargetFirmware();
                    /* 10.0.0 broke the ABI for QueryIoMapping. */
                    if (target_fw < TargetFirmware_10_0_0) {
                        if (table_64)         { ::ams::kern::svc::PatchSvcTableEntry(table_64,         svc::SvcId_QueryIoMapping, LegacyQueryIoMapping::Call64); }
                        if (table_64_from_32) { ::ams::kern::svc::PatchSvcTableEntry(table_64_from_32, svc::SvcId_QueryIoMapping, LegacyQueryIoMapping::Call64From32); }
                    }
                    /* 6.0.0 broke the ABI for GetFutureThreadInfo, and renamed it to GetDebugFutureThreadInfo. */
                    if (target_fw < TargetFirmware_6_0_0) {
                        static_assert(svc::SvcId_GetDebugFutureThreadInfo == svc::SvcId_LegacyGetFutureThreadInfo);
                        if (table_64)         { ::ams::kern::svc::PatchSvcTableEntry(table_64,         svc::SvcId_GetDebugFutureThreadInfo, LegacyGetFutureThreadInfo::Call64); }
                        if (table_64_from_32) { ::ams::kern::svc::PatchSvcTableEntry(table_64_from_32, svc::SvcId_GetDebugFutureThreadInfo, LegacyGetFutureThreadInfo::Call64From32); }
                    }
                    /* 3.0.0 broke the ABI for ContinueDebugEvent. */
                    if (target_fw < TargetFirmware_3_0_0) {
                        if (table_64)         { ::ams::kern::svc::PatchSvcTableEntry(table_64,         svc::SvcId_ContinueDebugEvent, LegacyContinueDebugEvent::Call64); }
                        if (table_64_from_32) { ::ams::kern::svc::PatchSvcTableEntry(table_64_from_32, svc::SvcId_ContinueDebugEvent, LegacyContinueDebugEvent::Call64From32); }
                    }
                }
            public:
                SvcTablePatcher(const SvcTable *table_64, const SvcTable *table_64_from_32) {
                    PatchTables(GetTableData(table_64), GetTableData(table_64_from_32));
                }
        };
        SvcTablePatcher SvcTablePatcher::s_instance(std::addressof(SvcTable64), std::addressof(SvcTable64From32));
    }
}