from binascii import unhexlify import random, struct, hmac, itertools, math from Crypto.Cipher import AES import cryptography from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes import numpy as np import xarray as xr ## Computation in GF(2^128)/(x^128 + x^7 + x^2 + x^1 + 1) ## Elements are represented as an integer n where (n & (1 << i)) is the coefficient for x^i. class X: """A helper class for specifying elements of GF(2^n).""" def __pow__(self, y): return 1 << y x = X() gcm_modulus = x**128 + x**7 + x**2 + x**1 + 1 gcm_modulus_degree = 128 def gf128_deg(x): """Compute the degree of x. >>> assert gf128_deg(x**20 + x**5 + 1) == 20 >>> assert gf128_deg(x**90) == 90 >>> assert gf128_deg(x**1) == 1 >>> assert gf128_deg(1) == 0 >>> assert gf128_deg(0) == -1 """ return x.bit_length() - 1 def gf128_add(a, b): """Compute a + b. >>> assert gf128_add(0, 0) == 0 >>> assert gf128_add(x**7 + x**2, 0) == x**7 + x**2 >>> assert gf128_add(x**7 + x**2, x**2 + x**1) == x**7 + x**1 >>> assert gf128_add(x**7 + x**2, x**7 + x**2) == 0 """ return a ^ b def gf128_mul(a, b): """Compute ab. >>> assert gf128_mul(x**20 + x**5 + 1, x**13 + x**2) == x**33 + x**22 + x**18 + x**13 + x**7 + x**2 >>> assert gf128_mul(x**80, x**90) == x**49 + x**44 + x**43 + x**42 >>> assert gf128_mul(x**23 + x**5, 1) == x**23 + x**5 """ if a == 0 or b == 0: return 0 if a > b: return gf128_mul(b, a) p = 0 deg_b = gf128_deg(b) while a > 0: if a & 1: p ^= b a = a >> 1 b = b << 1 deg_b += 1 if deg_b == gcm_modulus_degree: b = b ^ gcm_modulus deg_b = gf128_deg(b) return p def gf128_divmod(a, b): """Compute q, r such that a = qb + r; gf128_deg(r) < gf128_deg(b). >>> assert gf128_divmod(x**20 + x**5 + 1, x**13 + x**2) == (x**7, x**9 + x**5 + 1) >>> assert gf128_divmod(x**90, x**80) == (x**10, 0) """ q, r = 0, a while gf128_deg(r) >= gf128_deg(b): d = gf128_deg(r) - gf128_deg(b) q = q ^ (1 << d) r = r ^ (b << d) return q, r def gf128_mod(a, m=gcm_modulus): """Compute b such that a = b; gf128_deg(b) < gf128_deg(m). >>> assert gf128_mod(x**220 + x**5 + 1, x**50 + x**4) == x**36 + x**5 + 1 >>> assert gf128_mod(x**130 + x**64 + x**2) == x**64 + x**9 + x**4 + x**3 """ q, r = gf128_divmod(a, m) return r def gf128_egcd(a, b): """Compute g, x, y such that g | a, g | b, g = ax + by; for any h such that h | a and h | b, gf128_deg(g) >= gf128_deg(h). >>> assert gf128_egcd(x**2 + 1, x**1 + 1) == (x**1 + 1, 0, 1) >>> assert gf128_egcd(x**12 + x**4, x**5 + 1) == (x**1 + 1, x**3 + x**2 + 1, x**10 + x**9 + x**7 + x**5 + x**4 + x**1 + 1) >>> assert gf128_egcd(x**64 + x**2, x**37 + x**12 + 1) == (1, x**35 + x**34 + x**33 + x**31 + x**30 + x**28 + x**27 + x**25 + x**24 + x**21 + x**20 + x**18 + x**17 + x**15 + x**14 + x**12 + x**11 + x**9 + x**7 + x**6 + x**4 + x**3 + x**1 + 1, x**62 + x**61 + x**60 + x**58 + x**57 + x**55 + x**54 + x**52 + x**51 + x**48 + x**47 + x**45 + x**44 + x**42 + x**41 + x**39 + x**38 + x**37 + x**35 + x**34 + x**32 + x**31 + x**29 + x**28 + x**26 + x**25 + x**24 + x**22 + x**21 + x**19 + x**18 + x**16 + x**15 + x**13 + x**12 + x**11 + x**9 + x**8 + x**6 + x**5 + x**3 + x**2 + 1) """ orig_a = a orig_b = b a_x = 1 a_y = 0 b_x = 0 b_y = 1 # Invariant = a_x*x + a_y*y = a, b_x*x + b_y*y = b, r_x*r + r_y*y = r r_x = 0 r_y = 1 if a == 0: return b, r_x, r_y while True: q, r = gf128_divmod(b, a) if r == 0: # FIX mul so mod isn't necessary; with zeros assert a == gf128_mod(gf128_add(gf128_mul(orig_a, r_x), gf128_mul(orig_b, r_y))) return a, r_x, r_y r_x = gf128_add(b_x, gf128_mul(q, a_x)) r_y = gf128_add(b_y, gf128_mul(q, a_y)) a, b = r, a b_x, b_y = a_x, a_y a_x, a_y = r_x, r_y def gf128_inv(a): """Compute b such that ab = 1. >>> assert gf128_inv(x**1) == x**127 + x**6 + x**1 + 1 >>> assert gf128_inv(x**4) == x**127 + x**125 + x**124 + x**6 + x**4 + x**3 + x**1 + 1 """ g, xx, yy = gf128_egcd(a, gcm_modulus) assert g == 1 return xx def gf128_exp(a, n): """Compute a^n. >>> assert gf128_exp(x**2 + 1, 0) == 1 >>> assert gf128_exp(x**2 + 1, 1) == x**2 + 1 >>> assert gf128_exp(x**2 + 1, 2) == x**4 + 1 >>> assert gf128_exp(x**2 + 1, 10) == x**20 + x**16 + x**4 + 1 """ if n == 0: return 1 rec = gf128_exp(a, n//2) rec = gf128_mul(rec, rec) if n % 2 == 1: rec = gf128_mul(rec, a) return rec def gf128_sqrt(x): """Compute y such that y^2 = x: the inverse Frobenius automorphism. Reference: https://math.stackexchange.com/questions/943417/square-root-for-galois-fields-gf2m >>> assert gf128_sqrt(1) == 1 >>> assert gf128_sqrt(x**2 + 1) == x**1 + 1 >>> assert gf128_sqrt(x**3 + 1) == x**126 + x**123 + x**120 + x**117 + x**114 + x**111 + x**108 + x**105 + x**102 + x**99 + x**96 + x**93 + x**90 + x**87 + x**84 + x**81 + x**78 + x**75 + x**72 + x**69 + x**66 + x**63 + x**62 + x**60 + x**59 + x**57 + x**56 + x**54 + x**53 + x**51 + x**50 + x**48 + x**47 + x**45 + x**44 + x**42 + x**41 + x**39 + x**38 + x**36 + x**35 + x**33 + x**32 + x**30 + x**29 + x**27 + x**26 + x**24 + x**23 + x**21 + x**20 + x**18 + x**17 + x**15 + x**14 + x**12 + x**11 + x**9 + x**8 + x**6 + x**3 + 1 """ return gf128_exp(x, 2**127) def gf128_to_vec(x): return [int(n) for n in bin(x)[2:].zfill(128)[::-1]] def vec_to_gf128(vs): x = 0 for i, v in enumerate(vs): if v: x += (1 << i) return x def reverse_mask(b): # from https://stackoverflow.com/a/2602885 b = (b & 0xF0) >> 4 | (b & 0x0F) << 4; b = (b & 0xCC) >> 2 | (b & 0x33) << 2; b = (b & 0xAA) >> 1 | (b & 0x55) << 1; return b; def bytes_to_gf128(xs, st=None, en=None): if st is None and en is None: st = 0 en = 16 a, b = struct.unpack('>= 8 s += b'\x00' * (16 - len(s)) assert len(s) == 16 return s def Ms(): vs = [] for i in range(0, 128): v = gf128_to_vec(gf128_mod(1 << (2*i))) vs.append(v) return np.array(vs).transpose() def Mc(c): vs = [] pr = c for i in range(0, 128): vs.append(gf128_to_vec(pr)) pr = gf128_mul(x**1, pr) return np.array(vs).transpose() # Adapted from Wiki def rref_mod_2(M): M = M.copy().astype('int') lead = 0 rowCount, columnCount = M.shape adj = np.eye(rowCount, rowCount).astype('int') for r in range(rowCount): if columnCount <= lead: return M, adj i = r while M[i][lead] == 0: i = i + 1 if rowCount == i: i = r lead = lead + 1 if columnCount == lead: return M, adj if i != r: M[[i, r]] = M[[r, i]] adj[[i, r]] = adj[[r, i]] for j in range(rowCount): if M[j][lead] == 1 and j != r: M[j] ^= M[r] adj[j] ^= adj[r] lead = lead + 1 return M, adj def kernel(M, f): mt = M.transpose() N, adj = f(mt) basis = [] for i, row in enumerate(N): if np.all(np.isclose(row, 0)): basis.append(adj[i]) return N, adj, basis ## Computation in GF(2^128)[X]/(x^128 + x^7 + x^2 + x^1 + 1) ## Elements are represented as arrays where the ith element is the coefficient for x^i. def gf128poly_lead(x): """Compute the leading coefficient of x. >>> assert gf128poly_lead([x**3, x**1, 1]) == 1 >>> assert gf128poly_lead([x**3, x**1, x**2, x**4]) == x**4 >>> assert gf128poly_lead([1]) == 1 """ assert len(x) > 0 return x[-1] def collapse(x): if len(x) == 0: return [] if gf128poly_lead(x) == 0: return collapse(x[:-1]) return [gf128_mod(y, gcm_modulus) for y in x] def gf128poly_deg(x): """Compute the degree of x. >>> assert gf128poly_deg([x**3, x**1, 1]) == 2 >>> assert gf128poly_deg([x**2, x**3]) == 1 >>> assert gf128poly_deg([x**15]) == 0 >>> assert gf128poly_deg([1]) == 0 >>> assert gf128poly_deg([]) == -1 """ return len(x) - 1 def gf128poly_add(a, b): """Compute a + b. >>> assert gf128poly_add([x**3, x**1, 1], [x**2 + 1, x**3]) == [x**3 + x**2 + 1, x**1 + x**3, 1] >>> assert gf128poly_add([x**2, x**3], [x**5, x**12, x**2, x**4]) == [x**5 + x**2, x**12 + x**3, x**2, x**4] >>> assert gf128poly_add([x**15], []) == [x**15] >>> assert gf128poly_add([x**15], [x**15]) == [] """ return collapse([gf128_add(x, y) for x, y in itertools.zip_longest(a, b, fillvalue=0)]) def gf128poly_mul(a, b): """Compute ab. >>> assert gf128poly_mul([x**3, x**1, 1], [x**2 + 1, x**3]) == [x**5 + x**3, x**6 + x**3 + x**1, x**4 + x**2 + 1, x**3] >>> assert gf128poly_mul([x**2, x**3], [x**5, x**12, x**2, x**4]) == [x**7, x**14 + x**8, x**15 + x**4, x**6 + x**5, x**7] >>> assert gf128poly_mul([x**15], []) == [] >>> assert gf128poly_mul([x**15], [1]) == [x**15] >>> assert gf128poly_mul([0, x**15], [0, x**15]) == [0, 0, x**30] >>> assert gf128poly_mul([x**70], [1, x**70, 1]) == [x**70, x**19 + x**14 + x**13 + x**12, x**70] """ p = [] while gf128poly_deg(a) >= 0: p = gf128poly_add(p, [gf128_mul(x, a[0]) for x in b]) a = a[1:] b = [0] + b return collapse(p) def gf128poly_divmod(a, b): """Compute q, r such that a = qb + r; gf128poly_deg(r) < gf128poly_deg(b). >>> assert gf128poly_divmod([x**5 + x**3, x**6 + x**3 + x**1, x**4 + x**2 + 1, x**3], [x**2 + 1, x**3]) == ([x**3, x**1, 1], []) >>> assert gf128poly_divmod([x**7, x**14 + x**8, x**15 + x**4, x**6 + x**5, x**7], [x**2, x**3]) == ([x**5, x**12, x**2, x**4], []) >>> assert gf128poly_divmod([x**15], [1]) == ([x**15], []) >>> assert gf128poly_divmod([x**15], [x**15]) == ([1], []) >>> assert gf128poly_divmod([0, 0, x**30], [0, x**15]) == ([0, x**15], []) >>> assert gf128poly_divmod([x**70, x**19 + x**14 + x**13 + x**12, x**70], [1, x**70, 1]) == ([x**70], []) >>> assert gf128poly_divmod([x**18, x**18, x**3], [x**3, x**3]) == ([x**15 + 1, 1], [x**3]) """ assert b != [], 'divide by zero' q = [] while gf128poly_deg(a) >= gf128poly_deg(b): highA = gf128poly_lead(a) highB = gf128poly_lead(b) highBinv = 1 if highB == 1 else gf128_inv(highB) qq = gf128_mul(highA, highBinv) de = gf128poly_deg(a) - gf128poly_deg(b) new_term = [0]*(de+1) new_term[de] = qq q = gf128poly_add(q, new_term) fact = gf128poly_mul(b, new_term) a = gf128poly_add(a, fact) return collapse(q), collapse(a) def gf128poly_sqrt(p): """Compute q such that q^2 = p for p such that p' = 0. >>> assert gf128poly_sqrt([1, 0, 1, 0, 1, 0, 1]) == [1, 1, 1, 1] >>> assert gf128poly_sqrt([1, 0, 1, 0, 0, 0, 1]) == [1, 1, 0, 1] >>> assert gf128poly_sqrt([1, 0, x**4]) == [1, x**2] >>> assert gf128poly_sqrt([1, 0, x**3]) == [1, x**126 + x**123 + x**120 + x**117 + x**114 + x**111 + x**108 + x**105 + x**102 + x**99 + x**96 + x**93 + x**90 + x**87 + x**84 + x**81 + x**78 + x**75 + x**72 + x**69 + x**66 + x**63 + x**62 + x**60 + x**59 + x**57 + x**56 + x**54 + x**53 + x**51 + x**50 + x**48 + x**47 + x**45 + x**44 + x**42 + x**41 + x**39 + x**38 + x**36 + x**35 + x**33 + x**32 + x**30 + x**29 + x**27 + x**26 + x**24 + x**23 + x**21 + x**20 + x**18 + x**17 + x**15 + x**14 + x**12 + x**11 + x**9 + x**8 + x**6 + x**3] """ assert gf128poly_formal_derivative(p) == [] q = [] for x in p: if x == 0 or x == 1: q.append(x) else: q.append(gf128_sqrt(x)) for i in range(2, len(q), 2): q[i//2] = q[i] q[i] = 0 return collapse(q) def gf128poly_monic(p): """Compute p/p_lead for p such that gf128poly_deg(p) >= 0. >>> assert gf128poly_monic([x**54, x**2]) == [x**52, 1] >>> assert gf128poly_monic([x**2, x**3]) == [x**127 + x**6 + x**1 + 1, 1] """ assert len(p) > 0 lead = gf128poly_lead(p) if lead == 1: return p lead_inv = gf128_inv(lead) return [gf128_mul(x, lead_inv) for x in p] def gf128poly_formal_derivative(p): """Compute p'. >>> assert gf128poly_formal_derivative([]) == [] >>> assert gf128poly_formal_derivative([x**25]) == [] >>> assert gf128poly_formal_derivative([x**25, x**3 + x**1, x**10, x**7]) == [x**3 + x**1, 0, x**7] """ q = [] for i, x in enumerate(p): if i == 0: continue if i % 2 == 0: e = 0 else: e = x q.append(e) return collapse(q) def gf128poly_modexp(a, n, m): """Compute a^n (mod m). >>> assert gf128poly_modexp([x**2, x**3, x**4], 0, [x**3, x**2, 1]) == [1] >>> assert gf128poly_modexp([x**2, x**3, x**4], 1, [x**3, x**2, 1, x**5]) == [x**2, x**3, x**4] >>> assert gf128poly_modexp([x**2, x**3, x**4], 5, [x**3, x**2, 1]) == [x**39 + x**38 + x**37 + x**36 + x**35 + x**33 + x**32 + x**30 + x**27 + x**26 + x**25 + x**24 + x**21 + x**20 + x**15 + x**10, x**38 + x**36 + x**35 + x**33 + x**32 + x**31 + x**26 + x**24 + x**23 + x**22 + x**21 + x**20 + x**14 + x**11] """ def modmul(a, b, m): q, r = gf128poly_divmod(gf128poly_mul(a, b), m) return r if n == 0: return [1] rec = gf128poly_modexp(a, n//2, m) rec = modmul(rec, rec, m) if n % 2 == 1: rec = modmul(rec, a, m) return rec def gf128poly_monic_gcd(a, b): """Compute g such that g | a, g | b, gf128poly_lead(g) = 1; for any h such that h | a and h | b, gf128poly_deg(g) >= gf128poly_deg(h). >>> assert gf128poly_monic_gcd([x**3, x**2, x**4], [x**3, x**2, x**4]) == [x**127 + x**6 + x**1 + 1, x**127 + x**126 + x**6 + x**5 + x**1, 1] >>> assert gf128poly_monic_gcd([1], [1]) == [1] >>> assert gf128poly_monic_gcd([x**5 + x**2], []) == [x**5 + x**2] >>> assert gf128poly_monic_gcd([x**3, 0, 0, x**6], [x**3, x**8 + x**4, x**9 + x**7, x**8]) == [x**127 + x**6 + x**1 + 1, 1] """ assert gf128poly_deg(a) >= 0 or gf128poly_deg(b) >= 0 if gf128poly_deg(a) < 0: return b if gf128poly_deg(b) < 0: return a q, r = gf128poly_divmod(b, a) ret = gf128poly_monic_gcd(r, a) return gf128poly_monic(ret) def gf128poly_square_free_factorization(f): """Compute the square-free factorization of a monic polynomial f. Reference: https://en.wikipedia.org/wiki/Factorization_of_polynomials_over_finite_fields#Square-free_factorization >>> a = [x**1, x**2, 1] >>> b = [x**2, x**3+1, 1] >>> c = [x**3, x**4, x**5, 1] >>> asq = gf128poly_mul(a, a) >>> acb = gf128poly_mul(asq, a) >>> bsq = gf128poly_mul(b, b) >>> acbbsq = gf128poly_mul(acb, bsq) >>> p = gf128poly_mul(acbbsq, c) >>> assert gf128poly_square_free_factorization([1]) == [([1], 1)] >>> assert gf128poly_square_free_factorization(p) == [([x**3, x**4, x**5, 1], 1), ([x**1, x**2, 1], 3), ([x**2, x**3 + 1, 1], 2)] """ R = [] assert f[-1] == 1, 'monic' if f == [1]: return [([1], 1)] fprime = gf128poly_formal_derivative(f) c = gf128poly_monic_gcd(f, fprime) w, r = gf128poly_divmod(f, c) assert r == [] i = 1 while w != [1]: y = gf128poly_monic_gcd(w, c) fac, r = gf128poly_divmod(w, y) assert r == [] R.append((fac, i)) w = y c, r = gf128poly_divmod(c, y) assert r == [] i = i + 1 if c != [1]: c = gf128poly_sqrt(c) c = gf128poly_monic(c) Rrec = gf128poly_square_free_factorization(c) R += [(fac, i*2) for (fac, i) in Rrec] return [(fac, mult) for fac, mult in R if fac != [1]] def gf128poly_distinct_degree_factorization(f): """Compute the distinct-degree factorization of a square-free monic polynomial f. Reference: https://en.wikipedia.org/wiki/Factorization_of_polynomials_over_finite_fields#Distinct-degree_factorization >>> a = [x**1, 1] >>> b = [x**2, 1] >>> ab = gf128poly_mul(a, b) >>> assert gf128poly_distinct_degree_factorization(ab) == {((x**3, x**2 + x**1, 1), 1)} """ i = 1 S = set() f_ = f.copy() while gf128poly_deg(f_) >= 2*i: qq = gf128poly_modexp([0, 1], 2**(128*i), f_) qq = gf128poly_add(qq, [0, 1]) g = gf128poly_monic_gcd(f_, qq) if g != [1]: S.add((tuple(g), i)) q, r = gf128poly_divmod(f_, g) assert r == [] f_ = q i += 1 if f_ != [1]: S.add((tuple(f_), gf128poly_deg(f_))) if len(S) == 0: return {(tuple(f), 1)} else: return S def gf128poly_equal_degree_factorization(f, d): """Compute the equal-degree factorization of a square-free monic polynomial f whose factors are all of degree d. Uses the Cantor-Zassenhaus algorithm. Reference: https://en.wikipedia.org/wiki/Factorization_of_polynomials_over_finite_fields#Equal-degree_factorization >>> a = [x**1, 1] >>> b = [x**2, 1] >>> ab = gf128poly_mul(a, b) >>> assert gf128poly_equal_degree_factorization(ab, 1) == {(x**2, 1), (x**1, 1)} """ def gf128_rand(): return random.randint(0, 2**128-1) def gf128poly_rand(degree): return [gf128_rand() for _ in range(degree+1)] exp = (2**128-1)//3 n = gf128poly_deg(f) r = n//d S = {tuple(f)} while len(S) < r: h = gf128poly_rand(n - 1) g = gf128poly_monic_gcd(h, f) if g == [1]: g = gf128poly_add(gf128poly_modexp(h, exp, f), [1]) for u in list(S): u = list(u) if gf128poly_deg(u) == d: continue gugcd = gf128poly_monic_gcd(g, u) if gugcd != [1] and gugcd != u: S.remove(tuple(u)) S.add(tuple(gugcd)) qq, rr = gf128poly_divmod(u, gugcd) assert rr == [] S.add(tuple(qq)) return S def gf128poly_factorize(f, degree=None): """Compute the factors of a polynomial f. Does not return multiplicity. If degree is specified, only returns factors of that degree. """ factors = set() f = gf128poly_monic(f) fs = gf128poly_square_free_factorization(f) for p, _ in fs: qs = gf128poly_distinct_degree_factorization(p) for q, d in qs: if degree is not None and d != degree: continue rs = gf128poly_equal_degree_factorization(list(q), d) factors |= rs return factors ## AES-GCM def pad(xs): return xs + b'\x00' * ((16 - len(xs)) % 16) def build_blocks(aad, c): padded_aad = pad(aad) padded_c = pad(c) length = struct.pack('>QQ', len(aad)*8, len(c)*8) return padded_aad + padded_c + length def gmac(h, s, aad, c): blocks = build_blocks(aad, c) g = 0 for i in range(len(blocks)//16): block = bytes_to_gf128(blocks, i*16, (i + 1)*16) g = gf128_add(g, block) g = gf128_mul(g, h) t = gf128_add(g, s) return gf128_to_bytes(t) def ecb_encrypt(key, xs): c = AES.new(key, AES.MODE_ECB) return c.encrypt(xs) def xor(a, b): return bytes([(x^y) for (x, y) in zip(a, b)]) def gctr(k, nonce, x): y = b'' ctr = (int.from_bytes(nonce, 'big') << 32) + 2 i = 0 while len(y) < len(x): y += xor(x[i*16:(i + 1)*16], ecb_encrypt(k, ctr.to_bytes(16, 'big'))) ctr += 1 i += 1 return y def authentication_key(k): return bytes_to_gf128(ecb_encrypt(k, b'\x00'*16)) def blind(k, nonce): return bytes_to_gf128(ecb_encrypt(k, nonce + b'\x00\x00\x00\x01')) def gcm_encrypt(k, nonce, aad, m, mac_bytes=16): c = gctr(k, nonce, m) mac = gmac(authentication_key(k), blind(k, nonce), aad, c) return c, mac[:mac_bytes] def gcm_decrypt(k, nonce, aad, c, mac, mac_bytes=16): m = gctr(k, nonce, c) expected_mac = gmac(authentication_key(k), blind(k, nonce), aad, c)[:mac_bytes] assert hmac.compare_digest(mac, expected_mac) return m ## Forbidden Attack def compute_forbidden_polynomial(aad1, aad2, c1, c2, mac1, mac2): bs1 = build_blocks(aad1, c1) bs2 = build_blocks(aad2, c2) if len(bs1) < len(bs2): bs1 = b'\x00'*(len(bs2)-len(bs1)) + bs1 else: bs2 = b'\x00'*(len(bs1)-len(bs2)) + bs2 assert len(bs1) == len(bs2) f = [] N = len(bs1)//16 for i in range(N): b1 = bytes_to_gf128(bs1[i*16:(i+1)*16]) b2 = bytes_to_gf128(bs2[i*16:(i+1)*16]) f.append(gf128_add(b1, b2)) f.append(gf128_add(bytes_to_gf128(mac1), bytes_to_gf128(mac2))) return collapse(list(reversed(f))) def nonce_reuse_recover_secrets(nonce, aad1, aad2, c1, c2, mac1, mac2): f = compute_forbidden_polynomial(aad1, aad2, c1, c2, mac1, mac2) factors = gf128poly_factorize(f, degree=1) secrets = [] for factor in factors: if gf128poly_deg(factor) == 1: h = factor[0] zero_tag = gmac(h, 0, aad1, c1) s = gf128_add(bytes_to_gf128(mac1), bytes_to_gf128(zero_tag)) secrets.append((h, s)) return secrets ## MAC Truncation Attack def gen_blocks(n, js): blocks = b'\x00'*16*n for j in js: whichbyte = j // 8 whichbit = j % 8 newbyte = blocks[whichbyte] ^ (1 << (7-whichbit)) blocks = blocks[:whichbyte] + bytes([newbyte]) + blocks[whichbyte+1:] return blocks squarer = np.array(Ms()) use_numpy = True if use_numpy: mcsqlookup = np.load(open('square-basis.np', 'rb')) else: mcsqlookup = xr.open_dataarray('square-basis.nc') def mc_squared(c, j): return sum(mcsqlookup[i, j] if use_numpy else mcsqlookup[i, j].to_numpy() for i in range(128) if 1 == (c >> i) & 1) % 2 def gen_ad(blocks): matret = np.zeros((128, 128)) for i in range(len(blocks)//16): block = blocks[(i*16):(i+1)*16] if block == b'\x00'*16: continue j = i + 1 # first is taken up by length block mat = None d = bytes_to_gf128(block) try: mat = mc_squared(d, j) except IndexError: matsq = np.linalg.matrix_power(squarer, j) % 2 mat = Mc(d) @ matsq matret += mat return matret % 2 # Because the elements of Ad are in GF2 we can mod 2 def gen_t(n, macbytes, X=None, minrows=8): T = [] for j in range(n*128): blocks = gen_blocks(n, [j]) Ad = gen_ad(blocks) if X is not None: rows = min((n*128)//(X.shape[1]) - 1, macbytes*8-minrows) else: rows = min((n*128)//(Ad.shape[1]) - 1, macbytes*8-minrows) if X is not None: Ad = (Ad[:rows] @ X) % 2 z = np.concatenate(Ad[:rows]) T.append(z) return (np.array(T)).transpose().astype('int') def gen_flips(b): return np.nonzero(b)[0] def compute_n(ct): num_blocks = len(ct)//16 ns = [1] # 1-indexed while (ns[-1]*2)<=(num_blocks): ns.append(2*ns[-1]) ns = [n-1 for n in ns] # 0-indexed into c n = len(ns) - 1 return n def chunk(xs, n=16): return [xs[i*n:(i+1)*n] for i in range(len(xs)//16)] def find_b(n, basis, ct, mac, nonce, aad, oracle): orig_base = bytearray(ct).copy() base = bytearray(ct) idx = 0 while True: choice = random.sample(basis, random.randint(1, 10)) b = sum(choice) % 2 flips = gen_flips(b) blocks = gen_blocks(n, flips) for i, block in enumerate(chunk(blocks)): j = (len(base)//16)-(2**(i+1)-1) base[j*16:(j+1)*16] = xor(base[j*16:(j+1)*16], block) try: oracle(base[len(aad):], base[:len(aad)], mac, nonce) return b except (cryptography.exceptions.InvalidTag, ValueError): base = orig_base.copy() idx += 1 def mac_truncation_recover_secrets(ct, mac, nonce, mac_bytes, aad, oracle, compute_T_once=False): orig_ct = ct ct = aad + ct n = compute_n(ct) assert n > (mac_bytes*8//2) assert len(ct) % 16 == 0 assert len(aad) % 16 == 0 X = None K = None basisKerK = None if compute_T_once: T = gen_t(n, mac_bytes, X, minrows=7) _, _, basisKerT = kernel(T, rref_mod_2) while K is None or (basisKerK is None or len(basisKerK) > 1): if not compute_T_once: T = gen_t(n, mac_bytes, X, minrows=7) _, _, basisKerT = kernel(T, rref_mod_2) assert len(basisKerT[0]) == n*128 b = find_b(n, basisKerT, ct, mac, nonce, aad, oracle) flips = gen_flips(b) blocks = gen_blocks(n, flips) Ad = gen_ad(blocks) if X is not None: AdRelevant = ((Ad @ X) % 2)[:mac_bytes*8] else: AdRelevant = Ad[:mac_bytes*8] incrK = Ad[:mac_bytes*8][np.any(AdRelevant, axis=1)] if K is None: K = incrK else: K = np.concatenate([K, incrK]) _, _, basisKerK = kernel(K, rref_mod_2) if not compute_T_once: X = np.array(basisKerK).transpose() _, _, kerK = kernel(K, rref_mod_2) assert len(kerK) == 1, len(kerK) h = kerK[0] zero_tag = gf128_to_vec(bytes_to_gf128(gmac(vec_to_gf128(h), 0, aad, orig_ct)))[:mac_bytes*8] gf128_mac = 0 i = 0 for b in mac: for j in range(8): if b & (1 << (7-j)): gf128_mac += (1<QQ', ad_length*8, ct_length*8) def collide(k1, k2, nonce, c): h1 = gmac_key(k1) h2 = gmac_key(k2) p1 = gmac_blind(k1, nonce) p2 = gmac_blind(k2, nonce) assert len(c) % 16 == 0 mlen = len(c)//16+1 lens = bytes_to_gf128(encode_lengths(0, len(c) + 16)) acc = gf128_mul(lens, gf128_add(h1, h2)) acc = gf128_add(acc, gf128_add(p1, p2)) for i in range(1, mlen): hi = gf128_add(gf128_exp(h1, mlen+2-i), gf128_exp(h2, mlen+2-i)) acc = gf128_add(acc, gf128_mul(bytes_to_gf128(c[(i-1)*16:((i-1)+1)*16]), hi)) inv = gf128_inv(gf128_add(gf128_mul(h1, h1), gf128_mul(h2, h2))) c_append = gf128_mul(acc, inv) c_ = c + gf128_to_bytes(c_append) mac = gmac(h1, p1, b'', c_) return (c_, mac) def collide_penultimate(k1, k2, nonce, c): h1 = authentication_key(k1) h2 = authentication_key(k2) p1 = gmac_blind(k1, nonce) p2 = gmac_blind(k2, nonce) assert len(c) % 16 == 0 mlen = len(c)//16 lens = bytes_to_gf128(encode_lengths(0, len(c))) acc = gf128_mul(lens, gf128_add(h1, h2)) acc = gf128_add(acc, gf128_add(p1, p2)) n=4 h1Running = gf128_exp(h1, 4) h2Running = gf128_exp(h2, 4) for i in reversed(range(0, mlen-2)): # print(mlen+1-(i)) # i = mlen-2-1-i hi = gf128_add(h1Running, h2Running) h1Running = gf128_mul(h1Running, h1) h2Running = gf128_mul(h2Running, h2) n+=1 # hi = gf128_add(gf128_exp(h1, mlen+1-i), gf128_exp(h2, mlen+1-i)) #print('block', i, pt[i*16:(i+1)*16], 'exp', mlen+1-i) acc = gf128_add(acc, gf128_mul(bytes_to_gf128(c[i*16:(i+1)*16]), hi)) # for i in range(0, mlen-2): # print(i,mlen+1-i) # hi = gf128_add(gf128_exp(h1, mlen+1-i), gf128_exp(h2, mlen+1-i)) # acc = gf128_add(acc, gf128_mul(bytes_to_gf128(c[i*16:(i+1)*16]), hi)) hi = gf128_add(gf128_exp(h1, 2), gf128_exp(h2, 2)) i = mlen-1 #print('block', i, pt[i*16:(i+1)*16], 'exp', 2) acc = gf128_add(acc, gf128_mul(bytes_to_gf128(c[i*16:(i+1)*16]), hi)) inv = gf128_inv(gf128_add(gf128_exp(h1, 3), gf128_exp(h2, 3))) c_append = gf128_mul(acc, inv) c_ = c[:-32] + gf128_to_bytes(c_append) + c[-16:] mac = gmac(h1, p1, b'', c_) return (c_, mac) def gctr_oneblock(k, pt, nonce): enckeyval = nonce + b'\x00\x00\x00\x02' stream = ecb_encrypt(k, enckeyval) return xor(pt, stream) def key_search(nonce, init_bytes1, init_bytes2): seen1 = dict() seen2 = dict() while True: k1 = secrets.token_bytes(16) k2 = secrets.token_bytes(16) ct1 = gctr_oneblock(k1, init_bytes1, nonce) ct2 = gctr_oneblock(k2, init_bytes2, nonce) seen1[ct1] = k1 seen2[ct2] = k2 if ct1 in seen2: return k1, seen2[ct1] if ct2 in seen1: return seen1[ct2], k2 def att_merge_jpg_bmp(jpg, bmp, aad): # Precomputed with key_search; works for any files k1 = unhexlify('8007941455b5af579bb12fff92ef31a3') k2 = unhexlify('14ef746e8b1792e52b1d22ef124fae97') nonce = b'JORGELBORGES' total_len = 6 + (0xffff) + len(jpg) + 0xff # get some extra jpgstream, _ = gcm_encrypt(k1, nonce, aad, b'\x00'*total_len) bmpstream, _ = gcm_encrypt(k2, nonce, aad, b'\x00'*total_len) # 6 bytes r = xor(jpgstream, b'\xff\xd8\xff\xfe\xff\xff') # len(bmp) bytes bmpenc = xor(bmp[6:], bmpstream[6:6+len(bmp)]) r += bmpenc comlen = 0xffff # finish comment with padding r += b'\x00'*(comlen - len(bmpenc)) # jpg r += xor(jpg[2:-2], jpgstream[6+comlen:]) # comment; include penultimate block to be overwritten; therefore must be at least 3 blocks long # also serves to block-align to 14 bytes so the final ciphertext will be complete blocks endcomlen = (28 - (len(r) % 16)) + 16 + 14 tail = b'\xff\xfe' + struct.pack('>H', endcomlen) + b'\x00'*endcomlen + b'\xff\xd9' tailx = xor(tail, jpgstream[6+comlen+len(jpg)-4:]) r += tailx assert len(r) % 16 == 0 cfin, macfin = collide_penultimate(k1, k2, nonce, r) return cfin, macfin # Demos def forbidden_attack_demo(): k = b"tlonorbistertius" nonce = b"jorgelborges" m1 = b"The universe (which others call the Library)" aad1 = b"The Anatomy of Melancholy" m2 = b"From any of the hexagons one can see, interminably" aad2 = b"Letizia Alvarez de Toledo" c1, mac1 = gcm_encrypt(k, nonce, aad1, m1) c2, mac2 = gcm_encrypt(k, nonce, aad2, m2) # Recover the authentication key and blind from public information possible_secrets = nonce_reuse_recover_secrets(nonce, aad1, aad2, c1, c2, mac1, mac2) # Forge the ciphertext m_forged = b"As was natural, this inordinate hope" assert len(m_forged) <= len(m1) c_forged = xor(c1, xor(m1, m_forged)) aad_forged = b"You who read me, are You sure of understanding my language?" # Check possible candidates for authentication key succeeded = False for h, s in possible_secrets: mac_forged = gmac(h, s, aad_forged, c_forged) try: assert gcm_decrypt(k, nonce, aad_forged, c_forged, mac_forged) == m_forged succeeded = True print(c_forged.hex(), mac_forged.hex()) except AssertionError: pass assert succeeded def mac_truncation_demo(): # Doesn't work with non-block size multiples. # Need to modify to consider padding, but we can't mess with the bits in the padding, # nor can we extend ad/ct unless we also change length block. k = b'tlonorbistertius' aad = b'' mac_bytes=1 pt = b'celerypatchworks'*(2**5) nonce = b'jorgelborges' ct, mac = gcm_encrypt(k, nonce, aad, pt, mac_bytes=mac_bytes) def oracle(base, aad, mac, nonce): decryptor = Cipher( algorithms.AES(k), modes.GCM(nonce, mac, min_tag_length=mac_bytes), ).decryptor() decryptor.authenticate_additional_data(aad) decryptor.update(base) + decryptor.finalize() h, s = mac_truncation_recover_secrets(ct, mac, nonce, mac_bytes, aad, oracle, compute_T_once=mac_bytes==1) assert h == authentication_key(k) if __name__ == "__main__": pass # mac_truncation_demo() jpg = open('static/axolotl.jpg', 'rb').read() bmp = open('static/kitten.bmp', 'rb').read() c, mac = att_merge_jpg_bmp(jpg, bmp, aad=b"") print(mac.hex()) f = open('c.txt', 'wb') f.write(c) f.write(mac) f.close()