1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_ 17 #define TENSORFLOW_CC_FRAMEWORK_OPS_H_ 18 19 #include <type_traits> 20 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor.pb.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/lib/hash/hash.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 27 namespace tensorflow { 28 29 /// @defgroup core Core Tensorflow API 30 31 class Output; 32 33 /// @addtogroup core 34 /// @{ 35 36 /// Represents a node in the computation graph. 37 class Operation { 38 public: Operation()39 Operation() : node_(nullptr) {} 40 explicit Operation(Node* n); 41 num_inputs()42 int32 num_inputs() const { return node_->num_inputs(); } input_type(int32 o)43 DataType input_type(int32 o) const { return node_->input_type(o); } 44 Output input(int32 i) const; 45 num_outputs()46 int32 num_outputs() const { return node_->num_outputs(); } output_type(int32 o)47 DataType output_type(int32 o) const { return node_->output_type(o); } 48 Output output(int32 i) const; 49 node()50 Node* node() const { return node_; } 51 52 uint64 hash(int32 index) const; 53 54 bool operator==(const Operation& other) const { return node_ == other.node_; } 55 56 private: 57 typedef std::vector<std::pair<Node*, int32>> Inputs; 58 static Inputs GetInputs(Node* node); 59 60 Inputs inputs_; 61 Node* node_; 62 }; 63 64 /// Represents a tensor value produced by an Operation. 65 class Output { 66 public: 67 Output() = default; Output(Node * n)68 explicit Output(Node* n) : op_(n) {} Output(Node * n,int32 index)69 Output(Node* n, int32 index) : op_(n), index_(index) {} Output(const Operation & op,int32 index)70 Output(const Operation& op, int32 index) : op_(op), index_(index) {} 71 op()72 Operation op() const { return op_; } node()73 Node* node() const { return op().node(); } index()74 int32 index() const { return index_; } type()75 DataType type() const { return op_.output_type(index_); } name()76 string name() const { return strings::StrCat(node()->name(), ":", index()); } 77 bool operator==(const Output& other) const { 78 return op_ == other.op_ && index_ == other.index_; 79 } 80 hash()81 uint64 hash() const { return op_.hash(index_); } 82 83 private: 84 Operation op_ = Operation(nullptr); 85 int32 index_ = 0; 86 }; 87 88 /// Hash class that can be used for e.g. storing Outputs in an unordered_map 89 struct OutputHash { operatorOutputHash90 std::size_t operator()(const Output& output) const { 91 return Hash64Combine(std::hash<Node*>()(output.node()), 92 std::hash<int32>()(output.index())); 93 } 94 }; 95 96 /// Represents a tensor value that can be used as an operand to an Operation. 97 class Input { 98 public: 99 /// Initializer enables constructing an Input object from various kinds of C++ 100 /// constants such as simple primitive constants and nested initializer lists 101 /// representing a multi-dimensional array. Initializer constructors are all 102 /// templates, so the aforementioned kinds of C++ constants can be used to 103 /// construct an Initializer. Initializer stores the value it got constructed 104 /// with in a Tensor object. 105 struct Initializer { 106 /// Construct from a scalar value of an arithmetic type or a type that can 107 /// be converted to a string (eg. a string literal). 108 template <typename T, typename = typename std::enable_if< 109 std::is_arithmetic<T>::value || 110 std::is_convertible<T, string>::value>::type> InitializerInitializer111 Initializer(const T& v) { // NOLINT(runtime/explicit) 112 typedef typename RealType<T>::type RealT; 113 Tensor t(DataTypeToEnum<RealT>::v(), TensorShape()); 114 t.flat<T>()(0) = RealT(v); 115 tensor = t; 116 } 117 InitializerInitializer118 Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit) 119 120 /// Construct from a scalar value and an explicit shape 121 template <typename T, typename = typename std::enable_if< 122 std::is_arithmetic<T>::value || 123 std::is_convertible<T, string>::value>::type> InitializerInitializer124 Initializer(const T& v, const TensorShape& shape) { 125 typedef typename RealType<T>::type RealT; 126 Tensor t(DataTypeToEnum<RealT>::v(), shape); 127 for (int64 i = 0; i < t.NumElements(); ++i) { 128 t.flat<T>()(i) = RealT(v); 129 } 130 tensor = t; 131 } 132 133 /// Construct from a initializer list of scalars (a one-dimensional tensor). 134 template <typename T, typename = typename std::enable_if< 135 std::is_arithmetic<T>::value || 136 std::is_convertible<T, string>::value>::type> InitializerInitializer137 Initializer( 138 const std::initializer_list<T>& v) { // NOLINT(runtime/explicit) 139 typedef typename RealType<T>::type RealT; 140 Tensor t(DataTypeToEnum<RealT>::v(), 141 TensorShape{static_cast<int>(v.size())}); 142 std::copy_n(v.begin(), v.size(), t.flat<RealT>().data()); 143 tensor = t; 144 } 145 146 /// Construct from a initializer list of scalars and an explicit shape. 147 template <typename T, typename = typename std::enable_if< 148 std::is_arithmetic<T>::value || 149 std::is_convertible<T, string>::value>::type> InitializerInitializer150 Initializer(const std::initializer_list<T>& v, const TensorShape& shape) { 151 typedef typename RealType<T>::type RealT; 152 Tensor t(DataTypeToEnum<RealT>::v(), shape); 153 if (t.NumElements() != static_cast<int64>(v.size())) { 154 status = errors::InvalidArgument( 155 "Cannot construct a tensor with ", t.NumElements(), 156 " from an initializer list with ", v.size(), " elements"); 157 return; 158 } 159 std::copy_n(v.begin(), v.size(), t.flat<RealT>().data()); 160 tensor = t; 161 } 162 163 /// Construct a multi-dimensional tensor from a nested initializer 164 /// list. Note that C++ syntax allows nesting of arbitrarily typed 165 /// initializer lists, so such invalid initializers cannot be disallowed at 166 /// compile time. This function performs checks to make sure that the nested 167 /// initializer list is indeed a valid multi-dimensional tensor. 168 Initializer(const std::initializer_list<Initializer>& v); 169 170 // START_SKIP_DOXYGEN 171 template <typename T, bool = std::is_convertible<T, string>::value> 172 struct RealType { 173 typedef string type; 174 }; 175 176 template <typename T> 177 struct RealType<T, false> { 178 typedef T type; 179 }; 180 // END_SKIP_DOXYGEN 181 182 TensorProto AsTensorProto() { 183 TensorProto tensor_proto; 184 if (tensor.NumElements() > 1) { 185 tensor.AsProtoTensorContent(&tensor_proto); 186 } else { 187 tensor.AsProtoField(&tensor_proto); 188 } 189 return tensor_proto; 190 } 191 192 Status status; 193 Tensor tensor; 194 }; 195 196 /// All of Input's constructors are implicit. Input can be implicitly 197 /// constructed from the following objects : 198 /// * Output: This is so that the output of an Operation can be directly used 199 /// as the input to a op wrapper, which takes Inputs. 200 /// * A scalar, or a multi-dimensional tensor specified as a recursive 201 /// initializer list. This enables directly passing constants as 202 /// inputs to op wrappers. 203 /// * A Tensor object. 204 Input(const Output& o) : output_(o) {} // NOLINT(runtime/explicit) 205 206 template <typename T, typename = typename std::enable_if< 207 std::is_arithmetic<T>::value || 208 std::is_convertible<T, string>::value>::type> 209 Input(const T& v) // NOLINT(runtime/explicit) 210 : Input(Initializer(v)) {} 211 212 Input(const Initializer& init) // NOLINT(runtime/explicit) 213 : status_(init.status), 214 tensor_(init.tensor) {} 215 216 Input(const Tensor& t) // NOLINT(runtime/explicit) 217 : status_(Status::OK()), 218 tensor_(t) {} 219 220 Input(const std::initializer_list<Initializer>& 221 init) { // NOLINT(runtime/explicit) 222 for (const auto& i : init) { 223 if (!i.status.ok()) { 224 status_ = i.status; 225 return; 226 } 227 } 228 tensor_ = Initializer(init).tensor; 229 } 230 231 /// Constructor specifying a node name, index and datatype. This should only 232 /// be used for specifying a backward edge, needed by control flow. 233 Input(const string& name, int32 i, DataType dt) 234 : node_name_(name), index_(i), data_type_(dt) {} 235 236 Node* node() const { return output_.node(); } 237 string node_name() const { return node_name_; } 238 int32 index() const { return node_name_.empty() ? output_.index() : index_; } 239 DataType data_type() const { return data_type_; } 240 Status status() const { return status_; } 241 const Tensor& tensor() const { return tensor_; } 242 243 private: 244 Status status_; 245 Output output_ = Output(Operation(nullptr), 0); 246 Tensor tensor_; 247 const string node_name_ = ""; 248 int32 index_ = 0; 249 DataType data_type_ = DT_INVALID; 250 }; 251 252 /// A type for representing the output of ops that produce more than one output, 253 /// or a list of tensors. 254 typedef std::vector<Output> OutputList; 255 256 /// A type for representing the input to ops that require a list of tensors. 257 class InputList { 258 public: 259 /// Implicitly convert a list of outputs to a list of inputs. This is useful 260 /// to write code such as ops::Concat(ops::Split(x, 4)). 261 InputList(const OutputList& out) { // NOLINT(runtime/explicit) 262 for (auto const& x : out) { 263 inputs_.push_back(x); 264 } 265 } 266 267 InputList( 268 const std::initializer_list<Input>& inputs) // NOLINT(runtime/explicit) 269 : inputs_(inputs.begin(), inputs.end()) {} 270 271 InputList(const tensorflow::gtl::ArraySlice<Input>& 272 inputs) // NOLINT(runtime/explicit) 273 : inputs_(inputs.begin(), inputs.end()) {} 274 275 InputList( 276 const std::initializer_list<Output>& out) { // NOLINT(runtime/explicit) 277 for (auto const& x : out) { 278 inputs_.push_back(x); 279 } 280 } 281 282 typename std::vector<Input>::iterator begin() { return inputs_.begin(); } 283 typename std::vector<Input>::iterator end() { return inputs_.end(); } 284 typename std::vector<Input>::const_iterator begin() const { 285 return inputs_.begin(); 286 } 287 typename std::vector<Input>::const_iterator end() const { 288 return inputs_.end(); 289 } 290 291 private: 292 std::vector<Input> inputs_; 293 }; 294 295 /// @} 296 297 } // namespace tensorflow 298 299 #endif // TENSORFLOW_CC_FRAMEWORK_OPS_H_ 300