1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23
24 #include <complex>
25 #include <functional>
26 #include <memory>
27
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/stream_executor/blas.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/dnn.h"
32 #include "tensorflow/stream_executor/event.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/host_or_device_scalar.h"
35 #include "tensorflow/stream_executor/kernel.h"
36 #include "tensorflow/stream_executor/launch_dim.h"
37 #include "tensorflow/stream_executor/lib/array_slice.h"
38 #include "tensorflow/stream_executor/platform/mutex.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/platform/thread_annotations.h"
41 #include "tensorflow/stream_executor/temporary_memory_manager.h"
42
43 namespace stream_executor {
44
45 namespace host {
46 class HostBlas;
47 class HostFft;
48 class HostRng;
49 class HostTimer;
50 } // namespace host
51
52 namespace ocl {
53 class CLBlas;
54 } // namespace ocl
55
56 namespace internal {
57 class StreamInterface;
58 } // namespace internal
59
60 class DeviceMemoryBase;
61 template <typename ElemT>
62 class DeviceMemory;
63
64 class Timer;
65
66 namespace dnn {
67 class BatchDescriptor;
68 class FilterDescriptor;
69 class ConvolutionDescriptor;
70 class ProfileResult;
71 class AlgorithmDesc;
72 } // namespace dnn
73
74 class StreamExecutor;
75 class ScratchAllocator;
76
77 // Convert a type to the corresponding QuantizedActivationMode.
78 template <typename ElementType>
79 struct Quantization;
80
81 // Represents a stream of dependent computations on a GPU device.
82 //
83 // The operations within a stream execute linearly and asynchronously until
84 // BlockHostUntilDone() is invoked, which synchronously joins host code with
85 // the execution of the stream.
86 //
87 // If any given operation fails when entraining work for the stream, ok() will
88 // indicate that an error has occurred. After initialization, once a stream is
89 // !ok(), it will never be ok().
90 //
91 // Thread-safe post-initialization.
92 class Stream {
93 public:
94 // Instantiate a stream tied to parent as a platform executor. Work
95 // entrained onto this stream will be launched/managed on that
96 // StreamExecutor's platform.
97 explicit Stream(StreamExecutor *parent);
98
99 // Test only. Use an externally-populated value (like a mock) for the
100 // platform-specific stream implementation.
101 Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
102
103 // Deallocates any stream resources that the parent StreamExecutor has
104 // bestowed
105 // upon this object.
106 ~Stream();
107
108 // Returns whether any errors have occurred while entraining work for this
109 // stream.
ok()110 bool ok() const { return !InErrorState(); }
111
112 // Retrieves execution status back into the stream from the underlying
113 // implementation without blocking the stream.
114 //
115 // Normally, Stream::BlockHostUntilDone is used to get execution status.
116 // However, some devices use out-of-band mechnanisms to ensure their streams
117 // have finished on-device work, without needing to block the streams. (These
118 // devices should also override AllowsSyncOnCompletion to return false.) For
119 // these devices, this method can be used after work is finished to retrieve
120 // execution status.
121 port::Status RefreshStatus() LOCKS_EXCLUDED(mu_);
122
123 // Initialize the stream. This must be performed before entraining any other
124 // operations.
125 Stream &Init() LOCKS_EXCLUDED(mu_);
126
127 // Initializes timer t via the StreamExecutor.
128 Stream &InitTimer(Timer *t);
129
130 // Convenience wrapper around Init() and InitTimer().
131 Stream &InitWithTimer(Timer *t);
132
133 // Get or create a sub-stream from this stream. If there is any sub-stream in
134 // the pool that can be reused then just return this sub-stream. Otherwise
135 // create a new sub-stream.
136 //
137 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
138 Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
139
140 // Return the sub-stream back to the host stream so that it can be reused
141 // later. Sub-streams that are !ok() will not be reused.
142 //
143 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
144 void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
145
146 // Allocate temporary memories. The stream will deallocate them when blocked
147 // or destroyed.
148 template <typename T>
149 port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
150 AllocateTemporaryArray(uint64 element_count);
151
152 // Entrains onto the stream of operations: a kernel launch with the given
153 // (variadic) parameters for the invocation. These arguments can be things
154 // like DeviceMemory or primitive types such as int. What arguments you may
155 // pass to a given kernel are noted as the template parameters to the
156 // TypedKernel type that the machocc compiler generates.
157 //
158 // Template parameters:
159 // Params... The type list of formal parameters that the typed kernel
160 // expects, which is matched against Args...
161 // Args... The deduced type list for passed actual arguments
162 //
163 // Implementation: A compile-time compatibility check is performed that has
164 // some leniency versus an exact parameter pack match -- for example,
165 // `const DeviceMemory<T>` is considered "pack compatible" with a
166 // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
167 // perfect forwarding support without rvalue references. It also attempts to
168 // spit out helpful static_assert error traces with information as to the
169 // argument number and types that were mismatched.
170 template <typename... Params, typename... Args>
171 Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
172 const TypedKernel<Params...> &kernel, Args... args);
173
174 // Record a "start" event for the interval timer at this point in the
175 // stream's execution (relative to the previously and subsequently enqueued
176 // items in the stream's execution). Streams may be started/stopped multiple
177 // times.
178 Stream &ThenStartTimer(Timer *t);
179
180 // Record a "stop" event for the interval timer at this point in the
181 // stream's execution. See also Stream::ThenStartTimer.
182 Stream &ThenStopTimer(Timer *t);
183
184 // TODO(leary) If work is added to the stream that is being depended upon,
185 // then what? Have to describe what happens.
186 template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)187 Stream &ThenWaitFor(Stream *other, Params... more_streams) {
188 return ThenWaitFor(more_streams...).ThenWaitFor(other);
189 }
190
191 // Create a dependency for this stream's next work on the other stream
192 // completing. Does not take ownership of other, and other must not be
193 // null.
194 //
195 // Checks that a stream does not wait for itself, and it is up to the
196 // user to guarantee that a stream does not come to wait on itself in a
197 // cyclic manner; in that case, behavior is undefined.
198 //
199 // N.B. Base recursion case for the variadic ThenWaitFor.
200 Stream &ThenWaitFor(Stream *other);
201
202 // Waits for all streams values in others.
203 // Checks that there is no shallow circular wait (i.e. that "this" is not in
204 // others)
205 template <typename P>
ThenWaitFor(P others)206 Stream &ThenWaitFor(P others) {
207 for (auto &stream : *others) {
208 CHECK_NE(stream.get(), this);
209 ThenWaitFor(stream.get());
210 }
211 return *this;
212 }
213
214 // Waits for an event object to be set.
215 // Note that ThenRecordEvent must have been called on the event before
216 // you call this function; otherwise the event will be considered complete
217 // and this wait will do nothing.
218 Stream &ThenWaitFor(Event *event);
219
220 // Inserts the specified event into the end of this stream. Once the stream
221 // has processed all events prior to the insertion point, the event will be
222 // marked as completed.
223 // The stream does not take ownership of event - meaning that event's lifetime
224 // must extend past the point at which it is marked complete!
225 Stream &ThenRecordEvent(Event *event);
226
227 ////////////////
228 // DNN support
229 //
230 // See DnnSupport::* for comments on the following methods.
231
232 Stream &ThenBatchNormalizationForward(
233 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
234 const DeviceMemory<float> &offset,
235 const DeviceMemory<float> &estimated_mean,
236 const DeviceMemory<float> &estimated_variance,
237 const dnn::BatchDescriptor &x_desc,
238 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
239 DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
240 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
241 DeviceMemory<float> *saved_inv_var, bool is_training,
242 std::function<const DeviceMemory<float> &()> var_to_inv_var,
243 std::function<void()> inv_var_to_var);
244
245 Stream &ThenBatchNormalizationBackward(
246 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
247 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
248 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
249 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
250 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
251 DeviceMemory<float> *offset_backprop);
252
253 Stream &ThenBatchNormalizationForward(
254 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
255 const DeviceMemory<float> &offset,
256 const DeviceMemory<float> &estimated_mean,
257 const DeviceMemory<float> &estimated_variance,
258 const dnn::BatchDescriptor &x_desc,
259 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
260 DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
261 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
262 DeviceMemory<float> *saved_inv_var, bool is_training,
263 std::function<const DeviceMemory<float> &()> var_to_inv_var,
264 std::function<void()> inv_var_to_var);
265
266 Stream &ThenBatchNormalizationBackward(
267 const DeviceMemory<Eigen::half> &y_backprop,
268 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
269 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
270 const dnn::BatchDescriptor &x_desc,
271 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
272 DeviceMemory<Eigen::half> *x_backprop,
273 DeviceMemory<float> *scale_backprop,
274 DeviceMemory<float> *offset_backprop);
275
276 Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
277 const DeviceMemory<float> &input_data,
278 const dnn::FilterDescriptor &filter_descriptor,
279 const DeviceMemory<float> &filter_data,
280 const dnn::ConvolutionDescriptor &convolution_descriptor,
281 const dnn::BatchDescriptor &output_descriptor,
282 DeviceMemory<float> *output);
283
284 Stream &ThenConvolveQuantized(
285 const dnn::BatchDescriptor &input_descriptor,
286 const DeviceMemory<float> &input_data,
287 const dnn::FilterDescriptor &filter_descriptor,
288 const DeviceMemory<int8> &filter_coefficients,
289 const DeviceMemory<float> &coefficient_scales,
290 const dnn::ConvolutionDescriptor &convolution_descriptor,
291 const dnn::BatchDescriptor &output_descriptor,
292 DeviceMemory<float> *output_data);
293
294 Stream &ThenConvolveQuantized(
295 const dnn::BatchDescriptor &input_descriptor,
296 const DeviceMemory<float> &input_data,
297 const dnn::FilterDescriptor &filter_descriptor,
298 const DeviceMemory<int16> &filter_coefficients,
299 const DeviceMemory<float> &coefficient_scales,
300 const dnn::ConvolutionDescriptor &convolution_descriptor,
301 const dnn::BatchDescriptor &output_descriptor,
302 DeviceMemory<float> *output_data);
303
304 Stream &ThenConvolveWithAlgorithm(
305 const dnn::BatchDescriptor &input_descriptor,
306 const DeviceMemory<double> &input_data,
307 const dnn::FilterDescriptor &filter_descriptor,
308 const DeviceMemory<double> &filter_data,
309 const dnn::ConvolutionDescriptor &convolution_descriptor,
310 const dnn::BatchDescriptor &output_descriptor,
311 DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
312 const dnn::AlgorithmConfig &algorithm_config,
313 dnn::ProfileResult *output_profile_result);
314
315 Stream &ThenConvolveWithAlgorithm(
316 const dnn::BatchDescriptor &input_descriptor,
317 const DeviceMemory<float> &input_data,
318 const dnn::FilterDescriptor &filter_descriptor,
319 const DeviceMemory<float> &filter_data,
320 const dnn::ConvolutionDescriptor &convolution_descriptor,
321 const dnn::BatchDescriptor &output_descriptor,
322 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
323 const dnn::AlgorithmConfig &algorithm_config,
324 dnn::ProfileResult *output_profile_result);
325
326 Stream &ThenConvolveWithAlgorithm(
327 const dnn::BatchDescriptor &input_descriptor,
328 const DeviceMemory<Eigen::half> &input_data,
329 const dnn::FilterDescriptor &filter_descriptor,
330 const DeviceMemory<Eigen::half> &filter_data,
331 const dnn::ConvolutionDescriptor &convolution_descriptor,
332 const dnn::BatchDescriptor &output_descriptor,
333 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
334 const dnn::AlgorithmConfig &algorithm_config,
335 dnn::ProfileResult *output_profile_result);
336
337 Stream &ThenFusedConvolveWithAlgorithm(
338 const dnn::BatchDescriptor &conv_input_descriptor,
339 const DeviceMemory<double> &conv_input_data, double conv_input_scale,
340 const dnn::FilterDescriptor &filter_descriptor,
341 const DeviceMemory<double> &filter_data,
342 const dnn::ConvolutionDescriptor &convolution_descriptor,
343 const DeviceMemory<double> &side_input_data, double side_input_scale,
344 const dnn::BatchDescriptor &bias_descriptor,
345 const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
346 const dnn::BatchDescriptor &output_descriptor,
347 DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
348 const dnn::AlgorithmConfig &algorithm_config,
349 dnn::ProfileResult *output_profile_result);
350
351 Stream &ThenFusedConvolveWithAlgorithm(
352 const dnn::BatchDescriptor &conv_input_descriptor,
353 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
354 const dnn::FilterDescriptor &filter_descriptor,
355 const DeviceMemory<float> &filter_data,
356 const dnn::ConvolutionDescriptor &convolution_descriptor,
357 const DeviceMemory<float> &side_input_data, float side_input_scale,
358 const dnn::BatchDescriptor &bias_descriptor,
359 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
360 const dnn::BatchDescriptor &output_descriptor,
361 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
362 const dnn::AlgorithmConfig &algorithm_config,
363 dnn::ProfileResult *output_profile_result);
364
365 Stream &ThenFusedConvolveWithAlgorithm(
366 const dnn::BatchDescriptor &conv_input_descriptor,
367 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
368 const dnn::FilterDescriptor &filter_descriptor,
369 const DeviceMemory<Eigen::half> &filter_data,
370 const dnn::ConvolutionDescriptor &convolution_descriptor,
371 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
372 const dnn::BatchDescriptor &bias_descriptor,
373 const DeviceMemory<Eigen::half> &biases,
374 dnn::ActivationMode activation_mode,
375 const dnn::BatchDescriptor &output_descriptor,
376 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
377 const dnn::AlgorithmConfig &algorithm_config,
378 dnn::ProfileResult *output_profile_result);
379
380 Stream &ThenFusedConvolveWithAlgorithm(
381 const dnn::BatchDescriptor &conv_input_descriptor,
382 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
383 const dnn::FilterDescriptor &filter_descriptor,
384 const DeviceMemory<int8> &filter_data,
385 const dnn::ConvolutionDescriptor &convolution_descriptor,
386 const DeviceMemory<int8> &side_input_data, float side_input_scale,
387 const dnn::BatchDescriptor &bias_descriptor,
388 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
389 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
390 ScratchAllocator *scratch_allocator,
391 const dnn::AlgorithmConfig &algorithm_config,
392 dnn::ProfileResult *output_profile_result);
393
394 Stream &ThenSeparableConvolve(
395 const dnn::BatchDescriptor &input_descriptor,
396 const DeviceMemory<float> &input_data,
397 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
398 const DeviceMemory<float> &first_weights,
399 const DeviceMemory<float> &second_weights,
400 const dnn::ConvolutionDescriptor &convolution_descriptor,
401 const dnn::BatchDescriptor &output_descriptor,
402 DeviceMemory<float> *output);
403
404 Stream &ThenConvolveBackwardDataWithAlgorithm(
405 const dnn::FilterDescriptor &filter_descriptor,
406 const DeviceMemory<double> &filter_data,
407 const dnn::BatchDescriptor &output_descriptor,
408 DeviceMemory<double> backward_output_data,
409 const dnn::ConvolutionDescriptor &convolution_descriptor,
410 const dnn::BatchDescriptor &input_descriptor,
411 DeviceMemory<double> *backward_input_data,
412 ScratchAllocator *scratch_allocator,
413 const dnn::AlgorithmConfig &algorithm_config,
414 dnn::ProfileResult *output_profile_result);
415
416 Stream &ThenConvolveBackwardDataWithAlgorithm(
417 const dnn::FilterDescriptor &filter_descriptor,
418 const DeviceMemory<float> &filter_data,
419 const dnn::BatchDescriptor &output_descriptor,
420 DeviceMemory<float> backward_output_data,
421 const dnn::ConvolutionDescriptor &convolution_descriptor,
422 const dnn::BatchDescriptor &input_descriptor,
423 DeviceMemory<float> *backward_input_data,
424 ScratchAllocator *scratch_allocator,
425 const dnn::AlgorithmConfig &algorithm_config,
426 dnn::ProfileResult *output_profile_result);
427
428 Stream &ThenConvolveBackwardDataWithAlgorithm(
429 const dnn::FilterDescriptor &filter_descriptor,
430 const DeviceMemory<Eigen::half> &filter_data,
431 const dnn::BatchDescriptor &output_descriptor,
432 DeviceMemory<Eigen::half> backward_output_data,
433 const dnn::ConvolutionDescriptor &convolution_descriptor,
434 const dnn::BatchDescriptor &input_descriptor,
435 DeviceMemory<Eigen::half> *backward_input_data,
436 ScratchAllocator *scratch_allocator,
437 const dnn::AlgorithmConfig &algorithm_config,
438 dnn::ProfileResult *output_profile_result);
439
440 Stream &ThenConvolveBackwardFilterWithAlgorithm(
441 const dnn::BatchDescriptor &input_descriptor,
442 const DeviceMemory<double> &input_data,
443 const dnn::BatchDescriptor &output_descriptor,
444 DeviceMemory<double> backward_output_data,
445 const dnn::ConvolutionDescriptor &convolution_descriptor,
446 const dnn::FilterDescriptor &filter_descriptor,
447 DeviceMemory<double> *backward_filter_data,
448 ScratchAllocator *scratch_allocator,
449 const dnn::AlgorithmConfig &algorithm_config,
450 dnn::ProfileResult *output_profile_result);
451
452 Stream &ThenConvolveBackwardFilterWithAlgorithm(
453 const dnn::BatchDescriptor &input_descriptor,
454 const DeviceMemory<float> &input_data,
455 const dnn::BatchDescriptor &output_descriptor,
456 DeviceMemory<float> backward_output_data,
457 const dnn::ConvolutionDescriptor &convolution_descriptor,
458 const dnn::FilterDescriptor &filter_descriptor,
459 DeviceMemory<float> *backward_filter_data,
460 ScratchAllocator *scratch_allocator,
461 const dnn::AlgorithmConfig &algorithm_config,
462 dnn::ProfileResult *output_profile_result);
463
464 Stream &ThenConvolveBackwardFilterWithAlgorithm(
465 const dnn::BatchDescriptor &input_descriptor,
466 const DeviceMemory<Eigen::half> &input_data,
467 const dnn::BatchDescriptor &output_descriptor,
468 DeviceMemory<Eigen::half> backward_output_data,
469 const dnn::ConvolutionDescriptor &convolution_descriptor,
470 const dnn::FilterDescriptor &filter_descriptor,
471 DeviceMemory<Eigen::half> *backward_filter_data,
472 ScratchAllocator *scratch_allocator,
473 const dnn::AlgorithmConfig &algorithm_config,
474 dnn::ProfileResult *output_profile_result);
475
476 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
477 const DeviceMemory<double> &input_data,
478 const dnn::BatchDescriptor &bias_descriptor,
479 DeviceMemory<double> *backward_bias_data);
480
481 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
482 const DeviceMemory<float> &input_data,
483 const dnn::BatchDescriptor &bias_descriptor,
484 DeviceMemory<float> *backward_bias_data);
485
486 Stream &ThenConvolveBackwardBias(
487 const dnn::BatchDescriptor &input_descriptor,
488 const DeviceMemory<Eigen::half> &input_data,
489 const dnn::BatchDescriptor &bias_descriptor,
490 DeviceMemory<Eigen::half> *backward_bias_data);
491
492 Stream &ThenMatMul(const DeviceMemory<float> &input_data,
493 const DeviceMemory<float> &weights,
494 const dnn::BatchDescriptor &input_dimensions,
495 const dnn::BatchDescriptor &output_dimensions,
496 DeviceMemory<float> *output_data);
497
498 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
499 const DeviceMemory<int8> &weights,
500 const DeviceMemory<float> &weight_scales,
501 const dnn::BatchDescriptor &input_dimensions,
502 const dnn::BatchDescriptor &output_dimensions,
503 DeviceMemory<float> *output_data);
504
505 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
506 const DeviceMemory<int16> &weights,
507 const DeviceMemory<float> &weight_scales,
508 const dnn::BatchDescriptor &input_dimensions,
509 const dnn::BatchDescriptor &output_dimensions,
510 DeviceMemory<float> *output_data);
511
512 Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
513 const DeviceMemory<float> &biases,
514 const dnn::BatchDescriptor &dimensions,
515 DeviceMemory<float> *output_data);
516
517 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
518 const dnn::BatchDescriptor &input_dimensions,
519 const DeviceMemory<double> &input_data,
520 const dnn::BatchDescriptor &output_dimensions,
521 DeviceMemory<double> *output_data,
522 ScratchAllocator *workspace_allocator = nullptr);
523
524 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
525 const dnn::BatchDescriptor &input_dimensions,
526 const DeviceMemory<float> &input_data,
527 const dnn::BatchDescriptor &output_dimensions,
528 DeviceMemory<float> *output_data,
529 ScratchAllocator *workspace_allocator = nullptr);
530
531 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
532 const dnn::BatchDescriptor &input_dimensions,
533 const DeviceMemory<Eigen::half> &input_data,
534 const dnn::BatchDescriptor &output_dimensions,
535 DeviceMemory<Eigen::half> *output_data,
536 ScratchAllocator *workspace_allocator = nullptr);
537
538 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
539 const dnn::BatchDescriptor &input_dimensions,
540 const DeviceMemory<int8> &input_data,
541 const dnn::BatchDescriptor &output_dimensions,
542 DeviceMemory<int8> *output_data,
543 ScratchAllocator *workspace_allocator = nullptr);
544
545 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
546 const dnn::BatchDescriptor &input_dimensions,
547 const DeviceMemory<double> &input_data,
548 const dnn::BatchDescriptor &output_dimensions,
549 const DeviceMemory<double> &output_data,
550 const DeviceMemory<double> &input_diff_data,
551 DeviceMemory<double> *output_diff_data,
552 ScratchAllocator *workspace_allocator = nullptr);
553
554 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
555 const dnn::BatchDescriptor &input_dimensions,
556 const DeviceMemory<float> &input_data,
557 const dnn::BatchDescriptor &output_dimensions,
558 const DeviceMemory<float> &output_data,
559 const DeviceMemory<float> &input_diff_data,
560 DeviceMemory<float> *output_diff_data,
561 ScratchAllocator *workspace_allocator = nullptr);
562
563 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
564 const dnn::BatchDescriptor &input_dimensions,
565 const DeviceMemory<Eigen::half> &input_data,
566 const dnn::BatchDescriptor &output_dimensions,
567 const DeviceMemory<Eigen::half> &output_data,
568 const DeviceMemory<Eigen::half> &input_diff_data,
569 DeviceMemory<Eigen::half> *output_diff_data,
570 ScratchAllocator *workspace_allocator = nullptr);
571
572 Stream &ThenNormalizeWithDimensions(
573 const dnn::NormalizeDescriptor &normalize_descriptor,
574 const dnn::BatchDescriptor &dimensions,
575 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
576
577 Stream &ThenNormalizeBackwardWithDimensions(
578 const dnn::NormalizeDescriptor &normalize_descriptor,
579 const dnn::BatchDescriptor &dimensions,
580 const DeviceMemory<float> &raw_data,
581 const DeviceMemory<float> &normalized_data,
582 const DeviceMemory<float> &normalized_variable_gradient,
583 DeviceMemory<float> *raw_variable_gradient,
584 ScratchAllocator *workspace_allocator = nullptr);
585
586 Stream &ThenActivate(dnn::ActivationMode activation_mode,
587 const dnn::BatchDescriptor &dimensions,
588 const DeviceMemory<float> &input_data,
589 DeviceMemory<float> *output_data);
590
591 // Same as ThenActivate, but also takes an options argument that can be used
592 // for platform-specific option flags.
593 Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
594 const dnn::BatchDescriptor &dimensions,
595 const DeviceMemory<float> &input_data,
596 DeviceMemory<float> *output_data,
597 uint64 options);
598
599 Stream &ThenDepthConcatenate(
600 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
601 port::ArraySlice<const DeviceMemory<float> *> input_data,
602 DeviceMemory<float> *output_data);
603
604 Stream &ThenSpaceConcatenate(
605 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
606 port::ArraySlice<const DeviceMemory<float> *> input_data,
607 DeviceMemory<float> *output_data,
608 dnn::SpaceConcatenateMode concat_direction);
609
610 // Change the layout of the data by shrinking one dimension (or set of
611 // dimensions) and growing another dimension (or set of dimensions), while
612 // keeping the total number of data elements constant, and maintaining the
613 // current data ordering.
614 Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
615 const DeviceMemory<float> &input_data,
616 const dnn::BatchDescriptor &output_dimensions,
617 DeviceMemory<float> *output_data);
618
619 // Depth to space takes an X by Y image with depth D*M² and changes it to an
620 // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
621 // the input image is changed to an MxM contiguous area in the output image,
622 // with the values being laid out in raster order specified by
623 // DepthToSpaceLayout, and will have a new depth of D.
624 // See the DoDepthToSpace comment for more information.
625 Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
626 const DeviceMemory<float> &input_data,
627 const dnn::DepthToSpaceLayout &depth_to_space_layout,
628 const int sqrt_depth_reduction,
629 DeviceMemory<float> *output_data);
630
631 // Space to depth is the inverse of depth to space. Space to depth takes each
632 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
633 // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
634 // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
635 // data elements is not changed.
636 Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
637 const DeviceMemory<float> &input_data,
638 const dnn::DepthToSpaceLayout &space_to_depth_layout,
639 const int sqrt_depth_increase,
640 DeviceMemory<float> *output_data);
641
642 Stream &ThenElementwiseOperate(
643 dnn::ElementwiseOperation operation,
644 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
645 port::ArraySlice<const DeviceMemory<float> *> input_data,
646 const dnn::BatchDescriptor &output_dimensions,
647 DeviceMemory<float> *output_data);
648
649 Stream &ThenElementwiseOperateScaledQuantized(
650 dnn::ElementwiseOperation operation,
651 port::ArraySlice<int> input_multiplicands, int output_divisor,
652 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
653 port::ArraySlice<const DeviceMemory<float> *> input_data,
654 const dnn::BatchDescriptor &output_dimensions,
655 DeviceMemory<float> *output_data);
656
657 Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
658 const DeviceMemory<float> &input_data, int64 left_pad,
659 int64 right_pad, int64 top_pad, int64 bottom_pad,
660 DeviceMemory<float> *output_data);
661
662 Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
663 const DeviceMemory<float> &input_data, int64 left_trim,
664 int64 right_trim, int64 top_trim, int64 bottom_trim,
665 DeviceMemory<float> *output_data);
666
667 // Grows the input tensor by replicating the X and Y dimensions. The batch and
668 // depth/feature_map dimensions are unchanged. Currently, the input tensor is
669 // limited to X=1 and Y=1.
670 Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
671 const DeviceMemory<float> &input_data,
672 int64 replicate_x, int64 replicate_y,
673 DeviceMemory<float> *output_data);
674
675 // See DnnSupport::DoMemcpyD2HQuantized.
676 Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
677 dnn::QuantizedActivationMode mode,
678 void *host_dst, uint64 size);
679
680 // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
681 // and uses the Quantization trait to call the generic version of
682 // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
683 template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)684 Stream &ThenMemcpyD2HQuantized(
685 const DeviceMemory<float> &gpu_unquantized_src,
686 port::MutableArraySlice<ElementType> host_dst) {
687 return ThenMemcpyD2HQuantized(
688 gpu_unquantized_src, Quantization<ElementType>::kModeId,
689 host_dst.data(), host_dst.size() * sizeof(ElementType));
690 }
691
692 // See DnnSupport::DoMemcpyH2DQuantized.
693 Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
694 dnn::QuantizedActivationMode mode,
695 DeviceMemory<float> *gpu_unquantized_dst);
696
697 // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
698 // and uses the Quantization trait to call the generic version of
699 // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
700 template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)701 Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
702 DeviceMemory<float> *gpu_unquantized_dst) {
703 return ThenMemcpyH2DQuantized(
704 host_src.data(), host_src.size() * sizeof(ElementType),
705 Quantization<ElementType>::kModeId, gpu_unquantized_dst);
706 }
707
708 // See DnnSupport::DoCopyHostBuffer2Device.
709 Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
710 DeviceMemory<float> *gpu_unquantized_dst);
711
712 // See DnnSupport::DoCopyDevice2HostBuffer.
713 Stream &ThenCopyDevice2HostBuffer(
714 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
715
716 /////////////////
717 // BLAS support
718
719 // See BlasSupport::DoBlasAsum.
720 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
721 int incx, DeviceMemory<float> *result);
722 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
723 int incx, DeviceMemory<double> *result);
724 Stream &ThenBlasAsum(uint64 elem_count,
725 const DeviceMemory<std::complex<float>> &x, int incx,
726 DeviceMemory<float> *result);
727 Stream &ThenBlasAsum(uint64 elem_count,
728 const DeviceMemory<std::complex<double>> &x, int incx,
729 DeviceMemory<double> *result);
730
731 // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
732 // present in DeviceMemory, it must be an execution-time constant (i.e. a
733 // value
734 // that the stream does not change or populate during the course of
735 // execution). The value is effectively captured at stream-enqueue time.
736 Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
737 const DeviceMemory<float> &x, int incx,
738 DeviceMemory<float> *y, int incy);
739 Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
740 const DeviceMemory<double> &x, int incx,
741 DeviceMemory<double> *y, int incy);
742 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
743 const DeviceMemory<std::complex<float>> &x, int incx,
744 DeviceMemory<std::complex<float>> *y, int incy);
745 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
746 const DeviceMemory<std::complex<double>> &x, int incx,
747 DeviceMemory<std::complex<double>> *y, int incy);
748
749 // See BlasSupport::DoBlasCopy.
750 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
751 int incx, DeviceMemory<float> *y, int incy);
752 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
753 int incx, DeviceMemory<double> *y, int incy);
754 Stream &ThenBlasCopy(uint64 elem_count,
755 const DeviceMemory<std::complex<float>> &x, int incx,
756 DeviceMemory<std::complex<float>> *y, int incy);
757 Stream &ThenBlasCopy(uint64 elem_count,
758 const DeviceMemory<std::complex<double>> &x, int incx,
759 DeviceMemory<std::complex<double>> *y, int incy);
760
761 // See BlasSupport::DoBlasDot.
762 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
763 const DeviceMemory<float> &y, int incy,
764 DeviceMemory<float> *result);
765 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
766 int incx, const DeviceMemory<double> &y, int incy,
767 DeviceMemory<double> *result);
768
769 // See BlasSupport::DoBlasDotc.
770 Stream &ThenBlasDotc(uint64 elem_count,
771 const DeviceMemory<std::complex<float>> &x, int incx,
772 const DeviceMemory<std::complex<float>> &y, int incy,
773 DeviceMemory<std::complex<float>> *result);
774 Stream &ThenBlasDotc(uint64 elem_count,
775 const DeviceMemory<std::complex<double>> &x, int incx,
776 const DeviceMemory<std::complex<double>> &y, int incy,
777 DeviceMemory<std::complex<double>> *result);
778
779 // See BlasSupport::DoBlasDotu.
780 Stream &ThenBlasDotu(uint64 elem_count,
781 const DeviceMemory<std::complex<float>> &x, int incx,
782 const DeviceMemory<std::complex<float>> &y, int incy,
783 DeviceMemory<std::complex<float>> *result);
784 Stream &ThenBlasDotu(uint64 elem_count,
785 const DeviceMemory<std::complex<double>> &x, int incx,
786 const DeviceMemory<std::complex<double>> &y, int incy,
787 DeviceMemory<std::complex<double>> *result);
788
789 // See BlasSupport::DoBlasNrm2.
790 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
791 int incx, DeviceMemory<float> *result);
792 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
793 int incx, DeviceMemory<double> *result);
794 Stream &ThenBlasNrm2(uint64 elem_count,
795 const DeviceMemory<std::complex<float>> &x, int incx,
796 DeviceMemory<float> *result);
797 Stream &ThenBlasNrm2(uint64 elem_count,
798 const DeviceMemory<std::complex<double>> &x, int incx,
799 DeviceMemory<double> *result);
800
801 // See BlasSupport::DoBlasRot.
802 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
803 DeviceMemory<float> *y, int incy, float c, float s);
804 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
805 DeviceMemory<double> *y, int incy, double c, double s);
806 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
807 int incx, DeviceMemory<std::complex<float>> *y, int incy,
808 float c, float s);
809 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
810 int incx, DeviceMemory<std::complex<double>> *y, int incy,
811 double c, double s);
812
813 // See BlasSupport::DoBlasRotg.
814 Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
815 DeviceMemory<float> *c, DeviceMemory<float> *s);
816 Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
817 DeviceMemory<double> *c, DeviceMemory<double> *s);
818 Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
819 DeviceMemory<std::complex<float>> *b,
820 DeviceMemory<float> *c,
821 DeviceMemory<std::complex<float>> *s);
822 Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
823 DeviceMemory<std::complex<double>> *b,
824 DeviceMemory<double> *c,
825 DeviceMemory<std::complex<double>> *s);
826
827 // See BlasSupport::DoBlasRotm.
828 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
829 DeviceMemory<float> *y, int incy,
830 const DeviceMemory<float> ¶m);
831 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
832 DeviceMemory<double> *y, int incy,
833 const DeviceMemory<double> ¶m);
834
835 // See BlasSupport::DoBlasRotmg.
836 Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
837 DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
838 DeviceMemory<float> *param);
839 Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
840 DeviceMemory<double> *x1,
841 const DeviceMemory<double> &y1,
842 DeviceMemory<double> *param);
843
844 // See BlasSupport::DoBlasScal.
845 Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
846 int incx);
847 Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
848 int incx);
849 Stream &ThenBlasScal(uint64 elem_count, float alpha,
850 DeviceMemory<std::complex<float>> *x, int incx);
851 Stream &ThenBlasScal(uint64 elem_count, double alpha,
852 DeviceMemory<std::complex<double>> *x, int incx);
853 Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
854 DeviceMemory<std::complex<float>> *x, int incx);
855 Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
856 DeviceMemory<std::complex<double>> *x, int incx);
857
858 // See BlasSupport::DoBlasSwap.
859 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
860 DeviceMemory<float> *y, int incy);
861 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
862 DeviceMemory<double> *y, int incy);
863 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
864 int incx, DeviceMemory<std::complex<float>> *y,
865 int incy);
866 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
867 int incx, DeviceMemory<std::complex<double>> *y,
868 int incy);
869
870 // See BlasSupport::DoBlasIamax.
871 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
872 int incx, DeviceMemory<int> *result);
873 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
874 int incx, DeviceMemory<int> *result);
875 Stream &ThenBlasIamax(uint64 elem_count,
876 const DeviceMemory<std::complex<float>> &x, int incx,
877 DeviceMemory<int> *result);
878 Stream &ThenBlasIamax(uint64 elem_count,
879 const DeviceMemory<std::complex<double>> &x, int incx,
880 DeviceMemory<int> *result);
881
882 // See BlasSupport::DoBlasIamin.
883 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
884 int incx, DeviceMemory<int> *result);
885 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
886 int incx, DeviceMemory<int> *result);
887 Stream &ThenBlasIamin(uint64 elem_count,
888 const DeviceMemory<std::complex<float>> &x, int incx,
889 DeviceMemory<int> *result);
890 Stream &ThenBlasIamin(uint64 elem_count,
891 const DeviceMemory<std::complex<double>> &x, int incx,
892 DeviceMemory<int> *result);
893
894 // See BlasSupport::DoBlasGbmv.
895 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
896 uint64 ku, float alpha, const DeviceMemory<float> &a,
897 int lda, const DeviceMemory<float> &x, int incx,
898 float beta, DeviceMemory<float> *y, int incy);
899 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
900 uint64 ku, double alpha, const DeviceMemory<double> &a,
901 int lda, const DeviceMemory<double> &x, int incx,
902 double beta, DeviceMemory<double> *y, int incy);
903 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
904 uint64 ku, std::complex<float> alpha,
905 const DeviceMemory<std::complex<float>> &a, int lda,
906 const DeviceMemory<std::complex<float>> &x, int incx,
907 std::complex<float> beta,
908 DeviceMemory<std::complex<float>> *y, int incy);
909 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
910 uint64 ku, std::complex<double> alpha,
911 const DeviceMemory<std::complex<double>> &a, int lda,
912 const DeviceMemory<std::complex<double>> &x, int incx,
913 std::complex<double> beta,
914 DeviceMemory<std::complex<double>> *y, int incy);
915
916 // See BlasSupport::DoBlasGemv.
917 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
918 const DeviceMemory<float> &a, int lda,
919 const DeviceMemory<float> &x, int incx, float beta,
920 DeviceMemory<float> *y, int incy);
921 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
922 const DeviceMemory<double> &a, int lda,
923 const DeviceMemory<double> &x, int incx, double beta,
924 DeviceMemory<double> *y, int incy);
925 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
926 std::complex<float> alpha,
927 const DeviceMemory<std::complex<float>> &a, int lda,
928 const DeviceMemory<std::complex<float>> &x, int incx,
929 std::complex<float> beta,
930 DeviceMemory<std::complex<float>> *y, int incy);
931 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
932 std::complex<double> alpha,
933 const DeviceMemory<std::complex<double>> &a, int lda,
934 const DeviceMemory<std::complex<double>> &x, int incx,
935 std::complex<double> beta,
936 DeviceMemory<std::complex<double>> *y, int incy);
937
938 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
939 float alpha, const DeviceMemory<float> &a,
940 int lda, const DeviceMemory<float> &x,
941 int incx, float beta,
942 DeviceMemory<float> *y, int incy,
943 blas::ProfileResult *output_profile_result);
944 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
945 double alpha, const DeviceMemory<double> &a,
946 int lda, const DeviceMemory<double> &x,
947 int incx, double beta,
948 DeviceMemory<double> *y, int incy,
949 blas::ProfileResult *output_profile_result);
950 Stream &ThenBlasGemvWithProfiling(
951 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
952 const DeviceMemory<std::complex<float>> &a, int lda,
953 const DeviceMemory<std::complex<float>> &x, int incx,
954 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
955 blas::ProfileResult *output_profile_result);
956 Stream &ThenBlasGemvWithProfiling(
957 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
958 const DeviceMemory<std::complex<double>> &a, int lda,
959 const DeviceMemory<std::complex<double>> &x, int incx,
960 std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
961 int incy, blas::ProfileResult *output_profile_result);
962
963 // See BlasSupport::DoBlasGer.
964 Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
965 const DeviceMemory<float> &x, int incx,
966 const DeviceMemory<float> &y, int incy,
967 DeviceMemory<float> *a, int lda);
968 Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
969 const DeviceMemory<double> &x, int incx,
970 const DeviceMemory<double> &y, int incy,
971 DeviceMemory<double> *a, int lda);
972
973 // See BlasSupport::DoBlasGerc.
974 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
975 const DeviceMemory<std::complex<float>> &x, int incx,
976 const DeviceMemory<std::complex<float>> &y, int incy,
977 DeviceMemory<std::complex<float>> *a, int lda);
978 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
979 const DeviceMemory<std::complex<double>> &x, int incx,
980 const DeviceMemory<std::complex<double>> &y, int incy,
981 DeviceMemory<std::complex<double>> *a, int lda);
982
983 // See BlasSupport::DoBlasGeru.
984 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
985 const DeviceMemory<std::complex<float>> &x, int incx,
986 const DeviceMemory<std::complex<float>> &y, int incy,
987 DeviceMemory<std::complex<float>> *a, int lda);
988 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
989 const DeviceMemory<std::complex<double>> &x, int incx,
990 const DeviceMemory<std::complex<double>> &y, int incy,
991 DeviceMemory<std::complex<double>> *a, int lda);
992
993 // See BlasSupport::DoBlasHbmv.
994 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
995 std::complex<float> alpha,
996 const DeviceMemory<std::complex<float>> &a, int lda,
997 const DeviceMemory<std::complex<float>> &x, int incx,
998 std::complex<float> beta,
999 DeviceMemory<std::complex<float>> *y, int incy);
1000 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1001 std::complex<double> alpha,
1002 const DeviceMemory<std::complex<double>> &a, int lda,
1003 const DeviceMemory<std::complex<double>> &x, int incx,
1004 std::complex<double> beta,
1005 DeviceMemory<std::complex<double>> *y, int incy);
1006
1007 // See BlasSupport::DoBlasHemv.
1008 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1009 std::complex<float> alpha,
1010 const DeviceMemory<std::complex<float>> &a, int lda,
1011 const DeviceMemory<std::complex<float>> &x, int incx,
1012 std::complex<float> beta,
1013 DeviceMemory<std::complex<float>> *y, int incy);
1014 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1015 std::complex<double> alpha,
1016 const DeviceMemory<std::complex<double>> &a, int lda,
1017 const DeviceMemory<std::complex<double>> &x, int incx,
1018 std::complex<double> beta,
1019 DeviceMemory<std::complex<double>> *y, int incy);
1020
1021 // See BlasSupport::DoBlasHer.
1022 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1023 const DeviceMemory<std::complex<float>> &x, int incx,
1024 DeviceMemory<std::complex<float>> *a, int lda);
1025 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1026 const DeviceMemory<std::complex<double>> &x, int incx,
1027 DeviceMemory<std::complex<double>> *a, int lda);
1028
1029 // See BlasSupport::DoBlasHer2.
1030 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1031 std::complex<float> alpha,
1032 const DeviceMemory<std::complex<float>> &x, int incx,
1033 const DeviceMemory<std::complex<float>> &y, int incy,
1034 DeviceMemory<std::complex<float>> *a, int lda);
1035 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1036 std::complex<double> alpha,
1037 const DeviceMemory<std::complex<double>> &x, int incx,
1038 const DeviceMemory<std::complex<double>> &y, int incy,
1039 DeviceMemory<std::complex<double>> *a, int lda);
1040
1041 // See BlasSupport::DoBlasHpmv.
1042 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1043 std::complex<float> alpha,
1044 const DeviceMemory<std::complex<float>> &ap,
1045 const DeviceMemory<std::complex<float>> &x, int incx,
1046 std::complex<float> beta,
1047 DeviceMemory<std::complex<float>> *y, int incy);
1048 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1049 std::complex<double> alpha,
1050 const DeviceMemory<std::complex<double>> &ap,
1051 const DeviceMemory<std::complex<double>> &x, int incx,
1052 std::complex<double> beta,
1053 DeviceMemory<std::complex<double>> *y, int incy);
1054
1055 // See BlasSupport::DoBlasHpr.
1056 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1057 const DeviceMemory<std::complex<float>> &x, int incx,
1058 DeviceMemory<std::complex<float>> *ap);
1059 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1060 const DeviceMemory<std::complex<double>> &x, int incx,
1061 DeviceMemory<std::complex<double>> *ap);
1062
1063 // See BlasSupport::DoBlasHpr2.
1064 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1065 std::complex<float> alpha,
1066 const DeviceMemory<std::complex<float>> &x, int incx,
1067 const DeviceMemory<std::complex<float>> &y, int incy,
1068 DeviceMemory<std::complex<float>> *ap);
1069 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1070 std::complex<double> alpha,
1071 const DeviceMemory<std::complex<double>> &x, int incx,
1072 const DeviceMemory<std::complex<double>> &y, int incy,
1073 DeviceMemory<std::complex<double>> *ap);
1074
1075 // See BlasSupport::DoBlasSbmv.
1076 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1077 const DeviceMemory<float> &a, int lda,
1078 const DeviceMemory<float> &x, int incx, float beta,
1079 DeviceMemory<float> *y, int incy);
1080 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1081 const DeviceMemory<double> &a, int lda,
1082 const DeviceMemory<double> &x, int incx, double beta,
1083 DeviceMemory<double> *y, int incy);
1084
1085 // See BlasSupport::DoBlasSpmv.
1086 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1087 const DeviceMemory<float> &ap,
1088 const DeviceMemory<float> &x, int incx, float beta,
1089 DeviceMemory<float> *y, int incy);
1090 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1091 const DeviceMemory<double> &ap,
1092 const DeviceMemory<double> &x, int incx, double beta,
1093 DeviceMemory<double> *y, int incy);
1094
1095 // See BlasSupport::DoBlasSpr.
1096 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1097 const DeviceMemory<float> &x, int incx,
1098 DeviceMemory<float> *ap);
1099 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1100 const DeviceMemory<double> &x, int incx,
1101 DeviceMemory<double> *ap);
1102
1103 // See BlasSupport::DoBlasSpr2.
1104 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1105 const DeviceMemory<float> &x, int incx,
1106 const DeviceMemory<float> &y, int incy,
1107 DeviceMemory<float> *ap);
1108 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1109 const DeviceMemory<double> &x, int incx,
1110 const DeviceMemory<double> &y, int incy,
1111 DeviceMemory<double> *ap);
1112
1113 // See BlasSupport::DoBlasSymv.
1114 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1115 const DeviceMemory<float> &a, int lda,
1116 const DeviceMemory<float> &x, int incx, float beta,
1117 DeviceMemory<float> *y, int incy);
1118 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1119 const DeviceMemory<double> &a, int lda,
1120 const DeviceMemory<double> &x, int incx, double beta,
1121 DeviceMemory<double> *y, int incy);
1122
1123 // See BlasSupport::DoBlasSyr.
1124 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1125 const DeviceMemory<float> &x, int incx,
1126 DeviceMemory<float> *a, int lda);
1127 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1128 const DeviceMemory<double> &x, int incx,
1129 DeviceMemory<double> *a, int lda);
1130
1131 // See BlasSupport::DoBlasSyr2.
1132 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1133 const DeviceMemory<float> &x, int incx,
1134 const DeviceMemory<float> &y, int incy,
1135 DeviceMemory<float> *a, int lda);
1136 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1137 const DeviceMemory<double> &x, int incx,
1138 const DeviceMemory<double> &y, int incy,
1139 DeviceMemory<double> *a, int lda);
1140
1141 // See BlasSupport::DoBlasTbmv.
1142 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1143 blas::Diagonal diag, uint64 n, uint64 k,
1144 const DeviceMemory<float> &a, int lda,
1145 DeviceMemory<float> *x, int incx);
1146 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1147 blas::Diagonal diag, uint64 n, uint64 k,
1148 const DeviceMemory<double> &a, int lda,
1149 DeviceMemory<double> *x, int incx);
1150 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1151 blas::Diagonal diag, uint64 n, uint64 k,
1152 const DeviceMemory<std::complex<float>> &a, int lda,
1153 DeviceMemory<std::complex<float>> *x, int incx);
1154 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1155 blas::Diagonal diag, uint64 n, uint64 k,
1156 const DeviceMemory<std::complex<double>> &a, int lda,
1157 DeviceMemory<std::complex<double>> *x, int incx);
1158
1159 // See BlasSupport::DoBlasTbsv.
1160 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1161 blas::Diagonal diag, uint64 n, uint64 k,
1162 const DeviceMemory<float> &a, int lda,
1163 DeviceMemory<float> *x, int incx);
1164 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1165 blas::Diagonal diag, uint64 n, uint64 k,
1166 const DeviceMemory<double> &a, int lda,
1167 DeviceMemory<double> *x, int incx);
1168 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1169 blas::Diagonal diag, uint64 n, uint64 k,
1170 const DeviceMemory<std::complex<float>> &a, int lda,
1171 DeviceMemory<std::complex<float>> *x, int incx);
1172 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1173 blas::Diagonal diag, uint64 n, uint64 k,
1174 const DeviceMemory<std::complex<double>> &a, int lda,
1175 DeviceMemory<std::complex<double>> *x, int incx);
1176
1177 // See BlasSupport::DoBlasTpmv.
1178 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1179 blas::Diagonal diag, uint64 n,
1180 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1181 int incx);
1182 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1183 blas::Diagonal diag, uint64 n,
1184 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1185 int incx);
1186 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1187 blas::Diagonal diag, uint64 n,
1188 const DeviceMemory<std::complex<float>> &ap,
1189 DeviceMemory<std::complex<float>> *x, int incx);
1190 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1191 blas::Diagonal diag, uint64 n,
1192 const DeviceMemory<std::complex<double>> &ap,
1193 DeviceMemory<std::complex<double>> *x, int incx);
1194
1195 // See BlasSupport::DoBlasTpsv.
1196 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1197 blas::Diagonal diag, uint64 n,
1198 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1199 int incx);
1200 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1201 blas::Diagonal diag, uint64 n,
1202 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1203 int incx);
1204 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1205 blas::Diagonal diag, uint64 n,
1206 const DeviceMemory<std::complex<float>> &ap,
1207 DeviceMemory<std::complex<float>> *x, int incx);
1208 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1209 blas::Diagonal diag, uint64 n,
1210 const DeviceMemory<std::complex<double>> &ap,
1211 DeviceMemory<std::complex<double>> *x, int incx);
1212
1213 // See BlasSupport::DoBlasTrmv.
1214 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1215 blas::Diagonal diag, uint64 n,
1216 const DeviceMemory<float> &a, int lda,
1217 DeviceMemory<float> *x, int incx);
1218 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1219 blas::Diagonal diag, uint64 n,
1220 const DeviceMemory<double> &a, int lda,
1221 DeviceMemory<double> *x, int incx);
1222 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1223 blas::Diagonal diag, uint64 n,
1224 const DeviceMemory<std::complex<float>> &a, int lda,
1225 DeviceMemory<std::complex<float>> *x, int incx);
1226 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1227 blas::Diagonal diag, uint64 n,
1228 const DeviceMemory<std::complex<double>> &a, int lda,
1229 DeviceMemory<std::complex<double>> *x, int incx);
1230
1231 // See BlasSupport::DoBlasTrsv.
1232 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1233 blas::Diagonal diag, uint64 n,
1234 const DeviceMemory<float> &a, int lda,
1235 DeviceMemory<float> *x, int incx);
1236 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1237 blas::Diagonal diag, uint64 n,
1238 const DeviceMemory<double> &a, int lda,
1239 DeviceMemory<double> *x, int incx);
1240 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1241 blas::Diagonal diag, uint64 n,
1242 const DeviceMemory<std::complex<float>> &a, int lda,
1243 DeviceMemory<std::complex<float>> *x, int incx);
1244 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1245 blas::Diagonal diag, uint64 n,
1246 const DeviceMemory<std::complex<double>> &a, int lda,
1247 DeviceMemory<std::complex<double>> *x, int incx);
1248
1249 // See BlasSupport::DoBlasGemm.
1250 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1251 uint64 m, uint64 n, uint64 k, float alpha,
1252 const DeviceMemory<Eigen::half> &a, int lda,
1253 const DeviceMemory<Eigen::half> &b, int ldb,
1254 float beta, DeviceMemory<Eigen::half> *c,
1255 int ldc);
1256 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1257 uint64 m, uint64 n, uint64 k, float alpha,
1258 const DeviceMemory<float> &a, int lda,
1259 const DeviceMemory<float> &b, int ldb,
1260 float beta, DeviceMemory<float> *c, int ldc);
1261 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1262 uint64 m, uint64 n, uint64 k, double alpha,
1263 const DeviceMemory<double> &a, int lda,
1264 const DeviceMemory<double> &b, int ldb,
1265 double beta, DeviceMemory<double> *c, int ldc);
1266 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1267 uint64 m, uint64 n, uint64 k,
1268 std::complex<float> alpha,
1269 const DeviceMemory<std::complex<float>> &a,
1270 int lda,
1271 const DeviceMemory<std::complex<float>> &b,
1272 int ldb, std::complex<float> beta,
1273 DeviceMemory<std::complex<float>> *c, int ldc);
1274 TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1275 uint64 m, uint64 n, uint64 k,
1276 std::complex<double> alpha,
1277 const DeviceMemory<std::complex<double>> &a,
1278 int lda,
1279 const DeviceMemory<std::complex<double>> &b,
1280 int ldb, std::complex<double> beta,
1281 DeviceMemory<std::complex<double>> *c,
1282 int ldc);
1283
1284 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1285 blas::Transpose transb, uint64 m, uint64 n,
1286 uint64 k, float alpha,
1287 const DeviceMemory<Eigen::half> &a, int lda,
1288 const DeviceMemory<Eigen::half> &b, int ldb,
1289 float beta, DeviceMemory<Eigen::half> *c,
1290 int ldc,
1291 blas::ProfileResult *output_profile_result);
1292 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1293 blas::Transpose transb, uint64 m, uint64 n,
1294 uint64 k, float alpha,
1295 const DeviceMemory<float> &a, int lda,
1296 const DeviceMemory<float> &b, int ldb,
1297 float beta, DeviceMemory<float> *c, int ldc,
1298 blas::ProfileResult *output_profile_result);
1299 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1300 blas::Transpose transb, uint64 m, uint64 n,
1301 uint64 k, double alpha,
1302 const DeviceMemory<double> &a, int lda,
1303 const DeviceMemory<double> &b, int ldb,
1304 double beta, DeviceMemory<double> *c,
1305 int ldc,
1306 blas::ProfileResult *output_profile_result);
1307 Stream &ThenBlasGemmWithProfiling(
1308 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1309 uint64 k, std::complex<float> alpha,
1310 const DeviceMemory<std::complex<float>> &a, int lda,
1311 const DeviceMemory<std::complex<float>> &b, int ldb,
1312 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1313 blas::ProfileResult *output_profile_result);
1314 Stream &ThenBlasGemmWithProfiling(
1315 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1316 uint64 k, std::complex<double> alpha,
1317 const DeviceMemory<std::complex<double>> &a, int lda,
1318 const DeviceMemory<std::complex<double>> &b, int ldb,
1319 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1320 blas::ProfileResult *output_profile_result);
1321
1322 // See BlasSupport::DoBlasGemmWithAlgorithm.
1323 Stream &ThenBlasGemmWithAlgorithm(
1324 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1325 uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1326 const DeviceMemory<Eigen::half> &a, int lda,
1327 const DeviceMemory<Eigen::half> &b, int ldb,
1328 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1329 int ldc, blas::ComputationType computation_type,
1330 blas::AlgorithmType algorithm,
1331 blas::ProfileResult *output_profile_result);
1332 Stream &ThenBlasGemmWithAlgorithm(
1333 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1334 uint64 k, const HostOrDeviceScalar<int> &alpha,
1335 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1336 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,
1337 int ldc, blas::ComputationType computation_type,
1338 blas::AlgorithmType algorithm,
1339 blas::ProfileResult *output_profile_result);
1340 Stream &ThenBlasGemmWithAlgorithm(
1341 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1342 uint64 k, const HostOrDeviceScalar<float> &alpha,
1343 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1344 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1345 int ldc, blas::ComputationType computation_type,
1346 blas::AlgorithmType algorithm,
1347 blas::ProfileResult *output_profile_result);
1348 Stream &ThenBlasGemmWithAlgorithm(
1349 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1350 uint64 k, const HostOrDeviceScalar<double> &alpha,
1351 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1352 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1353 int ldc, blas::ComputationType computation_type,
1354 blas::AlgorithmType algorithm,
1355 blas::ProfileResult *output_profile_result);
1356 Stream &ThenBlasGemmWithAlgorithm(
1357 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1358 uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1359 const DeviceMemory<std::complex<float>> &a, int lda,
1360 const DeviceMemory<std::complex<float>> &b, int ldb,
1361 const HostOrDeviceScalar<std::complex<float>> &beta,
1362 DeviceMemory<std::complex<float>> *c, int ldc,
1363 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1364 blas::ProfileResult *output_profile_result);
1365 Stream &ThenBlasGemmWithAlgorithm(
1366 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1367 uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1368 const DeviceMemory<std::complex<double>> &a, int lda,
1369 const DeviceMemory<std::complex<double>> &b, int ldb,
1370 const HostOrDeviceScalar<std::complex<double>> &beta,
1371 DeviceMemory<std::complex<double>> *c, int ldc,
1372 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1373 blas::ProfileResult *output_profile_result);
1374
1375 // See BlasSupport::DoBlasGemmBatched.
1376 Stream &ThenBlasGemmBatched(
1377 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1378 uint64 k, float alpha,
1379 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1380 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1381 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1382 int ldc, int batch_count);
1383 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1384 uint64 m, uint64 n, uint64 k, float alpha,
1385 const port::ArraySlice<DeviceMemory<float> *> &a,
1386 int lda,
1387 const port::ArraySlice<DeviceMemory<float> *> &b,
1388 int ldb, float beta,
1389 const port::ArraySlice<DeviceMemory<float> *> &c,
1390 int ldc, int batch_count);
1391 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1392 uint64 m, uint64 n, uint64 k, double alpha,
1393 const port::ArraySlice<DeviceMemory<double> *> &a,
1394 int lda,
1395 const port::ArraySlice<DeviceMemory<double> *> &b,
1396 int ldb, double beta,
1397 const port::ArraySlice<DeviceMemory<double> *> &c,
1398 int ldc, int batch_count);
1399 Stream &ThenBlasGemmBatched(
1400 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1401 uint64 k, std::complex<float> alpha,
1402 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1403 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1404 std::complex<float> beta,
1405 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1406 int batch_count);
1407 Stream &ThenBlasGemmBatched(
1408 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1409 uint64 k, std::complex<double> alpha,
1410 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1411 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1412 std::complex<double> beta,
1413 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1414 int batch_count);
1415 Stream &ThenBlasGemmBatchedWithScratch(
1416 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1417 uint64 k, float alpha,
1418 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1419 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1420 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1421 int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1422 Stream &ThenBlasGemmBatchedWithScratch(
1423 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1424 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1425 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1426 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1427 int batch_count, ScratchAllocator *scratch_allocator);
1428 Stream &ThenBlasGemmBatchedWithScratch(
1429 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1430 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1431 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1432 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1433 int batch_count, ScratchAllocator *scratch_allocator);
1434 Stream &ThenBlasGemmBatchedWithScratch(
1435 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1436 uint64 k, std::complex<float> alpha,
1437 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1438 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1439 std::complex<float> beta,
1440 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1441 int batch_count, ScratchAllocator *scratch_allocator);
1442 Stream &ThenBlasGemmBatchedWithScratch(
1443 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1444 uint64 k, std::complex<double> alpha,
1445 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1446 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1447 std::complex<double> beta,
1448 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1449 int batch_count, ScratchAllocator *scratch_allocator);
1450 Stream &ThenBlasGemmStridedBatched(
1451 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1452 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1453 int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1454 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1455 int64 stride_c, int batch_count);
1456 Stream &ThenBlasGemmStridedBatched(
1457 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1458 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1459 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1460 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1461 int batch_count);
1462 Stream &ThenBlasGemmStridedBatched(
1463 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1464 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1465 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1466 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1467 int batch_count);
1468 Stream &ThenBlasGemmStridedBatched(
1469 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1470 uint64 k, std::complex<float> alpha,
1471 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1472 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1473 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1474 int64 stride_c, int batch_count);
1475 Stream &ThenBlasGemmStridedBatched(
1476 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1477 uint64 k, std::complex<double> alpha,
1478 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1479 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1480 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1481 int64 stride_c, int batch_count);
1482
1483 // See BlasSupport::DoBlasHemm.
1484 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1485 uint64 n, std::complex<float> alpha,
1486 const DeviceMemory<std::complex<float>> &a, int lda,
1487 const DeviceMemory<std::complex<float>> &b, int ldb,
1488 std::complex<float> beta,
1489 DeviceMemory<std::complex<float>> *c, int ldc);
1490 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1491 uint64 n, std::complex<double> alpha,
1492 const DeviceMemory<std::complex<double>> &a, int lda,
1493 const DeviceMemory<std::complex<double>> &b, int ldb,
1494 std::complex<double> beta,
1495 DeviceMemory<std::complex<double>> *c, int ldc);
1496
1497 // See BlasSupport::DoBlasHerk.
1498 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1499 uint64 k, float alpha,
1500 const DeviceMemory<std::complex<float>> &a, int lda,
1501 float beta, DeviceMemory<std::complex<float>> *c,
1502 int ldc);
1503 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1504 uint64 k, double alpha,
1505 const DeviceMemory<std::complex<double>> &a, int lda,
1506 double beta, DeviceMemory<std::complex<double>> *c,
1507 int ldc);
1508
1509 // See BlasSupport::DoBlasHer2k.
1510 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1511 uint64 k, std::complex<float> alpha,
1512 const DeviceMemory<std::complex<float>> &a, int lda,
1513 const DeviceMemory<std::complex<float>> &b, int ldb,
1514 float beta, DeviceMemory<std::complex<float>> *c,
1515 int ldc);
1516 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1517 uint64 k, std::complex<double> alpha,
1518 const DeviceMemory<std::complex<double>> &a, int lda,
1519 const DeviceMemory<std::complex<double>> &b, int ldb,
1520 double beta, DeviceMemory<std::complex<double>> *c,
1521 int ldc);
1522
1523 // See BlasSupport::DoBlasSymm.
1524 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1525 uint64 n, float alpha, const DeviceMemory<float> &a,
1526 int lda, const DeviceMemory<float> &b, int ldb,
1527 float beta, DeviceMemory<float> *c, int ldc);
1528 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1529 uint64 n, double alpha, const DeviceMemory<double> &a,
1530 int lda, const DeviceMemory<double> &b, int ldb,
1531 double beta, DeviceMemory<double> *c, int ldc);
1532 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1533 uint64 n, std::complex<float> alpha,
1534 const DeviceMemory<std::complex<float>> &a, int lda,
1535 const DeviceMemory<std::complex<float>> &b, int ldb,
1536 std::complex<float> beta,
1537 DeviceMemory<std::complex<float>> *c, int ldc);
1538 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1539 uint64 n, std::complex<double> alpha,
1540 const DeviceMemory<std::complex<double>> &a, int lda,
1541 const DeviceMemory<std::complex<double>> &b, int ldb,
1542 std::complex<double> beta,
1543 DeviceMemory<std::complex<double>> *c, int ldc);
1544
1545 // See BlasSupport::DoBlasSyrk.
1546 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1547 uint64 k, float alpha, const DeviceMemory<float> &a,
1548 int lda, float beta, DeviceMemory<float> *c, int ldc);
1549 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1550 uint64 k, double alpha, const DeviceMemory<double> &a,
1551 int lda, double beta, DeviceMemory<double> *c, int ldc);
1552 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1553 uint64 k, std::complex<float> alpha,
1554 const DeviceMemory<std::complex<float>> &a, int lda,
1555 std::complex<float> beta,
1556 DeviceMemory<std::complex<float>> *c, int ldc);
1557 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1558 uint64 k, std::complex<double> alpha,
1559 const DeviceMemory<std::complex<double>> &a, int lda,
1560 std::complex<double> beta,
1561 DeviceMemory<std::complex<double>> *c, int ldc);
1562
1563 // See BlasSupport::DoBlasSyr2k.
1564 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1565 uint64 k, float alpha, const DeviceMemory<float> &a,
1566 int lda, const DeviceMemory<float> &b, int ldb,
1567 float beta, DeviceMemory<float> *c, int ldc);
1568 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1569 uint64 k, double alpha, const DeviceMemory<double> &a,
1570 int lda, const DeviceMemory<double> &b, int ldb,
1571 double beta, DeviceMemory<double> *c, int ldc);
1572 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1573 uint64 k, std::complex<float> alpha,
1574 const DeviceMemory<std::complex<float>> &a, int lda,
1575 const DeviceMemory<std::complex<float>> &b, int ldb,
1576 std::complex<float> beta,
1577 DeviceMemory<std::complex<float>> *c, int ldc);
1578 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1579 uint64 k, std::complex<double> alpha,
1580 const DeviceMemory<std::complex<double>> &a, int lda,
1581 const DeviceMemory<std::complex<double>> &b, int ldb,
1582 std::complex<double> beta,
1583 DeviceMemory<std::complex<double>> *c, int ldc);
1584
1585 // See BlasSupport::DoBlasTrmm.
1586 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1587 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1588 uint64 n, float alpha, const DeviceMemory<float> &a,
1589 int lda, DeviceMemory<float> *b, int ldb);
1590 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1591 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1592 uint64 n, double alpha, const DeviceMemory<double> &a,
1593 int lda, DeviceMemory<double> *b, int ldb);
1594 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1595 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1596 uint64 n, std::complex<float> alpha,
1597 const DeviceMemory<std::complex<float>> &a, int lda,
1598 DeviceMemory<std::complex<float>> *b, int ldb);
1599 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1600 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1601 uint64 n, std::complex<double> alpha,
1602 const DeviceMemory<std::complex<double>> &a, int lda,
1603 DeviceMemory<std::complex<double>> *b, int ldb);
1604
1605 // See BlasSupport::DoBlasTrsm.
1606 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1607 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1608 uint64 n, float alpha, const DeviceMemory<float> &a,
1609 int lda, DeviceMemory<float> *b, int ldb);
1610 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1611 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1612 uint64 n, double alpha, const DeviceMemory<double> &a,
1613 int lda, DeviceMemory<double> *b, int ldb);
1614 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1615 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1616 uint64 n, std::complex<float> alpha,
1617 const DeviceMemory<std::complex<float>> &a, int lda,
1618 DeviceMemory<std::complex<float>> *b, int ldb);
1619 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1620 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1621 uint64 n, std::complex<double> alpha,
1622 const DeviceMemory<std::complex<double>> &a, int lda,
1623 DeviceMemory<std::complex<double>> *b, int ldb);
1624
1625 // See FftSupport::DoFft.
1626 Stream &ThenFft(fft::Plan *plan,
1627 const DeviceMemory<std::complex<float>> &input,
1628 DeviceMemory<std::complex<float>> *output);
1629 Stream &ThenFft(fft::Plan *plan,
1630 const DeviceMemory<std::complex<double>> &input,
1631 DeviceMemory<std::complex<double>> *output);
1632 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1633 DeviceMemory<std::complex<float>> *output);
1634 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1635 DeviceMemory<std::complex<double>> *output);
1636 Stream &ThenFft(fft::Plan *plan,
1637 const DeviceMemory<std::complex<float>> &input,
1638 DeviceMemory<float> *output);
1639 Stream &ThenFft(fft::Plan *plan,
1640 const DeviceMemory<std::complex<double>> &input,
1641 DeviceMemory<double> *output);
1642
1643 // Makes the RNG use the provided value as the basis for further generation.
1644 // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1645 // sources of seed data if the default (high quality) sources are not
1646 // desired.
1647 // For most use cases, this function will not be necessary; each provided
1648 // back-end implementation will be appropriately seeded by default.
1649 // At a minimum 16 bytes of data are required in the seed buffer.
1650 //
1651 // To seed with good (non-reproducible) data:
1652 // File* f = File::Open("/dev/random", "r");
1653 // int64 bytes_read = f->Read(seed_data, bytes_to_read);
1654 // < error checking >
1655 // stream.ThenSetRngSeed(seed_data, bytes_read);
1656 //
1657 // To seed with reproducible data:
1658 // uint64_t seed_data[2] = { <data> };
1659 // stream.ThenSetRngSeed(seed_data, 16);
1660 Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1661
1662 // Populates the memory indicated by values with uniform-random-distribution
1663 // values. TODO(leary) seeding API/description
1664 //
1665 // Uses the type and size of the DeviceMemory to infer what data should be
1666 // populated.
1667 Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1668 Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1669 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1670 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1671 Stream &ThenPopulateRandGaussian(float mean, float stddev,
1672 DeviceMemory<float> *values);
1673 Stream &ThenPopulateRandGaussian(double mean, double stddev,
1674 DeviceMemory<double> *values);
1675
1676 // Entrain onto the stream: a memcpy to a host destination from a GPU source
1677 // of the given target size. host_dst must be a pointer to host memory
1678 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1679 // then registered with StreamExecutor::HostMemoryRegister.
1680 Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1681 uint64 size);
1682
1683 // Entrain onto the stream: a memcpy to a GPU destination from a host source
1684 // of the given target size. host_src must be a pointer to host memory
1685 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1686 // then registered with StreamExecutor::HostMemoryRegister.
1687 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1688 uint64 size);
1689
1690 // Alternative interface for memcpying from device to host that takes an
1691 // array slice. Checks that the destination size can accommodate the host
1692 // slice size.
1693 template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1694 Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1695 port::MutableArraySlice<T> host_dst) {
1696 auto host_size = host_dst.size() * sizeof(T);
1697 CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1698 return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1699 }
1700
1701 // Alternative interface for memcpying from host to device that takes an
1702 // array slice. Checks that the destination size can accommodate the host
1703 // slice size.
1704 template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1705 Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1706 DeviceMemory<T> *gpu_dst) {
1707 auto host_size = host_src.size() * sizeof(T);
1708 CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1709 return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1710 }
1711
1712 // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1713 // of the given target size. gpu_src/dst must be pointers to GPU memory and
1714 // peer access must be enabled between their owning StreamExecutors.
1715 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1716 uint64 size);
1717
1718 // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1719 // ensuring that the host pointer isn't getting confused accidentally with a
1720 // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1721 Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1722 const DeviceMemoryBase &gpu_src, uint64 size) {
1723 return ThenMemcpy(gpu_dst, gpu_src, size);
1724 }
1725
1726 // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1727 // The location must not be null.
1728 Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1729
1730 // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1731 // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1732 // by 4). The location must not be null.
1733 Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1734
1735 // Enqueue a forward operation of the RNN model onto the stream.
1736 // See DnnSupport::DoRnnForward for more details.
1737 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1738 const dnn::RnnSequenceTensorDescriptor &input_desc,
1739 const DeviceMemory<Eigen::half> &input_data,
1740 const dnn::RnnStateTensorDescriptor &input_h_desc,
1741 const DeviceMemory<Eigen::half> &input_h_data,
1742 const dnn::RnnStateTensorDescriptor &input_c_desc,
1743 const DeviceMemory<Eigen::half> &input_c_data,
1744 const DeviceMemory<Eigen::half> ¶ms,
1745 const dnn::RnnSequenceTensorDescriptor &output_desc,
1746 DeviceMemory<Eigen::half> *output_data,
1747 const dnn::RnnStateTensorDescriptor &output_h_desc,
1748 DeviceMemory<Eigen::half> *output_h_data,
1749 const dnn::RnnStateTensorDescriptor &output_c_desc,
1750 DeviceMemory<Eigen::half> *output_c_data,
1751 bool is_training,
1752 ScratchAllocator *reserve_space_allocator,
1753 ScratchAllocator *workspace_allocator,
1754 dnn::ProfileResult *output_profile_result);
1755
1756 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1757 const dnn::RnnSequenceTensorDescriptor &input_desc,
1758 const DeviceMemory<float> &input_data,
1759 const dnn::RnnStateTensorDescriptor &input_h_desc,
1760 const DeviceMemory<float> &input_h_data,
1761 const dnn::RnnStateTensorDescriptor &input_c_desc,
1762 const DeviceMemory<float> &input_c_data,
1763 const DeviceMemory<float> ¶ms,
1764 const dnn::RnnSequenceTensorDescriptor &output_desc,
1765 DeviceMemory<float> *output_data,
1766 const dnn::RnnStateTensorDescriptor &output_h_desc,
1767 DeviceMemory<float> *output_h_data,
1768 const dnn::RnnStateTensorDescriptor &output_c_desc,
1769 DeviceMemory<float> *output_c_data, bool is_training,
1770 ScratchAllocator *reserve_space_allocator,
1771 ScratchAllocator *workspace_allocator,
1772 dnn::ProfileResult *output_profile_result);
1773
1774 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1775 const dnn::RnnSequenceTensorDescriptor &input_desc,
1776 const DeviceMemory<double> &input_data,
1777 const dnn::RnnStateTensorDescriptor &input_h_desc,
1778 const DeviceMemory<double> &input_h_data,
1779 const dnn::RnnStateTensorDescriptor &input_c_desc,
1780 const DeviceMemory<double> &input_c_data,
1781 const DeviceMemory<double> ¶ms,
1782 const dnn::RnnSequenceTensorDescriptor &output_desc,
1783 DeviceMemory<double> *output_data,
1784 const dnn::RnnStateTensorDescriptor &output_h_desc,
1785 DeviceMemory<double> *output_h_data,
1786 const dnn::RnnStateTensorDescriptor &output_c_desc,
1787 DeviceMemory<double> *output_c_data, bool is_training,
1788 ScratchAllocator *reserve_space_allocator,
1789 ScratchAllocator *workspace_allocator,
1790 dnn::ProfileResult *output_profile_result);
1791
1792 // Enqueue a backward operation of the RNN model onto the stream.
1793 // See DnnSupport::DoRnnBackward for more details.
1794 Stream &ThenRnnBackward(
1795 const dnn::RnnDescriptor &rnn_desc,
1796 const dnn::RnnSequenceTensorDescriptor &input_desc,
1797 const DeviceMemory<Eigen::half> &input_data,
1798 const dnn::RnnStateTensorDescriptor &input_h_desc,
1799 const DeviceMemory<Eigen::half> &input_h_data,
1800 const dnn::RnnStateTensorDescriptor &input_c_desc,
1801 const DeviceMemory<Eigen::half> &input_c_data,
1802 const DeviceMemory<Eigen::half> ¶ms,
1803 const dnn::RnnSequenceTensorDescriptor &output_desc,
1804 const DeviceMemory<Eigen::half> &output_data,
1805 const dnn::RnnStateTensorDescriptor &output_h_desc,
1806 const DeviceMemory<Eigen::half> &output_h_data,
1807 const dnn::RnnStateTensorDescriptor &output_c_desc,
1808 const DeviceMemory<Eigen::half> &output_c_data,
1809 const DeviceMemory<Eigen::half> &output_backprop_data,
1810 const DeviceMemory<Eigen::half> &output_h_backprop_data,
1811 const DeviceMemory<Eigen::half> &output_c_backprop_data,
1812 DeviceMemory<Eigen::half> *input_backprop_data,
1813 DeviceMemory<Eigen::half> *input_h_backprop_data,
1814 DeviceMemory<Eigen::half> *input_c_backprop_data,
1815 DeviceMemory<Eigen::half> *params_backprop_data,
1816 DeviceMemory<uint8> *reserve_space_data,
1817 ScratchAllocator *workspace_allocator,
1818 dnn::ProfileResult *output_profile_result);
1819
1820 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1821 const dnn::RnnSequenceTensorDescriptor &input_desc,
1822 const DeviceMemory<float> &input_data,
1823 const dnn::RnnStateTensorDescriptor &input_h_desc,
1824 const DeviceMemory<float> &input_h_data,
1825 const dnn::RnnStateTensorDescriptor &input_c_desc,
1826 const DeviceMemory<float> &input_c_data,
1827 const DeviceMemory<float> ¶ms,
1828 const dnn::RnnSequenceTensorDescriptor &output_desc,
1829 const DeviceMemory<float> &output_data,
1830 const dnn::RnnStateTensorDescriptor &output_h_desc,
1831 const DeviceMemory<float> &output_h_data,
1832 const dnn::RnnStateTensorDescriptor &output_c_desc,
1833 const DeviceMemory<float> &output_c_data,
1834 const DeviceMemory<float> &output_backprop_data,
1835 const DeviceMemory<float> &output_h_backprop_data,
1836 const DeviceMemory<float> &output_c_backprop_data,
1837 DeviceMemory<float> *input_backprop_data,
1838 DeviceMemory<float> *input_h_backprop_data,
1839 DeviceMemory<float> *input_c_backprop_data,
1840 DeviceMemory<float> *params_backprop_data,
1841 DeviceMemory<uint8> *reserve_space_data,
1842 ScratchAllocator *workspace_allocator,
1843 dnn::ProfileResult *output_profile_result);
1844
1845 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1846 const dnn::RnnSequenceTensorDescriptor &input_desc,
1847 const DeviceMemory<double> &input_data,
1848 const dnn::RnnStateTensorDescriptor &input_h_desc,
1849 const DeviceMemory<double> &input_h_data,
1850 const dnn::RnnStateTensorDescriptor &input_c_desc,
1851 const DeviceMemory<double> &input_c_data,
1852 const DeviceMemory<double> ¶ms,
1853 const dnn::RnnSequenceTensorDescriptor &output_desc,
1854 const DeviceMemory<double> &output_data,
1855 const dnn::RnnStateTensorDescriptor &output_h_desc,
1856 const DeviceMemory<double> &output_h_data,
1857 const dnn::RnnStateTensorDescriptor &output_c_desc,
1858 const DeviceMemory<double> &output_c_data,
1859 const DeviceMemory<double> &output_backprop_data,
1860 const DeviceMemory<double> &output_h_backprop_data,
1861 const DeviceMemory<double> &output_c_backprop_data,
1862 DeviceMemory<double> *input_backprop_data,
1863 DeviceMemory<double> *input_h_backprop_data,
1864 DeviceMemory<double> *input_c_backprop_data,
1865 DeviceMemory<double> *params_backprop_data,
1866 DeviceMemory<uint8> *reserve_space_data,
1867 ScratchAllocator *workspace_allocator,
1868 dnn::ProfileResult *output_profile_result);
1869
1870 // Enqueue onto the stream a operation that transforms a tensor.
1871 // See DnnSupport::DoTransformTensor for more details.
1872 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1873 dnn::DataType input_type,
1874 const DeviceMemoryBase &input_data,
1875 const dnn::BatchDescriptor &output_desc,
1876 dnn::DataType output_type, float scale,
1877 DeviceMemoryBase *output_data);
1878
1879 // The templated version of the above ThenTransformTensor. Useful when the
1880 // input and output types are statically known.
1881 template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1882 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1883 const DeviceMemory<InElemT> &input_data,
1884 const dnn::BatchDescriptor &output_desc,
1885 DeviceMemory<OutElemT> *output_data) {
1886 return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1887 input_data, output_desc,
1888 dnn::ToDataType<OutElemT>(), output_data);
1889 }
1890
1891 // (Synchronously) block the host code waiting for the operations
1892 // entrained on the stream (enqueued to this point in program
1893 // execution) to complete.
1894 //
1895 // Returns an OK status if the blocking was successful and the stream is ok().
1896 // Otherwise returns an error describing why the blocking failed.
1897 port::Status BlockHostUntilDone() LOCKS_EXCLUDED(mu_);
1898
1899 // Warning! This method interacts with internal threads in
1900 // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1901 // use
1902 // only. Please check with a member of the FASTR team before making use of
1903 // this method.
1904 //
1905 // Entrains onto the stream a function to be executed on the host at some
1906 // point in the future.
1907 // Async host callbacks DO NOT block the stream as device functions (or as
1908 // synchronous host callbacks). No synchronization is possible with
1909 // asynchronous callbacks; they are strictly fire-and-forget.
1910 // This method is private due to the potential for undefined behavior with
1911 // synchronization using OpenCL user events.
1912 // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1913 // parameter will still be valid - this Stream may not be!
1914 // Any callbacks requiring device API calls must use this method.
1915 Stream &ThenEnqueueOnBackgroundThread(
1916 std::function<void(StreamExecutor *)> task);
1917
1918 // Returns the (opaque) platform-specific backing object. Ownership is not
1919 // transferred to the caller.
implementation()1920 internal::StreamInterface *implementation() { return implementation_.get(); }
1921
1922 // Entrains onto the stream a callback to the host (from the device).
1923 // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1924 // never fail or its failure is inconsequential.
1925 //
1926 // This is kept for backward compatibility. Future code should use
1927 // ThenDoHostCallbackWithStatus and explicitly return a success status.
1928 // TODO(b/112125301): Eventually remove this method.
1929 Stream &ThenDoHostCallback(std::function<void()> callback);
1930
1931 // Entrains onto the stream a callback to the host (from the device).
1932 // Host callbacks block/occupy the stream just as device functions
1933 // (execute one at a time, block later stream operations).
1934 // Whether the callback return status affects the result of BlockHostUntilDone
1935 // is platform-dependent.
1936 //
1937 // Behavior is undefined when synchronizing using OpenCL user events.
1938 // Behavior is undefined if host callbacks call device routines or insert
1939 // them into any stream.
1940 //
1941 // On certain platforms, ThenDoHostCallback is expected to have significant
1942 // negative effects on performance.
1943 Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
1944
1945 // Returns the StreamExecutor (parent object) associated with this stream.
parent()1946 StreamExecutor *parent() const {
1947 CHECK(parent_ != nullptr);
1948 return parent_;
1949 }
1950
1951 // Returns the (internal usage) temporary-memory-allocation manager associated
1952 // with this stream.
1953 internal::TemporaryMemoryManager *temporary_memory_manager();
1954
1955 // Returns a debugging string "[stream=0x...,impl=0x...]".
1956 string DebugStreamPointers() const;
1957
1958 private:
1959 friend class host::HostBlas; // for parent_.
1960 friend class host::HostFft; // for parent_.
1961 friend class host::HostRng; // for parent_.
1962 template <typename... Args>
1963 friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
1964 friend class ocl::CLBlas; // for parent_.
1965
InErrorState()1966 bool InErrorState() const LOCKS_EXCLUDED(mu_) {
1967 tf_shared_lock lock(mu_);
1968 return !ok_;
1969 }
1970
1971 // Sets the error state if operation_retcode is false.
1972 // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)1973 void CheckError(bool operation_retcode) LOCKS_EXCLUDED(mu_) {
1974 if (operation_retcode) {
1975 return;
1976 }
1977 mutex_lock lock(mu_);
1978 ok_ = false;
1979 }
1980
1981 // Checks the status and logs the error message, if any.
1982 void CheckStatus(port::Status status) LOCKS_EXCLUDED(mu_);
1983
SetError()1984 void SetError() { CheckError(false /* = operation_retcode */); }
1985
SetErrorAndLogNoDnnSupport()1986 void SetErrorAndLogNoDnnSupport() {
1987 SetError();
1988 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
1989 "without DNN support";
1990 }
1991
1992 // The StreamExecutor that supports the operation of this stream.
1993 StreamExecutor *parent_;
1994
1995 // The platform-dependent implementation that the StreamExecutor interface
1996 // delegates to.
1997 std::unique_ptr<internal::StreamInterface> implementation_;
1998
1999 // mutex that guards the allocation / error state flags.
2000 // Mutable so that it can be obtained via const reader lock.
2001 mutable mutex mu_;
2002
2003 // Whether Init() was successfully called to allocate this stream on the
2004 // underlying platform. It simply flips from 0 to 1 with a sanity check.
2005 // See StreamExecutor::AllocateStream.
2006 bool allocated_ GUARDED_BY(mu_);
2007
2008 // Whether all operations have entrained successfully to the current program
2009 // point.
2010 bool ok_ GUARDED_BY(mu_);
2011
2012 // Sub-streams that are generated from this stream. Each element has a pointer
2013 // to sub-stream and a boolean value indicating if this substream is ready to
2014 // be reused.
2015 std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2016 GUARDED_BY(mu_);
2017
2018 // Streams can allocate temporary memories to help with work they enqueue
2019 // (e.g. for scratch memory spaces). This member tracks those allocations and
2020 // notes when they can be reclaimed -- reclamation is attempted when
2021 // BlockHostUntilDone() is called.
2022 internal::TemporaryMemoryManager temporary_memory_manager_;
2023
2024 // Implementation of ThenConvolveBackwardBias that is shared by all types.
2025 template <typename T>
2026 Stream &ThenConvolveBackwardBiasImpl(
2027 const dnn::BatchDescriptor &input_descriptor,
2028 const DeviceMemory<T> &input_data,
2029 const dnn::BatchDescriptor &bias_descriptor,
2030 DeviceMemory<T> *backward_bias_data);
2031
2032 SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2033 };
2034
2035 ////////////
2036 // Inlines
2037
2038 template <typename T>
2039 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2040 Stream::AllocateTemporaryArray(uint64 element_count) {
2041 return temporary_memory_manager_.AllocateArray<T>(element_count);
2042 }
2043
temporary_memory_manager()2044 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2045 return &temporary_memory_manager_;
2046 }
2047
2048 template <>
2049 struct Quantization<uint8> {
2050 static constexpr dnn::QuantizedActivationMode kModeId =
2051 dnn::QuantizedActivationMode::k8Bit;
2052 };
2053
2054 template <>
2055 struct Quantization<uint16> {
2056 static constexpr dnn::QuantizedActivationMode kModeId =
2057 dnn::QuantizedActivationMode::k16Bit;
2058 };
2059
2060 template <>
2061 struct Quantization<int32> {
2062 static constexpr dnn::QuantizedActivationMode kModeId =
2063 dnn::QuantizedActivationMode::k32Bit;
2064 };
2065
2066 } // namespace stream_executor
2067
2068 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2069