1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ 17 #define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ 18 19 #include <functional> 20 #include <vector> 21 #include "tensorflow/core/framework/attr_value_util.h" 22 #include "tensorflow/core/framework/node_def.pb.h" 23 #include "tensorflow/core/framework/node_def_util.h" 24 #include "tensorflow/core/framework/op.h" 25 #include "tensorflow/core/framework/op_def.pb.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 31 namespace tensorflow { 32 33 class NodeDefBuilder; 34 typedef std::function<Status(const OpDef&, int, const NodeDef&, 35 NodeDefBuilder*)> 36 FakeInputFunctor; 37 38 // This is a helper for creating a NodeDef. Automatically sets attrs 39 // that can be inferred from the inputs, and uses default values 40 // (where they exist) for unspecified attrs. Example usage: 41 // 42 // NodeDef node_def; 43 // Status status = NodeDefBuilder(node_name, op_name) 44 // .Input(...) 45 // .Attr(...) 46 // .Finalize(&node_def); 47 // if (!status.ok()) return status; 48 // // Use node_def here. 49 class NodeDefBuilder { 50 public: 51 // To specify an output to be consumed by one of the Input() methods below. 52 struct NodeOut { 53 NodeOut(StringPiece n, int i, DataType dt); 54 NodeOut(); // uninitialized, call Reset() before use. 55 void Reset(StringPiece n, int i, DataType dt); 56 string node; 57 int index; 58 DataType data_type; 59 }; 60 61 // Specify the name and the Op (either via an OpDef or the name of 62 // the Op plus a registry) for the NodeDef. Other fields are 63 // specified by calling the methods below. 64 // REQUIRES: The OpDef must satisfy ValidateOpDef(). 65 NodeDefBuilder(StringPiece name, StringPiece op_name, 66 const OpRegistryInterface* op_registry = OpRegistry::Global()); 67 // REQUIRES: in addition, *op_def must outlive *this. 68 NodeDefBuilder(StringPiece name, const OpDef* op_def); 69 70 // You must call one Input() function per input_arg in the Op, 71 // *and in the same order as the input_args appear in the OpDef.* 72 73 // For inputs that take a single tensor. 74 NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt); 75 NodeDefBuilder& Input(const NodeOut& src); 76 77 // For inputs that take a list of tensors. 78 NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list); 79 80 // To create inputs in tests, see fake_input.h. 81 NodeDefBuilder& Input(FakeInputFunctor fake_input); 82 83 // Specify that this node must only run after src_node. 84 NodeDefBuilder& ControlInput(StringPiece src_node); 85 86 // Constrains what devices this node may be scheduled on. 87 NodeDefBuilder& Device(StringPiece device_spec); 88 89 // Sets the attr, if not already set. If already set with a different 90 // value, an error will be returned from Finalize(). 91 NodeDefBuilder& Attr(StringPiece name, const AttrValue& value); 92 NodeDefBuilder& Attr(StringPiece name, StringPiece value); 93 NodeDefBuilder& Attr(StringPiece name, const char* value); 94 NodeDefBuilder& Attr(StringPiece name, int32 value); 95 NodeDefBuilder& Attr(StringPiece name, int64 value); 96 NodeDefBuilder& Attr(StringPiece name, float value); 97 NodeDefBuilder& Attr(StringPiece name, double value); 98 NodeDefBuilder& Attr(StringPiece name, bool value); 99 NodeDefBuilder& Attr(StringPiece name, DataType value); 100 NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value); 101 NodeDefBuilder& Attr(StringPiece name, const Tensor& value); 102 NodeDefBuilder& Attr(StringPiece name, const TensorProto& value); 103 NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value); 104 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value); 105 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value); 106 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value); 107 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value); 108 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value); 109 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value); 110 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<bool> value); 111 NodeDefBuilder& Attr(StringPiece name, const std::vector<bool>& value); 112 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<DataType> value); 113 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<TensorShape> value); 114 NodeDefBuilder& Attr(StringPiece name, 115 gtl::ArraySlice<PartialTensorShape> value); 116 NodeDefBuilder& Attr(StringPiece name, 117 gtl::ArraySlice<TensorShapeProto> value); 118 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<Tensor> value); 119 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<NameAttrList> value); 120 121 template <class T> Attr(StringPiece name,std::initializer_list<T> value)122 NodeDefBuilder& Attr(StringPiece name, std::initializer_list<T> value) { 123 return Attr(name, gtl::ArraySlice<T>(value)); 124 } 125 126 // Finish building the NodeDef, returning any errors or setting 127 // *node_def if none. 128 // WARNING: Not all problems are detected! The resulting NodeDef may 129 // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. 130 Status Finalize(NodeDef* node_def) const; 131 132 // Accessors for the values set in the constructor. node_name()133 const string& node_name() const { return node_def_.name(); } op_def()134 const OpDef& op_def() const { return *op_def_; } 135 136 private: 137 // Called in the constructors. 138 void Initialize(); 139 140 // Get the current ArgDef and advance to the next one. Returns nullptr 141 // if no more inputs are available. 142 const OpDef::ArgDef* NextArgDef(); 143 144 // Returns true if there is still an input_arg available in *op_def_, 145 // otherwise adds to error_ and returns false. 146 bool NextArgAvailable(); 147 148 // These do the main work of the Input() methods. 149 void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node, 150 int src_index, DataType dt); 151 void ListInput(const OpDef::ArgDef* input_arg, 152 gtl::ArraySlice<NodeOut> src_list); 153 154 // Add "src_node:src_index" to the list of inputs in the node_def_. 155 void AddInput(StringPiece src_node, int src_index); 156 157 // Generate an error if you can't pass dt when expected is expected. 158 void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, 159 DataType dt); 160 161 // If input_arg->is_ref() is true, generate an error if dt is not a ref. 162 void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt); 163 164 // Makes dt a ref type if that is what the input_arg specifies. MaybeAddRef(const OpDef::ArgDef * input_arg,DataType dt)165 DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) { 166 return input_arg->is_ref() ? MakeRefType(dt) : dt; 167 } 168 169 const OpDef* op_def_; 170 NodeDef node_def_; 171 int inputs_specified_; 172 std::vector<string> control_inputs_; 173 std::vector<string> errors_; 174 }; 175 176 } // namespace tensorflow 177 178 #endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ 179