1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <vector>
18
19 #include "tensorflow/cc/framework/scope_internal.h"
20 #include "tensorflow/core/common_runtime/shape_refiner.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/graph_constructor.h"
23 #include "tensorflow/core/graph/node_builder.h"
24
25 namespace tensorflow {
26
Scope(Impl * impl)27 Scope::Scope(Impl* impl) : impl_(impl) {}
28
Scope(const Scope & other)29 Scope::Scope(const Scope& other) : impl_(new Impl(*other.impl())) {}
30
~Scope()31 Scope::~Scope() {}
32
operator =(const Scope & other)33 Scope& Scope::operator=(const Scope& other) {
34 // We can't copy Impls because of the const members, use copy ctor instead
35 impl_.reset(new Impl(*other.impl_));
36 return *this;
37 }
38
Impl(Graph * graph,Status * status,NameMap * name_map,ShapeRefiner * refiner,bool disable_shape_inference)39 Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
40 ShapeRefiner* refiner, bool disable_shape_inference)
41 : graph_(graph),
42 status_(status),
43 name_map_(name_map),
44 refiner_(refiner),
45 scope_used_(nullptr),
46 colocation_constraints_(),
47 disable_shape_inference_(disable_shape_inference) {}
48
Impl(const std::shared_ptr<Graph> & graph,const std::shared_ptr<Status> & status,const std::shared_ptr<NameMap> & name_map,const std::shared_ptr<ShapeRefiner> & refiner)49 Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
50 const std::shared_ptr<Status>& status,
51 const std::shared_ptr<NameMap>& name_map,
52 const std::shared_ptr<ShapeRefiner>& refiner)
53 : graph_(graph),
54 status_(status),
55 name_map_(name_map),
56 refiner_(refiner),
57 scope_used_(nullptr),
58 colocation_constraints_(),
59 disable_shape_inference_(false) {}
60
NewRootScope()61 Scope Scope::NewRootScope() {
62 Graph* graph = new Graph(OpRegistry::Global());
63 ShapeRefiner* refiner =
64 new ShapeRefiner(graph->versions(), graph->op_registry());
65 return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
66 /* disable_shape_inference */ false));
67 }
68
DisabledShapeInferenceScope()69 Scope Scope::DisabledShapeInferenceScope() {
70 Graph* graph = new Graph(OpRegistry::Global());
71 ShapeRefiner* refiner =
72 new ShapeRefiner(graph->versions(), graph->op_registry());
73 return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
74 /* disable_shape_inference */ true));
75 }
76
Impl(const Scope & other,Tags::ScopeName,const string & name,bool copy_names)77 Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
78 bool copy_names)
79 : graph_(other.impl()->graph_),
80 status_(other.impl()->status_),
81 name_map_(copy_names ? other.impl()->name_map_
82 : std::shared_ptr<NameMap>(new NameMap)),
83 refiner_(other.impl()->refiner_),
84 scope_used_(nullptr),
85 control_deps_(other.impl()->control_deps_),
86 name_(name),
87 op_name_(""),
88 exit_on_error_(other.impl()->exit_on_error_),
89 kernel_label_(other.impl()->kernel_label_),
90 device_(other.impl()->device_),
91 colocation_constraints_(other.impl()->colocation_constraints_),
92 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
93
Impl(const Scope & other,Tags::OpName,const string & name,const string & op_name)94 Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
95 const string& op_name)
96 : graph_(other.impl()->graph_),
97 status_(other.impl()->status_),
98 name_map_(other.impl()->name_map_),
99 refiner_(other.impl()->refiner_),
100 scope_used_(other.impl()->scope_used_),
101 control_deps_(other.impl()->control_deps_),
102 name_(name),
103 op_name_(op_name),
104 exit_on_error_(other.impl()->exit_on_error_),
105 kernel_label_(other.impl()->kernel_label_),
106 device_(other.impl()->device_),
107 colocation_constraints_(other.impl()->colocation_constraints_),
108 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
109
Impl(const Scope & other,Tags::ControlDeps,std::vector<Operation> control_deps,bool clear_control_deps)110 Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
111 std::vector<Operation> control_deps, bool clear_control_deps)
112 : graph_(other.impl()->graph_),
113 status_(other.impl()->status_),
114 name_map_(other.impl()->name_map_),
115 refiner_(other.impl()->refiner_),
116 scope_used_(other.impl()->scope_used_),
117 control_deps_(
118 clear_control_deps
119 ? std::vector<Operation>()
120 : (control_deps.insert(control_deps.begin(),
121 other.impl()->control_deps_.begin(),
122 other.impl()->control_deps_.end()),
123 control_deps)),
124 name_(other.impl()->name_),
125 op_name_(other.impl()->op_name_),
126 exit_on_error_(other.impl()->exit_on_error_),
127 kernel_label_(other.impl()->kernel_label_),
128 device_(other.impl()->device_),
129 colocation_constraints_(other.impl()->colocation_constraints_),
130 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
131
Impl(const Scope & other,Tags::Device,const string & device)132 Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
133 : graph_(other.impl()->graph_),
134 status_(other.impl()->status_),
135 name_map_(other.impl()->name_map_),
136 refiner_(other.impl()->refiner_),
137 scope_used_(other.impl()->scope_used_),
138 control_deps_(other.impl()->control_deps_),
139 name_(other.impl()->name_),
140 op_name_(other.impl()->op_name_),
141 exit_on_error_(other.impl()->exit_on_error_),
142 kernel_label_(other.impl()->kernel_label_),
143 device_(device),
144 colocation_constraints_(other.impl()->colocation_constraints_),
145 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
146
Impl(const Scope & other,Tags::SingleUseScope,const string & op_name)147 Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
148 const string& op_name)
149 : graph_(other.impl()->graph_),
150 status_(other.impl()->status_),
151 name_map_(other.impl()->name_map_),
152 refiner_(other.impl()->refiner_),
153 scope_used_(new bool(false)),
154 control_deps_(other.impl()->control_deps_),
155 name_(other.impl()->name_),
156 op_name_(op_name),
157 exit_on_error_(other.impl()->exit_on_error_),
158 kernel_label_(other.impl()->kernel_label_),
159 device_(other.impl()->device_),
160 colocation_constraints_(other.impl()->colocation_constraints_),
161 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
162
Impl(const Scope & other,Tags::ExitOnError)163 Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
164 : graph_(other.impl()->graph_),
165 status_(other.impl()->status_),
166 name_map_(other.impl()->name_map_),
167 refiner_(other.impl()->refiner_),
168 scope_used_(other.impl()->scope_used_),
169 control_deps_(other.impl()->control_deps_),
170 name_(other.impl()->name_),
171 op_name_(other.impl()->op_name_),
172 exit_on_error_(true),
173 kernel_label_(other.impl()->kernel_label_),
174 device_(other.impl()->device_),
175 colocation_constraints_(other.impl()->colocation_constraints_),
176 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
177
Impl(const Scope & other,Tags::KernelLabel,const string & kernel_label)178 Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
179 const string& kernel_label)
180 : graph_(other.impl()->graph_),
181 status_(other.impl()->status_),
182 name_map_(other.impl()->name_map_),
183 refiner_(other.impl()->refiner_),
184 scope_used_(other.impl()->scope_used_),
185 control_deps_(other.impl()->control_deps_),
186 name_(other.impl()->name_),
187 op_name_(other.impl()->op_name_),
188 exit_on_error_(other.impl()->exit_on_error_),
189 kernel_label_(kernel_label),
190 device_(other.impl()->device_),
191 colocation_constraints_(other.impl()->colocation_constraints_),
192 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
193
Impl(const Scope & other,Tags::Colocate,const Operation & colocate_with_op,bool clear_colocations)194 Scope::Impl::Impl(const Scope& other, Tags::Colocate,
195 const Operation& colocate_with_op, bool clear_colocations)
196 : graph_(other.impl()->graph_),
197 status_(other.impl()->status_),
198 name_map_(other.impl()->name_map_),
199 refiner_(other.impl()->refiner_),
200 scope_used_(other.impl()->scope_used_),
201 control_deps_(other.impl()->control_deps_),
202 name_(other.impl()->name_),
203 op_name_(other.impl()->op_name_),
204 exit_on_error_(other.impl()->exit_on_error_),
205 kernel_label_(other.impl()->kernel_label_),
206 device_(other.impl()->device_),
207 colocation_constraints_(
208 clear_colocations
209 ? std::unordered_set<string>()
210 : other.impl()->GetColocationConstraints(colocate_with_op)),
211 disable_shape_inference_(other.impl()->disable_shape_inference_) {}
212
GetColocationConstraints(const Operation & colocate_with_op) const213 std::unordered_set<string> Scope::Impl::GetColocationConstraints(
214 const Operation& colocate_with_op) const {
215 std::unordered_set<string> current_constraints(colocation_constraints_);
216 const AttrSlice attrs = colocate_with_op.node()->attrs();
217 std::vector<string> node_constraints;
218 if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
219 for (const string& entry : node_constraints) {
220 StringPiece s(entry);
221 if (s.Consume(kColocationGroupPrefix)) {
222 current_constraints.insert(s.ToString());
223 }
224 }
225 } else {
226 current_constraints.insert(colocate_with_op.node()->name());
227 }
228 return current_constraints;
229 }
230
ok() const231 bool Scope::ok() const { return impl()->status_->ok(); }
232
graph() const233 Graph* Scope::graph() const { return impl()->graph_.get(); }
234
graph_as_shared_ptr() const235 std::shared_ptr<Graph> Scope::graph_as_shared_ptr() const {
236 return impl()->graph_;
237 }
238
status() const239 Status Scope::status() const { return *impl()->status_; }
240
control_deps() const241 const std::vector<Operation>& Scope::control_deps() const {
242 return impl()->control_deps_;
243 }
244
UpdateStatus(const Status s) const245 void Scope::UpdateStatus(const Status s) const {
246 impl()->status_->Update(s);
247 if (impl()->exit_on_error_ && !ok()) {
248 LOG(FATAL) << *impl()->status_;
249 }
250 }
251
ToGraphDef(GraphDef * gdef) const252 Status Scope::ToGraphDef(GraphDef* gdef) const {
253 if (!ok()) {
254 return *impl()->status_;
255 }
256 graph()->ToGraphDef(gdef);
257 return Status::OK();
258 }
259
ToGraph(Graph * g) const260 Status Scope::ToGraph(Graph* g) const {
261 if (ok()) {
262 GraphDef graph_def;
263 graph()->ToGraphDef(&graph_def);
264 GraphConstructorOptions opts;
265 UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
266 }
267 return *impl()->status_;
268 }
269
UpdateBuilder(NodeBuilder * builder) const270 void Scope::UpdateBuilder(NodeBuilder* builder) const {
271 std::vector<Node*> control_inputs;
272 for (const auto& op : impl()->control_deps_) {
273 control_inputs.push_back(op.node());
274 }
275 builder->ControlInputs(control_inputs);
276
277 if (!impl()->kernel_label_.empty()) {
278 builder->Attr("_kernel", impl()->kernel_label_);
279 }
280
281 if (!impl()->colocation_constraints_.empty()) {
282 std::vector<string> constraints(impl()->colocation_constraints_.begin(),
283 impl()->colocation_constraints_.end());
284 // Sort the set.
285 std::sort(constraints.begin(), constraints.end());
286 // Add loc:@ prefix
287 std::transform(constraints.begin(), constraints.end(), constraints.begin(),
288 [](const string& s) {
289 return strings::StrCat(kColocationGroupPrefix, s);
290 });
291 builder->Attr(kColocationAttrName, constraints);
292 }
293 if (!impl()->device_.empty()) {
294 builder->Device(impl()->device_);
295 }
296 }
297
GetUniqueName(const string & prefix,bool check_single_use) const298 string Scope::Impl::GetUniqueName(const string& prefix,
299 bool check_single_use) const {
300 if (check_single_use && single_use_scope()) {
301 if (*scope_used_) {
302 *status_ =
303 errors::AlreadyExists(prefix, " already exists in the current scope");
304 return "";
305 }
306 *scope_used_ = true;
307 return prefix;
308 }
309 auto entry = name_map_->find(prefix);
310 string unique_name = prefix;
311 if (entry == name_map_->end()) {
312 name_map_->insert({prefix, 0});
313 } else {
314 unique_name = strings::StrCat(unique_name, "_", ++entry->second);
315 }
316 return unique_name;
317 }
318
GetNameForOp(const string & default_name) const319 string Scope::Impl::GetNameForOp(const string& default_name) const {
320 const string unique_name =
321 GetUniqueName(default_name, true /* check_single_use */);
322 const string sep = name_.empty() || unique_name.empty() ? "" : "/";
323 return strings::StrCat(name_, sep, unique_name);
324 }
325
GetUniqueNameForOp(const string & default_name) const326 string Scope::GetUniqueNameForOp(const string& default_name) const {
327 if (impl()->single_use_scope()) {
328 if (impl()->op_name_.empty() || *impl()->scope_used_) {
329 *impl()->status_ =
330 errors::InvalidArgument("Cannot get a unique name in this scope");
331 return "";
332 }
333 *impl()->scope_used_ = true;
334 return impl()->op_name_;
335 }
336 return impl()->op_name_.empty() ? impl()->GetNameForOp(default_name)
337 : impl()->GetNameForOp(impl()->op_name_);
338 }
339
NewSubScope(const string & child_scope_name) const340 Scope Scope::NewSubScope(const string& child_scope_name) const {
341 if (child_scope_name.empty()) {
342 return Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->name_,
343 true /* copy_names */));
344 }
345 const string unique_name =
346 impl()->GetUniqueName(child_scope_name, false /* check_single_use */);
347 const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/";
348 return Scope(new Impl(*this, Impl::Tags::ScopeName(),
349 strings::StrCat(impl()->name_, sep, unique_name),
350 false /* copy_names */));
351 }
352
WithOpName(const string & op_name) const353 Scope Scope::WithOpName(const string& op_name) const {
354 if (impl()->single_use_scope()) {
355 UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
356 " on this scope"));
357 return *this;
358 }
359 return Scope(new Impl(*this, Impl::Tags::OpName(), impl()->name_, op_name));
360 }
361
WithControlDependencies(const gtl::ArraySlice<Operation> & control_deps) const362 Scope Scope::WithControlDependencies(
363 const gtl::ArraySlice<Operation>& control_deps) const {
364 return Scope(
365 new Impl(*this, Impl::Tags::ControlDeps(),
366 std::vector<Operation>(control_deps.begin(), control_deps.end()),
367 /* clear_control_deps */ false));
368 }
369
WithControlDependencies(const Output & control_dep) const370 Scope Scope::WithControlDependencies(const Output& control_dep) const {
371 return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
372 std::vector<Operation>(1, control_dep.op()),
373 /* clear_control_deps */ false));
374 }
375
WithNoControlDependencies() const376 Scope Scope::WithNoControlDependencies() const {
377 return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
378 std::vector<Operation>(),
379 /* clear_control_deps */ true));
380 }
381
WithDevice(const string & device) const382 Scope Scope::WithDevice(const string& device) const {
383 return Scope(new Impl(*this, Impl::Tags::Device(), device));
384 }
385
ColocateWith(const Operation & op) const386 Scope Scope::ColocateWith(const Operation& op) const {
387 return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
388 /* clear_colocations */ false));
389 }
390
ClearColocation() const391 Scope Scope::ClearColocation() const {
392 return Scope(new Impl(*this, Impl::Tags::Colocate(), Operation(),
393 /* clear_colocations */ true));
394 }
395
ExitOnError() const396 Scope Scope::ExitOnError() const {
397 return Scope(new Impl(*this, Impl::Tags::ExitOnError()));
398 }
399
WithKernelLabel(const string & kernel_label) const400 Scope Scope::WithKernelLabel(const string& kernel_label) const {
401 return Scope(new Impl(*this, Impl::Tags::KernelLabel(), kernel_label));
402 }
403
GetCompositeOpScopes(const string & composite_op_name) const404 CompositeOpScopes Scope::GetCompositeOpScopes(
405 const string& composite_op_name) const {
406 if (impl()->op_name_.empty() && composite_op_name.empty()) {
407 UpdateStatus(errors::InvalidArgument(
408 "Cannot create composite op scopes with empty name"));
409 return {*this, *this};
410 }
411 if (!impl()->single_use_scope()) {
412 Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name
413 : impl()->op_name_);
414 const string child_op_sep = impl()->name_.empty() ? "" : "_";
415 const string child_name =
416 strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_);
417 return {child,
418 Scope(new Impl(child, Impl::Tags::SingleUseScope(), child_name))};
419 } else {
420 return {Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->op_name_,
421 true /* copy_names */)),
422 *this};
423 }
424 }
425
DoShapeInference(Node * node) const426 Status Scope::DoShapeInference(Node* node) const {
427 if (impl_->disable_shape_inference_) return Status::OK();
428 return impl_->refiner_->AddNode(node);
429 }
430
431 class InternalScope {
432 public:
433 // NewScope doesn't take ownership of the inputs.
NewScope(Graph * graph,Status * status,ShapeRefiner * refiner)434 static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
435 Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap;
436 for (const Node* node : graph->nodes()) {
437 (*name_map)[node->name()] = 0;
438 }
439 // We provide null destructors for these shared ptrs (except for name_map)
440 // since the caller owns them and doesn't want the scope to destroy them.
441 return Scope(new Scope::Impl(
442 std::shared_ptr<Graph>(graph, [](Graph*) {}),
443 std::shared_ptr<Status>(status, [](Status*) {}),
444 std::shared_ptr<Scope::Impl::NameMap>(name_map),
445 std::shared_ptr<ShapeRefiner>(refiner, [](ShapeRefiner*) {})));
446 }
447 };
448
NewInternalScope(Graph * graph,Status * status,ShapeRefiner * refiner)449 Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
450 return InternalScope::NewScope(graph, status, refiner);
451 }
452
453 } // namespace tensorflow
454