#!/bin/bash

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

if [ "$#" -lt 1 ]; then
  echo "Not enough arguments! Run raydp-submit help to see instructions"
  exit 1
fi

if [ $1 == "help" ]; then
    echo "raydp-submit [--ray-conf file] spark-arguments"
    echo "Submit an spark-app to the ray cluster: raydp-submit [--ray-conf file]\\"
    echo "  --conf spark.example.config ... --class org.foo.bar example.jar"
    echo "It is recommended to use a ray config file to connect to an existing ray cluster"
    echo "spark-arguments are handed to SparkSubmit"
    exit 0
fi

# set SPARK_HOME
if [ -z "${SPARK_HOME}" ]; then
  SPARK_HOME=$(python3 -c "import os, pyspark; print(os.path.dirname(pyspark.__file__))")
  $SPARK_HOME/sbin/spark-config.sh
fi
# set RAYDP_HOME
if [ -z "${RAYDP_HOME}" ]; then
  RAYDP_HOME=$(python3 -c "import os, raydp; print(os.path.dirname(raydp.__file__))")
fi
# set RAY_HOME
if [ -z "${RAY_HOME}" ]; then
  RAY_HOME=$(python3 -c "import os, ray; print(os.path.dirname(ray.__file__))")
fi
# set RAY_DRIVER_NODE_IP
if [ -z "${RAY_DRIVER_NODE_IP}" ]; then
  RAY_DRIVER_NODE_IP=$(python3 -c "import ray; print(ray.util.get_node_ip_address())")
fi


# set log4j versions for spark driver and executors inside ray worker
# read from versions.py.
log4j_config=()
while read line; do
  log4j_config+=($line)
done < <(python3 -c "import os; from raydp import versions; print(versions.SPARK_LOG4J_VERSION); 
print(versions.SPARK_LOG4J_CONFIG_FILE_NAME_KEY);
print(os.getenv(\"SPARK_LOG4J_CONFIG_FILE_NAME\", versions.SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT));
print(versions.RAY_LOG4J_VERSION);
print(versions.RAY_LOG4J_CONFIG_FILE_NAME_KEY);
print(os.getenv(\"RAY_LOG4J_CONFIG_FILE_NAME\", versions.RAY_LOG4J_CONFIG_FILE_NAME_DEFAULT));
")


SPARK_LOG4J_VERSION=${log4j_config[0]}
SPARK_LOG4J_CONFIG_FILE_NAME_KEY=${log4j_config[1]}
SPARK_LOG4J_CONFIG_FILE_NAME=${log4j_config[2]}
RAY_LOG4J_VERSION=${log4j_config[3]}
RAY_LOG4J_CONFIG_FILE_NAME_KEY=${log4j_config[4]}
RAY_LOG4J_CONFIG_FILE_NAME=${log4j_config[5]}
# get agent jar
raydp_agent_jar=$(python3 -c "import os, glob, raydp; raydp_agent_path=os.path.abspath(os.path.join(os.path.dirname(raydp.__file__), './jars/raydp-agent*.jar'));
print(glob.glob(raydp_agent_path)[0])
")


# search user's configs in terms of log4j
SPARK_PREF_CP=

search_target()
{
	origin_IFS=$IFS
	IFS='=' read -r -a kv <<< "$1"
	if [ ${#kv[@]} -eq 2 ]; then
		case "${kv[0]}" in
			"spark.preferClassPath") 		SPARK_PREF_CP=${kv[1]};;
			"spark.log4j.config.file.name")		SPARK_LOG4J_CONFIG_FILE_NAME=${kv[1]};;
			"spark.ray.log4j.config.file.name")	RAY_LOG4J_CONFIG_FILE_NAME=${kv[1]};;
			*)					;;
		esac
	fi
	IFS=$origin_IFS
}

args=()
while [ $# -gt 0 ]; do
        case "$1" in
                --conf)
			args+=("--conf")
		        args+=("$2")
			search_target "$2"
                        shift 2
                        ;;
                *)
				args+=("$1")
                        shift 1
                        ;;
        esac
done

# set configs with either user's choice or defaults
added_confs=()
added_args=()
added_confs+=("-javaagent:$raydp_agent_jar")
added_confs+=("-Dspark.ray.log4j.factory.class=$SPARK_LOG4J_VERSION")
added_confs+=("-D$SPARK_LOG4J_CONFIG_FILE_NAME_KEY=$SPARK_LOG4J_CONFIG_FILE_NAME")
added_confs+=("-Dspark.javaagent=$raydp_agent_jar")
# Add JDK 17+ module access flags for the SparkSubmit JVM
# -XX:+IgnoreUnrecognizedVMOptions ensures compatibility with Java 8/11 
# (they will silently ignore --add-opens/--add-exports flags)
added_confs+=("-XX:+IgnoreUnrecognizedVMOptions")
added_confs+=("--add-opens=java.base/java.lang=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.lang.invoke=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.io=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.net=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.nio=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.math=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.text=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.util=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.util.concurrent=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/sun.nio.ch=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/sun.nio.cs=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/sun.security.action=ALL-UNNAMED")
added_confs+=("--add-opens=java.base/sun.util.calendar=ALL-UNNAMED")
added_confs+=("--add-exports=java.base/sun.nio.ch=ALL-UNNAMED")
added_args+=("--conf")
added_args+=("spark.ray.log4j.config.file.name=$RAY_LOG4J_CONFIG_FILE_NAME")
added_args+=("--conf")
added_args+=("spark.driver.host=$RAY_DRIVER_NODE_IP")
added_args+=("--conf")
added_args+=("spark.driver.bindAddress=$RAY_DRIVER_NODE_IP")

# Find the java binary
if [ -n "${JAVA_HOME}" ]; then
  RUNNER="${JAVA_HOME}/bin/java"
else
  if [ "$(command -v java)" ]; then
    RUNNER="java"
  else
    echo "JAVA_HOME is not set" >&2
    exit 1
  fi
fi

# check if prefer classpath set
if [ ! -z "$SPARK_PREF_CP" ]; then
	SPARK_PREF_CP=$SPARK_PREF_CP:
fi
# class path for spark driver, ray's prefer class path can be set by 'spark.ray.preferClassPath' which is passed down to ray in spark conf
RAYDP_CLASS_PATH="-cp $SPARK_PREF_CP$RAYDP_HOME/jars/*:$SPARK_HOME/conf:$SPARK_HOME/jars/*:$RAY_HOME/jars/*"

# merge all
new_args=()
	for e in "${args[@]::${#args[@]}-1}"
do
		new_args+=("$e")
done
	for e in "${added_args[@]}"
do
		new_args+=("$e")
done
	new_args+=("${args[${#args[@]}-1]}")

# set arguments
set -- "${new_args[@]}"

echo "$@"

if [ $1 == "--ray-conf" ]; then
  RAY_CONF=$2
  shift 2
  $RUNNER $RAYDP_CLASS_PATH -Dray.config-file=$RAY_CONF "${added_confs[@]}"  org.apache.spark.deploy.SparkSubmit --master ray "$@"
else
  $RUNNER $RAYDP_CLASS_PATH "${added_confs[@]}" org.apache.spark.deploy.SparkSubmit --master ray "$@"
fi
