• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
17 
18 #include "absl/hash/hash.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/lib/strings/proto_serialization.h"
23 
24 namespace stream_executor {
25 namespace dnn {
26 
27 namespace {
28 
ProtoMapIsSubset(const google::protobuf::Map<int64_t,int64_t> & x,const google::protobuf::Map<int64_t,int64_t> & y)29 bool ProtoMapIsSubset(const google::protobuf::Map<int64_t, int64_t>& x,
30                       const google::protobuf::Map<int64_t, int64_t>& y) {
31   for (const auto& ypair : y) {
32     const auto it = x.find(ypair.first);
33     if (it == x.end() || it->second != ypair.second) return false;
34   }
35   return true;
36 }
37 
ProtoMapsEqual(const google::protobuf::Map<int64_t,int64_t> & x,const google::protobuf::Map<int64_t,int64_t> & y)38 bool ProtoMapsEqual(const google::protobuf::Map<int64_t, int64_t>& x,
39                     const google::protobuf::Map<int64_t, int64_t>& y) {
40   return ProtoMapIsSubset(x, y) && ProtoMapIsSubset(y, x);
41 }
42 
43 }  // namespace
44 
45 constexpr DataType ToDataType<float>::value;
46 constexpr DataType ToDataType<double>::value;
47 constexpr DataType ToDataType<Eigen::half>::value;
48 constexpr DataType ToDataType<Eigen::bfloat16>::value;
49 constexpr DataType ToDataType<int8>::value;
50 constexpr DataType ToDataType<int32>::value;
51 constexpr DataType ToDataType<std::complex<float>>::value;
52 constexpr DataType ToDataType<std::complex<double>>::value;
53 
AlgorithmDesc(int64_t engine_id,const std::vector<std::pair<int64_t,int64_t>> & tuning_knobs,std::optional<uint64_t> workspace_size)54 AlgorithmDesc::AlgorithmDesc(
55     int64_t engine_id,
56     const std::vector<std::pair<int64_t, int64_t>>& tuning_knobs,
57     std::optional<uint64_t> workspace_size) {
58   proto_.set_is_cudnn_frontend(true);
59   proto_.set_algo_id(engine_id);
60   if (workspace_size) {
61     proto_.mutable_workspace_size()->set_value(*workspace_size);
62   }
63   for (const auto& pair : tuning_knobs) {
64     (*proto_.mutable_tuning_knobs())[pair.first] = pair.second;
65   }
66 }
67 
hash() const68 uint64_t AlgorithmDesc::hash() const {
69   return tensorflow::DeterministicProtoHash64(proto_);
70 }
71 
operator ==(const AlgorithmDesc & other) const72 bool AlgorithmDesc::operator==(const AlgorithmDesc& other) const {
73   if (is_cudnn_frontend()) {
74     return other.is_cudnn_frontend() && algo_id() == other.algo_id() &&
75            ProtoMapsEqual(proto_.tuning_knobs(), other.proto_.tuning_knobs());
76   }
77   return !other.is_cudnn_frontend() && algo_id() == other.algo_id() &&
78          tensor_ops_enabled() == other.tensor_ops_enabled();
79 }
80 
ToString() const81 std::string AlgorithmDesc::ToString() const {
82   if (is_cudnn_frontend()) {
83     // Format similarly to cudnn_frontend::ExecutionPlan::getTag(), e.g.
84     // "eng2{k1=2,k3=4}".
85     return absl::StrFormat(
86         "eng%d{%s}", proto_.algo_id(),
87         absl::StrJoin(
88             proto_.tuning_knobs(), ",",
89             [](std::string* out,
90                const google::protobuf::Map<int64_t, int64_t>::value_type& pair) {
91               absl::StrAppendFormat(out, "k%d=%d", pair.first, pair.second);
92             }));
93   }
94   if (tensor_ops_enabled()) {
95     return absl::StrCat(algo_id(), "#TC");
96   } else {
97     return absl::StrCat(algo_id());
98   }
99 }
100 
TuningKnobs() const101 std::vector<std::pair<int64_t, int64_t>> AlgorithmDesc::TuningKnobs() const {
102   std::vector<std::pair<int64_t, int64_t>> result;
103   result.reserve(proto_.tuning_knobs().size());
104   for (const auto& pair : proto_.tuning_knobs()) {
105     result.emplace_back(pair.first, pair.second);
106   }
107   return result;
108 }
109 
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)110 bool DnnSupport::GetConvolveAlgorithms(
111     CudaComputeCapability cuda_compute_capability,
112     std::vector<AlgorithmDesc>* out_algorithms) {
113   return false;
114 }
115 
GetConvolveRunners(bool,dnn::ConvolutionKind,dnn::DataType,dnn::DataType,Stream *,const dnn::BatchDescriptor &,DeviceMemoryBase,const dnn::FilterDescriptor &,DeviceMemoryBase,const dnn::BatchDescriptor &,DeviceMemoryBase,const dnn::ConvolutionDescriptor &,bool,ScratchAllocator *,std::vector<std::unique_ptr<const dnn::ConvRunner>> *)116 port::Status DnnSupport::GetConvolveRunners(
117     bool /* use_cudnn_frontend */, dnn::ConvolutionKind /*kind*/,
118     dnn::DataType /*input_type*/, dnn::DataType /*output_type*/,
119     Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/,
120     DeviceMemoryBase /*input_data*/,
121     const dnn::FilterDescriptor& /*filter_descriptor*/,
122     DeviceMemoryBase /*filter_data*/,
123     const dnn::BatchDescriptor& /*output_descriptor*/,
124     DeviceMemoryBase /*output_data*/,
125     const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
126     bool /*use_fallback*/, ScratchAllocator* /*scratch_allocator*/,
127     std::vector<std::unique_ptr<const dnn::ConvRunner>>* /*exec_plans*/) {
128   return port::UnimplementedError("GetConvolveRunners not implemented.");
129 }
130 
131 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>>
ConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor)132 DnnSupport::ConvolveRunnerFromDesc(
133     Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
134     dnn::ConvolutionKind kind, dnn::DataType element_type,
135     dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
136     const dnn::FilterDescriptor& filter_descriptor,
137     const dnn::BatchDescriptor& output_descriptor,
138     const dnn::ConvolutionDescriptor& convolution_descriptor) {
139   return port::UnimplementedError("ConvolveRunnerFromDesc not implemented.");
140 }
141 
GetFusedConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_input_scale,double side_input_scale,double leakyrelu_alpha,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<const dnn::FusedConvRunner>> * out_exec_plans)142 port::Status DnnSupport::GetFusedConvolveRunners(
143     bool use_cudnn_frontend, dnn::ConvolutionKind kind,
144     dnn::DataType element_type, dnn::DataType bias_type,
145     dnn::DataType output_type, double conv_input_scale, double side_input_scale,
146     double leakyrelu_alpha, Stream* stream,
147     const dnn::BatchDescriptor& input_descriptor,
148     const dnn::FilterDescriptor& filter_descriptor,
149     const dnn::BatchDescriptor& bias_descriptor,
150     const dnn::BatchDescriptor& output_descriptor,
151     const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
152     dnn::ActivationMode activation_mode,
153     std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans) {
154   return port::UnimplementedError("GetFusedConvolveRunners not implemented.");
155 }
156 
157 port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
FusedConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_scale,double side_input_scale,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode)158 DnnSupport::FusedConvolveRunnerFromDesc(
159     Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
160     dnn::ConvolutionKind kind, dnn::DataType element_type,
161     dnn::DataType bias_type, dnn::DataType output_type, double conv_scale,
162     double side_input_scale, double leakyrelu_alpha,
163     const dnn::BatchDescriptor& input_descriptor,
164     const dnn::FilterDescriptor& filter_descriptor,
165     const dnn::BatchDescriptor& bias_descriptor,
166     const dnn::BatchDescriptor& output_descriptor,
167     const dnn::ConvolutionDescriptor& convolution_descriptor,
168     dnn::ActivationMode activation_mode) {
169   return port::UnimplementedError(
170       "FusedConvolveRunnerFromDesc not implemented.");
171 }
172 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind,dnn::DataType,Stream *,const dnn::BatchDescriptor &,DeviceMemoryBase input_data,const dnn::FilterDescriptor &,DeviceMemoryBase filter_data,const dnn::BatchDescriptor &,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor &,ScratchAllocator * scratch_allocator,std::vector<ProfileResult> *)173 bool DnnSupport::GetMIOpenConvolveAlgorithms(
174     dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/,
175     Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/,
176     DeviceMemoryBase input_data,
177     const dnn::FilterDescriptor& /*filter_descriptor*/,
178     DeviceMemoryBase filter_data,
179     const dnn::BatchDescriptor& /*output_descriptor*/,
180     DeviceMemoryBase output_data,
181     const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
182     ScratchAllocator* scratch_allocator,
183     std::vector<ProfileResult>* /*out_algorithms*/) {
184   return false;
185 }
186 
GetRnnAlgorithms(std::vector<AlgorithmDesc> * out_algorithms)187 bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
188   return false;
189 }
190 
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)191 bool DnnSupport::GetConvolveBackwardDataAlgorithms(
192     CudaComputeCapability cuda_compute_capability,
193     std::vector<AlgorithmDesc>* out_algorithms) {
194   return false;
195 }
196 
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)197 bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
198     CudaComputeCapability cuda_compute_capability,
199     std::vector<AlgorithmDesc>* out_algorithms) {
200   return false;
201 }
202 
QuantizedActivationModeString(QuantizedActivationMode mode)203 std::string QuantizedActivationModeString(QuantizedActivationMode mode) {
204   switch (mode) {
205     case dnn::QuantizedActivationMode::k8Bit:
206       return "uint8";
207     case dnn::QuantizedActivationMode::k16Bit:
208       return "uint16";
209     case dnn::QuantizedActivationMode::k32Bit:
210       return "int32";
211     default:
212       return absl::StrCat("unknown: ", static_cast<int32_t>(mode));
213   }
214 }
215 
ActivationModeString(ActivationMode mode)216 std::string ActivationModeString(ActivationMode mode) {
217   switch (mode) {
218     case ActivationMode::kNone:
219       return "none";
220     case ActivationMode::kSigmoid:
221       return "sigmoid";
222     case ActivationMode::kRelu:
223       return "relu";
224     case ActivationMode::kRelu6:
225       return "relu6";
226     case ActivationMode::kReluX:
227       return "reluX";
228     case ActivationMode::kTanh:
229       return "tanh";
230     case ActivationMode::kBandPass:
231       return "bandpass";
232     default:
233       return absl::StrCat("unknown: ", static_cast<int32_t>(mode));
234   }
235 }
236 
ElementwiseOperationString(ElementwiseOperation op)237 std::string ElementwiseOperationString(ElementwiseOperation op) {
238   switch (op) {
239     case ElementwiseOperation::kAdd:
240       return "add";
241     case ElementwiseOperation::kMultiply:
242       return "multiply";
243     default:
244       return absl::StrCat("unknown: ", static_cast<int32_t>(op));
245   }
246 }
247 
DataLayoutString(DataLayout layout)248 std::string DataLayoutString(DataLayout layout) {
249   switch (layout) {
250     case DataLayout::kYXDepthBatch:
251       return "YXDepthBatch";
252     case DataLayout::kYXBatchDepth:
253       return "YXBatchDepth";
254     case DataLayout::kBatchYXDepth:
255       return "BatchYXDepth";
256     case DataLayout::kBatchDepthYX:
257       return "BatchDepthYX";
258     case DataLayout::kBatchDepthYX4:
259       return "BatchDepthYX4";
260     case DataLayout::kBatchDepthYX32:
261       return "BatchDepthYX32";
262     default:
263       return absl::StrCat("unknown: ", static_cast<int32_t>(layout));
264   }
265 }
266 
FilterLayoutString(FilterLayout layout)267 std::string FilterLayoutString(FilterLayout layout) {
268   switch (layout) {
269     case FilterLayout::kOutputInputYX:
270       return "OutputInputYX";
271     case FilterLayout::kOutputYXInput:
272       return "OutputYXInput";
273     case FilterLayout::kOutputInputYX4:
274       return "OutputInputYX4";
275     case FilterLayout::kOutputInputYX32:
276       return "OutputInputYX32";
277     case FilterLayout::kInputYXOutput:
278       return "InputYXOutput";
279     case FilterLayout::kYXInputOutput:
280       return "YXInputOutput";
281     default:
282       return absl::StrCat("unknown: ", static_cast<int32_t>(layout));
283   }
284 }
285 
PadAlignmentString(PadAlignment alignment)286 std::string PadAlignmentString(PadAlignment alignment) {
287   switch (alignment) {
288     case PadAlignment::kDefault:
289       return "default";
290     case PadAlignment::kCudnnPadding:
291       return "cuDNN padding";
292     case PadAlignment::kTensorFlowPadding:
293       return "TensorFlow padding";
294     default:
295       return absl::StrCat("unknown: ", static_cast<int32_t>(alignment));
296   }
297 }
298 
operator <<(std::ostream & str,dnn::PadAlignment alignment)299 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
300   return str << PadAlignmentString(alignment);
301 }
302 
ShortPoolingModeString(PoolingMode mode)303 std::string ShortPoolingModeString(PoolingMode mode) {
304   switch (mode) {
305     case PoolingMode::kMaximum:
306       return "Max";
307     case PoolingMode::kAverage:
308       return "Avg";
309     default:
310       return absl::StrCat("unknown: ", static_cast<int32_t>(mode));
311   }
312 }
313 
314 struct ConvDimIndices {
315   union {
316     struct {
317       int depth_idx;
318       int batch_idx;
319       int spatial_idx;
320     } data;
321     struct {
322       int output_idx;
323       int input_idx;
324       int spatial_idx;
325     } filter;
326   };
327 };
328 
GetDimIndices(const DataLayout & layout,const int data_dims)329 ConvDimIndices GetDimIndices(const DataLayout& layout, const int data_dims) {
330   ConvDimIndices dim_indices;
331   switch (layout) {
332     case DataLayout::kYXBatchDepth:
333       dim_indices.data.depth_idx = data_dims - 1;
334       dim_indices.data.batch_idx = data_dims - 2;
335       dim_indices.data.spatial_idx = 0;
336       break;
337 
338     case DataLayout::kYXDepthBatch:
339       dim_indices.data.depth_idx = data_dims - 2;
340       dim_indices.data.batch_idx = data_dims - 1;
341       dim_indices.data.spatial_idx = 0;
342       break;
343 
344     case DataLayout::kBatchYXDepth:
345       dim_indices.data.depth_idx = data_dims - 1;
346       dim_indices.data.batch_idx = 0;
347       dim_indices.data.spatial_idx = 1;
348       break;
349 
350     case DataLayout::kBatchDepthYX:
351     case DataLayout::kBatchDepthYX4:
352     case DataLayout::kBatchDepthYX32:
353       dim_indices.data.depth_idx = 1;
354       dim_indices.data.batch_idx = 0;
355       dim_indices.data.spatial_idx = 2;
356       break;
357 
358     default:
359       LOG(FATAL) << "Unknown layout " << layout;
360   }
361 
362   return dim_indices;
363 }
364 
GetDimIndices(const FilterLayout & layout,const int data_dims)365 ConvDimIndices GetDimIndices(const FilterLayout& layout, const int data_dims) {
366   ConvDimIndices dim_indices;
367   switch (layout) {
368     case FilterLayout::kOutputInputYX:
369     case FilterLayout::kOutputInputYX4:
370     case FilterLayout::kOutputInputYX32:
371       dim_indices.filter.input_idx = 1;
372       dim_indices.filter.output_idx = 0;
373       dim_indices.filter.spatial_idx = 2;
374       break;
375 
376     case FilterLayout::kOutputYXInput:
377       dim_indices.filter.input_idx = data_dims - 1;
378       dim_indices.filter.output_idx = 0;
379       dim_indices.filter.spatial_idx = 1;
380       break;
381 
382     case FilterLayout::kInputYXOutput:
383       dim_indices.filter.input_idx = 0;
384       dim_indices.filter.output_idx = data_dims - 1;
385       dim_indices.filter.spatial_idx = 1;
386       break;
387 
388     case FilterLayout::kYXInputOutput:
389       dim_indices.filter.input_idx = data_dims - 2;
390       dim_indices.filter.output_idx = data_dims - 1;
391       dim_indices.filter.spatial_idx = 0;
392       break;
393 
394     default:
395       LOG(FATAL) << "Unknown layout " << layout;
396   }
397 
398   return dim_indices;
399 }
400 
ReorderDims(const std::vector<int64_t> & input,const DataLayout & from,const DataLayout & to)401 std::vector<int64_t> ReorderDims(const std::vector<int64_t>& input,
402                                  const DataLayout& from, const DataLayout& to) {
403   if (from == to) return input;
404 
405   ConvDimIndices from_indices = GetDimIndices(from, input.size());
406   ConvDimIndices to_indices = GetDimIndices(to, input.size());
407 
408   std::vector<int64_t> reordered(input.size());
409   reordered[to_indices.data.batch_idx] = input[from_indices.data.batch_idx];
410   reordered[to_indices.data.depth_idx] = input[from_indices.data.depth_idx];
411 
412   int spatial_idx_from = from_indices.data.spatial_idx;
413   int spatial_idx_to = to_indices.data.spatial_idx;
414   for (size_t i = 0; i < input.size() - 2;
415        i++, spatial_idx_from++, spatial_idx_to++) {
416     reordered[spatial_idx_to] = input[spatial_idx_from];
417   }
418 
419   return reordered;
420 }
421 
ReorderDims(const std::vector<int64_t> & input,const FilterLayout & from,const FilterLayout & to)422 std::vector<int64_t> ReorderDims(const std::vector<int64_t>& input,
423                                  const FilterLayout& from,
424                                  const FilterLayout& to) {
425   if (from == to) return input;
426 
427   ConvDimIndices from_indices = GetDimIndices(from, input.size());
428   ConvDimIndices to_indices = GetDimIndices(to, input.size());
429 
430   std::vector<int64_t> reordered(input.size());
431   reordered[to_indices.filter.output_idx] =
432       input[from_indices.filter.output_idx];
433   reordered[to_indices.filter.input_idx] = input[from_indices.filter.input_idx];
434 
435   int spatial_idx_from = from_indices.filter.spatial_idx;
436   int spatial_idx_to = to_indices.filter.spatial_idx;
437   for (size_t i = 0; i < input.size() - 2;
438        i++, spatial_idx_from++, spatial_idx_to++) {
439     reordered[spatial_idx_to] = input[spatial_idx_from];
440   }
441 
442   return reordered;
443 }
444 
445 // -- AlgorithmConfig
446 
ToString() const447 std::string AlgorithmConfig::ToString() const {
448   std::string algo = "none";
449   if (algorithm().has_value()) {
450     algo = algorithm()->ToString();
451   }
452   std::string algo_no_scratch = "none";
453   if (algorithm_no_scratch().has_value()) {
454     algo_no_scratch = algorithm_no_scratch()->ToString();
455   }
456   return absl::StrCat(algo, ", ", algo_no_scratch);
457 }
458 
459 // -- BatchDescriptor
460 
BatchDescriptor(int ndims)461 BatchDescriptor::BatchDescriptor(int ndims)
462     : value_max_(0.0),
463       value_min_(0.0),
464       quantized_activation_mode_(QuantizedActivationMode::k8Bit) {
465   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
466   set_layout(DataLayout::kYXDepthBatch);
467 }
468 
BatchDescriptor()469 BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {}
470 
full_dims(const DataLayout & layout) const471 std::vector<int64_t> BatchDescriptor::full_dims(
472     const DataLayout& layout) const {
473   std::vector<int64_t> bdyx_dims(ndims() + 2);
474   bdyx_dims[0] = count();
475   bdyx_dims[1] = feature_map_count();
476   std::copy(spatial_size().begin(), spatial_size().end(),
477             bdyx_dims.begin() + 2);
478   return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout);
479 }
480 
full_strides(const DataLayout & layout) const481 std::vector<int64_t> BatchDescriptor::full_strides(
482     const DataLayout& layout) const {
483   std::vector<int64_t> phys_dims = full_dims(this->layout());
484   std::vector<int64_t> phys_strides(phys_dims.size());
485   phys_strides[ndims() + 1] = 1;
486   for (int i = ndims(); i >= 0; i--) {
487     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
488   }
489   return ReorderDims(phys_strides, this->layout(), layout);
490 }
491 
vectorized_dims(const DataLayout & layout,int vector_size,int vector_dim) const492 std::vector<int64_t> BatchDescriptor::vectorized_dims(const DataLayout& layout,
493                                                       int vector_size,
494                                                       int vector_dim) const {
495   std::vector<int64_t> bdyx_dims = full_dims(dnn::DataLayout::kBatchDepthYX);
496   if (vector_dim != -1) {
497     bdyx_dims[vector_dim] /= vector_size;
498   }
499   return dnn::ReorderDims(bdyx_dims, dnn::DataLayout::kBatchDepthYX, layout);
500 }
501 
vectorized_strides(const DataLayout & layout,int vector_size,int vector_dim) const502 std::vector<int64_t> BatchDescriptor::vectorized_strides(
503     const DataLayout& layout, int vector_size, int vector_dim) const {
504   std::vector<int64_t> phys_dims =
505       vectorized_dims(this->layout(), vector_size, vector_dim);
506   std::vector<int64_t> phys_strides(phys_dims.size());
507   phys_strides[phys_dims.size() - 1] = 1;
508   for (int i = phys_dims.size() - 2; i >= 0; i--) {
509     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
510   }
511   return ReorderDims(phys_strides, this->layout(), layout);
512 }
513 
CloneFrom(const BatchDescriptor & other)514 void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
515   tensor_ = other.tensor_;
516   value_max_ = other.value_max_;
517   value_min_ = other.value_min_;
518   quantized_activation_mode_ = other.quantized_activation_mode_;
519 }
520 
ToString() const521 std::string BatchDescriptor::ToString() const {
522   std::string spatial;
523   for (int i = 0; i < ndims(); i++) {
524     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
525   }
526   return absl::StrFormat(
527       "{count: %d feature_map_count: %d spatial: %s "
528       "value_min: %f value_max: %f layout: %s}",
529       count(), feature_map_count(), spatial, value_min_, value_max_,
530       DataLayoutString(layout()));
531 }
532 
ToShortString() const533 std::string BatchDescriptor::ToShortString() const {
534   // All the constituent strings are less than 15 characters, so the
535   // small string optimization ensures that there will be at most one
536   // heap memory allocation.
537   std::string depth = absl::StrCat("d", feature_map_count());
538   std::string batch = absl::StrCat("b", count());
539 
540   std::string spatial = "s";
541   for (int i = 0; i < ndims(); i++) {
542     absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
543   }
544 
545   std::string suffix;
546   if (value_min() != value_max()) {
547     absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
548   }
549   if (quantized_activation_mode() == QuantizedActivationMode::k16Bit) {
550     suffix += "_16bit";
551   }
552 
553   switch (layout()) {
554     case DataLayout::kYXDepthBatch:
555       return absl::StrCat(spatial, depth, batch, suffix);
556     case DataLayout::kYXBatchDepth:
557       return absl::StrCat(spatial, batch, depth, suffix);
558     case DataLayout::kBatchYXDepth:
559       return absl::StrCat(batch, spatial, depth, suffix);
560     case DataLayout::kBatchDepthYX:
561       return absl::StrCat(batch, depth, spatial, suffix);
562     case DataLayout::kBatchDepthYX4:
563     case DataLayout::kBatchDepthYX32:
564       return absl::StrCat(batch, depth, spatial, suffix, "(VECT_C)");
565     default:
566       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
567       return "";  // Avoid return warning (unreachable)
568   }
569 }
570 
NodesPerFeatureMap() const571 int64_t BatchDescriptor::NodesPerFeatureMap() const {
572   int64_t ret = 1;
573   for (int i = 0; i < ndims(); i++) {
574     ret *= spatial_size()[i];
575   }
576   return ret;
577 }
578 
NodesAcrossFeatureMaps() const579 int64_t BatchDescriptor::NodesAcrossFeatureMaps() const {
580   return NodesPerFeatureMap() * feature_map_count();
581 }
582 
ElementCount() const583 int64_t BatchDescriptor::ElementCount() const {
584   return count() * feature_map_count() * NodesPerFeatureMap();
585 }
586 
FullyConnectedWeightCount(const BatchDescriptor & input,const BatchDescriptor & output)587 int64_t BatchDescriptor::FullyConnectedWeightCount(
588     const BatchDescriptor& input, const BatchDescriptor& output) {
589   return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps();
590 }
591 
FullyConnectedBiasCount(const BatchDescriptor & output)592 int64_t BatchDescriptor::FullyConnectedBiasCount(
593     const BatchDescriptor& output) {
594   return output.NodesAcrossFeatureMaps();
595 }
596 
DepthConcatenateOutputDescriptor(port::ArraySlice<dnn::BatchDescriptor> inputs)597 BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
598     port::ArraySlice<dnn::BatchDescriptor> inputs) {  // non-absl ok
599   if (inputs.empty()) {
600     return BatchDescriptor();
601   }
602   int feature_map_count = 0;
603   for (const auto& dimensions : inputs) {
604     feature_map_count += dimensions.feature_map_count();
605   }
606   BatchDescriptor output = inputs[0];
607   output.set_feature_map_count(feature_map_count);
608   return output;
609 }
610 
ToProto(DataType data_type) const611 TensorDescriptorProto BatchDescriptor::ToProto(DataType data_type) const {
612   CHECK_EQ(0.0, value_max_);
613   CHECK_EQ(0.0, value_min_);
614   CHECK(quantized_activation_mode_ == QuantizedActivationMode::k8Bit);
615 
616   TensorDescriptorProto ret = tensor_;
617   ret.set_data_type(data_type);
618   return ret;
619 }
620 
621 // -- FilterDescriptor
622 
FilterDescriptor(int ndims)623 FilterDescriptor::FilterDescriptor(int ndims) {
624   tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
625   set_layout(FilterLayout::kOutputInputYX);
626 }
627 
FilterDescriptor()628 FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {}
629 
~FilterDescriptor()630 FilterDescriptor::~FilterDescriptor() {}
631 
CloneFrom(const FilterDescriptor & other)632 void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
633   tensor_ = other.tensor_;
634 }
635 
ToString() const636 std::string FilterDescriptor::ToString() const {
637   std::string desc = absl::StrFormat(
638       "{output_feature_map_count: %d input_feature_map_count: %d "
639       "layout: %s shape: ",
640       output_feature_map_count(), input_feature_map_count(),
641       FilterLayoutString(layout()));
642   for (int i = 0; i < ndims(); i++) {
643     absl::StrAppendFormat(&desc, "%d ", input_filter_dims()[i]);
644   }
645   absl::StrAppend(&desc, "}");
646 
647   return desc;
648 }
649 
ToShortString() const650 std::string FilterDescriptor::ToShortString() const {
651   // All the constituent strings are less than 15 characters, so the
652   // small string optimization ensures that there will be at most one
653   // heap memory allocation.
654   std::string od = absl::StrCat("od", output_feature_map_count());
655   std::string id = absl::StrCat("id", input_feature_map_count());
656 
657   std::string spatial = "s";
658   for (int i = 0; i < ndims(); i++) {
659     absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
660   }
661 
662   switch (layout()) {
663     case FilterLayout::kOutputInputYX:
664       return absl::StrCat(od, id, spatial);
665     case FilterLayout::kOutputYXInput:
666       return absl::StrCat(od, spatial, id);
667     case FilterLayout::kOutputInputYX4:
668     case FilterLayout::kOutputInputYX32:
669       return absl::StrCat(od, id, spatial, "(VECT_C)");
670     case FilterLayout::kInputYXOutput:
671       return absl::StrCat(id, spatial, od);
672     case FilterLayout::kYXInputOutput:
673       return absl::StrCat(spatial, id, od);
674     default:
675       LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
676       return "";  // Avoid return warning (unreachable)
677   }
678 }
679 
ComputeWeightCount() const680 int64_t FilterDescriptor::ComputeWeightCount() const {
681   int64_t ret = output_feature_map_count() * input_feature_map_count();
682   for (int i = 0; i < ndims(); i++) {
683     ret *= input_filter_dims()[i];
684   }
685   return ret;
686 }
687 
full_dims(const FilterLayout & layout) const688 std::vector<int64_t> FilterDescriptor::full_dims(
689     const FilterLayout& layout) const {
690   std::vector<int64_t> oiyx_dims(ndims() + 2);
691   oiyx_dims[0] = output_feature_map_count();
692   oiyx_dims[1] = input_feature_map_count();
693   std::copy(input_filter_dims().begin(), input_filter_dims().end(),
694             oiyx_dims.begin() + 2);
695   return ReorderDims(oiyx_dims, FilterLayout::kOutputInputYX, layout);
696 }
697 
full_strides(const FilterLayout & layout) const698 std::vector<int64_t> FilterDescriptor::full_strides(
699     const FilterLayout& layout) const {
700   std::vector<int64_t> phys_dims = full_dims(this->layout());
701   std::vector<int64_t> phys_strides(phys_dims.size());
702   phys_strides[ndims() + 1] = 1;
703   for (int i = ndims(); i >= 0; i--) {
704     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
705   }
706   return ReorderDims(phys_strides, this->layout(), layout);
707 }
708 
vectorized_dims(const FilterLayout & layout,int vector_size,int vector_dim) const709 std::vector<int64_t> FilterDescriptor::vectorized_dims(
710     const FilterLayout& layout, int vector_size, int vector_dim) const {
711   std::vector<int64_t> oiyx_dims = full_dims(dnn::FilterLayout::kOutputInputYX);
712   if (vector_dim != -1) {
713     oiyx_dims[vector_dim] /= vector_size;
714   }
715   return ReorderDims(oiyx_dims, FilterLayout::kOutputInputYX, layout);
716 }
717 
vectorized_strides(const FilterLayout & layout,int vector_size,int vector_dim) const718 std::vector<int64_t> FilterDescriptor::vectorized_strides(
719     const FilterLayout& layout, int vector_size, int vector_dim) const {
720   std::vector<int64_t> phys_dims =
721       vectorized_dims(this->layout(), vector_size, vector_dim);
722   std::vector<int64_t> phys_strides(phys_dims.size());
723   phys_strides[phys_dims.size() - 1] = 1;
724   for (int i = phys_dims.size() - 2; i >= 0; i--) {
725     phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
726   }
727   return ReorderDims(phys_strides, this->layout(), layout);
728 }
729 
ToProto(DataType data_type) const730 TensorDescriptorProto FilterDescriptor::ToProto(DataType data_type) const {
731   TensorDescriptorProto ret = tensor_;
732   ret.set_data_type(data_type);
733   return ret;
734 }
735 
736 // -- ConvolutionDescriptor
737 
ConvolutionDescriptor(int ndims)738 ConvolutionDescriptor::ConvolutionDescriptor(int ndims) {
739   proto_.mutable_paddings()->Resize(ndims, 0);
740   proto_.mutable_strides()->Resize(ndims, 1);
741   proto_.mutable_dilations()->Resize(ndims, 1);
742   proto_.set_group_count(1);
743   proto_.set_convolution_mode(ConvolutionMode::CROSS_CORRELATION);
744 }
745 
ConvolutionDescriptor()746 ConvolutionDescriptor::ConvolutionDescriptor()
747     : ConvolutionDescriptor(/*ndims=*/2) {}
748 
~ConvolutionDescriptor()749 ConvolutionDescriptor::~ConvolutionDescriptor() {}
750 
ToString() const751 std::string ConvolutionDescriptor::ToString() const {
752   std::string padding;
753   std::string strides;
754   std::string dilations;
755   for (int i = 0; i < ndims(); i++) {
756     absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
757     absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
758     absl::StrAppendFormat(&dilations, "%d ", this->dilations()[i]);
759   }
760 
761   return absl::StrFormat(
762       "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
763       "%s}",
764       padding, PadAlignmentString(pad_alignment()), strides, dilations);
765 }
766 
ToShortString() const767 std::string ConvolutionDescriptor::ToShortString() const {
768   std::string desc;
769   for (int i = 0; i < ndims(); i++) {
770     if (i > 0) absl::StrAppend(&desc, "_");
771     absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
772   }
773   for (int i = 0; i < ndims(); i++) {
774     absl::StrAppendFormat(&desc, "_s%d:%d", i, strides()[i]);
775   }
776   for (int i = 0; i < ndims(); i++) {
777     absl::StrAppendFormat(&desc, "_d%d:%d", i, dilations()[i]);
778   }
779   return desc;
780 }
781 
782 // -- PoolingDescriptor
783 
PoolingDescriptor(int ndims)784 PoolingDescriptor::PoolingDescriptor(int ndims)
785     : mode_(dnn::PoolingMode::kMaximum),
786       ndims_(ndims),
787       propagate_nans_(false),
788       window_(ndims, 0),
789       padding_(ndims, 0),
790       strides_(ndims, 1) {}
791 
PoolingDescriptor()792 PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
793 
CloneFrom(const PoolingDescriptor & other)794 void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
795   mode_ = other.mode_;
796   ndims_ = other.ndims_;
797   window_ = other.window_;
798   padding_ = other.padding_;
799   strides_ = other.strides_;
800   propagate_nans_ = other.propagate_nans_;
801 }
802 
ToString() const803 std::string PoolingDescriptor::ToString() const {
804   const char* mode_string =
805       mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
806 
807   std::string window, strides, padding;
808   for (int i = 0; i < ndims_; i++) {
809     absl::StrAppendFormat(&window, "%d ", window_[i]);
810     absl::StrAppendFormat(&strides, "%d ", strides_[i]);
811     absl::StrAppendFormat(&padding, "%d", padding_[i]);
812   }
813 
814   const char* propagate_string = propagate_nans_ ? "Yes" : "No";
815 
816   return absl::StrFormat(
817       "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}",
818       mode_string, window, strides, padding, propagate_string);
819 }
820 
ToShortString() const821 std::string PoolingDescriptor::ToShortString() const {
822   std::string window, strides, padding;
823   for (int i = 0; i < ndims_; i++) {
824     absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
825     absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
826     absl::StrAppendFormat(&padding, "_p%d:%d", i, padding_[i]);
827   }
828   return absl::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
829                       window, strides, padding,
830                       propagate_nans_ ? "propagate_nans" : "ignore_nans");
831 }
832 
833 // -- NormalizeDescriptor
834 
NormalizeDescriptor()835 NormalizeDescriptor::NormalizeDescriptor()
836     : bias_(0.0),
837       range_(0),
838       alpha_(0.0),
839       beta_(0.0),
840       wrap_around_(false),
841       segment_size_(0) {}
842 
CloneFrom(const NormalizeDescriptor & other)843 void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
844   bias_ = other.bias_;
845   range_ = other.range_;
846   alpha_ = other.alpha_;
847   beta_ = other.beta_;
848   wrap_around_ = other.wrap_around_;
849   segment_size_ = other.segment_size_;
850 }
851 
ToString() const852 std::string NormalizeDescriptor::ToString() const {
853   return absl::StrFormat(
854       "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
855       "segment_size: %d}",
856       bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
857 }
858 
ToShortString() const859 std::string NormalizeDescriptor::ToShortString() const {
860   return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
861                       "_beta:", beta_, "_wrap:", wrap_around_,
862                       "_size:", segment_size_);
863 }
864 
IsStatusOk(const port::Status & status,bool report_error)865 bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
866   if (status.ok()) {
867     return true;
868   }
869   if (report_error) {
870     LOG(ERROR) << status.error_message();
871   }
872   return false;
873 }
874 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)875 port::Status DnnSupport::DoCtcLoss(
876     Stream* stream, dnn::DataType element_type,
877     const RnnStateTensorDescriptor& probs_desc,
878     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
879     absl::Span<const int> labels_lengths_data,
880     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
881     const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
882     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
883   return port::UnimplementedError("CtcLoss not implemented");
884 }
885 
886 }  // namespace dnn
887 }  // namespace stream_executor
888