• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
17 #define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
18 
19 #include <vector>
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 
28 namespace tensorflow {
29 
30 // Given a function like:
31 //   namespace ops {
32 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
33 //     if (opts.HaveError()) return nullptr;
34 //     static const string kOpName = "Identity";
35 //     NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
36 //                              opts.op_registry());
37 //     node_builder.Input(input);
38 //     return opts.FinalizeBuilder(&node_builder);
39 //   }
40 //   }  // namespace ops
41 //
42 //   // Or, alternatively:
43 //   namespace ops {
44 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
45 //     static const string kOpName = "Identity";
46 //     return UnaryOp(kOpName, input, opts);
47 //   }
48 //   }  // namespace ops
49 //
50 // You call it like:
51 //   GraphDefBuilder b;
52 //   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
53 //   Node* na = Const(7, b.opts());
54 //   // Note: WithName() returns a copy, opts is unchanged.
55 //   Node* nb = Const(5, b.opts().WithName("control-input"));
56 //   Node* nc = Identity(na, b.opts().WithControlInput(nb));
57 //   GraphDef graph_def;
58 //   Status status = b.ToGraphDef(&graph_def);
59 //   if (!status.ok()) { /* Handle error */ }
60 //
61 // In tests you can skip the status handling via:
62 //   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
63 //   ...
64 //   b.ToGraphDef(&graph_def);
65 
66 class GraphDefBuilder {
67  public:
68   // Options for adding a Node to a Graph.
69   class Options {
70    public:
71     // Sets the Graph (that Nodes will be added to) and the status.  The
72     // status may be set to nullptr, in which case errors cause CHECK
73     // failures.  The graph and status must outlive *this.
74     Options(Graph* graph, Status* status);
75     ~Options();
76 
77     // Methods for setting options.  These are const methods: they
78     // return a copy of *this with the option set.
79     Options WithName(StringPiece name) const;
80     Options WithDevice(StringPiece device) const;
81     Options WithControlInput(Node* control_input) const;
82     Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
83 
84     // Override the default value for an optional attr.
85     template <class T>
WithAttr(StringPiece attr_name,T && value)86     Options WithAttr(StringPiece attr_name, T&& value) const {
87       return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
88     }
89     // Note: overload needed to allow {...} expressions for value.
90     template <class T>
WithAttr(StringPiece attr_name,std::initializer_list<T> value)91     Options WithAttr(StringPiece attr_name,
92                      std::initializer_list<T> value) const {
93       return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
94     }
95 
96     // Methods for using options from a function that creates a Node.
97 
98     // Returns true if the status associated with *this has an error.
99     // Use this to skip processing that may depend on prior results.
HaveError()100     bool HaveError() const { return status_ != nullptr && !status_->ok(); }
101 
102     // Returns a string representation of the status associated with *this.
103     // Returns the string `"OK"` if the status doesn't have any error.
StatusToString()104     string StatusToString() const { return status_->ToString(); }
105 
106     // Given the Op type name, return a name for a node of that type.
107     // Uses the value set in WithName() if that has been called.  Otherwise,
108     // returns a name built out of the Op type name.
109     string GetNameForOp(StringPiece op) const;
110 
111     // Sets the device, adds control inputs, adds attrs, and calls Finalize().
112     // If Finalize returns an error, it is saved and this function returns
113     // nullptr.
114     Node* FinalizeBuilder(NodeBuilder* builder) const;
115 
116     // Updates the associated status, if any, or calls TF_CHECK_OK if none.
117     void UpdateStatus(const Status& status) const;
118 
119     // Accessor
op_registry()120     const OpRegistryInterface* op_registry() const {
121       return graph_->op_registry();
122     }
123 
124    private:
125     Options WithNameImpl(StringPiece name);
126     Options WithDeviceImpl(StringPiece device);
127     Options WithControlInputImpl(Node* control_input);
128     Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
129     template <class T>
WithAttrImpl(StringPiece name,T && value)130     Options WithAttrImpl(StringPiece name, T&& value) {
131       attrs_.emplace_back(string(name), AttrValue());
132       SetAttrValue(std::forward<T>(value), &attrs_.back().second);
133       return *this;
134     }
135 
136     Graph* const graph_;
137     Status* const status_;
138     string name_;
139     string device_;
140     std::vector<Node*> control_inputs_;
141     std::vector<std::pair<string, AttrValue>> attrs_;
142   };
143 
144   // Start building a new graph.
145   explicit GraphDefBuilder(
146       const OpRegistryInterface* op_registry = OpRegistry::Global())
graph_(op_registry)147       : graph_(op_registry), opts_(&graph_, &status_) {}
148 
149   // For use in tests, where you want to fail immediately on error instead
150   // of checking the status at the end.
151   enum TestFailImmediatelyType { kFailImmediately };
152   explicit GraphDefBuilder(
153       TestFailImmediatelyType,
154       const OpRegistryInterface* op_registry = OpRegistry::Global())
graph_(op_registry)155       : graph_(op_registry), opts_(&graph_, nullptr) {}
156 
157   // Gets the Options with the associated Graph and Status.
opts()158   const Options& opts() const { return opts_; }
159 
160   // Once all the nodes have been added, call this to get whether it was
161   // successful, and if so fill *graph_def.
162   Status ToGraphDef(GraphDef* graph_def) const;
163 
164   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
165   // registry. Ignores duplicate functions, and returns a bad status if an
166   // imported function differs from an existing function or op with the same
167   // name.
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)168   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
169     return graph_.AddFunctionLibrary(fdef_lib);
170   }
171 
172   // Returns whether a user-defined function with `name` already exists in the
173   // graph.
HasFunction(const string & name)174   bool HasFunction(const string& name) {
175     return graph_.flib_def().Find(name) != nullptr;
176   }
177 
178  private:
179   Graph graph_;
180   Status status_;
181   Options opts_;
182 };
183 
184 namespace ops {
185 
186 // A NodeOut may either be a regular input or back input.  Regular
187 // inputs are specified via either a Node* or a Node* and an output
188 // index.  Back inputs are specified by a node name, output index, and
189 // output type.
190 typedef NodeBuilder::NodeOut NodeOut;
191 
192 // For adding an Op with no inputs to a GraphDefBuilder.
193 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
194 
195 // For adding an Op with one input to a GraphDefBuilder.
196 Node* UnaryOp(const string& op_name, NodeOut input,
197               const GraphDefBuilder::Options& opts);
198 
199 // For adding an Op with two inputs to a GraphDefBuilder.
200 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
201                const GraphDefBuilder::Options& opts);
202 
203 }  // namespace ops
204 }  // namespace tensorflow
205 
206 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
207