1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/attr_value_util.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/grappler/clusters/utils.h"
25
26 namespace tensorflow {
27 namespace grappler {
28
29 constexpr int kOpsPerMac = 2;
30 constexpr char kGuaranteeConst[] = "GuaranteeConst";
31 constexpr char kConv2d[] = "Conv2D";
32 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
33 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
34 constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
35 constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
36 constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
37 "DepthwiseConv2dNativeBackpropFilter";
38 constexpr char kDepthwiseConv2dNativeBackpropInput[] =
39 "DepthwiseConv2dNativeBackpropInput";
40 constexpr char kMatMul[] = "MatMul";
41 constexpr char kSparseMatMul[] = "SparseMatMul";
42 constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
43 constexpr char kPlaceholder[] = "Placeholder";
44 constexpr char kIdentity[] = "Identity";
45 constexpr char kIdentityN[] = "IdentityN";
46 constexpr char kRefIdentity[] = "RefIdentity";
47 constexpr char kNoOp[] = "NoOp";
48 constexpr char kReshape[] = "Reshape";
49 constexpr char kSqueeze[] = "Squeeze";
50 constexpr char kRecv[] = "_Recv";
51 constexpr char kSend[] = "_Send";
52 constexpr char kBatchMatMul[] = "BatchMatMul";
53 constexpr char kRank[] = "Rank";
54 constexpr char kShape[] = "Shape";
55 constexpr char kShapeN[] = "ShapeN";
56 constexpr char kSize[] = "Size";
57 constexpr char kStopGradient[] = "StopGradient";
58 constexpr char kPreventGradient[] = "PreventGradient";
59 constexpr char kGather[] = "Gather";
60 constexpr char kGatherV2[] = "GatherV2";
61 constexpr char kSlice[] = "Slice";
62 constexpr char kMaxPool[] = "MaxPool";
63 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
64 constexpr char kAvgPool[] = "AvgPool";
65 constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
66 constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
67 constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
68 constexpr char kQuantizedMatMul[] = "QuantizedMatMul";
69 constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
70 // Persistent ops.
71 constexpr char kConst[] = "Const";
72 constexpr char kVariable[] = "Variable";
73 constexpr char kVariableV2[] = "VariableV2";
74 constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
75 constexpr char kVarHandleOp[] = "VarHandleOp";
76 constexpr char kReadVariableOp[] = "ReadVariableOp";
77
78 static const Costs::Duration kMinComputeTime(1);
79
80 namespace {
81
GetDataFormat(const OpInfo & op_info)82 string GetDataFormat(const OpInfo& op_info) {
83 string data_format = "NHWC"; // Default format.
84 if (op_info.attr().find("data_format") != op_info.attr().end()) {
85 data_format = op_info.attr().at("data_format").s();
86 }
87 return data_format;
88 }
89
GetFilterFormat(const OpInfo & op_info)90 string GetFilterFormat(const OpInfo& op_info) {
91 string filter_format = "HWIO"; // Default format.
92 if (op_info.attr().find("filter_format") != op_info.attr().end()) {
93 filter_format = op_info.attr().at("filter_format").s();
94 }
95 return filter_format;
96 }
97
GetPadding(const OpInfo & op_info)98 Padding GetPadding(const OpInfo& op_info) {
99 if (op_info.attr().find("padding") != op_info.attr().end() &&
100 op_info.attr().at("padding").s() == "VALID") {
101 return Padding::VALID;
102 }
103 return Padding::SAME; // Default padding.
104 }
105
IsTraining(const OpInfo & op_info)106 bool IsTraining(const OpInfo& op_info) {
107 if (op_info.attr().find("is_training") != op_info.attr().end() &&
108 op_info.attr().at("is_training").b()) {
109 return true;
110 }
111 return false;
112 }
113
114 // TODO(dyoon): support non-4D tensors in the c ost functions of convolution
115 // related ops (Conv, Pool, BatchNorm, and their backprops) and the related
116 // helper functions.
GetStrides(const OpInfo & op_info)117 std::vector<int64> GetStrides(const OpInfo& op_info) {
118 if (op_info.attr().find("strides") != op_info.attr().end()) {
119 const auto strides = op_info.attr().at("strides").list().i();
120 CHECK(strides.size() == 4)
121 << "Attr strides is not a length-4 vector: " << op_info.DebugString();
122 return {strides[0], strides[1], strides[2], strides[3]};
123 }
124 return {1, 1, 1, 1};
125 }
126
GetKernelSize(const OpInfo & op_info)127 std::vector<int64> GetKernelSize(const OpInfo& op_info) {
128 if (op_info.attr().find("ksize") != op_info.attr().end()) {
129 const auto ksize = op_info.attr().at("ksize").list().i();
130 CHECK(ksize.size() == 4)
131 << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
132 return {ksize[0], ksize[1], ksize[2], ksize[3]};
133 }
134 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
135 // {1, 1, 1, 1} in that case.
136 return {1, 1, 1, 1};
137 }
138
GetOutputSize(const int64 input,const int64 filter,const int64 stride,const Padding & padding)139 int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
140 const Padding& padding) {
141 // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
142 // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
143 if (padding == Padding::VALID) {
144 return (input - filter + stride) / stride;
145 } else { // SAME.
146 return (input + stride - 1) / stride;
147 }
148 }
149
150 // Return the output element count of a binary element-wise op considering
151 // broadcasting.
CwiseOutputElementCount(const TensorShapeProto & input_shape_1,const TensorShapeProto & input_shape_2)152 int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
153 const TensorShapeProto& input_shape_2) {
154 bool found_unknown_shapes;
155 int rank = std::max(1, input_shape_1.dim_size());
156 TensorShapeProto output_shape =
157 MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes);
158
159 if (input_shape_1.dim_size() == input_shape_2.dim_size()) {
160 auto shape_1 =
161 MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes);
162 auto shape_2 =
163 MaybeGetMinimumShape(input_shape_2, rank, &found_unknown_shapes);
164 if (shape_1.dim_size() == shape_2.dim_size()) {
165 for (int i = 0; i < shape_1.dim_size(); i++) {
166 output_shape.mutable_dim(i)->set_size(
167 std::max(shape_1.dim(i).size(), shape_2.dim(i).size()));
168 }
169 }
170 }
171
172 int64 count = 1;
173 for (int i = 0; i < output_shape.dim_size(); i++) {
174 count *= output_shape.dim(i).size();
175 }
176 return count;
177 }
178
179 } // namespace
180
181 // Return a minimum shape if the shape is unknown. If known, return the original
182 // shape.
MaybeGetMinimumShape(const TensorShapeProto & original_shape,int rank,bool * found_unknown_shapes)183 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
184 int rank, bool* found_unknown_shapes) {
185 auto shape = original_shape;
186 bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
187
188 if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
189 *found_unknown_shapes = true;
190 VLOG(2) << "Use minimum shape because the rank is unknown.";
191 // The size of each dimension is at least 1, if unknown.
192 for (int i = shape.dim_size(); i < rank; i++) {
193 shape.add_dim()->set_size(1);
194 }
195 } else if (is_scalar) {
196 for (int i = 0; i < rank; i++) {
197 shape.add_dim()->set_size(1);
198 }
199 } else if (shape.dim_size() > rank) {
200 *found_unknown_shapes = true;
201 shape.clear_dim();
202 for (int i = 0; i < rank; i++) {
203 shape.add_dim()->set_size(original_shape.dim(i).size());
204 }
205 } else {
206 for (int i = 0; i < shape.dim_size(); i++) {
207 if (shape.dim(i).size() < 0) {
208 *found_unknown_shapes = true;
209 VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
210 // The size of each dimension is at least 1, if unknown.
211 shape.mutable_dim(i)->set_size(1);
212 }
213 }
214 }
215 return shape;
216 }
217
OpLevelCostEstimator()218 OpLevelCostEstimator::OpLevelCostEstimator() {
219 // Syntactic sugar to build and return a lambda that takes an OpInfo and
220 // returns a cost.
221 typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context)
222 const;
223 auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpContext&)> {
224 return [this, impl](const OpContext& op_context) {
225 return (this->*impl)(op_context);
226 };
227 };
228
229 device_cost_impl_ = {
230 {kConv2d, wrap(&OpLevelCostEstimator::PredictConv2D)},
231 {kConv2dBackpropFilter,
232 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
233 {kConv2dBackpropInput,
234 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
235 {kFusedConv2dBiasActivation,
236 wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
237 // reuse Conv2D for DepthwiseConv2dNative because the calculation is the
238 // same although the actual meaning of the parameters are different. See
239 // comments in PredictConv2D and related functions
240 {kDepthwiseConv2dNative, wrap(&OpLevelCostEstimator::PredictConv2D)},
241 {kDepthwiseConv2dNativeBackpropFilter,
242 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
243 {kDepthwiseConv2dNativeBackpropInput,
244 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
245 {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
246 {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
247 {kSparseTensorDenseMatMul,
248 wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul)},
249 {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
250 {kQuantizedMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
251 {kQuantizedMatMulV2, wrap(&OpLevelCostEstimator::PredictMatMul)},
252
253 {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
254 {kGuaranteeConst, wrap(&OpLevelCostEstimator::PredictNoOp)},
255
256 {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
257 {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
258 {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
259
260 {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
261 {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
262 {kIdentityN, wrap(&OpLevelCostEstimator::PredictIdentity)},
263 {kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
264 {kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
265 {kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
266 {kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
267 {kSqueeze, wrap(&OpLevelCostEstimator::PredictIdentity)},
268 {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
269 {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
270
271 {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
272 {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
273 {kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)},
274 {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)},
275 {kMaxPool, wrap(&OpLevelCostEstimator::PredictMaxPool)},
276 {kMaxPoolGrad, wrap(&OpLevelCostEstimator::PredictMaxPoolGrad)},
277 {kAvgPool, wrap(&OpLevelCostEstimator::PredictAvgPool)},
278 {kAvgPoolGrad, wrap(&OpLevelCostEstimator::PredictAvgPoolGrad)},
279 {kFusedBatchNorm, wrap(&OpLevelCostEstimator::PredictFusedBatchNorm)},
280 {kFusedBatchNormGrad,
281 wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
282 };
283
284 persistent_ops_ = {
285 kConst, kVariable, kVariableV2, kAutoReloadVariable,
286 kVarHandleOp, kReadVariableOp,
287 };
288
289 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
290
291 // Quantize = apply min and max bounds, multiply by scale factor and round.
292 const int quantize_v2_cost =
293 EIGEN_COST(scalar_product_op<float>) + EIGEN_COST(scalar_max_op<float>) +
294 EIGEN_COST(scalar_min_op<float>) + EIGEN_COST(scalar_round_op<float>);
295
296 elementwise_ops_ = {
297 // Unary ops alphabetically sorted
298 {"Acos", EIGEN_COST(scalar_acos_op<float>)},
299 {"Asin", EIGEN_COST(scalar_asin_op<float>)},
300 {"Atan", EIGEN_COST(scalar_atan_op<float>)},
301 {"Atan2", EIGEN_COST(scalar_quotient_op<float>) +
302 EIGEN_COST(scalar_atan_op<float>)},
303 // For now, we use Eigen cost model for float to int16 cast as an example
304 // case; Eigen cost model is zero when src and dst types are identical,
305 // and it uses AddCost (1) when different. We may implement a separate
306 // cost functions for cast ops, using the actual input and output types.
307 {"Cast", Eigen::internal::functor_traits<
308 Eigen::internal::scalar_cast_op<float, int16>>::Cost},
309 {"Ceil", EIGEN_COST(scalar_ceil_op<float>)},
310 {"Cos", EIGEN_COST(scalar_cos_op<float>)},
311 {"Dequantize", EIGEN_COST(scalar_product_op<float>)},
312 {"Erf", 1},
313 {"Erfc", 1},
314 {"Exp", EIGEN_COST(scalar_exp_op<float>)},
315 {"Expm1", EIGEN_COST(scalar_expm1_op<float>)},
316 {"Floor", EIGEN_COST(scalar_floor_op<float>)},
317 {"Inv", EIGEN_COST(scalar_inverse_op<float>)},
318 {"InvGrad", 1},
319 {"Lgamma", 1},
320 {"Log", EIGEN_COST(scalar_log_op<float>)},
321 {"Log1p", EIGEN_COST(scalar_log1p_op<float>)},
322 {"Neg", EIGEN_COST(scalar_opposite_op<float>)},
323 {"QuantizeV2", quantize_v2_cost},
324 {"Reciprocal", EIGEN_COST(scalar_inverse_op<float>)},
325 {"Rint", 1},
326 {"Round", EIGEN_COST(scalar_round_op<float>)},
327 {"Rsqrt", EIGEN_COST(scalar_rsqrt_op<float>)},
328 {"Sqrt", EIGEN_COST(scalar_sqrt_op<float>)},
329 {"Square", EIGEN_COST(scalar_square_op<float>)},
330 {"Tanh", EIGEN_COST(scalar_tanh_op<float>)},
331 {"Relu", EIGEN_COST(scalar_max_op<float>)},
332 {"Sigmoid", EIGEN_COST(scalar_logistic_op<float>)},
333 {"QuantizedSigmoid", EIGEN_COST(scalar_logistic_op<float>)},
334 {"Sign", EIGEN_COST(scalar_sign_op<float>)},
335 {"Sin", EIGEN_COST(scalar_sin_op<float>)},
336 {"Tan", EIGEN_COST(scalar_tan_op<float>)},
337 // Binary ops alphabetically sorted
338 {"Add", EIGEN_COST(scalar_sum_op<float>)},
339 {"ApproximateEqual", 1},
340 {"BiasAdd", EIGEN_COST(scalar_sum_op<float>)},
341 {"QuantizedBiasAdd", EIGEN_COST(scalar_sum_op<float>)},
342 {"Div", EIGEN_COST(scalar_quotient_op<float>)},
343 {"Equal", 1},
344 {"FloorDiv", EIGEN_COST(scalar_quotient_op<float>)},
345 {"FloorMod", EIGEN_COST(scalar_mod_op<float>)},
346 {"Greater", 1},
347 {"GreaterEqual", 1},
348 {"Less", 1},
349 {"LessEqual", 1},
350 {"LogicalAnd", EIGEN_COST(scalar_boolean_and_op)},
351 {"LogicalNot", 1},
352 {"LogicalOr", EIGEN_COST(scalar_boolean_or_op)},
353 {"Maximum", EIGEN_COST(scalar_max_op<float>)},
354 {"Minimum", EIGEN_COST(scalar_min_op<float>)},
355 {"Mod", EIGEN_COST(scalar_mod_op<float>)},
356 {"Mul", EIGEN_COST(scalar_product_op<float>)},
357 {"NotEqual", 1},
358 {"QuantizedAdd", EIGEN_COST(scalar_sum_op<float>)},
359 {"QuantizedMul", EIGEN_COST(scalar_product_op<float>)},
360 {"RealDiv", EIGEN_COST(scalar_quotient_op<float>)},
361 {"ReluGrad", EIGEN_COST(scalar_max_op<float>)},
362 {"SquareDifference", 1},
363 {"Sub", EIGEN_COST(scalar_difference_op<float>)},
364 {"TruncateDiv", EIGEN_COST(scalar_quotient_op<float>)},
365 {"TruncateMod", EIGEN_COST(scalar_mod_op<float>)}};
366
367 #undef EIGEN_COST
368
369 // By default, use sum of memory_time and compute_time for execution_time.
370 compute_memory_overlap_ = false;
371 }
372
PredictCosts(const OpContext & op_context) const373 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
374 const auto& op_info = op_context.op_info;
375 auto it = device_cost_impl_.find(op_info.op());
376 if (it != device_cost_impl_.end()) {
377 std::function<Costs(const OpContext&)> estimator = it->second;
378 Costs costs = estimator(op_context);
379 VLOG(1) << "Operation " << op_info.op() << " takes "
380 << costs.execution_time.count() << " ns.";
381 return costs;
382 }
383
384 if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
385 return PredictVariable(op_context);
386 }
387
388 if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
389 return PredictCwiseOp(op_context);
390 }
391
392 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
393
394 return PredictCostOfAnUnknownOp(op_context);
395 }
396
GetDeviceInfo(const DeviceProperties & device) const397 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
398 const DeviceProperties& device) const {
399 double gflops = -1;
400 double gb_per_sec = -1;
401
402 if (device.type() == "CPU") {
403 // Check if vector instructions are available, and refine performance
404 // prediction based on this.
405 // Frequencies are stored in MHz in the DeviceProperties.
406 gflops = device.num_cores() * device.frequency() * 1e-3;
407 if (gb_per_sec < 0) {
408 if (device.bandwidth() > 0) {
409 gb_per_sec = device.bandwidth() / 1e6;
410 } else {
411 gb_per_sec = 32;
412 }
413 }
414 } else if (device.type() == "GPU") {
415 const string architecture = device.environment().at("architecture");
416 int cores_per_multiprocessor;
417 if (architecture < "3") {
418 // Fermi
419 cores_per_multiprocessor = 32;
420 } else if (architecture < "4") {
421 // Kepler
422 cores_per_multiprocessor = 192;
423 } else if (architecture < "6") {
424 // Maxwell
425 cores_per_multiprocessor = 128;
426 } else {
427 // Pascal (compute capability version 6) and Volta (compute capability
428 // version 7)
429 cores_per_multiprocessor = 64;
430 }
431 gflops = device.num_cores() * device.frequency() * 1e-3 *
432 cores_per_multiprocessor * kOpsPerMac;
433 if (device.bandwidth() > 0) {
434 gb_per_sec = device.bandwidth() / 1e6;
435 } else {
436 gb_per_sec = 100;
437 }
438 }
439 VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
440 << " gb_per_sec: " << gb_per_sec;
441
442 DCHECK_LT(0, gflops) << device.DebugString();
443 DCHECK_LT(0, gb_per_sec) << device.DebugString();
444
445 return DeviceInfo(gflops, gb_per_sec);
446 }
447
PredictCwiseOp(const OpContext & op_context) const448 Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
449 const auto& op_info = op_context.op_info;
450 bool found_unknown_shapes = false;
451 // For unary or binary element-wise operations, op count is the element count
452 // of any input. We use the count for the largest input here to be more robust
453 // in case that the shape is unknown or partially known for other input.
454 int64 op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
455 // If output shape is available, try use the element count calculated from
456 // that.
457 if (op_info.outputs_size() > 0) {
458 op_count = std::max(
459 op_count,
460 CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
461 }
462 // For binary ops, calculate the output shape possibly resulting from
463 // broadcasting.
464 if (op_info.inputs_size() >= 2) {
465 op_count =
466 std::max(op_count, CwiseOutputElementCount(op_info.inputs(0).shape(),
467 op_info.inputs(1).shape()));
468 }
469
470 int op_cost = 1;
471 bool is_known_elementwise_op = false;
472 auto it = elementwise_ops_.find(op_info.op());
473 if (it != elementwise_ops_.end()) {
474 op_cost = it->second;
475 is_known_elementwise_op = true;
476 } else {
477 LOG(WARNING) << "Not a cwise op: " << op_info.op();
478 }
479
480 Costs costs = PredictOpCountBasedCost(op_count * op_cost, op_info);
481 if (found_unknown_shapes || !is_known_elementwise_op) {
482 costs.inaccurate = true;
483 }
484 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
485 return costs;
486 }
487
PredictCostOfAnUnknownOp(const OpContext & op_context) const488 Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp(
489 const OpContext& op_context) const {
490 // Don't assume the operation is cwise, return cost based on input/output size
491 // and admit that it is inaccurate...
492 auto costs = PredictOpCountBasedCost(0, op_context.op_info);
493 costs.inaccurate = true;
494 return costs;
495 }
496
PredictOpCountBasedCost(double operations,const OpInfo & op_info) const497 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
498 double operations, const OpInfo& op_info) const {
499 bool unknown_shapes = false;
500 const double input_size = CalculateInputSize(op_info, &unknown_shapes);
501 const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
502 Costs costs =
503 PredictOpCountBasedCost(operations, input_size, output_size, op_info);
504 costs.inaccurate = unknown_shapes;
505 costs.num_ops_with_unknown_shapes = unknown_shapes;
506 costs.max_memory = output_size;
507 return costs;
508 }
509
PredictOpCountBasedCost(double operations,double input_io_bytes,double output_io_bytes,const OpInfo & op_info) const510 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
511 double operations, double input_io_bytes, double output_io_bytes,
512 const OpInfo& op_info) const {
513 double total_io_bytes = input_io_bytes + output_io_bytes;
514 const DeviceInfo device_info = GetDeviceInfo(op_info.device());
515 if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
516 device_info.intermediate_read_gb_per_sec <= 0 ||
517 device_info.intermediate_write_gb_per_sec <= 0) {
518 VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
519 << " device type:" << op_info.device().type()
520 << " device model:" << op_info.device().model();
521 }
522
523 Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
524 VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
525 << " Compute Time (ns):" << compute_cost.count();
526
527 Costs::NanoSeconds memory_cost(
528 std::ceil(total_io_bytes / device_info.gb_per_sec));
529 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
530 << " Memory Time (ns):" << memory_cost.count();
531
532 // Check if bytes > 0. If it's not and the bandwidth is set to infinity
533 // then the result would be undefined.
534 double intermediate_read_time =
535 (input_io_bytes > 0)
536 ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
537 : 0;
538
539 double intermediate_write_time =
540 (output_io_bytes > 0)
541 ? std::ceil(output_io_bytes /
542 device_info.intermediate_write_gb_per_sec)
543 : 0;
544
545 Costs::NanoSeconds intermediate_memory_cost =
546 compute_memory_overlap_
547 ? std::max(intermediate_read_time, intermediate_write_time)
548 : (intermediate_read_time + intermediate_write_time);
549 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
550 << " Intermediate Memory Time (ns):"
551 << intermediate_memory_cost.count();
552
553 Costs costs;
554 costs.compute_time = compute_cost;
555 costs.memory_time = memory_cost;
556 costs.intermediate_memory_time = intermediate_memory_cost;
557 costs.intermediate_memory_read_time =
558 Costs::NanoSeconds(intermediate_read_time);
559 costs.intermediate_memory_write_time =
560 Costs::NanoSeconds(intermediate_write_time);
561 CombineCostsAndUpdateExecutionTime(&costs);
562 return costs;
563 }
564
CountConv2DOperations(const OpInfo & op_info,bool * found_unknown_shapes) const565 int64 OpLevelCostEstimator::CountConv2DOperations(
566 const OpInfo& op_info, bool* found_unknown_shapes) const {
567 return CountConv2DOperations(op_info, nullptr, found_unknown_shapes);
568 }
569
570 // Helper to translate the positional arguments into named fields.
571 OpLevelCostEstimator::ConvolutionDimensions
ConvolutionDimensionsFromInputs(const TensorShapeProto & original_image_shape,const TensorShapeProto & original_filter_shape,const OpInfo & op_info,bool * found_unknown_shapes)572 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
573 const TensorShapeProto& original_image_shape,
574 const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
575 bool* found_unknown_shapes) {
576 VLOG(2) << "op features: " << op_info.DebugString();
577 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
578 VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
579 auto image_shape =
580 MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
581 auto filter_shape =
582 MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes);
583 VLOG(2) << "Image shape: " << image_shape.DebugString();
584 VLOG(2) << "Filter shape: " << filter_shape.DebugString();
585
586 int x_index, y_index, channel_index;
587 const string& data_format = GetDataFormat(op_info);
588 if (data_format == "NCHW") {
589 x_index = 2;
590 y_index = 3;
591 channel_index = 1;
592 } else {
593 // Use NHWC.
594 x_index = 1;
595 y_index = 2;
596 channel_index = 3;
597 }
598 const string& filter_format = GetFilterFormat(op_info);
599 int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
600 if (filter_format == "HWIO") {
601 filter_x_index = 0;
602 filter_y_index = 1;
603 in_channel_index = 2;
604 out_channel_index = 3;
605 } else {
606 // Use OIHW
607 filter_x_index = 2;
608 filter_y_index = 3;
609 in_channel_index = 1;
610 out_channel_index = 0;
611 }
612 int64 batch = image_shape.dim(0).size();
613 int64 ix = image_shape.dim(x_index).size();
614 int64 iy = image_shape.dim(y_index).size();
615 int64 iz = image_shape.dim(channel_index).size();
616 int64 kx = filter_shape.dim(filter_x_index).size();
617 int64 ky = filter_shape.dim(filter_y_index).size();
618 std::vector<int64> strides = GetStrides(op_info);
619 const auto padding = GetPadding(op_info);
620 int64 sx = strides[x_index];
621 int64 sy = strides[y_index];
622 int64 ox = GetOutputSize(ix, kx, sx, padding);
623 int64 oy = GetOutputSize(iy, ky, sy, padding);
624 int64 oz = filter_shape.dim(out_channel_index).size();
625 // Only check equality when both sizes are known (in other words, when
626 // neither is set to a minimum dimension size of 1).
627 if (iz != 1 && filter_shape.dim(in_channel_index).size() != 1) {
628 CHECK_EQ(iz, filter_shape.dim(in_channel_index).size());
629 } else {
630 iz = std::max<int64>(iz, filter_shape.dim(in_channel_index).size());
631 }
632 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
633 batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
634
635 VLOG(1) << "Batch Size:" << batch;
636 VLOG(1) << "Image Dims:" << ix << "," << iy;
637 VLOG(1) << "Input Features:" << iz;
638 VLOG(1) << "Kernel Dims:" << kx << "," << ky;
639 VLOG(1) << "Output Features:" << oz;
640 VLOG(1) << "Output Dims:" << ox << "," << oy;
641 VLOG(1) << "Strides:" << sx << "," << sy;
642 VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
643 return conv_dims;
644 }
645
CountConv2DOperations(const OpInfo & op_info,ConvolutionDimensions * conv_info,bool * found_unknown_shapes) const646 int64 OpLevelCostEstimator::CountConv2DOperations(
647 const OpInfo& op_info, ConvolutionDimensions* conv_info,
648 bool* found_unknown_shapes) const {
649 DCHECK(op_info.op() == kConv2d || op_info.op() == kDepthwiseConv2dNative)
650 << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
651
652 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
653 op_info.inputs(0).shape(), op_info.inputs(1).shape(), op_info,
654 found_unknown_shapes);
655
656 // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
657 // multiplier; The effective output channel depth oz_effective is
658 // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
659 // Compare to Conv2D where # ops = N x H x W x iz x oz x 2RS,
660 // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = iz.
661 int64 ops = conv_dims.batch;
662 ops *= conv_dims.ox * conv_dims.oy;
663 ops *= conv_dims.kx * conv_dims.ky;
664 if (op_info.op() == kConv2d) {
665 ops *= conv_dims.iz * conv_dims.oz;
666 } else {
667 // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
668 // although ops are the same as Conv2D.
669 conv_dims.oz *= conv_dims.iz;
670 ops *= conv_dims.oz;
671 }
672 ops *= kOpsPerMac;
673
674 if (conv_info != nullptr) {
675 *conv_info = conv_dims;
676 }
677 return ops;
678 }
679
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const680 int64 OpLevelCostEstimator::CountMatMulOperations(
681 const OpInfo& op_info, bool* found_unknown_shapes) const {
682 return CountMatMulOperations(op_info, nullptr, found_unknown_shapes);
683 }
684
685 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
CountMatMulOperations(const OpInfo & op_info,MatMulDimensions * mat_mul,bool * found_unknown_shapes) const686 int64 OpLevelCostEstimator::CountMatMulOperations(
687 const OpInfo& op_info, MatMulDimensions* mat_mul,
688 bool* found_unknown_shapes) const {
689 double ops = 0;
690
691 if (op_info.inputs_size() < 2) {
692 LOG(ERROR) << "Need 2 inputs but got " << op_info.inputs_size();
693 // TODO(pcma): Try to separate invalid inputs from unknown shapes
694 *found_unknown_shapes = true;
695 return 0;
696 }
697
698 auto& a_matrix = op_info.inputs(0);
699 auto& b_matrix = op_info.inputs(1);
700
701 bool transpose_a = false;
702 bool transpose_b = false;
703
704 double m_dim, n_dim, k_dim, k_dim_b = 0;
705
706 for (const auto& item : op_info.attr()) {
707 VLOG(1) << "Key:" << item.first
708 << " Value:" << SummarizeAttrValue(item.second);
709 if (item.first == "transpose_a" && item.second.b() == true)
710 transpose_a = true;
711 if (item.first == "transpose_b" && item.second.b() == true)
712 transpose_b = true;
713 }
714 VLOG(1) << "transpose_a:" << transpose_a;
715 VLOG(1) << "transpose_b:" << transpose_b;
716 auto a_matrix_shape =
717 MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
718 auto b_matrix_shape =
719 MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
720 if (transpose_a) {
721 m_dim = a_matrix_shape.dim(1).size();
722 k_dim = a_matrix_shape.dim(0).size();
723 } else {
724 m_dim = a_matrix_shape.dim(0).size();
725 k_dim = a_matrix_shape.dim(1).size();
726 }
727 if (transpose_b) {
728 k_dim_b = b_matrix_shape.dim(1).size();
729 n_dim = b_matrix_shape.dim(0).size();
730 } else {
731 k_dim_b = b_matrix_shape.dim(0).size();
732 n_dim = b_matrix_shape.dim(1).size();
733 }
734
735 VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
736 // Only check equality when both sizes are known (in other words, when
737 // neither is set to a minimum dimension size of 1).
738 if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
739 LOG(ERROR) << "Incompatible Matrix dimensions";
740 return ops;
741 } else {
742 // One of k_dim and k_dim_b might be 1 (mininum dimension size).
743 k_dim = std::max(k_dim, k_dim_b);
744 }
745
746 ops = m_dim * n_dim * k_dim * 2;
747 VLOG(1) << "Operations for Matmul: " << ops;
748
749 if (mat_mul != nullptr) {
750 mat_mul->m = m_dim;
751 mat_mul->n = n_dim;
752 mat_mul->k = k_dim;
753 }
754 return ops;
755 }
756
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const757 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
758 const OpInfo& op_info, bool* found_unknown_shapes) const {
759 if (op_info.op() != kBatchMatMul) {
760 LOG(ERROR) << "Invalid Operation: " << op_info.op();
761 // TODO(pcma): Try to separate invalid inputs from unknown shapes
762 *found_unknown_shapes = true;
763 return 0;
764 }
765 if (op_info.inputs_size() != 2) {
766 LOG(ERROR) << "Expected 2 inputs but got " << op_info.inputs_size();
767 // TODO(pcma): Try to separate invalid inputs from unknown shapes
768 *found_unknown_shapes = true;
769 return 0;
770 }
771
772 double ops = 0;
773 const auto& a_input = op_info.inputs(0);
774 const auto& b_input = op_info.inputs(1);
775
776 // BatchMatMul requires inputs of at least matrix shape (rank 2).
777 // The two most minor dimensions of each input are matrices that
778 // need to be multiplied together. The other dimensions determine
779 // the number of such MatMuls. For example, if the BatchMatMul has
780 // inputs of shape:
781 // a_input_shape = [2, 3, 4, 5]
782 // b_input_shape = [2, 3, 5, 6]
783 // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
784 // in this BatchMatMul.
785 const int matrix_rank = 2;
786
787 bool a_input_shape_unknown = false;
788 bool b_input_shape_unknown = false;
789
790 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
791 a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
792 &a_input_shape_unknown);
793 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
794 b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
795 &b_input_shape_unknown);
796
797 *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
798 (a_input.shape().dim_size() < matrix_rank) ||
799 (b_input.shape().dim_size() < matrix_rank);
800
801 // Compute the number of matmuls as the max indicated at each dimension
802 // by either input. Note that the shapes do not have to have
803 // the same rank due to incompleteness.
804 TensorShapeProto* bigger_rank_shape = &a_input_shape;
805 TensorShapeProto* smaller_rank_shape = &b_input_shape;
806 if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
807 bigger_rank_shape = &b_input_shape;
808 smaller_rank_shape = &a_input_shape;
809 }
810 int num_matmuls = 1;
811 for (int b_i = 0,
812 s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
813 b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
814 int b_dim = bigger_rank_shape->dim(b_i).size();
815 int s_dim = 1;
816 if (s_i >= 0) {
817 s_dim = smaller_rank_shape->dim(s_i).size();
818 }
819 num_matmuls *= std::max(b_dim, s_dim);
820 }
821
822 // Build the MatMul. Note that values are ignored here since we are just
823 // counting ops (e.g. only shapes matter).
824 OpInfo matmul_op_info;
825 matmul_op_info.set_op("MatMul");
826
827 AttrValue transpose_a;
828 transpose_a.set_b(false);
829 if (op_info.attr().find("adj_x") != op_info.attr().end()) {
830 transpose_a.set_b(op_info.attr().at("adj_x").b());
831 }
832 (*matmul_op_info.mutable_attr())["transpose_a"] = transpose_a;
833
834 AttrValue transpose_b;
835 transpose_b.set_b(false);
836 if (op_info.attr().find("adj_y") != op_info.attr().end()) {
837 transpose_b.set_b(op_info.attr().at("adj_y").b());
838 }
839 (*matmul_op_info.mutable_attr())["transpose_b"] = transpose_b;
840
841 OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs();
842 a_matrix->set_dtype(a_input.dtype());
843 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
844 for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
845 i < a_input_shape.dim_size(); ++i) {
846 *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
847 }
848
849 OpInfo::TensorProperties* b_matrix = matmul_op_info.add_inputs();
850 b_matrix->set_dtype(b_input.dtype());
851 TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
852 for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
853 i < b_input_shape.dim_size(); ++i) {
854 *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
855 }
856
857 for (int i = 0; i < num_matmuls; ++i) {
858 bool matmul_unknown_shapes = false;
859 ops += CountMatMulOperations(matmul_op_info, &matmul_unknown_shapes);
860 *found_unknown_shapes |= matmul_unknown_shapes;
861 }
862 return ops;
863 }
864
GetTensorShapeProtoFromTensorProto(const TensorProto & tensor_proto,TensorShapeProto * tensor_shape_proto)865 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
866 TensorShapeProto* tensor_shape_proto) {
867 tensor_shape_proto->Clear();
868 // First convert TensorProto into Tensor class so that it correctly parses
869 // data values within TensorProto (whether it's in int_val, int64_val,
870 // tensor_content, or anything.
871 Tensor tensor(tensor_proto.dtype());
872 if (!tensor.FromProto(tensor_proto)) {
873 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
874 << "failed to parse TensorProto: "
875 << tensor_proto.DebugString();
876 return false;
877 }
878 if (tensor.dims() != 1) {
879 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
880 << "tensor is not 1D: " << tensor.dims();
881 return false;
882 }
883 // Then, convert it back to TensorProto using AsProtoField, which makes sure
884 // the data is in int_val, int64_val, or such repeated data fields, not in
885 // tensor_content.
886 TensorProto temp_tensor;
887 tensor.AsProtoField(&temp_tensor);
888
889 #define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type) \
890 do { \
891 for (const auto& value : temp_tensor.type##_val()) { \
892 tensor_shape_proto->add_dim()->set_size(value); \
893 } \
894 } while (0)
895
896 if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
897 tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
898 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
899 } else if (tensor.dtype() == DT_INT64) {
900 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
901 } else if (tensor.dtype() == DT_UINT32) {
902 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
903 } else if (tensor.dtype() == DT_UINT64) {
904 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
905 } else {
906 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
907 << "Unsupported dtype: " << tensor.dtype();
908 return false;
909 }
910 #undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
911
912 return true;
913 }
914
915 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
CountConv2DBackpropInputOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes) const916 int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
917 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
918 bool* found_unknown_shapes) const {
919 int64 ops = 0;
920
921 DCHECK(op_info.op() == kConv2dBackpropInput ||
922 op_info.op() == kDepthwiseConv2dNativeBackpropInput)
923 << "Invalid Operation: not kConv2dBackpropInput nor"
924 "kDepthwiseConv2dNativeBackpropInput";
925
926 if (op_info.inputs_size() < 2) {
927 // TODO(pcma): Try to separate invalid inputs from unknown shapes
928 *found_unknown_shapes = true;
929 return ops;
930 }
931
932 TensorShapeProto input_shape;
933 bool shape_found = false;
934 if (op_info.inputs(0).has_value()) {
935 const TensorProto& value = op_info.inputs(0).value();
936 shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
937 }
938 if (!shape_found && op_info.outputs_size() == 1) {
939 input_shape = op_info.outputs(0).shape();
940 shape_found = true;
941 }
942 if (!shape_found) {
943 // Set the minimum filter size that's feasible.
944 input_shape.Clear();
945 for (int i = 0; i < 4; ++i) {
946 input_shape.add_dim()->set_size(1);
947 }
948 *found_unknown_shapes = true;
949 }
950
951 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
952 input_shape, op_info.inputs(1).shape(), op_info, found_unknown_shapes);
953
954 ops = conv_dims.batch;
955 ops *= conv_dims.ox * conv_dims.oy;
956 ops *= conv_dims.kx * conv_dims.ky;
957 if (op_info.op() == kConv2dBackpropInput) {
958 ops *= conv_dims.iz * conv_dims.oz;
959 } else {
960 // conv_dims always use forward path definition regardless
961 conv_dims.oz *= conv_dims.iz;
962 ops *= conv_dims.oz;
963 }
964 ops *= kOpsPerMac;
965
966 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
967
968 if (returned_conv_dims != nullptr) {
969 *returned_conv_dims = conv_dims;
970 }
971 return ops;
972 }
973
CountConv2DBackpropFilterOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes) const974 int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
975 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
976 bool* found_unknown_shapes) const {
977 int64 ops = 0;
978
979 DCHECK(op_info.op() == kConv2dBackpropFilter ||
980 op_info.op() == kDepthwiseConv2dNativeBackpropFilter)
981 << "Invalid Operation: not kConv2dBackpropFilter nor"
982 "kDepthwiseConv2dNativeBackpropFilter";
983
984 TensorShapeProto filter_shape;
985 bool shape_found = false;
986 if (op_info.inputs_size() >= 2 && op_info.inputs(1).has_value()) {
987 const TensorProto& value = op_info.inputs(1).value();
988 shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
989 }
990 if (!shape_found && op_info.outputs_size() == 1) {
991 filter_shape = op_info.outputs(0).shape();
992 shape_found = true;
993 }
994 if (!shape_found) {
995 // Set the minimum filter size that's feasible.
996 filter_shape.Clear();
997 for (int i = 0; i < 4; ++i) {
998 filter_shape.add_dim()->set_size(1);
999 }
1000 *found_unknown_shapes = true;
1001 }
1002
1003 if (op_info.inputs_size() < 1) {
1004 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1005 *found_unknown_shapes = true;
1006 return ops;
1007 }
1008 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1009 op_info.inputs(0).shape(), filter_shape, op_info, found_unknown_shapes);
1010
1011 ops = conv_dims.batch;
1012 ops *= conv_dims.ox * conv_dims.oy;
1013 ops *= conv_dims.kx * conv_dims.ky;
1014 if (op_info.op() == kConv2dBackpropFilter) {
1015 ops *= conv_dims.iz * conv_dims.oz;
1016 } else {
1017 // conv_dims always use forward path definition regardless
1018 conv_dims.oz *= conv_dims.iz;
1019 ops *= conv_dims.oz;
1020 }
1021 ops *= kOpsPerMac;
1022 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
1023
1024 if (returned_conv_dims != nullptr) {
1025 *returned_conv_dims = conv_dims;
1026 }
1027 return ops;
1028 }
1029
CalculateTensorElementCount(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes) const1030 int64 OpLevelCostEstimator::CalculateTensorElementCount(
1031 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const {
1032 VLOG(2) << " with " << DataTypeString(tensor.dtype()) << " tensor of shape "
1033 << tensor.shape().DebugString();
1034 int64 tensor_size = 1;
1035 int num_dims = std::max(1, tensor.shape().dim_size());
1036 auto tensor_shape =
1037 MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
1038 for (const auto& dim : tensor_shape.dim()) {
1039 tensor_size *= dim.size();
1040 }
1041 return tensor_size;
1042 }
1043
CalculateTensorSize(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes) const1044 int64 OpLevelCostEstimator::CalculateTensorSize(
1045 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const {
1046 int64 count = CalculateTensorElementCount(tensor, found_unknown_shapes);
1047 int size = DataTypeSize(BaseType(tensor.dtype()));
1048 VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
1049 return count * size;
1050 }
1051
CalculateInputSize(const OpInfo & op_info,bool * found_unknown_shapes) const1052 int64 OpLevelCostEstimator::CalculateInputSize(
1053 const OpInfo& op_info, bool* found_unknown_shapes) const {
1054 int64 total_input_size = 0;
1055 for (auto& input : op_info.inputs()) {
1056 int64 input_size = CalculateTensorSize(input, found_unknown_shapes);
1057 total_input_size += input_size;
1058 VLOG(1) << "Input Size: " << input_size
1059 << " Total Input Size:" << total_input_size;
1060 }
1061 return total_input_size;
1062 }
1063
CalculateLargestInputCount(const OpInfo & op_info,bool * found_unknown_shapes) const1064 int64 OpLevelCostEstimator::CalculateLargestInputCount(
1065 const OpInfo& op_info, bool* found_unknown_shapes) const {
1066 int64 largest_input_count = 0;
1067 for (auto& input : op_info.inputs()) {
1068 int64 input_count =
1069 CalculateTensorElementCount(input, found_unknown_shapes);
1070 if (input_count > largest_input_count) {
1071 largest_input_count = input_count;
1072 }
1073 VLOG(1) << "Input Count: " << input_count
1074 << " Largest Input Count:" << largest_input_count;
1075 }
1076 return largest_input_count;
1077 }
1078
CalculateOutputSize(const OpInfo & op_info,bool * found_unknown_shapes) const1079 int64 OpLevelCostEstimator::CalculateOutputSize(
1080 const OpInfo& op_info, bool* found_unknown_shapes) const {
1081 int64 total_output_size = 0;
1082 // use float as default for calculations
1083 for (const auto& output : op_info.outputs()) {
1084 DataType dt = output.dtype();
1085 const auto& original_output_shape = output.shape();
1086 int64 output_size = DataTypeSize(BaseType(dt));
1087 int num_dims = std::max(1, original_output_shape.dim_size());
1088 auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1089 found_unknown_shapes);
1090 for (const auto& dim : output_shape.dim()) {
1091 output_size *= dim.size();
1092 }
1093 total_output_size += output_size;
1094 VLOG(1) << "Output Size: " << output_size
1095 << " Total Output Size:" << total_output_size;
1096 }
1097 return total_output_size;
1098 }
1099
PredictConv2D(const OpContext & op_context) const1100 Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
1101 const auto& op_info = op_context.op_info;
1102 bool found_unknown_shapes = false;
1103 auto costs = PredictOpCountBasedCost(
1104 CountConv2DOperations(op_info, &found_unknown_shapes), op_info);
1105 costs.inaccurate = found_unknown_shapes;
1106 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1107 return costs;
1108 }
1109
PredictConv2DBackpropInput(const OpContext & op_context) const1110 Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
1111 const OpContext& op_context) const {
1112 const auto& op_info = op_context.op_info;
1113 bool found_unknown_shapes = false;
1114 auto costs =
1115 PredictOpCountBasedCost(CountConv2DBackpropInputOperations(
1116 op_info, nullptr, &found_unknown_shapes),
1117 op_info);
1118 costs.inaccurate = found_unknown_shapes;
1119 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1120 return costs;
1121 }
1122
PredictConv2DBackpropFilter(const OpContext & op_context) const1123 Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
1124 const OpContext& op_context) const {
1125 const auto& op_info = op_context.op_info;
1126 bool found_unknown_shapes = false;
1127 auto costs =
1128 PredictOpCountBasedCost(CountConv2DBackpropFilterOperations(
1129 op_info, nullptr, &found_unknown_shapes),
1130 op_info);
1131 costs.inaccurate = found_unknown_shapes;
1132 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1133 return costs;
1134 }
1135
PredictFusedConv2DBiasActivation(const OpContext & op_context) const1136 Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
1137 const OpContext& op_context) const {
1138 // FusedConv2DBiasActivation computes a fused kernel which implements:
1139 // 2D convolution, adds side input with separate scaling on convolution and
1140 // side inputs, then adds bias, and finally applies the ReLU activation
1141 // function to the result:
1142 //
1143 // Input -> Conv2D -> Add -> BiasAdd -> ReLU
1144 // ^ ^ ^
1145 // Filter Side Input Bias
1146 //
1147 // Note that when adding the side input, the operation multiplies the output
1148 // of Conv2D by conv_input_scale, confusingly, and the side_input by
1149 // side_input_scale.
1150 //
1151 // Note that in the special case that side_input_scale is 0, which we infer
1152 // from side_input having dimensions [], we skip that addition operation.
1153 //
1154 // For more information, see
1155 // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
1156
1157 // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
1158 // filter formats (OIHW_VECT_I).
1159 string data_format = GetDataFormat(op_context.op_info);
1160 if (data_format != "NCHW" && data_format != "NHWC") {
1161 LOG(WARNING) << "unsupported data format: " << data_format;
1162 Costs cost = Costs::ZeroCosts();
1163 cost.inaccurate = true;
1164 return cost;
1165 }
1166 string filter_format = GetFilterFormat(op_context.op_info);
1167 if (filter_format != "HWIO" && filter_format != "OIHW") {
1168 LOG(WARNING) << "unsupported filter format: " << filter_format;
1169 Costs cost = Costs::ZeroCosts();
1170 cost.inaccurate = true;
1171 return cost;
1172 }
1173
1174 auto& conv_input = op_context.op_info.inputs(0);
1175 auto& filter = op_context.op_info.inputs(1);
1176 auto& bias = op_context.op_info.inputs(2);
1177 auto& side_input = op_context.op_info.inputs(3);
1178 auto& conv_input_scale = op_context.op_info.inputs(4);
1179 auto& side_input_scale = op_context.op_info.inputs(5);
1180
1181 // Manually compute our convolution dimensions.
1182 bool found_unknown_shapes = false;
1183 auto dims = ConvolutionDimensionsFromInputs(
1184 conv_input.shape(), filter.shape(), op_context.op_info,
1185 &found_unknown_shapes);
1186
1187 // Construct the shape of our output tensor from our convolution dimensions
1188 // and format, as it may not be available yet.
1189 // TODO(varomodt): should we centralize the Conv2D input/output shapes?
1190 OpInfo::TensorProperties output;
1191 if (data_format == "NCHW") {
1192 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
1193 } else if (data_format == "NHWC") {
1194 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
1195 }
1196
1197 // Add the operations the fused op always computes.
1198 std::vector<OpContext> component_ops = {
1199 FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
1200 FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
1201 FusedChildContext(op_context, "BiasAdd", output, {output, bias}),
1202 FusedChildContext(op_context, "Relu", output, {output})};
1203
1204 // Add our side_input iff it's non-empty.
1205 if (side_input.shape().dim_size() > 0) {
1206 component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
1207 {side_input, side_input_scale}));
1208 component_ops.push_back(
1209 FusedChildContext(op_context, "Add", output, {side_input, output}));
1210 }
1211
1212 // Construct an op_context which definitely has our output shape.
1213 auto op_context_with_output = op_context;
1214 op_context_with_output.op_info.mutable_outputs()->Clear();
1215 *op_context_with_output.op_info.mutable_outputs()->Add() = output;
1216
1217 // Construct component operations and run the cost computation.
1218 auto costs = PredictFusedOp(op_context_with_output, component_ops);
1219 costs.inaccurate |= found_unknown_shapes;
1220 costs.num_ops_with_unknown_shapes = costs.inaccurate;
1221 return costs;
1222 }
1223
PredictMatMul(const OpContext & op_context) const1224 Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
1225 const auto& op_info = op_context.op_info;
1226 bool found_unknown_shapes = false;
1227 auto costs = PredictOpCountBasedCost(
1228 CountMatMulOperations(op_info, &found_unknown_shapes), op_info);
1229 costs.inaccurate = found_unknown_shapes;
1230 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1231 return costs;
1232 }
1233
PredictSparseTensorDenseMatMul(const OpContext & op_context) const1234 Costs OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
1235 const OpContext& op_context) const {
1236 const auto& op_info = op_context.op_info;
1237 bool found_unknown_shapes = false;
1238 // input[0]: indices in sparse matrix a
1239 // input[1]: values in sparse matrix a
1240 // input[2]: shape of matrix a
1241 // input[3]: matrix b
1242 // See
1243 // https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
1244 int64 num_elems_in_a =
1245 CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1246 auto b_matrix = op_info.inputs(3);
1247 auto b_matrix_shape =
1248 MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
1249 int64 n_dim = b_matrix_shape.dim(1).size();
1250
1251 // Each element in A is multiplied and added with an element from each column
1252 // in b.
1253 const int64 op_count = kOpsPerMac * num_elems_in_a * n_dim;
1254
1255 int64 a_indices_input_size =
1256 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1257 int64 a_values_input_size =
1258 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1259 int64 a_shape_input_size =
1260 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1261 int64 b_input_size =
1262 num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
1263 double input_size = a_indices_input_size + a_values_input_size +
1264 a_shape_input_size + b_input_size;
1265
1266 double output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
1267
1268 auto costs =
1269 PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
1270 costs.inaccurate = found_unknown_shapes;
1271 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1272 costs.max_memory = output_size;
1273
1274 return costs;
1275 }
1276
PredictNoOp(const OpContext & op_context) const1277 Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
1278 const auto& op_info = op_context.op_info;
1279 VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1280 return Costs::ZeroCosts();
1281 }
1282
PredictIdentity(const OpContext & op_context) const1283 Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const {
1284 const auto& op_info = op_context.op_info;
1285 VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1286 Costs result = Costs::ZeroCosts();
1287 result.max_memory = CalculateOutputSize(op_info, &result.inaccurate);
1288 result.num_ops_with_unknown_shapes = result.inaccurate;
1289 // Assign the minimum amount of time we can represent to the identity op since
1290 // it tends to be really cheap.
1291 result.compute_time = kMinComputeTime;
1292 result.execution_time = result.compute_time;
1293 return result;
1294 }
1295
PredictVariable(const OpContext & op_context) const1296 Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
1297 const auto& op_info = op_context.op_info;
1298 VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1299 Costs result = Costs::ZeroCosts();
1300 result.persistent_memory = CalculateOutputSize(op_info, &result.inaccurate);
1301 result.num_ops_with_unknown_shapes = result.inaccurate;
1302
1303 result.compute_time = kMinComputeTime;
1304 result.execution_time = result.compute_time;
1305 return result;
1306 }
1307
PredictBatchMatMul(const OpContext & op_context) const1308 Costs OpLevelCostEstimator::PredictBatchMatMul(
1309 const OpContext& op_context) const {
1310 const auto& op_info = op_context.op_info;
1311 bool found_unknown_shapes = false;
1312 Costs costs = PredictOpCountBasedCost(
1313 CountBatchMatMulOperations(op_info, &found_unknown_shapes), op_info);
1314 costs.inaccurate = found_unknown_shapes;
1315 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1316 return costs;
1317 }
1318
PredictMetadata(const OpContext & op_context) const1319 Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
1320 const auto& op_info = op_context.op_info;
1321 Costs costs = Costs::ZeroCosts();
1322 costs.max_memory = CalculateOutputSize(op_info, &costs.inaccurate);
1323 costs.num_ops_with_unknown_shapes = costs.inaccurate;
1324 // Metadata operations are so cheap we assume they take the minimum amount of
1325 // time we can represent (1 ns).
1326 costs.compute_time = kMinComputeTime;
1327 costs.execution_time = costs.compute_time;
1328
1329 return costs;
1330 }
1331
PredictGatherOrSlice(const OpContext & op_context) const1332 Costs OpLevelCostEstimator::PredictGatherOrSlice(
1333 const OpContext& op_context) const {
1334 // Gather & Slice ops can have a very large input, but only access a small
1335 // part of it. For these op the size of the output determines the memory cost.
1336 const auto& op_info = op_context.op_info;
1337
1338 const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
1339 if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
1340 Costs costs = Costs::ZeroCosts();
1341 costs.inaccurate = true;
1342 return costs;
1343 }
1344
1345 bool unknown_shapes = false;
1346
1347 // Each output element is a copy of some element from input.
1348 // For roofline estimate we assume each copy has a unit cost.
1349 const int64 op_count =
1350 CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
1351
1352 const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
1353 double input_size = output_size;
1354 if (op_info.op() == "Slice") {
1355 // Add 'begin' & 'size' tensors sizes.
1356 input_size +=
1357 CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) +
1358 CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes);
1359 } else {
1360 // Assuming this is "Gather" or "GatherV2" op, add 'indices' size.
1361 input_size +=
1362 CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
1363 }
1364
1365 Costs costs =
1366 PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
1367 costs.inaccurate = unknown_shapes;
1368 costs.num_ops_with_unknown_shapes = unknown_shapes;
1369 costs.max_memory = output_size;
1370
1371 return costs;
1372 }
1373
PredictFusedOp(const OpContext & op_context,const std::vector<OpContext> & fused_op_contexts) const1374 Costs OpLevelCostEstimator::PredictFusedOp(
1375 const OpContext& op_context,
1376 const std::vector<OpContext>& fused_op_contexts) const {
1377 // Note that PredictOpCountBasedCost will get the correct memory_time from
1378 // the node's inputs and outputs; but we don't want to have to re-implement
1379 // the logic for computing the operation count of each of our component
1380 // operations here; so we simply add the compute times of each component
1381 // operation, then update the execution time.
1382 Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
1383
1384 fused_cost.compute_time = 0;
1385 fused_cost.inaccurate = false;
1386 for (auto& fused_op : fused_op_contexts) {
1387 auto op_cost = PredictCosts(fused_op);
1388
1389 fused_cost.compute_time += op_cost.compute_time;
1390 fused_cost.inaccurate |= op_cost.inaccurate;
1391 fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
1392 }
1393
1394 CombineCostsAndUpdateExecutionTime(&fused_cost);
1395 return fused_cost;
1396 }
1397
1398 /* static */
FusedChildContext(const OpContext & parent,const string & op_name,const OpInfo::TensorProperties & output,const std::vector<OpInfo::TensorProperties> & inputs)1399 OpContext OpLevelCostEstimator::FusedChildContext(
1400 const OpContext& parent, const string& op_name,
1401 const OpInfo::TensorProperties& output,
1402 const std::vector<OpInfo::TensorProperties>& inputs) {
1403 // Setup the base parameters of our new context.
1404 OpContext new_context;
1405 new_context.name = op_name;
1406 new_context.device_name = parent.device_name;
1407 new_context.op_info = parent.op_info;
1408 new_context.op_info.set_op(op_name);
1409
1410 // Setup the inputs of our new context.
1411 new_context.op_info.mutable_inputs()->Clear();
1412 for (const auto& input : inputs) {
1413 *new_context.op_info.mutable_inputs()->Add() = input;
1414 }
1415
1416 // Setup the output of our new context.
1417 new_context.op_info.mutable_outputs()->Clear();
1418 *new_context.op_info.mutable_outputs()->Add() = output;
1419
1420 return new_context;
1421 }
1422
1423 /* static */
DescribeTensor(DataType type,const std::vector<int64> & dims)1424 OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
1425 DataType type, const std::vector<int64>& dims) {
1426 OpInfo::TensorProperties ret;
1427 ret.set_dtype(type);
1428
1429 auto shape = ret.mutable_shape();
1430 for (const int dim : dims) {
1431 shape->add_dim()->set_size(dim);
1432 }
1433
1434 return ret;
1435 }
1436
1437 /* static */
1438 OpLevelCostEstimator::ConvolutionDimensions
OpDimensionsFromInputs(const TensorShapeProto & original_image_shape,const OpInfo & op_info,bool * found_unknown_shapes)1439 OpLevelCostEstimator::OpDimensionsFromInputs(
1440 const TensorShapeProto& original_image_shape, const OpInfo& op_info,
1441 bool* found_unknown_shapes) {
1442 VLOG(2) << "op features: " << op_info.DebugString();
1443 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
1444 auto image_shape =
1445 MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
1446 VLOG(2) << "Image shape: " << image_shape.DebugString();
1447
1448 int x_index, y_index, channel_index;
1449 const string& data_format = GetDataFormat(op_info);
1450 if (data_format == "NCHW") {
1451 x_index = 2;
1452 y_index = 3;
1453 channel_index = 1;
1454 } else {
1455 x_index = 1;
1456 y_index = 2;
1457 channel_index = 3;
1458 }
1459 int64 batch = image_shape.dim(0).size();
1460 int64 ix = image_shape.dim(x_index).size();
1461 int64 iy = image_shape.dim(y_index).size();
1462 int64 iz = image_shape.dim(channel_index).size();
1463
1464 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
1465 // {1, 1, 1, 1} in that case.
1466 std::vector<int64> ksize = GetKernelSize(op_info);
1467 int64 kx = ksize[x_index];
1468 int64 ky = ksize[y_index];
1469
1470 std::vector<int64> strides = GetStrides(op_info);
1471 int64 sx = strides[x_index];
1472 int64 sy = strides[y_index];
1473 const auto padding = GetPadding(op_info);
1474
1475 int64 ox = GetOutputSize(ix, kx, sx, padding);
1476 int64 oy = GetOutputSize(iy, ky, sy, padding);
1477 int64 oz = iz;
1478
1479 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
1480 batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
1481 return conv_dims;
1482 }
1483
PredictMaxPool(const OpContext & op_context) const1484 Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
1485 bool found_unknown_shapes = false;
1486 const auto& op_info = op_context.op_info;
1487 // x: op_info.inputs(0)
1488 ConvolutionDimensions dims = OpDimensionsFromInputs(
1489 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1490 // kx * ky - 1 comparisons per output (kx * xy > 1)
1491 // or 1 copy per output (kx * k1 = 1).
1492 int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
1493 int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
1494
1495 double total_input_size = 0;
1496 if (dims.ky >= dims.sy) {
1497 total_input_size =
1498 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1499 } else { // dims.ky < dims.sy
1500 // Vertical stride is larger than vertical kernel; assuming row-major
1501 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
1502 // others are not used for output).
1503 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
1504 total_input_size =
1505 data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
1506 }
1507 const double total_output_size =
1508 CalculateOutputSize(op_info, &found_unknown_shapes);
1509
1510 Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1511 total_output_size, op_info);
1512 costs.inaccurate = found_unknown_shapes;
1513 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1514 costs.max_memory = total_output_size;
1515 return costs;
1516 }
1517
PredictMaxPoolGrad(const OpContext & op_context) const1518 Costs OpLevelCostEstimator::PredictMaxPoolGrad(
1519 const OpContext& op_context) const {
1520 bool found_unknown_shapes = false;
1521 const auto& op_info = op_context.op_info;
1522 // x: op_info.inputs(0)
1523 // y: op_info.inputs(1)
1524 // y_grad: op_info.inputs(2)
1525 ConvolutionDimensions dims = OpDimensionsFromInputs(
1526 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1527
1528 int64 ops = 0;
1529 if (dims.kx == 1 && dims.ky == 1) {
1530 // 1x1 window. No need to know which input was max.
1531 ops = dims.batch * dims.ix * dims.iy * dims.iz;
1532 } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
1533 // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
1534 ops = dims.batch * dims.iz *
1535 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
1536 } else {
1537 // Overlapping window: initialize with zeros, re-run maxpool, then
1538 // accumulate y_gad to proper x_grad locations.
1539 ops = dims.batch * dims.iz *
1540 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
1541 }
1542
1543 // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
1544 // MaxPool internally.
1545 double total_input_size =
1546 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1547 total_input_size +=
1548 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1549 // Write x_grad; size equal to x.
1550 const double total_output_size =
1551 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1552
1553 Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1554 total_output_size, op_info);
1555 costs.inaccurate = found_unknown_shapes;
1556 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1557 costs.max_memory = total_output_size;
1558 return costs;
1559 }
1560
PredictAvgPool(const OpContext & op_context) const1561 Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
1562 bool found_unknown_shapes = false;
1563 const auto& op_info = op_context.op_info;
1564 // x: op_info.inputs(0)
1565 ConvolutionDimensions dims = OpDimensionsFromInputs(
1566 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1567
1568 // kx * ky - 1 additions and 1 multiplication per output.
1569 int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
1570
1571 double total_input_size = 0;
1572 if (dims.ky >= dims.sy) {
1573 total_input_size =
1574 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1575 } else { // dims.ky < dims.sy
1576 // vertical stride is larger than vertical kernel; assuming row-major
1577 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
1578 // others are not used for output).
1579 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
1580 total_input_size =
1581 data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
1582 }
1583 const double total_output_size =
1584 CalculateOutputSize(op_info, &found_unknown_shapes);
1585
1586 Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1587 total_output_size, op_info);
1588 costs.inaccurate = found_unknown_shapes;
1589 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1590 costs.max_memory = total_output_size;
1591 return costs;
1592 }
1593
PredictAvgPoolGrad(const OpContext & op_context) const1594 Costs OpLevelCostEstimator::PredictAvgPoolGrad(
1595 const OpContext& op_context) const {
1596 bool found_unknown_shapes = false;
1597 const auto& op_info = op_context.op_info;
1598 // x's shape: op_info.inputs(0)
1599 // y_grad: op_info.inputs(1)
1600
1601 // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
1602 bool shape_found = false;
1603 TensorShapeProto x_shape;
1604 if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
1605 const TensorProto& value = op_info.inputs(0).value();
1606 shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
1607 }
1608 if (!shape_found && op_info.outputs_size() > 0) {
1609 x_shape = op_info.outputs(0).shape();
1610 shape_found = true;
1611 }
1612 if (!shape_found) {
1613 // Set the minimum shape that's feasible.
1614 x_shape.Clear();
1615 for (int i = 0; i < 4; ++i) {
1616 x_shape.add_dim()->set_size(1);
1617 }
1618 found_unknown_shapes = true;
1619 }
1620
1621 ConvolutionDimensions dims =
1622 OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
1623
1624 int64 ops = 0;
1625 if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
1626 // Non-overlapping window.
1627 ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
1628 } else {
1629 // Overlapping window.
1630 ops = dims.batch * dims.iz *
1631 (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
1632 }
1633
1634 const double total_input_size =
1635 CalculateInputSize(op_info, &found_unknown_shapes);
1636 const double total_output_size =
1637 CalculateOutputSize(op_info, &found_unknown_shapes);
1638
1639 Costs costs = PredictOpCountBasedCost(ops, total_input_size,
1640 total_output_size, op_info);
1641 costs.inaccurate = found_unknown_shapes;
1642 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1643 costs.max_memory = total_output_size;
1644 return costs;
1645 }
1646
PredictFusedBatchNorm(const OpContext & op_context) const1647 Costs OpLevelCostEstimator::PredictFusedBatchNorm(
1648 const OpContext& op_context) const {
1649 bool found_unknown_shapes = false;
1650 const auto& op_info = op_context.op_info;
1651 // x: op_info.inputs(0)
1652 // scale: op_info.inputs(1)
1653 // offset: op_info.inputs(2)
1654 // mean: op_info.inputs(3) --> only for inference
1655 // variance: op_info.inputs(4) --> only for inference
1656 ConvolutionDimensions dims = OpDimensionsFromInputs(
1657 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
1658 const bool is_training = IsTraining(op_info);
1659
1660 int64 ops = 0;
1661 const auto rsqrt_cost = Eigen::internal::functor_traits<
1662 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
1663 if (is_training) {
1664 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
1665 } else {
1666 ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
1667 }
1668
1669 const double size_nhwc =
1670 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1671 const double size_c =
1672 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1673 double total_input_size = 0.0;
1674 double total_internal_read_size = 0.0;
1675 double total_output_size = 0.0;
1676 if (is_training) {
1677 total_input_size = size_nhwc + size_c * 2;
1678 total_output_size = size_nhwc + size_c * 4;
1679 total_internal_read_size = size_nhwc;
1680 } else {
1681 total_input_size = size_nhwc + size_c * 4;
1682 total_output_size = size_nhwc;
1683 }
1684
1685 Costs costs =
1686 PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
1687 total_output_size, op_info);
1688 costs.inaccurate = found_unknown_shapes;
1689 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1690 costs.max_memory = total_output_size;
1691 return costs;
1692 }
1693
PredictFusedBatchNormGrad(const OpContext & op_context) const1694 Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
1695 const OpContext& op_context) const {
1696 bool found_unknown_shapes = false;
1697 const auto& op_info = op_context.op_info;
1698 // y_backprop: op_info.inputs(0)
1699 // x: op_info.inputs(1)
1700 // scale: op_info.inputs(2)
1701 // mean: op_info.inputs(3)
1702 // variance or inverse of variance: op_info.inputs(4)
1703 ConvolutionDimensions dims = OpDimensionsFromInputs(
1704 op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
1705
1706 int64 ops = 0;
1707 const auto rsqrt_cost = Eigen::internal::functor_traits<
1708 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
1709 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
1710
1711 const double size_nhwc =
1712 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1713 const double size_c =
1714 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1715 double total_input_size = size_nhwc * 2 + size_c * 2;
1716 double total_internal_read_size = size_nhwc;
1717 double total_output_size = size_nhwc * 1 + size_c * 2;
1718
1719 Costs costs =
1720 PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
1721 total_output_size, op_info);
1722 costs.inaccurate = found_unknown_shapes;
1723 costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1724 costs.max_memory = total_output_size;
1725 return costs;
1726 }
1727
1728 /* static */
CombineCostsAndUpdateExecutionTime(Costs * costs) const1729 void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
1730 Costs* costs) const {
1731 if (compute_memory_overlap_) {
1732 costs->execution_time =
1733 std::max(costs->intermediate_memory_time,
1734 std::max(costs->compute_time, costs->memory_time));
1735 } else {
1736 costs->execution_time = costs->compute_time + costs->memory_time +
1737 costs->intermediate_memory_time;
1738 }
1739 }
1740 } // end namespace grappler
1741 } // end namespace tensorflow
1742