• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/stream.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "third_party/eigen3/Eigen/Core"
23 #include "tensorflow/compiler/xla/stream_executor/blas.h"
24 #include "tensorflow/compiler/xla/stream_executor/lib/stacktrace.h"
25 #include "tensorflow/compiler/xla/stream_executor/platform.h"
26 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
27 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
28 #include "tensorflow/compiler/xla/stream_executor/rng.h"
29 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
30 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
31 
32 namespace stream_executor {
33 
34 namespace {
35 // Code to turn parameters to functions on stream into strings that
36 // will be VLOG'ed. We need overloads, instead of
37 // e.g. BatchDescriptorToVlogString(), as the code that calls these
38 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)39 std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
40   return descriptor.ToShortString();
41 }
42 
ToVlogString(const dnn::FilterDescriptor & descriptor)43 std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
44   return descriptor.ToShortString();
45 }
46 
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)47 std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
48   return descriptor.ToShortString();
49 }
50 
ToVlogString(const dnn::PoolingDescriptor & descriptor)51 std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
52   return descriptor.ToShortString();
53 }
54 
ToVlogString(const dnn::NormalizeDescriptor & descriptor)55 std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
56   return descriptor.ToShortString();
57 }
58 
ToVlogString(dnn::ActivationMode mode)59 std::string ToVlogString(dnn::ActivationMode mode) {
60   return dnn::ActivationModeString(mode);
61 }
62 
ToVlogString(const dnn::AlgorithmConfig & algo_config)63 std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
64   return algo_config.ToString();
65 }
66 
ToVlogString(dnn::ElementwiseOperation op)67 std::string ToVlogString(dnn::ElementwiseOperation op) {
68   return dnn::ElementwiseOperationString(op);
69 }
70 
ToVlogString(dnn::QuantizedActivationMode mode)71 std::string ToVlogString(dnn::QuantizedActivationMode mode) {
72   return dnn::QuantizedActivationModeString(mode);
73 }
74 
ToVlogString(blas::Transpose t)75 std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
76 
ToVlogString(blas::UpperLower ul)77 std::string ToVlogString(blas::UpperLower ul) {
78   return blas::UpperLowerString(ul);
79 }
80 
ToVlogString(blas::Diagonal d)81 std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
82 
ToVlogString(blas::Side s)83 std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
84 
ToVlogString(blas::ComputationType ty)85 std::string ToVlogString(blas::ComputationType ty) {
86   return blas::ComputationTypeString(ty);
87 }
88 
ToVlogString(const void * ptr)89 std::string ToVlogString(const void *ptr) {
90   if (ptr == nullptr) {
91     return "null";
92   }
93 
94   // StrCat does not convert pointers to text.
95   std::ostringstream out;
96   out << ptr;
97   return out.str();
98 }
99 
100 template <class T>
ToVlogString(const std::complex<T> & c)101 std::string ToVlogString(const std::complex<T> &c) {
102   // StrCat does not convert std::complex to text.
103   std::ostringstream out;
104   out << c;
105   return out.str();
106 }
107 
108 template <class T>
ToVlogString(const std::function<T> & f)109 std::string ToVlogString(const std::function<T> &f) {
110   return f == nullptr ? "null" : "<non-null function>";
111 }
112 
ToVlogString(const DeviceMemoryBase & memory)113 std::string ToVlogString(const DeviceMemoryBase &memory) {
114   return ToVlogString(memory.opaque());
115 }
116 
ToVlogString(const DeviceMemoryBase * memory)117 std::string ToVlogString(const DeviceMemoryBase *memory) {
118   return memory == nullptr ? "null" : ToVlogString(*memory);
119 }
120 
ToVlogString(const Eigen::half & h)121 std::string ToVlogString(const Eigen::half &h) {
122   return absl::StrCat(static_cast<float>(h));
123 }
124 
ToVlogString(int i)125 std::string ToVlogString(int i) { return absl::StrCat(i); }
126 
ToVlogString(uint32 i)127 std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
128 
ToVlogString(uint64_t i)129 std::string ToVlogString(uint64_t i) { return absl::StrCat(i); }
130 
ToVlogString(int64_t i)131 std::string ToVlogString(int64_t i) { return absl::StrCat(i); }
132 
ToVlogString(float f)133 std::string ToVlogString(float f) { return absl::StrCat(f); }
134 
ToVlogString(double d)135 std::string ToVlogString(double d) { return absl::StrCat(d); }
136 
137 template <class T>
ToVlogString(port::ArraySlice<T> elements)138 std::string ToVlogString(port::ArraySlice<T> elements) {  // non-absl ok
139   std::string str = absl::StrCat(
140       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
141       elements.size(), "]{");
142   const char *separator = "";
143   size_t max_to_show = std::numeric_limits<size_t>::max();
144   if (!VLOG_IS_ON(2)) {
145     max_to_show = 5;
146   } else if (!VLOG_IS_ON(3)) {
147     max_to_show = 20;
148   } else if (!VLOG_IS_ON(11)) {
149     max_to_show = 1000;
150   }
151   for (size_t i = 0; i < elements.size(); ++i) {
152     if (i == max_to_show) {
153       str += ", ...";
154       break;
155     }
156     absl::StrAppend(&str, separator, ToVlogString(elements[i]));
157     separator = ", ";
158   }
159   str += "}";
160   return str;
161 }
162 
163 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)164 std::string ToVlogString(port::MutableArraySlice<T> elements) {  // non-absl ok
165   return ToVlogString(port::ArraySlice<T>(elements));            // non-absl ok
166 }
167 
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)168 std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
169   switch (depth_to_space_layout) {
170     case dnn::DepthToSpaceLayout::DepthHeightWidth:
171       return "DepthToSpaceLayout::DepthHeightWidth";
172   }
173   return "unknown DepthToSpaceLayout";
174 }
175 
ToVlogString(dnn::DataType data_type)176 std::string ToVlogString(dnn::DataType data_type) {
177   switch (data_type) {
178     case dnn::DataType::kFloat:
179       return "dnn::DataType::kFloat";
180     case dnn::DataType::kDouble:
181       return "dnn::DataType::kDouble";
182     case dnn::DataType::kHalf:
183       return "dnn::DataType::kHalf";
184     case dnn::DataType::kInt8:
185       return "dnn::DataType::kInt8";
186     case dnn::DataType::kInt32:
187       return "dnn::DataType::kInt32";
188     default:
189       return "unknown DataType";
190   }
191 }
192 
193 // Used together with PARAM to VLOG calls made to the stream. Intended
194 // to be used like this:
195 //
196 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
197 //
198 // where a and b are the parameters to MyFunction.
199 //
200 // See VLOG_CALL for a short-hand for this. This way of doing it saves
201 // a tremendous amount of boilerplate code given how many functions
202 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,std::string>> params)203 std::string CallStr(const char *function_name, Stream *stream,
204                     std::vector<std::pair<const char *, std::string>> params) {
205   // Do not call this function unless VLOG is on since just
206   // constructing all the strings in params is expensive.
207   CHECK(VLOG_IS_ON(1));
208 
209   std::string str = absl::StrCat(stream->DebugStreamPointers(),
210                                  " Called Stream::", function_name, "(");
211   const char *separator = "";
212   for (const auto &param : params) {
213     absl::StrAppend(&str, separator, param.first, "=", param.second);
214     separator = ", ";
215   }
216   absl::StrAppend(&str, ")");
217   if (VLOG_IS_ON(10)) {
218     absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
219   }
220   return str;
221 }
222 
223 // Use this macro to avoid having to type every parameter twice to log
224 // it with VLOG and CallStr.
225 #define PARAM(parameter)                \
226   {                                     \
227 #parameter, ToVlogString(parameter) \
228   }
229 
230 // Use this macro to avoid having to type out the name of each
231 // function and to save some boilerplate. Intended to be used like this:
232 //
233 //   VLOG_CALL(PARAM(a), PARAM(b))
234 //
235 // This saves a tremendous amount of boilerplate compared to the alternative:
236 //
237 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
238 //           << ", b=" << ToVlogString(b);
239 //
240 // Note here that most of the parameter names are not short and that
241 // most of the functions take many more than 2 parameters.
242 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
243 
244 }  // namespace
245 
Stream(StreamExecutor * parent)246 Stream::Stream(StreamExecutor *parent)
247     : parent_(parent),
248       implementation_(parent->implementation()->GetStreamImplementation()),
249       allocated_(false),
250       status_(port::InternalError("Uninitialized stream")),
251       temporary_memory_manager_(this) {
252   VLOG_CALL(PARAM(parent));
253 }
254 
~Stream()255 Stream::~Stream() {
256   VLOG_CALL();
257 
258   // Ensure the stream is completed.
259   auto status = BlockHostUntilDone();
260   if (!status.ok()) {
261     LOG(WARNING) << "Error blocking host until done in stream destructor: "
262                  << status;
263   }
264   temporary_memory_manager_.ForceDeallocateAll();
265   RunAfterBlockHostUntilDoneCallbacks();
266 
267   if (allocated_) {
268     parent_->DeallocateStream(this);
269   }
270 }
271 
RefreshStatus()272 port::Status Stream::RefreshStatus() {
273   port::Status status = parent_->GetStatus(this);
274   // We should not put the stream in an error state, just because the GetStatus
275   // method is unimplemented.
276   if (status != port::Status(port::error::UNIMPLEMENTED,
277                              "GetStatus is not supported on this executor.")) {
278     CheckStatus(status);
279   }
280   return status;
281 }
282 
Init()283 Stream &Stream::Init() {
284   VLOG_CALL();
285 
286   absl::MutexLock lock(&mu_);
287   CHECK_EQ(false, allocated_)
288       << "stream appears to already have been initialized";
289   CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization";
290 
291   if (parent_->AllocateStream(this)) {
292     // Successful initialization!
293     allocated_ = true;
294     status_ = ::tensorflow::OkStatus();
295   } else {
296     LOG(ERROR) << "failed to allocate stream during initialization";
297   }
298 
299   return *this;
300 }
301 
InitTimer(Timer * timer)302 Stream &Stream::InitTimer(Timer *timer) {
303   VLOG_CALL(PARAM(timer));
304 
305   CheckError(parent_->AllocateTimer(timer));
306   return *this;
307 }
308 
InitWithTimer(Timer * timer)309 Stream &Stream::InitWithTimer(Timer *timer) {
310   VLOG_CALL(PARAM(timer));
311 
312   return Init().InitTimer(timer);
313 }
314 
ThenRecordEvent(Event * event)315 Stream &Stream::ThenRecordEvent(Event *event) {
316   VLOG_CALL(PARAM(event));
317 
318   port::Status status = parent_->RecordEvent(this, event);
319   if (!status.ok()) {
320     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
321                << "; not marking stream as bad, as the Event object may be "
322                << "at fault. Monitor for further errors.";
323   }
324 
325   return *this;
326 }
327 
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)328 Stream &Stream::ThenBatchNormalizationForward(
329     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
330     const DeviceMemory<float> &offset,
331     const DeviceMemory<float> &estimated_mean,
332     const DeviceMemory<float> &estimated_variance,
333     const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
334     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
335     const double exponential_average_factor,
336     dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
337     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
338     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
339     bool is_training, ScratchAllocator *reserve_space_allocator,
340     ScratchAllocator *workspace_allocator) {
341   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
342             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
343   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
344     CheckError(dnn->DoBatchNormalizationForward(
345         this, x, scale, offset, estimated_mean, estimated_variance, side_input,
346         x_desc, scale_offset_desc, epsilon, exponential_average_factor,
347         activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
348         is_training, reserve_space_allocator, workspace_allocator));
349   } else {
350     SetErrorAndLogNoDnnSupport();
351   }
352   return *this;
353 }
354 
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<float> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<float> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)355 Stream &Stream::ThenBatchNormalizationBackward(
356     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
357     const DeviceMemory<float> &scale, const DeviceMemory<float> &offset,
358     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
359     const DeviceMemory<float> &y, const dnn::BatchDescriptor &x_desc,
360     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
361     dnn::ActivationMode activation_mode, DeviceMemory<float> *x_backprop,
362     DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
363     DeviceMemory<float> *side_input_backprop,
364     DeviceMemory<uint8> *reserve_space_data,
365     ScratchAllocator *workspace_allocator) {
366   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
367             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
368             PARAM(scale_backprop), PARAM(offset_backprop));
369   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
370     CheckError(dnn->DoBatchNormalizationBackward(
371         this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc,
372         scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop,
373         offset_backprop, side_input_backprop, reserve_space_data,
374         workspace_allocator));
375   } else {
376     SetErrorAndLogNoDnnSupport();
377   }
378   return *this;
379 }
380 
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)381 Stream &Stream::ThenBatchNormalizationForward(
382     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
383     const DeviceMemory<float> &offset,
384     const DeviceMemory<float> &estimated_mean,
385     const DeviceMemory<float> &estimated_variance,
386     const DeviceMemory<Eigen::half> &side_input,
387     const dnn::BatchDescriptor &x_desc,
388     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
389     const double exponential_average_factor,
390     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
391     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
392     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
393     bool is_training, ScratchAllocator *reserve_space_allocator,
394     ScratchAllocator *workspace_allocator) {
395   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
396             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
397   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
398     CheckError(dnn->DoBatchNormalizationForward(
399         this, x, scale, offset, estimated_mean, estimated_variance, side_input,
400         x_desc, scale_offset_desc, epsilon, exponential_average_factor,
401         activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
402         is_training, reserve_space_allocator, workspace_allocator));
403   } else {
404     SetErrorAndLogNoDnnSupport();
405   }
406   return *this;
407 }
408 
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<Eigen::half> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<Eigen::half> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)409 Stream &Stream::ThenBatchNormalizationBackward(
410     const DeviceMemory<Eigen::half> &y_backprop,
411     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
412     const DeviceMemory<float> &offset, const DeviceMemory<float> &mean,
413     const DeviceMemory<float> &inv_var, const DeviceMemory<Eigen::half> &y,
414     const dnn::BatchDescriptor &x_desc,
415     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
416     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *x_backprop,
417     DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
418     DeviceMemory<Eigen::half> *side_input_backprop,
419     DeviceMemory<uint8> *reserve_space_data,
420     ScratchAllocator *workspace_allocator) {
421   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
422             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
423             PARAM(scale_backprop), PARAM(offset_backprop));
424   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
425     CheckError(dnn->DoBatchNormalizationBackward(
426         this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc,
427         scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop,
428         offset_backprop, side_input_backprop, reserve_space_data,
429         workspace_allocator));
430 
431   } else {
432     SetErrorAndLogNoDnnSupport();
433   }
434   return *this;
435 }
436 
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)437 Stream &Stream::ThenConvolve(
438     const dnn::BatchDescriptor &input_descriptor,
439     const DeviceMemory<float> &input_data,
440     const dnn::FilterDescriptor &filter_descriptor,
441     const DeviceMemory<float> &filter_data,
442     const dnn::ConvolutionDescriptor &convolution_descriptor,
443     const dnn::BatchDescriptor &output_descriptor,
444     DeviceMemory<float> *output) {
445   if (ok()) {
446     CheckError(ConvolveWithAlgorithm(
447                    dnn::ConvolutionKind::FORWARD, input_descriptor, input_data,
448                    filter_descriptor, filter_data, output_descriptor, *output,
449                    convolution_descriptor,
450                    /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
451                    /*output_profile_result=*/nullptr)
452                    .ok());
453   }
454   return *this;
455 }
456 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)457 Stream &Stream::ThenConvolveQuantized(
458     const dnn::BatchDescriptor &input_descriptor,
459     const DeviceMemory<float> &input_data,
460     const dnn::FilterDescriptor &filter_descriptor,
461     const DeviceMemory<int8> &filter_coefficients,
462     const DeviceMemory<float> &coefficient_scales,
463     const dnn::ConvolutionDescriptor &convolution_descriptor,
464     const dnn::BatchDescriptor &output_descriptor,
465     DeviceMemory<float> *output) {
466   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
467             PARAM(filter_descriptor), PARAM(filter_coefficients),
468             PARAM(coefficient_scales), PARAM(convolution_descriptor),
469             PARAM(output_descriptor), PARAM(output));
470 
471   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
472     CheckError(dnn->DoConvolveQuantized(
473         this, input_descriptor, input_data, filter_descriptor,
474         filter_coefficients, coefficient_scales, convolution_descriptor,
475         output_descriptor, output));
476   } else {
477     SetError();
478     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
479                     "without DNN support";
480   }
481   return *this;
482 }
483 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)484 Stream &Stream::ThenConvolveQuantized(
485     const dnn::BatchDescriptor &input_descriptor,
486     const DeviceMemory<float> &input_data,
487     const dnn::FilterDescriptor &filter_descriptor,
488     const DeviceMemory<int16> &filter_coefficients,
489     const DeviceMemory<float> &coefficient_scales,
490     const dnn::ConvolutionDescriptor &convolution_descriptor,
491     const dnn::BatchDescriptor &output_descriptor,
492     DeviceMemory<float> *output) {
493   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
494             PARAM(filter_descriptor), PARAM(filter_coefficients),
495             PARAM(coefficient_scales), PARAM(convolution_descriptor),
496             PARAM(output_descriptor), PARAM(output));
497 
498   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
499     CheckError(dnn->DoConvolveQuantized(
500         this, input_descriptor, input_data, filter_descriptor,
501         filter_coefficients, coefficient_scales, convolution_descriptor,
502         output_descriptor, output));
503   } else {
504     SetError();
505     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
506                     "without DNN support";
507   }
508   return *this;
509 }
510 
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)511 Stream &Stream::ThenSeparableConvolve(
512     const dnn::BatchDescriptor &batch_descriptor,
513     const DeviceMemory<float> &input_data,
514     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
515     const DeviceMemory<float> &first_weights,
516     const DeviceMemory<float> &second_weights,
517     const dnn::ConvolutionDescriptor &convolution_descriptor,
518     const dnn::BatchDescriptor &output_descriptor,
519     DeviceMemory<float> *output) {
520   VLOG_CALL(
521       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
522       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
523       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
524 
525   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
526     CheckError(dnn->DoSeparableConvolve(
527         this, batch_descriptor, input_data, filter_descriptor, depth_multiplier,
528         first_weights, second_weights, convolution_descriptor,
529         output_descriptor, output));
530   } else {
531     SetErrorAndLogNoDnnSupport();
532   }
533   return *this;
534 }
535 
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)536 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
537                            const DeviceMemory<float> &weights,
538                            const dnn::BatchDescriptor &input_dimensions,
539                            const dnn::BatchDescriptor &output_dimensions,
540                            DeviceMemory<float> *output_data) {
541   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
542             PARAM(output_dimensions), PARAM(output_data));
543 
544   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
545     CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
546                              output_dimensions, output_data));
547   } else {
548     SetErrorAndLogNoDnnSupport();
549   }
550   return *this;
551 }
552 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)553 Stream &Stream::ThenMatMulQuantized(
554     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
555     const DeviceMemory<float> &weight_scales,
556     const dnn::BatchDescriptor &input_dimensions,
557     const dnn::BatchDescriptor &output_dimensions,
558     DeviceMemory<float> *output_data) {
559   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
560             PARAM(input_dimensions), PARAM(output_dimensions),
561             PARAM(output_data));
562 
563   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
564     CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
565                                       input_dimensions, output_dimensions,
566                                       output_data));
567   } else {
568     SetErrorAndLogNoDnnSupport();
569   }
570   return *this;
571 }
572 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)573 Stream &Stream::ThenMatMulQuantized(
574     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
575     const DeviceMemory<float> &weight_scales,
576     const dnn::BatchDescriptor &input_dimensions,
577     const dnn::BatchDescriptor &output_dimensions,
578     DeviceMemory<float> *output_data) {
579   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
580             PARAM(input_dimensions), PARAM(output_dimensions),
581             PARAM(output_data));
582 
583   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
584     CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
585                                       input_dimensions, output_dimensions,
586                                       output_data));
587   } else {
588     SetErrorAndLogNoDnnSupport();
589   }
590   return *this;
591 }
592 
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)593 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
594                             const DeviceMemory<float> &biases,
595                             const dnn::BatchDescriptor &dimensions,
596                             DeviceMemory<float> *output_data) {
597   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
598             PARAM(output_data));
599 
600   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
601     CheckError(
602         dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
603   } else {
604     SetErrorAndLogNoDnnSupport();
605   }
606   return *this;
607 }
608 
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)609 Stream &Stream::ThenNormalizeWithDimensions(
610     const dnn::NormalizeDescriptor &normalize_descriptor,
611     const dnn::BatchDescriptor &dimensions,
612     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
613   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
614             PARAM(output_data));
615 
616   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
617     CheckError(dnn->DoNormalizeWithDimensions(
618         this, normalize_descriptor, dimensions, input_data, output_data));
619   } else {
620     SetErrorAndLogNoDnnSupport();
621   }
622   return *this;
623 }
624 
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)625 Stream &Stream::ThenNormalizeBackwardWithDimensions(
626     const dnn::NormalizeDescriptor &normalize_descriptor,
627     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
628     const DeviceMemory<float> &normalized_data,
629     const DeviceMemory<float> &normalized_variable_gradient,
630     DeviceMemory<float> *raw_variable_gradient,
631     ScratchAllocator *workspace_allocator) {
632   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
633             PARAM(normalized_data), PARAM(normalized_variable_gradient),
634             PARAM(raw_variable_gradient), PARAM(workspace_allocator));
635 
636   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
637     CheckError(dnn->DoNormalizeBackwardWithDimensions(
638         this, normalize_descriptor, dimensions, raw_data, normalized_data,
639         normalized_variable_gradient, raw_variable_gradient,
640         workspace_allocator));
641   } else {
642     SetErrorAndLogNoDnnSupport();
643   }
644   return *this;
645 }
646 
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)647 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
648                              const dnn::BatchDescriptor &dimensions,
649                              const DeviceMemory<float> &input_data,
650                              DeviceMemory<float> *output_data) {
651   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
652                                  output_data, /*options=*/0);
653 }
654 
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64_t options)655 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
656                                         const dnn::BatchDescriptor &dimensions,
657                                         const DeviceMemory<float> &input_data,
658                                         DeviceMemory<float> *output_data,
659                                         uint64_t options) {
660   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
661             PARAM(output_data), PARAM(options));
662 
663   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
664     CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
665                                output_data, options));
666   } else {
667     SetErrorAndLogNoDnnSupport();
668   }
669   return *this;
670 }
671 
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)672 Stream &Stream::ThenDepthConcatenate(
673     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
674     port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
675     DeviceMemory<float> *output_data) {
676   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
677 
678   for (size_t i = 1; i < input_dimensions.size(); ++i) {
679     if (input_dimensions[i].count() != input_dimensions[0].count() ||
680         input_dimensions[i].height() != input_dimensions[0].height() ||
681         input_dimensions[i].width() != input_dimensions[0].width()) {
682       SetError();
683       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
684                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
685                  << "input_dimensions[" << i
686                  << "]: " << input_dimensions[i].ToString();
687       return *this;
688     }
689   }
690 
691   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
692     CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
693                                        output_data));
694   } else {
695     SetErrorAndLogNoDnnSupport();
696   }
697   return *this;
698 }
699 
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)700 Stream &Stream::ThenSpaceConcatenate(
701     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
702     port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
703     DeviceMemory<float> *output_data,
704     dnn::SpaceConcatenateMode concat_direction) {
705   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
706 
707   // Check that the input dimensions of all the other batches match those of the
708   // first batch.
709   for (size_t i = 1; i < input_dimensions.size(); ++i) {
710     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
711         (input_dimensions[i].count() != input_dimensions[0].count() ||
712          input_dimensions[i].height() != input_dimensions[0].height() ||
713          input_dimensions[i].feature_map_count() !=
714              input_dimensions[0].feature_map_count())) {
715       SetError();
716       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
717                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
718                  << "input_dimensions[" << i
719                  << "]: " << input_dimensions[i].ToString();
720       return *this;
721     }
722 
723     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
724         (input_dimensions[i].count() != input_dimensions[0].count() ||
725          input_dimensions[i].width() != input_dimensions[0].width() ||
726          input_dimensions[i].feature_map_count() !=
727              input_dimensions[0].feature_map_count())) {
728       SetError();
729       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
730                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
731                  << "input_dimensions[" << i
732                  << "]: " << input_dimensions[i].ToString();
733       return *this;
734     }
735   }
736   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
737     CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
738                                        output_data, concat_direction));
739   } else {
740     SetErrorAndLogNoDnnSupport();
741   }
742   return *this;
743 }
744 
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)745 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
746                             const DeviceMemory<float> &input_data,
747                             const dnn::BatchDescriptor &output_dimensions,
748                             DeviceMemory<float> *output_data) {
749   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
750             PARAM(output_dimensions), PARAM(output_data));
751 
752   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
753     CheckError(dnn->DoReshape(this, input_dimensions, input_data,
754                               output_dimensions, output_data));
755   } else {
756     SetErrorAndLogNoDnnSupport();
757   }
758   return *this;
759 }
760 
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)761 Stream &Stream::ThenDepthToSpace(
762     const dnn::BatchDescriptor &input_dimensions,
763     const DeviceMemory<float> &input_data,
764     const dnn::DepthToSpaceLayout &depth_to_space_layout,
765     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
766   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
767             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
768             PARAM(output_data));
769 
770   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
771     CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
772                                    depth_to_space_layout, sqrt_depth_reduction,
773                                    output_data));
774   } else {
775     SetErrorAndLogNoDnnSupport();
776   }
777   return *this;
778 }
779 
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)780 Stream &Stream::ThenSpaceToDepth(
781     const dnn::BatchDescriptor &input_dimensions,
782     const DeviceMemory<float> &input_data,
783     const dnn::DepthToSpaceLayout &space_to_depth_layout,
784     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
785   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
786             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
787             PARAM(output_data));
788 
789   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
790     CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
791                                    space_to_depth_layout, sqrt_depth_increase,
792                                    output_data));
793   } else {
794     SetErrorAndLogNoDnnSupport();
795   }
796   return *this;
797 }
798 
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)799 Stream &Stream::ThenElementwiseOperate(
800     dnn::ElementwiseOperation operation,
801     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
802     port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
803     const dnn::BatchDescriptor &output_dimensions,
804     DeviceMemory<float> *output_data) {
805   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
806             PARAM(output_dimensions), PARAM(output_data));
807 
808   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
809     CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
810                                          input_data, output_dimensions,
811                                          output_data));
812   } else {
813     SetErrorAndLogNoDnnSupport();
814   }
815   return *this;
816 }
817 
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)818 Stream &Stream::ThenElementwiseOperateScaledQuantized(
819     dnn::ElementwiseOperation operation,
820     port::ArraySlice<int> input_multiplicands,  // non-absl ok
821     int output_divisor,
822     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,   // non-absl ok
823     port::ArraySlice<const DeviceMemory<float> *> input_data,  // non-absl ok
824     const dnn::BatchDescriptor &output_dimensions,
825     DeviceMemory<float> *output_data) {
826   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
827             PARAM(input_dimensions), PARAM(input_data),
828             PARAM(output_dimensions), PARAM(output_data));
829 
830   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
831     CheckError(dnn->DoElementwiseOperateScaledQuantized(
832         this, operation, input_multiplicands, output_divisor, input_dimensions,
833         input_data, output_dimensions, output_data));
834   } else {
835     SetErrorAndLogNoDnnSupport();
836   }
837   return *this;
838 }
839 
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)840 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
841                           const DeviceMemory<float> &input_data,
842                           int64_t left_pad, int64_t right_pad, int64_t top_pad,
843                           int64_t bottom_pad,
844                           DeviceMemory<float> *output_data) {
845   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
846             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
847             PARAM(output_data));
848 
849   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
850     CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
851                             top_pad, bottom_pad, output_data));
852   } else {
853     SetErrorAndLogNoDnnSupport();
854   }
855   return *this;
856 }
857 
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)858 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
859                             const DeviceMemory<float> &input_data,
860                             int64_t left_trim, int64_t right_trim,
861                             int64_t top_trim, int64_t bottom_trim,
862                             DeviceMemory<float> *output_data) {
863   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
864             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
865             PARAM(output_data));
866 
867   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
868     CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
869                               right_trim, top_trim, bottom_trim, output_data));
870   } else {
871     SetErrorAndLogNoDnnSupport();
872   }
873   return *this;
874 }
875 
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t replicate_x,int64_t replicate_y,DeviceMemory<float> * output_data)876 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
877                                 const DeviceMemory<float> &input_data,
878                                 int64_t replicate_x, int64_t replicate_y,
879                                 DeviceMemory<float> *output_data) {
880   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
881             PARAM(replicate_y), PARAM(output_data));
882 
883   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
884     CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
885                                   replicate_y, output_data));
886   } else {
887     SetErrorAndLogNoDnnSupport();
888   }
889   return *this;
890 }
891 
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64_t size)892 Stream &Stream::ThenMemcpyD2HQuantized(
893     const DeviceMemory<float> &gpu_unquantized_src,
894     dnn::QuantizedActivationMode mode, void *host_dst, uint64_t size) {
895   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
896             PARAM(size));
897 
898   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
899     CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
900                                          host_dst, size));
901   } else {
902     SetErrorAndLogNoDnnSupport();
903   }
904   return *this;
905 }
906 
ThenMemcpyH2DQuantized(const void * host_src,uint64_t size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)907 Stream &Stream::ThenMemcpyH2DQuantized(
908     const void *host_src, uint64_t size, dnn::QuantizedActivationMode mode,
909     DeviceMemory<float> *gpu_unquantized_dst) {
910   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
911             PARAM(gpu_unquantized_dst));
912 
913   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
914     CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
915                                          gpu_unquantized_dst));
916   } else {
917     SetErrorAndLogNoDnnSupport();
918   }
919   return *this;
920 }
921 
GetOrCreateSubStream()922 Stream *Stream::GetOrCreateSubStream() {
923   // Do not destroy bad streams when holding mu_ because ~Stream() may
924   // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
925   std::vector<std::unique_ptr<Stream>> bad_streams;
926 
927   absl::MutexLock lock(&mu_);
928 
929   // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
930   // we encounter along the way.
931   for (size_t index = 0; index < sub_streams_.size();) {
932     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
933     if (pair.second) {
934       // The sub_stream is reusable.
935       Stream *sub_stream = pair.first.get();
936       if (sub_stream->ok()) {
937         VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
938                 << sub_stream->DebugStreamPointers();
939         pair.second = false;
940         return sub_stream;
941       }
942 
943       // The stream is reusable and not ok. Streams have a monotonic state
944       // machine; the stream will remain in !ok forever. Swap it with the last
945       // stream and pop it off.
946       const int64_t last = sub_streams_.size() - 1;
947       if (index != last) {
948         std::swap(pair, sub_streams_[last]);
949       }
950       bad_streams.push_back(std::move(sub_streams_.back().first));
951       sub_streams_.pop_back();
952       VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
953               << sub_stream->DebugStreamPointers();
954     } else {
955       // The sub_stream is not reusable, move on to the next one.
956       ++index;
957     }
958   }
959 
960   // No streams are reusable; create a new stream.
961   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
962                             false);
963   Stream *sub_stream = sub_streams_.back().first.get();
964   sub_stream->Init();
965   if (!sub_stream->ok()) {
966     LOG(ERROR) << "sub-stream failed to be initialized";
967   }
968   VLOG(1) << DebugStreamPointers() << " created new sub_stream "
969           << sub_stream->DebugStreamPointers();
970 
971   return sub_stream;
972 }
973 
ReturnSubStream(Stream * sub_stream)974 void Stream::ReturnSubStream(Stream *sub_stream) {
975   // Do not destroy bad streams when holding mu_ because ~Stream() may
976   // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
977   std::unique_ptr<Stream> bad_stream;
978 
979   absl::MutexLock lock(&mu_);
980 
981   // Look for the sub-stream.
982   for (int64_t index = 0, end = sub_streams_.size(); index < end; ++index) {
983     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
984     if (pair.first.get() != sub_stream) {
985       continue;
986     }
987 
988     // Found the sub_stream.
989     if (sub_stream->ok()) {
990       VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
991               << sub_stream->DebugStreamPointers();
992       pair.second = true;
993     } else {
994       // The returned stream is not ok. Streams have a monotonic state
995       // machine; the stream will remain in !ok forever. Swap it with the last
996       // stream and pop it off.
997       VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
998               << sub_stream->DebugStreamPointers();
999       const int64_t last = sub_streams_.size() - 1;
1000       if (index != last) {
1001         std::swap(pair, sub_streams_[last]);
1002       }
1003       std::swap(bad_stream, sub_streams_.back().first);
1004       sub_streams_.pop_back();
1005     }
1006     return;
1007   }
1008 
1009   LOG(FATAL) << DebugStreamPointers()
1010              << " did not create the returned sub-stream "
1011              << sub_stream->DebugStreamPointers();
1012 }
1013 
ThenStartTimer(Timer * t)1014 Stream &Stream::ThenStartTimer(Timer *t) {
1015   VLOG_CALL(PARAM(t));
1016 
1017   CheckError(parent_->StartTimer(this, t));
1018   return *this;
1019 }
1020 
ThenStopTimer(Timer * t)1021 Stream &Stream::ThenStopTimer(Timer *t) {
1022   VLOG_CALL(PARAM(t));
1023 
1024   CheckError(parent_->StopTimer(this, t));
1025   return *this;
1026 }
1027 
ThenWaitFor(Stream * other)1028 Stream &Stream::ThenWaitFor(Stream *other) {
1029   VLOG_CALL(PARAM(other));
1030 
1031   CHECK(this != other) << "stream cannot wait for itself";
1032   if (ok() && other->ok()) {
1033     CheckError(parent_->CreateStreamDependency(this, other));
1034   } else {
1035     SetError();
1036     LOG(INFO) << DebugStreamPointers() << " did not wait for "
1037               << other->DebugStreamPointers();
1038   }
1039   return *this;
1040 }
1041 
ThenWaitFor(Event * event)1042 Stream &Stream::ThenWaitFor(Event *event) {
1043   VLOG_CALL(PARAM(event));
1044 
1045   if (ok()) {
1046     port::Status status = parent_->WaitForEvent(this, event);
1047     if (!status.ok()) {
1048       LOG(ERROR) << "Error waiting for event in stream: "
1049                  << status.error_message()
1050                  << "; not marking stream as bad, as the Event object may be "
1051                  << "at fault. Monitor for further errors.";
1052     }
1053   } else {
1054     LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1055   }
1056   return *this;
1057 }
1058 
1059 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1060 // functions and logs for errors.
1061 template <typename... Args>
1062 struct ThenBlasImpl {
1063   // blas_func is the DoBlasXXX member function pointer, and args are its
1064   // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1065   Stream &operator()(Stream *stream,
1066                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1067                      Args... args) {
1068     return Run(stream, blas_func, /*record_error=*/true, args...);
1069   }
1070 
1071   // Like operator(), but only calls stream->CheckError() if record_error is
1072   // true.
1073   Stream &Run(Stream *stream,
1074               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1075               bool record_error, Args... args);
1076 };
1077 
1078 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1079 Stream &ThenBlasImpl<Args...>::Run(
1080     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1081     bool record_error, Args... args) {
1082   if (stream->ok()) {
1083     bool ok;
1084     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1085       ok = (blas->*blas_func)(stream, args...);
1086     } else {
1087       LOG(WARNING)
1088           << "attempting to perform BLAS operation using StreamExecutor "
1089              "without BLAS support";
1090       ok = false;
1091     }
1092     if (record_error) {
1093       stream->CheckError(ok);
1094     }
1095   }
1096   return *stream;
1097 }
1098 
ThenBlasAxpy(uint64_t elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1099 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, float alpha,
1100                              const DeviceMemory<float> &x, int incx,
1101                              DeviceMemory<float> *y, int incy) {
1102   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1103             PARAM(incy));
1104 
1105   ThenBlasImpl<uint64_t, float, const DeviceMemory<float> &, int,
1106                DeviceMemory<float> *, int>
1107       impl;
1108   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1109               y, incy);
1110 }
1111 
ThenBlasAxpy(uint64_t elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1112 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, double alpha,
1113                              const DeviceMemory<double> &x, int incx,
1114                              DeviceMemory<double> *y, int incy) {
1115   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1116             PARAM(incy));
1117 
1118   ThenBlasImpl<uint64_t, double, const DeviceMemory<double> &, int,
1119                DeviceMemory<double> *, int>
1120       impl;
1121   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1122               y, incy);
1123 }
1124 
ThenBlasAxpy(uint64_t elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1125 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, std::complex<float> alpha,
1126                              const DeviceMemory<std::complex<float>> &x,
1127                              int incx, DeviceMemory<std::complex<float>> *y,
1128                              int incy) {
1129   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1130             PARAM(incy));
1131 
1132   ThenBlasImpl<uint64_t, std::complex<float>,
1133                const DeviceMemory<std::complex<float>> &, int,
1134                DeviceMemory<std::complex<float>> *, int>
1135       impl;
1136   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1137               y, incy);
1138 }
1139 
ThenBlasAxpy(uint64_t elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1140 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, std::complex<double> alpha,
1141                              const DeviceMemory<std::complex<double>> &x,
1142                              int incx, DeviceMemory<std::complex<double>> *y,
1143                              int incy) {
1144   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1145             PARAM(incy));
1146 
1147   ThenBlasImpl<uint64_t, std::complex<double>,
1148                const DeviceMemory<std::complex<double>> &, int,
1149                DeviceMemory<std::complex<double>> *, int>
1150       impl;
1151   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1152               y, incy);
1153 }
1154 
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1155 Stream &Stream::ThenBlasCopy(uint64_t elem_count, const DeviceMemory<float> &x,
1156                              int incx, DeviceMemory<float> *y, int incy) {
1157   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1158 
1159   ThenBlasImpl<uint64_t, const DeviceMemory<float> &, int,
1160                DeviceMemory<float> *, int>
1161       impl;
1162   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1163               incy);
1164 }
1165 
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1166 Stream &Stream::ThenBlasCopy(uint64_t elem_count, const DeviceMemory<double> &x,
1167                              int incx, DeviceMemory<double> *y, int incy) {
1168   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1169 
1170   ThenBlasImpl<uint64_t, const DeviceMemory<double> &, int,
1171                DeviceMemory<double> *, int>
1172       impl;
1173   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1174               incy);
1175 }
1176 
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1177 Stream &Stream::ThenBlasCopy(uint64_t elem_count,
1178                              const DeviceMemory<std::complex<float>> &x,
1179                              int incx, DeviceMemory<std::complex<float>> *y,
1180                              int incy) {
1181   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1182 
1183   ThenBlasImpl<uint64_t, const DeviceMemory<std::complex<float>> &, int,
1184                DeviceMemory<std::complex<float>> *, int>
1185       impl;
1186   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1187               incy);
1188 }
1189 
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1190 Stream &Stream::ThenBlasCopy(uint64_t elem_count,
1191                              const DeviceMemory<std::complex<double>> &x,
1192                              int incx, DeviceMemory<std::complex<double>> *y,
1193                              int incy) {
1194   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1195 
1196   ThenBlasImpl<uint64_t, const DeviceMemory<std::complex<double>> &, int,
1197                DeviceMemory<std::complex<double>> *, int>
1198       impl;
1199   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1200               incy);
1201 }
1202 
ThenBlasScal(uint64_t elem_count,float alpha,DeviceMemory<float> * x,int incx)1203 Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha,
1204                              DeviceMemory<float> *x, int incx) {
1205   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1206 
1207   ThenBlasImpl<uint64_t, float, DeviceMemory<float> *, int> impl;
1208   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1209 }
1210 
ThenBlasScal(uint64_t elem_count,double alpha,DeviceMemory<double> * x,int incx)1211 Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha,
1212                              DeviceMemory<double> *x, int incx) {
1213   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1214 
1215   ThenBlasImpl<uint64_t, double, DeviceMemory<double> *, int> impl;
1216   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1217 }
1218 
ThenBlasScal(uint64_t elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)1219 Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha,
1220                              DeviceMemory<std::complex<float>> *x, int incx) {
1221   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1222 
1223   ThenBlasImpl<uint64_t, float, DeviceMemory<std::complex<float>> *, int> impl;
1224   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1225 }
1226 
ThenBlasScal(uint64_t elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)1227 Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha,
1228                              DeviceMemory<std::complex<double>> *x, int incx) {
1229   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1230 
1231   ThenBlasImpl<uint64_t, double, DeviceMemory<std::complex<double>> *, int>
1232       impl;
1233   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1234 }
1235 
ThenBlasScal(uint64_t elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)1236 Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex<float> alpha,
1237                              DeviceMemory<std::complex<float>> *x, int incx) {
1238   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1239 
1240   ThenBlasImpl<uint64_t, std::complex<float>,
1241                DeviceMemory<std::complex<float>> *, int>
1242       impl;
1243   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1244 }
1245 
ThenBlasScal(uint64_t elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)1246 Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex<double> alpha,
1247                              DeviceMemory<std::complex<double>> *x, int incx) {
1248   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1249 
1250   ThenBlasImpl<uint64_t, std::complex<double>,
1251                DeviceMemory<std::complex<double>> *, int>
1252       impl;
1253   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1254 }
1255 
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1256 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1257                              float alpha, const DeviceMemory<float> &a, int lda,
1258                              const DeviceMemory<float> &x, int incx, float beta,
1259                              DeviceMemory<float> *y, int incy) {
1260   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1261             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1262             PARAM(incy));
1263 
1264   ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, float,
1265                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1266                int, float, DeviceMemory<float> *, int>
1267       impl;
1268   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1269               x, incx, beta, y, incy);
1270 }
1271 
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1272 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1273                              double alpha, const DeviceMemory<double> &a,
1274                              int lda, const DeviceMemory<double> &x, int incx,
1275                              double beta, DeviceMemory<double> *y, int incy) {
1276   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1277             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1278             PARAM(incy));
1279 
1280   ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, double,
1281                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1282                int, double, DeviceMemory<double> *, int>
1283       impl;
1284   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1285               x, incx, beta, y, incy);
1286 }
1287 
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1288 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1289                              std::complex<float> alpha,
1290                              const DeviceMemory<std::complex<float>> &a,
1291                              int lda,
1292                              const DeviceMemory<std::complex<float>> &x,
1293                              int incx, std::complex<float> beta,
1294                              DeviceMemory<std::complex<float>> *y, int incy) {
1295   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1296             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1297             PARAM(incy));
1298 
1299   ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, std::complex<float>,
1300                const DeviceMemory<std::complex<float>> &, int,
1301                const DeviceMemory<std::complex<float>> &, int,
1302                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
1303       impl;
1304   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1305               x, incx, beta, y, incy);
1306 }
1307 
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1308 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1309                              std::complex<double> alpha,
1310                              const DeviceMemory<std::complex<double>> &a,
1311                              int lda,
1312                              const DeviceMemory<std::complex<double>> &x,
1313                              int incx, std::complex<double> beta,
1314                              DeviceMemory<std::complex<double>> *y, int incy) {
1315   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1316             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1317             PARAM(incy));
1318 
1319   ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, std::complex<double>,
1320                const DeviceMemory<std::complex<double>> &, int,
1321                const DeviceMemory<std::complex<double>> &, int,
1322                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
1323       impl;
1324   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1325               x, incx, beta, y, incy);
1326 }
1327 
ThenBlasSbmv(blas::UpperLower uplo,uint64_t n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1328 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k,
1329                              float alpha, const DeviceMemory<float> &a, int lda,
1330                              const DeviceMemory<float> &x, int incx, float beta,
1331                              DeviceMemory<float> *y, int incy) {
1332   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
1333             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
1334 
1335   ThenBlasImpl<blas::UpperLower, uint64_t, uint64_t, float,
1336                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1337                int, float, DeviceMemory<float> *, int>
1338       impl;
1339   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
1340               x, incx, beta, y, incy);
1341 }
1342 
ThenBlasSbmv(blas::UpperLower uplo,uint64_t n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1343 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k,
1344                              double alpha, const DeviceMemory<double> &a,
1345                              int lda, const DeviceMemory<double> &x, int incx,
1346                              double beta, DeviceMemory<double> *y, int incy) {
1347   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
1348             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
1349 
1350   ThenBlasImpl<blas::UpperLower, uint64_t, uint64_t, double,
1351                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1352                int, double, DeviceMemory<double> *, int>
1353       impl;
1354   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
1355               x, incx, beta, y, incy);
1356 }
1357 
1358 namespace {
1359 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
1360 // blas::ProfileResult*.  This functor doesn't put the stream into an error
1361 // state if the op fails and the profile result is non-null.  Instead, the
1362 // error-ness is returned in the profile result itself.
1363 template <typename... Args>
1364 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon0297cec10211::ThenBlasWithProfileImpl1365   Stream &operator()(Stream *stream,
1366                      bool (blas::BlasSupport::*blas_func)(
1367                          Stream *, Args..., blas::ProfileResult *),
1368                      Args... args, blas::ProfileResult *profile_result) {
1369     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
1370     bool record_error = profile_result == nullptr;
1371     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
1372   }
1373 };
1374 }  // anonymous namespace
1375 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)1376 Stream &Stream::ThenBlasGemvWithProfiling(
1377     blas::Transpose trans, uint64_t m, uint64 n, float alpha,
1378     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1379     int incx, float beta, DeviceMemory<float> *y, int incy,
1380     blas::ProfileResult *output_profile_result) {
1381   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1382             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1383             PARAM(incy));
1384 
1385   ThenBlasWithProfileImpl<
1386       blas::Transpose, uint64_t, uint64_t, float, const DeviceMemory<float> &,
1387       int, const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
1388       impl;
1389   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1390               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1391 }
1392 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)1393 Stream &Stream::ThenBlasGemvWithProfiling(
1394     blas::Transpose trans, uint64_t m, uint64 n, double alpha,
1395     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1396     int incx, double beta, DeviceMemory<double> *y, int incy,
1397     blas::ProfileResult *output_profile_result) {
1398   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1399             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1400             PARAM(incy));
1401 
1402   ThenBlasWithProfileImpl<blas::Transpose, uint64_t, uint64_t, double,
1403                           const DeviceMemory<double> &, int,
1404                           const DeviceMemory<double> &, int, double,
1405                           DeviceMemory<double> *, int>
1406       impl;
1407   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1408               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1409 }
1410 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)1411 Stream &Stream::ThenBlasGemvWithProfiling(
1412     blas::Transpose trans, uint64_t m, uint64 n, std::complex<float> alpha,
1413     const DeviceMemory<std::complex<float>> &a, int lda,
1414     const DeviceMemory<std::complex<float>> &x, int incx,
1415     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1416     blas::ProfileResult *output_profile_result) {
1417   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1418             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1419             PARAM(incy));
1420 
1421   ThenBlasWithProfileImpl<
1422       blas::Transpose, uint64_t, uint64_t, std::complex<float>,
1423       const DeviceMemory<std::complex<float>> &, int,
1424       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
1425       DeviceMemory<std::complex<float>> *, int>
1426       impl;
1427   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1428               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1429 }
1430 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)1431 Stream &Stream::ThenBlasGemvWithProfiling(
1432     blas::Transpose trans, uint64_t m, uint64 n, std::complex<double> alpha,
1433     const DeviceMemory<std::complex<double>> &a, int lda,
1434     const DeviceMemory<std::complex<double>> &x, int incx,
1435     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1436     blas::ProfileResult *output_profile_result) {
1437   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1438             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1439             PARAM(incy));
1440 
1441   ThenBlasWithProfileImpl<
1442       blas::Transpose, uint64_t, uint64_t, std::complex<double>,
1443       const DeviceMemory<std::complex<double>> &, int,
1444       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
1445       DeviceMemory<std::complex<double>> *, int>
1446       impl;
1447   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1448               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1449 }
1450 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)1451 Stream &Stream::ThenBlasGemmWithProfiling(
1452     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1453     uint64_t k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1454     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1455     DeviceMemory<Eigen::half> *c, int ldc,
1456     blas::ProfileResult *output_profile_result) {
1457   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1458             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1459             PARAM(beta), PARAM(c), PARAM(ldc));
1460 
1461   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t,
1462                           uint64_t, float, const DeviceMemory<Eigen::half> &,
1463                           int, const DeviceMemory<Eigen::half> &, int, float,
1464                           DeviceMemory<Eigen::half> *, int>
1465       impl;
1466   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1467               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1468               output_profile_result);
1469 }
1470 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)1471 Stream &Stream::ThenBlasGemmWithProfiling(
1472     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1473     uint64_t k, float alpha, const DeviceMemory<float> &a, int lda,
1474     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1475     int ldc, blas::ProfileResult *output_profile_result) {
1476   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1477             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1478             PARAM(beta), PARAM(c), PARAM(ldc));
1479 
1480   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t,
1481                           uint64_t, float, const DeviceMemory<float> &, int,
1482                           const DeviceMemory<float> &, int, float,
1483                           DeviceMemory<float> *, int>
1484       impl;
1485   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1486               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1487               output_profile_result);
1488 }
1489 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)1490 Stream &Stream::ThenBlasGemmWithProfiling(
1491     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1492     uint64_t k, double alpha, const DeviceMemory<double> &a, int lda,
1493     const DeviceMemory<double> &b, int ldb, double beta,
1494     DeviceMemory<double> *c, int ldc,
1495     blas::ProfileResult *output_profile_result) {
1496   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1497             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1498             PARAM(beta), PARAM(c), PARAM(ldc));
1499 
1500   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t,
1501                           uint64_t, double, const DeviceMemory<double> &, int,
1502                           const DeviceMemory<double> &, int, double,
1503                           DeviceMemory<double> *, int>
1504       impl;
1505   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1506               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1507               output_profile_result);
1508 }
1509 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)1510 Stream &Stream::ThenBlasGemmWithProfiling(
1511     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1512     uint64_t k, std::complex<float> alpha,
1513     const DeviceMemory<std::complex<float>> &a, int lda,
1514     const DeviceMemory<std::complex<float>> &b, int ldb,
1515     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1516     blas::ProfileResult *output_profile_result) {
1517   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1518             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1519             PARAM(beta), PARAM(c), PARAM(ldc));
1520 
1521   ThenBlasWithProfileImpl<
1522       blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1523       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
1524       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
1525       DeviceMemory<std::complex<float>> *, int>
1526       impl;
1527   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1528               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1529               output_profile_result);
1530 }
1531 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)1532 Stream &Stream::ThenBlasGemmWithProfiling(
1533     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1534     uint64_t k, std::complex<double> alpha,
1535     const DeviceMemory<std::complex<double>> &a, int lda,
1536     const DeviceMemory<std::complex<double>> &b, int ldb,
1537     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1538     blas::ProfileResult *output_profile_result) {
1539   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1540             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1541             PARAM(beta), PARAM(c), PARAM(ldc));
1542 
1543   ThenBlasWithProfileImpl<
1544       blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1545       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
1546       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
1547       DeviceMemory<std::complex<double>> *, int>
1548       impl;
1549   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1550               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1551               output_profile_result);
1552 }
1553 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)1554 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1555                              blas::Transpose transa, blas::Diagonal diag,
1556                              uint64_t m, uint64 n, float alpha,
1557                              const DeviceMemory<float> &a, int lda,
1558                              DeviceMemory<float> *b, int ldb) {
1559   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1560             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1561 
1562   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1563                uint64_t, uint64_t, float, const DeviceMemory<float> &, int,
1564                DeviceMemory<float> *, int>
1565       impl;
1566   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1567               n, alpha, a, lda, b, ldb);
1568 }
1569 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)1570 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1571                              blas::Transpose transa, blas::Diagonal diag,
1572                              uint64_t m, uint64 n, double alpha,
1573                              const DeviceMemory<double> &a, int lda,
1574                              DeviceMemory<double> *b, int ldb) {
1575   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1576             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1577 
1578   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1579                uint64_t, uint64_t, double, const DeviceMemory<double> &, int,
1580                DeviceMemory<double> *, int>
1581       impl;
1582   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1583               n, alpha, a, lda, b, ldb);
1584 }
1585 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)1586 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1587                              blas::Transpose transa, blas::Diagonal diag,
1588                              uint64_t m, uint64 n, std::complex<float> alpha,
1589                              const DeviceMemory<std::complex<float>> &a,
1590                              int lda, DeviceMemory<std::complex<float>> *b,
1591                              int ldb) {
1592   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1593             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1594 
1595   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1596                uint64_t, uint64_t, std::complex<float>,
1597                const DeviceMemory<std::complex<float>> &, int,
1598                DeviceMemory<std::complex<float>> *, int>
1599       impl;
1600   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1601               n, alpha, a, lda, b, ldb);
1602 }
1603 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)1604 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1605                              blas::Transpose transa, blas::Diagonal diag,
1606                              uint64_t m, uint64 n, std::complex<double> alpha,
1607                              const DeviceMemory<std::complex<double>> &a,
1608                              int lda, DeviceMemory<std::complex<double>> *b,
1609                              int ldb) {
1610   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1611             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1612 
1613   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1614                uint64_t, uint64_t, std::complex<double>,
1615                const DeviceMemory<std::complex<double>> &, int,
1616                DeviceMemory<std::complex<double>> *, int>
1617       impl;
1618   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1619               n, alpha, a, lda, b, ldb);
1620 }
1621 
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float * > & as,int lda,DeviceMemory<float * > * bs,int ldb,int batch_count)1622 Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1623                                     blas::Transpose transa, blas::Diagonal diag,
1624                                     uint64_t m, uint64 n, float alpha,
1625                                     const DeviceMemory<float *> &as, int lda,
1626                                     DeviceMemory<float *> *bs, int ldb,
1627                                     int batch_count) {
1628   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1629             PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1630             PARAM(ldb), PARAM(batch_count));
1631 
1632   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1633                uint64_t, uint64_t, float, const DeviceMemory<float *> &, int,
1634                DeviceMemory<float *> *, int, int>
1635       impl;
1636   return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1637               diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1638 }
1639 
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double * > & as,int lda,DeviceMemory<double * > * bs,int ldb,int batch_count)1640 Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1641                                     blas::Transpose transa, blas::Diagonal diag,
1642                                     uint64_t m, uint64 n, double alpha,
1643                                     const DeviceMemory<double *> &as, int lda,
1644                                     DeviceMemory<double *> *bs, int ldb,
1645                                     int batch_count) {
1646   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1647             PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1648             PARAM(ldb), PARAM(batch_count));
1649 
1650   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1651                uint64_t, uint64_t, double, const DeviceMemory<double *> &, int,
1652                DeviceMemory<double *> *, int, int>
1653       impl;
1654   return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1655               diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1656 }
1657 
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float> * > & as,int lda,DeviceMemory<std::complex<float> * > * bs,int ldb,int batch_count)1658 Stream &Stream::ThenBlasTrsmBatched(
1659     blas::Side side, blas::UpperLower uplo, blas::Transpose transa,
1660     blas::Diagonal diag, uint64_t m, uint64 n, std::complex<float> alpha,
1661     const DeviceMemory<std::complex<float> *> &as, int lda,
1662     DeviceMemory<std::complex<float> *> *bs, int ldb, int batch_count) {
1663   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1664             PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1665             PARAM(ldb), PARAM(batch_count));
1666 
1667   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1668                uint64_t, uint64_t, std::complex<float>,
1669                const DeviceMemory<std::complex<float> *> &, int,
1670                DeviceMemory<std::complex<float> *> *, int, int>
1671       impl;
1672   return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1673               diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1674 }
1675 
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double> * > & as,int lda,DeviceMemory<std::complex<double> * > * bs,int ldb,int batch_count)1676 Stream &Stream::ThenBlasTrsmBatched(
1677     blas::Side side, blas::UpperLower uplo, blas::Transpose transa,
1678     blas::Diagonal diag, uint64_t m, uint64 n, std::complex<double> alpha,
1679     const DeviceMemory<std::complex<double> *> &as, int lda,
1680     DeviceMemory<std::complex<double> *> *bs, int ldb, int batch_count) {
1681   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1682             PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1683             PARAM(ldb), PARAM(batch_count));
1684 
1685   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1686                uint64_t, uint64_t, std::complex<double>,
1687                const DeviceMemory<std::complex<double> *> &, int,
1688                DeviceMemory<std::complex<double> *> *, int, int>
1689       impl;
1690   return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1691               diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1692 }
1693 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)1694 Stream &Stream::ThenBlasGemmBatched(
1695     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1696     uint64_t k, float alpha,
1697     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a,  // non-absl ok
1698     int lda,
1699     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b,  // non-absl ok
1700     int ldb, float beta,
1701     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,  // non-absl ok
1702     int ldc, int batch_count) {
1703   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1704                                         b, ldb, beta, c, ldc, batch_count,
1705                                         /*scratch_allocator=*/nullptr);
1706 }
1707 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1708 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1709     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1710     uint64_t k, float alpha,
1711     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a,  // non-absl ok
1712     int lda,
1713     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b,  // non-absl ok
1714     int ldb, float beta,
1715     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,  // non-absl ok
1716     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1717   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1718             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1719             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1720 
1721   ThenBlasImpl<
1722       blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, float,
1723       const port::ArraySlice<DeviceMemory<Eigen::half> *> &,  // non-absl ok
1724       int,
1725       const port::ArraySlice<DeviceMemory<Eigen::half> *> &,  // non-absl ok
1726       int, float,
1727       const port::ArraySlice<DeviceMemory<Eigen::half> *> &,  // non-absl ok
1728       int, int, ScratchAllocator *>
1729       impl;
1730   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1731               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1732               scratch_allocator);
1733 }
1734 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)1735 Stream &Stream::ThenBlasGemmBatched(
1736     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1737     uint64_t k, float alpha,
1738     const port::ArraySlice<DeviceMemory<float> *> &a,           // non-absl ok
1739     int lda, const port::ArraySlice<DeviceMemory<float> *> &b,  // non-absl ok
1740     int ldb, float beta,
1741     const port::ArraySlice<DeviceMemory<float> *> &c,  // non-absl ok
1742     int ldc, int batch_count) {
1743   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1744                                         b, ldb, beta, c, ldc, batch_count,
1745                                         /*scratch_allocator=*/nullptr);
1746 }
1747 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1748 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1749     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1750     uint64_t k, float alpha,
1751     const port::ArraySlice<DeviceMemory<float> *> &a,           // non-absl ok
1752     int lda, const port::ArraySlice<DeviceMemory<float> *> &b,  // non-absl ok
1753     int ldb, float beta,
1754     const port::ArraySlice<DeviceMemory<float> *> &c,  // non-absl ok
1755     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1756   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1757             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1758             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1759 
1760   ThenBlasImpl<
1761       blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, float,
1762       const port::ArraySlice<DeviceMemory<float> *> &, int,    // non-absl ok
1763       const port::ArraySlice<DeviceMemory<float> *> &, int,    // non-absl ok
1764       float, const port::ArraySlice<DeviceMemory<float> *> &,  // non-absl ok
1765       int, int, ScratchAllocator *>
1766       impl;
1767   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1768               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1769               scratch_allocator);
1770 }
1771 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)1772 Stream &Stream::ThenBlasGemmBatched(
1773     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1774     uint64_t k, double alpha,
1775     const port::ArraySlice<DeviceMemory<double> *> &a,           // non-absl ok
1776     int lda, const port::ArraySlice<DeviceMemory<double> *> &b,  // non-absl ok
1777     int ldb, double beta,
1778     const port::ArraySlice<DeviceMemory<double> *> &c,  // non-absl ok
1779     int ldc, int batch_count) {
1780   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1781                                         b, ldb, beta, c, ldc, batch_count,
1782                                         /*scratch_allocator=*/nullptr);
1783 }
1784 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1785 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1786     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1787     uint64_t k, double alpha,
1788     const port::ArraySlice<DeviceMemory<double> *> &a,           // non-absl ok
1789     int lda, const port::ArraySlice<DeviceMemory<double> *> &b,  // non-absl ok
1790     int ldb, double beta,
1791     const port::ArraySlice<DeviceMemory<double> *> &c,  // non-absl ok
1792     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1793   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1794             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1795             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1796 
1797   ThenBlasImpl<
1798       blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, double,
1799       const port::ArraySlice<DeviceMemory<double> *> &,       // non-absl ok
1800       int, const port::ArraySlice<DeviceMemory<double> *> &,  // non-absl ok
1801       int, double,
1802       const port::ArraySlice<DeviceMemory<double> *> &,  // non-absl ok
1803       int, int, ScratchAllocator *>
1804       impl;
1805   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1806               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1807               scratch_allocator);
1808 }
1809 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)1810 Stream &Stream::ThenBlasGemmBatched(
1811     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1812     uint64_t k, std::complex<float> alpha,
1813     const port::ArraySlice<DeviceMemory<std::complex<float>> *>  // non-absl ok
1814         &a,
1815     int lda,
1816     const port::ArraySlice<DeviceMemory<std::complex<float>> *>  // non-absl ok
1817         &b,
1818     int ldb, std::complex<float> beta,
1819     const port::ArraySlice<DeviceMemory<std::complex<float>> *>  // non-absl ok
1820         &c,
1821     int ldc, int batch_count) {
1822   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1823                                         b, ldb, beta, c, ldc, batch_count,
1824                                         /*scratch_allocator=*/nullptr);
1825 }
1826 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1827 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1828     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1829     uint64_t k, std::complex<float> alpha,
1830     const port::ArraySlice<DeviceMemory<std::complex<float>> *>  // non-absl ok
1831         &a,
1832     int lda,
1833     const port::ArraySlice<DeviceMemory<std::complex<float>> *>  // non-absl ok
1834         &b,
1835     int ldb, std::complex<float> beta,
1836     const port::ArraySlice<DeviceMemory<std::complex<float>> *>  // non-absl ok
1837         &c,
1838     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1839   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1840             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1841             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1842 
1843   ThenBlasImpl<
1844       blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1845       std::complex<float>, const DeviceMemorySlice<std::complex<float>> &, int,
1846       const DeviceMemorySlice<std::complex<float>> &, int, std::complex<float>,
1847       const DeviceMemorySlice<std::complex<float>> &, int, int,
1848       ScratchAllocator *>
1849       impl;
1850   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1851               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1852               scratch_allocator);
1853 }
1854 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<double> alpha,const DeviceMemorySlice<std::complex<double>> & a,int lda,const DeviceMemorySlice<std::complex<double>> & b,int ldb,std::complex<double> beta,const DeviceMemorySlice<std::complex<double>> & c,int ldc,int batch_count)1855 Stream &Stream::ThenBlasGemmBatched(
1856     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1857     uint64_t k, std::complex<double> alpha,
1858     const DeviceMemorySlice<std::complex<double>> &a, int lda,
1859     const DeviceMemorySlice<std::complex<double>> &b, int ldb,
1860     std::complex<double> beta, const DeviceMemorySlice<std::complex<double>> &c,
1861     int ldc, int batch_count) {
1862   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1863                                         b, ldb, beta, c, ldc, batch_count,
1864                                         /*scratch_allocator=*/nullptr);
1865 }
1866 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<double> alpha,const DeviceMemorySlice<std::complex<double>> & a,int lda,const DeviceMemorySlice<std::complex<double>> & b,int ldb,std::complex<double> beta,const DeviceMemorySlice<std::complex<double>> & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1867 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1868     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1869     uint64_t k, std::complex<double> alpha,
1870     const DeviceMemorySlice<std::complex<double>> &a, int lda,
1871     const DeviceMemorySlice<std::complex<double>> &b, int ldb,
1872     std::complex<double> beta,
1873     const DeviceMemorySlice<std::complex<double>> &c,  // non-absl ok
1874     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1875   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1876             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1877             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1878 
1879   ThenBlasImpl<
1880       blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1881       std::complex<double>, const DeviceMemorySlice<std::complex<double>> &,
1882       int, const DeviceMemorySlice<std::complex<double>> &, int,
1883       std::complex<double>, const DeviceMemorySlice<std::complex<double>> &,
1884       int, int, ScratchAllocator *>
1885       impl;
1886   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1887               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1888               scratch_allocator);
1889 }
1890 
ThenSetRngSeed(const uint8 * seed,uint64_t seed_bytes)1891 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes) {
1892   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
1893 
1894   if (rng::RngSupport *rng = parent_->AsRng()) {
1895     CheckError(rng->SetSeed(this, seed, seed_bytes));
1896   } else {
1897     SetError();
1898     LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
1899   }
1900   return *this;
1901 }
1902 
ThenPopulateRandUniform(DeviceMemory<float> * values)1903 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
1904   VLOG_CALL(PARAM(values));
1905 
1906   if (rng::RngSupport *rng = parent_->AsRng()) {
1907     CheckError(rng->DoPopulateRandUniform(this, values));
1908   } else {
1909     SetError();
1910     LOG(INFO) << DebugStreamPointers()
1911               << " attempting to perform RNG operation using StreamExecutor"
1912                  " without RNG support.";
1913   }
1914   return *this;
1915 }
1916 
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)1917 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
1918                                          DeviceMemory<float> *values) {
1919   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
1920 
1921   if (rng::RngSupport *rng = parent_->AsRng()) {
1922     CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
1923   } else {
1924     SetError();
1925     LOG(INFO) << DebugStreamPointers()
1926               << " attempting to perform RNG operation using StreamExecutor"
1927                  " without RNG support.";
1928   }
1929   return *this;
1930 }
1931 
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)1932 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
1933                                          DeviceMemory<double> *values) {
1934   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
1935 
1936   if (rng::RngSupport *rng = parent_->AsRng()) {
1937     CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
1938   } else {
1939     SetError();
1940     LOG(INFO) << DebugStreamPointers()
1941               << " attempting to perform RNG operation using StreamExecutor"
1942                  " without RNG support.";
1943   }
1944   return *this;
1945 }
1946 
ThenPopulateRandUniform(DeviceMemory<double> * values)1947 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
1948   VLOG_CALL(PARAM(values));
1949 
1950   if (rng::RngSupport *rng = parent_->AsRng()) {
1951     CheckError(rng->DoPopulateRandUniform(this, values));
1952   } else {
1953     SetError();
1954     LOG(INFO) << DebugStreamPointers()
1955               << " attempting to perform RNG operation using StreamExecutor"
1956                  " without RNG support.";
1957   }
1958   return *this;
1959 }
1960 
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)1961 Stream &Stream::ThenPopulateRandUniform(
1962     DeviceMemory<std::complex<float>> *values) {
1963   VLOG_CALL(PARAM(values));
1964 
1965   if (rng::RngSupport *rng = parent_->AsRng()) {
1966     CheckError(rng->DoPopulateRandUniform(this, values));
1967   } else {
1968     SetError();
1969     LOG(INFO) << DebugStreamPointers()
1970               << " attempting to perform RNG operation using StreamExecutor"
1971                  " without RNG support.";
1972   }
1973   return *this;
1974 }
1975 
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)1976 Stream &Stream::ThenPopulateRandUniform(
1977     DeviceMemory<std::complex<double>> *values) {
1978   VLOG_CALL(PARAM(values));
1979 
1980   if (rng::RngSupport *rng = parent_->AsRng()) {
1981     CheckError(rng->DoPopulateRandUniform(this, values));
1982   } else {
1983     SetError();
1984     LOG(INFO) << DebugStreamPointers()
1985               << " attempting to perform RNG operation using StreamExecutor"
1986                  " without RNG support.";
1987   }
1988   return *this;
1989 }
1990 
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64_t size)1991 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1992                            uint64_t size) {
1993   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
1994 
1995   CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
1996   return *this;
1997 }
1998 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64_t size)1999 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
2000                            uint64_t size) {
2001   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
2002 
2003   CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
2004   return *this;
2005 }
2006 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64_t size)2007 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
2008                            const DeviceMemoryBase &gpu_src, uint64_t size) {
2009   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
2010 
2011   CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
2012   return *this;
2013 }
2014 
ThenMemZero(DeviceMemoryBase * location,uint64_t size)2015 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64_t size) {
2016   VLOG_CALL(PARAM(location), PARAM(size));
2017 
2018   CheckStatus(parent_->MemZero(this, location, size));
2019   return *this;
2020 }
2021 
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64_t size)2022 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
2023                              uint64_t size) {
2024   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
2025 
2026   CheckStatus(parent_->Memset32(this, location, pattern, size));
2027   return *this;
2028 }
2029 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2030 Stream &Stream::ThenRnnForward(
2031     const dnn::RnnDescriptor &rnn_desc,
2032     const dnn::RnnSequenceTensorDescriptor &input_desc,
2033     const DeviceMemory<Eigen::half> &input_data,
2034     const DeviceMemory<int> &seq_lengths_data,
2035     const dnn::RnnStateTensorDescriptor &input_h_desc,
2036     const DeviceMemory<Eigen::half> &input_h_data,
2037     const dnn::RnnStateTensorDescriptor &input_c_desc,
2038     const DeviceMemory<Eigen::half> &input_c_data,
2039     const DeviceMemory<Eigen::half> &params,
2040     const dnn::RnnSequenceTensorDescriptor &output_desc,
2041     DeviceMemory<Eigen::half> *output_data,
2042     const dnn::RnnStateTensorDescriptor &output_h_desc,
2043     DeviceMemory<Eigen::half> *output_h_data,
2044     const dnn::RnnStateTensorDescriptor &output_c_desc,
2045     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
2046     ScratchAllocator *reserve_space_allocator,
2047     ScratchAllocator *workspace_allocator,
2048     dnn::ProfileResult *output_profile_result) {
2049   // TODO(zhengxq): add VLOG PARAM calls.
2050   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2051     auto status = dnn->DoRnnForward(
2052         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2053         input_h_data, input_c_desc, input_c_data, params, output_desc,
2054         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2055         is_training, reserve_space_allocator, workspace_allocator,
2056         output_profile_result);
2057     if (!status && !output_profile_result) {
2058       SetError();
2059     }
2060   } else {
2061     SetErrorAndLogNoDnnSupport();
2062   }
2063   return *this;
2064 }
2065 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2066 Stream &Stream::ThenRnnForward(
2067     const dnn::RnnDescriptor &rnn_desc,
2068     const dnn::RnnSequenceTensorDescriptor &input_desc,
2069     const DeviceMemory<float> &input_data,
2070     const DeviceMemory<int> &seq_lengths_data,
2071     const dnn::RnnStateTensorDescriptor &input_h_desc,
2072     const DeviceMemory<float> &input_h_data,
2073     const dnn::RnnStateTensorDescriptor &input_c_desc,
2074     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
2075     const dnn::RnnSequenceTensorDescriptor &output_desc,
2076     DeviceMemory<float> *output_data,
2077     const dnn::RnnStateTensorDescriptor &output_h_desc,
2078     DeviceMemory<float> *output_h_data,
2079     const dnn::RnnStateTensorDescriptor &output_c_desc,
2080     DeviceMemory<float> *output_c_data, bool is_training,
2081     ScratchAllocator *reserve_space_allocator,
2082     ScratchAllocator *workspace_allocator,
2083     dnn::ProfileResult *output_profile_result) {
2084   // TODO(zhengxq): add VLOG PARAM calls.
2085   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2086     auto status = dnn->DoRnnForward(
2087         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2088         input_h_data, input_c_desc, input_c_data, params, output_desc,
2089         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2090         is_training, reserve_space_allocator, workspace_allocator,
2091         output_profile_result);
2092     if (!status && !output_profile_result) {
2093       SetError();
2094     }
2095   } else {
2096     SetErrorAndLogNoDnnSupport();
2097   }
2098   return *this;
2099 }
2100 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2101 Stream &Stream::ThenRnnForward(
2102     const dnn::RnnDescriptor &rnn_desc,
2103     const dnn::RnnSequenceTensorDescriptor &input_desc,
2104     const DeviceMemory<double> &input_data,
2105     const DeviceMemory<int> &seq_lengths_data,
2106     const dnn::RnnStateTensorDescriptor &input_h_desc,
2107     const DeviceMemory<double> &input_h_data,
2108     const dnn::RnnStateTensorDescriptor &input_c_desc,
2109     const DeviceMemory<double> &input_c_data,
2110     const DeviceMemory<double> &params,
2111     const dnn::RnnSequenceTensorDescriptor &output_desc,
2112     DeviceMemory<double> *output_data,
2113     const dnn::RnnStateTensorDescriptor &output_h_desc,
2114     DeviceMemory<double> *output_h_data,
2115     const dnn::RnnStateTensorDescriptor &output_c_desc,
2116     DeviceMemory<double> *output_c_data, bool is_training,
2117     ScratchAllocator *reserve_space_allocator,
2118     ScratchAllocator *workspace_allocator,
2119     dnn::ProfileResult *output_profile_result) {
2120   // TODO(zhengxq): add VLOG PARAM calls.
2121   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2122     auto status = dnn->DoRnnForward(
2123         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2124         input_h_data, input_c_desc, input_c_data, params, output_desc,
2125         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2126         is_training, reserve_space_allocator, workspace_allocator,
2127         output_profile_result);
2128     if (!status && !output_profile_result) {
2129       SetError();
2130     }
2131   } else {
2132     SetErrorAndLogNoDnnSupport();
2133   }
2134   return *this;
2135 }
2136 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2137 Stream &Stream::ThenRnnBackward(
2138     const dnn::RnnDescriptor &rnn_desc,
2139     const dnn::RnnSequenceTensorDescriptor &input_desc,
2140     const DeviceMemory<Eigen::half> &input_data,
2141     const DeviceMemory<int> &seq_lengths_data,
2142     const dnn::RnnStateTensorDescriptor &input_h_desc,
2143     const DeviceMemory<Eigen::half> &input_h_data,
2144     const dnn::RnnStateTensorDescriptor &input_c_desc,
2145     const DeviceMemory<Eigen::half> &input_c_data,
2146     const DeviceMemory<Eigen::half> &params,
2147     const dnn::RnnSequenceTensorDescriptor &output_desc,
2148     const DeviceMemory<Eigen::half> &output_data,
2149     const dnn::RnnStateTensorDescriptor &output_h_desc,
2150     const DeviceMemory<Eigen::half> &output_h_data,
2151     const dnn::RnnStateTensorDescriptor &output_c_desc,
2152     const DeviceMemory<Eigen::half> &output_c_data,
2153     const DeviceMemory<Eigen::half> &output_backprop_data,
2154     const DeviceMemory<Eigen::half> &output_h_backprop_data,
2155     const DeviceMemory<Eigen::half> &output_c_backprop_data,
2156     DeviceMemory<Eigen::half> *input_backprop_data,
2157     DeviceMemory<Eigen::half> *input_h_backprop_data,
2158     DeviceMemory<Eigen::half> *input_c_backprop_data,
2159     DeviceMemory<Eigen::half> *params_backprop_data,
2160     DeviceMemory<uint8> *reserve_space_data,
2161     ScratchAllocator *workspace_allocator,
2162     dnn::ProfileResult *output_profile_result) {
2163   // TODO(zhengxq): add VLOG PARAM calls.
2164   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2165     auto status = dnn->DoRnnBackward(
2166         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2167         input_h_data, input_c_desc, input_c_data, params, output_desc,
2168         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2169         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2170         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2171         params_backprop_data, reserve_space_data, workspace_allocator,
2172         output_profile_result);
2173     if (!status && !output_profile_result) {
2174       SetError();
2175     }
2176   } else {
2177     SetError();
2178     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
2179   }
2180   return *this;
2181 }
2182 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2183 Stream &Stream::ThenRnnBackward(
2184     const dnn::RnnDescriptor &rnn_desc,
2185     const dnn::RnnSequenceTensorDescriptor &input_desc,
2186     const DeviceMemory<float> &input_data,
2187     const DeviceMemory<int> &seq_lengths_data,
2188     const dnn::RnnStateTensorDescriptor &input_h_desc,
2189     const DeviceMemory<float> &input_h_data,
2190     const dnn::RnnStateTensorDescriptor &input_c_desc,
2191     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
2192     const dnn::RnnSequenceTensorDescriptor &output_desc,
2193     const DeviceMemory<float> &output_data,
2194     const dnn::RnnStateTensorDescriptor &output_h_desc,
2195     const DeviceMemory<float> &output_h_data,
2196     const dnn::RnnStateTensorDescriptor &output_c_desc,
2197     const DeviceMemory<float> &output_c_data,
2198     const DeviceMemory<float> &output_backprop_data,
2199     const DeviceMemory<float> &output_h_backprop_data,
2200     const DeviceMemory<float> &output_c_backprop_data,
2201     DeviceMemory<float> *input_backprop_data,
2202     DeviceMemory<float> *input_h_backprop_data,
2203     DeviceMemory<float> *input_c_backprop_data,
2204     DeviceMemory<float> *params_backprop_data,
2205     DeviceMemory<uint8> *reserve_space_data,
2206     ScratchAllocator *workspace_allocator,
2207     dnn::ProfileResult *output_profile_result) {
2208   // TODO(zhengxq): add VLOG PARAM calls.
2209   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2210     auto status = dnn->DoRnnBackward(
2211         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2212         input_h_data, input_c_desc, input_c_data, params, output_desc,
2213         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2214         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2215         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2216         params_backprop_data, reserve_space_data, workspace_allocator,
2217         output_profile_result);
2218     if (!status && !output_profile_result) {
2219       SetError();
2220     }
2221   } else {
2222     SetError();
2223     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
2224   }
2225   return *this;
2226 }
2227 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2228 Stream &Stream::ThenRnnBackward(
2229     const dnn::RnnDescriptor &rnn_desc,
2230     const dnn::RnnSequenceTensorDescriptor &input_desc,
2231     const DeviceMemory<double> &input_data,
2232     const DeviceMemory<int> &seq_lengths_data,
2233     const dnn::RnnStateTensorDescriptor &input_h_desc,
2234     const DeviceMemory<double> &input_h_data,
2235     const dnn::RnnStateTensorDescriptor &input_c_desc,
2236     const DeviceMemory<double> &input_c_data,
2237     const DeviceMemory<double> &params,
2238     const dnn::RnnSequenceTensorDescriptor &output_desc,
2239     const DeviceMemory<double> &output_data,
2240     const dnn::RnnStateTensorDescriptor &output_h_desc,
2241     const DeviceMemory<double> &output_h_data,
2242     const dnn::RnnStateTensorDescriptor &output_c_desc,
2243     const DeviceMemory<double> &output_c_data,
2244     const DeviceMemory<double> &output_backprop_data,
2245     const DeviceMemory<double> &output_h_backprop_data,
2246     const DeviceMemory<double> &output_c_backprop_data,
2247     DeviceMemory<double> *input_backprop_data,
2248     DeviceMemory<double> *input_h_backprop_data,
2249     DeviceMemory<double> *input_c_backprop_data,
2250     DeviceMemory<double> *params_backprop_data,
2251     DeviceMemory<uint8> *reserve_space_data,
2252     ScratchAllocator *workspace_allocator,
2253     dnn::ProfileResult *output_profile_result) {
2254   // TODO(zhengxq): add VLOG PARAM calls.
2255   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2256     auto status = dnn->DoRnnBackward(
2257         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2258         input_h_data, input_c_desc, input_c_data, params, output_desc,
2259         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2260         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2261         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2262         params_backprop_data, reserve_space_data, workspace_allocator,
2263         output_profile_result);
2264     if (!status && !output_profile_result) {
2265       SetError();
2266     }
2267   } else {
2268     SetError();
2269     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
2270   }
2271   return *this;
2272 }
2273 
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)2274 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
2275                             const DeviceMemory<float> &probs_data,
2276                             absl::Span<const int> labels_data,
2277                             absl::Span<const int> labels_lengths_data,
2278                             absl::Span<const int> input_lengths_data,
2279                             DeviceMemory<float> *costs_data,
2280                             const dnn::RnnStateTensorDescriptor &grads_desc,
2281                             DeviceMemory<float> *grads_data,
2282                             ScratchAllocator *workspace_allocator) {
2283   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2284     DeviceMemory<uint8> scratch_memory;
2285     int ctc_loss_algo_id;
2286     auto status =
2287         dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
2288                                labels_data, labels_lengths_data,
2289                                input_lengths_data, workspace_allocator,
2290                                &scratch_memory, &ctc_loss_algo_id)
2291             .ok();
2292     if (status) {
2293       status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
2294                               labels_lengths_data, input_lengths_data,
2295                               costs_data, grads_desc, grads_data,
2296                               &scratch_memory, ctc_loss_algo_id);
2297     }
2298     if (!status) {
2299       SetError();
2300     }
2301   } else {
2302     SetErrorAndLogNoDnnSupport();
2303   }
2304   return *this;
2305 }
2306 
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)2307 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
2308                                     dnn::DataType input_type,
2309                                     const DeviceMemoryBase &input_data,
2310                                     const dnn::BatchDescriptor &output_desc,
2311                                     dnn::DataType output_type, float scale,
2312                                     DeviceMemoryBase *output_data) {
2313   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
2314             PARAM(output_desc), PARAM(output_type), PARAM(scale),
2315             PARAM(output_data));
2316   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2317     CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data,
2318                                       output_desc, output_type, scale,
2319                                       output_data));
2320   } else {
2321     SetErrorAndLogNoDnnSupport();
2322   }
2323   return *this;
2324 }
2325 
ThenDoHostCallback(std::function<void ()> callback)2326 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
2327   VLOG_CALL(PARAM(callback));
2328 
2329   if (!ok()) {
2330     LOG(INFO) << DebugStreamPointers()
2331               << " was in error state before adding host callback";
2332   }
2333   CheckError(parent_->HostCallback(this, std::move(callback)));
2334   return *this;
2335 }
2336 
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)2337 Stream &Stream::ThenDoHostCallbackWithStatus(
2338     std::function<port::Status()> callback) {
2339   VLOG_CALL(PARAM(callback));
2340 
2341   if (!ok()) {
2342     LOG(INFO) << DebugStreamPointers()
2343               << " was in error state before adding host callback";
2344   }
2345   CheckError(parent_->HostCallback(this, std::move(callback)));
2346   return *this;
2347 }
2348 
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)2349 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
2350     std::function<void()> callback) {
2351   VLOG_CALL(PARAM(callback));
2352 
2353   if (!ok()) {
2354     LOG(INFO) << DebugStreamPointers()
2355               << " was in error state before adding callback to be run after "
2356                  "next block-host-until-done.";
2357   }
2358   absl::MutexLock lock(&mu_);
2359   after_block_host_until_done_callbacks_.push_back(std::move(callback));
2360   return *this;
2361 }
2362 
CheckError(bool operation_retcode)2363 void Stream::CheckError(bool operation_retcode) {
2364   if (operation_retcode) {
2365     return;
2366   }
2367   absl::MutexLock lock(&mu_);
2368   status_ = port::InternalError("Unknown error");
2369 }
2370 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)2371 Stream &Stream::ThenFft(fft::Plan *plan,
2372                         const DeviceMemory<std::complex<float>> &input,
2373                         DeviceMemory<std::complex<float>> *output) {
2374   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2375 
2376   if (fft::FftSupport *fft = parent_->AsFft()) {
2377     CheckError(fft->DoFft(this, plan, input, output));
2378   } else {
2379     SetError();
2380     LOG(INFO) << DebugStreamPointers()
2381               << " attempting to perform FFT operation using StreamExecutor"
2382                  " without FFT support";
2383   }
2384   return *this;
2385 }
2386 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)2387 Stream &Stream::ThenFft(fft::Plan *plan,
2388                         const DeviceMemory<std::complex<double>> &input,
2389                         DeviceMemory<std::complex<double>> *output) {
2390   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2391 
2392   if (fft::FftSupport *fft = parent_->AsFft()) {
2393     CheckError(fft->DoFft(this, plan, input, output));
2394   } else {
2395     SetError();
2396     LOG(INFO) << DebugStreamPointers()
2397               << " attempting to perform FFT operation using StreamExecutor"
2398                  " without FFT support";
2399   }
2400   return *this;
2401 }
2402 
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)2403 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
2404                         DeviceMemory<std::complex<float>> *output) {
2405   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2406 
2407   if (fft::FftSupport *fft = parent_->AsFft()) {
2408     CheckError(fft->DoFft(this, plan, input, output));
2409   } else {
2410     SetError();
2411     LOG(INFO) << DebugStreamPointers()
2412               << " attempting to perform FFT operation using StreamExecutor"
2413                  " without FFT support";
2414   }
2415   return *this;
2416 }
2417 
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)2418 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
2419                         DeviceMemory<std::complex<double>> *output) {
2420   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2421 
2422   if (fft::FftSupport *fft = parent_->AsFft()) {
2423     CheckError(fft->DoFft(this, plan, input, output));
2424   } else {
2425     SetError();
2426     LOG(INFO) << DebugStreamPointers()
2427               << " attempting to perform FFT operation using StreamExecutor"
2428                  " without FFT support";
2429   }
2430   return *this;
2431 }
2432 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)2433 Stream &Stream::ThenFft(fft::Plan *plan,
2434                         const DeviceMemory<std::complex<float>> &input,
2435                         DeviceMemory<float> *output) {
2436   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2437 
2438   if (fft::FftSupport *fft = parent_->AsFft()) {
2439     CheckError(fft->DoFft(this, plan, input, output));
2440   } else {
2441     SetError();
2442     LOG(INFO) << DebugStreamPointers()
2443               << " attempting to perform FFT operation using StreamExecutor"
2444                  " without FFT support";
2445   }
2446   return *this;
2447 }
2448 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)2449 Stream &Stream::ThenFft(fft::Plan *plan,
2450                         const DeviceMemory<std::complex<double>> &input,
2451                         DeviceMemory<double> *output) {
2452   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2453 
2454   if (fft::FftSupport *fft = parent_->AsFft()) {
2455     CheckError(fft->DoFft(this, plan, input, output));
2456   } else {
2457     SetError();
2458     LOG(INFO) << DebugStreamPointers()
2459               << " attempting to perform FFT operation using StreamExecutor"
2460                  " without FFT support";
2461   }
2462   return *this;
2463 }
2464 
2465 // It looks confusing, but all this is doing is inserting a callback at the
2466 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)2467 Stream &Stream::ThenEnqueueOnBackgroundThread(
2468     std::function<void(StreamExecutor *)> task) {
2469   VLOG_CALL(PARAM(task));
2470 
2471   StreamExecutor *stream_executor = this->parent_;
2472   std::function<void()> bound_task = std::bind(task, stream_executor);
2473 
2474   return ThenDoHostCallback([stream_executor, bound_task]() {
2475     stream_executor->EnqueueOnBackgroundThread(bound_task);
2476   });
2477 }
2478 
BlockHostUntilDone()2479 port::Status Stream::BlockHostUntilDone() {
2480   VLOG_CALL();
2481 
2482   if (!ok()) {
2483     absl::MutexLock lock(&mu_);
2484     LOG(INFO) << status_.ToString();
2485     port::Status status = port::Status(
2486         port::error::INTERNAL,
2487         "stream did not block host until done; was already in an error state");
2488     LOG(INFO) << DebugStreamPointers() << " " << status;
2489     return status;
2490   }
2491 
2492   temporary_memory_manager_.DeallocateFinalizedTemporaries();
2493 
2494   port::Status error = parent_->BlockHostUntilDone(this);
2495   CheckError(error.ok());
2496 
2497   RunAfterBlockHostUntilDoneCallbacks();
2498   return error;
2499 }
2500 
RunAfterBlockHostUntilDoneCallbacks()2501 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
2502   std::vector<std::function<void()>> callbacks;
2503   {
2504     absl::MutexLock lock(&mu_);
2505     std::swap(callbacks, after_block_host_until_done_callbacks_);
2506   }
2507   for (const auto &fn : callbacks) {
2508     fn();
2509   }
2510 }
2511 
DebugStreamPointers() const2512 std::string Stream::DebugStreamPointers() const {
2513   // Relies on the ToVlogString(const void*) overload above.
2514   return absl::StrCat("[stream=", ToVlogString(this),
2515                       ",impl=", ToVlogString(implementation_.get()), "]");
2516 }
2517 
CheckStatus(port::Status status)2518 void Stream::CheckStatus(port::Status status) {
2519   if (status.ok()) {
2520     return;
2521   }
2522   LOG(ERROR) << status;
2523   absl::MutexLock lock(&mu_);
2524   status_ = status;
2525 }
2526 
2527 }  // namespace stream_executor
2528