1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/dnn.h"
17
18 #include "absl/hash/hash.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21
22 namespace stream_executor {
23 namespace dnn {
24
25 constexpr DataType ToDataType<float>::value;
26 constexpr DataType ToDataType<double>::value;
27 constexpr DataType ToDataType<Eigen::half>::value;
28 constexpr DataType ToDataType<int8>::value;
29 constexpr DataType ToDataType<int32>::value;
30
hash() const31 uint64 AlgorithmDesc::hash() const {
32 auto p = std::make_pair(algo_id(), tensor_ops_enabled());
33 return absl::Hash<decltype(p)>()(p);
34 }
35
ToString() const36 string AlgorithmDesc::ToString() const {
37 if (tensor_ops_enabled()) {
38 return absl::StrCat(algo_id(), "#TC");
39 } else {
40 return absl::StrCat(algo_id());
41 }
42 }
43
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)44 bool DnnSupport::GetConvolveAlgorithms(
45 bool with_winograd_nonfused, int cc_major, int cc_minor,
46 std::vector<AlgorithmDesc>* out_algorithms) {
47 return false;
48 }
49
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind,Stream *,dnn::DataType,const dnn::BatchDescriptor &,const dnn::FilterDescriptor &,const dnn::ConvolutionDescriptor &,const dnn::BatchDescriptor &,std::vector<ProfileResult> *)50 bool DnnSupport::GetMIOpenConvolveAlgorithms(
51 dnn::ConvolutionKind /*kind*/, Stream* /*stream*/,
52 dnn::DataType /*element_type*/,
53 const dnn::BatchDescriptor& /*input_descriptor*/,
54 const dnn::FilterDescriptor& /*filter_descriptor*/,
55 const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
56 const dnn::BatchDescriptor& /*output_descriptor*/,
57 std::vector<ProfileResult>* /*out_algorithms*/) {
58 return false;
59 }
60
GetRnnAlgorithms(std::vector<AlgorithmDesc> * out_algorithms)61 bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
62 return false;
63 }
64
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)65 bool DnnSupport::GetConvolveBackwardDataAlgorithms(
66 bool with_winograd_nonfused, int cc_major, int cc_minor,
67 std::vector<AlgorithmDesc>* out_algorithms) {
68 return false;
69 }
70
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<AlgorithmDesc> * out_algorithms)71 bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
72 bool with_winograd_nonfused, int cc_major, int cc_minor,
73 std::vector<AlgorithmDesc>* out_algorithms) {
74 return false;
75 }
76
QuantizedActivationModeString(QuantizedActivationMode mode)77 string QuantizedActivationModeString(QuantizedActivationMode mode) {
78 switch (mode) {
79 case dnn::QuantizedActivationMode::k8Bit:
80 return "uint8";
81 case dnn::QuantizedActivationMode::k16Bit:
82 return "uint16";
83 case dnn::QuantizedActivationMode::k32Bit:
84 return "int32";
85 default:
86 LOG(FATAL) << "Unknown quantized_activation_mode "
87 << static_cast<int32>(mode);
88 }
89 return "unknown quantized_activation_mode";
90 }
91
ActivationModeString(ActivationMode mode)92 string ActivationModeString(ActivationMode mode) {
93 switch (mode) {
94 case ActivationMode::kSigmoid:
95 return "sigmoid";
96 case ActivationMode::kRelu:
97 return "relu";
98 case ActivationMode::kRelu6:
99 return "relu6";
100 case ActivationMode::kReluX:
101 return "reluX";
102 case ActivationMode::kTanh:
103 return "tanh";
104 case ActivationMode::kBandPass:
105 return "bandpass";
106 default:
107 LOG(FATAL) << "Unknown activation_mode " << static_cast<int32>(mode);
108 }
109 return "unknown activation_mode";
110 }
111
ElementwiseOperationString(ElementwiseOperation op)112 string ElementwiseOperationString(ElementwiseOperation op) {
113 switch (op) {
114 case ElementwiseOperation::kAdd:
115 return "add";
116 case ElementwiseOperation::kMultiply:
117 return "multiply";
118 default:
119 LOG(FATAL) << "Unknown elementwise op " << static_cast<int32>(op);
120 }
121 return "unknown element wise op";
122 }
123
DataLayoutString(DataLayout layout)124 string DataLayoutString(DataLayout layout) {
125 switch (layout) {
126 case DataLayout::kYXDepthBatch:
127 return "YXDepthBatch";
128 case DataLayout::kYXBatchDepth:
129 return "YXBatchDepth";
130 case DataLayout::kBatchYXDepth:
131 return "BatchYXDepth";
132 case DataLayout::kBatchDepthYX:
133 return "BatchDepthYX";
134 case DataLayout::kBatchDepthYX4:
135 return "BatchDepthYX4";
136 default:
137 LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout);
138 }
139 return "unknown data layout";
140 }
141
FilterLayoutString(FilterLayout layout)142 string FilterLayoutString(FilterLayout layout) {
143 switch (layout) {
144 case FilterLayout::kOutputInputYX:
145 return "OutputInputYX";
146 case FilterLayout::kOutputYXInput:
147 return "OutputYXInput";
148 case FilterLayout::kOutputInputYX4:
149 return "OutputInputYX4";
150 case FilterLayout::kInputYXOutput:
151 return "InputYXOutput";
152 case FilterLayout::kYXInputOutput:
153 return "YXInputOutput";
154 default:
155 LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(layout);
156 }
157 return "unknown filter layout";
158 }
159
PadAlignmentString(PadAlignment alignment)160 string PadAlignmentString(PadAlignment alignment) {
161 switch (alignment) {
162 case PadAlignment::kDefault:
163 return "default";
164 case PadAlignment::kCudnnPadding:
165 return "cuDNN padding";
166 case PadAlignment::kTensorFlowPadding:
167 return "TensorFlow padding";
168 }
169 return "unknown pad alignment";
170 }
171
operator <<(std::ostream & str,dnn::PadAlignment alignment)172 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
173 return str << PadAlignmentString(alignment);
174 }
175
ShortPoolingModeString(PoolingMode mode)176 string ShortPoolingModeString(PoolingMode mode) {
177 switch (mode) {
178 case PoolingMode::kMaximum:
179 return "Max";
180 case PoolingMode::kAverage:
181 return "Avg";
182 default:
183 LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode);
184 }
185 return "unknown filter layout";
186 }
187
GetDimIndices(const DataLayout & layout,const int data_dims)188 std::tuple<int, int, int> GetDimIndices(const DataLayout& layout,
189 const int data_dims) {
190 int depth_idx, batch_idx, spatial_idx;
191 switch (layout) {
192 case DataLayout::kYXBatchDepth:
193 depth_idx = data_dims - 1;
194 batch_idx = data_dims - 2;
195 spatial_idx = 0;
196 break;
197
198 case DataLayout::kYXDepthBatch:
199 depth_idx = data_dims - 2;
200 batch_idx = data_dims - 1;
201 spatial_idx = 0;
202 break;
203
204 case DataLayout::kBatchYXDepth:
205 depth_idx = data_dims - 1;
206 batch_idx = 0;
207 spatial_idx = 1;
208 break;
209
210 case DataLayout::kBatchDepthYX:
211 case DataLayout::kBatchDepthYX4:
212 depth_idx = 1;
213 batch_idx = 0;
214 spatial_idx = 2;
215 break;
216
217 default:
218 LOG(FATAL) << "Unknown layout " << layout;
219 }
220
221 return std::make_tuple(depth_idx, batch_idx, spatial_idx);
222 }
223
ReorderDims(const std::vector<int64> & input,const DataLayout & from,const DataLayout & to)224 std::vector<int64> ReorderDims(const std::vector<int64>& input,
225 const DataLayout& from, const DataLayout& to) {
226 if (from == to) return input;
227
228 int d_idx_from, b_idx_from, spatial_idx_from;
229 int d_idx_to, b_idx_to, spatial_idx_to;
230
231 std::tie(d_idx_from, b_idx_from, spatial_idx_from) =
232 GetDimIndices(from, input.size());
233 std::tie(d_idx_to, b_idx_to, spatial_idx_to) =
234 GetDimIndices(to, input.size());
235
236 std::vector<int64> reordered(input.size());
237 reordered[b_idx_to] = input[b_idx_from];
238 reordered[d_idx_to] = input[d_idx_from];
239
240 for (size_t i = 0; i < input.size() - 2;
241 i++, spatial_idx_from++, spatial_idx_to++) {
242 reordered[spatial_idx_to] = input[spatial_idx_from];
243 }
244
245 return reordered;
246 }
247
248 // -- AlgorithmConfig
249
ToString() const250 string AlgorithmConfig::ToString() const {
251 string algo = "none";
252 if (algorithm().has_value()) {
253 algo = algorithm()->ToString();
254 }
255 string algo_no_scratch = "none";
256 if (algorithm_no_scratch().has_value()) {
257 algo_no_scratch = algorithm_no_scratch()->ToString();
258 }
259 return absl::StrCat(algo, ", ", algo_no_scratch);
260 }
261
262 // -- BatchDescriptor
263
BatchDescriptor(int ndims)264 BatchDescriptor::BatchDescriptor(int ndims)
265 : value_max_(0.0),
266 value_min_(0.0),
267 quantized_activation_mode_(QuantizedActivationMode::k8Bit) {
268 tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
269 set_layout(DataLayout::kYXDepthBatch);
270 }
271
BatchDescriptor()272 BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {}
273
full_dims(const DataLayout & layout) const274 std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const {
275 std::vector<int64> bdyx_dims(ndims() + 2);
276 bdyx_dims[0] = count();
277 bdyx_dims[1] = feature_map_count();
278 std::copy(spatial_size().begin(), spatial_size().end(),
279 bdyx_dims.begin() + 2);
280 return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout);
281 }
282
full_strides(const DataLayout & layout) const283 std::vector<int64> BatchDescriptor::full_strides(
284 const DataLayout& layout) const {
285 if (this->layout() == DataLayout::kBatchDepthYX4) {
286 LOG(FATAL)
287 << "Cannot compute full strides for batch descriptor " << ToString()
288 << ", because its layout is kBatchDepthYX4. In fact, "
289 "cudnnSetTensorNdDescriptor doesn't work for kBatchDepthYX4 at all. "
290 "Use cudnnSetTensor4DDescriptor to set cudnnTensorDescriptor_t "
291 "instead.";
292 }
293 std::vector<int64> phys_dims = full_dims(this->layout());
294 std::vector<int64> phys_strides(phys_dims.size());
295 phys_strides[ndims() + 1] = 1;
296 for (int i = ndims(); i >= 0; i--) {
297 phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
298 }
299 return ReorderDims(phys_strides, this->layout(), layout);
300 }
301
CloneFrom(const BatchDescriptor & other)302 void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
303 tensor_ = other.tensor_;
304 value_max_ = other.value_max_;
305 value_min_ = other.value_min_;
306 quantized_activation_mode_ = other.quantized_activation_mode_;
307 }
308
ToString() const309 string BatchDescriptor::ToString() const {
310 string spatial;
311 for (int i = 0; i < ndims(); i++) {
312 absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
313 }
314 return absl::StrFormat(
315 "{count: %d feature_map_count: %d spatial: %s "
316 "value_min: %f value_max: %f layout: %s}",
317 count(), feature_map_count(), spatial, value_min_, value_max_,
318 DataLayoutString(layout()));
319 }
320
ToShortString() const321 string BatchDescriptor::ToShortString() const {
322 // All the constituent strings are less than 15 characters, so the
323 // small string optimization ensures that there will be at most one
324 // heap memory allocation.
325 string depth = absl::StrCat("d", feature_map_count());
326 string batch = absl::StrCat("b", count());
327
328 string spatial = "s";
329 for (int i = 0; i < ndims(); i++) {
330 absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
331 }
332
333 string suffix;
334 if (value_min() != value_max()) {
335 absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
336 }
337 if (quantized_activation_mode() == QuantizedActivationMode::k16Bit) {
338 suffix += "_16bit";
339 }
340
341 switch (layout()) {
342 case DataLayout::kYXDepthBatch:
343 return absl::StrCat(spatial, depth, batch, suffix);
344 case DataLayout::kYXBatchDepth:
345 return absl::StrCat(spatial, batch, depth, suffix);
346 case DataLayout::kBatchYXDepth:
347 return absl::StrCat(batch, spatial, depth, suffix);
348 case DataLayout::kBatchDepthYX:
349 return absl::StrCat(batch, depth, spatial, suffix);
350 case DataLayout::kBatchDepthYX4:
351 return absl::StrCat(batch, depth, spatial, suffix, "(VECT_C)");
352 default:
353 LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
354 return ""; // Avoid return warning (unreachable)
355 }
356 }
357
NodesPerFeatureMap() const358 int64 BatchDescriptor::NodesPerFeatureMap() const {
359 int64 ret = 1;
360 for (int i = 0; i < ndims(); i++) {
361 ret *= spatial_size()[i];
362 }
363 return ret;
364 }
365
NodesAcrossFeatureMaps() const366 int64 BatchDescriptor::NodesAcrossFeatureMaps() const {
367 return NodesPerFeatureMap() * feature_map_count();
368 }
369
ElementCount() const370 int64 BatchDescriptor::ElementCount() const {
371 return count() * feature_map_count() * NodesPerFeatureMap();
372 }
373
FullyConnectedWeightCount(const BatchDescriptor & input,const BatchDescriptor & output)374 int64 BatchDescriptor::FullyConnectedWeightCount(
375 const BatchDescriptor& input, const BatchDescriptor& output) {
376 return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps();
377 }
378
FullyConnectedBiasCount(const BatchDescriptor & output)379 int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) {
380 return output.NodesAcrossFeatureMaps();
381 }
382
DepthConcatenateOutputDescriptor(port::ArraySlice<dnn::BatchDescriptor> inputs)383 BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
384 port::ArraySlice<dnn::BatchDescriptor> inputs) {
385 if (inputs.empty()) {
386 return BatchDescriptor();
387 }
388 int feature_map_count = 0;
389 for (const auto& dimensions : inputs) {
390 feature_map_count += dimensions.feature_map_count();
391 }
392 BatchDescriptor output = inputs[0];
393 output.set_feature_map_count(feature_map_count);
394 return output;
395 }
396
ToProto(DataType data_type) const397 TensorDescriptorProto BatchDescriptor::ToProto(DataType data_type) const {
398 CHECK_EQ(0.0, value_max_);
399 CHECK_EQ(0.0, value_min_);
400 CHECK(quantized_activation_mode_ == QuantizedActivationMode::k8Bit);
401
402 TensorDescriptorProto ret = tensor_;
403 ret.set_data_type(data_type);
404 return ret;
405 }
406
407 // -- FilterDescriptor
408
FilterDescriptor(int ndims)409 FilterDescriptor::FilterDescriptor(int ndims) {
410 tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
411 set_layout(FilterLayout::kOutputInputYX);
412 }
413
FilterDescriptor()414 FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {}
415
~FilterDescriptor()416 FilterDescriptor::~FilterDescriptor() {}
417
CloneFrom(const FilterDescriptor & other)418 void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
419 tensor_ = other.tensor_;
420 }
421
ToString() const422 string FilterDescriptor::ToString() const {
423 string desc = absl::StrFormat(
424 "{output_feature_map_count: %d input_feature_map_count: %d "
425 "layout: %s shape: ",
426 output_feature_map_count(), input_feature_map_count(),
427 FilterLayoutString(layout()));
428 for (int i = 0; i < ndims(); i++) {
429 absl::StrAppendFormat(&desc, "%d ", input_filter_dims()[i]);
430 }
431 absl::StrAppend(&desc, "}");
432
433 return desc;
434 }
435
ToShortString() const436 string FilterDescriptor::ToShortString() const {
437 // All the constituent strings are less than 15 characters, so the
438 // small string optimization ensures that there will be at most one
439 // heap memory allocation.
440 string od = absl::StrCat("od", output_feature_map_count());
441 string id = absl::StrCat("id", input_feature_map_count());
442
443 string spatial = "s";
444 for (int i = 0; i < ndims(); i++) {
445 absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
446 }
447
448 switch (layout()) {
449 case FilterLayout::kOutputInputYX:
450 return absl::StrCat(od, id, spatial);
451 case FilterLayout::kOutputYXInput:
452 return absl::StrCat(od, spatial, id);
453 case FilterLayout::kOutputInputYX4:
454 return absl::StrCat(od, id, spatial, "(VECT_C)");
455 case FilterLayout::kInputYXOutput:
456 return absl::StrCat(id, spatial, od);
457 case FilterLayout::kYXInputOutput:
458 return absl::StrCat(spatial, id, od);
459 default:
460 LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
461 return ""; // Avoid return warning (unreachable)
462 }
463 }
464
ComputeWeightCount() const465 int64 FilterDescriptor::ComputeWeightCount() const {
466 int64 ret = output_feature_map_count() * input_feature_map_count();
467 for (int i = 0; i < ndims(); i++) {
468 ret *= input_filter_dims()[i];
469 }
470 return ret;
471 }
472
ToProto(DataType data_type) const473 TensorDescriptorProto FilterDescriptor::ToProto(DataType data_type) const {
474 TensorDescriptorProto ret = tensor_;
475 ret.set_data_type(data_type);
476 return ret;
477 }
478
479 // -- ConvolutionDescriptor
480
ConvolutionDescriptor(int ndims)481 ConvolutionDescriptor::ConvolutionDescriptor(int ndims) {
482 proto_.mutable_paddings()->Resize(ndims, 0);
483 proto_.mutable_strides()->Resize(ndims, 1);
484 proto_.mutable_dilations()->Resize(ndims, 1);
485 proto_.set_group_count(1);
486 proto_.set_convolution_mode(ConvolutionMode::CROSS_CORRELATION);
487 }
488
ConvolutionDescriptor()489 ConvolutionDescriptor::ConvolutionDescriptor()
490 : ConvolutionDescriptor(/*ndims=*/2) {}
491
~ConvolutionDescriptor()492 ConvolutionDescriptor::~ConvolutionDescriptor() {}
493
ToString() const494 string ConvolutionDescriptor::ToString() const {
495 string padding;
496 string strides;
497 string dilations;
498 for (int i = 0; i < ndims(); i++) {
499 absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
500 absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
501 absl::StrAppendFormat(&dilations, "%d ", this->dilations()[i]);
502 }
503
504 return absl::StrFormat(
505 "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
506 "%s}",
507 padding, PadAlignmentString(pad_alignment()), strides, dilations);
508 }
509
ToShortString() const510 string ConvolutionDescriptor::ToShortString() const {
511 string desc;
512 for (int i = 0; i < ndims(); i++) {
513 if (i > 0) absl::StrAppend(&desc, "_");
514 absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
515 }
516 for (int i = 0; i < ndims(); i++) {
517 absl::StrAppendFormat(&desc, "_s%d:%d", i, strides()[i]);
518 }
519 for (int i = 0; i < ndims(); i++) {
520 absl::StrAppendFormat(&desc, "_d%d:%d", i, dilations()[i]);
521 }
522 return desc;
523 }
524
525 // -- PoolingDescriptor
526
PoolingDescriptor(int ndims)527 PoolingDescriptor::PoolingDescriptor(int ndims)
528 : mode_(dnn::PoolingMode::kMaximum),
529 ndims_(ndims),
530 propagate_nans_(false),
531 window_(ndims, 0),
532 padding_(ndims, 0),
533 strides_(ndims, 1) {}
534
PoolingDescriptor()535 PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
536
CloneFrom(const PoolingDescriptor & other)537 void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
538 mode_ = other.mode_;
539 ndims_ = other.ndims_;
540 window_ = other.window_;
541 padding_ = other.padding_;
542 strides_ = other.strides_;
543 propagate_nans_ = other.propagate_nans_;
544 }
545
ToString() const546 string PoolingDescriptor::ToString() const {
547 const char* mode_string =
548 mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
549
550 string window, strides, padding;
551 for (int i = 0; i < ndims_; i++) {
552 absl::StrAppendFormat(&window, "%d ", window_[i]);
553 absl::StrAppendFormat(&strides, "%d ", strides_[i]);
554 absl::StrAppendFormat(&padding, "%d", padding_[i]);
555 }
556
557 const char* propagate_string = propagate_nans_ ? "Yes" : "No";
558
559 return absl::StrFormat(
560 "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}",
561 mode_string, window, strides, padding, propagate_string);
562 }
563
ToShortString() const564 string PoolingDescriptor::ToShortString() const {
565 string window, strides, padding;
566 for (int i = 0; i < ndims_; i++) {
567 absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
568 absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
569 absl::StrAppendFormat(&padding, "_p%d:%d", i, padding_[i]);
570 }
571 return absl::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
572 window, strides, padding,
573 propagate_nans_ ? "propagate_nans" : "ignore_nans");
574 }
575
576 // -- NormalizeDescriptor
577
NormalizeDescriptor()578 NormalizeDescriptor::NormalizeDescriptor()
579 : bias_(0.0),
580 range_(0),
581 alpha_(0.0),
582 beta_(0.0),
583 wrap_around_(false),
584 segment_size_(0) {}
585
CloneFrom(const NormalizeDescriptor & other)586 void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
587 bias_ = other.bias_;
588 range_ = other.range_;
589 alpha_ = other.alpha_;
590 beta_ = other.beta_;
591 wrap_around_ = other.wrap_around_;
592 segment_size_ = other.segment_size_;
593 }
594
ToString() const595 string NormalizeDescriptor::ToString() const {
596 return absl::StrFormat(
597 "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
598 "segment_size: %d}",
599 bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
600 }
601
ToShortString() const602 string NormalizeDescriptor::ToShortString() const {
603 return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
604 "_beta:", beta_, "_wrap:", wrap_around_,
605 "_size:", segment_size_);
606 }
607
IsStatusOk(const port::Status & status,bool report_error)608 bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
609 if (status.ok()) {
610 return true;
611 }
612 if (report_error) {
613 LOG(ERROR) << status.error_message();
614 }
615 return false;
616 }
617
DoCtcLoss(Stream * stream,dnn::DataType element_type,const RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory)618 port::Status DnnSupport::DoCtcLoss(Stream* stream, dnn::DataType element_type,
619 const RnnStateTensorDescriptor& probs_desc,
620 const DeviceMemoryBase probs_data,
621 absl::Span<const int> labels_data,
622 absl::Span<const int> labels_lengths_data,
623 absl::Span<const int> input_lengths_data,
624 DeviceMemoryBase costs_data,
625 const RnnStateTensorDescriptor& grads_desc,
626 DeviceMemoryBase grads_data,
627 DeviceMemory<uint8> scratch_memory) {
628 return port::UnimplementedError("CtcLoss not implemented");
629 }
630
631 } // namespace dnn
632 } // namespace stream_executor
633