• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/algorithm/container.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/platform_util.h"
30 #include "tensorflow/compiler/xla/service/stream_pool.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/regexp.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 namespace {
42 
43 class HloProfileTest : public ClientLibraryTestBase {};
44 
45 struct ParsedProfileOutputLine {
46   int64 cycles;
47   string cycles_percentage;
48   double usec;
49   string flops;
50   string trops;
51   string bytes_per_sec;
52   string bytes_per_cycle;
53   string opcode;
54 };
55 
HasFlops(const ParsedProfileOutputLine & parsed_line)56 ::testing::AssertionResult HasFlops(
57     const ParsedProfileOutputLine& parsed_line) {
58   if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) {
59     return ::testing::AssertionSuccess()
60            << "'flops' field present in  " << parsed_line.opcode << ": '"
61            << parsed_line.flops << "'";
62   }
63 
64   return ::testing::AssertionFailure()
65          << "'flops' field absent in  " << parsed_line.opcode << ": '"
66          << parsed_line.flops << "'";
67 }
68 
HasTrops(const ParsedProfileOutputLine & parsed_line)69 ::testing::AssertionResult HasTrops(
70     const ParsedProfileOutputLine& parsed_line) {
71   if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) {
72     return ::testing::AssertionSuccess()
73            << "'trops' field present in  " << parsed_line.opcode << ": '"
74            << parsed_line.trops << "'";
75   }
76 
77   return ::testing::AssertionFailure()
78          << "'trops' field absent in  " << parsed_line.opcode << ": '"
79          << parsed_line.trops << "'";
80 }
81 
ParseOneProfileOutputLine(const string & line,bool expect_hlo,absl::flat_hash_map<string,ParsedProfileOutputLine> * parsed_results,absl::Span<const absl::string_view> opcodes_to_ignore={})82 Status ParseOneProfileOutputLine(
83     const string& line, bool expect_hlo,
84     absl::flat_hash_map<string, ParsedProfileOutputLine>* parsed_results,
85     absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
86   string separator = "[^:]*:: +";
87   string match_percentage = R"(\d+\.\d*% +\d+Σ)";
88   string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
89   string match_usecs = "([0-9.]+) usec";
90   string match_flops = "([^ ]*)";
91   string match_trops = "([^ ]*)";
92   string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?";
93   string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?";
94 
95   // The underlined part is what we're trying to match with match_opcode:
96   //
97   //   %dot33 = f32[256,256]{1,0} dot(...)
98   //                              ^^^
99 
100   string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*"
101                                    : "(\\[total\\])( \\[entry\\])?";
102   string regexp_pattern = absl::StrCat(
103       " +", match_cycles, separator, match_usecs, separator, match_flops,
104       separator, match_trops, separator, match_bytes_per_sec, separator,
105       match_bytes_per_cycle, separator, match_opcode);
106 
107   ParsedProfileOutputLine parsed_line;
108   bool matched = RE2::FullMatch(
109       line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
110       &parsed_line.usec, &parsed_line.flops, &parsed_line.trops,
111       &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle,
112       &parsed_line.opcode);
113   if (!matched) {
114     return tensorflow::errors::InvalidArgument(
115         "Input did not match regexp.  Input: ", line,
116         ", Regexp: ", regexp_pattern);
117   }
118 
119   if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
120     InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
121   }
122 
123   return Status::OK();
124 }
125 
IsExtraMetricProfileOutputLine(const string & line)126 bool IsExtraMetricProfileOutputLine(const string& line) {
127   return RE2::FullMatch(line, "Extra metric \\S+: \\d+");
128 }
129 
130 // Returns void so that we can ASSERT.
ExecuteAndFetchProfile(string * profile_output,LocalClient * client,const XlaComputation & computation,const Shape & lhs_arg_shape,const Shape & rhs_arg_shape)131 void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
132                             const XlaComputation& computation,
133                             const Shape& lhs_arg_shape,
134                             const Shape& rhs_arg_shape) {
135   LocalService* service = ClientLibrary::GetXlaService(client->platform());
136   Backend* backend = service->mutable_backend();
137   se::StreamExecutor* executor = backend->default_stream_executor();
138   DeviceMemoryAllocator* allocator = backend->memory_allocator();
139   auto* transfer_manager = backend->transfer_manager();
140   TF_ASSERT_OK_AND_ASSIGN(
141       StreamPool::Ptr stream_ptr,
142       backend->BorrowStream(backend->default_device_ordinal()));
143 
144   TF_ASSERT_OK_AND_ASSIGN(
145       ScopedShapedBuffer lhs_arg,
146       transfer_manager->AllocateScopedShapedBuffer(
147           lhs_arg_shape, allocator, backend->default_device_ordinal()));
148   TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
149       stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
150 
151   TF_ASSERT_OK_AND_ASSIGN(
152       ScopedShapedBuffer rhs_arg,
153       transfer_manager->AllocateScopedShapedBuffer(
154           rhs_arg_shape, allocator, backend->default_device_ordinal()));
155   TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
156       stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
157 
158   ExecutableBuildOptions build_options;
159   build_options.mutable_debug_options()->set_xla_hlo_profile(true);
160   TF_ASSERT_OK_AND_ASSIGN(
161       std::unique_ptr<LocalExecutable> local_executable,
162       client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
163                       build_options));
164 
165   Executable* executable = local_executable->executable();
166   HloExecutionProfile hlo_execution_profile(
167       &executable->hlo_profile_printer_data(),
168       &executable->hlo_profile_index_map());
169 
170   ExecutableRunOptions exec_run_options;
171   exec_run_options.set_stream(stream_ptr.get());
172   exec_run_options.set_allocator(backend->memory_allocator());
173   exec_run_options.set_intra_op_thread_pool(
174       backend->eigen_intra_op_thread_pool_device());
175   ServiceExecutableRunOptions run_options(exec_run_options,
176                                           /*borrow_stream=*/nullptr);
177   std::vector<const ShapedBuffer*> args = {&lhs_arg, &rhs_arg};
178   TF_ASSERT_OK_AND_ASSIGN(
179       auto execution_result,
180       executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile));
181   TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
182   (void)execution_result;
183 
184   *profile_output =
185       hlo_execution_profile.ToString(executor->GetDeviceDescription());
186 
187   XLA_VLOG_LINES(4, *profile_output);
188 }
189 
XLA_TEST_F(HloProfileTest,ProfileSingleComputation)190 XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
191   const int64 m = 256, k = 256, n = 256;
192   Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
193   Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k});
194 
195   TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
196                           PlatformUtil::GetDefaultPlatform());
197   TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
198                           ClientLibrary::GetOrCreateLocalClient(platform));
199 
200   XlaBuilder builder(TestName());
201   Tanh(Add(
202       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
203       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
204 
205   TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
206 
207   string profile_output;
208   ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape,
209                          rhs_shape);
210   VLOG(4) << "Profile Output:\n" << profile_output;
211   std::vector<string> profile_output_lines =
212       absl::StrSplit(profile_output, '\n');
213 
214   absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
215 
216   int line_no = 0;
217 
218   // Skip extra metrics.
219   while (IsExtraMetricProfileOutputLine(profile_output_lines[line_no])) {
220     line_no++;
221   }
222 
223   line_no++;  // Skip 'Execution profile for ....'
224 
225   ASSERT_LT(line_no, profile_output_lines.size());
226   TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
227                                          /*expect_hlo=*/false,
228                                          &parsed_profile_lines));
229 
230   ASSERT_LT(line_no, profile_output_lines.size());
231   TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
232                                          /*expect_hlo=*/true,
233                                          &parsed_profile_lines));
234 
235   ASSERT_LT(line_no, profile_output_lines.size());
236   TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
237                                          /*expect_hlo=*/true,
238                                          &parsed_profile_lines));
239 
240   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile,
241                           MaybeFind(parsed_profile_lines, "[total]"));
242   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
243                           MaybeFind(parsed_profile_lines, "add"));
244   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile,
245                           MaybeFind(parsed_profile_lines, "tanh"));
246 
247   EXPECT_GT(total_profile.cycles, 0);
248   EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ");
249 
250   EXPECT_TRUE(HasFlops(total_profile));
251   EXPECT_TRUE(HasTrops(total_profile));
252 
253   EXPECT_GT(total_profile.cycles, dot_profile.cycles);
254   EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
255   EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
256 
257   EXPECT_TRUE(HasFlops(dot_profile));
258   EXPECT_FALSE(HasTrops(dot_profile));
259 
260   EXPECT_GT(total_profile.cycles, tanh_profile.cycles);
261   EXPECT_NE(tanh_profile.cycles_percentage, "0.00%");
262   EXPECT_NE(tanh_profile.cycles_percentage, "100.00%");
263 
264   EXPECT_FALSE(HasFlops(tanh_profile));
265   EXPECT_TRUE(HasTrops(tanh_profile));
266 }
267 
XLA_TEST_F(HloProfileTest,ProfileWhileComputation)268 XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
269   const int64 size = 256;
270   Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
271   Shape while_result_shape =
272       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), matrix_shape});
273 
274   TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
275                           PlatformUtil::GetDefaultPlatform());
276   TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
277                           ClientLibrary::GetOrCreateLocalClient(platform));
278 
279   XlaComputation condition;
280   {
281     XlaBuilder builder("condition");
282     auto state = Parameter(&builder, 0, while_result_shape, "state");
283     auto iteration = GetTupleElement(state, 0);
284     Gt(ConstantR0<int32>(&builder, 5), iteration);
285     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
286   }
287 
288   XlaComputation body;
289   {
290     XlaBuilder builder("body");
291     auto state = Parameter(&builder, 0, while_result_shape, "state");
292     auto matrix = GetTupleElement(state, 1);
293     auto next_iteration =
294         Add(GetTupleElement(state, 0), ConstantR0<int32>(&builder, 1));
295     Tuple(&builder, {next_iteration, Mul(matrix, matrix)});
296     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
297   }
298 
299   XlaBuilder builder(TestName());
300   auto initial_while_state =
301       Tuple(&builder, {ConstantR0<int32>(&builder, 0),
302                        Parameter(&builder, 0, matrix_shape, "initial_value")});
303   auto while_result = While(condition, body, initial_while_state);
304   Add(GetTupleElement(while_result, 1),
305       Parameter(&builder, 1, matrix_shape, "other_value"));
306 
307   TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
308 
309   string profile_output;
310   ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape,
311                          matrix_shape);
312   SCOPED_TRACE(profile_output);
313 
314   std::vector<string> profile_output_lines =
315       absl::StrSplit(profile_output, '\n');
316 
317   auto while_body_profile_start =
318       absl::c_find_if(profile_output_lines, [](absl::string_view s) {
319         return absl::StartsWith(s, "Execution profile for body");
320       });
321 
322   ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
323 
324   auto while_body_profile_end =
325       std::find_if(while_body_profile_start, profile_output_lines.end(),
326                    [](absl::string_view s) {
327                      return absl::StartsWith(s, "********** microseconds ");
328                    });
329 
330   // We emit a blank line before the "microseconds report" line.
331   while_body_profile_end--;
332 
333   ASSERT_NE(while_body_profile_end, profile_output_lines.end());
334 
335   absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
336 
337   for (auto while_body_profile_i = while_body_profile_start + 1;
338        while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
339     // There are multiple "get-tuple-element" instructions in the while body so
340     // we ignore them -- we don't want parsed_profile_lines to be a multi-map.
341     TF_ASSERT_OK(ParseOneProfileOutputLine(
342         *while_body_profile_i,
343         /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1),
344         &parsed_profile_lines, {"get-tuple-element"}));
345   }
346 
347   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
348                           MaybeFind(parsed_profile_lines, "[total]"));
349   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile,
350                           MaybeFind(parsed_profile_lines, "multiply"));
351 
352   EXPECT_GT(total_while_body_profile.cycles, 0);
353   EXPECT_EQ(total_while_body_profile.opcode, "[total]");
354   EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ");
355 
356   EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
357   EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
358   EXPECT_NE(multiply_profile.cycles_percentage, "100.00%");
359 }
360 }  // namespace
361 }  // namespace xla
362 
AddXlaHloProfileFlag(int argc,char ** argv)363 static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
364   // Intentional "leak".
365   char** new_argv = new char*[argc + 2];
366   for (int i = 0; i < argc; i++) {
367     new_argv[i] = argv[i];
368   }
369 
370   // We do it this way (as opposed to piping in a modified DebugOptions
371   // instance) for better end-to-end integration testing.
372   new_argv[argc] = strdup("--xla_hlo_profile");
373 
374   // Fusion can change the Hlo instructions that show up in the final Hlo
375   // executable, so block it here. Also block the WhileLoopInvariantCodeMotion
376   // pass, otherwise a while loop is transformed and we could not match the
377   // original name in the ProfileWhileComputation test.
378   new_argv[argc + 1] = strdup(
379       "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion");
380   return {argc + 2, new_argv};
381 }
382 
main(int argc,char ** argv)383 GTEST_API_ int main(int argc, char** argv) {
384   std::vector<tensorflow::Flag> flag_list;
385   xla::AppendDebugOptionsFlags(&flag_list);
386   std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv);
387 
388   auto usage = tensorflow::Flags::Usage(argv[0], flag_list);
389   if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
390     LOG(ERROR) << "\n" << usage;
391     return 2;
392   }
393 
394   testing::InitGoogleTest(&argc, argv);
395   if (argc > 1) {
396     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
397     return 2;
398   }
399   return RUN_ALL_TESTS();
400 }
401