• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20 
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23 
24 #include <complex>
25 #include <functional>
26 #include <memory>
27 
28 #include "absl/synchronization/mutex.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/stream_executor/blas.h"
32 #include "tensorflow/stream_executor/device_memory.h"
33 #include "tensorflow/stream_executor/dnn.h"
34 #include "tensorflow/stream_executor/event.h"
35 #include "tensorflow/stream_executor/fft.h"
36 #include "tensorflow/stream_executor/host_or_device_scalar.h"
37 #include "tensorflow/stream_executor/kernel.h"
38 #include "tensorflow/stream_executor/launch_dim.h"
39 #include "tensorflow/stream_executor/lib/array_slice.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
42 #include "tensorflow/stream_executor/temporary_memory_manager.h"
43 
44 namespace stream_executor {
45 
46 namespace host {
47 class HostBlas;
48 class HostFft;
49 class HostRng;
50 class HostTimer;
51 }  // namespace host
52 
53 namespace ocl {
54 class CLBlas;
55 }  // namespace ocl
56 
57 namespace internal {
58 class StreamInterface;
59 }  // namespace internal
60 
61 class DeviceMemoryBase;
62 template <typename ElemT>
63 class DeviceMemory;
64 
65 class Timer;
66 
67 namespace dnn {
68 class BatchDescriptor;
69 class FilterDescriptor;
70 class ConvolutionDescriptor;
71 class ProfileResult;
72 class AlgorithmDesc;
73 }  // namespace dnn
74 
75 class StreamExecutor;
76 class ScratchAllocator;
77 
78 namespace detail {
79 
80 // Helper class to prevent a template function argument from being deduced. This
81 // is identical to std::type_identity in C++20.
82 template <typename T>
83 struct NonDeduced {
84   using type = T;
85 };
86 template <typename T>
87 using NonDeducedType = typename NonDeduced<T>::type;
88 
89 }  // namespace detail
90 
91 // Convert a type to the corresponding QuantizedActivationMode.
92 template <typename ElementType>
93 struct Quantization;
94 
95 // Represents a stream of dependent computations on a GPU device.
96 //
97 // The operations within a stream execute linearly and asynchronously until
98 // BlockHostUntilDone() is invoked, which synchronously joins host code with
99 // the execution of the stream.
100 //
101 // If any given operation fails when entraining work for the stream, ok() will
102 // indicate that an error has occurred. After initialization, once a stream is
103 // !ok(), it will never be ok().
104 //
105 // Thread-safe post-initialization.
106 class Stream {
107  public:
108   // Instantiate a stream tied to parent as a platform executor. Work
109   // entrained onto this stream will be launched/managed on that
110   // StreamExecutor's platform.
111   explicit Stream(StreamExecutor *parent);
112 
113   // Test only. Use an externally-populated value (like a mock) for the
114   // platform-specific stream implementation.
115   Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
116 
117   // Deallocates any stream resources that the parent StreamExecutor has
118   // bestowed
119   // upon this object.
120   ~Stream();
121 
122   // Returns whether any errors have occurred while entraining work for this
123   // stream.
ok()124   bool ok() const { return !InErrorState(); }
125 
126   // Retrieves execution status back into the stream from the underlying
127   // implementation without blocking the stream.
128   //
129   // Normally, Stream::BlockHostUntilDone is used to get execution status.
130   // However, some devices use out-of-band mechnanisms to ensure their streams
131   // have finished on-device work, without needing to block the streams. (These
132   // devices should also override AllowsSyncOnCompletion to return false.) For
133   // these devices, this method can be used after work is finished to retrieve
134   // execution status.
135   port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_);
136 
137   // Initialize the stream. This must be performed before entraining any other
138   // operations.
139   Stream &Init() TF_LOCKS_EXCLUDED(mu_);
140 
141   // Initializes timer t via the StreamExecutor.
142   Stream &InitTimer(Timer *t);
143 
144   // Convenience wrapper around Init() and InitTimer().
145   Stream &InitWithTimer(Timer *t);
146 
147   // Get or create a sub-stream from this stream. If there is any sub-stream in
148   // the pool that can be reused then just return this sub-stream.  Otherwise
149   // create a new sub-stream.
150   //
151   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
152   Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_);
153 
154   // Return the sub-stream back to the host stream so that it can be reused
155   // later. Sub-streams that are !ok() will not be reused.
156   //
157   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
158   void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_);
159 
160   // Allocate temporary memories. The stream will deallocate them when blocked
161   // or destroyed.
162   template <typename T>
163   port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
164   AllocateTemporaryArray(uint64 element_count);
165 
166   // Entrains onto the stream of operations: a kernel launch with the given
167   // (variadic) parameters for the invocation. These arguments can be things
168   // like DeviceMemory or primitive types such as int. What arguments you may
169   // pass to a given kernel are noted as the template parameters to the
170   // TypedKernel type that the machocc compiler generates.
171   //
172   // Template parameters:
173   //  Params...   The type list of formal parameters that the typed kernel
174   //              expects, which is matched against Args...
175   //  Args...     The deduced type list for passed actual arguments
176   //
177   // Implementation: A compile-time compatibility check is performed that has
178   // some leniency versus an exact parameter pack match -- for example,
179   // `const DeviceMemory<T>` is considered "pack compatible" with a
180   // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
181   // perfect forwarding support without rvalue references. It also attempts to
182   // spit out helpful static_assert error traces with information as to the
183   // argument number and types that were mismatched.
184   template <typename... Params, typename... Args>
185   Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
186                      const TypedKernel<Params...> &kernel, Args... args);
187 
188   // Record a "start" event for the interval timer at this point in the
189   // stream's execution (relative to the previously and subsequently enqueued
190   // items in the stream's execution). Streams may be started/stopped multiple
191   // times.
192   Stream &ThenStartTimer(Timer *t);
193 
194   // Record a "stop" event for the interval timer at this point in the
195   // stream's execution. See also Stream::ThenStartTimer.
196   Stream &ThenStopTimer(Timer *t);
197 
198   // TODO(leary) If work is added to the stream that is being depended upon,
199   //              then what? Have to describe what happens.
200   template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)201   Stream &ThenWaitFor(Stream *other, Params... more_streams) {
202     return ThenWaitFor(more_streams...).ThenWaitFor(other);
203   }
204 
205   // Create a dependency for this stream's next work on the other stream
206   // completing. Does not take ownership of other, and other must not be
207   // null.
208   //
209   // Checks that a stream does not wait for itself, and it is up to the
210   // user to guarantee that a stream does not come to wait on itself in a
211   // cyclic manner; in that case, behavior is undefined.
212   //
213   // N.B. Base recursion case for the variadic ThenWaitFor.
214   Stream &ThenWaitFor(Stream *other);
215 
216   // Waits for all streams values in others.
217   // Checks that there is no shallow circular wait (i.e. that "this" is not in
218   // others)
219   template <typename P>
ThenWaitFor(P others)220   Stream &ThenWaitFor(P others) {
221     for (auto &stream : *others) {
222       CHECK_NE(stream.get(), this);
223       ThenWaitFor(stream.get());
224     }
225     return *this;
226   }
227 
228   // Waits for an event object to be set.
229   // Note that ThenRecordEvent must have been called on the event before
230   // you call this function; otherwise the event will be considered complete
231   // and this wait will do nothing.
232   Stream &ThenWaitFor(Event *event);
233 
234   // Inserts the specified event into the end of this stream. Once the stream
235   // has processed all events prior to the insertion point, the event will be
236   // marked as completed.
237   // The stream does not take ownership of event - meaning that event's lifetime
238   // must extend past the point at which it is marked complete!
239   Stream &ThenRecordEvent(Event *event);
240 
241   ////////////////
242   // DNN support
243   //
244   // See DnnSupport::* for comments on the following methods.
245 
246   Stream &ThenBatchNormalizationForward(
247       const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
248       const DeviceMemory<float> &offset,
249       const DeviceMemory<float> &estimated_mean,
250       const DeviceMemory<float> &estimated_variance,
251       const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
252       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
253       const double exponential_average_factor,
254       dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
255       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
256       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
257       bool is_training,
258       ScratchAllocator *reserve_space_allocator,
259       ScratchAllocator *workspace_allocator);
260 
261   Stream &ThenBatchNormalizationBackward(
262       const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
263       const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
264       const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
265       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
266       DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
267       DeviceMemory<float> *offset_backprop,
268       DeviceMemory<uint8> *reserve_space_data,
269       ScratchAllocator *workspace_allocator);
270 
271   Stream &ThenBatchNormalizationForward(
272       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
273       const DeviceMemory<float> &offset,
274       const DeviceMemory<float> &estimated_mean,
275       const DeviceMemory<float> &estimated_variance,
276       const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
277       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
278       const double exponential_average_factor,
279       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
280       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
281       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
282       bool is_training,
283       ScratchAllocator *reserve_space_allocator,
284       ScratchAllocator *workspace_allocator);
285 
286   Stream &ThenBatchNormalizationBackward(
287       const DeviceMemory<Eigen::half> &y_backprop,
288       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
289       const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
290       const dnn::BatchDescriptor &x_desc,
291       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
292       DeviceMemory<Eigen::half> *x_backprop,
293       DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
294       DeviceMemory<uint8> *reserve_space_data,
295       ScratchAllocator *workspace_allocator);
296 
297   Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
298                        const DeviceMemory<float> &input_data,
299                        const dnn::FilterDescriptor &filter_descriptor,
300                        const DeviceMemory<float> &filter_data,
301                        const dnn::ConvolutionDescriptor &convolution_descriptor,
302                        const dnn::BatchDescriptor &output_descriptor,
303                        DeviceMemory<float> *output);
304 
305   Stream &ThenConvolveQuantized(
306       const dnn::BatchDescriptor &input_descriptor,
307       const DeviceMemory<float> &input_data,
308       const dnn::FilterDescriptor &filter_descriptor,
309       const DeviceMemory<int8> &filter_coefficients,
310       const DeviceMemory<float> &coefficient_scales,
311       const dnn::ConvolutionDescriptor &convolution_descriptor,
312       const dnn::BatchDescriptor &output_descriptor,
313       DeviceMemory<float> *output_data);
314 
315   Stream &ThenConvolveQuantized(
316       const dnn::BatchDescriptor &input_descriptor,
317       const DeviceMemory<float> &input_data,
318       const dnn::FilterDescriptor &filter_descriptor,
319       const DeviceMemory<int16> &filter_coefficients,
320       const DeviceMemory<float> &coefficient_scales,
321       const dnn::ConvolutionDescriptor &convolution_descriptor,
322       const dnn::BatchDescriptor &output_descriptor,
323       DeviceMemory<float> *output_data);
324 
325   template <typename InputType, typename OutputType>
ConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<InputType> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)326   port::Status ConvolveWithAlgorithm(
327       const dnn::BatchDescriptor &input_descriptor,
328       const DeviceMemory<InputType> &input_data,
329       const dnn::FilterDescriptor &filter_descriptor,
330       const DeviceMemory<InputType> &filter_data,
331       const dnn::ConvolutionDescriptor &convolution_descriptor,
332       const dnn::BatchDescriptor &output_descriptor,
333       DeviceMemory<OutputType> *output, ScratchAllocator *scratch_allocator,
334       const dnn::AlgorithmConfig &algorithm_config,
335       dnn::ProfileResult *output_profile_result) {
336     DeviceMemory<uint8> scratch_memory;
337     dnn::AlgorithmDesc algorithm_desc;
338     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
339       TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
340           dnn::ConvolutionKind::FORWARD, this, input_descriptor, input_data,
341           filter_descriptor, filter_data, output_descriptor, *output,
342           convolution_descriptor, algorithm_config, scratch_allocator,
343           &algorithm_desc, &scratch_memory));
344       return dnn->DoConvolve(
345           dnn::ConvolutionKind::FORWARD, dnn::ToDataType<InputType>::value,
346           dnn::ToDataType<OutputType>::value, this, input_descriptor,
347           input_data, filter_descriptor, filter_data, output_descriptor,
348           *output, convolution_descriptor, algorithm_desc, scratch_memory,
349           output_profile_result);
350     }
351     return port::UnimplementedError("DNN library is not found.");
352   }
353 
354   port::Status FusedConvolveWithAlgorithm(
355       const dnn::BatchDescriptor &conv_input_descriptor,
356       const DeviceMemory<double> &conv_input_data, double conv_input_scale,
357       const dnn::FilterDescriptor &filter_descriptor,
358       const DeviceMemory<double> &filter_data,
359       const dnn::ConvolutionDescriptor &convolution_descriptor,
360       const DeviceMemory<double> &side_input_data, double side_input_scale,
361       const dnn::BatchDescriptor &bias_descriptor,
362       const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
363       const dnn::BatchDescriptor &output_descriptor,
364       DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
365       const dnn::AlgorithmConfig &algorithm_config,
366       dnn::ProfileResult *output_profile_result);
367 
368   port::Status FusedConvolveWithAlgorithm(
369       const dnn::BatchDescriptor &conv_input_descriptor,
370       const DeviceMemory<float> &conv_input_data, float conv_input_scale,
371       const dnn::FilterDescriptor &filter_descriptor,
372       const DeviceMemory<float> &filter_data,
373       const dnn::ConvolutionDescriptor &convolution_descriptor,
374       const DeviceMemory<float> &side_input_data, float side_input_scale,
375       const dnn::BatchDescriptor &bias_descriptor,
376       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
377       const dnn::BatchDescriptor &output_descriptor,
378       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
379       const dnn::AlgorithmConfig &algorithm_config,
380       dnn::ProfileResult *output_profile_result);
381 
382   port::Status FusedConvolveWithAlgorithm(
383       const dnn::BatchDescriptor &conv_input_descriptor,
384       const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
385       const dnn::FilterDescriptor &filter_descriptor,
386       const DeviceMemory<Eigen::half> &filter_data,
387       const dnn::ConvolutionDescriptor &convolution_descriptor,
388       const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
389       const dnn::BatchDescriptor &bias_descriptor,
390       const DeviceMemory<Eigen::half> &biases,
391       dnn::ActivationMode activation_mode,
392       const dnn::BatchDescriptor &output_descriptor,
393       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
394       const dnn::AlgorithmConfig &algorithm_config,
395       dnn::ProfileResult *output_profile_result);
396 
397   port::Status FusedConvolveWithAlgorithm(
398       const dnn::BatchDescriptor &conv_input_descriptor,
399       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
400       const dnn::FilterDescriptor &filter_descriptor,
401       const DeviceMemory<int8> &filter_data,
402       const dnn::ConvolutionDescriptor &convolution_descriptor,
403       const DeviceMemory<int8> &side_input_data, float side_input_scale,
404       const dnn::BatchDescriptor &bias_descriptor,
405       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
406       const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
407       ScratchAllocator *scratch_allocator,
408       const dnn::AlgorithmConfig &algorithm_config,
409       dnn::ProfileResult *output_profile_result);
410 
411   port::Status FusedConvolveWithAlgorithm(
412       const dnn::BatchDescriptor &conv_input_descriptor,
413       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
414       const dnn::FilterDescriptor &filter_descriptor,
415       const DeviceMemory<int8> &filter_data,
416       const dnn::ConvolutionDescriptor &convolution_descriptor,
417       const DeviceMemory<float> &side_input_data, float side_input_scale,
418       const dnn::BatchDescriptor &bias_descriptor,
419       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
420       const dnn::BatchDescriptor &output_descriptor,
421       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
422       const dnn::AlgorithmConfig &algorithm_config,
423       dnn::ProfileResult *output_profile_result);
424 
425   Stream &ThenSeparableConvolve(
426       const dnn::BatchDescriptor &input_descriptor,
427       const DeviceMemory<float> &input_data,
428       const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
429       const DeviceMemory<float> &first_weights,
430       const DeviceMemory<float> &second_weights,
431       const dnn::ConvolutionDescriptor &convolution_descriptor,
432       const dnn::BatchDescriptor &output_descriptor,
433       DeviceMemory<float> *output);
434 
435   template <typename ElementType>
ConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<ElementType> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)436   port::Status ConvolveBackwardDataWithAlgorithm(
437       const dnn::FilterDescriptor &filter_descriptor,
438       const DeviceMemory<ElementType> &filter_data,
439       const dnn::BatchDescriptor &output_descriptor,
440       DeviceMemory<ElementType> backward_output_data,
441       const dnn::ConvolutionDescriptor &convolution_descriptor,
442       const dnn::BatchDescriptor &input_descriptor,
443       DeviceMemory<ElementType> *backward_input_data,
444       ScratchAllocator *scratch_allocator,
445       const dnn::AlgorithmConfig &algorithm_config,
446       dnn::ProfileResult *output_profile_result) {
447     DeviceMemory<uint8> scratch_memory;
448     dnn::AlgorithmDesc algorithm_desc;
449     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
450       TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
451           dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
452           *backward_input_data, filter_descriptor, filter_data,
453           output_descriptor, backward_output_data, convolution_descriptor,
454           algorithm_config, scratch_allocator, &algorithm_desc,
455           &scratch_memory));
456       return dnn->DoConvolve(
457           dnn::ConvolutionKind::BACKWARD_DATA,
458           dnn::ToDataType<ElementType>::value,
459           dnn::ToDataType<ElementType>::value, this, input_descriptor,
460           *backward_input_data, filter_descriptor, filter_data,
461           output_descriptor, backward_output_data, convolution_descriptor,
462           algorithm_desc, scratch_memory, output_profile_result);
463     }
464     return port::UnimplementedError("DNN library is not found.");
465   }
466 
467   template <typename ElementType>
ConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)468   port::Status ConvolveBackwardFilterWithAlgorithm(
469       const dnn::BatchDescriptor &input_descriptor,
470       const DeviceMemory<ElementType> &input_data,
471       const dnn::BatchDescriptor &output_descriptor,
472       DeviceMemory<ElementType> backward_output_data,
473       const dnn::ConvolutionDescriptor &convolution_descriptor,
474       const dnn::FilterDescriptor &filter_descriptor,
475       DeviceMemory<ElementType> *backward_filter_data,
476       ScratchAllocator *scratch_allocator,
477       const dnn::AlgorithmConfig &algorithm_config,
478       dnn::ProfileResult *output_profile_result) {
479     DeviceMemory<uint8> scratch_memory;
480     dnn::AlgorithmDesc algorithm_desc;
481     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
482       TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
483           dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
484           input_data, filter_descriptor, *backward_filter_data,
485           output_descriptor, backward_output_data, convolution_descriptor,
486           algorithm_config, scratch_allocator, &algorithm_desc,
487           &scratch_memory));
488       return dnn->DoConvolve(
489           dnn::ConvolutionKind::BACKWARD_FILTER,
490           dnn::ToDataType<ElementType>::value,
491           dnn::ToDataType<ElementType>::value, this, input_descriptor,
492           input_data, filter_descriptor, *backward_filter_data,
493           output_descriptor, backward_output_data, convolution_descriptor,
494           algorithm_desc, scratch_memory, output_profile_result);
495     }
496     return port::UnimplementedError("DNN library is not found.");
497   }
498 
499   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
500                                    const DeviceMemory<double> &input_data,
501                                    const dnn::BatchDescriptor &bias_descriptor,
502                                    DeviceMemory<double> *backward_bias_data);
503 
504   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
505                                    const DeviceMemory<float> &input_data,
506                                    const dnn::BatchDescriptor &bias_descriptor,
507                                    DeviceMemory<float> *backward_bias_data);
508 
509   Stream &ThenConvolveBackwardBias(
510       const dnn::BatchDescriptor &input_descriptor,
511       const DeviceMemory<Eigen::half> &input_data,
512       const dnn::BatchDescriptor &bias_descriptor,
513       DeviceMemory<Eigen::half> *backward_bias_data);
514 
515   Stream &ThenMatMul(const DeviceMemory<float> &input_data,
516                      const DeviceMemory<float> &weights,
517                      const dnn::BatchDescriptor &input_dimensions,
518                      const dnn::BatchDescriptor &output_dimensions,
519                      DeviceMemory<float> *output_data);
520 
521   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
522                               const DeviceMemory<int8> &weights,
523                               const DeviceMemory<float> &weight_scales,
524                               const dnn::BatchDescriptor &input_dimensions,
525                               const dnn::BatchDescriptor &output_dimensions,
526                               DeviceMemory<float> *output_data);
527 
528   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
529                               const DeviceMemory<int16> &weights,
530                               const DeviceMemory<float> &weight_scales,
531                               const dnn::BatchDescriptor &input_dimensions,
532                               const dnn::BatchDescriptor &output_dimensions,
533                               DeviceMemory<float> *output_data);
534 
535   Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
536                       const DeviceMemory<float> &biases,
537                       const dnn::BatchDescriptor &dimensions,
538                       DeviceMemory<float> *output_data);
539 
540   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
541                           const dnn::BatchDescriptor &input_dimensions,
542                           const DeviceMemory<double> &input_data,
543                           const dnn::BatchDescriptor &output_dimensions,
544                           DeviceMemory<double> *output_data,
545                           ScratchAllocator *workspace_allocator = nullptr);
546 
547   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
548                           const dnn::BatchDescriptor &input_dimensions,
549                           const DeviceMemory<float> &input_data,
550                           const dnn::BatchDescriptor &output_dimensions,
551                           DeviceMemory<float> *output_data,
552                           ScratchAllocator *workspace_allocator = nullptr);
553 
554   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
555                           const dnn::BatchDescriptor &input_dimensions,
556                           const DeviceMemory<Eigen::half> &input_data,
557                           const dnn::BatchDescriptor &output_dimensions,
558                           DeviceMemory<Eigen::half> *output_data,
559                           ScratchAllocator *workspace_allocator = nullptr);
560 
561   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
562                           const dnn::BatchDescriptor &input_dimensions,
563                           const DeviceMemory<int8> &input_data,
564                           const dnn::BatchDescriptor &output_dimensions,
565                           DeviceMemory<int8> *output_data,
566                           ScratchAllocator *workspace_allocator = nullptr);
567 
568   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
569                            const dnn::BatchDescriptor &input_dimensions,
570                            const DeviceMemory<double> &input_data,
571                            const dnn::BatchDescriptor &output_dimensions,
572                            const DeviceMemory<double> &output_data,
573                            const DeviceMemory<double> &input_diff_data,
574                            DeviceMemory<double> *output_diff_data,
575                            ScratchAllocator *workspace_allocator = nullptr);
576 
577   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
578                            const dnn::BatchDescriptor &input_dimensions,
579                            const DeviceMemory<float> &input_data,
580                            const dnn::BatchDescriptor &output_dimensions,
581                            const DeviceMemory<float> &output_data,
582                            const DeviceMemory<float> &input_diff_data,
583                            DeviceMemory<float> *output_diff_data,
584                            ScratchAllocator *workspace_allocator = nullptr);
585 
586   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
587                            const dnn::BatchDescriptor &input_dimensions,
588                            const DeviceMemory<Eigen::half> &input_data,
589                            const dnn::BatchDescriptor &output_dimensions,
590                            const DeviceMemory<Eigen::half> &output_data,
591                            const DeviceMemory<Eigen::half> &input_diff_data,
592                            DeviceMemory<Eigen::half> *output_diff_data,
593                            ScratchAllocator *workspace_allocator = nullptr);
594 
595   Stream &ThenNormalizeWithDimensions(
596       const dnn::NormalizeDescriptor &normalize_descriptor,
597       const dnn::BatchDescriptor &dimensions,
598       const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
599 
600   Stream &ThenNormalizeBackwardWithDimensions(
601       const dnn::NormalizeDescriptor &normalize_descriptor,
602       const dnn::BatchDescriptor &dimensions,
603       const DeviceMemory<float> &raw_data,
604       const DeviceMemory<float> &normalized_data,
605       const DeviceMemory<float> &normalized_variable_gradient,
606       DeviceMemory<float> *raw_variable_gradient,
607       ScratchAllocator *workspace_allocator = nullptr);
608 
609   Stream &ThenActivate(dnn::ActivationMode activation_mode,
610                        const dnn::BatchDescriptor &dimensions,
611                        const DeviceMemory<float> &input_data,
612                        DeviceMemory<float> *output_data);
613 
614   // Same as ThenActivate, but also takes an options argument that can be used
615   // for platform-specific option flags.
616   Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
617                                   const dnn::BatchDescriptor &dimensions,
618                                   const DeviceMemory<float> &input_data,
619                                   DeviceMemory<float> *output_data,
620                                   uint64 options);
621 
622   Stream &ThenDepthConcatenate(
623       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
624       port::ArraySlice<const DeviceMemory<float> *> input_data,
625       DeviceMemory<float> *output_data);
626 
627   Stream &ThenSpaceConcatenate(
628       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
629       port::ArraySlice<const DeviceMemory<float> *> input_data,
630       DeviceMemory<float> *output_data,
631       dnn::SpaceConcatenateMode concat_direction);
632 
633   // Change the layout of the data by shrinking one dimension (or set of
634   // dimensions) and growing another dimension (or set of dimensions), while
635   // keeping the total number of data elements constant, and maintaining the
636   // current data ordering.
637   Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
638                       const DeviceMemory<float> &input_data,
639                       const dnn::BatchDescriptor &output_dimensions,
640                       DeviceMemory<float> *output_data);
641 
642   // Depth to space takes an X by Y image with depth D*M² and changes it to an
643   // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
644   // the input image is changed to an MxM contiguous area in the output image,
645   // with the values being laid out in raster order specified by
646   // DepthToSpaceLayout, and will have a new depth of D.
647   // See the DoDepthToSpace comment for more information.
648   Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
649                            const DeviceMemory<float> &input_data,
650                            const dnn::DepthToSpaceLayout &depth_to_space_layout,
651                            const int sqrt_depth_reduction,
652                            DeviceMemory<float> *output_data);
653 
654   // Space to depth is the inverse of depth to space. Space to depth takes each
655   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
656   // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
657   // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
658   // data elements is not changed.
659   Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
660                            const DeviceMemory<float> &input_data,
661                            const dnn::DepthToSpaceLayout &space_to_depth_layout,
662                            const int sqrt_depth_increase,
663                            DeviceMemory<float> *output_data);
664 
665   Stream &ThenElementwiseOperate(
666       dnn::ElementwiseOperation operation,
667       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
668       port::ArraySlice<const DeviceMemory<float> *> input_data,
669       const dnn::BatchDescriptor &output_dimensions,
670       DeviceMemory<float> *output_data);
671 
672   Stream &ThenElementwiseOperateScaledQuantized(
673       dnn::ElementwiseOperation operation,
674       port::ArraySlice<int> input_multiplicands, int output_divisor,
675       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
676       port::ArraySlice<const DeviceMemory<float> *> input_data,
677       const dnn::BatchDescriptor &output_dimensions,
678       DeviceMemory<float> *output_data);
679 
680   Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
681                     const DeviceMemory<float> &input_data, int64 left_pad,
682                     int64 right_pad, int64 top_pad, int64 bottom_pad,
683                     DeviceMemory<float> *output_data);
684 
685   Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
686                       const DeviceMemory<float> &input_data, int64 left_trim,
687                       int64 right_trim, int64 top_trim, int64 bottom_trim,
688                       DeviceMemory<float> *output_data);
689 
690   // Grows the input tensor by replicating the X and Y dimensions. The batch and
691   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
692   // limited to X=1 and Y=1.
693   Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
694                           const DeviceMemory<float> &input_data,
695                           int64 replicate_x, int64 replicate_y,
696                           DeviceMemory<float> *output_data);
697 
698   // See DnnSupport::DoMemcpyD2HQuantized.
699   Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
700                                  dnn::QuantizedActivationMode mode,
701                                  void *host_dst, uint64 size);
702 
703   // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
704   // and uses the Quantization trait to call the generic version of
705   // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
706   template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)707   Stream &ThenMemcpyD2HQuantized(
708       const DeviceMemory<float> &gpu_unquantized_src,
709       port::MutableArraySlice<ElementType> host_dst) {
710     return ThenMemcpyD2HQuantized(
711         gpu_unquantized_src, Quantization<ElementType>::kModeId,
712         host_dst.data(), host_dst.size() * sizeof(ElementType));
713   }
714 
715   // See DnnSupport::DoMemcpyH2DQuantized.
716   Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
717                                  dnn::QuantizedActivationMode mode,
718                                  DeviceMemory<float> *gpu_unquantized_dst);
719 
720   // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
721   // and uses the Quantization trait to call the generic version of
722   // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
723   template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)724   Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
725                                  DeviceMemory<float> *gpu_unquantized_dst) {
726     return ThenMemcpyH2DQuantized(
727         host_src.data(), host_src.size() * sizeof(ElementType),
728         Quantization<ElementType>::kModeId, gpu_unquantized_dst);
729   }
730 
731   // See DnnSupport::DoCopyHostBuffer2Device.
732   Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
733                                     DeviceMemory<float> *gpu_unquantized_dst);
734 
735   // See DnnSupport::DoCopyDevice2HostBuffer.
736   Stream &ThenCopyDevice2HostBuffer(
737       const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
738 
739   /////////////////
740   // BLAS support
741 
742   // See BlasSupport::DoBlasAsum.
743   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
744                        int incx, DeviceMemory<float> *result);
745   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
746                        int incx, DeviceMemory<double> *result);
747   Stream &ThenBlasAsum(uint64 elem_count,
748                        const DeviceMemory<std::complex<float>> &x, int incx,
749                        DeviceMemory<float> *result);
750   Stream &ThenBlasAsum(uint64 elem_count,
751                        const DeviceMemory<std::complex<double>> &x, int incx,
752                        DeviceMemory<double> *result);
753 
754   // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
755   // present in DeviceMemory, it must be an execution-time constant (i.e. a
756   // value
757   // that the stream does not change or populate during the course of
758   // execution). The value is effectively captured at stream-enqueue time.
759   Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
760                        const DeviceMemory<float> &x, int incx,
761                        DeviceMemory<float> *y, int incy);
762   Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
763                        const DeviceMemory<double> &x, int incx,
764                        DeviceMemory<double> *y, int incy);
765   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
766                        const DeviceMemory<std::complex<float>> &x, int incx,
767                        DeviceMemory<std::complex<float>> *y, int incy);
768   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
769                        const DeviceMemory<std::complex<double>> &x, int incx,
770                        DeviceMemory<std::complex<double>> *y, int incy);
771 
772   // See BlasSupport::DoBlasCopy.
773   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
774                        int incx, DeviceMemory<float> *y, int incy);
775   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
776                        int incx, DeviceMemory<double> *y, int incy);
777   Stream &ThenBlasCopy(uint64 elem_count,
778                        const DeviceMemory<std::complex<float>> &x, int incx,
779                        DeviceMemory<std::complex<float>> *y, int incy);
780   Stream &ThenBlasCopy(uint64 elem_count,
781                        const DeviceMemory<std::complex<double>> &x, int incx,
782                        DeviceMemory<std::complex<double>> *y, int incy);
783 
784   // See BlasSupport::DoBlasDot.
785   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
786                       const DeviceMemory<float> &y, int incy,
787                       DeviceMemory<float> *result);
788   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
789                       int incx, const DeviceMemory<double> &y, int incy,
790                       DeviceMemory<double> *result);
791 
792   // See BlasSupport::DoBlasDotc.
793   Stream &ThenBlasDotc(uint64 elem_count,
794                        const DeviceMemory<std::complex<float>> &x, int incx,
795                        const DeviceMemory<std::complex<float>> &y, int incy,
796                        DeviceMemory<std::complex<float>> *result);
797   Stream &ThenBlasDotc(uint64 elem_count,
798                        const DeviceMemory<std::complex<double>> &x, int incx,
799                        const DeviceMemory<std::complex<double>> &y, int incy,
800                        DeviceMemory<std::complex<double>> *result);
801 
802   // See BlasSupport::DoBlasDotu.
803   Stream &ThenBlasDotu(uint64 elem_count,
804                        const DeviceMemory<std::complex<float>> &x, int incx,
805                        const DeviceMemory<std::complex<float>> &y, int incy,
806                        DeviceMemory<std::complex<float>> *result);
807   Stream &ThenBlasDotu(uint64 elem_count,
808                        const DeviceMemory<std::complex<double>> &x, int incx,
809                        const DeviceMemory<std::complex<double>> &y, int incy,
810                        DeviceMemory<std::complex<double>> *result);
811 
812   // See BlasSupport::DoBlasNrm2.
813   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
814                        int incx, DeviceMemory<float> *result);
815   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
816                        int incx, DeviceMemory<double> *result);
817   Stream &ThenBlasNrm2(uint64 elem_count,
818                        const DeviceMemory<std::complex<float>> &x, int incx,
819                        DeviceMemory<float> *result);
820   Stream &ThenBlasNrm2(uint64 elem_count,
821                        const DeviceMemory<std::complex<double>> &x, int incx,
822                        DeviceMemory<double> *result);
823 
824   // See BlasSupport::DoBlasRot.
825   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
826                       DeviceMemory<float> *y, int incy, float c, float s);
827   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
828                       DeviceMemory<double> *y, int incy, double c, double s);
829   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
830                       int incx, DeviceMemory<std::complex<float>> *y, int incy,
831                       float c, float s);
832   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
833                       int incx, DeviceMemory<std::complex<double>> *y, int incy,
834                       double c, double s);
835 
836   // See BlasSupport::DoBlasRotg.
837   Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
838                        DeviceMemory<float> *c, DeviceMemory<float> *s);
839   Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
840                        DeviceMemory<double> *c, DeviceMemory<double> *s);
841   Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
842                        DeviceMemory<std::complex<float>> *b,
843                        DeviceMemory<float> *c,
844                        DeviceMemory<std::complex<float>> *s);
845   Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
846                        DeviceMemory<std::complex<double>> *b,
847                        DeviceMemory<double> *c,
848                        DeviceMemory<std::complex<double>> *s);
849 
850   // See BlasSupport::DoBlasRotm.
851   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
852                        DeviceMemory<float> *y, int incy,
853                        const DeviceMemory<float> &param);
854   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
855                        DeviceMemory<double> *y, int incy,
856                        const DeviceMemory<double> &param);
857 
858   // See BlasSupport::DoBlasRotmg.
859   Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
860                         DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
861                         DeviceMemory<float> *param);
862   Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
863                         DeviceMemory<double> *x1,
864                         const DeviceMemory<double> &y1,
865                         DeviceMemory<double> *param);
866 
867   // See BlasSupport::DoBlasScal.
868   Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
869                        int incx);
870   Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
871                        int incx);
872   Stream &ThenBlasScal(uint64 elem_count, float alpha,
873                        DeviceMemory<std::complex<float>> *x, int incx);
874   Stream &ThenBlasScal(uint64 elem_count, double alpha,
875                        DeviceMemory<std::complex<double>> *x, int incx);
876   Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
877                        DeviceMemory<std::complex<float>> *x, int incx);
878   Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
879                        DeviceMemory<std::complex<double>> *x, int incx);
880 
881   // See BlasSupport::DoBlasSwap.
882   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
883                        DeviceMemory<float> *y, int incy);
884   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
885                        DeviceMemory<double> *y, int incy);
886   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
887                        int incx, DeviceMemory<std::complex<float>> *y,
888                        int incy);
889   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
890                        int incx, DeviceMemory<std::complex<double>> *y,
891                        int incy);
892 
893   // See BlasSupport::DoBlasIamax.
894   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
895                         int incx, DeviceMemory<int> *result);
896   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
897                         int incx, DeviceMemory<int> *result);
898   Stream &ThenBlasIamax(uint64 elem_count,
899                         const DeviceMemory<std::complex<float>> &x, int incx,
900                         DeviceMemory<int> *result);
901   Stream &ThenBlasIamax(uint64 elem_count,
902                         const DeviceMemory<std::complex<double>> &x, int incx,
903                         DeviceMemory<int> *result);
904 
905   // See BlasSupport::DoBlasIamin.
906   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
907                         int incx, DeviceMemory<int> *result);
908   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
909                         int incx, DeviceMemory<int> *result);
910   Stream &ThenBlasIamin(uint64 elem_count,
911                         const DeviceMemory<std::complex<float>> &x, int incx,
912                         DeviceMemory<int> *result);
913   Stream &ThenBlasIamin(uint64 elem_count,
914                         const DeviceMemory<std::complex<double>> &x, int incx,
915                         DeviceMemory<int> *result);
916 
917   // See BlasSupport::DoBlasGbmv.
918   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
919                        uint64 ku, float alpha, const DeviceMemory<float> &a,
920                        int lda, const DeviceMemory<float> &x, int incx,
921                        float beta, DeviceMemory<float> *y, int incy);
922   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
923                        uint64 ku, double alpha, const DeviceMemory<double> &a,
924                        int lda, const DeviceMemory<double> &x, int incx,
925                        double beta, DeviceMemory<double> *y, int incy);
926   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
927                        uint64 ku, std::complex<float> alpha,
928                        const DeviceMemory<std::complex<float>> &a, int lda,
929                        const DeviceMemory<std::complex<float>> &x, int incx,
930                        std::complex<float> beta,
931                        DeviceMemory<std::complex<float>> *y, int incy);
932   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
933                        uint64 ku, std::complex<double> alpha,
934                        const DeviceMemory<std::complex<double>> &a, int lda,
935                        const DeviceMemory<std::complex<double>> &x, int incx,
936                        std::complex<double> beta,
937                        DeviceMemory<std::complex<double>> *y, int incy);
938 
939   // See BlasSupport::DoBlasGemv.
940   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
941                        const DeviceMemory<float> &a, int lda,
942                        const DeviceMemory<float> &x, int incx, float beta,
943                        DeviceMemory<float> *y, int incy);
944   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
945                        const DeviceMemory<double> &a, int lda,
946                        const DeviceMemory<double> &x, int incx, double beta,
947                        DeviceMemory<double> *y, int incy);
948   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
949                        std::complex<float> alpha,
950                        const DeviceMemory<std::complex<float>> &a, int lda,
951                        const DeviceMemory<std::complex<float>> &x, int incx,
952                        std::complex<float> beta,
953                        DeviceMemory<std::complex<float>> *y, int incy);
954   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
955                        std::complex<double> alpha,
956                        const DeviceMemory<std::complex<double>> &a, int lda,
957                        const DeviceMemory<std::complex<double>> &x, int incx,
958                        std::complex<double> beta,
959                        DeviceMemory<std::complex<double>> *y, int incy);
960 
961   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
962                                     float alpha, const DeviceMemory<float> &a,
963                                     int lda, const DeviceMemory<float> &x,
964                                     int incx, float beta,
965                                     DeviceMemory<float> *y, int incy,
966                                     blas::ProfileResult *output_profile_result);
967   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
968                                     double alpha, const DeviceMemory<double> &a,
969                                     int lda, const DeviceMemory<double> &x,
970                                     int incx, double beta,
971                                     DeviceMemory<double> *y, int incy,
972                                     blas::ProfileResult *output_profile_result);
973   Stream &ThenBlasGemvWithProfiling(
974       blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
975       const DeviceMemory<std::complex<float>> &a, int lda,
976       const DeviceMemory<std::complex<float>> &x, int incx,
977       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
978       blas::ProfileResult *output_profile_result);
979   Stream &ThenBlasGemvWithProfiling(
980       blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
981       const DeviceMemory<std::complex<double>> &a, int lda,
982       const DeviceMemory<std::complex<double>> &x, int incx,
983       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
984       int incy, blas::ProfileResult *output_profile_result);
985 
986   // See BlasSupport::DoBlasGer.
987   Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
988                       const DeviceMemory<float> &x, int incx,
989                       const DeviceMemory<float> &y, int incy,
990                       DeviceMemory<float> *a, int lda);
991   Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
992                       const DeviceMemory<double> &x, int incx,
993                       const DeviceMemory<double> &y, int incy,
994                       DeviceMemory<double> *a, int lda);
995 
996   // See BlasSupport::DoBlasGerc.
997   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
998                        const DeviceMemory<std::complex<float>> &x, int incx,
999                        const DeviceMemory<std::complex<float>> &y, int incy,
1000                        DeviceMemory<std::complex<float>> *a, int lda);
1001   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
1002                        const DeviceMemory<std::complex<double>> &x, int incx,
1003                        const DeviceMemory<std::complex<double>> &y, int incy,
1004                        DeviceMemory<std::complex<double>> *a, int lda);
1005 
1006   // See BlasSupport::DoBlasGeru.
1007   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
1008                        const DeviceMemory<std::complex<float>> &x, int incx,
1009                        const DeviceMemory<std::complex<float>> &y, int incy,
1010                        DeviceMemory<std::complex<float>> *a, int lda);
1011   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
1012                        const DeviceMemory<std::complex<double>> &x, int incx,
1013                        const DeviceMemory<std::complex<double>> &y, int incy,
1014                        DeviceMemory<std::complex<double>> *a, int lda);
1015 
1016   // See BlasSupport::DoBlasHbmv.
1017   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1018                        std::complex<float> alpha,
1019                        const DeviceMemory<std::complex<float>> &a, int lda,
1020                        const DeviceMemory<std::complex<float>> &x, int incx,
1021                        std::complex<float> beta,
1022                        DeviceMemory<std::complex<float>> *y, int incy);
1023   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1024                        std::complex<double> alpha,
1025                        const DeviceMemory<std::complex<double>> &a, int lda,
1026                        const DeviceMemory<std::complex<double>> &x, int incx,
1027                        std::complex<double> beta,
1028                        DeviceMemory<std::complex<double>> *y, int incy);
1029 
1030   // See BlasSupport::DoBlasHemv.
1031   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1032                        std::complex<float> alpha,
1033                        const DeviceMemory<std::complex<float>> &a, int lda,
1034                        const DeviceMemory<std::complex<float>> &x, int incx,
1035                        std::complex<float> beta,
1036                        DeviceMemory<std::complex<float>> *y, int incy);
1037   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1038                        std::complex<double> alpha,
1039                        const DeviceMemory<std::complex<double>> &a, int lda,
1040                        const DeviceMemory<std::complex<double>> &x, int incx,
1041                        std::complex<double> beta,
1042                        DeviceMemory<std::complex<double>> *y, int incy);
1043 
1044   // See BlasSupport::DoBlasHer.
1045   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1046                       const DeviceMemory<std::complex<float>> &x, int incx,
1047                       DeviceMemory<std::complex<float>> *a, int lda);
1048   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1049                       const DeviceMemory<std::complex<double>> &x, int incx,
1050                       DeviceMemory<std::complex<double>> *a, int lda);
1051 
1052   // See BlasSupport::DoBlasHer2.
1053   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1054                        std::complex<float> alpha,
1055                        const DeviceMemory<std::complex<float>> &x, int incx,
1056                        const DeviceMemory<std::complex<float>> &y, int incy,
1057                        DeviceMemory<std::complex<float>> *a, int lda);
1058   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1059                        std::complex<double> alpha,
1060                        const DeviceMemory<std::complex<double>> &x, int incx,
1061                        const DeviceMemory<std::complex<double>> &y, int incy,
1062                        DeviceMemory<std::complex<double>> *a, int lda);
1063 
1064   // See BlasSupport::DoBlasHpmv.
1065   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1066                        std::complex<float> alpha,
1067                        const DeviceMemory<std::complex<float>> &ap,
1068                        const DeviceMemory<std::complex<float>> &x, int incx,
1069                        std::complex<float> beta,
1070                        DeviceMemory<std::complex<float>> *y, int incy);
1071   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1072                        std::complex<double> alpha,
1073                        const DeviceMemory<std::complex<double>> &ap,
1074                        const DeviceMemory<std::complex<double>> &x, int incx,
1075                        std::complex<double> beta,
1076                        DeviceMemory<std::complex<double>> *y, int incy);
1077 
1078   // See BlasSupport::DoBlasHpr.
1079   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1080                       const DeviceMemory<std::complex<float>> &x, int incx,
1081                       DeviceMemory<std::complex<float>> *ap);
1082   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1083                       const DeviceMemory<std::complex<double>> &x, int incx,
1084                       DeviceMemory<std::complex<double>> *ap);
1085 
1086   // See BlasSupport::DoBlasHpr2.
1087   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1088                        std::complex<float> alpha,
1089                        const DeviceMemory<std::complex<float>> &x, int incx,
1090                        const DeviceMemory<std::complex<float>> &y, int incy,
1091                        DeviceMemory<std::complex<float>> *ap);
1092   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1093                        std::complex<double> alpha,
1094                        const DeviceMemory<std::complex<double>> &x, int incx,
1095                        const DeviceMemory<std::complex<double>> &y, int incy,
1096                        DeviceMemory<std::complex<double>> *ap);
1097 
1098   // See BlasSupport::DoBlasSbmv.
1099   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1100                        const DeviceMemory<float> &a, int lda,
1101                        const DeviceMemory<float> &x, int incx, float beta,
1102                        DeviceMemory<float> *y, int incy);
1103   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1104                        const DeviceMemory<double> &a, int lda,
1105                        const DeviceMemory<double> &x, int incx, double beta,
1106                        DeviceMemory<double> *y, int incy);
1107 
1108   // See BlasSupport::DoBlasSpmv.
1109   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1110                        const DeviceMemory<float> &ap,
1111                        const DeviceMemory<float> &x, int incx, float beta,
1112                        DeviceMemory<float> *y, int incy);
1113   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1114                        const DeviceMemory<double> &ap,
1115                        const DeviceMemory<double> &x, int incx, double beta,
1116                        DeviceMemory<double> *y, int incy);
1117 
1118   // See BlasSupport::DoBlasSpr.
1119   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1120                       const DeviceMemory<float> &x, int incx,
1121                       DeviceMemory<float> *ap);
1122   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1123                       const DeviceMemory<double> &x, int incx,
1124                       DeviceMemory<double> *ap);
1125 
1126   // See BlasSupport::DoBlasSpr2.
1127   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1128                        const DeviceMemory<float> &x, int incx,
1129                        const DeviceMemory<float> &y, int incy,
1130                        DeviceMemory<float> *ap);
1131   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1132                        const DeviceMemory<double> &x, int incx,
1133                        const DeviceMemory<double> &y, int incy,
1134                        DeviceMemory<double> *ap);
1135 
1136   // See BlasSupport::DoBlasSymv.
1137   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1138                        const DeviceMemory<float> &a, int lda,
1139                        const DeviceMemory<float> &x, int incx, float beta,
1140                        DeviceMemory<float> *y, int incy);
1141   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1142                        const DeviceMemory<double> &a, int lda,
1143                        const DeviceMemory<double> &x, int incx, double beta,
1144                        DeviceMemory<double> *y, int incy);
1145 
1146   // See BlasSupport::DoBlasSyr.
1147   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1148                       const DeviceMemory<float> &x, int incx,
1149                       DeviceMemory<float> *a, int lda);
1150   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1151                       const DeviceMemory<double> &x, int incx,
1152                       DeviceMemory<double> *a, int lda);
1153 
1154   // See BlasSupport::DoBlasSyr2.
1155   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1156                        const DeviceMemory<float> &x, int incx,
1157                        const DeviceMemory<float> &y, int incy,
1158                        DeviceMemory<float> *a, int lda);
1159   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1160                        const DeviceMemory<double> &x, int incx,
1161                        const DeviceMemory<double> &y, int incy,
1162                        DeviceMemory<double> *a, int lda);
1163 
1164   // See BlasSupport::DoBlasTbmv.
1165   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1166                        blas::Diagonal diag, uint64 n, uint64 k,
1167                        const DeviceMemory<float> &a, int lda,
1168                        DeviceMemory<float> *x, int incx);
1169   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1170                        blas::Diagonal diag, uint64 n, uint64 k,
1171                        const DeviceMemory<double> &a, int lda,
1172                        DeviceMemory<double> *x, int incx);
1173   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1174                        blas::Diagonal diag, uint64 n, uint64 k,
1175                        const DeviceMemory<std::complex<float>> &a, int lda,
1176                        DeviceMemory<std::complex<float>> *x, int incx);
1177   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1178                        blas::Diagonal diag, uint64 n, uint64 k,
1179                        const DeviceMemory<std::complex<double>> &a, int lda,
1180                        DeviceMemory<std::complex<double>> *x, int incx);
1181 
1182   // See BlasSupport::DoBlasTbsv.
1183   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1184                        blas::Diagonal diag, uint64 n, uint64 k,
1185                        const DeviceMemory<float> &a, int lda,
1186                        DeviceMemory<float> *x, int incx);
1187   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1188                        blas::Diagonal diag, uint64 n, uint64 k,
1189                        const DeviceMemory<double> &a, int lda,
1190                        DeviceMemory<double> *x, int incx);
1191   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1192                        blas::Diagonal diag, uint64 n, uint64 k,
1193                        const DeviceMemory<std::complex<float>> &a, int lda,
1194                        DeviceMemory<std::complex<float>> *x, int incx);
1195   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1196                        blas::Diagonal diag, uint64 n, uint64 k,
1197                        const DeviceMemory<std::complex<double>> &a, int lda,
1198                        DeviceMemory<std::complex<double>> *x, int incx);
1199 
1200   // See BlasSupport::DoBlasTpmv.
1201   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1202                        blas::Diagonal diag, uint64 n,
1203                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1204                        int incx);
1205   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1206                        blas::Diagonal diag, uint64 n,
1207                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1208                        int incx);
1209   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1210                        blas::Diagonal diag, uint64 n,
1211                        const DeviceMemory<std::complex<float>> &ap,
1212                        DeviceMemory<std::complex<float>> *x, int incx);
1213   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1214                        blas::Diagonal diag, uint64 n,
1215                        const DeviceMemory<std::complex<double>> &ap,
1216                        DeviceMemory<std::complex<double>> *x, int incx);
1217 
1218   // See BlasSupport::DoBlasTpsv.
1219   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1220                        blas::Diagonal diag, uint64 n,
1221                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1222                        int incx);
1223   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1224                        blas::Diagonal diag, uint64 n,
1225                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1226                        int incx);
1227   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1228                        blas::Diagonal diag, uint64 n,
1229                        const DeviceMemory<std::complex<float>> &ap,
1230                        DeviceMemory<std::complex<float>> *x, int incx);
1231   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1232                        blas::Diagonal diag, uint64 n,
1233                        const DeviceMemory<std::complex<double>> &ap,
1234                        DeviceMemory<std::complex<double>> *x, int incx);
1235 
1236   // See BlasSupport::DoBlasTrmv.
1237   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1238                        blas::Diagonal diag, uint64 n,
1239                        const DeviceMemory<float> &a, int lda,
1240                        DeviceMemory<float> *x, int incx);
1241   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1242                        blas::Diagonal diag, uint64 n,
1243                        const DeviceMemory<double> &a, int lda,
1244                        DeviceMemory<double> *x, int incx);
1245   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1246                        blas::Diagonal diag, uint64 n,
1247                        const DeviceMemory<std::complex<float>> &a, int lda,
1248                        DeviceMemory<std::complex<float>> *x, int incx);
1249   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1250                        blas::Diagonal diag, uint64 n,
1251                        const DeviceMemory<std::complex<double>> &a, int lda,
1252                        DeviceMemory<std::complex<double>> *x, int incx);
1253 
1254   // See BlasSupport::DoBlasTrsv.
1255   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1256                        blas::Diagonal diag, uint64 n,
1257                        const DeviceMemory<float> &a, int lda,
1258                        DeviceMemory<float> *x, int incx);
1259   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1260                        blas::Diagonal diag, uint64 n,
1261                        const DeviceMemory<double> &a, int lda,
1262                        DeviceMemory<double> *x, int incx);
1263   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1264                        blas::Diagonal diag, uint64 n,
1265                        const DeviceMemory<std::complex<float>> &a, int lda,
1266                        DeviceMemory<std::complex<float>> *x, int incx);
1267   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1268                        blas::Diagonal diag, uint64 n,
1269                        const DeviceMemory<std::complex<double>> &a, int lda,
1270                        DeviceMemory<std::complex<double>> *x, int incx);
1271 
1272   // See BlasSupport::DoBlasGemm.
1273   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1274                                  uint64 m, uint64 n, uint64 k, float alpha,
1275                                  const DeviceMemory<Eigen::half> &a, int lda,
1276                                  const DeviceMemory<Eigen::half> &b, int ldb,
1277                                  float beta, DeviceMemory<Eigen::half> *c,
1278                                  int ldc);
1279   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1280                                  uint64 m, uint64 n, uint64 k, float alpha,
1281                                  const DeviceMemory<float> &a, int lda,
1282                                  const DeviceMemory<float> &b, int ldb,
1283                                  float beta, DeviceMemory<float> *c, int ldc);
1284   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1285                                  uint64 m, uint64 n, uint64 k, double alpha,
1286                                  const DeviceMemory<double> &a, int lda,
1287                                  const DeviceMemory<double> &b, int ldb,
1288                                  double beta, DeviceMemory<double> *c, int ldc);
1289   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1290                                  uint64 m, uint64 n, uint64 k,
1291                                  std::complex<float> alpha,
1292                                  const DeviceMemory<std::complex<float>> &a,
1293                                  int lda,
1294                                  const DeviceMemory<std::complex<float>> &b,
1295                                  int ldb, std::complex<float> beta,
1296                                  DeviceMemory<std::complex<float>> *c, int ldc);
1297   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1298                                  uint64 m, uint64 n, uint64 k,
1299                                  std::complex<double> alpha,
1300                                  const DeviceMemory<std::complex<double>> &a,
1301                                  int lda,
1302                                  const DeviceMemory<std::complex<double>> &b,
1303                                  int ldb, std::complex<double> beta,
1304                                  DeviceMemory<std::complex<double>> *c,
1305                                  int ldc);
1306 
1307   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1308                                     blas::Transpose transb, uint64 m, uint64 n,
1309                                     uint64 k, float alpha,
1310                                     const DeviceMemory<Eigen::half> &a, int lda,
1311                                     const DeviceMemory<Eigen::half> &b, int ldb,
1312                                     float beta, DeviceMemory<Eigen::half> *c,
1313                                     int ldc,
1314                                     blas::ProfileResult *output_profile_result);
1315   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1316                                     blas::Transpose transb, uint64 m, uint64 n,
1317                                     uint64 k, float alpha,
1318                                     const DeviceMemory<float> &a, int lda,
1319                                     const DeviceMemory<float> &b, int ldb,
1320                                     float beta, DeviceMemory<float> *c, int ldc,
1321                                     blas::ProfileResult *output_profile_result);
1322   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1323                                     blas::Transpose transb, uint64 m, uint64 n,
1324                                     uint64 k, double alpha,
1325                                     const DeviceMemory<double> &a, int lda,
1326                                     const DeviceMemory<double> &b, int ldb,
1327                                     double beta, DeviceMemory<double> *c,
1328                                     int ldc,
1329                                     blas::ProfileResult *output_profile_result);
1330   Stream &ThenBlasGemmWithProfiling(
1331       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1332       uint64 k, std::complex<float> alpha,
1333       const DeviceMemory<std::complex<float>> &a, int lda,
1334       const DeviceMemory<std::complex<float>> &b, int ldb,
1335       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1336       blas::ProfileResult *output_profile_result);
1337   Stream &ThenBlasGemmWithProfiling(
1338       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1339       uint64 k, std::complex<double> alpha,
1340       const DeviceMemory<std::complex<double>> &a, int lda,
1341       const DeviceMemory<std::complex<double>> &b, int ldb,
1342       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1343       blas::ProfileResult *output_profile_result);
1344 
1345   // See BlasSupport::DoBlasGemmWithAlgorithm.
1346   Stream &ThenBlasGemmWithAlgorithm(
1347       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1348       uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1349       const DeviceMemory<Eigen::half> &a, int lda,
1350       const DeviceMemory<Eigen::half> &b, int ldb,
1351       const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1352       int ldc, blas::ComputationType computation_type,
1353       blas::AlgorithmType algorithm,
1354       blas::ProfileResult *output_profile_result);
1355   Stream &ThenBlasGemmWithAlgorithm(
1356       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1357       uint64 k, const HostOrDeviceScalar<int> &alpha,
1358       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1359       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,
1360       int ldc, blas::ComputationType computation_type,
1361       blas::AlgorithmType algorithm,
1362       blas::ProfileResult *output_profile_result);
1363   Stream &ThenBlasGemmWithAlgorithm(
1364       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1365       uint64 k, const HostOrDeviceScalar<float> &alpha,
1366       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1367       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1368       int ldc, blas::ComputationType computation_type,
1369       blas::AlgorithmType algorithm,
1370       blas::ProfileResult *output_profile_result);
1371   Stream &ThenBlasGemmWithAlgorithm(
1372       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1373       uint64 k, const HostOrDeviceScalar<double> &alpha,
1374       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1375       int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1376       int ldc, blas::ComputationType computation_type,
1377       blas::AlgorithmType algorithm,
1378       blas::ProfileResult *output_profile_result);
1379   Stream &ThenBlasGemmWithAlgorithm(
1380       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1381       uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1382       const DeviceMemory<std::complex<float>> &a, int lda,
1383       const DeviceMemory<std::complex<float>> &b, int ldb,
1384       const HostOrDeviceScalar<std::complex<float>> &beta,
1385       DeviceMemory<std::complex<float>> *c, int ldc,
1386       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1387       blas::ProfileResult *output_profile_result);
1388   Stream &ThenBlasGemmWithAlgorithm(
1389       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1390       uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1391       const DeviceMemory<std::complex<double>> &a, int lda,
1392       const DeviceMemory<std::complex<double>> &b, int ldb,
1393       const HostOrDeviceScalar<std::complex<double>> &beta,
1394       DeviceMemory<std::complex<double>> *c, int ldc,
1395       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1396       blas::ProfileResult *output_profile_result);
1397 
1398   // See BlasSupport::DoBlasGemmBatched.
1399   Stream &ThenBlasGemmBatched(
1400       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1401       uint64 k, float alpha,
1402       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1403       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1404       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1405       int ldc, int batch_count);
1406   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1407                               uint64 m, uint64 n, uint64 k, float alpha,
1408                               const port::ArraySlice<DeviceMemory<float> *> &a,
1409                               int lda,
1410                               const port::ArraySlice<DeviceMemory<float> *> &b,
1411                               int ldb, float beta,
1412                               const port::ArraySlice<DeviceMemory<float> *> &c,
1413                               int ldc, int batch_count);
1414   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1415                               uint64 m, uint64 n, uint64 k, double alpha,
1416                               const port::ArraySlice<DeviceMemory<double> *> &a,
1417                               int lda,
1418                               const port::ArraySlice<DeviceMemory<double> *> &b,
1419                               int ldb, double beta,
1420                               const port::ArraySlice<DeviceMemory<double> *> &c,
1421                               int ldc, int batch_count);
1422   Stream &ThenBlasGemmBatched(
1423       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1424       uint64 k, std::complex<float> alpha,
1425       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1426       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1427       std::complex<float> beta,
1428       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1429       int batch_count);
1430   Stream &ThenBlasGemmBatched(
1431       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1432       uint64 k, std::complex<double> alpha,
1433       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1434       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1435       std::complex<double> beta,
1436       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1437       int batch_count);
1438   Stream &ThenBlasGemmBatchedWithScratch(
1439       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1440       uint64 k, float alpha,
1441       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1442       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1443       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1444       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1445   Stream &ThenBlasGemmBatchedWithScratch(
1446       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1447       uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1448       int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1449       float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1450       int batch_count, ScratchAllocator *scratch_allocator);
1451   Stream &ThenBlasGemmBatchedWithScratch(
1452       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1453       uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1454       int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1455       double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1456       int batch_count, ScratchAllocator *scratch_allocator);
1457   Stream &ThenBlasGemmBatchedWithScratch(
1458       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1459       uint64 k, std::complex<float> alpha,
1460       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1461       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1462       std::complex<float> beta,
1463       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1464       int batch_count, ScratchAllocator *scratch_allocator);
1465   Stream &ThenBlasGemmBatchedWithScratch(
1466       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1467       uint64 k, std::complex<double> alpha,
1468       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1469       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1470       std::complex<double> beta,
1471       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1472       int batch_count, ScratchAllocator *scratch_allocator);
1473   Stream &ThenBlasGemmStridedBatched(
1474       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1475       uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1476       int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1477       int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1478       int64 stride_c, int batch_count);
1479   Stream &ThenBlasGemmStridedBatched(
1480       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1481       uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1482       int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1483       float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1484       int batch_count);
1485   Stream &ThenBlasGemmStridedBatched(
1486       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1487       uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1488       int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1489       double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1490       int batch_count);
1491   Stream &ThenBlasGemmStridedBatched(
1492       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1493       uint64 k, std::complex<float> alpha,
1494       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1495       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1496       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1497       int64 stride_c, int batch_count);
1498   Stream &ThenBlasGemmStridedBatched(
1499       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1500       uint64 k, std::complex<double> alpha,
1501       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1502       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1503       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1504       int64 stride_c, int batch_count);
1505 
1506   // See BlasSupport::DoBlasHemm.
1507   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1508                        uint64 n, std::complex<float> alpha,
1509                        const DeviceMemory<std::complex<float>> &a, int lda,
1510                        const DeviceMemory<std::complex<float>> &b, int ldb,
1511                        std::complex<float> beta,
1512                        DeviceMemory<std::complex<float>> *c, int ldc);
1513   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1514                        uint64 n, std::complex<double> alpha,
1515                        const DeviceMemory<std::complex<double>> &a, int lda,
1516                        const DeviceMemory<std::complex<double>> &b, int ldb,
1517                        std::complex<double> beta,
1518                        DeviceMemory<std::complex<double>> *c, int ldc);
1519 
1520   // See BlasSupport::DoBlasHerk.
1521   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1522                        uint64 k, float alpha,
1523                        const DeviceMemory<std::complex<float>> &a, int lda,
1524                        float beta, DeviceMemory<std::complex<float>> *c,
1525                        int ldc);
1526   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1527                        uint64 k, double alpha,
1528                        const DeviceMemory<std::complex<double>> &a, int lda,
1529                        double beta, DeviceMemory<std::complex<double>> *c,
1530                        int ldc);
1531 
1532   // See BlasSupport::DoBlasHer2k.
1533   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1534                         uint64 k, std::complex<float> alpha,
1535                         const DeviceMemory<std::complex<float>> &a, int lda,
1536                         const DeviceMemory<std::complex<float>> &b, int ldb,
1537                         float beta, DeviceMemory<std::complex<float>> *c,
1538                         int ldc);
1539   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1540                         uint64 k, std::complex<double> alpha,
1541                         const DeviceMemory<std::complex<double>> &a, int lda,
1542                         const DeviceMemory<std::complex<double>> &b, int ldb,
1543                         double beta, DeviceMemory<std::complex<double>> *c,
1544                         int ldc);
1545 
1546   // See BlasSupport::DoBlasSymm.
1547   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1548                        uint64 n, float alpha, const DeviceMemory<float> &a,
1549                        int lda, const DeviceMemory<float> &b, int ldb,
1550                        float beta, DeviceMemory<float> *c, int ldc);
1551   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1552                        uint64 n, double alpha, const DeviceMemory<double> &a,
1553                        int lda, const DeviceMemory<double> &b, int ldb,
1554                        double beta, DeviceMemory<double> *c, int ldc);
1555   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1556                        uint64 n, std::complex<float> alpha,
1557                        const DeviceMemory<std::complex<float>> &a, int lda,
1558                        const DeviceMemory<std::complex<float>> &b, int ldb,
1559                        std::complex<float> beta,
1560                        DeviceMemory<std::complex<float>> *c, int ldc);
1561   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1562                        uint64 n, std::complex<double> alpha,
1563                        const DeviceMemory<std::complex<double>> &a, int lda,
1564                        const DeviceMemory<std::complex<double>> &b, int ldb,
1565                        std::complex<double> beta,
1566                        DeviceMemory<std::complex<double>> *c, int ldc);
1567 
1568   // See BlasSupport::DoBlasSyrk.
1569   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1570                        uint64 k, float alpha, const DeviceMemory<float> &a,
1571                        int lda, float beta, DeviceMemory<float> *c, int ldc);
1572   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1573                        uint64 k, double alpha, const DeviceMemory<double> &a,
1574                        int lda, double beta, DeviceMemory<double> *c, int ldc);
1575   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1576                        uint64 k, std::complex<float> alpha,
1577                        const DeviceMemory<std::complex<float>> &a, int lda,
1578                        std::complex<float> beta,
1579                        DeviceMemory<std::complex<float>> *c, int ldc);
1580   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1581                        uint64 k, std::complex<double> alpha,
1582                        const DeviceMemory<std::complex<double>> &a, int lda,
1583                        std::complex<double> beta,
1584                        DeviceMemory<std::complex<double>> *c, int ldc);
1585 
1586   // See BlasSupport::DoBlasSyr2k.
1587   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1588                         uint64 k, float alpha, const DeviceMemory<float> &a,
1589                         int lda, const DeviceMemory<float> &b, int ldb,
1590                         float beta, DeviceMemory<float> *c, int ldc);
1591   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1592                         uint64 k, double alpha, const DeviceMemory<double> &a,
1593                         int lda, const DeviceMemory<double> &b, int ldb,
1594                         double beta, DeviceMemory<double> *c, int ldc);
1595   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1596                         uint64 k, std::complex<float> alpha,
1597                         const DeviceMemory<std::complex<float>> &a, int lda,
1598                         const DeviceMemory<std::complex<float>> &b, int ldb,
1599                         std::complex<float> beta,
1600                         DeviceMemory<std::complex<float>> *c, int ldc);
1601   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1602                         uint64 k, std::complex<double> alpha,
1603                         const DeviceMemory<std::complex<double>> &a, int lda,
1604                         const DeviceMemory<std::complex<double>> &b, int ldb,
1605                         std::complex<double> beta,
1606                         DeviceMemory<std::complex<double>> *c, int ldc);
1607 
1608   // See BlasSupport::DoBlasTrmm.
1609   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1610                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1611                        uint64 n, float alpha, const DeviceMemory<float> &a,
1612                        int lda, DeviceMemory<float> *b, int ldb);
1613   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1614                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1615                        uint64 n, double alpha, const DeviceMemory<double> &a,
1616                        int lda, DeviceMemory<double> *b, int ldb);
1617   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1618                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1619                        uint64 n, std::complex<float> alpha,
1620                        const DeviceMemory<std::complex<float>> &a, int lda,
1621                        DeviceMemory<std::complex<float>> *b, int ldb);
1622   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1623                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1624                        uint64 n, std::complex<double> alpha,
1625                        const DeviceMemory<std::complex<double>> &a, int lda,
1626                        DeviceMemory<std::complex<double>> *b, int ldb);
1627 
1628   // See BlasSupport::DoBlasTrsm.
1629   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1630                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1631                        uint64 n, float alpha, const DeviceMemory<float> &a,
1632                        int lda, DeviceMemory<float> *b, int ldb);
1633   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1634                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1635                        uint64 n, double alpha, const DeviceMemory<double> &a,
1636                        int lda, DeviceMemory<double> *b, int ldb);
1637   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1638                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1639                        uint64 n, std::complex<float> alpha,
1640                        const DeviceMemory<std::complex<float>> &a, int lda,
1641                        DeviceMemory<std::complex<float>> *b, int ldb);
1642   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1643                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1644                        uint64 n, std::complex<double> alpha,
1645                        const DeviceMemory<std::complex<double>> &a, int lda,
1646                        DeviceMemory<std::complex<double>> *b, int ldb);
1647 
1648   // See BlasSupport::DoBlatLtMatmul.
1649   // Note that we prevent alpha and beta from being used to deduce CType so that
1650   // they can be constructed implicitly from values of type CType. Without this,
1651   // type deduction would fail when this function is called with a value of type
1652   // CType for alpha or beta.
1653   template <typename ABType, typename CType>
1654   Stream &ThenBlasLtMatmul(
1655       const blas::IBlasLtMatmulPlan *plan,
1656       const detail::NonDeducedType<HostOrDeviceScalar<CType>> &alpha,
1657       const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
1658       const detail::NonDeducedType<HostOrDeviceScalar<CType>> &beta,
1659       DeviceMemory<CType> *c, ScratchAllocator *scratch_allocator,
1660       const blas::IBlasLtMatmulAlgorithm *algorithm,
1661       const DeviceMemory<CType> &bias = {},
1662       blas::ProfileResult *output_profile_result = nullptr) {
1663     return ThenBlasLtMatmulImpl(plan, alpha, a, b, beta, c, scratch_allocator,
1664                                 algorithm, bias, output_profile_result);
1665   }
1666 
1667   // See FftSupport::DoFft.
1668   Stream &ThenFft(fft::Plan *plan,
1669                   const DeviceMemory<std::complex<float>> &input,
1670                   DeviceMemory<std::complex<float>> *output);
1671   Stream &ThenFft(fft::Plan *plan,
1672                   const DeviceMemory<std::complex<double>> &input,
1673                   DeviceMemory<std::complex<double>> *output);
1674   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1675                   DeviceMemory<std::complex<float>> *output);
1676   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1677                   DeviceMemory<std::complex<double>> *output);
1678   Stream &ThenFft(fft::Plan *plan,
1679                   const DeviceMemory<std::complex<float>> &input,
1680                   DeviceMemory<float> *output);
1681   Stream &ThenFft(fft::Plan *plan,
1682                   const DeviceMemory<std::complex<double>> &input,
1683                   DeviceMemory<double> *output);
1684 
1685   // Makes the RNG use the provided value as the basis for further generation.
1686   // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1687   // sources of seed data if the default (high quality) sources are not
1688   // desired.
1689   // For most use cases, this function will not be necessary; each provided
1690   // back-end implementation will be appropriately seeded by default.
1691   // At a minimum 16 bytes of data are required in the seed buffer.
1692   //
1693   // To seed with good (non-reproducible) data:
1694   //   File* f = File::Open("/dev/random", "r");
1695   //   int64 bytes_read = f->Read(seed_data, bytes_to_read);
1696   //   < error checking >
1697   //   stream.ThenSetRngSeed(seed_data, bytes_read);
1698   //
1699   // To seed with reproducible data:
1700   //   uint64_t seed_data[2] = { <data> };
1701   //   stream.ThenSetRngSeed(seed_data, 16);
1702   Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1703 
1704   // Populates the memory indicated by values with uniform-random-distribution
1705   // values. TODO(leary) seeding API/description
1706   //
1707   // Uses the type and size of the DeviceMemory to infer what data should be
1708   // populated.
1709   Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1710   Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1711   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1712   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1713   Stream &ThenPopulateRandGaussian(float mean, float stddev,
1714                                    DeviceMemory<float> *values);
1715   Stream &ThenPopulateRandGaussian(double mean, double stddev,
1716                                    DeviceMemory<double> *values);
1717 
1718   // Entrain onto the stream: a memcpy to a host destination from a GPU source
1719   // of the given target size. host_dst must be a pointer to host memory
1720   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1721   // then registered with StreamExecutor::HostMemoryRegister.
1722   Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1723                      uint64 size);
1724 
1725   // Entrain onto the stream: a memcpy to a GPU destination from a host source
1726   // of the given target size. host_src must be a pointer to host memory
1727   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1728   // then registered with StreamExecutor::HostMemoryRegister.
1729   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1730                      uint64 size);
1731 
1732   // Alternative interface for memcpying from device to host that takes an
1733   // array slice. Checks that the destination size can accommodate the host
1734   // slice size.
1735   template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1736   Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1737                         port::MutableArraySlice<T> host_dst) {
1738     auto host_size = host_dst.size() * sizeof(T);
1739     CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1740     return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1741   }
1742 
1743   // Alternative interface for memcpying from host to device that takes an
1744   // array slice. Checks that the destination size can accommodate the host
1745   // slice size.
1746   template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1747   Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1748                         DeviceMemory<T> *gpu_dst) {
1749     auto host_size = host_src.size() * sizeof(T);
1750     CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1751     return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1752   }
1753 
1754   // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1755   // of the given target size. gpu_src/dst must be pointers to GPU memory and
1756   // peer access must be enabled between their owning StreamExecutors.
1757   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1758                      uint64 size);
1759 
1760   // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1761   // ensuring that the host pointer isn't getting confused accidentally with a
1762   // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1763   Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1764                         const DeviceMemoryBase &gpu_src, uint64 size) {
1765     return ThenMemcpy(gpu_dst, gpu_src, size);
1766   }
1767 
1768   // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1769   // The location must not be null.
1770   Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1771 
1772   // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1773   // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1774   // by 4). The location must not be null.
1775   Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1776 
1777   // Enqueue a forward operation of the RNN model onto the stream.
1778   // See DnnSupport::DoRnnForward for more details.
1779   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1780                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1781                          const DeviceMemory<Eigen::half> &input_data,
1782                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1783                          const DeviceMemory<Eigen::half> &input_h_data,
1784                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1785                          const DeviceMemory<Eigen::half> &input_c_data,
1786                          const DeviceMemory<Eigen::half> &params,
1787                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1788                          DeviceMemory<Eigen::half> *output_data,
1789                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1790                          DeviceMemory<Eigen::half> *output_h_data,
1791                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1792                          DeviceMemory<Eigen::half> *output_c_data,
1793                          bool is_training,
1794                          ScratchAllocator *reserve_space_allocator,
1795                          ScratchAllocator *workspace_allocator,
1796                          dnn::ProfileResult *output_profile_result);
1797 
1798   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1799                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1800                          const DeviceMemory<float> &input_data,
1801                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1802                          const DeviceMemory<float> &input_h_data,
1803                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1804                          const DeviceMemory<float> &input_c_data,
1805                          const DeviceMemory<float> &params,
1806                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1807                          DeviceMemory<float> *output_data,
1808                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1809                          DeviceMemory<float> *output_h_data,
1810                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1811                          DeviceMemory<float> *output_c_data, bool is_training,
1812                          ScratchAllocator *reserve_space_allocator,
1813                          ScratchAllocator *workspace_allocator,
1814                          dnn::ProfileResult *output_profile_result);
1815 
1816   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1817                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1818                          const DeviceMemory<double> &input_data,
1819                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1820                          const DeviceMemory<double> &input_h_data,
1821                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1822                          const DeviceMemory<double> &input_c_data,
1823                          const DeviceMemory<double> &params,
1824                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1825                          DeviceMemory<double> *output_data,
1826                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1827                          DeviceMemory<double> *output_h_data,
1828                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1829                          DeviceMemory<double> *output_c_data, bool is_training,
1830                          ScratchAllocator *reserve_space_allocator,
1831                          ScratchAllocator *workspace_allocator,
1832                          dnn::ProfileResult *output_profile_result);
1833 
1834   // Enqueue a backward operation of the RNN model onto the stream.
1835   // See DnnSupport::DoRnnBackward for more details.
1836   Stream &ThenRnnBackward(
1837       const dnn::RnnDescriptor &rnn_desc,
1838       const dnn::RnnSequenceTensorDescriptor &input_desc,
1839       const DeviceMemory<Eigen::half> &input_data,
1840       const dnn::RnnStateTensorDescriptor &input_h_desc,
1841       const DeviceMemory<Eigen::half> &input_h_data,
1842       const dnn::RnnStateTensorDescriptor &input_c_desc,
1843       const DeviceMemory<Eigen::half> &input_c_data,
1844       const DeviceMemory<Eigen::half> &params,
1845       const dnn::RnnSequenceTensorDescriptor &output_desc,
1846       const DeviceMemory<Eigen::half> &output_data,
1847       const dnn::RnnStateTensorDescriptor &output_h_desc,
1848       const DeviceMemory<Eigen::half> &output_h_data,
1849       const dnn::RnnStateTensorDescriptor &output_c_desc,
1850       const DeviceMemory<Eigen::half> &output_c_data,
1851       const DeviceMemory<Eigen::half> &output_backprop_data,
1852       const DeviceMemory<Eigen::half> &output_h_backprop_data,
1853       const DeviceMemory<Eigen::half> &output_c_backprop_data,
1854       DeviceMemory<Eigen::half> *input_backprop_data,
1855       DeviceMemory<Eigen::half> *input_h_backprop_data,
1856       DeviceMemory<Eigen::half> *input_c_backprop_data,
1857       DeviceMemory<Eigen::half> *params_backprop_data,
1858       DeviceMemory<uint8> *reserve_space_data,
1859       ScratchAllocator *workspace_allocator,
1860       dnn::ProfileResult *output_profile_result);
1861 
1862   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1863                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1864                           const DeviceMemory<float> &input_data,
1865                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1866                           const DeviceMemory<float> &input_h_data,
1867                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1868                           const DeviceMemory<float> &input_c_data,
1869                           const DeviceMemory<float> &params,
1870                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1871                           const DeviceMemory<float> &output_data,
1872                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1873                           const DeviceMemory<float> &output_h_data,
1874                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1875                           const DeviceMemory<float> &output_c_data,
1876                           const DeviceMemory<float> &output_backprop_data,
1877                           const DeviceMemory<float> &output_h_backprop_data,
1878                           const DeviceMemory<float> &output_c_backprop_data,
1879                           DeviceMemory<float> *input_backprop_data,
1880                           DeviceMemory<float> *input_h_backprop_data,
1881                           DeviceMemory<float> *input_c_backprop_data,
1882                           DeviceMemory<float> *params_backprop_data,
1883                           DeviceMemory<uint8> *reserve_space_data,
1884                           ScratchAllocator *workspace_allocator,
1885                           dnn::ProfileResult *output_profile_result);
1886 
1887   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1888                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1889                           const DeviceMemory<double> &input_data,
1890                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1891                           const DeviceMemory<double> &input_h_data,
1892                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1893                           const DeviceMemory<double> &input_c_data,
1894                           const DeviceMemory<double> &params,
1895                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1896                           const DeviceMemory<double> &output_data,
1897                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1898                           const DeviceMemory<double> &output_h_data,
1899                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1900                           const DeviceMemory<double> &output_c_data,
1901                           const DeviceMemory<double> &output_backprop_data,
1902                           const DeviceMemory<double> &output_h_backprop_data,
1903                           const DeviceMemory<double> &output_c_backprop_data,
1904                           DeviceMemory<double> *input_backprop_data,
1905                           DeviceMemory<double> *input_h_backprop_data,
1906                           DeviceMemory<double> *input_c_backprop_data,
1907                           DeviceMemory<double> *params_backprop_data,
1908                           DeviceMemory<uint8> *reserve_space_data,
1909                           ScratchAllocator *workspace_allocator,
1910                           dnn::ProfileResult *output_profile_result);
1911 
1912   // Enqueue a CTCLoss operation onto the stream.
1913   // See DnnSupport::DoCtcLoss for more details.
1914   Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
1915                       const DeviceMemory<float> &probs_data,
1916                       absl::Span<const int> labels_data,
1917                       absl::Span<const int> labels_lengths_data,
1918                       absl::Span<const int> input_lengths_data,
1919                       DeviceMemory<float> *costs_data,
1920                       const dnn::RnnStateTensorDescriptor &grads_desc,
1921                       DeviceMemory<float> *grads_data,
1922                       ScratchAllocator *workspace_allocator);
1923 
1924   // Enqueue onto the stream a operation that transforms a tensor.
1925   // See DnnSupport::DoTransformTensor for more details.
1926   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1927                               dnn::DataType input_type,
1928                               const DeviceMemoryBase &input_data,
1929                               const dnn::BatchDescriptor &output_desc,
1930                               dnn::DataType output_type, float scale,
1931                               DeviceMemoryBase *output_data);
1932 
1933   // The templated version of the above ThenTransformTensor. Useful when the
1934   // input and output types are statically known.
1935   template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1936   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1937                               const DeviceMemory<InElemT> &input_data,
1938                               const dnn::BatchDescriptor &output_desc,
1939                               DeviceMemory<OutElemT> *output_data) {
1940     return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1941                                input_data, output_desc,
1942                                dnn::ToDataType<OutElemT>(), output_data);
1943   }
1944 
1945   // (Synchronously) block the host code waiting for the operations
1946   // entrained on the stream (enqueued to this point in program
1947   // execution) to complete.
1948   //
1949   // Returns an OK status if the blocking was successful and the stream is ok().
1950   // Otherwise returns an error describing why the blocking failed.
1951   port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_);
1952 
1953   // Warning! This method interacts with internal threads in
1954   // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1955   // use
1956   // only. Please check with a member of the FASTR team before making use of
1957   // this method.
1958   //
1959   // Entrains onto the stream a function to be executed on the host at some
1960   // point in the future.
1961   // Async host callbacks DO NOT block the stream as device functions (or as
1962   // synchronous host callbacks). No synchronization is possible with
1963   // asynchronous callbacks; they are strictly fire-and-forget.
1964   // This method is private due to the potential for undefined behavior with
1965   // synchronization using OpenCL user events.
1966   // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1967   // parameter will still be valid - this Stream may not be!
1968   // Any callbacks requiring device API calls must use this method.
1969   Stream &ThenEnqueueOnBackgroundThread(
1970       std::function<void(StreamExecutor *)> task);
1971 
1972   // Returns the (opaque) platform-specific backing object. Ownership is not
1973   // transferred to the caller.
implementation()1974   internal::StreamInterface *implementation() { return implementation_.get(); }
1975 
1976   // Entrains onto the stream a callback to the host (from the device).
1977   // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1978   // never fail or its failure is inconsequential.
1979   //
1980   // This is kept for backward compatibility. Future code should use
1981   // ThenDoHostCallbackWithStatus and explicitly return a success status.
1982   // TODO(b/112125301): Eventually remove this method.
1983   Stream &ThenDoHostCallback(std::function<void()> callback);
1984 
1985   // Entrains onto the stream a callback to the host (from the device).
1986   // Host callbacks block/occupy the stream just as device functions
1987   // (execute one at a time, block later stream operations).
1988   // Whether the callback return status affects the result of BlockHostUntilDone
1989   // is platform-dependent.
1990   //
1991   // Behavior is undefined when synchronizing using OpenCL user events.
1992   // Behavior is undefined if host callbacks call device routines or insert
1993   // them into any stream.
1994   //
1995   // On certain platforms, ThenDoHostCallback is expected to have significant
1996   // negative effects on performance.
1997   Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
1998 
1999   // Runs the given callback after the next call to BlockHostUntilDone on this
2000   // stream (or after the Stream does BlockHostUntilDone in its destructor).
2001   // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
2002   // some use cases.
2003   Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
2004 
2005   // Returns the StreamExecutor (parent object) associated with this stream.
parent()2006   StreamExecutor *parent() const {
2007     CHECK(parent_ != nullptr);
2008     return parent_;
2009   }
2010 
2011   // Returns the (internal usage) temporary-memory-allocation manager associated
2012   // with this stream.
2013   internal::TemporaryMemoryManager *temporary_memory_manager();
2014 
2015   // Returns a debugging string "[stream=0x...,impl=0x...]".
2016   std::string DebugStreamPointers() const;
2017 
2018  private:
2019   friend class host::HostBlas;  // for parent_.
2020   friend class host::HostFft;   // for parent_.
2021   friend class host::HostRng;   // for parent_.
2022   template <typename... Args>
2023   friend struct ThenBlasImpl;  // for implementing ThenBlasXXX.
2024   friend class ocl::CLBlas;    // for parent_.
2025 
InErrorState()2026   bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
2027     absl::ReaderMutexLock lock(&mu_);
2028     return !status_.ok();
2029   }
2030 
2031   // Sets the error state if operation_retcode is false.
2032   // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)2033   void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_) {
2034     if (operation_retcode) {
2035       return;
2036     }
2037     absl::MutexLock lock(&mu_);
2038     status_ = port::InternalError("Unknown error");
2039   }
2040 
2041   // Checks the status and logs the error message, if any.
2042   void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_);
2043 
SetError()2044   void SetError() { CheckError(false /* = operation_retcode */); }
2045 
SetErrorAndLogNoDnnSupport()2046   void SetErrorAndLogNoDnnSupport() {
2047     SetError();
2048     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
2049                     "without DNN support";
2050   }
2051 
2052   // Runs the set of callbacks that are intended to run after
2053   // BlockHostUntilDone.
2054   void RunAfterBlockHostUntilDoneCallbacks();
2055 
2056   // The StreamExecutor that supports the operation of this stream.
2057   StreamExecutor *parent_;
2058 
2059   // The platform-dependent implementation that the StreamExecutor interface
2060   // delegates to.
2061   std::unique_ptr<internal::StreamInterface> implementation_;
2062 
2063   // mutex that guards the allocation / error state flags.
2064   // Mutable so that it can be obtained via const reader lock.
2065   mutable absl::Mutex mu_;
2066 
2067   // Whether Init() was successfully called to allocate this stream on the
2068   // underlying platform. It simply flips from 0 to 1 with a sanity check.
2069   // See StreamExecutor::AllocateStream.
2070   bool allocated_ TF_GUARDED_BY(mu_);
2071 
2072   // The last error (if any) of all method calls.
2073   port::Status status_ TF_GUARDED_BY(mu_);
2074 
2075   // Sub-streams that are generated from this stream. Each element has a pointer
2076   // to sub-stream and a boolean value indicating if this substream is ready to
2077   // be reused.
2078   std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2079       TF_GUARDED_BY(mu_);
2080 
2081   // Streams can allocate temporary memories to help with work they enqueue
2082   // (e.g. for scratch memory spaces). This member tracks those allocations and
2083   // notes when they can be reclaimed -- reclamation is attempted when
2084   // BlockHostUntilDone() is called.
2085   internal::TemporaryMemoryManager temporary_memory_manager_;
2086 
2087   // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
2088   std::vector<std::function<void()>> after_block_host_until_done_callbacks_
2089       TF_GUARDED_BY(mu_);
2090 
2091   // Implementation of ThenConvolveBackwardBias that is shared by all types.
2092   template <typename T>
2093   Stream &ThenConvolveBackwardBiasImpl(
2094       const dnn::BatchDescriptor &input_descriptor,
2095       const DeviceMemory<T> &input_data,
2096       const dnn::BatchDescriptor &bias_descriptor,
2097       DeviceMemory<T> *backward_bias_data);
2098 
2099   // Implementation of ThenBlasLtMatmul that is shared by all types.
2100   template <typename ABType, typename CType>
2101   Stream &ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan *plan,
2102                                const HostOrDeviceScalar<CType> &alpha,
2103                                const DeviceMemory<ABType> &a,
2104                                const DeviceMemory<ABType> &b,
2105                                const HostOrDeviceScalar<CType> &beta,
2106                                DeviceMemory<CType> *c,
2107                                ScratchAllocator *scratch_allocator,
2108                                const blas::IBlasLtMatmulAlgorithm *algorithm,
2109                                const DeviceMemory<CType> &bias,
2110                                blas::ProfileResult *output_profile_result);
2111 
2112   SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2113 };
2114 
2115 ////////////
2116 // Inlines
2117 
2118 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)2119 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
2120                                   const TypedKernel<Params...> &kernel,
2121                                   Args... args) {
2122   KernelInvocationChecker<std::tuple<Params...>,
2123                           std::tuple<Args...>>::CheckAllStaticAssert();
2124   if (ok()) {
2125     // This is the core that allows type-safe kernel launching.
2126     // Since the platforms take kernel arguments as tuples of (void *, size),
2127     // we pack the variadic parameters passed as ...args into the desired
2128     // tuple form and pass that packed form to the StreamExecutor::Launch()
2129     // implementation.
2130     KernelArgsArray<sizeof...(args)> kernel_args;
2131     kernel.PackParams(&kernel_args, args...);
2132     DCHECK(parent_ != nullptr);
2133     bool ok =
2134         parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)
2135             .ok();
2136     if (!ok) {
2137       SetError();
2138       LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
2139     }
2140   }
2141   return *this;
2142 }
2143 
2144 template <typename T>
2145 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2146 Stream::AllocateTemporaryArray(uint64 element_count) {
2147   return temporary_memory_manager_.AllocateArray<T>(element_count);
2148 }
2149 
temporary_memory_manager()2150 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2151   return &temporary_memory_manager_;
2152 }
2153 
2154 template <>
2155 struct Quantization<uint8> {
2156   static constexpr dnn::QuantizedActivationMode kModeId =
2157       dnn::QuantizedActivationMode::k8Bit;
2158 };
2159 
2160 template <>
2161 struct Quantization<uint16> {
2162   static constexpr dnn::QuantizedActivationMode kModeId =
2163       dnn::QuantizedActivationMode::k16Bit;
2164 };
2165 
2166 template <>
2167 struct Quantization<int32> {
2168   static constexpr dnn::QuantizedActivationMode kModeId =
2169       dnn::QuantizedActivationMode::k32Bit;
2170 };
2171 
2172 }  // namespace stream_executor
2173 
2174 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2175