#!/usr/bin/env python3
"""
Sample script to export frames to files. Can be modified to export to any
file format.
"""

import argparse
import collections
import os.path
import shutil

import numpy

import cepton_sdk
import cepton_sdk.capture_replay
import cepton_sdk.export
from cepton_util.common import *


def process_points(points):
    # TODO: process points before exporting
    return points


def main():
    parser = argparse.ArgumentParser(
        usage="%(prog)s [OPTIONS] output_dir",
        description="Exports Cepton LiDAR points from a sensor or capture file. By default, exports frames to individual files.")
    parser.add_argument("output_dir", help="Output directory.")
    parser.add_argument("--capture_path", help="Path to PCAP capture file.")
    parser.add_argument("--capture_seek", type=float,
                        help="Capture file start position [seconds].")
    all_file_types = [x.name for x in cepton_sdk.export.PointsFileType]
    parser.add_argument("--format", default="LAS", choices=all_file_types, help="Output file format.")
    parser.add_argument("--combine", action="store_true",
                        help="Combine points into single file per sensor.")
    parser.add_argument("--t_length", default="1", help="Maximum export time.")
    args = parser.parse_args()

    file_type = cepton_sdk.export.PointsFileType[args.format.upper()]
    t_length = parse_time_hms(args.t_length)

    output_dir = fix_path(remove_extension(args.output_dir))
    shutil.rmtree(output_dir, ignore_errors=True)
    os.makedirs(output_dir)

    # Initialize
    options = {}
    if args.capture_path is not None:
        options["capture_path"] = fix_path(args.capture_path)
    cepton_sdk.initialize(**options)
    if args.capture_seek is not None:
        cepton_sdk.capture_replay.seek(args.capture_seek)

    listener = cepton_sdk.ImageFramesListener()
    if args.combine:
        cepton_sdk.wait(t_length)
        image_points_dict = listener.get_points()
        for serial_number, image_points_list in image_points_dict.items():
            image_points = cepton_sdk.combine_points(image_points_list)
            is_valid = image_points.distances < 1e4
            image_points = image_points[is_valid]

            path = os.path.join(output_dir, str(serial_number))
            points = image_points.to_points()
            points = process_points(points)

            # Save
            cepton_sdk.export.save_points(points, path, file_type=file_type)
    else:
        t_0 = cepton_sdk.get_time()
        i_frame = collections.defaultdict(lambda: 0)
        while True:
            if t_length > 0:
                if (cepton_sdk.get_time() - t_0) > t_length:
                    break
            try:
                image_points_dict = listener.get_points()
            except:
                break
            for serial_number, image_points_list in image_points_dict.items():
                sensor_dir = os.path.join(output_dir, str(serial_number))
                if not os.path.isdir(sensor_dir):
                    os.makedirs(sensor_dir)
                for image_points in image_points_list:
                    is_valid = image_points.distances < 1e4
                    image_points = image_points[is_valid]

                    i_frame_tmp = i_frame[serial_number]
                    path = os.path.join(sensor_dir, str(i_frame_tmp))
                    points = image_points.to_points()
                    points = process_points(points)

                    # Save
                    cepton_sdk.export.save_points(
                        points, path, file_type=file_type)

                    i_frame[serial_number] += 1


if __name__ == "__main__":
    main()
