#!/usr/bin/env python

# Generate the CDML metadata associated with a dataset.

import sys
import os
import numpy
try:
    import cdmsNode
    from cdmsNode import NumericToCdType
except ImportError:
    from cdms import cdmsNode
    from cdms.cdmsNode import NumericToCdType
import cdms
import getopt
import string
import re
import cdtime

_fudgeFactor = 0.0                     # Two axis vectors a,b are 'equal' if abs((b-a)/b)<=fudgeFactor, element-wise
verbose = 0                             # If true, print messages.
dateString = re.compile("(?P<year>\d+)(-(?P<month>\d{1,2})(-(?P<day>\d{1,2})(\s+(?P<hour>\d{1,2})(:(?P<minute>\d{1,2})(:(?P<second>\d{1,2}))?)?)?)?)?\Z") # Date/time of the form yyyy-mm-dd hh:mm:ss with trailing fields optional
calendarMap = {'gregorian': cdtime.GregorianCalendar,
               'julian': cdtime.JulianCalendar,
               'noleap': cdtime.NoLeapCalendar,
               '360': cdtime.Calendar360}


UnsupportedTypecode = "Unsupported Numeric datatype"
CannotRenameAxis = "Cannot rename axis: "
CannotExtend = "Cannot extend a dataset: "
InvalidTime = "Time must be of the form yyyy-mm-dd hh:mi:ss "
InvalidCalendar = "Calendar must be 'gregorian', 'julian', 'noleap', or '360' "

usage = """Usage: cdimport [-dghjmrsv] [-a time0] [-c calendar] [-e vectorEpsilon] [-l levelName] [-n variableName] [-t timeName] [-u variableName] [-x xmlFile] <directory> <template> <datasetId>

Arguments:

    <directory> is the root directory of the dataset. <directory> may not contain template specifiers.

    <template> is the file template which specifies the dataset partitioning. It is a path-name, relative to the directory, containing template specifiers. For example, the template '%v/rnl_ncep/%v.rnl_ncep.%Y.ctl' contains the specifiers %v (variable ID), and %Y (year). A specifier of the form %eX refers to the ending timepoint/level in the file. For example, %en represents the last month in the file. Specifiers include:

        %d day number (1 .. 31)
        %eX ending timepoint/level, where X is a specifier character
        %g month, lower case, three characters ('jan', 'feb', ...)
        %G month, upper case, three characters ('JAN', 'FEB', ...)
        %H hour (0 .. 23)
        %L vertical level (integer)
        %m month number, not zero filled (1 .. 12)
        %M minute 0 .. 59
        %n month number, two-digit, zero-filled (01, 02, ..., 12)
        %S second (0 .. 59)
        %v variable ID (string)
        %y year, two-digit, zero-filled (integer)
        %Y year (integer)
        %z Zulu time (ex: '6Z19990201')
        %% percent sign       

    <datasetId> is a string identifier for the dataset.

Options:

   -a: adjust the time dimension values so that the first timepoint is time0.
       time0 has the form "yyyy-mm-dd hh:mi:ss" where trailing fields may be omitted.
       Example: -a '1982-10-1'. All values are adjusted by the same increment.
   -c: specify the calendar, e.g., 'gregorian', 'julian', 'noleap', or '360'
   -d: save output after each variable is scanned (default = false)
   -e: specify equality of axes (default = 0.0). See note below.
   -g: allow gaps in extended, linear time dimensions (implies -s)
   -h: print help
   -j: ingest levels as increasing (default is decreasing)
   -l: name of the extended level dimension (default='level')
   -m: ingest times as decreasing (default is increasing)
   -n: skip a variable (option may be used more than once)
   -r: assert that the extended time dimension is NOT relative time (time is relative by default)
   -s: assert that the extended time dimension is linear (default = false)
   -t: name of the extended time dimension (default='time')
   -u: flag variableName as duplicated in all files (option may be used more than once,
       bounds, weights, and other associated variables are included by default)
       Also, a comma-separated list of variables may be specified, with no spaces.
       Example: -u weights_latitude,bounds_latitude
   -v: verbose (off by default)
   -x: save output as an XML file (write to standard output by default)

Example:

    Ingest the six-hourly NCEP reanalysis data. Treat time as linear,
    with possible gaps in the time dimension. Skip variable rsut.
    Set the dataset ID to ncep_reanalysis_6h.

    cdimport -g -v -n rsut -x ncep_reanalysis_6h.xml /pcmdi/obs/6h/ %v/rnl_ncep/%v.rnl_ncep.%Y.ctl ncep_reanalysis_6h

Notes:

   Note: vectorEpsilon is interpreted as follows: two axis vectors x and y are treated as 'equal'
     iff for each element a(n), b(n), abs(a(n)-b(n))<=abs(a(n)*vectorEpsilon).
     Set vectorEpsilon=0 to require true equality of vectors.

   See also: http://www-pcmdi.llnl.gov/software/cdms/
"""

def main(argv):
    global _fudgeFactor, verbose

    try:
        args, lastargs = getopt.getopt(argv[1:],"a:c:de:ghjl:mn:rst:u:vx:")
    except getopt.error:
        print sys.exc_value
        print usage
        sys.exit(0)

    adjustCalendar = 0
    allowGaps = 0
    calendar = None
    decreasingLevels = 1
    duplicateVars = []
    extendedLevel = 'level'
    extendedTime = 'time'
    increasingTimes = 1
    saveEachIteration = 0
    setCalendar = 0
    skipVars = []
    timeIsLinear = 0
    timeIsRelative = 1
    writeToStdout = 1
    verbose = 0
    cwd = os.getcwd()
    for flag,arg in args:
        if flag=='-a':
            baseTimeString = arg
            adjustCalendar = 1
            baseTimeMatch = dateString.match(arg)
            if baseTimeMatch is None:
                raise InvalidTime, baseTimeString
        elif flag=='-c':
            calendar = string.lower(arg)
            setCalendar = 1
            if calendar not in ['gregorian', 'julian', 'noleap', '360']:
                raise InvalidCalendar, calendar
        elif flag=='-d': saveEachIteration = 1
        elif flag=='-e':
            _fudgeFactor = string.atof(arg)
        elif flag=='-g':
            allowGaps = 1
            timeIsLinear = 1
        elif flag=='-h': print usage; sys.exit(0)
        elif flag=='-j': decreasingLevels = 0
        elif flag=='-l': extendedLevel = arg
        elif flag=='-m': increasingTimes = 0
        elif flag=='-n': skipVars.append(arg)
        elif flag=='-r': timeIsRelative = 0
        elif flag=='-s': timeIsLinear = 1
        elif flag=='-t': extendedTime = arg
        elif flag=='-u': duplicateVars = duplicateVars + string.split(arg,',')
        elif flag=='-v': verbose = 1
        elif flag=='-x':
            writeToStdout = 0
            xmlpath = arg

    if timeIsLinear==1:
        linearDimList = [extendedTime]
    else:
        linearDimList = None

    if len(lastargs)!=3:
        print 'Wrong number of arguments'
        print 'Invalid command line:',argv
        print usage; sys.exit(1)
    direc = lastargs[0]
    template = lastargs[1]
    datasetid = lastargs[2]

    if verbose:
        print argv
        print 'Looking for matching files (this may take a while...)'

    filelist = cdms.matchingFiles(direc,template)

    # Variables, times (time,etime), levels (level,elevel)
    vardict = {}
    timedict = {}
    leveldict = {}
    
    for (path,matchnames) in filelist:
        var,time,etime,level,elevel = matchnames
        vardict[var] = 1
        timedict[str((time,etime))] = (time,etime)
        leveldict[(level,elevel)] = 1
    varlist = vardict.keys(); varlist.sort()
    timelist = timedict.values(); timelist.sort()
    if not increasingTimes: timelist.reverse()
    levellist = leveldict.keys(); levellist.sort()
    if decreasingLevels: levellist.reverse()

    if verbose:
        print '*** Generating dataset',datasetid,'from directory',direc,', template',template,'***'
        print 'List of variables found (in filenames):', varlist
        print 'List of timepoints found (in filenames):', timelist
        print 'List of levels found (in filenames):', levellist

    # The main loop
    # for each variable
    # - for each time (first,last)
    #   - for each level (first,last)
    #     - generate the filename for var,time,level using the template
    #     - scan the file, returning a dataset D4
    #     - extend D3 with D4, in the level dimension
    #   - extend D2 with D3 (all levels for this var, time), in the time dimension
    # - merge D with D2 (all levels, times for this var)
    # resulting in dataset D

    dset = None
    for var in varlist:
        if var in skipVars:
            if verbose:
                print 'Skipping variable',var
            continue
        dset2 = None
        for (time,etime) in timelist:
            dset3 = None
            for (level,elevel) in levellist:
                matchnames = (var,time,etime,level,elevel)
                relpath = cdms.getPathFromTemplate(template, matchnames)
                fullpath = os.path.join(direc,relpath)
                if verbose:
                    print 'Scanning',fullpath
                if os.path.exists(fullpath):
                    dset4 = scan(datasetid,fullpath,template,direc,linearDimList,matchnames,extendedTime,extendedLevel)
                    dset3 = extend(dset3,dset4,extendedLevel)
                elif allowGaps and verbose:
                    print 'File not found, was skipped:',fullpath
                else:
                    raise IOError, 'No such file or directory: %s'%fullpath
            dset2 = extend(dset2,dset3,extendedTime,timeIsRelative,linearDimList,allowGaps,'T')
        dset = merge(dset,dset2,template,duplicateVars)
        if saveEachIteration==1:
            if writeToStdout:
                dset.dump()
            else:
                os.chdir(cwd)               # Restore current directory
                dset.dump(xmlpath)
                if verbose:
                    print xmlpath,'written'

    if dset is None:
        if verbose:
            print 'No data found'
            sys.exit(1)

    # Add the Conventions attribute if not present
    conventions = dset.getExternalAttr('Conventions')
    if conventions is None: dset.setExternalAttr('Conventions','')

    # Set calendar attribute on all time axes and/or adjust time values.
    if setCalendar or adjustCalendar:
        if not setCalendar: calendar = 'gregorian'
        if adjustCalendar:
            year = baseTimeMatch.group('year')
            timeargs = [string.atoi(year)]
            month = baseTimeMatch.group('month')
            if month!=None: timeargs.append(string.atoi(month))
            day = baseTimeMatch.group('day')
            if day!=None: timeargs.append(string.atoi(day))
            hour = baseTimeMatch.group('hour')
            if hour!=None: timeargs.append(string.atoi(hour))
            minute = baseTimeMatch.group('minute')
            if minute!=None: timeargs.append(string.atoi(minute))
            second = baseTimeMatch.group('second')
            if second!=None: timeargs.append(string.atoi(second))
            baseTime = apply(cdtime.comptime,tuple(timeargs))
            
        for node in dset.getIdDict().values():
            if node.tag == 'axis':
                id = node.id
                if id is None: id = ''
                if setCalendar:
                    axisAttr = node.getExternalAttr('axis')
                    if (axisAttr is not None and axisAttr=='T') or (len(id)>=4 and string.lower(id[0:4])=='time'):
                        node.setExternalAttr('calendar',calendar)
                if adjustCalendar and node.id==extendedTime:
                    units = node.getExternalAttr('units')
                    cdcalendar = calendarMap[calendar]
                    rightTime = baseTime.torel(units, cdcalendar)
                    wrongTime = node.getData()[0]
                    offset = rightTime.value - wrongTime
                    if node.dataRepresent==cdmsNode.CdLinear:
                        node.data.start = node.data.start + offset
                    else:
                        node.data = node.data + offset

    # Write the output
    if writeToStdout:
        dset.dump()
    else:
        os.chdir(cwd)               # Restore current directory
        dset.dump(xmlpath)
        if verbose:
            print xmlpath,'written'

#---------------------------------------------------------------------------------------------------------------------------------
# Create a datanode from cdunif file 'path'
# 'id' is the dataset identifier
# 'linearDims' is a list of dimensions to treat as linear,
# (regardless of the dimension values)
# 'matchnames' is a list (varname,starttime,endtime,startlevel,endlevel) for this file
# If specified, all variables with a different name than matchnames[0] will have a .template
# attribute set to the appropriate template.
def scan(id,path,template=None,directory=None,linearDims=None,matchnames=None,extendedTime=None,extendedLevel=None):
    # Open the file, create the dataset node
    cdfile = cdms.openDataset(path)
    dataset = cdmsNode.DatasetNode(id)
    if directory == None:
        head,tail=os.path.split(path)
        dataset.setExternalAttr('directory',head)
    else:
        dataset.setExternalAttr('directory',directory)
    if template == None:
        dataset.setExternalAttr('template',tail)
    else:
        dataset.setExternalAttr('template',template)

    # Copy global attributes
    for attname in cdfile.attributes.keys():
        attval = cdfile.attributes[attname]
        if type(attval)==numpy.ndarray:
            attval = attval[0]
        dataset.setExternalAttr(attname,attval)

    # Create axis nodes
    objlist = []                        # list of (node, cdobj)
    for cdaxis in cdfile.axes.values():
        dimarray = cdaxis[:]
        datatype = NumericToCdType.get(cdaxis.typecode())
        if datatype is None: raise UnsupportedTypecode, cdaxis.typecode()
        # A hack: current version of cdmsNode doesn't support CdInt
        if datatype==cdms.CdInt: datatype=cdms.CdLong
        axis = cdmsNode.AxisNode(cdaxis.id, len(cdaxis), datatype)
        dataset.addId(cdaxis.id, axis)
        units = cdaxis.attributes.get('units')
        if units==None: units=''
        axis.setExternalAttr('units',units)
        dimname = cdaxis.id
        if linearDims!=None and dimname in linearDims:
            start = dimarray[0]
            length = len(dimarray)
            if length>1:
                delta = dimarray[1]-dimarray[0]
            else:
                delta = 0.0
            linearNode = cdmsNode.LinearDataNode(start,delta,length)
            axis.setLinearData(linearNode)
        else:
            axis.setData(dimarray)
        objlist.append((axis, cdaxis))

    # Create variable nodes
    for cdvar in cdfile.variables.values():
        domain = cdmsNode.DomainNode()
        for i in range(len(cdvar.domain)):
            cdaxis = cdvar.getAxis(i)
            domElem = cdmsNode.DomElemNode(cdaxis.id, 0, len(cdaxis))
            domain.add(domElem)
        datatype = NumericToCdType.get(cdvar.typecode())
        if not datatype: raise UnsupportedTypecode, cdvar.typecode()
        # A hack: current version of cdmsNode doesn't support CdInt
        if datatype==cdms.CdInt: datatype=cdms.CdLong
        variable = cdmsNode.VariableNode(cdvar.id,datatype,domain)
        dataset.addId(cdvar.id,variable)
        objlist.append((variable, cdvar))
        
    # Copy attributes
    for (obj,cdobj) in objlist:
        variable = obj
        varname = cdobj.id
        for attname in cdobj.attributes.keys():
            attval = cdobj.attributes[attname]
            if attname=='name_in_file' and attval==cdobj.id:
                continue
            if type(attval)==numpy.ndarray:
                attval = attval[0]
            obj.setExternalAttr(attname, attval)

        # If necessary, create a variable template which overrides
        # the file template. This is the case if this variable name
        # does not match the varname for this file.
        if matchnames != None:
            tempVarName = matchnames[0]
            matchcopy = [matchnames[0],matchnames[1],matchnames[2],matchnames[3],matchnames[4]]
            if tempVarName!=None and varname!=tempVarName and variable.tag=='variable':

                domain = variable.getDomain()
                for domElem in domain.children():
                    elemname = domElem.getName()

                    # If the variable is not a function of the extended time dimension,
                    # set the template times to those of this file.
                    # Same for the extended level dimension
                    if elemname==extendedTime:
                        matchcopy[1] = matchcopy[2] = None
                    elif elemname==extendedLevel:
                        matchcopy[3] = matchcopy[4] = None
                altTemplate = cdms.getPathFromTemplate(template,tuple(matchcopy))
                variable.setExternalAttr('template',altTemplate)

    # Cleanup
    cdfile.close()
    return dataset

#---------------------------------------------------------------------------------------------------------------------------------
# Add variables in datanode 'dset2' to those in
# datanode 'dset'. Merge common axes. 'dset2' object names may be altered.
# duplicateVars is a list of variables which are duplicated across
# datasets, hence are ignored on a merge.
def merge(dset,dset2,template=None,duplicateVars=[]):
    global verbose

    suffices = ['','_a','_b','_c','_d','_e','_f','_g','_h','_j','_k','_m','_n','_p','_q','_r','_s','_t','_u','_v','_w','_x','_y','_z']
    if dset is None:
        if template != None and dset2 != None:
            dset2.setExternalAttr('template',template)
        return dset2

    # Create a list of 'duplicate variables', e.g. bounds
    # and weights which are duplicated across per-variable datasets
    # Also include associated variables by default.
    idlist1 = dset.getIdDict().values()
    dupvar = []
    for obj in idlist1:
        if obj.tag!='axis': continue
        axis = obj
        boundsName = axis.getExternalAttr("bounds")
        if boundsName != None:
            dupvar.append(boundsName)
        weightsName = axis.getExternalAttr("weights")
        if weightsName != None:
            dupvar.append(weightsName)
        assocNames = axis.getExternalAttr("associate")
        if assocNames != None:
            dupvar = dupvar + string.split(assocNames)
    dupvar = dupvar + duplicateVars
    if verbose:
        if dupvar != []:
            print 'Ignoring duplicate variables:',dupvar

    axismap = {}
    idlist = dset2.getIdDict().values()

    # For each axis in dset2:
    for obj in idlist:
        if obj.tag!='axis': continue
        axis = obj
        axisname = axis.id
        axismap[axisname] = axisname

        # If the axis is not in dset, just add it
        axis1 = dset.getChildNamed(axisname)
        if axis1 == None:
            dset.addId(axisname,axis)
        
        # If axisname is in dset and the axes are equal, just continue (it's there)
        elif axis1.isClose(axis,_fudgeFactor):
            continue

        # The axis needs to be renamed:
        #   - find a suffix such that either:
        #     (1) axisname_suffix is in dset and is equal:
        #         just continue
        #     (2) axisname_suffix is not in dset:
        #         add it to dset
        #   - in either case:
        #     - change the name to name_suffix
        #     - save old name in attribute 'name_in_file'
        #     - save the map of old->new axis name
        else:
            foundname = 0               # 1 if the axis can be renamed
            for suffix in suffices:
                newname = axisname+'_'+`len(axis)`+suffix
                axismap[axisname] = newname
                axis.setExternalAttr('name_in_file',axisname)
                axis1 = dset.getChildNamed(newname)

                # If axis is not in dset, just add it
                if axis1 == None:
                    axis.id = newname
                    axis.setExternalAttr('id',newname)
                    dset.addId(newname,axis)
                    foundname = 1
                    break

                # If axis is in dset and is equal, just continue outer loop
                elif axis1.isClose(axis,_fudgeFactor):
                    foundname = 1
                    break
                
            if not foundname:
                dset.dump()
                print 'Cannot match axis:',axis.data
                raise CannotRenameAxis, newname

    # For each variable object in dset2:
    for obj in idlist:
        if obj.tag=='variable' and obj.id not in dupvar:
            var = obj

            # - map any changed axis names
            domain = var.getDomain()
            for domElem in domain.children():
                domElem.setName(axismap[domElem.getName()])
                
            # - add the object and ID to dset
            dset.addId(var.id,var)

    # For all other named objects in dset2:
    # - if the object tag is not in dset, add the object
    for obj in idlist:
        if obj.tag not in ('axis','variable'):
            if dset.getChildNamed(obj.id) != None:
                dset.addId(obj.id,obj)

    # Merge dataset attributes
    attrdict = dset.getExternalDict()
    attrdict2 = dset2.getExternalDict()
    for attname2 in attrdict2.keys():
        if attname2 in ('id','template'): continue
        attval2 = attrdict2[attname2]
        if attrdict.has_key(attname2):
            attval = attrdict[attname2]
            if attval != attval2:
                sys.stderr.write('Warning: dataset attributes are mismatched: %s.%s = %s, %s.%s = %s\n'%(dset.id,attname2,attval,dset2.id,attname2,attval2))
        else:
            dset.setExternalAttr(attname2,attval2[0],attval2[1])
    if template != None:
        dset.setExternalAttr('template',template)
    return dset

#---------------------------------------------------------------------------------------------------------------------------------
# Extend variables in datanode 'dset' with those
# in datanode 'dset2'. Append axes named 'extendDim'
# 'isreltime' is true iff the extended dimension is a relative time dimension
# If 'allowGaps' is true, allow linear time dimensions to have gaps
# axisAttr is the value of the extended dimension 'axis' attribute: 'Z' for levels,
# or 'T' for time.
def extend(dset,dset2,extendDim,isreltime=0,linearDims=None,allowGaps=0,axisAttr='Z'):

    if dset is None:
        return dset2
    if dset2 is None:
        return dset

    # For each axis in dset2, other than extendDim:
    # - if the axis name is not in dset or is unequal: error
    idlist = dset2.getIdDict().values()
    for obj in idlist:
        if obj.tag!='axis' or obj.id==extendDim: continue
        axis = obj
        axisname = axis.id
        axis1 = dset.getChildNamed(axisname)
        if axis1 is None:
            raise CannotExtend, 'Axis %s not found'%axisname
        if not axis1.isClose(axis,_fudgeFactor):
            raise CannotExtend, 'Axis %s does not match'%axisname

    # Extend dset.extendDim by dset2.extendDim, preserving
    # monotonicity and linearity
    exdim1 = dset.getChildNamed(extendDim)
    if exdim1 is None:
        raise CannotExtend, 'Extend dimension %s not found'%extendDim
    exdim2 = dset2.getChildNamed(extendDim)
    if exdim2 is None:
        raise CannotExtend, 'Extend dimension %s not found'%extendDim
    exdim1.extend(exdim2,isreltime,allowGaps)
    exdim1.setExternalAttr('axis',axisAttr)

    # Check the assertion that time is linear
    if linearDims!=None and extendDim in linearDims:
        if exdim1.dataRepresent == cdmsNode.CdVector:
            raise CannotExtend, 'Axis %s is not linear or has the wrong length in the PREVIOUS file'%extendDim

    # For each var object in dset2 with extendDim in the domain:
    # - if the object name is not in dset: error
    # - if the domain element is not all of extendDim: error
    # - extend dset.var
    for obj in idlist:
        if obj.tag!='variable': continue
        var2 = obj
        varname = var2.id

        var1 = dset.getChildNamed(varname)
        if var1 is None:
            raise CannotExtend, 'Variable %s not found'%varname
            
        # Check if this var has the extend dimension
        domain = var2.getDomain()
        for domElem in domain.children():
            elemname = domElem.getName()
            if elemname==extendDim:
                start = domElem.start
                length = domElem.length
                if start!=None and length!=None:

                    # Check that dset.var is a function of the entire extended dimension
                    if start!=0 or length!=len(exdim2):
                        raise CannotExtend, 'Variable %s is not a function of the entire dimension %s'%(varname,extendDim)
                    else:
                        # Extend dset.var
                        domain1 = var1.getDomain()
                        foundExtendDim = 0
                        for domElem1 in domain1.children():
                            if domElem1.getName()==extendDim:
                                domElem1.length = len(exdim1)
                                domElem1.setExternalAttr('length',domElem1.length)
                                if exdim1.partition!=None:
                                    domElem1.partition_length = exdim1.partition_length
                                    domElem1.setExternalAttr('partition_length',domElem1.partition_length)
                                foundExtendDim = 1
                                break

                        # If dset2.var is a function of the extended dimension, so must dset.var be
                        if foundExtendDim==0:
                            raise CannotExtend, 'Variable %s does not have the extended dimension %s'%(varname,extendDim)
    return dset

#---------------------------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
    main(sys.argv)
