# Stage 1: Builder
FROM python:3.10-slim AS builder

ENV PYTHONUNBUFFERED=1 \
    PYTHONDONTWRITEBYTECODE=1

# Install Build Dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
    build-essential \
    curl \
    git \
    libomp-dev \
    && rm -rf /var/lib/apt/lists/*

# Install Rust (for Maturin)
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"

# Install Maturin
RUN pip install maturin

WORKDIR /app

# Copy Source Code
COPY . .

# Build C++ Shared Library (for ctypes benchmarks)
RUN g++ -O3 -march=native -shared -fPIC -fopenmp -mavx2 -o librsr.so rsr_gemm.cpp quantize.cpp rsr_fused.cpp fmm_octree.cpp

# Build Static Library (for Rust/Maturin)
RUN g++ -O3 -march=native -fPIC -fopenmp -mavx2 -c rsr_gemm.cpp -o rsr_gemm.o
RUN g++ -O3 -march=native -fPIC -fopenmp -mavx2 -c quantize.cpp -o quantize.o
RUN g++ -O3 -march=native -fPIC -fopenmp -mavx2 -c rsr_fused.cpp -o rsr_fused.o
RUN g++ -O3 -march=native -fPIC -fopenmp -mavx2 -c fmm_octree.cpp -o fmm_octree.o
RUN ar rcs libtrme_core.a rsr_gemm.o quantize.o rsr_fused.o fmm_octree.o

# Build Python Wheel via Maturin
RUN maturin build --release --strip

# Stage 2: Runtime
FROM python:3.10-slim

WORKDIR /app

# Install Runtime Deps (OpenMP)
RUN apt-get update && apt-get install -y --no-install-recommends \
    libgomp1 \
    && rm -rf /var/lib/apt/lists/*

# Copy Wheel from Builder
COPY --from=builder /app/target/wheels/*.whl /app/wheels/

# Install Wheel + Dependencies
# Note: Torch is large, in prod use a pre-baked torch image or cache.
# For this Dockerfile, we install standard.
RUN pip install --no-cache-dir /app/wheels/*.whl torch numpy scipy

# Copy Shared Library for ctypes
COPY --from=builder /app/librsr.so /app/librsr.so

# Copy Examples/Benchmarks for user
COPY benchmark_rsr.py .
COPY test_torch_integration.py .
COPY trme_sim.py .
COPY energy_estimator.py .
# Copy RTL for reference/co-sim
COPY rtl/ ./rtl/

# Environment Variables
ENV OMP_NUM_THREADS=4

CMD ["python", "benchmark_rsr.py"]
