1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/metrics.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/edgeset.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_node_util.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace tensorflow {
29
30 namespace {
IsInitializationOp(const Node * node)31 bool IsInitializationOp(const Node* node) {
32 return node->op_def().allows_uninitialized_input();
33 }
34 } // namespace
35
~ImmutableExecutorState()36 ImmutableExecutorState::~ImmutableExecutorState() {
37 for (int32_t i = 0; i < gview_.num_nodes(); i++) {
38 NodeItem* item = gview_.node(i);
39 if (item != nullptr) {
40 params_.delete_kernel(item->kernel);
41 }
42 }
43 }
44
45 namespace {
GetMaxPendingCounts(const Node * n,size_t * max_pending,size_t * max_dead_count)46 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
47 size_t* max_dead_count) {
48 const size_t num_in_edges = n->in_edges().size();
49 size_t initial_count;
50 if (IsMerge(n)) {
51 // merge waits all control inputs so we initialize the pending
52 // count to be the number of control edges.
53 int32_t num_control_edges = 0;
54 for (const Edge* edge : n->in_edges()) {
55 if (edge->IsControlEdge()) {
56 num_control_edges++;
57 }
58 }
59 // Use bit 0 to indicate if we are waiting for a ready live data input.
60 initial_count = 1 + (num_control_edges << 1);
61 } else {
62 initial_count = num_in_edges;
63 }
64
65 *max_pending = initial_count;
66 *max_dead_count = num_in_edges;
67 }
68 } // namespace
69
EnsureFrameInfo(const string & fname)70 ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
71 const string& fname) {
72 auto iter = frame_info_.find(fname);
73 if (iter != frame_info_.end()) {
74 return iter->second.get();
75 } else {
76 auto frame_info = absl::make_unique<FrameInfo>(fname);
77 absl::string_view fname_view = frame_info->name;
78 auto emplace_result =
79 frame_info_.emplace(fname_view, std::move(frame_info));
80 return emplace_result.first->second.get();
81 }
82 }
83
Initialize(const Graph & graph)84 Status ImmutableExecutorState::Initialize(const Graph& graph) {
85 TF_RETURN_IF_ERROR(gview_.Initialize(&graph));
86
87 // Build the information about frames in this subgraph.
88 ControlFlowInfo cf_info;
89 TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
90
91 for (auto& it : cf_info.unique_frame_names) {
92 EnsureFrameInfo(it)->nodes =
93 absl::make_unique<std::vector<const NodeItem*>>();
94 }
95 root_frame_info_ = frame_info_[""].get();
96
97 pending_ids_.resize(gview_.num_nodes());
98
99 // Preprocess every node in the graph to create an instance of op
100 // kernel for each node.
101 requires_control_flow_ = false;
102 for (const Node* n : graph.nodes()) {
103 if (IsSink(n)) continue;
104 if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
105 requires_control_flow_ = true;
106 } else if (IsRecv(n)) {
107 // A Recv node from a different device may produce dead tensors from
108 // non-local control-flow nodes.
109 //
110 // TODO(mrry): Track whether control flow was present in the
111 // pre-partitioned graph, and enable the caller (e.g.
112 // `DirectSession`) to relax this constraint.
113 string send_device;
114 string recv_device;
115 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
116 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
117 if (send_device != recv_device) {
118 requires_control_flow_ = true;
119 }
120 }
121
122 const int id = n->id();
123 const string& frame_name = cf_info.frame_names[id];
124 FrameInfo* frame_info = EnsureFrameInfo(frame_name);
125
126 NodeItem* item = gview_.node(id);
127 item->node_id = id;
128
129 item->input_start = frame_info->total_inputs;
130 frame_info->total_inputs += n->num_inputs();
131
132 Status s = params_.create_kernel(n->properties(), &item->kernel);
133 if (!s.ok()) {
134 item->kernel = nullptr;
135 s = AttachDef(s, *n);
136 return s;
137 }
138 CHECK(item->kernel);
139 item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
140 item->is_merge = IsMerge(n);
141 item->is_any_consumer_merge_or_control_trigger = false;
142 for (const Node* consumer : n->out_nodes()) {
143 if (IsMerge(consumer) || IsControlTrigger(consumer)) {
144 item->is_any_consumer_merge_or_control_trigger = true;
145 break;
146 }
147 }
148 const Tensor* const_tensor = item->kernel->const_tensor();
149 if (const_tensor) {
150 // Hold onto a shallow copy of the constant tensor in `*this` so that the
151 // reference count does not drop to 1. This prevents the constant tensor
152 // from being forwarded, and its buffer reused.
153 const_tensors_.emplace_back(*const_tensor);
154 }
155 item->const_tensor = const_tensor;
156 item->is_noop = (item->kernel->type_string_view() == "NoOp");
157 item->is_enter = IsEnter(n);
158 if (item->is_enter) {
159 bool is_constant_enter;
160 TF_RETURN_IF_ERROR(
161 GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
162 item->is_constant_enter = is_constant_enter;
163
164 string frame_name;
165 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
166 FrameInfo* frame_info = frame_info_[frame_name].get();
167
168 int parallel_iterations;
169 TF_RETURN_IF_ERROR(
170 GetNodeAttr(n->attrs(), "parallel_iterations", ¶llel_iterations));
171
172 if (frame_info->parallel_iterations == -1) {
173 frame_info->parallel_iterations = parallel_iterations;
174 } else if (frame_info->parallel_iterations != parallel_iterations) {
175 LOG(WARNING) << "Loop frame \"" << frame_name
176 << "\" had two different values for parallel_iterations: "
177 << frame_info->parallel_iterations << " vs. "
178 << parallel_iterations << ".";
179 }
180
181 if (enter_frame_info_.size() <= id) {
182 enter_frame_info_.resize(id + 1);
183 }
184 enter_frame_info_[id] = frame_info;
185 } else {
186 item->is_constant_enter = false;
187 }
188 item->is_exit = IsExit(n);
189 item->is_control_trigger = IsControlTrigger(n);
190 item->is_source = IsSource(n);
191 item->is_enter_exit_or_next_iter =
192 (IsEnter(n) || IsExit(n) || IsNextIteration(n));
193 item->is_transfer_node = IsTransferNode(n);
194 item->is_initialization_op = IsInitializationOp(n);
195 item->is_recv_or_switch = IsRecv(n) || IsSwitch(n);
196 item->is_next_iteration = IsNextIteration(n);
197 item->is_distributed_communication = IsDistributedCommunication(n);
198
199 // Compute the maximum values we'll store for this node in the
200 // pending counts data structure, and allocate a handle in
201 // that frame's pending counts data structure that has enough
202 // space to store these maximal count values.
203 size_t max_pending, max_dead;
204 GetMaxPendingCounts(n, &max_pending, &max_dead);
205 pending_ids_[id] =
206 frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
207
208 // See if this node is a root node, and if so, add item to root_nodes_.
209 if (n->in_edges().empty()) {
210 root_nodes_.push_back(item);
211 }
212
213 // Initialize static information about the frames in the graph.
214 frame_info->nodes->push_back(item);
215 if (item->is_enter) {
216 string enter_name;
217 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
218 EnsureFrameInfo(enter_name)->input_count++;
219 }
220
221 // Record information about whether each output of the op is used.
222 std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]);
223 std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false);
224 int32_t unused_outputs = n->num_outputs();
225 for (const Edge* e : n->out_edges()) {
226 if (IsSink(e->dst())) continue;
227 if (e->src_output() >= 0) {
228 if (!outputs_required[e->src_output()]) {
229 --unused_outputs;
230 outputs_required[e->src_output()] = true;
231 }
232 }
233 }
234 if (unused_outputs > 0) {
235 for (int i = 0; i < n->num_outputs(); ++i) {
236 if (!outputs_required[i]) {
237 metrics::RecordUnusedOutput(n->type_string());
238 }
239 }
240 item->outputs_required = std::move(outputs_required);
241 }
242 }
243
244 // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input
245 // location.
246 for (const Node* n : graph.nodes()) {
247 if (IsSink(n)) continue;
248 const int id = n->id();
249 NodeItem* item = gview_.node(id);
250
251 for (EdgeInfo& e : item->mutable_output_edges()) {
252 const int dst_id = e.dst_id;
253 NodeItem* dst_item = gview_.node(dst_id);
254 e.input_slot += dst_item->input_start;
255 }
256 }
257
258 // Initialize PendingCounts only after pending_ids_[node.id] is initialized
259 // for all nodes.
260 InitializePending(&graph, cf_info);
261 return gview_.SetAllocAttrs(&graph, params_.device);
262 }
263
264 namespace {
265 // If a Node has been marked to use a ScopedAllocator x for output i, then
266 // sc_attr will contain the subsequence (i, x) at an even offset. This function
267 // extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
268 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)269 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
270 int output_index,
271 AllocatorAttributes* alloc_attr) {
272 DCHECK_LE(2, sc_attr.size());
273 for (int i = 0; i < sc_attr.size(); i += 2) {
274 if (sc_attr[i] == output_index) {
275 CHECK_EQ(alloc_attr->scope_id, 0);
276 alloc_attr->scope_id = sc_attr[i + 1];
277 return true;
278 }
279 }
280 return false;
281 }
282 } // namespace
283
BuildControlFlowInfo(const Graph * g,ControlFlowInfo * cf_info)284 Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g,
285 ControlFlowInfo* cf_info) {
286 const int num_nodes = g->num_node_ids();
287 cf_info->frame_names.resize(num_nodes);
288 std::vector<Node*> parent_nodes;
289 parent_nodes.resize(num_nodes);
290 std::vector<bool> visited;
291 visited.resize(num_nodes);
292
293 string frame_name;
294 std::deque<Node*> ready;
295
296 // Initialize with the root nodes.
297 for (Node* n : g->nodes()) {
298 if (n->in_edges().empty()) {
299 visited[n->id()] = true;
300 cf_info->unique_frame_names.insert(frame_name);
301 ready.push_back(n);
302 }
303 }
304
305 while (!ready.empty()) {
306 Node* curr_node = ready.front();
307 int curr_id = curr_node->id();
308 ready.pop_front();
309
310 Node* parent = nullptr;
311 if (IsEnter(curr_node)) {
312 // Enter a child frame.
313 TF_RETURN_IF_ERROR(
314 GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
315 parent = curr_node;
316 } else if (IsExit(curr_node)) {
317 // Exit to the parent frame.
318 parent = parent_nodes[curr_id];
319 frame_name = cf_info->frame_names[parent->id()];
320 parent = parent_nodes[parent->id()];
321 } else {
322 parent = parent_nodes[curr_id];
323 frame_name = cf_info->frame_names[curr_id];
324 }
325
326 for (const Edge* out_edge : curr_node->out_edges()) {
327 Node* out = out_edge->dst();
328 if (IsSink(out)) continue;
329 const int out_id = out->id();
330
331 // Add to ready queue if not visited.
332 bool is_visited = visited[out_id];
333 if (!is_visited) {
334 ready.push_back(out);
335 visited[out_id] = true;
336
337 // Process the node 'out'.
338 cf_info->frame_names[out_id] = frame_name;
339 parent_nodes[out_id] = parent;
340 cf_info->unique_frame_names.insert(frame_name);
341 }
342 }
343 }
344
345 return Status::OK();
346 }
347
InitializePending(const Graph * graph,const ControlFlowInfo & cf_info)348 void ImmutableExecutorState::InitializePending(const Graph* graph,
349 const ControlFlowInfo& cf_info) {
350 for (auto& it : cf_info.unique_frame_names) {
351 FrameInfo* finfo = EnsureFrameInfo(it);
352 DCHECK_EQ(finfo->pending_counts.get(), nullptr);
353 finfo->pending_counts =
354 absl::make_unique<PendingCounts>(finfo->pending_counts_layout);
355 }
356
357 if (!requires_control_flow_) {
358 atomic_pending_counts_.reset(new std::atomic<int32>[gview_.num_nodes()]);
359 std::fill(atomic_pending_counts_.get(),
360 atomic_pending_counts_.get() + gview_.num_nodes(), 0);
361 }
362
363 for (const Node* n : graph->nodes()) {
364 if (IsSink(n)) continue;
365 const int id = n->id();
366 const string& name = cf_info.frame_names[id];
367 size_t max_pending, max_dead;
368 GetMaxPendingCounts(n, &max_pending, &max_dead);
369 auto& counts = EnsureFrameInfo(name)->pending_counts;
370 counts->set_initial_count(pending_ids_[id], max_pending);
371 if (!requires_control_flow_) {
372 atomic_pending_counts_[id] = max_pending;
373 }
374 }
375 }
376 } // namespace tensorflow
377