1 //===--- DurationRewriter.cpp - clang-tidy --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "DurationRewriter.h"
10 #include "clang/Tooling/FixIt.h"
11 #include "llvm/ADT/IndexedMap.h"
12
13 using namespace clang::ast_matchers;
14
15 namespace clang {
16 namespace tidy {
17 namespace abseil {
18
19 struct DurationScale2IndexFunctor {
20 using argument_type = DurationScale;
operator ()clang::tidy::abseil::DurationScale2IndexFunctor21 unsigned operator()(DurationScale Scale) const {
22 return static_cast<unsigned>(Scale);
23 }
24 };
25
26 /// Returns an integer if the fractional part of a `FloatingLiteral` is `0`.
27 static llvm::Optional<llvm::APSInt>
truncateIfIntegral(const FloatingLiteral & FloatLiteral)28 truncateIfIntegral(const FloatingLiteral &FloatLiteral) {
29 double Value = FloatLiteral.getValueAsApproximateDouble();
30 if (std::fmod(Value, 1) == 0) {
31 if (Value >= static_cast<double>(1u << 31))
32 return llvm::None;
33
34 return llvm::APSInt::get(static_cast<int64_t>(Value));
35 }
36 return llvm::None;
37 }
38
39 const std::pair<llvm::StringRef, llvm::StringRef> &
getDurationInverseForScale(DurationScale Scale)40 getDurationInverseForScale(DurationScale Scale) {
41 static const llvm::IndexedMap<std::pair<llvm::StringRef, llvm::StringRef>,
42 DurationScale2IndexFunctor>
43 InverseMap = []() {
44 // TODO: Revisit the immediately invoked lambda technique when
45 // IndexedMap gets an initializer list constructor.
46 llvm::IndexedMap<std::pair<llvm::StringRef, llvm::StringRef>,
47 DurationScale2IndexFunctor>
48 InverseMap;
49 InverseMap.resize(6);
50 InverseMap[DurationScale::Hours] =
51 std::make_pair("::absl::ToDoubleHours", "::absl::ToInt64Hours");
52 InverseMap[DurationScale::Minutes] =
53 std::make_pair("::absl::ToDoubleMinutes", "::absl::ToInt64Minutes");
54 InverseMap[DurationScale::Seconds] =
55 std::make_pair("::absl::ToDoubleSeconds", "::absl::ToInt64Seconds");
56 InverseMap[DurationScale::Milliseconds] = std::make_pair(
57 "::absl::ToDoubleMilliseconds", "::absl::ToInt64Milliseconds");
58 InverseMap[DurationScale::Microseconds] = std::make_pair(
59 "::absl::ToDoubleMicroseconds", "::absl::ToInt64Microseconds");
60 InverseMap[DurationScale::Nanoseconds] = std::make_pair(
61 "::absl::ToDoubleNanoseconds", "::absl::ToInt64Nanoseconds");
62 return InverseMap;
63 }();
64
65 return InverseMap[Scale];
66 }
67
68 /// If `Node` is a call to the inverse of `Scale`, return that inverse's
69 /// argument, otherwise None.
70 static llvm::Optional<std::string>
rewriteInverseDurationCall(const MatchFinder::MatchResult & Result,DurationScale Scale,const Expr & Node)71 rewriteInverseDurationCall(const MatchFinder::MatchResult &Result,
72 DurationScale Scale, const Expr &Node) {
73 const std::pair<llvm::StringRef, llvm::StringRef> &InverseFunctions =
74 getDurationInverseForScale(Scale);
75 if (const auto *MaybeCallArg = selectFirst<const Expr>(
76 "e",
77 match(callExpr(callee(functionDecl(hasAnyName(
78 InverseFunctions.first, InverseFunctions.second))),
79 hasArgument(0, expr().bind("e"))),
80 Node, *Result.Context))) {
81 return tooling::fixit::getText(*MaybeCallArg, *Result.Context).str();
82 }
83
84 return llvm::None;
85 }
86
87 /// If `Node` is a call to the inverse of `Scale`, return that inverse's
88 /// argument, otherwise None.
89 static llvm::Optional<std::string>
rewriteInverseTimeCall(const MatchFinder::MatchResult & Result,DurationScale Scale,const Expr & Node)90 rewriteInverseTimeCall(const MatchFinder::MatchResult &Result,
91 DurationScale Scale, const Expr &Node) {
92 llvm::StringRef InverseFunction = getTimeInverseForScale(Scale);
93 if (const auto *MaybeCallArg = selectFirst<const Expr>(
94 "e", match(callExpr(callee(functionDecl(hasName(InverseFunction))),
95 hasArgument(0, expr().bind("e"))),
96 Node, *Result.Context))) {
97 return tooling::fixit::getText(*MaybeCallArg, *Result.Context).str();
98 }
99
100 return llvm::None;
101 }
102
103 /// Returns the factory function name for a given `Scale`.
getDurationFactoryForScale(DurationScale Scale)104 llvm::StringRef getDurationFactoryForScale(DurationScale Scale) {
105 switch (Scale) {
106 case DurationScale::Hours:
107 return "absl::Hours";
108 case DurationScale::Minutes:
109 return "absl::Minutes";
110 case DurationScale::Seconds:
111 return "absl::Seconds";
112 case DurationScale::Milliseconds:
113 return "absl::Milliseconds";
114 case DurationScale::Microseconds:
115 return "absl::Microseconds";
116 case DurationScale::Nanoseconds:
117 return "absl::Nanoseconds";
118 }
119 llvm_unreachable("unknown scaling factor");
120 }
121
getTimeFactoryForScale(DurationScale Scale)122 llvm::StringRef getTimeFactoryForScale(DurationScale Scale) {
123 switch (Scale) {
124 case DurationScale::Hours:
125 return "absl::FromUnixHours";
126 case DurationScale::Minutes:
127 return "absl::FromUnixMinutes";
128 case DurationScale::Seconds:
129 return "absl::FromUnixSeconds";
130 case DurationScale::Milliseconds:
131 return "absl::FromUnixMillis";
132 case DurationScale::Microseconds:
133 return "absl::FromUnixMicros";
134 case DurationScale::Nanoseconds:
135 return "absl::FromUnixNanos";
136 }
137 llvm_unreachable("unknown scaling factor");
138 }
139
140 /// Returns the Time factory function name for a given `Scale`.
getTimeInverseForScale(DurationScale scale)141 llvm::StringRef getTimeInverseForScale(DurationScale scale) {
142 switch (scale) {
143 case DurationScale::Hours:
144 return "absl::ToUnixHours";
145 case DurationScale::Minutes:
146 return "absl::ToUnixMinutes";
147 case DurationScale::Seconds:
148 return "absl::ToUnixSeconds";
149 case DurationScale::Milliseconds:
150 return "absl::ToUnixMillis";
151 case DurationScale::Microseconds:
152 return "absl::ToUnixMicros";
153 case DurationScale::Nanoseconds:
154 return "absl::ToUnixNanos";
155 }
156 llvm_unreachable("unknown scaling factor");
157 }
158
159 /// Returns `true` if `Node` is a value which evaluates to a literal `0`.
IsLiteralZero(const MatchFinder::MatchResult & Result,const Expr & Node)160 bool IsLiteralZero(const MatchFinder::MatchResult &Result, const Expr &Node) {
161 auto ZeroMatcher =
162 anyOf(integerLiteral(equals(0)), floatLiteral(equals(0.0)));
163
164 // Check to see if we're using a zero directly.
165 if (selectFirst<const clang::Expr>(
166 "val", match(expr(ignoringImpCasts(ZeroMatcher)).bind("val"), Node,
167 *Result.Context)) != nullptr)
168 return true;
169
170 // Now check to see if we're using a functional cast with a scalar
171 // initializer expression, e.g. `int{0}`.
172 if (selectFirst<const clang::Expr>(
173 "val", match(cxxFunctionalCastExpr(
174 hasDestinationType(
175 anyOf(isInteger(), realFloatingPointType())),
176 hasSourceExpression(initListExpr(
177 hasInit(0, ignoringParenImpCasts(ZeroMatcher)))))
178 .bind("val"),
179 Node, *Result.Context)) != nullptr)
180 return true;
181
182 return false;
183 }
184
185 llvm::Optional<std::string>
stripFloatCast(const ast_matchers::MatchFinder::MatchResult & Result,const Expr & Node)186 stripFloatCast(const ast_matchers::MatchFinder::MatchResult &Result,
187 const Expr &Node) {
188 if (const Expr *MaybeCastArg = selectFirst<const Expr>(
189 "cast_arg",
190 match(expr(anyOf(cxxStaticCastExpr(
191 hasDestinationType(realFloatingPointType()),
192 hasSourceExpression(expr().bind("cast_arg"))),
193 cStyleCastExpr(
194 hasDestinationType(realFloatingPointType()),
195 hasSourceExpression(expr().bind("cast_arg"))),
196 cxxFunctionalCastExpr(
197 hasDestinationType(realFloatingPointType()),
198 hasSourceExpression(expr().bind("cast_arg"))))),
199 Node, *Result.Context)))
200 return tooling::fixit::getText(*MaybeCastArg, *Result.Context).str();
201
202 return llvm::None;
203 }
204
205 llvm::Optional<std::string>
stripFloatLiteralFraction(const MatchFinder::MatchResult & Result,const Expr & Node)206 stripFloatLiteralFraction(const MatchFinder::MatchResult &Result,
207 const Expr &Node) {
208 if (const auto *LitFloat = llvm::dyn_cast<FloatingLiteral>(&Node))
209 // Attempt to simplify a `Duration` factory call with a literal argument.
210 if (llvm::Optional<llvm::APSInt> IntValue = truncateIfIntegral(*LitFloat))
211 return IntValue->toString(/*radix=*/10);
212
213 return llvm::None;
214 }
215
simplifyDurationFactoryArg(const MatchFinder::MatchResult & Result,const Expr & Node)216 std::string simplifyDurationFactoryArg(const MatchFinder::MatchResult &Result,
217 const Expr &Node) {
218 // Check for an explicit cast to `float` or `double`.
219 if (llvm::Optional<std::string> MaybeArg = stripFloatCast(Result, Node))
220 return *MaybeArg;
221
222 // Check for floats without fractional components.
223 if (llvm::Optional<std::string> MaybeArg =
224 stripFloatLiteralFraction(Result, Node))
225 return *MaybeArg;
226
227 // We couldn't simplify any further, so return the argument text.
228 return tooling::fixit::getText(Node, *Result.Context).str();
229 }
230
getScaleForDurationInverse(llvm::StringRef Name)231 llvm::Optional<DurationScale> getScaleForDurationInverse(llvm::StringRef Name) {
232 static const llvm::StringMap<DurationScale> ScaleMap(
233 {{"ToDoubleHours", DurationScale::Hours},
234 {"ToInt64Hours", DurationScale::Hours},
235 {"ToDoubleMinutes", DurationScale::Minutes},
236 {"ToInt64Minutes", DurationScale::Minutes},
237 {"ToDoubleSeconds", DurationScale::Seconds},
238 {"ToInt64Seconds", DurationScale::Seconds},
239 {"ToDoubleMilliseconds", DurationScale::Milliseconds},
240 {"ToInt64Milliseconds", DurationScale::Milliseconds},
241 {"ToDoubleMicroseconds", DurationScale::Microseconds},
242 {"ToInt64Microseconds", DurationScale::Microseconds},
243 {"ToDoubleNanoseconds", DurationScale::Nanoseconds},
244 {"ToInt64Nanoseconds", DurationScale::Nanoseconds}});
245
246 auto ScaleIter = ScaleMap.find(std::string(Name));
247 if (ScaleIter == ScaleMap.end())
248 return llvm::None;
249
250 return ScaleIter->second;
251 }
252
getScaleForTimeInverse(llvm::StringRef Name)253 llvm::Optional<DurationScale> getScaleForTimeInverse(llvm::StringRef Name) {
254 static const llvm::StringMap<DurationScale> ScaleMap(
255 {{"ToUnixHours", DurationScale::Hours},
256 {"ToUnixMinutes", DurationScale::Minutes},
257 {"ToUnixSeconds", DurationScale::Seconds},
258 {"ToUnixMillis", DurationScale::Milliseconds},
259 {"ToUnixMicros", DurationScale::Microseconds},
260 {"ToUnixNanos", DurationScale::Nanoseconds}});
261
262 auto ScaleIter = ScaleMap.find(std::string(Name));
263 if (ScaleIter == ScaleMap.end())
264 return llvm::None;
265
266 return ScaleIter->second;
267 }
268
rewriteExprFromNumberToDuration(const ast_matchers::MatchFinder::MatchResult & Result,DurationScale Scale,const Expr * Node)269 std::string rewriteExprFromNumberToDuration(
270 const ast_matchers::MatchFinder::MatchResult &Result, DurationScale Scale,
271 const Expr *Node) {
272 const Expr &RootNode = *Node->IgnoreParenImpCasts();
273
274 // First check to see if we can undo a complimentary function call.
275 if (llvm::Optional<std::string> MaybeRewrite =
276 rewriteInverseDurationCall(Result, Scale, RootNode))
277 return *MaybeRewrite;
278
279 if (IsLiteralZero(Result, RootNode))
280 return std::string("absl::ZeroDuration()");
281
282 return (llvm::Twine(getDurationFactoryForScale(Scale)) + "(" +
283 simplifyDurationFactoryArg(Result, RootNode) + ")")
284 .str();
285 }
286
rewriteExprFromNumberToTime(const ast_matchers::MatchFinder::MatchResult & Result,DurationScale Scale,const Expr * Node)287 std::string rewriteExprFromNumberToTime(
288 const ast_matchers::MatchFinder::MatchResult &Result, DurationScale Scale,
289 const Expr *Node) {
290 const Expr &RootNode = *Node->IgnoreParenImpCasts();
291
292 // First check to see if we can undo a complimentary function call.
293 if (llvm::Optional<std::string> MaybeRewrite =
294 rewriteInverseTimeCall(Result, Scale, RootNode))
295 return *MaybeRewrite;
296
297 if (IsLiteralZero(Result, RootNode))
298 return std::string("absl::UnixEpoch()");
299
300 return (llvm::Twine(getTimeFactoryForScale(Scale)) + "(" +
301 tooling::fixit::getText(RootNode, *Result.Context) + ")")
302 .str();
303 }
304
isInMacro(const MatchFinder::MatchResult & Result,const Expr * E)305 bool isInMacro(const MatchFinder::MatchResult &Result, const Expr *E) {
306 if (!E->getBeginLoc().isMacroID())
307 return false;
308
309 SourceLocation Loc = E->getBeginLoc();
310 // We want to get closer towards the initial macro typed into the source only
311 // if the location is being expanded as a macro argument.
312 while (Result.SourceManager->isMacroArgExpansion(Loc)) {
313 // We are calling getImmediateMacroCallerLoc, but note it is essentially
314 // equivalent to calling getImmediateSpellingLoc in this context according
315 // to Clang implementation. We are not calling getImmediateSpellingLoc
316 // because Clang comment says it "should not generally be used by clients."
317 Loc = Result.SourceManager->getImmediateMacroCallerLoc(Loc);
318 }
319 return Loc.isMacroID();
320 }
321
322 } // namespace abseil
323 } // namespace tidy
324 } // namespace clang
325