#!python
"""
Convert a mesh file to another.
"""
from __future__ import print_function

import numpy

import meshio


def _main():
    # Parse command line arguments.
    args = _parse_options()

    # read mesh data
    points, cells, point_data, cell_data, field_data = meshio.read(
        args.infile, file_format=args.input_format
    )

    print("Number of points: {}".format(len(points)))
    print("Elements:")
    for tpe, elems in cells.items():
        print("  Number of {}s: {}".format(tpe, len(elems)))

    if point_data:
        print("Point data: {}".format(", ".join(point_data.keys())))

    cell_data_keys = set()
    for cell_type in cell_data:
        cell_data_keys = cell_data_keys.union(cell_data[cell_type].keys())
    if cell_data_keys:
        print("Cell data: {}".format(", ".join(cell_data_keys)))

    if args.prune:
        cells.pop("vertex", None)
        cells.pop("line", None)
        if "tetra" in cells:
            # remove_lower_order_cells
            cells.pop("triangle", None)
        # remove_orphaned_nodes.
        # find which nodes are not mentioned in the cells and remove them
        flat_cells = cells["tetra"].flatten()
        orphaned_nodes = numpy.setdiff1d(numpy.arange(len(points)), flat_cells)
        points = numpy.delete(points, orphaned_nodes, axis=0)
        # also adapt the point data
        for key in point_data:
            point_data[key] = numpy.delete(point_data[key], orphaned_nodes, axis=0)

        # reset GLOBAL_ID
        if "GLOBAL_ID" in point_data:
            point_data["GLOBAL_ID"] = numpy.arange(1, len(points) + 1)

        # We now need to adapt the cells too.
        diff = numpy.zeros(len(flat_cells), dtype=flat_cells.dtype)
        for orphan in orphaned_nodes:
            diff[numpy.argwhere(flat_cells > orphan)] += 1
        flat_cells -= diff
        cells["tetra"] = flat_cells.reshape(cells["tetra"].shape)

    # Some converters (like VTK) require `points` to be contiguous.
    points = numpy.ascontiguousarray(points)

    # write it out
    meshio.write(
        args.outfile,
        points,
        cells,
        file_format=args.output_format,
        point_data=point_data,
        cell_data=cell_data,
        field_data=field_data,
    )

    return


def _parse_options():
    """Parse input options."""
    import argparse

    parser = argparse.ArgumentParser(description=("Convert between mesh formats."))

    parser.add_argument("infile", type=str, help="mesh file to be read from")

    parser.add_argument(
        "--input-format",
        "-i",
        type=str,
        choices=meshio.helpers.input_filetypes,
        help="input file format",
        default=None,
    )

    parser.add_argument(
        "--output-format",
        "-o",
        type=str,
        choices=meshio.helpers.output_filetypes,
        help="output file format",
        default=None,
    )

    parser.add_argument("outfile", type=str, help="mesh file to be written to")

    parser.add_argument(
        "--prune",
        "-p",
        action="store_true",
        help="remove lower order cells, remove orphaned nodes",
    )

    parser.add_argument(
        "--version",
        "-v",
        action="version",
        version="%(prog)s " + ("(version %s)" % meshio.__version__),
    )

    return parser.parse_args()


if __name__ == "__main__":
    _main()
