• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/intrinsic_table.h"
16 
17 #include <algorithm>
18 #include <limits>
19 #include <unordered_map>
20 #include <utility>
21 
22 #include "src/program_builder.h"
23 #include "src/sem/atomic_type.h"
24 #include "src/sem/depth_multisampled_texture_type.h"
25 #include "src/sem/depth_texture_type.h"
26 #include "src/sem/external_texture_type.h"
27 #include "src/sem/multisampled_texture_type.h"
28 #include "src/sem/pipeline_stage_set.h"
29 #include "src/sem/sampled_texture_type.h"
30 #include "src/sem/storage_texture_type.h"
31 #include "src/utils/hash.h"
32 #include "src/utils/map.h"
33 #include "src/utils/math.h"
34 #include "src/utils/scoped_assignment.h"
35 
36 namespace tint {
37 namespace {
38 
39 // Forward declarations
40 struct OverloadInfo;
41 class Matchers;
42 class NumberMatcher;
43 class TypeMatcher;
44 
45 /// A special type that matches all TypeMatchers
46 class Any : public Castable<Any, sem::Type> {
47  public:
48   Any() = default;
49   ~Any() override = default;
type_name() const50   std::string type_name() const override { return "<any>"; }
FriendlyName(const SymbolTable &) const51   std::string FriendlyName(const SymbolTable&) const override {
52     return "<any>";
53   }
54 };
55 
56 /// Number is an 32 bit unsigned integer, which can be in one of three states:
57 /// * Invalid - Number has not been assigned a value
58 /// * Valid   - a fixed integer value
59 /// * Any     - matches any other non-invalid number
60 struct Number {
61   static const Number any;
62   static const Number invalid;
63 
64   /// Constructed as a valid number with the value v
Numbertint::__anon033464eb0111::Number65   explicit Number(uint32_t v) : value_(v), state_(kValid) {}
66 
67   /// @returns the value of the number
Valuetint::__anon033464eb0111::Number68   inline uint32_t Value() const { return value_; }
69 
70   /// @returns the true if the number is valid
IsValidtint::__anon033464eb0111::Number71   inline bool IsValid() const { return state_ == kValid; }
72 
73   /// @returns the true if the number is any
IsAnytint::__anon033464eb0111::Number74   inline bool IsAny() const { return state_ == kAny; }
75 
76   /// Assignment operator.
77   /// The number becomes valid, with the value n
operator =tint::__anon033464eb0111::Number78   inline Number& operator=(uint32_t n) {
79     value_ = n;
80     state_ = kValid;
81     return *this;
82   }
83 
84  private:
85   enum State {
86     kInvalid,
87     kValid,
88     kAny,
89   };
90 
Numbertint::__anon033464eb0111::Number91   constexpr explicit Number(State state) : state_(state) {}
92 
93   uint32_t value_ = 0;
94   State state_ = kInvalid;
95 };
96 
97 const Number Number::any{Number::kAny};
98 const Number Number::invalid{Number::kInvalid};
99 
100 /// ClosedState holds the state of the open / closed numbers and types.
101 /// Used by the MatchState.
102 class ClosedState {
103  public:
ClosedState(ProgramBuilder & b)104   explicit ClosedState(ProgramBuilder& b) : builder(b) {}
105 
106   /// If the type with index `idx` is open, then it is closed with type `ty` and
107   /// Type() returns true. If the type is closed, then `Type()` returns true iff
108   /// it is equal to `ty`.
Type(uint32_t idx,const sem::Type * ty)109   bool Type(uint32_t idx, const sem::Type* ty) {
110     auto res = types_.emplace(idx, ty);
111     return res.second || res.first->second == ty;
112   }
113 
114   /// If the number with index `idx` is open, then it is closed with number
115   /// `number` and Num() returns true. If the number is closed, then `Num()`
116   /// returns true iff it is equal to `ty`.
Num(uint32_t idx,Number number)117   bool Num(uint32_t idx, Number number) {
118     auto res = numbers_.emplace(idx, number.Value());
119     return res.second || res.first->second == number.Value();
120   }
121 
122   /// Type returns the closed type with index `idx`.
123   /// An ICE is raised if the type is not closed.
Type(uint32_t idx) const124   const sem::Type* Type(uint32_t idx) const {
125     auto it = types_.find(idx);
126     if (it == types_.end()) {
127       TINT_ICE(Resolver, builder.Diagnostics())
128           << "type with index " << idx << " is not closed";
129       return nullptr;
130     }
131     TINT_ASSERT(Resolver, it != types_.end());
132     return it->second;
133   }
134 
135   /// Type returns the number type with index `idx`.
136   /// An ICE is raised if the number is not closed.
Num(uint32_t idx) const137   Number Num(uint32_t idx) const {
138     auto it = numbers_.find(idx);
139     if (it == numbers_.end()) {
140       TINT_ICE(Resolver, builder.Diagnostics())
141           << "number with index " << idx << " is not closed";
142       return Number::invalid;
143     }
144     return Number(it->second);
145   }
146 
147  private:
148   ProgramBuilder& builder;
149   std::unordered_map<uint32_t, const sem::Type*> types_;
150   std::unordered_map<uint32_t, uint32_t> numbers_;
151 };
152 
153 /// Index type used for matcher indices
154 using MatcherIndex = uint8_t;
155 
156 /// Index value used for open types / numbers that do not have a constraint
157 constexpr MatcherIndex kNoMatcher = std::numeric_limits<MatcherIndex>::max();
158 
159 /// MatchState holds the state used to match an overload.
160 class MatchState {
161  public:
MatchState(ProgramBuilder & b,ClosedState & c,const Matchers & m,const OverloadInfo & o,MatcherIndex const * matcher_indices)162   MatchState(ProgramBuilder& b,
163              ClosedState& c,
164              const Matchers& m,
165              const OverloadInfo& o,
166              MatcherIndex const* matcher_indices)
167       : builder(b),
168         closed(c),
169         matchers(m),
170         overload(o),
171         matcher_indices_(matcher_indices) {}
172 
173   /// The program builder
174   ProgramBuilder& builder;
175   /// The open / closed types and numbers
176   ClosedState& closed;
177   /// The type and number matchers
178   Matchers const& matchers;
179   /// The current overload being evaluated
180   OverloadInfo const& overload;
181 
182   /// Type uses the next TypeMatcher from the matcher indices to match the type
183   /// `ty`. If the type matches, the canonical expected type is returned. If the
184   /// type `ty` does not match, then nullptr is returned.
185   /// @note: The matcher indices are progressed on calling.
186   const sem::Type* Type(const sem::Type* ty);
187 
188   /// Num uses the next NumMatcher from the matcher indices to match the number
189   /// `num`. If the number matches, the canonical expected number is returned.
190   /// If the number `num` does not match, then an invalid number is returned.
191   /// @note: The matcher indices are progressed on calling.
192   Number Num(Number num);
193 
194   /// @returns a string representation of the next TypeMatcher from the matcher
195   /// indices.
196   /// @note: The matcher indices are progressed on calling.
197   std::string TypeName();
198 
199   /// @returns a string representation of the next NumberMatcher from the
200   /// matcher indices.
201   /// @note: The matcher indices are progressed on calling.
202   std::string NumName();
203 
204  private:
205   MatcherIndex const* matcher_indices_ = nullptr;
206 };
207 
208 /// A TypeMatcher is the interface used to match an type used as part of an
209 /// overload's parameter or return type.
210 class TypeMatcher {
211  public:
212   /// Destructor
213   virtual ~TypeMatcher() = default;
214 
215   /// Checks whether the given type matches the matcher rules, and returns the
216   /// expected, canonicalized type on success.
217   /// Match may close open types and numbers in state.
218   /// @param type the type to match
219   /// @returns the canonicalized type on match, otherwise nullptr
220   virtual const sem::Type* Match(MatchState& state,
221                                  const sem::Type* type) const = 0;
222 
223   /// @return a string representation of the matcher. Used for printing error
224   /// messages when no overload is found.
225   virtual std::string String(MatchState& state) const = 0;
226 };
227 
228 /// A NumberMatcher is the interface used to match a number or enumerator used
229 /// as part of an overload's parameter or return type.
230 class NumberMatcher {
231  public:
232   /// Destructor
233   virtual ~NumberMatcher() = default;
234 
235   /// Checks whether the given number matches the matcher rules.
236   /// Match may close open numbers in state.
237   /// @param number the number to match
238   /// @returns true if the argument type is as expected.
239   virtual Number Match(MatchState& state, Number number) const = 0;
240 
241   /// @return a string representation of the matcher. Used for printing error
242   /// messages when no overload is found.
243   virtual std::string String(MatchState& state) const = 0;
244 };
245 
246 /// OpenTypeMatcher is a Matcher for an open type.
247 /// The OpenTypeMatcher will match against any type (so long as it is consistent
248 /// across all uses in the overload)
249 class OpenTypeMatcher : public TypeMatcher {
250  public:
251   /// Constructor
OpenTypeMatcher(uint32_t index)252   explicit OpenTypeMatcher(uint32_t index) : index_(index) {}
253 
Match(MatchState & state,const sem::Type * type) const254   const sem::Type* Match(MatchState& state,
255                          const sem::Type* type) const override {
256     if (type->Is<Any>()) {
257       return state.closed.Type(index_);
258     }
259     return state.closed.Type(index_, type) ? type : nullptr;
260   }
261 
262   std::string String(MatchState& state) const override;
263 
264  private:
265   uint32_t index_;
266 };
267 
268 /// OpenNumberMatcher is a Matcher for an open number.
269 /// The OpenNumberMatcher will match against any number (so long as it is
270 /// consistent for the overload)
271 class OpenNumberMatcher : public NumberMatcher {
272  public:
OpenNumberMatcher(uint32_t index)273   explicit OpenNumberMatcher(uint32_t index) : index_(index) {}
274 
Match(MatchState & state,Number number) const275   Number Match(MatchState& state, Number number) const override {
276     if (number.IsAny()) {
277       return state.closed.Num(index_);
278     }
279     return state.closed.Num(index_, number) ? number : Number::invalid;
280   }
281 
282   std::string String(MatchState& state) const override;
283 
284  private:
285   uint32_t index_;
286 };
287 
288 ////////////////////////////////////////////////////////////////////////////////
289 // Binding functions for use in the generated intrinsic_table.inl
290 // TODO(bclayton): See if we can move more of this hand-rolled code to the
291 // template
292 ////////////////////////////////////////////////////////////////////////////////
293 using TexelFormat = ast::ImageFormat;
294 using Access = ast::Access;
295 using StorageClass = ast::StorageClass;
296 using ParameterUsage = sem::ParameterUsage;
297 using PipelineStageSet = sem::PipelineStageSet;
298 using PipelineStage = ast::PipelineStage;
299 
match_bool(const sem::Type * ty)300 bool match_bool(const sem::Type* ty) {
301   return ty->IsAnyOf<Any, sem::Bool>();
302 }
303 
build_bool(MatchState & state)304 const sem::Bool* build_bool(MatchState& state) {
305   return state.builder.create<sem::Bool>();
306 }
307 
match_f32(const sem::Type * ty)308 bool match_f32(const sem::Type* ty) {
309   return ty->IsAnyOf<Any, sem::F32>();
310 }
311 
build_i32(MatchState & state)312 const sem::I32* build_i32(MatchState& state) {
313   return state.builder.create<sem::I32>();
314 }
315 
match_i32(const sem::Type * ty)316 bool match_i32(const sem::Type* ty) {
317   return ty->IsAnyOf<Any, sem::I32>();
318 }
319 
build_u32(MatchState & state)320 const sem::U32* build_u32(MatchState& state) {
321   return state.builder.create<sem::U32>();
322 }
323 
match_u32(const sem::Type * ty)324 bool match_u32(const sem::Type* ty) {
325   return ty->IsAnyOf<Any, sem::U32>();
326 }
327 
build_f32(MatchState & state)328 const sem::F32* build_f32(MatchState& state) {
329   return state.builder.create<sem::F32>();
330 }
331 
match_vec(const sem::Type * ty,Number & N,const sem::Type * & T)332 bool match_vec(const sem::Type* ty, Number& N, const sem::Type*& T) {
333   if (ty->Is<Any>()) {
334     N = Number::any;
335     T = ty;
336     return true;
337   }
338 
339   if (auto* v = ty->As<sem::Vector>()) {
340     N = v->Width();
341     T = v->type();
342     return true;
343   }
344   return false;
345 }
346 
build_vec(MatchState & state,Number N,const sem::Type * el)347 const sem::Vector* build_vec(MatchState& state, Number N, const sem::Type* el) {
348   return state.builder.create<sem::Vector>(el, N.Value());
349 }
350 
351 template <int N>
match_vec(const sem::Type * ty,const sem::Type * & T)352 bool match_vec(const sem::Type* ty, const sem::Type*& T) {
353   if (ty->Is<Any>()) {
354     T = ty;
355     return true;
356   }
357 
358   if (auto* v = ty->As<sem::Vector>()) {
359     if (v->Width() == N) {
360       T = v->type();
361       return true;
362     }
363   }
364   return false;
365 }
366 
match_vec2(const sem::Type * ty,const sem::Type * & T)367 bool match_vec2(const sem::Type* ty, const sem::Type*& T) {
368   return match_vec<2>(ty, T);
369 }
370 
build_vec2(MatchState & state,const sem::Type * T)371 const sem::Vector* build_vec2(MatchState& state, const sem::Type* T) {
372   return build_vec(state, Number(2), T);
373 }
374 
match_vec3(const sem::Type * ty,const sem::Type * & T)375 bool match_vec3(const sem::Type* ty, const sem::Type*& T) {
376   return match_vec<3>(ty, T);
377 }
378 
build_vec3(MatchState & state,const sem::Type * T)379 const sem::Vector* build_vec3(MatchState& state, const sem::Type* T) {
380   return build_vec(state, Number(3), T);
381 }
382 
match_vec4(const sem::Type * ty,const sem::Type * & T)383 bool match_vec4(const sem::Type* ty, const sem::Type*& T) {
384   return match_vec<4>(ty, T);
385 }
386 
build_vec4(MatchState & state,const sem::Type * T)387 const sem::Vector* build_vec4(MatchState& state, const sem::Type* T) {
388   return build_vec(state, Number(4), T);
389 }
390 
match_mat(const sem::Type * ty,Number & M,Number & N,const sem::Type * & T)391 bool match_mat(const sem::Type* ty, Number& M, Number& N, const sem::Type*& T) {
392   if (ty->Is<Any>()) {
393     M = Number::any;
394     N = Number::any;
395     T = ty;
396     return true;
397   }
398   if (auto* m = ty->As<sem::Matrix>()) {
399     M = m->columns();
400     N = m->ColumnType()->Width();
401     T = m->type();
402     return true;
403   }
404   return false;
405 }
406 
build_mat(MatchState & state,Number N,Number M,const sem::Type * T)407 const sem::Matrix* build_mat(MatchState& state,
408                              Number N,
409                              Number M,
410                              const sem::Type* T) {
411   auto* column_type = state.builder.create<sem::Vector>(T, M.Value());
412   return state.builder.create<sem::Matrix>(column_type, N.Value());
413 }
414 
match_array(const sem::Type * ty,const sem::Type * & T)415 bool match_array(const sem::Type* ty, const sem::Type*& T) {
416   if (ty->Is<Any>()) {
417     T = ty;
418     return true;
419   }
420 
421   if (auto* a = ty->As<sem::Array>()) {
422     if (a->Count() == 0) {
423       T = a->ElemType();
424       return true;
425     }
426   }
427   return false;
428 }
429 
build_array(MatchState & state,const sem::Type * el)430 const sem::Array* build_array(MatchState& state, const sem::Type* el) {
431   return state.builder.create<sem::Array>(el,
432                                           /* count */ 0,
433                                           /* align */ 0,
434                                           /* size */ 0,
435                                           /* stride */ 0,
436                                           /* stride_implicit */ 0);
437 }
438 
match_ptr(const sem::Type * ty,Number & S,const sem::Type * & T,Number & A)439 bool match_ptr(const sem::Type* ty, Number& S, const sem::Type*& T, Number& A) {
440   if (ty->Is<Any>()) {
441     S = Number::any;
442     T = ty;
443     A = Number::any;
444     return true;
445   }
446 
447   if (auto* p = ty->As<sem::Pointer>()) {
448     S = Number(static_cast<uint32_t>(p->StorageClass()));
449     T = p->StoreType();
450     A = Number(static_cast<uint32_t>(p->Access()));
451     return true;
452   }
453   return false;
454 }
455 
build_ptr(MatchState & state,Number S,const sem::Type * T,Number & A)456 const sem::Pointer* build_ptr(MatchState& state,
457                               Number S,
458                               const sem::Type* T,
459                               Number& A) {
460   return state.builder.create<sem::Pointer>(
461       T, static_cast<ast::StorageClass>(S.Value()),
462       static_cast<ast::Access>(A.Value()));
463 }
464 
match_atomic(const sem::Type * ty,const sem::Type * & T)465 bool match_atomic(const sem::Type* ty, const sem::Type*& T) {
466   if (ty->Is<Any>()) {
467     T = ty;
468     return true;
469   }
470 
471   if (auto* a = ty->As<sem::Atomic>()) {
472     T = a->Type();
473     return true;
474   }
475   return false;
476 }
477 
build_atomic(MatchState & state,const sem::Type * T)478 const sem::Atomic* build_atomic(MatchState& state, const sem::Type* T) {
479   return state.builder.create<sem::Atomic>(T);
480 }
481 
match_sampler(const sem::Type * ty)482 bool match_sampler(const sem::Type* ty) {
483   if (ty->Is<Any>()) {
484     return true;
485   }
486   return ty->Is([](const sem::Sampler* s) {
487     return s->kind() == ast::SamplerKind::kSampler;
488   });
489 }
490 
build_sampler(MatchState & state)491 const sem::Sampler* build_sampler(MatchState& state) {
492   return state.builder.create<sem::Sampler>(ast::SamplerKind::kSampler);
493 }
494 
match_sampler_comparison(const sem::Type * ty)495 bool match_sampler_comparison(const sem::Type* ty) {
496   if (ty->Is<Any>()) {
497     return true;
498   }
499   return ty->Is([](const sem::Sampler* s) {
500     return s->kind() == ast::SamplerKind::kComparisonSampler;
501   });
502 }
503 
build_sampler_comparison(MatchState & state)504 const sem::Sampler* build_sampler_comparison(MatchState& state) {
505   return state.builder.create<sem::Sampler>(
506       ast::SamplerKind::kComparisonSampler);
507 }
508 
match_texture(const sem::Type * ty,ast::TextureDimension dim,const sem::Type * & T)509 bool match_texture(const sem::Type* ty,
510                    ast::TextureDimension dim,
511                    const sem::Type*& T) {
512   if (ty->Is<Any>()) {
513     T = ty;
514     return true;
515   }
516   if (auto* v = ty->As<sem::SampledTexture>()) {
517     if (v->dim() == dim) {
518       T = v->type();
519       return true;
520     }
521   }
522   return false;
523 }
524 
525 #define JOIN(a, b) a##b
526 
527 #define DECLARE_SAMPLED_TEXTURE(suffix, dim)                  \
528   bool JOIN(match_texture_, suffix)(const sem::Type* ty,      \
529                                     const sem::Type*& T) {    \
530     return match_texture(ty, dim, T);                         \
531   }                                                           \
532   const sem::SampledTexture* JOIN(build_texture_, suffix)(    \
533       MatchState & state, const sem::Type* T) {               \
534     return state.builder.create<sem::SampledTexture>(dim, T); \
535   }
536 
537 DECLARE_SAMPLED_TEXTURE(1d, ast::TextureDimension::k1d)
538 DECLARE_SAMPLED_TEXTURE(2d, ast::TextureDimension::k2d)
539 DECLARE_SAMPLED_TEXTURE(2d_array, ast::TextureDimension::k2dArray)
540 DECLARE_SAMPLED_TEXTURE(3d, ast::TextureDimension::k3d)
DECLARE_SAMPLED_TEXTURE(cube,ast::TextureDimension::kCube)541 DECLARE_SAMPLED_TEXTURE(cube, ast::TextureDimension::kCube)
542 DECLARE_SAMPLED_TEXTURE(cube_array, ast::TextureDimension::kCubeArray)
543 #undef DECLARE_SAMPLED_TEXTURE
544 
545 bool match_texture_multisampled(const sem::Type* ty,
546                                 ast::TextureDimension dim,
547                                 const sem::Type*& T) {
548   if (ty->Is<Any>()) {
549     T = ty;
550     return true;
551   }
552   if (auto* v = ty->As<sem::MultisampledTexture>()) {
553     if (v->dim() == dim) {
554       T = v->type();
555       return true;
556     }
557   }
558   return false;
559 }
560 
561 #define DECLARE_MULTISAMPLED_TEXTURE(suffix, dim)                            \
562   bool JOIN(match_texture_multisampled_, suffix)(const sem::Type* ty,        \
563                                                  const sem::Type*& T) {      \
564     return match_texture_multisampled(ty, dim, T);                           \
565   }                                                                          \
566   const sem::MultisampledTexture* JOIN(build_texture_multisampled_, suffix)( \
567       MatchState & state, const sem::Type* T) {                              \
568     return state.builder.create<sem::MultisampledTexture>(dim, T);           \
569   }
570 
571 DECLARE_MULTISAMPLED_TEXTURE(2d, ast::TextureDimension::k2d)
572 #undef DECLARE_MULTISAMPLED_TEXTURE
573 
match_texture_depth(const sem::Type * ty,ast::TextureDimension dim)574 bool match_texture_depth(const sem::Type* ty, ast::TextureDimension dim) {
575   if (ty->Is<Any>()) {
576     return true;
577   }
578   return ty->Is([&](const sem::DepthTexture* t) { return t->dim() == dim; });
579 }
580 
581 #define DECLARE_DEPTH_TEXTURE(suffix, dim)                       \
582   bool JOIN(match_texture_depth_, suffix)(const sem::Type* ty) { \
583     return match_texture_depth(ty, dim);                         \
584   }                                                              \
585   const sem::DepthTexture* JOIN(build_texture_depth_,            \
586                                 suffix)(MatchState & state) {    \
587     return state.builder.create<sem::DepthTexture>(dim);         \
588   }
589 
590 DECLARE_DEPTH_TEXTURE(2d, ast::TextureDimension::k2d)
591 DECLARE_DEPTH_TEXTURE(2d_array, ast::TextureDimension::k2dArray)
DECLARE_DEPTH_TEXTURE(cube,ast::TextureDimension::kCube)592 DECLARE_DEPTH_TEXTURE(cube, ast::TextureDimension::kCube)
593 DECLARE_DEPTH_TEXTURE(cube_array, ast::TextureDimension::kCubeArray)
594 #undef DECLARE_DEPTH_TEXTURE
595 
596 bool match_texture_depth_multisampled_2d(const sem::Type* ty) {
597   if (ty->Is<Any>()) {
598     return true;
599   }
600   return ty->Is([&](const sem::DepthMultisampledTexture* t) {
601     return t->dim() == ast::TextureDimension::k2d;
602   });
603 }
604 
build_texture_depth_multisampled_2d(MatchState & state)605 sem::DepthMultisampledTexture* build_texture_depth_multisampled_2d(
606     MatchState& state) {
607   return state.builder.create<sem::DepthMultisampledTexture>(
608       ast::TextureDimension::k2d);
609 }
610 
match_texture_storage(const sem::Type * ty,ast::TextureDimension dim,Number & F,Number & A)611 bool match_texture_storage(const sem::Type* ty,
612                            ast::TextureDimension dim,
613                            Number& F,
614                            Number& A) {
615   if (ty->Is<Any>()) {
616     F = Number::any;
617     A = Number::any;
618     return true;
619   }
620   if (auto* v = ty->As<sem::StorageTexture>()) {
621     if (v->dim() == dim) {
622       F = Number(static_cast<uint32_t>(v->image_format()));
623       A = Number(static_cast<uint32_t>(v->access()));
624       return true;
625     }
626   }
627   return false;
628 }
629 
630 #define DECLARE_STORAGE_TEXTURE(suffix, dim)                                  \
631   bool JOIN(match_texture_storage_, suffix)(const sem::Type* ty, Number& F,   \
632                                             Number& A) {                      \
633     return match_texture_storage(ty, dim, F, A);                              \
634   }                                                                           \
635   const sem::StorageTexture* JOIN(build_texture_storage_, suffix)(            \
636       MatchState & state, Number F, Number A) {                               \
637     auto format = static_cast<TexelFormat>(F.Value());                        \
638     auto access = static_cast<Access>(A.Value());                             \
639     auto* T = sem::StorageTexture::SubtypeFor(format, state.builder.Types()); \
640     return state.builder.create<sem::StorageTexture>(dim, format, access, T); \
641   }
642 
643 DECLARE_STORAGE_TEXTURE(1d, ast::TextureDimension::k1d)
644 DECLARE_STORAGE_TEXTURE(2d, ast::TextureDimension::k2d)
645 DECLARE_STORAGE_TEXTURE(2d_array, ast::TextureDimension::k2dArray)
646 DECLARE_STORAGE_TEXTURE(3d, ast::TextureDimension::k3d)
647 #undef DECLARE_STORAGE_TEXTURE
648 
match_texture_external(const sem::Type * ty)649 bool match_texture_external(const sem::Type* ty) {
650   return ty->IsAnyOf<Any, sem::ExternalTexture>();
651 }
652 
build_texture_external(MatchState & state)653 const sem::ExternalTexture* build_texture_external(MatchState& state) {
654   return state.builder.create<sem::ExternalTexture>();
655 }
656 
657 // Builtin types starting with a _ prefix cannot be declared in WGSL, so they
658 // can only be used as return types. Because of this, they must only match Any,
659 // which is used as the return type matcher.
match_modf_result(const sem::Type * ty)660 bool match_modf_result(const sem::Type* ty) {
661   return ty->Is<Any>();
662 }
match_modf_result_vec(const sem::Type * ty,Number & N)663 bool match_modf_result_vec(const sem::Type* ty, Number& N) {
664   if (!ty->Is<Any>()) {
665     return false;
666   }
667   N = Number::any;
668   return true;
669 }
match_frexp_result(const sem::Type * ty)670 bool match_frexp_result(const sem::Type* ty) {
671   return ty->Is<Any>();
672 }
match_frexp_result_vec(const sem::Type * ty,Number & N)673 bool match_frexp_result_vec(const sem::Type* ty, Number& N) {
674   if (!ty->Is<Any>()) {
675     return false;
676   }
677   N = Number::any;
678   return true;
679 }
680 
681 struct NameAndType {
682   std::string name;
683   sem::Type* type;
684 };
build_struct(MatchState & state,std::string name,std::initializer_list<NameAndType> member_names_and_types)685 const sem::Struct* build_struct(
686     MatchState& state,
687     std::string name,
688     std::initializer_list<NameAndType> member_names_and_types) {
689   uint32_t offset = 0;
690   uint32_t max_align = 0;
691   sem::StructMemberList members;
692   for (auto& m : member_names_and_types) {
693     uint32_t align = m.type->Align();
694     uint32_t size = m.type->Size();
695     offset = utils::RoundUp(align, offset);
696     max_align = std::max(max_align, align);
697     members.emplace_back(state.builder.create<sem::StructMember>(
698         /* declaration */ nullptr,
699         /* name */ state.builder.Sym(m.name),
700         /* type */ m.type,
701         /* index */ static_cast<uint32_t>(members.size()),
702         /* offset */ offset,
703         /* align */ align,
704         /* size */ size));
705     offset += size;
706   }
707   uint32_t size_without_padding = offset;
708   uint32_t size_with_padding = utils::RoundUp(max_align, offset);
709   return state.builder.create<sem::Struct>(
710       /* declaration */ nullptr,
711       /* name */ state.builder.Sym(name),
712       /* members */ members,
713       /* align */ max_align,
714       /* size */ size_with_padding,
715       /* size_no_padding */ size_without_padding);
716 }
717 
build_modf_result(MatchState & state)718 const sem::Struct* build_modf_result(MatchState& state) {
719   auto* f32 = state.builder.create<sem::F32>();
720   return build_struct(state, "__modf_result", {{"fract", f32}, {"whole", f32}});
721 }
build_modf_result_vec(MatchState & state,Number & n)722 const sem::Struct* build_modf_result_vec(MatchState& state, Number& n) {
723   auto* vec_f32 = state.builder.create<sem::Vector>(
724       state.builder.create<sem::F32>(), n.Value());
725   return build_struct(state, "__modf_result_vec" + std::to_string(n.Value()),
726                       {{"fract", vec_f32}, {"whole", vec_f32}});
727 }
build_frexp_result(MatchState & state)728 const sem::Struct* build_frexp_result(MatchState& state) {
729   auto* f32 = state.builder.create<sem::F32>();
730   auto* i32 = state.builder.create<sem::I32>();
731   return build_struct(state, "__frexp_result", {{"sig", f32}, {"exp", i32}});
732 }
build_frexp_result_vec(MatchState & state,Number & n)733 const sem::Struct* build_frexp_result_vec(MatchState& state, Number& n) {
734   auto* vec_f32 = state.builder.create<sem::Vector>(
735       state.builder.create<sem::F32>(), n.Value());
736   auto* vec_i32 = state.builder.create<sem::Vector>(
737       state.builder.create<sem::I32>(), n.Value());
738   return build_struct(state, "__frexp_result_vec" + std::to_string(n.Value()),
739                       {{"sig", vec_f32}, {"exp", vec_i32}});
740 }
741 
742 /// ParameterInfo describes a parameter
743 struct ParameterInfo {
744   /// The parameter usage (parameter name in definition file)
745   const ParameterUsage usage;
746 
747   /// Pointer to a list of indices that are used to match the parameter type.
748   /// The matcher indices index on Matchers::type and / or Matchers::number.
749   /// These indices are consumed by the matchers themselves.
750   /// The first index is always a TypeMatcher.
751   MatcherIndex const* const matcher_indices;
752 };
753 
754 /// OpenTypeInfo describes an open type
755 struct OpenTypeInfo {
756   /// Name of the open type (e.g. 'T')
757   const char* name;
758   /// Optional type matcher constraint.
759   /// Either an index in Matchers::type, or kNoMatcher
760   const MatcherIndex matcher_index;
761 };
762 
763 /// OpenNumberInfo describes an open number
764 struct OpenNumberInfo {
765   /// Name of the open number (e.g. 'N')
766   const char* name;
767   /// Optional number matcher constraint.
768   /// Either an index in Matchers::number, or kNoMatcher
769   const MatcherIndex matcher_index;
770 };
771 
772 /// OverloadInfo describes a single function overload
773 struct OverloadInfo {
774   /// Total number of parameters for the overload
775   const uint8_t num_parameters;
776   /// Total number of open types for the overload
777   const uint8_t num_open_types;
778   /// Total number of open numbers for the overload
779   const uint8_t num_open_numbers;
780   /// Pointer to the first open type
781   OpenTypeInfo const* const open_types;
782   /// Pointer to the first open number
783   OpenNumberInfo const* const open_numbers;
784   /// Pointer to the first parameter
785   ParameterInfo const* const parameters;
786   /// Pointer to a list of matcher indices that index on Matchers::type and
787   /// Matchers::number, used to build the return type. If the function has no
788   /// return type then this is null
789   MatcherIndex const* const return_matcher_indices;
790   /// The pipeline stages that this overload can be used in
791   PipelineStageSet supported_stages;
792   /// True if the overload is marked as deprecated
793   bool is_deprecated;
794 };
795 
796 /// IntrinsicInfo describes an intrinsic function
797 struct IntrinsicInfo {
798   /// Number of overloads of the intrinsic function
799   const uint8_t num_overloads;
800   /// Pointer to the start of the overloads for the function
801   OverloadInfo const* const overloads;
802 };
803 
804 #include "intrinsic_table.inl"
805 
806 /// IntrinsicPrototype describes a fully matched intrinsic function, which is
807 /// used as a lookup for building unique sem::Intrinsic instances.
808 struct IntrinsicPrototype {
809   /// Parameter describes a single parameter
810   struct Parameter {
811     /// Parameter type
812     const sem::Type* const type;
813     /// Parameter usage
814     ParameterUsage const usage = ParameterUsage::kNone;
815   };
816 
817   /// Hasher provides a hash function for the IntrinsicPrototype
818   struct Hasher {
819     /// @param i the IntrinsicPrototype to create a hash for
820     /// @return the hash value
operator ()tint::__anon033464eb0111::IntrinsicPrototype::Hasher821     inline std::size_t operator()(const IntrinsicPrototype& i) const {
822       size_t hash = utils::Hash(i.parameters.size());
823       for (auto& p : i.parameters) {
824         utils::HashCombine(&hash, p.type, p.usage);
825       }
826       return utils::Hash(hash, i.type, i.return_type, i.supported_stages,
827                          i.is_deprecated);
828     }
829   };
830 
831   sem::IntrinsicType type = sem::IntrinsicType::kNone;
832   std::vector<Parameter> parameters;
833   sem::Type const* return_type = nullptr;
834   PipelineStageSet supported_stages;
835   bool is_deprecated = false;
836 };
837 
838 /// Equality operator for IntrinsicPrototype
operator ==(const IntrinsicPrototype & a,const IntrinsicPrototype & b)839 bool operator==(const IntrinsicPrototype& a, const IntrinsicPrototype& b) {
840   if (a.type != b.type || a.supported_stages != b.supported_stages ||
841       a.return_type != b.return_type || a.is_deprecated != b.is_deprecated ||
842       a.parameters.size() != b.parameters.size()) {
843     return false;
844   }
845   for (size_t i = 0; i < a.parameters.size(); i++) {
846     auto& pa = a.parameters[i];
847     auto& pb = b.parameters[i];
848     if (pa.type != pb.type || pa.usage != pb.usage) {
849       return false;
850     }
851   }
852   return true;
853 }
854 
855 /// Impl is the private implementation of the IntrinsicTable interface.
856 class Impl : public IntrinsicTable {
857  public:
858   explicit Impl(ProgramBuilder& builder);
859 
860   const sem::Intrinsic* Lookup(sem::IntrinsicType intrinsic_type,
861                                const std::vector<const sem::Type*>& args,
862                                const Source& source) override;
863 
864  private:
865   const sem::Intrinsic* Match(sem::IntrinsicType intrinsic_type,
866                               const OverloadInfo& overload,
867                               const std::vector<const sem::Type*>& args,
868                               int& match_score);
869 
870   MatchState Match(ClosedState& closed,
871                    const OverloadInfo& overload,
872                    MatcherIndex const* matcher_indices) const;
873 
874   void PrintOverload(std::ostream& ss,
875                      const OverloadInfo& overload,
876                      sem::IntrinsicType intrinsic_type) const;
877 
878   ProgramBuilder& builder;
879   Matchers matchers;
880   std::unordered_map<IntrinsicPrototype,
881                      sem::Intrinsic*,
882                      IntrinsicPrototype::Hasher>
883       intrinsics;
884 };
885 
886 /// @return a string representing a call to an intrinsic with the given argument
887 /// types.
CallSignature(ProgramBuilder & builder,sem::IntrinsicType intrinsic_type,const std::vector<const sem::Type * > & args)888 std::string CallSignature(ProgramBuilder& builder,
889                           sem::IntrinsicType intrinsic_type,
890                           const std::vector<const sem::Type*>& args) {
891   std::stringstream ss;
892   ss << sem::str(intrinsic_type) << "(";
893   {
894     bool first = true;
895     for (auto* arg : args) {
896       if (!first) {
897         ss << ", ";
898       }
899       first = false;
900       ss << arg->UnwrapRef()->FriendlyName(builder.Symbols());
901     }
902   }
903   ss << ")";
904 
905   return ss.str();
906 }
907 
String(MatchState & state) const908 std::string OpenTypeMatcher::String(MatchState& state) const {
909   return state.overload.open_types[index_].name;
910 }
911 
String(MatchState & state) const912 std::string OpenNumberMatcher::String(MatchState& state) const {
913   return state.overload.open_numbers[index_].name;
914 }
915 
Impl(ProgramBuilder & b)916 Impl::Impl(ProgramBuilder& b) : builder(b) {}
917 
Lookup(sem::IntrinsicType intrinsic_type,const std::vector<const sem::Type * > & args,const Source & source)918 const sem::Intrinsic* Impl::Lookup(sem::IntrinsicType intrinsic_type,
919                                    const std::vector<const sem::Type*>& args,
920                                    const Source& source) {
921   // Candidate holds information about a mismatched overload that could be what
922   // the user intended to call.
923   struct Candidate {
924     const OverloadInfo* overload;
925     int score;
926   };
927 
928   // The list of failed matches that had promise.
929   std::vector<Candidate> candidates;
930 
931   auto& intrinsic = kIntrinsics[static_cast<uint32_t>(intrinsic_type)];
932   for (uint32_t o = 0; o < intrinsic.num_overloads; o++) {
933     int match_score = 1000;
934     auto& overload = intrinsic.overloads[o];
935     if (auto* match = Match(intrinsic_type, overload, args, match_score)) {
936       return match;
937     }
938     if (match_score > 0) {
939       candidates.emplace_back(Candidate{&overload, match_score});
940     }
941   }
942 
943   // Sort the candidates with the most promising first
944   std::stable_sort(
945       candidates.begin(), candidates.end(),
946       [](const Candidate& a, const Candidate& b) { return a.score > b.score; });
947 
948   // Generate an error message
949   std::stringstream ss;
950   ss << "no matching call to " << CallSignature(builder, intrinsic_type, args)
951      << std::endl;
952   if (!candidates.empty()) {
953     ss << std::endl;
954     ss << candidates.size() << " candidate function"
955        << (candidates.size() > 1 ? "s:" : ":") << std::endl;
956     for (auto& candidate : candidates) {
957       ss << "  ";
958       PrintOverload(ss, *candidate.overload, intrinsic_type);
959       ss << std::endl;
960     }
961   }
962   builder.Diagnostics().add_error(diag::System::Resolver, ss.str(), source);
963   return nullptr;
964 }
965 
Match(sem::IntrinsicType intrinsic_type,const OverloadInfo & overload,const std::vector<const sem::Type * > & args,int & match_score)966 const sem::Intrinsic* Impl::Match(sem::IntrinsicType intrinsic_type,
967                                   const OverloadInfo& overload,
968                                   const std::vector<const sem::Type*>& args,
969                                   int& match_score) {
970   // Score wait for argument <-> parameter count matches / mismatches
971   constexpr int kScorePerParamArgMismatch = -1;
972   constexpr int kScorePerMatchedParam = 2;
973   constexpr int kScorePerMatchedOpenType = 1;
974   constexpr int kScorePerMatchedOpenNumber = 1;
975 
976   auto num_parameters = overload.num_parameters;
977   auto num_arguments = static_cast<decltype(num_parameters)>(args.size());
978 
979   bool overload_matched = true;
980 
981   if (num_parameters != num_arguments) {
982     match_score +=
983         kScorePerParamArgMismatch * (std::max(num_parameters, num_arguments) -
984                                      std::min(num_parameters, num_arguments));
985     overload_matched = false;
986   }
987 
988   ClosedState closed(builder);
989 
990   std::vector<IntrinsicPrototype::Parameter> parameters;
991 
992   auto num_params = std::min(num_parameters, num_arguments);
993   for (uint32_t p = 0; p < num_params; p++) {
994     auto& parameter = overload.parameters[p];
995     auto* indices = parameter.matcher_indices;
996     auto* type = Match(closed, overload, indices).Type(args[p]->UnwrapRef());
997     if (type) {
998       parameters.emplace_back(
999           IntrinsicPrototype::Parameter{type, parameter.usage});
1000       match_score += kScorePerMatchedParam;
1001     } else {
1002       overload_matched = false;
1003     }
1004   }
1005 
1006   if (overload_matched) {
1007     // Check all constrained open types matched
1008     for (uint32_t ot = 0; ot < overload.num_open_types; ot++) {
1009       auto& open_type = overload.open_types[ot];
1010       if (open_type.matcher_index != kNoMatcher) {
1011         auto* index = &open_type.matcher_index;
1012         if (Match(closed, overload, index).Type(closed.Type(ot))) {
1013           match_score += kScorePerMatchedOpenType;
1014         } else {
1015           overload_matched = false;
1016         }
1017       }
1018     }
1019   }
1020 
1021   if (overload_matched) {
1022     // Check all constrained open numbers matched
1023     for (uint32_t on = 0; on < overload.num_open_numbers; on++) {
1024       auto& open_number = overload.open_numbers[on];
1025       if (open_number.matcher_index != kNoMatcher) {
1026         auto* index = &open_number.matcher_index;
1027         if (Match(closed, overload, index).Num(closed.Num(on)).IsValid()) {
1028           match_score += kScorePerMatchedOpenNumber;
1029         } else {
1030           overload_matched = false;
1031         }
1032       }
1033     }
1034   }
1035 
1036   if (!overload_matched) {
1037     return nullptr;
1038   }
1039 
1040   // Build the return type
1041   const sem::Type* return_type = nullptr;
1042   if (auto* indices = overload.return_matcher_indices) {
1043     Any any;
1044     return_type = Match(closed, overload, indices).Type(&any);
1045     if (!return_type) {
1046       std::stringstream ss;
1047       PrintOverload(ss, overload, intrinsic_type);
1048       TINT_ICE(Resolver, builder.Diagnostics())
1049           << "MatchState.Match() returned null for " << ss.str();
1050       return nullptr;
1051     }
1052   } else {
1053     return_type = builder.create<sem::Void>();
1054   }
1055 
1056   IntrinsicPrototype intrinsic;
1057   intrinsic.type = intrinsic_type;
1058   intrinsic.return_type = return_type;
1059   intrinsic.parameters = std::move(parameters);
1060   intrinsic.supported_stages = overload.supported_stages;
1061   intrinsic.is_deprecated = overload.is_deprecated;
1062 
1063   // De-duplicate intrinsics that are identical.
1064   return utils::GetOrCreate(intrinsics, intrinsic, [&] {
1065     std::vector<sem::Parameter*> params;
1066     params.reserve(intrinsic.parameters.size());
1067     for (auto& p : intrinsic.parameters) {
1068       params.emplace_back(builder.create<sem::Parameter>(
1069           nullptr, static_cast<uint32_t>(params.size()), p.type,
1070           ast::StorageClass::kNone, ast::Access::kUndefined, p.usage));
1071     }
1072     return builder.create<sem::Intrinsic>(
1073         intrinsic.type, intrinsic.return_type, std::move(params),
1074         intrinsic.supported_stages, intrinsic.is_deprecated);
1075   });
1076 }
1077 
Match(ClosedState & closed,const OverloadInfo & overload,MatcherIndex const * matcher_indices) const1078 MatchState Impl::Match(ClosedState& closed,
1079                        const OverloadInfo& overload,
1080                        MatcherIndex const* matcher_indices) const {
1081   return MatchState(builder, closed, matchers, overload, matcher_indices);
1082 }
1083 
PrintOverload(std::ostream & ss,const OverloadInfo & overload,sem::IntrinsicType intrinsic_type) const1084 void Impl::PrintOverload(std::ostream& ss,
1085                          const OverloadInfo& overload,
1086                          sem::IntrinsicType intrinsic_type) const {
1087   ClosedState closed(builder);
1088 
1089   ss << intrinsic_type << "(";
1090   for (uint32_t p = 0; p < overload.num_parameters; p++) {
1091     auto& parameter = overload.parameters[p];
1092     if (p > 0) {
1093       ss << ", ";
1094     }
1095     if (parameter.usage != ParameterUsage::kNone) {
1096       ss << sem::str(parameter.usage) << ": ";
1097     }
1098     auto* indices = parameter.matcher_indices;
1099     ss << Match(closed, overload, indices).TypeName();
1100   }
1101   ss << ")";
1102   if (overload.return_matcher_indices) {
1103     ss << " -> ";
1104     auto* indices = overload.return_matcher_indices;
1105     ss << Match(closed, overload, indices).TypeName();
1106   }
1107 
1108   bool first = true;
1109   auto separator = [&] {
1110     ss << (first ? "  where: " : ", ");
1111     first = false;
1112   };
1113   for (uint32_t i = 0; i < overload.num_open_types; i++) {
1114     auto& open_type = overload.open_types[i];
1115     if (open_type.matcher_index != kNoMatcher) {
1116       separator();
1117       ss << open_type.name;
1118       auto* index = &open_type.matcher_index;
1119       ss << " is " << Match(closed, overload, index).TypeName();
1120     }
1121   }
1122   for (uint32_t i = 0; i < overload.num_open_numbers; i++) {
1123     auto& open_number = overload.open_numbers[i];
1124     if (open_number.matcher_index != kNoMatcher) {
1125       separator();
1126       ss << open_number.name;
1127       auto* index = &open_number.matcher_index;
1128       ss << " is " << Match(closed, overload, index).NumName();
1129     }
1130   }
1131 }
1132 
Type(const sem::Type * ty)1133 const sem::Type* MatchState::Type(const sem::Type* ty) {
1134   MatcherIndex matcher_index = *matcher_indices_++;
1135   auto* matcher = matchers.type[matcher_index];
1136   return matcher->Match(*this, ty);
1137 }
1138 
Num(Number number)1139 Number MatchState::Num(Number number) {
1140   MatcherIndex matcher_index = *matcher_indices_++;
1141   auto* matcher = matchers.number[matcher_index];
1142   return matcher->Match(*this, number);
1143 }
1144 
TypeName()1145 std::string MatchState::TypeName() {
1146   MatcherIndex matcher_index = *matcher_indices_++;
1147   auto* matcher = matchers.type[matcher_index];
1148   return matcher->String(*this);
1149 }
1150 
NumName()1151 std::string MatchState::NumName() {
1152   MatcherIndex matcher_index = *matcher_indices_++;
1153   auto* matcher = matchers.number[matcher_index];
1154   return matcher->String(*this);
1155 }
1156 
1157 }  // namespace
1158 
Create(ProgramBuilder & builder)1159 std::unique_ptr<IntrinsicTable> IntrinsicTable::Create(
1160     ProgramBuilder& builder) {
1161   return std::make_unique<Impl>(builder);
1162 }
1163 
1164 IntrinsicTable::~IntrinsicTable() = default;
1165 
1166 /// TypeInfo for the Any type declared in the anonymous namespace above
1167 TINT_INSTANTIATE_TYPEINFO(Any);
1168 
1169 }  // namespace tint
1170