• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/algorithm/container.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/platform_util.h"
30 #include "tensorflow/compiler/xla/service/stream_pool.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/regexp.h"
37 #include "tensorflow/core/platform/test.h"
38 
39 namespace xla {
40 namespace {
41 
42 class HloProfileTest : public ClientLibraryTestBase {};
43 
44 struct ParsedProfileOutputLine {
45   int64_t cycles;
46   std::string cycles_percentage;
47   double usec;
48   std::string flops;
49   std::string trops;
50   std::string bytes_per_sec;
51   std::string bytes_per_cycle;
52   std::string opcode;
53 };
54 
HasFlops(const ParsedProfileOutputLine & parsed_line)55 ::testing::AssertionResult HasFlops(
56     const ParsedProfileOutputLine& parsed_line) {
57   if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) {
58     return ::testing::AssertionSuccess()
59            << "'flops' field present in  " << parsed_line.opcode << ": '"
60            << parsed_line.flops << "'";
61   }
62 
63   return ::testing::AssertionFailure()
64          << "'flops' field absent in  " << parsed_line.opcode << ": '"
65          << parsed_line.flops << "'";
66 }
67 
HasTrops(const ParsedProfileOutputLine & parsed_line)68 ::testing::AssertionResult HasTrops(
69     const ParsedProfileOutputLine& parsed_line) {
70   if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) {
71     return ::testing::AssertionSuccess()
72            << "'trops' field present in  " << parsed_line.opcode << ": '"
73            << parsed_line.trops << "'";
74   }
75 
76   return ::testing::AssertionFailure()
77          << "'trops' field absent in  " << parsed_line.opcode << ": '"
78          << parsed_line.trops << "'";
79 }
80 
ParseOneProfileOutputLine(const std::string & line,bool expect_hlo,absl::flat_hash_map<std::string,ParsedProfileOutputLine> * parsed_results,absl::Span<const absl::string_view> opcodes_to_ignore={})81 Status ParseOneProfileOutputLine(
82     const std::string& line, bool expect_hlo,
83     absl::flat_hash_map<std::string, ParsedProfileOutputLine>* parsed_results,
84     absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
85   std::string separator = "[^:]*:: +";
86   std::string match_percentage = R"(\d+\.\d*% +\d+Σ)";
87   std::string match_cycles =
88       R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
89   std::string match_usecs = "([0-9.]+) usec";
90   std::string match_flops = "([^ ]*)";
91   std::string match_trops = "([^ ]*)";
92   std::string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?";
93   std::string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?";
94 
95   // The underlined part is what we're trying to match with match_opcode:
96   //
97   //   %dot33 = f32[256,256]{1,0} dot(...)
98   //                              ^^^
99 
100   std::string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*"
101                                         : "(\\[total\\])( \\[entry\\])?";
102   std::string regexp_pattern = absl::StrCat(
103       " +", match_cycles, separator, match_usecs, separator, match_flops,
104       separator, match_trops, separator, match_bytes_per_sec, separator,
105       match_bytes_per_cycle, separator, match_opcode);
106 
107   ParsedProfileOutputLine parsed_line;
108   bool matched = RE2::FullMatch(
109       line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
110       &parsed_line.usec, &parsed_line.flops, &parsed_line.trops,
111       &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle,
112       &parsed_line.opcode);
113   if (!matched) {
114     return tensorflow::errors::InvalidArgument(
115         "Input did not match regexp.  Input: ", line,
116         ", Regexp: ", regexp_pattern);
117   }
118 
119   if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
120     InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
121   }
122 
123   return OkStatus();
124 }
125 
IsExtraMetricProfileOutputLine(const std::string & line)126 bool IsExtraMetricProfileOutputLine(const std::string& line) {
127   return RE2::FullMatch(line, "Extra metric \\S+: \\d+");
128 }
129 
130 // Returns void so that we can ASSERT.
ExecuteAndFetchProfile(std::string * profile_output,LocalClient * client,const XlaComputation & computation,const Shape & lhs_arg_shape,const Shape & rhs_arg_shape)131 void ExecuteAndFetchProfile(std::string* profile_output, LocalClient* client,
132                             const XlaComputation& computation,
133                             const Shape& lhs_arg_shape,
134                             const Shape& rhs_arg_shape) {
135   LocalService* service = ClientLibrary::GetXlaService(client->platform());
136   Backend* backend = service->mutable_backend();
137   se::StreamExecutor* executor = backend->default_stream_executor();
138   se::DeviceMemoryAllocator* allocator = backend->memory_allocator();
139   auto* transfer_manager = backend->transfer_manager();
140   TF_ASSERT_OK_AND_ASSIGN(
141       StreamPool::Ptr stream_ptr,
142       backend->BorrowStream(backend->default_device_ordinal()));
143 
144   TF_ASSERT_OK_AND_ASSIGN(
145       ScopedShapedBuffer lhs_arg,
146       transfer_manager->AllocateScopedShapedBuffer(
147           lhs_arg_shape, allocator, backend->default_device_ordinal()));
148   TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
149       stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
150 
151   TF_ASSERT_OK_AND_ASSIGN(
152       ScopedShapedBuffer rhs_arg,
153       transfer_manager->AllocateScopedShapedBuffer(
154           rhs_arg_shape, allocator, backend->default_device_ordinal()));
155   TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
156       stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
157 
158   ExecutableBuildOptions build_options;
159   build_options.mutable_debug_options()->set_xla_hlo_profile(true);
160   TF_ASSERT_OK_AND_ASSIGN(
161       auto local_executables,
162       client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
163                       build_options));
164 
165   Executable* executable = local_executables[0]->executable();
166   HloExecutionProfile hlo_execution_profile(
167       &executable->hlo_profile_printer_data(),
168       &executable->hlo_profile_index_map());
169 
170   ExecutableRunOptions exec_run_options;
171   exec_run_options.set_stream(stream_ptr.get());
172   exec_run_options.set_allocator(backend->memory_allocator());
173   exec_run_options.set_intra_op_thread_pool(
174       backend->eigen_intra_op_thread_pool_device());
175   ServiceExecutableRunOptions run_options(exec_run_options,
176                                           /*borrow_stream=*/nullptr);
177   std::vector<const ShapedBuffer*> args = {&lhs_arg, &rhs_arg};
178   TF_ASSERT_OK_AND_ASSIGN(
179       auto execution_result,
180       executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile));
181   TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
182   (void)execution_result;
183 
184   *profile_output = hlo_execution_profile.ToString(
185       executor->GetDeviceDescription().clock_rate_ghz());
186 
187   XLA_VLOG_LINES(4, *profile_output);
188 }
189 
XLA_TEST_F(HloProfileTest,DISABLED_ON_GPU (ProfileSingleComputation))190 XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileSingleComputation)) {
191   const int64_t m = 256, k = 256, n = 256;
192   Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
193   Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k});
194 
195   TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
196                           PlatformUtil::GetDefaultPlatform());
197   TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
198                           ClientLibrary::GetOrCreateLocalClient(platform));
199 
200   XlaBuilder builder(TestName());
201   Tanh(Add(
202       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
203       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
204 
205   TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
206 
207   std::string profile_output;
208   ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape,
209                          rhs_shape);
210   VLOG(4) << "Profile Output:\n" << profile_output;
211   std::vector<std::string> profile_output_lines =
212       absl::StrSplit(profile_output, '\n');
213 
214   absl::flat_hash_map<std::string, ParsedProfileOutputLine>
215       parsed_profile_lines;
216 
217   int line_no = 0;
218 
219   // Skip extra metrics.
220   while (IsExtraMetricProfileOutputLine(profile_output_lines[line_no])) {
221     line_no++;
222   }
223 
224   line_no++;  // Skip 'Execution profile for ....'
225 
226   ASSERT_LT(line_no, profile_output_lines.size());
227   TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
228                                          /*expect_hlo=*/false,
229                                          &parsed_profile_lines));
230 
231   ASSERT_LT(line_no, profile_output_lines.size());
232   TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
233                                          /*expect_hlo=*/true,
234                                          &parsed_profile_lines));
235 
236   ASSERT_LT(line_no, profile_output_lines.size());
237   TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
238                                          /*expect_hlo=*/true,
239                                          &parsed_profile_lines));
240 
241   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile,
242                           MaybeFind(parsed_profile_lines, "[total]"));
243   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
244                           MaybeFind(parsed_profile_lines, "add"));
245   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile,
246                           MaybeFind(parsed_profile_lines, "tanh"));
247 
248   EXPECT_GT(total_profile.cycles, 0);
249   EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ");
250 
251   EXPECT_TRUE(HasFlops(total_profile));
252   EXPECT_TRUE(HasTrops(total_profile));
253 
254   EXPECT_GT(total_profile.cycles, dot_profile.cycles);
255   EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
256   EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
257 
258   EXPECT_TRUE(HasFlops(dot_profile));
259   EXPECT_FALSE(HasTrops(dot_profile));
260 
261   EXPECT_GT(total_profile.cycles, tanh_profile.cycles);
262   EXPECT_NE(tanh_profile.cycles_percentage, "0.00%");
263   EXPECT_NE(tanh_profile.cycles_percentage, "100.00%");
264 
265   EXPECT_FALSE(HasFlops(tanh_profile));
266   EXPECT_TRUE(HasTrops(tanh_profile));
267 }
268 
XLA_TEST_F(HloProfileTest,DISABLED_ON_GPU (ProfileWhileComputation))269 XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
270   const int64_t size = 256;
271   Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
272   Shape while_result_shape =
273       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), matrix_shape});
274 
275   TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
276                           PlatformUtil::GetDefaultPlatform());
277   TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
278                           ClientLibrary::GetOrCreateLocalClient(platform));
279 
280   XlaComputation condition;
281   {
282     XlaBuilder builder("condition");
283     auto state = Parameter(&builder, 0, while_result_shape, "state");
284     auto iteration = GetTupleElement(state, 0);
285     Gt(ConstantR0<int32_t>(&builder, 5), iteration);
286     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
287   }
288 
289   XlaComputation body;
290   {
291     XlaBuilder builder("body");
292     auto state = Parameter(&builder, 0, while_result_shape, "state");
293     auto matrix = GetTupleElement(state, 1);
294     auto next_iteration =
295         Add(GetTupleElement(state, 0), ConstantR0<int32_t>(&builder, 1));
296     Tuple(&builder, {next_iteration, Mul(matrix, matrix)});
297     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
298   }
299 
300   XlaBuilder builder(TestName());
301   auto initial_while_state =
302       Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
303                        Parameter(&builder, 0, matrix_shape, "initial_value")});
304   auto while_result = While(condition, body, initial_while_state);
305   Add(GetTupleElement(while_result, 1),
306       Parameter(&builder, 1, matrix_shape, "other_value"));
307 
308   TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
309 
310   std::string profile_output;
311   ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape,
312                          matrix_shape);
313   SCOPED_TRACE(profile_output);
314 
315   std::vector<std::string> profile_output_lines =
316       absl::StrSplit(profile_output, '\n');
317 
318   auto while_body_profile_start =
319       absl::c_find_if(profile_output_lines, [](absl::string_view s) {
320         return absl::StartsWith(s, "Execution profile for body");
321       });
322 
323   ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
324 
325   auto while_body_profile_end =
326       std::find_if(while_body_profile_start, profile_output_lines.end(),
327                    [](absl::string_view s) {
328                      return absl::StartsWith(s, "********** microseconds ");
329                    });
330 
331   // We emit a blank line before the "microseconds report" line.
332   while_body_profile_end--;
333 
334   ASSERT_NE(while_body_profile_end, profile_output_lines.end());
335 
336   absl::flat_hash_map<std::string, ParsedProfileOutputLine>
337       parsed_profile_lines;
338 
339   for (auto while_body_profile_i = while_body_profile_start + 1;
340        while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
341     // There are multiple "get-tuple-element" instructions in the while body so
342     // we ignore them -- we don't want parsed_profile_lines to be a multi-map.
343     TF_ASSERT_OK(ParseOneProfileOutputLine(
344         *while_body_profile_i,
345         /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1),
346         &parsed_profile_lines, {"get-tuple-element"}));
347   }
348 
349   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
350                           MaybeFind(parsed_profile_lines, "[total]"));
351   TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile,
352                           MaybeFind(parsed_profile_lines, "multiply"));
353 
354   EXPECT_GT(total_while_body_profile.cycles, 0);
355   EXPECT_EQ(total_while_body_profile.opcode, "[total]");
356   EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ");
357 
358   EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
359   EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
360   EXPECT_NE(multiply_profile.cycles_percentage, "100.00%");
361 }
362 }  // namespace
363 }  // namespace xla
364 
AddXlaHloProfileFlag(int argc,char ** argv)365 static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
366   // Intentional "leak".
367   char** new_argv = new char*[argc + 2];
368   for (int i = 0; i < argc; i++) {
369     new_argv[i] = argv[i];
370   }
371 
372   // We do it this way (as opposed to piping in a modified DebugOptions
373   // instance) for better end-to-end integration testing.
374   new_argv[argc] = strdup("--xla_hlo_profile");
375 
376   // Fusion can change the Hlo instructions that show up in the final Hlo
377   // executable, so block it here. Also block the WhileLoopInvariantCodeMotion
378   // pass, otherwise a while loop is transformed and we could not match the
379   // original name in the ProfileWhileComputation test.
380   new_argv[argc + 1] = strdup(
381       "--xla_disable_hlo_passes=fusion,fusion_merger,multi_output_fusion,"
382       "while-loop-invariant-code-motion");
383   return {argc + 2, new_argv};
384 }
385 
main(int argc,char ** argv)386 GTEST_API_ int main(int argc, char** argv) {
387   std::vector<tensorflow::Flag> flag_list;
388   xla::AppendDebugOptionsFlags(&flag_list);
389   std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv);
390 
391   auto usage = tensorflow::Flags::Usage(argv[0], flag_list);
392   if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
393     LOG(ERROR) << "\n" << usage;
394     return 2;
395   }
396 
397   testing::InitGoogleTest(&argc, argv);
398   if (argc > 1) {
399     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
400     return 2;
401   }
402   return RUN_ALL_TESTS();
403 }
404