#!python

"""
  THICK-2D -- Thickness Hierarchy Inference & Calculation Kit for 2D materials

  This program is free software; you can redistribute it and/or modify it under the
  terms of the GNU General Public License as published by the Free Software Foundation
  version 3 of the License.

  This program is distributed in the hope that it will be useful, but WITHOUT ANY
  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
  PARTICULAR PURPOSE.  See the GNU General Public License for more details.
  Email: cekuma1@gmail.com
  
""" 
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib
import logging
import time
import copy
from ase import Atoms
from collections import Counter
from datetime import datetime
import warnings
from read_write import read_options_from_input,load_structure,append_data 
from optimize_struct import optimize_structure_vasp,optimize_structure_qe
from write_inputs import print_line, print_boxed_message, print_banner
from predict_thickness_2D import predict_thickness_2D 

try:
    from importlib.metadata import version  # Python 3.8+
    version = version("SMATool") 
except ImportError:
    from importlib_metadata import version  # Python <3.8
    version = version("SMATool") 
except ImportError:
    import pkg_resources
    version = pkg_resources.get_distribution("SMATool").version

    
    
warnings.filterwarnings('ignore')
#matplotlib.use('Agg')


log_filename = 'thick2d.log'
if os.path.exists(log_filename):
    os.remove(log_filename)

logging.basicConfig(
    filename=log_filename,
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO
)



# Main code starts here!
current_time = datetime.now().strftime('%H:%M:%S')
current_date = datetime.now().strftime('%Y-%m-%d')
start_time = time.time()

logging.info(f"TICK2D calculation started at {current_time} on {current_date}")

options = read_options_from_input()
code_type = options.get("code_type", "VASP")
#mode = "Deep Neural Network" if options.get("use_dnn", False) else "Boosting ML model"
model_type = options.get("model_type", "classic").lower() 

if model_type == "dnn":
    mode = "Deep Neural Network"
elif model_type == "classic":
    mode = "Boosting ML model"  # or any other generic ML model description
   
optimize = options.get("optimize", False)
throughput = options.get("throughput", False)
num_augmented_samples = int(options.get("num_augmented_samples", 50))
box_width = 80

filename_thick2d = 'thick2d.out'
if os.path.exists(filename_thick2d):
    os.remove(filename_thick2d)

if throughput:
    filename_thickness = 'structure_thickness.txt'
    filename_struct = options.get('structure_file', None)

#    if os.path.exists(filename_thickness):
#        os.remove(filename_thickness)
#        print(f"The file {filename_thickness} has been deleted for new appending.")
        
    if filename_struct is not None:
    # Split the filename and extension
        base_filename, _ = os.path.splitext(filename_struct)
        #print(base_filename)

#if os.path.exists(filename_thickness):
#    os.remove(filename_thickness)


    if not os.path.isfile(filename_thickness):
        with open(filename_thickness, 'w') as file:
            file.write("#Material, Thickness (Ang), Material_id\n")    
    
if optimize:
    if code_type == "VASP":

        custom_options = options['custom_options']
        base_path = custom_options.get('potential_dir', "./")
        os.environ["VASP_PP_PATH"] = os.path.abspath("./potentials")

        struct = options.get("structure_file")
        atoms = load_structure(struct)
        symbols = atoms.get_chemical_symbols()
        unique_symbols = sorted(set(symbols), key=symbols.index)  # unique elements and preserve order
        #print(unique_symbols)

        # Prepare potential directories
        potentials_path = "potentials"
        for symbol in unique_symbols:
            # Check the potential file path
            potential_potcar_paths = [
                os.path.join(base_path, symbol, "POTCAR"),
                os.path.join(base_path, symbol + "_pv", "POTCAR"),
                os.path.join(base_path, symbol + "_sv", "POTCAR"),
                os.path.join(base_path, symbol + "_GW", "POTCAR"),
                os.path.join(base_path, symbol + "_sv_GW", "POTCAR"),
                os.path.join(base_path, symbol + "_pv_GW", "POTCAR")
                
            ]

            pot_file_path = next((path for path in potential_potcar_paths if os.path.exists(path)), None)
            if not pot_file_path:
                raise Exception(f"POTCAR for {symbol} not found in any of the expected directories!")
                logging.error(f"POTCAR for {symbol} not found in any of the expected directories!")

            suffix = os.path.basename(os.path.dirname(pot_file_path)) 
            potential_dir = os.path.join(potentials_path, "potpaw_PBE", suffix)
            
            if os.path.exists(potential_dir):
                shutil.rmtree(potential_dir)
                
            os.makedirs(potential_dir, exist_ok=True)

            shutil.copy(pot_file_path, potential_dir)

        atoms = optimize_structure_vasp()  
    elif code_type == "QE":
        atoms = optimize_structure_qe()

else:
    try:
        atoms = load_structure("OPT/optimized_structure.cif")          
    except ImportError:
        struct = options.get("structure_file")
        atoms = load_structure(struct)

    
print_banner(version,code_type, mode)

thickness_2D = 0.0
model_directory = os.path.join(os.getcwd(), "ml_model") #Path.cwd() / "ml_model"

thickness_2D = predict_thickness_2D(atoms, model_directory,num_augmented_samples=num_augmented_samples) #os.getcwd())
    
if isinstance(thickness_2D, np.ndarray):
    thickness_2D = thickness_2D[0] 

if isinstance(thickness_2D, np.ndarray) and thickness_2D.size == 1:
    thickness_2D = thickness_2D.item()
    


atom_counts = Counter(atoms.get_chemical_symbols())
structure_name = ''.join(f"{atom}{count if count > 1 else ''}" for atom, count in sorted(atom_counts.items()))
        
thick_label = f"Thickness of {structure_name}"
thicknessval = f"{thickness_2D:.3f} Å"



if throughput:
    append_data(filename_thickness,structure_name,thickness_2D,matid=base_filename)
        
print(f"Predicted thickness for {structure_name} is: {thickness_2D:.2f} Å".center(box_width, '-'))
logging.info(f"Predicted thickness for {structure_name} is: {thickness_2D:.2f} Å")

data = [
    (thick_label, thicknessval)
]         
            


          
with open(filename_thick2d, 'a') as th_file:
    print_banner(version,code_type, mode,ec_file=th_file)
    print_line(th_file, "=" * box_width, border_char="+", filler_char="-")
    
#    description = "                   This is a {} lattice".format(dim)
#    print_line(th_file, description)
    print_line(th_file, "=" * box_width, border_char="+", filler_char="-")
    
    for label, value in data:
        line = f"{label} = {value}"
        centered_line = line.center(box_width)
        th_file.write(centered_line + '\n')

    print_boxed_message(ec_file=th_file)


print_boxed_message()

print("")

end_time = time.time()  # Capture the end time
elapsed_time = end_time - start_time  # Calculate the elapsed time

logging.info("Well done! GOOD LUCK!")
logging.info(f"THICK2D calculation Ended at {datetime.now().strftime('%H:%M:%S')} on {datetime.now().strftime('%Y-%m-%d')}")
       
print("Well done! GOOD LUCK!")
print("")

with open("calctime.log", 'w') as f:
    f.write(f"Calculation done in {elapsed_time:.2f} s\n")


