• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_
17 #define TENSORFLOW_CC_FRAMEWORK_OPS_H_
18 
19 #include <type_traits>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 
27 namespace tensorflow {
28 
29 /// @defgroup core Core Tensorflow API
30 
31 class Output;
32 
33 /// @addtogroup core
34 /// @{
35 
36 /// Represents a node in the computation graph.
37 class Operation {
38  public:
Operation()39   Operation() : node_(nullptr) {}
40   explicit Operation(Node* n);
41 
num_inputs()42   int32 num_inputs() const { return node_->num_inputs(); }
input_type(int32 o)43   DataType input_type(int32 o) const { return node_->input_type(o); }
44   Output input(int32 i) const;
45 
num_outputs()46   int32 num_outputs() const { return node_->num_outputs(); }
output_type(int32 o)47   DataType output_type(int32 o) const { return node_->output_type(o); }
48   Output output(int32 i) const;
49 
node()50   Node* node() const { return node_; }
51 
52   uint64 hash(int32 index) const;
53 
54   bool operator==(const Operation& other) const { return node_ == other.node_; }
55 
56  private:
57   typedef std::vector<std::pair<Node*, int32>> Inputs;
58   static Inputs GetInputs(Node* node);
59 
60   Inputs inputs_;
61   Node* node_;
62 };
63 
64 /// Represents a tensor value produced by an Operation.
65 class Output {
66  public:
67   Output() = default;
Output(Node * n)68   explicit Output(Node* n) : op_(n) {}
Output(Node * n,int32 index)69   Output(Node* n, int32 index) : op_(n), index_(index) {}
Output(const Operation & op,int32 index)70   Output(const Operation& op, int32 index) : op_(op), index_(index) {}
71 
op()72   Operation op() const { return op_; }
node()73   Node* node() const { return op().node(); }
index()74   int32 index() const { return index_; }
type()75   DataType type() const { return op_.output_type(index_); }
name()76   string name() const { return strings::StrCat(node()->name(), ":", index()); }
77   bool operator==(const Output& other) const {
78     return op_ == other.op_ && index_ == other.index_;
79   }
80 
hash()81   uint64 hash() const { return op_.hash(index_); }
82 
83  private:
84   Operation op_ = Operation(nullptr);
85   int32 index_ = 0;
86 };
87 
88 /// Hash class that can be used for e.g. storing Outputs in an unordered_map
89 struct OutputHash {
operatorOutputHash90   std::size_t operator()(const Output& output) const {
91     return Hash64Combine(std::hash<Node*>()(output.node()),
92                          std::hash<int32>()(output.index()));
93   }
94 };
95 
96 /// Represents a tensor value that can be used as an operand to an Operation.
97 class Input {
98  public:
99   /// Initializer enables constructing an Input object from various kinds of C++
100   /// constants such as simple primitive constants and nested initializer lists
101   /// representing a multi-dimensional array. Initializer constructors are all
102   /// templates, so the aforementioned kinds of C++ constants can be used to
103   /// construct an Initializer. Initializer stores the value it got constructed
104   /// with in a Tensor object.
105   struct Initializer {
106     /// Construct from a scalar value of an arithmetic type or a type that can
107     /// be converted to a string (eg. a string literal).
108     template <typename T, typename = typename std::enable_if<
109                               std::is_arithmetic<T>::value ||
110                               std::is_convertible<T, string>::value>::type>
InitializerInitializer111     Initializer(const T& v) {  // NOLINT(runtime/explicit)
112       typedef typename RealType<T>::type RealT;
113       Tensor t(DataTypeToEnum<RealT>::v(), TensorShape());
114       t.flat<T>()(0) = RealT(v);
115       tensor = t;
116     }
117 
InitializerInitializer118     Initializer(const Tensor& t) : tensor(t) {}  // NOLINT(runtime/explicit)
119 
120     /// Construct from a scalar value and an explicit shape
121     template <typename T, typename = typename std::enable_if<
122                               std::is_arithmetic<T>::value ||
123                               std::is_convertible<T, string>::value>::type>
InitializerInitializer124     Initializer(const T& v, const TensorShape& shape) {
125       typedef typename RealType<T>::type RealT;
126       Tensor t(DataTypeToEnum<RealT>::v(), shape);
127       for (int64 i = 0; i < t.NumElements(); ++i) {
128         t.flat<T>()(i) = RealT(v);
129       }
130       tensor = t;
131     }
132 
133     /// Construct from a initializer list of scalars (a one-dimensional tensor).
134     template <typename T, typename = typename std::enable_if<
135                               std::is_arithmetic<T>::value ||
136                               std::is_convertible<T, string>::value>::type>
InitializerInitializer137     Initializer(
138         const std::initializer_list<T>& v) {  // NOLINT(runtime/explicit)
139       typedef typename RealType<T>::type RealT;
140       Tensor t(DataTypeToEnum<RealT>::v(),
141                TensorShape{static_cast<int>(v.size())});
142       std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
143       tensor = t;
144     }
145 
146     /// Construct from a initializer list of scalars and an explicit shape.
147     template <typename T, typename = typename std::enable_if<
148                               std::is_arithmetic<T>::value ||
149                               std::is_convertible<T, string>::value>::type>
InitializerInitializer150     Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
151       typedef typename RealType<T>::type RealT;
152       Tensor t(DataTypeToEnum<RealT>::v(), shape);
153       if (t.NumElements() != static_cast<int64>(v.size())) {
154         status = errors::InvalidArgument(
155             "Cannot construct a tensor with ", t.NumElements(),
156             " from an initializer list with ", v.size(), " elements");
157         return;
158       }
159       std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
160       tensor = t;
161     }
162 
163     /// Construct a multi-dimensional tensor from a nested initializer
164     /// list. Note that C++ syntax allows nesting of arbitrarily typed
165     /// initializer lists, so such invalid initializers cannot be disallowed at
166     /// compile time. This function performs checks to make sure that the nested
167     /// initializer list is indeed a valid multi-dimensional tensor.
168     Initializer(const std::initializer_list<Initializer>& v);
169 
170     // START_SKIP_DOXYGEN
171     template <typename T, bool = std::is_convertible<T, string>::value>
172     struct RealType {
173       typedef string type;
174     };
175 
176     template <typename T>
177     struct RealType<T, false> {
178       typedef T type;
179     };
180     // END_SKIP_DOXYGEN
181 
182     TensorProto AsTensorProto() {
183       TensorProto tensor_proto;
184       if (tensor.NumElements() > 1) {
185         tensor.AsProtoTensorContent(&tensor_proto);
186       } else {
187         tensor.AsProtoField(&tensor_proto);
188       }
189       return tensor_proto;
190     }
191 
192     Status status;
193     Tensor tensor;
194   };
195 
196   /// All of Input's constructors are implicit. Input can be implicitly
197   /// constructed from the following objects :
198   /// * Output: This is so that the output of an Operation can be directly used
199   ///   as the input to a op wrapper, which takes Inputs.
200   /// * A scalar, or a multi-dimensional tensor specified as a recursive
201   ///   initializer list. This enables directly passing constants as
202   ///   inputs to op wrappers.
203   /// * A Tensor object.
204   Input(const Output& o) : output_(o) {}  // NOLINT(runtime/explicit)
205 
206   template <typename T, typename = typename std::enable_if<
207                             std::is_arithmetic<T>::value ||
208                             std::is_convertible<T, string>::value>::type>
209   Input(const T& v)  // NOLINT(runtime/explicit)
210       : Input(Initializer(v)) {}
211 
212   Input(const Initializer& init)  // NOLINT(runtime/explicit)
213       : status_(init.status),
214         tensor_(init.tensor) {}
215 
216   Input(const Tensor& t)  // NOLINT(runtime/explicit)
217       : status_(Status::OK()),
218         tensor_(t) {}
219 
220   Input(const std::initializer_list<Initializer>&
221             init) {  // NOLINT(runtime/explicit)
222     for (const auto& i : init) {
223       if (!i.status.ok()) {
224         status_ = i.status;
225         return;
226       }
227     }
228     tensor_ = Initializer(init).tensor;
229   }
230 
231   /// Constructor specifying a node name, index and datatype. This should only
232   /// be used for specifying a backward edge, needed by control flow.
233   Input(const string& name, int32 i, DataType dt)
234       : node_name_(name), index_(i), data_type_(dt) {}
235 
236   Node* node() const { return output_.node(); }
237   string node_name() const { return node_name_; }
238   int32 index() const { return node_name_.empty() ? output_.index() : index_; }
239   DataType data_type() const { return data_type_; }
240   Status status() const { return status_; }
241   const Tensor& tensor() const { return tensor_; }
242 
243  private:
244   Status status_;
245   Output output_ = Output(Operation(nullptr), 0);
246   Tensor tensor_;
247   const string node_name_ = "";
248   int32 index_ = 0;
249   DataType data_type_ = DT_INVALID;
250 };
251 
252 /// A type for representing the output of ops that produce more than one output,
253 /// or a list of tensors.
254 typedef std::vector<Output> OutputList;
255 
256 /// A type for representing the input to ops that require a list of tensors.
257 class InputList {
258  public:
259   /// Implicitly convert a list of outputs to a list of inputs. This is useful
260   /// to write code such as ops::Concat(ops::Split(x, 4)).
261   InputList(const OutputList& out) {  // NOLINT(runtime/explicit)
262     for (auto const& x : out) {
263       inputs_.push_back(x);
264     }
265   }
266 
267   InputList(
268       const std::initializer_list<Input>& inputs)  // NOLINT(runtime/explicit)
269       : inputs_(inputs.begin(), inputs.end()) {}
270 
271   InputList(const tensorflow::gtl::ArraySlice<Input>&
272                 inputs)  // NOLINT(runtime/explicit)
273       : inputs_(inputs.begin(), inputs.end()) {}
274 
275   InputList(
276       const std::initializer_list<Output>& out) {  // NOLINT(runtime/explicit)
277     for (auto const& x : out) {
278       inputs_.push_back(x);
279     }
280   }
281 
282   typename std::vector<Input>::iterator begin() { return inputs_.begin(); }
283   typename std::vector<Input>::iterator end() { return inputs_.end(); }
284   typename std::vector<Input>::const_iterator begin() const {
285     return inputs_.begin();
286   }
287   typename std::vector<Input>::const_iterator end() const {
288     return inputs_.end();
289   }
290 
291  private:
292   std::vector<Input> inputs_;
293 };
294 
295 /// @}
296 
297 }  // namespace tensorflow
298 
299 #endif  // TENSORFLOW_CC_FRAMEWORK_OPS_H_
300