• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/service/hlo.pb.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/statusor.h"
30 #include "tensorflow/core/profiler/convert/tool_options.h"
31 #include "tensorflow/core/profiler/utils/hlo_proto_to_module.h"
32 
33 namespace tensorflow {
34 namespace profiler {
35 
36 namespace {
37 
38 using ::tensorflow::StatusOr;
39 using ::tensorflow::errors::InvalidArgument;
40 using ::xla::HloComputation;
41 using ::xla::HloInstruction;
42 using ::xla::HloModule;
43 using ::xla::HloPrintOptions;
44 using ::xla::HloProto;
45 using ::xla::HloRenderOptions;
46 using ::xla::RenderedGraphFormat;
47 
FindInstruction(const HloModule & module,std::string node_name)48 const HloInstruction* FindInstruction(const HloModule& module,
49                                       std::string node_name) {
50   if (absl::StartsWith(node_name, "%")) {
51     node_name.erase(node_name.begin());
52   }
53   for (const HloComputation* computation : module.computations()) {
54     auto instrs = computation->instructions();
55     auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) {
56       // Try with and without "%" at the beginning of the node name.
57       return absl::EqualsIgnoreCase(instr->name(), node_name) ||
58              absl::EqualsIgnoreCase(instr->name(),
59                                     absl::StrCat("%", node_name));
60     });
61     if (it != instrs.end()) {
62       return *it;
63     }
64   }
65   return nullptr;
66 }
67 
FindComputation(const HloModule & module,const std::string & comp_name)68 const HloComputation* FindComputation(const HloModule& module,
69                                       const std::string& comp_name) {
70   for (const HloComputation* computation : module.computations()) {
71     if (absl::EqualsIgnoreCase(computation->name(), comp_name)) {
72       return computation;
73     }
74   }
75   return nullptr;
76 }
77 
CleanUpHloModuleForGraphviz(HloModule * hlo_module)78 void CleanUpHloModuleForGraphviz(HloModule* hlo_module) {
79   // Infeed config is escaped serialized proto, and graphviz server complains.
80   for (HloComputation* computation : hlo_module->computations()) {
81     for (HloInstruction* inst : computation->instructions()) {
82       if (inst->opcode() == xla::HloOpcode::kInfeed) {
83         inst->set_infeed_config("");
84       } else if (inst->opcode() == xla::HloOpcode::kOutfeed) {
85         inst->set_outfeed_config("");
86       }
87     }
88   }
89 }
90 
Plot(std::unique_ptr<HloModule> module,const std::string & node_name,int graph_width,const HloRenderOptions & render_options,const RenderedGraphFormat & format)91 StatusOr<std::string> Plot(std::unique_ptr<HloModule> module,
92                            const std::string& node_name, int graph_width,
93                            const HloRenderOptions& render_options,
94                            const RenderedGraphFormat& format) {
95   if (node_name.empty()) {
96     // This should not happen.
97     return InvalidArgument("node_name should not be empty");
98   }
99   // Find the node with the given name.
100   const HloInstruction* instr = FindInstruction(*module, node_name);
101   const HloComputation* comp = FindComputation(*module, node_name);
102   if (!instr && !comp) {
103     return InvalidArgument(
104         absl::StrCat("Couldn't find HloInstruction or HloComputation named ",
105                      node_name, "."));
106   }
107   // Generate the graph and print the resulting string.
108   StatusOr<std::string> graph_handle;
109 
110   CleanUpHloModuleForGraphviz(module.get());
111   if (comp) {
112     graph_handle =
113         xla::RenderGraph(*comp, "", comp->parent()->config().debug_options(),
114                          format, nullptr, render_options);
115   } else {
116     graph_handle = xla::RenderNeighborhoodAround(*instr, graph_width, format,
117                                                  render_options);
118   }
119   if (graph_handle.ok()) {
120     LOG(INFO) << graph_handle.ValueOrDie();
121   } else {
122     LOG(INFO) << "Unable to render graph: " << graph_handle.status();
123   }
124 
125   return graph_handle;
126 }
127 
128 // Default parameter constants for graph viewer.
129 static constexpr char kGraphTypeName[] = "graph";
130 static constexpr char kShortTxtTypeName[] = "short_txt";
131 static constexpr char kLongTxtTypeName[] = "long_txt";
132 static constexpr char kDefaultFormatString[] = "url";
133 static constexpr int kDefaultWidth = 3;
134 static constexpr int kDefaultShowMetadata = 0;
135 static constexpr int kDefaultMergeFusion = 0;
136 
137 }  // namespace
138 
ParseGraphViewerParams(const ToolOptions & options)139 StatusOr<GraphViewerParams> ParseGraphViewerParams(const ToolOptions& options) {
140   GraphViewerParams params;
141   std::optional<std::string> type = GetParam<std::string>(options, "type");
142   if (!type.has_value()) {
143     return errors::InvalidArgument("Graph viewer must provide a type option.");
144   }
145 
146   // For graph type.
147   if (type == kGraphTypeName) {
148     params.type = type.value();
149     if (std::optional<std::string> node_name =
150             GetParam<std::string>(options, "node_name")) {
151       params.node_name = node_name.value();
152     }
153 
154     params.graph_width =
155         GetParamWithDefault<int>(options, "graph_width", kDefaultWidth);
156     params.render_options.show_backend_config = GetParamWithDefault<int>(
157         options, "show_metadata", kDefaultShowMetadata);
158     params.render_options.show_fusion_subcomputations =
159         !GetParamWithDefault<int>(options, "merge_fusion", kDefaultMergeFusion);
160     params.format = GetRenderFormat(GetParamWithDefault<std::string>(
161         options, "format", kDefaultFormatString));
162 
163     return params;
164   }
165 
166   // For txt type.
167   if (type == kShortTxtTypeName || type == kLongTxtTypeName) {
168     params.type = type.value();
169     params.verbose = (type == kLongTxtTypeName);
170     params.show_metadata =
171         GetParamWithDefault(options, "show_metadata", kDefaultShowMetadata);
172     return params;
173   }
174 
175   // Unknown type.
176   return errors::InvalidArgument("Unknown graph viewer type option: ",
177                                  type.value());
178 }
179 
GetRenderFormat(const std::string & format_string)180 xla::RenderedGraphFormat GetRenderFormat(const std::string& format_string) {
181   if (format_string == "html") {
182     return xla::RenderedGraphFormat::kHtml;
183   } else if (format_string == "dot") {
184     return xla::RenderedGraphFormat::kDot;
185   } else if (format_string == "url") {
186     return xla::RenderedGraphFormat::kUrl;
187   } else {
188     LOG(ERROR) << "Invalid graph format argument: " << format_string
189                << ", fallback to default url";
190     return xla::RenderedGraphFormat::kUrl;
191   }
192 }
193 
ConvertHloProtoToGraph(const HloProto & hlo_proto,const std::string & node_name,int graph_width,const HloRenderOptions & render_options,const RenderedGraphFormat & format)194 StatusOr<std::string> ConvertHloProtoToGraph(
195     const HloProto& hlo_proto, const std::string& node_name, int graph_width,
196     const HloRenderOptions& render_options, const RenderedGraphFormat& format) {
197   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
198                       ConvertHloProtoToModule(hlo_proto));
199   return Plot(std::move(hlo_module), node_name, graph_width, render_options,
200               format);
201 }
202 
ConvertHloProtoToStringView(const HloProto & hlo_proto,bool verbose,bool metadata)203 StatusOr<std::string> ConvertHloProtoToStringView(const HloProto& hlo_proto,
204                                                   bool verbose, bool metadata) {
205   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
206                       ConvertHloProtoToModule(hlo_proto));
207   HloPrintOptions options;
208   if (!verbose) {
209     options = HloPrintOptions::ShortParsable();
210   }
211   options.set_print_large_constants(verbose);
212   options.set_print_metadata(metadata);
213   return hlo_module->ToString(options);
214 }
215 }  // namespace profiler
216 }  // namespace tensorflow
217