• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
23 
24 #include <complex>
25 #include <functional>
26 #include <memory>
27 #include <type_traits>
28 
29 #include "absl/base/thread_annotations.h"
30 #include "absl/synchronization/mutex.h"
31 #include "tensorflow/compiler/xla/stream_executor/blas.h"
32 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
33 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
34 #include "tensorflow/compiler/xla/stream_executor/event.h"
35 #include "tensorflow/compiler/xla/stream_executor/fft.h"
36 #include "tensorflow/compiler/xla/stream_executor/kernel.h"
37 #include "tensorflow/compiler/xla/stream_executor/launch_dim.h"
38 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
39 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
40 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
41 #include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h"
42 
43 namespace stream_executor {
44 
45 namespace host {
46 class HostBlas;
47 class HostFft;
48 class HostRng;
49 class HostTimer;
50 }  // namespace host
51 
52 namespace ocl {
53 class CLBlas;
54 }  // namespace ocl
55 
56 namespace internal {
57 class StreamInterface;
58 }  // namespace internal
59 
60 class DeviceMemoryBase;
61 template <typename ElemT>
62 class DeviceMemory;
63 
64 class Timer;
65 
66 namespace dnn {
67 class BatchDescriptor;
68 class FilterDescriptor;
69 class ConvolutionDescriptor;
70 class ProfileResult;
71 class AlgorithmDesc;
72 }  // namespace dnn
73 
74 class StreamExecutor;
75 class ScratchAllocator;
76 
77 namespace detail {
78 
79 // Helper class to prevent a template function argument from being deduced. This
80 // is identical to std::type_identity in C++20.
81 template <typename T>
82 struct NonDeduced {
83   using type = T;
84 };
85 template <typename T>
86 using NonDeducedType = typename NonDeduced<T>::type;
87 
88 // Helper to return if `T` is the same type as `First` or any or `Rest`.
89 template <typename T>
is_any_of()90 constexpr bool is_any_of() {
91   return false;
92 }
93 
94 template <typename T, typename First, typename... Rest>
is_any_of()95 constexpr bool is_any_of() {
96   return std::is_same_v<T, First> || is_any_of<T, Rest...>();
97 }
98 
99 }  // namespace detail
100 
101 // Convert a type to the corresponding QuantizedActivationMode.
102 template <typename ElementType>
103 struct Quantization;
104 
105 // Represents a stream of dependent computations on a GPU device.
106 //
107 // The operations within a stream execute linearly and asynchronously until
108 // BlockHostUntilDone() is invoked, which synchronously joins host code with
109 // the execution of the stream.
110 //
111 // If any given operation fails when entraining work for the stream, ok() will
112 // indicate that an error has occurred. After initialization, once a stream is
113 // !ok(), it will never be ok().
114 //
115 // Thread-safe post-initialization.
116 class Stream {
117  public:
118   // Instantiate a stream tied to parent as a platform executor. Work
119   // entrained onto this stream will be launched/managed on that
120   // StreamExecutor's platform.
121   explicit Stream(StreamExecutor *parent);
122 
123   // Deallocates any stream resources that the parent StreamExecutor has
124   // bestowed
125   // upon this object.
126   ~Stream();
127 
128   // Returns whether any errors have occurred while entraining work for this
129   // stream.
ok()130   bool ok() const { return !InErrorState(); }
131 
132   // Retrieves execution status back into the stream from the underlying
133   // implementation without blocking the stream.
134   //
135   // Normally, Stream::BlockHostUntilDone is used to get execution status.
136   // However, some devices use out-of-band mechnanisms to ensure their streams
137   // have finished on-device work, without needing to block the streams. (These
138   // devices should also override AllowsSyncOnCompletion to return false.) For
139   // these devices, this method can be used after work is finished to retrieve
140   // execution status.
141   port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_);
142 
143   // Initialize the stream. This must be performed before entraining any other
144   // operations.
145   Stream &Init() TF_LOCKS_EXCLUDED(mu_);
146 
147   // Initializes timer t via the StreamExecutor.
148   Stream &InitTimer(Timer *t);
149 
150   // Convenience wrapper around Init() and InitTimer().
151   Stream &InitWithTimer(Timer *t);
152 
153   // Get or create a sub-stream from this stream. If there is any sub-stream in
154   // the pool that can be reused then just return this sub-stream.  Otherwise
155   // create a new sub-stream.
156   //
157   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
158   Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_);
159 
160   // Return the sub-stream back to the host stream so that it can be reused
161   // later. Sub-streams that are !ok() will not be reused.
162   //
163   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
164   void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_);
165 
166   // Allocate temporary memories. The stream will deallocate them when blocked
167   // or destroyed.
168   template <typename T>
169   port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
170   AllocateTemporaryArray(uint64_t element_count);
171 
172   // Entrains onto the stream of operations: a kernel launch with the given
173   // (variadic) parameters for the invocation. These arguments can be things
174   // like DeviceMemory or primitive types such as int. What arguments you may
175   // pass to a given kernel are noted as the template parameters to the
176   // TypedKernel type that the machocc compiler generates.
177   //
178   // Template parameters:
179   //  Params...   The type list of formal parameters that the typed kernel
180   //              expects, which is matched against Args...
181   //  Args...     The deduced type list for passed actual arguments
182   //
183   // Implementation: A compile-time compatibility check is performed that has
184   // some leniency versus an exact parameter pack match -- for example,
185   // `const DeviceMemory<T>` is considered "pack compatible" with a
186   // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
187   // perfect forwarding support without rvalue references. It also attempts to
188   // spit out helpful static_assert error traces with information as to the
189   // argument number and types that were mismatched.
190   template <typename... Params, typename... Args>
191   port::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
192                           const TypedKernel<Params...> &kernel, Args... args);
193 
194   // Record a "start" event for the interval timer at this point in the
195   // stream's execution (relative to the previously and subsequently enqueued
196   // items in the stream's execution). Streams may be started/stopped multiple
197   // times.
198   Stream &ThenStartTimer(Timer *t);
199 
200   // Record a "stop" event for the interval timer at this point in the
201   // stream's execution. See also Stream::ThenStartTimer.
202   Stream &ThenStopTimer(Timer *t);
203 
204   // TODO(leary) If work is added to the stream that is being depended upon,
205   //              then what? Have to describe what happens.
206   template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)207   Stream &ThenWaitFor(Stream *other, Params... more_streams) {
208     return ThenWaitFor(more_streams...).ThenWaitFor(other);
209   }
210 
211   // Create a dependency for this stream's next work on the other stream
212   // completing. Does not take ownership of other, and other must not be
213   // null.
214   //
215   // Checks that a stream does not wait for itself, and it is up to the
216   // user to guarantee that a stream does not come to wait on itself in a
217   // cyclic manner; in that case, behavior is undefined.
218   //
219   // N.B. Base recursion case for the variadic ThenWaitFor.
220   Stream &ThenWaitFor(Stream *other);
221 
222   // Waits for all streams values in others.
223   // Checks that there is no shallow circular wait (i.e. that "this" is not in
224   // others)
225   template <typename P>
ThenWaitFor(P others)226   Stream &ThenWaitFor(P others) {
227     for (auto &stream : *others) {
228       CHECK_NE(stream.get(), this);
229       ThenWaitFor(stream.get());
230     }
231     return *this;
232   }
233 
234   // Waits for an event object to be set.
235   // Note that ThenRecordEvent must have been called on the event before
236   // you call this function; otherwise the event will be considered complete
237   // and this wait will do nothing.
238   Stream &ThenWaitFor(Event *event);
239 
240   // Inserts the specified event into the end of this stream. Once the stream
241   // has processed all events prior to the insertion point, the event will be
242   // marked as completed.
243   // The stream does not take ownership of event - meaning that event's lifetime
244   // must extend past the point at which it is marked complete!
245   Stream &ThenRecordEvent(Event *event);
246 
247   ////////////////
248   // DNN support
249   //
250   // See DnnSupport::* for comments on the following methods.
251 
252   Stream &ThenBatchNormalizationForward(
253       const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
254       const DeviceMemory<float> &offset,
255       const DeviceMemory<float> &estimated_mean,
256       const DeviceMemory<float> &estimated_variance,
257       const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
258       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
259       const double exponential_average_factor,
260       dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
261       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
262       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
263       bool is_training, ScratchAllocator *reserve_space_allocator,
264       ScratchAllocator *workspace_allocator);
265 
266   Stream &ThenBatchNormalizationBackward(
267       const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
268       const DeviceMemory<float> &scale, const DeviceMemory<float> &offset,
269       const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
270       const DeviceMemory<float> &y, const dnn::BatchDescriptor &x_desc,
271       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
272       dnn::ActivationMode activation_mode, DeviceMemory<float> *x_backprop,
273       DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
274       DeviceMemory<float> *side_input_backprop,
275       DeviceMemory<uint8> *reserve_space_data,
276       ScratchAllocator *workspace_allocator);
277 
278   Stream &ThenBatchNormalizationForward(
279       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
280       const DeviceMemory<float> &offset,
281       const DeviceMemory<float> &estimated_mean,
282       const DeviceMemory<float> &estimated_variance,
283       const DeviceMemory<Eigen::half> &side_input,
284       const dnn::BatchDescriptor &x_desc,
285       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
286       const double exponential_average_factor,
287       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
288       DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
289       DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
290       bool is_training, ScratchAllocator *reserve_space_allocator,
291       ScratchAllocator *workspace_allocator);
292 
293   Stream &ThenBatchNormalizationBackward(
294       const DeviceMemory<Eigen::half> &y_backprop,
295       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
296       const DeviceMemory<float> &offset, const DeviceMemory<float> &mean,
297       const DeviceMemory<float> &inv_var, const DeviceMemory<Eigen::half> &y,
298       const dnn::BatchDescriptor &x_desc,
299       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
300       dnn::ActivationMode activation_mode,
301       DeviceMemory<Eigen::half> *x_backprop,
302       DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
303       DeviceMemory<Eigen::half> *side_input_backprop,
304       DeviceMemory<uint8> *reserve_space_data,
305       ScratchAllocator *workspace_allocator);
306 
307   Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
308                        const DeviceMemory<float> &input_data,
309                        const dnn::FilterDescriptor &filter_descriptor,
310                        const DeviceMemory<float> &filter_data,
311                        const dnn::ConvolutionDescriptor &convolution_descriptor,
312                        const dnn::BatchDescriptor &output_descriptor,
313                        DeviceMemory<float> *output);
314 
315   Stream &ThenConvolveQuantized(
316       const dnn::BatchDescriptor &input_descriptor,
317       const DeviceMemory<float> &input_data,
318       const dnn::FilterDescriptor &filter_descriptor,
319       const DeviceMemory<int8> &filter_coefficients,
320       const DeviceMemory<float> &coefficient_scales,
321       const dnn::ConvolutionDescriptor &convolution_descriptor,
322       const dnn::BatchDescriptor &output_descriptor,
323       DeviceMemory<float> *output_data);
324 
325   Stream &ThenConvolveQuantized(
326       const dnn::BatchDescriptor &input_descriptor,
327       const DeviceMemory<float> &input_data,
328       const dnn::FilterDescriptor &filter_descriptor,
329       const DeviceMemory<int16> &filter_coefficients,
330       const DeviceMemory<float> &coefficient_scales,
331       const dnn::ConvolutionDescriptor &convolution_descriptor,
332       const dnn::BatchDescriptor &output_descriptor,
333       DeviceMemory<float> *output_data);
334 
335   template <typename InputType, typename OutputType>
ConvolveWithAlgorithm(dnn::ConvolutionKind kind,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<InputType> input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<InputType> filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)336   port::Status ConvolveWithAlgorithm(
337       dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor,
338       DeviceMemory<InputType> input_data,
339       const dnn::FilterDescriptor &filter_descriptor,
340       DeviceMemory<InputType> filter_data,
341       const dnn::BatchDescriptor &output_descriptor,
342       DeviceMemory<OutputType> output_data,
343       const dnn::ConvolutionDescriptor &convolution_descriptor,
344       ScratchAllocator *scratch_allocator,
345       const dnn::AlgorithmConfig &algorithm_config,
346       dnn::ProfileResult *output_profile_result) {
347     DeviceMemory<uint8> scratch_memory;
348     dnn::AlgorithmDesc algorithm_desc;
349     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
350       TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
351           kind, this, input_descriptor, input_data, filter_descriptor,
352           filter_data, output_descriptor, output_data, convolution_descriptor,
353           algorithm_config, scratch_allocator, &algorithm_desc,
354           &scratch_memory));
355       return dnn->DoConvolve(kind, dnn::ToDataType<InputType>::value,
356                              dnn::ToDataType<OutputType>::value, this,
357                              input_descriptor, input_data, filter_descriptor,
358                              filter_data, output_descriptor, output_data,
359                              convolution_descriptor, algorithm_desc,
360                              scratch_memory, output_profile_result);
361     }
362     return port::UnimplementedError("DNN library is not found.");
363   }
364 
365   template <typename InputT, typename ScaleT, typename SideInputT,
366             typename BiasT, typename OutputT>
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<InputT> & conv_input_data,ScaleT conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<InputT> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<SideInputT> & side_input_data,ScaleT side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasT> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputT> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)367   port::Status FusedConvolveWithAlgorithm(
368       const dnn::BatchDescriptor &conv_input_descriptor,
369       const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale,
370       const dnn::FilterDescriptor &filter_descriptor,
371       const DeviceMemory<InputT> &filter_data,
372       const dnn::ConvolutionDescriptor &convolution_descriptor,
373       const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale,
374       const dnn::BatchDescriptor &bias_descriptor,
375       const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode,
376       const dnn::BatchDescriptor &output_descriptor,
377       DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator,
378       const dnn::AlgorithmConfig &algorithm_config,
379       dnn::ProfileResult *output_profile_result) {
380     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
381       return dnn->DoFusedConvolve(
382           this, dnn::ToDataType<InputT>::value,
383           dnn::ToDataType<SideInputT>::value, dnn::ToDataType<BiasT>::value,
384           dnn::ToDataType<OutputT>::value, conv_input_descriptor,
385           conv_input_data, conv_input_scale, filter_descriptor, filter_data,
386           convolution_descriptor, side_input_data, side_input_scale,
387           bias_descriptor, biases, activation_mode, output_descriptor, *output,
388           scratch_allocator, algorithm_config, output_profile_result);
389     }
390     return port::UnimplementedError("DNN library is not found.");
391   }
392 
ConvolveRunnerFromDesc(const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor)393   port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc(
394       const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind,
395       dnn::DataType element_type, dnn::DataType output_type,
396       const dnn::BatchDescriptor &input_descriptor,
397       const dnn::FilterDescriptor &filter_descriptor,
398       const dnn::BatchDescriptor &output_descriptor,
399       const dnn::ConvolutionDescriptor &convolution_descriptor) {
400     dnn::DnnSupport *dnn_support = parent_->AsDnn();
401     if (!dnn_support) {
402       return port::UnimplementedError("DNN library is not found.");
403     }
404     return dnn_support->ConvolveRunnerFromDesc(
405         this, algorithm_desc, kind, element_type, output_type, input_descriptor,
406         filter_descriptor, output_descriptor, convolution_descriptor);
407   }
408 
409   port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
FusedConvolveRunnerFromDesc(const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_input_scale,double side_input_scale,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode)410   FusedConvolveRunnerFromDesc(
411       const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind,
412       dnn::DataType element_type, dnn::DataType bias_type,
413       dnn::DataType output_type, double conv_input_scale,
414       double side_input_scale, double leakyrelu_alpha,
415       const dnn::BatchDescriptor &input_descriptor,
416       const dnn::FilterDescriptor &filter_descriptor,
417       const dnn::BatchDescriptor &bias_descriptor,
418       const dnn::BatchDescriptor &output_descriptor,
419       const dnn::ConvolutionDescriptor &convolution_descriptor,
420       dnn::ActivationMode activation_mode) {
421     dnn::DnnSupport *dnn_support = parent_->AsDnn();
422     if (!dnn_support) {
423       return port::UnimplementedError("DNN library is not found.");
424     }
425     return dnn_support->FusedConvolveRunnerFromDesc(
426         this, algorithm_desc, kind, element_type, bias_type, output_type,
427         conv_input_scale, side_input_scale, leakyrelu_alpha, input_descriptor,
428         filter_descriptor, bias_descriptor, output_descriptor,
429         convolution_descriptor, activation_mode);
430   }
431 
432   Stream &ThenSeparableConvolve(
433       const dnn::BatchDescriptor &input_descriptor,
434       const DeviceMemory<float> &input_data,
435       const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
436       const DeviceMemory<float> &first_weights,
437       const DeviceMemory<float> &second_weights,
438       const dnn::ConvolutionDescriptor &convolution_descriptor,
439       const dnn::BatchDescriptor &output_descriptor,
440       DeviceMemory<float> *output);
441 
442   Stream &ThenMatMul(const DeviceMemory<float> &input_data,
443                      const DeviceMemory<float> &weights,
444                      const dnn::BatchDescriptor &input_dimensions,
445                      const dnn::BatchDescriptor &output_dimensions,
446                      DeviceMemory<float> *output_data);
447 
448   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
449                               const DeviceMemory<int8> &weights,
450                               const DeviceMemory<float> &weight_scales,
451                               const dnn::BatchDescriptor &input_dimensions,
452                               const dnn::BatchDescriptor &output_dimensions,
453                               DeviceMemory<float> *output_data);
454 
455   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
456                               const DeviceMemory<int16> &weights,
457                               const DeviceMemory<float> &weight_scales,
458                               const dnn::BatchDescriptor &input_dimensions,
459                               const dnn::BatchDescriptor &output_dimensions,
460                               DeviceMemory<float> *output_data);
461 
462   Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
463                       const DeviceMemory<float> &biases,
464                       const dnn::BatchDescriptor &dimensions,
465                       DeviceMemory<float> *output_data);
466 
467   template <typename ElementType>
468   port::Status ThenPoolForward(
469       const dnn::PoolingDescriptor &pooling_dimensions,
470       const dnn::BatchDescriptor &input_dimensions,
471       const DeviceMemory<ElementType> &input_data,
472       const dnn::BatchDescriptor &output_dimensions,
473       DeviceMemory<ElementType> *output_data,
474       ScratchAllocator *workspace_allocator = nullptr) {
475     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
476       return dnn->DoPoolForward(dnn::ToDataType<ElementType>::value, this,
477                                 pooling_dimensions, input_dimensions,
478                                 input_data, output_dimensions, *output_data,
479                                 workspace_allocator);
480     }
481     return port::UnimplementedError("DNN library is not found.");
482   }
483 
484   template <typename ElementType>
485   port::Status ThenPoolBackward(
486       const dnn::PoolingDescriptor &pooling_dimensions,
487       const dnn::BatchDescriptor &input_dimensions,
488       const DeviceMemory<ElementType> &input_data,
489       const dnn::BatchDescriptor &output_dimensions,
490       const DeviceMemory<ElementType> &output_data,
491       const DeviceMemory<ElementType> &input_diff_data,
492       DeviceMemory<ElementType> *output_diff_data,
493       ScratchAllocator *workspace_allocator = nullptr) {
494     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
495       return dnn->DoPoolBackward(
496           dnn::ToDataType<ElementType>::value, this, pooling_dimensions,
497           input_dimensions, input_data, output_dimensions, output_data,
498           input_diff_data, *output_diff_data, workspace_allocator);
499     }
500     return port::UnimplementedError("DNN library is not found.");
501   }
502 
503   Stream &ThenNormalizeWithDimensions(
504       const dnn::NormalizeDescriptor &normalize_descriptor,
505       const dnn::BatchDescriptor &dimensions,
506       const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
507 
508   Stream &ThenNormalizeBackwardWithDimensions(
509       const dnn::NormalizeDescriptor &normalize_descriptor,
510       const dnn::BatchDescriptor &dimensions,
511       const DeviceMemory<float> &raw_data,
512       const DeviceMemory<float> &normalized_data,
513       const DeviceMemory<float> &normalized_variable_gradient,
514       DeviceMemory<float> *raw_variable_gradient,
515       ScratchAllocator *workspace_allocator = nullptr);
516 
517   Stream &ThenActivate(dnn::ActivationMode activation_mode,
518                        const dnn::BatchDescriptor &dimensions,
519                        const DeviceMemory<float> &input_data,
520                        DeviceMemory<float> *output_data);
521 
522   // Same as ThenActivate, but also takes an options argument that can be used
523   // for platform-specific option flags.
524   Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
525                                   const dnn::BatchDescriptor &dimensions,
526                                   const DeviceMemory<float> &input_data,
527                                   DeviceMemory<float> *output_data,
528                                   uint64_t options);
529 
530   Stream &ThenDepthConcatenate(
531       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
532       port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
533       DeviceMemory<float> *output_data);
534 
535   Stream &ThenSpaceConcatenate(
536       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
537       port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
538       DeviceMemory<float> *output_data,
539       dnn::SpaceConcatenateMode concat_direction);
540 
541   // Change the layout of the data by shrinking one dimension (or set of
542   // dimensions) and growing another dimension (or set of dimensions), while
543   // keeping the total number of data elements constant, and maintaining the
544   // current data ordering.
545   Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
546                       const DeviceMemory<float> &input_data,
547                       const dnn::BatchDescriptor &output_dimensions,
548                       DeviceMemory<float> *output_data);
549 
550   // Depth to space takes an X by Y image with depth D*M² and changes it to an
551   // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
552   // the input image is changed to an MxM contiguous area in the output image,
553   // with the values being laid out in raster order specified by
554   // DepthToSpaceLayout, and will have a new depth of D.
555   // See the DoDepthToSpace comment for more information.
556   Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
557                            const DeviceMemory<float> &input_data,
558                            const dnn::DepthToSpaceLayout &depth_to_space_layout,
559                            const int sqrt_depth_reduction,
560                            DeviceMemory<float> *output_data);
561 
562   // Space to depth is the inverse of depth to space. Space to depth takes each
563   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
564   // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
565   // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
566   // data elements is not changed.
567   Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
568                            const DeviceMemory<float> &input_data,
569                            const dnn::DepthToSpaceLayout &space_to_depth_layout,
570                            const int sqrt_depth_increase,
571                            DeviceMemory<float> *output_data);
572 
573   Stream &ThenElementwiseOperate(
574       dnn::ElementwiseOperation operation,
575       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
576       port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
577       const dnn::BatchDescriptor &output_dimensions,
578       DeviceMemory<float> *output_data);
579 
580   Stream &ThenElementwiseOperateScaledQuantized(
581       dnn::ElementwiseOperation operation,
582       port::ArraySlice<int> input_multiplicands,  // non-absl ok
583       int output_divisor,
584       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
585       port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
586       const dnn::BatchDescriptor &output_dimensions,
587       DeviceMemory<float> *output_data);
588 
589   Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
590                     const DeviceMemory<float> &input_data, int64_t left_pad,
591                     int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
592                     DeviceMemory<float> *output_data);
593 
594   Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
595                       const DeviceMemory<float> &input_data, int64_t left_trim,
596                       int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
597                       DeviceMemory<float> *output_data);
598 
599   // Grows the input tensor by replicating the X and Y dimensions. The batch and
600   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
601   // limited to X=1 and Y=1.
602   Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
603                           const DeviceMemory<float> &input_data,
604                           int64_t replicate_x, int64_t replicate_y,
605                           DeviceMemory<float> *output_data);
606 
607   // See DnnSupport::DoMemcpyD2HQuantized.
608   Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
609                                  dnn::QuantizedActivationMode mode,
610                                  void *host_dst, uint64_t size);
611 
612   // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
613   // and uses the Quantization trait to call the generic version of
614   // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
615   template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)616   Stream &ThenMemcpyD2HQuantized(
617       const DeviceMemory<float> &gpu_unquantized_src,
618       port::MutableArraySlice<ElementType> host_dst) {
619     return ThenMemcpyD2HQuantized(
620         gpu_unquantized_src, Quantization<ElementType>::kModeId,
621         host_dst.data(), host_dst.size() * sizeof(ElementType));
622   }
623 
624   // See DnnSupport::DoMemcpyH2DQuantized.
625   Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64_t size,
626                                  dnn::QuantizedActivationMode mode,
627                                  DeviceMemory<float> *gpu_unquantized_dst);
628 
629   // Template version of ThenMemcpyH2DQuantized that takes an array slice
630   // and uses the Quantization trait to call the generic version of
631   // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
632   template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)633   Stream &ThenMemcpyH2DQuantized(
634       port::ArraySlice<ElementType> host_src,  // non-absl ok
635       DeviceMemory<float> *gpu_unquantized_dst) {
636     return ThenMemcpyH2DQuantized(
637         host_src.data(), host_src.size() * sizeof(ElementType),
638         Quantization<ElementType>::kModeId, gpu_unquantized_dst);
639   }
640 
641   // See DnnSupport::DoCopyHostBuffer2Device.
642   Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
643                                     DeviceMemory<float> *gpu_unquantized_dst);
644 
645   // See DnnSupport::DoCopyDevice2HostBuffer.
646   Stream &ThenCopyDevice2HostBuffer(
647       const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
648 
649   /////////////////
650   // BLAS support
651 
652   // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
653   // present in DeviceMemory, it must be an execution-time constant (i.e. a
654   // value
655   // that the stream does not change or populate during the course of
656   // execution). The value is effectively captured at stream-enqueue time.
657   Stream &ThenBlasAxpy(uint64_t elem_count, float alpha,
658                        const DeviceMemory<float> &x, int incx,
659                        DeviceMemory<float> *y, int incy);
660   Stream &ThenBlasAxpy(uint64_t elem_count, double alpha,
661                        const DeviceMemory<double> &x, int incx,
662                        DeviceMemory<double> *y, int incy);
663   Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<float> alpha,
664                        const DeviceMemory<std::complex<float>> &x, int incx,
665                        DeviceMemory<std::complex<float>> *y, int incy);
666   Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<double> alpha,
667                        const DeviceMemory<std::complex<double>> &x, int incx,
668                        DeviceMemory<std::complex<double>> *y, int incy);
669 
670   // See BlasSupport::DoBlasCopy.
671   Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<float> &x,
672                        int incx, DeviceMemory<float> *y, int incy);
673   Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<double> &x,
674                        int incx, DeviceMemory<double> *y, int incy);
675   Stream &ThenBlasCopy(uint64_t elem_count,
676                        const DeviceMemory<std::complex<float>> &x, int incx,
677                        DeviceMemory<std::complex<float>> *y, int incy);
678   Stream &ThenBlasCopy(uint64_t elem_count,
679                        const DeviceMemory<std::complex<double>> &x, int incx,
680                        DeviceMemory<std::complex<double>> *y, int incy);
681 
682   // See BlasSupport::DoBlasScal.
683   Stream &ThenBlasScal(uint64_t elem_count, float alpha, DeviceMemory<float> *x,
684                        int incx);
685   Stream &ThenBlasScal(uint64_t elem_count, double alpha,
686                        DeviceMemory<double> *x, int incx);
687   Stream &ThenBlasScal(uint64_t elem_count, float alpha,
688                        DeviceMemory<std::complex<float>> *x, int incx);
689   Stream &ThenBlasScal(uint64_t elem_count, double alpha,
690                        DeviceMemory<std::complex<double>> *x, int incx);
691   Stream &ThenBlasScal(uint64_t elem_count, std::complex<float> alpha,
692                        DeviceMemory<std::complex<float>> *x, int incx);
693   Stream &ThenBlasScal(uint64_t elem_count, std::complex<double> alpha,
694                        DeviceMemory<std::complex<double>> *x, int incx);
695 
696   // See BlasSupport::DoBlasGemv.
697   Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, float alpha,
698                        const DeviceMemory<float> &a, int lda,
699                        const DeviceMemory<float> &x, int incx, float beta,
700                        DeviceMemory<float> *y, int incy);
701   Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
702                        double alpha, const DeviceMemory<double> &a, int lda,
703                        const DeviceMemory<double> &x, int incx, double beta,
704                        DeviceMemory<double> *y, int incy);
705   Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
706                        std::complex<float> alpha,
707                        const DeviceMemory<std::complex<float>> &a, int lda,
708                        const DeviceMemory<std::complex<float>> &x, int incx,
709                        std::complex<float> beta,
710                        DeviceMemory<std::complex<float>> *y, int incy);
711   Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
712                        std::complex<double> alpha,
713                        const DeviceMemory<std::complex<double>> &a, int lda,
714                        const DeviceMemory<std::complex<double>> &x, int incx,
715                        std::complex<double> beta,
716                        DeviceMemory<std::complex<double>> *y, int incy);
717 
718   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n,
719                                     float alpha, const DeviceMemory<float> &a,
720                                     int lda, const DeviceMemory<float> &x,
721                                     int incx, float beta,
722                                     DeviceMemory<float> *y, int incy,
723                                     blas::ProfileResult *output_profile_result);
724   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n,
725                                     double alpha, const DeviceMemory<double> &a,
726                                     int lda, const DeviceMemory<double> &x,
727                                     int incx, double beta,
728                                     DeviceMemory<double> *y, int incy,
729                                     blas::ProfileResult *output_profile_result);
730   Stream &ThenBlasGemvWithProfiling(
731       blas::Transpose trans, uint64_t m, uint64 n, std::complex<float> alpha,
732       const DeviceMemory<std::complex<float>> &a, int lda,
733       const DeviceMemory<std::complex<float>> &x, int incx,
734       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
735       blas::ProfileResult *output_profile_result);
736   Stream &ThenBlasGemvWithProfiling(
737       blas::Transpose trans, uint64_t m, uint64 n, std::complex<double> alpha,
738       const DeviceMemory<std::complex<double>> &a, int lda,
739       const DeviceMemory<std::complex<double>> &x, int incx,
740       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
741       int incy, blas::ProfileResult *output_profile_result);
742 
743   // See BlasSupport::DoBlasSbmv.
744   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, float alpha,
745                        const DeviceMemory<float> &a, int lda,
746                        const DeviceMemory<float> &x, int incx, float beta,
747                        DeviceMemory<float> *y, int incy);
748   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k,
749                        double alpha, const DeviceMemory<double> &a, int lda,
750                        const DeviceMemory<double> &x, int incx, double beta,
751                        DeviceMemory<double> *y, int incy);
752 
753   template <typename InputType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<InputType> * c,int ldc,blas::ComputePrecision precision)754   port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
755                             uint64_t m, uint64 n, uint64 k,
756                             const DeviceMemory<InputType> &a, int lda,
757                             const DeviceMemory<InputType> &b, int ldb,
758                             DeviceMemory<InputType> *c, int ldc,
759                             blas::ComputePrecision precision) {
760     InputType alpha{1.0};
761     InputType beta{0.0};
762     return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
763                         ldc, precision);
764   }
765 
766   // TODO(parkers): Update all callers to pass kDefaultComputePrecision.
767   template <typename InputType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<InputType> * c,int ldc)768   port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
769                             uint64_t m, uint64 n, uint64 k,
770                             const DeviceMemory<InputType> &a, int lda,
771                             const DeviceMemory<InputType> &b, int ldb,
772                             DeviceMemory<InputType> *c, int ldc) {
773     return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc,
774                         blas::kDefaultComputePrecision);
775   }
776 
777   template <typename InputType, typename ConstantType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<InputType> * c,int ldc,blas::ComputePrecision precision)778   port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
779                             uint64_t m, uint64 n, uint64 k, ConstantType alpha,
780                             const DeviceMemory<InputType> &a, int lda,
781                             const DeviceMemory<InputType> &b, int ldb,
782                             ConstantType beta, DeviceMemory<InputType> *c,
783                             int ldc, blas::ComputePrecision precision) {
784     static_assert(
785         detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16, float,
786                           double, std::complex<float>, std::complex<double>>(),
787         "Input can be half, bf16, float, double, std::complex<float> or "
788         "std::complex<double>");
789     static_assert(!std::is_same_v<InputType, Eigen::half> ||
790                       detail::is_any_of<ConstantType, float, Eigen::half>(),
791                   "If input is Eigen::half, constant has to be either "
792                   "Eigen::half or float");
793     static_assert(
794         detail::is_any_of<InputType, Eigen::half, ConstantType>(),
795         "If input is not Eigen::half, constant and input types have to match");
796     blas::BlasSupport *blas = parent()->AsBlas();
797     if (!blas) {
798       return port::InternalError(
799           "Attempting to perform BLAS operation using "
800           "StreamExecutor without BLAS support");
801     }
802 
803     void *alpha_ptr = &alpha;
804     void *beta_ptr = &beta;
805     float alpha_storage, beta_storage;
806     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
807                                     &beta_storage);
808 
809     return blas->DoBlasGemm(this, transa, transb, m, n, k,
810                             blas::ToDataType<InputType>::value, alpha_ptr, a,
811                             lda, b, ldb, beta_ptr, c, ldc, precision);
812   }
813 
814   // TODO(parkers): Update all callers to pass kDefaultComputePrecision.
815   template <typename InputType, typename ConstantType>
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<InputType> * c,int ldc)816   port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
817                             uint64_t m, uint64 n, uint64 k, ConstantType alpha,
818                             const DeviceMemory<InputType> &a, int lda,
819                             const DeviceMemory<InputType> &b, int ldb,
820                             ConstantType beta, DeviceMemory<InputType> *c,
821                             int ldc) {
822     return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
823                         ldc, blas::kDefaultComputePrecision);
824   }
825 
826   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
827                                     blas::Transpose transb, uint64_t m,
828                                     uint64 n, uint64_t k, float alpha,
829                                     const DeviceMemory<Eigen::half> &a, int lda,
830                                     const DeviceMemory<Eigen::half> &b, int ldb,
831                                     float beta, DeviceMemory<Eigen::half> *c,
832                                     int ldc,
833                                     blas::ProfileResult *output_profile_result);
834   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
835                                     blas::Transpose transb, uint64_t m,
836                                     uint64 n, uint64_t k, float alpha,
837                                     const DeviceMemory<float> &a, int lda,
838                                     const DeviceMemory<float> &b, int ldb,
839                                     float beta, DeviceMemory<float> *c, int ldc,
840                                     blas::ProfileResult *output_profile_result);
841   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
842                                     blas::Transpose transb, uint64_t m,
843                                     uint64 n, uint64_t k, double alpha,
844                                     const DeviceMemory<double> &a, int lda,
845                                     const DeviceMemory<double> &b, int ldb,
846                                     double beta, DeviceMemory<double> *c,
847                                     int ldc,
848                                     blas::ProfileResult *output_profile_result);
849   Stream &ThenBlasGemmWithProfiling(
850       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
851       uint64_t k, std::complex<float> alpha,
852       const DeviceMemory<std::complex<float>> &a, int lda,
853       const DeviceMemory<std::complex<float>> &b, int ldb,
854       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
855       blas::ProfileResult *output_profile_result);
856   Stream &ThenBlasGemmWithProfiling(
857       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
858       uint64_t k, std::complex<double> alpha,
859       const DeviceMemory<std::complex<double>> &a, int lda,
860       const DeviceMemory<std::complex<double>> &b, int ldb,
861       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
862       blas::ProfileResult *output_profile_result);
863 
864   template <typename InputType, typename OutputType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)865   port::Status ThenBlasGemmWithAlgorithm(
866       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
867       uint64_t k, const DeviceMemory<InputType> &a, int lda,
868       const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c,
869       int ldc, blas::ComputationType computation_type,
870       blas::AlgorithmType algorithm,
871       blas::ProfileResult *output_profile_result) {
872     OutputType alpha{1};
873     OutputType beta{0};
874     return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b,
875                                      ldb, beta, c, ldc, computation_type,
876                                      algorithm, output_profile_result);
877   }
878 
879   template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,const DeviceMemory<InputType> & b,int ldb,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)880   port::Status ThenBlasGemmWithAlgorithm(
881       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
882       uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
883       const DeviceMemory<InputType> &b, int ldb, ConstantType beta,
884       DeviceMemory<OutputType> *c, int ldc,
885       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
886       blas::ProfileResult *output_profile_result) {
887     TF_RETURN_IF_ERROR(
888         CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
889             computation_type));
890 
891     blas::BlasSupport *blas = parent()->AsBlas();
892     if (!blas) {
893       return port::InternalError(
894           "Attempting to perform BLAS operation using "
895           "StreamExecutor without BLAS support");
896     }
897 
898     void *alpha_ptr = &alpha;
899     void *beta_ptr = &beta;
900     float alpha_storage, beta_storage;
901     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
902                                     &beta_storage);
903 
904     port::Status st = blas->DoBlasGemmWithAlgorithm(
905         this, transa, transb, m, n, k, alpha_ptr, a,
906         blas::ToDataType<InputType>::value, lda, b,
907         blas::ToDataType<InputType>::value, ldb, beta_ptr, c,
908         blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm,
909         output_profile_result);
910     if (output_profile_result) {
911       // The error is recorded in the profile.
912       return ::tensorflow::OkStatus();
913     }
914     return st;
915   }
916 
917   template <typename InputType, typename OutputType, typename ConstantType>
ThenBlasGemmStridedBatchedWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<OutputType> * c,int ldc,int64_t stride_c,int batch_count,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)918   port::Status ThenBlasGemmStridedBatchedWithAlgorithm(
919       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
920       uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
921       int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
922       int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
923       int64_t stride_c, int batch_count, blas::ComputationType computation_type,
924       blas::AlgorithmType algorithm,
925       blas::ProfileResult *output_profile_result) {
926     TF_RETURN_IF_ERROR(
927         CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
928             computation_type));
929 
930     blas::BlasSupport *blas = parent()->AsBlas();
931     if (!blas) {
932       return port::InternalError(
933           "Attempting to perform BLAS operation using "
934           "StreamExecutor without BLAS support");
935     }
936     void *alpha_ptr = &alpha;
937     void *beta_ptr = &beta;
938     float alpha_storage, beta_storage;
939     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
940                                     &beta_storage);
941     port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm(
942         this, transa, transb, m, n, k, alpha_ptr, a,
943         blas::ToDataType<InputType>::value, lda, stride_a, b,
944         blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c,
945         blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
946         computation_type, algorithm, output_profile_result);
947     if (output_profile_result) {
948       // The error is recorded in the profile.
949       return ::tensorflow::OkStatus();
950     }
951     return st;
952   }
953 
954   template <typename T>
955   using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>;  // non-absl ok
956 
957   // See BlasSupport::DoBlasGemmBatched.
958   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
959                               uint64_t m, uint64 n, uint64_t k, float alpha,
960                               const DeviceMemorySlice<Eigen::half> &a, int lda,
961                               const DeviceMemorySlice<Eigen::half> &b, int ldb,
962                               float beta,
963                               const DeviceMemorySlice<Eigen::half> &c, int ldc,
964                               int batch_count);
965   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
966                               uint64_t m, uint64 n, uint64 k, float alpha,
967                               const DeviceMemorySlice<float> &a, int lda,
968                               const DeviceMemorySlice<float> &b, int ldb,
969                               float beta, const DeviceMemorySlice<float> &c,
970                               int ldc, int batch_count);
971   Stream &ThenBlasGemmBatched(
972       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
973       uint64 k, double alpha,
974       const port::ArraySlice<DeviceMemory<double> *> &a,  // non-absl ok
975       int lda,
976       const port::ArraySlice<DeviceMemory<double> *> &b,  // non-absl ok
977       int ldb, double beta,
978       const port::ArraySlice<DeviceMemory<double> *> &c,  // non-absl ok
979       int ldc, int batch_count);
980   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
981                               uint64_t m, uint64 n, uint64_t k,
982                               std::complex<float> alpha,
983                               const DeviceMemorySlice<std::complex<float>> &a,
984                               int lda,
985                               const DeviceMemorySlice<std::complex<float>> &b,
986                               int ldb, std::complex<float> beta,
987                               const DeviceMemorySlice<std::complex<float>> &c,
988                               int ldc, int batch_count);
989   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
990                               uint64_t m, uint64 n, uint64_t k,
991                               std::complex<double> alpha,
992                               const DeviceMemorySlice<std::complex<double>> &a,
993                               int lda,
994                               const DeviceMemorySlice<std::complex<double>> &b,
995                               int ldb, std::complex<double> beta,
996                               const DeviceMemorySlice<std::complex<double>> &c,
997                               int ldc, int batch_count);
998   Stream &ThenBlasGemmBatchedWithScratch(
999       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1000       uint64_t k, float alpha, const DeviceMemorySlice<Eigen::half> &a, int lda,
1001       const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta,
1002       const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count,
1003       ScratchAllocator *scratch_allocator);
1004   Stream &ThenBlasGemmBatchedWithScratch(
1005       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1006       uint64_t k, float alpha, const DeviceMemorySlice<float> &a, int lda,
1007       const DeviceMemorySlice<float> &b, int ldb, float beta,
1008       const DeviceMemorySlice<float> &c, int ldc, int batch_count,
1009       ScratchAllocator *scratch_allocator);
1010   Stream &ThenBlasGemmBatchedWithScratch(
1011       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1012       uint64_t k, double alpha, const DeviceMemorySlice<double> &a, int lda,
1013       const DeviceMemorySlice<double> &b, int ldb, double beta,
1014       const DeviceMemorySlice<double> &c, int ldc, int batch_count,
1015       ScratchAllocator *scratch_allocator);
1016   Stream &ThenBlasGemmBatchedWithScratch(
1017       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1018       uint64_t k, std::complex<float> alpha,
1019       const DeviceMemorySlice<std::complex<float>> &a, int lda,
1020       const DeviceMemorySlice<std::complex<float>> &b, int ldb,
1021       std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c,
1022       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1023   Stream &ThenBlasGemmBatchedWithScratch(
1024       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1025       uint64_t k, std::complex<double> alpha,
1026       const DeviceMemorySlice<std::complex<double>> &a, int lda,
1027       const DeviceMemorySlice<std::complex<double>> &b, int ldb,
1028       std::complex<double> beta,
1029       const DeviceMemorySlice<std::complex<double>> &c, int ldc,
1030       int batch_count, ScratchAllocator *scratch_allocator);
1031 
1032   template <typename InputType, typename ConstantType>
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,ConstantType alpha,const DeviceMemory<InputType> & a,int lda,int64_t stride_a,const DeviceMemory<InputType> & b,int ldb,int64_t stride_b,ConstantType beta,DeviceMemory<InputType> * c,int ldc,int64_t stride_c,int batch_count)1033   port::Status ThenBlasGemmStridedBatched(
1034       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1035       uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1036       int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
1037       int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc,
1038       int64_t stride_c, int batch_count) {
1039     static_assert(
1040         detail::is_any_of<InputType, float, Eigen::half, Eigen::bfloat16,
1041                           double, std::complex<float>, std::complex<double>>(),
1042         "Unsupported input type");
1043     static_assert(
1044         std::is_same_v<ConstantType, InputType> ||
1045             (detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16>() &&
1046              std::is_same_v<ConstantType, float>),
1047         "Mismatched input and alpha/beta types");
1048     blas::BlasSupport *blas = parent()->AsBlas();
1049     if (!blas) {
1050       return port::InternalError(
1051           "Attempting to perform BLAS operation using "
1052           "StreamExecutor without BLAS support");
1053     }
1054 
1055     void *alpha_ptr = &alpha;
1056     void *beta_ptr = &beta;
1057     float alpha_storage, beta_storage;
1058     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1059                                     &beta_storage);
1060 
1061     return blas->DoBlasGemmStridedBatched(
1062         this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
1063         alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
1064         stride_c, batch_count);
1065   }
1066 
1067   // See BlasSupport::DoBlasTrsm.
1068   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1069                        blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1070                        uint64_t n, float alpha, const DeviceMemory<float> &a,
1071                        int lda, DeviceMemory<float> *b, int ldb);
1072   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1073                        blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1074                        uint64_t n, double alpha, const DeviceMemory<double> &a,
1075                        int lda, DeviceMemory<double> *b, int ldb);
1076   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1077                        blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1078                        uint64_t n, std::complex<float> alpha,
1079                        const DeviceMemory<std::complex<float>> &a, int lda,
1080                        DeviceMemory<std::complex<float>> *b, int ldb);
1081   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1082                        blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1083                        uint64_t n, std::complex<double> alpha,
1084                        const DeviceMemory<std::complex<double>> &a, int lda,
1085                        DeviceMemory<std::complex<double>> *b, int ldb);
1086 
1087   // See BlasSupport::DoBlasTrsmBatched.
1088   Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1089                               blas::Transpose transa, blas::Diagonal diag,
1090                               uint64_t m, uint64 n, float alpha,
1091                               const DeviceMemory<float *> &as, int lda,
1092                               DeviceMemory<float *> *bs, int ldb,
1093                               int batch_count);
1094   Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1095                               blas::Transpose transa, blas::Diagonal diag,
1096                               uint64_t m, uint64 n, double alpha,
1097                               const DeviceMemory<double *> &as, int lda,
1098                               DeviceMemory<double *> *bs, int ldb,
1099                               int batch_count);
1100   Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1101                               blas::Transpose transa, blas::Diagonal diag,
1102                               uint64_t m, uint64 n, std::complex<float> alpha,
1103                               const DeviceMemory<std::complex<float> *> &as,
1104                               int lda, DeviceMemory<std::complex<float> *> *bs,
1105                               int ldb, int batch_count);
1106   Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1107                               blas::Transpose transa, blas::Diagonal diag,
1108                               uint64_t m, uint64 n, std::complex<double> alpha,
1109                               const DeviceMemory<std::complex<double> *> &as,
1110                               int lda, DeviceMemory<std::complex<double> *> *bs,
1111                               int ldb, int batch_count);
1112 
1113   // See FftSupport::DoFft.
1114   Stream &ThenFft(fft::Plan *plan,
1115                   const DeviceMemory<std::complex<float>> &input,
1116                   DeviceMemory<std::complex<float>> *output);
1117   Stream &ThenFft(fft::Plan *plan,
1118                   const DeviceMemory<std::complex<double>> &input,
1119                   DeviceMemory<std::complex<double>> *output);
1120   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1121                   DeviceMemory<std::complex<float>> *output);
1122   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1123                   DeviceMemory<std::complex<double>> *output);
1124   Stream &ThenFft(fft::Plan *plan,
1125                   const DeviceMemory<std::complex<float>> &input,
1126                   DeviceMemory<float> *output);
1127   Stream &ThenFft(fft::Plan *plan,
1128                   const DeviceMemory<std::complex<double>> &input,
1129                   DeviceMemory<double> *output);
1130 
1131   // Makes the RNG use the provided value as the basis for further generation.
1132   // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1133   // sources of seed data if the default (high quality) sources are not
1134   // desired.
1135   // For most use cases, this function will not be necessary; each provided
1136   // back-end implementation will be appropriately seeded by default.
1137   // At a minimum 16 bytes of data are required in the seed buffer.
1138   //
1139   // To seed with good (non-reproducible) data:
1140   //   File* f = File::Open("/dev/random", "r");
1141   //   int64_t bytes_read = f->Read(seed_data, bytes_to_read);
1142   //   < error checking >
1143   //   stream.ThenSetRngSeed(seed_data, bytes_read);
1144   //
1145   // To seed with reproducible data:
1146   //   uint64_t seed_data[2] = { <data> };
1147   //   stream.ThenSetRngSeed(seed_data, 16);
1148   Stream &ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes);
1149 
1150   // Populates the memory indicated by values with uniform-random-distribution
1151   // values. TODO(leary) seeding API/description
1152   //
1153   // Uses the type and size of the DeviceMemory to infer what data should be
1154   // populated.
1155   Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1156   Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1157   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1158   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1159   Stream &ThenPopulateRandGaussian(float mean, float stddev,
1160                                    DeviceMemory<float> *values);
1161   Stream &ThenPopulateRandGaussian(double mean, double stddev,
1162                                    DeviceMemory<double> *values);
1163 
1164   // Entrain onto the stream: a memcpy to a host destination from a GPU source
1165   // of the given target size. host_dst must be a pointer to host memory
1166   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1167   // then registered with StreamExecutor::HostMemoryRegister.
1168   Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1169                      uint64_t size);
1170 
1171   // Entrain onto the stream: a memcpy to a GPU destination from a host source
1172   // of the given target size. host_src must be a pointer to host memory
1173   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1174   // then registered with StreamExecutor::HostMemoryRegister.
1175   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1176                      uint64_t size);
1177 
1178   // Alternative interface for memcpying from device to host that takes an
1179   // array slice. Checks that the destination size can accommodate the host
1180   // slice size.
1181   template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1182   Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1183                         port::MutableArraySlice<T> host_dst) {
1184     auto host_size = host_dst.size() * sizeof(T);
1185     CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1186     return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1187   }
1188 
1189   // Alternative interface for memcpying from host to device that takes an
1190   // array slice. Checks that the destination size can accommodate the host
1191   // slice size.
1192   template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1193   Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,  // non-absl ok
1194                         DeviceMemory<T> *gpu_dst) {
1195     auto host_size = host_src.size() * sizeof(T);
1196     CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1197     return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1198   }
1199 
1200   // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1201   // of the given target size. gpu_src/dst must be pointers to GPU memory and
1202   // peer access must be enabled between their owning StreamExecutors.
1203   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1204                      uint64_t size);
1205 
1206   // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1207   // ensuring that the host pointer isn't getting confused accidentally with a
1208   // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64_t size)1209   Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1210                         const DeviceMemoryBase &gpu_src, uint64_t size) {
1211     return ThenMemcpy(gpu_dst, gpu_src, size);
1212   }
1213 
1214   // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1215   // The location must not be null.
1216   Stream &ThenMemZero(DeviceMemoryBase *location, uint64_t size);
1217 
1218   // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1219   // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1220   // by 4). The location must not be null.
1221   Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
1222                        uint64_t size);
1223 
1224   // Enqueue a forward operation of the RNN model onto the stream.
1225   // See DnnSupport::DoRnnForward for more details.
1226   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1227                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1228                          const DeviceMemory<Eigen::half> &input_data,
1229                          const DeviceMemory<int> &seq_lengths_data,
1230                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1231                          const DeviceMemory<Eigen::half> &input_h_data,
1232                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1233                          const DeviceMemory<Eigen::half> &input_c_data,
1234                          const DeviceMemory<Eigen::half> &params,
1235                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1236                          DeviceMemory<Eigen::half> *output_data,
1237                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1238                          DeviceMemory<Eigen::half> *output_h_data,
1239                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1240                          DeviceMemory<Eigen::half> *output_c_data,
1241                          bool is_training,
1242                          ScratchAllocator *reserve_space_allocator,
1243                          ScratchAllocator *workspace_allocator,
1244                          dnn::ProfileResult *output_profile_result);
1245 
1246   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1247                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1248                          const DeviceMemory<float> &input_data,
1249                          const DeviceMemory<int> &seq_lengths_data,
1250                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1251                          const DeviceMemory<float> &input_h_data,
1252                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1253                          const DeviceMemory<float> &input_c_data,
1254                          const DeviceMemory<float> &params,
1255                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1256                          DeviceMemory<float> *output_data,
1257                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1258                          DeviceMemory<float> *output_h_data,
1259                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1260                          DeviceMemory<float> *output_c_data, bool is_training,
1261                          ScratchAllocator *reserve_space_allocator,
1262                          ScratchAllocator *workspace_allocator,
1263                          dnn::ProfileResult *output_profile_result);
1264 
1265   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1266                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1267                          const DeviceMemory<double> &input_data,
1268                          const DeviceMemory<int> &seq_lengths_data,
1269                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1270                          const DeviceMemory<double> &input_h_data,
1271                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1272                          const DeviceMemory<double> &input_c_data,
1273                          const DeviceMemory<double> &params,
1274                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1275                          DeviceMemory<double> *output_data,
1276                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1277                          DeviceMemory<double> *output_h_data,
1278                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1279                          DeviceMemory<double> *output_c_data, bool is_training,
1280                          ScratchAllocator *reserve_space_allocator,
1281                          ScratchAllocator *workspace_allocator,
1282                          dnn::ProfileResult *output_profile_result);
1283 
1284   // Enqueue a backward operation of the RNN model onto the stream.
1285   // See DnnSupport::DoRnnBackward for more details.
1286   Stream &ThenRnnBackward(
1287       const dnn::RnnDescriptor &rnn_desc,
1288       const dnn::RnnSequenceTensorDescriptor &input_desc,
1289       const DeviceMemory<Eigen::half> &input_data,
1290       const DeviceMemory<int> &seq_lengths_data,
1291       const dnn::RnnStateTensorDescriptor &input_h_desc,
1292       const DeviceMemory<Eigen::half> &input_h_data,
1293       const dnn::RnnStateTensorDescriptor &input_c_desc,
1294       const DeviceMemory<Eigen::half> &input_c_data,
1295       const DeviceMemory<Eigen::half> &params,
1296       const dnn::RnnSequenceTensorDescriptor &output_desc,
1297       const DeviceMemory<Eigen::half> &output_data,
1298       const dnn::RnnStateTensorDescriptor &output_h_desc,
1299       const DeviceMemory<Eigen::half> &output_h_data,
1300       const dnn::RnnStateTensorDescriptor &output_c_desc,
1301       const DeviceMemory<Eigen::half> &output_c_data,
1302       const DeviceMemory<Eigen::half> &output_backprop_data,
1303       const DeviceMemory<Eigen::half> &output_h_backprop_data,
1304       const DeviceMemory<Eigen::half> &output_c_backprop_data,
1305       DeviceMemory<Eigen::half> *input_backprop_data,
1306       DeviceMemory<Eigen::half> *input_h_backprop_data,
1307       DeviceMemory<Eigen::half> *input_c_backprop_data,
1308       DeviceMemory<Eigen::half> *params_backprop_data,
1309       DeviceMemory<uint8> *reserve_space_data,
1310       ScratchAllocator *workspace_allocator,
1311       dnn::ProfileResult *output_profile_result);
1312 
1313   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1314                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1315                           const DeviceMemory<float> &input_data,
1316                           const DeviceMemory<int> &seq_lengths_data,
1317                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1318                           const DeviceMemory<float> &input_h_data,
1319                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1320                           const DeviceMemory<float> &input_c_data,
1321                           const DeviceMemory<float> &params,
1322                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1323                           const DeviceMemory<float> &output_data,
1324                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1325                           const DeviceMemory<float> &output_h_data,
1326                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1327                           const DeviceMemory<float> &output_c_data,
1328                           const DeviceMemory<float> &output_backprop_data,
1329                           const DeviceMemory<float> &output_h_backprop_data,
1330                           const DeviceMemory<float> &output_c_backprop_data,
1331                           DeviceMemory<float> *input_backprop_data,
1332                           DeviceMemory<float> *input_h_backprop_data,
1333                           DeviceMemory<float> *input_c_backprop_data,
1334                           DeviceMemory<float> *params_backprop_data,
1335                           DeviceMemory<uint8> *reserve_space_data,
1336                           ScratchAllocator *workspace_allocator,
1337                           dnn::ProfileResult *output_profile_result);
1338 
1339   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1340                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1341                           const DeviceMemory<double> &input_data,
1342                           const DeviceMemory<int> &seq_lengths_data,
1343                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1344                           const DeviceMemory<double> &input_h_data,
1345                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1346                           const DeviceMemory<double> &input_c_data,
1347                           const DeviceMemory<double> &params,
1348                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1349                           const DeviceMemory<double> &output_data,
1350                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1351                           const DeviceMemory<double> &output_h_data,
1352                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1353                           const DeviceMemory<double> &output_c_data,
1354                           const DeviceMemory<double> &output_backprop_data,
1355                           const DeviceMemory<double> &output_h_backprop_data,
1356                           const DeviceMemory<double> &output_c_backprop_data,
1357                           DeviceMemory<double> *input_backprop_data,
1358                           DeviceMemory<double> *input_h_backprop_data,
1359                           DeviceMemory<double> *input_c_backprop_data,
1360                           DeviceMemory<double> *params_backprop_data,
1361                           DeviceMemory<uint8> *reserve_space_data,
1362                           ScratchAllocator *workspace_allocator,
1363                           dnn::ProfileResult *output_profile_result);
1364 
1365   // Enqueue a CTCLoss operation onto the stream.
1366   // See DnnSupport::DoCtcLoss for more details.
1367   Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
1368                       const DeviceMemory<float> &probs_data,
1369                       absl::Span<const int> labels_data,
1370                       absl::Span<const int> labels_lengths_data,
1371                       absl::Span<const int> input_lengths_data,
1372                       DeviceMemory<float> *costs_data,
1373                       const dnn::RnnStateTensorDescriptor &grads_desc,
1374                       DeviceMemory<float> *grads_data,
1375                       ScratchAllocator *workspace_allocator);
1376 
1377   // Enqueue onto the stream a operation that transforms a tensor.
1378   // See DnnSupport::DoTransformTensor for more details.
1379   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1380                               dnn::DataType input_type,
1381                               const DeviceMemoryBase &input_data,
1382                               const dnn::BatchDescriptor &output_desc,
1383                               dnn::DataType output_type, float scale,
1384                               DeviceMemoryBase *output_data);
1385 
1386   // The templated version of the above ThenTransformTensor. Useful when the
1387   // input and output types are statically known.
1388   template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1389   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1390                               const DeviceMemory<InElemT> &input_data,
1391                               const dnn::BatchDescriptor &output_desc,
1392                               DeviceMemory<OutElemT> *output_data) {
1393     return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1394                                input_data, output_desc,
1395                                dnn::ToDataType<OutElemT>(), output_data);
1396   }
1397 
1398   // (Synchronously) block the host code waiting for the operations
1399   // entrained on the stream (enqueued to this point in program
1400   // execution) to complete.
1401   //
1402   // Returns an OK status if the blocking was successful and the stream is ok().
1403   // Otherwise returns an error describing why the blocking failed.
1404   port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_);
1405 
1406   // Warning! This method interacts with internal threads in
1407   // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1408   // use
1409   // only. Please check with a member of the FASTR team before making use of
1410   // this method.
1411   //
1412   // Entrains onto the stream a function to be executed on the host at some
1413   // point in the future.
1414   // Async host callbacks DO NOT block the stream as device functions (or as
1415   // synchronous host callbacks). No synchronization is possible with
1416   // asynchronous callbacks; they are strictly fire-and-forget.
1417   // This method is private due to the potential for undefined behavior with
1418   // synchronization using OpenCL user events.
1419   // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1420   // parameter will still be valid - this Stream may not be!
1421   // Any callbacks requiring device API calls must use this method.
1422   Stream &ThenEnqueueOnBackgroundThread(
1423       std::function<void(StreamExecutor *)> task);
1424 
1425   // Returns the (opaque) platform-specific backing object. Ownership is not
1426   // transferred to the caller.
implementation()1427   internal::StreamInterface *implementation() { return implementation_.get(); }
1428 
1429   // Entrains onto the stream a callback to the host (from the device).
1430   // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1431   // never fail or its failure is inconsequential.
1432   //
1433   // This is kept for backward compatibility. Future code should use
1434   // ThenDoHostCallbackWithStatus and explicitly return a success status.
1435   // TODO(b/112125301): Eventually remove this method.
1436   Stream &ThenDoHostCallback(std::function<void()> callback);
1437 
1438   // Entrains onto the stream a callback to the host (from the device).
1439   // Host callbacks block/occupy the stream just as device functions
1440   // (execute one at a time, block later stream operations).
1441   // Whether the callback return status affects the result of BlockHostUntilDone
1442   // is platform-dependent.
1443   //
1444   // Behavior is undefined when synchronizing using OpenCL user events.
1445   // Behavior is undefined if host callbacks call device routines or insert
1446   // them into any stream.
1447   //
1448   // On certain platforms, ThenDoHostCallback is expected to have significant
1449   // negative effects on performance.
1450   Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
1451 
1452   // Runs the given callback after the next call to BlockHostUntilDone on this
1453   // stream (or after the Stream does BlockHostUntilDone in its destructor).
1454   // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
1455   // some use cases.
1456   Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
1457 
1458   // Returns the StreamExecutor (parent object) associated with this stream.
parent()1459   StreamExecutor *parent() const {
1460     CHECK(parent_ != nullptr);
1461     return parent_;
1462   }
1463 
1464   //
GetCudaComputeCapability()1465   CudaComputeCapability GetCudaComputeCapability() const {
1466     return parent()->GetDeviceDescription().cuda_compute_capability();
1467   }
1468 
GetRocmComputeCapability()1469   RocmComputeCapability GetRocmComputeCapability() const {
1470     return parent()->GetDeviceDescription().rocm_compute_capability();
1471   }
1472   // Returns the (internal usage) temporary-memory-allocation manager associated
1473   // with this stream.
1474   internal::TemporaryMemoryManager *temporary_memory_manager();
1475 
1476   // Returns a debugging string "[stream=0x...,impl=0x...]".
1477   std::string DebugStreamPointers() const;
1478 
1479  private:
1480   friend class host::HostBlas;  // for parent_.
1481   friend class host::HostFft;   // for parent_.
1482   friend class host::HostRng;   // for parent_.
1483   template <typename... Args>
1484   friend struct ThenBlasImpl;  // for implementing ThenBlasXXX.
1485   friend class ocl::CLBlas;    // for parent_.
1486 
1487   // Checks whether types match before a call to extended BLAS version.
1488   template <typename ABType, typename CType, typename ScaleType>
CheckTypesForExtendedBlas(blas::ComputationType computation_type)1489   port::Status CheckTypesForExtendedBlas(
1490       blas::ComputationType computation_type) {
1491     static_assert(
1492         detail::is_any_of<ABType, Eigen::half, Eigen::bfloat16, float, double,
1493                           int8_t, std::complex<float>, std::complex<double>>(),
1494         "The only buffer types supported are: Eigen::half, float, "
1495         "double, int8, std::complex<float> and std::complex<double>");
1496     static_assert(
1497         std::is_same_v<ABType, CType> ||
1498             (std::is_same_v<ABType, int8_t> && std::is_same_v<CType, int32_t>),
1499         "Input and output buffer types should be the same unless input is "
1500         "int8 and output is int32");
1501     static_assert(
1502         std::is_same_v<ScaleType, CType> ||
1503             (std::is_same_v<ScaleType, float> &&
1504              detail::is_any_of<CType, Eigen::half, Eigen::bfloat16>()),
1505         "Mismatched alpha/beta and output types");
1506 
1507     bool valid_computation_type = [computation_type] {
1508       switch (computation_type) {
1509         case blas::ComputationType::kF16:
1510           return std::is_same_v<CType, Eigen::half>;
1511         case blas::ComputationType::kF32:
1512           return detail::is_any_of<CType, Eigen::half, Eigen::bfloat16, float,
1513                                    std::complex<float>>();
1514         case blas::ComputationType::kF64:
1515           return detail::is_any_of<CType, double, std::complex<double>>();
1516         case blas::ComputationType::kI32:
1517           return std::is_same_v<CType, int32_t>;
1518         case blas::ComputationType::kF16AsF32:   // fall-through
1519         case blas::ComputationType::kBF16AsF32:  // fall-through
1520         case blas::ComputationType::kTF32AsF32:
1521           return detail::is_any_of<CType, float, std::complex<float>>();
1522       }
1523     }();
1524 
1525     if (!valid_computation_type) {
1526       return port::InternalError(absl::StrCat(
1527           "Invalid computation type ",
1528           blas::ComputationTypeString(computation_type), " for output type: ",
1529           blas::DataTypeString(blas::ToDataType<CType>::value)));
1530     }
1531     return ::tensorflow::OkStatus();
1532   }
1533 
InErrorState()1534   bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
1535     absl::ReaderMutexLock lock(&mu_);
1536     return !status_.ok();
1537   }
1538 
1539   // Sets the error state if operation_retcode is false.
1540   // This is a useful shorthand for many stream routines.
1541   void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_);
1542 
1543   // Checks the status and logs the error message, if any.
1544   void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_);
1545 
SetError()1546   void SetError() { CheckError(false /* = operation_retcode */); }
1547 
SetErrorAndLogNoDnnSupport()1548   void SetErrorAndLogNoDnnSupport() {
1549     SetError();
1550     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
1551                     "without DNN support";
1552   }
1553 
1554   // Runs the set of callbacks that are intended to run after
1555   // BlockHostUntilDone.
1556   void RunAfterBlockHostUntilDoneCallbacks();
1557 
1558   // The StreamExecutor that supports the operation of this stream.
1559   StreamExecutor *parent_;
1560 
1561   // The platform-dependent implementation that the StreamExecutor interface
1562   // delegates to.
1563   std::unique_ptr<internal::StreamInterface> implementation_;
1564 
1565   // mutex that guards the allocation / error state flags.
1566   // Mutable so that it can be obtained via const reader lock.
1567   mutable absl::Mutex mu_;
1568 
1569   // Whether Init() was successfully called to allocate this stream on the
1570   // underlying platform. It simply flips from 0 to 1 with a sanity check.
1571   // See StreamExecutor::AllocateStream.
1572   bool allocated_ ABSL_GUARDED_BY(mu_);
1573 
1574   // The last error (if any) of all method calls.
1575   port::Status status_ ABSL_GUARDED_BY(mu_);
1576 
1577   // Sub-streams that are generated from this stream. Each element has a pointer
1578   // to sub-stream and a boolean value indicating if this substream is ready to
1579   // be reused.
1580   std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
1581       ABSL_GUARDED_BY(mu_);
1582 
1583   // Streams can allocate temporary memories to help with work they enqueue
1584   // (e.g. for scratch memory spaces). This member tracks those allocations and
1585   // notes when they can be reclaimed -- reclamation is attempted when
1586   // BlockHostUntilDone() is called.
1587   internal::TemporaryMemoryManager temporary_memory_manager_;
1588 
1589   // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
1590   std::vector<std::function<void()>> after_block_host_until_done_callbacks_
1591       ABSL_GUARDED_BY(mu_);
1592 
1593   // Non-extended BLAS interface requires alpha/beta to be floats when input
1594   // type is Eigen::half. However, for consistency purposes it is convenient
1595   // for the interface to accept Eigen::half.
1596   template <typename T>
UpcastHalfToFloat(void ** alpha_ptr,void ** beta_ptr,float * alpha_storage,float * beta_storage)1597   void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr,
1598                          float *alpha_storage, float *beta_storage) {
1599     if (std::is_same<T, Eigen::half>::value) {
1600       *alpha_storage =
1601           static_cast<float>(*reinterpret_cast<Eigen::half *>(*alpha_ptr));
1602       *beta_storage =
1603           static_cast<float>(*reinterpret_cast<Eigen::half *>(*beta_ptr));
1604       *alpha_ptr = alpha_storage;
1605       *beta_ptr = beta_storage;
1606     } else if (std::is_same<T, Eigen::bfloat16>::value) {
1607       *alpha_storage =
1608           static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*alpha_ptr));
1609       *beta_storage =
1610           static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*beta_ptr));
1611       *alpha_ptr = alpha_storage;
1612       *beta_ptr = beta_storage;
1613     }
1614   }
1615 
1616   SE_DISALLOW_COPY_AND_ASSIGN(Stream);
1617 };
1618 
1619 ////////////
1620 // Inlines
1621 
1622 template <typename... Params, typename... Args>
ThenLaunch(ThreadDim thread_dims,BlockDim block_dims,const TypedKernel<Params...> & kernel,Args...args)1623 inline port::Status Stream::ThenLaunch(ThreadDim thread_dims,
1624                                        BlockDim block_dims,
1625                                        const TypedKernel<Params...> &kernel,
1626                                        Args... args) {
1627   KernelInvocationChecker<std::tuple<Params...>,
1628                           std::tuple<Args...>>::CheckAllStaticAssert();
1629 
1630   // This is the core that allows type-safe kernel launching.
1631   // Since the platforms take kernel arguments as tuples of (void *, size),
1632   // we pack the variadic parameters passed as ...args into the desired
1633   // tuple form and pass that packed form to the StreamExecutor::Launch()
1634   // implementation.
1635   KernelArgsArray<sizeof...(args)> kernel_args;
1636   kernel.PackParams(&kernel_args, args...);
1637   TF_RETURN_IF_ERROR(
1638       parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args));
1639   return ::tensorflow::OkStatus();
1640 }
1641 
1642 template <typename T>
1643 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64_t element_count)1644 Stream::AllocateTemporaryArray(uint64_t element_count) {
1645   return temporary_memory_manager_.AllocateArray<T>(element_count);
1646 }
1647 
temporary_memory_manager()1648 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
1649   return &temporary_memory_manager_;
1650 }
1651 
1652 template <>
1653 struct Quantization<uint8> {
1654   static constexpr dnn::QuantizedActivationMode kModeId =
1655       dnn::QuantizedActivationMode::k8Bit;
1656 };
1657 
1658 template <>
1659 struct Quantization<uint16> {
1660   static constexpr dnn::QuantizedActivationMode kModeId =
1661       dnn::QuantizedActivationMode::k16Bit;
1662 };
1663 
1664 template <>
1665 struct Quantization<int32> {
1666   static constexpr dnn::QuantizedActivationMode kModeId =
1667       dnn::QuantizedActivationMode::k32Bit;
1668 };
1669 
1670 }  // namespace stream_executor
1671 
1672 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
1673