• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/grappler_item.h"
17 
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 
CreateOptOptionsForEager()34 GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
35   GrapplerItem::OptimizationOptions optimization_options;
36   // Tensorflow 2.0 in eager mode with automatic control dependencies will
37   // prune all nodes that are not in the transitive fanin of the fetch nodes.
38   // However because the function will be executed via FunctionLibraryRuntime,
39   // and current function implementation does not prune stateful and dataset
40   // ops, we rely on Grappler to do the correct graph pruning.
41   optimization_options.allow_pruning_stateful_and_dataset_ops = true;
42 
43   optimization_options.is_eager_mode = true;
44 
45   // All the nested function calls will be executed and optimized via
46   // PartitionedCallOp, there is no need to optimize functions now.
47   optimization_options.optimize_function_library = false;
48 
49   return optimization_options;
50 }
51 
WithGraph(GraphDef && graph_def) const52 GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
53   GrapplerItem item;
54   item.id = id;
55   item.feed = feed;
56   item.fetch = fetch;
57   item.init_ops = init_ops;
58   item.keep_ops = keep_ops;
59   item.expected_init_time = expected_init_time;
60   item.save_op = save_op;
61   item.restore_op = restore_op;
62   item.save_restore_loc_tensor = save_restore_loc_tensor;
63   item.queue_runners = queue_runners;
64   item.devices_ = devices_;
65   item.optimization_options_ = optimization_options_;
66   item.graph.Swap(&graph_def);
67   return item;
68 }
69 
MainOpsFanin() const70 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
71   std::vector<const NodeDef*> fanin_nodes;
72   TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
73   return fanin_nodes;
74 }
75 
EnqueueOpsFanin() const76 std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
77   std::vector<string> enqueue_ops;
78   for (const auto& queue_runner : queue_runners) {
79     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
80       enqueue_ops.push_back(enqueue_op);
81     }
82   }
83   std::vector<const NodeDef*> fanin_nodes;
84   TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
85   return fanin_nodes;
86 }
87 
InitOpsFanin() const88 std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
89   std::vector<const NodeDef*> fanin_nodes;
90   TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin_nodes));
91   return fanin_nodes;
92 }
93 
MainVariables() const94 std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
95   std::vector<const NodeDef*> fanin;
96   TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin));
97   std::vector<const NodeDef*> vars;
98   for (const NodeDef* node : fanin) {
99     if (IsVariable(*node)) {
100       vars.push_back(node);
101     }
102   }
103   return vars;
104 }
105 
NodesToPreserve() const106 std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
107   std::unordered_set<string> result;
108   for (const string& f : fetch) {
109     VLOG(1) << "Add fetch " << f;
110     result.insert(NodeName(f));
111   }
112   for (const auto& f : feed) {
113     VLOG(1) << "Add feed " << f.first;
114     result.insert(NodeName(f.first));
115   }
116   for (const auto& node : init_ops) {
117     result.insert(NodeName(node));
118   }
119   for (const auto& node : keep_ops) {
120     result.insert(NodeName(node));
121   }
122   if (!save_op.empty()) {
123     result.insert(NodeName(save_op));
124   }
125   if (!restore_op.empty()) {
126     result.insert(NodeName(restore_op));
127   }
128   if (!save_restore_loc_tensor.empty()) {
129     result.insert(NodeName(save_restore_loc_tensor));
130   }
131 
132   for (const auto& queue_runner : queue_runners) {
133     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
134       result.insert(NodeName(enqueue_op));
135     }
136     if (!queue_runner.close_op_name().empty()) {
137       result.insert(NodeName(queue_runner.close_op_name()));
138     }
139     if (!queue_runner.cancel_op_name().empty()) {
140       result.insert(NodeName(queue_runner.cancel_op_name()));
141     }
142   }
143 
144   absl::optional<FunctionLibraryDefinition> fn_library;
145   if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
146     fn_library.emplace(OpRegistry::Global(), graph.library());
147   }
148   for (const NodeDef& node : graph.node()) {
149     const auto attrs = AttrSlice(&node.attr());
150 
151     // Tensorflow functions do not prune stateful or dataset-output ops from
152     // the function body (see PruneFunctionBody in common_runtime/function.cc).
153     if (!optimization_options_.allow_pruning_stateful_and_dataset_ops &&
154         (IsStateful(node, &*fn_library) || IsDataset(node))) {
155       result.insert(node.name());
156     }
157 
158     // Do not remove ops with attribute _grappler_do_not_remove. This is useful
159     // for debugging.
160     bool do_not_remove;
161     if (TryGetNodeAttr(attrs, "_grappler_do_not_remove", &do_not_remove) &&
162         do_not_remove) {
163       result.insert(node.name());
164     }
165   }
166 
167   return result;
168 }
169 
devices() const170 const std::unordered_set<string>& GrapplerItem::devices() const {
171   return devices_;
172 }
173 
AddDevice(const string & device)174 Status GrapplerItem::AddDevice(const string& device) {
175   DeviceNameUtils::ParsedName name;
176 
177   if (!DeviceNameUtils::ParseFullName(device, &name)) {
178     return errors::InvalidArgument("Invalid device name: device=", device);
179 
180   } else if (!name.has_job || !name.has_replica || !name.has_task ||
181              !name.has_type || !name.has_id) {
182     return errors::InvalidArgument("Not a fully defined device name: device=",
183                                    device);
184   }
185 
186   devices_.insert(DeviceNameUtils::ParsedNameToString(name));
187   return Status::OK();
188 }
189 
AddDevices(const GrapplerItem & other)190 Status GrapplerItem::AddDevices(const GrapplerItem& other) {
191   std::vector<absl::string_view> invalid_devices;
192   for (const string& device : other.devices()) {
193     Status added = AddDevice(device);
194     if (!added.ok()) invalid_devices.emplace_back(device);
195   }
196   return invalid_devices.empty()
197              ? Status::OK()
198              : errors::InvalidArgument("Skipped invalid devices: [",
199                                        absl::StrJoin(invalid_devices, ", "),
200                                        "]");
201 }
202 
InferDevicesFromGraph()203 Status GrapplerItem::InferDevicesFromGraph() {
204   absl::flat_hash_set<absl::string_view> invalid_devices;
205   for (const NodeDef& node : graph.node()) {
206     Status added = AddDevice(node.device());
207     if (!added.ok()) invalid_devices.insert(node.device());
208   }
209   VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
210   return invalid_devices.empty()
211              ? Status::OK()
212              : errors::InvalidArgument("Skipped invalid devices: [",
213                                        absl::StrJoin(invalid_devices, ", "),
214                                        "]");
215 }
216 
ClearDevices()217 void GrapplerItem::ClearDevices() { devices_.clear(); }
218 
optimization_options() const219 const GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options()
220     const {
221   return optimization_options_;
222 }
223 
optimization_options()224 GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
225   return optimization_options_;
226 }
227 
228 }  // end namespace grappler
229 }  // end namespace tensorflow
230