• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
18 
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 
26 namespace tensorflow {
27 namespace grappler {
28 
29 // Represents the outputs of a vectorized op. Currently, a simple type alias
30 // provided for symmetry with `VectorizerInput`.
31 using VectorizerOutput = std::vector<WrappedTensor>;
32 
33 // Represents the inputs of a vectorized op. Supports iteration, random access,
34 // and retrieval of stacked and unstacked tensor inputs.
35 class VectorizerInput {
36  public:
VectorizerInput(std::vector<WrappedTensor> && inputs)37   VectorizerInput(std::vector<WrappedTensor>&& inputs)
38       : inputs_(std::move(inputs)) {}
39 
40   // Gets the stacked tensor input at position index. Returns an error if
41   // the tensor at index is unstacked. The type T must have a (Node*, int)
42   // constructor.
43   template <class T>
stacked(int index,T * result)44   Status stacked(int index, T* result) const {
45     DCHECK_GE(index, 0);
46     DCHECK_LT(index, size());
47 
48     if (!inputs_[index].stacked) {
49       return errors::InvalidArgument("Expecting input ", index,
50                                      " to be stacked.");
51     }
52     *result = {inputs_[index].node, inputs_[index].output_index};
53     return Status::OK();
54   }
55 
56   // Gets the unstacked tensor input at position index. Returns an error if
57   // the tensor at index is stacked. The type T must have a (Node*, int)
58   // constructor.
59   template <class T>
unstacked(int index,T * result)60   Status unstacked(int index, T* result) const {
61     DCHECK_GE(index, 0);
62     DCHECK_LT(index, size());
63 
64     if (inputs_[index].stacked) {
65       return errors::InvalidArgument("Expecting input ", index,
66                                      " to be unstacked.");
67     }
68     *result = {inputs_[index].node, inputs_[index].output_index};
69     return Status::OK();
70   }
71 
72   // Returns a const reference to the element at specified location index.
at(int index)73   const WrappedTensor& at(int index) const {
74     DCHECK_GE(index, 0);
75     DCHECK_LT(index, size());
76     return inputs_.at(index);
77   }
78 
79   // Returns a const iterator pointing to the first wrapped tensor input.
begin()80   std::vector<WrappedTensor>::const_iterator begin() const {
81     return inputs_.begin();
82   }
83   // Returns a const iterator pointing to the past-the-end wrapped tensor input.
end()84   std::vector<WrappedTensor>::const_iterator end() const {
85     return inputs_.end();
86   }
87 
88   // Returns the number of input tensors.
size()89   size_t size() const { return inputs_.size(); }
90 
91  private:
92   std::vector<WrappedTensor> inputs_;
93 };
94 
95 // Interface for vectorization of TensorFlow operations. See `CastVectorizer`
96 // for an example.
97 class Vectorizer {
98  public:
~Vectorizer()99   virtual ~Vectorizer() {}
100 
101   // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
102   // that produce the same vector output(s) as executing `node`'s op
103   // on elements of `inputs`. The new Node(s) collectively have the
104   // same number of input and output ports as the node being converted.
105   // Adds edges between the newly created nodes and nodes in `inputs`, and adds
106   // mappings to the new nodes' output ports to `outputs`, where the i'th
107   // value in `outputs` corresponds to the i'th output port of the node
108   // to be converted.
109   virtual Status Vectorize(const Node& node, Graph* outer_scope,
110                            VectorizerInput&& inputs,
111                            VectorizerOutput* outputs) = 0;
112 };
113 
114 }  // namespace grappler
115 }  // namespace tensorflow
116 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
117