徒然

思いついたら書きます

Pythonで動くエラトステネスの篩をさらに高速化してみた話

はじめに

以前の記事でエラトステネスの篩を作成しました。
しかしこの篩ではLibrary CheckerのEnumerate PrimesでTLEしてしまいます。 そこで本記事では篩のさらなる改良を目指します。

まずエラトステネスの篩の仕組みをおさらいしておくと、調べる上限をNとして

というアルゴリズムです。

ここでよくある高速化の手段として、偶数は2以外合成数であることから、

  • N \ge 2のとき、2素数で確定する
  • N以下の\{3, 5, 7, ...\}について、素数と仮定する
  • \sqrt N以下のi = 3, 5, ...から小さい順に、i素数であるとき
  • \sqrt Nより大きく合成数ではなかった奇数素数

と書き換えることができ、探索するiが半分になって嬉しくなるよねというのがあります。
今回の高速化では最初から調べない数をさらに増やし、2, 3, 5の倍数を最初から落とします。これにより探索数が約8/30になることに加えて、残りの素数候補が
30i + j (j \in \{1, 7, 11, 13, 17, 19, 23, 29\}, (i, j) \neq (0, 1))
で表されることから、jの部分を8bitで管理できて嬉しくなります。

実装の際は下記記事を参照しました。

qiita.com

以下が、高速化されたエラトステネスの篩をPythonで実装したコードです。

def faster_eratosthenes(n: int):
    if n < 30:
        return [x for x in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] if x <= n]
    remains = [1, 7, 11, 13, 17, 19, 23, 29]
    inv_remains = {x: i for i, x in enumerate(remains)}
    msk = 255  # (1 << 8) - 1
    div30 = [i * j // 30 for j in remains for i in remains]
    mod30 = [inv_remains[i * j % 30] for j in remains for i in remains]
    shift = [1 << i for i in range(8)]
    msk8 = [msk - shift[i] for i in range(8)]
    inv_shift = {shift[i]: i for i in range(8)}
    res = [2, 3, 5]
    max_k = n // 30
    import bisect

    max_m = bisect.bisect_right(remains, n % 30) - 1
    sqrtn = int(n**0.5) + 1
    max_sqrt_k = sqrtn // 30
    max_sqrt_m = bisect.bisect_right(remains, sqrtn % 30) - 1
    table = bytearray([msk] * (max_k + 1))
    table[max_k] = (1 << (max_m + 1)) - 1
    table[0] -= 1  # remove 1
    for k in range(max_sqrt_k + 1):
        x = table[k]
        while x:
            u = x & (-x)
            if table[k] & u:
                m = inv_shift[u]
                res.append(k * 30 + remains[m])
                if k == max_sqrt_k and m > max_sqrt_m:
                    break
                # k_before = k
                m_before = m
                i = k * (30 * k + 2 * remains[m]) + div30[(m << 3) + m]
                j = mod30[(m << 3) + m]
                while i < max_k or (i == max_k and j <= max_m):
                    table[i] &= msk8[j]
                    if m_before == 7:
                        i += 2 * k + remains[m] + div30[m << 3] - div30[(m << 3) + 7]
                        j = mod30[m << 3]
                        # k_before += 1
                        m_before = 0
                    else:
                        i += (
                            k * (remains[m_before + 1] - remains[m_before])
                            + div30[(m << 3) + m_before + 1]
                            - div30[(m << 3) + m_before]
                        )
                        j = mod30[(m << 3) + m_before + 1]
                        m_before += 1
            x &= x - 1
    i = table[max_sqrt_k]
    i30 = 30 * max_sqrt_k
    while i:
        j = inv_shift[i & (-i)]
        if i30 + remains[j] > res[-1]:
            res.append(i30 + remains[j])
        i &= i - 1
    i30 += 30
    for i in table[max_sqrt_k + 1 :]:
        while i:
            j = inv_shift[i & (-i)]
            res.append(i30 + remains[j])
            i &= i - 1
        i30 += 30
    return res

高速化が実ってLibrary Checkerでも無事通るようになりました。

judge.yosupo.jp

おわりに

実装の中身の解説については途中で挫折したので参考記事に丸投げしてしまいました。
需要と元気があれば後日更新します。