import contextlib
import functools
import hashlib
import os
import re
import textwrap
import sys
from argparse import Namespace
from dataclasses import (
    fields,
    is_dataclass,
)
from typing import (
    Tuple,
    List,
    Iterable,
    Iterator,
    Callable,
    Sequence,
    TypeVar,
    Optional,
    Dict,
    Any,
    Union,
    Set,
    NoReturn,
)
from enum import Enum

from torchgen.code_template import CodeTemplate

# Safely load fast C Yaml loader/dumper if they are available
try:
    from yaml import CSafeLoader as Loader
except ImportError:
    from yaml import SafeLoader as Loader  # type: ignore[misc]

try:
    from yaml import CSafeDumper as Dumper
except ImportError:
    from yaml import SafeDumper as Dumper  # type: ignore[misc]
YamlDumper = Dumper

# A custom loader for YAML that errors on duplicate keys.
# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
class YamlLoader(Loader):
    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
        mapping = []
        for key_node, value_node in node.value:
            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
            assert (
                key not in mapping
            ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
            mapping.append(key)
        mapping = super().construct_mapping(node, deep=deep)  # type: ignore[no-untyped-call]
        return mapping


# Many of these functions share logic for defining both the definition
# and declaration (for example, the function signature is the same), so
# we organize them into one function that takes a Target to say which
# code we want.
#
# This is an OPEN enum (we may add more cases to it in the future), so be sure
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
# for your use.
Target = Enum(
    "Target",
    (
        # top level namespace (not including at)
        "DEFINITION",
        "DECLARATION",
        # TORCH_LIBRARY(...) { ... }
        "REGISTRATION",
        # namespace { ... }
        "ANONYMOUS_DEFINITION",
        # namespace cpu { ... }
        "NAMESPACED_DEFINITION",
        "NAMESPACED_DECLARATION",
    ),
)

# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
# occurrence of a parameter in the derivative formula
IDENT_REGEX = r"(^|\W){}($|\W)"

# TODO: Use a real parser here; this will get bamboozled
def split_name_params(schema: str) -> Tuple[str, List[str]]:
    m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
    if m is None:
        raise RuntimeError(f"Unsupported function schema: {schema}")
    name, _, params = m.groups()
    return name, params.split(", ")


T = TypeVar("T")
S = TypeVar("S")

# These two functions purposely return generators in analogy to map()
# so that you don't mix up when you need to list() them

# Map over function that may return None; omit Nones from output sequence
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
    for x in xs:
        r = func(x)
        if r is not None:
            yield r


# Map over function that returns sequences and cat them all together
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
    for x in xs:
        for r in func(x):
            yield r


# Conveniently add error context to exceptions raised.  Lets us
# easily say that an error occurred while processing a specific
# context.
@contextlib.contextmanager
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
    try:
        yield
    except Exception as e:
        # TODO: this does the wrong thing with KeyError
        msg = msg_fn()
        msg = textwrap.indent(msg, "  ")
        msg = f"{e.args[0]}\n{msg}" if e.args else msg
        e.args = (msg,) + e.args[1:]
        raise


# A little trick from https://github.com/python/mypy/issues/6366
# for getting mypy to do exhaustiveness checking
# TODO: put this somewhere else, maybe
def assert_never(x: NoReturn) -> NoReturn:
    raise AssertionError("Unhandled type: {}".format(type(x).__name__))


@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
    return CodeTemplate.from_file(template_fn)


# String hash that's stable across different executions, unlike builtin hash
def string_stable_hash(s: str) -> int:
    sha1 = hashlib.sha1(s.encode("latin1")).digest()
    return int.from_bytes(sha1, byteorder="little")


# A small abstraction for writing out generated files and keeping track
# of what files have been written (so you can write out a list of output
# files)
class FileManager:
    install_dir: str
    template_dir: str
    dry_run: bool
    filenames: Set[str]

    def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
        self.install_dir = install_dir
        self.template_dir = template_dir
        self.filenames = set()
        self.dry_run = dry_run

    def _write_if_changed(self, filename: str, contents: str) -> None:
        old_contents: Optional[str]
        try:
            with open(filename, "r") as f:
                old_contents = f.read()
        except IOError:
            old_contents = None
        if contents != old_contents:
            # Create output directory if it doesn't exist
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with open(filename, "w") as f:
                f.write(contents)

    def write_with_template(
        self,
        filename: str,
        template_fn: str,
        env_callable: Callable[[], Union[str, Dict[str, Any]]],
    ) -> None:
        filename = "{}/{}".format(self.install_dir, filename)
        assert filename not in self.filenames, "duplicate file write {filename}"
        self.filenames.add(filename)
        if not self.dry_run:
            env = env_callable()
            if isinstance(env, dict):
                # TODO: Update the comment reference to the correct location
                if "generated_comment" not in env:
                    comment = "@" + "generated by torchgen/gen.py"
                    comment += " from {}".format(os.path.basename(template_fn))
                    env["generated_comment"] = comment
                template = _read_template(os.path.join(self.template_dir, template_fn))
                self._write_if_changed(filename, template.substitute(env))
            elif isinstance(env, str):
                self._write_if_changed(filename, env)
            else:
                assert_never(env)

    def write(
        self,
        filename: str,
        env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]],
    ) -> None:
        self.write_with_template(filename, filename, env_callable)

    def write_sharded(
        self,
        filename: str,
        items: Iterable[T],
        *,
        key_fn: Callable[[T], str],
        env_callable: Callable[[T], Dict[str, List[str]]],
        num_shards: int,
        base_env: Optional[Dict[str, Any]] = None,
        sharded_keys: Set[str],
    ) -> None:

        everything: Dict[str, Any] = {"shard_id": "Everything"}
        shards: List[Dict[str, Any]] = [
            {"shard_id": f"_{i}"} for i in range(num_shards)
        ]
        all_shards = [everything] + shards

        if base_env is not None:
            for shard in all_shards:
                shard.update(base_env)

        for key in sharded_keys:
            for shard in all_shards:
                if key in shard:
                    assert isinstance(
                        shard[key], list
                    ), "sharded keys in base_env must be a list"
                    shard[key] = shard[key].copy()
                else:
                    shard[key] = []

        def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
            for k, v in from_.items():
                assert k in sharded_keys, f"undeclared sharded key {k}"
                into[k] += v

        if self.dry_run:
            # Dry runs don't write any templates, so incomplete environments are fine
            items = ()

        for item in items:
            key = key_fn(item)
            sid = string_stable_hash(key) % num_shards
            env = env_callable(item)

            merge_env(shards[sid], env)
            merge_env(everything, env)

        dot_pos = filename.rfind(".")
        if dot_pos == -1:
            dot_pos = len(filename)
        base_filename = filename[:dot_pos]
        extension = filename[dot_pos:]

        for shard in all_shards:
            shard_id = shard["shard_id"]
            self.write_with_template(
                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
            )

        # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
        self.filenames.discard(
            f"{self.install_dir}/{base_filename}Everything{extension}"
        )

    def write_outputs(self, variable_name: str, filename: str) -> None:
        """Write a file containing the list of all outputs which are
        generated by this script."""
        content = "set({}\n    {})".format(
            variable_name,
            "\n    ".join('"' + name + '"' for name in sorted(self.filenames)),
        )
        self._write_if_changed(filename, content)


# Helper function to generate file manager
def make_file_manager(
    options: Namespace, install_dir: Optional[str] = None
) -> FileManager:
    template_dir = os.path.join(options.source_path, "templates")
    install_dir = install_dir if install_dir else options.install_dir
    return FileManager(
        install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
    )


# Helper function to create a pretty representation for dataclasses
def dataclass_repr(
    obj: Any,
    indent: int = 0,
    width: int = 80,
) -> str:
    # built-in pprint module support dataclasses from python 3.10
    if sys.version_info >= (3, 10):
        from pprint import pformat

        return pformat(obj, indent, width)

    return _pformat(obj, indent=indent, width=width)


def _pformat(
    obj: Any,
    indent: int,
    width: int,
    curr_indent: int = 0,
) -> str:
    assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"

    class_name = obj.__class__.__name__
    # update current indentation level with class name
    curr_indent += len(class_name) + 1

    fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]

    fields_str = []
    for name, attr in fields_list:
        # update the current indent level with the field name
        # dict, list, set and tuple also add indent as done in pprint
        _curr_indent = curr_indent + len(name) + 1
        if is_dataclass(attr):
            str_repr = _pformat(attr, indent, width, _curr_indent)
        elif isinstance(attr, dict):
            str_repr = _format_dict(attr, indent, width, _curr_indent)
        elif isinstance(attr, (list, set, tuple)):
            str_repr = _format_list(attr, indent, width, _curr_indent)
        else:
            str_repr = repr(attr)

        fields_str.append(f"{name}={str_repr}")

    indent_str = curr_indent * " "
    body = f",\n{indent_str}".join(fields_str)
    return f"{class_name}({body})"


def _format_dict(
    attr: Dict[Any, Any],
    indent: int,
    width: int,
    curr_indent: int,
) -> str:
    curr_indent += indent + 3
    dict_repr = []
    for k, v in attr.items():
        k_repr = repr(k)
        v_str = (
            _pformat(v, indent, width, curr_indent + len(k_repr))
            if is_dataclass(v)
            else repr(v)
        )
        dict_repr.append(f"{k_repr}: {v_str}")

    return _format(dict_repr, indent, width, curr_indent, "{", "}")


def _format_list(
    attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
    indent: int,
    width: int,
    curr_indent: int,
) -> str:
    curr_indent += indent + 1
    list_repr = [
        _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
        for l in attr
    ]
    start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
    return _format(list_repr, indent, width, curr_indent, start, end)


def _format(
    fields_str: List[str],
    indent: int,
    width: int,
    curr_indent: int,
    start: str,
    end: str,
) -> str:
    delimiter, curr_indent_str = "", ""
    # if it exceed the max width then we place one element per line
    if len(repr(fields_str)) >= width:
        delimiter = "\n"
        curr_indent_str = " " * curr_indent

    indent_str = " " * indent
    body = f", {delimiter}{curr_indent_str}".join(fields_str)
    return f"{start}{indent_str}{body}{end}"
