1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h"
17
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/compiler/xla/service/hlo.pb.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/statusor.h"
30 #include "tensorflow/core/profiler/convert/tool_options.h"
31 #include "tensorflow/core/profiler/utils/hlo_proto_to_module.h"
32
33 namespace tensorflow {
34 namespace profiler {
35
36 namespace {
37
38 using ::tensorflow::StatusOr;
39 using ::tensorflow::errors::InvalidArgument;
40 using ::xla::HloComputation;
41 using ::xla::HloInstruction;
42 using ::xla::HloModule;
43 using ::xla::HloPrintOptions;
44 using ::xla::HloProto;
45 using ::xla::HloRenderOptions;
46 using ::xla::RenderedGraphFormat;
47
FindInstruction(const HloModule & module,std::string node_name)48 const HloInstruction* FindInstruction(const HloModule& module,
49 std::string node_name) {
50 if (absl::StartsWith(node_name, "%")) {
51 node_name.erase(node_name.begin());
52 }
53 for (const HloComputation* computation : module.computations()) {
54 auto instrs = computation->instructions();
55 auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) {
56 // Try with and without "%" at the beginning of the node name.
57 return absl::EqualsIgnoreCase(instr->name(), node_name) ||
58 absl::EqualsIgnoreCase(instr->name(),
59 absl::StrCat("%", node_name));
60 });
61 if (it != instrs.end()) {
62 return *it;
63 }
64 }
65 return nullptr;
66 }
67
FindComputation(const HloModule & module,const std::string & comp_name)68 const HloComputation* FindComputation(const HloModule& module,
69 const std::string& comp_name) {
70 for (const HloComputation* computation : module.computations()) {
71 if (absl::EqualsIgnoreCase(computation->name(), comp_name)) {
72 return computation;
73 }
74 }
75 return nullptr;
76 }
77
CleanUpHloModuleForGraphviz(HloModule * hlo_module)78 void CleanUpHloModuleForGraphviz(HloModule* hlo_module) {
79 // Infeed config is escaped serialized proto, and graphviz server complains.
80 for (HloComputation* computation : hlo_module->computations()) {
81 for (HloInstruction* inst : computation->instructions()) {
82 if (inst->opcode() == xla::HloOpcode::kInfeed) {
83 inst->set_infeed_config("");
84 } else if (inst->opcode() == xla::HloOpcode::kOutfeed) {
85 inst->set_outfeed_config("");
86 }
87 }
88 }
89 }
90
Plot(std::unique_ptr<HloModule> module,const std::string & node_name,int graph_width,const HloRenderOptions & render_options,const RenderedGraphFormat & format)91 StatusOr<std::string> Plot(std::unique_ptr<HloModule> module,
92 const std::string& node_name, int graph_width,
93 const HloRenderOptions& render_options,
94 const RenderedGraphFormat& format) {
95 if (node_name.empty()) {
96 // This should not happen.
97 return InvalidArgument("node_name should not be empty");
98 }
99 // Find the node with the given name.
100 const HloInstruction* instr = FindInstruction(*module, node_name);
101 const HloComputation* comp = FindComputation(*module, node_name);
102 if (!instr && !comp) {
103 return InvalidArgument(
104 absl::StrCat("Couldn't find HloInstruction or HloComputation named ",
105 node_name, "."));
106 }
107 // Generate the graph and print the resulting string.
108 StatusOr<std::string> graph_handle;
109
110 CleanUpHloModuleForGraphviz(module.get());
111 if (comp) {
112 graph_handle =
113 xla::RenderGraph(*comp, "", comp->parent()->config().debug_options(),
114 format, nullptr, render_options);
115 } else {
116 graph_handle = xla::RenderNeighborhoodAround(*instr, graph_width, format,
117 render_options);
118 }
119 if (graph_handle.ok()) {
120 LOG(INFO) << graph_handle.ValueOrDie();
121 } else {
122 LOG(INFO) << "Unable to render graph: " << graph_handle.status();
123 }
124
125 return graph_handle;
126 }
127
128 // Default parameter constants for graph viewer.
129 static constexpr char kGraphTypeName[] = "graph";
130 static constexpr char kShortTxtTypeName[] = "short_txt";
131 static constexpr char kLongTxtTypeName[] = "long_txt";
132 static constexpr char kDefaultFormatString[] = "url";
133 static constexpr int kDefaultWidth = 3;
134 static constexpr int kDefaultShowMetadata = 0;
135 static constexpr int kDefaultMergeFusion = 0;
136
137 } // namespace
138
ParseGraphViewerParams(const ToolOptions & options)139 StatusOr<GraphViewerParams> ParseGraphViewerParams(const ToolOptions& options) {
140 GraphViewerParams params;
141 std::optional<std::string> type = GetParam<std::string>(options, "type");
142 if (!type.has_value()) {
143 return errors::InvalidArgument("Graph viewer must provide a type option.");
144 }
145
146 // For graph type.
147 if (type == kGraphTypeName) {
148 params.type = type.value();
149 if (std::optional<std::string> node_name =
150 GetParam<std::string>(options, "node_name")) {
151 params.node_name = node_name.value();
152 }
153
154 params.graph_width =
155 GetParamWithDefault<int>(options, "graph_width", kDefaultWidth);
156 params.render_options.show_backend_config = GetParamWithDefault<int>(
157 options, "show_metadata", kDefaultShowMetadata);
158 params.render_options.show_fusion_subcomputations =
159 !GetParamWithDefault<int>(options, "merge_fusion", kDefaultMergeFusion);
160 params.format = GetRenderFormat(GetParamWithDefault<std::string>(
161 options, "format", kDefaultFormatString));
162
163 return params;
164 }
165
166 // For txt type.
167 if (type == kShortTxtTypeName || type == kLongTxtTypeName) {
168 params.type = type.value();
169 params.verbose = (type == kLongTxtTypeName);
170 params.show_metadata =
171 GetParamWithDefault(options, "show_metadata", kDefaultShowMetadata);
172 return params;
173 }
174
175 // Unknown type.
176 return errors::InvalidArgument("Unknown graph viewer type option: ",
177 type.value());
178 }
179
GetRenderFormat(const std::string & format_string)180 xla::RenderedGraphFormat GetRenderFormat(const std::string& format_string) {
181 if (format_string == "html") {
182 return xla::RenderedGraphFormat::kHtml;
183 } else if (format_string == "dot") {
184 return xla::RenderedGraphFormat::kDot;
185 } else if (format_string == "url") {
186 return xla::RenderedGraphFormat::kUrl;
187 } else {
188 LOG(ERROR) << "Invalid graph format argument: " << format_string
189 << ", fallback to default url";
190 return xla::RenderedGraphFormat::kUrl;
191 }
192 }
193
ConvertHloProtoToGraph(const HloProto & hlo_proto,const std::string & node_name,int graph_width,const HloRenderOptions & render_options,const RenderedGraphFormat & format)194 StatusOr<std::string> ConvertHloProtoToGraph(
195 const HloProto& hlo_proto, const std::string& node_name, int graph_width,
196 const HloRenderOptions& render_options, const RenderedGraphFormat& format) {
197 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
198 ConvertHloProtoToModule(hlo_proto));
199 return Plot(std::move(hlo_module), node_name, graph_width, render_options,
200 format);
201 }
202
ConvertHloProtoToStringView(const HloProto & hlo_proto,bool verbose,bool metadata)203 StatusOr<std::string> ConvertHloProtoToStringView(const HloProto& hlo_proto,
204 bool verbose, bool metadata) {
205 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
206 ConvertHloProtoToModule(hlo_proto));
207 HloPrintOptions options;
208 if (!verbose) {
209 options = HloPrintOptions::ShortParsable();
210 }
211 options.set_print_large_constants(verbose);
212 options.set_print_metadata(metadata);
213 return hlo_module->ToString(options);
214 }
215 } // namespace profiler
216 } // namespace tensorflow
217