• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_
17 #define TENSORFLOW_CC_FRAMEWORK_OPS_H_
18 
19 #include <type_traits>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 
27 namespace tensorflow {
28 
29 /// @defgroup core Core Tensorflow API
30 
31 class Output;
32 
33 /// @addtogroup core
34 /// @{
35 
36 /// Represents a node in the computation graph.
37 class Operation {
38  public:
Operation()39   Operation() : node_(nullptr) {}
40   explicit Operation(Node* n);
41 
num_inputs()42   int32 num_inputs() const { return node_->num_inputs(); }
input_type(int32_t o)43   DataType input_type(int32_t o) const { return node_->input_type(o); }
44   Output input(int32_t i) const;
45 
num_outputs()46   int32 num_outputs() const { return node_->num_outputs(); }
output_type(int32_t o)47   DataType output_type(int32_t o) const { return node_->output_type(o); }
48   Output output(int32_t i) const;
49 
node()50   Node* node() const { return node_; }
51 
52   uint64 hash(int32_t index) const;
53 
54   bool operator==(const Operation& other) const { return node_ == other.node_; }
55 
56  private:
57   typedef std::vector<std::pair<Node*, int32>> Inputs;
58   static Inputs GetInputs(Node* node);
59 
60   Inputs inputs_;
61   Node* node_;
62 };
63 
64 /// Represents a tensor value produced by an Operation.
65 class Output {
66  public:
67   Output() = default;
Output(Node * n)68   explicit Output(Node* n) : op_(n) {}
Output(Node * n,int32_t index)69   Output(Node* n, int32_t index) : op_(n), index_(index) {}
Output(const Operation & op,int32_t index)70   Output(const Operation& op, int32_t index) : op_(op), index_(index) {}
71 
op()72   Operation op() const { return op_; }
node()73   Node* node() const { return op().node(); }
index()74   int32 index() const { return index_; }
type()75   DataType type() const { return op_.output_type(index_); }
name()76   std::string name() const {
77     return strings::StrCat(node()->name(), ":", index());
78   }
79   bool operator==(const Output& other) const {
80     return op_ == other.op_ && index_ == other.index_;
81   }
82 
hash()83   uint64 hash() const { return op_.hash(index_); }
84 
85  private:
86   Operation op_ = Operation(nullptr);
87   int32 index_ = 0;
88 };
89 
90 /// Hash class that can be used for e.g. storing Outputs in an unordered_map
91 struct OutputHash {
operatorOutputHash92   std::size_t operator()(const Output& output) const {
93     return Hash64Combine(std::hash<Node*>()(output.node()),
94                          std::hash<int32>()(output.index()));
95   }
96 };
97 
98 /// Represents a tensor value that can be used as an operand to an Operation.
99 class Input {
100  public:
101   /// Initializer enables constructing an Input object from various kinds of C++
102   /// constants such as simple primitive constants and nested initializer lists
103   /// representing a multi-dimensional array. Initializer constructors are all
104   /// templates, so the aforementioned kinds of C++ constants can be used to
105   /// construct an Initializer. Initializer stores the value it got constructed
106   /// with in a Tensor object.
107   struct Initializer {
108     /// Construct from a scalar value of an arithmetic type or a type that can
109     /// be converted to a string (eg. a string literal).
110     template <typename T, typename = typename std::enable_if<
111                               std::is_arithmetic<T>::value ||
112                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer113     Initializer(const T& v) {  // NOLINT(runtime/explicit)
114       typedef typename RealType<T>::type RealT;
115       Tensor t(DataTypeToEnum<RealT>::v(), TensorShape());
116       t.flat<RealT>()(0) = RealT(v);
117       tensor = t;
118     }
119 
InitializerInitializer120     Initializer(const Tensor& t) : tensor(t) {}  // NOLINT(runtime/explicit)
121 
122     /// Construct from a scalar value and an explicit shape
123     template <typename T, typename = typename std::enable_if<
124                               std::is_arithmetic<T>::value ||
125                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer126     Initializer(const T& v, const TensorShape& shape) {
127       typedef typename RealType<T>::type RealT;
128       Tensor t(DataTypeToEnum<RealT>::v(), shape);
129       for (int64_t i = 0; i < t.NumElements(); ++i) {
130         t.flat<RealT>()(i) = RealT(v);
131       }
132       tensor = t;
133     }
134 
135     /// Construct from a initializer list of scalars (a one-dimensional tensor).
136     template <typename T, typename = typename std::enable_if<
137                               std::is_arithmetic<T>::value ||
138                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer139     Initializer(
140         const std::initializer_list<T>& v) {  // NOLINT(runtime/explicit)
141       typedef typename RealType<T>::type RealT;
142       Tensor t(DataTypeToEnum<RealT>::v(),
143                TensorShape{static_cast<int>(v.size())});
144       std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
145       tensor = t;
146     }
147 
148     /// Construct from a initializer list of scalars and an explicit shape.
149     template <typename T, typename = typename std::enable_if<
150                               std::is_arithmetic<T>::value ||
151                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer152     Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
153       typedef typename RealType<T>::type RealT;
154       Tensor t(DataTypeToEnum<RealT>::v(), shape);
155       if (t.NumElements() != static_cast<int64>(v.size())) {
156         status = errors::InvalidArgument(
157             "Cannot construct a tensor with ", t.NumElements(),
158             " from an initializer list with ", v.size(), " elements");
159         return;
160       }
161       std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
162       tensor = t;
163     }
164 
165     /// Construct a multi-dimensional tensor from a nested initializer
166     /// list. Note that C++ syntax allows nesting of arbitrarily typed
167     /// initializer lists, so such invalid initializers cannot be disallowed at
168     /// compile time. This function performs checks to make sure that the nested
169     /// initializer list is indeed a valid multi-dimensional tensor.
170     Initializer(const std::initializer_list<Initializer>& v);
171 
172     // START_SKIP_DOXYGEN
173     template <typename T, bool = std::is_convertible<T, std::string>::value>
174     struct RealType {
175       typedef tstring type;
176     };
177 
178     template <typename T>
179     struct RealType<T, false> {
180       typedef T type;
181     };
182     // END_SKIP_DOXYGEN
183 
184     TensorProto AsTensorProto() {
185       TensorProto tensor_proto;
186       if (tensor.NumElements() > 1) {
187         tensor.AsProtoTensorContent(&tensor_proto);
188       } else {
189         tensor.AsProtoField(&tensor_proto);
190       }
191       return tensor_proto;
192     }
193 
194     Status status;
195     Tensor tensor;
196   };
197 
198   /// All of Input's constructors are implicit. Input can be implicitly
199   /// constructed from the following objects :
200   /// * Output: This is so that the output of an Operation can be directly used
201   ///   as the input to a op wrapper, which takes Inputs.
202   /// * A scalar, or a multi-dimensional tensor specified as a recursive
203   ///   initializer list. This enables directly passing constants as
204   ///   inputs to op wrappers.
205   /// * A Tensor object.
206   Input(const Output& o) : output_(o) {}  // NOLINT(runtime/explicit)
207 
208   template <typename T, typename = typename std::enable_if<
209                             std::is_arithmetic<T>::value ||
210                             std::is_convertible<T, std::string>::value>::type>
211   Input(const T& v)  // NOLINT(runtime/explicit)
212       : Input(Initializer(v)) {}
213 
214   Input(const Initializer& init)  // NOLINT(runtime/explicit)
215       : status_(init.status),
216         tensor_(init.tensor) {}
217 
218   Input(const Tensor& t)  // NOLINT(runtime/explicit)
219       : status_(Status::OK()),
220         tensor_(t) {}
221 
222   Input(const std::initializer_list<Initializer>&
223             init) {  // NOLINT(runtime/explicit)
224     for (const auto& i : init) {
225       if (!i.status.ok()) {
226         status_ = i.status;
227         return;
228       }
229     }
230     tensor_ = Initializer(init).tensor;
231   }
232 
233   /// Constructor specifying a node name, index and datatype. This should only
234   /// be used for specifying a backward edge, needed by control flow.
235   Input(const std::string& name, int32_t i, DataType dt)
236       : node_name_(name), index_(i), data_type_(dt) {}
237 
238   Node* node() const { return output_.node(); }
239   std::string node_name() const { return node_name_; }
240   int32 index() const { return node_name_.empty() ? output_.index() : index_; }
241   DataType data_type() const { return data_type_; }
242   Status status() const { return status_; }
243   const Tensor& tensor() const { return tensor_; }
244 
245  private:
246   Status status_;
247   Output output_ = Output(Operation(nullptr), 0);
248   Tensor tensor_;
249   const std::string node_name_ = "";
250   int32 index_ = 0;
251   DataType data_type_ = DT_INVALID;
252 };
253 
254 /// A type for representing the output of ops that produce more than one output,
255 /// or a list of tensors.
256 typedef std::vector<Output> OutputList;
257 
258 /// A type for representing the input to ops that require a list of tensors.
259 class InputList {
260  public:
261   /// Implicitly convert a list of outputs to a list of inputs. This is useful
262   /// to write code such as ops::Concat(ops::Split(x, 4)).
263   InputList(const OutputList& out) {  // NOLINT(runtime/explicit)
264     for (auto const& x : out) {
265       inputs_.push_back(x);
266     }
267   }
268 
269   InputList(
270       const std::initializer_list<Input>& inputs)  // NOLINT(runtime/explicit)
271       : inputs_(inputs.begin(), inputs.end()) {}
272 
273   InputList(const tensorflow::gtl::ArraySlice<Input>&
274                 inputs)  // NOLINT(runtime/explicit)
275       : inputs_(inputs.begin(), inputs.end()) {}
276 
277   InputList(
278       const std::initializer_list<Output>& out) {  // NOLINT(runtime/explicit)
279     for (auto const& x : out) {
280       inputs_.push_back(x);
281     }
282   }
283 
284   typename std::vector<Input>::iterator begin() { return inputs_.begin(); }
285   typename std::vector<Input>::iterator end() { return inputs_.end(); }
286   typename std::vector<Input>::const_iterator begin() const {
287     return inputs_.begin();
288   }
289   typename std::vector<Input>::const_iterator end() const {
290     return inputs_.end();
291   }
292 
293  private:
294   std::vector<Input> inputs_;
295 };
296 
297 /// @}
298 
299 }  // namespace tensorflow
300 
301 #endif  // TENSORFLOW_CC_FRAMEWORK_OPS_H_
302