1# Copyright 2024 The Chromium Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5from __future__ import annotations 6 7import enum 8from typing import Any, Generic, Iterable, Tuple, TypeVar 9 10 11class UnexpectedStateError(RuntimeError): 12 13 def __init__(self, state: BaseState, expected: Iterable[BaseState]) -> None: 14 self._state = state 15 self._expected = tuple(expected) 16 names = ", ".join(tuple(s.name for s in expected)) 17 super().__init__(f"Unexpected state got={state.name} expected=({names})") 18 19 @property 20 def state(self) -> BaseState: 21 return self._state 22 23 @property 24 def expected(self) -> Tuple[BaseState, ...]: 25 return self._expected 26 27 28class BaseState(enum.IntEnum): 29 """Base class for StateMachine states.""" 30 31 32@enum.unique 33class State(BaseState): 34 """Default state implementation.""" 35 INITIAL = enum.auto() 36 SETUP = enum.auto() 37 READY = enum.auto() 38 RUN = enum.auto() 39 DONE = enum.auto() 40 41 42StateT = TypeVar("StateT", bound="BaseState") 43 44 45class StateMachine(Generic[StateT]): 46 47 def __init__(self, default: StateT) -> None: 48 self._state: StateT = default 49 50 @property 51 def state(self) -> StateT: 52 return self._state 53 54 @property 55 def name(self) -> str: 56 return self._state.name 57 58 def __eq__(self, other: Any) -> bool: 59 if self is other: 60 return True 61 if isinstance(other, StateMachine): 62 return self._state is other._state 63 if isinstance(other, type(self._state)): 64 return self._state is other 65 return False 66 67 def transition(self, *args: StateT, to: StateT) -> None: 68 self.expect(*args) 69 self._state = to 70 71 def expect(self, *args: StateT) -> None: 72 if self._state not in args: 73 raise UnexpectedStateError(self._state, args) 74 75 def expect_before(self, state: StateT) -> None: 76 if self._state >= state: 77 valid_states = (s for s in type(self._state) if s < state) 78 raise UnexpectedStateError(self._state, valid_states) 79 80 def expect_at_least(self, state: StateT) -> None: 81 if self._state < state: 82 valid_states = (s for s in type(self._state) if s >= state) 83 raise UnexpectedStateError(self._state, valid_states) 84 85 def __str__(self) -> str: 86 return f"{self._state}" 87