From 62d6a6167e4121a536b813c883ac73773fef3ad7 Mon Sep 17 00:00:00 2001
From: cyfraeviolae <cyfraeviolae>
Date: Thu, 25 Aug 2022 02:15:55 -0400
Subject: nonce truncation

---
 aesgcmanalysis.py | 275 +++++++++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 242 insertions(+), 33 deletions(-)

diff --git a/aesgcmanalysis.py b/aesgcmanalysis.py
index 73a9ef8..abdfdb3 100644
--- a/aesgcmanalysis.py
+++ b/aesgcmanalysis.py
@@ -1,4 +1,4 @@
-import random, struct, hmac, itertools
+import random, struct, hmac, itertools, math
 from Crypto.Cipher import AES
 import numpy as np
 
@@ -77,9 +77,9 @@ def gf128_mod(a, m=gcm_modulus):
 
 def gf128_egcd(a, b):
     """Compute g, x, y such that g | a, g | b, g = ax + by;
-	for any h such that h | a and h | b, gf128_deg(g) >= gf128_deg(h).
+    for any h such that h | a and h | b, gf128_deg(g) >= gf128_deg(h).
     >>> assert gf128_egcd(x**2 + 1, x**1 + 1) == (x**1 + 1, 0, 1)
-	>>> assert gf128_egcd(x**12 + x**4, x**5 + 1) == (x**1 + 1, x**3 + x**2 + 1, x**10 + x**9 + x**7 + x**5 + x**4 + x**1 + 1)
+    >>> assert gf128_egcd(x**12 + x**4, x**5 + 1) == (x**1 + 1, x**3 + x**2 + 1, x**10 + x**9 + x**7 + x**5 + x**4 + x**1 + 1)
     >>> assert gf128_egcd(x**64 + x**2, x**37 + x**12 + 1) == (1, x**35 + x**34 + x**33 + x**31 + x**30 + x**28 + x**27 + x**25 + x**24 + x**21 + x**20 + x**18 + x**17 + x**15 + x**14 + x**12 + x**11 + x**9 + x**7 + x**6 + x**4 + x**3 + x**1 + 1, x**62 + x**61 + x**60 + x**58 + x**57 + x**55 + x**54 + x**52 + x**51 + x**48 + x**47 + x**45 + x**44 + x**42 + x**41 + x**39 + x**38 + x**37 + x**35 + x**34 + x**32 + x**31 + x**29 + x**28 + x**26 + x**25 + x**24 + x**22 + x**21 + x**19 + x**18 + x**16 + x**15 + x**13 + x**12 + x**11 + x**9 + x**8 + x**6 + x**5 + x**3 + x**2 + 1)
     """
     orig_a = a
@@ -138,6 +138,97 @@ def gf128_sqrt(x):
     """
     return gf128_exp(x, 2**127)
 
+def gf128_to_vec(x):
+    return [int(n) for n in bin(x)[2:].zfill(128)[::-1]]
+
+def vec_to_gf128(vs):
+    x = 0
+    for i, v in enumerate(vs):
+        if v:
+            x += (1 << i)
+    return x
+
+def reverse_mask(b):
+    # from https://stackoverflow.com/a/2602885
+    b = (b & 0xF0) >> 4 | (b & 0x0F) << 4;
+    b = (b & 0xCC) >> 2 | (b & 0x33) << 2;
+    b = (b & 0xAA) >> 1 | (b & 0x55) << 1;
+    return b;
+
+def bytes_to_gf128(xs, st=None, en=None):
+    if st is None and en is None:
+        st = 0
+        en = 16
+    a, b = struct.unpack('<QQ', bytes(reverse_mask(xs[i]) for i in range(st, en)))
+    return a + (b << 64)
+
+def gf128_to_bytes(n):
+    s = b''
+    while n != 0:
+        c = n & 0xff
+        f = 0
+        for i in range(8):
+            if c & (1 << i):
+                f += (1 << (7-i))
+        s += int.to_bytes(f, 1, 'big')
+        n >>= 8
+    s += b'\x00' * (16 - len(s))
+    assert len(s) == 16
+    return s
+
+def Ms():
+    vs = []
+    for i in range(0, 128):
+        v = gf128_to_vec(gf128_mod(1 << (2*i)))
+        vs.append(v)
+    return np.array(vs).transpose()
+
+def Mc(c):
+    vs = []
+    pr = c
+    for i in range(0, 128):
+        vs.append(gf128_to_vec(pr))
+        pr = gf128_mul(x**1, pr)
+    return np.array(vs).transpose()
+
+# Adapted from Wiki
+def rref_mod_2(M):
+    M = M.copy().astype('int')
+
+    lead = 0
+    rowCount, columnCount = M.shape
+    adj = np.eye(rowCount, rowCount).astype('int')
+
+    for r in range(rowCount):
+        if columnCount <= lead:
+            return M, adj
+        i = r
+        while M[i][lead] == 0:
+            i = i + 1
+            if rowCount == i:
+                i = r
+                lead = lead + 1
+                if columnCount == lead:
+                    return M, adj
+        if i != r:
+            M[[i, r]] = M[[r, i]]
+            adj[[i, r]] = adj[[r, i]]
+        for j in range(rowCount):
+            if M[j][lead] == 1 and j != r:
+                M[j] ^= M[r]
+                adj[j] ^= adj[r]
+        lead = lead + 1
+    return M, adj
+
+def kernel(M, f):
+    mt = M.transpose()
+    N, adj = f(mt)
+    basis = []
+    for i, row in enumerate(N):
+        if np.all(np.isclose(row, 0)):
+            basis.append(adj[i])
+    return N, adj, basis
+
 ## Computation in GF(2^128)[X]/(x^128 + x^7 + x^2 + x^1 + 1)
 ## Elements are represented as arrays where the ith element is the coefficient for x^i.
 
@@ -285,7 +376,7 @@ def gf128poly_modexp(a, n, m):
 
 def gf128poly_monic_gcd(a, b):
     """Compute g such that g | a, g | b, gf128poly_lead(g) = 1;
-	for any h such that h | a and h | b, gf128poly_deg(g) >= gf128poly_deg(h).
+    for any h such that h | a and h | b, gf128poly_deg(g) >= gf128poly_deg(h).
     >>> assert gf128poly_monic_gcd([x**3, x**2, x**4], [x**3, x**2, x**4]) == [x**127 + x**6 + x**1 + 1, x**127 + x**126 + x**6 + x**5 + x**1, 1]
     >>> assert gf128poly_monic_gcd([1], [1]) == [1]
     >>> assert gf128poly_monic_gcd([x**5 + x**2], []) == [x**5 + x**2]
@@ -422,34 +513,6 @@ def gf128poly_factorize(f, degree=None):
 
 ## AES-GCM
 
-def reverse_mask(b):
-    # from https://stackoverflow.com/a/2602885
-    b = (b & 0xF0) >> 4 | (b & 0x0F) << 4;
-    b = (b & 0xCC) >> 2 | (b & 0x33) << 2;
-    b = (b & 0xAA) >> 1 | (b & 0x55) << 1;
-    return b;
-
-def bytes_to_gf128(xs, st=None, en=None):
-    if st is None and en is None:
-        st = 0
-        en = 16
-    a, b = struct.unpack('<QQ', bytes(reverse_mask(xs[i]) for i in range(st, en)))
-    return a + (b << 64)
-
-def gf128_to_bytes(n):
-    s = b''
-    while n != 0:
-        c = n & 0xff
-        f = 0
-        for i in range(8):
-            if c & (1 << i):
-                f += (1 << (7-i))
-        s += int.to_bytes(f, 1, 'big')
-        n >>= 8
-    s += b'\x00' * (16 - len(s))
-    assert len(s) == 16
-    return s
-
 def pad(xs):
     return xs + b'\x00' * ((16 - len(xs)) % 16)
 
@@ -534,7 +597,134 @@ def nonce_reuse_recover_secrets(nonce, aad1, aad2, c1, c2, mac1, mac2):
             secrets.append((h, s))
     return secrets
 
-def demo():
+## Nonce Truncation Attack
+
+def gen_blocks(n, js):
+    blocks = b'\x00'*16*n
+    for j in js:
+        whichbyte = j // 8
+        whichbit = j % 8
+        newbyte = blocks[whichbyte] ^ (1 << (7-whichbit))
+        blocks = blocks[:whichbyte] + bytes([newbyte]) + blocks[whichbyte+1:]
+    return blocks
+
+squarer = np.array(Ms())
+matsqlookup = np.load(open('squares.np', 'rb'))
+adlookup = np.load(open('ad.np', 'rb'))
+
+def gen_ad(blocks):
+    matret = np.zeros((128, 128))
+    for i in range(len(blocks)//16):
+        block = blocks[(i*16):(i+1)*16]
+        if block == b'\x00'*16:
+            continue
+        j = i + 1 # first is taken up by length block
+        mat = None
+        d = bytes_to_gf128(block)
+        matd = Mc(d)
+        if len(matsqlookup) > j:
+            matsq = matsqlookup[j]
+        else:
+            matsq = np.linalg.matrix_power(squarer, j) % 2
+        mat = matd @ matsq
+        matret += mat
+    return matret % 2 # Because the elements of Ad are in GF2 we can mod 2
+
+def gen_t(n, macbytes, X=None, minrows=8):
+    T = []
+    for j in range(n*128):
+        if j < len(adlookup):
+            Ad = adlookup[j]
+        else:
+            blocks = gen_blocks(n, [j])
+            Ad = gen_ad(blocks)
+
+        if X is not None:
+            rows = min((n*128)//(X.shape[1]) - 1, macbytes*8-minrows)
+        else:
+            rows = min((n*128)//(Ad.shape[1]) - 1, macbytes*8-minrows)
+
+        if X is not None:
+            Ad = (Ad[:rows] @ X) % 2
+        z = np.concatenate(Ad[:rows])
+        T.append(z)
+    return (np.array(T)).transpose().astype('int')
+
+def gen_flips(b):
+    return np.nonzero(b)[0]
+
+def compute_n(ct):
+    num_blocks = len(ct)//16
+    ns = [1] # 1-indexed
+    while (ns[-1]*2)<=(num_blocks):
+        ns.append(2*ns[-1])
+    ns = [n-1 for n in ns] # 0-indexed into c
+    n = len(ns) - 1
+    return n
+
+def chunk(xs, n=16):
+    return [xs[i*n:(i+1)*n] for i in range(len(xs)//16)]
+
+def find_b(n, basis, ct, mac, nonce, aad, oracle):
+    base = bytearray(ct)
+    idx = 0
+    while True:
+        choice = random.sample(basis, random.randint(1, 12))
+        b = sum(choice) % 2
+        flips = gen_flips(b)
+        blocks = gen_blocks(n, flips)
+        for i, block in enumerate(chunk(blocks)):
+            j = (len(base)//16)-(2**(i+1)-1)
+            base[j*16:(j+1)*16] = xor(base[j*16:(j+1)*16], block)
+        try:
+            oracle(base[len(aad):], base[:len(aad)], mac, nonce)
+            return b
+        except ValueError as e:
+            assert str(e) == 'MAC check failed'
+            for i, block in enumerate(chunk(blocks)):
+                j = (len(base)//16)-(2**(i+1)-1)
+                base[j*16:(j+1)*16] = xor(base[j*16:(j+1)*16], block)
+        idx += 1
+
+def compute_auth_key(ct, mac, nonce, mac_bytes, aad, oracle):
+    ct = aad + ct
+    n = compute_n(ct)
+    assert n > (mac_bytes*8//2)
+    assert len(ct) % 16 == 0
+    assert len(aad) % 16 == 0
+    X = None
+    K = None
+    basisKerK = None
+    while K is None or (basisKerK is None or len(basisKerK) > 1):
+        T = gen_t(n, mac_bytes, X, minrows=7)
+        _, _, basisKerT = kernel(T, rref_mod_2)
+        assert len(basisKerT[0]) == n*128
+
+        b = find_b(n, basisKerT, ct, mac, nonce, aad, oracle)
+        flips = gen_flips(b)
+        blocks = gen_blocks(n, flips)
+        Ad = gen_ad(blocks)
+
+        if X is not None:
+            AdRelevant = ((Ad @ X) % 2)[:mac_bytes*8]
+        else:
+            AdRelevant = Ad[:mac_bytes*8]
+        incrK = Ad[:mac_bytes*8][np.any(AdRelevant, axis=1)]
+        if K is None:
+            K = incrK
+        else:
+            K = np.concatenate([K, incrK])
+        _, _, basisKerK = kernel(K, rref_mod_2)
+        X = np.array(basisKerK).transpose()
+        print(len(basisKerK))
+    _, _, kerK = kernel(K, rref_mod_2)
+    assert len(kerK) == 1
+    h = kerK[0]
+    return gf128_to_bytes(vec_to_gf128(h))
+
+# Demos
+
+def forbidden_attack_demo():
     k = b"tlonorbistertius"
     nonce = b"jorgelborges"
     m1 = b"The universe (which others call the Library)"
@@ -565,3 +755,22 @@ def demo():
         except AssertionError:
             pass
     assert succeeded
+
+def nonce_truncation_demo():
+    k = b'YELLOW_SUBMARINE'
+    aad = b'YELLOW_SUBMAFINERELLOWPUBMARINF_'
+    MACBYTES=1
+    pt = b'CELERYPATCHWORKS'*(2**9)
+    nonce = b'JORGELBORGES'
+    ct, mac = gcm_encrypt(k, nonce, aad, pt, mac_bytes=MACBYTES)
+    def oracle(base, aad, mac, nonce):
+        cipher = AES.new(k, mode=AES.MODE_GCM, nonce=nonce, mac_len=MACBYTES)
+        cipher.update(aad)
+        pt = cipher.decrypt_and_verify(base, mac)
+    h = compute_auth_key(ct, mac, nonce, MACBYTES, aad, oracle)
+    assert h == gf128_to_bytes(authentication_key(k))
+nonce_truncation_demo()
+
+# Try with different lengths
+# Make it so we chosoe to generate gen_t on the fly if needed
+# PRofiling
-- 
cgit v1.2.3