• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/optimization_registry.h"
17 #include "tensorflow/core/graph/control_flow.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 
20 namespace tensorflow {
21 namespace {
22 
23 static constexpr const char* const kParallelIterationsAttrName =
24     "parallel_iterations";
25 
make_zeros(const DataType & dtype,const TensorShapeProto & shape)26 Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) {
27   Tensor tensor(dtype, TensorShape(shape));
28 
29   // Conveniently, all numeric data types have 0x0 == zero.  Otherwise we would
30   // need a giant switch statement here.
31   memset(const_cast<char*>(tensor.tensor_data().data()), 0,
32          tensor.tensor_data().size());
33 
34   return tensor;
35 }
36 
37 // Replaces occurrences of the "AccumulateNV2" stub operator with a graph of
38 // lower-level ops. The graph is equivalent (modulo certain corner cases)
39 // to the semantics of the original accumulate_n() Python op in math_ops.py.
40 // Implementing the op with a rewrite allows this new variant of accumulate_n
41 // to be differentiable.
42 //
43 // The binary code that generates AccumulateNV2 stub ops is located in a
44 // dynamic library built out of tensorflow/contrib/framework. Ideally, this
45 // class would also be in contrib, but calls to REGISTER_OPTIMIZATION() from
46 // third-party libraries aren't currently supported.
47 class AccumulateNV2RemovePass : public GraphOptimizationPass {
48  public:
Run(const GraphOptimizationPassOptions & options)49   Status Run(const GraphOptimizationPassOptions& options) override {
50     // TODO(freiss.oss@gmail.com): Substantial shared code with
51     // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
52     // a third similar rewrite.
53     if (options.graph == nullptr) {
54       // TODO(apassos) returning OK feels weird here as we can't do anything
55       // without a graph, but some tests require this.
56       return Status::OK();
57     }
58 
59     Graph* g = options.graph->get();
60     if (g == nullptr) {
61       return errors::Internal(
62           "AccumulateNV2 removal should happen before partitioning and a "
63           "graph should be available.");
64     }
65 
66     // Build up a todo list of ops to replace, *then* modify the graph
67     gtl::InlinedVector<Node*, 2> matches;
68     for (Node* n : g->op_nodes()) {
69       if (n->type_string() == "AccumulateNV2") {
70         matches.push_back(n);
71       }
72     }
73     if (matches.empty()) return Status::OK();
74 
75     std::vector<ControlFlowInfo> control_flow_info;
76     TF_RETURN_IF_ERROR(BuildControlFlowInfo(g, &control_flow_info));
77 
78     for (Node* n : matches) {
79       // Temporary variables do not work inside while loops with parallel
80       // iterations. If the `AccumulateNV2` node is executed inside a loop, we
81       // rewrite it into 'AddN' node.
82       const Node* frame = control_flow_info[n->id()].frame;
83       bool is_in_while_loop = frame->id() != Graph::kSourceId;
84 
85       // With `parallel_iterations == 1` it's safe to use TemporaryVariable.
86       if (is_in_while_loop) {
87         int parallel_iterations;
88         bool found = TryGetNodeAttr(frame->attrs(), kParallelIterationsAttrName,
89                                     &parallel_iterations);
90         if (found && parallel_iterations == 1) {
91           is_in_while_loop = false;
92         }
93       }
94 
95       if (is_in_while_loop) {
96         TF_RETURN_IF_ERROR(RewriteIntoAddN(n, g));
97       } else {
98         TF_RETURN_IF_ERROR(RewriteIntoTempVariable(n, g));
99       }
100     }
101     return Status::OK();
102   }
103 
RewriteIntoTempVariable(Node * n,Graph * g)104   Status RewriteIntoTempVariable(Node* n, Graph* g) {
105     VLOG(3) << "Rewrite AccumulateNV2 into TemporaryVariable and Assign: "
106             << SummarizeNode(*n);
107 
108     AttrSlice n_attrs = n->attrs();
109     auto base_make_node = [n, &n_attrs](const string& op, const string& name) {
110       NodeDebugInfo debug_info(*n);
111       NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info);
112 
113       // The pieces of AccumulateNV2 should all be on the same node.
114       node_builder.Device(n->requested_device());
115       const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
116       if (!colo.empty()) {
117         node_builder.Attr(kColocationAttrName, colo);
118       }
119       return node_builder;
120     };
121     auto make_node = [n, g, &base_make_node](string op) {
122       return base_make_node(
123           op, g->NewName(strings::StrCat(n->name(), "/Internal")));
124     };
125 
126     DataType dtype;
127     TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
128     TensorShapeProto shape;
129     TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
130 
131     std::vector<const Edge*> data_edges, control_edges;
132     for (const Edge* input_edge : n->in_edges()) {
133       if (input_edge->IsControlEdge()) {
134         control_edges.push_back(input_edge);
135       } else {
136         data_edges.push_back(input_edge);
137       }
138     }
139 
140     // Create the following ops to replace the AccumulateNV2 placeholder:
141     Node* create_accumulator = nullptr;            // TemporaryVariable op
142     Node* initial_val = nullptr;                   // Const op
143     Node* initialize_accumulator = nullptr;        // Assign op
144     std::vector<Node*> add_values_to_accumulator;  // AssignAdd ops
145     Node* clean_up_accumulator = nullptr;          // DestroyTemporaryVariable
146 
147     const string accumulator_name =
148         strings::StrCat(n->name(), "/Internal/Accumulator");
149     TensorShapeProto variable_shape;
150     variable_shape.add_dim()->set_size(0);
151     TF_RETURN_IF_ERROR(make_node("TemporaryVariable")
152                            .Attr("shape", variable_shape)
153                            .Attr("dtype", dtype)
154                            .Attr("var_name", accumulator_name)
155                            .Finalize(g, &create_accumulator));
156     PartialTensorShape partial_shape(shape);
157     // Make a Fill operation to make a zero tensor with the shape of the first
158     // input.
159     Node* shape_node;
160     TF_RETURN_IF_ERROR(
161         make_node("Shape")
162             .Input(data_edges[0]->src(), data_edges[0]->src_output())
163             .Finalize(g, &shape_node));
164     Node* zero;
165     TF_RETURN_IF_ERROR(make_node("Const")
166                            .Attr("value", make_zeros(dtype, TensorShapeProto()))
167                            .Attr("dtype", dtype)
168                            .Finalize(g, &zero));
169     TF_RETURN_IF_ERROR(make_node("Fill")
170                            .Input(shape_node)
171                            .Input(zero)
172                            .Finalize(g, &initial_val));
173     TF_RETURN_IF_ERROR(make_node("Assign")
174                            .Attr("T", dtype)
175                            .Input(create_accumulator)  // ref: Ref(T)
176                            .Input(initial_val)         // value: T
177                            .Attr("validate_shape", false)
178                            .Finalize(g, &initialize_accumulator));
179     for (int i = 0; i < data_edges.size(); ++i) {
180       Node* assignAdd;
181       TF_RETURN_IF_ERROR(make_node("AssignAdd")
182                              .Attr("T", dtype)
183                              .Attr("use_locking", true)
184                              .Input(initialize_accumulator)  // ref: Ref(T)
185                              .Input(data_edges[i]->src(),
186                                     data_edges[i]->src_output())  // value: T
187                              .Finalize(g, &assignAdd));
188 
189       add_values_to_accumulator.push_back(assignAdd);
190     }
191 
192     // Note that we use the original placeholder op's name here
193     TF_RETURN_IF_ERROR(base_make_node("DestroyTemporaryVariable", n->name())
194                            .Attr("T", dtype)
195                            .Attr("var_name", accumulator_name)
196                            .Input(initialize_accumulator)
197                            .Finalize(g, &clean_up_accumulator));
198 
199     // Add edges to the graph to ensure that operations occur in the right
200     // order:
201     // 1. Do anything that had a control edge to the AccumulateNV2 placeholder
202     // 2. Initialize accumulator
203     // 3. Add input values to accumulator (already handled by data edges
204     //    added above)
205     // 4. Reclaim the buffer that held the accumulator
206     // 5. Do anything that depended on the AccumulateNV2 placeholder
207     for (const Edge* control_edge : control_edges) {
208       g->AddControlEdge(control_edge->src(), initialize_accumulator);
209     }
210 
211     for (Node* assign_add : add_values_to_accumulator) {
212       g->AddControlEdge(assign_add, clean_up_accumulator);
213     }
214 
215     for (const Edge* out_edge : n->out_edges()) {
216       if (out_edge->IsControlEdge()) {
217         g->AddControlEdge(clean_up_accumulator, out_edge->dst());
218       } else {
219         g->AddEdge(clean_up_accumulator, 0, out_edge->dst(),
220                    out_edge->dst_input());
221       }
222     }
223 
224     // Remove the original AccumulateNV2 placeholder op.
225     // This removal modifies the op and must happen after we have finished
226     // using its incoming/outgoing edge sets.
227     g->RemoveNode(n);
228 
229     return Status::OK();
230   }
231 
RewriteIntoAddN(Node * n,Graph * g)232   Status RewriteIntoAddN(Node* n, Graph* g) {
233     VLOG(3) << "Rewrite AccumulateNV2 into AddN: " << SummarizeNode(*n);
234 
235     AttrSlice n_attrs = n->attrs();
236     DataType dtype;
237     TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
238     int num_inputs;
239     TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "N", &num_inputs));
240 
241     Node* add_n_node = nullptr;
242 
243     std::vector<NodeBuilder::NodeOut> data_inputs;
244     std::vector<Node*> control_inputs;
245     data_inputs.reserve(n->num_inputs());
246     control_inputs.reserve(n->in_edges().size() - n->num_inputs());
247     for (const Edge* in_edge : n->in_edges()) {
248       if (in_edge->IsControlEdge()) {
249         control_inputs.push_back(in_edge->src());
250       } else {
251         data_inputs.emplace_back(in_edge->src(), in_edge->src_output());
252       }
253     }
254 
255     // Rewrite `AccumulateNV2` node into `AddN` node.
256     NodeDebugInfo debug_info(*n);
257     NodeBuilder builder =
258         NodeBuilder(n->name(), "AddN", OpRegistry::Global(), &debug_info)
259             .Device(n->requested_device())
260             .Attr("N", num_inputs)
261             .Attr("T", dtype)
262             .Input(data_inputs)
263             .ControlInputs(control_inputs);
264     const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
265     if (!colo.empty()) {
266       builder.Attr(kColocationAttrName, colo);
267     }
268     TF_RETURN_IF_ERROR(builder.Finalize(g, &add_n_node));
269 
270     // Forward all consumers to the new node.
271     for (const Edge* out_edge : n->out_edges()) {
272       if (out_edge->IsControlEdge()) {
273         g->AddControlEdge(add_n_node, out_edge->dst());
274       } else {
275         g->AddEdge(add_n_node, 0, out_edge->dst(), out_edge->dst_input());
276       }
277     }
278 
279     // Remove the original AccumulateNV2 placeholder op.
280     // This removal modifies the op and must happen after we have finished
281     // using its incoming/outgoing edge sets.
282     g->RemoveNode(n);
283 
284     return Status::OK();
285   }
286 };
287 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10,
288                       AccumulateNV2RemovePass);
289 
290 }  // namespace
291 }  // namespace tensorflow
292