RSA算法原理

rsa作为常用的非堆成加密算法,以前并不知道原理,以及为什么是安全的。直到看了李永乐老师科普视频 。所以就动手试了下。


基本原理

  1. 找到两个质数p,q
  2. 算出这两个质数的乘积n,即n=p*q
  3. 算出欧拉函数f(x) = (p-1)(q-1) = m
  4. 这时我们找到一个数 e,1< e < m, 并且 e与m互质。e即为私钥
  5. 找到一个整数 d, 使e乘以d 除以m的余数为1。即(e * d) % m = 1。

代码实现

  • 质数可以通过筛选法获得。这里有解释。prime_up_to函数可以返回所有小于整数n的质数
def prime_up_to(n):
    primes = [True] * (n + 1)

    for i in range(2, int(math.ceil(math.sqrt(n)))):
        if primes[i]:
            for o in range(i*2, n + 1, i):
                primes[o] = False
    return [p for p in range(2, n + 1) if primes[p]]

  • 然后使用random.choice选取两个质数。
    prime_1 = random.choice(prime)  # p
    prime.remove(prime_1)
    prime_2 = random.choice(prime)  # q

  • 算出两个质数的乘积n
n = prime_1 * prime_2
  • 欧拉函数的值即为
euler_fx = (prime_1-1)(prime_2-1)     # m
  • 寻找公钥 e

公钥的主要条件是 e与欧拉函数的值必须互质,即最大公约数为1。我们可以通过辗转相除法(欧几里得算法)来判断两个数是否互质。定义判断互质的函数

def is_co_prime(m, n):
    while temp != 0:
        temp = n % m
        n = m
        m = temp
    if n == 1:
        return True
    else:
        return False

  • 寻找私钥 d

根据(ed) % m = 1 可得,ed = 1 - xm。即ed + xm = 1。因为e,m已知我们可以根据扩展欧几里得算法求出d,以及x。定义扩展欧几里得算法:

def ext_euclid(a, b):
    old_s, s = 1, 0
    old_t, t = 0, 1
    old_r, r = a, b
    if b == 0:
        return 1, 0, a
    else:
        while r != 0:
            q = old_r // r
            old_r, r = r, old_r - q * r
            old_s, s = s, old_s - q * s
            old_t, t = t, old_t - q * t
    return old_r, old_s, old_t      # 最小公约数  s为a的系数  t为b的系数

到此为止我们就能获取公钥以及私钥了。质数的乘积n也很重要,在接下来的加密以及解密中会用到。

def get_rsa_key(prime_1, prime_2):
    n = prime_1 * prime_2   # n
    euler_fx = (prime_1 - 1) * (prime_2 -1)     # m

    while True:
        e = random.randint(1, euler_fx)     # e 公钥
        if is_co_prime(e, euler_fx):
            break
    a = ext_euclid(e, euler_fx)
    print(prime_1, prime_2)
    print(a)

    d = a[1]    # d 私钥
    return e, d, n      # 公钥 私钥 质数乘积

加密解密过程

  • 加密 对于明文A, 取A的e次幂,并除以生成密钥时的两个质数的乘积n得到余数C,C即为加密后的密文。

  • 解密 对于密文C,取C的d次幂 除以 n 得到的余数 D,D即为加密之前的明文A。

加密解密的过程可以写为:

while True:
    temp = get_rsa_key(num_1, num_2)
    if temp[1] > 1:
        break
public_key = temp[0]
private_key = temp[1]
big_int = temp[2]

result = (num_2 ** public_key) % big_int

decode = (result ** private_key) % big_int

print(result)
print(decode)