#!/usr/bin/env python3
"""
A script for generating custom system couplings/hamiltonians.
"""

import numpy as np


def generate_custom_hamiltonian():
    """
    Generates a custom Hamiltonian and stores it in a .npy file.
    """

    state_num = int(input("Enter the number of eigenenergies: "))

    hamiltonian = np.zeros((state_num,), dtype=float)
    filename_default = "custom_hamiltonian"
    E_vals = input("Enter the eigenenenergy values separated by spaces: ")

    E_vals = E_vals.split()
    assert len(E_vals) == state_num, \
        "Number of eigenenergies inconsistent with the one specified"
    hamiltonian[:] = [float(val) for val in E_vals]

    print("Hamiltonian imported")
    filename = input(
        "Now specify the file name, default is " + filename_default + ": ")

    if filename == "":
        np.save("data/" + filename_default, hamiltonian)
        print(
            "Hamiltonian successfully saved at data/" +
            filename_default + ".npy")

    else:
        np.save("data/" + filename, hamiltonian)
        print("Hamiltonian successfully saved at data/" + filename + ".npy")


def generate_custom_coupling():
    """
    Generates a custom coupling matrix and stores it in a .npy file.
    """

    state_num_ld = int(input("Enter number of states in the lead: "))
    state_num_ctr = int(input("Enter number of states in the "
                              "central region: "))

    cpl_matrix = np.zeros((state_num_ld, state_num_ctr), dtype=np.complex128)
    filename_default = "custom_coupling"

    print("Here is a little tutorial:")
    print("Enter your matrix in the form [x11 x12 x13 ; x21 x22 x23 ; "
          "x31 x32 x33]")
    print("Complex values are allowed and can be typed in as e.g. "
          "1.+1.j, j is the complex unit")
    cpl_vals = input("Now enter the coupling matrix: ")

    assert cpl_vals[0] == "["
    assert cpl_vals[-1] == "]"

    cpl_vals = cpl_vals.lstrip("[")
    cpl_vals = cpl_vals.rstrip("]")

    cpl_rows = cpl_vals.split(";")
    assert len(cpl_rows) == state_num_ld, \
        "Number of rows inconsistent with number of lead states"

    for i, row in enumerate(cpl_rows):
        row_values = [complex(val) for val in row.split()]
        assert len(row_values) == state_num_ctr, \
            "Number of items in one row inconsistent " \
            "with number of center states"
        cpl_matrix[i, :] = row_values

    print("Matrix imported")
    filename = input(
        "Now specify the file name, default is " + filename_default + ": ")

    if filename == "":
        np.save("data/" + filename_default, cpl_matrix)
        print("Coupling matrix successfully saved at data/" +
              filename_default + ".npy")

    else:
        np.save("data/" + filename, cpl_matrix)
        print("Coupling matrix successfully saved at data/" +
              filename + ".npy")


def user_interface():
    print("This script can be used to create custom Hamiltonians and custom")
    print("coupling matrices to be used in "
          "CustomCenter/CustomLead/CustomCoupling")
    print("classes.")
    print()

    while True:
        print("Would you like to generate a Hamiltonian ('h'), "
              "or a coupling matrix ('c')?")
        usr_input = input("Press 'q' to quit: ")

        if usr_input == 'h':
            generate_custom_hamiltonian()

        elif usr_input == 'c':
            generate_custom_coupling()

        elif usr_input == 'q':
            print("Goodbye!")
            break

        else:
            print("Invalid input! Type 'h' for Hamiltonian, "
                  "'c' for coupling matrix")
            print("and 'q' to quit.")


user_interface()
