/*******************************
christian wagner

This program implements the pohlig-hellman algorithm
for finding discrete logs.
Algo discription:
http://alum.wpi.edu/~sorrodp/crypto/node5.html


It uses the miracl large integer package located at
http://indigo.ie/~mscott/

to compile type
g++ -I . pohlighellman.cpp miracl.a

Note i put extra files in the miracl.a file to get it
to compile.

if you are getting a "mess up in shanks" message
then goto line 197 and read the comment.
i

*******************************/

#include <iostream>
#include <string>
#include <list>
#include <stdlib.h>
#include "big.h"
#include "crt.h"
#include "monty.h"

using namespace std;
Miracl precision(5000,10);

// limit on the brute force attack for factoring
#define LIMIT 10000
#define min( a, b ) ((a) < (b)? (a) : (b))

Big ph( Big alpha, Big beta, Big p );
Big shanks( Big alpha, Big beta, Big p, Big m );
int findPrimes( Big p, Big **primes, int **exp );
Big brute( Big n );
Big brent( Big n );

int main( int argc, char *argv[] ) {
    if( argc != 4 ) {
        cerr << "usage: " << argv[ 0 ] << "<alpha> <beta> <p> < input\n";
        exit( -1 );
    }

    // get alpha beta and p
    Big alpha( argv[ 1 ] );
    Big beta( argv[ 2 ] );
    Big p( argv[ 3 ] );

    // find a
    Big a = ph( alpha, beta, p );

    // check if it is right, try shanks if it isn't
    // if this is a large p, and it fails, save yourself
    // the trouble and quit now
    if( pow( alpha, a, p ) != beta ) {
        cerr << "ph failed\nTrying shanks\n";
        a = shanks( alpha, beta, p, sqrt( p - 1 ) + 1 );
        if( pow( alpha, a, p ) != beta ) {
            cerr << "no good\n";
            exit( -1 );
        }
    }

    cout << "a = " << a << endl
         << "calculating values..\n";

    string str1, str2;

    // read in the two numbers, and spit out the
    // decrypted numbers
    while( cin >> str1 >> str2 ) {
        Big t1( (char *) str1.c_str( ) );
        Big t2( (char *) str2.c_str( ) );

        t1 = inverse( t1, p );
        t1 = pow( t1, a, p );

        cout << ( t1 * t2 ) % p << endl;
    }

    // now you have to take the decrypted numbers and convert them
    // to ascii

    return( 0 );
}

// the pohlig hellman algorithm
Big ph( Big alpha, Big beta, Big p ) {
    Big *primes;
    int *exp;
    int numPrimes = 0;
    Big p1 = p - 1;

    // get the prime factors, and exponents of p-1
    numPrimes = findPrimes( p1, &primes, &exp );

    Big *mods = new Big[ numPrimes ];
    Big *eva = new Big[ numPrimes ];

    for( int i = 0; i < numPrimes; i++ ) {

        Big pi = primes[ i ];
        Big gamma = pow( alpha, p1 / pi, p );  // alpha^(p-1/prime)
        Big Z = beta;
        Big be;

        // initialize to 0
        eva[ i ] = 0;
        for( int j = 1; j <= exp[ i ]; j++ ) {
            Big delta = pow( Z, p1 / pi, p );   // Z^(p-1/pi)

            // delta = gamma^be  uses shanks to find
            be = shanks( gamma, delta, p, sqrt( primes[ i ] ) + 1 );

            //Z = Z * alpha^(-be*prime^j)
            Z *= inverse( pow( alpha, be * pi / primes[ i ], p ), p );
            pi *= primes[ i ];

            // used in crt, this calculates the value to me modded
            eva[ i ] += be * pow( primes[ i ], j - 1, p );
        }

        // find the mod part of the crt
        mods[ i ] = pow( primes[ i ], exp[ i ], p );
    }

    // use the chinese remainder theorem
    // to find a
    Crt crt( numPrimes, mods );
    Big tmp = crt.eval( eva );
    delete [] primes;
    delete [] exp;
    delete [] mods;
    delete [] eva;

    return( tmp );
}

// used in shanks
// operator< is for sorting
class Pair {
    public:
        Big v1, v2;
        bool operator<( const Pair p2 ) const { return( v2 < p2.v2 ); }
};

// run shanks algorithm to find a
// note m = sqrt( p - 1 ) + 1
// in implementation i made a time space tradeoff
// to use less memory, i make one list, then calculate
// the other list's values.
Big shanks( Big alpha, Big beta, Big p, Big m ) {
    Big i = 0, j = 0, k;
    list< Pair > p1;
    Pair pa;
    bool flag = false;

    // calculate first list
    for( i = 0; i < m; ++i ) {
        pa.v1 = i;

        // alpha^(m*i)
        pa.v2 = pow( alpha, m*i, p );
        p1.push_back( pa );
    }

    p1.sort( );

    Big tmp;
    Big alphaInv = inverse( alpha, p );
    for( j = 0; j < m; ++j ) {

        // beta * alpha^-j
        tmp = ( beta * pow( alphaInv, j, p ) ) % p;

        // look for where the values == each other
        list< Pair >::iterator i1 = p1.begin( );
        while( i1 != p1.end( ) ) {
            if( (*i1).v2 == tmp ) {
                i = (*i1).v1;
                k = j;
                flag = true;

/*************************************************
if you are getting a "mess up in shanks" message
then comment out the below line of code.  This should
hopefully fix it.
*************************************************/
                break;
            }
            i1++;
        }
        if( flag ) {
            break;
        }
    }
    p1.clear( );

    if( !flag ) {
        cerr << "mess up in shanks\n";
        exit( -1 );
    }

    return( m * i + k );
}

// finds the primes and the exponents of the primes
// makes a new array, and returns the size
int findPrimes( Big p, Big **primes, int **exp ) {
    // special case p = 1
    // should never happen unless ur stupid
    if( p == 1 ) {
        *primes = new Big[ 1 ];
        *exp = new int[ 1 ];
        (*primes)[1] = 1;
        (*exp)[1] = 1;
        return( 1 );
    }

    Big n;
    list< Big > pri;

    // brute force as much as we can
    while( p != 1 && ( n = brute( p ) ) != -1 ) {
        p /= n;
        pri.push_back( n );
    }

    // use brent to find the rest
    // this should get them all, if you see the message
    // composite, then you need to add another factoring
    // method
    while( p!= 1 ) {
        n = brent( p );
        p /= n;
        pri.push_back( n );
    }

    // just get the unique numbers
    list< Big > tmp = pri;
    tmp.unique( );

    // allocate space for our arrays
    int size = tmp.size( );
    *primes = new Big[ size ];
    *exp = new int[ size ];

    list< Big >::iterator itr = tmp.begin( );
    for( int i  = 0; i < size; i++ ) {
        (*primes)[ i ] = *itr++;
    }

    (*exp)[ 0 ] = 1;
    itr = pri.begin( );
    for( int i = 1, j = 0; i < pri.size( ); i++ ) {
        Big prev = *itr++;
        if( prev != *itr ) {
            (*exp)[ ++j ] = 0;
        }
        (*exp)[ j ]++;
    }

    return( size );
}

// brute force factor, pulled from miracl pacakge
// returns first factor it finds
Big brute( Big n ) {
    if( prime( n ) ) {
        return( n );
    }

    miracl *mr = &precision;
    gprime( LIMIT );

    int p = mr -> PRIMES[ 0 ];
    int c = 0;
    while( 1 ) {
        if( n % p == 0 ) {
            gprime( 0 );
            return( p );
        }

        p = mr->PRIMES[ ++c ];
        if( p == 0 ) {
            break;
        }
    }
    return( -1 );
}

// brent factor, pulled from miracl pacakge
// returns first factor it can find
Big brent( Big n ) {
    if( prime( n ) ) {
        return( n );
    }

    long m = 10L;
    long r = 1L;
    long iter = 0L;
    long i, k;

    Big z = 0;
    ZZn x, y, q, ys;

    do {
        modulo( n );
        y = z;
        q = 1;

        do {
            x = y;
            for( i = 1L; i <= r; i++ )
                y=(y*y+3);
            k = 0;
            do {
                iter++;
                ys = y;

                for( i = 1; i <= min( m, r - k ); i++ ) {
                    y = y * y + 3;
                    q = ( y - x ) * q;
                }
                z = gcd( q, n );
                k += m;
            } while( k < r && z ==1 );
            r *= 2;
        } while( z==1 );
        if( z == n ) {
            do {
                ys = ys * ys + 3;
                z = gcd( ys - x, n );
            } while( z == 1 );
        }
        if( !prime( z ) ) {
             cerr << "\ncomposite factor ";
        }
        return( z );
    } while( !prime( n ) );

    return( n );
}

