/*
 * Copyright (c) Atmosphère-NX
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
#pragma once
#include 
#include 
#include 
#include 
#include 
namespace ams::crypto::impl {
    class XtsModeImpl {
        NON_COPYABLE(XtsModeImpl);
        NON_MOVEABLE(XtsModeImpl);
        public:
            /* TODO: More generic support. */
            static constexpr size_t BlockSize = 16;
            static constexpr size_t IvSize    = 16;
        private:
            enum State {
                State_None,
                State_Initialized,
                State_Processing,
                State_Done
            };
        private:
            u8 buffer[BlockSize];
            u8 tweak[BlockSize];
            u8 last_block[BlockSize];
            size_t num_buffered;
            const void *cipher_ctx;
            void (*cipher_func)(void *dst_block, const void *src_block, const void *cipher_ctx);
            State state;
        public:
            XtsModeImpl() : num_buffered(0), state(State_None) { /* ... */ }
            ~XtsModeImpl() {
                ClearMemory(this, sizeof(*this));
            }
        private:
            template
            static void EncryptBlockCallback(void *dst_block, const void *src_block, const void *cipher) {
                return static_cast(cipher)->EncryptBlock(dst_block, BlockCipher::BlockSize, src_block, BlockCipher::BlockSize);
            }
            template
            static void DecryptBlockCallback(void *dst_block, const void *src_block, const void *cipher) {
                return static_cast(cipher)->DecryptBlock(dst_block, BlockCipher::BlockSize, src_block, BlockCipher::BlockSize);
            }
            template
            void Initialize(const BlockCipher *cipher, const void *tweak, size_t tweak_size) {
                AMS_ASSERT(tweak_size == IvSize);
                AMS_UNUSED(tweak_size);
                cipher->EncryptBlock(this->tweak, IvSize, tweak, IvSize);
                this->num_buffered = 0;
                this->state = State_Initialized;
            }
            void ProcessBlock(u8 *dst, const u8 *src);
        public:
            template
            void InitializeEncryption(const BlockCipher1 *cipher1, const BlockCipher2 *cipher2, const void *tweak, size_t tweak_size) {
                static_assert(BlockCipher1::BlockSize == BlockSize);
                static_assert(BlockCipher2::BlockSize == BlockSize);
                this->cipher_ctx  = cipher1;
                this->cipher_func = EncryptBlockCallback;
                this->Initialize(cipher2, tweak, tweak_size);
            }
            template
            void InitializeDecryption(const BlockCipher1 *cipher1, const BlockCipher2 *cipher2, const void *tweak, size_t tweak_size) {
                static_assert(BlockCipher1::BlockSize == BlockSize);
                static_assert(BlockCipher2::BlockSize == BlockSize);
                this->cipher_ctx  = cipher1;
                this->cipher_func = DecryptBlockCallback;
                this->Initialize(cipher2, tweak, tweak_size);
            }
            template
            size_t Update(void *dst, size_t dst_size, const void *src, size_t src_size) {
                return this->UpdateGeneric(dst, dst_size, src, src_size);
            }
            template
            size_t ProcessBlocks(u8 *dst, const u8 *src, size_t num_blocks) {
                return this->ProcessBlocksGeneric(dst, src, num_blocks);
            }
            size_t GetBufferedDataSize() const {
                return this->num_buffered;
            }
            constexpr size_t GetBlockSize() const {
                return BlockSize;
            }
            size_t FinalizeEncryption(void *dst, size_t dst_size);
            size_t FinalizeDecryption(void *dst, size_t dst_size);
            size_t UpdateGeneric(void *dst, size_t dst_size, const void *src, size_t src_size);
            size_t ProcessBlocksGeneric(u8 *dst, const u8 *src, size_t num_blocks);
            size_t ProcessPartialData(u8 *dst, const u8 *src, size_t size);
            size_t ProcessRemainingData(u8 *dst, const u8 *src, size_t size);
    };
    template<> size_t XtsModeImpl::Update(void *dst, size_t dst_size, const void *src, size_t src_size);
    template<> size_t XtsModeImpl::Update(void *dst, size_t dst_size, const void *src, size_t src_size);
    template<> size_t XtsModeImpl::Update(void *dst, size_t dst_size, const void *src, size_t src_size);
    template<> size_t XtsModeImpl::Update(void *dst, size_t dst_size, const void *src, size_t src_size);
    template<> size_t XtsModeImpl::Update(void *dst, size_t dst_size, const void *src, size_t src_size);
    template<> size_t XtsModeImpl::Update(void *dst, size_t dst_size, const void *src, size_t src_size);
}