• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Neural Net operation support for StreamExecutor instances.
17 //
18 // This is an abstract interface for a platform to optionally support common
19 // neural net operations; it accommodates implementations such as the cudnn
20 // library operations.
21 
22 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DNN_H_
23 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DNN_H_
24 
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <string>
29 #include <tuple>
30 #include <type_traits>
31 #include <utility>
32 
33 #include "google/protobuf/wrappers.pb.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/stream_executor/data_type.h"
37 #include "tensorflow/compiler/xla/stream_executor/device_description.h"
38 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
39 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
40 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
41 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
42 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
43 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
44 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
45 
46 namespace Eigen {
47 struct half;
48 }  // namespace Eigen
49 
50 namespace stream_executor {
51 
52 class HostBuffer;
53 class Stream;
54 class ScratchAllocator;
55 
56 namespace dnn {
57 
58 // Specifies an index to use when accessing specific spatial dimensions.
59 enum class DimIndex : int {
60   X = 0,
61   Y = 1,
62   Z = 2,
63 };
64 
65 // Return a reordered dims.
66 std::vector<int64_t> ReorderDims(const std::vector<int64_t>& input,
67                                  const DataLayout& from, const DataLayout& to);
68 
69 // Helper functions to make methods more readable.
GetDim(absl::Span<const int64_t> data,DimIndex dim)70 inline int64_t GetDim(absl::Span<const int64_t> data, DimIndex dim) {
71   return data.rbegin()[static_cast<int64_t>(dim)];
72 }
73 
SetDim(absl::Span<int64_t> data,DimIndex dim,int64_t value)74 inline void SetDim(absl::Span<int64_t> data, DimIndex dim, int64_t value) {
75   data.rbegin()[static_cast<int64_t>(dim)] = value;
76 }
77 
SetDim(std::vector<int64_t> * data,DimIndex dim,int64_t value)78 inline void SetDim(std::vector<int64_t>* data, DimIndex dim, int64_t value) {
79   return SetDim(absl::MakeSpan(*data), dim, value);
80 }
81 
82 // int64_t is not the same type as tensorflow::protobuf_int64 in open-source.
83 // This wrapper function gives an int64_t array slice view of a repeated int64
84 // protobuf field.
85 //
86 // T should be a protobuf RepeatedField.
87 template <typename T>
AsInt64Slice(const T & repeated_field)88 inline absl::Span<const int64_t> AsInt64Slice(const T& repeated_field) {
89   using data_ty =
90       typename std::remove_reference<decltype(*repeated_field.data())>::type;
91   static_assert(std::is_integral<data_ty>::value &&
92                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
93                 "repeated_field.data() must return a pointer to a signed "
94                 "64-bit integer type.");
95   return absl::Span<const int64_t>(
96       reinterpret_cast<const int64_t*>(repeated_field.data()),
97       repeated_field.size());
98 }
99 template <typename T>
AsInt64Slice(T * repeated_field)100 inline absl::Span<int64_t> AsInt64Slice(T* repeated_field) {
101   using data_ty =
102       typename std::remove_reference<decltype(*repeated_field->data())>::type;
103   static_assert(std::is_integral<data_ty>::value &&
104                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
105                 "repeated_field->data() must return a pointer to a signed "
106                 "64-bit integer type.");
107   return absl::Span<int64_t>(
108       reinterpret_cast<int64_t*>(repeated_field->mutable_data()),
109       repeated_field->size());
110 }
111 
112 // Returns a string representation of the given data layout.
113 std::string DataLayoutString(DataLayout layout);
114 
115 // Specifies a quantization for activations in a given BatchDescriptor.
116 enum class QuantizedActivationMode {
117   k8Bit = 1,
118   k16Bit = 2,
119   k32Bit = 4,
120 };
121 
122 // Specifies the types of a RNN model.
123 enum class RnnMode {
124   kRnnRelu = 0,
125   kRnnTanh = 1,
126   kRnnLstm = 2,
127   kRnnGru = 3,
128 };
129 
130 // Specifies the input model and whether there is a linear transformation
131 // between the input state and the first layer hidden state.
132 enum class RnnInputMode {
133   kRnnLinearSkip = 0,
134   kRnnSkipInput = 1,
135 };
136 
137 // Specifies the number of directions used in a RNN model. When bidirection
138 // is used, the input states and output sequence contain data for both
139 // directions.
140 enum class RnnDirectionMode {
141   kRnnUnidirectional = 0,
142   kRnnBidirectional = 1,
143 };
144 
145 // Relevant to DepthToSpace and SpaceToDepth. This is the write layout when
146 // performing depth to space and the read layout when performing space to depth.
147 // It's specified with most-major dimension first and most-minor dimension last.
148 // In DepthToSpace, the D*M^2 values are read in and then, for DepthHeightWidth,
149 // written out to the output patch, by varying first width, then height, then
150 // depth. In C array format, it looks like [depth][height][width]. See
151 // DepthToSpace comment for more information.
152 enum class DepthToSpaceLayout { DepthHeightWidth };
153 
154 // Specifies the descriptor for a RNN model.
155 //
156 // An example use case:
157 //   * The user first creates a model through createRnnDescriptor.
158 //   * The user queries the size of the underlying opaque parameter buffer.
159 //   * The user creates and initializes a parameter buffer of the proper size.
160 //   * The user runs forward and backward operations using this RNN descriptor.
161 //   * Once a while, user queries maintainable weights and bias regions from
162 //       the underlying parameter buffer. They are more likely to be forward
163 //       compatible and should used in saving and restoring a model.
164 //   * The user releases the RNN descriptor when the model is no longer in use.
165 class RnnDescriptor {
166  public:
167   struct ParamsRegion {
168     int64_t offset;
169     int64_t size;
170   };
171   typedef std::vector<ParamsRegion> ParamsRegions;
~RnnDescriptor()172   virtual ~RnnDescriptor() {}
ParamsSizeInBytes()173   virtual int64_t ParamsSizeInBytes() const { return -1; }
ParamsWeightRegions()174   virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); }
ParamsBiasRegions()175   virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
176 };
177 
178 // Specifies the sequence in a RNN model.
179 //
180 // The user is responsible for releasing this descriptor when it is no longer
181 // in use. The destructor releases the underlying descriptors.
182 class RnnSequenceTensorDescriptor {
183  public:
~RnnSequenceTensorDescriptor()184   virtual ~RnnSequenceTensorDescriptor() {}
185 };
186 
187 // Specifies either the input and hidden state in a RNN model.
188 //
189 // The user is responsible for releasing this descriptor when it is no longer
190 // in use. The destructor releases the underlying descriptors.
191 class RnnStateTensorDescriptor {
192  public:
~RnnStateTensorDescriptor()193   virtual ~RnnStateTensorDescriptor() {}
194 };
195 
196 // Returns a string representation of the given quantization mode.
197 std::string QuantizedActivationModeString(QuantizedActivationMode mode);
198 
199 // Describes the dimensions that a layer consumes/produces.
200 //
201 // This is a matrix (height, width), its "depth" (feature_map_count),
202 // how many of these matrices are present (count),
203 // and the maximum and minimum values expected in the matrix (value_max,
204 // value_min).
205 // If input is quantized, all values greater
206 // than value_max will be clipped to value_max and all values less than
207 // value_min will be clipped to value_min.
208 // When quantized output is dequantized no value will be greater than
209 // value_max or less than value_min.
210 //
211 // Uses the named argument construction form:
212 //
213 //  auto input_batch_dimensions =
214 //      BatchDescriptor().set_count(42).set_feature_map_count(7)...
215 //
216 // Details:
217 //
218 // For a convolutional layer, a single inference takes a 3-dimensional matrix
219 // of input and produces a 3-dimensional matrix of output. We call the three
220 // dimensions height, width and feature_map_count, where for an image, the
221 // height and width correspond to the Y and X pixel indices, respectively, and
222 // the feature_map_count corresponds to the RGB dimension of the input data.
223 // Then the count indicates how many 3D matrices are being presented to be
224 // processed at once; this corresponds to the neural network concept of
225 // minibatch size.
226 //
227 // For a fully connected layer, it's better to put the nodes of the layer in
228 // the feature_map_count, and leave the height and weight as degenerate (== 1).
229 // Count indicates how many input vectors (degenerate 3D matrices) are to be
230 // processed.
231 //
232 // If unspecified, value_max and value_min default to 0.0.
233 // If value_max == value_min the Stream will attempt to derive valid values -
234 // for example the output of Relu6 activation will always be in the range
235 // [0.0, 6.0].
236 //
237 // If unspecified, layout defaults to kYXDepthBatch.
238 class BatchDescriptor {
239  public:
240   // Creates a "blank" batch descriptor, which should be initialized via the
241   // named argument helpers.
242   BatchDescriptor();
243   explicit BatchDescriptor(int ndims);
244 
245   // Clones values from 'other' for initialization.
246   void CloneFrom(const BatchDescriptor& other);
247 
248   std::string ToString() const;
249   std::string ToShortString() const;
250 
251   // Pre-condition:
252   //   value_max_ == 0
253   //   value_min_ == 0
254   //   quantized_activation_mode_ == QuantizedActivationMode::k8Bit
255   TensorDescriptorProto ToProto(DataType data_type) const;
256 
257   // Accessors.
count()258   int64_t count() const { return tensor_.dimensions(0); }
feature_map_count()259   int64_t feature_map_count() const { return tensor_.dimensions(1); }
height()260   int64_t height() const { return GetDim(spatial_size(), DimIndex::Y); }
width()261   int64_t width() const { return GetDim(spatial_size(), DimIndex::X); }
spatial_dim(DimIndex dim)262   int64_t spatial_dim(DimIndex dim) const {
263     return GetDim(spatial_size(), dim);
264   }
ndims()265   int ndims() const { return spatial_size().size(); }
value_max()266   float value_max() const { return value_max_; }
value_min()267   float value_min() const { return value_min_; }
layout()268   DataLayout layout() const { return tensor_.data_layout(); }
quantized_activation_mode()269   QuantizedActivationMode quantized_activation_mode() const {
270     return quantized_activation_mode_;
271   }
272   // Full dimensions of the underlying data, ordered according to a specific
273   // layout.
274   std::vector<int64_t> full_dims(const DataLayout& layout) const;
275 
276   // Full strides of the underlying data, ordered according to a specific
277   // layout.
278   std::vector<int64_t> full_strides(const DataLayout& layout) const;
279 
280   // Vectorized dimensions where users can specify the dimension that the number
281   // of dimensions is reported rather than the full number of elements.
282   std::vector<int64_t> vectorized_dims(const DataLayout& layout,
283                                        int vector_size, int vector_dim) const;
284 
285   // Vectorized strides correspond to the vectorized_dims.
286   std::vector<int64_t> vectorized_strides(const DataLayout& layout,
287                                           int vector_size,
288                                           int vector_dim) const;
289 
290   // Named-argument helpers for avoiding user error during construction.
set_count(int64_t value)291   BatchDescriptor& set_count(int64_t value) {
292     tensor_.set_dimensions(0, value);
293     return *this;
294   }
set_feature_map_count(int64_t value)295   BatchDescriptor& set_feature_map_count(int64_t value) {
296     tensor_.set_dimensions(1, value);
297     return *this;
298   }
set_height(int64_t value)299   BatchDescriptor& set_height(int64_t value) {
300     SetDim(spatial_size(), DimIndex::Y, value);
301     return *this;
302   }
set_width(int64_t value)303   BatchDescriptor& set_width(int64_t value) {
304     SetDim(spatial_size(), DimIndex::X, value);
305     return *this;
306   }
set_spatial_dim(DimIndex dim,int64_t value)307   BatchDescriptor& set_spatial_dim(DimIndex dim, int64_t value) {
308     SetDim(spatial_size(), dim, value);
309     return *this;
310   }
set_value_max(float value)311   BatchDescriptor& set_value_max(float value) {
312     value_max_ = value;
313     return *this;
314   }
set_value_min(float value)315   BatchDescriptor& set_value_min(float value) {
316     value_min_ = value;
317     return *this;
318   }
set_layout(DataLayout layout)319   BatchDescriptor& set_layout(DataLayout layout) {
320     tensor_.set_data_layout(layout);
321     return *this;
322   }
set_quantized_activation_mode(QuantizedActivationMode quantized_activation_mode)323   BatchDescriptor& set_quantized_activation_mode(
324       QuantizedActivationMode quantized_activation_mode) {
325     quantized_activation_mode_ = quantized_activation_mode;
326     return *this;
327   }
328 
329   // Return the number of nodes in a single feature map.
330   int64_t NodesPerFeatureMap() const;
331 
332   // Return the number of nodes across all feature maps. Note that this is not
333   // affected by the batch count.
334   int64_t NodesAcrossFeatureMaps() const;
335 
336   // Returns the number of elements (e.g. RGB pixel values) required to hold a
337   // given batch descriptor, given a no-padding assumption. Note that this is
338   // affected by the batch count.
339   int64_t ElementCount() const;
340 
341   // Return the number of weights required to fully connect a layer with
342   // dimensions given by the 'input' descriptor with a layer with dimensions
343   // given by the 'output' descriptor.
344   static int64_t FullyConnectedWeightCount(const BatchDescriptor& input,
345                                            const BatchDescriptor& output);
346 
347   // Return the number of biases required to fully connect to an output layer
348   // with dimensions given the 'output' descriptor.
349   static int64_t FullyConnectedBiasCount(const BatchDescriptor& output);
350 
351   // Return a BatchDescriptor for the output of a depth concatenation
352   // with the given input descriptors. The inputs should have the same
353   // dimensions, except possibly for feature_map_count(), though this
354   // function does not verify that.
355   static BatchDescriptor DepthConcatenateOutputDescriptor(
356       port::ArraySlice<dnn::BatchDescriptor> inputs);  // non-absl ok
357 
358  private:
spatial_size()359   absl::Span<const int64_t> spatial_size() const {
360     return AsInt64Slice(tensor_.dimensions()).subspan(2);
361   }
362 
spatial_size()363   absl::Span<int64_t> spatial_size() {
364     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
365   }
366 
367   TensorDescriptorProto tensor_;
368   float value_max_;
369   float value_min_;
370   QuantizedActivationMode quantized_activation_mode_;
371 };
372 
373 // Returns a string representation of the given filter layout.
374 std::string FilterLayoutString(FilterLayout layout);
375 
376 // Describes a filter for the convolution. This is the "window" from
377 // height-by-width patches of each of the feature maps in the input layer to the
378 // cells within the output feature map.
379 //
380 // Uses the named argument construction form:
381 //
382 //  FilterDescriptor filter_dimensions;
383 //  filter_dimensions
384 //    .set_output_feature_map_count(42)
385 //    .set_input_feature_map_count(7)
386 //    ...
387 //
388 // Arguments:
389 // - output_feature_map_count: number of feature maps in the output layer.
390 // - input_feature_map_count: number of feature maps in the input layer (from
391 //      which the filter patch is taken).
392 // - input_filter_height: "height" number of neurons used in the sliding window
393 //      over the input layer.
394 // - input_filter_width: "width" number of neurons used in the sliding window
395 //      over the input layer.
396 //
397 // Sometimes names like "filter input height" are referred to by synonymous
398 // terminology, such as "kernel y size".
399 //
400 // If unspecified, layout defaults to kOutputInputYX.
401 class FilterDescriptor {
402  public:
403   // By default construction, all dimensions are set to zero, so they should all
404   // be populated by the user via the named-argument helpers below. (See class
405   // comment for details.)
406   FilterDescriptor();
407   explicit FilterDescriptor(int ndims);
408   ~FilterDescriptor();
409 
410   // Named-argument helpers for avoiding user error during construction.
set_output_feature_map_count(int64_t value)411   FilterDescriptor& set_output_feature_map_count(int64_t value) {
412     tensor_.set_dimensions(0, value);
413     return *this;
414   }
set_input_feature_map_count(int64_t value)415   FilterDescriptor& set_input_feature_map_count(int64_t value) {
416     tensor_.set_dimensions(1, value);
417     return *this;
418   }
set_input_filter_height(int64_t value)419   FilterDescriptor& set_input_filter_height(int64_t value) {
420     SetDim(input_filter_dims(), DimIndex::Y, value);
421     return *this;
422   }
set_input_filter_width(int64_t value)423   FilterDescriptor& set_input_filter_width(int64_t value) {
424     SetDim(input_filter_dims(), DimIndex::X, value);
425     return *this;
426   }
set_layout(FilterLayout layout)427   FilterDescriptor& set_layout(FilterLayout layout) {
428     tensor_.set_filter_layout(layout);
429     return *this;
430   }
set_spatial_dim(DimIndex dim,int64_t value)431   FilterDescriptor& set_spatial_dim(DimIndex dim, int64_t value) {
432     SetDim(input_filter_dims(), dim, value);
433     return *this;
434   }
ndims()435   int ndims() const { return input_filter_dims().size(); }
436 
437   void CloneFrom(const FilterDescriptor& other);
438 
439   std::string ToString() const;
440   std::string ToShortString() const;
441   TensorDescriptorProto ToProto(DataType data_type) const;
442 
443   // Returns the number of weights required as parameters for a convolution
444   // using this filter descriptor.
445   int64_t ComputeWeightCount() const;
446 
447   // Returns the number of biases required as parameters for a convolution
448   // using this filter descriptor.
bias_count()449   int64_t bias_count() const { return output_feature_map_count(); }
450 
output_feature_map_count()451   int64_t output_feature_map_count() const { return tensor_.dimensions(0); }
input_feature_map_count()452   int64_t input_feature_map_count() const { return tensor_.dimensions(1); }
input_filter_height()453   int64_t input_filter_height() const {
454     return GetDim(input_filter_dims(), DimIndex::Y);
455   }
input_filter_width()456   int64_t input_filter_width() const {
457     return GetDim(input_filter_dims(), DimIndex::X);
458   }
input_filter_dim(DimIndex dim)459   int64_t input_filter_dim(DimIndex dim) const {
460     return GetDim(input_filter_dims(), dim);
461   }
462 
layout()463   FilterLayout layout() const { return tensor_.filter_layout(); }
464 
input_filter_dims()465   absl::Span<const int64_t> input_filter_dims() const {
466     return AsInt64Slice(tensor_.dimensions()).subspan(2);
467   }
468 
469   // Full dimensions of the underlying filter,
470   // ordered according to a specific layout.
471   std::vector<int64_t> full_dims(const FilterLayout& layout) const;
472 
473   // Full strides of the underlying filter,
474   // ordered according to a specific layout.
475   std::vector<int64_t> full_strides(const FilterLayout& layout) const;
476 
477   // Vectorized dimensions where users can specify the dimension that the number
478   // of dimensions is reported rather than the full number of elements.
479   std::vector<int64_t> vectorized_dims(const FilterLayout& layout,
480                                        int vector_size, int vector_dim) const;
481 
482   // Vectorized strides correspond to the vectorized_dims.
483   std::vector<int64_t> vectorized_strides(const FilterLayout& layout,
484                                           int vector_size,
485                                           int vector_dim) const;
486 
487  private:
input_filter_dims()488   absl::Span<int64_t> input_filter_dims() {
489     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
490   }
491 
492   TensorDescriptorProto tensor_;
493 };
494 
495 // Describes how padding should be aligned when the total number of pad
496 // elements is odd.
497 enum class PadAlignment : int64_t {
498   kDefault = 0,        // default padding for the device.
499   kCudnnPadding,       // cuDNN padding - prefer to pad at the start.
500   kTensorFlowPadding,  // TensorFlow padding - prefer to pad at the end.
501 };
502 
503 // Returns a string representation of the given padding alignment.
504 std::string PadAlignmentString(PadAlignment alignment);
505 
506 // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
507 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
508 
509 // Describes a convolution.
510 //
511 // Uses the named argument construction form:
512 //
513 //  ConvolutionDescriptor convolution_dimensions;
514 //  convolution_dimensions
515 //    .set_vertical_filter_stride(2)
516 //    .set_horizontal_filter_stride(2)
517 //    ...
518 //
519 // Arguments:
520 // - zero_padding_height: padding of the "y dimension" of the input data. Note
521 //    that this is different from the height of the filter.
522 // - zero_padding_width: analogous to the height above, but in the "x
523 //    dimension".
524 // - vertical_filter_stride: the convolution slides a 2-dimensional window of
525 //    filter-height-by-filter-width over the input layer -- the center of that
526 //    window is moved in the "y dimension" according to this stride value.
527 // - horizontal_filter_stride: analogous to the vertical stride above, but in
528 //    the "x dimension".
529 // - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped
530 //   cells between each filter element in the "y dimension".
531 // - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1)
532 //   skipped cells between each filter element in the "x dimension".
533 // - convolution_not_crosscor: By default (convolution_not_crosscor == false),
534 //   we perform cross correlation rather than convolution. With the flag set,
535 //   we perform convolution. Convolution and cross correlation are related by
536 //   rotating the filter by 180 degrees (or equivalently flipping all spatial
537 //   dimensions).
538 class ConvolutionDescriptor {
539  public:
540   // By default construction, there is no zero-padding and the filter stride is
541   // 1x1 (centering the filter on every cell in the input layer's
542   // width-by-height area).
543   ConvolutionDescriptor();
544   explicit ConvolutionDescriptor(int ndims);
545   ~ConvolutionDescriptor();
546 
547   std::string ToString() const;
548   std::string ToShortString() const;
ToProto()549   ConvolutionDescriptorProto ToProto() const { return proto_; }
550 
set_zero_padding_height(int64_t value)551   ConvolutionDescriptor& set_zero_padding_height(int64_t value) {
552     SetDim(padding(), DimIndex::Y, value);
553     return *this;
554   }
set_zero_padding_width(int64_t value)555   ConvolutionDescriptor& set_zero_padding_width(int64_t value) {
556     SetDim(padding(), DimIndex::X, value);
557     return *this;
558   }
set_zero_padding(DimIndex dim,int64_t value)559   ConvolutionDescriptor& set_zero_padding(DimIndex dim, int64_t value) {
560     SetDim(padding(), dim, value);
561     return *this;
562   }
set_vertical_filter_stride(int64_t value)563   ConvolutionDescriptor& set_vertical_filter_stride(int64_t value) {
564     SetDim(strides(), DimIndex::Y, value);
565     return *this;
566   }
set_horizontal_filter_stride(int64_t value)567   ConvolutionDescriptor& set_horizontal_filter_stride(int64_t value) {
568     SetDim(strides(), DimIndex::X, value);
569     return *this;
570   }
set_filter_stride(DimIndex dim,int64_t value)571   ConvolutionDescriptor& set_filter_stride(DimIndex dim, int64_t value) {
572     SetDim(strides(), dim, value);
573     return *this;
574   }
set_vertical_dilation_rate(int64_t value)575   ConvolutionDescriptor& set_vertical_dilation_rate(int64_t value) {
576     SetDim(dilations(), DimIndex::Y, value);
577     return *this;
578   }
set_horizontal_dilation_rate(int64_t value)579   ConvolutionDescriptor& set_horizontal_dilation_rate(int64_t value) {
580     SetDim(dilations(), DimIndex::X, value);
581     return *this;
582   }
set_dilation_rate(DimIndex dim,int64_t value)583   ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64_t value) {
584     SetDim(dilations(), dim, value);
585     return *this;
586   }
set_group_count(int group_count)587   ConvolutionDescriptor& set_group_count(int group_count) {
588     proto_.set_group_count(group_count);
589     return *this;
590   }
set_convolution_not_crosscorr(bool conv)591   ConvolutionDescriptor& set_convolution_not_crosscorr(bool conv) {
592     proto_.set_convolution_mode(conv ? ConvolutionMode::CONVOLUTION
593                                      : ConvolutionMode::CROSS_CORRELATION);
594     return *this;
595   }
set_name(const std::string & name)596   ConvolutionDescriptor& set_name(const std::string& name) {
597     proto_.set_name(name);
598     return *this;
599   }
zero_padding_height()600   int64_t zero_padding_height() const { return GetDim(padding(), DimIndex::Y); }
zero_padding_width()601   int64_t zero_padding_width() const { return GetDim(padding(), DimIndex::X); }
vertical_filter_stride()602   int64_t vertical_filter_stride() const {
603     return GetDim(strides(), DimIndex::Y);
604   }
horizontal_filter_stride()605   int64_t horizontal_filter_stride() const {
606     return GetDim(strides(), DimIndex::X);
607   }
vertical_dilation_rate()608   int64_t vertical_dilation_rate() const {
609     return GetDim(dilations(), DimIndex::Y);
610   }
horizontal_dilation_rate()611   int64_t horizontal_dilation_rate() const {
612     return GetDim(dilations(), DimIndex::X);
613   }
614 
zero_padding(DimIndex dim)615   int zero_padding(DimIndex dim) const { return GetDim(padding(), dim); }
filter_stride(DimIndex dim)616   int filter_stride(DimIndex dim) const { return GetDim(strides(), dim); }
dilation_rate(DimIndex dim)617   int dilation_rate(DimIndex dim) const { return GetDim(dilations(), dim); }
618   // TODO(timshen): remove this function. No users of this class is setting a
619   // non-default pad alignment.
pad_alignment()620   PadAlignment pad_alignment() const { return PadAlignment::kDefault; }
group_count()621   int group_count() const { return proto_.group_count(); }
ndims()622   int ndims() const { return padding().size(); }
convolution_not_crosscorr()623   bool convolution_not_crosscorr() const {
624     return proto_.convolution_mode() == ConvolutionMode::CONVOLUTION;
625   }
626 
strides()627   absl::Span<const int64_t> strides() const {
628     return AsInt64Slice(proto_.strides());
629   }
630 
dilations()631   absl::Span<const int64_t> dilations() const {
632     return AsInt64Slice(proto_.dilations());
633   }
634 
padding()635   absl::Span<const int64_t> padding() const {
636     return AsInt64Slice(proto_.paddings());
637   }
638 
name()639   std::string name() const { return proto_.name(); }
640 
641  private:
strides()642   absl::Span<int64_t> strides() {
643     return AsInt64Slice(proto_.mutable_strides());
644   }
645 
dilations()646   absl::Span<int64_t> dilations() {
647     return AsInt64Slice(proto_.mutable_dilations());
648   }
649 
padding()650   absl::Span<int64_t> padding() {
651     return AsInt64Slice(proto_.mutable_paddings());
652   }
653 
654   ConvolutionDescriptorProto proto_;
655 
656   // TODO(leary) cudnn provides these fields, but need to characterize what
657   // their effect is -- they may be boolean rather than integral.
658   // int64_t upscale_input_x;
659   // int64_t upscale_input_y;
660 };
661 
662 // A patch of values in the input can be pooled via either a max or an average
663 // operation.
664 // Specify int64_t so there's no padding in PoolingDescriptor.
665 enum class PoolingMode : int64_t {
666   kMaximum,
667   kAverage,
668 };
669 
670 // Specify the dimension in which to concatenate inputs in space.
671 // Specify int64_t so there's no padding in SpaceConcatenateMode.
672 enum class SpaceConcatenateMode : int64_t {
673   XDirection,
674   YDirection,
675 };
676 
677 // Returns a short name for the pooling mode, e.g. "Avg".
678 std::string ShortPoolingModeString(PoolingMode mode);
679 
680 // Describes a pooling operation to be enqueued onto a stream via a platform's
681 // DnnSupport.
682 //
683 // TODO(broune): describe how padding works and what happens if the
684 // window height/width is not divisible by the vertical/horizontal
685 // stride.
686 //
687 // Arguments:
688 //  pooling_mode: pooling operator to use on the input patch
689 //  window_height: height of input window
690 //  window_width: width of input window
691 //  vertical_stride: vertical delta for center of the input patch
692 //  horizontal_stride: horizontal delta for center of the input patch
693 class PoolingDescriptor {
694  public:
695   PoolingDescriptor();
696   explicit PoolingDescriptor(int ndims);
697 
set_pooling_mode(PoolingMode value)698   PoolingDescriptor& set_pooling_mode(PoolingMode value) {
699     mode_ = value;
700     return *this;
701   }
set_window_height(int64_t value)702   PoolingDescriptor& set_window_height(int64_t value) {
703     SetDim(&window_, DimIndex::Y, value);
704     return *this;
705   }
set_window_width(int64_t value)706   PoolingDescriptor& set_window_width(int64_t value) {
707     SetDim(&window_, DimIndex::X, value);
708     return *this;
709   }
set_window(DimIndex dim,int64_t value)710   PoolingDescriptor& set_window(DimIndex dim, int64_t value) {
711     SetDim(&window_, dim, value);
712     return *this;
713   }
set_vertical_padding(int64_t value)714   PoolingDescriptor& set_vertical_padding(int64_t value) {
715     SetDim(&padding_, DimIndex::Y, value);
716     return *this;
717   }
set_horizontal_padding(int64_t value)718   PoolingDescriptor& set_horizontal_padding(int64_t value) {
719     SetDim(&padding_, DimIndex::X, value);
720     return *this;
721   }
set_padding(DimIndex dim,int64_t value)722   PoolingDescriptor& set_padding(DimIndex dim, int64_t value) {
723     SetDim(&padding_, dim, value);
724     return *this;
725   }
set_vertical_stride(int64_t value)726   PoolingDescriptor& set_vertical_stride(int64_t value) {
727     SetDim(&strides_, DimIndex::Y, value);
728     return *this;
729   }
set_horizontal_stride(int64_t value)730   PoolingDescriptor& set_horizontal_stride(int64_t value) {
731     SetDim(&strides_, DimIndex::X, value);
732     return *this;
733   }
set_stride(DimIndex dim,int64_t value)734   PoolingDescriptor& set_stride(DimIndex dim, int64_t value) {
735     SetDim(&strides_, dim, value);
736     return *this;
737   }
set_propagate_nans(bool value)738   PoolingDescriptor& set_propagate_nans(bool value) {
739     propagate_nans_ = value;
740     return *this;
741   }
set_name(const std::string & name)742   PoolingDescriptor& set_name(const std::string& name) {
743     name_ = name;
744     return *this;
745   }
746 
ndims()747   int ndims() const { return ndims_; }
748   void CloneFrom(const PoolingDescriptor& other);
749 
750   std::string ToString() const;
751   std::string ToShortString() const;
752 
mode()753   PoolingMode mode() const { return mode_; }
window_height()754   int64_t window_height() const { return GetDim(window_, DimIndex::Y); }
window_width()755   int64_t window_width() const { return GetDim(window_, DimIndex::X); }
window(DimIndex dim)756   int64_t window(DimIndex dim) const { return GetDim(window_, dim); }
vertical_padding()757   int64_t vertical_padding() const { return GetDim(padding_, DimIndex::Y); }
horizontal_padding()758   int64_t horizontal_padding() const { return GetDim(padding_, DimIndex::X); }
padding(DimIndex dim)759   int64_t padding(DimIndex dim) const { return GetDim(padding_, dim); }
vertical_stride()760   int64_t vertical_stride() const { return GetDim(strides_, DimIndex::Y); }
horizontal_stride()761   int64_t horizontal_stride() const { return GetDim(strides_, DimIndex::X); }
stride(DimIndex dim)762   int64_t stride(DimIndex dim) const { return GetDim(strides_, dim); }
window()763   absl::Span<const int64_t> window() const { return window_; }
padding()764   absl::Span<const int64_t> padding() const { return padding_; }
strides()765   absl::Span<const int64_t> strides() const { return strides_; }
propagate_nans()766   bool propagate_nans() const { return propagate_nans_; }
name()767   std::string name() const { return name_; }
768 
769  private:
770   PoolingMode mode_;
771   int ndims_;
772   bool propagate_nans_;
773   std::string name_;  // Name as in Tensorflow NodeDef, for debugging purposes.
774 
775   // Stored as: ..., y, x.
776   std::vector<int64_t> window_;
777   std::vector<int64_t> padding_;
778   std::vector<int64_t> strides_;
779 };
780 
781 // Collects parameters for DNN algorithms
782 class AlgorithmDesc {
783  public:
784   typedef int64_t Index;
AlgorithmDesc()785   AlgorithmDesc() : AlgorithmDesc(0, false, std::nullopt) {}
AlgorithmDesc(AlgorithmProto proto)786   explicit AlgorithmDesc(AlgorithmProto proto) : proto_(std::move(proto)) {}
AlgorithmDesc(Index algo_id,bool use_tensor_ops)787   AlgorithmDesc(Index algo_id, bool use_tensor_ops)
788       : AlgorithmDesc(algo_id, use_tensor_ops, std::nullopt) {}
AlgorithmDesc(Index algo_id,bool use_tensor_ops,std::optional<uint64_t> workspace_size)789   AlgorithmDesc(Index algo_id, bool use_tensor_ops,
790                 std::optional<uint64_t> workspace_size) {
791     proto_.set_is_cudnn_frontend(false);
792     proto_.set_algo_id(algo_id);
793     proto_.set_math_type(use_tensor_ops ? AlgorithmProto::TENSOR_OP_MATH
794                                         : AlgorithmProto::DEFAULT_MATH);
795     if (workspace_size) {
796       proto_.mutable_workspace_size()->set_value(*workspace_size);
797     }
798   }
799   AlgorithmDesc(int64_t engine_id,
800                 const std::vector<std::pair<int64_t, int64_t>>& tuning_knobs,
801                 std::optional<uint64_t> workspace_size);
is_cudnn_frontend()802   bool is_cudnn_frontend() const { return proto_.is_cudnn_frontend(); }
803 
tensor_ops_enabled()804   bool tensor_ops_enabled() const {
805     return proto_.math_type() == AlgorithmProto::TENSOR_OP_MATH;
806   }
workspace_size()807   std::optional<uint64_t> workspace_size() const {
808     if (proto_.has_workspace_size()) {
809       return proto_.workspace_size().value();
810     }
811     return std::nullopt;
812   }
algo_id()813   Index algo_id() const { return proto_.algo_id(); }
814 
815   std::vector<std::pair<int64_t, int64_t>> TuningKnobs() const;
816 
817   bool operator==(const AlgorithmDesc& other) const;
818 
819   uint64_t hash() const;
820 
ToProto()821   AlgorithmProto ToProto() const { return proto_; }
822 
823   std::string ToString() const;
824 
825  private:
826   AlgorithmProto proto_;
827 };
828 
829 // Describes the result from a perf experiment.
830 //
831 // Arguments:
832 //  algorithm: returns the exact algorithm that was used.
833 //  elapsed_time_in_ms: returns the measured elapsed time in milliseconds.
834 class ProfileResult {
835  public:
is_valid()836   bool is_valid() const {
837     return algorithm_.has_value() &&
838            elapsed_time_in_ms() != std::numeric_limits<float>::max();
839   }
840 
algorithm()841   AlgorithmDesc algorithm() const { return *algorithm_; }
set_algorithm(AlgorithmDesc val)842   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
843 
elapsed_time_in_ms()844   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)845   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
846 
scratch_size()847   size_t scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)848   void set_scratch_size(size_t val) { scratch_size_ = val; }
849 
850  private:
851   std::optional<AlgorithmDesc> algorithm_;
852   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
853   // The scratch size algorithm_ requires. Currently it's only populated by
854   // convolutions.
855   size_t scratch_size_ = 0;
856 };
857 
858 // Backend-specific data shared between repeated launches of the same
859 // convolution.
860 template <typename Sig>
861 class OpRunner;
862 
863 // An abstract class owning cached state for a particular op/configuration.
864 //
865 // The primary motivation for this is cuDNN backend ExecutionPlans, which are
866 // costly to recreate.
867 //
868 // All OpRunners must be outlived by their parent Stream.
869 template <typename... Args>
870 class OpRunner<void(Args...)> {
871  public:
~OpRunner()872   virtual ~OpRunner() {}
873 
874   // Get a description of the runner, for uniqueness of autotune entries.
875   //
876   // Since this is used to determine whether runners are equivalent for the
877   // purpose of scoring autotune entries, it shall be unique among runners of
878   // the same op and parameters.
879   virtual std::string ToString() const = 0;
880 
881   // Get the number of bytes of scratch space needed for `operator()`.
882   //
883   // If determining the workspace size can fail, runners should precompute and
884   // cache it at construction time.
885   virtual size_t GetWorkspaceSize() const = 0;
886 
887   // Convert to an AlgorithmDesc for AoT compilation or autotuning.
888   virtual port::StatusOr<AlgorithmDesc> ToAlgorithmDesc() const = 0;
889 
890   // Launch the operation, with the signature determined by `Sig`.
891   virtual port::Status operator()(Stream*, ProfileResult*,
892                                   DeviceMemoryBase scratch_memory,
893                                   Args... args) const = 0;
894 };
895 
896 using ConvSignature = void(DeviceMemoryBase /* input_data */,
897                            DeviceMemoryBase /* filter_data */,
898                            DeviceMemoryBase /* output_data */);
899 using ConvRunner = OpRunner<ConvSignature>;
900 
901 using FusedConvSignature = void(DeviceMemoryBase /* input_data */,
902                                 DeviceMemoryBase /* filter_data */,
903                                 DeviceMemoryBase /* side_input_data */,
904                                 DeviceMemoryBase /* bias_data */,
905                                 DeviceMemoryBase /* output_data */);
906 using FusedConvRunner = OpRunner<FusedConvSignature>;
907 
908 // Describes the configuration for the algorithms that will used.
909 //
910 // Arguments:
911 //  algorithm: the primary algorithm that should be used.
912 //  algorithm_no_scratch: a secondary algorithm that should be used, if the
913 //    the allocation for the scratch memory fails.
914 //  scrach_size: specify the size of scratch memory in bytes needed for the
915 //    algorithm used.
916 //
917 // On CUDA platform with CUDNN library, algorithm and algorithm_no_scratch
918 // would be used. On ROCm platform with MIOpen library, algorithm and
919 // scratch_size would be used. The major difference between the two platforms
920 // are whether it's possible to get an algorithm without scratch memory. On
921 // CUDA + CUDNN it's possible, and algorithm_no_scratch can be used to track
922 // such information, whereas on ROCm + MIOpen there is no guarantee to getting
923 // one without scratch memory, and scratch_size field is used to track it.
924 class AlgorithmConfig {
925  public:
AlgorithmConfig()926   AlgorithmConfig() {}
AlgorithmConfig(AlgorithmDesc algorithm)927   explicit AlgorithmConfig(AlgorithmDesc algorithm) : algorithm_(algorithm) {}
AlgorithmConfig(AlgorithmDesc algorithm,size_t scratch_size)928   AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size)
929       : algorithm_(algorithm), scratch_size_(scratch_size) {}
AlgorithmConfig(AlgorithmDesc algorithm,AlgorithmDesc algorithm_no_scratch)930   AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch)
931       : algorithm_(algorithm), algorithm_no_scratch_(algorithm_no_scratch) {}
AlgorithmConfig(AlgorithmDesc algorithm,size_t scratch_size,AlgorithmDesc algorithm_no_scratch)932   AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size,
933                   AlgorithmDesc algorithm_no_scratch)
934       : algorithm_(algorithm),
935         algorithm_no_scratch_(algorithm_no_scratch),
936         scratch_size_(scratch_size) {}
937 
938   // TODO(ruochengw): After cl/380702564, add support for algorithm configs with
939   // cuDNN Frontend APIs.
AlgorithmConfig(const AlgorithmConfigProto & algorithm_config_proto)940   explicit AlgorithmConfig(const AlgorithmConfigProto& algorithm_config_proto) {
941     const AlgorithmProto& algorithm_proto = algorithm_config_proto.algorithm();
942     algorithm_ = AlgorithmDesc(algorithm_proto);
943     if (algorithm_config_proto.optional_scratch_size_case() !=
944         /*ONEOF_NAME_NOT_SET=*/0) {
945       scratch_size_ = algorithm_config_proto.scratch_size();
946     }
947     if (algorithm_config_proto.optional_algorithm_no_scratch_case() !=
948         /*ONEOF_NAME_NOT_SET=*/0) {
949       const AlgorithmProto& algorithm_no_scratch_proto =
950           algorithm_config_proto.algorithm_no_scratch();
951       algorithm_no_scratch_ = AlgorithmDesc(algorithm_no_scratch_proto);
952     }
953   }
954 
algorithm()955   std::optional<AlgorithmDesc> algorithm() const { return algorithm_; }
set_algorithm(AlgorithmDesc val)956   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
algorithm_no_scratch()957   std::optional<AlgorithmDesc> algorithm_no_scratch() const {
958     return algorithm_no_scratch_;
959   }
set_algorithm_no_scratch(AlgorithmDesc val)960   void set_algorithm_no_scratch(AlgorithmDesc val) {
961     algorithm_no_scratch_ = val;
962   }
scratch_size()963   std::optional<size_t> scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)964   void set_scratch_size(size_t val) { scratch_size_ = val; }
965   bool operator==(const AlgorithmConfig& other) const {
966     return this->algorithm_ == other.algorithm_ &&
967            this->algorithm_no_scratch_ == other.algorithm_no_scratch_ &&
968            this->scratch_size_ == other.scratch_size_;
969   }
970   bool operator!=(const AlgorithmConfig& other) const {
971     return !(*this == other);
972   }
973   std::string ToString() const;
974 
975   // TODO(ruochengw): After cl/380702564, add support for algorithm configs with
976   // cuDNN Frontend APIs.
ToProto()977   AlgorithmConfigProto ToProto() const {
978     AlgorithmConfigProto algorithm_config_proto;
979     if (algorithm_.has_value()) {
980       *algorithm_config_proto.mutable_algorithm() =
981           algorithm_.value().ToProto();
982     }
983     if (algorithm_no_scratch_.has_value()) {
984       *algorithm_config_proto.mutable_algorithm_no_scratch() =
985           algorithm_no_scratch_.value().ToProto();
986     }
987     if (scratch_size_.has_value()) {
988       algorithm_config_proto.set_scratch_size(scratch_size_.value());
989     }
990     return algorithm_config_proto;
991   }
992 
993  private:
994   std::optional<AlgorithmDesc> algorithm_;
995   std::optional<AlgorithmDesc> algorithm_no_scratch_;
996   std::optional<size_t> scratch_size_;
997 };
998 
999 // Describes a local response normalization (LRN). LRN is used e.g. in
1000 // dist_belief.
1001 //
1002 // Let V be the vector of feature maps at some (batch, y, x)
1003 // coordinate. LRN applies independently to each vector V in the
1004 // input, across all coordinates (batch, y, x), by mapping each V to
1005 // another vector U of the same size using the formula
1006 //
1007 //   U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
1008 //
1009 // where the sum is taken over j in the closed range [i - range, i + range].
1010 //
1011 // When calculating U_i the j in the sum can extend beyond the bounds
1012 // of V. If wrap_around is true, then V_j = V_{j mod F} where F is the
1013 // size of V, which is the number of feature maps. If wrap_around is
1014 // false, then V_j = 0 for j outside [0, F-1].
1015 //
1016 // If segment_size <= F, where F is the number of feature_maps, then
1017 // segment_size has no effect. Otherwise, each consecutive segment of
1018 // segment_size entries in V are normalized separately.
1019 //
1020 // Not all StreamExecutors allow wrap_around == true or segment_size
1021 // != 64. Some do not implement normalization at all.
1022 class NormalizeDescriptor {
1023  public:
1024   NormalizeDescriptor();
1025 
set_bias(float bias)1026   NormalizeDescriptor& set_bias(float bias) {
1027     bias_ = bias;
1028     return *this;
1029   }
1030 
set_range(int32_t range)1031   NormalizeDescriptor& set_range(int32_t range) {
1032     range_ = range;
1033     return *this;
1034   }
1035 
set_alpha(float alpha)1036   NormalizeDescriptor& set_alpha(float alpha) {
1037     alpha_ = alpha;
1038     return *this;
1039   }
1040 
set_beta(float beta)1041   NormalizeDescriptor& set_beta(float beta) {
1042     beta_ = beta;
1043     return *this;
1044   }
1045 
set_wrap_around(bool wrap_around)1046   NormalizeDescriptor& set_wrap_around(bool wrap_around) {
1047     wrap_around_ = wrap_around;
1048     return *this;
1049   }
1050 
set_segment_size(int32_t segment_size)1051   NormalizeDescriptor& set_segment_size(int32_t segment_size) {
1052     segment_size_ = segment_size;
1053     return *this;
1054   }
1055 
1056   void CloneFrom(const NormalizeDescriptor& other);
1057 
1058   std::string ToString() const;
1059   std::string ToShortString() const;
1060 
bias()1061   float bias() const { return bias_; }
range()1062   int32 range() const { return range_; }
alpha()1063   float alpha() const { return alpha_; }
beta()1064   float beta() const { return beta_; }
wrap_around()1065   bool wrap_around() const { return wrap_around_; }
segment_size()1066   int32 segment_size() const { return segment_size_; }
1067 
1068  private:
1069   float bias_;
1070   int32 range_;
1071   float alpha_;
1072   float beta_;
1073   bool wrap_around_;
1074   int32 segment_size_;
1075 };
1076 
1077 // Returns a string representation of the given activation mode.
1078 std::string ActivationModeString(ActivationMode mode);
1079 
1080 // Describes the operation that DoElementwiseOperation should perform on its
1081 // inputs.
1082 enum class ElementwiseOperation { kAdd, kMultiply };
1083 
1084 std::string ElementwiseOperationString(ElementwiseOperation op);
1085 
1086 // A simple class representing the version of the backing library, to
1087 // workaround the "too perfect forwarding" issue in gcc6+ compilers.
1088 // See PR#16309 and issue #18402 for links discussing the issue.
1089 class VersionInfo {
1090  public:
1091   VersionInfo(int major = 0, int minor = 0, int patch = 0)
major_(major)1092       : major_(major), minor_(minor), patch_(patch) {}
major_version()1093   int major_version() const { return major_; }
minor_version()1094   int minor_version() const { return minor_; }
patch()1095   int patch() const { return patch_; }
1096 
1097  private:
1098   int major_;
1099   int minor_;
1100   int patch_;
1101 };
1102 
1103 // Suite of operations typically used for implementing Deep/Convolutional Neural
1104 // Nets. Note: A false return value of an operation indicates the
1105 // implementation is not available.
1106 //
1107 // TODO(b/118763918): this class (or rather dispatch table) has several
1108 // problems:
1109 // * Some overloads are missing. Ideally we want to have template virtual
1110 //   functions while the template arguments is a closed set. However, we don't
1111 //   get that from the language.
1112 // * The API is a union of cuDNN and another private backend. Only 10% of the
1113 //   functions are actually implemented by both backends, the rest are
1114 //   actually backend-specific. The massive interface creates extra mental
1115 //   burden.
1116 // * Poor error handling: the API should return Status objects.
1117 //
1118 // PrepareForConvolution is an example for how new APIs should be written.
1119 class DnnSupport {
1120  public:
DnnSupport()1121   DnnSupport() {}
~DnnSupport()1122   virtual ~DnnSupport() {}
1123 
1124   virtual port::Status Init() = 0;
1125 
1126   // Gets the version of the backing library, as a VersionInfo object.
GetVersion()1127   virtual port::StatusOr<VersionInfo> GetVersion() {
1128     return port::UnimplementedError(
1129         "DnnSupport::GetVersion not implemented on this platform.");
1130   }
1131 
1132   // Performs a single-precision forward batch normalization operation onto
1133   // the stream.
1134   //
1135   // Arguments:
1136   //  stream: borrowed pointer to the stream that the batch normalization
1137   //    operation should be enqueued onto.
1138   //  x: input data.
1139   //  scale: scaling parameters.
1140   //  offset: offset parameters.
1141   //  estimated_mean: population mean estimated during training.
1142   //    Used for inference only; empty for training.
1143   //  estimated_variance: population variance estimated during training,
1144   //    used for inference only; empty for training.
1145   //  side_input: optional input that is element-wise added to the output of
1146   //    batch normalization.
1147   //  x_desc: dimensions of the input data, which is the same as the dimensions
1148   //    of the output and side input.
1149   //  scale_offset_desc: dimensions of scale and offset.
1150   //  epsilon: a small floating point number added to the variance of x.
1151   //  activation_mode: activation applied to the result of batch normalization
1152   //    (or after adding optional side input)
1153   //  y: output data.
1154   //  batch_mean: batch mean, to be used to compute the running mean.
1155   //  batch_variance: batch variance, to be used to compute
1156   //    the running variance.
1157   //  reserve_space_1: saved mean, to be reused in the backward gradient
1158   //    computation.
1159   //  reserve_space_2: saved inv_var (1/sqrt(epsilon + variance), to be reused
1160   //    in the backward gradient computation.
1161   //  is_training: Set to true for training, false for inference.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1162   virtual bool DoBatchNormalizationForward(
1163       Stream* stream, const DeviceMemory<float>& x,
1164       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1165       const DeviceMemory<float>& estimated_mean,
1166       const DeviceMemory<float>& estimated_variance,
1167       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
1168       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1169       const double exponential_average_factor,
1170       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
1171       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1172       DeviceMemory<float>* reserve_space_1,
1173       DeviceMemory<float>* reserve_space_2, bool is_training,
1174       ScratchAllocator* reserve_space_allocator,
1175       ScratchAllocator* workspace_allocator) {
1176     return false;
1177   }
1178 
1179   // Performs a half-precision forwards batch normalization operation onto the
1180   // stream. See DoBatchNormalizationForward above for argument details.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1181   virtual bool DoBatchNormalizationForward(
1182       Stream* stream, const DeviceMemory<Eigen::half>& x,
1183       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1184       const DeviceMemory<float>& estimated_mean,
1185       const DeviceMemory<float>& estimated_variance,
1186       const DeviceMemory<Eigen::half>& side_input,
1187       const dnn::BatchDescriptor& x_desc,
1188       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1189       const double exponential_average_factor,
1190       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
1191       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1192       DeviceMemory<float>* reserve_space_1,
1193       DeviceMemory<float>* reserve_space_2, bool is_training,
1194       ScratchAllocator* reserve_space_allocator,
1195       ScratchAllocator* workspace_allocator) {
1196     return false;
1197   }
1198 
1199   // Performs a single-precision backward batch normalization gradient
1200   // computation operation onto the stream.
1201   //
1202   // Arguments:
1203   //  stream: borrowed pointer to the stream that the batch normalization
1204   //    gradient computation operation should be enqueued onto.
1205   //  y_backprop: gradient with regard to output y.
1206   //  x: input data.
1207   //  scale: scaling parameters.
1208   //  inv_var: 1/sqrt(epsilon + variance) of x.
1209   //  x_desc: dimensions of the input data, which is the same as the dimensions
1210   //    of the output.
1211   //  scale_offset_desc: dimensions of scale and offset.
1212   //  epsilon: a small floating point number added to the variance of x.
1213   //  x_backprop: gradient with respect to input x.
1214   //  scale_backprop: gradient with respect to scale.
1215   //  offset_backprop: gradient with respect to offset.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<float> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<float> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1216   virtual bool DoBatchNormalizationBackward(
1217       Stream* stream, const DeviceMemory<float>& y_backprop,
1218       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
1219       const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
1220       const DeviceMemory<float>& inv_var, const DeviceMemory<float>& y,
1221       const dnn::BatchDescriptor& x_desc,
1222       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1223       dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop,
1224       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
1225       DeviceMemory<float>* side_input_backprop,
1226       DeviceMemory<uint8>* reserve_space_data,
1227       ScratchAllocator* workspace_allocator) {
1228     return false;
1229   }
1230 
1231   // Performs a half-precision backward batch normalization gradient computation
1232   // operation onto the stream. See DoBatchNormalizationBackward above for
1233   // argument details.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<Eigen::half> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<Eigen::half> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1234   virtual bool DoBatchNormalizationBackward(
1235       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
1236       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
1237       const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
1238       const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y,
1239       const dnn::BatchDescriptor& x_desc,
1240       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1241       dnn::ActivationMode activation_mode,
1242       DeviceMemory<Eigen::half>* x_backprop,
1243       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
1244       DeviceMemory<Eigen::half>* side_input_backprop,
1245       DeviceMemory<uint8>* reserve_space_data,
1246       ScratchAllocator* workspace_allocator) {
1247     return false;
1248   }
1249 
1250   // Enqueues a fused convolution operation onto the stream.
1251   // We provide several variants with different types for inputs, biases and
1252   // scaling parameters.
1253   //
1254   // Arguments (all borrowed):
1255   //  stream: borrowed pointer to the stream that the 'convolve' operation
1256   //    should be enqueued onto.
1257   //  conv_input_descriptor: dimensions of the convolution input layer.
1258   //  conv_input_data: un-owned device memory region which contains the
1259   //    convolution input.
1260   //  conv_input_scale: a floating point scale to multiply with each element
1261   //    of conv_input_data.
1262   //  filter_descriptor: dimensions of the convolution filter.
1263   //  filter_data: un-owned device memory region which contains the
1264   //    convolution filter weights.
1265   //  convolution_descriptor: stride of the convolution filter.
1266   //  biases: un-owned device memory region containing biases to add to the
1267   //    input.
1268   //  activation_mode: Type of activation to perform.
1269   //  side_input_data: un-owned device memory region which contains optional
1270   //    side input data. If 'side_input_scale' is non-zero, then this must
1271   //    point to data in the tensor shape specified by output_shape.
1272   //    It will be scaled by 'side_input_scale' and added to the convolution
1273   //    result and bias prior to applying the activation function.
1274   //  side_input_scale: a floating point scale to multiply with each element
1275   //    of side_input_data.
1276   //  output_descriptor: dimensions of the output layer.
1277   //  output_data: un-owned device memory region in which to place the
1278   //    convolution result.
1279   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1280   //    space in order to speed up the convolution operation.
1281   //  algorithm_config: specifies which algorithm should be used for the
1282   //    operation.
1283   //  output_profile_result: the output profile result for this call. The
1284   //    profiling is only enabled when this is not nullptr.
1285   //
1286   // conv_input_descriptor, filter_descriptor, convolution_descriptor and
1287   // output_descriptor together specify exactly how the convolution is aligned
1288   // with the input data:
1289   //
1290   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1291   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1292   // * input dimensions / filter stride == output dimensions
1293   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1294   //   same size - this requires padding the input.
1295   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1296   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1297   //   that if the inverse of the filter is applied to the output in VALID mode
1298   //   the result is the same size as the input - this requires even more
1299   //   padding of the input.
DoFusedConvolve(Stream * stream,DataType input_type,DataType side_input_type,DataType bias_type,DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1300   virtual port::Status DoFusedConvolve(
1301       Stream* stream, DataType input_type, DataType side_input_type,
1302       DataType bias_type, DataType output_type,
1303       const dnn::BatchDescriptor& conv_input_descriptor,
1304       DeviceMemoryBase conv_input_data, double conv_input_scale,
1305       const dnn::FilterDescriptor& filter_descriptor,
1306       DeviceMemoryBase filter_data,
1307       const dnn::ConvolutionDescriptor& convolution_descriptor,
1308       DeviceMemoryBase side_input_data, double side_input_scale,
1309       const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
1310       dnn::ActivationMode activation_mode,
1311       const dnn::BatchDescriptor& output_descriptor,
1312       DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
1313       const dnn::AlgorithmConfig& algorithm_config,
1314       dnn::ProfileResult* output_profile_result) {
1315     return port::UnimplementedError(
1316         "DnnSupport::DoFusedConvolve not implemented on this platform.");
1317   }
1318 
1319   template <typename ElementType, typename OutputType>
PrepareForConvolution(ConvolutionKind kind,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemory<ElementType> input_data,const FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> filter_data,const BatchDescriptor & output_descriptor,DeviceMemory<OutputType> output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)1320   port::Status PrepareForConvolution(
1321       ConvolutionKind kind, Stream* stream,
1322       const BatchDescriptor& batch_descriptor,
1323       DeviceMemory<ElementType> input_data,
1324       const FilterDescriptor& filter_descriptor,
1325       DeviceMemory<ElementType> filter_data,
1326       const BatchDescriptor& output_descriptor,
1327       DeviceMemory<OutputType> output_data,
1328       const ConvolutionDescriptor& convolution_descriptor,
1329       const AlgorithmConfig& algorithm_config,
1330       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
1331       DeviceMemory<uint8>* scratch_memory) {
1332     return DoPrepareForConvolution(
1333         kind, ToDataType<ElementType>::value, stream, batch_descriptor,
1334         input_data, filter_descriptor, filter_data, output_descriptor,
1335         output_data, convolution_descriptor, algorithm_config,
1336         scratch_allocator, algorithm_desc, scratch_memory);
1337   }
1338 
1339   // Enqueues a single-precision convolution operation onto the stream.
1340   //
1341   // Arguments (all borrowed):
1342   //  stream: borrowed pointer to the stream that the 'convolve' operation
1343   //    should be enqueued onto.
1344   //  input_descriptor: dimensions of the input layer.
1345   //  input_data: un-owned device memory region which contains the
1346   //    convolution input.
1347   //  filter_descriptor: dimensions of the convolution filter.
1348   //  convolution_descriptor: stride of the convolution filter.
1349   //  output_descriptor: dimensions of the output layer.
1350   //  output_data: un-owned device memory region in which to place the
1351   //    convolution result.
1352   //  algorithm_desc: specifies which algorithm should be used for the
1353   //    operation.
1354   //  scratch: un-owned device memory for scratch space in order to speed up
1355   //    the convolution operation.
1356   //  output_profile_result: the output profile result for this call. The
1357   //    profiling is only enabled when this is not nullptr.
1358   //
1359   // input_descriptor, filter_descriptor, convolution_descriptor and
1360   // output_descriptor together specify exactly how the convolution is aligned
1361   // with the input data:
1362   //
1363   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1364   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1365   // * input dimensions / filter stride == output dimensions
1366   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1367   //   same size - this requires padding the input.
1368   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1369   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1370   //   that if the inverse of the filter is applied to the output in VALID mode
1371   //   the result is the same size as the input - this requires even more
1372   //   padding of the input.
1373   virtual port::Status DoConvolve(
1374       ConvolutionKind kind, DataType element_type, DataType output_type,
1375       Stream* stream, const BatchDescriptor& input_descriptor,
1376       DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor,
1377       DeviceMemoryBase filter_data, const BatchDescriptor& output_descriptor,
1378       DeviceMemoryBase output_data,
1379       const ConvolutionDescriptor& convolution_descriptor,
1380       AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
1381       ProfileResult* output_profile_result) = 0;
1382 
1383   // Return a list of algorithms supported by the forward convolution pass.
1384   // cc_major and cc_minor are the compute capabilities of the device.
1385   virtual bool GetConvolveAlgorithms(
1386       CudaComputeCapability cuda_compute_capability,
1387       std::vector<AlgorithmDesc>* out_algorithms);
1388 
1389   virtual port::Status GetConvolveRunners(
1390       bool use_cudnn_frontend, dnn::ConvolutionKind kind,
1391       dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
1392       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
1393       const dnn::FilterDescriptor& filter_descriptor,
1394       DeviceMemoryBase filter_data,
1395       const dnn::BatchDescriptor& output_descriptor,
1396       DeviceMemoryBase output_data,
1397       const dnn::ConvolutionDescriptor& convolution_descriptor,
1398       bool use_fallback, ScratchAllocator* scratch_allocator,
1399       std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans);
1400 
1401   virtual port::StatusOr<std::unique_ptr<const dnn::ConvRunner>>
1402   ConvolveRunnerFromDesc(
1403       Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
1404       dnn::ConvolutionKind kind, dnn::DataType element_type,
1405       dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
1406       const dnn::FilterDescriptor& filter_descriptor,
1407       const dnn::BatchDescriptor& output_descriptor,
1408       const dnn::ConvolutionDescriptor& convolution_descriptor);
1409 
1410   virtual port::Status GetFusedConvolveRunners(
1411       bool use_cudnn_frontend, dnn::ConvolutionKind kind,
1412       dnn::DataType element_type, dnn::DataType bias_type,
1413       dnn::DataType output_type, double conv_input_scale,
1414       double side_input_scale, double leakyrelu_alpha, Stream* stream,
1415       const dnn::BatchDescriptor& input_descriptor,
1416       const dnn::FilterDescriptor& filter_descriptor,
1417       const dnn::BatchDescriptor& bias_descriptor,
1418       const dnn::BatchDescriptor& output_descriptor,
1419       const dnn::ConvolutionDescriptor& convolution_descriptor,
1420       bool use_fallback, dnn::ActivationMode activation_mode,
1421       std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans);
1422 
1423   virtual port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
1424   FusedConvolveRunnerFromDesc(
1425       Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
1426       dnn::ConvolutionKind kind, dnn::DataType element_type,
1427       dnn::DataType bias_type, dnn::DataType output_type, double conv_scale,
1428       double side_input_scale, double leakyrelu_alpha,
1429       const dnn::BatchDescriptor& input_descriptor,
1430       const dnn::FilterDescriptor& filter_descriptor,
1431       const dnn::BatchDescriptor& bias_descriptor,
1432       const dnn::BatchDescriptor& output_descriptor,
1433       const dnn::ConvolutionDescriptor& convolution_descriptor,
1434       dnn::ActivationMode activation_mode);
1435 
1436   virtual bool GetMIOpenConvolveAlgorithms(
1437       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
1438       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
1439       const dnn::FilterDescriptor& filter_descriptor,
1440       DeviceMemoryBase filter_data,
1441       const dnn::BatchDescriptor& output_descriptor,
1442       DeviceMemoryBase output_data,
1443       const dnn::ConvolutionDescriptor& convolution_descriptor,
1444       ScratchAllocator* scratch_allocator,
1445       std::vector<ProfileResult>* out_algorithms);
1446 
1447   // Returns a list of supported rnn algorithms.
1448   virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
1449 
1450   // Version of DoConvolve that uses pre-quantized 8 bit coefficients.
1451   // coefficient_scales specifies the scaling of each column of coefficients:
1452   // original float coefficient[row * num_columns + column] =
1453   //     quantized coefficient[row * num_columns + column] *
1454   //     coefficient_scales[column].
1455   virtual bool DoConvolveQuantized(
1456       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1457       const DeviceMemory<float>& input_data,
1458       const dnn::FilterDescriptor& filter_descriptor,
1459       const DeviceMemory<int8>& filter_coefficients,
1460       const DeviceMemory<float>& coefficient_scales,
1461       const dnn::ConvolutionDescriptor& convolution_descriptor,
1462       const dnn::BatchDescriptor& output_descriptor,
1463       DeviceMemory<float>* output_data) = 0;
1464 
1465   // Same as DoConvolveQuantized above, but int8 filter coefficients.
1466   virtual bool DoConvolveQuantized(
1467       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1468       const DeviceMemory<float>& input_data,
1469       const dnn::FilterDescriptor& filter_descriptor,
1470       const DeviceMemory<int16>& filter_coefficients,
1471       const DeviceMemory<float>& coefficient_scales,
1472       const dnn::ConvolutionDescriptor& convolution_descriptor,
1473       const dnn::BatchDescriptor& output_descriptor,
1474       DeviceMemory<float>* output_data) = 0;
1475 
1476   // Variation of the above with the weight matrix split into two matrices.
1477   // first_weights: Coefficients of the first matrix.
1478   // second_weights: Coefficients of the second matrix.
1479   // depth_multiplier: specifies the columns of the first matrix and rows
1480   // of the second one - first_weights columns = depth_multiplier,
1481   // second_weights rows = depth_multiplier *
1482   //                       filter_descriptor.input_feature_map_count().
1483   // see go/separable for documentation on separable convolutions.
1484   virtual bool DoSeparableConvolve(
1485       Stream* stream, const BatchDescriptor& input_descriptor,
1486       const DeviceMemory<float>& input_data,
1487       const FilterDescriptor& filter_descriptor, int depth_multiplier,
1488       const DeviceMemory<float>& first_weights,
1489       const DeviceMemory<float>& second_weights,
1490       const ConvolutionDescriptor& convolution_descriptor,
1491       const BatchDescriptor& output_descriptor,
1492       DeviceMemory<float>* output_data) = 0;
1493 
1494   // Return a list of algorithms supported by the backward convolution pass for
1495   // data.
1496   virtual bool GetConvolveBackwardDataAlgorithms(
1497       CudaComputeCapability cuda_compute_capability,
1498       std::vector<AlgorithmDesc>* out_algorithms);
1499 
1500   // Return a list of algorithms supported by the backward convolution pass for
1501   // filters.
1502   virtual bool GetConvolveBackwardFilterAlgorithms(
1503       CudaComputeCapability cuda_compute_capability,
1504       std::vector<AlgorithmDesc>* out_algorithms);
1505 
1506   // Fully connects the "nodes" (float values) in input_data with
1507   // shape input_dimensions to output_data with output_dimensions
1508   // using provided weights. This is equivalent to computing a matrix
1509   // product, hence the name MatMul.
1510   //
1511   // A BatchDescriptor has four dimensions: batch, y, x, depth. Matrix products
1512   // happen in two dimensions. To get down to two dimensions, we consider the
1513   // input y, x and depth dimension as one combined dimension T. For now,
1514   // assume that the output height and width are 1 and let OD be the output
1515   // depth.
1516   //
1517   // There are three device memory buffers passed in to this
1518   // function. We can now view all three as matrices:
1519   //
1520   //   input_data: A batch x T matrix
1521   //   weights: A T x OD matrix
1522   //   output_data: A batch x OD matrix
1523   //
1524   // This function then computes the matrix product of input_data and
1525   // weights and writes the result into output_data.
1526   //
1527   // Here the weights buffer is in row major order, i.e. the first OD
1528   // entries in weights are the first row, the second OD entries in
1529   // weights are the second row and so on.
1530   //
1531   // The case for output width*height > 1 is more complicated. Let K =
1532   // OY * OX where OY is the output height and OX is the output
1533   // width. Then weights is divided into K sub-arrays W_i, for
1534   // i=0,...,k-1, that each represent a T x OD matrix. This function
1535   // then computes the K matrix multiplications of input_data with
1536   // each W_i. This creates K matrices with dimensions batch x
1537   // OD. These K matrices are concatenated horizontally to form one
1538   // larger matrix with dimensions batch x (K*OD); note that this is
1539   // not the same as concatenating the bytes of the matrices. The
1540   // combined matrix can then be interpreted as a tensor with
1541   // dimensions (batch, OY, OX, OD). If the output tensor format is
1542   // not kBatchYXDepth, this function would then need to arrange for
1543   // the output to be in the requested layout, if that is
1544   // supported. Note that the case K=1 is equivalent to the
1545   // description above. It is recommended to prefer the case K=1.
1546   //
1547   // Arguments (all borrowed):
1548   //  stream: borrowed pointer to the stream that the 'fully connect' operation
1549   //    should be enqueued onto.
1550   //  output_data: un-owned device memory region in which to place the
1551   //    fully connected result.
1552   virtual bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
1553                         const DeviceMemory<float>& weights,
1554                         const dnn::BatchDescriptor& input_dimensions,
1555                         const dnn::BatchDescriptor& output_dimensions,
1556                         DeviceMemory<float>* output_data) = 0;
1557 
1558   // Version of DoMatMul that uses pre-quantized 8 bit weights.
1559   // weight_scales specifies the scaling of each column of weights:
1560   // original float weight[row * num_columns + column] =
1561   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1562   virtual bool DoMatMulQuantized(Stream* stream,
1563                                  const DeviceMemory<float>& input_data,
1564                                  const DeviceMemory<int8>& quantized_weights,
1565                                  const DeviceMemory<float>& weight_scales,
1566                                  const dnn::BatchDescriptor& input_dimensions,
1567                                  const dnn::BatchDescriptor& output_dimensions,
1568                                  DeviceMemory<float>* output_data) = 0;
1569 
1570   // Version of DoMatMul that uses pre-quantized 16 bit weights.
1571   // weight_scales specifies the scaling of each column of weights:
1572   // original float weight[row * num_columns + column] =
1573   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1574   virtual bool DoMatMulQuantized(Stream* stream,
1575                                  const DeviceMemory<float>& input_data,
1576                                  const DeviceMemory<int16>& quantized_weights,
1577                                  const DeviceMemory<float>& weight_scales,
1578                                  const dnn::BatchDescriptor& input_dimensions,
1579                                  const dnn::BatchDescriptor& output_dimensions,
1580                                  DeviceMemory<float>* output_data) = 0;
1581 
1582   // Adds biases to the feature maps in input_data producing
1583   // output_data. input_data can equal output_data, but must not
1584   // partially overlap it.
1585   //
1586   // Let K = count() * height() * width() and N = feature_map_count()
1587   // on dimensions. Then input_value contains K*N values and biases
1588   // contains N values. We can thus logically consider input_value to
1589   // contain K vectors of N elements each. This function adds biases
1590   // to each of those N vectors.
1591   //
1592   // TODO(broune): This works differently when width() * height() > 1
1593   // and the call to ThenBiasAdd() follows a call to ThenMatMul(). In
1594   // that case there should be width() * height() *
1595   // feature_map_count() biases, but this is not implemented on all
1596   // StreamExecutors.
1597   //
1598   // Arguments (all borrowed):
1599   //  stream: borrowed pointer to the stream that the 'bias add' operation
1600   //    should be enqueued onto.
1601   //  input_data: un-owned device memory region containing the input.
1602   //  biases: un-owned device memory region containing biases to add to the
1603   //    input.
1604   //  dimensions: dimensions of input_data and output_data.
1605   //  output_data: un-owned device memory region in which to place the result.
1606   virtual bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
1607                          const DeviceMemory<float>& biases,
1608                          const dnn::BatchDescriptor& dimensions,
1609                          DeviceMemory<float>* output_data) = 0;
1610 
1611   // Performs a forward pooling operation on input_data, writing to
1612   // output_data. See PoolingDescriptor for how to configure the
1613   // pooling operation.
1614   //
1615   // Pooling happens as a window that moves across the Y and X
1616   // dimensions of input_data, where each position of the window
1617   // yields one output value. E.g. for max pooling, the computed value
1618   // is the maximum element in the window. The operation is applied
1619   // independently to each batch and at each feature map (depth), so
1620   // that the output depth and feature_map_count are the same as for
1621   // the input. The output width and height can be different.
1622   //
1623   // See PoolingDescriptor for how to configure the pooling operation.
1624   virtual port::Status DoPoolForward(
1625       DataType element_type, Stream* stream,
1626       const dnn::PoolingDescriptor& pooling_dimensions,
1627       const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
1628       const dnn::BatchDescriptor& output_dimensions,
1629       DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator) = 0;
1630 
1631   // Performs differentiation of the pooling operation.
1632   virtual port::Status DoPoolBackward(
1633       DataType element_type, Stream* stream,
1634       const dnn::PoolingDescriptor& pooling_dimensions,
1635       const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
1636       const dnn::BatchDescriptor& output_dimensions,
1637       DeviceMemoryBase output_data, DeviceMemoryBase input_diff_data,
1638       DeviceMemoryBase output_diff_data,
1639       ScratchAllocator* workspace_allocator) = 0;
1640 
1641   // Applies local response normalization to the values from input_data and
1642   // writes the result to output_data.
1643   //
1644   // See comments on NormalizeDescriptor for a description of local response
1645   // normalization.
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1646   virtual bool DoNormalizeWithDimensions(
1647       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1648       const dnn::BatchDescriptor& dimensions,
1649       const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
1650     return false;
1651   }
1652 
1653   // Performs backpropagation for the normalization operation
1654   //
1655   // Given raw data, its corresponding normalized output, and a gradient of some
1656   // unspecified function with respect to the normalized variables, computes the
1657   // gradient of that unspecified function with respect to the raw variables.
1658   //
1659   // The normalized data input array is expected to match the output that would
1660   // be obtained by running the raw data input array through the DoNormalize
1661   // method above.
1662   //
1663   // See comments on NormalizeDescriptor for a description of local response
1664   // normalization.
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1665   virtual bool DoNormalizeBackwardWithDimensions(
1666       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1667       const dnn::BatchDescriptor& dimensions,
1668       const DeviceMemory<float>& raw_data,
1669       const DeviceMemory<float>& normalized_data,
1670       const DeviceMemory<float>& normalized_variable_gradient,
1671       DeviceMemory<float>* raw_variable_gradient,
1672       ScratchAllocator* workspace_allocator) {
1673     return false;
1674   }
1675 
1676   // Applies an activation function (see ActivationMode) to all of the values
1677   // held on the device in 'input_data', whose dimensions are described by
1678   // 'dimensions'.
1679   //
1680   // Arguments (all borrowed):
1681   //  stream: borrowed pointer to the stream that the 'activate' operation
1682   //    should be enqueued onto.
1683   //  activation_mode: Type of activation to perform.
1684   //  input_data: un-owned device memory region which contains the
1685   //    activate input.
1686   //  output_data: un-owned device memory region in which to place the
1687   //    activate result.
DoActivate(Stream * stream,ActivationMode activation_mode,const BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64_t options)1688   virtual bool DoActivate(Stream* stream, ActivationMode activation_mode,
1689                           const BatchDescriptor& dimensions,
1690                           const DeviceMemory<float>& input_data,
1691                           DeviceMemory<float>* output_data, uint64_t options) {
1692     return false;
1693   }
1694 
1695   // Concatenates several layers into one, by concatenating the depth of each
1696   // layer at matching x and y coordinates.
1697   // The inputs must all have the same width and height, the output will have
1698   // the same width and height as the inputs and its depth will be the sum of
1699   // the input depths.
1700   //
1701   // Arguments (all borrowed):
1702   //  stream: borrowed pointer to the stream that the 'depth concatenate'
1703   // operation should be enqueued onto.
1704   //  input_dimensions: The dimensions of each input.
1705   //  input_data: un-owned device memory region which contains the
1706   //    input data for each input layer.
1707   //  output_data: un-owned device memory region in which to place the
1708   //    depth concatenate result.
1709   virtual bool DoDepthConcatenate(
1710       Stream* stream,
1711       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,  // non-absl ok
1712       port::ArraySlice<const DeviceMemory<float>*> input_data,  // non-absl ok
1713       DeviceMemory<float>* output_data) = 0;
1714 
1715   // Concatenates several layers into one, by concatenating each in the
1716   // x-dimension or y-dimension, based on a user-specified flag.
1717   // For x-concatenation, layers are aligned at matching y and depth
1718   // coordinates, and for y-concatenation, they are aligned at matching x and
1719   // depth coordinates. The inputs must all have the same depth and batch size.
1720   // For x-concatenation, the inputs must have the same height (y-size), and the
1721   // output will have the same depth and height as the inputs and its width (x-
1722   // size) will be the sum of the input widths.  For y-concatenation, the inputs
1723   // must have the same width, and the output will have the same depth and width
1724   // as the inputs, and its height will be the sum of the input heights.
1725   //
1726   // Arguments:
1727   //  stream: borrowed pointer to the stream that the 'space concatenate'
1728   //    operation should be enqueued onto.
1729   //  input_dimensions: the dimensions of each input.
1730   //  input_data: un-owned device memory region which contains the input data
1731   //    for each input layer.
1732   //  output_data: un-owned device memory region in which to place the space
1733   //    concatenate result.
1734   //  concat_direction:  either dnn:SpaceConcatenateMode::XDirection or
1735   //    dnn::SpaceConcatenateMode::YDirection.
DoSpaceConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1736   virtual bool DoSpaceConcatenate(
1737       Stream* stream,
1738       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,  // non-absl ok
1739       port::ArraySlice<const DeviceMemory<float>*> input_data,  // non-absl ok
1740       DeviceMemory<float>* output_data,
1741       dnn::SpaceConcatenateMode concat_direction) {
1742     return false;
1743   }
1744 
1745   // Change the layout of the data by shrinking one dimension (or set of
1746   // dimensions) and growing another dimension (or set of dimensions), while
1747   // keeping the total number of data elements constant, and maintaining the
1748   // current data ordering.
1749   //
1750   // Currently, the only supported operation is depth into space by a power of
1751   // 2. E.g. (y, x, z) -> (y*2, x*2, z/4)
1752   //
1753   // Note that Reshape may not be a no-op, depending on the platform and which
1754   // dimensions are being changed.
1755   //
1756   // Example: forgetting about batch for the moment, let's take a tensor that's
1757   // 2x1x8 (y by x by z) and reshape to a tensor that's 4x2x2. The memory layout
1758   // is row-major order: y,x,z. I.e. z changes the fastest, then x, then y. The
1759   // elements of the tensor range from 0 to 15. The x,y,z indices are below each
1760   // element.
1761   //
1762   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1763   // y0 y0 y0 y0 y0 y0 y0 y0 y1 y1 y1 y1 y1 y1 y1 y1
1764   // x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0
1765   // z0 z1 z2 z3 z4 z5 z6 z7 z0 z1 z2 z3 z4 z5 z6 z7
1766   //
1767   // reshape to 4x2x2
1768   //
1769   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1770   // y0 y0 y0 y0 y1 y1 y1 y1 y2 y2 y2 y2 y3 y3 y3 y3
1771   // x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1
1772   // z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1
DoReshape(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1773   virtual bool DoReshape(Stream* stream,
1774                          const dnn::BatchDescriptor& input_dimensions,
1775                          const DeviceMemory<float>& input_data,
1776                          const dnn::BatchDescriptor& output_dimensions,
1777                          DeviceMemory<float>* output_data) {
1778     return false;
1779   }
1780 
1781   // Depth to space takes an X by Y image with depth D*M^2 and changes it to an
1782   // MX x MY image with depth D. Each input location (x,y) with depth D*M^2 in
1783   // the input image is changed to an MxM contiguous area in the output image,
1784   // with the values being laid out in the raster order by DepthToSpaceLayout,
1785   // and will have a new depth of D.
1786   //
1787   // Example.
1788   // M=2, Din =8, Xin=2, Yin=2. Xout=4, Yout=4,  Dout=2
1789   // DepthHeightWidth layout
1790   // Values within a 'cell' are at different depths and same x & y.
1791   // Input:
1792   // abcdefgh  ijklmnop
1793   // qrstuvwx  yz012345
1794   // Output:
1795   // ae bf im jn
1796   // cg dh ko lp
1797   // qu rv y2 z3
1798   // sw tx 04 15
1799   //
1800   // sqrt_depth_reduction: 'M' in the comment above
DoDepthToSpace(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & depth_to_space_layout,const int & sqrt_depth_reduction,DeviceMemory<float> * output_data)1801   virtual bool DoDepthToSpace(Stream* stream,
1802                               const dnn::BatchDescriptor& input_dimensions,
1803                               const DeviceMemory<float>& input_data,
1804                               const DepthToSpaceLayout& depth_to_space_layout,
1805                               const int& sqrt_depth_reduction,
1806                               DeviceMemory<float>* output_data) {
1807     return false;
1808   }
1809 
1810   // Space to depth is the inverse of depth to space. Space to depth takes each
1811   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
1812   // the input, and transforms it to a 1 by 1 patch with depth D*M^2. If the
1813   // input has size (MX, MY, D), the output has size (X, Y, D*M^2). The number
1814   // of data elements is not changed.
1815   //
1816   // Example.
1817   // M=2, Din =2, Xin=4, Yin=4,  Dout=8
1818   // DepthHeightWidth layout
1819   // Values within a 'cell' are at different depths and same x & y.
1820   // Input:
1821   // ae bf im jn
1822   // cg dh ko lp
1823   // qu rv y2 z3
1824   // sw tx 04 15
1825   // Output:
1826   // abcdefgh  ijklmnop
1827   // qrstuvwx  yz012345
1828   //
1829   // sqrt_depth_increase: 'M' in the comment above
DoSpaceToDepth(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & space_to_depth_layout,const int & sqrt_depth_increase,DeviceMemory<float> * output_data)1830   virtual bool DoSpaceToDepth(Stream* stream,
1831                               const dnn::BatchDescriptor& input_dimensions,
1832                               const DeviceMemory<float>& input_data,
1833                               const DepthToSpaceLayout& space_to_depth_layout,
1834                               const int& sqrt_depth_increase,
1835                               DeviceMemory<float>* output_data) {
1836     return false;
1837   }
1838 
1839   // Computes the specified operation (e.g. addition or multiplication)
1840   // between corresponding elements in the inputs and stores the result in the
1841   // output element.
1842   // The inputs and output must all have the same dimensions, but may have
1843   // different quantization parameters (min_value and max_value).
1844   //
1845   // Arguments (all borrowed):
1846   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1847   // should be enqueued onto.
1848   //  operation: The operation to perform.
1849   //  input_dimensions: The dimensions of each input.
1850   //  input_data: un-owned device memory region which contains the
1851   //    input data for each input layer.
1852   //  output_dimensions: The dimensions of the output.
1853   //  output_data: un-owned device memory region in which to place the
1854   //    operation result.
1855   virtual bool DoElementwiseOperate(
1856       Stream* stream, ElementwiseOperation operation,
1857       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,  // non-absl ok
1858       port::ArraySlice<const DeviceMemory<float>*> input_data,  // non-absl ok
1859       const dnn::BatchDescriptor& output_dimensions,
1860       DeviceMemory<float>* output_data) = 0;
1861 
1862   // Computes the specified operation (e.g. addition or multiplication)
1863   // between corresponding elements in the inputs and stores the result in the
1864   // output element. Each input is multiplied by a scalar constant and the
1865   // result is divided by a scalar constant.
1866   // e.g. To perform Z = 0.9*X + 1.1*Y, set the input multiplicands to 9 and 11
1867   // and the output divisor to 10.
1868   // The inputs and output must all have the same dimensions, but may have
1869   // different quantization parameters (min_value and max_value).
1870   //
1871   // Arguments (all borrowed):
1872   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1873   // should be enqueued onto.
1874   //  operation: The operation to perform.
1875   //  input_multiplicands: Amount to scale each input.
1876   //  output_divisor: Amount to divide the output.
1877   //  input_dimensions: The dimensions of each input.
1878   //  input_data: un-owned device memory region which contains the
1879   //    input data for each input layer.
1880   //  output_dimensions: The dimensions of the output.
1881   //  output_data: un-owned device memory region in which to place the
1882   //    operation result.
DoElementwiseOperateScaledQuantized(Stream * stream,ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1883   virtual bool DoElementwiseOperateScaledQuantized(
1884       Stream* stream, ElementwiseOperation operation,
1885       port::ArraySlice<int> input_multiplicands,  // non-absl ok
1886       int output_divisor,
1887       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,  // non-absl ok
1888       port::ArraySlice<const DeviceMemory<float>*> input_data,  // non-absl ok
1889       const dnn::BatchDescriptor& output_dimensions,
1890       DeviceMemory<float>* output_data) {
1891     return false;
1892   }
1893 
1894   // Pads the input with zeros in the X and Y dimensions. The feature_map
1895   // dimension is unchanged.
1896   //
1897   // Arguments (all borrowed):
1898   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1899   // should be enqueued onto.
1900   //  dimensions: The dimensions of the input.
1901   //  input_data: un-owned device memory region which contains the
1902   //    input data for the input layer.
1903   //  left_pad: Amount to pad the input on the left.
1904   //  right_pad: Amount to pad the input on the right.
1905   //  top_pad: Amount to pad the input at the top (low Y).
1906   //  bottom_pad: Amount to pad the input at the bottom (high Y).
1907   //  output_data: un-owned device memory region in which to place the
1908   //    padded result.
1909   virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
1910                        const DeviceMemory<float>& input_data, int64_t left_pad,
1911                        int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
1912                        DeviceMemory<float>* output_data) = 0;
1913 
1914   // Extracts a slice of the input in the X and Y dimensions. The feature_map
1915   // dimension is unchanged.
1916   //
1917   // Arguments (all borrowed):
1918   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1919   // should be enqueued onto.
1920   //  dimensions: The dimensions of the input.
1921   //  input_data: un-owned device memory region which contains the
1922   //    input data for the input layer.
1923   //  left_trim: Amount to cut off the input on the left.
1924   //  right_trim: Amount to cut off the input on the right.
1925   //  top_trim: Amount to cut off the input at the top (low y).
1926   //  bottom_trim: Amount to cut off the input at the bottom (high Y).
1927   //  output_data: un-owned device memory region in which to place the
1928   //    padded result.
1929   virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
1930                          const DeviceMemory<float>& input_data,
1931                          int64_t left_trim, int64_t right_trim,
1932                          int64_t top_trim, int64_t bottom_trim,
1933                          DeviceMemory<float>* output_data) = 0;
1934 
1935   // Grows the input tensor by replicating the X and Y dimensions. The batch and
1936   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
1937   // limited to X=1 and Y=1.
1938   //
1939   // For example, the input has dimensions x=2, y=3, and replicate_x=3,
1940   // replicate_y=2. The diagonal elements of the output would be: [x0y0, x1y1,
1941   // x0y2, x1y0, x0y1, x1y2].
1942   // Here is the example as a picture. input:
1943   // AB
1944   // CD
1945   // EF
1946   // broadcast result:
1947   // ABABAB
1948   // CDCDCD
1949   // EFEFEF
1950   // ABABAB
1951   // CDCDCD
1952   // EFEFEF
1953   //
1954   // Arguments (all borrowed):
1955   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1956   // should be enqueued onto.
1957   //  dimensions: The dimensions of the input.
1958   //  input_data: un-owned device memory region which contains the
1959   //    input data for the input layer.
1960   //  replicate_x: Amount to replicate the input's X dimension.
1961   //  replicate_y: Amount to replicate the input's Y dimension.
1962   //  output_data: un-owned device memory region in which to place the
1963   //    padded result.
DoXYBroadcast(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t replicate_x,int64_t replicate_y,DeviceMemory<float> * output_data)1964   virtual bool DoXYBroadcast(Stream* stream,
1965                              const dnn::BatchDescriptor& dimensions,
1966                              const DeviceMemory<float>& input_data,
1967                              int64_t replicate_x, int64_t replicate_y,
1968                              DeviceMemory<float>* output_data) {
1969     return false;
1970   }
1971 
1972   // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
1973   // is, bytes instead of scaled floats) into 'host_dst' if they are available
1974   // for the underlying DNN implementation. If this quantized output is not
1975   // available, false is returned, which will place 'stream' into an error
1976   // state.
1977   //
1978   // Arguments (all borrowed):
1979   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
1980   //    operation should be enqueued onto.
1981   //  gpu_unquantized_src: the device memory that contains the unquantized data
1982   //    -- this data should also have a corresponding quantized representation
1983   //    on the device for this operation to succeed.
1984   //  mode: Type of quantization of the data to write into host_dst.
1985   //  host_dst: un-owned host memory region that is mutated in place,
1986   //    it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
1987   //    (asynchronous) memcpy operation is performed.
1988   //  size: size in bytes of the host_dst host memory region.
1989   virtual bool DoMemcpyD2HQuantized(
1990       Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
1991       QuantizedActivationMode mode, void* host_dst, int64_t size) = 0;
1992 
1993   // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
1994   // of a layer (that is, bytes instead of scaled floats) if they are supported
1995   // by the underlying DNN implementation. If this quantized input is not
1996   // supported, false is returned, which will place 'stream' into an error
1997   // state.
1998   //
1999   // Arguments (all borrowed):
2000   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
2001   //    operation should be enqueued onto.
2002   //  host_src: un-owned host memory region that contains the quantized data.
2003   //  size: size in bytes of the host_src host memory region.
2004   //  mode: Type of quantization of the data to read from host_src.
2005   //  gpu_unquantized_dst: the device memory that is clobbered by the values in
2006   //    'host_src' when the enqueued (asynchronous) memcpy operation is
2007   //    performed. -- this data should also have a corresponding quantized
2008   //    representation on the device for this operation to
2009   //    succeed.
2010   virtual bool DoMemcpyH2DQuantized(
2011       Stream* stream, const void* host_src, int64_t size,
2012       QuantizedActivationMode mode,
2013       DeviceMemory<float>* gpu_unquantized_dst) = 0;
2014 
2015   // Create an RNN descriptor based on model shapes and configurations.
2016   // The caller retains the ownership of the descriptor.
2017   //
2018   // Arguments:
2019   //  num_layers: the number of layers for a RNN model.
2020   //  hidden_size: the size of the hidden state.
2021   //  input_size: the size of the input state.
2022   //  cell_size: the size of the cell state
2023   //  input_mode: an enum to specify whether a linear transformation is added
2024   //    after the input state. If input_size is different from hidden_size, this
2025   //    is required.
2026   //  direction_mode: an enum to specify whether this model is unidirectional or
2027   //    bidirectional.
2028   //  rnn_mode: an enum to specify the type of model to build.
2029   //  data_type: an enum to specify the data types used in this model.
2030   //  dropout: the dropout threshold between layers. When it is 0., no dropout
2031   //    is added.
2032   //  seed: a seed for initializing the dropout layers.
2033   //  state_allocator: an memory allocator that will be used to store the state
2034   //    for dropout layer. The user has to maintain the memory until the model
2035   //    is no longer in use.
2036   //  use_padded_io: a bool to specify whether the input is using padded IO.
2037   virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)2038   createRnnDescriptor(int num_layers, int hidden_size, int input_size,
2039                       int cell_size, int batch_size,
2040                       dnn::RnnInputMode input_mode,
2041                       dnn::RnnDirectionMode direction_mode,
2042                       dnn::RnnMode rnn_mode, dnn::DataType data_type,
2043                       const dnn::AlgorithmConfig& algorithm_config,
2044                       float dropout, uint64_t seed,
2045                       ScratchAllocator* state_allocator, bool use_padded_io) {
2046     return port::Status(port::error::UNIMPLEMENTED,
2047                         "createRnnDescriptor is unimplemented");
2048   }
2049 
2050   // Create a RNN sequence descriptor that specifies either the input or output
2051   // sequence. The caller retains the ownership of the returned descriptor.
2052   //
2053   // Arguments:
2054   //  max_seq_length: the max length of the sequences.
2055   //  batch_size: the size of a minibatch.
2056   //  data_size: the size of the state.
2057   //  seq_lengths: the lengths of sequences in a batch.
2058   //  data_type: an enum to specify the type for the underlying data.
2059   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2060   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2061                                     int data_size, dnn::DataType data_type) {
2062     return port::Status(port::error::UNIMPLEMENTED,
2063                         "createRnnSequenceTensorDescriptor is unimplemented");
2064   }
2065 
2066   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2067   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2068                                     int data_size,
2069                                     const absl::Span<const int>& seq_lengths,
2070                                     bool time_major, dnn::DataType data_type) {
2071     return port::Status(port::error::UNIMPLEMENTED,
2072                         "createRnnSequenceTensorDescriptor is unimplemented");
2073   }
2074 
2075   // Create an RNN state descriptor that specifies the input or hidden state.
2076   // The caller retains the ownership of the returned descriptor.
2077   virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2078   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
2079                                  dnn::DataType data_type) {
2080     return port::Status(port::error::UNIMPLEMENTED,
2081                         "createRnnStateTensorDescriptor is unimplemented");
2082   }
2083 
2084   // Enqueue a forward operation of the RNN model onto the stream.
2085   //
2086   // Arguments:
2087   //  stream: pointer to the stream where this operation should be enqueued to.
2088   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2089   //  input_desc: descriptor for the input sequence.
2090   //  input_data: the device memory region that contains the input data.
2091   //  input_h_desc: descriptor for the input "h" state.
2092   //  input_h_data: the device memory region that contains the input "h" data.
2093   //  input_c_desc: descriptor for the input "c" state.
2094   //  input_c_data: the device memory region that contains the input "c" data.
2095   //    This must be specified for LSTM models.
2096   //  params: the device memory region that contains the parameters used in this
2097   //    model.
2098   //  output_desc: descriptor for the output sequence.
2099   //  output_data: the memory region that stores the output sequence data.
2100   //  output_h_desc: descriptor for the output "h" state.
2101   //  output_h_data: the memory region that stores the output "h" data.
2102   //  output_c_desc: descriptor for the output "c" state.
2103   //  output_c_data: the memory region that stores the output "c" data. This
2104   //    must be specified for LSTM models.
2105   //  is_training: whether this is used in training or inference. That decides
2106   //    whether respace_space data need to be produced.
2107   //  reserve_space_allocator: if "is_training" is true, an memory allocator
2108   //    to create memory that holds the produced reserve_space. The caller is
2109   //  retains the data and feed it to the backward pass.
2110   //  workspace_allocator: an allocator to create temporary workspace used in
2111   //    this kernel. The caller is responsible for retaining the memory long
2112   //    enough for the lifespan of this operation, and recycles afterwards.
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2113   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2114                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2115                             const DeviceMemory<Eigen::half>& input_data,
2116                             const DeviceMemory<int>& seq_lengths_data,
2117                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2118                             const DeviceMemory<Eigen::half>& input_h_data,
2119                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2120                             const DeviceMemory<Eigen::half>& input_c_data,
2121                             const DeviceMemory<Eigen::half>& params,
2122                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2123                             DeviceMemory<Eigen::half>* output_data,
2124                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2125                             DeviceMemory<Eigen::half>* output_h_data,
2126                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2127                             DeviceMemory<Eigen::half>* output_c_data,
2128                             bool is_training,
2129                             ScratchAllocator* reserve_space_allocator,
2130                             ScratchAllocator* workspace_allocator,
2131                             dnn::ProfileResult* output_profile_result) {
2132     return false;
2133   }
2134 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2135   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2136                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2137                             const DeviceMemory<float>& input_data,
2138                             const DeviceMemory<int>& seq_lengths_data,
2139                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2140                             const DeviceMemory<float>& input_h_data,
2141                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2142                             const DeviceMemory<float>& input_c_data,
2143                             const DeviceMemory<float>& params,
2144                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2145                             DeviceMemory<float>* output_data,
2146                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2147                             DeviceMemory<float>* output_h_data,
2148                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2149                             DeviceMemory<float>* output_c_data,
2150                             bool is_training,
2151                             ScratchAllocator* reserve_space_allocator,
2152                             ScratchAllocator* workspace_allocator,
2153                             dnn::ProfileResult* output_profile_result) {
2154     return false;
2155   }
2156 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2157   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2158                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2159                             const DeviceMemory<double>& input_data,
2160                             const DeviceMemory<int>& seq_lengths_data,
2161                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2162                             const DeviceMemory<double>& input_h_data,
2163                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2164                             const DeviceMemory<double>& input_c_data,
2165                             const DeviceMemory<double>& params,
2166                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2167                             DeviceMemory<double>* output_data,
2168                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2169                             DeviceMemory<double>* output_h_data,
2170                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2171                             DeviceMemory<double>* output_c_data,
2172                             bool is_training,
2173                             ScratchAllocator* reserve_space_allocator,
2174                             ScratchAllocator* workspace_allocator,
2175                             dnn::ProfileResult* output_profile_result) {
2176     return false;
2177   }
2178   // Enqueue a backward operation of the RNN model onto the stream.
2179   //
2180   // Arguments:
2181   //  stream: pointer to the stream where this operation should be enqueued to.
2182   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2183   //  input_desc: descriptor for the input sequence.
2184   //  input_data: the device memory region that contains the input data.
2185   //  input_h_desc: descriptor for the input "h" state.
2186   //  input_h_data: the device memory region that contains the input "h" data.
2187   //  input_c_desc: descriptor for the input "c" state.
2188   //  input_c_data: the device memory region that contains the input "c" data.
2189   //    This must be specified for LSTM models.
2190   //  params: the device memory region that contains the parameters used in this
2191   //    model.
2192   //  output_desc: descriptor for the output sequence.
2193   //  output_data: the memory region that stores the output sequence data.
2194   //  output_h_desc: descriptor for the output "h" state.
2195   //  output_h_data: the memory region that stores the output "h" data.
2196   //  output_c_desc: descriptor for the output "c" state.
2197   //  output_c_data: the memory region that stores the output "c" data. This
2198   //    must be specified for LSTM models.
2199   //  output_backprop_data: the device memory region that contains the backprop
2200   //    to the output sequence.
2201   //  output_h_backprop_data: the device memory region that contains the
2202   //    backprop to the output "h" state.
2203   //  output_c_backprop_data: the device memory region that contains the
2204   //    backprop to the output "c" state.
2205   //  input_backprop_data: the device memory region that stores the backprop
2206   //    to the input sequence.
2207   //  input_h_backprop_data: the device memory region that stores the backprop
2208   //    to the input "h" state.
2209   //  input_c_backprop_data: the device memory region that stores the backprop
2210   //    to the input "c" state.
2211   //  params_backprop_data: the device memory region that stores the backprop
2212   //    to the parameters.
2213   //  reserve_space_data: the reserve_space data that is produced by the forward
2214   //    operation. This memory region could be modified by this operation.
2215   //  workspace_allocator: a memory allocator that creates the temporary
2216   //    workspace memory used by this operation. The caller is responsible for
2217   //    keeping the memory alive long enough for this operation, and recylces
2218   //    afterwards.
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2219   virtual bool DoRnnBackward(
2220       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2221       const dnn::RnnSequenceTensorDescriptor& input_desc,
2222       const DeviceMemory<Eigen::half>& input_data,
2223       const DeviceMemory<int>& seq_lengths_data,
2224       const dnn::RnnStateTensorDescriptor& input_h_desc,
2225       const DeviceMemory<Eigen::half>& input_h_data,
2226       const dnn::RnnStateTensorDescriptor& input_c_desc,
2227       const DeviceMemory<Eigen::half>& input_c_data,
2228       const DeviceMemory<Eigen::half>& params,
2229       const dnn::RnnSequenceTensorDescriptor& output_desc,
2230       const DeviceMemory<Eigen::half>& output_data,
2231       const dnn::RnnStateTensorDescriptor& output_h_desc,
2232       const DeviceMemory<Eigen::half>& output_h_data,
2233       const dnn::RnnStateTensorDescriptor& output_c_desc,
2234       const DeviceMemory<Eigen::half>& output_c_data,
2235       const DeviceMemory<Eigen::half>& output_backprop_data,
2236       const DeviceMemory<Eigen::half>& output_h_backprop_data,
2237       const DeviceMemory<Eigen::half>& output_c_backprop_data,
2238       DeviceMemory<Eigen::half>* input_backprop_data,
2239       DeviceMemory<Eigen::half>* input_h_backprop_data,
2240       DeviceMemory<Eigen::half>* input_c_backprop_data,
2241       DeviceMemory<Eigen::half>* params_backprop_data,
2242       DeviceMemory<uint8>* reserve_space_data,
2243       ScratchAllocator* workspace_allocator,
2244       dnn::ProfileResult* output_profile_result) {
2245     return false;
2246   }
2247 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2248   virtual bool DoRnnBackward(
2249       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2250       const dnn::RnnSequenceTensorDescriptor& input_desc,
2251       const DeviceMemory<float>& input_data,
2252       const DeviceMemory<int>& seq_lengths_data,
2253       const dnn::RnnStateTensorDescriptor& input_h_desc,
2254       const DeviceMemory<float>& input_h_data,
2255       const dnn::RnnStateTensorDescriptor& input_c_desc,
2256       const DeviceMemory<float>& input_c_data,
2257       const DeviceMemory<float>& params,
2258       const dnn::RnnSequenceTensorDescriptor& output_desc,
2259       const DeviceMemory<float>& output_data,
2260       const dnn::RnnStateTensorDescriptor& output_h_desc,
2261       const DeviceMemory<float>& output_h_data,
2262       const dnn::RnnStateTensorDescriptor& output_c_desc,
2263       const DeviceMemory<float>& output_c_data,
2264       const DeviceMemory<float>& output_backprop_data,
2265       const DeviceMemory<float>& output_h_backprop_data,
2266       const DeviceMemory<float>& output_c_backprop_data,
2267       DeviceMemory<float>* input_backprop_data,
2268       DeviceMemory<float>* input_h_backprop_data,
2269       DeviceMemory<float>* input_c_backprop_data,
2270       DeviceMemory<float>* params_backprop_data,
2271       DeviceMemory<uint8>* reserve_space_data,
2272       ScratchAllocator* workspace_allocator,
2273       dnn::ProfileResult* output_profile_result) {
2274     return false;
2275   }
2276 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2277   virtual bool DoRnnBackward(
2278       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2279       const dnn::RnnSequenceTensorDescriptor& input_desc,
2280       const DeviceMemory<double>& input_data,
2281       const DeviceMemory<int>& seq_lengths_data,
2282       const dnn::RnnStateTensorDescriptor& input_h_desc,
2283       const DeviceMemory<double>& input_h_data,
2284       const dnn::RnnStateTensorDescriptor& input_c_desc,
2285       const DeviceMemory<double>& input_c_data,
2286       const DeviceMemory<double>& params,
2287       const dnn::RnnSequenceTensorDescriptor& output_desc,
2288       const DeviceMemory<double>& output_data,
2289       const dnn::RnnStateTensorDescriptor& output_h_desc,
2290       const DeviceMemory<double>& output_h_data,
2291       const dnn::RnnStateTensorDescriptor& output_c_desc,
2292       const DeviceMemory<double>& output_c_data,
2293       const DeviceMemory<double>& output_backprop_data,
2294       const DeviceMemory<double>& output_h_backprop_data,
2295       const DeviceMemory<double>& output_c_backprop_data,
2296       DeviceMemory<double>* input_backprop_data,
2297       DeviceMemory<double>* input_h_backprop_data,
2298       DeviceMemory<double>* input_c_backprop_data,
2299       DeviceMemory<double>* params_backprop_data,
2300       DeviceMemory<uint8>* reserve_space_data,
2301       ScratchAllocator* workspace_allocator,
2302       dnn::ProfileResult* output_profile_result) {
2303     return false;
2304   }
2305 
2306   template <typename ElementType>
PrepareForCtcLoss(Stream * stream,const RnnStateTensorDescriptor & probs_desc,DeviceMemory<ElementType> probs_data,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2307   port::Status PrepareForCtcLoss(Stream* stream,
2308                                  const RnnStateTensorDescriptor& probs_desc,
2309                                  DeviceMemory<ElementType> probs_data,
2310                                  const RnnStateTensorDescriptor& grads_desc,
2311                                  absl::Span<const int> labels_data,
2312                                  absl::Span<const int> labels_lengths_data,
2313                                  absl::Span<const int> input_lengths_data,
2314                                  ScratchAllocator* workspace_allocator,
2315                                  DeviceMemory<uint8>* scratch_memory,
2316                                  int* ctc_loss_algo_id) {
2317     return DoPrepareForCtcLoss(
2318         stream, ToDataType<ElementType>::value, probs_desc, grads_desc,
2319         labels_data, labels_lengths_data, input_lengths_data,
2320         workspace_allocator, scratch_memory, ctc_loss_algo_id);
2321   }
2322 
2323   // Enqueue a CTC Loss operation onto the stream.
2324   //
2325   // Arguments:
2326   //  stream: pointer to the stream where this operation should be enqueued to.
2327   //  element_type: date type of the input tensors
2328   //  probs_desc: specifies the shape and the data layout of the input tensor.
2329   //  probs_data: the device memory region that contains the input tensor.
2330   //  labels_data: the device memory region that contains the labels_value
2331   //    tensor.
2332   //  labels_lengths_data: the device memory region that contains the
2333   //    labels_lengths tensor
2334   //  input_lengths_data: the device memory region that contains the seq_lengths
2335   //    tensor
2336   //  costs_data: the device memory region that contains the costs tensor.
2337   //  grads_desc: specifies the shape and the data layout of the grads tensor.
2338   //  grads_data: the device memory region that contains the grads tensor.
2339   //  ctc_loss_desc: a CTCLoss descriptor.
2340   //  workspace_allocator: a memory allocator that creates the temporary
2341   //    workspace memory used by this operation. The caller is responsible for
2342   //    keeping the memory alive long enough for this operation, and recylces
2343   //    afterwards.
2344   virtual port::Status DoCtcLoss(
2345       Stream* stream, dnn::DataType element_type,
2346       const RnnStateTensorDescriptor& probs_desc,
2347       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2348       absl::Span<const int> labels_lengths_data,
2349       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2350       const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
2351       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
2352 
2353   template <typename ElementType>
DoCtcLoss(Stream * stream,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<ElementType> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<ElementType> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<ElementType> * grads_data,DeviceMemory<uint8> * scratch_memory,int ctc_loss_algo_id)2354   bool DoCtcLoss(Stream* stream,
2355                  const dnn::RnnStateTensorDescriptor& probs_desc,
2356                  const DeviceMemory<ElementType>& probs_data,
2357                  absl::Span<const int> labels_data,
2358                  absl::Span<const int> labels_lengths_data,
2359                  absl::Span<const int> input_lengths_data,
2360                  DeviceMemory<ElementType>* costs_data,
2361                  const dnn::RnnStateTensorDescriptor& grads_desc,
2362                  DeviceMemory<ElementType>* grads_data,
2363                  DeviceMemory<uint8>* scratch_memory, int ctc_loss_algo_id) {
2364     return IsStatusOk(
2365         DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
2366                   probs_data, labels_data, labels_lengths_data,
2367                   input_lengths_data, *costs_data, grads_desc, *grads_data,
2368                   *scratch_memory, ctc_loss_algo_id),
2369         false);
2370   }
2371 
2372   // Transforms a tensor into another tensor with a different layout and/or data
2373   // type.
2374   //
2375   // Arguments:
2376   //  stream: pointer to the stream where this operation should be enqueued to.
2377   //  input_desc: specifies the shape and the data layout of the input tensor.
2378   //  input_type: the data type of the input tensor.
2379   //  input_data: the device memory region that contains the input tensor.
2380   //  output_desc: specifies the shape and the data layout of the output tensor.
2381   //  output_type: the data type of the output tensor.
2382   //  scale: an element-wise scaling factor to apply.
2383   //  output_data: the device memory region that contains the output tensor.
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)2384   virtual bool DoTransformTensor(Stream* stream,
2385                                  const dnn::BatchDescriptor& input_desc,
2386                                  dnn::DataType input_type,
2387                                  const DeviceMemoryBase& input_data,
2388                                  const dnn::BatchDescriptor& output_desc,
2389                                  dnn::DataType output_type, float scale,
2390                                  DeviceMemoryBase* output_data) {
2391     return false;
2392   }
2393 
2394   // Enqueues a fused convolution+bias+activation operation onto the stream.
2395   //
2396   // Arguments (all borrowed):
2397   //
2398   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2399   //  be enqueued onto.
2400   //
2401   //  conv_input_descriptor: dimensions of the convolution input layer.
2402   //  conv_input_data: device memory which contains the convolution input.
2403   //
2404   //  filter_descriptor: dimensions of the convolution filter.
2405   //  filter_data: device memory which contains the convolution filter weights.
2406   //
2407   //  convolution_descriptor: stride of the convolution filter.
2408   //
2409   //  bias_descriptor: dimensions of the bias layer
2410   //  biases: device memory region containing biases to add to the convolution
2411   //  output
2412   //
2413   //  activation_mode: Type of activation to perform.
2414   //
2415   //  output_descriptor: dimensions of the output layer.
2416   //  output_data: device memory region in which to place the fusion result.
2417   //
2418   //  output_profile_result: the output profile result for this call.
2419   //         The profiling is only enabled when this is not nullptr.
2420   //
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)2421   virtual bool DoFusedConvolutionBiasActivation(
2422       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
2423       const DeviceMemory<float>& conv_input_data,
2424       const dnn::FilterDescriptor& filter_descriptor,
2425       const DeviceMemory<float>& filter_data,
2426       const dnn::ConvolutionDescriptor& convolution_descriptor,
2427       const dnn::BatchDescriptor& bias_descriptor,
2428       const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
2429       const dnn::BatchDescriptor& output_descriptor,
2430       DeviceMemory<float>* output_data,
2431       dnn::ProfileResult* output_profile_result) {
2432     return false;
2433   }
2434 
2435   // Enqueues a fused batchnorm+activation (inference) operation onto the
2436   // stream.
2437   //
2438   // Arguments (all borrowed):
2439   //
2440   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2441   //  be enqueued onto.
2442   //
2443   //  x_descriptor: dimensions of the batchnorm input layer.
2444   //  x_data: device memory which contains the batchnorm input.
2445   //
2446   //  scale_offset_mean_variance_descriptor:
2447   //      dimensions of the scale/offset/mean/variance tensor.
2448   //  scale_data: device memory which contains the scale input.
2449   //  offset_data: device memory which contains the offset input.
2450   //  mean_data: device memory which contains the mean input.
2451   //  variance_data: device memory which contains the variance input.
2452   //  epsilon : the epsilon value to use in batchnorm calculation
2453   //
2454   //  activation_mode: Type of activation to perform.
2455   //
2456   //  y_data: device memory region in which to place the fusion result.
2457   //
2458   //  output_profile_result: the output profile result for this call.
2459   //         The profiling is only enabled when this is not nullptr.
2460   //
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)2461   virtual bool DoFusedBatchNormActivationInference(
2462       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2463       const DeviceMemory<float>& x_data,
2464       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2465       const DeviceMemory<float>& scale_data,
2466       const DeviceMemory<float>& offset_data,
2467       const DeviceMemory<float>& mean_data,
2468       const DeviceMemory<float>& variance_data, double epsilon,
2469       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2470       dnn::ProfileResult* output_profile_result) {
2471     return false;
2472   }
2473 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)2474   virtual bool DoFusedBatchNormActivationInference(
2475       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2476       const DeviceMemory<Eigen::half>& x_data,
2477       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2478       const DeviceMemory<float>& scale_data,
2479       const DeviceMemory<float>& offset_data,
2480       const DeviceMemory<float>& mean_data,
2481       const DeviceMemory<float>& variance_data, double epsilon,
2482       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2483       dnn::ProfileResult* output_profile_result) {
2484     return false;
2485   }
2486 
2487   // Enqueues a fused batchnorm+activation (training-fwd) operation onto the
2488   // stream.
2489   //
2490   // Arguments (all borrowed):
2491   //
2492   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2493   //  be enqueued onto.
2494   //
2495   //  x_descriptor: dimensions of the batchnorm input layer.
2496   //  x_data: device memory which contains the batchnorm input.
2497   //
2498   //  scale_offset_mean_variance_descriptor:
2499   //      dimensions of the scale/offset/mean/variance tensor.
2500   //  scale_data: device memory which contains the scale input.
2501   //  offset_data: device memory which contains the offset input.
2502   //  epsilon : the epsilon value to use in batchnorm calculation
2503   //
2504   //  activation_mode: Type of activation to perform.
2505   //
2506   //  y_data: device memory region in which to place the fusion result.
2507   //  batch_mean_data: device memory in which to place the batch mean output.
2508   //  batch_var_data: device memory in which to place the batch variance output.
2509   //  saved_mean_data: device memory in which to save the mean for bwd pass.
2510   //  saved_var_data: device memory in which to save the variance for bwd pass.
2511   //
2512   //  output_profile_result: the output profile result for this call.
2513   //         The profiling is only enabled when this is not nullptr.
2514   //
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2515   virtual bool DoFusedBatchNormActivationForward(
2516       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2517       const DeviceMemory<float>& x_data,
2518       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2519       const DeviceMemory<float>& scale_data,
2520       const DeviceMemory<float>& offset_data, double epsilon,
2521       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2522       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2523       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2524       dnn::ProfileResult* output_profile_result) {
2525     return false;
2526   }
2527 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2528   virtual bool DoFusedBatchNormActivationForward(
2529       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2530       const DeviceMemory<Eigen::half>& x_data,
2531       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2532       const DeviceMemory<float>& scale_data,
2533       const DeviceMemory<float>& offset_data, double epsilon,
2534       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2535       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2536       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2537       dnn::ProfileResult* output_profile_result) {
2538     return false;
2539   }
2540 
2541   // Enqueues a fused batchnorm+activation (training-bwd) operation onto the
2542   // stream.
2543   //
2544   // Arguments (all borrowed):
2545   //
2546   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2547   //  be enqueued onto.
2548   //
2549   //  y_act_backprop_descriptor: dimensions of the backprop input from the
2550   //  previous layer. y_act_backprop_data: device memory which contains the
2551   //  backprop input.
2552   //
2553   //  y_act_data: device memory which contains the actv-fwd output data.
2554   //
2555   //  activation_mode: actv-fwd type.
2556   //
2557   //  scale_offset_mean_variance_descriptor:
2558   //      dimensions of the scale/offset/mean/variance tensor.
2559   //  scale_data: device memory which contains the scale input.
2560   //  offset_data: device memory which contains the offset input.
2561   //  saved_mean_data: device memory which contains the saved mean from fwd
2562   //  pass. saved_var_data: device memory which contains the saved variance from
2563   //  fwd pass.
2564   //
2565   //  x_bn_backprop_data: device memory region in which to place the backprop
2566   //  data from this layer scale_backprop_data: device memory in which to place
2567   //  the scale backprop output. offset_backprop_data: device memory in which to
2568   //  place the offset backprop output.
2569   //
2570   //  output_profile_result: the output profile result for this call.
2571   //         The profiling is only enabled when this is not nullptr.
2572   //
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2573   virtual bool DoFusedBatchNormActivationBackward(
2574       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2575       const DeviceMemory<float>& y_act_backprop_data,
2576       const DeviceMemory<float>& y_act_data,
2577       dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
2578       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2579       const DeviceMemory<float>& scale_data,
2580       const DeviceMemory<float>& offset_data,
2581       const DeviceMemory<float>& saved_mean_data,
2582       const DeviceMemory<float>& saved_var_data,
2583       DeviceMemory<float>* x_bn_backprop_data,
2584       DeviceMemory<float>* scale_backprop_data,
2585       DeviceMemory<float>* offset_backprop_data,
2586       dnn::ProfileResult* output_profile_result) {
2587     return false;
2588   }
2589 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2590   virtual bool DoFusedBatchNormActivationBackward(
2591       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2592       const DeviceMemory<Eigen::half>& y_act_backprop_data,
2593       const DeviceMemory<Eigen::half>& y_act_data,
2594       dnn::ActivationMode activation_mode,
2595       const DeviceMemory<Eigen::half>& x_bn_data,
2596       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2597       const DeviceMemory<float>& scale_data,
2598       const DeviceMemory<float>& offset_data,
2599       const DeviceMemory<float>& saved_mean_data,
2600       const DeviceMemory<float>& saved_var_data,
2601       DeviceMemory<Eigen::half>* x_bn_backprop_data,
2602       DeviceMemory<float>* scale_backprop_data,
2603       DeviceMemory<float>* offset_backprop_data,
2604       dnn::ProfileResult* output_profile_result) {
2605     return false;
2606   }
2607 
2608   // Notifies that a stream is being destroyed and should be invalidated from
2609   // any internal caching.  This exists to allow the CUDA implementation to
2610   // avoid redundant cudnnSetStream calls without risking problems when a stream
2611   // is destroyed and a new stream later created in the same memory.
NotifyStreamDestroyed(Stream * stream)2612   virtual void NotifyStreamDestroyed(Stream* stream) {}
2613 
2614  protected:
2615   // Returns whether status is 'ok', and potentially logs the error.
2616   static bool IsStatusOk(const port::Status& status, bool report_error);
2617 
2618  private:
DoPrepareForConvolution(ConvolutionKind kind,DataType element_type,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemoryBase input_data,const FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2619   virtual port::Status DoPrepareForConvolution(
2620       ConvolutionKind kind, DataType element_type, Stream* stream,
2621       const BatchDescriptor& batch_descriptor, DeviceMemoryBase input_data,
2622       const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data,
2623       const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
2624       const ConvolutionDescriptor& convolution_descriptor,
2625       const AlgorithmConfig& algorithm_config,
2626       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
2627       DeviceMemory<uint8>* scratch_memory) {
2628     *algorithm_desc = {};
2629     *scratch_memory = {};
2630     return ::tensorflow::OkStatus();
2631   }
2632 
DoPrepareForCtcLoss(Stream * stream,DataType element_type,const RnnStateTensorDescriptor & probs_desc,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2633   virtual port::Status DoPrepareForCtcLoss(
2634       Stream* stream, DataType element_type,
2635       const RnnStateTensorDescriptor& probs_desc,
2636       const RnnStateTensorDescriptor& grads_desc,
2637       absl::Span<const int> labels_data,
2638       absl::Span<const int> labels_lengths_data,
2639       absl::Span<const int> input_lengths_data,
2640       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
2641       int* ctc_loss_algo_id) {
2642     *scratch_memory = {};
2643     return ::tensorflow::OkStatus();
2644   }
2645 
2646   SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
2647 };
2648 
2649 }  // namespace dnn
2650 }  // namespace stream_executor
2651 
2652 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DNN_H_
2653