• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
17 #define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/core/framework/function.pb.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 
30 namespace tensorflow {
31 
32 // Given a function like:
33 //   namespace ops {
34 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
35 //     if (opts.HaveError()) return nullptr;
36 //     static const string kOpName = "Identity";
37 //     NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
38 //                              opts.op_registry());
39 //     node_builder.Input(input);
40 //     return opts.FinalizeBuilder(&node_builder);
41 //   }
42 //   }  // namespace ops
43 //
44 //   // Or, alternatively:
45 //   namespace ops {
46 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
47 //     static const string kOpName = "Identity";
48 //     return UnaryOp(kOpName, input, opts);
49 //   }
50 //   }  // namespace ops
51 //
52 // You call it like:
53 //   GraphDefBuilder b;
54 //   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
55 //   Node* na = Const(7, b.opts());
56 //   // Note: WithName() returns a copy, opts is unchanged.
57 //   Node* nb = Const(5, b.opts().WithName("control-input"));
58 //   Node* nc = Identity(na, b.opts().WithControlInput(nb));
59 //   GraphDef graph_def;
60 //   Status status = b.ToGraphDef(&graph_def);
61 //   if (!status.ok()) { /* Handle error */ }
62 //
63 // In tests you can skip the status handling via:
64 //   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
65 //   ...
66 //   b.ToGraphDef(&graph_def);
67 
68 class GraphDefBuilder {
69  public:
70   // Options for adding a Node to a Graph.
71   class Options {
72    public:
73     // Sets the Graph (that Nodes will be added to) and the status.  The
74     // status may be set to nullptr, in which case errors cause CHECK
75     // failures.  The graph and status must outlive *this.
76     Options(Graph* graph, Status* status);
77     ~Options();
78 
79     // Methods for setting options.  These are const methods: they
80     // return a copy of *this with the option set.
81     Options WithName(StringPiece name) const;
82     Options WithDevice(StringPiece device) const;
83     Options WithControlInput(Node* control_input) const;
84     Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
85 
86     // Override the default value for an optional attr.
87     template <class T>
WithAttr(StringPiece attr_name,T && value)88     Options WithAttr(StringPiece attr_name, T&& value) const {
89       return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
90     }
91     // Note: overload needed to allow {...} expressions for value.
92     template <class T>
WithAttr(StringPiece attr_name,std::initializer_list<T> value)93     Options WithAttr(StringPiece attr_name,
94                      std::initializer_list<T> value) const {
95       return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
96     }
97 
98     // Methods for using options from a function that creates a Node.
99 
100     // Returns true if the status associated with *this has an error.
101     // Use this to skip processing that may depend on prior results.
HaveError()102     bool HaveError() const { return status_ != nullptr && !status_->ok(); }
103 
104     // Returns a string representation of the status associated with *this.
105     // Returns the string `"OK"` if the status doesn't have any error.
StatusToString()106     string StatusToString() const { return status_->ToString(); }
107 
108     // Given the Op type name, return a name for a node of that type.
109     // Uses the value set in WithName() if that has been called.  Otherwise,
110     // returns a name built out of the Op type name.
111     string GetNameForOp(StringPiece op) const;
112 
113     // Sets the device, adds control inputs, adds attrs, and calls Finalize().
114     // If Finalize returns an error, it is saved and this function returns
115     // nullptr.
116     Node* FinalizeBuilder(NodeBuilder* builder) const;
117 
118     // Updates the associated status, if any, or calls TF_CHECK_OK if none.
119     void UpdateStatus(const Status& status) const;
120 
121     // Accessor
op_registry()122     const OpRegistryInterface* op_registry() const {
123       return graph_->op_registry();
124     }
125 
126    private:
127     Options WithNameImpl(StringPiece name);
128     Options WithDeviceImpl(StringPiece device);
129     Options WithControlInputImpl(Node* control_input);
130     Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
131     template <class T>
WithAttrImpl(StringPiece name,T && value)132     Options WithAttrImpl(StringPiece name, T&& value) {
133       attrs_.emplace_back(string(name), AttrValue());
134       SetAttrValue(std::forward<T>(value), &attrs_.back().second);
135       return *this;
136     }
137 
138     Graph* const graph_;
139     Status* const status_;
140     string name_;
141     string device_;
142     std::vector<Node*> control_inputs_;
143     std::vector<std::pair<string, AttrValue>> attrs_;
144   };
145 
146   // Start building a new graph.
147   explicit GraphDefBuilder(
148       const OpRegistryInterface* op_registry = OpRegistry::Global())
graph_(op_registry)149       : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, &status_) {}
150 
151   // For use in tests, where you want to fail immediately on error instead
152   // of checking the status at the end.
153   enum TestFailImmediatelyType { kFailImmediately };
154   explicit GraphDefBuilder(
155       TestFailImmediatelyType,
156       const OpRegistryInterface* op_registry = OpRegistry::Global())
graph_(op_registry)157       : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, nullptr) {}
158 
159   // Gets the Options with the associated Graph and Status.
opts()160   const Options& opts() const { return opts_; }
161 
162   // Once all the nodes have been added, call this to get whether it was
163   // successful, and if so fill *graph_def.
164   Status ToGraphDef(GraphDef* graph_def) const;
165 
166   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
167   // registry. Ignores duplicate functions, and returns a bad status if an
168   // imported function differs from an existing function or op with the same
169   // name.
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)170   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
171     return flib_def_.AddLibrary(fdef_lib);
172   }
173 
174   // Returns whether a user-defined function with `name` already exists in the
175   // graph.
HasFunction(const string & name)176   bool HasFunction(const string& name) {
177     return flib_def_.Find(name) != nullptr;
178   }
179 
180  private:
181   Graph graph_;
182   FunctionLibraryDefinition flib_def_;
183   Status status_;
184   Options opts_;
185 };
186 
187 namespace ops {
188 
189 // A NodeOut may either be a regular input or back input.  Regular
190 // inputs are specified via either a Node* or a Node* and an output
191 // index.  Back inputs are specified by a node name, output index, and
192 // output type.
193 typedef NodeBuilder::NodeOut NodeOut;
194 
195 // For adding an Op with no inputs to a GraphDefBuilder.
196 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
197 
198 // For adding an Op with one input to a GraphDefBuilder.
199 Node* UnaryOp(const string& op_name, NodeOut input,
200               const GraphDefBuilder::Options& opts);
201 
202 // For adding an Op with two inputs to a GraphDefBuilder.
203 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
204                const GraphDefBuilder::Options& opts);
205 
206 // For adding an Op with three inputs to a GraphDefBuilder.
207 Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c,
208                 const GraphDefBuilder::Options& opts);
209 
210 }  // namespace ops
211 }  // namespace tensorflow
212 
213 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
214