Implementing SHA-1 in Python

Earlier this year researchers at Google were able to generate two PDFs with the same SHA-1 digest, and the world became reasonably worried about the security of the hashing algorithm.

So even though I’ll likely never be using SHA-1 in the future (and more importantly, would never use my own implementation in a real-world project), I thought I’d sit down with the spec and see if I could implement it in Python, which I haven’t been using as much as I want to lately.

Thankfully NIST also provides a short example case to check against.

So let’s begin!

We have to define some helper functions. The first is ROTL, which is the spec’s abbreviation for the rotate-left operation. Essentially it means to rotate the bits in a circle to the left, so in a very small four-bit example:

ROTL(1011, 2) = 1110

Seems pretty easy, right? The spec even gives a shortcut, saying

ROTLn(x) = (x << n) | (x >> w – n)

But I didn’t know before that a bitwise left-shift in Python doesn’t keep a fixed width – it actually just adds zeros on the right. So 10011 << 2 = 1001100, instead of 01100.

To keep the original width, we need to bitwise AND (&) it with (2 ^ original width – 1). In the example above 10011 << 2 & ((2 ^ 5) – 1) = 1001100 & 11111 = 01100.

Here’s the final result:

def ROTL(x, n, w):
  return((x << n & (2 ** w - 1)) | (x >> w - n))
Next we need to define some logical functions (see page 10, section 4.1.1): Ch, Parity, and Maj. These are pretty straightforward since Python has built-in operators for everything that these functions need: bitwise AND, bitwise XOR, and bitwise complement.
def Ch(x, y, z):
  return((x & y) ^ (~x & z))

def Parity(x, y, z):
  return(x ^ y ^ z)

def Maj(x, y, z):
  return((x & y) ^ (x & z) ^ (y & z))

With helper functions defined, let’s actually start on a sha1 function that takes a single argument x.

The first bit it easy – the spec defines constants in hexadecimal notation that fill a list K of length 80. One constant is assigned to indices 0-19, and a new constant for indices 20-39, and so on.

def sha1(x):
  K = []

  for t in range(80):
    if   t <= 19:
      K.append(0x5a827999)
    elif t <= 39:
      K.append(0x6ed9eba1)
    elif t <= 59:
      K.append(0x8f1bbcdc)
    else:
      K.append(0xca62c1d6)

Next, we need to take the input message and manipulate it into bits, and pad it sufficiently to make a multiple of 512 bits (in our example, we’ll do exactly 512 bits and not worry about processing multiple word sets).

The padding consists of a 1, followed by 0 to make the total length 448 bits, and then the input message length in bits formatted as a 64-bit string.

Note the check at the end of this section that makes sure the length of x_padded is 512 characters. If this were adapted to multiple word sets, that would check that the length is evenly divisible by 512.

  x_bytes = bytearray(x, 'ascii')

  x_bits  = [format(x, '08b') for x in x_bytes]
  print('x_bits:', x_bits)

  x_bits_string = ''.join(x_bits)
  print('x_bits_string:', x_bits_string)

  pad_bits = '1' + ('0' * (448 - (8 * len(x) + 1))) + format(len(x) * 8, '064b')
  
  x_padded = x_bits_string + pad_bits
  print('x_padded:', x_padded)
  print('len(x_padded):', len(x_padded))
  assert(len(x_padded) == 512)

Next, some initial values. With multiple word sets we would have M(1), M(2), … up to M(N), where N = len(x_padded) / 512.

We also define initial hash values: a list of length 5 with hexadecimal starting points. After modification in the hashing algorithm, these will be concatenated into the final digest.

  M1 = x_padded
  H = [0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0]
  N = 1

Next begins a loop that in our case will only run once, but in a multi-word case would iterate N times. Inside that loop, we initiate a list W that will be of length 80, with indices 0-15 containing substrings of M(N), followed by left-rotated XOR-ed values of previously inserted substrings for all subsequent indices.

After that, temporary variables a, b, c, d, and e are created holding the initial values of the 5 indices in the list H. I printed them using Python’s base function hex to confirm.

  for i in range(1, N + 1):
    print('------' * 2)
    print('i = ', i)

    W = list()

    for t in range(80):

      if t <= 15:
        W.extend([ int(M1[ (32 * t) : (32 * (t + 1)) ], 2)  ])
      else:
        W.extend([ ROTL( W[t - 3] ^ W[t - 8] ^ W[t - 14] ^ W[t - 16], n=1, w=32) ])

    print('W:', W[0:16])

    a = H[0]
    b = H[1]
    c = H[2]
    d = H[3]
    e = H[4]
 
    print('hex(a):', hex(a))
    print('hex(b):', hex(b))
    print('hex(c):', hex(c))
    print('hex(d):', hex(d))
    print('hex(e):', hex(e))

Now for some fun – we loop from 0 to 79 as when creating the list K. This time we choose a logical function (Ch, Parity, and ​​Maj, defined above) based on the iterator variable t. Then we need to mess up those a-e variables.

First there’s a long addition of modified values that goes into a temporary variable T. The variable a gets left-rotated by 5 bits, and then the logical function is applied to b, c, and d. The result is added to e, K[t], and W[t], and then the sum is taken modulus 2 ^ 32 (to maintain the right number of bits).

Then several of the variable values simply change places, while c is recalculated as b left-rotated 30 bits.

    for t in range(80):
      print('------')
      print('t =', t)

      if t <= 19:
        f = Ch
      elif t <= 39:
        f = Parity
      elif t <= 59:
        f = Maj
      else:
        f = Parity

      T = (ROTL(a, n=5, w=32) + f(b, c, d) + e + K[t] + W[t]) % (2 ** 32)
      e = d
      d = c
      c = ROTL(b, n=30, w=32)
      b = a
      a = T

After that we can print the current values of these variables, and add them to the corresponding elements of H, again modulo 32:

      print('hex(a):', hex(a))
      print('hex(b):', hex(b))
      print('hex(c):', hex(c))
      print('hex(d):', hex(d))
      print('hex(e):', hex(e))

    H[0] = (a + H[0]) % (2 ** 32)
    H[1] = (b + H[1]) % (2 ** 32)
    H[2] = (c + H[2]) % (2 ** 32)
    H[3] = (d + H[3]) % (2 ** 32)
    H[4] = (e + H[4]) % (2 ** 32)

The last step is to format the elements of H as hexadecimal strings, and then join the pieces together to form a single digest:

  print(H)
  H = [format(x, '08x') for x in H]

  return("".join(H))

The final test: can we replicate the result of the example in the NIST document?

>>> print(sha1('abc'))
a9993e364706816aba3e25717850c26c9cd0d89d
>>> assert(sha1('abc') == 'a9993e364706816aba3e25717850c26c9cd0d89d'

Pretty cool that it actually worked, though I promise it was not right on the first try! This is why you don’t write your own cryptography library…though you should definitely take a stab at reproducing the algorithms to understand them more thoroughly

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: