import torch.distributed as dist

from torch._C._distributed_c10d import (
    _create_work_from_future,
    AllgatherOptions,
    AllreduceOptions,
    BarrierOptions,
    ReduceScatterOptions,
    BroadcastOptions,
    ScatterOptions,
    AllToAllOptions
)
from torch.futures import Future

from typing import List
from torch import Tensor


def ret_work(ret):
    fut = Future()
    fut.set_result(ret)
    return _create_work_from_future(fut)


class FakeProcessGroup(dist.ProcessGroup):
    """
    A fake process group (not related to FakeTensor) is a process group which
    doesn't actually do any communication, it just hallucinates some
    communication.  You can run a single rank with a fake process group
    without needing multiple processes (simulates per-rank behavior)

    NOTE: This is not a real process group, and it would produce wrong results
    for every collective. It should be used as a convinient tool when playing
    with distributed but don't care about the actual data.
    """
    def __init__(self, rank, world_size):
        super().__init__(rank, world_size)
        self._rank = rank
        self._world_size = world_size

    def allreduce(self, tensor_list, opts=AllreduceOptions()):
        return ret_work(tensor_list)

    def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()):
        # NOTE: in general it's not good form to try to make FakePG work with 'real data',
        # but the reasoning here is that we want FakePG to work with DeviceMesh's init
        # code that have the data validation, which makes it worth the tradeoff.
        # In general user should use MTPG or normal PG for cases where they may care about
        # real data from collectives
        for chunk in output_tensors[0]:
            chunk.copy_(input_tensor[0])
        return ret_work(output_tensors)

    def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()):
        return ret_work(output_tensor)

    def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
        # assume each rank have the same input tensor so we just copy to the results
        # since it's not a real allgather, we simply make this copying logic to let
        # some simple validation works (i.e. calling allgather to see if each rank have
        # the same tensor or not)
        # NOTE: in general it's not good form to try to make FakePG work with 'real data',
        # but the reasoning here is that we want FakePG to work with DeviceMesh's init
        # code that have the data validation, which makes it worth the tradeoff.
        # In general user should use MTPG or normal PG for cases where they may care about
        # real data from collectives
        chunks = output_tensor.chunk(self._world_size)
        for chunk in chunks:
            chunk.copy_(input_tensor)
        return ret_work(output_tensor)

    def _reduce_scatter_base(self, output_tensor, input_tensor, opts=ReduceScatterOptions()):
        return ret_work(output_tensor)

    def barrier(self, opts=BarrierOptions()):
        # it should be no-op for fake pg
        pass

    def broadcast(self, tensors: List[Tensor], opts=BroadcastOptions()):
        return ret_work(tensors)

    def scatter(
        self,
        output_tensors: List[Tensor],
        input_tensors: List[List[Tensor]],
        opts=ScatterOptions(),
    ):
        return ret_work(output_tensors)

    def alltoall(
        self,
        output_tensors: List[Tensor],
        input_tensors: List[Tensor],
        opts=AllToAllOptions(),
    ):
        return ret_work(output_tensors)

    def alltoall_base(
        self,
        output_tensor: Tensor,
        input_tensor: Tensor,
        output_split_sizes: List[int],
        input_split_sizes: List[int],
        opts=AllToAllOptions(),
    ):
        return ret_work(output_tensor)

    def send(
        self,
        tensors: List[Tensor],
        dstRank: int,
        tag: int,
    ):
        return ret_work(None)

    def recv(
        self,
        tensors: List[Tensor],
        srcRank: int,
        tag: int,
    ):
        return ret_work(tensors)

    def getBackendName(self):
        return "fake"

    def __repr__(self):
        return f"FakePG world_size:{self._world_size} rank:{self._rank}"


class FakeStore(dist.Store):
    """
    A fake store is a fake Key-Value store simply for initialization usage
    the of fake process group, one can either use FakeStore or HashStore.
    """
    pass

def _create_fake_pg(prefix_store, rank, world_size, timeout):
    return FakeProcessGroup(rank, world_size)

dist.Backend.register_backend("fake", _create_fake_pg, devices=['cpu', 'cuda'])
