• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/graph_compiler_util.h"
17 
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
25 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/versions.pb.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/node_builder.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/util/dump_graph.h"
37 
38 namespace tensorflow {
39 
40 namespace {
41 
42 const char* const kFeedIdAttr = "_feed_id";
43 const char* const kFetchIdAttr = "_fetch_id";
44 const char* const kShapeAttr = "_shape";
45 const char* const kDebugNameAttr = "_debug_name";
46 
47 typedef std::unordered_map<string, Node*> NodeMap;
48 
49 // Each feed id identifies the positional output of some node, which may consist
50 // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
51 // tensor with a placeholder.  For each feed tensor, replaces all edges so they
52 // point from a new _Arg node instead. The newly created _Arg nodes are added to
53 // `arg_nodes`.
AddArgNodes(Graph * graph,const NodeMap & node_map,const protobuf::RepeatedPtrField<tf2xla::Feed> & feeds,const std::unordered_map<string,string> & feed_remapping,std::unordered_set<const Node * > * arg_nodes)54 Status AddArgNodes(Graph* graph, const NodeMap& node_map,
55                    const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
56                    const std::unordered_map<string, string>& feed_remapping,
57                    std::unordered_set<const Node*>* arg_nodes) {
58   for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
59     const tf2xla::Feed& feed = feeds[arg_index];
60     // All feeds have been replaced by placeholders.
61     const int output_index = 0;
62 
63     const string key = TensorIdToString(feed.id());
64     const auto remap_it = feed_remapping.find(key);
65     auto node_it = node_map.find(remap_it->second);
66     if (node_it == node_map.end()) {
67       // Strip off the aot_feed_#/ prefix.
68       absl::string_view name(remap_it->second);
69       const auto index = name.find('/');
70       if (index > 0) name.remove_prefix(index + 1);
71       return errors::InvalidArgument(
72           "Node is fed but not needed for fetching: ", name);
73     }
74     const Node* feed_node = node_it->second;
75 
76     // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
77     // "_shape" attr if we can determine it.  That way the graph will be
78     // initialized with whatever shapes we can infer, while the user can still
79     // explicitly specify or override them.
80     Node* arg_node = nullptr;
81     TF_RETURN_IF_ERROR(
82         NodeBuilder(
83             absl::StrCat("_arg_", arg_index),
84             FunctionLibraryDefinition::FunctionLibraryDefinition::kArgOp)
85             .Attr("T", BaseType(feed_node->output_type(output_index)))
86             .Attr("index", arg_index)
87             .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
88             .Attr(kShapeAttr, TensorShape(feed.shape()))
89             .Attr(kDebugNameAttr, feed.name())
90             .Finalize(graph, &arg_node));
91     arg_nodes->insert(arg_node);
92 
93     // Collects out-edges from the feed node that have a matching edge index;
94     // these will be replaced with edges from the arg node instead.
95     //
96     // We must collect the edges first and process them in a second pass, since
97     // removing the edge from the graph invalidates feed_node->out_edges.
98     std::vector<const Edge*> feed_edges;
99     for (const Edge* edge : feed_node->out_edges()) {
100       if (edge->src_output() == output_index) {
101         feed_edges.push_back(edge);
102       }
103     }
104     for (const Edge* edge : feed_edges) {
105       graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
106       graph->RemoveEdge(edge);
107     }
108   }
109   return Status::OK();
110 }
111 
112 // Each fetch id identifies the positional output of some node.  For each fetch
113 // node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
AddRetvalNodes(Graph * graph,const NodeMap & node_map,const protobuf::RepeatedPtrField<tf2xla::Fetch> & fetches,std::unordered_set<const Node * > * retval_nodes)114 Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
115                       const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches,
116                       std::unordered_set<const Node*>* retval_nodes) {
117   for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
118     const tf2xla::TensorId& id = fetches[ret_index].id();
119     auto it = node_map.find(id.node_name());
120     if (it == node_map.end()) {
121       return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
122     }
123     Node* fetch_node = it->second;
124     if (id.output_index() >= fetch_node->num_outputs()) {
125       return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
126                                      ", output index should be < ",
127                                      fetch_node->num_outputs());
128     }
129     // Connects fetch_node -> retval_node.
130     Node* retval_node = nullptr;
131     TF_RETURN_IF_ERROR(
132         NodeBuilder(absl::StrCat("_retval_", ret_index),
133                     FunctionLibraryDefinition::kRetOp)
134             .Input(fetch_node, id.output_index())
135             .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
136             .Attr("index", ret_index)
137             .Attr(kFetchIdAttr, TensorIdToString(id))
138             .Finalize(graph, &retval_node));
139     retval_nodes->insert(retval_node);
140   }
141   return Status::OK();
142 }
143 
144 // RewriteAndPruneGraph identifies input and output edges (named by the feed and
145 // fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
146 // nodes, and outputs flow to _Retval nodes.  This allows the symbolic graph
147 // execution to know the input and output args for the generated function.
RewriteAndPruneGraph(Graph * graph,const tf2xla::Config & config,const std::unordered_map<string,string> & feed_remapping)148 Status RewriteAndPruneGraph(
149     Graph* graph, const tf2xla::Config& config,
150     const std::unordered_map<string, string>& feed_remapping) {
151   NodeMap node_map;
152   for (Node* n : graph->nodes()) {
153     node_map[n->name()] = n;
154   }
155   std::unordered_set<const Node*> nodes_to_keep;
156   TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed(), feed_remapping,
157                                  &nodes_to_keep));
158   TF_RETURN_IF_ERROR(
159       AddRetvalNodes(graph, node_map, config.fetch(), &nodes_to_keep));
160   VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
161   PruneForReverseReachability(graph, std::move(nodes_to_keep));
162   FixupSourceAndSinkEdges(graph);
163   VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
164   // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
165   std::set<string> missing_feeds, missing_fetches;
166   for (const tf2xla::Feed& feed : config.feed()) {
167     missing_feeds.insert(TensorIdToString(feed.id()));
168   }
169   for (const tf2xla::Fetch& fetch : config.fetch()) {
170     missing_fetches.insert(TensorIdToString(fetch.id()));
171   }
172   for (const Node* n : graph->op_nodes()) {
173     if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
174       string feed_id;
175       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
176       if (missing_feeds.erase(feed_id) == 0) {
177         return errors::Aborted(FunctionLibraryDefinition::kArgOp,
178                                " node found with unknown feed id: ", feed_id);
179       }
180     } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) {
181       string fetch_id;
182       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
183       if (missing_fetches.erase(fetch_id) == 0) {
184         return errors::Aborted(FunctionLibraryDefinition::kRetOp,
185                                " node found with unknown fetch id: ", fetch_id);
186       }
187     }
188   }
189   if (!missing_feeds.empty() || !missing_fetches.empty()) {
190     return errors::Aborted(
191         "Post graph-pruning",
192         ", missing feeds: ", absl::StrJoin(missing_feeds, ", "),
193         ", missing fetches: ", absl::StrJoin(missing_fetches, ", "));
194   }
195   return Status::OK();
196 }
197 
198 // CollectArgNodes collects _Arg nodes from the graph, and performs basic
199 // sanity-checking to ensure the index and type attributes of each node are
200 // initialized correctly.
CollectArgNodes(const Graph & graph,std::vector<Node * > * arg_nodes)201 Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
202   std::map<int, Node*> indexed_arg_nodes;
203   for (Node* n : graph.nodes()) {
204     if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
205       int index;
206       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
207       auto insert_result = indexed_arg_nodes.insert({index, n});
208       if (!insert_result.second) {
209         const Node* dup = insert_result.first->second;
210         return errors::InvalidArgument(
211             "Multiple ", FunctionLibraryDefinition::kArgOp,
212             " nodes with index ", index, ", ", FormatNodeForError(*n), " and ",
213             FormatNodeForError(*dup));
214       }
215     }
216   }
217   arg_nodes->clear();
218   for (const auto& index_node : indexed_arg_nodes) {
219     const int arg_nodes_size = arg_nodes->size();
220     if (index_node.first != arg_nodes_size) {
221       return errors::InvalidArgument(
222           "Expected ", FunctionLibraryDefinition::kArgOp, " node with index ",
223           arg_nodes->size(), ", but got index ", index_node.first);
224     }
225     arg_nodes->push_back(index_node.second);
226   }
227   return Status::OK();
228 }
229 
230 }  // namespace
231 
CreateXlaArgs(const Graph & graph,std::vector<XlaCompiler::Argument> * xla_args)232 Status CreateXlaArgs(const Graph& graph,
233                      std::vector<XlaCompiler::Argument>* xla_args) {
234   std::vector<Node*> arg_nodes;
235   TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
236   for (const Node* node : arg_nodes) {
237     XlaCompiler::Argument arg;
238     arg.kind = XlaCompiler::Argument::kParameter;
239     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
240     TensorShape shape;
241     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
242     arg.shape = shape;
243     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
244     xla_args->push_back(arg);
245   }
246   return Status::OK();
247 }
248 
PopulateXlaArgs(const tf2xla::Config & config,std::vector<XlaCompiler::Argument> * xla_args)249 void PopulateXlaArgs(const tf2xla::Config& config,
250                      std::vector<XlaCompiler::Argument>* xla_args) {
251   // Populate arguments with resource variables from the config. The variables
252   // get turned into inputs and outputs.
253   for (const tf2xla::Variable& variable : config.variable()) {
254     XlaCompiler::Argument arg;
255     arg.type = variable.type();
256     arg.kind = XlaCompiler::Argument::kResource;
257     arg.shape = variable.shape();
258     arg.name = variable.node_name();
259     arg.resource_kind = XlaResource::kVariable;
260     arg.initialized = true;
261     xla_args->push_back(std::move(arg));
262   }
263 }
264 
InitGraph(const GraphDef & graph_def,const tf2xla::Config & config,std::unique_ptr<Graph> * graph)265 Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
266                  std::unique_ptr<Graph>* graph) {
267   TF_RETURN_IF_ERROR(ValidateConfig(config));
268 
269   FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
270   std::unique_ptr<Graph> g(new Graph(flib_def));
271 
272   // Replace references to fed tensors with references to newly added
273   // placeholders.
274   GraphDef first_copy_def = graph_def;
275 
276   // Maps from name:port of a feed to the name:port of the placeholder to use.
277   std::unordered_map<string, string> feed_remapping;
278   TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
279                                              &feed_remapping, &first_copy_def));
280 
281   // Prune the GraphDef first so that unknown ops that we aren't compiling get
282   // filtered out.
283   GraphDef second_copy_def;
284   // Add the placeholder nodes as "fetches" in prune_config, such that they will
285   // be preserved in PruneGraphDefInto.
286   auto prune_config = config;
287   for (const auto& entry : feed_remapping) {
288     auto ph = prune_config.add_fetch();
289     *ph->mutable_id()->mutable_node_name() = entry.second;
290     ph->mutable_id()->set_output_index(0);
291   }
292   TF_RETURN_IF_ERROR(
293       PruneGraphDefInto(prune_config, first_copy_def, &second_copy_def));
294 
295   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
296       &second_copy_def, *g->op_registry(), /*node_offset=*/0));
297 
298   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
299       GraphConstructorOptions(), std::move(second_copy_def), g.get()));
300   TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
301 
302   // Functionalize control flow.
303   TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def));
304   // After control flow functionalization, we might have more FunctionDef's
305   // (then/else branch, loop body). Add them to the graph.
306   TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto()));
307 
308   *graph = std::move(g);
309   return Status::OK();
310 }
311 
312 }  // namespace tensorflow
313