-
Notifications
You must be signed in to change notification settings - Fork 1
/
ntt.py
61 lines (51 loc) · 1.84 KB
/
ntt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Computes the forward number-theoretic transform of the given vector,
# with respect to the given primitive nth root of unity under the given modulus.
# The length of the vector must be a power of 2.
def transform_fast(vector, root, mod):
transform_vec = vector
n = len(transform_vec)
levels = n.bit_length() - 1
if 1 << levels != n:
raise ValueError("Length is not a power of 2")
powtable = []
temp = 1
for i in range(n // 2):
powtable.append(temp)
temp = temp * root % mod
# print("powtable[4]=",str(powtable[4]))
def reverse(x, bits):
y = 0
for i in range(bits):
y = (y << 1) | (x & 1)
x >>= 1
return y
for i in range(n):
j = reverse(i, levels)
if j > i:
transform_vec[i], transform_vec[j] = transform_vec[j], transform_vec[i]
size = 2
while size <= n:
halfsize = size // 2
tablestep = n // size
for i in range(0, n, size):
k = 0
for j in range(i, i + halfsize):
l = j + halfsize
left = transform_vec[j]
right = transform_vec[l] * powtable[k]
transform_vec[j] = (left + right) % mod
transform_vec[l] = (left - right) % mod
k += tablestep
size *= 2
return (transform_vec)
# Compute the inverse of ntt
def inverse_transform(vector, inv_root, inv_L, mod):
outvec = transform_fast(vector, inv_root, mod)
return [(val * inv_L % mod) for val in outvec]
# Compute the convolution using ntt
def convolve_ntt(vec1, vec2, root, inv_root, inv_L, mod):
temp1 = transform_fast(vec1, root, mod)
temp2 = transform_fast(vec2, root, mod)
temp3 = [(x * y % mod) for (x, y) in zip(temp1, temp2)]
conv = inverse_transform(temp3, inv_root, inv_L, mod)
return conv