1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23
24 #include <complex>
25 #include <functional>
26 #include <memory>
27 #include <type_traits>
28
29 #include "absl/synchronization/mutex.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/stream_executor/blas.h"
33 #include "tensorflow/stream_executor/device_memory.h"
34 #include "tensorflow/stream_executor/dnn.h"
35 #include "tensorflow/stream_executor/event.h"
36 #include "tensorflow/stream_executor/fft.h"
37 #include "tensorflow/stream_executor/host_or_device_scalar.h"
38 #include "tensorflow/stream_executor/kernel.h"
39 #include "tensorflow/stream_executor/launch_dim.h"
40 #include "tensorflow/stream_executor/lib/array_slice.h"
41 #include "tensorflow/stream_executor/platform/port.h"
42 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
43 #include "tensorflow/stream_executor/temporary_memory_manager.h"
44
45 #if GOOGLE_CUDA
46 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
47 #endif // GOOGLE_CUDA
48
49 namespace stream_executor {
50
51 namespace host {
52 class HostBlas;
53 class HostFft;
54 class HostRng;
55 class HostTimer;
56 } // namespace host
57
58 namespace ocl {
59 class CLBlas;
60 } // namespace ocl
61
62 namespace internal {
63 class StreamInterface;
64 } // namespace internal
65
66 class DeviceMemoryBase;
67 template <typename ElemT>
68 class DeviceMemory;
69
70 class Timer;
71
72 namespace dnn {
73 class BatchDescriptor;
74 class FilterDescriptor;
75 class ConvolutionDescriptor;
76 class ProfileResult;
77 class AlgorithmDesc;
78 } // namespace dnn
79
80 class StreamExecutor;
81 class ScratchAllocator;
82
83 namespace detail {
84
85 // Helper class to prevent a template function argument from being deduced. This
86 // is identical to std::type_identity in C++20.
87 template <typename T>
88 struct NonDeduced {
89 using type = T;
90 };
91 template <typename T>
92 using NonDeducedType = typename NonDeduced<T>::type;
93
94 } // namespace detail
95
96 // Convert a type to the corresponding QuantizedActivationMode.
97 template <typename ElementType>
98 struct Quantization;
99
100 // Represents a stream of dependent computations on a GPU device.
101 //
102 // The operations within a stream execute linearly and asynchronously until
103 // BlockHostUntilDone() is invoked, which synchronously joins host code with
104 // the execution of the stream.
105 //
106 // If any given operation fails when entraining work for the stream, ok() will
107 // indicate that an error has occurred. After initialization, once a stream is
108 // !ok(), it will never be ok().
109 //
110 // Thread-safe post-initialization.
111 class Stream {
112 public:
113 // Instantiate a stream tied to parent as a platform executor. Work
114 // entrained onto this stream will be launched/managed on that
115 // StreamExecutor's platform.
116 explicit Stream(StreamExecutor *parent);
117
118 // Test only. Use an externally-populated value (like a mock) for the
119 // platform-specific stream implementation.
120 Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
121
122 // Deallocates any stream resources that the parent StreamExecutor has
123 // bestowed
124 // upon this object.
125 ~Stream();
126
127 // Returns whether any errors have occurred while entraining work for this
128 // stream.
ok()129 bool ok() const { return !InErrorState(); }
130
131 // Retrieves execution status back into the stream from the underlying
132 // implementation without blocking the stream.
133 //
134 // Normally, Stream::BlockHostUntilDone is used to get execution status.
135 // However, some devices use out-of-band mechnanisms to ensure their streams
136 // have finished on-device work, without needing to block the streams. (These
137 // devices should also override AllowsSyncOnCompletion to return false.) For
138 // these devices, this method can be used after work is finished to retrieve
139 // execution status.
140 port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_);
141
142 // Initialize the stream. This must be performed before entraining any other
143 // operations.
144 Stream &Init() TF_LOCKS_EXCLUDED(mu_);
145
146 // Initializes timer t via the StreamExecutor.
147 Stream &InitTimer(Timer *t);
148
149 // Convenience wrapper around Init() and InitTimer().
150 Stream &InitWithTimer(Timer *t);
151
152 // Get or create a sub-stream from this stream. If there is any sub-stream in
153 // the pool that can be reused then just return this sub-stream. Otherwise
154 // create a new sub-stream.
155 //
156 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
157 Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_);
158
159 // Return the sub-stream back to the host stream so that it can be reused
160 // later. Sub-streams that are !ok() will not be reused.
161 //
162 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
163 void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_);
164
165 // Allocate temporary memories. The stream will deallocate them when blocked
166 // or destroyed.
167 template <typename T>
168 port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
169 AllocateTemporaryArray(uint64 element_count);
170
171 // Entrains onto the stream of operations: a kernel launch with the given
172 // (variadic) parameters for the invocation. These arguments can be things
173 // like DeviceMemory or primitive types such as int. What arguments you may
174 // pass to a given kernel are noted as the template parameters to the
175 // TypedKernel type that the machocc compiler generates.
176 //
177 // Template parameters:
178 // Params... The type list of formal parameters that the typed kernel
179 // expects, which is matched against Args...
180 // Args... The deduced type list for passed actual arguments
181 //
182 // Implementation: A compile-time compatibility check is performed that has
183 // some leniency versus an exact parameter pack match -- for example,
184 // `const DeviceMemory<T>` is considered "pack compatible" with a
185 // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
186 // perfect forwarding support without rvalue references. It also attempts to
187 // spit out helpful static_assert error traces with information as to the
188 // argument number and types that were mismatched.
189 template <typename... Params, typename... Args>
190 Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
191 const TypedKernel<Params...> &kernel, Args... args);
192
193 // Record a "start" event for the interval timer at this point in the
194 // stream's execution (relative to the previously and subsequently enqueued
195 // items in the stream's execution). Streams may be started/stopped multiple
196 // times.
197 Stream &ThenStartTimer(Timer *t);
198
199 // Record a "stop" event for the interval timer at this point in the
200 // stream's execution. See also Stream::ThenStartTimer.
201 Stream &ThenStopTimer(Timer *t);
202
203 // TODO(leary) If work is added to the stream that is being depended upon,
204 // then what? Have to describe what happens.
205 template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)206 Stream &ThenWaitFor(Stream *other, Params... more_streams) {
207 return ThenWaitFor(more_streams...).ThenWaitFor(other);
208 }
209
210 // Create a dependency for this stream's next work on the other stream
211 // completing. Does not take ownership of other, and other must not be
212 // null.
213 //
214 // Checks that a stream does not wait for itself, and it is up to the
215 // user to guarantee that a stream does not come to wait on itself in a
216 // cyclic manner; in that case, behavior is undefined.
217 //
218 // N.B. Base recursion case for the variadic ThenWaitFor.
219 Stream &ThenWaitFor(Stream *other);
220
221 // Waits for all streams values in others.
222 // Checks that there is no shallow circular wait (i.e. that "this" is not in
223 // others)
224 template <typename P>
ThenWaitFor(P others)225 Stream &ThenWaitFor(P others) {
226 for (auto &stream : *others) {
227 CHECK_NE(stream.get(), this);
228 ThenWaitFor(stream.get());
229 }
230 return *this;
231 }
232
233 // Waits for an event object to be set.
234 // Note that ThenRecordEvent must have been called on the event before
235 // you call this function; otherwise the event will be considered complete
236 // and this wait will do nothing.
237 Stream &ThenWaitFor(Event *event);
238
239 // Inserts the specified event into the end of this stream. Once the stream
240 // has processed all events prior to the insertion point, the event will be
241 // marked as completed.
242 // The stream does not take ownership of event - meaning that event's lifetime
243 // must extend past the point at which it is marked complete!
244 Stream &ThenRecordEvent(Event *event);
245
246 ////////////////
247 // DNN support
248 //
249 // See DnnSupport::* for comments on the following methods.
250
251 Stream &ThenBatchNormalizationForward(
252 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
253 const DeviceMemory<float> &offset,
254 const DeviceMemory<float> &estimated_mean,
255 const DeviceMemory<float> &estimated_variance,
256 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
257 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
258 const double exponential_average_factor,
259 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
260 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
261 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
262 bool is_training,
263 ScratchAllocator *reserve_space_allocator,
264 ScratchAllocator *workspace_allocator);
265
266 Stream &ThenBatchNormalizationBackward(
267 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
268 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
269 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
270 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
271 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
272 DeviceMemory<float> *offset_backprop,
273 DeviceMemory<uint8> *reserve_space_data,
274 ScratchAllocator *workspace_allocator);
275
276 Stream &ThenBatchNormalizationForward(
277 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
278 const DeviceMemory<float> &offset,
279 const DeviceMemory<float> &estimated_mean,
280 const DeviceMemory<float> &estimated_variance,
281 const DeviceMemory<Eigen::half> &side_input,
282 const dnn::BatchDescriptor &x_desc,
283 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
284 const double exponential_average_factor,
285 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
286 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
287 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
288 bool is_training, ScratchAllocator *reserve_space_allocator,
289 ScratchAllocator *workspace_allocator);
290
291 Stream &ThenBatchNormalizationBackward(
292 const DeviceMemory<Eigen::half> &y_backprop,
293 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
294 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
295 const dnn::BatchDescriptor &x_desc,
296 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
297 DeviceMemory<Eigen::half> *x_backprop,
298 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
299 DeviceMemory<uint8> *reserve_space_data,
300 ScratchAllocator *workspace_allocator);
301
302 Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
303 const DeviceMemory<float> &input_data,
304 const dnn::FilterDescriptor &filter_descriptor,
305 const DeviceMemory<float> &filter_data,
306 const dnn::ConvolutionDescriptor &convolution_descriptor,
307 const dnn::BatchDescriptor &output_descriptor,
308 DeviceMemory<float> *output);
309
310 Stream &ThenConvolveQuantized(
311 const dnn::BatchDescriptor &input_descriptor,
312 const DeviceMemory<float> &input_data,
313 const dnn::FilterDescriptor &filter_descriptor,
314 const DeviceMemory<int8> &filter_coefficients,
315 const DeviceMemory<float> &coefficient_scales,
316 const dnn::ConvolutionDescriptor &convolution_descriptor,
317 const dnn::BatchDescriptor &output_descriptor,
318 DeviceMemory<float> *output_data);
319
320 Stream &ThenConvolveQuantized(
321 const dnn::BatchDescriptor &input_descriptor,
322 const DeviceMemory<float> &input_data,
323 const dnn::FilterDescriptor &filter_descriptor,
324 const DeviceMemory<int16> &filter_coefficients,
325 const DeviceMemory<float> &coefficient_scales,
326 const dnn::ConvolutionDescriptor &convolution_descriptor,
327 const dnn::BatchDescriptor &output_descriptor,
328 DeviceMemory<float> *output_data);
329
330 template <typename InputType, typename OutputType>
ConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<InputType> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)331 port::Status ConvolveWithAlgorithm(
332 const dnn::BatchDescriptor &input_descriptor,
333 const DeviceMemory<InputType> &input_data,
334 const dnn::FilterDescriptor &filter_descriptor,
335 const DeviceMemory<InputType> &filter_data,
336 const dnn::ConvolutionDescriptor &convolution_descriptor,
337 const dnn::BatchDescriptor &output_descriptor,
338 DeviceMemory<OutputType> *output, ScratchAllocator *scratch_allocator,
339 const dnn::AlgorithmConfig &algorithm_config,
340 dnn::ProfileResult *output_profile_result) {
341 DeviceMemory<uint8> scratch_memory;
342 dnn::AlgorithmDesc algorithm_desc;
343 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
344 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
345 dnn::ConvolutionKind::FORWARD, this, input_descriptor, input_data,
346 filter_descriptor, filter_data, output_descriptor, *output,
347 convolution_descriptor, algorithm_config, scratch_allocator,
348 &algorithm_desc, &scratch_memory));
349 return dnn->DoConvolve(
350 dnn::ConvolutionKind::FORWARD, dnn::ToDataType<InputType>::value,
351 dnn::ToDataType<OutputType>::value, this, input_descriptor,
352 input_data, filter_descriptor, filter_data, output_descriptor,
353 *output, convolution_descriptor, algorithm_desc, scratch_memory,
354 output_profile_result);
355 }
356 return port::UnimplementedError("DNN library is not found.");
357 }
358
359 template <typename InputType, typename OutputType>
ConvolveWithExecutionPlan(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<InputType> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & plan_config,dnn::ProfileResult * output_profile_result)360 port::Status ConvolveWithExecutionPlan(
361 const dnn::BatchDescriptor &input_descriptor,
362 const DeviceMemory<InputType> &input_data,
363 const dnn::FilterDescriptor &filter_descriptor,
364 const DeviceMemory<InputType> &filter_data,
365 const dnn::ConvolutionDescriptor &convolution_descriptor,
366 const dnn::BatchDescriptor &output_descriptor,
367 DeviceMemory<OutputType> *output, ScratchAllocator *scratch_allocator,
368 const dnn::AlgorithmConfig &plan_config,
369 dnn::ProfileResult *output_profile_result) {
370 #if GOOGLE_CUDA
371 dnn::DnnSupport *dnn = parent_->AsDnn();
372 if (dnn) {
373 gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
374 return cudnn_dnn->DoConvolveWithExecutionPlan(
375 dnn::ConvolutionKind::FORWARD, dnn::ToDataType<InputType>::value,
376 dnn::ToDataType<OutputType>::value, this, input_descriptor,
377 input_data, filter_descriptor, filter_data, output_descriptor,
378 *output, convolution_descriptor, plan_config, scratch_allocator,
379 output_profile_result);
380 }
381 #endif // GOOGLE_CUDA
382 return port::UnimplementedError("DNN library is not found.");
383 }
384
385 template <typename InputT, typename ScaleT, typename SideInputT,
386 typename BiasT, typename OutputT>
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<InputT> & conv_input_data,ScaleT conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputT> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<SideInputT> & side_input_data,ScaleT side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasT> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputT> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)387 port::Status FusedConvolveWithAlgorithm(
388 const dnn::BatchDescriptor &conv_input_descriptor,
389 const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale,
390 const dnn::FilterDescriptor &filter_descriptor,
391 const DeviceMemory<InputT> &filter_data,
392 const dnn::ConvolutionDescriptor &convolution_descriptor,
393 const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale,
394 const dnn::BatchDescriptor &bias_descriptor,
395 const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode,
396 const dnn::BatchDescriptor &output_descriptor,
397 DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator,
398 const dnn::AlgorithmConfig &algorithm_config,
399 dnn::ProfileResult *output_profile_result) {
400 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
401 return dnn->DoFusedConvolve(
402 this, dnn::ToDataType<InputT>::value,
403 dnn::ToDataType<SideInputT>::value, dnn::ToDataType<BiasT>::value,
404 dnn::ToDataType<OutputT>::value, conv_input_descriptor,
405 conv_input_data, conv_input_scale, filter_descriptor, filter_data,
406 convolution_descriptor, side_input_data, side_input_scale,
407 bias_descriptor, biases, activation_mode, output_descriptor, *output,
408 scratch_allocator, algorithm_config, output_profile_result);
409 }
410 return port::UnimplementedError("DNN library is not found.");
411 }
412
413 template <typename InputT, typename ScaleT, typename SideInputT,
414 typename BiasT, typename OutputT>
FusedConvolveWithExecutionPlan(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<InputT> & conv_input_data,ScaleT conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputT> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<SideInputT> & side_input_data,ScaleT side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasT> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputT> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)415 port::Status FusedConvolveWithExecutionPlan(
416 const dnn::BatchDescriptor &conv_input_descriptor,
417 const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale,
418 const dnn::FilterDescriptor &filter_descriptor,
419 const DeviceMemory<InputT> &filter_data,
420 const dnn::ConvolutionDescriptor &convolution_descriptor,
421 const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale,
422 const dnn::BatchDescriptor &bias_descriptor,
423 const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode,
424 const dnn::BatchDescriptor &output_descriptor,
425 DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator,
426 const dnn::AlgorithmConfig &algorithm_config,
427 dnn::ProfileResult *output_profile_result) {
428 #if GOOGLE_CUDA
429 dnn::DnnSupport *dnn = parent_->AsDnn();
430 if (dnn) {
431 gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
432 return cudnn_dnn->DoFusedConvolveWithExecutionPlan(
433 this, dnn::ToDataType<InputT>::value, conv_input_descriptor,
434 conv_input_data, conv_input_scale, filter_descriptor, filter_data,
435 convolution_descriptor, side_input_data, side_input_scale,
436 bias_descriptor, biases, activation_mode, output_descriptor, *output,
437 scratch_allocator, algorithm_config, output_profile_result);
438 }
439 #endif // GOOGLE_CUDA
440 return port::UnimplementedError("DNN library is not found.");
441 }
442
443 Stream &ThenSeparableConvolve(
444 const dnn::BatchDescriptor &input_descriptor,
445 const DeviceMemory<float> &input_data,
446 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
447 const DeviceMemory<float> &first_weights,
448 const DeviceMemory<float> &second_weights,
449 const dnn::ConvolutionDescriptor &convolution_descriptor,
450 const dnn::BatchDescriptor &output_descriptor,
451 DeviceMemory<float> *output);
452
453 template <typename ElementType>
ConvolveBackwardDataWithExecutionPlan(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<ElementType> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & plan_config,dnn::ProfileResult * output_profile_result)454 port::Status ConvolveBackwardDataWithExecutionPlan(
455 const dnn::FilterDescriptor &filter_descriptor,
456 const DeviceMemory<ElementType> &filter_data,
457 const dnn::BatchDescriptor &output_descriptor,
458 DeviceMemory<ElementType> backward_output_data,
459 const dnn::ConvolutionDescriptor &convolution_descriptor,
460 const dnn::BatchDescriptor &input_descriptor,
461 DeviceMemory<ElementType> *backward_input_data,
462 ScratchAllocator *scratch_allocator,
463 const dnn::AlgorithmConfig &plan_config,
464 dnn::ProfileResult *output_profile_result) {
465 #if GOOGLE_CUDA
466 dnn::DnnSupport *dnn = parent_->AsDnn();
467 if (dnn) {
468 gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
469 return cudnn_dnn->DoConvolveWithExecutionPlan(
470 dnn::ConvolutionKind::BACKWARD_DATA,
471 dnn::ToDataType<ElementType>::value,
472 dnn::ToDataType<ElementType>::value, this, input_descriptor,
473 *backward_input_data, filter_descriptor, filter_data,
474 output_descriptor, backward_output_data, convolution_descriptor,
475 plan_config, scratch_allocator, output_profile_result);
476 }
477 #endif // GOOGLE_CUDA
478 return port::UnimplementedError("DNN library is not found.");
479 }
480
481 template <typename ElementType>
ConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<ElementType> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)482 port::Status ConvolveBackwardDataWithAlgorithm(
483 const dnn::FilterDescriptor &filter_descriptor,
484 const DeviceMemory<ElementType> &filter_data,
485 const dnn::BatchDescriptor &output_descriptor,
486 DeviceMemory<ElementType> backward_output_data,
487 const dnn::ConvolutionDescriptor &convolution_descriptor,
488 const dnn::BatchDescriptor &input_descriptor,
489 DeviceMemory<ElementType> *backward_input_data,
490 ScratchAllocator *scratch_allocator,
491 const dnn::AlgorithmConfig &algorithm_config,
492 dnn::ProfileResult *output_profile_result) {
493 DeviceMemory<uint8> scratch_memory;
494 dnn::AlgorithmDesc algorithm_desc;
495 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
496 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
497 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
498 *backward_input_data, filter_descriptor, filter_data,
499 output_descriptor, backward_output_data, convolution_descriptor,
500 algorithm_config, scratch_allocator, &algorithm_desc,
501 &scratch_memory));
502 return dnn->DoConvolve(
503 dnn::ConvolutionKind::BACKWARD_DATA,
504 dnn::ToDataType<ElementType>::value,
505 dnn::ToDataType<ElementType>::value, this, input_descriptor,
506 *backward_input_data, filter_descriptor, filter_data,
507 output_descriptor, backward_output_data, convolution_descriptor,
508 algorithm_desc, scratch_memory, output_profile_result);
509 }
510 return port::UnimplementedError("DNN library is not found.");
511 }
512
513 template <typename ElementType>
ConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)514 port::Status ConvolveBackwardFilterWithAlgorithm(
515 const dnn::BatchDescriptor &input_descriptor,
516 const DeviceMemory<ElementType> &input_data,
517 const dnn::BatchDescriptor &output_descriptor,
518 DeviceMemory<ElementType> backward_output_data,
519 const dnn::ConvolutionDescriptor &convolution_descriptor,
520 const dnn::FilterDescriptor &filter_descriptor,
521 DeviceMemory<ElementType> *backward_filter_data,
522 ScratchAllocator *scratch_allocator,
523 const dnn::AlgorithmConfig &algorithm_config,
524 dnn::ProfileResult *output_profile_result) {
525 DeviceMemory<uint8> scratch_memory;
526 dnn::AlgorithmDesc algorithm_desc;
527 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
528 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
529 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
530 input_data, filter_descriptor, *backward_filter_data,
531 output_descriptor, backward_output_data, convolution_descriptor,
532 algorithm_config, scratch_allocator, &algorithm_desc,
533 &scratch_memory));
534 return dnn->DoConvolve(
535 dnn::ConvolutionKind::BACKWARD_FILTER,
536 dnn::ToDataType<ElementType>::value,
537 dnn::ToDataType<ElementType>::value, this, input_descriptor,
538 input_data, filter_descriptor, *backward_filter_data,
539 output_descriptor, backward_output_data, convolution_descriptor,
540 algorithm_desc, scratch_memory, output_profile_result);
541 }
542 return port::UnimplementedError("DNN library is not found.");
543 }
544
545 template <typename ElementType>
ConvolveBackwardFilterWithExecutionPlan(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<ElementType> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<ElementType> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & plan_config,dnn::ProfileResult * output_profile_result)546 port::Status ConvolveBackwardFilterWithExecutionPlan(
547 const dnn::BatchDescriptor &input_descriptor,
548 const DeviceMemory<ElementType> &input_data,
549 const dnn::BatchDescriptor &output_descriptor,
550 DeviceMemory<ElementType> backward_output_data,
551 const dnn::ConvolutionDescriptor &convolution_descriptor,
552 const dnn::FilterDescriptor &filter_descriptor,
553 DeviceMemory<ElementType> *backward_filter_data,
554 ScratchAllocator *scratch_allocator,
555 const dnn::AlgorithmConfig &plan_config,
556 dnn::ProfileResult *output_profile_result) {
557 #if GOOGLE_CUDA
558 dnn::DnnSupport *dnn = parent_->AsDnn();
559 if (dnn) {
560 gpu::CudnnSupport *cudnn_dnn = dynamic_cast<gpu::CudnnSupport *>(dnn);
561 return cudnn_dnn->DoConvolveWithExecutionPlan(
562 dnn::ConvolutionKind::BACKWARD_FILTER,
563 dnn::ToDataType<ElementType>::value,
564 dnn::ToDataType<ElementType>::value, this, input_descriptor,
565 input_data, filter_descriptor, *backward_filter_data,
566 output_descriptor, backward_output_data, convolution_descriptor,
567 plan_config, scratch_allocator, output_profile_result);
568 }
569 #endif // GOOGLE_CUDA
570 return port::UnimplementedError("DNN library is not found.");
571 }
572
573 Stream &ThenMatMul(const DeviceMemory<float> &input_data,
574 const DeviceMemory<float> &weights,
575 const dnn::BatchDescriptor &input_dimensions,
576 const dnn::BatchDescriptor &output_dimensions,
577 DeviceMemory<float> *output_data);
578
579 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
580 const DeviceMemory<int8> &weights,
581 const DeviceMemory<float> &weight_scales,
582 const dnn::BatchDescriptor &input_dimensions,
583 const dnn::BatchDescriptor &output_dimensions,
584 DeviceMemory<float> *output_data);
585
586 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
587 const DeviceMemory<int16> &weights,
588 const DeviceMemory<float> &weight_scales,
589 const dnn::BatchDescriptor &input_dimensions,
590 const dnn::BatchDescriptor &output_dimensions,
591 DeviceMemory<float> *output_data);
592
593 Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
594 const DeviceMemory<float> &biases,
595 const dnn::BatchDescriptor &dimensions,
596 DeviceMemory<float> *output_data);
597
598 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
599 const dnn::BatchDescriptor &input_dimensions,
600 const DeviceMemory<double> &input_data,
601 const dnn::BatchDescriptor &output_dimensions,
602 DeviceMemory<double> *output_data,
603 ScratchAllocator *workspace_allocator = nullptr);
604
605 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
606 const dnn::BatchDescriptor &input_dimensions,
607 const DeviceMemory<float> &input_data,
608 const dnn::BatchDescriptor &output_dimensions,
609 DeviceMemory<float> *output_data,
610 ScratchAllocator *workspace_allocator = nullptr);
611
612 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
613 const dnn::BatchDescriptor &input_dimensions,
614 const DeviceMemory<Eigen::half> &input_data,
615 const dnn::BatchDescriptor &output_dimensions,
616 DeviceMemory<Eigen::half> *output_data,
617 ScratchAllocator *workspace_allocator = nullptr);
618
619 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
620 const dnn::BatchDescriptor &input_dimensions,
621 const DeviceMemory<int8> &input_data,
622 const dnn::BatchDescriptor &output_dimensions,
623 DeviceMemory<int8> *output_data,
624 ScratchAllocator *workspace_allocator = nullptr);
625
626 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
627 const dnn::BatchDescriptor &input_dimensions,
628 const DeviceMemory<double> &input_data,
629 const dnn::BatchDescriptor &output_dimensions,
630 const DeviceMemory<double> &output_data,
631 const DeviceMemory<double> &input_diff_data,
632 DeviceMemory<double> *output_diff_data,
633 ScratchAllocator *workspace_allocator = nullptr);
634
635 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
636 const dnn::BatchDescriptor &input_dimensions,
637 const DeviceMemory<float> &input_data,
638 const dnn::BatchDescriptor &output_dimensions,
639 const DeviceMemory<float> &output_data,
640 const DeviceMemory<float> &input_diff_data,
641 DeviceMemory<float> *output_diff_data,
642 ScratchAllocator *workspace_allocator = nullptr);
643
644 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
645 const dnn::BatchDescriptor &input_dimensions,
646 const DeviceMemory<Eigen::half> &input_data,
647 const dnn::BatchDescriptor &output_dimensions,
648 const DeviceMemory<Eigen::half> &output_data,
649 const DeviceMemory<Eigen::half> &input_diff_data,
650 DeviceMemory<Eigen::half> *output_diff_data,
651 ScratchAllocator *workspace_allocator = nullptr);
652
653 Stream &ThenNormalizeWithDimensions(
654 const dnn::NormalizeDescriptor &normalize_descriptor,
655 const dnn::BatchDescriptor &dimensions,
656 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
657
658 Stream &ThenNormalizeBackwardWithDimensions(
659 const dnn::NormalizeDescriptor &normalize_descriptor,
660 const dnn::BatchDescriptor &dimensions,
661 const DeviceMemory<float> &raw_data,
662 const DeviceMemory<float> &normalized_data,
663 const DeviceMemory<float> &normalized_variable_gradient,
664 DeviceMemory<float> *raw_variable_gradient,
665 ScratchAllocator *workspace_allocator = nullptr);
666
667 Stream &ThenActivate(dnn::ActivationMode activation_mode,
668 const dnn::BatchDescriptor &dimensions,
669 const DeviceMemory<float> &input_data,
670 DeviceMemory<float> *output_data);
671
672 // Same as ThenActivate, but also takes an options argument that can be used
673 // for platform-specific option flags.
674 Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
675 const dnn::BatchDescriptor &dimensions,
676 const DeviceMemory<float> &input_data,
677 DeviceMemory<float> *output_data,
678 uint64 options);
679
680 Stream &ThenDepthConcatenate(
681 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
682 port::ArraySlice<const DeviceMemory<float> *> input_data,
683 DeviceMemory<float> *output_data);
684
685 Stream &ThenSpaceConcatenate(
686 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
687 port::ArraySlice<const DeviceMemory<float> *> input_data,
688 DeviceMemory<float> *output_data,
689 dnn::SpaceConcatenateMode concat_direction);
690
691 // Change the layout of the data by shrinking one dimension (or set of
692 // dimensions) and growing another dimension (or set of dimensions), while
693 // keeping the total number of data elements constant, and maintaining the
694 // current data ordering.
695 Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
696 const DeviceMemory<float> &input_data,
697 const dnn::BatchDescriptor &output_dimensions,
698 DeviceMemory<float> *output_data);
699
700 // Depth to space takes an X by Y image with depth D*M² and changes it to an
701 // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
702 // the input image is changed to an MxM contiguous area in the output image,
703 // with the values being laid out in raster order specified by
704 // DepthToSpaceLayout, and will have a new depth of D.
705 // See the DoDepthToSpace comment for more information.
706 Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
707 const DeviceMemory<float> &input_data,
708 const dnn::DepthToSpaceLayout &depth_to_space_layout,
709 const int sqrt_depth_reduction,
710 DeviceMemory<float> *output_data);
711
712 // Space to depth is the inverse of depth to space. Space to depth takes each
713 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
714 // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
715 // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
716 // data elements is not changed.
717 Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
718 const DeviceMemory<float> &input_data,
719 const dnn::DepthToSpaceLayout &space_to_depth_layout,
720 const int sqrt_depth_increase,
721 DeviceMemory<float> *output_data);
722
723 Stream &ThenElementwiseOperate(
724 dnn::ElementwiseOperation operation,
725 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
726 port::ArraySlice<const DeviceMemory<float> *> input_data,
727 const dnn::BatchDescriptor &output_dimensions,
728 DeviceMemory<float> *output_data);
729
730 Stream &ThenElementwiseOperateScaledQuantized(
731 dnn::ElementwiseOperation operation,
732 port::ArraySlice<int> input_multiplicands, int output_divisor,
733 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
734 port::ArraySlice<const DeviceMemory<float> *> input_data,
735 const dnn::BatchDescriptor &output_dimensions,
736 DeviceMemory<float> *output_data);
737
738 Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
739 const DeviceMemory<float> &input_data, int64_t left_pad,
740 int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
741 DeviceMemory<float> *output_data);
742
743 Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
744 const DeviceMemory<float> &input_data, int64_t left_trim,
745 int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
746 DeviceMemory<float> *output_data);
747
748 // Grows the input tensor by replicating the X and Y dimensions. The batch and
749 // depth/feature_map dimensions are unchanged. Currently, the input tensor is
750 // limited to X=1 and Y=1.
751 Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
752 const DeviceMemory<float> &input_data,
753 int64_t replicate_x, int64_t replicate_y,
754 DeviceMemory<float> *output_data);
755
756 // See DnnSupport::DoMemcpyD2HQuantized.
757 Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
758 dnn::QuantizedActivationMode mode,
759 void *host_dst, uint64 size);
760
761 // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
762 // and uses the Quantization trait to call the generic version of
763 // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
764 template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)765 Stream &ThenMemcpyD2HQuantized(
766 const DeviceMemory<float> &gpu_unquantized_src,
767 port::MutableArraySlice<ElementType> host_dst) {
768 return ThenMemcpyD2HQuantized(
769 gpu_unquantized_src, Quantization<ElementType>::kModeId,
770 host_dst.data(), host_dst.size() * sizeof(ElementType));
771 }
772
773 // See DnnSupport::DoMemcpyH2DQuantized.
774 Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
775 dnn::QuantizedActivationMode mode,
776 DeviceMemory<float> *gpu_unquantized_dst);
777
778 // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
779 // and uses the Quantization trait to call the generic version of
780 // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
781 template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)782 Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
783 DeviceMemory<float> *gpu_unquantized_dst) {
784 return ThenMemcpyH2DQuantized(
785 host_src.data(), host_src.size() * sizeof(ElementType),
786 Quantization<ElementType>::kModeId, gpu_unquantized_dst);
787 }
788
789 // See DnnSupport::DoCopyHostBuffer2Device.
790 Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
791 DeviceMemory<float> *gpu_unquantized_dst);
792
793 // See DnnSupport::DoCopyDevice2HostBuffer.
794 Stream &ThenCopyDevice2HostBuffer(
795 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
796
797 /////////////////
798 // BLAS support
799
800 // See BlasSupport::DoBlasAsum.
801 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
802 int incx, DeviceMemory<float> *result);
803 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
804 int incx, DeviceMemory<double> *result);
805 Stream &ThenBlasAsum(uint64 elem_count,
806 const DeviceMemory<std::complex<float>> &x, int incx,
807 DeviceMemory<float> *result);
808 Stream &ThenBlasAsum(uint64 elem_count,
809 const DeviceMemory<std::complex<double>> &x, int incx,
810 DeviceMemory<double> *result);
811
812 // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
813 // present in DeviceMemory, it must be an execution-time constant (i.e. a
814 // value
815 // that the stream does not change or populate during the course of
816 // execution). The value is effectively captured at stream-enqueue time.
817 Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
818 const DeviceMemory<float> &x, int incx,
819 DeviceMemory<float> *y, int incy);
820 Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
821 const DeviceMemory<double> &x, int incx,
822 DeviceMemory<double> *y, int incy);
823 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
824 const DeviceMemory<std::complex<float>> &x, int incx,
825 DeviceMemory<std::complex<float>> *y, int incy);
826 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
827 const DeviceMemory<std::complex<double>> &x, int incx,
828 DeviceMemory<std::complex<double>> *y, int incy);
829
830 // See BlasSupport::DoBlasCopy.
831 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
832 int incx, DeviceMemory<float> *y, int incy);
833 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
834 int incx, DeviceMemory<double> *y, int incy);
835 Stream &ThenBlasCopy(uint64 elem_count,
836 const DeviceMemory<std::complex<float>> &x, int incx,
837 DeviceMemory<std::complex<float>> *y, int incy);
838 Stream &ThenBlasCopy(uint64 elem_count,
839 const DeviceMemory<std::complex<double>> &x, int incx,
840 DeviceMemory<std::complex<double>> *y, int incy);
841
842 // See BlasSupport::DoBlasDot.
843 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
844 const DeviceMemory<float> &y, int incy,
845 DeviceMemory<float> *result);
846 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
847 int incx, const DeviceMemory<double> &y, int incy,
848 DeviceMemory<double> *result);
849
850 // See BlasSupport::DoBlasDotc.
851 Stream &ThenBlasDotc(uint64 elem_count,
852 const DeviceMemory<std::complex<float>> &x, int incx,
853 const DeviceMemory<std::complex<float>> &y, int incy,
854 DeviceMemory<std::complex<float>> *result);
855 Stream &ThenBlasDotc(uint64 elem_count,
856 const DeviceMemory<std::complex<double>> &x, int incx,
857 const DeviceMemory<std::complex<double>> &y, int incy,
858 DeviceMemory<std::complex<double>> *result);
859
860 // See BlasSupport::DoBlasDotu.
861 Stream &ThenBlasDotu(uint64 elem_count,
862 const DeviceMemory<std::complex<float>> &x, int incx,
863 const DeviceMemory<std::complex<float>> &y, int incy,
864 DeviceMemory<std::complex<float>> *result);
865 Stream &ThenBlasDotu(uint64 elem_count,
866 const DeviceMemory<std::complex<double>> &x, int incx,
867 const DeviceMemory<std::complex<double>> &y, int incy,
868 DeviceMemory<std::complex<double>> *result);
869
870 // See BlasSupport::DoBlasNrm2.
871 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
872 int incx, DeviceMemory<float> *result);
873 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
874 int incx, DeviceMemory<double> *result);
875 Stream &ThenBlasNrm2(uint64 elem_count,
876 const DeviceMemory<std::complex<float>> &x, int incx,
877 DeviceMemory<float> *result);
878 Stream &ThenBlasNrm2(uint64 elem_count,
879 const DeviceMemory<std::complex<double>> &x, int incx,
880 DeviceMemory<double> *result);
881
882 // See BlasSupport::DoBlasRot.
883 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
884 DeviceMemory<float> *y, int incy, float c, float s);
885 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
886 DeviceMemory<double> *y, int incy, double c, double s);
887 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
888 int incx, DeviceMemory<std::complex<float>> *y, int incy,
889 float c, float s);
890 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
891 int incx, DeviceMemory<std::complex<double>> *y, int incy,
892 double c, double s);
893
894 // See BlasSupport::DoBlasRotg.
895 Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
896 DeviceMemory<float> *c, DeviceMemory<float> *s);
897 Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
898 DeviceMemory<double> *c, DeviceMemory<double> *s);
899 Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
900 DeviceMemory<std::complex<float>> *b,
901 DeviceMemory<float> *c,
902 DeviceMemory<std::complex<float>> *s);
903 Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
904 DeviceMemory<std::complex<double>> *b,
905 DeviceMemory<double> *c,
906 DeviceMemory<std::complex<double>> *s);
907
908 // See BlasSupport::DoBlasRotm.
909 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
910 DeviceMemory<float> *y, int incy,
911 const DeviceMemory<float> ¶m);
912 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
913 DeviceMemory<double> *y, int incy,
914 const DeviceMemory<double> ¶m);
915
916 // See BlasSupport::DoBlasRotmg.
917 Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
918 DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
919 DeviceMemory<float> *param);
920 Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
921 DeviceMemory<double> *x1,
922 const DeviceMemory<double> &y1,
923 DeviceMemory<double> *param);
924
925 // See BlasSupport::DoBlasScal.
926 Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
927 int incx);
928 Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
929 int incx);
930 Stream &ThenBlasScal(uint64 elem_count, float alpha,
931 DeviceMemory<std::complex<float>> *x, int incx);
932 Stream &ThenBlasScal(uint64 elem_count, double alpha,
933 DeviceMemory<std::complex<double>> *x, int incx);
934 Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
935 DeviceMemory<std::complex<float>> *x, int incx);
936 Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
937 DeviceMemory<std::complex<double>> *x, int incx);
938
939 // See BlasSupport::DoBlasSwap.
940 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
941 DeviceMemory<float> *y, int incy);
942 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
943 DeviceMemory<double> *y, int incy);
944 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
945 int incx, DeviceMemory<std::complex<float>> *y,
946 int incy);
947 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
948 int incx, DeviceMemory<std::complex<double>> *y,
949 int incy);
950
951 // See BlasSupport::DoBlasIamax.
952 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
953 int incx, DeviceMemory<int> *result);
954 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
955 int incx, DeviceMemory<int> *result);
956 Stream &ThenBlasIamax(uint64 elem_count,
957 const DeviceMemory<std::complex<float>> &x, int incx,
958 DeviceMemory<int> *result);
959 Stream &ThenBlasIamax(uint64 elem_count,
960 const DeviceMemory<std::complex<double>> &x, int incx,
961 DeviceMemory<int> *result);
962
963 // See BlasSupport::DoBlasIamin.
964 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
965 int incx, DeviceMemory<int> *result);
966 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
967 int incx, DeviceMemory<int> *result);
968 Stream &ThenBlasIamin(uint64 elem_count,
969 const DeviceMemory<std::complex<float>> &x, int incx,
970 DeviceMemory<int> *result);
971 Stream &ThenBlasIamin(uint64 elem_count,
972 const DeviceMemory<std::complex<double>> &x, int incx,
973 DeviceMemory<int> *result);
974
975 // See BlasSupport::DoBlasGbmv.
976 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
977 uint64 ku, float alpha, const DeviceMemory<float> &a,
978 int lda, const DeviceMemory<float> &x, int incx,
979 float beta, DeviceMemory<float> *y, int incy);
980 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
981 uint64 ku, double alpha, const DeviceMemory<double> &a,
982 int lda, const DeviceMemory<double> &x, int incx,
983 double beta, DeviceMemory<double> *y, int incy);
984 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
985 uint64 ku, std::complex<float> alpha,
986 const DeviceMemory<std::complex<float>> &a, int lda,
987 const DeviceMemory<std::complex<float>> &x, int incx,
988 std::complex<float> beta,
989 DeviceMemory<std::complex<float>> *y, int incy);
990 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
991 uint64 ku, std::complex<double> alpha,
992 const DeviceMemory<std::complex<double>> &a, int lda,
993 const DeviceMemory<std::complex<double>> &x, int incx,
994 std::complex<double> beta,
995 DeviceMemory<std::complex<double>> *y, int incy);
996
997 // See BlasSupport::DoBlasGemv.
998 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
999 const DeviceMemory<float> &a, int lda,
1000 const DeviceMemory<float> &x, int incx, float beta,
1001 DeviceMemory<float> *y, int incy);
1002 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
1003 const DeviceMemory<double> &a, int lda,
1004 const DeviceMemory<double> &x, int incx, double beta,
1005 DeviceMemory<double> *y, int incy);
1006 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1007 std::complex<float> alpha,
1008 const DeviceMemory<std::complex<float>> &a, int lda,
1009 const DeviceMemory<std::complex<float>> &x, int incx,
1010 std::complex<float> beta,
1011 DeviceMemory<std::complex<float>> *y, int incy);
1012 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1013 std::complex<double> alpha,
1014 const DeviceMemory<std::complex<double>> &a, int lda,
1015 const DeviceMemory<std::complex<double>> &x, int incx,
1016 std::complex<double> beta,
1017 DeviceMemory<std::complex<double>> *y, int incy);
1018
1019 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
1020 float alpha, const DeviceMemory<float> &a,
1021 int lda, const DeviceMemory<float> &x,
1022 int incx, float beta,
1023 DeviceMemory<float> *y, int incy,
1024 blas::ProfileResult *output_profile_result);
1025 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
1026 double alpha, const DeviceMemory<double> &a,
1027 int lda, const DeviceMemory<double> &x,
1028 int incx, double beta,
1029 DeviceMemory<double> *y, int incy,
1030 blas::ProfileResult *output_profile_result);
1031 Stream &ThenBlasGemvWithProfiling(
1032 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
1033 const DeviceMemory<std::complex<float>> &a, int lda,
1034 const DeviceMemory<std::complex<float>> &x, int incx,
1035 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1036 blas::ProfileResult *output_profile_result);
1037 Stream &ThenBlasGemvWithProfiling(
1038 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
1039 const DeviceMemory<std::complex<double>> &a, int lda,
1040 const DeviceMemory<std::complex<double>> &x, int incx,
1041 std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
1042 int incy, blas::ProfileResult *output_profile_result);
1043
1044 // See BlasSupport::DoBlasGer.
1045 Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
1046 const DeviceMemory<float> &x, int incx,
1047 const DeviceMemory<float> &y, int incy,
1048 DeviceMemory<float> *a, int lda);
1049 Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
1050 const DeviceMemory<double> &x, int incx,
1051 const DeviceMemory<double> &y, int incy,
1052 DeviceMemory<double> *a, int lda);
1053
1054 // See BlasSupport::DoBlasGerc.
1055 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
1056 const DeviceMemory<std::complex<float>> &x, int incx,
1057 const DeviceMemory<std::complex<float>> &y, int incy,
1058 DeviceMemory<std::complex<float>> *a, int lda);
1059 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
1060 const DeviceMemory<std::complex<double>> &x, int incx,
1061 const DeviceMemory<std::complex<double>> &y, int incy,
1062 DeviceMemory<std::complex<double>> *a, int lda);
1063
1064 // See BlasSupport::DoBlasGeru.
1065 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
1066 const DeviceMemory<std::complex<float>> &x, int incx,
1067 const DeviceMemory<std::complex<float>> &y, int incy,
1068 DeviceMemory<std::complex<float>> *a, int lda);
1069 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
1070 const DeviceMemory<std::complex<double>> &x, int incx,
1071 const DeviceMemory<std::complex<double>> &y, int incy,
1072 DeviceMemory<std::complex<double>> *a, int lda);
1073
1074 // See BlasSupport::DoBlasHbmv.
1075 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1076 std::complex<float> alpha,
1077 const DeviceMemory<std::complex<float>> &a, int lda,
1078 const DeviceMemory<std::complex<float>> &x, int incx,
1079 std::complex<float> beta,
1080 DeviceMemory<std::complex<float>> *y, int incy);
1081 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1082 std::complex<double> alpha,
1083 const DeviceMemory<std::complex<double>> &a, int lda,
1084 const DeviceMemory<std::complex<double>> &x, int incx,
1085 std::complex<double> beta,
1086 DeviceMemory<std::complex<double>> *y, int incy);
1087
1088 // See BlasSupport::DoBlasHemv.
1089 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1090 std::complex<float> alpha,
1091 const DeviceMemory<std::complex<float>> &a, int lda,
1092 const DeviceMemory<std::complex<float>> &x, int incx,
1093 std::complex<float> beta,
1094 DeviceMemory<std::complex<float>> *y, int incy);
1095 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1096 std::complex<double> alpha,
1097 const DeviceMemory<std::complex<double>> &a, int lda,
1098 const DeviceMemory<std::complex<double>> &x, int incx,
1099 std::complex<double> beta,
1100 DeviceMemory<std::complex<double>> *y, int incy);
1101
1102 // See BlasSupport::DoBlasHer.
1103 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1104 const DeviceMemory<std::complex<float>> &x, int incx,
1105 DeviceMemory<std::complex<float>> *a, int lda);
1106 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1107 const DeviceMemory<std::complex<double>> &x, int incx,
1108 DeviceMemory<std::complex<double>> *a, int lda);
1109
1110 // See BlasSupport::DoBlasHer2.
1111 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1112 std::complex<float> alpha,
1113 const DeviceMemory<std::complex<float>> &x, int incx,
1114 const DeviceMemory<std::complex<float>> &y, int incy,
1115 DeviceMemory<std::complex<float>> *a, int lda);
1116 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1117 std::complex<double> alpha,
1118 const DeviceMemory<std::complex<double>> &x, int incx,
1119 const DeviceMemory<std::complex<double>> &y, int incy,
1120 DeviceMemory<std::complex<double>> *a, int lda);
1121
1122 // See BlasSupport::DoBlasHpmv.
1123 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1124 std::complex<float> alpha,
1125 const DeviceMemory<std::complex<float>> &ap,
1126 const DeviceMemory<std::complex<float>> &x, int incx,
1127 std::complex<float> beta,
1128 DeviceMemory<std::complex<float>> *y, int incy);
1129 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1130 std::complex<double> alpha,
1131 const DeviceMemory<std::complex<double>> &ap,
1132 const DeviceMemory<std::complex<double>> &x, int incx,
1133 std::complex<double> beta,
1134 DeviceMemory<std::complex<double>> *y, int incy);
1135
1136 // See BlasSupport::DoBlasHpr.
1137 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1138 const DeviceMemory<std::complex<float>> &x, int incx,
1139 DeviceMemory<std::complex<float>> *ap);
1140 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1141 const DeviceMemory<std::complex<double>> &x, int incx,
1142 DeviceMemory<std::complex<double>> *ap);
1143
1144 // See BlasSupport::DoBlasHpr2.
1145 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1146 std::complex<float> alpha,
1147 const DeviceMemory<std::complex<float>> &x, int incx,
1148 const DeviceMemory<std::complex<float>> &y, int incy,
1149 DeviceMemory<std::complex<float>> *ap);
1150 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1151 std::complex<double> alpha,
1152 const DeviceMemory<std::complex<double>> &x, int incx,
1153 const DeviceMemory<std::complex<double>> &y, int incy,
1154 DeviceMemory<std::complex<double>> *ap);
1155
1156 // See BlasSupport::DoBlasSbmv.
1157 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1158 const DeviceMemory<float> &a, int lda,
1159 const DeviceMemory<float> &x, int incx, float beta,
1160 DeviceMemory<float> *y, int incy);
1161 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1162 const DeviceMemory<double> &a, int lda,
1163 const DeviceMemory<double> &x, int incx, double beta,
1164 DeviceMemory<double> *y, int incy);
1165
1166 // See BlasSupport::DoBlasSpmv.
1167 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1168 const DeviceMemory<float> &ap,
1169 const DeviceMemory<float> &x, int incx, float beta,
1170 DeviceMemory<float> *y, int incy);
1171 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1172 const DeviceMemory<double> &ap,
1173 const DeviceMemory<double> &x, int incx, double beta,
1174 DeviceMemory<double> *y, int incy);
1175
1176 // See BlasSupport::DoBlasSpr.
1177 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1178 const DeviceMemory<float> &x, int incx,
1179 DeviceMemory<float> *ap);
1180 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1181 const DeviceMemory<double> &x, int incx,
1182 DeviceMemory<double> *ap);
1183
1184 // See BlasSupport::DoBlasSpr2.
1185 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1186 const DeviceMemory<float> &x, int incx,
1187 const DeviceMemory<float> &y, int incy,
1188 DeviceMemory<float> *ap);
1189 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1190 const DeviceMemory<double> &x, int incx,
1191 const DeviceMemory<double> &y, int incy,
1192 DeviceMemory<double> *ap);
1193
1194 // See BlasSupport::DoBlasSymv.
1195 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1196 const DeviceMemory<float> &a, int lda,
1197 const DeviceMemory<float> &x, int incx, float beta,
1198 DeviceMemory<float> *y, int incy);
1199 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1200 const DeviceMemory<double> &a, int lda,
1201 const DeviceMemory<double> &x, int incx, double beta,
1202 DeviceMemory<double> *y, int incy);
1203
1204 // See BlasSupport::DoBlasSyr.
1205 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1206 const DeviceMemory<float> &x, int incx,
1207 DeviceMemory<float> *a, int lda);
1208 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1209 const DeviceMemory<double> &x, int incx,
1210 DeviceMemory<double> *a, int lda);
1211
1212 // See BlasSupport::DoBlasSyr2.
1213 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1214 const DeviceMemory<float> &x, int incx,
1215 const DeviceMemory<float> &y, int incy,
1216 DeviceMemory<float> *a, int lda);
1217 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1218 const DeviceMemory<double> &x, int incx,
1219 const DeviceMemory<double> &y, int incy,
1220 DeviceMemory<double> *a, int lda);
1221
1222 // See BlasSupport::DoBlasTbmv.
1223 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1224 blas::Diagonal diag, uint64 n, uint64 k,
1225 const DeviceMemory<float> &a, int lda,
1226 DeviceMemory<float> *x, int incx);
1227 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1228 blas::Diagonal diag, uint64 n, uint64 k,
1229 const DeviceMemory<double> &a, int lda,
1230 DeviceMemory<double> *x, int incx);
1231 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1232 blas::Diagonal diag, uint64 n, uint64 k,
1233 const DeviceMemory<std::complex<float>> &a, int lda,
1234 DeviceMemory<std::complex<float>> *x, int incx);
1235 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1236 blas::Diagonal diag, uint64 n, uint64 k,
1237 const DeviceMemory<std::complex<double>> &a, int lda,
1238 DeviceMemory<std::complex<double>> *x, int incx);
1239
1240 // See BlasSupport::DoBlasTbsv.
1241 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1242 blas::Diagonal diag, uint64 n, uint64 k,
1243 const DeviceMemory<float> &a, int lda,
1244 DeviceMemory<float> *x, int incx);
1245 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1246 blas::Diagonal diag, uint64 n, uint64 k,
1247 const DeviceMemory<double> &a, int lda,
1248 DeviceMemory<double> *x, int incx);
1249 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1250 blas::Diagonal diag, uint64 n, uint64 k,
1251 const DeviceMemory<std::complex<float>> &a, int lda,
1252 DeviceMemory<std::complex<float>> *x, int incx);
1253 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1254 blas::Diagonal diag, uint64 n, uint64 k,
1255 const DeviceMemory<std::complex<double>> &a, int lda,
1256 DeviceMemory<std::complex<double>> *x, int incx);
1257
1258 // See BlasSupport::DoBlasTpmv.
1259 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1260 blas::Diagonal diag, uint64 n,
1261 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1262 int incx);
1263 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1264 blas::Diagonal diag, uint64 n,
1265 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1266 int incx);
1267 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1268 blas::Diagonal diag, uint64 n,
1269 const DeviceMemory<std::complex<float>> &ap,
1270 DeviceMemory<std::complex<float>> *x, int incx);
1271 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1272 blas::Diagonal diag, uint64 n,
1273 const DeviceMemory<std::complex<double>> &ap,
1274 DeviceMemory<std::complex<double>> *x, int incx);
1275
1276 // See BlasSupport::DoBlasTpsv.
1277 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1278 blas::Diagonal diag, uint64 n,
1279 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1280 int incx);
1281 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1282 blas::Diagonal diag, uint64 n,
1283 const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1284 int incx);
1285 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1286 blas::Diagonal diag, uint64 n,
1287 const DeviceMemory<std::complex<float>> &ap,
1288 DeviceMemory<std::complex<float>> *x, int incx);
1289 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1290 blas::Diagonal diag, uint64 n,
1291 const DeviceMemory<std::complex<double>> &ap,
1292 DeviceMemory<std::complex<double>> *x, int incx);
1293
1294 // See BlasSupport::DoBlasTrmv.
1295 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1296 blas::Diagonal diag, uint64 n,
1297 const DeviceMemory<float> &a, int lda,
1298 DeviceMemory<float> *x, int incx);
1299 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1300 blas::Diagonal diag, uint64 n,
1301 const DeviceMemory<double> &a, int lda,
1302 DeviceMemory<double> *x, int incx);
1303 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1304 blas::Diagonal diag, uint64 n,
1305 const DeviceMemory<std::complex<float>> &a, int lda,
1306 DeviceMemory<std::complex<float>> *x, int incx);
1307 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1308 blas::Diagonal diag, uint64 n,
1309 const DeviceMemory<std::complex<double>> &a, int lda,
1310 DeviceMemory<std::complex<double>> *x, int incx);
1311
1312 // See BlasSupport::DoBlasTrsv.
1313 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1314 blas::Diagonal diag, uint64 n,
1315 const DeviceMemory<float> &a, int lda,
1316 DeviceMemory<float> *x, int incx);
1317 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1318 blas::Diagonal diag, uint64 n,
1319 const DeviceMemory<double> &a, int lda,
1320 DeviceMemory<double> *x, int incx);
1321 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1322 blas::Diagonal diag, uint64 n,
1323 const DeviceMemory<std::complex<float>> &a, int lda,
1324 DeviceMemory<std::complex<float>> *x, int incx);
1325 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1326 blas::Diagonal diag, uint64 n,
1327 const DeviceMemory<std::complex<double>> &a, int lda,
1328 DeviceMemory<std::complex<double>> *x, int incx);
1329
1330 template <typename InputType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<InputType> * c,int ldc)1331 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1332 uint64 m, uint64 n, uint64 k,
1333 const DeviceMemory<InputType> &a, int lda,
1334 const DeviceMemory<InputType> &b, int ldb,
1335 DeviceMemory<InputType> *c, int ldc) {
1336 InputType alpha{1.0};
1337 InputType beta{0.0};
1338 return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
1339 ldc);
1340 }
1341
1342 template <typename InputType, typename ConstantType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<InputType> * c,int ldc)1343 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1344 uint64 m, uint64 n, uint64 k, ConstantType alpha,
1345 const DeviceMemory<InputType> &a, int lda,
1346 const DeviceMemory<InputType> &b, int ldb,
1347 ConstantType beta, DeviceMemory<InputType> *c,
1348 int ldc) {
1349 static_assert(!std::is_same<InputType, Eigen::half>::value ||
1350 std::is_same<ConstantType, float>::value ||
1351 std::is_same<ConstantType, Eigen::half>::value,
1352 "If input is Eigen::half, constant has to be either "
1353 "Eigen::half or float");
1354 static_assert(
1355 std::is_same<InputType, Eigen::half>::value ||
1356 std::is_same<InputType, ConstantType>::value,
1357 "If input is not Eigen::half, constant and input types have to match");
1358 static_assert(
1359 std::is_same<InputType, Eigen::half>::value ||
1360 std::is_same<InputType, Eigen::bfloat16>::value ||
1361 std::is_same<InputType, float>::value ||
1362 std::is_same<InputType, double>::value ||
1363 std::is_same<InputType, std::complex<float>>::value ||
1364 std::is_same<InputType, std::complex<double>>::value,
1365 "Input can be half, bf16, float, double, std::complex<float> or "
1366 "std::complex<double>");
1367 blas::BlasSupport *blas = parent()->AsBlas();
1368 if (!blas) {
1369 return port::InternalError(
1370 "Attempting to perform BLAS operation using "
1371 "StreamExecutor without BLAS support");
1372 }
1373
1374 void *alpha_ptr = α
1375 void *beta_ptr = β
1376 float alpha_storage, beta_storage;
1377 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1378 &beta_storage);
1379
1380 return blas->DoBlasGemm(this, transa, transb, m, n, k,
1381 blas::ToDataType<InputType>::value, alpha_ptr, a,
1382 lda, b, ldb, beta_ptr, c, ldc);
1383 }
1384
1385 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1386 blas::Transpose transb, uint64 m, uint64 n,
1387 uint64 k, float alpha,
1388 const DeviceMemory<Eigen::half> &a, int lda,
1389 const DeviceMemory<Eigen::half> &b, int ldb,
1390 float beta, DeviceMemory<Eigen::half> *c,
1391 int ldc,
1392 blas::ProfileResult *output_profile_result);
1393 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1394 blas::Transpose transb, uint64 m, uint64 n,
1395 uint64 k, float alpha,
1396 const DeviceMemory<float> &a, int lda,
1397 const DeviceMemory<float> &b, int ldb,
1398 float beta, DeviceMemory<float> *c, int ldc,
1399 blas::ProfileResult *output_profile_result);
1400 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1401 blas::Transpose transb, uint64 m, uint64 n,
1402 uint64 k, double alpha,
1403 const DeviceMemory<double> &a, int lda,
1404 const DeviceMemory<double> &b, int ldb,
1405 double beta, DeviceMemory<double> *c,
1406 int ldc,
1407 blas::ProfileResult *output_profile_result);
1408 Stream &ThenBlasGemmWithProfiling(
1409 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1410 uint64 k, std::complex<float> alpha,
1411 const DeviceMemory<std::complex<float>> &a, int lda,
1412 const DeviceMemory<std::complex<float>> &b, int ldb,
1413 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1414 blas::ProfileResult *output_profile_result);
1415 Stream &ThenBlasGemmWithProfiling(
1416 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1417 uint64 k, std::complex<double> alpha,
1418 const DeviceMemory<std::complex<double>> &a, int lda,
1419 const DeviceMemory<std::complex<double>> &b, int ldb,
1420 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1421 blas::ProfileResult *output_profile_result);
1422
1423 template <typename InputType, typename OutputType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1424 port::Status ThenBlasGemmWithAlgorithm(
1425 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1426 uint64 k, const DeviceMemory<InputType> &a, int lda,
1427 const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c,
1428 int ldc, blas::ComputationType computation_type,
1429 blas::AlgorithmType algorithm,
1430 blas::ProfileResult *output_profile_result) {
1431 OutputType alpha{1};
1432 OutputType beta{0};
1433 return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b,
1434 ldb, beta, c, ldc, computation_type,
1435 algorithm, output_profile_result);
1436 }
1437
1438 template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1439 port::Status ThenBlasGemmWithAlgorithm(
1440 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1441 uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1442 const DeviceMemory<InputType> &b, int ldb, ConstantType beta,
1443 DeviceMemory<OutputType> *c, int ldc,
1444 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1445 blas::ProfileResult *output_profile_result) {
1446 TF_RETURN_IF_ERROR(
1447 CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
1448 computation_type));
1449
1450 blas::BlasSupport *blas = parent()->AsBlas();
1451 if (!blas) {
1452 return port::InternalError(
1453 "Attempting to perform BLAS operation using "
1454 "StreamExecutor without BLAS support");
1455 }
1456
1457 void *alpha_ptr = α
1458 void *beta_ptr = β
1459 float alpha_storage, beta_storage;
1460 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1461 &beta_storage);
1462
1463 port::Status st = blas->DoBlasGemmWithAlgorithm(
1464 this, transa, transb, m, n, k, alpha_ptr, a,
1465 blas::ToDataType<InputType>::value, lda, b,
1466 blas::ToDataType<InputType>::value, ldb, beta_ptr, c,
1467 blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm,
1468 output_profile_result);
1469 if (output_profile_result) {
1470 // The error is recorded in the profile.
1471 return port::Status::OK();
1472 }
1473 return st;
1474 }
1475
1476 template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmStridedBatchedWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,int64_t stride_c,int batch_count,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1477 port::Status ThenBlasGemmStridedBatchedWithAlgorithm(
1478 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1479 uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1480 int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
1481 int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
1482 int64_t stride_c, int batch_count, blas::ComputationType computation_type,
1483 blas::AlgorithmType algorithm,
1484 blas::ProfileResult *output_profile_result) {
1485 TF_RETURN_IF_ERROR(
1486 CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
1487 computation_type));
1488
1489 blas::BlasSupport *blas = parent()->AsBlas();
1490 if (!blas) {
1491 return port::InternalError(
1492 "Attempting to perform BLAS operation using "
1493 "StreamExecutor without BLAS support");
1494 }
1495 void *alpha_ptr = α
1496 void *beta_ptr = β
1497 float alpha_storage, beta_storage;
1498 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1499 &beta_storage);
1500 port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm(
1501 this, transa, transb, m, n, k, alpha_ptr, a,
1502 blas::ToDataType<InputType>::value, stride_a, lda, b,
1503 blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c,
1504 blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
1505 computation_type, algorithm, output_profile_result);
1506 if (output_profile_result) {
1507 // The error is recorded in the profile.
1508 return port::Status::OK();
1509 }
1510 return st;
1511 }
1512
1513 // See BlasSupport::DoBlasGemmBatched.
1514 Stream &ThenBlasGemmBatched(
1515 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1516 uint64 k, float alpha,
1517 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1518 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1519 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1520 int ldc, int batch_count);
1521 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1522 uint64 m, uint64 n, uint64 k, float alpha,
1523 const port::ArraySlice<DeviceMemory<float> *> &a,
1524 int lda,
1525 const port::ArraySlice<DeviceMemory<float> *> &b,
1526 int ldb, float beta,
1527 const port::ArraySlice<DeviceMemory<float> *> &c,
1528 int ldc, int batch_count);
1529 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1530 uint64 m, uint64 n, uint64 k, double alpha,
1531 const port::ArraySlice<DeviceMemory<double> *> &a,
1532 int lda,
1533 const port::ArraySlice<DeviceMemory<double> *> &b,
1534 int ldb, double beta,
1535 const port::ArraySlice<DeviceMemory<double> *> &c,
1536 int ldc, int batch_count);
1537 Stream &ThenBlasGemmBatched(
1538 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1539 uint64 k, std::complex<float> alpha,
1540 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1541 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1542 std::complex<float> beta,
1543 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1544 int batch_count);
1545 Stream &ThenBlasGemmBatched(
1546 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1547 uint64 k, std::complex<double> alpha,
1548 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1549 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1550 std::complex<double> beta,
1551 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1552 int batch_count);
1553 Stream &ThenBlasGemmBatchedWithScratch(
1554 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1555 uint64 k, float alpha,
1556 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1557 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1558 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1559 int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1560 Stream &ThenBlasGemmBatchedWithScratch(
1561 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1562 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1563 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1564 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1565 int batch_count, ScratchAllocator *scratch_allocator);
1566 Stream &ThenBlasGemmBatchedWithScratch(
1567 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1568 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1569 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1570 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1571 int batch_count, ScratchAllocator *scratch_allocator);
1572 Stream &ThenBlasGemmBatchedWithScratch(
1573 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1574 uint64 k, std::complex<float> alpha,
1575 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1576 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1577 std::complex<float> beta,
1578 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1579 int batch_count, ScratchAllocator *scratch_allocator);
1580 Stream &ThenBlasGemmBatchedWithScratch(
1581 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1582 uint64 k, std::complex<double> alpha,
1583 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1584 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1585 std::complex<double> beta,
1586 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1587 int batch_count, ScratchAllocator *scratch_allocator);
1588
1589 template <typename InputType, typename ConstantType>
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<InputType> * c,int ldc,int64_t stride_c,int batch_count)1590 port::Status ThenBlasGemmStridedBatched(
1591 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1592 uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1593 int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
1594 int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc,
1595 int64_t stride_c, int batch_count) {
1596 static_assert(((std::is_same<InputType, Eigen::half>::value ||
1597 std::is_same<InputType, Eigen::bfloat16>::value) &&
1598 std::is_same<ConstantType, float>::value) ||
1599 ((std::is_same<InputType, float>::value ||
1600 std::is_same<InputType, Eigen::half>::value ||
1601 std::is_same<InputType, Eigen::bfloat16>::value ||
1602 std::is_same<InputType, double>::value ||
1603 std::is_same<InputType, std::complex<float>>::value ||
1604 std::is_same<InputType, std::complex<double>>::value) &&
1605 std::is_same<ConstantType, InputType>::value),
1606 "Input or constant type mismatch");
1607 blas::BlasSupport *blas = parent()->AsBlas();
1608 if (!blas) {
1609 return port::InternalError(
1610 "Attempting to perform BLAS operation using "
1611 "StreamExecutor without BLAS support");
1612 }
1613
1614 void *alpha_ptr = α
1615 void *beta_ptr = β
1616 float alpha_storage, beta_storage;
1617 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1618 &beta_storage);
1619
1620 return blas->DoBlasGemmStridedBatched(
1621 this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
1622 alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
1623 stride_c, batch_count);
1624 }
1625
1626 // See BlasSupport::DoBlasHemm.
1627 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1628 uint64 n, std::complex<float> alpha,
1629 const DeviceMemory<std::complex<float>> &a, int lda,
1630 const DeviceMemory<std::complex<float>> &b, int ldb,
1631 std::complex<float> beta,
1632 DeviceMemory<std::complex<float>> *c, int ldc);
1633 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1634 uint64 n, std::complex<double> alpha,
1635 const DeviceMemory<std::complex<double>> &a, int lda,
1636 const DeviceMemory<std::complex<double>> &b, int ldb,
1637 std::complex<double> beta,
1638 DeviceMemory<std::complex<double>> *c, int ldc);
1639
1640 // See BlasSupport::DoBlasHerk.
1641 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1642 uint64 k, float alpha,
1643 const DeviceMemory<std::complex<float>> &a, int lda,
1644 float beta, DeviceMemory<std::complex<float>> *c,
1645 int ldc);
1646 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1647 uint64 k, double alpha,
1648 const DeviceMemory<std::complex<double>> &a, int lda,
1649 double beta, DeviceMemory<std::complex<double>> *c,
1650 int ldc);
1651
1652 // See BlasSupport::DoBlasHer2k.
1653 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1654 uint64 k, std::complex<float> alpha,
1655 const DeviceMemory<std::complex<float>> &a, int lda,
1656 const DeviceMemory<std::complex<float>> &b, int ldb,
1657 float beta, DeviceMemory<std::complex<float>> *c,
1658 int ldc);
1659 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1660 uint64 k, std::complex<double> alpha,
1661 const DeviceMemory<std::complex<double>> &a, int lda,
1662 const DeviceMemory<std::complex<double>> &b, int ldb,
1663 double beta, DeviceMemory<std::complex<double>> *c,
1664 int ldc);
1665
1666 // See BlasSupport::DoBlasSymm.
1667 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1668 uint64 n, float alpha, const DeviceMemory<float> &a,
1669 int lda, const DeviceMemory<float> &b, int ldb,
1670 float beta, DeviceMemory<float> *c, int ldc);
1671 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1672 uint64 n, double alpha, const DeviceMemory<double> &a,
1673 int lda, const DeviceMemory<double> &b, int ldb,
1674 double beta, DeviceMemory<double> *c, int ldc);
1675 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1676 uint64 n, std::complex<float> alpha,
1677 const DeviceMemory<std::complex<float>> &a, int lda,
1678 const DeviceMemory<std::complex<float>> &b, int ldb,
1679 std::complex<float> beta,
1680 DeviceMemory<std::complex<float>> *c, int ldc);
1681 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1682 uint64 n, std::complex<double> alpha,
1683 const DeviceMemory<std::complex<double>> &a, int lda,
1684 const DeviceMemory<std::complex<double>> &b, int ldb,
1685 std::complex<double> beta,
1686 DeviceMemory<std::complex<double>> *c, int ldc);
1687
1688 // See BlasSupport::DoBlasSyrk.
1689 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1690 uint64 k, float alpha, const DeviceMemory<float> &a,
1691 int lda, float beta, DeviceMemory<float> *c, int ldc);
1692 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1693 uint64 k, double alpha, const DeviceMemory<double> &a,
1694 int lda, double beta, DeviceMemory<double> *c, int ldc);
1695 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1696 uint64 k, std::complex<float> alpha,
1697 const DeviceMemory<std::complex<float>> &a, int lda,
1698 std::complex<float> beta,
1699 DeviceMemory<std::complex<float>> *c, int ldc);
1700 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1701 uint64 k, std::complex<double> alpha,
1702 const DeviceMemory<std::complex<double>> &a, int lda,
1703 std::complex<double> beta,
1704 DeviceMemory<std::complex<double>> *c, int ldc);
1705
1706 // See BlasSupport::DoBlasSyr2k.
1707 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1708 uint64 k, float alpha, const DeviceMemory<float> &a,
1709 int lda, const DeviceMemory<float> &b, int ldb,
1710 float beta, DeviceMemory<float> *c, int ldc);
1711 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1712 uint64 k, double alpha, const DeviceMemory<double> &a,
1713 int lda, const DeviceMemory<double> &b, int ldb,
1714 double beta, DeviceMemory<double> *c, int ldc);
1715 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1716 uint64 k, std::complex<float> alpha,
1717 const DeviceMemory<std::complex<float>> &a, int lda,
1718 const DeviceMemory<std::complex<float>> &b, int ldb,
1719 std::complex<float> beta,
1720 DeviceMemory<std::complex<float>> *c, int ldc);
1721 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1722 uint64 k, std::complex<double> alpha,
1723 const DeviceMemory<std::complex<double>> &a, int lda,
1724 const DeviceMemory<std::complex<double>> &b, int ldb,
1725 std::complex<double> beta,
1726 DeviceMemory<std::complex<double>> *c, int ldc);
1727
1728 // See BlasSupport::DoBlasTrmm.
1729 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1730 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1731 uint64 n, float alpha, const DeviceMemory<float> &a,
1732 int lda, DeviceMemory<float> *b, int ldb);
1733 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1734 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1735 uint64 n, double alpha, const DeviceMemory<double> &a,
1736 int lda, DeviceMemory<double> *b, int ldb);
1737 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1738 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1739 uint64 n, std::complex<float> alpha,
1740 const DeviceMemory<std::complex<float>> &a, int lda,
1741 DeviceMemory<std::complex<float>> *b, int ldb);
1742 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1743 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1744 uint64 n, std::complex<double> alpha,
1745 const DeviceMemory<std::complex<double>> &a, int lda,
1746 DeviceMemory<std::complex<double>> *b, int ldb);
1747
1748 // See BlasSupport::DoBlasTrsm.
1749 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1750 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1751 uint64 n, float alpha, const DeviceMemory<float> &a,
1752 int lda, DeviceMemory<float> *b, int ldb);
1753 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1754 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1755 uint64 n, double alpha, const DeviceMemory<double> &a,
1756 int lda, DeviceMemory<double> *b, int ldb);
1757 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1758 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1759 uint64 n, std::complex<float> alpha,
1760 const DeviceMemory<std::complex<float>> &a, int lda,
1761 DeviceMemory<std::complex<float>> *b, int ldb);
1762 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1763 blas::Transpose transa, blas::Diagonal diag, uint64 m,
1764 uint64 n, std::complex<double> alpha,
1765 const DeviceMemory<std::complex<double>> &a, int lda,
1766 DeviceMemory<std::complex<double>> *b, int ldb);
1767
1768 // See BlasSupport::DoBlatLtMatmul.
1769 // Note that we prevent alpha and beta from being used to deduce CType so that
1770 // they can be constructed implicitly from values of type CType. Without this,
1771 // type deduction would fail when this function is called with a value of type
1772 // CType for alpha or beta.
1773 template <typename ABType, typename CType>
1774 Stream &ThenBlasLtMatmul(
1775 const blas::IBlasLtMatmulPlan *plan,
1776 const detail::NonDeducedType<HostOrDeviceScalar<CType>> &alpha,
1777 const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
1778 const detail::NonDeducedType<HostOrDeviceScalar<CType>> &beta,
1779 DeviceMemory<CType> *c, ScratchAllocator *scratch_allocator,
1780 const blas::IBlasLtMatmulAlgorithm *algorithm,
1781 const DeviceMemory<CType> &bias = {},
1782 blas::ProfileResult *output_profile_result = nullptr) {
1783 return ThenBlasLtMatmulImpl(plan, alpha, a, b, beta, c, scratch_allocator,
1784 algorithm, bias, output_profile_result);
1785 }
1786
1787 // See FftSupport::DoFft.
1788 Stream &ThenFft(fft::Plan *plan,
1789 const DeviceMemory<std::complex<float>> &input,
1790 DeviceMemory<std::complex<float>> *output);
1791 Stream &ThenFft(fft::Plan *plan,
1792 const DeviceMemory<std::complex<double>> &input,
1793 DeviceMemory<std::complex<double>> *output);
1794 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1795 DeviceMemory<std::complex<float>> *output);
1796 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1797 DeviceMemory<std::complex<double>> *output);
1798 Stream &ThenFft(fft::Plan *plan,
1799 const DeviceMemory<std::complex<float>> &input,
1800 DeviceMemory<float> *output);
1801 Stream &ThenFft(fft::Plan *plan,
1802 const DeviceMemory<std::complex<double>> &input,
1803 DeviceMemory<double> *output);
1804
1805 // Makes the RNG use the provided value as the basis for further generation.
1806 // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1807 // sources of seed data if the default (high quality) sources are not
1808 // desired.
1809 // For most use cases, this function will not be necessary; each provided
1810 // back-end implementation will be appropriately seeded by default.
1811 // At a minimum 16 bytes of data are required in the seed buffer.
1812 //
1813 // To seed with good (non-reproducible) data:
1814 // File* f = File::Open("/dev/random", "r");
1815 // int64 bytes_read = f->Read(seed_data, bytes_to_read);
1816 // < error checking >
1817 // stream.ThenSetRngSeed(seed_data, bytes_read);
1818 //
1819 // To seed with reproducible data:
1820 // uint64_t seed_data[2] = { <data> };
1821 // stream.ThenSetRngSeed(seed_data, 16);
1822 Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1823
1824 // Populates the memory indicated by values with uniform-random-distribution
1825 // values. TODO(leary) seeding API/description
1826 //
1827 // Uses the type and size of the DeviceMemory to infer what data should be
1828 // populated.
1829 Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1830 Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1831 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1832 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1833 Stream &ThenPopulateRandGaussian(float mean, float stddev,
1834 DeviceMemory<float> *values);
1835 Stream &ThenPopulateRandGaussian(double mean, double stddev,
1836 DeviceMemory<double> *values);
1837
1838 // Entrain onto the stream: a memcpy to a host destination from a GPU source
1839 // of the given target size. host_dst must be a pointer to host memory
1840 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1841 // then registered with StreamExecutor::HostMemoryRegister.
1842 Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1843 uint64 size);
1844
1845 // Entrain onto the stream: a memcpy to a GPU destination from a host source
1846 // of the given target size. host_src must be a pointer to host memory
1847 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1848 // then registered with StreamExecutor::HostMemoryRegister.
1849 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1850 uint64 size);
1851
1852 // Alternative interface for memcpying from device to host that takes an
1853 // array slice. Checks that the destination size can accommodate the host
1854 // slice size.
1855 template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1856 Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1857 port::MutableArraySlice<T> host_dst) {
1858 auto host_size = host_dst.size() * sizeof(T);
1859 CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1860 return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1861 }
1862
1863 // Alternative interface for memcpying from host to device that takes an
1864 // array slice. Checks that the destination size can accommodate the host
1865 // slice size.
1866 template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1867 Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1868 DeviceMemory<T> *gpu_dst) {
1869 auto host_size = host_src.size() * sizeof(T);
1870 CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1871 return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1872 }
1873
1874 // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1875 // of the given target size. gpu_src/dst must be pointers to GPU memory and
1876 // peer access must be enabled between their owning StreamExecutors.
1877 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1878 uint64 size);
1879
1880 // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1881 // ensuring that the host pointer isn't getting confused accidentally with a
1882 // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1883 Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1884 const DeviceMemoryBase &gpu_src, uint64 size) {
1885 return ThenMemcpy(gpu_dst, gpu_src, size);
1886 }
1887
1888 // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1889 // The location must not be null.
1890 Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1891
1892 // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1893 // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1894 // by 4). The location must not be null.
1895 Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1896
1897 // Enqueue a forward operation of the RNN model onto the stream.
1898 // See DnnSupport::DoRnnForward for more details.
1899 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1900 const dnn::RnnSequenceTensorDescriptor &input_desc,
1901 const DeviceMemory<Eigen::half> &input_data,
1902 const DeviceMemory<int> &seq_lengths_data,
1903 const dnn::RnnStateTensorDescriptor &input_h_desc,
1904 const DeviceMemory<Eigen::half> &input_h_data,
1905 const dnn::RnnStateTensorDescriptor &input_c_desc,
1906 const DeviceMemory<Eigen::half> &input_c_data,
1907 const DeviceMemory<Eigen::half> ¶ms,
1908 const dnn::RnnSequenceTensorDescriptor &output_desc,
1909 DeviceMemory<Eigen::half> *output_data,
1910 const dnn::RnnStateTensorDescriptor &output_h_desc,
1911 DeviceMemory<Eigen::half> *output_h_data,
1912 const dnn::RnnStateTensorDescriptor &output_c_desc,
1913 DeviceMemory<Eigen::half> *output_c_data,
1914 bool is_training,
1915 ScratchAllocator *reserve_space_allocator,
1916 ScratchAllocator *workspace_allocator,
1917 dnn::ProfileResult *output_profile_result);
1918
1919 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1920 const dnn::RnnSequenceTensorDescriptor &input_desc,
1921 const DeviceMemory<float> &input_data,
1922 const DeviceMemory<int> &seq_lengths_data,
1923 const dnn::RnnStateTensorDescriptor &input_h_desc,
1924 const DeviceMemory<float> &input_h_data,
1925 const dnn::RnnStateTensorDescriptor &input_c_desc,
1926 const DeviceMemory<float> &input_c_data,
1927 const DeviceMemory<float> ¶ms,
1928 const dnn::RnnSequenceTensorDescriptor &output_desc,
1929 DeviceMemory<float> *output_data,
1930 const dnn::RnnStateTensorDescriptor &output_h_desc,
1931 DeviceMemory<float> *output_h_data,
1932 const dnn::RnnStateTensorDescriptor &output_c_desc,
1933 DeviceMemory<float> *output_c_data, bool is_training,
1934 ScratchAllocator *reserve_space_allocator,
1935 ScratchAllocator *workspace_allocator,
1936 dnn::ProfileResult *output_profile_result);
1937
1938 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1939 const dnn::RnnSequenceTensorDescriptor &input_desc,
1940 const DeviceMemory<double> &input_data,
1941 const DeviceMemory<int> &seq_lengths_data,
1942 const dnn::RnnStateTensorDescriptor &input_h_desc,
1943 const DeviceMemory<double> &input_h_data,
1944 const dnn::RnnStateTensorDescriptor &input_c_desc,
1945 const DeviceMemory<double> &input_c_data,
1946 const DeviceMemory<double> ¶ms,
1947 const dnn::RnnSequenceTensorDescriptor &output_desc,
1948 DeviceMemory<double> *output_data,
1949 const dnn::RnnStateTensorDescriptor &output_h_desc,
1950 DeviceMemory<double> *output_h_data,
1951 const dnn::RnnStateTensorDescriptor &output_c_desc,
1952 DeviceMemory<double> *output_c_data, bool is_training,
1953 ScratchAllocator *reserve_space_allocator,
1954 ScratchAllocator *workspace_allocator,
1955 dnn::ProfileResult *output_profile_result);
1956
1957 // Enqueue a backward operation of the RNN model onto the stream.
1958 // See DnnSupport::DoRnnBackward for more details.
1959 Stream &ThenRnnBackward(
1960 const dnn::RnnDescriptor &rnn_desc,
1961 const dnn::RnnSequenceTensorDescriptor &input_desc,
1962 const DeviceMemory<Eigen::half> &input_data,
1963 const DeviceMemory<int> &seq_lengths_data,
1964 const dnn::RnnStateTensorDescriptor &input_h_desc,
1965 const DeviceMemory<Eigen::half> &input_h_data,
1966 const dnn::RnnStateTensorDescriptor &input_c_desc,
1967 const DeviceMemory<Eigen::half> &input_c_data,
1968 const DeviceMemory<Eigen::half> ¶ms,
1969 const dnn::RnnSequenceTensorDescriptor &output_desc,
1970 const DeviceMemory<Eigen::half> &output_data,
1971 const dnn::RnnStateTensorDescriptor &output_h_desc,
1972 const DeviceMemory<Eigen::half> &output_h_data,
1973 const dnn::RnnStateTensorDescriptor &output_c_desc,
1974 const DeviceMemory<Eigen::half> &output_c_data,
1975 const DeviceMemory<Eigen::half> &output_backprop_data,
1976 const DeviceMemory<Eigen::half> &output_h_backprop_data,
1977 const DeviceMemory<Eigen::half> &output_c_backprop_data,
1978 DeviceMemory<Eigen::half> *input_backprop_data,
1979 DeviceMemory<Eigen::half> *input_h_backprop_data,
1980 DeviceMemory<Eigen::half> *input_c_backprop_data,
1981 DeviceMemory<Eigen::half> *params_backprop_data,
1982 DeviceMemory<uint8> *reserve_space_data,
1983 ScratchAllocator *workspace_allocator,
1984 dnn::ProfileResult *output_profile_result);
1985
1986 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1987 const dnn::RnnSequenceTensorDescriptor &input_desc,
1988 const DeviceMemory<float> &input_data,
1989 const DeviceMemory<int> &seq_lengths_data,
1990 const dnn::RnnStateTensorDescriptor &input_h_desc,
1991 const DeviceMemory<float> &input_h_data,
1992 const dnn::RnnStateTensorDescriptor &input_c_desc,
1993 const DeviceMemory<float> &input_c_data,
1994 const DeviceMemory<float> ¶ms,
1995 const dnn::RnnSequenceTensorDescriptor &output_desc,
1996 const DeviceMemory<float> &output_data,
1997 const dnn::RnnStateTensorDescriptor &output_h_desc,
1998 const DeviceMemory<float> &output_h_data,
1999 const dnn::RnnStateTensorDescriptor &output_c_desc,
2000 const DeviceMemory<float> &output_c_data,
2001 const DeviceMemory<float> &output_backprop_data,
2002 const DeviceMemory<float> &output_h_backprop_data,
2003 const DeviceMemory<float> &output_c_backprop_data,
2004 DeviceMemory<float> *input_backprop_data,
2005 DeviceMemory<float> *input_h_backprop_data,
2006 DeviceMemory<float> *input_c_backprop_data,
2007 DeviceMemory<float> *params_backprop_data,
2008 DeviceMemory<uint8> *reserve_space_data,
2009 ScratchAllocator *workspace_allocator,
2010 dnn::ProfileResult *output_profile_result);
2011
2012 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
2013 const dnn::RnnSequenceTensorDescriptor &input_desc,
2014 const DeviceMemory<double> &input_data,
2015 const DeviceMemory<int> &seq_lengths_data,
2016 const dnn::RnnStateTensorDescriptor &input_h_desc,
2017 const DeviceMemory<double> &input_h_data,
2018 const dnn::RnnStateTensorDescriptor &input_c_desc,
2019 const DeviceMemory<double> &input_c_data,
2020 const DeviceMemory<double> ¶ms,
2021 const dnn::RnnSequenceTensorDescriptor &output_desc,
2022 const DeviceMemory<double> &output_data,
2023 const dnn::RnnStateTensorDescriptor &output_h_desc,
2024 const DeviceMemory<double> &output_h_data,
2025 const dnn::RnnStateTensorDescriptor &output_c_desc,
2026 const DeviceMemory<double> &output_c_data,
2027 const DeviceMemory<double> &output_backprop_data,
2028 const DeviceMemory<double> &output_h_backprop_data,
2029 const DeviceMemory<double> &output_c_backprop_data,
2030 DeviceMemory<double> *input_backprop_data,
2031 DeviceMemory<double> *input_h_backprop_data,
2032 DeviceMemory<double> *input_c_backprop_data,
2033 DeviceMemory<double> *params_backprop_data,
2034 DeviceMemory<uint8> *reserve_space_data,
2035 ScratchAllocator *workspace_allocator,
2036 dnn::ProfileResult *output_profile_result);
2037
2038 // Enqueue a CTCLoss operation onto the stream.
2039 // See DnnSupport::DoCtcLoss for more details.
2040 Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
2041 const DeviceMemory<float> &probs_data,
2042 absl::Span<const int> labels_data,
2043 absl::Span<const int> labels_lengths_data,
2044 absl::Span<const int> input_lengths_data,
2045 DeviceMemory<float> *costs_data,
2046 const dnn::RnnStateTensorDescriptor &grads_desc,
2047 DeviceMemory<float> *grads_data,
2048 ScratchAllocator *workspace_allocator);
2049
2050 // Enqueue onto the stream a operation that transforms a tensor.
2051 // See DnnSupport::DoTransformTensor for more details.
2052 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
2053 dnn::DataType input_type,
2054 const DeviceMemoryBase &input_data,
2055 const dnn::BatchDescriptor &output_desc,
2056 dnn::DataType output_type, float scale,
2057 DeviceMemoryBase *output_data);
2058
2059 // The templated version of the above ThenTransformTensor. Useful when the
2060 // input and output types are statically known.
2061 template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)2062 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
2063 const DeviceMemory<InElemT> &input_data,
2064 const dnn::BatchDescriptor &output_desc,
2065 DeviceMemory<OutElemT> *output_data) {
2066 return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
2067 input_data, output_desc,
2068 dnn::ToDataType<OutElemT>(), output_data);
2069 }
2070
2071 // (Synchronously) block the host code waiting for the operations
2072 // entrained on the stream (enqueued to this point in program
2073 // execution) to complete.
2074 //
2075 // Returns an OK status if the blocking was successful and the stream is ok().
2076 // Otherwise returns an error describing why the blocking failed.
2077 port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_);
2078
2079 // Warning! This method interacts with internal threads in
2080 // sometimes-unpredictable ways and is intended for GPU-Executor-internal
2081 // use
2082 // only. Please check with a member of the FASTR team before making use of
2083 // this method.
2084 //
2085 // Entrains onto the stream a function to be executed on the host at some
2086 // point in the future.
2087 // Async host callbacks DO NOT block the stream as device functions (or as
2088 // synchronous host callbacks). No synchronization is possible with
2089 // asynchronous callbacks; they are strictly fire-and-forget.
2090 // This method is private due to the potential for undefined behavior with
2091 // synchronization using OpenCL user events.
2092 // The ONLY lifetime guarantee in these calls is that the StreamExecutor
2093 // parameter will still be valid - this Stream may not be!
2094 // Any callbacks requiring device API calls must use this method.
2095 Stream &ThenEnqueueOnBackgroundThread(
2096 std::function<void(StreamExecutor *)> task);
2097
2098 // Returns the (opaque) platform-specific backing object. Ownership is not
2099 // transferred to the caller.
implementation()2100 internal::StreamInterface *implementation() { return implementation_.get(); }
2101
2102 // Entrains onto the stream a callback to the host (from the device).
2103 // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
2104 // never fail or its failure is inconsequential.
2105 //
2106 // This is kept for backward compatibility. Future code should use
2107 // ThenDoHostCallbackWithStatus and explicitly return a success status.
2108 // TODO(b/112125301): Eventually remove this method.
2109 Stream &ThenDoHostCallback(std::function<void()> callback);
2110
2111 // Entrains onto the stream a callback to the host (from the device).
2112 // Host callbacks block/occupy the stream just as device functions
2113 // (execute one at a time, block later stream operations).
2114 // Whether the callback return status affects the result of BlockHostUntilDone
2115 // is platform-dependent.
2116 //
2117 // Behavior is undefined when synchronizing using OpenCL user events.
2118 // Behavior is undefined if host callbacks call device routines or insert
2119 // them into any stream.
2120 //
2121 // On certain platforms, ThenDoHostCallback is expected to have significant
2122 // negative effects on performance.
2123 Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
2124
2125 // Runs the given callback after the next call to BlockHostUntilDone on this
2126 // stream (or after the Stream does BlockHostUntilDone in its destructor).
2127 // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
2128 // some use cases.
2129 Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
2130
2131 // Returns the StreamExecutor (parent object) associated with this stream.
parent()2132 StreamExecutor *parent() const {
2133 CHECK(parent_ != nullptr);
2134 return parent_;
2135 }
2136
2137 //
GetCudaComputeCapability()2138 CudaComputeCapability GetCudaComputeCapability() const {
2139 return parent()->GetDeviceDescription().cuda_compute_capability();
2140 }
2141
2142 // Returns the (internal usage) temporary-memory-allocation manager associated
2143 // with this stream.
2144 internal::TemporaryMemoryManager *temporary_memory_manager();
2145
2146 // Returns a debugging string "[stream=0x...,impl=0x...]".
2147 std::string DebugStreamPointers() const;
2148
2149 private:
2150 friend class host::HostBlas; // for parent_.
2151 friend class host::HostFft; // for parent_.
2152 friend class host::HostRng; // for parent_.
2153 template <typename... Args>
2154 friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
2155 friend class ocl::CLBlas; // for parent_.
2156
2157 // Checks whether types match before a call to extended BLAS version.
2158 template <typename InputType, typename OutputType, typename ConstantType>
CheckTypesForExtendedBlas(blas::ComputationType computation_type)2159 port::Status CheckTypesForExtendedBlas(
2160 blas::ComputationType computation_type) {
2161 static_assert(std::is_same<InputType, Eigen::half>::value ||
2162 std::is_same<InputType, Eigen::bfloat16>::value ||
2163 std::is_same<InputType, float>::value ||
2164 std::is_same<InputType, double>::value ||
2165 std::is_same<InputType, int8>::value ||
2166 std::is_same<InputType, std::complex<float>>::value ||
2167 std::is_same<InputType, std::complex<double>>::value,
2168 "The only buffer types supported are: Eigen::half, float, "
2169 "double, int8, std::complex<float> and std::complex<double>");
2170 static_assert(
2171 std::is_same<InputType, OutputType>::value ||
2172 (std::is_same<InputType, int8>::value &&
2173 std::is_same<OutputType, int32>::value),
2174 "Input and output buffer types should be the same unless input is "
2175 "int8 and output is int32");
2176 static_assert(std::is_same<ConstantType, OutputType>::value ||
2177 (std::is_same<ConstantType, float>::value &&
2178 (std::is_same<OutputType, Eigen::half>::value ||
2179 std::is_same<OutputType, Eigen::bfloat16>::value)),
2180 "Constant and output types should match");
2181 blas::ComputationType expected_computation_type =
2182 blas::ToComputationType<ConstantType>::value;
2183 if (expected_computation_type != computation_type &&
2184 !(computation_type == blas::ComputationType::kF32 &&
2185 (expected_computation_type == blas::ComputationType::kF16 ||
2186 expected_computation_type == blas::ComputationType::kBF16AsF32))) {
2187 return port::InternalError(absl::StrCat(
2188 "Alpha/beta type and computation type have to match, got ",
2189 blas::ComputationTypeString(computation_type),
2190 " for computation type, expected: ",
2191 blas::ComputationTypeString(expected_computation_type)));
2192 }
2193 return port::Status::OK();
2194 }
2195
InErrorState()2196 bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
2197 absl::ReaderMutexLock lock(&mu_);
2198 return !status_.ok();
2199 }
2200
2201 // Sets the error state if operation_retcode is false.
2202 // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)2203 void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_) {
2204 if (operation_retcode) {
2205 return;
2206 }
2207 absl::MutexLock lock(&mu_);
2208 status_ = port::InternalError("Unknown error");
2209 }
2210
2211 // Checks the status and logs the error message, if any.
2212 void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_);
2213
SetError()2214 void SetError() { CheckError(false /* = operation_retcode */); }
2215
SetErrorAndLogNoDnnSupport()2216 void SetErrorAndLogNoDnnSupport() {
2217 SetError();
2218 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
2219 "without DNN support";
2220 }
2221
2222 // Runs the set of callbacks that are intended to run after
2223 // BlockHostUntilDone.
2224 void RunAfterBlockHostUntilDoneCallbacks();
2225
2226 // The StreamExecutor that supports the operation of this stream.
2227 StreamExecutor *parent_;
2228
2229 // The platform-dependent implementation that the StreamExecutor interface
2230 // delegates to.
2231 std::unique_ptr<internal::StreamInterface> implementation_;
2232
2233 // mutex that guards the allocation / error state flags.
2234 // Mutable so that it can be obtained via const reader lock.
2235 mutable absl::Mutex mu_;
2236
2237 // Whether Init() was successfully called to allocate this stream on the
2238 // underlying platform. It simply flips from 0 to 1 with a sanity check.
2239 // See StreamExecutor::AllocateStream.
2240 bool allocated_ TF_GUARDED_BY(mu_);
2241
2242 // The last error (if any) of all method calls.
2243 port::Status status_ TF_GUARDED_BY(mu_);
2244
2245 // Sub-streams that are generated from this stream. Each element has a pointer
2246 // to sub-stream and a boolean value indicating if this substream is ready to
2247 // be reused.
2248 std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2249 TF_GUARDED_BY(mu_);
2250
2251 // Streams can allocate temporary memories to help with work they enqueue
2252 // (e.g. for scratch memory spaces). This member tracks those allocations and
2253 // notes when they can be reclaimed -- reclamation is attempted when
2254 // BlockHostUntilDone() is called.
2255 internal::TemporaryMemoryManager temporary_memory_manager_;
2256
2257 // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
2258 std::vector<std::function<void()>> after_block_host_until_done_callbacks_
2259 TF_GUARDED_BY(mu_);
2260
2261 // Implementation of ThenBlasLtMatmul that is shared by all types.
2262 template <typename ABType, typename CType>
2263 Stream &ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan *plan,
2264 const HostOrDeviceScalar<CType> &alpha,
2265 const DeviceMemory<ABType> &a,
2266 const DeviceMemory<ABType> &b,
2267 const HostOrDeviceScalar<CType> &beta,
2268 DeviceMemory<CType> *c,
2269 ScratchAllocator *scratch_allocator,
2270 const blas::IBlasLtMatmulAlgorithm *algorithm,
2271 const DeviceMemory<CType> &bias,
2272 blas::ProfileResult *output_profile_result);
2273
2274 // Non-extended BLAS interface requires alpha/beta to be floats when input
2275 // type is Eigen::half. However, for consistency purposes it is convenient
2276 // for the interface to accept Eigen::half.
2277 template <typename T>
UpcastHalfToFloat(void ** alpha_ptr,void ** beta_ptr,float * alpha_storage,float * beta_storage)2278 void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr,
2279 float *alpha_storage, float *beta_storage) {
2280 if (std::is_same<T, Eigen::half>::value) {
2281 *alpha_storage =
2282 static_cast<float>(*reinterpret_cast<Eigen::half *>(*alpha_ptr));
2283 *beta_storage =
2284 static_cast<float>(*reinterpret_cast<Eigen::half *>(*beta_ptr));
2285 *alpha_ptr = alpha_storage;
2286 *beta_ptr = beta_storage;
2287 } else if (std::is_same<T, Eigen::bfloat16>::value) {
2288 *alpha_storage =
2289 static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*alpha_ptr));
2290 *beta_storage =
2291 static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*beta_ptr));
2292 *alpha_ptr = alpha_storage;
2293 *beta_ptr = beta_storage;
2294 }
2295 }
2296
2297 SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2298 };
2299
2300 ////////////
2301 // Inlines
2302
2303 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)2304 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
2305 const TypedKernel<Params...> &kernel,
2306 Args... args) {
2307 KernelInvocationChecker<std::tuple<Params...>,
2308 std::tuple<Args...>>::CheckAllStaticAssert();
2309 if (ok()) {
2310 // This is the core that allows type-safe kernel launching.
2311 // Since the platforms take kernel arguments as tuples of (void *, size),
2312 // we pack the variadic parameters passed as ...args into the desired
2313 // tuple form and pass that packed form to the StreamExecutor::Launch()
2314 // implementation.
2315 KernelArgsArray<sizeof...(args)> kernel_args;
2316 kernel.PackParams(&kernel_args, args...);
2317 DCHECK(parent_ != nullptr);
2318 bool ok =
2319 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)
2320 .ok();
2321 if (!ok) {
2322 SetError();
2323 LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
2324 }
2325 }
2326 return *this;
2327 }
2328
2329 template <typename T>
2330 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2331 Stream::AllocateTemporaryArray(uint64 element_count) {
2332 return temporary_memory_manager_.AllocateArray<T>(element_count);
2333 }
2334
temporary_memory_manager()2335 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2336 return &temporary_memory_manager_;
2337 }
2338
2339 template <>
2340 struct Quantization<uint8> {
2341 static constexpr dnn::QuantizedActivationMode kModeId =
2342 dnn::QuantizedActivationMode::k8Bit;
2343 };
2344
2345 template <>
2346 struct Quantization<uint16> {
2347 static constexpr dnn::QuantizedActivationMode kModeId =
2348 dnn::QuantizedActivationMode::k16Bit;
2349 };
2350
2351 template <>
2352 struct Quantization<int32> {
2353 static constexpr dnn::QuantizedActivationMode kModeId =
2354 dnn::QuantizedActivationMode::k32Bit;
2355 };
2356
2357 } // namespace stream_executor
2358
2359 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2360