#!python
"""HybridQ: A Hybrid Simulator for Quantum Circuits.

    Usage:
        hybridq-dm <circuit_filename> <pauli_string> <output_filename> [options]
        hybridq-dm -s <pauli_string> -o <output_filename> [options]
        hybridq-dm (-h|--help)
        hybridq-dm --version

    Options:
        -h,--help                                                Show this help.
        -p,--params=<params_filename>                            Load parameters from file.
        -c,--circuit-filename=<circuit_filename>                 Circuit filename [default: stdin].
        -o,--output-filename=<output_filename>                   Output filename.
        -s,--pauli-string=<pauli_string>                         Pauli string to use.
        --compress=<compress>                                    Level of gate compression [default: auto].
        --parallel                                               Run some tasks in parallel.
        --return-info                                            Return extra info (if available).
        --use-mpi                                                Use MPI.
        --append                                                 Append to an existing file.
        --verbose                                                Verbose output.
        --version                                                Show version.

    Author: Salvatore Mandra (salvatore.mandra@nasa.gov)

    Copyright © 2021, United States Government as represented by the
    Administrator of the National Aeronautics and Space Administration. All
    rights reserved.
    
    The HybridQ: A Hybrid Simulator for Quantum Circuits framework is licensed
    under the Apache License, Version 2.0 (the "License"); you may not use this
    application except in compliance with the License. You may obtain a copy of
    the License at http://www.apache.org/licenses/LICENSE-2.0. 
    
    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    License for the specific language governing permissions and limitations
    under the License.
"""

from hybridq.gate import Gate
from hybridq.circuit import Circuit
from hybridq.circuit.simulation import clifford
from hybridq.extras.io import qasm
from hybridq.circuit import utils
from warnings import warn
from docopt import docopt
from time import time
import numexpr as ne
import numpy as np
import json
import sys
import os

if __name__ == '__main__':

    # Get arguments from doctopt
    args = docopt(__doc__)

    # If --version, dump version and exit
    if args['--version']:
        from importlib.metadata import version
        print(version('hybridq'))
        sys.exit()

    # Fix args
    args['--pauli-string'] = args['<pauli_string>'] if args[
        '<pauli_string>'] else args['--pauli-string']
    args['--circuit-filename'] = args['<circuit_filename>'] if args[
        '<circuit_filename>'] else args['--circuit-filename']
    args['--output-filename'] = args['<output_filename>'] if args[
        '<output_filename>'] else args['--output-filename']
    if args['--compress'] == 'auto':
        del (args['--compress'])

    # Initialize params
    params = {
        k[2:].replace('-', '_'): v
        for k, v in args.items()
        if k[:2] == '--' and v is not None
    }

    # Load params from file
    if args['--params']:

        # Check if args['--params'] is actually an existing filename
        if os.path.exists(args['--params']):
            with open(args['--params']) as file:
                _params = json.loads(file.read())

        # Otherwise, try to convert it to dict
        else:
            _params = json.loads(args['--params'])

        # Update params
        params.update({k.replace('-', '_'): v for k, v in _params.items()})

    # Convert to integer
    for k in ['compress']:
        if k in params:
            params[k] = int(ne.evaluate(str(params[k])))

    # Convert to bool
    for k in ['use_mpi', 'return_info']:
        if k in params:
            params[k] = bool(ne.evaluate(str(params[k])))

    ######################## FEW CHECKS ########################

    # If MPI, cannot read from stdin
    if params['use_mpi'] and (params['circuit_filename'] == 'stdin'):
        raise ValueError("Reading from stdin is not allowed when --use-mpi.")

    # Load MPI
    if params['use_mpi']:

        from mpi4py import MPI
        mpi_comm = MPI.COMM_WORLD
        mpi_size = mpi_comm.Get_size()
        mpi_rank = mpi_comm.Get_rank()

    # Check if it is possible to write on file
    if params['output_filename'] != '-':
        if not os.access(params['output_filename'], os.W_OK):

            # Check if output filename can be created
            try:
                with open(params['output_filename'], 'w') as file:
                    file.write('')
            except:

                print(sys.exc_info()[1], file=sys.stderr)
                raise ValueError(
                    f"Cannot write to file \'{params['output_filename']}\'.")

        else:

            # If file already exists, refuse to overwrite it
            if not params['append']:

                warn(f"File '{params['output_filename']}' already exists and " + \
                      "it will be overwritten. If this is not the intended " + \
                      "behavior, use --append instead.")

    ####################### LOAD CIRCUIT #######################

    if params['circuit_filename'] == 'stdin':
        params['circuit'] = qasm.from_qasm(sys.stdin.read())
    else:
        with open(params['circuit_filename']) as file:
            params['circuit'] = qasm.from_qasm(file.read())

    # Check Pauli string
    if set(params['pauli_string']).difference(list('IXYZ')):
        raise ValueError(
            'The only allowed symbols for the Pauli strings are {I,X,Y,Z}.')
    if len(params['pauli_string']) != len(params['circuit'].all_qubits()):
        raise ValueError(
            'Pauli string must be long as the number of qubits in circuit.')

    # Initialize Pauli string
    params['pauli_string'] = Circuit(
        Gate(p, [q])
        for p, q in zip(params['pauli_string'], params['circuit'].all_qubits()))

    # Check that circuit and pauli_string have the exact same qubits
    assert (
        params['circuit'].all_qubits() == params['pauli_string'].all_qubits())

    ######################### DUMP INFO ########################

    if params['verbose'] and (not params['use_mpi'] or mpi_rank == 0):
        for k, v in params.items():
            if k == 'circuit':
                pass
            elif k == 'output_filename':
                _file = params['output_filename']
                print('# Output Filename:',
                      _file if _file != '-' else 'stdout',
                      file=sys.stderr)
            elif k == 'pauli_string':
                _pauli_string = ''.join(g.name for g in params['pauli_string'])
                print('# Pauli String:', _pauli_string, file=sys.stderr)
            else:
                print('#',
                      ' '.join([
                          x.capitalize() for x in k.replace('_', ' ').split()
                      ]),
                      file=sys.stderr,
                      end=': ')
                print(v, file=sys.stderr)
        print(f'# Number of qubits: {len(params["circuit"].all_qubits())}',
              file=sys.stderr)

    ########################## SIMULATE ########################

    # Initialize results
    results = {}

    # Start clock
    _start = time()

    # Simulate
    results['simulate'] = clifford.update_pauli_string(**params)

    # Stop clock
    _end = time()

    # Add runtime
    results['runtime (s)'] = _end - _start

    # Print runtime
    if params['verbose'] and (not params['use_mpi'] or mpi_rank == 0):
        print(f'# Runtime (s): {results["runtime (s)"]:1.4f}')

    # Dump results
    if not params['use_mpi'] or mpi_rank == 0:

        # Dump results to file
        results = json.dumps(results, indent=2)

        if params['output_filename'] == '-':
            print(results, file=sys.stdout)
        else:
            with open(params['output_filename'],
                      'a' if params['append'] else 'w') as file:
                print(results, file=file)
