import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
from torchgen.model import (
    Argument,
    BackendIndex,
    BaseTy,
    FunctionSchema,
    OptionalType,
    SelfArgument,
    BaseType,
    NativeFunctionsGroup,
    TensorOptionsArguments,
    Type,
)
from torchgen.static_runtime import config

import math
from typing import List, Optional, Sequence, Tuple, Union


def has_alias(
    arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
) -> bool:
    for arg in arguments:
        annotation = getattr(arg, "annotation", None)
        if not annotation:
            continue
        alias_set = getattr(annotation, "alias_set", ())
        if alias_set:
            return True
    return False


BLOCKED_OPS = frozenset(
    (
        # non cpu ops
        "sparse_sampled_addmm",
        "hspmm",
        # sparse ops
        "sspaddmm",
        # deprecated ops
        "floor_divide",
        "ger",
        # buggy ops
        "conj_physical",  # P495807361
        "binary_cross_entropy",  # P496394764
        "arccosh",
        # uncommon ops
        "cholesky",
        "lu_solve",
        "linalg_cholesky",
        "linalg_householder_product",
        "_compute_linear_combination",
    )
)


def is_supported(g: NativeFunctionsGroup) -> bool:
    base_op_name = g.out.func.name.name.base
    if base_op_name in BLOCKED_OPS:
        return False
    if config.is_hand_written(g):
        return False
    if not g.structured:
        # In case of unstructured op, we check if it has out variant implementation.
        # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
        # parameter.
        if (
            not hasattr(g, "out")
            or not str(g.out.func).endswith("Tensor(a!) out) -> Tensor(a!)")
            or not str(g.out.func.name).endswith(".out")
        ):
            return False
    if has_alias(g.out.func.arguments.non_out):
        # This op may create an alias of inputs.
        return False
    if len(g.out.func.arguments.out) > 1:
        # More than 1 output values.
        return False
    if "at::Tensor &" != cpp.returns_type(g.out.func.returns).cpp_type():
        # Returns a non-Tensor value.
        return False
    for arg in g.out.func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            return False
    return True


def ivalue_type_conversion_method(
    arg_type: Union[BaseType, OptionalType, Type]
) -> Optional[Tuple[bool, str]]:
    """
    Return the method call expression of `c10::ivalue' to convert its contained value to
    the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
    this function returns ".toTensor()", so that it can be appended to the ivalue's
    variable name to get the value of the expected type.
    """
    type_conversion_methods = {
        BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
        BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
        BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
        BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
        BaseTy.ScalarType: (
            (False, "toScalarType()"),
            (False, "toOptional<at::ScalarType>()"),
        ),
        BaseTy.str: (
            (False, "toStringView()"),
            (False, "toOptional<c10::string_view>()"),
        ),
    }

    base_ty_object = None
    if isinstance(arg_type, BaseType):
        base_ty_object = arg_type.name
    elif isinstance(arg_type, OptionalType):
        if not isinstance(arg_type.elem, BaseType):
            # ListType is currently unsupported.
            return None
        base_ty_object = arg_type.elem.name
    else:
        return None

    if base_ty_object not in type_conversion_methods:
        return None
    methods = type_conversion_methods[base_ty_object]
    if isinstance(arg_type, BaseType):
        return methods[0]
    return methods[1]


should_use_int_tensor_ops_ = frozenset(
    (
        "bitwise_not",
        "bitwise_and",
        "bitwise_or",
        "bitwise_xor",
        "gcd",
        "lcm",
        "scatter",
        "gather",
        "_convert_indices_from_coo_to_csr",
        "_convert_indices_from_csr_to_coo",
    )
)


def should_use_int_tensor(op_name: str) -> bool:
    return op_name in should_use_int_tensor_ops_


test_tensor_dim_ops_1_ = frozenset(
    (
        "addmv",
        "index_add",
        "_convert_indices_from_coo_to_csr",
        "_convert_indices_from_csr_to_coo",
        "nll_loss_backward",
        "dot",
        "vdot",
        "outer",
        "ger",
    )
)
test_tensor_dim_ops_2_ = frozenset(
    ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation")
)


def test_tensor_dim(op_name: str) -> int:
    if op_name in test_tensor_dim_ops_1_:
        return 1
    if op_name in test_tensor_dim_ops_2_:
        return 2
    return 3


def test_value_expression(
    arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
) -> str:
    num_tensors = 16 if index == 0 else 64
    num_dim = test_tensor_dim(op_name)
    size_per_dim = math.ceil(num_tensors / float(num_dim))
    size_per_dim += size_per_dim % 2
    tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
    if should_use_int_tensor(op_name):
        tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
    else:
        tensor_expression = f"at::rand({tensor_size_ex})"

    value_expressions = {
        BaseTy.Tensor: tensor_expression,
        BaseTy.int: "1",
        BaseTy.bool: "false",
        BaseTy.Scalar: "2",
        BaseTy.ScalarType: "at::ScalarType::Float",
        BaseTy.str: '"floor"',
    }

    base_ty_object = None
    if isinstance(arg_type, BaseType):
        base_ty_object = arg_type.name
    else:
        assert isinstance(arg_type, OptionalType) and isinstance(
            arg_type.elem, BaseType
        )
        base_ty_object = arg_type.elem.name
    assert base_ty_object in value_expressions, "not expected type"
    value_expression = value_expressions[base_ty_object]
    return value_expression


def generate_test_value_definitions(g: NativeFunctionsGroup, index: int) -> str:
    schema = g.functional.func
    assert not schema.is_out_fn()
    schema_name = schema.name.name.base
    arg_map = {}
    for arg in schema.schema_order_arguments():
        test_value_exp = test_value_expression(arg.type, index, schema_name)
        arg_map[arg.name] = test_value_exp
    config.override_test_values(arg_map, schema_name, index)
    arg_populations = []
    for arg_name, arg_value in arg_map.items():
        arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
    return ";\n    ".join(arg_populations) + ";"


def generate_test_value_names(g: NativeFunctionsGroup, index: int) -> str:
    schema = g.functional.func
    assert not schema.is_out_fn()
    return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())


generate_test_ir_arguments_base_ty_to_type_str_ = {
    BaseTy.Tensor: "Tensor",
    BaseTy.int: "int",
    BaseTy.float: "float",
    BaseTy.str: "str",
    BaseTy.Scalar: "int",
    BaseTy.ScalarType: "int",
    BaseTy.bool: "bool",
}


def generate_test_ir_arguments(
    g: NativeFunctionsGroup,
) -> List[Tuple[str, Optional[str]]]:
    def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
        t = arg.type
        add_optional = False
        if isinstance(t, OptionalType):
            t = t.elem
            add_optional = True
        assert isinstance(t, BaseType)
        type_str = None
        if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
            type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
        if type_str and add_optional:
            type_str = f"{type_str}?"
        return ("%" + arg.name, type_str)

    schema = g.functional.func
    assert not schema.is_out_fn()
    return [ir_argument(arg) for arg in schema.schema_order_arguments()]


def generate_arg_extraction(g: NativeFunctionsGroup) -> str:
    schema = g.functional.func
    assert not schema.is_out_fn()
    arg_populations = []
    for i, arg in enumerate(schema.schema_order_arguments()):
        maybe_method = ivalue_type_conversion_method(arg.type)
        assert maybe_method
        is_reference, type_conversion_method = maybe_method
        reference = "&" if is_reference else ""
        arg_populations.append(
            f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
        )
    return ";\n    ".join(arg_populations) + ";"


def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
    kernel = backend_index.get_kernel(g.functional)
    if g.structured or kernel is None:
        return cpp.name(g.functional.func)
    return kernel.kernel


def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
    kernel = backend_index.get_kernel(g.out)
    if g.structured or kernel is None:
        return cpp.name(g.out.func)
    return kernel.kernel


def generate_non_out_variant_call(
    g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
    schema = g.functional.func
    assert not schema.is_out_fn()
    kernel_name = get_kernel_name(g, backend_index)
    arg_names = (arg.name for arg in schema.schema_order_arguments())
    namespace_name = "cpu" if g.structured else "native"
    return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'


def generate_out_variant_call(
    g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
    schema = g.out.func
    assert schema.is_out_fn()
    arg_names = []
    kernel_name = get_out_kernel_name(g, backend_index)
    if g.structured:
        # structured op starts with the output tensor argument.
        arg_names = [out_arg.name for out_arg in schema.arguments.out]
    else:
        arg_names = []
    for arg in schema.arguments.non_out:
        if isinstance(arg, SelfArgument):
            arg_names.append(arg.argument.name)
        else:
            assert isinstance(arg, Argument)
            arg_names.append(arg.name)
    if not g.structured:
        assert len(schema.arguments.out) == 1
        arg_names.append(schema.arguments.out[0].name)
    cpp_func_name = cpp.name(schema)
    cpp_arg_names = ",".join(arg_names)
    namespace_name = "cpu" if g.structured else "native"
    return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"


no_memory_resize_ops = frozenset(
    (
        "isin.Scalar_Tensor",
        "index_add",
        "dot",
        "vdot",
        "nuclear_norm",
        "histc",
        "l1_loss",
        "multi_margin_loss",
        "multilabel_margin_loss",
        "nll_loss",
        "nll_loss2d",
    )
)


def should_check_resize(schema: FunctionSchema) -> bool:
    schema_str = str(schema)
    type_variant_op_name = schema_str[: schema_str.find("(")]
    return type_variant_op_name not in no_memory_resize_ops


def op_name_from_group(g: NativeFunctionsGroup) -> str:
    return g.functional.func.name.name.base


class GenOutVariantDispatcher:
    def __call__(
        self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
    ) -> str:
        if not groups:
            return ""
        generated_type_variants = []
        for g in groups:
            with native_function_manager(g):
                assert is_supported(g)
                assert isinstance(g, NativeFunctionsGroup)
                generated_type_variant = self.op_generator(g, backend_index)
                generated_type_variants.append(generated_type_variant)
        op_name = op_name_from_group(groups[0])
        body = "\n".join(generated_type_variants)
        generated = f"""
REGISTER_OPERATOR_FUNCTOR(
    aten::{op_name},
    aten_{op_name},
    [](Node* n) -> SROperator {{
      {body}
      LogAndDumpSchema(n);
      return nullptr;
    }});
"""
        return generated

    def op_generator(self, g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
        functional = g.functional
        schema = str(functional.func)
        op_name = op_name_from_group(g)
        populated_argument = generate_arg_extraction(g)
        functional_variant_call = generate_non_out_variant_call(g, backend_index)
        assert len(g.out.func.arguments.out) == 1
        out_variable_name = str(g.out.func.arguments.out[0].name)
        out_variant_call = generate_out_variant_call(g, backend_index)
        generated = f"""
      if (n->matches(torch::schema("aten::{schema}"))) {{
        return [](ProcessedNode* p_node) {{
          {populated_argument}
          if (p_node->Output(0).isNone()) {{
            p_node->Output(0) = {functional_variant_call};
            return;
          }}
          auto& {out_variable_name} = p_node->Output(0).toTensor();
          fastResizeToZero({out_variable_name});
          {out_variant_call};
        }};
      }}"""
        return generated


class GenOutVariantDispatcherTestCase:
    def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str:
        if not groups:
            return ""
        generated_type_variants = []
        for g in groups:
            with native_function_manager(g):
                assert is_supported(g)
                assert isinstance(g, NativeFunctionsGroup)
                generated_type_variant = self.test_case_generator(g)
                generated_type_variants.append(generated_type_variant)
        return "\n".join(generated_type_variants)

    def test_case_generator(self, g: NativeFunctionsGroup) -> str:
        functional = g.functional
        schema = str(functional.func)
        assert schema.find("(") > 0
        type_variant_op_name = schema[: schema.find("(")].replace(".", "_")
        op_name = op_name_from_group(g)
        assert type_variant_op_name.startswith(op_name)

        arg_types = generate_test_ir_arguments(g)
        arg_declarations = ", ".join(
            (
                arg_name if arg_type is None else f"{arg_name}: {arg_type}"
                for arg_name, arg_type in arg_types
            )
        )
        arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
        assert (
            len(functional.func.returns) == 1
            and isinstance(functional.func.returns[0].type, BaseType)
            and functional.func.returns[0].type.name is BaseTy.Tensor
        )
        test_value_definitions = generate_test_value_definitions(g, 0)
        test_value_names = generate_test_value_names(g, 0)
        test_value_definitions2 = generate_test_value_definitions(g, 1)
        test_value_names2 = generate_test_value_names(g, 1)
        check_resize = "true" if should_check_resize(functional.func) else "false"
        generated = f"""
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
  const std::string script = R"IR(
    graph({arg_declarations}):
        %bias: None = prim::Constant()
        %ret = aten::{op_name}({arg_names})
        %cloned = aten::clone(%ret, %bias)
        return (%cloned)
  )IR";

  {test_value_definitions}
  std::vector<IValue> args{{{test_value_names}}};
  testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});

  {test_value_definitions2}
  std::vector<IValue> args2{{{test_value_names2}}};
  testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});

}}
"""
        return generated
