/*
    DFT++ is a density functional package developed by the research group
    of Professor Tomas Arias

    Copyright 1996-2003 Sohrab Ismail-Beigi

    This file is part of DFT++.

    DFT++ is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    DFT++ is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with DFT++; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

    Please see the file CREDITS for a list of authors.

    For academic users, we request that publications using results obtained with
    this software reference

    "New algebraic formulation of density functional calculation," by Sohrab Ismail-Beigi
    and T.A. Arias, Computer Physics Communications 128:1-2, 1-45 (June 2000).

    and, if using the wavelet basis, further reference

    "Multiresolution analysis of electronic structure: semicardinal and wavelet bases,"
    T.A. Arias, Reviews of Modern Physics 71:1, 267-311 (January 1999).

    and 

    "Robust ab initio calculation of condensed matter: transparent convergence through
    semicardinal multiresolution analysis,'' I.P. Daykov, T.A. Arias, and
    Torkel D. Engeness, Physical Review Letters, 90:21, 216402 (May 2003).

    For your convenience, preprints of the above articles may be obtained from
    http://arXiv.org/abs/cond-mat/9909130, 9805262, and 0204411, respectively.
*/

/*
 *     Nikolaj Moll                 April 10, 1999
 *
 * Add support for symmetry calculation.
 */

/* $Id: symm.cpp,v 1.16.2.21 2003/05/29 18:54:33 ivan Exp $ */

#include "header.h"

#ifdef _WIN32
double rint(double x)
{
  return (floor(x+.5));
}
#endif

/*
 * symmetry recongnition of the ionic configuration
 */
void symmetries(Ioninfo *ioninfo, Lattice *lattice, Symmetries *symm)
{
  int sp, n1, n2, i;
  int bnrot = 0, tnrot = 0;
  vector3 r, tr; //zeroed by the default constructor
  matrix3 bsym[48], tsym[48], tmp, identity(1.0, 1.0, 1.0);

  symm->nrot = 0;

  dft_log("\nautomatic search of point group symmetries:\n");
  
  /* find symmetries of bravais lattice */
  bravais_symmetries(lattice, bnrot, bsym);

  /* first find symmetries without translation */
  basis_symmetries(ioninfo, bnrot, bsym, r, symm->nrot, symm->sym);

  /* use the atomic positions (ipd: n2=n1 below) and the center of two
   * atoms of the same species as a symmetry point */
  for (sp = 0; sp < ioninfo->nspecies; sp++) {
    for (n1 = 0; n1 < ioninfo->species[sp].natoms; n1++) {
      for (n2 = 0; n2 <= n1; n2++) {
        tr = (ioninfo->species[sp].atpos[n1]
              + ioninfo->species[sp].atpos[n2])/2.0;

        
        /* find subgroup of symmetries */
        basis_symmetries(ioninfo, bnrot, bsym, tr, tnrot, tsym);
        //ipd: a little debugging info
        //dft_log("Symm center: ");
        //tr.print(dft_global_log,"%le ");
        //dft_log("Nsymm: %d\n", tnrot);

        if (tnrot > symm->nrot) {
          r = tr;
          symm->nrot = tnrot;
          for (i = 0; i < symm->nrot; i++)
            symm->sym[i] = tsym[i];
        }
      }
    }
  }
  
  /* sort, so identity is the first matrix */ 
  for (i = 0; i < symm->nrot; i++)
    if (matcmp(symm->sym[i], identity) < MIN_SYMM_TOL) {
      tmp = symm->sym[0];
      symm->sym[0] = symm->sym[i];
      symm->sym[i] = tmp;
    }

  /* print out the symmetries */
  dft_log("reduced to %d total symmetries with basis:\n\n", symm->nrot);
  for (i = 0; i < symm->nrot; i++)
  {
    symm->sym[i].print(dft_global_log,"%4.0f ");
    dft_log("\n");
  }
  dft_log_flush();

  /* initialize the symmetry matrices for forces */
  matrix3 AtA, invAtA;
  AtA = lattice->RTR;
  invAtA = lattice->invRTR;
  for (i = 0 ; i < symm->nrot; i++) {
    symm->f_sym[i] = AtA * symm->sym[i] * invAtA;
  }
  dft_log("symmetries for forces:\n\n");
  for (i = 0; i < symm->nrot; i++) {
    symm->f_sym[i].print(dft_global_log,"%4.0f ");
    dft_log("\n");
  }
  dft_log_flush();

#ifdef PLANEWAVES
  /* move atoms to symmetry point */
  if (abs(r) > MIN_SYMM_TOL) {
    dft_log("moving atoms to new symmetry point: ");
    r.print(dft_global_log,"%g ");
    for (sp = 0; sp < ioninfo->nspecies; sp++) {
      for (n1 = 0; n1 < ioninfo->species[sp].natoms; n1++) {
        ioninfo->species[sp].atpos[n1] -= r;
      }
    }
  }
#elif defined WAVELETS
  // save r in the symmetry class
  symm->Rsym=r;
  
  dft_log("Center of symmetry for the above symmetries: ");
  symm->Rsym.print(dft_global_log,"%le ");
#endif
}

//
// Finds the applicable symmetries of the Bravais lattice from among
// all possible symmetries
// INPUT:
//    lattice
// OUTPUT:
//    bnrot - number of the symmetries (rotations) found
//    bsym  - array of the symmetry matrices
//
void 
bravais_symmetries(Lattice *lattice, int &bnrot, matrix3 *bsym)
{
  int i; 
  matrix3 a, g, m, mgm, t(1.0, 1.0, 1.0), identity(1.0, 1.0, 1.0);
  
  /* transpose lattice vector matrix */
  a = lattice->latvec;

  /* find reduced basis and transmission matrix t */
  minimize_basis(a, t);

  /* print transmission matrix if it is not equal to the identity */
  if (matcmp(t, identity) > MIN_SYMM_TOL) {
    dft_log("\ntransmission matrix:\n\n");
    t.print(dft_global_log,"%4.0f ");
    dft_log("\nwith the corresponding reduced lattice vectors:\n\n");
    a.print(dft_global_log,"%12.6f ");
  }

  /* calculate metric */
  g = (~a)*a;

  /* test all possible symmetries */
  bnrot = 0;
  for (m.m[0][0] = -1.0; m.m[0][0] <= 1.0; m.m[0][0]++)
    for (m.m[0][1] = -1.0; m.m[0][1] <= 1.0; m.m[0][1]++)
      for (m.m[0][2] = -1.0; m.m[0][2] <= 1.0; m.m[0][2]++)
        for (m.m[1][0] = -1.0; m.m[1][0] <= 1.0; m.m[1][0]++)
          for (m.m[1][1] = -1.0; m.m[1][1] <= 1.0; m.m[1][1]++)
            for (m.m[1][2] = -1.0; m.m[1][2] <= 1.0; m.m[1][2]++)
              for (m.m[2][0] = -1.0; m.m[2][0] <= 1.0; m.m[2][0]++)
                for (m.m[2][1] = -1.0; m.m[2][1] <= 1.0; m.m[2][1]++)
                  for (m.m[2][2] = -1.0; m.m[2][2] <= 1.0; m.m[2][2]++) {
	
                    /* determinant of symmetry has to be +-1 */
                    if (fabs(fabs(det3(m)) - 1.0) < MIN_SYMM_TOL) {

                      /* calculate transformed metric */
                      mgm = (~m)*g*m;

                      /* compare orignal and transformed metric */
                      if (matcmp(g, mgm) < MIN_SYMM_TOL)
                        /* transposed symmetry matrix, because matrix a 
                           is transposed */
                        bsym[bnrot++] = m;	
                    }
                  }

  /* transform symmetries from reduced basis to orginal basis */
  dft_log("\n%d symmtries of the bravais lattice\n", bnrot);

  for (i = 0; i < bnrot; i++)
    bsym[i] = t*bsym[i]*inv3(t);

  dft_log_flush();
}

void 
minimize_basis(matrix3 &a, matrix3 &t)
{
  matrix3 d(1.0, 1.0, 1.0), new_a;
  int change, k1, k2, k3, i, j;

  do {
    change = FALSE;
    for (k1 = 0; k1 < 3; k1 ++) {
      k2 = (k1 + 1)%3;
      k3 = (k1 + 2)%3;
      for (i = -1; i <= 1; i++)
        for (j = -1; j <= 1; j++) {
          d.m[k1][k1] = 1.0;
          d.m[k1][k2] = 0.0;
          d.m[k1][k3] = 0.0;
          d.m[k2][k1] = i;
          d.m[k2][k2] = 1.0;
          d.m[k2][k3] = 0.0;
          d.m[k3][k1] = j;
          d.m[k3][k2] = 0.0;
          d.m[k3][k3] = 1.0;
          new_a = a * d;
          if (trace3((~new_a)*new_a) + MIN_SYMM_TOL < trace3((~a)*a)) {
            change = TRUE;
            a = new_a;
            t = t * d;
          }
        }
    }
  }
  while (change);
}

real 
matcmp(matrix3 a, matrix3 b)
{
  matrix3 tmp;
  real sum = 0.0;
  int i, j;
  
  tmp = a - b;

  for (i = 0; i < 3; i++)
    for (j = 0; j < 3; j++)
      sum += fabs(tmp.m[i][j]);

  return(sum);
}


//
// Finds a subgroup of symmetries satisfied by the unit cell basis
// for a specific center of point symmetry
//
// INPUT:
//     ioninfo - information about the ions (unit cell basis)
//     bnrot - number of symemtries in the input group
//     bsym  - array of symmetry matrix (input symmetry group)
//     tr  - center of symmetry that we're going to check
// OUTPUT:
//     tnrot - number of elements in the subgroup
//     tsym  - array of symmetry matrices in the subgroup
//
void 
basis_symmetries(Ioninfo *ioninfo, int &bnrot, matrix3 *bsym, 
		      vector3 tr, int &tnrot, matrix3 *tsym)
{
  vector3 **pos, new_pos;
  int bsym_basis[48], sp, natoms, n, n1, irot, found, i;

  /* get memory for new atom positions */
  pos  = (vector3**)mymalloc(sizeof(vector3*)*ioninfo->nspecies, 
                             "pos", "basis_symmetries");

  // loop over the different types of atoms in the system
  for (sp = 0; sp < ioninfo->nspecies; sp++) {
    //allocate space for atomic coordinates for this spiece
    natoms = ioninfo->species[sp].natoms;
    pos[sp] = (vector3*)mymalloc(sizeof(vector3)*natoms, 
                                 "pos[]", "basis_symmetries");

    // loop over atom of this kind
    for (n = 0; n < natoms; n++) {
      /* move atoms to new symmetry point */
      pos[sp][n] = ioninfo->species[sp].atpos[n] - tr;
      
      for (i = 0; i < 3; i++) {
        pos[sp][n].v[i] = fmod(pos[sp][n].v[i], 1.0);
        if (pos[sp][n].v[i] < 0.0)
          pos[sp][n].v[i]++;
      }   
    }
  }

  // now check for symmetries
  for (irot = 0; irot < bnrot; irot++) {
    bsym_basis[irot] = TRUE;
    for (sp = 0; sp < ioninfo->nspecies && bsym_basis[irot]; sp++) {
      natoms = ioninfo->species[sp].natoms;
      for (n = 0; n < natoms && bsym_basis[irot]; n++) { 
        new_pos = bsym[irot]*pos[sp][n];
	
        for (i = 0; i < 3; i++) {
          new_pos.v[i] = fmod(new_pos.v[i], 1.0);
          if (new_pos.v[i] < 0.0)
            new_pos.v[i]++;
        }   
        
        /* look if there is an equivalent atom */
        for (found = FALSE, n1 = 0; n1 < natoms && !found; n1++) 
          if (abs(new_pos - pos[sp][n1]) < MIN_SYMM_TOL)   
            found = TRUE;
        
        bsym_basis[irot] = found;
      }
    }
  }
  
  /* copy symmetries which are symmetries of the basis */
  tnrot = 0;
  for(irot = 0; irot < bnrot; irot++)
    if (bsym_basis[irot])
      tsym[tnrot++] = bsym[irot];
  
  for (sp = 0; sp < ioninfo->nspecies; sp++)
    myfree(pos[sp]);
  myfree(pos);
}

void 
check_symmetries(Ioninfo *ioninfo, Symmetries *symm)
{
  int irot, sp, natoms, n, i, n1, found;
  vector3 new_pos;

  for (irot = 0; irot < symm->nrot; irot++) {
    for (sp = 0; sp < ioninfo->nspecies; sp++) {
      natoms = ioninfo->species[sp].natoms;
      for (n = 0; n < natoms; n++) { 
        new_pos = symm->sym[irot]*ioninfo->species[sp].atpos[n];
	
        for (i = 0; i < 3; i++) {
          new_pos.v[i] = fmod(new_pos.v[i], 1.0);
          if (new_pos.v[i] < 0.0)
            new_pos.v[i]++;
        }   

        /* look if there is an equivalent atom */
        for (found = FALSE, n1 = 0; n1 < natoms && !found; n1++) 
          if (abs(new_pos - ioninfo->species[sp].atpos[n1]) < MIN_SYMM_TOL)   
            found = TRUE;

        if (!found) {
          dft_log(DFT_SILENCE,
                  "symmetry: %d  species: %d  atom: %d\n", irot, sp, n);
          die("Symmetries do not agree with atomic positions!\n");
        }
      }
    }
  }
}


/*
 * fold_kpoints
 *
 * Fold the kpoints according to symmetries.
 *
 * Ref:
 *  H.J.Monkhorst, J.D.Pack, PRB 13, 5188, 1976
 *
 */
int
fold_kpoints(vector3 *old_kvec, vector3 **new_kvec, real *old_w, real **new_w,
	     const int *kpt_fold, int nkpts, int &new_nkpts)
{

  int i[3], j, k;
  vector3 *kvec1, *kvec0 = old_kvec;
  real *w1, *w0 = old_w; 

  int total_fold = kpt_fold[0] * kpt_fold[1] * kpt_fold[2];

  // Ok, let's go.
  if (total_fold <= 0) {
    dft_log(DFT_SILENCE,
            "Why would you want fold to be 0? %d %d %d\n",
            kpt_fold[0], kpt_fold[1], kpt_fold[2]);
//    die("You are nuts. I will quit.\n");
    return FALSE;
  }
  else if (total_fold == 1)
    // do nothing
    return FALSE;

  // move all components of kpoints to between 0 and 1
  for (k = 0; k < nkpts; k++) {
    for (j = 0; j < 3; j++) {
      kvec0[k].v[j] = fmod(kvec0[k].v[j], 1.0);
      if (kvec0[k].v[j] < 0.0)
        kvec0[k].v[j] += 1.0;
    }
  }

  // allocate temporary storages.
  kvec1 = (vector3 *)mymalloc(sizeof(vector3)*total_fold*nkpts,
                              "kvec","fold_kpoints");
  w1    = (real *)   mymalloc(sizeof(real)*total_fold*nkpts,
                              "w","setup_elecinfo");


  int new_n = 0;
  for (i[0] = 0; i[0] < kpt_fold[0]; i[0]++)
    for (i[1] = 0; i[1] < kpt_fold[1]; i[1]++)
      for (i[2] = 0; i[2] < kpt_fold[2]; i[2]++)
        for (k = 0; k < nkpts; k++) {
          for (j = 0; j < 3; j++) {
            if (kpt_fold[j] > 1) {
              kvec1[new_n].v[j] =
                (kvec0[k].v[j] + i[j])/kpt_fold[j];
            } else {
              kvec1[new_n].v[j] = kvec0[k].v[j];
            }
          }
          w1[new_n] = w0[k]/total_fold;
          new_n++;
        }


  // modify the number of kpoints.
  new_nkpts = nkpts * total_fold;

  // get the kvec and w to point to the newly folded lists.
  *new_kvec = kvec1;
  *new_w    = w1;

  // output the folded kpoint coordinates
  dft_log("kpoint folding with mesh: %d x %d x %d\n", 
          kpt_fold[0], kpt_fold[1], kpt_fold[2]);
  for (k = 0; k < nkpts; k++) {
    dft_log("%5d\t[ %4.2f %4.2f %4.2f ]  %4.2f\n",k,
            kvec1[k].v[0], kvec1[k].v[1], kvec1[k].v[2], w1[k]);
  }
  dft_log("\n");
  
  return TRUE;
}


/*
 * reduce_kpoints
 *
 * Reduce the kpoints according to system symmetries.
 *
 * Requirement: all components of kvec[] must be between 0 and 1.
 */
void
reduce_kpoints(const vector3 *old_kvec, real *old_w, 
	       vector3 **new_kvec, real **new_w, 
	       int nkpts,
	       int &new_nkpts,
	       const Symmetries &symm,
	       Elecinfo &elecinfo, const Lattice &lattice,
	       int reduce_kpts_flag)
{
  const matrix3 &GGT = lattice.GGT;
  matrix3 invGGT = lattice.invGGT;
  const matrix3 identity(1.0,1.0,1.0);
  matrix3 sym[48];
  vector3 k_new, k_tmp;
  int i, ii, j, k, *found = NULL, nrot = symm.nrot;
  real diff1, diff2, total_w;

  // check that first one is identity
  if ( identity != symm.sym[0] ) {
    die("reduce_kpoints: first symmetry operation is not identity");
  }
  
  /*
   * if not reduce_kpts_flag, then set symmetry numbers to 0.
   * (i.e.  not used)
   * so that the rest of the subroutine just allocates and copy
   * over the kpoint information to elecinfo.
   */
  if ( !reduce_kpts_flag )
    nrot = 0;

  // get the symmetry matrix for k space
  /*
   * Sk = (~A A) Sr (G ~G) / (2pi)^2
   *   note:  (~A A / (2pi)^2) = inv ( G ~G )
   * 
   */
  for (i = 0; i < nrot; i++) {
    sym[i] = invGGT * symm.sym[i] * GGT;
  }

  // initialize found array to false.
  found = (int*)mymalloc(sizeof(int)*nkpts,"found","reduce_kpoints");
  for (i = 1; i < nkpts; i++)
    found[i] = FALSE;

  /* for each original k-point,
   * if not found yet, put it in new k-point list,
   *  for each symmetry, if transformed k-point matches another
   *  original k-point, list that one as found, add its weight 
   *  to the present k-point.
   */
  new_nkpts = 0;
  for (i = 0; i < nkpts; i++) { // for each original k-point,
    if ( !found[i] ) {          // if not yet found.
      new_nkpts++;                 // add to new-kpoint list
      for (j = 0; j < nrot; j++) {  // for each symmetry
        k_new = sym[j] * old_kvec[i];   // produce transformed k.
        // get k_new in betweenn 0..1;
        for (ii = 0; ii < 3; ii++) {
          k_new.v[ii] = fmod(k_new.v[ii], 1.0);
          if (k_new.v[ii] < 0.0)
            k_new.v[ii] += 1.0;
        }
        // compare with remaining kpoints
        for (k = i+1; k < nkpts; k++) {
          // get k_tmp in betweenn 0..1 to compare to k_new;
          for (ii = 0; ii < 3; ii++) {
            k_tmp.v[ii] = fmod(old_kvec[k].v[ii], 1.0);
            if (k_tmp.v[ii] < 0.0)
              k_tmp.v[ii] += 1.0;
          }
          diff1 = abs2(k_new - k_tmp);

          // for now, only use inversion symmetry for NOSPIN case.
          if (elecinfo.spintype == NOSPIN) {
            // check also for inversion symmetry.
            for (ii = 0; ii < 3; ii++) {
              if (k_tmp.v[ii] > 0.0)
                k_tmp.v[ii] -= 1.0;
            }
            diff2 = abs2(k_new + k_tmp);
          } else {
            diff2 = MIN_KPT_DISTANCE * 10;
          }

          if ( ((diff1 < MIN_KPT_DISTANCE) || (diff2 < MIN_KPT_DISTANCE))
               && (! found[k]) ) {
            found[k] = TRUE;
            old_w[i] += old_w[k];
          }
        }
      }
    }
  }

  if (new_nkpts == nkpts)
    dft_log("reduce_kpoints: No reducable k-point discovered.\n");
  else 
    dft_log("reduce_kpoints: number of k-points reduced to %d.\n", new_nkpts);

  // calculate weight renormalization in case error has been accumulated.
  for (i = 0, total_w = 0.0; i < nkpts; i++)
    if ( !found[i] )
      total_w += old_w[i];


  *new_kvec = (vector3 *)mymalloc(sizeof(vector3)*new_nkpts,
                                  "new_kvec",
                                  "reduce_kpoints");
  *new_w    = (real *)   mymalloc(sizeof(real)*new_nkpts,
                                  "new_w","reduce_kpoints");
  
  for (i = j = 0; i < nkpts; i++) {
    if ( !found[i] ) {
      (*new_kvec)[j] = old_kvec[i];
      (*new_w)[j] = old_w[i] / total_w;
      j++;
    }
  }
  

  myfree(found);

  // output the reduced kpoint coordinates
  for (k = 0; k < new_nkpts; k++) {
    dft_log("%5d\t[ %4.2f %4.2f %4.2f ] %4.2f \n",k,
            (*new_kvec)[k].v[0], 
            (*new_kvec)[k].v[1], 
            (*new_kvec)[k].v[2],
            (*new_w)[k]);
  }
  dft_log("\n");

  return;
}


/*
 * map_symm_atom
 *
 * Map atoms to symmetry related ones if needs to calculate force.
 */
int
map_symm_atom(Ioninfo &ioninfo, const Symmetries &symm)
{
  int nrot = symm.nrot, irot, sp, nat, nat1, i;
  int found;
  Speciesinfo *species = ioninfo.species;
  vector3 new_pos, new_pos2;

  for (sp = 0; sp < ioninfo.nspecies; sp++)
    for (nat = 0; nat < species[sp].natoms; nat++) {
      for (irot = 0; irot < nrot; irot++) {
        new_pos = symm.sym[irot]*species[sp].atpos[nat];
        for (found = FALSE, nat1 = 0; (nat1 < species[sp].natoms) && (!found); nat1++) {
          new_pos2 = new_pos - species[sp].atpos[nat1];
          for (i = 0; i < 3; i++) {
            new_pos2.v[i] = fmod(new_pos2.v[i], 1.0);
            if (fabs(new_pos2.v[i]) > fabs(new_pos2.v[i]+1.0))
              new_pos2.v[i]++;
            else if (fabs(new_pos2.v[i]) > fabs(new_pos2.v[i]-1.0))
              new_pos2.v[i]--;
          }
          if (abs(new_pos2) < MIN_SYMM_TOL) {
            symm.maps[sp][irot][nat] = nat1;
            found = TRUE;
          }
        }
        if (!found)
          die("Species %d, atom %d, symm %d not found!!\n",
              sp, nat, irot);

      }
    }

  if (nrot <= 1) 
    return 1;

  dft_log("Mapping of atoms according to symmetries:\n");
  for (sp = 0; sp < ioninfo.nspecies; sp++) {
    for (nat = 0; nat < species[sp].natoms; nat++) {
      dft_log("%3d %3d : ",sp,nat);
      for (irot = 0; irot < nrot; irot++) {
        dft_log(" %3d",symm.maps[sp][irot][nat]);
      }
      dft_log("\n");
    }
  }
  dft_log_flush();

  return 1;

}


syntax highlighted by Code2HTML, v. 0.9.1