/*
    DFT++ is a density functional package developed by the research group
    of Professor Tomas Arias

    Copyright 1996-2003 Sohrab Ismail-Beigi

    This file is part of DFT++.

    DFT++ is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    DFT++ is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with DFT++; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

    Please see the file CREDITS for a list of authors.

    For academic users, we request that publications using results obtained with
    this software reference

    "New algebraic formulation of density functional calculation," by Sohrab Ismail-Beigi
    and T.A. Arias, Computer Physics Communications 128:1-2, 1-45 (June 2000).

    and, if using the wavelet basis, further reference

    "Multiresolution analysis of electronic structure: semicardinal and wavelet bases,"
    T.A. Arias, Reviews of Modern Physics 71:1, 267-311 (January 1999).

    and 

    "Robust ab initio calculation of condensed matter: transparent convergence through
    semicardinal multiresolution analysis,'' I.P. Daykov, T.A. Arias, and
    Torkel D. Engeness, Physical Review Letters, 90:21, 216402 (May 2003).

    For your convenience, preprints of the above articles may be obtained from
    http://arXiv.org/abs/cond-mat/9909130, 9805262, and 0204411, respectively.
*/

#include "header.h"

Elecvars::Elecvars():nstates(0),states(NULL) {}

// Allocate the electronic variables, and read
// wavefunctions or randomize as needed.
void Elecvars::setup(Everything &e)
{
  dft_log("----- Elecvars::setup() -----\n");
  dft_log("Setting up and allocating electronic variables.\n");
  dft_log_flush();

  // Get copies of elecinfo data (needed for destructor call) ?!?!? --CSG
  // Get information out of everything
  Elecinfo &einfo = e.elecinfo;


  /**********************************************************************/
  /* k point folding and reduction                                      */
  /**********************************************************************/

  dft_log("Folding and reducing kpoints\n");

  // We need a temporary array to use with the folding/reducing
  // routines.  kvec is allocated and then the contents of
  // the kvectors inside elecinfo are copied into it.
  // NOTE: at this point, `nstates' contains the number of INPUT
  // k-points, irrespective of spin
  // NOTE: although the new kpoint set is returned in new_kvec,
  // the old set is also altered, unsuitable for further consumption.
  vector3 *new_kvec = NULL;
  real *new_weight = NULL;
  int new_nstates;
  int folded_some;
  folded_some = fold_kpoints(einfo.input_kvec,
			     &new_kvec,
			     einfo.input_weight,
			     &new_weight,
			     e.basis_spec.kpt_fold,
			     einfo.nstates,
			     new_nstates);


  int q, b;

  if(folded_some){
    // copy new values into storage
    einfo.input_kvec = (vector3 *) myrealloc(einfo.input_kvec, 
					   new_nstates*sizeof(vector3),
					   "input_kvec",
					   "Elecvars::setup()");
    einfo.input_weight = (real *) myrealloc(einfo.input_weight,
					     new_nstates*sizeof(int),
					     "input_weight",
					  "Elecvars::setup()");
    for(q = 0; q < new_nstates; q++){
      einfo.input_kvec[q] = new_kvec[q];
      einfo.input_weight[q] = new_weight[q];
    }
    einfo.nstates = new_nstates;
    myfree(new_kvec);
    myfree(new_weight);
  }


  // now play the same game for reduce...
  // again, use new_kvec and new_weight for the return data
  reduce_kpoints(einfo.input_kvec,
		 einfo.input_weight,
		 &new_kvec,
		 &new_weight,
		 einfo.nstates,
		 new_nstates,
		 e.symm,einfo,e.lattice,e.symm.calc_symmetries_flag);


  // we have the correct number of states, so create BlochState
  // array and store the new k points. it is at this point that
  // the spin gets taken into account.
  
  if (einfo.spintype == NOSPIN) {
    nstates = einfo.nstates = new_nstates;
    states = (BlochState *) mymalloc(sizeof(BlochState)*nstates,
                                     "states",
                                     "Elecvars::setup()");
    for(q = 0; q < nstates; q++){
      states[q].qnum.kvec = new_kvec[q];
      states[q].qnum.spin = 0;
      states[q].w = new_weight[q];
    }
  } else if (einfo.spintype == ZSPIN) {
    // twice as many states as kpoints
    nstates = einfo.nstates = new_nstates*2;
    states = (BlochState *) mymalloc(sizeof(BlochState)*nstates,
				     "states",
				     "Elecvars::setup()");
    for(q = 0; q < new_nstates; q++){
      states[q].qnum.kvec = new_kvec[q];
      states[q].qnum.spin = 1;
      states[q].w = new_weight[q];
      states[q+new_nstates].qnum.kvec = new_kvec[q];
      states[q+new_nstates].qnum.spin = -1;
      states[q+new_nstates].w = new_weight[q];
    }
  } else {
    die("Spin type not supported within Elecvars::setup()\n");
  }

  // kpoint reduction finished. kvec and w are no longer needed.
  myfree(new_kvec);
  myfree(new_weight);


  // This really should not happen, but just to be paranoid...
  if (nstates < 1)
    die("Error: nstates < 1 inside Elecvars::setup()\n");



  /**********************************************************************/
  /* Bands and electrons                                                */
  /**********************************************************************/


  // Figure out the number of bands and number of electrons
  dft_log("Computing the number of bands and number of electrons\n");
  if (einfo.initial_fillings_flag == FALSE) {
    // There are no initial fillings, thus ions determine nelectrons
    // and nelectrons then determines the fillings
    einfo.nelectrons = 0;
    for (int sp=0; sp < e.ioninfo.nspecies; sp++){
      einfo.nelectrons += e.ioninfo.species[sp].Z *
	e.ioninfo.species[sp].natoms;
    }
    dft_log("For a neutral system, the number of electrons is %.0f\n", einfo.nelectrons);

    // Do we compute the number of bands automatically? If not,
    // einfo.nbands is alrady filled in.
    if (e.cntrl.auto_n_bands == TRUE)
      einfo.nbands = (int)ceil(einfo.nelectrons / 2.0);
    else if(einfo.nbands < (int)ceil(einfo.nelectrons / 2.0))
      die("You specified too few bands to hold %d electrons!\n", einfo.nelectrons);

    dft_log("The number of bands is %d\n", einfo.nbands);

    // Fill in the fillings for the bands
    dft_log("Setting up band fillings automatically\n");
    setup_initial_fillings(einfo);
  } else {
    // There are initial fillings.  We trust the user to put in the correct
    // total charge.  So, we read in the fillings, which then determine
    // nelectrons.

    if (e.cntrl.auto_n_bands == TRUE)
      die("You must specify the number of bands explicitly when\nyou give initial fillings.");

    dft_log("Setting up fillings for %d bands.\n", einfo.nbands);
    setup_initial_fillings(einfo);

    // compute the number of electrons based on the fillings
    einfo.nelectrons=0.;
    for (q = 0; q < einfo.nstates; q++)
      for (b = 0; b < einfo.nbands; b++) {
        einfo.nelectrons+=REAL(states[q].F.c[b])*states[q].w;
      }
  }

  // Print out the current status of Elecvars and Elecinfo
  print_status(einfo);
  dft_log("\n");


  /**********************************************************************/
  /* set up the basis                                                   */
  /**********************************************************************/
  BasisSpec &basis_spec = e.basis_spec;

  // This must be done in Basis or Column!

  // Now let's set up the bases for the wavefunctions
  // in each of the states[]

  dft_log("----- Basis::setup() -----\n");

/*    dft_log("----- setup_G_vectors() -----\n"); */

  for (q=0; q < nstates; q++){
    // If we want k-point-dependent bases (e.g. planewaves we set up
    // the G-vectors for each k-point using its k-vector.
    if (basis_spec.basis_flag == 1){
      states[q].basis = basis;
      states[q].basis.qnum = &(states[q].qnum);

      states[q].basis.setup(states[q].qnum.kvec);
    }
    // Otherwise, for k-point independent basis (e.g. planewaves) all
    // bases are centered at k=0, so we compute the basis once and
    // copy for the rest.
    else {
      if (q==0){
        states[0].basis = basis;
        states[q].basis.qnum = &(states[q].qnum);
        vector3 zero(0.0,0.0,0.0);
        states[0].basis.setup(zero);
      }
      else{
        states[q].basis = states[0].basis;
        states[q].basis.qnum = &(states[q].qnum);
      }
    }    
  }
  

  dft_log("\n");


  /**********************************************************************/
  /* Allocate space for the electron density, electrostatic potential,  */
  /* the local (pseudo)potential, wavefunctions, and the local part of  */
  /* the self-consistent potential                                      */
  /**********************************************************************/

  //ipd: the wonderful mess bellow shows how crazy is to work with
  //     the simple command-dependency structure that we have now
  // the if statement allocates n, d and Vlocpot only if
  // we're dooing emin, fdt, or bans structure with n as input
  dft_log("Allocating space for scalar fields\n");
  if( (e.cntrl.electronic_minimization_flag == TRUE) ||
       (e.cntrl.finite_diff_test_flag == TRUE ) ||
      ( (e.cntrl.band_structure_flag == TRUE) &&
        (e.elecinfo.read_n_flag == TRUE) )  ){
    n.init_scalarfield(basis, REALSPACE);
    d.init_scalarfield(basis, COEFFSPACE);
    Vlocpot.init_scalarfield(basis, COEFFSPACE);
  }

  //ipd: Probably Vscloc should be initialized just for no-spin and
  //     only z-spin should initialize the up/down potential
  switch (einfo.spintype){
    case NOSPIN:
      Vscloc.init_scalarfield(basis, REALSPACE);
      break;
    case ZSPIN:
      //ipd: looks like the code needs Vscloc for Z-spin
      // check it out
      Vscloc.init_scalarfield(basis, REALSPACE);
      n_dn.init_scalarfield(basis, REALSPACE);
      n_up.init_scalarfield(basis, REALSPACE);
      Vscloc_up.init_scalarfield(basis, REALSPACE);
      Vscloc_dn.init_scalarfield(basis, REALSPACE);
      break;
    default:
      die("Spin type not supported in Elecvars::setup().\n");
  }
  
  int nbands = einfo.nbands;

  /* allocate the pointer arrays to Y, C and B in elecvars */
  Y=(ColumnBundle **)mymalloc(nstates*sizeof(ColumnBundle*),
                              "elecvars.setup", "Y-array");
  C=(ColumnBundle **)mymalloc(nstates*sizeof(ColumnBundle*),
                              "elecvars.setup", "C-array");
  B=(Matrix **)mymalloc(nstates*sizeof(Matrix*),
                              "elecvars.setup", "B-array");
    
  /* Allocate space for all the matrices and eigenvalues */
  for(q = 0; q < nstates; q++){
    //set the pointers of Y,C and B arrays in elecvars
    B[q] = &states[q].B;
    C[q] = &states[q].C;
    Y[q] = &states[q].Y;

    //initialize the objects
    states[q].U.init(nbands,nbands);
    states[q].Umhalf.init(nbands,nbands);
    states[q].W.init(nbands,nbands);
    states[q].Hsub.init(nbands,nbands);
    states[q].Hsub_evecs.init(nbands,nbands);
    states[q].mu = (real *)mymalloc(sizeof(real)*einfo.nbands,
				"Elecvars::setup()","mu");
    states[q].Hsub_eigs = (real *)mymalloc(sizeof(real)*einfo.nbands,
				       "Elecvars::setup()","Hsub_eigs");
    states[q].U.hermetian = 1;
    states[q].Umhalf.hermetian = 1;
    states[q].Hsub.hermetian = 1;

    if (einfo.subspace_rotation){
      states[q].B.init(nbands,nbands);
      states[q].V.init(nbands,nbands);
      states[q].Z.init(nbands,nbands);
      states[q].beta = (real *)mymalloc(sizeof(real)*einfo.nbands,
                                        "Elecvars::setup()","beta");
      states[q].B.hermetian=1;
      states[q].V.hermetian=0;
      states[q].Z.hermetian=0;
      /* Set B to identity */
      states[q].B.zero_out();
      for (int i=0; i < einfo.nbands; i++)
        states[q].B(i,i) = 1.0;
    }


    /*  set the Vscloc pointer in each BlochState to the right thing */

    // the local potential is spin dependent
    switch (states[q].qnum.spin) {
    case 1: 
      states[q].Vscloc = &(Vscloc_up);
      break;
    case -1:
      states[q].Vscloc = &(Vscloc_dn);
      break;
    default: // Y.qnum->spin == 0
      states[q].Vscloc = &(Vscloc);
    }   


    dft_log("Initializing ColumnBundles\n");
    
    /* for each state, create the bundle */
    states[q].Y.init(einfo.nbands, &(states[q].basis), "distributed");

    //dft_log("Y.col[0]->basis->nbasis = %d  Y.col[0].basis->nbasis = %d\n",
    //&(states[q].Y.col[0]->basis->nbasis),
    //states[q].Y.col[0]->basis->nbasis );

    // use Y as a template to create C
    states[q].Y.copy_structure_to(states[q].C);
  }

  

  dft_log("Initializing wave functions:  ");
  /*
  dft_log("just fill it with junk for now to test memory stuff\n");
  dft_log("accessing via ColumnBundle.data\n");
  for (q = 0; q < nstates; q++)
    for(int i = 0; i < states[q].Y.data.ndata; i++)
      states[q].Y.data.d[i] = (q+1.)*i;
  */

  // Randomize initial wave functions, and orthonormalize Y
  if (einfo.read_Y_flag == 0){
    dft_log("setting Y to random values\n"); dft_log_flush();
    System::seed_with_time();
    dft_log("Elecvars.c: please set mpi dependent seed\n");
    for (int q = 0; q < nstates; q++)
      states[q].Y.randomize();
    dft_log("Orthonormalizing Y to C\n"); dft_log_flush();
    calc_UVC(einfo,*this);
  }
  else{
    //ipd: put it in a separate function - more modular :)
    read_bloch_states_array(einfo);
  }

  // read in charge density if need to.
  if (einfo.read_n_flag == TRUE){
    dft_log("Reading charge density from file '%s'\n",
            einfo.init_n_filename);
    dft_log_flush();
    n.data.read(einfo.init_n_filename);
  }
  
  // read in the self consistent potential if necessary
  if (einfo.read_vscloc_flag == TRUE) {
    dft_log("Reading the self-consistent local potential from file '%s'\n",
            einfo.vscloc_filename);
    dft_log_flush();
    Vscloc.data.read(einfo.vscloc_filename);
  }
  dft_log("\n");
}

void Elecvars::read_bloch_states_array(Elecinfo &einfo)
{
  dft_log("reading Y from '%s'\n",einfo.init_Y_filename);
  dft_log_flush();
  
  // dft_fopen itself dies if it can not open, so no need to check.
  FILE *filep = dft_fopen(einfo.init_Y_filename,"r");

  int q;
  
  for(q = 0; q < nstates; q++)
    states[q].Y.read_stream(filep);
  
  dft_fclose(filep);
}

void Elecvars::write_bloch_states_array(Elecinfo &einfo)
{
  dft_log("writing C to  '%s'\n",einfo.init_Y_filename);
  dft_log_flush();
  
  // dft_fopen itself dies if it can not open, so no need to check.
  FILE *filep = dft_fopen(einfo.init_Y_filename,"w");

  int q;
  
  for(q = 0; q < nstates; q++)
    states[q].Y.write_stream(filep);
  
  dft_fclose(filep);
}

// Sets up the fillings of the bands depending
// on the value of the initial_fillings_flag:
// 0 = set them up automatically to be 2.0,1.0 or 0.0 (so the total=nelectrons)
// 1 = read them from the file initial_fillings_file
void Elecvars::setup_initial_fillings(Elecinfo &einfo)
{
  int q, b;
  real fnk;
  int nbands = einfo.nbands;
  
  // If we have to read the fillings from a file, open the file first
  if (einfo.initial_fillings_flag == TRUE){
    dft_log("Reading initial fillings from file %s.\n",
            einfo.initial_fillings_filename);
    dft_text_FILE * fillings_FILE = dft_text_fopen(einfo.initial_fillings_filename,"r");
    if (fillings_FILE == NULL)
      die("Can't open file %s to read initial fillings!\n",
          einfo.initial_fillings_filename);
    
    // Read and fill in the fillings!
    for (q = 0; q < nstates; q++){
      states[q].F.init(nbands);
      for (b = 0; b < nbands; b++){
        dft_text_fscanf(fillings_FILE,"%lg",&fnk);
        states[q].F.c[b] = fnk;
      }
    }
    dft_text_fclose(fillings_FILE);
  }
  // otherwise, automatic fillings are being performed
  else {
    dft_log("Computing initial fillings automatically.\n");
    
    for (q = 0; q < nstates; q++){
      real ne = 0.0, bandelectrons = 0.0;
      if(einfo.spintype == NOSPIN){ne = einfo.nelectrons;
      bandelectrons = 2.0;}
      else if(einfo.spintype == ZSPIN) {ne = einfo.nelectrons/2.0; bandelectrons = 1.0;}
      else die("Spintype not suported in Elecinfo::setup_initial_fillings.\n");
      
      states[q].F.init(nbands);
      for (b = 0; b < nbands; b++){
        // calculate what the filling should be so as to fill up # of electrons
        if (ne > bandelectrons)
          fnk = bandelectrons;
        else if 
          (ne > 0.0) fnk = ne;
        else
          fnk = 0.0;
        
        ne -= fnk;
        states[q].F.c[b] = fnk;
      }
    } // q
    
  }  
}

/*
 * Elecvars::print_fillings  
 *
 * Print out filling info:
 *
 */
void Elecvars::print_fillings(Output *out)
{ 

  dft_log("Dumping latest fillings to '%s'\n", out->filename);
  
  int q,b;
  for (q=0; q < nstates; q++){
    // print directly, rather than calling diag_matrix::print()
    // because we only want to print the REAL part
    for (b=0; b < states[q].F.n; b++) // loop through bands 
      out->printf("%16.10le ",REAL(states[q].F.c[b]));
    out->printf("\n");
  }
}

// Taking Elecinfo as an argument is not so nice, I know. But 
// Elecinfo cannot have its own print_status() because the 
// nontrivial bits only get calculated in Elecvars::setup()
// Maybe this should change? The original rationale to keep
// nbands, nelectrons in Elecinfo is that they do not change 
// throughout the calculation, so they are more like paremeters
// in this respect.

void Elecvars::print_status(Elecinfo &einfo)
{
  dft_log("Displaying some Elecvars & Elecinfo internals:\n");
  dft_log("nelectrons = %f\n",einfo.nelectrons);
  dft_log("nbands = %d\tnstates = %d\n",einfo.nbands,nstates);
  dft_log("states and fillings follow:\n");
  int q;
  for (q = 0; q < nstates; q++) {
    dft_log("%d [ %f %f %f ] %f  spin %2d\n", q, 
            states[q].qnum.kvec.v[0],
            states[q].qnum.kvec.v[1],
            states[q].qnum.kvec.v[2],
            states[q].w,
            states[q].qnum.spin);
    dft_log(">> ");    

    // should we only print the REAL part?
    states[q].F.print(dft_global_log);
  }
}


syntax highlighted by Code2HTML, v. 0.9.1