• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/dnn.h"
17 
18 #include "absl/hash/hash.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 
22 namespace stream_executor {
23 namespace dnn {
24 
25 constexpr DataType ToDataType<float>::value;
26 constexpr DataType ToDataType<double>::value;
27 constexpr DataType ToDataType<Eigen::half>::value;
28 constexpr DataType ToDataType<int8>::value;
29 constexpr DataType ToDataType<int32>::value;
30 
hash() const31 uint64 AlgorithmDesc::hash() const {
32   if (IsExecutionPlan()) {
33     auto p = exec_plan_id();
34     return absl::Hash<decltype(p)>()(p);
35   }
36   auto p = std::make_pair(algo_id(), tensor_ops_enabled());
37   return absl::Hash<decltype(p)>()(p);
38 }
39 
ToString() const40 std::string AlgorithmDesc::ToString() const {
41   if (IsExecutionPlan()) {
42     return absl::StrCat(exec_plan_id());
43   }
44   if (tensor_ops_enabled()) {
45     return absl::StrCat(algo_id(), "#TC");
46   } else {
47     return absl::StrCat(algo_id());
48   }
49 }
50 
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)51 bool DnnSupport::GetConvolveAlgorithms(
52     CudaComputeCapability cuda_compute_capability,
53     std::vector<AlgorithmDesc>* out_algorithms) {
54   return false;
55 }
56 
GetConvolveExecutionPlans(dnn::ConvolutionKind,dnn::DataType,Stream *,const dnn::BatchDescriptor &,const dnn::FilterDescriptor &,const dnn::BatchDescriptor &,const dnn::ConvolutionDescriptor &,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> *)57 bool DnnSupport::GetConvolveExecutionPlans(
58     dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/,
59     Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/,
60     const dnn::FilterDescriptor& /*filter_descriptor*/,
61     const dnn::BatchDescriptor& /*output_descriptor*/,
62     const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
63     std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* /*exec_plans*/) {
64   return false;
65 }
66 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind,dnn::DataType,Stream *,const dnn::BatchDescriptor &,DeviceMemoryBase input_data,const dnn::FilterDescriptor &,DeviceMemoryBase filter_data,const dnn::BatchDescriptor &,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor &,ScratchAllocator * scratch_allocator,std::vector<ProfileResult> *)67 bool DnnSupport::GetMIOpenConvolveAlgorithms(
68     dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/,
69     Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/,
70     DeviceMemoryBase input_data,
71     const dnn::FilterDescriptor& /*filter_descriptor*/,
72     DeviceMemoryBase filter_data,
73     const dnn::BatchDescriptor& /*output_descriptor*/,
74     DeviceMemoryBase output_data,
75     const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
76     ScratchAllocator* scratch_allocator,
77     std::vector<ProfileResult>* /*out_algorithms*/) {
78   return false;
79 }
80 
GetRnnAlgorithms(std::vector<AlgorithmDesc> * out_algorithms)81 bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
82   return false;
83 }
84 
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)85 bool DnnSupport::GetConvolveBackwardDataAlgorithms(
86     CudaComputeCapability cuda_compute_capability,
87     std::vector<AlgorithmDesc>* out_algorithms) {
88   return false;
89 }
90 
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)91 bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
92     CudaComputeCapability cuda_compute_capability,
93     std::vector<AlgorithmDesc>* out_algorithms) {
94   return false;
95 }
96 
QuantizedActivationModeString(QuantizedActivationMode mode)97 std::string QuantizedActivationModeString(QuantizedActivationMode mode) {
98   switch (mode) {
99     case dnn::QuantizedActivationMode::k8Bit:
100       return "uint8";
101     case dnn::QuantizedActivationMode::k16Bit:
102       return "uint16";
103     case dnn::QuantizedActivationMode::k32Bit:
104       return "int32";
105     default:
106       LOG(FATAL) << "Unknown quantized_activation_mode "
107                  << static_cast<int32>(mode);
108   }
109   return "unknown quantized_activation_mode";
110 }
111 
ActivationModeString(ActivationMode mode)112 std::string ActivationModeString(ActivationMode mode) {
113   switch (mode) {
114     case ActivationMode::kNone:
115       return "none";
116     case ActivationMode::kSigmoid:
117       return "sigmoid";
118     case ActivationMode::kRelu:
119       return "relu";
120     case ActivationMode::kRelu6:
121       return "relu6";
122     case ActivationMode::kReluX:
123       return "reluX";
124     case ActivationMode::kTanh:
125       return "tanh";
126     case ActivationMode::kBandPass:
127       return "bandpass";
128     default:
129       LOG(FATAL) << "Unknown activation_mode " << static_cast<int32>(mode);
130   }
131   return "unknown activation_mode";
132 }
133 
ElementwiseOperationString(ElementwiseOperation op)134 std::string ElementwiseOperationString(ElementwiseOperation op) {
135   switch (op) {
136     case ElementwiseOperation::kAdd:
137       return "add";
138     case ElementwiseOperation::kMultiply:
139       return "multiply";
140     default:
141       LOG(FATAL) << "Unknown elementwise op " << static_cast<int32>(op);
142   }
143   return "unknown element wise op";
144 }
145 
DataLayoutString(DataLayout layout)146 std::string DataLayoutString(DataLayout layout) {
147   switch (layout) {
148     case DataLayout::kYXDepthBatch:
149       return "YXDepthBatch";
150     case DataLayout::kYXBatchDepth:
151       return "YXBatchDepth";
152     case DataLayout::kBatchYXDepth:
153       return "BatchYXDepth";
154     case DataLayout::kBatchDepthYX:
155       return "BatchDepthYX";
156     case DataLayout::kBatchDepthYX4:
157       return "BatchDepthYX4";
158     default:
159       LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout);
160   }
161   return "unknown data layout";
162 }
163 
FilterLayoutString(FilterLayout layout)164 std::string FilterLayoutString(FilterLayout layout) {
165   switch (layout) {
166     case FilterLayout::kOutputInputYX:
167       return "OutputInputYX";
168     case FilterLayout::kOutputYXInput:
169       return "OutputYXInput";
170     case FilterLayout::kOutputInputYX4:
171       return "OutputInputYX4";
172     case FilterLayout::kInputYXOutput:
173       return "InputYXOutput";
174     case FilterLayout::kYXInputOutput:
175       return "YXInputOutput";
176     default:
177       LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(layout);
178   }
179   return "unknown filter layout";
180 }
181 
PadAlignmentString(PadAlignment alignment)182 std::string PadAlignmentString(PadAlignment alignment) {
183   switch (alignment) {
184     case PadAlignment::kDefault:
185       return "default";
186     case PadAlignment::kCudnnPadding:
187       return "cuDNN padding";
188     case PadAlignment::kTensorFlowPadding:
189       return "TensorFlow padding";
190   }
191   return "unknown pad alignment";
192 }
193 
operator <<(std::ostream & str,dnn::PadAlignment alignment)194 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
195   return str << PadAlignmentString(alignment);
196 }
197 
ShortPoolingModeString(PoolingMode mode)198 std::string ShortPoolingModeString(PoolingMode mode) {
199   switch (mode) {
200     case PoolingMode::kMaximum:
201       return "Max";
202     case PoolingMode::kAverage:
203       return "Avg";
204     default:
205       LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode);
206   }
207   return "unknown filter layout";
208 }
209 
210 struct ConvDimIndices {
211   union {
212     struct {
213       int depth_idx;
214       int batch_idx;
215       int spatial_idx;
216     } data;
217     struct {
218       int output_idx;
219       int input_idx;
220       int spatial_idx;
221     } filter;
222   };
223 };
224 
GetDimIndices(const DataLayout & layout,const int data_dims)225 ConvDimIndices GetDimIndices(const DataLayout& layout, const int data_dims) {
226   ConvDimIndices dim_indices;
227   switch (layout) {
228     case DataLayout::kYXBatchDepth:
229       dim_indices.data.depth_idx = data_dims - 1;
230       dim_indices.data.batch_idx = data_dims - 2;
231       dim_indices.data.spatial_idx = 0;
232       break;
233 
234     case DataLayout::kYXDepthBatch:
235       dim_indices.data.depth_idx = data_dims - 2;
236       dim_indices.data.batch_idx = data_dims - 1;
237       dim_indices.data.spatial_idx = 0;
238       break;
239 
240     case DataLayout::kBatchYXDepth:
241       dim_indices.data.depth_idx = data_dims - 1;
242       dim_indices.data.batch_idx = 0;
243       dim_indices.data.spatial_idx = 1;
244       break;
245 
246     case DataLayout::kBatchDepthYX:
247     case DataLayout::kBatchDepthYX4:
248     case DataLayout::kBatchDepthYX32:
249       dim_indices.data.depth_idx = 1;
250       dim_indices.data.batch_idx = 0;
251       dim_indices.data.spatial_idx = 2;
252       break;
253 
254     default:
255       LOG(FATAL) << "Unknown layout " << layout;
256   }
257 
258   return dim_indices;
259 }
260 
GetDimIndices(const FilterLayout & layout,const int data_dims)261 ConvDimIndices GetDimIndices(const FilterLayout& layout, const int data_dims) {
262   ConvDimIndices dim_indices;
263   switch (layout) {
264     case FilterLayout::kOutputInputYX:
265     case FilterLayout::kOutputInputYX4:
266     case FilterLayout::kOutputInputYX32:
267       dim_indices.filter.input_idx = 1;
268       dim_indices.filter.output_idx = 0;
269       dim_indices.filter.spatial_idx = 2;
270       break;
271 
272     case FilterLayout::kOutputYXInput:
273       dim_indices.filter.input_idx = data_dims - 1;
274       dim_indices.filter.output_idx = 0;
275       dim_indices.filter.spatial_idx = 1;
276       break;
277 
278     case FilterLayout::kInputYXOutput:
279       dim_indices.filter.input_idx = 0;
280       dim_indices.filter.output_idx = data_dims - 1;
281       dim_indices.filter.spatial_idx = 1;
282       break;
283 
284     case FilterLayout::kYXInputOutput:
285       dim_indices.filter.input_idx = data_dims - 2;
286       dim_indices.filter.output_idx = data_dims - 1;
287       dim_indices.filter.spatial_idx = 0;
288       break;
289 
290     default:
291       LOG(FATAL) << "Unknown layout " << layout;
292   }
293 
294   return dim_indices;
295 }
296 
ReorderDims(const std::vector<int64> & input,const DataLayout & from,const DataLayout & to)297 std::vector<int64> ReorderDims(const std::vector<int64>& input,
298                                const DataLayout& from, const DataLayout& to) {
299   if (from == to) return input;
300 
301   ConvDimIndices from_indices = GetDimIndices(from, input.size());
302   ConvDimIndices to_indices = GetDimIndices(to, input.size());
303 
304   std::vector<int64> reordered(input.size());
305   reordered[to_indices.data.batch_idx] = input[from_indices.data.batch_idx];
306   reordered[to_indices.data.depth_idx] = input[from_indices.data.depth_idx];
307 
308   int spatial_idx_from = from_indices.data.spatial_idx;
309   int spatial_idx_to = to_indices.data.spatial_idx;
310   for (size_t i = 0; i < input.size() - 2;
311        i++, spatial_idx_from++, spatial_idx_to++) {
312     reordered[spatial_idx_to] = input[spatial_idx_from];
313   }
314 
315   return reordered;
316 }
317 
ReorderDims(const std::vector<int64> & input,const FilterLayout & from,const FilterLayout & to)318 std::vector<int64> ReorderDims(const std::vector<int64>& input,
319                                const FilterLayout& from,
320                                const FilterLayout& to) {
321   if (from == to) return input;
322 
323   ConvDimIndices from_indices = GetDimIndices(from, input.size());
324   ConvDimIndices to_indices = GetDimIndices(to, input.size());
325 
326   std::vector<int64> reordered(input.size());
327   reordered[to_indices.filter.output_idx] =
328       input[from_indices.filter.output_idx];
329   reordered[to_indices.filter.input_idx] = input[from_indices.filter.input_idx];
330 
331   int spatial_idx_from = from_indices.filter.spatial_idx;
332   int spatial_idx_to = to_indices.filter.spatial_idx;
333   for (size_t i = 0; i < input.size() - 2;
334        i++, spatial_idx_from++, spatial_idx_to++) {
335     reordered[spatial_idx_to] = input[spatial_idx_from];
336   }
337 
338   return reordered;
339 }
340 
341 // -- AlgorithmConfig
342 
ToString() const343 std::string AlgorithmConfig::ToString() const {
344   std::string algo = "none";
345   if (algorithm().has_value()) {
346     algo = algorithm()->ToString();
347   }
348   std::string algo_no_scratch = "none";
349   if (algorithm_no_scratch().has_value()) {
350     algo_no_scratch = algorithm_no_scratch()->ToString();
351   }
352   return absl::StrCat(algo, ", ", algo_no_scratch);
353 }
354 
355 // -- BatchDescriptor
356 
BatchDescriptor(int ndims)357 BatchDescriptor::BatchDescriptor(int ndims)
358     : value_max_(0.0),
359       value_min_(0.0),
360       quantized_activation_mode_(QuantizedActivationMode::k8Bit) {
361   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
362   set_layout(DataLayout::kYXDepthBatch);
363 }
364 
BatchDescriptor()365 BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {}
366 
full_dims(const DataLayout & layout) const367 std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const {
368   std::vector<int64> bdyx_dims(ndims() + 2);
369   bdyx_dims[0] = count();
370   bdyx_dims[1] = feature_map_count();
371   std::copy(spatial_size().begin(), spatial_size().end(),
372             bdyx_dims.begin() + 2);
373   return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout);
374 }
375 
full_strides(const DataLayout & layout) const376 std::vector<int64> BatchDescriptor::full_strides(
377     const DataLayout& layout) const {
378   std::vector<int64> phys_dims = full_dims(this->layout());
379   std::vector<int64> phys_strides(phys_dims.size());
380   phys_strides[ndims() + 1] = 1;
381   for (int i = ndims(); i >= 0; i--) {
382     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
383   }
384   return ReorderDims(phys_strides, this->layout(), layout);
385 }
386 
vectorized_dims(const DataLayout & layout,int vector_size,int vector_dim) const387 std::vector<int64> BatchDescriptor::vectorized_dims(const DataLayout& layout,
388                                                     int vector_size,
389                                                     int vector_dim) const {
390   std::vector<int64> bdyx_dims = full_dims(dnn::DataLayout::kBatchDepthYX);
391   if (vector_dim != -1) {
392     bdyx_dims[vector_dim] /= vector_size;
393   }
394   return dnn::ReorderDims(bdyx_dims, dnn::DataLayout::kBatchDepthYX, layout);
395 }
396 
vectorized_strides(const DataLayout & layout,int vector_size,int vector_dim) const397 std::vector<int64> BatchDescriptor::vectorized_strides(const DataLayout& layout,
398                                                        int vector_size,
399                                                        int vector_dim) const {
400   std::vector<int64> phys_dims =
401       vectorized_dims(this->layout(), vector_size, vector_dim);
402   std::vector<int64> phys_strides(phys_dims.size());
403   phys_strides[phys_dims.size() - 1] = 1;
404   for (int i = phys_dims.size() - 2; i >= 0; i--) {
405     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
406   }
407   return ReorderDims(phys_strides, this->layout(), layout);
408 }
409 
CloneFrom(const BatchDescriptor & other)410 void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
411   tensor_ = other.tensor_;
412   value_max_ = other.value_max_;
413   value_min_ = other.value_min_;
414   quantized_activation_mode_ = other.quantized_activation_mode_;
415 }
416 
ToString() const417 std::string BatchDescriptor::ToString() const {
418   std::string spatial;
419   for (int i = 0; i < ndims(); i++) {
420     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
421   }
422   return absl::StrFormat(
423       "{count: %d feature_map_count: %d spatial: %s "
424       "value_min: %f value_max: %f layout: %s}",
425       count(), feature_map_count(), spatial, value_min_, value_max_,
426       DataLayoutString(layout()));
427 }
428 
ToShortString() const429 std::string BatchDescriptor::ToShortString() const {
430   // All the constituent strings are less than 15 characters, so the
431   // small string optimization ensures that there will be at most one
432   // heap memory allocation.
433   std::string depth = absl::StrCat("d", feature_map_count());
434   std::string batch = absl::StrCat("b", count());
435 
436   std::string spatial = "s";
437   for (int i = 0; i < ndims(); i++) {
438     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
439   }
440 
441   std::string suffix;
442   if (value_min() != value_max()) {
443     absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
444   }
445   if (quantized_activation_mode() == QuantizedActivationMode::k16Bit) {
446     suffix += "_16bit";
447   }
448 
449   switch (layout()) {
450     case DataLayout::kYXDepthBatch:
451       return absl::StrCat(spatial, depth, batch, suffix);
452     case DataLayout::kYXBatchDepth:
453       return absl::StrCat(spatial, batch, depth, suffix);
454     case DataLayout::kBatchYXDepth:
455       return absl::StrCat(batch, spatial, depth, suffix);
456     case DataLayout::kBatchDepthYX:
457       return absl::StrCat(batch, depth, spatial, suffix);
458     case DataLayout::kBatchDepthYX4:
459       return absl::StrCat(batch, depth, spatial, suffix, "(VECT_C)");
460     default:
461       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
462       return "";  // Avoid return warning (unreachable)
463   }
464 }
465 
NodesPerFeatureMap() const466 int64 BatchDescriptor::NodesPerFeatureMap() const {
467   int64_t ret = 1;
468   for (int i = 0; i < ndims(); i++) {
469     ret *= spatial_size()[i];
470   }
471   return ret;
472 }
473 
NodesAcrossFeatureMaps() const474 int64 BatchDescriptor::NodesAcrossFeatureMaps() const {
475   return NodesPerFeatureMap() * feature_map_count();
476 }
477 
ElementCount() const478 int64 BatchDescriptor::ElementCount() const {
479   return count() * feature_map_count() * NodesPerFeatureMap();
480 }
481 
FullyConnectedWeightCount(const BatchDescriptor & input,const BatchDescriptor & output)482 int64 BatchDescriptor::FullyConnectedWeightCount(
483     const BatchDescriptor& input, const BatchDescriptor& output) {
484   return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps();
485 }
486 
FullyConnectedBiasCount(const BatchDescriptor & output)487 int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) {
488   return output.NodesAcrossFeatureMaps();
489 }
490 
DepthConcatenateOutputDescriptor(port::ArraySlice<dnn::BatchDescriptor> inputs)491 BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
492     port::ArraySlice<dnn::BatchDescriptor> inputs) {
493   if (inputs.empty()) {
494     return BatchDescriptor();
495   }
496   int feature_map_count = 0;
497   for (const auto& dimensions : inputs) {
498     feature_map_count += dimensions.feature_map_count();
499   }
500   BatchDescriptor output = inputs[0];
501   output.set_feature_map_count(feature_map_count);
502   return output;
503 }
504 
ToProto(DataType data_type) const505 TensorDescriptorProto BatchDescriptor::ToProto(DataType data_type) const {
506   CHECK_EQ(0.0, value_max_);
507   CHECK_EQ(0.0, value_min_);
508   CHECK(quantized_activation_mode_ == QuantizedActivationMode::k8Bit);
509 
510   TensorDescriptorProto ret = tensor_;
511   ret.set_data_type(data_type);
512   return ret;
513 }
514 
515 // -- FilterDescriptor
516 
FilterDescriptor(int ndims)517 FilterDescriptor::FilterDescriptor(int ndims) {
518   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
519   set_layout(FilterLayout::kOutputInputYX);
520 }
521 
FilterDescriptor()522 FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {}
523 
~FilterDescriptor()524 FilterDescriptor::~FilterDescriptor() {}
525 
CloneFrom(const FilterDescriptor & other)526 void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
527   tensor_ = other.tensor_;
528 }
529 
ToString() const530 std::string FilterDescriptor::ToString() const {
531   std::string desc = absl::StrFormat(
532       "{output_feature_map_count: %d input_feature_map_count: %d "
533       "layout: %s shape: ",
534       output_feature_map_count(), input_feature_map_count(),
535       FilterLayoutString(layout()));
536   for (int i = 0; i < ndims(); i++) {
537     absl::StrAppendFormat(&desc, "%d ", input_filter_dims()[i]);
538   }
539   absl::StrAppend(&desc, "}");
540 
541   return desc;
542 }
543 
ToShortString() const544 std::string FilterDescriptor::ToShortString() const {
545   // All the constituent strings are less than 15 characters, so the
546   // small string optimization ensures that there will be at most one
547   // heap memory allocation.
548   std::string od = absl::StrCat("od", output_feature_map_count());
549   std::string id = absl::StrCat("id", input_feature_map_count());
550 
551   std::string spatial = "s";
552   for (int i = 0; i < ndims(); i++) {
553     absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
554   }
555 
556   switch (layout()) {
557     case FilterLayout::kOutputInputYX:
558       return absl::StrCat(od, id, spatial);
559     case FilterLayout::kOutputYXInput:
560       return absl::StrCat(od, spatial, id);
561     case FilterLayout::kOutputInputYX4:
562       return absl::StrCat(od, id, spatial, "(VECT_C)");
563     case FilterLayout::kInputYXOutput:
564       return absl::StrCat(id, spatial, od);
565     case FilterLayout::kYXInputOutput:
566       return absl::StrCat(spatial, id, od);
567     default:
568       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
569       return "";  // Avoid return warning (unreachable)
570   }
571 }
572 
ComputeWeightCount() const573 int64 FilterDescriptor::ComputeWeightCount() const {
574   int64_t ret = output_feature_map_count() * input_feature_map_count();
575   for (int i = 0; i < ndims(); i++) {
576     ret *= input_filter_dims()[i];
577   }
578   return ret;
579 }
580 
full_dims(const FilterLayout & layout) const581 std::vector<int64> FilterDescriptor::full_dims(
582     const FilterLayout& layout) const {
583   std::vector<int64> oiyx_dims(ndims() + 2);
584   oiyx_dims[0] = output_feature_map_count();
585   oiyx_dims[1] = input_feature_map_count();
586   std::copy(input_filter_dims().begin(), input_filter_dims().end(),
587             oiyx_dims.begin() + 2);
588   return ReorderDims(oiyx_dims, FilterLayout::kOutputInputYX, layout);
589 }
590 
full_strides(const FilterLayout & layout) const591 std::vector<int64> FilterDescriptor::full_strides(
592     const FilterLayout& layout) const {
593   std::vector<int64> phys_dims = full_dims(this->layout());
594   std::vector<int64> phys_strides(phys_dims.size());
595   phys_strides[ndims() + 1] = 1;
596   for (int i = ndims(); i >= 0; i--) {
597     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
598   }
599   return ReorderDims(phys_strides, this->layout(), layout);
600 }
601 
vectorized_dims(const FilterLayout & layout,int vector_size,int vector_dim) const602 std::vector<int64> FilterDescriptor::vectorized_dims(const FilterLayout& layout,
603                                                      int vector_size,
604                                                      int vector_dim) const {
605   std::vector<int64> oiyx_dims = full_dims(dnn::FilterLayout::kOutputInputYX);
606   if (vector_dim != -1) {
607     oiyx_dims[vector_dim] /= vector_size;
608   }
609   return ReorderDims(oiyx_dims, FilterLayout::kOutputInputYX, layout);
610 }
611 
vectorized_strides(const FilterLayout & layout,int vector_size,int vector_dim) const612 std::vector<int64> FilterDescriptor::vectorized_strides(
613     const FilterLayout& layout, int vector_size, int vector_dim) const {
614   std::vector<int64> phys_dims =
615       vectorized_dims(this->layout(), vector_size, vector_dim);
616   std::vector<int64> phys_strides(phys_dims.size());
617   phys_strides[phys_dims.size() - 1] = 1;
618   for (int i = phys_dims.size() - 2; i >= 0; i--) {
619     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
620   }
621   return ReorderDims(phys_strides, this->layout(), layout);
622 }
623 
ToProto(DataType data_type) const624 TensorDescriptorProto FilterDescriptor::ToProto(DataType data_type) const {
625   TensorDescriptorProto ret = tensor_;
626   ret.set_data_type(data_type);
627   return ret;
628 }
629 
630 // -- ConvolutionDescriptor
631 
ConvolutionDescriptor(int ndims)632 ConvolutionDescriptor::ConvolutionDescriptor(int ndims) {
633   proto_.mutable_paddings()->Resize(ndims, 0);
634   proto_.mutable_strides()->Resize(ndims, 1);
635   proto_.mutable_dilations()->Resize(ndims, 1);
636   proto_.set_group_count(1);
637   proto_.set_convolution_mode(ConvolutionMode::CROSS_CORRELATION);
638 }
639 
ConvolutionDescriptor()640 ConvolutionDescriptor::ConvolutionDescriptor()
641     : ConvolutionDescriptor(/*ndims=*/2) {}
642 
~ConvolutionDescriptor()643 ConvolutionDescriptor::~ConvolutionDescriptor() {}
644 
ToString() const645 std::string ConvolutionDescriptor::ToString() const {
646   std::string padding;
647   std::string strides;
648   std::string dilations;
649   for (int i = 0; i < ndims(); i++) {
650     absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
651     absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
652     absl::StrAppendFormat(&dilations, "%d ", this->dilations()[i]);
653   }
654 
655   return absl::StrFormat(
656       "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
657       "%s}",
658       padding, PadAlignmentString(pad_alignment()), strides, dilations);
659 }
660 
ToShortString() const661 std::string ConvolutionDescriptor::ToShortString() const {
662   std::string desc;
663   for (int i = 0; i < ndims(); i++) {
664     if (i > 0) absl::StrAppend(&desc, "_");
665     absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
666   }
667   for (int i = 0; i < ndims(); i++) {
668     absl::StrAppendFormat(&desc, "_s%d:%d", i, strides()[i]);
669   }
670   for (int i = 0; i < ndims(); i++) {
671     absl::StrAppendFormat(&desc, "_d%d:%d", i, dilations()[i]);
672   }
673   return desc;
674 }
675 
676 // -- PoolingDescriptor
677 
PoolingDescriptor(int ndims)678 PoolingDescriptor::PoolingDescriptor(int ndims)
679     : mode_(dnn::PoolingMode::kMaximum),
680       ndims_(ndims),
681       propagate_nans_(false),
682       window_(ndims, 0),
683       padding_(ndims, 0),
684       strides_(ndims, 1) {}
685 
PoolingDescriptor()686 PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
687 
CloneFrom(const PoolingDescriptor & other)688 void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
689   mode_ = other.mode_;
690   ndims_ = other.ndims_;
691   window_ = other.window_;
692   padding_ = other.padding_;
693   strides_ = other.strides_;
694   propagate_nans_ = other.propagate_nans_;
695 }
696 
ToString() const697 std::string PoolingDescriptor::ToString() const {
698   const char* mode_string =
699       mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
700 
701   std::string window, strides, padding;
702   for (int i = 0; i < ndims_; i++) {
703     absl::StrAppendFormat(&window, "%d ", window_[i]);
704     absl::StrAppendFormat(&strides, "%d ", strides_[i]);
705     absl::StrAppendFormat(&padding, "%d", padding_[i]);
706   }
707 
708   const char* propagate_string = propagate_nans_ ? "Yes" : "No";
709 
710   return absl::StrFormat(
711       "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}",
712       mode_string, window, strides, padding, propagate_string);
713 }
714 
ToShortString() const715 std::string PoolingDescriptor::ToShortString() const {
716   std::string window, strides, padding;
717   for (int i = 0; i < ndims_; i++) {
718     absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
719     absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
720     absl::StrAppendFormat(&padding, "_p%d:%d", i, padding_[i]);
721   }
722   return absl::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
723                       window, strides, padding,
724                       propagate_nans_ ? "propagate_nans" : "ignore_nans");
725 }
726 
727 // -- NormalizeDescriptor
728 
NormalizeDescriptor()729 NormalizeDescriptor::NormalizeDescriptor()
730     : bias_(0.0),
731       range_(0),
732       alpha_(0.0),
733       beta_(0.0),
734       wrap_around_(false),
735       segment_size_(0) {}
736 
CloneFrom(const NormalizeDescriptor & other)737 void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
738   bias_ = other.bias_;
739   range_ = other.range_;
740   alpha_ = other.alpha_;
741   beta_ = other.beta_;
742   wrap_around_ = other.wrap_around_;
743   segment_size_ = other.segment_size_;
744 }
745 
ToString() const746 std::string NormalizeDescriptor::ToString() const {
747   return absl::StrFormat(
748       "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
749       "segment_size: %d}",
750       bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
751 }
752 
ToShortString() const753 std::string NormalizeDescriptor::ToShortString() const {
754   return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
755                       "_beta:", beta_, "_wrap:", wrap_around_,
756                       "_size:", segment_size_);
757 }
758 
IsStatusOk(const port::Status & status,bool report_error)759 bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
760   if (status.ok()) {
761     return true;
762   }
763   if (report_error) {
764     LOG(ERROR) << status.error_message();
765   }
766   return false;
767 }
768 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)769 port::Status DnnSupport::DoCtcLoss(
770     Stream* stream, dnn::DataType element_type,
771     const RnnStateTensorDescriptor& probs_desc,
772     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
773     absl::Span<const int> labels_lengths_data,
774     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
775     const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
776     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
777   return port::UnimplementedError("CtcLoss not implemented");
778 }
779 
780 }  // namespace dnn
781 }  // namespace stream_executor
782