• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/metrics.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/edgeset.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_node_util.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace tensorflow {
29 
30 namespace {
IsInitializationOp(const Node * node)31 bool IsInitializationOp(const Node* node) {
32   return node->op_def().allows_uninitialized_input();
33 }
34 }  // namespace
35 
~ImmutableExecutorState()36 ImmutableExecutorState::~ImmutableExecutorState() {
37   for (int32 i = 0; i < gview_.num_nodes(); i++) {
38     NodeItem* item = gview_.node(i);
39     if (item != nullptr) {
40       params_.delete_kernel(item->kernel);
41     }
42   }
43 }
44 
45 namespace {
GetMaxPendingCounts(const Node * n,size_t * max_pending,size_t * max_dead_count)46 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
47                          size_t* max_dead_count) {
48   const size_t num_in_edges = n->in_edges().size();
49   size_t initial_count;
50   if (IsMerge(n)) {
51     // merge waits all control inputs so we initialize the pending
52     // count to be the number of control edges.
53     int32 num_control_edges = 0;
54     for (const Edge* edge : n->in_edges()) {
55       if (edge->IsControlEdge()) {
56         num_control_edges++;
57       }
58     }
59     // Use bit 0 to indicate if we are waiting for a ready live data input.
60     initial_count = 1 + (num_control_edges << 1);
61   } else {
62     initial_count = num_in_edges;
63   }
64 
65   *max_pending = initial_count;
66   *max_dead_count = num_in_edges;
67 }
68 }  // namespace
69 
EnsureFrameInfo(const string & fname)70 ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
71     const string& fname) {
72   auto iter = frame_info_.find(fname);
73   if (iter != frame_info_.end()) {
74     return iter->second.get();
75   } else {
76     auto frame_info = absl::make_unique<FrameInfo>(fname);
77     absl::string_view fname_view = frame_info->name;
78     auto emplace_result =
79         frame_info_.emplace(fname_view, std::move(frame_info));
80     return emplace_result.first->second.get();
81   }
82 }
83 
Initialize(const Graph & graph)84 Status ImmutableExecutorState::Initialize(const Graph& graph) {
85   TF_RETURN_IF_ERROR(gview_.Initialize(&graph));
86 
87   // Build the information about frames in this subgraph.
88   ControlFlowInfo cf_info;
89   TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
90 
91   for (auto& it : cf_info.unique_frame_names) {
92     EnsureFrameInfo(it)->nodes =
93         absl::make_unique<std::vector<const NodeItem*>>();
94   }
95   root_frame_info_ = frame_info_[""].get();
96 
97   pending_ids_.resize(gview_.num_nodes());
98 
99   // Preprocess every node in the graph to create an instance of op
100   // kernel for each node.
101   requires_control_flow_ = false;
102   for (const Node* n : graph.nodes()) {
103     if (IsSink(n)) continue;
104     if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
105       requires_control_flow_ = true;
106     } else if (IsRecv(n)) {
107       // A Recv node from a different device may produce dead tensors from
108       // non-local control-flow nodes.
109       //
110       // TODO(mrry): Track whether control flow was present in the
111       // pre-partitioned graph, and enable the caller (e.g.
112       // `DirectSession`) to relax this constraint.
113       string send_device;
114       string recv_device;
115       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
116       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
117       if (send_device != recv_device) {
118         requires_control_flow_ = true;
119       }
120     }
121 
122     const int id = n->id();
123     const string& frame_name = cf_info.frame_names[id];
124     FrameInfo* frame_info = EnsureFrameInfo(frame_name);
125 
126     NodeItem* item = gview_.node(id);
127     item->node_id = id;
128 
129     item->input_start = frame_info->total_inputs;
130     frame_info->total_inputs += n->num_inputs();
131 
132     Status s = params_.create_kernel(n->properties(), &item->kernel);
133     if (!s.ok()) {
134       item->kernel = nullptr;
135       s = AttachDef(s, *n);
136       return s;
137     }
138     CHECK(item->kernel);
139     item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
140     item->is_merge = IsMerge(n);
141     item->is_any_consumer_merge_or_control_trigger = false;
142     for (const Node* consumer : n->out_nodes()) {
143       if (IsMerge(consumer) || IsControlTrigger(consumer)) {
144         item->is_any_consumer_merge_or_control_trigger = true;
145         break;
146       }
147     }
148     const Tensor* const_tensor = item->kernel->const_tensor();
149     if (const_tensor) {
150       // Hold onto a shallow copy of the constant tensor in `*this` so that the
151       // reference count does not drop to 1. This prevents the constant tensor
152       // from being forwarded, and its buffer reused.
153       const_tensors_.emplace_back(*const_tensor);
154     }
155     item->const_tensor = const_tensor;
156     item->is_noop = (item->kernel->type_string_view() == "NoOp");
157     item->is_enter = IsEnter(n);
158     if (item->is_enter) {
159       bool is_constant_enter;
160       TF_RETURN_IF_ERROR(
161           GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
162       item->is_constant_enter = is_constant_enter;
163 
164       string frame_name;
165       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
166       FrameInfo* frame_info = frame_info_[frame_name].get();
167 
168       int parallel_iterations;
169       TF_RETURN_IF_ERROR(
170           GetNodeAttr(n->attrs(), "parallel_iterations", &parallel_iterations));
171 
172       if (frame_info->parallel_iterations == -1) {
173         frame_info->parallel_iterations = parallel_iterations;
174       } else if (frame_info->parallel_iterations != parallel_iterations) {
175         LOG(WARNING) << "Loop frame \"" << frame_name
176                      << "\" had two different values for parallel_iterations: "
177                      << frame_info->parallel_iterations << " vs. "
178                      << parallel_iterations << ".";
179       }
180 
181       if (enter_frame_info_.size() <= id) {
182         enter_frame_info_.resize(id + 1);
183       }
184       enter_frame_info_[id] = frame_info;
185     } else {
186       item->is_constant_enter = false;
187     }
188     item->is_exit = IsExit(n);
189     item->is_control_trigger = IsControlTrigger(n);
190     item->is_source = IsSource(n);
191     item->is_enter_exit_or_next_iter =
192         (IsEnter(n) || IsExit(n) || IsNextIteration(n));
193     item->is_transfer_node = IsTransferNode(n);
194     item->is_initialization_op = IsInitializationOp(n);
195     item->is_recv_or_switch = IsRecv(n) || IsSwitch(n);
196     item->is_next_iteration = IsNextIteration(n);
197 
198     // Compute the maximum values we'll store for this node in the
199     // pending counts data structure, and allocate a handle in
200     // that frame's pending counts data structure that has enough
201     // space to store these maximal count values.
202     size_t max_pending, max_dead;
203     GetMaxPendingCounts(n, &max_pending, &max_dead);
204     pending_ids_[id] =
205         frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
206 
207     // See if this node is a root node, and if so, add item to root_nodes_.
208     if (n->in_edges().empty()) {
209       root_nodes_.push_back(item);
210     }
211 
212     // Initialize static information about the frames in the graph.
213     frame_info->nodes->push_back(item);
214     if (item->is_enter) {
215       string enter_name;
216       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
217       EnsureFrameInfo(enter_name)->input_count++;
218     }
219 
220     // Record information about whether each output of the op is used.
221     std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]);
222     std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false);
223     int32 unused_outputs = n->num_outputs();
224     for (const Edge* e : n->out_edges()) {
225       if (IsSink(e->dst())) continue;
226       if (e->src_output() >= 0) {
227         if (!outputs_required[e->src_output()]) {
228           --unused_outputs;
229           outputs_required[e->src_output()] = true;
230         }
231       }
232     }
233     if (unused_outputs > 0) {
234       for (int i = 0; i < n->num_outputs(); ++i) {
235         if (!outputs_required[i]) {
236           metrics::RecordUnusedOutput(n->type_string());
237         }
238       }
239       item->outputs_required = std::move(outputs_required);
240     }
241   }
242 
243   // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input
244   // location.
245   for (const Node* n : graph.nodes()) {
246     if (IsSink(n)) continue;
247     const int id = n->id();
248     NodeItem* item = gview_.node(id);
249 
250     for (EdgeInfo& e : item->mutable_output_edges()) {
251       const int dst_id = e.dst_id;
252       NodeItem* dst_item = gview_.node(dst_id);
253       e.input_slot += dst_item->input_start;
254     }
255   }
256 
257   // Initialize PendingCounts only after pending_ids_[node.id] is initialized
258   // for all nodes.
259   InitializePending(&graph, cf_info);
260   return gview_.SetAllocAttrs(&graph, params_.device);
261 }
262 
263 namespace {
264 // If a Node has been marked to use a ScopedAllocator x for output i, then
265 // sc_attr will contain the subsequence (i, x) at an even offset.  This function
266 // extracts and transfers that ScopedAllocator id to alloc_attr.  For now, we
267 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)268 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
269                                 int output_index,
270                                 AllocatorAttributes* alloc_attr) {
271   DCHECK_LE(2, sc_attr.size());
272   for (int i = 0; i < sc_attr.size(); i += 2) {
273     if (sc_attr[i] == output_index) {
274       CHECK_EQ(alloc_attr->scope_id, 0);
275       alloc_attr->scope_id = sc_attr[i + 1];
276       return true;
277     }
278   }
279   return false;
280 }
281 }  // namespace
282 
BuildControlFlowInfo(const Graph * g,ControlFlowInfo * cf_info)283 Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g,
284                                                     ControlFlowInfo* cf_info) {
285   const int num_nodes = g->num_node_ids();
286   cf_info->frame_names.resize(num_nodes);
287   std::vector<Node*> parent_nodes;
288   parent_nodes.resize(num_nodes);
289   std::vector<bool> visited;
290   visited.resize(num_nodes);
291 
292   string frame_name;
293   std::deque<Node*> ready;
294 
295   // Initialize with the root nodes.
296   for (Node* n : g->nodes()) {
297     if (n->in_edges().empty()) {
298       visited[n->id()] = true;
299       cf_info->unique_frame_names.insert(frame_name);
300       ready.push_back(n);
301     }
302   }
303 
304   while (!ready.empty()) {
305     Node* curr_node = ready.front();
306     int curr_id = curr_node->id();
307     ready.pop_front();
308 
309     Node* parent = nullptr;
310     if (IsEnter(curr_node)) {
311       // Enter a child frame.
312       TF_RETURN_IF_ERROR(
313           GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
314       parent = curr_node;
315     } else if (IsExit(curr_node)) {
316       // Exit to the parent frame.
317       parent = parent_nodes[curr_id];
318       frame_name = cf_info->frame_names[parent->id()];
319       parent = parent_nodes[parent->id()];
320     } else {
321       parent = parent_nodes[curr_id];
322       frame_name = cf_info->frame_names[curr_id];
323     }
324 
325     for (const Edge* out_edge : curr_node->out_edges()) {
326       Node* out = out_edge->dst();
327       if (IsSink(out)) continue;
328       const int out_id = out->id();
329 
330       // Add to ready queue if not visited.
331       bool is_visited = visited[out_id];
332       if (!is_visited) {
333         ready.push_back(out);
334         visited[out_id] = true;
335 
336         // Process the node 'out'.
337         cf_info->frame_names[out_id] = frame_name;
338         parent_nodes[out_id] = parent;
339         cf_info->unique_frame_names.insert(frame_name);
340       }
341     }
342   }
343 
344   return Status::OK();
345 }
346 
InitializePending(const Graph * graph,const ControlFlowInfo & cf_info)347 void ImmutableExecutorState::InitializePending(const Graph* graph,
348                                                const ControlFlowInfo& cf_info) {
349   for (auto& it : cf_info.unique_frame_names) {
350     FrameInfo* finfo = EnsureFrameInfo(it);
351     DCHECK_EQ(finfo->pending_counts.get(), nullptr);
352     finfo->pending_counts =
353         absl::make_unique<PendingCounts>(finfo->pending_counts_layout);
354   }
355 
356   if (!requires_control_flow_) {
357     atomic_pending_counts_.reset(new std::atomic<int32>[gview_.num_nodes()]);
358     std::fill(atomic_pending_counts_.get(),
359               atomic_pending_counts_.get() + gview_.num_nodes(), 0);
360   }
361 
362   for (const Node* n : graph->nodes()) {
363     if (IsSink(n)) continue;
364     const int id = n->id();
365     const string& name = cf_info.frame_names[id];
366     size_t max_pending, max_dead;
367     GetMaxPendingCounts(n, &max_pending, &max_dead);
368     auto& counts = EnsureFrameInfo(name)->pending_counts;
369     counts->set_initial_count(pending_ids_[id], max_pending);
370     if (!requires_control_flow_) {
371       atomic_pending_counts_[id] = max_pending;
372     }
373   }
374 }
375 }  // namespace tensorflow
376