• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20 
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23 
24 #include <complex>
25 #include <functional>
26 #include <memory>
27 #include <type_traits>
28 
29 #include "absl/synchronization/mutex.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34 #include "tensorflow/stream_executor/dnn.h"
35 #include "tensorflow/stream_executor/event.h"
36 #include "tensorflow/stream_executor/fft.h"
37 #include "tensorflow/stream_executor/host_or_device_scalar.h"
38 #include "tensorflow/stream_executor/kernel.h"
39 #include "tensorflow/stream_executor/launch_dim.h"
40 #include "tensorflow/stream_executor/lib/array_slice.h"
41 #include "tensorflow/stream_executor/platform/port.h"
42 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
43 #include "tensorflow/stream_executor/temporary_memory_manager.h"
44 
45 #if GOOGLE_CUDA
46 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
47 #endif  // GOOGLE_CUDA
48 
49 namespace stream_executor {
50 
51 namespace host {
52 class HostBlas;
53 class HostFft;
54 class HostRng;
55 class HostTimer;
56 }  // namespace host
57 
58 namespace ocl {
59 class CLBlas;
60 }  // namespace ocl
61 
62 namespace internal {
63 class StreamInterface;
64 }  // namespace internal
65 
66 class DeviceMemoryBase;
67 template <typename ElemT>
68 class DeviceMemory;
69 
70 class Timer;
71 
72 namespace dnn {
73 class BatchDescriptor;
74 class FilterDescriptor;
75 class ConvolutionDescriptor;
76 class ProfileResult;
77 class AlgorithmDesc;
78 }  // namespace dnn
79 
80 class StreamExecutor;
81 class ScratchAllocator;
82 
83 namespace detail {
84 
85 // Helper class to prevent a template function argument from being deduced. This
86 // is identical to std::type_identity in C++20.
87 template <typename T>
88 struct NonDeduced {
89   using type = T;
90 };
91 template <typename T>
92 using NonDeducedType = typename NonDeduced<T>::type;
93 
94 }  // namespace detail
95 
96 // Convert a type to the corresponding QuantizedActivationMode.
97 template <typename ElementType>
98 struct Quantization;
99 
100 // Represents a stream of dependent computations on a GPU device.
101 //
102 // The operations within a stream execute linearly and asynchronously until
103 // BlockHostUntilDone() is invoked, which synchronously joins host code with
104 // the execution of the stream.
105 //
106 // If any given operation fails when entraining work for the stream, ok() will
107 // indicate that an error has occurred. After initialization, once a stream is
108 // !ok(), it will never be ok().
109 //
110 // Thread-safe post-initialization.
111 class Stream {
112  public:
113   // Instantiate a stream tied to parent as a platform executor. Work
114   // entrained onto this stream will be launched/managed on that
115   // StreamExecutor's platform.
116   explicit Stream(StreamExecutor *parent);
117 
118   // Test only. Use an externally-populated value (like a mock) for the
119   // platform-specific stream implementation.
120   Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
121 
122   // Deallocates any stream resources that the parent StreamExecutor has
123   // bestowed
124   // upon this object.
125   ~Stream();
126 
127   // Returns whether any errors have occurred while entraining work for this
128   // stream.
ok()129   bool ok() const { return !InErrorState(); }
130 
131   // Retrieves execution status back into the stream from the underlying
132   // implementation without blocking the stream.
133   //
134   // Normally, Stream::BlockHostUntilDone is used to get execution status.
135   // However, some devices use out-of-band mechnanisms to ensure their streams
136   // have finished on-device work, without needing to block the streams. (These
137   // devices should also override AllowsSyncOnCompletion to return false.) For
138   // these devices, this method can be used after work is finished to retrieve
139   // execution status.
140   port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_);
141 
142   // Initialize the stream. This must be performed before entraining any other
143   // operations.
144   Stream &Init() TF_LOCKS_EXCLUDED(mu_);
145 
146   // Initializes timer t via the StreamExecutor.
147   Stream &InitTimer(Timer *t);
148 
149   // Convenience wrapper around Init() and InitTimer().
150   Stream &InitWithTimer(Timer *t);
151 
152   // Get or create a sub-stream from this stream. If there is any sub-stream in
153   // the pool that can be reused then just return this sub-stream.  Otherwise
154   // create a new sub-stream.
155   //
156   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
157   Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_);
158 
159   // Return the sub-stream back to the host stream so that it can be reused
160   // later. Sub-streams that are !ok() will not be reused.
161   //
162   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
163   void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_);
164 
165   // Allocate temporary memories. The stream will deallocate them when blocked
166   // or destroyed.
167   template <typename T>
168   port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
169   AllocateTemporaryArray(uint64 element_count);
170 
171   // Entrains onto the stream of operations: a kernel launch with the given
172   // (variadic) parameters for the invocation. These arguments can be things
173   // like DeviceMemory or primitive types such as int. What arguments you may
174   // pass to a given kernel are noted as the template parameters to the
175   // TypedKernel type that the machocc compiler generates.
176   //
177   // Template parameters:
178   //  Params...   The type list of formal parameters that the typed kernel
179   //              expects, which is matched against Args...
180   //  Args...     The deduced type list for passed actual arguments
181   //
182   // Implementation: A compile-time compatibility check is performed that has
183   // some leniency versus an exact parameter pack match -- for example,
184   // `const DeviceMemory<T>` is considered "pack compatible" with a
185   // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
186   // perfect forwarding support without rvalue references. It also attempts to
187   // spit out helpful static_assert error traces with information as to the
188   // argument number and types that were mismatched.
189   template <typename... Params, typename... Args>
190   Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
191                      const TypedKernel<Params...> &kernel, Args... args);
192 
193   // Record a "start" event for the interval timer at this point in the
194   // stream's execution (relative to the previously and subsequently enqueued
195   // items in the stream's execution). Streams may be started/stopped multiple
196   // times.
197   Stream &ThenStartTimer(Timer *t);
198 
199   // Record a "stop" event for the interval timer at this point in the
200   // stream's execution. See also Stream::ThenStartTimer.
201   Stream &ThenStopTimer(Timer *t);
202 
203   // TODO(leary) If work is added to the stream that is being depended upon,
204   //              then what? Have to describe what happens.
205   template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)206   Stream &ThenWaitFor(Stream *other, Params... more_streams) {
207     return ThenWaitFor(more_streams...).ThenWaitFor(other);
208   }
209 
210   // Create a dependency for this stream's next work on the other stream
211   // completing. Does not take ownership of other, and other must not be
212   // null.
213   //
214   // Checks that a stream does not wait for itself, and it is up to the
215   // user to guarantee that a stream does not come to wait on itself in a
216   // cyclic manner; in that case, behavior is undefined.
217   //
218   // N.B. Base recursion case for the variadic ThenWaitFor.
219   Stream &ThenWaitFor(Stream *other);
220 
221   // Waits for all streams values in others.
222   // Checks that there is no shallow circular wait (i.e. that "this" is not in
223   // others)
224   template <typename P>
ThenWaitFor(P others)225   Stream &ThenWaitFor(P others) {
226     for (auto &stream : *others) {
227       CHECK_NE(stream.get(), this);
228       ThenWaitFor(stream.get());
229     }
230     return *this;
231   }
232 
233   // Waits for an event object to be set.
234   // Note that ThenRecordEvent must have been called on the event before
235   // you call this function; otherwise the event will be considered complete
236   // and this wait will do nothing.
237   Stream &ThenWaitFor(Event *event);
238 
239   // Inserts the specified event into the end of this stream. Once the stream
240   // has processed all events prior to the insertion point, the event will be
241   // marked as completed.
242   // The stream does not take ownership of event - meaning that event's lifetime
243   // must extend past the point at which it is marked complete!
244   Stream &ThenRecordEvent(Event *event);
245 
246   ////////////////
247   // DNN support
248   //
249   // See DnnSupport::* for comments on the following methods.
250 
251   Stream &ThenBatchNormalizationForward(
252       const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
253       const DeviceMemory<float> &offset,
254       const DeviceMemory<float> &estimated_mean,
255       const DeviceMemory<float> &estimated_variance,
256       const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
257       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
258       const double exponential_average_factor,
259       dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
260       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
261       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
262       bool is_training,
263       ScratchAllocator *reserve_space_allocator,
264       ScratchAllocator *workspace_allocator);
265 
266   Stream &ThenBatchNormalizationBackward(
267       const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
268       const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
269       const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
270       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
271       DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
272       DeviceMemory<float> *offset_backprop,
273       DeviceMemory<uint8> *reserve_space_data,
274       ScratchAllocator *workspace_allocator);
275 
276   Stream &ThenBatchNormalizationForward(
277       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
278       const DeviceMemory<float> &offset,
279       const DeviceMemory<float> &estimated_mean,
280       const DeviceMemory<float> &estimated_variance,
281       const DeviceMemory<Eigen::half> &side_input,
282       const dnn::BatchDescriptor &x_desc,
283       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
284       const double exponential_average_factor,
285       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
286       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
287       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
288       bool is_training, ScratchAllocator *reserve_space_allocator,
289       ScratchAllocator *workspace_allocator);
290 
291   Stream &ThenBatchNormalizationBackward(
292       const DeviceMemory<Eigen::half> &y_backprop,
293       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
294       const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
295       const dnn::BatchDescriptor &x_desc,
296       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
297       DeviceMemory<Eigen::half> *x_backprop,
298       DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
299       DeviceMemory<uint8> *reserve_space_data,
300       ScratchAllocator *workspace_allocator);
301 
302   Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
303                        const DeviceMemory<float> &input_data,
304                        const dnn::FilterDescriptor &filter_descriptor,
305                        const DeviceMemory<float> &filter_data,
306                        const dnn::ConvolutionDescriptor &convolution_descriptor,
307                        const dnn::BatchDescriptor &output_descriptor,
308                        DeviceMemory<float> *output);
309 
310   Stream &ThenConvolveQuantized(
311       const dnn::BatchDescriptor &input_descriptor,
312       const DeviceMemory<float> &input_data,
313       const dnn::FilterDescriptor &filter_descriptor,
314       const DeviceMemory<int8> &filter_coefficients,
315       const DeviceMemory<float> &coefficient_scales,
316       const dnn::ConvolutionDescriptor &convolution_descriptor,
317       const dnn::BatchDescriptor &output_descriptor,
318       DeviceMemory<float> *output_data);
319 
320   Stream &ThenConvolveQuantized(
321       const dnn::BatchDescriptor &input_descriptor,
322       const DeviceMemory<float> &input_data,
323       const dnn::FilterDescriptor &filter_descriptor,
324       const DeviceMemory<int16> &filter_coefficients,
325       const DeviceMemory<float> &coefficient_scales,
326       const dnn::ConvolutionDescriptor &convolution_descriptor,
327       const dnn::BatchDescriptor &output_descriptor,
328       DeviceMemory<float> *output_data);
329 
330   template <typename InputType, typename OutputType>
ConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<InputType> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)331   port::Status ConvolveWithAlgorithm(
332       const dnn::BatchDescriptor &input_descriptor,
333       const DeviceMemory<InputType> &input_data,
334       const dnn::FilterDescriptor &filter_descriptor,
335       const DeviceMemory<InputType> &filter_data,
336       const dnn::ConvolutionDescriptor &convolution_descriptor,
337       const dnn::BatchDescriptor &output_descriptor,
338       DeviceMemory<OutputType> *output, ScratchAllocator *scratch_allocator,
339       const dnn::AlgorithmConfig &algorithm_config,
340       dnn::ProfileResult *output_profile_result) {
341     DeviceMemory<uint8> scratch_memory;
342     dnn::AlgorithmDesc algorithm_desc;
343     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
344       TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
345           dnn::ConvolutionKind::FORWARD, this, input_descriptor, input_data,
346           filter_descriptor, filter_data, output_descriptor, *output,
347           convolution_descriptor, algorithm_config, scratch_allocator,
348           &algorithm_desc, &scratch_memory));
349       return dnn->DoConvolve(
350           dnn::ConvolutionKind::FORWARD, dnn::ToDataType<InputType>::value,
351           dnn::ToDataType<OutputType>::value, this, input_descriptor,
352           input_data, filter_descriptor, filter_data, output_descriptor,
353           *output, convolution_descriptor, algorithm_desc, scratch_memory,
354           output_profile_result);
355     }
356     return port::UnimplementedError("DNN library is not found.");
357   }
358 
359   template <typename InputType, typename OutputType>
ConvolveWithExecutionPlan(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<InputType> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & plan_config,dnn::ProfileResult * output_profile_result)360   port::Status ConvolveWithExecutionPlan(
361       const dnn::BatchDescriptor &input_descriptor,
362       const DeviceMemory<InputType> &input_data,
363       const dnn::FilterDescriptor &filter_descriptor,
364       const DeviceMemory<InputType> &filter_data,
365       const dnn::ConvolutionDescriptor &convolution_descriptor,
366       const dnn::BatchDescriptor &output_descriptor,
367       DeviceMemory<OutputType> *output, ScratchAllocator *scratch_allocator,
368       const dnn::AlgorithmConfig &plan_config,
369       dnn::ProfileResult *output_profile_result) {
370 #if GOOGLE_CUDA
371     dnn::DnnSupport *dnn = parent_->AsDnn();
372     if (dnn) {
373       gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
374       return cudnn_dnn->DoConvolveWithExecutionPlan(
375           dnn::ConvolutionKind::FORWARD, dnn::ToDataType<InputType>::value,
376           dnn::ToDataType<OutputType>::value, this, input_descriptor,
377           input_data, filter_descriptor, filter_data, output_descriptor,
378           *output, convolution_descriptor, plan_config, scratch_allocator,
379           output_profile_result);
380     }
381 #endif  // GOOGLE_CUDA
382     return port::UnimplementedError("DNN library is not found.");
383   }
384 
385   template <typename InputT, typename ScaleT, typename SideInputT,
386             typename BiasT, typename OutputT>
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<InputT> & conv_input_data,ScaleT conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputT> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<SideInputT> & side_input_data,ScaleT side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasT> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputT> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)387   port::Status FusedConvolveWithAlgorithm(
388       const dnn::BatchDescriptor &conv_input_descriptor,
389       const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale,
390       const dnn::FilterDescriptor &filter_descriptor,
391       const DeviceMemory<InputT> &filter_data,
392       const dnn::ConvolutionDescriptor &convolution_descriptor,
393       const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale,
394       const dnn::BatchDescriptor &bias_descriptor,
395       const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode,
396       const dnn::BatchDescriptor &output_descriptor,
397       DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator,
398       const dnn::AlgorithmConfig &algorithm_config,
399       dnn::ProfileResult *output_profile_result) {
400     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
401       return dnn->DoFusedConvolve(
402           this, dnn::ToDataType<InputT>::value,
403           dnn::ToDataType<SideInputT>::value, dnn::ToDataType<BiasT>::value,
404           dnn::ToDataType<OutputT>::value, conv_input_descriptor,
405           conv_input_data, conv_input_scale, filter_descriptor, filter_data,
406           convolution_descriptor, side_input_data, side_input_scale,
407           bias_descriptor, biases, activation_mode, output_descriptor, *output,
408           scratch_allocator, algorithm_config, output_profile_result);
409     }
410     return port::UnimplementedError("DNN library is not found.");
411   }
412 
413   template <typename InputT, typename ScaleT, typename SideInputT,
414             typename BiasT, typename OutputT>
FusedConvolveWithExecutionPlan(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<InputT> & conv_input_data,ScaleT conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputT> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<SideInputT> & side_input_data,ScaleT side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasT> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputT> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)415   port::Status FusedConvolveWithExecutionPlan(
416       const dnn::BatchDescriptor &conv_input_descriptor,
417       const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale,
418       const dnn::FilterDescriptor &filter_descriptor,
419       const DeviceMemory<InputT> &filter_data,
420       const dnn::ConvolutionDescriptor &convolution_descriptor,
421       const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale,
422       const dnn::BatchDescriptor &bias_descriptor,
423       const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode,
424       const dnn::BatchDescriptor &output_descriptor,
425       DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator,
426       const dnn::AlgorithmConfig &algorithm_config,
427       dnn::ProfileResult *output_profile_result) {
428 #if GOOGLE_CUDA
429     dnn::DnnSupport *dnn = parent_->AsDnn();
430     if (dnn) {
431       gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
432       return cudnn_dnn->DoFusedConvolveWithExecutionPlan(
433           this, dnn::ToDataType<InputT>::value, conv_input_descriptor,
434           conv_input_data, conv_input_scale, filter_descriptor, filter_data,
435           convolution_descriptor, side_input_data, side_input_scale,
436           bias_descriptor, biases, activation_mode, output_descriptor, *output,
437           scratch_allocator, algorithm_config, output_profile_result);
438     }
439 #endif  // GOOGLE_CUDA
440     return port::UnimplementedError("DNN library is not found.");
441   }
442 
443   Stream &ThenSeparableConvolve(
444       const dnn::BatchDescriptor &input_descriptor,
445       const DeviceMemory<float> &input_data,
446       const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
447       const DeviceMemory<float> &first_weights,
448       const DeviceMemory<float> &second_weights,
449       const dnn::ConvolutionDescriptor &convolution_descriptor,
450       const dnn::BatchDescriptor &output_descriptor,
451       DeviceMemory<float> *output);
452 
453   template <typename ElementType>
ConvolveBackwardDataWithExecutionPlan(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<ElementType> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & plan_config,dnn::ProfileResult * output_profile_result)454   port::Status ConvolveBackwardDataWithExecutionPlan(
455       const dnn::FilterDescriptor &filter_descriptor,
456       const DeviceMemory<ElementType> &filter_data,
457       const dnn::BatchDescriptor &output_descriptor,
458       DeviceMemory<ElementType> backward_output_data,
459       const dnn::ConvolutionDescriptor &convolution_descriptor,
460       const dnn::BatchDescriptor &input_descriptor,
461       DeviceMemory<ElementType> *backward_input_data,
462       ScratchAllocator *scratch_allocator,
463       const dnn::AlgorithmConfig &plan_config,
464       dnn::ProfileResult *output_profile_result) {
465 #if GOOGLE_CUDA
466     dnn::DnnSupport *dnn = parent_->AsDnn();
467     if (dnn) {
468       gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
469       return cudnn_dnn->DoConvolveWithExecutionPlan(
470           dnn::ConvolutionKind::BACKWARD_DATA,
471           dnn::ToDataType<ElementType>::value,
472           dnn::ToDataType<ElementType>::value, this, input_descriptor,
473           *backward_input_data, filter_descriptor, filter_data,
474           output_descriptor, backward_output_data, convolution_descriptor,
475           plan_config, scratch_allocator, output_profile_result);
476     }
477 #endif  // GOOGLE_CUDA
478     return port::UnimplementedError("DNN library is not found.");
479   }
480 
481   template <typename ElementType>
ConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<ElementType> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)482   port::Status ConvolveBackwardDataWithAlgorithm(
483       const dnn::FilterDescriptor &filter_descriptor,
484       const DeviceMemory<ElementType> &filter_data,
485       const dnn::BatchDescriptor &output_descriptor,
486       DeviceMemory<ElementType> backward_output_data,
487       const dnn::ConvolutionDescriptor &convolution_descriptor,
488       const dnn::BatchDescriptor &input_descriptor,
489       DeviceMemory<ElementType> *backward_input_data,
490       ScratchAllocator *scratch_allocator,
491       const dnn::AlgorithmConfig &algorithm_config,
492       dnn::ProfileResult *output_profile_result) {
493     DeviceMemory<uint8> scratch_memory;
494     dnn::AlgorithmDesc algorithm_desc;
495     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
496       TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
497           dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
498           *backward_input_data, filter_descriptor, filter_data,
499           output_descriptor, backward_output_data, convolution_descriptor,
500           algorithm_config, scratch_allocator, &algorithm_desc,
501           &scratch_memory));
502       return dnn->DoConvolve(
503           dnn::ConvolutionKind::BACKWARD_DATA,
504           dnn::ToDataType<ElementType>::value,
505           dnn::ToDataType<ElementType>::value, this, input_descriptor,
506           *backward_input_data, filter_descriptor, filter_data,
507           output_descriptor, backward_output_data, convolution_descriptor,
508           algorithm_desc, scratch_memory, output_profile_result);
509     }
510     return port::UnimplementedError("DNN library is not found.");
511   }
512 
513   template <typename ElementType>
ConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)514   port::Status ConvolveBackwardFilterWithAlgorithm(
515       const dnn::BatchDescriptor &input_descriptor,
516       const DeviceMemory<ElementType> &input_data,
517       const dnn::BatchDescriptor &output_descriptor,
518       DeviceMemory<ElementType> backward_output_data,
519       const dnn::ConvolutionDescriptor &convolution_descriptor,
520       const dnn::FilterDescriptor &filter_descriptor,
521       DeviceMemory<ElementType> *backward_filter_data,
522       ScratchAllocator *scratch_allocator,
523       const dnn::AlgorithmConfig &algorithm_config,
524       dnn::ProfileResult *output_profile_result) {
525     DeviceMemory<uint8> scratch_memory;
526     dnn::AlgorithmDesc algorithm_desc;
527     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
528       TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
529           dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
530           input_data, filter_descriptor, *backward_filter_data,
531           output_descriptor, backward_output_data, convolution_descriptor,
532           algorithm_config, scratch_allocator, &algorithm_desc,
533           &scratch_memory));
534       return dnn->DoConvolve(
535           dnn::ConvolutionKind::BACKWARD_FILTER,
536           dnn::ToDataType<ElementType>::value,
537           dnn::ToDataType<ElementType>::value, this, input_descriptor,
538           input_data, filter_descriptor, *backward_filter_data,
539           output_descriptor, backward_output_data, convolution_descriptor,
540           algorithm_desc, scratch_memory, output_profile_result);
541     }
542     return port::UnimplementedError("DNN library is not found.");
543   }
544 
545   template <typename ElementType>
ConvolveBackwardFilterWithExecutionPlan(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & plan_config,dnn::ProfileResult * output_profile_result)546   port::Status ConvolveBackwardFilterWithExecutionPlan(
547       const dnn::BatchDescriptor &input_descriptor,
548       const DeviceMemory<ElementType> &input_data,
549       const dnn::BatchDescriptor &output_descriptor,
550       DeviceMemory<ElementType> backward_output_data,
551       const dnn::ConvolutionDescriptor &convolution_descriptor,
552       const dnn::FilterDescriptor &filter_descriptor,
553       DeviceMemory<ElementType> *backward_filter_data,
554       ScratchAllocator *scratch_allocator,
555       const dnn::AlgorithmConfig &plan_config,
556       dnn::ProfileResult *output_profile_result) {
557 #if GOOGLE_CUDA
558     dnn::DnnSupport *dnn = parent_->AsDnn();
559     if (dnn) {
560       gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
561       return cudnn_dnn->DoConvolveWithExecutionPlan(
562           dnn::ConvolutionKind::BACKWARD_FILTER,
563           dnn::ToDataType<ElementType>::value,
564           dnn::ToDataType<ElementType>::value, this, input_descriptor,
565           input_data, filter_descriptor, *backward_filter_data,
566           output_descriptor, backward_output_data, convolution_descriptor,
567           plan_config, scratch_allocator, output_profile_result);
568     }
569 #endif  // GOOGLE_CUDA
570     return port::UnimplementedError("DNN library is not found.");
571   }
572 
573   Stream &ThenMatMul(const DeviceMemory<float> &input_data,
574                      const DeviceMemory<float> &weights,
575                      const dnn::BatchDescriptor &input_dimensions,
576                      const dnn::BatchDescriptor &output_dimensions,
577                      DeviceMemory<float> *output_data);
578 
579   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
580                               const DeviceMemory<int8> &weights,
581                               const DeviceMemory<float> &weight_scales,
582                               const dnn::BatchDescriptor &input_dimensions,
583                               const dnn::BatchDescriptor &output_dimensions,
584                               DeviceMemory<float> *output_data);
585 
586   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
587                               const DeviceMemory<int16> &weights,
588                               const DeviceMemory<float> &weight_scales,
589                               const dnn::BatchDescriptor &input_dimensions,
590                               const dnn::BatchDescriptor &output_dimensions,
591                               DeviceMemory<float> *output_data);
592 
593   Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
594                       const DeviceMemory<float> &biases,
595                       const dnn::BatchDescriptor &dimensions,
596                       DeviceMemory<float> *output_data);
597 
598   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
599                           const dnn::BatchDescriptor &input_dimensions,
600                           const DeviceMemory<double> &input_data,
601                           const dnn::BatchDescriptor &output_dimensions,
602                           DeviceMemory<double> *output_data,
603                           ScratchAllocator *workspace_allocator = nullptr);
604 
605   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
606                           const dnn::BatchDescriptor &input_dimensions,
607                           const DeviceMemory<float> &input_data,
608                           const dnn::BatchDescriptor &output_dimensions,
609                           DeviceMemory<float> *output_data,
610                           ScratchAllocator *workspace_allocator = nullptr);
611 
612   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
613                           const dnn::BatchDescriptor &input_dimensions,
614                           const DeviceMemory<Eigen::half> &input_data,
615                           const dnn::BatchDescriptor &output_dimensions,
616                           DeviceMemory<Eigen::half> *output_data,
617                           ScratchAllocator *workspace_allocator = nullptr);
618 
619   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
620                           const dnn::BatchDescriptor &input_dimensions,
621                           const DeviceMemory<int8> &input_data,
622                           const dnn::BatchDescriptor &output_dimensions,
623                           DeviceMemory<int8> *output_data,
624                           ScratchAllocator *workspace_allocator = nullptr);
625 
626   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
627                            const dnn::BatchDescriptor &input_dimensions,
628                            const DeviceMemory<double> &input_data,
629                            const dnn::BatchDescriptor &output_dimensions,
630                            const DeviceMemory<double> &output_data,
631                            const DeviceMemory<double> &input_diff_data,
632                            DeviceMemory<double> *output_diff_data,
633                            ScratchAllocator *workspace_allocator = nullptr);
634 
635   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
636                            const dnn::BatchDescriptor &input_dimensions,
637                            const DeviceMemory<float> &input_data,
638                            const dnn::BatchDescriptor &output_dimensions,
639                            const DeviceMemory<float> &output_data,
640                            const DeviceMemory<float> &input_diff_data,
641                            DeviceMemory<float> *output_diff_data,
642                            ScratchAllocator *workspace_allocator = nullptr);
643 
644   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
645                            const dnn::BatchDescriptor &input_dimensions,
646                            const DeviceMemory<Eigen::half> &input_data,
647                            const dnn::BatchDescriptor &output_dimensions,
648                            const DeviceMemory<Eigen::half> &output_data,
649                            const DeviceMemory<Eigen::half> &input_diff_data,
650                            DeviceMemory<Eigen::half> *output_diff_data,
651                            ScratchAllocator *workspace_allocator = nullptr);
652 
653   Stream &ThenNormalizeWithDimensions(
654       const dnn::NormalizeDescriptor &normalize_descriptor,
655       const dnn::BatchDescriptor &dimensions,
656       const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
657 
658   Stream &ThenNormalizeBackwardWithDimensions(
659       const dnn::NormalizeDescriptor &normalize_descriptor,
660       const dnn::BatchDescriptor &dimensions,
661       const DeviceMemory<float> &raw_data,
662       const DeviceMemory<float> &normalized_data,
663       const DeviceMemory<float> &normalized_variable_gradient,
664       DeviceMemory<float> *raw_variable_gradient,
665       ScratchAllocator *workspace_allocator = nullptr);
666 
667   Stream &ThenActivate(dnn::ActivationMode activation_mode,
668                        const dnn::BatchDescriptor &dimensions,
669                        const DeviceMemory<float> &input_data,
670                        DeviceMemory<float> *output_data);
671 
672   // Same as ThenActivate, but also takes an options argument that can be used
673   // for platform-specific option flags.
674   Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
675                                   const dnn::BatchDescriptor &dimensions,
676                                   const DeviceMemory<float> &input_data,
677                                   DeviceMemory<float> *output_data,
678                                   uint64 options);
679 
680   Stream &ThenDepthConcatenate(
681       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
682       port::ArraySlice<const DeviceMemory<float> *> input_data,
683       DeviceMemory<float> *output_data);
684 
685   Stream &ThenSpaceConcatenate(
686       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
687       port::ArraySlice<const DeviceMemory<float> *> input_data,
688       DeviceMemory<float> *output_data,
689       dnn::SpaceConcatenateMode concat_direction);
690 
691   // Change the layout of the data by shrinking one dimension (or set of
692   // dimensions) and growing another dimension (or set of dimensions), while
693   // keeping the total number of data elements constant, and maintaining the
694   // current data ordering.
695   Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
696                       const DeviceMemory<float> &input_data,
697                       const dnn::BatchDescriptor &output_dimensions,
698                       DeviceMemory<float> *output_data);
699 
700   // Depth to space takes an X by Y image with depth D*M² and changes it to an
701   // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
702   // the input image is changed to an MxM contiguous area in the output image,
703   // with the values being laid out in raster order specified by
704   // DepthToSpaceLayout, and will have a new depth of D.
705   // See the DoDepthToSpace comment for more information.
706   Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
707                            const DeviceMemory<float> &input_data,
708                            const dnn::DepthToSpaceLayout &depth_to_space_layout,
709                            const int sqrt_depth_reduction,
710                            DeviceMemory<float> *output_data);
711 
712   // Space to depth is the inverse of depth to space. Space to depth takes each
713   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
714   // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
715   // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
716   // data elements is not changed.
717   Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
718                            const DeviceMemory<float> &input_data,
719                            const dnn::DepthToSpaceLayout &space_to_depth_layout,
720                            const int sqrt_depth_increase,
721                            DeviceMemory<float> *output_data);
722 
723   Stream &ThenElementwiseOperate(
724       dnn::ElementwiseOperation operation,
725       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
726       port::ArraySlice<const DeviceMemory<float> *> input_data,
727       const dnn::BatchDescriptor &output_dimensions,
728       DeviceMemory<float> *output_data);
729 
730   Stream &ThenElementwiseOperateScaledQuantized(
731       dnn::ElementwiseOperation operation,
732       port::ArraySlice<int> input_multiplicands, int output_divisor,
733       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
734       port::ArraySlice<const DeviceMemory<float> *> input_data,
735       const dnn::BatchDescriptor &output_dimensions,
736       DeviceMemory<float> *output_data);
737 
738   Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
739                     const DeviceMemory<float> &input_data, int64_t left_pad,
740                     int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
741                     DeviceMemory<float> *output_data);
742 
743   Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
744                       const DeviceMemory<float> &input_data, int64_t left_trim,
745                       int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
746                       DeviceMemory<float> *output_data);
747 
748   // Grows the input tensor by replicating the X and Y dimensions. The batch and
749   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
750   // limited to X=1 and Y=1.
751   Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
752                           const DeviceMemory<float> &input_data,
753                           int64_t replicate_x, int64_t replicate_y,
754                           DeviceMemory<float> *output_data);
755 
756   // See DnnSupport::DoMemcpyD2HQuantized.
757   Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
758                                  dnn::QuantizedActivationMode mode,
759                                  void *host_dst, uint64 size);
760 
761   // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
762   // and uses the Quantization trait to call the generic version of
763   // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
764   template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)765   Stream &ThenMemcpyD2HQuantized(
766       const DeviceMemory<float> &gpu_unquantized_src,
767       port::MutableArraySlice<ElementType> host_dst) {
768     return ThenMemcpyD2HQuantized(
769         gpu_unquantized_src, Quantization<ElementType>::kModeId,
770         host_dst.data(), host_dst.size() * sizeof(ElementType));
771   }
772 
773   // See DnnSupport::DoMemcpyH2DQuantized.
774   Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
775                                  dnn::QuantizedActivationMode mode,
776                                  DeviceMemory<float> *gpu_unquantized_dst);
777 
778   // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
779   // and uses the Quantization trait to call the generic version of
780   // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
781   template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)782   Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
783                                  DeviceMemory<float> *gpu_unquantized_dst) {
784     return ThenMemcpyH2DQuantized(
785         host_src.data(), host_src.size() * sizeof(ElementType),
786         Quantization<ElementType>::kModeId, gpu_unquantized_dst);
787   }
788 
789   // See DnnSupport::DoCopyHostBuffer2Device.
790   Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
791                                     DeviceMemory<float> *gpu_unquantized_dst);
792 
793   // See DnnSupport::DoCopyDevice2HostBuffer.
794   Stream &ThenCopyDevice2HostBuffer(
795       const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
796 
797   /////////////////
798   // BLAS support
799 
800   // See BlasSupport::DoBlasAsum.
801   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
802                        int incx, DeviceMemory<float> *result);
803   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
804                        int incx, DeviceMemory<double> *result);
805   Stream &ThenBlasAsum(uint64 elem_count,
806                        const DeviceMemory<std::complex<float>> &x, int incx,
807                        DeviceMemory<float> *result);
808   Stream &ThenBlasAsum(uint64 elem_count,
809                        const DeviceMemory<std::complex<double>> &x, int incx,
810                        DeviceMemory<double> *result);
811 
812   // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
813   // present in DeviceMemory, it must be an execution-time constant (i.e. a
814   // value
815   // that the stream does not change or populate during the course of
816   // execution). The value is effectively captured at stream-enqueue time.
817   Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
818                        const DeviceMemory<float> &x, int incx,
819                        DeviceMemory<float> *y, int incy);
820   Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
821                        const DeviceMemory<double> &x, int incx,
822                        DeviceMemory<double> *y, int incy);
823   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
824                        const DeviceMemory<std::complex<float>> &x, int incx,
825                        DeviceMemory<std::complex<float>> *y, int incy);
826   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
827                        const DeviceMemory<std::complex<double>> &x, int incx,
828                        DeviceMemory<std::complex<double>> *y, int incy);
829 
830   // See BlasSupport::DoBlasCopy.
831   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
832                        int incx, DeviceMemory<float> *y, int incy);
833   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
834                        int incx, DeviceMemory<double> *y, int incy);
835   Stream &ThenBlasCopy(uint64 elem_count,
836                        const DeviceMemory<std::complex<float>> &x, int incx,
837                        DeviceMemory<std::complex<float>> *y, int incy);
838   Stream &ThenBlasCopy(uint64 elem_count,
839                        const DeviceMemory<std::complex<double>> &x, int incx,
840                        DeviceMemory<std::complex<double>> *y, int incy);
841 
842   // See BlasSupport::DoBlasDot.
843   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
844                       const DeviceMemory<float> &y, int incy,
845                       DeviceMemory<float> *result);
846   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
847                       int incx, const DeviceMemory<double> &y, int incy,
848                       DeviceMemory<double> *result);
849 
850   // See BlasSupport::DoBlasDotc.
851   Stream &ThenBlasDotc(uint64 elem_count,
852                        const DeviceMemory<std::complex<float>> &x, int incx,
853                        const DeviceMemory<std::complex<float>> &y, int incy,
854                        DeviceMemory<std::complex<float>> *result);
855   Stream &ThenBlasDotc(uint64 elem_count,
856                        const DeviceMemory<std::complex<double>> &x, int incx,
857                        const DeviceMemory<std::complex<double>> &y, int incy,
858                        DeviceMemory<std::complex<double>> *result);
859 
860   // See BlasSupport::DoBlasDotu.
861   Stream &ThenBlasDotu(uint64 elem_count,
862                        const DeviceMemory<std::complex<float>> &x, int incx,
863                        const DeviceMemory<std::complex<float>> &y, int incy,
864                        DeviceMemory<std::complex<float>> *result);
865   Stream &ThenBlasDotu(uint64 elem_count,
866                        const DeviceMemory<std::complex<double>> &x, int incx,
867                        const DeviceMemory<std::complex<double>> &y, int incy,
868                        DeviceMemory<std::complex<double>> *result);
869 
870   // See BlasSupport::DoBlasNrm2.
871   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
872                        int incx, DeviceMemory<float> *result);
873   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
874                        int incx, DeviceMemory<double> *result);
875   Stream &ThenBlasNrm2(uint64 elem_count,
876                        const DeviceMemory<std::complex<float>> &x, int incx,
877                        DeviceMemory<float> *result);
878   Stream &ThenBlasNrm2(uint64 elem_count,
879                        const DeviceMemory<std::complex<double>> &x, int incx,
880                        DeviceMemory<double> *result);
881 
882   // See BlasSupport::DoBlasRot.
883   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
884                       DeviceMemory<float> *y, int incy, float c, float s);
885   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
886                       DeviceMemory<double> *y, int incy, double c, double s);
887   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
888                       int incx, DeviceMemory<std::complex<float>> *y, int incy,
889                       float c, float s);
890   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
891                       int incx, DeviceMemory<std::complex<double>> *y, int incy,
892                       double c, double s);
893 
894   // See BlasSupport::DoBlasRotg.
895   Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
896                        DeviceMemory<float> *c, DeviceMemory<float> *s);
897   Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
898                        DeviceMemory<double> *c, DeviceMemory<double> *s);
899   Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
900                        DeviceMemory<std::complex<float>> *b,
901                        DeviceMemory<float> *c,
902                        DeviceMemory<std::complex<float>> *s);
903   Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
904                        DeviceMemory<std::complex<double>> *b,
905                        DeviceMemory<double> *c,
906                        DeviceMemory<std::complex<double>> *s);
907 
908   // See BlasSupport::DoBlasRotm.
909   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
910                        DeviceMemory<float> *y, int incy,
911                        const DeviceMemory<float> &param);
912   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
913                        DeviceMemory<double> *y, int incy,
914                        const DeviceMemory<double> &param);
915 
916   // See BlasSupport::DoBlasRotmg.
917   Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
918                         DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
919                         DeviceMemory<float> *param);
920   Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
921                         DeviceMemory<double> *x1,
922                         const DeviceMemory<double> &y1,
923                         DeviceMemory<double> *param);
924 
925   // See BlasSupport::DoBlasScal.
926   Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
927                        int incx);
928   Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
929                        int incx);
930   Stream &ThenBlasScal(uint64 elem_count, float alpha,
931                        DeviceMemory<std::complex<float>> *x, int incx);
932   Stream &ThenBlasScal(uint64 elem_count, double alpha,
933                        DeviceMemory<std::complex<double>> *x, int incx);
934   Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
935                        DeviceMemory<std::complex<float>> *x, int incx);
936   Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
937                        DeviceMemory<std::complex<double>> *x, int incx);
938 
939   // See BlasSupport::DoBlasSwap.
940   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
941                        DeviceMemory<float> *y, int incy);
942   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
943                        DeviceMemory<double> *y, int incy);
944   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
945                        int incx, DeviceMemory<std::complex<float>> *y,
946                        int incy);
947   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
948                        int incx, DeviceMemory<std::complex<double>> *y,
949                        int incy);
950 
951   // See BlasSupport::DoBlasIamax.
952   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
953                         int incx, DeviceMemory<int> *result);
954   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
955                         int incx, DeviceMemory<int> *result);
956   Stream &ThenBlasIamax(uint64 elem_count,
957                         const DeviceMemory<std::complex<float>> &x, int incx,
958                         DeviceMemory<int> *result);
959   Stream &ThenBlasIamax(uint64 elem_count,
960                         const DeviceMemory<std::complex<double>> &x, int incx,
961                         DeviceMemory<int> *result);
962 
963   // See BlasSupport::DoBlasIamin.
964   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
965                         int incx, DeviceMemory<int> *result);
966   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
967                         int incx, DeviceMemory<int> *result);
968   Stream &ThenBlasIamin(uint64 elem_count,
969                         const DeviceMemory<std::complex<float>> &x, int incx,
970                         DeviceMemory<int> *result);
971   Stream &ThenBlasIamin(uint64 elem_count,
972                         const DeviceMemory<std::complex<double>> &x, int incx,
973                         DeviceMemory<int> *result);
974 
975   // See BlasSupport::DoBlasGbmv.
976   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
977                        uint64 ku, float alpha, const DeviceMemory<float> &a,
978                        int lda, const DeviceMemory<float> &x, int incx,
979                        float beta, DeviceMemory<float> *y, int incy);
980   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
981                        uint64 ku, double alpha, const DeviceMemory<double> &a,
982                        int lda, const DeviceMemory<double> &x, int incx,
983                        double beta, DeviceMemory<double> *y, int incy);
984   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
985                        uint64 ku, std::complex<float> alpha,
986                        const DeviceMemory<std::complex<float>> &a, int lda,
987                        const DeviceMemory<std::complex<float>> &x, int incx,
988                        std::complex<float> beta,
989                        DeviceMemory<std::complex<float>> *y, int incy);
990   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
991                        uint64 ku, std::complex<double> alpha,
992                        const DeviceMemory<std::complex<double>> &a, int lda,
993                        const DeviceMemory<std::complex<double>> &x, int incx,
994                        std::complex<double> beta,
995                        DeviceMemory<std::complex<double>> *y, int incy);
996 
997   // See BlasSupport::DoBlasGemv.
998   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
999                        const DeviceMemory<float> &a, int lda,
1000                        const DeviceMemory<float> &x, int incx, float beta,
1001                        DeviceMemory<float> *y, int incy);
1002   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
1003                        const DeviceMemory<double> &a, int lda,
1004                        const DeviceMemory<double> &x, int incx, double beta,
1005                        DeviceMemory<double> *y, int incy);
1006   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1007                        std::complex<float> alpha,
1008                        const DeviceMemory<std::complex<float>> &a, int lda,
1009                        const DeviceMemory<std::complex<float>> &x, int incx,
1010                        std::complex<float> beta,
1011                        DeviceMemory<std::complex<float>> *y, int incy);
1012   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1013                        std::complex<double> alpha,
1014                        const DeviceMemory<std::complex<double>> &a, int lda,
1015                        const DeviceMemory<std::complex<double>> &x, int incx,
1016                        std::complex<double> beta,
1017                        DeviceMemory<std::complex<double>> *y, int incy);
1018 
1019   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
1020                                     float alpha, const DeviceMemory<float> &a,
1021                                     int lda, const DeviceMemory<float> &x,
1022                                     int incx, float beta,
1023                                     DeviceMemory<float> *y, int incy,
1024                                     blas::ProfileResult *output_profile_result);
1025   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
1026                                     double alpha, const DeviceMemory<double> &a,
1027                                     int lda, const DeviceMemory<double> &x,
1028                                     int incx, double beta,
1029                                     DeviceMemory<double> *y, int incy,
1030                                     blas::ProfileResult *output_profile_result);
1031   Stream &ThenBlasGemvWithProfiling(
1032       blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
1033       const DeviceMemory<std::complex<float>> &a, int lda,
1034       const DeviceMemory<std::complex<float>> &x, int incx,
1035       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1036       blas::ProfileResult *output_profile_result);
1037   Stream &ThenBlasGemvWithProfiling(
1038       blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
1039       const DeviceMemory<std::complex<double>> &a, int lda,
1040       const DeviceMemory<std::complex<double>> &x, int incx,
1041       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
1042       int incy, blas::ProfileResult *output_profile_result);
1043 
1044   // See BlasSupport::DoBlasGer.
1045   Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
1046                       const DeviceMemory<float> &x, int incx,
1047                       const DeviceMemory<float> &y, int incy,
1048                       DeviceMemory<float> *a, int lda);
1049   Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
1050                       const DeviceMemory<double> &x, int incx,
1051                       const DeviceMemory<double> &y, int incy,
1052                       DeviceMemory<double> *a, int lda);
1053 
1054   // See BlasSupport::DoBlasGerc.
1055   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
1056                        const DeviceMemory<std::complex<float>> &x, int incx,
1057                        const DeviceMemory<std::complex<float>> &y, int incy,
1058                        DeviceMemory<std::complex<float>> *a, int lda);
1059   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
1060                        const DeviceMemory<std::complex<double>> &x, int incx,
1061                        const DeviceMemory<std::complex<double>> &y, int incy,
1062                        DeviceMemory<std::complex<double>> *a, int lda);
1063 
1064   // See BlasSupport::DoBlasGeru.
1065   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
1066                        const DeviceMemory<std::complex<float>> &x, int incx,
1067                        const DeviceMemory<std::complex<float>> &y, int incy,
1068                        DeviceMemory<std::complex<float>> *a, int lda);
1069   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
1070                        const DeviceMemory<std::complex<double>> &x, int incx,
1071                        const DeviceMemory<std::complex<double>> &y, int incy,
1072                        DeviceMemory<std::complex<double>> *a, int lda);
1073 
1074   // See BlasSupport::DoBlasHbmv.
1075   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1076                        std::complex<float> alpha,
1077                        const DeviceMemory<std::complex<float>> &a, int lda,
1078                        const DeviceMemory<std::complex<float>> &x, int incx,
1079                        std::complex<float> beta,
1080                        DeviceMemory<std::complex<float>> *y, int incy);
1081   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1082                        std::complex<double> alpha,
1083                        const DeviceMemory<std::complex<double>> &a, int lda,
1084                        const DeviceMemory<std::complex<double>> &x, int incx,
1085                        std::complex<double> beta,
1086                        DeviceMemory<std::complex<double>> *y, int incy);
1087 
1088   // See BlasSupport::DoBlasHemv.
1089   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1090                        std::complex<float> alpha,
1091                        const DeviceMemory<std::complex<float>> &a, int lda,
1092                        const DeviceMemory<std::complex<float>> &x, int incx,
1093                        std::complex<float> beta,
1094                        DeviceMemory<std::complex<float>> *y, int incy);
1095   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1096                        std::complex<double> alpha,
1097                        const DeviceMemory<std::complex<double>> &a, int lda,
1098                        const DeviceMemory<std::complex<double>> &x, int incx,
1099                        std::complex<double> beta,
1100                        DeviceMemory<std::complex<double>> *y, int incy);
1101 
1102   // See BlasSupport::DoBlasHer.
1103   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1104                       const DeviceMemory<std::complex<float>> &x, int incx,
1105                       DeviceMemory<std::complex<float>> *a, int lda);
1106   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1107                       const DeviceMemory<std::complex<double>> &x, int incx,
1108                       DeviceMemory<std::complex<double>> *a, int lda);
1109 
1110   // See BlasSupport::DoBlasHer2.
1111   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1112                        std::complex<float> alpha,
1113                        const DeviceMemory<std::complex<float>> &x, int incx,
1114                        const DeviceMemory<std::complex<float>> &y, int incy,
1115                        DeviceMemory<std::complex<float>> *a, int lda);
1116   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1117                        std::complex<double> alpha,
1118                        const DeviceMemory<std::complex<double>> &x, int incx,
1119                        const DeviceMemory<std::complex<double>> &y, int incy,
1120                        DeviceMemory<std::complex<double>> *a, int lda);
1121 
1122   // See BlasSupport::DoBlasHpmv.
1123   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1124                        std::complex<float> alpha,
1125                        const DeviceMemory<std::complex<float>> &ap,
1126                        const DeviceMemory<std::complex<float>> &x, int incx,
1127                        std::complex<float> beta,
1128                        DeviceMemory<std::complex<float>> *y, int incy);
1129   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1130                        std::complex<double> alpha,
1131                        const DeviceMemory<std::complex<double>> &ap,
1132                        const DeviceMemory<std::complex<double>> &x, int incx,
1133                        std::complex<double> beta,
1134                        DeviceMemory<std::complex<double>> *y, int incy);
1135 
1136   // See BlasSupport::DoBlasHpr.
1137   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1138                       const DeviceMemory<std::complex<float>> &x, int incx,
1139                       DeviceMemory<std::complex<float>> *ap);
1140   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1141                       const DeviceMemory<std::complex<double>> &x, int incx,
1142                       DeviceMemory<std::complex<double>> *ap);
1143 
1144   // See BlasSupport::DoBlasHpr2.
1145   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1146                        std::complex<float> alpha,
1147                        const DeviceMemory<std::complex<float>> &x, int incx,
1148                        const DeviceMemory<std::complex<float>> &y, int incy,
1149                        DeviceMemory<std::complex<float>> *ap);
1150   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1151                        std::complex<double> alpha,
1152                        const DeviceMemory<std::complex<double>> &x, int incx,
1153                        const DeviceMemory<std::complex<double>> &y, int incy,
1154                        DeviceMemory<std::complex<double>> *ap);
1155 
1156   // See BlasSupport::DoBlasSbmv.
1157   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1158                        const DeviceMemory<float> &a, int lda,
1159                        const DeviceMemory<float> &x, int incx, float beta,
1160                        DeviceMemory<float> *y, int incy);
1161   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1162                        const DeviceMemory<double> &a, int lda,
1163                        const DeviceMemory<double> &x, int incx, double beta,
1164                        DeviceMemory<double> *y, int incy);
1165 
1166   // See BlasSupport::DoBlasSpmv.
1167   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1168                        const DeviceMemory<float> &ap,
1169                        const DeviceMemory<float> &x, int incx, float beta,
1170                        DeviceMemory<float> *y, int incy);
1171   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1172                        const DeviceMemory<double> &ap,
1173                        const DeviceMemory<double> &x, int incx, double beta,
1174                        DeviceMemory<double> *y, int incy);
1175 
1176   // See BlasSupport::DoBlasSpr.
1177   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1178                       const DeviceMemory<float> &x, int incx,
1179                       DeviceMemory<float> *ap);
1180   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1181                       const DeviceMemory<double> &x, int incx,
1182                       DeviceMemory<double> *ap);
1183 
1184   // See BlasSupport::DoBlasSpr2.
1185   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1186                        const DeviceMemory<float> &x, int incx,
1187                        const DeviceMemory<float> &y, int incy,
1188                        DeviceMemory<float> *ap);
1189   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1190                        const DeviceMemory<double> &x, int incx,
1191                        const DeviceMemory<double> &y, int incy,
1192                        DeviceMemory<double> *ap);
1193 
1194   // See BlasSupport::DoBlasSymv.
1195   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1196                        const DeviceMemory<float> &a, int lda,
1197                        const DeviceMemory<float> &x, int incx, float beta,
1198                        DeviceMemory<float> *y, int incy);
1199   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1200                        const DeviceMemory<double> &a, int lda,
1201                        const DeviceMemory<double> &x, int incx, double beta,
1202                        DeviceMemory<double> *y, int incy);
1203 
1204   // See BlasSupport::DoBlasSyr.
1205   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1206                       const DeviceMemory<float> &x, int incx,
1207                       DeviceMemory<float> *a, int lda);
1208   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1209                       const DeviceMemory<double> &x, int incx,
1210                       DeviceMemory<double> *a, int lda);
1211 
1212   // See BlasSupport::DoBlasSyr2.
1213   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1214                        const DeviceMemory<float> &x, int incx,
1215                        const DeviceMemory<float> &y, int incy,
1216                        DeviceMemory<float> *a, int lda);
1217   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1218                        const DeviceMemory<double> &x, int incx,
1219                        const DeviceMemory<double> &y, int incy,
1220                        DeviceMemory<double> *a, int lda);
1221 
1222   // See BlasSupport::DoBlasTbmv.
1223   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1224                        blas::Diagonal diag, uint64 n, uint64 k,
1225                        const DeviceMemory<float> &a, int lda,
1226                        DeviceMemory<float> *x, int incx);
1227   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1228                        blas::Diagonal diag, uint64 n, uint64 k,
1229                        const DeviceMemory<double> &a, int lda,
1230                        DeviceMemory<double> *x, int incx);
1231   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1232                        blas::Diagonal diag, uint64 n, uint64 k,
1233                        const DeviceMemory<std::complex<float>> &a, int lda,
1234                        DeviceMemory<std::complex<float>> *x, int incx);
1235   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1236                        blas::Diagonal diag, uint64 n, uint64 k,
1237                        const DeviceMemory<std::complex<double>> &a, int lda,
1238                        DeviceMemory<std::complex<double>> *x, int incx);
1239 
1240   // See BlasSupport::DoBlasTbsv.
1241   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1242                        blas::Diagonal diag, uint64 n, uint64 k,
1243                        const DeviceMemory<float> &a, int lda,
1244                        DeviceMemory<float> *x, int incx);
1245   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1246                        blas::Diagonal diag, uint64 n, uint64 k,
1247                        const DeviceMemory<double> &a, int lda,
1248                        DeviceMemory<double> *x, int incx);
1249   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1250                        blas::Diagonal diag, uint64 n, uint64 k,
1251                        const DeviceMemory<std::complex<float>> &a, int lda,
1252                        DeviceMemory<std::complex<float>> *x, int incx);
1253   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1254                        blas::Diagonal diag, uint64 n, uint64 k,
1255                        const DeviceMemory<std::complex<double>> &a, int lda,
1256                        DeviceMemory<std::complex<double>> *x, int incx);
1257 
1258   // See BlasSupport::DoBlasTpmv.
1259   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1260                        blas::Diagonal diag, uint64 n,
1261                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1262                        int incx);
1263   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1264                        blas::Diagonal diag, uint64 n,
1265                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1266                        int incx);
1267   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1268                        blas::Diagonal diag, uint64 n,
1269                        const DeviceMemory<std::complex<float>> &ap,
1270                        DeviceMemory<std::complex<float>> *x, int incx);
1271   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1272                        blas::Diagonal diag, uint64 n,
1273                        const DeviceMemory<std::complex<double>> &ap,
1274                        DeviceMemory<std::complex<double>> *x, int incx);
1275 
1276   // See BlasSupport::DoBlasTpsv.
1277   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1278                        blas::Diagonal diag, uint64 n,
1279                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1280                        int incx);
1281   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1282                        blas::Diagonal diag, uint64 n,
1283                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1284                        int incx);
1285   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1286                        blas::Diagonal diag, uint64 n,
1287                        const DeviceMemory<std::complex<float>> &ap,
1288                        DeviceMemory<std::complex<float>> *x, int incx);
1289   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1290                        blas::Diagonal diag, uint64 n,
1291                        const DeviceMemory<std::complex<double>> &ap,
1292                        DeviceMemory<std::complex<double>> *x, int incx);
1293 
1294   // See BlasSupport::DoBlasTrmv.
1295   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1296                        blas::Diagonal diag, uint64 n,
1297                        const DeviceMemory<float> &a, int lda,
1298                        DeviceMemory<float> *x, int incx);
1299   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1300                        blas::Diagonal diag, uint64 n,
1301                        const DeviceMemory<double> &a, int lda,
1302                        DeviceMemory<double> *x, int incx);
1303   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1304                        blas::Diagonal diag, uint64 n,
1305                        const DeviceMemory<std::complex<float>> &a, int lda,
1306                        DeviceMemory<std::complex<float>> *x, int incx);
1307   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1308                        blas::Diagonal diag, uint64 n,
1309                        const DeviceMemory<std::complex<double>> &a, int lda,
1310                        DeviceMemory<std::complex<double>> *x, int incx);
1311 
1312   // See BlasSupport::DoBlasTrsv.
1313   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1314                        blas::Diagonal diag, uint64 n,
1315                        const DeviceMemory<float> &a, int lda,
1316                        DeviceMemory<float> *x, int incx);
1317   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1318                        blas::Diagonal diag, uint64 n,
1319                        const DeviceMemory<double> &a, int lda,
1320                        DeviceMemory<double> *x, int incx);
1321   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1322                        blas::Diagonal diag, uint64 n,
1323                        const DeviceMemory<std::complex<float>> &a, int lda,
1324                        DeviceMemory<std::complex<float>> *x, int incx);
1325   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1326                        blas::Diagonal diag, uint64 n,
1327                        const DeviceMemory<std::complex<double>> &a, int lda,
1328                        DeviceMemory<std::complex<double>> *x, int incx);
1329 
1330   template <typename InputType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<InputType> * c,int ldc)1331   port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1332                             uint64 m, uint64 n, uint64 k,
1333                             const DeviceMemory<InputType> &a, int lda,
1334                             const DeviceMemory<InputType> &b, int ldb,
1335                             DeviceMemory<InputType> *c, int ldc) {
1336     InputType alpha{1.0};
1337     InputType beta{0.0};
1338     return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
1339                         ldc);
1340   }
1341 
1342   template <typename InputType, typename ConstantType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<InputType> * c,int ldc)1343   port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1344                             uint64 m, uint64 n, uint64 k, ConstantType alpha,
1345                             const DeviceMemory<InputType> &a, int lda,
1346                             const DeviceMemory<InputType> &b, int ldb,
1347                             ConstantType beta, DeviceMemory<InputType> *c,
1348                             int ldc) {
1349     static_assert(!std::is_same<InputType, Eigen::half>::value ||
1350                       std::is_same<ConstantType, float>::value ||
1351                       std::is_same<ConstantType, Eigen::half>::value,
1352                   "If input is Eigen::half, constant has to be either "
1353                   "Eigen::half or float");
1354     static_assert(
1355         std::is_same<InputType, Eigen::half>::value ||
1356             std::is_same<InputType, ConstantType>::value,
1357         "If input is not Eigen::half, constant and input types have to match");
1358     static_assert(
1359         std::is_same<InputType, Eigen::half>::value ||
1360             std::is_same<InputType, Eigen::bfloat16>::value ||
1361             std::is_same<InputType, float>::value ||
1362             std::is_same<InputType, double>::value ||
1363             std::is_same<InputType, std::complex<float>>::value ||
1364             std::is_same<InputType, std::complex<double>>::value,
1365         "Input can be half, bf16, float, double, std::complex<float> or "
1366         "std::complex<double>");
1367     blas::BlasSupport *blas = parent()->AsBlas();
1368     if (!blas) {
1369       return port::InternalError(
1370           "Attempting to perform BLAS operation using "
1371           "StreamExecutor without BLAS support");
1372     }
1373 
1374     void *alpha_ptr = &alpha;
1375     void *beta_ptr = &beta;
1376     float alpha_storage, beta_storage;
1377     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1378                                     &beta_storage);
1379 
1380     return blas->DoBlasGemm(this, transa, transb, m, n, k,
1381                             blas::ToDataType<InputType>::value, alpha_ptr, a,
1382                             lda, b, ldb, beta_ptr, c, ldc);
1383   }
1384 
1385   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1386                                     blas::Transpose transb, uint64 m, uint64 n,
1387                                     uint64 k, float alpha,
1388                                     const DeviceMemory<Eigen::half> &a, int lda,
1389                                     const DeviceMemory<Eigen::half> &b, int ldb,
1390                                     float beta, DeviceMemory<Eigen::half> *c,
1391                                     int ldc,
1392                                     blas::ProfileResult *output_profile_result);
1393   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1394                                     blas::Transpose transb, uint64 m, uint64 n,
1395                                     uint64 k, float alpha,
1396                                     const DeviceMemory<float> &a, int lda,
1397                                     const DeviceMemory<float> &b, int ldb,
1398                                     float beta, DeviceMemory<float> *c, int ldc,
1399                                     blas::ProfileResult *output_profile_result);
1400   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1401                                     blas::Transpose transb, uint64 m, uint64 n,
1402                                     uint64 k, double alpha,
1403                                     const DeviceMemory<double> &a, int lda,
1404                                     const DeviceMemory<double> &b, int ldb,
1405                                     double beta, DeviceMemory<double> *c,
1406                                     int ldc,
1407                                     blas::ProfileResult *output_profile_result);
1408   Stream &ThenBlasGemmWithProfiling(
1409       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1410       uint64 k, std::complex<float> alpha,
1411       const DeviceMemory<std::complex<float>> &a, int lda,
1412       const DeviceMemory<std::complex<float>> &b, int ldb,
1413       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1414       blas::ProfileResult *output_profile_result);
1415   Stream &ThenBlasGemmWithProfiling(
1416       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1417       uint64 k, std::complex<double> alpha,
1418       const DeviceMemory<std::complex<double>> &a, int lda,
1419       const DeviceMemory<std::complex<double>> &b, int ldb,
1420       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1421       blas::ProfileResult *output_profile_result);
1422 
1423   template <typename InputType, typename OutputType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1424   port::Status ThenBlasGemmWithAlgorithm(
1425       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1426       uint64 k, const DeviceMemory<InputType> &a, int lda,
1427       const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c,
1428       int ldc, blas::ComputationType computation_type,
1429       blas::AlgorithmType algorithm,
1430       blas::ProfileResult *output_profile_result) {
1431     OutputType alpha{1};
1432     OutputType beta{0};
1433     return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b,
1434                                      ldb, beta, c, ldc, computation_type,
1435                                      algorithm, output_profile_result);
1436   }
1437 
1438   template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1439   port::Status ThenBlasGemmWithAlgorithm(
1440       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1441       uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1442       const DeviceMemory<InputType> &b, int ldb, ConstantType beta,
1443       DeviceMemory<OutputType> *c, int ldc,
1444       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1445       blas::ProfileResult *output_profile_result) {
1446     TF_RETURN_IF_ERROR(
1447         CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
1448             computation_type));
1449 
1450     blas::BlasSupport *blas = parent()->AsBlas();
1451     if (!blas) {
1452       return port::InternalError(
1453           "Attempting to perform BLAS operation using "
1454           "StreamExecutor without BLAS support");
1455     }
1456 
1457     void *alpha_ptr = &alpha;
1458     void *beta_ptr = &beta;
1459     float alpha_storage, beta_storage;
1460     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1461                                     &beta_storage);
1462 
1463     port::Status st = blas->DoBlasGemmWithAlgorithm(
1464         this, transa, transb, m, n, k, alpha_ptr, a,
1465         blas::ToDataType<InputType>::value, lda, b,
1466         blas::ToDataType<InputType>::value, ldb, beta_ptr, c,
1467         blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm,
1468         output_profile_result);
1469     if (output_profile_result) {
1470       // The error is recorded in the profile.
1471       return port::Status::OK();
1472     }
1473     return st;
1474   }
1475 
1476   template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmStridedBatchedWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,int64_t stride_c,int batch_count,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1477   port::Status ThenBlasGemmStridedBatchedWithAlgorithm(
1478       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1479       uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1480       int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
1481       int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
1482       int64_t stride_c, int batch_count, blas::ComputationType computation_type,
1483       blas::AlgorithmType algorithm,
1484       blas::ProfileResult *output_profile_result) {
1485     TF_RETURN_IF_ERROR(
1486         CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
1487             computation_type));
1488 
1489     blas::BlasSupport *blas = parent()->AsBlas();
1490     if (!blas) {
1491       return port::InternalError(
1492           "Attempting to perform BLAS operation using "
1493           "StreamExecutor without BLAS support");
1494     }
1495     void *alpha_ptr = &alpha;
1496     void *beta_ptr = &beta;
1497     float alpha_storage, beta_storage;
1498     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1499                                     &beta_storage);
1500     port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm(
1501         this, transa, transb, m, n, k, alpha_ptr, a,
1502         blas::ToDataType<InputType>::value, stride_a, lda, b,
1503         blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c,
1504         blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
1505         computation_type, algorithm, output_profile_result);
1506     if (output_profile_result) {
1507       // The error is recorded in the profile.
1508       return port::Status::OK();
1509     }
1510     return st;
1511   }
1512 
1513   // See BlasSupport::DoBlasGemmBatched.
1514   Stream &ThenBlasGemmBatched(
1515       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1516       uint64 k, float alpha,
1517       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1518       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1519       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1520       int ldc, int batch_count);
1521   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1522                               uint64 m, uint64 n, uint64 k, float alpha,
1523                               const port::ArraySlice<DeviceMemory<float> *> &a,
1524                               int lda,
1525                               const port::ArraySlice<DeviceMemory<float> *> &b,
1526                               int ldb, float beta,
1527                               const port::ArraySlice<DeviceMemory<float> *> &c,
1528                               int ldc, int batch_count);
1529   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1530                               uint64 m, uint64 n, uint64 k, double alpha,
1531                               const port::ArraySlice<DeviceMemory<double> *> &a,
1532                               int lda,
1533                               const port::ArraySlice<DeviceMemory<double> *> &b,
1534                               int ldb, double beta,
1535                               const port::ArraySlice<DeviceMemory<double> *> &c,
1536                               int ldc, int batch_count);
1537   Stream &ThenBlasGemmBatched(
1538       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1539       uint64 k, std::complex<float> alpha,
1540       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1541       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1542       std::complex<float> beta,
1543       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1544       int batch_count);
1545   Stream &ThenBlasGemmBatched(
1546       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1547       uint64 k, std::complex<double> alpha,
1548       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1549       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1550       std::complex<double> beta,
1551       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1552       int batch_count);
1553   Stream &ThenBlasGemmBatchedWithScratch(
1554       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1555       uint64 k, float alpha,
1556       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1557       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1558       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1559       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1560   Stream &ThenBlasGemmBatchedWithScratch(
1561       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1562       uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1563       int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1564       float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1565       int batch_count, ScratchAllocator *scratch_allocator);
1566   Stream &ThenBlasGemmBatchedWithScratch(
1567       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1568       uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1569       int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1570       double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1571       int batch_count, ScratchAllocator *scratch_allocator);
1572   Stream &ThenBlasGemmBatchedWithScratch(
1573       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1574       uint64 k, std::complex<float> alpha,
1575       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1576       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1577       std::complex<float> beta,
1578       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1579       int batch_count, ScratchAllocator *scratch_allocator);
1580   Stream &ThenBlasGemmBatchedWithScratch(
1581       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1582       uint64 k, std::complex<double> alpha,
1583       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1584       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1585       std::complex<double> beta,
1586       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1587       int batch_count, ScratchAllocator *scratch_allocator);
1588 
1589   template <typename InputType, typename ConstantType>
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<InputType> * c,int ldc,int64_t stride_c,int batch_count)1590   port::Status ThenBlasGemmStridedBatched(
1591       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1592       uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1593       int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
1594       int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc,
1595       int64_t stride_c, int batch_count) {
1596     static_assert(((std::is_same<InputType, Eigen::half>::value ||
1597                     std::is_same<InputType, Eigen::bfloat16>::value) &&
1598                    std::is_same<ConstantType, float>::value) ||
1599                       ((std::is_same<InputType, float>::value ||
1600                         std::is_same<InputType, Eigen::half>::value ||
1601                         std::is_same<InputType, Eigen::bfloat16>::value ||
1602                         std::is_same<InputType, double>::value ||
1603                         std::is_same<InputType, std::complex<float>>::value ||
1604                         std::is_same<InputType, std::complex<double>>::value) &&
1605                        std::is_same<ConstantType, InputType>::value),
1606                   "Input or constant type mismatch");
1607     blas::BlasSupport *blas = parent()->AsBlas();
1608     if (!blas) {
1609       return port::InternalError(
1610           "Attempting to perform BLAS operation using "
1611           "StreamExecutor without BLAS support");
1612     }
1613 
1614     void *alpha_ptr = &alpha;
1615     void *beta_ptr = &beta;
1616     float alpha_storage, beta_storage;
1617     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1618                                     &beta_storage);
1619 
1620     return blas->DoBlasGemmStridedBatched(
1621         this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
1622         alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
1623         stride_c, batch_count);
1624   }
1625 
1626   // See BlasSupport::DoBlasHemm.
1627   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1628                        uint64 n, std::complex<float> alpha,
1629                        const DeviceMemory<std::complex<float>> &a, int lda,
1630                        const DeviceMemory<std::complex<float>> &b, int ldb,
1631                        std::complex<float> beta,
1632                        DeviceMemory<std::complex<float>> *c, int ldc);
1633   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1634                        uint64 n, std::complex<double> alpha,
1635                        const DeviceMemory<std::complex<double>> &a, int lda,
1636                        const DeviceMemory<std::complex<double>> &b, int ldb,
1637                        std::complex<double> beta,
1638                        DeviceMemory<std::complex<double>> *c, int ldc);
1639 
1640   // See BlasSupport::DoBlasHerk.
1641   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1642                        uint64 k, float alpha,
1643                        const DeviceMemory<std::complex<float>> &a, int lda,
1644                        float beta, DeviceMemory<std::complex<float>> *c,
1645                        int ldc);
1646   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1647                        uint64 k, double alpha,
1648                        const DeviceMemory<std::complex<double>> &a, int lda,
1649                        double beta, DeviceMemory<std::complex<double>> *c,
1650                        int ldc);
1651 
1652   // See BlasSupport::DoBlasHer2k.
1653   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1654                         uint64 k, std::complex<float> alpha,
1655                         const DeviceMemory<std::complex<float>> &a, int lda,
1656                         const DeviceMemory<std::complex<float>> &b, int ldb,
1657                         float beta, DeviceMemory<std::complex<float>> *c,
1658                         int ldc);
1659   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1660                         uint64 k, std::complex<double> alpha,
1661                         const DeviceMemory<std::complex<double>> &a, int lda,
1662                         const DeviceMemory<std::complex<double>> &b, int ldb,
1663                         double beta, DeviceMemory<std::complex<double>> *c,
1664                         int ldc);
1665 
1666   // See BlasSupport::DoBlasSymm.
1667   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1668                        uint64 n, float alpha, const DeviceMemory<float> &a,
1669                        int lda, const DeviceMemory<float> &b, int ldb,
1670                        float beta, DeviceMemory<float> *c, int ldc);
1671   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1672                        uint64 n, double alpha, const DeviceMemory<double> &a,
1673                        int lda, const DeviceMemory<double> &b, int ldb,
1674                        double beta, DeviceMemory<double> *c, int ldc);
1675   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1676                        uint64 n, std::complex<float> alpha,
1677                        const DeviceMemory<std::complex<float>> &a, int lda,
1678                        const DeviceMemory<std::complex<float>> &b, int ldb,
1679                        std::complex<float> beta,
1680                        DeviceMemory<std::complex<float>> *c, int ldc);
1681   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1682                        uint64 n, std::complex<double> alpha,
1683                        const DeviceMemory<std::complex<double>> &a, int lda,
1684                        const DeviceMemory<std::complex<double>> &b, int ldb,
1685                        std::complex<double> beta,
1686                        DeviceMemory<std::complex<double>> *c, int ldc);
1687 
1688   // See BlasSupport::DoBlasSyrk.
1689   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1690                        uint64 k, float alpha, const DeviceMemory<float> &a,
1691                        int lda, float beta, DeviceMemory<float> *c, int ldc);
1692   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1693                        uint64 k, double alpha, const DeviceMemory<double> &a,
1694                        int lda, double beta, DeviceMemory<double> *c, int ldc);
1695   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1696                        uint64 k, std::complex<float> alpha,
1697                        const DeviceMemory<std::complex<float>> &a, int lda,
1698                        std::complex<float> beta,
1699                        DeviceMemory<std::complex<float>> *c, int ldc);
1700   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1701                        uint64 k, std::complex<double> alpha,
1702                        const DeviceMemory<std::complex<double>> &a, int lda,
1703                        std::complex<double> beta,
1704                        DeviceMemory<std::complex<double>> *c, int ldc);
1705 
1706   // See BlasSupport::DoBlasSyr2k.
1707   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1708                         uint64 k, float alpha, const DeviceMemory<float> &a,
1709                         int lda, const DeviceMemory<float> &b, int ldb,
1710                         float beta, DeviceMemory<float> *c, int ldc);
1711   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1712                         uint64 k, double alpha, const DeviceMemory<double> &a,
1713                         int lda, const DeviceMemory<double> &b, int ldb,
1714                         double beta, DeviceMemory<double> *c, int ldc);
1715   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1716                         uint64 k, std::complex<float> alpha,
1717                         const DeviceMemory<std::complex<float>> &a, int lda,
1718                         const DeviceMemory<std::complex<float>> &b, int ldb,
1719                         std::complex<float> beta,
1720                         DeviceMemory<std::complex<float>> *c, int ldc);
1721   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1722                         uint64 k, std::complex<double> alpha,
1723                         const DeviceMemory<std::complex<double>> &a, int lda,
1724                         const DeviceMemory<std::complex<double>> &b, int ldb,
1725                         std::complex<double> beta,
1726                         DeviceMemory<std::complex<double>> *c, int ldc);
1727 
1728   // See BlasSupport::DoBlasTrmm.
1729   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1730                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1731                        uint64 n, float alpha, const DeviceMemory<float> &a,
1732                        int lda, DeviceMemory<float> *b, int ldb);
1733   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1734                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1735                        uint64 n, double alpha, const DeviceMemory<double> &a,
1736                        int lda, DeviceMemory<double> *b, int ldb);
1737   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1738                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1739                        uint64 n, std::complex<float> alpha,
1740                        const DeviceMemory<std::complex<float>> &a, int lda,
1741                        DeviceMemory<std::complex<float>> *b, int ldb);
1742   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1743                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1744                        uint64 n, std::complex<double> alpha,
1745                        const DeviceMemory<std::complex<double>> &a, int lda,
1746                        DeviceMemory<std::complex<double>> *b, int ldb);
1747 
1748   // See BlasSupport::DoBlasTrsm.
1749   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1750                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1751                        uint64 n, float alpha, const DeviceMemory<float> &a,
1752                        int lda, DeviceMemory<float> *b, int ldb);
1753   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1754                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1755                        uint64 n, double alpha, const DeviceMemory<double> &a,
1756                        int lda, DeviceMemory<double> *b, int ldb);
1757   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1758                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1759                        uint64 n, std::complex<float> alpha,
1760                        const DeviceMemory<std::complex<float>> &a, int lda,
1761                        DeviceMemory<std::complex<float>> *b, int ldb);
1762   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1763                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1764                        uint64 n, std::complex<double> alpha,
1765                        const DeviceMemory<std::complex<double>> &a, int lda,
1766                        DeviceMemory<std::complex<double>> *b, int ldb);
1767 
1768   // See BlasSupport::DoBlatLtMatmul.
1769   // Note that we prevent alpha and beta from being used to deduce CType so that
1770   // they can be constructed implicitly from values of type CType. Without this,
1771   // type deduction would fail when this function is called with a value of type
1772   // CType for alpha or beta.
1773   template <typename ABType, typename CType>
1774   Stream &ThenBlasLtMatmul(
1775       const blas::IBlasLtMatmulPlan *plan,
1776       const detail::NonDeducedType<HostOrDeviceScalar<CType>> &alpha,
1777       const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
1778       const detail::NonDeducedType<HostOrDeviceScalar<CType>> &beta,
1779       DeviceMemory<CType> *c, ScratchAllocator *scratch_allocator,
1780       const blas::IBlasLtMatmulAlgorithm *algorithm,
1781       const DeviceMemory<CType> &bias = {},
1782       blas::ProfileResult *output_profile_result = nullptr) {
1783     return ThenBlasLtMatmulImpl(plan, alpha, a, b, beta, c, scratch_allocator,
1784                                 algorithm, bias, output_profile_result);
1785   }
1786 
1787   // See FftSupport::DoFft.
1788   Stream &ThenFft(fft::Plan *plan,
1789                   const DeviceMemory<std::complex<float>> &input,
1790                   DeviceMemory<std::complex<float>> *output);
1791   Stream &ThenFft(fft::Plan *plan,
1792                   const DeviceMemory<std::complex<double>> &input,
1793                   DeviceMemory<std::complex<double>> *output);
1794   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1795                   DeviceMemory<std::complex<float>> *output);
1796   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1797                   DeviceMemory<std::complex<double>> *output);
1798   Stream &ThenFft(fft::Plan *plan,
1799                   const DeviceMemory<std::complex<float>> &input,
1800                   DeviceMemory<float> *output);
1801   Stream &ThenFft(fft::Plan *plan,
1802                   const DeviceMemory<std::complex<double>> &input,
1803                   DeviceMemory<double> *output);
1804 
1805   // Makes the RNG use the provided value as the basis for further generation.
1806   // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1807   // sources of seed data if the default (high quality) sources are not
1808   // desired.
1809   // For most use cases, this function will not be necessary; each provided
1810   // back-end implementation will be appropriately seeded by default.
1811   // At a minimum 16 bytes of data are required in the seed buffer.
1812   //
1813   // To seed with good (non-reproducible) data:
1814   //   File* f = File::Open("/dev/random", "r");
1815   //   int64 bytes_read = f->Read(seed_data, bytes_to_read);
1816   //   < error checking >
1817   //   stream.ThenSetRngSeed(seed_data, bytes_read);
1818   //
1819   // To seed with reproducible data:
1820   //   uint64_t seed_data[2] = { <data> };
1821   //   stream.ThenSetRngSeed(seed_data, 16);
1822   Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1823 
1824   // Populates the memory indicated by values with uniform-random-distribution
1825   // values. TODO(leary) seeding API/description
1826   //
1827   // Uses the type and size of the DeviceMemory to infer what data should be
1828   // populated.
1829   Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1830   Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1831   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1832   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1833   Stream &ThenPopulateRandGaussian(float mean, float stddev,
1834                                    DeviceMemory<float> *values);
1835   Stream &ThenPopulateRandGaussian(double mean, double stddev,
1836                                    DeviceMemory<double> *values);
1837 
1838   // Entrain onto the stream: a memcpy to a host destination from a GPU source
1839   // of the given target size. host_dst must be a pointer to host memory
1840   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1841   // then registered with StreamExecutor::HostMemoryRegister.
1842   Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1843                      uint64 size);
1844 
1845   // Entrain onto the stream: a memcpy to a GPU destination from a host source
1846   // of the given target size. host_src must be a pointer to host memory
1847   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1848   // then registered with StreamExecutor::HostMemoryRegister.
1849   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1850                      uint64 size);
1851 
1852   // Alternative interface for memcpying from device to host that takes an
1853   // array slice. Checks that the destination size can accommodate the host
1854   // slice size.
1855   template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1856   Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1857                         port::MutableArraySlice<T> host_dst) {
1858     auto host_size = host_dst.size() * sizeof(T);
1859     CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1860     return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1861   }
1862 
1863   // Alternative interface for memcpying from host to device that takes an
1864   // array slice. Checks that the destination size can accommodate the host
1865   // slice size.
1866   template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1867   Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1868                         DeviceMemory<T> *gpu_dst) {
1869     auto host_size = host_src.size() * sizeof(T);
1870     CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1871     return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1872   }
1873 
1874   // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1875   // of the given target size. gpu_src/dst must be pointers to GPU memory and
1876   // peer access must be enabled between their owning StreamExecutors.
1877   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1878                      uint64 size);
1879 
1880   // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1881   // ensuring that the host pointer isn't getting confused accidentally with a
1882   // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1883   Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1884                         const DeviceMemoryBase &gpu_src, uint64 size) {
1885     return ThenMemcpy(gpu_dst, gpu_src, size);
1886   }
1887 
1888   // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1889   // The location must not be null.
1890   Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1891 
1892   // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1893   // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1894   // by 4). The location must not be null.
1895   Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1896 
1897   // Enqueue a forward operation of the RNN model onto the stream.
1898   // See DnnSupport::DoRnnForward for more details.
1899   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1900                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1901                          const DeviceMemory<Eigen::half> &input_data,
1902                          const DeviceMemory<int> &seq_lengths_data,
1903                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1904                          const DeviceMemory<Eigen::half> &input_h_data,
1905                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1906                          const DeviceMemory<Eigen::half> &input_c_data,
1907                          const DeviceMemory<Eigen::half> &params,
1908                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1909                          DeviceMemory<Eigen::half> *output_data,
1910                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1911                          DeviceMemory<Eigen::half> *output_h_data,
1912                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1913                          DeviceMemory<Eigen::half> *output_c_data,
1914                          bool is_training,
1915                          ScratchAllocator *reserve_space_allocator,
1916                          ScratchAllocator *workspace_allocator,
1917                          dnn::ProfileResult *output_profile_result);
1918 
1919   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1920                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1921                          const DeviceMemory<float> &input_data,
1922                          const DeviceMemory<int> &seq_lengths_data,
1923                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1924                          const DeviceMemory<float> &input_h_data,
1925                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1926                          const DeviceMemory<float> &input_c_data,
1927                          const DeviceMemory<float> &params,
1928                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1929                          DeviceMemory<float> *output_data,
1930                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1931                          DeviceMemory<float> *output_h_data,
1932                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1933                          DeviceMemory<float> *output_c_data, bool is_training,
1934                          ScratchAllocator *reserve_space_allocator,
1935                          ScratchAllocator *workspace_allocator,
1936                          dnn::ProfileResult *output_profile_result);
1937 
1938   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1939                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1940                          const DeviceMemory<double> &input_data,
1941                          const DeviceMemory<int> &seq_lengths_data,
1942                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1943                          const DeviceMemory<double> &input_h_data,
1944                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1945                          const DeviceMemory<double> &input_c_data,
1946                          const DeviceMemory<double> &params,
1947                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1948                          DeviceMemory<double> *output_data,
1949                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1950                          DeviceMemory<double> *output_h_data,
1951                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1952                          DeviceMemory<double> *output_c_data, bool is_training,
1953                          ScratchAllocator *reserve_space_allocator,
1954                          ScratchAllocator *workspace_allocator,
1955                          dnn::ProfileResult *output_profile_result);
1956 
1957   // Enqueue a backward operation of the RNN model onto the stream.
1958   // See DnnSupport::DoRnnBackward for more details.
1959   Stream &ThenRnnBackward(
1960       const dnn::RnnDescriptor &rnn_desc,
1961       const dnn::RnnSequenceTensorDescriptor &input_desc,
1962       const DeviceMemory<Eigen::half> &input_data,
1963       const DeviceMemory<int> &seq_lengths_data,
1964       const dnn::RnnStateTensorDescriptor &input_h_desc,
1965       const DeviceMemory<Eigen::half> &input_h_data,
1966       const dnn::RnnStateTensorDescriptor &input_c_desc,
1967       const DeviceMemory<Eigen::half> &input_c_data,
1968       const DeviceMemory<Eigen::half> &params,
1969       const dnn::RnnSequenceTensorDescriptor &output_desc,
1970       const DeviceMemory<Eigen::half> &output_data,
1971       const dnn::RnnStateTensorDescriptor &output_h_desc,
1972       const DeviceMemory<Eigen::half> &output_h_data,
1973       const dnn::RnnStateTensorDescriptor &output_c_desc,
1974       const DeviceMemory<Eigen::half> &output_c_data,
1975       const DeviceMemory<Eigen::half> &output_backprop_data,
1976       const DeviceMemory<Eigen::half> &output_h_backprop_data,
1977       const DeviceMemory<Eigen::half> &output_c_backprop_data,
1978       DeviceMemory<Eigen::half> *input_backprop_data,
1979       DeviceMemory<Eigen::half> *input_h_backprop_data,
1980       DeviceMemory<Eigen::half> *input_c_backprop_data,
1981       DeviceMemory<Eigen::half> *params_backprop_data,
1982       DeviceMemory<uint8> *reserve_space_data,
1983       ScratchAllocator *workspace_allocator,
1984       dnn::ProfileResult *output_profile_result);
1985 
1986   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1987                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1988                           const DeviceMemory<float> &input_data,
1989                           const DeviceMemory<int> &seq_lengths_data,
1990                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1991                           const DeviceMemory<float> &input_h_data,
1992                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1993                           const DeviceMemory<float> &input_c_data,
1994                           const DeviceMemory<float> &params,
1995                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1996                           const DeviceMemory<float> &output_data,
1997                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1998                           const DeviceMemory<float> &output_h_data,
1999                           const dnn::RnnStateTensorDescriptor &output_c_desc,
2000                           const DeviceMemory<float> &output_c_data,
2001                           const DeviceMemory<float> &output_backprop_data,
2002                           const DeviceMemory<float> &output_h_backprop_data,
2003                           const DeviceMemory<float> &output_c_backprop_data,
2004                           DeviceMemory<float> *input_backprop_data,
2005                           DeviceMemory<float> *input_h_backprop_data,
2006                           DeviceMemory<float> *input_c_backprop_data,
2007                           DeviceMemory<float> *params_backprop_data,
2008                           DeviceMemory<uint8> *reserve_space_data,
2009                           ScratchAllocator *workspace_allocator,
2010                           dnn::ProfileResult *output_profile_result);
2011 
2012   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
2013                           const dnn::RnnSequenceTensorDescriptor &input_desc,
2014                           const DeviceMemory<double> &input_data,
2015                           const DeviceMemory<int> &seq_lengths_data,
2016                           const dnn::RnnStateTensorDescriptor &input_h_desc,
2017                           const DeviceMemory<double> &input_h_data,
2018                           const dnn::RnnStateTensorDescriptor &input_c_desc,
2019                           const DeviceMemory<double> &input_c_data,
2020                           const DeviceMemory<double> &params,
2021                           const dnn::RnnSequenceTensorDescriptor &output_desc,
2022                           const DeviceMemory<double> &output_data,
2023                           const dnn::RnnStateTensorDescriptor &output_h_desc,
2024                           const DeviceMemory<double> &output_h_data,
2025                           const dnn::RnnStateTensorDescriptor &output_c_desc,
2026                           const DeviceMemory<double> &output_c_data,
2027                           const DeviceMemory<double> &output_backprop_data,
2028                           const DeviceMemory<double> &output_h_backprop_data,
2029                           const DeviceMemory<double> &output_c_backprop_data,
2030                           DeviceMemory<double> *input_backprop_data,
2031                           DeviceMemory<double> *input_h_backprop_data,
2032                           DeviceMemory<double> *input_c_backprop_data,
2033                           DeviceMemory<double> *params_backprop_data,
2034                           DeviceMemory<uint8> *reserve_space_data,
2035                           ScratchAllocator *workspace_allocator,
2036                           dnn::ProfileResult *output_profile_result);
2037 
2038   // Enqueue a CTCLoss operation onto the stream.
2039   // See DnnSupport::DoCtcLoss for more details.
2040   Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
2041                       const DeviceMemory<float> &probs_data,
2042                       absl::Span<const int> labels_data,
2043                       absl::Span<const int> labels_lengths_data,
2044                       absl::Span<const int> input_lengths_data,
2045                       DeviceMemory<float> *costs_data,
2046                       const dnn::RnnStateTensorDescriptor &grads_desc,
2047                       DeviceMemory<float> *grads_data,
2048                       ScratchAllocator *workspace_allocator);
2049 
2050   // Enqueue onto the stream a operation that transforms a tensor.
2051   // See DnnSupport::DoTransformTensor for more details.
2052   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
2053                               dnn::DataType input_type,
2054                               const DeviceMemoryBase &input_data,
2055                               const dnn::BatchDescriptor &output_desc,
2056                               dnn::DataType output_type, float scale,
2057                               DeviceMemoryBase *output_data);
2058 
2059   // The templated version of the above ThenTransformTensor. Useful when the
2060   // input and output types are statically known.
2061   template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)2062   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
2063                               const DeviceMemory<InElemT> &input_data,
2064                               const dnn::BatchDescriptor &output_desc,
2065                               DeviceMemory<OutElemT> *output_data) {
2066     return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
2067                                input_data, output_desc,
2068                                dnn::ToDataType<OutElemT>(), output_data);
2069   }
2070 
2071   // (Synchronously) block the host code waiting for the operations
2072   // entrained on the stream (enqueued to this point in program
2073   // execution) to complete.
2074   //
2075   // Returns an OK status if the blocking was successful and the stream is ok().
2076   // Otherwise returns an error describing why the blocking failed.
2077   port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_);
2078 
2079   // Warning! This method interacts with internal threads in
2080   // sometimes-unpredictable ways and is intended for GPU-Executor-internal
2081   // use
2082   // only. Please check with a member of the FASTR team before making use of
2083   // this method.
2084   //
2085   // Entrains onto the stream a function to be executed on the host at some
2086   // point in the future.
2087   // Async host callbacks DO NOT block the stream as device functions (or as
2088   // synchronous host callbacks). No synchronization is possible with
2089   // asynchronous callbacks; they are strictly fire-and-forget.
2090   // This method is private due to the potential for undefined behavior with
2091   // synchronization using OpenCL user events.
2092   // The ONLY lifetime guarantee in these calls is that the StreamExecutor
2093   // parameter will still be valid - this Stream may not be!
2094   // Any callbacks requiring device API calls must use this method.
2095   Stream &ThenEnqueueOnBackgroundThread(
2096       std::function<void(StreamExecutor *)> task);
2097 
2098   // Returns the (opaque) platform-specific backing object. Ownership is not
2099   // transferred to the caller.
implementation()2100   internal::StreamInterface *implementation() { return implementation_.get(); }
2101 
2102   // Entrains onto the stream a callback to the host (from the device).
2103   // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
2104   // never fail or its failure is inconsequential.
2105   //
2106   // This is kept for backward compatibility. Future code should use
2107   // ThenDoHostCallbackWithStatus and explicitly return a success status.
2108   // TODO(b/112125301): Eventually remove this method.
2109   Stream &ThenDoHostCallback(std::function<void()> callback);
2110 
2111   // Entrains onto the stream a callback to the host (from the device).
2112   // Host callbacks block/occupy the stream just as device functions
2113   // (execute one at a time, block later stream operations).
2114   // Whether the callback return status affects the result of BlockHostUntilDone
2115   // is platform-dependent.
2116   //
2117   // Behavior is undefined when synchronizing using OpenCL user events.
2118   // Behavior is undefined if host callbacks call device routines or insert
2119   // them into any stream.
2120   //
2121   // On certain platforms, ThenDoHostCallback is expected to have significant
2122   // negative effects on performance.
2123   Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
2124 
2125   // Runs the given callback after the next call to BlockHostUntilDone on this
2126   // stream (or after the Stream does BlockHostUntilDone in its destructor).
2127   // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
2128   // some use cases.
2129   Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
2130 
2131   // Returns the StreamExecutor (parent object) associated with this stream.
parent()2132   StreamExecutor *parent() const {
2133     CHECK(parent_ != nullptr);
2134     return parent_;
2135   }
2136 
2137   //
GetCudaComputeCapability()2138   CudaComputeCapability GetCudaComputeCapability() const {
2139     return parent()->GetDeviceDescription().cuda_compute_capability();
2140   }
2141 
2142   // Returns the (internal usage) temporary-memory-allocation manager associated
2143   // with this stream.
2144   internal::TemporaryMemoryManager *temporary_memory_manager();
2145 
2146   // Returns a debugging string "[stream=0x...,impl=0x...]".
2147   std::string DebugStreamPointers() const;
2148 
2149  private:
2150   friend class host::HostBlas;  // for parent_.
2151   friend class host::HostFft;   // for parent_.
2152   friend class host::HostRng;   // for parent_.
2153   template <typename... Args>
2154   friend struct ThenBlasImpl;  // for implementing ThenBlasXXX.
2155   friend class ocl::CLBlas;    // for parent_.
2156 
2157   // Checks whether types match before a call to extended BLAS version.
2158   template <typename InputType, typename OutputType, typename ConstantType>
CheckTypesForExtendedBlas(blas::ComputationType computation_type)2159   port::Status CheckTypesForExtendedBlas(
2160       blas::ComputationType computation_type) {
2161     static_assert(std::is_same<InputType, Eigen::half>::value ||
2162                       std::is_same<InputType, Eigen::bfloat16>::value ||
2163                       std::is_same<InputType, float>::value ||
2164                       std::is_same<InputType, double>::value ||
2165                       std::is_same<InputType, int8>::value ||
2166                       std::is_same<InputType, std::complex<float>>::value ||
2167                       std::is_same<InputType, std::complex<double>>::value,
2168                   "The only buffer types supported are: Eigen::half, float, "
2169                   "double, int8, std::complex<float> and std::complex<double>");
2170     static_assert(
2171         std::is_same<InputType, OutputType>::value ||
2172             (std::is_same<InputType, int8>::value &&
2173              std::is_same<OutputType, int32>::value),
2174         "Input and output buffer types should be the same unless input is "
2175         "int8 and output is int32");
2176     static_assert(std::is_same<ConstantType, OutputType>::value ||
2177                       (std::is_same<ConstantType, float>::value &&
2178                        (std::is_same<OutputType, Eigen::half>::value ||
2179                         std::is_same<OutputType, Eigen::bfloat16>::value)),
2180                   "Constant and output types should match");
2181     blas::ComputationType expected_computation_type =
2182         blas::ToComputationType<ConstantType>::value;
2183     if (expected_computation_type != computation_type &&
2184         !(computation_type == blas::ComputationType::kF32 &&
2185           (expected_computation_type == blas::ComputationType::kF16 ||
2186            expected_computation_type == blas::ComputationType::kBF16AsF32))) {
2187       return port::InternalError(absl::StrCat(
2188           "Alpha/beta type and computation type have to match, got ",
2189           blas::ComputationTypeString(computation_type),
2190           " for computation type, expected: ",
2191           blas::ComputationTypeString(expected_computation_type)));
2192     }
2193     return port::Status::OK();
2194   }
2195 
InErrorState()2196   bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
2197     absl::ReaderMutexLock lock(&mu_);
2198     return !status_.ok();
2199   }
2200 
2201   // Sets the error state if operation_retcode is false.
2202   // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)2203   void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_) {
2204     if (operation_retcode) {
2205       return;
2206     }
2207     absl::MutexLock lock(&mu_);
2208     status_ = port::InternalError("Unknown error");
2209   }
2210 
2211   // Checks the status and logs the error message, if any.
2212   void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_);
2213 
SetError()2214   void SetError() { CheckError(false /* = operation_retcode */); }
2215 
SetErrorAndLogNoDnnSupport()2216   void SetErrorAndLogNoDnnSupport() {
2217     SetError();
2218     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
2219                     "without DNN support";
2220   }
2221 
2222   // Runs the set of callbacks that are intended to run after
2223   // BlockHostUntilDone.
2224   void RunAfterBlockHostUntilDoneCallbacks();
2225 
2226   // The StreamExecutor that supports the operation of this stream.
2227   StreamExecutor *parent_;
2228 
2229   // The platform-dependent implementation that the StreamExecutor interface
2230   // delegates to.
2231   std::unique_ptr<internal::StreamInterface> implementation_;
2232 
2233   // mutex that guards the allocation / error state flags.
2234   // Mutable so that it can be obtained via const reader lock.
2235   mutable absl::Mutex mu_;
2236 
2237   // Whether Init() was successfully called to allocate this stream on the
2238   // underlying platform. It simply flips from 0 to 1 with a sanity check.
2239   // See StreamExecutor::AllocateStream.
2240   bool allocated_ TF_GUARDED_BY(mu_);
2241 
2242   // The last error (if any) of all method calls.
2243   port::Status status_ TF_GUARDED_BY(mu_);
2244 
2245   // Sub-streams that are generated from this stream. Each element has a pointer
2246   // to sub-stream and a boolean value indicating if this substream is ready to
2247   // be reused.
2248   std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2249       TF_GUARDED_BY(mu_);
2250 
2251   // Streams can allocate temporary memories to help with work they enqueue
2252   // (e.g. for scratch memory spaces). This member tracks those allocations and
2253   // notes when they can be reclaimed -- reclamation is attempted when
2254   // BlockHostUntilDone() is called.
2255   internal::TemporaryMemoryManager temporary_memory_manager_;
2256 
2257   // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
2258   std::vector<std::function<void()>> after_block_host_until_done_callbacks_
2259       TF_GUARDED_BY(mu_);
2260 
2261   // Implementation of ThenBlasLtMatmul that is shared by all types.
2262   template <typename ABType, typename CType>
2263   Stream &ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan *plan,
2264                                const HostOrDeviceScalar<CType> &alpha,
2265                                const DeviceMemory<ABType> &a,
2266                                const DeviceMemory<ABType> &b,
2267                                const HostOrDeviceScalar<CType> &beta,
2268                                DeviceMemory<CType> *c,
2269                                ScratchAllocator *scratch_allocator,
2270                                const blas::IBlasLtMatmulAlgorithm *algorithm,
2271                                const DeviceMemory<CType> &bias,
2272                                blas::ProfileResult *output_profile_result);
2273 
2274   // Non-extended BLAS interface requires alpha/beta to be floats when input
2275   // type is Eigen::half. However, for consistency purposes it is convenient
2276   // for the interface to accept Eigen::half.
2277   template <typename T>
UpcastHalfToFloat(void ** alpha_ptr,void ** beta_ptr,float * alpha_storage,float * beta_storage)2278   void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr,
2279                          float *alpha_storage, float *beta_storage) {
2280     if (std::is_same<T, Eigen::half>::value) {
2281       *alpha_storage =
2282           static_cast<float>(*reinterpret_cast<Eigen::half *>(*alpha_ptr));
2283       *beta_storage =
2284           static_cast<float>(*reinterpret_cast<Eigen::half *>(*beta_ptr));
2285       *alpha_ptr = alpha_storage;
2286       *beta_ptr = beta_storage;
2287     } else if (std::is_same<T, Eigen::bfloat16>::value) {
2288       *alpha_storage =
2289           static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*alpha_ptr));
2290       *beta_storage =
2291           static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*beta_ptr));
2292       *alpha_ptr = alpha_storage;
2293       *beta_ptr = beta_storage;
2294     }
2295   }
2296 
2297   SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2298 };
2299 
2300 ////////////
2301 // Inlines
2302 
2303 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)2304 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
2305                                   const TypedKernel<Params...> &kernel,
2306                                   Args... args) {
2307   KernelInvocationChecker<std::tuple<Params...>,
2308                           std::tuple<Args...>>::CheckAllStaticAssert();
2309   if (ok()) {
2310     // This is the core that allows type-safe kernel launching.
2311     // Since the platforms take kernel arguments as tuples of (void *, size),
2312     // we pack the variadic parameters passed as ...args into the desired
2313     // tuple form and pass that packed form to the StreamExecutor::Launch()
2314     // implementation.
2315     KernelArgsArray<sizeof...(args)> kernel_args;
2316     kernel.PackParams(&kernel_args, args...);
2317     DCHECK(parent_ != nullptr);
2318     bool ok =
2319         parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)
2320             .ok();
2321     if (!ok) {
2322       SetError();
2323       LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
2324     }
2325   }
2326   return *this;
2327 }
2328 
2329 template <typename T>
2330 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2331 Stream::AllocateTemporaryArray(uint64 element_count) {
2332   return temporary_memory_manager_.AllocateArray<T>(element_count);
2333 }
2334 
temporary_memory_manager()2335 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2336   return &temporary_memory_manager_;
2337 }
2338 
2339 template <>
2340 struct Quantization<uint8> {
2341   static constexpr dnn::QuantizedActivationMode kModeId =
2342       dnn::QuantizedActivationMode::k8Bit;
2343 };
2344 
2345 template <>
2346 struct Quantization<uint16> {
2347   static constexpr dnn::QuantizedActivationMode kModeId =
2348       dnn::QuantizedActivationMode::k16Bit;
2349 };
2350 
2351 template <>
2352 struct Quantization<int32> {
2353   static constexpr dnn::QuantizedActivationMode kModeId =
2354       dnn::QuantizedActivationMode::k32Bit;
2355 };
2356 
2357 }  // namespace stream_executor
2358 
2359 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2360