// Copyright 2023 ByteDance Ltd. and/or its affiliates.
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
#define _GLIBCXX_USE_CXX11_ABI 0
{%- macro debug_helper(var) %}
  {%- if debug %}
  std::cout<<"[MLIR c interface]{{var}} = " << {{var}} << std::endl;
  {%- endif %}
{%- endmacro %}


#include "hercules/runtime/mlir/convert_memref.h"
#include "hercules/runtime/mlir/func_loader.h"
#include "hercules/runtime/runtime_value.h"
#include "hercules/runtime/registry.h"

#if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 10 )
#if __has_builtin(__fp16)
static_assert(sizeof (__fp16) == 2, "__fp16(a.k.a float16) size must be 2 bytes, your machine/gcc may not support __fp16 type");
#else
#include <float.h>
#if __has_builtin(_Float16)
static_assert(sizeof (_Float16) == 2, "_Float16(a.k.a float16) size must be 2 bytes, your machine/gcc may not support _Float16 type");
using __fp16 = _Float16;
#endif
#endif
#else
using __fp16 = uint16_t;
#endif


static_assert(sizeof (float) == 4, "float(a.k.a float32) size must be 4 bytes, your machine/gcc may not support float type");
static_assert(sizeof (double) == 8, "double(a.k.a float64) size must be 8 bytes, your machine/gcc may not support double type");

{%- set func_prefix = "_" ~ unique_id ~ "_" ~ func_name ~ "_" %}
{%- set func_type = func_prefix ~ "_mlir_func_type" %}
{%- set c_interface_func_name = func_prefix ~ "_hvm_c_api_" %}
using namespace hercules::runtime;
using namespace hercules::runtime::mlir;

// define func type
{%- if func_return_kind.is_void() %}
using {{func_type}} = void(*)({{input_types |join(', ') }});

{%- elif func_return_kind.is_scalar() %}
using {{func_type}} = {{return_type}}(*)({{input_types |join(', ') }});

{%- elif func_return_kind.is_static_tensor() %}
static_assert(false, "function_return_kind({{func_return_kind}}) is not supported");

{%- elif func_return_kind.is_dynamic_tensor() %}
using {{func_type}} = void(*)(void *, {{input_types |join(', ') }});

{%- else %}
static_assert(false, "function_return_kind({{func_return_kind}}) is not supported");

{%- endif %}



RTValue {{ c_interface_func_name }}(PyArgs args){
  static void * func_ptr = load_func("_mlir_ciface_{{func_name}}", "{{lib_path}}");
  {{debug_helper("func_ptr")}}
  static auto casted_func_ptr = reinterpret_cast<{{func_type}}>(func_ptr);
  {{debug_helper("casted_func_ptr")}}

  //convert input types
  {%- set func_args_list = [] %}
  {%- set ndarray_ptr_list = [] %}
  {%- for t in input_types %}
  {%- set i = loop.index-1 %}
  {%- if t == "void *" %}
  auto && _memref_data_{{i}} = args[{{i}}].As<NDArray>();
  {{debug_helper("_memref_data_"~i)}}
  auto && _memref_ptr_{{i}} = convert_from_ndarray(_memref_data_{{i}});
  {{debug_helper("_memref_ptr_"~i)}}
  auto _cast_arg_{{i}}_ = _memref_ptr_{{i}}.get();
  {% set _ = ndarray_ptr_list.append("_cast_arg_" ~ i ~ "_") %}
  {%- else %}
  auto _cast_arg_{{i}}_ = args[{{i}}].As<{{t}}>();
  {%- endif %}
  {% set _ = func_args_list.append("_cast_arg_" ~ i ~ "_") %}
  {{debug_helper( "_cast_arg_" ~ i ~ "_")}}
  {%- endfor %}
  {%- set func_args = func_args_list|join(', ') %}

  // call funtion
  {%- if func_return_kind.is_void() %}
  casted_func_ptr({{func_args}});
  return None;

  {%- elif func_return_kind.is_scalar() %}
  return casted_func_ptr({{func_args}});

  {%- elif func_return_kind.is_static_tensor() %}
  static_assert(false, "function_return_kind({{func_return_kind}}) is not supported");
  return None;

  {%- elif func_return_kind.is_dynamic_tensor() %}
  auto && _mlir_return_31905_shared_ptr_571 = alloc_memref_descriptor_ptr({{return_ndim}});
  void * _mlir_return_31905_ptr_571 = _mlir_return_31905_shared_ptr_571.get();
  casted_func_ptr(_mlir_return_31905_ptr_571, {{func_args}});
  return convert_to_ndarray(_mlir_return_31905_shared_ptr_571, {{return_ndim}}, cvt_str_to_dl_dtype("{{return_dtype}}"));
  {%- else %}
  static_assert(false, "function_return_kind({{func_return_kind}}) is not supported");
  return None;

  {%- endif %}
}

HVM_REGISTER_NATIVE_FUNC({{c_interface_func_name}});