• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/metrics.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/edgeset.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_node_util.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace tensorflow {
29 
30 namespace {
IsInitializationOp(const Node * node)31 bool IsInitializationOp(const Node* node) {
32   return node->op_def().allows_uninitialized_input();
33 }
34 }  // namespace
35 
~ImmutableExecutorState()36 ImmutableExecutorState::~ImmutableExecutorState() {
37   for (int32_t i = 0; i < gview_.num_nodes(); i++) {
38     NodeItem* item = gview_.node(i);
39     if (item != nullptr) {
40       params_.delete_kernel(item->kernel);
41     }
42   }
43 }
44 
45 namespace {
GetMaxPendingCounts(const Node * n,size_t * max_pending,size_t * max_dead_count)46 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
47                          size_t* max_dead_count) {
48   const size_t num_in_edges = n->in_edges().size();
49   size_t initial_count;
50   if (IsMerge(n)) {
51     // merge waits all control inputs so we initialize the pending
52     // count to be the number of control edges.
53     int32_t num_control_edges = 0;
54     for (const Edge* edge : n->in_edges()) {
55       if (edge->IsControlEdge()) {
56         num_control_edges++;
57       }
58     }
59     // Use bit 0 to indicate if we are waiting for a ready live data input.
60     initial_count = 1 + (num_control_edges << 1);
61   } else {
62     initial_count = num_in_edges;
63   }
64 
65   *max_pending = initial_count;
66   *max_dead_count = num_in_edges;
67 }
68 }  // namespace
69 
EnsureFrameInfo(const string & fname)70 ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
71     const string& fname) {
72   auto iter = frame_info_.find(fname);
73   if (iter != frame_info_.end()) {
74     return iter->second.get();
75   } else {
76     auto frame_info = absl::make_unique<FrameInfo>(fname);
77     absl::string_view fname_view = frame_info->name;
78     auto emplace_result =
79         frame_info_.emplace(fname_view, std::move(frame_info));
80     return emplace_result.first->second.get();
81   }
82 }
83 
Initialize(const Graph & graph)84 Status ImmutableExecutorState::Initialize(const Graph& graph) {
85   TF_RETURN_IF_ERROR(gview_.Initialize(&graph));
86 
87   // Build the information about frames in this subgraph.
88   ControlFlowInfo cf_info;
89   TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
90 
91   for (auto& it : cf_info.unique_frame_names) {
92     EnsureFrameInfo(it)->nodes =
93         absl::make_unique<std::vector<const NodeItem*>>();
94   }
95   root_frame_info_ = frame_info_[""].get();
96 
97   pending_ids_.resize(gview_.num_nodes());
98 
99   // Preprocess every node in the graph to create an instance of op
100   // kernel for each node.
101   requires_control_flow_ = false;
102   for (const Node* n : graph.nodes()) {
103     if (IsSink(n)) continue;
104     if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
105       requires_control_flow_ = true;
106     } else if (IsRecv(n)) {
107       // A Recv node from a different device may produce dead tensors from
108       // non-local control-flow nodes.
109       //
110       // TODO(mrry): Track whether control flow was present in the
111       // pre-partitioned graph, and enable the caller (e.g.
112       // `DirectSession`) to relax this constraint.
113       string send_device;
114       string recv_device;
115       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
116       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
117       if (send_device != recv_device) {
118         requires_control_flow_ = true;
119       }
120     }
121 
122     const int id = n->id();
123     const string& frame_name = cf_info.frame_names[id];
124     FrameInfo* frame_info = EnsureFrameInfo(frame_name);
125 
126     NodeItem* item = gview_.node(id);
127     item->node_id = id;
128 
129     item->input_start = frame_info->total_inputs;
130     frame_info->total_inputs += n->num_inputs();
131 
132     Status s = params_.create_kernel(n->properties(), &item->kernel);
133     if (!s.ok()) {
134       item->kernel = nullptr;
135       s = AttachDef(s, *n);
136       return s;
137     }
138     CHECK(item->kernel);
139     item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
140     item->is_merge = IsMerge(n);
141     item->is_any_consumer_merge_or_control_trigger = false;
142     for (const Node* consumer : n->out_nodes()) {
143       if (IsMerge(consumer) || IsControlTrigger(consumer)) {
144         item->is_any_consumer_merge_or_control_trigger = true;
145         break;
146       }
147     }
148     const Tensor* const_tensor = item->kernel->const_tensor();
149     if (const_tensor) {
150       // Hold onto a shallow copy of the constant tensor in `*this` so that the
151       // reference count does not drop to 1. This prevents the constant tensor
152       // from being forwarded, and its buffer reused.
153       const_tensors_.emplace_back(*const_tensor);
154     }
155     item->const_tensor = const_tensor;
156     item->is_noop = (item->kernel->type_string_view() == "NoOp");
157     item->is_enter = IsEnter(n);
158     if (item->is_enter) {
159       bool is_constant_enter;
160       TF_RETURN_IF_ERROR(
161           GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
162       item->is_constant_enter = is_constant_enter;
163 
164       string frame_name;
165       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
166       FrameInfo* frame_info = frame_info_[frame_name].get();
167 
168       int parallel_iterations;
169       TF_RETURN_IF_ERROR(
170           GetNodeAttr(n->attrs(), "parallel_iterations", &parallel_iterations));
171 
172       if (frame_info->parallel_iterations == -1) {
173         frame_info->parallel_iterations = parallel_iterations;
174       } else if (frame_info->parallel_iterations != parallel_iterations) {
175         LOG(WARNING) << "Loop frame \"" << frame_name
176                      << "\" had two different values for parallel_iterations: "
177                      << frame_info->parallel_iterations << " vs. "
178                      << parallel_iterations << ".";
179       }
180 
181       if (enter_frame_info_.size() <= id) {
182         enter_frame_info_.resize(id + 1);
183       }
184       enter_frame_info_[id] = frame_info;
185     } else {
186       item->is_constant_enter = false;
187     }
188     item->is_exit = IsExit(n);
189     item->is_control_trigger = IsControlTrigger(n);
190     item->is_source = IsSource(n);
191     item->is_enter_exit_or_next_iter =
192         (IsEnter(n) || IsExit(n) || IsNextIteration(n));
193     item->is_transfer_node = IsTransferNode(n);
194     item->is_initialization_op = IsInitializationOp(n);
195     item->is_recv_or_switch = IsRecv(n) || IsSwitch(n);
196     item->is_next_iteration = IsNextIteration(n);
197     item->is_distributed_communication = IsDistributedCommunication(n);
198 
199     // Compute the maximum values we'll store for this node in the
200     // pending counts data structure, and allocate a handle in
201     // that frame's pending counts data structure that has enough
202     // space to store these maximal count values.
203     size_t max_pending, max_dead;
204     GetMaxPendingCounts(n, &max_pending, &max_dead);
205     pending_ids_[id] =
206         frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
207 
208     // See if this node is a root node, and if so, add item to root_nodes_.
209     if (n->in_edges().empty()) {
210       root_nodes_.push_back(item);
211     }
212 
213     // Initialize static information about the frames in the graph.
214     frame_info->nodes->push_back(item);
215     if (item->is_enter) {
216       string enter_name;
217       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
218       EnsureFrameInfo(enter_name)->input_count++;
219     }
220 
221     // Record information about whether each output of the op is used.
222     std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]);
223     std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false);
224     int32_t unused_outputs = n->num_outputs();
225     for (const Edge* e : n->out_edges()) {
226       if (IsSink(e->dst())) continue;
227       if (e->src_output() >= 0) {
228         if (!outputs_required[e->src_output()]) {
229           --unused_outputs;
230           outputs_required[e->src_output()] = true;
231         }
232       }
233     }
234     if (unused_outputs > 0) {
235       for (int i = 0; i < n->num_outputs(); ++i) {
236         if (!outputs_required[i]) {
237           metrics::RecordUnusedOutput(n->type_string());
238         }
239       }
240       item->outputs_required = std::move(outputs_required);
241     }
242   }
243 
244   // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input
245   // location.
246   for (const Node* n : graph.nodes()) {
247     if (IsSink(n)) continue;
248     const int id = n->id();
249     NodeItem* item = gview_.node(id);
250 
251     for (EdgeInfo& e : item->mutable_output_edges()) {
252       const int dst_id = e.dst_id;
253       NodeItem* dst_item = gview_.node(dst_id);
254       e.input_slot += dst_item->input_start;
255     }
256   }
257 
258   // Initialize PendingCounts only after pending_ids_[node.id] is initialized
259   // for all nodes.
260   InitializePending(&graph, cf_info);
261   return gview_.SetAllocAttrs(&graph, params_.device);
262 }
263 
264 namespace {
265 // If a Node has been marked to use a ScopedAllocator x for output i, then
266 // sc_attr will contain the subsequence (i, x) at an even offset.  This function
267 // extracts and transfers that ScopedAllocator id to alloc_attr.  For now, we
268 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)269 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
270                                 int output_index,
271                                 AllocatorAttributes* alloc_attr) {
272   DCHECK_LE(2, sc_attr.size());
273   for (int i = 0; i < sc_attr.size(); i += 2) {
274     if (sc_attr[i] == output_index) {
275       CHECK_EQ(alloc_attr->scope_id, 0);
276       alloc_attr->scope_id = sc_attr[i + 1];
277       return true;
278     }
279   }
280   return false;
281 }
282 }  // namespace
283 
BuildControlFlowInfo(const Graph * g,ControlFlowInfo * cf_info)284 Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g,
285                                                     ControlFlowInfo* cf_info) {
286   const int num_nodes = g->num_node_ids();
287   cf_info->frame_names.resize(num_nodes);
288   std::vector<Node*> parent_nodes;
289   parent_nodes.resize(num_nodes);
290   std::vector<bool> visited;
291   visited.resize(num_nodes);
292 
293   string frame_name;
294   std::deque<Node*> ready;
295 
296   // Initialize with the root nodes.
297   for (Node* n : g->nodes()) {
298     if (n->in_edges().empty()) {
299       visited[n->id()] = true;
300       cf_info->unique_frame_names.insert(frame_name);
301       ready.push_back(n);
302     }
303   }
304 
305   while (!ready.empty()) {
306     Node* curr_node = ready.front();
307     int curr_id = curr_node->id();
308     ready.pop_front();
309 
310     Node* parent = nullptr;
311     if (IsEnter(curr_node)) {
312       // Enter a child frame.
313       TF_RETURN_IF_ERROR(
314           GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
315       parent = curr_node;
316     } else if (IsExit(curr_node)) {
317       // Exit to the parent frame.
318       parent = parent_nodes[curr_id];
319       frame_name = cf_info->frame_names[parent->id()];
320       parent = parent_nodes[parent->id()];
321     } else {
322       parent = parent_nodes[curr_id];
323       frame_name = cf_info->frame_names[curr_id];
324     }
325 
326     for (const Edge* out_edge : curr_node->out_edges()) {
327       Node* out = out_edge->dst();
328       if (IsSink(out)) continue;
329       const int out_id = out->id();
330 
331       // Add to ready queue if not visited.
332       bool is_visited = visited[out_id];
333       if (!is_visited) {
334         ready.push_back(out);
335         visited[out_id] = true;
336 
337         // Process the node 'out'.
338         cf_info->frame_names[out_id] = frame_name;
339         parent_nodes[out_id] = parent;
340         cf_info->unique_frame_names.insert(frame_name);
341       }
342     }
343   }
344 
345   return Status::OK();
346 }
347 
InitializePending(const Graph * graph,const ControlFlowInfo & cf_info)348 void ImmutableExecutorState::InitializePending(const Graph* graph,
349                                                const ControlFlowInfo& cf_info) {
350   for (auto& it : cf_info.unique_frame_names) {
351     FrameInfo* finfo = EnsureFrameInfo(it);
352     DCHECK_EQ(finfo->pending_counts.get(), nullptr);
353     finfo->pending_counts =
354         absl::make_unique<PendingCounts>(finfo->pending_counts_layout);
355   }
356 
357   if (!requires_control_flow_) {
358     atomic_pending_counts_.reset(new std::atomic<int32>[gview_.num_nodes()]);
359     std::fill(atomic_pending_counts_.get(),
360               atomic_pending_counts_.get() + gview_.num_nodes(), 0);
361   }
362 
363   for (const Node* n : graph->nodes()) {
364     if (IsSink(n)) continue;
365     const int id = n->id();
366     const string& name = cf_info.frame_names[id];
367     size_t max_pending, max_dead;
368     GetMaxPendingCounts(n, &max_pending, &max_dead);
369     auto& counts = EnsureFrameInfo(name)->pending_counts;
370     counts->set_initial_count(pending_ids_[id], max_pending);
371     if (!requires_control_flow_) {
372       atomic_pending_counts_[id] = max_pending;
373     }
374   }
375 }
376 }  // namespace tensorflow
377