1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
17
18 #include "absl/hash/hash.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/lib/strings/proto_serialization.h"
23
24 namespace stream_executor {
25 namespace dnn {
26
27 namespace {
28
ProtoMapIsSubset(const google::protobuf::Map<int64_t,int64_t> & x,const google::protobuf::Map<int64_t,int64_t> & y)29 bool ProtoMapIsSubset(const google::protobuf::Map<int64_t, int64_t>& x,
30 const google::protobuf::Map<int64_t, int64_t>& y) {
31 for (const auto& ypair : y) {
32 const auto it = x.find(ypair.first);
33 if (it == x.end() || it->second != ypair.second) return false;
34 }
35 return true;
36 }
37
ProtoMapsEqual(const google::protobuf::Map<int64_t,int64_t> & x,const google::protobuf::Map<int64_t,int64_t> & y)38 bool ProtoMapsEqual(const google::protobuf::Map<int64_t, int64_t>& x,
39 const google::protobuf::Map<int64_t, int64_t>& y) {
40 return ProtoMapIsSubset(x, y) && ProtoMapIsSubset(y, x);
41 }
42
43 } // namespace
44
45 constexpr DataType ToDataType<float>::value;
46 constexpr DataType ToDataType<double>::value;
47 constexpr DataType ToDataType<Eigen::half>::value;
48 constexpr DataType ToDataType<Eigen::bfloat16>::value;
49 constexpr DataType ToDataType<int8>::value;
50 constexpr DataType ToDataType<int32>::value;
51 constexpr DataType ToDataType<std::complex<float>>::value;
52 constexpr DataType ToDataType<std::complex<double>>::value;
53
AlgorithmDesc(int64_t engine_id,const std::vector<std::pair<int64_t,int64_t>> & tuning_knobs,std::optional<uint64_t> workspace_size)54 AlgorithmDesc::AlgorithmDesc(
55 int64_t engine_id,
56 const std::vector<std::pair<int64_t, int64_t>>& tuning_knobs,
57 std::optional<uint64_t> workspace_size) {
58 proto_.set_is_cudnn_frontend(true);
59 proto_.set_algo_id(engine_id);
60 if (workspace_size) {
61 proto_.mutable_workspace_size()->set_value(*workspace_size);
62 }
63 for (const auto& pair : tuning_knobs) {
64 (*proto_.mutable_tuning_knobs())[pair.first] = pair.second;
65 }
66 }
67
hash() const68 uint64_t AlgorithmDesc::hash() const {
69 return tensorflow::DeterministicProtoHash64(proto_);
70 }
71
operator ==(const AlgorithmDesc & other) const72 bool AlgorithmDesc::operator==(const AlgorithmDesc& other) const {
73 if (is_cudnn_frontend()) {
74 return other.is_cudnn_frontend() && algo_id() == other.algo_id() &&
75 ProtoMapsEqual(proto_.tuning_knobs(), other.proto_.tuning_knobs());
76 }
77 return !other.is_cudnn_frontend() && algo_id() == other.algo_id() &&
78 tensor_ops_enabled() == other.tensor_ops_enabled();
79 }
80
ToString() const81 std::string AlgorithmDesc::ToString() const {
82 if (is_cudnn_frontend()) {
83 // Format similarly to cudnn_frontend::ExecutionPlan::getTag(), e.g.
84 // "eng2{k1=2,k3=4}".
85 return absl::StrFormat(
86 "eng%d{%s}", proto_.algo_id(),
87 absl::StrJoin(
88 proto_.tuning_knobs(), ",",
89 [](std::string* out,
90 const google::protobuf::Map<int64_t, int64_t>::value_type& pair) {
91 absl::StrAppendFormat(out, "k%d=%d", pair.first, pair.second);
92 }));
93 }
94 if (tensor_ops_enabled()) {
95 return absl::StrCat(algo_id(), "#TC");
96 } else {
97 return absl::StrCat(algo_id());
98 }
99 }
100
TuningKnobs() const101 std::vector<std::pair<int64_t, int64_t>> AlgorithmDesc::TuningKnobs() const {
102 std::vector<std::pair<int64_t, int64_t>> result;
103 result.reserve(proto_.tuning_knobs().size());
104 for (const auto& pair : proto_.tuning_knobs()) {
105 result.emplace_back(pair.first, pair.second);
106 }
107 return result;
108 }
109
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)110 bool DnnSupport::GetConvolveAlgorithms(
111 CudaComputeCapability cuda_compute_capability,
112 std::vector<AlgorithmDesc>* out_algorithms) {
113 return false;
114 }
115
GetConvolveRunners(bool,dnn::ConvolutionKind,dnn::DataType,dnn::DataType,Stream *,const dnn::BatchDescriptor &,DeviceMemoryBase,const dnn::FilterDescriptor &,DeviceMemoryBase,const dnn::BatchDescriptor &,DeviceMemoryBase,const dnn::ConvolutionDescriptor &,bool,ScratchAllocator *,std::vector<std::unique_ptr<const dnn::ConvRunner>> *)116 port::Status DnnSupport::GetConvolveRunners(
117 bool /* use_cudnn_frontend */, dnn::ConvolutionKind /*kind*/,
118 dnn::DataType /*input_type*/, dnn::DataType /*output_type*/,
119 Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/,
120 DeviceMemoryBase /*input_data*/,
121 const dnn::FilterDescriptor& /*filter_descriptor*/,
122 DeviceMemoryBase /*filter_data*/,
123 const dnn::BatchDescriptor& /*output_descriptor*/,
124 DeviceMemoryBase /*output_data*/,
125 const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
126 bool /*use_fallback*/, ScratchAllocator* /*scratch_allocator*/,
127 std::vector<std::unique_ptr<const dnn::ConvRunner>>* /*exec_plans*/) {
128 return port::UnimplementedError("GetConvolveRunners not implemented.");
129 }
130
131 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>>
ConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor)132 DnnSupport::ConvolveRunnerFromDesc(
133 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
134 dnn::ConvolutionKind kind, dnn::DataType element_type,
135 dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
136 const dnn::FilterDescriptor& filter_descriptor,
137 const dnn::BatchDescriptor& output_descriptor,
138 const dnn::ConvolutionDescriptor& convolution_descriptor) {
139 return port::UnimplementedError("ConvolveRunnerFromDesc not implemented.");
140 }
141
GetFusedConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_input_scale,double side_input_scale,double leakyrelu_alpha,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<const dnn::FusedConvRunner>> * out_exec_plans)142 port::Status DnnSupport::GetFusedConvolveRunners(
143 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
144 dnn::DataType element_type, dnn::DataType bias_type,
145 dnn::DataType output_type, double conv_input_scale, double side_input_scale,
146 double leakyrelu_alpha, Stream* stream,
147 const dnn::BatchDescriptor& input_descriptor,
148 const dnn::FilterDescriptor& filter_descriptor,
149 const dnn::BatchDescriptor& bias_descriptor,
150 const dnn::BatchDescriptor& output_descriptor,
151 const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
152 dnn::ActivationMode activation_mode,
153 std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans) {
154 return port::UnimplementedError("GetFusedConvolveRunners not implemented.");
155 }
156
157 port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
FusedConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_scale,double side_input_scale,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode)158 DnnSupport::FusedConvolveRunnerFromDesc(
159 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
160 dnn::ConvolutionKind kind, dnn::DataType element_type,
161 dnn::DataType bias_type, dnn::DataType output_type, double conv_scale,
162 double side_input_scale, double leakyrelu_alpha,
163 const dnn::BatchDescriptor& input_descriptor,
164 const dnn::FilterDescriptor& filter_descriptor,
165 const dnn::BatchDescriptor& bias_descriptor,
166 const dnn::BatchDescriptor& output_descriptor,
167 const dnn::ConvolutionDescriptor& convolution_descriptor,
168 dnn::ActivationMode activation_mode) {
169 return port::UnimplementedError(
170 "FusedConvolveRunnerFromDesc not implemented.");
171 }
172
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind,dnn::DataType,Stream *,const dnn::BatchDescriptor &,DeviceMemoryBase input_data,const dnn::FilterDescriptor &,DeviceMemoryBase filter_data,const dnn::BatchDescriptor &,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor &,ScratchAllocator * scratch_allocator,std::vector<ProfileResult> *)173 bool DnnSupport::GetMIOpenConvolveAlgorithms(
174 dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/,
175 Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/,
176 DeviceMemoryBase input_data,
177 const dnn::FilterDescriptor& /*filter_descriptor*/,
178 DeviceMemoryBase filter_data,
179 const dnn::BatchDescriptor& /*output_descriptor*/,
180 DeviceMemoryBase output_data,
181 const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
182 ScratchAllocator* scratch_allocator,
183 std::vector<ProfileResult>* /*out_algorithms*/) {
184 return false;
185 }
186
GetRnnAlgorithms(std::vector<AlgorithmDesc> * out_algorithms)187 bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
188 return false;
189 }
190
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)191 bool DnnSupport::GetConvolveBackwardDataAlgorithms(
192 CudaComputeCapability cuda_compute_capability,
193 std::vector<AlgorithmDesc>* out_algorithms) {
194 return false;
195 }
196
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<AlgorithmDesc> * out_algorithms)197 bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
198 CudaComputeCapability cuda_compute_capability,
199 std::vector<AlgorithmDesc>* out_algorithms) {
200 return false;
201 }
202
QuantizedActivationModeString(QuantizedActivationMode mode)203 std::string QuantizedActivationModeString(QuantizedActivationMode mode) {
204 switch (mode) {
205 case dnn::QuantizedActivationMode::k8Bit:
206 return "uint8";
207 case dnn::QuantizedActivationMode::k16Bit:
208 return "uint16";
209 case dnn::QuantizedActivationMode::k32Bit:
210 return "int32";
211 default:
212 return absl::StrCat("unknown: ", static_cast<int32_t>(mode));
213 }
214 }
215
ActivationModeString(ActivationMode mode)216 std::string ActivationModeString(ActivationMode mode) {
217 switch (mode) {
218 case ActivationMode::kNone:
219 return "none";
220 case ActivationMode::kSigmoid:
221 return "sigmoid";
222 case ActivationMode::kRelu:
223 return "relu";
224 case ActivationMode::kRelu6:
225 return "relu6";
226 case ActivationMode::kReluX:
227 return "reluX";
228 case ActivationMode::kTanh:
229 return "tanh";
230 case ActivationMode::kBandPass:
231 return "bandpass";
232 default:
233 return absl::StrCat("unknown: ", static_cast<int32_t>(mode));
234 }
235 }
236
ElementwiseOperationString(ElementwiseOperation op)237 std::string ElementwiseOperationString(ElementwiseOperation op) {
238 switch (op) {
239 case ElementwiseOperation::kAdd:
240 return "add";
241 case ElementwiseOperation::kMultiply:
242 return "multiply";
243 default:
244 return absl::StrCat("unknown: ", static_cast<int32_t>(op));
245 }
246 }
247
DataLayoutString(DataLayout layout)248 std::string DataLayoutString(DataLayout layout) {
249 switch (layout) {
250 case DataLayout::kYXDepthBatch:
251 return "YXDepthBatch";
252 case DataLayout::kYXBatchDepth:
253 return "YXBatchDepth";
254 case DataLayout::kBatchYXDepth:
255 return "BatchYXDepth";
256 case DataLayout::kBatchDepthYX:
257 return "BatchDepthYX";
258 case DataLayout::kBatchDepthYX4:
259 return "BatchDepthYX4";
260 case DataLayout::kBatchDepthYX32:
261 return "BatchDepthYX32";
262 default:
263 return absl::StrCat("unknown: ", static_cast<int32_t>(layout));
264 }
265 }
266
FilterLayoutString(FilterLayout layout)267 std::string FilterLayoutString(FilterLayout layout) {
268 switch (layout) {
269 case FilterLayout::kOutputInputYX:
270 return "OutputInputYX";
271 case FilterLayout::kOutputYXInput:
272 return "OutputYXInput";
273 case FilterLayout::kOutputInputYX4:
274 return "OutputInputYX4";
275 case FilterLayout::kOutputInputYX32:
276 return "OutputInputYX32";
277 case FilterLayout::kInputYXOutput:
278 return "InputYXOutput";
279 case FilterLayout::kYXInputOutput:
280 return "YXInputOutput";
281 default:
282 return absl::StrCat("unknown: ", static_cast<int32_t>(layout));
283 }
284 }
285
PadAlignmentString(PadAlignment alignment)286 std::string PadAlignmentString(PadAlignment alignment) {
287 switch (alignment) {
288 case PadAlignment::kDefault:
289 return "default";
290 case PadAlignment::kCudnnPadding:
291 return "cuDNN padding";
292 case PadAlignment::kTensorFlowPadding:
293 return "TensorFlow padding";
294 default:
295 return absl::StrCat("unknown: ", static_cast<int32_t>(alignment));
296 }
297 }
298
operator <<(std::ostream & str,dnn::PadAlignment alignment)299 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
300 return str << PadAlignmentString(alignment);
301 }
302
ShortPoolingModeString(PoolingMode mode)303 std::string ShortPoolingModeString(PoolingMode mode) {
304 switch (mode) {
305 case PoolingMode::kMaximum:
306 return "Max";
307 case PoolingMode::kAverage:
308 return "Avg";
309 default:
310 return absl::StrCat("unknown: ", static_cast<int32_t>(mode));
311 }
312 }
313
314 struct ConvDimIndices {
315 union {
316 struct {
317 int depth_idx;
318 int batch_idx;
319 int spatial_idx;
320 } data;
321 struct {
322 int output_idx;
323 int input_idx;
324 int spatial_idx;
325 } filter;
326 };
327 };
328
GetDimIndices(const DataLayout & layout,const int data_dims)329 ConvDimIndices GetDimIndices(const DataLayout& layout, const int data_dims) {
330 ConvDimIndices dim_indices;
331 switch (layout) {
332 case DataLayout::kYXBatchDepth:
333 dim_indices.data.depth_idx = data_dims - 1;
334 dim_indices.data.batch_idx = data_dims - 2;
335 dim_indices.data.spatial_idx = 0;
336 break;
337
338 case DataLayout::kYXDepthBatch:
339 dim_indices.data.depth_idx = data_dims - 2;
340 dim_indices.data.batch_idx = data_dims - 1;
341 dim_indices.data.spatial_idx = 0;
342 break;
343
344 case DataLayout::kBatchYXDepth:
345 dim_indices.data.depth_idx = data_dims - 1;
346 dim_indices.data.batch_idx = 0;
347 dim_indices.data.spatial_idx = 1;
348 break;
349
350 case DataLayout::kBatchDepthYX:
351 case DataLayout::kBatchDepthYX4:
352 case DataLayout::kBatchDepthYX32:
353 dim_indices.data.depth_idx = 1;
354 dim_indices.data.batch_idx = 0;
355 dim_indices.data.spatial_idx = 2;
356 break;
357
358 default:
359 LOG(FATAL) << "Unknown layout " << layout;
360 }
361
362 return dim_indices;
363 }
364
GetDimIndices(const FilterLayout & layout,const int data_dims)365 ConvDimIndices GetDimIndices(const FilterLayout& layout, const int data_dims) {
366 ConvDimIndices dim_indices;
367 switch (layout) {
368 case FilterLayout::kOutputInputYX:
369 case FilterLayout::kOutputInputYX4:
370 case FilterLayout::kOutputInputYX32:
371 dim_indices.filter.input_idx = 1;
372 dim_indices.filter.output_idx = 0;
373 dim_indices.filter.spatial_idx = 2;
374 break;
375
376 case FilterLayout::kOutputYXInput:
377 dim_indices.filter.input_idx = data_dims - 1;
378 dim_indices.filter.output_idx = 0;
379 dim_indices.filter.spatial_idx = 1;
380 break;
381
382 case FilterLayout::kInputYXOutput:
383 dim_indices.filter.input_idx = 0;
384 dim_indices.filter.output_idx = data_dims - 1;
385 dim_indices.filter.spatial_idx = 1;
386 break;
387
388 case FilterLayout::kYXInputOutput:
389 dim_indices.filter.input_idx = data_dims - 2;
390 dim_indices.filter.output_idx = data_dims - 1;
391 dim_indices.filter.spatial_idx = 0;
392 break;
393
394 default:
395 LOG(FATAL) << "Unknown layout " << layout;
396 }
397
398 return dim_indices;
399 }
400
ReorderDims(const std::vector<int64_t> & input,const DataLayout & from,const DataLayout & to)401 std::vector<int64_t> ReorderDims(const std::vector<int64_t>& input,
402 const DataLayout& from, const DataLayout& to) {
403 if (from == to) return input;
404
405 ConvDimIndices from_indices = GetDimIndices(from, input.size());
406 ConvDimIndices to_indices = GetDimIndices(to, input.size());
407
408 std::vector<int64_t> reordered(input.size());
409 reordered[to_indices.data.batch_idx] = input[from_indices.data.batch_idx];
410 reordered[to_indices.data.depth_idx] = input[from_indices.data.depth_idx];
411
412 int spatial_idx_from = from_indices.data.spatial_idx;
413 int spatial_idx_to = to_indices.data.spatial_idx;
414 for (size_t i = 0; i < input.size() - 2;
415 i++, spatial_idx_from++, spatial_idx_to++) {
416 reordered[spatial_idx_to] = input[spatial_idx_from];
417 }
418
419 return reordered;
420 }
421
ReorderDims(const std::vector<int64_t> & input,const FilterLayout & from,const FilterLayout & to)422 std::vector<int64_t> ReorderDims(const std::vector<int64_t>& input,
423 const FilterLayout& from,
424 const FilterLayout& to) {
425 if (from == to) return input;
426
427 ConvDimIndices from_indices = GetDimIndices(from, input.size());
428 ConvDimIndices to_indices = GetDimIndices(to, input.size());
429
430 std::vector<int64_t> reordered(input.size());
431 reordered[to_indices.filter.output_idx] =
432 input[from_indices.filter.output_idx];
433 reordered[to_indices.filter.input_idx] = input[from_indices.filter.input_idx];
434
435 int spatial_idx_from = from_indices.filter.spatial_idx;
436 int spatial_idx_to = to_indices.filter.spatial_idx;
437 for (size_t i = 0; i < input.size() - 2;
438 i++, spatial_idx_from++, spatial_idx_to++) {
439 reordered[spatial_idx_to] = input[spatial_idx_from];
440 }
441
442 return reordered;
443 }
444
445 // -- AlgorithmConfig
446
ToString() const447 std::string AlgorithmConfig::ToString() const {
448 std::string algo = "none";
449 if (algorithm().has_value()) {
450 algo = algorithm()->ToString();
451 }
452 std::string algo_no_scratch = "none";
453 if (algorithm_no_scratch().has_value()) {
454 algo_no_scratch = algorithm_no_scratch()->ToString();
455 }
456 return absl::StrCat(algo, ", ", algo_no_scratch);
457 }
458
459 // -- BatchDescriptor
460
BatchDescriptor(int ndims)461 BatchDescriptor::BatchDescriptor(int ndims)
462 : value_max_(0.0),
463 value_min_(0.0),
464 quantized_activation_mode_(QuantizedActivationMode::k8Bit) {
465 tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
466 set_layout(DataLayout::kYXDepthBatch);
467 }
468
BatchDescriptor()469 BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {}
470
full_dims(const DataLayout & layout) const471 std::vector<int64_t> BatchDescriptor::full_dims(
472 const DataLayout& layout) const {
473 std::vector<int64_t> bdyx_dims(ndims() + 2);
474 bdyx_dims[0] = count();
475 bdyx_dims[1] = feature_map_count();
476 std::copy(spatial_size().begin(), spatial_size().end(),
477 bdyx_dims.begin() + 2);
478 return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout);
479 }
480
full_strides(const DataLayout & layout) const481 std::vector<int64_t> BatchDescriptor::full_strides(
482 const DataLayout& layout) const {
483 std::vector<int64_t> phys_dims = full_dims(this->layout());
484 std::vector<int64_t> phys_strides(phys_dims.size());
485 phys_strides[ndims() + 1] = 1;
486 for (int i = ndims(); i >= 0; i--) {
487 phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
488 }
489 return ReorderDims(phys_strides, this->layout(), layout);
490 }
491
vectorized_dims(const DataLayout & layout,int vector_size,int vector_dim) const492 std::vector<int64_t> BatchDescriptor::vectorized_dims(const DataLayout& layout,
493 int vector_size,
494 int vector_dim) const {
495 std::vector<int64_t> bdyx_dims = full_dims(dnn::DataLayout::kBatchDepthYX);
496 if (vector_dim != -1) {
497 bdyx_dims[vector_dim] /= vector_size;
498 }
499 return dnn::ReorderDims(bdyx_dims, dnn::DataLayout::kBatchDepthYX, layout);
500 }
501
vectorized_strides(const DataLayout & layout,int vector_size,int vector_dim) const502 std::vector<int64_t> BatchDescriptor::vectorized_strides(
503 const DataLayout& layout, int vector_size, int vector_dim) const {
504 std::vector<int64_t> phys_dims =
505 vectorized_dims(this->layout(), vector_size, vector_dim);
506 std::vector<int64_t> phys_strides(phys_dims.size());
507 phys_strides[phys_dims.size() - 1] = 1;
508 for (int i = phys_dims.size() - 2; i >= 0; i--) {
509 phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
510 }
511 return ReorderDims(phys_strides, this->layout(), layout);
512 }
513
CloneFrom(const BatchDescriptor & other)514 void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
515 tensor_ = other.tensor_;
516 value_max_ = other.value_max_;
517 value_min_ = other.value_min_;
518 quantized_activation_mode_ = other.quantized_activation_mode_;
519 }
520
ToString() const521 std::string BatchDescriptor::ToString() const {
522 std::string spatial;
523 for (int i = 0; i < ndims(); i++) {
524 absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
525 }
526 return absl::StrFormat(
527 "{count: %d feature_map_count: %d spatial: %s "
528 "value_min: %f value_max: %f layout: %s}",
529 count(), feature_map_count(), spatial, value_min_, value_max_,
530 DataLayoutString(layout()));
531 }
532
ToShortString() const533 std::string BatchDescriptor::ToShortString() const {
534 // All the constituent strings are less than 15 characters, so the
535 // small string optimization ensures that there will be at most one
536 // heap memory allocation.
537 std::string depth = absl::StrCat("d", feature_map_count());
538 std::string batch = absl::StrCat("b", count());
539
540 std::string spatial = "s";
541 for (int i = 0; i < ndims(); i++) {
542 absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
543 }
544
545 std::string suffix;
546 if (value_min() != value_max()) {
547 absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
548 }
549 if (quantized_activation_mode() == QuantizedActivationMode::k16Bit) {
550 suffix += "_16bit";
551 }
552
553 switch (layout()) {
554 case DataLayout::kYXDepthBatch:
555 return absl::StrCat(spatial, depth, batch, suffix);
556 case DataLayout::kYXBatchDepth:
557 return absl::StrCat(spatial, batch, depth, suffix);
558 case DataLayout::kBatchYXDepth:
559 return absl::StrCat(batch, spatial, depth, suffix);
560 case DataLayout::kBatchDepthYX:
561 return absl::StrCat(batch, depth, spatial, suffix);
562 case DataLayout::kBatchDepthYX4:
563 case DataLayout::kBatchDepthYX32:
564 return absl::StrCat(batch, depth, spatial, suffix, "(VECT_C)");
565 default:
566 LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
567 return ""; // Avoid return warning (unreachable)
568 }
569 }
570
NodesPerFeatureMap() const571 int64_t BatchDescriptor::NodesPerFeatureMap() const {
572 int64_t ret = 1;
573 for (int i = 0; i < ndims(); i++) {
574 ret *= spatial_size()[i];
575 }
576 return ret;
577 }
578
NodesAcrossFeatureMaps() const579 int64_t BatchDescriptor::NodesAcrossFeatureMaps() const {
580 return NodesPerFeatureMap() * feature_map_count();
581 }
582
ElementCount() const583 int64_t BatchDescriptor::ElementCount() const {
584 return count() * feature_map_count() * NodesPerFeatureMap();
585 }
586
FullyConnectedWeightCount(const BatchDescriptor & input,const BatchDescriptor & output)587 int64_t BatchDescriptor::FullyConnectedWeightCount(
588 const BatchDescriptor& input, const BatchDescriptor& output) {
589 return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps();
590 }
591
FullyConnectedBiasCount(const BatchDescriptor & output)592 int64_t BatchDescriptor::FullyConnectedBiasCount(
593 const BatchDescriptor& output) {
594 return output.NodesAcrossFeatureMaps();
595 }
596
DepthConcatenateOutputDescriptor(port::ArraySlice<dnn::BatchDescriptor> inputs)597 BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
598 port::ArraySlice<dnn::BatchDescriptor> inputs) { // non-absl ok
599 if (inputs.empty()) {
600 return BatchDescriptor();
601 }
602 int feature_map_count = 0;
603 for (const auto& dimensions : inputs) {
604 feature_map_count += dimensions.feature_map_count();
605 }
606 BatchDescriptor output = inputs[0];
607 output.set_feature_map_count(feature_map_count);
608 return output;
609 }
610
ToProto(DataType data_type) const611 TensorDescriptorProto BatchDescriptor::ToProto(DataType data_type) const {
612 CHECK_EQ(0.0, value_max_);
613 CHECK_EQ(0.0, value_min_);
614 CHECK(quantized_activation_mode_ == QuantizedActivationMode::k8Bit);
615
616 TensorDescriptorProto ret = tensor_;
617 ret.set_data_type(data_type);
618 return ret;
619 }
620
621 // -- FilterDescriptor
622
FilterDescriptor(int ndims)623 FilterDescriptor::FilterDescriptor(int ndims) {
624 tensor_.mutable_dimensions()->Resize(ndims + 2, 0);
625 set_layout(FilterLayout::kOutputInputYX);
626 }
627
FilterDescriptor()628 FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {}
629
~FilterDescriptor()630 FilterDescriptor::~FilterDescriptor() {}
631
CloneFrom(const FilterDescriptor & other)632 void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
633 tensor_ = other.tensor_;
634 }
635
ToString() const636 std::string FilterDescriptor::ToString() const {
637 std::string desc = absl::StrFormat(
638 "{output_feature_map_count: %d input_feature_map_count: %d "
639 "layout: %s shape: ",
640 output_feature_map_count(), input_feature_map_count(),
641 FilterLayoutString(layout()));
642 for (int i = 0; i < ndims(); i++) {
643 absl::StrAppendFormat(&desc, "%d ", input_filter_dims()[i]);
644 }
645 absl::StrAppend(&desc, "}");
646
647 return desc;
648 }
649
ToShortString() const650 std::string FilterDescriptor::ToShortString() const {
651 // All the constituent strings are less than 15 characters, so the
652 // small string optimization ensures that there will be at most one
653 // heap memory allocation.
654 std::string od = absl::StrCat("od", output_feature_map_count());
655 std::string id = absl::StrCat("id", input_feature_map_count());
656
657 std::string spatial = "s";
658 for (int i = 0; i < ndims(); i++) {
659 absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
660 }
661
662 switch (layout()) {
663 case FilterLayout::kOutputInputYX:
664 return absl::StrCat(od, id, spatial);
665 case FilterLayout::kOutputYXInput:
666 return absl::StrCat(od, spatial, id);
667 case FilterLayout::kOutputInputYX4:
668 case FilterLayout::kOutputInputYX32:
669 return absl::StrCat(od, id, spatial, "(VECT_C)");
670 case FilterLayout::kInputYXOutput:
671 return absl::StrCat(id, spatial, od);
672 case FilterLayout::kYXInputOutput:
673 return absl::StrCat(spatial, id, od);
674 default:
675 LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
676 return ""; // Avoid return warning (unreachable)
677 }
678 }
679
ComputeWeightCount() const680 int64_t FilterDescriptor::ComputeWeightCount() const {
681 int64_t ret = output_feature_map_count() * input_feature_map_count();
682 for (int i = 0; i < ndims(); i++) {
683 ret *= input_filter_dims()[i];
684 }
685 return ret;
686 }
687
full_dims(const FilterLayout & layout) const688 std::vector<int64_t> FilterDescriptor::full_dims(
689 const FilterLayout& layout) const {
690 std::vector<int64_t> oiyx_dims(ndims() + 2);
691 oiyx_dims[0] = output_feature_map_count();
692 oiyx_dims[1] = input_feature_map_count();
693 std::copy(input_filter_dims().begin(), input_filter_dims().end(),
694 oiyx_dims.begin() + 2);
695 return ReorderDims(oiyx_dims, FilterLayout::kOutputInputYX, layout);
696 }
697
full_strides(const FilterLayout & layout) const698 std::vector<int64_t> FilterDescriptor::full_strides(
699 const FilterLayout& layout) const {
700 std::vector<int64_t> phys_dims = full_dims(this->layout());
701 std::vector<int64_t> phys_strides(phys_dims.size());
702 phys_strides[ndims() + 1] = 1;
703 for (int i = ndims(); i >= 0; i--) {
704 phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
705 }
706 return ReorderDims(phys_strides, this->layout(), layout);
707 }
708
vectorized_dims(const FilterLayout & layout,int vector_size,int vector_dim) const709 std::vector<int64_t> FilterDescriptor::vectorized_dims(
710 const FilterLayout& layout, int vector_size, int vector_dim) const {
711 std::vector<int64_t> oiyx_dims = full_dims(dnn::FilterLayout::kOutputInputYX);
712 if (vector_dim != -1) {
713 oiyx_dims[vector_dim] /= vector_size;
714 }
715 return ReorderDims(oiyx_dims, FilterLayout::kOutputInputYX, layout);
716 }
717
vectorized_strides(const FilterLayout & layout,int vector_size,int vector_dim) const718 std::vector<int64_t> FilterDescriptor::vectorized_strides(
719 const FilterLayout& layout, int vector_size, int vector_dim) const {
720 std::vector<int64_t> phys_dims =
721 vectorized_dims(this->layout(), vector_size, vector_dim);
722 std::vector<int64_t> phys_strides(phys_dims.size());
723 phys_strides[phys_dims.size() - 1] = 1;
724 for (int i = phys_dims.size() - 2; i >= 0; i--) {
725 phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
726 }
727 return ReorderDims(phys_strides, this->layout(), layout);
728 }
729
ToProto(DataType data_type) const730 TensorDescriptorProto FilterDescriptor::ToProto(DataType data_type) const {
731 TensorDescriptorProto ret = tensor_;
732 ret.set_data_type(data_type);
733 return ret;
734 }
735
736 // -- ConvolutionDescriptor
737
ConvolutionDescriptor(int ndims)738 ConvolutionDescriptor::ConvolutionDescriptor(int ndims) {
739 proto_.mutable_paddings()->Resize(ndims, 0);
740 proto_.mutable_strides()->Resize(ndims, 1);
741 proto_.mutable_dilations()->Resize(ndims, 1);
742 proto_.set_group_count(1);
743 proto_.set_convolution_mode(ConvolutionMode::CROSS_CORRELATION);
744 }
745
ConvolutionDescriptor()746 ConvolutionDescriptor::ConvolutionDescriptor()
747 : ConvolutionDescriptor(/*ndims=*/2) {}
748
~ConvolutionDescriptor()749 ConvolutionDescriptor::~ConvolutionDescriptor() {}
750
ToString() const751 std::string ConvolutionDescriptor::ToString() const {
752 std::string padding;
753 std::string strides;
754 std::string dilations;
755 for (int i = 0; i < ndims(); i++) {
756 absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
757 absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
758 absl::StrAppendFormat(&dilations, "%d ", this->dilations()[i]);
759 }
760
761 return absl::StrFormat(
762 "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
763 "%s}",
764 padding, PadAlignmentString(pad_alignment()), strides, dilations);
765 }
766
ToShortString() const767 std::string ConvolutionDescriptor::ToShortString() const {
768 std::string desc;
769 for (int i = 0; i < ndims(); i++) {
770 if (i > 0) absl::StrAppend(&desc, "_");
771 absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
772 }
773 for (int i = 0; i < ndims(); i++) {
774 absl::StrAppendFormat(&desc, "_s%d:%d", i, strides()[i]);
775 }
776 for (int i = 0; i < ndims(); i++) {
777 absl::StrAppendFormat(&desc, "_d%d:%d", i, dilations()[i]);
778 }
779 return desc;
780 }
781
782 // -- PoolingDescriptor
783
PoolingDescriptor(int ndims)784 PoolingDescriptor::PoolingDescriptor(int ndims)
785 : mode_(dnn::PoolingMode::kMaximum),
786 ndims_(ndims),
787 propagate_nans_(false),
788 window_(ndims, 0),
789 padding_(ndims, 0),
790 strides_(ndims, 1) {}
791
PoolingDescriptor()792 PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
793
CloneFrom(const PoolingDescriptor & other)794 void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
795 mode_ = other.mode_;
796 ndims_ = other.ndims_;
797 window_ = other.window_;
798 padding_ = other.padding_;
799 strides_ = other.strides_;
800 propagate_nans_ = other.propagate_nans_;
801 }
802
ToString() const803 std::string PoolingDescriptor::ToString() const {
804 const char* mode_string =
805 mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
806
807 std::string window, strides, padding;
808 for (int i = 0; i < ndims_; i++) {
809 absl::StrAppendFormat(&window, "%d ", window_[i]);
810 absl::StrAppendFormat(&strides, "%d ", strides_[i]);
811 absl::StrAppendFormat(&padding, "%d", padding_[i]);
812 }
813
814 const char* propagate_string = propagate_nans_ ? "Yes" : "No";
815
816 return absl::StrFormat(
817 "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}",
818 mode_string, window, strides, padding, propagate_string);
819 }
820
ToShortString() const821 std::string PoolingDescriptor::ToShortString() const {
822 std::string window, strides, padding;
823 for (int i = 0; i < ndims_; i++) {
824 absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
825 absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
826 absl::StrAppendFormat(&padding, "_p%d:%d", i, padding_[i]);
827 }
828 return absl::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
829 window, strides, padding,
830 propagate_nans_ ? "propagate_nans" : "ignore_nans");
831 }
832
833 // -- NormalizeDescriptor
834
NormalizeDescriptor()835 NormalizeDescriptor::NormalizeDescriptor()
836 : bias_(0.0),
837 range_(0),
838 alpha_(0.0),
839 beta_(0.0),
840 wrap_around_(false),
841 segment_size_(0) {}
842
CloneFrom(const NormalizeDescriptor & other)843 void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
844 bias_ = other.bias_;
845 range_ = other.range_;
846 alpha_ = other.alpha_;
847 beta_ = other.beta_;
848 wrap_around_ = other.wrap_around_;
849 segment_size_ = other.segment_size_;
850 }
851
ToString() const852 std::string NormalizeDescriptor::ToString() const {
853 return absl::StrFormat(
854 "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
855 "segment_size: %d}",
856 bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
857 }
858
ToShortString() const859 std::string NormalizeDescriptor::ToShortString() const {
860 return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
861 "_beta:", beta_, "_wrap:", wrap_around_,
862 "_size:", segment_size_);
863 }
864
IsStatusOk(const port::Status & status,bool report_error)865 bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
866 if (status.ok()) {
867 return true;
868 }
869 if (report_error) {
870 LOG(ERROR) << status.error_message();
871 }
872 return false;
873 }
874
DoCtcLoss(Stream * stream,dnn::DataType element_type,const RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)875 port::Status DnnSupport::DoCtcLoss(
876 Stream* stream, dnn::DataType element_type,
877 const RnnStateTensorDescriptor& probs_desc,
878 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
879 absl::Span<const int> labels_lengths_data,
880 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
881 const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
882 DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
883 return port::UnimplementedError("CtcLoss not implemented");
884 }
885
886 } // namespace dnn
887 } // namespace stream_executor
888