#!python


# This python script distributes simulations of Gaussian Random Fields
# realisations over multiple processors using SPAM and Martin Schlather's
# RandomFields package: https://cran.r-project.org/package=RandomFields
# Copyright (C) 2021 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

"""
################################
# CONFIGURATION VARIABLES: START
# paste these variables into a .env file or export them as environment variables
# any commented or omitted variable will take their default value

##########################
# Random fields parameters
# default: 0.1
# c_length = 0.1
# default: 100
# n_nodes = 100
# default: 3
# dim = 3
# This is the total number of realisations
# default: 5
# n_rea = 5

############################
# Multiprocessing parameters
# The realisation can be made over several processes (thus repeating the SVD)
# n_proc = 2
# To control the memory needed per process a maximum number of realisations
# per process can be set (thus repeating the SVD)
# default: n_rea
# max_rea_per_proc = 5

##################
# Files parameters
# Specify if the output should be written in the shell or in a file.
# default: True
# shell = False
# Possible values: DEBUG / INFOR / WARNING / ERROR / CRITICAL
# default: Info
# log_level = INFO
# Type of the saved array. <f8: float64 <f4: float32 <f2: float16
# default: <f4
# dtype = <f4
# Path to where data files are stored. One file per process is created with a
# random slug appened to its name to avoid overwriting data.
# Takes relative path from current directory or absolute path.
# If not defined it takes the current directory
# default: None
# data_folder = data
# Merge all process files into a single file (no random slug appended).
# The realisations will be appened to a pre-exisiting file.
# default: True
# merge_file = False

# CONFIGURATION VARIABLES: END
##############################
"""


import logging
import multiprocessing
try:                 multiprocessing.set_start_method('fork')
except RuntimeError: pass
import os
import string
import sys

import numpy
from decouple import AutoConfig

numpy.seterr(all="ignore")
import math
import random

import h5py
import spam.excursions

if len(sys.argv) == 1:
    config = AutoConfig()
else:
    config = AutoConfig(search_path=sys.argv[1])

# get configuration variables and set the dictionary
CONF = {}
CONF["c_length"] = config("c_length", cast=float, default=0.1)
CONF["n_nodes"] = config("n_nodes", cast=int, default=50)
CONF["dim"] = config("dim", cast=int, default=3)
CONF["n_rea"] = config("n_rea", cast=int, default=1)
CONF["n_proc"] = config("n_proc", cast=int, default=1)
CONF["max_rea_per_proc"] = config("max_rea_per_proc", cast=int, default=CONF["n_rea"])
CONF["shell"] = config("shell", cast=bool, default=True)
CONF["log_level"] = config("log_level", default="INFO")
CONF["dtype"] = config("dtype", default="<f4")
CONF["merge_file"] = config("merge_file", cast=bool, default=True)
CONF["data_folder"] = config("data_folder", default=None)

# get main folder based on settings folder
CONF["main_folder"] = (
    os.getcwd() if len(sys.argv) == 1 else os.path.abspath(sys.argv[1])
)
if not os.path.isdir(CONF["main_folder"]):
    os.makedirs(CONF["main_folder"])

# create data folder
CONF["data_folder"] = (
    os.path.abspath(CONF["data_folder"]) if CONF["data_folder"] else CONF["main_folder"]
)
if not os.path.isdir(CONF["data_folder"]):
    os.makedirs(CONF["data_folder"])

# define logger
format = "[%(asctime)s - %(name)s - %(levelname)s]  %(message)s"
if CONF["shell"]:
    CONF["log_file"] = None
else:
    CONF["log_file"] = os.path.join(
        CONF["main_folder"], f"{os.path.splitext(os.path.basename(__file__))[0]}.log"
    )
logging.basicConfig(
    format=format,
    filename=CONF["log_file"],
    filemode="w",
    level=getattr(logging, CONF["log_level"]),
)
log = logging.getLogger("grf")

# create pool distribution
pool_distribution = [CONF["n_rea"] // CONF["n_proc"]] * CONF["n_proc"]
for i in range(CONF["n_rea"] % CONF["n_proc"]):
    pool_distribution[i] += 1

# dump configurations in the log
log.info(f"Working Directory: {os.getcwd()}")
log.info(f"Script: {__file__}")
for k, v in CONF.items():
    log.info(f"Configurations: {k} = {v}")
    print(f"Configurations: {k} = {v}")

# MAIN FUNCTION
def gen_rf(n):
    # don't run if n == 0
    if not n:
        return

    # get process info and slug
    process = multiprocessing.current_process()
    process_readable_id = f'{int(process.name.split("-")[-1])}'
    slug = "".join([random.choice(string.ascii_lowercase) for _ in range(8)])
    log = logging.getLogger(f"grf.process.{process_readable_id}")

    # name pk file to save data (add slug to avoid overwriting data)
    h5_file = os.path.join(
        CONF["data_folder"],
        f'grf-lc{CONF["c_length"]}-n{CONF["n_nodes"]}-rea{n}-{slug}.h5',
    )

    # create batch list
    n_batch = math.ceil(n / CONF["max_rea_per_proc"])
    batch = [n // n_batch] * n_batch
    for i in range(n % n_batch):
        batch[i] += 1

    log.info(f"Generate {n:02d} realisations in {n_batch} batch: {h5_file}")

    n_past = 0
    for i_batch in range(n_batch):

        n_current = batch[i_batch]
        n_past = sum([batch[i] for i in range(i_batch)])

        log.debug(
            f"\t - Batch {i_batch + 1:02d}/{n_batch:02d} \
                  ({n_current + n_past:02d}/{n:02d} rea): simulate"
        )

        # define covariance
        covariance = {
            "type": "stable",
            "alpha": 2.0,
            "variance": 1.0,
            "correlation_length": CONF["c_length"],
        }

        # generate realisations
        realisations = spam.excursions.simulateRandomField(
            nNodes=CONF["n_nodes"],
            covariance=covariance,
            dim=CONF["dim"],
            nRea=n_current,
            RprintLevel=0,
            # vtkFile=f'grf-lc{CONF["c_length"]}-{process.pid}'
        )

        # expand axis in case of single realisation for homogeneous input
        if len(realisations.shape) == CONF["dim"]:
            realisations = numpy.expand_dims(realisations, axis=CONF["dim"])

        # dump the realisation into a hd5
        log.debug(
            f"\t - Batch {i_batch + 1:02d}/{n_batch:02d} \
                  ({n_current + n_past:02d}/{n:02d} rea): write"
        )
        with h5py.File(h5_file, "a") as f:
            for i in range(realisations.shape[-1]):
                realisation = realisations[..., i]
                data_set = f.create_dataset(
                    f"{i + n_past:05d}",
                    realisation.shape,
                    dtype=str(CONF["dtype"]),
                    data=realisation,
                )
                for k in ["dim", "c_length", "n_nodes", "dtype"]:
                    data_set.attrs[k] = CONF.get(k)

        log.debug(
            f"\t - Batch {i_batch + 1:02d}/{n_batch:02d} \
                  ({n_current + n_past:02d}/{n:02d} rea): end"
        )

    return h5_file


# run multiprocessing function
with multiprocessing.Pool(CONF["n_proc"]) as p:
    files = p.map(gen_rf, pool_distribution)

# merge files
if CONF["merge_file"]:
    h5_file = os.path.join(
        CONF["data_folder"], f'grf-lc{CONF["c_length"]}-n{CONF["n_nodes"]}.h5'
    )
    log.info(f"Merging data into a single file: {h5_file}")
    with h5py.File(h5_file, "a") as f_write:

        # get last realisation number to update dataset key
        i_rea = 0
        for k in f_write.keys():
            i_rea = max(i_rea, int(k))

        # loop over the process files
        for i, file in enumerate([f for f in files if f]):

            # open main file
            with h5py.File(file, "r") as f_read:
                log.info(f"Merging file: {file} >> {h5_file}")

                # append datasets to the merged file
                for rea in f_read.values():
                    i_rea += 1
                    data_set = f_write.create_dataset(
                        f"{i_rea:05d}", rea[:].shape, data=rea[:], dtype=rea[:].dtype
                    )
                    for k, v in rea.attrs.items():
                        data_set.attrs[k] = v

            # delete process files
            os.remove(file)

    log.info(f"Merging data into a single file: done")
