1 // Copyright 2019 The Amber Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SRC_TYPE_H_ 16 #define SRC_TYPE_H_ 17 18 #include <cassert> 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "src/format_data.h" 24 #include "src/make_unique.h" 25 26 namespace amber { 27 namespace type { 28 29 class List; 30 class Number; 31 class Struct; 32 33 class Type { 34 public: 35 Type(); 36 virtual ~Type(); 37 IsSignedInt(FormatMode mode)38 static bool IsSignedInt(FormatMode mode) { 39 return mode == FormatMode::kSInt || mode == FormatMode::kSNorm || 40 mode == FormatMode::kSScaled; 41 } 42 IsUnsignedInt(FormatMode mode)43 static bool IsUnsignedInt(FormatMode mode) { 44 return mode == FormatMode::kUInt || mode == FormatMode::kUNorm || 45 mode == FormatMode::kUScaled || mode == FormatMode::kSRGB; 46 } 47 IsInt(FormatMode mode)48 static bool IsInt(FormatMode mode) { 49 return IsSignedInt(mode) || IsUnsignedInt(mode); 50 } 51 IsFloat(FormatMode mode)52 static bool IsFloat(FormatMode mode) { 53 return mode == FormatMode::kSFloat || mode == FormatMode::kUFloat; 54 } 55 IsInt8(FormatMode mode,uint32_t num_bits)56 static bool IsInt8(FormatMode mode, uint32_t num_bits) { 57 return IsSignedInt(mode) && num_bits == 8; 58 } IsInt16(FormatMode mode,uint32_t num_bits)59 static bool IsInt16(FormatMode mode, uint32_t num_bits) { 60 return IsSignedInt(mode) && num_bits == 16; 61 } IsInt32(FormatMode mode,uint32_t num_bits)62 static bool IsInt32(FormatMode mode, uint32_t num_bits) { 63 return IsSignedInt(mode) && num_bits == 32; 64 } IsInt64(FormatMode mode,uint32_t num_bits)65 static bool IsInt64(FormatMode mode, uint32_t num_bits) { 66 return IsSignedInt(mode) && num_bits == 64; 67 } 68 IsUint8(FormatMode mode,uint32_t num_bits)69 static bool IsUint8(FormatMode mode, uint32_t num_bits) { 70 return IsUnsignedInt(mode) && num_bits == 8; 71 } IsUint16(FormatMode mode,uint32_t num_bits)72 static bool IsUint16(FormatMode mode, uint32_t num_bits) { 73 return IsUnsignedInt(mode) && num_bits == 16; 74 } IsUint32(FormatMode mode,uint32_t num_bits)75 static bool IsUint32(FormatMode mode, uint32_t num_bits) { 76 return IsUnsignedInt(mode) && num_bits == 32; 77 } IsUint64(FormatMode mode,uint32_t num_bits)78 static bool IsUint64(FormatMode mode, uint32_t num_bits) { 79 return IsUnsignedInt(mode) && num_bits == 64; 80 } 81 IsFloat16(FormatMode mode,uint32_t num_bits)82 static bool IsFloat16(FormatMode mode, uint32_t num_bits) { 83 return IsFloat(mode) && num_bits == 16; 84 } IsFloat32(FormatMode mode,uint32_t num_bits)85 static bool IsFloat32(FormatMode mode, uint32_t num_bits) { 86 return IsFloat(mode) && num_bits == 32; 87 } IsFloat64(FormatMode mode,uint32_t num_bits)88 static bool IsFloat64(FormatMode mode, uint32_t num_bits) { 89 return IsFloat(mode) && num_bits == 64; 90 } 91 92 // Returns the size in bytes of a single element of the type. This does not 93 // include space for arrays, vectors, etc. 94 virtual uint32_t SizeInBytes() const = 0; 95 96 virtual bool Equal(const Type* b) const = 0; 97 IsList()98 virtual bool IsList() const { return false; } IsNumber()99 virtual bool IsNumber() const { return false; } IsStruct()100 virtual bool IsStruct() const { return false; } 101 102 List* AsList(); 103 Number* AsNumber(); 104 Struct* AsStruct(); 105 106 const List* AsList() const; 107 const Number* AsNumber() const; 108 const Struct* AsStruct() const; 109 SetRowCount(uint32_t size)110 void SetRowCount(uint32_t size) { row_count_ = size; } RowCount()111 uint32_t RowCount() const { return row_count_; } 112 SetColumnCount(uint32_t size)113 void SetColumnCount(uint32_t size) { column_count_ = size; } ColumnCount()114 uint32_t ColumnCount() const { return column_count_; } 115 SetIsRuntimeArray()116 void SetIsRuntimeArray() { is_array_ = true; } SetIsSizedArray(uint32_t size)117 void SetIsSizedArray(uint32_t size) { 118 is_array_ = true; 119 array_size_ = size; 120 } IsArray()121 bool IsArray() const { return is_array_; } IsSizedArray()122 bool IsSizedArray() const { return is_array_ && array_size_ > 0; } IsRuntimeArray()123 bool IsRuntimeArray() const { return is_array_ && array_size_ == 0; } ArraySize()124 uint32_t ArraySize() const { return array_size_; } 125 IsVec()126 bool IsVec() const { return column_count_ == 1 && row_count_ > 1; } 127 128 // Returns true if this type holds a vec3. IsVec3()129 bool IsVec3() const { return column_count_ == 1 && row_count_ == 3; } 130 131 // Returns true if this type holds a matrix. IsMatrix()132 bool IsMatrix() const { return column_count_ > 1 && row_count_ > 1; } 133 134 private: 135 uint32_t row_count_ = 1; 136 uint32_t column_count_ = 1; 137 uint32_t array_size_ = 0; 138 bool is_array_ = false; 139 }; 140 141 class Number : public Type { 142 public: 143 explicit Number(FormatMode mode); 144 Number(FormatMode mode, uint32_t bits); 145 ~Number() override; 146 147 static std::unique_ptr<Number> Int(uint32_t bits); 148 static std::unique_ptr<Number> Uint(uint32_t bits); 149 static std::unique_ptr<Number> Float(uint32_t bits); 150 IsNumber()151 bool IsNumber() const override { return true; } 152 NumBits()153 uint32_t NumBits() const { return bits_; } SizeInBytes()154 uint32_t SizeInBytes() const override { return bits_ / 8; } 155 Equal(const Type * b)156 bool Equal(const Type* b) const override { 157 if (!b->IsNumber()) 158 return false; 159 160 auto n = b->AsNumber(); 161 return format_mode_ == n->format_mode_ && bits_ == n->bits_; 162 } 163 GetFormatMode()164 FormatMode GetFormatMode() const { return format_mode_; } 165 166 private: 167 FormatMode format_mode_ = FormatMode::kSInt; 168 uint32_t bits_ = 32; 169 }; 170 171 // The list type only holds lists of scalar float and int values. 172 class List : public Type { 173 public: 174 struct Member { MemberMember175 Member(FormatComponentType t, FormatMode m, uint32_t b) 176 : name(t), mode(m), num_bits(b) {} 177 SizeInBytesMember178 uint32_t SizeInBytes() const { return num_bits / 8; } 179 180 FormatComponentType name = FormatComponentType::kR; 181 FormatMode mode = FormatMode::kSInt; 182 uint32_t num_bits = 0; 183 }; 184 185 List(); 186 ~List() override; 187 IsList()188 bool IsList() const override { return true; } 189 Equal(const Type * b)190 bool Equal(const Type* b) const override { 191 if (!b->IsList()) 192 return false; 193 194 auto l = b->AsList(); 195 if (pack_size_in_bits_ != l->pack_size_in_bits_) 196 return false; 197 if (members_.size() != l->members_.size()) 198 return false; 199 200 auto& lm = l->Members(); 201 for (size_t i = 0; i < members_.size(); ++i) { 202 if (members_[i].name != lm[i].name) 203 return false; 204 if (members_[i].mode != lm[i].mode) 205 return false; 206 if (members_[i].num_bits != lm[i].num_bits) 207 return false; 208 } 209 return true; 210 } 211 SetPackSizeInBits(uint32_t size)212 void SetPackSizeInBits(uint32_t size) { pack_size_in_bits_ = size; } PackSizeInBits()213 uint32_t PackSizeInBits() const { return pack_size_in_bits_; } IsPacked()214 bool IsPacked() const { return pack_size_in_bits_ > 0; } 215 AddMember(FormatComponentType name,FormatMode mode,uint32_t num_bits)216 void AddMember(FormatComponentType name, FormatMode mode, uint32_t num_bits) { 217 members_.push_back({name, mode, num_bits}); 218 } 219 Members()220 const std::vector<Member>& Members() const { return members_; } Members()221 std::vector<Member>& Members() { return members_; } 222 223 uint32_t SizeInBytes() const override; 224 225 private: 226 std::vector<Member> members_; 227 uint32_t pack_size_in_bits_ = 0; 228 }; 229 230 class Struct : public Type { 231 public: 232 struct Member { 233 std::string name; 234 Type* type; 235 int32_t offset_in_bytes = -1; 236 int32_t array_stride_in_bytes = -1; 237 int32_t matrix_stride_in_bytes = -1; 238 HasOffsetMember239 bool HasOffset() const { return offset_in_bytes >= 0; } HasArrayStrideMember240 bool HasArrayStride() const { return array_stride_in_bytes > 0; } HasMatrixStrideMember241 bool HasMatrixStride() const { return matrix_stride_in_bytes > 0; } 242 }; 243 244 Struct(); 245 ~Struct() override; 246 247 uint32_t SizeInBytes() const override; IsStruct()248 bool IsStruct() const override { return true; } 249 Equal(const Type * b)250 bool Equal(const Type* b) const override { 251 if (!b->IsStruct()) 252 return false; 253 254 auto s = b->AsStruct(); 255 if (is_stride_specified_ != s->is_stride_specified_) 256 return false; 257 if (stride_in_bytes_ != s->stride_in_bytes_) 258 return false; 259 if (members_.size() != s->members_.size()) 260 return false; 261 262 auto& sm = s->Members(); 263 for (size_t i = 0; i < members_.size(); ++i) { 264 if (members_[i].offset_in_bytes != sm[i].offset_in_bytes) 265 return false; 266 if (members_[i].array_stride_in_bytes != sm[i].array_stride_in_bytes) 267 return false; 268 if (members_[i].matrix_stride_in_bytes != sm[i].matrix_stride_in_bytes) 269 return false; 270 if (!members_[i].type->Equal(sm[i].type)) 271 return false; 272 } 273 return true; 274 } 275 HasStride()276 bool HasStride() const { return is_stride_specified_; } StrideInBytes()277 uint32_t StrideInBytes() const { return stride_in_bytes_; } SetStrideInBytes(uint32_t stride)278 void SetStrideInBytes(uint32_t stride) { 279 is_stride_specified_ = true; 280 stride_in_bytes_ = stride; 281 } 282 AddMember(Type * type)283 Member* AddMember(Type* type) { 284 members_.push_back({}); 285 members_.back().type = type; 286 return &members_.back(); 287 } 288 Members()289 const std::vector<Member>& Members() const { return members_; } 290 291 private: 292 std::vector<Member> members_; 293 bool is_stride_specified_ = false; 294 uint32_t stride_in_bytes_ = 0; 295 }; 296 297 } // namespace type 298 } // namespace amber 299 300 #endif // SRC_TYPE_H_ 301