/*
 * Copyright (c) Atmosphère-NX
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
#include 
#include "se_execute.hpp"
namespace ams::se {
    /* NOTE: This implementation is mostly copy/pasted from crypto::impl::RsaOaepImpl. */
    namespace {
        constexpr inline size_t HashSize = sizeof(Sha256Hash);
        constexpr inline u8 HeadMagic = 0x00;
        void ApplyMGF1(u8 *dst, size_t dst_size, const void *src, size_t src_size) {
            /* Check our pre-conditions. */
            AMS_ABORT_UNLESS(src_size <= RsaSize - (1 + HashSize));
            /* Create a buffer. */
            util::AlignedBuffer buf;
            u32 counter = 0;
            while (dst_size > 0) {
                /* Setup the current hash buffer. */
                const size_t cur_size = std::min(HashSize, dst_size);
                std::memcpy(static_cast(buf), src, src_size);
                {
                    u32 counter_be;
                    util::StoreBigEndian(std::addressof(counter_be), counter++);
                    std::memcpy(static_cast(buf) + src_size, std::addressof(counter_be), sizeof(counter_be));
                }
                /* Ensure se sees correct data. */
                hw::FlushDataCache(buf, src_size + sizeof(u32));
                hw::DataSynchronizationBarrierInnerShareable();
                /* Calculate the hash. */
                Sha256Hash hash;
                se::CalculateSha256(std::addressof(hash), buf, src_size + sizeof(u32));
                /* Mask the current output. */
                const u8 *mask = hash.bytes;
                for (size_t i = 0; i < cur_size; ++i) {
                    *(dst++) ^= *(mask++);
                }
                /* Advance. */
                dst_size -= cur_size;
            }
        }
    }
    size_t DecodeRsaOaepSha256(void *dst, size_t dst_size, void *src, size_t src_size, const void *label_digest, size_t label_digest_size) {
        /* Check our preconditions. */
        AMS_ABORT_UNLESS(src_size == RsaSize);
        AMS_ABORT_UNLESS(label_digest_size == HashSize);
        /* Get a byte-readable copy of the input. */
        u8 *buf = static_cast(src);
        /* Validate sanity byte. */
        bool is_valid = buf[0] == HeadMagic;
        /* Decrypt seed and masked db. */
        size_t db_len = src_size - HashSize - 1;
        u8 *seed = buf + 1;
        u8 *db   = seed + HashSize;
        ApplyMGF1(seed, HashSize, db, db_len);
        ApplyMGF1(db, db_len, seed, HashSize);
        /* Check the label digest. */
        is_valid &= crypto::IsSameBytes(label_digest, db, HashSize);
        /* Skip past the label digest. */
        db     += HashSize;
        db_len -= HashSize;
        /* Verify that DB is of the form 0000...0001 < message > */
        s32 msg_ofs = 0;
        {
            int looking_for_one = 1;
            int invalid_db_padding = 0;
            int is_zero;
            int is_one;
            for (size_t i = 0; i < db_len; /* ... */) {
                is_zero = (db[i] == 0);
                is_one  = (db[i] == 1);
                msg_ofs += (looking_for_one & is_one) * (static_cast(++i));
                looking_for_one &= ~is_one;
                invalid_db_padding |= (looking_for_one & ~is_zero);
            }
            is_valid &= (invalid_db_padding == 0);
        }
        /* If we're invalid, return zero size. */
        const size_t valid_msg_size = db_len - msg_ofs;
        const size_t msg_size       = std::min(dst_size, static_cast(is_valid) * valid_msg_size);
        /* Copy to output. */
        std::memcpy(dst, db + msg_ofs, msg_size);
        /* Return copied size. */
        return msg_size;
    }
}