diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_differential_linear_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_differential_linear_model.py new file mode 100644 index 00000000..8e23ec63 --- /dev/null +++ b/claasp/cipher_modules/models/sat/sat_models/sat_differential_linear_model.py @@ -0,0 +1,421 @@ +import time +from claasp.cipher_modules.models.sat import solvers +from claasp.cipher_modules.models.sat.sat_model import SatModel +from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_deterministic_truncated_xor_differential_model import ( + SatBitwiseDeterministicTruncatedXorDifferentialModel +) +from claasp.cipher_modules.models.sat.sat_models.sat_xor_linear_model import SatXorLinearModel +from claasp.cipher_modules.models.utils import set_component_solution, get_bit_bindings +from claasp.cipher_modules.models.sat.utils import utils as sat_utils, constants + + +class SatDifferentialLinearModel(SatModel): + """ + Model that combines concrete XOR differential model with bitwise deterministic truncated differential model + and linear model to create a differential-linear model. + """ + + def __init__(self, cipher, dict_of_components): + """ + Initializes the model with cipher and components. + + INPUT: + - ``cipher`` -- **object**; The cipher model used in the SAT-based differential trail search. + - ``dict_of_components`` -- **dict**; Dictionary mapping component IDs to their respective models and types. + """ + self.dict_of_components = dict_of_components + self.regular_components = self._get_components_by_type('sat_xor_differential_propagation_constraints') + self.truncated_components = self._get_components_by_type( + 'sat_bitwise_deterministic_truncated_xor_differential_constraints') + self.linear_components = self._get_components_by_type( + 'sat_xor_linear_mask_propagation_constraints') + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) + super().__init__(cipher, "sequential", False) + + def _get_components_by_type(self, model_type): + """ + Retrieves components based on their model type. + + INPUT: + - ``model_type`` -- **str**; The model type to filter components. + + RETURN: + - **list**; A list of components of the specified type. + """ + return [component for component in self.dict_of_components if component['model_type'] == model_type] + + def _get_regular_xor_differential_components_in_border(self): + """ + Retrieves regular components that are connected to truncated components (border components). + + RETURN: + - **list**; A list of regular components at the border. + """ + regular_component_ids = {item['component_id'] for item in self.regular_components} + border_components = [] + + for truncated_component in self.truncated_components: + component_obj = self.cipher.get_component_from_id(truncated_component['component_id']) + for input_id in component_obj.input_id_links: + if input_id in regular_component_ids: + border_components.append(input_id) + + return list(set(border_components)) + + def _get_truncated_xor_differential_components_in_border(self): + """ + Retrieves truncated components that are connected to linear components (border components). + + RETURN: + - **list**; A list of truncated components at the border. + """ + truncated_component_ids = {item['component_id'] for item in self.truncated_components} + border_components = [] + + for linear_component in self.linear_components: + component_obj = self.cipher.get_component_from_id(linear_component['component_id']) + for input_id in component_obj.input_id_links: + if input_id in truncated_component_ids: + border_components.append(input_id) + + return list(set(border_components)) + + def _get_connecting_constraints(self): + """ + Adds constraints for connecting regular, truncated, and linear components. + """ + def is_any_string_in_list_substring_of_string(string, string_list): + # Check if any string in the list is a substring of the given string + return any(s in string for s in string_list) + + border_components = self._get_regular_xor_differential_components_in_border() + for component_id in border_components: + component = self.cipher.get_component_from_id(component_id) + for idx in range(component.output_bit_size): + constraints = sat_utils.get_cnf_bitwise_truncate_constraints( + f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1' + ) + self._model_constraints.extend(constraints) + self._variables_list.extend([ + f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1' + ]) + + border_components = self._get_truncated_xor_differential_components_in_border() + + linear_component_ids = [item['component_id'] for item in self.linear_components] + + for component_id in border_components: + component = self.cipher.get_component_from_id(component_id) + for idx in range(component.output_bit_size): + truncated_component = f'{component_id}_{idx}_o' + component_successors = self.bit_bindings[truncated_component] + for component_successor in component_successors: + length_component_successor = len(component_successor) + component_successor_id = component_successor[:length_component_successor-2] + + if is_any_string_in_list_substring_of_string(component_successor_id, linear_component_ids): + constraints = sat_utils.get_cnf_truncated_linear_constraints( + component_successor, f'{component_id}_{idx}_0' + ) + self._model_constraints.extend(constraints) + self._variables_list.extend([component_successor, f'{component_id}_{idx}_0']) + + def _build_weight_constraints(self, weight): + """ + Builds weight constraints for the model based on the specified weight. + + INPUT: + - ``weight`` -- **int**; The weight to constrain the search. If set to 0, the hardware variables are negated. + + RETURN: + - **tuple**; A tuple containing a list of variables and a list of constraints. + """ + + hw_variables = [var_id for var_id in self._variables_list if var_id.startswith('hw_')] + + linear_component_ids = [linear_component["component_id"] for linear_component in self.linear_components] + hw_linear_variables = [] + for linear_component_id in linear_component_ids: + for hw_variable in hw_variables: + if linear_component_id in hw_variable: + hw_linear_variables.append(hw_variable) + hw_variables.extend(hw_linear_variables) + if weight == 0: + return [], [f'-{var}' for var in hw_variables] + + return self._counter(hw_variables, weight) + + def _build_unknown_variable_constraints(self, num_unknowns): + """ + Adds constraints for limiting the number of unknown variables. + + INPUT: + - ``num_unknowns`` -- **int**; The number of unknown variables allowed. + + RETURN: + - **tuple**; A tuple containing a list of variables and a list of constraints. + """ + border_components = self._get_truncated_xor_differential_components_in_border() + minimize_vars = [] + for border_component in border_components: + output_id = border_component + minimize_vars.extend( + [bit_id for bit_id in self._variables_list if bit_id.startswith(output_id) and bit_id.endswith("_0")] + ) + return self._sequential_counter(minimize_vars, num_unknowns, "dummy_id_unknown") + + def build_xor_differential_linear_model(self, weight=-1, num_unknown_vars=None): + """ + Constructs a model to search for differential-linear trails. + This model is a combination of the concrete XOR differential model, the bitwise truncated deterministic model, + and the linear XOR differential model. + + INPUT: + - ``weight`` -- **integer** (default: `-1`); specifies the maximum probability weight. If set to a non-negative + integer, it constrains the search to trails with the fixed probability weight. + - ``number_of_unknown_variables`` -- **int** (default: None); specifies the upper limit on the number of unknown + variables allowed in the differential trail. + + EXAMPLES:: + + sage: from claasp.cipher_modules.models.sat.sat_models.sat_differential_linear_model import SatDifferentialLinearModel + sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher + sage: speck = SpeckBlockCipher(number_of_rounds=4) + sage: component_model_types = [] + sage: for component in speck.get_all_components(): + ....: component_model_type = { + ....: "component_id": component.id, + ....: "component_object": component, + ....: "model_type": "sat_xor_differential_propagation_constraints" + ....: } + ....: component_model_types.append(component_model_type) + sage: sat = SatDifferentialLinearModel(speck, component_model_types) + sage: sat.build_xor_differential_linear_model() + ... + """ + self.build_generic_sat_model_from_dictionary(self.dict_of_components) + constraints = SatXorLinearModel.branch_xor_linear_constraints(self.bit_bindings) + self._model_constraints.extend(constraints) + + if num_unknown_vars is not None: + variables, constraints = self._build_unknown_variable_constraints(num_unknown_vars) + self._variables_list.extend(variables) + self._model_constraints.extend(constraints) + + if weight != -1: + variables, constraints = self._build_weight_constraints(weight) + self._variables_list.extend(variables) + self._model_constraints.extend(constraints) + + self._get_connecting_constraints() + + @staticmethod + def fix_variables_value_constraints( + fixed_variables, regular_components=None, truncated_components=None, linear_components=None): + """ + Imposes fixed value constraints on variables within differential, truncated, and linear components. + + INPUT: + - ``fixed_variables`` -- **list** (default: `[]`); specifies a list of variables that should be fixed to specific values. Each entry in the list should be a dictionary representing constraints for specific components, written in the CLAASP constraining syntax. + - ``regular_components`` -- **list** (default: None); list of regular components. + - ``truncated_components`` -- **list** (default: None); list of truncated components. + - ``linear_components`` -- **list** (default: None); list of linear components. + + RETURN: + - **list**; A list of constraints for the model. + """ + truncated_vars = [] + regular_vars = [] + linear_vars = [] + + for var in fixed_variables: + component_id = var["component_id"] + + if component_id in [comp["component_id"] for comp in regular_components] and 2 in var['bit_values']: + raise ValueError("The fixed value in a regular XOR differential component cannot be 2") + + if component_id in [comp["component_id"] for comp in truncated_components]: + truncated_vars.append(var) + elif component_id in [comp["component_id"] for comp in linear_components]: + linear_vars.append(var) + elif component_id in [comp["component_id"] for comp in regular_components]: + regular_vars.append(var) + else: + regular_vars.append(var) + + regular_constraints = SatModel.fix_variables_value_constraints(regular_vars) + truncated_constraints = SatBitwiseDeterministicTruncatedXorDifferentialModel.fix_variables_value_constraints( + truncated_vars) + linear_constraints = SatXorLinearModel.fix_variables_value_xor_linear_constraints(linear_vars) + + return regular_constraints + truncated_constraints + linear_constraints + + def _parse_solver_output(self, variable2value): + """ + Parses the solver's output and returns component solutions and total weight. The total weight is the sum of the + probability weight of the top part (differential part) and the correlation weight of the bottom part (linear part). + Note that the weight of the middle part is deterministic. + + INPUT: + - ``variable2value`` -- **dict**; mapping of solver's variables to their values. + + RETURN: + - **tuple**; a tuple containing the dictionary of component solutions and the total weight. + """ + components_solutions = self._get_cipher_inputs_components_solutions('', variable2value) + total_weight_diff = 0 + total_weight_lin = 0 + + for component in self._cipher.get_all_components(): + if component.id in [d['component_id'] for d in self.regular_components]: + hex_value = self._get_component_hex_value(component, '', variable2value) + weight = self.calculate_component_weight(component, '', variable2value) + components_solutions[component.id] = set_component_solution(hex_value, weight) + total_weight_diff += weight + + elif component.id in [d['component_id'] for d in self.truncated_components]: + value = self._get_component_value_double_ids(component, variable2value) + components_solutions[component.id] = set_component_solution(value) + + elif component.id in [d['component_id'] for d in self.linear_components]: + hex_value = self._get_component_hex_value(component, constants.OUTPUT_BIT_ID_SUFFIX, variable2value) + weight = self.calculate_component_weight(component, constants.OUTPUT_BIT_ID_SUFFIX, variable2value) + total_weight_lin += weight + components_solutions[component.id] = set_component_solution(hex_value, weight) + + return components_solutions, total_weight_diff + 2 * total_weight_lin + + def find_one_differential_linear_trail_with_fixed_weight( + self, weight, num_unknown_vars=None, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + """ + Finds one XOR differential-linear trail with a fixed weight. The weight must be the sum of the probability weight + of the top part (differential part) and the correlation weight of the bottom part (linear part). + + INPUT: + - ``weight`` -- **int**; Maximum probability weight for the regular XOR differential part. + - ``num_unknown_vars`` -- **int** (default: None); Upper limit on the number of unknown variables allowed. + - ``fixed_values`` -- **list** (default: `[]`); specifies a list of variables that should be fixed to specific values. Each entry in the list should be a dictionary representing constraints for specific components, written in the CLAASP constraining syntax. + - ``solver_name`` -- **str** (default: ``solvers.SOLVER_DEFAULT``); The name of the SAT solver to use. + + RETURN: + - **dict**; Solution returned by the solver, including the trail and additional information. + + EXAMPLES:: + + sage: from claasp.cipher_modules.models.sat.sat_models.sat_differential_linear_model import SatDifferentialLinearModel + sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher + sage: from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list + sage: from claasp.cipher_modules.models.sat.utils.utils import _generate_component_model_types, \ + ....: _update_component_model_types_for_truncated_components, _update_component_model_types_for_linear_components + sage: import itertools + sage: speck = SpeckBlockCipher(number_of_rounds=6) + sage: middle_part_components = [] + sage: bottom_part_components = [] + sage: for round_number in range(2, 4): + ....: middle_part_components.append(speck.get_components_in_round(round_number)) + sage: for round_number in range(4, 6): + ....: bottom_part_components.append(speck.get_components_in_round(round_number)) + sage: middle_part_components = list(itertools.chain(*middle_part_components)) + sage: bottom_part_components = list(itertools.chain(*bottom_part_components)) + sage: middle_part_components = [component.id for component in middle_part_components] + sage: bottom_part_components = [component.id for component in bottom_part_components] + sage: plaintext = set_fixed_variables( + ....: component_id='plaintext', + ....: constraint_type='equal', + ....: bit_positions=range(32), + ....: bit_values=integer_to_bit_list(0x02110a04, 32, 'big') + ....: ) + sage: key = set_fixed_variables( + ....: component_id='key', + ....: constraint_type='equal', + ....: bit_positions=range(64), + ....: bit_values=(0,) * 64 + ....: ) + sage: modadd_2_7 = set_fixed_variables( + ....: component_id='modadd_4_7', + ....: constraint_type='not_equal', + ....: bit_positions=range(4), + ....: bit_values=[0] * 4 + ....: ) + sage: ciphertext_difference = set_fixed_variables( + ....: component_id='cipher_output_5_12', + ....: constraint_type='equal', + ....: bit_positions=range(32), + ....: bit_values=integer_to_bit_list(0x02000201, 32, 'big') + ....: ) + sage: component_model_types = _generate_component_model_types(speck) + sage: _update_component_model_types_for_truncated_components(component_model_types, middle_part_components) + sage: _update_component_model_types_for_linear_components(component_model_types, bottom_part_components) + sage: sat_heterogeneous_model = SatDifferentialLinearModel(speck, component_model_types) + sage: trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( + ....: weight=8, fixed_values=[key, plaintext, modadd_2_7, ciphertext_difference], solver_name="CADICAL_EXT", num_unknown_vars=31 + ....: ) + sage: trail["status"] == 'SATISFIABLE' + True + + """ + start_time = time.time() + + self.build_xor_differential_linear_model(weight, num_unknown_vars) + constraints = self.fix_variables_value_constraints( + fixed_values, + self.regular_components, + self.truncated_components, + self.linear_components + ) + self.model_constraints.extend(constraints) + + solution = self.solve("XOR_REGULAR_DETERMINISTIC_DIFFERENTIAL", solver_name=solver_name) + solution['building_time_seconds'] = time.time() - start_time + solution['test_name'] = "find_one_regular_truncated_xor_differential_trail" + + return solution + + def find_lowest_weight_xor_differential_linear_trail( + self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ): + """ + Finds the XOR regular truncated differential trail with the lowest weight. + + INPUT: + - ``fixed_values`` -- **list** (default: `[]`); specifies a list of variables that should be fixed to specific values. Each entry in the list should be a dictionary representing constraints for specific components, written in the CLAASP constraining syntax. + - ``solver_name`` -- **str** (default: ``solvers.SOLVER_DEFAULT``); The SAT solver to use. + + RETURN: + - **dict**; Solution with the trail and metadata (weight, time, memory usage). + """ + current_weight = 0 + start_building_time = time.time() + self.build_xor_regular_and_deterministic_truncated_differential_model(current_weight) + constraints = self.fix_variables_value_constraints( + fixed_values, self.regular_components, self.truncated_components + ) + self.model_constraints.extend(constraints) + end_building_time = time.time() + solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name) + solution['building_time_seconds'] = end_building_time - start_building_time + total_time = solution['solving_time_seconds'] + max_memory = solution['memory_megabytes'] + while solution['total_weight'] is None: + current_weight += 1 + self.build_xor_regular_and_deterministic_truncated_differential_model(current_weight) + self.model_constraints.extend(constraints) + solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name) + total_time += solution['solving_time_seconds'] + max_memory = max(max_memory, solution['memory_megabytes']) + + solution['solving_time_seconds'] = total_time + solution['memory_megabytes'] = max_memory + solution['test_name'] = "find_lowest_weight_xor_regular_truncated_differential_trail" + + return solution + + @property + def cipher(self): + """ + Returns the cipher instance associated with the model. + + RETURN: + - **object**; The cipher object being used in this model. + """ + return self._cipher diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py index 872c0181..7268637a 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py @@ -34,7 +34,8 @@ def __init__(self, cipher, counter='sequential', compact=False): super().__init__(cipher, counter, compact) self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) - def branch_xor_linear_constraints(self): + @staticmethod + def branch_xor_linear_constraints(bindings): """ Return lists of variables and clauses for branch in XOR LINEAR model. @@ -52,7 +53,7 @@ def branch_xor_linear_constraints(self): sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: speck = SpeckBlockCipher(number_of_rounds=3) sage: sat = SatXorLinearModel(speck) - sage: sat.branch_xor_linear_constraints() + sage: SatXorLinearModel.branch_xor_linear_constraints(sat.bit_bindings) ['-plaintext_0_o rot_0_0_0_i', 'plaintext_0_o -rot_0_0_0_i', '-plaintext_1_o rot_0_0_1_i', @@ -62,7 +63,7 @@ def branch_xor_linear_constraints(self): 'xor_2_10_15_o -cipher_output_2_12_31_i'] """ constraints = [] - for output_bit, input_bits in self.bit_bindings.items(): + for output_bit, input_bits in bindings.items(): constraints.extend(utils.cnf_xor(output_bit, input_bits)) return constraints @@ -91,7 +92,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self._cipher, '_'.join) if fixed_variables == []: fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) - constraints = self.fix_variables_value_xor_linear_constraints(fixed_variables) + constraints = SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables) self._model_constraints = constraints component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION) operation_types = ("AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB") @@ -106,7 +107,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): self._variables_list.extend(variables) self._model_constraints.extend(constraints) - constraints = self.branch_xor_linear_constraints() + constraints = SatXorLinearModel.branch_xor_linear_constraints(self.bit_bindings) self._model_constraints.extend(constraints) if weight != -1: @@ -399,7 +400,8 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values return solution - def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]): + @staticmethod + def fix_variables_value_xor_linear_constraints(fixed_variables=[]): """ Return lists variables and clauses for fixing variables in XOR LINEAR model. @@ -428,7 +430,7 @@ def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]): ....: 'bit_positions': [0, 1, 2, 3], ....: 'bit_values': [1, 1, 1, 0] ....: }] - sage: sat.fix_variables_value_xor_linear_constraints(fixed_variables) + sage: SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables) ['plaintext_0_o', '-plaintext_1_o', 'plaintext_2_o', diff --git a/claasp/cipher_modules/models/sat/utils/utils.py b/claasp/cipher_modules/models/sat/utils/utils.py index 18d0ca26..89c6ad49 100644 --- a/claasp/cipher_modules/models/sat/utils/utils.py +++ b/claasp/cipher_modules/models/sat/utils/utils.py @@ -659,6 +659,12 @@ def get_cnf_bitwise_truncate_constraints(a, a_0, a_1): ] +def get_cnf_truncated_linear_constraints(a, a_0): + return [ + f'-{a} -{a_0}' + ] + + def modadd_truncated_lsb(result, variable_0, variable_1, next_carry): return [f'{next_carry[0]} -{next_carry[1]}', f'{next_carry[0]} -{variable_1[1]}', @@ -819,3 +825,29 @@ def run_yices(solver_specs, options, dimacs_input, input_file_name): os.remove(input_file_name) return status, time, memory, values + + +def _generate_component_model_types(speck_cipher): + """Generates the component model types for a given Speck cipher.""" + component_model_types = [] + for component in speck_cipher.get_all_components(): + component_model_types.append({ + "component_id": component.id, + "component_object": component, + "model_type": "sat_xor_differential_propagation_constraints" + }) + return component_model_types + + +def _update_component_model_types_for_truncated_components(component_model_types, truncated_components): + """Updates the component model types for truncated components.""" + for component_model_type in component_model_types: + if component_model_type["component_id"] in truncated_components: + component_model_type["model_type"] = "sat_bitwise_deterministic_truncated_xor_differential_constraints" + + +def _update_component_model_types_for_linear_components(component_model_types, linear_components): + """Updates the component model types for linear components.""" + for component_model_type in component_model_types: + if component_model_type["component_id"] in linear_components: + component_model_type["model_type"] = "sat_xor_linear_mask_propagation_constraints" diff --git a/claasp/cipher_modules/models/utils.py b/claasp/cipher_modules/models/utils.py index 3ead6046..d15a87ac 100644 --- a/claasp/cipher_modules/models/utils.py +++ b/claasp/cipher_modules/models/utils.py @@ -22,6 +22,8 @@ import math from copy import deepcopy +import numpy as np + from claasp.name_mappings import CONSTANT, CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, WORD_OPERATION, LINEAR_LAYER, SBOX, MIX_COLUMN, \ INPUT_KEY, INPUT_PLAINTEXT, INPUT_MESSAGE, INPUT_STATE @@ -791,3 +793,90 @@ def get_related_key_scenario_format_for_fixed_values(_cipher): fixed_variables.append(fixed_variable) return fixed_variables + + +def _extract_bits(columns, positions): + """Extracts bits from columns at specified positions using vectorization.""" + bit_size = columns.shape[0] * 8 + positions = np.array(positions) + byte_indices = (bit_size - positions - 1) // 8 + bit_indices = positions % 8 + if np.any(byte_indices < 0) or np.any(byte_indices >= columns.shape[0]): + raise IndexError("Byte index out of range.") + bytes_at_positions = columns[byte_indices][:, :] + bits = (bytes_at_positions >> bit_indices[:, np.newaxis]) & 1 + + return bits + + +def _number_to_n_bit_binary_string(number, n_bits): + """Converts a number to an n-bit binary string with leading zero padding.""" + return format(number, f'0{n_bits}b') + + +def _extract_bit_positions(hex_number, state_size): + binary_str = _number_to_n_bit_binary_string(hex_number, state_size) + binary_str = binary_str[::-1] + positions = [i for i, bit in enumerate(binary_str) if bit == '1'] + return positions + + +def _repeat_input_difference(input_difference, num_samples, num_bytes): + """Function to repeat the input difference for a large sample size.""" + bytes_array = np.frombuffer(input_difference.to_bytes(num_bytes, 'big'), dtype=np.uint8) + repeated_array = np.broadcast_to(bytes_array[:, np.newaxis], (num_bytes, num_samples)) + return repeated_array + + +def differential_linear_checker_for_permutation( + cipher, input_difference, output_mask, number_of_samples, state_size +): + """ + This method helps to verify experimentally differential-linear distinguishers for permutations using the vectorized evaluator + """ + if state_size % 8 != 0: + raise ValueError("State size must be a multiple of 8.") + num_bytes = int(state_size/8) + + rng = np.random.default_rng() + input_difference_data = _repeat_input_difference(input_difference, number_of_samples, num_bytes) + plaintext1 = rng.integers(low=0, high=256, size=(num_bytes, number_of_samples), dtype=np.uint8) + plaintext2 = plaintext1 ^ input_difference_data + ciphertext1 = cipher.evaluate_vectorized([plaintext1]) + ciphertext2 = cipher.evaluate_vectorized([plaintext2]) + ciphertext3 = ciphertext1[0] ^ ciphertext2[0] + bit_positions_ciphertext = _extract_bit_positions(output_mask, state_size) + ccc = _extract_bits(ciphertext3.T, bit_positions_ciphertext) + parities = np.bitwise_xor.reduce(ccc, axis=0) + count = np.count_nonzero(parities == 0) + corr = 2*count/number_of_samples*1.0-1 + return corr + + +def differential_linear_checker_for_block_cipher_single_key( + cipher, input_difference, output_mask, number_of_samples, block_size, key_size, fixed_key +): + """ + This method helps to verify experimentally differential-linear distinguishers for block ciphers using the vectorized evaluator + """ + if block_size % 8 != 0: + raise ValueError("State size must be a multiple of 8.") + if key_size % 8 != 0: + raise ValueError("Key size must be a multiple of 8.") + state_num_bytes = int(block_size / 8) + key_num_bytes = int(key_size / 8) + + rng = np.random.default_rng() + fixed_key_data = _repeat_input_difference(fixed_key, number_of_samples, key_num_bytes) + input_difference_data = _repeat_input_difference(input_difference, number_of_samples, state_num_bytes) + plaintext1 = rng.integers(low=0, high=256, size=(state_num_bytes, number_of_samples), dtype=np.uint8) + plaintext2 = plaintext1 ^ input_difference_data + ciphertext1 = cipher.evaluate_vectorized([plaintext1, fixed_key_data]) + ciphertext2 = cipher.evaluate_vectorized([plaintext2, fixed_key_data]) + ciphertext3 = ciphertext1[0] ^ ciphertext2[0] + bit_positions_ciphertext = _extract_bit_positions(output_mask, block_size) + ccc = _extract_bits(ciphertext3.T, bit_positions_ciphertext) + parities = np.bitwise_xor.reduce(ccc, axis=0) + count = np.count_nonzero(parities == 0) + corr = 2*count/number_of_samples*1.0-1 + return corr diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_differential_linear_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_differential_linear_test.py new file mode 100644 index 00000000..db884772 --- /dev/null +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_differential_linear_test.py @@ -0,0 +1,291 @@ +from claasp.cipher_modules.models.sat.sat_models.sat_differential_linear_model import SatDifferentialLinearModel +from claasp.cipher_modules.models.sat.utils.utils import _generate_component_model_types, \ + _update_component_model_types_for_truncated_components, _update_component_model_types_for_linear_components +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list, \ + differential_linear_checker_for_permutation, differential_linear_checker_for_block_cipher_single_key +from claasp.ciphers.block_ciphers.aradi_block_cipher_sbox import AradiBlockCipherSBox +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation +import itertools + + +def test_differential_linear_trail_with_fixed_weight_6_rounds_speck(): + """Test for finding a differential-linear trail with fixed weight for 6 rounds of Speck.""" + speck = SpeckBlockCipher(number_of_rounds=6) + middle_part_components = [] + bottom_part_components = [] + for round_number in range(2, 4): + middle_part_components.append(speck.get_components_in_round(round_number)) + for round_number in range(4, 6): + bottom_part_components.append(speck.get_components_in_round(round_number)) + + middle_part_components = list(itertools.chain(*middle_part_components)) + bottom_part_components = list(itertools.chain(*bottom_part_components)) + + middle_part_components = [component.id for component in middle_part_components] + bottom_part_components = [component.id for component in bottom_part_components] + + plaintext = set_fixed_variables( + component_id='plaintext', + constraint_type='equal', + bit_positions=range(32), + bit_values=integer_to_bit_list(0x02110a04, 32, 'big') + ) + + key = set_fixed_variables( + component_id='key', + constraint_type='equal', + bit_positions=range(64), + bit_values=(0,) * 64 + ) + + modadd_2_7 = set_fixed_variables( + component_id='modadd_4_7', + constraint_type='not_equal', + bit_positions=range(4), + bit_values=[0] * 4 + ) + + ciphertext_difference = set_fixed_variables( + component_id='cipher_output_5_12', + constraint_type='equal', + bit_positions=range(32), + bit_values=integer_to_bit_list(0x02000201, 32, 'big') + ) + + component_model_types = _generate_component_model_types(speck) + _update_component_model_types_for_truncated_components(component_model_types, middle_part_components) + _update_component_model_types_for_linear_components(component_model_types, bottom_part_components) + + sat_heterogeneous_model = SatDifferentialLinearModel(speck, component_model_types) + + trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( + weight=8, fixed_values=[key, plaintext, modadd_2_7, ciphertext_difference], solver_name="CADICAL_EXT", num_unknown_vars=31 + ) + assert trail["status"] == 'SATISFIABLE' + + +def test_differential_linear_trail_with_fixed_weight_3_rounds_chacha(): + """Test for finding a differential-linear trail with fixed weight for 3 rounds of ChaCha permutation.""" + chacha = ChachaPermutation(number_of_rounds=6) + import itertools + top_part_components = [] + middle_part_components = [] + bottom_part_components = [] + for round_number in range(1): + top_part_components.append(chacha.get_components_in_round(round_number)) + for round_number in range(1, 3): + middle_part_components.append(chacha.get_components_in_round(round_number)) + for round_number in range(3, 6): + bottom_part_components.append(chacha.get_components_in_round(round_number)) + + middle_part_components = list(itertools.chain(*middle_part_components)) + bottom_part_components = list(itertools.chain(*bottom_part_components)) + + middle_part_components = [component.id for component in middle_part_components] + bottom_part_components = [component.id for component in bottom_part_components] + + plaintext = set_fixed_variables( + component_id='plaintext', + constraint_type='equal', + bit_positions=range(512), + bit_values=integer_to_bit_list(0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000, 512, 'big') + ) + + cipher_output_5_24 = set_fixed_variables( + component_id='cipher_output_5_24', + constraint_type='equal', + bit_positions=range(512), + bit_values=integer_to_bit_list(0x00010000000100010000000100030003000000800000008000000000000001800000000000000001000000010000000201000101010000000000010103000101, 512, 'big') + ) + + modadd_3_15 = set_fixed_variables( + component_id=f'modadd_3_15', + constraint_type='not_equal', + bit_positions=range(32), + bit_values=[0] * 32 + ) + + component_model_types = _generate_component_model_types(chacha) + _update_component_model_types_for_truncated_components(component_model_types, middle_part_components) + _update_component_model_types_for_linear_components(component_model_types, bottom_part_components) + + sat_heterogeneous_model = SatDifferentialLinearModel(chacha, component_model_types) + + trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( + weight=5, fixed_values=[plaintext, modadd_3_15, cipher_output_5_24], solver_name="CADICAL_EXT", num_unknown_vars=511 + ) + assert trail["status"] == 'SATISFIABLE' + assert trail["total_weight"] <= 5 + + +def test_differential_linear_trail_with_fixed_weight_4_rounds_aradi(): + """Test for finding a differential-linear trail with fixed weight for 4 rounds of Aradi block cipher.""" + aradi = AradiBlockCipherSBox(number_of_rounds=4) + import itertools + top_part_components = [] + middle_part_components = [] + bottom_part_components = [] + for round_number in range(1): + top_part_components.append(aradi.get_components_in_round(round_number)) + for round_number in range(1, 3): + middle_part_components.append(aradi.get_components_in_round(round_number)) + for round_number in range(3, 4): + bottom_part_components.append(aradi.get_components_in_round(round_number)) + middle_part_components = list(itertools.chain(*middle_part_components)) + bottom_part_components = list(itertools.chain(*bottom_part_components)) + + middle_part_components = [component.id for component in middle_part_components] + bottom_part_components = [component.id for component in bottom_part_components] + + plaintext = set_fixed_variables( + component_id='plaintext', + constraint_type='equal', + bit_positions=range(128), + bit_values=integer_to_bit_list(0x00000000000080000000000000008000, 128, 'big') + ) + + cipher_output_3_86 = set_fixed_variables( + component_id='cipher_output_3_86', + constraint_type='equal', + bit_positions=range(128), + bit_values=integer_to_bit_list(0x90900120800000011010002000000000, 128, 'big') + ) + + key = set_fixed_variables( + component_id='key', + constraint_type='equal', + bit_positions=range(256), + bit_values=[0] * 256 + ) + + sbox_4_8 = set_fixed_variables( + component_id=f'sbox_3_8', + constraint_type='not_equal', + bit_positions=range(4), + bit_values=[0] * 4 + ) + + component_model_types = _generate_component_model_types(aradi) + _update_component_model_types_for_truncated_components(component_model_types, middle_part_components) + _update_component_model_types_for_linear_components(component_model_types, bottom_part_components) + + sat_heterogeneous_model = SatDifferentialLinearModel(aradi, component_model_types) + + trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( + weight=10, fixed_values=[key, plaintext, sbox_4_8, cipher_output_3_86], solver_name="CADICAL_EXT", num_unknown_vars=128-1 + ) + assert trail["status"] == 'SATISFIABLE' + assert trail["total_weight"] <= 10 + + +def test_differential_linear_trail_with_fixed_weight_4_rounds_chacha(): + """Test for finding a differential-linear trail with fixed weight for 4 rounds of ChaCha permutation.""" + chacha = ChachaPermutation(number_of_rounds=8) + + import itertools + + top_part_components = [] + middle_part_components = [] + bottom_part_components = [] + for round_number in range(2): + top_part_components.append(chacha.get_components_in_round(round_number)) + for round_number in range(2, 4): + middle_part_components.append(chacha.get_components_in_round(round_number)) + for round_number in range(4, 8): + bottom_part_components.append(chacha.get_components_in_round(round_number)) + + middle_part_components = list(itertools.chain(*middle_part_components)) + bottom_part_components = list(itertools.chain(*bottom_part_components)) + + middle_part_components = [component.id for component in middle_part_components] + bottom_part_components = [component.id for component in bottom_part_components] + + plaintext = set_fixed_variables( + component_id='plaintext', + constraint_type='equal', + bit_positions=range(512), + bit_values=integer_to_bit_list( + 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000088088780, + 512, + 'big' + ) + ) + + modadd_4_15 = set_fixed_variables( + component_id=f'modadd_4_15', + constraint_type='not_equal', + bit_positions=range(32), + bit_values=[0] * 32 + ) + + component_model_types = _generate_component_model_types(chacha) + _update_component_model_types_for_truncated_components(component_model_types, middle_part_components) + _update_component_model_types_for_linear_components(component_model_types, bottom_part_components) + + sat_heterogeneous_model = SatDifferentialLinearModel(chacha, component_model_types) + + trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( + weight=32, fixed_values=[plaintext, modadd_4_15], solver_name="CADICAL_EXT", num_unknown_vars=511 + ) + assert trail["status"] == 'SATISFIABLE' + assert trail["total_weight"] <= 32 + + +def test_diff_lin_chacha(): + """ + This test is verifying experimentally the test test_differential_linear_trail_with_fixed_weight_3_rounds_chacha + """ + input_difference = 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000 + output_mask = 0x00010000000100010000000100030003000000800000008000000000000001800000000000000001000000010000000201000101010000000000010103000101 + number_of_samples = 2 ** 12 + number_of_rounds = 6 + state_size = 512 + chacha = ChachaPermutation(number_of_rounds=number_of_rounds) + corr = differential_linear_checker_for_permutation( + chacha, input_difference, output_mask, number_of_samples, state_size + ) + import math + abs_corr = abs(corr) + assert abs(math.log(abs_corr, 2)) < 3 + + +def test_diff_lin_speck(): + """ + This test is verifying experimentally the test test_differential_linear_trail_with_fixed_weight_6_rounds_speck + """ + input_difference = 0x02110a04 + output_mask = 0x02000201 + number_of_samples = 2 ** 15 + number_of_rounds = 6 + fixed_key = 0x0 + speck = SpeckBlockCipher(number_of_rounds=number_of_rounds) + block_size = speck.inputs_bit_size[0] + key_size = speck.inputs_bit_size[1] + corr = differential_linear_checker_for_block_cipher_single_key( + speck, input_difference, output_mask, number_of_samples, block_size, key_size, fixed_key + ) + import math + abs_corr = abs(corr) + assert abs(math.log(abs_corr, 2)) <= 8 + + +def test_diff_lin_aradi(): + """ + This test is verifying experimentally the test test_differential_linear_trail_with_fixed_weight_4_rounds_aradi + """ + input_difference = 0x00000000000080000000000000008000 + output_mask = 0x90900120800000011010002000000000 + number_of_samples = 2 ** 12 + number_of_rounds = 4 + fixed_key = 0x90900120800000011010002000000000 + speck = AradiBlockCipherSBox(number_of_rounds=number_of_rounds) + block_size = speck.inputs_bit_size[0] + key_size = speck.inputs_bit_size[1] + corr = differential_linear_checker_for_block_cipher_single_key( + speck, input_difference, output_mask, number_of_samples, block_size, key_size, fixed_key + ) + import math + abs_corr = abs(corr) + print(corr, abs_corr, abs(math.log(abs_corr, 2))) + assert abs(math.log(abs_corr, 2)) < 8 diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py index 21eeebef..d769ada1 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py @@ -6,7 +6,8 @@ def test_branch_xor_linear_constraints(): speck = SpeckBlockCipher(number_of_rounds=3) sat = SatXorLinearModel(speck) - constraints = sat.branch_xor_linear_constraints() + + constraints = SatXorLinearModel.branch_xor_linear_constraints(sat.bit_bindings) assert constraints[0] == '-plaintext_0_o rot_0_0_0_i' assert constraints[1] == 'plaintext_0_o -rot_0_0_0_i' @@ -62,7 +63,7 @@ def test_fix_variables_value_xor_linear_constraints(): 'constraint_type': 'not_equal', 'bit_positions': [0, 1, 2, 3], 'bit_values': [1, 1, 1, 0]}] - constraints = sat.fix_variables_value_xor_linear_constraints(fixed_variables) + constraints = SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables) assert constraints == ['plaintext_0_o', '-plaintext_1_o', 'plaintext_2_o', 'plaintext_3_o', '-ciphertext_0_o -ciphertext_1_o -ciphertext_2_o ciphertext_3_o']