Skip to content

Commit

Permalink
Merge pull request #285 from Crypto-TII/feature/differential-linear-s…
Browse files Browse the repository at this point in the history
…at-model

Feature/differential-linear sat model
  • Loading branch information
peacker authored Oct 4, 2024
2 parents 5250c93 + 201917f commit fdf1bba
Show file tree
Hide file tree
Showing 6 changed files with 845 additions and 9 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, cipher, counter='sequential', compact=False):
super().__init__(cipher, counter, compact)
self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join)

def branch_xor_linear_constraints(self):
@staticmethod
def branch_xor_linear_constraints(bindings):
"""
Return lists of variables and clauses for branch in XOR LINEAR model.
Expand All @@ -52,7 +53,7 @@ def branch_xor_linear_constraints(self):
sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher
sage: speck = SpeckBlockCipher(number_of_rounds=3)
sage: sat = SatXorLinearModel(speck)
sage: sat.branch_xor_linear_constraints()
sage: SatXorLinearModel.branch_xor_linear_constraints(sat.bit_bindings)
['-plaintext_0_o rot_0_0_0_i',
'plaintext_0_o -rot_0_0_0_i',
'-plaintext_1_o rot_0_0_1_i',
Expand All @@ -62,7 +63,7 @@ def branch_xor_linear_constraints(self):
'xor_2_10_15_o -cipher_output_2_12_31_i']
"""
constraints = []
for output_bit, input_bits in self.bit_bindings.items():
for output_bit, input_bits in bindings.items():
constraints.extend(utils.cnf_xor(output_bit, input_bits))

return constraints
Expand Down Expand Up @@ -91,7 +92,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]):
self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self._cipher, '_'.join)
if fixed_variables == []:
fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher)
constraints = self.fix_variables_value_xor_linear_constraints(fixed_variables)
constraints = SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables)
self._model_constraints = constraints
component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION)
operation_types = ("AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB")
Expand All @@ -106,7 +107,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]):
self._variables_list.extend(variables)
self._model_constraints.extend(constraints)

constraints = self.branch_xor_linear_constraints()
constraints = SatXorLinearModel.branch_xor_linear_constraints(self.bit_bindings)
self._model_constraints.extend(constraints)

if weight != -1:
Expand Down Expand Up @@ -399,7 +400,8 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values

return solution

def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]):
@staticmethod
def fix_variables_value_xor_linear_constraints(fixed_variables=[]):
"""
Return lists variables and clauses for fixing variables in XOR LINEAR model.
Expand Down Expand Up @@ -428,7 +430,7 @@ def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]):
....: 'bit_positions': [0, 1, 2, 3],
....: 'bit_values': [1, 1, 1, 0]
....: }]
sage: sat.fix_variables_value_xor_linear_constraints(fixed_variables)
sage: SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables)
['plaintext_0_o',
'-plaintext_1_o',
'plaintext_2_o',
Expand Down
32 changes: 32 additions & 0 deletions claasp/cipher_modules/models/sat/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,12 @@ def get_cnf_bitwise_truncate_constraints(a, a_0, a_1):
]


def get_cnf_truncated_linear_constraints(a, a_0):
return [
f'-{a} -{a_0}'
]


def modadd_truncated_lsb(result, variable_0, variable_1, next_carry):
return [f'{next_carry[0]} -{next_carry[1]}',
f'{next_carry[0]} -{variable_1[1]}',
Expand Down Expand Up @@ -819,3 +825,29 @@ def run_yices(solver_specs, options, dimacs_input, input_file_name):
os.remove(input_file_name)

return status, time, memory, values


def _generate_component_model_types(speck_cipher):
"""Generates the component model types for a given Speck cipher."""
component_model_types = []
for component in speck_cipher.get_all_components():
component_model_types.append({
"component_id": component.id,
"component_object": component,
"model_type": "sat_xor_differential_propagation_constraints"
})
return component_model_types


def _update_component_model_types_for_truncated_components(component_model_types, truncated_components):
"""Updates the component model types for truncated components."""
for component_model_type in component_model_types:
if component_model_type["component_id"] in truncated_components:
component_model_type["model_type"] = "sat_bitwise_deterministic_truncated_xor_differential_constraints"


def _update_component_model_types_for_linear_components(component_model_types, linear_components):
"""Updates the component model types for linear components."""
for component_model_type in component_model_types:
if component_model_type["component_id"] in linear_components:
component_model_type["model_type"] = "sat_xor_linear_mask_propagation_constraints"
89 changes: 89 additions & 0 deletions claasp/cipher_modules/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import math
from copy import deepcopy

import numpy as np

from claasp.name_mappings import CONSTANT, CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, WORD_OPERATION, LINEAR_LAYER, SBOX, MIX_COLUMN, \
INPUT_KEY, INPUT_PLAINTEXT, INPUT_MESSAGE, INPUT_STATE

Expand Down Expand Up @@ -791,3 +793,90 @@ def get_related_key_scenario_format_for_fixed_values(_cipher):
fixed_variables.append(fixed_variable)

return fixed_variables


def _extract_bits(columns, positions):
"""Extracts bits from columns at specified positions using vectorization."""
bit_size = columns.shape[0] * 8
positions = np.array(positions)
byte_indices = (bit_size - positions - 1) // 8
bit_indices = positions % 8
if np.any(byte_indices < 0) or np.any(byte_indices >= columns.shape[0]):
raise IndexError("Byte index out of range.")
bytes_at_positions = columns[byte_indices][:, :]
bits = (bytes_at_positions >> bit_indices[:, np.newaxis]) & 1

return bits


def _number_to_n_bit_binary_string(number, n_bits):
"""Converts a number to an n-bit binary string with leading zero padding."""
return format(number, f'0{n_bits}b')


def _extract_bit_positions(hex_number, state_size):
binary_str = _number_to_n_bit_binary_string(hex_number, state_size)
binary_str = binary_str[::-1]
positions = [i for i, bit in enumerate(binary_str) if bit == '1']
return positions


def _repeat_input_difference(input_difference, num_samples, num_bytes):
"""Function to repeat the input difference for a large sample size."""
bytes_array = np.frombuffer(input_difference.to_bytes(num_bytes, 'big'), dtype=np.uint8)
repeated_array = np.broadcast_to(bytes_array[:, np.newaxis], (num_bytes, num_samples))
return repeated_array


def differential_linear_checker_for_permutation(
cipher, input_difference, output_mask, number_of_samples, state_size
):
"""
This method helps to verify experimentally differential-linear distinguishers for permutations using the vectorized evaluator
"""
if state_size % 8 != 0:
raise ValueError("State size must be a multiple of 8.")
num_bytes = int(state_size/8)

rng = np.random.default_rng()
input_difference_data = _repeat_input_difference(input_difference, number_of_samples, num_bytes)
plaintext1 = rng.integers(low=0, high=256, size=(num_bytes, number_of_samples), dtype=np.uint8)
plaintext2 = plaintext1 ^ input_difference_data
ciphertext1 = cipher.evaluate_vectorized([plaintext1])
ciphertext2 = cipher.evaluate_vectorized([plaintext2])
ciphertext3 = ciphertext1[0] ^ ciphertext2[0]
bit_positions_ciphertext = _extract_bit_positions(output_mask, state_size)
ccc = _extract_bits(ciphertext3.T, bit_positions_ciphertext)
parities = np.bitwise_xor.reduce(ccc, axis=0)
count = np.count_nonzero(parities == 0)
corr = 2*count/number_of_samples*1.0-1
return corr


def differential_linear_checker_for_block_cipher_single_key(
cipher, input_difference, output_mask, number_of_samples, block_size, key_size, fixed_key
):
"""
This method helps to verify experimentally differential-linear distinguishers for block ciphers using the vectorized evaluator
"""
if block_size % 8 != 0:
raise ValueError("State size must be a multiple of 8.")
if key_size % 8 != 0:
raise ValueError("Key size must be a multiple of 8.")
state_num_bytes = int(block_size / 8)
key_num_bytes = int(key_size / 8)

rng = np.random.default_rng()
fixed_key_data = _repeat_input_difference(fixed_key, number_of_samples, key_num_bytes)
input_difference_data = _repeat_input_difference(input_difference, number_of_samples, state_num_bytes)
plaintext1 = rng.integers(low=0, high=256, size=(state_num_bytes, number_of_samples), dtype=np.uint8)
plaintext2 = plaintext1 ^ input_difference_data
ciphertext1 = cipher.evaluate_vectorized([plaintext1, fixed_key_data])
ciphertext2 = cipher.evaluate_vectorized([plaintext2, fixed_key_data])
ciphertext3 = ciphertext1[0] ^ ciphertext2[0]
bit_positions_ciphertext = _extract_bit_positions(output_mask, block_size)
ccc = _extract_bits(ciphertext3.T, bit_positions_ciphertext)
parities = np.bitwise_xor.reduce(ccc, axis=0)
count = np.count_nonzero(parities == 0)
corr = 2*count/number_of_samples*1.0-1
return corr
Loading

0 comments on commit fdf1bba

Please sign in to comment.