1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/grappler_item.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/util/device_name_utils.h"
29
30 namespace tensorflow {
31 namespace grappler {
32
WithGraph(GraphDef && graph_def) const33 GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
34 GrapplerItem item;
35 item.id = id;
36 item.feed = feed;
37 item.fetch = fetch;
38 item.init_ops = init_ops;
39 item.keep_ops = keep_ops;
40 item.expected_init_time = expected_init_time;
41 item.save_op = save_op;
42 item.restore_op = restore_op;
43 item.save_restore_loc_tensor = save_restore_loc_tensor;
44 item.queue_runners = queue_runners;
45 item.devices_ = devices_;
46 item.optimization_options_ = optimization_options_;
47 item.graph.Swap(&graph_def);
48 return item;
49 }
50
MainOpsFanin() const51 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
52 return ComputeTransitiveFanin(graph, fetch);
53 }
54
EnqueueOpsFanin() const55 std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
56 std::vector<string> enqueue_ops;
57 for (const auto& queue_runner : queue_runners) {
58 for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
59 enqueue_ops.push_back(enqueue_op);
60 }
61 }
62 return ComputeTransitiveFanin(graph, enqueue_ops);
63 }
64
InitOpsFanin() const65 std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
66 return ComputeTransitiveFanin(graph, init_ops);
67 }
68
MainVariables() const69 std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
70 std::vector<const NodeDef*> fanin = ComputeTransitiveFanin(graph, init_ops);
71 std::vector<const NodeDef*> vars;
72 for (const NodeDef* node : fanin) {
73 if (IsVariable(*node)) {
74 vars.push_back(node);
75 }
76 }
77 return vars;
78 }
79
NodesToPreserve() const80 std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
81 std::unordered_set<string> result;
82 for (const string& f : fetch) {
83 VLOG(1) << "Add fetch " << f;
84 result.insert(NodeName(f));
85 }
86 for (const auto& f : feed) {
87 VLOG(1) << "Add feed " << f.first;
88 result.insert(NodeName(f.first));
89 }
90 for (const auto& node : init_ops) {
91 result.insert(NodeName(node));
92 }
93 for (const auto& node : keep_ops) {
94 result.insert(NodeName(node));
95 }
96 if (!save_op.empty()) {
97 result.insert(NodeName(save_op));
98 }
99 if (!restore_op.empty()) {
100 result.insert(NodeName(restore_op));
101 }
102 if (!save_restore_loc_tensor.empty()) {
103 result.insert(NodeName(save_restore_loc_tensor));
104 }
105
106 for (const auto& queue_runner : queue_runners) {
107 for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
108 result.insert(NodeName(enqueue_op));
109 }
110 if (!queue_runner.close_op_name().empty()) {
111 result.insert(NodeName(queue_runner.close_op_name()));
112 }
113 if (!queue_runner.cancel_op_name().empty()) {
114 result.insert(NodeName(queue_runner.cancel_op_name()));
115 }
116 }
117
118 // Tensorflow functions do not prune stateful or dataset-output ops from
119 // the function body (see PruneFunctionBody in common_runtime/function.cc).
120 if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
121 FunctionLibraryDefinition fn_library(OpRegistry::Global(), graph.library());
122 for (const NodeDef& node : graph.node()) {
123 if (IsStateful(node, &fn_library) || IsDataset(node)) {
124 result.insert(node.name());
125 }
126 }
127 }
128
129 return result;
130 }
131
devices() const132 const std::unordered_set<string>& GrapplerItem::devices() const {
133 return devices_;
134 }
135
AddDevice(const string & device)136 Status GrapplerItem::AddDevice(const string& device) {
137 DeviceNameUtils::ParsedName name;
138
139 if (!DeviceNameUtils::ParseFullName(device, &name)) {
140 return errors::InvalidArgument("Invalid device name: device=", device);
141
142 } else if (!name.has_job || !name.has_replica || !name.has_task ||
143 !name.has_type || !name.has_id) {
144 return errors::InvalidArgument("Not a fully defined device name: device=",
145 device);
146 }
147
148 devices_.insert(DeviceNameUtils::ParsedNameToString(name));
149 return Status::OK();
150 }
151
AddDevices(const GrapplerItem & other)152 Status GrapplerItem::AddDevices(const GrapplerItem& other) {
153 std::vector<absl::string_view> invalid_devices;
154 for (const string& device : other.devices()) {
155 Status added = AddDevice(device);
156 if (!added.ok()) invalid_devices.emplace_back(device);
157 }
158 return invalid_devices.empty()
159 ? Status::OK()
160 : errors::InvalidArgument("Skipped invalid devices: [",
161 absl::StrJoin(invalid_devices, ", "),
162 "]");
163 }
164
InferDevicesFromGraph()165 Status GrapplerItem::InferDevicesFromGraph() {
166 absl::flat_hash_set<absl::string_view> invalid_devices;
167 for (const NodeDef& node : graph.node()) {
168 Status added = AddDevice(node.device());
169 if (!added.ok()) invalid_devices.insert(node.device());
170 }
171 VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
172 return invalid_devices.empty()
173 ? Status::OK()
174 : errors::InvalidArgument("Skipped invalid devices: [",
175 absl::StrJoin(invalid_devices, ", "),
176 "]");
177 }
178
ClearDevices()179 void GrapplerItem::ClearDevices() { devices_.clear(); }
180
optimization_options() const181 const GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options()
182 const {
183 return optimization_options_;
184 }
185
optimization_options()186 GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
187 return optimization_options_;
188 }
189
ComputeTransitiveFanin(const GraphDef & graph,const std::vector<string> & terminal_nodes)190 std::vector<const NodeDef*> ComputeTransitiveFanin(
191 const GraphDef& graph, const std::vector<string>& terminal_nodes) {
192 bool ill_formed = false;
193 std::vector<const NodeDef*> result =
194 ComputeTransitiveFanin(graph, terminal_nodes, &ill_formed);
195 CHECK(!ill_formed);
196 return result;
197 }
198
ComputeTransitiveFanin(const GraphDef & graph,const std::vector<string> & terminal_nodes,bool * ill_formed)199 std::vector<const NodeDef*> ComputeTransitiveFanin(
200 const GraphDef& graph, const std::vector<string>& terminal_nodes,
201 bool* ill_formed) {
202 *ill_formed = false;
203 std::unordered_map<string, const NodeDef*> name_to_node;
204 std::unordered_map<string, const NodeDef*> name_to_send;
205 for (const auto& node : graph.node()) {
206 name_to_node[node.name()] = &node;
207 if (node.op() == "_Send") {
208 const auto& attr = node.attr();
209 name_to_send[attr.at("tensor_name").s()] = &node;
210 }
211 }
212
213 std::vector<const NodeDef*> queue;
214 for (const string& root : terminal_nodes) {
215 const NodeDef* node = name_to_node[NodeName(root)];
216 if (!node) {
217 *ill_formed = true;
218 VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
219 return {};
220 }
221 queue.push_back(node);
222 }
223
224 std::vector<const NodeDef*> result;
225 std::unordered_set<const NodeDef*> visited;
226
227 while (!queue.empty()) {
228 const NodeDef* node = queue.back();
229 queue.pop_back();
230 if (!visited.insert(node).second) {
231 // The node has already been visited.
232 continue;
233 }
234 result.push_back(node);
235 for (const string& input : node->input()) {
236 const NodeDef* in = name_to_node[NodeName(input)];
237 if (!in) {
238 VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
239 *ill_formed = true;
240 return {};
241 }
242 queue.push_back(in);
243 }
244 if (node->op() == "_Recv") {
245 const auto& attr = node->attr();
246 const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
247 if (send) {
248 queue.push_back(send);
249 }
250 // Subgraph after partitioning may have either _Send or _Recv, not both.
251 // So, we do not set ill_formed for missing _Send.
252 }
253 }
254 return result;
255 }
256
257 } // end namespace grappler
258 } // end namespace tensorflow
259