• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_
17 
18 #include <atomic>
19 #include <deque>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/core/common_runtime/graph_view.h"
25 #include "tensorflow/core/common_runtime/local_executor_params.h"
26 #include "tensorflow/core/common_runtime/pending_counts.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/flatmap.h"
30 #include "tensorflow/core/lib/gtl/flatset.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 
36 class Graph;
37 
38 // Represents the state of an executor (graph and control flow information)
39 // that is immutable throughout execution.
40 //
41 // TODO(b/152651962): Add independent unit tests for this class.
42 class ImmutableExecutorState {
43  public:
44   struct FrameInfo {
FrameInfoFrameInfo45     explicit FrameInfo(string name)
46         : name(std::move(name)),
47           input_count(0),
48           total_inputs(0),
49           pending_counts(nullptr),
50           nodes(nullptr),
51           parallel_iterations(-1) {}
52 
53     // The name of the frame.
54     string name;
55 
56     // The total number of inputs to a frame.
57     int input_count;
58 
59     // The total number of input tensors of a frame.
60     // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
61     int total_inputs;
62 
63     // Used to determine the next place to allocate space in the
64     // pending_counts data structure we'll eventually construct
65     PendingCounts::Layout pending_counts_layout;
66 
67     // Each frame has its own PendingCounts only for the nodes in the frame.
68     std::unique_ptr<PendingCounts> pending_counts;
69 
70     // The nodes in a frame. Used only for debugging.
71     std::unique_ptr<std::vector<const NodeItem*>> nodes;
72 
73     // The number of iterations of this frame that can execute concurrently.
74     int32 parallel_iterations;
75   };
76 
ImmutableExecutorState(const LocalExecutorParams & p)77   explicit ImmutableExecutorState(const LocalExecutorParams& p)
78       : params_(p), gview_() {}
79   ~ImmutableExecutorState();
80 
81   Status Initialize(const Graph& graph);
82 
83   // Process all Nodes in the current graph, attempting to infer the
84   // memory allocation attributes to be used wherever they may allocate
85   // a tensor buffer.
86   Status SetAllocAttrs();
87 
params()88   const LocalExecutorParams& params() const { return params_; }
graph_view()89   const GraphView& graph_view() const { return gview_; }
pending_ids()90   const std::vector<PendingCounts::Handle>& pending_ids() const {
91     return pending_ids_;
92   }
root_nodes()93   const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; }
94 
get_root_frame_info()95   const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
96 
get_enter_frame_info(const NodeItem & node_item)97   const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const {
98     DCHECK(node_item.is_enter);
99     return *enter_frame_info_[node_item.node_id];
100   }
101 
requires_control_flow_support()102   bool requires_control_flow_support() const { return requires_control_flow_; }
103 
104   // Copies the pending counts for nodes in this graph to the given array.
105   //
106   // This method provides a more efficient way of initializing
107   // `SimplePropagatorState` than individually accessing the pending counts from
108   // `get_root_frame_info().counts`.
109   //
110   // REQUIRES: `!requires_control_flow_support && len(dest) ==
111   // graph_view().num_nodes()`.
copy_pending_counts(std::atomic<int32> * dest)112   void copy_pending_counts(std::atomic<int32>* dest) const {
113     DCHECK(!requires_control_flow_);
114     memcpy(dest, atomic_pending_counts_.get(),
115            graph_view().num_nodes() * sizeof(std::atomic<int32>));
116     std::atomic_thread_fence(std::memory_order_release);
117   }
118 
119  private:
120   struct ControlFlowInfo {
121     gtl::FlatSet<string> unique_frame_names;
122     std::vector<string> frame_names;
123   };
124 
125   static Status BuildControlFlowInfo(const Graph* graph,
126                                      ControlFlowInfo* cf_info);
127   void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
128 
129   FrameInfo* EnsureFrameInfo(const string& fname);
130 
131   // Owned.
132   LocalExecutorParams params_;
133   GraphView gview_;
134   bool requires_control_flow_;
135   std::vector<PendingCounts::Handle> pending_ids_;
136 
137   // Root nodes (with no in edges) that should form the initial ready queue
138   std::vector<const NodeItem*> root_nodes_;
139 
140   // Mapping from frame name to static information about the frame.
141   // TODO(yuanbyu): We could cache it along with the graph so to avoid
142   // the overhead of constructing it for each executor instance.
143   absl::flat_hash_map<absl::string_view, std::unique_ptr<FrameInfo>>
144       frame_info_;
145   const FrameInfo* root_frame_info_;  // Not owned.
146 
147   // If the graph contains any "Enter" or "RefEnter" nodes, this vector maps
148   // dense node IDs to the corresponding FrameInfo.
149   std::vector<FrameInfo*> enter_frame_info_;
150 
151   // If `requires_control_flow_` is false, this points to an array of initial
152   // pending counts for the nodes in the graph, indexed by node ID.
153   std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_;
154 
155   // Shallow copies of the constant tensors used in the graph.
156   std::vector<Tensor> const_tensors_;
157 
158   TF_DISALLOW_COPY_AND_ASSIGN(ImmutableExecutorState);
159 };
160 
161 }  // namespace tensorflow
162 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_
163