我们平时接触的长乘法, 按位相乘, 是一种时间复杂度为 O(n ^ 2) 的算法. 今天, 我们来介绍一种, 时间复杂度为 O (n ^ log 3) 的大整数乘法 (log 表示以 2 为底的对数).
介绍原理
karatsuba 算法要求乘数与被乘数要满足以下几个条件, 第一, 乘数与被乘数的位数相同; 第二, 乘数与被乘数的位数应为 2 次幂, 即为 2 ^ 2, 2 ^ 3, 2 ^ 4, 2 ^ n 等数值.
下面我们先来看几个简单的例子, 并以此来了解 karatsuba 算法的使用方法.
两位数相乘
我们设被乘数 A = 85, 乘数 B = 41. 下面来看我们的操作步骤:
将 A, B 一分为二, 令 p = A 的前半部分 = 8,q = A 的后半部分 = 5 , r = B 的前半部分 = 4 ,s = B 的后半部分 = 1,n = 2. 通过简单的数学运算:
A * B = pq * rs = (p * 10 + q) * (r * 10 + s) = p * r * 10 ^ 2 + (p * s + q * r ) * 10 + q * s. 令 u = p * r,v =
(p - q) * (s - r),w = q * s. 所以 A * B = u * 10 ^ 2 + (u + v + w) * 10 + w.
换成数值求解的过程如下:
A * B = 85 * 41 = (8 * 10 + 5) * ( 4 * 10 + 1) = 8 * 4 * 10 * 10 + (8 * 1 + 5 * 4) * 10 + 5 * 1. 其中 u = 8 * 4 = 32,v = (8 - 5) (1 - 4) = -9,w = 5 * 1 = 5. 所以, A * B = 32 * 100 + (32 - 9 + 5) * 10 + 5 = 3485. 与长乘法所得结果一致.
四位数相乘
我们设被乘数 A = 8537, 乘数 B = 4123. 下面来看我们的操作步骤:
将 A, B 一分为二, 令 p = A 的前半部分 = 85,q = A 的后半部分 = 37 , r = B 的前半部分 = 41 ,s = B 的后半部分 = 23,n = 4.
==> 其中, u = 85 * 41, v = (85 - 37) * (23 - 41), w = 37 * 23.
==> A * B = 8537 * 4123 = u * 10 ^ 4 + (u + v + w) * 10 ^ 2 + w = 3485_0000 +34_7200 + 851 = 35198051.
在我们计算 u, v, w 的过程中又会涉及两位数的乘法, 我们继续使用 Karatsuba 算法得出两位数相乘的结果.
N 位数相乘
我们令 n 为 乘数与被乘数的位数, 令 p = A 的前半部分, q = A 的后半部分, r = B 的前半部分 ,s = B 的后半部分.
==> 其中, u = p * r,v = (p - q) * (s - r),w = q * s. 所以 A * B = u * 10 ^ n + (u + v + w) * 10 ^ (n / 2) + w.
而 u, v, w 则是两个 n / 2 位的乘法运算. 我们继续调用 Karatsuba 算法 j 计算 u, v, w 的数值. 接着, 我们在计算 n / 2 乘法的过程中又会遇到 n / 4 位的乘法运算...... 以此类推, 直到我们遇到两个个位数的乘法, 我们就直接返回这两个个位数乘法的结果. 层层返回, 最终得到 N 位数的乘法结果.
时间复杂度
我们平常使用的长乘法, 是 O (n ^ 2) 的时间复杂度. 比如两个 N 位数相乘, 我们需要将每一位按规则相乘, 所以需要计算 N * N 次乘法. 而使用 Karatsuba 算法每层需要计算三次乘法, 两次加法, 以及若干次加法, 每使用一次 karatsuba 算法, 乘法规模就下降一半. 所以, 对于两个 n = 2 ^ K 位数乘法运算, 我们需要计算 3 ^ k 次乘法运算. 而 K = log n(底数为 2), 3 ^ K = 3 ^ log n = 2 ^ (log 3 * log n) = 2 ^ (log n * log 3) = n ^ log 3 (底数为 2).
代码实现
- # 关注微信公众号: Python 高效编程
- from math import log2, ceil
- def pad(string: str, real_len: int, max_len: int) -> str:
- pad_len: int = max_len - real_len
- return f"{'0'* pad_len}{string}"
- def kara(n1: int, n2: int) -> int:
- if n1 <10 or n2 < 10:
- return n1 * n2
- n1_str: str = str(n1)
- n2_str: str = str(n2)
- n1_len: int = len(n1_str)
- n2_len: int = len(n2_str)
- real_len: int = max(n1_len, n2_len)
- max_len: int = 2 ** ceil(log2(real_len))
- mid_len: int = max_len>> 1
- n1_pad: str = pad(n1_str, n1_len, max_len)
- n2_pad: str = pad(n2_str, n2_len, max_len)
- p: int = int(n1_pad[:mid_len])
- q: int = int(n1_pad[mid_len:])
- r: int = int(n2_pad[:mid_len])
- s: int = int(n2_pad[mid_len:])
- u: int = kara(p, r)
- v: int = kara(q-p, r-s)
- w: int = kara(q, s)
- return u * 10 ** max_len + (u+v+w) * 10 ** mid_len + w
输出结果:
- ==> kara(123456, 9734) == 123456 * 9734
- ==> kara(1234233456756, 32459734) == 1234233456756 * 32459734
来源: http://www.jianshu.com/p/d6f54454a3ea