1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "absl/algorithm/container.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/platform_util.h"
30 #include "tensorflow/compiler/xla/service/stream_pool.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/regexp.h"
37 #include "tensorflow/core/platform/test.h"
38
39 namespace xla {
40 namespace {
41
42 class HloProfileTest : public ClientLibraryTestBase {};
43
44 struct ParsedProfileOutputLine {
45 int64_t cycles;
46 std::string cycles_percentage;
47 double usec;
48 std::string flops;
49 std::string trops;
50 std::string bytes_per_sec;
51 std::string bytes_per_cycle;
52 std::string opcode;
53 };
54
HasFlops(const ParsedProfileOutputLine & parsed_line)55 ::testing::AssertionResult HasFlops(
56 const ParsedProfileOutputLine& parsed_line) {
57 if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) {
58 return ::testing::AssertionSuccess()
59 << "'flops' field present in " << parsed_line.opcode << ": '"
60 << parsed_line.flops << "'";
61 }
62
63 return ::testing::AssertionFailure()
64 << "'flops' field absent in " << parsed_line.opcode << ": '"
65 << parsed_line.flops << "'";
66 }
67
HasTrops(const ParsedProfileOutputLine & parsed_line)68 ::testing::AssertionResult HasTrops(
69 const ParsedProfileOutputLine& parsed_line) {
70 if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) {
71 return ::testing::AssertionSuccess()
72 << "'trops' field present in " << parsed_line.opcode << ": '"
73 << parsed_line.trops << "'";
74 }
75
76 return ::testing::AssertionFailure()
77 << "'trops' field absent in " << parsed_line.opcode << ": '"
78 << parsed_line.trops << "'";
79 }
80
ParseOneProfileOutputLine(const std::string & line,bool expect_hlo,absl::flat_hash_map<std::string,ParsedProfileOutputLine> * parsed_results,absl::Span<const absl::string_view> opcodes_to_ignore={})81 Status ParseOneProfileOutputLine(
82 const std::string& line, bool expect_hlo,
83 absl::flat_hash_map<std::string, ParsedProfileOutputLine>* parsed_results,
84 absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
85 std::string separator = "[^:]*:: +";
86 std::string match_percentage = R"(\d+\.\d*% +\d+Σ)";
87 std::string match_cycles =
88 R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
89 std::string match_usecs = "([0-9.]+) usec";
90 std::string match_flops = "([^ ]*)";
91 std::string match_trops = "([^ ]*)";
92 std::string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?";
93 std::string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?";
94
95 // The underlined part is what we're trying to match with match_opcode:
96 //
97 // %dot33 = f32[256,256]{1,0} dot(...)
98 // ^^^
99
100 std::string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*"
101 : "(\\[total\\])( \\[entry\\])?";
102 std::string regexp_pattern = absl::StrCat(
103 " +", match_cycles, separator, match_usecs, separator, match_flops,
104 separator, match_trops, separator, match_bytes_per_sec, separator,
105 match_bytes_per_cycle, separator, match_opcode);
106
107 ParsedProfileOutputLine parsed_line;
108 bool matched = RE2::FullMatch(
109 line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
110 &parsed_line.usec, &parsed_line.flops, &parsed_line.trops,
111 &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle,
112 &parsed_line.opcode);
113 if (!matched) {
114 return tensorflow::errors::InvalidArgument(
115 "Input did not match regexp. Input: ", line,
116 ", Regexp: ", regexp_pattern);
117 }
118
119 if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
120 InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
121 }
122
123 return OkStatus();
124 }
125
IsExtraMetricProfileOutputLine(const std::string & line)126 bool IsExtraMetricProfileOutputLine(const std::string& line) {
127 return RE2::FullMatch(line, "Extra metric \\S+: \\d+");
128 }
129
130 // Returns void so that we can ASSERT.
ExecuteAndFetchProfile(std::string * profile_output,LocalClient * client,const XlaComputation & computation,const Shape & lhs_arg_shape,const Shape & rhs_arg_shape)131 void ExecuteAndFetchProfile(std::string* profile_output, LocalClient* client,
132 const XlaComputation& computation,
133 const Shape& lhs_arg_shape,
134 const Shape& rhs_arg_shape) {
135 LocalService* service = ClientLibrary::GetXlaService(client->platform());
136 Backend* backend = service->mutable_backend();
137 se::StreamExecutor* executor = backend->default_stream_executor();
138 se::DeviceMemoryAllocator* allocator = backend->memory_allocator();
139 auto* transfer_manager = backend->transfer_manager();
140 TF_ASSERT_OK_AND_ASSIGN(
141 StreamPool::Ptr stream_ptr,
142 backend->BorrowStream(backend->default_device_ordinal()));
143
144 TF_ASSERT_OK_AND_ASSIGN(
145 ScopedShapedBuffer lhs_arg,
146 transfer_manager->AllocateScopedShapedBuffer(
147 lhs_arg_shape, allocator, backend->default_device_ordinal()));
148 TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
149 stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
150
151 TF_ASSERT_OK_AND_ASSIGN(
152 ScopedShapedBuffer rhs_arg,
153 transfer_manager->AllocateScopedShapedBuffer(
154 rhs_arg_shape, allocator, backend->default_device_ordinal()));
155 TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
156 stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
157
158 ExecutableBuildOptions build_options;
159 build_options.mutable_debug_options()->set_xla_hlo_profile(true);
160 TF_ASSERT_OK_AND_ASSIGN(
161 auto local_executables,
162 client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
163 build_options));
164
165 Executable* executable = local_executables[0]->executable();
166 HloExecutionProfile hlo_execution_profile(
167 &executable->hlo_profile_printer_data(),
168 &executable->hlo_profile_index_map());
169
170 ExecutableRunOptions exec_run_options;
171 exec_run_options.set_stream(stream_ptr.get());
172 exec_run_options.set_allocator(backend->memory_allocator());
173 exec_run_options.set_intra_op_thread_pool(
174 backend->eigen_intra_op_thread_pool_device());
175 ServiceExecutableRunOptions run_options(exec_run_options,
176 /*borrow_stream=*/nullptr);
177 std::vector<const ShapedBuffer*> args = {&lhs_arg, &rhs_arg};
178 TF_ASSERT_OK_AND_ASSIGN(
179 auto execution_result,
180 executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile));
181 TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
182 (void)execution_result;
183
184 *profile_output = hlo_execution_profile.ToString(
185 executor->GetDeviceDescription().clock_rate_ghz());
186
187 XLA_VLOG_LINES(4, *profile_output);
188 }
189
XLA_TEST_F(HloProfileTest,DISABLED_ON_GPU (ProfileSingleComputation))190 XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileSingleComputation)) {
191 const int64_t m = 256, k = 256, n = 256;
192 Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
193 Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k});
194
195 TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
196 PlatformUtil::GetDefaultPlatform());
197 TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
198 ClientLibrary::GetOrCreateLocalClient(platform));
199
200 XlaBuilder builder(TestName());
201 Tanh(Add(
202 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
203 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
204
205 TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
206
207 std::string profile_output;
208 ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape,
209 rhs_shape);
210 VLOG(4) << "Profile Output:\n" << profile_output;
211 std::vector<std::string> profile_output_lines =
212 absl::StrSplit(profile_output, '\n');
213
214 absl::flat_hash_map<std::string, ParsedProfileOutputLine>
215 parsed_profile_lines;
216
217 int line_no = 0;
218
219 // Skip extra metrics.
220 while (IsExtraMetricProfileOutputLine(profile_output_lines[line_no])) {
221 line_no++;
222 }
223
224 line_no++; // Skip 'Execution profile for ....'
225
226 ASSERT_LT(line_no, profile_output_lines.size());
227 TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
228 /*expect_hlo=*/false,
229 &parsed_profile_lines));
230
231 ASSERT_LT(line_no, profile_output_lines.size());
232 TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
233 /*expect_hlo=*/true,
234 &parsed_profile_lines));
235
236 ASSERT_LT(line_no, profile_output_lines.size());
237 TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
238 /*expect_hlo=*/true,
239 &parsed_profile_lines));
240
241 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile,
242 MaybeFind(parsed_profile_lines, "[total]"));
243 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
244 MaybeFind(parsed_profile_lines, "add"));
245 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile,
246 MaybeFind(parsed_profile_lines, "tanh"));
247
248 EXPECT_GT(total_profile.cycles, 0);
249 EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ");
250
251 EXPECT_TRUE(HasFlops(total_profile));
252 EXPECT_TRUE(HasTrops(total_profile));
253
254 EXPECT_GT(total_profile.cycles, dot_profile.cycles);
255 EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
256 EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
257
258 EXPECT_TRUE(HasFlops(dot_profile));
259 EXPECT_FALSE(HasTrops(dot_profile));
260
261 EXPECT_GT(total_profile.cycles, tanh_profile.cycles);
262 EXPECT_NE(tanh_profile.cycles_percentage, "0.00%");
263 EXPECT_NE(tanh_profile.cycles_percentage, "100.00%");
264
265 EXPECT_FALSE(HasFlops(tanh_profile));
266 EXPECT_TRUE(HasTrops(tanh_profile));
267 }
268
XLA_TEST_F(HloProfileTest,DISABLED_ON_GPU (ProfileWhileComputation))269 XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
270 const int64_t size = 256;
271 Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
272 Shape while_result_shape =
273 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), matrix_shape});
274
275 TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
276 PlatformUtil::GetDefaultPlatform());
277 TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
278 ClientLibrary::GetOrCreateLocalClient(platform));
279
280 XlaComputation condition;
281 {
282 XlaBuilder builder("condition");
283 auto state = Parameter(&builder, 0, while_result_shape, "state");
284 auto iteration = GetTupleElement(state, 0);
285 Gt(ConstantR0<int32_t>(&builder, 5), iteration);
286 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
287 }
288
289 XlaComputation body;
290 {
291 XlaBuilder builder("body");
292 auto state = Parameter(&builder, 0, while_result_shape, "state");
293 auto matrix = GetTupleElement(state, 1);
294 auto next_iteration =
295 Add(GetTupleElement(state, 0), ConstantR0<int32_t>(&builder, 1));
296 Tuple(&builder, {next_iteration, Mul(matrix, matrix)});
297 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
298 }
299
300 XlaBuilder builder(TestName());
301 auto initial_while_state =
302 Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
303 Parameter(&builder, 0, matrix_shape, "initial_value")});
304 auto while_result = While(condition, body, initial_while_state);
305 Add(GetTupleElement(while_result, 1),
306 Parameter(&builder, 1, matrix_shape, "other_value"));
307
308 TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
309
310 std::string profile_output;
311 ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape,
312 matrix_shape);
313 SCOPED_TRACE(profile_output);
314
315 std::vector<std::string> profile_output_lines =
316 absl::StrSplit(profile_output, '\n');
317
318 auto while_body_profile_start =
319 absl::c_find_if(profile_output_lines, [](absl::string_view s) {
320 return absl::StartsWith(s, "Execution profile for body");
321 });
322
323 ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
324
325 auto while_body_profile_end =
326 std::find_if(while_body_profile_start, profile_output_lines.end(),
327 [](absl::string_view s) {
328 return absl::StartsWith(s, "********** microseconds ");
329 });
330
331 // We emit a blank line before the "microseconds report" line.
332 while_body_profile_end--;
333
334 ASSERT_NE(while_body_profile_end, profile_output_lines.end());
335
336 absl::flat_hash_map<std::string, ParsedProfileOutputLine>
337 parsed_profile_lines;
338
339 for (auto while_body_profile_i = while_body_profile_start + 1;
340 while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
341 // There are multiple "get-tuple-element" instructions in the while body so
342 // we ignore them -- we don't want parsed_profile_lines to be a multi-map.
343 TF_ASSERT_OK(ParseOneProfileOutputLine(
344 *while_body_profile_i,
345 /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1),
346 &parsed_profile_lines, {"get-tuple-element"}));
347 }
348
349 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
350 MaybeFind(parsed_profile_lines, "[total]"));
351 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile,
352 MaybeFind(parsed_profile_lines, "multiply"));
353
354 EXPECT_GT(total_while_body_profile.cycles, 0);
355 EXPECT_EQ(total_while_body_profile.opcode, "[total]");
356 EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ");
357
358 EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
359 EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
360 EXPECT_NE(multiply_profile.cycles_percentage, "100.00%");
361 }
362 } // namespace
363 } // namespace xla
364
AddXlaHloProfileFlag(int argc,char ** argv)365 static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
366 // Intentional "leak".
367 char** new_argv = new char*[argc + 2];
368 for (int i = 0; i < argc; i++) {
369 new_argv[i] = argv[i];
370 }
371
372 // We do it this way (as opposed to piping in a modified DebugOptions
373 // instance) for better end-to-end integration testing.
374 new_argv[argc] = strdup("--xla_hlo_profile");
375
376 // Fusion can change the Hlo instructions that show up in the final Hlo
377 // executable, so block it here. Also block the WhileLoopInvariantCodeMotion
378 // pass, otherwise a while loop is transformed and we could not match the
379 // original name in the ProfileWhileComputation test.
380 new_argv[argc + 1] = strdup(
381 "--xla_disable_hlo_passes=fusion,fusion_merger,multi_output_fusion,"
382 "while-loop-invariant-code-motion");
383 return {argc + 2, new_argv};
384 }
385
main(int argc,char ** argv)386 GTEST_API_ int main(int argc, char** argv) {
387 std::vector<tensorflow::Flag> flag_list;
388 xla::AppendDebugOptionsFlags(&flag_list);
389 std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv);
390
391 auto usage = tensorflow::Flags::Usage(argv[0], flag_list);
392 if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
393 LOG(ERROR) << "\n" << usage;
394 return 2;
395 }
396
397 testing::InitGoogleTest(&argc, argv);
398 if (argc > 1) {
399 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
400 return 2;
401 }
402 return RUN_ALL_TESTS();
403 }
404