from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.utils._pytree as pytree
from torch._subclasses import FakeTensorMode
from torch.distributed._spmd.data_parallel import (
    DataParallelStyle,
    partition_data_parallel,
)
from torch.distributed._spmd.distribute import _convert_to_distributed, Schema
from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard

from torch.fx import GraphModule


class ParallelMode(ABC):
    """
    Basic Parallel Mode interface. Each parallelism pattern should implement
    this interface to describe how to partition and compile the graph in the
    spmd compiler.
    """

    @abstractmethod
    def partition(
        self,
        gm: GraphModule,
        model: torch.nn.Module,
        optimizer: Optional[torch.optim.Optimizer],
        params_and_buffers: Dict[str, Any],
        named_states: Dict[str, Any],
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
    ) -> GraphModule:
        """
        Partition a single device graph to a distributed graph.

        TODO(@wanchaol): some of these arguments are not necessary for
        partitioning, remove the unnecessary ones later.
        """
        raise NotImplementedError()

    @abstractmethod
    def transform_and_compile(self, gm: GraphModule) -> GraphModule:
        """
        Transform and compile a distributed graph with a set of graph
        transformation and optimization passes for each parallel mode.

        The returned result should be a compiled executable graph in
        the distributed environment.
        """
        # TODO: add more necessary arguments to this interface.
        raise NotImplementedError()


class DataParallel(ParallelMode):
    """Data Parallelism mode."""

    def __init__(
        self,
        parallel_style: str = "replicate",
        *,
        input_batch_dim: int = 0,
        custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None,
    ):
        """
        DataParallel Mode that partition the model and graph to data parallel style
        parallelism (i.e. DDP/FSDP/ZERO-3). It currently supports three different
        parallel styles: "replicate", "fully_shard", and "default". See
        :class:`DataParallelStyle` for more details.

        Args:
            parallel_style (str): parallel style to use. Currently supports
                "replicate", "fully_shard", and "default".

        Keyword args:
            input_batch_dim (int): the batch dimension of the input tensor.
                 default: 0
            custom_passes (Callable[[GraphModule], GraphModule], optional):
                A custom callable that overrides the default graph transformation
                and optimization passes.
        """
        if parallel_style == "replicate":
            self.parallel_style = DataParallelStyle.REPLICATE
        elif parallel_style == "fully_shard":
            self.parallel_style = DataParallelStyle.FULLY_SHARD
        elif parallel_style == "default":
            self.parallel_style = DataParallelStyle.DEFAULT
        else:
            raise RuntimeError(f"Unknown parallel style: {parallel_style}")

        # TODO: what if user passes in a incorrect `input_batch_dim`, how should we
        # detect that and do proper error handling?
        self.input_batch_dim = input_batch_dim

        if custom_passes is not None:
            self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
        else:
            # TODO: add a few default passes here.
            self._gm_passes = lambda gm: gm

    def partition(
        self,
        gm: GraphModule,
        model: torch.nn.Module,
        optimizer: Optional[torch.optim.Optimizer],
        params_and_buffers: Dict[str, Any],
        named_states: Dict[str, Any],
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
    ) -> GraphModule:
        # TODO: figure out a way to avoid explicit "cuda" mesh.
        mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()))

        gm = partition_data_parallel(
            gm,
            model,
            optimizer,
            params_and_buffers,
            named_states,
            args,
            kwargs,
            mesh,
            self.parallel_style,
            self.input_batch_dim,
        )
        return gm

    def transform_and_compile(self, gm: GraphModule) -> GraphModule:
        """optimize a distributed graph with a set of optimization passes"""
        # TODO: add more necessary arguments to this interface.
        return self._gm_passes(gm)


class DTensorExpandMode(ParallelMode):
    """
    The DTensor Expand mode. It's replicating the parameters and
    shard the inputs to represent DDP like behavior, it's currently
    a transitent mode before we move to the new data parallel expansion.
    """

    def __init__(
        self, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None
    ):
        self._placements_override: Dict[int, List[Placement]] = {}
        if custom_passes is not None:
            self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
        else:
            # TODO: add a few default passes here.
            self._gm_passes = lambda gm: gm

    def partition(
        self,
        gm: GraphModule,
        model: torch.nn.Module,
        optimizer: Optional[torch.optim.Optimizer],
        params_and_buffers: Dict[str, Any],
        named_states: Dict[str, Any],
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
    ) -> GraphModule:
        flat_args, _ = pytree.tree_flatten(list(args) + list(kwargs.values()))

        mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
        shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
        # FIXME: allow other sharding schemas
        replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])

        inps, schemas = [], []

        for p in pytree.tree_flatten(params_and_buffers)[0]:
            assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}"
            inps.append(p)
            schemas.append(replicate_schema)

        for o in pytree.tree_flatten(named_states)[0]:
            if isinstance(o, torch.Tensor):
                inps.append(o)
                schemas.append(replicate_schema)
            else:
                inps.append(torch.empty(0))
                schemas.append(replicate_schema)

        for a in flat_args:
            if isinstance(a, torch.Tensor):
                inps.append(a)
                if id(a) in self._placements_override:
                    schemas.append(
                        Schema(mesh=mesh, placements=self._placements_override[id(a)])
                    )
                else:
                    schemas.append(shard_schema)
            else:
                # Create dummy tensor and schema for non-tensor inputs for
                # the purpose of dtensor expansion. Non-tensor inputs are
                # guaranteed unused in dispatcher graphs produced by make_fx.
                # However, we still need to respect them so that tensor inputs
                # match wtih their placeholders.
                inps.append(torch.empty(0))
                schemas.append(shard_schema)

        with FakeTensorMode(allow_non_fake_inputs=True):
            fake_inps = [torch.empty_like(inp) for inp in inps]

        return _convert_to_distributed(
            gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False
        )[0]

    def transform_and_compile(self, gm: GraphModule) -> GraphModule:
        """
        Transform and compile a distributed graph with a set of graph transformation
        and optimization passes for the dtensor fallback parallel mode.
        """
        # TODO: move the trasnformation passed to this function
        return self._gm_passes(gm)
