1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 17 #define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 18 19 #include <vector> 20 #include "tensorflow/core/framework/graph.pb.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/graph/graph.h" 23 #include "tensorflow/core/graph/node_builder.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/core/stringpiece.h" 26 #include "tensorflow/core/lib/gtl/array_slice.h" 27 28 namespace tensorflow { 29 30 // Given a function like: 31 // namespace ops { 32 // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { 33 // if (opts.HaveError()) return nullptr; 34 // static const string kOpName = "Identity"; 35 // NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, 36 // opts.op_registry()); 37 // node_builder.Input(input); 38 // return opts.FinalizeBuilder(&node_builder); 39 // } 40 // } // namespace ops 41 // 42 // // Or, alternatively: 43 // namespace ops { 44 // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { 45 // static const string kOpName = "Identity"; 46 // return UnaryOp(kOpName, input, opts); 47 // } 48 // } // namespace ops 49 // 50 // You call it like: 51 // GraphDefBuilder b; 52 // using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 53 // Node* na = Const(7, b.opts()); 54 // // Note: WithName() returns a copy, opts is unchanged. 55 // Node* nb = Const(5, b.opts().WithName("control-input")); 56 // Node* nc = Identity(na, b.opts().WithControlInput(nb)); 57 // GraphDef graph_def; 58 // Status status = b.ToGraphDef(&graph_def); 59 // if (!status.ok()) { /* Handle error */ } 60 // 61 // In tests you can skip the status handling via: 62 // GraphDefBuilder b(GraphDefBuilder::kFailImmediately); 63 // ... 64 // b.ToGraphDef(&graph_def); 65 66 class GraphDefBuilder { 67 public: 68 // Options for adding a Node to a Graph. 69 class Options { 70 public: 71 // Sets the Graph (that Nodes will be added to) and the status. The 72 // status may be set to nullptr, in which case errors cause CHECK 73 // failures. The graph and status must outlive *this. 74 Options(Graph* graph, Status* status); 75 ~Options(); 76 77 // Methods for setting options. These are const methods: they 78 // return a copy of *this with the option set. 79 Options WithName(StringPiece name) const; 80 Options WithDevice(StringPiece device) const; 81 Options WithControlInput(Node* control_input) const; 82 Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const; 83 84 // Override the default value for an optional attr. 85 template <class T> WithAttr(StringPiece attr_name,T && value)86 Options WithAttr(StringPiece attr_name, T&& value) const { 87 return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value)); 88 } 89 // Note: overload needed to allow {...} expressions for value. 90 template <class T> WithAttr(StringPiece attr_name,std::initializer_list<T> value)91 Options WithAttr(StringPiece attr_name, 92 std::initializer_list<T> value) const { 93 return WithAttr<std::initializer_list<T>>(attr_name, std::move(value)); 94 } 95 96 // Methods for using options from a function that creates a Node. 97 98 // Returns true if the status associated with *this has an error. 99 // Use this to skip processing that may depend on prior results. HaveError()100 bool HaveError() const { return status_ != nullptr && !status_->ok(); } 101 102 // Returns a string representation of the status associated with *this. 103 // Returns the string `"OK"` if the status doesn't have any error. StatusToString()104 string StatusToString() const { return status_->ToString(); } 105 106 // Given the Op type name, return a name for a node of that type. 107 // Uses the value set in WithName() if that has been called. Otherwise, 108 // returns a name built out of the Op type name. 109 string GetNameForOp(StringPiece op) const; 110 111 // Sets the device, adds control inputs, adds attrs, and calls Finalize(). 112 // If Finalize returns an error, it is saved and this function returns 113 // nullptr. 114 Node* FinalizeBuilder(NodeBuilder* builder) const; 115 116 // Updates the associated status, if any, or calls TF_CHECK_OK if none. 117 void UpdateStatus(const Status& status) const; 118 119 // Accessor op_registry()120 const OpRegistryInterface* op_registry() const { 121 return graph_->op_registry(); 122 } 123 124 private: 125 Options WithNameImpl(StringPiece name); 126 Options WithDeviceImpl(StringPiece device); 127 Options WithControlInputImpl(Node* control_input); 128 Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs); 129 template <class T> WithAttrImpl(StringPiece name,T && value)130 Options WithAttrImpl(StringPiece name, T&& value) { 131 attrs_.emplace_back(string(name), AttrValue()); 132 SetAttrValue(std::forward<T>(value), &attrs_.back().second); 133 return *this; 134 } 135 136 Graph* const graph_; 137 Status* const status_; 138 string name_; 139 string device_; 140 std::vector<Node*> control_inputs_; 141 std::vector<std::pair<string, AttrValue>> attrs_; 142 }; 143 144 // Start building a new graph. 145 explicit GraphDefBuilder( 146 const OpRegistryInterface* op_registry = OpRegistry::Global()) graph_(op_registry)147 : graph_(op_registry), opts_(&graph_, &status_) {} 148 149 // For use in tests, where you want to fail immediately on error instead 150 // of checking the status at the end. 151 enum TestFailImmediatelyType { kFailImmediately }; 152 explicit GraphDefBuilder( 153 TestFailImmediatelyType, 154 const OpRegistryInterface* op_registry = OpRegistry::Global()) graph_(op_registry)155 : graph_(op_registry), opts_(&graph_, nullptr) {} 156 157 // Gets the Options with the associated Graph and Status. opts()158 const Options& opts() const { return opts_; } 159 160 // Once all the nodes have been added, call this to get whether it was 161 // successful, and if so fill *graph_def. 162 Status ToGraphDef(GraphDef* graph_def) const; 163 164 // Adds the function and gradient definitions in `fdef_lib` to this graph's op 165 // registry. Ignores duplicate functions, and returns a bad status if an 166 // imported function differs from an existing function or op with the same 167 // name. AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)168 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { 169 return graph_.AddFunctionLibrary(fdef_lib); 170 } 171 172 // Returns whether a user-defined function with `name` already exists in the 173 // graph. HasFunction(const string & name)174 bool HasFunction(const string& name) { 175 return graph_.flib_def().Find(name) != nullptr; 176 } 177 178 private: 179 Graph graph_; 180 Status status_; 181 Options opts_; 182 }; 183 184 namespace ops { 185 186 // A NodeOut may either be a regular input or back input. Regular 187 // inputs are specified via either a Node* or a Node* and an output 188 // index. Back inputs are specified by a node name, output index, and 189 // output type. 190 typedef NodeBuilder::NodeOut NodeOut; 191 192 // For adding an Op with no inputs to a GraphDefBuilder. 193 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); 194 195 // For adding an Op with one input to a GraphDefBuilder. 196 Node* UnaryOp(const string& op_name, NodeOut input, 197 const GraphDefBuilder::Options& opts); 198 199 // For adding an Op with two inputs to a GraphDefBuilder. 200 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, 201 const GraphDefBuilder::Options& opts); 202 203 } // namespace ops 204 } // namespace tensorflow 205 206 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 207