1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/metrics.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/edgeset.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_node_util.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace tensorflow {
29
30 namespace {
IsInitializationOp(const Node * node)31 bool IsInitializationOp(const Node* node) {
32 return node->op_def().allows_uninitialized_input();
33 }
34 } // namespace
35
~ImmutableExecutorState()36 ImmutableExecutorState::~ImmutableExecutorState() {
37 for (int32 i = 0; i < gview_.num_nodes(); i++) {
38 NodeItem* item = gview_.node(i);
39 if (item != nullptr) {
40 params_.delete_kernel(item->kernel);
41 }
42 }
43 }
44
45 namespace {
GetMaxPendingCounts(const Node * n,size_t * max_pending,size_t * max_dead_count)46 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
47 size_t* max_dead_count) {
48 const size_t num_in_edges = n->in_edges().size();
49 size_t initial_count;
50 if (IsMerge(n)) {
51 // merge waits all control inputs so we initialize the pending
52 // count to be the number of control edges.
53 int32 num_control_edges = 0;
54 for (const Edge* edge : n->in_edges()) {
55 if (edge->IsControlEdge()) {
56 num_control_edges++;
57 }
58 }
59 // Use bit 0 to indicate if we are waiting for a ready live data input.
60 initial_count = 1 + (num_control_edges << 1);
61 } else {
62 initial_count = num_in_edges;
63 }
64
65 *max_pending = initial_count;
66 *max_dead_count = num_in_edges;
67 }
68 } // namespace
69
EnsureFrameInfo(const string & fname)70 ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
71 const string& fname) {
72 auto iter = frame_info_.find(fname);
73 if (iter != frame_info_.end()) {
74 return iter->second.get();
75 } else {
76 auto frame_info = absl::make_unique<FrameInfo>(fname);
77 absl::string_view fname_view = frame_info->name;
78 auto emplace_result =
79 frame_info_.emplace(fname_view, std::move(frame_info));
80 return emplace_result.first->second.get();
81 }
82 }
83
Initialize(const Graph & graph)84 Status ImmutableExecutorState::Initialize(const Graph& graph) {
85 TF_RETURN_IF_ERROR(gview_.Initialize(&graph));
86
87 // Build the information about frames in this subgraph.
88 ControlFlowInfo cf_info;
89 TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
90
91 for (auto& it : cf_info.unique_frame_names) {
92 EnsureFrameInfo(it)->nodes =
93 absl::make_unique<std::vector<const NodeItem*>>();
94 }
95 root_frame_info_ = frame_info_[""].get();
96
97 pending_ids_.resize(gview_.num_nodes());
98
99 // Preprocess every node in the graph to create an instance of op
100 // kernel for each node.
101 requires_control_flow_ = false;
102 for (const Node* n : graph.nodes()) {
103 if (IsSink(n)) continue;
104 if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
105 requires_control_flow_ = true;
106 } else if (IsRecv(n)) {
107 // A Recv node from a different device may produce dead tensors from
108 // non-local control-flow nodes.
109 //
110 // TODO(mrry): Track whether control flow was present in the
111 // pre-partitioned graph, and enable the caller (e.g.
112 // `DirectSession`) to relax this constraint.
113 string send_device;
114 string recv_device;
115 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
116 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
117 if (send_device != recv_device) {
118 requires_control_flow_ = true;
119 }
120 }
121
122 const int id = n->id();
123 const string& frame_name = cf_info.frame_names[id];
124 FrameInfo* frame_info = EnsureFrameInfo(frame_name);
125
126 NodeItem* item = gview_.node(id);
127 item->node_id = id;
128
129 item->input_start = frame_info->total_inputs;
130 frame_info->total_inputs += n->num_inputs();
131
132 Status s = params_.create_kernel(n->properties(), &item->kernel);
133 if (!s.ok()) {
134 item->kernel = nullptr;
135 s = AttachDef(s, *n);
136 return s;
137 }
138 CHECK(item->kernel);
139 item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
140 item->is_merge = IsMerge(n);
141 item->is_any_consumer_merge_or_control_trigger = false;
142 for (const Node* consumer : n->out_nodes()) {
143 if (IsMerge(consumer) || IsControlTrigger(consumer)) {
144 item->is_any_consumer_merge_or_control_trigger = true;
145 break;
146 }
147 }
148 const Tensor* const_tensor = item->kernel->const_tensor();
149 if (const_tensor) {
150 // Hold onto a shallow copy of the constant tensor in `*this` so that the
151 // reference count does not drop to 1. This prevents the constant tensor
152 // from being forwarded, and its buffer reused.
153 const_tensors_.emplace_back(*const_tensor);
154 }
155 item->const_tensor = const_tensor;
156 item->is_noop = (item->kernel->type_string_view() == "NoOp");
157 item->is_enter = IsEnter(n);
158 if (item->is_enter) {
159 bool is_constant_enter;
160 TF_RETURN_IF_ERROR(
161 GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
162 item->is_constant_enter = is_constant_enter;
163
164 string frame_name;
165 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
166 FrameInfo* frame_info = frame_info_[frame_name].get();
167
168 int parallel_iterations;
169 TF_RETURN_IF_ERROR(
170 GetNodeAttr(n->attrs(), "parallel_iterations", ¶llel_iterations));
171
172 if (frame_info->parallel_iterations == -1) {
173 frame_info->parallel_iterations = parallel_iterations;
174 } else if (frame_info->parallel_iterations != parallel_iterations) {
175 LOG(WARNING) << "Loop frame \"" << frame_name
176 << "\" had two different values for parallel_iterations: "
177 << frame_info->parallel_iterations << " vs. "
178 << parallel_iterations << ".";
179 }
180
181 if (enter_frame_info_.size() <= id) {
182 enter_frame_info_.resize(id + 1);
183 }
184 enter_frame_info_[id] = frame_info;
185 } else {
186 item->is_constant_enter = false;
187 }
188 item->is_exit = IsExit(n);
189 item->is_control_trigger = IsControlTrigger(n);
190 item->is_source = IsSource(n);
191 item->is_enter_exit_or_next_iter =
192 (IsEnter(n) || IsExit(n) || IsNextIteration(n));
193 item->is_transfer_node = IsTransferNode(n);
194 item->is_initialization_op = IsInitializationOp(n);
195 item->is_recv_or_switch = IsRecv(n) || IsSwitch(n);
196 item->is_next_iteration = IsNextIteration(n);
197
198 // Compute the maximum values we'll store for this node in the
199 // pending counts data structure, and allocate a handle in
200 // that frame's pending counts data structure that has enough
201 // space to store these maximal count values.
202 size_t max_pending, max_dead;
203 GetMaxPendingCounts(n, &max_pending, &max_dead);
204 pending_ids_[id] =
205 frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
206
207 // See if this node is a root node, and if so, add item to root_nodes_.
208 if (n->in_edges().empty()) {
209 root_nodes_.push_back(item);
210 }
211
212 // Initialize static information about the frames in the graph.
213 frame_info->nodes->push_back(item);
214 if (item->is_enter) {
215 string enter_name;
216 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
217 EnsureFrameInfo(enter_name)->input_count++;
218 }
219
220 // Record information about whether each output of the op is used.
221 std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]);
222 std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false);
223 int32 unused_outputs = n->num_outputs();
224 for (const Edge* e : n->out_edges()) {
225 if (IsSink(e->dst())) continue;
226 if (e->src_output() >= 0) {
227 if (!outputs_required[e->src_output()]) {
228 --unused_outputs;
229 outputs_required[e->src_output()] = true;
230 }
231 }
232 }
233 if (unused_outputs > 0) {
234 for (int i = 0; i < n->num_outputs(); ++i) {
235 if (!outputs_required[i]) {
236 metrics::RecordUnusedOutput(n->type_string());
237 }
238 }
239 item->outputs_required = std::move(outputs_required);
240 }
241 }
242
243 // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input
244 // location.
245 for (const Node* n : graph.nodes()) {
246 if (IsSink(n)) continue;
247 const int id = n->id();
248 NodeItem* item = gview_.node(id);
249
250 for (EdgeInfo& e : item->mutable_output_edges()) {
251 const int dst_id = e.dst_id;
252 NodeItem* dst_item = gview_.node(dst_id);
253 e.input_slot += dst_item->input_start;
254 }
255 }
256
257 // Initialize PendingCounts only after pending_ids_[node.id] is initialized
258 // for all nodes.
259 InitializePending(&graph, cf_info);
260 return gview_.SetAllocAttrs(&graph, params_.device);
261 }
262
263 namespace {
264 // If a Node has been marked to use a ScopedAllocator x for output i, then
265 // sc_attr will contain the subsequence (i, x) at an even offset. This function
266 // extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
267 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)268 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
269 int output_index,
270 AllocatorAttributes* alloc_attr) {
271 DCHECK_LE(2, sc_attr.size());
272 for (int i = 0; i < sc_attr.size(); i += 2) {
273 if (sc_attr[i] == output_index) {
274 CHECK_EQ(alloc_attr->scope_id, 0);
275 alloc_attr->scope_id = sc_attr[i + 1];
276 return true;
277 }
278 }
279 return false;
280 }
281 } // namespace
282
BuildControlFlowInfo(const Graph * g,ControlFlowInfo * cf_info)283 Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g,
284 ControlFlowInfo* cf_info) {
285 const int num_nodes = g->num_node_ids();
286 cf_info->frame_names.resize(num_nodes);
287 std::vector<Node*> parent_nodes;
288 parent_nodes.resize(num_nodes);
289 std::vector<bool> visited;
290 visited.resize(num_nodes);
291
292 string frame_name;
293 std::deque<Node*> ready;
294
295 // Initialize with the root nodes.
296 for (Node* n : g->nodes()) {
297 if (n->in_edges().empty()) {
298 visited[n->id()] = true;
299 cf_info->unique_frame_names.insert(frame_name);
300 ready.push_back(n);
301 }
302 }
303
304 while (!ready.empty()) {
305 Node* curr_node = ready.front();
306 int curr_id = curr_node->id();
307 ready.pop_front();
308
309 Node* parent = nullptr;
310 if (IsEnter(curr_node)) {
311 // Enter a child frame.
312 TF_RETURN_IF_ERROR(
313 GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
314 parent = curr_node;
315 } else if (IsExit(curr_node)) {
316 // Exit to the parent frame.
317 parent = parent_nodes[curr_id];
318 frame_name = cf_info->frame_names[parent->id()];
319 parent = parent_nodes[parent->id()];
320 } else {
321 parent = parent_nodes[curr_id];
322 frame_name = cf_info->frame_names[curr_id];
323 }
324
325 for (const Edge* out_edge : curr_node->out_edges()) {
326 Node* out = out_edge->dst();
327 if (IsSink(out)) continue;
328 const int out_id = out->id();
329
330 // Add to ready queue if not visited.
331 bool is_visited = visited[out_id];
332 if (!is_visited) {
333 ready.push_back(out);
334 visited[out_id] = true;
335
336 // Process the node 'out'.
337 cf_info->frame_names[out_id] = frame_name;
338 parent_nodes[out_id] = parent;
339 cf_info->unique_frame_names.insert(frame_name);
340 }
341 }
342 }
343
344 return Status::OK();
345 }
346
InitializePending(const Graph * graph,const ControlFlowInfo & cf_info)347 void ImmutableExecutorState::InitializePending(const Graph* graph,
348 const ControlFlowInfo& cf_info) {
349 for (auto& it : cf_info.unique_frame_names) {
350 FrameInfo* finfo = EnsureFrameInfo(it);
351 DCHECK_EQ(finfo->pending_counts.get(), nullptr);
352 finfo->pending_counts =
353 absl::make_unique<PendingCounts>(finfo->pending_counts_layout);
354 }
355
356 if (!requires_control_flow_) {
357 atomic_pending_counts_.reset(new std::atomic<int32>[gview_.num_nodes()]);
358 std::fill(atomic_pending_counts_.get(),
359 atomic_pending_counts_.get() + gview_.num_nodes(), 0);
360 }
361
362 for (const Node* n : graph->nodes()) {
363 if (IsSink(n)) continue;
364 const int id = n->id();
365 const string& name = cf_info.frame_names[id];
366 size_t max_pending, max_dead;
367 GetMaxPendingCounts(n, &max_pending, &max_dead);
368 auto& counts = EnsureFrameInfo(name)->pending_counts;
369 counts->set_initial_count(pending_ids_[id], max_pending);
370 if (!requires_control_flow_) {
371 atomic_pending_counts_[id] = max_pending;
372 }
373 }
374 }
375 } // namespace tensorflow
376