cmake_minimum_required(VERSION 3.18)

# CPU-only build option
option(RFX_CPU_ONLY "Build without CUDA support" OFF)

# Conditionally enable CUDA language
if(RFX_CPU_ONLY)
    project(rfx LANGUAGES CXX)
    message(STATUS "Building CPU-only version (GPU disabled)")
    set(CUDA_AVAILABLE OFF)
else()
    # Try to find CUDA
    find_package(CUDAToolkit)
    if(CUDAToolkit_FOUND)
        enable_language(CUDA)
        project(rfx LANGUAGES CXX CUDA)
        set(CUDA_AVAILABLE ON)
        add_definitions(-DCUDA_FOUND)
        message(STATUS "Building with CUDA support")
    else()
        project(rfx LANGUAGES CXX)
        set(CUDA_AVAILABLE OFF)
        message(WARNING "CUDA not found - building CPU-only version")
    endif()
endif()

# Set CMake policies to suppress warnings
if(POLICY CMP0148)
    cmake_policy(SET CMP0148 NEW)
endif()

# Suppress RPATH warnings (common with conda/anaconda environments)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Only set CUDA standard if CUDA is available
if(CUDA_AVAILABLE)
    set(CMAKE_CUDA_STANDARD 17)
    set(CMAKE_CUDA_STANDARD_REQUIRED ON)
endif()

# Find required packages
find_package(pybind11 REQUIRED)
find_package(OpenMP REQUIRED)
# Try to find NUMA library
find_library(NUMA_LIB NAMES numa libnuma.so.2 libnuma.so)
# Try to find LAPACK (optional - we have fallback below)
find_package(LAPACK)

# Compiler flags - add debug symbols even in release mode
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3 -march=native")

if(CUDA_AVAILABLE)
    set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -O3 --use_fast_math -diag-suppress=177")
    # Set CUDA architectures (CUDA 12.8 supported architectures)
    # Supported: 75 (Turing), 80 (Ampere), 86 (Ampere), 87, 89 (Ada), 90 (Hopper)
    set(CMAKE_CUDA_ARCHITECTURES 75 80 86 87 89 90)
endif()

# Include directories
include_directories(
    ${CMAKE_CURRENT_SOURCE_DIR}/include
    ${CMAKE_CURRENT_SOURCE_DIR}/cuda
)

if(CUDA_AVAILABLE)
    include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()

# Source files
set(SOURCES
    src/rf_config.cpp
    src/rf_arrays.cpp
    src/rf_utils.cpp
    src/rf_utilities.cpp
    src/rf_random_forest.cpp
    src/rf_cuda_config.cpp
    src/rf_parallel_tree_growing.cpp
    src/rf_vectorized_ops.cpp
    src/rf_memory_pool.cpp
    src/rf_growtree_wrapper.cpp
    src/rf_growtree.cpp
    src/rf_bootstrap.cpp
    src/rf_getamat.cpp
    src/rf_testreebag.cpp
    src/rf_varimp.cpp
    src/rf_proximity.cpp
    src/rf_finishprox.cpp
    src/rf_predict.cpp
    src/rf_mds_cpu.cpp
    python/randomforest_py.cpp
)

# GPU-only source files (excluded from CPU-only builds)
if(CUDA_AVAILABLE)
    list(APPEND SOURCES src/rf_proximity_optimized.cpp)
endif()

# CUDA source files
if(CUDA_AVAILABLE)
    set(CUDA_SOURCES
        cuda/rf_config_cuda.cu
        cuda/rf_memory.cu
        cuda/rf_cuda_memory.cu
        cuda/rf_bootstrap.cu
        cuda/rf_varimp.cu
        cuda/rf_proximity.cu
        cuda/rf_testreebag.cu
        cuda/rf_finishprox.cu
        cuda/rf_getamat.cu
        cuda/rf_predict.cu
        cuda/rf_growtree.cu
        cuda/rf_quantization_kernels.cu
        cuda/rf_proximity_lowrank.cu
        cuda/rf_proximity_upper_triangle.cu
        cuda/rf_lowrank_helpers.cu
        cuda/rf_mds_gpu.cu
    )
else()
    set(CUDA_SOURCES "")
endif()

# Create Python module with pybind11
pybind11_add_module(rfx ${SOURCES} ${CUDA_SOURCES})

# Link CUDA libraries and fix libstdc++ linking
target_link_libraries(rfx PRIVATE
    pybind11::module
    OpenMP::OpenMP_CXX
)

# Link CUDA libraries only if available
if(CUDA_AVAILABLE)
    target_link_libraries(rfx PRIVATE
        CUDA::cudart
        CUDA::curand
        CUDA::cusolver
        CUDA::cublas
    )
endif()

# Link NUMA if found
if(NUMA_LIB)
    target_link_libraries(rfx PRIVATE ${NUMA_LIB})
    message(STATUS "Found NUMA: ${NUMA_LIB}")
else()
    message(WARNING "NUMA library not found - may cause linking issues")
endif()

# Link LAPACK and BLAS explicitly (for CPU MDS computation)
# Use find_library to locate LAPACK and BLAS
# Try multiple library name variations and paths
find_library(LAPACK_LIB 
    NAMES lapack liblapack.so.3 liblapack.so
    PATHS /usr/lib/x86_64-linux-gnu /usr/lib /lib/x86_64-linux-gnu /lib
    NO_DEFAULT_PATH
)
find_library(BLAS_LIB 
    NAMES blas libblas.so.3 libblas.so
    PATHS /usr/lib/x86_64-linux-gnu /usr/lib /lib/x86_64-linux-gnu /lib
    NO_DEFAULT_PATH
)

# If not found with NO_DEFAULT_PATH, try without it (searches more paths)
if(NOT LAPACK_LIB)
    find_library(LAPACK_LIB NAMES lapack liblapack.so.3 liblapack.so)
endif()
if(NOT BLAS_LIB)
    find_library(BLAS_LIB NAMES blas libblas.so.3 libblas.so)
endif()

if(LAPACK_LIB AND BLAS_LIB)
    target_link_libraries(rfx PRIVATE ${LAPACK_LIB} ${BLAS_LIB})
    message(STATUS "Found LAPACK: ${LAPACK_LIB}")
    message(STATUS "Found BLAS: ${BLAS_LIB}")
else()
    message(WARNING "LAPACK or BLAS not found - CPU MDS may not work")
    if(NOT LAPACK_LIB)
        message(WARNING "  LAPACK_LIB not found")
    endif()
    if(NOT BLAS_LIB)
        message(WARNING "  BLAS_LIB not found")
    endif()
endif()

# Fix libstdc++ linking to use conda environment
if(DEFINED ENV{CONDA_PREFIX})
    set(CONDA_PREFIX $ENV{CONDA_PREFIX})
    target_link_directories(rfx PRIVATE ${CONDA_PREFIX}/lib)
    target_link_libraries(rfx PRIVATE -static-libgcc -static-libstdc++)
endif()

# Set library properties
if(CUDA_AVAILABLE)
    set_target_properties(rfx PROPERTIES
        CUDA_SEPARABLE_COMPILATION OFF
        CUDA_RESOLVE_DEVICE_SYMBOLS ON
        POSITION_INDEPENDENT_CODE ON
        CUDA_RUNTIME_LIBRARY Shared
    )
    
    # Add CUDA device runtime library if available
    find_library(CUDA_DEVICE_RUNTIME_LIBRARY
        NAMES cudadevrt
        PATHS ${CUDAToolkit_LIBRARY_DIR}
        NO_DEFAULT_PATH
    )
    
    if(CUDA_DEVICE_RUNTIME_LIBRARY)
        target_link_libraries(rfx PRIVATE ${CUDA_DEVICE_RUNTIME_LIBRARY})
    endif()
    
    # Force device code linking
    set_property(TARGET rfx PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
    set_property(TARGET rfx PROPERTY CUDA_SEPARABLE_COMPILATION OFF)
    
    # Add explicit device linking for CUDA files only
    set_property(TARGET rfx PROPERTY CUDA_RUNTIME_LIBRARY Shared)
    set_property(TARGET rfx PROPERTY CUDA_ARCHITECTURES 75 80 86 87 89 90)
else()
    # CPU-only properties
    set_target_properties(rfx PROPERTIES
        POSITION_INDEPENDENT_CODE ON
    )
endif()

# Install targets
install(TARGETS rfx
    LIBRARY DESTINATION lib
    ARCHIVE DESTINATION lib
    RUNTIME DESTINATION bin
)

install(DIRECTORY include/ DESTINATION include/randomforest)
install(DIRECTORY cuda/ DESTINATION include/randomforest/cuda
    FILES_MATCHING PATTERN "*.cuh"
)

# Post-build: Copy .so file to python/ directory for development
# This allows Python to load the module without manual copying
add_custom_command(TARGET rfx POST_BUILD
    COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_SOURCE_DIR}/python
    COMMAND ${CMAKE_COMMAND} -E copy_if_different
        $<TARGET_FILE:rfx>
        ${CMAKE_SOURCE_DIR}/python/
    COMMENT "Copying rfx module to python/ directory"
)

