/* * Copyright (c) 2018-2020 Atmosphère-NX * * This program is free software; you can redistribute it and/or modify it * under the terms and conditions of the GNU General Public License, * version 2, as published by the Free Software Foundation. * * This program is distributed in the hope it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #pragma once #include #include #include #include #include namespace ams::crypto::impl { template class CtrModeImpl { NON_COPYABLE(CtrModeImpl); NON_MOVEABLE(CtrModeImpl); public: static constexpr size_t KeySize = BlockCipher::KeySize; static constexpr size_t BlockSize = BlockCipher::BlockSize; static constexpr size_t IvSize = BlockCipher::BlockSize; private: enum State { State_None, State_Initialized, }; private: const BlockCipher *block_cipher; u8 counter[IvSize]; u8 encrypted_counter[BlockSize]; size_t buffer_offset; State state; public: CtrModeImpl() : state(State_None) { /* ... */ } ~CtrModeImpl() { ClearMemory(this, sizeof(*this)); } void Initialize(const BlockCipher *block_cipher, const void *iv, size_t iv_size) { this->Initialize(block_cipher, iv, iv_size, 0); } void Initialize(const BlockCipher *block_cipher, const void *iv, size_t iv_size, s64 offset) { AMS_ASSERT(iv_size == IvSize); AMS_ASSERT(offset >= 0); this->block_cipher = block_cipher; this->state = State_Initialized; this->SwitchMessage(iv, iv_size); if (offset >= 0) { u64 ctr_offset = offset / BlockSize; if (ctr_offset > 0) { this->IncrementCounter(ctr_offset); } if (size_t remaining = static_cast(offset % BlockSize); remaining != 0) { this->block_cipher->EncryptBlock(this->encrypted_counter, sizeof(this->encrypted_counter), this->counter, sizeof(this->counter)); this->IncrementCounter(); this->buffer_offset = remaining; } } } void SwitchMessage(const void *iv, size_t iv_size) { AMS_ASSERT(this->state == State_Initialized); AMS_ASSERT(iv_size == IvSize); std::memcpy(this->counter, iv, iv_size); this->buffer_offset = 0; } void IncrementCounter() { for (s32 i = IvSize - 1; i >= 0; --i) { if (++this->counter[i] != 0) { break; } } } size_t Update(void *_dst, size_t dst_size, const void *_src, size_t src_size) { AMS_ASSERT(this->state == State_Initialized); AMS_ASSERT(dst_size >= src_size); AMS_UNUSED(dst_size); u8 *dst = static_cast(_dst); const u8 *src = static_cast(_src); size_t remaining = src_size; if (this->buffer_offset > 0) { const size_t xor_size = std::min(BlockSize - this->buffer_offset, remaining); const u8 *ctr = this->encrypted_counter + this->buffer_offset; for (size_t i = 0; i < xor_size; i++) { dst[i] = src[i] ^ ctr[i]; } src += xor_size; dst += xor_size; remaining -= xor_size; this->buffer_offset += xor_size; if (this->buffer_offset == BlockSize) { this->buffer_offset = 0; } } if (remaining >= BlockSize) { const size_t num_blocks = remaining / BlockSize; this->ProcessBlocks(dst, src, num_blocks); const size_t processed_size = num_blocks * BlockSize; dst += processed_size; src += processed_size; remaining -= processed_size; } if (remaining > 0) { this->ProcessBlock(dst, src, remaining); this->buffer_offset = remaining; } return src_size; } private: void IncrementCounter(u64 count) { u64 _block[IvSize / sizeof(u64)] = {}; util::StoreBigEndian(std::addressof(_block[(IvSize / sizeof(u64)) - 1]), count); u16 acc = 0; const u8 *block = reinterpret_cast(_block); for (s32 i = IvSize - 1; i >= 0; --i) { acc += (this->counter[i] + block[i]); this->counter[i] = acc & 0xFF; acc >>= 8; } } void ProcessBlock(u8 *dst, const u8 *src, size_t src_size) { this->block_cipher->EncryptBlock(this->encrypted_counter, BlockSize, this->counter, IvSize); this->IncrementCounter(); for (size_t i = 0; i < src_size; i++) { dst[i] = src[i] ^ this->encrypted_counter[i]; } } void ProcessBlocks(u8 *dst, const u8 *src, size_t num_blocks); }; template inline void CtrModeImpl::ProcessBlocks(u8 *dst, const u8 *src, size_t num_blocks) { while (num_blocks--) { this->ProcessBlock(dst, src, BlockSize); dst += BlockSize; src += BlockSize; } } template<> void CtrModeImpl::ProcessBlocks(u8 *dst, const u8 *src, size_t num_blocks); template<> void CtrModeImpl::ProcessBlocks(u8 *dst, const u8 *src, size_t num_blocks); template<> void CtrModeImpl::ProcessBlocks(u8 *dst, const u8 *src, size_t num_blocks); }