RSA: implementation and proofs

14 Jun 2015

what is RSA?

RSA is a public-key, or asymmetric, encryption algorithm. In contrast to symmetric algorithms, like DES and AES, which use the same key for both encryption and decryption, RSA employs two distinct keys: a public key used to encrypt data, and a private key used to decrypt whatever was encrypted with the public one. The beauty of public-key encryption is that the parties involved never need to exchange a master key, meaning that communications can be securely encrypted without any prior contact.

Public-key encryption was proposed by Whitfield Diffie and Martin Hellman in ‘76, while RSA itself was patented in ‘77 by Ron Rivest, Adi Shamir, and Leonard Adleman, who then went on to found a cybersecurity company of the same name – confusing, but great PR!

Rivest, Shamir, and Adleman

Clifford Cocks, an English cryptographer, arrived at a similar algorithm in ‘73 while working for British intelligence at GHCQ, but his work wasn’t declassified until 1998 due to its sensitivity. Forty years later, RSA underpins SSL certification, SSH handshakes, and lots more.

In this post, we’ll implement RSA, but we’ll very much take the long way around while doing so. The algorithm introduces a number of interesting problems, like finding greatest common divisors, performing modular exponentiation, computing modular inverses, and generating random prime numbers, each of which we’ll thoroughly explore and derive solutions to (many of these won’t be immediately clear, so we’ll formally prove them as we go). Note that we won’t prove RSA itself – I might add that as an extension to the article at some point in the future.

math precursor

The only thing we need to know before diving into RSA is some modular arithmetic, which is simply arithmetic with the property that numbers have a maximum value (called the modulus) and wrap around to 0 when they exceed it. When we take a number , we’re basically taking the remainder of ; most programming languages provide this in the form of a mod function or % operator. We’ll see lots of expressions in the form of:

Here, the symbol implies congruence, or that equals . An important gotcha is that applies to both sides of the expression, which isn’t immediately obvious to anyone used to the modulo operator in the programming sense. Many sources choose to omit the parentheses, simply writing , which just compounds the confusion; the clearest notation would probably be something like . This is extremely important to remember because otherwise, expressions like won’t make any sense at all (“but if is equal to 1 for all not equal to 1, why not just write ?!”).

Some notes about miscellaneous notation:

  1. means that divides, or is a factor of,
  2. range notation is used here and there: represents all of the numbers between and inclusive, includes but excludes , excludes both and , etc.

how RSA works

RSA revolves around a numeric key-pair, or a mathematically related public and private key. The public key is made known to the world, which can then use it to encrypt a message, while the private key can be used to decrypt anything encrypted with the public key. Encrypting and decrypting a message is fairly straightforward, while generating a key-pair is a more substantial process.

generate a key-pair

To generate a public/private key-pair:

  1. generate two (large) random primes, and
  2. let
  3. find (Euler’s totient), or the number of integers in the range that are coprime with – that is, have a Greatest Common Divisor (GCD) of 1 with it.
  4. find a value such that and is coprime with ; this is your public key.
  5. find a value such that – in other words, find the multiplicative modular inverse of modulo ; this is your private key.

Though short and concise, the above steps present several complex problems:

  1. generate a large, random prime number; this is probably the most involved, so we’ll save it for last (step 1)
  2. find , where is the product of two primes (step 3)
  3. find the GCD of two numbers, which will allow us to find (step 4)
  4. find the multiplicative modular inverse of a value, to find (step 4)

example

Before we dive into solving those, let’s walk through the process of generating a key-pair using some small sample numbers.

  1. let and
  2. (coprime values are 1, 2, 4, 7, 8, 11, 13, and 14)
  3. , because 3 is both less than and coprime with 8
  4. , because and

Easy! Except, of course, we weren’t dealing with numbers with hundreds of digits – that’s the hard part. :)

finding

To compute , we can take advantage of the fact that it’s composed of two prime factors: and . Thus, the only values with which it shares GCDs that aren’t 1 must be multiples of either or (for instance, and ). There are only multiples of () and multiples of () that are less than or equal to . Thus, there are values in the range that have a GCD with not equal to 1. Note, however, that we double counted in our list of multiples of and , so in reality it’s . Thus, , where is the total numbers of values in the range – that is, .

computing GCDs

To find the GCD of two numbers, we’ll employ the Euclidean algorithm:

  1. the GCD of any number and 0 is the absolute value of that number
  2. the GCD of numbers and is the GCD of and

or:

def gcd(a, b):
    return abs(a) if b == 0 else gcd(b, a % b)

Let’s prove it. Case 1 should be self-explanatory: 0 is technically divisible by any number, even if the quotient equals 0, so the GCD of 0 and any other number should be that number. We need to be careful and take its absolute value, however, to account for negative values; the greatest divisor of -5 is 5, after all, not -5, so the GCD of 0 and -5 must also be 5. Thus, we have to take the absolute value of -5 to arrive at the greatest divisor.

Case 2 is less intuitive (at least for me), and requires proving that . Let’s begin by creating another variable :

prove

We first want to prove that the GCD of and divides (or ). Begin by rewriting and as products of their GCD.

and are just placeholders: we don’t want to know or care what they equal. Now, plug those into the definition of :

Since we’ve shown that is the product of and another value, it is by definition divisible by .

prove

Apply the same logic here:

prove

We know that, by definition, , and we’ve proven that . Thus, is a common divisor of both and . That doesn’t imply that it’s the least common divisor, greatest, or anything else: all we know is that it divides both numbers. We do know that there exists a greatest common divisor of and , , so we can conclude that:

We now re-apply that same reasoning. We know that and . Thus, is a common divisor of and . Since we know that the greatest common divisor of and is , we can conclude that:

But now we have two almost contradictory conclusions:

The only way these can both be true is if:

So we’ve proven that (remember, ).

prove

First, let’s assume that , and rewrite it as: (or )

Now, we already know that , Since order doesn’t matter, we can rewrite as . Now, we apply the rule again.

or:

Bingo. We’ve proven Case 2, and completed our proof of the Euclidean Algorithm. Before we move on, we’ll also define a convenience wrapper for gcd() that determines whether two numbers are prime:

def coprime(a, b):
    return gcd(a, b) == 1

finding modular inverses

Given a value and modulus , the modular multiplicative inverse of is a value that satisfies:

This implies that there exists some value for which:

This turns out to be in the form of Bézout’s identity, which states that for values and , there exist values and that satisfy:

and , called Bézout coefficients, can be solved for using the Extended Euclidean algorithm (EEA). corresponds to , or the modular inverse that we were looking for, while can be thrown out once computed. The EEA will also give us the GCD of and – it is, after all, an extension of the Euclidean algorithm, which we use to find the GCD of two values. We need to verify that it equals 1, since we make the assume that ; if it doesn’t, has no modular inverse. Since modular_inverse() is just a wrapper for EEA – to be implemented in a function called bezout_coefficients() – its definition is simple:

def modular_inverse(num, modulus):
    coef1, _, gcd = bezout_coefficients(num, modulus)
    return coef1 if gcd == 1 else None

bezout_coefficients() is a bit tricker:

def bezout_coefficients(a, b):
    if b == 0:
        return -1 if a < 0 else 1, 0, abs(a)
    else:
        quotient, remainder = divmod(a, b)
        coef1, coef2, gcd = bezout_coefficients(b, remainder)
        return coef2, coef1 - quotient * coef2, gcd

Let’s see why it works.

the Extended Euclidean algorithm

How to solve for and ? Bezout’s Identity states:

or, for :

Let’s simplify:

Here, represents the floor function, which floors the result of to an integer.

Since we know, by the already proven Euclidean algorithm, that , we can write:

So, and . But what are and ? They’re the results of running the EEA on ! Classic recursion. In sum:

def bezout_coefficients(a, b):
    quotient, remainder = divmod(a, b)
    coef1, coef2 = bezout_coefficients(b, remainder)
    return coef2, coef1 - quotient * coef2

Of course, we need a base case, or we’ll end up recursing ad infinitum. Let’s take the case of .

So, if , we set the coefficient to 1 if is positive and -1 is is negative, and set to… what? If is 0, then can take on any value. For simplicity’s sake we’ll choose 0. Our revised definition looks like:

def bezout_coefficients(a, b):
    if b == 0:
        return -1 if a < 0 else 1, 0
    else:
        quotient, remainder = divmod(a, b)
        coef1, coef2 = bezout_coefficients(b, remainder)
        return coef2, coef1 - quotient * coef2

Also note that, since this is simply a more involved version of the Euclidean algorithm (we’re making recursive calls to bezout_coefficients(b, remainder) and have a base case of b == 0), when we hit the base case, abs(a) is the GCD of a and b. Since modular_inverse() needs to check that the GCD of its two arguments equals 1, we should return it in addition to the coefficients themselves. Hence, we’ll let it trickle up from our base case into the final return value:

def bezout_coefficients(a, b):
    if b == 0:
        return -1 if a < 0 else 1, 0, abs(a)
    else:
        quotient, remainder = divmod(a, b)
        coef1, coef2, gcd = bezout_coefficients(b, remainder)
        return coef2, coef1 - quotient * coef2, gcd

generating large, random primes

Here’s the idea:

  1. generate a large, random, odd number
  2. check for primality
    1. if prime, return it
    2. otherwise, increment by 2, and return to step 2.)

Easy enough, except for the bit about testing primality. How to do so efficiently? We’ll turn to the Rabin-Miller algorithm, a probabilistic primality test which either tells us with absolute certainty that a number is composite, or with high likelihood that it’s prime. We’re fine with a merely probabilistic solution because it’s fast, since speed is a non-negligible issue due to the size of the numbers that we’re dealing with, and also because the chances of a false positive (ie indicating that a number is prime when it’s actually composite) are astronomically low after even only a few iterations of the test.

Rabin-Miller primality test

The Rabin-Miller test relies on the below two assumptions (just accept that they’re true for now, and we’ll prove them later on). If is a prime number:

  1. for any not divisible by
  2. for any that satisfies , must equal ±1

Using these, you can test a value for compositeness like so (note that we return true/false to indicate definite compositeness/probable primality respectively):

  1. pick a random value in the range
  2. use assumption 1 to assert that ); if it’s not, return true
  3. if has an integer square root, let ; otherwise, return false
  4. since , we can use assumption 2 to assert that ; if not, return true
  5. otherwise, repeat steps 3-4, taking the square root of , and the square root of that, and so on, until you hit a value that doesn’t have an integer square root.
  6. if you haven’t already returned anything, you’ve satisfied assumptions 1 and 2 for all testable cases and can return false.

In sum, we return true if we’ve confirmed that is a witness to the compositeness of , and false if does not prove that is composite – transitively, there is a high chance that is prime, but we can only be more sure by running more such tests. While the above steps serve as a good verbal description of the algorithm, we’ll have to slightly modify them to convert the algorithm into real code.

We need to implement a function is_witness(), which checks whether a random value is a witness to the compositeness of our prime candidate, .

  1. write in the form . , for instance, would yield and , since .
  2. pick a random value in the range . We’ll check whether this is a witness for .
  3. let
  4. if , then return false
  5. repeat times:
    1. let
    2. if , return true
    3. if , return false
  6. if we haven’t returned yet, return true

These steps seem quite a bit different from before, but in reality, they’re exactly the same and just operating in reverse. We start with a value that doesn’t have an integer square root, and square it until we hit . Why did we bother decomposing into the form of ? Well, it allows us to rewrite as , and now we know exactly how many times we can take square roots before we hit a value that isn’t reducible any further – in this case, .

So, if we start with and square it, we’ll get , then , then , and ultimately , or . What’s the advantage of starting from the non-reducible value and squaring it, rather than the reducible value and taking its square roots? It sometimes allows us to short-circuit the process. For instance, as we iterate through the squares of , if we find an occurrence of -1, we know that we’ll get 1 when we square it, and 1 when we square that, and keep on getting 1s until we stop iterating. As a consequence, we know that we won’t find any failing conditions, and can exit early by returning false (step 5.3). The same goes for step 4: if , we know that each of the following squares will equal 1, so we immediately return false.

The failing conditions – ie those that cause the algorithm to return true – might not be immediately clear. In 5.2, we know that, if , we’ve violated assumption 2, because that implies that the previous value of was not equivalent to . Wait, why? Because if it were equal to -1, we would’ve already returned via 5.3 in the previous iteration, and if it were , then we would’ve returned either from 5.3 in an earlier iteration still or 4 at the very beginning. We also return true when we hit 6, because we know that by that point, if assumption 1 is:

  1. true, and , then the previous value of can’t be either 1 or -1 because we would already have returned via either 4 or 5.3.
  2. false, then by definition can’t be prime, since the assumption must hold true for prime

Finally, we simply repeat the is_witness() test times. Here’s the final implementation:

def is_prime(n, k=5):
    if n == 2:
        return True

    if n <= 1 or n % 2 == 0:
        return False

    s, d = decompose_to_factors_of_2(n - 1)

    def is_witness(a):
        x = modular_power(a, d, n)
        if x in [1, n - 1]:
            return False

        for _ in range(s - 1):
            x = modular_power(x, 2, n)
            if x == 1:
                return True

            if x == n - 1:
                return False

        return True

    for _ in range(k):
        if is_witness(random.randint(2, n - 1)):
            return False

    return True

def decompose_to_factors_of_2(num):
    s = 0
    d = num

    while d % 2 == 0:
        d //= 2
        s += 1

    return s, d

Note that we’ve introduced a currently undefined function, modular_power(). The problem with computing and is that , , , and are HUGE. Simply running (a ** d) % n would be asking for trouble. Fortunately, there are efficient ways of performing modular exponentiation, and we’ll implement one such method in the modular_power() function later in this article. Now, we need to actually prove the two assumptions that we base Rabin-Miller on.

Euclid’s lemma

…but before we do so, we need to prove Euclid’s Lemma, since both of the following proofs depend on it. It states that if is relatively prime to and , then . We’ll prove it using Bezout’s Identity. The GCD of and is 1, so there must exist and that satisfy:

Multiply both sides by :

is divisible by (because it’s divisible by , which is divisible by according to the lemma’s requisite), and is by definition divisible by , so must be divisible by too.

proof of assumption 1

Our first assumption was that for a prime , for any not divisible by . This is better known as Fermat’s Little Theorem. To prove it, begin by multiplying all of the numbers in the range by :

We make two observations:

  1. given two values and , is equivalent to (we effectively divide out ). We can prove this by rewriting as , which implies that , or . By Euclid’s Lemma, since and are coprime (reminder: this is a criterion of Fermat’s Little Theorem), , which means we can write , or .

  2. when each of its elements is simplified in , the above sequence is simply a rearrangement of . This is true because, firstly, its values all lie in the range – none can equal 0 since shares no factors other than 1 with either or any value in due to its primeness. The trick now is to realize that, if we have two distinct values and , and know that , then by the previous observation we can “divide out ” and have . If and were two values chosen from the sequence, we’d know that they’re all less than , and can thus remove the from the expression, leaving us with: . In conclusion, the only way to satisfy is to have be the same item as , and that means that the distinct values in map to distinct values in .

By observation 1:

By observation 2, we can cancel out each of the factors of from both sides of the expressions (after all, is prime and all of the factors of are less than it, so it’s coprime with all of them), which leaves us with:

QED.

proof of assumption 2

We now prove assumption 2: if is prime and , must equal . First, for greater clarity later on, we can rewrite our conclusion as: must divide either or . Now, if , then:

If divides , then:

and we’ve proven our conclusion. What if doesn’t divide ? We can then leverage Euclid’s Lemma: if is relatively prime to and , then . We know that is prime and doesn’t divide , so it’s relatively prime to , and we know that it divides . As a result, it has to divide , which implies that: . Again, we’ve proven our conclusion, and thus proven assumption 2.

applying Rabin-Miller

Now that we’ve implemented Rabin-Miller, creating a large, random prime is almost trivial:

def get_random_prime(num_bits):
    lower_bound = 2 ** (num_bits - 2)
    upper_bound = 2 ** (num_bits - 1) - 1
    guess = random.randint(lower_bound, upper_bound)

    if guess % 2 == 0:
        guess += 1

    while not is_prime(guess):
        guess += 2

    return guess

The num_bits parameter is a bit of a weird way of specifying the desired size of the prime, but it’ll make sense since we usually want to create RSA keys of a specific bit-length (more on this later on).

wrapping it all up

At long last, we can define our create_key_pair() function.

def create_key_pair(bit_length):
    prime_bit_length = bit_length // 2
    p = get_random_prime(prime_bit_length)
    q = get_random_prime(prime_bit_length)
    n = p * q
    totient = (p - 1) * (q - 1)

    while True:
        e_candidate = random.randint(3, totient - 1)
        if e_candidate % 2 == 0:
            e_candidate += 1

        if coprime(e_candidate, totient):
            e = e_candidate
            break

    d = modular_inverse(e, totient)
    return e, d, n

The only thing that requires explanation is this bit_length business. The idea here is that we generally want to create RSA keys of a certain bit-length (1024 and 2048 are common values), so we pass in a parameter specifying the length. To make sure that has a bit-length approximately equal to bit_length, we need to make sure that the primes and that we use to create it have a bit length of bit_length / 2, since multiplying two -bit numbers yields an approximately -bit value. How come? The number of bits in a positive integer is , so the number of bits in is . According to the logarithm power rule, we can rewrite as , so the bit length equals . In other words, has roughly twice as many bits as .

encrypt/decrypt messages

In comparison to generating keys, encrypting and decrypting data with them is mercifully simple.

  1. encrypt a message with public key and modulus :
  2. decrypt a message with private key and modulus :
def encrypt(e, n, m):
    return modular_power(m, e, n)

def decrypt(d, n, c):
    return modular_power(c, d, n)

So, what’s modular_power()? The problem with the encryption and decryption operations, which look deceptively trivial, is that all of the values involved are big. Really, really big. As a result, naively solving by simply resolving and then simplifying that modulo is a no-go. Fortunately, there are more efficient ways of performing modular exponentiation, like exponentiation by squaring.

exponentiation by squaring

When trying to solve , begin by representing in binary form:

where is the total number of bits in , and represents the value of each bit – either 0 or 1. Now, rewrite the original expression:

For illustrative purposes, let’s temporarily remove the factor from each exponent, which leaves us with:

It’s now obvious that each factor is a square of the one that precedes it: is the square of , is the square of , etc. If we were to programmatically solve the expression, we could maintain a variable, say accumulator, that we’d initialize to , and square from factor to factor to avoid recomputing every time. Now, let’s reintroduce :

The good thing is that has a limited set of possible values: just 0 and 1! Any value in the form – that is, all of the above factors – evaluates to when , and , or 1, when . In other words, the value of only controls whether or not we multiply one of the factors into the accumulator that’ll become our ultimate result (since if , we’ll just end up multiplying in 1, which means we shouldn’t even bother). Thus, modular_power() might look something like this:

def modular_power(base, exp, modulus):
    result = 1

    while exp:
        if exp % 2 == 1:
            result = result * base
        exp >>= 1
        base = base ** 2

    return result % modulus

But we still haven’t addressed the issue of multiplying huge numbers by huge numbers, and this version of modular_power() doesn’t perform much better than (base ** exp) % modulus (in fact, after some spot checking, it appears to be much slower!). We can address that by taking advantage of the following property of modular multiplication:

We can prove it by rewriting and in terms of :

and substituting that into the original expression:

We’re able to remove the entire chunk of the expression that gets multiplied by because it’s by definition divisible by , meaning that, taken , it would equal 0, and wouldn’t contribute anything to the sum. Thus, equals , or .

Using that, we can make the following adjustment to our initial implementation:

def modular_power(base, exp, modulus):
    result = 1
    base %= modulus

    while exp:
        if exp % 2 == 1:
            result = (result * base) % modulus
        exp >>= 1
        base = (base ** 2) % modulus

    return result

We’re now taking % modulus in a bunch of places, which is valid due to the above property and prevents the value of both result and base from growing out of control.

That tops off our implementation of RSA. Here’s the entire source file.

acknowledgements

I wouldn’t have been able to present most of the proofs in this article without help from the following sources. One of the key motivations for gathering them all in one post is that, as I tried to understand all of the moving parts of RSA, I needed to sift through a lot of material to find accessible and satisfactory explanations: