1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Neural Net operation support for StreamExecutor instances.
17 //
18 // This is an abstract interface for a platform to optionally support common
19 // neural net operations; it accommodates implementations such as the cudnn
20 // library operations.
21
22 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DNN_H_
23 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DNN_H_
24
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <string>
29 #include <tuple>
30 #include <type_traits>
31 #include <utility>
32
33 #include "google/protobuf/wrappers.pb.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/stream_executor/data_type.h"
37 #include "tensorflow/compiler/xla/stream_executor/device_description.h"
38 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
39 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
40 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
41 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
42 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
43 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
44 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
45
46 namespace Eigen {
47 struct half;
48 } // namespace Eigen
49
50 namespace stream_executor {
51
52 class HostBuffer;
53 class Stream;
54 class ScratchAllocator;
55
56 namespace dnn {
57
58 // Specifies an index to use when accessing specific spatial dimensions.
59 enum class DimIndex : int {
60 X = 0,
61 Y = 1,
62 Z = 2,
63 };
64
65 // Return a reordered dims.
66 std::vector<int64_t> ReorderDims(const std::vector<int64_t>& input,
67 const DataLayout& from, const DataLayout& to);
68
69 // Helper functions to make methods more readable.
GetDim(absl::Span<const int64_t> data,DimIndex dim)70 inline int64_t GetDim(absl::Span<const int64_t> data, DimIndex dim) {
71 return data.rbegin()[static_cast<int64_t>(dim)];
72 }
73
SetDim(absl::Span<int64_t> data,DimIndex dim,int64_t value)74 inline void SetDim(absl::Span<int64_t> data, DimIndex dim, int64_t value) {
75 data.rbegin()[static_cast<int64_t>(dim)] = value;
76 }
77
SetDim(std::vector<int64_t> * data,DimIndex dim,int64_t value)78 inline void SetDim(std::vector<int64_t>* data, DimIndex dim, int64_t value) {
79 return SetDim(absl::MakeSpan(*data), dim, value);
80 }
81
82 // int64_t is not the same type as tensorflow::protobuf_int64 in open-source.
83 // This wrapper function gives an int64_t array slice view of a repeated int64
84 // protobuf field.
85 //
86 // T should be a protobuf RepeatedField.
87 template <typename T>
AsInt64Slice(const T & repeated_field)88 inline absl::Span<const int64_t> AsInt64Slice(const T& repeated_field) {
89 using data_ty =
90 typename std::remove_reference<decltype(*repeated_field.data())>::type;
91 static_assert(std::is_integral<data_ty>::value &&
92 std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
93 "repeated_field.data() must return a pointer to a signed "
94 "64-bit integer type.");
95 return absl::Span<const int64_t>(
96 reinterpret_cast<const int64_t*>(repeated_field.data()),
97 repeated_field.size());
98 }
99 template <typename T>
AsInt64Slice(T * repeated_field)100 inline absl::Span<int64_t> AsInt64Slice(T* repeated_field) {
101 using data_ty =
102 typename std::remove_reference<decltype(*repeated_field->data())>::type;
103 static_assert(std::is_integral<data_ty>::value &&
104 std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
105 "repeated_field->data() must return a pointer to a signed "
106 "64-bit integer type.");
107 return absl::Span<int64_t>(
108 reinterpret_cast<int64_t*>(repeated_field->mutable_data()),
109 repeated_field->size());
110 }
111
112 // Returns a string representation of the given data layout.
113 std::string DataLayoutString(DataLayout layout);
114
115 // Specifies a quantization for activations in a given BatchDescriptor.
116 enum class QuantizedActivationMode {
117 k8Bit = 1,
118 k16Bit = 2,
119 k32Bit = 4,
120 };
121
122 // Specifies the types of a RNN model.
123 enum class RnnMode {
124 kRnnRelu = 0,
125 kRnnTanh = 1,
126 kRnnLstm = 2,
127 kRnnGru = 3,
128 };
129
130 // Specifies the input model and whether there is a linear transformation
131 // between the input state and the first layer hidden state.
132 enum class RnnInputMode {
133 kRnnLinearSkip = 0,
134 kRnnSkipInput = 1,
135 };
136
137 // Specifies the number of directions used in a RNN model. When bidirection
138 // is used, the input states and output sequence contain data for both
139 // directions.
140 enum class RnnDirectionMode {
141 kRnnUnidirectional = 0,
142 kRnnBidirectional = 1,
143 };
144
145 // Relevant to DepthToSpace and SpaceToDepth. This is the write layout when
146 // performing depth to space and the read layout when performing space to depth.
147 // It's specified with most-major dimension first and most-minor dimension last.
148 // In DepthToSpace, the D*M^2 values are read in and then, for DepthHeightWidth,
149 // written out to the output patch, by varying first width, then height, then
150 // depth. In C array format, it looks like [depth][height][width]. See
151 // DepthToSpace comment for more information.
152 enum class DepthToSpaceLayout { DepthHeightWidth };
153
154 // Specifies the descriptor for a RNN model.
155 //
156 // An example use case:
157 // * The user first creates a model through createRnnDescriptor.
158 // * The user queries the size of the underlying opaque parameter buffer.
159 // * The user creates and initializes a parameter buffer of the proper size.
160 // * The user runs forward and backward operations using this RNN descriptor.
161 // * Once a while, user queries maintainable weights and bias regions from
162 // the underlying parameter buffer. They are more likely to be forward
163 // compatible and should used in saving and restoring a model.
164 // * The user releases the RNN descriptor when the model is no longer in use.
165 class RnnDescriptor {
166 public:
167 struct ParamsRegion {
168 int64_t offset;
169 int64_t size;
170 };
171 typedef std::vector<ParamsRegion> ParamsRegions;
~RnnDescriptor()172 virtual ~RnnDescriptor() {}
ParamsSizeInBytes()173 virtual int64_t ParamsSizeInBytes() const { return -1; }
ParamsWeightRegions()174 virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); }
ParamsBiasRegions()175 virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
176 };
177
178 // Specifies the sequence in a RNN model.
179 //
180 // The user is responsible for releasing this descriptor when it is no longer
181 // in use. The destructor releases the underlying descriptors.
182 class RnnSequenceTensorDescriptor {
183 public:
~RnnSequenceTensorDescriptor()184 virtual ~RnnSequenceTensorDescriptor() {}
185 };
186
187 // Specifies either the input and hidden state in a RNN model.
188 //
189 // The user is responsible for releasing this descriptor when it is no longer
190 // in use. The destructor releases the underlying descriptors.
191 class RnnStateTensorDescriptor {
192 public:
~RnnStateTensorDescriptor()193 virtual ~RnnStateTensorDescriptor() {}
194 };
195
196 // Returns a string representation of the given quantization mode.
197 std::string QuantizedActivationModeString(QuantizedActivationMode mode);
198
199 // Describes the dimensions that a layer consumes/produces.
200 //
201 // This is a matrix (height, width), its "depth" (feature_map_count),
202 // how many of these matrices are present (count),
203 // and the maximum and minimum values expected in the matrix (value_max,
204 // value_min).
205 // If input is quantized, all values greater
206 // than value_max will be clipped to value_max and all values less than
207 // value_min will be clipped to value_min.
208 // When quantized output is dequantized no value will be greater than
209 // value_max or less than value_min.
210 //
211 // Uses the named argument construction form:
212 //
213 // auto input_batch_dimensions =
214 // BatchDescriptor().set_count(42).set_feature_map_count(7)...
215 //
216 // Details:
217 //
218 // For a convolutional layer, a single inference takes a 3-dimensional matrix
219 // of input and produces a 3-dimensional matrix of output. We call the three
220 // dimensions height, width and feature_map_count, where for an image, the
221 // height and width correspond to the Y and X pixel indices, respectively, and
222 // the feature_map_count corresponds to the RGB dimension of the input data.
223 // Then the count indicates how many 3D matrices are being presented to be
224 // processed at once; this corresponds to the neural network concept of
225 // minibatch size.
226 //
227 // For a fully connected layer, it's better to put the nodes of the layer in
228 // the feature_map_count, and leave the height and weight as degenerate (== 1).
229 // Count indicates how many input vectors (degenerate 3D matrices) are to be
230 // processed.
231 //
232 // If unspecified, value_max and value_min default to 0.0.
233 // If value_max == value_min the Stream will attempt to derive valid values -
234 // for example the output of Relu6 activation will always be in the range
235 // [0.0, 6.0].
236 //
237 // If unspecified, layout defaults to kYXDepthBatch.
238 class BatchDescriptor {
239 public:
240 // Creates a "blank" batch descriptor, which should be initialized via the
241 // named argument helpers.
242 BatchDescriptor();
243 explicit BatchDescriptor(int ndims);
244
245 // Clones values from 'other' for initialization.
246 void CloneFrom(const BatchDescriptor& other);
247
248 std::string ToString() const;
249 std::string ToShortString() const;
250
251 // Pre-condition:
252 // value_max_ == 0
253 // value_min_ == 0
254 // quantized_activation_mode_ == QuantizedActivationMode::k8Bit
255 TensorDescriptorProto ToProto(DataType data_type) const;
256
257 // Accessors.
count()258 int64_t count() const { return tensor_.dimensions(0); }
feature_map_count()259 int64_t feature_map_count() const { return tensor_.dimensions(1); }
height()260 int64_t height() const { return GetDim(spatial_size(), DimIndex::Y); }
width()261 int64_t width() const { return GetDim(spatial_size(), DimIndex::X); }
spatial_dim(DimIndex dim)262 int64_t spatial_dim(DimIndex dim) const {
263 return GetDim(spatial_size(), dim);
264 }
ndims()265 int ndims() const { return spatial_size().size(); }
value_max()266 float value_max() const { return value_max_; }
value_min()267 float value_min() const { return value_min_; }
layout()268 DataLayout layout() const { return tensor_.data_layout(); }
quantized_activation_mode()269 QuantizedActivationMode quantized_activation_mode() const {
270 return quantized_activation_mode_;
271 }
272 // Full dimensions of the underlying data, ordered according to a specific
273 // layout.
274 std::vector<int64_t> full_dims(const DataLayout& layout) const;
275
276 // Full strides of the underlying data, ordered according to a specific
277 // layout.
278 std::vector<int64_t> full_strides(const DataLayout& layout) const;
279
280 // Vectorized dimensions where users can specify the dimension that the number
281 // of dimensions is reported rather than the full number of elements.
282 std::vector<int64_t> vectorized_dims(const DataLayout& layout,
283 int vector_size, int vector_dim) const;
284
285 // Vectorized strides correspond to the vectorized_dims.
286 std::vector<int64_t> vectorized_strides(const DataLayout& layout,
287 int vector_size,
288 int vector_dim) const;
289
290 // Named-argument helpers for avoiding user error during construction.
set_count(int64_t value)291 BatchDescriptor& set_count(int64_t value) {
292 tensor_.set_dimensions(0, value);
293 return *this;
294 }
set_feature_map_count(int64_t value)295 BatchDescriptor& set_feature_map_count(int64_t value) {
296 tensor_.set_dimensions(1, value);
297 return *this;
298 }
set_height(int64_t value)299 BatchDescriptor& set_height(int64_t value) {
300 SetDim(spatial_size(), DimIndex::Y, value);
301 return *this;
302 }
set_width(int64_t value)303 BatchDescriptor& set_width(int64_t value) {
304 SetDim(spatial_size(), DimIndex::X, value);
305 return *this;
306 }
set_spatial_dim(DimIndex dim,int64_t value)307 BatchDescriptor& set_spatial_dim(DimIndex dim, int64_t value) {
308 SetDim(spatial_size(), dim, value);
309 return *this;
310 }
set_value_max(float value)311 BatchDescriptor& set_value_max(float value) {
312 value_max_ = value;
313 return *this;
314 }
set_value_min(float value)315 BatchDescriptor& set_value_min(float value) {
316 value_min_ = value;
317 return *this;
318 }
set_layout(DataLayout layout)319 BatchDescriptor& set_layout(DataLayout layout) {
320 tensor_.set_data_layout(layout);
321 return *this;
322 }
set_quantized_activation_mode(QuantizedActivationMode quantized_activation_mode)323 BatchDescriptor& set_quantized_activation_mode(
324 QuantizedActivationMode quantized_activation_mode) {
325 quantized_activation_mode_ = quantized_activation_mode;
326 return *this;
327 }
328
329 // Return the number of nodes in a single feature map.
330 int64_t NodesPerFeatureMap() const;
331
332 // Return the number of nodes across all feature maps. Note that this is not
333 // affected by the batch count.
334 int64_t NodesAcrossFeatureMaps() const;
335
336 // Returns the number of elements (e.g. RGB pixel values) required to hold a
337 // given batch descriptor, given a no-padding assumption. Note that this is
338 // affected by the batch count.
339 int64_t ElementCount() const;
340
341 // Return the number of weights required to fully connect a layer with
342 // dimensions given by the 'input' descriptor with a layer with dimensions
343 // given by the 'output' descriptor.
344 static int64_t FullyConnectedWeightCount(const BatchDescriptor& input,
345 const BatchDescriptor& output);
346
347 // Return the number of biases required to fully connect to an output layer
348 // with dimensions given the 'output' descriptor.
349 static int64_t FullyConnectedBiasCount(const BatchDescriptor& output);
350
351 // Return a BatchDescriptor for the output of a depth concatenation
352 // with the given input descriptors. The inputs should have the same
353 // dimensions, except possibly for feature_map_count(), though this
354 // function does not verify that.
355 static BatchDescriptor DepthConcatenateOutputDescriptor(
356 port::ArraySlice<dnn::BatchDescriptor> inputs); // non-absl ok
357
358 private:
spatial_size()359 absl::Span<const int64_t> spatial_size() const {
360 return AsInt64Slice(tensor_.dimensions()).subspan(2);
361 }
362
spatial_size()363 absl::Span<int64_t> spatial_size() {
364 return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
365 }
366
367 TensorDescriptorProto tensor_;
368 float value_max_;
369 float value_min_;
370 QuantizedActivationMode quantized_activation_mode_;
371 };
372
373 // Returns a string representation of the given filter layout.
374 std::string FilterLayoutString(FilterLayout layout);
375
376 // Describes a filter for the convolution. This is the "window" from
377 // height-by-width patches of each of the feature maps in the input layer to the
378 // cells within the output feature map.
379 //
380 // Uses the named argument construction form:
381 //
382 // FilterDescriptor filter_dimensions;
383 // filter_dimensions
384 // .set_output_feature_map_count(42)
385 // .set_input_feature_map_count(7)
386 // ...
387 //
388 // Arguments:
389 // - output_feature_map_count: number of feature maps in the output layer.
390 // - input_feature_map_count: number of feature maps in the input layer (from
391 // which the filter patch is taken).
392 // - input_filter_height: "height" number of neurons used in the sliding window
393 // over the input layer.
394 // - input_filter_width: "width" number of neurons used in the sliding window
395 // over the input layer.
396 //
397 // Sometimes names like "filter input height" are referred to by synonymous
398 // terminology, such as "kernel y size".
399 //
400 // If unspecified, layout defaults to kOutputInputYX.
401 class FilterDescriptor {
402 public:
403 // By default construction, all dimensions are set to zero, so they should all
404 // be populated by the user via the named-argument helpers below. (See class
405 // comment for details.)
406 FilterDescriptor();
407 explicit FilterDescriptor(int ndims);
408 ~FilterDescriptor();
409
410 // Named-argument helpers for avoiding user error during construction.
set_output_feature_map_count(int64_t value)411 FilterDescriptor& set_output_feature_map_count(int64_t value) {
412 tensor_.set_dimensions(0, value);
413 return *this;
414 }
set_input_feature_map_count(int64_t value)415 FilterDescriptor& set_input_feature_map_count(int64_t value) {
416 tensor_.set_dimensions(1, value);
417 return *this;
418 }
set_input_filter_height(int64_t value)419 FilterDescriptor& set_input_filter_height(int64_t value) {
420 SetDim(input_filter_dims(), DimIndex::Y, value);
421 return *this;
422 }
set_input_filter_width(int64_t value)423 FilterDescriptor& set_input_filter_width(int64_t value) {
424 SetDim(input_filter_dims(), DimIndex::X, value);
425 return *this;
426 }
set_layout(FilterLayout layout)427 FilterDescriptor& set_layout(FilterLayout layout) {
428 tensor_.set_filter_layout(layout);
429 return *this;
430 }
set_spatial_dim(DimIndex dim,int64_t value)431 FilterDescriptor& set_spatial_dim(DimIndex dim, int64_t value) {
432 SetDim(input_filter_dims(), dim, value);
433 return *this;
434 }
ndims()435 int ndims() const { return input_filter_dims().size(); }
436
437 void CloneFrom(const FilterDescriptor& other);
438
439 std::string ToString() const;
440 std::string ToShortString() const;
441 TensorDescriptorProto ToProto(DataType data_type) const;
442
443 // Returns the number of weights required as parameters for a convolution
444 // using this filter descriptor.
445 int64_t ComputeWeightCount() const;
446
447 // Returns the number of biases required as parameters for a convolution
448 // using this filter descriptor.
bias_count()449 int64_t bias_count() const { return output_feature_map_count(); }
450
output_feature_map_count()451 int64_t output_feature_map_count() const { return tensor_.dimensions(0); }
input_feature_map_count()452 int64_t input_feature_map_count() const { return tensor_.dimensions(1); }
input_filter_height()453 int64_t input_filter_height() const {
454 return GetDim(input_filter_dims(), DimIndex::Y);
455 }
input_filter_width()456 int64_t input_filter_width() const {
457 return GetDim(input_filter_dims(), DimIndex::X);
458 }
input_filter_dim(DimIndex dim)459 int64_t input_filter_dim(DimIndex dim) const {
460 return GetDim(input_filter_dims(), dim);
461 }
462
layout()463 FilterLayout layout() const { return tensor_.filter_layout(); }
464
input_filter_dims()465 absl::Span<const int64_t> input_filter_dims() const {
466 return AsInt64Slice(tensor_.dimensions()).subspan(2);
467 }
468
469 // Full dimensions of the underlying filter,
470 // ordered according to a specific layout.
471 std::vector<int64_t> full_dims(const FilterLayout& layout) const;
472
473 // Full strides of the underlying filter,
474 // ordered according to a specific layout.
475 std::vector<int64_t> full_strides(const FilterLayout& layout) const;
476
477 // Vectorized dimensions where users can specify the dimension that the number
478 // of dimensions is reported rather than the full number of elements.
479 std::vector<int64_t> vectorized_dims(const FilterLayout& layout,
480 int vector_size, int vector_dim) const;
481
482 // Vectorized strides correspond to the vectorized_dims.
483 std::vector<int64_t> vectorized_strides(const FilterLayout& layout,
484 int vector_size,
485 int vector_dim) const;
486
487 private:
input_filter_dims()488 absl::Span<int64_t> input_filter_dims() {
489 return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
490 }
491
492 TensorDescriptorProto tensor_;
493 };
494
495 // Describes how padding should be aligned when the total number of pad
496 // elements is odd.
497 enum class PadAlignment : int64_t {
498 kDefault = 0, // default padding for the device.
499 kCudnnPadding, // cuDNN padding - prefer to pad at the start.
500 kTensorFlowPadding, // TensorFlow padding - prefer to pad at the end.
501 };
502
503 // Returns a string representation of the given padding alignment.
504 std::string PadAlignmentString(PadAlignment alignment);
505
506 // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
507 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
508
509 // Describes a convolution.
510 //
511 // Uses the named argument construction form:
512 //
513 // ConvolutionDescriptor convolution_dimensions;
514 // convolution_dimensions
515 // .set_vertical_filter_stride(2)
516 // .set_horizontal_filter_stride(2)
517 // ...
518 //
519 // Arguments:
520 // - zero_padding_height: padding of the "y dimension" of the input data. Note
521 // that this is different from the height of the filter.
522 // - zero_padding_width: analogous to the height above, but in the "x
523 // dimension".
524 // - vertical_filter_stride: the convolution slides a 2-dimensional window of
525 // filter-height-by-filter-width over the input layer -- the center of that
526 // window is moved in the "y dimension" according to this stride value.
527 // - horizontal_filter_stride: analogous to the vertical stride above, but in
528 // the "x dimension".
529 // - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped
530 // cells between each filter element in the "y dimension".
531 // - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1)
532 // skipped cells between each filter element in the "x dimension".
533 // - convolution_not_crosscor: By default (convolution_not_crosscor == false),
534 // we perform cross correlation rather than convolution. With the flag set,
535 // we perform convolution. Convolution and cross correlation are related by
536 // rotating the filter by 180 degrees (or equivalently flipping all spatial
537 // dimensions).
538 class ConvolutionDescriptor {
539 public:
540 // By default construction, there is no zero-padding and the filter stride is
541 // 1x1 (centering the filter on every cell in the input layer's
542 // width-by-height area).
543 ConvolutionDescriptor();
544 explicit ConvolutionDescriptor(int ndims);
545 ~ConvolutionDescriptor();
546
547 std::string ToString() const;
548 std::string ToShortString() const;
ToProto()549 ConvolutionDescriptorProto ToProto() const { return proto_; }
550
set_zero_padding_height(int64_t value)551 ConvolutionDescriptor& set_zero_padding_height(int64_t value) {
552 SetDim(padding(), DimIndex::Y, value);
553 return *this;
554 }
set_zero_padding_width(int64_t value)555 ConvolutionDescriptor& set_zero_padding_width(int64_t value) {
556 SetDim(padding(), DimIndex::X, value);
557 return *this;
558 }
set_zero_padding(DimIndex dim,int64_t value)559 ConvolutionDescriptor& set_zero_padding(DimIndex dim, int64_t value) {
560 SetDim(padding(), dim, value);
561 return *this;
562 }
set_vertical_filter_stride(int64_t value)563 ConvolutionDescriptor& set_vertical_filter_stride(int64_t value) {
564 SetDim(strides(), DimIndex::Y, value);
565 return *this;
566 }
set_horizontal_filter_stride(int64_t value)567 ConvolutionDescriptor& set_horizontal_filter_stride(int64_t value) {
568 SetDim(strides(), DimIndex::X, value);
569 return *this;
570 }
set_filter_stride(DimIndex dim,int64_t value)571 ConvolutionDescriptor& set_filter_stride(DimIndex dim, int64_t value) {
572 SetDim(strides(), dim, value);
573 return *this;
574 }
set_vertical_dilation_rate(int64_t value)575 ConvolutionDescriptor& set_vertical_dilation_rate(int64_t value) {
576 SetDim(dilations(), DimIndex::Y, value);
577 return *this;
578 }
set_horizontal_dilation_rate(int64_t value)579 ConvolutionDescriptor& set_horizontal_dilation_rate(int64_t value) {
580 SetDim(dilations(), DimIndex::X, value);
581 return *this;
582 }
set_dilation_rate(DimIndex dim,int64_t value)583 ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64_t value) {
584 SetDim(dilations(), dim, value);
585 return *this;
586 }
set_group_count(int group_count)587 ConvolutionDescriptor& set_group_count(int group_count) {
588 proto_.set_group_count(group_count);
589 return *this;
590 }
set_convolution_not_crosscorr(bool conv)591 ConvolutionDescriptor& set_convolution_not_crosscorr(bool conv) {
592 proto_.set_convolution_mode(conv ? ConvolutionMode::CONVOLUTION
593 : ConvolutionMode::CROSS_CORRELATION);
594 return *this;
595 }
set_name(const std::string & name)596 ConvolutionDescriptor& set_name(const std::string& name) {
597 proto_.set_name(name);
598 return *this;
599 }
zero_padding_height()600 int64_t zero_padding_height() const { return GetDim(padding(), DimIndex::Y); }
zero_padding_width()601 int64_t zero_padding_width() const { return GetDim(padding(), DimIndex::X); }
vertical_filter_stride()602 int64_t vertical_filter_stride() const {
603 return GetDim(strides(), DimIndex::Y);
604 }
horizontal_filter_stride()605 int64_t horizontal_filter_stride() const {
606 return GetDim(strides(), DimIndex::X);
607 }
vertical_dilation_rate()608 int64_t vertical_dilation_rate() const {
609 return GetDim(dilations(), DimIndex::Y);
610 }
horizontal_dilation_rate()611 int64_t horizontal_dilation_rate() const {
612 return GetDim(dilations(), DimIndex::X);
613 }
614
zero_padding(DimIndex dim)615 int zero_padding(DimIndex dim) const { return GetDim(padding(), dim); }
filter_stride(DimIndex dim)616 int filter_stride(DimIndex dim) const { return GetDim(strides(), dim); }
dilation_rate(DimIndex dim)617 int dilation_rate(DimIndex dim) const { return GetDim(dilations(), dim); }
618 // TODO(timshen): remove this function. No users of this class is setting a
619 // non-default pad alignment.
pad_alignment()620 PadAlignment pad_alignment() const { return PadAlignment::kDefault; }
group_count()621 int group_count() const { return proto_.group_count(); }
ndims()622 int ndims() const { return padding().size(); }
convolution_not_crosscorr()623 bool convolution_not_crosscorr() const {
624 return proto_.convolution_mode() == ConvolutionMode::CONVOLUTION;
625 }
626
strides()627 absl::Span<const int64_t> strides() const {
628 return AsInt64Slice(proto_.strides());
629 }
630
dilations()631 absl::Span<const int64_t> dilations() const {
632 return AsInt64Slice(proto_.dilations());
633 }
634
padding()635 absl::Span<const int64_t> padding() const {
636 return AsInt64Slice(proto_.paddings());
637 }
638
name()639 std::string name() const { return proto_.name(); }
640
641 private:
strides()642 absl::Span<int64_t> strides() {
643 return AsInt64Slice(proto_.mutable_strides());
644 }
645
dilations()646 absl::Span<int64_t> dilations() {
647 return AsInt64Slice(proto_.mutable_dilations());
648 }
649
padding()650 absl::Span<int64_t> padding() {
651 return AsInt64Slice(proto_.mutable_paddings());
652 }
653
654 ConvolutionDescriptorProto proto_;
655
656 // TODO(leary) cudnn provides these fields, but need to characterize what
657 // their effect is -- they may be boolean rather than integral.
658 // int64_t upscale_input_x;
659 // int64_t upscale_input_y;
660 };
661
662 // A patch of values in the input can be pooled via either a max or an average
663 // operation.
664 // Specify int64_t so there's no padding in PoolingDescriptor.
665 enum class PoolingMode : int64_t {
666 kMaximum,
667 kAverage,
668 };
669
670 // Specify the dimension in which to concatenate inputs in space.
671 // Specify int64_t so there's no padding in SpaceConcatenateMode.
672 enum class SpaceConcatenateMode : int64_t {
673 XDirection,
674 YDirection,
675 };
676
677 // Returns a short name for the pooling mode, e.g. "Avg".
678 std::string ShortPoolingModeString(PoolingMode mode);
679
680 // Describes a pooling operation to be enqueued onto a stream via a platform's
681 // DnnSupport.
682 //
683 // TODO(broune): describe how padding works and what happens if the
684 // window height/width is not divisible by the vertical/horizontal
685 // stride.
686 //
687 // Arguments:
688 // pooling_mode: pooling operator to use on the input patch
689 // window_height: height of input window
690 // window_width: width of input window
691 // vertical_stride: vertical delta for center of the input patch
692 // horizontal_stride: horizontal delta for center of the input patch
693 class PoolingDescriptor {
694 public:
695 PoolingDescriptor();
696 explicit PoolingDescriptor(int ndims);
697
set_pooling_mode(PoolingMode value)698 PoolingDescriptor& set_pooling_mode(PoolingMode value) {
699 mode_ = value;
700 return *this;
701 }
set_window_height(int64_t value)702 PoolingDescriptor& set_window_height(int64_t value) {
703 SetDim(&window_, DimIndex::Y, value);
704 return *this;
705 }
set_window_width(int64_t value)706 PoolingDescriptor& set_window_width(int64_t value) {
707 SetDim(&window_, DimIndex::X, value);
708 return *this;
709 }
set_window(DimIndex dim,int64_t value)710 PoolingDescriptor& set_window(DimIndex dim, int64_t value) {
711 SetDim(&window_, dim, value);
712 return *this;
713 }
set_vertical_padding(int64_t value)714 PoolingDescriptor& set_vertical_padding(int64_t value) {
715 SetDim(&padding_, DimIndex::Y, value);
716 return *this;
717 }
set_horizontal_padding(int64_t value)718 PoolingDescriptor& set_horizontal_padding(int64_t value) {
719 SetDim(&padding_, DimIndex::X, value);
720 return *this;
721 }
set_padding(DimIndex dim,int64_t value)722 PoolingDescriptor& set_padding(DimIndex dim, int64_t value) {
723 SetDim(&padding_, dim, value);
724 return *this;
725 }
set_vertical_stride(int64_t value)726 PoolingDescriptor& set_vertical_stride(int64_t value) {
727 SetDim(&strides_, DimIndex::Y, value);
728 return *this;
729 }
set_horizontal_stride(int64_t value)730 PoolingDescriptor& set_horizontal_stride(int64_t value) {
731 SetDim(&strides_, DimIndex::X, value);
732 return *this;
733 }
set_stride(DimIndex dim,int64_t value)734 PoolingDescriptor& set_stride(DimIndex dim, int64_t value) {
735 SetDim(&strides_, dim, value);
736 return *this;
737 }
set_propagate_nans(bool value)738 PoolingDescriptor& set_propagate_nans(bool value) {
739 propagate_nans_ = value;
740 return *this;
741 }
set_name(const std::string & name)742 PoolingDescriptor& set_name(const std::string& name) {
743 name_ = name;
744 return *this;
745 }
746
ndims()747 int ndims() const { return ndims_; }
748 void CloneFrom(const PoolingDescriptor& other);
749
750 std::string ToString() const;
751 std::string ToShortString() const;
752
mode()753 PoolingMode mode() const { return mode_; }
window_height()754 int64_t window_height() const { return GetDim(window_, DimIndex::Y); }
window_width()755 int64_t window_width() const { return GetDim(window_, DimIndex::X); }
window(DimIndex dim)756 int64_t window(DimIndex dim) const { return GetDim(window_, dim); }
vertical_padding()757 int64_t vertical_padding() const { return GetDim(padding_, DimIndex::Y); }
horizontal_padding()758 int64_t horizontal_padding() const { return GetDim(padding_, DimIndex::X); }
padding(DimIndex dim)759 int64_t padding(DimIndex dim) const { return GetDim(padding_, dim); }
vertical_stride()760 int64_t vertical_stride() const { return GetDim(strides_, DimIndex::Y); }
horizontal_stride()761 int64_t horizontal_stride() const { return GetDim(strides_, DimIndex::X); }
stride(DimIndex dim)762 int64_t stride(DimIndex dim) const { return GetDim(strides_, dim); }
window()763 absl::Span<const int64_t> window() const { return window_; }
padding()764 absl::Span<const int64_t> padding() const { return padding_; }
strides()765 absl::Span<const int64_t> strides() const { return strides_; }
propagate_nans()766 bool propagate_nans() const { return propagate_nans_; }
name()767 std::string name() const { return name_; }
768
769 private:
770 PoolingMode mode_;
771 int ndims_;
772 bool propagate_nans_;
773 std::string name_; // Name as in Tensorflow NodeDef, for debugging purposes.
774
775 // Stored as: ..., y, x.
776 std::vector<int64_t> window_;
777 std::vector<int64_t> padding_;
778 std::vector<int64_t> strides_;
779 };
780
781 // Collects parameters for DNN algorithms
782 class AlgorithmDesc {
783 public:
784 typedef int64_t Index;
AlgorithmDesc()785 AlgorithmDesc() : AlgorithmDesc(0, false, std::nullopt) {}
AlgorithmDesc(AlgorithmProto proto)786 explicit AlgorithmDesc(AlgorithmProto proto) : proto_(std::move(proto)) {}
AlgorithmDesc(Index algo_id,bool use_tensor_ops)787 AlgorithmDesc(Index algo_id, bool use_tensor_ops)
788 : AlgorithmDesc(algo_id, use_tensor_ops, std::nullopt) {}
AlgorithmDesc(Index algo_id,bool use_tensor_ops,std::optional<uint64_t> workspace_size)789 AlgorithmDesc(Index algo_id, bool use_tensor_ops,
790 std::optional<uint64_t> workspace_size) {
791 proto_.set_is_cudnn_frontend(false);
792 proto_.set_algo_id(algo_id);
793 proto_.set_math_type(use_tensor_ops ? AlgorithmProto::TENSOR_OP_MATH
794 : AlgorithmProto::DEFAULT_MATH);
795 if (workspace_size) {
796 proto_.mutable_workspace_size()->set_value(*workspace_size);
797 }
798 }
799 AlgorithmDesc(int64_t engine_id,
800 const std::vector<std::pair<int64_t, int64_t>>& tuning_knobs,
801 std::optional<uint64_t> workspace_size);
is_cudnn_frontend()802 bool is_cudnn_frontend() const { return proto_.is_cudnn_frontend(); }
803
tensor_ops_enabled()804 bool tensor_ops_enabled() const {
805 return proto_.math_type() == AlgorithmProto::TENSOR_OP_MATH;
806 }
workspace_size()807 std::optional<uint64_t> workspace_size() const {
808 if (proto_.has_workspace_size()) {
809 return proto_.workspace_size().value();
810 }
811 return std::nullopt;
812 }
algo_id()813 Index algo_id() const { return proto_.algo_id(); }
814
815 std::vector<std::pair<int64_t, int64_t>> TuningKnobs() const;
816
817 bool operator==(const AlgorithmDesc& other) const;
818
819 uint64_t hash() const;
820
ToProto()821 AlgorithmProto ToProto() const { return proto_; }
822
823 std::string ToString() const;
824
825 private:
826 AlgorithmProto proto_;
827 };
828
829 // Describes the result from a perf experiment.
830 //
831 // Arguments:
832 // algorithm: returns the exact algorithm that was used.
833 // elapsed_time_in_ms: returns the measured elapsed time in milliseconds.
834 class ProfileResult {
835 public:
is_valid()836 bool is_valid() const {
837 return algorithm_.has_value() &&
838 elapsed_time_in_ms() != std::numeric_limits<float>::max();
839 }
840
algorithm()841 AlgorithmDesc algorithm() const { return *algorithm_; }
set_algorithm(AlgorithmDesc val)842 void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
843
elapsed_time_in_ms()844 float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)845 void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
846
scratch_size()847 size_t scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)848 void set_scratch_size(size_t val) { scratch_size_ = val; }
849
850 private:
851 std::optional<AlgorithmDesc> algorithm_;
852 float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
853 // The scratch size algorithm_ requires. Currently it's only populated by
854 // convolutions.
855 size_t scratch_size_ = 0;
856 };
857
858 // Backend-specific data shared between repeated launches of the same
859 // convolution.
860 template <typename Sig>
861 class OpRunner;
862
863 // An abstract class owning cached state for a particular op/configuration.
864 //
865 // The primary motivation for this is cuDNN backend ExecutionPlans, which are
866 // costly to recreate.
867 //
868 // All OpRunners must be outlived by their parent Stream.
869 template <typename... Args>
870 class OpRunner<void(Args...)> {
871 public:
~OpRunner()872 virtual ~OpRunner() {}
873
874 // Get a description of the runner, for uniqueness of autotune entries.
875 //
876 // Since this is used to determine whether runners are equivalent for the
877 // purpose of scoring autotune entries, it shall be unique among runners of
878 // the same op and parameters.
879 virtual std::string ToString() const = 0;
880
881 // Get the number of bytes of scratch space needed for `operator()`.
882 //
883 // If determining the workspace size can fail, runners should precompute and
884 // cache it at construction time.
885 virtual size_t GetWorkspaceSize() const = 0;
886
887 // Convert to an AlgorithmDesc for AoT compilation or autotuning.
888 virtual port::StatusOr<AlgorithmDesc> ToAlgorithmDesc() const = 0;
889
890 // Launch the operation, with the signature determined by `Sig`.
891 virtual port::Status operator()(Stream*, ProfileResult*,
892 DeviceMemoryBase scratch_memory,
893 Args... args) const = 0;
894 };
895
896 using ConvSignature = void(DeviceMemoryBase /* input_data */,
897 DeviceMemoryBase /* filter_data */,
898 DeviceMemoryBase /* output_data */);
899 using ConvRunner = OpRunner<ConvSignature>;
900
901 using FusedConvSignature = void(DeviceMemoryBase /* input_data */,
902 DeviceMemoryBase /* filter_data */,
903 DeviceMemoryBase /* side_input_data */,
904 DeviceMemoryBase /* bias_data */,
905 DeviceMemoryBase /* output_data */);
906 using FusedConvRunner = OpRunner<FusedConvSignature>;
907
908 // Describes the configuration for the algorithms that will used.
909 //
910 // Arguments:
911 // algorithm: the primary algorithm that should be used.
912 // algorithm_no_scratch: a secondary algorithm that should be used, if the
913 // the allocation for the scratch memory fails.
914 // scrach_size: specify the size of scratch memory in bytes needed for the
915 // algorithm used.
916 //
917 // On CUDA platform with CUDNN library, algorithm and algorithm_no_scratch
918 // would be used. On ROCm platform with MIOpen library, algorithm and
919 // scratch_size would be used. The major difference between the two platforms
920 // are whether it's possible to get an algorithm without scratch memory. On
921 // CUDA + CUDNN it's possible, and algorithm_no_scratch can be used to track
922 // such information, whereas on ROCm + MIOpen there is no guarantee to getting
923 // one without scratch memory, and scratch_size field is used to track it.
924 class AlgorithmConfig {
925 public:
AlgorithmConfig()926 AlgorithmConfig() {}
AlgorithmConfig(AlgorithmDesc algorithm)927 explicit AlgorithmConfig(AlgorithmDesc algorithm) : algorithm_(algorithm) {}
AlgorithmConfig(AlgorithmDesc algorithm,size_t scratch_size)928 AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size)
929 : algorithm_(algorithm), scratch_size_(scratch_size) {}
AlgorithmConfig(AlgorithmDesc algorithm,AlgorithmDesc algorithm_no_scratch)930 AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch)
931 : algorithm_(algorithm), algorithm_no_scratch_(algorithm_no_scratch) {}
AlgorithmConfig(AlgorithmDesc algorithm,size_t scratch_size,AlgorithmDesc algorithm_no_scratch)932 AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size,
933 AlgorithmDesc algorithm_no_scratch)
934 : algorithm_(algorithm),
935 algorithm_no_scratch_(algorithm_no_scratch),
936 scratch_size_(scratch_size) {}
937
938 // TODO(ruochengw): After cl/380702564, add support for algorithm configs with
939 // cuDNN Frontend APIs.
AlgorithmConfig(const AlgorithmConfigProto & algorithm_config_proto)940 explicit AlgorithmConfig(const AlgorithmConfigProto& algorithm_config_proto) {
941 const AlgorithmProto& algorithm_proto = algorithm_config_proto.algorithm();
942 algorithm_ = AlgorithmDesc(algorithm_proto);
943 if (algorithm_config_proto.optional_scratch_size_case() !=
944 /*ONEOF_NAME_NOT_SET=*/0) {
945 scratch_size_ = algorithm_config_proto.scratch_size();
946 }
947 if (algorithm_config_proto.optional_algorithm_no_scratch_case() !=
948 /*ONEOF_NAME_NOT_SET=*/0) {
949 const AlgorithmProto& algorithm_no_scratch_proto =
950 algorithm_config_proto.algorithm_no_scratch();
951 algorithm_no_scratch_ = AlgorithmDesc(algorithm_no_scratch_proto);
952 }
953 }
954
algorithm()955 std::optional<AlgorithmDesc> algorithm() const { return algorithm_; }
set_algorithm(AlgorithmDesc val)956 void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
algorithm_no_scratch()957 std::optional<AlgorithmDesc> algorithm_no_scratch() const {
958 return algorithm_no_scratch_;
959 }
set_algorithm_no_scratch(AlgorithmDesc val)960 void set_algorithm_no_scratch(AlgorithmDesc val) {
961 algorithm_no_scratch_ = val;
962 }
scratch_size()963 std::optional<size_t> scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)964 void set_scratch_size(size_t val) { scratch_size_ = val; }
965 bool operator==(const AlgorithmConfig& other) const {
966 return this->algorithm_ == other.algorithm_ &&
967 this->algorithm_no_scratch_ == other.algorithm_no_scratch_ &&
968 this->scratch_size_ == other.scratch_size_;
969 }
970 bool operator!=(const AlgorithmConfig& other) const {
971 return !(*this == other);
972 }
973 std::string ToString() const;
974
975 // TODO(ruochengw): After cl/380702564, add support for algorithm configs with
976 // cuDNN Frontend APIs.
ToProto()977 AlgorithmConfigProto ToProto() const {
978 AlgorithmConfigProto algorithm_config_proto;
979 if (algorithm_.has_value()) {
980 *algorithm_config_proto.mutable_algorithm() =
981 algorithm_.value().ToProto();
982 }
983 if (algorithm_no_scratch_.has_value()) {
984 *algorithm_config_proto.mutable_algorithm_no_scratch() =
985 algorithm_no_scratch_.value().ToProto();
986 }
987 if (scratch_size_.has_value()) {
988 algorithm_config_proto.set_scratch_size(scratch_size_.value());
989 }
990 return algorithm_config_proto;
991 }
992
993 private:
994 std::optional<AlgorithmDesc> algorithm_;
995 std::optional<AlgorithmDesc> algorithm_no_scratch_;
996 std::optional<size_t> scratch_size_;
997 };
998
999 // Describes a local response normalization (LRN). LRN is used e.g. in
1000 // dist_belief.
1001 //
1002 // Let V be the vector of feature maps at some (batch, y, x)
1003 // coordinate. LRN applies independently to each vector V in the
1004 // input, across all coordinates (batch, y, x), by mapping each V to
1005 // another vector U of the same size using the formula
1006 //
1007 // U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
1008 //
1009 // where the sum is taken over j in the closed range [i - range, i + range].
1010 //
1011 // When calculating U_i the j in the sum can extend beyond the bounds
1012 // of V. If wrap_around is true, then V_j = V_{j mod F} where F is the
1013 // size of V, which is the number of feature maps. If wrap_around is
1014 // false, then V_j = 0 for j outside [0, F-1].
1015 //
1016 // If segment_size <= F, where F is the number of feature_maps, then
1017 // segment_size has no effect. Otherwise, each consecutive segment of
1018 // segment_size entries in V are normalized separately.
1019 //
1020 // Not all StreamExecutors allow wrap_around == true or segment_size
1021 // != 64. Some do not implement normalization at all.
1022 class NormalizeDescriptor {
1023 public:
1024 NormalizeDescriptor();
1025
set_bias(float bias)1026 NormalizeDescriptor& set_bias(float bias) {
1027 bias_ = bias;
1028 return *this;
1029 }
1030
set_range(int32_t range)1031 NormalizeDescriptor& set_range(int32_t range) {
1032 range_ = range;
1033 return *this;
1034 }
1035
set_alpha(float alpha)1036 NormalizeDescriptor& set_alpha(float alpha) {
1037 alpha_ = alpha;
1038 return *this;
1039 }
1040
set_beta(float beta)1041 NormalizeDescriptor& set_beta(float beta) {
1042 beta_ = beta;
1043 return *this;
1044 }
1045
set_wrap_around(bool wrap_around)1046 NormalizeDescriptor& set_wrap_around(bool wrap_around) {
1047 wrap_around_ = wrap_around;
1048 return *this;
1049 }
1050
set_segment_size(int32_t segment_size)1051 NormalizeDescriptor& set_segment_size(int32_t segment_size) {
1052 segment_size_ = segment_size;
1053 return *this;
1054 }
1055
1056 void CloneFrom(const NormalizeDescriptor& other);
1057
1058 std::string ToString() const;
1059 std::string ToShortString() const;
1060
bias()1061 float bias() const { return bias_; }
range()1062 int32 range() const { return range_; }
alpha()1063 float alpha() const { return alpha_; }
beta()1064 float beta() const { return beta_; }
wrap_around()1065 bool wrap_around() const { return wrap_around_; }
segment_size()1066 int32 segment_size() const { return segment_size_; }
1067
1068 private:
1069 float bias_;
1070 int32 range_;
1071 float alpha_;
1072 float beta_;
1073 bool wrap_around_;
1074 int32 segment_size_;
1075 };
1076
1077 // Returns a string representation of the given activation mode.
1078 std::string ActivationModeString(ActivationMode mode);
1079
1080 // Describes the operation that DoElementwiseOperation should perform on its
1081 // inputs.
1082 enum class ElementwiseOperation { kAdd, kMultiply };
1083
1084 std::string ElementwiseOperationString(ElementwiseOperation op);
1085
1086 // A simple class representing the version of the backing library, to
1087 // workaround the "too perfect forwarding" issue in gcc6+ compilers.
1088 // See PR#16309 and issue #18402 for links discussing the issue.
1089 class VersionInfo {
1090 public:
1091 VersionInfo(int major = 0, int minor = 0, int patch = 0)
major_(major)1092 : major_(major), minor_(minor), patch_(patch) {}
major_version()1093 int major_version() const { return major_; }
minor_version()1094 int minor_version() const { return minor_; }
patch()1095 int patch() const { return patch_; }
1096
1097 private:
1098 int major_;
1099 int minor_;
1100 int patch_;
1101 };
1102
1103 // Suite of operations typically used for implementing Deep/Convolutional Neural
1104 // Nets. Note: A false return value of an operation indicates the
1105 // implementation is not available.
1106 //
1107 // TODO(b/118763918): this class (or rather dispatch table) has several
1108 // problems:
1109 // * Some overloads are missing. Ideally we want to have template virtual
1110 // functions while the template arguments is a closed set. However, we don't
1111 // get that from the language.
1112 // * The API is a union of cuDNN and another private backend. Only 10% of the
1113 // functions are actually implemented by both backends, the rest are
1114 // actually backend-specific. The massive interface creates extra mental
1115 // burden.
1116 // * Poor error handling: the API should return Status objects.
1117 //
1118 // PrepareForConvolution is an example for how new APIs should be written.
1119 class DnnSupport {
1120 public:
DnnSupport()1121 DnnSupport() {}
~DnnSupport()1122 virtual ~DnnSupport() {}
1123
1124 virtual port::Status Init() = 0;
1125
1126 // Gets the version of the backing library, as a VersionInfo object.
GetVersion()1127 virtual port::StatusOr<VersionInfo> GetVersion() {
1128 return port::UnimplementedError(
1129 "DnnSupport::GetVersion not implemented on this platform.");
1130 }
1131
1132 // Performs a single-precision forward batch normalization operation onto
1133 // the stream.
1134 //
1135 // Arguments:
1136 // stream: borrowed pointer to the stream that the batch normalization
1137 // operation should be enqueued onto.
1138 // x: input data.
1139 // scale: scaling parameters.
1140 // offset: offset parameters.
1141 // estimated_mean: population mean estimated during training.
1142 // Used for inference only; empty for training.
1143 // estimated_variance: population variance estimated during training,
1144 // used for inference only; empty for training.
1145 // side_input: optional input that is element-wise added to the output of
1146 // batch normalization.
1147 // x_desc: dimensions of the input data, which is the same as the dimensions
1148 // of the output and side input.
1149 // scale_offset_desc: dimensions of scale and offset.
1150 // epsilon: a small floating point number added to the variance of x.
1151 // activation_mode: activation applied to the result of batch normalization
1152 // (or after adding optional side input)
1153 // y: output data.
1154 // batch_mean: batch mean, to be used to compute the running mean.
1155 // batch_variance: batch variance, to be used to compute
1156 // the running variance.
1157 // reserve_space_1: saved mean, to be reused in the backward gradient
1158 // computation.
1159 // reserve_space_2: saved inv_var (1/sqrt(epsilon + variance), to be reused
1160 // in the backward gradient computation.
1161 // is_training: Set to true for training, false for inference.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1162 virtual bool DoBatchNormalizationForward(
1163 Stream* stream, const DeviceMemory<float>& x,
1164 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1165 const DeviceMemory<float>& estimated_mean,
1166 const DeviceMemory<float>& estimated_variance,
1167 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
1168 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1169 const double exponential_average_factor,
1170 dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
1171 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1172 DeviceMemory<float>* reserve_space_1,
1173 DeviceMemory<float>* reserve_space_2, bool is_training,
1174 ScratchAllocator* reserve_space_allocator,
1175 ScratchAllocator* workspace_allocator) {
1176 return false;
1177 }
1178
1179 // Performs a half-precision forwards batch normalization operation onto the
1180 // stream. See DoBatchNormalizationForward above for argument details.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1181 virtual bool DoBatchNormalizationForward(
1182 Stream* stream, const DeviceMemory<Eigen::half>& x,
1183 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1184 const DeviceMemory<float>& estimated_mean,
1185 const DeviceMemory<float>& estimated_variance,
1186 const DeviceMemory<Eigen::half>& side_input,
1187 const dnn::BatchDescriptor& x_desc,
1188 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1189 const double exponential_average_factor,
1190 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
1191 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1192 DeviceMemory<float>* reserve_space_1,
1193 DeviceMemory<float>* reserve_space_2, bool is_training,
1194 ScratchAllocator* reserve_space_allocator,
1195 ScratchAllocator* workspace_allocator) {
1196 return false;
1197 }
1198
1199 // Performs a single-precision backward batch normalization gradient
1200 // computation operation onto the stream.
1201 //
1202 // Arguments:
1203 // stream: borrowed pointer to the stream that the batch normalization
1204 // gradient computation operation should be enqueued onto.
1205 // y_backprop: gradient with regard to output y.
1206 // x: input data.
1207 // scale: scaling parameters.
1208 // inv_var: 1/sqrt(epsilon + variance) of x.
1209 // x_desc: dimensions of the input data, which is the same as the dimensions
1210 // of the output.
1211 // scale_offset_desc: dimensions of scale and offset.
1212 // epsilon: a small floating point number added to the variance of x.
1213 // x_backprop: gradient with respect to input x.
1214 // scale_backprop: gradient with respect to scale.
1215 // offset_backprop: gradient with respect to offset.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<float> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<float> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1216 virtual bool DoBatchNormalizationBackward(
1217 Stream* stream, const DeviceMemory<float>& y_backprop,
1218 const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
1219 const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
1220 const DeviceMemory<float>& inv_var, const DeviceMemory<float>& y,
1221 const dnn::BatchDescriptor& x_desc,
1222 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1223 dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop,
1224 DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
1225 DeviceMemory<float>* side_input_backprop,
1226 DeviceMemory<uint8>* reserve_space_data,
1227 ScratchAllocator* workspace_allocator) {
1228 return false;
1229 }
1230
1231 // Performs a half-precision backward batch normalization gradient computation
1232 // operation onto the stream. See DoBatchNormalizationBackward above for
1233 // argument details.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<Eigen::half> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<Eigen::half> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1234 virtual bool DoBatchNormalizationBackward(
1235 Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
1236 const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
1237 const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
1238 const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y,
1239 const dnn::BatchDescriptor& x_desc,
1240 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1241 dnn::ActivationMode activation_mode,
1242 DeviceMemory<Eigen::half>* x_backprop,
1243 DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
1244 DeviceMemory<Eigen::half>* side_input_backprop,
1245 DeviceMemory<uint8>* reserve_space_data,
1246 ScratchAllocator* workspace_allocator) {
1247 return false;
1248 }
1249
1250 // Enqueues a fused convolution operation onto the stream.
1251 // We provide several variants with different types for inputs, biases and
1252 // scaling parameters.
1253 //
1254 // Arguments (all borrowed):
1255 // stream: borrowed pointer to the stream that the 'convolve' operation
1256 // should be enqueued onto.
1257 // conv_input_descriptor: dimensions of the convolution input layer.
1258 // conv_input_data: un-owned device memory region which contains the
1259 // convolution input.
1260 // conv_input_scale: a floating point scale to multiply with each element
1261 // of conv_input_data.
1262 // filter_descriptor: dimensions of the convolution filter.
1263 // filter_data: un-owned device memory region which contains the
1264 // convolution filter weights.
1265 // convolution_descriptor: stride of the convolution filter.
1266 // biases: un-owned device memory region containing biases to add to the
1267 // input.
1268 // activation_mode: Type of activation to perform.
1269 // side_input_data: un-owned device memory region which contains optional
1270 // side input data. If 'side_input_scale' is non-zero, then this must
1271 // point to data in the tensor shape specified by output_shape.
1272 // It will be scaled by 'side_input_scale' and added to the convolution
1273 // result and bias prior to applying the activation function.
1274 // side_input_scale: a floating point scale to multiply with each element
1275 // of side_input_data.
1276 // output_descriptor: dimensions of the output layer.
1277 // output_data: un-owned device memory region in which to place the
1278 // convolution result.
1279 // scratch_allocator: un-owned, may-be-null object that may allocate scratch
1280 // space in order to speed up the convolution operation.
1281 // algorithm_config: specifies which algorithm should be used for the
1282 // operation.
1283 // output_profile_result: the output profile result for this call. The
1284 // profiling is only enabled when this is not nullptr.
1285 //
1286 // conv_input_descriptor, filter_descriptor, convolution_descriptor and
1287 // output_descriptor together specify exactly how the convolution is aligned
1288 // with the input data:
1289 //
1290 // * (input dimensions - filter size + 1) / filter stride == output dimensions
1291 // corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1292 // * input dimensions / filter stride == output dimensions
1293 // corresponds to dist_belief padding = SAME, i.e. input and output are the
1294 // same size - this requires padding the input.
1295 // * (input dimensions + filter size - 1) / filter stride == output dimensions
1296 // corresponds to dist_belief padding = FULL, i.e. the output is sized so
1297 // that if the inverse of the filter is applied to the output in VALID mode
1298 // the result is the same size as the input - this requires even more
1299 // padding of the input.
DoFusedConvolve(Stream * stream,DataType input_type,DataType side_input_type,DataType bias_type,DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1300 virtual port::Status DoFusedConvolve(
1301 Stream* stream, DataType input_type, DataType side_input_type,
1302 DataType bias_type, DataType output_type,
1303 const dnn::BatchDescriptor& conv_input_descriptor,
1304 DeviceMemoryBase conv_input_data, double conv_input_scale,
1305 const dnn::FilterDescriptor& filter_descriptor,
1306 DeviceMemoryBase filter_data,
1307 const dnn::ConvolutionDescriptor& convolution_descriptor,
1308 DeviceMemoryBase side_input_data, double side_input_scale,
1309 const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
1310 dnn::ActivationMode activation_mode,
1311 const dnn::BatchDescriptor& output_descriptor,
1312 DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
1313 const dnn::AlgorithmConfig& algorithm_config,
1314 dnn::ProfileResult* output_profile_result) {
1315 return port::UnimplementedError(
1316 "DnnSupport::DoFusedConvolve not implemented on this platform.");
1317 }
1318
1319 template <typename ElementType, typename OutputType>
PrepareForConvolution(ConvolutionKind kind,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemory<ElementType> input_data,const FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> filter_data,const BatchDescriptor & output_descriptor,DeviceMemory<OutputType> output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)1320 port::Status PrepareForConvolution(
1321 ConvolutionKind kind, Stream* stream,
1322 const BatchDescriptor& batch_descriptor,
1323 DeviceMemory<ElementType> input_data,
1324 const FilterDescriptor& filter_descriptor,
1325 DeviceMemory<ElementType> filter_data,
1326 const BatchDescriptor& output_descriptor,
1327 DeviceMemory<OutputType> output_data,
1328 const ConvolutionDescriptor& convolution_descriptor,
1329 const AlgorithmConfig& algorithm_config,
1330 ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
1331 DeviceMemory<uint8>* scratch_memory) {
1332 return DoPrepareForConvolution(
1333 kind, ToDataType<ElementType>::value, stream, batch_descriptor,
1334 input_data, filter_descriptor, filter_data, output_descriptor,
1335 output_data, convolution_descriptor, algorithm_config,
1336 scratch_allocator, algorithm_desc, scratch_memory);
1337 }
1338
1339 // Enqueues a single-precision convolution operation onto the stream.
1340 //
1341 // Arguments (all borrowed):
1342 // stream: borrowed pointer to the stream that the 'convolve' operation
1343 // should be enqueued onto.
1344 // input_descriptor: dimensions of the input layer.
1345 // input_data: un-owned device memory region which contains the
1346 // convolution input.
1347 // filter_descriptor: dimensions of the convolution filter.
1348 // convolution_descriptor: stride of the convolution filter.
1349 // output_descriptor: dimensions of the output layer.
1350 // output_data: un-owned device memory region in which to place the
1351 // convolution result.
1352 // algorithm_desc: specifies which algorithm should be used for the
1353 // operation.
1354 // scratch: un-owned device memory for scratch space in order to speed up
1355 // the convolution operation.
1356 // output_profile_result: the output profile result for this call. The
1357 // profiling is only enabled when this is not nullptr.
1358 //
1359 // input_descriptor, filter_descriptor, convolution_descriptor and
1360 // output_descriptor together specify exactly how the convolution is aligned
1361 // with the input data:
1362 //
1363 // * (input dimensions - filter size + 1) / filter stride == output dimensions
1364 // corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1365 // * input dimensions / filter stride == output dimensions
1366 // corresponds to dist_belief padding = SAME, i.e. input and output are the
1367 // same size - this requires padding the input.
1368 // * (input dimensions + filter size - 1) / filter stride == output dimensions
1369 // corresponds to dist_belief padding = FULL, i.e. the output is sized so
1370 // that if the inverse of the filter is applied to the output in VALID mode
1371 // the result is the same size as the input - this requires even more
1372 // padding of the input.
1373 virtual port::Status DoConvolve(
1374 ConvolutionKind kind, DataType element_type, DataType output_type,
1375 Stream* stream, const BatchDescriptor& input_descriptor,
1376 DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor,
1377 DeviceMemoryBase filter_data, const BatchDescriptor& output_descriptor,
1378 DeviceMemoryBase output_data,
1379 const ConvolutionDescriptor& convolution_descriptor,
1380 AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
1381 ProfileResult* output_profile_result) = 0;
1382
1383 // Return a list of algorithms supported by the forward convolution pass.
1384 // cc_major and cc_minor are the compute capabilities of the device.
1385 virtual bool GetConvolveAlgorithms(
1386 CudaComputeCapability cuda_compute_capability,
1387 std::vector<AlgorithmDesc>* out_algorithms);
1388
1389 virtual port::Status GetConvolveRunners(
1390 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
1391 dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
1392 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
1393 const dnn::FilterDescriptor& filter_descriptor,
1394 DeviceMemoryBase filter_data,
1395 const dnn::BatchDescriptor& output_descriptor,
1396 DeviceMemoryBase output_data,
1397 const dnn::ConvolutionDescriptor& convolution_descriptor,
1398 bool use_fallback, ScratchAllocator* scratch_allocator,
1399 std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans);
1400
1401 virtual port::StatusOr<std::unique_ptr<const dnn::ConvRunner>>
1402 ConvolveRunnerFromDesc(
1403 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
1404 dnn::ConvolutionKind kind, dnn::DataType element_type,
1405 dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
1406 const dnn::FilterDescriptor& filter_descriptor,
1407 const dnn::BatchDescriptor& output_descriptor,
1408 const dnn::ConvolutionDescriptor& convolution_descriptor);
1409
1410 virtual port::Status GetFusedConvolveRunners(
1411 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
1412 dnn::DataType element_type, dnn::DataType bias_type,
1413 dnn::DataType output_type, double conv_input_scale,
1414 double side_input_scale, double leakyrelu_alpha, Stream* stream,
1415 const dnn::BatchDescriptor& input_descriptor,
1416 const dnn::FilterDescriptor& filter_descriptor,
1417 const dnn::BatchDescriptor& bias_descriptor,
1418 const dnn::BatchDescriptor& output_descriptor,
1419 const dnn::ConvolutionDescriptor& convolution_descriptor,
1420 bool use_fallback, dnn::ActivationMode activation_mode,
1421 std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans);
1422
1423 virtual port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
1424 FusedConvolveRunnerFromDesc(
1425 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
1426 dnn::ConvolutionKind kind, dnn::DataType element_type,
1427 dnn::DataType bias_type, dnn::DataType output_type, double conv_scale,
1428 double side_input_scale, double leakyrelu_alpha,
1429 const dnn::BatchDescriptor& input_descriptor,
1430 const dnn::FilterDescriptor& filter_descriptor,
1431 const dnn::BatchDescriptor& bias_descriptor,
1432 const dnn::BatchDescriptor& output_descriptor,
1433 const dnn::ConvolutionDescriptor& convolution_descriptor,
1434 dnn::ActivationMode activation_mode);
1435
1436 virtual bool GetMIOpenConvolveAlgorithms(
1437 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
1438 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
1439 const dnn::FilterDescriptor& filter_descriptor,
1440 DeviceMemoryBase filter_data,
1441 const dnn::BatchDescriptor& output_descriptor,
1442 DeviceMemoryBase output_data,
1443 const dnn::ConvolutionDescriptor& convolution_descriptor,
1444 ScratchAllocator* scratch_allocator,
1445 std::vector<ProfileResult>* out_algorithms);
1446
1447 // Returns a list of supported rnn algorithms.
1448 virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
1449
1450 // Version of DoConvolve that uses pre-quantized 8 bit coefficients.
1451 // coefficient_scales specifies the scaling of each column of coefficients:
1452 // original float coefficient[row * num_columns + column] =
1453 // quantized coefficient[row * num_columns + column] *
1454 // coefficient_scales[column].
1455 virtual bool DoConvolveQuantized(
1456 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1457 const DeviceMemory<float>& input_data,
1458 const dnn::FilterDescriptor& filter_descriptor,
1459 const DeviceMemory<int8>& filter_coefficients,
1460 const DeviceMemory<float>& coefficient_scales,
1461 const dnn::ConvolutionDescriptor& convolution_descriptor,
1462 const dnn::BatchDescriptor& output_descriptor,
1463 DeviceMemory<float>* output_data) = 0;
1464
1465 // Same as DoConvolveQuantized above, but int8 filter coefficients.
1466 virtual bool DoConvolveQuantized(
1467 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1468 const DeviceMemory<float>& input_data,
1469 const dnn::FilterDescriptor& filter_descriptor,
1470 const DeviceMemory<int16>& filter_coefficients,
1471 const DeviceMemory<float>& coefficient_scales,
1472 const dnn::ConvolutionDescriptor& convolution_descriptor,
1473 const dnn::BatchDescriptor& output_descriptor,
1474 DeviceMemory<float>* output_data) = 0;
1475
1476 // Variation of the above with the weight matrix split into two matrices.
1477 // first_weights: Coefficients of the first matrix.
1478 // second_weights: Coefficients of the second matrix.
1479 // depth_multiplier: specifies the columns of the first matrix and rows
1480 // of the second one - first_weights columns = depth_multiplier,
1481 // second_weights rows = depth_multiplier *
1482 // filter_descriptor.input_feature_map_count().
1483 // see go/separable for documentation on separable convolutions.
1484 virtual bool DoSeparableConvolve(
1485 Stream* stream, const BatchDescriptor& input_descriptor,
1486 const DeviceMemory<float>& input_data,
1487 const FilterDescriptor& filter_descriptor, int depth_multiplier,
1488 const DeviceMemory<float>& first_weights,
1489 const DeviceMemory<float>& second_weights,
1490 const ConvolutionDescriptor& convolution_descriptor,
1491 const BatchDescriptor& output_descriptor,
1492 DeviceMemory<float>* output_data) = 0;
1493
1494 // Return a list of algorithms supported by the backward convolution pass for
1495 // data.
1496 virtual bool GetConvolveBackwardDataAlgorithms(
1497 CudaComputeCapability cuda_compute_capability,
1498 std::vector<AlgorithmDesc>* out_algorithms);
1499
1500 // Return a list of algorithms supported by the backward convolution pass for
1501 // filters.
1502 virtual bool GetConvolveBackwardFilterAlgorithms(
1503 CudaComputeCapability cuda_compute_capability,
1504 std::vector<AlgorithmDesc>* out_algorithms);
1505
1506 // Fully connects the "nodes" (float values) in input_data with
1507 // shape input_dimensions to output_data with output_dimensions
1508 // using provided weights. This is equivalent to computing a matrix
1509 // product, hence the name MatMul.
1510 //
1511 // A BatchDescriptor has four dimensions: batch, y, x, depth. Matrix products
1512 // happen in two dimensions. To get down to two dimensions, we consider the
1513 // input y, x and depth dimension as one combined dimension T. For now,
1514 // assume that the output height and width are 1 and let OD be the output
1515 // depth.
1516 //
1517 // There are three device memory buffers passed in to this
1518 // function. We can now view all three as matrices:
1519 //
1520 // input_data: A batch x T matrix
1521 // weights: A T x OD matrix
1522 // output_data: A batch x OD matrix
1523 //
1524 // This function then computes the matrix product of input_data and
1525 // weights and writes the result into output_data.
1526 //
1527 // Here the weights buffer is in row major order, i.e. the first OD
1528 // entries in weights are the first row, the second OD entries in
1529 // weights are the second row and so on.
1530 //
1531 // The case for output width*height > 1 is more complicated. Let K =
1532 // OY * OX where OY is the output height and OX is the output
1533 // width. Then weights is divided into K sub-arrays W_i, for
1534 // i=0,...,k-1, that each represent a T x OD matrix. This function
1535 // then computes the K matrix multiplications of input_data with
1536 // each W_i. This creates K matrices with dimensions batch x
1537 // OD. These K matrices are concatenated horizontally to form one
1538 // larger matrix with dimensions batch x (K*OD); note that this is
1539 // not the same as concatenating the bytes of the matrices. The
1540 // combined matrix can then be interpreted as a tensor with
1541 // dimensions (batch, OY, OX, OD). If the output tensor format is
1542 // not kBatchYXDepth, this function would then need to arrange for
1543 // the output to be in the requested layout, if that is
1544 // supported. Note that the case K=1 is equivalent to the
1545 // description above. It is recommended to prefer the case K=1.
1546 //
1547 // Arguments (all borrowed):
1548 // stream: borrowed pointer to the stream that the 'fully connect' operation
1549 // should be enqueued onto.
1550 // output_data: un-owned device memory region in which to place the
1551 // fully connected result.
1552 virtual bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
1553 const DeviceMemory<float>& weights,
1554 const dnn::BatchDescriptor& input_dimensions,
1555 const dnn::BatchDescriptor& output_dimensions,
1556 DeviceMemory<float>* output_data) = 0;
1557
1558 // Version of DoMatMul that uses pre-quantized 8 bit weights.
1559 // weight_scales specifies the scaling of each column of weights:
1560 // original float weight[row * num_columns + column] =
1561 // quantized_weight[row * nnum_columns + column] * weight_scales[column].
1562 virtual bool DoMatMulQuantized(Stream* stream,
1563 const DeviceMemory<float>& input_data,
1564 const DeviceMemory<int8>& quantized_weights,
1565 const DeviceMemory<float>& weight_scales,
1566 const dnn::BatchDescriptor& input_dimensions,
1567 const dnn::BatchDescriptor& output_dimensions,
1568 DeviceMemory<float>* output_data) = 0;
1569
1570 // Version of DoMatMul that uses pre-quantized 16 bit weights.
1571 // weight_scales specifies the scaling of each column of weights:
1572 // original float weight[row * num_columns + column] =
1573 // quantized_weight[row * nnum_columns + column] * weight_scales[column].
1574 virtual bool DoMatMulQuantized(Stream* stream,
1575 const DeviceMemory<float>& input_data,
1576 const DeviceMemory<int16>& quantized_weights,
1577 const DeviceMemory<float>& weight_scales,
1578 const dnn::BatchDescriptor& input_dimensions,
1579 const dnn::BatchDescriptor& output_dimensions,
1580 DeviceMemory<float>* output_data) = 0;
1581
1582 // Adds biases to the feature maps in input_data producing
1583 // output_data. input_data can equal output_data, but must not
1584 // partially overlap it.
1585 //
1586 // Let K = count() * height() * width() and N = feature_map_count()
1587 // on dimensions. Then input_value contains K*N values and biases
1588 // contains N values. We can thus logically consider input_value to
1589 // contain K vectors of N elements each. This function adds biases
1590 // to each of those N vectors.
1591 //
1592 // TODO(broune): This works differently when width() * height() > 1
1593 // and the call to ThenBiasAdd() follows a call to ThenMatMul(). In
1594 // that case there should be width() * height() *
1595 // feature_map_count() biases, but this is not implemented on all
1596 // StreamExecutors.
1597 //
1598 // Arguments (all borrowed):
1599 // stream: borrowed pointer to the stream that the 'bias add' operation
1600 // should be enqueued onto.
1601 // input_data: un-owned device memory region containing the input.
1602 // biases: un-owned device memory region containing biases to add to the
1603 // input.
1604 // dimensions: dimensions of input_data and output_data.
1605 // output_data: un-owned device memory region in which to place the result.
1606 virtual bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
1607 const DeviceMemory<float>& biases,
1608 const dnn::BatchDescriptor& dimensions,
1609 DeviceMemory<float>* output_data) = 0;
1610
1611 // Performs a forward pooling operation on input_data, writing to
1612 // output_data. See PoolingDescriptor for how to configure the
1613 // pooling operation.
1614 //
1615 // Pooling happens as a window that moves across the Y and X
1616 // dimensions of input_data, where each position of the window
1617 // yields one output value. E.g. for max pooling, the computed value
1618 // is the maximum element in the window. The operation is applied
1619 // independently to each batch and at each feature map (depth), so
1620 // that the output depth and feature_map_count are the same as for
1621 // the input. The output width and height can be different.
1622 //
1623 // See PoolingDescriptor for how to configure the pooling operation.
1624 virtual port::Status DoPoolForward(
1625 DataType element_type, Stream* stream,
1626 const dnn::PoolingDescriptor& pooling_dimensions,
1627 const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
1628 const dnn::BatchDescriptor& output_dimensions,
1629 DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator) = 0;
1630
1631 // Performs differentiation of the pooling operation.
1632 virtual port::Status DoPoolBackward(
1633 DataType element_type, Stream* stream,
1634 const dnn::PoolingDescriptor& pooling_dimensions,
1635 const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
1636 const dnn::BatchDescriptor& output_dimensions,
1637 DeviceMemoryBase output_data, DeviceMemoryBase input_diff_data,
1638 DeviceMemoryBase output_diff_data,
1639 ScratchAllocator* workspace_allocator) = 0;
1640
1641 // Applies local response normalization to the values from input_data and
1642 // writes the result to output_data.
1643 //
1644 // See comments on NormalizeDescriptor for a description of local response
1645 // normalization.
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1646 virtual bool DoNormalizeWithDimensions(
1647 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1648 const dnn::BatchDescriptor& dimensions,
1649 const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
1650 return false;
1651 }
1652
1653 // Performs backpropagation for the normalization operation
1654 //
1655 // Given raw data, its corresponding normalized output, and a gradient of some
1656 // unspecified function with respect to the normalized variables, computes the
1657 // gradient of that unspecified function with respect to the raw variables.
1658 //
1659 // The normalized data input array is expected to match the output that would
1660 // be obtained by running the raw data input array through the DoNormalize
1661 // method above.
1662 //
1663 // See comments on NormalizeDescriptor for a description of local response
1664 // normalization.
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1665 virtual bool DoNormalizeBackwardWithDimensions(
1666 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1667 const dnn::BatchDescriptor& dimensions,
1668 const DeviceMemory<float>& raw_data,
1669 const DeviceMemory<float>& normalized_data,
1670 const DeviceMemory<float>& normalized_variable_gradient,
1671 DeviceMemory<float>* raw_variable_gradient,
1672 ScratchAllocator* workspace_allocator) {
1673 return false;
1674 }
1675
1676 // Applies an activation function (see ActivationMode) to all of the values
1677 // held on the device in 'input_data', whose dimensions are described by
1678 // 'dimensions'.
1679 //
1680 // Arguments (all borrowed):
1681 // stream: borrowed pointer to the stream that the 'activate' operation
1682 // should be enqueued onto.
1683 // activation_mode: Type of activation to perform.
1684 // input_data: un-owned device memory region which contains the
1685 // activate input.
1686 // output_data: un-owned device memory region in which to place the
1687 // activate result.
DoActivate(Stream * stream,ActivationMode activation_mode,const BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64_t options)1688 virtual bool DoActivate(Stream* stream, ActivationMode activation_mode,
1689 const BatchDescriptor& dimensions,
1690 const DeviceMemory<float>& input_data,
1691 DeviceMemory<float>* output_data, uint64_t options) {
1692 return false;
1693 }
1694
1695 // Concatenates several layers into one, by concatenating the depth of each
1696 // layer at matching x and y coordinates.
1697 // The inputs must all have the same width and height, the output will have
1698 // the same width and height as the inputs and its depth will be the sum of
1699 // the input depths.
1700 //
1701 // Arguments (all borrowed):
1702 // stream: borrowed pointer to the stream that the 'depth concatenate'
1703 // operation should be enqueued onto.
1704 // input_dimensions: The dimensions of each input.
1705 // input_data: un-owned device memory region which contains the
1706 // input data for each input layer.
1707 // output_data: un-owned device memory region in which to place the
1708 // depth concatenate result.
1709 virtual bool DoDepthConcatenate(
1710 Stream* stream,
1711 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
1712 port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
1713 DeviceMemory<float>* output_data) = 0;
1714
1715 // Concatenates several layers into one, by concatenating each in the
1716 // x-dimension or y-dimension, based on a user-specified flag.
1717 // For x-concatenation, layers are aligned at matching y and depth
1718 // coordinates, and for y-concatenation, they are aligned at matching x and
1719 // depth coordinates. The inputs must all have the same depth and batch size.
1720 // For x-concatenation, the inputs must have the same height (y-size), and the
1721 // output will have the same depth and height as the inputs and its width (x-
1722 // size) will be the sum of the input widths. For y-concatenation, the inputs
1723 // must have the same width, and the output will have the same depth and width
1724 // as the inputs, and its height will be the sum of the input heights.
1725 //
1726 // Arguments:
1727 // stream: borrowed pointer to the stream that the 'space concatenate'
1728 // operation should be enqueued onto.
1729 // input_dimensions: the dimensions of each input.
1730 // input_data: un-owned device memory region which contains the input data
1731 // for each input layer.
1732 // output_data: un-owned device memory region in which to place the space
1733 // concatenate result.
1734 // concat_direction: either dnn:SpaceConcatenateMode::XDirection or
1735 // dnn::SpaceConcatenateMode::YDirection.
DoSpaceConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1736 virtual bool DoSpaceConcatenate(
1737 Stream* stream,
1738 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
1739 port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
1740 DeviceMemory<float>* output_data,
1741 dnn::SpaceConcatenateMode concat_direction) {
1742 return false;
1743 }
1744
1745 // Change the layout of the data by shrinking one dimension (or set of
1746 // dimensions) and growing another dimension (or set of dimensions), while
1747 // keeping the total number of data elements constant, and maintaining the
1748 // current data ordering.
1749 //
1750 // Currently, the only supported operation is depth into space by a power of
1751 // 2. E.g. (y, x, z) -> (y*2, x*2, z/4)
1752 //
1753 // Note that Reshape may not be a no-op, depending on the platform and which
1754 // dimensions are being changed.
1755 //
1756 // Example: forgetting about batch for the moment, let's take a tensor that's
1757 // 2x1x8 (y by x by z) and reshape to a tensor that's 4x2x2. The memory layout
1758 // is row-major order: y,x,z. I.e. z changes the fastest, then x, then y. The
1759 // elements of the tensor range from 0 to 15. The x,y,z indices are below each
1760 // element.
1761 //
1762 // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
1763 // y0 y0 y0 y0 y0 y0 y0 y0 y1 y1 y1 y1 y1 y1 y1 y1
1764 // x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0
1765 // z0 z1 z2 z3 z4 z5 z6 z7 z0 z1 z2 z3 z4 z5 z6 z7
1766 //
1767 // reshape to 4x2x2
1768 //
1769 // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
1770 // y0 y0 y0 y0 y1 y1 y1 y1 y2 y2 y2 y2 y3 y3 y3 y3
1771 // x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1
1772 // z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1
DoReshape(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1773 virtual bool DoReshape(Stream* stream,
1774 const dnn::BatchDescriptor& input_dimensions,
1775 const DeviceMemory<float>& input_data,
1776 const dnn::BatchDescriptor& output_dimensions,
1777 DeviceMemory<float>* output_data) {
1778 return false;
1779 }
1780
1781 // Depth to space takes an X by Y image with depth D*M^2 and changes it to an
1782 // MX x MY image with depth D. Each input location (x,y) with depth D*M^2 in
1783 // the input image is changed to an MxM contiguous area in the output image,
1784 // with the values being laid out in the raster order by DepthToSpaceLayout,
1785 // and will have a new depth of D.
1786 //
1787 // Example.
1788 // M=2, Din =8, Xin=2, Yin=2. Xout=4, Yout=4, Dout=2
1789 // DepthHeightWidth layout
1790 // Values within a 'cell' are at different depths and same x & y.
1791 // Input:
1792 // abcdefgh ijklmnop
1793 // qrstuvwx yz012345
1794 // Output:
1795 // ae bf im jn
1796 // cg dh ko lp
1797 // qu rv y2 z3
1798 // sw tx 04 15
1799 //
1800 // sqrt_depth_reduction: 'M' in the comment above
DoDepthToSpace(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & depth_to_space_layout,const int & sqrt_depth_reduction,DeviceMemory<float> * output_data)1801 virtual bool DoDepthToSpace(Stream* stream,
1802 const dnn::BatchDescriptor& input_dimensions,
1803 const DeviceMemory<float>& input_data,
1804 const DepthToSpaceLayout& depth_to_space_layout,
1805 const int& sqrt_depth_reduction,
1806 DeviceMemory<float>* output_data) {
1807 return false;
1808 }
1809
1810 // Space to depth is the inverse of depth to space. Space to depth takes each
1811 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
1812 // the input, and transforms it to a 1 by 1 patch with depth D*M^2. If the
1813 // input has size (MX, MY, D), the output has size (X, Y, D*M^2). The number
1814 // of data elements is not changed.
1815 //
1816 // Example.
1817 // M=2, Din =2, Xin=4, Yin=4, Dout=8
1818 // DepthHeightWidth layout
1819 // Values within a 'cell' are at different depths and same x & y.
1820 // Input:
1821 // ae bf im jn
1822 // cg dh ko lp
1823 // qu rv y2 z3
1824 // sw tx 04 15
1825 // Output:
1826 // abcdefgh ijklmnop
1827 // qrstuvwx yz012345
1828 //
1829 // sqrt_depth_increase: 'M' in the comment above
DoSpaceToDepth(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & space_to_depth_layout,const int & sqrt_depth_increase,DeviceMemory<float> * output_data)1830 virtual bool DoSpaceToDepth(Stream* stream,
1831 const dnn::BatchDescriptor& input_dimensions,
1832 const DeviceMemory<float>& input_data,
1833 const DepthToSpaceLayout& space_to_depth_layout,
1834 const int& sqrt_depth_increase,
1835 DeviceMemory<float>* output_data) {
1836 return false;
1837 }
1838
1839 // Computes the specified operation (e.g. addition or multiplication)
1840 // between corresponding elements in the inputs and stores the result in the
1841 // output element.
1842 // The inputs and output must all have the same dimensions, but may have
1843 // different quantization parameters (min_value and max_value).
1844 //
1845 // Arguments (all borrowed):
1846 // stream: borrowed pointer to the stream that the 'elementwise operation'
1847 // should be enqueued onto.
1848 // operation: The operation to perform.
1849 // input_dimensions: The dimensions of each input.
1850 // input_data: un-owned device memory region which contains the
1851 // input data for each input layer.
1852 // output_dimensions: The dimensions of the output.
1853 // output_data: un-owned device memory region in which to place the
1854 // operation result.
1855 virtual bool DoElementwiseOperate(
1856 Stream* stream, ElementwiseOperation operation,
1857 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
1858 port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
1859 const dnn::BatchDescriptor& output_dimensions,
1860 DeviceMemory<float>* output_data) = 0;
1861
1862 // Computes the specified operation (e.g. addition or multiplication)
1863 // between corresponding elements in the inputs and stores the result in the
1864 // output element. Each input is multiplied by a scalar constant and the
1865 // result is divided by a scalar constant.
1866 // e.g. To perform Z = 0.9*X + 1.1*Y, set the input multiplicands to 9 and 11
1867 // and the output divisor to 10.
1868 // The inputs and output must all have the same dimensions, but may have
1869 // different quantization parameters (min_value and max_value).
1870 //
1871 // Arguments (all borrowed):
1872 // stream: borrowed pointer to the stream that the 'elementwise operation'
1873 // should be enqueued onto.
1874 // operation: The operation to perform.
1875 // input_multiplicands: Amount to scale each input.
1876 // output_divisor: Amount to divide the output.
1877 // input_dimensions: The dimensions of each input.
1878 // input_data: un-owned device memory region which contains the
1879 // input data for each input layer.
1880 // output_dimensions: The dimensions of the output.
1881 // output_data: un-owned device memory region in which to place the
1882 // operation result.
DoElementwiseOperateScaledQuantized(Stream * stream,ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1883 virtual bool DoElementwiseOperateScaledQuantized(
1884 Stream* stream, ElementwiseOperation operation,
1885 port::ArraySlice<int> input_multiplicands, // non-absl ok
1886 int output_divisor,
1887 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
1888 port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
1889 const dnn::BatchDescriptor& output_dimensions,
1890 DeviceMemory<float>* output_data) {
1891 return false;
1892 }
1893
1894 // Pads the input with zeros in the X and Y dimensions. The feature_map
1895 // dimension is unchanged.
1896 //
1897 // Arguments (all borrowed):
1898 // stream: borrowed pointer to the stream that the 'elementwise operation'
1899 // should be enqueued onto.
1900 // dimensions: The dimensions of the input.
1901 // input_data: un-owned device memory region which contains the
1902 // input data for the input layer.
1903 // left_pad: Amount to pad the input on the left.
1904 // right_pad: Amount to pad the input on the right.
1905 // top_pad: Amount to pad the input at the top (low Y).
1906 // bottom_pad: Amount to pad the input at the bottom (high Y).
1907 // output_data: un-owned device memory region in which to place the
1908 // padded result.
1909 virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
1910 const DeviceMemory<float>& input_data, int64_t left_pad,
1911 int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
1912 DeviceMemory<float>* output_data) = 0;
1913
1914 // Extracts a slice of the input in the X and Y dimensions. The feature_map
1915 // dimension is unchanged.
1916 //
1917 // Arguments (all borrowed):
1918 // stream: borrowed pointer to the stream that the 'elementwise operation'
1919 // should be enqueued onto.
1920 // dimensions: The dimensions of the input.
1921 // input_data: un-owned device memory region which contains the
1922 // input data for the input layer.
1923 // left_trim: Amount to cut off the input on the left.
1924 // right_trim: Amount to cut off the input on the right.
1925 // top_trim: Amount to cut off the input at the top (low y).
1926 // bottom_trim: Amount to cut off the input at the bottom (high Y).
1927 // output_data: un-owned device memory region in which to place the
1928 // padded result.
1929 virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
1930 const DeviceMemory<float>& input_data,
1931 int64_t left_trim, int64_t right_trim,
1932 int64_t top_trim, int64_t bottom_trim,
1933 DeviceMemory<float>* output_data) = 0;
1934
1935 // Grows the input tensor by replicating the X and Y dimensions. The batch and
1936 // depth/feature_map dimensions are unchanged. Currently, the input tensor is
1937 // limited to X=1 and Y=1.
1938 //
1939 // For example, the input has dimensions x=2, y=3, and replicate_x=3,
1940 // replicate_y=2. The diagonal elements of the output would be: [x0y0, x1y1,
1941 // x0y2, x1y0, x0y1, x1y2].
1942 // Here is the example as a picture. input:
1943 // AB
1944 // CD
1945 // EF
1946 // broadcast result:
1947 // ABABAB
1948 // CDCDCD
1949 // EFEFEF
1950 // ABABAB
1951 // CDCDCD
1952 // EFEFEF
1953 //
1954 // Arguments (all borrowed):
1955 // stream: borrowed pointer to the stream that the 'elementwise operation'
1956 // should be enqueued onto.
1957 // dimensions: The dimensions of the input.
1958 // input_data: un-owned device memory region which contains the
1959 // input data for the input layer.
1960 // replicate_x: Amount to replicate the input's X dimension.
1961 // replicate_y: Amount to replicate the input's Y dimension.
1962 // output_data: un-owned device memory region in which to place the
1963 // padded result.
DoXYBroadcast(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t replicate_x,int64_t replicate_y,DeviceMemory<float> * output_data)1964 virtual bool DoXYBroadcast(Stream* stream,
1965 const dnn::BatchDescriptor& dimensions,
1966 const DeviceMemory<float>& input_data,
1967 int64_t replicate_x, int64_t replicate_y,
1968 DeviceMemory<float>* output_data) {
1969 return false;
1970 }
1971
1972 // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
1973 // is, bytes instead of scaled floats) into 'host_dst' if they are available
1974 // for the underlying DNN implementation. If this quantized output is not
1975 // available, false is returned, which will place 'stream' into an error
1976 // state.
1977 //
1978 // Arguments (all borrowed):
1979 // stream: borrowed pointer to the stream that the 'quantized memcpy'
1980 // operation should be enqueued onto.
1981 // gpu_unquantized_src: the device memory that contains the unquantized data
1982 // -- this data should also have a corresponding quantized representation
1983 // on the device for this operation to succeed.
1984 // mode: Type of quantization of the data to write into host_dst.
1985 // host_dst: un-owned host memory region that is mutated in place,
1986 // it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
1987 // (asynchronous) memcpy operation is performed.
1988 // size: size in bytes of the host_dst host memory region.
1989 virtual bool DoMemcpyD2HQuantized(
1990 Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
1991 QuantizedActivationMode mode, void* host_dst, int64_t size) = 0;
1992
1993 // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
1994 // of a layer (that is, bytes instead of scaled floats) if they are supported
1995 // by the underlying DNN implementation. If this quantized input is not
1996 // supported, false is returned, which will place 'stream' into an error
1997 // state.
1998 //
1999 // Arguments (all borrowed):
2000 // stream: borrowed pointer to the stream that the 'quantized memcpy'
2001 // operation should be enqueued onto.
2002 // host_src: un-owned host memory region that contains the quantized data.
2003 // size: size in bytes of the host_src host memory region.
2004 // mode: Type of quantization of the data to read from host_src.
2005 // gpu_unquantized_dst: the device memory that is clobbered by the values in
2006 // 'host_src' when the enqueued (asynchronous) memcpy operation is
2007 // performed. -- this data should also have a corresponding quantized
2008 // representation on the device for this operation to
2009 // succeed.
2010 virtual bool DoMemcpyH2DQuantized(
2011 Stream* stream, const void* host_src, int64_t size,
2012 QuantizedActivationMode mode,
2013 DeviceMemory<float>* gpu_unquantized_dst) = 0;
2014
2015 // Create an RNN descriptor based on model shapes and configurations.
2016 // The caller retains the ownership of the descriptor.
2017 //
2018 // Arguments:
2019 // num_layers: the number of layers for a RNN model.
2020 // hidden_size: the size of the hidden state.
2021 // input_size: the size of the input state.
2022 // cell_size: the size of the cell state
2023 // input_mode: an enum to specify whether a linear transformation is added
2024 // after the input state. If input_size is different from hidden_size, this
2025 // is required.
2026 // direction_mode: an enum to specify whether this model is unidirectional or
2027 // bidirectional.
2028 // rnn_mode: an enum to specify the type of model to build.
2029 // data_type: an enum to specify the data types used in this model.
2030 // dropout: the dropout threshold between layers. When it is 0., no dropout
2031 // is added.
2032 // seed: a seed for initializing the dropout layers.
2033 // state_allocator: an memory allocator that will be used to store the state
2034 // for dropout layer. The user has to maintain the memory until the model
2035 // is no longer in use.
2036 // use_padded_io: a bool to specify whether the input is using padded IO.
2037 virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)2038 createRnnDescriptor(int num_layers, int hidden_size, int input_size,
2039 int cell_size, int batch_size,
2040 dnn::RnnInputMode input_mode,
2041 dnn::RnnDirectionMode direction_mode,
2042 dnn::RnnMode rnn_mode, dnn::DataType data_type,
2043 const dnn::AlgorithmConfig& algorithm_config,
2044 float dropout, uint64_t seed,
2045 ScratchAllocator* state_allocator, bool use_padded_io) {
2046 return port::Status(port::error::UNIMPLEMENTED,
2047 "createRnnDescriptor is unimplemented");
2048 }
2049
2050 // Create a RNN sequence descriptor that specifies either the input or output
2051 // sequence. The caller retains the ownership of the returned descriptor.
2052 //
2053 // Arguments:
2054 // max_seq_length: the max length of the sequences.
2055 // batch_size: the size of a minibatch.
2056 // data_size: the size of the state.
2057 // seq_lengths: the lengths of sequences in a batch.
2058 // data_type: an enum to specify the type for the underlying data.
2059 virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2060 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2061 int data_size, dnn::DataType data_type) {
2062 return port::Status(port::error::UNIMPLEMENTED,
2063 "createRnnSequenceTensorDescriptor is unimplemented");
2064 }
2065
2066 virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2067 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2068 int data_size,
2069 const absl::Span<const int>& seq_lengths,
2070 bool time_major, dnn::DataType data_type) {
2071 return port::Status(port::error::UNIMPLEMENTED,
2072 "createRnnSequenceTensorDescriptor is unimplemented");
2073 }
2074
2075 // Create an RNN state descriptor that specifies the input or hidden state.
2076 // The caller retains the ownership of the returned descriptor.
2077 virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2078 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
2079 dnn::DataType data_type) {
2080 return port::Status(port::error::UNIMPLEMENTED,
2081 "createRnnStateTensorDescriptor is unimplemented");
2082 }
2083
2084 // Enqueue a forward operation of the RNN model onto the stream.
2085 //
2086 // Arguments:
2087 // stream: pointer to the stream where this operation should be enqueued to.
2088 // rnn_desc: a RNN descriptor created by createRnnDescriptor.
2089 // input_desc: descriptor for the input sequence.
2090 // input_data: the device memory region that contains the input data.
2091 // input_h_desc: descriptor for the input "h" state.
2092 // input_h_data: the device memory region that contains the input "h" data.
2093 // input_c_desc: descriptor for the input "c" state.
2094 // input_c_data: the device memory region that contains the input "c" data.
2095 // This must be specified for LSTM models.
2096 // params: the device memory region that contains the parameters used in this
2097 // model.
2098 // output_desc: descriptor for the output sequence.
2099 // output_data: the memory region that stores the output sequence data.
2100 // output_h_desc: descriptor for the output "h" state.
2101 // output_h_data: the memory region that stores the output "h" data.
2102 // output_c_desc: descriptor for the output "c" state.
2103 // output_c_data: the memory region that stores the output "c" data. This
2104 // must be specified for LSTM models.
2105 // is_training: whether this is used in training or inference. That decides
2106 // whether respace_space data need to be produced.
2107 // reserve_space_allocator: if "is_training" is true, an memory allocator
2108 // to create memory that holds the produced reserve_space. The caller is
2109 // retains the data and feed it to the backward pass.
2110 // workspace_allocator: an allocator to create temporary workspace used in
2111 // this kernel. The caller is responsible for retaining the memory long
2112 // enough for the lifespan of this operation, and recycles afterwards.
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2113 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2114 const dnn::RnnSequenceTensorDescriptor& input_desc,
2115 const DeviceMemory<Eigen::half>& input_data,
2116 const DeviceMemory<int>& seq_lengths_data,
2117 const dnn::RnnStateTensorDescriptor& input_h_desc,
2118 const DeviceMemory<Eigen::half>& input_h_data,
2119 const dnn::RnnStateTensorDescriptor& input_c_desc,
2120 const DeviceMemory<Eigen::half>& input_c_data,
2121 const DeviceMemory<Eigen::half>& params,
2122 const dnn::RnnSequenceTensorDescriptor& output_desc,
2123 DeviceMemory<Eigen::half>* output_data,
2124 const dnn::RnnStateTensorDescriptor& output_h_desc,
2125 DeviceMemory<Eigen::half>* output_h_data,
2126 const dnn::RnnStateTensorDescriptor& output_c_desc,
2127 DeviceMemory<Eigen::half>* output_c_data,
2128 bool is_training,
2129 ScratchAllocator* reserve_space_allocator,
2130 ScratchAllocator* workspace_allocator,
2131 dnn::ProfileResult* output_profile_result) {
2132 return false;
2133 }
2134
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2135 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2136 const dnn::RnnSequenceTensorDescriptor& input_desc,
2137 const DeviceMemory<float>& input_data,
2138 const DeviceMemory<int>& seq_lengths_data,
2139 const dnn::RnnStateTensorDescriptor& input_h_desc,
2140 const DeviceMemory<float>& input_h_data,
2141 const dnn::RnnStateTensorDescriptor& input_c_desc,
2142 const DeviceMemory<float>& input_c_data,
2143 const DeviceMemory<float>& params,
2144 const dnn::RnnSequenceTensorDescriptor& output_desc,
2145 DeviceMemory<float>* output_data,
2146 const dnn::RnnStateTensorDescriptor& output_h_desc,
2147 DeviceMemory<float>* output_h_data,
2148 const dnn::RnnStateTensorDescriptor& output_c_desc,
2149 DeviceMemory<float>* output_c_data,
2150 bool is_training,
2151 ScratchAllocator* reserve_space_allocator,
2152 ScratchAllocator* workspace_allocator,
2153 dnn::ProfileResult* output_profile_result) {
2154 return false;
2155 }
2156
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2157 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2158 const dnn::RnnSequenceTensorDescriptor& input_desc,
2159 const DeviceMemory<double>& input_data,
2160 const DeviceMemory<int>& seq_lengths_data,
2161 const dnn::RnnStateTensorDescriptor& input_h_desc,
2162 const DeviceMemory<double>& input_h_data,
2163 const dnn::RnnStateTensorDescriptor& input_c_desc,
2164 const DeviceMemory<double>& input_c_data,
2165 const DeviceMemory<double>& params,
2166 const dnn::RnnSequenceTensorDescriptor& output_desc,
2167 DeviceMemory<double>* output_data,
2168 const dnn::RnnStateTensorDescriptor& output_h_desc,
2169 DeviceMemory<double>* output_h_data,
2170 const dnn::RnnStateTensorDescriptor& output_c_desc,
2171 DeviceMemory<double>* output_c_data,
2172 bool is_training,
2173 ScratchAllocator* reserve_space_allocator,
2174 ScratchAllocator* workspace_allocator,
2175 dnn::ProfileResult* output_profile_result) {
2176 return false;
2177 }
2178 // Enqueue a backward operation of the RNN model onto the stream.
2179 //
2180 // Arguments:
2181 // stream: pointer to the stream where this operation should be enqueued to.
2182 // rnn_desc: a RNN descriptor created by createRnnDescriptor.
2183 // input_desc: descriptor for the input sequence.
2184 // input_data: the device memory region that contains the input data.
2185 // input_h_desc: descriptor for the input "h" state.
2186 // input_h_data: the device memory region that contains the input "h" data.
2187 // input_c_desc: descriptor for the input "c" state.
2188 // input_c_data: the device memory region that contains the input "c" data.
2189 // This must be specified for LSTM models.
2190 // params: the device memory region that contains the parameters used in this
2191 // model.
2192 // output_desc: descriptor for the output sequence.
2193 // output_data: the memory region that stores the output sequence data.
2194 // output_h_desc: descriptor for the output "h" state.
2195 // output_h_data: the memory region that stores the output "h" data.
2196 // output_c_desc: descriptor for the output "c" state.
2197 // output_c_data: the memory region that stores the output "c" data. This
2198 // must be specified for LSTM models.
2199 // output_backprop_data: the device memory region that contains the backprop
2200 // to the output sequence.
2201 // output_h_backprop_data: the device memory region that contains the
2202 // backprop to the output "h" state.
2203 // output_c_backprop_data: the device memory region that contains the
2204 // backprop to the output "c" state.
2205 // input_backprop_data: the device memory region that stores the backprop
2206 // to the input sequence.
2207 // input_h_backprop_data: the device memory region that stores the backprop
2208 // to the input "h" state.
2209 // input_c_backprop_data: the device memory region that stores the backprop
2210 // to the input "c" state.
2211 // params_backprop_data: the device memory region that stores the backprop
2212 // to the parameters.
2213 // reserve_space_data: the reserve_space data that is produced by the forward
2214 // operation. This memory region could be modified by this operation.
2215 // workspace_allocator: a memory allocator that creates the temporary
2216 // workspace memory used by this operation. The caller is responsible for
2217 // keeping the memory alive long enough for this operation, and recylces
2218 // afterwards.
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2219 virtual bool DoRnnBackward(
2220 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2221 const dnn::RnnSequenceTensorDescriptor& input_desc,
2222 const DeviceMemory<Eigen::half>& input_data,
2223 const DeviceMemory<int>& seq_lengths_data,
2224 const dnn::RnnStateTensorDescriptor& input_h_desc,
2225 const DeviceMemory<Eigen::half>& input_h_data,
2226 const dnn::RnnStateTensorDescriptor& input_c_desc,
2227 const DeviceMemory<Eigen::half>& input_c_data,
2228 const DeviceMemory<Eigen::half>& params,
2229 const dnn::RnnSequenceTensorDescriptor& output_desc,
2230 const DeviceMemory<Eigen::half>& output_data,
2231 const dnn::RnnStateTensorDescriptor& output_h_desc,
2232 const DeviceMemory<Eigen::half>& output_h_data,
2233 const dnn::RnnStateTensorDescriptor& output_c_desc,
2234 const DeviceMemory<Eigen::half>& output_c_data,
2235 const DeviceMemory<Eigen::half>& output_backprop_data,
2236 const DeviceMemory<Eigen::half>& output_h_backprop_data,
2237 const DeviceMemory<Eigen::half>& output_c_backprop_data,
2238 DeviceMemory<Eigen::half>* input_backprop_data,
2239 DeviceMemory<Eigen::half>* input_h_backprop_data,
2240 DeviceMemory<Eigen::half>* input_c_backprop_data,
2241 DeviceMemory<Eigen::half>* params_backprop_data,
2242 DeviceMemory<uint8>* reserve_space_data,
2243 ScratchAllocator* workspace_allocator,
2244 dnn::ProfileResult* output_profile_result) {
2245 return false;
2246 }
2247
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2248 virtual bool DoRnnBackward(
2249 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2250 const dnn::RnnSequenceTensorDescriptor& input_desc,
2251 const DeviceMemory<float>& input_data,
2252 const DeviceMemory<int>& seq_lengths_data,
2253 const dnn::RnnStateTensorDescriptor& input_h_desc,
2254 const DeviceMemory<float>& input_h_data,
2255 const dnn::RnnStateTensorDescriptor& input_c_desc,
2256 const DeviceMemory<float>& input_c_data,
2257 const DeviceMemory<float>& params,
2258 const dnn::RnnSequenceTensorDescriptor& output_desc,
2259 const DeviceMemory<float>& output_data,
2260 const dnn::RnnStateTensorDescriptor& output_h_desc,
2261 const DeviceMemory<float>& output_h_data,
2262 const dnn::RnnStateTensorDescriptor& output_c_desc,
2263 const DeviceMemory<float>& output_c_data,
2264 const DeviceMemory<float>& output_backprop_data,
2265 const DeviceMemory<float>& output_h_backprop_data,
2266 const DeviceMemory<float>& output_c_backprop_data,
2267 DeviceMemory<float>* input_backprop_data,
2268 DeviceMemory<float>* input_h_backprop_data,
2269 DeviceMemory<float>* input_c_backprop_data,
2270 DeviceMemory<float>* params_backprop_data,
2271 DeviceMemory<uint8>* reserve_space_data,
2272 ScratchAllocator* workspace_allocator,
2273 dnn::ProfileResult* output_profile_result) {
2274 return false;
2275 }
2276
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2277 virtual bool DoRnnBackward(
2278 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2279 const dnn::RnnSequenceTensorDescriptor& input_desc,
2280 const DeviceMemory<double>& input_data,
2281 const DeviceMemory<int>& seq_lengths_data,
2282 const dnn::RnnStateTensorDescriptor& input_h_desc,
2283 const DeviceMemory<double>& input_h_data,
2284 const dnn::RnnStateTensorDescriptor& input_c_desc,
2285 const DeviceMemory<double>& input_c_data,
2286 const DeviceMemory<double>& params,
2287 const dnn::RnnSequenceTensorDescriptor& output_desc,
2288 const DeviceMemory<double>& output_data,
2289 const dnn::RnnStateTensorDescriptor& output_h_desc,
2290 const DeviceMemory<double>& output_h_data,
2291 const dnn::RnnStateTensorDescriptor& output_c_desc,
2292 const DeviceMemory<double>& output_c_data,
2293 const DeviceMemory<double>& output_backprop_data,
2294 const DeviceMemory<double>& output_h_backprop_data,
2295 const DeviceMemory<double>& output_c_backprop_data,
2296 DeviceMemory<double>* input_backprop_data,
2297 DeviceMemory<double>* input_h_backprop_data,
2298 DeviceMemory<double>* input_c_backprop_data,
2299 DeviceMemory<double>* params_backprop_data,
2300 DeviceMemory<uint8>* reserve_space_data,
2301 ScratchAllocator* workspace_allocator,
2302 dnn::ProfileResult* output_profile_result) {
2303 return false;
2304 }
2305
2306 template <typename ElementType>
PrepareForCtcLoss(Stream * stream,const RnnStateTensorDescriptor & probs_desc,DeviceMemory<ElementType> probs_data,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2307 port::Status PrepareForCtcLoss(Stream* stream,
2308 const RnnStateTensorDescriptor& probs_desc,
2309 DeviceMemory<ElementType> probs_data,
2310 const RnnStateTensorDescriptor& grads_desc,
2311 absl::Span<const int> labels_data,
2312 absl::Span<const int> labels_lengths_data,
2313 absl::Span<const int> input_lengths_data,
2314 ScratchAllocator* workspace_allocator,
2315 DeviceMemory<uint8>* scratch_memory,
2316 int* ctc_loss_algo_id) {
2317 return DoPrepareForCtcLoss(
2318 stream, ToDataType<ElementType>::value, probs_desc, grads_desc,
2319 labels_data, labels_lengths_data, input_lengths_data,
2320 workspace_allocator, scratch_memory, ctc_loss_algo_id);
2321 }
2322
2323 // Enqueue a CTC Loss operation onto the stream.
2324 //
2325 // Arguments:
2326 // stream: pointer to the stream where this operation should be enqueued to.
2327 // element_type: date type of the input tensors
2328 // probs_desc: specifies the shape and the data layout of the input tensor.
2329 // probs_data: the device memory region that contains the input tensor.
2330 // labels_data: the device memory region that contains the labels_value
2331 // tensor.
2332 // labels_lengths_data: the device memory region that contains the
2333 // labels_lengths tensor
2334 // input_lengths_data: the device memory region that contains the seq_lengths
2335 // tensor
2336 // costs_data: the device memory region that contains the costs tensor.
2337 // grads_desc: specifies the shape and the data layout of the grads tensor.
2338 // grads_data: the device memory region that contains the grads tensor.
2339 // ctc_loss_desc: a CTCLoss descriptor.
2340 // workspace_allocator: a memory allocator that creates the temporary
2341 // workspace memory used by this operation. The caller is responsible for
2342 // keeping the memory alive long enough for this operation, and recylces
2343 // afterwards.
2344 virtual port::Status DoCtcLoss(
2345 Stream* stream, dnn::DataType element_type,
2346 const RnnStateTensorDescriptor& probs_desc,
2347 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2348 absl::Span<const int> labels_lengths_data,
2349 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2350 const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
2351 DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
2352
2353 template <typename ElementType>
DoCtcLoss(Stream * stream,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<ElementType> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<ElementType> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<ElementType> * grads_data,DeviceMemory<uint8> * scratch_memory,int ctc_loss_algo_id)2354 bool DoCtcLoss(Stream* stream,
2355 const dnn::RnnStateTensorDescriptor& probs_desc,
2356 const DeviceMemory<ElementType>& probs_data,
2357 absl::Span<const int> labels_data,
2358 absl::Span<const int> labels_lengths_data,
2359 absl::Span<const int> input_lengths_data,
2360 DeviceMemory<ElementType>* costs_data,
2361 const dnn::RnnStateTensorDescriptor& grads_desc,
2362 DeviceMemory<ElementType>* grads_data,
2363 DeviceMemory<uint8>* scratch_memory, int ctc_loss_algo_id) {
2364 return IsStatusOk(
2365 DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
2366 probs_data, labels_data, labels_lengths_data,
2367 input_lengths_data, *costs_data, grads_desc, *grads_data,
2368 *scratch_memory, ctc_loss_algo_id),
2369 false);
2370 }
2371
2372 // Transforms a tensor into another tensor with a different layout and/or data
2373 // type.
2374 //
2375 // Arguments:
2376 // stream: pointer to the stream where this operation should be enqueued to.
2377 // input_desc: specifies the shape and the data layout of the input tensor.
2378 // input_type: the data type of the input tensor.
2379 // input_data: the device memory region that contains the input tensor.
2380 // output_desc: specifies the shape and the data layout of the output tensor.
2381 // output_type: the data type of the output tensor.
2382 // scale: an element-wise scaling factor to apply.
2383 // output_data: the device memory region that contains the output tensor.
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)2384 virtual bool DoTransformTensor(Stream* stream,
2385 const dnn::BatchDescriptor& input_desc,
2386 dnn::DataType input_type,
2387 const DeviceMemoryBase& input_data,
2388 const dnn::BatchDescriptor& output_desc,
2389 dnn::DataType output_type, float scale,
2390 DeviceMemoryBase* output_data) {
2391 return false;
2392 }
2393
2394 // Enqueues a fused convolution+bias+activation operation onto the stream.
2395 //
2396 // Arguments (all borrowed):
2397 //
2398 // stream: borrowed pointer to the stream that the 'fusion' operation should
2399 // be enqueued onto.
2400 //
2401 // conv_input_descriptor: dimensions of the convolution input layer.
2402 // conv_input_data: device memory which contains the convolution input.
2403 //
2404 // filter_descriptor: dimensions of the convolution filter.
2405 // filter_data: device memory which contains the convolution filter weights.
2406 //
2407 // convolution_descriptor: stride of the convolution filter.
2408 //
2409 // bias_descriptor: dimensions of the bias layer
2410 // biases: device memory region containing biases to add to the convolution
2411 // output
2412 //
2413 // activation_mode: Type of activation to perform.
2414 //
2415 // output_descriptor: dimensions of the output layer.
2416 // output_data: device memory region in which to place the fusion result.
2417 //
2418 // output_profile_result: the output profile result for this call.
2419 // The profiling is only enabled when this is not nullptr.
2420 //
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)2421 virtual bool DoFusedConvolutionBiasActivation(
2422 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
2423 const DeviceMemory<float>& conv_input_data,
2424 const dnn::FilterDescriptor& filter_descriptor,
2425 const DeviceMemory<float>& filter_data,
2426 const dnn::ConvolutionDescriptor& convolution_descriptor,
2427 const dnn::BatchDescriptor& bias_descriptor,
2428 const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
2429 const dnn::BatchDescriptor& output_descriptor,
2430 DeviceMemory<float>* output_data,
2431 dnn::ProfileResult* output_profile_result) {
2432 return false;
2433 }
2434
2435 // Enqueues a fused batchnorm+activation (inference) operation onto the
2436 // stream.
2437 //
2438 // Arguments (all borrowed):
2439 //
2440 // stream: borrowed pointer to the stream that the 'fusion' operation should
2441 // be enqueued onto.
2442 //
2443 // x_descriptor: dimensions of the batchnorm input layer.
2444 // x_data: device memory which contains the batchnorm input.
2445 //
2446 // scale_offset_mean_variance_descriptor:
2447 // dimensions of the scale/offset/mean/variance tensor.
2448 // scale_data: device memory which contains the scale input.
2449 // offset_data: device memory which contains the offset input.
2450 // mean_data: device memory which contains the mean input.
2451 // variance_data: device memory which contains the variance input.
2452 // epsilon : the epsilon value to use in batchnorm calculation
2453 //
2454 // activation_mode: Type of activation to perform.
2455 //
2456 // y_data: device memory region in which to place the fusion result.
2457 //
2458 // output_profile_result: the output profile result for this call.
2459 // The profiling is only enabled when this is not nullptr.
2460 //
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)2461 virtual bool DoFusedBatchNormActivationInference(
2462 Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2463 const DeviceMemory<float>& x_data,
2464 const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2465 const DeviceMemory<float>& scale_data,
2466 const DeviceMemory<float>& offset_data,
2467 const DeviceMemory<float>& mean_data,
2468 const DeviceMemory<float>& variance_data, double epsilon,
2469 dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2470 dnn::ProfileResult* output_profile_result) {
2471 return false;
2472 }
2473
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)2474 virtual bool DoFusedBatchNormActivationInference(
2475 Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2476 const DeviceMemory<Eigen::half>& x_data,
2477 const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2478 const DeviceMemory<float>& scale_data,
2479 const DeviceMemory<float>& offset_data,
2480 const DeviceMemory<float>& mean_data,
2481 const DeviceMemory<float>& variance_data, double epsilon,
2482 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2483 dnn::ProfileResult* output_profile_result) {
2484 return false;
2485 }
2486
2487 // Enqueues a fused batchnorm+activation (training-fwd) operation onto the
2488 // stream.
2489 //
2490 // Arguments (all borrowed):
2491 //
2492 // stream: borrowed pointer to the stream that the 'fusion' operation should
2493 // be enqueued onto.
2494 //
2495 // x_descriptor: dimensions of the batchnorm input layer.
2496 // x_data: device memory which contains the batchnorm input.
2497 //
2498 // scale_offset_mean_variance_descriptor:
2499 // dimensions of the scale/offset/mean/variance tensor.
2500 // scale_data: device memory which contains the scale input.
2501 // offset_data: device memory which contains the offset input.
2502 // epsilon : the epsilon value to use in batchnorm calculation
2503 //
2504 // activation_mode: Type of activation to perform.
2505 //
2506 // y_data: device memory region in which to place the fusion result.
2507 // batch_mean_data: device memory in which to place the batch mean output.
2508 // batch_var_data: device memory in which to place the batch variance output.
2509 // saved_mean_data: device memory in which to save the mean for bwd pass.
2510 // saved_var_data: device memory in which to save the variance for bwd pass.
2511 //
2512 // output_profile_result: the output profile result for this call.
2513 // The profiling is only enabled when this is not nullptr.
2514 //
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2515 virtual bool DoFusedBatchNormActivationForward(
2516 Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2517 const DeviceMemory<float>& x_data,
2518 const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2519 const DeviceMemory<float>& scale_data,
2520 const DeviceMemory<float>& offset_data, double epsilon,
2521 dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2522 DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2523 DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2524 dnn::ProfileResult* output_profile_result) {
2525 return false;
2526 }
2527
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2528 virtual bool DoFusedBatchNormActivationForward(
2529 Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2530 const DeviceMemory<Eigen::half>& x_data,
2531 const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2532 const DeviceMemory<float>& scale_data,
2533 const DeviceMemory<float>& offset_data, double epsilon,
2534 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2535 DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2536 DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2537 dnn::ProfileResult* output_profile_result) {
2538 return false;
2539 }
2540
2541 // Enqueues a fused batchnorm+activation (training-bwd) operation onto the
2542 // stream.
2543 //
2544 // Arguments (all borrowed):
2545 //
2546 // stream: borrowed pointer to the stream that the 'fusion' operation should
2547 // be enqueued onto.
2548 //
2549 // y_act_backprop_descriptor: dimensions of the backprop input from the
2550 // previous layer. y_act_backprop_data: device memory which contains the
2551 // backprop input.
2552 //
2553 // y_act_data: device memory which contains the actv-fwd output data.
2554 //
2555 // activation_mode: actv-fwd type.
2556 //
2557 // scale_offset_mean_variance_descriptor:
2558 // dimensions of the scale/offset/mean/variance tensor.
2559 // scale_data: device memory which contains the scale input.
2560 // offset_data: device memory which contains the offset input.
2561 // saved_mean_data: device memory which contains the saved mean from fwd
2562 // pass. saved_var_data: device memory which contains the saved variance from
2563 // fwd pass.
2564 //
2565 // x_bn_backprop_data: device memory region in which to place the backprop
2566 // data from this layer scale_backprop_data: device memory in which to place
2567 // the scale backprop output. offset_backprop_data: device memory in which to
2568 // place the offset backprop output.
2569 //
2570 // output_profile_result: the output profile result for this call.
2571 // The profiling is only enabled when this is not nullptr.
2572 //
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2573 virtual bool DoFusedBatchNormActivationBackward(
2574 Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2575 const DeviceMemory<float>& y_act_backprop_data,
2576 const DeviceMemory<float>& y_act_data,
2577 dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
2578 const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2579 const DeviceMemory<float>& scale_data,
2580 const DeviceMemory<float>& offset_data,
2581 const DeviceMemory<float>& saved_mean_data,
2582 const DeviceMemory<float>& saved_var_data,
2583 DeviceMemory<float>* x_bn_backprop_data,
2584 DeviceMemory<float>* scale_backprop_data,
2585 DeviceMemory<float>* offset_backprop_data,
2586 dnn::ProfileResult* output_profile_result) {
2587 return false;
2588 }
2589
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2590 virtual bool DoFusedBatchNormActivationBackward(
2591 Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2592 const DeviceMemory<Eigen::half>& y_act_backprop_data,
2593 const DeviceMemory<Eigen::half>& y_act_data,
2594 dnn::ActivationMode activation_mode,
2595 const DeviceMemory<Eigen::half>& x_bn_data,
2596 const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2597 const DeviceMemory<float>& scale_data,
2598 const DeviceMemory<float>& offset_data,
2599 const DeviceMemory<float>& saved_mean_data,
2600 const DeviceMemory<float>& saved_var_data,
2601 DeviceMemory<Eigen::half>* x_bn_backprop_data,
2602 DeviceMemory<float>* scale_backprop_data,
2603 DeviceMemory<float>* offset_backprop_data,
2604 dnn::ProfileResult* output_profile_result) {
2605 return false;
2606 }
2607
2608 // Notifies that a stream is being destroyed and should be invalidated from
2609 // any internal caching. This exists to allow the CUDA implementation to
2610 // avoid redundant cudnnSetStream calls without risking problems when a stream
2611 // is destroyed and a new stream later created in the same memory.
NotifyStreamDestroyed(Stream * stream)2612 virtual void NotifyStreamDestroyed(Stream* stream) {}
2613
2614 protected:
2615 // Returns whether status is 'ok', and potentially logs the error.
2616 static bool IsStatusOk(const port::Status& status, bool report_error);
2617
2618 private:
DoPrepareForConvolution(ConvolutionKind kind,DataType element_type,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemoryBase input_data,const FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2619 virtual port::Status DoPrepareForConvolution(
2620 ConvolutionKind kind, DataType element_type, Stream* stream,
2621 const BatchDescriptor& batch_descriptor, DeviceMemoryBase input_data,
2622 const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data,
2623 const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
2624 const ConvolutionDescriptor& convolution_descriptor,
2625 const AlgorithmConfig& algorithm_config,
2626 ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
2627 DeviceMemory<uint8>* scratch_memory) {
2628 *algorithm_desc = {};
2629 *scratch_memory = {};
2630 return ::tensorflow::OkStatus();
2631 }
2632
DoPrepareForCtcLoss(Stream * stream,DataType element_type,const RnnStateTensorDescriptor & probs_desc,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2633 virtual port::Status DoPrepareForCtcLoss(
2634 Stream* stream, DataType element_type,
2635 const RnnStateTensorDescriptor& probs_desc,
2636 const RnnStateTensorDescriptor& grads_desc,
2637 absl::Span<const int> labels_data,
2638 absl::Span<const int> labels_lengths_data,
2639 absl::Span<const int> input_lengths_data,
2640 ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
2641 int* ctc_loss_algo_id) {
2642 *scratch_memory = {};
2643 return ::tensorflow::OkStatus();
2644 }
2645
2646 SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
2647 };
2648
2649 } // namespace dnn
2650 } // namespace stream_executor
2651
2652 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DNN_H_
2653