1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20
21 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
23
24 #include <complex>
25 #include <functional>
26 #include <memory>
27 #include <type_traits>
28
29 #include "absl/base/thread_annotations.h"
30 #include "absl/synchronization/mutex.h"
31 #include "tensorflow/compiler/xla/stream_executor/blas.h"
32 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
33 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
34 #include "tensorflow/compiler/xla/stream_executor/event.h"
35 #include "tensorflow/compiler/xla/stream_executor/fft.h"
36 #include "tensorflow/compiler/xla/stream_executor/kernel.h"
37 #include "tensorflow/compiler/xla/stream_executor/launch_dim.h"
38 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
39 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
40 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
41 #include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h"
42
43 namespace stream_executor {
44
45 namespace host {
46 class HostBlas;
47 class HostFft;
48 class HostRng;
49 class HostTimer;
50 } // namespace host
51
52 namespace ocl {
53 class CLBlas;
54 } // namespace ocl
55
56 namespace internal {
57 class StreamInterface;
58 } // namespace internal
59
60 class DeviceMemoryBase;
61 template <typename ElemT>
62 class DeviceMemory;
63
64 class Timer;
65
66 namespace dnn {
67 class BatchDescriptor;
68 class FilterDescriptor;
69 class ConvolutionDescriptor;
70 class ProfileResult;
71 class AlgorithmDesc;
72 } // namespace dnn
73
74 class StreamExecutor;
75 class ScratchAllocator;
76
77 namespace detail {
78
79 // Helper class to prevent a template function argument from being deduced. This
80 // is identical to std::type_identity in C++20.
81 template <typename T>
82 struct NonDeduced {
83 using type = T;
84 };
85 template <typename T>
86 using NonDeducedType = typename NonDeduced<T>::type;
87
88 // Helper to return if `T` is the same type as `First` or any or `Rest`.
89 template <typename T>
is_any_of()90 constexpr bool is_any_of() {
91 return false;
92 }
93
94 template <typename T, typename First, typename... Rest>
is_any_of()95 constexpr bool is_any_of() {
96 return std::is_same_v<T, First> || is_any_of<T, Rest...>();
97 }
98
99 } // namespace detail
100
101 // Convert a type to the corresponding QuantizedActivationMode.
102 template <typename ElementType>
103 struct Quantization;
104
105 // Represents a stream of dependent computations on a GPU device.
106 //
107 // The operations within a stream execute linearly and asynchronously until
108 // BlockHostUntilDone() is invoked, which synchronously joins host code with
109 // the execution of the stream.
110 //
111 // If any given operation fails when entraining work for the stream, ok() will
112 // indicate that an error has occurred. After initialization, once a stream is
113 // !ok(), it will never be ok().
114 //
115 // Thread-safe post-initialization.
116 class Stream {
117 public:
118 // Instantiate a stream tied to parent as a platform executor. Work
119 // entrained onto this stream will be launched/managed on that
120 // StreamExecutor's platform.
121 explicit Stream(StreamExecutor *parent);
122
123 // Deallocates any stream resources that the parent StreamExecutor has
124 // bestowed
125 // upon this object.
126 ~Stream();
127
128 // Returns whether any errors have occurred while entraining work for this
129 // stream.
ok()130 bool ok() const { return !InErrorState(); }
131
132 // Retrieves execution status back into the stream from the underlying
133 // implementation without blocking the stream.
134 //
135 // Normally, Stream::BlockHostUntilDone is used to get execution status.
136 // However, some devices use out-of-band mechnanisms to ensure their streams
137 // have finished on-device work, without needing to block the streams. (These
138 // devices should also override AllowsSyncOnCompletion to return false.) For
139 // these devices, this method can be used after work is finished to retrieve
140 // execution status.
141 port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_);
142
143 // Initialize the stream. This must be performed before entraining any other
144 // operations.
145 Stream &Init() TF_LOCKS_EXCLUDED(mu_);
146
147 // Initializes timer t via the StreamExecutor.
148 Stream &InitTimer(Timer *t);
149
150 // Convenience wrapper around Init() and InitTimer().
151 Stream &InitWithTimer(Timer *t);
152
153 // Get or create a sub-stream from this stream. If there is any sub-stream in
154 // the pool that can be reused then just return this sub-stream. Otherwise
155 // create a new sub-stream.
156 //
157 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
158 Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_);
159
160 // Return the sub-stream back to the host stream so that it can be reused
161 // later. Sub-streams that are !ok() will not be reused.
162 //
163 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
164 void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_);
165
166 // Allocate temporary memories. The stream will deallocate them when blocked
167 // or destroyed.
168 template <typename T>
169 port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
170 AllocateTemporaryArray(uint64_t element_count);
171
172 // Entrains onto the stream of operations: a kernel launch with the given
173 // (variadic) parameters for the invocation. These arguments can be things
174 // like DeviceMemory or primitive types such as int. What arguments you may
175 // pass to a given kernel are noted as the template parameters to the
176 // TypedKernel type that the machocc compiler generates.
177 //
178 // Template parameters:
179 // Params... The type list of formal parameters that the typed kernel
180 // expects, which is matched against Args...
181 // Args... The deduced type list for passed actual arguments
182 //
183 // Implementation: A compile-time compatibility check is performed that has
184 // some leniency versus an exact parameter pack match -- for example,
185 // `const DeviceMemory<T>` is considered "pack compatible" with a
186 // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
187 // perfect forwarding support without rvalue references. It also attempts to
188 // spit out helpful static_assert error traces with information as to the
189 // argument number and types that were mismatched.
190 template <typename... Params, typename... Args>
191 port::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
192 const TypedKernel<Params...> &kernel, Args... args);
193
194 // Record a "start" event for the interval timer at this point in the
195 // stream's execution (relative to the previously and subsequently enqueued
196 // items in the stream's execution). Streams may be started/stopped multiple
197 // times.
198 Stream &ThenStartTimer(Timer *t);
199
200 // Record a "stop" event for the interval timer at this point in the
201 // stream's execution. See also Stream::ThenStartTimer.
202 Stream &ThenStopTimer(Timer *t);
203
204 // TODO(leary) If work is added to the stream that is being depended upon,
205 // then what? Have to describe what happens.
206 template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)207 Stream &ThenWaitFor(Stream *other, Params... more_streams) {
208 return ThenWaitFor(more_streams...).ThenWaitFor(other);
209 }
210
211 // Create a dependency for this stream's next work on the other stream
212 // completing. Does not take ownership of other, and other must not be
213 // null.
214 //
215 // Checks that a stream does not wait for itself, and it is up to the
216 // user to guarantee that a stream does not come to wait on itself in a
217 // cyclic manner; in that case, behavior is undefined.
218 //
219 // N.B. Base recursion case for the variadic ThenWaitFor.
220 Stream &ThenWaitFor(Stream *other);
221
222 // Waits for all streams values in others.
223 // Checks that there is no shallow circular wait (i.e. that "this" is not in
224 // others)
225 template <typename P>
ThenWaitFor(P others)226 Stream &ThenWaitFor(P others) {
227 for (auto &stream : *others) {
228 CHECK_NE(stream.get(), this);
229 ThenWaitFor(stream.get());
230 }
231 return *this;
232 }
233
234 // Waits for an event object to be set.
235 // Note that ThenRecordEvent must have been called on the event before
236 // you call this function; otherwise the event will be considered complete
237 // and this wait will do nothing.
238 Stream &ThenWaitFor(Event *event);
239
240 // Inserts the specified event into the end of this stream. Once the stream
241 // has processed all events prior to the insertion point, the event will be
242 // marked as completed.
243 // The stream does not take ownership of event - meaning that event's lifetime
244 // must extend past the point at which it is marked complete!
245 Stream &ThenRecordEvent(Event *event);
246
247 ////////////////
248 // DNN support
249 //
250 // See DnnSupport::* for comments on the following methods.
251
252 Stream &ThenBatchNormalizationForward(
253 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
254 const DeviceMemory<float> &offset,
255 const DeviceMemory<float> &estimated_mean,
256 const DeviceMemory<float> &estimated_variance,
257 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
258 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
259 const double exponential_average_factor,
260 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
261 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
262 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
263 bool is_training, ScratchAllocator *reserve_space_allocator,
264 ScratchAllocator *workspace_allocator);
265
266 Stream &ThenBatchNormalizationBackward(
267 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
268 const DeviceMemory<float> &scale, const DeviceMemory<float> &offset,
269 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
270 const DeviceMemory<float> &y, const dnn::BatchDescriptor &x_desc,
271 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
272 dnn::ActivationMode activation_mode, DeviceMemory<float> *x_backprop,
273 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
274 DeviceMemory<float> *side_input_backprop,
275 DeviceMemory<uint8> *reserve_space_data,
276 ScratchAllocator *workspace_allocator);
277
278 Stream &ThenBatchNormalizationForward(
279 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
280 const DeviceMemory<float> &offset,
281 const DeviceMemory<float> &estimated_mean,
282 const DeviceMemory<float> &estimated_variance,
283 const DeviceMemory<Eigen::half> &side_input,
284 const dnn::BatchDescriptor &x_desc,
285 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
286 const double exponential_average_factor,
287 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
288 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
289 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
290 bool is_training, ScratchAllocator *reserve_space_allocator,
291 ScratchAllocator *workspace_allocator);
292
293 Stream &ThenBatchNormalizationBackward(
294 const DeviceMemory<Eigen::half> &y_backprop,
295 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
296 const DeviceMemory<float> &offset, const DeviceMemory<float> &mean,
297 const DeviceMemory<float> &inv_var, const DeviceMemory<Eigen::half> &y,
298 const dnn::BatchDescriptor &x_desc,
299 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
300 dnn::ActivationMode activation_mode,
301 DeviceMemory<Eigen::half> *x_backprop,
302 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
303 DeviceMemory<Eigen::half> *side_input_backprop,
304 DeviceMemory<uint8> *reserve_space_data,
305 ScratchAllocator *workspace_allocator);
306
307 Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
308 const DeviceMemory<float> &input_data,
309 const dnn::FilterDescriptor &filter_descriptor,
310 const DeviceMemory<float> &filter_data,
311 const dnn::ConvolutionDescriptor &convolution_descriptor,
312 const dnn::BatchDescriptor &output_descriptor,
313 DeviceMemory<float> *output);
314
315 Stream &ThenConvolveQuantized(
316 const dnn::BatchDescriptor &input_descriptor,
317 const DeviceMemory<float> &input_data,
318 const dnn::FilterDescriptor &filter_descriptor,
319 const DeviceMemory<int8> &filter_coefficients,
320 const DeviceMemory<float> &coefficient_scales,
321 const dnn::ConvolutionDescriptor &convolution_descriptor,
322 const dnn::BatchDescriptor &output_descriptor,
323 DeviceMemory<float> *output_data);
324
325 Stream &ThenConvolveQuantized(
326 const dnn::BatchDescriptor &input_descriptor,
327 const DeviceMemory<float> &input_data,
328 const dnn::FilterDescriptor &filter_descriptor,
329 const DeviceMemory<int16> &filter_coefficients,
330 const DeviceMemory<float> &coefficient_scales,
331 const dnn::ConvolutionDescriptor &convolution_descriptor,
332 const dnn::BatchDescriptor &output_descriptor,
333 DeviceMemory<float> *output_data);
334
335 template <typename InputType, typename OutputType>
ConvolveWithAlgorithm(dnn::ConvolutionKind kind,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<InputType> input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<InputType> filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)336 port::Status ConvolveWithAlgorithm(
337 dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor,
338 DeviceMemory<InputType> input_data,
339 const dnn::FilterDescriptor &filter_descriptor,
340 DeviceMemory<InputType> filter_data,
341 const dnn::BatchDescriptor &output_descriptor,
342 DeviceMemory<OutputType> output_data,
343 const dnn::ConvolutionDescriptor &convolution_descriptor,
344 ScratchAllocator *scratch_allocator,
345 const dnn::AlgorithmConfig &algorithm_config,
346 dnn::ProfileResult *output_profile_result) {
347 DeviceMemory<uint8> scratch_memory;
348 dnn::AlgorithmDesc algorithm_desc;
349 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
350 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
351 kind, this, input_descriptor, input_data, filter_descriptor,
352 filter_data, output_descriptor, output_data, convolution_descriptor,
353 algorithm_config, scratch_allocator, &algorithm_desc,
354 &scratch_memory));
355 return dnn->DoConvolve(kind, dnn::ToDataType<InputType>::value,
356 dnn::ToDataType<OutputType>::value, this,
357 input_descriptor, input_data, filter_descriptor,
358 filter_data, output_descriptor, output_data,
359 convolution_descriptor, algorithm_desc,
360 scratch_memory, output_profile_result);
361 }
362 return port::UnimplementedError("DNN library is not found.");
363 }
364
365 template <typename InputT, typename ScaleT, typename SideInputT,
366 typename BiasT, typename OutputT>
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<InputT> & conv_input_data,ScaleT conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputT> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<SideInputT> & side_input_data,ScaleT side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasT> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputT> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)367 port::Status FusedConvolveWithAlgorithm(
368 const dnn::BatchDescriptor &conv_input_descriptor,
369 const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale,
370 const dnn::FilterDescriptor &filter_descriptor,
371 const DeviceMemory<InputT> &filter_data,
372 const dnn::ConvolutionDescriptor &convolution_descriptor,
373 const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale,
374 const dnn::BatchDescriptor &bias_descriptor,
375 const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode,
376 const dnn::BatchDescriptor &output_descriptor,
377 DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator,
378 const dnn::AlgorithmConfig &algorithm_config,
379 dnn::ProfileResult *output_profile_result) {
380 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
381 return dnn->DoFusedConvolve(
382 this, dnn::ToDataType<InputT>::value,
383 dnn::ToDataType<SideInputT>::value, dnn::ToDataType<BiasT>::value,
384 dnn::ToDataType<OutputT>::value, conv_input_descriptor,
385 conv_input_data, conv_input_scale, filter_descriptor, filter_data,
386 convolution_descriptor, side_input_data, side_input_scale,
387 bias_descriptor, biases, activation_mode, output_descriptor, *output,
388 scratch_allocator, algorithm_config, output_profile_result);
389 }
390 return port::UnimplementedError("DNN library is not found.");
391 }
392
ConvolveRunnerFromDesc(const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor)393 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc(
394 const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind,
395 dnn::DataType element_type, dnn::DataType output_type,
396 const dnn::BatchDescriptor &input_descriptor,
397 const dnn::FilterDescriptor &filter_descriptor,
398 const dnn::BatchDescriptor &output_descriptor,
399 const dnn::ConvolutionDescriptor &convolution_descriptor) {
400 dnn::DnnSupport *dnn_support = parent_->AsDnn();
401 if (!dnn_support) {
402 return port::UnimplementedError("DNN library is not found.");
403 }
404 return dnn_support->ConvolveRunnerFromDesc(
405 this, algorithm_desc, kind, element_type, output_type, input_descriptor,
406 filter_descriptor, output_descriptor, convolution_descriptor);
407 }
408
409 port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
FusedConvolveRunnerFromDesc(const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_input_scale,double side_input_scale,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode)410 FusedConvolveRunnerFromDesc(
411 const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind,
412 dnn::DataType element_type, dnn::DataType bias_type,
413 dnn::DataType output_type, double conv_input_scale,
414 double side_input_scale, double leakyrelu_alpha,
415 const dnn::BatchDescriptor &input_descriptor,
416 const dnn::FilterDescriptor &filter_descriptor,
417 const dnn::BatchDescriptor &bias_descriptor,
418 const dnn::BatchDescriptor &output_descriptor,
419 const dnn::ConvolutionDescriptor &convolution_descriptor,
420 dnn::ActivationMode activation_mode) {
421 dnn::DnnSupport *dnn_support = parent_->AsDnn();
422 if (!dnn_support) {
423 return port::UnimplementedError("DNN library is not found.");
424 }
425 return dnn_support->FusedConvolveRunnerFromDesc(
426 this, algorithm_desc, kind, element_type, bias_type, output_type,
427 conv_input_scale, side_input_scale, leakyrelu_alpha, input_descriptor,
428 filter_descriptor, bias_descriptor, output_descriptor,
429 convolution_descriptor, activation_mode);
430 }
431
432 Stream &ThenSeparableConvolve(
433 const dnn::BatchDescriptor &input_descriptor,
434 const DeviceMemory<float> &input_data,
435 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
436 const DeviceMemory<float> &first_weights,
437 const DeviceMemory<float> &second_weights,
438 const dnn::ConvolutionDescriptor &convolution_descriptor,
439 const dnn::BatchDescriptor &output_descriptor,
440 DeviceMemory<float> *output);
441
442 Stream &ThenMatMul(const DeviceMemory<float> &input_data,
443 const DeviceMemory<float> &weights,
444 const dnn::BatchDescriptor &input_dimensions,
445 const dnn::BatchDescriptor &output_dimensions,
446 DeviceMemory<float> *output_data);
447
448 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
449 const DeviceMemory<int8> &weights,
450 const DeviceMemory<float> &weight_scales,
451 const dnn::BatchDescriptor &input_dimensions,
452 const dnn::BatchDescriptor &output_dimensions,
453 DeviceMemory<float> *output_data);
454
455 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
456 const DeviceMemory<int16> &weights,
457 const DeviceMemory<float> &weight_scales,
458 const dnn::BatchDescriptor &input_dimensions,
459 const dnn::BatchDescriptor &output_dimensions,
460 DeviceMemory<float> *output_data);
461
462 Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
463 const DeviceMemory<float> &biases,
464 const dnn::BatchDescriptor &dimensions,
465 DeviceMemory<float> *output_data);
466
467 template <typename ElementType>
468 port::Status ThenPoolForward(
469 const dnn::PoolingDescriptor &pooling_dimensions,
470 const dnn::BatchDescriptor &input_dimensions,
471 const DeviceMemory<ElementType> &input_data,
472 const dnn::BatchDescriptor &output_dimensions,
473 DeviceMemory<ElementType> *output_data,
474 ScratchAllocator *workspace_allocator = nullptr) {
475 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
476 return dnn->DoPoolForward(dnn::ToDataType<ElementType>::value, this,
477 pooling_dimensions, input_dimensions,
478 input_data, output_dimensions, *output_data,
479 workspace_allocator);
480 }
481 return port::UnimplementedError("DNN library is not found.");
482 }
483
484 template <typename ElementType>
485 port::Status ThenPoolBackward(
486 const dnn::PoolingDescriptor &pooling_dimensions,
487 const dnn::BatchDescriptor &input_dimensions,
488 const DeviceMemory<ElementType> &input_data,
489 const dnn::BatchDescriptor &output_dimensions,
490 const DeviceMemory<ElementType> &output_data,
491 const DeviceMemory<ElementType> &input_diff_data,
492 DeviceMemory<ElementType> *output_diff_data,
493 ScratchAllocator *workspace_allocator = nullptr) {
494 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
495 return dnn->DoPoolBackward(
496 dnn::ToDataType<ElementType>::value, this, pooling_dimensions,
497 input_dimensions, input_data, output_dimensions, output_data,
498 input_diff_data, *output_diff_data, workspace_allocator);
499 }
500 return port::UnimplementedError("DNN library is not found.");
501 }
502
503 Stream &ThenNormalizeWithDimensions(
504 const dnn::NormalizeDescriptor &normalize_descriptor,
505 const dnn::BatchDescriptor &dimensions,
506 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
507
508 Stream &ThenNormalizeBackwardWithDimensions(
509 const dnn::NormalizeDescriptor &normalize_descriptor,
510 const dnn::BatchDescriptor &dimensions,
511 const DeviceMemory<float> &raw_data,
512 const DeviceMemory<float> &normalized_data,
513 const DeviceMemory<float> &normalized_variable_gradient,
514 DeviceMemory<float> *raw_variable_gradient,
515 ScratchAllocator *workspace_allocator = nullptr);
516
517 Stream &ThenActivate(dnn::ActivationMode activation_mode,
518 const dnn::BatchDescriptor &dimensions,
519 const DeviceMemory<float> &input_data,
520 DeviceMemory<float> *output_data);
521
522 // Same as ThenActivate, but also takes an options argument that can be used
523 // for platform-specific option flags.
524 Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
525 const dnn::BatchDescriptor &dimensions,
526 const DeviceMemory<float> &input_data,
527 DeviceMemory<float> *output_data,
528 uint64_t options);
529
530 Stream &ThenDepthConcatenate(
531 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
532 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
533 DeviceMemory<float> *output_data);
534
535 Stream &ThenSpaceConcatenate(
536 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
537 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
538 DeviceMemory<float> *output_data,
539 dnn::SpaceConcatenateMode concat_direction);
540
541 // Change the layout of the data by shrinking one dimension (or set of
542 // dimensions) and growing another dimension (or set of dimensions), while
543 // keeping the total number of data elements constant, and maintaining the
544 // current data ordering.
545 Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
546 const DeviceMemory<float> &input_data,
547 const dnn::BatchDescriptor &output_dimensions,
548 DeviceMemory<float> *output_data);
549
550 // Depth to space takes an X by Y image with depth D*M² and changes it to an
551 // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
552 // the input image is changed to an MxM contiguous area in the output image,
553 // with the values being laid out in raster order specified by
554 // DepthToSpaceLayout, and will have a new depth of D.
555 // See the DoDepthToSpace comment for more information.
556 Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
557 const DeviceMemory<float> &input_data,
558 const dnn::DepthToSpaceLayout &depth_to_space_layout,
559 const int sqrt_depth_reduction,
560 DeviceMemory<float> *output_data);
561
562 // Space to depth is the inverse of depth to space. Space to depth takes each
563 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
564 // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
565 // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
566 // data elements is not changed.
567 Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
568 const DeviceMemory<float> &input_data,
569 const dnn::DepthToSpaceLayout &space_to_depth_layout,
570 const int sqrt_depth_increase,
571 DeviceMemory<float> *output_data);
572
573 Stream &ThenElementwiseOperate(
574 dnn::ElementwiseOperation operation,
575 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
576 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
577 const dnn::BatchDescriptor &output_dimensions,
578 DeviceMemory<float> *output_data);
579
580 Stream &ThenElementwiseOperateScaledQuantized(
581 dnn::ElementwiseOperation operation,
582 port::ArraySlice<int> input_multiplicands, // non-absl ok
583 int output_divisor,
584 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
585 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
586 const dnn::BatchDescriptor &output_dimensions,
587 DeviceMemory<float> *output_data);
588
589 Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
590 const DeviceMemory<float> &input_data, int64_t left_pad,
591 int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
592 DeviceMemory<float> *output_data);
593
594 Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
595 const DeviceMemory<float> &input_data, int64_t left_trim,
596 int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
597 DeviceMemory<float> *output_data);
598
599 // Grows the input tensor by replicating the X and Y dimensions. The batch and
600 // depth/feature_map dimensions are unchanged. Currently, the input tensor is
601 // limited to X=1 and Y=1.
602 Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
603 const DeviceMemory<float> &input_data,
604 int64_t replicate_x, int64_t replicate_y,
605 DeviceMemory<float> *output_data);
606
607 // See DnnSupport::DoMemcpyD2HQuantized.
608 Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
609 dnn::QuantizedActivationMode mode,
610 void *host_dst, uint64_t size);
611
612 // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
613 // and uses the Quantization trait to call the generic version of
614 // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
615 template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)616 Stream &ThenMemcpyD2HQuantized(
617 const DeviceMemory<float> &gpu_unquantized_src,
618 port::MutableArraySlice<ElementType> host_dst) {
619 return ThenMemcpyD2HQuantized(
620 gpu_unquantized_src, Quantization<ElementType>::kModeId,
621 host_dst.data(), host_dst.size() * sizeof(ElementType));
622 }
623
624 // See DnnSupport::DoMemcpyH2DQuantized.
625 Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64_t size,
626 dnn::QuantizedActivationMode mode,
627 DeviceMemory<float> *gpu_unquantized_dst);
628
629 // Template version of ThenMemcpyH2DQuantized that takes an array slice
630 // and uses the Quantization trait to call the generic version of
631 // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
632 template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)633 Stream &ThenMemcpyH2DQuantized(
634 port::ArraySlice<ElementType> host_src, // non-absl ok
635 DeviceMemory<float> *gpu_unquantized_dst) {
636 return ThenMemcpyH2DQuantized(
637 host_src.data(), host_src.size() * sizeof(ElementType),
638 Quantization<ElementType>::kModeId, gpu_unquantized_dst);
639 }
640
641 // See DnnSupport::DoCopyHostBuffer2Device.
642 Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
643 DeviceMemory<float> *gpu_unquantized_dst);
644
645 // See DnnSupport::DoCopyDevice2HostBuffer.
646 Stream &ThenCopyDevice2HostBuffer(
647 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
648
649 /////////////////
650 // BLAS support
651
652 // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
653 // present in DeviceMemory, it must be an execution-time constant (i.e. a
654 // value
655 // that the stream does not change or populate during the course of
656 // execution). The value is effectively captured at stream-enqueue time.
657 Stream &ThenBlasAxpy(uint64_t elem_count, float alpha,
658 const DeviceMemory<float> &x, int incx,
659 DeviceMemory<float> *y, int incy);
660 Stream &ThenBlasAxpy(uint64_t elem_count, double alpha,
661 const DeviceMemory<double> &x, int incx,
662 DeviceMemory<double> *y, int incy);
663 Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<float> alpha,
664 const DeviceMemory<std::complex<float>> &x, int incx,
665 DeviceMemory<std::complex<float>> *y, int incy);
666 Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<double> alpha,
667 const DeviceMemory<std::complex<double>> &x, int incx,
668 DeviceMemory<std::complex<double>> *y, int incy);
669
670 // See BlasSupport::DoBlasCopy.
671 Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<float> &x,
672 int incx, DeviceMemory<float> *y, int incy);
673 Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<double> &x,
674 int incx, DeviceMemory<double> *y, int incy);
675 Stream &ThenBlasCopy(uint64_t elem_count,
676 const DeviceMemory<std::complex<float>> &x, int incx,
677 DeviceMemory<std::complex<float>> *y, int incy);
678 Stream &ThenBlasCopy(uint64_t elem_count,
679 const DeviceMemory<std::complex<double>> &x, int incx,
680 DeviceMemory<std::complex<double>> *y, int incy);
681
682 // See BlasSupport::DoBlasScal.
683 Stream &ThenBlasScal(uint64_t elem_count, float alpha, DeviceMemory<float> *x,
684 int incx);
685 Stream &ThenBlasScal(uint64_t elem_count, double alpha,
686 DeviceMemory<double> *x, int incx);
687 Stream &ThenBlasScal(uint64_t elem_count, float alpha,
688 DeviceMemory<std::complex<float>> *x, int incx);
689 Stream &ThenBlasScal(uint64_t elem_count, double alpha,
690 DeviceMemory<std::complex<double>> *x, int incx);
691 Stream &ThenBlasScal(uint64_t elem_count, std::complex<float> alpha,
692 DeviceMemory<std::complex<float>> *x, int incx);
693 Stream &ThenBlasScal(uint64_t elem_count, std::complex<double> alpha,
694 DeviceMemory<std::complex<double>> *x, int incx);
695
696 // See BlasSupport::DoBlasGemv.
697 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, float alpha,
698 const DeviceMemory<float> &a, int lda,
699 const DeviceMemory<float> &x, int incx, float beta,
700 DeviceMemory<float> *y, int incy);
701 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
702 double alpha, const DeviceMemory<double> &a, int lda,
703 const DeviceMemory<double> &x, int incx, double beta,
704 DeviceMemory<double> *y, int incy);
705 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
706 std::complex<float> alpha,
707 const DeviceMemory<std::complex<float>> &a, int lda,
708 const DeviceMemory<std::complex<float>> &x, int incx,
709 std::complex<float> beta,
710 DeviceMemory<std::complex<float>> *y, int incy);
711 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
712 std::complex<double> alpha,
713 const DeviceMemory<std::complex<double>> &a, int lda,
714 const DeviceMemory<std::complex<double>> &x, int incx,
715 std::complex<double> beta,
716 DeviceMemory<std::complex<double>> *y, int incy);
717
718 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n,
719 float alpha, const DeviceMemory<float> &a,
720 int lda, const DeviceMemory<float> &x,
721 int incx, float beta,
722 DeviceMemory<float> *y, int incy,
723 blas::ProfileResult *output_profile_result);
724 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n,
725 double alpha, const DeviceMemory<double> &a,
726 int lda, const DeviceMemory<double> &x,
727 int incx, double beta,
728 DeviceMemory<double> *y, int incy,
729 blas::ProfileResult *output_profile_result);
730 Stream &ThenBlasGemvWithProfiling(
731 blas::Transpose trans, uint64_t m, uint64 n, std::complex<float> alpha,
732 const DeviceMemory<std::complex<float>> &a, int lda,
733 const DeviceMemory<std::complex<float>> &x, int incx,
734 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
735 blas::ProfileResult *output_profile_result);
736 Stream &ThenBlasGemvWithProfiling(
737 blas::Transpose trans, uint64_t m, uint64 n, std::complex<double> alpha,
738 const DeviceMemory<std::complex<double>> &a, int lda,
739 const DeviceMemory<std::complex<double>> &x, int incx,
740 std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
741 int incy, blas::ProfileResult *output_profile_result);
742
743 // See BlasSupport::DoBlasSbmv.
744 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, float alpha,
745 const DeviceMemory<float> &a, int lda,
746 const DeviceMemory<float> &x, int incx, float beta,
747 DeviceMemory<float> *y, int incy);
748 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k,
749 double alpha, const DeviceMemory<double> &a, int lda,
750 const DeviceMemory<double> &x, int incx, double beta,
751 DeviceMemory<double> *y, int incy);
752
753 template <typename InputType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<InputType> * c,int ldc,blas::ComputePrecision precision)754 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
755 uint64_t m, uint64 n, uint64 k,
756 const DeviceMemory<InputType> &a, int lda,
757 const DeviceMemory<InputType> &b, int ldb,
758 DeviceMemory<InputType> *c, int ldc,
759 blas::ComputePrecision precision) {
760 InputType alpha{1.0};
761 InputType beta{0.0};
762 return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
763 ldc, precision);
764 }
765
766 // TODO(parkers): Update all callers to pass kDefaultComputePrecision.
767 template <typename InputType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<InputType> * c,int ldc)768 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
769 uint64_t m, uint64 n, uint64 k,
770 const DeviceMemory<InputType> &a, int lda,
771 const DeviceMemory<InputType> &b, int ldb,
772 DeviceMemory<InputType> *c, int ldc) {
773 return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc,
774 blas::kDefaultComputePrecision);
775 }
776
777 template <typename InputType, typename ConstantType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<InputType> * c,int ldc,blas::ComputePrecision precision)778 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
779 uint64_t m, uint64 n, uint64 k, ConstantType alpha,
780 const DeviceMemory<InputType> &a, int lda,
781 const DeviceMemory<InputType> &b, int ldb,
782 ConstantType beta, DeviceMemory<InputType> *c,
783 int ldc, blas::ComputePrecision precision) {
784 static_assert(
785 detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16, float,
786 double, std::complex<float>, std::complex<double>>(),
787 "Input can be half, bf16, float, double, std::complex<float> or "
788 "std::complex<double>");
789 static_assert(!std::is_same_v<InputType, Eigen::half> ||
790 detail::is_any_of<ConstantType, float, Eigen::half>(),
791 "If input is Eigen::half, constant has to be either "
792 "Eigen::half or float");
793 static_assert(
794 detail::is_any_of<InputType, Eigen::half, ConstantType>(),
795 "If input is not Eigen::half, constant and input types have to match");
796 blas::BlasSupport *blas = parent()->AsBlas();
797 if (!blas) {
798 return port::InternalError(
799 "Attempting to perform BLAS operation using "
800 "StreamExecutor without BLAS support");
801 }
802
803 void *alpha_ptr = α
804 void *beta_ptr = β
805 float alpha_storage, beta_storage;
806 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
807 &beta_storage);
808
809 return blas->DoBlasGemm(this, transa, transb, m, n, k,
810 blas::ToDataType<InputType>::value, alpha_ptr, a,
811 lda, b, ldb, beta_ptr, c, ldc, precision);
812 }
813
814 // TODO(parkers): Update all callers to pass kDefaultComputePrecision.
815 template <typename InputType, typename ConstantType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<InputType> * c,int ldc)816 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
817 uint64_t m, uint64 n, uint64 k, ConstantType alpha,
818 const DeviceMemory<InputType> &a, int lda,
819 const DeviceMemory<InputType> &b, int ldb,
820 ConstantType beta, DeviceMemory<InputType> *c,
821 int ldc) {
822 return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
823 ldc, blas::kDefaultComputePrecision);
824 }
825
826 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
827 blas::Transpose transb, uint64_t m,
828 uint64 n, uint64_t k, float alpha,
829 const DeviceMemory<Eigen::half> &a, int lda,
830 const DeviceMemory<Eigen::half> &b, int ldb,
831 float beta, DeviceMemory<Eigen::half> *c,
832 int ldc,
833 blas::ProfileResult *output_profile_result);
834 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
835 blas::Transpose transb, uint64_t m,
836 uint64 n, uint64_t k, float alpha,
837 const DeviceMemory<float> &a, int lda,
838 const DeviceMemory<float> &b, int ldb,
839 float beta, DeviceMemory<float> *c, int ldc,
840 blas::ProfileResult *output_profile_result);
841 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
842 blas::Transpose transb, uint64_t m,
843 uint64 n, uint64_t k, double alpha,
844 const DeviceMemory<double> &a, int lda,
845 const DeviceMemory<double> &b, int ldb,
846 double beta, DeviceMemory<double> *c,
847 int ldc,
848 blas::ProfileResult *output_profile_result);
849 Stream &ThenBlasGemmWithProfiling(
850 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
851 uint64_t k, std::complex<float> alpha,
852 const DeviceMemory<std::complex<float>> &a, int lda,
853 const DeviceMemory<std::complex<float>> &b, int ldb,
854 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
855 blas::ProfileResult *output_profile_result);
856 Stream &ThenBlasGemmWithProfiling(
857 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
858 uint64_t k, std::complex<double> alpha,
859 const DeviceMemory<std::complex<double>> &a, int lda,
860 const DeviceMemory<std::complex<double>> &b, int ldb,
861 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
862 blas::ProfileResult *output_profile_result);
863
864 template <typename InputType, typename OutputType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)865 port::Status ThenBlasGemmWithAlgorithm(
866 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
867 uint64_t k, const DeviceMemory<InputType> &a, int lda,
868 const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c,
869 int ldc, blas::ComputationType computation_type,
870 blas::AlgorithmType algorithm,
871 blas::ProfileResult *output_profile_result) {
872 OutputType alpha{1};
873 OutputType beta{0};
874 return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b,
875 ldb, beta, c, ldc, computation_type,
876 algorithm, output_profile_result);
877 }
878
879 template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)880 port::Status ThenBlasGemmWithAlgorithm(
881 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
882 uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
883 const DeviceMemory<InputType> &b, int ldb, ConstantType beta,
884 DeviceMemory<OutputType> *c, int ldc,
885 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
886 blas::ProfileResult *output_profile_result) {
887 TF_RETURN_IF_ERROR(
888 CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
889 computation_type));
890
891 blas::BlasSupport *blas = parent()->AsBlas();
892 if (!blas) {
893 return port::InternalError(
894 "Attempting to perform BLAS operation using "
895 "StreamExecutor without BLAS support");
896 }
897
898 void *alpha_ptr = α
899 void *beta_ptr = β
900 float alpha_storage, beta_storage;
901 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
902 &beta_storage);
903
904 port::Status st = blas->DoBlasGemmWithAlgorithm(
905 this, transa, transb, m, n, k, alpha_ptr, a,
906 blas::ToDataType<InputType>::value, lda, b,
907 blas::ToDataType<InputType>::value, ldb, beta_ptr, c,
908 blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm,
909 output_profile_result);
910 if (output_profile_result) {
911 // The error is recorded in the profile.
912 return ::tensorflow::OkStatus();
913 }
914 return st;
915 }
916
917 template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmStridedBatchedWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,int64_t stride_c,int batch_count,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)918 port::Status ThenBlasGemmStridedBatchedWithAlgorithm(
919 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
920 uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
921 int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
922 int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
923 int64_t stride_c, int batch_count, blas::ComputationType computation_type,
924 blas::AlgorithmType algorithm,
925 blas::ProfileResult *output_profile_result) {
926 TF_RETURN_IF_ERROR(
927 CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
928 computation_type));
929
930 blas::BlasSupport *blas = parent()->AsBlas();
931 if (!blas) {
932 return port::InternalError(
933 "Attempting to perform BLAS operation using "
934 "StreamExecutor without BLAS support");
935 }
936 void *alpha_ptr = α
937 void *beta_ptr = β
938 float alpha_storage, beta_storage;
939 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
940 &beta_storage);
941 port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm(
942 this, transa, transb, m, n, k, alpha_ptr, a,
943 blas::ToDataType<InputType>::value, lda, stride_a, b,
944 blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c,
945 blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
946 computation_type, algorithm, output_profile_result);
947 if (output_profile_result) {
948 // The error is recorded in the profile.
949 return ::tensorflow::OkStatus();
950 }
951 return st;
952 }
953
954 template <typename T>
955 using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok
956
957 // See BlasSupport::DoBlasGemmBatched.
958 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
959 uint64_t m, uint64 n, uint64_t k, float alpha,
960 const DeviceMemorySlice<Eigen::half> &a, int lda,
961 const DeviceMemorySlice<Eigen::half> &b, int ldb,
962 float beta,
963 const DeviceMemorySlice<Eigen::half> &c, int ldc,
964 int batch_count);
965 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
966 uint64_t m, uint64 n, uint64 k, float alpha,
967 const DeviceMemorySlice<float> &a, int lda,
968 const DeviceMemorySlice<float> &b, int ldb,
969 float beta, const DeviceMemorySlice<float> &c,
970 int ldc, int batch_count);
971 Stream &ThenBlasGemmBatched(
972 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
973 uint64 k, double alpha,
974 const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
975 int lda,
976 const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
977 int ldb, double beta,
978 const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
979 int ldc, int batch_count);
980 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
981 uint64_t m, uint64 n, uint64_t k,
982 std::complex<float> alpha,
983 const DeviceMemorySlice<std::complex<float>> &a,
984 int lda,
985 const DeviceMemorySlice<std::complex<float>> &b,
986 int ldb, std::complex<float> beta,
987 const DeviceMemorySlice<std::complex<float>> &c,
988 int ldc, int batch_count);
989 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
990 uint64_t m, uint64 n, uint64_t k,
991 std::complex<double> alpha,
992 const DeviceMemorySlice<std::complex<double>> &a,
993 int lda,
994 const DeviceMemorySlice<std::complex<double>> &b,
995 int ldb, std::complex<double> beta,
996 const DeviceMemorySlice<std::complex<double>> &c,
997 int ldc, int batch_count);
998 Stream &ThenBlasGemmBatchedWithScratch(
999 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1000 uint64_t k, float alpha, const DeviceMemorySlice<Eigen::half> &a, int lda,
1001 const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta,
1002 const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count,
1003 ScratchAllocator *scratch_allocator);
1004 Stream &ThenBlasGemmBatchedWithScratch(
1005 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1006 uint64_t k, float alpha, const DeviceMemorySlice<float> &a, int lda,
1007 const DeviceMemorySlice<float> &b, int ldb, float beta,
1008 const DeviceMemorySlice<float> &c, int ldc, int batch_count,
1009 ScratchAllocator *scratch_allocator);
1010 Stream &ThenBlasGemmBatchedWithScratch(
1011 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1012 uint64_t k, double alpha, const DeviceMemorySlice<double> &a, int lda,
1013 const DeviceMemorySlice<double> &b, int ldb, double beta,
1014 const DeviceMemorySlice<double> &c, int ldc, int batch_count,
1015 ScratchAllocator *scratch_allocator);
1016 Stream &ThenBlasGemmBatchedWithScratch(
1017 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1018 uint64_t k, std::complex<float> alpha,
1019 const DeviceMemorySlice<std::complex<float>> &a, int lda,
1020 const DeviceMemorySlice<std::complex<float>> &b, int ldb,
1021 std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c,
1022 int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1023 Stream &ThenBlasGemmBatchedWithScratch(
1024 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1025 uint64_t k, std::complex<double> alpha,
1026 const DeviceMemorySlice<std::complex<double>> &a, int lda,
1027 const DeviceMemorySlice<std::complex<double>> &b, int ldb,
1028 std::complex<double> beta,
1029 const DeviceMemorySlice<std::complex<double>> &c, int ldc,
1030 int batch_count, ScratchAllocator *scratch_allocator);
1031
1032 template <typename InputType, typename ConstantType>
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<InputType> * c,int ldc,int64_t stride_c,int batch_count)1033 port::Status ThenBlasGemmStridedBatched(
1034 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1035 uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1036 int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
1037 int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc,
1038 int64_t stride_c, int batch_count) {
1039 static_assert(
1040 detail::is_any_of<InputType, float, Eigen::half, Eigen::bfloat16,
1041 double, std::complex<float>, std::complex<double>>(),
1042 "Unsupported input type");
1043 static_assert(
1044 std::is_same_v<ConstantType, InputType> ||
1045 (detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16>() &&
1046 std::is_same_v<ConstantType, float>),
1047 "Mismatched input and alpha/beta types");
1048 blas::BlasSupport *blas = parent()->AsBlas();
1049 if (!blas) {
1050 return port::InternalError(
1051 "Attempting to perform BLAS operation using "
1052 "StreamExecutor without BLAS support");
1053 }
1054
1055 void *alpha_ptr = α
1056 void *beta_ptr = β
1057 float alpha_storage, beta_storage;
1058 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1059 &beta_storage);
1060
1061 return blas->DoBlasGemmStridedBatched(
1062 this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
1063 alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
1064 stride_c, batch_count);
1065 }
1066
1067 // See BlasSupport::DoBlasTrsm.
1068 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1069 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1070 uint64_t n, float alpha, const DeviceMemory<float> &a,
1071 int lda, DeviceMemory<float> *b, int ldb);
1072 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1073 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1074 uint64_t n, double alpha, const DeviceMemory<double> &a,
1075 int lda, DeviceMemory<double> *b, int ldb);
1076 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1077 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1078 uint64_t n, std::complex<float> alpha,
1079 const DeviceMemory<std::complex<float>> &a, int lda,
1080 DeviceMemory<std::complex<float>> *b, int ldb);
1081 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1082 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1083 uint64_t n, std::complex<double> alpha,
1084 const DeviceMemory<std::complex<double>> &a, int lda,
1085 DeviceMemory<std::complex<double>> *b, int ldb);
1086
1087 // See BlasSupport::DoBlasTrsmBatched.
1088 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1089 blas::Transpose transa, blas::Diagonal diag,
1090 uint64_t m, uint64 n, float alpha,
1091 const DeviceMemory<float *> &as, int lda,
1092 DeviceMemory<float *> *bs, int ldb,
1093 int batch_count);
1094 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1095 blas::Transpose transa, blas::Diagonal diag,
1096 uint64_t m, uint64 n, double alpha,
1097 const DeviceMemory<double *> &as, int lda,
1098 DeviceMemory<double *> *bs, int ldb,
1099 int batch_count);
1100 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1101 blas::Transpose transa, blas::Diagonal diag,
1102 uint64_t m, uint64 n, std::complex<float> alpha,
1103 const DeviceMemory<std::complex<float> *> &as,
1104 int lda, DeviceMemory<std::complex<float> *> *bs,
1105 int ldb, int batch_count);
1106 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1107 blas::Transpose transa, blas::Diagonal diag,
1108 uint64_t m, uint64 n, std::complex<double> alpha,
1109 const DeviceMemory<std::complex<double> *> &as,
1110 int lda, DeviceMemory<std::complex<double> *> *bs,
1111 int ldb, int batch_count);
1112
1113 // See FftSupport::DoFft.
1114 Stream &ThenFft(fft::Plan *plan,
1115 const DeviceMemory<std::complex<float>> &input,
1116 DeviceMemory<std::complex<float>> *output);
1117 Stream &ThenFft(fft::Plan *plan,
1118 const DeviceMemory<std::complex<double>> &input,
1119 DeviceMemory<std::complex<double>> *output);
1120 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1121 DeviceMemory<std::complex<float>> *output);
1122 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1123 DeviceMemory<std::complex<double>> *output);
1124 Stream &ThenFft(fft::Plan *plan,
1125 const DeviceMemory<std::complex<float>> &input,
1126 DeviceMemory<float> *output);
1127 Stream &ThenFft(fft::Plan *plan,
1128 const DeviceMemory<std::complex<double>> &input,
1129 DeviceMemory<double> *output);
1130
1131 // Makes the RNG use the provided value as the basis for further generation.
1132 // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1133 // sources of seed data if the default (high quality) sources are not
1134 // desired.
1135 // For most use cases, this function will not be necessary; each provided
1136 // back-end implementation will be appropriately seeded by default.
1137 // At a minimum 16 bytes of data are required in the seed buffer.
1138 //
1139 // To seed with good (non-reproducible) data:
1140 // File* f = File::Open("/dev/random", "r");
1141 // int64_t bytes_read = f->Read(seed_data, bytes_to_read);
1142 // < error checking >
1143 // stream.ThenSetRngSeed(seed_data, bytes_read);
1144 //
1145 // To seed with reproducible data:
1146 // uint64_t seed_data[2] = { <data> };
1147 // stream.ThenSetRngSeed(seed_data, 16);
1148 Stream &ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes);
1149
1150 // Populates the memory indicated by values with uniform-random-distribution
1151 // values. TODO(leary) seeding API/description
1152 //
1153 // Uses the type and size of the DeviceMemory to infer what data should be
1154 // populated.
1155 Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1156 Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1157 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1158 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1159 Stream &ThenPopulateRandGaussian(float mean, float stddev,
1160 DeviceMemory<float> *values);
1161 Stream &ThenPopulateRandGaussian(double mean, double stddev,
1162 DeviceMemory<double> *values);
1163
1164 // Entrain onto the stream: a memcpy to a host destination from a GPU source
1165 // of the given target size. host_dst must be a pointer to host memory
1166 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1167 // then registered with StreamExecutor::HostMemoryRegister.
1168 Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1169 uint64_t size);
1170
1171 // Entrain onto the stream: a memcpy to a GPU destination from a host source
1172 // of the given target size. host_src must be a pointer to host memory
1173 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1174 // then registered with StreamExecutor::HostMemoryRegister.
1175 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1176 uint64_t size);
1177
1178 // Alternative interface for memcpying from device to host that takes an
1179 // array slice. Checks that the destination size can accommodate the host
1180 // slice size.
1181 template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1182 Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1183 port::MutableArraySlice<T> host_dst) {
1184 auto host_size = host_dst.size() * sizeof(T);
1185 CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1186 return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1187 }
1188
1189 // Alternative interface for memcpying from host to device that takes an
1190 // array slice. Checks that the destination size can accommodate the host
1191 // slice size.
1192 template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1193 Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src, // non-absl ok
1194 DeviceMemory<T> *gpu_dst) {
1195 auto host_size = host_src.size() * sizeof(T);
1196 CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1197 return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1198 }
1199
1200 // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1201 // of the given target size. gpu_src/dst must be pointers to GPU memory and
1202 // peer access must be enabled between their owning StreamExecutors.
1203 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1204 uint64_t size);
1205
1206 // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1207 // ensuring that the host pointer isn't getting confused accidentally with a
1208 // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64_t size)1209 Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1210 const DeviceMemoryBase &gpu_src, uint64_t size) {
1211 return ThenMemcpy(gpu_dst, gpu_src, size);
1212 }
1213
1214 // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1215 // The location must not be null.
1216 Stream &ThenMemZero(DeviceMemoryBase *location, uint64_t size);
1217
1218 // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1219 // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1220 // by 4). The location must not be null.
1221 Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
1222 uint64_t size);
1223
1224 // Enqueue a forward operation of the RNN model onto the stream.
1225 // See DnnSupport::DoRnnForward for more details.
1226 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1227 const dnn::RnnSequenceTensorDescriptor &input_desc,
1228 const DeviceMemory<Eigen::half> &input_data,
1229 const DeviceMemory<int> &seq_lengths_data,
1230 const dnn::RnnStateTensorDescriptor &input_h_desc,
1231 const DeviceMemory<Eigen::half> &input_h_data,
1232 const dnn::RnnStateTensorDescriptor &input_c_desc,
1233 const DeviceMemory<Eigen::half> &input_c_data,
1234 const DeviceMemory<Eigen::half> ¶ms,
1235 const dnn::RnnSequenceTensorDescriptor &output_desc,
1236 DeviceMemory<Eigen::half> *output_data,
1237 const dnn::RnnStateTensorDescriptor &output_h_desc,
1238 DeviceMemory<Eigen::half> *output_h_data,
1239 const dnn::RnnStateTensorDescriptor &output_c_desc,
1240 DeviceMemory<Eigen::half> *output_c_data,
1241 bool is_training,
1242 ScratchAllocator *reserve_space_allocator,
1243 ScratchAllocator *workspace_allocator,
1244 dnn::ProfileResult *output_profile_result);
1245
1246 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1247 const dnn::RnnSequenceTensorDescriptor &input_desc,
1248 const DeviceMemory<float> &input_data,
1249 const DeviceMemory<int> &seq_lengths_data,
1250 const dnn::RnnStateTensorDescriptor &input_h_desc,
1251 const DeviceMemory<float> &input_h_data,
1252 const dnn::RnnStateTensorDescriptor &input_c_desc,
1253 const DeviceMemory<float> &input_c_data,
1254 const DeviceMemory<float> ¶ms,
1255 const dnn::RnnSequenceTensorDescriptor &output_desc,
1256 DeviceMemory<float> *output_data,
1257 const dnn::RnnStateTensorDescriptor &output_h_desc,
1258 DeviceMemory<float> *output_h_data,
1259 const dnn::RnnStateTensorDescriptor &output_c_desc,
1260 DeviceMemory<float> *output_c_data, bool is_training,
1261 ScratchAllocator *reserve_space_allocator,
1262 ScratchAllocator *workspace_allocator,
1263 dnn::ProfileResult *output_profile_result);
1264
1265 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1266 const dnn::RnnSequenceTensorDescriptor &input_desc,
1267 const DeviceMemory<double> &input_data,
1268 const DeviceMemory<int> &seq_lengths_data,
1269 const dnn::RnnStateTensorDescriptor &input_h_desc,
1270 const DeviceMemory<double> &input_h_data,
1271 const dnn::RnnStateTensorDescriptor &input_c_desc,
1272 const DeviceMemory<double> &input_c_data,
1273 const DeviceMemory<double> ¶ms,
1274 const dnn::RnnSequenceTensorDescriptor &output_desc,
1275 DeviceMemory<double> *output_data,
1276 const dnn::RnnStateTensorDescriptor &output_h_desc,
1277 DeviceMemory<double> *output_h_data,
1278 const dnn::RnnStateTensorDescriptor &output_c_desc,
1279 DeviceMemory<double> *output_c_data, bool is_training,
1280 ScratchAllocator *reserve_space_allocator,
1281 ScratchAllocator *workspace_allocator,
1282 dnn::ProfileResult *output_profile_result);
1283
1284 // Enqueue a backward operation of the RNN model onto the stream.
1285 // See DnnSupport::DoRnnBackward for more details.
1286 Stream &ThenRnnBackward(
1287 const dnn::RnnDescriptor &rnn_desc,
1288 const dnn::RnnSequenceTensorDescriptor &input_desc,
1289 const DeviceMemory<Eigen::half> &input_data,
1290 const DeviceMemory<int> &seq_lengths_data,
1291 const dnn::RnnStateTensorDescriptor &input_h_desc,
1292 const DeviceMemory<Eigen::half> &input_h_data,
1293 const dnn::RnnStateTensorDescriptor &input_c_desc,
1294 const DeviceMemory<Eigen::half> &input_c_data,
1295 const DeviceMemory<Eigen::half> ¶ms,
1296 const dnn::RnnSequenceTensorDescriptor &output_desc,
1297 const DeviceMemory<Eigen::half> &output_data,
1298 const dnn::RnnStateTensorDescriptor &output_h_desc,
1299 const DeviceMemory<Eigen::half> &output_h_data,
1300 const dnn::RnnStateTensorDescriptor &output_c_desc,
1301 const DeviceMemory<Eigen::half> &output_c_data,
1302 const DeviceMemory<Eigen::half> &output_backprop_data,
1303 const DeviceMemory<Eigen::half> &output_h_backprop_data,
1304 const DeviceMemory<Eigen::half> &output_c_backprop_data,
1305 DeviceMemory<Eigen::half> *input_backprop_data,
1306 DeviceMemory<Eigen::half> *input_h_backprop_data,
1307 DeviceMemory<Eigen::half> *input_c_backprop_data,
1308 DeviceMemory<Eigen::half> *params_backprop_data,
1309 DeviceMemory<uint8> *reserve_space_data,
1310 ScratchAllocator *workspace_allocator,
1311 dnn::ProfileResult *output_profile_result);
1312
1313 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1314 const dnn::RnnSequenceTensorDescriptor &input_desc,
1315 const DeviceMemory<float> &input_data,
1316 const DeviceMemory<int> &seq_lengths_data,
1317 const dnn::RnnStateTensorDescriptor &input_h_desc,
1318 const DeviceMemory<float> &input_h_data,
1319 const dnn::RnnStateTensorDescriptor &input_c_desc,
1320 const DeviceMemory<float> &input_c_data,
1321 const DeviceMemory<float> ¶ms,
1322 const dnn::RnnSequenceTensorDescriptor &output_desc,
1323 const DeviceMemory<float> &output_data,
1324 const dnn::RnnStateTensorDescriptor &output_h_desc,
1325 const DeviceMemory<float> &output_h_data,
1326 const dnn::RnnStateTensorDescriptor &output_c_desc,
1327 const DeviceMemory<float> &output_c_data,
1328 const DeviceMemory<float> &output_backprop_data,
1329 const DeviceMemory<float> &output_h_backprop_data,
1330 const DeviceMemory<float> &output_c_backprop_data,
1331 DeviceMemory<float> *input_backprop_data,
1332 DeviceMemory<float> *input_h_backprop_data,
1333 DeviceMemory<float> *input_c_backprop_data,
1334 DeviceMemory<float> *params_backprop_data,
1335 DeviceMemory<uint8> *reserve_space_data,
1336 ScratchAllocator *workspace_allocator,
1337 dnn::ProfileResult *output_profile_result);
1338
1339 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1340 const dnn::RnnSequenceTensorDescriptor &input_desc,
1341 const DeviceMemory<double> &input_data,
1342 const DeviceMemory<int> &seq_lengths_data,
1343 const dnn::RnnStateTensorDescriptor &input_h_desc,
1344 const DeviceMemory<double> &input_h_data,
1345 const dnn::RnnStateTensorDescriptor &input_c_desc,
1346 const DeviceMemory<double> &input_c_data,
1347 const DeviceMemory<double> ¶ms,
1348 const dnn::RnnSequenceTensorDescriptor &output_desc,
1349 const DeviceMemory<double> &output_data,
1350 const dnn::RnnStateTensorDescriptor &output_h_desc,
1351 const DeviceMemory<double> &output_h_data,
1352 const dnn::RnnStateTensorDescriptor &output_c_desc,
1353 const DeviceMemory<double> &output_c_data,
1354 const DeviceMemory<double> &output_backprop_data,
1355 const DeviceMemory<double> &output_h_backprop_data,
1356 const DeviceMemory<double> &output_c_backprop_data,
1357 DeviceMemory<double> *input_backprop_data,
1358 DeviceMemory<double> *input_h_backprop_data,
1359 DeviceMemory<double> *input_c_backprop_data,
1360 DeviceMemory<double> *params_backprop_data,
1361 DeviceMemory<uint8> *reserve_space_data,
1362 ScratchAllocator *workspace_allocator,
1363 dnn::ProfileResult *output_profile_result);
1364
1365 // Enqueue a CTCLoss operation onto the stream.
1366 // See DnnSupport::DoCtcLoss for more details.
1367 Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
1368 const DeviceMemory<float> &probs_data,
1369 absl::Span<const int> labels_data,
1370 absl::Span<const int> labels_lengths_data,
1371 absl::Span<const int> input_lengths_data,
1372 DeviceMemory<float> *costs_data,
1373 const dnn::RnnStateTensorDescriptor &grads_desc,
1374 DeviceMemory<float> *grads_data,
1375 ScratchAllocator *workspace_allocator);
1376
1377 // Enqueue onto the stream a operation that transforms a tensor.
1378 // See DnnSupport::DoTransformTensor for more details.
1379 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1380 dnn::DataType input_type,
1381 const DeviceMemoryBase &input_data,
1382 const dnn::BatchDescriptor &output_desc,
1383 dnn::DataType output_type, float scale,
1384 DeviceMemoryBase *output_data);
1385
1386 // The templated version of the above ThenTransformTensor. Useful when the
1387 // input and output types are statically known.
1388 template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1389 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1390 const DeviceMemory<InElemT> &input_data,
1391 const dnn::BatchDescriptor &output_desc,
1392 DeviceMemory<OutElemT> *output_data) {
1393 return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1394 input_data, output_desc,
1395 dnn::ToDataType<OutElemT>(), output_data);
1396 }
1397
1398 // (Synchronously) block the host code waiting for the operations
1399 // entrained on the stream (enqueued to this point in program
1400 // execution) to complete.
1401 //
1402 // Returns an OK status if the blocking was successful and the stream is ok().
1403 // Otherwise returns an error describing why the blocking failed.
1404 port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_);
1405
1406 // Warning! This method interacts with internal threads in
1407 // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1408 // use
1409 // only. Please check with a member of the FASTR team before making use of
1410 // this method.
1411 //
1412 // Entrains onto the stream a function to be executed on the host at some
1413 // point in the future.
1414 // Async host callbacks DO NOT block the stream as device functions (or as
1415 // synchronous host callbacks). No synchronization is possible with
1416 // asynchronous callbacks; they are strictly fire-and-forget.
1417 // This method is private due to the potential for undefined behavior with
1418 // synchronization using OpenCL user events.
1419 // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1420 // parameter will still be valid - this Stream may not be!
1421 // Any callbacks requiring device API calls must use this method.
1422 Stream &ThenEnqueueOnBackgroundThread(
1423 std::function<void(StreamExecutor *)> task);
1424
1425 // Returns the (opaque) platform-specific backing object. Ownership is not
1426 // transferred to the caller.
implementation()1427 internal::StreamInterface *implementation() { return implementation_.get(); }
1428
1429 // Entrains onto the stream a callback to the host (from the device).
1430 // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1431 // never fail or its failure is inconsequential.
1432 //
1433 // This is kept for backward compatibility. Future code should use
1434 // ThenDoHostCallbackWithStatus and explicitly return a success status.
1435 // TODO(b/112125301): Eventually remove this method.
1436 Stream &ThenDoHostCallback(std::function<void()> callback);
1437
1438 // Entrains onto the stream a callback to the host (from the device).
1439 // Host callbacks block/occupy the stream just as device functions
1440 // (execute one at a time, block later stream operations).
1441 // Whether the callback return status affects the result of BlockHostUntilDone
1442 // is platform-dependent.
1443 //
1444 // Behavior is undefined when synchronizing using OpenCL user events.
1445 // Behavior is undefined if host callbacks call device routines or insert
1446 // them into any stream.
1447 //
1448 // On certain platforms, ThenDoHostCallback is expected to have significant
1449 // negative effects on performance.
1450 Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
1451
1452 // Runs the given callback after the next call to BlockHostUntilDone on this
1453 // stream (or after the Stream does BlockHostUntilDone in its destructor).
1454 // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
1455 // some use cases.
1456 Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
1457
1458 // Returns the StreamExecutor (parent object) associated with this stream.
parent()1459 StreamExecutor *parent() const {
1460 CHECK(parent_ != nullptr);
1461 return parent_;
1462 }
1463
1464 //
GetCudaComputeCapability()1465 CudaComputeCapability GetCudaComputeCapability() const {
1466 return parent()->GetDeviceDescription().cuda_compute_capability();
1467 }
1468
GetRocmComputeCapability()1469 RocmComputeCapability GetRocmComputeCapability() const {
1470 return parent()->GetDeviceDescription().rocm_compute_capability();
1471 }
1472 // Returns the (internal usage) temporary-memory-allocation manager associated
1473 // with this stream.
1474 internal::TemporaryMemoryManager *temporary_memory_manager();
1475
1476 // Returns a debugging string "[stream=0x...,impl=0x...]".
1477 std::string DebugStreamPointers() const;
1478
1479 private:
1480 friend class host::HostBlas; // for parent_.
1481 friend class host::HostFft; // for parent_.
1482 friend class host::HostRng; // for parent_.
1483 template <typename... Args>
1484 friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
1485 friend class ocl::CLBlas; // for parent_.
1486
1487 // Checks whether types match before a call to extended BLAS version.
1488 template <typename ABType, typename CType, typename ScaleType>
CheckTypesForExtendedBlas(blas::ComputationType computation_type)1489 port::Status CheckTypesForExtendedBlas(
1490 blas::ComputationType computation_type) {
1491 static_assert(
1492 detail::is_any_of<ABType, Eigen::half, Eigen::bfloat16, float, double,
1493 int8_t, std::complex<float>, std::complex<double>>(),
1494 "The only buffer types supported are: Eigen::half, float, "
1495 "double, int8, std::complex<float> and std::complex<double>");
1496 static_assert(
1497 std::is_same_v<ABType, CType> ||
1498 (std::is_same_v<ABType, int8_t> && std::is_same_v<CType, int32_t>),
1499 "Input and output buffer types should be the same unless input is "
1500 "int8 and output is int32");
1501 static_assert(
1502 std::is_same_v<ScaleType, CType> ||
1503 (std::is_same_v<ScaleType, float> &&
1504 detail::is_any_of<CType, Eigen::half, Eigen::bfloat16>()),
1505 "Mismatched alpha/beta and output types");
1506
1507 bool valid_computation_type = [computation_type] {
1508 switch (computation_type) {
1509 case blas::ComputationType::kF16:
1510 return std::is_same_v<CType, Eigen::half>;
1511 case blas::ComputationType::kF32:
1512 return detail::is_any_of<CType, Eigen::half, Eigen::bfloat16, float,
1513 std::complex<float>>();
1514 case blas::ComputationType::kF64:
1515 return detail::is_any_of<CType, double, std::complex<double>>();
1516 case blas::ComputationType::kI32:
1517 return std::is_same_v<CType, int32_t>;
1518 case blas::ComputationType::kF16AsF32: // fall-through
1519 case blas::ComputationType::kBF16AsF32: // fall-through
1520 case blas::ComputationType::kTF32AsF32:
1521 return detail::is_any_of<CType, float, std::complex<float>>();
1522 }
1523 }();
1524
1525 if (!valid_computation_type) {
1526 return port::InternalError(absl::StrCat(
1527 "Invalid computation type ",
1528 blas::ComputationTypeString(computation_type), " for output type: ",
1529 blas::DataTypeString(blas::ToDataType<CType>::value)));
1530 }
1531 return ::tensorflow::OkStatus();
1532 }
1533
InErrorState()1534 bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
1535 absl::ReaderMutexLock lock(&mu_);
1536 return !status_.ok();
1537 }
1538
1539 // Sets the error state if operation_retcode is false.
1540 // This is a useful shorthand for many stream routines.
1541 void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_);
1542
1543 // Checks the status and logs the error message, if any.
1544 void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_);
1545
SetError()1546 void SetError() { CheckError(false /* = operation_retcode */); }
1547
SetErrorAndLogNoDnnSupport()1548 void SetErrorAndLogNoDnnSupport() {
1549 SetError();
1550 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
1551 "without DNN support";
1552 }
1553
1554 // Runs the set of callbacks that are intended to run after
1555 // BlockHostUntilDone.
1556 void RunAfterBlockHostUntilDoneCallbacks();
1557
1558 // The StreamExecutor that supports the operation of this stream.
1559 StreamExecutor *parent_;
1560
1561 // The platform-dependent implementation that the StreamExecutor interface
1562 // delegates to.
1563 std::unique_ptr<internal::StreamInterface> implementation_;
1564
1565 // mutex that guards the allocation / error state flags.
1566 // Mutable so that it can be obtained via const reader lock.
1567 mutable absl::Mutex mu_;
1568
1569 // Whether Init() was successfully called to allocate this stream on the
1570 // underlying platform. It simply flips from 0 to 1 with a sanity check.
1571 // See StreamExecutor::AllocateStream.
1572 bool allocated_ ABSL_GUARDED_BY(mu_);
1573
1574 // The last error (if any) of all method calls.
1575 port::Status status_ ABSL_GUARDED_BY(mu_);
1576
1577 // Sub-streams that are generated from this stream. Each element has a pointer
1578 // to sub-stream and a boolean value indicating if this substream is ready to
1579 // be reused.
1580 std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
1581 ABSL_GUARDED_BY(mu_);
1582
1583 // Streams can allocate temporary memories to help with work they enqueue
1584 // (e.g. for scratch memory spaces). This member tracks those allocations and
1585 // notes when they can be reclaimed -- reclamation is attempted when
1586 // BlockHostUntilDone() is called.
1587 internal::TemporaryMemoryManager temporary_memory_manager_;
1588
1589 // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
1590 std::vector<std::function<void()>> after_block_host_until_done_callbacks_
1591 ABSL_GUARDED_BY(mu_);
1592
1593 // Non-extended BLAS interface requires alpha/beta to be floats when input
1594 // type is Eigen::half. However, for consistency purposes it is convenient
1595 // for the interface to accept Eigen::half.
1596 template <typename T>
UpcastHalfToFloat(void ** alpha_ptr,void ** beta_ptr,float * alpha_storage,float * beta_storage)1597 void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr,
1598 float *alpha_storage, float *beta_storage) {
1599 if (std::is_same<T, Eigen::half>::value) {
1600 *alpha_storage =
1601 static_cast<float>(*reinterpret_cast<Eigen::half *>(*alpha_ptr));
1602 *beta_storage =
1603 static_cast<float>(*reinterpret_cast<Eigen::half *>(*beta_ptr));
1604 *alpha_ptr = alpha_storage;
1605 *beta_ptr = beta_storage;
1606 } else if (std::is_same<T, Eigen::bfloat16>::value) {
1607 *alpha_storage =
1608 static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*alpha_ptr));
1609 *beta_storage =
1610 static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*beta_ptr));
1611 *alpha_ptr = alpha_storage;
1612 *beta_ptr = beta_storage;
1613 }
1614 }
1615
1616 SE_DISALLOW_COPY_AND_ASSIGN(Stream);
1617 };
1618
1619 ////////////
1620 // Inlines
1621
1622 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)1623 inline port::Status Stream::ThenLaunch(ThreadDim thread_dims,
1624 BlockDim block_dims,
1625 const TypedKernel<Params...> &kernel,
1626 Args... args) {
1627 KernelInvocationChecker<std::tuple<Params...>,
1628 std::tuple<Args...>>::CheckAllStaticAssert();
1629
1630 // This is the core that allows type-safe kernel launching.
1631 // Since the platforms take kernel arguments as tuples of (void *, size),
1632 // we pack the variadic parameters passed as ...args into the desired
1633 // tuple form and pass that packed form to the StreamExecutor::Launch()
1634 // implementation.
1635 KernelArgsArray<sizeof...(args)> kernel_args;
1636 kernel.PackParams(&kernel_args, args...);
1637 TF_RETURN_IF_ERROR(
1638 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args));
1639 return ::tensorflow::OkStatus();
1640 }
1641
1642 template <typename T>
1643 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64_t element_count)1644 Stream::AllocateTemporaryArray(uint64_t element_count) {
1645 return temporary_memory_manager_.AllocateArray<T>(element_count);
1646 }
1647
temporary_memory_manager()1648 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
1649 return &temporary_memory_manager_;
1650 }
1651
1652 template <>
1653 struct Quantization<uint8> {
1654 static constexpr dnn::QuantizedActivationMode kModeId =
1655 dnn::QuantizedActivationMode::k8Bit;
1656 };
1657
1658 template <>
1659 struct Quantization<uint16> {
1660 static constexpr dnn::QuantizedActivationMode kModeId =
1661 dnn::QuantizedActivationMode::k16Bit;
1662 };
1663
1664 template <>
1665 struct Quantization<int32> {
1666 static constexpr dnn::QuantizedActivationMode kModeId =
1667 dnn::QuantizedActivationMode::k32Bit;
1668 };
1669
1670 } // namespace stream_executor
1671
1672 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
1673