#!/usr/bin/env python
import argparse
import logging
import os
from pathlib import Path
from cxas import CXAS

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def parse_arguments() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Segment 159 anatomical structures in X-Ray images.",
        epilog="Written by Constantin Seibold. If you use this tool, please cite accordingly."
    )

    parser.add_argument(
        "-i", "--input",
        metavar="filepath",
        type=str,
        required=True,
        help="Path to file or directory to be processed."
    )
    
    parser.add_argument(
        "-o", "--output",
        metavar="directory",
        type=str,
        required=True,
        help="Output directory for segmentation masks."
    )
    
    parser.add_argument(
        "-ot", "--output_type",
        choices=["json", "npy", "npz", "jpg", "png", "dicom-seg"],
        default='png',
        help="Storage type of segmentations if they are stored."
    )
    
    parser.add_argument(
        "-g", "--gpus",
        default="0",
        help="Select specific GPU/CPU to process the input."
    )
    
    parser.add_argument(
        "-m", "--model",
        choices=["UNet_ResNet50_default"],
        default="UNet_ResNet50_default",
        help="Model used for inference."
    )

    return parser.parse_args()

def main() -> None:
    """Main entry point for the script."""
    args = parse_arguments()

    model = CXAS(model_name=args.model, gpus=args.gpus)
    
    input_path = Path(args.input)
    output_directory = Path(args.output)

    if input_path.is_dir():
        logging.info(f"Processing directory: {input_path}...")
        model.process_folder(
            input_directory_name=str(input_path),
            output_directory=str(output_directory),
            create=True,
            storage_type=args.output_type,
        )
        logging.info(f"Segmentation completed. Results stored in: {output_directory}")
        
    elif input_path.is_file():
        logging.info(f"Processing file: {input_path}...")
        model.process_file(
            filename=str(input_path),
            output_directory=str(output_directory),
            create=True,
            do_store=True,
            storage_type=args.output_type,
        )
        logging.info(f"Segmentation completed. Results stored in: {output_directory}")
        
    else:
        logging.error(f"{input_path} is neither a file nor a directory.")
        raise FileNotFoundError(f"{input_path} is neither a file nor a directory.")

if __name__ == "__main__":
    main()
