mirror of
				https://github.com/Atmosphere-NX/Atmosphere-libs.git
				synced 2025-10-31 19:45:51 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			270 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			270 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /*
 | |
|  * Copyright (c) Atmosphère-NX
 | |
|  *
 | |
|  * This program is free software; you can redistribute it and/or modify it
 | |
|  * under the terms and conditions of the GNU General Public License,
 | |
|  * version 2, as published by the Free Software Foundation.
 | |
|  *
 | |
|  * This program is distributed in the hope it will be useful, but WITHOUT
 | |
|  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 | |
|  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 | |
|  * more details.
 | |
|  *
 | |
|  * You should have received a copy of the GNU General Public License
 | |
|  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 | |
|  */
 | |
| #include <vapours.hpp>
 | |
| #include "crypto_aes_impl.arch.x64.hpp"
 | |
| 
 | |
| namespace ams::crypto::impl {
 | |
| 
 | |
|     template<> void CtrModeImpl<AesEncryptor128>::ProcessBlocks(u8 *dst, const u8 *src, size_t num_blocks) {
 | |
|         /* Check pre-conditions. */
 | |
|         AMS_ASSERT(src != nullptr);
 | |
|         AMS_ASSERT(dst != nullptr);
 | |
| 
 | |
|         /* If we have aes-ni, use an optimized impl. */
 | |
|         if (IsAesNiAvailable()) {
 | |
|             /* Load all keys into sse2 registers. */
 | |
|             const u8 *raw_round_keys = m_block_cipher->GetRoundKey();
 | |
|             const __m128i round_keys[AesEncryptor128::RoundKeySize / BlockSize] = {
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  0)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  1)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  2)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  3)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  4)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  5)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  6)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  7)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  8)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize *  9)),
 | |
|                 _mm_loadu_si128(reinterpret_cast<const __m128i *>(raw_round_keys + BlockSize * 10)),
 | |
|             };
 | |
|             static_assert(AesEncryptor128::RoundKeySize / BlockSize == 11);
 | |
| 
 | |
|             /* Declare constant for counter math. */
 | |
|             const __m128i One = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1);
 | |
| 
 | |
|             /* Process eight blocks at a time, while we can. */
 | |
|             constexpr const auto UnrolledBlockCount = 8;
 | |
|             constexpr const auto CounterThreshold   = static_cast<u8>(0x100 - UnrolledBlockCount);
 | |
| 
 | |
|             /* Load the counter. */
 | |
|             auto counter = _mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter));
 | |
| 
 | |
|             size_t cur_blocks;
 | |
|             for (cur_blocks = 0; cur_blocks + UnrolledBlockCount <= num_blocks; cur_blocks += UnrolledBlockCount) {
 | |
|                 __m128i b0;
 | |
|                 __m128i b1;
 | |
|                 __m128i b2;
 | |
|                 __m128i b3;
 | |
|                 __m128i b4;
 | |
|                 __m128i b5;
 | |
|                 __m128i b6;
 | |
|                 __m128i b7;
 | |
| 
 | |
|                 __m128i key = round_keys[0];
 | |
| 
 | |
|                 /* Get the last byte of the block. */
 | |
|                 static_assert(util::IsLittleEndian());
 | |
|                 const u8 counter_val = _mm_extract_epi16(counter, 7) >> BITSIZEOF(u8);
 | |
| 
 | |
|                 /* Do initial encryption of each block. */
 | |
|                 if (CounterThreshold <= counter_val) {
 | |
|                     /* We'll overwrap, so take slow path for counter. */
 | |
|                     _mm_storeu_si128(reinterpret_cast<__m128i *>(m_counter), counter);
 | |
| 
 | |
|                     b0 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
|                     b1 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
|                     b2 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
|                     b3 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
|                     b4 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
|                     b5 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
|                     b6 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
|                     b7 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter)), key);
 | |
|                     this->IncrementCounter();
 | |
| 
 | |
|                     counter = _mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter));
 | |
|                 } else {
 | |
|                     /* We can take the fast path for the counter. */
 | |
|                     b0 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                     b1 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                     b2 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                     b3 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                     b4 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                     b5 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                     b6 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                     b7 = _mm_xor_si128(counter, key);
 | |
|                     counter = _mm_add_epi64(counter, One);
 | |
|                 }
 | |
| 
 | |
|                 /* Do encryption for all rounds. */
 | |
|                 key = round_keys[1];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[2];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[3];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[4];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[5];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[6];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[7];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[8];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[9];
 | |
|                 b0 = _mm_aesenc_si128(b0, key);
 | |
|                 b1 = _mm_aesenc_si128(b1, key);
 | |
|                 b2 = _mm_aesenc_si128(b2, key);
 | |
|                 b3 = _mm_aesenc_si128(b3, key);
 | |
|                 b4 = _mm_aesenc_si128(b4, key);
 | |
|                 b5 = _mm_aesenc_si128(b5, key);
 | |
|                 b6 = _mm_aesenc_si128(b6, key);
 | |
|                 b7 = _mm_aesenc_si128(b7, key);
 | |
| 
 | |
|                 key = round_keys[10];
 | |
|                 b0 = _mm_aesenclast_si128(b0, key);
 | |
|                 b1 = _mm_aesenclast_si128(b1, key);
 | |
|                 b2 = _mm_aesenclast_si128(b2, key);
 | |
|                 b3 = _mm_aesenclast_si128(b3, key);
 | |
|                 b4 = _mm_aesenclast_si128(b4, key);
 | |
|                 b5 = _mm_aesenclast_si128(b5, key);
 | |
|                 b6 = _mm_aesenclast_si128(b6, key);
 | |
|                 b7 = _mm_aesenclast_si128(b7, key);
 | |
| 
 | |
|                 /* Write the blocks. */
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 0), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 0)), b0));
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 1), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 1)), b1));
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 2), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 2)), b2));
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 3), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 3)), b3));
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 4), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 4)), b4));
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 5), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 5)), b5));
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 6), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 6)), b6));
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst + BlockSize * 7), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src + BlockSize * 7)), b7));
 | |
| 
 | |
|                 src += BlockSize * UnrolledBlockCount;
 | |
|                 dst += BlockSize * UnrolledBlockCount;
 | |
|             }
 | |
| 
 | |
|             /* Store the updated counter. */
 | |
|             _mm_storeu_si128(reinterpret_cast<__m128i *>(m_counter), counter);
 | |
| 
 | |
|             /* Process blocks one at a time. */
 | |
|             for (/* ... */; cur_blocks < num_blocks; ++cur_blocks) {
 | |
|                 /* Load current counter. */
 | |
|                 __m128i b = _mm_loadu_si128(reinterpret_cast<const __m128i *>(m_counter));
 | |
| 
 | |
|                 /* Do aes rounds. */
 | |
|                 b = _mm_xor_si128(b, round_keys[0]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[1]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[2]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[3]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[4]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[5]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[6]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[7]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[8]);
 | |
|                 b = _mm_aesenc_si128(b, round_keys[9]);
 | |
|                 b = _mm_aesenclast_si128(b, round_keys[10]);
 | |
| 
 | |
|                 /* Write the block. */
 | |
|                 _mm_storeu_si128(reinterpret_cast<__m128i *>(dst), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src)), b));
 | |
| 
 | |
|                 /* Advance. */
 | |
|                 src += BlockSize;
 | |
|                 dst += BlockSize;
 | |
|                 this->IncrementCounter();
 | |
|             }
 | |
|         } else {
 | |
|             /* Fall back to the default implementation. */
 | |
|             while (num_blocks--) {
 | |
|                 this->ProcessBlock(dst, src, BlockSize);
 | |
|                 dst += BlockSize;
 | |
|                 src += BlockSize;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
| }
 |