1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Neural Net operation support for StreamExecutor instances.
17 //
18 // This is an abstract interface for a platform to optionally support common
19 // neural net operations; it accommodates implementations such as the cudnn
20 // library operations.
21 
22 #ifndef TENSORFLOW_STREAM_EXECUTOR_DNN_H_
23 #define TENSORFLOW_STREAM_EXECUTOR_DNN_H_
24 
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <tuple>
29 #include <type_traits>
30 
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/stream_executor/data_type.h"
34 #include "tensorflow/stream_executor/device_description.h"
35 #include "tensorflow/stream_executor/device_memory.h"
36 #include "tensorflow/stream_executor/dnn.pb.h"
37 #include "tensorflow/stream_executor/lib/array_slice.h"
38 #include "tensorflow/stream_executor/lib/status.h"
39 #include "tensorflow/stream_executor/lib/statusor.h"
40 #include "tensorflow/stream_executor/platform/logging.h"
41 #include "tensorflow/stream_executor/platform/port.h"
42 
43 namespace Eigen {
44 struct half;
45 }  // namespace Eigen
46 
47 namespace stream_executor {
48 
49 class HostBuffer;
50 class Stream;
51 class ScratchAllocator;
52 
53 namespace dnn {
54 
55 // Specifies an index to use when accessing specific spatial dimensions.
56 enum class DimIndex : int {
57   X = 0,
58   Y = 1,
59   Z = 2,
60 };
61 
62 // Return a reordered dims.
63 std::vector<int64> ReorderDims(const std::vector<int64>& input,
64                                const DataLayout& from, const DataLayout& to);
65 
66 // Helper functions to make methods more readable.
GetDim(absl::Span<const int64> data,DimIndex dim)67 inline int64 GetDim(absl::Span<const int64> data, DimIndex dim) {
68   return data.rbegin()[static_cast<int64>(dim)];
69 }
70 
SetDim(absl::Span<int64> data,DimIndex dim,int64_t value)71 inline void SetDim(absl::Span<int64> data, DimIndex dim, int64_t value) {
72   data.rbegin()[static_cast<int64>(dim)] = value;
73 }
74 
SetDim(std::vector<int64> * data,DimIndex dim,int64_t value)75 inline void SetDim(std::vector<int64>* data, DimIndex dim, int64_t value) {
76   return SetDim(absl::MakeSpan(*data), dim, value);
77 }
78 
79 // int64 is not the same type as tensorflow::protobuf_int64 in open-source. This
80 // wrapper function gives an int64 array slice view of a repeated int64 protobuf
81 // field.
82 //
83 // T should be a protobuf RepeatedField.
84 template <typename T>
AsInt64Slice(const T & repeated_field)85 inline absl::Span<const int64> AsInt64Slice(const T& repeated_field) {
86   using data_ty =
87       typename std::remove_reference<decltype(*repeated_field.data())>::type;
88   static_assert(std::is_integral<data_ty>::value &&
89                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
90                 "repeated_field.data() must return a pointer to a signed "
91                 "64-bit integer type.");
92   return absl::Span<const int64>(
93       reinterpret_cast<const int64*>(repeated_field.data()),
94       repeated_field.size());
95 }
96 template <typename T>
AsInt64Slice(T * repeated_field)97 inline absl::Span<int64> AsInt64Slice(T* repeated_field) {
98   using data_ty =
99       typename std::remove_reference<decltype(*repeated_field->data())>::type;
100   static_assert(std::is_integral<data_ty>::value &&
101                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
102                 "repeated_field->data() must return a pointer to a signed "
103                 "64-bit integer type.");
104   return absl::Span<int64>(
105       reinterpret_cast<int64*>(repeated_field->mutable_data()),
106       repeated_field->size());
107 }
108 
109 // Returns a string representation of the given data layout.
110 std::string DataLayoutString(DataLayout layout);
111 
112 // Specifies a quantization for activations in a given BatchDescriptor.
113 enum class QuantizedActivationMode {
114   k8Bit = 1,
115   k16Bit = 2,
116   k32Bit = 4,
117 };
118 
119 // Specifies the types of a RNN model.
120 enum class RnnMode {
121   kRnnRelu = 0,
122   kRnnTanh = 1,
123   kRnnLstm = 2,
124   kRnnGru = 3,
125 };
126 
127 // Specifies the input model and whether there is a linear transformation
128 // between the input state and the first layer hidden state.
129 enum class RnnInputMode {
130   kRnnLinearSkip = 0,
131   kRnnSkipInput = 1,
132 };
133 
134 // Specifies the number of directions used in a RNN model. When bidirection
135 // is used, the input states and output sequence contain data for both
136 // directions.
137 enum class RnnDirectionMode {
138   kRnnUnidirectional = 0,
139   kRnnBidirectional = 1,
140 };
141 
142 // Relevant to DepthToSpace and SpaceToDepth. This is the write layout when
143 // performing depth to space and the read layout when performing space to depth.
144 // It's specified with most-major dimension first and most-minor dimension last.
145 // In DepthToSpace, the D*M^2 values are read in and then, for DepthHeightWidth,
146 // written out to the output patch, by varying first width, then height, then
147 // depth. In C array format, it looks like [depth][height][width]. See
148 // DepthToSpace comment for more information.
149 enum class DepthToSpaceLayout { DepthHeightWidth };
150 
151 // Specifies the descriptor for a RNN model.
152 //
153 // An example use case:
154 //   * The user first creates a model through createRnnDescriptor.
155 //   * The user queries the size of the underlying opaque parameter buffer.
156 //   * The user creates and initializes a parameter buffer of the proper size.
157 //   * The user runs forward and backward operations using this RNN descriptor.
158 //   * Once a while, user queries maintainable weights and bias regions from
159 //       the underlying parameter buffer. They are more likely to be forward
160 //       compatible and should used in saving and restoring a model.
161 //   * The user releases the RNN descriptor when the model is no longer in use.
162 class RnnDescriptor {
163  public:
164   struct ParamsRegion {
165     int64 offset;
166     int64 size;
167   };
168   typedef std::vector<ParamsRegion> ParamsRegions;
~RnnDescriptor()169   virtual ~RnnDescriptor() {}
ParamsSizeInBytes()170   virtual int64 ParamsSizeInBytes() const { return -1; }
ParamsWeightRegions()171   virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); }
ParamsBiasRegions()172   virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
173 };
174 
175 // Specifies the sequence in a RNN model.
176 //
177 // The user is responsible for releasing this descriptor when it is no longer
178 // in use. The destructor releases the underlying descriptors.
179 class RnnSequenceTensorDescriptor {
180  public:
~RnnSequenceTensorDescriptor()181   virtual ~RnnSequenceTensorDescriptor() {}
182 };
183 
184 // Specifies either the input and hidden state in a RNN model.
185 //
186 // The user is responsible for releasing this descriptor when it is no longer
187 // in use. The destructor releases the underlying descriptors.
188 class RnnStateTensorDescriptor {
189  public:
~RnnStateTensorDescriptor()190   virtual ~RnnStateTensorDescriptor() {}
191 };
192 
193 // Specifies the execution plan in convolution.
194 class ConvolveExecutionPlan {
195  public:
~ConvolveExecutionPlan()196   virtual ~ConvolveExecutionPlan() {}
getTag()197   virtual std::string getTag() { return "unknown"; }
get_raw_desc()198   virtual void* get_raw_desc() { return nullptr; }
getWorkspaceSize()199   virtual int64_t getWorkspaceSize() { return -1; }
200 };
201 
202 // Returns a string representation of the given quantization mode.
203 std::string QuantizedActivationModeString(QuantizedActivationMode mode);
204 
205 // Describes the dimensions that a layer consumes/produces.
206 //
207 // This is a matrix (height, width), its "depth" (feature_map_count),
208 // how many of these matrices are present (count),
209 // and the maximum and minimum values expected in the matrix (value_max,
210 // value_min).
211 // If input is quantized, all values greater
212 // than value_max will be clipped to value_max and all values less than
213 // value_min will be clipped to value_min.
214 // When quantized output is dequantized no value will be greater than
215 // value_max or less than value_min.
216 //
217 // Uses the named argument construction form:
218 //
219 //  auto input_batch_dimensions =
220 //      BatchDescriptor().set_count(42).set_feature_map_count(7)...
221 //
222 // Details:
223 //
224 // For a convolutional layer, a single inference takes a 3-dimensional matrix
225 // of input and produces a 3-dimensional matrix of output. We call the three
226 // dimensions height, width and feature_map_count, where for an image, the
227 // height and width correspond to the Y and X pixel indices, respectively, and
228 // the feature_map_count corresponds to the RGB dimension of the input data.
229 // Then the count indicates how many 3D matrices are being presented to be
230 // processed at once; this corresponds to the neural network concept of
231 // minibatch size.
232 //
233 // For a fully connected layer, it's better to put the nodes of the layer in
234 // the feature_map_count, and leave the height and weight as degenerate (== 1).
235 // Count indicates how many input vectors (degenerate 3D matrices) are to be
236 // processed.
237 //
238 // If unspecified, value_max and value_min default to 0.0.
239 // If value_max == value_min the Stream will attempt to derive valid values -
240 // for example the output of Relu6 activation will always be in the range
241 // [0.0, 6.0].
242 //
243 // If unspecified, layout defaults to kYXDepthBatch.
244 class BatchDescriptor {
245  public:
246   // Creates a "blank" batch descriptor, which should be initialized via the
247   // named argument helpers.
248   BatchDescriptor();
249   explicit BatchDescriptor(int ndims);
250 
251   // Clones values from 'other' for initialization.
252   void CloneFrom(const BatchDescriptor& other);
253 
254   std::string ToString() const;
255   std::string ToShortString() const;
256 
257   // Pre-condition:
258   //   value_max_ == 0
259   //   value_min_ == 0
260   //   quantized_activation_mode_ == QuantizedActivationMode::k8Bit
261   TensorDescriptorProto ToProto(DataType data_type) const;
262 
263   // Accessors.
count()264   int64 count() const { return tensor_.dimensions(0); }
feature_map_count()265   int64 feature_map_count() const { return tensor_.dimensions(1); }
height()266   int64 height() const { return GetDim(spatial_size(), DimIndex::Y); }
width()267   int64 width() const { return GetDim(spatial_size(), DimIndex::X); }
spatial_dim(DimIndex dim)268   int64 spatial_dim(DimIndex dim) const { return GetDim(spatial_size(), dim); }
ndims()269   int ndims() const { return spatial_size().size(); }
value_max()270   float value_max() const { return value_max_; }
value_min()271   float value_min() const { return value_min_; }
layout()272   DataLayout layout() const { return tensor_.data_layout(); }
quantized_activation_mode()273   QuantizedActivationMode quantized_activation_mode() const {
274     return quantized_activation_mode_;
275   }
276   // Full dimensions of the underlying data, ordered according to a specific
277   // layout.
278   std::vector<int64> full_dims(const DataLayout& layout) const;
279 
280   // Full strides of the underlying data, ordered according to a specific
281   // layout.
282   std::vector<int64> full_strides(const DataLayout& layout) const;
283 
284   // Vectorized dimensions where users can specify the dimension that the number
285   // of dimensions is reported rather than the full number of elements.
286   std::vector<int64> vectorized_dims(const DataLayout& layout, int vector_size,
287                                      int vector_dim) const;
288 
289   // Vectorized strides correspond to the vectorized_dims.
290   std::vector<int64> vectorized_strides(const DataLayout& layout,
291                                         int vector_size, int vector_dim) const;
292 
293   // Named-argument helpers for avoiding user error during construction.
set_count(int64_t value)294   BatchDescriptor& set_count(int64_t value) {
295     tensor_.set_dimensions(0, value);
296     return *this;
297   }
set_feature_map_count(int64_t value)298   BatchDescriptor& set_feature_map_count(int64_t value) {
299     tensor_.set_dimensions(1, value);
300     return *this;
301   }
set_height(int64_t value)302   BatchDescriptor& set_height(int64_t value) {
303     SetDim(spatial_size(), DimIndex::Y, value);
304     return *this;
305   }
set_width(int64_t value)306   BatchDescriptor& set_width(int64_t value) {
307     SetDim(spatial_size(), DimIndex::X, value);
308     return *this;
309   }
set_spatial_dim(DimIndex dim,int64_t value)310   BatchDescriptor& set_spatial_dim(DimIndex dim, int64_t value) {
311     SetDim(spatial_size(), dim, value);
312     return *this;
313   }
set_value_max(float value)314   BatchDescriptor& set_value_max(float value) {
315     value_max_ = value;
316     return *this;
317   }
set_value_min(float value)318   BatchDescriptor& set_value_min(float value) {
319     value_min_ = value;
320     return *this;
321   }
set_layout(DataLayout layout)322   BatchDescriptor& set_layout(DataLayout layout) {
323     tensor_.set_data_layout(layout);
324     return *this;
325   }
set_quantized_activation_mode(QuantizedActivationMode quantized_activation_mode)326   BatchDescriptor& set_quantized_activation_mode(
327       QuantizedActivationMode quantized_activation_mode) {
328     quantized_activation_mode_ = quantized_activation_mode;
329     return *this;
330   }
331 
332   // Return the number of nodes in a single feature map.
333   int64 NodesPerFeatureMap() const;
334 
335   // Return the number of nodes across all feature maps. Note that this is not
336   // affected by the batch count.
337   int64 NodesAcrossFeatureMaps() const;
338 
339   // Returns the number of elements (e.g. RGB pixel values) required to hold a
340   // given batch descriptor, given a no-padding assumption. Note that this is
341   // affected by the batch count.
342   int64 ElementCount() const;
343 
344   // Return the number of weights required to fully connect a layer with
345   // dimensions given by the 'input' descriptor with a layer with dimensions
346   // given by the 'output' descriptor.
347   static int64 FullyConnectedWeightCount(const BatchDescriptor& input,
348                                          const BatchDescriptor& output);
349 
350   // Return the number of biases required to fully connect to an output layer
351   // with dimensions given the 'output' descriptor.
352   static int64 FullyConnectedBiasCount(const BatchDescriptor& output);
353 
354   // Return a BatchDescriptor for the output of a depth concatenation
355   // with the given input descriptors. The inputs should have the same
356   // dimensions, except possibly for feature_map_count(), though this
357   // function does not verify that.
358   static BatchDescriptor DepthConcatenateOutputDescriptor(
359       port::ArraySlice<dnn::BatchDescriptor> inputs);
360 
361  private:
spatial_size()362   absl::Span<const int64> spatial_size() const {
363     return AsInt64Slice(tensor_.dimensions()).subspan(2);
364   }
365 
spatial_size()366   absl::Span<int64> spatial_size() {
367     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
368   }
369 
370   TensorDescriptorProto tensor_;
371   float value_max_;
372   float value_min_;
373   QuantizedActivationMode quantized_activation_mode_;
374 };
375 
376 // Returns a string representation of the given filter layout.
377 std::string FilterLayoutString(FilterLayout layout);
378 
379 // Describes a filter for the convolution. This is the "window" from
380 // height-by-width patches of each of the feature maps in the input layer to the
381 // cells within the output feature map.
382 //
383 // Uses the named argument construction form:
384 //
385 //  FilterDescriptor filter_dimensions;
386 //  filter_dimensions
387 //    .set_output_feature_map_count(42)
388 //    .set_input_feature_map_count(7)
389 //    ...
390 //
391 // Arguments:
392 // - output_feature_map_count: number of feature maps in the output layer.
393 // - input_feature_map_count: number of feature maps in the input layer (from
394 //      which the filter patch is taken).
395 // - input_filter_height: "height" number of neurons used in the sliding window
396 //      over the input layer.
397 // - input_filter_width: "width" number of neurons used in the sliding window
398 //      over the input layer.
399 //
400 // Sometimes names like "filter input height" are referred to by synonymous
401 // terminology, such as "kernel y size".
402 //
403 // If unspecified, layout defaults to kOutputInputYX.
404 class FilterDescriptor {
405  public:
406   // By default construction, all dimensions are set to zero, so they should all
407   // be populated by the user via the named-argument helpers below. (See class
408   // comment for details.)
409   FilterDescriptor();
410   explicit FilterDescriptor(int ndims);
411   ~FilterDescriptor();
412 
413   // Named-argument helpers for avoiding user error during construction.
set_output_feature_map_count(int64_t value)414   FilterDescriptor& set_output_feature_map_count(int64_t value) {
415     tensor_.set_dimensions(0, value);
416     return *this;
417   }
set_input_feature_map_count(int64_t value)418   FilterDescriptor& set_input_feature_map_count(int64_t value) {
419     tensor_.set_dimensions(1, value);
420     return *this;
421   }
set_input_filter_height(int64_t value)422   FilterDescriptor& set_input_filter_height(int64_t value) {
423     SetDim(input_filter_dims(), DimIndex::Y, value);
424     return *this;
425   }
set_input_filter_width(int64_t value)426   FilterDescriptor& set_input_filter_width(int64_t value) {
427     SetDim(input_filter_dims(), DimIndex::X, value);
428     return *this;
429   }
set_layout(FilterLayout layout)430   FilterDescriptor& set_layout(FilterLayout layout) {
431     tensor_.set_filter_layout(layout);
432     return *this;
433   }
set_spatial_dim(DimIndex dim,int64_t value)434   FilterDescriptor& set_spatial_dim(DimIndex dim, int64_t value) {
435     SetDim(input_filter_dims(), dim, value);
436     return *this;
437   }
ndims()438   int ndims() const { return input_filter_dims().size(); }
439 
440   void CloneFrom(const FilterDescriptor& other);
441 
442   std::string ToString() const;
443   std::string ToShortString() const;
444   TensorDescriptorProto ToProto(DataType data_type) const;
445 
446   // Returns the number of weights required as parameters for a convolution
447   // using this filter descriptor.
448   int64 ComputeWeightCount() const;
449 
450   // Returns the number of biases required as parameters for a convolution
451   // using this filter descriptor.
bias_count()452   int64 bias_count() const { return output_feature_map_count(); }
453 
output_feature_map_count()454   int64 output_feature_map_count() const { return tensor_.dimensions(0); }
input_feature_map_count()455   int64 input_feature_map_count() const { return tensor_.dimensions(1); }
input_filter_height()456   int64 input_filter_height() const {
457     return GetDim(input_filter_dims(), DimIndex::Y);
458   }
input_filter_width()459   int64 input_filter_width() const {
460     return GetDim(input_filter_dims(), DimIndex::X);
461   }
input_filter_dim(DimIndex dim)462   int64 input_filter_dim(DimIndex dim) const {
463     return GetDim(input_filter_dims(), dim);
464   }
465 
layout()466   FilterLayout layout() const { return tensor_.filter_layout(); }
467 
input_filter_dims()468   absl::Span<const int64> input_filter_dims() const {
469     return AsInt64Slice(tensor_.dimensions()).subspan(2);
470   }
471 
472   // Full dimensions of the underlying filter,
473   // ordered according to a specific layout.
474   std::vector<int64> full_dims(const FilterLayout& layout) const;
475 
476   // Full strides of the underlying filter,
477   // ordered according to a specific layout.
478   std::vector<int64> full_strides(const FilterLayout& layout) const;
479 
480   // Vectorized dimensions where users can specify the dimension that the number
481   // of dimensions is reported rather than the full number of elements.
482   std::vector<int64> vectorized_dims(const FilterLayout& layout,
483                                      int vector_size, int vector_dim) const;
484 
485   // Vectorized strides correspond to the vectorized_dims.
486   std::vector<int64> vectorized_strides(const FilterLayout& layout,
487                                         int vector_size, int vector_dim) const;
488 
489  private:
input_filter_dims()490   absl::Span<int64> input_filter_dims() {
491     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
492   }
493 
494   TensorDescriptorProto tensor_;
495 };
496 
497 // Describes how padding should be aligned when the total number of pad
498 // elements is odd.
499 enum class PadAlignment : int64 {
500   kDefault = 0,        // default padding for the device.
501   kCudnnPadding,       // cuDNN padding - prefer to pad at the start.
502   kTensorFlowPadding,  // TensorFlow padding - prefer to pad at the end.
503 };
504 
505 // Returns a string representation of the given padding alignment.
506 std::string PadAlignmentString(PadAlignment alignment);
507 
508 // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
509 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
510 
511 // Describes a convolution.
512 //
513 // Uses the named argument construction form:
514 //
515 //  ConvolutionDescriptor convolution_dimensions;
516 //  convolution_dimensions
517 //    .set_vertical_filter_stride(2)
518 //    .set_horizontal_filter_stride(2)
519 //    ...
520 //
521 // Arguments:
522 // - zero_padding_height: padding of the "y dimension" of the input data. Note
523 //    that this is different from the height of the filter.
524 // - zero_padding_width: analogous to the height above, but in the "x
525 //    dimension".
526 // - vertical_filter_stride: the convolution slides a 2-dimensional window of
527 //    filter-height-by-filter-width over the input layer -- the center of that
528 //    window is moved in the "y dimension" according to this stride value.
529 // - horizontal_filter_stride: analogous to the vertical stride above, but in
530 //    the "x dimension".
531 // - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped
532 //   cells between each filter element in the "y dimension".
533 // - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1)
534 //   skipped cells between each filter element in the "x dimension".
535 // - convolution_not_crosscor: By default (convolution_not_crosscor == false),
536 //   we perform cross correlation rather than convolution. With the flag set,
537 //   we perform convolution. Convolution and cross correlation are related by
538 //   rotating the filter by 180 degrees (or equivalently flipping all spatial
539 //   dimensions).
540 class ConvolutionDescriptor {
541  public:
542   // By default construction, there is no zero-padding and the filter stride is
543   // 1x1 (centering the filter on every cell in the input layer's
544   // width-by-height area).
545   ConvolutionDescriptor();
546   explicit ConvolutionDescriptor(int ndims);
547   ~ConvolutionDescriptor();
548 
549   std::string ToString() const;
550   std::string ToShortString() const;
ToProto()551   ConvolutionDescriptorProto ToProto() const { return proto_; }
552 
set_zero_padding_height(int64_t value)553   ConvolutionDescriptor& set_zero_padding_height(int64_t value) {
554     SetDim(padding(), DimIndex::Y, value);
555     return *this;
556   }
set_zero_padding_width(int64_t value)557   ConvolutionDescriptor& set_zero_padding_width(int64_t value) {
558     SetDim(padding(), DimIndex::X, value);
559     return *this;
560   }
set_zero_padding(DimIndex dim,int64_t value)561   ConvolutionDescriptor& set_zero_padding(DimIndex dim, int64_t value) {
562     SetDim(padding(), dim, value);
563     return *this;
564   }
set_vertical_filter_stride(int64_t value)565   ConvolutionDescriptor& set_vertical_filter_stride(int64_t value) {
566     SetDim(strides(), DimIndex::Y, value);
567     return *this;
568   }
set_horizontal_filter_stride(int64_t value)569   ConvolutionDescriptor& set_horizontal_filter_stride(int64_t value) {
570     SetDim(strides(), DimIndex::X, value);
571     return *this;
572   }
set_filter_stride(DimIndex dim,int64_t value)573   ConvolutionDescriptor& set_filter_stride(DimIndex dim, int64_t value) {
574     SetDim(strides(), dim, value);
575     return *this;
576   }
set_vertical_dilation_rate(int64_t value)577   ConvolutionDescriptor& set_vertical_dilation_rate(int64_t value) {
578     SetDim(dilations(), DimIndex::Y, value);
579     return *this;
580   }
set_horizontal_dilation_rate(int64_t value)581   ConvolutionDescriptor& set_horizontal_dilation_rate(int64_t value) {
582     SetDim(dilations(), DimIndex::X, value);
583     return *this;
584   }
set_dilation_rate(DimIndex dim,int64_t value)585   ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64_t value) {
586     SetDim(dilations(), dim, value);
587     return *this;
588   }
set_group_count(int group_count)589   ConvolutionDescriptor& set_group_count(int group_count) {
590     proto_.set_group_count(group_count);
591     return *this;
592   }
set_convolution_not_crosscorr(bool conv)593   ConvolutionDescriptor& set_convolution_not_crosscorr(bool conv) {
594     proto_.set_convolution_mode(conv ? ConvolutionMode::CONVOLUTION
595                                      : ConvolutionMode::CROSS_CORRELATION);
596     return *this;
597   }
set_name(const std::string & name)598   ConvolutionDescriptor& set_name(const std::string& name) {
599     proto_.set_name(name);
600     return *this;
601   }
zero_padding_height()602   int64 zero_padding_height() const { return GetDim(padding(), DimIndex::Y); }
zero_padding_width()603   int64 zero_padding_width() const { return GetDim(padding(), DimIndex::X); }
vertical_filter_stride()604   int64 vertical_filter_stride() const {
605     return GetDim(strides(), DimIndex::Y);
606   }
horizontal_filter_stride()607   int64 horizontal_filter_stride() const {
608     return GetDim(strides(), DimIndex::X);
609   }
vertical_dilation_rate()610   int64 vertical_dilation_rate() const {
611     return GetDim(dilations(), DimIndex::Y);
612   }
horizontal_dilation_rate()613   int64 horizontal_dilation_rate() const {
614     return GetDim(dilations(), DimIndex::X);
615   }
616 
zero_padding(DimIndex dim)617   int zero_padding(DimIndex dim) const { return GetDim(padding(), dim); }
filter_stride(DimIndex dim)618   int filter_stride(DimIndex dim) const { return GetDim(strides(), dim); }
dilation_rate(DimIndex dim)619   int dilation_rate(DimIndex dim) const { return GetDim(dilations(), dim); }
620   // TODO(timshen): remove this function. No users of this class is setting a
621   // non-default pad alignment.
pad_alignment()622   PadAlignment pad_alignment() const { return PadAlignment::kDefault; }
group_count()623   int group_count() const { return proto_.group_count(); }
ndims()624   int ndims() const { return padding().size(); }
convolution_not_crosscorr()625   bool convolution_not_crosscorr() const {
626     return proto_.convolution_mode() == ConvolutionMode::CONVOLUTION;
627   }
628 
strides()629   absl::Span<const int64> strides() const {
630     return AsInt64Slice(proto_.strides());
631   }
632 
dilations()633   absl::Span<const int64> dilations() const {
634     return AsInt64Slice(proto_.dilations());
635   }
636 
padding()637   absl::Span<const int64> padding() const {
638     return AsInt64Slice(proto_.paddings());
639   }
640 
name()641   std::string name() const { return proto_.name(); }
642 
643  private:
strides()644   absl::Span<int64> strides() { return AsInt64Slice(proto_.mutable_strides()); }
645 
dilations()646   absl::Span<int64> dilations() {
647     return AsInt64Slice(proto_.mutable_dilations());
648   }
649 
padding()650   absl::Span<int64> padding() {
651     return AsInt64Slice(proto_.mutable_paddings());
652   }
653 
654   ConvolutionDescriptorProto proto_;
655 
656   // TODO(leary) cudnn provides these fields, but need to characterize what
657   // their effect is -- they may be boolean rather than integral.
658   // int64 upscale_input_x;
659   // int64 upscale_input_y;
660 };
661 
662 // A patch of values in the input can be pooled via either a max or an average
663 // operation.
664 // Specify int64 so there's no padding in PoolingDescriptor.
665 enum class PoolingMode : int64 {
666   kMaximum,
667   kAverage,
668 };
669 
670 // Specify the dimension in which to concatenate inputs in space.
671 // Specify int64 so there's no padding in SpaceConcatenateMode.
672 enum class SpaceConcatenateMode : int64 {
673   XDirection,
674   YDirection,
675 };
676 
677 // Returns a short name for the pooling mode, e.g. "Avg".
678 std::string ShortPoolingModeString(PoolingMode mode);
679 
680 // Describes a pooling operation to be enqueued onto a stream via a platform's
681 // DnnSupport.
682 //
683 // TODO(broune): describe how padding works and what happens if the
684 // window height/width is not divisible by the vertical/horizontal
685 // stride.
686 //
687 // Arguments:
688 //  pooling_mode: pooling operator to use on the input patch
689 //  window_height: height of input window
690 //  window_width: width of input window
691 //  vertical_stride: vertical delta for center of the input patch
692 //  horizontal_stride: horizontal delta for center of the input patch
693 class PoolingDescriptor {
694  public:
695   PoolingDescriptor();
696   explicit PoolingDescriptor(int ndims);
697 
set_pooling_mode(PoolingMode value)698   PoolingDescriptor& set_pooling_mode(PoolingMode value) {
699     mode_ = value;
700     return *this;
701   }
set_window_height(int64_t value)702   PoolingDescriptor& set_window_height(int64_t value) {
703     SetDim(&window_, DimIndex::Y, value);
704     return *this;
705   }
set_window_width(int64_t value)706   PoolingDescriptor& set_window_width(int64_t value) {
707     SetDim(&window_, DimIndex::X, value);
708     return *this;
709   }
set_window(DimIndex dim,int64_t value)710   PoolingDescriptor& set_window(DimIndex dim, int64_t value) {
711     SetDim(&window_, dim, value);
712     return *this;
713   }
set_vertical_padding(int64_t value)714   PoolingDescriptor& set_vertical_padding(int64_t value) {
715     SetDim(&padding_, DimIndex::Y, value);
716     return *this;
717   }
set_horizontal_padding(int64_t value)718   PoolingDescriptor& set_horizontal_padding(int64_t value) {
719     SetDim(&padding_, DimIndex::X, value);
720     return *this;
721   }
set_padding(DimIndex dim,int64_t value)722   PoolingDescriptor& set_padding(DimIndex dim, int64_t value) {
723     SetDim(&padding_, dim, value);
724     return *this;
725   }
set_vertical_stride(int64_t value)726   PoolingDescriptor& set_vertical_stride(int64_t value) {
727     SetDim(&strides_, DimIndex::Y, value);
728     return *this;
729   }
set_horizontal_stride(int64_t value)730   PoolingDescriptor& set_horizontal_stride(int64_t value) {
731     SetDim(&strides_, DimIndex::X, value);
732     return *this;
733   }
set_stride(DimIndex dim,int64_t value)734   PoolingDescriptor& set_stride(DimIndex dim, int64_t value) {
735     SetDim(&strides_, dim, value);
736     return *this;
737   }
set_propagate_nans(bool value)738   PoolingDescriptor& set_propagate_nans(bool value) {
739     propagate_nans_ = value;
740     return *this;
741   }
set_name(const std::string & name)742   PoolingDescriptor& set_name(const std::string& name) {
743     name_ = name;
744     return *this;
745   }
746 
ndims()747   int ndims() const { return ndims_; }
748   void CloneFrom(const PoolingDescriptor& other);
749 
750   std::string ToString() const;
751   std::string ToShortString() const;
752 
mode()753   PoolingMode mode() const { return mode_; }
window_height()754   int64 window_height() const { return GetDim(window_, DimIndex::Y); }
window_width()755   int64 window_width() const { return GetDim(window_, DimIndex::X); }
window(DimIndex dim)756   int64 window(DimIndex dim) const { return GetDim(window_, dim); }
vertical_padding()757   int64 vertical_padding() const { return GetDim(padding_, DimIndex::Y); }
horizontal_padding()758   int64 horizontal_padding() const { return GetDim(padding_, DimIndex::X); }
padding(DimIndex dim)759   int64 padding(DimIndex dim) const { return GetDim(padding_, dim); }
vertical_stride()760   int64 vertical_stride() const { return GetDim(strides_, DimIndex::Y); }
horizontal_stride()761   int64 horizontal_stride() const { return GetDim(strides_, DimIndex::X); }
stride(DimIndex dim)762   int64 stride(DimIndex dim) const { return GetDim(strides_, dim); }
window()763   absl::Span<const int64> window() const { return window_; }
padding()764   absl::Span<const int64> padding() const { return padding_; }
strides()765   absl::Span<const int64> strides() const { return strides_; }
propagate_nans()766   bool propagate_nans() const { return propagate_nans_; }
name()767   std::string name() const { return name_; }
768 
769  private:
770   PoolingMode mode_;
771   int ndims_;
772   bool propagate_nans_;
773   std::string name_;  // Name as in Tensorflow NodeDef, for debugging purposes.
774 
775   // Stored as: ..., y, x.
776   std::vector<int64> window_;
777   std::vector<int64> padding_;
778   std::vector<int64> strides_;
779 };
780 
781 // Collects parameters for DNN algorithms
782 class AlgorithmDesc {
783  public:
784   typedef int64 Index;
785   typedef std::string Tag;
AlgorithmDesc()786   AlgorithmDesc() : AlgorithmDesc(0, false) {}
AlgorithmDesc(Index a,bool use_tensor_ops)787   AlgorithmDesc(Index a, bool use_tensor_ops) {
788     proto_.set_algo_id(a);
789     proto_.set_math_type(use_tensor_ops ? AlgorithmProto::TENSOR_OP_MATH
790                                         : AlgorithmProto::DEFAULT_MATH);
791   }
AlgorithmDesc(Tag a,void * b)792   AlgorithmDesc(Tag a, void* b) {
793     proto_.set_exec_plan_id(a);
794     exec_plan_desc_ = b;
795   }
IsExecutionPlan()796   bool IsExecutionPlan() const { return exec_plan_id() != ""; }
tensor_ops_enabled()797   bool tensor_ops_enabled() const {
798     return proto_.math_type() == AlgorithmProto::TENSOR_OP_MATH;
799   }
algo_id()800   Index algo_id() const { return proto_.algo_id(); }
exec_plan_id()801   Tag exec_plan_id() const { return proto_.exec_plan_id(); }
exec_plan_desc()802   void* exec_plan_desc() const { return exec_plan_desc_; }
803   bool operator==(const AlgorithmDesc& other) const {
804     if (IsExecutionPlan()) {
805       return exec_plan_id() == other.exec_plan_id();
806     }
807     return algo_id() == other.algo_id() &&
808            tensor_ops_enabled() == other.tensor_ops_enabled();
809   }
810   uint64 hash() const;
811 
ToProto()812   AlgorithmProto ToProto() const { return proto_; }
813 
814   std::string ToString() const;
815 
816  private:
817   AlgorithmProto proto_;
818   // We keep a pointer for the execution plan if cuDNN v8 is used. Note,
819   // AlgorithmDesc doesn't own it.
820   void* exec_plan_desc_;
821 };
822 
823 // Describes the result from a perf experiment.
824 //
825 // Arguments:
826 //  algorithm: returns the exact algorithm that was used.
827 //  elapsed_time_in_ms: returns the measured elapsed time in milliseconds.
828 class ProfileResult {
829  public:
is_valid()830   bool is_valid() const {
831     return algorithm_.has_value() &&
832            elapsed_time_in_ms() != std::numeric_limits<float>::max();
833   }
834 
algorithm()835   AlgorithmDesc algorithm() const { return *algorithm_; }
set_algorithm(AlgorithmDesc val)836   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
837 
elapsed_time_in_ms()838   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)839   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
840 
scratch_size()841   size_t scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)842   void set_scratch_size(size_t val) { scratch_size_ = val; }
843 
844  private:
845   absl::optional<AlgorithmDesc> algorithm_;
846   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
847   // The scratch size algorithm_ requires. Currently it's only populated by
848   // convolutions.
849   size_t scratch_size_ = 0;
850 };
851 
852 // Describes the configuration for the algorithms that will used.
853 //
854 // Arguments:
855 //  algorithm: the primary algorithm that should be used.
856 //  algorithm_no_scratch: a secondary algorithm that should be used, if the
857 //    the allocation for the scratch memory fails.
858 //  scrach_size: specify the size of scratch memory in bytes needed for the
859 //    algorithm used.
860 //
861 // On CUDA platform with CUDNN library, algorithm and algorithm_no_scratch
862 // would be used. On ROCm platform with MIOpen library, algorithm and
863 // scratch_size would be used. The major difference between the two platforms
864 // are whether it's possible to get an algorithm without scratch memory. On
865 // CUDA + CUDNN it's possible, and algorithm_no_scratch can be used to track
866 // such information, whereas on ROCm + MIOpen there is no guarantee to getting
867 // one without scratch memory, and scratch_size field is used to track it.
868 class AlgorithmConfig {
869  public:
AlgorithmConfig()870   AlgorithmConfig() {}
AlgorithmConfig(AlgorithmDesc algorithm)871   explicit AlgorithmConfig(AlgorithmDesc algorithm) : algorithm_(algorithm) {}
AlgorithmConfig(AlgorithmDesc algorithm,size_t scratch_size)872   AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size)
873       : algorithm_(algorithm), scratch_size_(scratch_size) {}
AlgorithmConfig(AlgorithmDesc algorithm,AlgorithmDesc algorithm_no_scratch)874   AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch)
875       : algorithm_(algorithm), algorithm_no_scratch_(algorithm_no_scratch) {}
AlgorithmConfig(AlgorithmDesc algorithm,size_t scratch_size,AlgorithmDesc algorithm_no_scratch)876   AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size,
877                   AlgorithmDesc algorithm_no_scratch)
878       : algorithm_(algorithm),
879         algorithm_no_scratch_(algorithm_no_scratch),
880         scratch_size_(scratch_size) {}
881 
882   // TODO(ruochengw): After cl/380702564, add support for algorithm configs with
883   // cuDNN Frontend APIs.
AlgorithmConfig(const AlgorithmConfigProto & algorithm_config_proto)884   explicit AlgorithmConfig(const AlgorithmConfigProto& algorithm_config_proto) {
885     const AlgorithmProto& algorithm_proto = algorithm_config_proto.algorithm();
886     algorithm_ = AlgorithmDesc(
887         algorithm_proto.algo_id(),
888         algorithm_proto.math_type() == AlgorithmProto::TENSOR_OP_MATH);
889     if (algorithm_config_proto.optional_scratch_size_case() !=
890         /*ONEOF_NAME_NOT_SET=*/0) {
891       scratch_size_ = algorithm_config_proto.scratch_size();
892     }
893     if (algorithm_config_proto.optional_algorithm_no_scratch_case() !=
894         /*ONEOF_NAME_NOT_SET=*/0) {
895       const AlgorithmProto& algorithm_no_scratch_proto =
896           algorithm_config_proto.algorithm_no_scratch();
897       algorithm_no_scratch_ = AlgorithmDesc(
898           algorithm_no_scratch_proto.algo_id(),
899           /*use_tensor_ops=*/algorithm_no_scratch_proto.math_type() ==
900               AlgorithmProto::TENSOR_OP_MATH);
901     }
902   }
903 
algorithm()904   absl::optional<AlgorithmDesc> algorithm() const { return algorithm_; }
set_algorithm(AlgorithmDesc val)905   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
algorithm_no_scratch()906   absl::optional<AlgorithmDesc> algorithm_no_scratch() const {
907     return algorithm_no_scratch_;
908   }
set_algorithm_no_scratch(AlgorithmDesc val)909   void set_algorithm_no_scratch(AlgorithmDesc val) {
910     algorithm_no_scratch_ = val;
911   }
scratch_size()912   absl::optional<size_t> scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)913   void set_scratch_size(size_t val) { scratch_size_ = val; }
914   bool operator==(const AlgorithmConfig& other) const {
915     return this->algorithm_ == other.algorithm_ &&
916            this->algorithm_no_scratch_ == other.algorithm_no_scratch_ &&
917            this->scratch_size_ == other.scratch_size_;
918   }
919   bool operator!=(const AlgorithmConfig& other) const {
920     return !(*this == other);
921   }
922   std::string ToString() const;
set_plan(std::unique_ptr<dnn::ConvolveExecutionPlan> & plan)923   void set_plan(std::unique_ptr<dnn::ConvolveExecutionPlan>& plan) {
924     plan_ = std::move(plan);
925   }
set_plan_no_scratch(std::unique_ptr<dnn::ConvolveExecutionPlan> & plan)926   void set_plan_no_scratch(std::unique_ptr<dnn::ConvolveExecutionPlan>& plan) {
927     plan_no_scratch_ = std::move(plan);
928   }
929 
930   // TODO(ruochengw): After cl/380702564, add support for algorithm configs with
931   // cuDNN Frontend APIs.
ToProto()932   AlgorithmConfigProto ToProto() const {
933     AlgorithmConfigProto algorithm_config_proto;
934     if (algorithm_.has_value()) {
935       *algorithm_config_proto.mutable_algorithm() =
936           algorithm_.value().ToProto();
937     }
938     if (algorithm_no_scratch_.has_value()) {
939       *algorithm_config_proto.mutable_algorithm_no_scratch() =
940           algorithm_no_scratch_.value().ToProto();
941     }
942     if (scratch_size_.has_value()) {
943       algorithm_config_proto.set_scratch_size(scratch_size_.value());
944     }
945     return algorithm_config_proto;
946   }
947 
948  private:
949   absl::optional<AlgorithmDesc> algorithm_;
950   absl::optional<AlgorithmDesc> algorithm_no_scratch_;
951   absl::optional<size_t> scratch_size_;
952   std::shared_ptr<const dnn::ConvolveExecutionPlan> plan_;
953   std::shared_ptr<const dnn::ConvolveExecutionPlan> plan_no_scratch_;
954 };
955 
956 // Describes a local response normalization (LRN). LRN is used e.g. in
957 // dist_belief.
958 //
959 // Let V be the vector of feature maps at some (batch, y, x)
960 // coordinate. LRN applies independently to each vector V in the
961 // input, across all coordinates (batch, y, x), by mapping each V to
962 // another vector U of the same size using the formula
963 //
964 //   U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
965 //
966 // where the sum is taken over j in the closed range [i - range, i + range].
967 //
968 // When calculating U_i the j in the sum can extend beyond the bounds
969 // of V. If wrap_around is true, then V_j = V_{j mod F} where F is the
970 // size of V, which is the number of feature maps. If wrap_around is
971 // false, then V_j = 0 for j outside [0, F-1].
972 //
973 // If segment_size <= F, where F is the number of feature_maps, then
974 // segment_size has no effect. Otherwise, each consecutive segment of
975 // segment_size entries in V are normalized separately.
976 //
977 // Not all StreamExecutors allow wrap_around == true or segment_size
978 // != 64. Some do not implement normalization at all.
979 class NormalizeDescriptor {
980  public:
981   NormalizeDescriptor();
982 
set_bias(float bias)983   NormalizeDescriptor& set_bias(float bias) {
984     bias_ = bias;
985     return *this;
986   }
987 
set_range(int32_t range)988   NormalizeDescriptor& set_range(int32_t range) {
989     range_ = range;
990     return *this;
991   }
992 
set_alpha(float alpha)993   NormalizeDescriptor& set_alpha(float alpha) {
994     alpha_ = alpha;
995     return *this;
996   }
997 
set_beta(float beta)998   NormalizeDescriptor& set_beta(float beta) {
999     beta_ = beta;
1000     return *this;
1001   }
1002 
set_wrap_around(bool wrap_around)1003   NormalizeDescriptor& set_wrap_around(bool wrap_around) {
1004     wrap_around_ = wrap_around;
1005     return *this;
1006   }
1007 
set_segment_size(int32_t segment_size)1008   NormalizeDescriptor& set_segment_size(int32_t segment_size) {
1009     segment_size_ = segment_size;
1010     return *this;
1011   }
1012 
1013   void CloneFrom(const NormalizeDescriptor& other);
1014 
1015   std::string ToString() const;
1016   std::string ToShortString() const;
1017 
bias()1018   float bias() const { return bias_; }
range()1019   int32 range() const { return range_; }
alpha()1020   float alpha() const { return alpha_; }
beta()1021   float beta() const { return beta_; }
wrap_around()1022   bool wrap_around() const { return wrap_around_; }
segment_size()1023   int32 segment_size() const { return segment_size_; }
1024 
1025  private:
1026   float bias_;
1027   int32 range_;
1028   float alpha_;
1029   float beta_;
1030   bool wrap_around_;
1031   int32 segment_size_;
1032 };
1033 
1034 // Returns a string representation of the given activation mode.
1035 std::string ActivationModeString(ActivationMode mode);
1036 
1037 // Describes the operation that DoElementwiseOperation should perform on its
1038 // inputs.
1039 enum class ElementwiseOperation { kAdd, kMultiply };
1040 
1041 std::string ElementwiseOperationString(ElementwiseOperation op);
1042 
1043 // A simple class representing the version of the backing library, to
1044 // workaround the "too perfect forwarding" issue in gcc6+ compilers.
1045 // See PR#16309 and issue #18402 for links discussing the issue.
1046 class VersionInfo {
1047  public:
1048   VersionInfo(int major = 0, int minor = 0, int patch = 0)
major_(major)1049       : major_(major), minor_(minor), patch_(patch) {}
major_version()1050   int major_version() const { return major_; }
minor_version()1051   int minor_version() const { return minor_; }
patch()1052   int patch() const { return patch_; }
1053 
1054  private:
1055   int major_;
1056   int minor_;
1057   int patch_;
1058 };
1059 
1060 // Suite of operations typically used for implementing Deep/Convolutional Neural
1061 // Nets. Note: A false return value of an operation indicates the
1062 // implementation is not available.
1063 //
1064 // TODO(b/118763918): this class (or rather dispatch table) has several
1065 // problems:
1066 // * Some overloads are missing. Ideally we want to have template virtual
1067 //   functions while the template arguments is a closed set. However, we don't
1068 //   get that from the language.
1069 // * The API is a union of cuDNN and another private backend. Only 10% of the
1070 //   functions are actually implemented by both backends, the rest are
1071 //   actually backend-specific. The massive interface creates extra mental
1072 //   burden.
1073 // * Poor error handling: the API should return Status objects.
1074 //
1075 // PrepareForConvolution is an example for how new APIs should be written.
1076 class DnnSupport {
1077  public:
DnnSupport()1078   DnnSupport() {}
~DnnSupport()1079   virtual ~DnnSupport() {}
1080 
1081   virtual port::Status Init() = 0;
1082 
1083   // Gets the version of the backing library, as a VersionInfo object.
GetVersion()1084   virtual port::StatusOr<VersionInfo> GetVersion() {
1085     return port::UnimplementedError(
1086         "DnnSupport::GetVersion not implemented on this platform.");
1087   }
1088 
1089   // Performs a single-precision forward batch normalization operation onto
1090   // the stream.
1091   //
1092   // Arguments:
1093   //  stream: borrowed pointer to the stream that the batch normalization
1094   //    operation should be enqueued onto.
1095   //  x: input data.
1096   //  scale: scaling parameters.
1097   //  offset: offset parameters.
1098   //  estimated_mean: population mean estimated during training.
1099   //    Used for inference only; empty for training.
1100   //  estimated_variance: population variance estimated during training,
1101   //    used for inference only; empty for training.
1102   //  side_input: optional input that is element-wise added to the output of
1103   //    batch normalization.
1104   //  x_desc: dimensions of the input data, which is the same as the dimensions
1105   //    of the output and side input.
1106   //  scale_offset_desc: dimensions of scale and offset.
1107   //  epsilon: a small floating point number added to the variance of x.
1108   //  activation_mode: activation applied to the result of batch normalization
1109   //    (or after adding optional side input)
1110   //  y: output data.
1111   //  batch_mean: batch mean, to be used to compute the running mean.
1112   //  batch_variance: batch variance, to be used to compute
1113   //    the running variance.
1114   //  reserve_space_1: saved mean, to be reused in the backward gradient
1115   //    computation.
1116   //  reserve_space_2: saved inv_var (1/sqrt(epsilon + variance), to be reused
1117   //    in the backward gradient computation.
1118   //  is_training: Set to true for training, false for inference.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1119   virtual bool DoBatchNormalizationForward(
1120       Stream* stream, const DeviceMemory<float>& x,
1121       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1122       const DeviceMemory<float>& estimated_mean,
1123       const DeviceMemory<float>& estimated_variance,
1124       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
1125       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1126       const double exponential_average_factor,
1127       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
1128       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1129       DeviceMemory<float>* reserve_space_1,
1130       DeviceMemory<float>* reserve_space_2, bool is_training,
1131       ScratchAllocator* reserve_space_allocator,
1132       ScratchAllocator* workspace_allocator) {
1133     return false;
1134   }
1135 
1136   // Performs a half-precision forwards batch normalization operation onto the
1137   // stream. See DoBatchNormalizationForward above for argument details.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1138   virtual bool DoBatchNormalizationForward(
1139       Stream* stream, const DeviceMemory<Eigen::half>& x,
1140       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1141       const DeviceMemory<float>& estimated_mean,
1142       const DeviceMemory<float>& estimated_variance,
1143       const DeviceMemory<Eigen::half>& side_input,
1144       const dnn::BatchDescriptor& x_desc,
1145       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1146       const double exponential_average_factor,
1147       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
1148       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1149       DeviceMemory<float>* reserve_space_1,
1150       DeviceMemory<float>* reserve_space_2, bool is_training,
1151       ScratchAllocator* reserve_space_allocator,
1152       ScratchAllocator* workspace_allocator) {
1153     return false;
1154   }
1155 
1156   // Performs a single-precision backward batch normalization gradient
1157   // computation operation onto the stream.
1158   //
1159   // Arguments:
1160   //  stream: borrowed pointer to the stream that the batch normalization
1161   //    gradient computation operation should be enqueued onto.
1162   //  y_backprop: gradient with regard to output y.
1163   //  x: input data.
1164   //  scale: scaling parameters.
1165   //  inv_var: 1/sqrt(epsilon + variance) of x.
1166   //  x_desc: dimensions of the input data, which is the same as the dimensions
1167   //    of the output.
1168   //  scale_offset_desc: dimensions of scale and offset.
1169   //  epsilon: a small floating point number added to the variance of x.
1170   //  x_backprop: gradient with respect to input x.
1171   //  scale_backprop: gradient with respect to scale.
1172   //  offset_backprop: gradient with respect to offset.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1173   virtual bool DoBatchNormalizationBackward(
1174       Stream* stream, const DeviceMemory<float>& y_backprop,
1175       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
1176       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
1177       const dnn::BatchDescriptor& x_desc,
1178       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1179       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
1180       DeviceMemory<float>* offset_backprop,
1181       DeviceMemory<uint8>* reserve_space_data,
1182       ScratchAllocator* workspace_allocator) {
1183     return false;
1184   }
1185 
1186   // Performs a half-precision backward batch normalization gradient computation
1187   // operation onto the stream. See DoBatchNormalizationBackward above for
1188   // argument details.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1189   virtual bool DoBatchNormalizationBackward(
1190       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
1191       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
1192       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
1193       const dnn::BatchDescriptor& x_desc,
1194       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1195       DeviceMemory<Eigen::half>* x_backprop,
1196       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
1197       DeviceMemory<uint8>* reserve_space_data,
1198       ScratchAllocator* workspace_allocator) {
1199     return false;
1200   }
1201 
1202   // Enqueues a fused convolution operation onto the stream.
1203   // We provide several variants with different types for inputs, biases and
1204   // scaling parameters.
1205   //
1206   // Arguments (all borrowed):
1207   //  stream: borrowed pointer to the stream that the 'convolve' operation
1208   //    should be enqueued onto.
1209   //  conv_input_descriptor: dimensions of the convolution input layer.
1210   //  conv_input_data: un-owned device memory region which contains the
1211   //    convolution input.
1212   //  conv_input_scale: a floating point scale to multiply with each element
1213   //    of conv_input_data.
1214   //  filter_descriptor: dimensions of the convolution filter.
1215   //  filter_data: un-owned device memory region which contains the
1216   //    convolution filter weights.
1217   //  convolution_descriptor: stride of the convolution filter.
1218   //  biases: un-owned device memory region containing biases to add to the
1219   //    input.
1220   //  activation_mode: Type of activation to perform.
1221   //  side_input_data: un-owned device memory region which contains optional
1222   //    side input data. If 'side_input_scale' is non-zero, then this must
1223   //    point to data in the tensor shape specified by output_shape.
1224   //    It will be scaled by 'side_input_scale' and added to the convolution
1225   //    result and bias prior to applying the activation function.
1226   //  side_input_scale: a floating point scale to multiply with each element
1227   //    of side_input_data.
1228   //  output_descriptor: dimensions of the output layer.
1229   //  output_data: un-owned device memory region in which to place the
1230   //    convolution result.
1231   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1232   //    space in order to speed up the convolution operation.
1233   //  algorithm_config: specifies which algorithm should be used for the
1234   //    operation.
1235   //  output_profile_result: the output profile result for this call. The
1236   //    profiling is only enabled when this is not nullptr.
1237   //
1238   // conv_input_descriptor, filter_descriptor, convolution_descriptor and
1239   // output_descriptor together specify exactly how the convolution is aligned
1240   // with the input data:
1241   //
1242   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1243   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1244   // * input dimensions / filter stride == output dimensions
1245   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1246   //   same size - this requires padding the input.
1247   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1248   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1249   //   that if the inverse of the filter is applied to the output in VALID mode
1250   //   the result is the same size as the input - this requires even more
1251   //   padding of the input.
DoFusedConvolve(Stream * stream,DataType input_type,DataType side_input_type,DataType bias_type,DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1252   virtual port::Status DoFusedConvolve(
1253       Stream* stream, DataType input_type, DataType side_input_type,
1254       DataType bias_type, DataType output_type,
1255       const dnn::BatchDescriptor& conv_input_descriptor,
1256       DeviceMemoryBase conv_input_data, double conv_input_scale,
1257       const dnn::FilterDescriptor& filter_descriptor,
1258       DeviceMemoryBase filter_data,
1259       const dnn::ConvolutionDescriptor& convolution_descriptor,
1260       DeviceMemoryBase side_input_data, double side_input_scale,
1261       const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
1262       dnn::ActivationMode activation_mode,
1263       const dnn::BatchDescriptor& output_descriptor,
1264       DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
1265       const dnn::AlgorithmConfig& algorithm_config,
1266       dnn::ProfileResult* output_profile_result) {
1267     return port::UnimplementedError(
1268         "DnnSupport::DoFusedConvolve not implemented on this platform.");
1269   }
1270 
1271   template <typename ElementType, typename OutputType>
PrepareForConvolution(ConvolutionKind kind,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemory<ElementType> input_data,const FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> filter_data,const BatchDescriptor & output_descriptor,DeviceMemory<OutputType> output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)1272   port::Status PrepareForConvolution(
1273       ConvolutionKind kind, Stream* stream,
1274       const BatchDescriptor& batch_descriptor,
1275       DeviceMemory<ElementType> input_data,
1276       const FilterDescriptor& filter_descriptor,
1277       DeviceMemory<ElementType> filter_data,
1278       const BatchDescriptor& output_descriptor,
1279       DeviceMemory<OutputType> output_data,
1280       const ConvolutionDescriptor& convolution_descriptor,
1281       const AlgorithmConfig& algorithm_config,
1282       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
1283       DeviceMemory<uint8>* scratch_memory) {
1284     return DoPrepareForConvolution(
1285         kind, ToDataType<ElementType>::value, stream, batch_descriptor,
1286         input_data, filter_descriptor, filter_data, output_descriptor,
1287         output_data, convolution_descriptor, algorithm_config,
1288         scratch_allocator, algorithm_desc, scratch_memory);
1289   }
1290 
1291   // Enqueues a single-precision convolution operation onto the stream.
1292   //
1293   // Arguments (all borrowed):
1294   //  stream: borrowed pointer to the stream that the 'convolve' operation
1295   //    should be enqueued onto.
1296   //  input_descriptor: dimensions of the input layer.
1297   //  input_data: un-owned device memory region which contains the
1298   //    convolution input.
1299   //  filter_descriptor: dimensions of the convolution filter.
1300   //  convolution_descriptor: stride of the convolution filter.
1301   //  output_descriptor: dimensions of the output layer.
1302   //  output_data: un-owned device memory region in which to place the
1303   //    convolution result.
1304   //  algorithm_desc: specifies which algorithm should be used for the
1305   //    operation.
1306   //  scratch: un-owned device memory for scratch space in order to speed up
1307   //    the convolution operation.
1308   //  output_profile_result: the output profile result for this call. The
1309   //    profiling is only enabled when this is not nullptr.
1310   //
1311   // input_descriptor, filter_descriptor, convolution_descriptor and
1312   // output_descriptor together specify exactly how the convolution is aligned
1313   // with the input data:
1314   //
1315   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1316   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1317   // * input dimensions / filter stride == output dimensions
1318   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1319   //   same size - this requires padding the input.
1320   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1321   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1322   //   that if the inverse of the filter is applied to the output in VALID mode
1323   //   the result is the same size as the input - this requires even more
1324   //   padding of the input.
1325   virtual port::Status DoConvolve(
1326       ConvolutionKind kind, DataType element_type, DataType output_type,
1327       Stream* stream, const BatchDescriptor& input_descriptor,
1328       DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor,
1329       DeviceMemoryBase filter_data, const BatchDescriptor& output_descriptor,
1330       DeviceMemoryBase output_data,
1331       const ConvolutionDescriptor& convolution_descriptor,
1332       AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
1333       ProfileResult* output_profile_result) = 0;
1334 
1335   // Return a list of algorithms supported by the forward convolution pass.
1336   // cc_major and cc_minor are the compute capabilities of the device.
1337   virtual bool GetConvolveAlgorithms(
1338       CudaComputeCapability cuda_compute_capability,
1339       std::vector<AlgorithmDesc>* out_algorithms);
1340 
1341   virtual bool GetConvolveExecutionPlans(
1342       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
1343       const dnn::BatchDescriptor& input_descriptor,
1344       const dnn::FilterDescriptor& filter_descriptor,
1345       const dnn::BatchDescriptor& output_descriptor,
1346       const dnn::ConvolutionDescriptor& convolution_descriptor,
1347       std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* out_exec_plans);
1348 
1349   virtual bool GetMIOpenConvolveAlgorithms(
1350       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
1351       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
1352       const dnn::FilterDescriptor& filter_descriptor,
1353       DeviceMemoryBase filter_data,
1354       const dnn::BatchDescriptor& output_descriptor,
1355       DeviceMemoryBase output_data,
1356       const dnn::ConvolutionDescriptor& convolution_descriptor,
1357       ScratchAllocator* scratch_allocator,
1358       std::vector<ProfileResult>* out_algorithms);
1359 
1360   // Returns a list of supported rnn algorithms.
1361   virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
1362 
1363   // Version of DoConvolve that uses pre-quantized 8 bit coefficients.
1364   // coefficient_scales specifies the scaling of each column of coefficients:
1365   // original float coefficient[row * num_columns + column] =
1366   //     quantized coefficient[row * num_columns + column] *
1367   //     coefficient_scales[column].
1368   virtual bool DoConvolveQuantized(
1369       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1370       const DeviceMemory<float>& input_data,
1371       const dnn::FilterDescriptor& filter_descriptor,
1372       const DeviceMemory<int8>& filter_coefficients,
1373       const DeviceMemory<float>& coefficient_scales,
1374       const dnn::ConvolutionDescriptor& convolution_descriptor,
1375       const dnn::BatchDescriptor& output_descriptor,
1376       DeviceMemory<float>* output_data) = 0;
1377 
1378   // Same as DoConvolveQuantized above, but int8 filter coefficients.
1379   virtual bool DoConvolveQuantized(
1380       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1381       const DeviceMemory<float>& input_data,
1382       const dnn::FilterDescriptor& filter_descriptor,
1383       const DeviceMemory<int16>& filter_coefficients,
1384       const DeviceMemory<float>& coefficient_scales,
1385       const dnn::ConvolutionDescriptor& convolution_descriptor,
1386       const dnn::BatchDescriptor& output_descriptor,
1387       DeviceMemory<float>* output_data) = 0;
1388 
1389   // Variation of the above with the weight matrix split into two matrices.
1390   // first_weights: Coefficients of the first matrix.
1391   // second_weights: Coefficients of the second matrix.
1392   // depth_multiplier: specifies the columns of the first matrix and rows
1393   // of the second one - first_weights columns = depth_multiplier,
1394   // second_weights rows = depth_multiplier *
1395   //                       filter_descriptor.input_feature_map_count().
1396   // see go/separable for documentation on separable convolutions.
1397   virtual bool DoSeparableConvolve(
1398       Stream* stream, const BatchDescriptor& input_descriptor,
1399       const DeviceMemory<float>& input_data,
1400       const FilterDescriptor& filter_descriptor, int depth_multiplier,
1401       const DeviceMemory<float>& first_weights,
1402       const DeviceMemory<float>& second_weights,
1403       const ConvolutionDescriptor& convolution_descriptor,
1404       const BatchDescriptor& output_descriptor,
1405       DeviceMemory<float>* output_data) = 0;
1406 
1407   // Return a list of algorithms supported by the backward convolution pass for
1408   // data.
1409   virtual bool GetConvolveBackwardDataAlgorithms(
1410       CudaComputeCapability cuda_compute_capability,
1411       std::vector<AlgorithmDesc>* out_algorithms);
1412 
1413   // Return a list of algorithms supported by the backward convolution pass for
1414   // filters.
1415   virtual bool GetConvolveBackwardFilterAlgorithms(
1416       CudaComputeCapability cuda_compute_capability,
1417       std::vector<AlgorithmDesc>* out_algorithms);
1418 
1419   // Fully connects the "nodes" (float values) in input_data with
1420   // shape input_dimensions to output_data with output_dimensions
1421   // using provided weights. This is equivalent to computing a matrix
1422   // product, hence the name MatMul.
1423   //
1424   // A BatchDescriptor has four dimensions: batch, y, x, depth. Matrix products
1425   // happen in two dimensions. To get down to two dimensions, we consider the
1426   // input y, x and depth dimension as one combined dimension T. For now,
1427   // assume that the output height and width are 1 and let OD be the output
1428   // depth.
1429   //
1430   // There are three device memory buffers passed in to this
1431   // function. We can now view all three as matrices:
1432   //
1433   //   input_data: A batch x T matrix
1434   //   weights: A T x OD matrix
1435   //   output_data: A batch x OD matrix
1436   //
1437   // This function then computes the matrix product of input_data and
1438   // weights and writes the result into output_data.
1439   //
1440   // Here the weights buffer is in row major order, i.e. the first OD
1441   // entries in weights are the first row, the second OD entries in
1442   // weights are the second row and so on.
1443   //
1444   // The case for output width*height > 1 is more complicated. Let K =
1445   // OY * OX where OY is the output height and OX is the output
1446   // width. Then weights is divided into K sub-arrays W_i, for
1447   // i=0,...,k-1, that each represent a T x OD matrix. This function
1448   // then computes the K matrix multiplications of input_data with
1449   // each W_i. This creates K matrices with dimensions batch x
1450   // OD. These K matrices are concatenated horizontally to form one
1451   // larger matrix with dimensions batch x (K*OD); note that this is
1452   // not the same as concatenating the bytes of the matrices. The
1453   // combined matrix can then be interpreted as a tensor with
1454   // dimensions (batch, OY, OX, OD). If the output tensor format is
1455   // not kBatchYXDepth, this function would then need to arrange for
1456   // the output to be in the requested layout, if that is
1457   // supported. Note that the case K=1 is equivalent to the
1458   // description above. It is recommended to prefer the case K=1.
1459   //
1460   // Arguments (all borrowed):
1461   //  stream: borrowed pointer to the stream that the 'fully connect' operation
1462   //    should be enqueued onto.
1463   //  output_data: un-owned device memory region in which to place the
1464   //    fully connected result.
1465   virtual bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
1466                         const DeviceMemory<float>& weights,
1467                         const dnn::BatchDescriptor& input_dimensions,
1468                         const dnn::BatchDescriptor& output_dimensions,
1469                         DeviceMemory<float>* output_data) = 0;
1470 
1471   // Version of DoMatMul that uses pre-quantized 8 bit weights.
1472   // weight_scales specifies the scaling of each column of weights:
1473   // original float weight[row * num_columns + column] =
1474   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1475   virtual bool DoMatMulQuantized(Stream* stream,
1476                                  const DeviceMemory<float>& input_data,
1477                                  const DeviceMemory<int8>& quantized_weights,
1478                                  const DeviceMemory<float>& weight_scales,
1479                                  const dnn::BatchDescriptor& input_dimensions,
1480                                  const dnn::BatchDescriptor& output_dimensions,
1481                                  DeviceMemory<float>* output_data) = 0;
1482 
1483   // Version of DoMatMul that uses pre-quantized 16 bit weights.
1484   // weight_scales specifies the scaling of each column of weights:
1485   // original float weight[row * num_columns + column] =
1486   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1487   virtual bool DoMatMulQuantized(Stream* stream,
1488                                  const DeviceMemory<float>& input_data,
1489                                  const DeviceMemory<int16>& quantized_weights,
1490                                  const DeviceMemory<float>& weight_scales,
1491                                  const dnn::BatchDescriptor& input_dimensions,
1492                                  const dnn::BatchDescriptor& output_dimensions,
1493                                  DeviceMemory<float>* output_data) = 0;
1494 
1495   // Adds biases to the feature maps in input_data producing
1496   // output_data. input_data can equal output_data, but must not
1497   // partially overlap it.
1498   //
1499   // Let K = count() * height() * width() and N = feature_map_count()
1500   // on dimensions. Then input_value contains K*N values and biases
1501   // contains N values. We can thus logically consider input_value to
1502   // contain K vectors of N elements each. This function adds biases
1503   // to each of those N vectors.
1504   //
1505   // TODO(broune): This works differently when width() * height() > 1
1506   // and the call to ThenBiasAdd() follows a call to ThenMatMul(). In
1507   // that case there should be width() * height() *
1508   // feature_map_count() biases, but this is not implemented on all
1509   // StreamExecutors.
1510   //
1511   // Arguments (all borrowed):
1512   //  stream: borrowed pointer to the stream that the 'bias add' operation
1513   //    should be enqueued onto.
1514   //  input_data: un-owned device memory region containing the input.
1515   //  biases: un-owned device memory region containing biases to add to the
1516   //    input.
1517   //  dimensions: dimensions of input_data and output_data.
1518   //  output_data: un-owned device memory region in which to place the result.
1519   virtual bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
1520                          const DeviceMemory<float>& biases,
1521                          const dnn::BatchDescriptor& dimensions,
1522                          DeviceMemory<float>* output_data) = 0;
1523 
1524   // Performs a forward pooling operation on input_data, writing to
1525   // output_data. See PoolingDescriptor for how to configure the
1526   // pooling operation.
1527   //
1528   // Pooling happens as a window that moves across the Y and X
1529   // dimensions of input_data, where each position of the window
1530   // yields one output value. E.g. for max pooling, the computed value
1531   // is the maximum element in the window. The operation is applied
1532   // independently to each batch and at each feature map (depth), so
1533   // that the output depth and feature_map_count are the same as for
1534   // the input. The output width and height can be different.
1535   //
1536   // See PoolingDescriptor for how to configure the pooling operation.
1537   virtual bool DoPoolForward(Stream* stream,
1538                              const dnn::PoolingDescriptor& pooling_dimensions,
1539                              const dnn::BatchDescriptor& input_dimensions,
1540                              const DeviceMemory<float>& input_data,
1541                              const dnn::BatchDescriptor& output_dimensions,
1542                              DeviceMemory<float>* output_data,
1543                              ScratchAllocator* workspace_allocator) = 0;
1544 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)1545   virtual bool DoPoolForward(Stream* stream,
1546                              const dnn::PoolingDescriptor& pooling_dimensions,
1547                              const dnn::BatchDescriptor& input_dimensions,
1548                              const DeviceMemory<double>& input_data,
1549                              const dnn::BatchDescriptor& output_dimensions,
1550                              DeviceMemory<double>* output_data,
1551                              ScratchAllocator* workspace_allocator) {
1552     LOG(FATAL) << "DoPoolForward not implemented for double.";
1553     return false;
1554   }
1555 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)1556   virtual bool DoPoolForward(Stream* stream,
1557                              const dnn::PoolingDescriptor& pooling_dimensions,
1558                              const dnn::BatchDescriptor& input_dimensions,
1559                              const DeviceMemory<Eigen::half>& input_data,
1560                              const dnn::BatchDescriptor& output_dimensions,
1561                              DeviceMemory<Eigen::half>* output_data,
1562                              ScratchAllocator* workspace_allocator) {
1563     LOG(FATAL) << "DoPoolForward not implemented for float16.";
1564     return false;
1565   }
1566 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)1567   virtual bool DoPoolForward(Stream* stream,
1568                              const dnn::PoolingDescriptor& pooling_dimensions,
1569                              const dnn::BatchDescriptor& input_dimensions,
1570                              const DeviceMemory<int8>& input_data,
1571                              const dnn::BatchDescriptor& output_dimensions,
1572                              DeviceMemory<int8>* output_data,
1573                              ScratchAllocator* workspace_allocator) {
1574     LOG(FATAL) << "DoPoolForward not implemented for int8.";
1575     return false;
1576   }
1577 
1578   // Performs differentiation of the pooling operation.
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)1579   virtual bool DoPoolBackward(Stream* stream,
1580                               const dnn::PoolingDescriptor& pooling_dimensions,
1581                               const dnn::BatchDescriptor& input_dimensions,
1582                               const DeviceMemory<double>& input_data,
1583                               const dnn::BatchDescriptor& output_dimensions,
1584                               const DeviceMemory<double>& output_data,
1585                               const DeviceMemory<double>& input_diff_data,
1586                               DeviceMemory<double>* output_diff_data,
1587                               ScratchAllocator* workspace_allocator) {
1588     LOG(FATAL) << "DoPoolBackward not implemented.";
1589     return false;
1590   }
1591 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)1592   virtual bool DoPoolBackward(Stream* stream,
1593                               const dnn::PoolingDescriptor& pooling_dimensions,
1594                               const dnn::BatchDescriptor& input_dimensions,
1595                               const DeviceMemory<float>& input_data,
1596                               const dnn::BatchDescriptor& output_dimensions,
1597                               const DeviceMemory<float>& output_data,
1598                               const DeviceMemory<float>& input_diff_data,
1599                               DeviceMemory<float>* output_diff_data,
1600                               ScratchAllocator* workspace_allocator) {
1601     LOG(FATAL) << "DoPoolBackward not implemented.";
1602     return false;
1603   }
1604 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)1605   virtual bool DoPoolBackward(Stream* stream,
1606                               const dnn::PoolingDescriptor& pooling_dimensions,
1607                               const dnn::BatchDescriptor& input_dimensions,
1608                               const DeviceMemory<Eigen::half>& input_data,
1609                               const dnn::BatchDescriptor& output_dimensions,
1610                               const DeviceMemory<Eigen::half>& output_data,
1611                               const DeviceMemory<Eigen::half>& input_diff_data,
1612                               DeviceMemory<Eigen::half>* output_diff_data,
1613                               ScratchAllocator* workspace_allocator) {
1614     LOG(FATAL) << "DoPoolBackward not implemented.";
1615     return false;
1616   }
1617 
1618   // Applies local response normalization to the values from input_data and
1619   // writes the result to output_data.
1620   //
1621   // See comments on NormalizeDescriptor for a description of local response
1622   // normalization.
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1623   virtual bool DoNormalizeWithDimensions(
1624       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1625       const dnn::BatchDescriptor& dimensions,
1626       const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
1627     return false;
1628   }
1629 
1630   // Performs backpropagation for the normalization operation
1631   //
1632   // Given raw data, its corresponding normalized output, and a gradient of some
1633   // unspecified function with respect to the normalized variables, computes the
1634   // gradient of that unspecified function with respect to the raw variables.
1635   //
1636   // The normalized data input array is expected to match the output that would
1637   // be obtained by running the raw data input array through the DoNormalize
1638   // method above.
1639   //
1640   // See comments on NormalizeDescriptor for a description of local response
1641   // normalization.
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1642   virtual bool DoNormalizeBackwardWithDimensions(
1643       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1644       const dnn::BatchDescriptor& dimensions,
1645       const DeviceMemory<float>& raw_data,
1646       const DeviceMemory<float>& normalized_data,
1647       const DeviceMemory<float>& normalized_variable_gradient,
1648       DeviceMemory<float>* raw_variable_gradient,
1649       ScratchAllocator* workspace_allocator) {
1650     return false;
1651   }
1652 
1653   // Applies an activation function (see ActivationMode) to all of the values
1654   // held on the device in 'input_data', whose dimensions are described by
1655   // 'dimensions'.
1656   //
1657   // Arguments (all borrowed):
1658   //  stream: borrowed pointer to the stream that the 'activate' operation
1659   //    should be enqueued onto.
1660   //  activation_mode: Type of activation to perform.
1661   //  input_data: un-owned device memory region which contains the
1662   //    activate input.
1663   //  output_data: un-owned device memory region in which to place the
1664   //    activate result.
DoActivate(Stream * stream,ActivationMode activation_mode,const BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1665   virtual bool DoActivate(Stream* stream, ActivationMode activation_mode,
1666                           const BatchDescriptor& dimensions,
1667                           const DeviceMemory<float>& input_data,
1668                           DeviceMemory<float>* output_data, uint64 options) {
1669     return false;
1670   }
1671 
1672   // Concatenates several layers into one, by concatenating the depth of each
1673   // layer at matching x and y coordinates.
1674   // The inputs must all have the same width and height, the output will have
1675   // the same width and height as the inputs and its depth will be the sum of
1676   // the input depths.
1677   //
1678   // Arguments (all borrowed):
1679   //  stream: borrowed pointer to the stream that the 'depth concatenate'
1680   // operation should be enqueued onto.
1681   //  input_dimensions: The dimensions of each input.
1682   //  input_data: un-owned device memory region which contains the
1683   //    input data for each input layer.
1684   //  output_data: un-owned device memory region in which to place the
1685   //    depth concatenate result.
1686   virtual bool DoDepthConcatenate(
1687       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1688       port::ArraySlice<const DeviceMemory<float>*> input_data,
1689       DeviceMemory<float>* output_data) = 0;
1690 
1691   // Concatenates several layers into one, by concatenating each in the
1692   // x-dimension or y-dimension, based on a user-specified flag.
1693   // For x-concatenation, layers are aligned at matching y and depth
1694   // coordinates, and for y-concatenation, they are aligned at matching x and
1695   // depth coordinates. The inputs must all have the same depth and batch size.
1696   // For x-concatenation, the inputs must have the same height (y-size), and the
1697   // output will have the same depth and height as the inputs and its width (x-
1698   // size) will be the sum of the input widths.  For y-concatenation, the inputs
1699   // must have the same width, and the output will have the same depth and width
1700   // as the inputs, and its height will be the sum of the input heights.
1701   //
1702   // Arguments:
1703   //  stream: borrowed pointer to the stream that the 'space concatenate'
1704   //    operation should be enqueued onto.
1705   //  input_dimensions: the dimensions of each input.
1706   //  input_data: un-owned device memory region which contains the input data
1707   //    for each input layer.
1708   //  output_data: un-owned device memory region in which to place the space
1709   //    concatenate result.
1710   //  concat_direction:  either dnn:SpaceConcatenateMode::XDirection or
1711   //    dnn::SpaceConcatenateMode::YDirection.
DoSpaceConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1712   virtual bool DoSpaceConcatenate(
1713       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1714       port::ArraySlice<const DeviceMemory<float>*> input_data,
1715       DeviceMemory<float>* output_data,
1716       dnn::SpaceConcatenateMode concat_direction) {
1717     return false;
1718   }
1719 
1720   // Change the layout of the data by shrinking one dimension (or set of
1721   // dimensions) and growing another dimension (or set of dimensions), while
1722   // keeping the total number of data elements constant, and maintaining the
1723   // current data ordering.
1724   //
1725   // Currently, the only supported operation is depth into space by a power of
1726   // 2. E.g. (y, x, z) -> (y*2, x*2, z/4)
1727   //
1728   // Note that Reshape may not be a no-op, depending on the platform and which
1729   // dimensions are being changed.
1730   //
1731   // Example: forgetting about batch for the moment, let's take a tensor that's
1732   // 2x1x8 (y by x by z) and reshape to a tensor that's 4x2x2. The memory layout
1733   // is row-major order: y,x,z. I.e. z changes the fastest, then x, then y. The
1734   // elements of the tensor range from 0 to 15. The x,y,z indices are below each
1735   // element.
1736   //
1737   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1738   // y0 y0 y0 y0 y0 y0 y0 y0 y1 y1 y1 y1 y1 y1 y1 y1
1739   // x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0
1740   // z0 z1 z2 z3 z4 z5 z6 z7 z0 z1 z2 z3 z4 z5 z6 z7
1741   //
1742   // reshape to 4x2x2
1743   //
1744   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1745   // y0 y0 y0 y0 y1 y1 y1 y1 y2 y2 y2 y2 y3 y3 y3 y3
1746   // x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1
1747   // z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1
DoReshape(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1748   virtual bool DoReshape(Stream* stream,
1749                          const dnn::BatchDescriptor& input_dimensions,
1750                          const DeviceMemory<float>& input_data,
1751                          const dnn::BatchDescriptor& output_dimensions,
1752                          DeviceMemory<float>* output_data) {
1753     return false;
1754   }
1755 
1756   // Depth to space takes an X by Y image with depth D*M^2 and changes it to an
1757   // MX x MY image with depth D. Each input location (x,y) with depth D*M^2 in
1758   // the input image is changed to an MxM contiguous area in the output image,
1759   // with the values being laid out in the raster order by DepthToSpaceLayout,
1760   // and will have a new depth of D.
1761   //
1762   // Example.
1763   // M=2, Din =8, Xin=2, Yin=2. Xout=4, Yout=4,  Dout=2
1764   // DepthHeightWidth layout
1765   // Values within a 'cell' are at different depths and same x & y.
1766   // Input:
1767   // abcdefgh  ijklmnop
1768   // qrstuvwx  yz012345
1769   // Output:
1770   // ae bf im jn
1771   // cg dh ko lp
1772   // qu rv y2 z3
1773   // sw tx 04 15
1774   //
1775   // sqrt_depth_reduction: 'M' in the comment above
DoDepthToSpace(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & depth_to_space_layout,const int & sqrt_depth_reduction,DeviceMemory<float> * output_data)1776   virtual bool DoDepthToSpace(Stream* stream,
1777                               const dnn::BatchDescriptor& input_dimensions,
1778                               const DeviceMemory<float>& input_data,
1779                               const DepthToSpaceLayout& depth_to_space_layout,
1780                               const int& sqrt_depth_reduction,
1781                               DeviceMemory<float>* output_data) {
1782     return false;
1783   }
1784 
1785   // Space to depth is the inverse of depth to space. Space to depth takes each
1786   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
1787   // the input, and transforms it to a 1 by 1 patch with depth D*M^2. If the
1788   // input has size (MX, MY, D), the output has size (X, Y, D*M^2). The number
1789   // of data elements is not changed.
1790   //
1791   // Example.
1792   // M=2, Din =2, Xin=4, Yin=4,  Dout=8
1793   // DepthHeightWidth layout
1794   // Values within a 'cell' are at different depths and same x & y.
1795   // Input:
1796   // ae bf im jn
1797   // cg dh ko lp
1798   // qu rv y2 z3
1799   // sw tx 04 15
1800   // Output:
1801   // abcdefgh  ijklmnop
1802   // qrstuvwx  yz012345
1803   //
1804   // sqrt_depth_increase: 'M' in the comment above
DoSpaceToDepth(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & space_to_depth_layout,const int & sqrt_depth_increase,DeviceMemory<float> * output_data)1805   virtual bool DoSpaceToDepth(Stream* stream,
1806                               const dnn::BatchDescriptor& input_dimensions,
1807                               const DeviceMemory<float>& input_data,
1808                               const DepthToSpaceLayout& space_to_depth_layout,
1809                               const int& sqrt_depth_increase,
1810                               DeviceMemory<float>* output_data) {
1811     return false;
1812   }
1813 
1814   // Computes the specified operation (e.g. addition or multiplication)
1815   // between corresponding elements in the inputs and stores the result in the
1816   // output element.
1817   // The inputs and output must all have the same dimensions, but may have
1818   // different quantization parameters (min_value and max_value).
1819   //
1820   // Arguments (all borrowed):
1821   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1822   // should be enqueued onto.
1823   //  operation: The operation to perform.
1824   //  input_dimensions: The dimensions of each input.
1825   //  input_data: un-owned device memory region which contains the
1826   //    input data for each input layer.
1827   //  output_dimensions: The dimensions of the output.
1828   //  output_data: un-owned device memory region in which to place the
1829   //    operation result.
1830   virtual bool DoElementwiseOperate(
1831       Stream* stream, ElementwiseOperation operation,
1832       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1833       port::ArraySlice<const DeviceMemory<float>*> input_data,
1834       const dnn::BatchDescriptor& output_dimensions,
1835       DeviceMemory<float>* output_data) = 0;
1836 
1837   // Computes the specified operation (e.g. addition or multiplication)
1838   // between corresponding elements in the inputs and stores the result in the
1839   // output element. Each input is multiplied by a scalar constant and the
1840   // result is divided by a scalar constant.
1841   // e.g. To perform Z = 0.9*X + 1.1*Y, set the input multiplicands to 9 and 11
1842   // and the output divisor to 10.
1843   // The inputs and output must all have the same dimensions, but may have
1844   // different quantization parameters (min_value and max_value).
1845   //
1846   // Arguments (all borrowed):
1847   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1848   // should be enqueued onto.
1849   //  operation: The operation to perform.
1850   //  input_multiplicands: Amount to scale each input.
1851   //  output_divisor: Amount to divide the output.
1852   //  input_dimensions: The dimensions of each input.
1853   //  input_data: un-owned device memory region which contains the
1854   //    input data for each input layer.
1855   //  output_dimensions: The dimensions of the output.
1856   //  output_data: un-owned device memory region in which to place the
1857   //    operation result.
DoElementwiseOperateScaledQuantized(Stream * stream,ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1858   virtual bool DoElementwiseOperateScaledQuantized(
1859       Stream* stream, ElementwiseOperation operation,
1860       port::ArraySlice<int> input_multiplicands, int output_divisor,
1861       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1862       port::ArraySlice<const DeviceMemory<float>*> input_data,
1863       const dnn::BatchDescriptor& output_dimensions,
1864       DeviceMemory<float>* output_data) {
1865     return false;
1866   }
1867 
1868   // Pads the input with zeros in the X and Y dimensions. The feature_map
1869   // dimension is unchanged.
1870   //
1871   // Arguments (all borrowed):
1872   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1873   // should be enqueued onto.
1874   //  dimensions: The dimensions of the input.
1875   //  input_data: un-owned device memory region which contains the
1876   //    input data for the input layer.
1877   //  left_pad: Amount to pad the input on the left.
1878   //  right_pad: Amount to pad the input on the right.
1879   //  top_pad: Amount to pad the input at the top (low Y).
1880   //  bottom_pad: Amount to pad the input at the bottom (high Y).
1881   //  output_data: un-owned device memory region in which to place the
1882   //    padded result.
1883   virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
1884                        const DeviceMemory<float>& input_data, int64_t left_pad,
1885                        int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
1886                        DeviceMemory<float>* output_data) = 0;
1887 
1888   // Extracts a slice of the input in the X and Y dimensions. The feature_map
1889   // dimension is unchanged.
1890   //
1891   // Arguments (all borrowed):
1892   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1893   // should be enqueued onto.
1894   //  dimensions: The dimensions of the input.
1895   //  input_data: un-owned device memory region which contains the
1896   //    input data for the input layer.
1897   //  left_trim: Amount to cut off the input on the left.
1898   //  right_trim: Amount to cut off the input on the right.
1899   //  top_trim: Amount to cut off the input at the top (low y).
1900   //  bottom_trim: Amount to cut off the input at the bottom (high Y).
1901   //  output_data: un-owned device memory region in which to place the
1902   //    padded result.
1903   virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
1904                          const DeviceMemory<float>& input_data,
1905                          int64_t left_trim, int64_t right_trim,
1906                          int64_t top_trim, int64_t bottom_trim,
1907                          DeviceMemory<float>* output_data) = 0;
1908 
1909   // Grows the input tensor by replicating the X and Y dimensions. The batch and
1910   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
1911   // limited to X=1 and Y=1.
1912   //
1913   // For example, the input has dimensions x=2, y=3, and replicate_x=3,
1914   // replicate_y=2. The diagonal elements of the output would be: [x0y0, x1y1,
1915   // x0y2, x1y0, x0y1, x1y2].
1916   // Here is the example as a picture. input:
1917   // AB
1918   // CD
1919   // EF
1920   // broadcast result:
1921   // ABABAB
1922   // CDCDCD
1923   // EFEFEF
1924   // ABABAB
1925   // CDCDCD
1926   // EFEFEF
1927   //
1928   // Arguments (all borrowed):
1929   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1930   // should be enqueued onto.
1931   //  dimensions: The dimensions of the input.
1932   //  input_data: un-owned device memory region which contains the
1933   //    input data for the input layer.
1934   //  replicate_x: Amount to replicate the input's X dimension.
1935   //  replicate_y: Amount to replicate the input's Y dimension.
1936   //  output_data: un-owned device memory region in which to place the
1937   //    padded result.
DoXYBroadcast(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t replicate_x,int64_t replicate_y,DeviceMemory<float> * output_data)1938   virtual bool DoXYBroadcast(Stream* stream,
1939                              const dnn::BatchDescriptor& dimensions,
1940                              const DeviceMemory<float>& input_data,
1941                              int64_t replicate_x, int64_t replicate_y,
1942                              DeviceMemory<float>* output_data) {
1943     return false;
1944   }
1945 
1946   // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
1947   // is, bytes instead of scaled floats) into 'host_dst' if they are available
1948   // for the underlying DNN implementation. If this quantized output is not
1949   // available, false is returned, which will place 'stream' into an error
1950   // state.
1951   //
1952   // Arguments (all borrowed):
1953   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
1954   //    operation should be enqueued onto.
1955   //  gpu_unquantized_src: the device memory that contains the unquantized data
1956   //    -- this data should also have a corresponding quantized representation
1957   //    on the device for this operation to succeed.
1958   //  mode: Type of quantization of the data to write into host_dst.
1959   //  host_dst: un-owned host memory region that is mutated in place,
1960   //    it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
1961   //    (asynchronous) memcpy operation is performed.
1962   //  size: size in bytes of the host_dst host memory region.
1963   virtual bool DoMemcpyD2HQuantized(
1964       Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
1965       QuantizedActivationMode mode, void* host_dst, int64_t size) = 0;
1966 
1967   // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
1968   // of a layer (that is, bytes instead of scaled floats) if they are supported
1969   // by the underlying DNN implementation. If this quantized input is not
1970   // supported, false is returned, which will place 'stream' into an error
1971   // state.
1972   //
1973   // Arguments (all borrowed):
1974   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
1975   //    operation should be enqueued onto.
1976   //  host_src: un-owned host memory region that contains the quantized data.
1977   //  size: size in bytes of the host_src host memory region.
1978   //  mode: Type of quantization of the data to read from host_src.
1979   //  gpu_unquantized_dst: the device memory that is clobbered by the values in
1980   //    'host_src' when the enqueued (asynchronous) memcpy operation is
1981   //    performed. -- this data should also have a corresponding quantized
1982   //    representation on the device for this operation to
1983   //    succeed.
1984   virtual bool DoMemcpyH2DQuantized(
1985       Stream* stream, const void* host_src, int64_t size,
1986       QuantizedActivationMode mode,
1987       DeviceMemory<float>* gpu_unquantized_dst) = 0;
1988 
1989   // Create an RNN descriptor based on model shapes and configurations.
1990   // The caller retains the ownership of the descriptor.
1991   //
1992   // Arguments:
1993   //  num_layers: the number of layers for a RNN model.
1994   //  hidden_size: the size of the hidden state.
1995   //  input_size: the size of the input state.
1996   //  cell_size: the size of the cell state
1997   //  input_mode: an enum to specify whether a linear transformation is added
1998   //    after the input state. If input_size is different from hidden_size, this
1999   //    is required.
2000   //  direction_mode: an enum to specify whether this model is unidirectional or
2001   //    bidirectional.
2002   //  rnn_mode: an enum to specify the type of model to build.
2003   //  data_type: an enum to specify the data types used in this model.
2004   //  dropout: the dropout threshold between layers. When it is 0., no dropout
2005   //    is added.
2006   //  seed: a seed for initializing the dropout layers.
2007   //  state_allocator: an memory allocator that will be used to store the state
2008   //    for dropout layer. The user has to maintain the memory until the model
2009   //    is no longer in use.
2010   //  use_padded_io: a bool to specify whether the input is using padded IO.
2011   virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2012   createRnnDescriptor(int num_layers, int hidden_size, int input_size,
2013                       int cell_size, int batch_size,
2014                       dnn::RnnInputMode input_mode,
2015                       dnn::RnnDirectionMode direction_mode,
2016                       dnn::RnnMode rnn_mode, dnn::DataType data_type,
2017                       const dnn::AlgorithmConfig& algorithm_config,
2018                       float dropout, uint64 seed,
2019                       ScratchAllocator* state_allocator, bool use_padded_io) {
2020     return port::Status(port::error::UNIMPLEMENTED,
2021                         "createRnnDescriptor is unimplemented");
2022   }
2023 
2024   // Create a RNN sequence descriptor that specifies either the input or output
2025   // sequence. The caller retains the ownership of the returned descriptor.
2026   //
2027   // Arguments:
2028   //  max_seq_length: the max length of the sequences.
2029   //  batch_size: the size of a minibatch.
2030   //  data_size: the size of the state.
2031   //  seq_lengths: the lengths of sequences in a batch.
2032   //  data_type: an enum to specify the type for the underlying data.
2033   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2034   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2035                                     int data_size, dnn::DataType data_type) {
2036     return port::Status(port::error::UNIMPLEMENTED,
2037                         "createRnnSequenceTensorDescriptor is unimplemented");
2038   }
2039 
2040   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2041   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2042                                     int data_size,
2043                                     const absl::Span<const int>& seq_lengths,
2044                                     bool time_major, dnn::DataType data_type) {
2045     return port::Status(port::error::UNIMPLEMENTED,
2046                         "createRnnSequenceTensorDescriptor is unimplemented");
2047   }
2048 
2049   // Create an RNN state descriptor that specifies the input or hidden state.
2050   // The caller retains the ownership of the returned descriptor.
2051   virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2052   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
2053                                  dnn::DataType data_type) {
2054     return port::Status(port::error::UNIMPLEMENTED,
2055                         "createRnnStateTensorDescriptor is unimplemented");
2056   }
2057 
2058   // Enqueue a forward operation of the RNN model onto the stream.
2059   //
2060   // Arguments:
2061   //  stream: pointer to the stream where this operation should be enqueued to.
2062   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2063   //  input_desc: descriptor for the input sequence.
2064   //  input_data: the device memory region that contains the input data.
2065   //  input_h_desc: descriptor for the input "h" state.
2066   //  input_h_data: the device memory region that contains the input "h" data.
2067   //  input_c_desc: descriptor for the input "c" state.
2068   //  input_c_data: the device memory region that contains the input "c" data.
2069   //    This must be specified for LSTM models.
2070   //  params: the device memory region that contains the parameters used in this
2071   //    model.
2072   //  output_desc: descriptor for the output sequence.
2073   //  output_data: the memory region that stores the output sequence data.
2074   //  output_h_desc: descriptor for the output "h" state.
2075   //  output_h_data: the memory region that stores the output "h" data.
2076   //  output_c_desc: descriptor for the output "c" state.
2077   //  output_c_data: the memory region that stores the output "c" data. This
2078   //    must be specified for LSTM models.
2079   //  is_training: whether this is used in training or inference. That decides
2080   //    whether respace_space data need to be produced.
2081   //  reserve_space_allocator: if "is_training" is true, an memory allocator
2082   //    to create memory that holds the produced reserve_space. The caller is
2083   //  retains the data and feed it to the backward pass.
2084   //  workspace_allocator: an allocator to create temporary workspace used in
2085   //    this kernel. The caller is responsible for retaining the memory long
2086   //    enough for the lifespan of this operation, and recycles afterwards.
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2087   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2088                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2089                             const DeviceMemory<Eigen::half>& input_data,
2090                             const DeviceMemory<int>& seq_lengths_data,
2091                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2092                             const DeviceMemory<Eigen::half>& input_h_data,
2093                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2094                             const DeviceMemory<Eigen::half>& input_c_data,
2095                             const DeviceMemory<Eigen::half>& params,
2096                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2097                             DeviceMemory<Eigen::half>* output_data,
2098                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2099                             DeviceMemory<Eigen::half>* output_h_data,
2100                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2101                             DeviceMemory<Eigen::half>* output_c_data,
2102                             bool is_training,
2103                             ScratchAllocator* reserve_space_allocator,
2104                             ScratchAllocator* workspace_allocator,
2105                             dnn::ProfileResult* output_profile_result) {
2106     return false;
2107   }
2108 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2109   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2110                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2111                             const DeviceMemory<float>& input_data,
2112                             const DeviceMemory<int>& seq_lengths_data,
2113                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2114                             const DeviceMemory<float>& input_h_data,
2115                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2116                             const DeviceMemory<float>& input_c_data,
2117                             const DeviceMemory<float>& params,
2118                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2119                             DeviceMemory<float>* output_data,
2120                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2121                             DeviceMemory<float>* output_h_data,
2122                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2123                             DeviceMemory<float>* output_c_data,
2124                             bool is_training,
2125                             ScratchAllocator* reserve_space_allocator,
2126                             ScratchAllocator* workspace_allocator,
2127                             dnn::ProfileResult* output_profile_result) {
2128     return false;
2129   }
2130 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2131   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2132                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2133                             const DeviceMemory<double>& input_data,
2134                             const DeviceMemory<int>& seq_lengths_data,
2135                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2136                             const DeviceMemory<double>& input_h_data,
2137                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2138                             const DeviceMemory<double>& input_c_data,
2139                             const DeviceMemory<double>& params,
2140                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2141                             DeviceMemory<double>* output_data,
2142                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2143                             DeviceMemory<double>* output_h_data,
2144                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2145                             DeviceMemory<double>* output_c_data,
2146                             bool is_training,
2147                             ScratchAllocator* reserve_space_allocator,
2148                             ScratchAllocator* workspace_allocator,
2149                             dnn::ProfileResult* output_profile_result) {
2150     return false;
2151   }
2152   // Enqueue a backward operation of the RNN model onto the stream.
2153   //
2154   // Arguments:
2155   //  stream: pointer to the stream where this operation should be enqueued to.
2156   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2157   //  input_desc: descriptor for the input sequence.
2158   //  input_data: the device memory region that contains the input data.
2159   //  input_h_desc: descriptor for the input "h" state.
2160   //  input_h_data: the device memory region that contains the input "h" data.
2161   //  input_c_desc: descriptor for the input "c" state.
2162   //  input_c_data: the device memory region that contains the input "c" data.
2163   //    This must be specified for LSTM models.
2164   //  params: the device memory region that contains the parameters used in this
2165   //    model.
2166   //  output_desc: descriptor for the output sequence.
2167   //  output_data: the memory region that stores the output sequence data.
2168   //  output_h_desc: descriptor for the output "h" state.
2169   //  output_h_data: the memory region that stores the output "h" data.
2170   //  output_c_desc: descriptor for the output "c" state.
2171   //  output_c_data: the memory region that stores the output "c" data. This
2172   //    must be specified for LSTM models.
2173   //  output_backprop_data: the device memory region that contains the backprop
2174   //    to the output sequence.
2175   //  output_h_backprop_data: the device memory region that contains the
2176   //    backprop to the output "h" state.
2177   //  output_c_backprop_data: the device memory region that contains the
2178   //    backprop to the output "c" state.
2179   //  input_backprop_data: the device memory region that stores the backprop
2180   //    to the input sequence.
2181   //  input_h_backprop_data: the device memory region that stores the backprop
2182   //    to the input "h" state.
2183   //  input_c_backprop_data: the device memory region that stores the backprop
2184   //    to the input "c" state.
2185   //  params_backprop_data: the device memory region that stores the backprop
2186   //    to the parameters.
2187   //  reserve_space_data: the reserve_space data that is produced by the forward
2188   //    operation. This memory region could be modified by this operation.
2189   //  workspace_allocator: a memory allocator that creates the temporary
2190   //    workspace memory used by this operation. The caller is responsible for
2191   //    keeping the memory alive long enough for this operation, and recylces
2192   //    afterwards.
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2193   virtual bool DoRnnBackward(
2194       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2195       const dnn::RnnSequenceTensorDescriptor& input_desc,
2196       const DeviceMemory<Eigen::half>& input_data,
2197       const DeviceMemory<int>& seq_lengths_data,
2198       const dnn::RnnStateTensorDescriptor& input_h_desc,
2199       const DeviceMemory<Eigen::half>& input_h_data,
2200       const dnn::RnnStateTensorDescriptor& input_c_desc,
2201       const DeviceMemory<Eigen::half>& input_c_data,
2202       const DeviceMemory<Eigen::half>& params,
2203       const dnn::RnnSequenceTensorDescriptor& output_desc,
2204       const DeviceMemory<Eigen::half>& output_data,
2205       const dnn::RnnStateTensorDescriptor& output_h_desc,
2206       const DeviceMemory<Eigen::half>& output_h_data,
2207       const dnn::RnnStateTensorDescriptor& output_c_desc,
2208       const DeviceMemory<Eigen::half>& output_c_data,
2209       const DeviceMemory<Eigen::half>& output_backprop_data,
2210       const DeviceMemory<Eigen::half>& output_h_backprop_data,
2211       const DeviceMemory<Eigen::half>& output_c_backprop_data,
2212       DeviceMemory<Eigen::half>* input_backprop_data,
2213       DeviceMemory<Eigen::half>* input_h_backprop_data,
2214       DeviceMemory<Eigen::half>* input_c_backprop_data,
2215       DeviceMemory<Eigen::half>* params_backprop_data,
2216       DeviceMemory<uint8>* reserve_space_data,
2217       ScratchAllocator* workspace_allocator,
2218       dnn::ProfileResult* output_profile_result) {
2219     return false;
2220   }
2221 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2222   virtual bool DoRnnBackward(
2223       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2224       const dnn::RnnSequenceTensorDescriptor& input_desc,
2225       const DeviceMemory<float>& input_data,
2226       const DeviceMemory<int>& seq_lengths_data,
2227       const dnn::RnnStateTensorDescriptor& input_h_desc,
2228       const DeviceMemory<float>& input_h_data,
2229       const dnn::RnnStateTensorDescriptor& input_c_desc,
2230       const DeviceMemory<float>& input_c_data,
2231       const DeviceMemory<float>& params,
2232       const dnn::RnnSequenceTensorDescriptor& output_desc,
2233       const DeviceMemory<float>& output_data,
2234       const dnn::RnnStateTensorDescriptor& output_h_desc,
2235       const DeviceMemory<float>& output_h_data,
2236       const dnn::RnnStateTensorDescriptor& output_c_desc,
2237       const DeviceMemory<float>& output_c_data,
2238       const DeviceMemory<float>& output_backprop_data,
2239       const DeviceMemory<float>& output_h_backprop_data,
2240       const DeviceMemory<float>& output_c_backprop_data,
2241       DeviceMemory<float>* input_backprop_data,
2242       DeviceMemory<float>* input_h_backprop_data,
2243       DeviceMemory<float>* input_c_backprop_data,
2244       DeviceMemory<float>* params_backprop_data,
2245       DeviceMemory<uint8>* reserve_space_data,
2246       ScratchAllocator* workspace_allocator,
2247       dnn::ProfileResult* output_profile_result) {
2248     return false;
2249   }
2250 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2251   virtual bool DoRnnBackward(
2252       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2253       const dnn::RnnSequenceTensorDescriptor& input_desc,
2254       const DeviceMemory<double>& input_data,
2255       const DeviceMemory<int>& seq_lengths_data,
2256       const dnn::RnnStateTensorDescriptor& input_h_desc,
2257       const DeviceMemory<double>& input_h_data,
2258       const dnn::RnnStateTensorDescriptor& input_c_desc,
2259       const DeviceMemory<double>& input_c_data,
2260       const DeviceMemory<double>& params,
2261       const dnn::RnnSequenceTensorDescriptor& output_desc,
2262       const DeviceMemory<double>& output_data,
2263       const dnn::RnnStateTensorDescriptor& output_h_desc,
2264       const DeviceMemory<double>& output_h_data,
2265       const dnn::RnnStateTensorDescriptor& output_c_desc,
2266       const DeviceMemory<double>& output_c_data,
2267       const DeviceMemory<double>& output_backprop_data,
2268       const DeviceMemory<double>& output_h_backprop_data,
2269       const DeviceMemory<double>& output_c_backprop_data,
2270       DeviceMemory<double>* input_backprop_data,
2271       DeviceMemory<double>* input_h_backprop_data,
2272       DeviceMemory<double>* input_c_backprop_data,
2273       DeviceMemory<double>* params_backprop_data,
2274       DeviceMemory<uint8>* reserve_space_data,
2275       ScratchAllocator* workspace_allocator,
2276       dnn::ProfileResult* output_profile_result) {
2277     return false;
2278   }
2279 
2280   template <typename ElementType>
PrepareForCtcLoss(Stream * stream,const RnnStateTensorDescriptor & probs_desc,DeviceMemory<ElementType> probs_data,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2281   port::Status PrepareForCtcLoss(Stream* stream,
2282                                  const RnnStateTensorDescriptor& probs_desc,
2283                                  DeviceMemory<ElementType> probs_data,
2284                                  const RnnStateTensorDescriptor& grads_desc,
2285                                  absl::Span<const int> labels_data,
2286                                  absl::Span<const int> labels_lengths_data,
2287                                  absl::Span<const int> input_lengths_data,
2288                                  ScratchAllocator* workspace_allocator,
2289                                  DeviceMemory<uint8>* scratch_memory,
2290                                  int* ctc_loss_algo_id) {
2291     return DoPrepareForCtcLoss(
2292         stream, ToDataType<ElementType>::value, probs_desc, grads_desc,
2293         labels_data, labels_lengths_data, input_lengths_data,
2294         workspace_allocator, scratch_memory, ctc_loss_algo_id);
2295   }
2296 
2297   // Enqueue a CTC Loss operation onto the stream.
2298   //
2299   // Arguments:
2300   //  stream: pointer to the stream where this operation should be enqueued to.
2301   //  element_type: date type of the input tensors
2302   //  probs_desc: specifies the shape and the data layout of the input tensor.
2303   //  probs_data: the device memory region that contains the input tensor.
2304   //  labels_data: the device memory region that contains the labels_value
2305   //    tensor.
2306   //  labels_lengths_data: the device memory region that contains the
2307   //    labels_lengths tensor
2308   //  input_lengths_data: the device memory region that contains the seq_lengths
2309   //    tensor
2310   //  costs_data: the device memory region that contains the costs tensor.
2311   //  grads_desc: specifies the shape and the data layout of the grads tensor.
2312   //  grads_data: the device memory region that contains the grads tensor.
2313   //  ctc_loss_desc: a CTCLoss descriptor.
2314   //  workspace_allocator: a memory allocator that creates the temporary
2315   //    workspace memory used by this operation. The caller is responsible for
2316   //    keeping the memory alive long enough for this operation, and recylces
2317   //    afterwards.
2318   virtual port::Status DoCtcLoss(
2319       Stream* stream, dnn::DataType element_type,
2320       const RnnStateTensorDescriptor& probs_desc,
2321       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2322       absl::Span<const int> labels_lengths_data,
2323       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2324       const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
2325       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
2326 
2327   template <typename ElementType>
DoCtcLoss(Stream * stream,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<ElementType> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<ElementType> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<ElementType> * grads_data,DeviceMemory<uint8> * scratch_memory,int ctc_loss_algo_id)2328   bool DoCtcLoss(Stream* stream,
2329                  const dnn::RnnStateTensorDescriptor& probs_desc,
2330                  const DeviceMemory<ElementType>& probs_data,
2331                  absl::Span<const int> labels_data,
2332                  absl::Span<const int> labels_lengths_data,
2333                  absl::Span<const int> input_lengths_data,
2334                  DeviceMemory<ElementType>* costs_data,
2335                  const dnn::RnnStateTensorDescriptor& grads_desc,
2336                  DeviceMemory<ElementType>* grads_data,
2337                  DeviceMemory<uint8>* scratch_memory, int ctc_loss_algo_id) {
2338     return IsStatusOk(
2339         DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
2340                   probs_data, labels_data, labels_lengths_data,
2341                   input_lengths_data, *costs_data, grads_desc, *grads_data,
2342                   *scratch_memory, ctc_loss_algo_id),
2343         false);
2344   }
2345 
2346   // Transforms a tensor into another tensor with a different layout and/or data
2347   // type.
2348   //
2349   // Arguments:
2350   //  stream: pointer to the stream where this operation should be enqueued to.
2351   //  input_desc: specifies the shape and the data layout of the input tensor.
2352   //  input_type: the data type of the input tensor.
2353   //  input_data: the device memory region that contains the input tensor.
2354   //  output_desc: specifies the shape and the data layout of the output tensor.
2355   //  output_type: the data type of the output tensor.
2356   //  scale: an element-wise scaling factor to apply.
2357   //  output_data: the device memory region that contains the output tensor.
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)2358   virtual bool DoTransformTensor(Stream* stream,
2359                                  const dnn::BatchDescriptor& input_desc,
2360                                  dnn::DataType input_type,
2361                                  const DeviceMemoryBase& input_data,
2362                                  const dnn::BatchDescriptor& output_desc,
2363                                  dnn::DataType output_type, float scale,
2364                                  DeviceMemoryBase* output_data) {
2365     return false;
2366   }
2367 
2368   // Enqueues a fused convolution+bias+activation operation onto the stream.
2369   //
2370   // Arguments (all borrowed):
2371   //
2372   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2373   //  be enqueued onto.
2374   //
2375   //  conv_input_descriptor: dimensions of the convolution input layer.
2376   //  conv_input_data: device memory which contains the convolution input.
2377   //
2378   //  filter_descriptor: dimensions of the convolution filter.
2379   //  filter_data: device memory which contains the convolution filter weights.
2380   //
2381   //  convolution_descriptor: stride of the convolution filter.
2382   //
2383   //  bias_descriptor: dimensions of the bias layer
2384   //  biases: device memory region containing biases to add to the convolution
2385   //  output
2386   //
2387   //  activation_mode: Type of activation to perform.
2388   //
2389   //  output_descriptor: dimensions of the output layer.
2390   //  output_data: device memory region in which to place the fusion result.
2391   //
2392   //  output_profile_result: the output profile result for this call.
2393   //         The profiling is only enabled when this is not nullptr.
2394   //
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)2395   virtual bool DoFusedConvolutionBiasActivation(
2396       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
2397       const DeviceMemory<float>& conv_input_data,
2398       const dnn::FilterDescriptor& filter_descriptor,
2399       const DeviceMemory<float>& filter_data,
2400       const dnn::ConvolutionDescriptor& convolution_descriptor,
2401       const dnn::BatchDescriptor& bias_descriptor,
2402       const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
2403       const dnn::BatchDescriptor& output_descriptor,
2404       DeviceMemory<float>* output_data,
2405       dnn::ProfileResult* output_profile_result) {
2406     return false;
2407   }
2408 
2409   // Enqueues a fused batchnorm+activation (inference) operation onto the
2410   // stream.
2411   //
2412   // Arguments (all borrowed):
2413   //
2414   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2415   //  be enqueued onto.
2416   //
2417   //  x_descriptor: dimensions of the batchnorm input layer.
2418   //  x_data: device memory which contains the batchnorm input.
2419   //
2420   //  scale_offset_mean_variance_descriptor:
2421   //      dimensions of the scale/offset/mean/variance tensor.
2422   //  scale_data: device memory which contains the scale input.
2423   //  offset_data: device memory which contains the offset input.
2424   //  mean_data: device memory which contains the mean input.
2425   //  variance_data: device memory which contains the variance input.
2426   //  epsilon : the epsilon value to use in batchnorm calculation
2427   //
2428   //  activation_mode: Type of activation to perform.
2429   //
2430   //  y_data: device memory region in which to place the fusion result.
2431   //
2432   //  output_profile_result: the output profile result for this call.
2433   //         The profiling is only enabled when this is not nullptr.
2434   //
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)2435   virtual bool DoFusedBatchNormActivationInference(
2436       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2437       const DeviceMemory<float>& x_data,
2438       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2439       const DeviceMemory<float>& scale_data,
2440       const DeviceMemory<float>& offset_data,
2441       const DeviceMemory<float>& mean_data,
2442       const DeviceMemory<float>& variance_data, double epsilon,
2443       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2444       dnn::ProfileResult* output_profile_result) {
2445     return false;
2446   }
2447 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)2448   virtual bool DoFusedBatchNormActivationInference(
2449       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2450       const DeviceMemory<Eigen::half>& x_data,
2451       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2452       const DeviceMemory<float>& scale_data,
2453       const DeviceMemory<float>& offset_data,
2454       const DeviceMemory<float>& mean_data,
2455       const DeviceMemory<float>& variance_data, double epsilon,
2456       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2457       dnn::ProfileResult* output_profile_result) {
2458     return false;
2459   }
2460 
2461   // Enqueues a fused batchnorm+activation (training-fwd) operation onto the
2462   // stream.
2463   //
2464   // Arguments (all borrowed):
2465   //
2466   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2467   //  be enqueued onto.
2468   //
2469   //  x_descriptor: dimensions of the batchnorm input layer.
2470   //  x_data: device memory which contains the batchnorm input.
2471   //
2472   //  scale_offset_mean_variance_descriptor:
2473   //      dimensions of the scale/offset/mean/variance tensor.
2474   //  scale_data: device memory which contains the scale input.
2475   //  offset_data: device memory which contains the offset input.
2476   //  epsilon : the epsilon value to use in batchnorm calculation
2477   //
2478   //  activation_mode: Type of activation to perform.
2479   //
2480   //  y_data: device memory region in which to place the fusion result.
2481   //  batch_mean_data: device memory in which to place the batch mean output.
2482   //  batch_var_data: device memory in which to place the batch variance output.
2483   //  saved_mean_data: device memory in which to save the mean for bwd pass.
2484   //  saved_var_data: device memory in which to save the variance for bwd pass.
2485   //
2486   //  output_profile_result: the output profile result for this call.
2487   //         The profiling is only enabled when this is not nullptr.
2488   //
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2489   virtual bool DoFusedBatchNormActivationForward(
2490       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2491       const DeviceMemory<float>& x_data,
2492       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2493       const DeviceMemory<float>& scale_data,
2494       const DeviceMemory<float>& offset_data, double epsilon,
2495       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2496       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2497       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2498       dnn::ProfileResult* output_profile_result) {
2499     return false;
2500   }
2501 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2502   virtual bool DoFusedBatchNormActivationForward(
2503       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2504       const DeviceMemory<Eigen::half>& x_data,
2505       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2506       const DeviceMemory<float>& scale_data,
2507       const DeviceMemory<float>& offset_data, double epsilon,
2508       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2509       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2510       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2511       dnn::ProfileResult* output_profile_result) {
2512     return false;
2513   }
2514 
2515   // Enqueues a fused batchnorm+activation (training-bwd) operation onto the
2516   // stream.
2517   //
2518   // Arguments (all borrowed):
2519   //
2520   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2521   //  be enqueued onto.
2522   //
2523   //  y_act_backprop_descriptor: dimensions of the backprop input from the
2524   //  previous layer. y_act_backprop_data: device memory which contains the
2525   //  backprop input.
2526   //
2527   //  y_act_data: device memory which contains the actv-fwd output data.
2528   //
2529   //  activation_mode: actv-fwd type.
2530   //
2531   //  scale_offset_mean_variance_descriptor:
2532   //      dimensions of the scale/offset/mean/variance tensor.
2533   //  scale_data: device memory which contains the scale input.
2534   //  offset_data: device memory which contains the offset input.
2535   //  saved_mean_data: device memory which contains the saved mean from fwd
2536   //  pass. saved_var_data: device memory which contains the saved variance from
2537   //  fwd pass.
2538   //
2539   //  x_bn_backprop_data: device memory region in which to place the backprop
2540   //  data from this layer scale_backprop_data: device memory in which to place
2541   //  the scale backprop output. offset_backprop_data: device memory in which to
2542   //  place the offset backprop output.
2543   //
2544   //  output_profile_result: the output profile result for this call.
2545   //         The profiling is only enabled when this is not nullptr.
2546   //
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2547   virtual bool DoFusedBatchNormActivationBackward(
2548       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2549       const DeviceMemory<float>& y_act_backprop_data,
2550       const DeviceMemory<float>& y_act_data,
2551       dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
2552       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2553       const DeviceMemory<float>& scale_data,
2554       const DeviceMemory<float>& offset_data,
2555       const DeviceMemory<float>& saved_mean_data,
2556       const DeviceMemory<float>& saved_var_data,
2557       DeviceMemory<float>* x_bn_backprop_data,
2558       DeviceMemory<float>* scale_backprop_data,
2559       DeviceMemory<float>* offset_backprop_data,
2560       dnn::ProfileResult* output_profile_result) {
2561     return false;
2562   }
2563 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2564   virtual bool DoFusedBatchNormActivationBackward(
2565       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2566       const DeviceMemory<Eigen::half>& y_act_backprop_data,
2567       const DeviceMemory<Eigen::half>& y_act_data,
2568       dnn::ActivationMode activation_mode,
2569       const DeviceMemory<Eigen::half>& x_bn_data,
2570       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2571       const DeviceMemory<float>& scale_data,
2572       const DeviceMemory<float>& offset_data,
2573       const DeviceMemory<float>& saved_mean_data,
2574       const DeviceMemory<float>& saved_var_data,
2575       DeviceMemory<Eigen::half>* x_bn_backprop_data,
2576       DeviceMemory<float>* scale_backprop_data,
2577       DeviceMemory<float>* offset_backprop_data,
2578       dnn::ProfileResult* output_profile_result) {
2579     return false;
2580   }
2581 
2582  protected:
2583   // Returns whether status is 'ok', and potentially logs the error.
2584   static bool IsStatusOk(const port::Status& status, bool report_error);
2585 
2586  private:
DoPrepareForConvolution(ConvolutionKind kind,DataType element_type,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemoryBase input_data,const FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2587   virtual port::Status DoPrepareForConvolution(
2588       ConvolutionKind kind, DataType element_type, Stream* stream,
2589       const BatchDescriptor& batch_descriptor, DeviceMemoryBase input_data,
2590       const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data,
2591       const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
2592       const ConvolutionDescriptor& convolution_descriptor,
2593       const AlgorithmConfig& algorithm_config,
2594       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
2595       DeviceMemory<uint8>* scratch_memory) {
2596     *algorithm_desc = {};
2597     *scratch_memory = {};
2598     return port::Status::OK();
2599   }
2600 
DoPrepareForCtcLoss(Stream * stream,DataType element_type,const RnnStateTensorDescriptor & probs_desc,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2601   virtual port::Status DoPrepareForCtcLoss(
2602       Stream* stream, DataType element_type,
2603       const RnnStateTensorDescriptor& probs_desc,
2604       const RnnStateTensorDescriptor& grads_desc,
2605       absl::Span<const int> labels_data,
2606       absl::Span<const int> labels_lengths_data,
2607       absl::Span<const int> input_lengths_data,
2608       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
2609       int* ctc_loss_algo_id) {
2610     *scratch_memory = {};
2611     return port::Status::OK();
2612   }
2613 
2614   SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
2615 };
2616 
2617 }  // namespace dnn
2618 }  // namespace stream_executor
2619 
2620 #endif  // TENSORFLOW_STREAM_EXECUTOR_DNN_H_
2621