#!env python

import sys

try:
    import pykitti
except ImportError as e:
    print('Could not load module \'pykitti\'. Please run `pip install pykitti`')
    sys.exit(1)

import tf
import os
import cv2
import rospy
import rosbag
from tf2_msgs.msg import TFMessage
from datetime import datetime
from sensor_msgs.msg import Image, Imu
from geometry_msgs.msg import TransformStamped
from cv_bridge import CvBridge, CvBridgeError


def save_imu_data(bag, kitti, imu_frame_id, topic):
    print("Exporting IMU")
    for timestamp, oxts in zip(kitti.timestamps, kitti.oxts):
        q = tf.transformations.quaternion_from_euler(oxts.packet.roll, oxts.packet.pitch, oxts.packet.yaw)
        imu = Imu()
        imu.header.frame_id = imu_frame_id
        imu.header.stamp = rospy.Time.from_sec(float(timestamp.strftime("%s.%f")))
        imu.orientation.x = q[0]
        imu.orientation.y = q[1]
        imu.orientation.z = q[2]
        imu.orientation.w = q[3]
        imu.linear_acceleration.x = oxts.packet.af
        imu.linear_acceleration.y = oxts.packet.al
        imu.linear_acceleration.z = oxts.packet.au
        imu.angular_velocity.x = oxts.packet.wf
        imu.angular_velocity.y = oxts.packet.wl
        imu.angular_velocity.z = oxts.packet.wu
        bag.write(topic, imu, t=imu.header.stamp)


def save_camera_data(bag, kitti, bridge, camera, camera_frame_id, topic):
    print("Exporting camera {}".format(camera))
    camera_pad = '{0:02d}'.format(camera)
    image_path = os.path.join(kitti.data_path, 'image_{}'.format(camera_pad))
    img_data_dir = os.path.join(image_path, 'data')
    img_filenames = sorted(os.listdir(img_data_dir))
    with open(os.path.join(image_path, 'timestamps.txt')) as f:
        img_datetimes = map(lambda x: datetime.strptime(x[:-4], '%Y-%m-%d %H:%M:%S.%f'), f.readlines())

    for dt, filename in zip(img_datetimes, img_filenames):
        img_filename = os.path.join(img_data_dir, filename)
        cv_image = cv2.imread(img_filename)
        if camera in (0, 1):
            cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
        encoding = "mono8" if camera in (0, 1) else "bgr8"
        image_message = bridge.cv2_to_imgmsg(cv_image, encoding=encoding)
        image_message.header.frame_id = camera_frame_id
        image_message.header.stamp = rospy.Time.from_sec(float(datetime.strftime(dt, "%s.%f")))
        bag.write(topic, image_message, t=image_message.header.stamp)


def get_static_transform(bag, time, from_frame_id, to_frame_id, transform):
    q = tf.transformations.quaternion_from_matrix(transform)
    t = transform[0:3, 3]
    tf_msg = TransformStamped()
    tf_msg.header.frame_id = from_frame_id
    tf_msg.header.stamp = time
    tf_msg.child_frame_id = to_frame_id
    tf_msg.transform.translation.x = float(t[0])
    tf_msg.transform.translation.y = float(t[1])
    tf_msg.transform.translation.z = float(t[2])
    tf_msg.transform.rotation.x = float(q[0])
    tf_msg.transform.rotation.y = float(q[1])
    tf_msg.transform.rotation.z = float(q[2])
    tf_msg.transform.rotation.w = float(q[3])
    return tf_msg


def save_static_transforms(bag, transforms, min_time):
    print("Exporting transformations")
    tfm = TFMessage()
    time = rospy.Time.from_sec(min_time)
    for transform in transforms:
        t = get_static_transform(bag, time, from_frame_id=transform[0], to_frame_id=transform[1], transform=transform[2])
        tfm.transforms.append(t)
    bag.write('/tf_static', tfm, t=time)


def main():
    if len(sys.argv) not in (3, 4):
        print("Usage: kitti2bag base_dir date drive")
        print("   or: kitti2bag date drive (base_dir is then current dir)")
        return
    basedir = sys.argv[1] if len(sys.argv) == 4 else os.getcwd()
    date, drive = tuple(sys.argv[-2:])

    bridge = CvBridge()
    compression = rosbag.Compression.NONE
    # compression = rosbag.Compression.BZ2
    # compression = rosbag.Compression.LZ4
    bag = rosbag.Bag("kitti_{}_drive_{}_sync.bag".format(date, drive), 'w', compression=compression)
    kitti = pykitti.raw(basedir, date, drive)
    if not os.path.exists(kitti.data_path):
        print('Path {} does not exists. Exiting.'.format(kitti.data_path))
        sys.exit(1)

    kitti.load_calib()
    kitti.load_oxts()
    kitti.load_timestamps()

    if len(kitti.timestamps) == 0:
        print('Dataset is empty? Exiting.')
        sys.exit(1)

    try:
        # IMU
        imu_frame_id = 'imu_link'
        imu_topic = '/kitti/oxts/imu'

        # Prepared for velo. Not implemented yet.
        velo_frame_id = 'velo_link'
        velo_topic = '/kitti/velo'

        # CAMERAS
        cameras = [
            (0, 'camera_gray_left', '/kitti/camera/gray/left'),
            (1, 'camera_gray_right', '/kitti/camera/gray/right'),
            (2, 'camera_color_left', '/kitti/camera/color/left'),
            (3, 'camera_color_left', '/kitti/camera/color/right')
        ]

        # tf_static
        transforms = [
            (imu_frame_id, velo_frame_id, kitti.calib.T_velo_imu),
            (imu_frame_id, cameras[0][1], kitti.calib.T_cam0_imu),
            (imu_frame_id, cameras[1][1], kitti.calib.T_cam1_imu),
            (imu_frame_id, cameras[2][1], kitti.calib.T_cam2_imu),
            (imu_frame_id, cameras[3][1], kitti.calib.T_cam3_imu)
        ]

        min_time = float(kitti.timestamps[0].strftime("%s.%f"))

        # Export
        save_static_transforms(bag, transforms, min_time)
        save_imu_data(bag, kitti, imu_frame_id, imu_topic)
        for camera in cameras:
            save_camera_data(bag, kitti, bridge, camera=camera[0], camera_frame_id=camera[1], topic=camera[2])
    finally:
        print("## OVERVIEW ##")
        print(bag)
        bag.close()


if __name__ == '__main__':
    main()
