1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 17 #define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 18 19 #include <vector> 20 21 #include "tensorflow/core/framework/function.pb.h" 22 #include "tensorflow/core/framework/graph.pb.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/graph/graph.h" 25 #include "tensorflow/core/graph/node_builder.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/lib/core/stringpiece.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 30 namespace tensorflow { 31 32 // Given a function like: 33 // namespace ops { 34 // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { 35 // if (opts.HaveError()) return nullptr; 36 // static const string kOpName = "Identity"; 37 // NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, 38 // opts.op_registry()); 39 // node_builder.Input(input); 40 // return opts.FinalizeBuilder(&node_builder); 41 // } 42 // } // namespace ops 43 // 44 // // Or, alternatively: 45 // namespace ops { 46 // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { 47 // static const string kOpName = "Identity"; 48 // return UnaryOp(kOpName, input, opts); 49 // } 50 // } // namespace ops 51 // 52 // You call it like: 53 // GraphDefBuilder b; 54 // using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 55 // Node* na = Const(7, b.opts()); 56 // // Note: WithName() returns a copy, opts is unchanged. 57 // Node* nb = Const(5, b.opts().WithName("control-input")); 58 // Node* nc = Identity(na, b.opts().WithControlInput(nb)); 59 // GraphDef graph_def; 60 // Status status = b.ToGraphDef(&graph_def); 61 // if (!status.ok()) { /* Handle error */ } 62 // 63 // In tests you can skip the status handling via: 64 // GraphDefBuilder b(GraphDefBuilder::kFailImmediately); 65 // ... 66 // b.ToGraphDef(&graph_def); 67 68 class GraphDefBuilder { 69 public: 70 // Options for adding a Node to a Graph. 71 class Options { 72 public: 73 // Sets the Graph (that Nodes will be added to) and the status. The 74 // status may be set to nullptr, in which case errors cause CHECK 75 // failures. The graph and status must outlive *this. 76 Options(Graph* graph, Status* status); 77 ~Options(); 78 79 // Methods for setting options. These are const methods: they 80 // return a copy of *this with the option set. 81 Options WithName(StringPiece name) const; 82 Options WithDevice(StringPiece device) const; 83 Options WithControlInput(Node* control_input) const; 84 Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const; 85 86 // Override the default value for an optional attr. 87 template <class T> WithAttr(StringPiece attr_name,T && value)88 Options WithAttr(StringPiece attr_name, T&& value) const { 89 return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value)); 90 } 91 // Note: overload needed to allow {...} expressions for value. 92 template <class T> WithAttr(StringPiece attr_name,std::initializer_list<T> value)93 Options WithAttr(StringPiece attr_name, 94 std::initializer_list<T> value) const { 95 return WithAttr<std::initializer_list<T>>(attr_name, std::move(value)); 96 } 97 98 // Methods for using options from a function that creates a Node. 99 100 // Returns true if the status associated with *this has an error. 101 // Use this to skip processing that may depend on prior results. HaveError()102 bool HaveError() const { return status_ != nullptr && !status_->ok(); } 103 104 // Returns a string representation of the status associated with *this. 105 // Returns the string `"OK"` if the status doesn't have any error. StatusToString()106 string StatusToString() const { return status_->ToString(); } 107 108 // Given the Op type name, return a name for a node of that type. 109 // Uses the value set in WithName() if that has been called. Otherwise, 110 // returns a name built out of the Op type name. 111 string GetNameForOp(StringPiece op) const; 112 113 // Sets the device, adds control inputs, adds attrs, and calls Finalize(). 114 // If Finalize returns an error, it is saved and this function returns 115 // nullptr. 116 Node* FinalizeBuilder(NodeBuilder* builder) const; 117 118 // Updates the associated status, if any, or calls TF_CHECK_OK if none. 119 void UpdateStatus(const Status& status) const; 120 121 // Accessor op_registry()122 const OpRegistryInterface* op_registry() const { 123 return graph_->op_registry(); 124 } 125 126 private: 127 Options WithNameImpl(StringPiece name); 128 Options WithDeviceImpl(StringPiece device); 129 Options WithControlInputImpl(Node* control_input); 130 Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs); 131 template <class T> WithAttrImpl(StringPiece name,T && value)132 Options WithAttrImpl(StringPiece name, T&& value) { 133 attrs_.emplace_back(string(name), AttrValue()); 134 SetAttrValue(std::forward<T>(value), &attrs_.back().second); 135 return *this; 136 } 137 138 Graph* const graph_; 139 Status* const status_; 140 string name_; 141 string device_; 142 std::vector<Node*> control_inputs_; 143 std::vector<std::pair<string, AttrValue>> attrs_; 144 }; 145 146 // Start building a new graph. 147 explicit GraphDefBuilder( 148 const OpRegistryInterface* op_registry = OpRegistry::Global()) graph_(op_registry)149 : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, &status_) {} 150 151 // For use in tests, where you want to fail immediately on error instead 152 // of checking the status at the end. 153 enum TestFailImmediatelyType { kFailImmediately }; 154 explicit GraphDefBuilder( 155 TestFailImmediatelyType, 156 const OpRegistryInterface* op_registry = OpRegistry::Global()) graph_(op_registry)157 : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, nullptr) {} 158 159 // Gets the Options with the associated Graph and Status. opts()160 const Options& opts() const { return opts_; } 161 162 // Once all the nodes have been added, call this to get whether it was 163 // successful, and if so fill *graph_def. 164 Status ToGraphDef(GraphDef* graph_def) const; 165 166 // Adds the function and gradient definitions in `fdef_lib` to this graph's op 167 // registry. Ignores duplicate functions, and returns a bad status if an 168 // imported function differs from an existing function or op with the same 169 // name. AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)170 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { 171 return flib_def_.AddLibrary(fdef_lib); 172 } 173 174 // Returns whether a user-defined function with `name` already exists in the 175 // graph. HasFunction(const string & name)176 bool HasFunction(const string& name) { 177 return flib_def_.Find(name) != nullptr; 178 } 179 180 private: 181 Graph graph_; 182 FunctionLibraryDefinition flib_def_; 183 Status status_; 184 Options opts_; 185 }; 186 187 namespace ops { 188 189 // A NodeOut may either be a regular input or back input. Regular 190 // inputs are specified via either a Node* or a Node* and an output 191 // index. Back inputs are specified by a node name, output index, and 192 // output type. 193 typedef NodeBuilder::NodeOut NodeOut; 194 195 // For adding an Op with no inputs to a GraphDefBuilder. 196 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); 197 198 // For adding an Op with one input to a GraphDefBuilder. 199 Node* UnaryOp(const string& op_name, NodeOut input, 200 const GraphDefBuilder::Options& opts); 201 202 // For adding an Op with two inputs to a GraphDefBuilder. 203 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, 204 const GraphDefBuilder::Options& opts); 205 206 // For adding an Op with three inputs to a GraphDefBuilder. 207 Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c, 208 const GraphDefBuilder::Options& opts); 209 210 } // namespace ops 211 } // namespace tensorflow 212 213 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 214