/*
linear_algebra.c (ANSI C program)
Josh Arenberg CMSC443 project 1
Creation date unknown
Last modified March 29, 2000


linear_algebra.c is the implementation file for linear_algebra
*/

#include "linear_algebra.h"



/* function Menu takes no values and returns no value */
void Menu (void) {

  int e=0,menu_item=0;

  while(1) {

    printf("Make a selection by number:\n");
    printf("[1}: Multiply two nxn matrices\n");     /* menu choices */
    printf("[2]: Invert an nxn matrix\n");
    printf("[3]: Solve a matrix equation\n");
    printf("[9]: Exit program\n");
    printf("Selection: ");
    e=scanf("%d", &menu_item);
    if(e!=1) {
      printf("USAGE: Selection must be a number in the list\n");
      exit(1);
    }
    else {                         
      if (menu_item==9) {             
	printf("exiting...\n");
	exit(0);
      }
      if (menu_item==1 || menu_item==2 || menu_item==3) {
	MatrixOp(menu_item);
      }
      else {
	printf("USAGE: Selection must be a number in the list\n");
	exit(1);
      }	
    }
  }
}



/* function MatrixOp() takes an int and returns no value */
void MatrixOp(int s) {

  int i,j,n,r,k,e=0;
  int **a,**b,**c,*m;


 while(1) {
   printf("Input matrix size\n");       /*  get values for n,r */
   e=scanf("%d",&n);
   if(e!=1) {
     printf("USAGE: integers only\n");
     exit(1);
   }
   else if(n<1) {
     printf("USAGE: matrix size must be positive integer.  Please try again.\n");
   }
   else break;
 }

 while(1) {
   printf("%s\n","Input modulus");
   e=scanf("%d",&r);
   if(e!=1) {
     printf("USAGE: integers only\n");
     exit(1);
   }
   else if(r<1) {
     printf("USAGE: modulus must be positive integer.  Please try again.\n");
   }
   else break;
 }

 m=calloc(r,sizeof(int));                /* allocate memory for vector space r */
 m=Modulo(m,r);
 a=Mem(n);
 b=Mem(n);
 c=Mem(n);

             /*  matrix multiplication */
 if (s==1){                                 
  printf("Input the %dx%d elements of matrix 1 (one row per line)\n", n, n);
  for (i=0;i<n;i++){
    for (j=0;j<n;j++){

      while(1) {                            /*  error check inputs */
	e=scanf("%d",&k);
	if(e!=1) {
	  printf("USAGE: integers only\n");
	  exit(1);
	}
	else if(k<0 || k>=r) {
	  printf("USAGE: Outside of mod system\n");
	  exit(1);
	}
	else break;
      }

      a[i][j]=k;                /* insert value in next cell */
    }
  }

  printf("Now input the %dx%d elemnts of matrix 2 (one row per line)\n", n, n);
  for (i=0;i<n;i++){
    for (j=0;j<n;j++){

      while(1) {                            /*  error check inputs */
	e=scanf("%d",&k);
	if(e!=1) {
	  printf("USAGE: integers only\n");
	  exit(1);
	}
	else if(k<0 || k>=r) {
	  printf("USAGE: Outside of mod system.\n");
	  exit(1);
	}
	else break;
      }

      b[i][j]=k;                          /*  assign new value to matrix B */ 
    }
  }

  c=Multiply(a,b,r,n);      /*  multiply matrices */
  e=Results(n,c,1);         /* print menu */

  return;
 }

             /*  matrix inverse */
 if (s==2){
   printf("Input the %dx%d matrix elements (one row per line)\n", n, n);
   for (i=0;i<n;i++){
     for (j=0;j<n;j++){

       while(1) {                            /*  error check inputs */
	 e=scanf("%d",&k);
	 if(e!=1) {
	   printf("USAGE: integers only\n");
	   exit(1);
	 }
	 else if(k<0 || k>=r) {
	   printf("USAGE: Outside of mod system.\n");
	   exit(1);
	 }
	 else break;
       }

       a[i][j]=k;                           /*  assign new value to matrix A */
     }
   }

   c=Invert(a,m,r,n);    /* calculate inverse matrix */
   if(c[0][0]!=-1){

     e=Results(n, c, 2);    /* print menu */
   }
   return;
 }

               /*   solve linear system */
 if (s==3){
   printf("Input the %dx%d elements of matrix A (one row per line)\n", n, n);
   for (i=0;i<n;i++){
     for (j=0;j<n;j++){
    
       while(1) {                            /*  error check inputs */
	 e=scanf("%d",&k);
	 if(e!=1) {
	   printf("USAGE: integers only\n");
	   exit(1);
	 }
	 else if(k<0 || k>=r) {
	   printf("USAGE: Outside of mod system.\n");
	   exit(1);
	 }
	 else break;
       }

       a[i][j]=k;                            /* assign new value to matrix */
     }
   }
  
   printf("Input the %d elements of vector b (on one line):\n", n);
   for (i=0;i<n;i++){

     while(1) {                            /* error check inputs */
       e=scanf("%d",&k);
       if(e!=1) {
	 printf("USAGE: integers only\n");
	 exit(1);
       }
       else if(k<0 || k>=r) {
	 printf("USAGE: Outside of mod system.\n");
	 exit(1);
       }
       else break;
     }

     b[i][0]=k;                            /* assign new value to vector b */
   }

   c=Solve(a,b,m,r,n);                     /* solve system */

   e=Results(n,c,3);            /* print */

 }

 return;
}




/* function Mem takes an int and returns a pointer to an int array */
int** Mem(int n) { 

int i,j;
int **c;

 c=calloc(n,sizeof(int));         /*  allocate contiguous memory in case matrix */
 for (i=0;i<n;i++) {              /*  size(s) large */
   c[i]=calloc(n,sizeof(int));
 }
 return c;
}

/* function Modulo takes an int array and an int and returns a pointer to an int */
int* Modulo(int m[],int r) {

  int i,j;

  for (i=i;i<r;i++) {      /*  initialize array size of mod (r-1) */
    m[i]=0;
  }
  for (i=1;i<r;i++) {     /* use array values corresponding to integers in mod system */
    for (j=1;j<r;j++) {
      if ((i*j)%r==1) {
	m[i]=j;
	m[j]=i;
	break;
      }
    }
  }
  return m;
} 


/* function Multiply takes two pointers to int arrays and an int and returns a
   pointer to an int */
int** Multiply(int **a,int **b,int r,int n) {

  int i,j,k,q=0;
  int** c;

  c=Mem(n);                             /*  allocate memory for resultant vector */
  for(i=0;i<n;i++){                     /*  multiply array elements of a and b */
    for (j=0;j<n;j++){
      for (k=0;k<n;k++)
	q=q+((a[i][k])*(b[k][j]))%r;
      c[i][j]=q%r;                      /*  array c gets resultant vector mod r */
      q=0;
    } 
  }
  return c;
} 


/* function Invert takes two pointers to integer arrays and two ints and returns 
   a pointer to an integer array */
int** Invert(int **a,int m[],int r,int n) {

  int i,j,k,q,cnt,min,minj;
  int **z;

  z=Mem(n);                           /*  allocate space for new matrix */

  for(i=0;i<n;i++) {                  /*  create identity matrix */
    for(j=0;j<n;j++) {
      if (i==j)
	z[i][j]=1;
      else
	z[i][j]=0;
    }
  }

  for(i=0;i<n;i++) {
    do {
      cnt=0;                           /*  initialize temporary parameters */
      min=r;
      minj=-1;

      for(j=i;j<n;j++) {                   /* look for availabilty of row operations */
	if((a[j][i]<min)&&(a[j][i]!=0)) {
	  min=a[j][i];
	  minj=j;
	}
      }
      if(minj==-1){                           /* row op not permitted  */
	z[0][0]=-1;
	printf("This matrix is singular.\n\n");
	return z;
      }
      for(j=i;j<n;j++){                       /*  perform row operations */
	if(j!=minj){
	  q=(a[j][i])/min;
	  for(k=0;k<n;k++){
	    a[j][k]=(a[j][k]-q*a[minj][k])%r;
	    z[j][k]=(z[j][k]-q*z[minj][k])%r;
	    if(a[j][k]<0)
	      a[j][k]=a[j][k]+r;
	    if(z[j][k]<0)
	      z[j][k]=z[j][k]+r;
	  } 
	  if(a[j][i]==0)
	    cnt=cnt+1;
	}
      }
    } while(cnt<(n-i-1));

    if(minj!=i){                 /* look for next row to manipulate */
      for(k=0;k<n;k++){
	q=a[minj][k];
	a[minj][k]=a[i][k];
	a[i][k]=q;  
	q=z[minj][k];
	z[minj][k]=z[i][k];
	z[i][k]=q;
      }
    }
  }  

  for(i=0;i<n;i++){
    if (m[a[i][i]]==0){
      printf("This matrix is singular!\n\n");
      z[0][0]=-1;
      return z;  
    }
    else{                       /* extend row op across identity matrix */
      q=m[a[i][i]]; 
      for(j=0;j<n;j++){ 
	a[i][j]=(a[i][j]*q%r);
	z[i][j]=(z[i][j]*q%r); 
      }
    }
    for(j=0;j<i;j++){             /* complete row ops */
      q=a[j][i];
      for(k=0;k<n;k++){
	a[j][k]=(a[j][k]-q*a[i][k])%r; 
	if (a[j][k]<0)
	  a[j][k]=a[j][k]+r;
	z[j][k]=(z[j][k]-q*z[i][k])%r;
	if (z[j][k]<0)
	  z[j][k]=z[j][k]+r; 
      }
    }
  }

  return z;
}
   
/* function Solve takes three pointers to int arrays and two ints and returns a 
   pointer to an int array */
int** Solve(int **a,int **b,int m[],int r,int n) {

  int **c,**d;

  c=Mem(n);              /*  allocate space to store a vector and matrix */
  d=Mem(n);
  d=Invert(a,m,r,n);     /*  try to invert matrix A */
  c=Multiply(d,b,r,n);   /*  multiply inverted matrix A and vector b */
  return c;
}



/* function Results takes two ints and a pointer to a pointer to an int
   and returns an int */
int Results(int n, int **c,int sol_type) {

  int i,j,e=0,out;
  char outfile[MAXCHARS];
  FILE *of;

  while(1) {

    printf("\nWrite to [1] screen or [2] file:");
    if(e=scanf("%d", &out) != 1) {
      	printf("1 for screen, 2 for file\n");
    }
    else {
      if(out==1) {   /* print to screen */
	
	if(sol_type==3) {                   /* solution to equation */
	  printf("Solution vector:\n");
	  for(i=0;i<n;i++){
	    printf("%s%d%s%d\n","x",i+1," = ",c[i][0]);  
	  }
	}
	else {                  /* inverse or product */

	  if(sol_type==2) {                   /* inversion */
	    printf("Inverse is:\n");
	  }
	  else {                                  /* product */
	    printf("Product is:\n");
	  }
	  for(i=0;i<n;i++){
	    for(j=0;j<n;j++)
	      printf("%d%s",c[i][j]," ");
	    printf("%s\n"," ");
	  }
	}

      }

      else if(out==2) {   /* print to file */

	printf("Name of file: ");
	if(scanf("%s", outfile)!=1) {
	  printf("filename unspecified.  Please try again\n");
	}
	if(MAXCHARS < strlen(outfile)) {      /* filename too large */
	  printf("Error: filename must be of size %d chars or fewer\n",
		 MAXCHARS);
	  continue;
	}
	if((of=fopen(outfile, "w"))==NULL) {
	  printf("error getting output file\n");
	  e=1;
	  break;
	}

	if(sol_type==3) {                   /* solution to equation */
	  for(i=0;i<n;i++){
	    fprintf(of, "%s%d%s%d\n","x",i+1," = ",c[i][0]);  
	  }
	}

	else {            /* product or inversion */
	  for(i=0;i<n;i++){
	    for(j=0;j<n;j++)
	      fprintf(of, "%d%s",c[i][j]," ");
	    fprintf(of, "%s\n"," ");
	  }
	}
     
	fclose(of);                  /* close file */
      }          /* end else if */

      else {
	printf("S for screen, F for file\n");
      }
      break;
    }  /* end else- scanf no error */

  }

    printf("\n");
    return (e);
}





/* function main takes no values and returns an int */
int main(void) {



/** print instructions  **/
  printf("\nPlease enter nxn matrices one row per line.\n");
  printf("Integers in a mod r system must be input using their unique\n");
  printf("representatives in the mod system {0,1,..,r-1}\n");
  printf("(ex. in mod 5, you'd input a 0 not a 5).\n\n");

  Menu ();        /* print options and take selection */

  

  return (0);

}   /* end of linear_algebra.c */

