1
2 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
18
19 #include "absl/strings/match.h"
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/costs/op_context.h"
28 #include "tensorflow/core/grappler/costs/utils.h"
29 #include "tensorflow/core/platform/errors.h"
30
31 namespace tensorflow {
32 namespace grappler {
33
34 // TODO(dyoon): update op to Predict method map for TF ops with V2 or V3 suffix.
35 constexpr int kOpsPerMac = 2;
36 constexpr char kGuaranteeConst[] = "GuaranteeConst";
37 constexpr char kAddN[] = "AddN";
38 constexpr char kBitCast[] = "BitCast";
39 constexpr char kConcatV2[] = "ConcatV2";
40 constexpr char kConv2d[] = "Conv2D";
41 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
42 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
43 constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
44 constexpr char kDataFormatVecPermute[] = "DataFormatVecPermute";
45 constexpr char kDepthToSpace[] = "DepthToSpace";
46 constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
47 constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
48 "DepthwiseConv2dNativeBackpropFilter";
49 constexpr char kDepthwiseConv2dNativeBackpropInput[] =
50 "DepthwiseConv2dNativeBackpropInput";
51 constexpr char kMatMul[] = "MatMul";
52 constexpr char kXlaEinsum[] = "XlaEinsum";
53 constexpr char kEinsum[] = "Einsum";
54 constexpr char kExpandDims[] = "ExpandDims";
55 constexpr char kFill[] = "Fill";
56 constexpr char kSparseMatMul[] = "SparseMatMul";
57 constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
58 constexpr char kPlaceholder[] = "Placeholder";
59 constexpr char kIdentity[] = "Identity";
60 constexpr char kIdentityN[] = "IdentityN";
61 constexpr char kRefIdentity[] = "RefIdentity";
62 constexpr char kNoOp[] = "NoOp";
63 constexpr char kReshape[] = "Reshape";
64 constexpr char kSplit[] = "Split";
65 constexpr char kSqueeze[] = "Squeeze";
66 constexpr char kRecv[] = "_Recv";
67 constexpr char kSend[] = "_Send";
68 constexpr char kBatchMatMul[] = "BatchMatMul";
69 constexpr char kBatchMatMulV2[] = "BatchMatMulV2";
70 constexpr char kOneHot[] = "OneHot";
71 constexpr char kPack[] = "Pack";
72 constexpr char kRank[] = "Rank";
73 constexpr char kRange[] = "Range";
74 constexpr char kShape[] = "Shape";
75 constexpr char kShapeN[] = "ShapeN";
76 constexpr char kSize[] = "Size";
77 constexpr char kStopGradient[] = "StopGradient";
78 constexpr char kPreventGradient[] = "PreventGradient";
79 constexpr char kGather[] = "Gather";
80 constexpr char kGatherNd[] = "GatherNd";
81 constexpr char kGatherV2[] = "GatherV2";
82 constexpr char kScatterAdd[] = "ScatterAdd";
83 constexpr char kScatterDiv[] = "ScatterDiv";
84 constexpr char kScatterMax[] = "ScatterMax";
85 constexpr char kScatterMin[] = "ScatterMin";
86 constexpr char kScatterMul[] = "ScatterMul";
87 constexpr char kScatterSub[] = "ScatterSub";
88 constexpr char kScatterUpdate[] = "ScatterUpdate";
89 constexpr char kSlice[] = "Slice";
90 constexpr char kStridedSlice[] = "StridedSlice";
91 constexpr char kSpaceToDepth[] = "SpaceToDepth";
92 constexpr char kTranspose[] = "Transpose";
93 constexpr char kTile[] = "Tile";
94 constexpr char kMaxPool[] = "MaxPool";
95 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
96 constexpr char kAvgPool[] = "AvgPool";
97 constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
98 constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
99 constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
100 constexpr char kQuantizedMatMul[] = "QuantizedMatMul";
101 constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
102 constexpr char kUnpack[] = "Unpack";
103 constexpr char kSoftmax[] = "Softmax";
104 constexpr char kResizeBilinear[] = "ResizeBilinear";
105 constexpr char kCropAndResize[] = "CropAndResize";
106 // Dynamic control flow ops.
107 constexpr char kSwitch[] = "Switch";
108 constexpr char kMerge[] = "Merge";
109 constexpr char kEnter[] = "Enter";
110 constexpr char kExit[] = "Exit";
111 constexpr char kNextIteration[] = "NextIteration";
112 // Persistent ops.
113 constexpr char kConst[] = "Const";
114 constexpr char kVariable[] = "Variable";
115 constexpr char kVariableV2[] = "VariableV2";
116 constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
117 constexpr char kVarHandleOp[] = "VarHandleOp";
118 constexpr char kVarHandlesOp[] = "_VarHandlesOp";
119 constexpr char kReadVariableOp[] = "ReadVariableOp";
120 constexpr char kReadVariablesOp[] = "_ReadVariablesOp";
121 constexpr char kAssignVariableOp[] = "AssignVariableOp";
122 constexpr char kAssignAddVariableOp[] = "AssignAddVariableOp";
123 constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp";
124
125 static const Costs::Duration kMinComputeTime(1);
126 static const int64_t kMinComputeOp = 1;
127
128 namespace {
129
GetDataFormat(const OpInfo & op_info)130 std::string GetDataFormat(const OpInfo& op_info) {
131 std::string data_format = "NHWC"; // Default format.
132 if (op_info.attr().find("data_format") != op_info.attr().end()) {
133 data_format = op_info.attr().at("data_format").s();
134 }
135 return data_format;
136 }
137
GetFilterFormat(const OpInfo & op_info)138 std::string GetFilterFormat(const OpInfo& op_info) {
139 std::string filter_format = "HWIO"; // Default format.
140 if (op_info.attr().find("filter_format") != op_info.attr().end()) {
141 filter_format = op_info.attr().at("filter_format").s();
142 }
143 return filter_format;
144 }
145
GetPadding(const OpInfo & op_info)146 Padding GetPadding(const OpInfo& op_info) {
147 if (op_info.attr().find("padding") != op_info.attr().end() &&
148 op_info.attr().at("padding").s() == "VALID") {
149 return Padding::VALID;
150 }
151 return Padding::SAME; // Default padding.
152 }
153
IsTraining(const OpInfo & op_info)154 bool IsTraining(const OpInfo& op_info) {
155 if (op_info.attr().find("is_training") != op_info.attr().end() &&
156 op_info.attr().at("is_training").b()) {
157 return true;
158 }
159 return false;
160 }
161
162 // TODO(dyoon): support non-4D tensors in the cost functions of convolution
163 // related ops (Conv, Pool, BatchNorm, and their backprops) and the related
164 // helper functions.
GetStrides(const OpInfo & op_info)165 std::vector<int64> GetStrides(const OpInfo& op_info) {
166 if (op_info.attr().find("strides") != op_info.attr().end()) {
167 const auto strides = op_info.attr().at("strides").list().i();
168 DCHECK(strides.size() == 4)
169 << "Attr strides is not a length-4 vector: " << op_info.DebugString();
170 if (strides.size() != 4) return {1, 1, 1, 1};
171 return {strides[0], strides[1], strides[2], strides[3]};
172 }
173 return {1, 1, 1, 1};
174 }
175
GetKernelSize(const OpInfo & op_info)176 std::vector<int64> GetKernelSize(const OpInfo& op_info) {
177 if (op_info.attr().find("ksize") != op_info.attr().end()) {
178 const auto ksize = op_info.attr().at("ksize").list().i();
179 DCHECK(ksize.size() == 4)
180 << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
181 if (ksize.size() != 4) return {1, 1, 1, 1};
182 return {ksize[0], ksize[1], ksize[2], ksize[3]};
183 }
184 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
185 // {1, 1, 1, 1} in that case.
186 return {1, 1, 1, 1};
187 }
188
GetOutputSize(const int64_t input,const int64_t filter,const int64_t stride,const Padding & padding)189 int64 GetOutputSize(const int64_t input, const int64_t filter,
190 const int64_t stride, const Padding& padding) {
191 // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
192 // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
193 if (padding == Padding::VALID) {
194 return (input - filter + stride) / stride;
195 } else { // SAME.
196 return (input + stride - 1) / stride;
197 }
198 }
199
200 // Return the output element count of a multi-input element-wise op considering
201 // broadcasting.
CwiseOutputElementCount(const OpInfo & op_info)202 int64 CwiseOutputElementCount(const OpInfo& op_info) {
203 int max_rank = 1;
204 for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
205 max_rank = std::max(max_rank, input_properties.shape().dim_size());
206 }
207
208 TensorShapeProto output_shape;
209 output_shape.mutable_dim()->Reserve(max_rank);
210 for (int i = 0; i < max_rank; ++i) {
211 output_shape.add_dim();
212 }
213
214 // Expand the shape of the output to follow the numpy-style broadcast rule
215 // which matches each input starting with the trailing dimensions and working
216 // its way forward. To do this, iterate through each input shape's dimensions
217 // in reverse order, and potentially increase the corresponding output
218 // dimension.
219 for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
220 const TensorShapeProto& input_shape = input_properties.shape();
221 for (int i = input_shape.dim_size() - 1; i >= 0; --i) {
222 int output_shape_dim_index =
223 i + output_shape.dim_size() - input_shape.dim_size();
224 output_shape.mutable_dim(output_shape_dim_index)
225 ->set_size(std::max(output_shape.dim(output_shape_dim_index).size(),
226 input_shape.dim(i).size()));
227 }
228 }
229
230 int64_t count = 1;
231 for (int i = 0; i < output_shape.dim_size(); i++) {
232 count *= output_shape.dim(i).size();
233 }
234 return count;
235 }
236
237 // Helper function for determining whether there are repeated indices in the
238 // input Einsum equation.
CheckRepeatedDimensions(const absl::string_view dim_str)239 bool CheckRepeatedDimensions(const absl::string_view dim_str) {
240 int str_size = dim_str.size();
241 for (int idx = 0; idx < str_size - 1; idx++) {
242 if (dim_str.find(dim_str[idx], idx + 1) != std::string::npos) {
243 return true;
244 }
245 }
246 return false;
247 }
248
249 // Auxiliary function for determining whether OpLevelCostEstimator is compatible
250 // with a given Einsum.
IsEinsumCorrectlyFormed(const OpContext & einsum_context)251 bool IsEinsumCorrectlyFormed(const OpContext& einsum_context) {
252 const auto& op_info = einsum_context.op_info;
253
254 auto it = op_info.attr().find("equation");
255 if (it == op_info.attr().end()) return false;
256 const absl::string_view equation = it->second.s();
257 std::vector<std::string> equation_split = absl::StrSplit(equation, "->");
258
259 if (equation_split.empty()) {
260 LOG(WARNING) << "Einsum with malformed equation";
261 return false;
262 }
263 std::vector<absl::string_view> input_split =
264 absl::StrSplit(equation_split[0], ',');
265
266 // The current model covers Einsum operations with two operands and a RHS
267 if (op_info.inputs_size() != 2 || equation_split.size() != 2) {
268 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
269 return false;
270 }
271 const auto& a_input = op_info.inputs(0);
272 const auto& b_input = op_info.inputs(1);
273 absl::string_view rhs_str = equation_split[1];
274 absl::string_view a_input_str = input_split[0];
275 absl::string_view b_input_str = input_split[1];
276
277 // Ellipsis are not currently supported
278 if (absl::StrContains(a_input_str, "...") ||
279 absl::StrContains(b_input_str, "...")) {
280 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
281 << ", ellipsis not supported";
282 return false;
283 }
284
285 constexpr int kMatrixRank = 2;
286
287 bool a_input_shape_unknown = false;
288 bool b_input_shape_unknown = false;
289
290 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
291 a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
292 &a_input_shape_unknown);
293 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
294 b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
295 &b_input_shape_unknown);
296
297 if (a_input_str.size() != static_cast<size_t>(a_input_shape.dim_size()) ||
298 b_input_str.size() != static_cast<size_t>(b_input_shape.dim_size())) {
299 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
300 << ", equation subscripts don't match tensor rank.";
301 return false;
302 }
303
304 // Subscripts where axis appears more than once for a single input are not yet
305 // supported
306 if (CheckRepeatedDimensions(a_input_str) ||
307 CheckRepeatedDimensions(b_input_str) ||
308 CheckRepeatedDimensions(rhs_str)) {
309 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
310 << ", Subscripts where axis appears more than once for a single "
311 "input are not yet supported";
312 return false;
313 }
314
315 return true;
316 }
317
318 } // namespace
319
320 // Return a minimum shape if the shape is unknown. If known, return the original
321 // shape.
MaybeGetMinimumShape(const TensorShapeProto & original_shape,int rank,bool * found_unknown_shapes)322 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
323 int rank, bool* found_unknown_shapes) {
324 auto shape = original_shape;
325 bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
326
327 if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
328 *found_unknown_shapes = true;
329 VLOG(2) << "Use minimum shape because the rank is unknown.";
330 // The size of each dimension is at least 1, if unknown.
331 for (int i = shape.dim_size(); i < rank; i++) {
332 shape.add_dim()->set_size(1);
333 }
334 } else if (is_scalar) {
335 for (int i = 0; i < rank; i++) {
336 shape.add_dim()->set_size(1);
337 }
338 } else if (shape.dim_size() > rank) {
339 *found_unknown_shapes = true;
340 shape.clear_dim();
341 for (int i = 0; i < rank; i++) {
342 shape.add_dim()->set_size(original_shape.dim(i).size());
343 }
344 } else {
345 for (int i = 0; i < shape.dim_size(); i++) {
346 if (shape.dim(i).size() < 0) {
347 *found_unknown_shapes = true;
348 VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
349 // The size of each dimension is at least 1, if unknown.
350 shape.mutable_dim(i)->set_size(1);
351 }
352 }
353 }
354 return shape;
355 }
356
OpLevelCostEstimator()357 OpLevelCostEstimator::OpLevelCostEstimator() {
358 // Syntactic sugar to build and return a lambda that takes an OpInfo and
359 // returns a cost.
360 typedef Status (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context,
361 NodeCosts*) const;
362 auto wrap = [this](CostImpl impl)
363 -> std::function<Status(const OpContext&, NodeCosts*)> {
364 return [this, impl](const OpContext& op_context, NodeCosts* node_costs) {
365 return (this->*impl)(op_context, node_costs);
366 };
367 };
368
369 device_cost_impl_.emplace(kConv2d,
370 wrap(&OpLevelCostEstimator::PredictConv2D));
371 device_cost_impl_.emplace(
372 kConv2dBackpropFilter,
373 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
374 device_cost_impl_.emplace(
375 kConv2dBackpropInput,
376 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
377 device_cost_impl_.emplace(
378 kFusedConv2dBiasActivation,
379 wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation));
380 // reuse Conv2D for DepthwiseConv2dNative because the calculation is the
381 // same although the actual meaning of the parameters are different. See
382 // comments in PredictConv2D and related functions
383 device_cost_impl_.emplace(kDepthwiseConv2dNative,
384 wrap(&OpLevelCostEstimator::PredictConv2D));
385 device_cost_impl_.emplace(
386 kDepthwiseConv2dNativeBackpropFilter,
387 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
388 device_cost_impl_.emplace(
389 kDepthwiseConv2dNativeBackpropInput,
390 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
391 device_cost_impl_.emplace(kMatMul,
392 wrap(&OpLevelCostEstimator::PredictMatMul));
393 device_cost_impl_.emplace(kSparseMatMul,
394 wrap(&OpLevelCostEstimator::PredictMatMul));
395 device_cost_impl_.emplace(
396 kSparseTensorDenseMatMul,
397 wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul));
398 device_cost_impl_.emplace(kBatchMatMul,
399 wrap(&OpLevelCostEstimator::PredictBatchMatMul));
400 device_cost_impl_.emplace(kBatchMatMulV2,
401 wrap(&OpLevelCostEstimator::PredictBatchMatMul));
402 device_cost_impl_.emplace(kQuantizedMatMul,
403 wrap(&OpLevelCostEstimator::PredictMatMul));
404 device_cost_impl_.emplace(kQuantizedMatMulV2,
405 wrap(&OpLevelCostEstimator::PredictMatMul));
406 device_cost_impl_.emplace(kXlaEinsum,
407 wrap(&OpLevelCostEstimator::PredictEinsum));
408 device_cost_impl_.emplace(kEinsum,
409 wrap(&OpLevelCostEstimator::PredictEinsum));
410
411 device_cost_impl_.emplace(kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp));
412 device_cost_impl_.emplace(kGuaranteeConst,
413 wrap(&OpLevelCostEstimator::PredictNoOp));
414
415 device_cost_impl_.emplace(kGather,
416 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
417 device_cost_impl_.emplace(kGatherNd,
418 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
419 device_cost_impl_.emplace(kGatherV2,
420 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
421 device_cost_impl_.emplace(kScatterAdd,
422 wrap(&OpLevelCostEstimator::PredictScatter));
423 device_cost_impl_.emplace(kScatterDiv,
424 wrap(&OpLevelCostEstimator::PredictScatter));
425 device_cost_impl_.emplace(kScatterMax,
426 wrap(&OpLevelCostEstimator::PredictScatter));
427 device_cost_impl_.emplace(kScatterMin,
428 wrap(&OpLevelCostEstimator::PredictScatter));
429 device_cost_impl_.emplace(kScatterMul,
430 wrap(&OpLevelCostEstimator::PredictScatter));
431 device_cost_impl_.emplace(kScatterSub,
432 wrap(&OpLevelCostEstimator::PredictScatter));
433 device_cost_impl_.emplace(kScatterUpdate,
434 wrap(&OpLevelCostEstimator::PredictScatter));
435
436 device_cost_impl_.emplace(kSlice,
437 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
438 device_cost_impl_.emplace(kStridedSlice,
439 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
440
441 device_cost_impl_.emplace(kPlaceholder,
442 wrap(&OpLevelCostEstimator::PredictIdentity));
443 device_cost_impl_.emplace(kIdentity,
444 wrap(&OpLevelCostEstimator::PredictIdentity));
445 device_cost_impl_.emplace(kIdentityN,
446 wrap(&OpLevelCostEstimator::PredictIdentity));
447 device_cost_impl_.emplace(kRefIdentity,
448 wrap(&OpLevelCostEstimator::PredictIdentity));
449 device_cost_impl_.emplace(kStopGradient,
450 wrap(&OpLevelCostEstimator::PredictIdentity));
451 device_cost_impl_.emplace(kPreventGradient,
452 wrap(&OpLevelCostEstimator::PredictIdentity));
453 device_cost_impl_.emplace(kReshape,
454 wrap(&OpLevelCostEstimator::PredictIdentity));
455 device_cost_impl_.emplace(kRecv,
456 wrap(&OpLevelCostEstimator::PredictIdentity));
457 device_cost_impl_.emplace(kSend,
458 wrap(&OpLevelCostEstimator::PredictIdentity));
459 device_cost_impl_.emplace(kSwitch,
460 wrap(&OpLevelCostEstimator::PredictIdentity));
461 device_cost_impl_.emplace(kMerge,
462 wrap(&OpLevelCostEstimator::PredictIdentity));
463 device_cost_impl_.emplace(kEnter,
464 wrap(&OpLevelCostEstimator::PredictIdentity));
465 device_cost_impl_.emplace(kExit,
466 wrap(&OpLevelCostEstimator::PredictIdentity));
467 device_cost_impl_.emplace(kNextIteration,
468 wrap(&OpLevelCostEstimator::PredictIdentity));
469 device_cost_impl_.emplace(kBitCast,
470 wrap(&OpLevelCostEstimator::PredictIdentity));
471
472 device_cost_impl_.emplace(kConcatV2,
473 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
474 device_cost_impl_.emplace(kDataFormatVecPermute,
475 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
476 device_cost_impl_.emplace(kDepthToSpace,
477 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
478 device_cost_impl_.emplace(kExpandDims,
479 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
480 device_cost_impl_.emplace(kFill,
481 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
482 device_cost_impl_.emplace(kOneHot,
483 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
484 device_cost_impl_.emplace(kPack,
485 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
486 device_cost_impl_.emplace(kRange,
487 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
488 device_cost_impl_.emplace(kSpaceToDepth,
489 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
490 device_cost_impl_.emplace(kSplit,
491 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
492 device_cost_impl_.emplace(kSqueeze,
493 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
494 device_cost_impl_.emplace(kTranspose,
495 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
496 device_cost_impl_.emplace(kTile,
497 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
498 device_cost_impl_.emplace(kUnpack,
499 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
500
501 device_cost_impl_.emplace(kRank,
502 wrap(&OpLevelCostEstimator::PredictMetadata));
503 device_cost_impl_.emplace(kShape,
504 wrap(&OpLevelCostEstimator::PredictMetadata));
505 device_cost_impl_.emplace(kShapeN,
506 wrap(&OpLevelCostEstimator::PredictMetadata));
507 device_cost_impl_.emplace(kSize,
508 wrap(&OpLevelCostEstimator::PredictMetadata));
509 device_cost_impl_.emplace(kMaxPool,
510 wrap(&OpLevelCostEstimator::PredictMaxPool));
511 device_cost_impl_.emplace(kMaxPoolGrad,
512 wrap(&OpLevelCostEstimator::PredictMaxPoolGrad));
513 device_cost_impl_.emplace(kAvgPool,
514 wrap(&OpLevelCostEstimator::PredictAvgPool));
515 device_cost_impl_.emplace(kAvgPoolGrad,
516 wrap(&OpLevelCostEstimator::PredictAvgPoolGrad));
517 device_cost_impl_.emplace(kFusedBatchNorm,
518 wrap(&OpLevelCostEstimator::PredictFusedBatchNorm));
519 device_cost_impl_.emplace(
520 kFusedBatchNormGrad,
521 wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad));
522 device_cost_impl_.emplace(kSoftmax,
523 wrap(&OpLevelCostEstimator::PredictSoftmax));
524 device_cost_impl_.emplace(kResizeBilinear,
525 wrap(&OpLevelCostEstimator::PredictResizeBilinear));
526 device_cost_impl_.emplace(kCropAndResize,
527 wrap(&OpLevelCostEstimator::PredictCropAndResize));
528 device_cost_impl_.emplace(
529 kAssignVariableOp, wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
530 device_cost_impl_.emplace(
531 kAssignAddVariableOp,
532 wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
533 device_cost_impl_.emplace(
534 kAssignSubVariableOp,
535 wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
536 device_cost_impl_.emplace(kAddN, wrap(&OpLevelCostEstimator::PredictNaryOp));
537
538 persistent_ops_ = {
539 kConst, kVariable, kVariableV2, kAutoReloadVariable,
540 kVarHandleOp, kReadVariableOp, kVarHandlesOp, kReadVariablesOp};
541
542 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
543
544 // Quantize = apply min and max bounds, multiply by scale factor and round.
545 const int quantize_v2_cost =
546 EIGEN_COST(scalar_product_op<float>) + EIGEN_COST(scalar_max_op<float>) +
547 EIGEN_COST(scalar_min_op<float>) + EIGEN_COST(scalar_round_op<float>);
548 const int quantize_and_dequantize_v2_cost =
549 quantize_v2_cost + EIGEN_COST(scalar_product_op<float>);
550
551 // Unary ops alphabetically sorted
552 elementwise_ops_.emplace("Acos", EIGEN_COST(scalar_acos_op<float>));
553 elementwise_ops_.emplace("All", EIGEN_COST(scalar_boolean_and_op));
554 elementwise_ops_.emplace("ArgMax", EIGEN_COST(scalar_max_op<float>));
555 elementwise_ops_.emplace("Asin", EIGEN_COST(scalar_asin_op<float>));
556 elementwise_ops_.emplace("Atan", EIGEN_COST(scalar_atan_op<float>));
557 elementwise_ops_.emplace("Atan2", EIGEN_COST(scalar_quotient_op<float>) +
558 EIGEN_COST(scalar_atan_op<float>));
559 // For now, we use Eigen cost model for float to int16 cast as an example
560 // case; Eigen cost model is zero when src and dst types are identical,
561 // and it uses AddCost (1) when different. We may implement a separate
562 // cost functions for cast ops, using the actual input and output types.
563 elementwise_ops_.emplace(
564 "Cast", Eigen::internal::functor_traits<
565 Eigen::internal::scalar_cast_op<float, int16>>::Cost);
566 elementwise_ops_.emplace("Ceil", EIGEN_COST(scalar_ceil_op<float>));
567 elementwise_ops_.emplace("Cos", EIGEN_COST(scalar_cos_op<float>));
568 elementwise_ops_.emplace("Dequantize", EIGEN_COST(scalar_product_op<float>));
569 elementwise_ops_.emplace("Erf", 1);
570 elementwise_ops_.emplace("Erfc", 1);
571 elementwise_ops_.emplace("Exp", EIGEN_COST(scalar_exp_op<float>));
572 elementwise_ops_.emplace("Expm1", EIGEN_COST(scalar_expm1_op<float>));
573 elementwise_ops_.emplace("Floor", EIGEN_COST(scalar_floor_op<float>));
574 elementwise_ops_.emplace("Inv", EIGEN_COST(scalar_inverse_op<float>));
575 elementwise_ops_.emplace("InvGrad", 1);
576 elementwise_ops_.emplace("Lgamma", 1);
577 elementwise_ops_.emplace("Log", EIGEN_COST(scalar_log_op<float>));
578 elementwise_ops_.emplace("Log1p", EIGEN_COST(scalar_log1p_op<float>));
579 elementwise_ops_.emplace("Max", EIGEN_COST(scalar_max_op<float>));
580 elementwise_ops_.emplace("Min", EIGEN_COST(scalar_min_op<float>));
581 elementwise_ops_.emplace("Neg", EIGEN_COST(scalar_opposite_op<float>));
582 elementwise_ops_.emplace("Prod", EIGEN_COST(scalar_product_op<float>));
583 elementwise_ops_.emplace("QuantizeAndDequantizeV2",
584 quantize_and_dequantize_v2_cost);
585 elementwise_ops_.emplace("QuantizeAndDequantizeV4",
586 quantize_and_dequantize_v2_cost);
587 elementwise_ops_.emplace("QuantizedSigmoid",
588 EIGEN_COST(scalar_logistic_op<float>));
589 elementwise_ops_.emplace("QuantizeV2", quantize_v2_cost);
590 elementwise_ops_.emplace("Reciprocal", EIGEN_COST(scalar_inverse_op<float>));
591 elementwise_ops_.emplace("Relu", EIGEN_COST(scalar_max_op<float>));
592 elementwise_ops_.emplace("Relu6", EIGEN_COST(scalar_max_op<float>));
593 elementwise_ops_.emplace("Rint", 1);
594 elementwise_ops_.emplace("Round", EIGEN_COST(scalar_round_op<float>));
595 elementwise_ops_.emplace("Rsqrt", EIGEN_COST(scalar_rsqrt_op<float>));
596 elementwise_ops_.emplace("Sigmoid", EIGEN_COST(scalar_logistic_op<float>));
597 elementwise_ops_.emplace("Sign", EIGEN_COST(scalar_sign_op<float>));
598 elementwise_ops_.emplace("Sin", EIGEN_COST(scalar_sin_op<float>));
599 elementwise_ops_.emplace("Sqrt", EIGEN_COST(scalar_sqrt_op<float>));
600 elementwise_ops_.emplace("Square", EIGEN_COST(scalar_square_op<float>));
601 elementwise_ops_.emplace("Sum", EIGEN_COST(scalar_sum_op<float>));
602 elementwise_ops_.emplace("Tan", EIGEN_COST(scalar_tan_op<float>));
603 elementwise_ops_.emplace("Tanh", EIGEN_COST(scalar_tanh_op<float>));
604 elementwise_ops_.emplace("TopKV2", EIGEN_COST(scalar_max_op<float>));
605 // Binary ops alphabetically sorted
606 elementwise_ops_.emplace("Add", EIGEN_COST(scalar_sum_op<float>));
607 elementwise_ops_.emplace("AddV2", EIGEN_COST(scalar_sum_op<float>));
608 elementwise_ops_.emplace("ApproximateEqual", 1);
609 elementwise_ops_.emplace("BiasAdd", EIGEN_COST(scalar_sum_op<float>));
610 elementwise_ops_.emplace("QuantizedBiasAdd",
611 EIGEN_COST(scalar_sum_op<float>));
612 elementwise_ops_.emplace("Div", EIGEN_COST(scalar_quotient_op<float>));
613 elementwise_ops_.emplace("Equal", 1);
614 elementwise_ops_.emplace("FloorDiv", EIGEN_COST(scalar_quotient_op<float>));
615 elementwise_ops_.emplace("FloorMod", EIGEN_COST(scalar_mod_op<float>));
616 elementwise_ops_.emplace("Greater", 1);
617 elementwise_ops_.emplace("GreaterEqual", 1);
618 elementwise_ops_.emplace("Less", 1);
619 elementwise_ops_.emplace("LessEqual", 1);
620 elementwise_ops_.emplace("LogicalAnd", EIGEN_COST(scalar_boolean_and_op));
621 elementwise_ops_.emplace("LogicalNot", 1);
622 elementwise_ops_.emplace("LogicalOr", EIGEN_COST(scalar_boolean_or_op));
623 elementwise_ops_.emplace("Maximum", EIGEN_COST(scalar_max_op<float>));
624 elementwise_ops_.emplace("Minimum", EIGEN_COST(scalar_min_op<float>));
625 elementwise_ops_.emplace("Mod", EIGEN_COST(scalar_mod_op<float>));
626 elementwise_ops_.emplace("Mul", EIGEN_COST(scalar_product_op<float>));
627 elementwise_ops_.emplace("NotEqual", 1);
628 elementwise_ops_.emplace("QuantizedAdd", EIGEN_COST(scalar_sum_op<float>));
629 elementwise_ops_.emplace("QuantizedMul",
630 EIGEN_COST(scalar_product_op<float>));
631 elementwise_ops_.emplace("RealDiv", EIGEN_COST(scalar_quotient_op<float>));
632 elementwise_ops_.emplace("ReluGrad", EIGEN_COST(scalar_max_op<float>));
633 elementwise_ops_.emplace("Select", EIGEN_COST(scalar_boolean_or_op));
634 elementwise_ops_.emplace("SelectV2", EIGEN_COST(scalar_boolean_or_op));
635 elementwise_ops_.emplace("SquaredDifference",
636 EIGEN_COST(scalar_square_op<float>) +
637 EIGEN_COST(scalar_difference_op<float>));
638 elementwise_ops_.emplace("Sub", EIGEN_COST(scalar_difference_op<float>));
639 elementwise_ops_.emplace("TruncateDiv",
640 EIGEN_COST(scalar_quotient_op<float>));
641 elementwise_ops_.emplace("TruncateMod", EIGEN_COST(scalar_mod_op<float>));
642 elementwise_ops_.emplace("Where", 1);
643
644 #undef EIGEN_COST
645
646 // By default, use sum of memory_time and compute_time for execution_time.
647 compute_memory_overlap_ = false;
648 }
649
PredictCosts(const OpContext & op_context) const650 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
651 Costs costs;
652 NodeCosts node_costs;
653 if (PredictNodeCosts(op_context, &node_costs).ok()) {
654 if (node_costs.has_costs) {
655 return node_costs.costs;
656 }
657 // Convert NodeCosts to Costs.
658 if (node_costs.minimum_cost_op) {
659 // Override to minimum cost; Note that some ops with minimum cost may have
660 // non-typical device (e.g., channel for _Send), which may fail with
661 // GetDeviceInfo(), called from PredictOpCountBasedCost(). Make sure we
662 // directly set minimum values to Costs here, not calling
663 // PredictOpCountBasedCost().
664 costs.compute_time = kMinComputeTime;
665 costs.execution_time = kMinComputeTime;
666 costs.memory_time = 0;
667 costs.intermediate_memory_time = 0;
668 costs.intermediate_memory_read_time = 0;
669 costs.intermediate_memory_write_time = 0;
670 } else {
671 // Convert NodeCosts to Costs.
672 costs = PredictOpCountBasedCost(
673 node_costs.num_compute_ops, node_costs.num_total_read_bytes(),
674 node_costs.num_total_write_bytes(), op_context.op_info);
675 }
676 VLOG(1) << "Operation " << op_context.op_info.op() << " takes "
677 << costs.execution_time.count() << " ns.";
678 // Copy additional stats from NodeCosts to Costs.
679 costs.max_memory = node_costs.max_memory;
680 costs.persistent_memory = node_costs.persistent_memory;
681 costs.temporary_memory = node_costs.temporary_memory;
682 costs.inaccurate = node_costs.inaccurate;
683 costs.num_ops_with_unknown_shapes =
684 node_costs.num_nodes_with_unknown_shapes;
685 costs.num_ops_total = node_costs.num_nodes;
686 return costs;
687 }
688 // Errors during node cost estimate.
689 LOG(WARNING) << "Error in PredictCost() for the op: "
690 << op_context.op_info.ShortDebugString();
691 costs = Costs::ZeroCosts(/*inaccurate=*/true);
692 costs.num_ops_with_unknown_shapes = node_costs.num_nodes_with_unknown_shapes;
693 return costs;
694 }
695
PredictNodeCosts(const OpContext & op_context,NodeCosts * node_costs) const696 Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
697 NodeCosts* node_costs) const {
698 const auto& op_info = op_context.op_info;
699 auto it = device_cost_impl_.find(op_info.op());
700 if (it != device_cost_impl_.end()) {
701 std::function<Status(const OpContext&, NodeCosts*)> estimator = it->second;
702 return estimator(op_context, node_costs);
703 }
704
705 if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
706 return PredictVariable(op_context, node_costs);
707 }
708
709 if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
710 return PredictCwiseOp(op_context, node_costs);
711 }
712
713 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
714
715 node_costs->num_nodes_with_unknown_op_type = 1;
716 return PredictCostOfAnUnknownOp(op_context, node_costs);
717 }
718
719 // This method assumes a typical system composed of CPUs and GPUs, connected
720 // through PCIe. To define device info more precisely, override this method.
GetDeviceInfo(const DeviceProperties & device) const721 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
722 const DeviceProperties& device) const {
723 double gflops = -1;
724 double gb_per_sec = -1;
725
726 if (device.type() == "CPU") {
727 // Check if vector instructions are available, and refine performance
728 // prediction based on this.
729 // Frequencies are stored in MHz in the DeviceProperties.
730 gflops = device.num_cores() * device.frequency() * 1e-3;
731 if (gb_per_sec < 0) {
732 if (device.bandwidth() > 0) {
733 gb_per_sec = device.bandwidth() / 1e6;
734 } else {
735 gb_per_sec = 32;
736 }
737 }
738 } else if (device.type() == "GPU") {
739 const auto& device_env = device.environment();
740 auto it = device_env.find("architecture");
741 if (it != device_env.end()) {
742 const std::string architecture = device_env.at("architecture");
743 int cores_per_multiprocessor;
744 if (architecture < "3") {
745 // Fermi
746 cores_per_multiprocessor = 32;
747 } else if (architecture < "4") {
748 // Kepler
749 cores_per_multiprocessor = 192;
750 } else if (architecture < "6") {
751 // Maxwell
752 cores_per_multiprocessor = 128;
753 } else {
754 // Pascal (compute capability version 6) and Volta (compute capability
755 // version 7)
756 cores_per_multiprocessor = 64;
757 }
758 gflops = device.num_cores() * device.frequency() * 1e-3 *
759 cores_per_multiprocessor * kOpsPerMac;
760 if (device.bandwidth() > 0) {
761 gb_per_sec = device.bandwidth() / 1e6;
762 } else {
763 gb_per_sec = 100;
764 }
765 } else {
766 // Architecture is not available (ex: pluggable device), return default
767 // value.
768 gflops = 100; // Dummy value;
769 gb_per_sec = 12; // default PCIe x16 gen3.
770 }
771 } else {
772 LOG_EVERY_N(WARNING, 1000) << "Unknown device type: " << device.type()
773 << ", assuming PCIe between CPU and GPU.";
774 gflops = 1; // Dummy value; data transfer ops would not have compute ops.
775 gb_per_sec = 12; // default PCIe x16 gen3.
776 }
777 VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
778 << " gb_per_sec: " << gb_per_sec;
779
780 return DeviceInfo(gflops, gb_per_sec);
781 }
782
PredictCwiseOp(const OpContext & op_context,NodeCosts * node_costs) const783 Status OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context,
784 NodeCosts* node_costs) const {
785 const auto& op_info = op_context.op_info;
786 bool found_unknown_shapes = false;
787 // For element-wise operations, op count is the element count of any input. We
788 // use the count for the largest input here to be more robust in case that the
789 // shape is unknown or partially known for other input.
790 int64_t op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
791 // If output shape is available, try to use the element count calculated from
792 // that.
793 if (op_info.outputs_size() > 0) {
794 op_count = std::max(
795 op_count,
796 CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
797 }
798 // Calculate the output shape possibly resulting from broadcasting.
799 if (op_info.inputs_size() >= 2) {
800 op_count = std::max(op_count, CwiseOutputElementCount(op_info));
801 }
802
803 int op_cost = 1;
804 auto it = elementwise_ops_.find(op_info.op());
805 if (it != elementwise_ops_.end()) {
806 op_cost = it->second;
807 } else {
808 return errors::InvalidArgument("Not a cwise op: ", op_info.op());
809 }
810
811 return PredictDefaultNodeCosts(op_count * op_cost, op_context,
812 &found_unknown_shapes, node_costs);
813 }
814
PredictCostOfAnUnknownOp(const OpContext & op_context,NodeCosts * node_costs) const815 Status OpLevelCostEstimator::PredictCostOfAnUnknownOp(
816 const OpContext& op_context, NodeCosts* node_costs) const {
817 // Don't assume the operation is cwise, return cost based on input/output size
818 // and admit that it is inaccurate...
819 bool found_unknown_shapes = false;
820 node_costs->inaccurate = true;
821 return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
822 node_costs);
823 }
824
PredictOpCountBasedCost(double operations,const OpInfo & op_info) const825 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
826 double operations, const OpInfo& op_info) const {
827 bool unknown_shapes = false;
828 const double input_size = CalculateInputSize(op_info, &unknown_shapes);
829 const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
830 Costs costs =
831 PredictOpCountBasedCost(operations, input_size, output_size, op_info);
832 costs.inaccurate = unknown_shapes;
833 costs.num_ops_with_unknown_shapes = unknown_shapes;
834 costs.max_memory = output_size;
835 return costs;
836 }
837
PredictOpCountBasedCost(double operations,double input_io_bytes,double output_io_bytes,const OpInfo & op_info) const838 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
839 double operations, double input_io_bytes, double output_io_bytes,
840 const OpInfo& op_info) const {
841 double total_io_bytes = input_io_bytes + output_io_bytes;
842 const DeviceInfo device_info = GetDeviceInfo(op_info.device());
843 if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
844 device_info.intermediate_read_gb_per_sec <= 0 ||
845 device_info.intermediate_write_gb_per_sec <= 0) {
846 VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
847 << " device type:" << op_info.device().type()
848 << " device model:" << op_info.device().model();
849 }
850
851 Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
852 VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
853 << " Compute Time (ns):" << compute_cost.count();
854
855 Costs::NanoSeconds memory_cost(
856 std::ceil(total_io_bytes / device_info.gb_per_sec));
857 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
858 << " Memory Time (ns):" << memory_cost.count();
859
860 // Check if bytes > 0. If it's not and the bandwidth is set to infinity
861 // then the result would be undefined.
862 double intermediate_read_time =
863 (input_io_bytes > 0)
864 ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
865 : 0;
866
867 double intermediate_write_time =
868 (output_io_bytes > 0)
869 ? std::ceil(output_io_bytes /
870 device_info.intermediate_write_gb_per_sec)
871 : 0;
872
873 Costs::NanoSeconds intermediate_memory_cost =
874 compute_memory_overlap_
875 ? std::max(intermediate_read_time, intermediate_write_time)
876 : (intermediate_read_time + intermediate_write_time);
877 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
878 << " Intermediate Memory Time (ns):"
879 << intermediate_memory_cost.count();
880
881 Costs costs = Costs::ZeroCosts();
882 costs.compute_time = compute_cost;
883 costs.memory_time = memory_cost;
884 costs.intermediate_memory_time = intermediate_memory_cost;
885 costs.intermediate_memory_read_time =
886 Costs::NanoSeconds(intermediate_read_time);
887 costs.intermediate_memory_write_time =
888 Costs::NanoSeconds(intermediate_write_time);
889 CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
890 return costs;
891 }
892
CountConv2DOperations(const OpInfo & op_info,bool * found_unknown_shapes)893 int64 OpLevelCostEstimator::CountConv2DOperations(const OpInfo& op_info,
894 bool* found_unknown_shapes) {
895 return CountConv2DOperations(op_info, nullptr, found_unknown_shapes);
896 }
897
898 // Helper to translate the positional arguments into named fields.
899 /* static */
900 OpLevelCostEstimator::ConvolutionDimensions
ConvolutionDimensionsFromInputs(const TensorShapeProto & original_image_shape,const TensorShapeProto & original_filter_shape,const OpInfo & op_info,bool * found_unknown_shapes)901 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
902 const TensorShapeProto& original_image_shape,
903 const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
904 bool* found_unknown_shapes) {
905 VLOG(2) << "op features: " << op_info.DebugString();
906 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
907 VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
908
909 int x_index, y_index, major_channel_index, minor_channel_index = -1;
910 const std::string& data_format = GetDataFormat(op_info);
911 if (data_format == "NCHW") {
912 major_channel_index = 1;
913 y_index = 2;
914 x_index = 3;
915 } else if (data_format == "NCHW_VECT_C") {
916 // Use NCHW_VECT_C
917 minor_channel_index = 1;
918 y_index = 2;
919 x_index = 3;
920 major_channel_index = 4;
921 } else {
922 // Use NHWC.
923 y_index = 1;
924 x_index = 2;
925 major_channel_index = 3;
926 }
927 const std::string& filter_format = GetFilterFormat(op_info);
928 int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
929 in_minor_channel_index = -1;
930 if (filter_format == "HWIO") {
931 filter_y_index = 0;
932 filter_x_index = 1;
933 in_major_channel_index = 2;
934 out_channel_index = 3;
935 } else if (filter_format == "OIHW_VECT_I") {
936 out_channel_index = 0;
937 in_minor_channel_index = 1;
938 filter_y_index = 2;
939 filter_x_index = 3;
940 in_major_channel_index = 4;
941 } else {
942 // Use OIHW
943 out_channel_index = 0;
944 in_major_channel_index = 1;
945 filter_y_index = 2;
946 filter_x_index = 3;
947 }
948
949 auto image_shape = MaybeGetMinimumShape(original_image_shape,
950 minor_channel_index >= 0 ? 5 : 4,
951 found_unknown_shapes);
952 auto filter_shape = MaybeGetMinimumShape(original_filter_shape,
953 in_minor_channel_index >= 0 ? 5 : 4,
954 found_unknown_shapes);
955 VLOG(2) << "Image shape: " << image_shape.DebugString();
956 VLOG(2) << "Filter shape: " << filter_shape.DebugString();
957
958 int64_t batch = image_shape.dim(0).size();
959 int64_t ix = image_shape.dim(x_index).size();
960 int64_t iy = image_shape.dim(y_index).size();
961 int64_t iz = minor_channel_index >= 0
962 ? image_shape.dim(minor_channel_index).size() *
963 image_shape.dim(major_channel_index).size()
964 : image_shape.dim(major_channel_index).size();
965 int64_t kx = filter_shape.dim(filter_x_index).size();
966 int64_t ky = filter_shape.dim(filter_y_index).size();
967 int64_t kz = in_minor_channel_index >= 0
968 ? filter_shape.dim(in_major_channel_index).size() *
969 filter_shape.dim(in_minor_channel_index).size()
970 : filter_shape.dim(in_major_channel_index).size();
971 std::vector<int64> strides = GetStrides(op_info);
972 const auto padding = GetPadding(op_info);
973 int64_t sx = strides[x_index];
974 int64_t sy = strides[y_index];
975 int64_t ox = GetOutputSize(ix, kx, sx, padding);
976 int64_t oy = GetOutputSize(iy, ky, sy, padding);
977 int64_t oz = filter_shape.dim(out_channel_index).size();
978 // Only check equality when both sizes are known (in other words, when
979 // neither is set to a minimum dimension size of 1).
980 if (iz != 1 && kz != 1) {
981 DCHECK_EQ(iz % kz, 0) << "Input channel " << iz
982 << " is not a multiple of filter channel " << kz
983 << ".";
984 if (iz % kz) {
985 *found_unknown_shapes = true;
986 }
987 } else {
988 iz = kz = std::max<int64>(iz, kz);
989 }
990 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
991 batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
992
993 VLOG(1) << "Batch Size:" << batch;
994 VLOG(1) << "Image Dims:" << ix << "," << iy;
995 VLOG(1) << "Input Depth:" << iz;
996 VLOG(1) << "Kernel Dims:" << kx << "," << ky;
997 VLOG(1) << "Kernel Depth:" << kz;
998 VLOG(1) << "Output Dims:" << ox << "," << oy;
999 VLOG(1) << "Output Depth:" << oz;
1000 VLOG(1) << "Strides:" << sx << "," << sy;
1001 VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
1002 return conv_dims;
1003 }
1004
CountConv2DOperations(const OpInfo & op_info,ConvolutionDimensions * conv_info,bool * found_unknown_shapes)1005 int64 OpLevelCostEstimator::CountConv2DOperations(
1006 const OpInfo& op_info, ConvolutionDimensions* conv_info,
1007 bool* found_unknown_shapes) {
1008 DCHECK(op_info.op() == kConv2d || op_info.op() == kDepthwiseConv2dNative)
1009 << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
1010
1011 if (op_info.inputs_size() < 2) { // Unexpect inputs.
1012 *found_unknown_shapes = true;
1013 return 0;
1014 }
1015
1016 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1017 op_info.inputs(0).shape(), op_info.inputs(1).shape(), op_info,
1018 found_unknown_shapes);
1019
1020 // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
1021 // multiplier; The effective output channel depth oz_effective is
1022 // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
1023 // Compare to Conv2D where # ops = N x H x W x kz x oz x 2RS,
1024 // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = kz.
1025 int64_t ops = conv_dims.batch;
1026 ops *= conv_dims.ox * conv_dims.oy;
1027 ops *= conv_dims.kx * conv_dims.ky;
1028 if (op_info.op() == kConv2d) {
1029 ops *= conv_dims.kz * conv_dims.oz;
1030 } else {
1031 // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
1032 // although ops are the same as Conv2D.
1033 conv_dims.oz *= conv_dims.iz;
1034 ops *= conv_dims.oz;
1035 }
1036 ops *= kOpsPerMac;
1037
1038 if (conv_info != nullptr) {
1039 *conv_info = conv_dims;
1040 }
1041 return ops;
1042 }
1043
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1044 int64 OpLevelCostEstimator::CountMatMulOperations(const OpInfo& op_info,
1045 bool* found_unknown_shapes) {
1046 return CountMatMulOperations(op_info, nullptr, found_unknown_shapes);
1047 }
1048
1049 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
CountMatMulOperations(const OpInfo & op_info,MatMulDimensions * mat_mul,bool * found_unknown_shapes)1050 int64 OpLevelCostEstimator::CountMatMulOperations(const OpInfo& op_info,
1051 MatMulDimensions* mat_mul,
1052 bool* found_unknown_shapes) {
1053 double ops = 0;
1054
1055 if (op_info.inputs_size() < 2) {
1056 LOG(ERROR) << "Need 2 inputs but got " << op_info.inputs_size();
1057 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1058 *found_unknown_shapes = true;
1059 return 0;
1060 }
1061
1062 auto& a_matrix = op_info.inputs(0);
1063 auto& b_matrix = op_info.inputs(1);
1064
1065 bool transpose_a = false;
1066 bool transpose_b = false;
1067
1068 double m_dim, n_dim, k_dim, k_dim_b = 0;
1069
1070 for (const auto& item : op_info.attr()) {
1071 VLOG(1) << "Key:" << item.first
1072 << " Value:" << SummarizeAttrValue(item.second);
1073 if (item.first == "transpose_a" && item.second.b() == true)
1074 transpose_a = true;
1075 if (item.first == "transpose_b" && item.second.b() == true)
1076 transpose_b = true;
1077 }
1078 VLOG(1) << "transpose_a:" << transpose_a;
1079 VLOG(1) << "transpose_b:" << transpose_b;
1080 auto a_matrix_shape =
1081 MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
1082 auto b_matrix_shape =
1083 MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
1084 if (transpose_a) {
1085 m_dim = a_matrix_shape.dim(1).size();
1086 k_dim = a_matrix_shape.dim(0).size();
1087 } else {
1088 m_dim = a_matrix_shape.dim(0).size();
1089 k_dim = a_matrix_shape.dim(1).size();
1090 }
1091 if (transpose_b) {
1092 k_dim_b = b_matrix_shape.dim(1).size();
1093 n_dim = b_matrix_shape.dim(0).size();
1094 } else {
1095 k_dim_b = b_matrix_shape.dim(0).size();
1096 n_dim = b_matrix_shape.dim(1).size();
1097 }
1098
1099 VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
1100 // Only check equality when both sizes are known (in other words, when
1101 // neither is set to a minimum dimension size of 1).
1102 if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
1103 LOG(ERROR) << "Incompatible Matrix dimensions";
1104 return ops;
1105 } else {
1106 // One of k_dim and k_dim_b might be 1 (minimum dimension size).
1107 k_dim = std::max(k_dim, k_dim_b);
1108 }
1109
1110 ops = m_dim * n_dim * k_dim * 2;
1111 VLOG(1) << "Operations for Matmul: " << ops;
1112
1113 if (mat_mul != nullptr) {
1114 mat_mul->m = m_dim;
1115 mat_mul->n = n_dim;
1116 mat_mul->k = k_dim;
1117 }
1118 return ops;
1119 }
1120
GenerateBatchMatmulContextFromEinsum(const OpContext & einsum_context,OpContext * batch_matmul_context,bool * found_unknown_shapes) const1121 bool OpLevelCostEstimator::GenerateBatchMatmulContextFromEinsum(
1122 const OpContext& einsum_context, OpContext* batch_matmul_context,
1123 bool* found_unknown_shapes) const {
1124 // This auxiliary function transforms an einsum OpContext into its equivalent
1125 // Batch Matmul OpContext. The function returns a boolean, which determines
1126 // whether it was successful in generating the output OpContext or not.
1127
1128 // Einsum computes a generalized contraction between tensors of arbitrary
1129 // dimension as defined by the equation written in the Einstein summation
1130 // convention. The number of tensors in the computation and the number of
1131 // contractions can be arbitrarily long. The current model only contemplates
1132 // Einsum equations, which can be translated into a single BatchMatMul
1133 // operation. Einsum operations with more than two operands are not currently
1134 // supported. Subscripts where an axis appears more than once for a single
1135 // input and ellipsis are currently also excluded. See:
1136 // https://www.tensorflow.org/api_docs/python/tf/einsum
1137 // We distinguish four kinds of dimensions, depending on their placement in
1138 // the equation:
1139 // + B: Batch dimensions: Dimensions which appear in both operands and RHS.
1140 // + K: Contracting dimensions: These appear in both inputs but not RHS.
1141 // + M: Operand A dimensions: These appear in the first operand and the RHS.
1142 // + N: Operand B dimensions: These appear in the second operand and the RHS.
1143 // Then, the operation to estimate is BatchMatMul([B,M,K],[B,K,N])
1144
1145 if (batch_matmul_context == nullptr) {
1146 VLOG(1) << "Output context should not be a nullptr.";
1147 return false;
1148 }
1149 if (!IsEinsumCorrectlyFormed(einsum_context)) return false;
1150 const auto& op_info = einsum_context.op_info;
1151 std::vector<std::string> equation_split =
1152 absl::StrSplit(op_info.attr().find("equation")->second.s(), "->");
1153 std::vector<absl::string_view> input_split =
1154 absl::StrSplit(equation_split[0], ',');
1155 const auto& a_input = op_info.inputs(0);
1156 const auto& b_input = op_info.inputs(1);
1157 absl::string_view rhs_str = equation_split[1];
1158 absl::string_view a_input_str = input_split[0];
1159 absl::string_view b_input_str = input_split[1];
1160
1161 constexpr int kMatrixRank = 2;
1162
1163 bool a_input_shape_unknown = false;
1164 bool b_input_shape_unknown = false;
1165
1166 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1167 a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
1168 &a_input_shape_unknown);
1169 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1170 b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
1171 &b_input_shape_unknown);
1172
1173 *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1174 (a_input.shape().dim_size() < kMatrixRank) ||
1175 (b_input.shape().dim_size() < kMatrixRank);
1176
1177 OpInfo batch_matmul_op_info = op_info;
1178 batch_matmul_op_info.mutable_inputs()->Clear();
1179 batch_matmul_op_info.set_op("BatchMatMul");
1180
1181 AttrValue transpose_attribute;
1182 transpose_attribute.set_b(false);
1183 (*batch_matmul_op_info.mutable_attr())["transpose_a"] = transpose_attribute;
1184 (*batch_matmul_op_info.mutable_attr())["transpose_b"] = transpose_attribute;
1185
1186 OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs();
1187 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1188 a_matrix->set_dtype(a_input.dtype());
1189
1190 OpInfo::TensorProperties* b_matrix = batch_matmul_op_info.add_inputs();
1191 b_matrix->set_dtype(b_input.dtype());
1192 TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1193
1194 TensorShapeProto_Dim m_dim;
1195 TensorShapeProto_Dim n_dim;
1196 TensorShapeProto_Dim k_dim;
1197
1198 m_dim.set_size(1);
1199 n_dim.set_size(1);
1200 k_dim.set_size(1);
1201
1202 for (int i_idx = 0, a_input_str_size = a_input_str.size();
1203 i_idx < a_input_str_size; ++i_idx) {
1204 if (b_input_str.find(a_input_str[i_idx]) == std::string::npos) {
1205 if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1206 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1207 return false;
1208 }
1209
1210 m_dim.set_size(m_dim.size() * a_input_shape.dim(i_idx).size());
1211 continue;
1212 } else if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1213 // The dimension does not appear in the RHS, therefore it is a contracting
1214 // dimension.
1215 k_dim.set_size(k_dim.size() * a_input_shape.dim(i_idx).size());
1216 continue;
1217 }
1218 // It appears in both input operands, therefore we place it as an outer
1219 // dimension for the Batch Matmul.
1220 *(a_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1221 *(b_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1222 }
1223 for (int i_idx = 0, b_input_str_size = b_input_str.size();
1224 i_idx < b_input_str_size; ++i_idx) {
1225 if (a_input_str.find(b_input_str[i_idx]) == std::string::npos) {
1226 if (rhs_str.find(b_input_str[i_idx]) == std::string::npos) {
1227 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1228 return false;
1229 }
1230 n_dim.set_size(n_dim.size() * b_input_shape.dim(i_idx).size());
1231 }
1232 }
1233
1234 // The two inner-most dimensions of the Batch Matmul are added.
1235 *(a_matrix_shape->add_dim()) = m_dim;
1236 *(a_matrix_shape->add_dim()) = k_dim;
1237 *(b_matrix_shape->add_dim()) = k_dim;
1238 *(b_matrix_shape->add_dim()) = n_dim;
1239
1240 *batch_matmul_context = einsum_context;
1241 batch_matmul_context->op_info = batch_matmul_op_info;
1242 return true;
1243 }
1244
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1245 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
1246 const OpInfo& op_info, bool* found_unknown_shapes) {
1247 return CountBatchMatMulOperations(op_info, nullptr, found_unknown_shapes);
1248 }
1249
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes)1250 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
1251 const OpInfo& op_info, BatchMatMulDimensions* batch_mat_mul,
1252 bool* found_unknown_shapes) {
1253 if (op_info.op() != kBatchMatMul && op_info.op() != kBatchMatMulV2) {
1254 LOG(ERROR) << "Invalid Operation: " << op_info.op();
1255 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1256 *found_unknown_shapes = true;
1257 return 0;
1258 }
1259 if (op_info.inputs_size() != 2) {
1260 LOG(ERROR) << "Expected 2 inputs but got " << op_info.inputs_size();
1261 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1262 *found_unknown_shapes = true;
1263 return 0;
1264 }
1265
1266 double ops = 0;
1267 const auto& a_input = op_info.inputs(0);
1268 const auto& b_input = op_info.inputs(1);
1269
1270 // BatchMatMul requires inputs of at least matrix shape (rank 2).
1271 // The two most minor dimensions of each input are matrices that
1272 // need to be multiplied together. The other dimensions determine
1273 // the number of such MatMuls. For example, if the BatchMatMul has
1274 // inputs of shape:
1275 // a_input_shape = [2, 3, 4, 5]
1276 // b_input_shape = [2, 3, 5, 6]
1277 // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
1278 // in this BatchMatMul.
1279 const int matrix_rank = 2;
1280
1281 bool a_input_shape_unknown = false;
1282 bool b_input_shape_unknown = false;
1283
1284 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1285 a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
1286 &a_input_shape_unknown);
1287 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1288 b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
1289 &b_input_shape_unknown);
1290
1291 *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1292 (a_input.shape().dim_size() < matrix_rank) ||
1293 (b_input.shape().dim_size() < matrix_rank);
1294
1295 // Compute the number of matmuls as the max indicated at each dimension
1296 // by either input. Note that the shapes do not have to have
1297 // the same rank due to incompleteness.
1298 TensorShapeProto* bigger_rank_shape = &a_input_shape;
1299 TensorShapeProto* smaller_rank_shape = &b_input_shape;
1300 if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
1301 bigger_rank_shape = &b_input_shape;
1302 smaller_rank_shape = &a_input_shape;
1303 }
1304 int num_matmuls = 1;
1305 for (int b_i = 0,
1306 s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
1307 b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
1308 int b_dim = bigger_rank_shape->dim(b_i).size();
1309 int s_dim = 1;
1310 if (s_i >= 0) {
1311 s_dim = smaller_rank_shape->dim(s_i).size();
1312 }
1313 if (batch_mat_mul != nullptr) {
1314 batch_mat_mul->batch_dims.push_back(s_dim);
1315 }
1316 num_matmuls *= std::max(b_dim, s_dim);
1317 }
1318
1319 // Build the MatMul. Note that values are ignored here since we are just
1320 // counting ops (e.g. only shapes matter).
1321 OpInfo matmul_op_info;
1322 matmul_op_info.set_op("MatMul");
1323
1324 AttrValue transpose_a;
1325 transpose_a.set_b(false);
1326 if (op_info.attr().find("adj_x") != op_info.attr().end()) {
1327 transpose_a.set_b(op_info.attr().at("adj_x").b());
1328 }
1329 (*matmul_op_info.mutable_attr())["transpose_a"] = transpose_a;
1330
1331 AttrValue transpose_b;
1332 transpose_b.set_b(false);
1333 if (op_info.attr().find("adj_y") != op_info.attr().end()) {
1334 transpose_b.set_b(op_info.attr().at("adj_y").b());
1335 }
1336 (*matmul_op_info.mutable_attr())["transpose_b"] = transpose_b;
1337
1338 OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs();
1339 a_matrix->set_dtype(a_input.dtype());
1340 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1341 for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
1342 i < a_input_shape.dim_size(); ++i) {
1343 *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
1344 }
1345
1346 OpInfo::TensorProperties* b_matrix = matmul_op_info.add_inputs();
1347 b_matrix->set_dtype(b_input.dtype());
1348 TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1349 for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
1350 i < b_input_shape.dim_size(); ++i) {
1351 *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
1352 }
1353 if (batch_mat_mul != nullptr) {
1354 batch_mat_mul->matmul_dims.m = (transpose_a.b())
1355 ? a_matrix_shape->dim(1).size()
1356 : a_matrix_shape->dim(0).size();
1357 batch_mat_mul->matmul_dims.k = (transpose_a.b())
1358 ? a_matrix_shape->dim(0).size()
1359 : a_matrix_shape->dim(1).size();
1360 batch_mat_mul->matmul_dims.n = (transpose_b.b())
1361 ? b_matrix_shape->dim(0).size()
1362 : b_matrix_shape->dim(1).size();
1363 }
1364
1365 for (int i = 0; i < num_matmuls; ++i) {
1366 bool matmul_unknown_shapes = false;
1367 ops += CountMatMulOperations(matmul_op_info, &matmul_unknown_shapes);
1368 *found_unknown_shapes |= matmul_unknown_shapes;
1369 }
1370 return ops;
1371 }
1372
GetTensorShapeProtoFromTensorProto(const TensorProto & tensor_proto,TensorShapeProto * tensor_shape_proto)1373 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
1374 TensorShapeProto* tensor_shape_proto) {
1375 tensor_shape_proto->Clear();
1376 // First convert TensorProto into Tensor class so that it correctly parses
1377 // data values within TensorProto (whether it's in int_val, int64_val,
1378 // tensor_content, or anything.
1379 Tensor tensor(tensor_proto.dtype());
1380 if (!tensor.FromProto(tensor_proto)) {
1381 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1382 << "failed to parse TensorProto: "
1383 << tensor_proto.DebugString();
1384 return false;
1385 }
1386 if (tensor.dims() != 1) {
1387 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1388 << "tensor is not 1D: " << tensor.dims();
1389 return false;
1390 }
1391 // Then, convert it back to TensorProto using AsProtoField, which makes sure
1392 // the data is in int_val, int64_val, or such repeated data fields, not in
1393 // tensor_content.
1394 TensorProto temp_tensor;
1395 tensor.AsProtoField(&temp_tensor);
1396
1397 #define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type) \
1398 do { \
1399 for (const auto& value : temp_tensor.type##_val()) { \
1400 tensor_shape_proto->add_dim()->set_size(value); \
1401 } \
1402 } while (0)
1403
1404 if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
1405 tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
1406 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
1407 } else if (tensor.dtype() == DT_INT64) {
1408 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
1409 } else if (tensor.dtype() == DT_UINT32) {
1410 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
1411 } else if (tensor.dtype() == DT_UINT64) {
1412 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
1413 } else {
1414 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1415 << "Unsupported dtype: " << tensor.dtype();
1416 return false;
1417 }
1418 #undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
1419
1420 return true;
1421 }
1422
1423 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
CountConv2DBackpropInputOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1424 int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
1425 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1426 bool* found_unknown_shapes) {
1427 int64_t ops = 0;
1428
1429 DCHECK(op_info.op() == kConv2dBackpropInput ||
1430 op_info.op() == kDepthwiseConv2dNativeBackpropInput)
1431 << "Invalid Operation: not kConv2dBackpropInput nor"
1432 "kDepthwiseConv2dNativeBackpropInput";
1433
1434 if (op_info.inputs_size() < 2) {
1435 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1436 *found_unknown_shapes = true;
1437 return ops;
1438 }
1439
1440 TensorShapeProto input_shape;
1441 bool shape_found = false;
1442 if (op_info.inputs(0).has_value()) {
1443 const TensorProto& value = op_info.inputs(0).value();
1444 shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
1445 }
1446 if (!shape_found && op_info.outputs_size() == 1) {
1447 input_shape = op_info.outputs(0).shape();
1448 shape_found = true;
1449 }
1450 if (!shape_found) {
1451 // Set the minimum filter size that's feasible.
1452 input_shape.Clear();
1453 for (int i = 0; i < 4; ++i) {
1454 input_shape.add_dim()->set_size(1);
1455 }
1456 *found_unknown_shapes = true;
1457 }
1458
1459 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1460 input_shape, op_info.inputs(1).shape(), op_info, found_unknown_shapes);
1461
1462 ops = conv_dims.batch;
1463 ops *= conv_dims.ox * conv_dims.oy;
1464 ops *= conv_dims.kx * conv_dims.ky;
1465 if (op_info.op() == kConv2dBackpropInput) {
1466 ops *= conv_dims.kz * conv_dims.oz;
1467 } else {
1468 // conv_dims always use forward path definition regardless
1469 conv_dims.oz *= conv_dims.iz;
1470 ops *= conv_dims.oz;
1471 }
1472 ops *= kOpsPerMac;
1473
1474 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
1475
1476 if (returned_conv_dims != nullptr) {
1477 *returned_conv_dims = conv_dims;
1478 }
1479 return ops;
1480 }
1481
CountConv2DBackpropFilterOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1482 int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
1483 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1484 bool* found_unknown_shapes) {
1485 int64_t ops = 0;
1486
1487 DCHECK(op_info.op() == kConv2dBackpropFilter ||
1488 op_info.op() == kDepthwiseConv2dNativeBackpropFilter)
1489 << "Invalid Operation: not kConv2dBackpropFilter nor"
1490 "kDepthwiseConv2dNativeBackpropFilter";
1491
1492 TensorShapeProto filter_shape;
1493 bool shape_found = false;
1494 if (op_info.inputs_size() >= 2 && op_info.inputs(1).has_value()) {
1495 const TensorProto& value = op_info.inputs(1).value();
1496 shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
1497 }
1498 if (!shape_found && op_info.outputs_size() == 1) {
1499 filter_shape = op_info.outputs(0).shape();
1500 shape_found = true;
1501 }
1502 if (!shape_found) {
1503 // Set the minimum filter size that's feasible.
1504 filter_shape.Clear();
1505 for (int i = 0; i < 4; ++i) {
1506 filter_shape.add_dim()->set_size(1);
1507 }
1508 *found_unknown_shapes = true;
1509 }
1510
1511 if (op_info.inputs_size() < 1) {
1512 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1513 *found_unknown_shapes = true;
1514 return ops;
1515 }
1516 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1517 op_info.inputs(0).shape(), filter_shape, op_info, found_unknown_shapes);
1518
1519 ops = conv_dims.batch;
1520 ops *= conv_dims.ox * conv_dims.oy;
1521 ops *= conv_dims.kx * conv_dims.ky;
1522 if (op_info.op() == kConv2dBackpropFilter) {
1523 ops *= conv_dims.kz * conv_dims.oz;
1524 } else {
1525 // conv_dims always use forward path definition regardless
1526 conv_dims.oz *= conv_dims.iz;
1527 ops *= conv_dims.oz;
1528 }
1529 ops *= kOpsPerMac;
1530 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
1531
1532 if (returned_conv_dims != nullptr) {
1533 *returned_conv_dims = conv_dims;
1534 }
1535 return ops;
1536 }
1537
CalculateTensorElementCount(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1538 int64 OpLevelCostEstimator::CalculateTensorElementCount(
1539 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1540 VLOG(2) << " with " << DataTypeString(tensor.dtype()) << " tensor of shape "
1541 << tensor.shape().DebugString();
1542 int64_t tensor_size = 1;
1543 int num_dims = std::max(1, tensor.shape().dim_size());
1544 auto tensor_shape =
1545 MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
1546 for (const auto& dim : tensor_shape.dim()) {
1547 tensor_size *= dim.size();
1548 }
1549 return tensor_size;
1550 }
1551
CalculateTensorSize(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1552 int64 OpLevelCostEstimator::CalculateTensorSize(
1553 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1554 int64_t count = CalculateTensorElementCount(tensor, found_unknown_shapes);
1555 int size = DataTypeSize(BaseType(tensor.dtype()));
1556 VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
1557 return count * size;
1558 }
1559
CalculateInputSize(const OpInfo & op_info,bool * found_unknown_shapes)1560 int64 OpLevelCostEstimator::CalculateInputSize(const OpInfo& op_info,
1561 bool* found_unknown_shapes) {
1562 int64_t total_input_size = 0;
1563 for (auto& input : op_info.inputs()) {
1564 int64_t input_size = CalculateTensorSize(input, found_unknown_shapes);
1565 total_input_size += input_size;
1566 VLOG(1) << "Input Size: " << input_size
1567 << " Total Input Size:" << total_input_size;
1568 }
1569 return total_input_size;
1570 }
1571
CalculateInputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1572 std::vector<int64> OpLevelCostEstimator::CalculateInputTensorSize(
1573 const OpInfo& op_info, bool* found_unknown_shapes) {
1574 std::vector<int64> input_tensor_size;
1575 input_tensor_size.reserve(op_info.inputs().size());
1576 for (auto& input : op_info.inputs()) {
1577 input_tensor_size.push_back(
1578 CalculateTensorSize(input, found_unknown_shapes));
1579 }
1580 return input_tensor_size;
1581 }
1582
CalculateLargestInputCount(const OpInfo & op_info,bool * found_unknown_shapes)1583 int64 OpLevelCostEstimator::CalculateLargestInputCount(
1584 const OpInfo& op_info, bool* found_unknown_shapes) {
1585 int64_t largest_input_count = 0;
1586 for (auto& input : op_info.inputs()) {
1587 int64_t input_count =
1588 CalculateTensorElementCount(input, found_unknown_shapes);
1589 if (input_count > largest_input_count) {
1590 largest_input_count = input_count;
1591 }
1592 VLOG(1) << "Input Count: " << input_count
1593 << " Largest Input Count:" << largest_input_count;
1594 }
1595 return largest_input_count;
1596 }
1597
CalculateOutputSize(const OpInfo & op_info,bool * found_unknown_shapes)1598 int64 OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info,
1599 bool* found_unknown_shapes) {
1600 int64_t total_output_size = 0;
1601 // Use float as default for calculations.
1602 for (const auto& output : op_info.outputs()) {
1603 DataType dt = output.dtype();
1604 const auto& original_output_shape = output.shape();
1605 int64_t output_size = DataTypeSize(BaseType(dt));
1606 int num_dims = std::max(1, original_output_shape.dim_size());
1607 auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1608 found_unknown_shapes);
1609 for (const auto& dim : output_shape.dim()) {
1610 output_size *= dim.size();
1611 }
1612 total_output_size += output_size;
1613 VLOG(1) << "Output Size: " << output_size
1614 << " Total Output Size:" << total_output_size;
1615 }
1616 return total_output_size;
1617 }
1618
CalculateOutputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1619 std::vector<int64> OpLevelCostEstimator::CalculateOutputTensorSize(
1620 const OpInfo& op_info, bool* found_unknown_shapes) {
1621 std::vector<int64> output_tensor_size;
1622 output_tensor_size.reserve(op_info.outputs().size());
1623 // Use float as default for calculations.
1624 for (const auto& output : op_info.outputs()) {
1625 DataType dt = output.dtype();
1626 const auto& original_output_shape = output.shape();
1627 int64_t output_size = DataTypeSize(BaseType(dt));
1628 int num_dims = std::max(1, original_output_shape.dim_size());
1629 auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1630 found_unknown_shapes);
1631 for (const auto& dim : output_shape.dim()) {
1632 output_size *= dim.size();
1633 }
1634 output_tensor_size.push_back(output_size);
1635 }
1636 return output_tensor_size;
1637 }
1638
PredictDefaultNodeCosts(const int64_t num_compute_ops,const OpContext & op_context,bool * found_unknown_shapes,NodeCosts * node_costs)1639 Status OpLevelCostEstimator::PredictDefaultNodeCosts(
1640 const int64_t num_compute_ops, const OpContext& op_context,
1641 bool* found_unknown_shapes, NodeCosts* node_costs) {
1642 const auto& op_info = op_context.op_info;
1643 node_costs->num_compute_ops = num_compute_ops;
1644 node_costs->num_input_bytes_accessed =
1645 CalculateInputTensorSize(op_info, found_unknown_shapes);
1646 node_costs->num_output_bytes_accessed =
1647 CalculateOutputTensorSize(op_info, found_unknown_shapes);
1648 node_costs->max_memory = node_costs->num_total_output_bytes();
1649 if (*found_unknown_shapes) {
1650 node_costs->inaccurate = true;
1651 node_costs->num_nodes_with_unknown_shapes = 1;
1652 }
1653 return Status::OK();
1654 }
1655
HasZeroDim(const OpInfo & op_info)1656 bool HasZeroDim(const OpInfo& op_info) {
1657 for (int i = 0; i < op_info.inputs_size(); ++i) {
1658 const auto& input = op_info.inputs(i);
1659 for (int j = 0; j < input.shape().dim_size(); ++j) {
1660 const auto& dim = input.shape().dim(j);
1661 if (dim.size() == 0) {
1662 VLOG(1) << "Convolution config has zero dim "
1663 << op_info.ShortDebugString();
1664 return true;
1665 }
1666 }
1667 }
1668 return false;
1669 }
1670
PredictConv2D(const OpContext & op_context,NodeCosts * node_costs) const1671 Status OpLevelCostEstimator::PredictConv2D(const OpContext& op_context,
1672 NodeCosts* node_costs) const {
1673 const auto& op_info = op_context.op_info;
1674 if (HasZeroDim(op_info)) {
1675 node_costs->num_nodes_with_unknown_shapes = 1;
1676 return errors::InvalidArgument("Conv2D op includes zero dimension: ",
1677 op_info.ShortDebugString());
1678 }
1679 bool found_unknown_shapes = false;
1680 int64_t num_compute_ops =
1681 CountConv2DOperations(op_info, &found_unknown_shapes);
1682 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1683 &found_unknown_shapes, node_costs);
1684 }
1685
PredictConv2DBackpropInput(const OpContext & op_context,NodeCosts * node_costs) const1686 Status OpLevelCostEstimator::PredictConv2DBackpropInput(
1687 const OpContext& op_context, NodeCosts* node_costs) const {
1688 const auto& op_info = op_context.op_info;
1689 if (HasZeroDim(op_info)) {
1690 node_costs->num_nodes_with_unknown_shapes = 1;
1691 return errors::InvalidArgument(
1692 "Conv2DBackpropInput op includes zero dimension",
1693 op_info.ShortDebugString());
1694 }
1695 bool found_unknown_shapes = false;
1696 int64_t num_compute_ops = CountConv2DBackpropInputOperations(
1697 op_info, nullptr, &found_unknown_shapes);
1698 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1699 &found_unknown_shapes, node_costs);
1700 }
1701
PredictConv2DBackpropFilter(const OpContext & op_context,NodeCosts * node_costs) const1702 Status OpLevelCostEstimator::PredictConv2DBackpropFilter(
1703 const OpContext& op_context, NodeCosts* node_costs) const {
1704 const auto& op_info = op_context.op_info;
1705 if (HasZeroDim(op_info)) {
1706 node_costs->num_nodes_with_unknown_shapes = 1;
1707 return errors::InvalidArgument(
1708 "Conv2DBackpropFilter op includes zero dimension",
1709 op_info.ShortDebugString());
1710 }
1711 bool found_unknown_shapes = false;
1712 int64_t num_compute_ops = CountConv2DBackpropFilterOperations(
1713 op_info, nullptr, &found_unknown_shapes);
1714 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1715 &found_unknown_shapes, node_costs);
1716 }
1717
PredictFusedConv2DBiasActivation(const OpContext & op_context,NodeCosts * node_costs) const1718 Status OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
1719 const OpContext& op_context, NodeCosts* node_costs) const {
1720 // FusedConv2DBiasActivation computes a fused kernel which implements:
1721 // 2D convolution, adds side input with separate scaling on convolution and
1722 // side inputs, then adds bias, and finally applies the ReLU activation
1723 // function to the result:
1724 //
1725 // Input -> Conv2D -> Add -> BiasAdd -> ReLU
1726 // ^ ^ ^
1727 // Filter Side Input Bias
1728 //
1729 // Note that when adding the side input, the operation multiplies the output
1730 // of Conv2D by conv_input_scale, confusingly, and the side_input by
1731 // side_input_scale.
1732 //
1733 // Note that in the special case that side_input_scale is 0, which we infer
1734 // from side_input having dimensions [], we skip that addition operation.
1735 //
1736 // For more information, see
1737 // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
1738
1739 // TODO(yaozhang): Support NHWC_VECT_W.
1740 std::string data_format = GetDataFormat(op_context.op_info);
1741 if (data_format != "NCHW" && data_format != "NHWC" &&
1742 data_format != "NCHW_VECT_C") {
1743 return errors::InvalidArgument(
1744 "Unsupported data format (", data_format,
1745 ") for op: ", op_context.op_info.ShortDebugString());
1746 }
1747 std::string filter_format = GetFilterFormat(op_context.op_info);
1748 if (filter_format != "HWIO" && filter_format != "OIHW" &&
1749 filter_format != "OIHW_VECT_I") {
1750 return errors::InvalidArgument(
1751 "Unsupported filter format (", filter_format,
1752 ") for op: ", op_context.op_info.ShortDebugString());
1753 }
1754
1755 auto& conv_input = op_context.op_info.inputs(0);
1756 auto& filter = op_context.op_info.inputs(1);
1757 auto& side_input = op_context.op_info.inputs(3);
1758 auto& conv_input_scale = op_context.op_info.inputs(4);
1759 auto& side_input_scale = op_context.op_info.inputs(5);
1760
1761 // Manually compute our convolution dimensions.
1762 bool found_unknown_shapes = false;
1763 auto dims = ConvolutionDimensionsFromInputs(
1764 conv_input.shape(), filter.shape(), op_context.op_info,
1765 &found_unknown_shapes);
1766 OpInfo::TensorProperties output;
1767 if (data_format == "NCHW" || data_format == "NCHW_VECT_C") {
1768 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox});
1769 } else if (data_format == "NHWC") {
1770 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz});
1771 }
1772
1773 // Add the operations the fused op always computes.
1774 std::vector<OpContext> component_ops = {
1775 FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
1776 FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
1777 FusedChildContext(
1778 op_context, "BiasAdd", output,
1779 {output, output}), // Note we're no longer using bias at all
1780 FusedChildContext(op_context, "Relu", output, {output})};
1781
1782 // Add our side_input iff it's non-empty.
1783 if (side_input.shape().dim_size() > 0) {
1784 component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
1785 {side_input, side_input_scale}));
1786 component_ops.push_back(FusedChildContext(
1787 op_context, "Add", output,
1788 {output, output})); // Note that we're not using side_input here
1789 }
1790
1791 // Construct an op_context which definitely has our output shape.
1792 auto op_context_with_output = op_context;
1793 op_context_with_output.op_info.mutable_outputs()->Clear();
1794 *op_context_with_output.op_info.mutable_outputs()->Add() = output;
1795
1796 // Construct component operations and run the cost computation.
1797 if (found_unknown_shapes) {
1798 node_costs->inaccurate = true;
1799 node_costs->num_nodes_with_unknown_shapes = 1;
1800 }
1801 return PredictFusedOp(op_context_with_output, component_ops, node_costs);
1802 }
1803
PredictMatMul(const OpContext & op_context,NodeCosts * node_costs) const1804 Status OpLevelCostEstimator::PredictMatMul(const OpContext& op_context,
1805 NodeCosts* node_costs) const {
1806 const auto& op_info = op_context.op_info;
1807 bool found_unknown_shapes = false;
1808 int64_t num_compute_ops =
1809 CountMatMulOperations(op_info, &found_unknown_shapes);
1810 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1811 &found_unknown_shapes, node_costs);
1812 }
1813
PredictEinsum(const OpContext & op_context,NodeCosts * node_costs) const1814 Status OpLevelCostEstimator::PredictEinsum(const OpContext& op_context,
1815 NodeCosts* node_costs) const {
1816 const auto& op_info = op_context.op_info;
1817
1818 auto it = op_info.attr().find("equation");
1819 if (it == op_info.attr().end()) {
1820 return errors::InvalidArgument("Einsum op doesn't have equation attr: ",
1821 op_info.ShortDebugString());
1822 }
1823
1824 OpContext batch_matmul_op_context;
1825 bool found_unknown_shapes = false;
1826 bool success = GenerateBatchMatmulContextFromEinsum(
1827 op_context, &batch_matmul_op_context, &found_unknown_shapes);
1828 if (found_unknown_shapes) {
1829 node_costs->inaccurate = true;
1830 node_costs->num_nodes_with_unknown_shapes = 1;
1831 }
1832 if (!success) {
1833 return PredictCostOfAnUnknownOp(op_context, node_costs);
1834 }
1835 return PredictNodeCosts(batch_matmul_op_context, node_costs);
1836 }
1837
PredictSparseTensorDenseMatMul(const OpContext & op_context,NodeCosts * node_costs) const1838 Status OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
1839 const OpContext& op_context, NodeCosts* node_costs) const {
1840 const auto& op_info = op_context.op_info;
1841 bool found_unknown_shapes = false;
1842 // input[0]: indices in sparse matrix a
1843 // input[1]: values in sparse matrix a
1844 // input[2]: shape of matrix a
1845 // input[3]: matrix b
1846 // See
1847 // https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
1848 int64_t num_elems_in_a =
1849 CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1850 auto b_matrix = op_info.inputs(3);
1851 auto b_matrix_shape =
1852 MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
1853 int64_t n_dim = b_matrix_shape.dim(1).size();
1854
1855 // Each element in A is multiplied and added with an element from each column
1856 // in b.
1857 const int64_t op_count = kOpsPerMac * num_elems_in_a * n_dim;
1858
1859 int64_t a_indices_input_size =
1860 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1861 int64_t a_values_input_size =
1862 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1863 int64_t a_shape_input_size =
1864 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1865 int64_t b_input_size =
1866 num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
1867 int64_t output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
1868
1869 node_costs->num_compute_ops = op_count;
1870 node_costs->num_input_bytes_accessed = {a_indices_input_size,
1871 a_values_input_size,
1872 a_shape_input_size, b_input_size};
1873 node_costs->num_output_bytes_accessed = {output_size};
1874 if (found_unknown_shapes) {
1875 node_costs->inaccurate = true;
1876 node_costs->num_nodes_with_unknown_shapes = 1;
1877 }
1878 return Status::OK();
1879 }
1880
PredictNoOp(const OpContext & op_context,NodeCosts * node_costs) const1881 Status OpLevelCostEstimator::PredictNoOp(const OpContext& op_context,
1882 NodeCosts* node_costs) const {
1883 const auto& op_info = op_context.op_info;
1884 VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1885 // By default, NodeCosts is initialized to zero ops and bytes.
1886 return Status::OK();
1887 }
1888
PredictPureMemoryOp(const OpContext & op_context,NodeCosts * node_costs) const1889 Status OpLevelCostEstimator::PredictPureMemoryOp(const OpContext& op_context,
1890 NodeCosts* node_costs) const {
1891 // Each output element is a copy of some element from input, with no required
1892 // computation, so just compute memory costs.
1893 bool found_unknown_shapes = false;
1894 node_costs->num_nodes_with_pure_memory_op = 1;
1895 return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
1896 node_costs);
1897 }
1898
PredictIdentity(const OpContext & op_context,NodeCosts * node_costs) const1899 Status OpLevelCostEstimator::PredictIdentity(const OpContext& op_context,
1900 NodeCosts* node_costs) const {
1901 const auto& op_info = op_context.op_info;
1902 VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Identity";
1903 node_costs->minimum_cost_op = true;
1904 node_costs->num_compute_ops = kMinComputeOp;
1905 // Identity op internally pass input tensor buffer's pointer to the output
1906 // tensor buffer; no actual memory operation.
1907 node_costs->num_input_bytes_accessed = {0};
1908 node_costs->num_output_bytes_accessed = {0};
1909 bool inaccurate = false;
1910 node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1911 if (inaccurate) {
1912 node_costs->inaccurate = true;
1913 node_costs->num_nodes_with_unknown_shapes = 1;
1914 }
1915 return Status::OK();
1916 }
1917
PredictVariable(const OpContext & op_context,NodeCosts * node_costs) const1918 Status OpLevelCostEstimator::PredictVariable(const OpContext& op_context,
1919 NodeCosts* node_costs) const {
1920 const auto& op_info = op_context.op_info;
1921 VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Variable";
1922 node_costs->minimum_cost_op = true;
1923 node_costs->num_compute_ops = kMinComputeOp;
1924 // Variables are persistent ops; initialized before step; hence, no memory
1925 // cost.
1926 node_costs->num_input_bytes_accessed = {0};
1927 node_costs->num_output_bytes_accessed = {0};
1928 bool inaccurate = false;
1929 node_costs->persistent_memory = CalculateOutputSize(op_info, &inaccurate);
1930 if (inaccurate) {
1931 node_costs->inaccurate = true;
1932 node_costs->num_nodes_with_unknown_shapes = 1;
1933 }
1934 return Status::OK();
1935 }
1936
PredictBatchMatMul(const OpContext & op_context,NodeCosts * node_costs) const1937 Status OpLevelCostEstimator::PredictBatchMatMul(const OpContext& op_context,
1938 NodeCosts* node_costs) const {
1939 const auto& op_info = op_context.op_info;
1940 bool found_unknown_shapes = false;
1941 int64_t num_compute_ops =
1942 CountBatchMatMulOperations(op_info, &found_unknown_shapes);
1943 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1944 &found_unknown_shapes, node_costs);
1945 }
1946
PredictMetadata(const OpContext & op_context,NodeCosts * node_costs) const1947 Status OpLevelCostEstimator::PredictMetadata(const OpContext& op_context,
1948 NodeCosts* node_costs) const {
1949 const auto& op_info = op_context.op_info;
1950 node_costs->minimum_cost_op = true;
1951 node_costs->num_compute_ops = kMinComputeOp;
1952 node_costs->num_input_bytes_accessed = {0};
1953 node_costs->num_output_bytes_accessed = {0};
1954 bool inaccurate = false;
1955 node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1956 if (inaccurate) {
1957 node_costs->inaccurate = true;
1958 node_costs->num_nodes_with_unknown_shapes = 1;
1959 }
1960 return Status::OK();
1961 }
1962
PredictGatherOrSlice(const OpContext & op_context,NodeCosts * node_costs) const1963 Status OpLevelCostEstimator::PredictGatherOrSlice(const OpContext& op_context,
1964 NodeCosts* node_costs) const {
1965 // Gather & Slice ops can have a very large input, but only access a small
1966 // part of it. For these op the size of the output determines the memory cost.
1967 const auto& op_info = op_context.op_info;
1968
1969 const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
1970 if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
1971 return errors::InvalidArgument(
1972 op_info.op(),
1973 " Op doesn't have valid input / output: ", op_info.ShortDebugString());
1974 }
1975
1976 bool unknown_shapes = false;
1977
1978 // Each output element is a copy of some element from input.
1979 // For roofline estimate we assume each copy has a unit cost.
1980 const int64_t op_count =
1981 CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
1982 node_costs->num_compute_ops = op_count;
1983
1984 const int64_t output_size = CalculateOutputSize(op_info, &unknown_shapes);
1985 node_costs->num_output_bytes_accessed = {output_size};
1986
1987 node_costs->num_input_bytes_accessed.reserve(op_info.inputs().size());
1988 int64_t input_size = output_size;
1989 // Note that input(0) byte accessed is not equal to input(0) tensor size.
1990 // It's equal to the output size; though, input access is indexed gather or
1991 // slice (ignore duplicate indices).
1992 node_costs->num_input_bytes_accessed.push_back(input_size);
1993 int begin_input_index = 1;
1994 int end_input_index;
1995 if (op_info.op() == "Slice") {
1996 // Slice: 'input' (omitted), 'begin', 'size'
1997 end_input_index = 3;
1998 } else if (op_info.op() == "StridedSlice") {
1999 // StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
2000 end_input_index = 4;
2001 } else {
2002 // Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
2003 end_input_index = 2;
2004 }
2005 for (int i = begin_input_index; i < end_input_index; ++i) {
2006 node_costs->num_input_bytes_accessed.push_back(
2007 CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes));
2008 }
2009 if (unknown_shapes) {
2010 node_costs->inaccurate = true;
2011 node_costs->num_nodes_with_unknown_shapes = 1;
2012 }
2013 return Status::OK();
2014 }
2015
PredictScatter(const OpContext & op_context,NodeCosts * node_costs) const2016 Status OpLevelCostEstimator::PredictScatter(const OpContext& op_context,
2017 NodeCosts* node_costs) const {
2018 // Scatter ops sparsely access a reference input and output tensor.
2019 const auto& op_info = op_context.op_info;
2020 bool found_unknown_shapes = false;
2021
2022 // input[0]: ref tensor that will be sparsely accessed
2023 // input[1]: indices - A tensor of indices into the first dimension of ref.
2024 // input[2]: updates where updates.shape = indices.shape + ref.shape[1:]
2025 // See
2026 // https://www.tensorflow.org/api_docs/python/tf/scatter_add and
2027 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/state_ops.cc#L146
2028
2029 const int64_t num_indices =
2030 CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
2031
2032 int64_t num_elems_in_ref_per_index = 1;
2033 auto ref_tensor_shape = MaybeGetMinimumShape(
2034 op_info.inputs(0).shape(), op_info.inputs(0).shape().dim_size(),
2035 &found_unknown_shapes);
2036 for (int i = 1; i < ref_tensor_shape.dim().size(); ++i) {
2037 num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size();
2038 }
2039 const int64_t op_count = num_indices * num_elems_in_ref_per_index;
2040 node_costs->num_compute_ops = op_count;
2041
2042 // Sparsely access ref so input size depends on the number of operations
2043 int64_t ref_input_size =
2044 op_count * DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2045 int64_t indices_input_size =
2046 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2047 int64_t updates_input_size =
2048 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2049 node_costs->num_input_bytes_accessed = {ref_input_size, indices_input_size,
2050 updates_input_size};
2051
2052 // Sparsely access ref so output size depends on the number of operations
2053 int64_t output_size =
2054 op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype()));
2055 node_costs->num_output_bytes_accessed = {output_size};
2056
2057 if (found_unknown_shapes) {
2058 node_costs->inaccurate = true;
2059 node_costs->num_nodes_with_unknown_shapes = 1;
2060 }
2061 return Status::OK();
2062 }
2063
PredictFusedOp(const OpContext & op_context,const std::vector<OpContext> & fused_op_contexts,NodeCosts * node_costs) const2064 Status OpLevelCostEstimator::PredictFusedOp(
2065 const OpContext& op_context,
2066 const std::vector<OpContext>& fused_op_contexts,
2067 NodeCosts* node_costs) const {
2068 // Note that PredictDefaultNodeCosts will get the correct memory costs from
2069 // the node's inputs and outputs; but we don't want to have to re-implement
2070 // the logic for computing the operation count of each of our component
2071 // operations here; so we simply add the compute times of each component
2072 // operation, then update the cost.
2073 bool found_unknown_shapes = false;
2074 Status s =
2075 PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, node_costs);
2076
2077 for (auto& fused_op : fused_op_contexts) {
2078 NodeCosts fused_node_costs;
2079 s.Update(PredictNodeCosts(fused_op, &fused_node_costs));
2080 node_costs->num_compute_ops += fused_node_costs.num_compute_ops;
2081 node_costs->inaccurate |= fused_node_costs.inaccurate;
2082 // Set, not increment. Note that we are predicting the cost of one fused
2083 // node, not a function node composed of many nodes.
2084 node_costs->num_nodes_with_unknown_shapes |=
2085 fused_node_costs.num_nodes_with_unknown_shapes;
2086 node_costs->num_nodes_with_unknown_op_type |=
2087 fused_node_costs.num_nodes_with_unknown_op_type;
2088 node_costs->num_nodes_with_pure_memory_op |=
2089 fused_node_costs.num_nodes_with_pure_memory_op;
2090 }
2091
2092 return Status::OK();
2093 }
2094
2095 /* static */
FusedChildContext(const OpContext & parent,const std::string & op_name,const OpInfo::TensorProperties & output,const std::vector<OpInfo::TensorProperties> & inputs)2096 OpContext OpLevelCostEstimator::FusedChildContext(
2097 const OpContext& parent, const std::string& op_name,
2098 const OpInfo::TensorProperties& output,
2099 const std::vector<OpInfo::TensorProperties>& inputs) {
2100 // Setup the base parameters of our new context.
2101 OpContext new_context;
2102 new_context.name = op_name;
2103 new_context.device_name = parent.device_name;
2104 new_context.op_info = parent.op_info;
2105 new_context.op_info.set_op(op_name);
2106
2107 // Setup the inputs of our new context.
2108 new_context.op_info.mutable_inputs()->Clear();
2109 for (const auto& input : inputs) {
2110 *new_context.op_info.mutable_inputs()->Add() = input;
2111 }
2112
2113 // Setup the output of our new context.
2114 new_context.op_info.mutable_outputs()->Clear();
2115 *new_context.op_info.mutable_outputs()->Add() = output;
2116
2117 return new_context;
2118 }
2119
2120 /* static */
DescribeTensor(DataType type,const std::vector<int64> & dims)2121 OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
2122 DataType type, const std::vector<int64>& dims) {
2123 OpInfo::TensorProperties ret;
2124 ret.set_dtype(type);
2125
2126 auto shape = ret.mutable_shape();
2127 for (const int dim : dims) {
2128 shape->add_dim()->set_size(dim);
2129 }
2130
2131 return ret;
2132 }
2133
2134 /* static */
2135 OpLevelCostEstimator::ConvolutionDimensions
OpDimensionsFromInputs(const TensorShapeProto & original_image_shape,const OpInfo & op_info,bool * found_unknown_shapes)2136 OpLevelCostEstimator::OpDimensionsFromInputs(
2137 const TensorShapeProto& original_image_shape, const OpInfo& op_info,
2138 bool* found_unknown_shapes) {
2139 VLOG(2) << "op features: " << op_info.DebugString();
2140 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
2141 auto image_shape =
2142 MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
2143 VLOG(2) << "Image shape: " << image_shape.DebugString();
2144
2145 int x_index, y_index, channel_index;
2146 const std::string& data_format = GetDataFormat(op_info);
2147 if (data_format == "NCHW") {
2148 channel_index = 1;
2149 y_index = 2;
2150 x_index = 3;
2151 } else {
2152 y_index = 1;
2153 x_index = 2;
2154 channel_index = 3;
2155 }
2156 int64_t batch = image_shape.dim(0).size();
2157 int64_t ix = image_shape.dim(x_index).size();
2158 int64_t iy = image_shape.dim(y_index).size();
2159 int64_t iz = image_shape.dim(channel_index).size();
2160
2161 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
2162 // {1, 1, 1, 1} in that case.
2163 std::vector<int64> ksize = GetKernelSize(op_info);
2164 int64_t kx = ksize[x_index];
2165 int64_t ky = ksize[y_index];
2166 // These ops don't support groupwise operation, therefore kz == iz.
2167 int64_t kz = iz;
2168
2169 std::vector<int64> strides = GetStrides(op_info);
2170 int64_t sx = strides[x_index];
2171 int64_t sy = strides[y_index];
2172 const auto padding = GetPadding(op_info);
2173
2174 int64_t ox = GetOutputSize(ix, kx, sx, padding);
2175 int64_t oy = GetOutputSize(iy, ky, sy, padding);
2176 int64_t oz = iz;
2177
2178 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
2179 batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
2180 return conv_dims;
2181 }
2182
PredictMaxPool(const OpContext & op_context,NodeCosts * node_costs) const2183 Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
2184 NodeCosts* node_costs) const {
2185 bool found_unknown_shapes = false;
2186 const auto& op_info = op_context.op_info;
2187 // x: op_info.inputs(0)
2188 ConvolutionDimensions dims = OpDimensionsFromInputs(
2189 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2190 // kx * ky - 1 comparisons per output (kx * xy > 1)
2191 // or 1 copy per output (kx * k1 = 1).
2192 int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
2193 int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
2194 node_costs->num_compute_ops = ops;
2195
2196 int64_t input_size = 0;
2197 if (dims.ky >= dims.sy) {
2198 input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2199 } else { // dims.ky < dims.sy
2200 // Vertical stride is larger than vertical kernel; assuming row-major
2201 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2202 // others are not used for output).
2203 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2204 input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2205 }
2206 node_costs->num_input_bytes_accessed = {input_size};
2207 const int64_t output_size =
2208 CalculateOutputSize(op_info, &found_unknown_shapes);
2209 node_costs->num_output_bytes_accessed = {output_size};
2210 node_costs->max_memory = output_size;
2211 if (found_unknown_shapes) {
2212 node_costs->inaccurate = true;
2213 node_costs->num_nodes_with_unknown_shapes = 1;
2214 }
2215 return Status::OK();
2216 }
2217
PredictMaxPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2218 Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
2219 NodeCosts* node_costs) const {
2220 bool found_unknown_shapes = false;
2221 const auto& op_info = op_context.op_info;
2222 // x: op_info.inputs(0)
2223 // y: op_info.inputs(1)
2224 // y_grad: op_info.inputs(2)
2225 if (op_info.inputs_size() < 3) {
2226 return errors::InvalidArgument("MaxPoolGrad op has invalid inputs: ",
2227 op_info.ShortDebugString());
2228 }
2229
2230 ConvolutionDimensions dims = OpDimensionsFromInputs(
2231 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2232
2233 int64_t ops = 0;
2234 if (dims.kx == 1 && dims.ky == 1) {
2235 // 1x1 window. No need to know which input was max.
2236 ops = dims.batch * dims.ix * dims.iy * dims.iz;
2237 } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2238 // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
2239 ops = dims.batch * dims.iz *
2240 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
2241 } else {
2242 // Overlapping window: initialize with zeros, re-run maxpool, then
2243 // accumulate y_gad to proper x_grad locations.
2244 ops = dims.batch * dims.iz *
2245 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
2246 }
2247 node_costs->num_compute_ops = ops;
2248
2249 // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
2250 // MaxPool internally.
2251 const int64_t input0_size =
2252 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2253 const int64_t input2_size =
2254 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2255 node_costs->num_input_bytes_accessed = {input0_size, 0, input2_size};
2256 // Write x_grad; size equal to x.
2257 const int64_t output_size =
2258 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2259 node_costs->num_output_bytes_accessed = {output_size};
2260 node_costs->max_memory = output_size;
2261
2262 if (found_unknown_shapes) {
2263 node_costs->inaccurate = true;
2264 node_costs->num_nodes_with_unknown_shapes = 1;
2265 }
2266 return Status::OK();
2267 }
2268
2269 /* This predict function handles three types of tensorflow ops
2270 * AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting
2271 * was not possible for these ops, therefore the input tensor's shapes is
2272 * enough to compute the cost */
PredictAssignVariableOps(const OpContext & op_context,NodeCosts * node_costs) const2273 Status OpLevelCostEstimator::PredictAssignVariableOps(
2274 const OpContext& op_context, NodeCosts* node_costs) const {
2275 bool found_unknown_shapes = false;
2276 const auto& op_info = op_context.op_info;
2277 /* First input of these ops are reference to the assignee. */
2278 if (op_info.inputs_size() != 2) {
2279 return errors::InvalidArgument("AssignVariable op has invalid input: ",
2280 op_info.ShortDebugString());
2281 }
2282
2283 const int64_t ops = op_info.op() == kAssignVariableOp
2284 ? 0
2285 : CalculateTensorElementCount(op_info.inputs(1),
2286 &found_unknown_shapes);
2287 node_costs->num_compute_ops = ops;
2288 const int64_t input_size = CalculateInputSize(op_info, &found_unknown_shapes);
2289 node_costs->num_input_bytes_accessed = {input_size};
2290 // TODO(dyoon): check these ops' behavior whether it writes data;
2291 // Op itself doesn't have output tensor, but it may modify the input (ref or
2292 // resource). Maybe use node_costs->internal_write_bytes.
2293 node_costs->num_output_bytes_accessed = {0};
2294 if (found_unknown_shapes) {
2295 node_costs->inaccurate = true;
2296 node_costs->num_nodes_with_unknown_shapes = 1;
2297 }
2298 return Status::OK();
2299 }
2300
PredictAvgPool(const OpContext & op_context,NodeCosts * node_costs) const2301 Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
2302 NodeCosts* node_costs) const {
2303 bool found_unknown_shapes = false;
2304 const auto& op_info = op_context.op_info;
2305 // x: op_info.inputs(0)
2306 ConvolutionDimensions dims = OpDimensionsFromInputs(
2307 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2308
2309 // kx * ky - 1 additions and 1 multiplication per output.
2310 int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
2311 node_costs->num_compute_ops = ops;
2312
2313 int64_t input_size;
2314 if (dims.ky >= dims.sy) {
2315 input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2316 } else { // dims.ky < dims.sy
2317 // vertical stride is larger than vertical kernel; assuming row-major
2318 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2319 // others are not used for output).
2320 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2321 input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2322 }
2323 node_costs->num_input_bytes_accessed = {input_size};
2324
2325 const int64_t output_size =
2326 CalculateOutputSize(op_info, &found_unknown_shapes);
2327 node_costs->num_output_bytes_accessed = {output_size};
2328 node_costs->max_memory = output_size;
2329
2330 if (found_unknown_shapes) {
2331 node_costs->inaccurate = true;
2332 node_costs->num_nodes_with_unknown_shapes = 1;
2333 }
2334 return Status::OK();
2335 }
2336
PredictAvgPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2337 Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
2338 NodeCosts* node_costs) const {
2339 bool found_unknown_shapes = false;
2340 const auto& op_info = op_context.op_info;
2341 // x's shape: op_info.inputs(0)
2342 // y_grad: op_info.inputs(1)
2343
2344 // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
2345 bool shape_found = false;
2346 TensorShapeProto x_shape;
2347 if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
2348 const TensorProto& value = op_info.inputs(0).value();
2349 shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
2350 }
2351 if (!shape_found && op_info.outputs_size() > 0) {
2352 x_shape = op_info.outputs(0).shape();
2353 shape_found = true;
2354 }
2355 if (!shape_found) {
2356 // Set the minimum shape that's feasible.
2357 x_shape.Clear();
2358 for (int i = 0; i < 4; ++i) {
2359 x_shape.add_dim()->set_size(1);
2360 }
2361 found_unknown_shapes = true;
2362 }
2363
2364 ConvolutionDimensions dims =
2365 OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
2366
2367 int64_t ops = 0;
2368 if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2369 // Non-overlapping window.
2370 ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
2371 } else {
2372 // Overlapping window.
2373 ops = dims.batch * dims.iz *
2374 (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
2375 }
2376 auto s = PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2377 node_costs);
2378 node_costs->max_memory = node_costs->num_total_output_bytes();
2379 return s;
2380 }
2381
PredictFusedBatchNorm(const OpContext & op_context,NodeCosts * node_costs) const2382 Status OpLevelCostEstimator::PredictFusedBatchNorm(
2383 const OpContext& op_context, NodeCosts* node_costs) const {
2384 bool found_unknown_shapes = false;
2385 const auto& op_info = op_context.op_info;
2386 // x: op_info.inputs(0)
2387 // scale: op_info.inputs(1)
2388 // offset: op_info.inputs(2)
2389 // mean: op_info.inputs(3) --> only for inference
2390 // variance: op_info.inputs(4) --> only for inference
2391 ConvolutionDimensions dims = OpDimensionsFromInputs(
2392 op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2393 const bool is_training = IsTraining(op_info);
2394
2395 int64_t ops = 0;
2396 const auto rsqrt_cost = Eigen::internal::functor_traits<
2397 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2398 if (is_training) {
2399 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
2400 } else {
2401 ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
2402 }
2403 node_costs->num_compute_ops = ops;
2404
2405 const int64_t size_nhwc =
2406 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2407 const int64_t size_c =
2408 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2409 if (is_training) {
2410 node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c};
2411 node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2412 size_c};
2413 // FusedBatchNorm in training mode internally re-reads the input tensor:
2414 // one for mean/variance, and the 2nd internal read forthe actual scaling.
2415 // Assume small intermediate data such as mean / variance (size_c) can be
2416 // cached on-chip.
2417 node_costs->internal_read_bytes = size_nhwc;
2418 } else {
2419 node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2420 size_c};
2421 node_costs->num_output_bytes_accessed = {size_nhwc};
2422 }
2423 node_costs->max_memory = node_costs->num_total_output_bytes();
2424
2425 if (found_unknown_shapes) {
2426 node_costs->inaccurate = true;
2427 node_costs->num_nodes_with_unknown_shapes = 1;
2428 }
2429 return Status::OK();
2430 }
2431
PredictFusedBatchNormGrad(const OpContext & op_context,NodeCosts * node_costs) const2432 Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
2433 const OpContext& op_context, NodeCosts* node_costs) const {
2434 bool found_unknown_shapes = false;
2435 const auto& op_info = op_context.op_info;
2436 // y_backprop: op_info.inputs(0)
2437 // x: op_info.inputs(1)
2438 // scale: op_info.inputs(2)
2439 // mean: op_info.inputs(3)
2440 // variance or inverse of variance: op_info.inputs(4)
2441 ConvolutionDimensions dims = OpDimensionsFromInputs(
2442 op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
2443
2444 int64_t ops = 0;
2445 const auto rsqrt_cost = Eigen::internal::functor_traits<
2446 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2447 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
2448 node_costs->num_compute_ops = ops;
2449
2450 const int64_t size_nhwc =
2451 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2452 const int64_t size_c =
2453 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2454 // TODO(dyoon): fix missing memory cost for variance input (size_c) and
2455 // yet another read of y_backprop (size_nhwc) internally.
2456 node_costs->num_input_bytes_accessed = {size_nhwc, size_nhwc, size_c, size_c};
2457 node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c};
2458 // FusedBatchNormGrad has to read y_backprop internally.
2459 node_costs->internal_read_bytes = size_nhwc;
2460 node_costs->max_memory = node_costs->num_total_output_bytes();
2461
2462 if (found_unknown_shapes) {
2463 node_costs->inaccurate = true;
2464 node_costs->num_nodes_with_unknown_shapes = 1;
2465 }
2466 return Status::OK();
2467 }
2468
PredictNaryOp(const OpContext & op_context,NodeCosts * node_costs) const2469 Status OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context,
2470 NodeCosts* node_costs) const {
2471 const auto& op_info = op_context.op_info;
2472 bool found_unknown_shapes = false;
2473 // Calculate the largest known tensor size across all inputs and output.
2474 int64_t op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
2475 // If output shape is available, try to use the element count calculated from
2476 // that.
2477 if (op_info.outputs_size() > 0) {
2478 op_count = std::max(
2479 op_count,
2480 CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
2481 }
2482 // Also calculate the output shape possibly resulting from broadcasting.
2483 // Note that the some Nary ops (such as AddN) do not support broadcasting,
2484 // but we're including this here for completeness.
2485 if (op_info.inputs_size() >= 2) {
2486 op_count = std::max(op_count, CwiseOutputElementCount(op_info));
2487 }
2488
2489 // Nary ops perform one operation for every element in every input tensor.
2490 op_count *= op_info.inputs_size() - 1;
2491
2492 const auto sum_cost = Eigen::internal::functor_traits<
2493 Eigen::internal::scalar_sum_op<float>>::Cost;
2494 return PredictDefaultNodeCosts(op_count * sum_cost, op_context,
2495 &found_unknown_shapes, node_costs);
2496 }
2497
2498 // softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
PredictSoftmax(const OpContext & op_context,NodeCosts * node_costs) const2499 Status OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context,
2500 NodeCosts* node_costs) const {
2501 bool found_unknown_shapes = false;
2502 const int64_t logits_size = CalculateTensorElementCount(
2503 op_context.op_info.inputs(0), &found_unknown_shapes);
2504 // Softmax input rank should be >=1.
2505 TensorShapeProto logits_shape = op_context.op_info.inputs(0).shape();
2506 if (logits_shape.unknown_rank() || logits_shape.dim_size() == 0) {
2507 return errors::InvalidArgument("Softmax op has invalid input: ",
2508 op_context.op_info.ShortDebugString());
2509 }
2510
2511 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2512
2513 // Every element of <logits> will be exponentiated, have that result included
2514 // in a sum across j, and also have that result multiplied by the reciprocal
2515 // of the sum_j. In addition, we'll compute 1/sum_j for every i.
2516 auto ops =
2517 (EIGEN_COST(scalar_exp_op<float>) + EIGEN_COST(scalar_sum_op<float>) +
2518 EIGEN_COST(scalar_product_op<float>)) *
2519 logits_size +
2520 EIGEN_COST(scalar_inverse_op<float>) * logits_shape.dim(0).size();
2521
2522 #undef EIGEN_COST
2523 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2524 node_costs);
2525 }
2526
PredictResizeBilinear(const OpContext & op_context,NodeCosts * node_costs) const2527 Status OpLevelCostEstimator::PredictResizeBilinear(
2528 const OpContext& op_context, NodeCosts* node_costs) const {
2529 bool found_unknown_shapes = false;
2530
2531 if (op_context.op_info.outputs().empty() ||
2532 op_context.op_info.inputs().empty()) {
2533 return errors::InvalidArgument(
2534 "ResizeBilinear op has invalid input / output ",
2535 op_context.op_info.ShortDebugString());
2536 }
2537
2538 const int64_t output_elements = CalculateTensorElementCount(
2539 op_context.op_info.outputs(0), &found_unknown_shapes);
2540
2541 const auto half_pixel_centers =
2542 op_context.op_info.attr().find("half_pixel_centers");
2543 bool use_half_pixel_centers = false;
2544 if (half_pixel_centers == op_context.op_info.attr().end()) {
2545 LOG(WARNING) << "half_pixel_centers attr not set for ResizeBilinear.";
2546 return PredictCostOfAnUnknownOp(op_context, node_costs);
2547 } else {
2548 use_half_pixel_centers = half_pixel_centers->second.b();
2549 }
2550
2551 // Compose cost of bilinear interpolation.
2552 int64_t ops = 0;
2553
2554 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2555 const auto sub_cost_float = EIGEN_COST(scalar_difference_op<float>);
2556 const auto sub_cost_int = EIGEN_COST(scalar_difference_op<int64>);
2557 const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2558 const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2559 const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2560 const auto max_cost = EIGEN_COST(scalar_max_op<int64>);
2561 const auto min_cost = EIGEN_COST(scalar_min_op<int64>);
2562 const auto cast_to_int_cost = Eigen::internal::functor_traits<
2563 Eigen::internal::scalar_cast_op<float, int64>>::Cost;
2564 const auto cast_to_float_cost = Eigen::internal::functor_traits<
2565 Eigen::internal::scalar_cast_op<int64, float>>::Cost;
2566 const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2567 #undef EIGEN_COST
2568
2569 // Ops calcualted from tensorflow/core/kernels/image/resize_bilinear_op.cc.
2570
2571 // Op counts taken from resize_bilinear implementation on 07/21/2020.
2572 // Computed op counts may become inaccurate if resize_bilinear implementation
2573 // changes.
2574
2575 // resize_bilinear has an optimization where the interpolation weights are
2576 // precomputed and cached. Given input tensors of size [B,H1,W1,C] and output
2577 // tensors of size [B,H2,W2,C], the last dimension C that needs to be accessed
2578 // in the input for interpolation are identical at every point in the output.
2579 // These values are cached in the compute_interpolation_weights function. For
2580 // a particular y in [0...H2-1], the rows to be accessed in the input are the
2581 // same. Likewise, for a particular x in [0...H2-1], the columns to be accsed
2582 // are the same. So the precomputation only needs to be done for H2 + W2
2583 // values.
2584 const auto output_shape = MaybeGetMinimumShape(
2585 op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2586 // Assume H is dim 1 and W is dim 2 to match logic in resize_bilinear, which
2587 // also makes this assumption.
2588 const int64_t output_height = output_shape.dim(1).size();
2589 const int64_t output_width = output_shape.dim(2).size();
2590 // Add the ops done outside of the scaler function in
2591 // compute_interpolation_weights.
2592 int64_t interp_weight_cost = floor_cost + max_cost + min_cost +
2593 sub_cost_float + sub_cost_int + ceil_cost +
2594 cast_to_int_cost * 2;
2595 // There are two options for computing the weight of each pixel in the
2596 // interpolation. Algorithm can use pixel centers, or corners, for the
2597 // weight. Ops depend on the scaler function passed into
2598 // compute_interpolation_weights.
2599 if (use_half_pixel_centers) {
2600 // Ops for HalfPixelScalaer.
2601 interp_weight_cost +=
2602 add_cost + mul_cost + sub_cost_float + cast_to_float_cost;
2603 } else {
2604 // Ops for LegacyScaler.
2605 interp_weight_cost += cast_to_float_cost + mul_cost;
2606 }
2607 // Cost for the interpolation is multipled by (H2 + w2), as mentioned above.
2608 ops += interp_weight_cost * (output_height + output_width);
2609
2610 // Ops for computing the new values, done for every element. Logic is from
2611 // compute_lerp in the inner loop of resize_image which consists of:
2612 // const float top = top_left + (top_right - top_left) * x_lerp;
2613 // const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
2614 // return top + (bottom - top) * y_lerp;
2615 ops += (add_cost * 3 + sub_cost_float * 3 + mul_cost * 3) * output_elements;
2616
2617 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2618 node_costs);
2619 }
2620
PredictCropAndResize(const OpContext & op_context,NodeCosts * node_costs) const2621 Status OpLevelCostEstimator::PredictCropAndResize(const OpContext& op_context,
2622 NodeCosts* node_costs) const {
2623 bool found_unknown_shapes = false;
2624
2625 const auto method = op_context.op_info.attr().find("method");
2626 bool use_bilinear_interp;
2627 if (method == op_context.op_info.attr().end() ||
2628 method->second.s() == "bilinear") {
2629 use_bilinear_interp = true;
2630 } else if (method->second.s() == "nearest") {
2631 use_bilinear_interp = false;
2632 } else {
2633 LOG(WARNING) << "method attr in CropAndResize invalid; expected bilinear "
2634 "or nearest.";
2635 return PredictCostOfAnUnknownOp(op_context, node_costs);
2636 }
2637
2638 const int64_t num_boxes = op_context.op_info.inputs(1).shape().dim(0).size();
2639 const auto crop_shape = MaybeGetMinimumShape(
2640 op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2641 const int64_t crop_height = crop_shape.dim(1).size();
2642 const int64_t crop_width = crop_shape.dim(2).size();
2643 const int64_t output_elements = CalculateTensorElementCount(
2644 op_context.op_info.outputs(0), &found_unknown_shapes);
2645
2646 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2647 const auto sub_cost = EIGEN_COST(scalar_difference_op<float>);
2648 const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2649 const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2650 auto div_cost = EIGEN_COST(scalar_div_cost<float>);
2651 const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2652 const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2653 auto round_cost = EIGEN_COST(scalar_round_op<float>);
2654 const auto cast_to_float_cost = Eigen::internal::functor_traits<
2655 Eigen::internal::scalar_cast_op<int64, float>>::Cost;
2656 #undef EIGEN_COST
2657
2658 // Computing ops following
2659 // tensorflow/core/kernels/image/crop_and_resize_op.cc at 08/25/2020. Op
2660 // calculation differs from rough estimate in implementation, as it separates
2661 // out cost per box from cost per pixel and cost per element.
2662
2663 // Ops for variables height_scale and width_scale.
2664 int64_t ops = (sub_cost * 6 + mul_cost * 2 + div_cost * 2) * num_boxes;
2665 // Ops for variable in_y.
2666 ops += (mul_cost * 2 + sub_cost + add_cost) * crop_height * num_boxes;
2667 // Ops for variable in_x (same computation across both branches).
2668 ops += (mul_cost * 2 + sub_cost + add_cost) * crop_height * crop_width *
2669 num_boxes;
2670 // Specify op_cost based on the method.
2671 if (use_bilinear_interp) {
2672 // Ops for variables top_y_index, bottom_y_index, y_lerp.
2673 ops += (floor_cost + ceil_cost + sub_cost) * crop_height * num_boxes;
2674 // Ops for variables left_x, right_x, x_lerp;
2675 ops += (floor_cost + ceil_cost + sub_cost) * crop_height * crop_width *
2676 num_boxes;
2677 // Ops for innermost loop across depth.
2678 ops +=
2679 (cast_to_float_cost * 4 + add_cost * 3 + sub_cost * 3 + mul_cost * 3) *
2680 output_elements;
2681 } else /* method == "nearest" */ {
2682 // Ops for variables closest_x_index and closest_y_index.
2683 ops += round_cost * 2 * crop_height * crop_width * num_boxes;
2684 // Ops for innermost loop across depth.
2685 ops += cast_to_float_cost * output_elements;
2686 }
2687 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2688 node_costs);
2689 }
2690
2691 } // end namespace grappler
2692 } // end namespace tensorflow
2693