Skip to content

Commit

Permalink
Optimizations: Reduced heap allocations (#68)
Browse files Browse the repository at this point in the history
* move state matrix on stack

* move tmp array in ShiftRow to stack

* move KeyExpansions temp and rcon array on stack

* move block and encryptedBlock arrays on stack

* fixed signed to unsigned comparison warnings
  • Loading branch information
mrdcvlsc authored Sep 16, 2022
1 parent f694d56 commit e795922
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 73 deletions.
90 changes: 31 additions & 59 deletions src/AES.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "AES.h"

AES::AES(const AESKeyLength keyLength) {
this->Nb = 4;
switch (keyLength) {
case AESKeyLength::AES_128:
this->Nk = 4;
Expand All @@ -16,8 +15,6 @@ AES::AES(const AESKeyLength keyLength) {
this->Nr = 14;
break;
}

blockBytesLen = 4 * this->Nb * sizeof(unsigned char);
}

unsigned char *AES::EncryptECB(const unsigned char in[], unsigned int inLen,
Expand Down Expand Up @@ -55,7 +52,7 @@ unsigned char *AES::EncryptCBC(const unsigned char in[], unsigned int inLen,
const unsigned char *iv) {
CheckLength(inLen);
unsigned char *out = new unsigned char[inLen];
unsigned char *block = new unsigned char[blockBytesLen];
unsigned char block[blockBytesLen];
unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)];
KeyExpansion(key, roundKeys);
memcpy(block, iv, blockBytesLen);
Expand All @@ -65,7 +62,6 @@ unsigned char *AES::EncryptCBC(const unsigned char in[], unsigned int inLen,
memcpy(block, out + i, blockBytesLen);
}

delete[] block;
delete[] roundKeys;

return out;
Expand All @@ -76,7 +72,7 @@ unsigned char *AES::DecryptCBC(const unsigned char in[], unsigned int inLen,
const unsigned char *iv) {
CheckLength(inLen);
unsigned char *out = new unsigned char[inLen];
unsigned char *block = new unsigned char[blockBytesLen];
unsigned char block[blockBytesLen];
unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)];
KeyExpansion(key, roundKeys);
memcpy(block, iv, blockBytesLen);
Expand All @@ -86,7 +82,6 @@ unsigned char *AES::DecryptCBC(const unsigned char in[], unsigned int inLen,
memcpy(block, in + i, blockBytesLen);
}

delete[] block;
delete[] roundKeys;

return out;
Expand All @@ -97,8 +92,8 @@ unsigned char *AES::EncryptCFB(const unsigned char in[], unsigned int inLen,
const unsigned char *iv) {
CheckLength(inLen);
unsigned char *out = new unsigned char[inLen];
unsigned char *block = new unsigned char[blockBytesLen];
unsigned char *encryptedBlock = new unsigned char[blockBytesLen];
unsigned char block[blockBytesLen];
unsigned char encryptedBlock[blockBytesLen];
unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)];
KeyExpansion(key, roundKeys);
memcpy(block, iv, blockBytesLen);
Expand All @@ -108,8 +103,6 @@ unsigned char *AES::EncryptCFB(const unsigned char in[], unsigned int inLen,
memcpy(block, out + i, blockBytesLen);
}

delete[] block;
delete[] encryptedBlock;
delete[] roundKeys;

return out;
Expand All @@ -120,8 +113,8 @@ unsigned char *AES::DecryptCFB(const unsigned char in[], unsigned int inLen,
const unsigned char *iv) {
CheckLength(inLen);
unsigned char *out = new unsigned char[inLen];
unsigned char *block = new unsigned char[blockBytesLen];
unsigned char *encryptedBlock = new unsigned char[blockBytesLen];
unsigned char block[blockBytesLen];
unsigned char encryptedBlock[blockBytesLen];
unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)];
KeyExpansion(key, roundKeys);
memcpy(block, iv, blockBytesLen);
Expand All @@ -131,8 +124,6 @@ unsigned char *AES::DecryptCFB(const unsigned char in[], unsigned int inLen,
memcpy(block, in + i, blockBytesLen);
}

delete[] block;
delete[] encryptedBlock;
delete[] roundKeys;

return out;
Expand All @@ -147,12 +138,8 @@ void AES::CheckLength(unsigned int len) {

void AES::EncryptBlock(const unsigned char in[], unsigned char out[],
unsigned char *roundKeys) {
unsigned char **state = new unsigned char *[4];
state[0] = new unsigned char[4 * Nb];
int i, j, round;
for (i = 0; i < 4; i++) {
state[i] = state[0] + Nb * i;
}
unsigned char state[4][Nb];
unsigned int i, j, round;

for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
Expand All @@ -178,19 +165,12 @@ void AES::EncryptBlock(const unsigned char in[], unsigned char out[],
out[i + 4 * j] = state[i][j];
}
}

delete[] state[0];
delete[] state;
}

void AES::DecryptBlock(const unsigned char in[], unsigned char out[],
unsigned char *roundKeys) {
unsigned char **state = new unsigned char *[4];
state[0] = new unsigned char[4 * Nb];
int i, j, round;
for (i = 0; i < 4; i++) {
state[i] = state[0] + Nb * i;
}
unsigned char state[4][Nb];
unsigned int i, j, round;

for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
Expand All @@ -216,13 +196,10 @@ void AES::DecryptBlock(const unsigned char in[], unsigned char out[],
out[i + 4 * j] = state[i][j];
}
}

delete[] state[0];
delete[] state;
}

void AES::SubBytes(unsigned char **state) {
int i, j;
void AES::SubBytes(unsigned char state[4][Nb]) {
unsigned int i, j;
unsigned char t;
for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
Expand All @@ -232,19 +209,17 @@ void AES::SubBytes(unsigned char **state) {
}
}

void AES::ShiftRow(unsigned char **state, int i,
int n) // shift row i on n positions
void AES::ShiftRow(unsigned char state[4][Nb], unsigned int i,
unsigned int n) // shift row i on n positions
{
unsigned char *tmp = new unsigned char[Nb];
for (int j = 0; j < Nb; j++) {
unsigned char tmp[Nb];
for (unsigned int j = 0; j < Nb; j++) {
tmp[j] = state[i][(j + n) % Nb];
}
memcpy(state[i], tmp, Nb * sizeof(unsigned char));

delete[] tmp;
}

void AES::ShiftRows(unsigned char **state) {
void AES::ShiftRows(unsigned char state[4][Nb]) {
ShiftRow(state, 1, 1);
ShiftRow(state, 2, 2);
ShiftRow(state, 3, 3);
Expand All @@ -255,8 +230,8 @@ unsigned char AES::xtime(unsigned char b) // multiply on x
return (b << 1) ^ (((b >> 7) & 1) * 0x1b);
}

void AES::MixColumns(unsigned char **state) {
unsigned char temp_state[4][4];
void AES::MixColumns(unsigned char state[4][Nb]) {
unsigned char temp_state[4][Nb];

for (size_t i = 0; i < 4; ++i) {
memset(temp_state[i], 0, 4);
Expand All @@ -278,8 +253,8 @@ void AES::MixColumns(unsigned char **state) {
}
}

void AES::AddRoundKey(unsigned char **state, unsigned char *key) {
int i, j;
void AES::AddRoundKey(unsigned char state[4][Nb], unsigned char *key) {
unsigned int i, j;
for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
state[i][j] = state[i][j] ^ key[i + 4 * j];
Expand Down Expand Up @@ -309,8 +284,8 @@ void AES::XorWords(unsigned char *a, unsigned char *b, unsigned char *c) {
}
}

void AES::Rcon(unsigned char *a, int n) {
int i;
void AES::Rcon(unsigned char *a, unsigned int n) {
unsigned int i;
unsigned char c = 1;
for (i = 0; i < n - 1; i++) {
c = xtime(c);
Expand All @@ -321,10 +296,10 @@ void AES::Rcon(unsigned char *a, int n) {
}

void AES::KeyExpansion(const unsigned char key[], unsigned char w[]) {
unsigned char *temp = new unsigned char[4];
unsigned char *rcon = new unsigned char[4];
unsigned char temp[4];
unsigned char rcon[4];

int i = 0;
unsigned int i = 0;
while (i < 4 * Nk) {
w[i] = key[i];
i++;
Expand Down Expand Up @@ -352,13 +327,10 @@ void AES::KeyExpansion(const unsigned char key[], unsigned char w[]) {
w[i + 3] = w[i + 3 - 4 * Nk] ^ temp[3];
i += 4;
}

delete[] rcon;
delete[] temp;
}

void AES::InvSubBytes(unsigned char **state) {
int i, j;
void AES::InvSubBytes(unsigned char state[4][Nb]) {
unsigned int i, j;
unsigned char t;
for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
Expand All @@ -368,8 +340,8 @@ void AES::InvSubBytes(unsigned char **state) {
}
}

void AES::InvMixColumns(unsigned char **state) {
unsigned char temp_state[4][4];
void AES::InvMixColumns(unsigned char state[4][Nb]) {
unsigned char temp_state[4][Nb];

for (size_t i = 0; i < 4; ++i) {
memset(temp_state[i], 0, 4);
Expand All @@ -388,7 +360,7 @@ void AES::InvMixColumns(unsigned char **state) {
}
}

void AES::InvShiftRows(unsigned char **state) {
void AES::InvShiftRows(unsigned char state[4][Nb]) {
ShiftRow(state, 1, Nb - 1);
ShiftRow(state, 2, Nb - 2);
ShiftRow(state, 3, Nb - 3);
Expand Down
28 changes: 14 additions & 14 deletions src/AES.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,38 @@ enum class AESKeyLength { AES_128, AES_192, AES_256 };

class AES {
private:
int Nb;
int Nk;
int Nr;
static constexpr unsigned int Nb = 4;
static constexpr unsigned int blockBytesLen = 4 * Nb * sizeof(unsigned char);

unsigned int blockBytesLen;
unsigned int Nk;
unsigned int Nr;

void SubBytes(unsigned char **state);
void SubBytes(unsigned char state[4][Nb]);

void ShiftRow(unsigned char **state, int i,
int n); // shift row i on n positions
void ShiftRow(unsigned char state[4][Nb], unsigned int i,
unsigned int n); // shift row i on n positions

void ShiftRows(unsigned char **state);
void ShiftRows(unsigned char state[4][Nb]);

unsigned char xtime(unsigned char b); // multiply on x

void MixColumns(unsigned char **state);
void MixColumns(unsigned char state[4][Nb]);

void AddRoundKey(unsigned char **state, unsigned char *key);
void AddRoundKey(unsigned char state[4][Nb], unsigned char *key);

void SubWord(unsigned char *a);

void RotWord(unsigned char *a);

void XorWords(unsigned char *a, unsigned char *b, unsigned char *c);

void Rcon(unsigned char *a, int n);
void Rcon(unsigned char *a, unsigned int n);

void InvSubBytes(unsigned char **state);
void InvSubBytes(unsigned char state[4][Nb]);

void InvMixColumns(unsigned char **state);
void InvMixColumns(unsigned char state[4][Nb]);

void InvShiftRows(unsigned char **state);
void InvShiftRows(unsigned char state[4][Nb]);

void CheckLength(unsigned int len);

Expand Down

0 comments on commit e795922

Please sign in to comment.