• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/grappler_item.h"
17 
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/util/device_name_utils.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 
WithGraph(GraphDef && graph_def) const33 GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
34   GrapplerItem item;
35   item.id = id;
36   item.feed = feed;
37   item.fetch = fetch;
38   item.init_ops = init_ops;
39   item.keep_ops = keep_ops;
40   item.expected_init_time = expected_init_time;
41   item.save_op = save_op;
42   item.restore_op = restore_op;
43   item.save_restore_loc_tensor = save_restore_loc_tensor;
44   item.queue_runners = queue_runners;
45   item.devices_ = devices_;
46   item.optimization_options_ = optimization_options_;
47   item.graph.Swap(&graph_def);
48   return item;
49 }
50 
MainOpsFanin() const51 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
52   return ComputeTransitiveFanin(graph, fetch);
53 }
54 
EnqueueOpsFanin() const55 std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
56   std::vector<string> enqueue_ops;
57   for (const auto& queue_runner : queue_runners) {
58     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
59       enqueue_ops.push_back(enqueue_op);
60     }
61   }
62   return ComputeTransitiveFanin(graph, enqueue_ops);
63 }
64 
InitOpsFanin() const65 std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
66   return ComputeTransitiveFanin(graph, init_ops);
67 }
68 
MainVariables() const69 std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
70   std::vector<const NodeDef*> fanin = ComputeTransitiveFanin(graph, init_ops);
71   std::vector<const NodeDef*> vars;
72   for (const NodeDef* node : fanin) {
73     if (IsVariable(*node)) {
74       vars.push_back(node);
75     }
76   }
77   return vars;
78 }
79 
NodesToPreserve() const80 std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
81   std::unordered_set<string> result;
82   for (const string& f : fetch) {
83     VLOG(1) << "Add fetch " << f;
84     result.insert(NodeName(f));
85   }
86   for (const auto& f : feed) {
87     VLOG(1) << "Add feed " << f.first;
88     result.insert(NodeName(f.first));
89   }
90   for (const auto& node : init_ops) {
91     result.insert(NodeName(node));
92   }
93   for (const auto& node : keep_ops) {
94     result.insert(NodeName(node));
95   }
96   if (!save_op.empty()) {
97     result.insert(NodeName(save_op));
98   }
99   if (!restore_op.empty()) {
100     result.insert(NodeName(restore_op));
101   }
102   if (!save_restore_loc_tensor.empty()) {
103     result.insert(NodeName(save_restore_loc_tensor));
104   }
105 
106   for (const auto& queue_runner : queue_runners) {
107     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
108       result.insert(NodeName(enqueue_op));
109     }
110     if (!queue_runner.close_op_name().empty()) {
111       result.insert(NodeName(queue_runner.close_op_name()));
112     }
113     if (!queue_runner.cancel_op_name().empty()) {
114       result.insert(NodeName(queue_runner.cancel_op_name()));
115     }
116   }
117 
118   // Tensorflow functions do not prune stateful or dataset-output ops from
119   // the function body (see PruneFunctionBody in common_runtime/function.cc).
120   if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
121     FunctionLibraryDefinition fn_library(OpRegistry::Global(), graph.library());
122     for (const NodeDef& node : graph.node()) {
123       if (IsStateful(node, &fn_library) || IsDataset(node)) {
124         result.insert(node.name());
125       }
126     }
127   }
128 
129   return result;
130 }
131 
devices() const132 const std::unordered_set<string>& GrapplerItem::devices() const {
133   return devices_;
134 }
135 
AddDevice(const string & device)136 Status GrapplerItem::AddDevice(const string& device) {
137   DeviceNameUtils::ParsedName name;
138 
139   if (!DeviceNameUtils::ParseFullName(device, &name)) {
140     return errors::InvalidArgument("Invalid device name: device=", device);
141 
142   } else if (!name.has_job || !name.has_replica || !name.has_task ||
143              !name.has_type || !name.has_id) {
144     return errors::InvalidArgument("Not a fully defined device name: device=",
145                                    device);
146   }
147 
148   devices_.insert(DeviceNameUtils::ParsedNameToString(name));
149   return Status::OK();
150 }
151 
AddDevices(const GrapplerItem & other)152 Status GrapplerItem::AddDevices(const GrapplerItem& other) {
153   std::vector<absl::string_view> invalid_devices;
154   for (const string& device : other.devices()) {
155     Status added = AddDevice(device);
156     if (!added.ok()) invalid_devices.emplace_back(device);
157   }
158   return invalid_devices.empty()
159              ? Status::OK()
160              : errors::InvalidArgument("Skipped invalid devices: [",
161                                        absl::StrJoin(invalid_devices, ", "),
162                                        "]");
163 }
164 
InferDevicesFromGraph()165 Status GrapplerItem::InferDevicesFromGraph() {
166   absl::flat_hash_set<absl::string_view> invalid_devices;
167   for (const NodeDef& node : graph.node()) {
168     Status added = AddDevice(node.device());
169     if (!added.ok()) invalid_devices.insert(node.device());
170   }
171   VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
172   return invalid_devices.empty()
173              ? Status::OK()
174              : errors::InvalidArgument("Skipped invalid devices: [",
175                                        absl::StrJoin(invalid_devices, ", "),
176                                        "]");
177 }
178 
ClearDevices()179 void GrapplerItem::ClearDevices() { devices_.clear(); }
180 
optimization_options() const181 const GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options()
182     const {
183   return optimization_options_;
184 }
185 
optimization_options()186 GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
187   return optimization_options_;
188 }
189 
ComputeTransitiveFanin(const GraphDef & graph,const std::vector<string> & terminal_nodes)190 std::vector<const NodeDef*> ComputeTransitiveFanin(
191     const GraphDef& graph, const std::vector<string>& terminal_nodes) {
192   bool ill_formed = false;
193   std::vector<const NodeDef*> result =
194       ComputeTransitiveFanin(graph, terminal_nodes, &ill_formed);
195   CHECK(!ill_formed);
196   return result;
197 }
198 
ComputeTransitiveFanin(const GraphDef & graph,const std::vector<string> & terminal_nodes,bool * ill_formed)199 std::vector<const NodeDef*> ComputeTransitiveFanin(
200     const GraphDef& graph, const std::vector<string>& terminal_nodes,
201     bool* ill_formed) {
202   *ill_formed = false;
203   std::unordered_map<string, const NodeDef*> name_to_node;
204   std::unordered_map<string, const NodeDef*> name_to_send;
205   for (const auto& node : graph.node()) {
206     name_to_node[node.name()] = &node;
207     if (node.op() == "_Send") {
208       const auto& attr = node.attr();
209       name_to_send[attr.at("tensor_name").s()] = &node;
210     }
211   }
212 
213   std::vector<const NodeDef*> queue;
214   for (const string& root : terminal_nodes) {
215     const NodeDef* node = name_to_node[NodeName(root)];
216     if (!node) {
217       *ill_formed = true;
218       VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
219       return {};
220     }
221     queue.push_back(node);
222   }
223 
224   std::vector<const NodeDef*> result;
225   std::unordered_set<const NodeDef*> visited;
226 
227   while (!queue.empty()) {
228     const NodeDef* node = queue.back();
229     queue.pop_back();
230     if (!visited.insert(node).second) {
231       // The node has already been visited.
232       continue;
233     }
234     result.push_back(node);
235     for (const string& input : node->input()) {
236       const NodeDef* in = name_to_node[NodeName(input)];
237       if (!in) {
238         VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
239         *ill_formed = true;
240         return {};
241       }
242       queue.push_back(in);
243     }
244     if (node->op() == "_Recv") {
245       const auto& attr = node->attr();
246       const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
247       if (send) {
248         queue.push_back(send);
249       }
250       // Subgraph after partitioning may have either _Send or _Recv, not both.
251       // So, we do not set ill_formed for missing _Send.
252     }
253   }
254   return result;
255 }
256 
257 }  // end namespace grappler
258 }  // end namespace tensorflow
259