• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/framework/ops.h"
17 #include "tensorflow/cc/framework/scope_internal.h"
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/math_ops.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 
27 namespace {
28 
29 const char* const kExpandDimsPrefix = "vectorized/expanddims/";
30 
31 // Reshapes stacked inputs for broadcast. Stacked inputs have an extra leading
32 // dimension, which may cause automatic broadcasting rules to expand the
33 // input dimensions wrongly when the unstacked shapes have different ranks.
34 // To avoid that, we reshape stacked inputs to the maximum rank they need
35 // to be broadcasted to.
36 //
37 // For example, suppose we have inputs A and B, where A is a stacked tensor with
38 // shape [n, 5] (where n is the stack size) and B is an unstacked tensor with
39 // shape [12, 7, 5]. If we added them directly, tensorflow broadcasting rules
40 // would expand the dimensions of A to [1, n, 5], then (incorrectly) check that
41 // the dimensions n and 7 are compatible, and if so, create an output of shape
42 // [12, 7, 5]. However, correct addition of these inputs would create an output
43 // with shape [n, 12, 7, 5]: we need to manually expand the dimensions of A
44 // *after* the leading dimension, i.e. expand A to the shape [n, 1, 1, 5] before
45 // broadcasting.
ExpandDimsForBroadcast(VectorizerInput * inputs,Graph * g)46 Status ExpandDimsForBroadcast(VectorizerInput* inputs, Graph* g) {
47   Status status;
48   Scope parent = NewInternalScope(g, &status, nullptr);
49   Scope scope = parent.NewSubScope(kExpandDimsPrefix);
50 
51   // TODO(rachelim): We can potentially get rid of all these ops if shapes are
52   // known statically
53 
54   // Get the stacked rank of each input
55   auto get_stacked_rank = [&scope](const WrappedTensor& input) {
56     Output rank = ops::Rank(scope, Output(input.node, input.output_index));
57 
58     if (!input.stacked) {
59       // If the input is unstacked, add 1
60       rank = ops::Add(scope, rank, ops::Const(scope, 1));
61     }
62 
63     return rank;
64   };
65 
66   Output rank_0 = get_stacked_rank(inputs->at(0));
67   Output rank_1 = get_stacked_rank(inputs->at(1));
68 
69   Output max_rank = ops::Maximum(scope, rank_0, rank_1);
70 
71   // For all inputs that are stacked, expand dimensions after dim 0.
72   auto expand_dims_if_unstacked =
73       [&scope, &max_rank](const WrappedTensor& tensor, const Output& rank) {
74         if (!tensor.stacked)
75           return WrappedTensor(tensor.node, tensor.output_index, false);
76 
77         Output input(tensor.node, tensor.output_index);
78 
79         Output rank_diff = ops::Sub(scope, max_rank, rank);
80 
81         // [1] * rank_diff
82         Output ones = ops::Fill(
83             scope, ops::ExpandDims(scope, rank_diff, ops::Const(scope, 0)),
84             ops::Const(scope, 1));
85 
86         Output shape = ops::Shape(scope, input);
87 
88         Output const_vec_1 = ops::Const(scope, {1});
89         // shape[:1]
90         Output concat_pre = ops::StridedSlice(
91             scope, shape, const_vec_1, const_vec_1, const_vec_1,
92             ops::StridedSlice::Attrs().BeginMask(1));
93 
94         // shape[1:]
95         Output concat_post = ops::StridedSlice(
96             scope, shape, const_vec_1, const_vec_1, const_vec_1,
97             ops::StridedSlice::Attrs().EndMask(1));
98 
99         // tf.concat([shape[:1], ones, shape[1:]], 0)
100         Output new_shape = ops::Concat(scope, {concat_pre, ones, concat_post},
101                                        ops::Const(scope, 0));
102 
103         Output reshaped = ops::Reshape(scope, input, new_shape);
104 
105         return WrappedTensor(reshaped.node(), 0, true);
106       };
107 
108   *inputs = VectorizerInput({expand_dims_if_unstacked(inputs->at(0), rank_0),
109                              expand_dims_if_unstacked(inputs->at(1), rank_1)});
110   return Status::OK();
111 }
112 
113 // Vectorization helper for component-wise ops. Since these operations act
114 // component-wise, the vectorized op is the same as the original.
CwiseVectorizeHelper(const Node & node,Graph * outer_scope,VectorizerInput && inputs,VectorizerOutput * outputs)115 Status CwiseVectorizeHelper(const Node& node, Graph* outer_scope,
116                             VectorizerInput&& inputs,
117                             VectorizerOutput* outputs) {
118   // Add new node with the same op type and attrs as the original node
119   Node* new_node;
120   auto node_builder = NodeBuilder(strings::StrCat("vectorized/", node.name()),
121                                   node.type_string());
122   for (const auto& input : inputs) {
123     node_builder = node_builder.Input(input.node, input.output_index);
124   }
125   for (const auto& attr_slice : node.attrs()) {
126     node_builder = node_builder.Attr(attr_slice.first, attr_slice.second);
127   }
128   TF_RETURN_IF_ERROR(node_builder.Finalize(outer_scope, &new_node));
129 
130   // Add output mappings
131   outputs->push_back({new_node, 0, true});
132   return Status::OK();
133 }
134 
135 class UnaryCwiseOpVectorizer : public Vectorizer {
136  public:
Vectorize(const Node & node,Graph * outer_scope,VectorizerInput && inputs,VectorizerOutput * outputs)137   Status Vectorize(const Node& node, Graph* outer_scope,
138                    VectorizerInput&& inputs,
139                    VectorizerOutput* outputs) override {
140     if (inputs.size() != 1) {
141       return errors::Internal("Failed to vectorize ", node.type_string(),
142                               ". The op should have 1 input, but has ",
143                               inputs.size());
144     }
145 
146     return CwiseVectorizeHelper(node, outer_scope, std::move(inputs), outputs);
147   }
148 };
149 
150 class BinaryCwiseOpVectorizer : public Vectorizer {
151  public:
Vectorize(const Node & node,Graph * outer_scope,VectorizerInput && inputs,VectorizerOutput * outputs)152   Status Vectorize(const Node& node, Graph* outer_scope,
153                    VectorizerInput&& inputs,
154                    VectorizerOutput* outputs) override {
155     if (inputs.size() != 2) {
156       return errors::Internal("Failed to vectorize ", node.type_string(),
157                               ". The op should have 2 input, but has ",
158                               inputs.size());
159     }
160     // Binary ops support broadcasting
161     TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope));
162 
163     return CwiseVectorizeHelper(node, outer_scope, std::move(inputs), outputs);
164   }
165 };
166 
167 // Bitwise unary
168 REGISTER_VECTORIZER("Invert", UnaryCwiseOpVectorizer);
169 
170 // Logical unary
171 REGISTER_VECTORIZER("LogicalNot", UnaryCwiseOpVectorizer);
172 
173 // Complex unary
174 REGISTER_VECTORIZER("Angle", UnaryCwiseOpVectorizer);
175 REGISTER_VECTORIZER("ComplexAbs", UnaryCwiseOpVectorizer);
176 REGISTER_VECTORIZER("Conj", UnaryCwiseOpVectorizer);
177 REGISTER_VECTORIZER("Imag", UnaryCwiseOpVectorizer);
178 REGISTER_VECTORIZER("Real", UnaryCwiseOpVectorizer);
179 
180 // Real unary
181 REGISTER_VECTORIZER("Abs", UnaryCwiseOpVectorizer);
182 REGISTER_VECTORIZER("Acos", UnaryCwiseOpVectorizer);
183 REGISTER_VECTORIZER("Acosh", UnaryCwiseOpVectorizer);
184 REGISTER_VECTORIZER("Asin", UnaryCwiseOpVectorizer);
185 REGISTER_VECTORIZER("Asinh", UnaryCwiseOpVectorizer);
186 REGISTER_VECTORIZER("Atan", UnaryCwiseOpVectorizer);
187 REGISTER_VECTORIZER("Atanh", UnaryCwiseOpVectorizer);
188 REGISTER_VECTORIZER("BesselI0e", UnaryCwiseOpVectorizer);
189 REGISTER_VECTORIZER("BesselI1e", UnaryCwiseOpVectorizer);
190 REGISTER_VECTORIZER("Ceil", UnaryCwiseOpVectorizer);
191 REGISTER_VECTORIZER("Cos", UnaryCwiseOpVectorizer);
192 REGISTER_VECTORIZER("Cosh", UnaryCwiseOpVectorizer);
193 REGISTER_VECTORIZER("Digamma", UnaryCwiseOpVectorizer);
194 REGISTER_VECTORIZER("Elu", UnaryCwiseOpVectorizer);
195 REGISTER_VECTORIZER("Erf", UnaryCwiseOpVectorizer);
196 REGISTER_VECTORIZER("Erfc", UnaryCwiseOpVectorizer);
197 REGISTER_VECTORIZER("Exp", UnaryCwiseOpVectorizer);
198 REGISTER_VECTORIZER("Expm1", UnaryCwiseOpVectorizer);
199 REGISTER_VECTORIZER("Floor", UnaryCwiseOpVectorizer);
200 REGISTER_VECTORIZER("Inv", UnaryCwiseOpVectorizer);
201 REGISTER_VECTORIZER("IsFinite", UnaryCwiseOpVectorizer);
202 REGISTER_VECTORIZER("IsInf", UnaryCwiseOpVectorizer);
203 REGISTER_VECTORIZER("Lgamma", UnaryCwiseOpVectorizer);
204 REGISTER_VECTORIZER("Log", UnaryCwiseOpVectorizer);
205 REGISTER_VECTORIZER("Log1p", UnaryCwiseOpVectorizer);
206 REGISTER_VECTORIZER("Neg", UnaryCwiseOpVectorizer);
207 REGISTER_VECTORIZER("Reciprocal", UnaryCwiseOpVectorizer);
208 REGISTER_VECTORIZER("Relu", UnaryCwiseOpVectorizer);
209 REGISTER_VECTORIZER("Relu6", UnaryCwiseOpVectorizer);
210 REGISTER_VECTORIZER("Rint", UnaryCwiseOpVectorizer);
211 REGISTER_VECTORIZER("Round", UnaryCwiseOpVectorizer);
212 REGISTER_VECTORIZER("Rsqrt", UnaryCwiseOpVectorizer);
213 REGISTER_VECTORIZER("Selu", UnaryCwiseOpVectorizer);
214 REGISTER_VECTORIZER("Sigmoid", UnaryCwiseOpVectorizer);
215 REGISTER_VECTORIZER("Sign", UnaryCwiseOpVectorizer);
216 REGISTER_VECTORIZER("Sin", UnaryCwiseOpVectorizer);
217 REGISTER_VECTORIZER("Sinh", UnaryCwiseOpVectorizer);
218 REGISTER_VECTORIZER("Softplus", UnaryCwiseOpVectorizer);
219 REGISTER_VECTORIZER("Softsign", UnaryCwiseOpVectorizer);
220 REGISTER_VECTORIZER("Sqrt", UnaryCwiseOpVectorizer);
221 REGISTER_VECTORIZER("Square", UnaryCwiseOpVectorizer);
222 REGISTER_VECTORIZER("Tanh", UnaryCwiseOpVectorizer);
223 REGISTER_VECTORIZER("Tan", UnaryCwiseOpVectorizer);
224 
225 // Miscellaneous unary
226 REGISTER_VECTORIZER("Cast", UnaryCwiseOpVectorizer);
227 REGISTER_VECTORIZER("Identity", UnaryCwiseOpVectorizer);
228 
229 // Bitwise binary
230 REGISTER_VECTORIZER("BitwiseAnd", BinaryCwiseOpVectorizer);
231 REGISTER_VECTORIZER("BitwiseOr", BinaryCwiseOpVectorizer);
232 REGISTER_VECTORIZER("BitwiseXor", BinaryCwiseOpVectorizer);
233 REGISTER_VECTORIZER("LeftShift", BinaryCwiseOpVectorizer);
234 REGISTER_VECTORIZER("RightShift", BinaryCwiseOpVectorizer);
235 
236 // Logical binary
237 REGISTER_VECTORIZER("LogicalAnd", BinaryCwiseOpVectorizer);
238 REGISTER_VECTORIZER("LogicalOr", BinaryCwiseOpVectorizer);
239 
240 // Real binary
241 REGISTER_VECTORIZER("Add", BinaryCwiseOpVectorizer);
242 REGISTER_VECTORIZER("AddV2", BinaryCwiseOpVectorizer);
243 REGISTER_VECTORIZER("Atan2", BinaryCwiseOpVectorizer);
244 REGISTER_VECTORIZER("Complex", BinaryCwiseOpVectorizer);
245 REGISTER_VECTORIZER("Div", BinaryCwiseOpVectorizer);
246 REGISTER_VECTORIZER("DivNoNan", BinaryCwiseOpVectorizer);
247 REGISTER_VECTORIZER("Equal", BinaryCwiseOpVectorizer);
248 REGISTER_VECTORIZER("FloorDiv", BinaryCwiseOpVectorizer);
249 REGISTER_VECTORIZER("FloorMod", BinaryCwiseOpVectorizer);
250 REGISTER_VECTORIZER("Greater", BinaryCwiseOpVectorizer);
251 REGISTER_VECTORIZER("GreaterEqual", BinaryCwiseOpVectorizer);
252 REGISTER_VECTORIZER("Igamma", BinaryCwiseOpVectorizer);
253 REGISTER_VECTORIZER("Igammac", BinaryCwiseOpVectorizer);
254 REGISTER_VECTORIZER("IgammaGradA", BinaryCwiseOpVectorizer);
255 REGISTER_VECTORIZER("Less", BinaryCwiseOpVectorizer);
256 REGISTER_VECTORIZER("LessEqual", BinaryCwiseOpVectorizer);
257 REGISTER_VECTORIZER("Maximum", BinaryCwiseOpVectorizer);
258 REGISTER_VECTORIZER("Minimum", BinaryCwiseOpVectorizer);
259 REGISTER_VECTORIZER("Mod", BinaryCwiseOpVectorizer);
260 REGISTER_VECTORIZER("Mul", BinaryCwiseOpVectorizer);
261 REGISTER_VECTORIZER("NotEqual", BinaryCwiseOpVectorizer);
262 REGISTER_VECTORIZER("Polygamma", BinaryCwiseOpVectorizer);
263 REGISTER_VECTORIZER("Pow", BinaryCwiseOpVectorizer);
264 REGISTER_VECTORIZER("RealDiv", BinaryCwiseOpVectorizer);
265 REGISTER_VECTORIZER("SquaredDifference", BinaryCwiseOpVectorizer);
266 REGISTER_VECTORIZER("Sub", BinaryCwiseOpVectorizer);
267 REGISTER_VECTORIZER("TruncateDiv", BinaryCwiseOpVectorizer);
268 REGISTER_VECTORIZER("TruncateMod", BinaryCwiseOpVectorizer);
269 REGISTER_VECTORIZER("Zeta", BinaryCwiseOpVectorizer);
270 }  // namespace
271 }  // namespace grappler
272 }  // namespace tensorflow
273