# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from typing import Any, Generic, List, Sequence, TypeVar, Union, cast, overload

from typing_extensions import Literal, Protocol, TypeAlias, runtime_checkable

from streamlit import util
from streamlit.elements.heading import HeadingProtoTag
from streamlit.elements.select_slider import SelectSliderSerde
from streamlit.elements.slider import SliderScalar, SliderScalarT, SliderSerde, Step
from streamlit.elements.time_widgets import (
    DateInputSerde,
    DateWidgetReturn,
    TimeInputSerde,
    _parse_date_value,
)
from streamlit.proto.Block_pb2 import Block as BlockProto
from streamlit.proto.Button_pb2 import Button as ButtonProto
from streamlit.proto.Checkbox_pb2 import Checkbox as CheckboxProto
from streamlit.proto.Code_pb2 import Code as CodeProto
from streamlit.proto.ColorPicker_pb2 import ColorPicker as ColorPickerProto
from streamlit.proto.DateInput_pb2 import DateInput as DateInputProto
from streamlit.proto.Element_pb2 import Element as ElementProto
from streamlit.proto.Exception_pb2 import Exception as ExceptionProto
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.Heading_pb2 import Heading as HeadingProto
from streamlit.proto.Markdown_pb2 import Markdown as MarkdownProto
from streamlit.proto.MultiSelect_pb2 import MultiSelect as MultiSelectProto
from streamlit.proto.NumberInput_pb2 import NumberInput as NumberInputProto
from streamlit.proto.Radio_pb2 import Radio as RadioProto
from streamlit.proto.Selectbox_pb2 import Selectbox as SelectboxProto
from streamlit.proto.Slider_pb2 import Slider as SliderProto
from streamlit.proto.Text_pb2 import Text as TextProto
from streamlit.proto.TextArea_pb2 import TextArea as TextAreaProto
from streamlit.proto.TextInput_pb2 import TextInput as TextInputProto
from streamlit.proto.TimeInput_pb2 import TimeInput as TimeInputProto
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
from streamlit.runtime.state.common import user_key_from_widget_id
from streamlit.runtime.state.session_state import SessionState


# TODO This class serves as a fallback option for elements that have not
# been implemented yet, as well as providing implementations of some
# trivial methods. It may have significantly reduced scope, or be removed
# entirely, once all elements have been implemented.
# This class will not be sufficient implementation for most elements.
# Widgets need their own classes to translate interactions into the appropriate
# WidgetState and provide higher level interaction interfaces, and other elements
# have enough variation in how to get their values that most will need their
# own classes too.
@dataclass(init=False)
class Element:
    type: str
    proto: Any = field(repr=False)
    root: ElementTree = field(repr=False)
    key: str | None

    def __init__(self, proto: ElementProto, root: ElementTree):
        self.proto = proto
        self.root = root
        ty = proto.WhichOneof("type")
        assert ty is not None
        self.type = ty
        self.key = None

    def __iter__(self):
        yield self

    @property
    def value(self) -> Any:
        p = getattr(self.proto, self.type)
        try:
            state = self.root.session_state
            assert state
            return state[p.id]
        except ValueError:
            # No id field, not a widget
            return p.value

    def widget_state(self) -> WidgetState | None:
        return None

    def run(self) -> ElementTree:
        return self.root.run()

    def __repr__(self):
        return util.repr_(self)


@dataclass(init=False, repr=False)
class Text(Element):
    proto: TextProto

    type: str
    root: ElementTree = field(repr=False)
    key: None = None

    def __init__(self, proto: TextProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self.type = "text"

    @property
    def value(self) -> str:
        return self.proto.body


@dataclass(init=False, repr=False)
class HeadingBase(Element, ABC):
    proto: HeadingProto

    type: str
    tag: str
    anchor: str | None
    hide_anchor: bool
    root: ElementTree = field(repr=False)
    key: None

    def __init__(self, proto: HeadingProto, root: ElementTree, type_: str):
        self.proto = proto
        self.key = None
        self.tag = proto.tag
        self.anchor = proto.anchor
        self.hide_anchor = proto.hide_anchor
        self.root = root
        self.type = type_

    @property
    def value(self) -> str:
        return self.proto.body


@dataclass(init=False, repr=False)
class Title(HeadingBase):
    def __init__(self, proto: HeadingProto, root: ElementTree):
        super().__init__(proto, root, "title")


@dataclass(init=False, repr=False)
class Header(HeadingBase):
    def __init__(self, proto: HeadingProto, root: ElementTree):
        super().__init__(proto, root, "header")


@dataclass(init=False, repr=False)
class Subheader(HeadingBase):
    def __init__(self, proto: HeadingProto, root: ElementTree):
        super().__init__(proto, root, "subheader")


@dataclass(init=False, repr=False)
class Markdown(Element):
    proto: MarkdownProto

    type: str
    is_caption: bool
    allow_html: bool
    root: ElementTree = field(repr=False)
    key: None

    def __init__(self, proto: MarkdownProto, root: ElementTree):
        self.proto = proto
        self.key = None
        self.is_caption = proto.is_caption
        self.allow_html = proto.allow_html
        self.root = root
        self.type = "markdown"

    @property
    def value(self) -> str:
        return self.proto.body


@dataclass(init=False, repr=False)
class Caption(Markdown):
    def __init__(self, proto: MarkdownProto, root: ElementTree):
        super().__init__(proto, root)
        self.type = "caption"


@dataclass(init=False, repr=False)
class Latex(Markdown):
    def __init__(self, proto: MarkdownProto, root: ElementTree):
        super().__init__(proto, root)
        self.type = "latex"


@dataclass(init=False, repr=False)
class Code(Element):
    proto: CodeProto

    type: str
    language: str
    show_line_numbers: bool
    root: ElementTree = field(repr=False)
    key: None

    def __init__(self, proto: CodeProto, root: ElementTree):
        self.proto = proto
        self.key = None
        self.language = proto.language
        self.show_line_numbers = proto.show_line_numbers
        self.root = root
        self.type = "code"

    @property
    def value(self) -> str:
        return self.proto.code_text


@dataclass(repr=False)
class Exception(Element):
    type: str
    message: str
    is_markdown: bool
    stack_trace: list[str]
    is_warning: bool

    def __init__(self, proto: ExceptionProto, root: ElementTree):
        self.key = None
        self.root = root
        self.proto = proto
        self.type = "exception"

        self.message = proto.message
        self.is_markdown = proto.message_is_markdown
        self.stack_trace = list(proto.stack_trace)
        self.is_warning = proto.is_warning

    @property
    def value(self) -> str:
        return self.message


@dataclass(init=False, repr=False)
class Divider(Markdown):
    def __init__(self, proto: MarkdownProto, root: ElementTree):
        super().__init__(proto, root)
        self.type = "divider"


@runtime_checkable
class Widget(Protocol):
    id: str
    key: str | None

    def set_value(self, v: Any):
        ...


T = TypeVar("T")


@dataclass(repr=False)
class ColorPicker(Element, Widget):
    _value: str | None

    proto: ColorPickerProto
    type: str
    id: str
    label: str
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: ColorPickerProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "color_picker"
        self.id = proto.id
        self.label = proto.label
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    @property
    def value(self) -> str:
        """The currently selected value from the options."""
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            return cast(str, state[self.id])

    def widget_state(self) -> WidgetState:
        """Protobuf message representing the state of the widget, including
        any interactions that have happened.
        Should be the same as the frontend would produce for those interactions.
        """
        ws = WidgetState()
        ws.id = self.id
        ws.string_value = self.value
        return ws

    def set_value(self, v: str) -> ColorPicker:
        self._value = v
        return self

    def pick(self, v: str) -> ColorPicker:
        if not v.startswith("#"):
            v = f"#{v}"
        return self.set_value(v)


@dataclass(init=False, repr=False)
class Radio(Element, Widget, Generic[T]):
    _value: T | None

    proto: RadioProto
    type: str
    id: str
    label: str
    options: list[str]
    help: str
    form_id: str
    disabled: bool
    horizontal: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: RadioProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "radio"
        self.id = proto.id
        self.label = proto.label
        self.options = list(proto.options)
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.horizontal = proto.horizontal
        self.key = user_key_from_widget_id(self.id)

    @property
    def index(self) -> int:
        return self.options.index(str(self.value))

    @property
    def value(self) -> T:
        """The currently selected value from the options."""
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            return cast(T, state[self.id])

    def set_value(self, v: T) -> Radio[T]:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        """Protobuf message representing the state of the widget, including
        any interactions that have happened.
        Should be the same as the frontend would produce for those interactions.
        """
        ws = WidgetState()
        ws.id = self.id
        ws.int_value = self.index
        return ws


@dataclass(init=False, repr=False)
class Checkbox(Element, Widget):
    _value: bool | None

    proto: CheckboxProto
    type: str
    id: str
    label: str
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: CheckboxProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "checkbox"
        self.id = proto.id
        self.label = proto.label
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def widget_state(self) -> WidgetState:
        ws = WidgetState()
        ws.id = self.id
        ws.bool_value = self.value
        return ws

    @property
    def value(self) -> bool:
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            return cast(bool, state[self.id])

    def set_value(self, v: bool) -> Checkbox:
        self._value = v
        return self

    def check(self) -> Checkbox:
        return self.set_value(True)

    def uncheck(self) -> Checkbox:
        return self.set_value(False)


@dataclass(init=False, repr=False)
class Multiselect(Element, Widget, Generic[T]):
    _value: list[T] | None

    proto: MultiSelectProto
    type: str
    id: str
    label: str
    options: list[str]
    help: str
    form_id: str
    disabled: bool
    max_selections: int
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: MultiSelectProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "multiselect"
        self.id = proto.id
        self.label = proto.label
        self.options = list(proto.options)
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.max_selections = proto.max_selections
        self.key = user_key_from_widget_id(self.id)

    def widget_state(self) -> WidgetState:
        """Protobuf message representing the state of the widget, including
        any interactions that have happened.
        Should be the same as the frontend would produce for those interactions.
        """
        ws = WidgetState()
        ws.id = self.id
        ws.int_array_value.data[:] = self.indices
        return ws

    @property
    def value(self) -> list[T]:
        """The currently selected values from the options."""
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            return cast(List[T], state[self.id])

    @property
    def indices(self) -> Sequence[int]:
        return [self.options.index(str(v)) for v in self.value]

    def set_value(self, v: list[T]) -> Multiselect[T]:
        """
        Set the value of the multiselect widget.
        Implementation note: set_value not work correctly if `format_func` is also
        passed to the multiselect. This is because we send options via proto with
        applied `format_func`, but keep original values in session state
        as widget value.
        """
        self._value = v
        return self

    def select(self, v: T) -> Multiselect[T]:
        current = self.value
        if v in current:
            return self
        else:
            new = current.copy()
            new.append(v)
            self.set_value(new)
            return self

    def unselect(self, v: T) -> Multiselect[T]:
        current = self.value
        if v not in current:
            return self
        else:
            new = current.copy()
            while v in new:
                new.remove(v)
            self.set_value(new)
            return self


@dataclass(init=False, repr=False)
class Selectbox(Element, Widget, Generic[T]):
    _value: T | None

    proto: SelectboxProto = field(repr=False)
    type: str
    id: str
    label: str
    options: list[str]
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: SelectboxProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "selectbox"
        self.id = proto.id
        self.label = proto.label
        self.options = list(proto.options)
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    @property
    def index(self) -> int:
        if len(self.options) == 0:
            return 0
        return self.options.index(str(self.value))

    @property
    def value(self) -> T:
        """The currently selected value from the options."""
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            return cast(T, state[self.id])

    def set_value(self, v: T) -> Selectbox[T]:
        """
        Set the value of the selectbox.
        Implementation note: set_value not work correctly if `format_func` is also
        passed to the selectbox. This is because we send options via proto with applied
        `format_func`, but keep original values in session state as widget value.
        """
        self._value = v
        return self

    def select(self, v: T) -> Selectbox[T]:
        return self.set_value(v)

    def select_index(self, index: int) -> Selectbox[T]:
        return self.set_value(cast(T, self.options[index]))

    def widget_state(self) -> WidgetState:
        """Protobuf message representing the state of the widget, including
        any interactions that have happened.
        Should be the same as the frontend would produce for those interactions.
        """
        ws = WidgetState()
        ws.id = self.id
        ws.int_value = self.index
        return ws


@dataclass(init=False, repr=False)
class Button(Element, Widget):
    _value: bool

    proto: ButtonProto
    type: str
    id: str
    label: str
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: ButtonProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = False

        self.type = "button"
        self.id = proto.id
        self.label = proto.label
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def widget_state(self) -> WidgetState:
        ws = WidgetState()
        ws.id = self.id
        ws.trigger_value = self._value
        return ws

    @property
    def value(self) -> bool:
        if self._value:
            return self._value
        else:
            state = self.root.session_state
            assert state
            return cast(bool, state[self.id])

    def set_value(self, v: bool) -> Button:
        self._value = v
        return self

    def click(self) -> Button:
        return self.set_value(True)


@dataclass(init=False, repr=False)
class Slider(Element, Widget, Generic[SliderScalarT]):
    _value: SliderScalarT | Sequence[SliderScalarT] | None

    proto: SliderProto
    type: str
    data_type: SliderProto.DataType.ValueType
    id: str
    label: str
    min_value: SliderScalar
    max_value: SliderScalar
    step: Step
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: SliderProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "slider"
        self.data_type = proto.data_type
        self.id = proto.id
        self.label = proto.label
        self.min_value = proto.min
        self.max_value = proto.max
        self.step = proto.step
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def set_value(
        self, v: SliderScalarT | Sequence[SliderScalarT]
    ) -> Slider[SliderScalarT]:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        data_type = self.proto.data_type
        serde = SliderSerde([], data_type, True, None)
        v = serde.serialize(self.value)

        ws = WidgetState()
        ws.id = self.id
        ws.double_array_value.data[:] = v
        return ws

    @property
    def value(self) -> SliderScalarT | Sequence[SliderScalarT]:
        """The currently selected value or range."""
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            # Awkward to do this with `cast`
            return state[self.id]  # type: ignore

    def set_range(
        self, lower: SliderScalarT, upper: SliderScalarT
    ) -> Slider[SliderScalarT]:
        return self.set_value([lower, upper])


@dataclass(init=False, repr=False)
class SelectSlider(Element, Widget, Generic[T]):
    _value: T | Sequence[T] | None

    proto: SliderProto
    type: str
    data_type: SliderProto.DataType.ValueType
    id: str
    label: str
    options: list[str]
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: SliderProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "select_slider"
        self.data_type = proto.data_type
        self.id = proto.id
        self.label = proto.label
        self.options = list(proto.options)
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def set_value(self, v: T | Sequence[T]) -> SelectSlider[T]:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        serde = SelectSliderSerde(self.options, [], False)
        v = serde.serialize(self.value)

        ws = WidgetState()
        ws.id = self.id
        ws.double_array_value.data[:] = v
        return ws

    @property
    def value(self) -> T | Sequence[T]:
        """The currently selected value or range."""
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            # Awkward to do this with `cast`
            return state[self.id]  # type: ignore

    def set_range(self, lower: T, upper: T) -> SelectSlider[T]:
        return self.set_value([lower, upper])


@dataclass(repr=False)
class TextInput(Element, Widget):
    _value: str | None
    proto: TextInputProto
    type: str
    id: str
    label: str
    max_chars: int
    help: str
    form_id: str
    autocomplete: str
    placeholder: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: TextInputProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "text_input"
        self.id = proto.id
        self.label = proto.label
        self.max_chars = proto.max_chars
        self.help = proto.help
        self.form_id = proto.form_id
        self.autocomplete = proto.autocomplete
        self.placeholder = proto.placeholder
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def set_value(self, v: str) -> TextInput:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        ws = WidgetState()
        ws.id = self.id
        ws.string_value = self.value
        return ws

    @property
    def value(self) -> str:
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            # Awkward to do this with `cast`
            return state[self.id]  # type: ignore

    def input(self, v: str) -> TextInput:
        # TODO should input be setting or appending?
        if self.max_chars and len(v) > self.max_chars:
            return self
        return self.set_value(v)


@dataclass(repr=False)
class TextArea(Element, Widget):
    _value: str | None

    proto: TextAreaProto
    type: str
    id: str
    label: str
    max_chars: int
    help: str
    form_id: str
    placeholder: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: TextAreaProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "text_area"
        self.id = proto.id
        self.label = proto.label
        self.max_chars = proto.max_chars
        self.help = proto.help
        self.form_id = proto.form_id
        self.placeholder = proto.placeholder
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def set_value(self, v: str) -> TextArea:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        ws = WidgetState()
        ws.id = self.id
        ws.string_value = self.value
        return ws

    @property
    def value(self) -> str:
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            # Awkward to do this with `cast`
            return state[self.id]  # type: ignore

    def input(self, v: str) -> TextArea:
        # TODO should input be setting or appending?
        if self.max_chars and len(v) > self.max_chars:
            return self
        return self.set_value(v)


Number = Union[int, float]


@dataclass(repr=False)
class NumberInput(Element, Widget):
    _value: Number | None
    proto: NumberInputProto
    type: str
    id: str
    label: str
    min_value: Number
    max_value: Number
    step: Number
    help: str
    form_id: str
    placeholder: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: NumberInputProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "number_input"
        self.id = proto.id
        self.label = proto.label
        self.min_value = proto.min
        self.max_value = proto.max
        self.step = proto.step
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def set_value(self, v: Number) -> NumberInput:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        ws = WidgetState()
        ws.id = self.id
        ws.double_value = self.value
        return ws

    @property
    def value(self) -> Number:
        if self._value is not None:
            return self._value
        else:
            state = self.root.session_state
            assert state
            # Awkward to do this with `cast`
            return state[self.id]  # type: ignore

    def increment(self) -> NumberInput:
        v = min(self.value + self.step, self.max_value)
        return self.set_value(v)

    def decrement(self) -> NumberInput:
        v = max(self.value - self.step, self.min_value)
        return self.set_value(v)


SingleDateValue: TypeAlias = Union[date, datetime]
DateValue: TypeAlias = Union[SingleDateValue, Sequence[SingleDateValue]]


@dataclass(repr=False)
class DateInput(Element, Widget):
    _value: DateValue | None
    proto: DateInputProto
    type: str
    id: str
    label: str
    min: date
    max: date
    is_range: bool
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: DateInputProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "date_input"
        self.id = proto.id
        self.label = proto.label
        self.min = datetime.strptime(proto.min, "%Y/%m/%d").date()
        self.max = datetime.strptime(proto.max, "%Y/%m/%d").date()
        self.is_range = proto.is_range
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def set_value(self, v: DateValue) -> DateInput:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        ws = WidgetState()
        ws.id = self.id

        serde = DateInputSerde(None)  # type: ignore
        ws.string_array_value.data[:] = serde.serialize(self.value)
        return ws

    @property
    def value(self) -> DateWidgetReturn:
        if self._value is not None:
            parsed, _ = _parse_date_value(self._value)
            return tuple(parsed)  # type: ignore
        else:
            state = self.root.session_state
            assert state
            return state[self.id]  # type: ignore


TimeValue: TypeAlias = Union[time, datetime]


@dataclass(repr=False)
class TimeInput(Element, Widget):
    _value: TimeValue | None
    proto: TimeInputProto
    type: str
    id: str
    label: str
    step: int
    help: str
    form_id: str
    disabled: bool
    key: str | None

    root: ElementTree = field(repr=False)

    def __init__(self, proto: TimeInputProto, root: ElementTree):
        self.proto = proto
        self.root = root
        self._value = None

        self.type = "time_input"
        self.id = proto.id
        self.label = proto.label
        self.step = proto.step
        self.help = proto.help
        self.form_id = proto.form_id
        self.disabled = proto.disabled
        self.key = user_key_from_widget_id(self.id)

    def set_value(self, v: TimeValue) -> TimeInput:
        self._value = v
        return self

    def widget_state(self) -> WidgetState:
        ws = WidgetState()
        ws.id = self.id

        serde = TimeInputSerde(None)  # type: ignore
        ws.string_value = serde.serialize(self.value)
        return ws

    @property
    def value(self) -> time:
        if self._value is not None:
            v = self._value
            v = v.time() if isinstance(v, datetime) else v
            return v
        else:
            state = self.root.session_state
            assert state
            return state[self.id]  # type: ignore

    def increment(self) -> TimeInput:
        """Select the next available time."""
        dt = datetime.combine(date.today(), self.value) + timedelta(seconds=self.step)
        return self.set_value(dt.time())

    def decrement(self) -> TimeInput:
        """Select the previous available time."""
        dt = datetime.combine(date.today(), self.value) - timedelta(seconds=self.step)
        return self.set_value(dt.time())


@dataclass(init=False, repr=False)
class Block:
    type: str
    children: dict[int, Node]
    proto: BlockProto | None = field(repr=False)
    root: ElementTree = field(repr=False)

    def __init__(
        self,
        root: ElementTree,
        proto: BlockProto | None = None,
        type: str | None = None,
    ):
        self.children = {}
        self.proto = proto
        if proto:
            ty = proto.WhichOneof("type")
            # TODO does not work for `st.container` which has no block proto
            assert ty is not None
            self.type = ty
        elif type is not None:
            self.type = type
        else:
            self.type = ""
        self.root = root

    def __len__(self) -> int:
        return len(self.children)

    def __iter__(self):
        yield self
        for child_idx in self.children:
            for c in self.children[child_idx]:
                yield c

    def __getitem__(self, k: int) -> Node:
        return self.children[k]

    @property
    def key(self) -> str | None:
        return None

    @overload
    def get(self, element_type: Literal["text"]) -> Sequence[Text]:
        ...

    @overload
    def get(self, element_type: Literal["markdown"]) -> Sequence[Markdown]:
        ...

    @overload
    def get(self, element_type: Literal["caption"]) -> Sequence[Caption]:
        ...

    @overload
    def get(self, element_type: Literal["latex"]) -> Sequence[Latex]:
        ...

    @overload
    def get(self, element_type: Literal["code"]) -> Sequence[Code]:
        ...

    @overload
    def get(self, element_type: Literal["divider"]) -> Sequence[Divider]:
        ...

    @overload
    def get(self, element_type: Literal["title"]) -> Sequence[Title]:
        ...

    @overload
    def get(self, element_type: Literal["header"]) -> Sequence[Header]:
        ...

    @overload
    def get(self, element_type: Literal["subheader"]) -> Sequence[Subheader]:
        ...

    @overload
    def get(self, element_type: Literal["exception"]) -> Sequence[Exception]:
        ...

    @overload
    def get(self, element_type: Literal["radio"]) -> Sequence[Radio[Any]]:
        ...

    @overload
    def get(self, element_type: Literal["checkbox"]) -> Sequence[Checkbox]:
        ...

    @overload
    def get(self, element_type: Literal["multiselect"]) -> Sequence[Multiselect[Any]]:
        ...

    @overload
    def get(self, element_type: Literal["selectbox"]) -> Sequence[Selectbox[Any]]:
        ...

    @overload
    def get(self, element_type: Literal["slider"]) -> Sequence[Slider[Any]]:
        ...

    @overload
    def get(
        self, element_type: Literal["select_slider"]
    ) -> Sequence[SelectSlider[Any]]:
        ...

    @overload
    def get(self, element_type: Literal["button"]) -> Sequence[Button]:
        ...

    @overload
    def get(self, element_type: Literal["text_input"]) -> Sequence[TextInput]:
        ...

    @overload
    def get(self, element_type: Literal["text_area"]) -> Sequence[TextArea]:
        ...

    @overload
    def get(self, element_type: Literal["color_picker"]) -> Sequence[ColorPicker]:
        ...

    @overload
    def get(self, element_type: Literal["number_input"]) -> Sequence[NumberInput]:
        ...

    @overload
    def get(self, element_type: Literal["date_input"]) -> Sequence[DateInput]:
        ...

    @overload
    def get(self, element_type: Literal["time_input"]) -> Sequence[TimeInput]:
        ...

    def get(self, element_type: str) -> Sequence[Node]:
        return [e for e in self if e.type == element_type]

    def get_widget(self, key: str) -> Widget | None:
        for e in self:
            if e.key == key:
                assert isinstance(e, Widget)
                return e
        return None

    def widget_state(self) -> WidgetState | None:
        return None

    def run(self) -> ElementTree:
        return self.root.run()

    def __repr__(self):
        return util.repr_(self)


Node: TypeAlias = Union[Element, Block]


@dataclass(init=False, repr=False)
class ElementTree(Block):
    """A tree of the elements produced by running a streamlit script.

    This acts as the initial entrypoint for querying the produced elements,
    and interacting with widgets.

    Elements can be queried in three ways:
    - By element type, using `.get(...)` to get a list of all of that element,
    in the order they appear in the app
    - By user key, for widgets, using `.get_widget(...)` to get that widget node
    - Positionally, using list indexing syntax (`[...]`) to access a child of a
    block element. Not recommended because the exact tree structure can be surprising.

    Element queries made on a block will return only the elements descending
    from that block.

    Returned elements have methods for accessing whatever attributes are relevant.
    For very simple elements this may be only its value, while complex elements
    like widgets have many.

    Widgets provide a fluent API for faking frontend interaction and rerunning
    the script with the new widget values. All widgets provide a low level `set_value`
    method, along with higher level methods specific to that type of widget.
    After an interaction, calling `.run()` will return the ElementTree for
    the rerun.
    """

    type: str

    script_path: str | None = field(repr=False, default=None)
    _session_state: SessionState | None = field(repr=False, default=None)

    def __init__(self):
        # Expect script_path and session_state to be filled in afterwards
        self.children = {}
        self.root = self
        self.type = "root"

    @property
    def session_state(self) -> SessionState:
        assert self._session_state is not None
        return self._session_state

    def get_widget_states(self) -> WidgetStates:
        ws = WidgetStates()
        for node in self:
            w = node.widget_state()
            if w is not None:
                ws.widgets.append(w)

        return ws

    def run(self) -> ElementTree:
        assert self.script_path is not None
        from streamlit.testing.local_script_runner import LocalScriptRunner

        widget_states = self.get_widget_states()
        runner = LocalScriptRunner(self.script_path, self.session_state)
        return runner.run(widget_states)


def parse_tree_from_messages(messages: list[ForwardMsg]) -> ElementTree:
    """Transform a list of `ForwardMsg` into a tree matching the implicit
    tree structure of blocks and elements in a streamlit app.

    Returns the root of the tree, which acts as the entrypoint for the query
    and interaction API.
    """
    root = ElementTree()
    root.children = {
        0: Block(type="main", root=root),
        1: Block(type="sidebar", root=root),
    }

    for msg in messages:
        if not msg.HasField("delta"):
            continue
        delta_path = msg.metadata.delta_path
        delta = msg.delta
        if delta.WhichOneof("type") == "new_element":
            elt = delta.new_element
            new_node: Node
            if elt.WhichOneof("type") == "text":
                new_node = Text(elt.text, root=root)
            elif elt.WhichOneof("type") == "markdown":
                if elt.markdown.element_type == MarkdownProto.Type.NATIVE:
                    new_node = Markdown(elt.markdown, root=root)
                elif elt.markdown.element_type == MarkdownProto.Type.CAPTION:
                    new_node = Caption(elt.markdown, root=root)
                elif elt.markdown.element_type == MarkdownProto.Type.LATEX:
                    new_node = Latex(elt.markdown, root=root)
                elif elt.markdown.element_type == MarkdownProto.Type.DIVIDER:
                    new_node = Divider(elt.markdown, root=root)
                else:
                    raise ValueError(
                        f"Unknown markdown type {elt.markdown.element_type}"
                    )
            elif elt.WhichOneof("type") == "heading":
                if elt.heading.tag == HeadingProtoTag.TITLE_TAG.value:
                    new_node = Title(elt.heading, root=root)
                elif elt.heading.tag == HeadingProtoTag.HEADER_TAG.value:
                    new_node = Header(elt.heading, root=root)
                elif elt.heading.tag == HeadingProtoTag.SUBHEADER_TAG.value:
                    new_node = Subheader(elt.heading, root=root)
                else:
                    raise ValueError(f"Unknown heading type with tag {elt.heading.tag}")
            elif elt.WhichOneof("type") == "exception":
                new_node = Exception(elt.exception, root=root)
            elif elt.WhichOneof("type") == "radio":
                new_node = Radio(elt.radio, root=root)
            elif elt.WhichOneof("type") == "checkbox":
                new_node = Checkbox(elt.checkbox, root=root)
            elif elt.WhichOneof("type") == "multiselect":
                new_node = Multiselect(elt.multiselect, root=root)
            elif elt.WhichOneof("type") == "selectbox":
                new_node = Selectbox(elt.selectbox, root=root)
            elif elt.WhichOneof("type") == "slider":
                if elt.slider.type == SliderProto.Type.SLIDER:
                    new_node = Slider(elt.slider, root=root)
                elif elt.slider.type == SliderProto.Type.SELECT_SLIDER:
                    new_node = SelectSlider(elt.slider, root=root)
                else:
                    raise ValueError(f"Slider with unknown type {elt.slider}")
            elif elt.WhichOneof("type") == "button":
                new_node = Button(elt.button, root=root)
            elif elt.WhichOneof("type") == "text_input":
                new_node = TextInput(elt.text_input, root=root)
            elif elt.WhichOneof("type") == "text_area":
                new_node = TextArea(elt.text_area, root=root)
            elif elt.WhichOneof("type") == "number_input":
                new_node = NumberInput(elt.number_input, root=root)
            elif elt.WhichOneof("type") == "code":
                new_node = Code(elt.code, root=root)
            elif elt.WhichOneof("type") == "color_picker":
                new_node = ColorPicker(elt.color_picker, root=root)
            elif elt.WhichOneof("type") == "date_input":
                new_node = DateInput(elt.date_input, root=root)
            elif elt.WhichOneof("type") == "time_input":
                new_node = TimeInput(elt.time_input, root=root)
            else:
                new_node = Element(elt, root=root)
        elif delta.WhichOneof("type") == "add_block":
            new_node = Block(proto=delta.add_block, root=root)
        else:
            # add_rows
            continue

        current_node: Block = root
        # Every node up to the end is a Block
        for idx in delta_path[:-1]:
            children = current_node.children
            child = children.get(idx)
            if child is None:
                child = Block(root=root)
                children[idx] = child
            assert isinstance(child, Block)
            current_node = child
        current_node.children[delta_path[-1]] = new_node

    return root
