#include <iostream.h>
#include <Integer.h>
#include <stdlib.h>

#include "euclid.h"
Integer shanks (Integer alpha, Integer beta, Integer p);
void quicksort (Integer *list[2], int left, int right);
void swap (Integer *list[2], int i, int j);

//main simply reads values in and passes then to the shanks function.

main (){
    Integer alpha = 3, beta = 525, p = 809, answer;
    cout << "This program will compute the discrete log base alpha of ";
    cout <<"beta in Zp.\nInsert alpha:\n";
    cin >> alpha;
    cout << "Insert Beta:\n";
    cin >> beta;
    cout << "Insert p:\n";
    cin >> p;
    answer =  shanks (alpha, beta, p);
    Integer check = ModularExp(alpha, answer, p);
    cout << "Log base alpha (beta) = "<<answer <<'\n';  
//  cout <<"Check that "<<beta<<" = "<<check<<'\n';
}

//shanks executes Shanks' algorithm as outlined in the book.

Integer shanks (Integer alpha, Integer beta, Integer p){

    Integer a, m = sqrt (p-1) + 1;
    long j;
    Integer alphainv, common;
    Integer *L1[2], *L2[2];
    ExtendedEuclid (alpha, p, &common, &alphainv);
    //form a ^-1 mod p.
    if (common != 1) {
	cout << "(a, p) != 1\n";
	exit (1);
    }

    //declare arrays: m must fit in an int.

    L1[0] = new Integer[m.as_long() + 50];
    L2[0] = new Integer[m.as_long() + 50];
    L1[1] = new Integer[m.as_long() + 50];
    L2[1] = new Integer[m.as_long() + 50];
    //form elements of L1.
    for (j=0; j<=m-1; ++j){
	L1 [0][j] = j;
	L1 [1][j] = ModularExp (alpha, m*j, p);
    }
    //sort L1.
    quicksort (L1, 0,  (m-1).as_long());
    //form elements of L2
    for (j=0; j<=m-1; ++ j){
	L2 [0][j] = j;
	
	L2 [1][j] = (beta * ModularExp (alphainv, j, p) ) % p; 	 
    }
    //sort L2
    quicksort (L2,0,  (m-1).as_long());
    long i;
    for (j=0; j <= m-1; ++j)
      {
        cout << j <<' '<<L1[1][j] << ' ' << L2[1][j]<<'\n';
      }
    cout << '\n';

    i = 0; j = 0;
    // "Mergesearch" lists for a match.
    while (L1 [1][i] != L2[1][j]){
//	cout <<  L1[1][i] << "=?" << L2[0][j] ;
//	cout << ' ' << L2 [1][j] << " i="<<i<<" j="<<j<<'\n';
	if (L1 [1][i] > L2[1][j]) ++j;
	else  ++i;
	if (i > m || j >m ) exit (2);
    }
   cout << "m = "<<m<<" i = "<<L2[0][j]<<" j = "<<L1[0][i]<<'\n';
    //return corresponding value for a.
    return ( (m*L1[0][i]+L2[0][j]) % (p-1) );
}

//This quicksort was taken from K&R's "The C Programming Language"

void quicksort (Integer *list[2],int left, int right){
    Integer temp, temp2; int index, i, last;
    if (left >=right) return;
    swap (list, left, (left+right)/2);
    last = left;
    for (i=left + 1; i<=right;i++)
	if (list[1][i] < list[1][left])
	    swap(list, ++last, i);
    swap (list, left, last);
    quicksort (list, left, last-1);
    quicksort (list, last + 1, right);	    
}

void swap (Integer *list[2], int i, int j){
    Integer temp[2];
    temp[0] = list [0][i];
    temp[1] = list [1][i];
    list [0][i] = list [0][j];
    list [1][i] = list [1][j];
    list [0][j] = temp[0];
    list [1][j] = temp[1];
}


