#!/usr/bin/env python3

import sys
import os
from dbnn.dbnn import GPUDBNN, DatasetConfig

def main():
    # Get available datasets in current directory
    available_datasets = DatasetConfig.get_available_datasets()
    
    if len(sys.argv) < 2:
        print("Usage: dbnn <dataset_name> [--train] [--predict] [--train-only]")
        print("\nAvailable datasets:", ", ".join(available_datasets))
        sys.exit(1)

    dataset_name = sys.argv[1]
    train_flag = "--train" in sys.argv
    predict_flag = "--predict" in sys.argv
    train_only_flag = "--train-only" in sys.argv

    if dataset_name not in available_datasets:
        print(f"Error: Dataset '{dataset_name}' not found.")
        print("\nAvailable datasets:", ", ".join(available_datasets))
        sys.exit(1)

    try:
        model = GPUDBNN(dataset_name=dataset_name)
        
        if train_flag or train_only_flag:
            results = model.fit_predict(save_path=f"{dataset_name}_train_test_predictions.csv")
            print(f"\nTest Accuracy: {results['test_accuracy']:.4f}")
            
        if predict_flag and not train_only_flag:
            predictions = model.predict_and_save(save_path=f"{dataset_name}_predictions.csv")
            
    except Exception as e:
        print(f"Error processing dataset {dataset_name}: {str(e)}")
        sys.exit(1)

if __name__ == "__main__":
    main()
