1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/grappler_item.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30
31 namespace tensorflow {
32 namespace grappler {
33
CreateOptOptionsForEager()34 GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
35 GrapplerItem::OptimizationOptions optimization_options;
36 // Tensorflow 2.0 in eager mode with automatic control dependencies will
37 // prune all nodes that are not in the transitive fanin of the fetch nodes.
38 // However because the function will be executed via FunctionLibraryRuntime,
39 // and current function implementation does not prune stateful and dataset
40 // ops, we rely on Grappler to do the correct graph pruning.
41 optimization_options.allow_pruning_stateful_and_dataset_ops = true;
42
43 optimization_options.is_eager_mode = true;
44
45 // All the nested function calls will be executed and optimized via
46 // PartitionedCallOp, there is no need to optimize functions now.
47 optimization_options.optimize_function_library = false;
48
49 return optimization_options;
50 }
51
WithGraph(GraphDef && graph_def) const52 GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
53 GrapplerItem item;
54 item.id = id;
55 item.feed = feed;
56 item.fetch = fetch;
57 item.init_ops = init_ops;
58 item.keep_ops = keep_ops;
59 item.expected_init_time = expected_init_time;
60 item.save_op = save_op;
61 item.restore_op = restore_op;
62 item.save_restore_loc_tensor = save_restore_loc_tensor;
63 item.queue_runners = queue_runners;
64 item.devices_ = devices_;
65 item.optimization_options_ = optimization_options_;
66 item.graph.Swap(&graph_def);
67 return item;
68 }
69
MainOpsFanin() const70 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
71 std::vector<const NodeDef*> fanin_nodes;
72 TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
73 return fanin_nodes;
74 }
75
EnqueueOpsFanin() const76 std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
77 std::vector<string> enqueue_ops;
78 for (const auto& queue_runner : queue_runners) {
79 for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
80 enqueue_ops.push_back(enqueue_op);
81 }
82 }
83 std::vector<const NodeDef*> fanin_nodes;
84 TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
85 return fanin_nodes;
86 }
87
InitOpsFanin() const88 std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
89 std::vector<const NodeDef*> fanin_nodes;
90 TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin_nodes));
91 return fanin_nodes;
92 }
93
MainVariables() const94 std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
95 std::vector<const NodeDef*> fanin;
96 TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin));
97 std::vector<const NodeDef*> vars;
98 for (const NodeDef* node : fanin) {
99 if (IsVariable(*node)) {
100 vars.push_back(node);
101 }
102 }
103 return vars;
104 }
105
NodesToPreserve() const106 std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
107 std::unordered_set<string> result;
108 for (const string& f : fetch) {
109 VLOG(1) << "Add fetch " << f;
110 result.insert(NodeName(f));
111 }
112 for (const auto& f : feed) {
113 VLOG(1) << "Add feed " << f.first;
114 result.insert(NodeName(f.first));
115 }
116 for (const auto& node : init_ops) {
117 result.insert(NodeName(node));
118 }
119 for (const auto& node : keep_ops) {
120 result.insert(NodeName(node));
121 }
122 if (!save_op.empty()) {
123 result.insert(NodeName(save_op));
124 }
125 if (!restore_op.empty()) {
126 result.insert(NodeName(restore_op));
127 }
128 if (!save_restore_loc_tensor.empty()) {
129 result.insert(NodeName(save_restore_loc_tensor));
130 }
131
132 for (const auto& queue_runner : queue_runners) {
133 for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
134 result.insert(NodeName(enqueue_op));
135 }
136 if (!queue_runner.close_op_name().empty()) {
137 result.insert(NodeName(queue_runner.close_op_name()));
138 }
139 if (!queue_runner.cancel_op_name().empty()) {
140 result.insert(NodeName(queue_runner.cancel_op_name()));
141 }
142 }
143
144 absl::optional<FunctionLibraryDefinition> fn_library;
145 if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
146 fn_library.emplace(OpRegistry::Global(), graph.library());
147 }
148 for (const NodeDef& node : graph.node()) {
149 const auto attrs = AttrSlice(&node.attr());
150
151 // Tensorflow functions do not prune stateful or dataset-output ops from
152 // the function body (see PruneFunctionBody in common_runtime/function.cc).
153 if (!optimization_options_.allow_pruning_stateful_and_dataset_ops &&
154 (IsStateful(node, &*fn_library) || IsDataset(node))) {
155 result.insert(node.name());
156 }
157
158 // Do not remove ops with attribute _grappler_do_not_remove. This is useful
159 // for debugging.
160 bool do_not_remove;
161 if (TryGetNodeAttr(attrs, "_grappler_do_not_remove", &do_not_remove) &&
162 do_not_remove) {
163 result.insert(node.name());
164 }
165 }
166
167 return result;
168 }
169
devices() const170 const std::unordered_set<string>& GrapplerItem::devices() const {
171 return devices_;
172 }
173
AddDevice(const string & device)174 Status GrapplerItem::AddDevice(const string& device) {
175 DeviceNameUtils::ParsedName name;
176
177 if (!DeviceNameUtils::ParseFullName(device, &name)) {
178 return errors::InvalidArgument("Invalid device name: device=", device);
179
180 } else if (!name.has_job || !name.has_replica || !name.has_task ||
181 !name.has_type || !name.has_id) {
182 return errors::InvalidArgument("Not a fully defined device name: device=",
183 device);
184 }
185
186 devices_.insert(DeviceNameUtils::ParsedNameToString(name));
187 return Status::OK();
188 }
189
AddDevices(const GrapplerItem & other)190 Status GrapplerItem::AddDevices(const GrapplerItem& other) {
191 std::vector<absl::string_view> invalid_devices;
192 for (const string& device : other.devices()) {
193 Status added = AddDevice(device);
194 if (!added.ok()) invalid_devices.emplace_back(device);
195 }
196 return invalid_devices.empty()
197 ? Status::OK()
198 : errors::InvalidArgument("Skipped invalid devices: [",
199 absl::StrJoin(invalid_devices, ", "),
200 "]");
201 }
202
InferDevicesFromGraph()203 Status GrapplerItem::InferDevicesFromGraph() {
204 absl::flat_hash_set<absl::string_view> invalid_devices;
205 for (const NodeDef& node : graph.node()) {
206 Status added = AddDevice(node.device());
207 if (!added.ok()) invalid_devices.insert(node.device());
208 }
209 VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
210 return invalid_devices.empty()
211 ? Status::OK()
212 : errors::InvalidArgument("Skipped invalid devices: [",
213 absl::StrJoin(invalid_devices, ", "),
214 "]");
215 }
216
ClearDevices()217 void GrapplerItem::ClearDevices() { devices_.clear(); }
218
optimization_options() const219 const GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options()
220 const {
221 return optimization_options_;
222 }
223
optimization_options()224 GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
225 return optimization_options_;
226 }
227
228 } // end namespace grappler
229 } // end namespace tensorflow
230