• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5from __future__ import annotations
6
7import enum
8from typing import Any, Generic, Iterable, Tuple, TypeVar
9
10
11class UnexpectedStateError(RuntimeError):
12
13  def __init__(self, state: BaseState, expected: Iterable[BaseState]) -> None:
14    self._state = state
15    self._expected = tuple(expected)
16    names = ", ".join(tuple(s.name for s in expected))
17    super().__init__(f"Unexpected state got={state.name} expected=({names})")
18
19  @property
20  def state(self) -> BaseState:
21    return self._state
22
23  @property
24  def expected(self) -> Tuple[BaseState, ...]:
25    return self._expected
26
27
28class BaseState(enum.IntEnum):
29  """Base class for StateMachine states."""
30
31
32@enum.unique
33class State(BaseState):
34  """Default state implementation."""
35  INITIAL = enum.auto()
36  SETUP = enum.auto()
37  READY = enum.auto()
38  RUN = enum.auto()
39  DONE = enum.auto()
40
41
42StateT = TypeVar("StateT", bound="BaseState")
43
44
45class StateMachine(Generic[StateT]):
46
47  def __init__(self, default: StateT) -> None:
48    self._state: StateT = default
49
50  @property
51  def state(self) -> StateT:
52    return self._state
53
54  @property
55  def name(self) -> str:
56    return self._state.name
57
58  def __eq__(self, other: Any) -> bool:
59    if self is other:
60      return True
61    if isinstance(other, StateMachine):
62      return self._state is other._state
63    if isinstance(other, type(self._state)):
64      return self._state is other
65    return False
66
67  def transition(self, *args: StateT, to: StateT) -> None:
68    self.expect(*args)
69    self._state = to
70
71  def expect(self, *args: StateT) -> None:
72    if self._state not in args:
73      raise UnexpectedStateError(self._state, args)
74
75  def expect_before(self, state: StateT) -> None:
76    if self._state >= state:
77      valid_states = (s for s in type(self._state) if s < state)
78      raise UnexpectedStateError(self._state, valid_states)
79
80  def expect_at_least(self, state: StateT) -> None:
81    if self._state < state:
82      valid_states = (s for s in type(self._state) if s >= state)
83      raise UnexpectedStateError(self._state, valid_states)
84
85  def __str__(self) -> str:
86    return f"{self._state}"
87