# Perform video classification using deep learning techniques

# Install required packages
# !pip install torch torchvision opencv-python matplotlib

import cv2
import torch
import torchvision.transforms as transforms
from torchvision import models
from collections import Counter
import requests
import matplotlib.pyplot as plt

model = models.resnet50(pretrained=True)
model.eval()

labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = requests.get(labels_url).json()

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def classify_frame(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = transform(frame)
    frame = frame.unsqueeze(0)
    with torch.no_grad():
        outputs = model(frame)
        _, predicted = outputs.max(1)
        return labels[predicted.item()]

def classify_video(video_path, frame_interval=30):
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    classifications = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_interval == 0:
            label = classify_frame(frame)
            classifications.append(label)
            print(f"Frame {frame_count}: {label}")

        frame_count += 1

    cap.release()
    return classifications

def analyze_video_classification(video_path, frame_interval=30):
    print(f"Analyzing video: {video_path}")
    print(f"Sampling every {frame_interval} frames")
    print("-" * 50)

    classifications = classify_video(video_path, frame_interval)

    if classifications:
        class_counts = Counter(classifications)

        print(f"\nTotal frames analyzed: {len(classifications)}")
        print("\nClassification results:")
        print("-" * 30)

        for class_name, count in class_counts.most_common():
            percentage = (count / len(classifications)) * 100
            print(f"{class_name}: {count} frames ({percentage:.1f}%)")

        most_common = class_counts.most_common(1)[0]
        print(f"\nMost common classification: {most_common[0]} ({most_common[1]} frames)")

        return classifications, class_counts
    else:
        print("No frames were processed.")
        return [], Counter()

if __name__ == "__main__":
    video_path = "sample_video.mp4"

    try:
        classifications, counts = analyze_video_classification(video_path, frame_interval=30)
    except Exception as e:
        print(f"Error processing video: {e}")
        print("Please ensure the video file exists and is accessible.")

