/*
 *  rsa_prog.c
 *    RSA Encryption/decryption program for:
 *    CMSC 443 Spring 1997, project 9
 *
 *    Michael A. Gurski  
 *
 *  Compile with:
 *     gcc -g -O -o rsa_prog rsa_prog.c -lgmp -lm
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <gmp.h>

#define ENCRYPTION_EXPONENT	65537	/* X.509, PKCS#1 */
#define ENCRYPTION_EXPONENT2	17
#define ENCRYPTION_EXPONENT3	3	/* PEM, PKCS#1 */
#define EXTENSION_LENGTH	9
#define LOG_2			((double) 0.301029995667)
#define MAX_FILENAME		500
#define MAX_PRIME_DIGITS	666
#define MAX_PRIME_RANGE		5000
#define MAX_PROB_REPS		30
#define RANDMASK		63 + 1
#define WORD_BIT_SIZE		32

void doProbPrimes(void);
void enCrypt();
void deCrypt();
void keyGen(void);


char *banner =
"RSA Encryption/Decryption Program\n"
"Copyright (c) 1997 Michael A. Gurski\n"
"                   CMSC 443 0101 Spring 1997\n"
"  This program probably would violate the RSADSI patent on using RSA\n"
"if it were used for something other than educational purposes.  Grab\n"
"RSAREF or a decent RSA implementation (if outside the draconian USA).\n"
"\n";

char *menuText =
" (1)  Generate probabilistic primes\n"
" (2)  Encrypt a message\n"
" (3)  Decrypt a message\n"
" (4)  Generate public/private keypair\n"
" (0)  Quit\n"
"\n"
"Choice: ";

int
main (int argc, char **argv) {
    int choice;

    /* seed the PRNG */
    srandom(time(NULL));

    printf("%s\n",banner);

    /* process input selections from menu */
    while (1) {
	printf("%s",menuText);
	fflush(stdout);
	scanf("%d",&choice);
	fflush(stdin);

	switch(choice) {
	case 0:	/* end program */
	    exit(0);
	    break;
	case 1:	/* probabilistic primes */
	    doProbPrimes();
	    break;
	case 2:	/* encrypt a file */
	    enCrypt();
	    break;
	case 3:	/* decrypt a file */
	    deCrypt();
	    break;
	case 4:	/* generate pub/priv keypair */
	    keyGen();
	    break;
	}
    }
}

/**
 * doProbPrimes()
 *
 *  Generates a list of probabilistic primes in the range
 *   x <= p <= x + MAX_PRIME_RANGE
 *  and displays them for the user.
 */
void
doProbPrimes(void) {
    char num[MAX_PRIME_DIGITS+1];
    mpz_t userIn,x;
    int i;

    printf("Enter a starting number up to %d digits in length.\n",MAX_PRIME_DIGITS);
    printf("Probabilistic primes p in the range:\n");
    printf("  x <= p <= x + %d\n",MAX_PRIME_RANGE);
    printf("will be found.\n");
    
    printf("Number: ");
    fflush(stdout);

    fgets(num,MAX_PRIME_DIGITS+1,stdin);

    /* initialize the integer type */
    mpz_init_set_str(userIn,num,10);
    mpz_init(x);

    /* if it's even, add 1... */
    if(!mpz_mod_ui(x,userIn,2)) {
	mpz_add_ui(x,userIn,1);
    }
    else {
	mpz_set(x,userIn);
    }

    /* loop through...skipping evens since they're not prime */
    for(i=0;i<MAX_PRIME_RANGE/2;i++) {
	mpz_add_ui(x,x,2);

	/* if gmp thinks it's a probabilistic prime, it works for us... */
	if(mpz_probab_prime_p(x,MAX_PROB_REPS)) {
	    printf("%s\n",mpz_get_str(NULL,10,x));
	}
    }

    mpz_clear(userIn);
    mpz_clear(x);
}

/**
 *  enCrypt()
 *
 *  Encrypt a specified text file with an RSA public key
 */
void
enCrypt(void) {
    char keyFile[MAX_FILENAME+EXTENSION_LENGTH+1], encryptFile[MAX_FILENAME+1];
    char outFile[MAX_FILENAME+1];
    FILE *PK,*OF;
    mpz_t e,n,m,c;
    int blockSize, *block, i, j;
    char *cBlock, scratch[4];

    /* initialize our integers */
    mpz_init(e);
    mpz_init(n);
    mpz_init(m);
    mpz_init(c);

    /* grab name of file to encrypt */
    printf("Enter filename to encrypt\n");
    fgets(encryptFile,MAX_FILENAME+1,stdin);
    encryptFile[strlen(encryptFile)-1] = NULL;

    /* grab name of ciphertext */
    printf("Enter filename to store ciphertext in\n");
    fgets(outFile,MAX_FILENAME+1,stdin);
    outFile[strlen(outFile)-1] = NULL;

    /* grab name of keyfile */
    printf("Enter filename for the public key to encrypt this file with\n");
    printf("WITHOUT the .public extension\n");
    fgets(keyFile,MAX_FILENAME+1,stdin);
    keyFile[strlen(keyFile)-1] = NULL;

    strcat(keyFile,".public");

    /* read in key */
    PK=fopen(keyFile,"r");
    mpz_inp_str(e,PK,10);
    mpz_inp_str(n,PK,10);
    fclose(PK);

#if 0
    printf("e      = %s\n",mpz_get_str(NULL,10,e));
    printf("n      = %s\n",mpz_get_str(NULL,10,n));
#endif

    /* compute blocksize */
    blockSize = (strlen(mpz_get_str(NULL,10,n)) - 1) / 3;
    block = calloc(blockSize+1, sizeof(int));
    cBlock = calloc(3 * blockSize + 1, sizeof(char));

#if 0
    printf("blockSize = %d\n",blockSize);
#endif

    PK=fopen(encryptFile,"r");
    OF=fopen(outFile,"w+");

    /* loop thru plaintext, reading in blocks */
    while(!feof(PK)) {

	/* read in enough chars for block */
	for(i=0;i<blockSize && !feof(PK);i++) {
	    block[i] = fgetc(PK);

	    if(block[i] == EOF) {
		block[i] = 0;
		break;
	    }
	    else if((block[i] < 32) || (block[i] > 127))
		block[i] = ' ';

	    block[i+1]=0;
	}

	/* concat decimal representations */
	for(j=0;j<i;j++) {
	    sprintf(scratch,"%d",block[j]);
	    strcat(cBlock,scratch);
	}

	/* put string into integer type */
	mpz_set_str(m,cBlock,10);

	/* c = (m^e) % n ... RSA encryption */
	mpz_powm(c,m,e,n);

	/* write out encrypted block */
	mpz_out_str(OF,10,c);
	fprintf(OF,"\n");
	cBlock[0]='\0';
    }
    
    fclose(PK);
    fclose(OF);
    
    mpz_clear(e);
    mpz_clear(n);
    mpz_clear(m);
    mpz_clear(c);
    free(cBlock);
    free(block);
}

/**
 *  deCrypt()
 *
 *  Decrypt a specified text file with an RSA private key
 */
void
deCrypt(void) {
    char keyFile[MAX_FILENAME+EXTENSION_LENGTH+1], decryptFile[MAX_FILENAME+1];
    char outFile[MAX_FILENAME+1];
    FILE *PK,*OF;
    mpz_t d,n,m,c;
    int blockSize, i, j;
    char *cBlock, scratch[4];

    /* init ints */
    mpz_init(d);
    mpz_init(n);
    mpz_init(m);
    mpz_init(c);
    
    /* grab name of file to decrypt */
    printf("Enter filename to decrypt\n");
    fgets(decryptFile,MAX_FILENAME+1,stdin);
    decryptFile[strlen(decryptFile)-1] = NULL;
    
    /* grab name of plaintext file */
    printf("Enter filename to store plaintext in\n");
    fgets(outFile,MAX_FILENAME+1,stdin);
    outFile[strlen(outFile)-1] = NULL;

    /* grab name of keyfile */
    printf("Enter filename for the private key to decrypt this file with\n");
    printf("WITHOUT the .private extension\n");
    fgets(keyFile,MAX_FILENAME+1,stdin);
    keyFile[strlen(keyFile)-1] = NULL;
    
    strcat(keyFile,".private");

    /* read in private key */
    PK=fopen(keyFile,"r");
    mpz_inp_str(d,PK,10);
    mpz_inp_str(n,PK,10);
    fclose(PK);
    
#if 0
    printf("d      = %s\n",mpz_get_str(NULL,10,d));
    printf("n      = %s\n",mpz_get_str(NULL,10,n));
#endif

    /* calculate size of character block */
    cBlock = calloc(mpz_sizeinbase(n,10)+2,sizeof(char));
    
    PK=fopen(decryptFile,"r");
    OF=fopen(outFile,"w+");

    /* loop through encrypted blocks, decrypting them */
    while(!feof(PK) && mpz_inp_str(c,PK,10)) {

	/* m = (c^d) % n ... RSA decryption */
	mpz_powm(m,c,d,n);
	
	mpz_get_str(cBlock,10,m);
#if 0
	printf("cBlock = %s\n",cBlock);
#endif

	/* loop through decrypted block, separating out individual
	 * characters
	 */
	for(i=0;i<strlen(cBlock);i++) {
	    if(cBlock[i] == '1') {
		scratch[0] = cBlock[i]; i++;
		scratch[1] = cBlock[i]; i++;
		scratch[2] = cBlock[i];
		scratch[3] = '\0';
	    }
	    else {
		scratch[0] = cBlock[i]; i++;
		scratch[1] = cBlock[i];
		scratch[2] = '\0';
		scratch[3] = '\0';
	    }
	    fprintf(OF,"%c",atoi(scratch));
	}
	
	fprintf(OF,"\n");
    }
    
    fclose(PK);
    fclose(OF);
    
    mpz_clear(d);
    mpz_clear(n);
    mpz_clear(m);
    mpz_clear(c);
    free(cBlock);
}

/**
 *  keyGen()
 *
 *  Generate a public/private RSA keypair
 */
void
keyGen(void) {
    int userDigits,rLimbs,i,r;
    mpz_t randInt1,randInt2,p,q,n,phi_n,e,d,p_1,q_1,foo;
    char fname[MAX_FILENAME+1], pubFname[MAX_FILENAME+EXTENSION_LENGTH+1];
    char privFname[MAX_FILENAME+EXTENSION_LENGTH+1];
    FILE *PUBF, *PRIVF;

    /* get an incredibly rough approximation of the number of digits to
     * use for the primes....  This is coarse because mp uses the concept
     * of "limbs" when generating random numbers....basically machine words,
     * so the # of digits only determines how many limbs...  The primes are
     * always AT LEAST the number of digits the user specifies.
     */
    do {
	printf("Enter the (ROUGH) approximate number of digits you want\n");
	printf("your RSA primes to be (min 12, max %d).\n\n",MAX_PRIME_DIGITS);
	
	printf("Size: ");
	fflush(stdout);
	
	scanf("%d",&userDigits);
	fflush(stdin);
	
    } while((userDigits < 12) || (userDigits > MAX_PRIME_DIGITS));

    /* figure out how many limbs will hold the specified number of digits */
    rLimbs = ((((double) userDigits) / LOG_2) / ((double) WORD_BIT_SIZE)) + 1;

#if 0
    printf("limbs: %d %f\n",rLimbs,((double) userDigits) / LOG_2);
#endif

    mpz_init(randInt1);
    mpz_init(randInt2);
    mpz_init(p);
    mpz_init(q);
    mpz_init(p_1);
    mpz_init(q_1);
    mpz_init(n);
    mpz_init(phi_n);
    mpz_init(e);
    mpz_init(d);
    mpz_init(foo);

    /* find a random starting point for p...since there's no mpz_srandom,
     * find a random number of random numbers first
     */
    do {
	r=random() & RANDMASK;
	for(i=0;i<r;i++)
	    mpz_random(randInt1,rLimbs);
    } while(strlen(mpz_get_str(NULL,10,randInt1)) < userDigits);

    mpz_set(p,randInt1);

    /* keep looking for probabilistic primes */
    while(!mpz_probab_prime_p(p,MAX_PROB_REPS)) {
	mpz_add_ui(p,p,1);
    }

    printf("p = %s\n",mpz_get_str(NULL,10,p));

    /* find a random starting point for q...since there's no mpz_srandom,
     * find a random number of random numbers first
     */
    do {
	r=random() & RANDMASK;
	for(i=0;i<r;i++)
	    mpz_random(randInt2,rLimbs);
    } while(strlen(mpz_get_str(NULL,10,randInt2)) < userDigits);

    mpz_set(q,randInt2);

    /* keep looking for probabilistic primes */
    while(!mpz_probab_prime_p(q,MAX_PROB_REPS)) {
	mpz_add_ui(q,q,1);
    }

    printf("q = %s\n",mpz_get_str(NULL,10,q));

    /* find n */
    mpz_mul(n,p,q);
    printf("n = %s\n",mpz_get_str(NULL,10,n));

    /* find phi(n) */
    mpz_sub_ui(p_1,p,1);
    mpz_sub_ui(q_1,q,1);
    mpz_mul(phi_n,p_1,q_1);
    printf("phi(n) = %s\n",mpz_get_str(NULL,10,phi_n));

    /* According to Applied Cryptography 2nd ed, p 469--470:
     *  "RSA encryption goes much faster if you're smart about
     *   choosing a value of e.  The three most common choices
     *   are 3, 17, and 65537 (2^{16}+1). ... X.509 recommends
     *   65537 [304], PEM recommends 3 [76], and PKCS #1 ...
     *   recommends 3 or 65537 [1345].  There are no security
     *   problems with using any of these three values for e
     *   (assuming you pad messages with random values ...),
     *   even if a whole group of users uses the same value for
     *   e."
     *
     * In the interest of speed, e is just going to be 65537. It's
     * a shame no one recommends 17...it fits in nicely with 5 & 23.
     */
    mpz_set_ui(e,ENCRYPTION_EXPONENT);
    printf("e = %s\n",mpz_get_str(NULL,10,e));

    /* find d */
    mpz_invert(d,e,phi_n);
    printf("d = %s\n",mpz_get_str(NULL,10,d));

    /* just for the hell of it, show that it works... */
    mpz_mul(foo,e,d);
    mpz_mod(foo,foo,phi_n);
    printf("e * d %% n = %s\n",mpz_get_str(NULL,10,foo));

    /* get filename for the pub/priv keys */
    printf("Enter filename to write out to (e.g. \"foo\" would give foo.public, foo.private)\n");
    printf("or just hit return to not save.\n");
    fgets(fname,MAX_FILENAME+1,stdin);

    fname[strlen(fname)-1] = NULL;

    if(fname[0] != NULL) {
	strcpy(pubFname,fname);
	strcat(pubFname,".public");
	
	strcpy(privFname,fname);
	strcat(privFname,".private");

	/* write out public key */
	PUBF=fopen(pubFname,"w+t");
	mpz_out_str(PUBF,10,e);
	fprintf(PUBF,"\n");
	mpz_out_str(PUBF,10,n);
	fprintf(PUBF,"\n");
	fclose(PUBF);

	/* write out private key */
	PRIVF=fopen(privFname,"w+t");
	mpz_out_str(PRIVF,10,d);
	fprintf(PRIVF,"\n");
	mpz_out_str(PRIVF,10,n);
	fprintf(PRIVF,"\n");
	fclose(PRIVF);
    }

    mpz_clear(randInt1);
    mpz_clear(randInt2);
    mpz_clear(p);
    mpz_clear(q);
    mpz_clear(p_1);
    mpz_clear(q_1);
    mpz_clear(n);
    mpz_clear(phi_n);
    mpz_clear(e);
    mpz_clear(d);

}


