1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef SRC_CLONE_CONTEXT_H_
16 #define SRC_CLONE_CONTEXT_H_
17
18 #include <algorithm>
19 #include <functional>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "src/castable.h"
26 #include "src/debug.h"
27 #include "src/program_id.h"
28 #include "src/symbol.h"
29 #include "src/traits.h"
30
31 namespace tint {
32
33 // Forward declarations
34 class CloneContext;
35 class Program;
36 class ProgramBuilder;
37 namespace ast {
38 class FunctionList;
39 class Node;
40 } // namespace ast
41
42 ProgramID ProgramIDOf(const Program*);
43 ProgramID ProgramIDOf(const ProgramBuilder*);
44
45 /// Cloneable is the base class for all objects that can be cloned
46 class Cloneable : public Castable<Cloneable> {
47 public:
48 /// Performs a deep clone of this object using the CloneContext `ctx`.
49 /// @param ctx the clone context
50 /// @return the newly cloned object
51 virtual const Cloneable* Clone(CloneContext* ctx) const = 0;
52 };
53
54 /// @returns an invalid ProgramID
ProgramIDOf(const Cloneable *)55 inline ProgramID ProgramIDOf(const Cloneable*) {
56 return ProgramID();
57 }
58
59 /// CloneContext holds the state used while cloning AST nodes.
60 class CloneContext {
61 /// ParamTypeIsPtrOf<F, T>::value is true iff the first parameter of
62 /// F is a pointer of (or derives from) type T.
63 template <typename F, typename T>
64 using ParamTypeIsPtrOf = traits::IsTypeOrDerived<
65 typename std::remove_pointer<traits::ParameterType<F, 0>>::type,
66 T>;
67
68 public:
69 /// SymbolTransform is a function that takes a symbol and returns a new
70 /// symbol.
71 using SymbolTransform = std::function<Symbol(Symbol)>;
72
73 /// Constructor for cloning objects from `from` into `to`.
74 /// @param to the target ProgramBuilder to clone into
75 /// @param from the source Program to clone from
76 /// @param auto_clone_symbols clone all symbols in `from` before returning
77 CloneContext(ProgramBuilder* to,
78 Program const* from,
79 bool auto_clone_symbols = true);
80
81 /// Constructor for cloning objects from and to the ProgramBuilder `builder`.
82 /// @param builder the ProgramBuilder
83 explicit CloneContext(ProgramBuilder* builder);
84
85 /// Destructor
86 ~CloneContext();
87
88 /// Clones the Node or sem::Type `a` into the ProgramBuilder #dst if `a` is
89 /// not null. If `a` is null, then Clone() returns null.
90 ///
91 /// Clone() may use a function registered with ReplaceAll() to create a
92 /// transformed version of the object. See ReplaceAll() for more information.
93 ///
94 /// If the CloneContext is cloning from a Program to a ProgramBuilder, then
95 /// the Node or sem::Type `a` must be owned by the Program #src.
96 ///
97 /// @param object the type deriving from Cloneable to clone
98 /// @return the cloned node
99 template <typename T>
Clone(const T * object)100 const T* Clone(const T* object) {
101 if (src) {
102 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, object);
103 }
104 if (auto* cloned = CloneCloneable(object)) {
105 auto* out = CheckedCast<T>(cloned);
106 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, out);
107 return out;
108 }
109 return nullptr;
110 }
111
112 /// Clones the Node or sem::Type `a` into the ProgramBuilder #dst if `a` is
113 /// not null. If `a` is null, then Clone() returns null.
114 ///
115 /// Unlike Clone(), this method does not invoke or use any transformations
116 /// registered by ReplaceAll().
117 ///
118 /// If the CloneContext is cloning from a Program to a ProgramBuilder, then
119 /// the Node or sem::Type `a` must be owned by the Program #src.
120 ///
121 /// @param a the type deriving from Cloneable to clone
122 /// @return the cloned node
123 template <typename T>
CloneWithoutTransform(const T * a)124 const T* CloneWithoutTransform(const T* a) {
125 // If the input is nullptr, there's nothing to clone - just return nullptr.
126 if (a == nullptr) {
127 return nullptr;
128 }
129 if (src) {
130 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, a);
131 }
132 auto* c = a->Clone(this);
133 return CheckedCast<T>(c);
134 }
135
136 /// Clones the Source `s` into #dst
137 /// TODO(bclayton) - Currently this 'clone' is a shallow copy. If/when
138 /// `Source.File`s are owned by the Program this should make a copy of the
139 /// file.
140 /// @param s the `Source` to clone
141 /// @return the cloned source
Clone(const Source & s)142 Source Clone(const Source& s) const { return s; }
143
144 /// Clones the Symbol `s` into #dst
145 ///
146 /// The Symbol `s` must be owned by the Program #src.
147 ///
148 /// @param s the Symbol to clone
149 /// @return the cloned source
150 Symbol Clone(Symbol s);
151
152 /// Clones each of the elements of the vector `v` into the ProgramBuilder
153 /// #dst.
154 ///
155 /// All the elements of the vector `v` must be owned by the Program #src.
156 ///
157 /// @param v the vector to clone
158 /// @return the cloned vector
159 template <typename T>
Clone(const std::vector<T> & v)160 std::vector<T> Clone(const std::vector<T>& v) {
161 std::vector<T> out;
162 out.reserve(v.size());
163 for (auto& el : v) {
164 out.emplace_back(Clone(el));
165 }
166 return out;
167 }
168
169 /// Clones each of the elements of the vector `v` using the ProgramBuilder
170 /// #dst, inserting any additional elements into the list that were registered
171 /// with calls to InsertBefore().
172 ///
173 /// All the elements of the vector `v` must be owned by the Program #src.
174 ///
175 /// @param v the vector to clone
176 /// @return the cloned vector
177 template <typename T>
Clone(const std::vector<T * > & v)178 std::vector<T*> Clone(const std::vector<T*>& v) {
179 std::vector<T*> out;
180 Clone(out, v);
181 return out;
182 }
183
184 /// Clones each of the elements of the vector `from` into the vector `to`,
185 /// inserting any additional elements into the list that were registered with
186 /// calls to InsertBefore().
187 ///
188 /// All the elements of the vector `from` must be owned by the Program #src.
189 ///
190 /// @param from the vector to clone
191 /// @param to the cloned result
192 template <typename T>
Clone(std::vector<T * > & to,const std::vector<T * > & from)193 void Clone(std::vector<T*>& to, const std::vector<T*>& from) {
194 to.reserve(from.size());
195
196 auto list_transform_it = list_transforms_.find(&from);
197 if (list_transform_it != list_transforms_.end()) {
198 const auto& transforms = list_transform_it->second;
199 for (auto* o : transforms.insert_front_) {
200 to.emplace_back(CheckedCast<T>(o));
201 }
202 for (auto& el : from) {
203 auto insert_before_it = transforms.insert_before_.find(el);
204 if (insert_before_it != transforms.insert_before_.end()) {
205 for (auto insert : insert_before_it->second) {
206 to.emplace_back(CheckedCast<T>(insert));
207 }
208 }
209 if (transforms.remove_.count(el) == 0) {
210 to.emplace_back(Clone(el));
211 }
212 auto insert_after_it = transforms.insert_after_.find(el);
213 if (insert_after_it != transforms.insert_after_.end()) {
214 for (auto insert : insert_after_it->second) {
215 to.emplace_back(CheckedCast<T>(insert));
216 }
217 }
218 }
219 for (auto* o : transforms.insert_back_) {
220 to.emplace_back(CheckedCast<T>(o));
221 }
222 } else {
223 for (auto& el : from) {
224 to.emplace_back(Clone(el));
225 }
226 }
227 }
228
229 /// Clones each of the elements of the vector `v` into the ProgramBuilder
230 /// #dst.
231 ///
232 /// All the elements of the vector `v` must be owned by the Program #src.
233 ///
234 /// @param v the vector to clone
235 /// @return the cloned vector
236 ast::FunctionList Clone(const ast::FunctionList& v);
237
238 /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
239 /// is called with a Cloneable type that matches (or derives from) the type of
240 /// the single parameter of `replacer`.
241 /// The returned Cloneable of `replacer` will be used as the replacement for
242 /// all references to the object that's being cloned. This returned Cloneable
243 /// must be owned by the Program #dst.
244 ///
245 /// `replacer` must be function-like with the signature: `T* (T*)`
246 /// where `T` is a type deriving from Cloneable.
247 ///
248 /// If `replacer` returns a nullptr then Clone() will call `T::Clone()` to
249 /// clone the object.
250 ///
251 /// Example:
252 ///
253 /// ```
254 /// // Replace all ast::UintLiteralExpressions with the number 42
255 /// CloneCtx ctx(&out, in);
256 /// ctx.ReplaceAll([&] (ast::UintLiteralExpression* l) {
257 /// return ctx->dst->create<ast::UintLiteralExpression>(
258 /// ctx->Clone(l->source),
259 /// ctx->Clone(l->type),
260 /// 42);
261 /// });
262 /// ctx.Clone();
263 /// ```
264 ///
265 /// @warning a single handler can only be registered for any given type.
266 /// Attempting to register two handlers for the same type will result in an
267 /// ICE.
268 /// @warning The replacement object must be of the correct type for all
269 /// references of the original object. A type mismatch will result in an
270 /// assertion in debug builds, and undefined behavior in release builds.
271 /// @param replacer a function or function-like object with the signature
272 /// `T* (T*)`, where `T` derives from Cloneable
273 /// @returns this CloneContext so calls can be chained
274 template <typename F>
275 traits::EnableIf<ParamTypeIsPtrOf<F, Cloneable>::value, CloneContext>&
ReplaceAll(F && replacer)276 ReplaceAll(F&& replacer) {
277 using TPtr = traits::ParameterType<F, 0>;
278 using T = typename std::remove_pointer<TPtr>::type;
279 for (auto& transform : transforms_) {
280 if (transform.typeinfo->Is(TypeInfo::Of<T>()) ||
281 TypeInfo::Of<T>().Is(*transform.typeinfo)) {
282 TINT_ICE(Clone, Diagnostics())
283 << "ReplaceAll() called with a handler for type "
284 << TypeInfo::Of<T>().name
285 << " that is already handled by a handler for type "
286 << transform.typeinfo->name;
287 return *this;
288 }
289 }
290 CloneableTransform transform;
291 transform.typeinfo = &TypeInfo::Of<T>();
292 transform.function = [=](const Cloneable* in) {
293 return replacer(in->As<T>());
294 };
295 transforms_.emplace_back(std::move(transform));
296 return *this;
297 }
298
299 /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
300 /// is called with a Symbol.
301 /// The returned symbol of `replacer` will be used as the replacement for
302 /// all references to the symbol that's being cloned. This returned Symbol
303 /// must be owned by the Program #dst.
304 /// @param replacer a function the signature `Symbol(Symbol)`.
305 /// @warning a SymbolTransform can only be registered once. Attempting to
306 /// register a SymbolTransform more than once will result in an ICE.
307 /// @returns this CloneContext so calls can be chained
ReplaceAll(const SymbolTransform & replacer)308 CloneContext& ReplaceAll(const SymbolTransform& replacer) {
309 if (symbol_transform_) {
310 TINT_ICE(Clone, Diagnostics())
311 << "ReplaceAll(const SymbolTransform&) called "
312 "multiple times on the same CloneContext";
313 return *this;
314 }
315 symbol_transform_ = replacer;
316 return *this;
317 }
318
319 /// Replace replaces all occurrences of `what` in #src with the pointer `with`
320 /// in #dst when calling Clone().
321 /// [DEPRECATED]: This function cannot handle nested replacements. Use the
322 /// overload of Replace() that take a function for the `WITH` argument.
323 /// @param what a pointer to the object in #src that will be replaced with
324 /// `with`
325 /// @param with a pointer to the replacement object owned by #dst that will be
326 /// used as a replacement for `what`
327 /// @warning The replacement object must be of the correct type for all
328 /// references of the original object. A type mismatch will result in an
329 /// assertion in debug builds, and undefined behavior in release builds.
330 /// @returns this CloneContext so calls can be chained
331 template <typename WHAT,
332 typename WITH,
333 typename = traits::EnableIfIsType<WITH, Cloneable>>
Replace(const WHAT * what,const WITH * with)334 CloneContext& Replace(const WHAT* what, const WITH* with) {
335 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
336 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, with);
337 replacements_[what] = [with]() -> const Cloneable* { return with; };
338 return *this;
339 }
340
341 /// Replace replaces all occurrences of `what` in #src with the result of the
342 /// function `with` in #dst when calling Clone(). `with` will be called each
343 /// time `what` is cloned by this context. If `what` is not cloned, then
344 /// `with` may never be called.
345 /// @param what a pointer to the object in #src that will be replaced with
346 /// `with`
347 /// @param with a function that takes no arguments and returns a pointer to
348 /// the replacement object owned by #dst. The returned pointer will be used as
349 /// a replacement for `what`.
350 /// @warning The replacement object must be of the correct type for all
351 /// references of the original object. A type mismatch will result in an
352 /// assertion in debug builds, and undefined behavior in release builds.
353 /// @returns this CloneContext so calls can be chained
354 template <typename WHAT, typename WITH, typename = std::result_of_t<WITH()>>
Replace(const WHAT * what,WITH && with)355 CloneContext& Replace(const WHAT* what, WITH&& with) {
356 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
357 replacements_[what] = with;
358 return *this;
359 }
360
361 /// Removes `object` from the cloned copy of `vector`.
362 /// @param vector the vector in #src
363 /// @param object a pointer to the object in #src that will be omitted from
364 /// the cloned vector.
365 /// @returns this CloneContext so calls can be chained
366 template <typename T, typename OBJECT>
Remove(const std::vector<T> & vector,OBJECT * object)367 CloneContext& Remove(const std::vector<T>& vector, OBJECT* object) {
368 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, object);
369 if (std::find(vector.begin(), vector.end(), object) == vector.end()) {
370 TINT_ICE(Clone, Diagnostics())
371 << "CloneContext::Remove() vector does not contain object";
372 return *this;
373 }
374
375 list_transforms_[&vector].remove_.emplace(object);
376 return *this;
377 }
378
379 /// Inserts `object` before any other objects of `vector`, when it is cloned.
380 /// @param vector the vector in #src
381 /// @param object a pointer to the object in #dst that will be inserted at the
382 /// front of the vector
383 /// @returns this CloneContext so calls can be chained
384 template <typename T, typename OBJECT>
InsertFront(const std::vector<T> & vector,OBJECT * object)385 CloneContext& InsertFront(const std::vector<T>& vector, OBJECT* object) {
386 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
387 auto& transforms = list_transforms_[&vector];
388 auto& list = transforms.insert_front_;
389 list.emplace_back(object);
390 return *this;
391 }
392
393 /// Inserts `object` after any other objects of `vector`, when it is cloned.
394 /// @param vector the vector in #src
395 /// @param object a pointer to the object in #dst that will be inserted at the
396 /// end of the vector
397 /// @returns this CloneContext so calls can be chained
398 template <typename T, typename OBJECT>
InsertBack(const std::vector<T> & vector,OBJECT * object)399 CloneContext& InsertBack(const std::vector<T>& vector, OBJECT* object) {
400 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
401 auto& transforms = list_transforms_[&vector];
402 auto& list = transforms.insert_back_;
403 list.emplace_back(object);
404 return *this;
405 }
406
407 /// Inserts `object` before `before` whenever `vector` is cloned.
408 /// @param vector the vector in #src
409 /// @param before a pointer to the object in #src
410 /// @param object a pointer to the object in #dst that will be inserted before
411 /// any occurrence of the clone of `before`
412 /// @returns this CloneContext so calls can be chained
413 template <typename T, typename BEFORE, typename OBJECT>
InsertBefore(const std::vector<T> & vector,const BEFORE * before,const OBJECT * object)414 CloneContext& InsertBefore(const std::vector<T>& vector,
415 const BEFORE* before,
416 const OBJECT* object) {
417 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, before);
418 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
419 if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
420 TINT_ICE(Clone, Diagnostics())
421 << "CloneContext::InsertBefore() vector does not contain before";
422 return *this;
423 }
424
425 auto& transforms = list_transforms_[&vector];
426 auto& list = transforms.insert_before_[before];
427 list.emplace_back(object);
428 return *this;
429 }
430
431 /// Inserts `object` after `after` whenever `vector` is cloned.
432 /// @param vector the vector in #src
433 /// @param after a pointer to the object in #src
434 /// @param object a pointer to the object in #dst that will be inserted after
435 /// any occurrence of the clone of `after`
436 /// @returns this CloneContext so calls can be chained
437 template <typename T, typename AFTER, typename OBJECT>
InsertAfter(const std::vector<T> & vector,const AFTER * after,const OBJECT * object)438 CloneContext& InsertAfter(const std::vector<T>& vector,
439 const AFTER* after,
440 const OBJECT* object) {
441 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, after);
442 TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
443 if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
444 TINT_ICE(Clone, Diagnostics())
445 << "CloneContext::InsertAfter() vector does not contain after";
446 return *this;
447 }
448
449 auto& transforms = list_transforms_[&vector];
450 auto& list = transforms.insert_after_[after];
451 list.emplace_back(object);
452 return *this;
453 }
454
455 /// Clone performs the clone of the Program's AST nodes, types and symbols
456 /// from #src to #dst. Semantic nodes are not cloned, as these will be rebuilt
457 /// when the ProgramBuilder #dst builds its Program.
458 void Clone();
459
460 /// The target ProgramBuilder to clone into.
461 ProgramBuilder* const dst;
462
463 /// The source Program to clone from.
464 Program const* const src;
465
466 private:
467 struct CloneableTransform {
468 /// Constructor
469 CloneableTransform();
470 /// Copy constructor
471 /// @param other the CloneableTransform to copy
472 CloneableTransform(const CloneableTransform& other);
473 /// Destructor
474 ~CloneableTransform();
475
476 // TypeInfo of the Cloneable that the transform operates on
477 const TypeInfo* typeinfo;
478 std::function<const Cloneable*(const Cloneable*)> function;
479 };
480
481 CloneContext(const CloneContext&) = delete;
482 CloneContext& operator=(const CloneContext&) = delete;
483
484 /// Cast `obj` from type `FROM` to type `TO`, returning the cast object.
485 /// Reports an internal compiler error if the cast failed.
486 template <typename TO, typename FROM>
CheckedCast(const FROM * obj)487 const TO* CheckedCast(const FROM* obj) {
488 if (obj == nullptr) {
489 return nullptr;
490 }
491 if (const TO* cast = obj->template As<TO>()) {
492 return cast;
493 }
494 CheckedCastFailure(obj, TypeInfo::Of<TO>());
495 return nullptr;
496 }
497
498 /// Clones a Cloneable object, using any replacements or transforms that have
499 /// been configured.
500 const Cloneable* CloneCloneable(const Cloneable* object);
501
502 /// Adds an error diagnostic to Diagnostics() that the cloned object was not
503 /// of the expected type.
504 void CheckedCastFailure(const Cloneable* got, const TypeInfo& expected);
505
506 /// @returns the diagnostic list of #dst
507 diag::List& Diagnostics() const;
508
509 /// A vector of const Cloneable*
510 using CloneableList = std::vector<const Cloneable*>;
511
512 /// Transformations to be applied to a list (vector)
513 struct ListTransforms {
514 /// Constructor
515 ListTransforms();
516 /// Destructor
517 ~ListTransforms();
518
519 /// A map of object in #src to omit when cloned into #dst.
520 std::unordered_set<const Cloneable*> remove_;
521
522 /// A list of objects in #dst to insert before any others when the vector is
523 /// cloned.
524 CloneableList insert_front_;
525
526 /// A list of objects in #dst to insert befor after any others when the
527 /// vector is cloned.
528 CloneableList insert_back_;
529
530 /// A map of object in #src to the list of cloned objects in #dst.
531 /// Clone(const std::vector<T*>& v) will use this to insert the map-value
532 /// list into the target vector before cloning and inserting the map-key.
533 std::unordered_map<const Cloneable*, CloneableList> insert_before_;
534
535 /// A map of object in #src to the list of cloned objects in #dst.
536 /// Clone(const std::vector<T*>& v) will use this to insert the map-value
537 /// list into the target vector after cloning and inserting the map-key.
538 std::unordered_map<const Cloneable*, CloneableList> insert_after_;
539 };
540
541 /// A map of object in #src to functions that create their replacement in
542 /// #dst
543 std::unordered_map<const Cloneable*, std::function<const Cloneable*()>>
544 replacements_;
545
546 /// A map of symbol in #src to their cloned equivalent in #dst
547 std::unordered_map<Symbol, Symbol> cloned_symbols_;
548
549 /// Cloneable transform functions registered with ReplaceAll()
550 std::vector<CloneableTransform> transforms_;
551
552 /// Map of std::vector pointer to transforms for that list
553 std::unordered_map<const void*, ListTransforms> list_transforms_;
554
555 /// Symbol transform registered with ReplaceAll()
556 SymbolTransform symbol_transform_;
557 };
558
559 } // namespace tint
560
561 #endif // SRC_CLONE_CONTEXT_H_
562