/* * Copyright (c) 2018-2020 Atmosphère-NX * * This program is free software; you can redistribute it and/or modify it * under the terms and conditions of the GNU General Public License, * version 2, as published by the Free Software Foundation. * * This program is distributed in the hope it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #pragma once #include #include #include #include namespace ams::crypto::impl { template requires HashFunction class RsaPssImpl { NON_COPYABLE(RsaPssImpl); NON_MOVEABLE(RsaPssImpl); public: static constexpr size_t HashSize = Hash::HashSize; private: static constexpr u8 TailMagic = 0xBC; private: static void ComputeHashWithPadding(void *dst, const u8 *user_hash, size_t user_hash_size, const void *salt, size_t salt_size) { AMS_ASSERT(user_hash_size == HashSize); AMS_UNUSED(user_hash_size); /* Initialize our buffer. */ u8 buf[8 + HashSize]; std::memset(buf, 0, 8); std::memcpy(buf + 8, user_hash, HashSize); ON_SCOPE_EXIT { ClearMemory(buf, sizeof(buf)); }; /* Calculate our hash. */ Hash hash; hash.Initialize(); hash.Update(buf, sizeof(buf)); hash.Update(salt, salt_size); hash.GetHash(dst, HashSize); } static void ApplyMGF1(u8 *dst, size_t dst_size, const void *src, size_t src_size) { u8 buf[HashSize]; ON_SCOPE_EXIT { ClearMemory(buf, sizeof(buf)); }; const size_t required_iters = (dst_size + HashSize - 1) / HashSize; for (size_t i = 0; i < required_iters; i++) { Hash hash; hash.Initialize(); hash.Update(src, src_size); const u32 tmp = util::ConvertToBigEndian(static_cast(i)); hash.Update(std::addressof(tmp), sizeof(tmp)); hash.GetHash(buf, HashSize); const size_t start = HashSize * i; const size_t end = std::min(dst_size, start + HashSize); for (size_t j = start; j < end; j++) { dst[j] ^= buf[j - start]; } } } public: RsaPssImpl() { /* ... */ } bool Verify(u8 *buf, size_t size, const u8 *hash, size_t hash_size) { /* Validate sanity byte. */ bool is_valid = buf[size - 1] == TailMagic; /* Decrypt maskedDB */ const size_t db_len = size - HashSize - 1; u8 *db = buf; u8 *h = db + db_len; ApplyMGF1(db, db_len, h, HashSize); /* Apply lmask. */ db[0] &= 0x7F; /* Verify that DB is of the form 0000...0001 */ s32 salt_ofs = 0; { int looking_for_one = 1; int invalid_db_padding = 0; int is_zero; int is_one; for (size_t i = 0; i < db_len; /* ... */) { is_zero = (db[i] == 0); is_one = (db[i] == 1); salt_ofs += (looking_for_one & is_one) * (static_cast(++i)); looking_for_one &= ~is_one; invalid_db_padding |= (looking_for_one & ~is_zero); } is_valid &= (invalid_db_padding == 0); } /* Verify salt. */ const u8 *salt = db + salt_ofs; const size_t salt_size = db_len - salt_ofs; is_valid &= (salt_size != 0); is_valid &= (salt_size != db_len); /* Verify hash. */ u8 cmp_hash[HashSize]; ON_SCOPE_EXIT { ClearMemory(cmp_hash, sizeof(cmp_hash)); }; ComputeHashWithPadding(cmp_hash, hash, hash_size, salt, salt_size); is_valid &= IsSameBytes(cmp_hash, h, HashSize); /* Succeed if all our checks succeeded. */ return is_valid; } }; }