• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Marshallers.h - Generic matcher function marshallers -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Functions templates and classes to wrap matcher construct functions.
11 ///
12 /// A collection of template function and classes that provide a generic
13 /// marshalling layer on top of matcher construct functions.
14 /// These are used by the registry to export all marshaller constructors with
15 /// the same generic interface.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H
20 #define LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H
21 
22 #include "clang/AST/ASTTypeTraits.h"
23 #include "clang/AST/OperationKinds.h"
24 #include "clang/ASTMatchers/ASTMatchersInternal.h"
25 #include "clang/ASTMatchers/Dynamic/Diagnostics.h"
26 #include "clang/ASTMatchers/Dynamic/VariantValue.h"
27 #include "clang/Basic/AttrKinds.h"
28 #include "clang/Basic/LLVM.h"
29 #include "clang/Basic/OpenMPKinds.h"
30 #include "clang/Basic/TypeTraits.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/None.h"
33 #include "llvm/ADT/Optional.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Support/Regex.h"
39 #include <cassert>
40 #include <cstddef>
41 #include <iterator>
42 #include <limits>
43 #include <memory>
44 #include <string>
45 #include <utility>
46 #include <vector>
47 
48 namespace clang {
49 namespace ast_matchers {
50 namespace dynamic {
51 namespace internal {
52 
53 /// Helper template class to just from argument type to the right is/get
54 ///   functions in VariantValue.
55 /// Used to verify and extract the matcher arguments below.
56 template <class T> struct ArgTypeTraits;
57 template <class T> struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {
58 };
59 
60 template <> struct ArgTypeTraits<std::string> {
61   static bool hasCorrectType(const VariantValue &Value) {
62     return Value.isString();
63   }
64   static bool hasCorrectValue(const VariantValue &Value) { return true; }
65 
66   static const std::string &get(const VariantValue &Value) {
67     return Value.getString();
68   }
69 
70   static ArgKind getKind() {
71     return ArgKind(ArgKind::AK_String);
72   }
73 
74   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
75     return llvm::None;
76   }
77 };
78 
79 template <>
80 struct ArgTypeTraits<StringRef> : public ArgTypeTraits<std::string> {
81 };
82 
83 template <class T> struct ArgTypeTraits<ast_matchers::internal::Matcher<T>> {
84   static bool hasCorrectType(const VariantValue& Value) {
85     return Value.isMatcher();
86   }
87   static bool hasCorrectValue(const VariantValue &Value) {
88     return Value.getMatcher().hasTypedMatcher<T>();
89   }
90 
91   static ast_matchers::internal::Matcher<T> get(const VariantValue &Value) {
92     return Value.getMatcher().getTypedMatcher<T>();
93   }
94 
95   static ArgKind getKind() {
96     return ArgKind(ASTNodeKind::getFromNodeKind<T>());
97   }
98 
99   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
100     return llvm::None;
101   }
102 };
103 
104 template <> struct ArgTypeTraits<bool> {
105   static bool hasCorrectType(const VariantValue &Value) {
106     return Value.isBoolean();
107   }
108   static bool hasCorrectValue(const VariantValue &Value) { return true; }
109 
110   static bool get(const VariantValue &Value) {
111     return Value.getBoolean();
112   }
113 
114   static ArgKind getKind() {
115     return ArgKind(ArgKind::AK_Boolean);
116   }
117 
118   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
119     return llvm::None;
120   }
121 };
122 
123 template <> struct ArgTypeTraits<double> {
124   static bool hasCorrectType(const VariantValue &Value) {
125     return Value.isDouble();
126   }
127   static bool hasCorrectValue(const VariantValue &Value) { return true; }
128 
129   static double get(const VariantValue &Value) {
130     return Value.getDouble();
131   }
132 
133   static ArgKind getKind() {
134     return ArgKind(ArgKind::AK_Double);
135   }
136 
137   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
138     return llvm::None;
139   }
140 };
141 
142 template <> struct ArgTypeTraits<unsigned> {
143   static bool hasCorrectType(const VariantValue &Value) {
144     return Value.isUnsigned();
145   }
146   static bool hasCorrectValue(const VariantValue &Value) { return true; }
147 
148   static unsigned get(const VariantValue &Value) {
149     return Value.getUnsigned();
150   }
151 
152   static ArgKind getKind() {
153     return ArgKind(ArgKind::AK_Unsigned);
154   }
155 
156   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
157     return llvm::None;
158   }
159 };
160 
161 template <> struct ArgTypeTraits<attr::Kind> {
162 private:
163   static Optional<attr::Kind> getAttrKind(llvm::StringRef AttrKind) {
164     return llvm::StringSwitch<Optional<attr::Kind>>(AttrKind)
165 #define ATTR(X) .Case("attr::" #X, attr:: X)
166 #include "clang/Basic/AttrList.inc"
167         .Default(llvm::None);
168   }
169 
170 public:
171   static bool hasCorrectType(const VariantValue &Value) {
172     return Value.isString();
173   }
174   static bool hasCorrectValue(const VariantValue& Value) {
175     return getAttrKind(Value.getString()).hasValue();
176   }
177 
178   static attr::Kind get(const VariantValue &Value) {
179     return *getAttrKind(Value.getString());
180   }
181 
182   static ArgKind getKind() {
183     return ArgKind(ArgKind::AK_String);
184   }
185 
186   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
187 };
188 
189 template <> struct ArgTypeTraits<CastKind> {
190 private:
191   static Optional<CastKind> getCastKind(llvm::StringRef AttrKind) {
192     return llvm::StringSwitch<Optional<CastKind>>(AttrKind)
193 #define CAST_OPERATION(Name) .Case("CK_" #Name, CK_##Name)
194 #include "clang/AST/OperationKinds.def"
195         .Default(llvm::None);
196   }
197 
198 public:
199   static bool hasCorrectType(const VariantValue &Value) {
200     return Value.isString();
201   }
202   static bool hasCorrectValue(const VariantValue& Value) {
203     return getCastKind(Value.getString()).hasValue();
204   }
205 
206   static CastKind get(const VariantValue &Value) {
207     return *getCastKind(Value.getString());
208   }
209 
210   static ArgKind getKind() {
211     return ArgKind(ArgKind::AK_String);
212   }
213 
214   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
215 };
216 
217 template <> struct ArgTypeTraits<llvm::Regex::RegexFlags> {
218 private:
219   static Optional<llvm::Regex::RegexFlags> getFlags(llvm::StringRef Flags);
220 
221 public:
222   static bool hasCorrectType(const VariantValue &Value) {
223     return Value.isString();
224   }
225   static bool hasCorrectValue(const VariantValue& Value) {
226     return getFlags(Value.getString()).hasValue();
227   }
228 
229   static llvm::Regex::RegexFlags get(const VariantValue &Value) {
230     return *getFlags(Value.getString());
231   }
232 
233   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
234 
235   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
236 };
237 
238 template <> struct ArgTypeTraits<OpenMPClauseKind> {
239 private:
240   static Optional<OpenMPClauseKind> getClauseKind(llvm::StringRef ClauseKind) {
241     return llvm::StringSwitch<Optional<OpenMPClauseKind>>(ClauseKind)
242 #define OMP_CLAUSE_CLASS(Enum, Str, Class) .Case(#Enum, llvm::omp::Clause::Enum)
243 #include "llvm/Frontend/OpenMP/OMPKinds.def"
244         .Default(llvm::None);
245   }
246 
247 public:
248   static bool hasCorrectType(const VariantValue &Value) {
249     return Value.isString();
250   }
251   static bool hasCorrectValue(const VariantValue& Value) {
252     return getClauseKind(Value.getString()).hasValue();
253   }
254 
255   static OpenMPClauseKind get(const VariantValue &Value) {
256     return *getClauseKind(Value.getString());
257   }
258 
259   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
260 
261   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
262 };
263 
264 template <> struct ArgTypeTraits<UnaryExprOrTypeTrait> {
265 private:
266   static Optional<UnaryExprOrTypeTrait>
267   getUnaryOrTypeTraitKind(llvm::StringRef ClauseKind) {
268     return llvm::StringSwitch<Optional<UnaryExprOrTypeTrait>>(ClauseKind)
269 #define UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key)                          \
270   .Case("UETT_" #Name, UETT_##Name)
271 #define CXX11_UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key)                    \
272   .Case("UETT_" #Name, UETT_##Name)
273 #include "clang/Basic/TokenKinds.def"
274         .Default(llvm::None);
275   }
276 
277 public:
278   static bool hasCorrectType(const VariantValue &Value) {
279     return Value.isString();
280   }
281   static bool hasCorrectValue(const VariantValue& Value) {
282     return getUnaryOrTypeTraitKind(Value.getString()).hasValue();
283   }
284 
285   static UnaryExprOrTypeTrait get(const VariantValue &Value) {
286     return *getUnaryOrTypeTraitKind(Value.getString());
287   }
288 
289   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
290 
291   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
292 };
293 
294 /// Matcher descriptor interface.
295 ///
296 /// Provides a \c create() method that constructs the matcher from the provided
297 /// arguments, and various other methods for type introspection.
298 class MatcherDescriptor {
299 public:
300   virtual ~MatcherDescriptor() = default;
301 
302   virtual VariantMatcher create(SourceRange NameRange,
303                                 ArrayRef<ParserValue> Args,
304                                 Diagnostics *Error) const = 0;
305 
306   /// Returns whether the matcher is variadic. Variadic matchers can take any
307   /// number of arguments, but they must be of the same type.
308   virtual bool isVariadic() const = 0;
309 
310   /// Returns the number of arguments accepted by the matcher if not variadic.
311   virtual unsigned getNumArgs() const = 0;
312 
313   /// Given that the matcher is being converted to type \p ThisKind, append the
314   /// set of argument types accepted for argument \p ArgNo to \p ArgKinds.
315   // FIXME: We should provide the ability to constrain the output of this
316   // function based on the types of other matcher arguments.
317   virtual void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
318                            std::vector<ArgKind> &ArgKinds) const = 0;
319 
320   /// Returns whether this matcher is convertible to the given type.  If it is
321   /// so convertible, store in *Specificity a value corresponding to the
322   /// "specificity" of the converted matcher to the given context, and in
323   /// *LeastDerivedKind the least derived matcher kind which would result in the
324   /// same matcher overload.  Zero specificity indicates that this conversion
325   /// would produce a trivial matcher that will either always or never match.
326   /// Such matchers are excluded from code completion results.
327   virtual bool
328   isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity = nullptr,
329                   ASTNodeKind *LeastDerivedKind = nullptr) const = 0;
330 
331   /// Returns whether the matcher will, given a matcher of any type T, yield a
332   /// matcher of type T.
333   virtual bool isPolymorphic() const { return false; }
334 };
335 
336 inline bool isRetKindConvertibleTo(ArrayRef<ASTNodeKind> RetKinds,
337                                    ASTNodeKind Kind, unsigned *Specificity,
338                                    ASTNodeKind *LeastDerivedKind) {
339   for (const ASTNodeKind &NodeKind : RetKinds) {
340     if (ArgKind(NodeKind).isConvertibleTo(Kind, Specificity)) {
341       if (LeastDerivedKind)
342         *LeastDerivedKind = NodeKind;
343       return true;
344     }
345   }
346   return false;
347 }
348 
349 /// Simple callback implementation. Marshaller and function are provided.
350 ///
351 /// This class wraps a function of arbitrary signature and a marshaller
352 /// function into a MatcherDescriptor.
353 /// The marshaller is in charge of taking the VariantValue arguments, checking
354 /// their types, unpacking them and calling the underlying function.
355 class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
356 public:
357   using MarshallerType = VariantMatcher (*)(void (*Func)(),
358                                             StringRef MatcherName,
359                                             SourceRange NameRange,
360                                             ArrayRef<ParserValue> Args,
361                                             Diagnostics *Error);
362 
363   /// \param Marshaller Function to unpack the arguments and call \c Func
364   /// \param Func Matcher construct function. This is the function that
365   ///   compile-time matcher expressions would use to create the matcher.
366   /// \param RetKinds The list of matcher types to which the matcher is
367   ///   convertible.
368   /// \param ArgKinds The types of the arguments this matcher takes.
369   FixedArgCountMatcherDescriptor(MarshallerType Marshaller, void (*Func)(),
370                                  StringRef MatcherName,
371                                  ArrayRef<ASTNodeKind> RetKinds,
372                                  ArrayRef<ArgKind> ArgKinds)
373       : Marshaller(Marshaller), Func(Func), MatcherName(MatcherName),
374         RetKinds(RetKinds.begin(), RetKinds.end()),
375         ArgKinds(ArgKinds.begin(), ArgKinds.end()) {}
376 
377   VariantMatcher create(SourceRange NameRange,
378                         ArrayRef<ParserValue> Args,
379                         Diagnostics *Error) const override {
380     return Marshaller(Func, MatcherName, NameRange, Args, Error);
381   }
382 
383   bool isVariadic() const override { return false; }
384   unsigned getNumArgs() const override { return ArgKinds.size(); }
385 
386   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
387                    std::vector<ArgKind> &Kinds) const override {
388     Kinds.push_back(ArgKinds[ArgNo]);
389   }
390 
391   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
392                        ASTNodeKind *LeastDerivedKind) const override {
393     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
394                                   LeastDerivedKind);
395   }
396 
397 private:
398   const MarshallerType Marshaller;
399   void (* const Func)();
400   const std::string MatcherName;
401   const std::vector<ASTNodeKind> RetKinds;
402   const std::vector<ArgKind> ArgKinds;
403 };
404 
405 /// Helper methods to extract and merge all possible typed matchers
406 /// out of the polymorphic object.
407 template <class PolyMatcher>
408 static void mergePolyMatchers(const PolyMatcher &Poly,
409                               std::vector<DynTypedMatcher> &Out,
410                               ast_matchers::internal::EmptyTypeList) {}
411 
412 template <class PolyMatcher, class TypeList>
413 static void mergePolyMatchers(const PolyMatcher &Poly,
414                               std::vector<DynTypedMatcher> &Out, TypeList) {
415   Out.push_back(ast_matchers::internal::Matcher<typename TypeList::head>(Poly));
416   mergePolyMatchers(Poly, Out, typename TypeList::tail());
417 }
418 
419 /// Convert the return values of the functions into a VariantMatcher.
420 ///
421 /// There are 2 cases right now: The return value is a Matcher<T> or is a
422 /// polymorphic matcher. For the former, we just construct the VariantMatcher.
423 /// For the latter, we instantiate all the possible Matcher<T> of the poly
424 /// matcher.
425 inline VariantMatcher outvalueToVariantMatcher(const DynTypedMatcher &Matcher) {
426   return VariantMatcher::SingleMatcher(Matcher);
427 }
428 
429 template <typename T>
430 static VariantMatcher outvalueToVariantMatcher(const T &PolyMatcher,
431                                                typename T::ReturnTypes * =
432                                                    nullptr) {
433   std::vector<DynTypedMatcher> Matchers;
434   mergePolyMatchers(PolyMatcher, Matchers, typename T::ReturnTypes());
435   VariantMatcher Out = VariantMatcher::PolymorphicMatcher(std::move(Matchers));
436   return Out;
437 }
438 
439 template <typename T>
440 inline void
441 buildReturnTypeVectorFromTypeList(std::vector<ASTNodeKind> &RetTypes) {
442   RetTypes.push_back(ASTNodeKind::getFromNodeKind<typename T::head>());
443   buildReturnTypeVectorFromTypeList<typename T::tail>(RetTypes);
444 }
445 
446 template <>
447 inline void
448 buildReturnTypeVectorFromTypeList<ast_matchers::internal::EmptyTypeList>(
449     std::vector<ASTNodeKind> &RetTypes) {}
450 
451 template <typename T>
452 struct BuildReturnTypeVector {
453   static void build(std::vector<ASTNodeKind> &RetTypes) {
454     buildReturnTypeVectorFromTypeList<typename T::ReturnTypes>(RetTypes);
455   }
456 };
457 
458 template <typename T>
459 struct BuildReturnTypeVector<ast_matchers::internal::Matcher<T>> {
460   static void build(std::vector<ASTNodeKind> &RetTypes) {
461     RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>());
462   }
463 };
464 
465 template <typename T>
466 struct BuildReturnTypeVector<ast_matchers::internal::BindableMatcher<T>> {
467   static void build(std::vector<ASTNodeKind> &RetTypes) {
468     RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>());
469   }
470 };
471 
472 /// Variadic marshaller function.
473 template <typename ResultT, typename ArgT,
474           ResultT (*Func)(ArrayRef<const ArgT *>)>
475 VariantMatcher
476 variadicMatcherDescriptor(StringRef MatcherName, SourceRange NameRange,
477                           ArrayRef<ParserValue> Args, Diagnostics *Error) {
478   ArgT **InnerArgs = new ArgT *[Args.size()]();
479 
480   bool HasError = false;
481   for (size_t i = 0, e = Args.size(); i != e; ++i) {
482     using ArgTraits = ArgTypeTraits<ArgT>;
483 
484     const ParserValue &Arg = Args[i];
485     const VariantValue &Value = Arg.Value;
486     if (!ArgTraits::hasCorrectType(Value)) {
487       Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
488           << (i + 1) << ArgTraits::getKind().asString() << Value.getTypeAsString();
489       HasError = true;
490       break;
491     }
492     if (!ArgTraits::hasCorrectValue(Value)) {
493       if (llvm::Optional<std::string> BestGuess =
494               ArgTraits::getBestGuess(Value)) {
495         Error->addError(Arg.Range, Error->ET_RegistryUnknownEnumWithReplace)
496             << i + 1 << Value.getString() << *BestGuess;
497       } else if (Value.isString()) {
498         Error->addError(Arg.Range, Error->ET_RegistryValueNotFound)
499             << Value.getString();
500       } else {
501         // This isn't ideal, but it's better than reporting an empty string as
502         // the error in this case.
503         Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
504             << (i + 1) << ArgTraits::getKind().asString()
505             << Value.getTypeAsString();
506       }
507       HasError = true;
508       break;
509     }
510 
511     InnerArgs[i] = new ArgT(ArgTraits::get(Value));
512   }
513 
514   VariantMatcher Out;
515   if (!HasError) {
516     Out = outvalueToVariantMatcher(Func(llvm::makeArrayRef(InnerArgs,
517                                                            Args.size())));
518   }
519 
520   for (size_t i = 0, e = Args.size(); i != e; ++i) {
521     delete InnerArgs[i];
522   }
523   delete[] InnerArgs;
524   return Out;
525 }
526 
527 /// Matcher descriptor for variadic functions.
528 ///
529 /// This class simply wraps a VariadicFunction with the right signature to export
530 /// it as a MatcherDescriptor.
531 /// This allows us to have one implementation of the interface for as many free
532 /// functions as we want, reducing the number of symbols and size of the
533 /// object file.
534 class VariadicFuncMatcherDescriptor : public MatcherDescriptor {
535 public:
536   using RunFunc = VariantMatcher (*)(StringRef MatcherName,
537                                      SourceRange NameRange,
538                                      ArrayRef<ParserValue> Args,
539                                      Diagnostics *Error);
540 
541   template <typename ResultT, typename ArgT,
542             ResultT (*F)(ArrayRef<const ArgT *>)>
543   VariadicFuncMatcherDescriptor(
544       ast_matchers::internal::VariadicFunction<ResultT, ArgT, F> Func,
545       StringRef MatcherName)
546       : Func(&variadicMatcherDescriptor<ResultT, ArgT, F>),
547         MatcherName(MatcherName.str()),
548         ArgsKind(ArgTypeTraits<ArgT>::getKind()) {
549     BuildReturnTypeVector<ResultT>::build(RetKinds);
550   }
551 
552   VariantMatcher create(SourceRange NameRange,
553                         ArrayRef<ParserValue> Args,
554                         Diagnostics *Error) const override {
555     return Func(MatcherName, NameRange, Args, Error);
556   }
557 
558   bool isVariadic() const override { return true; }
559   unsigned getNumArgs() const override { return 0; }
560 
561   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
562                    std::vector<ArgKind> &Kinds) const override {
563     Kinds.push_back(ArgsKind);
564   }
565 
566   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
567                        ASTNodeKind *LeastDerivedKind) const override {
568     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
569                                   LeastDerivedKind);
570   }
571 
572 private:
573   const RunFunc Func;
574   const std::string MatcherName;
575   std::vector<ASTNodeKind> RetKinds;
576   const ArgKind ArgsKind;
577 };
578 
579 /// Return CK_Trivial when appropriate for VariadicDynCastAllOfMatchers.
580 class DynCastAllOfMatcherDescriptor : public VariadicFuncMatcherDescriptor {
581 public:
582   template <typename BaseT, typename DerivedT>
583   DynCastAllOfMatcherDescriptor(
584       ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT> Func,
585       StringRef MatcherName)
586       : VariadicFuncMatcherDescriptor(Func, MatcherName),
587         DerivedKind(ASTNodeKind::getFromNodeKind<DerivedT>()) {}
588 
589   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
590                        ASTNodeKind *LeastDerivedKind) const override {
591     // If Kind is not a base of DerivedKind, either DerivedKind is a base of
592     // Kind (in which case the match will always succeed) or Kind and
593     // DerivedKind are unrelated (in which case it will always fail), so set
594     // Specificity to 0.
595     if (VariadicFuncMatcherDescriptor::isConvertibleTo(Kind, Specificity,
596                                                  LeastDerivedKind)) {
597       if (Kind.isSame(DerivedKind) || !Kind.isBaseOf(DerivedKind)) {
598         if (Specificity)
599           *Specificity = 0;
600       }
601       return true;
602     } else {
603       return false;
604     }
605   }
606 
607 private:
608   const ASTNodeKind DerivedKind;
609 };
610 
611 /// Helper macros to check the arguments on all marshaller functions.
612 #define CHECK_ARG_COUNT(count)                                                 \
613   if (Args.size() != count) {                                                  \
614     Error->addError(NameRange, Error->ET_RegistryWrongArgCount)                \
615         << count << Args.size();                                               \
616     return VariantMatcher();                                                   \
617   }
618 
619 #define CHECK_ARG_TYPE(index, type)                                            \
620   if (!ArgTypeTraits<type>::hasCorrectType(Args[index].Value)) {               \
621     Error->addError(Args[index].Range, Error->ET_RegistryWrongArgType)         \
622         << (index + 1) << ArgTypeTraits<type>::getKind().asString()            \
623         << Args[index].Value.getTypeAsString();                                \
624     return VariantMatcher();                                                   \
625   }                                                                            \
626   if (!ArgTypeTraits<type>::hasCorrectValue(Args[index].Value)) {              \
627     if (llvm::Optional<std::string> BestGuess =                                \
628             ArgTypeTraits<type>::getBestGuess(Args[index].Value)) {            \
629       Error->addError(Args[index].Range,                                       \
630                       Error->ET_RegistryUnknownEnumWithReplace)                \
631           << index + 1 << Args[index].Value.getString() << *BestGuess;         \
632     } else if (Args[index].Value.isString()) {                                 \
633       Error->addError(Args[index].Range, Error->ET_RegistryValueNotFound)      \
634           << Args[index].Value.getString();                                    \
635     }                                                                          \
636     return VariantMatcher();                                                   \
637   }
638 
639 /// 0-arg marshaller function.
640 template <typename ReturnType>
641 static VariantMatcher matcherMarshall0(void (*Func)(), StringRef MatcherName,
642                                        SourceRange NameRange,
643                                        ArrayRef<ParserValue> Args,
644                                        Diagnostics *Error) {
645   using FuncType = ReturnType (*)();
646   CHECK_ARG_COUNT(0);
647   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)());
648 }
649 
650 /// 1-arg marshaller function.
651 template <typename ReturnType, typename ArgType1>
652 static VariantMatcher matcherMarshall1(void (*Func)(), StringRef MatcherName,
653                                        SourceRange NameRange,
654                                        ArrayRef<ParserValue> Args,
655                                        Diagnostics *Error) {
656   using FuncType = ReturnType (*)(ArgType1);
657   CHECK_ARG_COUNT(1);
658   CHECK_ARG_TYPE(0, ArgType1);
659   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)(
660       ArgTypeTraits<ArgType1>::get(Args[0].Value)));
661 }
662 
663 /// 2-arg marshaller function.
664 template <typename ReturnType, typename ArgType1, typename ArgType2>
665 static VariantMatcher matcherMarshall2(void (*Func)(), StringRef MatcherName,
666                                        SourceRange NameRange,
667                                        ArrayRef<ParserValue> Args,
668                                        Diagnostics *Error) {
669   using FuncType = ReturnType (*)(ArgType1, ArgType2);
670   CHECK_ARG_COUNT(2);
671   CHECK_ARG_TYPE(0, ArgType1);
672   CHECK_ARG_TYPE(1, ArgType2);
673   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)(
674       ArgTypeTraits<ArgType1>::get(Args[0].Value),
675       ArgTypeTraits<ArgType2>::get(Args[1].Value)));
676 }
677 
678 #undef CHECK_ARG_COUNT
679 #undef CHECK_ARG_TYPE
680 
681 /// Helper class used to collect all the possible overloads of an
682 ///   argument adaptative matcher function.
683 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
684           typename FromTypes, typename ToTypes>
685 class AdaptativeOverloadCollector {
686 public:
687   AdaptativeOverloadCollector(
688       StringRef Name, std::vector<std::unique_ptr<MatcherDescriptor>> &Out)
689       : Name(Name), Out(Out) {
690     collect(FromTypes());
691   }
692 
693 private:
694   using AdaptativeFunc = ast_matchers::internal::ArgumentAdaptingMatcherFunc<
695       ArgumentAdapterT, FromTypes, ToTypes>;
696 
697   /// End case for the recursion
698   static void collect(ast_matchers::internal::EmptyTypeList) {}
699 
700   /// Recursive case. Get the overload for the head of the list, and
701   ///   recurse to the tail.
702   template <typename FromTypeList>
703   inline void collect(FromTypeList);
704 
705   StringRef Name;
706   std::vector<std::unique_ptr<MatcherDescriptor>> &Out;
707 };
708 
709 /// MatcherDescriptor that wraps multiple "overloads" of the same
710 ///   matcher.
711 ///
712 /// It will try every overload and generate appropriate errors for when none or
713 /// more than one overloads match the arguments.
714 class OverloadedMatcherDescriptor : public MatcherDescriptor {
715 public:
716   OverloadedMatcherDescriptor(
717       MutableArrayRef<std::unique_ptr<MatcherDescriptor>> Callbacks)
718       : Overloads(std::make_move_iterator(Callbacks.begin()),
719                   std::make_move_iterator(Callbacks.end())) {}
720 
721   ~OverloadedMatcherDescriptor() override = default;
722 
723   VariantMatcher create(SourceRange NameRange,
724                         ArrayRef<ParserValue> Args,
725                         Diagnostics *Error) const override {
726     std::vector<VariantMatcher> Constructed;
727     Diagnostics::OverloadContext Ctx(Error);
728     for (const auto &O : Overloads) {
729       VariantMatcher SubMatcher = O->create(NameRange, Args, Error);
730       if (!SubMatcher.isNull()) {
731         Constructed.push_back(SubMatcher);
732       }
733     }
734 
735     if (Constructed.empty()) return VariantMatcher(); // No overload matched.
736     // We ignore the errors if any matcher succeeded.
737     Ctx.revertErrors();
738     if (Constructed.size() > 1) {
739       // More than one constructed. It is ambiguous.
740       Error->addError(NameRange, Error->ET_RegistryAmbiguousOverload);
741       return VariantMatcher();
742     }
743     return Constructed[0];
744   }
745 
746   bool isVariadic() const override {
747     bool Overload0Variadic = Overloads[0]->isVariadic();
748 #ifndef NDEBUG
749     for (const auto &O : Overloads) {
750       assert(Overload0Variadic == O->isVariadic());
751     }
752 #endif
753     return Overload0Variadic;
754   }
755 
756   unsigned getNumArgs() const override {
757     unsigned Overload0NumArgs = Overloads[0]->getNumArgs();
758 #ifndef NDEBUG
759     for (const auto &O : Overloads) {
760       assert(Overload0NumArgs == O->getNumArgs());
761     }
762 #endif
763     return Overload0NumArgs;
764   }
765 
766   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
767                    std::vector<ArgKind> &Kinds) const override {
768     for (const auto &O : Overloads) {
769       if (O->isConvertibleTo(ThisKind))
770         O->getArgKinds(ThisKind, ArgNo, Kinds);
771     }
772   }
773 
774   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
775                        ASTNodeKind *LeastDerivedKind) const override {
776     for (const auto &O : Overloads) {
777       if (O->isConvertibleTo(Kind, Specificity, LeastDerivedKind))
778         return true;
779     }
780     return false;
781   }
782 
783 private:
784   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
785 };
786 
787 template <typename ReturnType>
788 class RegexMatcherDescriptor : public MatcherDescriptor {
789 public:
790   RegexMatcherDescriptor(ReturnType (*WithFlags)(StringRef,
791                                                  llvm::Regex::RegexFlags),
792                          ReturnType (*NoFlags)(StringRef),
793                          ArrayRef<ASTNodeKind> RetKinds)
794       : WithFlags(WithFlags), NoFlags(NoFlags),
795         RetKinds(RetKinds.begin(), RetKinds.end()) {}
796   bool isVariadic() const override { return true; }
797   unsigned getNumArgs() const override { return 0; }
798 
799   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
800                    std::vector<ArgKind> &Kinds) const override {
801     assert(ArgNo < 2);
802     Kinds.push_back(ArgKind::AK_String);
803   }
804 
805   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
806                        ASTNodeKind *LeastDerivedKind) const override {
807     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
808                                   LeastDerivedKind);
809   }
810 
811   VariantMatcher create(SourceRange NameRange, ArrayRef<ParserValue> Args,
812                         Diagnostics *Error) const override {
813     if (Args.size() < 1 || Args.size() > 2) {
814       Error->addError(NameRange, Diagnostics::ET_RegistryWrongArgCount)
815           << "1 or 2" << Args.size();
816       return VariantMatcher();
817     }
818     if (!ArgTypeTraits<StringRef>::hasCorrectType(Args[0].Value)) {
819       Error->addError(Args[0].Range, Error->ET_RegistryWrongArgType)
820           << 1 << ArgTypeTraits<StringRef>::getKind().asString()
821           << Args[0].Value.getTypeAsString();
822       return VariantMatcher();
823     }
824     if (Args.size() == 1) {
825       return outvalueToVariantMatcher(
826           NoFlags(ArgTypeTraits<StringRef>::get(Args[0].Value)));
827     }
828     if (!ArgTypeTraits<llvm::Regex::RegexFlags>::hasCorrectType(
829             Args[1].Value)) {
830       Error->addError(Args[1].Range, Error->ET_RegistryWrongArgType)
831           << 2 << ArgTypeTraits<llvm::Regex::RegexFlags>::getKind().asString()
832           << Args[1].Value.getTypeAsString();
833       return VariantMatcher();
834     }
835     if (!ArgTypeTraits<llvm::Regex::RegexFlags>::hasCorrectValue(
836             Args[1].Value)) {
837       if (llvm::Optional<std::string> BestGuess =
838               ArgTypeTraits<llvm::Regex::RegexFlags>::getBestGuess(
839                   Args[1].Value)) {
840         Error->addError(Args[1].Range, Error->ET_RegistryUnknownEnumWithReplace)
841             << 2 << Args[1].Value.getString() << *BestGuess;
842       } else {
843         Error->addError(Args[1].Range, Error->ET_RegistryValueNotFound)
844             << Args[1].Value.getString();
845       }
846       return VariantMatcher();
847     }
848     return outvalueToVariantMatcher(
849         WithFlags(ArgTypeTraits<StringRef>::get(Args[0].Value),
850                   ArgTypeTraits<llvm::Regex::RegexFlags>::get(Args[1].Value)));
851   }
852 
853 private:
854   ReturnType (*const WithFlags)(StringRef, llvm::Regex::RegexFlags);
855   ReturnType (*const NoFlags)(StringRef);
856   const std::vector<ASTNodeKind> RetKinds;
857 };
858 
859 /// Variadic operator marshaller function.
860 class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
861 public:
862   using VarOp = DynTypedMatcher::VariadicOperator;
863 
864   VariadicOperatorMatcherDescriptor(unsigned MinCount, unsigned MaxCount,
865                                     VarOp Op, StringRef MatcherName)
866       : MinCount(MinCount), MaxCount(MaxCount), Op(Op),
867         MatcherName(MatcherName) {}
868 
869   VariantMatcher create(SourceRange NameRange,
870                         ArrayRef<ParserValue> Args,
871                         Diagnostics *Error) const override {
872     if (Args.size() < MinCount || MaxCount < Args.size()) {
873       const std::string MaxStr =
874           (MaxCount == std::numeric_limits<unsigned>::max() ? ""
875                                                             : Twine(MaxCount))
876               .str();
877       Error->addError(NameRange, Error->ET_RegistryWrongArgCount)
878           << ("(" + Twine(MinCount) + ", " + MaxStr + ")") << Args.size();
879       return VariantMatcher();
880     }
881 
882     std::vector<VariantMatcher> InnerArgs;
883     for (size_t i = 0, e = Args.size(); i != e; ++i) {
884       const ParserValue &Arg = Args[i];
885       const VariantValue &Value = Arg.Value;
886       if (!Value.isMatcher()) {
887         Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
888             << (i + 1) << "Matcher<>" << Value.getTypeAsString();
889         return VariantMatcher();
890       }
891       InnerArgs.push_back(Value.getMatcher());
892     }
893     return VariantMatcher::VariadicOperatorMatcher(Op, std::move(InnerArgs));
894   }
895 
896   bool isVariadic() const override { return true; }
897   unsigned getNumArgs() const override { return 0; }
898 
899   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
900                    std::vector<ArgKind> &Kinds) const override {
901     Kinds.push_back(ThisKind);
902   }
903 
904   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
905                        ASTNodeKind *LeastDerivedKind) const override {
906     if (Specificity)
907       *Specificity = 1;
908     if (LeastDerivedKind)
909       *LeastDerivedKind = Kind;
910     return true;
911   }
912 
913   bool isPolymorphic() const override { return true; }
914 
915 private:
916   const unsigned MinCount;
917   const unsigned MaxCount;
918   const VarOp Op;
919   const StringRef MatcherName;
920 };
921 
922 /// Helper functions to select the appropriate marshaller functions.
923 /// They detect the number of arguments, arguments types and return type.
924 
925 /// 0-arg overload
926 template <typename ReturnType>
927 std::unique_ptr<MatcherDescriptor>
928 makeMatcherAutoMarshall(ReturnType (*Func)(), StringRef MatcherName) {
929   std::vector<ASTNodeKind> RetTypes;
930   BuildReturnTypeVector<ReturnType>::build(RetTypes);
931   return std::make_unique<FixedArgCountMatcherDescriptor>(
932       matcherMarshall0<ReturnType>, reinterpret_cast<void (*)()>(Func),
933       MatcherName, RetTypes, None);
934 }
935 
936 /// 1-arg overload
937 template <typename ReturnType, typename ArgType1>
938 std::unique_ptr<MatcherDescriptor>
939 makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1), StringRef MatcherName) {
940   std::vector<ASTNodeKind> RetTypes;
941   BuildReturnTypeVector<ReturnType>::build(RetTypes);
942   ArgKind AK = ArgTypeTraits<ArgType1>::getKind();
943   return std::make_unique<FixedArgCountMatcherDescriptor>(
944       matcherMarshall1<ReturnType, ArgType1>,
945       reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AK);
946 }
947 
948 /// 2-arg overload
949 template <typename ReturnType, typename ArgType1, typename ArgType2>
950 std::unique_ptr<MatcherDescriptor>
951 makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1, ArgType2),
952                         StringRef MatcherName) {
953   std::vector<ASTNodeKind> RetTypes;
954   BuildReturnTypeVector<ReturnType>::build(RetTypes);
955   ArgKind AKs[] = { ArgTypeTraits<ArgType1>::getKind(),
956                     ArgTypeTraits<ArgType2>::getKind() };
957   return std::make_unique<FixedArgCountMatcherDescriptor>(
958       matcherMarshall2<ReturnType, ArgType1, ArgType2>,
959       reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AKs);
960 }
961 
962 template <typename ReturnType>
963 std::unique_ptr<MatcherDescriptor> makeMatcherRegexMarshall(
964     ReturnType (*FuncFlags)(llvm::StringRef, llvm::Regex::RegexFlags),
965     ReturnType (*Func)(llvm::StringRef)) {
966   std::vector<ASTNodeKind> RetTypes;
967   BuildReturnTypeVector<ReturnType>::build(RetTypes);
968   return std::make_unique<RegexMatcherDescriptor<ReturnType>>(FuncFlags, Func,
969                                                               RetTypes);
970 }
971 
972 /// Variadic overload.
973 template <typename ResultT, typename ArgT,
974           ResultT (*Func)(ArrayRef<const ArgT *>)>
975 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
976     ast_matchers::internal::VariadicFunction<ResultT, ArgT, Func> VarFunc,
977     StringRef MatcherName) {
978   return std::make_unique<VariadicFuncMatcherDescriptor>(VarFunc, MatcherName);
979 }
980 
981 /// Overload for VariadicDynCastAllOfMatchers.
982 ///
983 /// Not strictly necessary, but DynCastAllOfMatcherDescriptor gives us better
984 /// completion results for that type of matcher.
985 template <typename BaseT, typename DerivedT>
986 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
987     ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT>
988         VarFunc,
989     StringRef MatcherName) {
990   return std::make_unique<DynCastAllOfMatcherDescriptor>(VarFunc, MatcherName);
991 }
992 
993 /// Argument adaptative overload.
994 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
995           typename FromTypes, typename ToTypes>
996 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
997     ast_matchers::internal::ArgumentAdaptingMatcherFunc<ArgumentAdapterT,
998                                                         FromTypes, ToTypes>,
999     StringRef MatcherName) {
1000   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
1001   AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes, ToTypes>(MatcherName,
1002                                                                     Overloads);
1003   return std::make_unique<OverloadedMatcherDescriptor>(Overloads);
1004 }
1005 
1006 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
1007           typename FromTypes, typename ToTypes>
1008 template <typename FromTypeList>
1009 inline void AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes,
1010                                         ToTypes>::collect(FromTypeList) {
1011   Out.push_back(makeMatcherAutoMarshall(
1012       &AdaptativeFunc::template create<typename FromTypeList::head>, Name));
1013   collect(typename FromTypeList::tail());
1014 }
1015 
1016 /// Variadic operator overload.
1017 template <unsigned MinCount, unsigned MaxCount>
1018 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
1019     ast_matchers::internal::VariadicOperatorMatcherFunc<MinCount, MaxCount>
1020         Func,
1021     StringRef MatcherName) {
1022   return std::make_unique<VariadicOperatorMatcherDescriptor>(
1023       MinCount, MaxCount, Func.Op, MatcherName);
1024 }
1025 
1026 } // namespace internal
1027 } // namespace dynamic
1028 } // namespace ast_matchers
1029 } // namespace clang
1030 
1031 #endif // LLVM_CLANG_AST_MATCHERS_DYNAMIC_MARSHALLERS_H
1032