• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
16 
17 #include "tensorflow/core/common_runtime/scoped_allocator.h"
18 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/utils/frame.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 
29 // Like TF_RETURN_IF_ERROR, but also logs a WARNING.
30 #define LOG_WARNING_AND_RETURN_IF_ERROR(...)            \
31   do {                                                  \
32     const ::tensorflow::Status _status = (__VA_ARGS__); \
33     if (TF_PREDICT_FALSE(!_status.ok())) {              \
34       LOG(WARNING) << "error: " << _status;             \
35       return _status;                                   \
36     }                                                   \
37   } while (0)
38 
39 namespace tensorflow {
40 namespace grappler {
41 
42 namespace {
43 // Node names often have some kind of name_scope prefix, with slashes,
44 // and a _nn numeric suffix.  Returns true if the main part of the node_name
45 // matches op_name, i.e. it looks from the name like this node is
46 // of that op type.
HasOpName(const string & node_name,const string & op_name)47 bool HasOpName(const string& node_name, const string& op_name) {
48   size_t begin = node_name.rfind("/");
49   if (begin == string::npos) {
50     begin = 0;
51   } else {
52     ++begin;
53   }
54   size_t end = node_name.rfind("_");
55   if (end != string::npos) {
56     size_t p = end + 1;
57     while (p < node_name.size()) {
58       if (!isdigit(node_name[p])) {
59         end = node_name.size();
60         break;
61       }
62       ++p;
63     }
64   } else {
65     end = node_name.size();
66   }
67   return node_name.substr(begin, end - begin) == op_name;
68 }
69 
70 // After shape inference has been done each op should be annotated
71 // with its output shape(s).  This function iterates over a collection
72 // of ops that are a potential application of a ScopedAllocator.  It
73 // verifies whether they all have the same output type and if so
74 // gathers a vector of their output shapes.  It returns an error if
75 // any of the ops doesn't have type or shape data, or if it has more
76 // than one output, of if the output type of all ops is not the same.
77 // If it returns OK then *type and *shapes should be correctly populated.
CheckTypesAndGetShapes(const GraphProperties & graph_properties,const std::vector<NodeDef * > & ops,DataType * type,std::vector<TensorShape> * shapes)78 Status CheckTypesAndGetShapes(const GraphProperties& graph_properties,
79                               const std::vector<NodeDef*>& ops, DataType* type,
80                               std::vector<TensorShape>* shapes) {
81   VLOG(1) << "CheckTypesAndGetShapes";
82   *type = DT_INVALID;
83   for (NodeDef* n : ops) {
84     AttrSlice n_attrs = AttrSlice(*n);
85     DataType dtype;
86     LOG_WARNING_AND_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
87     VLOG(2) << "op " << n->name() << " has type " << dtype << " shapes.size() "
88             << shapes->size();
89     if (!graph_properties.HasOutputProperties(n->name())) {
90       LOG(ERROR) << "Node " << n->DebugString() << " lacks output shape.";
91       return errors::Internal("Node ", n->name(), " lacks output shape.");
92     }
93     const std::vector<OpInfo::TensorProperties>& prop_list =
94         graph_properties.GetOutputProperties(n->name());
95     if (prop_list.size() != 1) {
96       return errors::Internal("Node ", n->name(),
97                               " does not have exactly one output as expected "
98                               "by ScopedAllocatorOptimizer");
99     }
100     const OpInfo::TensorProperties& props = prop_list[0];
101     if (shapes->empty()) {
102       *type = props.dtype();
103     } else if (*type != props.dtype()) {
104       return errors::Internal("Group ops don't all have same type");
105     } else if (!TensorShape::IsValid(props.shape())) {
106       return errors::Internal("Complete shape not known for ", n->name());
107     }
108     VLOG(2) << "Adding shape " << props.shape().DebugString();
109     shapes->push_back(TensorShape(props.shape()));
110   }
111   return Status::OK();
112 }
113 
114 // Describes an existing input edge in the graph.
115 struct InputDesc {
116   NodeDef* from_node_def;
117   int output_slot;
118   NodeDef* to_node_def;
InputDesctensorflow::grappler::__anonaf288b9b0111::InputDesc119   InputDesc(NodeDef* f, int os, NodeDef* t)
120       : from_node_def(f), output_slot(os), to_node_def(t) {}
121 };
122 
123 // Populates *inputs with all of the non-control inputs of ops.
124 // Returns error if it fails to find exactly one input for each op,
125 // or if some input is not of type dtype.
GetInputs(NodeMap * node_map,const std::vector<NodeDef * > & ops,DataType dtype,std::vector<InputDesc> * inputs)126 Status GetInputs(NodeMap* node_map, const std::vector<NodeDef*>& ops,
127                  DataType dtype, std::vector<InputDesc>* inputs) {
128   VLOG(1) << "Getinputs";
129   for (NodeDef* n : ops) {
130     NodeDef* inode = nullptr;
131     int position = 0;
132     VLOG(2) << "for node " << n->name();
133     for (const auto& input_name : n->input()) {
134       if (!IsControlInput(input_name)) {
135         if (inode) {
136           return errors::Internal("Found more than one input for node ",
137                                   n->name());
138         }
139         ParseNodeName(input_name, &position);
140         inode = node_map->GetNode(input_name);
141         CHECK(inode) << input_name;
142         VLOG(2) << "inode " << inode->DebugString();
143       }
144     }
145     AttrSlice inode_attrs = AttrSlice(*inode);
146     DataType inode_dtype;
147     LOG_WARNING_AND_RETURN_IF_ERROR(
148         GetNodeAttr(inode_attrs, "T", &inode_dtype));
149     if (inode_dtype != dtype) {
150       return errors::Internal("ScopedAllocatorOptimizer expected input type ",
151                               dtype, " but found ", inode_dtype);
152     }
153     // inputs->push_back(InputDesc(inode, position, n));
154     inputs->emplace_back(inode, position, n);
155   }
156   return Status::OK();
157 }
158 
159 // Remove the NodeDef nd from node_map and graph.  It must be the case
160 // that nd no longer has any input or output edges, though that is not
161 // checked.
RemoveNode(NodeDef * nd,GraphDef * graph,NodeMap * node_map)162 void RemoveNode(NodeDef* nd, GraphDef* graph, NodeMap* node_map) {
163   node_map->RemoveNode(nd->name());
164   // TODO(tucker): The efficiency of this routine is poor.
165   // Change to accumulate and do a bulk removal, maybe refactoring
166   // some code from dependency_optimizer.
167   protobuf::RepeatedPtrField<NodeDef>* nodes = graph->mutable_node();
168   for (int i = 0; i < nodes->size(); ++i) {
169     if (nd->name() == (*nodes)[i].name()) {
170       nodes->SwapElements(i, nodes->size() - 1);
171       nodes->RemoveLast();
172       return;
173     }
174   }
175   LOG(FATAL) << "Failed to find node " << nd->name() << " in graph";
176 }
177 
178 // Removes a named edge from between two nodes.
RemoveEdge(const string & input_edge_name,const string & from_node_name,NodeDef * to_node,NodeMap * node_map)179 Status RemoveEdge(const string& input_edge_name, const string& from_node_name,
180                   NodeDef* to_node, NodeMap* node_map) {
181   if (node_map) {
182     node_map->RemoveOutput(from_node_name, to_node->name());
183   }
184   protobuf::RepeatedPtrField<string>* inputs = to_node->mutable_input();
185   int edge_index = -1;
186   for (edge_index = 0; edge_index < inputs->size(); ++edge_index) {
187     VLOG(2) << " consider edge " << (*inputs)[edge_index];
188     if ((*inputs)[edge_index] == input_edge_name) {
189       break;
190     }
191   }
192   if (edge_index >= inputs->size()) {
193     return errors::Internal("Could not find input name ", input_edge_name,
194                             " at node ", to_node->name());
195   }
196   inputs->DeleteSubrange(edge_index, 1);
197   return Status::OK();
198 }
199 }  // namespace
200 
ExtendNodeAttr(StringPiece name,const std::vector<int32> & values,NodeDef * node_def)201 void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name,
202                                               const std::vector<int32>& values,
203                                               NodeDef* node_def) {
204   if (HasNodeAttr(*node_def, name)) {
205     VLOG(2) << "extending";
206     AttrValue* existing = &(*node_def->mutable_attr())[string(name)];
207     for (int32 i : values) {
208       existing->mutable_list()->add_i(i);
209     }
210   } else {
211     VLOG(2) << "setting new attr value";
212     AddNodeAttr(name, values, node_def);
213   }
214 }
215 
216 class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
217  public:
~UnaryElementwiseRewriter()218   ~UnaryElementwiseRewriter() override {}
219 
220   // Return non-OK if any input is already committed to a ScopedAllocator.
CheckExistingScopedAllocator(const std::vector<InputDesc> & inputs)221   Status CheckExistingScopedAllocator(const std::vector<InputDesc>& inputs) {
222     for (const InputDesc& nd : inputs) {
223       VLOG(2) << "get attrs for " << nd.from_node_def->name();
224       AttrSlice n_attrs = AttrSlice(*nd.from_node_def);
225       int sa_id;
226       Status ss = GetNodeAttr(n_attrs, "sa_id", &sa_id);
227       if (ss.ok()) {
228         LOG(INFO) << "Abandoning PARewriter because input "
229                   << nd.from_node_def->name() << " is already assigned "
230                   << "to ScopedAllocator " << sa_id;
231         return errors::Internal(
232             "Abandoning PARewriter because input ", nd.from_node_def->name(),
233             " is already assigned to ScopedAllocator ", sa_id);
234       }
235     }
236     return Status::OK();
237   }
238 
239   // Return non-OK if any input is a member of op_set.
CheckInternalDataDependency(const std::set<string> & op_set,const std::vector<InputDesc> & inputs)240   Status CheckInternalDataDependency(const std::set<string>& op_set,
241                                      const std::vector<InputDesc>& inputs) {
242     for (const InputDesc& nd : inputs) {
243       if (op_set.find(nd.from_node_def->name()) != op_set.end()) {
244         if (nd.output_slot != tensorflow::Graph::kControlSlot) {
245           return errors::Internal("Data edge exists bewtween ",
246                                   nd.from_node_def->name(),
247                                   " and another "
248                                   "node in the set");
249         }
250       }
251     }
252     return Status::OK();
253   }
254 
255   // Remove all control edges between members of ops.
ClearInternalControlInputs(const std::set<string> & op_set,const std::vector<NodeDef * > & ops,NodeMap * node_map)256   void ClearInternalControlInputs(const std::set<string>& op_set,
257                                   const std::vector<NodeDef*>& ops,
258                                   NodeMap* node_map) {
259     for (NodeDef* n : ops) {
260       for (const auto& input_name : n->input()) {
261         if (IsControlInput(input_name)) {
262           int position = 0;
263           string input_node_name = ParseNodeName(input_name, &position);
264           CHECK_EQ(position, -1);
265           if (op_set.find(input_node_name) != op_set.end()) {
266             // This is an internal control edge.  Remove it.
267             VLOG(1) << "Remove control output from " << input_node_name
268                     << " via edge " << input_name << " to " << n->name();
269             TF_CHECK_OK(RemoveEdge(input_name, input_node_name, n, node_map));
270           }
271         }
272       }
273     }
274   }
275 
276   // Examine the input set of an op set, gathering their shapes and types
277   // and checking whether there are any considerations that prevent use
278   // of a single ScopedAllocator for all of those inputs.
AnalyzeInputs(ScopedAllocatorOptimizer * sa_opti,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,string * device_name,DataType * dtype,std::vector<TensorShape> * input_shapes,std::vector<InputDesc> * inputs,TensorShape * sa_shape)279   Status AnalyzeInputs(ScopedAllocatorOptimizer* sa_opti, NodeMap* node_map,
280                        const std::vector<NodeDef*>& ops,
281                        const std::set<string>& op_instance_names,
282                        string* device_name, DataType* dtype,
283                        std::vector<TensorShape>* input_shapes,
284                        std::vector<InputDesc>* inputs, TensorShape* sa_shape) {
285     CHECK(graph_properties_);
286     LOG_WARNING_AND_RETURN_IF_ERROR(
287         CheckTypesAndGetShapes(*graph_properties_, ops, dtype, input_shapes));
288     LOG_WARNING_AND_RETURN_IF_ERROR(
289         GetInputs(sa_opti->node_map(), ops, *dtype, inputs));
290     LOG_WARNING_AND_RETURN_IF_ERROR(CheckExistingScopedAllocator(*inputs));
291     LOG_WARNING_AND_RETURN_IF_ERROR(
292         CheckInternalDataDependency(op_instance_names, *inputs));
293     ClearInternalControlInputs(op_instance_names, ops, node_map);
294     *device_name = ops[0]->device();
295     CHECK(!device_name->empty());
296     CHECK(!input_shapes->empty());
297     CHECK_EQ(0, Allocator::kAllocatorAlignment % DataTypeSize(*dtype))
298         << "ScopedAllocatorOptimizer only applies to types that evenly "
299         << "divide kAllocatorAlignment";
300     std::vector<ScopedAllocator::Field> sa_fields;
301     // Calculate the field embedding boundaries and thereby the
302     // required size of the backing tensor.
303     int64 num_bytes = ScopedAllocatorMgr::PopulateFields(
304         0 /*scope_id*/, *input_shapes, *dtype, &sa_fields);
305     int64 num_elts = num_bytes / DataTypeSize(*dtype);
306     VLOG(2) << "num_bytes " << num_bytes << " num_elts=" << num_elts;
307     *sa_shape = TensorShape({num_elts});
308     return Status::OK();
309   }
310 
311   // Build the ScopedAllocator node that will be assigned to allocate
312   // the output tensors of the input node set.
ConstructScopedAllocatorNode(ScopedAllocatorOptimizer * sa_opti,GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const std::vector<TensorShape> & input_shapes,const std::vector<InputDesc> & inputs,const TensorShape & sa_shape)313   Status ConstructScopedAllocatorNode(
314       ScopedAllocatorOptimizer* sa_opti, GraphDef* graph, NodeMap* node_map,
315       const std::vector<NodeDef*>& ops, const string& device_name,
316       DataType dtype, int sa_id, const string& sa_name,
317       const std::vector<TensorShape>& input_shapes,
318       const std::vector<InputDesc>& inputs, const TensorShape& sa_shape) {
319     VLOG(2) << "ConstructScopedAllocatorNode " << sa_name;
320     NodeDefBuilder sa_builder(sa_name, "_ScopedAllocator");
321     sa_builder.Device(device_name);
322     sa_builder.Attr("sa_name", sa_name);
323     sa_builder.Attr("T", dtype);
324     sa_builder.Attr("id", sa_id);
325     sa_builder.Attr("shapes", input_shapes);
326     sa_builder.Attr("shape", sa_shape);
327     sa_builder.Attr("expected_call_count", static_cast<int64>(ops.size()));
328     NodeDef* sa_node = graph->add_node();
329     LOG_WARNING_AND_RETURN_IF_ERROR(sa_builder.Finalize(sa_node));
330     node_map->AddNode(sa_name, sa_node);
331 
332     // Add control edges from the ScopedAllocatorOp to all of the
333     // input nodes and mark them for allocation from backing tensor.
334     for (int i = 0; i < inputs.size(); ++i) {
335       auto& nd = inputs[i];
336       VLOG(2) << "To input " << i << ": " << nd.from_node_def->name()
337               << " add control input "
338               << "^" << sa_name;
339       nd.from_node_def->add_input(strings::StrCat("^", sa_name));
340       // This attribute says: allocate output_slot from
341       // ScopedAllocator instance sa_id + 1 + i.
342       ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator",
343                                                {nd.output_slot, sa_id + 1 + i},
344                                                nd.from_node_def);
345       node_map->AddOutput(sa_name, nd.from_node_def->name());
346     }
347     return Status::OK();
348   }
349 
BuildSAConcatNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const string & sac_name,const TensorShape & sa_shape,std::vector<NodeDefBuilder::NodeOut> * sac_inputs)350   Status BuildSAConcatNode(GraphDef* graph, NodeMap* node_map,
351                            const std::vector<NodeDef*>& ops,
352                            const std::set<string>& op_instance_names,
353                            const string& device_name, DataType dtype, int sa_id,
354                            const string& sa_name, const string& sac_name,
355                            const TensorShape& sa_shape,
356                            std::vector<NodeDefBuilder::NodeOut>* sac_inputs) {
357     VLOG(2) << "BuildSAConcatNode " << sac_name;
358     std::set<string> sac_ctl_inputs;
359     for (int i = 0; i < ops.size(); ++i) {
360       NodeDef* old_op = ops[i];
361       for (const string& old_op_input : old_op->input()) {
362         int position = 0;
363         string input_name = ParseNodeName(old_op_input, &position);
364         if (position == -1) {
365           // A control input: drop if from another member of the op set.
366           if (op_instance_names.find(old_op_input) == op_instance_names.end()) {
367             sac_ctl_inputs.insert(old_op_input);
368           }
369         } else {
370           // TODO(tucker): remove redundant check.
371           // A data input: illegal if from another member of the op set.
372           if (op_instance_names.find(old_op_input) != op_instance_names.end()) {
373             LOG(ERROR) << "Data edge between " << old_op_input << " and "
374                        << old_op->name() << " cannot build ScopedAllocator.";
375             return errors::Internal("Data edge between ", old_op_input, " and ",
376                                     old_op->name(),
377                                     " cannot build ScopedAllocator.");
378           }
379           sac_inputs->push_back(
380               NodeDefBuilder::NodeOut(old_op_input, 0, dtype));
381         }
382         VLOG(3) << "from op " << i << ": " << old_op->name()
383                 << " sac_inputs append " << old_op_input;
384       }
385     }
386     NodeDefBuilder sac_builder(sac_name, "_ScopedAllocatorConcat");
387     VLOG(2) << "New sac_name " << sac_name << " shape "
388             << sa_shape.DebugString();
389     sac_builder.Device(device_name);
390     sac_builder.Attr("sa_name", sa_name);
391     sac_builder.Attr("id", sa_id);
392     sac_builder.Attr("T", dtype);
393     sac_builder.Attr("shape", sa_shape);
394     sac_builder.Attr("N", static_cast<int>(sac_inputs->size()));
395     sac_builder.Input(NodeDefBuilder::NodeOut(sa_name, 0, dtype));
396     sac_builder.Input(*sac_inputs);
397     NodeDef* sac_node = graph->add_node();
398     LOG_WARNING_AND_RETURN_IF_ERROR(sac_builder.Finalize(sac_node));
399     node_map->AddNode(sac_name, sac_node);
400     node_map->AddOutput(sa_name, sac_name);
401 
402     // Attach the old control inputs to the new sac node.
403     for (const string& ctl_input : sac_ctl_inputs) {
404       sac_node->add_input(ctl_input);
405     }
406     return Status::OK();
407   }
408 
BuildReplacementOp(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,const string & op_name,const string & sac_name,const string & sa_op_name)409   Status BuildReplacementOp(GraphDef* graph, NodeMap* node_map,
410                             const std::vector<NodeDef*>& ops,
411                             const string& device_name, DataType dtype,
412                             const string& op_name, const string& sac_name,
413                             const string& sa_op_name) {
414     VLOG(2) << "BuildReplacementOp " << sa_op_name;
415     NodeDefBuilder op_builder(sa_op_name, op_name);
416     op_builder.Device(device_name);
417 
418     // Transfer the Node Attr from the first replaced Node to the new
419     // Node.  TODO(tucker): In principle we should verify that
420     // the Attr are consistent and compatible across all op instances.
421     // Unfortunately that will probably require op-specific tests, so
422     // punt on that for the time being.
423     AttrSlice first_slice(*ops[0]);
424     for (auto& it : first_slice) {
425       op_builder.Attr(it.first, it.second);
426     }
427     op_builder.Attr("_forward_input", {0, 0});
428     op_builder.Input(sac_name, 0, dtype);
429     NodeDef* sa_op_node = graph->add_node();
430     LOG_WARNING_AND_RETURN_IF_ERROR(op_builder.Finalize(sa_op_node));
431     node_map->AddNode(sa_op_name, sa_op_node);
432     node_map->AddOutput(sac_name, sa_op_name);
433     return Status::OK();
434   }
435 
BuildSplitNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::vector<TensorShape> & input_shapes,const std::vector<NodeDefBuilder::NodeOut> & sac_inputs,const string & device_name,DataType dtype,const string & op_name,int sa_id,const string & sas_name,const string & sa_name,const string & sa_op_name)436   Status BuildSplitNode(GraphDef* graph, NodeMap* node_map,
437                         const std::vector<NodeDef*>& ops,
438                         const std::vector<TensorShape>& input_shapes,
439                         const std::vector<NodeDefBuilder::NodeOut>& sac_inputs,
440                         const string& device_name, DataType dtype,
441                         const string& op_name, int sa_id,
442                         const string& sas_name, const string& sa_name,
443                         const string& sa_op_name) {
444     VLOG(2) << "new ScopedAllocatorSplit " << sas_name;
445     NodeDefBuilder sas_builder(sas_name, "_ScopedAllocatorSplit");
446     sas_builder.Device(device_name);
447     sas_builder.Attr("sa_name", sa_name);
448     sas_builder.Attr("id", sa_id);
449     sas_builder.Attr("T", dtype);
450     sas_builder.Attr("shapes", input_shapes);
451     std::vector<NodeDefBuilder::NodeOut> sas_inputs = sac_inputs;
452     sas_builder.Attr("N", static_cast<int>(sas_inputs.size()));
453     sas_builder.Input(NodeDefBuilder::NodeOut(sa_op_name, 0, dtype));
454     sas_builder.Input(sas_inputs);
455     NodeDef* sas_node = graph->add_node();
456     LOG_WARNING_AND_RETURN_IF_ERROR(sas_builder.Finalize(sas_node));
457     node_map->AddNode(sas_name, sas_node);
458     node_map->AddOutput(sa_op_name, sas_name);
459     return Status::OK();
460   }
461 
462   // After the new ScopedAllocator and its corresponding Concat and
463   // Split nodes have been built, and a new single Op instance
464   // constructed, rewire the graph: Remove input edges to the old Op
465   // nodes and replace the old Op node outputs with the corresponding
466   // ScopedAllocatorSplit node outputs.  After this the old Op nodes
467   // should no longer have any input or output edges and they can be
468   // removed from the graph.
RewireSubgraph(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & op_name,const string & sas_name)469   Status RewireSubgraph(GraphDef* graph, NodeMap* node_map,
470                         const std::vector<NodeDef*>& ops,
471                         const std::set<string>& op_instance_names,
472                         const string& op_name, const string& sas_name) {
473     VLOG(2) << "RewireSubgraph";
474     for (int op_idx = 0; op_idx < ops.size(); ++op_idx) {
475       NodeDef* old_op = ops[op_idx];
476       // Copy the output node set since we'll be modifying the version
477       // maintained by NodeMap in the loop.
478       std::set<NodeDef*> output_nodes = node_map->GetOutputs(old_op->name());
479       VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size()
480               << " outputs.  Moving them to the PASplit node.";
481       if (VLOG_IS_ON(2)) {
482         for (NodeDef* n : output_nodes) {
483           VLOG(3) << "    output: " << n->name();
484         }
485       }
486       for (NodeDef* n : output_nodes) {
487         VLOG(3) << "really checking old output " << n->name()
488                 << " for corresponding input.";
489         if (op_instance_names.find(n->name()) != op_instance_names.end()) {
490           // If this output node is a member of the ops set, it must have
491           // been an internal control edge so drop it.
492           VLOG(3) << "Dropping control output from " << old_op->name() << " to "
493                   << n->name();
494           // However, we may already have dropped it at the clear() below,
495           // so if we fail to find it, that's okay.
496           Status ignore = RemoveEdge(strings::StrCat("^", old_op->name()),
497                                      old_op->name(), n, node_map);
498           continue;
499         }
500         bool found = false;
501         VLOG(3) << "about to iterate over " << n->input_size() << " inputs";
502         for (int i = 0; i < n->input_size(); ++i) {
503           VLOG(3) << "input " << n->input(i);
504           int position = 0;
505           string input_node = ParseNodeName(n->input(i), &position);
506           if (input_node == old_op->name()) {
507             found = true;
508             VLOG(3) << "match pos=" << position;
509             if (position == -1) {
510               // It was a control edge
511               *n->mutable_input(i) = strings::StrCat("^", sas_name);
512             } else {
513               CHECK_EQ(0, position)
514                   << "name " << n->input(i) << " pos " << position;
515               *n->mutable_input(i) = strings::StrCat(sas_name, ":", op_idx);
516             }
517             node_map->RemoveOutput(old_op->name(), n->name());
518             node_map->AddOutput(sas_name, n->name());
519             VLOG(3) << "breaking on success";
520             break;
521           } else {
522             VLOG(3) << "other input " << n->input(i);
523           }
524         }
525         // In general it's required that we found the output node's old
526         // input and replaced it, but one exception is if the output node
527         // is of the same type being coalesced and the edge is a control
528         // input.  In that case it probably got eliminated in an earlier
529         // pass.
530         VLOG(3) << "before HasOp";
531         if (!HasOpName(n->name(), op_name)) {
532           CHECK(found) << "old_op " << old_op->name() << " node "
533                        << " could not find input edge on " << n->DebugString()
534                        << " to replace."
535                        << " " << op_name << " not in " << n->name();
536         }
537         VLOG(3) << "bottom of for output_nodes";
538       }
539       VLOG(3) << "Clearing all inputs of " << old_op->name();
540       node_map->RemoveInputs(old_op->name());
541       old_op->clear_input();
542       node_map->RemoveOutputs(old_op->name());
543       VLOG(3) << "after clear: " << old_op->DebugString();
544       // old_op should be dead, with no further inputs or outputs.
545       // It needs to be removed altogether before the graph is generated,
546       // but we need to leave it around until this Optimizer is done,
547       // because there may be some
548       // Remove.
549       RemoveNode(old_op, graph, node_map);
550     }
551     return Status::OK();
552   }
553 
554   // Given a collection of instances of op_name, presumed to be
555   // logically parallel and operating on tensors of the same type,
556   // replace them by a single instance.  First find the upstream Ops
557   // generating their inputs. Create a new ScopedAllocatorOp that
558   // outputs a single backing_tensor pre-arranged for sub-allocation
559   // of all of those input tensors.  Then insert a new
560   // ScopedAllocatorConcatOp below the upstream Ops to make explicit
561   // the materialization of a concatenation of their outputs.  Put the
562   // new op_name instance below the new concat op and follow with a
563   // ScopedAllocatorSplitOp that restores the correct shape outputs
564   // for the consumers of the old op_name instances.
565   //
566   // There must be no non-control edges between Nodes in 'ops'.
567   // Control edges among these nodes will be dropped.
Rewrite(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,const string & op_name,const std::vector<NodeDef * > & ops,bool * applied)568   Status Rewrite(ScopedAllocatorOptimizer* sa_opti, int64 invocation_count,
569                  GraphDef* graph, const string& op_name,
570                  const std::vector<NodeDef*>& ops, bool* applied) override {
571     if (VLOG_IS_ON(1)) {
572       VLOG(1) << "Rewrite";
573       string op_names;
574       for (auto& nd : ops) {
575         strings::StrAppend(&op_names, nd->name(), ", ");
576       }
577       VLOG(1) << "UnaryElementwiseRewriter::Rewrite " << op_name
578               << " to: " << op_names;
579     }
580     NodeMap* node_map = sa_opti->node_map();
581 
582     // Make a set of the node names for faster membership testing.
583     std::set<string> op_instance_names;
584     for (auto& nd : ops) {
585       op_instance_names.insert(nd->name());
586       VLOG(2) << "op_instance_name " << nd->name();
587     }
588     DataType dtype;
589     std::vector<TensorShape> input_shapes;
590     std::vector<InputDesc> inputs;
591     TensorShape sa_shape;
592     string device_name;
593 
594     TF_RETURN_IF_ERROR(AnalyzeInputs(sa_opti, node_map, ops, op_instance_names,
595                                      &device_name, &dtype, &input_shapes,
596                                      &inputs, &sa_shape));
597 
598     int sa_id = sa_opti->NewScopedAllocatorId(input_shapes.size());
599     string sa_name =
600         strings::StrCat("scoped_allocator_", sa_id, "_", invocation_count);
601     TF_RETURN_IF_ERROR(ConstructScopedAllocatorNode(
602         sa_opti, graph, node_map, ops, device_name, dtype, sa_id, sa_name,
603         input_shapes, inputs, sa_shape));
604 
605     // TODO(tucker): Maybe add control edges to delay execution of the
606     // ScopedAllocatorOp until just before first use in order to
607     // conserve memory.  What would be correct?  Let I0...In be the
608     // input nodes that are all going to alloc from SA.  If we make
609     // SA wait until all of these are ready, that might be too slow.
610     // It should probably wait until at least one is ready, but which
611     // one?  Maybe just pick the first.
612     // {
613     //   auto& nd = inputs[0];
614     //   std::vector<InputDesc> inputs_to_first;
615     //   LOG_WARNING_AND_RETURN_IF_ERROR(GetInputs(sa_opti->node_map(),
616     //   {nd.from_node_def},
617     //                                dtype, &inputs_to_first));
618     //   for (int i = 0; i < inputs_to_first.size(); ++i) {
619     //     sa_node->add_input(
620     //         strings::StrCat("^", inputs_to_first[i].from_node_def->name()));
621     //   }
622     // }
623 
624     // Build a ScopedAllocatorConcat below all of the input nodes.
625     std::vector<NodeDefBuilder::NodeOut> sac_inputs;
626     string sac_name = strings::StrCat("scoped_allocator_concat_", sa_id, "_",
627                                       invocation_count);
628     TF_RETURN_IF_ERROR(BuildSAConcatNode(
629         graph, node_map, ops, op_instance_names, device_name, dtype, sa_id,
630         sa_name, sac_name, sa_shape, &sac_inputs));
631 
632     // Construct a new instance of the parallel op and insert it
633     // immediately below the new ScopedAllocatorConcat.
634     string sa_op_name = strings::StrCat(sa_name, "_", op_name);
635     TF_RETURN_IF_ERROR(BuildReplacementOp(graph, node_map, ops, device_name,
636                                           dtype, op_name, sac_name,
637                                           sa_op_name));
638 
639     // Build a ScopedAllocatorSplit split below the new Op.
640     string sas_name = strings::StrCat("scoped_allocator_split_", sa_id, "_",
641                                       invocation_count);
642     TF_RETURN_IF_ERROR(BuildSplitNode(graph, node_map, ops, input_shapes,
643                                       sac_inputs, device_name, dtype, op_name,
644                                       sa_id, sas_name, sa_name, sa_op_name));
645 
646     // Rewire the graph.
647     TF_RETURN_IF_ERROR(RewireSubgraph(graph, node_map, ops, op_instance_names,
648                                       op_name, sas_name));
649 
650     *applied = true;
651     return Status::OK();
652   }
653 };
654 
ScopedAllocatorOptimizer(RewriterConfig::Toggle opt_level,const ScopedAllocatorOptions & opts)655 ScopedAllocatorOptimizer::ScopedAllocatorOptimizer(
656     RewriterConfig::Toggle opt_level, const ScopedAllocatorOptions& opts)
657     : opt_level_(opt_level) {
658   VLOG(1) << "ScopedAllocatorOptimizer::ScopedAllocatorOptimizer";
659   Rewriter* r = new UnaryElementwiseRewriter();
660   to_delete_.push_back(r);
661   if (opts.enable_op_size() == 0) {
662     // Opts handled by default:
663     for (const auto& op_name : {"CollectiveReduce"}) {
664       op_name_set_.insert(op_name);
665       rewriters_[op_name] = r;
666     }
667   } else {
668     for (const auto& op_name : opts.enable_op()) {
669       op_name_set_.insert(op_name);
670       rewriters_[op_name] = r;
671     }
672   }
673 }
674 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)675 Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
676                                           const GrapplerItem& item,
677                                           GraphDef* optimized_graph) {
678   *optimized_graph = item.graph;
679   // Nodes that cannot be removed from the graph without damaging correctness,
680   // typically fetch nodes.
681   nodes_to_preserve_ = item.NodesToPreserve();
682 
683   GraphProperties graph_properties(item);
684   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
685   LOG_WARNING_AND_RETURN_IF_ERROR(
686       graph_properties.InferStatically(assume_valid_feeds));
687   node_map_.reset(new NodeMap(optimized_graph));
688 
689   LOG_WARNING_AND_RETURN_IF_ERROR(ScopedAllocatorOptimizer::ProcessGraphDef(
690       optimized_graph, graph_properties));
691 
692   VLOG(1) << "ScopedAllocatorOptimizer::Optimize() done";
693   return Status::OK();
694 }
695 
GetRewriter(const string & op_name)696 ScopedAllocatorOptimizer::Rewriter* ScopedAllocatorOptimizer::GetRewriter(
697     const string& op_name) {
698   auto it = rewriters_.find(op_name);
699   if (it != rewriters_.end()) {
700     return it->second;
701   }
702   return nullptr;
703 }
704 
NewScopedAllocatorId(int num_fields)705 int ScopedAllocatorOptimizer::NewScopedAllocatorId(int num_fields) {
706   CHECK_GT(num_fields, 0);
707   int id = next_sa_id_;
708   next_sa_id_ += (num_fields + 1);
709   CHECK_GT(next_sa_id_, 0);
710   return id;
711 }
712 
~ScopedAllocatorOptimizer()713 ScopedAllocatorOptimizer::~ScopedAllocatorOptimizer() {
714   for (auto ptr : to_delete_) {
715     delete ptr;
716   }
717 }
718 
FindOpOccurrences(GraphDef * graph,const OpNameSet & op_names,GraphOpOccurrences * occs)719 void ScopedAllocatorOptimizer::FindOpOccurrences(GraphDef* graph,
720                                                  const OpNameSet& op_names,
721                                                  GraphOpOccurrences* occs) {
722   VLOG(1) << "FindOpOccurrences ";
723   for (const auto& it : op_names) {
724     VLOG(1) << "search target " << it;
725   }
726   for (int ni = 0; ni < graph->node_size(); ++ni) {
727     NodeDef* node = graph->mutable_node(ni);
728     const string& op_name = node->op();
729     if (op_names.find(op_name) != op_names.end()) {
730       VLOG(1) << "found " << op_name << " on dev " << node->device();
731       (*occs)[node->device()][op_name].push_back(node);
732     }
733   }
734 }
735 
736 namespace {
737 struct OpNameOrder {
operator ()tensorflow::grappler::__anonaf288b9b0211::OpNameOrder738   bool operator()(const NodeDef* a, const NodeDef* b) {
739     return a->name() <= b->name();
740   }
741 };
742 
743 class Tree {
744  public:
Tree(const string & edge,int depth)745   Tree(const string& edge, int depth) : edge_(edge), depth_(depth) {}
~Tree()746   ~Tree() {
747     for (auto it : subtrees_) delete it.second;
748   }
749 
GetSubTree(const string & edge)750   Tree* GetSubTree(const string& edge) {
751     auto it = subtrees_.find(edge);
752     if (it != subtrees_.end()) {
753       return it->second;
754     }
755     Tree* t = new Tree(edge, depth_ + 1);
756     subtrees_[edge] = t;
757     return t;
758   }
759 
InsertNode(NodeDef * n)760   void InsertNode(NodeDef* n) { nodes_.push_back(n); }
761 
762   string edge_;
763   int depth_;
764   std::vector<NodeDef*> nodes_;
765   std::unordered_map<string, Tree*> subtrees_;
766 };
767 
768 // Applies a function to every Tree in DFS order.  Terminates early
769 // on any non-OK Status.
ApplyToAll(Tree * tree,const std::function<Status (Tree *)> & func)770 Status ApplyToAll(Tree* tree, const std::function<Status(Tree*)>& func) {
771   Status s;
772   for (auto it : tree->subtrees_) {
773     s = ApplyToAll(it.second, func);
774     if (!s.ok()) return s;
775   }
776   s = func(tree);
777   return s;
778 }
779 
ComputeScopeTree(const string & op_name,const std::vector<NodeDef * > & node_vec)780 Tree* ComputeScopeTree(const string& op_name,
781                        const std::vector<NodeDef*>& node_vec) {
782   Tree* root = new Tree("", 0);
783   for (NodeDef* n : node_vec) {
784     std::vector<string> pieces = str_util::Split(n->name(), "/");
785     // last piece is node name proper.
786     int depth = pieces.size() - 1;
787     Tree* subtree = root;
788     for (int i = 0; i < depth; ++i) {
789       subtree = subtree->GetSubTree(pieces[i]);
790     }
791     subtree->InsertNode(n);
792   }
793   return root;
794 }
795 
PartitionByLoopStructure(const FrameView & frame_view,std::vector<NodeDef * > nodes,std::vector<std::vector<NodeDef * >> * loop_groups)796 void PartitionByLoopStructure(const FrameView& frame_view,
797                               std::vector<NodeDef*> nodes,
798                               std::vector<std::vector<NodeDef*>>* loop_groups) {
799   // It is assumed that two nodes with identical loop containment have
800   // identical integer vectors. Represent those by 64 bit hashes.
801   std::unordered_map<uint64, std::vector<NodeDef*>> loop_sets;
802   for (NodeDef* nd : nodes) {
803     uint64 hash = 0;
804     const std::vector<int>& loop_ids = frame_view.Frames(*nd);
805     for (int id : loop_ids) {
806       hash = Hash64Combine(hash, static_cast<uint64>(id));
807     }
808     loop_sets[hash].push_back(nd);
809   }
810   for (auto it : loop_sets) {
811     loop_groups->push_back(std::move(it.second));
812   }
813 }
814 
815 }  // namespace
816 
ProcessGraphDef(GraphDef * graph,const GraphProperties & graph_properties)817 Status ScopedAllocatorOptimizer::ProcessGraphDef(
818     GraphDef* graph, const GraphProperties& graph_properties) {
819   // Nodes created by this optimizer have the IsStateful() property
820   // which means their names must be globally unique within a process,
821   // so we include an optimizer invocation count in every generated
822   // name.
823   static std::atomic<int64> invocation_counter(1);
824   const int64 invocation_count =
825       invocation_counter.fetch_add(1, std::memory_order_seq_cst);
826   VLOG(1) << "ProcessGraphDef " << invocation_count;
827   Status status;
828   GraphOpOccurrences occ;
829   FindOpOccurrences(graph, op_name_set_, &occ);
830   if (!occ.empty()) {
831     FrameView frame_view;
832     // TODO(ezhulenev): Pass a GraphView when this optimizer will be migrated
833     // from NodeMap.
834     LOG_WARNING_AND_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph));
835 
836     for (auto& dt : occ) {
837       VLOG(2) << "Processing device " << dt.first;
838       const DevOpOccurrences& dev_occ = dt.second;
839       for (auto& it : dev_occ) {
840         string op_name = it.first;
841         VLOG(1) << "Processing " << op_name << " set size " << it.second.size();
842         Rewriter* rewriter = GetRewriter(op_name);
843         if (!rewriter) {
844           LOG(ERROR) << "Failed to find PARewriter for op_name " << op_name;
845           continue;
846         }
847         rewriter->SetGraphProperties(graph_properties);
848         std::unique_ptr<Tree> root(ComputeScopeTree(it.first, it.second));
849         // Nodes with a common depth and root path are now grouped
850         // in the same Tree struct.  Split those groups into subgroups that
851         // share identical loop nesting.
852         status = ApplyToAll(root.get(), [this, rewriter, graph, &frame_view,
853                                          &op_name, invocation_count](Tree* t) {
854           VLOG(2) << "applied to tree node " << t->edge_ << " at depth "
855                   << t->depth_ << " of size " << t->nodes_.size();
856           if (t->nodes_.size() > 1) {
857             std::vector<std::vector<NodeDef*>> loop_groups;
858             PartitionByLoopStructure(frame_view, t->nodes_, &loop_groups);
859             for (auto& lg : loop_groups) {
860               if (lg.size() > 1) {
861                 bool applied = false;
862                 Status s = OrderNodeSet(&lg);
863                 TF_RETURN_IF_ERROR(s);
864                 VLOG(1) << "Applying Rewriter for " << op_name;
865                 s = rewriter->Rewrite(this, invocation_count, graph, op_name,
866                                       lg, &applied);
867                 LOG_WARNING_AND_RETURN_IF_ERROR(s);
868               }
869             }
870           }
871           return Status::OK();
872         });
873         if (!status.ok()) {
874           break;
875         }
876       }
877       if (!status.ok()) {
878         break;
879       }
880     }
881   }
882   VLOG(1) << "ScopedAllocatorOptimizer returning " << status;
883   if (!status.ok()) {
884     LOG(ERROR) << "ScopedAllocatorOptimizer: " << status;
885   }
886   return status;
887 }
888 
889 namespace {
890 struct InstanceKeyLess {
operator ()tensorflow::grappler::__anonaf288b9b0411::InstanceKeyLess891   bool operator()(const NodeDef* a, const NodeDef* b) const {
892     AttrSlice a_attrs = AttrSlice(*a);
893     AttrSlice b_attrs = AttrSlice(*b);
894     int32 a_key = -1;
895     int32 b_key = -1;
896     Status s = GetNodeAttr(a_attrs, "instance_key", &a_key);
897     CHECK(s.ok());
898     s = GetNodeAttr(b_attrs, "instance_key", &b_key);
899     CHECK(s.ok());
900     return a_key < b_key;
901   }
902 };
903 
904 struct NameLess {
operator ()tensorflow::grappler::__anonaf288b9b0411::NameLess905   bool operator()(const NodeDef* a, const NodeDef* b) const {
906     return a->name() < b->name();
907   }
908 };
909 
IsCollectiveNode(const NodeDef & n)910 bool IsCollectiveNode(const NodeDef& n) {
911   AttrSlice attrs = AttrSlice(n);
912   int key = -1;
913   if (!IsCollective(n)) return false;
914   Status s = GetNodeAttr(attrs, "instance_key", &key);
915   if (s.ok() && key >= 0) {
916     return true;
917   }
918   return false;
919 }
920 }  // namespace
921 
OrderNodeSet(std::vector<NodeDef * > * nodes) const922 Status ScopedAllocatorOptimizer::OrderNodeSet(
923     std::vector<NodeDef*>* nodes) const {
924   // Nodes should be identical type.  Default order is by name but for
925   // collectives we order by increasing instance_key so each group gets
926   // the same instance_key.
927   if (nodes->size() <= 1) return Status::OK();
928   if (IsCollectiveNode(*nodes->at(0))) {
929     sort(nodes->begin(), nodes->end(), InstanceKeyLess());
930   } else {
931     sort(nodes->begin(), nodes->end(), NameLess());
932   }
933   return Status::OK();
934 }
935 
936 }  // namespace grappler
937 }  // namespace tensorflow
938 
939 #undef LOG_WARNING_AND_RETURN_IF_ERROR
940