• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines the type system."""
17
18from abc import ABCMeta, abstractmethod
19from dataclasses import dataclass
20from enum import Enum
21from typing import TYPE_CHECKING, Protocol, TypeVar
22
23from typing_extensions import override
24
25if TYPE_CHECKING:
26    from taihe.semantics.declarations import (
27        EnumDecl,
28        IfaceDecl,
29        StructDecl,
30        TypeDecl,
31        TypeRefDecl,
32        UnionDecl,
33    )
34    from taihe.semantics.visitor import TypeVisitor
35
36T = TypeVar("T")
37
38############################
39# Infrastructure for Types #
40############################
41
42
43class TypeProtocol(Protocol):
44    def _accept(self, v: "TypeVisitor[T]") -> T:
45        ...
46
47
48@dataclass(frozen=True, repr=False)
49class Type(metaclass=ABCMeta):
50    """Base class for all types."""
51
52    ty_ref: "TypeRefDecl"
53
54    def __repr__(self) -> str:
55        return f"<{self.__class__.__qualname__} {self.signature}>"
56
57    @property
58    @abstractmethod
59    def signature(self) -> str:
60        """Return the representation of the type."""
61
62    @abstractmethod
63    def _accept(self, v: "TypeVisitor[T]") -> T:
64        """Accept a visitor."""
65
66
67##################
68# Built-in Types #
69##################
70
71
72@dataclass(frozen=True, repr=False)
73class BuiltinType(Type, metaclass=ABCMeta):
74    """Represents a built-in type."""
75
76
77class ScalarKind(Enum):
78    """Enumeration of scalar types."""
79
80    BOOL = ("bool", 8, False, False)
81    F32 = ("f32", 32, True, True)
82    F64 = ("f64", 64, True, True)
83    I8 = ("i8", 8, True, False)
84    I16 = ("i16", 16, True, False)
85    I32 = ("i32", 32, True, False)
86    I64 = ("i64", 64, True, False)
87    U8 = ("u8", 8, False, False)
88    U16 = ("u16", 16, False, False)
89    U32 = ("u32", 32, False, False)
90    U64 = ("u64", 64, False, False)
91
92    def __init__(self, symbol: str, width: int, is_signed: bool, is_float: bool):
93        self.symbol = symbol
94        self.width = width
95        self.is_signed = is_signed
96        self.is_float = is_float
97
98
99@dataclass(frozen=True, repr=False)
100class ScalarType(BuiltinType):
101    kind: ScalarKind
102
103    @property
104    @override
105    def signature(self):
106        return self.kind.symbol
107
108    @override
109    def _accept(self, v: "TypeVisitor[T]") -> T:
110        return v.visit_scalar_type(self)
111
112
113@dataclass(frozen=True, repr=False)
114class OpaqueType(BuiltinType):
115    @property
116    @override
117    def signature(self):
118        return "Opaque"
119
120    @override
121    def _accept(self, v: "TypeVisitor[T]") -> T:
122        return v.visit_opaque_type(self)
123
124
125@dataclass(frozen=True, repr=False)
126class StringType(BuiltinType):
127    @property
128    @override
129    def signature(self):
130        return "String"
131
132    @override
133    def _accept(self, v: "TypeVisitor[T]") -> T:
134        return v.visit_string_type(self)
135
136
137class BuiltinBuilder(Protocol):
138    def __call__(self, ty_ref: "TypeRefDecl") -> BuiltinType:
139        ...
140
141
142# Builtin Types Map
143BUILTIN_TYPES: dict[str, BuiltinBuilder] = {
144    "bool": lambda ty_ref: ScalarType(ty_ref, ScalarKind.BOOL),
145    "f32": lambda ty_ref: ScalarType(ty_ref, ScalarKind.F32),
146    "f64": lambda ty_ref: ScalarType(ty_ref, ScalarKind.F64),
147    "i8": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I8),
148    "i16": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I16),
149    "i32": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I32),
150    "i64": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I64),
151    "u8": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U8),
152    "u16": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U16),
153    "u32": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U32),
154    "u64": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U64),
155    "String": lambda ty_ref: StringType(ty_ref),
156    "Opaque": lambda ty_ref: OpaqueType(ty_ref),
157}
158
159
160####################
161# Builtin Generics #
162####################
163
164
165@dataclass(frozen=True, repr=False)
166class CallbackType(Type):
167    return_ty: Type | None
168    params_ty: tuple[Type, ...]
169
170    @property
171    @override
172    def signature(self):
173        return_fmt = ty.signature if (ty := self.return_ty) else "void"
174        params_fmt = ", ".join(ty.signature for ty in self.params_ty)
175        return f"({params_fmt}) => {return_fmt}"
176
177    @override
178    def _accept(self, v: "TypeVisitor[T]") -> T:
179        return v.visit_callback_type(self)
180
181
182class GenericType(Type, metaclass=ABCMeta):
183    pass
184
185
186@dataclass(frozen=True, repr=False)
187class ArrayType(GenericType):
188    item_ty: Type
189
190    @property
191    @override
192    def signature(self):
193        return f"Array<{self.item_ty.signature}>"
194
195    @override
196    def _accept(self, v: "TypeVisitor[T]") -> T:
197        return v.visit_array_type(self)
198
199
200@dataclass(frozen=True, repr=False)
201class OptionalType(GenericType):
202    item_ty: Type
203
204    @property
205    @override
206    def signature(self):
207        return f"Optional<{self.item_ty.signature}>"
208
209    @override
210    def _accept(self, v: "TypeVisitor[T]") -> T:
211        return v.visit_optional_type(self)
212
213
214@dataclass(frozen=True, repr=False)
215class VectorType(GenericType):
216    val_ty: Type
217
218    @property
219    @override
220    def signature(self):
221        return f"Vector<{self.val_ty.signature}>"
222
223    @override
224    def _accept(self, v: "TypeVisitor[T]") -> T:
225        return v.visit_vector_type(self)
226
227
228@dataclass(frozen=True, repr=False)
229class MapType(GenericType):
230    key_ty: Type
231    val_ty: Type
232
233    @property
234    @override
235    def signature(self):
236        return f"Map<{self.key_ty.signature}, {self.val_ty.signature}>"
237
238    @override
239    def _accept(self, v: "TypeVisitor[T]") -> T:
240        return v.visit_map_type(self)
241
242
243@dataclass(frozen=True, repr=False)
244class SetType(GenericType):
245    key_ty: Type
246
247    @property
248    @override
249    def signature(self):
250        return f"Set<{self.key_ty.signature}>"
251
252    @override
253    def _accept(self, v: "TypeVisitor[T]") -> T:
254        return v.visit_set_type(self)
255
256
257class GenericBuilder(Protocol):
258    def __call__(self, ty_ref: "TypeRefDecl", *args: Type) -> GenericType:
259        ...
260
261
262# Builtin Generics Map
263BUILTIN_GENERICS: dict[str, GenericBuilder] = {
264    "Array": lambda ty_ref, *args: ArrayType(ty_ref, *args),
265    "Optional": lambda ty_ref, *args: OptionalType(ty_ref, *args),
266    "Vector": lambda ty_ref, *args: VectorType(ty_ref, *args),
267    "Map": lambda ty_ref, *args: MapType(ty_ref, *args),
268    "Set": lambda ty_ref, *args: SetType(ty_ref, *args),
269}
270
271
272##############
273# User Types #
274##############
275
276
277@dataclass(frozen=True, repr=False)
278class UserType(Type, metaclass=ABCMeta):
279    ty_decl: "TypeDecl"
280
281    @property
282    @override
283    def signature(self):
284        return f"{self.ty_decl.full_name}"
285
286
287@dataclass(frozen=True, repr=False)
288class EnumType(UserType):
289    ty_decl: "EnumDecl"
290
291    @override
292    def _accept(self, v: "TypeVisitor[T]") -> T:
293        return v.visit_enum_type(self)
294
295
296@dataclass(frozen=True, repr=False)
297class StructType(UserType):
298    ty_decl: "StructDecl"
299
300    @override
301    def _accept(self, v: "TypeVisitor[T]") -> T:
302        return v.visit_struct_type(self)
303
304
305@dataclass(frozen=True, repr=False)
306class UnionType(UserType):
307    ty_decl: "UnionDecl"
308
309    @override
310    def _accept(self, v: "TypeVisitor[T]") -> T:
311        return v.visit_union_type(self)
312
313
314@dataclass(frozen=True, repr=False)
315class IfaceType(UserType):
316    ty_decl: "IfaceDecl"
317
318    @override
319    def _accept(self, v: "TypeVisitor[T]") -> T:
320        return v.visit_iface_type(self)