#!/bin/bash

# general configuration of the job
#SBATCH --job-name=cerfacs-IT
#SBATCH --account=intertwin
#SBATCH --partition=develbooster
#SBATCH --output=job.out
#SBATCH --error=job.err
#SBATCH --time=00:10:00
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=1
#SBATCH --gpus-per-node=1
#SBATCH --exclusive

# command
EXEC="train.py"

# set modules
ml --force purge
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py

# Job info
echo "DEBUG: TIME: $(date)"
sysN="$(uname -n | cut -f2- -d.)"
sysN="${sysN%%[0-9]*}"
echo "Running on system: $sysN"

if [ "$DEBUG" = true ]; then
  echo "DEBUG: EXECUTE: $EXEC"
  echo "DEBUG: SLURM_SUBMIT_DIR: $SLURM_SUBMIT_DIR"
  echo "DEBUG: SLURM_JOB_ID: $SLURM_JOB_ID"
  echo "DEBUG: SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
  echo "DEBUG: SLURM_NNODES: $SLURM_NNODES"
  echo "DEBUG: SLURM_NTASKS: $SLURM_NTASKS"
  echo "DEBUG: SLURM_TASKS_PER_NODE: $SLURM_TASKS_PER_NODE"
  echo "DEBUG: SLURM_SUBMIT_HOST: $SLURM_SUBMIT_HOST"
  echo "DEBUG: SLURMD_NODENAME: $SLURMD_NODENAME"
  echo "DEBUG: CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
  echo "DEBUG: NCCL_DEBUG=INFO"
  export NCCL_DEBUG=INFO
fi

echo

# set vars
export NCCL_DEBUG=INFO
export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}
export OMP_NUM_THREADS=1
if [ "$SLURM_CPUS_PER_TASK" > 0 ] ; then
  export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
fi
export CUDA_VISIBLE_DEVICES="0,1,2,3"

# Env vairables check
if [ -z "$DIST_MODE" ]; then 
  >&2 echo "ERROR: env variable DIST_MODE is not set. Allowed values are 'horovod', 'ddp' or 'deepspeed'"
  exit 1
fi
if [ -z "$RUN_NAME" ]; then 
  >&2 echo "WARNING: env variable RUN_NAME is not set. It's a way to identify some specific run of an experiment."
  RUN_NAME=$DIST_MODE
fi
if [ -z "$PYTHON_VENV" ]; then 
  >&2 echo "WARNING: env variable PYTHON_VENV is not set. It's the path to a python virtual environment."
else
  # Activate Python virtual env
  source $PYTHON_VENV/bin/activate
fi

# Launch training
if [ "$DIST_MODE" == "ddp" ] ; then
  echo "DDP training"
  srun --cpu-bind=none --ntasks-per-node=1 \
    bash -c "torchrun \
    --log_dir='logs_torchrun' \
    --nnodes=$SLURM_NNODES \
    --nproc_per_node=$SLURM_GPUS_PER_NODE \
    --rdzv_id=$SLURM_JOB_ID \
    --rdzv_conf=is_host=\$(((SLURM_NODEID)) && echo 0 || echo 1) \
    --rdzv_backend=c10d \
    --rdzv_endpoint='$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)'i:29500 \
    $EXEC"

elif [ "$DIST_MODE" == "deepspeed" ] ; then
  echo "DEEPSPEED training"
  MASTER_ADDR=$(scontrol show hostnames "\$SLURM_JOB_NODELIST" | head -n 1)i
  export MASTER_ADDR
  export MASTER_PORT=29500 
  srun --cpu-bind=none python -u $EXEC

elif [ "$DIST_MODE" == "horovod" ] ; then
  echo "HOROVOD training"
  srun --cpu-bind=none python -u $EXEC
else
  >&2 echo "ERROR: unrecognized \$DIST_MODE env variable"
  exit 1
fi

#eof
