1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""Defines the type system.""" 17 18from abc import ABCMeta, abstractmethod 19from dataclasses import dataclass 20from enum import Enum 21from typing import TYPE_CHECKING, Protocol, TypeVar 22 23from typing_extensions import override 24 25if TYPE_CHECKING: 26 from taihe.semantics.declarations import ( 27 EnumDecl, 28 IfaceDecl, 29 StructDecl, 30 TypeDecl, 31 TypeRefDecl, 32 UnionDecl, 33 ) 34 from taihe.semantics.visitor import TypeVisitor 35 36T = TypeVar("T") 37 38############################ 39# Infrastructure for Types # 40############################ 41 42 43class TypeProtocol(Protocol): 44 def _accept(self, v: "TypeVisitor[T]") -> T: 45 ... 46 47 48@dataclass(frozen=True, repr=False) 49class Type(metaclass=ABCMeta): 50 """Base class for all types.""" 51 52 ty_ref: "TypeRefDecl" 53 54 def __repr__(self) -> str: 55 return f"<{self.__class__.__qualname__} {self.signature}>" 56 57 @property 58 @abstractmethod 59 def signature(self) -> str: 60 """Return the representation of the type.""" 61 62 @abstractmethod 63 def _accept(self, v: "TypeVisitor[T]") -> T: 64 """Accept a visitor.""" 65 66 67################## 68# Built-in Types # 69################## 70 71 72@dataclass(frozen=True, repr=False) 73class BuiltinType(Type, metaclass=ABCMeta): 74 """Represents a built-in type.""" 75 76 77class ScalarKind(Enum): 78 """Enumeration of scalar types.""" 79 80 BOOL = ("bool", 8, False, False) 81 F32 = ("f32", 32, True, True) 82 F64 = ("f64", 64, True, True) 83 I8 = ("i8", 8, True, False) 84 I16 = ("i16", 16, True, False) 85 I32 = ("i32", 32, True, False) 86 I64 = ("i64", 64, True, False) 87 U8 = ("u8", 8, False, False) 88 U16 = ("u16", 16, False, False) 89 U32 = ("u32", 32, False, False) 90 U64 = ("u64", 64, False, False) 91 92 def __init__(self, symbol: str, width: int, is_signed: bool, is_float: bool): 93 self.symbol = symbol 94 self.width = width 95 self.is_signed = is_signed 96 self.is_float = is_float 97 98 99@dataclass(frozen=True, repr=False) 100class ScalarType(BuiltinType): 101 kind: ScalarKind 102 103 @property 104 @override 105 def signature(self): 106 return self.kind.symbol 107 108 @override 109 def _accept(self, v: "TypeVisitor[T]") -> T: 110 return v.visit_scalar_type(self) 111 112 113@dataclass(frozen=True, repr=False) 114class OpaqueType(BuiltinType): 115 @property 116 @override 117 def signature(self): 118 return "Opaque" 119 120 @override 121 def _accept(self, v: "TypeVisitor[T]") -> T: 122 return v.visit_opaque_type(self) 123 124 125@dataclass(frozen=True, repr=False) 126class StringType(BuiltinType): 127 @property 128 @override 129 def signature(self): 130 return "String" 131 132 @override 133 def _accept(self, v: "TypeVisitor[T]") -> T: 134 return v.visit_string_type(self) 135 136 137class BuiltinBuilder(Protocol): 138 def __call__(self, ty_ref: "TypeRefDecl") -> BuiltinType: 139 ... 140 141 142# Builtin Types Map 143BUILTIN_TYPES: dict[str, BuiltinBuilder] = { 144 "bool": lambda ty_ref: ScalarType(ty_ref, ScalarKind.BOOL), 145 "f32": lambda ty_ref: ScalarType(ty_ref, ScalarKind.F32), 146 "f64": lambda ty_ref: ScalarType(ty_ref, ScalarKind.F64), 147 "i8": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I8), 148 "i16": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I16), 149 "i32": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I32), 150 "i64": lambda ty_ref: ScalarType(ty_ref, ScalarKind.I64), 151 "u8": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U8), 152 "u16": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U16), 153 "u32": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U32), 154 "u64": lambda ty_ref: ScalarType(ty_ref, ScalarKind.U64), 155 "String": lambda ty_ref: StringType(ty_ref), 156 "Opaque": lambda ty_ref: OpaqueType(ty_ref), 157} 158 159 160#################### 161# Builtin Generics # 162#################### 163 164 165@dataclass(frozen=True, repr=False) 166class CallbackType(Type): 167 return_ty: Type | None 168 params_ty: tuple[Type, ...] 169 170 @property 171 @override 172 def signature(self): 173 return_fmt = ty.signature if (ty := self.return_ty) else "void" 174 params_fmt = ", ".join(ty.signature for ty in self.params_ty) 175 return f"({params_fmt}) => {return_fmt}" 176 177 @override 178 def _accept(self, v: "TypeVisitor[T]") -> T: 179 return v.visit_callback_type(self) 180 181 182class GenericType(Type, metaclass=ABCMeta): 183 pass 184 185 186@dataclass(frozen=True, repr=False) 187class ArrayType(GenericType): 188 item_ty: Type 189 190 @property 191 @override 192 def signature(self): 193 return f"Array<{self.item_ty.signature}>" 194 195 @override 196 def _accept(self, v: "TypeVisitor[T]") -> T: 197 return v.visit_array_type(self) 198 199 200@dataclass(frozen=True, repr=False) 201class OptionalType(GenericType): 202 item_ty: Type 203 204 @property 205 @override 206 def signature(self): 207 return f"Optional<{self.item_ty.signature}>" 208 209 @override 210 def _accept(self, v: "TypeVisitor[T]") -> T: 211 return v.visit_optional_type(self) 212 213 214@dataclass(frozen=True, repr=False) 215class VectorType(GenericType): 216 val_ty: Type 217 218 @property 219 @override 220 def signature(self): 221 return f"Vector<{self.val_ty.signature}>" 222 223 @override 224 def _accept(self, v: "TypeVisitor[T]") -> T: 225 return v.visit_vector_type(self) 226 227 228@dataclass(frozen=True, repr=False) 229class MapType(GenericType): 230 key_ty: Type 231 val_ty: Type 232 233 @property 234 @override 235 def signature(self): 236 return f"Map<{self.key_ty.signature}, {self.val_ty.signature}>" 237 238 @override 239 def _accept(self, v: "TypeVisitor[T]") -> T: 240 return v.visit_map_type(self) 241 242 243@dataclass(frozen=True, repr=False) 244class SetType(GenericType): 245 key_ty: Type 246 247 @property 248 @override 249 def signature(self): 250 return f"Set<{self.key_ty.signature}>" 251 252 @override 253 def _accept(self, v: "TypeVisitor[T]") -> T: 254 return v.visit_set_type(self) 255 256 257class GenericBuilder(Protocol): 258 def __call__(self, ty_ref: "TypeRefDecl", *args: Type) -> GenericType: 259 ... 260 261 262# Builtin Generics Map 263BUILTIN_GENERICS: dict[str, GenericBuilder] = { 264 "Array": lambda ty_ref, *args: ArrayType(ty_ref, *args), 265 "Optional": lambda ty_ref, *args: OptionalType(ty_ref, *args), 266 "Vector": lambda ty_ref, *args: VectorType(ty_ref, *args), 267 "Map": lambda ty_ref, *args: MapType(ty_ref, *args), 268 "Set": lambda ty_ref, *args: SetType(ty_ref, *args), 269} 270 271 272############## 273# User Types # 274############## 275 276 277@dataclass(frozen=True, repr=False) 278class UserType(Type, metaclass=ABCMeta): 279 ty_decl: "TypeDecl" 280 281 @property 282 @override 283 def signature(self): 284 return f"{self.ty_decl.full_name}" 285 286 287@dataclass(frozen=True, repr=False) 288class EnumType(UserType): 289 ty_decl: "EnumDecl" 290 291 @override 292 def _accept(self, v: "TypeVisitor[T]") -> T: 293 return v.visit_enum_type(self) 294 295 296@dataclass(frozen=True, repr=False) 297class StructType(UserType): 298 ty_decl: "StructDecl" 299 300 @override 301 def _accept(self, v: "TypeVisitor[T]") -> T: 302 return v.visit_struct_type(self) 303 304 305@dataclass(frozen=True, repr=False) 306class UnionType(UserType): 307 ty_decl: "UnionDecl" 308 309 @override 310 def _accept(self, v: "TypeVisitor[T]") -> T: 311 return v.visit_union_type(self) 312 313 314@dataclass(frozen=True, repr=False) 315class IfaceType(UserType): 316 ty_decl: "IfaceDecl" 317 318 @override 319 def _accept(self, v: "TypeVisitor[T]") -> T: 320 return v.visit_iface_type(self)