1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/framework/ops.h"
17 #include "tensorflow/cc/framework/scope_internal.h"
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/math_ops.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
23
24 namespace tensorflow {
25 namespace grappler {
26
27 namespace {
28
29 const char* const kExpandDimsPrefix = "vectorized/expanddims/";
30
31 // Reshapes stacked inputs for broadcast. Stacked inputs have an extra leading
32 // dimension, which may cause automatic broadcasting rules to expand the
33 // input dimensions wrongly when the unstacked shapes have different ranks.
34 // To avoid that, we reshape stacked inputs to the maximum rank they need
35 // to be broadcasted to.
36 //
37 // For example, suppose we have inputs A and B, where A is a stacked tensor with
38 // shape [n, 5] (where n is the stack size) and B is an unstacked tensor with
39 // shape [12, 7, 5]. If we added them directly, tensorflow broadcasting rules
40 // would expand the dimensions of A to [1, n, 5], then (incorrectly) check that
41 // the dimensions n and 7 are compatible, and if so, create an output of shape
42 // [12, 7, 5]. However, correct addition of these inputs would create an output
43 // with shape [n, 12, 7, 5]: we need to manually expand the dimensions of A
44 // *after* the leading dimension, i.e. expand A to the shape [n, 1, 1, 5] before
45 // broadcasting.
ExpandDimsForBroadcast(VectorizerInput * inputs,Graph * g)46 Status ExpandDimsForBroadcast(VectorizerInput* inputs, Graph* g) {
47 Status status;
48 Scope parent = NewInternalScope(g, &status, nullptr);
49 Scope scope = parent.NewSubScope(kExpandDimsPrefix);
50
51 // TODO(rachelim): We can potentially get rid of all these ops if shapes are
52 // known statically
53
54 // Get the stacked rank of each input
55 auto get_stacked_rank = [&scope](const WrappedTensor& input) {
56 Output rank = ops::Rank(scope, Output(input.node, input.output_index));
57
58 if (!input.stacked) {
59 // If the input is unstacked, add 1
60 rank = ops::Add(scope, rank, ops::Const(scope, 1));
61 }
62
63 return rank;
64 };
65
66 Output rank_0 = get_stacked_rank(inputs->at(0));
67 Output rank_1 = get_stacked_rank(inputs->at(1));
68
69 Output max_rank = ops::Maximum(scope, rank_0, rank_1);
70
71 // For all inputs that are stacked, expand dimensions after dim 0.
72 auto expand_dims_if_unstacked =
73 [&scope, &max_rank](const WrappedTensor& tensor, const Output& rank) {
74 if (!tensor.stacked)
75 return WrappedTensor(tensor.node, tensor.output_index, false);
76
77 Output input(tensor.node, tensor.output_index);
78
79 Output rank_diff = ops::Sub(scope, max_rank, rank);
80
81 // [1] * rank_diff
82 Output ones = ops::Fill(
83 scope, ops::ExpandDims(scope, rank_diff, ops::Const(scope, 0)),
84 ops::Const(scope, 1));
85
86 Output shape = ops::Shape(scope, input);
87
88 Output const_vec_1 = ops::Const(scope, {1});
89 // shape[:1]
90 Output concat_pre = ops::StridedSlice(
91 scope, shape, const_vec_1, const_vec_1, const_vec_1,
92 ops::StridedSlice::Attrs().BeginMask(1));
93
94 // shape[1:]
95 Output concat_post = ops::StridedSlice(
96 scope, shape, const_vec_1, const_vec_1, const_vec_1,
97 ops::StridedSlice::Attrs().EndMask(1));
98
99 // tf.concat([shape[:1], ones, shape[1:]], 0)
100 Output new_shape = ops::Concat(scope, {concat_pre, ones, concat_post},
101 ops::Const(scope, 0));
102
103 Output reshaped = ops::Reshape(scope, input, new_shape);
104
105 return WrappedTensor(reshaped.node(), 0, true);
106 };
107
108 *inputs = VectorizerInput({expand_dims_if_unstacked(inputs->at(0), rank_0),
109 expand_dims_if_unstacked(inputs->at(1), rank_1)});
110 return Status::OK();
111 }
112
113 // Vectorization helper for component-wise ops. Since these operations act
114 // component-wise, the vectorized op is the same as the original.
CwiseVectorizeHelper(const Node & node,Graph * outer_scope,VectorizerInput && inputs,VectorizerOutput * outputs)115 Status CwiseVectorizeHelper(const Node& node, Graph* outer_scope,
116 VectorizerInput&& inputs,
117 VectorizerOutput* outputs) {
118 // Add new node with the same op type and attrs as the original node
119 Node* new_node;
120 auto node_builder = NodeBuilder(strings::StrCat("vectorized/", node.name()),
121 node.type_string());
122 for (const auto& input : inputs) {
123 node_builder = node_builder.Input(input.node, input.output_index);
124 }
125 for (const auto& attr_slice : node.attrs()) {
126 node_builder = node_builder.Attr(attr_slice.first, attr_slice.second);
127 }
128 TF_RETURN_IF_ERROR(node_builder.Finalize(outer_scope, &new_node));
129
130 // Add output mappings
131 outputs->push_back({new_node, 0, true});
132 return Status::OK();
133 }
134
135 class UnaryCwiseOpVectorizer : public Vectorizer {
136 public:
Vectorize(const Node & node,Graph * outer_scope,VectorizerInput && inputs,VectorizerOutput * outputs)137 Status Vectorize(const Node& node, Graph* outer_scope,
138 VectorizerInput&& inputs,
139 VectorizerOutput* outputs) override {
140 if (inputs.size() != 1) {
141 return errors::Internal("Failed to vectorize ", node.type_string(),
142 ". The op should have 1 input, but has ",
143 inputs.size());
144 }
145
146 return CwiseVectorizeHelper(node, outer_scope, std::move(inputs), outputs);
147 }
148 };
149
150 class BinaryCwiseOpVectorizer : public Vectorizer {
151 public:
Vectorize(const Node & node,Graph * outer_scope,VectorizerInput && inputs,VectorizerOutput * outputs)152 Status Vectorize(const Node& node, Graph* outer_scope,
153 VectorizerInput&& inputs,
154 VectorizerOutput* outputs) override {
155 if (inputs.size() != 2) {
156 return errors::Internal("Failed to vectorize ", node.type_string(),
157 ". The op should have 2 input, but has ",
158 inputs.size());
159 }
160 // Binary ops support broadcasting
161 TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope));
162
163 return CwiseVectorizeHelper(node, outer_scope, std::move(inputs), outputs);
164 }
165 };
166
167 // Bitwise unary
168 REGISTER_VECTORIZER("Invert", UnaryCwiseOpVectorizer);
169
170 // Logical unary
171 REGISTER_VECTORIZER("LogicalNot", UnaryCwiseOpVectorizer);
172
173 // Complex unary
174 REGISTER_VECTORIZER("Angle", UnaryCwiseOpVectorizer);
175 REGISTER_VECTORIZER("ComplexAbs", UnaryCwiseOpVectorizer);
176 REGISTER_VECTORIZER("Conj", UnaryCwiseOpVectorizer);
177 REGISTER_VECTORIZER("Imag", UnaryCwiseOpVectorizer);
178 REGISTER_VECTORIZER("Real", UnaryCwiseOpVectorizer);
179
180 // Real unary
181 REGISTER_VECTORIZER("Abs", UnaryCwiseOpVectorizer);
182 REGISTER_VECTORIZER("Acos", UnaryCwiseOpVectorizer);
183 REGISTER_VECTORIZER("Acosh", UnaryCwiseOpVectorizer);
184 REGISTER_VECTORIZER("Asin", UnaryCwiseOpVectorizer);
185 REGISTER_VECTORIZER("Asinh", UnaryCwiseOpVectorizer);
186 REGISTER_VECTORIZER("Atan", UnaryCwiseOpVectorizer);
187 REGISTER_VECTORIZER("Atanh", UnaryCwiseOpVectorizer);
188 REGISTER_VECTORIZER("BesselI0e", UnaryCwiseOpVectorizer);
189 REGISTER_VECTORIZER("BesselI1e", UnaryCwiseOpVectorizer);
190 REGISTER_VECTORIZER("Ceil", UnaryCwiseOpVectorizer);
191 REGISTER_VECTORIZER("Cos", UnaryCwiseOpVectorizer);
192 REGISTER_VECTORIZER("Cosh", UnaryCwiseOpVectorizer);
193 REGISTER_VECTORIZER("Digamma", UnaryCwiseOpVectorizer);
194 REGISTER_VECTORIZER("Elu", UnaryCwiseOpVectorizer);
195 REGISTER_VECTORIZER("Erf", UnaryCwiseOpVectorizer);
196 REGISTER_VECTORIZER("Erfc", UnaryCwiseOpVectorizer);
197 REGISTER_VECTORIZER("Exp", UnaryCwiseOpVectorizer);
198 REGISTER_VECTORIZER("Expm1", UnaryCwiseOpVectorizer);
199 REGISTER_VECTORIZER("Floor", UnaryCwiseOpVectorizer);
200 REGISTER_VECTORIZER("Inv", UnaryCwiseOpVectorizer);
201 REGISTER_VECTORIZER("IsFinite", UnaryCwiseOpVectorizer);
202 REGISTER_VECTORIZER("IsInf", UnaryCwiseOpVectorizer);
203 REGISTER_VECTORIZER("Lgamma", UnaryCwiseOpVectorizer);
204 REGISTER_VECTORIZER("Log", UnaryCwiseOpVectorizer);
205 REGISTER_VECTORIZER("Log1p", UnaryCwiseOpVectorizer);
206 REGISTER_VECTORIZER("Neg", UnaryCwiseOpVectorizer);
207 REGISTER_VECTORIZER("Reciprocal", UnaryCwiseOpVectorizer);
208 REGISTER_VECTORIZER("Relu", UnaryCwiseOpVectorizer);
209 REGISTER_VECTORIZER("Relu6", UnaryCwiseOpVectorizer);
210 REGISTER_VECTORIZER("Rint", UnaryCwiseOpVectorizer);
211 REGISTER_VECTORIZER("Round", UnaryCwiseOpVectorizer);
212 REGISTER_VECTORIZER("Rsqrt", UnaryCwiseOpVectorizer);
213 REGISTER_VECTORIZER("Selu", UnaryCwiseOpVectorizer);
214 REGISTER_VECTORIZER("Sigmoid", UnaryCwiseOpVectorizer);
215 REGISTER_VECTORIZER("Sign", UnaryCwiseOpVectorizer);
216 REGISTER_VECTORIZER("Sin", UnaryCwiseOpVectorizer);
217 REGISTER_VECTORIZER("Sinh", UnaryCwiseOpVectorizer);
218 REGISTER_VECTORIZER("Softplus", UnaryCwiseOpVectorizer);
219 REGISTER_VECTORIZER("Softsign", UnaryCwiseOpVectorizer);
220 REGISTER_VECTORIZER("Sqrt", UnaryCwiseOpVectorizer);
221 REGISTER_VECTORIZER("Square", UnaryCwiseOpVectorizer);
222 REGISTER_VECTORIZER("Tanh", UnaryCwiseOpVectorizer);
223 REGISTER_VECTORIZER("Tan", UnaryCwiseOpVectorizer);
224
225 // Miscellaneous unary
226 REGISTER_VECTORIZER("Cast", UnaryCwiseOpVectorizer);
227 REGISTER_VECTORIZER("Identity", UnaryCwiseOpVectorizer);
228
229 // Bitwise binary
230 REGISTER_VECTORIZER("BitwiseAnd", BinaryCwiseOpVectorizer);
231 REGISTER_VECTORIZER("BitwiseOr", BinaryCwiseOpVectorizer);
232 REGISTER_VECTORIZER("BitwiseXor", BinaryCwiseOpVectorizer);
233 REGISTER_VECTORIZER("LeftShift", BinaryCwiseOpVectorizer);
234 REGISTER_VECTORIZER("RightShift", BinaryCwiseOpVectorizer);
235
236 // Logical binary
237 REGISTER_VECTORIZER("LogicalAnd", BinaryCwiseOpVectorizer);
238 REGISTER_VECTORIZER("LogicalOr", BinaryCwiseOpVectorizer);
239
240 // Real binary
241 REGISTER_VECTORIZER("Add", BinaryCwiseOpVectorizer);
242 REGISTER_VECTORIZER("AddV2", BinaryCwiseOpVectorizer);
243 REGISTER_VECTORIZER("Atan2", BinaryCwiseOpVectorizer);
244 REGISTER_VECTORIZER("Complex", BinaryCwiseOpVectorizer);
245 REGISTER_VECTORIZER("Div", BinaryCwiseOpVectorizer);
246 REGISTER_VECTORIZER("DivNoNan", BinaryCwiseOpVectorizer);
247 REGISTER_VECTORIZER("Equal", BinaryCwiseOpVectorizer);
248 REGISTER_VECTORIZER("FloorDiv", BinaryCwiseOpVectorizer);
249 REGISTER_VECTORIZER("FloorMod", BinaryCwiseOpVectorizer);
250 REGISTER_VECTORIZER("Greater", BinaryCwiseOpVectorizer);
251 REGISTER_VECTORIZER("GreaterEqual", BinaryCwiseOpVectorizer);
252 REGISTER_VECTORIZER("Igamma", BinaryCwiseOpVectorizer);
253 REGISTER_VECTORIZER("Igammac", BinaryCwiseOpVectorizer);
254 REGISTER_VECTORIZER("IgammaGradA", BinaryCwiseOpVectorizer);
255 REGISTER_VECTORIZER("Less", BinaryCwiseOpVectorizer);
256 REGISTER_VECTORIZER("LessEqual", BinaryCwiseOpVectorizer);
257 REGISTER_VECTORIZER("Maximum", BinaryCwiseOpVectorizer);
258 REGISTER_VECTORIZER("Minimum", BinaryCwiseOpVectorizer);
259 REGISTER_VECTORIZER("Mod", BinaryCwiseOpVectorizer);
260 REGISTER_VECTORIZER("Mul", BinaryCwiseOpVectorizer);
261 REGISTER_VECTORIZER("NotEqual", BinaryCwiseOpVectorizer);
262 REGISTER_VECTORIZER("Polygamma", BinaryCwiseOpVectorizer);
263 REGISTER_VECTORIZER("Pow", BinaryCwiseOpVectorizer);
264 REGISTER_VECTORIZER("RealDiv", BinaryCwiseOpVectorizer);
265 REGISTER_VECTORIZER("SquaredDifference", BinaryCwiseOpVectorizer);
266 REGISTER_VECTORIZER("Sub", BinaryCwiseOpVectorizer);
267 REGISTER_VECTORIZER("TruncateDiv", BinaryCwiseOpVectorizer);
268 REGISTER_VECTORIZER("TruncateMod", BinaryCwiseOpVectorizer);
269 REGISTER_VECTORIZER("Zeta", BinaryCwiseOpVectorizer);
270 } // namespace
271 } // namespace grappler
272 } // namespace tensorflow
273