• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/dnn.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/core/lib/hash/hash.h"
20 #include "tensorflow/stream_executor/lib/stringprintf.h"
21 
22 namespace stream_executor {
23 namespace dnn {
24 
hash() const25 uint64 AlgorithmDesc::hash() const {
26   return ::tensorflow::Hash64Combine(algo_id(), tensor_ops_enabled());
27 }
28 
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)29 bool DnnSupport::GetConvolveAlgorithms(
30     bool with_winograd_nonfused, int cc_major, int cc_minor,
31     std::vector<AlgorithmDesc>* out_algorithms) {
32   return false;
33 }
34 
GetRnnAlgorithms(std::vector<AlgorithmDesc> * out_algorithms)35 bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
36   return false;
37 }
38 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)39 bool DnnSupport::GetConvolveBackwardDataAlgorithms(
40     bool with_winograd_nonfused, int cc_major, int cc_minor,
41     std::vector<AlgorithmDesc>* out_algorithms) {
42   return false;
43 }
44 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)45 bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
46     bool with_winograd_nonfused, int cc_major, int cc_minor,
47     std::vector<AlgorithmDesc>* out_algorithms) {
48   return false;
49 }
50 
QuantizedActivationModeString(QuantizedActivationMode mode)51 string QuantizedActivationModeString(QuantizedActivationMode mode) {
52   switch (mode) {
53     case dnn::QuantizedActivationMode::k8Bit:
54       return "uint8";
55     case dnn::QuantizedActivationMode::k16Bit:
56       return "uint16";
57     case dnn::QuantizedActivationMode::k32Bit:
58       return "int32";
59     default:
60       LOG(FATAL) << "Unknown quantized_activation_mode "
61                  << static_cast<int32>(mode);
62   }
63   return "unknown quantized_activation_mode";
64 }
65 
ActivationModeString(ActivationMode mode)66 string ActivationModeString(ActivationMode mode) {
67   switch (mode) {
68     case ActivationMode::kSigmoid:
69       return "sigmoid";
70     case ActivationMode::kRelu:
71       return "relu";
72     case ActivationMode::kRelu6:
73       return "relu6";
74     case ActivationMode::kReluX:
75       return "reluX";
76     case ActivationMode::kTanh:
77       return "tanh";
78     case ActivationMode::kBandPass:
79       return "bandpass";
80     default:
81       LOG(FATAL) << "Unknown activation_mode " << static_cast<int32>(mode);
82   }
83   return "unknown activation_mode";
84 }
85 
ElementwiseOperationString(ElementwiseOperation op)86 string ElementwiseOperationString(ElementwiseOperation op) {
87   switch (op) {
88     case ElementwiseOperation::kAdd:
89       return "add";
90     case ElementwiseOperation::kMultiply:
91       return "multiply";
92     default:
93       LOG(FATAL) << "Unknown elementwise op " << static_cast<int32>(op);
94   }
95   return "unknown element wise op";
96 }
97 
DataLayoutString(DataLayout layout)98 string DataLayoutString(DataLayout layout) {
99   switch (layout) {
100     case DataLayout::kYXDepthBatch:
101       return "YXDepthBatch";
102     case DataLayout::kYXBatchDepth:
103       return "YXBatchDepth";
104     case DataLayout::kBatchYXDepth:
105       return "BatchYXDepth";
106     case DataLayout::kBatchDepthYX:
107       return "BatchDepthYX";
108     case DataLayout::kBatchDepthYX4:
109       return "BatchDepthYX4";
110     default:
111       LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout);
112   }
113   return "unknown data layout";
114 }
115 
FilterLayoutString(FilterLayout layout)116 string FilterLayoutString(FilterLayout layout) {
117   switch (layout) {
118     case FilterLayout::kOutputInputYX:
119       return "OutputInputYX";
120     case FilterLayout::kOutputYXInput:
121       return "OutputYXInput";
122     case FilterLayout::kOutputInputYX4:
123       return "OutputInputYX4";
124     case FilterLayout::kInputYXOutput:
125       return "InputYXOutput";
126     case FilterLayout::kYXInputOutput:
127       return "YXInputOutput";
128     default:
129       LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(layout);
130   }
131   return "unknown filter layout";
132 }
133 
PadAlignmentString(PadAlignment alignment)134 string PadAlignmentString(PadAlignment alignment) {
135   switch (alignment) {
136     case PadAlignment::kDefault:
137       return "default";
138     case PadAlignment::kCudnnPadding:
139       return "cuDNN padding";
140     case PadAlignment::kTensorFlowPadding:
141       return "TensorFlow padding";
142   }
143   return "unknown pad alignment";
144 }
145 
operator <<(std::ostream & str,dnn::PadAlignment alignment)146 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
147   return str << PadAlignmentString(alignment);
148 }
149 
ShortPoolingModeString(PoolingMode mode)150 string ShortPoolingModeString(PoolingMode mode) {
151   switch (mode) {
152     case PoolingMode::kMaximum:
153       return "Max";
154     case PoolingMode::kAverage:
155       return "Avg";
156     default:
157       LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode);
158   }
159   return "unknown filter layout";
160 }
161 
GetDimIndices(const DataLayout & layout,const int data_dims)162 std::tuple<int, int, int> GetDimIndices(const DataLayout& layout,
163                                         const int data_dims) {
164   int depth_idx, batch_idx, spatial_idx;
165   switch (layout) {
166     case DataLayout::kYXBatchDepth:
167       depth_idx = data_dims - 1;
168       batch_idx = data_dims - 2;
169       spatial_idx = 0;
170       break;
171 
172     case DataLayout::kYXDepthBatch:
173       depth_idx = data_dims - 2;
174       batch_idx = data_dims - 1;
175       spatial_idx = 0;
176       break;
177 
178     case DataLayout::kBatchYXDepth:
179       depth_idx = data_dims - 1;
180       batch_idx = 0;
181       spatial_idx = 1;
182       break;
183 
184     case DataLayout::kBatchDepthYX:
185     case DataLayout::kBatchDepthYX4:
186       depth_idx = 1;
187       batch_idx = 0;
188       spatial_idx = 2;
189       break;
190 
191     default:
192       LOG(FATAL) << "Unknown layout " << layout;
193   }
194 
195   return std::make_tuple(depth_idx, batch_idx, spatial_idx);
196 }
197 
ReorderDims(const std::vector<int64> & input,const DataLayout & from,const DataLayout & to)198 std::vector<int64> ReorderDims(const std::vector<int64>& input,
199                                const DataLayout& from, const DataLayout& to) {
200   if (from == to) return input;
201 
202   int d_idx_from, b_idx_from, spatial_idx_from;
203   int d_idx_to, b_idx_to, spatial_idx_to;
204 
205   std::tie(d_idx_from, b_idx_from, spatial_idx_from) =
206       GetDimIndices(from, input.size());
207   std::tie(d_idx_to, b_idx_to, spatial_idx_to) =
208       GetDimIndices(to, input.size());
209 
210   std::vector<int64> reordered(input.size());
211   reordered[b_idx_to] = input[b_idx_from];
212   reordered[d_idx_to] = input[d_idx_from];
213 
214   for (size_t i = 0; i < input.size() - 2;
215        i++, spatial_idx_from++, spatial_idx_to++) {
216     reordered[spatial_idx_to] = input[spatial_idx_from];
217   }
218 
219   return reordered;
220 }
221 
222 // -- AlgorithmConfig
223 
ToString() const224 string AlgorithmConfig::ToString() const {
225   AlgorithmDesc::Index algo_id = -1;
226   if (algorithm().has_value()) {
227     algo_id = algorithm()->algo_id();
228   }
229   AlgorithmDesc::Index algo_id_no_scratch = -1;
230   if (algorithm_no_scratch().has_value()) {
231     algo_id_no_scratch = algorithm_no_scratch()->algo_id();
232   }
233   return absl::StrCat(algo_id, ", ", algo_id_no_scratch);
234 }
235 
236 // -- BatchDescriptor
237 
BatchDescriptor(int ndims)238 BatchDescriptor::BatchDescriptor(int ndims)
239     : value_max_(0.0),
240       value_min_(0.0),
241       quantized_activation_mode_(QuantizedActivationMode::k8Bit) {
242   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
243   set_layout(DataLayout::kYXDepthBatch);
244 }
245 
BatchDescriptor()246 BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {}
247 
full_dims(const DataLayout & layout) const248 std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const {
249   std::vector<int64> bdyx_dims(ndims() + 2);
250   bdyx_dims[0] = count();
251   bdyx_dims[1] = feature_map_count();
252   std::copy(spatial_size().begin(), spatial_size().end(),
253             bdyx_dims.begin() + 2);
254   return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout);
255 }
256 
full_strides(const DataLayout & layout) const257 std::vector<int64> BatchDescriptor::full_strides(
258     const DataLayout& layout) const {
259   if (this->layout() == DataLayout::kBatchDepthYX4) {
260     LOG(FATAL)
261         << "Cannot compute full strides for batch descriptor " << ToString()
262         << ", because its layout is kBatchDepthYX4. In fact, "
263            "cudnnSetTensorNdDescriptor doesn't work for kBatchDepthYX4 at all. "
264            "Use cudnnSetTensor4DDescriptor to set cudnnTensorDescriptor_t "
265            "instead.";
266   }
267   std::vector<int64> phys_dims = full_dims(this->layout());
268   std::vector<int64> phys_strides(phys_dims.size());
269   phys_strides[ndims() + 1] = 1;
270   for (int i = ndims(); i >= 0; i--) {
271     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
272   }
273   return ReorderDims(phys_strides, this->layout(), layout);
274 }
275 
CloneFrom(const BatchDescriptor & other)276 void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
277   tensor_ = other.tensor_;
278   value_max_ = other.value_max_;
279   value_min_ = other.value_min_;
280   quantized_activation_mode_ = other.quantized_activation_mode_;
281 }
282 
ToString() const283 string BatchDescriptor::ToString() const {
284   string spatial;
285   for (int i = 0; i < ndims(); i++) {
286     port::Appendf(&spatial, "%lld ", spatial_size()[i]);
287   }
288   return port::Printf(
289       "{count: %lld feature_map_count: %lld spatial: %s "
290       "value_min: %f value_max: %f layout: %s}",
291       count(), feature_map_count(), spatial.c_str(), value_min_, value_max_,
292       DataLayoutString(layout()).c_str());
293 }
294 
ToShortString() const295 string BatchDescriptor::ToShortString() const {
296   // All the constituent strings are less than 15 characters, so the
297   // small string optimization ensures that there will be at most one
298   // heap memory allocation.
299   string depth = absl::StrCat("d", feature_map_count());
300   string batch = absl::StrCat("b", count());
301 
302   string spatial = "s";
303   for (int i = 0; i < ndims(); i++) {
304     port::Appendf(&spatial, "%lld ", spatial_size()[i]);
305   }
306 
307   string suffix;
308   if (value_min() != value_max()) {
309     absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
310   }
311   if (quantized_activation_mode() == QuantizedActivationMode::k16Bit) {
312     suffix += "_16bit";
313   }
314 
315   switch (layout()) {
316     case DataLayout::kYXDepthBatch:
317       return absl::StrCat(spatial, depth, batch, suffix);
318     case DataLayout::kYXBatchDepth:
319       return absl::StrCat(spatial, batch, depth, suffix);
320     case DataLayout::kBatchYXDepth:
321       return absl::StrCat(batch, spatial, depth, suffix);
322     case DataLayout::kBatchDepthYX:
323       return absl::StrCat(batch, depth, spatial, suffix);
324     case DataLayout::kBatchDepthYX4:
325       return absl::StrCat(batch, depth, spatial, suffix, "(VECT_C)");
326     default:
327       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
328       return "";  // Avoid return warning (unreachable)
329   }
330 }
331 
NodesPerFeatureMap() const332 int64 BatchDescriptor::NodesPerFeatureMap() const {
333   int64 ret = 1;
334   for (int i = 0; i < ndims(); i++) {
335     ret *= spatial_size()[i];
336   }
337   return ret;
338 }
339 
NodesAcrossFeatureMaps() const340 int64 BatchDescriptor::NodesAcrossFeatureMaps() const {
341   return NodesPerFeatureMap() * feature_map_count();
342 }
343 
ElementCount() const344 int64 BatchDescriptor::ElementCount() const {
345   return count() * feature_map_count() * NodesPerFeatureMap();
346 }
347 
FullyConnectedWeightCount(const BatchDescriptor & input,const BatchDescriptor & output)348 int64 BatchDescriptor::FullyConnectedWeightCount(
349     const BatchDescriptor& input, const BatchDescriptor& output) {
350   return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps();
351 }
352 
FullyConnectedBiasCount(const BatchDescriptor & output)353 int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) {
354   return output.NodesAcrossFeatureMaps();
355 }
356 
DepthConcatenateOutputDescriptor(port::ArraySlice<dnn::BatchDescriptor> inputs)357 BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
358     port::ArraySlice<dnn::BatchDescriptor> inputs) {
359   if (inputs.empty()) {
360     return BatchDescriptor();
361   }
362   int feature_map_count = 0;
363   for (const auto& dimensions : inputs) {
364     feature_map_count += dimensions.feature_map_count();
365   }
366   BatchDescriptor output = inputs[0];
367   output.set_feature_map_count(feature_map_count);
368   return output;
369 }
370 
ToProto(DataType data_type) const371 TensorDescriptorProto BatchDescriptor::ToProto(DataType data_type) const {
372   CHECK_EQ(0.0, value_max_);
373   CHECK_EQ(0.0, value_min_);
374   CHECK(quantized_activation_mode_ == QuantizedActivationMode::k8Bit);
375 
376   TensorDescriptorProto ret = tensor_;
377   ret.set_data_type(data_type);
378   return ret;
379 }
380 
381 // -- FilterDescriptor
382 
FilterDescriptor(int ndims)383 FilterDescriptor::FilterDescriptor(int ndims) {
384   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
385   set_layout(FilterLayout::kOutputInputYX);
386 }
387 
FilterDescriptor()388 FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {}
389 
~FilterDescriptor()390 FilterDescriptor::~FilterDescriptor() {}
391 
CloneFrom(const FilterDescriptor & other)392 void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
393   tensor_ = other.tensor_;
394 }
395 
ToString() const396 string FilterDescriptor::ToString() const {
397   string desc = port::Printf(
398       "{output_feature_map_count: %lld input_feature_map_count: %lld "
399       "layout: %s shape: ",
400       output_feature_map_count(), input_feature_map_count(),
401       FilterLayoutString(layout()).c_str());
402   for (int i = 0; i < ndims(); i++) {
403     port::Appendf(&desc, "%lld ", input_filter_dims()[i]);
404   }
405   absl::StrAppend(&desc, "}");
406 
407   return desc;
408 }
409 
ToShortString() const410 string FilterDescriptor::ToShortString() const {
411   // All the constituent strings are less than 15 characters, so the
412   // small string optimization ensures that there will be at most one
413   // heap memory allocation.
414   string od = absl::StrCat("od", output_feature_map_count());
415   string id = absl::StrCat("id", input_feature_map_count());
416 
417   string spatial = "s";
418   for (int i = 0; i < ndims(); i++) {
419     port::Appendf(&spatial, "%lld ", input_filter_dims()[i]);
420   }
421 
422   switch (layout()) {
423     case FilterLayout::kOutputInputYX:
424       return absl::StrCat(od, id, spatial);
425     case FilterLayout::kOutputYXInput:
426       return absl::StrCat(od, spatial, id);
427     case FilterLayout::kOutputInputYX4:
428       return absl::StrCat(od, id, spatial, "(VECT_C)");
429     case FilterLayout::kInputYXOutput:
430       return absl::StrCat(id, spatial, od);
431     case FilterLayout::kYXInputOutput:
432       return absl::StrCat(spatial, id, od);
433     default:
434       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
435       return "";  // Avoid return warning (unreachable)
436   }
437 }
438 
ComputeWeightCount() const439 int64 FilterDescriptor::ComputeWeightCount() const {
440   int64 ret = output_feature_map_count() * input_feature_map_count();
441   for (int i = 0; i < ndims(); i++) {
442     ret *= input_filter_dims()[i];
443   }
444   return ret;
445 }
446 
ToProto(DataType data_type) const447 TensorDescriptorProto FilterDescriptor::ToProto(DataType data_type) const {
448   TensorDescriptorProto ret = tensor_;
449   ret.set_data_type(data_type);
450   return ret;
451 }
452 
453 // -- ConvolutionDescriptor
454 
ConvolutionDescriptor(int ndims)455 ConvolutionDescriptor::ConvolutionDescriptor(int ndims) {
456   proto_.mutable_paddings()->Resize(ndims, 0);
457   proto_.mutable_strides()->Resize(ndims, 1);
458   proto_.mutable_dilations()->Resize(ndims, 1);
459   proto_.set_group_count(1);
460   proto_.set_convolution_mode(ConvolutionMode::CROSS_CORRELATION);
461 }
462 
ConvolutionDescriptor()463 ConvolutionDescriptor::ConvolutionDescriptor()
464     : ConvolutionDescriptor(/*ndims=*/2) {}
465 
~ConvolutionDescriptor()466 ConvolutionDescriptor::~ConvolutionDescriptor() {}
467 
ToString() const468 string ConvolutionDescriptor::ToString() const {
469   string padding;
470   string strides;
471   string dilations;
472   for (int i = 0; i < ndims(); i++) {
473     port::Appendf(&padding, "%lld ", this->padding()[i]);
474     port::Appendf(&strides, "%lld ", this->strides()[i]);
475     port::Appendf(&dilations, "%lld ", this->dilations()[i]);
476   }
477 
478   return port::Printf(
479       "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
480       "%s}",
481       padding.c_str(), PadAlignmentString(pad_alignment()).c_str(),
482       strides.c_str(), dilations.c_str());
483 }
484 
ToShortString() const485 string ConvolutionDescriptor::ToShortString() const {
486   string desc;
487   for (int i = 0; i < ndims(); i++) {
488     if (i > 0) port::Appendf(&desc, "_");
489     port::Appendf(&desc, "p%d:%lld", i, padding()[i]);
490   }
491   for (int i = 0; i < ndims(); i++) {
492     port::Appendf(&desc, "_s%d:%lld", i, strides()[i]);
493   }
494   for (int i = 0; i < ndims(); i++) {
495     port::Appendf(&desc, "_d%d:%lld", i, dilations()[i]);
496   }
497   return desc;
498 }
499 
500 // -- PoolingDescriptor
501 
PoolingDescriptor(int ndims)502 PoolingDescriptor::PoolingDescriptor(int ndims)
503     : mode_(dnn::PoolingMode::kMaximum),
504       ndims_(ndims),
505       propagate_nans_(false),
506       window_(ndims, 0),
507       padding_(ndims, 0),
508       strides_(ndims, 1) {}
509 
PoolingDescriptor()510 PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
511 
CloneFrom(const PoolingDescriptor & other)512 void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
513   mode_ = other.mode_;
514   ndims_ = other.ndims_;
515   window_ = other.window_;
516   padding_ = other.padding_;
517   strides_ = other.strides_;
518   propagate_nans_ = other.propagate_nans_;
519 }
520 
ToString() const521 string PoolingDescriptor::ToString() const {
522   const char* mode_string =
523       mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
524 
525   string window, strides, padding;
526   for (int i = 0; i < ndims_; i++) {
527     port::Appendf(&window, "%lld ", window_[i]);
528     port::Appendf(&strides, "%lld ", strides_[i]);
529     port::Appendf(&padding, "%lld", padding_[i]);
530   }
531 
532   const char* propagate_string = propagate_nans_ ? "Yes" : "No";
533 
534   return port::Printf(
535       "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}",
536       mode_string, window.c_str(), strides.c_str(), padding.c_str(),
537       propagate_string);
538 }
539 
ToShortString() const540 string PoolingDescriptor::ToShortString() const {
541   string window, strides, padding;
542   for (int i = 0; i < ndims_; i++) {
543     port::Appendf(&window, "_w%d:%lld", i, window_[i]);
544     port::Appendf(&strides, "_s%d:%lld", i, strides_[i]);
545     port::Appendf(&padding, "_p%d:%lld", i, padding_[i]);
546   }
547   return absl::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
548                       window, strides, padding,
549                       propagate_nans_ ? "propagate_nans" : "ignore_nans");
550 }
551 
552 // -- NormalizeDescriptor
553 
NormalizeDescriptor()554 NormalizeDescriptor::NormalizeDescriptor()
555     : bias_(0.0),
556       range_(0),
557       alpha_(0.0),
558       beta_(0.0),
559       wrap_around_(false),
560       segment_size_(0) {}
561 
CloneFrom(const NormalizeDescriptor & other)562 void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
563   bias_ = other.bias_;
564   range_ = other.range_;
565   alpha_ = other.alpha_;
566   beta_ = other.beta_;
567   wrap_around_ = other.wrap_around_;
568   segment_size_ = other.segment_size_;
569 }
570 
ToString() const571 string NormalizeDescriptor::ToString() const {
572   return port::Printf(
573       "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
574       "segment_size: %d}",
575       bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
576 }
577 
ToShortString() const578 string NormalizeDescriptor::ToShortString() const {
579   return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
580                       "_beta:", beta_, "_wrap:", wrap_around_,
581                       "_size:", segment_size_);
582 }
583 
IsStatusOk(const port::Status & status,bool report_error)584 bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
585   if (status.ok()) {
586     return true;
587   }
588   if (report_error) {
589     LOG(ERROR) << status.error_message();
590   }
591   return false;
592 }
593 
594 }  // namespace dnn
595 }  // namespace stream_executor
596