1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23
24 #include <complex>
25 #include <functional>
26 #include <memory>
27
28 #include "absl/synchronization/mutex.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/stream_executor/blas.h"
32 #include "tensorflow/stream_executor/device_memory.h"
33 #include "tensorflow/stream_executor/dnn.h"
34 #include "tensorflow/stream_executor/event.h"
35 #include "tensorflow/stream_executor/fft.h"
36 #include "tensorflow/stream_executor/host_or_device_scalar.h"
37 #include "tensorflow/stream_executor/kernel.h"
38 #include "tensorflow/stream_executor/launch_dim.h"
39 #include "tensorflow/stream_executor/lib/array_slice.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
42 #include "tensorflow/stream_executor/temporary_memory_manager.h"
43
44 namespace stream_executor {
45
46 namespace host {
47 class HostBlas;
48 class HostFft;
49 class HostRng;
50 class HostTimer;
51 } // namespace host
52
53 namespace ocl {
54 class CLBlas;
55 } // namespace ocl
56
57 namespace internal {
58 class StreamInterface;
59 } // namespace internal
60
61 class DeviceMemoryBase;
62 template <typename ElemT>
63 class DeviceMemory;
64
65 class Timer;
66
67 namespace dnn {
68 class BatchDescriptor;
69 class FilterDescriptor;
70 class ConvolutionDescriptor;
71 class ProfileResult;
72 class AlgorithmDesc;
73 } // namespace dnn
74
75 class StreamExecutor;
76 class ScratchAllocator;
77
78 namespace detail {
79
80 // Helper class to prevent a template function argument from being deduced. This
81 // is identical to std::type_identity in C++20.
82 template <typename T>
83 struct NonDeduced {
84 using type = T;
85 };
86 template <typename T>
87 using NonDeducedType = typename NonDeduced<T>::type;
88
89 } // namespace detail
90
91 // Convert a type to the corresponding QuantizedActivationMode.
92 template <typename ElementType>
93 struct Quantization;
94
95 // Represents a stream of dependent computations on a GPU device.
96 //
97 // The operations within a stream execute linearly and asynchronously until
98 // BlockHostUntilDone() is invoked, which synchronously joins host code with
99 // the execution of the stream.
100 //
101 // If any given operation fails when entraining work for the stream, ok() will
102 // indicate that an error has occurred. After initialization, once a stream is
103 // !ok(), it will never be ok().
104 //
105 // Thread-safe post-initialization.
106 class Stream {
107 public:
108 // Instantiate a stream tied to parent as a platform executor. Work
109 // entrained onto this stream will be launched/managed on that
110 // StreamExecutor's platform.
111 explicit Stream(StreamExecutor *parent);
112
113 // Test only. Use an externally-populated value (like a mock) for the
114 // platform-specific stream implementation.
115 Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
116
117 // Deallocates any stream resources that the parent StreamExecutor has
118 // bestowed
119 // upon this object.
120 ~Stream();
121
122 // Returns whether any errors have occurred while entraining work for this
123 // stream.
ok()124 bool ok() const { return !InErrorState(); }
125
126 // Retrieves execution status back into the stream from the underlying
127 // implementation without blocking the stream.
128 //
129 // Normally, Stream::BlockHostUntilDone is used to get execution status.
130 // However, some devices use out-of-band mechnanisms to ensure their streams
131 // have finished on-device work, without needing to block the streams. (These
132 // devices should also override AllowsSyncOnCompletion to return false.) For
133 // these devices, this method can be used after work is finished to retrieve
134 // execution status.
135 port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_);
136
137 // Initialize the stream. This must be performed before entraining any other
138 // operations.
139 Stream &Init() TF_LOCKS_EXCLUDED(mu_);
140
141 // Initializes timer t via the StreamExecutor.
142 Stream &InitTimer(Timer *t);
143
144 // Convenience wrapper around Init() and InitTimer().
145 Stream &InitWithTimer(Timer *t);
146
147 // Get or create a sub-stream from this stream. If there is any sub-stream in
148 // the pool that can be reused then just return this sub-stream. Otherwise
149 // create a new sub-stream.
150 //
151 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
152 Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_);
153
154 // Return the sub-stream back to the host stream so that it can be reused
155 // later. Sub-streams that are !ok() will not be reused.
156 //
157 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
158 void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_);
159
160 // Allocate temporary memories. The stream will deallocate them when blocked
161 // or destroyed.
162 template <typename T>
163 port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
164 AllocateTemporaryArray(uint64 element_count);
165
166 // Entrains onto the stream of operations: a kernel launch with the given
167 // (variadic) parameters for the invocation. These arguments can be things
168 // like DeviceMemory or primitive types such as int. What arguments you may
169 // pass to a given kernel are noted as the template parameters to the
170 // TypedKernel type that the machocc compiler generates.
171 //
172 // Template parameters:
173 // Params... The type list of formal parameters that the typed kernel
174 // expects, which is matched against Args...
175 // Args... The deduced type list for passed actual arguments
176 //
177 // Implementation: A compile-time compatibility check is performed that has
178 // some leniency versus an exact parameter pack match -- for example,
179 // `const DeviceMemory<T>` is considered "pack compatible" with a
180 // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
181 // perfect forwarding support without rvalue references. It also attempts to
182 // spit out helpful static_assert error traces with information as to the
183 // argument number and types that were mismatched.
184 template <typename... Params, typename... Args>
185 Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
186 const TypedKernel<Params...> &kernel, Args... args);
187
188 // Record a "start" event for the interval timer at this point in the
189 // stream's execution (relative to the previously and subsequently enqueued
190 // items in the stream's execution). Streams may be started/stopped multiple
191 // times.
192 Stream &ThenStartTimer(Timer *t);
193
194 // Record a "stop" event for the interval timer at this point in the
195 // stream's execution. See also Stream::ThenStartTimer.
196 Stream &ThenStopTimer(Timer *t);
197
198 // TODO(leary) If work is added to the stream that is being depended upon,
199 // then what? Have to describe what happens.
200 template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)201 Stream &ThenWaitFor(Stream *other, Params... more_streams) {
202 return ThenWaitFor(more_streams...).ThenWaitFor(other);
203 }
204
205 // Create a dependency for this stream's next work on the other stream
206 // completing. Does not take ownership of other, and other must not be
207 // null.
208 //
209 // Checks that a stream does not wait for itself, and it is up to the
210 // user to guarantee that a stream does not come to wait on itself in a
211 // cyclic manner; in that case, behavior is undefined.
212 //
213 // N.B. Base recursion case for the variadic ThenWaitFor.
214 Stream &ThenWaitFor(Stream *other);
215
216 // Waits for all streams values in others.
217 // Checks that there is no shallow circular wait (i.e. that "this" is not in
218 // others)
219 template <typename P>
ThenWaitFor(P others)220 Stream &ThenWaitFor(P others) {
221 for (auto &stream : *others) {
222 CHECK_NE(stream.get(), this);
223 ThenWaitFor(stream.get());
224 }
225 return *this;
226 }
227
228 // Waits for an event object to be set.
229 // Note that ThenRecordEvent must have been called on the event before
230 // you call this function; otherwise the event will be considered complete
231 // and this wait will do nothing.
232 Stream &ThenWaitFor(Event *event);
233
234 // Inserts the specified event into the end of this stream. Once the stream
235 // has processed all events prior to the insertion point, the event will be
236 // marked as completed.
237 // The stream does not take ownership of event - meaning that event's lifetime
238 // must extend past the point at which it is marked complete!
239 Stream &ThenRecordEvent(Event *event);
240
241 ////////////////
242 // DNN support
243 //
244 // See DnnSupport::* for comments on the following methods.
245
246 Stream &ThenBatchNormalizationForward(
247 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
248 const DeviceMemory<float> &offset,
249 const DeviceMemory<float> &estimated_mean,
250 const DeviceMemory<float> &estimated_variance,
251 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
252 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
253 const double exponential_average_factor,
254 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
255 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
256 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
257 bool is_training,
258 ScratchAllocator *reserve_space_allocator,
259 ScratchAllocator *workspace_allocator);
260
261 Stream &ThenBatchNormalizationBackward(
262 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
263 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
264 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
265 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
266 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
267 DeviceMemory<float> *offset_backprop,
268 DeviceMemory<uint8> *reserve_space_data,
269 ScratchAllocator *workspace_allocator);
270
271 Stream &ThenBatchNormalizationForward(
272 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
273 const DeviceMemory<float> &offset,
274 const DeviceMemory<float> &estimated_mean,
275 const DeviceMemory<float> &estimated_variance,
276 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
277 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
278 const double exponential_average_factor,
279 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
280 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
281 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
282 bool is_training,
283 ScratchAllocator *reserve_space_allocator,
284 ScratchAllocator *workspace_allocator);
285
286 Stream &ThenBatchNormalizationBackward(
287 const DeviceMemory<Eigen::half> &y_backprop,
288 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
289 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
290 const dnn::BatchDescriptor &x_desc,
291 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
292 DeviceMemory<Eigen::half> *x_backprop,
293 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
294 DeviceMemory<uint8> *reserve_space_data,
295 ScratchAllocator *workspace_allocator);
296
297 Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
298 const DeviceMemory<float> &input_data,
299 const dnn::FilterDescriptor &filter_descriptor,
300 const DeviceMemory<float> &filter_data,
301 const dnn::ConvolutionDescriptor &convolution_descriptor,
302 const dnn::BatchDescriptor &output_descriptor,
303 DeviceMemory<float> *output);
304
305 Stream &ThenConvolveQuantized(
306 const dnn::BatchDescriptor &input_descriptor,
307 const DeviceMemory<float> &input_data,
308 const dnn::FilterDescriptor &filter_descriptor,
309 const DeviceMemory<int8> &filter_coefficients,
310 const DeviceMemory<float> &coefficient_scales,
311 const dnn::ConvolutionDescriptor &convolution_descriptor,
312 const dnn::BatchDescriptor &output_descriptor,
313 DeviceMemory<float> *output_data);
314
315 Stream &ThenConvolveQuantized(
316 const dnn::BatchDescriptor &input_descriptor,
317 const DeviceMemory<float> &input_data,
318 const dnn::FilterDescriptor &filter_descriptor,
319 const DeviceMemory<int16> &filter_coefficients,
320 const DeviceMemory<float> &coefficient_scales,
321 const dnn::ConvolutionDescriptor &convolution_descriptor,
322 const dnn::BatchDescriptor &output_descriptor,
323 DeviceMemory<float> *output_data);
324
325 template <typename InputType, typename OutputType>
ConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<InputType> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)326 port::Status ConvolveWithAlgorithm(
327 const dnn::BatchDescriptor &input_descriptor,
328 const DeviceMemory<InputType> &input_data,
329 const dnn::FilterDescriptor &filter_descriptor,
330 const DeviceMemory<InputType> &filter_data,
331 const dnn::ConvolutionDescriptor &convolution_descriptor,
332 const dnn::BatchDescriptor &output_descriptor,
333 DeviceMemory<OutputType> *output, ScratchAllocator *scratch_allocator,
334 const dnn::AlgorithmConfig &algorithm_config,
335 dnn::ProfileResult *output_profile_result) {
336 DeviceMemory<uint8> scratch_memory;
337 dnn::AlgorithmDesc algorithm_desc;
338 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
339 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
340 dnn::ConvolutionKind::FORWARD, this, input_descriptor, input_data,
341 filter_descriptor, filter_data, output_descriptor, *output,
342 convolution_descriptor, algorithm_config, scratch_allocator,
343 &algorithm_desc, &scratch_memory));
344 return dnn->DoConvolve(
345 dnn::ConvolutionKind::FORWARD, dnn::ToDataType<InputType>::value,
346 dnn::ToDataType<OutputType>::value, this, input_descriptor,
347 input_data, filter_descriptor, filter_data, output_descriptor,
348 *output, convolution_descriptor, algorithm_desc, scratch_memory,
349 output_profile_result);
350 }
351 return port::UnimplementedError("DNN library is not found.");
352 }
353
354 port::Status FusedConvolveWithAlgorithm(
355 const dnn::BatchDescriptor &conv_input_descriptor,
356 const DeviceMemory<double> &conv_input_data, double conv_input_scale,
357 const dnn::FilterDescriptor &filter_descriptor,
358 const DeviceMemory<double> &filter_data,
359 const dnn::ConvolutionDescriptor &convolution_descriptor,
360 const DeviceMemory<double> &side_input_data, double side_input_scale,
361 const dnn::BatchDescriptor &bias_descriptor,
362 const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
363 const dnn::BatchDescriptor &output_descriptor,
364 DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
365 const dnn::AlgorithmConfig &algorithm_config,
366 dnn::ProfileResult *output_profile_result);
367
368 port::Status FusedConvolveWithAlgorithm(
369 const dnn::BatchDescriptor &conv_input_descriptor,
370 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
371 const dnn::FilterDescriptor &filter_descriptor,
372 const DeviceMemory<float> &filter_data,
373 const dnn::ConvolutionDescriptor &convolution_descriptor,
374 const DeviceMemory<float> &side_input_data, float side_input_scale,
375 const dnn::BatchDescriptor &bias_descriptor,
376 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
377 const dnn::BatchDescriptor &output_descriptor,
378 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
379 const dnn::AlgorithmConfig &algorithm_config,
380 dnn::ProfileResult *output_profile_result);
381
382 port::Status FusedConvolveWithAlgorithm(
383 const dnn::BatchDescriptor &conv_input_descriptor,
384 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
385 const dnn::FilterDescriptor &filter_descriptor,
386 const DeviceMemory<Eigen::half> &filter_data,
387 const dnn::ConvolutionDescriptor &convolution_descriptor,
388 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
389 const dnn::BatchDescriptor &bias_descriptor,
390 const DeviceMemory<Eigen::half> &biases,
391 dnn::ActivationMode activation_mode,
392 const dnn::BatchDescriptor &output_descriptor,
393 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
394 const dnn::AlgorithmConfig &algorithm_config,
395 dnn::ProfileResult *output_profile_result);
396
397 port::Status FusedConvolveWithAlgorithm(
398 const dnn::BatchDescriptor &conv_input_descriptor,
399 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
400 const dnn::FilterDescriptor &filter_descriptor,
401 const DeviceMemory<int8> &filter_data,
402 const dnn::ConvolutionDescriptor &convolution_descriptor,
403 const DeviceMemory<int8> &side_input_data, float side_input_scale,
404 const dnn::BatchDescriptor &bias_descriptor,
405 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
406 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
407 ScratchAllocator *scratch_allocator,
408 const dnn::AlgorithmConfig &algorithm_config,
409 dnn::ProfileResult *output_profile_result);
410
411 port::Status FusedConvolveWithAlgorithm(
412 const dnn::BatchDescriptor &conv_input_descriptor,
413 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
414 const dnn::FilterDescriptor &filter_descriptor,
415 const DeviceMemory<int8> &filter_data,
416 const dnn::ConvolutionDescriptor &convolution_descriptor,
417 const DeviceMemory<float> &side_input_data, float side_input_scale,
418 const dnn::BatchDescriptor &bias_descriptor,
419 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
420 const dnn::BatchDescriptor &output_descriptor,
421 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
422 const dnn::AlgorithmConfig &algorithm_config,
423 dnn::ProfileResult *output_profile_result);
424
425 Stream &ThenSeparableConvolve(
426 const dnn::BatchDescriptor &input_descriptor,
427 const DeviceMemory<float> &input_data,
428 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
429 const DeviceMemory<float> &first_weights,
430 const DeviceMemory<float> &second_weights,
431 const dnn::ConvolutionDescriptor &convolution_descriptor,
432 const dnn::BatchDescriptor &output_descriptor,
433 DeviceMemory<float> *output);
434
435 template <typename ElementType>
ConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<ElementType> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)436 port::Status ConvolveBackwardDataWithAlgorithm(
437 const dnn::FilterDescriptor &filter_descriptor,
438 const DeviceMemory<ElementType> &filter_data,
439 const dnn::BatchDescriptor &output_descriptor,
440 DeviceMemory<ElementType> backward_output_data,
441 const dnn::ConvolutionDescriptor &convolution_descriptor,
442 const dnn::BatchDescriptor &input_descriptor,
443 DeviceMemory<ElementType> *backward_input_data,
444 ScratchAllocator *scratch_allocator,
445 const dnn::AlgorithmConfig &algorithm_config,
446 dnn::ProfileResult *output_profile_result) {
447 DeviceMemory<uint8> scratch_memory;
448 dnn::AlgorithmDesc algorithm_desc;
449 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
450 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
451 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
452 *backward_input_data, filter_descriptor, filter_data,
453 output_descriptor, backward_output_data, convolution_descriptor,
454 algorithm_config, scratch_allocator, &algorithm_desc,
455 &scratch_memory));
456 return dnn->DoConvolve(
457 dnn::ConvolutionKind::BACKWARD_DATA,
458 dnn::ToDataType<ElementType>::value,
459 dnn::ToDataType<ElementType>::value, this, input_descriptor,
460 *backward_input_data, filter_descriptor, filter_data,
461 output_descriptor, backward_output_data, convolution_descriptor,
462 algorithm_desc, scratch_memory, output_profile_result);
463 }
464 return port::UnimplementedError("DNN library is not found.");
465 }
466
467 template <typename ElementType>
ConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)468 port::Status ConvolveBackwardFilterWithAlgorithm(
469 const dnn::BatchDescriptor &input_descriptor,
470 const DeviceMemory<ElementType> &input_data,
471 const dnn::BatchDescriptor &output_descriptor,
472 DeviceMemory<ElementType> backward_output_data,
473 const dnn::ConvolutionDescriptor &convolution_descriptor,
474 const dnn::FilterDescriptor &filter_descriptor,
475 DeviceMemory<ElementType> *backward_filter_data,
476 ScratchAllocator *scratch_allocator,
477 const dnn::AlgorithmConfig &algorithm_config,
478 dnn::ProfileResult *output_profile_result) {
479 DeviceMemory<uint8> scratch_memory;
480 dnn::AlgorithmDesc algorithm_desc;
481 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
482 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
483 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
484 input_data, filter_descriptor, *backward_filter_data,
485 output_descriptor, backward_output_data, convolution_descriptor,
486 algorithm_config, scratch_allocator, &algorithm_desc,
487 &scratch_memory));
488 return dnn->DoConvolve(
489 dnn::ConvolutionKind::BACKWARD_FILTER,
490 dnn::ToDataType<ElementType>::value,
491 dnn::ToDataType<ElementType>::value, this, input_descriptor,
492 input_data, filter_descriptor, *backward_filter_data,
493 output_descriptor, backward_output_data, convolution_descriptor,
494 algorithm_desc, scratch_memory, output_profile_result);
495 }
496 return port::UnimplementedError("DNN library is not found.");
497 }
498
499 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
500 const DeviceMemory<double> &input_data,
501 const dnn::BatchDescriptor &bias_descriptor,
502 DeviceMemory<double> *backward_bias_data);
503
504 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
505 const DeviceMemory<float> &input_data,
506 const dnn::BatchDescriptor &bias_descriptor,
507 DeviceMemory<float> *backward_bias_data);
508
509 Stream &ThenConvolveBackwardBias(
510 const dnn::BatchDescriptor &input_descriptor,
511 const DeviceMemory<Eigen::half> &input_data,
512 const dnn::BatchDescriptor &bias_descriptor,
513 DeviceMemory<Eigen::half> *backward_bias_data);
514
515 Stream &ThenMatMul(const DeviceMemory<float> &input_data,
516 const DeviceMemory<float> &weights,
517 const dnn::BatchDescriptor &input_dimensions,
518 const dnn::BatchDescriptor &output_dimensions,
519 DeviceMemory<float> *output_data);
520
521 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
522 const DeviceMemory<int8> &weights,
523 const DeviceMemory<float> &weight_scales,
524 const dnn::BatchDescriptor &input_dimensions,
525 const dnn::BatchDescriptor &output_dimensions,
526 DeviceMemory<float> *output_data);
527
528 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
529 const DeviceMemory<int16> &weights,
530 const DeviceMemory<float> &weight_scales,
531 const dnn::BatchDescriptor &input_dimensions,
532 const dnn::BatchDescriptor &output_dimensions,
533 DeviceMemory<float> *output_data);
534
535 Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
536 const DeviceMemory<float> &biases,
537 const dnn::BatchDescriptor &dimensions,
538 DeviceMemory<float> *output_data);
539
540 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
541 const dnn::BatchDescriptor &input_dimensions,
542 const DeviceMemory<double> &input_data,
543 const dnn::BatchDescriptor &output_dimensions,
544 DeviceMemory<double> *output_data,
545 ScratchAllocator *workspace_allocator = nullptr);
546
547 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
548 const dnn::BatchDescriptor &input_dimensions,
549 const DeviceMemory<float> &input_data,
550 const dnn::BatchDescriptor &output_dimensions,
551 DeviceMemory<float> *output_data,
552 ScratchAllocator *workspace_allocator = nullptr);
553
554 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
555 const dnn::BatchDescriptor &input_dimensions,
556 const DeviceMemory<Eigen::half> &input_data,
557 const dnn::BatchDescriptor &output_dimensions,
558 DeviceMemory<Eigen::half> *output_data,
559 ScratchAllocator *workspace_allocator = nullptr);
560
561 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
562 const dnn::BatchDescriptor &input_dimensions,
563 const DeviceMemory<int8> &input_data,
564 const dnn::BatchDescriptor &output_dimensions,
565 DeviceMemory<int8> *output_data,
566 ScratchAllocator *workspace_allocator = nullptr);
567
568 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
569 const dnn::BatchDescriptor &input_dimensions,
570 const DeviceMemory<double> &input_data,
571 const dnn::BatchDescriptor &output_dimensions,
572 const DeviceMemory<double> &output_data,
573 const DeviceMemory<double> &input_diff_data,
574 DeviceMemory<double> *output_diff_data,
575 ScratchAllocator *workspace_allocator = nullptr);
576
577 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
578 const dnn::BatchDescriptor &input_dimensions,
579 const DeviceMemory<float> &input_data,
580 const dnn::BatchDescriptor &output_dimensions,
581 const DeviceMemory<float> &output_data,
582 const DeviceMemory<float> &input_diff_data,
583 DeviceMemory<float> *output_diff_data,
584 ScratchAllocator *workspace_allocator = nullptr);
585
586 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
587 const dnn::BatchDescriptor &input_dimensions,
588 const DeviceMemory<Eigen::half> &input_data,
589 const dnn::BatchDescriptor &output_dimensions,
590 const DeviceMemory<Eigen::half> &output_data,
591 const DeviceMemory<Eigen::half> &input_diff_data,
592 DeviceMemory<Eigen::half> *output_diff_data,
593 ScratchAllocator *workspace_allocator = nullptr);
594
595 Stream &ThenNormalizeWithDimensions(
596 const dnn::NormalizeDescriptor &normalize_descriptor,
597 const dnn::BatchDescriptor &dimensions,
598 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
599
600 Stream &ThenNormalizeBackwardWithDimensions(
601 const dnn::NormalizeDescriptor &normalize_descriptor,
602 const dnn::BatchDescriptor &dimensions,
603 const DeviceMemory<float> &raw_data,
604 const DeviceMemory<float> &normalized_data,
605 const DeviceMemory<float> &normalized_variable_gradient,
606 DeviceMemory<float> *raw_variable_gradient,
607 ScratchAllocator *workspace_allocator = nullptr);
608
609 Stream &ThenActivate(dnn::ActivationMode activation_mode,
610 const dnn::BatchDescriptor &dimensions,
611 const DeviceMemory<float> &input_data,
612 DeviceMemory<float> *output_data);
613
614 // Same as ThenActivate, but also takes an options argument that can be used
615 // for platform-specific option flags.
616 Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
617 const dnn::BatchDescriptor &dimensions,
618 const DeviceMemory<float> &input_data,
619 DeviceMemory<float> *output_data,
620 uint64 options);
621
622 Stream &ThenDepthConcatenate(
623 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
624 port::ArraySlice<const DeviceMemory<float> *> input_data,
625 DeviceMemory<float> *output_data);
626
627 Stream &ThenSpaceConcatenate(
628 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
629 port::ArraySlice<const DeviceMemory<float> *> input_data,
630 DeviceMemory<float> *output_data,
631 dnn::SpaceConcatenateMode concat_direction);
632
633 // Change the layout of the data by shrinking one dimension (or set of
634 // dimensions) and growing another dimension (or set of dimensions), while
635 // keeping the total number of data elements constant, and maintaining the
636 // current data ordering.
637 Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
638 const DeviceMemory<float> &input_data,
639 const dnn::BatchDescriptor &output_dimensions,
640 DeviceMemory<float> *output_data);
641
642 // Depth to space takes an X by Y image with depth D*M² and changes it to an
643 // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
644 // the input image is changed to an MxM contiguous area in the output image,
645 // with the values being laid out in raster order specified by
646 // DepthToSpaceLayout, and will have a new depth of D.
647 // See the DoDepthToSpace comment for more information.
648 Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
649 const DeviceMemory<float> &input_data,
650 const dnn::DepthToSpaceLayout &depth_to_space_layout,
651 const int sqrt_depth_reduction,
652 DeviceMemory<float> *output_data);
653
654 // Space to depth is the inverse of depth to space. Space to depth takes each
655 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
656 // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
657 // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
658 // data elements is not changed.
659 Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
660 const DeviceMemory<float> &input_data,
661 const dnn::DepthToSpaceLayout &space_to_depth_layout,
662 const int sqrt_depth_increase,
663 DeviceMemory<float> *output_data);
664
665 Stream &ThenElementwiseOperate(
666 dnn::ElementwiseOperation operation,
667 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
668 port::ArraySlice<const DeviceMemory<float> *> input_data,
669 const dnn::BatchDescriptor &output_dimensions,
670 DeviceMemory<float> *output_data);
671
672 Stream &ThenElementwiseOperateScaledQuantized(
673 dnn::ElementwiseOperation operation,
674 port::ArraySlice<int> input_multiplicands, int output_divisor,
675 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
676 port::ArraySlice<const DeviceMemory<float> *> input_data,
677 const dnn::BatchDescriptor &output_dimensions,
678 DeviceMemory<float> *output_data);
679
680 Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
681 const DeviceMemory<float> &input_data, int64 left_pad,
682 int64 right_pad, int64 top_pad, int64 bottom_pad,
683 DeviceMemory<float> *output_data);
684
685 Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
686 const DeviceMemory<float> &input_data, int64 left_trim,
687 int64 right_trim, int64 top_trim, int64 bottom_trim,
688 DeviceMemory<float> *output_data);
689
690 // Grows the input tensor by replicating the X and Y dimensions. The batch and
691 // depth/feature_map dimensions are unchanged. Currently, the input tensor is
692 // limited to X=1 and Y=1.
693 Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
694 const DeviceMemory<float> &input_data,
695 int64 replicate_x, int64 replicate_y,
696 DeviceMemory<float> *output_data);
697
698 // See DnnSupport::DoMemcpyD2HQuantized.
699 Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
700 dnn::QuantizedActivationMode mode,
701 void *host_dst, uint64 size);
702
703 // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
704 // and uses the Quantization trait to call the generic version of
705 // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
706 template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)707 Stream &ThenMemcpyD2HQuantized(
708 const DeviceMemory<float> &gpu_unquantized_src,
709 port::MutableArraySlice<ElementType> host_dst) {
710 return ThenMemcpyD2HQuantized(
711 gpu_unquantized_src, Quantization<ElementType>::kModeId,
712 host_dst.data(), host_dst.size() * sizeof(ElementType));
713 }
714
715 // See DnnSupport::DoMemcpyH2DQuantized.
716 Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
717 dnn::QuantizedActivationMode mode,
718 DeviceMemory<float> *gpu_unquantized_dst);
719
720 // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
721 // and uses the Quantization trait to call the generic version of
722 // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
723 template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)724 Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
725 DeviceMemory<float> *gpu_unquantized_dst) {
726 return ThenMemcpyH2DQuantized(
727 host_src.data(), host_src.size() * sizeof(ElementType),
728 Quantization<ElementType>::kModeId, gpu_unquantized_dst);
729 }
730
731 // See DnnSupport::DoCopyHostBuffer2Device.
732 Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
733 DeviceMemory<float> *gpu_unquantized_dst);
734
735 // See DnnSupport::DoCopyDevice2HostBuffer.
736 Stream &ThenCopyDevice2HostBuffer(
737 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
738
739 /////////////////
740 // BLAS support
741
742 // See BlasSupport::DoBlasAsum.
743 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
744 int incx, DeviceMemory<float> *result);
745 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
746 int incx, DeviceMemory<double> *result);
747 Stream &ThenBlasAsum(uint64 elem_count,
748 const DeviceMemory<std::complex<float>> &x, int incx,
749 DeviceMemory<float> *result);
750 Stream &ThenBlasAsum(uint64 elem_count,
751 const DeviceMemory<std::complex<double>> &x, int incx,
752 DeviceMemory<double> *result);
753
754 // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
755 // present in DeviceMemory, it must be an execution-time constant (i.e. a
756 // value
757 // that the stream does not change or populate during the course of
758 // execution). The value is effectively captured at stream-enqueue time.
759 Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
760 const DeviceMemory<float> &x, int incx,
761 DeviceMemory<float> *y, int incy);
762 Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
763 const DeviceMemory<double> &x, int incx,
764 DeviceMemory<double> *y, int incy);
765 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
766 const DeviceMemory<std::complex<float>> &x, int incx,
767 DeviceMemory<std::complex<float>> *y, int incy);
768 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
769 const DeviceMemory<std::complex<double>> &x, int incx,
770 DeviceMemory<std::complex<double>> *y, int incy);
771
772 // See BlasSupport::DoBlasCopy.
773 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
774 int incx, DeviceMemory<float> *y, int incy);
775 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
776 int incx, DeviceMemory<double> *y, int incy);
777 Stream &ThenBlasCopy(uint64 elem_count,
778 const DeviceMemory<std::complex<float>> &x, int incx,
779 DeviceMemory<std::complex<float>> *y, int incy);
780 Stream &ThenBlasCopy(uint64 elem_count,
781 const DeviceMemory<std::complex<double>> &x, int incx,
782 DeviceMemory<std::complex<double>> *y, int incy);
783
784 // See BlasSupport::DoBlasDot.
785 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
786 const DeviceMemory<float> &y, int incy,
787 DeviceMemory<float> *result);
788 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
789 int incx, const DeviceMemory<double> &y, int incy,
790 DeviceMemory<double> *result);
791
792 // See BlasSupport::DoBlasDotc.
793 Stream &ThenBlasDotc(uint64 elem_count,
794 const DeviceMemory<std::complex<float>> &x, int incx,
795 const DeviceMemory<std::complex<float>> &y, int incy,
796 DeviceMemory<std::complex<float>> *result);
797 Stream &ThenBlasDotc(uint64 elem_count,
798 const DeviceMemory<std::complex<double>> &x, int incx,
799 const DeviceMemory<std::complex<double>> &y, int incy,
800 DeviceMemory<std::complex<double>> *result);
801
802 // See BlasSupport::DoBlasDotu.
803 Stream &ThenBlasDotu(uint64 elem_count,
804 const DeviceMemory<std::complex<float>> &x, int incx,
805 const DeviceMemory<std::complex<float>> &y, int incy,
806 DeviceMemory<std::complex<float>> *result);
807 Stream &ThenBlasDotu(uint64 elem_count,
808 const DeviceMemory<std::complex<double>> &x, int incx,
809 const DeviceMemory<std::complex<double>> &y, int incy,
810 DeviceMemory<std::complex<double>> *result);
811
812 // See BlasSupport::DoBlasNrm2.
813 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
814 int incx, DeviceMemory<float> *result);
815 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
816 int incx, DeviceMemory<double> *result);
817 Stream &ThenBlasNrm2(uint64 elem_count,
818 const DeviceMemory<std::complex<float>> &x, int incx,
819 DeviceMemory<float> *result);
820 Stream &ThenBlasNrm2(uint64 elem_count,
821 const DeviceMemory<std::complex<double>> &x, int incx,
822 DeviceMemory<double> *result);
823
824 // See BlasSupport::DoBlasRot.
825 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
826 DeviceMemory<float> *y, int incy, float c, float s);
827 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
828 DeviceMemory<double> *y, int incy, double c, double s);
829 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
830 int incx, DeviceMemory<std::complex<float>> *y, int incy,
831 float c, float s);
832 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
833 int incx, DeviceMemory<std::complex<double>> *y, int incy,
834 double c, double s);
835
836 // See BlasSupport::DoBlasRotg.
837 Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
838 DeviceMemory<float> *c, DeviceMemory<float> *s);
839 Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
840 DeviceMemory<double> *c, DeviceMemory<double> *s);
841 Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
842 DeviceMemory<std::complex<float>> *b,
843 DeviceMemory<float> *c,
844 DeviceMemory<std::complex<float>> *s);
845 Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
846 DeviceMemory<std::complex<double>> *b,
847 DeviceMemory<double> *c,
848 DeviceMemory<std::complex<double>> *s);
849
850 // See BlasSupport::DoBlasRotm.
851 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
852 DeviceMemory<float> *y, int incy,
853 const DeviceMemory<float> ¶m);
854 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
855 DeviceMemory<double> *y, int incy,
856 const DeviceMemory<double> ¶m);
857
858 // See BlasSupport::DoBlasRotmg.
859 Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
860 DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
861 DeviceMemory<float> *param);
862 Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
863 DeviceMemory<double> *x1,
864 const DeviceMemory<double> &y1,
865 DeviceMemory<double> *param);
866
867 // See BlasSupport::DoBlasScal.
868 Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
869 int incx);
870 Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
871 int incx);
872 Stream &ThenBlasScal(uint64 elem_count, float alpha,
873 DeviceMemory<std::complex<float>> *x, int incx);
874 Stream &ThenBlasScal(uint64 elem_count, double alpha,
875 DeviceMemory<std::complex<double>> *x, int incx);
876 Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
877 DeviceMemory<std::complex<float>> *x, int incx);
878 Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
879 DeviceMemory<std::complex<double>> *x, int incx);
880
881 // See BlasSupport::DoBlasSwap.
882 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
883 DeviceMemory<float> *y, int incy);
884 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
885 DeviceMemory<double> *y, int incy);
886 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
887 int incx, DeviceMemory<std::complex<float>> *y,
888 int incy);
889 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
890 int incx, DeviceMemory<std::complex<double>> *y,
891 int incy);
892
893 // See BlasSupport::DoBlasIamax.
894 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
895 int incx, DeviceMemory<int> *result);
896 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
897 int incx, DeviceMemory<int> *result);
898 Stream &ThenBlasIamax(uint64 elem_count,
899 const DeviceMemory<std::complex<float>> &x, int incx,
900 DeviceMemory<int> *result);
901 Stream &ThenBlasIamax(uint64 elem_count,
902 const DeviceMemory<std::complex<double>> &x, int incx,
903 DeviceMemory<int> *result);
904
905 // See BlasSupport::DoBlasIamin.
906 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
907 int incx, DeviceMemory<int> *result);
908 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
909 int incx, DeviceMemory<int> *result);
910 Stream &ThenBlasIamin(uint64 elem_count,
911 const DeviceMemory<std::complex<float>> &x, int incx,
912 DeviceMemory<int> *result);
913 Stream &ThenBlasIamin(uint64 elem_count,
914 const DeviceMemory<std::complex<double>> &x, int incx,
915 DeviceMemory<int> *result);
916
917 // See BlasSupport::DoBlasGbmv.
918 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
919 uint64 ku, float alpha, const DeviceMemory<float> &a,
920 int lda, const DeviceMemory<float> &x, int incx,
921 float beta, DeviceMemory<float> *y, int incy);
922 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
923 uint64 ku, double alpha, const DeviceMemory<double> &a,
924 int lda, const DeviceMemory<double> &x, int incx,
925 double beta, DeviceMemory<double> *y, int incy);
926 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
927 uint64 ku, std::complex<float> alpha,
928 const DeviceMemory<std::complex<float>> &a, int lda,
929 const DeviceMemory<std::complex<float>> &x, int incx,
930 std::complex<float> beta,
931 DeviceMemory<std::complex<float>> *y, int incy);
932 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
933 uint64 ku, std::complex<double> alpha,
934 const DeviceMemory<std::complex<double>> &a, int lda,
935 const DeviceMemory<std::complex<double>> &x, int incx,
936 std::complex<double> beta,
937 DeviceMemory<std::complex<double>> *y, int incy);
938
939 // See BlasSupport::DoBlasGemv.
940 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
941 const DeviceMemory<float> &a, int lda,
942 const DeviceMemory<float> &x, int incx, float beta,
943 DeviceMemory<float> *y, int incy);
944 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
945 const DeviceMemory<double> &a, int lda,
946 const DeviceMemory<double> &x, int incx, double beta,
947 DeviceMemory<double> *y, int incy);
948 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
949 std::complex<float> alpha,
950 const DeviceMemory<std::complex<float>> &a, int lda,
951 const DeviceMemory<std::complex<float>> &x, int incx,
952 std::complex<float> beta,
953 DeviceMemory<std::complex<float>> *y, int incy);
954 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
955 std::complex<double> alpha,
956 const DeviceMemory<std::complex<double>> &a, int lda,
957 const DeviceMemory<std::complex<double>> &x, int incx,
958 std::complex<double> beta,
959 DeviceMemory<std::complex<double>> *y, int incy);
960
961 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
962 float alpha, const DeviceMemory<float> &a,
963 int lda, const DeviceMemory<float> &x,
964 int incx, float beta,
965 DeviceMemory<float> *y, int incy,
966 blas::ProfileResult *output_profile_result);
967 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
968 double alpha, const DeviceMemory<double> &a,
969 int lda, const DeviceMemory<double> &x,
970 int incx, double beta,
971 DeviceMemory<double> *y, int incy,
972 blas::ProfileResult *output_profile_result);
973 Stream &ThenBlasGemvWithProfiling(
974 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
975 const DeviceMemory<std::complex<float>> &a, int lda,
976 const DeviceMemory<std::complex<float>> &x, int incx,
977 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
978 blas::ProfileResult *output_profile_result);
979 Stream &ThenBlasGemvWithProfiling(
980 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
981 const DeviceMemory<std::complex<double>> &a, int lda,
982 const DeviceMemory<std::complex<double>> &x, int incx,
983 std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
984 int incy, blas::ProfileResult *output_profile_result);
985
986 // See BlasSupport::DoBlasGer.
987 Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
988 const DeviceMemory<float> &x, int incx,
989 const DeviceMemory<float> &y, int incy,
990 DeviceMemory<float> *a, int lda);
991 Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
992 const DeviceMemory<double> &x, int incx,
993 const DeviceMemory<double> &y, int incy,
994 DeviceMemory<double> *a, int lda);
995
996 // See BlasSupport::DoBlasGerc.
997 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
998 const DeviceMemory<std::complex<float>> &x, int incx,
999 const DeviceMemory<std::complex<float>> &y, int incy,
1000 DeviceMemory<std::complex<float>> *a, int lda);
1001 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
1002 const DeviceMemory<std::complex<double>> &x, int incx,
1003 const DeviceMemory<std::complex<double>> &y, int incy,
1004 DeviceMemory<std::complex<double>> *a, int lda);
1005
1006 // See BlasSupport::DoBlasGeru.
1007 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
1008 const DeviceMemory<std::complex<float>> &x, int incx,
1009 const DeviceMemory<std::complex<float>> &y, int incy,
1010 DeviceMemory<std::complex<float>> *a, int lda);
1011 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
1012 const DeviceMemory<std::complex<double>> &x, int incx,
1013 const DeviceMemory<std::complex<double>> &y, int incy,
1014 DeviceMemory<std::complex<double>> *a, int lda);
1015
1016 // See BlasSupport::DoBlasHbmv.
1017 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1018 std::complex<float> alpha,
1019 const DeviceMemory<std::complex<float>> &a, int lda,
1020 const DeviceMemory<std::complex<float>> &x, int incx,
1021 std::complex<float> beta,
1022 DeviceMemory<std::complex<float>> *y, int incy);
1023 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1024 std::complex<double> alpha,
1025 const DeviceMemory<std::complex<double>> &a, int lda,
1026 const DeviceMemory<std::complex<double>> &x, int incx,
1027 std::complex<double> beta,
1028 DeviceMemory<std::complex<double>> *y, int incy);
1029
1030 // See BlasSupport::DoBlasHemv.
1031 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1032 std::complex<float> alpha,
1033 const DeviceMemory<std::complex<float>> &a, int lda,
1034 const DeviceMemory<std::complex<float>> &x, int incx,
1035 std::complex<float> beta,
1036 DeviceMemory<std::complex<float>> *y, int incy);
1037 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1038 std::complex<double> alpha,
1039 const DeviceMemory<std::complex<double>> &a, int lda,
1040 const DeviceMemory<std::complex<double>> &x, int incx,
1041 std::complex<double> beta,
1042 DeviceMemory<std::complex<double>> *y, int incy);
1043
1044 // See BlasSupport::DoBlasHer.
1045 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1046 const DeviceMemory<std::complex<float>> &x, int incx,
1047 DeviceMemory<std::complex<float>> *a, int lda);
1048 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1049 const DeviceMemory<std::complex<double>> &x, int incx,
1050 DeviceMemory<std::complex<double>> *a, int lda);
1051
1052 // See BlasSupport::DoBlasHer2.
1053 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1054 std::complex<float> alpha,
1055 const DeviceMemory<std::complex<float>> &x, int incx,
1056 const DeviceMemory<std::complex<float>> &y, int incy,
1057 DeviceMemory<std::complex<float>> *a, int lda);
1058 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1059 std::complex<double> alpha,
1060 const DeviceMemory<std::complex<double>> &x, int incx,
1061 const DeviceMemory<std::complex<double>> &y, int incy,
1062 DeviceMemory<std::complex<double>> *a, int lda);
1063
1064 // See BlasSupport::DoBlasHpmv.
1065 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1066 std::complex<float> alpha,
1067 const DeviceMemory<std::complex<float>> &ap,
1068 const DeviceMemory<std::complex<float>> &x, int incx,
1069 std::complex<float> beta,
1070 DeviceMemory<std::complex<float>> *y, int incy);
1071 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1072 std::complex<double> alpha,
1073 const DeviceMemory<std::complex<double>> &ap,
1074 const DeviceMemory<std::complex<double>> &x, int incx,
1075 std::complex<double> beta,
1076 DeviceMemory<std::complex<double>> *y, int incy);
1077
1078 // See BlasSupport::DoBlasHpr.
1079 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1080 const DeviceMemory<std::complex<float>> &x, int incx,
1081 DeviceMemory<std::complex<float>> *ap);
1082 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1083 const DeviceMemory<std::complex<double>> &x, int incx,
1084 DeviceMemory<std::complex<double>> *ap);
1085
1086 // See BlasSupport::DoBlasHpr2.
1087 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1088 std::complex<float> alpha,
1089 const DeviceMemory<std::complex<float>> &x, int incx,
1090 const DeviceMemory<std::complex<float>> &y, int incy,
1091 DeviceMemory<std::complex<float>> *ap);
1092 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1093 std::complex<double> alpha,
1094 const DeviceMemory<std::complex<double>> &x, int incx,
1095 const DeviceMemory<std::complex<double>> &y, int incy,
1096 DeviceMemory<std::complex<double>> *ap);
1097
1098 // See BlasSupport::DoBlasSbmv.
1099 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1100 const DeviceMemory<float> &a, int lda,
1101 const DeviceMemory<float> &x, int incx, float beta,
1102 DeviceMemory<float> *y, int incy);
1103 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1104 const DeviceMemory<double> &a, int lda,
1105 const DeviceMemory<double> &x, int incx, double beta,
1106 DeviceMemory<double> *y, int incy);
1107
1108 // See BlasSupport::DoBlasSpmv.
1109 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1110 const DeviceMemory<float> &ap,
1111 const DeviceMemory<float> &x, int incx, float beta,
1112 DeviceMemory<float> *y, int incy);
1113 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1114 const DeviceMemory<double> &ap,
1115 const DeviceMemory<double> &x, int incx, double beta,
1116 DeviceMemory<double> *y, int incy);
1117
1118 // See BlasSupport::DoBlasSpr.
1119 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1120 const DeviceMemory<float> &x, int incx,
1121 DeviceMemory<float> *ap);
1122 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1123 const DeviceMemory<double> &x, int incx,
1124 DeviceMemory<double> *ap);
1125
1126 // See BlasSupport::DoBlasSpr2.
1127 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1128 const DeviceMemory<float> &x, int incx,
1129 const DeviceMemory<float> &y, int incy,
1130 DeviceMemory<float> *ap);
1131 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1132 const DeviceMemory<double> &x, int incx,
1133 const DeviceMemory<double> &y, int incy,
1134 DeviceMemory<double> *ap);
1135
1136 // See BlasSupport::DoBlasSymv.
1137 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1138 const DeviceMemory<float> &a, int lda,
1139 const DeviceMemory<float> &x, int incx, float beta,
1140 DeviceMemory<float> *y, int incy);
1141 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1142 const DeviceMemory<double> &a, int lda,
1143 const DeviceMemory<double> &x, int incx, double beta,
1144 DeviceMemory<double> *y, int incy);
1145
1146 // See BlasSupport::DoBlasSyr.
1147 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1148 const DeviceMemory<float> &x, int incx,
1149 DeviceMemory<float> *a, int lda);
1150 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1151 const DeviceMemory<double> &x, int incx,
1152 DeviceMemory<double> *a, int lda);
1153
1154 // See BlasSupport::DoBlasSyr2.
1155 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1156 const DeviceMemory<float> &x, int incx,
1157 const DeviceMemory<float> &y, int incy,
1158 DeviceMemory<float> *a, int lda);
1159 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1160 const DeviceMemory<double> &x, int incx,
1161 const DeviceMemory<double> &y, int incy,
1162 DeviceMemory<double> *a, int lda);
1163
1164 // See BlasSupport::DoBlasTbmv.
1165 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1166 blas::Diagonal diag, uint64 n, uint64 k,
1167 const DeviceMemory<float> &a, int lda,
1168 DeviceMemory<float> *x, int incx);
1169 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1170 blas::Diagonal diag, uint64 n, uint64 k,
1171 const DeviceMemory<double> &a, int lda,
1172 DeviceMemory<double> *x, int incx);
1173 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1174 blas::Diagonal diag, uint64 n, uint64 k,
1175 const DeviceMemory<std::complex<float>> &a, int lda,
1176 DeviceMemory<std::complex<float>> *x, int incx);
1177 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1178 blas::Diagonal diag, uint64 n, uint64 k,
1179 const DeviceMemory<std::complex<double>> &a, int lda,
1180 DeviceMemory<std::complex<double>> *x, int incx);
1181
1182 // See BlasSupport::DoBlasTbsv.
1183 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1184 blas::Diagonal diag, uint64 n, uint64 k,
1185 const DeviceMemory<float> &a, int lda,
1186 DeviceMemory<float> *x, int incx);
1187 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1188 blas::Diagonal diag, uint64 n, uint64 k,
1189 const DeviceMemory<double> &a, int lda,
1190 DeviceMemory<double> *x, int incx);
1191 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1192 blas::Diagonal diag, uint64 n, uint64 k,
1193 const DeviceMemory<std::complex<float>> &a, int lda,
1194 DeviceMemory<std::complex<float>> *x, int incx);
1195 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1196 blas::Diagonal diag, uint64 n, uint64 k,
1197 const DeviceMemory<std::complex<double>> &a, int lda,
1198 DeviceMemory<std::complex<double>> *x, int incx);
1199
1200 // See BlasSupport::DoBlasTpmv.
1201 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1202 blas::Diagonal diag, uint64 n,
1203 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1204 int incx);
1205 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1206 blas::Diagonal diag, uint64 n,
1207 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1208 int incx);
1209 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1210 blas::Diagonal diag, uint64 n,
1211 const DeviceMemory<std::complex<float>> &ap,
1212 DeviceMemory<std::complex<float>> *x, int incx);
1213 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1214 blas::Diagonal diag, uint64 n,
1215 const DeviceMemory<std::complex<double>> &ap,
1216 DeviceMemory<std::complex<double>> *x, int incx);
1217
1218 // See BlasSupport::DoBlasTpsv.
1219 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1220 blas::Diagonal diag, uint64 n,
1221 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1222 int incx);
1223 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1224 blas::Diagonal diag, uint64 n,
1225 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1226 int incx);
1227 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1228 blas::Diagonal diag, uint64 n,
1229 const DeviceMemory<std::complex<float>> &ap,
1230 DeviceMemory<std::complex<float>> *x, int incx);
1231 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1232 blas::Diagonal diag, uint64 n,
1233 const DeviceMemory<std::complex<double>> &ap,
1234 DeviceMemory<std::complex<double>> *x, int incx);
1235
1236 // See BlasSupport::DoBlasTrmv.
1237 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1238 blas::Diagonal diag, uint64 n,
1239 const DeviceMemory<float> &a, int lda,
1240 DeviceMemory<float> *x, int incx);
1241 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1242 blas::Diagonal diag, uint64 n,
1243 const DeviceMemory<double> &a, int lda,
1244 DeviceMemory<double> *x, int incx);
1245 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1246 blas::Diagonal diag, uint64 n,
1247 const DeviceMemory<std::complex<float>> &a, int lda,
1248 DeviceMemory<std::complex<float>> *x, int incx);
1249 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1250 blas::Diagonal diag, uint64 n,
1251 const DeviceMemory<std::complex<double>> &a, int lda,
1252 DeviceMemory<std::complex<double>> *x, int incx);
1253
1254 // See BlasSupport::DoBlasTrsv.
1255 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1256 blas::Diagonal diag, uint64 n,
1257 const DeviceMemory<float> &a, int lda,
1258 DeviceMemory<float> *x, int incx);
1259 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1260 blas::Diagonal diag, uint64 n,
1261 const DeviceMemory<double> &a, int lda,
1262 DeviceMemory<double> *x, int incx);
1263 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1264 blas::Diagonal diag, uint64 n,
1265 const DeviceMemory<std::complex<float>> &a, int lda,
1266 DeviceMemory<std::complex<float>> *x, int incx);
1267 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1268 blas::Diagonal diag, uint64 n,
1269 const DeviceMemory<std::complex<double>> &a, int lda,
1270 DeviceMemory<std::complex<double>> *x, int incx);
1271
1272 // See BlasSupport::DoBlasGemm.
1273 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1274 uint64 m, uint64 n, uint64 k, float alpha,
1275 const DeviceMemory<Eigen::half> &a, int lda,
1276 const DeviceMemory<Eigen::half> &b, int ldb,
1277 float beta, DeviceMemory<Eigen::half> *c,
1278 int ldc);
1279 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1280 uint64 m, uint64 n, uint64 k, float alpha,
1281 const DeviceMemory<float> &a, int lda,
1282 const DeviceMemory<float> &b, int ldb,
1283 float beta, DeviceMemory<float> *c, int ldc);
1284 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1285 uint64 m, uint64 n, uint64 k, double alpha,
1286 const DeviceMemory<double> &a, int lda,
1287 const DeviceMemory<double> &b, int ldb,
1288 double beta, DeviceMemory<double> *c, int ldc);
1289 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1290 uint64 m, uint64 n, uint64 k,
1291 std::complex<float> alpha,
1292 const DeviceMemory<std::complex<float>> &a,
1293 int lda,
1294 const DeviceMemory<std::complex<float>> &b,
1295 int ldb, std::complex<float> beta,
1296 DeviceMemory<std::complex<float>> *c, int ldc);
1297 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1298 uint64 m, uint64 n, uint64 k,
1299 std::complex<double> alpha,
1300 const DeviceMemory<std::complex<double>> &a,
1301 int lda,
1302 const DeviceMemory<std::complex<double>> &b,
1303 int ldb, std::complex<double> beta,
1304 DeviceMemory<std::complex<double>> *c,
1305 int ldc);
1306
1307 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1308 blas::Transpose transb, uint64 m, uint64 n,
1309 uint64 k, float alpha,
1310 const DeviceMemory<Eigen::half> &a, int lda,
1311 const DeviceMemory<Eigen::half> &b, int ldb,
1312 float beta, DeviceMemory<Eigen::half> *c,
1313 int ldc,
1314 blas::ProfileResult *output_profile_result);
1315 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1316 blas::Transpose transb, uint64 m, uint64 n,
1317 uint64 k, float alpha,
1318 const DeviceMemory<float> &a, int lda,
1319 const DeviceMemory<float> &b, int ldb,
1320 float beta, DeviceMemory<float> *c, int ldc,
1321 blas::ProfileResult *output_profile_result);
1322 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1323 blas::Transpose transb, uint64 m, uint64 n,
1324 uint64 k, double alpha,
1325 const DeviceMemory<double> &a, int lda,
1326 const DeviceMemory<double> &b, int ldb,
1327 double beta, DeviceMemory<double> *c,
1328 int ldc,
1329 blas::ProfileResult *output_profile_result);
1330 Stream &ThenBlasGemmWithProfiling(
1331 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1332 uint64 k, std::complex<float> alpha,
1333 const DeviceMemory<std::complex<float>> &a, int lda,
1334 const DeviceMemory<std::complex<float>> &b, int ldb,
1335 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1336 blas::ProfileResult *output_profile_result);
1337 Stream &ThenBlasGemmWithProfiling(
1338 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1339 uint64 k, std::complex<double> alpha,
1340 const DeviceMemory<std::complex<double>> &a, int lda,
1341 const DeviceMemory<std::complex<double>> &b, int ldb,
1342 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1343 blas::ProfileResult *output_profile_result);
1344
1345 // See BlasSupport::DoBlasGemmWithAlgorithm.
1346 Stream &ThenBlasGemmWithAlgorithm(
1347 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1348 uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1349 const DeviceMemory<Eigen::half> &a, int lda,
1350 const DeviceMemory<Eigen::half> &b, int ldb,
1351 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1352 int ldc, blas::ComputationType computation_type,
1353 blas::AlgorithmType algorithm,
1354 blas::ProfileResult *output_profile_result);
1355 Stream &ThenBlasGemmWithAlgorithm(
1356 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1357 uint64 k, const HostOrDeviceScalar<int> &alpha,
1358 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1359 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,
1360 int ldc, blas::ComputationType computation_type,
1361 blas::AlgorithmType algorithm,
1362 blas::ProfileResult *output_profile_result);
1363 Stream &ThenBlasGemmWithAlgorithm(
1364 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1365 uint64 k, const HostOrDeviceScalar<float> &alpha,
1366 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1367 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1368 int ldc, blas::ComputationType computation_type,
1369 blas::AlgorithmType algorithm,
1370 blas::ProfileResult *output_profile_result);
1371 Stream &ThenBlasGemmWithAlgorithm(
1372 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1373 uint64 k, const HostOrDeviceScalar<double> &alpha,
1374 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1375 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1376 int ldc, blas::ComputationType computation_type,
1377 blas::AlgorithmType algorithm,
1378 blas::ProfileResult *output_profile_result);
1379 Stream &ThenBlasGemmWithAlgorithm(
1380 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1381 uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1382 const DeviceMemory<std::complex<float>> &a, int lda,
1383 const DeviceMemory<std::complex<float>> &b, int ldb,
1384 const HostOrDeviceScalar<std::complex<float>> &beta,
1385 DeviceMemory<std::complex<float>> *c, int ldc,
1386 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1387 blas::ProfileResult *output_profile_result);
1388 Stream &ThenBlasGemmWithAlgorithm(
1389 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1390 uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1391 const DeviceMemory<std::complex<double>> &a, int lda,
1392 const DeviceMemory<std::complex<double>> &b, int ldb,
1393 const HostOrDeviceScalar<std::complex<double>> &beta,
1394 DeviceMemory<std::complex<double>> *c, int ldc,
1395 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1396 blas::ProfileResult *output_profile_result);
1397
1398 // See BlasSupport::DoBlasGemmBatched.
1399 Stream &ThenBlasGemmBatched(
1400 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1401 uint64 k, float alpha,
1402 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1403 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1404 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1405 int ldc, int batch_count);
1406 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1407 uint64 m, uint64 n, uint64 k, float alpha,
1408 const port::ArraySlice<DeviceMemory<float> *> &a,
1409 int lda,
1410 const port::ArraySlice<DeviceMemory<float> *> &b,
1411 int ldb, float beta,
1412 const port::ArraySlice<DeviceMemory<float> *> &c,
1413 int ldc, int batch_count);
1414 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1415 uint64 m, uint64 n, uint64 k, double alpha,
1416 const port::ArraySlice<DeviceMemory<double> *> &a,
1417 int lda,
1418 const port::ArraySlice<DeviceMemory<double> *> &b,
1419 int ldb, double beta,
1420 const port::ArraySlice<DeviceMemory<double> *> &c,
1421 int ldc, int batch_count);
1422 Stream &ThenBlasGemmBatched(
1423 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1424 uint64 k, std::complex<float> alpha,
1425 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1426 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1427 std::complex<float> beta,
1428 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1429 int batch_count);
1430 Stream &ThenBlasGemmBatched(
1431 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1432 uint64 k, std::complex<double> alpha,
1433 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1434 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1435 std::complex<double> beta,
1436 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1437 int batch_count);
1438 Stream &ThenBlasGemmBatchedWithScratch(
1439 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1440 uint64 k, float alpha,
1441 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1442 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1443 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1444 int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1445 Stream &ThenBlasGemmBatchedWithScratch(
1446 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1447 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1448 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1449 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1450 int batch_count, ScratchAllocator *scratch_allocator);
1451 Stream &ThenBlasGemmBatchedWithScratch(
1452 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1453 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1454 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1455 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1456 int batch_count, ScratchAllocator *scratch_allocator);
1457 Stream &ThenBlasGemmBatchedWithScratch(
1458 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1459 uint64 k, std::complex<float> alpha,
1460 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1461 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1462 std::complex<float> beta,
1463 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1464 int batch_count, ScratchAllocator *scratch_allocator);
1465 Stream &ThenBlasGemmBatchedWithScratch(
1466 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1467 uint64 k, std::complex<double> alpha,
1468 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1469 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1470 std::complex<double> beta,
1471 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1472 int batch_count, ScratchAllocator *scratch_allocator);
1473 Stream &ThenBlasGemmStridedBatched(
1474 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1475 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1476 int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1477 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1478 int64 stride_c, int batch_count);
1479 Stream &ThenBlasGemmStridedBatched(
1480 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1481 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1482 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1483 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1484 int batch_count);
1485 Stream &ThenBlasGemmStridedBatched(
1486 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1487 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1488 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1489 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1490 int batch_count);
1491 Stream &ThenBlasGemmStridedBatched(
1492 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1493 uint64 k, std::complex<float> alpha,
1494 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1495 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1496 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1497 int64 stride_c, int batch_count);
1498 Stream &ThenBlasGemmStridedBatched(
1499 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1500 uint64 k, std::complex<double> alpha,
1501 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1502 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1503 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1504 int64 stride_c, int batch_count);
1505
1506 // See BlasSupport::DoBlasHemm.
1507 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1508 uint64 n, std::complex<float> alpha,
1509 const DeviceMemory<std::complex<float>> &a, int lda,
1510 const DeviceMemory<std::complex<float>> &b, int ldb,
1511 std::complex<float> beta,
1512 DeviceMemory<std::complex<float>> *c, int ldc);
1513 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1514 uint64 n, std::complex<double> alpha,
1515 const DeviceMemory<std::complex<double>> &a, int lda,
1516 const DeviceMemory<std::complex<double>> &b, int ldb,
1517 std::complex<double> beta,
1518 DeviceMemory<std::complex<double>> *c, int ldc);
1519
1520 // See BlasSupport::DoBlasHerk.
1521 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1522 uint64 k, float alpha,
1523 const DeviceMemory<std::complex<float>> &a, int lda,
1524 float beta, DeviceMemory<std::complex<float>> *c,
1525 int ldc);
1526 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1527 uint64 k, double alpha,
1528 const DeviceMemory<std::complex<double>> &a, int lda,
1529 double beta, DeviceMemory<std::complex<double>> *c,
1530 int ldc);
1531
1532 // See BlasSupport::DoBlasHer2k.
1533 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1534 uint64 k, std::complex<float> alpha,
1535 const DeviceMemory<std::complex<float>> &a, int lda,
1536 const DeviceMemory<std::complex<float>> &b, int ldb,
1537 float beta, DeviceMemory<std::complex<float>> *c,
1538 int ldc);
1539 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1540 uint64 k, std::complex<double> alpha,
1541 const DeviceMemory<std::complex<double>> &a, int lda,
1542 const DeviceMemory<std::complex<double>> &b, int ldb,
1543 double beta, DeviceMemory<std::complex<double>> *c,
1544 int ldc);
1545
1546 // See BlasSupport::DoBlasSymm.
1547 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1548 uint64 n, float alpha, const DeviceMemory<float> &a,
1549 int lda, const DeviceMemory<float> &b, int ldb,
1550 float beta, DeviceMemory<float> *c, int ldc);
1551 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1552 uint64 n, double alpha, const DeviceMemory<double> &a,
1553 int lda, const DeviceMemory<double> &b, int ldb,
1554 double beta, DeviceMemory<double> *c, int ldc);
1555 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1556 uint64 n, std::complex<float> alpha,
1557 const DeviceMemory<std::complex<float>> &a, int lda,
1558 const DeviceMemory<std::complex<float>> &b, int ldb,
1559 std::complex<float> beta,
1560 DeviceMemory<std::complex<float>> *c, int ldc);
1561 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1562 uint64 n, std::complex<double> alpha,
1563 const DeviceMemory<std::complex<double>> &a, int lda,
1564 const DeviceMemory<std::complex<double>> &b, int ldb,
1565 std::complex<double> beta,
1566 DeviceMemory<std::complex<double>> *c, int ldc);
1567
1568 // See BlasSupport::DoBlasSyrk.
1569 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1570 uint64 k, float alpha, const DeviceMemory<float> &a,
1571 int lda, float beta, DeviceMemory<float> *c, int ldc);
1572 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1573 uint64 k, double alpha, const DeviceMemory<double> &a,
1574 int lda, double beta, DeviceMemory<double> *c, int ldc);
1575 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1576 uint64 k, std::complex<float> alpha,
1577 const DeviceMemory<std::complex<float>> &a, int lda,
1578 std::complex<float> beta,
1579 DeviceMemory<std::complex<float>> *c, int ldc);
1580 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1581 uint64 k, std::complex<double> alpha,
1582 const DeviceMemory<std::complex<double>> &a, int lda,
1583 std::complex<double> beta,
1584 DeviceMemory<std::complex<double>> *c, int ldc);
1585
1586 // See BlasSupport::DoBlasSyr2k.
1587 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1588 uint64 k, float alpha, const DeviceMemory<float> &a,
1589 int lda, const DeviceMemory<float> &b, int ldb,
1590 float beta, DeviceMemory<float> *c, int ldc);
1591 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1592 uint64 k, double alpha, const DeviceMemory<double> &a,
1593 int lda, const DeviceMemory<double> &b, int ldb,
1594 double beta, DeviceMemory<double> *c, int ldc);
1595 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1596 uint64 k, std::complex<float> alpha,
1597 const DeviceMemory<std::complex<float>> &a, int lda,
1598 const DeviceMemory<std::complex<float>> &b, int ldb,
1599 std::complex<float> beta,
1600 DeviceMemory<std::complex<float>> *c, int ldc);
1601 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1602 uint64 k, std::complex<double> alpha,
1603 const DeviceMemory<std::complex<double>> &a, int lda,
1604 const DeviceMemory<std::complex<double>> &b, int ldb,
1605 std::complex<double> beta,
1606 DeviceMemory<std::complex<double>> *c, int ldc);
1607
1608 // See BlasSupport::DoBlasTrmm.
1609 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1610 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1611 uint64 n, float alpha, const DeviceMemory<float> &a,
1612 int lda, DeviceMemory<float> *b, int ldb);
1613 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1614 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1615 uint64 n, double alpha, const DeviceMemory<double> &a,
1616 int lda, DeviceMemory<double> *b, int ldb);
1617 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1618 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1619 uint64 n, std::complex<float> alpha,
1620 const DeviceMemory<std::complex<float>> &a, int lda,
1621 DeviceMemory<std::complex<float>> *b, int ldb);
1622 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1623 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1624 uint64 n, std::complex<double> alpha,
1625 const DeviceMemory<std::complex<double>> &a, int lda,
1626 DeviceMemory<std::complex<double>> *b, int ldb);
1627
1628 // See BlasSupport::DoBlasTrsm.
1629 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1630 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1631 uint64 n, float alpha, const DeviceMemory<float> &a,
1632 int lda, DeviceMemory<float> *b, int ldb);
1633 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1634 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1635 uint64 n, double alpha, const DeviceMemory<double> &a,
1636 int lda, DeviceMemory<double> *b, int ldb);
1637 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1638 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1639 uint64 n, std::complex<float> alpha,
1640 const DeviceMemory<std::complex<float>> &a, int lda,
1641 DeviceMemory<std::complex<float>> *b, int ldb);
1642 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1643 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1644 uint64 n, std::complex<double> alpha,
1645 const DeviceMemory<std::complex<double>> &a, int lda,
1646 DeviceMemory<std::complex<double>> *b, int ldb);
1647
1648 // See BlasSupport::DoBlatLtMatmul.
1649 // Note that we prevent alpha and beta from being used to deduce CType so that
1650 // they can be constructed implicitly from values of type CType. Without this,
1651 // type deduction would fail when this function is called with a value of type
1652 // CType for alpha or beta.
1653 template <typename ABType, typename CType>
1654 Stream &ThenBlasLtMatmul(
1655 const blas::IBlasLtMatmulPlan *plan,
1656 const detail::NonDeducedType<HostOrDeviceScalar<CType>> &alpha,
1657 const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
1658 const detail::NonDeducedType<HostOrDeviceScalar<CType>> &beta,
1659 DeviceMemory<CType> *c, ScratchAllocator *scratch_allocator,
1660 const blas::IBlasLtMatmulAlgorithm *algorithm,
1661 const DeviceMemory<CType> &bias = {},
1662 blas::ProfileResult *output_profile_result = nullptr) {
1663 return ThenBlasLtMatmulImpl(plan, alpha, a, b, beta, c, scratch_allocator,
1664 algorithm, bias, output_profile_result);
1665 }
1666
1667 // See FftSupport::DoFft.
1668 Stream &ThenFft(fft::Plan *plan,
1669 const DeviceMemory<std::complex<float>> &input,
1670 DeviceMemory<std::complex<float>> *output);
1671 Stream &ThenFft(fft::Plan *plan,
1672 const DeviceMemory<std::complex<double>> &input,
1673 DeviceMemory<std::complex<double>> *output);
1674 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1675 DeviceMemory<std::complex<float>> *output);
1676 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1677 DeviceMemory<std::complex<double>> *output);
1678 Stream &ThenFft(fft::Plan *plan,
1679 const DeviceMemory<std::complex<float>> &input,
1680 DeviceMemory<float> *output);
1681 Stream &ThenFft(fft::Plan *plan,
1682 const DeviceMemory<std::complex<double>> &input,
1683 DeviceMemory<double> *output);
1684
1685 // Makes the RNG use the provided value as the basis for further generation.
1686 // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1687 // sources of seed data if the default (high quality) sources are not
1688 // desired.
1689 // For most use cases, this function will not be necessary; each provided
1690 // back-end implementation will be appropriately seeded by default.
1691 // At a minimum 16 bytes of data are required in the seed buffer.
1692 //
1693 // To seed with good (non-reproducible) data:
1694 // File* f = File::Open("/dev/random", "r");
1695 // int64 bytes_read = f->Read(seed_data, bytes_to_read);
1696 // < error checking >
1697 // stream.ThenSetRngSeed(seed_data, bytes_read);
1698 //
1699 // To seed with reproducible data:
1700 // uint64_t seed_data[2] = { <data> };
1701 // stream.ThenSetRngSeed(seed_data, 16);
1702 Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1703
1704 // Populates the memory indicated by values with uniform-random-distribution
1705 // values. TODO(leary) seeding API/description
1706 //
1707 // Uses the type and size of the DeviceMemory to infer what data should be
1708 // populated.
1709 Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1710 Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1711 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1712 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1713 Stream &ThenPopulateRandGaussian(float mean, float stddev,
1714 DeviceMemory<float> *values);
1715 Stream &ThenPopulateRandGaussian(double mean, double stddev,
1716 DeviceMemory<double> *values);
1717
1718 // Entrain onto the stream: a memcpy to a host destination from a GPU source
1719 // of the given target size. host_dst must be a pointer to host memory
1720 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1721 // then registered with StreamExecutor::HostMemoryRegister.
1722 Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1723 uint64 size);
1724
1725 // Entrain onto the stream: a memcpy to a GPU destination from a host source
1726 // of the given target size. host_src must be a pointer to host memory
1727 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1728 // then registered with StreamExecutor::HostMemoryRegister.
1729 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1730 uint64 size);
1731
1732 // Alternative interface for memcpying from device to host that takes an
1733 // array slice. Checks that the destination size can accommodate the host
1734 // slice size.
1735 template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1736 Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1737 port::MutableArraySlice<T> host_dst) {
1738 auto host_size = host_dst.size() * sizeof(T);
1739 CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1740 return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1741 }
1742
1743 // Alternative interface for memcpying from host to device that takes an
1744 // array slice. Checks that the destination size can accommodate the host
1745 // slice size.
1746 template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1747 Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1748 DeviceMemory<T> *gpu_dst) {
1749 auto host_size = host_src.size() * sizeof(T);
1750 CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1751 return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1752 }
1753
1754 // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1755 // of the given target size. gpu_src/dst must be pointers to GPU memory and
1756 // peer access must be enabled between their owning StreamExecutors.
1757 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1758 uint64 size);
1759
1760 // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1761 // ensuring that the host pointer isn't getting confused accidentally with a
1762 // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1763 Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1764 const DeviceMemoryBase &gpu_src, uint64 size) {
1765 return ThenMemcpy(gpu_dst, gpu_src, size);
1766 }
1767
1768 // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1769 // The location must not be null.
1770 Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1771
1772 // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1773 // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1774 // by 4). The location must not be null.
1775 Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1776
1777 // Enqueue a forward operation of the RNN model onto the stream.
1778 // See DnnSupport::DoRnnForward for more details.
1779 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1780 const dnn::RnnSequenceTensorDescriptor &input_desc,
1781 const DeviceMemory<Eigen::half> &input_data,
1782 const dnn::RnnStateTensorDescriptor &input_h_desc,
1783 const DeviceMemory<Eigen::half> &input_h_data,
1784 const dnn::RnnStateTensorDescriptor &input_c_desc,
1785 const DeviceMemory<Eigen::half> &input_c_data,
1786 const DeviceMemory<Eigen::half> ¶ms,
1787 const dnn::RnnSequenceTensorDescriptor &output_desc,
1788 DeviceMemory<Eigen::half> *output_data,
1789 const dnn::RnnStateTensorDescriptor &output_h_desc,
1790 DeviceMemory<Eigen::half> *output_h_data,
1791 const dnn::RnnStateTensorDescriptor &output_c_desc,
1792 DeviceMemory<Eigen::half> *output_c_data,
1793 bool is_training,
1794 ScratchAllocator *reserve_space_allocator,
1795 ScratchAllocator *workspace_allocator,
1796 dnn::ProfileResult *output_profile_result);
1797
1798 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1799 const dnn::RnnSequenceTensorDescriptor &input_desc,
1800 const DeviceMemory<float> &input_data,
1801 const dnn::RnnStateTensorDescriptor &input_h_desc,
1802 const DeviceMemory<float> &input_h_data,
1803 const dnn::RnnStateTensorDescriptor &input_c_desc,
1804 const DeviceMemory<float> &input_c_data,
1805 const DeviceMemory<float> ¶ms,
1806 const dnn::RnnSequenceTensorDescriptor &output_desc,
1807 DeviceMemory<float> *output_data,
1808 const dnn::RnnStateTensorDescriptor &output_h_desc,
1809 DeviceMemory<float> *output_h_data,
1810 const dnn::RnnStateTensorDescriptor &output_c_desc,
1811 DeviceMemory<float> *output_c_data, bool is_training,
1812 ScratchAllocator *reserve_space_allocator,
1813 ScratchAllocator *workspace_allocator,
1814 dnn::ProfileResult *output_profile_result);
1815
1816 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1817 const dnn::RnnSequenceTensorDescriptor &input_desc,
1818 const DeviceMemory<double> &input_data,
1819 const dnn::RnnStateTensorDescriptor &input_h_desc,
1820 const DeviceMemory<double> &input_h_data,
1821 const dnn::RnnStateTensorDescriptor &input_c_desc,
1822 const DeviceMemory<double> &input_c_data,
1823 const DeviceMemory<double> ¶ms,
1824 const dnn::RnnSequenceTensorDescriptor &output_desc,
1825 DeviceMemory<double> *output_data,
1826 const dnn::RnnStateTensorDescriptor &output_h_desc,
1827 DeviceMemory<double> *output_h_data,
1828 const dnn::RnnStateTensorDescriptor &output_c_desc,
1829 DeviceMemory<double> *output_c_data, bool is_training,
1830 ScratchAllocator *reserve_space_allocator,
1831 ScratchAllocator *workspace_allocator,
1832 dnn::ProfileResult *output_profile_result);
1833
1834 // Enqueue a backward operation of the RNN model onto the stream.
1835 // See DnnSupport::DoRnnBackward for more details.
1836 Stream &ThenRnnBackward(
1837 const dnn::RnnDescriptor &rnn_desc,
1838 const dnn::RnnSequenceTensorDescriptor &input_desc,
1839 const DeviceMemory<Eigen::half> &input_data,
1840 const dnn::RnnStateTensorDescriptor &input_h_desc,
1841 const DeviceMemory<Eigen::half> &input_h_data,
1842 const dnn::RnnStateTensorDescriptor &input_c_desc,
1843 const DeviceMemory<Eigen::half> &input_c_data,
1844 const DeviceMemory<Eigen::half> ¶ms,
1845 const dnn::RnnSequenceTensorDescriptor &output_desc,
1846 const DeviceMemory<Eigen::half> &output_data,
1847 const dnn::RnnStateTensorDescriptor &output_h_desc,
1848 const DeviceMemory<Eigen::half> &output_h_data,
1849 const dnn::RnnStateTensorDescriptor &output_c_desc,
1850 const DeviceMemory<Eigen::half> &output_c_data,
1851 const DeviceMemory<Eigen::half> &output_backprop_data,
1852 const DeviceMemory<Eigen::half> &output_h_backprop_data,
1853 const DeviceMemory<Eigen::half> &output_c_backprop_data,
1854 DeviceMemory<Eigen::half> *input_backprop_data,
1855 DeviceMemory<Eigen::half> *input_h_backprop_data,
1856 DeviceMemory<Eigen::half> *input_c_backprop_data,
1857 DeviceMemory<Eigen::half> *params_backprop_data,
1858 DeviceMemory<uint8> *reserve_space_data,
1859 ScratchAllocator *workspace_allocator,
1860 dnn::ProfileResult *output_profile_result);
1861
1862 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1863 const dnn::RnnSequenceTensorDescriptor &input_desc,
1864 const DeviceMemory<float> &input_data,
1865 const dnn::RnnStateTensorDescriptor &input_h_desc,
1866 const DeviceMemory<float> &input_h_data,
1867 const dnn::RnnStateTensorDescriptor &input_c_desc,
1868 const DeviceMemory<float> &input_c_data,
1869 const DeviceMemory<float> ¶ms,
1870 const dnn::RnnSequenceTensorDescriptor &output_desc,
1871 const DeviceMemory<float> &output_data,
1872 const dnn::RnnStateTensorDescriptor &output_h_desc,
1873 const DeviceMemory<float> &output_h_data,
1874 const dnn::RnnStateTensorDescriptor &output_c_desc,
1875 const DeviceMemory<float> &output_c_data,
1876 const DeviceMemory<float> &output_backprop_data,
1877 const DeviceMemory<float> &output_h_backprop_data,
1878 const DeviceMemory<float> &output_c_backprop_data,
1879 DeviceMemory<float> *input_backprop_data,
1880 DeviceMemory<float> *input_h_backprop_data,
1881 DeviceMemory<float> *input_c_backprop_data,
1882 DeviceMemory<float> *params_backprop_data,
1883 DeviceMemory<uint8> *reserve_space_data,
1884 ScratchAllocator *workspace_allocator,
1885 dnn::ProfileResult *output_profile_result);
1886
1887 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1888 const dnn::RnnSequenceTensorDescriptor &input_desc,
1889 const DeviceMemory<double> &input_data,
1890 const dnn::RnnStateTensorDescriptor &input_h_desc,
1891 const DeviceMemory<double> &input_h_data,
1892 const dnn::RnnStateTensorDescriptor &input_c_desc,
1893 const DeviceMemory<double> &input_c_data,
1894 const DeviceMemory<double> ¶ms,
1895 const dnn::RnnSequenceTensorDescriptor &output_desc,
1896 const DeviceMemory<double> &output_data,
1897 const dnn::RnnStateTensorDescriptor &output_h_desc,
1898 const DeviceMemory<double> &output_h_data,
1899 const dnn::RnnStateTensorDescriptor &output_c_desc,
1900 const DeviceMemory<double> &output_c_data,
1901 const DeviceMemory<double> &output_backprop_data,
1902 const DeviceMemory<double> &output_h_backprop_data,
1903 const DeviceMemory<double> &output_c_backprop_data,
1904 DeviceMemory<double> *input_backprop_data,
1905 DeviceMemory<double> *input_h_backprop_data,
1906 DeviceMemory<double> *input_c_backprop_data,
1907 DeviceMemory<double> *params_backprop_data,
1908 DeviceMemory<uint8> *reserve_space_data,
1909 ScratchAllocator *workspace_allocator,
1910 dnn::ProfileResult *output_profile_result);
1911
1912 // Enqueue a CTCLoss operation onto the stream.
1913 // See DnnSupport::DoCtcLoss for more details.
1914 Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
1915 const DeviceMemory<float> &probs_data,
1916 absl::Span<const int> labels_data,
1917 absl::Span<const int> labels_lengths_data,
1918 absl::Span<const int> input_lengths_data,
1919 DeviceMemory<float> *costs_data,
1920 const dnn::RnnStateTensorDescriptor &grads_desc,
1921 DeviceMemory<float> *grads_data,
1922 ScratchAllocator *workspace_allocator);
1923
1924 // Enqueue onto the stream a operation that transforms a tensor.
1925 // See DnnSupport::DoTransformTensor for more details.
1926 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1927 dnn::DataType input_type,
1928 const DeviceMemoryBase &input_data,
1929 const dnn::BatchDescriptor &output_desc,
1930 dnn::DataType output_type, float scale,
1931 DeviceMemoryBase *output_data);
1932
1933 // The templated version of the above ThenTransformTensor. Useful when the
1934 // input and output types are statically known.
1935 template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1936 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1937 const DeviceMemory<InElemT> &input_data,
1938 const dnn::BatchDescriptor &output_desc,
1939 DeviceMemory<OutElemT> *output_data) {
1940 return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1941 input_data, output_desc,
1942 dnn::ToDataType<OutElemT>(), output_data);
1943 }
1944
1945 // (Synchronously) block the host code waiting for the operations
1946 // entrained on the stream (enqueued to this point in program
1947 // execution) to complete.
1948 //
1949 // Returns an OK status if the blocking was successful and the stream is ok().
1950 // Otherwise returns an error describing why the blocking failed.
1951 port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_);
1952
1953 // Warning! This method interacts with internal threads in
1954 // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1955 // use
1956 // only. Please check with a member of the FASTR team before making use of
1957 // this method.
1958 //
1959 // Entrains onto the stream a function to be executed on the host at some
1960 // point in the future.
1961 // Async host callbacks DO NOT block the stream as device functions (or as
1962 // synchronous host callbacks). No synchronization is possible with
1963 // asynchronous callbacks; they are strictly fire-and-forget.
1964 // This method is private due to the potential for undefined behavior with
1965 // synchronization using OpenCL user events.
1966 // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1967 // parameter will still be valid - this Stream may not be!
1968 // Any callbacks requiring device API calls must use this method.
1969 Stream &ThenEnqueueOnBackgroundThread(
1970 std::function<void(StreamExecutor *)> task);
1971
1972 // Returns the (opaque) platform-specific backing object. Ownership is not
1973 // transferred to the caller.
implementation()1974 internal::StreamInterface *implementation() { return implementation_.get(); }
1975
1976 // Entrains onto the stream a callback to the host (from the device).
1977 // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1978 // never fail or its failure is inconsequential.
1979 //
1980 // This is kept for backward compatibility. Future code should use
1981 // ThenDoHostCallbackWithStatus and explicitly return a success status.
1982 // TODO(b/112125301): Eventually remove this method.
1983 Stream &ThenDoHostCallback(std::function<void()> callback);
1984
1985 // Entrains onto the stream a callback to the host (from the device).
1986 // Host callbacks block/occupy the stream just as device functions
1987 // (execute one at a time, block later stream operations).
1988 // Whether the callback return status affects the result of BlockHostUntilDone
1989 // is platform-dependent.
1990 //
1991 // Behavior is undefined when synchronizing using OpenCL user events.
1992 // Behavior is undefined if host callbacks call device routines or insert
1993 // them into any stream.
1994 //
1995 // On certain platforms, ThenDoHostCallback is expected to have significant
1996 // negative effects on performance.
1997 Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
1998
1999 // Runs the given callback after the next call to BlockHostUntilDone on this
2000 // stream (or after the Stream does BlockHostUntilDone in its destructor).
2001 // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
2002 // some use cases.
2003 Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
2004
2005 // Returns the StreamExecutor (parent object) associated with this stream.
parent()2006 StreamExecutor *parent() const {
2007 CHECK(parent_ != nullptr);
2008 return parent_;
2009 }
2010
2011 // Returns the (internal usage) temporary-memory-allocation manager associated
2012 // with this stream.
2013 internal::TemporaryMemoryManager *temporary_memory_manager();
2014
2015 // Returns a debugging string "[stream=0x...,impl=0x...]".
2016 std::string DebugStreamPointers() const;
2017
2018 private:
2019 friend class host::HostBlas; // for parent_.
2020 friend class host::HostFft; // for parent_.
2021 friend class host::HostRng; // for parent_.
2022 template <typename... Args>
2023 friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
2024 friend class ocl::CLBlas; // for parent_.
2025
InErrorState()2026 bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
2027 absl::ReaderMutexLock lock(&mu_);
2028 return !status_.ok();
2029 }
2030
2031 // Sets the error state if operation_retcode is false.
2032 // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)2033 void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_) {
2034 if (operation_retcode) {
2035 return;
2036 }
2037 absl::MutexLock lock(&mu_);
2038 status_ = port::InternalError("Unknown error");
2039 }
2040
2041 // Checks the status and logs the error message, if any.
2042 void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_);
2043
SetError()2044 void SetError() { CheckError(false /* = operation_retcode */); }
2045
SetErrorAndLogNoDnnSupport()2046 void SetErrorAndLogNoDnnSupport() {
2047 SetError();
2048 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
2049 "without DNN support";
2050 }
2051
2052 // Runs the set of callbacks that are intended to run after
2053 // BlockHostUntilDone.
2054 void RunAfterBlockHostUntilDoneCallbacks();
2055
2056 // The StreamExecutor that supports the operation of this stream.
2057 StreamExecutor *parent_;
2058
2059 // The platform-dependent implementation that the StreamExecutor interface
2060 // delegates to.
2061 std::unique_ptr<internal::StreamInterface> implementation_;
2062
2063 // mutex that guards the allocation / error state flags.
2064 // Mutable so that it can be obtained via const reader lock.
2065 mutable absl::Mutex mu_;
2066
2067 // Whether Init() was successfully called to allocate this stream on the
2068 // underlying platform. It simply flips from 0 to 1 with a sanity check.
2069 // See StreamExecutor::AllocateStream.
2070 bool allocated_ TF_GUARDED_BY(mu_);
2071
2072 // The last error (if any) of all method calls.
2073 port::Status status_ TF_GUARDED_BY(mu_);
2074
2075 // Sub-streams that are generated from this stream. Each element has a pointer
2076 // to sub-stream and a boolean value indicating if this substream is ready to
2077 // be reused.
2078 std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2079 TF_GUARDED_BY(mu_);
2080
2081 // Streams can allocate temporary memories to help with work they enqueue
2082 // (e.g. for scratch memory spaces). This member tracks those allocations and
2083 // notes when they can be reclaimed -- reclamation is attempted when
2084 // BlockHostUntilDone() is called.
2085 internal::TemporaryMemoryManager temporary_memory_manager_;
2086
2087 // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
2088 std::vector<std::function<void()>> after_block_host_until_done_callbacks_
2089 TF_GUARDED_BY(mu_);
2090
2091 // Implementation of ThenConvolveBackwardBias that is shared by all types.
2092 template <typename T>
2093 Stream &ThenConvolveBackwardBiasImpl(
2094 const dnn::BatchDescriptor &input_descriptor,
2095 const DeviceMemory<T> &input_data,
2096 const dnn::BatchDescriptor &bias_descriptor,
2097 DeviceMemory<T> *backward_bias_data);
2098
2099 // Implementation of ThenBlasLtMatmul that is shared by all types.
2100 template <typename ABType, typename CType>
2101 Stream &ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan *plan,
2102 const HostOrDeviceScalar<CType> &alpha,
2103 const DeviceMemory<ABType> &a,
2104 const DeviceMemory<ABType> &b,
2105 const HostOrDeviceScalar<CType> &beta,
2106 DeviceMemory<CType> *c,
2107 ScratchAllocator *scratch_allocator,
2108 const blas::IBlasLtMatmulAlgorithm *algorithm,
2109 const DeviceMemory<CType> &bias,
2110 blas::ProfileResult *output_profile_result);
2111
2112 SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2113 };
2114
2115 ////////////
2116 // Inlines
2117
2118 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)2119 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
2120 const TypedKernel<Params...> &kernel,
2121 Args... args) {
2122 KernelInvocationChecker<std::tuple<Params...>,
2123 std::tuple<Args...>>::CheckAllStaticAssert();
2124 if (ok()) {
2125 // This is the core that allows type-safe kernel launching.
2126 // Since the platforms take kernel arguments as tuples of (void *, size),
2127 // we pack the variadic parameters passed as ...args into the desired
2128 // tuple form and pass that packed form to the StreamExecutor::Launch()
2129 // implementation.
2130 KernelArgsArray<sizeof...(args)> kernel_args;
2131 kernel.PackParams(&kernel_args, args...);
2132 DCHECK(parent_ != nullptr);
2133 bool ok =
2134 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)
2135 .ok();
2136 if (!ok) {
2137 SetError();
2138 LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
2139 }
2140 }
2141 return *this;
2142 }
2143
2144 template <typename T>
2145 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2146 Stream::AllocateTemporaryArray(uint64 element_count) {
2147 return temporary_memory_manager_.AllocateArray<T>(element_count);
2148 }
2149
temporary_memory_manager()2150 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2151 return &temporary_memory_manager_;
2152 }
2153
2154 template <>
2155 struct Quantization<uint8> {
2156 static constexpr dnn::QuantizedActivationMode kModeId =
2157 dnn::QuantizedActivationMode::k8Bit;
2158 };
2159
2160 template <>
2161 struct Quantization<uint16> {
2162 static constexpr dnn::QuantizedActivationMode kModeId =
2163 dnn::QuantizedActivationMode::k16Bit;
2164 };
2165
2166 template <>
2167 struct Quantization<int32> {
2168 static constexpr dnn::QuantizedActivationMode kModeId =
2169 dnn::QuantizedActivationMode::k32Bit;
2170 };
2171
2172 } // namespace stream_executor
2173
2174 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2175