#!python
"""
Grid handler for different formats
"""
from __future__ import print_function, division

import sys, os, os.path as osp
import argparse as arg

# Variable for python3 compliance
is_3 = sys.version_info >= (3,0)

import numpy as np

import sisl

# For transforming a string to a range
import re
re_rng = re.compile('[,]?([0-9-]+)[,]?')

# Function to change a string to a range of atoms
def parse_rng(s):
    """ Parses a string into a list of ranges 
    
    This range can be formatted like this:
      1,2,3-6
    in which case it will return:
      [1,2,3,4,5,6]
    """
    rng = []
    for la in re_rng.findall(s):
        # We accumulate a list of integers
        tmp = la.split('-')
        if len(tmp) > 2: 
            print('Error in parsing: "'+s+'".')
        elif len(tmp) > 1:
            bo, eo = tuple(map(int,tmp))
            rng.extend(range(bo-1,eo))
        else:
            rng.append(int(tmp[0]-1))
    return rng

def file_idx(s):
    """ Returns a tuple of file and index of variable in the file.

    Example:

       SIESTA.nc:Rho,1 returns the `Rho` variable at the spin-index 1.
       SIESTA.nc:Rho returns the `Rho` variable at the spin-index 0 (default).
       SIESTA.nc errors unless only one grid exists in the file.
    """

    var = None
    idx = 0
    if ':' in s:
        f, var_idx = s.split(':')
        # Decide on var
        if ',' in var_idx:
            var, idx = var_idx.split(',')
        else:
            # If var_idx is integer, then idx, else var
            try:
                idx = int(var_idx)
            except:
                var = var_idx
    else:
        f = s

    # Return tuple
    return (f, var, idx)


def run():

    p = arg.ArgumentParser('Manipulates grid functions in commonly encounterd files.')

    def dir_type(val,check_type):
        if val in 'xyzabc':
            return 'xaybzc'.index(val) // 2
        try:
            return check_type(val)
        except:
            raise arg.ArgumentTypeError("Type: "+val+" not valid type "+str(check_type.__name__))

    def dir_int(val): return dir_type(val,int)
    def dir_float(val): return dir_type(val,float)
    def dir_angle(val): 
        if val in 'xyzabc':
            return 'xaybzc'.index(val) // 2
        try:
            val = val.replace('pi',np.pi)
            if 'r' in val:
                # Radians
                tmp = eval(val.replace('r',''))
            else:
                tmp = eval(val.replace('d','').replace('a','')) / 180 * np.pi
        except:
            raise arg.ArgumentTypeError("Type: "+val+" not valid type float")
        return tmp

    def sub_check(sub):
        """ Checks that the `sub` is either above(+)/below(-) """
        valid_actions = ('below', 'above', '-', '+')
        if sub not in valid_actions:
            raise ValueError('Argument not one of '+' '.join(valid_actions))
        # Returns -1,+1 for below/above
        return (valid_actions.index(sub) % 2) * 2 - 1


    # Associated geometry to the grid file
    p.add_argument('-g','--geometry', type=str, default=None,
                   help='Geometry associated with this grid.')

    # Other files 
    p.add_argument('-d','--diff', type=str, metavar='dir', default=None,
                   help='Difference grid file.')

    # Take the mean value in a direction
    p.add_argument('-m','--mean','--average', type=dir_type, metavar='dir', default=None,
                   help='Take the average value in the directions specified.')
    p.add_argument('--sum', type=dir_int, metavar='dir', default=None,
                   help='Sum grid in the directions specified.')

    
    class ValidateGridPart(arg.Action):
        def __call__(self, parser, args, values, option_string=None):
            # print '{n} {v} {o}'.format(n=args, v=values, o=option_string)
            action, dir, value = values
            try:
                action = sub_check(action)
            except:
                print("error: argument "+'/'.join(self.option_strings)+": invalid sub-part choice: {e!r}".format(e=action))
                sys.exit(1)
            try:
                dir = dir_int(dir)
            except:
                print("error: argument "+'/'.join(self.option_strings)+": invalid direction choice: {e!r}".format(e=dir))
                sys.exit(1)

            a = getattr(args, self.dest)
            if a:
                a.append((action,dir,value))
                # it must be a tuple
                setattr(args, self.dest, a) 
            else:
                setattr(args, self.dest, [(action,dir,value)])

    # We need a way to cut the cube file to a more managable size...
    p.add_argument("-gp","--grid-part",nargs=3,metavar=('PART','dir','val'),
                   dest="grid_part",action=ValidateGridPart,
                   default=None,
                   help="Reduce the size of the grid by retaining the part specified. Multiple consecutive options are allowed (applied in order). (above/+|below/-,xa|yb|zc,<pos>|<frac>f)")

    p.add_argument("-gr","--grid-remove",nargs=3,metavar=('PART','dir','val'),
                   dest="grid_remove",action=ValidateGridPart,
                   default=None,
                   help="Removes parts of the grid. Multiple consecutive options are allowed (applied in order). (above/+|below/-,xa|yb|zc,<pos>|<frac>f)")

    
    p.set_defaults(func=grid_action)
    
    p.add_argument('in_file',metavar='infile',type=str,
                   help='Read grid from file, if NetCDF file INFILE[var,idx] denotes variable and index.')
    p.add_argument('out_file',metavar='outfile',type=str,default='out.cube',
                   help='Write grid to file, append for more out files.')

    args = p.parse_args()
    args.func(args)

def grid_reducing(grid,args):
    """ Process all grid-reductions on the grid.

    This is done as several grids can be read in and processed
    and a standard routine to handle all of them is efficient.
    """

    # First reduce size
    if args.grid_part: # tuple
        for ab, axis, val in args.grid_part:
            # Do grid cut
            if 'f' in val:
                coord = np.sum(grid.cell[:,axis]) * float(val.replace('f',''))
            else:
                coord = float(val)
            idx = grid.index(coord,axis=axis)
            # Reduce grid
            grid = grid.sub_part(idx,axis=axis,above=ab>0)

    if args.grid_remove: # tuple
        for ab, axis, val in args.grid_remove:
            # Do grid cut
            if 'f' in val:
                coord = np.sum(grid.cell[:,axis]) * float(val.replace('f',''))
            else:
                coord = float(val)
            idx = grid.index(coord,axis=axis)
            # Reduce grid
            grid = grid.remove_part(idx,axis=axis,above=ab>0)
            

    if args.mean:
        grid = grid.mean(args.mean)

    if args.sum:
        grid = grid.sum(args.sum)

    return grid


def grid_action(args):
    """
    sgrid $@
    """
    # Convert the in_file to the file, variable and index
    f, var, idx = file_idx(args.in_file)
    
    # Read grid
    g = sisl.Grid.read(f, var, idx)

    if args.geometry:
        # The user has requested to overwrite the geometry
        # We have to do this before we change the grid size
        # by reduction, or averaging out one dimension
        geom = sisl.Geometry.read(args.geometry)
        g.set_geom(geom)

    # Reduce grid
    g = grid_reducing(g, args)

    if args.diff:
        f, var, idx = file_idx(args.diff)
        d = grid_reducing(sisl.Grid.read(f, var, idx),args)
        g -= d
        del d

    # Store grid
    g.write(args.out_file)


if __name__ == "__main__":
    run()
