• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Neural Net operation support for StreamExecutor instances.
17 //
18 // This is an abstract interface for a platform to optionally support common
19 // neural net operations; it accommodates implementations such as the cudnn
20 // library operations.
21 
22 #ifndef TENSORFLOW_STREAM_EXECUTOR_DNN_H_
23 #define TENSORFLOW_STREAM_EXECUTOR_DNN_H_
24 
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <tuple>
29 #include <type_traits>
30 
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/stream_executor/data_type.h"
34 #include "tensorflow/stream_executor/device_memory.h"
35 #include "tensorflow/stream_executor/dnn.pb.h"
36 #include "tensorflow/stream_executor/lib/array_slice.h"
37 #include "tensorflow/stream_executor/lib/status.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39 #include "tensorflow/stream_executor/platform/logging.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 
42 namespace Eigen {
43 struct half;
44 }  // namespace Eigen
45 
46 namespace stream_executor {
47 
48 class HostBuffer;
49 class Stream;
50 class ScratchAllocator;
51 
52 namespace dnn {
53 
54 // Specifies an index to use when accessing specific spatial dimensions.
55 enum class DimIndex : int {
56   X = 0,
57   Y = 1,
58   Z = 2,
59 };
60 
61 // Helper functions to make methods more readable.
GetDim(absl::Span<const int64> data,DimIndex dim)62 inline int64 GetDim(absl::Span<const int64> data, DimIndex dim) {
63   return data.rbegin()[static_cast<int64>(dim)];
64 }
65 
SetDim(absl::Span<int64> data,DimIndex dim,int64 value)66 inline void SetDim(absl::Span<int64> data, DimIndex dim, int64 value) {
67   data.rbegin()[static_cast<int64>(dim)] = value;
68 }
69 
SetDim(std::vector<int64> * data,DimIndex dim,int64 value)70 inline void SetDim(std::vector<int64>* data, DimIndex dim, int64 value) {
71   return SetDim(absl::MakeSpan(*data), dim, value);
72 }
73 
74 // int64 is not the same type as tensorflow::protobuf_int64 in open-source. This
75 // wrapper function gives an int64 array slice view of a repeated int64 protobuf
76 // field.
77 //
78 // T should be a protobuf RepeatedField.
79 template <typename T>
AsInt64Slice(const T & repeated_field)80 inline absl::Span<const int64> AsInt64Slice(const T& repeated_field) {
81   using data_ty =
82       typename std::remove_reference<decltype(*repeated_field.data())>::type;
83   static_assert(std::is_integral<data_ty>::value &&
84                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
85                 "repeated_field.data() must return a pointer to a signed "
86                 "64-bit integer type.");
87   return absl::Span<const int64>(
88       reinterpret_cast<const int64*>(repeated_field.data()),
89       repeated_field.size());
90 }
91 template <typename T>
AsInt64Slice(T * repeated_field)92 inline absl::Span<int64> AsInt64Slice(T* repeated_field) {
93   using data_ty =
94       typename std::remove_reference<decltype(*repeated_field->data())>::type;
95   static_assert(std::is_integral<data_ty>::value &&
96                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
97                 "repeated_field->data() must return a pointer to a signed "
98                 "64-bit integer type.");
99   return absl::Span<int64>(
100       reinterpret_cast<int64*>(repeated_field->mutable_data()),
101       repeated_field->size());
102 }
103 
104 // Returns a string representation of the given data layout.
105 std::string DataLayoutString(DataLayout layout);
106 
107 // Specifies a quantization for activations in a given BatchDescriptor.
108 enum class QuantizedActivationMode {
109   k8Bit = 1,
110   k16Bit = 2,
111   k32Bit = 4,
112 };
113 
114 // Specifies the types of a RNN model.
115 enum class RnnMode {
116   kRnnRelu = 0,
117   kRnnTanh = 1,
118   kRnnLstm = 2,
119   kRnnGru = 3,
120 };
121 
122 // Specifies the input model and whether there is a linear transformation
123 // between the input state and the first layer hidden state.
124 enum class RnnInputMode {
125   kRnnLinearSkip = 0,
126   kRnnSkipInput = 1,
127 };
128 
129 // Specifies the number of directions used in a RNN model. When bidirection
130 // is used, the input states and output sequence contain data for both
131 // directions.
132 enum class RnnDirectionMode {
133   kRnnUnidirectional = 0,
134   kRnnBidirectional = 1,
135 };
136 
137 // Relevant to DepthToSpace and SpaceToDepth. This is the write layout when
138 // performing depth to space and the read layout when performing space to depth.
139 // It's specified with most-major dimension first and most-minor dimension last.
140 // In DepthToSpace, the D*M^2 values are read in and then, for DepthHeightWidth,
141 // written out to the output patch, by varying first width, then height, then
142 // depth. In C array format, it looks like [depth][height][width]. See
143 // DepthToSpace comment for more information.
144 enum class DepthToSpaceLayout { DepthHeightWidth };
145 
146 // Specifies the descriptor for a RNN model.
147 //
148 // An example use case:
149 //   * The user first creates a model through createRnnDescriptor.
150 //   * The user queries the size of the underlying opaque parameter buffer.
151 //   * The user creates and initializes a parameter buffer of the proper size.
152 //   * The user runs forward and backward operations using this RNN descriptor.
153 //   * Once a while, user queries maintainable weights and bias regions from
154 //       the underlying parameter buffer. They are more likely to be forward
155 //       compatible and should used in saving and restoring a model.
156 //   * The user releases the RNN descriptor when the model is no longer in use.
157 class RnnDescriptor {
158  public:
159   struct ParamsRegion {
160     int64 offset;
161     int64 size;
162   };
163   typedef std::vector<ParamsRegion> ParamsRegions;
~RnnDescriptor()164   virtual ~RnnDescriptor() {}
ParamsSizeInBytes()165   virtual int64 ParamsSizeInBytes() const { return -1; }
ParamsWeightRegions()166   virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); }
ParamsBiasRegions()167   virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
168 };
169 
170 // Specifies the sequence in a RNN model.
171 //
172 // The user is responsible for releasing this descriptor when it is no longer
173 // in use. The destructor releases the underlying descriptors.
174 class RnnSequenceTensorDescriptor {
175  public:
~RnnSequenceTensorDescriptor()176   virtual ~RnnSequenceTensorDescriptor() {}
177 };
178 
179 // Specifies either the input and hidden state in a RNN model.
180 //
181 // The user is responsible for releasing this descriptor when it is no longer
182 // in use. The destructor releases the underlying descriptors.
183 class RnnStateTensorDescriptor {
184  public:
~RnnStateTensorDescriptor()185   virtual ~RnnStateTensorDescriptor() {}
186 };
187 
188 // Returns a string representation of the given quantization mode.
189 std::string QuantizedActivationModeString(QuantizedActivationMode mode);
190 
191 // Describes the dimensions that a layer consumes/produces.
192 //
193 // This is a matrix (height, width), its "depth" (feature_map_count),
194 // how many of these matrices are present (count),
195 // and the maximum and minimum values expected in the matrix (value_max,
196 // value_min).
197 // If input is quantized, all values greater
198 // than value_max will be clipped to value_max and all values less than
199 // value_min will be clipped to value_min.
200 // When quantized output is dequantized no value will be greater than
201 // value_max or less than value_min.
202 //
203 // Uses the named argument construction form:
204 //
205 //  auto input_batch_dimensions =
206 //      BatchDescriptor().set_count(42).set_feature_map_count(7)...
207 //
208 // Details:
209 //
210 // For a convolutional layer, a single inference takes a 3-dimensional matrix
211 // of input and produces a 3-dimensional matrix of output. We call the three
212 // dimensions height, width and feature_map_count, where for an image, the
213 // height and width correspond to the Y and X pixel indices, respectively, and
214 // the feature_map_count corresponds to the RGB dimension of the input data.
215 // Then the count indicates how many 3D matrices are being presented to be
216 // processed at once; this corresponds to the neural network concept of
217 // minibatch size.
218 //
219 // For a fully connected layer, it's better to put the nodes of the layer in
220 // the feature_map_count, and leave the height and weight as degenerate (== 1).
221 // Count indicates how many input vectors (degenerate 3D matrices) are to be
222 // processed.
223 //
224 // If unspecified, value_max and value_min default to 0.0.
225 // If value_max == value_min the Stream will attempt to derive valid values -
226 // for example the output of Relu6 activation will always be in the range
227 // [0.0, 6.0].
228 //
229 // If unspecified, layout defaults to kYXDepthBatch.
230 class BatchDescriptor {
231  public:
232   // Creates a "blank" batch descriptor, which should be initialized via the
233   // named argument helpers.
234   BatchDescriptor();
235   explicit BatchDescriptor(int ndims);
236 
237   // Clones values from 'other' for initialization.
238   void CloneFrom(const BatchDescriptor& other);
239 
240   std::string ToString() const;
241   std::string ToShortString() const;
242 
243   // Pre-condition:
244   //   value_max_ == 0
245   //   value_min_ == 0
246   //   quantized_activation_mode_ == QuantizedActivationMode::k8Bit
247   TensorDescriptorProto ToProto(DataType data_type) const;
248 
249   // Accessors.
count()250   int64 count() const { return tensor_.dimensions(0); }
feature_map_count()251   int64 feature_map_count() const { return tensor_.dimensions(1); }
height()252   int64 height() const { return GetDim(spatial_size(), DimIndex::Y); }
width()253   int64 width() const { return GetDim(spatial_size(), DimIndex::X); }
spatial_dim(DimIndex dim)254   int64 spatial_dim(DimIndex dim) const { return GetDim(spatial_size(), dim); }
ndims()255   int ndims() const { return spatial_size().size(); }
value_max()256   float value_max() const { return value_max_; }
value_min()257   float value_min() const { return value_min_; }
layout()258   DataLayout layout() const { return tensor_.data_layout(); }
quantized_activation_mode()259   QuantizedActivationMode quantized_activation_mode() const {
260     return quantized_activation_mode_;
261   }
262   // Full dimensions of the underlying data, ordered according to a specific
263   // layout.
264   std::vector<int64> full_dims(const DataLayout& layout) const;
265 
266   // Full strides of the underlying data, ordered according to a specific
267   // layout.
268   std::vector<int64> full_strides(const DataLayout& layout) const;
269 
270   // Named-argument helpers for avoiding user error during construction.
set_count(int64 value)271   BatchDescriptor& set_count(int64 value) {
272     tensor_.set_dimensions(0, value);
273     return *this;
274   }
set_feature_map_count(int64 value)275   BatchDescriptor& set_feature_map_count(int64 value) {
276     tensor_.set_dimensions(1, value);
277     return *this;
278   }
set_height(int64 value)279   BatchDescriptor& set_height(int64 value) {
280     SetDim(spatial_size(), DimIndex::Y, value);
281     return *this;
282   }
set_width(int64 value)283   BatchDescriptor& set_width(int64 value) {
284     SetDim(spatial_size(), DimIndex::X, value);
285     return *this;
286   }
set_spatial_dim(DimIndex dim,int64 value)287   BatchDescriptor& set_spatial_dim(DimIndex dim, int64 value) {
288     SetDim(spatial_size(), dim, value);
289     return *this;
290   }
set_value_max(float value)291   BatchDescriptor& set_value_max(float value) {
292     value_max_ = value;
293     return *this;
294   }
set_value_min(float value)295   BatchDescriptor& set_value_min(float value) {
296     value_min_ = value;
297     return *this;
298   }
set_layout(DataLayout layout)299   BatchDescriptor& set_layout(DataLayout layout) {
300     tensor_.set_data_layout(layout);
301     return *this;
302   }
set_quantized_activation_mode(QuantizedActivationMode quantized_activation_mode)303   BatchDescriptor& set_quantized_activation_mode(
304       QuantizedActivationMode quantized_activation_mode) {
305     quantized_activation_mode_ = quantized_activation_mode;
306     return *this;
307   }
308 
309   // Return the number of nodes in a single feature map.
310   int64 NodesPerFeatureMap() const;
311 
312   // Return the number of nodes across all feature maps. Note that this is not
313   // affected by the batch count.
314   int64 NodesAcrossFeatureMaps() const;
315 
316   // Returns the number of elements (e.g. RGB pixel values) required to hold a
317   // given batch descriptor, given a no-padding assumption. Note that this is
318   // affected by the batch count.
319   int64 ElementCount() const;
320 
321   // Return the number of weights required to fully connect a layer with
322   // dimensions given by the 'input' descriptor with a layer with dimensions
323   // given by the 'output' descriptor.
324   static int64 FullyConnectedWeightCount(const BatchDescriptor& input,
325                                          const BatchDescriptor& output);
326 
327   // Return the number of biases required to fully connect to an output layer
328   // with dimensions given the 'output' descriptor.
329   static int64 FullyConnectedBiasCount(const BatchDescriptor& output);
330 
331   // Return a BatchDescriptor for the output of a depth concatenation
332   // with the given input descriptors. The inputs should have the same
333   // dimensions, except possibly for feature_map_count(), though this
334   // function does not verify that.
335   static BatchDescriptor DepthConcatenateOutputDescriptor(
336       port::ArraySlice<dnn::BatchDescriptor> inputs);
337 
338  private:
spatial_size()339   absl::Span<const int64> spatial_size() const {
340     return AsInt64Slice(tensor_.dimensions()).subspan(2);
341   }
342 
spatial_size()343   absl::Span<int64> spatial_size() {
344     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
345   }
346 
347   TensorDescriptorProto tensor_;
348   float value_max_;
349   float value_min_;
350   QuantizedActivationMode quantized_activation_mode_;
351 };
352 
353 // Returns a string representation of the given filter layout.
354 std::string FilterLayoutString(FilterLayout layout);
355 
356 // Describes a filter for the convolution. This is the "window" from
357 // height-by-width patches of each of the feature maps in the input layer to the
358 // cells within the output feature map.
359 //
360 // Uses the named argument construction form:
361 //
362 //  FilterDescriptor filter_dimensions;
363 //  filter_dimensions
364 //    .set_output_feature_map_count(42)
365 //    .set_input_feature_map_count(7)
366 //    ...
367 //
368 // Arguments:
369 // - output_feature_map_count: number of feature maps in the output layer.
370 // - input_feature_map_count: number of feature maps in the input layer (from
371 //      which the filter patch is taken).
372 // - input_filter_height: "height" number of neurons used in the sliding window
373 //      over the input layer.
374 // - input_filter_width: "width" number of neurons used in the sliding window
375 //      over the input layer.
376 //
377 // Sometimes names like "filter input height" are referred to by synonymous
378 // terminology, such as "kernel y size".
379 //
380 // If unspecified, layout defaults to kOutputInputYX.
381 class FilterDescriptor {
382  public:
383   // By default construction, all dimensions are set to zero, so they should all
384   // be populated by the user via the named-argument helpers below. (See class
385   // comment for details.)
386   FilterDescriptor();
387   explicit FilterDescriptor(int ndims);
388   ~FilterDescriptor();
389 
390   // Named-argument helpers for avoiding user error during construction.
set_output_feature_map_count(int64 value)391   FilterDescriptor& set_output_feature_map_count(int64 value) {
392     tensor_.set_dimensions(0, value);
393     return *this;
394   }
set_input_feature_map_count(int64 value)395   FilterDescriptor& set_input_feature_map_count(int64 value) {
396     tensor_.set_dimensions(1, value);
397     return *this;
398   }
set_input_filter_height(int64 value)399   FilterDescriptor& set_input_filter_height(int64 value) {
400     SetDim(input_filter_dims(), DimIndex::Y, value);
401     return *this;
402   }
set_input_filter_width(int64 value)403   FilterDescriptor& set_input_filter_width(int64 value) {
404     SetDim(input_filter_dims(), DimIndex::X, value);
405     return *this;
406   }
set_layout(FilterLayout layout)407   FilterDescriptor& set_layout(FilterLayout layout) {
408     tensor_.set_filter_layout(layout);
409     return *this;
410   }
set_spatial_dim(DimIndex dim,int64 value)411   FilterDescriptor& set_spatial_dim(DimIndex dim, int64 value) {
412     SetDim(input_filter_dims(), dim, value);
413     return *this;
414   }
ndims()415   int ndims() const { return input_filter_dims().size(); }
416 
417   void CloneFrom(const FilterDescriptor& other);
418 
419   std::string ToString() const;
420   std::string ToShortString() const;
421   TensorDescriptorProto ToProto(DataType data_type) const;
422 
423   // Returns the number of weights required as parameters for a convolution
424   // using this filter descriptor.
425   int64 ComputeWeightCount() const;
426 
427   // Returns the number of biases required as parameters for a convolution
428   // using this filter descriptor.
bias_count()429   int64 bias_count() const { return output_feature_map_count(); }
430 
output_feature_map_count()431   int64 output_feature_map_count() const { return tensor_.dimensions(0); }
input_feature_map_count()432   int64 input_feature_map_count() const { return tensor_.dimensions(1); }
input_filter_height()433   int64 input_filter_height() const {
434     return GetDim(input_filter_dims(), DimIndex::Y);
435   }
input_filter_width()436   int64 input_filter_width() const {
437     return GetDim(input_filter_dims(), DimIndex::X);
438   }
input_filter_dim(DimIndex dim)439   int64 input_filter_dim(DimIndex dim) const {
440     return GetDim(input_filter_dims(), dim);
441   }
442 
layout()443   FilterLayout layout() const { return tensor_.filter_layout(); }
444 
input_filter_dims()445   absl::Span<const int64> input_filter_dims() const {
446     return AsInt64Slice(tensor_.dimensions()).subspan(2);
447   }
448 
449  private:
input_filter_dims()450   absl::Span<int64> input_filter_dims() {
451     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
452   }
453 
454   TensorDescriptorProto tensor_;
455 };
456 
457 // Describes how padding should be aligned when the total number of pad
458 // elements is odd.
459 enum class PadAlignment : int64 {
460   kDefault = 0,        // default padding for the device.
461   kCudnnPadding,       // cuDNN padding - prefer to pad at the start.
462   kTensorFlowPadding,  // TensorFlow padding - prefer to pad at the end.
463 };
464 
465 // Returns a string representation of the given padding alignment.
466 std::string PadAlignmentString(PadAlignment alignment);
467 
468 // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
469 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
470 
471 // Describes a convolution.
472 //
473 // Uses the named argument construction form:
474 //
475 //  ConvolutionDescriptor convolution_dimensions;
476 //  convolution_dimensions
477 //    .set_vertical_filter_stride(2)
478 //    .set_horizontal_filter_stride(2)
479 //    ...
480 //
481 // Arguments:
482 // - zero_padding_height: padding of the "y dimension" of the input data. Note
483 //    that this is different from the height of the filter.
484 // - zero_padding_width: analogous to the height above, but in the "x
485 //    dimension".
486 // - vertical_filter_stride: the convolution slides a 2-dimensional window of
487 //    filter-height-by-filter-width over the input layer -- the center of that
488 //    window is moved in the "y dimension" according to this stride value.
489 // - horizontal_filter_stride: analogous to the vertical stride above, but in
490 //    the "x dimension".
491 // - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped
492 //   cells between each filter element in the "y dimension".
493 // - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1)
494 //   skipped cells between each filter element in the "x dimension".
495 // - convolution_not_crosscor: By default (convolution_not_crosscor == false),
496 //   we perform cross correlation rather than convolution. With the flag set,
497 //   we perform convolution. Convolution and cross correlation are related by
498 //   rotating the filter by 180 degrees (or equivalently flipping all spatial
499 //   dimensions).
500 class ConvolutionDescriptor {
501  public:
502   // By default construction, there is no zero-padding and the filter stride is
503   // 1x1 (centering the filter on every cell in the input layer's
504   // width-by-height area).
505   ConvolutionDescriptor();
506   explicit ConvolutionDescriptor(int ndims);
507   ~ConvolutionDescriptor();
508 
509   std::string ToString() const;
510   std::string ToShortString() const;
ToProto()511   ConvolutionDescriptorProto ToProto() const { return proto_; }
512 
set_zero_padding_height(int64 value)513   ConvolutionDescriptor& set_zero_padding_height(int64 value) {
514     SetDim(padding(), DimIndex::Y, value);
515     return *this;
516   }
set_zero_padding_width(int64 value)517   ConvolutionDescriptor& set_zero_padding_width(int64 value) {
518     SetDim(padding(), DimIndex::X, value);
519     return *this;
520   }
set_zero_padding(DimIndex dim,int64 value)521   ConvolutionDescriptor& set_zero_padding(DimIndex dim, int64 value) {
522     SetDim(padding(), dim, value);
523     return *this;
524   }
set_vertical_filter_stride(int64 value)525   ConvolutionDescriptor& set_vertical_filter_stride(int64 value) {
526     SetDim(strides(), DimIndex::Y, value);
527     return *this;
528   }
set_horizontal_filter_stride(int64 value)529   ConvolutionDescriptor& set_horizontal_filter_stride(int64 value) {
530     SetDim(strides(), DimIndex::X, value);
531     return *this;
532   }
set_filter_stride(DimIndex dim,int64 value)533   ConvolutionDescriptor& set_filter_stride(DimIndex dim, int64 value) {
534     SetDim(strides(), dim, value);
535     return *this;
536   }
set_vertical_dilation_rate(int64 value)537   ConvolutionDescriptor& set_vertical_dilation_rate(int64 value) {
538     SetDim(dilations(), DimIndex::Y, value);
539     return *this;
540   }
set_horizontal_dilation_rate(int64 value)541   ConvolutionDescriptor& set_horizontal_dilation_rate(int64 value) {
542     SetDim(dilations(), DimIndex::X, value);
543     return *this;
544   }
set_dilation_rate(DimIndex dim,int64 value)545   ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64 value) {
546     SetDim(dilations(), dim, value);
547     return *this;
548   }
set_group_count(int group_count)549   ConvolutionDescriptor& set_group_count(int group_count) {
550     proto_.set_group_count(group_count);
551     return *this;
552   }
set_convolution_not_crosscorr(bool conv)553   ConvolutionDescriptor& set_convolution_not_crosscorr(bool conv) {
554     proto_.set_convolution_mode(conv ? ConvolutionMode::CONVOLUTION
555                                      : ConvolutionMode::CROSS_CORRELATION);
556     return *this;
557   }
set_name(const std::string & name)558   ConvolutionDescriptor& set_name(const std::string& name) {
559     proto_.set_name(name);
560     return *this;
561   }
zero_padding_height()562   int64 zero_padding_height() const { return GetDim(padding(), DimIndex::Y); }
zero_padding_width()563   int64 zero_padding_width() const { return GetDim(padding(), DimIndex::X); }
vertical_filter_stride()564   int64 vertical_filter_stride() const {
565     return GetDim(strides(), DimIndex::Y);
566   }
horizontal_filter_stride()567   int64 horizontal_filter_stride() const {
568     return GetDim(strides(), DimIndex::X);
569   }
vertical_dilation_rate()570   int64 vertical_dilation_rate() const {
571     return GetDim(dilations(), DimIndex::Y);
572   }
horizontal_dilation_rate()573   int64 horizontal_dilation_rate() const {
574     return GetDim(dilations(), DimIndex::X);
575   }
576 
zero_padding(DimIndex dim)577   int zero_padding(DimIndex dim) const { return GetDim(padding(), dim); }
filter_stride(DimIndex dim)578   int filter_stride(DimIndex dim) const { return GetDim(strides(), dim); }
dilation_rate(DimIndex dim)579   int dilation_rate(DimIndex dim) const { return GetDim(dilations(), dim); }
580   // TODO(timshen): remove this function. No users of this class is setting a
581   // non-default pad alignment.
pad_alignment()582   PadAlignment pad_alignment() const { return PadAlignment::kDefault; }
group_count()583   int group_count() const { return proto_.group_count(); }
ndims()584   int ndims() const { return padding().size(); }
convolution_not_crosscorr()585   bool convolution_not_crosscorr() const {
586     return proto_.convolution_mode() == ConvolutionMode::CONVOLUTION;
587   }
588 
strides()589   absl::Span<const int64> strides() const {
590     return AsInt64Slice(proto_.strides());
591   }
592 
dilations()593   absl::Span<const int64> dilations() const {
594     return AsInt64Slice(proto_.dilations());
595   }
596 
padding()597   absl::Span<const int64> padding() const {
598     return AsInt64Slice(proto_.paddings());
599   }
600 
name()601   std::string name() const { return proto_.name(); }
602 
603  private:
strides()604   absl::Span<int64> strides() { return AsInt64Slice(proto_.mutable_strides()); }
605 
dilations()606   absl::Span<int64> dilations() {
607     return AsInt64Slice(proto_.mutable_dilations());
608   }
609 
padding()610   absl::Span<int64> padding() {
611     return AsInt64Slice(proto_.mutable_paddings());
612   }
613 
614   ConvolutionDescriptorProto proto_;
615 
616   // TODO(leary) cudnn provides these fields, but need to characterize what
617   // their effect is -- they may be boolean rather than integral.
618   // int64 upscale_input_x;
619   // int64 upscale_input_y;
620 };
621 
622 // A patch of values in the input can be pooled via either a max or an average
623 // operation.
624 // Specify int64 so there's no padding in PoolingDescriptor.
625 enum class PoolingMode : int64 {
626   kMaximum,
627   kAverage,
628 };
629 
630 // Specify the dimension in which to concatenate inputs in space.
631 // Specify int64 so there's no padding in SpaceConcatenateMode.
632 enum class SpaceConcatenateMode : int64 {
633   XDirection,
634   YDirection,
635 };
636 
637 // Returns a short name for the pooling mode, e.g. "Avg".
638 std::string ShortPoolingModeString(PoolingMode mode);
639 
640 // Describes a pooling operation to be enqueued onto a stream via a platform's
641 // DnnSupport.
642 //
643 // TODO(broune): describe how padding works and what happens if the
644 // window height/width is not divisible by the vertical/horizontal
645 // stride.
646 //
647 // Arguments:
648 //  pooling_mode: pooling operator to use on the input patch
649 //  window_height: height of input window
650 //  window_width: width of input window
651 //  vertical_stride: vertical delta for center of the input patch
652 //  horizontal_stride: horizontal delta for center of the input patch
653 class PoolingDescriptor {
654  public:
655   PoolingDescriptor();
656   explicit PoolingDescriptor(int ndims);
657 
set_pooling_mode(PoolingMode value)658   PoolingDescriptor& set_pooling_mode(PoolingMode value) {
659     mode_ = value;
660     return *this;
661   }
set_window_height(int64 value)662   PoolingDescriptor& set_window_height(int64 value) {
663     SetDim(&window_, DimIndex::Y, value);
664     return *this;
665   }
set_window_width(int64 value)666   PoolingDescriptor& set_window_width(int64 value) {
667     SetDim(&window_, DimIndex::X, value);
668     return *this;
669   }
set_window(DimIndex dim,int64 value)670   PoolingDescriptor& set_window(DimIndex dim, int64 value) {
671     SetDim(&window_, dim, value);
672     return *this;
673   }
set_vertical_padding(int64 value)674   PoolingDescriptor& set_vertical_padding(int64 value) {
675     SetDim(&padding_, DimIndex::Y, value);
676     return *this;
677   }
set_horizontal_padding(int64 value)678   PoolingDescriptor& set_horizontal_padding(int64 value) {
679     SetDim(&padding_, DimIndex::X, value);
680     return *this;
681   }
set_padding(DimIndex dim,int64 value)682   PoolingDescriptor& set_padding(DimIndex dim, int64 value) {
683     SetDim(&padding_, dim, value);
684     return *this;
685   }
set_vertical_stride(int64 value)686   PoolingDescriptor& set_vertical_stride(int64 value) {
687     SetDim(&strides_, DimIndex::Y, value);
688     return *this;
689   }
set_horizontal_stride(int64 value)690   PoolingDescriptor& set_horizontal_stride(int64 value) {
691     SetDim(&strides_, DimIndex::X, value);
692     return *this;
693   }
set_stride(DimIndex dim,int64 value)694   PoolingDescriptor& set_stride(DimIndex dim, int64 value) {
695     SetDim(&strides_, dim, value);
696     return *this;
697   }
set_propagate_nans(bool value)698   PoolingDescriptor& set_propagate_nans(bool value) {
699     propagate_nans_ = value;
700     return *this;
701   }
set_name(const std::string & name)702   PoolingDescriptor& set_name(const std::string& name) {
703     name_ = name;
704     return *this;
705   }
706 
ndims()707   int ndims() const { return ndims_; }
708   void CloneFrom(const PoolingDescriptor& other);
709 
710   std::string ToString() const;
711   std::string ToShortString() const;
712 
mode()713   PoolingMode mode() const { return mode_; }
window_height()714   int64 window_height() const { return GetDim(window_, DimIndex::Y); }
window_width()715   int64 window_width() const { return GetDim(window_, DimIndex::X); }
window(DimIndex dim)716   int64 window(DimIndex dim) const { return GetDim(window_, dim); }
vertical_padding()717   int64 vertical_padding() const { return GetDim(padding_, DimIndex::Y); }
horizontal_padding()718   int64 horizontal_padding() const { return GetDim(padding_, DimIndex::X); }
padding(DimIndex dim)719   int64 padding(DimIndex dim) const { return GetDim(padding_, dim); }
vertical_stride()720   int64 vertical_stride() const { return GetDim(strides_, DimIndex::Y); }
horizontal_stride()721   int64 horizontal_stride() const { return GetDim(strides_, DimIndex::X); }
stride(DimIndex dim)722   int64 stride(DimIndex dim) const { return GetDim(strides_, dim); }
window()723   absl::Span<const int64> window() const { return window_; }
padding()724   absl::Span<const int64> padding() const { return padding_; }
strides()725   absl::Span<const int64> strides() const { return strides_; }
propagate_nans()726   bool propagate_nans() const { return propagate_nans_; }
name()727   std::string name() const { return name_; }
728 
729  private:
730   PoolingMode mode_;
731   int ndims_;
732   bool propagate_nans_;
733   std::string name_;  // Name as in Tensorflow NodeDef, for debugging purposes.
734 
735   // Stored as: ..., y, x.
736   std::vector<int64> window_;
737   std::vector<int64> padding_;
738   std::vector<int64> strides_;
739 };
740 
741 // Collects parameters for DNN algorithms
742 class AlgorithmDesc {
743  public:
744   typedef int64 Index;
AlgorithmDesc()745   AlgorithmDesc() : AlgorithmDesc(0, false) {}
AlgorithmDesc(Index a,bool use_tensor_ops)746   AlgorithmDesc(Index a, bool use_tensor_ops) {
747     proto_.set_algo_id(a);
748     proto_.set_math_type(use_tensor_ops ? AlgorithmProto::TENSOR_OP_MATH
749                                         : AlgorithmProto::DEFAULT_MATH);
750   }
tensor_ops_enabled()751   bool tensor_ops_enabled() const {
752     return proto_.math_type() == AlgorithmProto::TENSOR_OP_MATH;
753   }
algo_id()754   Index algo_id() const { return proto_.algo_id(); }
755   bool operator==(const AlgorithmDesc& other) const {
756     return algo_id() == other.algo_id() &&
757            tensor_ops_enabled() == other.tensor_ops_enabled();
758   }
759   uint64 hash() const;
760 
ToProto()761   AlgorithmProto ToProto() const { return proto_; }
762 
763   std::string ToString() const;
764 
765  private:
766   AlgorithmProto proto_;
767 };
768 
769 // Describes the result from a perf experiment.
770 //
771 // Arguments:
772 //  algorithm: returns the exact algorithm that was used.
773 //  elapsed_time_in_ms: returns the measured elapsed time in milliseconds.
774 class ProfileResult {
775  public:
is_valid()776   bool is_valid() const {
777     return algorithm_.has_value() &&
778            elapsed_time_in_ms() != std::numeric_limits<float>::max();
779   }
780 
algorithm()781   AlgorithmDesc algorithm() const { return *algorithm_; }
set_algorithm(AlgorithmDesc val)782   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
783 
elapsed_time_in_ms()784   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)785   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
786 
scratch_size()787   size_t scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)788   void set_scratch_size(size_t val) { scratch_size_ = val; }
789 
790  private:
791   absl::optional<AlgorithmDesc> algorithm_;
792   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
793   // The scratch size algorithm_ requires. Currently it's only populated by
794   // convolutions.
795   size_t scratch_size_ = 0;
796 };
797 
798 // Describes the configuration for the algorithms that will used.
799 //
800 // Arguments:
801 //  algorithm: the primary algorithm that should be used.
802 //  algorithm_no_scratch: a secondary algorithm that should be used, if the
803 //    the allocation for the scratch memory fails.
804 //  scrach_size: specify the size of scratch memory in bytes needed for the
805 //    algorithm used.
806 //
807 // On CUDA platform with CUDNN library, algorithm and algorithm_no_scratch
808 // would be used. On ROCm platform with MIOpen library, algorithm and
809 // scratch_size would be used. The major difference between the two platforms
810 // are whether it's possible to get an algorithm without scratch memory. On
811 // CUDA + CUDNN it's possible, and algorithm_no_scratch can be used to track
812 // such information, whereas on ROCm + MIOpen there is no guarantee to getting
813 // one without scratch memory, and scratch_size field is used to track it.
814 class AlgorithmConfig {
815  public:
AlgorithmConfig()816   AlgorithmConfig() {}
AlgorithmConfig(AlgorithmDesc algorithm)817   explicit AlgorithmConfig(AlgorithmDesc algorithm) : algorithm_(algorithm) {}
AlgorithmConfig(AlgorithmDesc algorithm,size_t scratch_size)818   AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size)
819       : algorithm_(algorithm), scratch_size_(scratch_size) {}
AlgorithmConfig(AlgorithmDesc algorithm,AlgorithmDesc algorithm_no_scratch)820   AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch)
821       : algorithm_(algorithm), algorithm_no_scratch_(algorithm_no_scratch) {}
algorithm()822   absl::optional<AlgorithmDesc> algorithm() const { return algorithm_; }
set_algorithm(AlgorithmDesc val)823   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
algorithm_no_scratch()824   absl::optional<AlgorithmDesc> algorithm_no_scratch() const {
825     return algorithm_no_scratch_;
826   }
set_algorithm_no_scratch(AlgorithmDesc val)827   void set_algorithm_no_scratch(AlgorithmDesc val) {
828     algorithm_no_scratch_ = val;
829   }
scratch_size()830   absl::optional<size_t> scratch_size() const { return scratch_size_; }
set_scratch_size(size_t val)831   void set_scratch_size(size_t val) { scratch_size_ = val; }
832   bool operator==(const AlgorithmConfig& other) const {
833     return this->algorithm_ == other.algorithm_ &&
834            this->algorithm_no_scratch_ == other.algorithm_no_scratch_ &&
835            this->scratch_size_ == other.scratch_size_;
836   }
837   bool operator!=(const AlgorithmConfig& other) const {
838     return !(*this == other);
839   }
840   std::string ToString() const;
841 
842  private:
843   absl::optional<AlgorithmDesc> algorithm_;
844   absl::optional<AlgorithmDesc> algorithm_no_scratch_;
845   absl::optional<size_t> scratch_size_;
846 };
847 
848 // Describes a local response normalization (LRN). LRN is used e.g. in
849 // dist_belief.
850 //
851 // Let V be the vector of feature maps at some (batch, y, x)
852 // coordinate. LRN applies independently to each vector V in the
853 // input, across all coordinates (batch, y, x), by mapping each V to
854 // another vector U of the same size using the formula
855 //
856 //   U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
857 //
858 // where the sum is taken over j in the closed range [i - range, i + range].
859 //
860 // When calculating U_i the j in the sum can extend beyond the bounds
861 // of V. If wrap_around is true, then V_j = V_{j mod F} where F is the
862 // size of V, which is the number of feature maps. If wrap_around is
863 // false, then V_j = 0 for j outside [0, F-1].
864 //
865 // If segment_size <= F, where F is the number of feature_maps, then
866 // segment_size has no effect. Otherwise, each consecutive segment of
867 // segment_size entries in V are normalized separately.
868 //
869 // Not all StreamExecutors allow wrap_around == true or segment_size
870 // != 64. Some do not implement normalization at all.
871 class NormalizeDescriptor {
872  public:
873   NormalizeDescriptor();
874 
set_bias(float bias)875   NormalizeDescriptor& set_bias(float bias) {
876     bias_ = bias;
877     return *this;
878   }
879 
set_range(int32 range)880   NormalizeDescriptor& set_range(int32 range) {
881     range_ = range;
882     return *this;
883   }
884 
set_alpha(float alpha)885   NormalizeDescriptor& set_alpha(float alpha) {
886     alpha_ = alpha;
887     return *this;
888   }
889 
set_beta(float beta)890   NormalizeDescriptor& set_beta(float beta) {
891     beta_ = beta;
892     return *this;
893   }
894 
set_wrap_around(bool wrap_around)895   NormalizeDescriptor& set_wrap_around(bool wrap_around) {
896     wrap_around_ = wrap_around;
897     return *this;
898   }
899 
set_segment_size(int32 segment_size)900   NormalizeDescriptor& set_segment_size(int32 segment_size) {
901     segment_size_ = segment_size;
902     return *this;
903   }
904 
905   void CloneFrom(const NormalizeDescriptor& other);
906 
907   std::string ToString() const;
908   std::string ToShortString() const;
909 
bias()910   float bias() const { return bias_; }
range()911   int32 range() const { return range_; }
alpha()912   float alpha() const { return alpha_; }
beta()913   float beta() const { return beta_; }
wrap_around()914   bool wrap_around() const { return wrap_around_; }
segment_size()915   int32 segment_size() const { return segment_size_; }
916 
917  private:
918   float bias_;
919   int32 range_;
920   float alpha_;
921   float beta_;
922   bool wrap_around_;
923   int32 segment_size_;
924 };
925 
926 // Returns a string representation of the given activation mode.
927 std::string ActivationModeString(ActivationMode mode);
928 
929 // Describes the operation that DoElementwiseOperation should perform on its
930 // inputs.
931 enum class ElementwiseOperation { kAdd, kMultiply };
932 
933 std::string ElementwiseOperationString(ElementwiseOperation op);
934 
935 // A simple class representing the version of the backing library, to
936 // workaround the "too perfect forwarding" issue in gcc6+ compilers.
937 // See PR#16309 and issue #18402 for links discussing the issue.
938 class VersionInfo {
939  public:
940   VersionInfo(int major = 0, int minor = 0, int patch = 0)
major_(major)941       : major_(major), minor_(minor), patch_(patch) {}
major_version()942   int major_version() const { return major_; }
minor_version()943   int minor_version() const { return minor_; }
patch()944   int patch() const { return patch_; }
945 
946  private:
947   int major_;
948   int minor_;
949   int patch_;
950 };
951 
952 // Suite of operations typically used for implementing Deep/Convolutional Neural
953 // Nets. Note: A false return value of an operation indicates the
954 // implementation is not available.
955 //
956 // TODO(b/118763918): this class (or rather dispatch table) has several
957 // problems:
958 // * Some overloads are missing. Ideally we want to have template virtual
959 //   functions while the template arguments is a closed set. However, we don't
960 //   get that from the language.
961 // * The API is a union of cuDNN and another private backend. Only 10% of the
962 //   functions are actually implemented by both backends, the rest are
963 //   actually backend-specific. The massive interface creates extra mental
964 //   burden.
965 // * Poor error handling: the API should return Status objects.
966 //
967 // PrepareForConvolution is an example for how new APIs should be written.
968 class DnnSupport {
969  public:
DnnSupport()970   DnnSupport() {}
~DnnSupport()971   virtual ~DnnSupport() {}
972 
973   virtual port::Status Init() = 0;
974 
975   // Gets the version of the backing library, as a VersionInfo object.
GetVersion()976   virtual port::StatusOr<VersionInfo> GetVersion() {
977     return port::UnimplementedError(
978         "DnnSupport::GetVersion not implemented on this platform.");
979   }
980 
981   // Performs a single-precision forward batch normalization operation onto
982   // the stream.
983   //
984   // Arguments:
985   //  stream: borrowed pointer to the stream that the batch normalization
986   //    operation should be enqueued onto.
987   //  x: input data.
988   //  scale: scaling parameters.
989   //  offset: offset parameters.
990   //  estimated_mean: population mean estimated during training.
991   //    Used for inference only; empty for training.
992   //  estimated_variance: population variance estimated during training,
993   //    used for inference only; empty for training.
994   //  side_input: optional input that is element-wise added to the output of
995   //    batch normalization.
996   //  x_desc: dimensions of the input data, which is the same as the dimensions
997   //    of the output and side input.
998   //  scale_offset_desc: dimensions of scale and offset.
999   //  epsilon: a small floating point number added to the variance of x.
1000   //  activation_mode: activation applied to the result of batch normalization
1001   //    (or after adding optional side input)
1002   //  y: output data.
1003   //  batch_mean: batch mean, to be used to compute the running mean.
1004   //  batch_variance: batch variance, to be used to compute
1005   //    the running variance.
1006   //  reserve_space_1: saved mean, to be reused in the backward gradient
1007   //    computation.
1008   //  reserve_space_2: saved inv_var (1/sqrt(epsilon + variance), to be reused
1009   //    in the backward gradient computation.
1010   //  is_training: Set to true for training, false for inference.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1011   virtual bool DoBatchNormalizationForward(
1012       Stream* stream, const DeviceMemory<float>& x,
1013       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1014       const DeviceMemory<float>& estimated_mean,
1015       const DeviceMemory<float>& estimated_variance,
1016       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
1017       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1018       const double exponential_average_factor,
1019       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
1020       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1021       DeviceMemory<float>* reserve_space_1,
1022       DeviceMemory<float>* reserve_space_2, bool is_training,
1023       ScratchAllocator* reserve_space_allocator,
1024       ScratchAllocator* workspace_allocator) {
1025     return false;
1026   }
1027 
1028   // Performs a half-precision forwards batch normalization operation onto the
1029   // stream. See DoBatchNormalizationForward above for argument details.
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * reserve_space_1,DeviceMemory<float> * reserve_space_2,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)1030   virtual bool DoBatchNormalizationForward(
1031       Stream* stream, const DeviceMemory<Eigen::half>& x,
1032       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1033       const DeviceMemory<float>& estimated_mean,
1034       const DeviceMemory<float>& estimated_variance,
1035       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
1036       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1037       const double exponential_average_factor,
1038       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
1039       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1040       DeviceMemory<float>* reserve_space_1,
1041       DeviceMemory<float>* reserve_space_2, bool is_training,
1042       ScratchAllocator* reserve_space_allocator,
1043       ScratchAllocator* workspace_allocator) {
1044     return false;
1045   }
1046 
1047   // Performs a single-precision backward batch normalization gradient
1048   // computation operation onto the stream.
1049   //
1050   // Arguments:
1051   //  stream: borrowed pointer to the stream that the batch normalization
1052   //    gradient computation operation should be enqueued onto.
1053   //  y_backprop: gradient with regard to output y.
1054   //  x: input data.
1055   //  scale: scaling parameters.
1056   //  inv_var: 1/sqrt(epsilon + variance) of x.
1057   //  x_desc: dimensions of the input data, which is the same as the dimensions
1058   //    of the output.
1059   //  scale_offset_desc: dimensions of scale and offset.
1060   //  epsilon: a small floating point number added to the variance of x.
1061   //  x_backprop: gradient with respect to input x.
1062   //  scale_backprop: gradient with respect to scale.
1063   //  offset_backprop: gradient with respect to offset.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1064   virtual bool DoBatchNormalizationBackward(
1065       Stream* stream, const DeviceMemory<float>& y_backprop,
1066       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
1067       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
1068       const dnn::BatchDescriptor& x_desc,
1069       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1070       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
1071       DeviceMemory<float>* offset_backprop,
1072       DeviceMemory<uint8>* reserve_space_data,
1073       ScratchAllocator* workspace_allocator) {
1074     return false;
1075   }
1076 
1077   // Performs a half-precision backward batch normalization gradient computation
1078   // operation onto the stream. See DoBatchNormalizationBackward above for
1079   // argument details.
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)1080   virtual bool DoBatchNormalizationBackward(
1081       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
1082       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
1083       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
1084       const dnn::BatchDescriptor& x_desc,
1085       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1086       DeviceMemory<Eigen::half>* x_backprop,
1087       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
1088       DeviceMemory<uint8>* reserve_space_data,
1089       ScratchAllocator* workspace_allocator) {
1090     return false;
1091   }
1092 
1093   // Enqueues a fused convolution operation onto the stream.
1094   // We provide several variants with different types for inputs, biases and
1095   // scaling parameters.
1096   //
1097   // Arguments (all borrowed):
1098   //  stream: borrowed pointer to the stream that the 'convolve' operation
1099   //    should be enqueued onto.
1100   //  conv_input_descriptor: dimensions of the convolution input layer.
1101   //  conv_input_data: un-owned device memory region which contains the
1102   //    convolution input.
1103   //  conv_input_scale: a floating point scale to multiply with each element
1104   //    of conv_input_data.
1105   //  filter_descriptor: dimensions of the convolution filter.
1106   //  filter_data: un-owned device memory region which contains the
1107   //    convolution filter weights.
1108   //  convolution_descriptor: stride of the convolution filter.
1109   //  biases: un-owned device memory region containing biases to add to the
1110   //    input.
1111   //  activation_mode: Type of activation to perform.
1112   //  side_input_data: un-owned device memory region which contains optional
1113   //    side input data. If 'side_input_scale' is non-zero, then this must
1114   //    point to data in the tensor shape specified by output_shape.
1115   //    It will be scaled by 'side_input_scale' and added to the convolution
1116   //    result and bias prior to applying the activation function.
1117   //  side_input_scale: a floating point scale to multiply with each element
1118   //    of side_input_data.
1119   //  output_descriptor: dimensions of the output layer.
1120   //  output_data: un-owned device memory region in which to place the
1121   //    convolution result.
1122   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1123   //    space in order to speed up the convolution operation.
1124   //  algorithm_config: specifies which algorithm should be used for the
1125   //    operation.
1126   //  output_profile_result: the output profile result for this call. The
1127   //    profiling is only enabled when this is not nullptr.
1128   //
1129   // conv_input_descriptor, filter_descriptor, convolution_descriptor and
1130   // output_descriptor together specify exactly how the convolution is aligned
1131   // with the input data:
1132   //
1133   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1134   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1135   // * input dimensions / filter stride == output dimensions
1136   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1137   //   same size - this requires padding the input.
1138   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1139   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1140   //   that if the inverse of the filter is applied to the output in VALID mode
1141   //   the result is the same size as the input - this requires even more
1142   //   padding of the input.
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1143   virtual port::Status DoFusedConvolve(
1144       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1145       const DeviceMemory<double>& conv_input_data, double conv_input_scale,
1146       const dnn::FilterDescriptor& filter_descriptor,
1147       const DeviceMemory<double>& filter_data,
1148       const dnn::ConvolutionDescriptor& convolution_descriptor,
1149       const DeviceMemory<double>& side_input_data, double side_input_scale,
1150       const dnn::BatchDescriptor& bias_descriptor,
1151       const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
1152       const dnn::BatchDescriptor& output_descriptor,
1153       DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
1154       const dnn::AlgorithmConfig& algorithm_config,
1155       dnn::ProfileResult* output_profile_result) {
1156     return port::UnimplementedError(
1157         "DnnSupport::DoFusedConvolve not implemented on this platform.");
1158   }
1159 
1160   // This is the float version of DoFusedConvolve.
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1161   virtual port::Status DoFusedConvolve(
1162       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1163       const DeviceMemory<float>& conv_input_data, float conv_input_scale,
1164       const dnn::FilterDescriptor& filter_descriptor,
1165       const DeviceMemory<float>& filter_data,
1166       const dnn::ConvolutionDescriptor& convolution_descriptor,
1167       const DeviceMemory<float>& side_input_data, float side_input_scale,
1168       const dnn::BatchDescriptor& bias_descriptor,
1169       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
1170       const dnn::BatchDescriptor& output_descriptor,
1171       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
1172       const dnn::AlgorithmConfig& algorithm_config,
1173       dnn::ProfileResult* output_profile_result) {
1174     return port::UnimplementedError(
1175         "DnnSupport::DoFusedConvolve not implemented on this platform.");
1176   }
1177 
1178   // This is the Eigen::half version of DoFusedConvolve.
1179   // The scaling parameters are still floats.
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1180   virtual port::Status DoFusedConvolve(
1181       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1182       const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
1183       const dnn::FilterDescriptor& filter_descriptor,
1184       const DeviceMemory<Eigen::half>& filter_data,
1185       const dnn::ConvolutionDescriptor& convolution_descriptor,
1186       const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
1187       const dnn::BatchDescriptor& bias_descriptor,
1188       const DeviceMemory<Eigen::half>& biases,
1189       dnn::ActivationMode activation_mode,
1190       const dnn::BatchDescriptor& output_descriptor,
1191       DeviceMemory<Eigen::half>* output_data,
1192       ScratchAllocator* scratch_allocator,
1193       const dnn::AlgorithmConfig& algorithm_config,
1194       dnn::ProfileResult* output_profile_result) {
1195     return port::UnimplementedError(
1196         "DnnSupport::DoFusedConvolve not implemented on this platform.");
1197   }
1198 
1199   // This is the int8 version of DoFusedConvolve.
1200   // The bias input and scaling parameters are floats.
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1201   virtual port::Status DoFusedConvolve(
1202       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1203       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
1204       const dnn::FilterDescriptor& filter_descriptor,
1205       const DeviceMemory<int8>& filter_data,
1206       const dnn::ConvolutionDescriptor& convolution_descriptor,
1207       const DeviceMemory<int8>& side_input_data, float side_input_scale,
1208       const dnn::BatchDescriptor& bias_descriptor,
1209       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
1210       const dnn::BatchDescriptor& output_descriptor,
1211       DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
1212       const dnn::AlgorithmConfig& algorithm_config,
1213       dnn::ProfileResult* output_profile_result) {
1214     return port::UnimplementedError(
1215         "DnnSupport::DoFusedConvolve not implemented on this platform.");
1216   }
1217 
1218   // This is the int8 version of DoFusedConvolve.
1219   // The output, bias input and scaling parameters are floats.
DoFusedConvolve(Stream *,const dnn::BatchDescriptor &,const DeviceMemory<int8> &,float,const dnn::FilterDescriptor &,const DeviceMemory<int8> &,const dnn::ConvolutionDescriptor &,const DeviceMemory<float> &,float,const dnn::BatchDescriptor &,const DeviceMemory<float> &,dnn::ActivationMode,const dnn::BatchDescriptor &,DeviceMemory<float> *,ScratchAllocator *,const dnn::AlgorithmConfig &,dnn::ProfileResult *)1220   virtual port::Status DoFusedConvolve(
1221       Stream* /*stream*/, const dnn::BatchDescriptor& /*conv_input_descriptor*/,
1222       const DeviceMemory<int8>& /*conv_input_data*/, float /*conv_input_scale*/,
1223       const dnn::FilterDescriptor& /*filter_descriptor*/,
1224       const DeviceMemory<int8>& /*filter_data*/,
1225       const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
1226       const DeviceMemory<float>& /*side_input_data*/,
1227       float /*side_input_scale*/,
1228       const dnn::BatchDescriptor& /*bias_descriptor*/,
1229       const DeviceMemory<float>& /*biases*/,
1230       dnn::ActivationMode /*activation_mode*/,
1231       const dnn::BatchDescriptor& /*output_descriptor*/,
1232       DeviceMemory<float>* /*output_data*/,
1233       ScratchAllocator* /*scratch_allocator*/,
1234       const dnn::AlgorithmConfig& /*algorithm_config*/,
1235       dnn::ProfileResult* /*output_profile_result*/) {
1236     return port::UnimplementedError(
1237         "DnnSupport::DoFusedConvolve not implemented on this platform.");
1238   }
1239 
1240   template <typename ElementType, typename OutputType>
PrepareForConvolution(ConvolutionKind kind,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemory<ElementType> input_data,const FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> filter_data,const BatchDescriptor & output_descriptor,DeviceMemory<OutputType> output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)1241   port::Status PrepareForConvolution(
1242       ConvolutionKind kind, Stream* stream,
1243       const BatchDescriptor& batch_descriptor,
1244       DeviceMemory<ElementType> input_data,
1245       const FilterDescriptor& filter_descriptor,
1246       DeviceMemory<ElementType> filter_data,
1247       const BatchDescriptor& output_descriptor,
1248       DeviceMemory<OutputType> output_data,
1249       const ConvolutionDescriptor& convolution_descriptor,
1250       const AlgorithmConfig& algorithm_config,
1251       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
1252       DeviceMemory<uint8>* scratch_memory) {
1253     return DoPrepareForConvolution(
1254         kind, ToDataType<ElementType>::value, stream, batch_descriptor,
1255         input_data, filter_descriptor, filter_data, output_descriptor,
1256         output_data, convolution_descriptor, algorithm_config,
1257         scratch_allocator, algorithm_desc, scratch_memory);
1258   }
1259 
1260   // Enqueues a single-precision convolution operation onto the stream.
1261   //
1262   // Arguments (all borrowed):
1263   //  stream: borrowed pointer to the stream that the 'convolve' operation
1264   //    should be enqueued onto.
1265   //  input_descriptor: dimensions of the input layer.
1266   //  input_data: un-owned device memory region which contains the
1267   //    convolution input.
1268   //  filter_descriptor: dimensions of the convolution filter.
1269   //  convolution_descriptor: stride of the convolution filter.
1270   //  output_descriptor: dimensions of the output layer.
1271   //  output_data: un-owned device memory region in which to place the
1272   //    convolution result.
1273   //  algorithm_desc: specifies which algorithm should be used for the
1274   //    operation.
1275   //  scratch: un-owned device memory for scratch space in order to speed up
1276   //    the convolution operation.
1277   //  output_profile_result: the output profile result for this call. The
1278   //    profiling is only enabled when this is not nullptr.
1279   //
1280   // input_descriptor, filter_descriptor, convolution_descriptor and
1281   // output_descriptor together specify exactly how the convolution is aligned
1282   // with the input data:
1283   //
1284   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1285   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1286   // * input dimensions / filter stride == output dimensions
1287   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1288   //   same size - this requires padding the input.
1289   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1290   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1291   //   that if the inverse of the filter is applied to the output in VALID mode
1292   //   the result is the same size as the input - this requires even more
1293   //   padding of the input.
1294   virtual port::Status DoConvolve(
1295       ConvolutionKind kind, DataType element_type, DataType output_type,
1296       Stream* stream, const BatchDescriptor& input_descriptor,
1297       DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor,
1298       DeviceMemoryBase filter_data, const BatchDescriptor& output_descriptor,
1299       DeviceMemoryBase output_data,
1300       const ConvolutionDescriptor& convolution_descriptor,
1301       AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
1302       ProfileResult* output_profile_result) = 0;
1303 
1304   template <typename ElementType, typename OutputType>
DoConvolve(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output_data,const dnn::AlgorithmDesc & algorithm_desc,DeviceMemory<uint8> * scratch_memory,ProfileResult * output_profile_result)1305   bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1306                   const DeviceMemory<ElementType>& input_data,
1307                   const dnn::FilterDescriptor& filter_descriptor,
1308                   const DeviceMemory<ElementType>& filter_data,
1309                   const dnn::ConvolutionDescriptor& convolution_descriptor,
1310                   const dnn::BatchDescriptor& output_descriptor,
1311                   DeviceMemory<OutputType>* output_data,
1312                   const dnn::AlgorithmDesc& algorithm_desc,
1313                   DeviceMemory<uint8>* scratch_memory,
1314                   ProfileResult* output_profile_result) {
1315     return IsStatusOk(
1316         DoConvolve(ConvolutionKind::FORWARD, ToDataType<ElementType>::value,
1317                    ToDataType<OutputType>::value, stream, input_descriptor,
1318                    input_data, filter_descriptor, filter_data,
1319                    output_descriptor, *output_data, convolution_descriptor,
1320                    algorithm_desc, *scratch_memory, output_profile_result),
1321         !output_profile_result);
1322   }
1323 
1324   // Return a list of algorithms supported by the forward convolution pass.
1325   // cc_major and cc_minor are the compute capabilities of the device.
1326   virtual bool GetConvolveAlgorithms(
1327       bool with_winograd_nonfused, int cc_major, int cc_minor,
1328       std::vector<AlgorithmDesc>* out_algorithms);
1329 
1330   virtual bool GetMIOpenConvolveAlgorithms(
1331       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
1332       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
1333       const dnn::FilterDescriptor& filter_descriptor,
1334       DeviceMemoryBase filter_data,
1335       const dnn::BatchDescriptor& output_descriptor,
1336       DeviceMemoryBase output_data,
1337       const dnn::ConvolutionDescriptor& convolution_descriptor,
1338       ScratchAllocator* scratch_allocator,
1339       std::vector<ProfileResult>* out_algorithms);
1340 
1341   // Returns a list of supported rnn algorithms.
1342   virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
1343 
1344   // Version of DoConvolve that uses pre-quantized 8 bit coefficients.
1345   // coefficient_scales specifies the scaling of each column of coefficients:
1346   // original float coefficient[row * num_columns + column] =
1347   //     quantized coefficient[row * num_columns + column] *
1348   //     coefficient_scales[column].
1349   virtual bool DoConvolveQuantized(
1350       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1351       const DeviceMemory<float>& input_data,
1352       const dnn::FilterDescriptor& filter_descriptor,
1353       const DeviceMemory<int8>& filter_coefficients,
1354       const DeviceMemory<float>& coefficient_scales,
1355       const dnn::ConvolutionDescriptor& convolution_descriptor,
1356       const dnn::BatchDescriptor& output_descriptor,
1357       DeviceMemory<float>* output_data) = 0;
1358 
1359   // Same as DoConvolveQuantized above, but int8 filter coefficients.
1360   virtual bool DoConvolveQuantized(
1361       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1362       const DeviceMemory<float>& input_data,
1363       const dnn::FilterDescriptor& filter_descriptor,
1364       const DeviceMemory<int16>& filter_coefficients,
1365       const DeviceMemory<float>& coefficient_scales,
1366       const dnn::ConvolutionDescriptor& convolution_descriptor,
1367       const dnn::BatchDescriptor& output_descriptor,
1368       DeviceMemory<float>* output_data) = 0;
1369 
1370   // Variation of the above with the weight matrix split into two matrices.
1371   // first_weights: Coefficients of the first matrix.
1372   // second_weights: Coefficients of the second matrix.
1373   // depth_multiplier: specifies the columns of the first matrix and rows
1374   // of the second one - first_weights columns = depth_multiplier,
1375   // second_weights rows = depth_multiplier *
1376   //                       filter_descriptor.input_feature_map_count().
1377   // see go/separable for documentation on separable convolutions.
1378   virtual bool DoSeparableConvolve(
1379       Stream* stream, const BatchDescriptor& input_descriptor,
1380       const DeviceMemory<float>& input_data,
1381       const FilterDescriptor& filter_descriptor, int depth_multiplier,
1382       const DeviceMemory<float>& first_weights,
1383       const DeviceMemory<float>& second_weights,
1384       const ConvolutionDescriptor& convolution_descriptor,
1385       const BatchDescriptor& output_descriptor,
1386       DeviceMemory<float>* output_data) = 0;
1387 
1388   // Enqueues a single-precision backward convolution (for data) operation onto
1389   // the stream.
1390   //
1391   // Arguments:
1392   //  stream: borrowed pointer to the stream that the 'convolve' operation
1393   //    should be enqueued onto.
1394   //  filter_descriptor: dimensions of the convolution filter.
1395   //  filter_data: coefficients for the convolution filter.
1396   //  output_descriptor: dimensions of the output gradients, which is the same
1397   //    as the dimensions of the output.
1398   //  backward_output_data: un-owned device memory region which contains the
1399   //    backprop of the output.
1400   //  convolution_descriptor: stride of the convolution filter.
1401   //  input_descriptor: dimensions of the input layer.
1402   //  backward_input_data: un-owned device memory region in which to place the
1403   //    backprop of the input.
1404   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1405   //    space in order to speed up the convolution operation.
1406   template <typename ElementType>
DoConvolveBackwardData(Stream * stream,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::BatchDescriptor & output_descriptor,const DeviceMemory<ElementType> & backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<ElementType> * backward_input_data,const dnn::AlgorithmDesc & algorithm_desc,DeviceMemory<uint8> * scratch_memory,ProfileResult * output_profile_result)1407   bool DoConvolveBackwardData(
1408       Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
1409       const DeviceMemory<ElementType>& filter_data,
1410       const dnn::BatchDescriptor& output_descriptor,
1411       const DeviceMemory<ElementType>& backward_output_data,
1412       const dnn::ConvolutionDescriptor& convolution_descriptor,
1413       const dnn::BatchDescriptor& input_descriptor,
1414       DeviceMemory<ElementType>* backward_input_data,
1415       const dnn::AlgorithmDesc& algorithm_desc,
1416       DeviceMemory<uint8>* scratch_memory,
1417       ProfileResult* output_profile_result) {
1418     return IsStatusOk(
1419         DoConvolve(
1420             ConvolutionKind::BACKWARD_DATA, ToDataType<ElementType>::value,
1421             ToDataType<ElementType>::value, stream, input_descriptor,
1422             *backward_input_data, filter_descriptor, filter_data,
1423             output_descriptor, backward_output_data, convolution_descriptor,
1424             algorithm_desc, *scratch_memory, output_profile_result),
1425         !output_profile_result);
1426   }
1427 
1428   // Return a list of algorithms supported by the backward convolution pass for
1429   // data.
1430   virtual bool GetConvolveBackwardDataAlgorithms(
1431       bool with_winograd_nonfused, int cc_major, int cc_minor,
1432       std::vector<AlgorithmDesc>* out_algorithms);
1433 
1434   // Enqueues a single-precision backward convolution (for filter) operation
1435   // onto the stream.
1436   //
1437   // Arguments:
1438   //  stream: borrowed pointer to the stream that the 'convolve' operation
1439   //    should be enqueued onto.
1440   //  input_descriptor: dimensions of the input layer.
1441   //  input_data: un-owned device memory region which contains the
1442   //    convolution input.
1443   //  output_descriptor: dimensions of the output gradients, which is the same
1444   //    as the dimensions of the output.
1445   //  backward_output_data: un-owned device memory region which contains the
1446   //    backprop of the output.
1447   //  convolution_descriptor: stride of the convolution filter.
1448   //  filter_descriptor: dimensions of the convolution filter.
1449   //  backward_filter_data: un-owned device memory region in which to place the
1450   //    backprop of the filter.
1451   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1452   //    space in order to speed up the convolution operation.
1453   template <typename ElementType>
DoConvolveBackwardFilter(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const BatchDescriptor & output_descriptor,const DeviceMemory<ElementType> & backward_output_data,const ConvolutionDescriptor & convolution_descriptor,const FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> * backward_filter_data,const dnn::AlgorithmDesc & algorithm_desc,DeviceMemory<uint8> * scratch_memory,ProfileResult * output_profile_result)1454   bool DoConvolveBackwardFilter(
1455       Stream* stream, const BatchDescriptor& input_descriptor,
1456       const DeviceMemory<ElementType>& input_data,
1457       const BatchDescriptor& output_descriptor,
1458       const DeviceMemory<ElementType>& backward_output_data,
1459       const ConvolutionDescriptor& convolution_descriptor,
1460       const FilterDescriptor& filter_descriptor,
1461       DeviceMemory<ElementType>* backward_filter_data,
1462       const dnn::AlgorithmDesc& algorithm_desc,
1463       DeviceMemory<uint8>* scratch_memory,
1464       ProfileResult* output_profile_result) {
1465     return IsStatusOk(
1466         DoConvolve(
1467             ConvolutionKind::BACKWARD_FILTER, ToDataType<ElementType>::value,
1468             ToDataType<ElementType>::value, stream, input_descriptor,
1469             input_data, filter_descriptor, *backward_filter_data,
1470             output_descriptor, backward_output_data, convolution_descriptor,
1471             algorithm_desc, *scratch_memory, output_profile_result),
1472         !output_profile_result);
1473   }
1474 
1475   // Return a list of algorithms supported by the backward convolution pass for
1476   // filters.
1477   virtual bool GetConvolveBackwardFilterAlgorithms(
1478       bool with_winograd_nonfused, int cc_major, int cc_minor,
1479       std::vector<AlgorithmDesc>* out_algorithms);
1480 
1481   // Enqueues a single-precision backward convolution (for bias) operation onto
1482   // the stream.
1483   //
1484   // Arguments:
1485   //  stream: borrowed pointer to the stream that the 'convolve' operation
1486   //    should be enqueued onto.
1487   //  input_descriptor: dimensions of the input layer.
1488   //  input_data: un-owned device memory region which contains the
1489   //    convolution input.
1490   //  bias_descriptor: dimensions of the bias tensor. Should be the same as the
1491   //    input dimensions, but with the spatial dimensions set to 1.
1492   //  backward_filter_data: un-owned device memory region in which to place the
1493   //    backprop of the bias.
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)1494   virtual bool DoConvolveBackwardBias(Stream* stream,
1495                                       const BatchDescriptor& input_descriptor,
1496                                       const DeviceMemory<float>& input_data,
1497                                       const BatchDescriptor& bias_descriptor,
1498                                       DeviceMemory<float>* backward_bias_data) {
1499     return false;
1500   }
1501 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)1502   virtual bool DoConvolveBackwardBias(
1503       Stream* stream, const BatchDescriptor& input_descriptor,
1504       const DeviceMemory<double>& input_data,
1505       const BatchDescriptor& bias_descriptor,
1506       DeviceMemory<double>* backward_bias_data) {
1507     return false;
1508   }
1509 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)1510   virtual bool DoConvolveBackwardBias(
1511       Stream* stream, const BatchDescriptor& input_descriptor,
1512       const DeviceMemory<Eigen::half>& input_data,
1513       const BatchDescriptor& bias_descriptor,
1514       DeviceMemory<Eigen::half>* backward_bias_data) {
1515     return false;
1516   }
1517 
1518   // Fully connects the "nodes" (float values) in input_data with
1519   // shape input_dimensions to output_data with output_dimensions
1520   // using provided weights. This is equivalent to computing a matrix
1521   // product, hence the name MatMul.
1522   //
1523   // A BatchDescriptor has four dimensions: batch, y, x, depth. Matrix products
1524   // happen in two dimensions. To get down to two dimensions, we consider the
1525   // input y, x and depth dimension as one combined dimension T. For now,
1526   // assume that the output height and width are 1 and let OD be the output
1527   // depth.
1528   //
1529   // There are three device memory buffers passed in to this
1530   // function. We can now view all three as matrices:
1531   //
1532   //   input_data: A batch x T matrix
1533   //   weights: A T x OD matrix
1534   //   output_data: A batch x OD matrix
1535   //
1536   // This function then computes the matrix product of input_data and
1537   // weights and writes the result into output_data.
1538   //
1539   // Here the weights buffer is in row major order, i.e. the first OD
1540   // entries in weights are the first row, the second OD entries in
1541   // weights are the second row and so on.
1542   //
1543   // The case for output width*height > 1 is more complicated. Let K =
1544   // OY * OX where OY is the output height and OX is the output
1545   // width. Then weights is divided into K sub-arrays W_i, for
1546   // i=0,...,k-1, that each represent a T x OD matrix. This function
1547   // then computes the K matrix multiplications of input_data with
1548   // each W_i. This creates K matrices with dimensions batch x
1549   // OD. These K matrices are concatenated horizontally to form one
1550   // larger matrix with dimensions batch x (K*OD); note that this is
1551   // not the same as concatenating the bytes of the matrices. The
1552   // combined matrix can then be interpreted as a tensor with
1553   // dimensions (batch, OY, OX, OD). If the output tensor format is
1554   // not kBatchYXDepth, this function would then need to arrange for
1555   // the output to be in the requested layout, if that is
1556   // supported. Note that the case K=1 is equivalent to the
1557   // description above. It is recommended to prefer the case K=1.
1558   //
1559   // Arguments (all borrowed):
1560   //  stream: borrowed pointer to the stream that the 'fully connect' operation
1561   //    should be enqueued onto.
1562   //  output_data: un-owned device memory region in which to place the
1563   //    fully connected result.
1564   virtual bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
1565                         const DeviceMemory<float>& weights,
1566                         const dnn::BatchDescriptor& input_dimensions,
1567                         const dnn::BatchDescriptor& output_dimensions,
1568                         DeviceMemory<float>* output_data) = 0;
1569 
1570   // Version of DoMatMul that uses pre-quantized 8 bit weights.
1571   // weight_scales specifies the scaling of each column of weights:
1572   // original float weight[row * num_columns + column] =
1573   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1574   virtual bool DoMatMulQuantized(Stream* stream,
1575                                  const DeviceMemory<float>& input_data,
1576                                  const DeviceMemory<int8>& quantized_weights,
1577                                  const DeviceMemory<float>& weight_scales,
1578                                  const dnn::BatchDescriptor& input_dimensions,
1579                                  const dnn::BatchDescriptor& output_dimensions,
1580                                  DeviceMemory<float>* output_data) = 0;
1581 
1582   // Version of DoMatMul that uses pre-quantized 16 bit weights.
1583   // weight_scales specifies the scaling of each column of weights:
1584   // original float weight[row * num_columns + column] =
1585   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1586   virtual bool DoMatMulQuantized(Stream* stream,
1587                                  const DeviceMemory<float>& input_data,
1588                                  const DeviceMemory<int16>& quantized_weights,
1589                                  const DeviceMemory<float>& weight_scales,
1590                                  const dnn::BatchDescriptor& input_dimensions,
1591                                  const dnn::BatchDescriptor& output_dimensions,
1592                                  DeviceMemory<float>* output_data) = 0;
1593 
1594   // Adds biases to the feature maps in input_data producing
1595   // output_data. input_data can equal output_data, but must not
1596   // partially overlap it.
1597   //
1598   // Let K = count() * height() * width() and N = feature_map_count()
1599   // on dimensions. Then input_value contains K*N values and biases
1600   // contains N values. We can thus logically consider input_value to
1601   // contain K vectors of N elements each. This function adds biases
1602   // to each of those N vectors.
1603   //
1604   // TODO(broune): This works differently when width() * height() > 1
1605   // and the call to ThenBiasAdd() follows a call to ThenMatMul(). In
1606   // that case there should be width() * height() *
1607   // feature_map_count() biases, but this is not implemented on all
1608   // StreamExecutors.
1609   //
1610   // Arguments (all borrowed):
1611   //  stream: borrowed pointer to the stream that the 'bias add' operation
1612   //    should be enqueued onto.
1613   //  input_data: un-owned device memory region containing the input.
1614   //  biases: un-owned device memory region containing biases to add to the
1615   //    input.
1616   //  dimensions: dimensions of input_data and output_data.
1617   //  output_data: un-owned device memory region in which to place the result.
1618   virtual bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
1619                          const DeviceMemory<float>& biases,
1620                          const dnn::BatchDescriptor& dimensions,
1621                          DeviceMemory<float>* output_data) = 0;
1622 
1623   // Performs a forward pooling operation on input_data, writing to
1624   // output_data. See PoolingDescriptor for how to configure the
1625   // pooling operation.
1626   //
1627   // Pooling happens as a window that moves across the Y and X
1628   // dimensions of input_data, where each position of the window
1629   // yields one output value. E.g. for max pooling, the computed value
1630   // is the maximum element in the window. The operation is applied
1631   // independently to each batch and at each feature map (depth), so
1632   // that the output depth and feature_map_count are the same as for
1633   // the input. The output width and height can be different.
1634   //
1635   // See PoolingDescriptor for how to configure the pooling operation.
1636   virtual bool DoPoolForward(Stream* stream,
1637                              const dnn::PoolingDescriptor& pooling_dimensions,
1638                              const dnn::BatchDescriptor& input_dimensions,
1639                              const DeviceMemory<float>& input_data,
1640                              const dnn::BatchDescriptor& output_dimensions,
1641                              DeviceMemory<float>* output_data,
1642                              ScratchAllocator* workspace_allocator) = 0;
1643 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)1644   virtual bool DoPoolForward(Stream* stream,
1645                              const dnn::PoolingDescriptor& pooling_dimensions,
1646                              const dnn::BatchDescriptor& input_dimensions,
1647                              const DeviceMemory<double>& input_data,
1648                              const dnn::BatchDescriptor& output_dimensions,
1649                              DeviceMemory<double>* output_data,
1650                              ScratchAllocator* workspace_allocator) {
1651     LOG(FATAL) << "DoPoolForward not implemented for double.";
1652     return false;
1653   }
1654 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)1655   virtual bool DoPoolForward(Stream* stream,
1656                              const dnn::PoolingDescriptor& pooling_dimensions,
1657                              const dnn::BatchDescriptor& input_dimensions,
1658                              const DeviceMemory<Eigen::half>& input_data,
1659                              const dnn::BatchDescriptor& output_dimensions,
1660                              DeviceMemory<Eigen::half>* output_data,
1661                              ScratchAllocator* workspace_allocator) {
1662     LOG(FATAL) << "DoPoolForward not implemented for float16.";
1663     return false;
1664   }
1665 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)1666   virtual bool DoPoolForward(Stream* stream,
1667                              const dnn::PoolingDescriptor& pooling_dimensions,
1668                              const dnn::BatchDescriptor& input_dimensions,
1669                              const DeviceMemory<int8>& input_data,
1670                              const dnn::BatchDescriptor& output_dimensions,
1671                              DeviceMemory<int8>* output_data,
1672                              ScratchAllocator* workspace_allocator) {
1673     LOG(FATAL) << "DoPoolForward not implemented for int8.";
1674     return false;
1675   }
1676 
1677   // Performs differentiation of the pooling operation.
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)1678   virtual bool DoPoolBackward(Stream* stream,
1679                               const dnn::PoolingDescriptor& pooling_dimensions,
1680                               const dnn::BatchDescriptor& input_dimensions,
1681                               const DeviceMemory<double>& input_data,
1682                               const dnn::BatchDescriptor& output_dimensions,
1683                               const DeviceMemory<double>& output_data,
1684                               const DeviceMemory<double>& input_diff_data,
1685                               DeviceMemory<double>* output_diff_data,
1686                               ScratchAllocator* workspace_allocator) {
1687     LOG(FATAL) << "DoPoolBackward not implemented.";
1688     return false;
1689   }
1690 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)1691   virtual bool DoPoolBackward(Stream* stream,
1692                               const dnn::PoolingDescriptor& pooling_dimensions,
1693                               const dnn::BatchDescriptor& input_dimensions,
1694                               const DeviceMemory<float>& input_data,
1695                               const dnn::BatchDescriptor& output_dimensions,
1696                               const DeviceMemory<float>& output_data,
1697                               const DeviceMemory<float>& input_diff_data,
1698                               DeviceMemory<float>* output_diff_data,
1699                               ScratchAllocator* workspace_allocator) {
1700     LOG(FATAL) << "DoPoolBackward not implemented.";
1701     return false;
1702   }
1703 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)1704   virtual bool DoPoolBackward(Stream* stream,
1705                               const dnn::PoolingDescriptor& pooling_dimensions,
1706                               const dnn::BatchDescriptor& input_dimensions,
1707                               const DeviceMemory<Eigen::half>& input_data,
1708                               const dnn::BatchDescriptor& output_dimensions,
1709                               const DeviceMemory<Eigen::half>& output_data,
1710                               const DeviceMemory<Eigen::half>& input_diff_data,
1711                               DeviceMemory<Eigen::half>* output_diff_data,
1712                               ScratchAllocator* workspace_allocator) {
1713     LOG(FATAL) << "DoPoolBackward not implemented.";
1714     return false;
1715   }
1716 
1717   // Applies local response normalization to the values from input_data and
1718   // writes the result to output_data.
1719   //
1720   // See comments on NormalizeDescriptor for a description of local response
1721   // normalization.
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1722   virtual bool DoNormalizeWithDimensions(
1723       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1724       const dnn::BatchDescriptor& dimensions,
1725       const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
1726     return false;
1727   }
1728 
1729   // Performs backpropagation for the normalization operation
1730   //
1731   // Given raw data, its corresponding normalized output, and a gradient of some
1732   // unspecified function with respect to the normalized variables, computes the
1733   // gradient of that unspecified function with respect to the raw variables.
1734   //
1735   // The normalized data input array is expected to match the output that would
1736   // be obtained by running the raw data input array through the DoNormalize
1737   // method above.
1738   //
1739   // See comments on NormalizeDescriptor for a description of local response
1740   // normalization.
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1741   virtual bool DoNormalizeBackwardWithDimensions(
1742       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1743       const dnn::BatchDescriptor& dimensions,
1744       const DeviceMemory<float>& raw_data,
1745       const DeviceMemory<float>& normalized_data,
1746       const DeviceMemory<float>& normalized_variable_gradient,
1747       DeviceMemory<float>* raw_variable_gradient,
1748       ScratchAllocator* workspace_allocator) {
1749     return false;
1750   }
1751 
1752   // Applies an activation function (see ActivationMode) to all of the values
1753   // held on the device in 'input_data', whose dimensions are described by
1754   // 'dimensions'.
1755   //
1756   // Arguments (all borrowed):
1757   //  stream: borrowed pointer to the stream that the 'activate' operation
1758   //    should be enqueued onto.
1759   //  activation_mode: Type of activation to perform.
1760   //  input_data: un-owned device memory region which contains the
1761   //    activate input.
1762   //  output_data: un-owned device memory region in which to place the
1763   //    activate result.
DoActivate(Stream * stream,ActivationMode activation_mode,const BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1764   virtual bool DoActivate(Stream* stream, ActivationMode activation_mode,
1765                           const BatchDescriptor& dimensions,
1766                           const DeviceMemory<float>& input_data,
1767                           DeviceMemory<float>* output_data, uint64 options) {
1768     return false;
1769   }
1770 
1771   // Concatenates several layers into one, by concatenating the depth of each
1772   // layer at matching x and y coordinates.
1773   // The inputs must all have the same width and height, the output will have
1774   // the same width and height as the inputs and its depth will be the sum of
1775   // the input depths.
1776   //
1777   // Arguments (all borrowed):
1778   //  stream: borrowed pointer to the stream that the 'depth concatenate'
1779   // operation should be enqueued onto.
1780   //  input_dimensions: The dimensions of each input.
1781   //  input_data: un-owned device memory region which contains the
1782   //    input data for each input layer.
1783   //  output_data: un-owned device memory region in which to place the
1784   //    depth concatenate result.
1785   virtual bool DoDepthConcatenate(
1786       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1787       port::ArraySlice<const DeviceMemory<float>*> input_data,
1788       DeviceMemory<float>* output_data) = 0;
1789 
1790   // Concatenates several layers into one, by concatenating each in the
1791   // x-dimension or y-dimension, based on a user-specified flag.
1792   // For x-concatenation, layers are aligned at matching y and depth
1793   // coordinates, and for y-concatenation, they are aligned at matching x and
1794   // depth coordinates. The inputs must all have the same depth and batch size.
1795   // For x-concatenation, the inputs must have the same height (y-size), and the
1796   // output will have the same depth and height as the inputs and its width (x-
1797   // size) will be the sum of the input widths.  For y-concatenation, the inputs
1798   // must have the same width, and the output will have the same depth and width
1799   // as the inputs, and its height will be the sum of the input heights.
1800   //
1801   // Arguments:
1802   //  stream: borrowed pointer to the stream that the 'space concatenate'
1803   //    operation should be enqueued onto.
1804   //  input_dimensions: the dimensions of each input.
1805   //  input_data: un-owned device memory region which contains the input data
1806   //    for each input layer.
1807   //  output_data: un-owned device memory region in which to place the space
1808   //    concatenate result.
1809   //  concat_direction:  either dnn:SpaceConcatenateMode::XDirection or
1810   //    dnn::SpaceConcatenateMode::YDirection.
DoSpaceConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1811   virtual bool DoSpaceConcatenate(
1812       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1813       port::ArraySlice<const DeviceMemory<float>*> input_data,
1814       DeviceMemory<float>* output_data,
1815       dnn::SpaceConcatenateMode concat_direction) {
1816     return false;
1817   }
1818 
1819   // Change the layout of the data by shrinking one dimension (or set of
1820   // dimensions) and growing another dimension (or set of dimensions), while
1821   // keeping the total number of data elements constant, and maintaining the
1822   // current data ordering.
1823   //
1824   // Currently, the only supported operation is depth into space by a power of
1825   // 2. E.g. (y, x, z) -> (y*2, x*2, z/4)
1826   //
1827   // Note that Reshape may not be a no-op, depending on the platform and which
1828   // dimensions are being changed.
1829   //
1830   // Example: forgetting about batch for the moment, let's take a tensor that's
1831   // 2x1x8 (y by x by z) and reshape to a tensor that's 4x2x2. The memory layout
1832   // is row-major order: y,x,z. I.e. z changes the fastest, then x, then y. The
1833   // elements of the tensor range from 0 to 15. The x,y,z indices are below each
1834   // element.
1835   //
1836   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1837   // y0 y0 y0 y0 y0 y0 y0 y0 y1 y1 y1 y1 y1 y1 y1 y1
1838   // x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0
1839   // z0 z1 z2 z3 z4 z5 z6 z7 z0 z1 z2 z3 z4 z5 z6 z7
1840   //
1841   // reshape to 4x2x2
1842   //
1843   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1844   // y0 y0 y0 y0 y1 y1 y1 y1 y2 y2 y2 y2 y3 y3 y3 y3
1845   // x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1
1846   // z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1
DoReshape(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1847   virtual bool DoReshape(Stream* stream,
1848                          const dnn::BatchDescriptor& input_dimensions,
1849                          const DeviceMemory<float>& input_data,
1850                          const dnn::BatchDescriptor& output_dimensions,
1851                          DeviceMemory<float>* output_data) {
1852     return false;
1853   }
1854 
1855   // Depth to space takes an X by Y image with depth D*M^2 and changes it to an
1856   // MX x MY image with depth D. Each input location (x,y) with depth D*M^2 in
1857   // the input image is changed to an MxM contiguous area in the output image,
1858   // with the values being laid out in the raster order by DepthToSpaceLayout,
1859   // and will have a new depth of D.
1860   //
1861   // Example.
1862   // M=2, Din =8, Xin=2, Yin=2. Xout=4, Yout=4,  Dout=2
1863   // DepthHeightWidth layout
1864   // Values within a 'cell' are at different depths and same x & y.
1865   // Input:
1866   // abcdefgh  ijklmnop
1867   // qrstuvwx  yz012345
1868   // Output:
1869   // ae bf im jn
1870   // cg dh ko lp
1871   // qu rv y2 z3
1872   // sw tx 04 15
1873   //
1874   // sqrt_depth_reduction: 'M' in the comment above
DoDepthToSpace(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & depth_to_space_layout,const int & sqrt_depth_reduction,DeviceMemory<float> * output_data)1875   virtual bool DoDepthToSpace(Stream* stream,
1876                               const dnn::BatchDescriptor& input_dimensions,
1877                               const DeviceMemory<float>& input_data,
1878                               const DepthToSpaceLayout& depth_to_space_layout,
1879                               const int& sqrt_depth_reduction,
1880                               DeviceMemory<float>* output_data) {
1881     return false;
1882   }
1883 
1884   // Space to depth is the inverse of depth to space. Space to depth takes each
1885   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
1886   // the input, and transforms it to a 1 by 1 patch with depth D*M^2. If the
1887   // input has size (MX, MY, D), the output has size (X, Y, D*M^2). The number
1888   // of data elements is not changed.
1889   //
1890   // Example.
1891   // M=2, Din =2, Xin=4, Yin=4,  Dout=8
1892   // DepthHeightWidth layout
1893   // Values within a 'cell' are at different depths and same x & y.
1894   // Input:
1895   // ae bf im jn
1896   // cg dh ko lp
1897   // qu rv y2 z3
1898   // sw tx 04 15
1899   // Output:
1900   // abcdefgh  ijklmnop
1901   // qrstuvwx  yz012345
1902   //
1903   // sqrt_depth_increase: 'M' in the comment above
DoSpaceToDepth(Stream * stream,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const DepthToSpaceLayout & space_to_depth_layout,const int & sqrt_depth_increase,DeviceMemory<float> * output_data)1904   virtual bool DoSpaceToDepth(Stream* stream,
1905                               const dnn::BatchDescriptor& input_dimensions,
1906                               const DeviceMemory<float>& input_data,
1907                               const DepthToSpaceLayout& space_to_depth_layout,
1908                               const int& sqrt_depth_increase,
1909                               DeviceMemory<float>* output_data) {
1910     return false;
1911   }
1912 
1913   // Computes the specified operation (e.g. addition or multiplication)
1914   // between corresponding elements in the inputs and stores the result in the
1915   // output element.
1916   // The inputs and output must all have the same dimensions, but may have
1917   // different quantization parameters (min_value and max_value).
1918   //
1919   // Arguments (all borrowed):
1920   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1921   // should be enqueued onto.
1922   //  operation: The operation to perform.
1923   //  input_dimensions: The dimensions of each input.
1924   //  input_data: un-owned device memory region which contains the
1925   //    input data for each input layer.
1926   //  output_dimensions: The dimensions of the output.
1927   //  output_data: un-owned device memory region in which to place the
1928   //    operation result.
1929   virtual bool DoElementwiseOperate(
1930       Stream* stream, ElementwiseOperation operation,
1931       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1932       port::ArraySlice<const DeviceMemory<float>*> input_data,
1933       const dnn::BatchDescriptor& output_dimensions,
1934       DeviceMemory<float>* output_data) = 0;
1935 
1936   // Computes the specified operation (e.g. addition or multiplication)
1937   // between corresponding elements in the inputs and stores the result in the
1938   // output element. Each input is multiplied by a scalar constant and the
1939   // result is divided by a scalar constant.
1940   // e.g. To perform Z = 0.9*X + 1.1*Y, set the input multiplicands to 9 and 11
1941   // and the output divisor to 10.
1942   // The inputs and output must all have the same dimensions, but may have
1943   // different quantization parameters (min_value and max_value).
1944   //
1945   // Arguments (all borrowed):
1946   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1947   // should be enqueued onto.
1948   //  operation: The operation to perform.
1949   //  input_multiplicands: Amount to scale each input.
1950   //  output_divisor: Amount to divide the output.
1951   //  input_dimensions: The dimensions of each input.
1952   //  input_data: un-owned device memory region which contains the
1953   //    input data for each input layer.
1954   //  output_dimensions: The dimensions of the output.
1955   //  output_data: un-owned device memory region in which to place the
1956   //    operation result.
DoElementwiseOperateScaledQuantized(Stream * stream,ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1957   virtual bool DoElementwiseOperateScaledQuantized(
1958       Stream* stream, ElementwiseOperation operation,
1959       port::ArraySlice<int> input_multiplicands, int output_divisor,
1960       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1961       port::ArraySlice<const DeviceMemory<float>*> input_data,
1962       const dnn::BatchDescriptor& output_dimensions,
1963       DeviceMemory<float>* output_data) {
1964     return false;
1965   }
1966 
1967   // Pads the input with zeros in the X and Y dimensions. The feature_map
1968   // dimension is unchanged.
1969   //
1970   // Arguments (all borrowed):
1971   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1972   // should be enqueued onto.
1973   //  dimensions: The dimensions of the input.
1974   //  input_data: un-owned device memory region which contains the
1975   //    input data for the input layer.
1976   //  left_pad: Amount to pad the input on the left.
1977   //  right_pad: Amount to pad the input on the right.
1978   //  top_pad: Amount to pad the input at the top (low Y).
1979   //  bottom_pad: Amount to pad the input at the bottom (high Y).
1980   //  output_data: un-owned device memory region in which to place the
1981   //    padded result.
1982   virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions,
1983                        const DeviceMemory<float> &input_data,
1984                        int64 left_pad, int64 right_pad, int64 top_pad,
1985                        int64 bottom_pad, DeviceMemory<float> *output_data) = 0;
1986 
1987   // Extracts a slice of the input in the X and Y dimensions. The feature_map
1988   // dimension is unchanged.
1989   //
1990   // Arguments (all borrowed):
1991   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1992   // should be enqueued onto.
1993   //  dimensions: The dimensions of the input.
1994   //  input_data: un-owned device memory region which contains the
1995   //    input data for the input layer.
1996   //  left_trim: Amount to cut off the input on the left.
1997   //  right_trim: Amount to cut off the input on the right.
1998   //  top_trim: Amount to cut off the input at the top (low y).
1999   //  bottom_trim: Amount to cut off the input at the bottom (high Y).
2000   //  output_data: un-owned device memory region in which to place the
2001   //    padded result.
2002   virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions,
2003                     const DeviceMemory<float> &input_data,
2004                     int64 left_trim, int64 right_trim, int64 top_trim,
2005                     int64 bottom_trim, DeviceMemory<float> *output_data) = 0;
2006 
2007   // Grows the input tensor by replicating the X and Y dimensions. The batch and
2008   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
2009   // limited to X=1 and Y=1.
2010   //
2011   // For example, the input has dimensions x=2, y=3, and replicate_x=3,
2012   // replicate_y=2. The diagonal elements of the output would be: [x0y0, x1y1,
2013   // x0y2, x1y0, x0y1, x1y2].
2014   // Here is the example as a picture. input:
2015   // AB
2016   // CD
2017   // EF
2018   // broadcast result:
2019   // ABABAB
2020   // CDCDCD
2021   // EFEFEF
2022   // ABABAB
2023   // CDCDCD
2024   // EFEFEF
2025   //
2026   // Arguments (all borrowed):
2027   //  stream: borrowed pointer to the stream that the 'elementwise operation'
2028   // should be enqueued onto.
2029   //  dimensions: The dimensions of the input.
2030   //  input_data: un-owned device memory region which contains the
2031   //    input data for the input layer.
2032   //  replicate_x: Amount to replicate the input's X dimension.
2033   //  replicate_y: Amount to replicate the input's Y dimension.
2034   //  output_data: un-owned device memory region in which to place the
2035   //    padded result.
DoXYBroadcast(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)2036   virtual bool DoXYBroadcast(Stream* stream,
2037                              const dnn::BatchDescriptor& dimensions,
2038                              const DeviceMemory<float>& input_data,
2039                              int64 replicate_x, int64 replicate_y,
2040                              DeviceMemory<float>* output_data) {
2041     return false;
2042   }
2043 
2044   // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
2045   // is, bytes instead of scaled floats) into 'host_dst' if they are available
2046   // for the underlying DNN implementation. If this quantized output is not
2047   // available, false is returned, which will place 'stream' into an error
2048   // state.
2049   //
2050   // Arguments (all borrowed):
2051   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
2052   //    operation should be enqueued onto.
2053   //  gpu_unquantized_src: the device memory that contains the unquantized data
2054   //    -- this data should also have a corresponding quantized representation
2055   //    on the device for this operation to succeed.
2056   //  mode: Type of quantization of the data to write into host_dst.
2057   //  host_dst: un-owned host memory region that is mutated in place,
2058   //    it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
2059   //    (asynchronous) memcpy operation is performed.
2060   //  size: size in bytes of the host_dst host memory region.
2061   virtual bool DoMemcpyD2HQuantized(
2062       Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
2063       QuantizedActivationMode mode, void* host_dst, int64 size) = 0;
2064 
2065   // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
2066   // of a layer (that is, bytes instead of scaled floats) if they are supported
2067   // by the underlying DNN implementation. If this quantized input is not
2068   // supported, false is returned, which will place 'stream' into an error
2069   // state.
2070   //
2071   // Arguments (all borrowed):
2072   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
2073   //    operation should be enqueued onto.
2074   //  host_src: un-owned host memory region that contains the quantized data.
2075   //  size: size in bytes of the host_src host memory region.
2076   //  mode: Type of quantization of the data to read from host_src.
2077   //  gpu_unquantized_dst: the device memory that is clobbered by the values in
2078   //    'host_src' when the enqueued (asynchronous) memcpy operation is
2079   //    performed. -- this data should also have a corresponding quantized
2080   //    representation on the device for this operation to
2081   //    succeed.
2082   virtual bool DoMemcpyH2DQuantized(
2083       Stream* stream, const void* host_src, int64 size,
2084       QuantizedActivationMode mode,
2085       DeviceMemory<float>* gpu_unquantized_dst) = 0;
2086 
2087   // Create an RNN descriptor based on model shapes and configurations.
2088   // The caller retains the ownership of the descriptor.
2089   //
2090   // Arguments:
2091   //  num_layers: the number of layers for a RNN model.
2092   //  hidden_size: the size of the hidden state.
2093   //  input_size: the size of the input state.
2094   //  cell_size: the size of the cell state
2095   //  input_mode: an enum to specify whether a linear transformation is added
2096   //    after the input state. If input_size is different from hidden_size, this
2097   //    is required.
2098   //  direction_mode: an enum to specify whether this model is unidirectional or
2099   //    bidirectional.
2100   //  rnn_mode: an enum to specify the type of model to build.
2101   //  data_type: an enum to specify the data types used in this model.
2102   //  dropout: the dropout threshold between layers. When it is 0., no dropout
2103   //    is added.
2104   //  seed: a seed for initializing the dropout layers.
2105   //  state_allocator: an memory allocator that will be used to store the state
2106   //    for dropout layer. The user has to maintain the memory until the model
2107   //    is no longer in use.
2108   //  use_padded_io: a bool to specify whether the input is using padded IO.
2109   virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2110   createRnnDescriptor(int num_layers, int hidden_size, int input_size,
2111                       int cell_size, int batch_size,
2112                       dnn::RnnInputMode input_mode,
2113                       dnn::RnnDirectionMode direction_mode,
2114                       dnn::RnnMode rnn_mode, dnn::DataType data_type,
2115                       const dnn::AlgorithmConfig& algorithm_config,
2116                       float dropout, uint64 seed,
2117                       ScratchAllocator* state_allocator, bool use_padded_io) {
2118     return port::Status(port::error::UNIMPLEMENTED,
2119                         "createRnnDescriptor is unimplemented");
2120   }
2121 
2122   // Create a RNN sequence descriptor that specifies either the input or output
2123   // sequence. The caller retains the ownership of the returned descriptor.
2124   //
2125   // Arguments:
2126   //  max_seq_length: the max length of the sequences.
2127   //  batch_size: the size of a minibatch.
2128   //  data_size: the size of the state.
2129   //  seq_lengths: the lengths of sequences in a batch.
2130   //  data_type: an enum to specify the type for the underlying data.
2131   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2132   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2133                                     int data_size, dnn::DataType data_type) {
2134     return port::Status(port::error::UNIMPLEMENTED,
2135                         "createRnnSequenceTensorDescriptor is unimplemented");
2136   }
2137 
2138   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2139   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2140                                     int data_size,
2141                                     const absl::Span<const int>& seq_lengths,
2142                                     bool time_major, dnn::DataType data_type) {
2143     return port::Status(port::error::UNIMPLEMENTED,
2144                         "createRnnSequenceTensorDescriptor is unimplemented");
2145   }
2146 
2147   // Create an RNN state descriptor that specifies the input or hidden state.
2148   // The caller retains the ownership of the returned descriptor.
2149   virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2150   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
2151                                  dnn::DataType data_type) {
2152     return port::Status(port::error::UNIMPLEMENTED,
2153                         "createRnnStateTensorDescriptor is unimplemented");
2154   }
2155 
2156   // Enqueue a forward operation of the RNN model onto the stream.
2157   //
2158   // Arguments:
2159   //  stream: pointer to the stream where this operation should be enqueued to.
2160   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2161   //  input_desc: descriptor for the input sequence.
2162   //  input_data: the device memory region that contains the input data.
2163   //  input_h_desc: descriptor for the input "h" state.
2164   //  input_h_data: the device memory region that contains the input "h" data.
2165   //  input_c_desc: descriptor for the input "c" state.
2166   //  input_c_data: the device memory region that contains the input "c" data.
2167   //    This must be specified for LSTM models.
2168   //  params: the device memory region that contains the parameters used in this
2169   //    model.
2170   //  output_desc: descriptor for the output sequence.
2171   //  output_data: the memory region that stores the output sequence data.
2172   //  output_h_desc: descriptor for the output "h" state.
2173   //  output_h_data: the memory region that stores the output "h" data.
2174   //  output_c_desc: descriptor for the output "c" state.
2175   //  output_c_data: the memory region that stores the output "c" data. This
2176   //    must be specified for LSTM models.
2177   //  is_training: whether this is used in training or inference. That decides
2178   //    whether respace_space data need to be produced.
2179   //  reserve_space_allocator: if "is_training" is true, an memory allocator
2180   //    to create memory that holds the produced reserve_space. The caller is
2181   //  retains the data and feed it to the backward pass.
2182   //  workspace_allocator: an allocator to create temporary workspace used in
2183   //    this kernel. The caller is responsible for retaining the memory long
2184   //    enough for the lifespan of this operation, and recycles afterwards.
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2185   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2186                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2187                             const DeviceMemory<Eigen::half>& input_data,
2188                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2189                             const DeviceMemory<Eigen::half>& input_h_data,
2190                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2191                             const DeviceMemory<Eigen::half>& input_c_data,
2192                             const DeviceMemory<Eigen::half>& params,
2193                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2194                             DeviceMemory<Eigen::half>* output_data,
2195                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2196                             DeviceMemory<Eigen::half>* output_h_data,
2197                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2198                             DeviceMemory<Eigen::half>* output_c_data,
2199                             bool is_training,
2200                             ScratchAllocator* reserve_space_allocator,
2201                             ScratchAllocator* workspace_allocator,
2202                             dnn::ProfileResult* output_profile_result) {
2203     return false;
2204   }
2205 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2206   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2207                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2208                             const DeviceMemory<float>& input_data,
2209                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2210                             const DeviceMemory<float>& input_h_data,
2211                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2212                             const DeviceMemory<float>& input_c_data,
2213                             const DeviceMemory<float>& params,
2214                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2215                             DeviceMemory<float>* output_data,
2216                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2217                             DeviceMemory<float>* output_h_data,
2218                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2219                             DeviceMemory<float>* output_c_data,
2220                             bool is_training,
2221                             ScratchAllocator* reserve_space_allocator,
2222                             ScratchAllocator* workspace_allocator,
2223                             dnn::ProfileResult* output_profile_result) {
2224     return false;
2225   }
2226 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2227   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2228                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2229                             const DeviceMemory<double>& input_data,
2230                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2231                             const DeviceMemory<double>& input_h_data,
2232                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2233                             const DeviceMemory<double>& input_c_data,
2234                             const DeviceMemory<double>& params,
2235                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2236                             DeviceMemory<double>* output_data,
2237                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2238                             DeviceMemory<double>* output_h_data,
2239                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2240                             DeviceMemory<double>* output_c_data,
2241                             bool is_training,
2242                             ScratchAllocator* reserve_space_allocator,
2243                             ScratchAllocator* workspace_allocator,
2244                             dnn::ProfileResult* output_profile_result) {
2245     return false;
2246   }
2247   // Enqueue a backward operation of the RNN model onto the stream.
2248   //
2249   // Arguments:
2250   //  stream: pointer to the stream where this operation should be enqueued to.
2251   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2252   //  input_desc: descriptor for the input sequence.
2253   //  input_data: the device memory region that contains the input data.
2254   //  input_h_desc: descriptor for the input "h" state.
2255   //  input_h_data: the device memory region that contains the input "h" data.
2256   //  input_c_desc: descriptor for the input "c" state.
2257   //  input_c_data: the device memory region that contains the input "c" data.
2258   //    This must be specified for LSTM models.
2259   //  params: the device memory region that contains the parameters used in this
2260   //    model.
2261   //  output_desc: descriptor for the output sequence.
2262   //  output_data: the memory region that stores the output sequence data.
2263   //  output_h_desc: descriptor for the output "h" state.
2264   //  output_h_data: the memory region that stores the output "h" data.
2265   //  output_c_desc: descriptor for the output "c" state.
2266   //  output_c_data: the memory region that stores the output "c" data. This
2267   //    must be specified for LSTM models.
2268   //  output_backprop_data: the device memory region that contains the backprop
2269   //    to the output sequence.
2270   //  output_h_backprop_data: the device memory region that contains the
2271   //    backprop to the output "h" state.
2272   //  output_c_backprop_data: the device memory region that contains the
2273   //    backprop to the output "c" state.
2274   //  input_backprop_data: the device memory region that stores the backprop
2275   //    to the input sequence.
2276   //  input_h_backprop_data: the device memory region that stores the backprop
2277   //    to the input "h" state.
2278   //  input_c_backprop_data: the device memory region that stores the backprop
2279   //    to the input "c" state.
2280   //  params_backprop_data: the device memory region that stores the backprop
2281   //    to the parameters.
2282   //  reserve_space_data: the reserve_space data that is produced by the forward
2283   //    operation. This memory region could be modified by this operation.
2284   //  workspace_allocator: a memory allocator that creates the temporary
2285   //    workspace memory used by this operation. The caller is responsible for
2286   //    keeping the memory alive long enough for this operation, and recylces
2287   //    afterwards.
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2288   virtual bool DoRnnBackward(
2289       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2290       const dnn::RnnSequenceTensorDescriptor& input_desc,
2291       const DeviceMemory<Eigen::half>& input_data,
2292       const dnn::RnnStateTensorDescriptor& input_h_desc,
2293       const DeviceMemory<Eigen::half>& input_h_data,
2294       const dnn::RnnStateTensorDescriptor& input_c_desc,
2295       const DeviceMemory<Eigen::half>& input_c_data,
2296       const DeviceMemory<Eigen::half>& params,
2297       const dnn::RnnSequenceTensorDescriptor& output_desc,
2298       const DeviceMemory<Eigen::half>& output_data,
2299       const dnn::RnnStateTensorDescriptor& output_h_desc,
2300       const DeviceMemory<Eigen::half>& output_h_data,
2301       const dnn::RnnStateTensorDescriptor& output_c_desc,
2302       const DeviceMemory<Eigen::half>& output_c_data,
2303       const DeviceMemory<Eigen::half>& output_backprop_data,
2304       const DeviceMemory<Eigen::half>& output_h_backprop_data,
2305       const DeviceMemory<Eigen::half>& output_c_backprop_data,
2306       DeviceMemory<Eigen::half>* input_backprop_data,
2307       DeviceMemory<Eigen::half>* input_h_backprop_data,
2308       DeviceMemory<Eigen::half>* input_c_backprop_data,
2309       DeviceMemory<Eigen::half>* params_backprop_data,
2310       DeviceMemory<uint8>* reserve_space_data,
2311       ScratchAllocator* workspace_allocator,
2312       dnn::ProfileResult* output_profile_result) {
2313     return false;
2314   }
2315 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2316   virtual bool DoRnnBackward(
2317       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2318       const dnn::RnnSequenceTensorDescriptor& input_desc,
2319       const DeviceMemory<float>& input_data,
2320       const dnn::RnnStateTensorDescriptor& input_h_desc,
2321       const DeviceMemory<float>& input_h_data,
2322       const dnn::RnnStateTensorDescriptor& input_c_desc,
2323       const DeviceMemory<float>& input_c_data,
2324       const DeviceMemory<float>& params,
2325       const dnn::RnnSequenceTensorDescriptor& output_desc,
2326       const DeviceMemory<float>& output_data,
2327       const dnn::RnnStateTensorDescriptor& output_h_desc,
2328       const DeviceMemory<float>& output_h_data,
2329       const dnn::RnnStateTensorDescriptor& output_c_desc,
2330       const DeviceMemory<float>& output_c_data,
2331       const DeviceMemory<float>& output_backprop_data,
2332       const DeviceMemory<float>& output_h_backprop_data,
2333       const DeviceMemory<float>& output_c_backprop_data,
2334       DeviceMemory<float>* input_backprop_data,
2335       DeviceMemory<float>* input_h_backprop_data,
2336       DeviceMemory<float>* input_c_backprop_data,
2337       DeviceMemory<float>* params_backprop_data,
2338       DeviceMemory<uint8>* reserve_space_data,
2339       ScratchAllocator* workspace_allocator,
2340       dnn::ProfileResult* output_profile_result) {
2341     return false;
2342   }
2343 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2344   virtual bool DoRnnBackward(
2345       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2346       const dnn::RnnSequenceTensorDescriptor& input_desc,
2347       const DeviceMemory<double>& input_data,
2348       const dnn::RnnStateTensorDescriptor& input_h_desc,
2349       const DeviceMemory<double>& input_h_data,
2350       const dnn::RnnStateTensorDescriptor& input_c_desc,
2351       const DeviceMemory<double>& input_c_data,
2352       const DeviceMemory<double>& params,
2353       const dnn::RnnSequenceTensorDescriptor& output_desc,
2354       const DeviceMemory<double>& output_data,
2355       const dnn::RnnStateTensorDescriptor& output_h_desc,
2356       const DeviceMemory<double>& output_h_data,
2357       const dnn::RnnStateTensorDescriptor& output_c_desc,
2358       const DeviceMemory<double>& output_c_data,
2359       const DeviceMemory<double>& output_backprop_data,
2360       const DeviceMemory<double>& output_h_backprop_data,
2361       const DeviceMemory<double>& output_c_backprop_data,
2362       DeviceMemory<double>* input_backprop_data,
2363       DeviceMemory<double>* input_h_backprop_data,
2364       DeviceMemory<double>* input_c_backprop_data,
2365       DeviceMemory<double>* params_backprop_data,
2366       DeviceMemory<uint8>* reserve_space_data,
2367       ScratchAllocator* workspace_allocator,
2368       dnn::ProfileResult* output_profile_result) {
2369     return false;
2370   }
2371 
2372   template <typename ElementType>
PrepareForCtcLoss(Stream * stream,const RnnStateTensorDescriptor & probs_desc,DeviceMemory<ElementType> probs_data,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2373   port::Status PrepareForCtcLoss(Stream* stream,
2374                                  const RnnStateTensorDescriptor& probs_desc,
2375                                  DeviceMemory<ElementType> probs_data,
2376                                  const RnnStateTensorDescriptor& grads_desc,
2377                                  absl::Span<const int> labels_data,
2378                                  absl::Span<const int> labels_lengths_data,
2379                                  absl::Span<const int> input_lengths_data,
2380                                  ScratchAllocator* workspace_allocator,
2381                                  DeviceMemory<uint8>* scratch_memory,
2382                                  int* ctc_loss_algo_id) {
2383     return DoPrepareForCtcLoss(
2384         stream, ToDataType<ElementType>::value, probs_desc, grads_desc,
2385         labels_data, labels_lengths_data, input_lengths_data,
2386         workspace_allocator, scratch_memory, ctc_loss_algo_id);
2387   }
2388 
2389   // Enqueue a CTC Loss operation onto the stream.
2390   //
2391   // Arguments:
2392   //  stream: pointer to the stream where this operation should be enqueued to.
2393   //  element_type: date type of the input tensors
2394   //  probs_desc: specifies the shape and the data layout of the input tensor.
2395   //  probs_data: the device memory region that contains the input tensor.
2396   //  labels_data: the device memory region that contains the labels_value
2397   //    tensor.
2398   //  labels_lengths_data: the device memory region that contains the
2399   //    labels_lengths tensor
2400   //  input_lengths_data: the device memory region that contains the seq_lengths
2401   //    tensor
2402   //  costs_data: the device memory region that contains the costs tensor.
2403   //  grads_desc: specifies the shape and the data layout of the grads tensor.
2404   //  grads_data: the device memory region that contains the grads tensor.
2405   //  ctc_loss_desc: a CTCLoss descriptor.
2406   //  workspace_allocator: a memory allocator that creates the temporary
2407   //    workspace memory used by this operation. The caller is responsible for
2408   //    keeping the memory alive long enough for this operation, and recylces
2409   //    afterwards.
2410   virtual port::Status DoCtcLoss(
2411       Stream* stream, dnn::DataType element_type,
2412       const RnnStateTensorDescriptor& probs_desc,
2413       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2414       absl::Span<const int> labels_lengths_data,
2415       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2416       const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
2417       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
2418 
2419   template <typename ElementType>
DoCtcLoss(Stream * stream,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<ElementType> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<ElementType> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<ElementType> * grads_data,DeviceMemory<uint8> * scratch_memory,int ctc_loss_algo_id)2420   bool DoCtcLoss(Stream* stream,
2421                  const dnn::RnnStateTensorDescriptor& probs_desc,
2422                  const DeviceMemory<ElementType>& probs_data,
2423                  absl::Span<const int> labels_data,
2424                  absl::Span<const int> labels_lengths_data,
2425                  absl::Span<const int> input_lengths_data,
2426                  DeviceMemory<ElementType>* costs_data,
2427                  const dnn::RnnStateTensorDescriptor& grads_desc,
2428                  DeviceMemory<ElementType>* grads_data,
2429                  DeviceMemory<uint8>* scratch_memory, int ctc_loss_algo_id) {
2430     return IsStatusOk(
2431         DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
2432                   probs_data, labels_data, labels_lengths_data,
2433                   input_lengths_data, *costs_data, grads_desc, *grads_data,
2434                   *scratch_memory, ctc_loss_algo_id),
2435         false);
2436   }
2437 
2438   // Transforms a tensor into another tensor with a different layout and/or data
2439   // type.
2440   //
2441   // Arguments:
2442   //  stream: pointer to the stream where this operation should be enqueued to.
2443   //  input_desc: specifies the shape and the data layout of the input tensor.
2444   //  input_type: the data type of the input tensor.
2445   //  input_data: the device memory region that contains the input tensor.
2446   //  output_desc: specifies the shape and the data layout of the output tensor.
2447   //  output_type: the data type of the output tensor.
2448   //  scale: an element-wise scaling factor to apply.
2449   //  output_data: the device memory region that contains the output tensor.
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)2450   virtual bool DoTransformTensor(Stream* stream,
2451                                  const dnn::BatchDescriptor& input_desc,
2452                                  dnn::DataType input_type,
2453                                  const DeviceMemoryBase& input_data,
2454                                  const dnn::BatchDescriptor& output_desc,
2455                                  dnn::DataType output_type, float scale,
2456                                  DeviceMemoryBase* output_data) {
2457     return false;
2458   }
2459 
2460   // Enqueues a fused convolution+bias+activation operation onto the stream.
2461   //
2462   // Arguments (all borrowed):
2463   //
2464   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2465   //  be enqueued onto.
2466   //
2467   //  conv_input_descriptor: dimensions of the convolution input layer.
2468   //  conv_input_data: device memory which contains the convolution input.
2469   //
2470   //  filter_descriptor: dimensions of the convolution filter.
2471   //  filter_data: device memory which contains the convolution filter weights.
2472   //
2473   //  convolution_descriptor: stride of the convolution filter.
2474   //
2475   //  bias_descriptor: dimensions of the bias layer
2476   //  biases: device memory region containing biases to add to the convolution
2477   //  output
2478   //
2479   //  activation_mode: Type of activation to perform.
2480   //
2481   //  output_descriptor: dimensions of the output layer.
2482   //  output_data: device memory region in which to place the fusion result.
2483   //
2484   //  output_profile_result: the output profile result for this call.
2485   //         The profiling is only enabled when this is not nullptr.
2486   //
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)2487   virtual bool DoFusedConvolutionBiasActivation(
2488       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
2489       const DeviceMemory<float>& conv_input_data,
2490       const dnn::FilterDescriptor& filter_descriptor,
2491       const DeviceMemory<float>& filter_data,
2492       const dnn::ConvolutionDescriptor& convolution_descriptor,
2493       const dnn::BatchDescriptor& bias_descriptor,
2494       const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
2495       const dnn::BatchDescriptor& output_descriptor,
2496       DeviceMemory<float>* output_data,
2497       dnn::ProfileResult* output_profile_result) {
2498     return false;
2499   }
2500 
2501   // Enqueues a fused batchnorm+activation (inference) operation onto the
2502   // stream.
2503   //
2504   // Arguments (all borrowed):
2505   //
2506   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2507   //  be enqueued onto.
2508   //
2509   //  x_descriptor: dimensions of the batchnorm input layer.
2510   //  x_data: device memory which contains the batchnorm input.
2511   //
2512   //  scale_offset_mean_variance_descriptor:
2513   //      dimensions of the scale/offset/mean/variance tensor.
2514   //  scale_data: device memory which contains the scale input.
2515   //  offset_data: device memory which contains the offset input.
2516   //  mean_data: device memory which contains the mean input.
2517   //  variance_data: device memory which contains the variance input.
2518   //  epsilon : the epsilon value to use in batchnorm calculation
2519   //
2520   //  activation_mode: Type of activation to perform.
2521   //
2522   //  y_data: device memory region in which to place the fusion result.
2523   //
2524   //  output_profile_result: the output profile result for this call.
2525   //         The profiling is only enabled when this is not nullptr.
2526   //
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)2527   virtual bool DoFusedBatchNormActivationInference(
2528       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2529       const DeviceMemory<float>& x_data,
2530       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2531       const DeviceMemory<float>& scale_data,
2532       const DeviceMemory<float>& offset_data,
2533       const DeviceMemory<float>& mean_data,
2534       const DeviceMemory<float>& variance_data, double epsilon,
2535       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2536       dnn::ProfileResult* output_profile_result) {
2537     return false;
2538   }
2539 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)2540   virtual bool DoFusedBatchNormActivationInference(
2541       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2542       const DeviceMemory<Eigen::half>& x_data,
2543       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2544       const DeviceMemory<float>& scale_data,
2545       const DeviceMemory<float>& offset_data,
2546       const DeviceMemory<float>& mean_data,
2547       const DeviceMemory<float>& variance_data, double epsilon,
2548       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2549       dnn::ProfileResult* output_profile_result) {
2550     return false;
2551   }
2552 
2553   // Enqueues a fused batchnorm+activation (training-fwd) operation onto the
2554   // stream.
2555   //
2556   // Arguments (all borrowed):
2557   //
2558   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2559   //  be enqueued onto.
2560   //
2561   //  x_descriptor: dimensions of the batchnorm input layer.
2562   //  x_data: device memory which contains the batchnorm input.
2563   //
2564   //  scale_offset_mean_variance_descriptor:
2565   //      dimensions of the scale/offset/mean/variance tensor.
2566   //  scale_data: device memory which contains the scale input.
2567   //  offset_data: device memory which contains the offset input.
2568   //  epsilon : the epsilon value to use in batchnorm calculation
2569   //
2570   //  activation_mode: Type of activation to perform.
2571   //
2572   //  y_data: device memory region in which to place the fusion result.
2573   //  batch_mean_data: device memory in which to place the batch mean output.
2574   //  batch_var_data: device memory in which to place the batch variance output.
2575   //  saved_mean_data: device memory in which to save the mean for bwd pass.
2576   //  saved_var_data: device memory in which to save the variance for bwd pass.
2577   //
2578   //  output_profile_result: the output profile result for this call.
2579   //         The profiling is only enabled when this is not nullptr.
2580   //
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2581   virtual bool DoFusedBatchNormActivationForward(
2582       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2583       const DeviceMemory<float>& x_data,
2584       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2585       const DeviceMemory<float>& scale_data,
2586       const DeviceMemory<float>& offset_data, double epsilon,
2587       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2588       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2589       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2590       dnn::ProfileResult* output_profile_result) {
2591     return false;
2592   }
2593 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)2594   virtual bool DoFusedBatchNormActivationForward(
2595       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2596       const DeviceMemory<Eigen::half>& x_data,
2597       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2598       const DeviceMemory<float>& scale_data,
2599       const DeviceMemory<float>& offset_data, double epsilon,
2600       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2601       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2602       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2603       dnn::ProfileResult* output_profile_result) {
2604     return false;
2605   }
2606 
2607   // Enqueues a fused batchnorm+activation (training-bwd) operation onto the
2608   // stream.
2609   //
2610   // Arguments (all borrowed):
2611   //
2612   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2613   //  be enqueued onto.
2614   //
2615   //  y_act_backprop_descriptor: dimensions of the backprop input from the
2616   //  previous layer. y_act_backprop_data: device memory which contains the
2617   //  backprop input.
2618   //
2619   //  y_act_data: device memory which contains the actv-fwd output data.
2620   //
2621   //  activation_mode: actv-fwd type.
2622   //
2623   //  scale_offset_mean_variance_descriptor:
2624   //      dimensions of the scale/offset/mean/variance tensor.
2625   //  scale_data: device memory which contains the scale input.
2626   //  offset_data: device memory which contains the offset input.
2627   //  saved_mean_data: device memory which contains the saved mean from fwd
2628   //  pass. saved_var_data: device memory which contains the saved variance from
2629   //  fwd pass.
2630   //
2631   //  x_bn_backprop_data: device memory region in which to place the backprop
2632   //  data from this layer scale_backprop_data: device memory in which to place
2633   //  the scale backprop output. offset_backprop_data: device memory in which to
2634   //  place the offset backprop output.
2635   //
2636   //  output_profile_result: the output profile result for this call.
2637   //         The profiling is only enabled when this is not nullptr.
2638   //
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2639   virtual bool DoFusedBatchNormActivationBackward(
2640       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2641       const DeviceMemory<float>& y_act_backprop_data,
2642       const DeviceMemory<float>& y_act_data,
2643       dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
2644       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2645       const DeviceMemory<float>& scale_data,
2646       const DeviceMemory<float>& offset_data,
2647       const DeviceMemory<float>& saved_mean_data,
2648       const DeviceMemory<float>& saved_var_data,
2649       DeviceMemory<float>* x_bn_backprop_data,
2650       DeviceMemory<float>* scale_backprop_data,
2651       DeviceMemory<float>* offset_backprop_data,
2652       dnn::ProfileResult* output_profile_result) {
2653     return false;
2654   }
2655 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)2656   virtual bool DoFusedBatchNormActivationBackward(
2657       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2658       const DeviceMemory<Eigen::half>& y_act_backprop_data,
2659       const DeviceMemory<Eigen::half>& y_act_data,
2660       dnn::ActivationMode activation_mode,
2661       const DeviceMemory<Eigen::half>& x_bn_data,
2662       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2663       const DeviceMemory<float>& scale_data,
2664       const DeviceMemory<float>& offset_data,
2665       const DeviceMemory<float>& saved_mean_data,
2666       const DeviceMemory<float>& saved_var_data,
2667       DeviceMemory<Eigen::half>* x_bn_backprop_data,
2668       DeviceMemory<float>* scale_backprop_data,
2669       DeviceMemory<float>* offset_backprop_data,
2670       dnn::ProfileResult* output_profile_result) {
2671     return false;
2672   }
2673 
2674  protected:
2675   // Returns whether status is 'ok', and potentially logs the error.
2676   static bool IsStatusOk(const port::Status& status, bool report_error);
2677 
2678  private:
DoPrepareForConvolution(ConvolutionKind kind,DataType element_type,Stream * stream,const BatchDescriptor & batch_descriptor,DeviceMemoryBase input_data,const FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const ConvolutionDescriptor & convolution_descriptor,const AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2679   virtual port::Status DoPrepareForConvolution(
2680       ConvolutionKind kind, DataType element_type, Stream* stream,
2681       const BatchDescriptor& batch_descriptor, DeviceMemoryBase input_data,
2682       const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data,
2683       const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
2684       const ConvolutionDescriptor& convolution_descriptor,
2685       const AlgorithmConfig& algorithm_config,
2686       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
2687       DeviceMemory<uint8>* scratch_memory) {
2688     *algorithm_desc = {};
2689     *scratch_memory = {};
2690     return port::Status::OK();
2691   }
2692 
DoPrepareForCtcLoss(Stream * stream,DataType element_type,const RnnStateTensorDescriptor & probs_desc,const RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2693   virtual port::Status DoPrepareForCtcLoss(
2694       Stream* stream, DataType element_type,
2695       const RnnStateTensorDescriptor& probs_desc,
2696       const RnnStateTensorDescriptor& grads_desc,
2697       absl::Span<const int> labels_data,
2698       absl::Span<const int> labels_lengths_data,
2699       absl::Span<const int> input_lengths_data,
2700       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
2701       int* ctc_loss_algo_id) {
2702     *scratch_memory = {};
2703     return port::Status::OK();
2704   }
2705 
2706   SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
2707 };
2708 
2709 }  // namespace dnn
2710 }  // namespace stream_executor
2711 
2712 #endif  // TENSORFLOW_STREAM_EXECUTOR_DNN_H_
2713