

/////////////////////////////////////////////////////////////////////////////
// File: rsa.C
// Contents: See below.
// Written originally by Myron Kennedy, Fall 1995,
// Modified in Spring 1997 by Ron Hannebohn
// Please read the rsa.README file for the list of revisions and other
// comments.

/*****************************************************************************
 *                         originally written by                             *
 *                         Myron Kennedy                                     *
 *                         modified by Ron Hannebohn and B. Stephens         *
 *                         for students in CMSC443 -                         *
 *                         Bignum RSA scheme                                 *
 *                         because of patent restrictions this program       * 
 *                         should only be used for 443 projects              *
 *****************************************************************************
 *                                                                           *
 *  This program implements interactive RSA tools.  It supports a            *
 *  menu structure that allows a user to generate probabilistic primes,      *
 *  compute N as the product of two probabilistic primes; factor a number;   *
 *  compute phi(N); randomly generate an encryption exponent; find an        *
 *  inverse in a mod system; display N, phi(N), the encryption exponent,     *
 *  and the decryption exponent; and perform exponentiation in a mod         *
 *  system.                                                                  *
 *                                                                           *
 *                                                                           *
 *  		To compile:  g++ -o rsa rsa.C -lg++                          *
 *              To run:      rsa                                             *
 *                                                                           *
 *****************************************************************************/

#include <Integer.h>
#include <stdlib.h>
#include <ctype.h>
#include <iostream.h>
#include <time.h>
#include <math.h>
#include "rsa.H"


main()
{
  
  char answer[MAX_ANSWER_LENGTH];
  Integer i, num, factor1 = 1, factor2 = 1, phi, m, e, d, exp, fastexp, 
          n1, n2, number = 0, a, n, jacobi, result, count, temp;
  Integer primes[750];
  int is_continue = 1, factored = 0, choice, invalid, current_prime_count = 0;
  int ii,j;
  int numtests;
  int test;

  n = 0;
  num = 0;
  phi = 0;
  e = 0, d = 0;
  numtests = 12;

  // Generate an initial list of primes between 1 and 5000.  This is for if 
  // the user immediately decides to find phi, or generate an encryption 
  // exponent, or some other operation requiring the prime list, without 
  // first generating the primes. 

  cout << "Please wait while initial list of primes is generated...." << endl;
  n = 3;
  for (i = n; i <= (n + 5000); i = i + 2)
  {
    a = (Integer) (random() % (i - 1));
    if (gcd(a, n) == 1)
    {
      if (current_prime_count >= MAX_SIZE)
	break;
      test = 0;
      for (ii=1;ii<=numtests;ii++)
	{
         jacobi = Jacobi(a, i);
         result = FastExp(a, (i - 1)/2, i);
         if (jacobi == result)
           test = 1;
         else
           test = 0;	   
         a = (Integer) (random() % (i - 1));
        }
      if (test == 1)
         primes[current_prime_count++] = i;
    }
  }

  while(is_continue)
  {
    PrintMenu();
    cout << "Please select one operation: ";
    cin >> answer;
    choice = atoi(answer);
    
    switch (choice)
    {
      case GENERATE_PRIMES:       // generate some probabilistic primes
        cout << endl << endl;
	cout << " Enter a number and probabilistic primes will be " << endl;
	cout << " generated in the interval (number < p < number + 5000)";
	cout << endl << " where p would be a prime:";
	cout << endl << endl << "      Enter number: ";
	n = GetPositiveInt("Enter number: ");
        cout << endl << " Probabilistic primes between " << n << " and ";
	cout << (n + 5000) << endl;
	current_prime_count = 0;
	srandom(time(NULL));
	count = 0;
	if (n % 2 == 0)    /* be sure to start with an odd number */
	  n = n - 1;                  
	if (n == 1)
	  n = 3;

	// To determine if the number is prime, for each candidate generate
	// one random number in the mod of the candidate, and if this random
	// number passes the Jacobi test numtest times, it is assumed to be
	//    prime. numtests now set to 8.
 
	for (i = n; i <= (n + 5000); i = i + 2)
	{
          test = 0;
	  a = (Integer) (random() % (i - 1));
	
	  
	
         if (current_prime_count >= MAX_SIZE)
	break;
       
      for (ii=1;ii<=numtests;ii++)
	{
         jacobi = Jacobi(a, i);
         result = FastExp(a, (i - 1)/2, i);
         if (jacobi == result)
           test = 1;	   
         else
           test = 0;
         a = (Integer) (random() % (i - 1));
        }
      if (test == 1)
        {
         if (current_prime_count >= MAX_SIZE)
		break;
	      primes[current_prime_count++] = i;
	      if (count % 8 == 0)
		cout << endl << "\t";
	      cout << " " << i;
	      count = count + 1;
	}
    }






	cout << endl;
	factored = 1;
	break;

      case DISPLAY_PRIMES:
	DisplayPrimeList(primes, current_prime_count);
	cout << endl;
	break;

      case COMPUTE_N_AS_PPROD:   // compose a number with only 2 prime factors.
	count = 1;
	while(1)
	{
	  cout << endl << " Enter the 1st prime number: ";
	  n1 = GetPositiveInt("Enter the 1st prime number: ");
	  cout << " Enter the 2nd prime number: ";
	  n2 = GetPositiveInt("Enter the 2nd prime number: ");
	  //	  if (n1 == n2 ||
	  //     !InPrimeList(n1, primes, current_prime_count) ||
	  //     !InPrimeList(n2, primes, current_prime_count))
	  //	  {
	  //   cout << "Numbers equal or not in the current primes list." << endl;
	  //	    count = count + 1;
	  //   if (count % 3 == 0)
	  //      DisplayPrimeList(primes, current_prime_count);
	  //    continue;
	  //	  }
	  //  else
	     break;
	}

	number = n1 * n2;
	cout << endl << " The number you computed is: " << endl;
	cout << "     " << number << endl;
	e = d = phi = 0;
	break;
      
      case TRIAL_DIV_FACTORING:       // perform trial division factoring
	cout << endl << "Enter a positive integer to see if it is prime: ";
	num = GetPositiveInt("Enter positive integer to see if it is prime: ");
	cout << " Trial division factors of " << num << ":" << endl << endl;
	Factor(num, factor1, factor2);
	factored = 1;  
	break;
          
      case POLLARD_FACTOR:        /* perform Pollard Rho factoring */
	cout << endl << "Enter a positive integer to see if it is prime: ";
	num = GetPositiveInt("Enter positive integer to see if its prime: ");
	cout << " Pollard's factors of " << num << ":" << endl;
	cout << endl;
	PollardRho(num, factor1, factor2);
	factor1 = factor2 = 1;
	factored = 1;
	break;
        
      case PHI:       /* compute phi(n) */
	phi = GetPhi(factor1, factor2, primes, current_prime_count);
	number = factor1 * factor2;
	factor1 = factor2 = 1;
	e = d = 0;
	break;

      case RANDOM_E:       /* generate encryption exponent */
	if (phi == 0)
	{
	  cout << endl << " Phi(N) has not been computed." << endl;
	  phi = GetPhi(factor1, factor2, primes, current_prime_count);
	  number = factor1 * factor2;
	  factor1 = factor2 = 1;
	}
	invalid = 1;
	srandom(time(NULL));
	while (invalid)
	{
	  e = (Integer)(random() % phi);
	  // Automatically calculate d by taking e's inverse. 
	  d = Inverse(phi, e);
	  if (gcd(e, phi) == 1 && e != d)
	  {
	    cout << endl;
	    cout << "\t Encryption exponent is: " << e << endl;
	    cout << "\t Decryption exponent (e's-inverse) is: " << d << endl;
	    cout << "\t Mod system phi(N) = " << phi << endl;
	    invalid = 0;
	  }
	}	  
	break;

      case MOD_INVERSE:      /* find Inverse */
	cout << endl << "Enter a positive number: ";
	num = GetPositiveInt("Enter positive number to invert: ");
	cout << "Enter a positive modulus: ";
	m = GetPositiveInt("Enter a positive modulus: ");	  
	temp = Inverse(m, num);
	if (temp == -1)
	{
	  cout << endl << "\t" << num << " has no inverse modulo ";
	  cout << m << endl;
	}
	else
	{
	  cout << endl << "\t" << num << "'s inverse modulo ";
	  cout << m << " = " << temp << endl;
	}

	cout << endl << endl;         
	break;
	
      case DISPLAY:       /* display data */
	cout << endl;
	if (number == 0)  
	  cout << "\t n was not computed." << endl;
	else 
	  cout << "\t n      = " << number << endl;
	if (phi == 0)
	  cout << "\t phi(n) was not computed." << endl;
	else
	  cout << "\t phi(n) = " << phi << endl;
	if (e == 0)
	  cout << "\t e was not computed." << endl;
	else
	  cout << "\t e      = " << e << endl;
	if (d == 0)
	  cout << "\t d was not computed." << endl;
	else
	  cout << "\t d      = " << d << endl;
	break;

      case EXPONENTIATE:       /* perform fast exponentiation */
	cout << endl << "Enter a positive number for the base: ";
	num = GetPositiveInt("Enter a positive number for the base: ");
	cout << "Enter a positive exponent: ";
	exp = GetPositiveInt("Enter a positive exponent: ");
	cout << "Enter a positive modulus: ";
	m = GetPositiveInt("Enter a positive modulus: ");
	fastexp = FastExp(num, exp, m);
	cout << endl << num;
	cout << "^" << exp << " mod " << m << " = " << fastexp << endl;   
	break;
	
      case EXIT:     /* exit */
	is_continue = 0;
	break;
      
      default:
	cout << endl << endl << "\t *** Your input was not accepted. ***";
	cout << endl;
	break;
    }
  }

  cout << endl << "\t *** Thank you for using MK's RSA program! ***" << endl;  
}


/*****************************  PrintMenu  ***********************************
 This function simply prints the header for the user menu.
 *****************************************************************************/
void PrintMenu()
{
  cout << endl;
  cout << "\t *****************************************************" << endl;
  cout << "\t *                                                   *" << endl;
  cout << "\t *    (1)  Generate probabilistic primes             *" << endl;
  cout << "\t *    (2)  Display current list of generated primes  *" << endl;
  cout << "\t *    (3)  Compute N as product of 2 primes          *" << endl;
  cout << "\t *    (4)  Perform trial division factoring          *" << endl;
  cout << "\t *    (5)  Find prime factors of a number (Pollard)  *" << endl;
  cout << "\t *    (6)  Compute phi(N)                            *" << endl;
  cout << "\t *    (7)  Randomly choose encryption exponent       *" << endl;
  cout << "\t *    (8)  Find an inverse in a mod system           *" << endl;
  cout << "\t *    (9)  Display N, phi(N), e, and d               *" << endl;
  cout << "\t *    (10) Perform exponentiation in a mod system    *" << endl;
  cout << "\t *    (11) Exit                                      *" << endl;
  cout << "\t *                                                   *" << endl;
  cout << "\t *****************************************************" << endl;
  cout << endl << endl;
}



 
/***************************** Jacobi **************************************
 This function evaluates the jacobi symbol for use with the Solovay-Strassen
 primality test.  It is a recursive definition that is implemented without
 knowledge of the prime factorization of its random number.
 ***************************************************************************/
Integer Jacobi (const Integer & rand_num, const Integer & num)
{
  Integer d, p1, temp;
 
  if (rand_num == 1)
    return 1;
  d = gcd(rand_num, num);
  if (d > 1) 
    return 0;
  else
  {
    if (rand_num % 2== 0)
    {
      temp = (num*num-1)/8;   
      if (temp % 2 == 0)
      {
	p1 = rand_num/2;
	return Jacobi(p1,num);
      }
      else
      { 
	p1 = rand_num/2;
	return -1*Jacobi(p1,num);
      }
    }
    else
    {
      temp = (rand_num - 1)*(num - 1)/4;
      if (temp % 2 == 0)
      {
	p1 = num % rand_num ;
	return Jacobi(p1,rand_num);
      }
      else
      {
	p1 = num % rand_num ;  
	return -1*Jacobi(p1,rand_num);
      } 
    } 
  } 
}



/******************************  Factor  ************************************
 This function makes an attempt at brute force factoring.  It only tests
 odd numbers, assuming that the number that you are trying to factor is the
 product of 2 primes.
 ****************************************************************************/

void Factor (const Integer &number, Integer &f1, Integer &f2)
{
  Integer factor_count(atoIntRep("0"));
  int is_true = 1;

  for (Integer i = 2; i <= sqrt(number); i = i + 1)
  {
    while(is_true)
    {
      if (i % 2 == 0) 
	i = i + 1;
      else if (i % 3 == 0)
	i = i + 1;
      else
	is_true = 0;
    }

    is_true = 1;

    if (number % i == 0)  
    {
      f1 = i;       /* 'i' is a factor of 'number' */
      cout << "\t" << f1;
      f2 = (number/f1);
      cout << "\t" << f2 << endl;
      factor_count = 1;
      i = number;
    }

    if (i == MAX_TIME) 
    {
      cout << endl << " Taking too much CPU time......Aborting...." << endl;
      i = sqrt(number) + 2;
    }
    
  }

  if (factor_count == 0)
    cout << endl << "\t" << number << " is prime";
  cout << endl << endl;
}
 



/****************************  PollardRho  *********************************
 This function attempts to factor large integers using the PollardRho 
 factoring algorithm, assuming that the number to be factored is the 
 product of 2 primes.
 ***************************************************************************/ 

void PollardRho (const Integer &number, Integer &f1, Integer &f2)
{
  Integer t(atoIntRep("1"));
  Integer a = t;
  Integer b = t;  
  int count = 0; 
  Integer c;

  while (t == 1) 
  {
    a = ((a * a) + 1) % number;
    b = ((b * b) + 1) % number;
    b = ((b * b) + 1) % number;
    c = abs(a - b) % number;
    count++;
    if (count == MAX_TIME)
    {
      cout << endl << " Taking too much computer time...Aborting..." << endl;
      t = -1;
    }

    if (count < MAX_TIME) 
      t = gcd(c, number);

  }

  if (t != -1)
  {
    f1 = t;
    cout << "\t" << f1;
    f2 = (number/f1);
    cout << "\t" << f2 << endl;
  }
} 
   


/******************************  Inverse  *********************************
 This function computes the inverse of a number in a mod system, if it 
 exists.  The extended Euclidean algorithm was used. 
 **************************************************************************/
Integer Inverse (const Integer &m, const Integer &num)
{
  Integer n0, inv, b0, t, t0, q, r, temp;

  n0 = m;
  b0 = num; 
  t0 = 0;
  t = 1;
  q = (Integer) (n0/b0);
  r = n0 - q * b0;

  while (r > 0)
  {
    temp = t0 - q * t;
    if (temp >= 0)
      temp = temp % m;
    else if (temp < 0)
      temp = m - ((-temp) % m);
    t0 = t;
    t = temp;
    n0 = b0;
    b0 = r;
    q = (Integer) (n0/b0);
    r = n0 - q * b0;
  }
  
  if (b0 != 1)
    inv = -1;
  else 
    inv = t % m;

  return inv;
 
}




/***************************** FastExp *************************************
 This function does fast exponentiation in a mod system when the cipher text,
 exponent, and modulus are passed in as parameters.
 ***************************************************************************/
Integer FastExp (const Integer &ciph, const Integer &exp, const Integer &n)
{

 Integer c1, exp1, x;

 c1 = ciph;
 exp1 = exp;
 x = 1;
 while (exp1 != 0)
   {
    while ((exp1 % 2) == 0)
      {
       exp1 = exp1/2;
       c1 = (c1 * c1) % n;
      }
    exp1 = exp1 - 1;
    x = (x * c1) % n;
   }
 return x;
}


/////////////////////////////////////////////////////////////////////////////
//
// Function: GetPositiveInt(RH)
// Purpose: Retrieves input until a positive integer is received.
// Parameters: none
// Returns: an Integer, the input.

Integer GetPositiveInt(char *prompt)
{
  char answer[MAX_ANSWER_LENGTH];
  Integer n;

  cin >> answer;
  n = atoI(answer);
  while (n <= 0) 
  {
    cout << "\t Number is not positive!" << endl;
    cout << endl << " " << prompt;
    cin >> answer;
    n = atoI(answer); 
  }

  return n;
}

//////////////////////////////////////////////////////////////////////////////
//
// Function(RH): InPrimeList
// Purpose: Determines if the given number is in the current list of primes.
// Parameters:
//   Integer number -- the number to be found.
//   Intger *primes -- the current array of probablisitic primes.
//   int length -- the current length of the primes list.
// Returns: int -- 1 if the number is in the list, 0 otherwise.

int InPrimeList(Integer number, Integer *primes, int length)
{
  int i;

  for (i = 0; i < length; i++)
    if (number ==  primes[i])
      return 1;
  return 0;
}



///////////////////////////////////////////////////////////////////////////////
//
// Function(RH): GetPhi
// Purpose: get two (prime) numbers from the user and calculates phi. 
// Parameters:
//   Integer &factor1 -- reference to first prime factor that will be chosen.
//   Integer &factor2 -- reference to 2nd prime factor that will be chosen.
//   Integer *primes -- the current list of primes. 
//   int current_prime_count -- the length of the current prime list. 
// Returns: Integer, the value of phi.

Integer GetPhi(Integer &factor1, Integer &factor2, Integer *primes, 
	       int current_prime_count)
{
  Integer phi;
  int count = 1;

  while(1)
  {
    // Loop until two positive integers that are currently in the prime list
    // are input by the user. 
    cout << endl << " Enter 1st prime number to calculate phi: ";
    factor1 = GetPositiveInt("Enter the 1st prime number: ");
    cout << " Enter the 2nd prime number to calculate phi: ";
    factor2 = GetPositiveInt("Enter the 2nd prime number: ");
    //    if (factor1 == factor2 ||
    //	!InPrimeList(factor1, primes, current_prime_count) ||
    //	!InPrimeList(factor2, primes, current_prime_count))
    // {
    //  cout << "Numbers equal or not in the current primes list." << endl;
    //  count++;
    //  // Display the list if user has entered erroneous input every 3 turns.
    //   if (count % 3 == 0)
    //   {
    //	cout << "Displaying current primes list..." << endl;
    //	DisplayPrimeList(primes, current_prime_count);
    //   }
    //  continue;
    //  }
    //   else
      break;
  }

  phi = ((factor1 - 1) * (factor2 - 1));
  cout << endl << "\t Phi(N) = " << (factor1 - 1) << " * ";
  cout << (factor2 - 1) << " = ";
  cout << phi << endl; 
  return phi;
}



//////////////////////////////////////////////////////////////////////////////
//
// Function: DisplayPrimeList(RH)
// Purpose: see function name.
// Parameters:
//   Integer *primes -- the array of primes.
//   int current_prime_count -- the number of primes in the current list.
// Returns: This function does not return a value.
//

void DisplayPrimeList(Integer *primes, int length)
{
  int i;

  cout << "Displaying prime list..." << endl;

  for (i = 0; i < length; i++)
  {
    if (i % 8 == 0)
      cout << endl << "\t";
    cout << " " << primes[i];
  }
  cout << endl;
}






