import warnings

from typing import List, Optional, Tuple

import torch
from torch import _VF, Tensor  # noqa: F401
from torch.nn.utils.rnn import PackedSequence


class QuantizedLinear(torch.jit.ScriptModule):
    __constants__ = ["scale", "zero_point"]

    def __init__(self, other):
        super().__init__()
        warnings.warn(
            "torch.jit.QuantizedLinear is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead."
        )

        self.in_features = other.in_features
        self.out_features = other.out_features
        # Quantize weight and discard the original
        (
            self.weight,
            self.col_offsets,
            self.scale,
            self.zero_point,
        ) = torch.fbgemm_linear_quantize_weight(
            other.weight.clone(memory_format=torch.contiguous_format).float()
        )
        self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
        self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
        assert other.bias is not None, "QuantizedLinear requires a bias"
        self.bias = torch.nn.Parameter(
            other.bias.clone(memory_format=torch.contiguous_format).float(),
            requires_grad=False,
        )

        self.register_buffer(
            "packed_tensor_ptr",
            torch.fbgemm_pack_quantized_matrix(
                self.weight.clone(memory_format=torch.contiguous_format)
            ),
        )

    @torch.jit.script_method
    def _unpack(self):
        self.packed_tensor_ptr.set_(torch.fbgemm_pack_quantized_matrix(self.weight))

    @torch.jit.script_method
    def _pack(self):
        self.packed_tensor_ptr.set_(
            torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
        )

    @torch.jit.script_method
    def forward(self, input):
        out = torch.fbgemm_linear_int8_weight_fp32_activation(
            input.float(),
            self.weight,
            self.packed_tensor_ptr,
            self.col_offsets,
            self.scale,
            self.zero_point,
            self.bias,
        )
        return out.to(input.dtype)

    def extra_repr(self):
        repr = (
            "in_features={in_features}, out_features={out_features}, "
            "scale={scale}, zero_point={zero_point}".format(**self.__dict__)
        )
        return repr


# FP16 weights
class QuantizedLinearFP16(torch.jit.ScriptModule):
    def __init__(self, other):
        super().__init__()
        warnings.warn(
            "torch.jit.QuantizedLinearFP16 is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead."
        )
        self.in_features = other.in_features
        self.out_features = other.out_features
        self.original_weight = other.weight
        self.weight = torch.fbgemm_pack_gemm_matrix_fp16(
            other.weight.clone(memory_format=torch.contiguous_format).float()
        )
        assert other.bias is not None, "QuantizedLinearFP16 requires a bias"
        self.bias = torch.nn.Parameter(
            other.bias.clone(memory_format=torch.contiguous_format).float(),
            requires_grad=False,
        )
        self.register_buffer("packed_weight", self.weight)

    @torch.jit.script_method
    def _unpack(self):
        self.packed_weight.set_(
            torch.fbgemm_pack_gemm_matrix_fp16(self.original_weight)
        )

    @torch.jit.script_method
    def _pack(self):
        self.packed_weight.set_(
            torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
        )

    @torch.jit.script_method
    def forward(self, input):
        out = torch.fbgemm_linear_fp16_weight_fp32_activation(
            input.float(), self.packed_weight, self.bias
        )
        return out

    def extra_repr(self):
        repr = "in_features={in_features}, out_features={out_features}, ".format(
            **self.__dict__
        )
        return repr


# Quantized RNN cell implementations
class QuantizedRNNCellBase(torch.jit.ScriptModule):
    __constants__ = [
        "input_size",
        "hidden_size",
        "bias",
        "scale_hh",
        "scale_ih",
        "zero_point_ih",
        "zero_point_hh",
    ]

    def __init__(self, other):
        super().__init__()
        warnings.warn(
            "torch.jit.QuantizedRNNCellBase is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
        )

        self.input_size = other.input_size
        self.hidden_size = other.hidden_size
        self.bias = other.bias
        if not self.bias:
            raise ValueError("Quantized RNN cells require bias terms")

        (
            weight_ih,
            col_offsets_ih,
            self.scale_ih,
            self.zero_point_ih,
        ) = torch.fbgemm_linear_quantize_weight(
            other.weight_ih.clone(memory_format=torch.contiguous_format).float()
        )
        self.register_buffer("weight_ih", weight_ih)
        self.register_buffer("col_offsets_ih", col_offsets_ih)
        (
            weight_hh,
            col_offsets_hh,
            self.scale_hh,
            self.zero_point_hh,
        ) = torch.fbgemm_linear_quantize_weight(
            other.weight_hh.clone(memory_format=torch.contiguous_format).float()
        )
        self.register_buffer("weight_hh", weight_hh)
        self.register_buffer("col_offsets_hh", col_offsets_hh)

        packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih)
        self.register_buffer("packed_ih", packed_ih)
        packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh)
        self.register_buffer("packed_hh", packed_hh)

        self.bias_ih = torch.nn.Parameter(
            other.bias_ih.clone(memory_format=torch.contiguous_format).float(),
            requires_grad=False,
        )
        self.bias_hh = torch.nn.Parameter(
            other.bias_hh.clone(memory_format=torch.contiguous_format).float(),
            requires_grad=False,
        )

    def extra_repr(self):
        s = "{input_size}, {hidden_size}"
        if "bias" in self.__dict__ and self.bias is not True:
            s += ", bias={bias}"
        if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
            s += ", nonlinearity={nonlinearity}"
        return s.format(**self.__dict__)

    @torch.jit.script_method
    def check_forward_input(self, input):
        if input.size(1) != self.input_size:
            raise RuntimeError(
                f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}"
            )

    @torch.jit.script_method
    def check_forward_hidden(
        self, input: Tensor, hx: Tensor, hidden_label: str = ""
    ) -> None:
        if input.size(0) != hx.size(0):
            raise RuntimeError(
                f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}"
            )

        if hx.size(1) != self.hidden_size:
            raise RuntimeError(
                f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}"
            )

    # TODO: for some reason weak_script_method causes a destruction of the
    # module to occur, which in turn frees the packed_ih object via its DataPtr
    # deleter. This is bizarre and should probably get fixed.
    # @torch._jit_internal.weak_script_method
    @torch.jit.script_method
    def _unpack(self):
        self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih))
        self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh))

    # @torch._jit_internal.weak_script_method
    @torch.jit.script_method
    def _pack(self):
        self.packed_ih.set_(
            torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
        )
        self.packed_hh.set_(
            torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
        )


class QuantizedRNNCell(QuantizedRNNCellBase):
    __constants__ = [
        "input_size",
        "hidden_size",
        "bias",
        "scale_hh",
        "scale_ih",
        "zero_point_ih",
        "zero_point_hh",
        "nonlinearity",
    ]

    def __init__(self, other):
        super().__init__(other)
        warnings.warn(
            "torch.jit.QuantizedRNNCell is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
        )
        self.nonlinearity = other.nonlinearity

    @torch.jit.script_method
    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
        self.check_forward_input(input)
        if hx is None:
            hx = torch.zeros(
                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
            )
        self.check_forward_hidden(input, hx, "")
        if self.nonlinearity == "tanh":
            ret = _VF.quantized_rnn_tanh_cell(
                input,
                hx,
                self.weight_ih,
                self.weight_hh,
                self.bias_ih,
                self.bias_hh,
                self.packed_ih,
                self.packed_hh,
                self.col_offsets_ih,
                self.col_offsets_hh,
                self.scale_ih,
                self.scale_hh,
                self.zero_point_ih,
                self.zero_point_hh,
            )
        elif self.nonlinearity == "relu":
            ret = _VF.quantized_rnn_relu_cell(
                input,
                hx,
                self.weight_ih,
                self.weight_hh,
                self.bias_ih,
                self.bias_hh,
                self.packed_ih,
                self.packed_hh,
                self.col_offsets_ih,
                self.col_offsets_hh,
                self.scale_ih,
                self.scale_hh,
                self.zero_point_ih,
                self.zero_point_hh,
            )
        else:
            ret = input  # TODO: remove when jit supports exception flow
            raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
        return ret


class QuantizedLSTMCell(QuantizedRNNCellBase):
    def __init__(self, other):
        super().__init__(other)
        warnings.warn(
            "torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead."
        )

    @torch.jit.script_method
    def forward(
        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
    ) -> Tuple[Tensor, Tensor]:
        self.check_forward_input(input)
        if hx is None:
            zeros = torch.zeros(
                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
            )
            hx = (zeros, zeros)
        self.check_forward_hidden(input, hx[0], "[0]")
        self.check_forward_hidden(input, hx[1], "[1]")
        return _VF.quantized_lstm_cell(
            input,
            hx,
            self.weight_ih,
            self.weight_hh,
            self.bias_ih,
            self.bias_hh,
            self.packed_ih,
            self.packed_hh,
            self.col_offsets_ih,
            self.col_offsets_hh,
            self.scale_ih,
            self.scale_hh,
            self.zero_point_ih,
            self.zero_point_hh,
        )


class QuantizedGRUCell(QuantizedRNNCellBase):
    def __init__(self, other):
        super().__init__(other)
        warnings.warn(
            "torch.jit.QuantizedGRUCell is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRUCell instead."
        )

    @torch.jit.script_method
    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
        self.check_forward_input(input)
        if hx is None:
            hx = torch.zeros(
                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
            )
        self.check_forward_hidden(input, hx, "")
        return _VF.quantized_gru_cell(
            input,
            hx,
            self.weight_ih,
            self.weight_hh,
            self.bias_ih,
            self.bias_hh,
            self.packed_ih,
            self.packed_hh,
            self.col_offsets_ih,
            self.col_offsets_hh,
            self.scale_ih,
            self.scale_hh,
            self.zero_point_ih,
            self.zero_point_hh,
        )


def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
    return tensor.index_select(dim, permutation)


class QuantizedRNNBase(torch.jit.ScriptModule):
    __constants__ = [
        "mode",
        "input_size",
        "hidden_size",
        "num_layers",
        "bias",
        "batch_first",
        "dropout",
        "bidirectional",
        "dtype",
    ]

    def __init__(self, other, dtype=torch.int8):
        super().__init__()
        warnings.warn(
            "torch.jit.QuantizedRNNBase is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic instead."
        )
        self.mode = other.mode
        self.input_size = other.input_size
        self.hidden_size = other.hidden_size
        self.num_layers = other.num_layers
        self.bias = other.bias
        self.batch_first = other.batch_first
        if self.mode != "GRU":
            assert not self.batch_first
        self.dropout = other.dropout
        self.bidirectional = other.bidirectional
        num_directions = 2 if self.bidirectional else 1
        self.dtype = dtype

        assert self.bias

        # TODO: support more than just LSTM
        if self.mode != "LSTM" and self.mode != "GRU":
            raise RuntimeError("Only LSTM or GRU is supported for QuantizedRNN")

        if dtype != torch.int8 and dtype != torch.float16:
            raise RuntimeError(f"Unsupported dtype: {dtype}")

        self.all_weights = []
        for layer in range(self.num_layers):
            for direction in range(num_directions):
                layer_input_size = (
                    self.input_size if layer == 0 else self.hidden_size * num_directions
                )

                suffix = "_reverse" if direction == 1 else ""

                def get_weight_bias(ihhh):
                    weight_name = f"weight_{ihhh}_l{layer}{suffix}"
                    bias_name = f"bias_{ihhh}_l{layer}{suffix}"

                    weight = getattr(other, weight_name)
                    bias = getattr(other, bias_name)
                    return weight, bias

                weight_ih, bias_ih = get_weight_bias("ih")
                weight_hh, bias_hh = get_weight_bias("hh")

                if dtype == torch.int8:
                    cell_params = torch.ops.quantized.make_quantized_cell_params(
                        weight_ih, weight_hh, bias_ih, bias_hh
                    )
                else:
                    packed_ih = torch.ops.quantized.linear_prepack_fp16(
                        weight_ih.float(), bias_ih
                    )
                    packed_hh = torch.ops.quantized.linear_prepack_fp16(
                        weight_hh.float(), bias_hh
                    )

                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
                        packed_ih, packed_hh
                    )

                setattr(self, f"cell_params_{layer}_{suffix}", cell_params)
                self.all_weights.append(cell_params)

    @torch.jit.script_method
    def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
        expected_input_dim = 2 if batch_sizes is not None else 3
        if input.dim() != expected_input_dim:
            raise RuntimeError(
                f"input must have {expected_input_dim} dimensions, got {input.dim()}"
            )
        if self.input_size != input.size(-1):
            raise RuntimeError(
                f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
            )

    @torch.jit.script_method
    def get_expected_hidden_size(
        self, input: Tensor, batch_sizes: Optional[Tensor]
    ) -> Tuple[int, int, int]:
        if batch_sizes is not None:
            mini_batch = int(batch_sizes[0])
        else:
            mini_batch = input.size(0) if self.batch_first else input.size(1)
        num_directions = 2 if self.bidirectional else 1
        expected_hidden_size = (
            self.num_layers * num_directions,
            mini_batch,
            self.hidden_size,
        )
        return expected_hidden_size

    @torch.jit.script_method
    def check_hidden_size(
        self,
        hx: Tensor,
        expected_hidden_size: Tuple[int, int, int],
        msg: str = "Expected hidden size {}, got {}",
    ) -> None:
        if hx.size() != expected_hidden_size:
            raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))

    @torch.jit.script_method
    def check_forward_args(
        self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
    ) -> None:
        self.check_input(input, batch_sizes)
        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
        self.check_hidden_size(
            hidden, expected_hidden_size, msg="Expected hidden size {}, got {}"
        )

    @torch.jit.script_method
    def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
        if permutation is None:
            return hx
        return apply_permutation(hx, permutation)


class QuantizedLSTM(QuantizedRNNBase):
    __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}

    def __init__(self, other, dtype):
        super().__init__(other, dtype)
        warnings.warn(
            "torch.jit.QuantizedLSTM is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTM instead."
        )

    @torch.jit.script_method
    def forward_impl(
        self,
        input: Tensor,
        hx: Optional[Tuple[Tensor, Tensor]],
        batch_sizes: Optional[Tensor],
        max_batch_size: int,
        sorted_indices: Optional[Tensor],
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            zeros = torch.zeros(
                self.num_layers * num_directions,
                max_batch_size,
                self.hidden_size,
                dtype=input.dtype,
                device=input.device,
            )
            hx = (zeros, zeros)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        assert batch_sizes is None
        result = torch.quantized_lstm(
            input,
            hx,
            self.all_weights,
            self.bias,
            self.num_layers,
            float(self.dropout),
            self.training,
            self.bidirectional,
            self.batch_first,
            dtype=self.dtype,
            use_dynamic=False,
        )
        output = result[0]
        hidden = result[1:]

        return output, hidden

    @torch.jit.script_method
    def forward_tensor(
        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        batch_sizes = None
        max_batch_size = input.size(0) if self.batch_first else input.size(1)
        sorted_indices = None
        unsorted_indices = None

        output, hidden = self.forward_impl(
            input, hx, batch_sizes, max_batch_size, sorted_indices
        )

        return output, self.permute_hidden(hidden, unsorted_indices)

    @torch.jit.script_method
    def forward_packed(
        self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
    ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
        input_, batch_sizes, sorted_indices, unsorted_indices = input
        max_batch_size = int(batch_sizes[0])

        output, hidden = self.forward_impl(
            input_, hx, batch_sizes, max_batch_size, sorted_indices
        )

        output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
        return output, self.permute_hidden(hidden, unsorted_indices)

    @torch.jit.script_method
    def permute_hidden(
        self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
    ) -> Tuple[Tensor, Tensor]:
        if permutation is None:
            return hx
        return apply_permutation(hx[0], permutation), apply_permutation(
            hx[1], permutation
        )

    @torch.jit.script_method
    def check_forward_args(
        self,
        input: Tensor,
        hidden: Tuple[Tensor, Tensor],
        batch_sizes: Optional[Tensor],
    ) -> None:
        self.check_input(input, batch_sizes)
        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

        self.check_hidden_size(
            hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}"
        )
        self.check_hidden_size(
            hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}"
        )

    def forward(self, input, hx=None):
        if isinstance(input, PackedSequence):
            return self.forward_packed(input, hx)
        else:
            return self.forward_tensor(input, hx)


class QuantizedGRU(QuantizedRNNBase):
    __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        warnings.warn(
            "torch.jit.QuantizedGRU is deprecated and will be removed in an upcoming "
            "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRU instead."
        )

    @torch.jit.script_method
    def forward_impl(
        self,
        input: Tensor,
        hx: Optional[Tensor],
        batch_sizes: Optional[Tensor],
        max_batch_size: int,
        sorted_indices: Optional[Tensor],
    ) -> Tuple[Tensor, Tensor]:
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            hx = torch.zeros(
                self.num_layers * num_directions,
                max_batch_size,
                self.hidden_size,
                dtype=input.dtype,
                device=input.device,
            )
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        if batch_sizes is None:
            result = torch.quantized_gru(
                input,
                hx,
                self.all_weights,
                self.bias,
                self.num_layers,
                float(self.dropout),
                self.training,
                self.bidirectional,
                self.batch_first,
            )
        else:
            result = torch.quantized_gru(
                input,
                batch_sizes,
                hx,
                self.all_weights,
                self.bias,
                self.num_layers,
                float(self.dropout),
                self.training,
                self.bidirectional,
            )

        output = result[0]
        hidden = result[1]

        return output, hidden

    @torch.jit.script_method
    def forward_tensor(
        self, input: Tensor, hx: Optional[Tensor] = None
    ) -> Tuple[Tensor, Tensor]:
        batch_sizes = None
        max_batch_size = input.size(0) if self.batch_first else input.size(1)
        sorted_indices = None
        unsorted_indices = None

        output, hidden = self.forward_impl(
            input, hx, batch_sizes, max_batch_size, sorted_indices
        )
        return output, self.permute_hidden(hidden, unsorted_indices)

    @torch.jit.script_method
    def forward_packed(
        self, input: PackedSequence, hx: Optional[Tensor] = None
    ) -> Tuple[PackedSequence, Tensor]:
        input_, batch_sizes, sorted_indices, unsorted_indices = input
        max_batch_size = int(batch_sizes[0])

        output, hidden = self.forward_impl(
            input_, hx, batch_sizes, max_batch_size, sorted_indices
        )

        output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
        return output, self.permute_hidden(hidden, unsorted_indices)

    def forward(self, input, hx=None):
        if isinstance(input, PackedSequence):
            return self.forward_packed(input, hx)
        else:
            return self.forward_tensor(input, hx)


def quantize_rnn_cell_modules(module):
    warnings.warn(
        "quantize_rnn_cell_modules function has been deprecated. "
        "Please use torch.ao.quantization.quantize_dynamic API instead."
    )
    reassign = {}
    for name, mod in module.named_modules():
        if mod is module:
            continue
        new_mod = quantize_rnn_cell_modules(mod)
        if new_mod is not mod:
            reassign[name] = new_mod
    for name, mod in reassign.items():
        setattr(module, name, mod)
    if isinstance(module, torch.nn.LSTMCell):
        return QuantizedLSTMCell(module)
    if isinstance(module, torch.nn.GRUCell):
        return QuantizedGRUCell(module)
    if isinstance(module, torch.nn.RNNCell):
        return QuantizedRNNCell(module)
    return module


def quantize_linear_modules(module, dtype=torch.int8):
    warnings.warn(
        "quantize_linear_modules function has been deprecated. "
        "Please use torch.ao.quantization.quantize_dynamic API instead."
    )

    reassign = {}
    for name, mod in module.named_modules():
        if mod is module:
            continue
        new_mod = quantize_linear_modules(mod, dtype)
        if new_mod is not mod:
            reassign[name] = new_mod

    for name, mod in reassign.items():
        setattr(module, name, mod)
    if isinstance(module, torch.nn.Linear):
        if dtype == torch.int8:
            return QuantizedLinear(module)
        elif dtype == torch.float16:
            return QuantizedLinearFP16(module)
        else:
            raise RuntimeError(f"Unsupported dtype: {dtype}")
    return module


def quantize_rnn_modules(module, dtype=torch.int8):
    warnings.warn(
        "quantize_rnn_modules function has been deprecated. "
        "Please use torch.ao.quantization.quantize_dynamic API instead."
    )
    reassign = {}
    for name, mod in module.named_modules():
        if mod is module:
            continue
        new_mod = quantize_rnn_modules(mod, dtype)
        if new_mod is not mod:
            reassign[name] = new_mod

    for name, mod in reassign.items():
        setattr(module, name, mod)
    if isinstance(module, torch.nn.LSTM):
        if dtype != torch.int8 and dtype != torch.float16:
            raise RuntimeError(f"Unsupported dtype: {dtype}")
        return QuantizedLSTM(module, dtype)
    if isinstance(module, torch.nn.GRU):
        return QuantizedGRU(module)
    return module
