import datetime
import functools
import os
import tempfile
import unittest

import torch

import torch._dynamo

import torch.utils._pytree as pytree
from torch._dynamo.utils import clone_input
from torch._subclasses.schema_check_mode import SchemaCheckMode
from torch.overrides import TorchFunctionMode
from torch.testing._internal.optests import (
    aot_autograd_check,
    autograd_registration_check,
    fake_check,
)


def safe_schema_check(op, args, kwargs):
    args, kwargs = deepcopy_tensors((args, kwargs))
    with SchemaCheckMode():
        result = op(*args, **kwargs)
        return result


def safe_autograd_registration_check(op, args, kwargs):
    # Don't perform autograd_registration_check if none of the inputs require grad.
    if not pytree.tree_any_only(
        torch.Tensor, lambda x: x.requires_grad, (args, kwargs)
    ):
        return
    args, kwargs = deepcopy_tensors((args, kwargs))
    return autograd_registration_check(op, args, kwargs)


def safe_fake_check(op, args, kwargs):
    args, kwargs = deepcopy_tensors((args, kwargs))
    return fake_check(op, args, kwargs, dynamic_only=False)


def safe_aot_autograd_check(op, args, kwargs, dynamic):
    def func(*args, **kwargs):
        args, kwargs = pytree.tree_map_only(torch.Tensor, torch.clone, (args, kwargs))
        return op(*args, **kwargs)

    # aot_autograd_check runs func(*args, **kwargs) multiple times
    # and assumes `func` does not modify its inputs.
    return aot_autograd_check(func, args, kwargs, dynamic, check_gradients="auto")


def deepcopy_tensors(inputs):
    return pytree.tree_map_only(torch.Tensor, clone_input, inputs)


# Test util requirements
# - The test util must have signature (op: OpOverload, args, kwargs)
# - The test util must NOT mutate args, kwargs.
# - The test utils in this list must not be prefixes of each other. For example,
#   having both "test_schema" and "test_schema_is_functional" is NOT OK.
# - The order of items in this dict matters (for opcheck), we'll run them
#   in order.
ALL_TEST_UTILS = {
    "test_schema": safe_schema_check,
    "test_autograd_registration": safe_autograd_registration_check,
    "test_faketensor": safe_fake_check,
    "test_aot_dispatch_static": functools.partial(
        safe_aot_autograd_check,
        dynamic=False,
    ),
    "test_aot_dispatch_dynamic": functools.partial(
        safe_aot_autograd_check,
        dynamic=True,
    ),
}


def generate_opcheck_tests(
    testcase,
    namespaces,
    failures_dict,
    failures_dict_path,
    additional_decorators,
    test_utils,
):
    """Given an existing TestCase, use the existing tests to generate
    additional validation tests for custom operators.

    For {all existing tests in the TestCase} x {all test utils},
    we will generate one new test. The new test runs a TorchFunctionMode
    that intercepts ``op(*args, **kwargs)`` calls and invokes
    ``test_util(op, *args, **kwargs)``, where ``op`` is an operator.

    The test_util that we support are in ALL_TEST_UTILS. They are:
    - test_schema: This runs SchemaCheckMode.
    - test_autograd_registration: This runs autograd_registration_check.
    - test_faketensor: This runs CrossRefFakeMode.
    - test_aot_dispatch_static: This runs aot_autograd_check, which:
        checks that the outputs (and gradients, if they are computable)
        are the same under eager-mode PyTorch and using AOTAutograd.
    - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but
        runs AOTAutograd using dynamic shapes instead of static shapes.

    The generated test will have name ``{test_util}__{original_name}``.
    For example, if there is a method named ``test_cumsum``, then
    we will generate a ``test_schema__test_cumsum``,
    ``test_faketensor__test_cumsum``, etc.

    For more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit

    Args:
        testcase: The testcase we will modify and generate additional tests for.
        namespaces: We will only intercept calls to custom operators with these
                    namespaces.
        failures_dict: See ``validate_failures_dict`` for more details
        failures_dict_path: The path to your failures dict. We will mention it in errors.
        additional_decorators: Pass us some decorators
        test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"]
    """
    if not issubclass(testcase, unittest.TestCase):
        raise ValueError(
            f"Expected testcase to be subclass of unittest.TestCase, got {type(testcase)}"
        )
    test_methods = [
        m
        for m in dir(testcase)
        if m.startswith("test_") and callable(getattr(testcase, m))
    ]
    validate_failures_dict(failures_dict, test_utils, testcase)

    def construct_method(attr, prefix, tester):
        method = getattr(testcase, attr)
        new_method_name = prefix + "__" + attr

        def new_method(*args, **kwargs):
            with OpCheckMode(
                namespaces,
                prefix,
                tester,
                failures_dict,
                new_method_name,
                failures_dict_path,
            ):
                result = method(*args, **kwargs)
            return result

        if new_method_name in additional_decorators:
            for dec in additional_decorators[new_method_name]:
                new_method = dec(new_method)

        if hasattr(testcase, new_method_name):
            raise RuntimeError(
                f"Tried to autogenerate {new_method_name} but {testcase} already "
                f"has method named {new_method_name}. Please rename the original "
                f"method on the TestCase."
            )
        setattr(testcase, new_method_name, new_method)

    test_utils = {name: ALL_TEST_UTILS[name] for name in test_utils}
    for attr in test_methods:
        for prefix, tester in test_utils.items():
            construct_method(attr, prefix, tester)


TEST_OPTIONS = ("xfail", "skip", "success")


def validate_failures_dict(failure_dict, test_utils, testcase):
    """Validates the failures dict.

    The failure dict looks something like the following.
    It maps operator name (qualname) to a list of autogenerated tests.
    Each autogenerated test may have a check for the operator (if the operator is
    called by the test); the dictionary specifies if we should skip the check,
    or if we expect some check to fail.

    {
        "fbgemm::split_lengths": {
            "test_schema__test_split_lengths": "xfail",
            "test_schema__test_split_lengths_empty": "skip",
        }
        "fbgemm::gather_lengths": {
            "test_schema__test_gather_lengths": "xfail",
        }
    }

    We require that all keys are sorted in alphabetical order. This makes
    it easier for us to codemod the failures_dict.
    """
    qualnames = list(failure_dict.keys())
    if qualnames != sorted(qualnames):
        raise RuntimeError("The failures dict must be sorted in alphabetical order")
    for qualname, test_to_option in failure_dict.items():
        test_names = list(test_to_option.keys())
        if test_names != sorted(test_names):
            raise RuntimeError(
                f"failures_dict['{qualname}']'s keys must be sorted in alphabetical order"
            )
        for test_name, test_option in test_to_option.items():
            if test_option not in TEST_OPTIONS:
                raise RuntimeError(
                    f"In failures_dict, got value={test_option} but it needs to be in {TEST_OPTIONS}"
                )
            if not any(test_name.startswith(test) for test in test_utils):
                raise RuntimeError(
                    f"In failures_dict, test name '{test_name}' should begin with one of {test_utils}"
                )
            for test in test_utils:
                if not test_name.startswith(test):
                    continue
                base_test_name = test_name[len(test) + 2 :]
                if hasattr(testcase, base_test_name):
                    continue
                raise RuntimeError(
                    f"In failures dict, got test name '{test_name}'. We parsed this as "
                    f"running test '{test}' on '{base_test_name}', but "
                    f"{base_test_name} does not exist on the TestCase. "
                    f"Maybe you need to change the test name?"
                )


class OpCheckMode(TorchFunctionMode):
    """
    For a given test, OpCheckMode intercepts calls to operators and runs
    test_util(op, args, kwargs) for each intercepted (op, args, kwargs).
    """

    def __init__(
        self,
        namespaces,
        test_util_name,
        test_util,
        failures_dict,
        test_name,
        failures_dict_path,
    ):
        # We will intercept calls to ops with these namespaces
        self.namespaces = namespaces
        # The test utility function. Its signature should be (op, args, kwargs) -> None.
        # Examples of test utilities are: schema_check, make_fx_check
        self.test_util = test_util
        self.test_util_name = test_util_name
        # The name of the test that is running this OpCheckMode.
        self.test_name = test_name
        # Maps qualname -> test_name -> skip/xfail
        # Tells us if we should skip a test or assert that there is a failure.
        self.failures_dict = failures_dict
        # Location of the failures dict. Makes it so that the error message is better.
        self.failures_dict_path = failures_dict_path

        # OpCheckMode surpresses errors, collects them here, and then raises them on exit.
        # Maps qualname -> List[exception]
        self.seen_ops_to_errors = {}

    def maybe_raise_errors_on_exit(self):
        # Check expected failures first
        for qualname in self.seen_ops_to_errors.keys():
            option = retrieve(self.failures_dict, qualname, self.test_name)
            if len(self.seen_ops_to_errors[qualname]) == 0:
                if option == "xfail":
                    raise OpCheckError(
                        f"generate_opcheck_tests: Unexpected success for operator "
                        f"{qualname} on test {self.test_name}. This may mean that "
                        f"you have fixed this test failure. Please remove the "
                        f"expected failure in the failure dict at "
                        f"{self.failures_dict_path}."
                        f"For more details, see "
                        f"https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit"
                    )
                continue
        failed_ops = []
        for qualname in self.seen_ops_to_errors.keys():
            option = retrieve(self.failures_dict, qualname, self.test_name)
            if option != "success":
                continue
            if len(self.seen_ops_to_errors[qualname]) == 0:
                continue
            failed_ops.append(qualname)
        if not failed_ops:
            return
        # Raise from the first error but also report about all of them to make
        # recording xfails easier.
        ex, op, args, kwargs = self.seen_ops_to_errors[failed_ops[0]][0]
        if should_print_repro():
            repro_command = generate_repro(self.test_util_name, op, args, kwargs)
            repro_command = (
                f"\n\nFor a minimal repro, run the following: \n\n{repro_command}"
            )
        else:
            repro_command = ""
        raise OpCheckError(
            f"Test generated by `generate_opcheck_tests`, {self.test_name}, "
            f"failed on operators {failed_ops}. This usually means that the "
            f"operators are not implemented correctly and may lead to silently "
            f"incorrect behavior. Set PYTORCH_OPCHECK_PRINT_REPRO=1 for a standalone repro, "
            f"or please see "
            f"https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit "
            f"for more recommendations. "
            f"{repro_command}"
        ) from ex

    def __exit__(self, *args, **kwargs):
        try:
            self.maybe_raise_errors_on_exit()
        finally:
            result = super().__exit__(*args, **kwargs)
        return result

    def run_test_util(self, op, args, kwargs):
        try:
            self.test_util(op, args, kwargs)
        except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
            # We might get here if the input is already a FakeTensor
            # or if we're in a torch.compile block. Just ignore these
            # since we can't handle them and reporting them as failures
            # is too noisy.
            pass

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}

        # Only intercept calls to operators
        if not isinstance(func, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
            return func(*args, **kwargs)
        if (
            torch.jit.is_tracing()
            or torch.jit.is_scripting()
            or torch._dynamo.is_compiling()
        ):
            return func(*args, **kwargs)
        # Pre-existing code may not use the .default overload. If we see an
        # OpOverloadPacket and we cannot resolve the overload, then we just throw
        # and ask the user to clarify. Otherwise, we attempt to resolve the overload.
        if isinstance(func, torch._ops.OpOverloadPacket):
            func = resolve_unique_overload_or_throw(func)
        qualname = func.name()
        ns = qualname.split("::")[0]
        if ns not in self.namespaces:
            return func(*args, **kwargs)

        args_c, kwargs_c = deepcopy_tensors((args, kwargs))
        # Only call test_util(op, *args, **kwargs) if this succeeds.
        result = func(*args, **kwargs)

        option = retrieve(self.failures_dict, qualname, self.test_name)
        if option == "success" or option == "xfail":
            # Surpress all errors during execution. Raise them during __exit__.
            try:
                if qualname not in self.seen_ops_to_errors:
                    self.seen_ops_to_errors[qualname] = []
                self.run_test_util(func, args_c, kwargs_c)
            except Exception as ex:
                if should_print_repro():
                    self.seen_ops_to_errors[qualname].append((ex, func, args, kwargs))
                else:
                    self.seen_ops_to_errors[qualname].append((ex, None, None, None))
        elif option == "skip":
            pass
        return result


def should_print_repro():
    """If set, the tests generated by `generate_opcheck_tests` will print a
    repro command on failure.

    In order to print the repro command, we need to save some tensors to disk.
    These will be saved under the following directory:
    {tempfile.gettempdir()}/pytorch_opcheck_safe_to_delete/.

    Although this is a temp folder, it will usually not automatically get cleaned
    up, so you'll need to manually delete it.
    """
    key = "PYTORCH_OPCHECK_PRINT_REPRO"
    if key not in os.environ:
        return False
    value = os.environ[key]
    return value == "1" or value == 1


def opcheck(op, args, kwargs=None, *, test_utils="ALL", raise_exception=True):
    """Given an operator and some sample arguments, tests if the operator is
    registered correctly.

    We test the following (which are important for correctness in eager-mode
    PyTorch and with torch.compile):
    - test_schema: if the operator's schema is correct.
    - test_autograd_registration: if autograd was registered correctly,
        i.e. to the correct DispatchKey.
    - test_faketensor: If the operator has a FakeTensor implementation
        (and if it is correct).
    - test_aot_dispatch_static: If the operator works with
        AOTAutograd/AOTDispatch, which is one of the parts in the PT2 stack.
        Checks that the outputs (and gradients, if they are computable)
        of the operator are the same under eager-mode PyTorch and torch.compile.
    - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but
        tests dynamic shapes instead of static shapes.

    For best results, please call ``opcheck`` multiple times with a
    representative set of inputs. For example, if your operator supports
    autograd, please use ``opcheck`` with inputs that require_grad.

    Args:
        op: The operator. Should look like torch.ops.aten.foo
        args: The args to the operator
        kwargs: The kwargs to the operator
        test_utils: Tests that we should run. Default: all of them.
            Example: ["test_schema", "test_faketensor"]
        raise_exception: If we should raise an exception on the first
            error. If False, we will return a dict with information
            on if each test passed or not.

    """

    if kwargs is None:
        kwargs = {}
    if isinstance(op, torch._ops.OpOverloadPacket):
        op = resolve_unique_overload_or_throw(op)
    if not isinstance(op, torch._ops.OpOverload):
        raise ValueError(
            f"opcheck(op, ...): op must be instance of torch._ops.OpOverload, "
            f"e.g. torch.ops.aten.sin.default, got {type(op)}"
        )
    if test_utils == "ALL":
        test_utils = tuple(ALL_TEST_UTILS.keys())
    if isinstance(test_utils, str):
        test_utils = (test_utils,)
    if not isinstance(test_utils, (tuple, list)) or not set(test_utils).issubset(
        ALL_TEST_UTILS.keys()
    ):
        raise ValueError(
            f"opcheck(op, ..., test_utils={test_utils}), expected test_utils "
            f"to be subset of {tuple(ALL_TEST_UTILS.keys())} but it was not"
        )

    results_dict = {}
    for test_util in test_utils:
        tester = ALL_TEST_UTILS[test_util]
        try:
            tester(op, args, kwargs)
            results_dict[test_util] = "SUCCESS"
        except Exception as ex:
            if raise_exception:
                raise OpCheckError(
                    f"opcheck(op, ...): {test_util} failed with {ex} "
                    f"(scroll up for stack trace)"
                ) from ex
            results_dict[test_util] = ex
    return results_dict


class OpCheckError(Exception):
    pass


def generate_repro(test, op, args, kwargs):
    now = datetime.datetime.now()
    unix_timestamp = datetime.datetime.timestamp(now) * 1000
    path = os.path.join(tempfile.gettempdir(), "pytorch_opcheck_safe_to_delete")
    if not os.path.exists(path):
        os.makedirs(path)
    filepath = os.path.join(path, f"repro_{unix_timestamp}.pt")

    ns, name = op._schema.name.split("::")
    overload = op._overloadname

    repro_command = (
        f"import torch\n"
        f"from torch.testing._internal.optests import opcheck\n"
        f"# Make sure you have loaded the library that contains the op\n"
        f"# via an import or torch.ops.load_library(...)\n"
        f"op = torch.ops.{ns}.{name}.{overload}\n"
        f'args, kwargs = torch.load("{filepath}")\n'
        f'opcheck(op, args, kwargs, test_utils="{test}")\n'
    )
    torch.save((args, kwargs), filepath)
    return repro_command


def resolve_unique_overload_or_throw(op: torch._ops.OpOverloadPacket):
    all_schemas = torch._C._jit_get_schemas_for_operator(op._qualified_op_name)
    if len(all_schemas) != 1:
        raise RuntimeError(
            f"opcheck can only test operators without overloads. "
            f"Got the following overloads for {op._qualified_op_name}: "
            f"{[schema.overload_name for schema in all_schemas]}"
        )

    overload_name = all_schemas[0].overload_name
    if overload_name == "":
        return op.default
    return getattr(op, overload_name)


def retrieve(failures_dict, qualname, test_name):
    if qualname not in failures_dict:
        return "success"
    dct = failures_dict[qualname]
    if test_name not in dct:
        return "success"
    return dct[test_name]
