1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23
24 #include <complex>
25 #include <functional>
26 #include <memory>
27
28 #include "absl/synchronization/mutex.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32 #include "tensorflow/stream_executor/dnn.h"
33 #include "tensorflow/stream_executor/event.h"
34 #include "tensorflow/stream_executor/fft.h"
35 #include "tensorflow/stream_executor/host_or_device_scalar.h"
36 #include "tensorflow/stream_executor/kernel.h"
37 #include "tensorflow/stream_executor/launch_dim.h"
38 #include "tensorflow/stream_executor/lib/array_slice.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/platform/thread_annotations.h"
41 #include "tensorflow/stream_executor/temporary_memory_manager.h"
42
43 namespace stream_executor {
44
45 namespace host {
46 class HostBlas;
47 class HostFft;
48 class HostRng;
49 class HostTimer;
50 } // namespace host
51
52 namespace ocl {
53 class CLBlas;
54 } // namespace ocl
55
56 namespace internal {
57 class StreamInterface;
58 } // namespace internal
59
60 class DeviceMemoryBase;
61 template <typename ElemT>
62 class DeviceMemory;
63
64 class Timer;
65
66 namespace dnn {
67 class BatchDescriptor;
68 class FilterDescriptor;
69 class ConvolutionDescriptor;
70 class ProfileResult;
71 class AlgorithmDesc;
72 } // namespace dnn
73
74 class StreamExecutor;
75 class ScratchAllocator;
76
77 // Convert a type to the corresponding QuantizedActivationMode.
78 template <typename ElementType>
79 struct Quantization;
80
81 // Represents a stream of dependent computations on a GPU device.
82 //
83 // The operations within a stream execute linearly and asynchronously until
84 // BlockHostUntilDone() is invoked, which synchronously joins host code with
85 // the execution of the stream.
86 //
87 // If any given operation fails when entraining work for the stream, ok() will
88 // indicate that an error has occurred. After initialization, once a stream is
89 // !ok(), it will never be ok().
90 //
91 // Thread-safe post-initialization.
92 class Stream {
93 public:
94 // Instantiate a stream tied to parent as a platform executor. Work
95 // entrained onto this stream will be launched/managed on that
96 // StreamExecutor's platform.
97 explicit Stream(StreamExecutor *parent);
98
99 // Test only. Use an externally-populated value (like a mock) for the
100 // platform-specific stream implementation.
101 Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
102
103 // Deallocates any stream resources that the parent StreamExecutor has
104 // bestowed
105 // upon this object.
106 ~Stream();
107
108 // Returns whether any errors have occurred while entraining work for this
109 // stream.
ok()110 bool ok() const { return !InErrorState(); }
111
112 // Retrieves execution status back into the stream from the underlying
113 // implementation without blocking the stream.
114 //
115 // Normally, Stream::BlockHostUntilDone is used to get execution status.
116 // However, some devices use out-of-band mechnanisms to ensure their streams
117 // have finished on-device work, without needing to block the streams. (These
118 // devices should also override AllowsSyncOnCompletion to return false.) For
119 // these devices, this method can be used after work is finished to retrieve
120 // execution status.
121 port::Status RefreshStatus() LOCKS_EXCLUDED(mu_);
122
123 // Initialize the stream. This must be performed before entraining any other
124 // operations.
125 Stream &Init() LOCKS_EXCLUDED(mu_);
126
127 // Initializes timer t via the StreamExecutor.
128 Stream &InitTimer(Timer *t);
129
130 // Convenience wrapper around Init() and InitTimer().
131 Stream &InitWithTimer(Timer *t);
132
133 // Get or create a sub-stream from this stream. If there is any sub-stream in
134 // the pool that can be reused then just return this sub-stream. Otherwise
135 // create a new sub-stream.
136 //
137 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
138 Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
139
140 // Return the sub-stream back to the host stream so that it can be reused
141 // later. Sub-streams that are !ok() will not be reused.
142 //
143 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
144 void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
145
146 // Allocate temporary memories. The stream will deallocate them when blocked
147 // or destroyed.
148 template <typename T>
149 port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
150 AllocateTemporaryArray(uint64 element_count);
151
152 // Entrains onto the stream of operations: a kernel launch with the given
153 // (variadic) parameters for the invocation. These arguments can be things
154 // like DeviceMemory or primitive types such as int. What arguments you may
155 // pass to a given kernel are noted as the template parameters to the
156 // TypedKernel type that the machocc compiler generates.
157 //
158 // Template parameters:
159 // Params... The type list of formal parameters that the typed kernel
160 // expects, which is matched against Args...
161 // Args... The deduced type list for passed actual arguments
162 //
163 // Implementation: A compile-time compatibility check is performed that has
164 // some leniency versus an exact parameter pack match -- for example,
165 // `const DeviceMemory<T>` is considered "pack compatible" with a
166 // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
167 // perfect forwarding support without rvalue references. It also attempts to
168 // spit out helpful static_assert error traces with information as to the
169 // argument number and types that were mismatched.
170 template <typename... Params, typename... Args>
171 Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
172 const TypedKernel<Params...> &kernel, Args... args);
173
174 // Record a "start" event for the interval timer at this point in the
175 // stream's execution (relative to the previously and subsequently enqueued
176 // items in the stream's execution). Streams may be started/stopped multiple
177 // times.
178 Stream &ThenStartTimer(Timer *t);
179
180 // Record a "stop" event for the interval timer at this point in the
181 // stream's execution. See also Stream::ThenStartTimer.
182 Stream &ThenStopTimer(Timer *t);
183
184 // TODO(leary) If work is added to the stream that is being depended upon,
185 // then what? Have to describe what happens.
186 template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)187 Stream &ThenWaitFor(Stream *other, Params... more_streams) {
188 return ThenWaitFor(more_streams...).ThenWaitFor(other);
189 }
190
191 // Create a dependency for this stream's next work on the other stream
192 // completing. Does not take ownership of other, and other must not be
193 // null.
194 //
195 // Checks that a stream does not wait for itself, and it is up to the
196 // user to guarantee that a stream does not come to wait on itself in a
197 // cyclic manner; in that case, behavior is undefined.
198 //
199 // N.B. Base recursion case for the variadic ThenWaitFor.
200 Stream &ThenWaitFor(Stream *other);
201
202 // Waits for all streams values in others.
203 // Checks that there is no shallow circular wait (i.e. that "this" is not in
204 // others)
205 template <typename P>
ThenWaitFor(P others)206 Stream &ThenWaitFor(P others) {
207 for (auto &stream : *others) {
208 CHECK_NE(stream.get(), this);
209 ThenWaitFor(stream.get());
210 }
211 return *this;
212 }
213
214 // Waits for an event object to be set.
215 // Note that ThenRecordEvent must have been called on the event before
216 // you call this function; otherwise the event will be considered complete
217 // and this wait will do nothing.
218 Stream &ThenWaitFor(Event *event);
219
220 // Inserts the specified event into the end of this stream. Once the stream
221 // has processed all events prior to the insertion point, the event will be
222 // marked as completed.
223 // The stream does not take ownership of event - meaning that event's lifetime
224 // must extend past the point at which it is marked complete!
225 Stream &ThenRecordEvent(Event *event);
226
227 ////////////////
228 // DNN support
229 //
230 // See DnnSupport::* for comments on the following methods.
231
232 Stream &ThenBatchNormalizationForward(
233 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
234 const DeviceMemory<float> &offset,
235 const DeviceMemory<float> &estimated_mean,
236 const DeviceMemory<float> &estimated_variance,
237 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
238 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
239 const double exponential_average_factor,
240 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
241 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
242 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
243 bool is_training,
244 std::function<const DeviceMemory<float> &()> var_to_inv_var,
245 std::function<void()> inv_var_to_var,
246 ScratchAllocator *reserve_space_allocator,
247 ScratchAllocator *workspace_allocator);
248
249 Stream &ThenBatchNormalizationBackward(
250 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
251 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
252 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
253 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
254 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
255 DeviceMemory<float> *offset_backprop,
256 DeviceMemory<uint8> *reserve_space_data,
257 ScratchAllocator *workspace_allocator);
258
259 Stream &ThenBatchNormalizationForward(
260 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
261 const DeviceMemory<float> &offset,
262 const DeviceMemory<float> &estimated_mean,
263 const DeviceMemory<float> &estimated_variance,
264 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
265 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
266 const double exponential_average_factor,
267 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
268 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
269 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
270 bool is_training,
271 std::function<const DeviceMemory<float> &()> var_to_inv_var,
272 std::function<void()> inv_var_to_var,
273 ScratchAllocator *reserve_space_allocator,
274 ScratchAllocator *workspace_allocator);
275
276 Stream &ThenBatchNormalizationBackward(
277 const DeviceMemory<Eigen::half> &y_backprop,
278 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
279 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
280 const dnn::BatchDescriptor &x_desc,
281 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
282 DeviceMemory<Eigen::half> *x_backprop,
283 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
284 DeviceMemory<uint8> *reserve_space_data,
285 ScratchAllocator *workspace_allocator);
286
287 Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
288 const DeviceMemory<float> &input_data,
289 const dnn::FilterDescriptor &filter_descriptor,
290 const DeviceMemory<float> &filter_data,
291 const dnn::ConvolutionDescriptor &convolution_descriptor,
292 const dnn::BatchDescriptor &output_descriptor,
293 DeviceMemory<float> *output);
294
295 Stream &ThenConvolveQuantized(
296 const dnn::BatchDescriptor &input_descriptor,
297 const DeviceMemory<float> &input_data,
298 const dnn::FilterDescriptor &filter_descriptor,
299 const DeviceMemory<int8> &filter_coefficients,
300 const DeviceMemory<float> &coefficient_scales,
301 const dnn::ConvolutionDescriptor &convolution_descriptor,
302 const dnn::BatchDescriptor &output_descriptor,
303 DeviceMemory<float> *output_data);
304
305 Stream &ThenConvolveQuantized(
306 const dnn::BatchDescriptor &input_descriptor,
307 const DeviceMemory<float> &input_data,
308 const dnn::FilterDescriptor &filter_descriptor,
309 const DeviceMemory<int16> &filter_coefficients,
310 const DeviceMemory<float> &coefficient_scales,
311 const dnn::ConvolutionDescriptor &convolution_descriptor,
312 const dnn::BatchDescriptor &output_descriptor,
313 DeviceMemory<float> *output_data);
314
315 Stream &ThenConvolveWithAlgorithm(
316 const dnn::BatchDescriptor &input_descriptor,
317 const DeviceMemory<double> &input_data,
318 const dnn::FilterDescriptor &filter_descriptor,
319 const DeviceMemory<double> &filter_data,
320 const dnn::ConvolutionDescriptor &convolution_descriptor,
321 const dnn::BatchDescriptor &output_descriptor,
322 DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
323 const dnn::AlgorithmConfig &algorithm_config,
324 dnn::ProfileResult *output_profile_result);
325
326 Stream &ThenConvolveWithAlgorithm(
327 const dnn::BatchDescriptor &input_descriptor,
328 const DeviceMemory<float> &input_data,
329 const dnn::FilterDescriptor &filter_descriptor,
330 const DeviceMemory<float> &filter_data,
331 const dnn::ConvolutionDescriptor &convolution_descriptor,
332 const dnn::BatchDescriptor &output_descriptor,
333 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
334 const dnn::AlgorithmConfig &algorithm_config,
335 dnn::ProfileResult *output_profile_result);
336
337 Stream &ThenConvolveWithAlgorithm(
338 const dnn::BatchDescriptor &input_descriptor,
339 const DeviceMemory<Eigen::half> &input_data,
340 const dnn::FilterDescriptor &filter_descriptor,
341 const DeviceMemory<Eigen::half> &filter_data,
342 const dnn::ConvolutionDescriptor &convolution_descriptor,
343 const dnn::BatchDescriptor &output_descriptor,
344 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
345 const dnn::AlgorithmConfig &algorithm_config,
346 dnn::ProfileResult *output_profile_result);
347
348 Stream &ThenConvolveWithAlgorithm(
349 const dnn::BatchDescriptor &input_descriptor,
350 const DeviceMemory<int8> &input_data,
351 const dnn::FilterDescriptor &filter_descriptor,
352 const DeviceMemory<int8> &filter_data,
353 const dnn::ConvolutionDescriptor &convolution_descriptor,
354 const dnn::BatchDescriptor &output_descriptor,
355 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
356 const dnn::AlgorithmConfig &algorithm_config,
357 dnn::ProfileResult *output_profile_result);
358
359 Stream &ThenConvolveWithAlgorithm(
360 const dnn::BatchDescriptor &input_descriptor,
361 const DeviceMemory<int8> &input_data,
362 const dnn::FilterDescriptor &filter_descriptor,
363 const DeviceMemory<int8> &filter_data,
364 const dnn::ConvolutionDescriptor &convolution_descriptor,
365 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
366 ScratchAllocator *scratch_allocator,
367 const dnn::AlgorithmConfig &algorithm_config,
368 dnn::ProfileResult *output_profile_result);
369
370 Stream &ThenFusedConvolveWithAlgorithm(
371 const dnn::BatchDescriptor &conv_input_descriptor,
372 const DeviceMemory<double> &conv_input_data, double conv_input_scale,
373 const dnn::FilterDescriptor &filter_descriptor,
374 const DeviceMemory<double> &filter_data,
375 const dnn::ConvolutionDescriptor &convolution_descriptor,
376 const DeviceMemory<double> &side_input_data, double side_input_scale,
377 const dnn::BatchDescriptor &bias_descriptor,
378 const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
379 const dnn::BatchDescriptor &output_descriptor,
380 DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
381 const dnn::AlgorithmConfig &algorithm_config,
382 dnn::ProfileResult *output_profile_result);
383
384 Stream &ThenFusedConvolveWithAlgorithm(
385 const dnn::BatchDescriptor &conv_input_descriptor,
386 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
387 const dnn::FilterDescriptor &filter_descriptor,
388 const DeviceMemory<float> &filter_data,
389 const dnn::ConvolutionDescriptor &convolution_descriptor,
390 const DeviceMemory<float> &side_input_data, float side_input_scale,
391 const dnn::BatchDescriptor &bias_descriptor,
392 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
393 const dnn::BatchDescriptor &output_descriptor,
394 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
395 const dnn::AlgorithmConfig &algorithm_config,
396 dnn::ProfileResult *output_profile_result);
397
398 Stream &ThenFusedConvolveWithAlgorithm(
399 const dnn::BatchDescriptor &conv_input_descriptor,
400 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
401 const dnn::FilterDescriptor &filter_descriptor,
402 const DeviceMemory<Eigen::half> &filter_data,
403 const dnn::ConvolutionDescriptor &convolution_descriptor,
404 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
405 const dnn::BatchDescriptor &bias_descriptor,
406 const DeviceMemory<Eigen::half> &biases,
407 dnn::ActivationMode activation_mode,
408 const dnn::BatchDescriptor &output_descriptor,
409 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
410 const dnn::AlgorithmConfig &algorithm_config,
411 dnn::ProfileResult *output_profile_result);
412
413 Stream &ThenFusedConvolveWithAlgorithm(
414 const dnn::BatchDescriptor &conv_input_descriptor,
415 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
416 const dnn::FilterDescriptor &filter_descriptor,
417 const DeviceMemory<int8> &filter_data,
418 const dnn::ConvolutionDescriptor &convolution_descriptor,
419 const DeviceMemory<int8> &side_input_data, float side_input_scale,
420 const dnn::BatchDescriptor &bias_descriptor,
421 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
422 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
423 ScratchAllocator *scratch_allocator,
424 const dnn::AlgorithmConfig &algorithm_config,
425 dnn::ProfileResult *output_profile_result);
426
427 Stream &ThenFusedConvolveWithAlgorithm(
428 const dnn::BatchDescriptor &conv_input_descriptor,
429 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
430 const dnn::FilterDescriptor &filter_descriptor,
431 const DeviceMemory<int8> &filter_data,
432 const dnn::ConvolutionDescriptor &convolution_descriptor,
433 const DeviceMemory<float> &side_input_data, float side_input_scale,
434 const dnn::BatchDescriptor &bias_descriptor,
435 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
436 const dnn::BatchDescriptor &output_descriptor,
437 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
438 const dnn::AlgorithmConfig &algorithm_config,
439 dnn::ProfileResult *output_profile_result);
440
441 Stream &ThenSeparableConvolve(
442 const dnn::BatchDescriptor &input_descriptor,
443 const DeviceMemory<float> &input_data,
444 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
445 const DeviceMemory<float> &first_weights,
446 const DeviceMemory<float> &second_weights,
447 const dnn::ConvolutionDescriptor &convolution_descriptor,
448 const dnn::BatchDescriptor &output_descriptor,
449 DeviceMemory<float> *output);
450
451 Stream &ThenConvolveBackwardDataWithAlgorithm(
452 const dnn::FilterDescriptor &filter_descriptor,
453 const DeviceMemory<double> &filter_data,
454 const dnn::BatchDescriptor &output_descriptor,
455 DeviceMemory<double> backward_output_data,
456 const dnn::ConvolutionDescriptor &convolution_descriptor,
457 const dnn::BatchDescriptor &input_descriptor,
458 DeviceMemory<double> *backward_input_data,
459 ScratchAllocator *scratch_allocator,
460 const dnn::AlgorithmConfig &algorithm_config,
461 dnn::ProfileResult *output_profile_result);
462
463 Stream &ThenConvolveBackwardDataWithAlgorithm(
464 const dnn::FilterDescriptor &filter_descriptor,
465 const DeviceMemory<float> &filter_data,
466 const dnn::BatchDescriptor &output_descriptor,
467 DeviceMemory<float> backward_output_data,
468 const dnn::ConvolutionDescriptor &convolution_descriptor,
469 const dnn::BatchDescriptor &input_descriptor,
470 DeviceMemory<float> *backward_input_data,
471 ScratchAllocator *scratch_allocator,
472 const dnn::AlgorithmConfig &algorithm_config,
473 dnn::ProfileResult *output_profile_result);
474
475 Stream &ThenConvolveBackwardDataWithAlgorithm(
476 const dnn::FilterDescriptor &filter_descriptor,
477 const DeviceMemory<Eigen::half> &filter_data,
478 const dnn::BatchDescriptor &output_descriptor,
479 DeviceMemory<Eigen::half> backward_output_data,
480 const dnn::ConvolutionDescriptor &convolution_descriptor,
481 const dnn::BatchDescriptor &input_descriptor,
482 DeviceMemory<Eigen::half> *backward_input_data,
483 ScratchAllocator *scratch_allocator,
484 const dnn::AlgorithmConfig &algorithm_config,
485 dnn::ProfileResult *output_profile_result);
486
487 Stream &ThenConvolveBackwardFilterWithAlgorithm(
488 const dnn::BatchDescriptor &input_descriptor,
489 const DeviceMemory<double> &input_data,
490 const dnn::BatchDescriptor &output_descriptor,
491 DeviceMemory<double> backward_output_data,
492 const dnn::ConvolutionDescriptor &convolution_descriptor,
493 const dnn::FilterDescriptor &filter_descriptor,
494 DeviceMemory<double> *backward_filter_data,
495 ScratchAllocator *scratch_allocator,
496 const dnn::AlgorithmConfig &algorithm_config,
497 dnn::ProfileResult *output_profile_result);
498
499 Stream &ThenConvolveBackwardFilterWithAlgorithm(
500 const dnn::BatchDescriptor &input_descriptor,
501 const DeviceMemory<float> &input_data,
502 const dnn::BatchDescriptor &output_descriptor,
503 DeviceMemory<float> backward_output_data,
504 const dnn::ConvolutionDescriptor &convolution_descriptor,
505 const dnn::FilterDescriptor &filter_descriptor,
506 DeviceMemory<float> *backward_filter_data,
507 ScratchAllocator *scratch_allocator,
508 const dnn::AlgorithmConfig &algorithm_config,
509 dnn::ProfileResult *output_profile_result);
510
511 Stream &ThenConvolveBackwardFilterWithAlgorithm(
512 const dnn::BatchDescriptor &input_descriptor,
513 const DeviceMemory<Eigen::half> &input_data,
514 const dnn::BatchDescriptor &output_descriptor,
515 DeviceMemory<Eigen::half> backward_output_data,
516 const dnn::ConvolutionDescriptor &convolution_descriptor,
517 const dnn::FilterDescriptor &filter_descriptor,
518 DeviceMemory<Eigen::half> *backward_filter_data,
519 ScratchAllocator *scratch_allocator,
520 const dnn::AlgorithmConfig &algorithm_config,
521 dnn::ProfileResult *output_profile_result);
522
523 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
524 const DeviceMemory<double> &input_data,
525 const dnn::BatchDescriptor &bias_descriptor,
526 DeviceMemory<double> *backward_bias_data);
527
528 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
529 const DeviceMemory<float> &input_data,
530 const dnn::BatchDescriptor &bias_descriptor,
531 DeviceMemory<float> *backward_bias_data);
532
533 Stream &ThenConvolveBackwardBias(
534 const dnn::BatchDescriptor &input_descriptor,
535 const DeviceMemory<Eigen::half> &input_data,
536 const dnn::BatchDescriptor &bias_descriptor,
537 DeviceMemory<Eigen::half> *backward_bias_data);
538
539 Stream &ThenMatMul(const DeviceMemory<float> &input_data,
540 const DeviceMemory<float> &weights,
541 const dnn::BatchDescriptor &input_dimensions,
542 const dnn::BatchDescriptor &output_dimensions,
543 DeviceMemory<float> *output_data);
544
545 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
546 const DeviceMemory<int8> &weights,
547 const DeviceMemory<float> &weight_scales,
548 const dnn::BatchDescriptor &input_dimensions,
549 const dnn::BatchDescriptor &output_dimensions,
550 DeviceMemory<float> *output_data);
551
552 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
553 const DeviceMemory<int16> &weights,
554 const DeviceMemory<float> &weight_scales,
555 const dnn::BatchDescriptor &input_dimensions,
556 const dnn::BatchDescriptor &output_dimensions,
557 DeviceMemory<float> *output_data);
558
559 Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
560 const DeviceMemory<float> &biases,
561 const dnn::BatchDescriptor &dimensions,
562 DeviceMemory<float> *output_data);
563
564 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
565 const dnn::BatchDescriptor &input_dimensions,
566 const DeviceMemory<double> &input_data,
567 const dnn::BatchDescriptor &output_dimensions,
568 DeviceMemory<double> *output_data,
569 ScratchAllocator *workspace_allocator = nullptr);
570
571 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
572 const dnn::BatchDescriptor &input_dimensions,
573 const DeviceMemory<float> &input_data,
574 const dnn::BatchDescriptor &output_dimensions,
575 DeviceMemory<float> *output_data,
576 ScratchAllocator *workspace_allocator = nullptr);
577
578 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
579 const dnn::BatchDescriptor &input_dimensions,
580 const DeviceMemory<Eigen::half> &input_data,
581 const dnn::BatchDescriptor &output_dimensions,
582 DeviceMemory<Eigen::half> *output_data,
583 ScratchAllocator *workspace_allocator = nullptr);
584
585 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
586 const dnn::BatchDescriptor &input_dimensions,
587 const DeviceMemory<int8> &input_data,
588 const dnn::BatchDescriptor &output_dimensions,
589 DeviceMemory<int8> *output_data,
590 ScratchAllocator *workspace_allocator = nullptr);
591
592 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
593 const dnn::BatchDescriptor &input_dimensions,
594 const DeviceMemory<double> &input_data,
595 const dnn::BatchDescriptor &output_dimensions,
596 const DeviceMemory<double> &output_data,
597 const DeviceMemory<double> &input_diff_data,
598 DeviceMemory<double> *output_diff_data,
599 ScratchAllocator *workspace_allocator = nullptr);
600
601 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
602 const dnn::BatchDescriptor &input_dimensions,
603 const DeviceMemory<float> &input_data,
604 const dnn::BatchDescriptor &output_dimensions,
605 const DeviceMemory<float> &output_data,
606 const DeviceMemory<float> &input_diff_data,
607 DeviceMemory<float> *output_diff_data,
608 ScratchAllocator *workspace_allocator = nullptr);
609
610 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
611 const dnn::BatchDescriptor &input_dimensions,
612 const DeviceMemory<Eigen::half> &input_data,
613 const dnn::BatchDescriptor &output_dimensions,
614 const DeviceMemory<Eigen::half> &output_data,
615 const DeviceMemory<Eigen::half> &input_diff_data,
616 DeviceMemory<Eigen::half> *output_diff_data,
617 ScratchAllocator *workspace_allocator = nullptr);
618
619 Stream &ThenNormalizeWithDimensions(
620 const dnn::NormalizeDescriptor &normalize_descriptor,
621 const dnn::BatchDescriptor &dimensions,
622 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
623
624 Stream &ThenNormalizeBackwardWithDimensions(
625 const dnn::NormalizeDescriptor &normalize_descriptor,
626 const dnn::BatchDescriptor &dimensions,
627 const DeviceMemory<float> &raw_data,
628 const DeviceMemory<float> &normalized_data,
629 const DeviceMemory<float> &normalized_variable_gradient,
630 DeviceMemory<float> *raw_variable_gradient,
631 ScratchAllocator *workspace_allocator = nullptr);
632
633 Stream &ThenActivate(dnn::ActivationMode activation_mode,
634 const dnn::BatchDescriptor &dimensions,
635 const DeviceMemory<float> &input_data,
636 DeviceMemory<float> *output_data);
637
638 // Same as ThenActivate, but also takes an options argument that can be used
639 // for platform-specific option flags.
640 Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
641 const dnn::BatchDescriptor &dimensions,
642 const DeviceMemory<float> &input_data,
643 DeviceMemory<float> *output_data,
644 uint64 options);
645
646 Stream &ThenDepthConcatenate(
647 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
648 port::ArraySlice<const DeviceMemory<float> *> input_data,
649 DeviceMemory<float> *output_data);
650
651 Stream &ThenSpaceConcatenate(
652 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
653 port::ArraySlice<const DeviceMemory<float> *> input_data,
654 DeviceMemory<float> *output_data,
655 dnn::SpaceConcatenateMode concat_direction);
656
657 // Change the layout of the data by shrinking one dimension (or set of
658 // dimensions) and growing another dimension (or set of dimensions), while
659 // keeping the total number of data elements constant, and maintaining the
660 // current data ordering.
661 Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
662 const DeviceMemory<float> &input_data,
663 const dnn::BatchDescriptor &output_dimensions,
664 DeviceMemory<float> *output_data);
665
666 // Depth to space takes an X by Y image with depth D*M² and changes it to an
667 // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
668 // the input image is changed to an MxM contiguous area in the output image,
669 // with the values being laid out in raster order specified by
670 // DepthToSpaceLayout, and will have a new depth of D.
671 // See the DoDepthToSpace comment for more information.
672 Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
673 const DeviceMemory<float> &input_data,
674 const dnn::DepthToSpaceLayout &depth_to_space_layout,
675 const int sqrt_depth_reduction,
676 DeviceMemory<float> *output_data);
677
678 // Space to depth is the inverse of depth to space. Space to depth takes each
679 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
680 // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
681 // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
682 // data elements is not changed.
683 Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
684 const DeviceMemory<float> &input_data,
685 const dnn::DepthToSpaceLayout &space_to_depth_layout,
686 const int sqrt_depth_increase,
687 DeviceMemory<float> *output_data);
688
689 Stream &ThenElementwiseOperate(
690 dnn::ElementwiseOperation operation,
691 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
692 port::ArraySlice<const DeviceMemory<float> *> input_data,
693 const dnn::BatchDescriptor &output_dimensions,
694 DeviceMemory<float> *output_data);
695
696 Stream &ThenElementwiseOperateScaledQuantized(
697 dnn::ElementwiseOperation operation,
698 port::ArraySlice<int> input_multiplicands, int output_divisor,
699 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
700 port::ArraySlice<const DeviceMemory<float> *> input_data,
701 const dnn::BatchDescriptor &output_dimensions,
702 DeviceMemory<float> *output_data);
703
704 Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
705 const DeviceMemory<float> &input_data, int64 left_pad,
706 int64 right_pad, int64 top_pad, int64 bottom_pad,
707 DeviceMemory<float> *output_data);
708
709 Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
710 const DeviceMemory<float> &input_data, int64 left_trim,
711 int64 right_trim, int64 top_trim, int64 bottom_trim,
712 DeviceMemory<float> *output_data);
713
714 // Grows the input tensor by replicating the X and Y dimensions. The batch and
715 // depth/feature_map dimensions are unchanged. Currently, the input tensor is
716 // limited to X=1 and Y=1.
717 Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
718 const DeviceMemory<float> &input_data,
719 int64 replicate_x, int64 replicate_y,
720 DeviceMemory<float> *output_data);
721
722 // See DnnSupport::DoMemcpyD2HQuantized.
723 Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
724 dnn::QuantizedActivationMode mode,
725 void *host_dst, uint64 size);
726
727 // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
728 // and uses the Quantization trait to call the generic version of
729 // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
730 template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)731 Stream &ThenMemcpyD2HQuantized(
732 const DeviceMemory<float> &gpu_unquantized_src,
733 port::MutableArraySlice<ElementType> host_dst) {
734 return ThenMemcpyD2HQuantized(
735 gpu_unquantized_src, Quantization<ElementType>::kModeId,
736 host_dst.data(), host_dst.size() * sizeof(ElementType));
737 }
738
739 // See DnnSupport::DoMemcpyH2DQuantized.
740 Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
741 dnn::QuantizedActivationMode mode,
742 DeviceMemory<float> *gpu_unquantized_dst);
743
744 // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
745 // and uses the Quantization trait to call the generic version of
746 // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
747 template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)748 Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
749 DeviceMemory<float> *gpu_unquantized_dst) {
750 return ThenMemcpyH2DQuantized(
751 host_src.data(), host_src.size() * sizeof(ElementType),
752 Quantization<ElementType>::kModeId, gpu_unquantized_dst);
753 }
754
755 // See DnnSupport::DoCopyHostBuffer2Device.
756 Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
757 DeviceMemory<float> *gpu_unquantized_dst);
758
759 // See DnnSupport::DoCopyDevice2HostBuffer.
760 Stream &ThenCopyDevice2HostBuffer(
761 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
762
763 /////////////////
764 // BLAS support
765
766 // See BlasSupport::DoBlasAsum.
767 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
768 int incx, DeviceMemory<float> *result);
769 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
770 int incx, DeviceMemory<double> *result);
771 Stream &ThenBlasAsum(uint64 elem_count,
772 const DeviceMemory<std::complex<float>> &x, int incx,
773 DeviceMemory<float> *result);
774 Stream &ThenBlasAsum(uint64 elem_count,
775 const DeviceMemory<std::complex<double>> &x, int incx,
776 DeviceMemory<double> *result);
777
778 // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
779 // present in DeviceMemory, it must be an execution-time constant (i.e. a
780 // value
781 // that the stream does not change or populate during the course of
782 // execution). The value is effectively captured at stream-enqueue time.
783 Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
784 const DeviceMemory<float> &x, int incx,
785 DeviceMemory<float> *y, int incy);
786 Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
787 const DeviceMemory<double> &x, int incx,
788 DeviceMemory<double> *y, int incy);
789 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
790 const DeviceMemory<std::complex<float>> &x, int incx,
791 DeviceMemory<std::complex<float>> *y, int incy);
792 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
793 const DeviceMemory<std::complex<double>> &x, int incx,
794 DeviceMemory<std::complex<double>> *y, int incy);
795
796 // See BlasSupport::DoBlasCopy.
797 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
798 int incx, DeviceMemory<float> *y, int incy);
799 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
800 int incx, DeviceMemory<double> *y, int incy);
801 Stream &ThenBlasCopy(uint64 elem_count,
802 const DeviceMemory<std::complex<float>> &x, int incx,
803 DeviceMemory<std::complex<float>> *y, int incy);
804 Stream &ThenBlasCopy(uint64 elem_count,
805 const DeviceMemory<std::complex<double>> &x, int incx,
806 DeviceMemory<std::complex<double>> *y, int incy);
807
808 // See BlasSupport::DoBlasDot.
809 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
810 const DeviceMemory<float> &y, int incy,
811 DeviceMemory<float> *result);
812 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
813 int incx, const DeviceMemory<double> &y, int incy,
814 DeviceMemory<double> *result);
815
816 // See BlasSupport::DoBlasDotc.
817 Stream &ThenBlasDotc(uint64 elem_count,
818 const DeviceMemory<std::complex<float>> &x, int incx,
819 const DeviceMemory<std::complex<float>> &y, int incy,
820 DeviceMemory<std::complex<float>> *result);
821 Stream &ThenBlasDotc(uint64 elem_count,
822 const DeviceMemory<std::complex<double>> &x, int incx,
823 const DeviceMemory<std::complex<double>> &y, int incy,
824 DeviceMemory<std::complex<double>> *result);
825
826 // See BlasSupport::DoBlasDotu.
827 Stream &ThenBlasDotu(uint64 elem_count,
828 const DeviceMemory<std::complex<float>> &x, int incx,
829 const DeviceMemory<std::complex<float>> &y, int incy,
830 DeviceMemory<std::complex<float>> *result);
831 Stream &ThenBlasDotu(uint64 elem_count,
832 const DeviceMemory<std::complex<double>> &x, int incx,
833 const DeviceMemory<std::complex<double>> &y, int incy,
834 DeviceMemory<std::complex<double>> *result);
835
836 // See BlasSupport::DoBlasNrm2.
837 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
838 int incx, DeviceMemory<float> *result);
839 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
840 int incx, DeviceMemory<double> *result);
841 Stream &ThenBlasNrm2(uint64 elem_count,
842 const DeviceMemory<std::complex<float>> &x, int incx,
843 DeviceMemory<float> *result);
844 Stream &ThenBlasNrm2(uint64 elem_count,
845 const DeviceMemory<std::complex<double>> &x, int incx,
846 DeviceMemory<double> *result);
847
848 // See BlasSupport::DoBlasRot.
849 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
850 DeviceMemory<float> *y, int incy, float c, float s);
851 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
852 DeviceMemory<double> *y, int incy, double c, double s);
853 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
854 int incx, DeviceMemory<std::complex<float>> *y, int incy,
855 float c, float s);
856 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
857 int incx, DeviceMemory<std::complex<double>> *y, int incy,
858 double c, double s);
859
860 // See BlasSupport::DoBlasRotg.
861 Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
862 DeviceMemory<float> *c, DeviceMemory<float> *s);
863 Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
864 DeviceMemory<double> *c, DeviceMemory<double> *s);
865 Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
866 DeviceMemory<std::complex<float>> *b,
867 DeviceMemory<float> *c,
868 DeviceMemory<std::complex<float>> *s);
869 Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
870 DeviceMemory<std::complex<double>> *b,
871 DeviceMemory<double> *c,
872 DeviceMemory<std::complex<double>> *s);
873
874 // See BlasSupport::DoBlasRotm.
875 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
876 DeviceMemory<float> *y, int incy,
877 const DeviceMemory<float> ¶m);
878 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
879 DeviceMemory<double> *y, int incy,
880 const DeviceMemory<double> ¶m);
881
882 // See BlasSupport::DoBlasRotmg.
883 Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
884 DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
885 DeviceMemory<float> *param);
886 Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
887 DeviceMemory<double> *x1,
888 const DeviceMemory<double> &y1,
889 DeviceMemory<double> *param);
890
891 // See BlasSupport::DoBlasScal.
892 Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
893 int incx);
894 Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
895 int incx);
896 Stream &ThenBlasScal(uint64 elem_count, float alpha,
897 DeviceMemory<std::complex<float>> *x, int incx);
898 Stream &ThenBlasScal(uint64 elem_count, double alpha,
899 DeviceMemory<std::complex<double>> *x, int incx);
900 Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
901 DeviceMemory<std::complex<float>> *x, int incx);
902 Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
903 DeviceMemory<std::complex<double>> *x, int incx);
904
905 // See BlasSupport::DoBlasSwap.
906 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
907 DeviceMemory<float> *y, int incy);
908 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
909 DeviceMemory<double> *y, int incy);
910 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
911 int incx, DeviceMemory<std::complex<float>> *y,
912 int incy);
913 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
914 int incx, DeviceMemory<std::complex<double>> *y,
915 int incy);
916
917 // See BlasSupport::DoBlasIamax.
918 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
919 int incx, DeviceMemory<int> *result);
920 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
921 int incx, DeviceMemory<int> *result);
922 Stream &ThenBlasIamax(uint64 elem_count,
923 const DeviceMemory<std::complex<float>> &x, int incx,
924 DeviceMemory<int> *result);
925 Stream &ThenBlasIamax(uint64 elem_count,
926 const DeviceMemory<std::complex<double>> &x, int incx,
927 DeviceMemory<int> *result);
928
929 // See BlasSupport::DoBlasIamin.
930 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
931 int incx, DeviceMemory<int> *result);
932 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
933 int incx, DeviceMemory<int> *result);
934 Stream &ThenBlasIamin(uint64 elem_count,
935 const DeviceMemory<std::complex<float>> &x, int incx,
936 DeviceMemory<int> *result);
937 Stream &ThenBlasIamin(uint64 elem_count,
938 const DeviceMemory<std::complex<double>> &x, int incx,
939 DeviceMemory<int> *result);
940
941 // See BlasSupport::DoBlasGbmv.
942 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
943 uint64 ku, float alpha, const DeviceMemory<float> &a,
944 int lda, const DeviceMemory<float> &x, int incx,
945 float beta, DeviceMemory<float> *y, int incy);
946 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
947 uint64 ku, double alpha, const DeviceMemory<double> &a,
948 int lda, const DeviceMemory<double> &x, int incx,
949 double beta, DeviceMemory<double> *y, int incy);
950 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
951 uint64 ku, std::complex<float> alpha,
952 const DeviceMemory<std::complex<float>> &a, int lda,
953 const DeviceMemory<std::complex<float>> &x, int incx,
954 std::complex<float> beta,
955 DeviceMemory<std::complex<float>> *y, int incy);
956 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
957 uint64 ku, std::complex<double> alpha,
958 const DeviceMemory<std::complex<double>> &a, int lda,
959 const DeviceMemory<std::complex<double>> &x, int incx,
960 std::complex<double> beta,
961 DeviceMemory<std::complex<double>> *y, int incy);
962
963 // See BlasSupport::DoBlasGemv.
964 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
965 const DeviceMemory<float> &a, int lda,
966 const DeviceMemory<float> &x, int incx, float beta,
967 DeviceMemory<float> *y, int incy);
968 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
969 const DeviceMemory<double> &a, int lda,
970 const DeviceMemory<double> &x, int incx, double beta,
971 DeviceMemory<double> *y, int incy);
972 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
973 std::complex<float> alpha,
974 const DeviceMemory<std::complex<float>> &a, int lda,
975 const DeviceMemory<std::complex<float>> &x, int incx,
976 std::complex<float> beta,
977 DeviceMemory<std::complex<float>> *y, int incy);
978 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
979 std::complex<double> alpha,
980 const DeviceMemory<std::complex<double>> &a, int lda,
981 const DeviceMemory<std::complex<double>> &x, int incx,
982 std::complex<double> beta,
983 DeviceMemory<std::complex<double>> *y, int incy);
984
985 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
986 float alpha, const DeviceMemory<float> &a,
987 int lda, const DeviceMemory<float> &x,
988 int incx, float beta,
989 DeviceMemory<float> *y, int incy,
990 blas::ProfileResult *output_profile_result);
991 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
992 double alpha, const DeviceMemory<double> &a,
993 int lda, const DeviceMemory<double> &x,
994 int incx, double beta,
995 DeviceMemory<double> *y, int incy,
996 blas::ProfileResult *output_profile_result);
997 Stream &ThenBlasGemvWithProfiling(
998 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
999 const DeviceMemory<std::complex<float>> &a, int lda,
1000 const DeviceMemory<std::complex<float>> &x, int incx,
1001 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1002 blas::ProfileResult *output_profile_result);
1003 Stream &ThenBlasGemvWithProfiling(
1004 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
1005 const DeviceMemory<std::complex<double>> &a, int lda,
1006 const DeviceMemory<std::complex<double>> &x, int incx,
1007 std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
1008 int incy, blas::ProfileResult *output_profile_result);
1009
1010 // See BlasSupport::DoBlasGer.
1011 Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
1012 const DeviceMemory<float> &x, int incx,
1013 const DeviceMemory<float> &y, int incy,
1014 DeviceMemory<float> *a, int lda);
1015 Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
1016 const DeviceMemory<double> &x, int incx,
1017 const DeviceMemory<double> &y, int incy,
1018 DeviceMemory<double> *a, int lda);
1019
1020 // See BlasSupport::DoBlasGerc.
1021 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
1022 const DeviceMemory<std::complex<float>> &x, int incx,
1023 const DeviceMemory<std::complex<float>> &y, int incy,
1024 DeviceMemory<std::complex<float>> *a, int lda);
1025 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
1026 const DeviceMemory<std::complex<double>> &x, int incx,
1027 const DeviceMemory<std::complex<double>> &y, int incy,
1028 DeviceMemory<std::complex<double>> *a, int lda);
1029
1030 // See BlasSupport::DoBlasGeru.
1031 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
1032 const DeviceMemory<std::complex<float>> &x, int incx,
1033 const DeviceMemory<std::complex<float>> &y, int incy,
1034 DeviceMemory<std::complex<float>> *a, int lda);
1035 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
1036 const DeviceMemory<std::complex<double>> &x, int incx,
1037 const DeviceMemory<std::complex<double>> &y, int incy,
1038 DeviceMemory<std::complex<double>> *a, int lda);
1039
1040 // See BlasSupport::DoBlasHbmv.
1041 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1042 std::complex<float> alpha,
1043 const DeviceMemory<std::complex<float>> &a, int lda,
1044 const DeviceMemory<std::complex<float>> &x, int incx,
1045 std::complex<float> beta,
1046 DeviceMemory<std::complex<float>> *y, int incy);
1047 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1048 std::complex<double> alpha,
1049 const DeviceMemory<std::complex<double>> &a, int lda,
1050 const DeviceMemory<std::complex<double>> &x, int incx,
1051 std::complex<double> beta,
1052 DeviceMemory<std::complex<double>> *y, int incy);
1053
1054 // See BlasSupport::DoBlasHemv.
1055 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1056 std::complex<float> alpha,
1057 const DeviceMemory<std::complex<float>> &a, int lda,
1058 const DeviceMemory<std::complex<float>> &x, int incx,
1059 std::complex<float> beta,
1060 DeviceMemory<std::complex<float>> *y, int incy);
1061 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1062 std::complex<double> alpha,
1063 const DeviceMemory<std::complex<double>> &a, int lda,
1064 const DeviceMemory<std::complex<double>> &x, int incx,
1065 std::complex<double> beta,
1066 DeviceMemory<std::complex<double>> *y, int incy);
1067
1068 // See BlasSupport::DoBlasHer.
1069 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1070 const DeviceMemory<std::complex<float>> &x, int incx,
1071 DeviceMemory<std::complex<float>> *a, int lda);
1072 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1073 const DeviceMemory<std::complex<double>> &x, int incx,
1074 DeviceMemory<std::complex<double>> *a, int lda);
1075
1076 // See BlasSupport::DoBlasHer2.
1077 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1078 std::complex<float> alpha,
1079 const DeviceMemory<std::complex<float>> &x, int incx,
1080 const DeviceMemory<std::complex<float>> &y, int incy,
1081 DeviceMemory<std::complex<float>> *a, int lda);
1082 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1083 std::complex<double> alpha,
1084 const DeviceMemory<std::complex<double>> &x, int incx,
1085 const DeviceMemory<std::complex<double>> &y, int incy,
1086 DeviceMemory<std::complex<double>> *a, int lda);
1087
1088 // See BlasSupport::DoBlasHpmv.
1089 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1090 std::complex<float> alpha,
1091 const DeviceMemory<std::complex<float>> &ap,
1092 const DeviceMemory<std::complex<float>> &x, int incx,
1093 std::complex<float> beta,
1094 DeviceMemory<std::complex<float>> *y, int incy);
1095 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1096 std::complex<double> alpha,
1097 const DeviceMemory<std::complex<double>> &ap,
1098 const DeviceMemory<std::complex<double>> &x, int incx,
1099 std::complex<double> beta,
1100 DeviceMemory<std::complex<double>> *y, int incy);
1101
1102 // See BlasSupport::DoBlasHpr.
1103 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1104 const DeviceMemory<std::complex<float>> &x, int incx,
1105 DeviceMemory<std::complex<float>> *ap);
1106 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1107 const DeviceMemory<std::complex<double>> &x, int incx,
1108 DeviceMemory<std::complex<double>> *ap);
1109
1110 // See BlasSupport::DoBlasHpr2.
1111 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1112 std::complex<float> alpha,
1113 const DeviceMemory<std::complex<float>> &x, int incx,
1114 const DeviceMemory<std::complex<float>> &y, int incy,
1115 DeviceMemory<std::complex<float>> *ap);
1116 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1117 std::complex<double> alpha,
1118 const DeviceMemory<std::complex<double>> &x, int incx,
1119 const DeviceMemory<std::complex<double>> &y, int incy,
1120 DeviceMemory<std::complex<double>> *ap);
1121
1122 // See BlasSupport::DoBlasSbmv.
1123 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1124 const DeviceMemory<float> &a, int lda,
1125 const DeviceMemory<float> &x, int incx, float beta,
1126 DeviceMemory<float> *y, int incy);
1127 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1128 const DeviceMemory<double> &a, int lda,
1129 const DeviceMemory<double> &x, int incx, double beta,
1130 DeviceMemory<double> *y, int incy);
1131
1132 // See BlasSupport::DoBlasSpmv.
1133 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1134 const DeviceMemory<float> &ap,
1135 const DeviceMemory<float> &x, int incx, float beta,
1136 DeviceMemory<float> *y, int incy);
1137 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1138 const DeviceMemory<double> &ap,
1139 const DeviceMemory<double> &x, int incx, double beta,
1140 DeviceMemory<double> *y, int incy);
1141
1142 // See BlasSupport::DoBlasSpr.
1143 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1144 const DeviceMemory<float> &x, int incx,
1145 DeviceMemory<float> *ap);
1146 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1147 const DeviceMemory<double> &x, int incx,
1148 DeviceMemory<double> *ap);
1149
1150 // See BlasSupport::DoBlasSpr2.
1151 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1152 const DeviceMemory<float> &x, int incx,
1153 const DeviceMemory<float> &y, int incy,
1154 DeviceMemory<float> *ap);
1155 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1156 const DeviceMemory<double> &x, int incx,
1157 const DeviceMemory<double> &y, int incy,
1158 DeviceMemory<double> *ap);
1159
1160 // See BlasSupport::DoBlasSymv.
1161 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1162 const DeviceMemory<float> &a, int lda,
1163 const DeviceMemory<float> &x, int incx, float beta,
1164 DeviceMemory<float> *y, int incy);
1165 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1166 const DeviceMemory<double> &a, int lda,
1167 const DeviceMemory<double> &x, int incx, double beta,
1168 DeviceMemory<double> *y, int incy);
1169
1170 // See BlasSupport::DoBlasSyr.
1171 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1172 const DeviceMemory<float> &x, int incx,
1173 DeviceMemory<float> *a, int lda);
1174 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1175 const DeviceMemory<double> &x, int incx,
1176 DeviceMemory<double> *a, int lda);
1177
1178 // See BlasSupport::DoBlasSyr2.
1179 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1180 const DeviceMemory<float> &x, int incx,
1181 const DeviceMemory<float> &y, int incy,
1182 DeviceMemory<float> *a, int lda);
1183 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1184 const DeviceMemory<double> &x, int incx,
1185 const DeviceMemory<double> &y, int incy,
1186 DeviceMemory<double> *a, int lda);
1187
1188 // See BlasSupport::DoBlasTbmv.
1189 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1190 blas::Diagonal diag, uint64 n, uint64 k,
1191 const DeviceMemory<float> &a, int lda,
1192 DeviceMemory<float> *x, int incx);
1193 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1194 blas::Diagonal diag, uint64 n, uint64 k,
1195 const DeviceMemory<double> &a, int lda,
1196 DeviceMemory<double> *x, int incx);
1197 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1198 blas::Diagonal diag, uint64 n, uint64 k,
1199 const DeviceMemory<std::complex<float>> &a, int lda,
1200 DeviceMemory<std::complex<float>> *x, int incx);
1201 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1202 blas::Diagonal diag, uint64 n, uint64 k,
1203 const DeviceMemory<std::complex<double>> &a, int lda,
1204 DeviceMemory<std::complex<double>> *x, int incx);
1205
1206 // See BlasSupport::DoBlasTbsv.
1207 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1208 blas::Diagonal diag, uint64 n, uint64 k,
1209 const DeviceMemory<float> &a, int lda,
1210 DeviceMemory<float> *x, int incx);
1211 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1212 blas::Diagonal diag, uint64 n, uint64 k,
1213 const DeviceMemory<double> &a, int lda,
1214 DeviceMemory<double> *x, int incx);
1215 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1216 blas::Diagonal diag, uint64 n, uint64 k,
1217 const DeviceMemory<std::complex<float>> &a, int lda,
1218 DeviceMemory<std::complex<float>> *x, int incx);
1219 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1220 blas::Diagonal diag, uint64 n, uint64 k,
1221 const DeviceMemory<std::complex<double>> &a, int lda,
1222 DeviceMemory<std::complex<double>> *x, int incx);
1223
1224 // See BlasSupport::DoBlasTpmv.
1225 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1226 blas::Diagonal diag, uint64 n,
1227 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1228 int incx);
1229 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1230 blas::Diagonal diag, uint64 n,
1231 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1232 int incx);
1233 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1234 blas::Diagonal diag, uint64 n,
1235 const DeviceMemory<std::complex<float>> &ap,
1236 DeviceMemory<std::complex<float>> *x, int incx);
1237 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1238 blas::Diagonal diag, uint64 n,
1239 const DeviceMemory<std::complex<double>> &ap,
1240 DeviceMemory<std::complex<double>> *x, int incx);
1241
1242 // See BlasSupport::DoBlasTpsv.
1243 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1244 blas::Diagonal diag, uint64 n,
1245 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1246 int incx);
1247 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1248 blas::Diagonal diag, uint64 n,
1249 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1250 int incx);
1251 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1252 blas::Diagonal diag, uint64 n,
1253 const DeviceMemory<std::complex<float>> &ap,
1254 DeviceMemory<std::complex<float>> *x, int incx);
1255 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1256 blas::Diagonal diag, uint64 n,
1257 const DeviceMemory<std::complex<double>> &ap,
1258 DeviceMemory<std::complex<double>> *x, int incx);
1259
1260 // See BlasSupport::DoBlasTrmv.
1261 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1262 blas::Diagonal diag, uint64 n,
1263 const DeviceMemory<float> &a, int lda,
1264 DeviceMemory<float> *x, int incx);
1265 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1266 blas::Diagonal diag, uint64 n,
1267 const DeviceMemory<double> &a, int lda,
1268 DeviceMemory<double> *x, int incx);
1269 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1270 blas::Diagonal diag, uint64 n,
1271 const DeviceMemory<std::complex<float>> &a, int lda,
1272 DeviceMemory<std::complex<float>> *x, int incx);
1273 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1274 blas::Diagonal diag, uint64 n,
1275 const DeviceMemory<std::complex<double>> &a, int lda,
1276 DeviceMemory<std::complex<double>> *x, int incx);
1277
1278 // See BlasSupport::DoBlasTrsv.
1279 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1280 blas::Diagonal diag, uint64 n,
1281 const DeviceMemory<float> &a, int lda,
1282 DeviceMemory<float> *x, int incx);
1283 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1284 blas::Diagonal diag, uint64 n,
1285 const DeviceMemory<double> &a, int lda,
1286 DeviceMemory<double> *x, int incx);
1287 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1288 blas::Diagonal diag, uint64 n,
1289 const DeviceMemory<std::complex<float>> &a, int lda,
1290 DeviceMemory<std::complex<float>> *x, int incx);
1291 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1292 blas::Diagonal diag, uint64 n,
1293 const DeviceMemory<std::complex<double>> &a, int lda,
1294 DeviceMemory<std::complex<double>> *x, int incx);
1295
1296 // See BlasSupport::DoBlasGemm.
1297 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1298 uint64 m, uint64 n, uint64 k, float alpha,
1299 const DeviceMemory<Eigen::half> &a, int lda,
1300 const DeviceMemory<Eigen::half> &b, int ldb,
1301 float beta, DeviceMemory<Eigen::half> *c,
1302 int ldc);
1303 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1304 uint64 m, uint64 n, uint64 k, float alpha,
1305 const DeviceMemory<float> &a, int lda,
1306 const DeviceMemory<float> &b, int ldb,
1307 float beta, DeviceMemory<float> *c, int ldc);
1308 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1309 uint64 m, uint64 n, uint64 k, double alpha,
1310 const DeviceMemory<double> &a, int lda,
1311 const DeviceMemory<double> &b, int ldb,
1312 double beta, DeviceMemory<double> *c, int ldc);
1313 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1314 uint64 m, uint64 n, uint64 k,
1315 std::complex<float> alpha,
1316 const DeviceMemory<std::complex<float>> &a,
1317 int lda,
1318 const DeviceMemory<std::complex<float>> &b,
1319 int ldb, std::complex<float> beta,
1320 DeviceMemory<std::complex<float>> *c, int ldc);
1321 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1322 uint64 m, uint64 n, uint64 k,
1323 std::complex<double> alpha,
1324 const DeviceMemory<std::complex<double>> &a,
1325 int lda,
1326 const DeviceMemory<std::complex<double>> &b,
1327 int ldb, std::complex<double> beta,
1328 DeviceMemory<std::complex<double>> *c,
1329 int ldc);
1330
1331 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1332 blas::Transpose transb, uint64 m, uint64 n,
1333 uint64 k, float alpha,
1334 const DeviceMemory<Eigen::half> &a, int lda,
1335 const DeviceMemory<Eigen::half> &b, int ldb,
1336 float beta, DeviceMemory<Eigen::half> *c,
1337 int ldc,
1338 blas::ProfileResult *output_profile_result);
1339 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1340 blas::Transpose transb, uint64 m, uint64 n,
1341 uint64 k, float alpha,
1342 const DeviceMemory<float> &a, int lda,
1343 const DeviceMemory<float> &b, int ldb,
1344 float beta, DeviceMemory<float> *c, int ldc,
1345 blas::ProfileResult *output_profile_result);
1346 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1347 blas::Transpose transb, uint64 m, uint64 n,
1348 uint64 k, double alpha,
1349 const DeviceMemory<double> &a, int lda,
1350 const DeviceMemory<double> &b, int ldb,
1351 double beta, DeviceMemory<double> *c,
1352 int ldc,
1353 blas::ProfileResult *output_profile_result);
1354 Stream &ThenBlasGemmWithProfiling(
1355 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1356 uint64 k, std::complex<float> alpha,
1357 const DeviceMemory<std::complex<float>> &a, int lda,
1358 const DeviceMemory<std::complex<float>> &b, int ldb,
1359 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1360 blas::ProfileResult *output_profile_result);
1361 Stream &ThenBlasGemmWithProfiling(
1362 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1363 uint64 k, std::complex<double> alpha,
1364 const DeviceMemory<std::complex<double>> &a, int lda,
1365 const DeviceMemory<std::complex<double>> &b, int ldb,
1366 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1367 blas::ProfileResult *output_profile_result);
1368
1369 // See BlasSupport::DoBlasGemmWithAlgorithm.
1370 Stream &ThenBlasGemmWithAlgorithm(
1371 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1372 uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1373 const DeviceMemory<Eigen::half> &a, int lda,
1374 const DeviceMemory<Eigen::half> &b, int ldb,
1375 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1376 int ldc, blas::ComputationType computation_type,
1377 blas::AlgorithmType algorithm,
1378 blas::ProfileResult *output_profile_result);
1379 Stream &ThenBlasGemmWithAlgorithm(
1380 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1381 uint64 k, const HostOrDeviceScalar<int> &alpha,
1382 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1383 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,
1384 int ldc, blas::ComputationType computation_type,
1385 blas::AlgorithmType algorithm,
1386 blas::ProfileResult *output_profile_result);
1387 Stream &ThenBlasGemmWithAlgorithm(
1388 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1389 uint64 k, const HostOrDeviceScalar<float> &alpha,
1390 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1391 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1392 int ldc, blas::ComputationType computation_type,
1393 blas::AlgorithmType algorithm,
1394 blas::ProfileResult *output_profile_result);
1395 Stream &ThenBlasGemmWithAlgorithm(
1396 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1397 uint64 k, const HostOrDeviceScalar<double> &alpha,
1398 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1399 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1400 int ldc, blas::ComputationType computation_type,
1401 blas::AlgorithmType algorithm,
1402 blas::ProfileResult *output_profile_result);
1403 Stream &ThenBlasGemmWithAlgorithm(
1404 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1405 uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1406 const DeviceMemory<std::complex<float>> &a, int lda,
1407 const DeviceMemory<std::complex<float>> &b, int ldb,
1408 const HostOrDeviceScalar<std::complex<float>> &beta,
1409 DeviceMemory<std::complex<float>> *c, int ldc,
1410 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1411 blas::ProfileResult *output_profile_result);
1412 Stream &ThenBlasGemmWithAlgorithm(
1413 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1414 uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1415 const DeviceMemory<std::complex<double>> &a, int lda,
1416 const DeviceMemory<std::complex<double>> &b, int ldb,
1417 const HostOrDeviceScalar<std::complex<double>> &beta,
1418 DeviceMemory<std::complex<double>> *c, int ldc,
1419 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1420 blas::ProfileResult *output_profile_result);
1421
1422 // See BlasSupport::DoBlasGemmBatched.
1423 Stream &ThenBlasGemmBatched(
1424 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1425 uint64 k, float alpha,
1426 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1427 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1428 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1429 int ldc, int batch_count);
1430 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1431 uint64 m, uint64 n, uint64 k, float alpha,
1432 const port::ArraySlice<DeviceMemory<float> *> &a,
1433 int lda,
1434 const port::ArraySlice<DeviceMemory<float> *> &b,
1435 int ldb, float beta,
1436 const port::ArraySlice<DeviceMemory<float> *> &c,
1437 int ldc, int batch_count);
1438 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1439 uint64 m, uint64 n, uint64 k, double alpha,
1440 const port::ArraySlice<DeviceMemory<double> *> &a,
1441 int lda,
1442 const port::ArraySlice<DeviceMemory<double> *> &b,
1443 int ldb, double beta,
1444 const port::ArraySlice<DeviceMemory<double> *> &c,
1445 int ldc, int batch_count);
1446 Stream &ThenBlasGemmBatched(
1447 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1448 uint64 k, std::complex<float> alpha,
1449 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1450 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1451 std::complex<float> beta,
1452 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1453 int batch_count);
1454 Stream &ThenBlasGemmBatched(
1455 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1456 uint64 k, std::complex<double> alpha,
1457 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1458 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1459 std::complex<double> beta,
1460 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1461 int batch_count);
1462 Stream &ThenBlasGemmBatchedWithScratch(
1463 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1464 uint64 k, float alpha,
1465 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1466 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1467 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1468 int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1469 Stream &ThenBlasGemmBatchedWithScratch(
1470 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1471 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1472 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1473 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1474 int batch_count, ScratchAllocator *scratch_allocator);
1475 Stream &ThenBlasGemmBatchedWithScratch(
1476 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1477 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1478 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1479 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1480 int batch_count, ScratchAllocator *scratch_allocator);
1481 Stream &ThenBlasGemmBatchedWithScratch(
1482 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1483 uint64 k, std::complex<float> alpha,
1484 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1485 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1486 std::complex<float> beta,
1487 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1488 int batch_count, ScratchAllocator *scratch_allocator);
1489 Stream &ThenBlasGemmBatchedWithScratch(
1490 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1491 uint64 k, std::complex<double> alpha,
1492 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1493 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1494 std::complex<double> beta,
1495 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1496 int batch_count, ScratchAllocator *scratch_allocator);
1497 Stream &ThenBlasGemmStridedBatched(
1498 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1499 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1500 int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1501 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1502 int64 stride_c, int batch_count);
1503 Stream &ThenBlasGemmStridedBatched(
1504 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1505 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1506 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1507 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1508 int batch_count);
1509 Stream &ThenBlasGemmStridedBatched(
1510 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1511 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1512 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1513 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1514 int batch_count);
1515 Stream &ThenBlasGemmStridedBatched(
1516 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1517 uint64 k, std::complex<float> alpha,
1518 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1519 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1520 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1521 int64 stride_c, int batch_count);
1522 Stream &ThenBlasGemmStridedBatched(
1523 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1524 uint64 k, std::complex<double> alpha,
1525 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1526 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1527 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1528 int64 stride_c, int batch_count);
1529
1530 // See BlasSupport::DoBlasHemm.
1531 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1532 uint64 n, std::complex<float> alpha,
1533 const DeviceMemory<std::complex<float>> &a, int lda,
1534 const DeviceMemory<std::complex<float>> &b, int ldb,
1535 std::complex<float> beta,
1536 DeviceMemory<std::complex<float>> *c, int ldc);
1537 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1538 uint64 n, std::complex<double> alpha,
1539 const DeviceMemory<std::complex<double>> &a, int lda,
1540 const DeviceMemory<std::complex<double>> &b, int ldb,
1541 std::complex<double> beta,
1542 DeviceMemory<std::complex<double>> *c, int ldc);
1543
1544 // See BlasSupport::DoBlasHerk.
1545 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1546 uint64 k, float alpha,
1547 const DeviceMemory<std::complex<float>> &a, int lda,
1548 float beta, DeviceMemory<std::complex<float>> *c,
1549 int ldc);
1550 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1551 uint64 k, double alpha,
1552 const DeviceMemory<std::complex<double>> &a, int lda,
1553 double beta, DeviceMemory<std::complex<double>> *c,
1554 int ldc);
1555
1556 // See BlasSupport::DoBlasHer2k.
1557 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1558 uint64 k, std::complex<float> alpha,
1559 const DeviceMemory<std::complex<float>> &a, int lda,
1560 const DeviceMemory<std::complex<float>> &b, int ldb,
1561 float beta, DeviceMemory<std::complex<float>> *c,
1562 int ldc);
1563 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1564 uint64 k, std::complex<double> alpha,
1565 const DeviceMemory<std::complex<double>> &a, int lda,
1566 const DeviceMemory<std::complex<double>> &b, int ldb,
1567 double beta, DeviceMemory<std::complex<double>> *c,
1568 int ldc);
1569
1570 // See BlasSupport::DoBlasSymm.
1571 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1572 uint64 n, float alpha, const DeviceMemory<float> &a,
1573 int lda, const DeviceMemory<float> &b, int ldb,
1574 float beta, DeviceMemory<float> *c, int ldc);
1575 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1576 uint64 n, double alpha, const DeviceMemory<double> &a,
1577 int lda, const DeviceMemory<double> &b, int ldb,
1578 double beta, DeviceMemory<double> *c, int ldc);
1579 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1580 uint64 n, std::complex<float> alpha,
1581 const DeviceMemory<std::complex<float>> &a, int lda,
1582 const DeviceMemory<std::complex<float>> &b, int ldb,
1583 std::complex<float> beta,
1584 DeviceMemory<std::complex<float>> *c, int ldc);
1585 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1586 uint64 n, std::complex<double> alpha,
1587 const DeviceMemory<std::complex<double>> &a, int lda,
1588 const DeviceMemory<std::complex<double>> &b, int ldb,
1589 std::complex<double> beta,
1590 DeviceMemory<std::complex<double>> *c, int ldc);
1591
1592 // See BlasSupport::DoBlasSyrk.
1593 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1594 uint64 k, float alpha, const DeviceMemory<float> &a,
1595 int lda, float beta, DeviceMemory<float> *c, int ldc);
1596 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1597 uint64 k, double alpha, const DeviceMemory<double> &a,
1598 int lda, double beta, DeviceMemory<double> *c, int ldc);
1599 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1600 uint64 k, std::complex<float> alpha,
1601 const DeviceMemory<std::complex<float>> &a, int lda,
1602 std::complex<float> beta,
1603 DeviceMemory<std::complex<float>> *c, int ldc);
1604 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1605 uint64 k, std::complex<double> alpha,
1606 const DeviceMemory<std::complex<double>> &a, int lda,
1607 std::complex<double> beta,
1608 DeviceMemory<std::complex<double>> *c, int ldc);
1609
1610 // See BlasSupport::DoBlasSyr2k.
1611 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1612 uint64 k, float alpha, const DeviceMemory<float> &a,
1613 int lda, const DeviceMemory<float> &b, int ldb,
1614 float beta, DeviceMemory<float> *c, int ldc);
1615 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1616 uint64 k, double alpha, const DeviceMemory<double> &a,
1617 int lda, const DeviceMemory<double> &b, int ldb,
1618 double beta, DeviceMemory<double> *c, int ldc);
1619 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1620 uint64 k, std::complex<float> alpha,
1621 const DeviceMemory<std::complex<float>> &a, int lda,
1622 const DeviceMemory<std::complex<float>> &b, int ldb,
1623 std::complex<float> beta,
1624 DeviceMemory<std::complex<float>> *c, int ldc);
1625 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1626 uint64 k, std::complex<double> alpha,
1627 const DeviceMemory<std::complex<double>> &a, int lda,
1628 const DeviceMemory<std::complex<double>> &b, int ldb,
1629 std::complex<double> beta,
1630 DeviceMemory<std::complex<double>> *c, int ldc);
1631
1632 // See BlasSupport::DoBlasTrmm.
1633 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1634 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1635 uint64 n, float alpha, const DeviceMemory<float> &a,
1636 int lda, DeviceMemory<float> *b, int ldb);
1637 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1638 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1639 uint64 n, double alpha, const DeviceMemory<double> &a,
1640 int lda, DeviceMemory<double> *b, int ldb);
1641 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1642 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1643 uint64 n, std::complex<float> alpha,
1644 const DeviceMemory<std::complex<float>> &a, int lda,
1645 DeviceMemory<std::complex<float>> *b, int ldb);
1646 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1647 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1648 uint64 n, std::complex<double> alpha,
1649 const DeviceMemory<std::complex<double>> &a, int lda,
1650 DeviceMemory<std::complex<double>> *b, int ldb);
1651
1652 // See BlasSupport::DoBlasTrsm.
1653 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1654 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1655 uint64 n, float alpha, const DeviceMemory<float> &a,
1656 int lda, DeviceMemory<float> *b, int ldb);
1657 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1658 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1659 uint64 n, double alpha, const DeviceMemory<double> &a,
1660 int lda, DeviceMemory<double> *b, int ldb);
1661 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1662 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1663 uint64 n, std::complex<float> alpha,
1664 const DeviceMemory<std::complex<float>> &a, int lda,
1665 DeviceMemory<std::complex<float>> *b, int ldb);
1666 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1667 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1668 uint64 n, std::complex<double> alpha,
1669 const DeviceMemory<std::complex<double>> &a, int lda,
1670 DeviceMemory<std::complex<double>> *b, int ldb);
1671
1672 // See FftSupport::DoFft.
1673 Stream &ThenFft(fft::Plan *plan,
1674 const DeviceMemory<std::complex<float>> &input,
1675 DeviceMemory<std::complex<float>> *output);
1676 Stream &ThenFft(fft::Plan *plan,
1677 const DeviceMemory<std::complex<double>> &input,
1678 DeviceMemory<std::complex<double>> *output);
1679 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1680 DeviceMemory<std::complex<float>> *output);
1681 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1682 DeviceMemory<std::complex<double>> *output);
1683 Stream &ThenFft(fft::Plan *plan,
1684 const DeviceMemory<std::complex<float>> &input,
1685 DeviceMemory<float> *output);
1686 Stream &ThenFft(fft::Plan *plan,
1687 const DeviceMemory<std::complex<double>> &input,
1688 DeviceMemory<double> *output);
1689
1690 // Makes the RNG use the provided value as the basis for further generation.
1691 // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1692 // sources of seed data if the default (high quality) sources are not
1693 // desired.
1694 // For most use cases, this function will not be necessary; each provided
1695 // back-end implementation will be appropriately seeded by default.
1696 // At a minimum 16 bytes of data are required in the seed buffer.
1697 //
1698 // To seed with good (non-reproducible) data:
1699 // File* f = File::Open("/dev/random", "r");
1700 // int64 bytes_read = f->Read(seed_data, bytes_to_read);
1701 // < error checking >
1702 // stream.ThenSetRngSeed(seed_data, bytes_read);
1703 //
1704 // To seed with reproducible data:
1705 // uint64_t seed_data[2] = { <data> };
1706 // stream.ThenSetRngSeed(seed_data, 16);
1707 Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1708
1709 // Populates the memory indicated by values with uniform-random-distribution
1710 // values. TODO(leary) seeding API/description
1711 //
1712 // Uses the type and size of the DeviceMemory to infer what data should be
1713 // populated.
1714 Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1715 Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1716 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1717 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1718 Stream &ThenPopulateRandGaussian(float mean, float stddev,
1719 DeviceMemory<float> *values);
1720 Stream &ThenPopulateRandGaussian(double mean, double stddev,
1721 DeviceMemory<double> *values);
1722
1723 // Entrain onto the stream: a memcpy to a host destination from a GPU source
1724 // of the given target size. host_dst must be a pointer to host memory
1725 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1726 // then registered with StreamExecutor::HostMemoryRegister.
1727 Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1728 uint64 size);
1729
1730 // Entrain onto the stream: a memcpy to a GPU destination from a host source
1731 // of the given target size. host_src must be a pointer to host memory
1732 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1733 // then registered with StreamExecutor::HostMemoryRegister.
1734 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1735 uint64 size);
1736
1737 // Alternative interface for memcpying from device to host that takes an
1738 // array slice. Checks that the destination size can accommodate the host
1739 // slice size.
1740 template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1741 Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1742 port::MutableArraySlice<T> host_dst) {
1743 auto host_size = host_dst.size() * sizeof(T);
1744 CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1745 return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1746 }
1747
1748 // Alternative interface for memcpying from host to device that takes an
1749 // array slice. Checks that the destination size can accommodate the host
1750 // slice size.
1751 template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1752 Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1753 DeviceMemory<T> *gpu_dst) {
1754 auto host_size = host_src.size() * sizeof(T);
1755 CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1756 return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1757 }
1758
1759 // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1760 // of the given target size. gpu_src/dst must be pointers to GPU memory and
1761 // peer access must be enabled between their owning StreamExecutors.
1762 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1763 uint64 size);
1764
1765 // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1766 // ensuring that the host pointer isn't getting confused accidentally with a
1767 // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1768 Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1769 const DeviceMemoryBase &gpu_src, uint64 size) {
1770 return ThenMemcpy(gpu_dst, gpu_src, size);
1771 }
1772
1773 // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1774 // The location must not be null.
1775 Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1776
1777 // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1778 // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1779 // by 4). The location must not be null.
1780 Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1781
1782 // Enqueue a forward operation of the RNN model onto the stream.
1783 // See DnnSupport::DoRnnForward for more details.
1784 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1785 const dnn::RnnSequenceTensorDescriptor &input_desc,
1786 const DeviceMemory<Eigen::half> &input_data,
1787 const dnn::RnnStateTensorDescriptor &input_h_desc,
1788 const DeviceMemory<Eigen::half> &input_h_data,
1789 const dnn::RnnStateTensorDescriptor &input_c_desc,
1790 const DeviceMemory<Eigen::half> &input_c_data,
1791 const DeviceMemory<Eigen::half> ¶ms,
1792 const dnn::RnnSequenceTensorDescriptor &output_desc,
1793 DeviceMemory<Eigen::half> *output_data,
1794 const dnn::RnnStateTensorDescriptor &output_h_desc,
1795 DeviceMemory<Eigen::half> *output_h_data,
1796 const dnn::RnnStateTensorDescriptor &output_c_desc,
1797 DeviceMemory<Eigen::half> *output_c_data,
1798 bool is_training,
1799 ScratchAllocator *reserve_space_allocator,
1800 ScratchAllocator *workspace_allocator,
1801 dnn::ProfileResult *output_profile_result);
1802
1803 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1804 const dnn::RnnSequenceTensorDescriptor &input_desc,
1805 const DeviceMemory<float> &input_data,
1806 const dnn::RnnStateTensorDescriptor &input_h_desc,
1807 const DeviceMemory<float> &input_h_data,
1808 const dnn::RnnStateTensorDescriptor &input_c_desc,
1809 const DeviceMemory<float> &input_c_data,
1810 const DeviceMemory<float> ¶ms,
1811 const dnn::RnnSequenceTensorDescriptor &output_desc,
1812 DeviceMemory<float> *output_data,
1813 const dnn::RnnStateTensorDescriptor &output_h_desc,
1814 DeviceMemory<float> *output_h_data,
1815 const dnn::RnnStateTensorDescriptor &output_c_desc,
1816 DeviceMemory<float> *output_c_data, bool is_training,
1817 ScratchAllocator *reserve_space_allocator,
1818 ScratchAllocator *workspace_allocator,
1819 dnn::ProfileResult *output_profile_result);
1820
1821 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1822 const dnn::RnnSequenceTensorDescriptor &input_desc,
1823 const DeviceMemory<double> &input_data,
1824 const dnn::RnnStateTensorDescriptor &input_h_desc,
1825 const DeviceMemory<double> &input_h_data,
1826 const dnn::RnnStateTensorDescriptor &input_c_desc,
1827 const DeviceMemory<double> &input_c_data,
1828 const DeviceMemory<double> ¶ms,
1829 const dnn::RnnSequenceTensorDescriptor &output_desc,
1830 DeviceMemory<double> *output_data,
1831 const dnn::RnnStateTensorDescriptor &output_h_desc,
1832 DeviceMemory<double> *output_h_data,
1833 const dnn::RnnStateTensorDescriptor &output_c_desc,
1834 DeviceMemory<double> *output_c_data, bool is_training,
1835 ScratchAllocator *reserve_space_allocator,
1836 ScratchAllocator *workspace_allocator,
1837 dnn::ProfileResult *output_profile_result);
1838
1839 // Enqueue a backward operation of the RNN model onto the stream.
1840 // See DnnSupport::DoRnnBackward for more details.
1841 Stream &ThenRnnBackward(
1842 const dnn::RnnDescriptor &rnn_desc,
1843 const dnn::RnnSequenceTensorDescriptor &input_desc,
1844 const DeviceMemory<Eigen::half> &input_data,
1845 const dnn::RnnStateTensorDescriptor &input_h_desc,
1846 const DeviceMemory<Eigen::half> &input_h_data,
1847 const dnn::RnnStateTensorDescriptor &input_c_desc,
1848 const DeviceMemory<Eigen::half> &input_c_data,
1849 const DeviceMemory<Eigen::half> ¶ms,
1850 const dnn::RnnSequenceTensorDescriptor &output_desc,
1851 const DeviceMemory<Eigen::half> &output_data,
1852 const dnn::RnnStateTensorDescriptor &output_h_desc,
1853 const DeviceMemory<Eigen::half> &output_h_data,
1854 const dnn::RnnStateTensorDescriptor &output_c_desc,
1855 const DeviceMemory<Eigen::half> &output_c_data,
1856 const DeviceMemory<Eigen::half> &output_backprop_data,
1857 const DeviceMemory<Eigen::half> &output_h_backprop_data,
1858 const DeviceMemory<Eigen::half> &output_c_backprop_data,
1859 DeviceMemory<Eigen::half> *input_backprop_data,
1860 DeviceMemory<Eigen::half> *input_h_backprop_data,
1861 DeviceMemory<Eigen::half> *input_c_backprop_data,
1862 DeviceMemory<Eigen::half> *params_backprop_data,
1863 DeviceMemory<uint8> *reserve_space_data,
1864 ScratchAllocator *workspace_allocator,
1865 dnn::ProfileResult *output_profile_result);
1866
1867 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1868 const dnn::RnnSequenceTensorDescriptor &input_desc,
1869 const DeviceMemory<float> &input_data,
1870 const dnn::RnnStateTensorDescriptor &input_h_desc,
1871 const DeviceMemory<float> &input_h_data,
1872 const dnn::RnnStateTensorDescriptor &input_c_desc,
1873 const DeviceMemory<float> &input_c_data,
1874 const DeviceMemory<float> ¶ms,
1875 const dnn::RnnSequenceTensorDescriptor &output_desc,
1876 const DeviceMemory<float> &output_data,
1877 const dnn::RnnStateTensorDescriptor &output_h_desc,
1878 const DeviceMemory<float> &output_h_data,
1879 const dnn::RnnStateTensorDescriptor &output_c_desc,
1880 const DeviceMemory<float> &output_c_data,
1881 const DeviceMemory<float> &output_backprop_data,
1882 const DeviceMemory<float> &output_h_backprop_data,
1883 const DeviceMemory<float> &output_c_backprop_data,
1884 DeviceMemory<float> *input_backprop_data,
1885 DeviceMemory<float> *input_h_backprop_data,
1886 DeviceMemory<float> *input_c_backprop_data,
1887 DeviceMemory<float> *params_backprop_data,
1888 DeviceMemory<uint8> *reserve_space_data,
1889 ScratchAllocator *workspace_allocator,
1890 dnn::ProfileResult *output_profile_result);
1891
1892 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1893 const dnn::RnnSequenceTensorDescriptor &input_desc,
1894 const DeviceMemory<double> &input_data,
1895 const dnn::RnnStateTensorDescriptor &input_h_desc,
1896 const DeviceMemory<double> &input_h_data,
1897 const dnn::RnnStateTensorDescriptor &input_c_desc,
1898 const DeviceMemory<double> &input_c_data,
1899 const DeviceMemory<double> ¶ms,
1900 const dnn::RnnSequenceTensorDescriptor &output_desc,
1901 const DeviceMemory<double> &output_data,
1902 const dnn::RnnStateTensorDescriptor &output_h_desc,
1903 const DeviceMemory<double> &output_h_data,
1904 const dnn::RnnStateTensorDescriptor &output_c_desc,
1905 const DeviceMemory<double> &output_c_data,
1906 const DeviceMemory<double> &output_backprop_data,
1907 const DeviceMemory<double> &output_h_backprop_data,
1908 const DeviceMemory<double> &output_c_backprop_data,
1909 DeviceMemory<double> *input_backprop_data,
1910 DeviceMemory<double> *input_h_backprop_data,
1911 DeviceMemory<double> *input_c_backprop_data,
1912 DeviceMemory<double> *params_backprop_data,
1913 DeviceMemory<uint8> *reserve_space_data,
1914 ScratchAllocator *workspace_allocator,
1915 dnn::ProfileResult *output_profile_result);
1916
1917 // Enqueue a CTCLoss operation onto the stream.
1918 // See DnnSupport::DoCtcLoss for more details.
1919 Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
1920 const DeviceMemory<float> &probs_data,
1921 absl::Span<const int> labels_data,
1922 absl::Span<const int> labels_lengths_data,
1923 absl::Span<const int> input_lengths_data,
1924 DeviceMemory<float> *costs_data,
1925 const dnn::RnnStateTensorDescriptor &grads_desc,
1926 DeviceMemory<float> *grads_data,
1927 ScratchAllocator *workspace_allocator);
1928
1929 // Enqueue onto the stream a operation that transforms a tensor.
1930 // See DnnSupport::DoTransformTensor for more details.
1931 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1932 dnn::DataType input_type,
1933 const DeviceMemoryBase &input_data,
1934 const dnn::BatchDescriptor &output_desc,
1935 dnn::DataType output_type, float scale,
1936 DeviceMemoryBase *output_data);
1937
1938 // The templated version of the above ThenTransformTensor. Useful when the
1939 // input and output types are statically known.
1940 template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1941 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1942 const DeviceMemory<InElemT> &input_data,
1943 const dnn::BatchDescriptor &output_desc,
1944 DeviceMemory<OutElemT> *output_data) {
1945 return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1946 input_data, output_desc,
1947 dnn::ToDataType<OutElemT>(), output_data);
1948 }
1949
1950 // (Synchronously) block the host code waiting for the operations
1951 // entrained on the stream (enqueued to this point in program
1952 // execution) to complete.
1953 //
1954 // Returns an OK status if the blocking was successful and the stream is ok().
1955 // Otherwise returns an error describing why the blocking failed.
1956 port::Status BlockHostUntilDone() LOCKS_EXCLUDED(mu_);
1957
1958 // Warning! This method interacts with internal threads in
1959 // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1960 // use
1961 // only. Please check with a member of the FASTR team before making use of
1962 // this method.
1963 //
1964 // Entrains onto the stream a function to be executed on the host at some
1965 // point in the future.
1966 // Async host callbacks DO NOT block the stream as device functions (or as
1967 // synchronous host callbacks). No synchronization is possible with
1968 // asynchronous callbacks; they are strictly fire-and-forget.
1969 // This method is private due to the potential for undefined behavior with
1970 // synchronization using OpenCL user events.
1971 // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1972 // parameter will still be valid - this Stream may not be!
1973 // Any callbacks requiring device API calls must use this method.
1974 Stream &ThenEnqueueOnBackgroundThread(
1975 std::function<void(StreamExecutor *)> task);
1976
1977 // Returns the (opaque) platform-specific backing object. Ownership is not
1978 // transferred to the caller.
implementation()1979 internal::StreamInterface *implementation() { return implementation_.get(); }
1980
1981 // Entrains onto the stream a callback to the host (from the device).
1982 // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1983 // never fail or its failure is inconsequential.
1984 //
1985 // This is kept for backward compatibility. Future code should use
1986 // ThenDoHostCallbackWithStatus and explicitly return a success status.
1987 // TODO(b/112125301): Eventually remove this method.
1988 Stream &ThenDoHostCallback(std::function<void()> callback);
1989
1990 // Entrains onto the stream a callback to the host (from the device).
1991 // Host callbacks block/occupy the stream just as device functions
1992 // (execute one at a time, block later stream operations).
1993 // Whether the callback return status affects the result of BlockHostUntilDone
1994 // is platform-dependent.
1995 //
1996 // Behavior is undefined when synchronizing using OpenCL user events.
1997 // Behavior is undefined if host callbacks call device routines or insert
1998 // them into any stream.
1999 //
2000 // On certain platforms, ThenDoHostCallback is expected to have significant
2001 // negative effects on performance.
2002 Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
2003
2004 // Runs the given callback after the next call to BlockHostUntilDone on this
2005 // stream (or after the Stream does BlockHostUntilDone iin its destructor).
2006 // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
2007 // some use cases.
2008 Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
2009
2010 // Returns the StreamExecutor (parent object) associated with this stream.
parent()2011 StreamExecutor *parent() const {
2012 CHECK(parent_ != nullptr);
2013 return parent_;
2014 }
2015
2016 // Returns the (internal usage) temporary-memory-allocation manager associated
2017 // with this stream.
2018 internal::TemporaryMemoryManager *temporary_memory_manager();
2019
2020 // Returns a debugging string "[stream=0x...,impl=0x...]".
2021 string DebugStreamPointers() const;
2022
2023 private:
2024 friend class host::HostBlas; // for parent_.
2025 friend class host::HostFft; // for parent_.
2026 friend class host::HostRng; // for parent_.
2027 template <typename... Args>
2028 friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
2029 friend class ocl::CLBlas; // for parent_.
2030
InErrorState()2031 bool InErrorState() const LOCKS_EXCLUDED(mu_) {
2032 absl::ReaderMutexLock lock(&mu_);
2033 return !ok_;
2034 }
2035
2036 // Sets the error state if operation_retcode is false.
2037 // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)2038 void CheckError(bool operation_retcode) LOCKS_EXCLUDED(mu_) {
2039 if (operation_retcode) {
2040 return;
2041 }
2042 absl::MutexLock lock(&mu_);
2043 ok_ = false;
2044 }
2045
2046 // Checks the status and logs the error message, if any.
2047 void CheckStatus(port::Status status) LOCKS_EXCLUDED(mu_);
2048
SetError()2049 void SetError() { CheckError(false /* = operation_retcode */); }
2050
SetErrorAndLogNoDnnSupport()2051 void SetErrorAndLogNoDnnSupport() {
2052 SetError();
2053 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
2054 "without DNN support";
2055 }
2056
2057 // Runs the set of callbacks that are intended to run after
2058 // BlockHostUntilDone.
2059 void RunAfterBlockHostUntilDoneCallbacks();
2060
2061 // The StreamExecutor that supports the operation of this stream.
2062 StreamExecutor *parent_;
2063
2064 // The platform-dependent implementation that the StreamExecutor interface
2065 // delegates to.
2066 std::unique_ptr<internal::StreamInterface> implementation_;
2067
2068 // mutex that guards the allocation / error state flags.
2069 // Mutable so that it can be obtained via const reader lock.
2070 mutable absl::Mutex mu_;
2071
2072 // Whether Init() was successfully called to allocate this stream on the
2073 // underlying platform. It simply flips from 0 to 1 with a sanity check.
2074 // See StreamExecutor::AllocateStream.
2075 bool allocated_ GUARDED_BY(mu_);
2076
2077 // Whether all operations have entrained successfully to the current program
2078 // point.
2079 bool ok_ GUARDED_BY(mu_);
2080
2081 // Sub-streams that are generated from this stream. Each element has a pointer
2082 // to sub-stream and a boolean value indicating if this substream is ready to
2083 // be reused.
2084 std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2085 GUARDED_BY(mu_);
2086
2087 // Streams can allocate temporary memories to help with work they enqueue
2088 // (e.g. for scratch memory spaces). This member tracks those allocations and
2089 // notes when they can be reclaimed -- reclamation is attempted when
2090 // BlockHostUntilDone() is called.
2091 internal::TemporaryMemoryManager temporary_memory_manager_;
2092
2093 // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
2094 std::vector<std::function<void()>> after_block_host_until_done_callbacks_
2095 GUARDED_BY(mu_);
2096
2097 // Implementation of ThenConvolveBackwardBias that is shared by all types.
2098 template <typename T>
2099 Stream &ThenConvolveBackwardBiasImpl(
2100 const dnn::BatchDescriptor &input_descriptor,
2101 const DeviceMemory<T> &input_data,
2102 const dnn::BatchDescriptor &bias_descriptor,
2103 DeviceMemory<T> *backward_bias_data);
2104
2105 SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2106 };
2107
2108 ////////////
2109 // Inlines
2110
2111 template <typename T>
2112 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2113 Stream::AllocateTemporaryArray(uint64 element_count) {
2114 return temporary_memory_manager_.AllocateArray<T>(element_count);
2115 }
2116
temporary_memory_manager()2117 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2118 return &temporary_memory_manager_;
2119 }
2120
2121 template <>
2122 struct Quantization<uint8> {
2123 static constexpr dnn::QuantizedActivationMode kModeId =
2124 dnn::QuantizedActivationMode::k8Bit;
2125 };
2126
2127 template <>
2128 struct Quantization<uint16> {
2129 static constexpr dnn::QuantizedActivationMode kModeId =
2130 dnn::QuantizedActivationMode::k16Bit;
2131 };
2132
2133 template <>
2134 struct Quantization<int32> {
2135 static constexpr dnn::QuantizedActivationMode kModeId =
2136 dnn::QuantizedActivationMode::k32Bit;
2137 };
2138
2139 } // namespace stream_executor
2140
2141 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2142