• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20 
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23 
24 #include <complex>
25 #include <functional>
26 #include <memory>
27 
28 #include "absl/synchronization/mutex.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32 #include "tensorflow/stream_executor/dnn.h"
33 #include "tensorflow/stream_executor/event.h"
34 #include "tensorflow/stream_executor/fft.h"
35 #include "tensorflow/stream_executor/host_or_device_scalar.h"
36 #include "tensorflow/stream_executor/kernel.h"
37 #include "tensorflow/stream_executor/launch_dim.h"
38 #include "tensorflow/stream_executor/lib/array_slice.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/platform/thread_annotations.h"
41 #include "tensorflow/stream_executor/temporary_memory_manager.h"
42 
43 namespace stream_executor {
44 
45 namespace host {
46 class HostBlas;
47 class HostFft;
48 class HostRng;
49 class HostTimer;
50 }  // namespace host
51 
52 namespace ocl {
53 class CLBlas;
54 }  // namespace ocl
55 
56 namespace internal {
57 class StreamInterface;
58 }  // namespace internal
59 
60 class DeviceMemoryBase;
61 template <typename ElemT>
62 class DeviceMemory;
63 
64 class Timer;
65 
66 namespace dnn {
67 class BatchDescriptor;
68 class FilterDescriptor;
69 class ConvolutionDescriptor;
70 class ProfileResult;
71 class AlgorithmDesc;
72 }  // namespace dnn
73 
74 class StreamExecutor;
75 class ScratchAllocator;
76 
77 // Convert a type to the corresponding QuantizedActivationMode.
78 template <typename ElementType>
79 struct Quantization;
80 
81 // Represents a stream of dependent computations on a GPU device.
82 //
83 // The operations within a stream execute linearly and asynchronously until
84 // BlockHostUntilDone() is invoked, which synchronously joins host code with
85 // the execution of the stream.
86 //
87 // If any given operation fails when entraining work for the stream, ok() will
88 // indicate that an error has occurred. After initialization, once a stream is
89 // !ok(), it will never be ok().
90 //
91 // Thread-safe post-initialization.
92 class Stream {
93  public:
94   // Instantiate a stream tied to parent as a platform executor. Work
95   // entrained onto this stream will be launched/managed on that
96   // StreamExecutor's platform.
97   explicit Stream(StreamExecutor *parent);
98 
99   // Test only. Use an externally-populated value (like a mock) for the
100   // platform-specific stream implementation.
101   Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
102 
103   // Deallocates any stream resources that the parent StreamExecutor has
104   // bestowed
105   // upon this object.
106   ~Stream();
107 
108   // Returns whether any errors have occurred while entraining work for this
109   // stream.
ok()110   bool ok() const { return !InErrorState(); }
111 
112   // Retrieves execution status back into the stream from the underlying
113   // implementation without blocking the stream.
114   //
115   // Normally, Stream::BlockHostUntilDone is used to get execution status.
116   // However, some devices use out-of-band mechnanisms to ensure their streams
117   // have finished on-device work, without needing to block the streams. (These
118   // devices should also override AllowsSyncOnCompletion to return false.) For
119   // these devices, this method can be used after work is finished to retrieve
120   // execution status.
121   port::Status RefreshStatus() LOCKS_EXCLUDED(mu_);
122 
123   // Initialize the stream. This must be performed before entraining any other
124   // operations.
125   Stream &Init() LOCKS_EXCLUDED(mu_);
126 
127   // Initializes timer t via the StreamExecutor.
128   Stream &InitTimer(Timer *t);
129 
130   // Convenience wrapper around Init() and InitTimer().
131   Stream &InitWithTimer(Timer *t);
132 
133   // Get or create a sub-stream from this stream. If there is any sub-stream in
134   // the pool that can be reused then just return this sub-stream.  Otherwise
135   // create a new sub-stream.
136   //
137   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
138   Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
139 
140   // Return the sub-stream back to the host stream so that it can be reused
141   // later. Sub-streams that are !ok() will not be reused.
142   //
143   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
144   void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
145 
146   // Allocate temporary memories. The stream will deallocate them when blocked
147   // or destroyed.
148   template <typename T>
149   port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
150   AllocateTemporaryArray(uint64 element_count);
151 
152   // Entrains onto the stream of operations: a kernel launch with the given
153   // (variadic) parameters for the invocation. These arguments can be things
154   // like DeviceMemory or primitive types such as int. What arguments you may
155   // pass to a given kernel are noted as the template parameters to the
156   // TypedKernel type that the machocc compiler generates.
157   //
158   // Template parameters:
159   //  Params...   The type list of formal parameters that the typed kernel
160   //              expects, which is matched against Args...
161   //  Args...     The deduced type list for passed actual arguments
162   //
163   // Implementation: A compile-time compatibility check is performed that has
164   // some leniency versus an exact parameter pack match -- for example,
165   // `const DeviceMemory<T>` is considered "pack compatible" with a
166   // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
167   // perfect forwarding support without rvalue references. It also attempts to
168   // spit out helpful static_assert error traces with information as to the
169   // argument number and types that were mismatched.
170   template <typename... Params, typename... Args>
171   Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
172                      const TypedKernel<Params...> &kernel, Args... args);
173 
174   // Record a "start" event for the interval timer at this point in the
175   // stream's execution (relative to the previously and subsequently enqueued
176   // items in the stream's execution). Streams may be started/stopped multiple
177   // times.
178   Stream &ThenStartTimer(Timer *t);
179 
180   // Record a "stop" event for the interval timer at this point in the
181   // stream's execution. See also Stream::ThenStartTimer.
182   Stream &ThenStopTimer(Timer *t);
183 
184   // TODO(leary) If work is added to the stream that is being depended upon,
185   //              then what? Have to describe what happens.
186   template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)187   Stream &ThenWaitFor(Stream *other, Params... more_streams) {
188     return ThenWaitFor(more_streams...).ThenWaitFor(other);
189   }
190 
191   // Create a dependency for this stream's next work on the other stream
192   // completing. Does not take ownership of other, and other must not be
193   // null.
194   //
195   // Checks that a stream does not wait for itself, and it is up to the
196   // user to guarantee that a stream does not come to wait on itself in a
197   // cyclic manner; in that case, behavior is undefined.
198   //
199   // N.B. Base recursion case for the variadic ThenWaitFor.
200   Stream &ThenWaitFor(Stream *other);
201 
202   // Waits for all streams values in others.
203   // Checks that there is no shallow circular wait (i.e. that "this" is not in
204   // others)
205   template <typename P>
ThenWaitFor(P others)206   Stream &ThenWaitFor(P others) {
207     for (auto &stream : *others) {
208       CHECK_NE(stream.get(), this);
209       ThenWaitFor(stream.get());
210     }
211     return *this;
212   }
213 
214   // Waits for an event object to be set.
215   // Note that ThenRecordEvent must have been called on the event before
216   // you call this function; otherwise the event will be considered complete
217   // and this wait will do nothing.
218   Stream &ThenWaitFor(Event *event);
219 
220   // Inserts the specified event into the end of this stream. Once the stream
221   // has processed all events prior to the insertion point, the event will be
222   // marked as completed.
223   // The stream does not take ownership of event - meaning that event's lifetime
224   // must extend past the point at which it is marked complete!
225   Stream &ThenRecordEvent(Event *event);
226 
227   ////////////////
228   // DNN support
229   //
230   // See DnnSupport::* for comments on the following methods.
231 
232   Stream &ThenBatchNormalizationForward(
233       const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
234       const DeviceMemory<float> &offset,
235       const DeviceMemory<float> &estimated_mean,
236       const DeviceMemory<float> &estimated_variance,
237       const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
238       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
239       const double exponential_average_factor,
240       dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
241       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
242       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
243       bool is_training,
244       std::function<const DeviceMemory<float> &()> var_to_inv_var,
245       std::function<void()> inv_var_to_var,
246       ScratchAllocator *reserve_space_allocator,
247       ScratchAllocator *workspace_allocator);
248 
249   Stream &ThenBatchNormalizationBackward(
250       const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
251       const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
252       const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
253       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
254       DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
255       DeviceMemory<float> *offset_backprop,
256       DeviceMemory<uint8> *reserve_space_data,
257       ScratchAllocator *workspace_allocator);
258 
259   Stream &ThenBatchNormalizationForward(
260       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
261       const DeviceMemory<float> &offset,
262       const DeviceMemory<float> &estimated_mean,
263       const DeviceMemory<float> &estimated_variance,
264       const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
265       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
266       const double exponential_average_factor,
267       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
268       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
269       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
270       bool is_training,
271       std::function<const DeviceMemory<float> &()> var_to_inv_var,
272       std::function<void()> inv_var_to_var,
273       ScratchAllocator *reserve_space_allocator,
274       ScratchAllocator *workspace_allocator);
275 
276   Stream &ThenBatchNormalizationBackward(
277       const DeviceMemory<Eigen::half> &y_backprop,
278       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
279       const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
280       const dnn::BatchDescriptor &x_desc,
281       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
282       DeviceMemory<Eigen::half> *x_backprop,
283       DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
284       DeviceMemory<uint8> *reserve_space_data,
285       ScratchAllocator *workspace_allocator);
286 
287   Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
288                        const DeviceMemory<float> &input_data,
289                        const dnn::FilterDescriptor &filter_descriptor,
290                        const DeviceMemory<float> &filter_data,
291                        const dnn::ConvolutionDescriptor &convolution_descriptor,
292                        const dnn::BatchDescriptor &output_descriptor,
293                        DeviceMemory<float> *output);
294 
295   Stream &ThenConvolveQuantized(
296       const dnn::BatchDescriptor &input_descriptor,
297       const DeviceMemory<float> &input_data,
298       const dnn::FilterDescriptor &filter_descriptor,
299       const DeviceMemory<int8> &filter_coefficients,
300       const DeviceMemory<float> &coefficient_scales,
301       const dnn::ConvolutionDescriptor &convolution_descriptor,
302       const dnn::BatchDescriptor &output_descriptor,
303       DeviceMemory<float> *output_data);
304 
305   Stream &ThenConvolveQuantized(
306       const dnn::BatchDescriptor &input_descriptor,
307       const DeviceMemory<float> &input_data,
308       const dnn::FilterDescriptor &filter_descriptor,
309       const DeviceMemory<int16> &filter_coefficients,
310       const DeviceMemory<float> &coefficient_scales,
311       const dnn::ConvolutionDescriptor &convolution_descriptor,
312       const dnn::BatchDescriptor &output_descriptor,
313       DeviceMemory<float> *output_data);
314 
315   Stream &ThenConvolveWithAlgorithm(
316       const dnn::BatchDescriptor &input_descriptor,
317       const DeviceMemory<double> &input_data,
318       const dnn::FilterDescriptor &filter_descriptor,
319       const DeviceMemory<double> &filter_data,
320       const dnn::ConvolutionDescriptor &convolution_descriptor,
321       const dnn::BatchDescriptor &output_descriptor,
322       DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
323       const dnn::AlgorithmConfig &algorithm_config,
324       dnn::ProfileResult *output_profile_result);
325 
326   Stream &ThenConvolveWithAlgorithm(
327       const dnn::BatchDescriptor &input_descriptor,
328       const DeviceMemory<float> &input_data,
329       const dnn::FilterDescriptor &filter_descriptor,
330       const DeviceMemory<float> &filter_data,
331       const dnn::ConvolutionDescriptor &convolution_descriptor,
332       const dnn::BatchDescriptor &output_descriptor,
333       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
334       const dnn::AlgorithmConfig &algorithm_config,
335       dnn::ProfileResult *output_profile_result);
336 
337   Stream &ThenConvolveWithAlgorithm(
338       const dnn::BatchDescriptor &input_descriptor,
339       const DeviceMemory<Eigen::half> &input_data,
340       const dnn::FilterDescriptor &filter_descriptor,
341       const DeviceMemory<Eigen::half> &filter_data,
342       const dnn::ConvolutionDescriptor &convolution_descriptor,
343       const dnn::BatchDescriptor &output_descriptor,
344       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
345       const dnn::AlgorithmConfig &algorithm_config,
346       dnn::ProfileResult *output_profile_result);
347 
348   Stream &ThenConvolveWithAlgorithm(
349       const dnn::BatchDescriptor &input_descriptor,
350       const DeviceMemory<int8> &input_data,
351       const dnn::FilterDescriptor &filter_descriptor,
352       const DeviceMemory<int8> &filter_data,
353       const dnn::ConvolutionDescriptor &convolution_descriptor,
354       const dnn::BatchDescriptor &output_descriptor,
355       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
356       const dnn::AlgorithmConfig &algorithm_config,
357       dnn::ProfileResult *output_profile_result);
358 
359   Stream &ThenConvolveWithAlgorithm(
360       const dnn::BatchDescriptor &input_descriptor,
361       const DeviceMemory<int8> &input_data,
362       const dnn::FilterDescriptor &filter_descriptor,
363       const DeviceMemory<int8> &filter_data,
364       const dnn::ConvolutionDescriptor &convolution_descriptor,
365       const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
366       ScratchAllocator *scratch_allocator,
367       const dnn::AlgorithmConfig &algorithm_config,
368       dnn::ProfileResult *output_profile_result);
369 
370   Stream &ThenFusedConvolveWithAlgorithm(
371       const dnn::BatchDescriptor &conv_input_descriptor,
372       const DeviceMemory<double> &conv_input_data, double conv_input_scale,
373       const dnn::FilterDescriptor &filter_descriptor,
374       const DeviceMemory<double> &filter_data,
375       const dnn::ConvolutionDescriptor &convolution_descriptor,
376       const DeviceMemory<double> &side_input_data, double side_input_scale,
377       const dnn::BatchDescriptor &bias_descriptor,
378       const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
379       const dnn::BatchDescriptor &output_descriptor,
380       DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
381       const dnn::AlgorithmConfig &algorithm_config,
382       dnn::ProfileResult *output_profile_result);
383 
384   Stream &ThenFusedConvolveWithAlgorithm(
385       const dnn::BatchDescriptor &conv_input_descriptor,
386       const DeviceMemory<float> &conv_input_data, float conv_input_scale,
387       const dnn::FilterDescriptor &filter_descriptor,
388       const DeviceMemory<float> &filter_data,
389       const dnn::ConvolutionDescriptor &convolution_descriptor,
390       const DeviceMemory<float> &side_input_data, float side_input_scale,
391       const dnn::BatchDescriptor &bias_descriptor,
392       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
393       const dnn::BatchDescriptor &output_descriptor,
394       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
395       const dnn::AlgorithmConfig &algorithm_config,
396       dnn::ProfileResult *output_profile_result);
397 
398   Stream &ThenFusedConvolveWithAlgorithm(
399       const dnn::BatchDescriptor &conv_input_descriptor,
400       const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
401       const dnn::FilterDescriptor &filter_descriptor,
402       const DeviceMemory<Eigen::half> &filter_data,
403       const dnn::ConvolutionDescriptor &convolution_descriptor,
404       const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
405       const dnn::BatchDescriptor &bias_descriptor,
406       const DeviceMemory<Eigen::half> &biases,
407       dnn::ActivationMode activation_mode,
408       const dnn::BatchDescriptor &output_descriptor,
409       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
410       const dnn::AlgorithmConfig &algorithm_config,
411       dnn::ProfileResult *output_profile_result);
412 
413   Stream &ThenFusedConvolveWithAlgorithm(
414       const dnn::BatchDescriptor &conv_input_descriptor,
415       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
416       const dnn::FilterDescriptor &filter_descriptor,
417       const DeviceMemory<int8> &filter_data,
418       const dnn::ConvolutionDescriptor &convolution_descriptor,
419       const DeviceMemory<int8> &side_input_data, float side_input_scale,
420       const dnn::BatchDescriptor &bias_descriptor,
421       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
422       const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
423       ScratchAllocator *scratch_allocator,
424       const dnn::AlgorithmConfig &algorithm_config,
425       dnn::ProfileResult *output_profile_result);
426 
427   Stream &ThenFusedConvolveWithAlgorithm(
428       const dnn::BatchDescriptor &conv_input_descriptor,
429       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
430       const dnn::FilterDescriptor &filter_descriptor,
431       const DeviceMemory<int8> &filter_data,
432       const dnn::ConvolutionDescriptor &convolution_descriptor,
433       const DeviceMemory<float> &side_input_data, float side_input_scale,
434       const dnn::BatchDescriptor &bias_descriptor,
435       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
436       const dnn::BatchDescriptor &output_descriptor,
437       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
438       const dnn::AlgorithmConfig &algorithm_config,
439       dnn::ProfileResult *output_profile_result);
440 
441   Stream &ThenSeparableConvolve(
442       const dnn::BatchDescriptor &input_descriptor,
443       const DeviceMemory<float> &input_data,
444       const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
445       const DeviceMemory<float> &first_weights,
446       const DeviceMemory<float> &second_weights,
447       const dnn::ConvolutionDescriptor &convolution_descriptor,
448       const dnn::BatchDescriptor &output_descriptor,
449       DeviceMemory<float> *output);
450 
451   Stream &ThenConvolveBackwardDataWithAlgorithm(
452       const dnn::FilterDescriptor &filter_descriptor,
453       const DeviceMemory<double> &filter_data,
454       const dnn::BatchDescriptor &output_descriptor,
455       DeviceMemory<double> backward_output_data,
456       const dnn::ConvolutionDescriptor &convolution_descriptor,
457       const dnn::BatchDescriptor &input_descriptor,
458       DeviceMemory<double> *backward_input_data,
459       ScratchAllocator *scratch_allocator,
460       const dnn::AlgorithmConfig &algorithm_config,
461       dnn::ProfileResult *output_profile_result);
462 
463   Stream &ThenConvolveBackwardDataWithAlgorithm(
464       const dnn::FilterDescriptor &filter_descriptor,
465       const DeviceMemory<float> &filter_data,
466       const dnn::BatchDescriptor &output_descriptor,
467       DeviceMemory<float> backward_output_data,
468       const dnn::ConvolutionDescriptor &convolution_descriptor,
469       const dnn::BatchDescriptor &input_descriptor,
470       DeviceMemory<float> *backward_input_data,
471       ScratchAllocator *scratch_allocator,
472       const dnn::AlgorithmConfig &algorithm_config,
473       dnn::ProfileResult *output_profile_result);
474 
475   Stream &ThenConvolveBackwardDataWithAlgorithm(
476       const dnn::FilterDescriptor &filter_descriptor,
477       const DeviceMemory<Eigen::half> &filter_data,
478       const dnn::BatchDescriptor &output_descriptor,
479       DeviceMemory<Eigen::half> backward_output_data,
480       const dnn::ConvolutionDescriptor &convolution_descriptor,
481       const dnn::BatchDescriptor &input_descriptor,
482       DeviceMemory<Eigen::half> *backward_input_data,
483       ScratchAllocator *scratch_allocator,
484       const dnn::AlgorithmConfig &algorithm_config,
485       dnn::ProfileResult *output_profile_result);
486 
487   Stream &ThenConvolveBackwardFilterWithAlgorithm(
488       const dnn::BatchDescriptor &input_descriptor,
489       const DeviceMemory<double> &input_data,
490       const dnn::BatchDescriptor &output_descriptor,
491       DeviceMemory<double> backward_output_data,
492       const dnn::ConvolutionDescriptor &convolution_descriptor,
493       const dnn::FilterDescriptor &filter_descriptor,
494       DeviceMemory<double> *backward_filter_data,
495       ScratchAllocator *scratch_allocator,
496       const dnn::AlgorithmConfig &algorithm_config,
497       dnn::ProfileResult *output_profile_result);
498 
499   Stream &ThenConvolveBackwardFilterWithAlgorithm(
500       const dnn::BatchDescriptor &input_descriptor,
501       const DeviceMemory<float> &input_data,
502       const dnn::BatchDescriptor &output_descriptor,
503       DeviceMemory<float> backward_output_data,
504       const dnn::ConvolutionDescriptor &convolution_descriptor,
505       const dnn::FilterDescriptor &filter_descriptor,
506       DeviceMemory<float> *backward_filter_data,
507       ScratchAllocator *scratch_allocator,
508       const dnn::AlgorithmConfig &algorithm_config,
509       dnn::ProfileResult *output_profile_result);
510 
511   Stream &ThenConvolveBackwardFilterWithAlgorithm(
512       const dnn::BatchDescriptor &input_descriptor,
513       const DeviceMemory<Eigen::half> &input_data,
514       const dnn::BatchDescriptor &output_descriptor,
515       DeviceMemory<Eigen::half> backward_output_data,
516       const dnn::ConvolutionDescriptor &convolution_descriptor,
517       const dnn::FilterDescriptor &filter_descriptor,
518       DeviceMemory<Eigen::half> *backward_filter_data,
519       ScratchAllocator *scratch_allocator,
520       const dnn::AlgorithmConfig &algorithm_config,
521       dnn::ProfileResult *output_profile_result);
522 
523   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
524                                    const DeviceMemory<double> &input_data,
525                                    const dnn::BatchDescriptor &bias_descriptor,
526                                    DeviceMemory<double> *backward_bias_data);
527 
528   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
529                                    const DeviceMemory<float> &input_data,
530                                    const dnn::BatchDescriptor &bias_descriptor,
531                                    DeviceMemory<float> *backward_bias_data);
532 
533   Stream &ThenConvolveBackwardBias(
534       const dnn::BatchDescriptor &input_descriptor,
535       const DeviceMemory<Eigen::half> &input_data,
536       const dnn::BatchDescriptor &bias_descriptor,
537       DeviceMemory<Eigen::half> *backward_bias_data);
538 
539   Stream &ThenMatMul(const DeviceMemory<float> &input_data,
540                      const DeviceMemory<float> &weights,
541                      const dnn::BatchDescriptor &input_dimensions,
542                      const dnn::BatchDescriptor &output_dimensions,
543                      DeviceMemory<float> *output_data);
544 
545   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
546                               const DeviceMemory<int8> &weights,
547                               const DeviceMemory<float> &weight_scales,
548                               const dnn::BatchDescriptor &input_dimensions,
549                               const dnn::BatchDescriptor &output_dimensions,
550                               DeviceMemory<float> *output_data);
551 
552   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
553                               const DeviceMemory<int16> &weights,
554                               const DeviceMemory<float> &weight_scales,
555                               const dnn::BatchDescriptor &input_dimensions,
556                               const dnn::BatchDescriptor &output_dimensions,
557                               DeviceMemory<float> *output_data);
558 
559   Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
560                       const DeviceMemory<float> &biases,
561                       const dnn::BatchDescriptor &dimensions,
562                       DeviceMemory<float> *output_data);
563 
564   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
565                           const dnn::BatchDescriptor &input_dimensions,
566                           const DeviceMemory<double> &input_data,
567                           const dnn::BatchDescriptor &output_dimensions,
568                           DeviceMemory<double> *output_data,
569                           ScratchAllocator *workspace_allocator = nullptr);
570 
571   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
572                           const dnn::BatchDescriptor &input_dimensions,
573                           const DeviceMemory<float> &input_data,
574                           const dnn::BatchDescriptor &output_dimensions,
575                           DeviceMemory<float> *output_data,
576                           ScratchAllocator *workspace_allocator = nullptr);
577 
578   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
579                           const dnn::BatchDescriptor &input_dimensions,
580                           const DeviceMemory<Eigen::half> &input_data,
581                           const dnn::BatchDescriptor &output_dimensions,
582                           DeviceMemory<Eigen::half> *output_data,
583                           ScratchAllocator *workspace_allocator = nullptr);
584 
585   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
586                           const dnn::BatchDescriptor &input_dimensions,
587                           const DeviceMemory<int8> &input_data,
588                           const dnn::BatchDescriptor &output_dimensions,
589                           DeviceMemory<int8> *output_data,
590                           ScratchAllocator *workspace_allocator = nullptr);
591 
592   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
593                            const dnn::BatchDescriptor &input_dimensions,
594                            const DeviceMemory<double> &input_data,
595                            const dnn::BatchDescriptor &output_dimensions,
596                            const DeviceMemory<double> &output_data,
597                            const DeviceMemory<double> &input_diff_data,
598                            DeviceMemory<double> *output_diff_data,
599                            ScratchAllocator *workspace_allocator = nullptr);
600 
601   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
602                            const dnn::BatchDescriptor &input_dimensions,
603                            const DeviceMemory<float> &input_data,
604                            const dnn::BatchDescriptor &output_dimensions,
605                            const DeviceMemory<float> &output_data,
606                            const DeviceMemory<float> &input_diff_data,
607                            DeviceMemory<float> *output_diff_data,
608                            ScratchAllocator *workspace_allocator = nullptr);
609 
610   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
611                            const dnn::BatchDescriptor &input_dimensions,
612                            const DeviceMemory<Eigen::half> &input_data,
613                            const dnn::BatchDescriptor &output_dimensions,
614                            const DeviceMemory<Eigen::half> &output_data,
615                            const DeviceMemory<Eigen::half> &input_diff_data,
616                            DeviceMemory<Eigen::half> *output_diff_data,
617                            ScratchAllocator *workspace_allocator = nullptr);
618 
619   Stream &ThenNormalizeWithDimensions(
620       const dnn::NormalizeDescriptor &normalize_descriptor,
621       const dnn::BatchDescriptor &dimensions,
622       const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
623 
624   Stream &ThenNormalizeBackwardWithDimensions(
625       const dnn::NormalizeDescriptor &normalize_descriptor,
626       const dnn::BatchDescriptor &dimensions,
627       const DeviceMemory<float> &raw_data,
628       const DeviceMemory<float> &normalized_data,
629       const DeviceMemory<float> &normalized_variable_gradient,
630       DeviceMemory<float> *raw_variable_gradient,
631       ScratchAllocator *workspace_allocator = nullptr);
632 
633   Stream &ThenActivate(dnn::ActivationMode activation_mode,
634                        const dnn::BatchDescriptor &dimensions,
635                        const DeviceMemory<float> &input_data,
636                        DeviceMemory<float> *output_data);
637 
638   // Same as ThenActivate, but also takes an options argument that can be used
639   // for platform-specific option flags.
640   Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
641                                   const dnn::BatchDescriptor &dimensions,
642                                   const DeviceMemory<float> &input_data,
643                                   DeviceMemory<float> *output_data,
644                                   uint64 options);
645 
646   Stream &ThenDepthConcatenate(
647       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
648       port::ArraySlice<const DeviceMemory<float> *> input_data,
649       DeviceMemory<float> *output_data);
650 
651   Stream &ThenSpaceConcatenate(
652       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
653       port::ArraySlice<const DeviceMemory<float> *> input_data,
654       DeviceMemory<float> *output_data,
655       dnn::SpaceConcatenateMode concat_direction);
656 
657   // Change the layout of the data by shrinking one dimension (or set of
658   // dimensions) and growing another dimension (or set of dimensions), while
659   // keeping the total number of data elements constant, and maintaining the
660   // current data ordering.
661   Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
662                       const DeviceMemory<float> &input_data,
663                       const dnn::BatchDescriptor &output_dimensions,
664                       DeviceMemory<float> *output_data);
665 
666   // Depth to space takes an X by Y image with depth D*M² and changes it to an
667   // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
668   // the input image is changed to an MxM contiguous area in the output image,
669   // with the values being laid out in raster order specified by
670   // DepthToSpaceLayout, and will have a new depth of D.
671   // See the DoDepthToSpace comment for more information.
672   Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
673                            const DeviceMemory<float> &input_data,
674                            const dnn::DepthToSpaceLayout &depth_to_space_layout,
675                            const int sqrt_depth_reduction,
676                            DeviceMemory<float> *output_data);
677 
678   // Space to depth is the inverse of depth to space. Space to depth takes each
679   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
680   // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
681   // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
682   // data elements is not changed.
683   Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
684                            const DeviceMemory<float> &input_data,
685                            const dnn::DepthToSpaceLayout &space_to_depth_layout,
686                            const int sqrt_depth_increase,
687                            DeviceMemory<float> *output_data);
688 
689   Stream &ThenElementwiseOperate(
690       dnn::ElementwiseOperation operation,
691       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
692       port::ArraySlice<const DeviceMemory<float> *> input_data,
693       const dnn::BatchDescriptor &output_dimensions,
694       DeviceMemory<float> *output_data);
695 
696   Stream &ThenElementwiseOperateScaledQuantized(
697       dnn::ElementwiseOperation operation,
698       port::ArraySlice<int> input_multiplicands, int output_divisor,
699       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
700       port::ArraySlice<const DeviceMemory<float> *> input_data,
701       const dnn::BatchDescriptor &output_dimensions,
702       DeviceMemory<float> *output_data);
703 
704   Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
705                     const DeviceMemory<float> &input_data, int64 left_pad,
706                     int64 right_pad, int64 top_pad, int64 bottom_pad,
707                     DeviceMemory<float> *output_data);
708 
709   Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
710                       const DeviceMemory<float> &input_data, int64 left_trim,
711                       int64 right_trim, int64 top_trim, int64 bottom_trim,
712                       DeviceMemory<float> *output_data);
713 
714   // Grows the input tensor by replicating the X and Y dimensions. The batch and
715   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
716   // limited to X=1 and Y=1.
717   Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
718                           const DeviceMemory<float> &input_data,
719                           int64 replicate_x, int64 replicate_y,
720                           DeviceMemory<float> *output_data);
721 
722   // See DnnSupport::DoMemcpyD2HQuantized.
723   Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
724                                  dnn::QuantizedActivationMode mode,
725                                  void *host_dst, uint64 size);
726 
727   // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
728   // and uses the Quantization trait to call the generic version of
729   // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
730   template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)731   Stream &ThenMemcpyD2HQuantized(
732       const DeviceMemory<float> &gpu_unquantized_src,
733       port::MutableArraySlice<ElementType> host_dst) {
734     return ThenMemcpyD2HQuantized(
735         gpu_unquantized_src, Quantization<ElementType>::kModeId,
736         host_dst.data(), host_dst.size() * sizeof(ElementType));
737   }
738 
739   // See DnnSupport::DoMemcpyH2DQuantized.
740   Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
741                                  dnn::QuantizedActivationMode mode,
742                                  DeviceMemory<float> *gpu_unquantized_dst);
743 
744   // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
745   // and uses the Quantization trait to call the generic version of
746   // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
747   template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)748   Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
749                                  DeviceMemory<float> *gpu_unquantized_dst) {
750     return ThenMemcpyH2DQuantized(
751         host_src.data(), host_src.size() * sizeof(ElementType),
752         Quantization<ElementType>::kModeId, gpu_unquantized_dst);
753   }
754 
755   // See DnnSupport::DoCopyHostBuffer2Device.
756   Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
757                                     DeviceMemory<float> *gpu_unquantized_dst);
758 
759   // See DnnSupport::DoCopyDevice2HostBuffer.
760   Stream &ThenCopyDevice2HostBuffer(
761       const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
762 
763   /////////////////
764   // BLAS support
765 
766   // See BlasSupport::DoBlasAsum.
767   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
768                        int incx, DeviceMemory<float> *result);
769   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
770                        int incx, DeviceMemory<double> *result);
771   Stream &ThenBlasAsum(uint64 elem_count,
772                        const DeviceMemory<std::complex<float>> &x, int incx,
773                        DeviceMemory<float> *result);
774   Stream &ThenBlasAsum(uint64 elem_count,
775                        const DeviceMemory<std::complex<double>> &x, int incx,
776                        DeviceMemory<double> *result);
777 
778   // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
779   // present in DeviceMemory, it must be an execution-time constant (i.e. a
780   // value
781   // that the stream does not change or populate during the course of
782   // execution). The value is effectively captured at stream-enqueue time.
783   Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
784                        const DeviceMemory<float> &x, int incx,
785                        DeviceMemory<float> *y, int incy);
786   Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
787                        const DeviceMemory<double> &x, int incx,
788                        DeviceMemory<double> *y, int incy);
789   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
790                        const DeviceMemory<std::complex<float>> &x, int incx,
791                        DeviceMemory<std::complex<float>> *y, int incy);
792   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
793                        const DeviceMemory<std::complex<double>> &x, int incx,
794                        DeviceMemory<std::complex<double>> *y, int incy);
795 
796   // See BlasSupport::DoBlasCopy.
797   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
798                        int incx, DeviceMemory<float> *y, int incy);
799   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
800                        int incx, DeviceMemory<double> *y, int incy);
801   Stream &ThenBlasCopy(uint64 elem_count,
802                        const DeviceMemory<std::complex<float>> &x, int incx,
803                        DeviceMemory<std::complex<float>> *y, int incy);
804   Stream &ThenBlasCopy(uint64 elem_count,
805                        const DeviceMemory<std::complex<double>> &x, int incx,
806                        DeviceMemory<std::complex<double>> *y, int incy);
807 
808   // See BlasSupport::DoBlasDot.
809   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
810                       const DeviceMemory<float> &y, int incy,
811                       DeviceMemory<float> *result);
812   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
813                       int incx, const DeviceMemory<double> &y, int incy,
814                       DeviceMemory<double> *result);
815 
816   // See BlasSupport::DoBlasDotc.
817   Stream &ThenBlasDotc(uint64 elem_count,
818                        const DeviceMemory<std::complex<float>> &x, int incx,
819                        const DeviceMemory<std::complex<float>> &y, int incy,
820                        DeviceMemory<std::complex<float>> *result);
821   Stream &ThenBlasDotc(uint64 elem_count,
822                        const DeviceMemory<std::complex<double>> &x, int incx,
823                        const DeviceMemory<std::complex<double>> &y, int incy,
824                        DeviceMemory<std::complex<double>> *result);
825 
826   // See BlasSupport::DoBlasDotu.
827   Stream &ThenBlasDotu(uint64 elem_count,
828                        const DeviceMemory<std::complex<float>> &x, int incx,
829                        const DeviceMemory<std::complex<float>> &y, int incy,
830                        DeviceMemory<std::complex<float>> *result);
831   Stream &ThenBlasDotu(uint64 elem_count,
832                        const DeviceMemory<std::complex<double>> &x, int incx,
833                        const DeviceMemory<std::complex<double>> &y, int incy,
834                        DeviceMemory<std::complex<double>> *result);
835 
836   // See BlasSupport::DoBlasNrm2.
837   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
838                        int incx, DeviceMemory<float> *result);
839   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
840                        int incx, DeviceMemory<double> *result);
841   Stream &ThenBlasNrm2(uint64 elem_count,
842                        const DeviceMemory<std::complex<float>> &x, int incx,
843                        DeviceMemory<float> *result);
844   Stream &ThenBlasNrm2(uint64 elem_count,
845                        const DeviceMemory<std::complex<double>> &x, int incx,
846                        DeviceMemory<double> *result);
847 
848   // See BlasSupport::DoBlasRot.
849   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
850                       DeviceMemory<float> *y, int incy, float c, float s);
851   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
852                       DeviceMemory<double> *y, int incy, double c, double s);
853   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
854                       int incx, DeviceMemory<std::complex<float>> *y, int incy,
855                       float c, float s);
856   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
857                       int incx, DeviceMemory<std::complex<double>> *y, int incy,
858                       double c, double s);
859 
860   // See BlasSupport::DoBlasRotg.
861   Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
862                        DeviceMemory<float> *c, DeviceMemory<float> *s);
863   Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
864                        DeviceMemory<double> *c, DeviceMemory<double> *s);
865   Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
866                        DeviceMemory<std::complex<float>> *b,
867                        DeviceMemory<float> *c,
868                        DeviceMemory<std::complex<float>> *s);
869   Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
870                        DeviceMemory<std::complex<double>> *b,
871                        DeviceMemory<double> *c,
872                        DeviceMemory<std::complex<double>> *s);
873 
874   // See BlasSupport::DoBlasRotm.
875   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
876                        DeviceMemory<float> *y, int incy,
877                        const DeviceMemory<float> &param);
878   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
879                        DeviceMemory<double> *y, int incy,
880                        const DeviceMemory<double> &param);
881 
882   // See BlasSupport::DoBlasRotmg.
883   Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
884                         DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
885                         DeviceMemory<float> *param);
886   Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
887                         DeviceMemory<double> *x1,
888                         const DeviceMemory<double> &y1,
889                         DeviceMemory<double> *param);
890 
891   // See BlasSupport::DoBlasScal.
892   Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
893                        int incx);
894   Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
895                        int incx);
896   Stream &ThenBlasScal(uint64 elem_count, float alpha,
897                        DeviceMemory<std::complex<float>> *x, int incx);
898   Stream &ThenBlasScal(uint64 elem_count, double alpha,
899                        DeviceMemory<std::complex<double>> *x, int incx);
900   Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
901                        DeviceMemory<std::complex<float>> *x, int incx);
902   Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
903                        DeviceMemory<std::complex<double>> *x, int incx);
904 
905   // See BlasSupport::DoBlasSwap.
906   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
907                        DeviceMemory<float> *y, int incy);
908   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
909                        DeviceMemory<double> *y, int incy);
910   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
911                        int incx, DeviceMemory<std::complex<float>> *y,
912                        int incy);
913   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
914                        int incx, DeviceMemory<std::complex<double>> *y,
915                        int incy);
916 
917   // See BlasSupport::DoBlasIamax.
918   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
919                         int incx, DeviceMemory<int> *result);
920   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
921                         int incx, DeviceMemory<int> *result);
922   Stream &ThenBlasIamax(uint64 elem_count,
923                         const DeviceMemory<std::complex<float>> &x, int incx,
924                         DeviceMemory<int> *result);
925   Stream &ThenBlasIamax(uint64 elem_count,
926                         const DeviceMemory<std::complex<double>> &x, int incx,
927                         DeviceMemory<int> *result);
928 
929   // See BlasSupport::DoBlasIamin.
930   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
931                         int incx, DeviceMemory<int> *result);
932   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
933                         int incx, DeviceMemory<int> *result);
934   Stream &ThenBlasIamin(uint64 elem_count,
935                         const DeviceMemory<std::complex<float>> &x, int incx,
936                         DeviceMemory<int> *result);
937   Stream &ThenBlasIamin(uint64 elem_count,
938                         const DeviceMemory<std::complex<double>> &x, int incx,
939                         DeviceMemory<int> *result);
940 
941   // See BlasSupport::DoBlasGbmv.
942   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
943                        uint64 ku, float alpha, const DeviceMemory<float> &a,
944                        int lda, const DeviceMemory<float> &x, int incx,
945                        float beta, DeviceMemory<float> *y, int incy);
946   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
947                        uint64 ku, double alpha, const DeviceMemory<double> &a,
948                        int lda, const DeviceMemory<double> &x, int incx,
949                        double beta, DeviceMemory<double> *y, int incy);
950   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
951                        uint64 ku, std::complex<float> alpha,
952                        const DeviceMemory<std::complex<float>> &a, int lda,
953                        const DeviceMemory<std::complex<float>> &x, int incx,
954                        std::complex<float> beta,
955                        DeviceMemory<std::complex<float>> *y, int incy);
956   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
957                        uint64 ku, std::complex<double> alpha,
958                        const DeviceMemory<std::complex<double>> &a, int lda,
959                        const DeviceMemory<std::complex<double>> &x, int incx,
960                        std::complex<double> beta,
961                        DeviceMemory<std::complex<double>> *y, int incy);
962 
963   // See BlasSupport::DoBlasGemv.
964   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
965                        const DeviceMemory<float> &a, int lda,
966                        const DeviceMemory<float> &x, int incx, float beta,
967                        DeviceMemory<float> *y, int incy);
968   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
969                        const DeviceMemory<double> &a, int lda,
970                        const DeviceMemory<double> &x, int incx, double beta,
971                        DeviceMemory<double> *y, int incy);
972   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
973                        std::complex<float> alpha,
974                        const DeviceMemory<std::complex<float>> &a, int lda,
975                        const DeviceMemory<std::complex<float>> &x, int incx,
976                        std::complex<float> beta,
977                        DeviceMemory<std::complex<float>> *y, int incy);
978   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
979                        std::complex<double> alpha,
980                        const DeviceMemory<std::complex<double>> &a, int lda,
981                        const DeviceMemory<std::complex<double>> &x, int incx,
982                        std::complex<double> beta,
983                        DeviceMemory<std::complex<double>> *y, int incy);
984 
985   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
986                                     float alpha, const DeviceMemory<float> &a,
987                                     int lda, const DeviceMemory<float> &x,
988                                     int incx, float beta,
989                                     DeviceMemory<float> *y, int incy,
990                                     blas::ProfileResult *output_profile_result);
991   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
992                                     double alpha, const DeviceMemory<double> &a,
993                                     int lda, const DeviceMemory<double> &x,
994                                     int incx, double beta,
995                                     DeviceMemory<double> *y, int incy,
996                                     blas::ProfileResult *output_profile_result);
997   Stream &ThenBlasGemvWithProfiling(
998       blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
999       const DeviceMemory<std::complex<float>> &a, int lda,
1000       const DeviceMemory<std::complex<float>> &x, int incx,
1001       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1002       blas::ProfileResult *output_profile_result);
1003   Stream &ThenBlasGemvWithProfiling(
1004       blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
1005       const DeviceMemory<std::complex<double>> &a, int lda,
1006       const DeviceMemory<std::complex<double>> &x, int incx,
1007       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
1008       int incy, blas::ProfileResult *output_profile_result);
1009 
1010   // See BlasSupport::DoBlasGer.
1011   Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
1012                       const DeviceMemory<float> &x, int incx,
1013                       const DeviceMemory<float> &y, int incy,
1014                       DeviceMemory<float> *a, int lda);
1015   Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
1016                       const DeviceMemory<double> &x, int incx,
1017                       const DeviceMemory<double> &y, int incy,
1018                       DeviceMemory<double> *a, int lda);
1019 
1020   // See BlasSupport::DoBlasGerc.
1021   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
1022                        const DeviceMemory<std::complex<float>> &x, int incx,
1023                        const DeviceMemory<std::complex<float>> &y, int incy,
1024                        DeviceMemory<std::complex<float>> *a, int lda);
1025   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
1026                        const DeviceMemory<std::complex<double>> &x, int incx,
1027                        const DeviceMemory<std::complex<double>> &y, int incy,
1028                        DeviceMemory<std::complex<double>> *a, int lda);
1029 
1030   // See BlasSupport::DoBlasGeru.
1031   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
1032                        const DeviceMemory<std::complex<float>> &x, int incx,
1033                        const DeviceMemory<std::complex<float>> &y, int incy,
1034                        DeviceMemory<std::complex<float>> *a, int lda);
1035   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
1036                        const DeviceMemory<std::complex<double>> &x, int incx,
1037                        const DeviceMemory<std::complex<double>> &y, int incy,
1038                        DeviceMemory<std::complex<double>> *a, int lda);
1039 
1040   // See BlasSupport::DoBlasHbmv.
1041   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1042                        std::complex<float> alpha,
1043                        const DeviceMemory<std::complex<float>> &a, int lda,
1044                        const DeviceMemory<std::complex<float>> &x, int incx,
1045                        std::complex<float> beta,
1046                        DeviceMemory<std::complex<float>> *y, int incy);
1047   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1048                        std::complex<double> alpha,
1049                        const DeviceMemory<std::complex<double>> &a, int lda,
1050                        const DeviceMemory<std::complex<double>> &x, int incx,
1051                        std::complex<double> beta,
1052                        DeviceMemory<std::complex<double>> *y, int incy);
1053 
1054   // See BlasSupport::DoBlasHemv.
1055   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1056                        std::complex<float> alpha,
1057                        const DeviceMemory<std::complex<float>> &a, int lda,
1058                        const DeviceMemory<std::complex<float>> &x, int incx,
1059                        std::complex<float> beta,
1060                        DeviceMemory<std::complex<float>> *y, int incy);
1061   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1062                        std::complex<double> alpha,
1063                        const DeviceMemory<std::complex<double>> &a, int lda,
1064                        const DeviceMemory<std::complex<double>> &x, int incx,
1065                        std::complex<double> beta,
1066                        DeviceMemory<std::complex<double>> *y, int incy);
1067 
1068   // See BlasSupport::DoBlasHer.
1069   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1070                       const DeviceMemory<std::complex<float>> &x, int incx,
1071                       DeviceMemory<std::complex<float>> *a, int lda);
1072   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1073                       const DeviceMemory<std::complex<double>> &x, int incx,
1074                       DeviceMemory<std::complex<double>> *a, int lda);
1075 
1076   // See BlasSupport::DoBlasHer2.
1077   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1078                        std::complex<float> alpha,
1079                        const DeviceMemory<std::complex<float>> &x, int incx,
1080                        const DeviceMemory<std::complex<float>> &y, int incy,
1081                        DeviceMemory<std::complex<float>> *a, int lda);
1082   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1083                        std::complex<double> alpha,
1084                        const DeviceMemory<std::complex<double>> &x, int incx,
1085                        const DeviceMemory<std::complex<double>> &y, int incy,
1086                        DeviceMemory<std::complex<double>> *a, int lda);
1087 
1088   // See BlasSupport::DoBlasHpmv.
1089   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1090                        std::complex<float> alpha,
1091                        const DeviceMemory<std::complex<float>> &ap,
1092                        const DeviceMemory<std::complex<float>> &x, int incx,
1093                        std::complex<float> beta,
1094                        DeviceMemory<std::complex<float>> *y, int incy);
1095   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1096                        std::complex<double> alpha,
1097                        const DeviceMemory<std::complex<double>> &ap,
1098                        const DeviceMemory<std::complex<double>> &x, int incx,
1099                        std::complex<double> beta,
1100                        DeviceMemory<std::complex<double>> *y, int incy);
1101 
1102   // See BlasSupport::DoBlasHpr.
1103   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1104                       const DeviceMemory<std::complex<float>> &x, int incx,
1105                       DeviceMemory<std::complex<float>> *ap);
1106   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1107                       const DeviceMemory<std::complex<double>> &x, int incx,
1108                       DeviceMemory<std::complex<double>> *ap);
1109 
1110   // See BlasSupport::DoBlasHpr2.
1111   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1112                        std::complex<float> alpha,
1113                        const DeviceMemory<std::complex<float>> &x, int incx,
1114                        const DeviceMemory<std::complex<float>> &y, int incy,
1115                        DeviceMemory<std::complex<float>> *ap);
1116   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1117                        std::complex<double> alpha,
1118                        const DeviceMemory<std::complex<double>> &x, int incx,
1119                        const DeviceMemory<std::complex<double>> &y, int incy,
1120                        DeviceMemory<std::complex<double>> *ap);
1121 
1122   // See BlasSupport::DoBlasSbmv.
1123   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1124                        const DeviceMemory<float> &a, int lda,
1125                        const DeviceMemory<float> &x, int incx, float beta,
1126                        DeviceMemory<float> *y, int incy);
1127   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1128                        const DeviceMemory<double> &a, int lda,
1129                        const DeviceMemory<double> &x, int incx, double beta,
1130                        DeviceMemory<double> *y, int incy);
1131 
1132   // See BlasSupport::DoBlasSpmv.
1133   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1134                        const DeviceMemory<float> &ap,
1135                        const DeviceMemory<float> &x, int incx, float beta,
1136                        DeviceMemory<float> *y, int incy);
1137   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1138                        const DeviceMemory<double> &ap,
1139                        const DeviceMemory<double> &x, int incx, double beta,
1140                        DeviceMemory<double> *y, int incy);
1141 
1142   // See BlasSupport::DoBlasSpr.
1143   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1144                       const DeviceMemory<float> &x, int incx,
1145                       DeviceMemory<float> *ap);
1146   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1147                       const DeviceMemory<double> &x, int incx,
1148                       DeviceMemory<double> *ap);
1149 
1150   // See BlasSupport::DoBlasSpr2.
1151   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1152                        const DeviceMemory<float> &x, int incx,
1153                        const DeviceMemory<float> &y, int incy,
1154                        DeviceMemory<float> *ap);
1155   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1156                        const DeviceMemory<double> &x, int incx,
1157                        const DeviceMemory<double> &y, int incy,
1158                        DeviceMemory<double> *ap);
1159 
1160   // See BlasSupport::DoBlasSymv.
1161   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1162                        const DeviceMemory<float> &a, int lda,
1163                        const DeviceMemory<float> &x, int incx, float beta,
1164                        DeviceMemory<float> *y, int incy);
1165   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1166                        const DeviceMemory<double> &a, int lda,
1167                        const DeviceMemory<double> &x, int incx, double beta,
1168                        DeviceMemory<double> *y, int incy);
1169 
1170   // See BlasSupport::DoBlasSyr.
1171   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1172                       const DeviceMemory<float> &x, int incx,
1173                       DeviceMemory<float> *a, int lda);
1174   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1175                       const DeviceMemory<double> &x, int incx,
1176                       DeviceMemory<double> *a, int lda);
1177 
1178   // See BlasSupport::DoBlasSyr2.
1179   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1180                        const DeviceMemory<float> &x, int incx,
1181                        const DeviceMemory<float> &y, int incy,
1182                        DeviceMemory<float> *a, int lda);
1183   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1184                        const DeviceMemory<double> &x, int incx,
1185                        const DeviceMemory<double> &y, int incy,
1186                        DeviceMemory<double> *a, int lda);
1187 
1188   // See BlasSupport::DoBlasTbmv.
1189   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1190                        blas::Diagonal diag, uint64 n, uint64 k,
1191                        const DeviceMemory<float> &a, int lda,
1192                        DeviceMemory<float> *x, int incx);
1193   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1194                        blas::Diagonal diag, uint64 n, uint64 k,
1195                        const DeviceMemory<double> &a, int lda,
1196                        DeviceMemory<double> *x, int incx);
1197   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1198                        blas::Diagonal diag, uint64 n, uint64 k,
1199                        const DeviceMemory<std::complex<float>> &a, int lda,
1200                        DeviceMemory<std::complex<float>> *x, int incx);
1201   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1202                        blas::Diagonal diag, uint64 n, uint64 k,
1203                        const DeviceMemory<std::complex<double>> &a, int lda,
1204                        DeviceMemory<std::complex<double>> *x, int incx);
1205 
1206   // See BlasSupport::DoBlasTbsv.
1207   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1208                        blas::Diagonal diag, uint64 n, uint64 k,
1209                        const DeviceMemory<float> &a, int lda,
1210                        DeviceMemory<float> *x, int incx);
1211   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1212                        blas::Diagonal diag, uint64 n, uint64 k,
1213                        const DeviceMemory<double> &a, int lda,
1214                        DeviceMemory<double> *x, int incx);
1215   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1216                        blas::Diagonal diag, uint64 n, uint64 k,
1217                        const DeviceMemory<std::complex<float>> &a, int lda,
1218                        DeviceMemory<std::complex<float>> *x, int incx);
1219   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1220                        blas::Diagonal diag, uint64 n, uint64 k,
1221                        const DeviceMemory<std::complex<double>> &a, int lda,
1222                        DeviceMemory<std::complex<double>> *x, int incx);
1223 
1224   // See BlasSupport::DoBlasTpmv.
1225   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1226                        blas::Diagonal diag, uint64 n,
1227                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1228                        int incx);
1229   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1230                        blas::Diagonal diag, uint64 n,
1231                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1232                        int incx);
1233   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1234                        blas::Diagonal diag, uint64 n,
1235                        const DeviceMemory<std::complex<float>> &ap,
1236                        DeviceMemory<std::complex<float>> *x, int incx);
1237   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1238                        blas::Diagonal diag, uint64 n,
1239                        const DeviceMemory<std::complex<double>> &ap,
1240                        DeviceMemory<std::complex<double>> *x, int incx);
1241 
1242   // See BlasSupport::DoBlasTpsv.
1243   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1244                        blas::Diagonal diag, uint64 n,
1245                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1246                        int incx);
1247   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1248                        blas::Diagonal diag, uint64 n,
1249                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1250                        int incx);
1251   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1252                        blas::Diagonal diag, uint64 n,
1253                        const DeviceMemory<std::complex<float>> &ap,
1254                        DeviceMemory<std::complex<float>> *x, int incx);
1255   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1256                        blas::Diagonal diag, uint64 n,
1257                        const DeviceMemory<std::complex<double>> &ap,
1258                        DeviceMemory<std::complex<double>> *x, int incx);
1259 
1260   // See BlasSupport::DoBlasTrmv.
1261   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1262                        blas::Diagonal diag, uint64 n,
1263                        const DeviceMemory<float> &a, int lda,
1264                        DeviceMemory<float> *x, int incx);
1265   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1266                        blas::Diagonal diag, uint64 n,
1267                        const DeviceMemory<double> &a, int lda,
1268                        DeviceMemory<double> *x, int incx);
1269   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1270                        blas::Diagonal diag, uint64 n,
1271                        const DeviceMemory<std::complex<float>> &a, int lda,
1272                        DeviceMemory<std::complex<float>> *x, int incx);
1273   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1274                        blas::Diagonal diag, uint64 n,
1275                        const DeviceMemory<std::complex<double>> &a, int lda,
1276                        DeviceMemory<std::complex<double>> *x, int incx);
1277 
1278   // See BlasSupport::DoBlasTrsv.
1279   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1280                        blas::Diagonal diag, uint64 n,
1281                        const DeviceMemory<float> &a, int lda,
1282                        DeviceMemory<float> *x, int incx);
1283   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1284                        blas::Diagonal diag, uint64 n,
1285                        const DeviceMemory<double> &a, int lda,
1286                        DeviceMemory<double> *x, int incx);
1287   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1288                        blas::Diagonal diag, uint64 n,
1289                        const DeviceMemory<std::complex<float>> &a, int lda,
1290                        DeviceMemory<std::complex<float>> *x, int incx);
1291   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1292                        blas::Diagonal diag, uint64 n,
1293                        const DeviceMemory<std::complex<double>> &a, int lda,
1294                        DeviceMemory<std::complex<double>> *x, int incx);
1295 
1296   // See BlasSupport::DoBlasGemm.
1297   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1298                                  uint64 m, uint64 n, uint64 k, float alpha,
1299                                  const DeviceMemory<Eigen::half> &a, int lda,
1300                                  const DeviceMemory<Eigen::half> &b, int ldb,
1301                                  float beta, DeviceMemory<Eigen::half> *c,
1302                                  int ldc);
1303   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1304                                  uint64 m, uint64 n, uint64 k, float alpha,
1305                                  const DeviceMemory<float> &a, int lda,
1306                                  const DeviceMemory<float> &b, int ldb,
1307                                  float beta, DeviceMemory<float> *c, int ldc);
1308   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1309                                  uint64 m, uint64 n, uint64 k, double alpha,
1310                                  const DeviceMemory<double> &a, int lda,
1311                                  const DeviceMemory<double> &b, int ldb,
1312                                  double beta, DeviceMemory<double> *c, int ldc);
1313   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1314                                  uint64 m, uint64 n, uint64 k,
1315                                  std::complex<float> alpha,
1316                                  const DeviceMemory<std::complex<float>> &a,
1317                                  int lda,
1318                                  const DeviceMemory<std::complex<float>> &b,
1319                                  int ldb, std::complex<float> beta,
1320                                  DeviceMemory<std::complex<float>> *c, int ldc);
1321   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1322                                  uint64 m, uint64 n, uint64 k,
1323                                  std::complex<double> alpha,
1324                                  const DeviceMemory<std::complex<double>> &a,
1325                                  int lda,
1326                                  const DeviceMemory<std::complex<double>> &b,
1327                                  int ldb, std::complex<double> beta,
1328                                  DeviceMemory<std::complex<double>> *c,
1329                                  int ldc);
1330 
1331   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1332                                     blas::Transpose transb, uint64 m, uint64 n,
1333                                     uint64 k, float alpha,
1334                                     const DeviceMemory<Eigen::half> &a, int lda,
1335                                     const DeviceMemory<Eigen::half> &b, int ldb,
1336                                     float beta, DeviceMemory<Eigen::half> *c,
1337                                     int ldc,
1338                                     blas::ProfileResult *output_profile_result);
1339   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1340                                     blas::Transpose transb, uint64 m, uint64 n,
1341                                     uint64 k, float alpha,
1342                                     const DeviceMemory<float> &a, int lda,
1343                                     const DeviceMemory<float> &b, int ldb,
1344                                     float beta, DeviceMemory<float> *c, int ldc,
1345                                     blas::ProfileResult *output_profile_result);
1346   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1347                                     blas::Transpose transb, uint64 m, uint64 n,
1348                                     uint64 k, double alpha,
1349                                     const DeviceMemory<double> &a, int lda,
1350                                     const DeviceMemory<double> &b, int ldb,
1351                                     double beta, DeviceMemory<double> *c,
1352                                     int ldc,
1353                                     blas::ProfileResult *output_profile_result);
1354   Stream &ThenBlasGemmWithProfiling(
1355       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1356       uint64 k, std::complex<float> alpha,
1357       const DeviceMemory<std::complex<float>> &a, int lda,
1358       const DeviceMemory<std::complex<float>> &b, int ldb,
1359       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1360       blas::ProfileResult *output_profile_result);
1361   Stream &ThenBlasGemmWithProfiling(
1362       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1363       uint64 k, std::complex<double> alpha,
1364       const DeviceMemory<std::complex<double>> &a, int lda,
1365       const DeviceMemory<std::complex<double>> &b, int ldb,
1366       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1367       blas::ProfileResult *output_profile_result);
1368 
1369   // See BlasSupport::DoBlasGemmWithAlgorithm.
1370   Stream &ThenBlasGemmWithAlgorithm(
1371       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1372       uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1373       const DeviceMemory<Eigen::half> &a, int lda,
1374       const DeviceMemory<Eigen::half> &b, int ldb,
1375       const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1376       int ldc, blas::ComputationType computation_type,
1377       blas::AlgorithmType algorithm,
1378       blas::ProfileResult *output_profile_result);
1379   Stream &ThenBlasGemmWithAlgorithm(
1380       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1381       uint64 k, const HostOrDeviceScalar<int> &alpha,
1382       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1383       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,
1384       int ldc, blas::ComputationType computation_type,
1385       blas::AlgorithmType algorithm,
1386       blas::ProfileResult *output_profile_result);
1387   Stream &ThenBlasGemmWithAlgorithm(
1388       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1389       uint64 k, const HostOrDeviceScalar<float> &alpha,
1390       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1391       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1392       int ldc, blas::ComputationType computation_type,
1393       blas::AlgorithmType algorithm,
1394       blas::ProfileResult *output_profile_result);
1395   Stream &ThenBlasGemmWithAlgorithm(
1396       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1397       uint64 k, const HostOrDeviceScalar<double> &alpha,
1398       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1399       int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1400       int ldc, blas::ComputationType computation_type,
1401       blas::AlgorithmType algorithm,
1402       blas::ProfileResult *output_profile_result);
1403   Stream &ThenBlasGemmWithAlgorithm(
1404       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1405       uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1406       const DeviceMemory<std::complex<float>> &a, int lda,
1407       const DeviceMemory<std::complex<float>> &b, int ldb,
1408       const HostOrDeviceScalar<std::complex<float>> &beta,
1409       DeviceMemory<std::complex<float>> *c, int ldc,
1410       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1411       blas::ProfileResult *output_profile_result);
1412   Stream &ThenBlasGemmWithAlgorithm(
1413       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1414       uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1415       const DeviceMemory<std::complex<double>> &a, int lda,
1416       const DeviceMemory<std::complex<double>> &b, int ldb,
1417       const HostOrDeviceScalar<std::complex<double>> &beta,
1418       DeviceMemory<std::complex<double>> *c, int ldc,
1419       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1420       blas::ProfileResult *output_profile_result);
1421 
1422   // See BlasSupport::DoBlasGemmBatched.
1423   Stream &ThenBlasGemmBatched(
1424       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1425       uint64 k, float alpha,
1426       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1427       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1428       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1429       int ldc, int batch_count);
1430   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1431                               uint64 m, uint64 n, uint64 k, float alpha,
1432                               const port::ArraySlice<DeviceMemory<float> *> &a,
1433                               int lda,
1434                               const port::ArraySlice<DeviceMemory<float> *> &b,
1435                               int ldb, float beta,
1436                               const port::ArraySlice<DeviceMemory<float> *> &c,
1437                               int ldc, int batch_count);
1438   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1439                               uint64 m, uint64 n, uint64 k, double alpha,
1440                               const port::ArraySlice<DeviceMemory<double> *> &a,
1441                               int lda,
1442                               const port::ArraySlice<DeviceMemory<double> *> &b,
1443                               int ldb, double beta,
1444                               const port::ArraySlice<DeviceMemory<double> *> &c,
1445                               int ldc, int batch_count);
1446   Stream &ThenBlasGemmBatched(
1447       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1448       uint64 k, std::complex<float> alpha,
1449       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1450       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1451       std::complex<float> beta,
1452       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1453       int batch_count);
1454   Stream &ThenBlasGemmBatched(
1455       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1456       uint64 k, std::complex<double> alpha,
1457       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1458       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1459       std::complex<double> beta,
1460       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1461       int batch_count);
1462   Stream &ThenBlasGemmBatchedWithScratch(
1463       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1464       uint64 k, float alpha,
1465       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1466       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1467       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1468       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1469   Stream &ThenBlasGemmBatchedWithScratch(
1470       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1471       uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1472       int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1473       float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1474       int batch_count, ScratchAllocator *scratch_allocator);
1475   Stream &ThenBlasGemmBatchedWithScratch(
1476       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1477       uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1478       int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1479       double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1480       int batch_count, ScratchAllocator *scratch_allocator);
1481   Stream &ThenBlasGemmBatchedWithScratch(
1482       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1483       uint64 k, std::complex<float> alpha,
1484       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1485       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1486       std::complex<float> beta,
1487       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1488       int batch_count, ScratchAllocator *scratch_allocator);
1489   Stream &ThenBlasGemmBatchedWithScratch(
1490       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1491       uint64 k, std::complex<double> alpha,
1492       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1493       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1494       std::complex<double> beta,
1495       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1496       int batch_count, ScratchAllocator *scratch_allocator);
1497   Stream &ThenBlasGemmStridedBatched(
1498       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1499       uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1500       int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1501       int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1502       int64 stride_c, int batch_count);
1503   Stream &ThenBlasGemmStridedBatched(
1504       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1505       uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1506       int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1507       float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1508       int batch_count);
1509   Stream &ThenBlasGemmStridedBatched(
1510       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1511       uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1512       int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1513       double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1514       int batch_count);
1515   Stream &ThenBlasGemmStridedBatched(
1516       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1517       uint64 k, std::complex<float> alpha,
1518       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1519       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1520       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1521       int64 stride_c, int batch_count);
1522   Stream &ThenBlasGemmStridedBatched(
1523       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1524       uint64 k, std::complex<double> alpha,
1525       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1526       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1527       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1528       int64 stride_c, int batch_count);
1529 
1530   // See BlasSupport::DoBlasHemm.
1531   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1532                        uint64 n, std::complex<float> alpha,
1533                        const DeviceMemory<std::complex<float>> &a, int lda,
1534                        const DeviceMemory<std::complex<float>> &b, int ldb,
1535                        std::complex<float> beta,
1536                        DeviceMemory<std::complex<float>> *c, int ldc);
1537   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1538                        uint64 n, std::complex<double> alpha,
1539                        const DeviceMemory<std::complex<double>> &a, int lda,
1540                        const DeviceMemory<std::complex<double>> &b, int ldb,
1541                        std::complex<double> beta,
1542                        DeviceMemory<std::complex<double>> *c, int ldc);
1543 
1544   // See BlasSupport::DoBlasHerk.
1545   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1546                        uint64 k, float alpha,
1547                        const DeviceMemory<std::complex<float>> &a, int lda,
1548                        float beta, DeviceMemory<std::complex<float>> *c,
1549                        int ldc);
1550   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1551                        uint64 k, double alpha,
1552                        const DeviceMemory<std::complex<double>> &a, int lda,
1553                        double beta, DeviceMemory<std::complex<double>> *c,
1554                        int ldc);
1555 
1556   // See BlasSupport::DoBlasHer2k.
1557   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1558                         uint64 k, std::complex<float> alpha,
1559                         const DeviceMemory<std::complex<float>> &a, int lda,
1560                         const DeviceMemory<std::complex<float>> &b, int ldb,
1561                         float beta, DeviceMemory<std::complex<float>> *c,
1562                         int ldc);
1563   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1564                         uint64 k, std::complex<double> alpha,
1565                         const DeviceMemory<std::complex<double>> &a, int lda,
1566                         const DeviceMemory<std::complex<double>> &b, int ldb,
1567                         double beta, DeviceMemory<std::complex<double>> *c,
1568                         int ldc);
1569 
1570   // See BlasSupport::DoBlasSymm.
1571   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1572                        uint64 n, float alpha, const DeviceMemory<float> &a,
1573                        int lda, const DeviceMemory<float> &b, int ldb,
1574                        float beta, DeviceMemory<float> *c, int ldc);
1575   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1576                        uint64 n, double alpha, const DeviceMemory<double> &a,
1577                        int lda, const DeviceMemory<double> &b, int ldb,
1578                        double beta, DeviceMemory<double> *c, int ldc);
1579   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1580                        uint64 n, std::complex<float> alpha,
1581                        const DeviceMemory<std::complex<float>> &a, int lda,
1582                        const DeviceMemory<std::complex<float>> &b, int ldb,
1583                        std::complex<float> beta,
1584                        DeviceMemory<std::complex<float>> *c, int ldc);
1585   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1586                        uint64 n, std::complex<double> alpha,
1587                        const DeviceMemory<std::complex<double>> &a, int lda,
1588                        const DeviceMemory<std::complex<double>> &b, int ldb,
1589                        std::complex<double> beta,
1590                        DeviceMemory<std::complex<double>> *c, int ldc);
1591 
1592   // See BlasSupport::DoBlasSyrk.
1593   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1594                        uint64 k, float alpha, const DeviceMemory<float> &a,
1595                        int lda, float beta, DeviceMemory<float> *c, int ldc);
1596   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1597                        uint64 k, double alpha, const DeviceMemory<double> &a,
1598                        int lda, double beta, DeviceMemory<double> *c, int ldc);
1599   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1600                        uint64 k, std::complex<float> alpha,
1601                        const DeviceMemory<std::complex<float>> &a, int lda,
1602                        std::complex<float> beta,
1603                        DeviceMemory<std::complex<float>> *c, int ldc);
1604   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1605                        uint64 k, std::complex<double> alpha,
1606                        const DeviceMemory<std::complex<double>> &a, int lda,
1607                        std::complex<double> beta,
1608                        DeviceMemory<std::complex<double>> *c, int ldc);
1609 
1610   // See BlasSupport::DoBlasSyr2k.
1611   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1612                         uint64 k, float alpha, const DeviceMemory<float> &a,
1613                         int lda, const DeviceMemory<float> &b, int ldb,
1614                         float beta, DeviceMemory<float> *c, int ldc);
1615   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1616                         uint64 k, double alpha, const DeviceMemory<double> &a,
1617                         int lda, const DeviceMemory<double> &b, int ldb,
1618                         double beta, DeviceMemory<double> *c, int ldc);
1619   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1620                         uint64 k, std::complex<float> alpha,
1621                         const DeviceMemory<std::complex<float>> &a, int lda,
1622                         const DeviceMemory<std::complex<float>> &b, int ldb,
1623                         std::complex<float> beta,
1624                         DeviceMemory<std::complex<float>> *c, int ldc);
1625   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1626                         uint64 k, std::complex<double> alpha,
1627                         const DeviceMemory<std::complex<double>> &a, int lda,
1628                         const DeviceMemory<std::complex<double>> &b, int ldb,
1629                         std::complex<double> beta,
1630                         DeviceMemory<std::complex<double>> *c, int ldc);
1631 
1632   // See BlasSupport::DoBlasTrmm.
1633   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1634                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1635                        uint64 n, float alpha, const DeviceMemory<float> &a,
1636                        int lda, DeviceMemory<float> *b, int ldb);
1637   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1638                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1639                        uint64 n, double alpha, const DeviceMemory<double> &a,
1640                        int lda, DeviceMemory<double> *b, int ldb);
1641   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1642                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1643                        uint64 n, std::complex<float> alpha,
1644                        const DeviceMemory<std::complex<float>> &a, int lda,
1645                        DeviceMemory<std::complex<float>> *b, int ldb);
1646   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1647                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1648                        uint64 n, std::complex<double> alpha,
1649                        const DeviceMemory<std::complex<double>> &a, int lda,
1650                        DeviceMemory<std::complex<double>> *b, int ldb);
1651 
1652   // See BlasSupport::DoBlasTrsm.
1653   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1654                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1655                        uint64 n, float alpha, const DeviceMemory<float> &a,
1656                        int lda, DeviceMemory<float> *b, int ldb);
1657   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1658                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1659                        uint64 n, double alpha, const DeviceMemory<double> &a,
1660                        int lda, DeviceMemory<double> *b, int ldb);
1661   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1662                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1663                        uint64 n, std::complex<float> alpha,
1664                        const DeviceMemory<std::complex<float>> &a, int lda,
1665                        DeviceMemory<std::complex<float>> *b, int ldb);
1666   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1667                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1668                        uint64 n, std::complex<double> alpha,
1669                        const DeviceMemory<std::complex<double>> &a, int lda,
1670                        DeviceMemory<std::complex<double>> *b, int ldb);
1671 
1672   // See FftSupport::DoFft.
1673   Stream &ThenFft(fft::Plan *plan,
1674                   const DeviceMemory<std::complex<float>> &input,
1675                   DeviceMemory<std::complex<float>> *output);
1676   Stream &ThenFft(fft::Plan *plan,
1677                   const DeviceMemory<std::complex<double>> &input,
1678                   DeviceMemory<std::complex<double>> *output);
1679   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1680                   DeviceMemory<std::complex<float>> *output);
1681   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1682                   DeviceMemory<std::complex<double>> *output);
1683   Stream &ThenFft(fft::Plan *plan,
1684                   const DeviceMemory<std::complex<float>> &input,
1685                   DeviceMemory<float> *output);
1686   Stream &ThenFft(fft::Plan *plan,
1687                   const DeviceMemory<std::complex<double>> &input,
1688                   DeviceMemory<double> *output);
1689 
1690   // Makes the RNG use the provided value as the basis for further generation.
1691   // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1692   // sources of seed data if the default (high quality) sources are not
1693   // desired.
1694   // For most use cases, this function will not be necessary; each provided
1695   // back-end implementation will be appropriately seeded by default.
1696   // At a minimum 16 bytes of data are required in the seed buffer.
1697   //
1698   // To seed with good (non-reproducible) data:
1699   //   File* f = File::Open("/dev/random", "r");
1700   //   int64 bytes_read = f->Read(seed_data, bytes_to_read);
1701   //   < error checking >
1702   //   stream.ThenSetRngSeed(seed_data, bytes_read);
1703   //
1704   // To seed with reproducible data:
1705   //   uint64_t seed_data[2] = { <data> };
1706   //   stream.ThenSetRngSeed(seed_data, 16);
1707   Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1708 
1709   // Populates the memory indicated by values with uniform-random-distribution
1710   // values. TODO(leary) seeding API/description
1711   //
1712   // Uses the type and size of the DeviceMemory to infer what data should be
1713   // populated.
1714   Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1715   Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1716   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1717   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1718   Stream &ThenPopulateRandGaussian(float mean, float stddev,
1719                                    DeviceMemory<float> *values);
1720   Stream &ThenPopulateRandGaussian(double mean, double stddev,
1721                                    DeviceMemory<double> *values);
1722 
1723   // Entrain onto the stream: a memcpy to a host destination from a GPU source
1724   // of the given target size. host_dst must be a pointer to host memory
1725   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1726   // then registered with StreamExecutor::HostMemoryRegister.
1727   Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1728                      uint64 size);
1729 
1730   // Entrain onto the stream: a memcpy to a GPU destination from a host source
1731   // of the given target size. host_src must be a pointer to host memory
1732   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1733   // then registered with StreamExecutor::HostMemoryRegister.
1734   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1735                      uint64 size);
1736 
1737   // Alternative interface for memcpying from device to host that takes an
1738   // array slice. Checks that the destination size can accommodate the host
1739   // slice size.
1740   template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1741   Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1742                         port::MutableArraySlice<T> host_dst) {
1743     auto host_size = host_dst.size() * sizeof(T);
1744     CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1745     return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1746   }
1747 
1748   // Alternative interface for memcpying from host to device that takes an
1749   // array slice. Checks that the destination size can accommodate the host
1750   // slice size.
1751   template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1752   Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1753                         DeviceMemory<T> *gpu_dst) {
1754     auto host_size = host_src.size() * sizeof(T);
1755     CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1756     return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1757   }
1758 
1759   // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1760   // of the given target size. gpu_src/dst must be pointers to GPU memory and
1761   // peer access must be enabled between their owning StreamExecutors.
1762   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1763                      uint64 size);
1764 
1765   // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1766   // ensuring that the host pointer isn't getting confused accidentally with a
1767   // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1768   Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1769                         const DeviceMemoryBase &gpu_src, uint64 size) {
1770     return ThenMemcpy(gpu_dst, gpu_src, size);
1771   }
1772 
1773   // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1774   // The location must not be null.
1775   Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1776 
1777   // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1778   // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1779   // by 4). The location must not be null.
1780   Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1781 
1782   // Enqueue a forward operation of the RNN model onto the stream.
1783   // See DnnSupport::DoRnnForward for more details.
1784   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1785                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1786                          const DeviceMemory<Eigen::half> &input_data,
1787                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1788                          const DeviceMemory<Eigen::half> &input_h_data,
1789                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1790                          const DeviceMemory<Eigen::half> &input_c_data,
1791                          const DeviceMemory<Eigen::half> &params,
1792                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1793                          DeviceMemory<Eigen::half> *output_data,
1794                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1795                          DeviceMemory<Eigen::half> *output_h_data,
1796                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1797                          DeviceMemory<Eigen::half> *output_c_data,
1798                          bool is_training,
1799                          ScratchAllocator *reserve_space_allocator,
1800                          ScratchAllocator *workspace_allocator,
1801                          dnn::ProfileResult *output_profile_result);
1802 
1803   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1804                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1805                          const DeviceMemory<float> &input_data,
1806                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1807                          const DeviceMemory<float> &input_h_data,
1808                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1809                          const DeviceMemory<float> &input_c_data,
1810                          const DeviceMemory<float> &params,
1811                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1812                          DeviceMemory<float> *output_data,
1813                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1814                          DeviceMemory<float> *output_h_data,
1815                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1816                          DeviceMemory<float> *output_c_data, bool is_training,
1817                          ScratchAllocator *reserve_space_allocator,
1818                          ScratchAllocator *workspace_allocator,
1819                          dnn::ProfileResult *output_profile_result);
1820 
1821   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1822                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1823                          const DeviceMemory<double> &input_data,
1824                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1825                          const DeviceMemory<double> &input_h_data,
1826                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1827                          const DeviceMemory<double> &input_c_data,
1828                          const DeviceMemory<double> &params,
1829                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1830                          DeviceMemory<double> *output_data,
1831                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1832                          DeviceMemory<double> *output_h_data,
1833                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1834                          DeviceMemory<double> *output_c_data, bool is_training,
1835                          ScratchAllocator *reserve_space_allocator,
1836                          ScratchAllocator *workspace_allocator,
1837                          dnn::ProfileResult *output_profile_result);
1838 
1839   // Enqueue a backward operation of the RNN model onto the stream.
1840   // See DnnSupport::DoRnnBackward for more details.
1841   Stream &ThenRnnBackward(
1842       const dnn::RnnDescriptor &rnn_desc,
1843       const dnn::RnnSequenceTensorDescriptor &input_desc,
1844       const DeviceMemory<Eigen::half> &input_data,
1845       const dnn::RnnStateTensorDescriptor &input_h_desc,
1846       const DeviceMemory<Eigen::half> &input_h_data,
1847       const dnn::RnnStateTensorDescriptor &input_c_desc,
1848       const DeviceMemory<Eigen::half> &input_c_data,
1849       const DeviceMemory<Eigen::half> &params,
1850       const dnn::RnnSequenceTensorDescriptor &output_desc,
1851       const DeviceMemory<Eigen::half> &output_data,
1852       const dnn::RnnStateTensorDescriptor &output_h_desc,
1853       const DeviceMemory<Eigen::half> &output_h_data,
1854       const dnn::RnnStateTensorDescriptor &output_c_desc,
1855       const DeviceMemory<Eigen::half> &output_c_data,
1856       const DeviceMemory<Eigen::half> &output_backprop_data,
1857       const DeviceMemory<Eigen::half> &output_h_backprop_data,
1858       const DeviceMemory<Eigen::half> &output_c_backprop_data,
1859       DeviceMemory<Eigen::half> *input_backprop_data,
1860       DeviceMemory<Eigen::half> *input_h_backprop_data,
1861       DeviceMemory<Eigen::half> *input_c_backprop_data,
1862       DeviceMemory<Eigen::half> *params_backprop_data,
1863       DeviceMemory<uint8> *reserve_space_data,
1864       ScratchAllocator *workspace_allocator,
1865       dnn::ProfileResult *output_profile_result);
1866 
1867   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1868                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1869                           const DeviceMemory<float> &input_data,
1870                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1871                           const DeviceMemory<float> &input_h_data,
1872                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1873                           const DeviceMemory<float> &input_c_data,
1874                           const DeviceMemory<float> &params,
1875                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1876                           const DeviceMemory<float> &output_data,
1877                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1878                           const DeviceMemory<float> &output_h_data,
1879                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1880                           const DeviceMemory<float> &output_c_data,
1881                           const DeviceMemory<float> &output_backprop_data,
1882                           const DeviceMemory<float> &output_h_backprop_data,
1883                           const DeviceMemory<float> &output_c_backprop_data,
1884                           DeviceMemory<float> *input_backprop_data,
1885                           DeviceMemory<float> *input_h_backprop_data,
1886                           DeviceMemory<float> *input_c_backprop_data,
1887                           DeviceMemory<float> *params_backprop_data,
1888                           DeviceMemory<uint8> *reserve_space_data,
1889                           ScratchAllocator *workspace_allocator,
1890                           dnn::ProfileResult *output_profile_result);
1891 
1892   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1893                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1894                           const DeviceMemory<double> &input_data,
1895                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1896                           const DeviceMemory<double> &input_h_data,
1897                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1898                           const DeviceMemory<double> &input_c_data,
1899                           const DeviceMemory<double> &params,
1900                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1901                           const DeviceMemory<double> &output_data,
1902                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1903                           const DeviceMemory<double> &output_h_data,
1904                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1905                           const DeviceMemory<double> &output_c_data,
1906                           const DeviceMemory<double> &output_backprop_data,
1907                           const DeviceMemory<double> &output_h_backprop_data,
1908                           const DeviceMemory<double> &output_c_backprop_data,
1909                           DeviceMemory<double> *input_backprop_data,
1910                           DeviceMemory<double> *input_h_backprop_data,
1911                           DeviceMemory<double> *input_c_backprop_data,
1912                           DeviceMemory<double> *params_backprop_data,
1913                           DeviceMemory<uint8> *reserve_space_data,
1914                           ScratchAllocator *workspace_allocator,
1915                           dnn::ProfileResult *output_profile_result);
1916 
1917   // Enqueue a CTCLoss operation onto the stream.
1918   // See DnnSupport::DoCtcLoss for more details.
1919   Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
1920                       const DeviceMemory<float> &probs_data,
1921                       absl::Span<const int> labels_data,
1922                       absl::Span<const int> labels_lengths_data,
1923                       absl::Span<const int> input_lengths_data,
1924                       DeviceMemory<float> *costs_data,
1925                       const dnn::RnnStateTensorDescriptor &grads_desc,
1926                       DeviceMemory<float> *grads_data,
1927                       ScratchAllocator *workspace_allocator);
1928 
1929   // Enqueue onto the stream a operation that transforms a tensor.
1930   // See DnnSupport::DoTransformTensor for more details.
1931   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1932                               dnn::DataType input_type,
1933                               const DeviceMemoryBase &input_data,
1934                               const dnn::BatchDescriptor &output_desc,
1935                               dnn::DataType output_type, float scale,
1936                               DeviceMemoryBase *output_data);
1937 
1938   // The templated version of the above ThenTransformTensor. Useful when the
1939   // input and output types are statically known.
1940   template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1941   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1942                               const DeviceMemory<InElemT> &input_data,
1943                               const dnn::BatchDescriptor &output_desc,
1944                               DeviceMemory<OutElemT> *output_data) {
1945     return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1946                                input_data, output_desc,
1947                                dnn::ToDataType<OutElemT>(), output_data);
1948   }
1949 
1950   // (Synchronously) block the host code waiting for the operations
1951   // entrained on the stream (enqueued to this point in program
1952   // execution) to complete.
1953   //
1954   // Returns an OK status if the blocking was successful and the stream is ok().
1955   // Otherwise returns an error describing why the blocking failed.
1956   port::Status BlockHostUntilDone() LOCKS_EXCLUDED(mu_);
1957 
1958   // Warning! This method interacts with internal threads in
1959   // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1960   // use
1961   // only. Please check with a member of the FASTR team before making use of
1962   // this method.
1963   //
1964   // Entrains onto the stream a function to be executed on the host at some
1965   // point in the future.
1966   // Async host callbacks DO NOT block the stream as device functions (or as
1967   // synchronous host callbacks). No synchronization is possible with
1968   // asynchronous callbacks; they are strictly fire-and-forget.
1969   // This method is private due to the potential for undefined behavior with
1970   // synchronization using OpenCL user events.
1971   // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1972   // parameter will still be valid - this Stream may not be!
1973   // Any callbacks requiring device API calls must use this method.
1974   Stream &ThenEnqueueOnBackgroundThread(
1975       std::function<void(StreamExecutor *)> task);
1976 
1977   // Returns the (opaque) platform-specific backing object. Ownership is not
1978   // transferred to the caller.
implementation()1979   internal::StreamInterface *implementation() { return implementation_.get(); }
1980 
1981   // Entrains onto the stream a callback to the host (from the device).
1982   // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1983   // never fail or its failure is inconsequential.
1984   //
1985   // This is kept for backward compatibility. Future code should use
1986   // ThenDoHostCallbackWithStatus and explicitly return a success status.
1987   // TODO(b/112125301): Eventually remove this method.
1988   Stream &ThenDoHostCallback(std::function<void()> callback);
1989 
1990   // Entrains onto the stream a callback to the host (from the device).
1991   // Host callbacks block/occupy the stream just as device functions
1992   // (execute one at a time, block later stream operations).
1993   // Whether the callback return status affects the result of BlockHostUntilDone
1994   // is platform-dependent.
1995   //
1996   // Behavior is undefined when synchronizing using OpenCL user events.
1997   // Behavior is undefined if host callbacks call device routines or insert
1998   // them into any stream.
1999   //
2000   // On certain platforms, ThenDoHostCallback is expected to have significant
2001   // negative effects on performance.
2002   Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
2003 
2004   // Runs the given callback after the next call to BlockHostUntilDone on this
2005   // stream (or after the Stream does BlockHostUntilDone iin its destructor).
2006   // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
2007   // some use cases.
2008   Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
2009 
2010   // Returns the StreamExecutor (parent object) associated with this stream.
parent()2011   StreamExecutor *parent() const {
2012     CHECK(parent_ != nullptr);
2013     return parent_;
2014   }
2015 
2016   // Returns the (internal usage) temporary-memory-allocation manager associated
2017   // with this stream.
2018   internal::TemporaryMemoryManager *temporary_memory_manager();
2019 
2020   // Returns a debugging string "[stream=0x...,impl=0x...]".
2021   string DebugStreamPointers() const;
2022 
2023  private:
2024   friend class host::HostBlas;  // for parent_.
2025   friend class host::HostFft;   // for parent_.
2026   friend class host::HostRng;   // for parent_.
2027   template <typename... Args>
2028   friend struct ThenBlasImpl;  // for implementing ThenBlasXXX.
2029   friend class ocl::CLBlas;    // for parent_.
2030 
InErrorState()2031   bool InErrorState() const LOCKS_EXCLUDED(mu_) {
2032     absl::ReaderMutexLock lock(&mu_);
2033     return !ok_;
2034   }
2035 
2036   // Sets the error state if operation_retcode is false.
2037   // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)2038   void CheckError(bool operation_retcode) LOCKS_EXCLUDED(mu_) {
2039     if (operation_retcode) {
2040       return;
2041     }
2042     absl::MutexLock lock(&mu_);
2043     ok_ = false;
2044   }
2045 
2046   // Checks the status and logs the error message, if any.
2047   void CheckStatus(port::Status status) LOCKS_EXCLUDED(mu_);
2048 
SetError()2049   void SetError() { CheckError(false /* = operation_retcode */); }
2050 
SetErrorAndLogNoDnnSupport()2051   void SetErrorAndLogNoDnnSupport() {
2052     SetError();
2053     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
2054                     "without DNN support";
2055   }
2056 
2057   // Runs the set of callbacks that are intended to run after
2058   // BlockHostUntilDone.
2059   void RunAfterBlockHostUntilDoneCallbacks();
2060 
2061   // The StreamExecutor that supports the operation of this stream.
2062   StreamExecutor *parent_;
2063 
2064   // The platform-dependent implementation that the StreamExecutor interface
2065   // delegates to.
2066   std::unique_ptr<internal::StreamInterface> implementation_;
2067 
2068   // mutex that guards the allocation / error state flags.
2069   // Mutable so that it can be obtained via const reader lock.
2070   mutable absl::Mutex mu_;
2071 
2072   // Whether Init() was successfully called to allocate this stream on the
2073   // underlying platform. It simply flips from 0 to 1 with a sanity check.
2074   // See StreamExecutor::AllocateStream.
2075   bool allocated_ GUARDED_BY(mu_);
2076 
2077   // Whether all operations have entrained successfully to the current program
2078   // point.
2079   bool ok_ GUARDED_BY(mu_);
2080 
2081   // Sub-streams that are generated from this stream. Each element has a pointer
2082   // to sub-stream and a boolean value indicating if this substream is ready to
2083   // be reused.
2084   std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2085       GUARDED_BY(mu_);
2086 
2087   // Streams can allocate temporary memories to help with work they enqueue
2088   // (e.g. for scratch memory spaces). This member tracks those allocations and
2089   // notes when they can be reclaimed -- reclamation is attempted when
2090   // BlockHostUntilDone() is called.
2091   internal::TemporaryMemoryManager temporary_memory_manager_;
2092 
2093   // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
2094   std::vector<std::function<void()>> after_block_host_until_done_callbacks_
2095       GUARDED_BY(mu_);
2096 
2097   // Implementation of ThenConvolveBackwardBias that is shared by all types.
2098   template <typename T>
2099   Stream &ThenConvolveBackwardBiasImpl(
2100       const dnn::BatchDescriptor &input_descriptor,
2101       const DeviceMemory<T> &input_data,
2102       const dnn::BatchDescriptor &bias_descriptor,
2103       DeviceMemory<T> *backward_bias_data);
2104 
2105   SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2106 };
2107 
2108 ////////////
2109 // Inlines
2110 
2111 template <typename T>
2112 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2113 Stream::AllocateTemporaryArray(uint64 element_count) {
2114   return temporary_memory_manager_.AllocateArray<T>(element_count);
2115 }
2116 
temporary_memory_manager()2117 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2118   return &temporary_memory_manager_;
2119 }
2120 
2121 template <>
2122 struct Quantization<uint8> {
2123   static constexpr dnn::QuantizedActivationMode kModeId =
2124       dnn::QuantizedActivationMode::k8Bit;
2125 };
2126 
2127 template <>
2128 struct Quantization<uint16> {
2129   static constexpr dnn::QuantizedActivationMode kModeId =
2130       dnn::QuantizedActivationMode::k16Bit;
2131 };
2132 
2133 template <>
2134 struct Quantization<int32> {
2135   static constexpr dnn::QuantizedActivationMode kModeId =
2136       dnn::QuantizedActivationMode::k32Bit;
2137 };
2138 
2139 }  // namespace stream_executor
2140 
2141 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2142