1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ 17 #define TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <unordered_set> 23 #include <vector> 24 25 #include "absl/strings/str_cat.h" 26 #include "tensorflow/cc/framework/ops.h" 27 #include "tensorflow/core/graph/graph_constructor.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/lib/gtl/array_slice.h" 30 31 namespace tensorflow { 32 33 class Graph; 34 class GraphDef; 35 class NodeBuilder; 36 struct CompositeOpScopes; 37 38 /// @addtogroup core 39 /// @{ 40 41 /// A `Scope` object represents a set of related TensorFlow ops that have the 42 /// same properties such as a common name prefix. 43 /// 44 /// A Scope object is a container for TensorFlow Op properties. Op constructors 45 /// get a Scope object as a mandatory first argument and the constructed op 46 /// acquires the properties in the object. 47 /// 48 /// A simple example: 49 /// 50 /// using namespace ops; 51 /// Scope root = Scope::NewRootScope(); 52 /// auto c1 = Const(root, { {1, 1} }); 53 /// auto m = MatMul(root, c1, { {41}, {1} }); 54 /// GraphDef gdef; 55 /// Status s = root.ToGraphDef(&gdef); 56 /// if (!s.ok()) { ... } 57 /// 58 /// Scope hierarchy: 59 /// 60 /// The Scope class provides various With<> functions that create a new scope. 61 /// The new scope typically has one property changed while other properties are 62 /// inherited from the parent scope. 63 /// NewSubScope(name) method appends `name` to the prefix of names for ops 64 /// created within the scope, and WithOpName() changes the suffix which 65 /// otherwise defaults to the type of the op. 66 /// 67 /// Name examples: 68 /// 69 /// Scope root = Scope::NewRootScope(); 70 /// Scope linear = root.NewSubScope("linear"); 71 /// // W will be named "linear/W" 72 /// auto W = Variable(linear.WithOpName("W"), 73 /// {2, 2}, DT_FLOAT); 74 /// // b will be named "linear/b_3" 75 /// int idx = 3; 76 /// auto b = Variable(linear.WithOpName("b_", idx), 77 /// {2}, DT_FLOAT); 78 /// auto x = Const(linear, {...}); // name: "linear/Const" 79 /// auto m = MatMul(linear, x, W); // name: "linear/MatMul" 80 /// auto r = BiasAdd(linear, m, b); // name: "linear/BiasAdd" 81 /// 82 /// Scope lifetime: 83 /// 84 /// A new scope is created by calling Scope::NewRootScope. This creates some 85 /// resources that are shared by all the child scopes that inherit from this 86 /// scope, directly or transitively. For instance, a new scope creates a new 87 /// Graph object to which operations are added when the new scope or its 88 /// children are used by an Op constructor. The new scope also has a Status 89 /// object which will be used to indicate errors by Op-constructor functions 90 /// called on any child scope. The Op-constructor functions have to check the 91 /// scope's status by calling the ok() method before proceeding to construct the 92 /// op. 93 /// 94 /// Thread safety: 95 /// 96 /// A `Scope` object is NOT thread-safe. Threads cannot concurrently call 97 /// op-constructor functions on the same `Scope` object. 98 class Scope { 99 public: 100 Scope(const Scope& other); 101 ~Scope(); 102 Scope& operator=(const Scope& other); 103 104 // The following functions are for users making graphs. They return brand new 105 // scopes, or scopes derived from an existing scope object. 106 107 /// Return a new scope. 108 /// This creates a new graph and all operations constructed in this graph 109 /// should use the returned object as the "root" scope. 110 static Scope NewRootScope(); 111 112 /// Return a new scope. Ops created with this scope will have 113 /// `name/child_scope_name` as the prefix. The actual name will be unique 114 /// in the current scope. All other properties are inherited from the current 115 /// scope. If `child_scope_name` is empty, the `/` is elided. 116 Scope NewSubScope(const string& child_scope_name) const; 117 118 /// Return a new scope. All ops created within the returned scope will have 119 /// names of the form `name/StrCat(fragments...)[_suffix]` 120 template <typename... Ty> WithOpName(Ty...fragments)121 Scope WithOpName(Ty... fragments) const { 122 return WithOpNameImpl(absl::StrCat(fragments...)); 123 } 124 125 /// Return a new scope. All ops created within the returned scope will have as 126 /// control dependencies the union of operations in the control_deps vector 127 /// and the control dependencies of the current scope. 128 Scope WithControlDependencies( 129 const gtl::ArraySlice<Operation>& control_deps) const; 130 /// Same as above, but convenient to add control dependency on the operation 131 /// producing the control_dep output. 132 Scope WithControlDependencies(const Output& control_dep) const; 133 134 /// Return a new scope. All ops created within the returned scope will have no 135 /// control dependencies on other operations. 136 Scope WithNoControlDependencies() const; 137 138 /// Return a new scope. All ops created within the returned scope will have 139 /// the device field set to 'device'. 140 Scope WithDevice(const string& device) const; 141 142 /// Returns a new scope. All ops created within the returned scope will have 143 /// their assigned device set to `assigned_device`. 144 Scope WithAssignedDevice(const string& assigned_device) const; 145 146 /// Returns a new scope. All ops created within the returned scope will have 147 /// their _XlaCluster attribute set to `xla_cluster`. 148 Scope WithXlaCluster(const string& xla_cluster) const; 149 150 /// Return a new scope. All ops created within the returned scope will be 151 /// co-located on the device where op is placed. 152 /// NOTE: This function is intended to be use internal libraries only for 153 /// controlling placement of ops on to devices. Public use is not encouraged 154 /// because the implementation of device placement is subject to change. 155 Scope ColocateWith(const Operation& op) const; 156 /// Convenience function for above. ColocateWith(const Output & out)157 Scope ColocateWith(const Output& out) const { return ColocateWith(out.op()); } 158 /// Clear all colocation constraints. 159 Scope ClearColocation() const; 160 161 /// Return a new scope. The op-constructor functions taking the returned scope 162 /// as the scope argument will exit as soon as an error is detected, instead 163 /// of setting the status on the scope. 164 Scope ExitOnError() const; 165 166 /// Return a new scope. All ops created with the new scope will have 167 /// kernel_label as the value for their '_kernel' attribute; 168 Scope WithKernelLabel(const string& kernel_label) const; 169 170 // The following functions are for scope object consumers. 171 172 /// Return a unique name, using default_name if an op name has not been 173 /// specified. 174 string GetUniqueNameForOp(const string& default_name) const; 175 176 /// Update the status on this scope. 177 /// Note: The status object is shared between all children of this scope. 178 /// If the resulting status is not Status::OK() and exit_on_error_ is set on 179 /// this scope, this function exits by calling LOG(FATAL). 180 void UpdateStatus(const Status s) const; 181 182 // START_SKIP_DOXYGEN 183 184 /// Update the builder with properties accumulated in this scope. Does not set 185 /// status(). 186 // TODO(skyewm): NodeBuilder is not part of public API 187 void UpdateBuilder(NodeBuilder* builder) const; 188 // END_SKIP_DOXYGEN 189 190 CompositeOpScopes GetCompositeOpScopes(const string& composite_op_name) const; 191 192 bool ok() const; 193 194 // TODO(skyewm): Graph is not part of public API 195 Graph* graph() const; 196 197 // TODO(skyewm): Graph is not part of public API 198 std::shared_ptr<Graph> graph_as_shared_ptr() const; 199 200 Status status() const; 201 202 /// If status() is Status::OK(), convert the Graph object stored in this scope 203 /// to a GraphDef proto and return Status::OK(). Otherwise, return the error 204 /// status as is without performing GraphDef conversion. 205 Status ToGraphDef(GraphDef* gdef) const; 206 207 // START_SKIP_DOXYGEN 208 209 /// If status() is Status::OK(), construct a Graph object using `opts` as the 210 /// GraphConstructorOptions, and return Status::OK if graph construction was 211 /// successful. Otherwise, return the error status. 212 // TODO(josh11b, keveman): Make this faster; right now it converts 213 // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds 214 // edges from the source and to the sink node, resolves back edges 215 // by name), and makes sure the resulting graph is valid. 216 Status ToGraph( 217 Graph* g, GraphConstructorOptions opts = GraphConstructorOptions{}) const; 218 219 // Calls AddNode() using this scope's ShapeRefiner. This exists in the public 220 // API to prevent custom op wrappers from needing access to shape_refiner.h or 221 // scope_internal.h. 222 // TODO(skyewm): remove this from public API 223 Status DoShapeInference(Node* node) const; 224 225 // Creates a new root scope that causes all DoShapeInference() calls to return 226 // Status::OK() (on the returned scope and any subscopes). Used for testing. 227 // TODO(skyewm): fix tests that still require this and eventually remove, or 228 // at least remove from public API 229 static Scope DisabledShapeInferenceScope(); 230 // END_SKIP_DOXYGEN 231 232 const std::vector<Operation>& control_deps() const; 233 234 // START_SKIP_DOXYGEN 235 class Impl; impl()236 Impl* impl() { return impl_.get(); } impl()237 const Impl* impl() const { return impl_.get(); } 238 // END_SKIP_DOXYGEN 239 240 private: 241 Scope WithOpNameImpl(const string& op_name) const; 242 243 friend class InternalScope; 244 std::unique_ptr<Impl> impl_; 245 explicit Scope(Impl*); 246 }; 247 248 /// A helper struct to hold the scopes that would be used by a function 249 /// constructing a composite op. 250 struct CompositeOpScopes { 251 /// Scope to be used for creating the local ops (primitive or other composite 252 /// ops). 253 Scope child; 254 /// Scope to be used for creating the last op. 255 Scope last; 256 }; 257 258 /// @} 259 260 } // namespace tensorflow 261 262 #endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ 263