import collections
import dataclasses
import enum
from typing import Any, Optional, Union

from torch._guards import ChainedSource, GuardSource, Source

from . import utils
from .bytecode_transformation import create_call_function, create_instruction
from .utils import enum_repr

# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
# so those cases are omitted intentionally
_GUARD_SOURCE_NN_MODULE = {
    GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
    GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
    GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
    GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
}

_GUARD_SOURCE_FSDP_MODULE = {
    GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
    GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
    GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
    GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
    GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
    GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
}

_GUARD_SOURCE_NOT_NN_MODULE = {
    GuardSource.LOCAL: GuardSource.LOCAL,
    GuardSource.GLOBAL: GuardSource.GLOBAL,
    GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
    GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
    GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL,
    GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL,
}


def is_constant_source(source):
    if isinstance(source, ConstantSource):
        return True
    try:
        if source.guard_source() == GuardSource.CONSTANT:
            return True
    except NotImplementedError:
        pass

    return False


def is_input_source(source):
    return source.guard_source() in [
        GuardSource.LOCAL,
        GuardSource.GLOBAL,
        GuardSource.LOCAL_NN_MODULE,
        GuardSource.GLOBAL_NN_MODULE,
        GuardSource.LOCAL_FSDP_MODULE,
        GuardSource.GLOBAL_FSDP_MODULE,
    ]


def reconstruct_getitem(
    source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
):
    instrs = source.base.reconstruct(codegen)

    if isinstance(source.index, Source):
        instrs.extend(source.index.reconstruct(codegen))
    else:
        if index_is_slice:
            assert isinstance(source, GetItemSource)
            instrs.append(codegen.create_load_const(source.unpack_slice()))
        else:
            instrs.append(codegen.create_load_const(source.index))

    return instrs


@dataclasses.dataclass(frozen=True)
class LocalSource(Source):
    local_name: str
    cell_or_freevar: bool = False

    def reconstruct(self, codegen):
        return [codegen.create_load(self.local_name)]

    def guard_source(self):
        return GuardSource.LOCAL

    def name(self):
        return f"L[{repr(self.local_name)}]"


@dataclasses.dataclass(frozen=True)
class RandomValueSource(Source):
    random_call_index: int

    def guard_source(self):
        return GuardSource.RANDOM_VALUE

    def reconstruct(self, codegen):
        return [
            codegen.create_load(codegen.tx.output.random_values_var),
            codegen.create_load_const(self.random_call_index),
            create_instruction("BINARY_SUBSCR"),
        ]

    def name(self):
        return f"random_value_{self.random_call_index}"


@dataclasses.dataclass(frozen=True)
class GeneratorStateSource(Source):
    device: str
    initial_seed: int

    def guard_source(self):
        return GuardSource.RANDOM_VALUE

    def reconstruct(self, codegen):
        # generator state is a torch.ByteTensor, so we reuse TensorVariable reconstruction in codegen.py
        raise NotImplementedError()

    def name(self):
        name = f"generator_state_{self.device}_{self.initial_seed}"
        return f"L[{name}]"


@dataclasses.dataclass(frozen=True)
class GlobalSource(Source):
    global_name: str

    def reconstruct(self, codegen):
        return [codegen.create_load_global(self.global_name, False, add=True)]

    def guard_source(self):
        return GuardSource.GLOBAL

    def name(self):
        return f"G[{repr(self.global_name)}]"


@dataclasses.dataclass(frozen=True)
class DummyGlobalSource(Source):
    def reconstruct(self, codegen):
        raise NotImplementedError()

    def guard_source(self):
        return GuardSource.GLOBAL

    def name(self):
        return ""


@dataclasses.dataclass(frozen=True)
class GlobalWeakRefSource(Source):
    global_name: str

    def reconstruct(self, codegen):
        return [
            codegen.create_load_global(self.global_name, True, add=True),
            *create_call_function(0, False),
        ]

    def guard_source(self):
        return GuardSource.GLOBAL

    def name(self):
        return f"G[{repr(self.global_name)}]()"


@dataclasses.dataclass(frozen=True)
class AttrSource(ChainedSource):
    member: str

    def __post_init__(self):
        assert self.base, "Can't construct an AttrSource without a valid base source"
        if "." in self.member:
            member_parts = self.member.split(".")
            object.__setattr__(
                self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
            )
            object.__setattr__(self, "member", member_parts[-1])

    def reconstruct(self, codegen):
        return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        if not self.member.isidentifier():
            return f"getattr({self.base.name()}, {self.member!r})"
        return f"{self.base.name()}.{self.member}"


@dataclasses.dataclass(frozen=True)
class ParamBufferSource(AttrSource):
    def guard_source(self):
        return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]


class TensorProperty(enum.Enum):
    SIZE = 0
    STRIDE = 1
    STORAGE_OFFSET = 2

    def method_name(self):
        if self is TensorProperty.SIZE:
            return "size"
        elif self is TensorProperty.STRIDE:
            return "stride"
        elif self is TensorProperty.STORAGE_OFFSET:
            return "storage_offset"


@dataclasses.dataclass(frozen=True)
class TensorPropertySource(ChainedSource):
    prop: TensorProperty
    idx: Optional[int] = None  # None for STORAGE_OFFSET

    def __post_init__(self):
        assert self.base is not None
        if self.prop is TensorProperty.STORAGE_OFFSET:
            assert self.idx is None
        else:
            assert self.idx is not None

    def reconstruct(self, codegen):
        instructions = [
            *self.base.reconstruct(codegen),
            codegen.create_load_attr(self.prop.method_name()),
        ]
        if self.idx is not None:
            instructions.append(codegen.create_load_const(self.idx))
        instructions.extend(
            create_call_function(1 if self.idx is not None else 0, True)
        )
        return instructions

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        if self.prop is TensorProperty.SIZE:
            return f"{self.base.name()}.size()[{self.idx}]"
        elif self.prop is TensorProperty.STRIDE:
            return f"{self.base.name()}.stride()[{self.idx}]"
        elif self.prop is TensorProperty.STORAGE_OFFSET:
            assert self.idx is None
            return f"{self.base.name()}.storage_offset()"
        else:
            raise AssertionError(f"unhandled {self.prop}")


@dataclasses.dataclass(frozen=True)
class NegateSource(ChainedSource):
    def __post_init__(self):
        assert self.base is not None

    def reconstruct(self, codegen):
        raise NotImplementedError()

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        # NB: use method call so that function stripping regexes work
        return f"{self.base.name()}.__neg__()"


@dataclasses.dataclass(frozen=True)
class DefaultsSource(ChainedSource):
    idx_key: Union[int, str]
    is_kw: bool = False
    field: str = dataclasses.field(init=False, repr=False, compare=False)
    _name: str = dataclasses.field(init=False, repr=False, compare=False)

    def __post_init__(self):
        assert (
            self.base
        ), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
        if self.is_kw:
            assert isinstance(self.idx_key, str)
            object.__setattr__(self, "field", "__kwdefaults__")
            object.__setattr__(
                self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
            )
        else:
            assert isinstance(self.idx_key, int)
            object.__setattr__(self, "field", "__defaults__")
            object.__setattr__(
                self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
            )

    def reconstruct(self, codegen):
        instrs = self.base.reconstruct(codegen)
        instrs.extend(codegen.create_load_attrs(self.field))
        instrs.extend(
            [
                codegen.create_load_const(self.idx_key),
                create_instruction("BINARY_SUBSCR"),
            ]
        )
        return instrs

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        return self._name


@dataclasses.dataclass(frozen=True)
class GetItemSource(ChainedSource):
    index: Any
    index_is_slice: bool = False

    def __post_init__(self):
        assert self.base is not None
        if isinstance(self.index, slice):
            # store the hashable version of the slice so the whole GetItemSource is hashable
            super().__setattr__("index", self.index.__reduce__())
            super().__setattr__("index_is_slice", True)

    def reconstruct(self, codegen):
        return [
            *reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice),
            create_instruction("BINARY_SUBSCR"),
        ]

    def guard_source(self):
        return self.base.guard_source()

    def unpack_slice(self):
        assert self.index_is_slice
        slice_class, slice_args = self.index
        return slice_class(*slice_args)

    def name(self):
        if isinstance(self.index, Source):
            return f"{self.base.name()}[{self.index.name()}]"
        else:
            if self.index_is_slice:
                return f"{self.base.name()}[{self.unpack_slice()!r}]"
            elif isinstance(self.index, enum.Enum):
                return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
            else:
                return f"{self.base.name()}[{self.index!r}]"


@dataclasses.dataclass(frozen=True)
class TupleIteratorGetItemSource(GetItemSource):
    def reconstruct(self, codegen):
        codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
        return [
            *self.base.reconstruct(codegen),
            codegen.create_load_const(self.index),
            *create_call_function(2, True),
        ]

    def name(self):
        return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"


@dataclasses.dataclass(frozen=True)
class TypeSource(ChainedSource):
    def __post_init__(self):
        assert self.base is not None

    def reconstruct(self, codegen):
        codegen.load_import_from("builtins", "type")
        return self.base.reconstruct(codegen) + create_call_function(1, True)

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        return f"type({self.base.name()})"


# NB - SuperSource is a weird one.
# it is our only source with 2 bases, so we use the objec
# as the base, rather than the type, since an invocation
# like super(Foo, foo) is represented here, the source object base is more spiritually
# aligned with the instance, rather than the type.
# This whole construction is questionable tho, and we should probably find a way to
# avoid this exception to our otherwise nice source parentage invariant.
@dataclasses.dataclass(frozen=True)
class SuperSource(ChainedSource):
    type: Source

    def __post_init__(self):
        assert self.type is not None
        assert self.base is not None

    def reconstruct(self, codegen):
        codegen.load_import_from("builtins", "super")
        return (
            self.type.reconstruct(codegen)
            + self.base.reconstruct(codegen)
            + create_call_function(2, True)
        )

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        return f"super({self.type.name()}, {self.base.name()})"


@dataclasses.dataclass(frozen=True)
class ODictGetItemSource(ChainedSource):
    index: Any

    def __post_init__(self):
        assert self.base is not None

    def reconstruct(self, codegen):
        return [
            codegen._create_load_const(collections.OrderedDict.__getitem__),
            *reconstruct_getitem(self, codegen, index_is_slice=False),
            *create_call_function(2, True),
        ]

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        if isinstance(self.index, type):
            rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
            return f"___odict_getitem({self.base.name()}, {rep})"
        elif isinstance(self.index, Source):
            return f"___odict_getitem({self.base.name()}, {self.index.name()})"
        else:
            return f"___odict_getitem({self.base.name()}, {self.index!r})"


@dataclasses.dataclass(frozen=True)
class NNModuleSource(ChainedSource):
    def reconstruct(self, codegen):
        return self.base.reconstruct(codegen)

    def guard_source(self):
        return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]

    def name(self):
        return self.base.name()


@dataclasses.dataclass(frozen=True)
class NotNNModuleSource(NNModuleSource):
    def guard_source(self):
        return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]


@dataclasses.dataclass(frozen=True)
class FSDPNNModuleSource(NNModuleSource):
    def guard_source(self):
        return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]


@dataclasses.dataclass(frozen=True)
class GlobalStateSource(Source):
    def name(self):
        return ""

    def guard_source(self):
        return GuardSource.GLOBAL


@dataclasses.dataclass(frozen=True)
class ConstantSource(Source):
    source_name: str

    def reconstruct(self, codegen):
        return [codegen.create_load_global(self.source_name, False, add=False)]

    def guard_source(self):
        return GuardSource.CONSTANT

    def name(self):
        return self.source_name

    def make_guard(self, fn, is_volatile=False):
        raise NotImplementedError()


@dataclasses.dataclass(frozen=True)
class NumpyTensorSource(ChainedSource):
    def name(self) -> str:
        return f"__as_tensor({self.base.name()})"

    def guard_source(self):
        return self.base.guard_source()

    def reconstruct(self, codegen):
        codegen.load_import_from("torch", "as_tensor")
        return self.base.reconstruct(codegen) + create_call_function(1, True)


# This is a synthetic source that is associated with the singleton
# shape env guard we always register for all frames.  We get the actual
# guard contents from the ambient ShapeEnv
@dataclasses.dataclass(frozen=True)
class ShapeEnvSource(Source):
    def name(self):
        return ""

    def guard_source(self):
        return GuardSource.SHAPE_ENV


def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
    if isinstance(source, ChainedSource):
        return is_from_local_source(
            source.base, allow_cell_or_freevar=allow_cell_or_freevar
        )
    if not isinstance(source, LocalSource):
        return False
    if not allow_cell_or_freevar and source.cell_or_freevar:
        return False
    return True
