• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/dnn.h"
17 
18 #include "absl/hash/hash.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 
22 namespace stream_executor {
23 namespace dnn {
24 
25 constexpr DataType ToDataType<float>::value;
26 constexpr DataType ToDataType<double>::value;
27 constexpr DataType ToDataType<Eigen::half>::value;
28 constexpr DataType ToDataType<int8>::value;
29 constexpr DataType ToDataType<int32>::value;
30 
hash() const31 uint64 AlgorithmDesc::hash() const {
32   auto p = std::make_pair(algo_id(), tensor_ops_enabled());
33   return absl::Hash<decltype(p)>()(p);
34 }
35 
ToString() const36 string AlgorithmDesc::ToString() const {
37   if (tensor_ops_enabled()) {
38     return absl::StrCat(algo_id(), "#TC");
39   } else {
40     return absl::StrCat(algo_id());
41   }
42 }
43 
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)44 bool DnnSupport::GetConvolveAlgorithms(
45     bool with_winograd_nonfused, int cc_major, int cc_minor,
46     std::vector<AlgorithmDesc>* out_algorithms) {
47   return false;
48 }
49 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind,Stream *,dnn::DataType,const dnn::BatchDescriptor &,const dnn::FilterDescriptor &,const dnn::ConvolutionDescriptor &,const dnn::BatchDescriptor &,std::vector<ProfileResult> *)50 bool DnnSupport::GetMIOpenConvolveAlgorithms(
51     dnn::ConvolutionKind /*kind*/, Stream* /*stream*/,
52     dnn::DataType /*element_type*/,
53     const dnn::BatchDescriptor& /*input_descriptor*/,
54     const dnn::FilterDescriptor& /*filter_descriptor*/,
55     const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
56     const dnn::BatchDescriptor& /*output_descriptor*/,
57     std::vector<ProfileResult>* /*out_algorithms*/) {
58   return false;
59 }
60 
GetRnnAlgorithms(std::vector<AlgorithmDesc> * out_algorithms)61 bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
62   return false;
63 }
64 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)65 bool DnnSupport::GetConvolveBackwardDataAlgorithms(
66     bool with_winograd_nonfused, int cc_major, int cc_minor,
67     std::vector<AlgorithmDesc>* out_algorithms) {
68   return false;
69 }
70 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)71 bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
72     bool with_winograd_nonfused, int cc_major, int cc_minor,
73     std::vector<AlgorithmDesc>* out_algorithms) {
74   return false;
75 }
76 
QuantizedActivationModeString(QuantizedActivationMode mode)77 string QuantizedActivationModeString(QuantizedActivationMode mode) {
78   switch (mode) {
79     case dnn::QuantizedActivationMode::k8Bit:
80       return "uint8";
81     case dnn::QuantizedActivationMode::k16Bit:
82       return "uint16";
83     case dnn::QuantizedActivationMode::k32Bit:
84       return "int32";
85     default:
86       LOG(FATAL) << "Unknown quantized_activation_mode "
87                  << static_cast<int32>(mode);
88   }
89   return "unknown quantized_activation_mode";
90 }
91 
ActivationModeString(ActivationMode mode)92 string ActivationModeString(ActivationMode mode) {
93   switch (mode) {
94     case ActivationMode::kSigmoid:
95       return "sigmoid";
96     case ActivationMode::kRelu:
97       return "relu";
98     case ActivationMode::kRelu6:
99       return "relu6";
100     case ActivationMode::kReluX:
101       return "reluX";
102     case ActivationMode::kTanh:
103       return "tanh";
104     case ActivationMode::kBandPass:
105       return "bandpass";
106     default:
107       LOG(FATAL) << "Unknown activation_mode " << static_cast<int32>(mode);
108   }
109   return "unknown activation_mode";
110 }
111 
ElementwiseOperationString(ElementwiseOperation op)112 string ElementwiseOperationString(ElementwiseOperation op) {
113   switch (op) {
114     case ElementwiseOperation::kAdd:
115       return "add";
116     case ElementwiseOperation::kMultiply:
117       return "multiply";
118     default:
119       LOG(FATAL) << "Unknown elementwise op " << static_cast<int32>(op);
120   }
121   return "unknown element wise op";
122 }
123 
DataLayoutString(DataLayout layout)124 string DataLayoutString(DataLayout layout) {
125   switch (layout) {
126     case DataLayout::kYXDepthBatch:
127       return "YXDepthBatch";
128     case DataLayout::kYXBatchDepth:
129       return "YXBatchDepth";
130     case DataLayout::kBatchYXDepth:
131       return "BatchYXDepth";
132     case DataLayout::kBatchDepthYX:
133       return "BatchDepthYX";
134     case DataLayout::kBatchDepthYX4:
135       return "BatchDepthYX4";
136     default:
137       LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout);
138   }
139   return "unknown data layout";
140 }
141 
FilterLayoutString(FilterLayout layout)142 string FilterLayoutString(FilterLayout layout) {
143   switch (layout) {
144     case FilterLayout::kOutputInputYX:
145       return "OutputInputYX";
146     case FilterLayout::kOutputYXInput:
147       return "OutputYXInput";
148     case FilterLayout::kOutputInputYX4:
149       return "OutputInputYX4";
150     case FilterLayout::kInputYXOutput:
151       return "InputYXOutput";
152     case FilterLayout::kYXInputOutput:
153       return "YXInputOutput";
154     default:
155       LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(layout);
156   }
157   return "unknown filter layout";
158 }
159 
PadAlignmentString(PadAlignment alignment)160 string PadAlignmentString(PadAlignment alignment) {
161   switch (alignment) {
162     case PadAlignment::kDefault:
163       return "default";
164     case PadAlignment::kCudnnPadding:
165       return "cuDNN padding";
166     case PadAlignment::kTensorFlowPadding:
167       return "TensorFlow padding";
168   }
169   return "unknown pad alignment";
170 }
171 
operator <<(std::ostream & str,dnn::PadAlignment alignment)172 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
173   return str << PadAlignmentString(alignment);
174 }
175 
ShortPoolingModeString(PoolingMode mode)176 string ShortPoolingModeString(PoolingMode mode) {
177   switch (mode) {
178     case PoolingMode::kMaximum:
179       return "Max";
180     case PoolingMode::kAverage:
181       return "Avg";
182     default:
183       LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode);
184   }
185   return "unknown filter layout";
186 }
187 
GetDimIndices(const DataLayout & layout,const int data_dims)188 std::tuple<int, int, int> GetDimIndices(const DataLayout& layout,
189                                         const int data_dims) {
190   int depth_idx, batch_idx, spatial_idx;
191   switch (layout) {
192     case DataLayout::kYXBatchDepth:
193       depth_idx = data_dims - 1;
194       batch_idx = data_dims - 2;
195       spatial_idx = 0;
196       break;
197 
198     case DataLayout::kYXDepthBatch:
199       depth_idx = data_dims - 2;
200       batch_idx = data_dims - 1;
201       spatial_idx = 0;
202       break;
203 
204     case DataLayout::kBatchYXDepth:
205       depth_idx = data_dims - 1;
206       batch_idx = 0;
207       spatial_idx = 1;
208       break;
209 
210     case DataLayout::kBatchDepthYX:
211     case DataLayout::kBatchDepthYX4:
212       depth_idx = 1;
213       batch_idx = 0;
214       spatial_idx = 2;
215       break;
216 
217     default:
218       LOG(FATAL) << "Unknown layout " << layout;
219   }
220 
221   return std::make_tuple(depth_idx, batch_idx, spatial_idx);
222 }
223 
ReorderDims(const std::vector<int64> & input,const DataLayout & from,const DataLayout & to)224 std::vector<int64> ReorderDims(const std::vector<int64>& input,
225                                const DataLayout& from, const DataLayout& to) {
226   if (from == to) return input;
227 
228   int d_idx_from, b_idx_from, spatial_idx_from;
229   int d_idx_to, b_idx_to, spatial_idx_to;
230 
231   std::tie(d_idx_from, b_idx_from, spatial_idx_from) =
232       GetDimIndices(from, input.size());
233   std::tie(d_idx_to, b_idx_to, spatial_idx_to) =
234       GetDimIndices(to, input.size());
235 
236   std::vector<int64> reordered(input.size());
237   reordered[b_idx_to] = input[b_idx_from];
238   reordered[d_idx_to] = input[d_idx_from];
239 
240   for (size_t i = 0; i < input.size() - 2;
241        i++, spatial_idx_from++, spatial_idx_to++) {
242     reordered[spatial_idx_to] = input[spatial_idx_from];
243   }
244 
245   return reordered;
246 }
247 
248 // -- AlgorithmConfig
249 
ToString() const250 string AlgorithmConfig::ToString() const {
251   string algo = "none";
252   if (algorithm().has_value()) {
253     algo = algorithm()->ToString();
254   }
255   string algo_no_scratch = "none";
256   if (algorithm_no_scratch().has_value()) {
257     algo_no_scratch = algorithm_no_scratch()->ToString();
258   }
259   return absl::StrCat(algo, ", ", algo_no_scratch);
260 }
261 
262 // -- BatchDescriptor
263 
BatchDescriptor(int ndims)264 BatchDescriptor::BatchDescriptor(int ndims)
265     : value_max_(0.0),
266       value_min_(0.0),
267       quantized_activation_mode_(QuantizedActivationMode::k8Bit) {
268   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
269   set_layout(DataLayout::kYXDepthBatch);
270 }
271 
BatchDescriptor()272 BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {}
273 
full_dims(const DataLayout & layout) const274 std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const {
275   std::vector<int64> bdyx_dims(ndims() + 2);
276   bdyx_dims[0] = count();
277   bdyx_dims[1] = feature_map_count();
278   std::copy(spatial_size().begin(), spatial_size().end(),
279             bdyx_dims.begin() + 2);
280   return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout);
281 }
282 
full_strides(const DataLayout & layout) const283 std::vector<int64> BatchDescriptor::full_strides(
284     const DataLayout& layout) const {
285   if (this->layout() == DataLayout::kBatchDepthYX4) {
286     LOG(FATAL)
287         << "Cannot compute full strides for batch descriptor " << ToString()
288         << ", because its layout is kBatchDepthYX4. In fact, "
289            "cudnnSetTensorNdDescriptor doesn't work for kBatchDepthYX4 at all. "
290            "Use cudnnSetTensor4DDescriptor to set cudnnTensorDescriptor_t "
291            "instead.";
292   }
293   std::vector<int64> phys_dims = full_dims(this->layout());
294   std::vector<int64> phys_strides(phys_dims.size());
295   phys_strides[ndims() + 1] = 1;
296   for (int i = ndims(); i >= 0; i--) {
297     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
298   }
299   return ReorderDims(phys_strides, this->layout(), layout);
300 }
301 
CloneFrom(const BatchDescriptor & other)302 void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
303   tensor_ = other.tensor_;
304   value_max_ = other.value_max_;
305   value_min_ = other.value_min_;
306   quantized_activation_mode_ = other.quantized_activation_mode_;
307 }
308 
ToString() const309 string BatchDescriptor::ToString() const {
310   string spatial;
311   for (int i = 0; i < ndims(); i++) {
312     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
313   }
314   return absl::StrFormat(
315       "{count: %d feature_map_count: %d spatial: %s "
316       "value_min: %f value_max: %f layout: %s}",
317       count(), feature_map_count(), spatial, value_min_, value_max_,
318       DataLayoutString(layout()));
319 }
320 
ToShortString() const321 string BatchDescriptor::ToShortString() const {
322   // All the constituent strings are less than 15 characters, so the
323   // small string optimization ensures that there will be at most one
324   // heap memory allocation.
325   string depth = absl::StrCat("d", feature_map_count());
326   string batch = absl::StrCat("b", count());
327 
328   string spatial = "s";
329   for (int i = 0; i < ndims(); i++) {
330     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
331   }
332 
333   string suffix;
334   if (value_min() != value_max()) {
335     absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
336   }
337   if (quantized_activation_mode() == QuantizedActivationMode::k16Bit) {
338     suffix += "_16bit";
339   }
340 
341   switch (layout()) {
342     case DataLayout::kYXDepthBatch:
343       return absl::StrCat(spatial, depth, batch, suffix);
344     case DataLayout::kYXBatchDepth:
345       return absl::StrCat(spatial, batch, depth, suffix);
346     case DataLayout::kBatchYXDepth:
347       return absl::StrCat(batch, spatial, depth, suffix);
348     case DataLayout::kBatchDepthYX:
349       return absl::StrCat(batch, depth, spatial, suffix);
350     case DataLayout::kBatchDepthYX4:
351       return absl::StrCat(batch, depth, spatial, suffix, "(VECT_C)");
352     default:
353       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
354       return "";  // Avoid return warning (unreachable)
355   }
356 }
357 
NodesPerFeatureMap() const358 int64 BatchDescriptor::NodesPerFeatureMap() const {
359   int64 ret = 1;
360   for (int i = 0; i < ndims(); i++) {
361     ret *= spatial_size()[i];
362   }
363   return ret;
364 }
365 
NodesAcrossFeatureMaps() const366 int64 BatchDescriptor::NodesAcrossFeatureMaps() const {
367   return NodesPerFeatureMap() * feature_map_count();
368 }
369 
ElementCount() const370 int64 BatchDescriptor::ElementCount() const {
371   return count() * feature_map_count() * NodesPerFeatureMap();
372 }
373 
FullyConnectedWeightCount(const BatchDescriptor & input,const BatchDescriptor & output)374 int64 BatchDescriptor::FullyConnectedWeightCount(
375     const BatchDescriptor& input, const BatchDescriptor& output) {
376   return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps();
377 }
378 
FullyConnectedBiasCount(const BatchDescriptor & output)379 int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) {
380   return output.NodesAcrossFeatureMaps();
381 }
382 
DepthConcatenateOutputDescriptor(port::ArraySlice<dnn::BatchDescriptor> inputs)383 BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
384     port::ArraySlice<dnn::BatchDescriptor> inputs) {
385   if (inputs.empty()) {
386     return BatchDescriptor();
387   }
388   int feature_map_count = 0;
389   for (const auto& dimensions : inputs) {
390     feature_map_count += dimensions.feature_map_count();
391   }
392   BatchDescriptor output = inputs[0];
393   output.set_feature_map_count(feature_map_count);
394   return output;
395 }
396 
ToProto(DataType data_type) const397 TensorDescriptorProto BatchDescriptor::ToProto(DataType data_type) const {
398   CHECK_EQ(0.0, value_max_);
399   CHECK_EQ(0.0, value_min_);
400   CHECK(quantized_activation_mode_ == QuantizedActivationMode::k8Bit);
401 
402   TensorDescriptorProto ret = tensor_;
403   ret.set_data_type(data_type);
404   return ret;
405 }
406 
407 // -- FilterDescriptor
408 
FilterDescriptor(int ndims)409 FilterDescriptor::FilterDescriptor(int ndims) {
410   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
411   set_layout(FilterLayout::kOutputInputYX);
412 }
413 
FilterDescriptor()414 FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {}
415 
~FilterDescriptor()416 FilterDescriptor::~FilterDescriptor() {}
417 
CloneFrom(const FilterDescriptor & other)418 void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
419   tensor_ = other.tensor_;
420 }
421 
ToString() const422 string FilterDescriptor::ToString() const {
423   string desc = absl::StrFormat(
424       "{output_feature_map_count: %d input_feature_map_count: %d "
425       "layout: %s shape: ",
426       output_feature_map_count(), input_feature_map_count(),
427       FilterLayoutString(layout()));
428   for (int i = 0; i < ndims(); i++) {
429     absl::StrAppendFormat(&desc, "%d ", input_filter_dims()[i]);
430   }
431   absl::StrAppend(&desc, "}");
432 
433   return desc;
434 }
435 
ToShortString() const436 string FilterDescriptor::ToShortString() const {
437   // All the constituent strings are less than 15 characters, so the
438   // small string optimization ensures that there will be at most one
439   // heap memory allocation.
440   string od = absl::StrCat("od", output_feature_map_count());
441   string id = absl::StrCat("id", input_feature_map_count());
442 
443   string spatial = "s";
444   for (int i = 0; i < ndims(); i++) {
445     absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
446   }
447 
448   switch (layout()) {
449     case FilterLayout::kOutputInputYX:
450       return absl::StrCat(od, id, spatial);
451     case FilterLayout::kOutputYXInput:
452       return absl::StrCat(od, spatial, id);
453     case FilterLayout::kOutputInputYX4:
454       return absl::StrCat(od, id, spatial, "(VECT_C)");
455     case FilterLayout::kInputYXOutput:
456       return absl::StrCat(id, spatial, od);
457     case FilterLayout::kYXInputOutput:
458       return absl::StrCat(spatial, id, od);
459     default:
460       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
461       return "";  // Avoid return warning (unreachable)
462   }
463 }
464 
ComputeWeightCount() const465 int64 FilterDescriptor::ComputeWeightCount() const {
466   int64 ret = output_feature_map_count() * input_feature_map_count();
467   for (int i = 0; i < ndims(); i++) {
468     ret *= input_filter_dims()[i];
469   }
470   return ret;
471 }
472 
ToProto(DataType data_type) const473 TensorDescriptorProto FilterDescriptor::ToProto(DataType data_type) const {
474   TensorDescriptorProto ret = tensor_;
475   ret.set_data_type(data_type);
476   return ret;
477 }
478 
479 // -- ConvolutionDescriptor
480 
ConvolutionDescriptor(int ndims)481 ConvolutionDescriptor::ConvolutionDescriptor(int ndims) {
482   proto_.mutable_paddings()->Resize(ndims, 0);
483   proto_.mutable_strides()->Resize(ndims, 1);
484   proto_.mutable_dilations()->Resize(ndims, 1);
485   proto_.set_group_count(1);
486   proto_.set_convolution_mode(ConvolutionMode::CROSS_CORRELATION);
487 }
488 
ConvolutionDescriptor()489 ConvolutionDescriptor::ConvolutionDescriptor()
490     : ConvolutionDescriptor(/*ndims=*/2) {}
491 
~ConvolutionDescriptor()492 ConvolutionDescriptor::~ConvolutionDescriptor() {}
493 
ToString() const494 string ConvolutionDescriptor::ToString() const {
495   string padding;
496   string strides;
497   string dilations;
498   for (int i = 0; i < ndims(); i++) {
499     absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
500     absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
501     absl::StrAppendFormat(&dilations, "%d ", this->dilations()[i]);
502   }
503 
504   return absl::StrFormat(
505       "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
506       "%s}",
507       padding, PadAlignmentString(pad_alignment()), strides, dilations);
508 }
509 
ToShortString() const510 string ConvolutionDescriptor::ToShortString() const {
511   string desc;
512   for (int i = 0; i < ndims(); i++) {
513     if (i > 0) absl::StrAppend(&desc, "_");
514     absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
515   }
516   for (int i = 0; i < ndims(); i++) {
517     absl::StrAppendFormat(&desc, "_s%d:%d", i, strides()[i]);
518   }
519   for (int i = 0; i < ndims(); i++) {
520     absl::StrAppendFormat(&desc, "_d%d:%d", i, dilations()[i]);
521   }
522   return desc;
523 }
524 
525 // -- PoolingDescriptor
526 
PoolingDescriptor(int ndims)527 PoolingDescriptor::PoolingDescriptor(int ndims)
528     : mode_(dnn::PoolingMode::kMaximum),
529       ndims_(ndims),
530       propagate_nans_(false),
531       window_(ndims, 0),
532       padding_(ndims, 0),
533       strides_(ndims, 1) {}
534 
PoolingDescriptor()535 PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
536 
CloneFrom(const PoolingDescriptor & other)537 void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
538   mode_ = other.mode_;
539   ndims_ = other.ndims_;
540   window_ = other.window_;
541   padding_ = other.padding_;
542   strides_ = other.strides_;
543   propagate_nans_ = other.propagate_nans_;
544 }
545 
ToString() const546 string PoolingDescriptor::ToString() const {
547   const char* mode_string =
548       mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
549 
550   string window, strides, padding;
551   for (int i = 0; i < ndims_; i++) {
552     absl::StrAppendFormat(&window, "%d ", window_[i]);
553     absl::StrAppendFormat(&strides, "%d ", strides_[i]);
554     absl::StrAppendFormat(&padding, "%d", padding_[i]);
555   }
556 
557   const char* propagate_string = propagate_nans_ ? "Yes" : "No";
558 
559   return absl::StrFormat(
560       "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}",
561       mode_string, window, strides, padding, propagate_string);
562 }
563 
ToShortString() const564 string PoolingDescriptor::ToShortString() const {
565   string window, strides, padding;
566   for (int i = 0; i < ndims_; i++) {
567     absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
568     absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
569     absl::StrAppendFormat(&padding, "_p%d:%d", i, padding_[i]);
570   }
571   return absl::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
572                       window, strides, padding,
573                       propagate_nans_ ? "propagate_nans" : "ignore_nans");
574 }
575 
576 // -- NormalizeDescriptor
577 
NormalizeDescriptor()578 NormalizeDescriptor::NormalizeDescriptor()
579     : bias_(0.0),
580       range_(0),
581       alpha_(0.0),
582       beta_(0.0),
583       wrap_around_(false),
584       segment_size_(0) {}
585 
CloneFrom(const NormalizeDescriptor & other)586 void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
587   bias_ = other.bias_;
588   range_ = other.range_;
589   alpha_ = other.alpha_;
590   beta_ = other.beta_;
591   wrap_around_ = other.wrap_around_;
592   segment_size_ = other.segment_size_;
593 }
594 
ToString() const595 string NormalizeDescriptor::ToString() const {
596   return absl::StrFormat(
597       "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
598       "segment_size: %d}",
599       bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
600 }
601 
ToShortString() const602 string NormalizeDescriptor::ToShortString() const {
603   return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
604                       "_beta:", beta_, "_wrap:", wrap_around_,
605                       "_size:", segment_size_);
606 }
607 
IsStatusOk(const port::Status & status,bool report_error)608 bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
609   if (status.ok()) {
610     return true;
611   }
612   if (report_error) {
613     LOG(ERROR) << status.error_message();
614   }
615   return false;
616 }
617 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory)618 port::Status DnnSupport::DoCtcLoss(Stream* stream, dnn::DataType element_type,
619                                    const RnnStateTensorDescriptor& probs_desc,
620                                    const DeviceMemoryBase probs_data,
621                                    absl::Span<const int> labels_data,
622                                    absl::Span<const int> labels_lengths_data,
623                                    absl::Span<const int> input_lengths_data,
624                                    DeviceMemoryBase costs_data,
625                                    const RnnStateTensorDescriptor& grads_desc,
626                                    DeviceMemoryBase grads_data,
627                                    DeviceMemory<uint8> scratch_memory) {
628   return port::UnimplementedError("CtcLoss not implemented");
629 }
630 
631 }  // namespace dnn
632 }  // namespace stream_executor
633