import logging
import os
import tempfile
from enum import Enum
from typing import Callable, cast, Dict, Iterable, List, Set

import torch.fx as fx
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._pytree import tree_flatten, tree_unflatten


logger: logging.Logger = logging.getLogger("graph_utils")


class OP(str, Enum):
    CALL_FUNCTION = "call_function"
    CALL_MODULE = "call_module"
    CALL_METHOD = "call_method"
    GET_ATTR = "get_attr"
    OUTPUT = "output"
    PLACEHOLDER = "placeholder"


class CommType(str, Enum):
    ALLREDUCE = "allreduce_"
    ALLGATHER = "allgather_"
    BROADCAST = "broadcast_"
    REDUCESCATTER = "reduce_scatter_"
    SCATTER = "scatter_"


def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorMetadata:
    metadata = node.meta.get("tensor_meta", None)
    if is_required and metadata is None:
        raise RuntimeError(
            f"Callsite expects that ``tensor_meta`` exists in ``{node.name}``, "
            f"but got None instead. Node: {node.op} {node.name} {node.target}"
        )
    return metadata


def get_output(graph: fx.Graph) -> fx.Node:
    """
    Take a graphmodule and returns the graph output node. We traverse in reverse
    to expedite it, with the idea that last node should be output
    """
    for node in reversed(graph.nodes):
        if node.op == OP.OUTPUT:
            return node
    raise RuntimeError(f"Cannot find the output node in {graph}")


def find_node(
    graph: fx.Graph, predicate: Callable, reverse_order: bool = False
) -> List[fx.Node]:
    """
    Take a predicate and return all the nodes in the `graph` where the predicate
    holds.
    """
    nodes = cast(Iterable[fx.Node], graph.nodes)
    if reverse_order:
        nodes = cast(Iterable[fx.Node], iter(reversed(nodes)))  # type: ignore[call-overload]
    return [node for node in nodes if predicate(node)]


def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
    """
    This function ensures nodes in ``subgraph`` satisfy one of the rules:
    1. The user of the node is in ``subgraph``.
    2. The user of the node is output.
    3. There are no users -- the node is a side-effect node.
    """
    all_nodes: Set[fx.Node] = set(subgraph)
    output = get_output(graph)
    for node in subgraph:
        for user in node.users:
            if not isinstance(user, fx.Node):
                continue
            if user not in all_nodes and user != output:
                return False
    return True


def clone_subgraph(
    graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
) -> List[fx.Node]:
    """
    Clone the given subgraph and insert it before ``target``.
    This API currently does not support inserting after ``target``.
    """

    all_nodes = set(subgraph)
    mapping: Dict[fx.Node, fx.Node] = dict()
    cloned_subgraph = []
    with graph.inserting_before(target):
        for node in subgraph:
            cloned_node = graph.call_function(
                node.target, node.args, node.kwargs, node.type
            )
            # TODO: there are many flatten/unflatten in IterGraph that
            # can be simplified with tree_map. Will simplify this in
            # a follow-up PR.
            original_input, _ = tree_flatten((node.args, node.kwargs))
            cloned_input, spec = tree_flatten((cloned_node.args, cloned_node.kwargs))
            mapped_cloned_input = []
            for original_input_node, cloned_input_node in zip(
                original_input, cloned_input
            ):
                if (
                    isinstance(original_input_node, fx.Node)
                    and original_input_node in all_nodes
                ):
                    assert original_input_node in mapping
                    mapped_cloned_input.append(mapping[original_input_node])
                else:
                    mapped_cloned_input.append(cloned_input_node)
            cloned_node.args, cloned_node.kwargs = tree_unflatten(
                mapped_cloned_input, spec
            )
            mapping[node] = cloned_node
            cloned_subgraph.append(cloned_node)

    return cloned_subgraph


def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
    """
    Runs the required steps to ensure production-ready graph.
    note - per the fx docs, eliminate dead code is not very precise.
    Hence, the flag to make this step optional.
    """

    gm.graph.lint()
    if remove_dead_code:
        gm.graph.eliminate_dead_code()
    gm.recompile()


def dump_graphs_to_files(graphs: Dict[str, fx.GraphModule], folder: str = "") -> str:
    if not folder:
        folder = tempfile.mkdtemp()

    for prefix, gm in graphs.items():
        with open(os.path.join(folder, f"{prefix}.graph"), "w") as fp:
            fp.write(str(gm))

    logger.warning("Dump graphs to %s", folder)

    return folder
