Skip to content

Commit

Permalink
Merge pull request #155 from dorssel/ctr_oneshot
Browse files Browse the repository at this point in the history
Refactor one-shot AesCtr
  • Loading branch information
dorssel authored Jan 12, 2025
2 parents bc75544 + e8f6b4b commit 2aa198e
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 136 deletions.
138 changes: 60 additions & 78 deletions AesExtra/AesCtr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: MIT

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;

Expand Down Expand Up @@ -138,13 +139,11 @@ public override void GenerateKey()
}

#region Modern_SymmetricAlgorithm
bool TryTransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv, Span<byte> destination, out int bytesWritten)
void OneShot(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv, Span<byte> destination)
{
if (destination.Length < input.Length)
{
bytesWritten = 0;
return false;
}
Debug.Assert(iv.Length == BLOCKSIZE);
Debug.Assert(destination.Length >= input.Length);

using var transform = new AesCtrTransform(Key, iv);
var inputSlice = input;
var destinationSlice = destination;
Expand All @@ -164,115 +163,98 @@ bool TryTransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv, Span<byte>
block[0..inputSlice.Length].CopyTo(destinationSlice);
CryptographicOperations.ZeroMemory(block);
}
bytesWritten = input.Length;
return true;
}

byte[] TransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv)
{
var output = new byte[input.Length];
_ = TryTransformCtr(input, iv, output, out _);
return output;
}

int TransformCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv, Span<byte> destination)
{
return TryTransformCtr(plaintext, iv, destination, out var bytesWritten) ? bytesWritten
: throw new ArgumentException("Destination is too short.", nameof(destination));
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="input">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] EncryptCtr(byte[] plaintext, byte[] iv)
public byte[] TransformCtr(byte[] input, byte[] iv)
{
return TransformCtr(plaintext, iv);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] EncryptCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv)
{
return TransformCtr(plaintext, iv);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <returns>TODO</returns>
public int EncryptCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv, Span<byte> destination)
{
return TransformCtr(plaintext, iv, destination);
}
if (input is null)
{
throw new ArgumentNullException(nameof(input));
}
if (iv is null)
{
throw new ArgumentNullException(nameof(input));
}
if (iv.Length != BLOCKSIZE)
{
throw new ArgumentException("Specified initial counter (IV) does not match the block size for this algorithm.", nameof(iv));
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <param name="bytesWritten">TODO</param>
/// <returns>TODO</returns>
public bool TryEncryptCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv, Span<byte> destination, out int bytesWritten)
{
return TryTransformCtr(plaintext, iv, destination, out bytesWritten);
var output = new byte[input.Length];
OneShot(input, iv, output);
return output;
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="input">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] DecryptCtr(byte[] ciphertext, byte[] iv)
public byte[] TransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv)
{
return TransformCtr(ciphertext, iv);
}
if (iv.Length != BLOCKSIZE)
{
throw new ArgumentException("Specified initial counter (IV) does not match the block size for this algorithm.", nameof(iv));
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] DecryptCtr(ReadOnlySpan<byte> ciphertext, ReadOnlySpan<byte> iv)
{
return TransformCtr(ciphertext, iv);
var output = new byte[input.Length];
OneShot(input, iv, output);
return output;
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="input">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <returns>TODO</returns>
public int DecryptCtr(ReadOnlySpan<byte> ciphertext, ReadOnlySpan<byte> iv, Span<byte> destination)
public int TransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv, Span<byte> destination)
{
return TransformCtr(ciphertext, iv, destination);
if (iv.Length != BLOCKSIZE)
{
throw new ArgumentException("Specified initial counter (IV) does not match the block size for this algorithm.", nameof(iv));
}
if (destination.Length < input.Length)
{
throw new ArgumentException("Destination is too short.", nameof(destination));
}

OneShot(input, iv, destination);
return input.Length;
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="input">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <param name="bytesWritten">TODO</param>
/// <returns>TODO</returns>
public bool TryDecryptCtr(ReadOnlySpan<byte> ciphertext, ReadOnlySpan<byte> iv, Span<byte> destination, out int bytesWritten)
public bool TryTransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv, Span<byte> destination, out int bytesWritten)
{
return TryTransformCtr(ciphertext, iv, destination, out bytesWritten);
if (iv.Length != BLOCKSIZE)
{
throw new ArgumentException("Specified initial counter (IV) does not match the block size for this algorithm.", nameof(iv));
}

if (destination.Length < input.Length)
{
bytesWritten = 0;
return false;
}

OneShot(input, iv, destination);
bytesWritten = input.Length;
return true;
}
#endregion
}
128 changes: 86 additions & 42 deletions UnitTests/AesCtr_KAT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,106 +46,150 @@ public void Encrypt_Read(NistAesCtrSampleTestVector testVector)
[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void EncryptCtr_Bytes(NistAesCtrSampleTestVector testVector)
public void Decrypt_Write(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var ciphertext = aes.EncryptCtr(testVector.Plaintext.ToArray(), testVector.InitialCounter.ToArray());
CollectionAssert.AreEqual(testVector.Ciphertext.ToArray(), ciphertext);
aes.IV = testVector.InitialCounter.ToArray();
using var ciphertextStream = new MemoryStream(testVector.Ciphertext.ToArray());
using var plaintextStream = new MemoryStream();
{
using var decryptor = aes.CreateDecryptor();
using var decryptorStream = new CryptoStream(plaintextStream, decryptor, CryptoStreamMode.Write);
ciphertextStream.CopyTo(decryptorStream);
}
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), plaintextStream.ToArray());
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void EncryptCtr_Span(NistAesCtrSampleTestVector testVector)
public void Decrypt_Read(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var ciphertext = aes.EncryptCtr(testVector.Plaintext.Span, testVector.InitialCounter.Span);
CollectionAssert.AreEqual(testVector.Ciphertext.ToArray(), ciphertext);
aes.IV = testVector.InitialCounter.ToArray();
using var ciphertextStream = new MemoryStream(testVector.Ciphertext.ToArray());
using var plaintextStream = new MemoryStream();
{
using var decryptor = aes.CreateDecryptor();
using var decryptorStream = new CryptoStream(ciphertextStream, decryptor, CryptoStreamMode.Read);
decryptorStream.CopyTo(plaintextStream);
}
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), plaintextStream.ToArray());
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void EncryptCtr_Destination(NistAesCtrSampleTestVector testVector)
public void Encrypt_TransformCtr_Array_Array(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var ciphertext = new byte[testVector.Plaintext.Length];
var count = aes.EncryptCtr(testVector.Plaintext.Span, testVector.InitialCounter.Span, ciphertext);
CollectionAssert.AreEqual(testVector.Ciphertext.ToArray(), ciphertext.ToArray());
Assert.AreEqual(testVector.Plaintext.Length, count);

var destination = aes.TransformCtr(testVector.Plaintext.ToArray(), testVector.InitialCounter.ToArray());

CollectionAssert.AreEqual(testVector.Ciphertext.ToArray(), destination);
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void Decrypt_Write(NistAesCtrSampleTestVector testVector)
public void Encrypt_TransformCtr_ReadOnlySpan_ReadOnlySpan(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
aes.IV = testVector.InitialCounter.ToArray();
using var ciphertextStream = new MemoryStream(testVector.Ciphertext.ToArray());
using var plaintextStream = new MemoryStream();
{
using var decryptor = aes.CreateDecryptor();
using var decryptorStream = new CryptoStream(plaintextStream, decryptor, CryptoStreamMode.Write);
ciphertextStream.CopyTo(decryptorStream);
}
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), plaintextStream.ToArray());

var destination = aes.TransformCtr(testVector.Plaintext.Span, testVector.InitialCounter.Span);

CollectionAssert.AreEqual(testVector.Ciphertext.ToArray(), destination);
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void Decrypt_Read(NistAesCtrSampleTestVector testVector)
public void Encrypt_TransformCtr_ReadOnlySpan_ReadOnlySpan_Span(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
aes.IV = testVector.InitialCounter.ToArray();
using var ciphertextStream = new MemoryStream(testVector.Ciphertext.ToArray());
using var plaintextStream = new MemoryStream();
{
using var decryptor = aes.CreateDecryptor();
using var decryptorStream = new CryptoStream(ciphertextStream, decryptor, CryptoStreamMode.Read);
decryptorStream.CopyTo(plaintextStream);
}
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), plaintextStream.ToArray());
var destination = new byte[testVector.Ciphertext.Length];

var count = aes.TransformCtr(testVector.Plaintext.Span, testVector.InitialCounter.Span, destination);

Assert.AreEqual(testVector.Ciphertext.Length, count);
CollectionAssert.AreEqual(testVector.Ciphertext.ToArray(), destination);
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void DecryptCtr_Bytes(NistAesCtrSampleTestVector testVector)
public void Encrypt_TryTransformCtr(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var plaintext = aes.DecryptCtr(testVector.Ciphertext.ToArray(), testVector.InitialCounter.ToArray());
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), plaintext);
var destination = new byte[testVector.Ciphertext.Length];

var success = aes.TryTransformCtr(testVector.Plaintext.Span, testVector.InitialCounter.Span, destination, out var bytesWritten);

Assert.IsTrue(success);
Assert.AreEqual(testVector.Ciphertext.Length, bytesWritten);
CollectionAssert.AreEqual(testVector.Ciphertext.ToArray(), destination);
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void DecryptCtr_Span(NistAesCtrSampleTestVector testVector)
public void Decrypt_TransformCtr_Array_Array(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var plaintext = aes.DecryptCtr(testVector.Ciphertext.Span, testVector.InitialCounter.Span);
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), plaintext);

var destination = aes.TransformCtr(testVector.Ciphertext.ToArray(), testVector.InitialCounter.ToArray());

CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), destination);
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void DecryptCtr_Destination(NistAesCtrSampleTestVector testVector)
public void Decrypt_TransformCtr_ReadOnlySpan_ReadOnlySpan(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var plaintext = new byte[testVector.Ciphertext.Length];
var count = aes.DecryptCtr(testVector.Ciphertext.Span, testVector.InitialCounter.Span, plaintext);
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), plaintext.ToArray());
Assert.AreEqual(testVector.Ciphertext.Length, count);

var destination = aes.TransformCtr(testVector.Ciphertext.Span, testVector.InitialCounter.Span);

CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), destination);
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void Decrypt_TransformCtr_ReadOnlySpan_ReadOnlySpan_Span(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var destination = new byte[testVector.Plaintext.Length];

var count = aes.TransformCtr(testVector.Ciphertext.Span, testVector.InitialCounter.Span, destination);

Assert.AreEqual(testVector.Plaintext.Length, count);
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), destination);
}

[TestMethod]
[TestCategory("NIST")]
[NistAesCtrSampleDataSource]
public void Decrypt_TryTransformCtr(NistAesCtrSampleTestVector testVector)
{
using var aes = AesCtr.Create();
aes.Key = testVector.Key.ToArray();
var destination = new byte[testVector.Plaintext.Length];

var success = aes.TryTransformCtr(testVector.Ciphertext.Span, testVector.InitialCounter.Span, destination, out var bytesWritten);

Assert.IsTrue(success);
Assert.AreEqual(testVector.Plaintext.Length, bytesWritten);
CollectionAssert.AreEqual(testVector.Plaintext.ToArray(), destination);
}
}
Loading

0 comments on commit 2aa198e

Please sign in to comment.