#!python
"""HybridQ: A Hybrid Simulator for Quantum Circuits.

    Usage:
        hybridq <circuit_filename> <output_filename> [options]
        hybridq <output_filename> [options]
        hybridq -o <output_filename> [options]
        hybridq (-h|--help)
        hybridq --version

    Options:
        -h,--help                                                Show this help.
        -p,--params=<params_filename>                            Load parameters from file.
        -c,--circuit-filename=<circuit_filename>                 Circuit filename [default: stdin].
        -o,--output-filename=<output_filename>                   Output filename.
        --initial-state=<initial_state>                          Initial state [default: 0].
        --final-state=<final_state>                              Final state [default: .].
        --optimize=<optimize>                                    Choose the optimization method [default: evolution].
        --backend=<backend>                                      Choose backend [default: numpy].
        --parallel                                               Run some tasks in parallel.
        --compress=<compress>                                    Level of gate compression [default: auto].
        --max-iterations=<max_iterations>                        Max iterations for optimal contraction (tensor contraction only) [default: 2].
        --max-repeats=<max_repeats>                              Max repeats for optimal contraction (tensor contraction only) [default: 32].
        --max-largest-intermediate=<max_largest_intermediate>    Max largest intermediate (tensor contraction only) [default: 67108864].
        --max-n-slices=<max_n_slices>                            Max number of slices to use (tensor contraction only).
        --tensor-only                                            Just return tensor + contraction (tensor contraction only).
        --complex-type=<complex_type>                            Complex type to be used [default: complex64].
        --return-info                                            Return extra info (if available).
        --use-mpi                                                Use MPI.
        --atol=<atol>                                            Absolute tollerance [default: 1e-8].
        --append                                                 Append to an existing file.
        --verbose                                                Verbose output.
        --version                                                Show version.

    Author: Salvatore Mandra (salvatore.mandra@nasa.gov)

    Copyright © 2021, United States Government as represented by the
    Administrator of the National Aeronautics and Space Administration. All
    rights reserved.
    
    The HybridQ: A Hybrid Simulator for Quantum Circuits framework is licensed
    under the Apache License, Version 2.0 (the "License"); you may not use this
    application except in compliance with the License. You may obtain a copy of
    the License at http://www.apache.org/licenses/LICENSE-2.0. 
    
    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    License for the specific language governing permissions and limitations
    under the License.
"""

from hybridq.gate import Gate
from hybridq.circuit import Circuit
from hybridq.circuit.simulation import simulate
from hybridq.extras.io import qasm
from hybridq.circuit import utils
from warnings import warn
from docopt import docopt
from time import time
import numexpr as ne
import numpy as np
import pickle
import json
import sys
import os


def _get_state(state: str, n_qubits: int):

    # Check state is a valid string
    if not isinstance(state, str):
        raise ValueError("'state' must be a string.")

    # If single char, extend to the number of qubits
    if len(state) == 1 and state in '01+-.':
        state = state * n_qubits

    # Check number of qubits
    if len(state) != n_qubits:
        raise ValueError(
            f'Wrong number of qubits " f"(expected {n_qubits}, got {len(state)}).'
        )

    # Return state
    return state


def _convert_to(params: dict, kw: list, t: type):

    # Initialize
    _params = {}

    # For each key ...
    for k in kw:

        # If key in params ...
        if k in params:

            # Evaluate and convert to the given type
            _params[k] = t(ne.evaluate(str(params[k])))

    # Return
    return _params


if __name__ == '__main__':

    # Get arguments from doctopt
    args = docopt(__doc__)

    # If --version, dump version and exit
    if args['--version']:
        from importlib.metadata import version
        print(version('hybridq'))
        sys.exit()

    # Fix args
    args['--circuit-filename'] = args['<circuit_filename>'] if args[
        '<circuit_filename>'] else args['--circuit-filename']
    args['--output-filename'] = args['<output_filename>'] if args[
        '<output_filename>'] else args['--output-filename']
    if args['--compress'] == 'auto':
        del (args['--compress'])

    # Initialize params
    params = {
        k[2:].replace('-', '_'): v
        for k, v in args.items()
        if k[:2] == '--' and v is not None
    }

    # Load params from file
    if args['--params']:

        # Check if args['--params'] is actually an existing filename
        if os.path.exists(args['--params']):
            with open(args['--params']) as file:
                _params = json.loads(file.read())

        # Otherwise, try to convert it to dict
        else:
            _params = json.loads(args['--params'])

        # Update params
        params.update({k.replace('-', '_'): v for k, v in _params.items()})

    # Convert to the given type
    params.update(
        _convert_to(params, [
            'max_largest_intermediate', 'max_repeats', 'max_iterations',
            'compress', 'max_n_slices'
        ], int))
    params.update(_convert_to(params, ['atol'], float))
    params.update(
        _convert_to(params, ['use_mpi', 'tensor_only', 'return_info'], bool))

    ######################## FEW CHECKS ########################

    # If MPI, cannot read from stdin
    if params['use_mpi'] and (params['circuit_filename'] == 'stdin'):
        raise ValueError("Reading from stdin is not allowed when --use-mpi.")

    # Load MPI
    if params['use_mpi']:

        from mpi4py import MPI
        mpi_comm = MPI.COMM_WORLD
        mpi_size = mpi_comm.Get_size()
        mpi_rank = mpi_comm.Get_rank()

    # Check if it is possible to write on file
    if not os.access(params['output_filename'], os.W_OK):

        # Check if output filename can be created
        try:
            with open(params['output_filename'], 'w') as file:
                file.write('')
        except:

            print(sys.exc_info()[1], file=sys.stderr)
            raise ValueError(
                f"Cannot write to file \'{params['output_filename']}\'.")

    else:

        # If file already exists, refuse to overwrite it
        if not params['append']:

            warn(f"File '{params['output_filename']}' already exists and " + \
                  "it will be overwritten. If this is not the intended " + \
                  "behavior, use --append instead.")

    ####################### LOAD CIRCUIT #######################

    if params['circuit_filename'] == 'stdin':
        params['circuit'] = qasm.from_qasm(sys.stdin.read())
    else:
        with open(params['circuit_filename']) as file:
            params['circuit'] = qasm.from_qasm(file.read())

    ##################### INITIALIZE STATES ####################

    params['initial_state'] = _get_state(params['initial_state'],
                                         len(params['circuit'].all_qubits()))
    params['final_state'] = _get_state(params['final_state'],
                                       len(params['circuit'].all_qubits()))

    ######################### DUMP INFO ########################

    if params['verbose'] and (not params['use_mpi'] or mpi_rank == 0):
        for k, v in params.items():
            if k != 'circuit':
                print('#',
                      ' '.join([
                          x.capitalize() for x in k.replace('_', ' ').split()
                      ]),
                      file=sys.stderr,
                      end=': ')
                print(v, file=sys.stderr)
        print(f'# Number of qubits: {len(params["circuit"].all_qubits())}',
              file=sys.stderr)

    ########################## SIMULATE ########################

    # Initialize results
    results = {}

    # Start clock
    _start = time()

    # Simulate
    results['simulate'] = simulate(**params)

    # Stop clock
    _end = time()

    # Add runtime
    results['runtime (s)'] = _end - _start

    # Print runtime
    if params['verbose'] and (not params['use_mpi'] or mpi_rank == 0):
        if 'evolution' in params['optimize']:
            psi = results['simulate'][0] if params['return_info'] else results[
                'simulate'].ravel()
            for x in range(8):
                _nq = len(params['initial_state'])
                x = int(''.join(reversed(bin(x)[2:].zfill(_nq))), 2)
                print(bin(x)[2:5].zfill(3), end='', file=sys.stderr)
                if _nq > 5:
                    print('0...0', end='', file=sys.stderr)
                else:
                    print('0' * (_nq - 3), end='', file=sys.stderr)
                print(
                    f': {psi[x]:+1.5e} (norm^2={np.linalg.norm(psi[x])**2:1.5e})',
                    file=sys.stderr)
        print(f'# Runtime (s): {results["runtime (s)"]:1.4f}', file=sys.stderr)

    # If tensor_only==True, remove opt._pool to allow serialization
    if params['tensor_only'] and params['parallel']:
        del (results['simulate'][1][1]._pool)

    # Dump results
    if not params['use_mpi'] or mpi_rank == 0:

        # Dump results to file
        with open(params['output_filename'],
                  'ab' if params['append'] else 'wb') as file:
            file.write(pickle.dumps(results))
