#!/usr/bin/env python
"""Contains the main running script of T4ME."""

# pylint: disable=too-many-branches, too-many-statements, useless-import-alias, unused-import, unused-variable, no-member, no-name-in-module, invalid-name

# python specifics
import sys
import logging

# locals
import t4me.utils as utils


def main():  # noqa: MC0001
    """
    Main routine.

    Sets up the calculations and calls necessary sub-routines.

    Notes
    -----
    This routine can be modified for more advanced usage. For
    day-day operations, the most used configurations are available
    by setting the parameters in the general configuration file (
    defaults to param.yml).

    """

    # first of all, check that the user have made a input
    # directory
    utils.check_directory("input")

    # check output folder (do this before logger config
    # in case user wants the log for the this job in the output
    # folder) and clean it
    utils.clean_directory("output")

    # configure logger
    utils.config_logger()
    logger = logging.getLogger(sys._getframe().f_code.co_name)  # pylint: disable=protected-access

    # these modules are loaded after we configure the logger
    # in case there are problems with third party libraries etc.
    import t4me.bandstructure as bandstructure
    import t4me.lattice as lattice
    import t4me.transport as transport
    import t4me.inputoutput as inputoutput

    # dump a startup message
    inputoutput.start_message()

    # this is not so very clean, but in order to avoid
    # code running at the top of our modules (sphinx complains e.g.)
    # we here check if all modules can be imported, if not, prints
    # a warning and later we just load the module in a subroutine
    # and assume the user knows that is going on

    # check spglib and load
    try:
        import t4me.spglib_interface
    except ImportError:
        inputoutput.spglib_error()

    # check skw_interface and load
    try:
        import t4me.skw_interface
    except ImportError:
        inputoutput.skw_warning()

    # check wildmagic and load
    try:
        import t4me.wildmagic
    except ImportError:
        inputoutput.wildmagic_warning()

    # read param file
    param = inputoutput.Param(inputoutput.readparam())

    # no parallelization at this point
    param.parallel = False

    # generate lattice from param and input file
    lat = lattice.Lattice(param)

    # generate or read bandstructure data
    bs = bandstructure.Bandstructure(lat, param)

    # dump the dispersions?
    if param.dispersion_write_preinter:
        inputoutput.dump_bandstruct_line(bs,
                                         param.dispersion_write_start,
                                         param.dispersion_write_end,
                                         datatype="e",
                                         filename="bands",
                                         k_direct=True,
                                         itype="linearnd",
                                         itype_sub="linear")
        # same as above just for the velocities
        # check if the velocities exists, if not print a message and
        # continue
        if bs.gen_velocities:
            logger.info("Data for the band velocities does not exist. "
                        "Skipping writing the band velocities to file "
                        "and continuing.")
        else:
            inputoutput.dump_bandstruct_line(bs,
                                             param.dispersion_write_start,
                                             param.dispersion_write_end,
                                             datatype="v",
                                             filename="velocities",
                                             k_direct=True,
                                             itype="linearnd",
                                             itype_sub="linear")

    # maybe the user wants to pre-interpolate?
    # but do not to this for tight binding stuff
    if param.dispersion_interpolate and not \
       param.dispersion_interpolate_method == "tb":
        logger.info("Pre-interpolating the dispersion data.")
        if bs.gen_velocities and not param.dispersion_velocities_numdiff:
            if lat.regular or param.dispersion_interpolate_method == "skw":
                # for a regular grid it is okey for us to extract velocities
                # on the fly
                logger.info("Interpolating the energies and extracting the "
                            "velocities by interpolation.")
                bs.interpolate(store_inter=True, ivelocities=True)
            else:
                # only interpolate the eigenvalues
                logger.info("Interpolating the energies.")
                bs.interpolate(store_inter=True, ivelocities=False)
                if param.dispersion_velocities_numdiff:
                    logger.warning("The user disabled the extraction of the "
                                   "velocities by finite difference. However, "
                                   "velocity extraction by interpolation is "
                                   "currently only possible for regular "
                                   "unitcells. Continuing.")
                # and then extract velocities by finite difference
                logger.info("Extracting the velocities by finite difference.")
                bs.calc_velocities()
        else:
            if param.dispersion_velocities_numdiff:
                # only interpolate the eigenvalues
                logger.info("Interpolating the energies.")
                bs.interpolate(store_inter=True, ivelocities=False)
                # and then extract velocities by finite difference
                logger.info("Extracting the velocities by finite difference.")
                bs.calc_velocities()
            else:
                bs.interpolate(store_inter=True, ivelocities=True)

        # dump the dispersions after interpolations?
        if param.dispersion_write_postinter:
            inputoutput.dump_bandstruct_line(bs,
                                             param.dispersion_write_start,
                                             param.dispersion_write_end,
                                             datatype="e",
                                             filename="bands_inter",
                                             k_direct=True,
                                             itype="linearnd",
                                             itype_sub="linear")
            # same as above just for the velocities
            if bs.gen_velocities:
                logger.info("Data for the band velocities does not exist. "
                            "Skipping writing the band velocities to file "
                            "and continuing.")
            else:
                inputoutput.dump_bandstruct_line(bs,
                                                 param.dispersion_write_start,
                                                 param.dispersion_write_end,
                                                 datatype="v",
                                                 filename="velocities_inter",
                                                 k_direct=True,
                                                 itype="linearnd",
                                                 itype_sub="linear")
    # do the user want effective mass data?
    if param.dispersion_effmass:
        # calculate effective mass
        bs.calc_effective_mass()
        # write effective mass
        logger.error("Writeout of the effective mass tensor "
                     "have not yet been included.")
        sys.exit(1)
        inputoutput.dump_effmass(bs, filename="effmass")

    # calculation of the density of states?
    if param.dos_calc:
        # calculate
        bs.calc_density_of_states()
        # write dos
        inputoutput.dump_density_of_states(bs, filename="dos")

    # transport calcs?
    if param.transport_calc:
        # initialize parameters, scattering etc.
        tran = transport.Transport(bs)
        # write relaxation times
        if not param.transport_method == "closed":
            inputoutput.dump_relaxation_time(tran)
        # perform the actual calculation of the transport coefficients
        tran.calc_transport_tensors()
        # write transport coefficients
        inputoutput.dump_transport_coefficients(tran)
        # dump dos as well
        if not param.dos_calc:
            if not param.transport_use_analytic_scattering or \
               param.transport_method == "closed":
                inputoutput.dump_density_of_states(bs, filename="dos")

    # dump message at the end
    inputoutput.end_message()

    # shutdown logger
    logging.shutdown()


if __name__ == '__main__':

    main()
