• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_CLONE_CONTEXT_H_
16 #define SRC_CLONE_CONTEXT_H_
17 
18 #include <algorithm>
19 #include <functional>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "src/castable.h"
26 #include "src/debug.h"
27 #include "src/program_id.h"
28 #include "src/symbol.h"
29 #include "src/traits.h"
30 
31 namespace tint {
32 
33 // Forward declarations
34 class CloneContext;
35 class Program;
36 class ProgramBuilder;
37 namespace ast {
38 class FunctionList;
39 class Node;
40 }  // namespace ast
41 
42 ProgramID ProgramIDOf(const Program*);
43 ProgramID ProgramIDOf(const ProgramBuilder*);
44 
45 /// Cloneable is the base class for all objects that can be cloned
46 class Cloneable : public Castable<Cloneable> {
47  public:
48   /// Performs a deep clone of this object using the CloneContext `ctx`.
49   /// @param ctx the clone context
50   /// @return the newly cloned object
51   virtual const Cloneable* Clone(CloneContext* ctx) const = 0;
52 };
53 
54 /// @returns an invalid ProgramID
ProgramIDOf(const Cloneable *)55 inline ProgramID ProgramIDOf(const Cloneable*) {
56   return ProgramID();
57 }
58 
59 /// CloneContext holds the state used while cloning AST nodes.
60 class CloneContext {
61   /// ParamTypeIsPtrOf<F, T>::value is true iff the first parameter of
62   /// F is a pointer of (or derives from) type T.
63   template <typename F, typename T>
64   using ParamTypeIsPtrOf = traits::IsTypeOrDerived<
65       typename std::remove_pointer<traits::ParameterType<F, 0>>::type,
66       T>;
67 
68  public:
69   /// SymbolTransform is a function that takes a symbol and returns a new
70   /// symbol.
71   using SymbolTransform = std::function<Symbol(Symbol)>;
72 
73   /// Constructor for cloning objects from `from` into `to`.
74   /// @param to the target ProgramBuilder to clone into
75   /// @param from the source Program to clone from
76   /// @param auto_clone_symbols clone all symbols in `from` before returning
77   CloneContext(ProgramBuilder* to,
78                Program const* from,
79                bool auto_clone_symbols = true);
80 
81   /// Constructor for cloning objects from and to the ProgramBuilder `builder`.
82   /// @param builder the ProgramBuilder
83   explicit CloneContext(ProgramBuilder* builder);
84 
85   /// Destructor
86   ~CloneContext();
87 
88   /// Clones the Node or sem::Type `a` into the ProgramBuilder #dst if `a` is
89   /// not null. If `a` is null, then Clone() returns null.
90   ///
91   /// Clone() may use a function registered with ReplaceAll() to create a
92   /// transformed version of the object. See ReplaceAll() for more information.
93   ///
94   /// If the CloneContext is cloning from a Program to a ProgramBuilder, then
95   /// the Node or sem::Type `a` must be owned by the Program #src.
96   ///
97   /// @param object the type deriving from Cloneable to clone
98   /// @return the cloned node
99   template <typename T>
Clone(const T * object)100   const T* Clone(const T* object) {
101     if (src) {
102       TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, object);
103     }
104     if (auto* cloned = CloneCloneable(object)) {
105       auto* out = CheckedCast<T>(cloned);
106       TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, out);
107       return out;
108     }
109     return nullptr;
110   }
111 
112   /// Clones the Node or sem::Type `a` into the ProgramBuilder #dst if `a` is
113   /// not null. If `a` is null, then Clone() returns null.
114   ///
115   /// Unlike Clone(), this method does not invoke or use any transformations
116   /// registered by ReplaceAll().
117   ///
118   /// If the CloneContext is cloning from a Program to a ProgramBuilder, then
119   /// the Node or sem::Type `a` must be owned by the Program #src.
120   ///
121   /// @param a the type deriving from Cloneable to clone
122   /// @return the cloned node
123   template <typename T>
CloneWithoutTransform(const T * a)124   const T* CloneWithoutTransform(const T* a) {
125     // If the input is nullptr, there's nothing to clone - just return nullptr.
126     if (a == nullptr) {
127       return nullptr;
128     }
129     if (src) {
130       TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, a);
131     }
132     auto* c = a->Clone(this);
133     return CheckedCast<T>(c);
134   }
135 
136   /// Clones the Source `s` into #dst
137   /// TODO(bclayton) - Currently this 'clone' is a shallow copy. If/when
138   /// `Source.File`s are owned by the Program this should make a copy of the
139   /// file.
140   /// @param s the `Source` to clone
141   /// @return the cloned source
Clone(const Source & s)142   Source Clone(const Source& s) const { return s; }
143 
144   /// Clones the Symbol `s` into #dst
145   ///
146   /// The Symbol `s` must be owned by the Program #src.
147   ///
148   /// @param s the Symbol to clone
149   /// @return the cloned source
150   Symbol Clone(Symbol s);
151 
152   /// Clones each of the elements of the vector `v` into the ProgramBuilder
153   /// #dst.
154   ///
155   /// All the elements of the vector `v` must be owned by the Program #src.
156   ///
157   /// @param v the vector to clone
158   /// @return the cloned vector
159   template <typename T>
Clone(const std::vector<T> & v)160   std::vector<T> Clone(const std::vector<T>& v) {
161     std::vector<T> out;
162     out.reserve(v.size());
163     for (auto& el : v) {
164       out.emplace_back(Clone(el));
165     }
166     return out;
167   }
168 
169   /// Clones each of the elements of the vector `v` using the ProgramBuilder
170   /// #dst, inserting any additional elements into the list that were registered
171   /// with calls to InsertBefore().
172   ///
173   /// All the elements of the vector `v` must be owned by the Program #src.
174   ///
175   /// @param v the vector to clone
176   /// @return the cloned vector
177   template <typename T>
Clone(const std::vector<T * > & v)178   std::vector<T*> Clone(const std::vector<T*>& v) {
179     std::vector<T*> out;
180     Clone(out, v);
181     return out;
182   }
183 
184   /// Clones each of the elements of the vector `from` into the vector `to`,
185   /// inserting any additional elements into the list that were registered with
186   /// calls to InsertBefore().
187   ///
188   /// All the elements of the vector `from` must be owned by the Program #src.
189   ///
190   /// @param from the vector to clone
191   /// @param to the cloned result
192   template <typename T>
Clone(std::vector<T * > & to,const std::vector<T * > & from)193   void Clone(std::vector<T*>& to, const std::vector<T*>& from) {
194     to.reserve(from.size());
195 
196     auto list_transform_it = list_transforms_.find(&from);
197     if (list_transform_it != list_transforms_.end()) {
198       const auto& transforms = list_transform_it->second;
199       for (auto* o : transforms.insert_front_) {
200         to.emplace_back(CheckedCast<T>(o));
201       }
202       for (auto& el : from) {
203         auto insert_before_it = transforms.insert_before_.find(el);
204         if (insert_before_it != transforms.insert_before_.end()) {
205           for (auto insert : insert_before_it->second) {
206             to.emplace_back(CheckedCast<T>(insert));
207           }
208         }
209         if (transforms.remove_.count(el) == 0) {
210           to.emplace_back(Clone(el));
211         }
212         auto insert_after_it = transforms.insert_after_.find(el);
213         if (insert_after_it != transforms.insert_after_.end()) {
214           for (auto insert : insert_after_it->second) {
215             to.emplace_back(CheckedCast<T>(insert));
216           }
217         }
218       }
219       for (auto* o : transforms.insert_back_) {
220         to.emplace_back(CheckedCast<T>(o));
221       }
222     } else {
223       for (auto& el : from) {
224         to.emplace_back(Clone(el));
225       }
226     }
227   }
228 
229   /// Clones each of the elements of the vector `v` into the ProgramBuilder
230   /// #dst.
231   ///
232   /// All the elements of the vector `v` must be owned by the Program #src.
233   ///
234   /// @param v the vector to clone
235   /// @return the cloned vector
236   ast::FunctionList Clone(const ast::FunctionList& v);
237 
238   /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
239   /// is called with a Cloneable type that matches (or derives from) the type of
240   /// the single parameter of `replacer`.
241   /// The returned Cloneable of `replacer` will be used as the replacement for
242   /// all references to the object that's being cloned. This returned Cloneable
243   /// must be owned by the Program #dst.
244   ///
245   /// `replacer` must be function-like with the signature: `T* (T*)`
246   ///  where `T` is a type deriving from Cloneable.
247   ///
248   /// If `replacer` returns a nullptr then Clone() will call `T::Clone()` to
249   /// clone the object.
250   ///
251   /// Example:
252   ///
253   /// ```
254   ///   // Replace all ast::UintLiteralExpressions with the number 42
255   ///   CloneCtx ctx(&out, in);
256   ///   ctx.ReplaceAll([&] (ast::UintLiteralExpression* l) {
257   ///       return ctx->dst->create<ast::UintLiteralExpression>(
258   ///           ctx->Clone(l->source),
259   ///           ctx->Clone(l->type),
260   ///           42);
261   ///     });
262   ///   ctx.Clone();
263   /// ```
264   ///
265   /// @warning a single handler can only be registered for any given type.
266   /// Attempting to register two handlers for the same type will result in an
267   /// ICE.
268   /// @warning The replacement object must be of the correct type for all
269   /// references of the original object. A type mismatch will result in an
270   /// assertion in debug builds, and undefined behavior in release builds.
271   /// @param replacer a function or function-like object with the signature
272   ///        `T* (T*)`, where `T` derives from Cloneable
273   /// @returns this CloneContext so calls can be chained
274   template <typename F>
275   traits::EnableIf<ParamTypeIsPtrOf<F, Cloneable>::value, CloneContext>&
ReplaceAll(F && replacer)276   ReplaceAll(F&& replacer) {
277     using TPtr = traits::ParameterType<F, 0>;
278     using T = typename std::remove_pointer<TPtr>::type;
279     for (auto& transform : transforms_) {
280       if (transform.typeinfo->Is(TypeInfo::Of<T>()) ||
281           TypeInfo::Of<T>().Is(*transform.typeinfo)) {
282         TINT_ICE(Clone, Diagnostics())
283             << "ReplaceAll() called with a handler for type "
284             << TypeInfo::Of<T>().name
285             << " that is already handled by a handler for type "
286             << transform.typeinfo->name;
287         return *this;
288       }
289     }
290     CloneableTransform transform;
291     transform.typeinfo = &TypeInfo::Of<T>();
292     transform.function = [=](const Cloneable* in) {
293       return replacer(in->As<T>());
294     };
295     transforms_.emplace_back(std::move(transform));
296     return *this;
297   }
298 
299   /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
300   /// is called with a Symbol.
301   /// The returned symbol of `replacer` will be used as the replacement for
302   /// all references to the symbol that's being cloned. This returned Symbol
303   /// must be owned by the Program #dst.
304   /// @param replacer a function the signature `Symbol(Symbol)`.
305   /// @warning a SymbolTransform can only be registered once. Attempting to
306   /// register a SymbolTransform more than once will result in an ICE.
307   /// @returns this CloneContext so calls can be chained
ReplaceAll(const SymbolTransform & replacer)308   CloneContext& ReplaceAll(const SymbolTransform& replacer) {
309     if (symbol_transform_) {
310       TINT_ICE(Clone, Diagnostics())
311           << "ReplaceAll(const SymbolTransform&) called "
312              "multiple times on the same CloneContext";
313       return *this;
314     }
315     symbol_transform_ = replacer;
316     return *this;
317   }
318 
319   /// Replace replaces all occurrences of `what` in #src with the pointer `with`
320   /// in #dst when calling Clone().
321   /// [DEPRECATED]: This function cannot handle nested replacements. Use the
322   /// overload of Replace() that take a function for the `WITH` argument.
323   /// @param what a pointer to the object in #src that will be replaced with
324   /// `with`
325   /// @param with a pointer to the replacement object owned by #dst that will be
326   /// used as a replacement for `what`
327   /// @warning The replacement object must be of the correct type for all
328   /// references of the original object. A type mismatch will result in an
329   /// assertion in debug builds, and undefined behavior in release builds.
330   /// @returns this CloneContext so calls can be chained
331   template <typename WHAT,
332             typename WITH,
333             typename = traits::EnableIfIsType<WITH, Cloneable>>
Replace(const WHAT * what,const WITH * with)334   CloneContext& Replace(const WHAT* what, const WITH* with) {
335     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
336     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, with);
337     replacements_[what] = [with]() -> const Cloneable* { return with; };
338     return *this;
339   }
340 
341   /// Replace replaces all occurrences of `what` in #src with the result of the
342   /// function `with` in #dst when calling Clone(). `with` will be called each
343   /// time `what` is cloned by this context. If `what` is not cloned, then
344   /// `with` may never be called.
345   /// @param what a pointer to the object in #src that will be replaced with
346   /// `with`
347   /// @param with a function that takes no arguments and returns a pointer to
348   /// the replacement object owned by #dst. The returned pointer will be used as
349   /// a replacement for `what`.
350   /// @warning The replacement object must be of the correct type for all
351   /// references of the original object. A type mismatch will result in an
352   /// assertion in debug builds, and undefined behavior in release builds.
353   /// @returns this CloneContext so calls can be chained
354   template <typename WHAT, typename WITH, typename = std::result_of_t<WITH()>>
Replace(const WHAT * what,WITH && with)355   CloneContext& Replace(const WHAT* what, WITH&& with) {
356     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
357     replacements_[what] = with;
358     return *this;
359   }
360 
361   /// Removes `object` from the cloned copy of `vector`.
362   /// @param vector the vector in #src
363   /// @param object a pointer to the object in #src that will be omitted from
364   /// the cloned vector.
365   /// @returns this CloneContext so calls can be chained
366   template <typename T, typename OBJECT>
Remove(const std::vector<T> & vector,OBJECT * object)367   CloneContext& Remove(const std::vector<T>& vector, OBJECT* object) {
368     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, object);
369     if (std::find(vector.begin(), vector.end(), object) == vector.end()) {
370       TINT_ICE(Clone, Diagnostics())
371           << "CloneContext::Remove() vector does not contain object";
372       return *this;
373     }
374 
375     list_transforms_[&vector].remove_.emplace(object);
376     return *this;
377   }
378 
379   /// Inserts `object` before any other objects of `vector`, when it is cloned.
380   /// @param vector the vector in #src
381   /// @param object a pointer to the object in #dst that will be inserted at the
382   /// front of the vector
383   /// @returns this CloneContext so calls can be chained
384   template <typename T, typename OBJECT>
InsertFront(const std::vector<T> & vector,OBJECT * object)385   CloneContext& InsertFront(const std::vector<T>& vector, OBJECT* object) {
386     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
387     auto& transforms = list_transforms_[&vector];
388     auto& list = transforms.insert_front_;
389     list.emplace_back(object);
390     return *this;
391   }
392 
393   /// Inserts `object` after any other objects of `vector`, when it is cloned.
394   /// @param vector the vector in #src
395   /// @param object a pointer to the object in #dst that will be inserted at the
396   /// end of the vector
397   /// @returns this CloneContext so calls can be chained
398   template <typename T, typename OBJECT>
InsertBack(const std::vector<T> & vector,OBJECT * object)399   CloneContext& InsertBack(const std::vector<T>& vector, OBJECT* object) {
400     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
401     auto& transforms = list_transforms_[&vector];
402     auto& list = transforms.insert_back_;
403     list.emplace_back(object);
404     return *this;
405   }
406 
407   /// Inserts `object` before `before` whenever `vector` is cloned.
408   /// @param vector the vector in #src
409   /// @param before a pointer to the object in #src
410   /// @param object a pointer to the object in #dst that will be inserted before
411   /// any occurrence of the clone of `before`
412   /// @returns this CloneContext so calls can be chained
413   template <typename T, typename BEFORE, typename OBJECT>
InsertBefore(const std::vector<T> & vector,const BEFORE * before,const OBJECT * object)414   CloneContext& InsertBefore(const std::vector<T>& vector,
415                              const BEFORE* before,
416                              const OBJECT* object) {
417     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, before);
418     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
419     if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
420       TINT_ICE(Clone, Diagnostics())
421           << "CloneContext::InsertBefore() vector does not contain before";
422       return *this;
423     }
424 
425     auto& transforms = list_transforms_[&vector];
426     auto& list = transforms.insert_before_[before];
427     list.emplace_back(object);
428     return *this;
429   }
430 
431   /// Inserts `object` after `after` whenever `vector` is cloned.
432   /// @param vector the vector in #src
433   /// @param after a pointer to the object in #src
434   /// @param object a pointer to the object in #dst that will be inserted after
435   /// any occurrence of the clone of `after`
436   /// @returns this CloneContext so calls can be chained
437   template <typename T, typename AFTER, typename OBJECT>
InsertAfter(const std::vector<T> & vector,const AFTER * after,const OBJECT * object)438   CloneContext& InsertAfter(const std::vector<T>& vector,
439                             const AFTER* after,
440                             const OBJECT* object) {
441     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, after);
442     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
443     if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
444       TINT_ICE(Clone, Diagnostics())
445           << "CloneContext::InsertAfter() vector does not contain after";
446       return *this;
447     }
448 
449     auto& transforms = list_transforms_[&vector];
450     auto& list = transforms.insert_after_[after];
451     list.emplace_back(object);
452     return *this;
453   }
454 
455   /// Clone performs the clone of the Program's AST nodes, types and symbols
456   /// from #src to #dst. Semantic nodes are not cloned, as these will be rebuilt
457   /// when the ProgramBuilder #dst builds its Program.
458   void Clone();
459 
460   /// The target ProgramBuilder to clone into.
461   ProgramBuilder* const dst;
462 
463   /// The source Program to clone from.
464   Program const* const src;
465 
466  private:
467   struct CloneableTransform {
468     /// Constructor
469     CloneableTransform();
470     /// Copy constructor
471     /// @param other the CloneableTransform to copy
472     CloneableTransform(const CloneableTransform& other);
473     /// Destructor
474     ~CloneableTransform();
475 
476     // TypeInfo of the Cloneable that the transform operates on
477     const TypeInfo* typeinfo;
478     std::function<const Cloneable*(const Cloneable*)> function;
479   };
480 
481   CloneContext(const CloneContext&) = delete;
482   CloneContext& operator=(const CloneContext&) = delete;
483 
484   /// Cast `obj` from type `FROM` to type `TO`, returning the cast object.
485   /// Reports an internal compiler error if the cast failed.
486   template <typename TO, typename FROM>
CheckedCast(const FROM * obj)487   const TO* CheckedCast(const FROM* obj) {
488     if (obj == nullptr) {
489       return nullptr;
490     }
491     if (const TO* cast = obj->template As<TO>()) {
492       return cast;
493     }
494     CheckedCastFailure(obj, TypeInfo::Of<TO>());
495     return nullptr;
496   }
497 
498   /// Clones a Cloneable object, using any replacements or transforms that have
499   /// been configured.
500   const Cloneable* CloneCloneable(const Cloneable* object);
501 
502   /// Adds an error diagnostic to Diagnostics() that the cloned object was not
503   /// of the expected type.
504   void CheckedCastFailure(const Cloneable* got, const TypeInfo& expected);
505 
506   /// @returns the diagnostic list of #dst
507   diag::List& Diagnostics() const;
508 
509   /// A vector of const Cloneable*
510   using CloneableList = std::vector<const Cloneable*>;
511 
512   /// Transformations to be applied to a list (vector)
513   struct ListTransforms {
514     /// Constructor
515     ListTransforms();
516     /// Destructor
517     ~ListTransforms();
518 
519     /// A map of object in #src to omit when cloned into #dst.
520     std::unordered_set<const Cloneable*> remove_;
521 
522     /// A list of objects in #dst to insert before any others when the vector is
523     /// cloned.
524     CloneableList insert_front_;
525 
526     /// A list of objects in #dst to insert befor after any others when the
527     /// vector is cloned.
528     CloneableList insert_back_;
529 
530     /// A map of object in #src to the list of cloned objects in #dst.
531     /// Clone(const std::vector<T*>& v) will use this to insert the map-value
532     /// list into the target vector before cloning and inserting the map-key.
533     std::unordered_map<const Cloneable*, CloneableList> insert_before_;
534 
535     /// A map of object in #src to the list of cloned objects in #dst.
536     /// Clone(const std::vector<T*>& v) will use this to insert the map-value
537     /// list into the target vector after cloning and inserting the map-key.
538     std::unordered_map<const Cloneable*, CloneableList> insert_after_;
539   };
540 
541   /// A map of object in #src to functions that create their replacement in
542   /// #dst
543   std::unordered_map<const Cloneable*, std::function<const Cloneable*()>>
544       replacements_;
545 
546   /// A map of symbol in #src to their cloned equivalent in #dst
547   std::unordered_map<Symbol, Symbol> cloned_symbols_;
548 
549   /// Cloneable transform functions registered with ReplaceAll()
550   std::vector<CloneableTransform> transforms_;
551 
552   /// Map of std::vector pointer to transforms for that list
553   std::unordered_map<const void*, ListTransforms> list_transforms_;
554 
555   /// Symbol transform registered with ReplaceAll()
556   SymbolTransform symbol_transform_;
557 };
558 
559 }  // namespace tint
560 
561 #endif  // SRC_CLONE_CONTEXT_H_
562