• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Neural Net operation support for StreamExecutor instances.
17 //
18 // This is an abstract interface for a platform to optionally support common
19 // neural net operations; it accommodates implementations such as the cudnn
20 // library operations.
21 
22 #ifndef TENSORFLOW_STREAM_EXECUTOR_DNN_H_
23 #define TENSORFLOW_STREAM_EXECUTOR_DNN_H_
24 
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <tuple>
29 #include <type_traits>
30 
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34 #include "tensorflow/stream_executor/dnn.pb.h"
35 #include "tensorflow/stream_executor/lib/array_slice.h"
36 #include "tensorflow/stream_executor/lib/status.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/platform/logging.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 
41 namespace Eigen {
42 struct half;
43 }  // namespace Eigen
44 
45 namespace stream_executor {
46 
47 class HostBuffer;
48 class Stream;
49 class ScratchAllocator;
50 
51 namespace dnn {
52 
53 // Specifies an index to use when accessing specific spatial dimensions.
54 enum class DimIndex : int {
55   X = 0,
56   Y = 1,
57   Z = 2,
58 };
59 
60 // Helper functions to make methods more readable.
GetDim(absl::Span<const int64> data,DimIndex dim)61 inline int64 GetDim(absl::Span<const int64> data, DimIndex dim) {
62   return data.rbegin()[static_cast<int64>(dim)];
63 }
64 
SetDim(absl::Span<int64> data,DimIndex dim,int64 value)65 inline void SetDim(absl::Span<int64> data, DimIndex dim, int64 value) {
66   data.rbegin()[static_cast<int64>(dim)] = value;
67 }
68 
SetDim(std::vector<int64> * data,DimIndex dim,int64 value)69 inline void SetDim(std::vector<int64>* data, DimIndex dim, int64 value) {
70   return SetDim(absl::MakeSpan(*data), dim, value);
71 }
72 
73 // int64 is not the same type as tensorflow::protobuf_int64 in open-source. This
74 // wrapper function gives an int64 array slice view of a repeated int64 protobuf
75 // field.
76 //
77 // T should be a protobuf RepeatedField.
78 template <typename T>
AsInt64Slice(const T & repeated_field)79 inline absl::Span<const int64> AsInt64Slice(const T& repeated_field) {
80   using data_ty =
81       typename std::remove_reference<decltype(*repeated_field.data())>::type;
82   static_assert(std::is_integral<data_ty>::value &&
83                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
84                 "repeated_field.data() must return a pointer to a signed "
85                 "64-bit integer type.");
86   return absl::Span<const int64>(
87       reinterpret_cast<const int64*>(repeated_field.data()),
88       repeated_field.size());
89 }
90 template <typename T>
AsInt64Slice(T * repeated_field)91 inline absl::Span<int64> AsInt64Slice(T* repeated_field) {
92   using data_ty =
93       typename std::remove_reference<decltype(*repeated_field->data())>::type;
94   static_assert(std::is_integral<data_ty>::value &&
95                     std::is_signed<data_ty>::value && sizeof(data_ty) == 8,
96                 "repeated_field->data() must return a pointer to a signed "
97                 "64-bit integer type.");
98   return absl::Span<int64>(
99       reinterpret_cast<int64*>(repeated_field->mutable_data()),
100       repeated_field->size());
101 }
102 
103 // Returns a string representation of the given data layout.
104 string DataLayoutString(DataLayout layout);
105 
106 // Specifies a quantization for activations in a given BatchDescriptor.
107 enum class QuantizedActivationMode {
108   k8Bit = 1,
109   k16Bit = 2,
110   k32Bit = 4,
111 };
112 
113 // A helper class to convert C/C++ types to the proper enums.
114 template <typename T>
115 struct ToDataType;
116 template <>
117 struct ToDataType<float> {
118   static constexpr DataType value = DataType::kFloat;
119 };
120 template <>
121 struct ToDataType<double> {
122   static constexpr DataType value = DataType::kDouble;
123 };
124 template <>
125 struct ToDataType<Eigen::half> {
126   static constexpr DataType value = DataType::kHalf;
127 };
128 template <>
129 struct ToDataType<int8> {
130   static constexpr DataType value = DataType::kInt8;
131 };
132 template <>
133 struct ToDataType<int32> {
134   static constexpr DataType value = DataType::kInt32;
135 };
136 
137 // Specifies the types of a RNN model.
138 enum class RnnMode {
139   kRnnRelu = 0,
140   kRnnTanh = 1,
141   kRnnLstm = 2,
142   kRnnGru = 3,
143 };
144 
145 // Specifies the input model and whether there is a linear transformation
146 // between the input state and the first layer hidden state.
147 enum class RnnInputMode {
148   kRnnLinearSkip = 0,
149   kRnnSkipInput = 1,
150 };
151 
152 // Specifies the number of directions used in a RNN model. When bidirection
153 // is used, the input states and output sequence contain data for both
154 // directions.
155 enum class RnnDirectionMode {
156   kRnnUnidirectional = 0,
157   kRnnBidirectional = 1,
158 };
159 
160 // Relevant to DepthToSpace and SpaceToDepth. This is the write layout when
161 // performing depth to space and the read layout when performing space to depth.
162 // It's specified with most-major dimension first and most-minor dimension last.
163 // In DepthToSpace, the D*M^2 values are read in and then, for DepthHeightWidth,
164 // written out to the output patch, by varying first width, then height, then
165 // depth. In C array format, it looks like [depth][height][width]. See
166 // DepthToSpace comment for more information.
167 enum class DepthToSpaceLayout { DepthHeightWidth };
168 
169 // Specifies the descriptor for a RNN model.
170 //
171 // An example use case:
172 //   * The user first creates a model through createRnnDescriptor.
173 //   * The user queries the size of the underlying opaque parameter buffer.
174 //   * The user creates and initializes a parameter buffer of the proper size.
175 //   * The user runs forward and backward operations using this RNN descriptor.
176 //   * Once a while, user queries maintainable weights and bias regions from
177 //       the underlying parameter buffer. They are more likely to be forward
178 //       compatible and should used in saving and restoring a model.
179 //   * The user releases the RNN descriptor when the model is no longer in use.
180 class RnnDescriptor {
181  public:
182   struct ParamsRegion {
183     int64 offset;
184     int64 size;
185   };
186   typedef std::vector<ParamsRegion> ParamsRegions;
187   virtual ~RnnDescriptor() {}
188   virtual int64 ParamsSizeInBytes() const { return -1; }
189   virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); }
190   virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
191 };
192 
193 // Specifies the sequence in a RNN model.
194 //
195 // The user is responsible for releasing this descriptor when it is no longer
196 // in use. The destructor releases the underlying descriptors.
197 class RnnSequenceTensorDescriptor {
198  public:
199   virtual ~RnnSequenceTensorDescriptor() {}
200 };
201 
202 // Specifies either the input and hidden state in a RNN model.
203 //
204 // The user is responsible for releasing this descriptor when it is no longer
205 // in use. The destructor releases the underlying descriptors.
206 class RnnStateTensorDescriptor {
207  public:
208   virtual ~RnnStateTensorDescriptor() {}
209 };
210 
211 // Returns a string representation of the given quantization mode.
212 string QuantizedActivationModeString(QuantizedActivationMode mode);
213 
214 // Describes the dimensions that a layer consumes/produces.
215 //
216 // This is a matrix (height, width), its "depth" (feature_map_count),
217 // how many of these matrices are present (count),
218 // and the maximum and minimum values expected in the matrix (value_max,
219 // value_min).
220 // If input is quantized, all values greater
221 // than value_max will be clipped to value_max and all values less than
222 // value_min will be clipped to value_min.
223 // When quantized output is dequantized no value will be greater than
224 // value_max or less than value_min.
225 //
226 // Uses the named argument construction form:
227 //
228 //  auto input_batch_dimensions =
229 //      BatchDescriptor().set_count(42).set_feature_map_count(7)...
230 //
231 // Details:
232 //
233 // For a convolutional layer, a single inference takes a 3-dimensional matrix
234 // of input and produces a 3-dimensional matrix of output. We call the three
235 // dimensions height, width and feature_map_count, where for an image, the
236 // height and width correspond to the Y and X pixel indices, respectively, and
237 // the feature_map_count corresponds to the RGB dimension of the input data.
238 // Then the count indicates how many 3D matrices are being presented to be
239 // processed at once; this corresponds to the neural network concept of
240 // minibatch size.
241 //
242 // For a fully connected layer, it's better to put the nodes of the layer in
243 // the feature_map_count, and leave the height and weight as degenerate (== 1).
244 // Count indicates how many input vectors (degenerate 3D matrices) are to be
245 // processed.
246 //
247 // If unspecified, value_max and value_min default to 0.0.
248 // If value_max == value_min the Stream will attempt to derive valid values -
249 // for example the output of Relu6 activation will always be in the range
250 // [0.0, 6.0].
251 //
252 // If unspecified, layout defaults to kYXDepthBatch.
253 class BatchDescriptor {
254  public:
255   // Creates a "blank" batch descriptor, which should be initialized via the
256   // named argument helpers.
257   BatchDescriptor();
258   explicit BatchDescriptor(int ndims);
259 
260   // Clones values from 'other' for initialization.
261   void CloneFrom(const BatchDescriptor& other);
262 
263   string ToString() const;
264   string ToShortString() const;
265 
266   // Pre-condition:
267   //   value_max_ == 0
268   //   value_min_ == 0
269   //   quantized_activation_mode_ == QuantizedActivationMode::k8Bit
270   TensorDescriptorProto ToProto(DataType data_type) const;
271 
272   // Accessors.
273   int64 count() const { return tensor_.dimensions(0); }
274   int64 feature_map_count() const { return tensor_.dimensions(1); }
275   int64 height() const { return GetDim(spatial_size(), DimIndex::Y); }
276   int64 width() const { return GetDim(spatial_size(), DimIndex::X); }
277   int64 spatial_dim(DimIndex dim) const { return GetDim(spatial_size(), dim); }
278   int ndims() const { return spatial_size().size(); }
279   float value_max() const { return value_max_; }
280   float value_min() const { return value_min_; }
281   DataLayout layout() const { return tensor_.data_layout(); }
282   QuantizedActivationMode quantized_activation_mode() const {
283     return quantized_activation_mode_;
284   }
285   // Full dimensions of the underlying data, ordered according to a specific
286   // layout.
287   std::vector<int64> full_dims(const DataLayout& layout) const;
288 
289   // Full strides of the underlying data, ordered according to a specific
290   // layout.
291   std::vector<int64> full_strides(const DataLayout& layout) const;
292 
293   // Named-argument helpers for avoiding user error during construction.
294   BatchDescriptor& set_count(int64 value) {
295     tensor_.set_dimensions(0, value);
296     return *this;
297   }
298   BatchDescriptor& set_feature_map_count(int64 value) {
299     tensor_.set_dimensions(1, value);
300     return *this;
301   }
302   BatchDescriptor& set_height(int64 value) {
303     SetDim(spatial_size(), DimIndex::Y, value);
304     return *this;
305   }
306   BatchDescriptor& set_width(int64 value) {
307     SetDim(spatial_size(), DimIndex::X, value);
308     return *this;
309   }
310   BatchDescriptor& set_spatial_dim(DimIndex dim, int64 value) {
311     SetDim(spatial_size(), dim, value);
312     return *this;
313   }
314   BatchDescriptor& set_value_max(float value) {
315     value_max_ = value;
316     return *this;
317   }
318   BatchDescriptor& set_value_min(float value) {
319     value_min_ = value;
320     return *this;
321   }
322   BatchDescriptor& set_layout(DataLayout layout) {
323     tensor_.set_data_layout(layout);
324     return *this;
325   }
326   BatchDescriptor& set_quantized_activation_mode(
327       QuantizedActivationMode quantized_activation_mode) {
328     quantized_activation_mode_ = quantized_activation_mode;
329     return *this;
330   }
331 
332   // Return the number of nodes in a single feature map.
333   int64 NodesPerFeatureMap() const;
334 
335   // Return the number of nodes across all feature maps. Note that this is not
336   // affected by the batch count.
337   int64 NodesAcrossFeatureMaps() const;
338 
339   // Returns the number of elements (e.g. RGB pixel values) required to hold a
340   // given batch descriptor, given a no-padding assumption. Note that this is
341   // affected by the batch count.
342   int64 ElementCount() const;
343 
344   // Return the number of weights required to fully connect a layer with
345   // dimensions given by the 'input' descriptor with a layer with dimensions
346   // given by the 'output' descriptor.
347   static int64 FullyConnectedWeightCount(const BatchDescriptor& input,
348                                          const BatchDescriptor& output);
349 
350   // Return the number of biases required to fully connect to an output layer
351   // with dimensions given the 'output' descriptor.
352   static int64 FullyConnectedBiasCount(const BatchDescriptor& output);
353 
354   // Return a BatchDescriptor for the output of a depth concatenation
355   // with the given input descriptors. The inputs should have the same
356   // dimensions, except possibly for feature_map_count(), though this
357   // function does not verify that.
358   static BatchDescriptor DepthConcatenateOutputDescriptor(
359       port::ArraySlice<dnn::BatchDescriptor> inputs);
360 
361  private:
362   absl::Span<const int64> spatial_size() const {
363     return AsInt64Slice(tensor_.dimensions()).subspan(2);
364   }
365 
366   absl::Span<int64> spatial_size() {
367     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
368   }
369 
370   TensorDescriptorProto tensor_;
371   float value_max_;
372   float value_min_;
373   QuantizedActivationMode quantized_activation_mode_;
374 };
375 
376 // Returns a string representation of the given filter layout.
377 string FilterLayoutString(FilterLayout layout);
378 
379 // Describes a filter for the convolution. This is the "window" from
380 // height-by-width patches of each of the feature maps in the input layer to the
381 // cells within the output feature map.
382 //
383 // Uses the named argument construction form:
384 //
385 //  FilterDescriptor filter_dimensions;
386 //  filter_dimensions
387 //    .set_output_feature_map_count(42)
388 //    .set_input_feature_map_count(7)
389 //    ...
390 //
391 // Arguments:
392 // - output_feature_map_count: number of feature maps in the output layer.
393 // - input_feature_map_count: number of feature maps in the input layer (from
394 //      which the filter patch is taken).
395 // - input_filter_height: "height" number of neurons used in the sliding window
396 //      over the input layer.
397 // - input_filter_width: "width" number of neurons used in the sliding window
398 //      over the input layer.
399 //
400 // Sometimes names like "filter input height" are referred to by synonymous
401 // terminology, such as "kernel y size".
402 //
403 // If unspecified, layout defaults to kOutputInputYX.
404 class FilterDescriptor {
405  public:
406   // By default construction, all dimensions are set to zero, so they should all
407   // be populated by the user via the named-argument helpers below. (See class
408   // comment for details.)
409   FilterDescriptor();
410   explicit FilterDescriptor(int ndims);
411   ~FilterDescriptor();
412 
413   // Named-argument helpers for avoiding user error during construction.
414   FilterDescriptor& set_output_feature_map_count(int64 value) {
415     tensor_.set_dimensions(0, value);
416     return *this;
417   }
418   FilterDescriptor& set_input_feature_map_count(int64 value) {
419     tensor_.set_dimensions(1, value);
420     return *this;
421   }
422   FilterDescriptor& set_input_filter_height(int64 value) {
423     SetDim(input_filter_dims(), DimIndex::Y, value);
424     return *this;
425   }
426   FilterDescriptor& set_input_filter_width(int64 value) {
427     SetDim(input_filter_dims(), DimIndex::X, value);
428     return *this;
429   }
430   FilterDescriptor& set_layout(FilterLayout layout) {
431     tensor_.set_filter_layout(layout);
432     return *this;
433   }
434   FilterDescriptor& set_spatial_dim(DimIndex dim, int64 value) {
435     SetDim(input_filter_dims(), dim, value);
436     return *this;
437   }
438   int ndims() const { return input_filter_dims().size(); }
439 
440   void CloneFrom(const FilterDescriptor& other);
441 
442   string ToString() const;
443   string ToShortString() const;
444   TensorDescriptorProto ToProto(DataType data_type) const;
445 
446   // Returns the number of weights required as parameters for a convolution
447   // using this filter descriptor.
448   int64 ComputeWeightCount() const;
449 
450   // Returns the number of biases required as parameters for a convolution
451   // using this filter descriptor.
452   int64 bias_count() const { return output_feature_map_count(); }
453 
454   int64 output_feature_map_count() const { return tensor_.dimensions(0); }
455   int64 input_feature_map_count() const { return tensor_.dimensions(1); }
456   int64 input_filter_height() const {
457     return GetDim(input_filter_dims(), DimIndex::Y);
458   }
459   int64 input_filter_width() const {
460     return GetDim(input_filter_dims(), DimIndex::X);
461   }
462   int64 input_filter_dim(DimIndex dim) const {
463     return GetDim(input_filter_dims(), dim);
464   }
465 
466   FilterLayout layout() const { return tensor_.filter_layout(); }
467 
468   absl::Span<const int64> input_filter_dims() const {
469     return AsInt64Slice(tensor_.dimensions()).subspan(2);
470   }
471 
472  private:
473   absl::Span<int64> input_filter_dims() {
474     return AsInt64Slice(tensor_.mutable_dimensions()).subspan(2);
475   }
476 
477   TensorDescriptorProto tensor_;
478 };
479 
480 // Describes how padding should be aligned when the total number of pad
481 // elements is odd.
482 enum class PadAlignment : int64 {
483   kDefault = 0,        // default padding for the device.
484   kCudnnPadding,       // cuDNN padding - prefer to pad at the start.
485   kTensorFlowPadding,  // TensorFlow padding - prefer to pad at the end.
486 };
487 
488 // Returns a string representation of the given padding alignment.
489 string PadAlignmentString(PadAlignment alignment);
490 
491 // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
492 std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
493 
494 // Describes a convolution.
495 //
496 // Uses the named argument construction form:
497 //
498 //  ConvolutionDescriptor convolution_dimensions;
499 //  convolution_dimensions
500 //    .set_vertical_filter_stride(2)
501 //    .set_horizontal_filter_stride(2)
502 //    ...
503 //
504 // Arguments:
505 // - zero_padding_height: padding of the "y dimension" of the input data. Note
506 //    that this is different from the height of the filter.
507 // - zero_padding_width: analogous to the height above, but in the "x
508 //    dimension".
509 // - vertical_filter_stride: the convolution slides a 2-dimensional window of
510 //    filter-height-by-filter-width over the input layer -- the center of that
511 //    window is moved in the "y dimension" according to this stride value.
512 // - horizontal_filter_stride: analogous to the vertical stride above, but in
513 //    the "x dimension".
514 // - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped
515 //   cells between each filter element in the "y dimension".
516 // - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1)
517 //   skipped cells between each filter element in the "x dimension".
518 // - convolution_not_crosscor: By default (convolution_not_crosscor == false),
519 //   we perform cross correlation rather than convolution. With the flag set,
520 //   we perform convolution. Convolution and cross correlation are related by
521 //   rotating the filter by 180 degrees (or equivalently flipping all spatial
522 //   dimensions).
523 class ConvolutionDescriptor {
524  public:
525   // By default construction, there is no zero-padding and the filter stride is
526   // 1x1 (centering the filter on every cell in the input layer's
527   // width-by-height area).
528   ConvolutionDescriptor();
529   explicit ConvolutionDescriptor(int ndims);
530   ~ConvolutionDescriptor();
531 
532   string ToString() const;
533   string ToShortString() const;
534   ConvolutionDescriptorProto ToProto() const { return proto_; }
535 
536   ConvolutionDescriptor& set_zero_padding_height(int64 value) {
537     SetDim(padding(), DimIndex::Y, value);
538     return *this;
539   }
540   ConvolutionDescriptor& set_zero_padding_width(int64 value) {
541     SetDim(padding(), DimIndex::X, value);
542     return *this;
543   }
544   ConvolutionDescriptor& set_zero_padding(DimIndex dim, int64 value) {
545     SetDim(padding(), dim, value);
546     return *this;
547   }
548   ConvolutionDescriptor& set_vertical_filter_stride(int64 value) {
549     SetDim(strides(), DimIndex::Y, value);
550     return *this;
551   }
552   ConvolutionDescriptor& set_horizontal_filter_stride(int64 value) {
553     SetDim(strides(), DimIndex::X, value);
554     return *this;
555   }
556   ConvolutionDescriptor& set_filter_stride(DimIndex dim, int64 value) {
557     SetDim(strides(), dim, value);
558     return *this;
559   }
560   ConvolutionDescriptor& set_vertical_dilation_rate(int64 value) {
561     SetDim(dilations(), DimIndex::Y, value);
562     return *this;
563   }
564   ConvolutionDescriptor& set_horizontal_dilation_rate(int64 value) {
565     SetDim(dilations(), DimIndex::X, value);
566     return *this;
567   }
568   ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64 value) {
569     SetDim(dilations(), dim, value);
570     return *this;
571   }
572   ConvolutionDescriptor& set_group_count(int group_count) {
573     proto_.set_group_count(group_count);
574     return *this;
575   }
576   ConvolutionDescriptor& set_convolution_not_crosscorr(bool conv) {
577     proto_.set_convolution_mode(conv ? ConvolutionMode::CONVOLUTION
578                                      : ConvolutionMode::CROSS_CORRELATION);
579     return *this;
580   }
581   ConvolutionDescriptor& set_name(const string& name) {
582     proto_.set_name(name);
583     return *this;
584   }
585   int64 zero_padding_height() const { return GetDim(padding(), DimIndex::Y); }
586   int64 zero_padding_width() const { return GetDim(padding(), DimIndex::X); }
587   int64 vertical_filter_stride() const {
588     return GetDim(strides(), DimIndex::Y);
589   }
590   int64 horizontal_filter_stride() const {
591     return GetDim(strides(), DimIndex::X);
592   }
593   int64 vertical_dilation_rate() const {
594     return GetDim(dilations(), DimIndex::Y);
595   }
596   int64 horizontal_dilation_rate() const {
597     return GetDim(dilations(), DimIndex::X);
598   }
599 
600   int zero_padding(DimIndex dim) const { return GetDim(padding(), dim); }
601   int filter_stride(DimIndex dim) const { return GetDim(strides(), dim); }
602   int dilation_rate(DimIndex dim) const { return GetDim(dilations(), dim); }
603   // TODO(timshen): remove this function. No users of this class is setting a
604   // non-default pad alignment.
605   PadAlignment pad_alignment() const { return PadAlignment::kDefault; }
606   int group_count() const { return proto_.group_count(); }
607   int ndims() const { return padding().size(); }
608   bool convolution_not_crosscorr() const {
609     return proto_.convolution_mode() == ConvolutionMode::CONVOLUTION;
610   }
611 
612   absl::Span<const int64> strides() const {
613     return AsInt64Slice(proto_.strides());
614   }
615 
616   absl::Span<const int64> dilations() const {
617     return AsInt64Slice(proto_.dilations());
618   }
619 
620   absl::Span<const int64> padding() const {
621     return AsInt64Slice(proto_.paddings());
622   }
623 
624   string name() const { return proto_.name(); }
625 
626  private:
627   absl::Span<int64> strides() { return AsInt64Slice(proto_.mutable_strides()); }
628 
629   absl::Span<int64> dilations() {
630     return AsInt64Slice(proto_.mutable_dilations());
631   }
632 
633   absl::Span<int64> padding() {
634     return AsInt64Slice(proto_.mutable_paddings());
635   }
636 
637   ConvolutionDescriptorProto proto_;
638 
639   // TODO(leary) cudnn provides these fields, but need to characterize what
640   // their effect is -- they may be boolean rather than integral.
641   // int64 upscale_input_x;
642   // int64 upscale_input_y;
643 };
644 
645 // A patch of values in the input can be pooled via either a max or an average
646 // operation.
647 // Specify int64 so there's no padding in PoolingDescriptor.
648 enum class PoolingMode : int64 {
649   kMaximum,
650   kAverage,
651 };
652 
653 // Specify the dimension in which to concatenate inputs in space.
654 // Specify int64 so there's no padding in SpaceConcatenateMode.
655 enum class SpaceConcatenateMode : int64 {
656   XDirection,
657   YDirection,
658 };
659 
660 // Returns a short name for the pooling mode, e.g. "Avg".
661 string ShortPoolingModeString(PoolingMode mode);
662 
663 // Describes a pooling operation to be enqueued onto a stream via a platform's
664 // DnnSupport.
665 //
666 // TODO(broune): describe how padding works and what happens if the
667 // window height/width is not divisible by the vertical/horizontal
668 // stride.
669 //
670 // Arguments:
671 //  pooling_mode: pooling operator to use on the input patch
672 //  window_height: height of input window
673 //  window_width: width of input window
674 //  vertical_stride: vertical delta for center of the input patch
675 //  horizontal_stride: horizontal delta for center of the input patch
676 class PoolingDescriptor {
677  public:
678   PoolingDescriptor();
679   explicit PoolingDescriptor(int ndims);
680 
681   PoolingDescriptor& set_pooling_mode(PoolingMode value) {
682     mode_ = value;
683     return *this;
684   }
685   PoolingDescriptor& set_window_height(int64 value) {
686     SetDim(&window_, DimIndex::Y, value);
687     return *this;
688   }
689   PoolingDescriptor& set_window_width(int64 value) {
690     SetDim(&window_, DimIndex::X, value);
691     return *this;
692   }
693   PoolingDescriptor& set_window(DimIndex dim, int64 value) {
694     SetDim(&window_, dim, value);
695     return *this;
696   }
697   PoolingDescriptor& set_vertical_padding(int64 value) {
698     SetDim(&padding_, DimIndex::Y, value);
699     return *this;
700   }
701   PoolingDescriptor& set_horizontal_padding(int64 value) {
702     SetDim(&padding_, DimIndex::X, value);
703     return *this;
704   }
705   PoolingDescriptor& set_padding(DimIndex dim, int64 value) {
706     SetDim(&padding_, dim, value);
707     return *this;
708   }
709   PoolingDescriptor& set_vertical_stride(int64 value) {
710     SetDim(&strides_, DimIndex::Y, value);
711     return *this;
712   }
713   PoolingDescriptor& set_horizontal_stride(int64 value) {
714     SetDim(&strides_, DimIndex::X, value);
715     return *this;
716   }
717   PoolingDescriptor& set_stride(DimIndex dim, int64 value) {
718     SetDim(&strides_, dim, value);
719     return *this;
720   }
721   PoolingDescriptor& set_propagate_nans(bool value) {
722     propagate_nans_ = value;
723     return *this;
724   }
725   PoolingDescriptor& set_name(const string& name) {
726     name_ = name;
727     return *this;
728   }
729 
730   int ndims() const { return ndims_; }
731   void CloneFrom(const PoolingDescriptor& other);
732 
733   string ToString() const;
734   string ToShortString() const;
735 
736   PoolingMode mode() const { return mode_; }
737   int64 window_height() const { return GetDim(window_, DimIndex::Y); }
738   int64 window_width() const { return GetDim(window_, DimIndex::X); }
739   int64 window(DimIndex dim) const { return GetDim(window_, dim); }
740   int64 vertical_padding() const { return GetDim(padding_, DimIndex::Y); }
741   int64 horizontal_padding() const { return GetDim(padding_, DimIndex::X); }
742   int64 padding(DimIndex dim) const { return GetDim(padding_, dim); }
743   int64 vertical_stride() const { return GetDim(strides_, DimIndex::Y); }
744   int64 horizontal_stride() const { return GetDim(strides_, DimIndex::X); }
745   int64 stride(DimIndex dim) const { return GetDim(strides_, dim); }
746   absl::Span<const int64> window() const { return window_; }
747   absl::Span<const int64> padding() const { return padding_; }
748   absl::Span<const int64> strides() const { return strides_; }
749   bool propagate_nans() const { return propagate_nans_; }
750   string name() const { return name_; }
751 
752  private:
753   PoolingMode mode_;
754   int ndims_;
755   bool propagate_nans_;
756   string name_;  // Name as in Tensorflow NodeDef, for debugging purposes.
757 
758   // Stored as: ..., y, x.
759   std::vector<int64> window_;
760   std::vector<int64> padding_;
761   std::vector<int64> strides_;
762 };
763 
764 // Collects parameters for DNN algorithms
765 class AlgorithmDesc {
766  public:
767   typedef int64 Index;
768   AlgorithmDesc() : AlgorithmDesc(0, false) {}
769   AlgorithmDesc(Index a, bool use_tensor_ops) {
770     proto_.set_algo_id(a);
771     proto_.set_math_type(use_tensor_ops ? AlgorithmProto::TENSOR_OP_MATH
772                                         : AlgorithmProto::DEFAULT_MATH);
773   }
774   bool tensor_ops_enabled() const {
775     return proto_.math_type() == AlgorithmProto::TENSOR_OP_MATH;
776   }
777   Index algo_id() const { return proto_.algo_id(); }
778   bool operator==(const AlgorithmDesc& other) const {
779     return algo_id() == other.algo_id() &&
780            tensor_ops_enabled() == other.tensor_ops_enabled();
781   }
782   uint64 hash() const;
783 
784   AlgorithmProto ToProto() const { return proto_; }
785 
786   string ToString() const;
787 
788  private:
789   AlgorithmProto proto_;
790 };
791 
792 // Describes the result from a perf experiment.
793 //
794 // Arguments:
795 //  algorithm: returns the exact algorithm that was used.
796 //  elapsed_time_in_ms: returns the measured elapsed time in milliseconds.
797 class ProfileResult {
798  public:
799   bool is_valid() const {
800     return algorithm_.has_value() &&
801            elapsed_time_in_ms() != std::numeric_limits<float>::max();
802   }
803 
804   AlgorithmDesc algorithm() const { return *algorithm_; }
805   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
806 
807   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
808   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
809 
810   size_t scratch_size() const { return scratch_size_; }
811   void set_scratch_size(size_t val) { scratch_size_ = val; }
812 
813  private:
814   absl::optional<AlgorithmDesc> algorithm_;
815   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
816   // The scratch size algorithm_ requires. Currently it's only populated by
817   // convolutions.
818   size_t scratch_size_ = 0;
819 };
820 
821 // Describes the configuration for the algorithms that will used.
822 //
823 // Arguments:
824 //  algorithm: the primary algorithm that should be used.
825 //  algorithm_no_scratch: a secondary algorithm that should be used, if the
826 //    the allocation for the scratch memory fails.
827 //  scrach_size: specify the size of scratch memory in bytes needed for the
828 //    algorithm used.
829 //
830 // On CUDA platform with CUDNN library, algorithm and algorithm_no_scratch
831 // would be used. On ROCm platform with MIOpen library, algorithm and
832 // scratch_size would be used. The major difference between the two platforms
833 // are whether it's possible to get an algorithm without scratch memory. On
834 // CUDA + CUDNN it's possible, and algorithm_no_scratch can be used to track
835 // such information, whereas on ROCm + MIOpen there is no guarantee to getting
836 // one without scratch memory, and scratch_size field is used to track it.
837 class AlgorithmConfig {
838  public:
839   AlgorithmConfig() {}
840   explicit AlgorithmConfig(AlgorithmDesc algorithm) : algorithm_(algorithm) {}
841   AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size)
842       : algorithm_(algorithm), scratch_size_(scratch_size) {}
843   AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch)
844       : algorithm_(algorithm), algorithm_no_scratch_(algorithm_no_scratch) {}
845   absl::optional<AlgorithmDesc> algorithm() const { return algorithm_; }
846   void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
847   absl::optional<AlgorithmDesc> algorithm_no_scratch() const {
848     return algorithm_no_scratch_;
849   }
850   void set_algorithm_no_scratch(AlgorithmDesc val) {
851     algorithm_no_scratch_ = val;
852   }
853   absl::optional<size_t> scratch_size() const { return scratch_size_; }
854   void set_scratch_size(size_t val) { scratch_size_ = val; }
855   bool operator==(const AlgorithmConfig& other) const {
856     return this->algorithm_ == other.algorithm_ &&
857            this->algorithm_no_scratch_ == other.algorithm_no_scratch_ &&
858            this->scratch_size_ == other.scratch_size_;
859   }
860   bool operator!=(const AlgorithmConfig& other) const {
861     return !(*this == other);
862   }
863   string ToString() const;
864 
865  private:
866   absl::optional<AlgorithmDesc> algorithm_;
867   absl::optional<AlgorithmDesc> algorithm_no_scratch_;
868   absl::optional<size_t> scratch_size_;
869 };
870 
871 // Describes a local response normalization (LRN). LRN is used e.g. in
872 // dist_belief.
873 //
874 // Let V be the vector of feature maps at some (batch, y, x)
875 // coordinate. LRN applies independently to each vector V in the
876 // input, across all coordinates (batch, y, x), by mapping each V to
877 // another vector U of the same size using the formula
878 //
879 //   U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
880 //
881 // where the sum is taken over j in the closed range [i - range, i + range].
882 //
883 // When calculating U_i the j in the sum can extend beyond the bounds
884 // of V. If wrap_around is true, then V_j = V_{j mod F} where F is the
885 // size of V, which is the number of feature maps. If wrap_around is
886 // false, then V_j = 0 for j outside [0, F-1].
887 //
888 // If segment_size <= F, where F is the number of feature_maps, then
889 // segment_size has no effect. Otherwise, each consecutive segment of
890 // segment_size entries in V are normalized separately.
891 //
892 // Not all StreamExecutors allow wrap_around == true or segment_size
893 // != 64. Some do not implement normalization at all.
894 class NormalizeDescriptor {
895  public:
896   NormalizeDescriptor();
897 
898   NormalizeDescriptor& set_bias(float bias) {
899     bias_ = bias;
900     return *this;
901   }
902 
903   NormalizeDescriptor& set_range(int32 range) {
904     range_ = range;
905     return *this;
906   }
907 
908   NormalizeDescriptor& set_alpha(float alpha) {
909     alpha_ = alpha;
910     return *this;
911   }
912 
913   NormalizeDescriptor& set_beta(float beta) {
914     beta_ = beta;
915     return *this;
916   }
917 
918   NormalizeDescriptor& set_wrap_around(bool wrap_around) {
919     wrap_around_ = wrap_around;
920     return *this;
921   }
922 
923   NormalizeDescriptor& set_segment_size(int32 segment_size) {
924     segment_size_ = segment_size;
925     return *this;
926   }
927 
928   void CloneFrom(const NormalizeDescriptor& other);
929 
930   string ToString() const;
931   string ToShortString() const;
932 
933   float bias() const { return bias_; }
934   int32 range() const { return range_; }
935   float alpha() const { return alpha_; }
936   float beta() const { return beta_; }
937   bool wrap_around() const { return wrap_around_; }
938   int32 segment_size() const { return segment_size_; }
939 
940  private:
941   float bias_;
942   int32 range_;
943   float alpha_;
944   float beta_;
945   bool wrap_around_;
946   int32 segment_size_;
947 };
948 
949 // Returns a string representation of the given activation mode.
950 string ActivationModeString(ActivationMode mode);
951 
952 // Describes the operation that DoElementwiseOperation should perform on its
953 // inputs.
954 enum class ElementwiseOperation { kAdd, kMultiply };
955 
956 string ElementwiseOperationString(ElementwiseOperation op);
957 
958 // A simple class representing the version of the backing library, to
959 // workaround the "too perfect forwarding" issue in gcc6+ compilers.
960 // See PR#16309 and issue #18402 for links discussing the issue.
961 class VersionInfo {
962  public:
963   VersionInfo(int major = 0, int minor = 0, int patch = 0)
964       : major_(major), minor_(minor), patch_(patch) {}
965   int major_version() const { return major_; }
966   int minor_version() const { return minor_; }
967   int patch() const { return patch_; }
968 
969  private:
970   int major_;
971   int minor_;
972   int patch_;
973 };
974 
975 // Suite of operations typically used for implementing Deep/Convolutional Neural
976 // Nets. Note: A false return value of an operation indicates the
977 // implementation is not available.
978 //
979 // TODO(b/118763918): this class (or rather dispatch table) has several
980 // problems:
981 // * Some overloads are missing. Ideally we want to have template virtual
982 //   functions while the template arguments is a closed set. However, we don't
983 //   get that from the language.
984 // * The API is a union of cuDNN and another private backend. Only 10% of the
985 //   functions are actually implemented by both backends, the rest are
986 //   actually backend-specific. The massive interface creates extra mental
987 //   burden.
988 // * Poor error handling: the API should return Status objects.
989 //
990 // PrepareForConvolution is an example for how new APIs should be written.
991 class DnnSupport {
992  public:
993   DnnSupport() {}
994   virtual ~DnnSupport() {}
995 
996   virtual port::Status Init() = 0;
997 
998   // Gets the version of the backing library, as a VersionInfo object.
999   virtual port::StatusOr<VersionInfo> GetVersion() {
1000     return port::UnimplementedError(
1001         "DnnSupport::GetVersion not implemented on this platform.");
1002   }
1003 
1004   // Performs a single-precision forward batch normalization operation onto
1005   // the stream.
1006   //
1007   // Arguments:
1008   //  stream: borrowed pointer to the stream that the batch normalization
1009   //    operation should be enqueued onto.
1010   //  x: input data.
1011   //  scale: scaling parameters.
1012   //  offset: offset parameters.
1013   //  estimated_mean: population mean estimated during training.
1014   //    Used for inference only; empty for training.
1015   //  estimated_variance: population variance estimated during training,
1016   //    used for inference only; empty for training.
1017   //  side_input: optional input that is element-wise added to the output of
1018   //    batch normalization.
1019   //  x_desc: dimensions of the input data, which is the same as the dimensions
1020   //    of the output and side input.
1021   //  scale_offset_desc: dimensions of scale and offset.
1022   //  epsilon: a small floating point number added to the variance of x.
1023   //  activation_mode: activation applied to the result of batch normalization
1024   //    (or after adding optional side input)
1025   //  y: output data.
1026   //  batch_mean: batch mean, to be used to compute the running mean.
1027   //  batch_variance: batch variance, to be used to compute
1028   //    the running variance.
1029   //  reserve_space_1: saved mean, to be reused in the backward gradient
1030   //    computation.
1031   //  reserve_space_2: saved inv_var (1/sqrt(epsilon + variance), to be reused
1032   //    in the backward gradient computation.
1033   //  is_training: Set to true for training, false for inference.
1034   //  var_to_inv_var: a function to convert the variance to inverted variance
1035   //    for cuDNN v4 forward inference.
1036   //  inv_var_to_var: a function to convert the inverted variance to
1037   //    variance for cuDNN v4 forward training, to be used for TensorFlow
1038   //    to calculate the running variance.
1039   virtual bool DoBatchNormalizationForward(
1040       Stream* stream, const DeviceMemory<float>& x,
1041       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1042       const DeviceMemory<float>& estimated_mean,
1043       const DeviceMemory<float>& estimated_variance,
1044       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
1045       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1046       const double exponential_average_factor,
1047       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
1048       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1049       DeviceMemory<float>* reserve_space_1,
1050       DeviceMemory<float>* reserve_space_2, bool is_training,
1051       ScratchAllocator* reserve_space_allocator,
1052       ScratchAllocator* workspace_allocator,
1053       std::function<const DeviceMemory<float>&()> var_to_inv_var,
1054       std::function<void()> inv_var_to_var) {
1055     return false;
1056   }
1057 
1058   // Performs a half-precision forwards batch normalization operation onto the
1059   // stream. See DoBatchNormalizationForward above for argument details.
1060   virtual bool DoBatchNormalizationForward(
1061       Stream* stream, const DeviceMemory<Eigen::half>& x,
1062       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
1063       const DeviceMemory<float>& estimated_mean,
1064       const DeviceMemory<float>& estimated_variance,
1065       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
1066       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1067       const double exponential_average_factor,
1068       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
1069       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
1070       DeviceMemory<float>* reserve_space_1,
1071       DeviceMemory<float>* reserve_space_2, bool is_training,
1072       ScratchAllocator* reserve_space_allocator,
1073       ScratchAllocator* workspace_allocator,
1074       std::function<const DeviceMemory<float>&()> var_to_inv_var,
1075       std::function<void()> inv_var_to_var) {
1076     return false;
1077   }
1078 
1079   // Performs a single-precision backward batch normalization gradient
1080   // computation operation onto the stream.
1081   //
1082   // Arguments:
1083   //  stream: borrowed pointer to the stream that the batch normalization
1084   //    gradient computation operation should be enqueued onto.
1085   //  y_backprop: gradient with regard to output y.
1086   //  x: input data.
1087   //  scale: scaling parameters.
1088   //  inv_var: 1/sqrt(epsilon + variance) of x.
1089   //  x_desc: dimensions of the input data, which is the same as the dimensions
1090   //    of the output.
1091   //  scale_offset_desc: dimensions of scale and offset.
1092   //  epsilon: a small floating point number added to the variance of x.
1093   //  x_backprop: gradient with respect to input x.
1094   //  scale_backprop: gradient with respect to scale.
1095   //  offset_backprop: gradient with respect to offset.
1096   virtual bool DoBatchNormalizationBackward(
1097       Stream* stream, const DeviceMemory<float>& y_backprop,
1098       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
1099       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
1100       const dnn::BatchDescriptor& x_desc,
1101       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1102       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
1103       DeviceMemory<float>* offset_backprop,
1104       DeviceMemory<uint8>* reserve_space_data,
1105       ScratchAllocator* workspace_allocator) {
1106     return false;
1107   }
1108 
1109   // Performs a half-precision backward batch normalization gradient computation
1110   // operation onto the stream. See DoBatchNormalizationBackward above for
1111   // argument details.
1112   virtual bool DoBatchNormalizationBackward(
1113       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
1114       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
1115       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
1116       const dnn::BatchDescriptor& x_desc,
1117       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
1118       DeviceMemory<Eigen::half>* x_backprop,
1119       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
1120       DeviceMemory<uint8>* reserve_space_data,
1121       ScratchAllocator* workspace_allocator) {
1122     return false;
1123   }
1124 
1125   // Enqueues a fused convolution operation onto the stream.
1126   // We provide several variants with different types for inputs, biases and
1127   // scaling parameters.
1128   //
1129   // Arguments (all borrowed):
1130   //  stream: borrowed pointer to the stream that the 'convolve' operation
1131   //    should be enqueued onto.
1132   //  conv_input_descriptor: dimensions of the convolution input layer.
1133   //  conv_input_data: un-owned device memory region which contains the
1134   //    convolution input.
1135   //  conv_input_scale: a floating point scale to multiply with each element
1136   //    of conv_input_data.
1137   //  filter_descriptor: dimensions of the convolution filter.
1138   //  filter_data: un-owned device memory region which contains the
1139   //    convolution filter weights.
1140   //  convolution_descriptor: stride of the convolution filter.
1141   //  biases: un-owned device memory region containing biases to add to the
1142   //    input.
1143   //  activation_mode: Type of activation to perform.
1144   //  side_input_data: un-owned device memory region which contains optional
1145   //    side input data. If 'side_input_scale' is non-zero, then this must
1146   //    point to data in the tensor shape specified by output_shape.
1147   //    It will be scaled by 'side_input_scale' and added to the convolution
1148   //    result and bias prior to applying the activation function.
1149   //  side_input_scale: a floating point scale to multiply with each element
1150   //    of side_input_data.
1151   //  output_descriptor: dimensions of the output layer.
1152   //  output_data: un-owned device memory region in which to place the
1153   //    convolution result.
1154   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1155   //    space in order to speed up the convolution operation.
1156   //  algorithm_config: specifies which algorithm should be used for the
1157   //    operation.
1158   //  output_profile_result: the output profile result for this call. The
1159   //    profiling is only enabled when this is not nullptr.
1160   //
1161   // conv_input_descriptor, filter_descriptor, convolution_descriptor and
1162   // output_descriptor together specify exactly how the convolution is aligned
1163   // with the input data:
1164   //
1165   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1166   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1167   // * input dimensions / filter stride == output dimensions
1168   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1169   //   same size - this requires padding the input.
1170   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1171   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1172   //   that if the inverse of the filter is applied to the output in VALID mode
1173   //   the result is the same size as the input - this requires even more
1174   //   padding of the input.
1175   virtual bool DoFusedConvolve(
1176       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1177       const DeviceMemory<double>& conv_input_data, double conv_input_scale,
1178       const dnn::FilterDescriptor& filter_descriptor,
1179       const DeviceMemory<double>& filter_data,
1180       const dnn::ConvolutionDescriptor& convolution_descriptor,
1181       const DeviceMemory<double>& side_input_data, double side_input_scale,
1182       const dnn::BatchDescriptor& bias_descriptor,
1183       const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
1184       const dnn::BatchDescriptor& output_descriptor,
1185       DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
1186       const dnn::AlgorithmConfig& algorithm_config,
1187       dnn::ProfileResult* output_profile_result) {
1188     return false;
1189   }
1190 
1191   // This is the float version of DoFusedConvolve.
1192   virtual bool DoFusedConvolve(
1193       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1194       const DeviceMemory<float>& conv_input_data, float conv_input_scale,
1195       const dnn::FilterDescriptor& filter_descriptor,
1196       const DeviceMemory<float>& filter_data,
1197       const dnn::ConvolutionDescriptor& convolution_descriptor,
1198       const DeviceMemory<float>& side_input_data, float side_input_scale,
1199       const dnn::BatchDescriptor& bias_descriptor,
1200       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
1201       const dnn::BatchDescriptor& output_descriptor,
1202       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
1203       const dnn::AlgorithmConfig& algorithm_config,
1204       dnn::ProfileResult* output_profile_result) {
1205     return false;
1206   }
1207 
1208   // This is the Eigen::half version of DoFusedConvolve.
1209   // The scaling parameters are still floats.
1210   virtual bool DoFusedConvolve(
1211       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1212       const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
1213       const dnn::FilterDescriptor& filter_descriptor,
1214       const DeviceMemory<Eigen::half>& filter_data,
1215       const dnn::ConvolutionDescriptor& convolution_descriptor,
1216       const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
1217       const dnn::BatchDescriptor& bias_descriptor,
1218       const DeviceMemory<Eigen::half>& biases,
1219       dnn::ActivationMode activation_mode,
1220       const dnn::BatchDescriptor& output_descriptor,
1221       DeviceMemory<Eigen::half>* output_data,
1222       ScratchAllocator* scratch_allocator,
1223       const dnn::AlgorithmConfig& algorithm_config,
1224       dnn::ProfileResult* output_profile_result) {
1225     return false;
1226   }
1227 
1228   // This is the int8 version of DoFusedConvolve.
1229   // The bias input and scaling parameters are floats.
1230   virtual bool DoFusedConvolve(
1231       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
1232       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
1233       const dnn::FilterDescriptor& filter_descriptor,
1234       const DeviceMemory<int8>& filter_data,
1235       const dnn::ConvolutionDescriptor& convolution_descriptor,
1236       const DeviceMemory<int8>& side_input_data, float side_input_scale,
1237       const dnn::BatchDescriptor& bias_descriptor,
1238       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
1239       const dnn::BatchDescriptor& output_descriptor,
1240       DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
1241       const dnn::AlgorithmConfig& algorithm_config,
1242       dnn::ProfileResult* output_profile_result) {
1243     return false;
1244   }
1245 
1246   // This is the int8 version of DoFusedConvolve.
1247   // The output, bias input and scaling parameters are floats.
1248   virtual bool DoFusedConvolve(
1249       Stream* /*stream*/, const dnn::BatchDescriptor& /*conv_input_descriptor*/,
1250       const DeviceMemory<int8>& /*conv_input_data*/, float /*conv_input_scale*/,
1251       const dnn::FilterDescriptor& /*filter_descriptor*/,
1252       const DeviceMemory<int8>& /*filter_data*/,
1253       const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
1254       const DeviceMemory<float>& /*side_input_data*/,
1255       float /*side_input_scale*/,
1256       const dnn::BatchDescriptor& /*bias_descriptor*/,
1257       const DeviceMemory<float>& /*biases*/,
1258       dnn::ActivationMode /*activation_mode*/,
1259       const dnn::BatchDescriptor& /*output_descriptor*/,
1260       DeviceMemory<float>* /*output_data*/,
1261       ScratchAllocator* /*scratch_allocator*/,
1262       const dnn::AlgorithmConfig& /*algorithm_config*/,
1263       dnn::ProfileResult* /*output_profile_result*/) {
1264     return false;
1265   }
1266 
1267   template <typename ElementType, typename OutputType>
1268   port::Status PrepareForConvolution(
1269       ConvolutionKind kind, Stream* stream,
1270       const BatchDescriptor& batch_descriptor,
1271       DeviceMemory<ElementType> input_data,
1272       const FilterDescriptor& filter_descriptor,
1273       DeviceMemory<ElementType> filter_data,
1274       const BatchDescriptor& output_descriptor,
1275       DeviceMemory<OutputType> output_data,
1276       const ConvolutionDescriptor& convolution_descriptor,
1277       const AlgorithmConfig& algorithm_config,
1278       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
1279       DeviceMemory<uint8>* scratch_memory) {
1280     return DoPrepareForConvolution(
1281         kind, ToDataType<ElementType>::value, stream, batch_descriptor,
1282         input_data, filter_descriptor, filter_data, output_descriptor,
1283         output_data, convolution_descriptor, algorithm_config,
1284         scratch_allocator, algorithm_desc, scratch_memory);
1285   }
1286 
1287   // Enqueues a single-precision convolution operation onto the stream.
1288   //
1289   // Arguments (all borrowed):
1290   //  stream: borrowed pointer to the stream that the 'convolve' operation
1291   //    should be enqueued onto.
1292   //  input_descriptor: dimensions of the input layer.
1293   //  input_data: un-owned device memory region which contains the
1294   //    convolution input.
1295   //  filter_descriptor: dimensions of the convolution filter.
1296   //  convolution_descriptor: stride of the convolution filter.
1297   //  output_descriptor: dimensions of the output layer.
1298   //  output_data: un-owned device memory region in which to place the
1299   //    convolution result.
1300   //  algorithm_desc: specifies which algorithm should be used for the
1301   //    operation.
1302   //  scratch: un-owned device memory for scratch space in order to speed up
1303   //    the convolution operation.
1304   //  output_profile_result: the output profile result for this call. The
1305   //    profiling is only enabled when this is not nullptr.
1306   //
1307   // input_descriptor, filter_descriptor, convolution_descriptor and
1308   // output_descriptor together specify exactly how the convolution is aligned
1309   // with the input data:
1310   //
1311   // * (input dimensions - filter size + 1) / filter stride == output dimensions
1312   //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
1313   // * input dimensions / filter stride == output dimensions
1314   //   corresponds to dist_belief padding = SAME, i.e. input and output are the
1315   //   same size - this requires padding the input.
1316   // * (input dimensions + filter size - 1) / filter stride == output dimensions
1317   //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
1318   //   that if the inverse of the filter is applied to the output in VALID mode
1319   //   the result is the same size as the input - this requires even more
1320   //   padding of the input.
1321   virtual port::Status DoConvolve(
1322       ConvolutionKind kind, DataType element_type, DataType output_type,
1323       Stream* stream, const BatchDescriptor& input_descriptor,
1324       DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor,
1325       DeviceMemoryBase filter_data, const BatchDescriptor& output_descriptor,
1326       DeviceMemoryBase output_data,
1327       const ConvolutionDescriptor& convolution_descriptor,
1328       AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
1329       ProfileResult* output_profile_result) = 0;
1330 
1331   template <typename ElementType, typename OutputType>
1332   bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1333                   const DeviceMemory<ElementType>& input_data,
1334                   const dnn::FilterDescriptor& filter_descriptor,
1335                   const DeviceMemory<ElementType>& filter_data,
1336                   const dnn::ConvolutionDescriptor& convolution_descriptor,
1337                   const dnn::BatchDescriptor& output_descriptor,
1338                   DeviceMemory<OutputType>* output_data,
1339                   const dnn::AlgorithmDesc& algorithm_desc,
1340                   DeviceMemory<uint8>* scratch_memory,
1341                   ProfileResult* output_profile_result) {
1342     return IsStatusOk(
1343         DoConvolve(ConvolutionKind::FORWARD, ToDataType<ElementType>::value,
1344                    ToDataType<OutputType>::value, stream, input_descriptor,
1345                    input_data, filter_descriptor, filter_data,
1346                    output_descriptor, *output_data, convolution_descriptor,
1347                    algorithm_desc, *scratch_memory, output_profile_result),
1348         !output_profile_result);
1349   }
1350 
1351   // Return a list of algorithms supported by the forward convolution pass.
1352   // cc_major and cc_minor are the compute capabilities of the device.
1353   virtual bool GetConvolveAlgorithms(
1354       bool with_winograd_nonfused, int cc_major, int cc_minor,
1355       std::vector<AlgorithmDesc>* out_algorithms);
1356 
1357   virtual bool GetMIOpenConvolveAlgorithms(
1358       dnn::ConvolutionKind kind, Stream* stream, dnn::DataType element_type,
1359       const dnn::BatchDescriptor& input_descriptor,
1360       const dnn::FilterDescriptor& filter_descriptor,
1361       const dnn::ConvolutionDescriptor& convolution_descriptor,
1362       const dnn::BatchDescriptor& output_descriptor,
1363       std::vector<ProfileResult>* out_algorithms);
1364 
1365   // Returns a list of supported rnn algorithms.
1366   virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
1367 
1368   // Version of DoConvolve that uses pre-quantized 8 bit coefficients.
1369   // coefficient_scales specifies the scaling of each column of coefficients:
1370   // original float coefficient[row * num_columns + column] =
1371   //     quantized coefficient[row * num_columns + column] *
1372   //     coefficient_scales[column].
1373   virtual bool DoConvolveQuantized(
1374       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1375       const DeviceMemory<float>& input_data,
1376       const dnn::FilterDescriptor& filter_descriptor,
1377       const DeviceMemory<int8>& filter_coefficients,
1378       const DeviceMemory<float>& coefficient_scales,
1379       const dnn::ConvolutionDescriptor& convolution_descriptor,
1380       const dnn::BatchDescriptor& output_descriptor,
1381       DeviceMemory<float>* output_data) = 0;
1382 
1383   // Same as DoConvolveQuantized above, but int8 filter coefficients.
1384   virtual bool DoConvolveQuantized(
1385       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
1386       const DeviceMemory<float>& input_data,
1387       const dnn::FilterDescriptor& filter_descriptor,
1388       const DeviceMemory<int16>& filter_coefficients,
1389       const DeviceMemory<float>& coefficient_scales,
1390       const dnn::ConvolutionDescriptor& convolution_descriptor,
1391       const dnn::BatchDescriptor& output_descriptor,
1392       DeviceMemory<float>* output_data) = 0;
1393 
1394   // Variation of the above with the weight matrix split into two matrices.
1395   // first_weights: Coefficients of the first matrix.
1396   // second_weights: Coefficients of the second matrix.
1397   // depth_multiplier: specifies the columns of the first matrix and rows
1398   // of the second one - first_weights columns = depth_multiplier,
1399   // second_weights rows = depth_multiplier *
1400   //                       filter_descriptor.input_feature_map_count().
1401   // see go/separable for documentation on separable convolutions.
1402   virtual bool DoSeparableConvolve(
1403       Stream* stream, const BatchDescriptor& input_descriptor,
1404       const DeviceMemory<float>& input_data,
1405       const FilterDescriptor& filter_descriptor, int depth_multiplier,
1406       const DeviceMemory<float>& first_weights,
1407       const DeviceMemory<float>& second_weights,
1408       const ConvolutionDescriptor& convolution_descriptor,
1409       const BatchDescriptor& output_descriptor,
1410       DeviceMemory<float>* output_data) = 0;
1411 
1412   // Enqueues a single-precision backward convolution (for data) operation onto
1413   // the stream.
1414   //
1415   // Arguments:
1416   //  stream: borrowed pointer to the stream that the 'convolve' operation
1417   //    should be enqueued onto.
1418   //  filter_descriptor: dimensions of the convolution filter.
1419   //  filter_data: coefficients for the convolution filter.
1420   //  output_descriptor: dimensions of the output gradients, which is the same
1421   //    as the dimensions of the output.
1422   //  backward_output_data: un-owned device memory region which contains the
1423   //    backprop of the output.
1424   //  convolution_descriptor: stride of the convolution filter.
1425   //  input_descriptor: dimensions of the input layer.
1426   //  backward_input_data: un-owned device memory region in which to place the
1427   //    backprop of the input.
1428   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1429   //    space in order to speed up the convolution operation.
1430   template <typename ElementType>
1431   bool DoConvolveBackwardData(
1432       Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
1433       const DeviceMemory<ElementType>& filter_data,
1434       const dnn::BatchDescriptor& output_descriptor,
1435       const DeviceMemory<ElementType>& backward_output_data,
1436       const dnn::ConvolutionDescriptor& convolution_descriptor,
1437       const dnn::BatchDescriptor& input_descriptor,
1438       DeviceMemory<ElementType>* backward_input_data,
1439       const dnn::AlgorithmDesc& algorithm_desc,
1440       DeviceMemory<uint8>* scratch_memory,
1441       ProfileResult* output_profile_result) {
1442     return IsStatusOk(
1443         DoConvolve(
1444             ConvolutionKind::BACKWARD_DATA, ToDataType<ElementType>::value,
1445             ToDataType<ElementType>::value, stream, input_descriptor,
1446             *backward_input_data, filter_descriptor, filter_data,
1447             output_descriptor, backward_output_data, convolution_descriptor,
1448             algorithm_desc, *scratch_memory, output_profile_result),
1449         !output_profile_result);
1450   }
1451 
1452   // Return a list of algorithms supported by the backward convolution pass for
1453   // data.
1454   virtual bool GetConvolveBackwardDataAlgorithms(
1455       bool with_winograd_nonfused, int cc_major, int cc_minor,
1456       std::vector<AlgorithmDesc>* out_algorithms);
1457 
1458   // Enqueues a single-precision backward convolution (for filter) operation
1459   // onto the stream.
1460   //
1461   // Arguments:
1462   //  stream: borrowed pointer to the stream that the 'convolve' operation
1463   //    should be enqueued onto.
1464   //  input_descriptor: dimensions of the input layer.
1465   //  input_data: un-owned device memory region which contains the
1466   //    convolution input.
1467   //  output_descriptor: dimensions of the output gradients, which is the same
1468   //    as the dimensions of the output.
1469   //  backward_output_data: un-owned device memory region which contains the
1470   //    backprop of the output.
1471   //  convolution_descriptor: stride of the convolution filter.
1472   //  filter_descriptor: dimensions of the convolution filter.
1473   //  backward_filter_data: un-owned device memory region in which to place the
1474   //    backprop of the filter.
1475   //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
1476   //    space in order to speed up the convolution operation.
1477   template <typename ElementType>
1478   bool DoConvolveBackwardFilter(
1479       Stream* stream, const BatchDescriptor& input_descriptor,
1480       const DeviceMemory<ElementType>& input_data,
1481       const BatchDescriptor& output_descriptor,
1482       const DeviceMemory<ElementType>& backward_output_data,
1483       const ConvolutionDescriptor& convolution_descriptor,
1484       const FilterDescriptor& filter_descriptor,
1485       DeviceMemory<ElementType>* backward_filter_data,
1486       const dnn::AlgorithmDesc& algorithm_desc,
1487       DeviceMemory<uint8>* scratch_memory,
1488       ProfileResult* output_profile_result) {
1489     return IsStatusOk(
1490         DoConvolve(
1491             ConvolutionKind::BACKWARD_FILTER, ToDataType<ElementType>::value,
1492             ToDataType<ElementType>::value, stream, input_descriptor,
1493             input_data, filter_descriptor, *backward_filter_data,
1494             output_descriptor, backward_output_data, convolution_descriptor,
1495             algorithm_desc, *scratch_memory, output_profile_result),
1496         !output_profile_result);
1497   }
1498 
1499   // Return a list of algorithms supported by the backward convolution pass for
1500   // filters.
1501   virtual bool GetConvolveBackwardFilterAlgorithms(
1502       bool with_winograd_nonfused, int cc_major, int cc_minor,
1503       std::vector<AlgorithmDesc>* out_algorithms);
1504 
1505   // Enqueues a single-precision backward convolution (for bias) operation onto
1506   // the stream.
1507   //
1508   // Arguments:
1509   //  stream: borrowed pointer to the stream that the 'convolve' operation
1510   //    should be enqueued onto.
1511   //  input_descriptor: dimensions of the input layer.
1512   //  input_data: un-owned device memory region which contains the
1513   //    convolution input.
1514   //  bias_descriptor: dimensions of the bias tensor. Should be the same as the
1515   //    input dimensions, but with the spatial dimensions set to 1.
1516   //  backward_filter_data: un-owned device memory region in which to place the
1517   //    backprop of the bias.
1518   virtual bool DoConvolveBackwardBias(Stream* stream,
1519                                       const BatchDescriptor& input_descriptor,
1520                                       const DeviceMemory<float>& input_data,
1521                                       const BatchDescriptor& bias_descriptor,
1522                                       DeviceMemory<float>* backward_bias_data) {
1523     return false;
1524   }
1525 
1526   virtual bool DoConvolveBackwardBias(
1527       Stream* stream, const BatchDescriptor& input_descriptor,
1528       const DeviceMemory<double>& input_data,
1529       const BatchDescriptor& bias_descriptor,
1530       DeviceMemory<double>* backward_bias_data) {
1531     return false;
1532   }
1533 
1534   virtual bool DoConvolveBackwardBias(
1535       Stream* stream, const BatchDescriptor& input_descriptor,
1536       const DeviceMemory<Eigen::half>& input_data,
1537       const BatchDescriptor& bias_descriptor,
1538       DeviceMemory<Eigen::half>* backward_bias_data) {
1539     return false;
1540   }
1541 
1542   // Fully connects the "nodes" (float values) in input_data with
1543   // shape input_dimensions to output_data with output_dimensions
1544   // using provided weights. This is equivalent to computing a matrix
1545   // product, hence the name MatMul.
1546   //
1547   // A BatchDescriptor has four dimensions: batch, y, x, depth. Matrix products
1548   // happen in two dimensions. To get down to two dimensions, we consider the
1549   // input y, x and depth dimension as one combined dimension T. For now,
1550   // assume that the output height and width are 1 and let OD be the output
1551   // depth.
1552   //
1553   // There are three device memory buffers passed in to this
1554   // function. We can now view all three as matrices:
1555   //
1556   //   input_data: A batch x T matrix
1557   //   weights: A T x OD matrix
1558   //   output_data: A batch x OD matrix
1559   //
1560   // This function then computes the matrix product of input_data and
1561   // weights and writes the result into output_data.
1562   //
1563   // Here the weights buffer is in row major order, i.e. the first OD
1564   // entries in weights are the first row, the second OD entries in
1565   // weights are the second row and so on.
1566   //
1567   // The case for output width*height > 1 is more complicated. Let K =
1568   // OY * OX where OY is the output height and OX is the output
1569   // width. Then weights is divided into K sub-arrays W_i, for
1570   // i=0,...,k-1, that each represent a T x OD matrix. This function
1571   // then computes the K matrix multiplications of input_data with
1572   // each W_i. This creates K matrices with dimensions batch x
1573   // OD. These K matrices are concatenated horizontally to form one
1574   // larger matrix with dimensions batch x (K*OD); note that this is
1575   // not the same as concatenating the bytes of the matrices. The
1576   // combined matrix can then be interpreted as a tensor with
1577   // dimensions (batch, OY, OX, OD). If the output tensor format is
1578   // not kBatchYXDepth, this function would then need to arrange for
1579   // the output to be in the requested layout, if that is
1580   // supported. Note that the case K=1 is equivalent to the
1581   // description above. It is recommended to prefer the case K=1.
1582   //
1583   // Arguments (all borrowed):
1584   //  stream: borrowed pointer to the stream that the 'fully connect' operation
1585   //    should be enqueued onto.
1586   //  output_data: un-owned device memory region in which to place the
1587   //    fully connected result.
1588   virtual bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
1589                         const DeviceMemory<float>& weights,
1590                         const dnn::BatchDescriptor& input_dimensions,
1591                         const dnn::BatchDescriptor& output_dimensions,
1592                         DeviceMemory<float>* output_data) = 0;
1593 
1594   // Version of DoMatMul that uses pre-quantized 8 bit weights.
1595   // weight_scales specifies the scaling of each column of weights:
1596   // original float weight[row * num_columns + column] =
1597   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1598   virtual bool DoMatMulQuantized(Stream* stream,
1599                                  const DeviceMemory<float>& input_data,
1600                                  const DeviceMemory<int8>& quantized_weights,
1601                                  const DeviceMemory<float>& weight_scales,
1602                                  const dnn::BatchDescriptor& input_dimensions,
1603                                  const dnn::BatchDescriptor& output_dimensions,
1604                                  DeviceMemory<float>* output_data) = 0;
1605 
1606   // Version of DoMatMul that uses pre-quantized 16 bit weights.
1607   // weight_scales specifies the scaling of each column of weights:
1608   // original float weight[row * num_columns + column] =
1609   //     quantized_weight[row * nnum_columns + column] * weight_scales[column].
1610   virtual bool DoMatMulQuantized(Stream* stream,
1611                                  const DeviceMemory<float>& input_data,
1612                                  const DeviceMemory<int16>& quantized_weights,
1613                                  const DeviceMemory<float>& weight_scales,
1614                                  const dnn::BatchDescriptor& input_dimensions,
1615                                  const dnn::BatchDescriptor& output_dimensions,
1616                                  DeviceMemory<float>* output_data) = 0;
1617 
1618   // Adds biases to the feature maps in input_data producing
1619   // output_data. input_data can equal output_data, but must not
1620   // partially overlap it.
1621   //
1622   // Let K = count() * height() * width() and N = feature_map_count()
1623   // on dimensions. Then input_value contains K*N values and biases
1624   // contains N values. We can thus logically consider input_value to
1625   // contain K vectors of N elements each. This function adds biases
1626   // to each of those N vectors.
1627   //
1628   // TODO(broune): This works differently when width() * height() > 1
1629   // and the call to ThenBiasAdd() follows a call to ThenMatMul(). In
1630   // that case there should be width() * height() *
1631   // feature_map_count() biases, but this is not implemented on all
1632   // StreamExecutors.
1633   //
1634   // Arguments (all borrowed):
1635   //  stream: borrowed pointer to the stream that the 'bias add' operation
1636   //    should be enqueued onto.
1637   //  input_data: un-owned device memory region containing the input.
1638   //  biases: un-owned device memory region containing biases to add to the
1639   //    input.
1640   //  dimensions: dimensions of input_data and output_data.
1641   //  output_data: un-owned device memory region in which to place the result.
1642   virtual bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
1643                          const DeviceMemory<float>& biases,
1644                          const dnn::BatchDescriptor& dimensions,
1645                          DeviceMemory<float>* output_data) = 0;
1646 
1647   // Performs a forward pooling operation on input_data, writing to
1648   // output_data. See PoolingDescriptor for how to configure the
1649   // pooling operation.
1650   //
1651   // Pooling happens as a window that moves across the Y and X
1652   // dimensions of input_data, where each position of the window
1653   // yields one output value. E.g. for max pooling, the computed value
1654   // is the maximum element in the window. The operation is applied
1655   // independently to each batch and at each feature map (depth), so
1656   // that the output depth and feature_map_count are the same as for
1657   // the input. The output width and height can be different.
1658   //
1659   // See PoolingDescriptor for how to configure the pooling operation.
1660   virtual bool DoPoolForward(Stream* stream,
1661                              const dnn::PoolingDescriptor& pooling_dimensions,
1662                              const dnn::BatchDescriptor& input_dimensions,
1663                              const DeviceMemory<float>& input_data,
1664                              const dnn::BatchDescriptor& output_dimensions,
1665                              DeviceMemory<float>* output_data,
1666                              ScratchAllocator* workspace_allocator) = 0;
1667 
1668   virtual bool DoPoolForward(Stream* stream,
1669                              const dnn::PoolingDescriptor& pooling_dimensions,
1670                              const dnn::BatchDescriptor& input_dimensions,
1671                              const DeviceMemory<double>& input_data,
1672                              const dnn::BatchDescriptor& output_dimensions,
1673                              DeviceMemory<double>* output_data,
1674                              ScratchAllocator* workspace_allocator) {
1675     LOG(FATAL) << "DoPoolForward not implemented for double.";
1676     return false;
1677   }
1678 
1679   virtual bool DoPoolForward(Stream* stream,
1680                              const dnn::PoolingDescriptor& pooling_dimensions,
1681                              const dnn::BatchDescriptor& input_dimensions,
1682                              const DeviceMemory<Eigen::half>& input_data,
1683                              const dnn::BatchDescriptor& output_dimensions,
1684                              DeviceMemory<Eigen::half>* output_data,
1685                              ScratchAllocator* workspace_allocator) {
1686     LOG(FATAL) << "DoPoolForward not implemented for float16.";
1687     return false;
1688   }
1689 
1690   virtual bool DoPoolForward(Stream* stream,
1691                              const dnn::PoolingDescriptor& pooling_dimensions,
1692                              const dnn::BatchDescriptor& input_dimensions,
1693                              const DeviceMemory<int8>& input_data,
1694                              const dnn::BatchDescriptor& output_dimensions,
1695                              DeviceMemory<int8>* output_data,
1696                              ScratchAllocator* workspace_allocator) {
1697     LOG(FATAL) << "DoPoolForward not implemented for int8.";
1698     return false;
1699   }
1700 
1701   // Performs differentiation of the pooling operation.
1702   virtual bool DoPoolBackward(Stream* stream,
1703                               const dnn::PoolingDescriptor& pooling_dimensions,
1704                               const dnn::BatchDescriptor& input_dimensions,
1705                               const DeviceMemory<double>& input_data,
1706                               const dnn::BatchDescriptor& output_dimensions,
1707                               const DeviceMemory<double>& output_data,
1708                               const DeviceMemory<double>& input_diff_data,
1709                               DeviceMemory<double>* output_diff_data,
1710                               ScratchAllocator* workspace_allocator) {
1711     LOG(FATAL) << "DoPoolBackward not implemented.";
1712     return false;
1713   }
1714 
1715   virtual bool DoPoolBackward(Stream* stream,
1716                               const dnn::PoolingDescriptor& pooling_dimensions,
1717                               const dnn::BatchDescriptor& input_dimensions,
1718                               const DeviceMemory<float>& input_data,
1719                               const dnn::BatchDescriptor& output_dimensions,
1720                               const DeviceMemory<float>& output_data,
1721                               const DeviceMemory<float>& input_diff_data,
1722                               DeviceMemory<float>* output_diff_data,
1723                               ScratchAllocator* workspace_allocator) {
1724     LOG(FATAL) << "DoPoolBackward not implemented.";
1725     return false;
1726   }
1727 
1728   virtual bool DoPoolBackward(Stream* stream,
1729                               const dnn::PoolingDescriptor& pooling_dimensions,
1730                               const dnn::BatchDescriptor& input_dimensions,
1731                               const DeviceMemory<Eigen::half>& input_data,
1732                               const dnn::BatchDescriptor& output_dimensions,
1733                               const DeviceMemory<Eigen::half>& output_data,
1734                               const DeviceMemory<Eigen::half>& input_diff_data,
1735                               DeviceMemory<Eigen::half>* output_diff_data,
1736                               ScratchAllocator* workspace_allocator) {
1737     LOG(FATAL) << "DoPoolBackward not implemented.";
1738     return false;
1739   }
1740 
1741   // Applies local response normalization to the values from input_data and
1742   // writes the result to output_data.
1743   //
1744   // See comments on NormalizeDescriptor for a description of local response
1745   // normalization.
1746   virtual bool DoNormalizeWithDimensions(
1747       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1748       const dnn::BatchDescriptor& dimensions,
1749       const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
1750     return false;
1751   }
1752 
1753   // Performs backpropagation for the normalization operation
1754   //
1755   // Given raw data, its corresponding normalized output, and a gradient of some
1756   // unspecified function with respect to the normalized variables, computes the
1757   // gradient of that unspecified function with respect to the raw variables.
1758   //
1759   // The normalized data input array is expected to match the output that would
1760   // be obtained by running the raw data input array through the DoNormalize
1761   // method above.
1762   //
1763   // See comments on NormalizeDescriptor for a description of local response
1764   // normalization.
1765   virtual bool DoNormalizeBackwardWithDimensions(
1766       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
1767       const dnn::BatchDescriptor& dimensions,
1768       const DeviceMemory<float>& raw_data,
1769       const DeviceMemory<float>& normalized_data,
1770       const DeviceMemory<float>& normalized_variable_gradient,
1771       DeviceMemory<float>* raw_variable_gradient,
1772       ScratchAllocator* workspace_allocator) {
1773     return false;
1774   }
1775 
1776   // Applies an activation function (see ActivationMode) to all of the values
1777   // held on the device in 'input_data', whose dimensions are described by
1778   // 'dimensions'.
1779   //
1780   // Arguments (all borrowed):
1781   //  stream: borrowed pointer to the stream that the 'activate' operation
1782   //    should be enqueued onto.
1783   //  activation_mode: Type of activation to perform.
1784   //  input_data: un-owned device memory region which contains the
1785   //    activate input.
1786   //  output_data: un-owned device memory region in which to place the
1787   //    activate result.
1788   virtual bool DoActivate(Stream* stream, ActivationMode activation_mode,
1789                           const BatchDescriptor& dimensions,
1790                           const DeviceMemory<float>& input_data,
1791                           DeviceMemory<float>* output_data, uint64 options) {
1792     return false;
1793   }
1794 
1795   // Concatenates several layers into one, by concatenating the depth of each
1796   // layer at matching x and y coordinates.
1797   // The inputs must all have the same width and height, the output will have
1798   // the same width and height as the inputs and its depth will be the sum of
1799   // the input depths.
1800   //
1801   // Arguments (all borrowed):
1802   //  stream: borrowed pointer to the stream that the 'depth concatenate'
1803   // operation should be enqueued onto.
1804   //  input_dimensions: The dimensions of each input.
1805   //  input_data: un-owned device memory region which contains the
1806   //    input data for each input layer.
1807   //  output_data: un-owned device memory region in which to place the
1808   //    depth concatenate result.
1809   virtual bool DoDepthConcatenate(
1810       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1811       port::ArraySlice<const DeviceMemory<float>*> input_data,
1812       DeviceMemory<float>* output_data) = 0;
1813 
1814   // Concatenates several layers into one, by concatenating each in the
1815   // x-dimension or y-dimension, based on a user-specified flag.
1816   // For x-concatenation, layers are aligned at matching y and depth
1817   // coordinates, and for y-concatenation, they are aligned at matching x and
1818   // depth coordinates. The inputs must all have the same depth and batch size.
1819   // For x-concatenation, the inputs must have the same height (y-size), and the
1820   // output will have the same depth and height as the inputs and its width (x-
1821   // size) will be the sum of the input widths.  For y-concatenation, the inputs
1822   // must have the same width, and the output will have the same depth and width
1823   // as the inputs, and its height will be the sum of the input heights.
1824   //
1825   // Arguments:
1826   //  stream: borrowed pointer to the stream that the 'space concatenate'
1827   //    operation should be enqueued onto.
1828   //  input_dimensions: the dimensions of each input.
1829   //  input_data: un-owned device memory region which contains the input data
1830   //    for each input layer.
1831   //  output_data: un-owned device memory region in which to place the space
1832   //    concatenate result.
1833   //  concat_direction:  either dnn:SpaceConcatenateMode::XDirection or
1834   //    dnn::SpaceConcatenateMode::YDirection.
1835   virtual bool DoSpaceConcatenate(
1836       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1837       port::ArraySlice<const DeviceMemory<float>*> input_data,
1838       DeviceMemory<float>* output_data,
1839       dnn::SpaceConcatenateMode concat_direction) {
1840     return false;
1841   }
1842 
1843   // Change the layout of the data by shrinking one dimension (or set of
1844   // dimensions) and growing another dimension (or set of dimensions), while
1845   // keeping the total number of data elements constant, and maintaining the
1846   // current data ordering.
1847   //
1848   // Currently, the only supported operation is depth into space by a power of
1849   // 2. E.g. (y, x, z) -> (y*2, x*2, z/4)
1850   //
1851   // Note that Reshape may not be a no-op, depending on the platform and which
1852   // dimensions are being changed.
1853   //
1854   // Example: forgetting about batch for the moment, let's take a tensor that's
1855   // 2x1x8 (y by x by z) and reshape to a tensor that's 4x2x2. The memory layout
1856   // is row-major order: y,x,z. I.e. z changes the fastest, then x, then y. The
1857   // elements of the tensor range from 0 to 15. The x,y,z indices are below each
1858   // element.
1859   //
1860   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1861   // y0 y0 y0 y0 y0 y0 y0 y0 y1 y1 y1 y1 y1 y1 y1 y1
1862   // x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0 x0
1863   // z0 z1 z2 z3 z4 z5 z6 z7 z0 z1 z2 z3 z4 z5 z6 z7
1864   //
1865   // reshape to 4x2x2
1866   //
1867   //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
1868   // y0 y0 y0 y0 y1 y1 y1 y1 y2 y2 y2 y2 y3 y3 y3 y3
1869   // x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1 x0 x0 x1 x1
1870   // z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1 z0 z1
1871   virtual bool DoReshape(Stream* stream,
1872                          const dnn::BatchDescriptor& input_dimensions,
1873                          const DeviceMemory<float>& input_data,
1874                          const dnn::BatchDescriptor& output_dimensions,
1875                          DeviceMemory<float>* output_data) {
1876     return false;
1877   }
1878 
1879   // Depth to space takes an X by Y image with depth D*M^2 and changes it to an
1880   // MX x MY image with depth D. Each input location (x,y) with depth D*M^2 in
1881   // the input image is changed to an MxM contiguous area in the output image,
1882   // with the values being laid out in the raster order by DepthToSpaceLayout,
1883   // and will have a new depth of D.
1884   //
1885   // Example.
1886   // M=2, Din =8, Xin=2, Yin=2. Xout=4, Yout=4,  Dout=2
1887   // DepthHeightWidth layout
1888   // Values within a 'cell' are at different depths and same x & y.
1889   // Input:
1890   // abcdefgh  ijklmnop
1891   // qrstuvwx  yz012345
1892   // Output:
1893   // ae bf im jn
1894   // cg dh ko lp
1895   // qu rv y2 z3
1896   // sw tx 04 15
1897   //
1898   // sqrt_depth_reduction: 'M' in the comment above
1899   virtual bool DoDepthToSpace(Stream* stream,
1900                               const dnn::BatchDescriptor& input_dimensions,
1901                               const DeviceMemory<float>& input_data,
1902                               const DepthToSpaceLayout& depth_to_space_layout,
1903                               const int& sqrt_depth_reduction,
1904                               DeviceMemory<float>* output_data) {
1905     return false;
1906   }
1907 
1908   // Space to depth is the inverse of depth to space. Space to depth takes each
1909   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
1910   // the input, and transforms it to a 1 by 1 patch with depth D*M^2. If the
1911   // input has size (MX, MY, D), the output has size (X, Y, D*M^2). The number
1912   // of data elements is not changed.
1913   //
1914   // Example.
1915   // M=2, Din =2, Xin=4, Yin=4,  Dout=8
1916   // DepthHeightWidth layout
1917   // Values within a 'cell' are at different depths and same x & y.
1918   // Input:
1919   // ae bf im jn
1920   // cg dh ko lp
1921   // qu rv y2 z3
1922   // sw tx 04 15
1923   // Output:
1924   // abcdefgh  ijklmnop
1925   // qrstuvwx  yz012345
1926   //
1927   // sqrt_depth_increase: 'M' in the comment above
1928   virtual bool DoSpaceToDepth(Stream* stream,
1929                               const dnn::BatchDescriptor& input_dimensions,
1930                               const DeviceMemory<float>& input_data,
1931                               const DepthToSpaceLayout& space_to_depth_layout,
1932                               const int& sqrt_depth_increase,
1933                               DeviceMemory<float>* output_data) {
1934     return false;
1935   }
1936 
1937   // Computes the specified operation (e.g. addition or multiplication)
1938   // between corresponding elements in the inputs and stores the result in the
1939   // output element.
1940   // The inputs and output must all have the same dimensions, but may have
1941   // different quantization parameters (min_value and max_value).
1942   //
1943   // Arguments (all borrowed):
1944   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1945   // should be enqueued onto.
1946   //  operation: The operation to perform.
1947   //  input_dimensions: The dimensions of each input.
1948   //  input_data: un-owned device memory region which contains the
1949   //    input data for each input layer.
1950   //  output_dimensions: The dimensions of the output.
1951   //  output_data: un-owned device memory region in which to place the
1952   //    operation result.
1953   virtual bool DoElementwiseOperate(
1954       Stream* stream, ElementwiseOperation operation,
1955       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1956       port::ArraySlice<const DeviceMemory<float>*> input_data,
1957       const dnn::BatchDescriptor& output_dimensions,
1958       DeviceMemory<float>* output_data) = 0;
1959 
1960   // Computes the specified operation (e.g. addition or multiplication)
1961   // between corresponding elements in the inputs and stores the result in the
1962   // output element. Each input is multiplied by a scalar constant and the
1963   // result is divided by a scalar constant.
1964   // e.g. To perform Z = 0.9*X + 1.1*Y, set the input multiplicands to 9 and 11
1965   // and the output divisor to 10.
1966   // The inputs and output must all have the same dimensions, but may have
1967   // different quantization parameters (min_value and max_value).
1968   //
1969   // Arguments (all borrowed):
1970   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1971   // should be enqueued onto.
1972   //  operation: The operation to perform.
1973   //  input_multiplicands: Amount to scale each input.
1974   //  output_divisor: Amount to divide the output.
1975   //  input_dimensions: The dimensions of each input.
1976   //  input_data: un-owned device memory region which contains the
1977   //    input data for each input layer.
1978   //  output_dimensions: The dimensions of the output.
1979   //  output_data: un-owned device memory region in which to place the
1980   //    operation result.
1981   virtual bool DoElementwiseOperateScaledQuantized(
1982       Stream* stream, ElementwiseOperation operation,
1983       port::ArraySlice<int> input_multiplicands, int output_divisor,
1984       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1985       port::ArraySlice<const DeviceMemory<float>*> input_data,
1986       const dnn::BatchDescriptor& output_dimensions,
1987       DeviceMemory<float>* output_data) {
1988     return false;
1989   }
1990 
1991   // Pads the input with zeros in the X and Y dimensions. The feature_map
1992   // dimension is unchanged.
1993   //
1994   // Arguments (all borrowed):
1995   //  stream: borrowed pointer to the stream that the 'elementwise operation'
1996   // should be enqueued onto.
1997   //  dimensions: The dimensions of the input.
1998   //  input_data: un-owned device memory region which contains the
1999   //    input data for the input layer.
2000   //  left_pad: Amount to pad the input on the left.
2001   //  right_pad: Amount to pad the input on the right.
2002   //  top_pad: Amount to pad the input at the top (low Y).
2003   //  bottom_pad: Amount to pad the input at the bottom (high Y).
2004   //  output_data: un-owned device memory region in which to place the
2005   //    padded result.
2006   virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions,
2007                        const DeviceMemory<float> &input_data,
2008                        int64 left_pad, int64 right_pad, int64 top_pad,
2009                        int64 bottom_pad, DeviceMemory<float> *output_data) = 0;
2010 
2011   // Extracts a slice of the input in the X and Y dimensions. The feature_map
2012   // dimension is unchanged.
2013   //
2014   // Arguments (all borrowed):
2015   //  stream: borrowed pointer to the stream that the 'elementwise operation'
2016   // should be enqueued onto.
2017   //  dimensions: The dimensions of the input.
2018   //  input_data: un-owned device memory region which contains the
2019   //    input data for the input layer.
2020   //  left_trim: Amount to cut off the input on the left.
2021   //  right_trim: Amount to cut off the input on the right.
2022   //  top_trim: Amount to cut off the input at the top (low y).
2023   //  bottom_trim: Amount to cut off the input at the bottom (high Y).
2024   //  output_data: un-owned device memory region in which to place the
2025   //    padded result.
2026   virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions,
2027                     const DeviceMemory<float> &input_data,
2028                     int64 left_trim, int64 right_trim, int64 top_trim,
2029                     int64 bottom_trim, DeviceMemory<float> *output_data) = 0;
2030 
2031   // Grows the input tensor by replicating the X and Y dimensions. The batch and
2032   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
2033   // limited to X=1 and Y=1.
2034   //
2035   // For example, the input has dimensions x=2, y=3, and replicate_x=3,
2036   // replicate_y=2. The diagonal elements of the output would be: [x0y0, x1y1,
2037   // x0y2, x1y0, x0y1, x1y2].
2038   // Here is the example as a picture. input:
2039   // AB
2040   // CD
2041   // EF
2042   // broadcast result:
2043   // ABABAB
2044   // CDCDCD
2045   // EFEFEF
2046   // ABABAB
2047   // CDCDCD
2048   // EFEFEF
2049   //
2050   // Arguments (all borrowed):
2051   //  stream: borrowed pointer to the stream that the 'elementwise operation'
2052   // should be enqueued onto.
2053   //  dimensions: The dimensions of the input.
2054   //  input_data: un-owned device memory region which contains the
2055   //    input data for the input layer.
2056   //  replicate_x: Amount to replicate the input's X dimension.
2057   //  replicate_y: Amount to replicate the input's Y dimension.
2058   //  output_data: un-owned device memory region in which to place the
2059   //    padded result.
2060   virtual bool DoXYBroadcast(Stream* stream,
2061                              const dnn::BatchDescriptor& dimensions,
2062                              const DeviceMemory<float>& input_data,
2063                              int64 replicate_x, int64 replicate_y,
2064                              DeviceMemory<float>* output_data) {
2065     return false;
2066   }
2067 
2068   // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
2069   // is, bytes instead of scaled floats) into 'host_dst' if they are available
2070   // for the underlying DNN implementation. If this quantized output is not
2071   // available, false is returned, which will place 'stream' into an error
2072   // state.
2073   //
2074   // Arguments (all borrowed):
2075   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
2076   //    operation should be enqueued onto.
2077   //  gpu_unquantized_src: the device memory that contains the unquantized data
2078   //    -- this data should also have a corresponding quantized representation
2079   //    on the device for this operation to succeed.
2080   //  mode: Type of quantization of the data to write into host_dst.
2081   //  host_dst: un-owned host memory region that is mutated in place,
2082   //    it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
2083   //    (asynchronous) memcpy operation is performed.
2084   //  size: size in bytes of the host_dst host memory region.
2085   virtual bool DoMemcpyD2HQuantized(
2086       Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
2087       QuantizedActivationMode mode, void* host_dst, int64 size) = 0;
2088 
2089   // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
2090   // of a layer (that is, bytes instead of scaled floats) if they are supported
2091   // by the underlying DNN implementation. If this quantized input is not
2092   // supported, false is returned, which will place 'stream' into an error
2093   // state.
2094   //
2095   // Arguments (all borrowed):
2096   //  stream: borrowed pointer to the stream that the 'quantized memcpy'
2097   //    operation should be enqueued onto.
2098   //  host_src: un-owned host memory region that contains the quantized data.
2099   //  size: size in bytes of the host_src host memory region.
2100   //  mode: Type of quantization of the data to read from host_src.
2101   //  gpu_unquantized_dst: the device memory that is clobbered by the values in
2102   //    'host_src' when the enqueued (asynchronous) memcpy operation is
2103   //    performed. -- this data should also have a corresponding quantized
2104   //    representation on the device for this operation to
2105   //    succeed.
2106   virtual bool DoMemcpyH2DQuantized(
2107       Stream* stream, const void* host_src, int64 size,
2108       QuantizedActivationMode mode,
2109       DeviceMemory<float>* gpu_unquantized_dst) = 0;
2110 
2111   // Create an RNN descriptor based on model shapes and configurations.
2112   // The caller retains the ownership of the descriptor.
2113   //
2114   // Arguments:
2115   //  num_layers: the number of layers for a RNN model.
2116   //  hidden_size: the size of the hidden state.
2117   //  input_size: the size of the input state.
2118   //  cell_size: the size of the cell state
2119   //  input_mode: an enum to specify whether a linear transformation is added
2120   //    after the input state. If input_size is different from hidden_size, this
2121   //    is required.
2122   //  direction_mode: an enum to specify whether this model is unidirectional or
2123   //    bidirectional.
2124   //  rnn_mode: an enum to specify the type of model to build.
2125   //  data_type: an enum to specify the data types used in this model.
2126   //  dropout: the dropout threshold between layers. When it is 0., no dropout
2127   //    is added.
2128   //  seed: a seed for initializing the dropout layers.
2129   //  state_allocator: an memory allocator that will be used to store the state
2130   //    for dropout layer. The user has to maintain the memory until the model
2131   //    is no longer in use.
2132   //  use_padded_io: a bool to specify whether the input is using padded IO.
2133   virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
2134   createRnnDescriptor(int num_layers, int hidden_size, int input_size,
2135                       int cell_size, int batch_size,
2136                       dnn::RnnInputMode input_mode,
2137                       dnn::RnnDirectionMode direction_mode,
2138                       dnn::RnnMode rnn_mode, dnn::DataType data_type,
2139                       const dnn::AlgorithmConfig& algorithm_config,
2140                       float dropout, uint64 seed,
2141                       ScratchAllocator* state_allocator, bool use_padded_io) {
2142     return port::Status(port::error::UNIMPLEMENTED,
2143                         "createRnnDescriptor is unimplemented");
2144   }
2145 
2146   // Create a RNN sequence descriptor that specifies either the input or output
2147   // sequence. The caller retains the ownership of the returned descriptor.
2148   //
2149   // Arguments:
2150   //  max_seq_length: the max length of the sequences.
2151   //  batch_size: the size of a minibatch.
2152   //  data_size: the size of the state.
2153   //  seq_lengths: the lengths of sequences in a batch.
2154   //  data_type: an enum to specify the type for the underlying data.
2155   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
2156   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2157                                     int data_size, dnn::DataType data_type) {
2158     return port::Status(port::error::UNIMPLEMENTED,
2159                         "createRnnSequenceTensorDescriptor is unimplemented");
2160   }
2161 
2162   virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
2163   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
2164                                     int data_size,
2165                                     const absl::Span<const int>& seq_lengths,
2166                                     bool time_major, dnn::DataType data_type) {
2167     return port::Status(port::error::UNIMPLEMENTED,
2168                         "createRnnSequenceTensorDescriptor is unimplemented");
2169   }
2170 
2171   // Create an RNN state descriptor that specifies the input or hidden state.
2172   // The caller retains the ownership of the returned descriptor.
2173   virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
2174   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
2175                                  dnn::DataType data_type) {
2176     return port::Status(port::error::UNIMPLEMENTED,
2177                         "createRnnStateTensorDescriptor is unimplemented");
2178   }
2179 
2180   // Enqueue a forward operation of the RNN model onto the stream.
2181   //
2182   // Arguments:
2183   //  stream: pointer to the stream where this operation should be enqueued to.
2184   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2185   //  input_desc: descriptor for the input sequence.
2186   //  input_data: the device memory region that contains the input data.
2187   //  input_h_desc: descriptor for the input "h" state.
2188   //  input_h_data: the device memory region that contains the input "h" data.
2189   //  input_c_desc: descriptor for the input "c" state.
2190   //  input_c_data: the device memory region that contains the input "c" data.
2191   //    This must be specified for LSTM models.
2192   //  params: the device memory region that contains the parameters used in this
2193   //    model.
2194   //  output_desc: descriptor for the output sequence.
2195   //  output_data: the memory region that stores the output sequence data.
2196   //  output_h_desc: descriptor for the output "h" state.
2197   //  output_h_data: the memory region that stores the output "h" data.
2198   //  output_c_desc: descriptor for the output "c" state.
2199   //  output_c_data: the memory region that stores the output "c" data. This
2200   //    must be specified for LSTM models.
2201   //  is_training: whether this is used in training or inference. That decides
2202   //    whether respace_space data need to be produced.
2203   //  reserve_space_allocator: if "is_training" is true, an memory allocator
2204   //    to create memory that holds the produced reserve_space. The caller is
2205   //  retains the data and feed it to the backward pass.
2206   //  workspace_allocator: an allocator to create temporary workspace used in
2207   //    this kernel. The caller is responsible for retaining the memory long
2208   //    enough for the lifespan of this operation, and recycles afterwards.
2209   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2210                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2211                             const DeviceMemory<Eigen::half>& input_data,
2212                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2213                             const DeviceMemory<Eigen::half>& input_h_data,
2214                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2215                             const DeviceMemory<Eigen::half>& input_c_data,
2216                             const DeviceMemory<Eigen::half>& params,
2217                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2218                             DeviceMemory<Eigen::half>* output_data,
2219                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2220                             DeviceMemory<Eigen::half>* output_h_data,
2221                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2222                             DeviceMemory<Eigen::half>* output_c_data,
2223                             bool is_training,
2224                             ScratchAllocator* reserve_space_allocator,
2225                             ScratchAllocator* workspace_allocator,
2226                             dnn::ProfileResult* output_profile_result) {
2227     return false;
2228   }
2229 
2230   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2231                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2232                             const DeviceMemory<float>& input_data,
2233                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2234                             const DeviceMemory<float>& input_h_data,
2235                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2236                             const DeviceMemory<float>& input_c_data,
2237                             const DeviceMemory<float>& params,
2238                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2239                             DeviceMemory<float>* output_data,
2240                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2241                             DeviceMemory<float>* output_h_data,
2242                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2243                             DeviceMemory<float>* output_c_data,
2244                             bool is_training,
2245                             ScratchAllocator* reserve_space_allocator,
2246                             ScratchAllocator* workspace_allocator,
2247                             dnn::ProfileResult* output_profile_result) {
2248     return false;
2249   }
2250 
2251   virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2252                             const dnn::RnnSequenceTensorDescriptor& input_desc,
2253                             const DeviceMemory<double>& input_data,
2254                             const dnn::RnnStateTensorDescriptor& input_h_desc,
2255                             const DeviceMemory<double>& input_h_data,
2256                             const dnn::RnnStateTensorDescriptor& input_c_desc,
2257                             const DeviceMemory<double>& input_c_data,
2258                             const DeviceMemory<double>& params,
2259                             const dnn::RnnSequenceTensorDescriptor& output_desc,
2260                             DeviceMemory<double>* output_data,
2261                             const dnn::RnnStateTensorDescriptor& output_h_desc,
2262                             DeviceMemory<double>* output_h_data,
2263                             const dnn::RnnStateTensorDescriptor& output_c_desc,
2264                             DeviceMemory<double>* output_c_data,
2265                             bool is_training,
2266                             ScratchAllocator* reserve_space_allocator,
2267                             ScratchAllocator* workspace_allocator,
2268                             dnn::ProfileResult* output_profile_result) {
2269     return false;
2270   }
2271   // Enqueue a backward operation of the RNN model onto the stream.
2272   //
2273   // Arguments:
2274   //  stream: pointer to the stream where this operation should be enqueued to.
2275   //  rnn_desc: a RNN descriptor created by createRnnDescriptor.
2276   //  input_desc: descriptor for the input sequence.
2277   //  input_data: the device memory region that contains the input data.
2278   //  input_h_desc: descriptor for the input "h" state.
2279   //  input_h_data: the device memory region that contains the input "h" data.
2280   //  input_c_desc: descriptor for the input "c" state.
2281   //  input_c_data: the device memory region that contains the input "c" data.
2282   //    This must be specified for LSTM models.
2283   //  params: the device memory region that contains the parameters used in this
2284   //    model.
2285   //  output_desc: descriptor for the output sequence.
2286   //  output_data: the memory region that stores the output sequence data.
2287   //  output_h_desc: descriptor for the output "h" state.
2288   //  output_h_data: the memory region that stores the output "h" data.
2289   //  output_c_desc: descriptor for the output "c" state.
2290   //  output_c_data: the memory region that stores the output "c" data. This
2291   //    must be specified for LSTM models.
2292   //  output_backprop_data: the device memory region that contains the backprop
2293   //    to the output sequence.
2294   //  output_h_backprop_data: the device memory region that contains the
2295   //    backprop to the output "h" state.
2296   //  output_c_backprop_data: the device memory region that contains the
2297   //    backprop to the output "c" state.
2298   //  input_backprop_data: the device memory region that stores the backprop
2299   //    to the input sequence.
2300   //  input_h_backprop_data: the device memory region that stores the backprop
2301   //    to the input "h" state.
2302   //  input_c_backprop_data: the device memory region that stores the backprop
2303   //    to the input "c" state.
2304   //  params_backprop_data: the device memory region that stores the backprop
2305   //    to the parameters.
2306   //  reserve_space_data: the reserve_space data that is produced by the forward
2307   //    operation. This memory region could be modified by this operation.
2308   //  workspace_allocator: a memory allocator that creates the temporary
2309   //    workspace memory used by this operation. The caller is responsible for
2310   //    keeping the memory alive long enough for this operation, and recylces
2311   //    afterwards.
2312   virtual bool DoRnnBackward(
2313       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2314       const dnn::RnnSequenceTensorDescriptor& input_desc,
2315       const DeviceMemory<Eigen::half>& input_data,
2316       const dnn::RnnStateTensorDescriptor& input_h_desc,
2317       const DeviceMemory<Eigen::half>& input_h_data,
2318       const dnn::RnnStateTensorDescriptor& input_c_desc,
2319       const DeviceMemory<Eigen::half>& input_c_data,
2320       const DeviceMemory<Eigen::half>& params,
2321       const dnn::RnnSequenceTensorDescriptor& output_desc,
2322       const DeviceMemory<Eigen::half>& output_data,
2323       const dnn::RnnStateTensorDescriptor& output_h_desc,
2324       const DeviceMemory<Eigen::half>& output_h_data,
2325       const dnn::RnnStateTensorDescriptor& output_c_desc,
2326       const DeviceMemory<Eigen::half>& output_c_data,
2327       const DeviceMemory<Eigen::half>& output_backprop_data,
2328       const DeviceMemory<Eigen::half>& output_h_backprop_data,
2329       const DeviceMemory<Eigen::half>& output_c_backprop_data,
2330       DeviceMemory<Eigen::half>* input_backprop_data,
2331       DeviceMemory<Eigen::half>* input_h_backprop_data,
2332       DeviceMemory<Eigen::half>* input_c_backprop_data,
2333       DeviceMemory<Eigen::half>* params_backprop_data,
2334       DeviceMemory<uint8>* reserve_space_data,
2335       ScratchAllocator* workspace_allocator,
2336       dnn::ProfileResult* output_profile_result) {
2337     return false;
2338   }
2339 
2340   virtual bool DoRnnBackward(
2341       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2342       const dnn::RnnSequenceTensorDescriptor& input_desc,
2343       const DeviceMemory<float>& input_data,
2344       const dnn::RnnStateTensorDescriptor& input_h_desc,
2345       const DeviceMemory<float>& input_h_data,
2346       const dnn::RnnStateTensorDescriptor& input_c_desc,
2347       const DeviceMemory<float>& input_c_data,
2348       const DeviceMemory<float>& params,
2349       const dnn::RnnSequenceTensorDescriptor& output_desc,
2350       const DeviceMemory<float>& output_data,
2351       const dnn::RnnStateTensorDescriptor& output_h_desc,
2352       const DeviceMemory<float>& output_h_data,
2353       const dnn::RnnStateTensorDescriptor& output_c_desc,
2354       const DeviceMemory<float>& output_c_data,
2355       const DeviceMemory<float>& output_backprop_data,
2356       const DeviceMemory<float>& output_h_backprop_data,
2357       const DeviceMemory<float>& output_c_backprop_data,
2358       DeviceMemory<float>* input_backprop_data,
2359       DeviceMemory<float>* input_h_backprop_data,
2360       DeviceMemory<float>* input_c_backprop_data,
2361       DeviceMemory<float>* params_backprop_data,
2362       DeviceMemory<uint8>* reserve_space_data,
2363       ScratchAllocator* workspace_allocator,
2364       dnn::ProfileResult* output_profile_result) {
2365     return false;
2366   }
2367 
2368   virtual bool DoRnnBackward(
2369       Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2370       const dnn::RnnSequenceTensorDescriptor& input_desc,
2371       const DeviceMemory<double>& input_data,
2372       const dnn::RnnStateTensorDescriptor& input_h_desc,
2373       const DeviceMemory<double>& input_h_data,
2374       const dnn::RnnStateTensorDescriptor& input_c_desc,
2375       const DeviceMemory<double>& input_c_data,
2376       const DeviceMemory<double>& params,
2377       const dnn::RnnSequenceTensorDescriptor& output_desc,
2378       const DeviceMemory<double>& output_data,
2379       const dnn::RnnStateTensorDescriptor& output_h_desc,
2380       const DeviceMemory<double>& output_h_data,
2381       const dnn::RnnStateTensorDescriptor& output_c_desc,
2382       const DeviceMemory<double>& output_c_data,
2383       const DeviceMemory<double>& output_backprop_data,
2384       const DeviceMemory<double>& output_h_backprop_data,
2385       const DeviceMemory<double>& output_c_backprop_data,
2386       DeviceMemory<double>* input_backprop_data,
2387       DeviceMemory<double>* input_h_backprop_data,
2388       DeviceMemory<double>* input_c_backprop_data,
2389       DeviceMemory<double>* params_backprop_data,
2390       DeviceMemory<uint8>* reserve_space_data,
2391       ScratchAllocator* workspace_allocator,
2392       dnn::ProfileResult* output_profile_result) {
2393     return false;
2394   }
2395 
2396   template <typename ElementType>
2397   port::Status PrepareForCtcLoss(Stream* stream,
2398                                  const RnnStateTensorDescriptor& probs_desc,
2399                                  DeviceMemory<ElementType> probs_data,
2400                                  const RnnStateTensorDescriptor& grads_desc,
2401                                  absl::Span<const int> labels_data,
2402                                  absl::Span<const int> labels_lengths_data,
2403                                  absl::Span<const int> input_lengths_data,
2404                                  ScratchAllocator* workspace_allocator,
2405                                  DeviceMemory<uint8>* scratch_memory) {
2406     return DoPrepareForCtcLoss(stream, ToDataType<ElementType>::value,
2407                                probs_desc, grads_desc, labels_data,
2408                                labels_lengths_data, input_lengths_data,
2409                                workspace_allocator, scratch_memory);
2410   }
2411 
2412   // Enqueue a CTC Loss operation onto the stream.
2413   //
2414   // Arguments:
2415   //  stream: pointer to the stream where this operation should be enqueued to.
2416   //  element_type: date type of the input tensors
2417   //  probs_desc: specifies the shape and the data layout of the input tensor.
2418   //  probs_data: the device memory region that contains the input tensor.
2419   //  labels_data: the device memory region that contains the labels_value
2420   //    tensor.
2421   //  labels_lengths_data: the device memory region that contains the
2422   //    labels_lengths tensor
2423   //  input_lengths_data: the device memory region that contains the seq_lengths
2424   //    tensor
2425   //  costs_data: the device memory region that contains the costs tensor.
2426   //  grads_desc: specifies the shape and the data layout of the grads tensor.
2427   //  grads_data: the device memory region that contains the grads tensor.
2428   //  ctc_loss_desc: a CTCLoss descriptor.
2429   //  workspace_allocator: a memory allocator that creates the temporary
2430   //    workspace memory used by this operation. The caller is responsible for
2431   //    keeping the memory alive long enough for this operation, and recylces
2432   //    afterwards.
2433   virtual port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
2434                                  const RnnStateTensorDescriptor& probs_desc,
2435                                  const DeviceMemoryBase probs_data,
2436                                  absl::Span<const int> labels_data,
2437                                  absl::Span<const int> labels_lengths_data,
2438                                  absl::Span<const int> input_lengths_data,
2439                                  DeviceMemoryBase costs_data,
2440                                  const RnnStateTensorDescriptor& grads_desc,
2441                                  DeviceMemoryBase grads_data,
2442                                  DeviceMemory<uint8> scratch_memory);
2443 
2444   template <typename ElementType>
2445   bool DoCtcLoss(Stream* stream,
2446                  const dnn::RnnStateTensorDescriptor& probs_desc,
2447                  const DeviceMemory<ElementType>& probs_data,
2448                  absl::Span<const int> labels_data,
2449                  absl::Span<const int> labels_lengths_data,
2450                  absl::Span<const int> input_lengths_data,
2451                  DeviceMemory<ElementType>* costs_data,
2452                  const dnn::RnnStateTensorDescriptor& grads_desc,
2453                  DeviceMemory<ElementType>* grads_data,
2454                  DeviceMemory<uint8>* scratch_memory) {
2455     return IsStatusOk(
2456         DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
2457                   probs_data, labels_data, labels_lengths_data,
2458                   input_lengths_data, *costs_data, grads_desc, *grads_data,
2459                   *scratch_memory),
2460         false);
2461   }
2462 
2463   // Transforms a tensor into another tensor with a different layout and/or data
2464   // type.
2465   //
2466   // Arguments:
2467   //  stream: pointer to the stream where this operation should be enqueued to.
2468   //  input_desc: specifies the shape and the data layout of the input tensor.
2469   //  input_type: the data type of the input tensor.
2470   //  input_data: the device memory region that contains the input tensor.
2471   //  output_desc: specifies the shape and the data layout of the output tensor.
2472   //  output_type: the data type of the output tensor.
2473   //  scale: an element-wise scaling factor to apply.
2474   //  output_data: the device memory region that contains the output tensor.
2475   virtual bool DoTransformTensor(Stream* stream,
2476                                  const dnn::BatchDescriptor& input_desc,
2477                                  dnn::DataType input_type,
2478                                  const DeviceMemoryBase& input_data,
2479                                  const dnn::BatchDescriptor& output_desc,
2480                                  dnn::DataType output_type, float scale,
2481                                  DeviceMemoryBase* output_data) {
2482     return false;
2483   }
2484 
2485   // Enqueues a fused convolution+bias+activation operation onto the stream.
2486   //
2487   // Arguments (all borrowed):
2488   //
2489   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2490   //  be enqueued onto.
2491   //
2492   //  conv_input_descriptor: dimensions of the convolution input layer.
2493   //  conv_input_data: device memory which contains the convolution input.
2494   //
2495   //  filter_descriptor: dimensions of the convolution filter.
2496   //  filter_data: device memory which contains the convolution filter weights.
2497   //
2498   //  convolution_descriptor: stride of the convolution filter.
2499   //
2500   //  bias_descriptor: dimensions of the bias layer
2501   //  biases: device memory region containing biases to add to the convolution
2502   //  output
2503   //
2504   //  activation_mode: Type of activation to perform.
2505   //
2506   //  output_descriptor: dimensions of the output layer.
2507   //  output_data: device memory region in which to place the fusion result.
2508   //
2509   //  output_profile_result: the output profile result for this call.
2510   //         The profiling is only enabled when this is not nullptr.
2511   //
2512   virtual bool DoFusedConvolutionBiasActivation(
2513       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
2514       const DeviceMemory<float>& conv_input_data,
2515       const dnn::FilterDescriptor& filter_descriptor,
2516       const DeviceMemory<float>& filter_data,
2517       const dnn::ConvolutionDescriptor& convolution_descriptor,
2518       const dnn::BatchDescriptor& bias_descriptor,
2519       const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
2520       const dnn::BatchDescriptor& output_descriptor,
2521       DeviceMemory<float>* output_data,
2522       dnn::ProfileResult* output_profile_result) {
2523     return false;
2524   }
2525 
2526   // Enqueues a fused batchnorm+activation (inference) operation onto the
2527   // stream.
2528   //
2529   // Arguments (all borrowed):
2530   //
2531   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2532   //  be enqueued onto.
2533   //
2534   //  x_descriptor: dimensions of the batchnorm input layer.
2535   //  x_data: device memory which contains the batchnorm input.
2536   //
2537   //  scale_offset_mean_variance_descriptor:
2538   //      dimensions of the scale/offset/mean/variance tensor.
2539   //  scale_data: device memory which contains the scale input.
2540   //  offset_data: device memory which contains the offset input.
2541   //  mean_data: device memory which contains the mean input.
2542   //  variance_data: device memory which contains the variance input.
2543   //  epsilon : the epsilon value to use in batchnorm calculation
2544   //
2545   //  activation_mode: Type of activation to perform.
2546   //
2547   //  y_data: device memory region in which to place the fusion result.
2548   //
2549   //  output_profile_result: the output profile result for this call.
2550   //         The profiling is only enabled when this is not nullptr.
2551   //
2552   virtual bool DoFusedBatchNormActivationInference(
2553       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2554       const DeviceMemory<float>& x_data,
2555       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2556       const DeviceMemory<float>& scale_data,
2557       const DeviceMemory<float>& offset_data,
2558       const DeviceMemory<float>& mean_data,
2559       const DeviceMemory<float>& variance_data, double epsilon,
2560       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2561       dnn::ProfileResult* output_profile_result) {
2562     return false;
2563   }
2564 
2565   virtual bool DoFusedBatchNormActivationInference(
2566       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2567       const DeviceMemory<Eigen::half>& x_data,
2568       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2569       const DeviceMemory<float>& scale_data,
2570       const DeviceMemory<float>& offset_data,
2571       const DeviceMemory<float>& mean_data,
2572       const DeviceMemory<float>& variance_data, double epsilon,
2573       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2574       dnn::ProfileResult* output_profile_result) {
2575     return false;
2576   }
2577 
2578   // Enqueues a fused batchnorm+activation (training-fwd) operation onto the
2579   // stream.
2580   //
2581   // Arguments (all borrowed):
2582   //
2583   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2584   //  be enqueued onto.
2585   //
2586   //  x_descriptor: dimensions of the batchnorm input layer.
2587   //  x_data: device memory which contains the batchnorm input.
2588   //
2589   //  scale_offset_mean_variance_descriptor:
2590   //      dimensions of the scale/offset/mean/variance tensor.
2591   //  scale_data: device memory which contains the scale input.
2592   //  offset_data: device memory which contains the offset input.
2593   //  epsilon : the epsilon value to use in batchnorm calculation
2594   //
2595   //  activation_mode: Type of activation to perform.
2596   //
2597   //  y_data: device memory region in which to place the fusion result.
2598   //  batch_mean_data: device memory in which to place the batch mean output.
2599   //  batch_var_data: device memory in which to place the batch variance output.
2600   //  saved_mean_data: device memory in which to save the mean for bwd pass.
2601   //  saved_var_data: device memory in which to save the variance for bwd pass.
2602   //
2603   //  output_profile_result: the output profile result for this call.
2604   //         The profiling is only enabled when this is not nullptr.
2605   //
2606   virtual bool DoFusedBatchNormActivationForward(
2607       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2608       const DeviceMemory<float>& x_data,
2609       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2610       const DeviceMemory<float>& scale_data,
2611       const DeviceMemory<float>& offset_data, double epsilon,
2612       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
2613       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2614       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2615       dnn::ProfileResult* output_profile_result) {
2616     return false;
2617   }
2618 
2619   virtual bool DoFusedBatchNormActivationForward(
2620       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
2621       const DeviceMemory<Eigen::half>& x_data,
2622       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2623       const DeviceMemory<float>& scale_data,
2624       const DeviceMemory<float>& offset_data, double epsilon,
2625       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
2626       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
2627       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
2628       dnn::ProfileResult* output_profile_result) {
2629     return false;
2630   }
2631 
2632   // Enqueues a fused batchnorm+activation (training-bwd) operation onto the
2633   // stream.
2634   //
2635   // Arguments (all borrowed):
2636   //
2637   //  stream: borrowed pointer to the stream that the 'fusion' operation should
2638   //  be enqueued onto.
2639   //
2640   //  y_act_backprop_descriptor: dimensions of the backprop input from the
2641   //  previous layer. y_act_backprop_data: device memory which contains the
2642   //  backprop input.
2643   //
2644   //  y_act_data: device memory which contains the actv-fwd output data.
2645   //
2646   //  activation_mode: actv-fwd type.
2647   //
2648   //  scale_offset_mean_variance_descriptor:
2649   //      dimensions of the scale/offset/mean/variance tensor.
2650   //  scale_data: device memory which contains the scale input.
2651   //  offset_data: device memory which contains the offset input.
2652   //  saved_mean_data: device memory which contains the saved mean from fwd
2653   //  pass. saved_var_data: device memory which contains the saved variance from
2654   //  fwd pass.
2655   //
2656   //  x_bn_backprop_data: device memory region in which to place the backprop
2657   //  data from this layer scale_backprop_data: device memory in which to place
2658   //  the scale backprop output. offset_backprop_data: device memory in which to
2659   //  place the offset backprop output.
2660   //
2661   //  output_profile_result: the output profile result for this call.
2662   //         The profiling is only enabled when this is not nullptr.
2663   //
2664   virtual bool DoFusedBatchNormActivationBackward(
2665       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2666       const DeviceMemory<float>& y_act_backprop_data,
2667       const DeviceMemory<float>& y_act_data,
2668       dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
2669       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2670       const DeviceMemory<float>& scale_data,
2671       const DeviceMemory<float>& offset_data,
2672       const DeviceMemory<float>& saved_mean_data,
2673       const DeviceMemory<float>& saved_var_data,
2674       DeviceMemory<float>* x_bn_backprop_data,
2675       DeviceMemory<float>* scale_backprop_data,
2676       DeviceMemory<float>* offset_backprop_data,
2677       dnn::ProfileResult* output_profile_result) {
2678     return false;
2679   }
2680 
2681   virtual bool DoFusedBatchNormActivationBackward(
2682       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
2683       const DeviceMemory<Eigen::half>& y_act_backprop_data,
2684       const DeviceMemory<Eigen::half>& y_act_data,
2685       dnn::ActivationMode activation_mode,
2686       const DeviceMemory<Eigen::half>& x_bn_data,
2687       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
2688       const DeviceMemory<float>& scale_data,
2689       const DeviceMemory<float>& offset_data,
2690       const DeviceMemory<float>& saved_mean_data,
2691       const DeviceMemory<float>& saved_var_data,
2692       DeviceMemory<Eigen::half>* x_bn_backprop_data,
2693       DeviceMemory<float>* scale_backprop_data,
2694       DeviceMemory<float>* offset_backprop_data,
2695       dnn::ProfileResult* output_profile_result) {
2696     return false;
2697   }
2698 
2699  protected:
2700   // Returns whether status is 'ok', and potentially logs the error.
2701   static bool IsStatusOk(const port::Status& status, bool report_error);
2702 
2703  private:
2704   virtual port::Status DoPrepareForConvolution(
2705       ConvolutionKind kind, DataType element_type, Stream* stream,
2706       const BatchDescriptor& batch_descriptor, DeviceMemoryBase input_data,
2707       const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data,
2708       const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
2709       const ConvolutionDescriptor& convolution_descriptor,
2710       const AlgorithmConfig& algorithm_config,
2711       ScratchAllocator* scratch_allocator, AlgorithmDesc* algorithm_desc,
2712       DeviceMemory<uint8>* scratch_memory) {
2713     *algorithm_desc = {};
2714     *scratch_memory = {};
2715     return port::Status::OK();
2716   }
2717 
2718   virtual port::Status DoPrepareForCtcLoss(
2719       Stream* stream, DataType element_type,
2720       const RnnStateTensorDescriptor& probs_desc,
2721       const RnnStateTensorDescriptor& grads_desc,
2722       absl::Span<const int> labels_data,
2723       absl::Span<const int> labels_lengths_data,
2724       absl::Span<const int> input_lengths_data,
2725       ScratchAllocator* scratch_allocator,
2726       DeviceMemory<uint8>* scratch_memory) {
2727     *scratch_memory = {};
2728     return port::Status::OK();
2729   }
2730 
2731   SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
2732 };
2733 
2734 }  // namespace dnn
2735 }  // namespace stream_executor
2736 
2737 #endif  // TENSORFLOW_STREAM_EXECUTOR_DNN_H_
2738