/* ohmask.c
 *
 * Removes parts of an image affected by OH using an OH mask.
 * Pixels with OH are replaced with a fit to the 
 * unaffected pixels. The mask has 1's where there's OH, 0 
 * otherwise.
 *
 * When determining the fit to the original data, the mask can be grown
 * to make sure that only good pixels are used for the fit.
 *
 * A. J. Dean, 5th November, 2002
 *
 */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "fitswrap.h"
#include "mathutils.h"
#include "nrroutines.h"

/*#define DIAGNOSTIC_OUTPUT*/
# if defined(DIAGNOSTIC_OUTPUT)
#include "cpgplot.h"  /* for the graphics output used in testing */
# endif /* DIAGNOSTIC_OUTPUT */

static void poly(float x, float* afunc, int ma); 
/* ma is the order of the fit plus 1 */

int main(int argc, char* argv[])
{

  float* imdata=NULL;
  float* maskdata=NULL;
  float* growmask=NULL;
  float* outdata=NULL;
  float* nr_x=NULL;
  float* nr_y=NULL;
  float* nr_y2=NULL;
  float* sigma=NULL;
  float yp1=0;
  float ypn=0;
  float y_out=0;
  float product=0;

  /* For the 3rd order poly fit */
  /*  float a[5]; *//* Need 1 to 4 for a 3rd order fit, in a C 0 based array */
  float* a=NULL;
  float** covar=NULL;
  float chisq=0;
  /* int ia[5]; */
  int* ia=NULL;
  int ma=0;
  /* ***** */

  long naxesdata[2];
  long naxesmask[2];

  int grow=0;
  int order=0;
  int num=0; 
  int count=0;

  int start=0;
  int end=0;

  int i=0;
  int j=0;
  int k=0;
  int l=0;

# if defined(DIAGNOSTIC_OUTPUT)
  float* fitx=NULL;
  float* fity=NULL;
  int m=0;
  float ymin=0;
  float ymax=0;
  char title[45];
# endif /* DIAGNOSTIC_OUTPUT */

  int fit_spline=0;

  /* Check for appropriate input */
  if (argc < 6) {

      fprintf(stderr,
          "Usage: %s image_to_correct.fits mask_image.fits out_file.fits grow_width spline||poly [poly order]\n",
           argv[0]);
        exit(1);

  }
  
  sscanf(argv[4],"%i",&grow);

  printf("\n******************************************************************\n");
  printf("CODE UPDATE 17 Dec 2002 - wavelength solution maintained in header\n");
  printf("Email ajd if there are any problems\n");
  printf("******************************************************************\n\n");

  if( strcmp(argv[5],"spline")==0 ){
   
    printf("ohmask: fitting a spline\n");
    fit_spline=1;
    
  } else if ( strcmp(argv[5],"poly")==0 ){
    
    printf("ohmask: fitting a polynomial\n");
    fit_spline=0;
    if (argc<7){

      printf("ohmask: polynomial order not specified, using 3\n");
      order=3;

    } else {

      sscanf(argv[6],"%i",&order);
      printf("ohmask: polynomial order is %i\n",order);

    }

  } else {

    printf("ohmask: ERROR! didn't understand the fit type\n");
    exit(1);

  }

  imdata=readfits_float_clever(argv[1], naxesdata);
  maskdata=readfits_float_clever(argv[2], naxesmask);
 
  /* Check images have the same dimensions */
  if( (naxesdata[0]!=naxesmask[0]) || (naxesdata[1]!=naxesmask[1]) ){
    printf("ohmask: Image has dimensions %ld x %ld\n", naxesdata[0], naxesdata[1]);
    printf("ohmask: Mask has dimensions %ld x %ld\n", naxesmask[0], naxesmask[1]);
    printf("ohmask: These are not the same!!\n");
    exit(1);
  }

  /* Allocate memory for the output array and the grown mask array*/
  if ( (outdata=(float *)malloc(naxesdata[0]*naxesdata[1]*sizeof(float))) 
       == NULL || 
       (growmask=(float *)malloc(naxesdata[0]*naxesdata[1]*sizeof(float))) 
       == NULL ){
    printf("ohmask: Memory allocation for output data array failed!");
    exit(1);
  }

  /* Allocate memory for the nrroutines fitting arrays (these start
   * at 1 not 0 */
  if ( (nr_x=(float *)malloc((naxesdata[0]+1)*sizeof(float)))  == NULL ||
       (nr_y=(float *)malloc((naxesdata[0]+1)*sizeof(float)))  == NULL ||
       (nr_y2=(float *)malloc((naxesdata[0]+1)*sizeof(float))) == NULL ){
    printf("ohmask: Memory allocation for nrroutine arrays failed!\n");
    exit(1);
  }

  if(fit_spline==0){
    if ( (sigma=(float *)malloc((naxesdata[0]+1)*sizeof(float)))  == NULL ){
      printf("ohmask: Memory allocation for nrroutine sigma arrays failed!\n");
      exit(1);
    }
  }

  /* Grow the mask */
  for(i=0; i<naxesdata[0]*naxesdata[1]; i++){
    *(growmask + i)=0.0;
  }
  
  for(i=0; i<naxesdata[1]; i++){ /* Loop over rows in the image */
    
    for(j=0; j<naxesdata[0]; j++){ /* Loop over pixels */
      
      if( *(maskdata + j + i*naxesdata[0]) == 1.0 ){ /* A mask point */
	
	start=j-grow;
	end=j+grow;
	  
	if(start<0){
	  start=0;
	}
	
	if(end>=naxesdata[0]){
	  end=naxesdata[0]-1;
	}
	
	for(k=start; k<=end; k++){
	  *(growmask + k + i*naxesdata[0]) = 1.0;
	}
	
      }
      
    }
    
  }

  /*writefits_float_basic("!grow_mask.fits", naxesdata, growmask);*/

  /* !!!!!!!! Cubic spline fitting !!!!!!!! */

  if(fit_spline==1){

    /* Cubic spline fitting parameters */
    yp1=0;
    ypn=0;
   
    for(i=0; i<naxesdata[1]; i++){ /* Loop over rows in the image */

      /* Make up arrays for the nrroutines fitting */
      count=1;
      for(j=0; j<naxesdata[0]; j++){ /* Loop over pixels */
	
	if( *(growmask + j + i*naxesdata[0]) == 0.0 ){ /* A good point */
	  
	  nr_x[count]=j+1; /* nrroutines is not 0 based! */
	  nr_y[count]=*(imdata + j + i*naxesdata[0]);
	  count++;
	  
	}
	
	
      }
      
      num=count-1;
      if(num==0){
	printf("ohmask: There was a spectrum containing no good points!!\n");
	printf("ohmask: Giving up!\n");
	exit(1);
      }

      /*printf("Fitting row %i, num is %i\n",i,num);*/
      /* Cubic spline fit the data */
      spline(nr_x, nr_y, num, yp1, ypn, nr_y2);

      /* Fill out the relevant row of the output array, but this time
       * use the original mask */
      for(j=0; j<naxesdata[0]; j++){ /* Loop over pixels */
	
	if( *(maskdata + j + i*naxesdata[0]) == 0.0 ){ /* A good point */
	  
	  *(outdata + j + i*naxesdata[0]) = *(imdata + j + i*naxesdata[0]);
	  
	} else { /* Fill in the point from the cubic spline fit */
	  
	  splint(nr_x, nr_y, nr_y2, num, (float)(j+1), &y_out);
	  *(outdata + j + i*naxesdata[0]) = y_out;
	  
	}
	
      }

    }

  } else {

    /* !!!!!!!!!!!!! Polynomial fitting !!!!!!!!!! */
    
    ma=order+1; /* A 3rd order polynomial still has an x^0 term */
    covar=matrix(1, (long)ma, 1, (long)ma);

    /* Allocate memory */
    /* Need e.g. 1 to 4 for a 3rd order fit, in a C 0 based array hence ma+1*/ 
    if ( (a=(float *)malloc((ma+1)*sizeof(float)))  == NULL ||
	 (ia=(int *)malloc((ma+1)*sizeof(float))) == NULL ){
      printf("ohmask: Memory allocation for polynomial arrays failed!\n");
      exit(1);
    }


    for(i=0; i<=ma; i++){
      ia[i]=1; /* Fit for all parameters */
    }

    /* Sigma array for least squares fitting */
    for(i=0; i<=naxesdata[0]; i++){
      *(sigma+i)=1.0; /* Don't know the sigmas */
    }

    for(i=0; i<naxesdata[1]; i++){ /* Loop over rows in the image */

      /* Make up arrays for the nrroutines fitting */
      count=1;
      for(j=0; j<naxesdata[0]; j++){ /* Loop over pixels */
	
	if( *(growmask + j + i*naxesdata[0]) == 0.0 ){ /* A good point */
	  
	  nr_x[count]=j+1; /* nrroutines is not 0 based! */
	  nr_y[count]=*(imdata + j + i*naxesdata[0]);
	  count++;
	  
	}
	
	
      }
      
      num=count-1;
      if(num==0){
	printf("ohmask: There was a spectrum containing no good points!!\n");
	printf("ohmask: Giving up!\n");
	exit(1);
      }

      /* Fit the data */
      lfit(nr_x, nr_y, sigma, num, a, ia, ma, covar, &chisq, poly);

      /*printf("Fitting row %i, num is %i\n",i,num);*/

# if defined(DIAGNOSTIC_OUTPUT)

      /* Display the fit for diagnostic purposes */
      /*if(i==300){*/

	if ( (fitx=(float *)malloc((naxesdata[0])*sizeof(float)))  == NULL ||
	     (fity=(float *)malloc((naxesdata[0])*sizeof(float))) == NULL ){
	  printf("ohmask: Memory allocation for fitted arrays failed!\n");
	  exit(1);
	}
	
	/* Make up arrays of the fit */
	for(m=0; m<naxesdata[0]; m++){

	  *(fitx+m)=m+1;

	  y_out=0.0;
	  for(k=1; k<=ma; k++){ /* Loop over the orders of the fit */
	    l=k;
	    product=1.0;
	    while(l>1){
	      product *= ((float)m+1.0); /* x is not 0 based for the fit */
	      l--;
	    }
	    y_out += (a[k]*product);
	  }

	  *(fity+m) = y_out;
	  
	}

	minmax((imdata + i*naxesdata[0]), naxesdata[0], &ymin, &ymax);

	if(cpgbeg(0, "/xwindow", 1, 1) != 1)
	  return 1; 
	cpgenv(400, 500, ymin, ymax, 0, 0);

	if(order<100){
	  sprintf(title,"Continuum fit using a polynomial of order %i",
		  order);
	} else {
	  strcpy(title,"That's a daft order polynomial!");
	}
	
	cpglab("X pixel", "Spectral intensity", title);
  
	cpgsci(5); /* Blue lines */
	cpgline(naxesdata[0], fitx, fity);
  
	cpgsci(1); /* White lines */
	cpgline(naxesdata[0], fitx,(imdata + i*naxesdata[0]));
	cpgend();
      
	/*}*/

# endif /* DIAGNOSTIC_OUTPUT */
     
      /* Fill out the relevant row of the output array, but this time
       * use the original mask */
      for(j=0; j<naxesdata[0]; j++){ /* Loop over pixels */
	
	/*if( *(maskdata + j + i*naxesdata[0]) == 0.0 ){ *//* A good point */
	if( *(growmask + j + i*naxesdata[0]) == 0.0 ){ /* A good point */
	    
	  *(outdata + j + i*naxesdata[0]) = *(imdata + j + i*naxesdata[0]);
	  
	} else { /* Fill in the point from the polynomial fit */
	  
	  y_out=0.0;
	  for(k=1; k<=ma; k++){ /* Loop over the orders of the fit */
	    l=k;
	    product=1.0;
	    while(l>1){
	      product *= ((float)j+1.0); /* x is not 0 based for the fit */
	      l--;
	    }
	    y_out += (a[k]*product);
	  }

	  *(outdata + j + i*naxesdata[0]) = y_out;
	  
	}
	
      }

# if defined(DIAGNOSTIC_OUTPUT) 

      /* Display the fitted data for diagnostic purposes */
      /*if(i==300){*/

	if(cpgbeg(0, "/xwindow", 1, 1) != 1)
	  return 1; 
	cpgenv(400, 500, ymin, ymax, 0, 0);

	if(order<100){
	  sprintf(title,"Continuum fit using a polynomial of order %i",
		  order);
	} else {
	  strcpy(title,"That's a daft order polynomial!");
	}
	
	cpglab("X pixel", "Spectral intensity", title);
  
	cpgsci(5); /* Blue lines */
	cpgline(naxesdata[0], fitx, fity);
  
	cpgsci(1); /* White lines */
	cpgline(naxesdata[0], fitx,(outdata + i*naxesdata[0]));

	cpgsci(2); /* Red lines */
	cpgbin(naxesdata[0], fitx, (maskdata + i*naxesdata[0]),TRUE);
	cpgend();

	free(fitx);
	free(fity);

	/*}*/

# endif /* DIAGNOSTIC_OUTPUT */

    }

    free(a);
    free(ia);

  } /* End of spline or poly fit */

  writefits_float_header_clever(argv[3], argv[1], naxesdata, 2, outdata); 
  /*writefits_float_basic(argv[3], naxesdata, outdata);*/
 
  free(imdata);
  free(maskdata);
  free(growmask);
  /*free(temparray);*/
  free(outdata);
  free(nr_x);
  free(nr_y);
  free(nr_y2);
  if(fit_spline==0){
    free(sigma);
    free_matrix(covar, 1, (long)ma, 1, (long)ma);
  }
  
  return 0;

}

static void poly(float x, float afunc[], int ma){

  int k=0;
  int l=0;
  float product=0;

  for(k=1; k<=ma; k++){ /* Loop over the orders of the fit */
    l=k;
    product=1.0;
    while(l>1){
      product *= x;
      l--;
    }
    afunc[k]=product;
  }

  /*
  afunc[1]=1;
  afunc[2]=x;
  afunc[3]=x*x;
  afunc[4]=x*x*x;
  */

  return;

}




/* Code that didn't make it through to the final version! */

/*
  float* temparray=NULL;
  int width=0;
  int start=0;
  int k=0;
*/

  /* Allocate memory for the temporary medianing array */
  /*  if ( (temparray=(float *)malloc(num*sizeof(float))) == NULL ){

    printf("ohmask: Memory allocation for temp array failed!");
    exit(1);
   }
  */

      /* Start at pixel in question, if it's O.K. do nothing
       * Otherwise find the median of the pixels around it that are O.K.
       * Raster out either side until we have the require number
       * of good pixels 
       */
      /*
      if( *(maskdata + j + i*naxesdata[0]) == 0.0 ){

	*(outdata + j + i*naxesdata[0]) = *(imdata + j + i*naxesdata[0]);

	} else { *//* Do the medianing */
	

	/*	!!!!!!!!!!!!!!!!!!!!!!!!!! */
      /*
	while(count<=size){

	  if(!reachedminend){

	    position=j-step;
	    if(position < 0){
	      reachedminend==1;
	      continue;
	    }

	    if(!maskpoint){

	      data=point;
	      count++;
	    }

	  }
 

	  if(!reachedmaxend){

	    position=j+step;
	    if(position > naxesdata[0]){
	      reachedmaxend==1;
	      continue;
	    }

	    if(!maskpoint){

	      data=point;
	      count++;
	    }

	  }
      */  
	  /*	!!!!!!!!!!!!!!!!!!!!!!!!!! */
      /*		      
	}
      */

      /* for(k=0; k<num; k++){ *//* Median around the pixel in question */
      /*
	start=j-width;
	*/
	/* Check that the ends of the ranges are dealt with properly */
	/*if(start<0){
	  start=0;
	}
	
	if( (start+size) > naxesdata[0] ){
	start = naxesdata[0] - size - 1; */ /*As the loop variable is 0 based */
	/*}

	  }

	*/
