1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
16
17 #include "tensorflow/core/common_runtime/scoped_allocator.h"
18 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/utils/frame.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28
29 // Like TF_RETURN_IF_ERROR, but also logs a WARNING.
30 #define LOG_WARNING_AND_RETURN_IF_ERROR(...) \
31 do { \
32 const ::tensorflow::Status _status = (__VA_ARGS__); \
33 if (TF_PREDICT_FALSE(!_status.ok())) { \
34 LOG(WARNING) << "error: " << _status; \
35 return _status; \
36 } \
37 } while (0)
38
39 namespace tensorflow {
40 namespace grappler {
41
42 namespace {
43 // Node names often have some kind of name_scope prefix, with slashes,
44 // and a _nn numeric suffix. Returns true if the main part of the node_name
45 // matches op_name, i.e. it looks from the name like this node is
46 // of that op type.
HasOpName(const string & node_name,const string & op_name)47 bool HasOpName(const string& node_name, const string& op_name) {
48 size_t begin = node_name.rfind("/");
49 if (begin == string::npos) {
50 begin = 0;
51 } else {
52 ++begin;
53 }
54 size_t end = node_name.rfind("_");
55 if (end != string::npos) {
56 size_t p = end + 1;
57 while (p < node_name.size()) {
58 if (!isdigit(node_name[p])) {
59 end = node_name.size();
60 break;
61 }
62 ++p;
63 }
64 } else {
65 end = node_name.size();
66 }
67 return node_name.substr(begin, end - begin) == op_name;
68 }
69
70 // After shape inference has been done each op should be annotated
71 // with its output shape(s). This function iterates over a collection
72 // of ops that are a potential application of a ScopedAllocator. It
73 // verifies whether they all have the same output type and if so
74 // gathers a vector of their output shapes. It returns an error if
75 // any of the ops doesn't have type or shape data, or if it has more
76 // than one output, of if the output type of all ops is not the same.
77 // If it returns OK then *type and *shapes should be correctly populated.
CheckTypesAndGetShapes(const GraphProperties & graph_properties,const std::vector<NodeDef * > & ops,DataType * type,std::vector<TensorShape> * shapes)78 Status CheckTypesAndGetShapes(const GraphProperties& graph_properties,
79 const std::vector<NodeDef*>& ops, DataType* type,
80 std::vector<TensorShape>* shapes) {
81 VLOG(1) << "CheckTypesAndGetShapes";
82 *type = DT_INVALID;
83 for (NodeDef* n : ops) {
84 AttrSlice n_attrs = AttrSlice(*n);
85 DataType dtype;
86 LOG_WARNING_AND_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
87 VLOG(2) << "op " << n->name() << " has type " << dtype << " shapes.size() "
88 << shapes->size();
89 if (!graph_properties.HasOutputProperties(n->name())) {
90 LOG(ERROR) << "Node " << n->DebugString() << " lacks output shape.";
91 return errors::Internal("Node ", n->name(), " lacks output shape.");
92 }
93 const std::vector<OpInfo::TensorProperties>& prop_list =
94 graph_properties.GetOutputProperties(n->name());
95 if (prop_list.size() != 1) {
96 return errors::Internal("Node ", n->name(),
97 " does not have exactly one output as expected "
98 "by ScopedAllocatorOptimizer");
99 }
100 const OpInfo::TensorProperties& props = prop_list[0];
101 if (shapes->empty()) {
102 *type = props.dtype();
103 } else if (*type != props.dtype()) {
104 return errors::Internal("Group ops don't all have same type");
105 } else if (!TensorShape::IsValid(props.shape())) {
106 return errors::Internal("Complete shape not known for ", n->name());
107 }
108 VLOG(2) << "Adding shape " << props.shape().DebugString();
109 shapes->push_back(TensorShape(props.shape()));
110 }
111 return Status::OK();
112 }
113
114 // Describes an existing input edge in the graph.
115 struct InputDesc {
116 NodeDef* from_node_def;
117 int output_slot;
118 NodeDef* to_node_def;
InputDesctensorflow::grappler::__anonaf288b9b0111::InputDesc119 InputDesc(NodeDef* f, int os, NodeDef* t)
120 : from_node_def(f), output_slot(os), to_node_def(t) {}
121 };
122
123 // Populates *inputs with all of the non-control inputs of ops.
124 // Returns error if it fails to find exactly one input for each op,
125 // or if some input is not of type dtype.
GetInputs(NodeMap * node_map,const std::vector<NodeDef * > & ops,DataType dtype,std::vector<InputDesc> * inputs)126 Status GetInputs(NodeMap* node_map, const std::vector<NodeDef*>& ops,
127 DataType dtype, std::vector<InputDesc>* inputs) {
128 VLOG(1) << "Getinputs";
129 for (NodeDef* n : ops) {
130 NodeDef* inode = nullptr;
131 int position = 0;
132 VLOG(2) << "for node " << n->name();
133 for (const auto& input_name : n->input()) {
134 if (!IsControlInput(input_name)) {
135 if (inode) {
136 return errors::Internal("Found more than one input for node ",
137 n->name());
138 }
139 ParseNodeName(input_name, &position);
140 inode = node_map->GetNode(input_name);
141 CHECK(inode) << input_name;
142 VLOG(2) << "inode " << inode->DebugString();
143 }
144 }
145 AttrSlice inode_attrs = AttrSlice(*inode);
146 DataType inode_dtype;
147 LOG_WARNING_AND_RETURN_IF_ERROR(
148 GetNodeAttr(inode_attrs, "T", &inode_dtype));
149 if (inode_dtype != dtype) {
150 return errors::Internal("ScopedAllocatorOptimizer expected input type ",
151 dtype, " but found ", inode_dtype);
152 }
153 // inputs->push_back(InputDesc(inode, position, n));
154 inputs->emplace_back(inode, position, n);
155 }
156 return Status::OK();
157 }
158
159 // Remove the NodeDef nd from node_map and graph. It must be the case
160 // that nd no longer has any input or output edges, though that is not
161 // checked.
RemoveNode(NodeDef * nd,GraphDef * graph,NodeMap * node_map)162 void RemoveNode(NodeDef* nd, GraphDef* graph, NodeMap* node_map) {
163 node_map->RemoveNode(nd->name());
164 // TODO(tucker): The efficiency of this routine is poor.
165 // Change to accumulate and do a bulk removal, maybe refactoring
166 // some code from dependency_optimizer.
167 protobuf::RepeatedPtrField<NodeDef>* nodes = graph->mutable_node();
168 for (int i = 0; i < nodes->size(); ++i) {
169 if (nd->name() == (*nodes)[i].name()) {
170 nodes->SwapElements(i, nodes->size() - 1);
171 nodes->RemoveLast();
172 return;
173 }
174 }
175 LOG(FATAL) << "Failed to find node " << nd->name() << " in graph";
176 }
177
178 // Removes a named edge from between two nodes.
RemoveEdge(const string & input_edge_name,const string & from_node_name,NodeDef * to_node,NodeMap * node_map)179 Status RemoveEdge(const string& input_edge_name, const string& from_node_name,
180 NodeDef* to_node, NodeMap* node_map) {
181 if (node_map) {
182 node_map->RemoveOutput(from_node_name, to_node->name());
183 }
184 protobuf::RepeatedPtrField<string>* inputs = to_node->mutable_input();
185 int edge_index = -1;
186 for (edge_index = 0; edge_index < inputs->size(); ++edge_index) {
187 VLOG(2) << " consider edge " << (*inputs)[edge_index];
188 if ((*inputs)[edge_index] == input_edge_name) {
189 break;
190 }
191 }
192 if (edge_index >= inputs->size()) {
193 return errors::Internal("Could not find input name ", input_edge_name,
194 " at node ", to_node->name());
195 }
196 inputs->DeleteSubrange(edge_index, 1);
197 return Status::OK();
198 }
199 } // namespace
200
ExtendNodeAttr(StringPiece name,const std::vector<int32> & values,NodeDef * node_def)201 void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name,
202 const std::vector<int32>& values,
203 NodeDef* node_def) {
204 if (HasNodeAttr(*node_def, name)) {
205 VLOG(2) << "extending";
206 AttrValue* existing = &(*node_def->mutable_attr())[string(name)];
207 for (int32 i : values) {
208 existing->mutable_list()->add_i(i);
209 }
210 } else {
211 VLOG(2) << "setting new attr value";
212 AddNodeAttr(name, values, node_def);
213 }
214 }
215
216 class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
217 public:
~UnaryElementwiseRewriter()218 ~UnaryElementwiseRewriter() override {}
219
220 // Return non-OK if any input is already committed to a ScopedAllocator.
CheckExistingScopedAllocator(const std::vector<InputDesc> & inputs)221 Status CheckExistingScopedAllocator(const std::vector<InputDesc>& inputs) {
222 for (const InputDesc& nd : inputs) {
223 VLOG(2) << "get attrs for " << nd.from_node_def->name();
224 AttrSlice n_attrs = AttrSlice(*nd.from_node_def);
225 int sa_id;
226 Status ss = GetNodeAttr(n_attrs, "sa_id", &sa_id);
227 if (ss.ok()) {
228 LOG(INFO) << "Abandoning PARewriter because input "
229 << nd.from_node_def->name() << " is already assigned "
230 << "to ScopedAllocator " << sa_id;
231 return errors::Internal(
232 "Abandoning PARewriter because input ", nd.from_node_def->name(),
233 " is already assigned to ScopedAllocator ", sa_id);
234 }
235 }
236 return Status::OK();
237 }
238
239 // Return non-OK if any input is a member of op_set.
CheckInternalDataDependency(const std::set<string> & op_set,const std::vector<InputDesc> & inputs)240 Status CheckInternalDataDependency(const std::set<string>& op_set,
241 const std::vector<InputDesc>& inputs) {
242 for (const InputDesc& nd : inputs) {
243 if (op_set.find(nd.from_node_def->name()) != op_set.end()) {
244 if (nd.output_slot != tensorflow::Graph::kControlSlot) {
245 return errors::Internal("Data edge exists bewtween ",
246 nd.from_node_def->name(),
247 " and another "
248 "node in the set");
249 }
250 }
251 }
252 return Status::OK();
253 }
254
255 // Remove all control edges between members of ops.
ClearInternalControlInputs(const std::set<string> & op_set,const std::vector<NodeDef * > & ops,NodeMap * node_map)256 void ClearInternalControlInputs(const std::set<string>& op_set,
257 const std::vector<NodeDef*>& ops,
258 NodeMap* node_map) {
259 for (NodeDef* n : ops) {
260 for (const auto& input_name : n->input()) {
261 if (IsControlInput(input_name)) {
262 int position = 0;
263 string input_node_name = ParseNodeName(input_name, &position);
264 CHECK_EQ(position, -1);
265 if (op_set.find(input_node_name) != op_set.end()) {
266 // This is an internal control edge. Remove it.
267 VLOG(1) << "Remove control output from " << input_node_name
268 << " via edge " << input_name << " to " << n->name();
269 TF_CHECK_OK(RemoveEdge(input_name, input_node_name, n, node_map));
270 }
271 }
272 }
273 }
274 }
275
276 // Examine the input set of an op set, gathering their shapes and types
277 // and checking whether there are any considerations that prevent use
278 // of a single ScopedAllocator for all of those inputs.
AnalyzeInputs(ScopedAllocatorOptimizer * sa_opti,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,string * device_name,DataType * dtype,std::vector<TensorShape> * input_shapes,std::vector<InputDesc> * inputs,TensorShape * sa_shape)279 Status AnalyzeInputs(ScopedAllocatorOptimizer* sa_opti, NodeMap* node_map,
280 const std::vector<NodeDef*>& ops,
281 const std::set<string>& op_instance_names,
282 string* device_name, DataType* dtype,
283 std::vector<TensorShape>* input_shapes,
284 std::vector<InputDesc>* inputs, TensorShape* sa_shape) {
285 CHECK(graph_properties_);
286 LOG_WARNING_AND_RETURN_IF_ERROR(
287 CheckTypesAndGetShapes(*graph_properties_, ops, dtype, input_shapes));
288 LOG_WARNING_AND_RETURN_IF_ERROR(
289 GetInputs(sa_opti->node_map(), ops, *dtype, inputs));
290 LOG_WARNING_AND_RETURN_IF_ERROR(CheckExistingScopedAllocator(*inputs));
291 LOG_WARNING_AND_RETURN_IF_ERROR(
292 CheckInternalDataDependency(op_instance_names, *inputs));
293 ClearInternalControlInputs(op_instance_names, ops, node_map);
294 *device_name = ops[0]->device();
295 CHECK(!device_name->empty());
296 CHECK(!input_shapes->empty());
297 CHECK_EQ(0, Allocator::kAllocatorAlignment % DataTypeSize(*dtype))
298 << "ScopedAllocatorOptimizer only applies to types that evenly "
299 << "divide kAllocatorAlignment";
300 std::vector<ScopedAllocator::Field> sa_fields;
301 // Calculate the field embedding boundaries and thereby the
302 // required size of the backing tensor.
303 int64 num_bytes = ScopedAllocatorMgr::PopulateFields(
304 0 /*scope_id*/, *input_shapes, *dtype, &sa_fields);
305 int64 num_elts = num_bytes / DataTypeSize(*dtype);
306 VLOG(2) << "num_bytes " << num_bytes << " num_elts=" << num_elts;
307 *sa_shape = TensorShape({num_elts});
308 return Status::OK();
309 }
310
311 // Build the ScopedAllocator node that will be assigned to allocate
312 // the output tensors of the input node set.
ConstructScopedAllocatorNode(ScopedAllocatorOptimizer * sa_opti,GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const std::vector<TensorShape> & input_shapes,const std::vector<InputDesc> & inputs,const TensorShape & sa_shape)313 Status ConstructScopedAllocatorNode(
314 ScopedAllocatorOptimizer* sa_opti, GraphDef* graph, NodeMap* node_map,
315 const std::vector<NodeDef*>& ops, const string& device_name,
316 DataType dtype, int sa_id, const string& sa_name,
317 const std::vector<TensorShape>& input_shapes,
318 const std::vector<InputDesc>& inputs, const TensorShape& sa_shape) {
319 VLOG(2) << "ConstructScopedAllocatorNode " << sa_name;
320 NodeDefBuilder sa_builder(sa_name, "_ScopedAllocator");
321 sa_builder.Device(device_name);
322 sa_builder.Attr("sa_name", sa_name);
323 sa_builder.Attr("T", dtype);
324 sa_builder.Attr("id", sa_id);
325 sa_builder.Attr("shapes", input_shapes);
326 sa_builder.Attr("shape", sa_shape);
327 sa_builder.Attr("expected_call_count", static_cast<int64>(ops.size()));
328 NodeDef* sa_node = graph->add_node();
329 LOG_WARNING_AND_RETURN_IF_ERROR(sa_builder.Finalize(sa_node));
330 node_map->AddNode(sa_name, sa_node);
331
332 // Add control edges from the ScopedAllocatorOp to all of the
333 // input nodes and mark them for allocation from backing tensor.
334 for (int i = 0; i < inputs.size(); ++i) {
335 auto& nd = inputs[i];
336 VLOG(2) << "To input " << i << ": " << nd.from_node_def->name()
337 << " add control input "
338 << "^" << sa_name;
339 nd.from_node_def->add_input(strings::StrCat("^", sa_name));
340 // This attribute says: allocate output_slot from
341 // ScopedAllocator instance sa_id + 1 + i.
342 ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator",
343 {nd.output_slot, sa_id + 1 + i},
344 nd.from_node_def);
345 node_map->AddOutput(sa_name, nd.from_node_def->name());
346 }
347 return Status::OK();
348 }
349
BuildSAConcatNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const string & sac_name,const TensorShape & sa_shape,std::vector<NodeDefBuilder::NodeOut> * sac_inputs)350 Status BuildSAConcatNode(GraphDef* graph, NodeMap* node_map,
351 const std::vector<NodeDef*>& ops,
352 const std::set<string>& op_instance_names,
353 const string& device_name, DataType dtype, int sa_id,
354 const string& sa_name, const string& sac_name,
355 const TensorShape& sa_shape,
356 std::vector<NodeDefBuilder::NodeOut>* sac_inputs) {
357 VLOG(2) << "BuildSAConcatNode " << sac_name;
358 std::set<string> sac_ctl_inputs;
359 for (int i = 0; i < ops.size(); ++i) {
360 NodeDef* old_op = ops[i];
361 for (const string& old_op_input : old_op->input()) {
362 int position = 0;
363 string input_name = ParseNodeName(old_op_input, &position);
364 if (position == -1) {
365 // A control input: drop if from another member of the op set.
366 if (op_instance_names.find(old_op_input) == op_instance_names.end()) {
367 sac_ctl_inputs.insert(old_op_input);
368 }
369 } else {
370 // TODO(tucker): remove redundant check.
371 // A data input: illegal if from another member of the op set.
372 if (op_instance_names.find(old_op_input) != op_instance_names.end()) {
373 LOG(ERROR) << "Data edge between " << old_op_input << " and "
374 << old_op->name() << " cannot build ScopedAllocator.";
375 return errors::Internal("Data edge between ", old_op_input, " and ",
376 old_op->name(),
377 " cannot build ScopedAllocator.");
378 }
379 sac_inputs->push_back(
380 NodeDefBuilder::NodeOut(old_op_input, 0, dtype));
381 }
382 VLOG(3) << "from op " << i << ": " << old_op->name()
383 << " sac_inputs append " << old_op_input;
384 }
385 }
386 NodeDefBuilder sac_builder(sac_name, "_ScopedAllocatorConcat");
387 VLOG(2) << "New sac_name " << sac_name << " shape "
388 << sa_shape.DebugString();
389 sac_builder.Device(device_name);
390 sac_builder.Attr("sa_name", sa_name);
391 sac_builder.Attr("id", sa_id);
392 sac_builder.Attr("T", dtype);
393 sac_builder.Attr("shape", sa_shape);
394 sac_builder.Attr("N", static_cast<int>(sac_inputs->size()));
395 sac_builder.Input(NodeDefBuilder::NodeOut(sa_name, 0, dtype));
396 sac_builder.Input(*sac_inputs);
397 NodeDef* sac_node = graph->add_node();
398 LOG_WARNING_AND_RETURN_IF_ERROR(sac_builder.Finalize(sac_node));
399 node_map->AddNode(sac_name, sac_node);
400 node_map->AddOutput(sa_name, sac_name);
401
402 // Attach the old control inputs to the new sac node.
403 for (const string& ctl_input : sac_ctl_inputs) {
404 sac_node->add_input(ctl_input);
405 }
406 return Status::OK();
407 }
408
BuildReplacementOp(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,const string & op_name,const string & sac_name,const string & sa_op_name)409 Status BuildReplacementOp(GraphDef* graph, NodeMap* node_map,
410 const std::vector<NodeDef*>& ops,
411 const string& device_name, DataType dtype,
412 const string& op_name, const string& sac_name,
413 const string& sa_op_name) {
414 VLOG(2) << "BuildReplacementOp " << sa_op_name;
415 NodeDefBuilder op_builder(sa_op_name, op_name);
416 op_builder.Device(device_name);
417
418 // Transfer the Node Attr from the first replaced Node to the new
419 // Node. TODO(tucker): In principle we should verify that
420 // the Attr are consistent and compatible across all op instances.
421 // Unfortunately that will probably require op-specific tests, so
422 // punt on that for the time being.
423 AttrSlice first_slice(*ops[0]);
424 for (auto& it : first_slice) {
425 op_builder.Attr(it.first, it.second);
426 }
427 op_builder.Attr("_forward_input", {0, 0});
428 op_builder.Input(sac_name, 0, dtype);
429 NodeDef* sa_op_node = graph->add_node();
430 LOG_WARNING_AND_RETURN_IF_ERROR(op_builder.Finalize(sa_op_node));
431 node_map->AddNode(sa_op_name, sa_op_node);
432 node_map->AddOutput(sac_name, sa_op_name);
433 return Status::OK();
434 }
435
BuildSplitNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::vector<TensorShape> & input_shapes,const std::vector<NodeDefBuilder::NodeOut> & sac_inputs,const string & device_name,DataType dtype,const string & op_name,int sa_id,const string & sas_name,const string & sa_name,const string & sa_op_name)436 Status BuildSplitNode(GraphDef* graph, NodeMap* node_map,
437 const std::vector<NodeDef*>& ops,
438 const std::vector<TensorShape>& input_shapes,
439 const std::vector<NodeDefBuilder::NodeOut>& sac_inputs,
440 const string& device_name, DataType dtype,
441 const string& op_name, int sa_id,
442 const string& sas_name, const string& sa_name,
443 const string& sa_op_name) {
444 VLOG(2) << "new ScopedAllocatorSplit " << sas_name;
445 NodeDefBuilder sas_builder(sas_name, "_ScopedAllocatorSplit");
446 sas_builder.Device(device_name);
447 sas_builder.Attr("sa_name", sa_name);
448 sas_builder.Attr("id", sa_id);
449 sas_builder.Attr("T", dtype);
450 sas_builder.Attr("shapes", input_shapes);
451 std::vector<NodeDefBuilder::NodeOut> sas_inputs = sac_inputs;
452 sas_builder.Attr("N", static_cast<int>(sas_inputs.size()));
453 sas_builder.Input(NodeDefBuilder::NodeOut(sa_op_name, 0, dtype));
454 sas_builder.Input(sas_inputs);
455 NodeDef* sas_node = graph->add_node();
456 LOG_WARNING_AND_RETURN_IF_ERROR(sas_builder.Finalize(sas_node));
457 node_map->AddNode(sas_name, sas_node);
458 node_map->AddOutput(sa_op_name, sas_name);
459 return Status::OK();
460 }
461
462 // After the new ScopedAllocator and its corresponding Concat and
463 // Split nodes have been built, and a new single Op instance
464 // constructed, rewire the graph: Remove input edges to the old Op
465 // nodes and replace the old Op node outputs with the corresponding
466 // ScopedAllocatorSplit node outputs. After this the old Op nodes
467 // should no longer have any input or output edges and they can be
468 // removed from the graph.
RewireSubgraph(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & op_name,const string & sas_name)469 Status RewireSubgraph(GraphDef* graph, NodeMap* node_map,
470 const std::vector<NodeDef*>& ops,
471 const std::set<string>& op_instance_names,
472 const string& op_name, const string& sas_name) {
473 VLOG(2) << "RewireSubgraph";
474 for (int op_idx = 0; op_idx < ops.size(); ++op_idx) {
475 NodeDef* old_op = ops[op_idx];
476 // Copy the output node set since we'll be modifying the version
477 // maintained by NodeMap in the loop.
478 std::set<NodeDef*> output_nodes = node_map->GetOutputs(old_op->name());
479 VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size()
480 << " outputs. Moving them to the PASplit node.";
481 if (VLOG_IS_ON(2)) {
482 for (NodeDef* n : output_nodes) {
483 VLOG(3) << " output: " << n->name();
484 }
485 }
486 for (NodeDef* n : output_nodes) {
487 VLOG(3) << "really checking old output " << n->name()
488 << " for corresponding input.";
489 if (op_instance_names.find(n->name()) != op_instance_names.end()) {
490 // If this output node is a member of the ops set, it must have
491 // been an internal control edge so drop it.
492 VLOG(3) << "Dropping control output from " << old_op->name() << " to "
493 << n->name();
494 // However, we may already have dropped it at the clear() below,
495 // so if we fail to find it, that's okay.
496 Status ignore = RemoveEdge(strings::StrCat("^", old_op->name()),
497 old_op->name(), n, node_map);
498 continue;
499 }
500 bool found = false;
501 VLOG(3) << "about to iterate over " << n->input_size() << " inputs";
502 for (int i = 0; i < n->input_size(); ++i) {
503 VLOG(3) << "input " << n->input(i);
504 int position = 0;
505 string input_node = ParseNodeName(n->input(i), &position);
506 if (input_node == old_op->name()) {
507 found = true;
508 VLOG(3) << "match pos=" << position;
509 if (position == -1) {
510 // It was a control edge
511 *n->mutable_input(i) = strings::StrCat("^", sas_name);
512 } else {
513 CHECK_EQ(0, position)
514 << "name " << n->input(i) << " pos " << position;
515 *n->mutable_input(i) = strings::StrCat(sas_name, ":", op_idx);
516 }
517 node_map->RemoveOutput(old_op->name(), n->name());
518 node_map->AddOutput(sas_name, n->name());
519 VLOG(3) << "breaking on success";
520 break;
521 } else {
522 VLOG(3) << "other input " << n->input(i);
523 }
524 }
525 // In general it's required that we found the output node's old
526 // input and replaced it, but one exception is if the output node
527 // is of the same type being coalesced and the edge is a control
528 // input. In that case it probably got eliminated in an earlier
529 // pass.
530 VLOG(3) << "before HasOp";
531 if (!HasOpName(n->name(), op_name)) {
532 CHECK(found) << "old_op " << old_op->name() << " node "
533 << " could not find input edge on " << n->DebugString()
534 << " to replace."
535 << " " << op_name << " not in " << n->name();
536 }
537 VLOG(3) << "bottom of for output_nodes";
538 }
539 VLOG(3) << "Clearing all inputs of " << old_op->name();
540 node_map->RemoveInputs(old_op->name());
541 old_op->clear_input();
542 node_map->RemoveOutputs(old_op->name());
543 VLOG(3) << "after clear: " << old_op->DebugString();
544 // old_op should be dead, with no further inputs or outputs.
545 // It needs to be removed altogether before the graph is generated,
546 // but we need to leave it around until this Optimizer is done,
547 // because there may be some
548 // Remove.
549 RemoveNode(old_op, graph, node_map);
550 }
551 return Status::OK();
552 }
553
554 // Given a collection of instances of op_name, presumed to be
555 // logically parallel and operating on tensors of the same type,
556 // replace them by a single instance. First find the upstream Ops
557 // generating their inputs. Create a new ScopedAllocatorOp that
558 // outputs a single backing_tensor pre-arranged for sub-allocation
559 // of all of those input tensors. Then insert a new
560 // ScopedAllocatorConcatOp below the upstream Ops to make explicit
561 // the materialization of a concatenation of their outputs. Put the
562 // new op_name instance below the new concat op and follow with a
563 // ScopedAllocatorSplitOp that restores the correct shape outputs
564 // for the consumers of the old op_name instances.
565 //
566 // There must be no non-control edges between Nodes in 'ops'.
567 // Control edges among these nodes will be dropped.
Rewrite(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,const string & op_name,const std::vector<NodeDef * > & ops,bool * applied)568 Status Rewrite(ScopedAllocatorOptimizer* sa_opti, int64 invocation_count,
569 GraphDef* graph, const string& op_name,
570 const std::vector<NodeDef*>& ops, bool* applied) override {
571 if (VLOG_IS_ON(1)) {
572 VLOG(1) << "Rewrite";
573 string op_names;
574 for (auto& nd : ops) {
575 strings::StrAppend(&op_names, nd->name(), ", ");
576 }
577 VLOG(1) << "UnaryElementwiseRewriter::Rewrite " << op_name
578 << " to: " << op_names;
579 }
580 NodeMap* node_map = sa_opti->node_map();
581
582 // Make a set of the node names for faster membership testing.
583 std::set<string> op_instance_names;
584 for (auto& nd : ops) {
585 op_instance_names.insert(nd->name());
586 VLOG(2) << "op_instance_name " << nd->name();
587 }
588 DataType dtype;
589 std::vector<TensorShape> input_shapes;
590 std::vector<InputDesc> inputs;
591 TensorShape sa_shape;
592 string device_name;
593
594 TF_RETURN_IF_ERROR(AnalyzeInputs(sa_opti, node_map, ops, op_instance_names,
595 &device_name, &dtype, &input_shapes,
596 &inputs, &sa_shape));
597
598 int sa_id = sa_opti->NewScopedAllocatorId(input_shapes.size());
599 string sa_name =
600 strings::StrCat("scoped_allocator_", sa_id, "_", invocation_count);
601 TF_RETURN_IF_ERROR(ConstructScopedAllocatorNode(
602 sa_opti, graph, node_map, ops, device_name, dtype, sa_id, sa_name,
603 input_shapes, inputs, sa_shape));
604
605 // TODO(tucker): Maybe add control edges to delay execution of the
606 // ScopedAllocatorOp until just before first use in order to
607 // conserve memory. What would be correct? Let I0...In be the
608 // input nodes that are all going to alloc from SA. If we make
609 // SA wait until all of these are ready, that might be too slow.
610 // It should probably wait until at least one is ready, but which
611 // one? Maybe just pick the first.
612 // {
613 // auto& nd = inputs[0];
614 // std::vector<InputDesc> inputs_to_first;
615 // LOG_WARNING_AND_RETURN_IF_ERROR(GetInputs(sa_opti->node_map(),
616 // {nd.from_node_def},
617 // dtype, &inputs_to_first));
618 // for (int i = 0; i < inputs_to_first.size(); ++i) {
619 // sa_node->add_input(
620 // strings::StrCat("^", inputs_to_first[i].from_node_def->name()));
621 // }
622 // }
623
624 // Build a ScopedAllocatorConcat below all of the input nodes.
625 std::vector<NodeDefBuilder::NodeOut> sac_inputs;
626 string sac_name = strings::StrCat("scoped_allocator_concat_", sa_id, "_",
627 invocation_count);
628 TF_RETURN_IF_ERROR(BuildSAConcatNode(
629 graph, node_map, ops, op_instance_names, device_name, dtype, sa_id,
630 sa_name, sac_name, sa_shape, &sac_inputs));
631
632 // Construct a new instance of the parallel op and insert it
633 // immediately below the new ScopedAllocatorConcat.
634 string sa_op_name = strings::StrCat(sa_name, "_", op_name);
635 TF_RETURN_IF_ERROR(BuildReplacementOp(graph, node_map, ops, device_name,
636 dtype, op_name, sac_name,
637 sa_op_name));
638
639 // Build a ScopedAllocatorSplit split below the new Op.
640 string sas_name = strings::StrCat("scoped_allocator_split_", sa_id, "_",
641 invocation_count);
642 TF_RETURN_IF_ERROR(BuildSplitNode(graph, node_map, ops, input_shapes,
643 sac_inputs, device_name, dtype, op_name,
644 sa_id, sas_name, sa_name, sa_op_name));
645
646 // Rewire the graph.
647 TF_RETURN_IF_ERROR(RewireSubgraph(graph, node_map, ops, op_instance_names,
648 op_name, sas_name));
649
650 *applied = true;
651 return Status::OK();
652 }
653 };
654
ScopedAllocatorOptimizer(RewriterConfig::Toggle opt_level,const ScopedAllocatorOptions & opts)655 ScopedAllocatorOptimizer::ScopedAllocatorOptimizer(
656 RewriterConfig::Toggle opt_level, const ScopedAllocatorOptions& opts)
657 : opt_level_(opt_level) {
658 VLOG(1) << "ScopedAllocatorOptimizer::ScopedAllocatorOptimizer";
659 Rewriter* r = new UnaryElementwiseRewriter();
660 to_delete_.push_back(r);
661 if (opts.enable_op_size() == 0) {
662 // Opts handled by default:
663 for (const auto& op_name : {"CollectiveReduce"}) {
664 op_name_set_.insert(op_name);
665 rewriters_[op_name] = r;
666 }
667 } else {
668 for (const auto& op_name : opts.enable_op()) {
669 op_name_set_.insert(op_name);
670 rewriters_[op_name] = r;
671 }
672 }
673 }
674
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)675 Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
676 const GrapplerItem& item,
677 GraphDef* optimized_graph) {
678 *optimized_graph = item.graph;
679 // Nodes that cannot be removed from the graph without damaging correctness,
680 // typically fetch nodes.
681 nodes_to_preserve_ = item.NodesToPreserve();
682
683 GraphProperties graph_properties(item);
684 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
685 LOG_WARNING_AND_RETURN_IF_ERROR(
686 graph_properties.InferStatically(assume_valid_feeds));
687 node_map_.reset(new NodeMap(optimized_graph));
688
689 LOG_WARNING_AND_RETURN_IF_ERROR(ScopedAllocatorOptimizer::ProcessGraphDef(
690 optimized_graph, graph_properties));
691
692 VLOG(1) << "ScopedAllocatorOptimizer::Optimize() done";
693 return Status::OK();
694 }
695
GetRewriter(const string & op_name)696 ScopedAllocatorOptimizer::Rewriter* ScopedAllocatorOptimizer::GetRewriter(
697 const string& op_name) {
698 auto it = rewriters_.find(op_name);
699 if (it != rewriters_.end()) {
700 return it->second;
701 }
702 return nullptr;
703 }
704
NewScopedAllocatorId(int num_fields)705 int ScopedAllocatorOptimizer::NewScopedAllocatorId(int num_fields) {
706 CHECK_GT(num_fields, 0);
707 int id = next_sa_id_;
708 next_sa_id_ += (num_fields + 1);
709 CHECK_GT(next_sa_id_, 0);
710 return id;
711 }
712
~ScopedAllocatorOptimizer()713 ScopedAllocatorOptimizer::~ScopedAllocatorOptimizer() {
714 for (auto ptr : to_delete_) {
715 delete ptr;
716 }
717 }
718
FindOpOccurrences(GraphDef * graph,const OpNameSet & op_names,GraphOpOccurrences * occs)719 void ScopedAllocatorOptimizer::FindOpOccurrences(GraphDef* graph,
720 const OpNameSet& op_names,
721 GraphOpOccurrences* occs) {
722 VLOG(1) << "FindOpOccurrences ";
723 for (const auto& it : op_names) {
724 VLOG(1) << "search target " << it;
725 }
726 for (int ni = 0; ni < graph->node_size(); ++ni) {
727 NodeDef* node = graph->mutable_node(ni);
728 const string& op_name = node->op();
729 if (op_names.find(op_name) != op_names.end()) {
730 VLOG(1) << "found " << op_name << " on dev " << node->device();
731 (*occs)[node->device()][op_name].push_back(node);
732 }
733 }
734 }
735
736 namespace {
737 struct OpNameOrder {
operator ()tensorflow::grappler::__anonaf288b9b0211::OpNameOrder738 bool operator()(const NodeDef* a, const NodeDef* b) {
739 return a->name() <= b->name();
740 }
741 };
742
743 class Tree {
744 public:
Tree(const string & edge,int depth)745 Tree(const string& edge, int depth) : edge_(edge), depth_(depth) {}
~Tree()746 ~Tree() {
747 for (auto it : subtrees_) delete it.second;
748 }
749
GetSubTree(const string & edge)750 Tree* GetSubTree(const string& edge) {
751 auto it = subtrees_.find(edge);
752 if (it != subtrees_.end()) {
753 return it->second;
754 }
755 Tree* t = new Tree(edge, depth_ + 1);
756 subtrees_[edge] = t;
757 return t;
758 }
759
InsertNode(NodeDef * n)760 void InsertNode(NodeDef* n) { nodes_.push_back(n); }
761
762 string edge_;
763 int depth_;
764 std::vector<NodeDef*> nodes_;
765 std::unordered_map<string, Tree*> subtrees_;
766 };
767
768 // Applies a function to every Tree in DFS order. Terminates early
769 // on any non-OK Status.
ApplyToAll(Tree * tree,const std::function<Status (Tree *)> & func)770 Status ApplyToAll(Tree* tree, const std::function<Status(Tree*)>& func) {
771 Status s;
772 for (auto it : tree->subtrees_) {
773 s = ApplyToAll(it.second, func);
774 if (!s.ok()) return s;
775 }
776 s = func(tree);
777 return s;
778 }
779
ComputeScopeTree(const string & op_name,const std::vector<NodeDef * > & node_vec)780 Tree* ComputeScopeTree(const string& op_name,
781 const std::vector<NodeDef*>& node_vec) {
782 Tree* root = new Tree("", 0);
783 for (NodeDef* n : node_vec) {
784 std::vector<string> pieces = str_util::Split(n->name(), "/");
785 // last piece is node name proper.
786 int depth = pieces.size() - 1;
787 Tree* subtree = root;
788 for (int i = 0; i < depth; ++i) {
789 subtree = subtree->GetSubTree(pieces[i]);
790 }
791 subtree->InsertNode(n);
792 }
793 return root;
794 }
795
PartitionByLoopStructure(const FrameView & frame_view,std::vector<NodeDef * > nodes,std::vector<std::vector<NodeDef * >> * loop_groups)796 void PartitionByLoopStructure(const FrameView& frame_view,
797 std::vector<NodeDef*> nodes,
798 std::vector<std::vector<NodeDef*>>* loop_groups) {
799 // It is assumed that two nodes with identical loop containment have
800 // identical integer vectors. Represent those by 64 bit hashes.
801 std::unordered_map<uint64, std::vector<NodeDef*>> loop_sets;
802 for (NodeDef* nd : nodes) {
803 uint64 hash = 0;
804 const std::vector<int>& loop_ids = frame_view.Frames(*nd);
805 for (int id : loop_ids) {
806 hash = Hash64Combine(hash, static_cast<uint64>(id));
807 }
808 loop_sets[hash].push_back(nd);
809 }
810 for (auto it : loop_sets) {
811 loop_groups->push_back(std::move(it.second));
812 }
813 }
814
815 } // namespace
816
ProcessGraphDef(GraphDef * graph,const GraphProperties & graph_properties)817 Status ScopedAllocatorOptimizer::ProcessGraphDef(
818 GraphDef* graph, const GraphProperties& graph_properties) {
819 // Nodes created by this optimizer have the IsStateful() property
820 // which means their names must be globally unique within a process,
821 // so we include an optimizer invocation count in every generated
822 // name.
823 static std::atomic<int64> invocation_counter(1);
824 const int64 invocation_count =
825 invocation_counter.fetch_add(1, std::memory_order_seq_cst);
826 VLOG(1) << "ProcessGraphDef " << invocation_count;
827 Status status;
828 GraphOpOccurrences occ;
829 FindOpOccurrences(graph, op_name_set_, &occ);
830 if (!occ.empty()) {
831 FrameView frame_view;
832 // TODO(ezhulenev): Pass a GraphView when this optimizer will be migrated
833 // from NodeMap.
834 LOG_WARNING_AND_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph));
835
836 for (auto& dt : occ) {
837 VLOG(2) << "Processing device " << dt.first;
838 const DevOpOccurrences& dev_occ = dt.second;
839 for (auto& it : dev_occ) {
840 string op_name = it.first;
841 VLOG(1) << "Processing " << op_name << " set size " << it.second.size();
842 Rewriter* rewriter = GetRewriter(op_name);
843 if (!rewriter) {
844 LOG(ERROR) << "Failed to find PARewriter for op_name " << op_name;
845 continue;
846 }
847 rewriter->SetGraphProperties(graph_properties);
848 std::unique_ptr<Tree> root(ComputeScopeTree(it.first, it.second));
849 // Nodes with a common depth and root path are now grouped
850 // in the same Tree struct. Split those groups into subgroups that
851 // share identical loop nesting.
852 status = ApplyToAll(root.get(), [this, rewriter, graph, &frame_view,
853 &op_name, invocation_count](Tree* t) {
854 VLOG(2) << "applied to tree node " << t->edge_ << " at depth "
855 << t->depth_ << " of size " << t->nodes_.size();
856 if (t->nodes_.size() > 1) {
857 std::vector<std::vector<NodeDef*>> loop_groups;
858 PartitionByLoopStructure(frame_view, t->nodes_, &loop_groups);
859 for (auto& lg : loop_groups) {
860 if (lg.size() > 1) {
861 bool applied = false;
862 Status s = OrderNodeSet(&lg);
863 TF_RETURN_IF_ERROR(s);
864 VLOG(1) << "Applying Rewriter for " << op_name;
865 s = rewriter->Rewrite(this, invocation_count, graph, op_name,
866 lg, &applied);
867 LOG_WARNING_AND_RETURN_IF_ERROR(s);
868 }
869 }
870 }
871 return Status::OK();
872 });
873 if (!status.ok()) {
874 break;
875 }
876 }
877 if (!status.ok()) {
878 break;
879 }
880 }
881 }
882 VLOG(1) << "ScopedAllocatorOptimizer returning " << status;
883 if (!status.ok()) {
884 LOG(ERROR) << "ScopedAllocatorOptimizer: " << status;
885 }
886 return status;
887 }
888
889 namespace {
890 struct InstanceKeyLess {
operator ()tensorflow::grappler::__anonaf288b9b0411::InstanceKeyLess891 bool operator()(const NodeDef* a, const NodeDef* b) const {
892 AttrSlice a_attrs = AttrSlice(*a);
893 AttrSlice b_attrs = AttrSlice(*b);
894 int32 a_key = -1;
895 int32 b_key = -1;
896 Status s = GetNodeAttr(a_attrs, "instance_key", &a_key);
897 CHECK(s.ok());
898 s = GetNodeAttr(b_attrs, "instance_key", &b_key);
899 CHECK(s.ok());
900 return a_key < b_key;
901 }
902 };
903
904 struct NameLess {
operator ()tensorflow::grappler::__anonaf288b9b0411::NameLess905 bool operator()(const NodeDef* a, const NodeDef* b) const {
906 return a->name() < b->name();
907 }
908 };
909
IsCollectiveNode(const NodeDef & n)910 bool IsCollectiveNode(const NodeDef& n) {
911 AttrSlice attrs = AttrSlice(n);
912 int key = -1;
913 if (!IsCollective(n)) return false;
914 Status s = GetNodeAttr(attrs, "instance_key", &key);
915 if (s.ok() && key >= 0) {
916 return true;
917 }
918 return false;
919 }
920 } // namespace
921
OrderNodeSet(std::vector<NodeDef * > * nodes) const922 Status ScopedAllocatorOptimizer::OrderNodeSet(
923 std::vector<NodeDef*>* nodes) const {
924 // Nodes should be identical type. Default order is by name but for
925 // collectives we order by increasing instance_key so each group gets
926 // the same instance_key.
927 if (nodes->size() <= 1) return Status::OK();
928 if (IsCollectiveNode(*nodes->at(0))) {
929 sort(nodes->begin(), nodes->end(), InstanceKeyLess());
930 } else {
931 sort(nodes->begin(), nodes->end(), NameLess());
932 }
933 return Status::OK();
934 }
935
936 } // namespace grappler
937 } // namespace tensorflow
938
939 #undef LOG_WARNING_AND_RETURN_IF_ERROR
940