Skip to content

Commit

Permalink
Merge pull request #150 from dorssel/fix_ctr_iv
Browse files Browse the repository at this point in the history
Fix IV handling for AesCtr
  • Loading branch information
dorssel authored Jan 10, 2025
2 parents 3f2f7a9 + 1f94d3c commit 87ea27f
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 82 deletions.
66 changes: 39 additions & 27 deletions AesExtra/AesCtr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ public sealed class AesCtr
const CipherMode FixedModeValue = CipherMode.CTS; // DevSkim: ignore DS187371
const PaddingMode FixedPaddingValue = PaddingMode.None;
const int FixedFeedbackSizeValue = BLOCKSIZE * BitsPerByte;
static readonly byte[] BlockOfZeros = new byte[BLOCKSIZE];

/// <inheritdoc cref="Aes.Create()" />
public static new AesCtr Create()
Expand Down Expand Up @@ -97,18 +96,23 @@ public override int FeedbackSize
}
}

// CTR.Encrypt === CTR.Decrypt; the transform is entirely symmetric.
static AesCtrTransform CreateTransform(byte[] rgbKey, byte[]? rgbIV)
{
return rgbIV is not null ? new(rgbKey, rgbIV)
: throw new CryptographicException("The cipher mode specified requires that an initialization vector(IV) be used.");
}

/// <inheritdoc cref="AesManaged.CreateDecryptor(byte[], byte[])" />
public override ICryptoTransform CreateDecryptor(byte[] rgbKey, byte[]? rgbIV)
{
// CTR.Encrypt === CTR.Decrypt; the transform is entirely symmetric.
return new AesCtrTransform(rgbKey, rgbIV ?? BlockOfZeros);
return CreateTransform(rgbKey, rgbIV);
}

/// <inheritdoc cref="AesManaged.CreateEncryptor(byte[], byte[])" />
public override ICryptoTransform CreateEncryptor(byte[] rgbKey, byte[]? rgbIV)
{
// CTR.Encrypt === CTR.Decrypt; the transform is entirely symmetric.
return new AesCtrTransform(rgbKey, rgbIV ?? BlockOfZeros);
return CreateTransform(rgbKey, rgbIV);
}

/// <inheritdoc cref="AesManaged.GenerateIV" />
Expand All @@ -130,14 +134,14 @@ public override void GenerateKey()
}

#region Modern_SymmetricAlgorithm
bool TryTransformCtr(ReadOnlySpan<byte> input, Span<byte> destination, out int bytesWritten)
bool TryTransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv, Span<byte> destination, out int bytesWritten)
{
if (destination.Length < input.Length)
{
bytesWritten = 0;
return false;
}
using var transform = new AesCtrTransform(Key, IVValue ?? BlockOfZeros);
using var transform = new AesCtrTransform(Key, iv);
var inputSlice = input;
var destinationSlice = destination;
while (inputSlice.Length >= BLOCKSIZE)
Expand All @@ -160,103 +164,111 @@ bool TryTransformCtr(ReadOnlySpan<byte> input, Span<byte> destination, out int b
return true;
}

byte[] TransformCtr(ReadOnlySpan<byte> input)
byte[] TransformCtr(ReadOnlySpan<byte> input, ReadOnlySpan<byte> iv)
{
var output = new byte[input.Length];
_ = TryTransformCtr(input, output, out _);
_ = TryTransformCtr(input, iv, output, out _);
return output;
}

int TransformCtr(ReadOnlySpan<byte> plaintext, Span<byte> destination)
int TransformCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv, Span<byte> destination)
{
return TryTransformCtr(plaintext, destination, out var bytesWritten) ? bytesWritten
return TryTransformCtr(plaintext, iv, destination, out var bytesWritten) ? bytesWritten
: throw new ArgumentException("Destination is too short.");
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] EncryptCtr(byte[] plaintext)
public byte[] EncryptCtr(byte[] plaintext, byte[] iv)
{
return TransformCtr(plaintext);
return TransformCtr(plaintext, iv);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] EncryptCtr(ReadOnlySpan<byte> plaintext)
public byte[] EncryptCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv)
{
return TransformCtr(plaintext);
return TransformCtr(plaintext, iv);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <returns>TODO</returns>
public int EncryptCtr(ReadOnlySpan<byte> plaintext, Span<byte> destination)
public int EncryptCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv, Span<byte> destination)
{
return TransformCtr(plaintext, destination);
return TransformCtr(plaintext, iv, destination);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="plaintext">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <param name="bytesWritten">TODO</param>
/// <returns>TODO</returns>
public bool TryEncryptCtr(ReadOnlySpan<byte> plaintext, Span<byte> destination, out int bytesWritten)
public bool TryEncryptCtr(ReadOnlySpan<byte> plaintext, ReadOnlySpan<byte> iv, Span<byte> destination, out int bytesWritten)
{
return TryTransformCtr(plaintext, destination, out bytesWritten);
return TryTransformCtr(plaintext, iv, destination, out bytesWritten);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] DecryptCtr(byte[] ciphertext)
public byte[] DecryptCtr(byte[] ciphertext, byte[] iv)
{
return TransformCtr(ciphertext);
return TransformCtr(ciphertext, iv);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="iv">TODO</param>
/// <returns>TODO</returns>
public byte[] DecryptCtr(ReadOnlySpan<byte> ciphertext)
public byte[] DecryptCtr(ReadOnlySpan<byte> ciphertext, ReadOnlySpan<byte> iv)
{
return TransformCtr(ciphertext);
return TransformCtr(ciphertext, iv);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <returns>TODO</returns>
public int DecryptCtr(ReadOnlySpan<byte> ciphertext, Span<byte> destination)
public int DecryptCtr(ReadOnlySpan<byte> ciphertext, ReadOnlySpan<byte> iv, Span<byte> destination)
{
return TransformCtr(ciphertext, destination);
return TransformCtr(ciphertext, iv, destination);
}

/// <summary>
/// TODO
/// </summary>
/// <param name="ciphertext">TODO</param>
/// <param name="iv">TODO</param>
/// <param name="destination">TODO</param>
/// <param name="bytesWritten">TODO</param>
/// <returns>TODO</returns>
public bool TryDecryptCtr(ReadOnlySpan<byte> ciphertext, Span<byte> destination, out int bytesWritten)
public bool TryDecryptCtr(ReadOnlySpan<byte> ciphertext, ReadOnlySpan<byte> iv, Span<byte> destination, out int bytesWritten)
{
return TryTransformCtr(ciphertext, destination, out bytesWritten);
return TryTransformCtr(ciphertext, iv, destination, out bytesWritten);
}
#endregion
}
2 changes: 1 addition & 1 deletion AesExtra/AesCtrTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal AesCtrTransform(byte[] key, ReadOnlySpan<byte> initialCounter)
using var aes = Aes.Create();
aes.Mode = CipherMode.ECB; // DevSkim: ignore DS187371
aes.BlockSize = BLOCKSIZE * BitsPerByte;
AesEcbTransform = aes.CreateEncryptor(key, null!);
AesEcbTransform = aes.CreateEncryptor(key, null);
Counter = initialCounter.ToArray();
}

Expand Down
44 changes: 27 additions & 17 deletions UnitTests/AesCtrTransform_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@ sealed class AesCtrTransform_Tests
{
const int BLOCKSIZE = 16; // bytes

readonly byte[] TestKey = new byte[128 / 8];
readonly byte[] InitialCounter = new byte[BLOCKSIZE];
static readonly byte[] TestKey =
[
31, 32, 33, 34, 35, 36, 37, 38,
41, 42, 43, 44, 45, 46, 47, 48,
51, 52, 53, 54, 55, 56, 57, 58
];

static readonly byte[] TestInitialCounter =
[
61, 62, 63, 64, 65, 66, 67, 68,
71, 72, 73, 74, 75, 76, 77, 78
];

[TestMethod]
public void Constructor()
{
using var transform = new AesCtrTransform(TestKey, InitialCounter);
using var transform = new AesCtrTransform(TestKey, TestInitialCounter);
}

[TestMethod]
Expand All @@ -38,14 +48,14 @@ public void Constructor_InvalidIVSize(int ivSize)
[TestMethod]
public void Dispose()
{
var transform = new AesCtrTransform(TestKey, InitialCounter);
var transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.Dispose();
}

[TestMethod]
public void Dispose_Double()
{
var transform = new AesCtrTransform(TestKey, InitialCounter);
var transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.Dispose();
transform.Dispose();
}
Expand All @@ -54,28 +64,28 @@ public void Dispose_Double()
[TestMethod]
public void CanReuseTransform_Get()
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
Assert.IsFalse(transform.CanReuseTransform);
}

[TestMethod]
public void CanTransformMultipleBlocks_Get()
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
Assert.IsTrue(transform.CanTransformMultipleBlocks);
}

[TestMethod]
public void InputBlockSize_Get()
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
Assert.AreEqual(BLOCKSIZE, transform.InputBlockSize);
}

[TestMethod]
public void OutputBlockSize_Get()
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
Assert.AreEqual(BLOCKSIZE, transform.InputBlockSize);
}

Expand All @@ -86,7 +96,7 @@ public void OutputBlockSize_Get()
[DataRow(10 * BLOCKSIZE)]
public void TransformBlock_ValidSize(int size)
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
var result = transform.TransformBlock(new byte[size], 0, size, new byte[size], 0);
Assert.AreEqual(size, result);
}
Expand All @@ -99,15 +109,15 @@ public void TransformBlock_InvalidSizeFails(int size)
{
Assert.ThrowsException<ArgumentException>(() =>
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.TransformBlock(new byte[size], 0, size, new byte[size], 0);
});
}

[TestMethod]
public void TransformBlock_AfterFinalFails()
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.TransformFinalBlock([], 0, 0);
Assert.ThrowsException<InvalidOperationException>(() =>
{
Expand All @@ -118,7 +128,7 @@ public void TransformBlock_AfterFinalFails()
[TestMethod]
public void TransformBlock_AfterDisposeFails()
{
ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.TransformBlock(new byte[BLOCKSIZE], 0, BLOCKSIZE, new byte[BLOCKSIZE], 0);
transform.Dispose();
Assert.ThrowsException<ObjectDisposedException>(() =>
Expand All @@ -134,7 +144,7 @@ public void TransformBlock_AfterDisposeFails()
[DataRow(BLOCKSIZE)]
public void TransformFinalBlock_ValidSize(int size)
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
var result = transform.TransformFinalBlock(new byte[size], 0, size);
Assert.AreEqual(size, result.Length);
}
Expand All @@ -146,15 +156,15 @@ public void TransformFinalBlock_InvalidSizeFails(int size)
{
Assert.ThrowsException<ArgumentOutOfRangeException>(() =>
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.TransformFinalBlock(new byte[size], 0, size);
});
}

[TestMethod]
public void TransformFinalBlock_AfterFinalFails()
{
using ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
using ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.TransformFinalBlock([], 0, 0);
Assert.ThrowsException<InvalidOperationException>(() =>
{
Expand All @@ -165,7 +175,7 @@ public void TransformFinalBlock_AfterFinalFails()
[TestMethod]
public void TransformFinalBlock_AfterDisposeFails()
{
ICryptoTransform transform = new AesCtrTransform(TestKey, InitialCounter);
ICryptoTransform transform = new AesCtrTransform(TestKey, TestInitialCounter);
transform.TransformBlock(new byte[BLOCKSIZE], 0, BLOCKSIZE, new byte[BLOCKSIZE], 0);
transform.Dispose();
Assert.ThrowsException<ObjectDisposedException>(() =>
Expand Down
Loading

0 comments on commit 87ea27f

Please sign in to comment.