216 Investigating the primality of numbers of the form 2*n^2-1

http://projecteuler.net/index.php?section=problems&id=216

ミラー・ラビンを使うのも、ひとつの方法でしょう。

しかし、それは問題の主旨に反するような気もする。

では、どうするか…

2*n^2-1が、ある素数pで割り切れるならば、

n^{2}¥equiv¥frac{p+1}{2}¥, (¥mathrm{mod}¥, p)

が成り立つことに注目。

この等式の解は mod p で考えれば、0個か2個。

mod p での √ が必要になるが、それは前にやった。

4n+1素数の平方和分解 – 落書き、時々落学

で、つまり、

n¥equiv¥sqrt{¥frac{p+1}{2}}¥,(¥mathrm{mod}¥, p)

なる、n に対しては, 2*n^2-1 は p で割り切れる。

あとは、これを利用して、ふるいにかければ良い。

{-# OPTIONS_GHC -fbang-patterns #-}
import Control.Monad (when)
import Control.Monad.ST (ST, runST)
import Data.Array.ST (STUArray, newArray, readArray, writeArray)
import Mod (jacobi, sqrtMod)
primesUpTo :: Integer -> [Integer]
primesUpTo n = runST $ do
p <- newArray (1,div n 2) True :: ST s (STUArray s Integer Bool)
let u = floor.sqrt.fromIntegral $ div n 2
loop !i = when (i <= u) $ sieve i (2*i*(i+1)) >> loop (i+1)
sieve !i !j = when (j <= div n 2) $ writeArray p j False >> sieve i (j+2*i+1)
primes !i !ps | i > div n 2 = return.reverse $ ps
| otherwise   = do q <- readArray p i
if q then primes (i+1) (2*i+1:ps)
else primes (i+1) ps
loop 1
primes 1 [2]
p216 :: Integer -> Integer
p216 n = runST $ do
q <- newArray (2,n) True :: ST s (STUArray s Integer Bool)
let u = floor $ sqrt 2 * fromIntegral n
ps = tail $ primesUpTo u
loop [] = count 2 
loop !(p:ps) | jacobi (div (p+1) 2) p == -1 = loop ps
| otherwise = sieve p s >> sieve p (p-s) >> check p >> loop ps
where s = sqrtMod (div (p+1) 2) p
sieve !i !j = when (j<=n) $ writeArray q j False >> sieve i (j+i)
check !i = when (div (i+1) 2 == j*j) $ writeArray q j True
where j = floor.sqrt.fromIntegral $ div (i+1) 2
count !i !c | i > n     = return c
| otherwise =  do p <- readArray q i
if p then count (i+1) (c+1)
else count (i+1) c
loop ps
main :: IO ()
main = print $ p216 (5*10^7)

ミラー・ラビンよりは速いけど、やっぱり、まだ遅い。

追記

C++で実装したら、10倍くらい速くなった。

#include <iostream>
#include <cmath>
#define SIGN(p) ((p) ? 1 : -1)
#define N 50000000
#define U ( (int) ( sqrt( 2 ) * N ))
#define S sqrt( 0.5 *  U )
using namespace std;
// input : a, b
// output : x, y  s.t. ax + by = gcd(a, b)
int extGcd( int a, int b, int& x, int& y ) {
if ( b ==  ) {
x = 1; y = ; return a;
}
int g = extGcd( b, a % b, y, x );
y -= a / b * x;
return g;
}
// xn = 1 (mod p)
int invMod(int n, int p) {
int x, y;
return extGcd ( n, p, x, y ) == 1 ? ( x + p ) % p : ;
}
// jacobi symbol ( a / n )
int jacobi(int a, int n) {
if (a < ) return SIGN(n % 4 == 1) * jacobi(-a, n);
if (a < 2) return a;
if (a % 2 == ) return SIGN(n % 8 == 1 || n % 8 == 7) * jacobi(a / 2, n);
return SIGN(a % 4 == 1 || n % 4 == 1) * jacobi(n % a, a);
}
// a^n (mod p)
int powMod(int a, int n, int p) {
if (n == ) return 1;
if (n % 2 == ) return powMod( (long long) a * a % p, n / 2, p);
return (long long) a * powMod( a, n - 1, p ) % p;
}
// sqrt n (mod p)
int sqrtMod(int n, int p) {
int w = 2, s = , q = p - 1, m = invMod( n, p );
while ( jacobi ( w, p ) != -1 ) w++;
while ( q % 2 ==  ) q /= 2, s++;
int v = powMod( w, q, p );
int r = powMod( n, (q + 1) / 2, p );
do {
int i = , u = (long long) r * r * m % p;
while ( u % p != 1 ) u = (long long) u * u % p, i++;
if ( i ==  ) return r;
r = (long long) r * powMod (v, 1 << (s - i - 1), p) % p;
} while ( true );
}
int main() {
bool *p = new bool [ U / 2 + 1], *q = new bool [ N + 1 ];
for ( int i = 1; i <= U / 2; i++ ) p[i] = true;
for ( int i = 1; i <= N; i++ ) q[i] = true;
for ( int i = 1; i <= S; i++ ) if ( p[i] ) // sieve prime
for ( int j = 2 * i * ( i + 1 ); j <= U / 2; j += 2 * i + 1 ) p[j] = false;
for ( int i = 3; i <= U ; i += 2 ) // sieve 2*n^2-1 is prime
if ( p[i / 2] ) {
if ( jacobi ( i / 2 + 1, i ) != 1 ) continue;
int s = sqrtMod( i / 2 + 1, i );
for ( int j = s; j <= N ; j += i ) q[j] = false;
for ( int j = i - s; j <= N ; j += i ) q[j] = false;
int t = sqrt(  i / 2 + 1 );
if (  i / 2 + 1 == t * t ) q[t] = true;
}
int c = ;
for (int i = 2; i <= N; i++ ) if ( q[i] ) c++;
cout << c << endl;
}