import functools
import itertools
import logging
from typing import Callable, Dict, List, Tuple, Union

import sympy
from sympy import Expr

from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._sympy.functions import FloorDiv, ModularIndexing

from .utils import sympy_subs, sympy_symbol, VarRanges
from .virtualized import V

log = logging.getLogger(__name__)


# This class is a little awkward, because ShapeEnv is doing most of the heavy
# lifting and in some cases we should be directly passing through to ShapeEnv,
# but there is some extra inductor logic that needs to be handled here
class SizeVarAllocator:
    def __init__(self, shape_env=None):
        super().__init__()
        if shape_env is None:
            shape_env = ShapeEnv()
        self.shape_env = shape_env
        self.var_to_val = self.shape_env.var_to_val
        self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
        # Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
        # The basic idea is if we have some complicated sympy expression
        # f(s0), we may choose to precompute it on the host and then replace
        # all occurrences of that sympy expression with ps0, so that when we
        # codegen we simply reference ps0 directly without repeating
        # f(s0).  Unlike regular size variables, ps variables cannot be
        # guarded upon; so if we are asked to guard on a Sympy expression
        # which potentially could have already had a precomputed replacement
        # on it, we are obligated to invert the precomputed replacements
        # (inv_precomputed_replacements).
        self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict()
        self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict()
        self.stride_vars = self.make_stride_vars_cache()
        self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
        self._simplify_loops = self.make_simplify_loops_cache()

    def simplify(self, expr: Expr):
        return sympy.expand(expr).xreplace(self.replacements)

    def make_simplify_with_ranges_cache(self):
        """
        self._simplify_with_ranges() can be expensive, cache its results
        """
        cache = dict()
        replacement_count = len(self.replacements)

        def simplify_with_ranges(expr: Expr, var_ranges: VarRanges):
            nonlocal replacement_count
            if replacement_count != len(self.replacements):
                # new replacements invalidates cached results
                cache.clear()
                replacement_count = len(self.replacements)
            key = (expr, *var_ranges.items())
            result = cache.get(key, None)
            if result is None:
                result = self._simplify_with_ranges(expr, var_ranges)
                cache[key] = result
            return result

        return simplify_with_ranges

    def make_simplify_loops_cache(self):
        """
        self._simplify_with_ranges() can be expensive, cache its results
        """
        cache = dict()
        replacement_count = len(self.replacements)

        def simplify_loops(index_vars, sizes, index_formulas):
            nonlocal replacement_count
            if replacement_count != len(self.replacements):
                # new replacements invalidates cached results
                cache.clear()
                replacement_count = len(self.replacements)
            key = (*index_vars, *sizes, *index_formulas)
            result = cache.get(key, None)
            if result is None:
                result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
                cache[key] = result
            return result

        return simplify_loops

    def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges):
        """
        Simplify indexing expression with knowledge of the ranges of
        iteration variables.
        """

        expr = join_dimensions(self.simplify(expr))
        original_expr = expr

        def remove_zero_terms(base, divisor):
            """Symbols smaller than the divisor are zero"""
            for v in base.free_symbols:
                if v in var_ranges:
                    # var smaller than divisor can be removed
                    # if the rest is guaranteed to be multiple of divisor
                    rest = sympy.Wild("_rest", exclude=[v])
                    m = base.match(v + rest)
                    if m and v not in m[rest].free_symbols:
                        gcd = sympy.gcd(m[rest], divisor)
                        if gcd == divisor:
                            if self.statically_known_leq(var_ranges[v], divisor):
                                base = m[rest]
            return base

        def visit_indexing_div(base, divisor):
            return FloorDiv(remove_zero_terms(base, divisor), divisor)

        def visit_modular_indexing(base, divisor, modulus):
            base = remove_zero_terms(base, divisor)
            base_pos = True
            if isinstance(base, ModularIndexing):
                # for modular indexing, biggest values from the ranges don't necessarily result in
                # the biggest result, the biggest result is modulus - 1
                base_s = base.args[2] - 1
            elif not base.has(ModularIndexing):
                # actual iteration range is to size-1
                iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
                base_lowest = sympy_subs(base, iter_ranges_zero)
                if self.statically_known_leq(0, base_lowest):
                    # can't replace with indexing div if base can be negative
                    base_pos = True
                else:
                    base_pos = False
                iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
                base_s = sympy_subs(base, iter_ranges)
            else:
                base_s = base
            if self.statically_known_lt(base_s, modulus * divisor) and base_pos:
                return FloorDiv(base, divisor)
            return ModularIndexing(base, divisor, modulus)

        if expr.has(ModularIndexing):
            expr = expr.replace(
                ModularIndexing(
                    sympy.Wild("base"),
                    sympy.Wild("divisor"),
                    sympy.Wild("modulus"),
                ),
                visit_modular_indexing,
            )

        if expr.has(FloorDiv):
            expr = expr.replace(
                FloorDiv(
                    sympy.Wild("base"),
                    sympy.Wild("divisor"),
                ),
                visit_indexing_div,
            )

        if expr != original_expr:
            return self._simplify_with_ranges(expr, var_ranges)
        return expr

    def _simplify_loops_impl(self, index_vars, sizes, index_formulas):
        """
        Try to remove as many axis from loop iterations as possible, by:
            1) removing size==1 dimensions
            2) fuse contiguous dimensions into a single loop
            If channel_last = True, we will prevent the last dim fused with other dims
        """
        sizes = list(map(self.simplify, sizes))

        strides = [self.stride_vars(x, index_vars) for x in index_formulas]
        assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))

        for i in range(len(sizes)):
            if sizes[i] == 1:
                # remove dim
                sizes[i] = None

        def can_merge_dims(a, b):
            for k in range(len(strides)):
                if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
                    strides[k][b]
                ):
                    # approximate test passed, try sound version
                    va = index_vars[a]
                    vb = index_vars[b]
                    v = sympy_symbol("_merge_tester")
                    expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
                    expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
                    if self.simplify(expr1) == self.simplify(expr2):
                        continue
                return False
            return True

        changed = True
        while changed:
            changed = False
            for i, j in itertools.product(
                reversed(range(len(sizes))), reversed(range(len(sizes)))
            ):
                if i == j or sizes[i] is None or sizes[j] is None:
                    continue
                if can_merge_dims(i, j):
                    changed = True
                    sizes[i] = sizes[i] * sizes[j]
                    sizes[j] = None

        def reindex(index):
            it = list(reversed(index))
            new_index = []
            for size in sizes:
                if size is None:
                    new_index.append(sympy.Integer(0))
                else:
                    new_index.append(it.pop())
            assert not it
            return new_index

        def prune(index):
            assert len(index) == len(sizes)
            return [i for i, s in zip(index, sizes) if s is not None]

        return [x for x in sizes if x is not None], reindex, prune

    # Note - [On Statically Known]
    #
    # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
    # operated by providing esentially a question, where the size hinted values were evaluted. If the condition was
    # true, we add a guard and return True, otherwise, False.
    #
    # def maybe_guard_foo(args):
    #   if size_hinted_check(args):
    #       return False # No guard, no optim
    #   guard(args) # Make a guard
    #   return True # Safe to apply optimization
    #
    # The prior system incurred a guard, and green lit an optimization.
    #
    # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
    # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
    # return False.
    #
    # def maybe_guard_foo(args):
    #   if all_static(args):
    #       return True # Safe to apply optimization
    #   else:
    #       return False # No guard, no optim

    # See Note - [On Statically Known]

    def is_expr_static_and_true(self, expr: Union[Expr, int]) -> bool:
        if expr in (True, False):
            return expr

        try:
            simplified = self.shape_env._maybe_evaluate_static(expr)
            if simplified is not None:
                return bool(simplified)
        except Exception:
            log.debug("Could not simplify %s", expr)

        return False

    def statically_known_equals(self, left: Expr, right: Expr) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left and right are equal.
        """
        return self.is_expr_static_and_true(sympy.Eq(left, right))

    # See Note - [On Statically Known]
    def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
        """
        if len(left) != len(right):
            return False
        if all(self.statically_known_equals(l, r) for l, r in zip(left, right)):
            return True
        return False

    # See Note - [On Statically Known]
    def statically_known_leq(self, left: Expr, right: Expr) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
        """
        expr = left <= right
        return self.is_expr_static_and_true(expr)

    # See Note - [On Statically Known]
    def statically_known_lt(self, left: Expr, right: Expr) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left is less than right.
        """
        expr = left < right
        return self.is_expr_static_and_true(expr)

    # See Note - [On Statically Known]
    def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool:
        """
        Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
        """
        expr = sympy.Eq(numerator % denominator, 0)
        return self.is_expr_static_and_true(expr)

    # The guard functions require you to ALREADY KNOW that a particular
    # condition holds.  If you don't know (you want to guard on an expression
    # being a particular value, and then get access to that value), use
    # the evaluate functions.

    def guard_equals(self, left: Expr, right: Expr) -> Expr:
        if isinstance(left, Expr):
            left = sympy_subs(left, self.inv_precomputed_replacements)
        if isinstance(right, Expr):
            right = sympy_subs(right, self.inv_precomputed_replacements)
        assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
        return left

    def guard_leq(self, left: Expr, right: Expr) -> None:
        return self.guard_lt(left, right + 1)

    def guard_lt(self, left: Expr, right: Expr) -> None:
        assert self.shape_env.evaluate_expr(sympy.Lt(left, right))

    # The evaluate functions evaluate some symbolic sympy expression
    # (NB: not necessarily an Expr) and return what the concrete result
    # is, guarding on the expression being that result

    # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b)
    # as this will ensure that you actually have a sympy'ified expression,
    # and will prevent you from incorrectly writing evaluate_expr(a == b)
    # which does the wrong thing if a or b is a sympy expression
    def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
        assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
        return self.shape_env.evaluate_expr(sympy.sympify(left))

    def evaluate_min(self, left: Expr, right: Expr) -> Expr:
        """return the smaller of left and right, and guard on that choice"""
        lv = self.size_hint(left)
        rv = self.size_hint(right)
        if lv == rv:
            return self.guard_equals(left, right)
        elif lv < rv:
            self.guard_lt(left, right)
            return left
        else:
            self.guard_lt(right, left)
            return right

    def evaluate_static_shape(self, left: Expr) -> int:
        right = self.size_hint(left)
        self.guard_equals(left, sympy.Integer(right))
        return int(right)

    def evaluate_static_shapes(self, left: List[Expr]) -> List[int]:
        return [self.evaluate_static_shape(x) for x in left]

    def symbolic_hint(self, expr: Expr) -> Expr:
        # Substitute all hints into expr, but leave unbacked symints alone
        if not isinstance(expr, Expr):
            assert isinstance(expr, int)
            return expr
        free_symbols = expr.free_symbols
        if not free_symbols:
            return int(expr)
        while any(s.name.startswith("ps") for s in free_symbols):
            expr = sympy_subs(expr, self.inv_precomputed_replacements)
            free_symbols = expr.free_symbols
        return sympy_subs(expr, self.var_to_val)

    def size_hint(self, expr: Expr) -> int:
        out = self.symbolic_hint(expr)
        try:
            return int(out)
        except Exception:
            log.debug("failed on: %s", out)
            raise

    def size_hints(self, exprs: List[Expr]) -> Tuple[int, ...]:
        return tuple(self.size_hint(x) for x in exprs)

    def _lru_cache(self, fn, maxsize=None):
        """
        Wrapper around functools.lru_cache that clears when replacements
        has been invalidated.
        """
        fn_cache = functools.lru_cache(maxsize)(fn)
        prior_len = len(self.replacements)

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            nonlocal prior_len
            if prior_len != len(self.replacements):
                prior_len = len(self.replacements)
                fn_cache.cache_clear()
            return fn_cache(*args, **kwargs)

        return wrapper

    def make_stride_vars_cache(self):
        cache = self._lru_cache(self._stride_vars)

        def stride_vars(
            index: Expr,
            vars: List[sympy.Symbol],
            support_vars: List[sympy.Symbol] = None,
        ) -> List[Expr]:
            if not support_vars:
                support_vars = vars
            return cache(index, tuple(vars), tuple(support_vars))

        return stride_vars

    def _stride_vars(
        self, index: Expr, vars: List[sympy.Symbol], support_vars: List[sympy.Symbol]
    ) -> List[Expr]:
        """Convert an indexing expression back into strides

        NOTE: This is only valid if the index is a standard strided offset
        calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
        stride of -10 because the index wraps around after the first element

        """
        strides = []
        index = self.simplify(index)
        # remove any offset
        index = index - sympy_subs(
            index, {v: sympy.Integer(0) for v in support_vars if v != 0}
        )
        for i in range(len(vars)):
            # drop all the other dims
            index_dim = sympy_subs(
                index,
                {
                    support_vars[j]: sympy.Integer(0)
                    for j in range(len(support_vars))
                    if vars[i] != support_vars[j] and support_vars[j] != 0
                },
            )
            v = vars[i]
            if v == 0:
                strides.append(sympy.Integer(0))
            else:
                # TODO(jansel): should we use sympy.diff here?
                strides.append(
                    sympy_subs(index_dim, {v: sympy.Integer(1)})
                    - sympy_subs(index_dim, {v: sympy.Integer(0)})
                )
        return strides

    def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
        """Extract offset part of an indexing expression"""
        index = self.simplify(index)
        return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})

    def stride_hints(
        self,
        index: Expr,
        vars: List[sympy.Symbol],
        support_vars: List[sympy.Symbol] = None,
    ) -> List[int]:
        for v in index.free_symbols:
            if v.name.startswith("indirect"):
                index = sympy_subs(index, {v: 0})
        result = []
        for s in self.stride_vars(index, vars, support_vars):
            try:
                result.append(self.size_hint(s))
            except TypeError:
                result.append(0)
        return result

    def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
        strides = tuple(
            map(abs, self.stride_hints(index, vars))
        )  # lambda to placate mypy
        order = list(range(len(strides)))
        order.sort(key=lambda x: (strides[x] == 0, strides[x]))
        return order

    def lookup_precomputed_size(self, expr: Expr):
        if expr not in self.precomputed_replacements:
            sym = sympy_symbol(f"ps{len(self.precomputed_replacements)}")
            self.precomputed_replacements[expr] = sym
            self.inv_precomputed_replacements[sym] = expr
        return self.precomputed_replacements[expr]


def join_dimensions(expr: Expr) -> Expr:
    if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
        return expr  # fast exit path
    return _join_dimensions_cached(expr)


@functools.lru_cache(256)
def _join_dimensions_cached(expr: Expr) -> Expr:
    """
    ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
    becomes
    ModularIndexing(i0, 1, 128)
    ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
    becomes i0


    This type of pattern can come from view operations
    """
    assert isinstance(expr, sympy.Add)

    scale = sympy.Wild("scale", exclude=[0])
    base = sympy.Wild("base")
    divisor = sympy.Wild("divisor")
    mod1 = sympy.Wild("modulus")
    mod2 = sympy.Wild("modulus2")
    for term1 in expr.args:
        m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
        if m1:
            for term2 in expr.args:
                m2 = term2.match(
                    m1[scale]
                    * m1[mod1]
                    * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
                )
                if m2 and term1 != term2:
                    expr = join_dimensions(
                        expr
                        - term1
                        - term2
                        + m1[scale]
                        * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
                    )
                    return expr
    for term1 in expr.args:
        m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
        if m1:
            for term2 in expr.args:
                m2 = term2.match(
                    m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1])
                )
                if m2 is not None:  # in case of success we get an empty dict here
                    expr = join_dimensions(
                        expr
                        - term1
                        - term2
                        + m1[scale] * FloorDiv(m1[base], m1[divisor])
                    )
                    return expr
    return expr


class SimplifyIndexing(V.WrapperHandler):  # type: ignore[name-defined]
    """
    A wrapper around .virtualize.ops that uses var range information to
    simplify ModularIndexing/FloorDiv.
    """

    def __init__(self, inner, var_ranges: VarRanges):
        super().__init__(inner)
        self.name = "SimplifyIndexing"
        self._simplify: Callable[
            [Expr], Expr
        ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)

    def load(self, name: str, index: sympy.Expr):
        return self._inner.load(name, self._simplify(index))

    def store(self, name, index, value, mode=None):
        return self._inner.store(name, self._simplify(index), value, mode=mode)

    def store_reduction(self, name, index, value):
        return self._inner.store_reduction(name, self._simplify(index), value)

    def index_expr(self, index, dtype):
        return self._inner.index_expr(self._simplify(index), dtype)
