Fast Modular Exponentiation Using Binary Exponentiation in Python

Learn the binary exponentiation algorithm for efficient modular exponentiation in Python. Essential for RSA and other cryptographic applications.

When computing $a^k \pmod{p}$ for large numbers, naively multiplying $a$ by itself $k$ times before taking the remainder is extremely inefficient, as intermediate values become astronomically large.

Binary exponentiation (also known as exponentiation by squaring) is an algorithm that performs this computation efficiently. It is an essential technique for calculating modular exponentiation of large numbers in applications such as RSA encryption.

The key insight is to express the exponent $k$ in binary, reducing the computational complexity to $O(\log k)$.

Example

Consider computing $5^{21} \pmod{p}$.

First, express the exponent $21$ in binary: $21 = 16 + 4 + 1 = 1 \cdot 2^4 + 0 \cdot 2^3 + 1 \cdot 2^2 + 0 \cdot 2^1 + 1 \cdot 2^0$ So the binary representation of $21$ is $(10101)_2$.

Using this, $5^{21}$ can be decomposed as: $$ 5^{21} = 5^{16+4+1} = 5^{16} \cdot 5^4 \cdot 5^1 $$ By sequentially computing powers of the form $a^{2^i}$ (simply by squaring the previous result), we only need to multiply together the terms where the binary digit is 1. This dramatically reduces the number of multiplications.

Implementation

Binary exponentiation can be implemented by scanning the bits of the exponent from right to left (Right-to-Left) or left to right (Left-to-Right). Both have the same computational complexity. Here, we present the simpler Right-to-Left approach.

Python Program

def power(base, exp, mod):
    """
    Efficiently calculates (base^exp) % mod using binary exponentiation.

    :param base: The base number.
    :param exp: The exponent.
    :param mod: The modulus.
    :return: The result of (base^exp) % mod.
    """
    res = 1
    base %= mod
    while exp > 0:
        # If the least significant bit of exp is 1, multiply res by base
        if exp % 2 == 1:
            res = (res * base) % mod

        # Square the base and right-shift the exponent by 1 bit
        base = (base * base) % mod
        exp //= 2

    return res

# --- Examples ---
# Compute 5^21 mod 99
k = 21
g = 5
p = 99
result = power(g, k, p)
print(f"{g}^{k} mod {p} = {result}") # -> 5^21 mod 99 = 20

# Example with very large numbers
k = 12345678901234567890
g = 987654321987654321
p = 1000000007
result = power(g, k, p)
print(f"Result for large numbers: {result}")

How the Program Works

  1. res = 1: Initializes the result variable to 1.
  2. while exp > 0:: Loops until the exponent exp becomes 0.
  3. if exp % 2 == 1:: Checks whether the least significant bit of exp is 1.
  4. res = (res * base) % mod: If the bit is 1, multiplies the current base into the result res.
  5. base = (base * base) % mod: Squares base, progressively computing $a, a^2, a^4, a^8, \dots$.
  6. exp //= 2: Integer-divides exp by 2, shifting to the next bit.
  7. When the loop ends, res holds the final result.

Since the remainder is taken at each step, intermediate values never exceed the modulus, preventing overflow and enabling fast computation even with extremely large numbers.