from binascii import unhexlify import random, struct, hmac, itertools, math, secrets from Crypto.Cipher import AES import cryptography from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes import numpy as np import xarray as xr ## Computation in GF(2^128)/(x^128 + x^7 + x^2 + x^1 + 1) ## Elements are represented as an integer n where (n & (1 << i)) is the coefficient for x^i. class X: """A helper class for specifying elements of GF(2^n).""" def __pow__(self, y): return 1 << y x = X() gcm_modulus = x**128 + x**7 + x**2 + x**1 + 1 gcm_modulus_degree = 128 def gf128_deg(x): """Compute the degree of x. >>> assert gf128_deg(x**20 + x**5 + 1) == 20 >>> assert gf128_deg(x**90) == 90 >>> assert gf128_deg(x**1) == 1 >>> assert gf128_deg(1) == 0 >>> assert gf128_deg(0) == -1 """ return x.bit_length() - 1 def gf128_add(a, b): """Compute a + b. >>> assert gf128_add(0, 0) == 0 >>> assert gf128_add(x**7 + x**2, 0) == x**7 + x**2 >>> assert gf128_add(x**7 + x**2, x**2 + x**1) == x**7 + x**1 >>> assert gf128_add(x**7 + x**2, x**7 + x**2) == 0 """ return a ^ b def gf128_mul(a, b): """Compute ab. >>> assert gf128_mul(x**20 + x**5 + 1, x**13 + x**2) == x**33 + x**22 + x**18 + x**13 + x**7 + x**2 >>> assert gf128_mul(x**80, x**90) == x**49 + x**44 + x**43 + x**42 >>> assert gf128_mul(x**23 + x**5, 1) == x**23 + x**5 """ if a == 0 or b == 0: return 0 if a > b: return gf128_mul(b, a) p = 0 deg_b = gf128_deg(b) while a > 0: if a & 1: p ^= b a = a >> 1 b = b << 1 deg_b += 1 if deg_b == gcm_modulus_degree: b = b ^ gcm_modulus deg_b = gf128_deg(b) return p def gf128_divmod(a, b): """Compute q, r such that a = qb + r; gf128_deg(r) < gf128_deg(b). >>> assert gf128_divmod(x**20 + x**5 + 1, x**13 + x**2) == (x**7, x**9 + x**5 + 1) >>> assert gf128_divmod(x**90, x**80) == (x**10, 0) """ q, r = 0, a while gf128_deg(r) >= gf128_deg(b): d = gf128_deg(r) - gf128_deg(b) q = q ^ (1 << d) r = r ^ (b << d) return q, r def gf128_mod(a, m=gcm_modulus): """Compute b such that a = b; gf128_deg(b) < gf128_deg(m). >>> assert gf128_mod(x**220 + x**5 + 1, x**50 + x**4) == x**36 + x**5 + 1 >>> assert gf128_mod(x**130 + x**64 + x**2) == x**64 + x**9 + x**4 + x**3 """ q, r = gf128_divmod(a, m) return r def gf128_egcd(a, b): """Compute g, x, y such that g | a, g | b, g = ax + by; for any h such that h | a and h | b, gf128_deg(g) >= gf128_deg(h). >>> assert gf128_egcd(x**2 + 1, x**1 + 1) == (x**1 + 1, 0, 1) >>> assert gf128_egcd(x**12 + x**4, x**5 + 1) == (x**1 + 1, x**3 + x**2 + 1, x**10 + x**9 + x**7 + x**5 + x**4 + x**1 + 1) >>> assert gf128_egcd(x**64 + x**2, x**37 + x**12 + 1) == (1, x**35 + x**34 + x**33 + x**31 + x**30 + x**28 + x**27 + x**25 + x**24 + x**21 + x**20 + x**18 + x**17 + x**15 + x**14 + x**12 + x**11 + x**9 + x**7 + x**6 + x**4 + x**3 + x**1 + 1, x**62 + x**61 + x**60 + x**58 + x**57 + x**55 + x**54 + x**52 + x**51 + x**48 + x**47 + x**45 + x**44 + x**42 + x**41 + x**39 + x**38 + x**37 + x**35 + x**34 + x**32 + x**31 + x**29 + x**28 + x**26 + x**25 + x**24 + x**22 + x**21 + x**19 + x**18 + x**16 + x**15 + x**13 + x**12 + x**11 + x**9 + x**8 + x**6 + x**5 + x**3 + x**2 + 1) """ orig_a = a orig_b = b a_x = 1 a_y = 0 b_x = 0 b_y = 1 # Invariant = a_x*x + a_y*y = a, b_x*x + b_y*y = b, r_x*r + r_y*y = r r_x = 0 r_y = 1 if a == 0: return b, r_x, r_y while True: q, r = gf128_divmod(b, a) if r == 0: # FIX mul so mod isn't necessary; with zeros assert a == gf128_mod(gf128_add(gf128_mul(orig_a, r_x), gf128_mul(orig_b, r_y))) return a, r_x, r_y r_x = gf128_add(b_x, gf128_mul(q, a_x)) r_y = gf128_add(b_y, gf128_mul(q, a_y)) a, b = r, a b_x, b_y = a_x, a_y a_x, a_y = r_x, r_y def gf128_inv(a): """Compute b such that ab = 1. >>> assert gf128_inv(x**1) == x**127 + x**6 + x**1 + 1 >>> assert gf128_inv(x**4) == x**127 + x**125 + x**124 + x**6 + x**4 + x**3 + x**1 + 1 """ g, xx, yy = gf128_egcd(a, gcm_modulus) assert g == 1 return xx def gf128_exp(a, n): """Compute a^n. >>> assert gf128_exp(x**2 + 1, 0) == 1 >>> assert gf128_exp(x**2 + 1, 1) == x**2 + 1 >>> assert gf128_exp(x**2 + 1, 2) == x**4 + 1 >>> assert gf128_exp(x**2 + 1, 10) == x**20 + x**16 + x**4 + 1 """ if n == 0: return 1 rec = gf128_exp(a, n//2) rec = gf128_mul(rec, rec) if n % 2 == 1: rec = gf128_mul(rec, a) return rec def gf128_sqrt(x): """Compute y such that y^2 = x: the inverse Frobenius automorphism. Reference: https://math.stackexchange.com/questions/943417/square-root-for-galois-fields-gf2m >>> assert gf128_sqrt(1) == 1 >>> assert gf128_sqrt(x**2 + 1) == x**1 + 1 >>> assert gf128_sqrt(x**3 + 1) == x**126 + x**123 + x**120 + x**117 + x**114 + x**111 + x**108 + x**105 + x**102 + x**99 + x**96 + x**93 + x**90 + x**87 + x**84 + x**81 + x**78 + x**75 + x**72 + x**69 + x**66 + x**63 + x**62 + x**60 + x**59 + x**57 + x**56 + x**54 + x**53 + x**51 + x**50 + x**48 + x**47 + x**45 + x**44 + x**42 + x**41 + x**39 + x**38 + x**36 + x**35 + x**33 + x**32 + x**30 + x**29 + x**27 + x**26 + x**24 + x**23 + x**21 + x**20 + x**18 + x**17 + x**15 + x**14 + x**12 + x**11 + x**9 + x**8 + x**6 + x**3 + 1 """ return gf128_exp(x, 2**127) def gf128_to_vec(x): return [int(n) for n in bin(x)[2:].zfill(128)[::-1]] def vec_to_gf128(vs): x = 0 for i, v in enumerate(vs): if v: x += (1 << i) return x def reverse_mask(b): # from https://stackoverflow.com/a/2602885 b = (b & 0xF0) >> 4 | (b & 0x0F) << 4; b = (b & 0xCC) >> 2 | (b & 0x33) << 2; b = (b & 0xAA) >> 1 | (b & 0x55) << 1; return b; def bytes_to_gf128(xs, st=None, en=None): if st is None and en is None: st = 0 en = 16 a, b = struct.unpack('>= 8 s += b'\x00' * (16 - len(s)) assert len(s) == 16 return s def Ms(): vs = [] for i in range(0, 128): v = gf128_to_vec(gf128_mod(1 << (2*i))) vs.append(v) return np.array(vs).transpose() def Mc(c): vs = [] pr = c for i in range(0, 128): vs.append(gf128_to_vec(pr)) pr = gf128_mul(x**1, pr) return np.array(vs).transpose() # Adapted from Wiki def rref_mod_2(M): M = M.copy().astype('int') lead = 0 rowCount, columnCount = M.shape adj = np.eye(rowCount, rowCount).astype('int') for r in range(rowCount): if columnCount <= lead: return M, adj i = r while M[i][lead] == 0: i = i + 1 if rowCount == i: i = r lead = lead + 1 if columnCount == lead: return M, adj if i != r: M[[i, r]] = M[[r, i]] adj[[i, r]] = adj[[r, i]] for j in range(rowCount): if M[j][lead] == 1 and j != r: M[j] ^= M[r] adj[j] ^= adj[r] lead = lead + 1 return M, adj def kernel(M, f): mt = M.transpose() N, adj = f(mt) basis = [] for i, row in enumerate(N): if np.all(np.isclose(row, 0)): basis.append(adj[i]) return N, adj, basis ## Computation in GF(2^128)[X]/(x^128 + x^7 + x^2 + x^1 + 1) ## Elements are represented as arrays where the ith element is the coefficient for x^i. def gf128poly_lead(x): """Compute the leading coefficient of x. >>> assert gf128poly_lead([x**3, x**1, 1]) == 1 >>> assert gf128poly_lead([x**3, x**1, x**2, x**4]) == x**4 >>> assert gf128poly_lead([1]) == 1 """ assert len(x) > 0 return x[-1] def collapse(x): if len(x) == 0: return [] if gf128poly_lead(x) == 0: return collapse(x[:-1]) return [gf128_mod(y, gcm_modulus) for y in x] def gf128poly_deg(x): """Compute the degree of x. >>> assert gf128poly_deg([x**3, x**1, 1]) == 2 >>> assert gf128poly_deg([x**2, x**3]) == 1 >>> assert gf128poly_deg([x**15]) == 0 >>> assert gf128poly_deg([1]) == 0 >>> assert gf128poly_deg([]) == -1 """ return len(x) - 1 def gf128poly_add(a, b): """Compute a + b. >>> assert gf128poly_add([x**3, x**1, 1], [x**2 + 1, x**3]) == [x**3 + x**2 + 1, x**1 + x**3, 1] >>> assert gf128poly_add([x**2, x**3], [x**5, x**12, x**2, x**4]) == [x**5 + x**2, x**12 + x**3, x**2, x**4] >>> assert gf128poly_add([x**15], []) == [x**15] >>> assert gf128poly_add([x**15], [x**15]) == [] """ return collapse([gf128_add(x, y) for x, y in itertools.zip_longest(a, b, fillvalue=0)]) def gf128poly_mul(a, b): """Compute ab. >>> assert gf128poly_mul([x**3, x**1, 1], [x**2 + 1, x**3]) == [x**5 + x**3, x**6 + x**3 + x**1, x**4 + x**2 + 1, x**3] >>> assert gf128poly_mul([x**2, x**3], [x**5, x**12, x**2, x**4]) == [x**7, x**14 + x**8, x**15 + x**4, x**6 + x**5, x**7] >>> assert gf128poly_mul([x**15], []) == [] >>> assert gf128poly_mul([x**15], [1]) == [x**15] >>> assert gf128poly_mul([0, x**15], [0, x**15]) == [0, 0, x**30] >>> assert gf128poly_mul([x**70], [1, x**70, 1]) == [x**70, x**19 + x**14 + x**13 + x**12, x**70] """ p = [] while gf128poly_deg(a) >= 0: p = gf128poly_add(p, [gf128_mul(x, a[0]) for x in b]) a = a[1:] b = [0] + b return collapse(p) def gf128poly_divmod(a, b): """Compute q, r such that a = qb + r; gf128poly_deg(r) < gf128poly_deg(b). >>> assert gf128poly_divmod([x**5 + x**3, x**6 + x**3 + x**1, x**4 + x**2 + 1, x**3], [x**2 + 1, x**3]) == ([x**3, x**1, 1], []) >>> assert gf128poly_divmod([x**7, x**14 + x**8, x**15 + x**4, x**6 + x**5, x**7], [x**2, x**3]) == ([x**5, x**12, x**2, x**4], []) >>> assert gf128poly_divmod([x**15], [1]) == ([x**15], []) >>> assert gf128poly_divmod([x**15], [x**15]) == ([1], []) >>> assert gf128poly_divmod([0, 0, x**30], [0, x**15]) == ([0, x**15], []) >>> assert gf128poly_divmod([x**70, x**19 + x**14 + x**13 + x**12, x**70], [1, x**70, 1]) == ([x**70], []) >>> assert gf128poly_divmod([x**18, x**18, x**3], [x**3, x**3]) == ([x**15 + 1, 1], [x**3]) """ assert b != [], 'divide by zero' q = [] while gf128poly_deg(a) >= gf128poly_deg(b): highA = gf128poly_lead(a) highB = gf128poly_lead(b) highBinv = 1 if highB == 1 else gf128_inv(highB) qq = gf128_mul(highA, highBinv) de = gf128poly_deg(a) - gf128poly_deg(b) new_term = [0]*(de+1) new_term[de] = qq q = gf128poly_add(q, new_term) fact = gf128poly_mul(b, new_term) a = gf128poly_add(a, fact) return collapse(q), collapse(a) def gf128poly_sqrt(p): """Compute q such that q^2 = p for p such that p' = 0. >>> assert gf128poly_sqrt([1, 0, 1, 0, 1, 0, 1]) == [1, 1, 1, 1] >>> assert gf128poly_sqrt([1, 0, 1, 0, 0, 0, 1]) == [1, 1, 0, 1] >>> assert gf128poly_sqrt([1, 0, x**4]) == [1, x**2] >>> assert gf128poly_sqrt([1, 0, x**3]) == [1, x**126 + x**123 + x**120 + x**117 + x**114 + x**111 + x**108 + x**105 + x**102 + x**99 + x**96 + x**93 + x**90 + x**87 + x**84 + x**81 + x**78 + x**75 + x**72 + x**69 + x**66 + x**63 + x**62 + x**60 + x**59 + x**57 + x**56 + x**54 + x**53 + x**51 + x**50 + x**48 + x**47 + x**45 + x**44 + x**42 + x**41 + x**39 + x**38 + x**36 + x**35 + x**33 + x**32 + x**30 + x**29 + x**27 + x**26 + x**24 + x**23 + x**21 + x**20 + x**18 + x**17 + x**15 + x**14 + x**12 + x**11 + x**9 + x**8 + x**6 + x**3] """ assert gf128poly_formal_derivative(p) == [] q = [] for x in p: if x == 0 or x == 1: q.append(x) else: q.append(gf128_sqrt(x)) for i in range(2, len(q), 2): q[i//2] = q[i] q[i] = 0 return collapse(q) def gf128poly_monic(p): """Compute p/p_lead for p such that gf128poly_deg(p) >= 0. >>> assert gf128poly_monic([x**54, x**2]) == [x**52, 1] >>> assert gf128poly_monic([x**2, x**3]) == [x**127 + x**6 + x**1 + 1, 1] """ assert len(p) > 0 lead = gf128poly_lead(p) if lead == 1: return p lead_inv = gf128_inv(lead) return [gf128_mul(x, lead_inv) for x in p] def gf128poly_formal_derivative(p): """Compute p'. >>> assert gf128poly_formal_derivative([]) == [] >>> assert gf128poly_formal_derivative([x**25]) == [] >>> assert gf128poly_formal_derivative([x**25, x**3 + x**1, x**10, x**7]) == [x**3 + x**1, 0, x**7] """ q = [] for i, x in enumerate(p): if i == 0: continue if i % 2 == 0: e = 0 else: e = x q.append(e) return collapse(q) def gf128poly_modexp(a, n, m): """Compute a^n (mod m). >>> assert gf128poly_modexp([x**2, x**3, x**4], 0, [x**3, x**2, 1]) == [1] >>> assert gf128poly_modexp([x**2, x**3, x**4], 1, [x**3, x**2, 1, x**5]) == [x**2, x**3, x**4] >>> assert gf128poly_modexp([x**2, x**3, x**4], 5, [x**3, x**2, 1]) == [x**39 + x**38 + x**37 + x**36 + x**35 + x**33 + x**32 + x**30 + x**27 + x**26 + x**25 + x**24 + x**21 + x**20 + x**15 + x**10, x**38 + x**36 + x**35 + x**33 + x**32 + x**31 + x**26 + x**24 + x**23 + x**22 + x**21 + x**20 + x**14 + x**11] """ def modmul(a, b, m): q, r = gf128poly_divmod(gf128poly_mul(a, b), m) return r if n == 0: return [1] rec = gf128poly_modexp(a, n//2, m) rec = modmul(rec, rec, m) if n % 2 == 1: rec = modmul(rec, a, m) return rec def gf128poly_monic_gcd(a, b): """Compute g such that g | a, g | b, gf128poly_lead(g) = 1; for any h such that h | a and h | b, gf128poly_deg(g) >= gf128poly_deg(h). >>> assert gf128poly_monic_gcd([x**3, x**2, x**4], [x**3, x**2, x**4]) == [x**127 + x**6 + x**1 + 1, x**127 + x**126 + x**6 + x**5 + x**1, 1] >>> assert gf128poly_monic_gcd([1], [1]) == [1] >>> assert gf128poly_monic_gcd([x**5 + x**2], []) == [x**5 + x**2] >>> assert gf128poly_monic_gcd([x**3, 0, 0, x**6], [x**3, x**8 + x**4, x**9 + x**7, x**8]) == [x**127 + x**6 + x**1 + 1, 1] """ assert gf128poly_deg(a) >= 0 or gf128poly_deg(b) >= 0 if gf128poly_deg(a) < 0: return b if gf128poly_deg(b) < 0: return a q, r = gf128poly_divmod(b, a) ret = gf128poly_monic_gcd(r, a) return gf128poly_monic(ret) def gf128poly_square_free_factorization(f): """Compute the square-free factorization of a monic polynomial f. Reference: https://en.wikipedia.org/wiki/Factorization_of_polynomials_over_finite_fields#Square-free_factorization >>> a = [x**1, x**2, 1] >>> b = [x**2, x**3+1, 1] >>> c = [x**3, x**4, x**5, 1] >>> asq = gf128poly_mul(a, a) >>> acb = gf128poly_mul(asq, a) >>> bsq = gf128poly_mul(b, b) >>> acbbsq = gf128poly_mul(acb, bsq) >>> p = gf128poly_mul(acbbsq, c) >>> assert gf128poly_square_free_factorization([1]) == [([1], 1)] >>> assert gf128poly_square_free_factorization(p) == [([x**3, x**4, x**5, 1], 1), ([x**1, x**2, 1], 3), ([x**2, x**3 + 1, 1], 2)] """ R = [] assert f[-1] == 1, 'monic' if f == [1]: return [([1], 1)] fprime = gf128poly_formal_derivative(f) c = gf128poly_monic_gcd(f, fprime) w, r = gf128poly_divmod(f, c) assert r == [] i = 1 while w != [1]: y = gf128poly_monic_gcd(w, c) fac, r = gf128poly_divmod(w, y) assert r == [] R.append((fac, i)) w = y c, r = gf128poly_divmod(c, y) assert r == [] i = i + 1 if c != [1]: c = gf128poly_sqrt(c) c = gf128poly_monic(c) Rrec = gf128poly_square_free_factorization(c) R += [(fac, i*2) for (fac, i) in Rrec] return [(fac, mult) for fac, mult in R if fac != [1]] def gf128poly_distinct_degree_factorization(f): """Compute the distinct-degree factorization of a square-free monic polynomial f. Reference: https://en.wikipedia.org/wiki/Factorization_of_polynomials_over_finite_fields#Distinct-degree_factorization >>> a = [x**1, 1] >>> b = [x**2, 1] >>> ab = gf128poly_mul(a, b) >>> assert gf128poly_distinct_degree_factorization(ab) == {((x**3, x**2 + x**1, 1), 1)} """ i = 1 S = set() f_ = f.copy() while gf128poly_deg(f_) >= 2*i: qq = gf128poly_modexp([0, 1], 2**(128*i), f_) qq = gf128poly_add(qq, [0, 1]) g = gf128poly_monic_gcd(f_, qq) if g != [1]: S.add((tuple(g), i)) q, r = gf128poly_divmod(f_, g) assert r == [] f_ = q i += 1 if f_ != [1]: S.add((tuple(f_), gf128poly_deg(f_))) if len(S) == 0: return {(tuple(f), 1)} else: return S def gf128poly_equal_degree_factorization(f, d): """Compute the equal-degree factorization of a square-free monic polynomial f whose factors are all of degree d. Uses the Cantor-Zassenhaus algorithm. Reference: https://en.wikipedia.org/wiki/Factorization_of_polynomials_over_finite_fields#Equal-degree_factorization >>> a = [x**1, 1] >>> b = [x**2, 1] >>> ab = gf128poly_mul(a, b) >>> assert gf128poly_equal_degree_factorization(ab, 1) == {(x**2, 1), (x**1, 1)} """ def gf128_rand(): return random.randint(0, 2**128-1) def gf128poly_rand(degree): return [gf128_rand() for _ in range(degree+1)] exp = (2**128-1)//3 n = gf128poly_deg(f) r = n//d S = {tuple(f)} while len(S) < r: h = gf128poly_rand(n - 1) g = gf128poly_monic_gcd(h, f) if g == [1]: g = gf128poly_add(gf128poly_modexp(h, exp, f), [1]) for u in list(S): u = list(u) if gf128poly_deg(u) == d: continue gugcd = gf128poly_monic_gcd(g, u) if gugcd != [1] and gugcd != u: S.remove(tuple(u)) S.add(tuple(gugcd)) qq, rr = gf128poly_divmod(u, gugcd) assert rr == [] S.add(tuple(qq)) return S def gf128poly_factorize(f, degree=None): """Compute the factors of a polynomial f. Does not return multiplicity. If degree is specified, only returns factors of that degree. """ factors = set() f = gf128poly_monic(f) fs = gf128poly_square_free_factorization(f) for p, _ in fs: qs = gf128poly_distinct_degree_factorization(p) for q, d in qs: if degree is not None and d != degree: continue rs = gf128poly_equal_degree_factorization(list(q), d) factors |= rs return factors ## AES-GCM def pad(xs): return xs + b'\x00' * ((16 - len(xs)) % 16) def build_blocks(aad, c): padded_aad = pad(aad) padded_c = pad(c) length = struct.pack('>QQ', len(aad)*8, len(c)*8) return padded_aad + padded_c + length def gmac(h, s, aad, c): blocks = build_blocks(aad, c) g = 0 for i in range(len(blocks)//16): block = bytes_to_gf128(blocks, i*16, (i + 1)*16) g = gf128_add(g, block) g = gf128_mul(g, h) t = gf128_add(g, s) return gf128_to_bytes(t) def ecb_encrypt(key, xs): c = AES.new(key, AES.MODE_ECB) return c.encrypt(xs) def xor(a, b): return bytes([(x^y) for (x, y) in zip(a, b)]) def gctr(k, nonce, x): y = b'' ctr = (int.from_bytes(nonce, 'big') << 32) + 2 i = 0 while len(y) < len(x): y += xor(x[i*16:(i + 1)*16], ecb_encrypt(k, ctr.to_bytes(16, 'big'))) ctr += 1 i += 1 return y def authentication_key(k): return bytes_to_gf128(ecb_encrypt(k, b'\x00'*16)) def blind(k, nonce): return bytes_to_gf128(ecb_encrypt(k, nonce + b'\x00\x00\x00\x01')) def gcm_encrypt(k, nonce, aad, m, mac_bytes=16): c = gctr(k, nonce, m) mac = gmac(authentication_key(k), blind(k, nonce), aad, c) return c, mac[:mac_bytes] def gcm_decrypt(k, nonce, aad, c, mac, mac_bytes=16): m = gctr(k, nonce, c) expected_mac = gmac(authentication_key(k), blind(k, nonce), aad, c)[:mac_bytes] assert hmac.compare_digest(mac, expected_mac) return m ## Forbidden Attack def compute_forbidden_polynomial(aad1, aad2, c1, c2, mac1, mac2): bs1 = build_blocks(aad1, c1) bs2 = build_blocks(aad2, c2) if len(bs1) < len(bs2): bs1 = b'\x00'*(len(bs2)-len(bs1)) + bs1 else: bs2 = b'\x00'*(len(bs1)-len(bs2)) + bs2 assert len(bs1) == len(bs2) f = [] N = len(bs1)//16 for i in range(N): b1 = bytes_to_gf128(bs1[i*16:(i+1)*16]) b2 = bytes_to_gf128(bs2[i*16:(i+1)*16]) f.append(gf128_add(b1, b2)) f.append(gf128_add(bytes_to_gf128(mac1), bytes_to_gf128(mac2))) return collapse(list(reversed(f))) def nonce_reuse_recover_secrets(nonce, aad1, aad2, c1, c2, mac1, mac2): f = compute_forbidden_polynomial(aad1, aad2, c1, c2, mac1, mac2) factors = gf128poly_factorize(f, degree=1) secrets = [] for factor in factors: if gf128poly_deg(factor) == 1: h = factor[0] zero_tag = gmac(h, 0, aad1, c1) s = gf128_add(bytes_to_gf128(mac1), bytes_to_gf128(zero_tag)) secrets.append((h, s)) return secrets ## MAC Truncation Attack def gen_blocks(n, js): blocks = b'\x00'*16*n for j in js: whichbyte = j // 8 whichbit = j % 8 newbyte = blocks[whichbyte] ^ (1 << (7-whichbit)) blocks = blocks[:whichbyte] + bytes([newbyte]) + blocks[whichbyte+1:] return blocks squarer = np.array(Ms()) use_numpy = True if use_numpy: mcsqlookup = np.load(open('square-basis.np', 'rb')) else: mcsqlookup = xr.open_dataarray('square-basis.nc') def mc_squared(c, j): return sum(mcsqlookup[i, j] if use_numpy else mcsqlookup[i, j].to_numpy() for i in range(128) if 1 == (c >> i) & 1) % 2 def gen_ad(blocks): matret = np.zeros((128, 128)) for i in range(len(blocks)//16): block = blocks[(i*16):(i+1)*16] if block == b'\x00'*16: continue j = i + 1 # first is taken up by length block mat = None d = bytes_to_gf128(block) try: mat = mc_squared(d, j) except IndexError: matsq = np.linalg.matrix_power(squarer, j) % 2 mat = Mc(d) @ matsq matret += mat return matret % 2 # Because the elements of Ad are in GF2 we can mod 2 def gen_t(n, macbytes, X=None, minrows=8): T = [] for j in range(n*128): blocks = gen_blocks(n, [j]) Ad = gen_ad(blocks) if X is not None: rows = min((n*128)//(X.shape[1]) - 1, macbytes*8-minrows) else: rows = min((n*128)//(Ad.shape[1]) - 1, macbytes*8-minrows) if X is not None: Ad = (Ad[:rows] @ X) % 2 z = np.concatenate(Ad[:rows]) T.append(z) return (np.array(T)).transpose().astype('int') def gen_flips(b): return np.nonzero(b)[0] def compute_n(ct): num_blocks = len(ct)//16 ns = [1] # 1-indexed while (ns[-1]*2)<=(num_blocks): ns.append(2*ns[-1]) ns = [n-1 for n in ns] # 0-indexed into c n = len(ns) - 1 return n def chunk(xs, n=16): return [xs[i*n:(i+1)*n] for i in range(len(xs)//16)] def find_b(n, basis, ct, mac, nonce, aad, oracle): orig_base = bytearray(ct).copy() base = bytearray(ct) idx = 0 while True: choice = random.sample(basis, random.randint(1, 10)) b = sum(choice) % 2 flips = gen_flips(b) blocks = gen_blocks(n, flips) for i, block in enumerate(chunk(blocks)): j = (len(base)//16)-(2**(i+1)-1) base[j*16:(j+1)*16] = xor(base[j*16:(j+1)*16], block) try: oracle(base[len(aad):], base[:len(aad)], mac, nonce) return b except (cryptography.exceptions.InvalidTag, ValueError): base = orig_base.copy() idx += 1 def mac_truncation_recover_secrets(ct, mac, nonce, mac_bytes, aad, oracle, compute_T_once=False): orig_ct = ct ct = aad + ct n = compute_n(ct) assert n > (mac_bytes*8//2) assert len(ct) % 16 == 0 assert len(aad) % 16 == 0 X = None K = None basisKerK = None if compute_T_once: T = gen_t(n, mac_bytes, X, minrows=7) _, _, basisKerT = kernel(T, rref_mod_2) while K is None or (basisKerK is None or len(basisKerK) > 1): if not compute_T_once: T = gen_t(n, mac_bytes, X, minrows=7) _, _, basisKerT = kernel(T, rref_mod_2) assert len(basisKerT[0]) == n*128 b = find_b(n, basisKerT, ct, mac, nonce, aad, oracle) flips = gen_flips(b) blocks = gen_blocks(n, flips) Ad = gen_ad(blocks) if X is not None: AdRelevant = ((Ad @ X) % 2)[:mac_bytes*8] else: AdRelevant = Ad[:mac_bytes*8] incrK = Ad[:mac_bytes*8][np.any(AdRelevant, axis=1)] if K is None: K = incrK else: K = np.concatenate([K, incrK]) _, _, basisKerK = kernel(K, rref_mod_2) if not compute_T_once: X = np.array(basisKerK).transpose() _, _, kerK = kernel(K, rref_mod_2) assert len(kerK) == 1, len(kerK) h = kerK[0] zero_tag = gf128_to_vec(bytes_to_gf128(gmac(vec_to_gf128(h), 0, aad, orig_ct)))[:mac_bytes*8] gf128_mac = 0 i = 0 for b in mac: for j in range(8): if b & (1 << (7-j)): gf128_mac += (1<QQ', ad_length*8, ct_length*8) def collide(k1, k2, nonce, c): h1 = authentication_key(k1) h2 = authentication_key(k2) p1 = blind(k1, nonce) p2 = blind(k2, nonce) assert len(c) % 16 == 0 mlen = len(c)//16+1 lens = bytes_to_gf128(encode_lengths(0, len(c) + 16)) acc = gf128_mul(lens, gf128_add(h1, h2)) acc = gf128_add(acc, gf128_add(p1, p2)) h1Running = gf128_exp(h1, 3) h2Running = gf128_exp(h2, 3) for i in reversed(range(mlen-1)): hi = gf128_add(h1Running, h2Running) h1Running = gf128_mul(h1Running, h1) h2Running = gf128_mul(h2Running, h2) acc = gf128_add(acc, gf128_mul(bytes_to_gf128(c[i*16:(i+1)*16]), hi)) inv = gf128_inv(gf128_add(gf128_mul(h1, h1), gf128_mul(h2, h2))) c_append = gf128_mul(acc, inv) c_ = c + gf128_to_bytes(c_append) mac = gmac(h1, p1, b'', c_) return (c_, mac) def collide_penultimate(k1, k2, nonce, c): h1 = authentication_key(k1) h2 = authentication_key(k2) p1 = blind(k1, nonce) p2 = blind(k2, nonce) assert len(c) % 16 == 0 mlen = len(c)//16 lens = bytes_to_gf128(encode_lengths(0, len(c))) acc = gf128_mul(lens, gf128_add(h1, h2)) acc = gf128_add(acc, gf128_add(p1, p2)) h1Running = gf128_exp(h1, 4) h2Running = gf128_exp(h2, 4) for i in reversed(range(0, mlen-2)): hi = gf128_add(h1Running, h2Running) h1Running = gf128_mul(h1Running, h1) h2Running = gf128_mul(h2Running, h2) acc = gf128_add(acc, gf128_mul(bytes_to_gf128(c[i*16:(i+1)*16]), hi)) hi = gf128_add(gf128_exp(h1, 2), gf128_exp(h2, 2)) i = mlen-1 acc = gf128_add(acc, gf128_mul(bytes_to_gf128(c[i*16:(i+1)*16]), hi)) inv = gf128_inv(gf128_add(gf128_exp(h1, 3), gf128_exp(h2, 3))) c_append = gf128_mul(acc, inv) c_ = c[:-32] + gf128_to_bytes(c_append) + c[-16:] mac = gmac(h1, p1, b'', c_) return (c_, mac) def gctr_oneblock(k, pt, nonce): enckeyval = nonce + b'\x00\x00\x00\x02' stream = ecb_encrypt(k, enckeyval) return xor(pt, stream) def key_search_jpg_bmp(nonce, init_bytes1, init_bytes2): seen1 = dict() seen2 = dict() while True: k1 = secrets.token_bytes(16) k2 = secrets.token_bytes(16) ct1 = gctr_oneblock(k1, init_bytes1, nonce) ct2 = gctr_oneblock(k2, init_bytes2, nonce) seen1[ct1] = k1 seen2[ct2] = k2 if ct1 in seen2: return k1, seen2[ct1] if ct2 in seen1: return seen1[ct2], k2 def att_merge_jpg_bmp(jpg, bmp, aad): # Precomputed with key_search_jpg_bmp; works for any files k1 = unhexlify('8007941455b5af579bb12fff92ef31a3') k2 = unhexlify('14ef746e8b1792e52b1d22ef124fae97') nonce = b'JORGELBORGES' total_len = 6 + (0xffff) + len(jpg) + 0xff # get some extra jpgstream, _ = gcm_encrypt(k1, nonce, aad, b'\x00'*total_len) bmpstream, _ = gcm_encrypt(k2, nonce, aad, b'\x00'*total_len) # 6 bytes r = xor(jpgstream, b'\xff\xd8\xff\xfe\xff\xff') # len(bmp) bytes bmpenc = xor(bmp[6:], bmpstream[6:6+len(bmp)]) r += bmpenc comlen = 0xffff # finish comment with padding r += b'\x00'*(comlen - len(bmpenc)) # jpg r += xor(jpg[2:-2], jpgstream[6+comlen:]) # comment; include penultimate block to be overwritten; therefore must be at least 3 blocks long # also serves to block-align to 14 bytes so the final ciphertext will be complete blocks endcomlen = (28 - (len(r) % 16)) + 16 + 14 tail = b'\xff\xfe' + struct.pack('>H', endcomlen) + b'\x00'*endcomlen + b'\xff\xd9' tailx = xor(tail, jpgstream[6+comlen+len(jpg)-4:]) r += tailx assert len(r) % 16 == 0 cfin, macfin = collide_penultimate(k1, k2, nonce, r) return cfin, macfin def key_search_pdf_pdf(): a = ''' %PDF-1.7 %µ¶ 0 0 obj <<>> stream '''.strip().encode('utf-8') nonce = b'JORGELBORGES' k1 = secrets.token_bytes(16) m1 = a + b'\x0a' c1 = gctr(k1, nonce, m1) while True: k2 = secrets.token_bytes(16) m2 = gctr(k2, nonce, c1) if m2[0] == b'%'[0] and b'\x0a' not in m2[:-1] and m2[-1] == b'\x0a'[0]: return k1, k2, nonce, c1 def att_merge_pdf_pdf(pdf1, pdf2, aad): # precomputed with key_search_pdf_pdf k1 = unhexlify('c94a4dbd95faf02bdc0c39e0c0984299') k2 = unhexlify('e4d26cdfbc732473103a5a887a755e19') nonce = unhexlify('4a4f5247454c424f52474553') r = unhexlify('ade70922bef96292d1b7d39d53140ed2229a6819eebe86f5a536ad7da256679ae12b88a8bbfad501') N = len(pdf1) + len(pdf2) + 1000 pdf1stream = gctr(k1, nonce, b'\x00'*N) pdf2stream = gctr(k2, nonce, b'\x00'*N) r += xor(pdf2, pdf2stream[len(r):]) r += xor(b"\x0aendstream\x0aendobj\x0a", pdf1stream[len(r):]) r += xor(pdf1, pdf1stream[len(r):]) r += b'\x00' * (16 - (len(r) % 16)) return collide(k1, k2, nonce, r) # Demos def forbidden_attack_demo(): k = b"tlonorbistertius" nonce = b"jorgelborges" m1 = b"The universe (which others call the Library)" aad1 = b"The Anatomy of Melancholy" m2 = b"From any of the hexagons one can see, interminably" aad2 = b"Letizia Alvarez de Toledo" c1, mac1 = gcm_encrypt(k, nonce, aad1, m1) c2, mac2 = gcm_encrypt(k, nonce, aad2, m2) # Recover the authentication key and blind from public information possible_secrets = nonce_reuse_recover_secrets(nonce, aad1, aad2, c1, c2, mac1, mac2) # Forge the ciphertext m_forged = b"As was natural, this inordinate hope" assert len(m_forged) <= len(m1) c_forged = xor(c1, xor(m1, m_forged)) aad_forged = b"You who read me, are You sure of understanding my language?" # Check possible candidates for authentication key succeeded = False for h, s in possible_secrets: mac_forged = gmac(h, s, aad_forged, c_forged) try: assert gcm_decrypt(k, nonce, aad_forged, c_forged, mac_forged) == m_forged succeeded = True print(c_forged.hex(), mac_forged.hex()) except AssertionError: pass assert succeeded def mac_truncation_demo(): # Doesn't work with non-block size multiples. # Need to modify to consider padding, but we can't mess with the bits in the padding, # nor can we extend ad/ct unless we also change length block. k = b'tlonorbistertius' aad = b'' mac_bytes=1 pt = b'celerypatchworks'*(2**5) nonce = b'jorgelborges' ct, mac = gcm_encrypt(k, nonce, aad, pt, mac_bytes=mac_bytes) def oracle(base, aad, mac, nonce): decryptor = Cipher( algorithms.AES(k), modes.GCM(nonce, mac, min_tag_length=mac_bytes), ).decryptor() decryptor.authenticate_additional_data(aad) decryptor.update(base) + decryptor.finalize() h, s = mac_truncation_recover_secrets(ct, mac, nonce, mac_bytes, aad, oracle, compute_T_once=mac_bytes==1) assert h == authentication_key(k) if __name__ == "__main__": pass # mac_truncation_demo() jpg = open('static/axolotl.jpg', 'rb').read() bmp = open('static/kitten.bmp', 'rb').read() c, mac = att_merge_jpg_bmp(jpg, bmp, aad=b"") print(mac.hex()) f = open('c.txt', 'wb') f.write(c) f.write(mac) f.close()