/*
 *  rsa_prog.c
 *    RSA Encryption/decryption 
 *    
 *
 *    Michael A. Gurski  
 *
 *  Compile with:
 *     gcc -g -O -o rsa_prog rsa_prog.c -lgmp -lm
 *
 *  Note:...compiles and tested under Red Hat Linux (7.3)
 */
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <gmp.h>

#define ENCRYPTION_EXPONENT	19	/* X.509, PKCS#1 */
#define ENCRYPTION_EXPONENT2	17
#define ENCRYPTION_EXPONENT3	3	/* PEM, PKCS#1 */
#define EXTENSION_LENGTH	3
#define LOG_2			((double) 0.301029995667)
#define MAX_FILENAME		500
#define MAX_PRIME_DIGITS	666
#define MAX_PRIME_RANGE		5000
#define MAX_PROB_REPS		30
#define RANDMASK		63 + 1
#define WORD_BIT_SIZE		32

void doProbPrimes(void);
void enCrypt();
void deCrypt();
void keyGen(void);


char *banner =
"RSA Encryption/Decryption Program\n 1997 Michael A. Gurski\n modified by B. Stephens\n for use in an educational environment\n"; 



char *menuText =
" (1)  Generate probabilistic primes\n"
" (2)  Encrypt a message\n"
" (3)  Decrypt a message\n"
" (0)  Quit\n"
"\n"
"Choice: ";

int
main (int argc, char **argv) {
    int choice;
    char tmp[10];
    /* seed the PRNG */
    srandom(time(NULL));

    printf("%s\n",banner);

    /* process input selections from menu */
    while (1) {
	printf("%s",menuText);
	fflush(stdout);
        fgets(tmp, 9, stdin);
        choice = atoi(tmp);
	/* scanf("%d",&choice); */
        
	/* fflush(stdin); */

	switch(choice) {
	case 0:	/* end program */
	    exit(0);
	    break;
	case 1:	/* probabilistic primes */
	    doProbPrimes();
	    break;
	case 2:	/* encrypt a file */
	    enCrypt();
	    break;
	case 3:	/* decrypt a file */
	    deCrypt();
	    break;
	    /* case 4:	 generate pub/priv keypair 
	    keyGen();
	    break;  */
	}
    }
}

/**
 * doProbPrimes()
 *
 *  Generates a list of probabilistic primes in the range
 *   x <= p <= x + MAX_PRIME_RANGE
 *  and displays them for the user.
 */
void
doProbPrimes(void) {
    char num[MAX_PRIME_DIGITS+1];
    mpz_t userIn,x;
    int i;

    printf("Enter a starting number up to %d digits in length.\n",MAX_PRIME_DIGITS);
    printf("Probabilistic primes p in the range:\n");
    printf("  x <= p <= x + %d\n",MAX_PRIME_RANGE);
    printf("will be found.\n");
    
    printf("Number: ");
    fflush(stdout);

    fgets(num,MAX_PRIME_DIGITS+1,stdin);

    /* initialize the integer type */
    mpz_init_set_str(userIn,num,10);
    mpz_init(x);

    /* if it's even, add 1... */
    if(!mpz_mod_ui(x,userIn,2)) {
	mpz_add_ui(x,userIn,1);
    }
    else {
	mpz_set(x,userIn);
    }

    /* loop through...skipping evens since they're not prime */
    for(i=0;i<MAX_PRIME_RANGE/2;i++) {
	mpz_add_ui(x,x,2);

	/* if gmp thinks it's a probabilistic prime, it works for us... */
	if(mpz_probab_prime_p(x,MAX_PROB_REPS)) {
	    printf("%s\n",mpz_get_str(NULL,10,x));
	}
    }

    mpz_clear(userIn);
    mpz_clear(x);
}

/**
 *  enCrypt()
 *
 *  Encrypt a specified text file with an RSA public key
 */
void
enCrypt(void) {
    char keyFile[MAX_FILENAME+EXTENSION_LENGTH+1];
    char eFile[MAX_FILENAME+1];
    char outFile[MAX_FILENAME+1];
    FILE *PK,*OF;
    mpz_t e,n,m,c;
    int blockSize, *block, i, j;
    char *cBlock, scratch[4];

    /* initialize our integers */
    mpz_init(e);
    mpz_init(n);
    mpz_init(m);
    mpz_init(c);

    /* grab name of file to encrypt */
    printf("Enter filename to encrypt\n");
    fgets(eFile,MAX_FILENAME+1,stdin); 
    eFile[strlen(eFile)-1] = NULL;

    /* grab name of ciphertext */
    printf("Enter filename to store ciphertext in\n");
    fgets(outFile,MAX_FILENAME+1,stdin);
    outFile[strlen(outFile)-1] = NULL;

    /* grab name of keyfile */
    printf("Enter filename for the public key to encrypt this file with\n");
    printf("WITHOUT the .public extension\n");
    fgets(keyFile,MAX_FILENAME+1,stdin);
    keyFile[strlen(keyFile)-1] = NULL;

    strcat(keyFile,".public");

    /* read in key */
    PK=fopen(keyFile,"r");
    mpz_inp_str(e,PK,10);
    mpz_inp_str(n,PK,10);
    fclose(PK);

#if 0
    printf("e      = %s\n",mpz_get_str(NULL,10,e));
    printf("n      = %s\n",mpz_get_str(NULL,10,n));
#endif

    /* compute blocksize */
    blockSize = (strlen(mpz_get_str(NULL,10,n)) - 1) / 3;
    block = calloc(blockSize+1, sizeof(int));
    cBlock = calloc(3 * blockSize + 1, sizeof(char));

#if 0
    printf("blockSize = %d\n",blockSize);
#endif

    PK=fopen(eFile,"r");
    OF=fopen(outFile,"w+");

    /* loop thru plaintext, reading in blocks */
    while(!feof(PK)) {

	/* read in enough chars for block */
	for(i=0;i<blockSize && !feof(PK);i++) {
	    block[i] = fgetc(PK);

	    if(block[i] == EOF) {
		block[i] = 0;
		break;
	    }
	    else if((block[i] < 32) || (block[i] > 127))
		block[i] = ' ';

	    block[i+1]=0;
	}

	/* concat decimal representations */
	for(j=0;j<i;j++) {
	    sprintf(scratch,"%d",block[j]);
	    strcat(cBlock,scratch);
	}

	/* put string into integer type */
	mpz_set_str(m,cBlock,10);

	/* c = (m^e) % n ... RSA encryption */
	mpz_powm(c,m,e,n);

	/* write out encrypted block */
	mpz_out_str(OF,10,c);
	fprintf(OF,"\n");
	cBlock[0]='\0';
    }
    
    fclose(PK);
    fclose(OF);
    
    mpz_clear(e);
    mpz_clear(n);
    mpz_clear(m);
    mpz_clear(c);
    free(cBlock);
    free(block);
}

/**
 *  deCrypt()
 *
 *  Decrypt a specified text file with an RSA private key
 */
void
deCrypt(void) {
    char keyFile[MAX_FILENAME+EXTENSION_LENGTH+1], decryptFile[MAX_FILENAME+1];
    char outFile[MAX_FILENAME+1];
    FILE *PK,*OF;
    mpz_t d,n,m,c;
    int blockSize, i, j;
    char *cBlock, scratch[4];

    /* init ints */
    mpz_init(d);
    mpz_init(n);
    mpz_init(m);
    mpz_init(c);
    
    /* grab name of file to decrypt */
    printf("Enter filename to decrypt\n");
    fgets(decryptFile,MAX_FILENAME+1,stdin);
    decryptFile[strlen(decryptFile)-1] = NULL;
    
    /* grab name of plaintext file */
    printf("Enter filename to store plaintext in\n");
    fgets(outFile,MAX_FILENAME+1,stdin);
    outFile[strlen(outFile)-1] = NULL;

    /* grab name of keyfile */
    printf("Enter filename for the private key to decrypt this file with\n");
    printf("WITHOUT the .private extension\n");
    fgets(keyFile,MAX_FILENAME+1,stdin);
    keyFile[strlen(keyFile)-1] = NULL;
    
    strcat(keyFile,".private");

    /* read in private key */
    PK=fopen(keyFile,"r");
    mpz_inp_str(d,PK,10);
    mpz_inp_str(n,PK,10);
    fclose(PK);
    
#if 0
    printf("d      = %s\n",mpz_get_str(NULL,10,d));
    printf("n      = %s\n",mpz_get_str(NULL,10,n));
#endif

    /* calculate size of character block */
    cBlock = calloc(mpz_sizeinbase(n,10)+2,sizeof(char));
    
    PK=fopen(decryptFile,"r");
    OF=fopen(outFile,"w+");

    /* loop through encrypted blocks, decrypting them */
    while(!feof(PK) && mpz_inp_str(c,PK,10)) {

	/* m = (c^d) % n ... RSA decryption */
	mpz_powm(m,c,d,n);
	
	mpz_get_str(cBlock,10,m);
#if 0
	printf("cBlock = %s\n",cBlock);
#endif

	/* loop through decrypted block, separating out individual
	 * characters
	 */
	for(i=0;i<strlen(cBlock);i++) {
	    if(cBlock[i] == '1') {
		scratch[0] = cBlock[i]; i++;
		scratch[1] = cBlock[i]; i++;
		scratch[2] = cBlock[i];
		scratch[3] = '\0';
	    }
	    else {
		scratch[0] = cBlock[i]; i++;
		scratch[1] = cBlock[i];
		scratch[2] = '\0';
		scratch[3] = '\0';
	    }
	    fprintf(OF,"%c",atoi(scratch));
	}
	
	fprintf(OF,"\n");
    }
    
    fclose(PK);
    fclose(OF);
    
    mpz_clear(d);
    mpz_clear(n);
    mpz_clear(m);
    mpz_clear(c);
    free(cBlock);
}

/**
 *  keyGen()
 *
 *  Generate a public/private RSA keypair
 */



