import functools
import inspect
import itertools
import logging
import math
import operator
import types
from typing import Dict, List

import torch
from torch import sym_float, sym_int

from .. import config, variables
from ..allowed_functions import is_allowed
from ..exc import (
    AttributeMutationError,
    unimplemented,
    Unsupported,
    UserError,
    UserErrorType,
)
from ..guards import GuardBuilder
from ..replay_record import DummyModule
from ..source import AttrSource, is_constant_source, SuperSource, TypeSource
from ..utils import (
    build_checkpoint_variable,
    check_constant_args,
    check_numpy_ndarray_args,
    check_unspec_python_args,
    get_fake_value,
    guard_if_dyn,
    is_utils_checkpoint,
    istype,
    numpy_operator_wrapper,
    proxy_args_kwargs,
    specialize_args_kwargs,
)
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable, EnumVariable
from .dicts import ConstDictVariable
from .lists import (
    BaseListVariable,
    ListIteratorVariable,
    ListVariable,
    SetVariable,
    SizeVariable,
    TupleIteratorVariable,
    TupleVariable,
)
from .tensor import FakeItemVariable, SymNodeVariable, UnspecializedPythonVariable
from .user_defined import UserDefinedVariable

log = logging.getLogger(__name__)


class BuiltinVariable(VariableTracker):
    @staticmethod
    @functools.lru_cache(None)
    def _constant_fold_functions():
        fns = {
            abs,
            all,
            any,
            bool,
            callable,
            chr,
            divmod,
            float,
            int,
            len,
            max,
            min,
            ord,
            pow,
            repr,
            round,
            str,
            str.format,
            sum,
            type,
            operator.pos,
            operator.neg,
            operator.not_,
            operator.invert,
            operator.pow,
            operator.mul,
            operator.matmul,
            operator.floordiv,
            operator.truediv,
            operator.mod,
            operator.add,
            operator.sub,
            operator.getitem,
            operator.lshift,
            operator.rshift,
            operator.and_,
            operator.or_,
            operator.xor,
            operator.ipow,
            operator.imul,
            operator.imatmul,
            operator.ifloordiv,
            operator.itruediv,
            operator.imod,
            operator.iadd,
            operator.isub,
            operator.ilshift,
            operator.irshift,
            operator.iand,
            operator.ixor,
            operator.ior,
            operator.index,
        }
        fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
        return fns

    def can_constant_fold_through(self):
        return self.fn in self._constant_fold_functions()

    @staticmethod
    @functools.lru_cache(None)
    def _fx_graph_functions():
        fns = {
            operator.pos,
            operator.neg,
            operator.not_,
            operator.invert,
            operator.pow,
            operator.mul,
            operator.matmul,
            operator.floordiv,
            operator.truediv,
            operator.mod,
            operator.add,
            operator.sub,
            operator.getitem,
            operator.lshift,
            operator.rshift,
            operator.and_,
            operator.or_,
            operator.xor,
            operator.ipow,
            operator.imul,
            operator.imatmul,
            operator.ifloordiv,
            operator.itruediv,
            operator.imod,
            operator.iadd,
            operator.isub,
            operator.ilshift,
            operator.irshift,
            operator.iand,
            operator.ixor,
            operator.ior,
        }
        return fns

    @staticmethod
    @functools.lru_cache(None)
    def _binops():
        # function -> ([forward name, reverse name, in-place name], in-place op)
        fns = {
            operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd),
            operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub),
            operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul),
            operator.truediv: (
                ["__truediv__", "__rtruediv__", "__itruediv__"],
                operator.itruediv,
            ),
            operator.floordiv: (
                ["__floordiv__", "__rfloordiv__", "__ifloordiv__"],
                operator.ifloordiv,
            ),
            operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod),
            pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
            operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
            operator.lshift: (
                ["__lshift__", "__rlshift__", "__ilshift__"],
                operator.ilshift,
            ),
            operator.rshift: (
                ["__rshift__", "__rrshift__", "__irshift__"],
                operator.irshift,
            ),
            # NB: The follow binary operators are not supported for now, since the
            # corresponding magic methods aren't defined on SymInt / SymFloat:
            # operator.matmul
            # divmod
            # operator.and_
            # operator.or_
            # operator.xor
        }
        return fns

    @staticmethod
    @functools.lru_cache(None)
    def _binop_handlers():
        # Multiple dispatch mechanism defining custom binop behavior for certain type
        # combinations. Handlers are attempted in order, and will be used if the type checks
        # match. They are expected to have the signature:
        # fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker

        # Override table contains: op_fn -> [list of handlers]
        op_handlers = {}
        for (
            op,
            (magic_method_names, in_place_op),
        ) in BuiltinVariable._binops().items():
            op_handlers[op] = []
            op_handlers[in_place_op] = []

            forward_name, reverse_name, inplace_name = magic_method_names

            # User-defined args (highest precedence)
            def user_defined_handler(
                tx,
                a,
                b,
                options,
                forward_name=forward_name,
                reverse_name=reverse_name,
            ):
                # Manually handle reversing logic if needed (e.g. call __radd__)

                # TODO: If we expand this to handle tensor args, we need to manually
                # handle cases like this:
                #
                # class A(int):
                #     def __radd__(self, other):
                #         print("woof")
                # torch.randn(3) + A(3)
                #
                # In this example, A.__radd__() is not called -> nothing is printed, because
                # Tensor.__add__ only does a subtype test against int, ignoring the subclass.
                # To be fully correct, we should not call A.__radd__() here, and there may be
                # other cases to reason about and add exceptions for.
                if isinstance(a, UserDefinedVariable):
                    return a.call_method(tx, forward_name, [b], {})
                else:
                    return b.call_method(tx, reverse_name, [a], {})

            op_handlers[op].append(
                ((UserDefinedVariable, VariableTracker), user_defined_handler)
            )
            op_handlers[op].append(
                ((VariableTracker, UserDefinedVariable), user_defined_handler)
            )

            def user_defined_inplace_handler(
                tx, a, b, options, forward_name=inplace_name
            ):
                return a.call_method(tx, forward_name, [b], {})

            op_handlers[in_place_op].append(
                ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler)
            )
            op_handlers[in_place_op].append(
                ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler)
            )

            # Dynamic shape args
            def dynamic_handler(tx, a, b, options, fn=op):
                from .builder import wrap_fx_proxy

                return wrap_fx_proxy(
                    tx,
                    tx.output.create_proxy(
                        "call_function", fn, *proxy_args_kwargs([a, b], {})
                    ),
                    **options,
                )

            op_handlers[op].append(
                ((SymNodeVariable, VariableTracker), dynamic_handler)
            )
            op_handlers[op].append(
                ((VariableTracker, SymNodeVariable), dynamic_handler)
            )

            # NB: Prefer out-of-place op when calling in-place op to generate valid graph
            op_handlers[in_place_op].append(
                ((SymNodeVariable, VariableTracker), dynamic_handler)
            )
            op_handlers[in_place_op].append(
                ((VariableTracker, SymNodeVariable), dynamic_handler)
            )

        # Special cases - lower precedence but still prefer these over constant folding

        # List-like addition (e.g. [1, 2] + [3, 4])
        def tuple_add_handler(tx, a, b, options):
            return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)

        def size_add_handler(tx, a, b, options):
            return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options)

        list_like_addition_handlers = [
            # NB: Prefer the tuple-specific logic over base logic because of
            # some SizeVariable weirdness. Specifically, the tuple-specific logic
            # drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
            (
                (SizeVariable, SizeVariable),
                size_add_handler,
            ),
            (
                (TupleVariable, TupleVariable),
                tuple_add_handler,
            ),
            (
                (TupleVariable, ConstantVariable),
                tuple_add_handler,
            ),
            (
                (ConstantVariable, TupleVariable),
                lambda tx, a, b, options: TupleVariable(
                    list(a.unpack_var_sequence(tx)) + b.items, **options
                ),
            ),
            (
                (BaseListVariable, BaseListVariable),
                lambda tx, a, b, options: type(a)(a.items + b.items, **options),
            ),
        ]
        op_handlers[operator.add].extend(list_like_addition_handlers)

        def list_iadd_handler(tx, a, b, options):
            if not a.mutable_local or not b.has_unpack_var_sequence(tx):
                # Handler doesn't apply
                return None

            return tx.replace_all(
                a,
                ListVariable(
                    list(a.items) + list(b.unpack_var_sequence(tx)),
                    regen_guards=False,
                    **options,
                ),
            )

        list_like_iadd_handlers = [
            (
                (ListVariable, VariableTracker),
                list_iadd_handler,
            ),
            (
                (TupleVariable, TupleVariable),
                tuple_add_handler,
            ),
            (
                (TupleVariable, ConstantVariable),
                tuple_add_handler,
            ),
        ]
        op_handlers[operator.iadd].extend(list_like_iadd_handlers)

        # List-like expansion (e.g. [1, 2, 3] * 3)
        def expand_list_like(tx, lst, const, options):
            return lst.__class__(
                items=lst.items * const.as_python_constant(),
                mutable_local=MutableLocal(),
                **options,
            )

        list_like_expansion_handlers = [
            ((ListVariable, ConstantVariable), expand_list_like),
            ((TupleVariable, ConstantVariable), expand_list_like),
            (
                (ConstantVariable, ListVariable),
                lambda tx, a, b, options: expand_list_like(tx, b, a, options),
            ),
            (
                (ConstantVariable, TupleVariable),
                lambda tx, a, b, options: expand_list_like(tx, b, a, options),
            ),
        ]
        op_handlers[operator.mul].extend(list_like_expansion_handlers)

        return op_handlers

    @staticmethod
    def _find_binop_handler(op, a, b):
        handlers = BuiltinVariable._binop_handlers()
        if op not in handlers:
            return None

        # Return first handler that matches the type checks
        for (type1, type2), handler in handlers[op]:
            if isinstance(a, type1) and isinstance(b, type2):
                return handler

        return None

    def can_insert_in_graph(self):
        return self.fn in self._fx_graph_functions()

    def __init__(self, fn, **kwargs):
        super().__init__(**kwargs)
        self.fn = fn

    def __str__(self):
        if self.fn is None:
            name = "None"
        else:
            name = self.fn.__name__

        return f"{self.__class__.__name__}({name})"

    def python_type(self):
        return type(self.fn)

    def as_python_constant(self):
        return self.fn

    def as_proxy(self):
        DTYPE = {
            bool: torch.bool,
            int: torch.int64,
            float: torch.float64,
        }
        if self.fn in DTYPE:
            return DTYPE[self.fn]
        return super().as_proxy()

    def reconstruct(self, codegen):
        name = self.fn.__name__
        assert self.fn.__module__ == "builtins"
        assert name not in codegen.tx.f_globals, "shadowed global"
        return [codegen.create_load_global(name, False, add=True)]

    def constant_args(self, *args, **kwargs):
        return check_constant_args(args, kwargs)

    def tensor_args(self, *args, **kwargs):
        return any(
            isinstance(i, variables.TensorVariable)
            for i in itertools.chain(args, kwargs.values())
        ) and not any(
            isinstance(i, variables.GetAttrVariable)
            for i in itertools.chain(args, kwargs.values())
        )

    def unspec_python_args(self, *args, **kwargs):
        return check_unspec_python_args(args, kwargs)

    @staticmethod
    def unwrap_unspec_args_kwargs(args, kwargs):
        unwrapped_args = []
        unwrapped_kwargs = {}
        for x in args:
            if isinstance(
                x,
                (variables.UnspecializedPythonVariable,),
            ):
                unwrapped_args.append(x.raw_value)
            else:
                unwrapped_args.append(x.as_python_constant())
        for k, v in kwargs:
            if isinstance(
                x,
                (variables.UnspecializedPythonVariable,),
            ):
                unwrapped_kwargs.update({k: v.raw_value})
            else:
                unwrapped_kwargs.update({k: v.as_python_constant()})
        return unwrapped_args, unwrapped_kwargs

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        from .builder import wrap_fx_proxy, wrap_fx_proxy_cls

        constant_args = check_constant_args(args, kwargs)
        tensor_args = self.tensor_args(*args, **kwargs)
        unspec_python_args = self.unspec_python_args(*args, **kwargs)
        options = VariableTracker.propagate(self, args, kwargs.values())
        has_constant_handler = self.can_constant_fold_through() and (
            constant_args or unspec_python_args
        )
        assert isinstance(args, (list, tuple))
        assert isinstance(kwargs, dict)

        # args[0] is list and args[1] is unspec
        if self.fn is operator.getitem and not isinstance(
            args[0], variables.TensorVariable
        ):
            tensor_args = False
            args, kwargs = specialize_args_kwargs(tx, args, kwargs)

        if (
            self.can_insert_in_graph()
            and tensor_args
            and not (
                self.fn is operator.getitem
                and isinstance(args[0], ConstDictVariable)
                and isinstance(args[1], variables.TensorVariable)
            )
        ):
            try:
                fn = self.fn
                if self.fn is operator.iadd and isinstance(
                    args[0], variables.ConstantVariable
                ):
                    # Work around weird bug in hf_T5
                    fn, args = operator.add, [args[1], args[0]]

                if self.fn is operator.getitem and isinstance(args[1], SymNodeVariable):
                    # Standard indexing will force specialization due to
                    # __index__.  Rewrite as a regular torch op which will
                    # trace fine
                    fn, args = torch.select, [
                        args[0],
                        variables.ConstantVariable(0),
                        args[1],
                    ]

                # Interaction between ndarray and tensors:
                #   We prefer the tensor op whenever there are tensors involved
                if check_numpy_ndarray_args(args, kwargs) and not any(
                    type(arg) == variables.TensorVariable for arg in args
                ):
                    proxy = tx.output.create_proxy(
                        "call_function",
                        numpy_operator_wrapper(self.fn),
                        *proxy_args_kwargs(args, kwargs),
                    )

                    return wrap_fx_proxy_cls(
                        variables.NumpyNdarrayVariable, tx, proxy, **options
                    )

                proxy = tx.output.create_proxy(
                    "call_function",
                    fn,
                    *proxy_args_kwargs(args, kwargs),
                )
                if any(isinstance(arg, FakeItemVariable) for arg in args):
                    return wrap_fx_proxy_cls(
                        FakeItemVariable,
                        tx,
                        proxy,
                        **options,
                    )
                elif self.unspec_python_args(*args, **kwargs):
                    _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
                    raw_value = self.fn(*_args, **_kwargs)

                    need_unwrap = any(
                        x.need_unwrap
                        for x in itertools.chain(args, kwargs.values())
                        if isinstance(x, variables.UnspecializedPythonVariable)
                    )

                    return wrap_fx_proxy_cls(
                        UnspecializedPythonVariable,
                        tx,
                        proxy,
                        raw_value=raw_value,
                        need_unwrap=need_unwrap,
                        **options,
                    )
                elif all(isinstance(x, SymNodeVariable) for x in args):
                    return SymNodeVariable.create(tx, proxy, None, **options)
                else:
                    # Work around for vision_maskrcnn due to precision difference
                    # specialize the dividend when float divide by tensor
                    if self.fn is operator.truediv and isinstance(
                        args[0], variables.UnspecializedPythonVariable
                    ):
                        args[0] = args[0].convert_to_constant(tx)
                    return wrap_fx_proxy(tx, proxy, **options)

            except NotImplementedError:
                unimplemented(f"partial tensor op: {self} {args} {kwargs}")

        # Handle cases like int(torch.seed())
        # Also handle sym_float to sym_int cases
        if self.fn in (int, float) and isinstance(args[0], SymNodeVariable):
            fn_ = sym_int if self.fn is int else sym_float
            out = wrap_fx_proxy(
                tx=tx,
                proxy=tx.output.create_proxy(
                    "call_function",
                    fn_,
                    (args[0].as_proxy(),),
                    {},
                ),
                **options,
            )
            return out

        # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.)
        # NB: Tensor args are handled above and not here
        if len(kwargs) == 0 and len(args) == 2:
            # Try to find a handler for the arg types; otherwise, fall through to constant handler
            binop_handler = BuiltinVariable._find_binop_handler(
                self.fn, args[0], args[1]
            )
            if binop_handler:
                res = binop_handler(tx, args[0], args[1], options)
                if res is not None:
                    return res

        handler = getattr(self, f"call_{self.fn.__name__}", None)
        if handler:
            try:
                inspect.signature(handler).bind(tx, *args, **kwargs)
            except TypeError as exc:
                if not has_constant_handler:
                    log.warning(
                        "incorrect arg count %s %s and no constant handler",
                        handler,
                        exc,
                    )
                handler = None

        if handler:
            try:
                result = handler(tx, *args, **kwargs)
                if result is not None:
                    return result.add_options(options)
            except Unsupported as exc:
                if not has_constant_handler:
                    raise
                # Actually, we will handle this just fine
                exc.remove_from_stats()

        if has_constant_handler:
            args, kwargs = specialize_args_kwargs(tx, args, kwargs)
            # constant fold
            return variables.ConstantVariable(
                self.as_python_constant()(
                    *[x.as_python_constant() for x in args],
                    **{k: v.as_python_constant() for k, v in kwargs.items()},
                ),
                **options,
            )

        if self.fn is round:
            if len(args) > 0 and isinstance(args[0], SymNodeVariable):
                raise UserError(
                    UserErrorType.STANDARD_LIBRARY,
                    "Calling round() on symbolic value is not supported. "
                    "You can use floor() to implement this functionality",
                )
        return super().call_function(tx, args, kwargs)

    def _call_min_max(self, tx, *args):
        if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
            # expand iterable
            items = args[0].unpack_var_sequence(tx)
            return self._call_min_max_seq(tx, items)
        elif len(args) == 2:
            return self._call_min_max_binary(tx, args[0], args[1])
        elif len(args) > 2:
            return self._call_min_max_seq(tx, args)

    def _call_min_max_seq(self, tx, items):
        assert len(items) > 0
        if len(items) == 1:
            return items[0]

        return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)

    def _call_min_max_binary(self, tx, a, b):
        if self.tensor_args(a, b):
            if not isinstance(a, variables.TensorVariable):
                a, b = b, a
            assert isinstance(a, variables.TensorVariable)

            # result of an item call is a scalar convert to a tensor
            if isinstance(a, FakeItemVariable):
                a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {})

            # Dynamic input does not get resolved, rather, gets stored as call_function
            if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
                from .builder import wrap_fx_proxy_cls

                return wrap_fx_proxy_cls(
                    type(a),
                    tx=tx,
                    proxy=tx.output.create_proxy(
                        "call_function",
                        self.fn,
                        *proxy_args_kwargs([a, b], {}),
                    ),
                    **VariableTracker.propagate(self, [a, b]),
                )

            # convert min/max to torch ops
            if b.is_python_constant():
                if isinstance(a, variables.NumpyNdarrayVariable):
                    import numpy as np

                    fn = variables.NumpyVariable(np.clip)
                else:
                    fn = variables.TorchVariable(torch.clamp)
                kwargs = {"min": b} if (self.fn is max) else {"max": b}
                result = fn.call_function(tx, [a], kwargs)
            else:
                if isinstance(a, variables.NumpyNdarrayVariable):
                    import numpy as np

                    fn = {max: np.maximum, min: np.minimum}[self.fn]
                    fn = variables.NumpyVariable(fn)
                else:
                    fn = {max: torch.maximum, min: torch.minimum}[self.fn]
                    fn = variables.TorchVariable(fn)
                result = fn.call_function(tx, [a, b], {})

            # return unspec if both a, b are unspec or const
            if all(
                isinstance(
                    i,
                    (
                        variables.UnspecializedPythonVariable,
                        variables.ConstantVariable,
                    ),
                )
                for i in [a, b]
            ):
                if any(isinstance(val, FakeItemVariable) for val in [a, b]):
                    return variables.FakeItemVariable.from_tensor_variable(result)

                if b.is_python_constant():
                    raw_b = b.as_python_constant()
                else:
                    raw_b = b.raw_value
                if self.fn is max:
                    raw_res = max(a.raw_value, raw_b)
                else:
                    raw_res = min(a.raw_value, raw_b)

                need_unwrap = any(
                    x.need_unwrap
                    for x in [a, b]
                    if isinstance(x, variables.UnspecializedPythonVariable)
                )
                return variables.UnspecializedPythonVariable.from_tensor_variable(
                    result, raw_res, need_unwrap
                )
            # otherwise return tensor
            else:
                return result
        elif isinstance(a, variables.ConstantVariable) and isinstance(
            b, variables.ConstantVariable
        ):
            if self.fn is max:
                return variables.ConstantVariable(max(a.value, b.value))
            else:
                return variables.ConstantVariable(min(a.value, b.value))
        elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
            proxy = tx.output.create_proxy(
                "call_function", self.fn, *proxy_args_kwargs([a, b], {})
            )
            return SymNodeVariable.create(tx, proxy, None)
        else:
            unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}")

    call_min = _call_min_max
    call_max = _call_min_max

    def call_range(self, tx, *args):
        if self.unspec_python_args(*args) or self.constant_args(*args):
            args, _ = specialize_args_kwargs(tx, args, {})
            return variables.RangeVariable(args)
        elif self._dynamic_args(*args):
            args = [variables.ConstantVariable(guard_if_dyn(arg)) for arg in args]
            return variables.RangeVariable(args)
        # None no-ops this handler and lets the driving function proceed
        return None

    def _dynamic_args(self, *args, **kwargs):
        return any(isinstance(x, SymNodeVariable) for x in args) or any(
            isinstance(x, SymNodeVariable) for x in kwargs.values()
        )

    def call_slice(self, tx, *args):
        return variables.SliceVariable(args)

    def _dyn_proxy(self, tx, *args, **kwargs):
        from .builder import wrap_fx_proxy

        options = VariableTracker.propagate(self, args, kwargs.values())
        return wrap_fx_proxy(
            tx,
            tx.output.create_proxy(
                "call_function", self.fn, *proxy_args_kwargs(args, kwargs)
            ),
            **options,
        )

    def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
        if self._dynamic_args(*args, **kwargs):
            return self._dyn_proxy(tx, *args, **kwargs)
        cls = variables.BaseListVariable.cls_for(self.fn)
        if obj is None:
            if cls is SetVariable:
                return cls(
                    tx,
                    [],
                    mutable_local=MutableLocal(),
                )
            else:
                return cls(
                    [],
                    mutable_local=MutableLocal(),
                )
        elif obj.has_unpack_var_sequence(tx):
            guards = set()
            if obj.source and not is_constant_source(obj.source):
                if isinstance(obj, TupleIteratorVariable):
                    guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
                else:
                    guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
            if cls is SetVariable:
                return cls(
                    tx,
                    list(obj.unpack_var_sequence(tx)),
                    mutable_local=MutableLocal(),
                    guards=guards,
                ).add_options(self, obj)

            return cls(
                list(obj.unpack_var_sequence(tx)),
                mutable_local=MutableLocal(),
                guards=guards,
            ).add_options(self, obj)

    call_iter = _call_iter_tuple_list
    call_tuple = _call_iter_tuple_list
    call_list = _call_iter_tuple_list
    call_set = _call_iter_tuple_list

    @staticmethod
    def is_supported_call_dict_arg(tx, arg):
        return (
            arg is None
            or isinstance(arg, ConstDictVariable)
            or (
                isinstance(
                    arg,
                    (
                        ListVariable,
                        TupleVariable,
                        ListIteratorVariable,
                    ),
                )
                and all(
                    isinstance(x, (ListVariable, TupleVariable))
                    and isinstance(
                        x.unpack_var_sequence(tx)[0], (ConstantVariable, EnumVariable)
                    )
                    for x in arg.unpack_var_sequence(tx)
                )
            )
        )

    def call_callable(self, tx, arg):
        from .functions import BaseUserFunctionVariable

        if isinstance(
            arg, (variables.UserDefinedClassVariable, BaseUserFunctionVariable)
        ):
            return variables.ConstantVariable(True).add_options(arg)

    @staticmethod
    def call_dict_helper(tx, user_cls, arg, **options):
        if arg is None or isinstance(arg, dict):
            return ConstDictVariable(
                arg if arg is not None else {}, user_cls, mutable_local=MutableLocal()
            ).add_options(options)
        elif isinstance(arg, variables.ConstDictVariable):
            return arg.clone(
                user_cls=user_cls, mutable_local=MutableLocal()
            ).add_options(options)
        elif isinstance(
            arg,
            (
                ListVariable,
                TupleVariable,
                ListIteratorVariable,
            ),
        ):
            items = user_cls()
            for x in arg.unpack_var_sequence(tx):
                k = x.unpack_var_sequence(tx)[0].as_python_constant()
                v = x.unpack_var_sequence(tx)[1]
                items.update({k: v})
            return ConstDictVariable(
                items, user_cls, mutable_local=MutableLocal()
            ).add_options(options)
        else:
            raise AssertionError("call_dict_helper with illegal arg")

    def call_cast(self, _, *args, **kwargs):
        if len(args) == 2:
            return args[1]

        unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}")

    def call_dict(self, tx, *args, **kwargs):
        if not (args or kwargs):
            return self.call_dict_helper(tx, dict, None)
        elif (
            not kwargs
            and len(args) == 1
            and self.is_supported_call_dict_arg(tx, args[0])
        ):
            return self.call_dict_helper(tx, dict, args[0])
        elif not args and kwargs:
            return variables.ConstDictVariable(
                dict(kwargs), user_cls=dict, mutable_local=MutableLocal()
            )
        else:
            unimplemented(f"dict(): {args} {kwargs}")

    def call_zip(self, tx, *args):
        options = VariableTracker.propagate(self, args)
        if all(x.has_unpack_var_sequence(tx) for x in args):
            items = [
                variables.TupleVariable(list(item), **options)
                for item in zip(*[arg.unpack_var_sequence(tx) for arg in args])
            ]
            return variables.TupleVariable(items, **options)

    def call_enumerate(self, tx, *args):
        options = VariableTracker.propagate(self, args)
        if len(args) == 1:
            start = 0
        else:
            assert len(args) == 2
            assert isinstance(args[1], variables.ConstantVariable)
            start = args[1].as_python_constant()
        if args[0].has_unpack_var_sequence(tx):
            items = [
                variables.TupleVariable(
                    [variables.ConstantVariable(idx, **options), var],
                    **options,
                )
                for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
            ]
            return variables.TupleVariable(items, **options)

    def call_len(self, tx, *args, **kwargs):
        return args[0].call_method(tx, "__len__", args[1:], kwargs)

    def call_getitem(self, tx, *args, **kwargs):
        if self.unspec_python_args(*args, **kwargs):
            args, kwargs = specialize_args_kwargs(tx, args, kwargs)
        return args[0].call_method(tx, "__getitem__", args[1:], kwargs)

    def call_isinstance(self, tx, arg, isinstance_type):
        arg_type = arg.python_type()

        isinstance_type = isinstance_type.as_python_constant()

        if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
            return variables.ConstantVariable(arg.call_isinstance(isinstance_type))
        # UserDefinedObject with C extensions can have torch.Tensor attributes,
        # so break graph.
        if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
            arg.value, types.MemberDescriptorType
        ):
            unimplemented(
                f"isinstance called on UserDefinedClass {arg} {isinstance_type}"
            )
        # handle __instancecheck__ defined in user class
        if (
            isinstance(arg, variables.UserDefinedObjectVariable)
            and "__instancecheck__" in isinstance_type.__class__.__dict__
        ):
            return variables.ConstantVariable(
                isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value)
            )

        try:
            val = issubclass(arg_type, isinstance_type)
        except TypeError:
            val = arg_type is isinstance_type
        return variables.ConstantVariable(val)

    def call_super(self, tx, a, b):
        source = (
            None
            if a.source is None or b.source is None
            else SuperSource(type=a.source, base=b.source)
        )
        return variables.SuperVariable(a, b, source=source)

    def call_next(self, tx, arg):
        if isinstance(arg, variables.ListIteratorVariable):
            val, next_iter = arg.next_variables()
            tx.replace_all(arg, next_iter)
            return val
        elif isinstance(arg, variables.BaseListVariable):
            return arg.items[0].add_options(self, arg)

    def call_hasattr(self, tx, obj, attr):
        if attr.is_python_constant():
            name = attr.as_python_constant()
            return obj.call_hasattr(tx, name).add_options(self, obj, attr)

    def call_map(self, tx, fn, seq):
        if seq.has_unpack_var_sequence(tx):
            items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)]
            return variables.TupleVariable(items).add_options(self, fn, seq)

    def call_sum(self, tx, seq, **kwargs):
        # Special case for sum on tuple of floats and ints
        if (
            isinstance(seq, (variables.ListVariable, variables.TupleVariable))
            and all(
                isinstance(x, variables.ConstantVariable)
                and isinstance(x.value, (int, float))
                for x in seq.items
            )
            and not kwargs
        ):
            new_list = [x.value for x in seq.items]
            return variables.ConstantVariable(sum(new_list))
        if seq.has_unpack_var_sequence(tx):
            start = kwargs.pop(
                "start", variables.ConstantVariable(0)
            ).as_python_constant()
            assert not kwargs
            items = seq.unpack_var_sequence(tx)[start:]
            return BuiltinVariable(functools.reduce).call_function(
                tx,
                [
                    BuiltinVariable(operator.add),
                    variables.TupleVariable(items),
                    variables.ConstantVariable(0).add_options(self, seq),
                ],
                {},
            )

    def call_reduce(self, tx, function, iterable, initializer=None):
        if iterable.has_unpack_var_sequence(tx):
            items = iterable.unpack_var_sequence(tx)
            if initializer is None:
                value, items = items[0], items[1:]
            else:
                value = initializer
            for element in items:
                value = function.call_function(tx, [value, element], {})
            return value

    def call_getattr(
        self, tx, obj: VariableTracker, name_var: VariableTracker, default=None
    ):
        from . import (
            ConstantVariable,
            GetAttrVariable,
            PythonModuleVariable,
            TorchVariable,
            UserFunctionVariable,
        )
        from .builder import VariableBuilder

        options = VariableTracker.propagate(self, obj, name_var)
        guards = options["guards"]
        name = name_var.as_python_constant()

        if not name_var.is_python_constant():
            unimplemented("non-const getattr() name")

        if tx.output.side_effects.is_attribute_mutation(obj):
            try:
                # re-read a pending side effect?
                return tx.output.side_effects.load_attr(obj, name).add_options(options)
            except KeyError:
                pass

        if default is not None:
            hasattr_var = self.call_hasattr(tx, obj, name_var)
            guards.update(hasattr_var.guards)
            assert hasattr_var.as_python_constant() in (True, False)
            if not hasattr_var.as_python_constant():
                return default.add_guards(guards)

        if obj.source:
            source = AttrSource(obj.source, name)
            options["source"] = source
        else:
            source = None

        if isinstance(obj, variables.NNModuleVariable):
            return obj.var_getattr(tx, name).add_options(options)
        elif isinstance(obj, variables.TensorVariable) and name == "grad":
            if source:
                # We are going to be raising this tensor as grapharg. So, ensure
                # that we have real grad value instead of fake tensor value.
                # Walk through the inputs of the subgraph and find if we already
                # have the original tensor stored in the graphargs.
                for grapharg in tx.output.graphargs:
                    if grapharg.source == source.base:
                        example_value = grapharg.example.grad
                        return VariableBuilder(tx, source)(example_value).add_options(
                            options
                        )
                unimplemented("tensor grad")
            else:
                unimplemented("tensor grad")
        elif isinstance(
            obj,
            (
                variables.TensorVariable,
                variables.NamedTupleVariable,
                variables.ConstantVariable,
                variables.UserDefinedClassVariable,
                variables.UserDefinedObjectVariable,
            ),
        ):
            try:
                return (
                    obj.var_getattr(tx, name).clone(source=source).add_options(options)
                )
            except NotImplementedError:
                return GetAttrVariable(obj, name, **options)
        elif isinstance(obj, TorchVariable):
            member = getattr(obj.value, name)
            if is_utils_checkpoint(member):
                options["source"] = source
                return build_checkpoint_variable(**options)
            elif is_allowed(member):
                return TorchVariable(member, **options)
            elif ConstantVariable.is_literal(member):
                return ConstantVariable(member, **options)
            else:
                return VariableBuilder(tx, source)(member).add_guards(guards)
        elif isinstance(obj, (PythonModuleVariable, DummyModule)):
            member = obj.value.__dict__[name]

            if config.replay_record_enabled:
                tx.exec_recorder.record_module_access(obj.value, name, member)

            return VariableBuilder(tx, source)(member).add_guards(guards)
        elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
            return ConstantVariable(
                getattr(obj.fn, name), **VariableTracker.propagate(obj)
            )
        else:
            try:
                return (
                    obj.var_getattr(tx, name).clone(source=source).add_options(options)
                )
            except NotImplementedError:
                return GetAttrVariable(obj, name, **options)

    def call_setattr(
        self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
    ):
        from .distributed import PlacementVariable

        if isinstance(obj, (variables.DataClassVariable, PlacementVariable)):
            return obj.call_method(tx, "__setattr__", [name_var, val], {})
        elif (
            tx.output.side_effects.is_attribute_mutation(obj)
            and name_var.is_python_constant()
        ):
            tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val)
            return val.add_options(self, obj, name_var)
        elif isinstance(obj, variables.UserDefinedObjectVariable):
            unimplemented(
                f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}"
            )
        elif isinstance(obj, variables.NNModuleVariable):
            if not tx.output.is_root_tracer():
                raise AttributeMutationError(
                    "Can't inplace modify module params/buffers inside HigherOrderOp"
                )
            if name_var.is_python_constant() and isinstance(
                val, variables.TensorVariable
            ):
                assigning_fake_val = get_fake_value(val.as_proxy().node, tx)

                try:
                    getattr_var = obj.var_getattr(tx, name_var.as_python_constant())
                except AttributeError:
                    getattr_var = None

                if isinstance(getattr_var, variables.TensorVariable):
                    # get_fake_val will return a real tensor here because it's an attribute on the module (get_attr node)
                    existing_attr = get_fake_value(getattr_var.as_proxy().node, tx)
                    existing_fake_attr = (
                        variables.builder.wrap_to_fake_tensor_and_record(
                            existing_attr, tx, source=getattr_var.source, is_tensor=True
                        )
                    )

                    # same tensor identiy, setattr is a no-op
                    mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
                    if (
                        existing_fake_attr is assigning_fake_val
                        and mod_setattr is torch.nn.Module.__setattr__
                    ):
                        return getattr_var

            obj.convert_to_unspecialized(tx)

    def call_delattr(self, tx, obj: VariableTracker, name_var: VariableTracker):
        return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())

    def call_type(self, tx, obj: VariableTracker):
        from .builder import VariableBuilder

        try:
            py_type = obj.python_type()
        except NotImplementedError:
            py_type = None

        if istype(obj, variables.TupleVariable):
            return BuiltinVariable(py_type).add_options(self, obj)

        if py_type is not None and obj.source:
            return VariableBuilder(tx, TypeSource(obj.source))(py_type).add_options(
                self, obj
            )

        raise UserError(
            UserErrorType.ANTI_PATTERN,
            "Can't call type() on generated custom object. "
            "Please use __class__ instead",
        )

    def call_reversed(self, tx, obj: VariableTracker):
        if obj.has_unpack_var_sequence(tx):
            items = list(reversed(obj.unpack_var_sequence(tx)))
            return variables.TupleVariable(
                items, **VariableTracker.propagate(self, obj)
            )

    def call_sorted(self, tx, obj: VariableTracker, **kwargs):
        if (
            obj.has_unpack_var_sequence(tx)
            and not isinstance(obj, variables.TensorVariable)
            and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx))
        ):
            function = kwargs.pop("key", None)
            reverse = kwargs.pop(
                "reverse", ConstantVariable(False)
            ).as_python_constant()
            assert len(kwargs) == 0
            if function:
                items = sorted(
                    obj.unpack_var_sequence(tx),
                    key=lambda x: function.call_function(
                        tx, [x], {}
                    ).as_python_constant(),
                    reverse=reverse,
                )
            else:
                items = sorted(
                    obj.unpack_var_sequence(tx),
                    key=lambda x: x.as_python_constant(),
                    reverse=reverse,
                )
            return variables.ListVariable(items, **VariableTracker.propagate(self, obj))

    def call_chain(self, tx, *args):
        if all(obj.has_unpack_var_sequence(tx) for obj in args):
            items = []
            for obj in args:
                items.extend(obj.unpack_var_sequence(tx))
            return variables.TupleVariable(
                items, **VariableTracker.propagate(self, *args)
            )

    def call_islice(self, tx, iterable, *args):
        if iterable.has_unpack_var_sequence(tx) and all(
            x.is_python_constant() for x in args
        ):
            const_args = [x.as_python_constant() for x in args]
            items = iterable.unpack_var_sequence(tx)
            items = list(itertools.islice(items, *const_args))
            return variables.TupleVariable(
                items, **VariableTracker.propagate(self, iterable, *args)
            )

    # neg is a constant fold function, so we only get here if constant fold is not valid
    def call_neg(self, tx, a):
        if isinstance(a, SymNodeVariable):
            return SymNodeVariable.create(
                tx,
                (operator.neg)(a.as_proxy()),
                sym_num=None,
            )
        # None no-ops this handler and lets the driving function proceed
        return None

    def call_id(self, tx, *args):
        if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
            nn_mod_variable = args[0]
            mod = tx.output.get_submodule(nn_mod_variable.module_key)
            return variables.ConstantVariable(id(mod))
        else:
            unimplemented(f"call_id with args {args}")

    def _comparison(self, tx, left, right):
        """
        Used to implement comparison operators for different types.
        For example, list1 < list2 is implemented differently from tensor1 < tensor2
        """
        from . import (
            BaseListVariable,
            ConstantVariable,
            NNModuleVariable,
            TensorVariable,
            UserDefinedObjectVariable,
            UserFunctionVariable,
        )
        from .lists import SizeVariable
        from .tensor import (
            supported_const_comparison_ops,
            supported_tensor_comparison_ops,
        )

        op = self.fn

        def _unimplemented():
            unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")

        if (
            all(
                isinstance(x, (NNModuleVariable, ConstantVariable))
                for x in [left, right]
            )
            and op in supported_const_comparison_ops.values()
        ):
            left = (
                tx.output.get_submodule(left.module_key)
                if isinstance(left, NNModuleVariable)
                else left.as_python_constant()
            )
            right = (
                tx.output.get_submodule(right.module_key)
                if isinstance(right, NNModuleVariable)
                else right.as_python_constant()
            )
            return ConstantVariable(op(left, right))

        if isinstance(left, UserFunctionVariable):
            if op not in supported_const_comparison_ops.values():
                _unimplemented()
            if not isinstance(right, UserFunctionVariable):
                _unimplemented()
            return ConstantVariable(op(left.fn, right.fn))

        # Note, we have a rare BaseListVariable subtype mismatch with valid comparison
        # x = torch.randn([3, 3])
        # x.size() == (3, 3) # True
        # (3, 3) == x.size() # True
        if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
            right, (TupleVariable, SizeVariable)
        ):
            return BaseListVariable.list_compare(tx, op, left, right)

        if isinstance(left, BaseListVariable):
            if not type(left) == type(right):  # Mismatch in BaseListVariable subclasses
                _unimplemented()
            return BaseListVariable.list_compare(tx, op, left, right)

        if isinstance(left, SetVariable):
            if not type(left) == type(right):  # Mismatch in BaseListVariable subclasses
                _unimplemented()
            return ConstantVariable(op(left._underlying_items, right._underlying_items))

        if isinstance(left, TensorVariable):
            from .builder import wrap_fx_proxy_cls

            if op not in supported_tensor_comparison_ops.values():
                _unimplemented()
            return wrap_fx_proxy_cls(
                type(left),  # handle Ndarrays and Tensors
                tx,
                op(left.as_proxy(), right.as_proxy()),
            )

        if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
            if op not in supported_tensor_comparison_ops.values():
                _unimplemented()

            return SymNodeVariable.create(
                tx,
                op(left.as_proxy(), right.as_proxy()),
                sym_num=None,
            )

        if isinstance(left, ConstantVariable) and isinstance(right, ConstantVariable):
            return ConstantVariable(op(left.value, right.value))

        if isinstance(left, UserDefinedObjectVariable) and isinstance(
            right, UserDefinedObjectVariable
        ):
            return ConstantVariable(op(left.value, right.value))

        if op.__name__ == "is_":
            # If the two objects are of different type, we can safely return False
            if type(left) is not type(right):
                return ConstantVariable(False)

        _unimplemented()

    # and_ is a constant fold function, so we only get here if constant fold is not valid
    def call_and_(self, tx, a, b):
        if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
            b, (SymNodeVariable, ConstantVariable)
        ):
            return SymNodeVariable.create(
                tx,
                tx.output.create_proxy(
                    "call_function", operator.and_, *proxy_args_kwargs([a, b], {})
                ),
                sym_num=None,
            )
        # None no-ops this handler and lets the driving function proceed
        return None

    # or_ is a constant fold function, so we only get here if constant fold is not valid
    def call_or_(self, tx, a, b):
        if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
            b, (SymNodeVariable, ConstantVariable)
        ):
            return SymNodeVariable.create(
                tx,
                tx.output.create_proxy(
                    "call_function", operator.or_, *proxy_args_kwargs([a, b], {})
                ),
                sym_num=None,
            )
        # None no-ops this handler and lets the driving function proceed
        return None

    def call_not_(self, tx, a):
        if isinstance(a, SymNodeVariable):
            return SymNodeVariable.create(
                tx,
                tx.output.create_proxy(
                    "call_function", operator.not_, *proxy_args_kwargs([a], {})
                ),
                sym_num=None,
            )

        if isinstance(a, ListVariable):
            return ConstantVariable(len(a.items) == 0).add_options(self, a)

        return None

    call_eq = _comparison
    call_gt = _comparison
    call_lt = _comparison
    call_ge = _comparison
    call_le = _comparison
    call_ne = _comparison
    call_is_ = _comparison
    call_is_not = _comparison
