1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/graph_compiler_util.h"
17
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
25 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/versions.pb.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/node_builder.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/util/dump_graph.h"
37
38 namespace tensorflow {
39
40 namespace {
41
42 const char* const kFeedIdAttr = "_feed_id";
43 const char* const kFetchIdAttr = "_fetch_id";
44 const char* const kShapeAttr = "_shape";
45 const char* const kDebugNameAttr = "_debug_name";
46
47 typedef std::unordered_map<string, Node*> NodeMap;
48
49 // Each feed id identifies the positional output of some node, which may consist
50 // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
51 // tensor with a placeholder. For each feed tensor, replaces all edges so they
52 // point from a new _Arg node instead. The newly created _Arg nodes are added to
53 // `arg_nodes`.
AddArgNodes(Graph * graph,const NodeMap & node_map,const protobuf::RepeatedPtrField<tf2xla::Feed> & feeds,const std::unordered_map<string,string> & feed_remapping,std::unordered_set<const Node * > * arg_nodes)54 Status AddArgNodes(Graph* graph, const NodeMap& node_map,
55 const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
56 const std::unordered_map<string, string>& feed_remapping,
57 std::unordered_set<const Node*>* arg_nodes) {
58 for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
59 const tf2xla::Feed& feed = feeds[arg_index];
60 // All feeds have been replaced by placeholders.
61 const int output_index = 0;
62
63 const string key = TensorIdToString(feed.id());
64 const auto remap_it = feed_remapping.find(key);
65 auto node_it = node_map.find(remap_it->second);
66 if (node_it == node_map.end()) {
67 // Strip off the aot_feed_#/ prefix.
68 absl::string_view name(remap_it->second);
69 const auto index = name.find('/');
70 if (index > 0) name.remove_prefix(index + 1);
71 return errors::InvalidArgument(
72 "Node is fed but not needed for fetching: ", name);
73 }
74 const Node* feed_node = node_it->second;
75
76 // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
77 // "_shape" attr if we can determine it. That way the graph will be
78 // initialized with whatever shapes we can infer, while the user can still
79 // explicitly specify or override them.
80 Node* arg_node = nullptr;
81 TF_RETURN_IF_ERROR(
82 NodeBuilder(
83 absl::StrCat("_arg_", arg_index),
84 FunctionLibraryDefinition::FunctionLibraryDefinition::kArgOp)
85 .Attr("T", BaseType(feed_node->output_type(output_index)))
86 .Attr("index", arg_index)
87 .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
88 .Attr(kShapeAttr, TensorShape(feed.shape()))
89 .Attr(kDebugNameAttr, feed.name())
90 .Finalize(graph, &arg_node));
91 arg_nodes->insert(arg_node);
92
93 // Collects out-edges from the feed node that have a matching edge index;
94 // these will be replaced with edges from the arg node instead.
95 //
96 // We must collect the edges first and process them in a second pass, since
97 // removing the edge from the graph invalidates feed_node->out_edges.
98 std::vector<const Edge*> feed_edges;
99 for (const Edge* edge : feed_node->out_edges()) {
100 if (edge->src_output() == output_index) {
101 feed_edges.push_back(edge);
102 }
103 }
104 for (const Edge* edge : feed_edges) {
105 graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
106 graph->RemoveEdge(edge);
107 }
108 }
109 return Status::OK();
110 }
111
112 // Each fetch id identifies the positional output of some node. For each fetch
113 // node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
AddRetvalNodes(Graph * graph,const NodeMap & node_map,const protobuf::RepeatedPtrField<tf2xla::Fetch> & fetches,std::unordered_set<const Node * > * retval_nodes)114 Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
115 const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches,
116 std::unordered_set<const Node*>* retval_nodes) {
117 for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
118 const tf2xla::TensorId& id = fetches[ret_index].id();
119 auto it = node_map.find(id.node_name());
120 if (it == node_map.end()) {
121 return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
122 }
123 Node* fetch_node = it->second;
124 if (id.output_index() >= fetch_node->num_outputs()) {
125 return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
126 ", output index should be < ",
127 fetch_node->num_outputs());
128 }
129 // Connects fetch_node -> retval_node.
130 Node* retval_node = nullptr;
131 TF_RETURN_IF_ERROR(
132 NodeBuilder(absl::StrCat("_retval_", ret_index),
133 FunctionLibraryDefinition::kRetOp)
134 .Input(fetch_node, id.output_index())
135 .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
136 .Attr("index", ret_index)
137 .Attr(kFetchIdAttr, TensorIdToString(id))
138 .Finalize(graph, &retval_node));
139 retval_nodes->insert(retval_node);
140 }
141 return Status::OK();
142 }
143
144 // RewriteAndPruneGraph identifies input and output edges (named by the feed and
145 // fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
146 // nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
147 // execution to know the input and output args for the generated function.
RewriteAndPruneGraph(Graph * graph,const tf2xla::Config & config,const std::unordered_map<string,string> & feed_remapping)148 Status RewriteAndPruneGraph(
149 Graph* graph, const tf2xla::Config& config,
150 const std::unordered_map<string, string>& feed_remapping) {
151 NodeMap node_map;
152 for (Node* n : graph->nodes()) {
153 node_map[n->name()] = n;
154 }
155 std::unordered_set<const Node*> nodes_to_keep;
156 TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed(), feed_remapping,
157 &nodes_to_keep));
158 TF_RETURN_IF_ERROR(
159 AddRetvalNodes(graph, node_map, config.fetch(), &nodes_to_keep));
160 VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
161 PruneForReverseReachability(graph, std::move(nodes_to_keep));
162 FixupSourceAndSinkEdges(graph);
163 VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
164 // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
165 std::set<string> missing_feeds, missing_fetches;
166 for (const tf2xla::Feed& feed : config.feed()) {
167 missing_feeds.insert(TensorIdToString(feed.id()));
168 }
169 for (const tf2xla::Fetch& fetch : config.fetch()) {
170 missing_fetches.insert(TensorIdToString(fetch.id()));
171 }
172 for (const Node* n : graph->op_nodes()) {
173 if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
174 string feed_id;
175 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
176 if (missing_feeds.erase(feed_id) == 0) {
177 return errors::Aborted(FunctionLibraryDefinition::kArgOp,
178 " node found with unknown feed id: ", feed_id);
179 }
180 } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) {
181 string fetch_id;
182 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
183 if (missing_fetches.erase(fetch_id) == 0) {
184 return errors::Aborted(FunctionLibraryDefinition::kRetOp,
185 " node found with unknown fetch id: ", fetch_id);
186 }
187 }
188 }
189 if (!missing_feeds.empty() || !missing_fetches.empty()) {
190 return errors::Aborted(
191 "Post graph-pruning",
192 ", missing feeds: ", absl::StrJoin(missing_feeds, ", "),
193 ", missing fetches: ", absl::StrJoin(missing_fetches, ", "));
194 }
195 return Status::OK();
196 }
197
198 // CollectArgNodes collects _Arg nodes from the graph, and performs basic
199 // sanity-checking to ensure the index and type attributes of each node are
200 // initialized correctly.
CollectArgNodes(const Graph & graph,std::vector<Node * > * arg_nodes)201 Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
202 std::map<int, Node*> indexed_arg_nodes;
203 for (Node* n : graph.nodes()) {
204 if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
205 int index;
206 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
207 auto insert_result = indexed_arg_nodes.insert({index, n});
208 if (!insert_result.second) {
209 const Node* dup = insert_result.first->second;
210 return errors::InvalidArgument(
211 "Multiple ", FunctionLibraryDefinition::kArgOp,
212 " nodes with index ", index, ", ", FormatNodeForError(*n), " and ",
213 FormatNodeForError(*dup));
214 }
215 }
216 }
217 arg_nodes->clear();
218 for (const auto& index_node : indexed_arg_nodes) {
219 const int arg_nodes_size = arg_nodes->size();
220 if (index_node.first != arg_nodes_size) {
221 return errors::InvalidArgument(
222 "Expected ", FunctionLibraryDefinition::kArgOp, " node with index ",
223 arg_nodes->size(), ", but got index ", index_node.first);
224 }
225 arg_nodes->push_back(index_node.second);
226 }
227 return Status::OK();
228 }
229
230 } // namespace
231
CreateXlaArgs(const Graph & graph,std::vector<XlaCompiler::Argument> * xla_args)232 Status CreateXlaArgs(const Graph& graph,
233 std::vector<XlaCompiler::Argument>* xla_args) {
234 std::vector<Node*> arg_nodes;
235 TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
236 for (const Node* node : arg_nodes) {
237 XlaCompiler::Argument arg;
238 arg.kind = XlaCompiler::Argument::kParameter;
239 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
240 TensorShape shape;
241 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
242 arg.shape = shape;
243 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
244 xla_args->push_back(arg);
245 }
246 return Status::OK();
247 }
248
PopulateXlaArgs(const tf2xla::Config & config,std::vector<XlaCompiler::Argument> * xla_args)249 void PopulateXlaArgs(const tf2xla::Config& config,
250 std::vector<XlaCompiler::Argument>* xla_args) {
251 // Populate arguments with resource variables from the config. The variables
252 // get turned into inputs and outputs.
253 for (const tf2xla::Variable& variable : config.variable()) {
254 XlaCompiler::Argument arg;
255 arg.type = variable.type();
256 arg.kind = XlaCompiler::Argument::kResource;
257 arg.shape = variable.shape();
258 arg.name = variable.node_name();
259 arg.resource_kind = XlaResource::kVariable;
260 arg.initialized = true;
261 xla_args->push_back(std::move(arg));
262 }
263 }
264
InitGraph(const GraphDef & graph_def,const tf2xla::Config & config,std::unique_ptr<Graph> * graph)265 Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
266 std::unique_ptr<Graph>* graph) {
267 TF_RETURN_IF_ERROR(ValidateConfig(config));
268
269 FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
270 std::unique_ptr<Graph> g(new Graph(flib_def));
271
272 // Replace references to fed tensors with references to newly added
273 // placeholders.
274 GraphDef first_copy_def = graph_def;
275
276 // Maps from name:port of a feed to the name:port of the placeholder to use.
277 std::unordered_map<string, string> feed_remapping;
278 TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
279 &feed_remapping, &first_copy_def));
280
281 // Prune the GraphDef first so that unknown ops that we aren't compiling get
282 // filtered out.
283 GraphDef second_copy_def;
284 // Add the placeholder nodes as "fetches" in prune_config, such that they will
285 // be preserved in PruneGraphDefInto.
286 auto prune_config = config;
287 for (const auto& entry : feed_remapping) {
288 auto ph = prune_config.add_fetch();
289 *ph->mutable_id()->mutable_node_name() = entry.second;
290 ph->mutable_id()->set_output_index(0);
291 }
292 TF_RETURN_IF_ERROR(
293 PruneGraphDefInto(prune_config, first_copy_def, &second_copy_def));
294
295 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
296 &second_copy_def, *g->op_registry(), /*node_offset=*/0));
297
298 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
299 GraphConstructorOptions(), std::move(second_copy_def), g.get()));
300 TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
301
302 // Functionalize control flow.
303 TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def));
304 // After control flow functionalization, we might have more FunctionDef's
305 // (then/else branch, loop body). Add them to the graph.
306 TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto()));
307
308 *graph = std::move(g);
309 return Status::OK();
310 }
311
312 } // namespace tensorflow
313