• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/device_properties.pb.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 
33 namespace {
34 
35 // TODO(dyoon): Consider to use this Test class for all the test cases, and then
36 // remove friend in the OpLevelCostEstimator class header.
37 class TestOpLevelCostEstimator : public OpLevelCostEstimator {
38  public:
TestOpLevelCostEstimator()39   TestOpLevelCostEstimator() {
40     compute_memory_overlap_ = true;
41     device_info_ = DeviceInfo();
42   }
~TestOpLevelCostEstimator()43   ~TestOpLevelCostEstimator() override {}
44 
SetDeviceInfo(const DeviceInfo & device_info)45   void SetDeviceInfo(const DeviceInfo& device_info) {
46     device_info_ = device_info;
47   }
48 
SetComputeMemoryOverlap(bool value)49   void SetComputeMemoryOverlap(bool value) { compute_memory_overlap_ = value; }
50 
51  protected:
GetDeviceInfo(const DeviceProperties & device) const52   DeviceInfo GetDeviceInfo(const DeviceProperties& device) const override {
53     return device_info_;
54   }
55 
56   DeviceInfo device_info_;
57 };
58 
ExpectZeroCost(const Costs & cost)59 void ExpectZeroCost(const Costs& cost) {
60   EXPECT_TRUE(cost.inaccurate);
61   EXPECT_EQ(cost.compute_time, Costs::Duration::zero());
62   EXPECT_EQ(cost.execution_time, Costs::Duration::zero());
63   EXPECT_EQ(cost.memory_time, Costs::Duration::zero());
64 }
65 
66 // Wrangles the minimum number of proto fields to set up a matrix.
DescribeMatrix(int rows,int columns,OpInfo * op_info)67 void DescribeMatrix(int rows, int columns, OpInfo* op_info) {
68   auto input = op_info->add_inputs();
69   auto shape = input->mutable_shape();
70   auto shape_rows = shape->add_dim();
71   shape_rows->set_size(rows);
72   auto shape_columns = shape->add_dim();
73   shape_columns->set_size(columns);
74   input->set_dtype(DT_FLOAT);
75 }
76 
SetCpuDevice(OpInfo * op_info)77 void SetCpuDevice(OpInfo* op_info) {
78   auto device = op_info->mutable_device();
79   device->set_type("CPU");
80   device->set_num_cores(10);
81   device->set_bandwidth(10000000);  // 10000000 KB/s = 10 GB/s
82   device->set_frequency(1000);      // 1000 Mhz = 1 GHz
83 }
84 
85 // Returns an OpInfo for MatMul with the minimum set of fields set up.
DescribeMatMul(int m,int n,int l,int k)86 OpContext DescribeMatMul(int m, int n, int l, int k) {
87   OpContext op_context;
88   SetCpuDevice(&op_context.op_info);
89   op_context.op_info.set_op("MatMul");
90 
91   DescribeMatrix(m, l, &op_context.op_info);
92   DescribeMatrix(k, n, &op_context.op_info);
93   return op_context;
94 }
95 
96 // Wrangles the minimum number of proto fields to set up an input of
97 // arbitrary rank and type.
DescribeArbitraryRankInput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)98 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
99                                 OpInfo* op_info) {
100   auto input = op_info->add_inputs();
101   input->set_dtype(dtype);
102   auto shape = input->mutable_shape();
103   for (auto d : dims) {
104     shape->add_dim()->set_size(d);
105   }
106 }
107 
108 // Wrangles the minimum number of proto fields to set up an output of
109 // arbitrary rank and type.
DescribeArbitraryRankOutput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)110 void DescribeArbitraryRankOutput(const std::vector<int>& dims, DataType dtype,
111                                  OpInfo* op_info) {
112   auto output = op_info->add_outputs();
113   output->set_dtype(dtype);
114   auto shape = output->mutable_shape();
115   for (auto d : dims) {
116     shape->add_dim()->set_size(d);
117   }
118 }
119 
120 // Returns an OpInfo for a SparseTensorDenseMatMul
DescribeSparseTensorDenseMatMul(const int nnz_a,const std::vector<int> & dims_b,const std::vector<int> & dims_out)121 OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
122                                           const std::vector<int>& dims_b,
123                                           const std::vector<int>& dims_out) {
124   OpContext op_context;
125   SetCpuDevice(&op_context.op_info);
126   op_context.op_info.set_op("SparseTensorDenseMatMul");
127 
128   DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
129   DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
130   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
131   DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
132   DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
133   return op_context;
134 }
135 
136 // Returns an OpInfo for an XlaEinsum
DescribeXlaEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)137 OpContext DescribeXlaEinsum(const std::vector<int>& dims_a,
138                             const std::vector<int>& dims_b,
139                             const string& equation) {
140   OpContext op_context;
141   SetCpuDevice(&op_context.op_info);
142   op_context.op_info.set_op("XlaEinsum");
143   AttrValue equation_attribute;
144   equation_attribute.set_s(equation);
145   (*op_context.op_info.mutable_attr())["equation"] = equation_attribute;
146   if (!dims_a.empty())
147     DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
148   if (!dims_b.empty())
149     DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
150   return op_context;
151 }
152 
153 // Returns an OpInfo for an Einsum
DescribeEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)154 OpContext DescribeEinsum(const std::vector<int>& dims_a,
155                          const std::vector<int>& dims_b,
156                          const string& equation) {
157   OpContext op_context = DescribeXlaEinsum(dims_a, dims_b, equation);
158   op_context.op_info.set_op("Einsum");
159   return op_context;
160 }
161 
DescribeDummyTensor(OpInfo::TensorProperties * tensor)162 void DescribeDummyTensor(OpInfo::TensorProperties* tensor) {
163   // Intentionally leave the tensor shape and type information missing.
164 }
165 
166 // Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
167 // estimation purposes.
DescribeTensor1D(int dim0,OpInfo::TensorProperties * tensor)168 void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
169   auto shape = tensor->mutable_shape();
170   shape->add_dim()->set_size(dim0);
171   tensor->set_dtype(DT_FLOAT);
172 }
173 
174 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
175 // estimation purposes.
DescribeTensor4D(int dim0,int dim1,int dim2,int dim3,OpInfo::TensorProperties * tensor)176 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
177                       OpInfo::TensorProperties* tensor) {
178   auto shape = tensor->mutable_shape();
179   shape->add_dim()->set_size(dim0);
180   shape->add_dim()->set_size(dim1);
181   shape->add_dim()->set_size(dim2);
182   shape->add_dim()->set_size(dim3);
183   tensor->set_dtype(DT_FLOAT);
184 }
185 
186 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
187 // estimation purposes.
DescribeTensor5D(int dim0,int dim1,int dim2,int dim3,int dim4,OpInfo::TensorProperties * tensor)188 void DescribeTensor5D(int dim0, int dim1, int dim2, int dim3, int dim4,
189                       OpInfo::TensorProperties* tensor) {
190   auto shape = tensor->mutable_shape();
191   shape->add_dim()->set_size(dim0);
192   shape->add_dim()->set_size(dim1);
193   shape->add_dim()->set_size(dim2);
194   shape->add_dim()->set_size(dim3);
195   shape->add_dim()->set_size(dim4);
196   tensor->set_dtype(DT_FLOAT);
197 }
198 
199 // DescribeConvolution constructs an OpContext for a Conv2D applied to an input
200 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
201 // (kx, ky, iz2, oz).
DescribeConvolution(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int oz)202 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
203                               int kx, int ky, int oz) {
204   OpContext op_context;
205   SetCpuDevice(&op_context.op_info);
206   op_context.op_info.set_op("Conv2D");
207 
208   DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
209   DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
210 
211   return op_context;
212 }
213 
214 // Describe DepthwiseConvolution constructs an OpContext for a
215 // DepthwiseConv2dNative applied to an input
216 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
217 // (kx, ky, iz2, cm). cm is channel multiplier
218 
DescribeDepthwiseConv2dNative(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int cm)219 OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
220                                         int iz2, int kx, int ky, int cm) {
221   OpContext op_context;
222   SetCpuDevice(&op_context.op_info);
223   op_context.op_info.set_op("DepthwiseConv2dNative");
224 
225   DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
226   DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs());
227 
228   return op_context;
229 }
230 
231 // DescribeFusedConv2DBiasActivation constructs an OpContext for a
232 // FusedConv2DBiasActivation applied to a convolution input tensor with shape
233 // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
234 // bias tensor with shape (oz), a side input tensor with shape
235 // (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
236 // shape (1). If a vectorized channel format is chosen (NCHW_VECT_C, e.g.) we'll
237 // default to 4 (the vector size most often used with this format on NVIDIA
238 // platforms) for the major channel size, and divide the input channel size by
239 // that amount.
240 //
241 // Note that this assumes the NHWC data format.
DescribeFusedConv2DBiasActivation(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int ox,int oy,int oz,bool has_side_input,const string & data_format,const string & filter_format)242 OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
243                                             int iz2, int kx, int ky, int ox,
244                                             int oy, int oz, bool has_side_input,
245                                             const string& data_format,
246                                             const string& filter_format) {
247   const int kVecWidth = 4;
248   OpContext op_context;
249   SetCpuDevice(&op_context.op_info);
250   op_context.op_info.set_op("FusedConv2DBiasActivation");
251   auto* attr_data_format = op_context.op_info.mutable_attr();
252   SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
253   auto* attr_filter_format = op_context.op_info.mutable_attr();
254   SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
255   if (data_format == "NHWC") {
256     DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
257   } else if (data_format == "NCHW") {
258     DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
259   } else {
260     // Use the NCHW_VECT_C format.
261     EXPECT_EQ(data_format, "NCHW_VECT_C");
262     EXPECT_EQ(iz1 % kVecWidth, 0);
263     DescribeTensor5D(batch, iz1 / kVecWidth, ix, iy, kVecWidth,
264                      op_context.op_info.add_inputs());
265   }
266   if (filter_format == "HWIO") {
267     DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
268   } else if (filter_format == "OIHW") {
269     DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
270   } else {
271     EXPECT_EQ(filter_format, "OIHW_VECT_I");
272     EXPECT_EQ(iz2 % kVecWidth, 0);
273     // Use the OIHW_VECT_I format.
274     DescribeTensor5D(oz, iz2 / kVecWidth, kx, ky, kVecWidth,
275                      op_context.op_info.add_inputs());
276   }
277   DescribeTensor1D(oz, op_context.op_info.add_inputs());
278 
279   // Add the side_input, if any.
280   auto side_input = op_context.op_info.add_inputs();
281   if (has_side_input) {
282     if (data_format == "NHWC") {
283       DescribeTensor4D(batch, ox, oy, oz, side_input);
284     } else if (data_format == "NCHW") {
285       DescribeTensor4D(batch, oz, ox, oy, side_input);
286     } else {
287       // Use the NCHW_VECT_C format.
288       EXPECT_EQ(data_format, "NCHW_VECT_C");
289       EXPECT_EQ(oz % kVecWidth, 0);
290       DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input);
291     }
292   }
293 
294   // Add the scaling tensors.
295   DescribeTensor1D(1, op_context.op_info.add_inputs());
296   DescribeTensor1D(1, op_context.op_info.add_inputs());
297 
298   return op_context;
299 }
300 
301 // DescribeUnaryOp constructs an OpContext for the given operation applied to
302 // a 4-tensor with shape (size1, 1, 1, 1).
DescribeUnaryOp(const string & op,int size1)303 OpContext DescribeUnaryOp(const string& op, int size1) {
304   OpContext op_context;
305   SetCpuDevice(&op_context.op_info);
306   op_context.op_info.set_op(op);
307 
308   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
309   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_outputs());
310 
311   return op_context;
312 }
313 
314 // DescribeBinaryOp constructs an OpContext for the given operation applied to
315 // a 4-tensor with dimensions (size1, 1, 1, 1) and a 4-tensor with dimensions
316 // (2 * size1, size2, 1, 1).
317 //
318 // The choice of dimension here is arbitrary, and is used strictly to test the
319 // cost model for applying elementwise operations to tensors with unequal
320 // dimension values.
DescribeBinaryOp(const string & op,int size1,int size2)321 OpContext DescribeBinaryOp(const string& op, int size1, int size2) {
322   OpContext op_context;
323   SetCpuDevice(&op_context.op_info);
324   op_context.op_info.set_op(op);
325 
326   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
327   DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_inputs());
328   DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_outputs());
329 
330   return op_context;
331 }
332 
333 // DescribeBiasAdd constructs an OpContext for a BiasAdd applied to a 4-tensor
334 // with dimensions (1, 1, size2, size1) and a bias with dimension (size1),
335 // according to the constraint that the bias must be 1D with size equal to that
336 // of the last dimension of the input value.
DescribeBiasAdd(int size1,int size2)337 OpContext DescribeBiasAdd(int size1, int size2) {
338   OpContext op_context;
339   SetCpuDevice(&op_context.op_info);
340   op_context.op_info.set_op("BiasAdd");
341 
342   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
343   DescribeTensor1D(size1, op_context.op_info.add_inputs());
344   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
345 
346   return op_context;
347 }
348 
GetOutputSize(const int x,const int k,const int s,const string & padding)349 int GetOutputSize(const int x, const int k, const int s,
350                   const string& padding) {
351   if (padding == "SAME") {
352     return (x + s - 1) / s;
353   } else {
354     return (x - k + s) / s;
355   }
356 }
357 
GetPoolingOutputSize(const std::vector<int> & input,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)358 std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
359                                       const std::vector<int>& ksize,
360                                       const std::vector<int>& strides,
361                                       const string& data_format,
362                                       const string& padding) {
363   // h, w, and c indices: default with NHWC.
364   int h_index = 1;
365   int w_index = 2;
366   int c_index = 3;
367   if (data_format == "NCHW") {
368     h_index = 2;
369     w_index = 3;
370     c_index = 1;
371   }
372   // Extract parameters.
373   int n = input[0];
374   int h = input[h_index];
375   int w = input[w_index];
376   int c = input[c_index];
377   int sx = strides[h_index];
378   int sy = strides[w_index];
379   int kx = ksize[h_index];
380   int ky = ksize[w_index];
381 
382   // Output activation size: default with VALID padding.
383   int ho = GetOutputSize(h, kx, sx, padding);
384   int wo = GetOutputSize(w, ky, sy, padding);
385 
386   std::vector<int> output;
387   if (data_format == "NHWC") {
388     output = {n, ho, wo, c};
389   } else {
390     output = {n, c, ho, wo};
391   }
392   return output;
393 }
394 
395 // Helper functions for testing GetTensorShapeProtoFromTensorProto().
GetTensorProto(const DataType dtype,const std::vector<int64> & shape,const std::vector<int64> values,const bool tensor_content,TensorProto * tensor_proto)396 void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
397                     const std::vector<int64> values, const bool tensor_content,
398                     TensorProto* tensor_proto) {
399   tensor_proto->Clear();
400   TensorProto temp_tensor_proto;
401   temp_tensor_proto.set_dtype(dtype);
402   for (const auto& x : shape) {
403     temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
404   }
405   for (const auto& x : values) {
406     if (dtype == DT_INT64) {
407       temp_tensor_proto.add_int64_val(x);
408     } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
409                dtype == DT_UINT8) {
410       temp_tensor_proto.add_int_val(x);
411     } else if (dtype == DT_UINT32) {
412       temp_tensor_proto.add_uint32_val(x);
413     } else if (dtype == DT_UINT64) {
414       temp_tensor_proto.add_uint64_val(x);
415     } else {
416       CHECK(false) << "Unsupported dtype: " << dtype;
417     }
418   }
419   Tensor tensor(dtype);
420   CHECK(tensor.FromProto(temp_tensor_proto));
421   if (tensor_content) {
422     tensor.AsProtoTensorContent(tensor_proto);
423   } else {
424     tensor.AsProtoField(tensor_proto);
425   }
426 }
427 
DescribePoolingOp(const string & op_name,const std::vector<int> & x,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)428 OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
429                             const std::vector<int>& ksize,
430                             const std::vector<int>& strides,
431                             const string& data_format, const string& padding) {
432   OpContext op_context;
433   auto& op_info = op_context.op_info;
434   SetCpuDevice(&op_info);
435   op_info.set_op(op_name);
436 
437   const std::vector<int> y =
438       GetPoolingOutputSize(x, ksize, strides, data_format, padding);
439   if (op_name == "AvgPool" || op_name == "MaxPool") {
440     // input: x, output: y.
441     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
442     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
443   } else if (op_name == "AvgPoolGrad") {
444     // input: x's shape, y_grad, output: x_grad.
445     DescribeArbitraryRankInput({4}, DT_INT32, &op_info);
446     auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value();
447     GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]},
448                    /*tensor_content=*/false, tensor_proto);
449     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
450     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
451   } else if (op_name == "MaxPoolGrad") {
452     // input: x, y, y_grad, output: x_grad.
453     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
454     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
455     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
456     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
457   }
458   auto* attr = op_info.mutable_attr();
459   SetAttrValue(data_format, &(*attr)["data_format"]);
460   SetAttrValue(padding, &(*attr)["padding"]);
461   SetAttrValue(strides, &(*attr)["strides"]);
462   SetAttrValue(ksize, &(*attr)["ksize"]);
463   return op_context;
464 }
465 
DescribeFusedBatchNorm(const bool is_training,const bool is_grad,const std::vector<int> & x,const string & data_format)466 OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
467                                  const std::vector<int>& x,
468                                  const string& data_format) {
469   // First, get MaxPool op info with unit stride and unit window.
470   OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
471                                            {1, 1, 1, 1}, data_format, "SAME");
472   auto& op_info = op_context.op_info;
473   // Override op name.
474   if (is_grad) {
475     op_info.set_op("FusedBatchNormGrad");
476   } else {
477     op_info.set_op("FusedBatchNorm");
478   }
479 
480   // Add additional input output tensors.
481   if (is_grad) {
482     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
483   }
484   int num_1d_inputs = is_grad ? 3 : 4;
485   for (int i = 0; i < num_1d_inputs; i++) {
486     auto* tensor = op_info.add_inputs();
487     auto* shape = tensor->mutable_shape();
488     shape->add_dim()->set_size(x[3]);
489     tensor->set_dtype(DT_FLOAT);
490   }
491   for (int i = 0; i < 4; i++) {
492     auto* tensor = op_info.add_outputs();
493     auto* shape = tensor->mutable_shape();
494     shape->add_dim()->set_size(x[3]);
495     tensor->set_dtype(DT_FLOAT);
496   }
497 
498   // Delete unnecessary attr.
499   auto* attr = op_context.op_info.mutable_attr();
500   attr->erase("ksize");
501   attr->erase("strides");
502   attr->erase("padding");
503 
504   // Additional attrs for FusedBatchNorm.
505   SetAttrValue(is_training, &(*attr)["is_training"]);
506 
507   return op_context;
508 }
509 }  // namespace
510 
511 class OpLevelCostEstimatorTest : public ::testing::Test {
512  protected:
513   using BatchMatMulDimensions = OpLevelCostEstimator::BatchMatMulDimensions;
514 
PredictCosts(const OpContext & op_context) const515   Costs PredictCosts(const OpContext& op_context) const {
516     return estimator_.PredictCosts(op_context);
517   }
518 
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const519   int64 CountMatMulOperations(const OpInfo& op_info,
520                               bool* found_unknown_shapes) const {
521     return estimator_.CountMatMulOperations(op_info, found_unknown_shapes);
522   }
523 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const524   int64 CountBatchMatMulOperations(const OpInfo& op_info,
525                                    bool* found_unknown_shapes) const {
526     return estimator_.CountBatchMatMulOperations(op_info, found_unknown_shapes);
527   }
528 
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes) const529   int64 CountBatchMatMulOperations(const OpInfo& op_info,
530                                    BatchMatMulDimensions* batch_mat_mul,
531                                    bool* found_unknown_shapes) const {
532     return estimator_.CountBatchMatMulOperations(op_info, batch_mat_mul,
533                                                  found_unknown_shapes);
534   }
535 
SetComputeMemoryOverlap(bool value)536   void SetComputeMemoryOverlap(bool value) {
537     estimator_.compute_memory_overlap_ = value;
538   }
539 
ValidateOpDimensionsFromInputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)540   void ValidateOpDimensionsFromInputs(const int n, const int h, const int w,
541                                       const int c, const int kx, const int ky,
542                                       const int sx, const int sy,
543                                       const string& data_format,
544                                       const string& padding) {
545     OpContext op_context;
546     int ho;
547     int wo;
548     if (data_format == "NHWC") {
549       op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
550                                      {1, sx, sy, 1}, "NHWC", padding);
551       ho = op_context.op_info.outputs(0).shape().dim(1).size();
552       wo = op_context.op_info.outputs(0).shape().dim(2).size();
553     } else {
554       op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
555                                      {1, 1, sx, sy}, "NCHW", padding);
556       ho = op_context.op_info.outputs(0).shape().dim(2).size();
557       wo = op_context.op_info.outputs(0).shape().dim(3).size();
558     }
559 
560     bool found_unknown_shapes;
561     auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
562         op_context.op_info.inputs(0).shape(), op_context.op_info,
563         &found_unknown_shapes);
564     Padding padding_enum;
565     if (padding == "VALID") {
566       padding_enum = Padding::VALID;
567     } else {
568       padding_enum = Padding::SAME;
569     }
570     EXPECT_EQ(n, dims.batch);
571     EXPECT_EQ(h, dims.ix);
572     EXPECT_EQ(w, dims.iy);
573     EXPECT_EQ(c, dims.iz);
574     EXPECT_EQ(kx, dims.kx);
575     EXPECT_EQ(ky, dims.ky);
576     EXPECT_EQ(sx, dims.sx);
577     EXPECT_EQ(sy, dims.sy);
578     EXPECT_EQ(ho, dims.ox);
579     EXPECT_EQ(wo, dims.oy);
580     EXPECT_EQ(c, dims.oz);
581     EXPECT_EQ(padding_enum, dims.padding);
582   }
583 
584   OpLevelCostEstimator estimator_;
585 };
586 
587 class OpLevelBatchMatMulCostEstimatorTest
588     : public OpLevelCostEstimatorTest,
589       public ::testing::WithParamInterface<const char*> {
590  protected:
591   // Returns an OpInfo for a BatchMatMul
DescribeBatchMatMul(const std::vector<int> & dims_a,const std::vector<int> & dims_b)592   OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
593                                 const std::vector<int>& dims_b) {
594     OpContext op_context;
595     SetCpuDevice(&op_context.op_info);
596     op_context.op_info.set_op(GetParam());
597 
598     DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
599     DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
600     return op_context;
601   }
602 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const603   int64 CountBatchMatMulOperations(const OpInfo& op_info,
604                                    bool* found_unknown_shapes) const {
605     return OpLevelCostEstimatorTest::CountBatchMatMulOperations(
606         op_info, found_unknown_shapes);
607   }
608 
CountBatchMatMulDimProduct(const OpInfo & op_info,bool * found_unknown_shapes) const609   int64 CountBatchMatMulDimProduct(const OpInfo& op_info,
610                                    bool* found_unknown_shapes) const {
611     BatchMatMulDimensions batch_mat_mul;
612 
613     batch_mat_mul.matmul_dims.n = 0;
614     batch_mat_mul.matmul_dims.m = 0;
615     batch_mat_mul.matmul_dims.k = 0;
616 
617     OpLevelCostEstimatorTest::CountBatchMatMulOperations(
618         op_info, &batch_mat_mul, found_unknown_shapes);
619     int dimension_product = 1;
620     for (auto dim : batch_mat_mul.batch_dims) dimension_product *= dim;
621 
622     dimension_product *= batch_mat_mul.matmul_dims.n;
623     dimension_product *= batch_mat_mul.matmul_dims.m;
624     dimension_product *= batch_mat_mul.matmul_dims.k;
625 
626     return dimension_product;
627   }
628 };
629 
TEST_F(OpLevelCostEstimatorTest,TestPersistentOpCosts)630 TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
631   OpContext op_context;
632   SetCpuDevice(&op_context.op_info);
633   std::unordered_set<string> persistent_ops = {
634       "Const",       "Variable",       "VariableV2", "AutoReloadVariable",
635       "VarHandleOp", "ReadVariableOp",
636   };
637   // Minimum cost for all persistent ops.
638   for (const auto& op : persistent_ops) {
639     op_context.op_info.set_op(op);
640     auto cost = estimator_.PredictCosts(op_context);
641     EXPECT_EQ(Costs::Duration(0), cost.memory_time);
642     EXPECT_EQ(Costs::Duration(1), cost.compute_time);
643     EXPECT_EQ(Costs::Duration(1), cost.execution_time);
644     EXPECT_EQ(cost.num_ops_total, 1);
645     EXPECT_FALSE(cost.inaccurate);
646     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
647     EXPECT_EQ(cost.temporary_memory, 0);
648     EXPECT_EQ(cost.persistent_memory, 0);
649   }
650 }
651 
TEST_F(OpLevelCostEstimatorTest,TestGatherCosts)652 TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
653   std::vector<std::string> gather_ops = {"Gather", "GatherNd", "GatherV2"};
654 
655   for (const auto& op : gather_ops) {
656     OpContext op_context;
657     SetCpuDevice(&op_context.op_info);
658     op_context.op_info.set_op(op);
659 
660     // Huge first input shouldn't affect Gather execution and memory costs.
661     DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
662     DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
663     DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
664 
665     auto cost = estimator_.PredictCosts(op_context);
666     EXPECT_EQ(Costs::Duration(130), cost.memory_time);
667     EXPECT_EQ(Costs::Duration(16), cost.compute_time);
668     EXPECT_EQ(Costs::Duration(146), cost.execution_time);
669     EXPECT_EQ(cost.num_ops_total, 1);
670     EXPECT_FALSE(cost.inaccurate);
671     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
672     EXPECT_EQ(cost.temporary_memory, 0);
673     EXPECT_EQ(cost.persistent_memory, 0);
674   }
675 }
676 
TEST_F(OpLevelCostEstimatorTest,TestGatherCostsWithoutOutput)677 TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
678   OpContext op_context;
679   SetCpuDevice(&op_context.op_info);
680   op_context.op_info.set_op("Gather");
681 
682   // Huge first input shouldn't affect Gather execution and memory costs.
683   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
684   DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
685 
686   auto cost = estimator_.PredictCosts(op_context);
687   EXPECT_EQ(Costs::Duration(0), cost.memory_time);
688   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
689   EXPECT_EQ(Costs::Duration(0), cost.execution_time);
690   EXPECT_EQ(1, cost.num_ops_total);
691   EXPECT_TRUE(cost.inaccurate);
692   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
693   EXPECT_EQ(cost.temporary_memory, 0);
694   EXPECT_EQ(cost.persistent_memory, 0);
695 }
696 
TEST_F(OpLevelCostEstimatorTest,TestSliceCosts)697 TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
698   OpContext op_context;
699   SetCpuDevice(&op_context.op_info);
700   op_context.op_info.set_op("Slice");
701 
702   // Huge first input shouldn't affect Slice execution and memory costs.
703   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
704   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
705   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
706   DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
707 
708   auto cost = estimator_.PredictCosts(op_context);
709   EXPECT_EQ(Costs::Duration(81), cost.memory_time);
710   EXPECT_EQ(Costs::Duration(10), cost.compute_time);
711   EXPECT_EQ(Costs::Duration(91), cost.execution_time);
712   EXPECT_EQ(cost.num_ops_total, 1);
713   EXPECT_FALSE(cost.inaccurate);
714   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
715   EXPECT_EQ(cost.temporary_memory, 0);
716   EXPECT_EQ(cost.persistent_memory, 0);
717 }
718 
TEST_F(OpLevelCostEstimatorTest,TestStridedSliceCosts)719 TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) {
720   OpContext op_context;
721   SetCpuDevice(&op_context.op_info);
722   op_context.op_info.set_op("StridedSlice");
723 
724   // Huge first input shouldn't affect StridedSlice execution and memory costs.
725   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
726   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
727   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
728   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
729   DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
730 
731   auto cost = estimator_.PredictCosts(op_context);
732   EXPECT_EQ(Costs::Duration(81), cost.memory_time);
733   EXPECT_EQ(Costs::Duration(10), cost.compute_time);
734   EXPECT_EQ(Costs::Duration(91), cost.execution_time);
735   EXPECT_EQ(cost.num_ops_total, 1);
736   EXPECT_FALSE(cost.inaccurate);
737   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
738   EXPECT_EQ(cost.temporary_memory, 0);
739   EXPECT_EQ(cost.persistent_memory, 0);
740 }
741 
TEST_F(OpLevelCostEstimatorTest,TestScatterOps)742 TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
743   std::vector<string> scatter_ops = {"ScatterAdd",   "ScatterDiv", "ScatterMax",
744                                      "ScatterMin",   "ScatterMul", "ScatterSub",
745                                      "ScatterUpdate"};
746   for (const auto& op : scatter_ops) {
747     // Test updates.shape = indices.shape + ref.shape[1:]
748     {
749       OpContext op_context;
750       SetCpuDevice(&op_context.op_info);
751       op_context.op_info.set_op(op);
752       // Huge first dimension in input shouldn't affect Scatter execution and
753       // memory costs.
754       DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
755       DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
756       DescribeArbitraryRankInput({16, 10}, DT_FLOAT, &op_context.op_info);
757       DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
758                                   &op_context.op_info);
759 
760       auto cost = estimator_.PredictCosts(op_context);
761       EXPECT_EQ(Costs::Duration(205), cost.memory_time);
762       EXPECT_EQ(Costs::Duration(16), cost.compute_time);
763       EXPECT_EQ(Costs::Duration(221), cost.execution_time);
764       EXPECT_EQ(cost.num_ops_total, 1);
765       EXPECT_FALSE(cost.inaccurate);
766       EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
767       EXPECT_EQ(cost.temporary_memory, 0);
768       EXPECT_EQ(cost.persistent_memory, 0);
769     }
770 
771     // Test updates.shape = [] and INT32 indices
772     {
773       OpContext op_context;
774       SetCpuDevice(&op_context.op_info);
775       op_context.op_info.set_op(op);
776       // Huge first dimension in input shouldn't affect Scatter execution and
777       // memory costs.
778       DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
779       DescribeArbitraryRankInput({16}, DT_INT32, &op_context.op_info);
780       DescribeArbitraryRankInput({}, DT_FLOAT, &op_context.op_info);
781       DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
782                                   &op_context.op_info);
783 
784       auto cost = estimator_.PredictCosts(op_context);
785       EXPECT_EQ(Costs::Duration(135), cost.memory_time);
786       EXPECT_EQ(Costs::Duration(16), cost.compute_time);
787       EXPECT_EQ(Costs::Duration(151), cost.execution_time);
788       EXPECT_EQ(1, cost.num_ops_total);
789       EXPECT_FALSE(cost.inaccurate);
790       EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
791     }
792   }
793 }
794 
TEST_F(OpLevelCostEstimatorTest,BiasAddExecutionTime)795 TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
796   auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
797   EXPECT_EQ(Costs::Duration(8400), cost.memory_time);
798   EXPECT_EQ(Costs::Duration(1000), cost.compute_time);
799   EXPECT_EQ(Costs::Duration(9400), cost.execution_time);
800   EXPECT_EQ(cost.num_ops_total, 1);
801   EXPECT_FALSE(cost.inaccurate);
802   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
803   EXPECT_EQ(cost.temporary_memory, 0);
804   EXPECT_EQ(cost.persistent_memory, 0);
805 }
806 
TEST_F(OpLevelCostEstimatorTest,Conv2DExecutionTime)807 TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) {
808   auto cost = PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
809   EXPECT_EQ(Costs::Duration(233780), cost.memory_time);
810   EXPECT_EQ(Costs::Duration(354877440), cost.compute_time);
811   EXPECT_EQ(Costs::Duration(355111220), cost.execution_time);
812   EXPECT_EQ(cost.num_ops_total, 1);
813   EXPECT_FALSE(cost.inaccurate);
814   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
815   EXPECT_EQ(cost.temporary_memory, 0);
816   EXPECT_EQ(cost.persistent_memory, 0);
817 }
818 
TEST_F(OpLevelCostEstimatorTest,InvalidConv2DConfig)819 TEST_F(OpLevelCostEstimatorTest, InvalidConv2DConfig) {
820   // Convolution ops.
821   const std::vector<std::string> conv_ops = {
822       "Conv2D",
823       "Conv2DBackpropFilter",
824       "Conv2DBackpropInput",
825       "DepthwiseConv2dNative",
826       "DepthwiseConv2dNativeBackpropFilter",
827       "DepthwiseConv2dNativeBackpropInput",
828   };
829   // A valid Conv2D config.
830   const std::vector<int> valid_conv_config = {16, 19, 19, 48, 48, 5, 5, 256};
831   for (const auto& op : conv_ops) {
832     // Test with setting one value in conv config to zero.
833     // PredictCosts() should return zero costs.
834     for (int i = 0; i < valid_conv_config.size(); ++i) {
835       std::vector<int> conv_config(valid_conv_config);
836       conv_config[i] = 0;
837       auto op_context = DescribeConvolution(
838           conv_config[0], conv_config[1], conv_config[2], conv_config[3],
839           conv_config[4], conv_config[5], conv_config[6], conv_config[7]);
840       op_context.op_info.set_op(op);
841       auto cost = PredictCosts(op_context);
842       EXPECT_EQ(Costs::Duration(0), cost.memory_time);
843       EXPECT_EQ(Costs::Duration(0), cost.compute_time);
844       EXPECT_EQ(Costs::Duration(0), cost.execution_time);
845       EXPECT_EQ(1, cost.num_ops_total);
846       EXPECT_TRUE(cost.inaccurate);
847       EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
848     }
849   }
850 }
851 
TEST_F(OpLevelCostEstimatorTest,DepthwiseConv2dNativeExecutionTime)852 TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
853   auto cost =
854       PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
855   EXPECT_EQ(Costs::Duration(112340), cost.memory_time);
856   EXPECT_EQ(Costs::Duration(4158720), cost.compute_time);
857   EXPECT_EQ(Costs::Duration(4271060), cost.execution_time);
858   EXPECT_EQ(cost.num_ops_total, 1);
859   EXPECT_FALSE(cost.inaccurate);
860   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
861   EXPECT_EQ(cost.temporary_memory, 0);
862   EXPECT_EQ(cost.persistent_memory, 0);
863 }
864 
TEST_F(OpLevelCostEstimatorTest,DummyExecutionTime)865 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
866   auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
867   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
868   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
869   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
870   EXPECT_EQ(cost.num_ops_total, 1);
871   EXPECT_TRUE(cost.inaccurate);
872   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
873   EXPECT_EQ(cost.temporary_memory, 0);
874   EXPECT_EQ(cost.persistent_memory, 0);
875 }
876 
TEST_F(OpLevelCostEstimatorTest,ExecutionTimeSumOrMax)877 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
878   SetComputeMemoryOverlap(true);
879   auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
880   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
881   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
882   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);  // max(2000, 200)
883   EXPECT_EQ(cost.num_ops_total, 1);
884   EXPECT_TRUE(cost.inaccurate);
885   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
886   EXPECT_EQ(cost.temporary_memory, 0);
887   EXPECT_EQ(cost.persistent_memory, 0);
888   SetComputeMemoryOverlap(false);  // Set it back to default.
889 }
890 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO_NoSideInput)891 TEST_F(OpLevelCostEstimatorTest,
892        FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
893   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
894       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
895       "NCHW", "HWIO"));
896   EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
897   EXPECT_EQ(Costs::Duration(355321037), cost.compute_time);
898   EXPECT_EQ(Costs::Duration(356146382), cost.execution_time);
899   EXPECT_EQ(cost.num_ops_total, 1);
900   EXPECT_FALSE(cost.inaccurate);
901   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
902   EXPECT_EQ(cost.temporary_memory, 0);
903   EXPECT_EQ(cost.persistent_memory, 0);
904 }
905 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO)906 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
907   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
908       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
909       "NCHW", "HWIO"));
910   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
911   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
912   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
913   EXPECT_EQ(cost.num_ops_total, 1);
914   EXPECT_FALSE(cost.inaccurate);
915   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
916   EXPECT_EQ(cost.temporary_memory, 0);
917   EXPECT_EQ(cost.persistent_memory, 0);
918 }
919 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW)920 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
921   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
922       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
923       "NCHW", "OIHW"));
924   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
925   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
926   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
927   EXPECT_EQ(cost.num_ops_total, 1);
928   EXPECT_FALSE(cost.inaccurate);
929   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
930   EXPECT_EQ(cost.temporary_memory, 0);
931   EXPECT_EQ(cost.persistent_memory, 0);
932 }
933 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_HWIO)934 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
935   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
936       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
937       "NHWC", "HWIO"));
938   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
939   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
940   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
941   EXPECT_EQ(cost.num_ops_total, 1);
942   EXPECT_FALSE(cost.inaccurate);
943   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
944   EXPECT_EQ(cost.temporary_memory, 0);
945   EXPECT_EQ(cost.persistent_memory, 0);
946 }
947 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_OIHW)948 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
949   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
950       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
951       "NHWC", "OIHW"));
952   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
953   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
954   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
955   EXPECT_EQ(cost.num_ops_total, 1);
956   EXPECT_FALSE(cost.inaccurate);
957   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
958   EXPECT_EQ(cost.temporary_memory, 0);
959   EXPECT_EQ(cost.persistent_memory, 0);
960 }
961 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW)962 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
963   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
964       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
965       "NCHW_VECT_C", "OIHW"));
966   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
967   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
968   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
969   EXPECT_EQ(cost.num_ops_total, 1);
970   EXPECT_FALSE(cost.inaccurate);
971   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
972   EXPECT_EQ(cost.temporary_memory, 0);
973   EXPECT_EQ(cost.persistent_memory, 0);
974 }
975 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW_VECT_I)976 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
977   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
978       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
979       "NCHW", "OIHW_VECT_I"));
980   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
981   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
982   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
983   EXPECT_EQ(cost.num_ops_total, 1);
984   EXPECT_FALSE(cost.inaccurate);
985   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
986   EXPECT_EQ(cost.temporary_memory, 0);
987   EXPECT_EQ(cost.persistent_memory, 0);
988 }
989 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I)990 TEST_F(OpLevelCostEstimatorTest,
991        FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I) {
992   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
993       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
994       "NCHW_VECT_C", "OIHW_VECT_I"));
995   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
996   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
997   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
998   EXPECT_EQ(cost.num_ops_total, 1);
999   EXPECT_FALSE(cost.inaccurate);
1000   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1001   EXPECT_EQ(cost.temporary_memory, 0);
1002   EXPECT_EQ(cost.persistent_memory, 0);
1003 }
1004 
TEST_F(OpLevelCostEstimatorTest,MulExecutionTime)1005 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
1006   auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
1007   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1008   EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1009   EXPECT_EQ(Costs::Duration(2200), cost.execution_time);
1010   EXPECT_EQ(cost.num_ops_total, 1);
1011   EXPECT_FALSE(cost.inaccurate);
1012   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1013   EXPECT_EQ(cost.temporary_memory, 0);
1014   EXPECT_EQ(cost.persistent_memory, 0);
1015 }
1016 
TEST_F(OpLevelCostEstimatorTest,MulBroadcastExecutionTime)1017 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) {
1018   auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 2));
1019   EXPECT_EQ(Costs::Duration(3600), cost.memory_time);
1020   EXPECT_EQ(Costs::Duration(400), cost.compute_time);
1021   EXPECT_EQ(Costs::Duration(4000), cost.execution_time);
1022   EXPECT_EQ(cost.num_ops_total, 1);
1023   EXPECT_FALSE(cost.inaccurate);
1024   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1025   EXPECT_EQ(cost.temporary_memory, 0);
1026   EXPECT_EQ(cost.persistent_memory, 0);
1027 }
1028 
TEST_F(OpLevelCostEstimatorTest,ModExecutionTime)1029 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) {
1030   auto cost = PredictCosts(DescribeBinaryOp("Mod", 1000, 1));
1031   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1032   EXPECT_EQ(Costs::Duration(1600), cost.compute_time);
1033   EXPECT_EQ(Costs::Duration(3600), cost.execution_time);
1034   EXPECT_EQ(cost.num_ops_total, 1);
1035   EXPECT_FALSE(cost.inaccurate);
1036   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1037   EXPECT_EQ(cost.temporary_memory, 0);
1038   EXPECT_EQ(cost.persistent_memory, 0);
1039 }
1040 
TEST_F(OpLevelCostEstimatorTest,SquaredDifferenceExecutionTime)1041 TEST_F(OpLevelCostEstimatorTest, SquaredDifferenceExecutionTime) {
1042   auto cost = PredictCosts(DescribeBinaryOp("SquaredDifference", 1000, 2));
1043   EXPECT_EQ(cost.memory_time, Costs::Duration(3600));
1044   EXPECT_EQ(cost.compute_time, Costs::Duration(800));
1045   EXPECT_EQ(cost.execution_time, Costs::Duration(4400));
1046   EXPECT_EQ(cost.num_ops_total, 1);
1047   EXPECT_FALSE(cost.inaccurate);
1048   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1049   EXPECT_EQ(cost.temporary_memory, 0);
1050   EXPECT_EQ(cost.persistent_memory, 0);
1051 }
1052 
TEST_F(OpLevelCostEstimatorTest,UnaryOpExecutionTime)1053 TEST_F(OpLevelCostEstimatorTest, UnaryOpExecutionTime) {
1054   std::vector<std::pair<std::string, int>> unary_ops = {
1055       {"All", 1},      {"ArgMax", 1}, {"Cast", 1},  {"Max", 1},
1056       {"Min", 1},      {"Prod", 1},   {"Relu", 1},  {"Relu6", 1},
1057       {"Softmax", 43}, {"Sum", 1},    {"TopKV2", 1}};
1058 
1059   const int kTensorSize = 1000;
1060   for (auto unary_op : unary_ops) {
1061     OpContext op_context = DescribeUnaryOp(unary_op.first, kTensorSize);
1062 
1063     const int kExpectedMemoryTime = 800;
1064     int expected_compute_time = std::ceil(
1065         unary_op.second * kTensorSize /
1066         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1067 
1068     auto cost = PredictCosts(op_context);
1069     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
1070     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time))
1071         << unary_op.first;
1072     EXPECT_EQ(cost.execution_time,
1073               Costs::Duration(expected_compute_time + kExpectedMemoryTime));
1074     EXPECT_EQ(cost.num_ops_total, 1);
1075     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1076     EXPECT_FALSE(cost.inaccurate);
1077     EXPECT_EQ(cost.temporary_memory, 0);
1078     EXPECT_EQ(cost.persistent_memory, 0);
1079   }
1080 }
1081 
TEST_F(OpLevelCostEstimatorTest,BinaryOpExecutionTime)1082 TEST_F(OpLevelCostEstimatorTest, BinaryOpExecutionTime) {
1083   std::vector<std::pair<std::string, int>> binary_ops = {
1084       {"Select", 1},
1085       {"SelectV2", 1},
1086       {"SquaredDifference", 2},
1087       {"Where", 1},
1088   };
1089 
1090   const int kTensorSize1 = 1000;
1091   const int kTensorSize2 = 2;
1092   for (auto binary_op : binary_ops) {
1093     OpContext op_context =
1094         DescribeBinaryOp(binary_op.first, kTensorSize1, kTensorSize2);
1095 
1096     const int kExpectedMemoryTime = 3600;
1097     int expected_compute_time = std::ceil(
1098         binary_op.second * kTensorSize1 * kTensorSize2 * 2 /
1099         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1100 
1101     auto cost = PredictCosts(op_context);
1102     EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time)
1103         << binary_op.first;
1104     EXPECT_EQ(Costs::Duration(expected_compute_time), cost.compute_time)
1105         << binary_op.first;
1106     EXPECT_EQ(Costs::Duration(expected_compute_time + kExpectedMemoryTime),
1107               cost.execution_time)
1108         << binary_op.first;
1109     EXPECT_EQ(cost.num_ops_total, 1);
1110     EXPECT_FALSE(cost.inaccurate);
1111     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1112     EXPECT_EQ(cost.temporary_memory, 0);
1113     EXPECT_EQ(cost.persistent_memory, 0);
1114   }
1115 }
1116 
TEST_F(OpLevelCostEstimatorTest,BroadcastAddExecutionTime)1117 TEST_F(OpLevelCostEstimatorTest, BroadcastAddExecutionTime) {
1118   OpContext op_context;
1119   SetCpuDevice(&op_context.op_info);
1120   op_context.op_info.set_op("Add");
1121 
1122   DescribeTensor1D(100, op_context.op_info.add_inputs());
1123   DescribeTensor4D(1, 10, 1, 1, op_context.op_info.add_inputs());
1124 
1125   auto cost = PredictCosts(op_context);
1126   EXPECT_EQ(Costs::Duration(44), cost.memory_time);
1127   EXPECT_EQ(Costs::Duration(100), cost.compute_time);
1128   EXPECT_EQ(Costs::Duration(144), cost.execution_time);
1129   EXPECT_EQ(cost.num_ops_total, 1);
1130   EXPECT_FALSE(cost.inaccurate);
1131   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1132   EXPECT_EQ(cost.temporary_memory, 0);
1133   EXPECT_EQ(cost.persistent_memory, 0);
1134 }
1135 
TEST_F(OpLevelCostEstimatorTest,UnknownOrPartialShape)1136 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
1137   {
1138     auto cost = PredictCosts(DescribeMatMul(2, 4, 7, 7));
1139     EXPECT_EQ(1, cost.num_ops_total);
1140     EXPECT_FALSE(cost.inaccurate);
1141     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1142   }
1143   {
1144     auto cost = PredictCosts(DescribeMatMul(-1, 4, 7, 7));
1145     EXPECT_EQ(1, cost.num_ops_total);
1146     EXPECT_TRUE(cost.inaccurate);
1147     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1148   }
1149   {
1150     auto cost = PredictCosts(DescribeMatMul(2, 4, -1, 7));
1151     EXPECT_EQ(1, cost.num_ops_total);
1152     EXPECT_TRUE(cost.inaccurate);
1153     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1154   }
1155   {
1156     auto cost =
1157         PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1158     EXPECT_EQ(1, cost.num_ops_total);
1159     EXPECT_FALSE(cost.inaccurate);
1160     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1161   }
1162   {
1163     auto cost =
1164         PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256));
1165     EXPECT_EQ(1, cost.num_ops_total);
1166     EXPECT_TRUE(cost.inaccurate);
1167     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1168   }
1169 }
1170 
TEST_P(OpLevelBatchMatMulCostEstimatorTest,TestBatchMatMul)1171 TEST_P(OpLevelBatchMatMulCostEstimatorTest, TestBatchMatMul) {
1172   {
1173     auto cost = PredictCosts(DescribeBatchMatMul({}, {}));
1174     EXPECT_EQ(1, cost.num_ops_total);
1175     EXPECT_TRUE(cost.inaccurate);
1176     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1177   }
1178   {
1179     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {}));
1180     EXPECT_EQ(1, cost.num_ops_total);
1181     EXPECT_TRUE(cost.inaccurate);
1182     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1183   }
1184   {
1185     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2}));
1186     EXPECT_EQ(1, cost.num_ops_total);
1187     EXPECT_FALSE(cost.inaccurate);
1188     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1189   }
1190   {
1191     auto cost = PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2}));
1192     EXPECT_EQ(1, cost.num_ops_total);
1193     EXPECT_FALSE(cost.inaccurate);
1194     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1195   }
1196   {
1197     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}));
1198     EXPECT_EQ(1, cost.num_ops_total);
1199     EXPECT_FALSE(cost.inaccurate);
1200     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1201   }
1202   bool matmul_inaccurate = false;
1203   bool batch_matmul_inaccurate = false;
1204   EXPECT_EQ(
1205       CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1206                             &matmul_inaccurate),
1207       CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
1208                                  &batch_matmul_inaccurate));
1209   EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
1210   EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1211                                        &matmul_inaccurate),
1212             CountBatchMatMulOperations(
1213                 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
1214                 &batch_matmul_inaccurate));
1215   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1216   EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1217                                        &matmul_inaccurate),
1218             CountBatchMatMulOperations(
1219                 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
1220                 &batch_matmul_inaccurate));
1221   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1222 
1223   // Test the count to make sure that they extracted the dimensions correctly
1224   int prod = CountBatchMatMulDimProduct(
1225       DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}).op_info,
1226       &batch_matmul_inaccurate);
1227   EXPECT_EQ(prod, 16);
1228   EXPECT_FALSE(batch_matmul_inaccurate);
1229 
1230   // Exercise the bad cases of a batchMatMul.
1231   OpContext bad_batch = DescribeBatchMatMul({2, 4}, {4, 2});
1232   bad_batch.op_info.set_op("notBatchMatMul");
1233   prod =
1234       CountBatchMatMulDimProduct(bad_batch.op_info, &batch_matmul_inaccurate);
1235 
1236   EXPECT_EQ(prod, 0);
1237   EXPECT_TRUE(batch_matmul_inaccurate);
1238 
1239   // Exercise a transpose case of a batchMatMul
1240   OpContext transpose_batch = DescribeBatchMatMul({2, 4, 3, 1}, {4, 2});
1241   auto attr = transpose_batch.op_info.mutable_attr();
1242   (*attr)["adj_x"].set_b(true);
1243   (*attr)["adj_y"].set_b(true);
1244 
1245   prod = CountBatchMatMulDimProduct(transpose_batch.op_info,
1246                                     &batch_matmul_inaccurate);
1247   EXPECT_EQ(prod, 12);
1248 }
1249 INSTANTIATE_TEST_SUITE_P(TestBatchMatMul, OpLevelBatchMatMulCostEstimatorTest,
1250                          ::testing::Values("BatchMatMul", "BatchMatMulV2"));
1251 
TEST_F(OpLevelCostEstimatorTest,SparseTensorDenseMatMul)1252 TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
1253   // Unknown shape cases
1254   {
1255     auto cost =
1256         PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
1257     EXPECT_EQ(1, cost.num_ops_total);
1258     EXPECT_TRUE(cost.inaccurate);
1259     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1260   }
1261   {
1262     auto cost =
1263         PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
1264     EXPECT_EQ(1, cost.num_ops_total);
1265     EXPECT_TRUE(cost.inaccurate);
1266     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1267   }
1268   {
1269     auto cost =
1270         PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
1271     EXPECT_EQ(1, cost.num_ops_total);
1272     EXPECT_TRUE(cost.inaccurate);
1273     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1274   }
1275   {
1276     auto cost =
1277         PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
1278     EXPECT_EQ(1, cost.num_ops_total);
1279     EXPECT_TRUE(cost.inaccurate);
1280     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1281   }
1282   // Known shape cases
1283   {
1284     auto cost = PredictCosts(
1285         DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
1286     EXPECT_EQ(1, cost.num_ops_total);
1287     EXPECT_FALSE(cost.inaccurate);
1288     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1289     EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1290     EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1291   }
1292   {
1293     // Same cost as above case because cost does not depend on k_dim
1294     auto cost = PredictCosts(
1295         DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
1296     EXPECT_EQ(1, cost.num_ops_total);
1297     EXPECT_FALSE(cost.inaccurate);
1298     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1299     EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1300     EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1301   }
1302 }
1303 
ExpectTensorShape(const std::vector<int64> & expected,const TensorShapeProto & tensor_shape_proto)1304 void ExpectTensorShape(const std::vector<int64>& expected,
1305                        const TensorShapeProto& tensor_shape_proto) {
1306   TensorShape tensor_shape_expected(expected);
1307   TensorShape tensor_shape(tensor_shape_proto);
1308 
1309   EXPECT_EQ(tensor_shape_expected, tensor_shape);
1310 }
1311 
TEST_F(OpLevelCostEstimatorTest,GetTensorShapeProtoFromTensorProto)1312 TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
1313   TensorProto tensor_proto;
1314   TensorShapeProto tensor_shape_proto;
1315 
1316   // Dimension larger than max value; should fail while converting to
1317   // Tensor class.
1318   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
1319   EXPECT_FALSE(
1320       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1321 
1322   tensor_proto.Clear();
1323   // Expect only 1D shape.
1324   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
1325   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
1326   EXPECT_FALSE(
1327       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1328 
1329   // Expect only handle integer data types.
1330   GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
1331   EXPECT_FALSE(
1332       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1333 
1334   // Check GetTensorShapeProtoFromTensorProto() returns correct values.
1335   {
1336     std::vector<int64> shape_expected = {10, 20, 30, 40};
1337     GetTensorProto(DT_INT32, {4}, shape_expected,
1338                    /*tensor_content=*/false, &tensor_proto);
1339     EXPECT_TRUE(
1340         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1341     ExpectTensorShape(shape_expected, tensor_shape_proto);
1342   }
1343 
1344   {
1345     std::vector<int64> shape_expected = {40, 20, 90, 40};
1346     GetTensorProto(DT_INT64, {4}, shape_expected,
1347                    /*tensor_content=*/false, &tensor_proto);
1348     EXPECT_TRUE(
1349         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1350     ExpectTensorShape(shape_expected, tensor_shape_proto);
1351   }
1352 
1353   {
1354     std::vector<int64> shape_expected = {10, 20, 30, 40};
1355     GetTensorProto(DT_INT32, {4}, shape_expected,
1356                    /*tensor_content=*/true, &tensor_proto);
1357     EXPECT_TRUE(
1358         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1359     ExpectTensorShape(shape_expected, tensor_shape_proto);
1360   }
1361 
1362   {
1363     std::vector<int64> shape_expected = {40, 20, 90, 40};
1364     GetTensorProto(DT_INT64, {4}, shape_expected,
1365                    /*tensor_content=*/true, &tensor_proto);
1366     EXPECT_TRUE(
1367         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1368     ExpectTensorShape(shape_expected, tensor_shape_proto);
1369   }
1370 }
1371 
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputs)1372 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
1373   std::vector<string> paddings = {"VALID", "SAME"};
1374   std::vector<string> formats = {"NHWC", "NCHW"};
1375   for (const auto& p : paddings) {
1376     for (const auto& f : formats) {
1377       // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1378       ValidateOpDimensionsFromInputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
1379       ValidateOpDimensionsFromInputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
1380       ValidateOpDimensionsFromInputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
1381       ValidateOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
1382     }
1383   }
1384 }
1385 
TEST_F(OpLevelCostEstimatorTest,PredictMaxPool)1386 TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
1387   auto predict_max_pool = [this](const int n, const int in, const int c,
1388                                  const int k, const int s,
1389                                  const string& padding) -> Costs {
1390     OpContext op_context = DescribePoolingOp(
1391         "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1392     return estimator_.PredictCosts(op_context);
1393   };
1394 
1395   {
1396     // Typical 3xz3 window with 2x2 stride.
1397     auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
1398     EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
1399     EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
1400     EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1401     EXPECT_EQ(costs.num_ops_total, 1);
1402     EXPECT_FALSE(costs.inaccurate);
1403     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1404     EXPECT_EQ(costs.temporary_memory, 0);
1405     EXPECT_EQ(costs.persistent_memory, 0);
1406   }
1407   {
1408     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1409     auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
1410     EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1411     EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1412     EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1413     EXPECT_EQ(1, costs.num_ops_total);
1414     EXPECT_FALSE(costs.inaccurate);
1415     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1416   }
1417   {
1418     // 2x2 window with 3x3 stride.
1419     auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
1420     EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
1421     EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
1422     EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1423     EXPECT_EQ(1, costs.num_ops_total);
1424     EXPECT_FALSE(costs.inaccurate);
1425     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1426   }
1427 }
1428 
TEST_F(OpLevelCostEstimatorTest,PredictMaxPoolGrad)1429 TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
1430   auto predict_max_pool_grad = [this](const int n, const int in, const int c,
1431                                       const int k, const int s,
1432                                       const string& padding) -> Costs {
1433     OpContext op_context =
1434         DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
1435                           {1, s, s, 1}, "NHWC", padding);
1436     return estimator_.PredictCosts(op_context);
1437   };
1438 
1439   {
1440     // Typical 3x3 window with 2x2 stride.
1441     auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
1442     EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
1443     EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
1444     EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1445     EXPECT_EQ(costs.num_ops_total, 1);
1446     EXPECT_FALSE(costs.inaccurate);
1447     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1448     EXPECT_EQ(costs.temporary_memory, 0);
1449     EXPECT_EQ(costs.persistent_memory, 0);
1450   }
1451   {
1452     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1453     auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
1454     EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
1455     EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
1456     EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1457     EXPECT_EQ(1, costs.num_ops_total);
1458     EXPECT_FALSE(costs.inaccurate);
1459     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1460   }
1461   {
1462     // 2x2 window with 3x3 stride.
1463     auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
1464     EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
1465     EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
1466     EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
1467     EXPECT_EQ(1, costs.num_ops_total);
1468     EXPECT_FALSE(costs.inaccurate);
1469     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1470   }
1471 }
1472 
TEST_F(OpLevelCostEstimatorTest,PredictAvgPool)1473 TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
1474   auto predict_avg_pool = [this](const int n, const int in, const int c,
1475                                  const int k, const int s,
1476                                  const string& padding) -> Costs {
1477     OpContext op_context = DescribePoolingOp(
1478         "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1479     return estimator_.PredictCosts(op_context);
1480   };
1481 
1482   {
1483     // Typical 3x3 window with 2x2 stride.
1484     auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
1485     EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
1486     EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
1487     EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1488     EXPECT_EQ(costs.num_ops_total, 1);
1489     EXPECT_FALSE(costs.inaccurate);
1490     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1491     EXPECT_EQ(costs.temporary_memory, 0);
1492     EXPECT_EQ(costs.persistent_memory, 0);
1493   }
1494   {
1495     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1496     auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
1497     EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1498     EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1499     EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1500     EXPECT_EQ(1, costs.num_ops_total);
1501     EXPECT_FALSE(costs.inaccurate);
1502     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1503   }
1504   {
1505     // 2x2 window with 3x3 stride.
1506     auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
1507     EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
1508     EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
1509     EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1510     EXPECT_EQ(1, costs.num_ops_total);
1511     EXPECT_FALSE(costs.inaccurate);
1512     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1513   }
1514 }
1515 
TEST_F(OpLevelCostEstimatorTest,PredictAvgPoolGrad)1516 TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
1517   auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
1518                                       const int k, const int s,
1519                                       const string& padding) -> Costs {
1520     OpContext op_context =
1521         DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
1522                           {1, s, s, 1}, "NHWC", padding);
1523     return estimator_.PredictCosts(op_context);
1524   };
1525 
1526   {
1527     // Typical 3xz3 window with 2x2 stride.
1528     auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
1529     EXPECT_EQ(Costs::Duration(1305602), costs.execution_time);
1530     EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
1531     EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1532     EXPECT_EQ(costs.num_ops_total, 1);
1533     EXPECT_FALSE(costs.inaccurate);
1534     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1535     EXPECT_EQ(costs.temporary_memory, 0);
1536     EXPECT_EQ(costs.persistent_memory, 0);
1537   }
1538   {
1539     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1540     auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
1541     EXPECT_EQ(Costs::Duration(960002), costs.execution_time);
1542     EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
1543     EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1544     EXPECT_EQ(1, costs.num_ops_total);
1545     EXPECT_FALSE(costs.inaccurate);
1546     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1547   }
1548   {
1549     // 2x2 window with 3x3 stride.
1550     auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
1551     EXPECT_EQ(Costs::Duration(862082), costs.execution_time);
1552     EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
1553     EXPECT_EQ(Costs::Duration(689666), costs.memory_time);
1554     EXPECT_EQ(1, costs.num_ops_total);
1555     EXPECT_FALSE(costs.inaccurate);
1556     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1557   }
1558 }
1559 
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNorm)1560 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
1561   auto predict_fused_bn = [this](const int n, const int in, const int c,
1562                                  const bool is_training) -> Costs {
1563     OpContext op_context = DescribeFusedBatchNorm(
1564         is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
1565     return estimator_.PredictCosts(op_context);
1566   };
1567 
1568   {
1569     auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
1570     EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
1571     EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
1572     EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
1573     EXPECT_EQ(costs.num_ops_total, 1);
1574     EXPECT_FALSE(costs.inaccurate);
1575     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1576     EXPECT_EQ(costs.temporary_memory, 0);
1577     EXPECT_EQ(costs.persistent_memory, 0);
1578   }
1579 
1580   {
1581     auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
1582     EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
1583     EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
1584     EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
1585     EXPECT_EQ(1, costs.num_ops_total);
1586     EXPECT_FALSE(costs.inaccurate);
1587     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1588   }
1589 
1590   {
1591     auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
1592     EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
1593     EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
1594     EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
1595     EXPECT_EQ(1, costs.num_ops_total);
1596     EXPECT_FALSE(costs.inaccurate);
1597     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1598   }
1599 
1600   {
1601     auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
1602     EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
1603     EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
1604     EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
1605     EXPECT_FALSE(costs.inaccurate);
1606     EXPECT_EQ(1, costs.num_ops_total);
1607     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1608   }
1609 }
1610 
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNormGrad)1611 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
1612   auto predict_fused_bn_grad = [this](const int n, const int in,
1613                                       const int c) -> Costs {
1614     OpContext op_context = DescribeFusedBatchNorm(
1615         /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
1616     return estimator_.PredictCosts(op_context);
1617   };
1618 
1619   {
1620     auto costs = predict_fused_bn_grad(10, 20, 96);
1621     EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
1622     EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
1623     EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
1624     EXPECT_EQ(costs.num_ops_total, 1);
1625     EXPECT_FALSE(costs.inaccurate);
1626     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1627     EXPECT_EQ(costs.temporary_memory, 0);
1628     EXPECT_EQ(costs.persistent_memory, 0);
1629   }
1630 
1631   {
1632     auto costs = predict_fused_bn_grad(128, 7, 384);
1633     EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
1634     EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
1635     EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
1636     EXPECT_EQ(1, costs.num_ops_total);
1637     EXPECT_FALSE(costs.inaccurate);
1638     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1639   }
1640 }
1641 
TEST_F(OpLevelCostEstimatorTest,MaybeGetMinimumShape)1642 TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
1643   {
1644     TensorShapeProto x;
1645     x.set_unknown_rank(true);
1646     bool unknown_shapes = false;
1647     TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1648     EXPECT_TRUE(unknown_shapes);
1649     ExpectTensorShape({1, 1, 1, 1}, y);
1650   }
1651 
1652   {
1653     TensorShapeProto x;
1654     x.set_unknown_rank(false);
1655     bool unknown_shapes = false;
1656     TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
1657     EXPECT_FALSE(unknown_shapes);
1658     ExpectTensorShape({1}, y);
1659   }
1660 
1661   {
1662     TensorShapeProto x;
1663     x.set_unknown_rank(false);
1664     bool unknown_shapes = false;
1665     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1666     EXPECT_FALSE(unknown_shapes);
1667     ExpectTensorShape({1, 1}, y);
1668   }
1669 
1670   {
1671     TensorShapeProto x;
1672     x.set_unknown_rank(false);
1673     x.add_dim()->set_size(10);
1674     x.add_dim()->set_size(20);
1675     bool unknown_shapes = false;
1676     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1677     EXPECT_FALSE(unknown_shapes);
1678     ExpectTensorShape({10, 20}, y);
1679 
1680     unknown_shapes = false;
1681     TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1682     EXPECT_TRUE(unknown_shapes);
1683     EXPECT_EQ(4, z.dim_size());
1684     ExpectTensorShape({10, 20, 1, 1}, z);
1685   }
1686 
1687   {
1688     TensorShapeProto x;
1689     x.set_unknown_rank(false);
1690     x.add_dim()->set_size(10);
1691     x.add_dim()->set_size(20);
1692     x.add_dim()->set_size(-1);
1693     x.add_dim()->set_size(20);
1694     bool unknown_shapes = false;
1695     TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1696     EXPECT_TRUE(unknown_shapes);
1697     ExpectTensorShape({10, 20, 1, 20}, y);
1698   }
1699 
1700   {
1701     TensorShapeProto x;
1702     x.set_unknown_rank(false);
1703     x.add_dim()->set_size(10);
1704     x.add_dim()->set_size(20);
1705     x.add_dim()->set_size(30);
1706     x.add_dim()->set_size(20);
1707     bool unknown_shapes = false;
1708     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1709     EXPECT_TRUE(unknown_shapes);
1710     ExpectTensorShape({10, 20}, y);
1711   }
1712 }
1713 
TEST_F(OpLevelCostEstimatorTest,IntermediateRdWrBandwidth)1714 TEST_F(OpLevelCostEstimatorTest, IntermediateRdWrBandwidth) {
1715   TestOpLevelCostEstimator estimator;
1716 
1717   // Compute limited.
1718   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1,
1719                                      /*gb_per_sec=*/1));
1720   estimator.SetComputeMemoryOverlap(true);
1721   auto cost = estimator.PredictCosts(
1722       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1723   EXPECT_EQ(Costs::Duration(3548774400), cost.execution_time);
1724   EXPECT_EQ(cost.execution_time, cost.compute_time);
1725 
1726   estimator.SetComputeMemoryOverlap(false);
1727   cost = estimator.PredictCosts(
1728       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1729   EXPECT_EQ(Costs::Duration(3551112192), cost.execution_time);
1730   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1731                                      cost.intermediate_memory_time);
1732 
1733   // Memory limited.
1734   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1735                                      /*gb_per_sec=*/1));
1736   estimator.SetComputeMemoryOverlap(true);
1737   cost = estimator.PredictCosts(
1738       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1739   EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1740   EXPECT_EQ(cost.execution_time, cost.memory_time);
1741 
1742   estimator.SetComputeMemoryOverlap(false);
1743   cost = estimator.PredictCosts(
1744       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1745   EXPECT_EQ(Costs::Duration(2373281), cost.execution_time);
1746   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1747                                      cost.intermediate_memory_time);
1748 
1749   // Intermediate memory bandwidth limited.
1750   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1751                                      /*gb_per_sec=*/9999,
1752                                      /*intermediate_read_gb_per_sec=*/1,
1753                                      /*intermediate_write_gb_per_sec=*/1));
1754   estimator.SetComputeMemoryOverlap(true);
1755   cost = estimator.PredictCosts(
1756       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1757   EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1758   EXPECT_EQ(cost.execution_time, cost.intermediate_memory_time);
1759 
1760   estimator.SetComputeMemoryOverlap(false);
1761   cost = estimator.PredictCosts(
1762       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1763   EXPECT_EQ(Costs::Duration(2373515), cost.execution_time);
1764   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1765                                      cost.intermediate_memory_time);
1766 }
1767 
TEST_F(OpLevelCostEstimatorTest,Einsum)1768 TEST_F(OpLevelCostEstimatorTest, Einsum) {
1769   {  // Test a simple matrix multiplication.
1770     auto cost = PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"));
1771     EXPECT_EQ(Costs::Duration(104000), cost.execution_time);
1772     EXPECT_EQ(Costs::Duration(100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1773               cost.compute_time);
1774     EXPECT_EQ(Costs::Duration(4000), cost.memory_time);
1775     EXPECT_EQ(cost.num_ops_total, 1);
1776     EXPECT_FALSE(cost.inaccurate);
1777     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1778     EXPECT_EQ(cost.temporary_memory, 0);
1779     EXPECT_EQ(cost.persistent_memory, 0);
1780 
1781     // Einsums and XlaEinsums should be estimated similarly.
1782     EXPECT_EQ(PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1783                   .execution_time,
1784               PredictCosts(DescribeXlaEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1785                   .execution_time);
1786   }
1787   {  // Test a simple batch matrix multiplication.
1788     auto cost = PredictCosts(
1789         DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"));
1790     EXPECT_EQ(Costs::Duration(25 * 104000), cost.execution_time);
1791     EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1792               cost.compute_time);
1793     EXPECT_EQ(Costs::Duration(25 * 4000), cost.memory_time);
1794     EXPECT_EQ(1, cost.num_ops_total);
1795     EXPECT_FALSE(cost.inaccurate);
1796     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1797 
1798     // Einsums and XlaEinsums should be estimated similarly.
1799     EXPECT_EQ(PredictCosts(
1800                   DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"))
1801                   .execution_time,
1802               PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50, 25},
1803                                              "Bik,jkB->Bij"))
1804                   .execution_time);
1805   }
1806   {  // Test multiple batch dimensions.
1807     auto cost = PredictCosts(DescribeEinsum(
1808         {25, 16, 100, 50}, {16, 100, 50, 25}, "BNik,NjkB->BNij"));
1809     EXPECT_EQ(Costs::Duration(16 * 25 * 104000), cost.execution_time);
1810     EXPECT_EQ(
1811         Costs::Duration(16 * 25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1812         cost.compute_time);
1813     EXPECT_EQ(Costs::Duration(16 * 25 * 4000), cost.memory_time);
1814     EXPECT_EQ(1, cost.num_ops_total);
1815     EXPECT_FALSE(cost.inaccurate);
1816     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1817 
1818     // Einsums and XlaEinsums should be estimated similarly.
1819     EXPECT_EQ(
1820         PredictCosts(DescribeEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1821                                     "BNik,NjkB->BNij"))
1822             .execution_time,
1823         PredictCosts(DescribeXlaEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1824                                        "BNik,NjkB->BNij"))
1825             .execution_time);
1826   }
1827   {  // Test multiple M dimensions.
1828     auto cost =
1829         PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"));
1830     EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1831     EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1832               cost.compute_time);
1833     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1834     EXPECT_EQ(1, cost.num_ops_total);
1835     EXPECT_FALSE(cost.inaccurate);
1836     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1837 
1838     // Einsums and XlaEinsums should be estimated similarly.
1839     EXPECT_EQ(
1840         PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1841             .execution_time,
1842         PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1843             .execution_time);
1844   }
1845   {  // Test multiple N dimensions.
1846     auto cost =
1847         PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"));
1848     EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1849     EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1850               cost.compute_time);
1851     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1852     EXPECT_EQ(1, cost.num_ops_total);
1853     EXPECT_FALSE(cost.inaccurate);
1854     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1855 
1856     // Einsums and XlaEinsums should be estimated similarly.
1857     EXPECT_EQ(
1858         PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1859             .execution_time,
1860         PredictCosts(DescribeXlaEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1861             .execution_time);
1862   }
1863   {  // Test multiple contracting dimensions.
1864     auto cost = PredictCosts(
1865         DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"));
1866     EXPECT_EQ(Costs::Duration(2600000), cost.execution_time);
1867     EXPECT_EQ(Costs::Duration(100 * 50 * 25 * 100 * 2 / (1000 * 10 * 1e-3)),
1868               cost.compute_time);
1869     EXPECT_EQ(Costs::Duration(100000), cost.memory_time);
1870     EXPECT_EQ(1, cost.num_ops_total);
1871     EXPECT_FALSE(cost.inaccurate);
1872     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1873 
1874     // Einsums and XlaEinsums should be estimated similarly.
1875     EXPECT_EQ(PredictCosts(
1876                   DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"))
1877                   .execution_time,
1878               PredictCosts(DescribeXlaEinsum({100, 50, 25}, {100, 50, 25},
1879                                              "ikl,jkl->ij"))
1880                   .execution_time);
1881   }
1882   {  // Test a simple matrix transpose.
1883     auto cost = PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji"));
1884     EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
1885     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1886     EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1887     EXPECT_EQ(1, cost.num_ops_total);
1888     EXPECT_TRUE(cost.inaccurate);
1889     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1890 
1891     // Einsums and XlaEinsums should be estimated similarly.
1892     EXPECT_EQ(
1893         PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji")).execution_time,
1894         PredictCosts(DescribeXlaEinsum({100, 50}, {}, "ij->ji"))
1895             .execution_time);
1896   }
1897   {  // Test a malformed Einsum equation: Mismatch between shapes and equation.
1898     auto cost =
1899         PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"));
1900     EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1901     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1902     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1903     EXPECT_EQ(1, cost.num_ops_total);
1904     EXPECT_TRUE(cost.inaccurate);
1905     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1906 
1907     // Einsums and XlaEinsums should be estimated similarly.
1908     EXPECT_EQ(
1909         PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1910             .execution_time,
1911         PredictCosts(DescribeXlaEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1912             .execution_time);
1913 
1914     cost = PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"));
1915     EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1916     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1917     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1918     EXPECT_EQ(1, cost.num_ops_total);
1919     EXPECT_TRUE(cost.inaccurate);
1920     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1921 
1922     // Einsums and XlaEinsums should be estimated similarly.
1923     EXPECT_EQ(
1924         PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1925             .execution_time,
1926         PredictCosts(DescribeXlaEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1927             .execution_time);
1928   }
1929   {  // Test an unsupported Einsum: ellipsis
1930     auto cost = PredictCosts(DescribeEinsum(
1931         {100, 50, 25, 16}, {50, 100, 32, 12}, "ik...,kl...->il..."));
1932     EXPECT_EQ(Costs::Duration(1568000), cost.execution_time);
1933     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1934     EXPECT_EQ(Costs::Duration(1568000), cost.memory_time);
1935     EXPECT_EQ(1, cost.num_ops_total);
1936     EXPECT_TRUE(cost.inaccurate);
1937     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1938 
1939     // Einsums and XlaEinsums should be estimated similarly.
1940     EXPECT_EQ(
1941         PredictCosts(DescribeEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1942                                     "ik...,kl...->il..."))
1943             .execution_time,
1944         PredictCosts(DescribeXlaEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1945                                        "ik...,kl...->il..."))
1946             .execution_time);
1947   }
1948   {  // Test a malformed/unsupported Einsum: repeated indices
1949     auto cost =
1950         PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"));
1951     EXPECT_EQ(Costs::Duration(202000), cost.execution_time);
1952     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1953     EXPECT_EQ(Costs::Duration(202000), cost.memory_time);
1954     EXPECT_EQ(1, cost.num_ops_total);
1955     EXPECT_TRUE(cost.inaccurate);
1956     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1957 
1958     // Einsums and XlaEinsums should be estimated similarly.
1959     EXPECT_EQ(
1960         PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
1961             .execution_time,
1962         PredictCosts(DescribeXlaEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
1963             .execution_time);
1964   }
1965   {  // Test missing shapes.
1966     auto cost = PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"));
1967     EXPECT_EQ(Costs::Duration(3020), cost.execution_time);
1968     EXPECT_EQ(Costs::Duration(1 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1969               cost.compute_time);
1970     EXPECT_EQ(Costs::Duration(2020), cost.memory_time);
1971     EXPECT_EQ(1, cost.num_ops_total);
1972     EXPECT_TRUE(cost.inaccurate);
1973     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1974 
1975     // Einsums and XlaEinsums should be estimated similarly.
1976     EXPECT_EQ(PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
1977                   .execution_time,
1978               PredictCosts(DescribeXlaEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
1979                   .execution_time);
1980   }
1981 }
1982 
TEST_F(OpLevelCostEstimatorTest,PredictResourceVariableOps)1983 TEST_F(OpLevelCostEstimatorTest, PredictResourceVariableOps) {
1984   TestOpLevelCostEstimator estimator;
1985   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1, /*gb_per_sec=*/1));
1986 
1987   {
1988     OpContext op_context;
1989     op_context.op_info.set_op("AssignVariableOp");
1990     DescribeDummyTensor(op_context.op_info.add_inputs());
1991     DescribeTensor1D(100, op_context.op_info.add_inputs());
1992     auto cost = estimator.PredictCosts(op_context);
1993     EXPECT_EQ(Costs::Duration(400), cost.memory_time);
1994     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1995     EXPECT_EQ(Costs::Duration(400), cost.execution_time);
1996     EXPECT_FALSE(cost.inaccurate);
1997     EXPECT_EQ(cost.temporary_memory, 0);
1998     EXPECT_EQ(cost.persistent_memory, 0);
1999   }
2000 
2001   {
2002     OpContext op_context;
2003     op_context.op_info.set_op("AssignSubVariableOp");
2004     DescribeDummyTensor(op_context.op_info.add_inputs());
2005     DescribeTensor1D(100, op_context.op_info.add_inputs());
2006     auto cost = estimator.PredictCosts(op_context);
2007     EXPECT_EQ(Costs::Duration(400), cost.memory_time);
2008     EXPECT_EQ(Costs::Duration(100), cost.compute_time);
2009     EXPECT_EQ(Costs::Duration(400), cost.execution_time);
2010     EXPECT_FALSE(cost.inaccurate);
2011   }
2012 }
2013 
TEST_F(OpLevelCostEstimatorTest,AddNExecutionTime)2014 TEST_F(OpLevelCostEstimatorTest, AddNExecutionTime) {
2015   OpContext op_context;
2016   SetCpuDevice(&op_context.op_info);
2017   op_context.op_info.set_op("AddN");
2018 
2019   DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2020   DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2021   DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2022 
2023   auto cost = PredictCosts(op_context);
2024   EXPECT_EQ(Costs::Duration(1200), cost.memory_time);
2025   EXPECT_EQ(Costs::Duration(200), cost.compute_time);
2026   EXPECT_EQ(Costs::Duration(1400), cost.execution_time);
2027   EXPECT_EQ(cost.num_ops_total, 1);
2028   EXPECT_FALSE(cost.inaccurate);
2029   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2030   EXPECT_EQ(cost.temporary_memory, 0);
2031   EXPECT_EQ(cost.persistent_memory, 0);
2032 }
2033 
TEST_F(OpLevelCostEstimatorTest,IdentityOpExecutionTime)2034 TEST_F(OpLevelCostEstimatorTest, IdentityOpExecutionTime) {
2035   std::vector<std::string> identity_ops = {
2036       "_Recv",         "_Send",        "BitCast",         "Identity",
2037       "Enter",         "Exit",         "IdentityN",       "Merge",
2038       "NextIteration", "Placeholder",  "PreventGradient", "RefIdentity",
2039       "Reshape",       "StopGradient", "Switch"};
2040 
2041   const int kTensorSize = 1000;
2042   for (auto identity_op : identity_ops) {
2043     OpContext op_context = DescribeUnaryOp(identity_op, kTensorSize);
2044 
2045     const int kExpectedMemoryTime = 0;
2046     const int kExpectedComputeTime = 1;
2047 
2048     auto cost = PredictCosts(op_context);
2049     EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2050     EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2051     EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2052               cost.execution_time);
2053     EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2054     EXPECT_EQ(cost.num_ops_total, 1);
2055     EXPECT_FALSE(cost.inaccurate);
2056     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2057     EXPECT_EQ(cost.temporary_memory, 0);
2058     EXPECT_EQ(cost.persistent_memory, 0);
2059   }
2060 }
2061 
TEST_F(OpLevelCostEstimatorTest,PureMemoryOpExecutionTime)2062 TEST_F(OpLevelCostEstimatorTest, PureMemoryOpExecutionTime) {
2063   std::vector<std::string> reshape_ops = {
2064       "ConcatV2",     "DataFormatVecPermute",
2065       "DepthToSpace", "ExpandDims",
2066       "Fill",         "OneHot",
2067       "Pack",         "Range",
2068       "SpaceToDepth", "Split",
2069       "Squeeze",      "Transpose",
2070       "Tile",         "Unpack"};
2071 
2072   const int kTensorSize = 1000;
2073   for (auto reshape_op : reshape_ops) {
2074     OpContext op_context = DescribeUnaryOp(reshape_op, kTensorSize);
2075 
2076     const int kExpectedMemoryTime = 800;
2077     const int kExpectedComputeTime = 0;
2078 
2079     auto cost = PredictCosts(op_context);
2080     EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2081     EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2082     EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2083               cost.execution_time);
2084     EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2085     EXPECT_EQ(cost.num_ops_total, 1);
2086     EXPECT_FALSE(cost.inaccurate);
2087     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2088     EXPECT_EQ(cost.temporary_memory, 0);
2089     EXPECT_EQ(cost.persistent_memory, 0);
2090   }
2091 }
2092 
TEST_F(OpLevelCostEstimatorTest,ResizeBilinearExecutionTime)2093 TEST_F(OpLevelCostEstimatorTest, ResizeBilinearExecutionTime) {
2094   const int kImageDim = 255;
2095   const int kChannelSize = 10;
2096   const int kComputeLerpCost = 9;
2097   {
2098     OpContext op_context;
2099     SetCpuDevice(&op_context.op_info);
2100     op_context.op_info.set_op("ResizeBilinear");
2101     DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2102                      op_context.op_info.add_inputs());
2103     // Test with no output.
2104     auto cost = PredictCosts(op_context);
2105     ExpectZeroCost(cost);
2106     op_context.op_info.clear_inputs();
2107 
2108     DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2109     // Test with no input.
2110     cost = PredictCosts(op_context);
2111     ExpectZeroCost(cost);
2112   }
2113   {
2114     // Test with size 0 output.
2115     OpContext op_context;
2116     SetCpuDevice(&op_context.op_info);
2117     op_context.op_info.set_op("ResizeBilinear");
2118 
2119     DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2120                      op_context.op_info.add_inputs());
2121     const int kExpectedMemoryTime = kImageDim * kImageDim * 4;
2122     DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2123 
2124     // As the half_pixel_centers attr was not set, cost should be inaccurate
2125     // with 0 compute time.
2126     auto cost = PredictCosts(op_context);
2127     EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2128     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2129     EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2130     EXPECT_TRUE(cost.inaccurate);
2131     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2132     EXPECT_EQ(cost.temporary_memory, 0);
2133     EXPECT_EQ(cost.persistent_memory, 0);
2134 
2135     AttrValue half_pixel_centers;
2136     half_pixel_centers.set_b(false);
2137     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2138         half_pixel_centers;
2139     cost = PredictCosts(op_context);
2140     // Compute time depends only on output size, so compute time is 0.
2141     EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2142     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2143     EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2144     EXPECT_FALSE(cost.inaccurate);
2145     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2146   }
2147 
2148   // Test with non-zero output size.
2149   const int kOutputImageDim = 100;
2150   OpContext op_context;
2151   SetCpuDevice(&op_context.op_info);
2152   op_context.op_info.set_op("ResizeBilinear");
2153   DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2154                    op_context.op_info.add_inputs());
2155   DescribeTensor4D(1, kOutputImageDim, kOutputImageDim, kChannelSize,
2156                    op_context.op_info.add_outputs());
2157   const int kExpectedMemoryTime =
2158       (kImageDim * kImageDim + kOutputImageDim * kOutputImageDim) * 4;
2159 
2160   {
2161     // Cost of calculating weights without using half_pixel_centers.
2162     AttrValue half_pixel_centers;
2163     half_pixel_centers.set_b(false);
2164     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2165         half_pixel_centers;
2166     const int kInterpWeightCost = 10;
2167     const int num_ops =
2168         kInterpWeightCost * (kOutputImageDim * 2) +
2169         kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2170     const int expected_compute_time = std::ceil(
2171         num_ops /
2172         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2173 
2174     const auto cost = PredictCosts(op_context);
2175     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2176     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2177     EXPECT_EQ(cost.execution_time,
2178               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2179     EXPECT_FALSE(cost.inaccurate);
2180     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2181   }
2182 
2183   {
2184     // Cost of calculating weights using half_pixel_centers.
2185     AttrValue half_pixel_centers;
2186     half_pixel_centers.set_b(true);
2187     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2188         half_pixel_centers;
2189     const int kInterpWeightCost = 12;
2190     const int num_ops =
2191         kInterpWeightCost * (kOutputImageDim * 2) +
2192         kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2193     const int expected_compute_time = std::ceil(
2194         num_ops /
2195         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2196 
2197     const auto cost = PredictCosts(op_context);
2198     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2199     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2200     EXPECT_EQ(cost.execution_time,
2201               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2202     EXPECT_FALSE(cost.inaccurate);
2203     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2204   }
2205 
2206   {
2207     // Cost with very large tensor.
2208     op_context.op_info.clear_outputs();
2209     // Number of elements in tensor exceeds 2^32.
2210     constexpr int64_t kLargeOutputImageDim = 40000;
2211     DescribeTensor4D(1, kLargeOutputImageDim, kLargeOutputImageDim,
2212                      kChannelSize, op_context.op_info.add_outputs());
2213     const int64_t kInterpWeightCost = 12;
2214     // Using half_pixel_centers.
2215     AttrValue half_pixel_centers;
2216     half_pixel_centers.set_b(true);
2217     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2218         half_pixel_centers;
2219 
2220     const int64_t num_ops =
2221         kInterpWeightCost * (kLargeOutputImageDim * 2) +
2222         kComputeLerpCost *
2223             (kLargeOutputImageDim * kLargeOutputImageDim * kChannelSize);
2224     const int64_t expected_compute_time = std::ceil(
2225         num_ops /
2226         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2227 
2228     const int64_t expected_memory_time =
2229         (kImageDim * kImageDim + kLargeOutputImageDim * kLargeOutputImageDim) *
2230         4;
2231 
2232     const auto cost = PredictCosts(op_context);
2233     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2234     EXPECT_EQ(cost.memory_time, Costs::Duration(expected_memory_time));
2235     EXPECT_EQ(cost.execution_time,
2236               Costs::Duration(expected_memory_time + expected_compute_time));
2237     EXPECT_FALSE(cost.inaccurate);
2238     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2239   }
2240 }
2241 
TEST_F(OpLevelCostEstimatorTest,CropAndResizeExecutionTime)2242 TEST_F(OpLevelCostEstimatorTest, CropAndResizeExecutionTime) {
2243   const int kImageDim = 255;
2244   const int kChannelSize = 10;
2245   const int kOutputImageDim = 100;
2246   const int kNumBoxes = 10;
2247   const int kOutputElements =
2248       kNumBoxes * kOutputImageDim * kOutputImageDim * kChannelSize;
2249   OpContext op_context;
2250   SetCpuDevice(&op_context.op_info);
2251   op_context.op_info.set_op("CropAndResize");
2252   DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2253                    op_context.op_info.add_inputs());
2254   DescribeArbitraryRankInput({kNumBoxes, 4}, DT_INT64, &op_context.op_info);
2255   DescribeTensor4D(kNumBoxes, kOutputImageDim, kOutputImageDim, kChannelSize,
2256                    op_context.op_info.add_outputs());
2257 
2258   // Note this is time [ns, default in Duration in Costs], not bytes;
2259   // whereas memory bandwidth from SetCpuDevice() is 10GB/s.
2260   const int kExpectedMemoryTime =
2261       (kImageDim * kImageDim * 4 +  // input image in float.
2262        kNumBoxes * 4 * 8 / 10 +     // boxes (kNumBoxes x 4) in int64.
2263        kNumBoxes * kOutputImageDim * kOutputImageDim * 4);  // output in float.
2264   // Note that input image and output image has kChannelSize dim, which is 10,
2265   // hence, no need to divide it by 10 (bandwidth).
2266 
2267   {
2268     // Cost of CropAndResize with bilinear interpolation.
2269     AttrValue method;
2270     method.set_s("bilinear");
2271     (*op_context.op_info.mutable_attr())["method"] = method;
2272     int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2273                   4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2274                   3 * kNumBoxes * kOutputImageDim +
2275                   3 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2276                   13 * kOutputElements;
2277     const int expected_compute_time = std::ceil(
2278         num_ops /
2279         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2280 
2281     const auto cost = PredictCosts(op_context);
2282     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2283     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2284     EXPECT_EQ(cost.execution_time,
2285               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2286     EXPECT_FALSE(cost.inaccurate);
2287     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2288   }
2289 
2290   {
2291     // Cost of CropAndResize when nearest pixel is taken.
2292     AttrValue method;
2293     method.set_s("nearest");
2294     (*op_context.op_info.mutable_attr())["method"] = method;
2295     int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2296                   4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2297                   2 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2298                   kOutputElements;
2299     const int expected_compute_time = std::ceil(
2300         num_ops /
2301         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2302 
2303     const auto cost = PredictCosts(op_context);
2304     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2305     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2306     EXPECT_EQ(cost.execution_time,
2307               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2308     EXPECT_FALSE(cost.inaccurate);
2309     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2310   }
2311 }
2312 
2313 }  // end namespace grappler
2314 }  // end namespace tensorflow
2315