"""This is a template for psi4 input format."""
import numpy
import sys

from fermilib.config import *
from fermilib.ops._interaction_tensor import (one_body_basis_change,
                                              two_body_basis_change)
from fermilib.utils import MolecularData

sys.path.append('&THIS_DIRECTORY')
from _psi4_conversion_functions import *


# Set memory that job can use in megabytes.
memory &memory mb

# Initialize molecular data.
_description = '&description'
if _description == 'None':
    _description = None
molecule = MolecularData(&geometry,
                         '&basis',
                         &multiplicity,
                         &charge,
                         _description)

# Set molecular geometry and symmetry.
molecule mol {
&geo_string
symmetry c1
}
mol.set_multiplicity(&multiplicity)
mol.set_molecular_charge(&charge)

# Set reference and guess.
if molecule.multiplicity == 1:
    set reference rhf
    set guess sad
else:
    set reference rohf
    set guess gwh

# Set global parameters of calculation.
set globals {
    basis &basis
    freeze_core false
    fail_on_maxiter true
    df_scf_guess true
    opdm true
    tpdm true
    soscf false
    scf_type df
    maxiter 1e6
    num_amps_print 1e6
    r_convergence 1e-6
    d_convergence 1e-6
    e_convergence 1e-6
    ints_tolerance EQUALITY_TOLERANCE
    damping_percentage 0
}

# Run self-consistent field (SCF) calculation.
if &run_scf:
    try:
        hf_energy, hf_wavefunction = energy('scf', return_wfn=True)
        if &verbose:
            print('Hartree-Fock energy for {} ({} electrons) is {}.'.format(
                molecule.name, molecule.n_electrons, hf_energy))
    except:
        if &tolerate_error:
            print('WARNING: SCF calculation failed.')
        else:
            raise
    else:
        # Get orbitals and Fock matrix.
        molecule.hf_energy = hf_energy
        molecule.nuclear_repulsion = mol.nuclear_repulsion_energy()
        molecule.canonical_orbitals = numpy.asarray(hf_wavefunction.Ca())
        molecule.n_orbitals = molecule.canonical_orbitals.shape[0]
        molecule.n_qubits = 2 * molecule.n_orbitals
        molecule.orbital_energies = numpy.asarray(hf_wavefunction.epsilon_a())
        molecule.fock_matrix = numpy.asarray(hf_wavefunction.Fa())

        # Get integrals using MintsHelper.
        mints = MintsHelper(hf_wavefunction.basisset())
        molecule.one_body_integrals = one_body_basis_change(
            numpy.asarray(mints.ao_kinetic()), molecule.canonical_orbitals)
        molecule.one_body_integrals += one_body_basis_change(
            numpy.asarray(mints.ao_potential()), molecule.canonical_orbitals)
        two_body_integrals = numpy.asarray(mints.ao_eri())
        two_body_integrals.reshape((molecule.n_orbitals, molecule.n_orbitals,
                                    molecule.n_orbitals, molecule.n_orbitals))
        two_body_integrals = numpy.einsum('psqr', two_body_integrals)
        two_body_integrals = two_body_basis_change(
            two_body_integrals, molecule.canonical_orbitals)
        integrals_name = molecule.filename + '_eri'
        numpy.save(integrals_name, two_body_integrals)
        molecule.save()


# Perform MP2 energy calculation if there are at least two electrons.
if &run_mp2:
    try:
        assert molecule.n_electrons > 1
        mp2_energy = energy('mp2')
        if &verbose:
            print('MP2 energy for {} ({} electrons) is {}.'.format(
                molecule.name, molecule.n_electrons, mp2_energy))
    except:
        if &tolerate_error:
            print('WARNING: MP2 calculation failed.')
        else:
            raise
    else:
        molecule.mp2_energy = mp2_energy
        molecule.save()


# Perform configuration interaction singles and doubles (CISD) calculation.
if &run_cisd:
    set qc_module detci
    try:
        cisd_energy, cisd_wavefunction = energy('cisd', return_wfn=True)
        if &verbose:
            print('CISD energy for {} ({} electrons) is {}.'.format(
                molecule.name, molecule.n_electrons, cisd_energy))
    except:
        if &tolerate_error:
            print('WARNING: CISD calculation failed.')
        else:
            raise
    else:
        # For the functions below, "a" and "b" refer to "up and "down" spins.
        molecule.cisd_energy = cisd_energy

        # Get 1-RDM from CISD calculation.
        cisd_one_rdm_a = numpy.array(cisd_wavefunction.get_opdm(
            0, 0, 'A', True)).reshape(molecule.n_orbitals, molecule.n_orbitals)
        cisd_one_rdm_b = numpy.array(cisd_wavefunction.get_opdm(
            0, 0, 'B', True)).reshape(molecule.n_orbitals, molecule.n_orbitals)

        # Get 2-RDM from CISD calculation.
        cisd_two_rdm_aa = numpy.array(cisd_wavefunction.get_tpdm(
            'AA', False)).reshape(molecule.n_orbitals, molecule.n_orbitals,
                                  molecule.n_orbitals, molecule.n_orbitals)
        cisd_two_rdm_ab = numpy.array(cisd_wavefunction.get_tpdm(
            'AB', False)).reshape(molecule.n_orbitals, molecule.n_orbitals,
                                  molecule.n_orbitals, molecule.n_orbitals)
        cisd_two_rdm_bb = numpy.array(cisd_wavefunction.get_tpdm(
            'BB', False)).reshape(molecule.n_orbitals, molecule.n_orbitals,
                                  molecule.n_orbitals, molecule.n_orbitals)

        # Get overall RDMs.
        cisd_one_rdm, cisd_two_rdm = unpack_spatial_rdm(
            cisd_one_rdm_a, cisd_one_rdm_b, cisd_two_rdm_aa,
            cisd_two_rdm_ab, cisd_two_rdm_bb)

        # Store 1-RDM in molecule file, 2-RDM separately in other file.
        molecule.cisd_one_rdm = cisd_one_rdm
        cisd_rdm_name = molecule.filename + '_cisd_rdm'
        numpy.save(cisd_rdm_name, cisd_two_rdm)
        molecule.save()


# Perform exact diagonalization.
if &run_fci:
    set qc_module detci
    try:
        fci_energy, fci_wavefunction = energy('fci', return_wfn=True)
        if &verbose:
            print('FCI energy for {} ({} electrons) is {}.'.format(
                molecule.name, molecule.n_electrons, fci_energy))
    except:
        if &tolerate_error:
            print('WARNING: FCI calculation failed.')
        else:
            raise
    else:
        # For the functions below, "a" and "b" refer to "up and "down" spins.
        molecule.fci_energy = fci_energy

        # Get 1-RDM from FCI calculation.
        fci_one_rdm_a = numpy.array(fci_wavefunction.get_opdm(
            0, 0, 'A', True)).reshape(molecule.n_orbitals, molecule.n_orbitals)
        fci_one_rdm_b = numpy.array(fci_wavefunction.get_opdm(
            0, 0, 'B', True)).reshape(molecule.n_orbitals, molecule.n_orbitals)

        # Get 2-RDM from FCI calculation.
        fci_two_rdm_aa = numpy.array(fci_wavefunction.get_tpdm(
            'AA', False)).reshape(molecule.n_orbitals, molecule.n_orbitals,
                                  molecule.n_orbitals, molecule.n_orbitals)
        fci_two_rdm_ab = numpy.array(fci_wavefunction.get_tpdm(
            'AB', False)).reshape(molecule.n_orbitals, molecule.n_orbitals,
                                  molecule.n_orbitals, molecule.n_orbitals)
        fci_two_rdm_bb = numpy.array(fci_wavefunction.get_tpdm(
            'BB', False)).reshape(molecule.n_orbitals, molecule.n_orbitals,
                                  molecule.n_orbitals, molecule.n_orbitals)

        # Get overall RDMs.
        fci_one_rdm, fci_two_rdm = unpack_spatial_rdm(
            fci_one_rdm_a, fci_one_rdm_b,
            fci_two_rdm_aa, fci_two_rdm_ab, fci_two_rdm_bb)

        # Store 1-RDM in molecule file, 2-RDM separately in other file.
        molecule.fci_one_rdm = fci_one_rdm
        fci_rdm_name = molecule.filename + '_fci_rdm'
        numpy.save(fci_rdm_name, fci_two_rdm)
        molecule.save()


# Perform coupled cluster singles and doubles (CCSD) calculation.
if &run_ccsd:
    set qc_module ccenergy
    try:
        ccsd_energy = energy('ccsd')
        if &verbose:
            print('CCSD energy for {} ({} electrons) is {}.'.format(
                molecule.name, molecule.n_electrons, ccsd_energy))
    except:
        if &tolerate_error:
            print('WARNING: CCSD calculation failed.')
        else:
            raise
    else:
        molecule.ccsd_energy = ccsd_energy

        # Merge CC amplitudes into molecule by parsing
        psi_filename = outfile_name()
        ccsd_amplitudes = parse_psi4_ccsd_amplitudes(
            2 * molecule.n_orbitals,
            molecule.get_n_alpha_electrons(),
            molecule.get_n_beta_electrons(),
            psi_filename)
        molecule.ccsd_amplitudes = ccsd_amplitudes
        molecule.save()
