#!/usr/bin/env python3

import numpy as np
import cv2
import os
import argparse
import subprocess
from tqdm import tqdm
from im2dhisteq import vid2dhisteq

parser = argparse.ArgumentParser(
    description="Apply im2dhisteq method on videos."
)
parser.add_argument(
    '--input',
    type=str,
    help='the path to the source video file to be processed',
    required=True
)
parser.add_argument(
    '--output', type=str, help='the path to the result file', required=True
)
parser.add_argument(
    '--w', type=int, help='width of the square window used for calculating the 2d histogram', required=False
)

args = parser.parse_args()

# input section
video_address = os.path.realpath(args.input)
## video without sound
video_output       = args.output
video_output_audio = f"{video_output[0:video_output.rfind('.')]}-audio.mp4"
audio_output       = f"{video_output[0:video_output.rfind('.')]}.mp3"
width = args.w if args.w else 6


# function section

def extract_audio(video_file, audio_file):
    command = ['ffmpeg', '-hide_banner', '-loglevel', 'error', '-i', video_file, '-q:a', '0', '-map', 'a', audio_file]
    subprocess.run(command, check=True)

def add_audio(video_file, audio_file, output_file):
    command = ['ffmpeg', '-hide_banner', '-loglevel', 'error', '-i', video_file, '-i', audio_file, '-c', 'copy', '-map', '0:v:0', '-map', '1:a:0', output_file]
    subprocess.run(command, check=True)

cap = cv2.VideoCapture(video_address)

# Get the total number of frames in the video
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

with tqdm(total=total_frames, unit='frame') as pbar:
    i = 0
    j = 0
    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    Wout_list = np.zeros((10))
    while cap.isOpened():
        ret, frame = cap.read()
        if ret == False:
            break
        #print("\r", 100 * i // length, "%", end="")

        frame_hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        frame_v = frame_hsv[:, :, 2].copy()
        image_heq, Wout = vid2dhisteq(frame_v, w_neighboring=width, Wout_list=Wout_list)
        Wout_list[j] = Wout
        j = (j + 1) % 10  # This will automatically loop j from 0 to 9
        frame_hsv[:, :, 2] = image_heq
        frame_eq = cv2.cvtColor(frame_hsv, cv2.COLOR_HSV2BGR)

        fps = cap.get(cv2.CAP_PROP_FPS)
        if i == 0:
            h, w, d = frame_eq.shape
            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
            video_out = cv2.VideoWriter(video_output, fourcc, fps, (w, h))
        video_out.write(frame_eq)

        i += 1
        pbar.update(1)

cv2.destroyAllWindows()
# Release resources
cap.release()
video_out.release()

extract_audio(video_address, audio_output)
add_audio(video_output, audio_output, video_output_audio)
