/*
    DFT++ is a density functional package developed by the research group
    of Professor Tomas Arias

    Copyright 1996-2003 Sohrab Ismail-Beigi

    This file is part of DFT++.

    DFT++ is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    DFT++ is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with DFT++; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

    Please see the file CREDITS for a list of authors.

    For academic users, we request that publications using results obtained with
    this software reference

    "New algebraic formulation of density functional calculation," by Sohrab Ismail-Beigi
    and T.A. Arias, Computer Physics Communications 128:1-2, 1-45 (June 2000).

    and, if using the wavelet basis, further reference

    "Multiresolution analysis of electronic structure: semicardinal and wavelet bases,"
    T.A. Arias, Reviews of Modern Physics 71:1, 267-311 (January 1999).

    and 

    "Robust ab initio calculation of condensed matter: transparent convergence through
    semicardinal multiresolution analysis,'' I.P. Daykov, T.A. Arias, and
    Torkel D. Engeness, Physical Review Letters, 90:21, 216402 (May 2003).

    For your convenience, preprints of the above articles may be obtained from
    http://arXiv.org/abs/cond-mat/9909130, 9805262, and 0204411, respectively.
*/

/*
 * Gabor Csanyi 8/1/2001
 *
 * the ComplexArray class holds a simple array of complex numbers
 *
 */

#include "header.h"


ComplexArray::ComplexArray(int length)
{
  //dft_log("ComplexArray(%d)\n",length);

  ndata = length;
  init(ndata);
}

ComplexArray::ComplexArray()
{
  ndata = 0;
  d = NULL;
}


void
ComplexArray::init(int length)
{
  ndata = length;
  if(ndata > 0){
    d = (complex *)mymalloc(ndata*sizeof(complex), "data", "ComplexArray");
    //dft_log("Allocated ComplexArray\n");

  }
  else
    die("Can't initiliaze ComplexArray.init() with length < 0\n");
}

void
ComplexArray::free()
{
  myfree(d);
}


void
ComplexArray::get_temp(int length)
{
  ndata = length;
  if(ndata > 0){
    d = (complex *)System::gimme_scratch(ndata,sizeof(complex));
    //dft_log("Allocated ComplexArray\n");

  }
  else
    die("Can't initiliaze ComplexArray.init_tmp() with length < 0\n");
}


void ComplexArray::release_temp()
{
  if(d)
    System::release_scratch(d);
  else
    die("Trouble??? Trying to release nothing!");
}

//////////////////////////////////////

//                                  //

// Operator functions               //

//                                  //

//////////////////////////////////////


/* Assignment:  nonstandard in that it returns void.  To make it standard,
 * replace void -> ComplexArray and uncomment the return *this; */
void ComplexArray::operator=(const ComplexArray &cd)
{
#ifdef DFT_PROFILING

  timerOn(50);  // turn on the = timer

#endif

  //dft_log("ComplexArray::operator=\n");

  if (ndata != cd.ndata)
    die("Size mismatch in ComplexArray::operator=()\n");

  for (int i=0; i < ndata; i++)
    d[i] = cd.d[i];
  /* return *this; */

#ifdef DFT_PROFILING

  timerOff(50);  // turn off the = timer

#endif

}

// Assignment of scalar:  all entries set to c

inline void
ComplexArray::operator=(complex c)
{

  //dft_log("ComplexArray::operator=(complex c)\n");


  for (int i=0; i < ndata; i++)
    d[i] = c;

}




/* Add two ComplexArray */
ComplexArray
ComplexArray::operator+(const ComplexArray &cd)
{
#ifdef DFT_PROFILING

  timerOn(51);  // turn on the + timer

#endif


  //dft_log("ComplexArray::operator+\n");


  if (ndata != cd.ndata)
    die("Size mismatch in ComplexArray::operator+()\n");
  
  ComplexArray cd2(*this);


  for(int i = 0; i < ndata; i++)
    cd2.d[i] += cd.d[i]; 

#ifdef DFT_PROFILING

  timerOff(51);  // turn off the + timer

#endif


  return cd2;
}


/* Accumulate sum of ComplexArray */
void
ComplexArray::operator+=(const ComplexArray &cd)
{

  //dft_log("ComplexArray::operator+=\n");

  if (ndata != cd.ndata)
    die("Size mismatch in ComplexArray::operator+=()\n");

  for (int i=0; i < ndata; i++)
      d[i] += cd.d[i];
}


/* Subtract two ComplexArray */
ComplexArray
ComplexArray::operator-(const ComplexArray &cd)
{
#ifdef DFT_PROFILING

  timerOn(52);  // turn off the - timer

#endif


  //dft_log("ComplexArray::operator-\n");

  if (ndata != cd.ndata)
    die("Size mismatch in ComplexArray::operator-()\n");
  
  ComplexArray cd2(*this);

  for(int i = 0; i < ndata; i++)
    cd2.d[i] -= cd.d[i]; 

#ifdef DFT_PROFILING

  timerOff(52);  // turn on the - timer

#endif


  return cd2;
}

/* Accumulate difference of arrays */
void
ComplexArray::operator-=(const ComplexArray &cd)
{
  //dft_log("ComplexArray::operator-=\n");

  if (ndata != cd.ndata)
    die("Size mismatch in ComplexArray::operator-=()\n");

  for (int i=0; i < ndata; i++)
    d[i] -= cd.d[i];
    
}

/* accumulate pointwise multiply */
void
ComplexArray::operator*=(const ComplexArray &cd)
{
  //dft_log("ComplexArray::operator-=\n");

  if (ndata != cd.ndata)
    die("Size mismatch in ComplexArray::operator*=()\n");

  for (int i=0; i < ndata; i++)
    d[i] *= cd.d[i];
    
}


/* Scale a ComplexArray cd by real r */
ComplexArray
operator*(real r,const ComplexArray &cd)
{
  //dft_log("ComplexArray::operator*\n");

  
  ComplexArray cd2(cd);

  for(int i =0; i < cd.ndata; i++)
    cd2.d[i] *= r;

  return cd2;
}

/* Scale a ComplexArray by real r */
ComplexArray
ComplexArray::operator*(real r)
{
  //dft_log("ComplexArray::operator*\n");

  ComplexArray cd2(*this);

  for(int i = 0; i < ndata; i++)
    cd2.d[i] *= r;

  return cd2;
}

/* Scale a ComplexArray cd by complex number c */
ComplexArray
operator*(complex c,const ComplexArray &cd)
{
  //dft_log("ComplexArray::operator*\n");

  
  ComplexArray cd2(cd.ndata);

  for(int i=0; i < cd.ndata; i++)
    cd2.d[i] = c*cd.d[i];

  return cd;
}

/* Scale a ComplexArray by a complex number c */
ComplexArray
ComplexArray::operator*(complex c)
{
  ComplexArray cd(ndata);
  for(int i = 0; i < ndata; i++)
    cd.d[i] = d[i]*c;

  return cd;
}


/* Scale a ComplexArray by complex c in place */
void
ComplexArray::operator*=(complex c)
{
  //dft_log("ComplexArray::operator*=\n");


  for (int i=0; i < ndata; i++)
    d[i] *= c;
}

/* Scale a ComplexArray by complex c in place */
void
ComplexArray::operator*=(real r)
{
  //dft_log("ComplexArray::operator*=\n");


  for (int i=0; i < ndata; i++)
    d[i] *= r;
}


/////////////////////////////////////////////

//                                         //

//  MEMBER FUNCTIONS:                      //

//                                         //

/////////////////////////////////////////////

// zeroes out the whole array

void ComplexArray::zero_out(void)
{
  for(int i=0;i < ndata; i++)
      d[i] = (real)0.0;
}

/* Negates all the entries in the array */
void ComplexArray::negate(void)
{
  //dft_log("ComplexArray::negate()\n");


  for(int i=0;i < ndata; i++){
    d[i].x = -d[i].x;
    d[i].y = -d[i].y;
  }
}

/* Randomizes w/ uniform distribution in [-.5, .5]*/
void ComplexArray::randomize(void)
{
  //dft_log("ComplexArray::negate()\n");

  for(int i=0;i < ndata; i++){
    d[i].x = rand()/(RAND_MAX+1.) -.5;
    d[i].y = rand()/(RAND_MAX+1.) -.5;
  }
}


/* Write ComplexArray in binary form to the file fname */
void ComplexArray::write(char *fname)
{
  FILE *fp;

  fp = dft_fopen(fname,"w");
  dft_fwrite(d,sizeof(scalar),ndata,fp);
  dft_fclose(fp);
}


void ComplexArray::write(FILE *fp)
{
  dft_fwrite(d, sizeof(scalar), ndata, fp);
}

void ComplexArray::writea(char *fname)
{
  FILE *fp;

  fp = dft_fopen(fname,"a");
  dft_fwrite(d,sizeof(scalar),ndata,fp);
  dft_fclose(fp);
}

void ComplexArray::read(char *fname)
{
  FILE *fp;
  fp = dft_fopen(fname,"r");
  dft_fread(d,sizeof(scalar),ndata,fp);
  dft_fclose(fp);


}

void ComplexArray::print()
{
  for(int i = 0; i < ndata; i++)
    dft_log("%4.2f+%4.2fi\n", d[i].x, d[i].y);
}


// Sum of absolute squares of all elements in cd

real abs2(const ComplexArray &cd)
{
  real cdcd(0);
  complex c;

  for(int i = 0; i < cd.ndata; i++){
    c =  cd.d[i];
    cdcd += c.x*c.x+c.y*c.y;
  }

  return cdcd;
}


void
add_scale_abs2(const complex &c, const ComplexArray & in,
               ComplexArray &out)
{
  if(in.ndata!=out.ndata)
    die("different lengths %d != %d in ComplexArray:add_scale_abs2\n",
        in.ndata, out.ndata);

  int i;
  for(i=0; i<in.ndata; i++){
    real z2;
    z2 = in.d[i].x*in.d[i].x + in.d[i].y*in.d[i].y;
    out.d[i].x += c.x*z2;
    out.d[i].y += c.y*z2;
  }
      
}

/* Take "dot-product" of two ComplexArray:  sum the diagonals of cd1^cd2 */
complex
dot(const ComplexArray &cd1,const ComplexArray &cd2)
{
  //dft_log("ComplexArray::dot()\n");

  int i;
  complex dot12;

  if (cd1.ndata != cd2.ndata)
    die("cd1.ndata != cd2.ndata in dot_ComplexArray()\n");


  dot12 = 0.0;
  for (i=0; i < cd1.ndata; i++){
    /* do dot12 += conj(cd1.d[i])*cd2.d[i]; */
    dot12.x += cd1.d[i].x*cd2.d[i].x + cd1.d[i].y*cd2.d[i].y;
    dot12.y += cd1.d[i].x*cd2.d[i].y - cd1.d[i].y*cd2.d[i].x;
  }
  return dot12;
}


/* Does cd2 += r * cd1 */
void
scale_accumulate(real r,const ComplexArray &cd1, ComplexArray &cd2)
{
  //dft_log("ComplexArray::scale_accumulate()\n");

  int i;

  for (i=0; i < cd1.ndata; i++)
      cd2.d[i] += r * cd1.d[i];
}

/* Does cd2 += c * cd1 */
void
scale_accumulate(complex c,const ComplexArray &cd1, ComplexArray &cd2)
{
  //dft_log("ComplexArray::scale_accumulate()\n");

  int i;

  for (i=0; i < cd1.ndata; i++)
      cd2.d[i] += c * cd1.d[i];
}


/* Does cd3 = r1*cd1 + r2*cd2 */
void
scaled_sum(real r1,const ComplexArray &cd1,
	   real r2,const ComplexArray &cd2,
	   ComplexArray &cd3)
{
  //dft_log("ComplexArray::scaled_sum()\n");

  if (cd1.ndata != cd2.ndata || cd1.ndata != cd3.ndata)
    die("Incompatible sizes in scaled_sum(r1,cd1,r2,cd2,cd3)");

  int i;
  for (i=0; i < cd1.ndata; i++)
      cd3.d[i] = r1*cd1.d[i] + r2*cd2.d[i];
}

/* Does cd3 = c1*cd1 + c2*cd2 */
void
scaled_sum(complex c1,const ComplexArray &cd1,
	   complex c2,const ComplexArray &cd2,
	   ComplexArray &cd3)
{
  //dft_log("ComplexArray::scaled_sum()\n");

  if (cd1.ndata != cd2.ndata || cd1.ndata != cd3.ndata)
    die("Incompatible sizes in scaled_sum(c1,cd1,c2,cd2,cd3)");

  int i;
  for (i=0; i < cd1.ndata; i++)
      cd3.d[i] = c1*cd1.d[i] + c2*cd2.d[i];
}


void
point_mult(ComplexArray &in1, ComplexArray &in2, ComplexArray &out)
{
  if (in1.ndata != in2.ndata || in1.ndata != out.ndata)
    die("Incompatible sizes in point_multiply(in1, in2, out)");

  int i;
  for (i=0; i < in1.ndata; i++){
    out.d[i].x = in1.d[i].x*in2.d[i].x - in1.d[i].y*in2.d[i].y;
    out.d[i].y = in1.d[i].x*in2.d[i].y + in1.d[i].y*in2.d[i].x;
  }
}


syntax highlighted by Code2HTML, v. 0.9.1