1
2 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
18
19 #include "absl/strings/match.h"
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/costs/op_context.h"
28 #include "tensorflow/core/grappler/costs/utils.h"
29 #include "tensorflow/core/platform/errors.h"
30
31 namespace tensorflow {
32 namespace grappler {
33
34 // TODO(dyoon): update op to Predict method map for TF ops with V2 or V3 suffix.
35 constexpr int kOpsPerMac = 2;
36 constexpr char kGuaranteeConst[] = "GuaranteeConst";
37 constexpr char kAddN[] = "AddN";
38 constexpr char kBitCast[] = "BitCast";
39 constexpr char kConcatV2[] = "ConcatV2";
40 constexpr char kConv2d[] = "Conv2D";
41 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
42 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
43 constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
44 constexpr char kDataFormatVecPermute[] = "DataFormatVecPermute";
45 constexpr char kDepthToSpace[] = "DepthToSpace";
46 constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
47 constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
48 "DepthwiseConv2dNativeBackpropFilter";
49 constexpr char kDepthwiseConv2dNativeBackpropInput[] =
50 "DepthwiseConv2dNativeBackpropInput";
51 constexpr char kMatMul[] = "MatMul";
52 constexpr char kXlaEinsum[] = "XlaEinsum";
53 constexpr char kEinsum[] = "Einsum";
54 constexpr char kExpandDims[] = "ExpandDims";
55 constexpr char kFill[] = "Fill";
56 constexpr char kSparseMatMul[] = "SparseMatMul";
57 constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
58 constexpr char kPlaceholder[] = "Placeholder";
59 constexpr char kIdentity[] = "Identity";
60 constexpr char kIdentityN[] = "IdentityN";
61 constexpr char kRefIdentity[] = "RefIdentity";
62 constexpr char kNoOp[] = "NoOp";
63 constexpr char kReshape[] = "Reshape";
64 constexpr char kSplit[] = "Split";
65 constexpr char kSqueeze[] = "Squeeze";
66 constexpr char kRecv[] = "_Recv";
67 constexpr char kSend[] = "_Send";
68 constexpr char kBatchMatMul[] = "BatchMatMul";
69 constexpr char kBatchMatMulV2[] = "BatchMatMulV2";
70 constexpr char kOneHot[] = "OneHot";
71 constexpr char kPack[] = "Pack";
72 constexpr char kRank[] = "Rank";
73 constexpr char kRange[] = "Range";
74 constexpr char kShape[] = "Shape";
75 constexpr char kShapeN[] = "ShapeN";
76 constexpr char kSize[] = "Size";
77 constexpr char kStopGradient[] = "StopGradient";
78 constexpr char kPreventGradient[] = "PreventGradient";
79 constexpr char kGather[] = "Gather";
80 constexpr char kGatherNd[] = "GatherNd";
81 constexpr char kGatherV2[] = "GatherV2";
82 constexpr char kScatterAdd[] = "ScatterAdd";
83 constexpr char kScatterDiv[] = "ScatterDiv";
84 constexpr char kScatterMax[] = "ScatterMax";
85 constexpr char kScatterMin[] = "ScatterMin";
86 constexpr char kScatterMul[] = "ScatterMul";
87 constexpr char kScatterSub[] = "ScatterSub";
88 constexpr char kScatterUpdate[] = "ScatterUpdate";
89 constexpr char kSlice[] = "Slice";
90 constexpr char kStridedSlice[] = "StridedSlice";
91 constexpr char kSpaceToDepth[] = "SpaceToDepth";
92 constexpr char kTranspose[] = "Transpose";
93 constexpr char kTile[] = "Tile";
94 constexpr char kMaxPool[] = "MaxPool";
95 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
96 constexpr char kAvgPool[] = "AvgPool";
97 constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
98 constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
99 constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
100 constexpr char kQuantizedMatMul[] = "QuantizedMatMul";
101 constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
102 constexpr char kUnpack[] = "Unpack";
103 constexpr char kSoftmax[] = "Softmax";
104 constexpr char kResizeBilinear[] = "ResizeBilinear";
105 constexpr char kCropAndResize[] = "CropAndResize";
106 // Dynamic control flow ops.
107 constexpr char kSwitch[] = "Switch";
108 constexpr char kMerge[] = "Merge";
109 constexpr char kEnter[] = "Enter";
110 constexpr char kExit[] = "Exit";
111 constexpr char kNextIteration[] = "NextIteration";
112 // Persistent ops.
113 constexpr char kConst[] = "Const";
114 constexpr char kVariable[] = "Variable";
115 constexpr char kVariableV2[] = "VariableV2";
116 constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
117 constexpr char kVarHandleOp[] = "VarHandleOp";
118 constexpr char kVarHandlesOp[] = "_VarHandlesOp";
119 constexpr char kReadVariableOp[] = "ReadVariableOp";
120 constexpr char kReadVariablesOp[] = "_ReadVariablesOp";
121 constexpr char kAssignVariableOp[] = "AssignVariableOp";
122 constexpr char kAssignAddVariableOp[] = "AssignAddVariableOp";
123 constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp";
124
125 static const Costs::Duration kMinComputeTime(1);
126 static const int64 kMinComputeOp = 1;
127
128 namespace {
129
GetDataFormat(const OpInfo & op_info)130 std::string GetDataFormat(const OpInfo& op_info) {
131 std::string data_format = "NHWC"; // Default format.
132 if (op_info.attr().find("data_format") != op_info.attr().end()) {
133 data_format = op_info.attr().at("data_format").s();
134 }
135 return data_format;
136 }
137
GetFilterFormat(const OpInfo & op_info)138 std::string GetFilterFormat(const OpInfo& op_info) {
139 std::string filter_format = "HWIO"; // Default format.
140 if (op_info.attr().find("filter_format") != op_info.attr().end()) {
141 filter_format = op_info.attr().at("filter_format").s();
142 }
143 return filter_format;
144 }
145
GetPadding(const OpInfo & op_info)146 Padding GetPadding(const OpInfo& op_info) {
147 if (op_info.attr().find("padding") != op_info.attr().end() &&
148 op_info.attr().at("padding").s() == "VALID") {
149 return Padding::VALID;
150 }
151 return Padding::SAME; // Default padding.
152 }
153
IsTraining(const OpInfo & op_info)154 bool IsTraining(const OpInfo& op_info) {
155 if (op_info.attr().find("is_training") != op_info.attr().end() &&
156 op_info.attr().at("is_training").b()) {
157 return true;
158 }
159 return false;
160 }
161
162 // TODO(dyoon): support non-4D tensors in the cost functions of convolution
163 // related ops (Conv, Pool, BatchNorm, and their backprops) and the related
164 // helper functions.
GetStrides(const OpInfo & op_info)165 std::vector<int64> GetStrides(const OpInfo& op_info) {
166 if (op_info.attr().find("strides") != op_info.attr().end()) {
167 const auto strides = op_info.attr().at("strides").list().i();
168 DCHECK(strides.size() == 4)
169 << "Attr strides is not a length-4 vector: " << op_info.DebugString();
170 if (strides.size() != 4) return {1, 1, 1, 1};
171 return {strides[0], strides[1], strides[2], strides[3]};
172 }
173 return {1, 1, 1, 1};
174 }
175
GetKernelSize(const OpInfo & op_info)176 std::vector<int64> GetKernelSize(const OpInfo& op_info) {
177 if (op_info.attr().find("ksize") != op_info.attr().end()) {
178 const auto ksize = op_info.attr().at("ksize").list().i();
179 DCHECK(ksize.size() == 4)
180 << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
181 if (ksize.size() != 4) return {1, 1, 1, 1};
182 return {ksize[0], ksize[1], ksize[2], ksize[3]};
183 }
184 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
185 // {1, 1, 1, 1} in that case.
186 return {1, 1, 1, 1};
187 }
188
GetOutputSize(const int64 input,const int64 filter,const int64 stride,const Padding & padding)189 int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
190 const Padding& padding) {
191 // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
192 // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
193 if (padding == Padding::VALID) {
194 return (input - filter + stride) / stride;
195 } else { // SAME.
196 return (input + stride - 1) / stride;
197 }
198 }
199
200 // Return the output element count of a multi-input element-wise op considering
201 // broadcasting.
CwiseOutputElementCount(const OpInfo & op_info)202 int64 CwiseOutputElementCount(const OpInfo& op_info) {
203 int max_rank = 1;
204 for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
205 max_rank = std::max(max_rank, input_properties.shape().dim_size());
206 }
207
208 TensorShapeProto output_shape;
209 output_shape.mutable_dim()->Reserve(max_rank);
210 for (int i = 0; i < max_rank; ++i) {
211 output_shape.add_dim();
212 }
213
214 // Expand the shape of the output to follow the numpy-style broadcast rule
215 // which matches each input starting with the trailing dimensions and working
216 // its way forward. To do this, iterate through each input shape's dimensions
217 // in reverse order, and potentially increase the corresponding output
218 // dimension.
219 for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
220 const TensorShapeProto& input_shape = input_properties.shape();
221 for (int i = input_shape.dim_size() - 1; i >= 0; --i) {
222 int output_shape_dim_index =
223 i + output_shape.dim_size() - input_shape.dim_size();
224 output_shape.mutable_dim(output_shape_dim_index)
225 ->set_size(std::max(output_shape.dim(output_shape_dim_index).size(),
226 input_shape.dim(i).size()));
227 }
228 }
229
230 int64 count = 1;
231 for (int i = 0; i < output_shape.dim_size(); i++) {
232 count *= output_shape.dim(i).size();
233 }
234 return count;
235 }
236
237 // Helper function for determining whether there are repeated indices in the
238 // input Einsum equation.
CheckRepeatedDimensions(const absl::string_view dim_str)239 bool CheckRepeatedDimensions(const absl::string_view dim_str) {
240 int str_size = dim_str.size();
241 for (int idx = 0; idx < str_size - 1; idx++) {
242 if (dim_str.find(dim_str[idx], idx + 1) != std::string::npos) {
243 return true;
244 }
245 }
246 return false;
247 }
248
249 // Auxiliary function for determining whether OpLevelCostEstimator is compatible
250 // with a given Einsum.
IsEinsumCorrectlyFormed(const OpContext & einsum_context)251 bool IsEinsumCorrectlyFormed(const OpContext& einsum_context) {
252 const auto& op_info = einsum_context.op_info;
253
254 auto it = op_info.attr().find("equation");
255 if (it == op_info.attr().end()) return false;
256 const absl::string_view equation = it->second.s();
257 std::vector<std::string> equation_split = absl::StrSplit(equation, "->");
258
259 if (equation_split.empty()) {
260 LOG(WARNING) << "Einsum with malformed equation";
261 return false;
262 }
263 std::vector<absl::string_view> input_split =
264 absl::StrSplit(equation_split[0], ',');
265
266 // The current model covers Einsum operations with two operands and a RHS
267 if (op_info.inputs_size() != 2 || equation_split.size() != 2) {
268 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
269 return false;
270 }
271 const auto& a_input = op_info.inputs(0);
272 const auto& b_input = op_info.inputs(1);
273 absl::string_view rhs_str = equation_split[1];
274 absl::string_view a_input_str = input_split[0];
275 absl::string_view b_input_str = input_split[1];
276
277 // Ellipsis are not currently supported
278 if (absl::StrContains(a_input_str, "...") ||
279 absl::StrContains(b_input_str, "...")) {
280 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
281 << ", ellipsis not supported";
282 return false;
283 }
284
285 constexpr int kMatrixRank = 2;
286
287 bool a_input_shape_unknown = false;
288 bool b_input_shape_unknown = false;
289
290 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
291 a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
292 &a_input_shape_unknown);
293 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
294 b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
295 &b_input_shape_unknown);
296
297 if (a_input_str.size() != static_cast<size_t>(a_input_shape.dim_size()) ||
298 b_input_str.size() != static_cast<size_t>(b_input_shape.dim_size())) {
299 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
300 << ", equation subscripts don't match tensor rank.";
301 return false;
302 }
303
304 // Subscripts where axis appears more than once for a single input are not yet
305 // supported
306 if (CheckRepeatedDimensions(a_input_str) ||
307 CheckRepeatedDimensions(b_input_str) ||
308 CheckRepeatedDimensions(rhs_str)) {
309 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
310 << ", Subscripts where axis appears more than once for a single "
311 "input are not yet supported";
312 return false;
313 }
314
315 return true;
316 }
317
318 } // namespace
319
320 // Return a minimum shape if the shape is unknown. If known, return the original
321 // shape.
MaybeGetMinimumShape(const TensorShapeProto & original_shape,int rank,bool * found_unknown_shapes)322 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
323 int rank, bool* found_unknown_shapes) {
324 auto shape = original_shape;
325 bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
326
327 if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
328 *found_unknown_shapes = true;
329 VLOG(2) << "Use minimum shape because the rank is unknown.";
330 // The size of each dimension is at least 1, if unknown.
331 for (int i = shape.dim_size(); i < rank; i++) {
332 shape.add_dim()->set_size(1);
333 }
334 } else if (is_scalar) {
335 for (int i = 0; i < rank; i++) {
336 shape.add_dim()->set_size(1);
337 }
338 } else if (shape.dim_size() > rank) {
339 *found_unknown_shapes = true;
340 shape.clear_dim();
341 for (int i = 0; i < rank; i++) {
342 shape.add_dim()->set_size(original_shape.dim(i).size());
343 }
344 } else {
345 for (int i = 0; i < shape.dim_size(); i++) {
346 if (shape.dim(i).size() < 0) {
347 *found_unknown_shapes = true;
348 VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
349 // The size of each dimension is at least 1, if unknown.
350 shape.mutable_dim(i)->set_size(1);
351 }
352 }
353 }
354 return shape;
355 }
356
OpLevelCostEstimator()357 OpLevelCostEstimator::OpLevelCostEstimator() {
358 // Syntactic sugar to build and return a lambda that takes an OpInfo and
359 // returns a cost.
360 typedef Status (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context,
361 NodeCosts*) const;
362 auto wrap = [this](CostImpl impl)
363 -> std::function<Status(const OpContext&, NodeCosts*)> {
364 return [this, impl](const OpContext& op_context, NodeCosts* node_costs) {
365 return (this->*impl)(op_context, node_costs);
366 };
367 };
368
369 device_cost_impl_.emplace(kConv2d,
370 wrap(&OpLevelCostEstimator::PredictConv2D));
371 device_cost_impl_.emplace(
372 kConv2dBackpropFilter,
373 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
374 device_cost_impl_.emplace(
375 kConv2dBackpropInput,
376 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
377 device_cost_impl_.emplace(
378 kFusedConv2dBiasActivation,
379 wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation));
380 // reuse Conv2D for DepthwiseConv2dNative because the calculation is the
381 // same although the actual meaning of the parameters are different. See
382 // comments in PredictConv2D and related functions
383 device_cost_impl_.emplace(kDepthwiseConv2dNative,
384 wrap(&OpLevelCostEstimator::PredictConv2D));
385 device_cost_impl_.emplace(
386 kDepthwiseConv2dNativeBackpropFilter,
387 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
388 device_cost_impl_.emplace(
389 kDepthwiseConv2dNativeBackpropInput,
390 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
391 device_cost_impl_.emplace(kMatMul,
392 wrap(&OpLevelCostEstimator::PredictMatMul));
393 device_cost_impl_.emplace(kSparseMatMul,
394 wrap(&OpLevelCostEstimator::PredictMatMul));
395 device_cost_impl_.emplace(
396 kSparseTensorDenseMatMul,
397 wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul));
398 device_cost_impl_.emplace(kBatchMatMul,
399 wrap(&OpLevelCostEstimator::PredictBatchMatMul));
400 device_cost_impl_.emplace(kBatchMatMulV2,
401 wrap(&OpLevelCostEstimator::PredictBatchMatMul));
402 device_cost_impl_.emplace(kQuantizedMatMul,
403 wrap(&OpLevelCostEstimator::PredictMatMul));
404 device_cost_impl_.emplace(kQuantizedMatMulV2,
405 wrap(&OpLevelCostEstimator::PredictMatMul));
406 device_cost_impl_.emplace(kXlaEinsum,
407 wrap(&OpLevelCostEstimator::PredictEinsum));
408 device_cost_impl_.emplace(kEinsum,
409 wrap(&OpLevelCostEstimator::PredictEinsum));
410
411 device_cost_impl_.emplace(kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp));
412 device_cost_impl_.emplace(kGuaranteeConst,
413 wrap(&OpLevelCostEstimator::PredictNoOp));
414
415 device_cost_impl_.emplace(kGather,
416 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
417 device_cost_impl_.emplace(kGatherNd,
418 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
419 device_cost_impl_.emplace(kGatherV2,
420 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
421 device_cost_impl_.emplace(kScatterAdd,
422 wrap(&OpLevelCostEstimator::PredictScatter));
423 device_cost_impl_.emplace(kScatterDiv,
424 wrap(&OpLevelCostEstimator::PredictScatter));
425 device_cost_impl_.emplace(kScatterMax,
426 wrap(&OpLevelCostEstimator::PredictScatter));
427 device_cost_impl_.emplace(kScatterMin,
428 wrap(&OpLevelCostEstimator::PredictScatter));
429 device_cost_impl_.emplace(kScatterMul,
430 wrap(&OpLevelCostEstimator::PredictScatter));
431 device_cost_impl_.emplace(kScatterSub,
432 wrap(&OpLevelCostEstimator::PredictScatter));
433 device_cost_impl_.emplace(kScatterUpdate,
434 wrap(&OpLevelCostEstimator::PredictScatter));
435
436 device_cost_impl_.emplace(kSlice,
437 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
438 device_cost_impl_.emplace(kStridedSlice,
439 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
440
441 device_cost_impl_.emplace(kPlaceholder,
442 wrap(&OpLevelCostEstimator::PredictIdentity));
443 device_cost_impl_.emplace(kIdentity,
444 wrap(&OpLevelCostEstimator::PredictIdentity));
445 device_cost_impl_.emplace(kIdentityN,
446 wrap(&OpLevelCostEstimator::PredictIdentity));
447 device_cost_impl_.emplace(kRefIdentity,
448 wrap(&OpLevelCostEstimator::PredictIdentity));
449 device_cost_impl_.emplace(kStopGradient,
450 wrap(&OpLevelCostEstimator::PredictIdentity));
451 device_cost_impl_.emplace(kPreventGradient,
452 wrap(&OpLevelCostEstimator::PredictIdentity));
453 device_cost_impl_.emplace(kReshape,
454 wrap(&OpLevelCostEstimator::PredictIdentity));
455 device_cost_impl_.emplace(kRecv,
456 wrap(&OpLevelCostEstimator::PredictIdentity));
457 device_cost_impl_.emplace(kSend,
458 wrap(&OpLevelCostEstimator::PredictIdentity));
459 device_cost_impl_.emplace(kSwitch,
460 wrap(&OpLevelCostEstimator::PredictIdentity));
461 device_cost_impl_.emplace(kMerge,
462 wrap(&OpLevelCostEstimator::PredictIdentity));
463 device_cost_impl_.emplace(kEnter,
464 wrap(&OpLevelCostEstimator::PredictIdentity));
465 device_cost_impl_.emplace(kExit,
466 wrap(&OpLevelCostEstimator::PredictIdentity));
467 device_cost_impl_.emplace(kNextIteration,
468 wrap(&OpLevelCostEstimator::PredictIdentity));
469 device_cost_impl_.emplace(kBitCast,
470 wrap(&OpLevelCostEstimator::PredictIdentity));
471
472 device_cost_impl_.emplace(kConcatV2,
473 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
474 device_cost_impl_.emplace(kDataFormatVecPermute,
475 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
476 device_cost_impl_.emplace(kDepthToSpace,
477 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
478 device_cost_impl_.emplace(kExpandDims,
479 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
480 device_cost_impl_.emplace(kFill,
481 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
482 device_cost_impl_.emplace(kOneHot,
483 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
484 device_cost_impl_.emplace(kPack,
485 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
486 device_cost_impl_.emplace(kRange,
487 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
488 device_cost_impl_.emplace(kSpaceToDepth,
489 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
490 device_cost_impl_.emplace(kSplit,
491 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
492 device_cost_impl_.emplace(kSqueeze,
493 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
494 device_cost_impl_.emplace(kTranspose,
495 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
496 device_cost_impl_.emplace(kTile,
497 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
498 device_cost_impl_.emplace(kUnpack,
499 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
500
501 device_cost_impl_.emplace(kRank,
502 wrap(&OpLevelCostEstimator::PredictMetadata));
503 device_cost_impl_.emplace(kShape,
504 wrap(&OpLevelCostEstimator::PredictMetadata));
505 device_cost_impl_.emplace(kShapeN,
506 wrap(&OpLevelCostEstimator::PredictMetadata));
507 device_cost_impl_.emplace(kSize,
508 wrap(&OpLevelCostEstimator::PredictMetadata));
509 device_cost_impl_.emplace(kMaxPool,
510 wrap(&OpLevelCostEstimator::PredictMaxPool));
511 device_cost_impl_.emplace(kMaxPoolGrad,
512 wrap(&OpLevelCostEstimator::PredictMaxPoolGrad));
513 device_cost_impl_.emplace(kAvgPool,
514 wrap(&OpLevelCostEstimator::PredictAvgPool));
515 device_cost_impl_.emplace(kAvgPoolGrad,
516 wrap(&OpLevelCostEstimator::PredictAvgPoolGrad));
517 device_cost_impl_.emplace(kFusedBatchNorm,
518 wrap(&OpLevelCostEstimator::PredictFusedBatchNorm));
519 device_cost_impl_.emplace(
520 kFusedBatchNormGrad,
521 wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad));
522 device_cost_impl_.emplace(kSoftmax,
523 wrap(&OpLevelCostEstimator::PredictSoftmax));
524 device_cost_impl_.emplace(kResizeBilinear,
525 wrap(&OpLevelCostEstimator::PredictResizeBilinear));
526 device_cost_impl_.emplace(kCropAndResize,
527 wrap(&OpLevelCostEstimator::PredictCropAndResize));
528 device_cost_impl_.emplace(
529 kAssignVariableOp, wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
530 device_cost_impl_.emplace(
531 kAssignAddVariableOp,
532 wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
533 device_cost_impl_.emplace(
534 kAssignSubVariableOp,
535 wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
536 device_cost_impl_.emplace(kAddN, wrap(&OpLevelCostEstimator::PredictNaryOp));
537
538 persistent_ops_ = {
539 kConst, kVariable, kVariableV2, kAutoReloadVariable,
540 kVarHandleOp, kReadVariableOp, kVarHandlesOp, kReadVariablesOp};
541
542 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
543
544 // Quantize = apply min and max bounds, multiply by scale factor and round.
545 const int quantize_v2_cost =
546 EIGEN_COST(scalar_product_op<float>) + EIGEN_COST(scalar_max_op<float>) +
547 EIGEN_COST(scalar_min_op<float>) + EIGEN_COST(scalar_round_op<float>);
548 const int quantize_and_dequantize_v2_cost =
549 quantize_v2_cost + EIGEN_COST(scalar_product_op<float>);
550
551 // Unary ops alphabetically sorted
552 elementwise_ops_.emplace("Acos", EIGEN_COST(scalar_acos_op<float>));
553 elementwise_ops_.emplace("All", EIGEN_COST(scalar_boolean_and_op));
554 elementwise_ops_.emplace("ArgMax", EIGEN_COST(scalar_max_op<float>));
555 elementwise_ops_.emplace("Asin", EIGEN_COST(scalar_asin_op<float>));
556 elementwise_ops_.emplace("Atan", EIGEN_COST(scalar_atan_op<float>));
557 elementwise_ops_.emplace("Atan2", EIGEN_COST(scalar_quotient_op<float>) +
558 EIGEN_COST(scalar_atan_op<float>));
559 // For now, we use Eigen cost model for float to int16 cast as an example
560 // case; Eigen cost model is zero when src and dst types are identical,
561 // and it uses AddCost (1) when different. We may implement a separate
562 // cost functions for cast ops, using the actual input and output types.
563 elementwise_ops_.emplace(
564 "Cast", Eigen::internal::functor_traits<
565 Eigen::internal::scalar_cast_op<float, int16>>::Cost);
566 elementwise_ops_.emplace("Ceil", EIGEN_COST(scalar_ceil_op<float>));
567 elementwise_ops_.emplace("Cos", EIGEN_COST(scalar_cos_op<float>));
568 elementwise_ops_.emplace("Dequantize", EIGEN_COST(scalar_product_op<float>));
569 elementwise_ops_.emplace("Erf", 1);
570 elementwise_ops_.emplace("Erfc", 1);
571 elementwise_ops_.emplace("Exp", EIGEN_COST(scalar_exp_op<float>));
572 elementwise_ops_.emplace("Expm1", EIGEN_COST(scalar_expm1_op<float>));
573 elementwise_ops_.emplace("Floor", EIGEN_COST(scalar_floor_op<float>));
574 elementwise_ops_.emplace("Inv", EIGEN_COST(scalar_inverse_op<float>));
575 elementwise_ops_.emplace("InvGrad", 1);
576 elementwise_ops_.emplace("Lgamma", 1);
577 elementwise_ops_.emplace("Log", EIGEN_COST(scalar_log_op<float>));
578 elementwise_ops_.emplace("Log1p", EIGEN_COST(scalar_log1p_op<float>));
579 elementwise_ops_.emplace("Max", EIGEN_COST(scalar_max_op<float>));
580 elementwise_ops_.emplace("Min", EIGEN_COST(scalar_min_op<float>));
581 elementwise_ops_.emplace("Neg", EIGEN_COST(scalar_opposite_op<float>));
582 elementwise_ops_.emplace("Prod", EIGEN_COST(scalar_product_op<float>));
583 elementwise_ops_.emplace("QuantizeAndDequantizeV2",
584 quantize_and_dequantize_v2_cost);
585 elementwise_ops_.emplace("QuantizedSigmoid",
586 EIGEN_COST(scalar_logistic_op<float>));
587 elementwise_ops_.emplace("QuantizeV2", quantize_v2_cost);
588 elementwise_ops_.emplace("Reciprocal", EIGEN_COST(scalar_inverse_op<float>));
589 elementwise_ops_.emplace("Relu", EIGEN_COST(scalar_max_op<float>));
590 elementwise_ops_.emplace("Relu6", EIGEN_COST(scalar_max_op<float>));
591 elementwise_ops_.emplace("Rint", 1);
592 elementwise_ops_.emplace("Round", EIGEN_COST(scalar_round_op<float>));
593 elementwise_ops_.emplace("Rsqrt", EIGEN_COST(scalar_rsqrt_op<float>));
594 elementwise_ops_.emplace("Sigmoid", EIGEN_COST(scalar_logistic_op<float>));
595 elementwise_ops_.emplace("Sign", EIGEN_COST(scalar_sign_op<float>));
596 elementwise_ops_.emplace("Sin", EIGEN_COST(scalar_sin_op<float>));
597 elementwise_ops_.emplace("Sqrt", EIGEN_COST(scalar_sqrt_op<float>));
598 elementwise_ops_.emplace("Square", EIGEN_COST(scalar_square_op<float>));
599 elementwise_ops_.emplace("Sum", EIGEN_COST(scalar_sum_op<float>));
600 elementwise_ops_.emplace("Tan", EIGEN_COST(scalar_tan_op<float>));
601 elementwise_ops_.emplace("Tanh", EIGEN_COST(scalar_tanh_op<float>));
602 elementwise_ops_.emplace("TopKV2", EIGEN_COST(scalar_max_op<float>));
603 // Binary ops alphabetically sorted
604 elementwise_ops_.emplace("Add", EIGEN_COST(scalar_sum_op<float>));
605 elementwise_ops_.emplace("AddV2", EIGEN_COST(scalar_sum_op<float>));
606 elementwise_ops_.emplace("ApproximateEqual", 1);
607 elementwise_ops_.emplace("BiasAdd", EIGEN_COST(scalar_sum_op<float>));
608 elementwise_ops_.emplace("QuantizedBiasAdd",
609 EIGEN_COST(scalar_sum_op<float>));
610 elementwise_ops_.emplace("Div", EIGEN_COST(scalar_quotient_op<float>));
611 elementwise_ops_.emplace("Equal", 1);
612 elementwise_ops_.emplace("FloorDiv", EIGEN_COST(scalar_quotient_op<float>));
613 elementwise_ops_.emplace("FloorMod", EIGEN_COST(scalar_mod_op<float>));
614 elementwise_ops_.emplace("Greater", 1);
615 elementwise_ops_.emplace("GreaterEqual", 1);
616 elementwise_ops_.emplace("Less", 1);
617 elementwise_ops_.emplace("LessEqual", 1);
618 elementwise_ops_.emplace("LogicalAnd", EIGEN_COST(scalar_boolean_and_op));
619 elementwise_ops_.emplace("LogicalNot", 1);
620 elementwise_ops_.emplace("LogicalOr", EIGEN_COST(scalar_boolean_or_op));
621 elementwise_ops_.emplace("Maximum", EIGEN_COST(scalar_max_op<float>));
622 elementwise_ops_.emplace("Minimum", EIGEN_COST(scalar_min_op<float>));
623 elementwise_ops_.emplace("Mod", EIGEN_COST(scalar_mod_op<float>));
624 elementwise_ops_.emplace("Mul", EIGEN_COST(scalar_product_op<float>));
625 elementwise_ops_.emplace("NotEqual", 1);
626 elementwise_ops_.emplace("QuantizedAdd", EIGEN_COST(scalar_sum_op<float>));
627 elementwise_ops_.emplace("QuantizedMul",
628 EIGEN_COST(scalar_product_op<float>));
629 elementwise_ops_.emplace("RealDiv", EIGEN_COST(scalar_quotient_op<float>));
630 elementwise_ops_.emplace("ReluGrad", EIGEN_COST(scalar_max_op<float>));
631 elementwise_ops_.emplace("Select", EIGEN_COST(scalar_boolean_or_op));
632 elementwise_ops_.emplace("SelectV2", EIGEN_COST(scalar_boolean_or_op));
633 elementwise_ops_.emplace("SquaredDifference",
634 EIGEN_COST(scalar_square_op<float>) +
635 EIGEN_COST(scalar_difference_op<float>));
636 elementwise_ops_.emplace("Sub", EIGEN_COST(scalar_difference_op<float>));
637 elementwise_ops_.emplace("TruncateDiv",
638 EIGEN_COST(scalar_quotient_op<float>));
639 elementwise_ops_.emplace("TruncateMod", EIGEN_COST(scalar_mod_op<float>));
640 elementwise_ops_.emplace("Where", 1);
641
642 #undef EIGEN_COST
643
644 // By default, use sum of memory_time and compute_time for execution_time.
645 compute_memory_overlap_ = false;
646 }
647
PredictCosts(const OpContext & op_context) const648 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
649 Costs costs;
650 NodeCosts node_costs;
651 if (PredictNodeCosts(op_context, &node_costs).ok()) {
652 if (node_costs.has_costs) {
653 return node_costs.costs;
654 }
655 // Convert NodeCosts to Costs.
656 if (node_costs.minimum_cost_op) {
657 // Override to minimum cost; Note that some ops with minimum cost may have
658 // non-typical device (e.g., channel for _Send), which may fail with
659 // GetDeviceInfo(), called from PredictOpCountBasedCost(). Make sure we
660 // directly set minimum values to Costs here, not calling
661 // PredictOpCountBasedCost().
662 costs.compute_time = kMinComputeTime;
663 costs.execution_time = kMinComputeTime;
664 costs.memory_time = 0;
665 costs.intermediate_memory_time = 0;
666 costs.intermediate_memory_read_time = 0;
667 costs.intermediate_memory_write_time = 0;
668 } else {
669 // Convert NodeCosts to Costs.
670 costs = PredictOpCountBasedCost(
671 node_costs.num_compute_ops, node_costs.num_total_read_bytes(),
672 node_costs.num_total_write_bytes(), op_context.op_info);
673 }
674 VLOG(1) << "Operation " << op_context.op_info.op() << " takes "
675 << costs.execution_time.count() << " ns.";
676 // Copy additional stats from NodeCosts to Costs.
677 costs.max_memory = node_costs.max_memory;
678 costs.persistent_memory = node_costs.persistent_memory;
679 costs.temporary_memory = node_costs.temporary_memory;
680 costs.inaccurate = node_costs.inaccurate;
681 costs.num_ops_with_unknown_shapes =
682 node_costs.num_nodes_with_unknown_shapes;
683 costs.num_ops_total = node_costs.num_nodes;
684 return costs;
685 }
686 // Errors during node cost estimate.
687 LOG(WARNING) << "Error in PredictCost() for the op: "
688 << op_context.op_info.ShortDebugString();
689 costs = Costs::ZeroCosts(/*inaccurate=*/true);
690 costs.num_ops_with_unknown_shapes = node_costs.num_nodes_with_unknown_shapes;
691 return costs;
692 }
693
PredictNodeCosts(const OpContext & op_context,NodeCosts * node_costs) const694 Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
695 NodeCosts* node_costs) const {
696 const auto& op_info = op_context.op_info;
697 auto it = device_cost_impl_.find(op_info.op());
698 if (it != device_cost_impl_.end()) {
699 std::function<Status(const OpContext&, NodeCosts*)> estimator = it->second;
700 return estimator(op_context, node_costs);
701 }
702
703 if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
704 return PredictVariable(op_context, node_costs);
705 }
706
707 if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
708 return PredictCwiseOp(op_context, node_costs);
709 }
710
711 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
712
713 node_costs->num_nodes_with_unknown_op_type = 1;
714 return PredictCostOfAnUnknownOp(op_context, node_costs);
715 }
716
717 // This method assumes a typical system composed of CPUs and GPUs, connected
718 // through PCIe. To define device info more precisely, override this method.
GetDeviceInfo(const DeviceProperties & device) const719 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
720 const DeviceProperties& device) const {
721 double gflops = -1;
722 double gb_per_sec = -1;
723
724 if (device.type() == "CPU") {
725 // Check if vector instructions are available, and refine performance
726 // prediction based on this.
727 // Frequencies are stored in MHz in the DeviceProperties.
728 gflops = device.num_cores() * device.frequency() * 1e-3;
729 if (gb_per_sec < 0) {
730 if (device.bandwidth() > 0) {
731 gb_per_sec = device.bandwidth() / 1e6;
732 } else {
733 gb_per_sec = 32;
734 }
735 }
736 } else if (device.type() == "GPU") {
737 const std::string architecture = device.environment().at("architecture");
738 int cores_per_multiprocessor;
739 if (architecture < "3") {
740 // Fermi
741 cores_per_multiprocessor = 32;
742 } else if (architecture < "4") {
743 // Kepler
744 cores_per_multiprocessor = 192;
745 } else if (architecture < "6") {
746 // Maxwell
747 cores_per_multiprocessor = 128;
748 } else {
749 // Pascal (compute capability version 6) and Volta (compute capability
750 // version 7)
751 cores_per_multiprocessor = 64;
752 }
753 gflops = device.num_cores() * device.frequency() * 1e-3 *
754 cores_per_multiprocessor * kOpsPerMac;
755 if (device.bandwidth() > 0) {
756 gb_per_sec = device.bandwidth() / 1e6;
757 } else {
758 gb_per_sec = 100;
759 }
760 } else {
761 LOG_EVERY_N(WARNING, 1000) << "Unknown device type: " << device.type()
762 << ", assuming PCIe between CPU and GPU.";
763 gflops = 1; // Dummy value; data transfer ops would not have compute ops.
764 gb_per_sec = 12; // default PCIe x16 gen3.
765 }
766 VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
767 << " gb_per_sec: " << gb_per_sec;
768
769 return DeviceInfo(gflops, gb_per_sec);
770 }
771
PredictCwiseOp(const OpContext & op_context,NodeCosts * node_costs) const772 Status OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context,
773 NodeCosts* node_costs) const {
774 const auto& op_info = op_context.op_info;
775 bool found_unknown_shapes = false;
776 // For element-wise operations, op count is the element count of any input. We
777 // use the count for the largest input here to be more robust in case that the
778 // shape is unknown or partially known for other input.
779 int64 op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
780 // If output shape is available, try to use the element count calculated from
781 // that.
782 if (op_info.outputs_size() > 0) {
783 op_count = std::max(
784 op_count,
785 CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
786 }
787 // Calculate the output shape possibly resulting from broadcasting.
788 if (op_info.inputs_size() >= 2) {
789 op_count = std::max(op_count, CwiseOutputElementCount(op_info));
790 }
791
792 int op_cost = 1;
793 auto it = elementwise_ops_.find(op_info.op());
794 if (it != elementwise_ops_.end()) {
795 op_cost = it->second;
796 } else {
797 return errors::InvalidArgument("Not a cwise op: ", op_info.op());
798 }
799
800 return PredictDefaultNodeCosts(op_count * op_cost, op_context,
801 &found_unknown_shapes, node_costs);
802 }
803
PredictCostOfAnUnknownOp(const OpContext & op_context,NodeCosts * node_costs) const804 Status OpLevelCostEstimator::PredictCostOfAnUnknownOp(
805 const OpContext& op_context, NodeCosts* node_costs) const {
806 // Don't assume the operation is cwise, return cost based on input/output size
807 // and admit that it is inaccurate...
808 bool found_unknown_shapes = false;
809 node_costs->inaccurate = true;
810 return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
811 node_costs);
812 }
813
PredictOpCountBasedCost(double operations,const OpInfo & op_info) const814 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
815 double operations, const OpInfo& op_info) const {
816 bool unknown_shapes = false;
817 const double input_size = CalculateInputSize(op_info, &unknown_shapes);
818 const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
819 Costs costs =
820 PredictOpCountBasedCost(operations, input_size, output_size, op_info);
821 costs.inaccurate = unknown_shapes;
822 costs.num_ops_with_unknown_shapes = unknown_shapes;
823 costs.max_memory = output_size;
824 return costs;
825 }
826
PredictOpCountBasedCost(double operations,double input_io_bytes,double output_io_bytes,const OpInfo & op_info) const827 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
828 double operations, double input_io_bytes, double output_io_bytes,
829 const OpInfo& op_info) const {
830 double total_io_bytes = input_io_bytes + output_io_bytes;
831 const DeviceInfo device_info = GetDeviceInfo(op_info.device());
832 if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
833 device_info.intermediate_read_gb_per_sec <= 0 ||
834 device_info.intermediate_write_gb_per_sec <= 0) {
835 VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
836 << " device type:" << op_info.device().type()
837 << " device model:" << op_info.device().model();
838 }
839
840 Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
841 VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
842 << " Compute Time (ns):" << compute_cost.count();
843
844 Costs::NanoSeconds memory_cost(
845 std::ceil(total_io_bytes / device_info.gb_per_sec));
846 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
847 << " Memory Time (ns):" << memory_cost.count();
848
849 // Check if bytes > 0. If it's not and the bandwidth is set to infinity
850 // then the result would be undefined.
851 double intermediate_read_time =
852 (input_io_bytes > 0)
853 ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
854 : 0;
855
856 double intermediate_write_time =
857 (output_io_bytes > 0)
858 ? std::ceil(output_io_bytes /
859 device_info.intermediate_write_gb_per_sec)
860 : 0;
861
862 Costs::NanoSeconds intermediate_memory_cost =
863 compute_memory_overlap_
864 ? std::max(intermediate_read_time, intermediate_write_time)
865 : (intermediate_read_time + intermediate_write_time);
866 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
867 << " Intermediate Memory Time (ns):"
868 << intermediate_memory_cost.count();
869
870 Costs costs = Costs::ZeroCosts();
871 costs.compute_time = compute_cost;
872 costs.memory_time = memory_cost;
873 costs.intermediate_memory_time = intermediate_memory_cost;
874 costs.intermediate_memory_read_time =
875 Costs::NanoSeconds(intermediate_read_time);
876 costs.intermediate_memory_write_time =
877 Costs::NanoSeconds(intermediate_write_time);
878 CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
879 return costs;
880 }
881
CountConv2DOperations(const OpInfo & op_info,bool * found_unknown_shapes)882 int64 OpLevelCostEstimator::CountConv2DOperations(const OpInfo& op_info,
883 bool* found_unknown_shapes) {
884 return CountConv2DOperations(op_info, nullptr, found_unknown_shapes);
885 }
886
887 // Helper to translate the positional arguments into named fields.
888 /* static */
889 OpLevelCostEstimator::ConvolutionDimensions
ConvolutionDimensionsFromInputs(const TensorShapeProto & original_image_shape,const TensorShapeProto & original_filter_shape,const OpInfo & op_info,bool * found_unknown_shapes)890 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
891 const TensorShapeProto& original_image_shape,
892 const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
893 bool* found_unknown_shapes) {
894 VLOG(2) << "op features: " << op_info.DebugString();
895 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
896 VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
897
898 int x_index, y_index, major_channel_index, minor_channel_index = -1;
899 const std::string& data_format = GetDataFormat(op_info);
900 if (data_format == "NCHW") {
901 major_channel_index = 1;
902 y_index = 2;
903 x_index = 3;
904 } else if (data_format == "NCHW_VECT_C") {
905 // Use NCHW_VECT_C
906 minor_channel_index = 1;
907 y_index = 2;
908 x_index = 3;
909 major_channel_index = 4;
910 } else {
911 // Use NHWC.
912 y_index = 1;
913 x_index = 2;
914 major_channel_index = 3;
915 }
916 const std::string& filter_format = GetFilterFormat(op_info);
917 int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
918 in_minor_channel_index = -1;
919 if (filter_format == "HWIO") {
920 filter_y_index = 0;
921 filter_x_index = 1;
922 in_major_channel_index = 2;
923 out_channel_index = 3;
924 } else if (filter_format == "OIHW_VECT_I") {
925 out_channel_index = 0;
926 in_minor_channel_index = 1;
927 filter_y_index = 2;
928 filter_x_index = 3;
929 in_major_channel_index = 4;
930 } else {
931 // Use OIHW
932 out_channel_index = 0;
933 in_major_channel_index = 1;
934 filter_y_index = 2;
935 filter_x_index = 3;
936 }
937
938 auto image_shape = MaybeGetMinimumShape(original_image_shape,
939 minor_channel_index >= 0 ? 5 : 4,
940 found_unknown_shapes);
941 auto filter_shape = MaybeGetMinimumShape(original_filter_shape,
942 in_minor_channel_index >= 0 ? 5 : 4,
943 found_unknown_shapes);
944 VLOG(2) << "Image shape: " << image_shape.DebugString();
945 VLOG(2) << "Filter shape: " << filter_shape.DebugString();
946
947 int64 batch = image_shape.dim(0).size();
948 int64 ix = image_shape.dim(x_index).size();
949 int64 iy = image_shape.dim(y_index).size();
950 int64 iz = minor_channel_index >= 0
951 ? image_shape.dim(minor_channel_index).size() *
952 image_shape.dim(major_channel_index).size()
953 : image_shape.dim(major_channel_index).size();
954 int64 kx = filter_shape.dim(filter_x_index).size();
955 int64 ky = filter_shape.dim(filter_y_index).size();
956 int64 kz = in_minor_channel_index >= 0
957 ? filter_shape.dim(in_major_channel_index).size() *
958 filter_shape.dim(in_minor_channel_index).size()
959 : filter_shape.dim(in_major_channel_index).size();
960 std::vector<int64> strides = GetStrides(op_info);
961 const auto padding = GetPadding(op_info);
962 int64 sx = strides[x_index];
963 int64 sy = strides[y_index];
964 int64 ox = GetOutputSize(ix, kx, sx, padding);
965 int64 oy = GetOutputSize(iy, ky, sy, padding);
966 int64 oz = filter_shape.dim(out_channel_index).size();
967 // Only check equality when both sizes are known (in other words, when
968 // neither is set to a minimum dimension size of 1).
969 if (iz != 1 && kz != 1) {
970 DCHECK_EQ(iz % kz, 0) << "Input channel " << iz
971 << " is not a multiple of filter channel " << kz
972 << ".";
973 if (iz % kz) {
974 *found_unknown_shapes = true;
975 }
976 } else {
977 iz = kz = std::max<int64>(iz, kz);
978 }
979 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
980 batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
981
982 VLOG(1) << "Batch Size:" << batch;
983 VLOG(1) << "Image Dims:" << ix << "," << iy;
984 VLOG(1) << "Input Depth:" << iz;
985 VLOG(1) << "Kernel Dims:" << kx << "," << ky;
986 VLOG(1) << "Kernel Depth:" << kz;
987 VLOG(1) << "Output Dims:" << ox << "," << oy;
988 VLOG(1) << "Output Depth:" << oz;
989 VLOG(1) << "Strides:" << sx << "," << sy;
990 VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
991 return conv_dims;
992 }
993
CountConv2DOperations(const OpInfo & op_info,ConvolutionDimensions * conv_info,bool * found_unknown_shapes)994 int64 OpLevelCostEstimator::CountConv2DOperations(
995 const OpInfo& op_info, ConvolutionDimensions* conv_info,
996 bool* found_unknown_shapes) {
997 DCHECK(op_info.op() == kConv2d || op_info.op() == kDepthwiseConv2dNative)
998 << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
999
1000 if (op_info.inputs_size() < 2) { // Unexpect inputs.
1001 *found_unknown_shapes = true;
1002 return 0;
1003 }
1004
1005 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1006 op_info.inputs(0).shape(), op_info.inputs(1).shape(), op_info,
1007 found_unknown_shapes);
1008
1009 // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
1010 // multiplier; The effective output channel depth oz_effective is
1011 // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
1012 // Compare to Conv2D where # ops = N x H x W x kz x oz x 2RS,
1013 // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = kz.
1014 int64 ops = conv_dims.batch;
1015 ops *= conv_dims.ox * conv_dims.oy;
1016 ops *= conv_dims.kx * conv_dims.ky;
1017 if (op_info.op() == kConv2d) {
1018 ops *= conv_dims.kz * conv_dims.oz;
1019 } else {
1020 // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
1021 // although ops are the same as Conv2D.
1022 conv_dims.oz *= conv_dims.iz;
1023 ops *= conv_dims.oz;
1024 }
1025 ops *= kOpsPerMac;
1026
1027 if (conv_info != nullptr) {
1028 *conv_info = conv_dims;
1029 }
1030 return ops;
1031 }
1032
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1033 int64 OpLevelCostEstimator::CountMatMulOperations(const OpInfo& op_info,
1034 bool* found_unknown_shapes) {
1035 return CountMatMulOperations(op_info, nullptr, found_unknown_shapes);
1036 }
1037
1038 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
CountMatMulOperations(const OpInfo & op_info,MatMulDimensions * mat_mul,bool * found_unknown_shapes)1039 int64 OpLevelCostEstimator::CountMatMulOperations(const OpInfo& op_info,
1040 MatMulDimensions* mat_mul,
1041 bool* found_unknown_shapes) {
1042 double ops = 0;
1043
1044 if (op_info.inputs_size() < 2) {
1045 LOG(ERROR) << "Need 2 inputs but got " << op_info.inputs_size();
1046 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1047 *found_unknown_shapes = true;
1048 return 0;
1049 }
1050
1051 auto& a_matrix = op_info.inputs(0);
1052 auto& b_matrix = op_info.inputs(1);
1053
1054 bool transpose_a = false;
1055 bool transpose_b = false;
1056
1057 double m_dim, n_dim, k_dim, k_dim_b = 0;
1058
1059 for (const auto& item : op_info.attr()) {
1060 VLOG(1) << "Key:" << item.first
1061 << " Value:" << SummarizeAttrValue(item.second);
1062 if (item.first == "transpose_a" && item.second.b() == true)
1063 transpose_a = true;
1064 if (item.first == "transpose_b" && item.second.b() == true)
1065 transpose_b = true;
1066 }
1067 VLOG(1) << "transpose_a:" << transpose_a;
1068 VLOG(1) << "transpose_b:" << transpose_b;
1069 auto a_matrix_shape =
1070 MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
1071 auto b_matrix_shape =
1072 MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
1073 if (transpose_a) {
1074 m_dim = a_matrix_shape.dim(1).size();
1075 k_dim = a_matrix_shape.dim(0).size();
1076 } else {
1077 m_dim = a_matrix_shape.dim(0).size();
1078 k_dim = a_matrix_shape.dim(1).size();
1079 }
1080 if (transpose_b) {
1081 k_dim_b = b_matrix_shape.dim(1).size();
1082 n_dim = b_matrix_shape.dim(0).size();
1083 } else {
1084 k_dim_b = b_matrix_shape.dim(0).size();
1085 n_dim = b_matrix_shape.dim(1).size();
1086 }
1087
1088 VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
1089 // Only check equality when both sizes are known (in other words, when
1090 // neither is set to a minimum dimension size of 1).
1091 if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
1092 LOG(ERROR) << "Incompatible Matrix dimensions";
1093 return ops;
1094 } else {
1095 // One of k_dim and k_dim_b might be 1 (minimum dimension size).
1096 k_dim = std::max(k_dim, k_dim_b);
1097 }
1098
1099 ops = m_dim * n_dim * k_dim * 2;
1100 VLOG(1) << "Operations for Matmul: " << ops;
1101
1102 if (mat_mul != nullptr) {
1103 mat_mul->m = m_dim;
1104 mat_mul->n = n_dim;
1105 mat_mul->k = k_dim;
1106 }
1107 return ops;
1108 }
1109
GenerateBatchMatmulContextFromEinsum(const OpContext & einsum_context,OpContext * batch_matmul_context,bool * found_unknown_shapes) const1110 bool OpLevelCostEstimator::GenerateBatchMatmulContextFromEinsum(
1111 const OpContext& einsum_context, OpContext* batch_matmul_context,
1112 bool* found_unknown_shapes) const {
1113 // This auxiliary function transforms an einsum OpContext into its equivalent
1114 // Batch Matmul OpContext. The function returns a boolean, which determines
1115 // whether it was successful in generating the output OpContext or not.
1116
1117 // Einsum computes a generalized contraction between tensors of arbitrary
1118 // dimension as defined by the equation written in the Einstein summation
1119 // convention. The number of tensors in the computation and the number of
1120 // contractions can be arbitrarily long. The current model only contemplates
1121 // Einsum equations, which can be translated into a single BatchMatMul
1122 // operation. Einsum operations with more than two operands are not currently
1123 // supported. Subscripts where an axis appears more than once for a single
1124 // input and ellipsis are currently also excluded. See:
1125 // https://www.tensorflow.org/api_docs/python/tf/einsum
1126 // We distinguish four kinds of dimensions, depending on their placement in
1127 // the equation:
1128 // + B: Batch dimensions: Dimensions which appear in both operands and RHS.
1129 // + K: Contracting dimensions: These appear in both inputs but not RHS.
1130 // + M: Operand A dimensions: These appear in the first operand and the RHS.
1131 // + N: Operand B dimensions: These appear in the second operand and the RHS.
1132 // Then, the operation to estimate is BatchMatMul([B,M,K],[B,K,N])
1133
1134 if (batch_matmul_context == nullptr) {
1135 VLOG(1) << "Output context should not be a nullptr.";
1136 return false;
1137 }
1138 if (!IsEinsumCorrectlyFormed(einsum_context)) return false;
1139 const auto& op_info = einsum_context.op_info;
1140 std::vector<std::string> equation_split =
1141 absl::StrSplit(op_info.attr().find("equation")->second.s(), "->");
1142 std::vector<absl::string_view> input_split =
1143 absl::StrSplit(equation_split[0], ',');
1144 const auto& a_input = op_info.inputs(0);
1145 const auto& b_input = op_info.inputs(1);
1146 absl::string_view rhs_str = equation_split[1];
1147 absl::string_view a_input_str = input_split[0];
1148 absl::string_view b_input_str = input_split[1];
1149
1150 constexpr int kMatrixRank = 2;
1151
1152 bool a_input_shape_unknown = false;
1153 bool b_input_shape_unknown = false;
1154
1155 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1156 a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
1157 &a_input_shape_unknown);
1158 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1159 b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
1160 &b_input_shape_unknown);
1161
1162 *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1163 (a_input.shape().dim_size() < kMatrixRank) ||
1164 (b_input.shape().dim_size() < kMatrixRank);
1165
1166 OpInfo batch_matmul_op_info = op_info;
1167 batch_matmul_op_info.mutable_inputs()->Clear();
1168 batch_matmul_op_info.set_op("BatchMatMul");
1169
1170 AttrValue transpose_attribute;
1171 transpose_attribute.set_b(false);
1172 (*batch_matmul_op_info.mutable_attr())["transpose_a"] = transpose_attribute;
1173 (*batch_matmul_op_info.mutable_attr())["transpose_b"] = transpose_attribute;
1174
1175 OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs();
1176 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1177 a_matrix->set_dtype(a_input.dtype());
1178
1179 OpInfo::TensorProperties* b_matrix = batch_matmul_op_info.add_inputs();
1180 b_matrix->set_dtype(b_input.dtype());
1181 TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1182
1183 TensorShapeProto_Dim m_dim;
1184 TensorShapeProto_Dim n_dim;
1185 TensorShapeProto_Dim k_dim;
1186
1187 m_dim.set_size(1);
1188 n_dim.set_size(1);
1189 k_dim.set_size(1);
1190
1191 for (int i_idx = 0, a_input_str_size = a_input_str.size();
1192 i_idx < a_input_str_size; ++i_idx) {
1193 if (b_input_str.find(a_input_str[i_idx]) == std::string::npos) {
1194 if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1195 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1196 return false;
1197 }
1198
1199 m_dim.set_size(m_dim.size() * a_input_shape.dim(i_idx).size());
1200 continue;
1201 } else if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1202 // The dimension does not appear in the RHS, therefore it is a contracting
1203 // dimension.
1204 k_dim.set_size(k_dim.size() * a_input_shape.dim(i_idx).size());
1205 continue;
1206 }
1207 // It appears in both input operands, therefore we place it as an outer
1208 // dimension for the Batch Matmul.
1209 *(a_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1210 *(b_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1211 }
1212 for (int i_idx = 0, b_input_str_size = b_input_str.size();
1213 i_idx < b_input_str_size; ++i_idx) {
1214 if (a_input_str.find(b_input_str[i_idx]) == std::string::npos) {
1215 if (rhs_str.find(b_input_str[i_idx]) == std::string::npos) {
1216 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1217 return false;
1218 }
1219 n_dim.set_size(n_dim.size() * b_input_shape.dim(i_idx).size());
1220 }
1221 }
1222
1223 // The two inner-most dimensions of the Batch Matmul are added.
1224 *(a_matrix_shape->add_dim()) = m_dim;
1225 *(a_matrix_shape->add_dim()) = k_dim;
1226 *(b_matrix_shape->add_dim()) = k_dim;
1227 *(b_matrix_shape->add_dim()) = n_dim;
1228
1229 *batch_matmul_context = einsum_context;
1230 batch_matmul_context->op_info = batch_matmul_op_info;
1231 return true;
1232 }
1233
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1234 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
1235 const OpInfo& op_info, bool* found_unknown_shapes) {
1236 return CountBatchMatMulOperations(op_info, nullptr, found_unknown_shapes);
1237 }
1238
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes)1239 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
1240 const OpInfo& op_info, BatchMatMulDimensions* batch_mat_mul,
1241 bool* found_unknown_shapes) {
1242 if (op_info.op() != kBatchMatMul && op_info.op() != kBatchMatMulV2) {
1243 LOG(ERROR) << "Invalid Operation: " << op_info.op();
1244 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1245 *found_unknown_shapes = true;
1246 return 0;
1247 }
1248 if (op_info.inputs_size() != 2) {
1249 LOG(ERROR) << "Expected 2 inputs but got " << op_info.inputs_size();
1250 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1251 *found_unknown_shapes = true;
1252 return 0;
1253 }
1254
1255 double ops = 0;
1256 const auto& a_input = op_info.inputs(0);
1257 const auto& b_input = op_info.inputs(1);
1258
1259 // BatchMatMul requires inputs of at least matrix shape (rank 2).
1260 // The two most minor dimensions of each input are matrices that
1261 // need to be multiplied together. The other dimensions determine
1262 // the number of such MatMuls. For example, if the BatchMatMul has
1263 // inputs of shape:
1264 // a_input_shape = [2, 3, 4, 5]
1265 // b_input_shape = [2, 3, 5, 6]
1266 // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
1267 // in this BatchMatMul.
1268 const int matrix_rank = 2;
1269
1270 bool a_input_shape_unknown = false;
1271 bool b_input_shape_unknown = false;
1272
1273 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1274 a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
1275 &a_input_shape_unknown);
1276 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1277 b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
1278 &b_input_shape_unknown);
1279
1280 *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1281 (a_input.shape().dim_size() < matrix_rank) ||
1282 (b_input.shape().dim_size() < matrix_rank);
1283
1284 // Compute the number of matmuls as the max indicated at each dimension
1285 // by either input. Note that the shapes do not have to have
1286 // the same rank due to incompleteness.
1287 TensorShapeProto* bigger_rank_shape = &a_input_shape;
1288 TensorShapeProto* smaller_rank_shape = &b_input_shape;
1289 if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
1290 bigger_rank_shape = &b_input_shape;
1291 smaller_rank_shape = &a_input_shape;
1292 }
1293 int num_matmuls = 1;
1294 for (int b_i = 0,
1295 s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
1296 b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
1297 int b_dim = bigger_rank_shape->dim(b_i).size();
1298 int s_dim = 1;
1299 if (s_i >= 0) {
1300 s_dim = smaller_rank_shape->dim(s_i).size();
1301 }
1302 if (batch_mat_mul != nullptr) {
1303 batch_mat_mul->batch_dims.push_back(s_dim);
1304 }
1305 num_matmuls *= std::max(b_dim, s_dim);
1306 }
1307
1308 // Build the MatMul. Note that values are ignored here since we are just
1309 // counting ops (e.g. only shapes matter).
1310 OpInfo matmul_op_info;
1311 matmul_op_info.set_op("MatMul");
1312
1313 AttrValue transpose_a;
1314 transpose_a.set_b(false);
1315 if (op_info.attr().find("adj_x") != op_info.attr().end()) {
1316 transpose_a.set_b(op_info.attr().at("adj_x").b());
1317 }
1318 (*matmul_op_info.mutable_attr())["transpose_a"] = transpose_a;
1319
1320 AttrValue transpose_b;
1321 transpose_b.set_b(false);
1322 if (op_info.attr().find("adj_y") != op_info.attr().end()) {
1323 transpose_b.set_b(op_info.attr().at("adj_y").b());
1324 }
1325 (*matmul_op_info.mutable_attr())["transpose_b"] = transpose_b;
1326
1327 OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs();
1328 a_matrix->set_dtype(a_input.dtype());
1329 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1330 for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
1331 i < a_input_shape.dim_size(); ++i) {
1332 *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
1333 }
1334
1335 OpInfo::TensorProperties* b_matrix = matmul_op_info.add_inputs();
1336 b_matrix->set_dtype(b_input.dtype());
1337 TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1338 for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
1339 i < b_input_shape.dim_size(); ++i) {
1340 *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
1341 }
1342 if (batch_mat_mul != nullptr) {
1343 batch_mat_mul->matmul_dims.m = (transpose_a.b())
1344 ? a_matrix_shape->dim(1).size()
1345 : a_matrix_shape->dim(0).size();
1346 batch_mat_mul->matmul_dims.k = (transpose_a.b())
1347 ? a_matrix_shape->dim(0).size()
1348 : a_matrix_shape->dim(1).size();
1349 batch_mat_mul->matmul_dims.n = (transpose_b.b())
1350 ? b_matrix_shape->dim(0).size()
1351 : b_matrix_shape->dim(1).size();
1352 }
1353
1354 for (int i = 0; i < num_matmuls; ++i) {
1355 bool matmul_unknown_shapes = false;
1356 ops += CountMatMulOperations(matmul_op_info, &matmul_unknown_shapes);
1357 *found_unknown_shapes |= matmul_unknown_shapes;
1358 }
1359 return ops;
1360 }
1361
GetTensorShapeProtoFromTensorProto(const TensorProto & tensor_proto,TensorShapeProto * tensor_shape_proto)1362 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
1363 TensorShapeProto* tensor_shape_proto) {
1364 tensor_shape_proto->Clear();
1365 // First convert TensorProto into Tensor class so that it correctly parses
1366 // data values within TensorProto (whether it's in int_val, int64_val,
1367 // tensor_content, or anything.
1368 Tensor tensor(tensor_proto.dtype());
1369 if (!tensor.FromProto(tensor_proto)) {
1370 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1371 << "failed to parse TensorProto: "
1372 << tensor_proto.DebugString();
1373 return false;
1374 }
1375 if (tensor.dims() != 1) {
1376 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1377 << "tensor is not 1D: " << tensor.dims();
1378 return false;
1379 }
1380 // Then, convert it back to TensorProto using AsProtoField, which makes sure
1381 // the data is in int_val, int64_val, or such repeated data fields, not in
1382 // tensor_content.
1383 TensorProto temp_tensor;
1384 tensor.AsProtoField(&temp_tensor);
1385
1386 #define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type) \
1387 do { \
1388 for (const auto& value : temp_tensor.type##_val()) { \
1389 tensor_shape_proto->add_dim()->set_size(value); \
1390 } \
1391 } while (0)
1392
1393 if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
1394 tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
1395 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
1396 } else if (tensor.dtype() == DT_INT64) {
1397 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
1398 } else if (tensor.dtype() == DT_UINT32) {
1399 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
1400 } else if (tensor.dtype() == DT_UINT64) {
1401 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
1402 } else {
1403 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1404 << "Unsupported dtype: " << tensor.dtype();
1405 return false;
1406 }
1407 #undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
1408
1409 return true;
1410 }
1411
1412 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
CountConv2DBackpropInputOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1413 int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
1414 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1415 bool* found_unknown_shapes) {
1416 int64 ops = 0;
1417
1418 DCHECK(op_info.op() == kConv2dBackpropInput ||
1419 op_info.op() == kDepthwiseConv2dNativeBackpropInput)
1420 << "Invalid Operation: not kConv2dBackpropInput nor"
1421 "kDepthwiseConv2dNativeBackpropInput";
1422
1423 if (op_info.inputs_size() < 2) {
1424 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1425 *found_unknown_shapes = true;
1426 return ops;
1427 }
1428
1429 TensorShapeProto input_shape;
1430 bool shape_found = false;
1431 if (op_info.inputs(0).has_value()) {
1432 const TensorProto& value = op_info.inputs(0).value();
1433 shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
1434 }
1435 if (!shape_found && op_info.outputs_size() == 1) {
1436 input_shape = op_info.outputs(0).shape();
1437 shape_found = true;
1438 }
1439 if (!shape_found) {
1440 // Set the minimum filter size that's feasible.
1441 input_shape.Clear();
1442 for (int i = 0; i < 4; ++i) {
1443 input_shape.add_dim()->set_size(1);
1444 }
1445 *found_unknown_shapes = true;
1446 }
1447
1448 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1449 input_shape, op_info.inputs(1).shape(), op_info, found_unknown_shapes);
1450
1451 ops = conv_dims.batch;
1452 ops *= conv_dims.ox * conv_dims.oy;
1453 ops *= conv_dims.kx * conv_dims.ky;
1454 if (op_info.op() == kConv2dBackpropInput) {
1455 ops *= conv_dims.kz * conv_dims.oz;
1456 } else {
1457 // conv_dims always use forward path definition regardless
1458 conv_dims.oz *= conv_dims.iz;
1459 ops *= conv_dims.oz;
1460 }
1461 ops *= kOpsPerMac;
1462
1463 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
1464
1465 if (returned_conv_dims != nullptr) {
1466 *returned_conv_dims = conv_dims;
1467 }
1468 return ops;
1469 }
1470
CountConv2DBackpropFilterOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1471 int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
1472 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1473 bool* found_unknown_shapes) {
1474 int64 ops = 0;
1475
1476 DCHECK(op_info.op() == kConv2dBackpropFilter ||
1477 op_info.op() == kDepthwiseConv2dNativeBackpropFilter)
1478 << "Invalid Operation: not kConv2dBackpropFilter nor"
1479 "kDepthwiseConv2dNativeBackpropFilter";
1480
1481 TensorShapeProto filter_shape;
1482 bool shape_found = false;
1483 if (op_info.inputs_size() >= 2 && op_info.inputs(1).has_value()) {
1484 const TensorProto& value = op_info.inputs(1).value();
1485 shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
1486 }
1487 if (!shape_found && op_info.outputs_size() == 1) {
1488 filter_shape = op_info.outputs(0).shape();
1489 shape_found = true;
1490 }
1491 if (!shape_found) {
1492 // Set the minimum filter size that's feasible.
1493 filter_shape.Clear();
1494 for (int i = 0; i < 4; ++i) {
1495 filter_shape.add_dim()->set_size(1);
1496 }
1497 *found_unknown_shapes = true;
1498 }
1499
1500 if (op_info.inputs_size() < 1) {
1501 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1502 *found_unknown_shapes = true;
1503 return ops;
1504 }
1505 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1506 op_info.inputs(0).shape(), filter_shape, op_info, found_unknown_shapes);
1507
1508 ops = conv_dims.batch;
1509 ops *= conv_dims.ox * conv_dims.oy;
1510 ops *= conv_dims.kx * conv_dims.ky;
1511 if (op_info.op() == kConv2dBackpropFilter) {
1512 ops *= conv_dims.kz * conv_dims.oz;
1513 } else {
1514 // conv_dims always use forward path definition regardless
1515 conv_dims.oz *= conv_dims.iz;
1516 ops *= conv_dims.oz;
1517 }
1518 ops *= kOpsPerMac;
1519 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
1520
1521 if (returned_conv_dims != nullptr) {
1522 *returned_conv_dims = conv_dims;
1523 }
1524 return ops;
1525 }
1526
CalculateTensorElementCount(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1527 int64 OpLevelCostEstimator::CalculateTensorElementCount(
1528 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1529 VLOG(2) << " with " << DataTypeString(tensor.dtype()) << " tensor of shape "
1530 << tensor.shape().DebugString();
1531 int64 tensor_size = 1;
1532 int num_dims = std::max(1, tensor.shape().dim_size());
1533 auto tensor_shape =
1534 MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
1535 for (const auto& dim : tensor_shape.dim()) {
1536 tensor_size *= dim.size();
1537 }
1538 return tensor_size;
1539 }
1540
CalculateTensorSize(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1541 int64 OpLevelCostEstimator::CalculateTensorSize(
1542 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1543 int64 count = CalculateTensorElementCount(tensor, found_unknown_shapes);
1544 int size = DataTypeSize(BaseType(tensor.dtype()));
1545 VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
1546 return count * size;
1547 }
1548
CalculateInputSize(const OpInfo & op_info,bool * found_unknown_shapes)1549 int64 OpLevelCostEstimator::CalculateInputSize(const OpInfo& op_info,
1550 bool* found_unknown_shapes) {
1551 int64 total_input_size = 0;
1552 for (auto& input : op_info.inputs()) {
1553 int64 input_size = CalculateTensorSize(input, found_unknown_shapes);
1554 total_input_size += input_size;
1555 VLOG(1) << "Input Size: " << input_size
1556 << " Total Input Size:" << total_input_size;
1557 }
1558 return total_input_size;
1559 }
1560
CalculateInputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1561 std::vector<int64> OpLevelCostEstimator::CalculateInputTensorSize(
1562 const OpInfo& op_info, bool* found_unknown_shapes) {
1563 std::vector<int64> input_tensor_size;
1564 input_tensor_size.reserve(op_info.inputs().size());
1565 for (auto& input : op_info.inputs()) {
1566 input_tensor_size.push_back(
1567 CalculateTensorSize(input, found_unknown_shapes));
1568 }
1569 return input_tensor_size;
1570 }
1571
CalculateLargestInputCount(const OpInfo & op_info,bool * found_unknown_shapes)1572 int64 OpLevelCostEstimator::CalculateLargestInputCount(
1573 const OpInfo& op_info, bool* found_unknown_shapes) {
1574 int64 largest_input_count = 0;
1575 for (auto& input : op_info.inputs()) {
1576 int64 input_count =
1577 CalculateTensorElementCount(input, found_unknown_shapes);
1578 if (input_count > largest_input_count) {
1579 largest_input_count = input_count;
1580 }
1581 VLOG(1) << "Input Count: " << input_count
1582 << " Largest Input Count:" << largest_input_count;
1583 }
1584 return largest_input_count;
1585 }
1586
CalculateOutputSize(const OpInfo & op_info,bool * found_unknown_shapes)1587 int64 OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info,
1588 bool* found_unknown_shapes) {
1589 int64 total_output_size = 0;
1590 // Use float as default for calculations.
1591 for (const auto& output : op_info.outputs()) {
1592 DataType dt = output.dtype();
1593 const auto& original_output_shape = output.shape();
1594 int64 output_size = DataTypeSize(BaseType(dt));
1595 int num_dims = std::max(1, original_output_shape.dim_size());
1596 auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1597 found_unknown_shapes);
1598 for (const auto& dim : output_shape.dim()) {
1599 output_size *= dim.size();
1600 }
1601 total_output_size += output_size;
1602 VLOG(1) << "Output Size: " << output_size
1603 << " Total Output Size:" << total_output_size;
1604 }
1605 return total_output_size;
1606 }
1607
CalculateOutputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1608 std::vector<int64> OpLevelCostEstimator::CalculateOutputTensorSize(
1609 const OpInfo& op_info, bool* found_unknown_shapes) {
1610 std::vector<int64> output_tensor_size;
1611 output_tensor_size.reserve(op_info.outputs().size());
1612 // Use float as default for calculations.
1613 for (const auto& output : op_info.outputs()) {
1614 DataType dt = output.dtype();
1615 const auto& original_output_shape = output.shape();
1616 int64 output_size = DataTypeSize(BaseType(dt));
1617 int num_dims = std::max(1, original_output_shape.dim_size());
1618 auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1619 found_unknown_shapes);
1620 for (const auto& dim : output_shape.dim()) {
1621 output_size *= dim.size();
1622 }
1623 output_tensor_size.push_back(output_size);
1624 }
1625 return output_tensor_size;
1626 }
1627
PredictDefaultNodeCosts(const int64 num_compute_ops,const OpContext & op_context,bool * found_unknown_shapes,NodeCosts * node_costs)1628 Status OpLevelCostEstimator::PredictDefaultNodeCosts(
1629 const int64 num_compute_ops, const OpContext& op_context,
1630 bool* found_unknown_shapes, NodeCosts* node_costs) {
1631 const auto& op_info = op_context.op_info;
1632 node_costs->num_compute_ops = num_compute_ops;
1633 node_costs->num_input_bytes_accessed =
1634 CalculateInputTensorSize(op_info, found_unknown_shapes);
1635 node_costs->num_output_bytes_accessed =
1636 CalculateOutputTensorSize(op_info, found_unknown_shapes);
1637 node_costs->max_memory = node_costs->num_total_output_bytes();
1638 if (*found_unknown_shapes) {
1639 node_costs->inaccurate = true;
1640 node_costs->num_nodes_with_unknown_shapes = 1;
1641 }
1642 return Status::OK();
1643 }
1644
HasZeroDim(const OpInfo & op_info)1645 bool HasZeroDim(const OpInfo& op_info) {
1646 for (int i = 0; i < op_info.inputs_size(); ++i) {
1647 const auto& input = op_info.inputs(i);
1648 for (int j = 0; j < input.shape().dim_size(); ++j) {
1649 const auto& dim = input.shape().dim(j);
1650 if (dim.size() == 0) {
1651 VLOG(1) << "Convolution config has zero dim "
1652 << op_info.ShortDebugString();
1653 return true;
1654 }
1655 }
1656 }
1657 return false;
1658 }
1659
PredictConv2D(const OpContext & op_context,NodeCosts * node_costs) const1660 Status OpLevelCostEstimator::PredictConv2D(const OpContext& op_context,
1661 NodeCosts* node_costs) const {
1662 const auto& op_info = op_context.op_info;
1663 if (HasZeroDim(op_info)) {
1664 node_costs->num_nodes_with_unknown_shapes = 1;
1665 return errors::InvalidArgument("Conv2D op includes zero dimension: ",
1666 op_info.ShortDebugString());
1667 }
1668 bool found_unknown_shapes = false;
1669 int64 num_compute_ops = CountConv2DOperations(op_info, &found_unknown_shapes);
1670 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1671 &found_unknown_shapes, node_costs);
1672 }
1673
PredictConv2DBackpropInput(const OpContext & op_context,NodeCosts * node_costs) const1674 Status OpLevelCostEstimator::PredictConv2DBackpropInput(
1675 const OpContext& op_context, NodeCosts* node_costs) const {
1676 const auto& op_info = op_context.op_info;
1677 if (HasZeroDim(op_info)) {
1678 node_costs->num_nodes_with_unknown_shapes = 1;
1679 return errors::InvalidArgument(
1680 "Conv2DBackpropInput op includes zero dimension",
1681 op_info.ShortDebugString());
1682 }
1683 bool found_unknown_shapes = false;
1684 int64 num_compute_ops = CountConv2DBackpropInputOperations(
1685 op_info, nullptr, &found_unknown_shapes);
1686 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1687 &found_unknown_shapes, node_costs);
1688 }
1689
PredictConv2DBackpropFilter(const OpContext & op_context,NodeCosts * node_costs) const1690 Status OpLevelCostEstimator::PredictConv2DBackpropFilter(
1691 const OpContext& op_context, NodeCosts* node_costs) const {
1692 const auto& op_info = op_context.op_info;
1693 if (HasZeroDim(op_info)) {
1694 node_costs->num_nodes_with_unknown_shapes = 1;
1695 return errors::InvalidArgument(
1696 "Conv2DBackpropFilter op includes zero dimension",
1697 op_info.ShortDebugString());
1698 }
1699 bool found_unknown_shapes = false;
1700 int64 num_compute_ops = CountConv2DBackpropFilterOperations(
1701 op_info, nullptr, &found_unknown_shapes);
1702 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1703 &found_unknown_shapes, node_costs);
1704 }
1705
PredictFusedConv2DBiasActivation(const OpContext & op_context,NodeCosts * node_costs) const1706 Status OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
1707 const OpContext& op_context, NodeCosts* node_costs) const {
1708 // FusedConv2DBiasActivation computes a fused kernel which implements:
1709 // 2D convolution, adds side input with separate scaling on convolution and
1710 // side inputs, then adds bias, and finally applies the ReLU activation
1711 // function to the result:
1712 //
1713 // Input -> Conv2D -> Add -> BiasAdd -> ReLU
1714 // ^ ^ ^
1715 // Filter Side Input Bias
1716 //
1717 // Note that when adding the side input, the operation multiplies the output
1718 // of Conv2D by conv_input_scale, confusingly, and the side_input by
1719 // side_input_scale.
1720 //
1721 // Note that in the special case that side_input_scale is 0, which we infer
1722 // from side_input having dimensions [], we skip that addition operation.
1723 //
1724 // For more information, see
1725 // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
1726
1727 // TODO(yaozhang): Support NHWC_VECT_W.
1728 std::string data_format = GetDataFormat(op_context.op_info);
1729 if (data_format != "NCHW" && data_format != "NHWC" &&
1730 data_format != "NCHW_VECT_C") {
1731 return errors::InvalidArgument(
1732 "Unsupported data format (", data_format,
1733 ") for op: ", op_context.op_info.ShortDebugString());
1734 }
1735 std::string filter_format = GetFilterFormat(op_context.op_info);
1736 if (filter_format != "HWIO" && filter_format != "OIHW" &&
1737 filter_format != "OIHW_VECT_I") {
1738 return errors::InvalidArgument(
1739 "Unsupported filter format (", filter_format,
1740 ") for op: ", op_context.op_info.ShortDebugString());
1741 }
1742
1743 auto& conv_input = op_context.op_info.inputs(0);
1744 auto& filter = op_context.op_info.inputs(1);
1745 auto& side_input = op_context.op_info.inputs(3);
1746 auto& conv_input_scale = op_context.op_info.inputs(4);
1747 auto& side_input_scale = op_context.op_info.inputs(5);
1748
1749 // Manually compute our convolution dimensions.
1750 bool found_unknown_shapes = false;
1751 auto dims = ConvolutionDimensionsFromInputs(
1752 conv_input.shape(), filter.shape(), op_context.op_info,
1753 &found_unknown_shapes);
1754 OpInfo::TensorProperties output;
1755 if (data_format == "NCHW" || data_format == "NCHW_VECT_C") {
1756 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox});
1757 } else if (data_format == "NHWC") {
1758 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz});
1759 }
1760
1761 // Add the operations the fused op always computes.
1762 std::vector<OpContext> component_ops = {
1763 FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
1764 FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
1765 FusedChildContext(
1766 op_context, "BiasAdd", output,
1767 {output, output}), // Note we're no longer using bias at all
1768 FusedChildContext(op_context, "Relu", output, {output})};
1769
1770 // Add our side_input iff it's non-empty.
1771 if (side_input.shape().dim_size() > 0) {
1772 component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
1773 {side_input, side_input_scale}));
1774 component_ops.push_back(FusedChildContext(
1775 op_context, "Add", output,
1776 {output, output})); // Note that we're not using side_input here
1777 }
1778
1779 // Construct an op_context which definitely has our output shape.
1780 auto op_context_with_output = op_context;
1781 op_context_with_output.op_info.mutable_outputs()->Clear();
1782 *op_context_with_output.op_info.mutable_outputs()->Add() = output;
1783
1784 // Construct component operations and run the cost computation.
1785 if (found_unknown_shapes) {
1786 node_costs->inaccurate = true;
1787 node_costs->num_nodes_with_unknown_shapes = 1;
1788 }
1789 return PredictFusedOp(op_context_with_output, component_ops, node_costs);
1790 }
1791
PredictMatMul(const OpContext & op_context,NodeCosts * node_costs) const1792 Status OpLevelCostEstimator::PredictMatMul(const OpContext& op_context,
1793 NodeCosts* node_costs) const {
1794 const auto& op_info = op_context.op_info;
1795 bool found_unknown_shapes = false;
1796 int64 num_compute_ops = CountMatMulOperations(op_info, &found_unknown_shapes);
1797 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1798 &found_unknown_shapes, node_costs);
1799 }
1800
PredictEinsum(const OpContext & op_context,NodeCosts * node_costs) const1801 Status OpLevelCostEstimator::PredictEinsum(const OpContext& op_context,
1802 NodeCosts* node_costs) const {
1803 const auto& op_info = op_context.op_info;
1804
1805 auto it = op_info.attr().find("equation");
1806 if (it == op_info.attr().end()) {
1807 return errors::InvalidArgument("Einsum op doesn't have equation attr: ",
1808 op_info.ShortDebugString());
1809 }
1810
1811 OpContext batch_matmul_op_context;
1812 bool found_unknown_shapes = false;
1813 bool success = GenerateBatchMatmulContextFromEinsum(
1814 op_context, &batch_matmul_op_context, &found_unknown_shapes);
1815 if (found_unknown_shapes) {
1816 node_costs->inaccurate = true;
1817 node_costs->num_nodes_with_unknown_shapes = 1;
1818 }
1819 if (!success) {
1820 return PredictCostOfAnUnknownOp(op_context, node_costs);
1821 }
1822 return PredictNodeCosts(batch_matmul_op_context, node_costs);
1823 }
1824
PredictSparseTensorDenseMatMul(const OpContext & op_context,NodeCosts * node_costs) const1825 Status OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
1826 const OpContext& op_context, NodeCosts* node_costs) const {
1827 const auto& op_info = op_context.op_info;
1828 bool found_unknown_shapes = false;
1829 // input[0]: indices in sparse matrix a
1830 // input[1]: values in sparse matrix a
1831 // input[2]: shape of matrix a
1832 // input[3]: matrix b
1833 // See
1834 // https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
1835 int64 num_elems_in_a =
1836 CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1837 auto b_matrix = op_info.inputs(3);
1838 auto b_matrix_shape =
1839 MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
1840 int64 n_dim = b_matrix_shape.dim(1).size();
1841
1842 // Each element in A is multiplied and added with an element from each column
1843 // in b.
1844 const int64 op_count = kOpsPerMac * num_elems_in_a * n_dim;
1845
1846 int64 a_indices_input_size =
1847 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1848 int64 a_values_input_size =
1849 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1850 int64 a_shape_input_size =
1851 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1852 int64 b_input_size =
1853 num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
1854 int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
1855
1856 node_costs->num_compute_ops = op_count;
1857 node_costs->num_input_bytes_accessed = {a_indices_input_size,
1858 a_values_input_size,
1859 a_shape_input_size, b_input_size};
1860 node_costs->num_output_bytes_accessed = {output_size};
1861 if (found_unknown_shapes) {
1862 node_costs->inaccurate = true;
1863 node_costs->num_nodes_with_unknown_shapes = 1;
1864 }
1865 return Status::OK();
1866 }
1867
PredictNoOp(const OpContext & op_context,NodeCosts * node_costs) const1868 Status OpLevelCostEstimator::PredictNoOp(const OpContext& op_context,
1869 NodeCosts* node_costs) const {
1870 const auto& op_info = op_context.op_info;
1871 VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1872 // By default, NodeCosts is initialized to zero ops and bytes.
1873 return Status::OK();
1874 }
1875
PredictPureMemoryOp(const OpContext & op_context,NodeCosts * node_costs) const1876 Status OpLevelCostEstimator::PredictPureMemoryOp(const OpContext& op_context,
1877 NodeCosts* node_costs) const {
1878 // Each output element is a copy of some element from input, with no required
1879 // computation, so just compute memory costs.
1880 bool found_unknown_shapes = false;
1881 node_costs->num_nodes_with_pure_memory_op = 1;
1882 return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
1883 node_costs);
1884 }
1885
PredictIdentity(const OpContext & op_context,NodeCosts * node_costs) const1886 Status OpLevelCostEstimator::PredictIdentity(const OpContext& op_context,
1887 NodeCosts* node_costs) const {
1888 const auto& op_info = op_context.op_info;
1889 VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Identity";
1890 node_costs->minimum_cost_op = true;
1891 node_costs->num_compute_ops = kMinComputeOp;
1892 // Identity op internally pass input tensor buffer's pointer to the output
1893 // tensor buffer; no actual memory operation.
1894 node_costs->num_input_bytes_accessed = {0};
1895 node_costs->num_output_bytes_accessed = {0};
1896 bool inaccurate = false;
1897 node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1898 if (inaccurate) {
1899 node_costs->inaccurate = true;
1900 node_costs->num_nodes_with_unknown_shapes = 1;
1901 }
1902 return Status::OK();
1903 }
1904
PredictVariable(const OpContext & op_context,NodeCosts * node_costs) const1905 Status OpLevelCostEstimator::PredictVariable(const OpContext& op_context,
1906 NodeCosts* node_costs) const {
1907 const auto& op_info = op_context.op_info;
1908 VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Variable";
1909 node_costs->minimum_cost_op = true;
1910 node_costs->num_compute_ops = kMinComputeOp;
1911 // Variables are persistent ops; initialized before step; hence, no memory
1912 // cost.
1913 node_costs->num_input_bytes_accessed = {0};
1914 node_costs->num_output_bytes_accessed = {0};
1915 bool inaccurate = false;
1916 node_costs->persistent_memory = CalculateOutputSize(op_info, &inaccurate);
1917 if (inaccurate) {
1918 node_costs->inaccurate = true;
1919 node_costs->num_nodes_with_unknown_shapes = 1;
1920 }
1921 return Status::OK();
1922 }
1923
PredictBatchMatMul(const OpContext & op_context,NodeCosts * node_costs) const1924 Status OpLevelCostEstimator::PredictBatchMatMul(const OpContext& op_context,
1925 NodeCosts* node_costs) const {
1926 const auto& op_info = op_context.op_info;
1927 bool found_unknown_shapes = false;
1928 int64 num_compute_ops =
1929 CountBatchMatMulOperations(op_info, &found_unknown_shapes);
1930 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1931 &found_unknown_shapes, node_costs);
1932 }
1933
PredictMetadata(const OpContext & op_context,NodeCosts * node_costs) const1934 Status OpLevelCostEstimator::PredictMetadata(const OpContext& op_context,
1935 NodeCosts* node_costs) const {
1936 const auto& op_info = op_context.op_info;
1937 node_costs->minimum_cost_op = true;
1938 node_costs->num_compute_ops = kMinComputeOp;
1939 node_costs->num_input_bytes_accessed = {0};
1940 node_costs->num_output_bytes_accessed = {0};
1941 bool inaccurate = false;
1942 node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1943 if (inaccurate) {
1944 node_costs->inaccurate = true;
1945 node_costs->num_nodes_with_unknown_shapes = 1;
1946 }
1947 return Status::OK();
1948 }
1949
PredictGatherOrSlice(const OpContext & op_context,NodeCosts * node_costs) const1950 Status OpLevelCostEstimator::PredictGatherOrSlice(const OpContext& op_context,
1951 NodeCosts* node_costs) const {
1952 // Gather & Slice ops can have a very large input, but only access a small
1953 // part of it. For these op the size of the output determines the memory cost.
1954 const auto& op_info = op_context.op_info;
1955
1956 const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
1957 if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
1958 return errors::InvalidArgument(
1959 op_info.op(),
1960 " Op doesn't have valid input / output: ", op_info.ShortDebugString());
1961 }
1962
1963 bool unknown_shapes = false;
1964
1965 // Each output element is a copy of some element from input.
1966 // For roofline estimate we assume each copy has a unit cost.
1967 const int64 op_count =
1968 CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
1969 node_costs->num_compute_ops = op_count;
1970
1971 const int64 output_size = CalculateOutputSize(op_info, &unknown_shapes);
1972 node_costs->num_output_bytes_accessed = {output_size};
1973
1974 node_costs->num_input_bytes_accessed.reserve(op_info.inputs().size());
1975 int64 input_size = output_size;
1976 // Note that input(0) byte accessed is not equal to input(0) tensor size.
1977 // It's equal to the output size; though, input access is indexed gather or
1978 // slice (ignore duplicate indices).
1979 node_costs->num_input_bytes_accessed.push_back(input_size);
1980 int begin_input_index = 1;
1981 int end_input_index;
1982 if (op_info.op() == "Slice") {
1983 // Slice: 'input' (omitted), 'begin', 'size'
1984 end_input_index = 3;
1985 } else if (op_info.op() == "StridedSlice") {
1986 // StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
1987 end_input_index = 4;
1988 } else {
1989 // Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
1990 end_input_index = 2;
1991 }
1992 for (int i = begin_input_index; i < end_input_index; ++i) {
1993 node_costs->num_input_bytes_accessed.push_back(
1994 CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes));
1995 }
1996 if (unknown_shapes) {
1997 node_costs->inaccurate = true;
1998 node_costs->num_nodes_with_unknown_shapes = 1;
1999 }
2000 return Status::OK();
2001 }
2002
PredictScatter(const OpContext & op_context,NodeCosts * node_costs) const2003 Status OpLevelCostEstimator::PredictScatter(const OpContext& op_context,
2004 NodeCosts* node_costs) const {
2005 // Scatter ops sparsely access a reference input and output tensor.
2006 const auto& op_info = op_context.op_info;
2007 bool found_unknown_shapes = false;
2008
2009 // input[0]: ref tensor that will be sparsely accessed
2010 // input[1]: indices - A tensor of indices into the first dimension of ref.
2011 // input[2]: updates where updates.shape = indices.shape + ref.shape[1:]
2012 // See
2013 // https://www.tensorflow.org/api_docs/python/tf/scatter_add and
2014 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/state_ops.cc#L146
2015
2016 const int64 num_indices =
2017 CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
2018
2019 int64 num_elems_in_ref_per_index = 1;
2020 auto ref_tensor_shape = MaybeGetMinimumShape(
2021 op_info.inputs(0).shape(), op_info.inputs(0).shape().dim_size(),
2022 &found_unknown_shapes);
2023 for (int i = 1; i < ref_tensor_shape.dim().size(); ++i) {
2024 num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size();
2025 }
2026 const int64 op_count = num_indices * num_elems_in_ref_per_index;
2027 node_costs->num_compute_ops = op_count;
2028
2029 // Sparsely access ref so input size depends on the number of operations
2030 int64 ref_input_size =
2031 op_count * DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2032 int64 indices_input_size =
2033 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2034 int64 updates_input_size =
2035 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2036 node_costs->num_input_bytes_accessed = {ref_input_size, indices_input_size,
2037 updates_input_size};
2038
2039 // Sparsely access ref so output size depends on the number of operations
2040 int64 output_size =
2041 op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype()));
2042 node_costs->num_output_bytes_accessed = {output_size};
2043
2044 if (found_unknown_shapes) {
2045 node_costs->inaccurate = true;
2046 node_costs->num_nodes_with_unknown_shapes = 1;
2047 }
2048 return Status::OK();
2049 }
2050
PredictFusedOp(const OpContext & op_context,const std::vector<OpContext> & fused_op_contexts,NodeCosts * node_costs) const2051 Status OpLevelCostEstimator::PredictFusedOp(
2052 const OpContext& op_context,
2053 const std::vector<OpContext>& fused_op_contexts,
2054 NodeCosts* node_costs) const {
2055 // Note that PredictDefaultNodeCosts will get the correct memory costs from
2056 // the node's inputs and outputs; but we don't want to have to re-implement
2057 // the logic for computing the operation count of each of our component
2058 // operations here; so we simply add the compute times of each component
2059 // operation, then update the cost.
2060 bool found_unknown_shapes = false;
2061 Status s =
2062 PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, node_costs);
2063
2064 for (auto& fused_op : fused_op_contexts) {
2065 NodeCosts fused_node_costs;
2066 s.Update(PredictNodeCosts(fused_op, &fused_node_costs));
2067 node_costs->num_compute_ops += fused_node_costs.num_compute_ops;
2068 node_costs->inaccurate |= fused_node_costs.inaccurate;
2069 // Set, not increment. Note that we are predicting the cost of one fused
2070 // node, not a function node composed of many nodes.
2071 node_costs->num_nodes_with_unknown_shapes |=
2072 fused_node_costs.num_nodes_with_unknown_shapes;
2073 node_costs->num_nodes_with_unknown_op_type |=
2074 fused_node_costs.num_nodes_with_unknown_op_type;
2075 node_costs->num_nodes_with_pure_memory_op |=
2076 fused_node_costs.num_nodes_with_pure_memory_op;
2077 }
2078
2079 return Status::OK();
2080 }
2081
2082 /* static */
FusedChildContext(const OpContext & parent,const std::string & op_name,const OpInfo::TensorProperties & output,const std::vector<OpInfo::TensorProperties> & inputs)2083 OpContext OpLevelCostEstimator::FusedChildContext(
2084 const OpContext& parent, const std::string& op_name,
2085 const OpInfo::TensorProperties& output,
2086 const std::vector<OpInfo::TensorProperties>& inputs) {
2087 // Setup the base parameters of our new context.
2088 OpContext new_context;
2089 new_context.name = op_name;
2090 new_context.device_name = parent.device_name;
2091 new_context.op_info = parent.op_info;
2092 new_context.op_info.set_op(op_name);
2093
2094 // Setup the inputs of our new context.
2095 new_context.op_info.mutable_inputs()->Clear();
2096 for (const auto& input : inputs) {
2097 *new_context.op_info.mutable_inputs()->Add() = input;
2098 }
2099
2100 // Setup the output of our new context.
2101 new_context.op_info.mutable_outputs()->Clear();
2102 *new_context.op_info.mutable_outputs()->Add() = output;
2103
2104 return new_context;
2105 }
2106
2107 /* static */
DescribeTensor(DataType type,const std::vector<int64> & dims)2108 OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
2109 DataType type, const std::vector<int64>& dims) {
2110 OpInfo::TensorProperties ret;
2111 ret.set_dtype(type);
2112
2113 auto shape = ret.mutable_shape();
2114 for (const int dim : dims) {
2115 shape->add_dim()->set_size(dim);
2116 }
2117
2118 return ret;
2119 }
2120
2121 /* static */
2122 OpLevelCostEstimator::ConvolutionDimensions
OpDimensionsFromInputs(const TensorShapeProto & original_image_shape,const OpInfo & op_info,bool * found_unknown_shapes)2123 OpLevelCostEstimator::OpDimensionsFromInputs(
2124 const TensorShapeProto& original_image_shape, const OpInfo& op_info,
2125 bool* found_unknown_shapes) {
2126 VLOG(2) << "op features: " << op_info.DebugString();
2127 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
2128 auto image_shape =
2129 MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
2130 VLOG(2) << "Image shape: " << image_shape.DebugString();
2131
2132 int x_index, y_index, channel_index;
2133 const std::string& data_format = GetDataFormat(op_info);
2134 if (data_format == "NCHW") {
2135 channel_index = 1;
2136 y_index = 2;
2137 x_index = 3;
2138 } else {
2139 y_index = 1;
2140 x_index = 2;
2141 channel_index = 3;
2142 }
2143 int64 batch = image_shape.dim(0).size();
2144 int64 ix = image_shape.dim(x_index).size();
2145 int64 iy = image_shape.dim(y_index).size();
2146 int64 iz = image_shape.dim(channel_index).size();
2147
2148 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
2149 // {1, 1, 1, 1} in that case.
2150 std::vector<int64> ksize = GetKernelSize(op_info);
2151 int64 kx = ksize[x_index];
2152 int64 ky = ksize[y_index];
2153 // These ops don't support groupwise operation, therefore kz == iz.
2154 int64 kz = iz;
2155
2156 std::vector<int64> strides = GetStrides(op_info);
2157 int64 sx = strides[x_index];
2158 int64 sy = strides[y_index];
2159 const auto padding = GetPadding(op_info);
2160
2161 int64 ox = GetOutputSize(ix, kx, sx, padding);
2162 int64 oy = GetOutputSize(iy, ky, sy, padding);
2163 int64 oz = iz;
2164
2165 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
2166 batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
2167 return conv_dims;
2168 }
2169
PredictMaxPool(const OpContext & op_context,NodeCosts * node_costs) const2170 Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
2171 NodeCosts* node_costs) const {
2172 bool found_unknown_shapes = false;
2173 const auto& op_info = op_context.op_info;
2174 // x: op_info.inputs(0)
2175 ConvolutionDimensions dims = OpDimensionsFromInputs(
2176 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2177 // kx * ky - 1 comparisons per output (kx * xy > 1)
2178 // or 1 copy per output (kx * k1 = 1).
2179 int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
2180 int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
2181 node_costs->num_compute_ops = ops;
2182
2183 int64 input_size = 0;
2184 if (dims.ky >= dims.sy) {
2185 input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2186 } else { // dims.ky < dims.sy
2187 // Vertical stride is larger than vertical kernel; assuming row-major
2188 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2189 // others are not used for output).
2190 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2191 input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2192 }
2193 node_costs->num_input_bytes_accessed = {input_size};
2194 const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
2195 node_costs->num_output_bytes_accessed = {output_size};
2196 node_costs->max_memory = output_size;
2197 if (found_unknown_shapes) {
2198 node_costs->inaccurate = true;
2199 node_costs->num_nodes_with_unknown_shapes = 1;
2200 }
2201 return Status::OK();
2202 }
2203
PredictMaxPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2204 Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
2205 NodeCosts* node_costs) const {
2206 bool found_unknown_shapes = false;
2207 const auto& op_info = op_context.op_info;
2208 // x: op_info.inputs(0)
2209 // y: op_info.inputs(1)
2210 // y_grad: op_info.inputs(2)
2211 if (op_info.inputs_size() < 3) {
2212 return errors::InvalidArgument("MaxPoolGrad op has invalid inputs: ",
2213 op_info.ShortDebugString());
2214 }
2215
2216 ConvolutionDimensions dims = OpDimensionsFromInputs(
2217 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2218
2219 int64 ops = 0;
2220 if (dims.kx == 1 && dims.ky == 1) {
2221 // 1x1 window. No need to know which input was max.
2222 ops = dims.batch * dims.ix * dims.iy * dims.iz;
2223 } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2224 // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
2225 ops = dims.batch * dims.iz *
2226 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
2227 } else {
2228 // Overlapping window: initialize with zeros, re-run maxpool, then
2229 // accumulate y_gad to proper x_grad locations.
2230 ops = dims.batch * dims.iz *
2231 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
2232 }
2233 node_costs->num_compute_ops = ops;
2234
2235 // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
2236 // MaxPool internally.
2237 const int64 input0_size =
2238 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2239 const int64 input2_size =
2240 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2241 node_costs->num_input_bytes_accessed = {input0_size, 0, input2_size};
2242 // Write x_grad; size equal to x.
2243 const int64 output_size =
2244 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2245 node_costs->num_output_bytes_accessed = {output_size};
2246 node_costs->max_memory = output_size;
2247
2248 if (found_unknown_shapes) {
2249 node_costs->inaccurate = true;
2250 node_costs->num_nodes_with_unknown_shapes = 1;
2251 }
2252 return Status::OK();
2253 }
2254
2255 /* This predict function handles three types of tensorflow ops
2256 * AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting
2257 * was not possible for these ops, therefore the input tensor's shapes is
2258 * enough to compute the cost */
PredictAssignVariableOps(const OpContext & op_context,NodeCosts * node_costs) const2259 Status OpLevelCostEstimator::PredictAssignVariableOps(
2260 const OpContext& op_context, NodeCosts* node_costs) const {
2261 bool found_unknown_shapes = false;
2262 const auto& op_info = op_context.op_info;
2263 /* First input of these ops are reference to the assignee. */
2264 if (op_info.inputs_size() != 2) {
2265 return errors::InvalidArgument("AssignVariable op has invalid input: ",
2266 op_info.ShortDebugString());
2267 }
2268
2269 const int64 ops = op_info.op() == kAssignVariableOp
2270 ? 0
2271 : CalculateTensorElementCount(op_info.inputs(1),
2272 &found_unknown_shapes);
2273 node_costs->num_compute_ops = ops;
2274 const int64 input_size = CalculateInputSize(op_info, &found_unknown_shapes);
2275 node_costs->num_input_bytes_accessed = {input_size};
2276 // TODO(dyoon): check these ops' behavior whether it writes data;
2277 // Op itself doesn't have output tensor, but it may modify the input (ref or
2278 // resource). Maybe use node_costs->internal_write_bytes.
2279 node_costs->num_output_bytes_accessed = {0};
2280 if (found_unknown_shapes) {
2281 node_costs->inaccurate = true;
2282 node_costs->num_nodes_with_unknown_shapes = 1;
2283 }
2284 return Status::OK();
2285 }
2286
PredictAvgPool(const OpContext & op_context,NodeCosts * node_costs) const2287 Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
2288 NodeCosts* node_costs) const {
2289 bool found_unknown_shapes = false;
2290 const auto& op_info = op_context.op_info;
2291 // x: op_info.inputs(0)
2292 ConvolutionDimensions dims = OpDimensionsFromInputs(
2293 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2294
2295 // kx * ky - 1 additions and 1 multiplication per output.
2296 int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
2297 node_costs->num_compute_ops = ops;
2298
2299 int64 input_size;
2300 if (dims.ky >= dims.sy) {
2301 input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2302 } else { // dims.ky < dims.sy
2303 // vertical stride is larger than vertical kernel; assuming row-major
2304 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2305 // others are not used for output).
2306 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2307 input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2308 }
2309 node_costs->num_input_bytes_accessed = {input_size};
2310
2311 const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
2312 node_costs->num_output_bytes_accessed = {output_size};
2313 node_costs->max_memory = output_size;
2314
2315 if (found_unknown_shapes) {
2316 node_costs->inaccurate = true;
2317 node_costs->num_nodes_with_unknown_shapes = 1;
2318 }
2319 return Status::OK();
2320 }
2321
PredictAvgPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2322 Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
2323 NodeCosts* node_costs) const {
2324 bool found_unknown_shapes = false;
2325 const auto& op_info = op_context.op_info;
2326 // x's shape: op_info.inputs(0)
2327 // y_grad: op_info.inputs(1)
2328
2329 // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
2330 bool shape_found = false;
2331 TensorShapeProto x_shape;
2332 if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
2333 const TensorProto& value = op_info.inputs(0).value();
2334 shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
2335 }
2336 if (!shape_found && op_info.outputs_size() > 0) {
2337 x_shape = op_info.outputs(0).shape();
2338 shape_found = true;
2339 }
2340 if (!shape_found) {
2341 // Set the minimum shape that's feasible.
2342 x_shape.Clear();
2343 for (int i = 0; i < 4; ++i) {
2344 x_shape.add_dim()->set_size(1);
2345 }
2346 found_unknown_shapes = true;
2347 }
2348
2349 ConvolutionDimensions dims =
2350 OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
2351
2352 int64 ops = 0;
2353 if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2354 // Non-overlapping window.
2355 ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
2356 } else {
2357 // Overlapping window.
2358 ops = dims.batch * dims.iz *
2359 (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
2360 }
2361 auto s = PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2362 node_costs);
2363 node_costs->max_memory = node_costs->num_total_output_bytes();
2364 return s;
2365 }
2366
PredictFusedBatchNorm(const OpContext & op_context,NodeCosts * node_costs) const2367 Status OpLevelCostEstimator::PredictFusedBatchNorm(
2368 const OpContext& op_context, NodeCosts* node_costs) const {
2369 bool found_unknown_shapes = false;
2370 const auto& op_info = op_context.op_info;
2371 // x: op_info.inputs(0)
2372 // scale: op_info.inputs(1)
2373 // offset: op_info.inputs(2)
2374 // mean: op_info.inputs(3) --> only for inference
2375 // variance: op_info.inputs(4) --> only for inference
2376 ConvolutionDimensions dims = OpDimensionsFromInputs(
2377 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2378 const bool is_training = IsTraining(op_info);
2379
2380 int64 ops = 0;
2381 const auto rsqrt_cost = Eigen::internal::functor_traits<
2382 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2383 if (is_training) {
2384 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
2385 } else {
2386 ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
2387 }
2388 node_costs->num_compute_ops = ops;
2389
2390 const int64 size_nhwc =
2391 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2392 const int64 size_c =
2393 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2394 if (is_training) {
2395 node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c};
2396 node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2397 size_c};
2398 // FusedBatchNorm in training mode internally re-reads the input tensor:
2399 // one for mean/variance, and the 2nd internal read forthe actual scaling.
2400 // Assume small intermediate data such as mean / variance (size_c) can be
2401 // cached on-chip.
2402 node_costs->internal_read_bytes = size_nhwc;
2403 } else {
2404 node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2405 size_c};
2406 node_costs->num_output_bytes_accessed = {size_nhwc};
2407 }
2408 node_costs->max_memory = node_costs->num_total_output_bytes();
2409
2410 if (found_unknown_shapes) {
2411 node_costs->inaccurate = true;
2412 node_costs->num_nodes_with_unknown_shapes = 1;
2413 }
2414 return Status::OK();
2415 }
2416
PredictFusedBatchNormGrad(const OpContext & op_context,NodeCosts * node_costs) const2417 Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
2418 const OpContext& op_context, NodeCosts* node_costs) const {
2419 bool found_unknown_shapes = false;
2420 const auto& op_info = op_context.op_info;
2421 // y_backprop: op_info.inputs(0)
2422 // x: op_info.inputs(1)
2423 // scale: op_info.inputs(2)
2424 // mean: op_info.inputs(3)
2425 // variance or inverse of variance: op_info.inputs(4)
2426 ConvolutionDimensions dims = OpDimensionsFromInputs(
2427 op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
2428
2429 int64 ops = 0;
2430 const auto rsqrt_cost = Eigen::internal::functor_traits<
2431 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2432 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
2433 node_costs->num_compute_ops = ops;
2434
2435 const int64 size_nhwc =
2436 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2437 const int64 size_c =
2438 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2439 // TODO(dyoon): fix missing memory cost for variance input (size_c) and
2440 // yet another read of y_backprop (size_nhwc) internally.
2441 node_costs->num_input_bytes_accessed = {size_nhwc, size_nhwc, size_c, size_c};
2442 node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c};
2443 // FusedBatchNormGrad has to read y_backprop internally.
2444 node_costs->internal_read_bytes = size_nhwc;
2445 node_costs->max_memory = node_costs->num_total_output_bytes();
2446
2447 if (found_unknown_shapes) {
2448 node_costs->inaccurate = true;
2449 node_costs->num_nodes_with_unknown_shapes = 1;
2450 }
2451 return Status::OK();
2452 }
2453
PredictNaryOp(const OpContext & op_context,NodeCosts * node_costs) const2454 Status OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context,
2455 NodeCosts* node_costs) const {
2456 const auto& op_info = op_context.op_info;
2457 bool found_unknown_shapes = false;
2458 // Calculate the largest known tensor size across all inputs and output.
2459 int64 op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
2460 // If output shape is available, try to use the element count calculated from
2461 // that.
2462 if (op_info.outputs_size() > 0) {
2463 op_count = std::max(
2464 op_count,
2465 CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
2466 }
2467 // Also calculate the output shape possibly resulting from broadcasting.
2468 // Note that the some Nary ops (such as AddN) do not support broadcasting,
2469 // but we're including this here for completeness.
2470 if (op_info.inputs_size() >= 2) {
2471 op_count = std::max(op_count, CwiseOutputElementCount(op_info));
2472 }
2473
2474 // Nary ops perform one operation for every element in every input tensor.
2475 op_count *= op_info.inputs_size() - 1;
2476
2477 const auto sum_cost = Eigen::internal::functor_traits<
2478 Eigen::internal::scalar_sum_op<float>>::Cost;
2479 return PredictDefaultNodeCosts(op_count * sum_cost, op_context,
2480 &found_unknown_shapes, node_costs);
2481 }
2482
2483 // softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
PredictSoftmax(const OpContext & op_context,NodeCosts * node_costs) const2484 Status OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context,
2485 NodeCosts* node_costs) const {
2486 bool found_unknown_shapes = false;
2487 const int64 logits_size = CalculateTensorElementCount(
2488 op_context.op_info.inputs(0), &found_unknown_shapes);
2489 // Softmax input rank should be >=1.
2490 TensorShapeProto logits_shape = op_context.op_info.inputs(0).shape();
2491 if (logits_shape.unknown_rank() || logits_shape.dim_size() == 0) {
2492 return errors::InvalidArgument("Softmax op has invalid input: ",
2493 op_context.op_info.ShortDebugString());
2494 }
2495
2496 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2497
2498 // Every element of <logits> will be exponentiated, have that result included
2499 // in a sum across j, and also have that result multiplied by the reciprocal
2500 // of the sum_j. In addition, we'll compute 1/sum_j for every i.
2501 auto ops =
2502 (EIGEN_COST(scalar_exp_op<float>) + EIGEN_COST(scalar_sum_op<float>) +
2503 EIGEN_COST(scalar_product_op<float>)) *
2504 logits_size +
2505 EIGEN_COST(scalar_inverse_op<float>) * logits_shape.dim(0).size();
2506
2507 #undef EIGEN_COST
2508 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2509 node_costs);
2510 }
2511
PredictResizeBilinear(const OpContext & op_context,NodeCosts * node_costs) const2512 Status OpLevelCostEstimator::PredictResizeBilinear(
2513 const OpContext& op_context, NodeCosts* node_costs) const {
2514 bool found_unknown_shapes = false;
2515
2516 if (op_context.op_info.outputs().empty() ||
2517 op_context.op_info.inputs().empty()) {
2518 return errors::InvalidArgument(
2519 "ResizeBilinear op has invalid input / output ",
2520 op_context.op_info.ShortDebugString());
2521 }
2522
2523 const int64 output_elements = CalculateTensorElementCount(
2524 op_context.op_info.outputs(0), &found_unknown_shapes);
2525
2526 const auto half_pixel_centers =
2527 op_context.op_info.attr().find("half_pixel_centers");
2528 bool use_half_pixel_centers = false;
2529 if (half_pixel_centers == op_context.op_info.attr().end()) {
2530 LOG(WARNING) << "half_pixel_centers attr not set for ResizeBilinear.";
2531 return PredictCostOfAnUnknownOp(op_context, node_costs);
2532 } else {
2533 use_half_pixel_centers = half_pixel_centers->second.b();
2534 }
2535
2536 // Compose cost of bilinear interpolation.
2537 int64 ops = 0;
2538
2539 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2540 const auto sub_cost_float = EIGEN_COST(scalar_difference_op<float>);
2541 const auto sub_cost_int = EIGEN_COST(scalar_difference_op<int64>);
2542 const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2543 const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2544 const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2545 const auto max_cost = EIGEN_COST(scalar_max_op<int64>);
2546 const auto min_cost = EIGEN_COST(scalar_min_op<int64>);
2547 const auto cast_to_int_cost = Eigen::internal::functor_traits<
2548 Eigen::internal::scalar_cast_op<float, int64>>::Cost;
2549 const auto cast_to_float_cost = Eigen::internal::functor_traits<
2550 Eigen::internal::scalar_cast_op<int64, float>>::Cost;
2551 const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2552 #undef EIGEN_COST
2553
2554 // Ops calcualted from tensorflow/core/kernels/image/resize_bilinear_op.cc.
2555
2556 // Op counts taken from resize_bilinear implementation on 07/21/2020.
2557 // Computed op counts may become inaccurate if resize_bilinear implementation
2558 // changes.
2559
2560 // resize_bilinear has an optimization where the interpolation weights are
2561 // precomputed and cached. Given input tensors of size [B,H1,W1,C] and output
2562 // tensors of size [B,H2,W2,C], the last dimension C that needs to be accessed
2563 // in the input for interpolation are identical at every point in the output.
2564 // These values are cached in the compute_interpolation_weights function. For
2565 // a particular y in [0...H2-1], the rows to be accessed in the input are the
2566 // same. Likewise, for a particular x in [0...H2-1], the columns to be accsed
2567 // are the same. So the precomputation only needs to be done for H2 + W2
2568 // values.
2569 const auto output_shape = MaybeGetMinimumShape(
2570 op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2571 // Assume H is dim 1 and W is dim 2 to match logic in resize_bilinear, which
2572 // also makes this assumption.
2573 const int64 output_height = output_shape.dim(1).size();
2574 const int64 output_width = output_shape.dim(2).size();
2575 // Add the ops done outside of the scaler function in
2576 // compute_interpolation_weights.
2577 int64 interp_weight_cost = floor_cost + max_cost + min_cost + sub_cost_float +
2578 sub_cost_int + ceil_cost + cast_to_int_cost * 2;
2579 // There are two options for computing the weight of each pixel in the
2580 // interpolation. Algorithm can use pixel centers, or corners, for the
2581 // weight. Ops depend on the scaler function passed into
2582 // compute_interpolation_weights.
2583 if (use_half_pixel_centers) {
2584 // Ops for HalfPixelScalaer.
2585 interp_weight_cost +=
2586 add_cost + mul_cost + sub_cost_float + cast_to_float_cost;
2587 } else {
2588 // Ops for LegacyScaler.
2589 interp_weight_cost += cast_to_float_cost + mul_cost;
2590 }
2591 // Cost for the interpolation is multipled by (H2 + w2), as mentioned above.
2592 ops += interp_weight_cost * (output_height + output_width);
2593
2594 // Ops for computing the new values, done for every element. Logic is from
2595 // compute_lerp in the inner loop of resize_image which consists of:
2596 // const float top = top_left + (top_right - top_left) * x_lerp;
2597 // const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
2598 // return top + (bottom - top) * y_lerp;
2599 ops += (add_cost * 3 + sub_cost_float * 3 + mul_cost * 3) * output_elements;
2600
2601 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2602 node_costs);
2603 }
2604
PredictCropAndResize(const OpContext & op_context,NodeCosts * node_costs) const2605 Status OpLevelCostEstimator::PredictCropAndResize(const OpContext& op_context,
2606 NodeCosts* node_costs) const {
2607 bool found_unknown_shapes = false;
2608
2609 const auto method = op_context.op_info.attr().find("method");
2610 bool use_bilinear_interp;
2611 if (method == op_context.op_info.attr().end() ||
2612 method->second.s() == "bilinear") {
2613 use_bilinear_interp = true;
2614 } else if (method->second.s() == "nearest") {
2615 use_bilinear_interp = false;
2616 } else {
2617 LOG(WARNING) << "method attr in CropAndResize invalid; expected bilinear "
2618 "or nearest.";
2619 return PredictCostOfAnUnknownOp(op_context, node_costs);
2620 }
2621
2622 const int64 num_boxes = op_context.op_info.inputs(1).shape().dim(0).size();
2623 const auto crop_shape = MaybeGetMinimumShape(
2624 op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2625 const int64 crop_height = crop_shape.dim(1).size();
2626 const int64 crop_width = crop_shape.dim(2).size();
2627 const int64 output_elements = CalculateTensorElementCount(
2628 op_context.op_info.outputs(0), &found_unknown_shapes);
2629
2630 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2631 const auto sub_cost = EIGEN_COST(scalar_difference_op<float>);
2632 const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2633 const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2634 auto div_cost = EIGEN_COST(scalar_div_cost<float>);
2635 const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2636 const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2637 auto round_cost = EIGEN_COST(scalar_round_op<float>);
2638 const auto cast_to_float_cost = Eigen::internal::functor_traits<
2639 Eigen::internal::scalar_cast_op<int64, float>>::Cost;
2640 #undef EIGEN_COST
2641
2642 // Computing ops following
2643 // tensorflow/core/kernels/image/crop_and_resize_op.cc at 08/25/2020. Op
2644 // calculation differs from rough estimate in implementation, as it separates
2645 // out cost per box from cost per pixel and cost per element.
2646
2647 // Ops for variables height_scale and width_scale.
2648 int64 ops = (sub_cost * 6 + mul_cost * 2 + div_cost * 2) * num_boxes;
2649 // Ops for variable in_y.
2650 ops += (mul_cost * 2 + sub_cost + add_cost) * crop_height * num_boxes;
2651 // Ops for variable in_x (same computation across both branches).
2652 ops += (mul_cost * 2 + sub_cost + add_cost) * crop_height * crop_width *
2653 num_boxes;
2654 // Specify op_cost based on the method.
2655 if (use_bilinear_interp) {
2656 // Ops for variables top_y_index, bottom_y_index, y_lerp.
2657 ops += (floor_cost + ceil_cost + sub_cost) * crop_height * num_boxes;
2658 // Ops for variables left_x, right_x, x_lerp;
2659 ops += (floor_cost + ceil_cost + sub_cost) * crop_height * crop_width *
2660 num_boxes;
2661 // Ops for innermost loop across depth.
2662 ops +=
2663 (cast_to_float_cost * 4 + add_cost * 3 + sub_cost * 3 + mul_cost * 3) *
2664 output_elements;
2665 } else /* method == "nearest" */ {
2666 // Ops for variables closest_x_index and closest_y_index.
2667 ops += round_cost * 2 * crop_height * crop_width * num_boxes;
2668 // Ops for innermost loop across depth.
2669 ops += cast_to_float_cost * output_elements;
2670 }
2671 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2672 node_costs);
2673 }
2674
2675 } // end namespace grappler
2676 } // end namespace tensorflow
2677