Skip to content

Commit

Permalink
Merge pull request #291 from Crypto-TII/feat/division_property
Browse files Browse the repository at this point in the history
Feat/division property
  • Loading branch information
peacker authored Dec 6, 2024
2 parents 28d11dd + 97936a8 commit cb2c6ab
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 80 deletions.
173 changes: 106 additions & 67 deletions claasp/cipher_modules/division_trail_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
from collections import Counter
from sage.rings.polynomial.pbori.pbori import BooleanPolynomialRing
from claasp.cipher_modules.graph_generator import create_networkx_graph_from_input_ids, _get_predecessors_subgraph
from claasp.cipher_modules.component_analysis_tests import binary_matrix_of_linear_component
from gurobipy import Model, GRB
import os

verbosity = False

class MilpDivisionTrailModel():
"""
Expand All @@ -36,6 +39,7 @@ class MilpDivisionTrailModel():
This module can only be used if the user possesses a Gurobi license.
"""

def __init__(self, cipher):
self._cipher = cipher
self._variables = None
Expand All @@ -44,13 +48,20 @@ def __init__(self, cipher):
self._used_variables = []
self._variables_as_list = []
self._unused_variables = []
self._used_predecessors_sorted = None
self._output_id = None
self._output_bit_index_previous_comp = None
self._block_needed = None
self._input_id_link_needed = None

def get_all_variables_as_list(self):
for component_id in list(self._variables.keys())[:-1]:
for bit_position in self._variables[component_id].keys():
for value in self._variables[component_id][bit_position].keys():
if value != "current":
self._variables_as_list.append(self._variables[component_id][bit_position][value].VarName)
varname = self._variables[component_id][bit_position][value].VarName
if varname not in self._variables_as_list: # rot and intermediate has the same name than original
self._variables_as_list.append(varname)

def get_unused_variables(self):
self.get_all_variables_as_list()
Expand Down Expand Up @@ -81,9 +92,7 @@ def build_gurobi_model(self):
model = Model("basic_model", env=env)
# model = Model()
model.Params.LogToConsole = 0
model.Params.Threads = 16 # best found experimentaly on ascon_sbox_2rounds
model.setParam("PoolSolutions", 1234) # 200000000
model.setParam(GRB.Param.PoolSearchMode, 2)
# model.Params.Threads = 16
self._model = model

def get_anfs_from_sbox(self, component):
Expand Down Expand Up @@ -183,7 +192,6 @@ def add_sbox_constraints(self, component):
x = B.variable_names()
anfs = self.get_anfs_from_sbox(component)
anfs = [B(anfs[i]) for i in range(component.input_bit_size)]
# print(anfs)

copy_monomials_deg = self.create_gurobi_vars_sbox(component, input_vars_concat)

Expand Down Expand Up @@ -215,6 +223,52 @@ def add_sbox_constraints(self, component):
self._model.addConstr(output_vars[index] >= constr)
self._model.update()

def create_copies_for_linear_layer(self, binary_matrix, input_vars_concat):
copies = {}
for index, var in enumerate(input_vars_concat):
column = [row[index] for row in binary_matrix]
number_of_1s = list(column).count(1)
if number_of_1s > 1:
current = 1
else:
current = 0
copies[index] = {}
copies[index][0] = var
copies[index]["current"] = current
self.set_as_used_variables([var])
new_vars = self._model.addVars(list(range(number_of_1s)), vtype=GRB.BINARY,
name="copy_" + var.VarName)
self._model.update()
for i in range(number_of_1s):
self._model.addConstr(var >= new_vars[i])
self._model.addConstr(
sum(new_vars[i] for i in range(number_of_1s)) >= var)
self._model.update()
for i in range(1, number_of_1s + 1):
copies[index][i] = new_vars[i - 1]
return copies

def add_linear_layer_constraints(self, component):
output_vars = self.get_output_vars(component)
input_vars_concat = self.get_input_vars(component)

if component.type == "linear_layer":
binary_matrix = component.description
else:
binary_matrix = binary_matrix_of_linear_component(component)

copies = self.create_copies_for_linear_layer(binary_matrix, input_vars_concat)
for index_row, row in enumerate(binary_matrix):
constr = 0
for index_bit, bit in enumerate(row):
if bit:
current = copies[index_bit]["current"]
constr += copies[index_bit][current]
copies[index_bit]["current"] += 1
self.set_as_used_variables([copies[index_bit][current]])
self._model.addConstr(output_vars[index_row] == constr)
self._model.update()

def add_xor_constraints(self, component):
output_vars = self.get_output_vars(component)

Expand All @@ -230,20 +284,15 @@ def add_xor_constraints(self, component):
else:
input_vars_concat.append(self._variables[input_name][pos][current])
self._variables[input_name][pos]["current"] += 1
# print(input_vars_concat)

block_size = component.output_bit_size
nb_blocks = component.description[1]
if constant_flag != []:
nb_blocks -= 1
# print(self._occurences[component.id])
# print(list(self._occurences[component.id].keys()))
# print(len(list(self._occurences[component.id].keys())))
for index, bit_pos in enumerate(list(self._occurences[component.id].keys())):
constr = 0
for j in range(nb_blocks):
constr += input_vars_concat[index + block_size * j]
# print(input_vars_concat[index + block_size * j])
self.set_as_used_variables([input_vars_concat[index + block_size * j]])
if (constant_flag != []) and (constant_flag[index]):
self._model.addConstr(output_vars[index] >= constr)
Expand Down Expand Up @@ -372,12 +421,14 @@ def add_constraints(self, predecessors, input_id_link_needed, block_needed):
self.create_gurobi_vars_from_all_components(predecessors, input_id_link_needed, block_needed)

used_predecessors_sorted = self.order_predecessors(list(self._occurences.keys()))
self._used_predecessors_sorted = used_predecessors_sorted
for component_id in used_predecessors_sorted:
if component_id not in self._cipher.inputs:
component = self._cipher.get_component_from_id(component_id)
print(f"---------> {component.id}")
if component.type == "sbox":
self.add_sbox_constraints(component)
elif component.type in ["linear_layer", "mix_column"]:
self.add_linear_layer_constraints(component)
elif component.type in ["cipher_output", "constant", "intermediate_output"]:
continue
elif component.type == "word_operation":
Expand Down Expand Up @@ -414,13 +465,9 @@ def get_where_component_is_used(self, predecessors, input_id_link_needed, block_
component = self._cipher.get_component_from_id(input_id_link_needed)
occurences[input_id_link_needed] = [[i for i in range(component.output_bit_size)]]

# print("occurences")
# print(occurences)
occurences_final = {}
for component_id in occurences.keys():
occurences_final[component_id] = self.find_copy_indexes(occurences[component_id])
# print("occurences_final")
# print(occurences_final)

self._occurences = occurences_final
return occurences_final
Expand Down Expand Up @@ -462,13 +509,10 @@ def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_nee
occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed)
all_vars = {}
used_predecessors_sorted = self.order_predecessors(list(occurences.keys()))
print("used_predecessors_sorted")
print(used_predecessors_sorted)
for component_id in used_predecessors_sorted:
all_vars[component_id] = {}
# We need the inputs vars to be the first ones defined by gurobi in order to find their values with X.values method.
# That's why we split the following loop: we first created the original vars, and then the copies vars when necessary.
# print(f"###### {component_id}")
if component_id[:3] == "rot":
component = self._cipher.get_component_from_id(component_id)
rotate_offset = component.description[1]
Expand Down Expand Up @@ -571,20 +615,15 @@ def get_output_bit_index_previous_component(self, output_bit_index_ciphertext, c
block_needed = comp.input_bit_positions[index]
input_id_link_needed = chosen_cipher_output
output_bit_index_previous_comp = output_bit_index_ciphertext
print(output_id)
print(block_needed)
print(input_id_link_needed)
print(output_bit_index_previous_comp)
return output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot
else:
output_id = self.get_cipher_output_component_id()
# output_id = "xor_1_69"
component = self._cipher.get_component_from_id(output_id)
pivot = 0
output_bit_index_previous_comp = output_bit_index_ciphertext
for index, block in enumerate(component.input_bit_positions):
if pivot <= output_bit_index_ciphertext < pivot + len(block):
output_bit_index_previous_comp = output_bit_index_ciphertext - pivot
output_bit_index_previous_comp = block[output_bit_index_ciphertext - pivot]
block_needed = block
input_id_link_needed = component.input_id_links[index]
break
Expand All @@ -609,31 +648,28 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex
output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot = self.get_output_bit_index_previous_component(
output_bit_index_ciphertext, chosen_cipher_output)

self._output_id = output_id
self._output_bit_index_previous_comp = output_bit_index_previous_comp
self._block_needed = block_needed
self._input_id_link_needed = input_id_link_needed

G = create_networkx_graph_from_input_ids(self._cipher)
predecessors = list(_get_predecessors_subgraph(G, [input_id_link_needed]))
for input_id in self._cipher.inputs + ['']:
if input_id in predecessors:
predecessors.remove(input_id)

# print("input_id_link_needed")
# print(input_id_link_needed)
# print("predecessors")
# print(predecessors)
self.add_constraints(predecessors, input_id_link_needed, block_needed)

var_from_block_needed = []
for i in block_needed:
var_from_block_needed.append(self._variables[input_id_link_needed][i][0])
# print("var_from_block_needed")
# print(var_from_block_needed)

output_vars = self._model.addVars(list(range(pivot, pivot + len(block_needed))), vtype=GRB.BINARY,
name=output_id)
self._variables[output_id] = output_vars
output_vars = list(output_vars.values())
self._model.update()
# print("output_vars")
# print(output_vars)

for i in range(len(block_needed)):
self._model.addConstr(output_vars[i] == var_from_block_needed[i])
Expand All @@ -654,10 +690,10 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex

self.set_unused_variables_to_zero()
self._model.update()
self._model.write("division_trail_model.lp")
end = time.time()
building_time = end - start
print(f"########## building_time : {building_time}")
if verbosity:
print(f"########## building_time : {building_time}")
self._model.update()

def get_solutions(self):
Expand All @@ -676,7 +712,6 @@ def get_solutions(self):
first_input_bit_positions = list(self._occurences[self._cipher.inputs[0]].keys())

solCount = self._model.SolCount
print('Number of solutions (might cancel each other) found: ' + str(solCount))
monomials = []
for sol in range(solCount):
self._model.setParam(GRB.Param.SolutionNumber, sol)
Expand All @@ -695,67 +730,80 @@ def get_solutions(self):
else:
if index < len(list(self._occurences[self._cipher.inputs[0]].keys())):
tmp += self._cipher.inputs[0][0] + str(first_input_bit_positions[index])
if 1 not in values[:max_input_bit_pos]:
tmp += str(1)
else:
if nb_inputs_used == 1:
input1_prefix = self._cipher.inputs[0][0]
l = tmp.split(input1_prefix)[1:]
sorted_l = sorted(l, key=lambda x: (x == '', int(x) if x else 0))
l = [''] + sorted_l
tmp = input1_prefix.join(l)

if tmp in monomials:
monomials.remove(tmp)
else:
monomials.append(tmp)

end = time.time()
printing_time = end - start
print(f"########## printing_time : {printing_time}")
print(monomials)
print(f'Number of monomials found: {len(monomials)}')
if verbosity:
print('Number of solutions (might cancel each other) found: ' + str(solCount))
print(f"########## printing_time : {printing_time}")
print(f'Number of monomials found: {len(monomials)}')
return monomials

def optimize_model(self):
print(self._model)
start = time.time()
self._model.optimize()
end = time.time()
solving_time = end - start
print(f"########## solving_time : {solving_time}")
if verbosity:
print(self._model)
print(f"########## solving_time : {solving_time}")

def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)

# # Specific to Aradi analysis:
# for i in range(96):
# v = self._model.getVarByName(f"plaintext[{i}]")
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
# ########################
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)
self._model.write("division_trail_model.lp")

self.optimize_model()
self.get_solutions()
return self.get_solutions()

def check_presence_of_particular_monomial_in_specific_anf(self, monomial, output_bit_index, fixed_degree=None,
chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)

for term in monomial:
var_term = self._model.getVarByName(f"{term[0]}[{term[1]}]")
self._model.addConstr(var_term == 1)
self._model.update()
self._model.write("division_trail_model.lp")

self.optimize_model()
self.get_solutions()
return self.get_solutions()

def check_presence_of_particular_monomial_in_all_anf(self, monomial, fixed_degree=None, chosen_cipher_output=None):
def check_presence_of_particular_monomial_in_all_anf(self, monomial, fixed_degree=None,
chosen_cipher_output=None):
s = ""
for term in monomial:
s += term[0][0] + str(term[1])
for i in range(self._cipher.output_bit_size):
print(f"\nSearch of {s} in anf {i} :")
self.check_presence_of_particular_monomial_in_specific_anf(monomial, i, fixed_degree, chosen_cipher_output)
self.check_presence_of_particular_monomial_in_specific_anf(monomial, i, fixed_degree,
chosen_cipher_output)

def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_output=None):
fixed_degree = None
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam(GRB.Param.PoolSearchMode, 1)
self._model.setParam('Presolve', 2)
self._model.setParam('MIPFocus', 3)
# self._model.setParam('Cuts', 2)
self._model.setParam('NodefileStart', 2.0)
self._model.setParam("MIPFocus", 2)
self._model.setParam("MIPGap", 0) # when set to 0, best solution = optimal solution
self._model.setParam('Cuts', 2)

index_plaintext = self._cipher.inputs.index("plaintext")
plaintext_bit_size = self._cipher.inputs_bit_size[index_plaintext]
Expand All @@ -765,19 +813,10 @@ def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_out
p.append(self._model.getVarByName(f"plaintext[{i}]"))
self._model.setObjective(sum(p[i] for i in range(nb_plaintext_bits_used)), GRB.MAXIMIZE)

## Specific to Aradi analysis:
# for i in range(128):
# v = self._model.getVarByName(f"plaintext[{i}]")
# if 0 <= i < 128: # free vars
# self._model.addConstr(v >= 0)
# else:
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
#######################

self._model.update()
self._model.write("division_trail_model.lp")
self.optimize_model()
# get degree

degree = self._model.getObjective().getValue()
return degree

Expand Down
Loading

0 comments on commit cb2c6ab

Please sign in to comment.