# import subprocess
# import sys

# # -----------------------------
# # Function to install packages
# # -----------------------------
# def install_if_missing(package, index_url=None):
#     try:
#         __import__(package)
#     except ImportError:
#         print(f"⚙️ Installing {package} ...")
#         cmd = [sys.executable, "-m", "pip", "install", package]
#         if index_url:
#             cmd += ["--index-url", index_url]
#         subprocess.check_call(cmd)
#         print(f"✅ Installed {package}")

# # -----------------------------
# # Step 1: Install dependencies
# # -----------------------------
# try:
#     import torch
# except ImportError:
#     print("PyTorch not found. Installing...")
#     # Try GPU version first; fallback to CPU if fails
#     try:
#         subprocess.check_call([sys.executable, "-m", "pip", "install",
#                                "torch", "torchvision", "torchaudio",
#                                "--index-url", "https://download.pytorch.org/whl/cu121"])
#     except Exception:
#         subprocess.check_call([sys.executable, "-m", "pip", "install",
#                                "torch", "torchvision", "torchaudio",
#                                "--index-url", "https://download.pytorch.org/whl/cpu"])

# # Install other dependencies
# for pkg in ["diffusers", "transformers", "accelerate", "matplotlib", "safetensors"]:
#     install_if_missing(pkg)

# -----------------------------
# Step 2: Import libraries
# -----------------------------
from diffusers import StableDiffusionPipeline
import torch
import matplotlib.pyplot as plt

# -----------------------------
# Step 3: Load model
# -----------------------------
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
pipe = pipe.to(device)

# -----------------------------
# Step 4: Generate image
# -----------------------------
prompt = input("\n🖋️ Enter a text prompt to generate an image: ")
print("🎨 Generating image... this may take a while.")

image = pipe(prompt).images[0]

# -----------------------------
# Step 5: Display & save
# -----------------------------
plt.imshow(image)
plt.axis("off")
plt.title(f"Generated Image for: '{prompt}'")
plt.show()

image.save("generated_image.png")
print("\n✅ Image generated and saved as 'generated_image.png'.")