1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_SEQUENCE_GPU_KERNEL_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_SEQUENCE_GPU_KERNEL_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <iostream> 23 #include "backend/kernel_compiler/gpu/gpu_kernel.h" 24 #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" 25 #include "runtime/device/gpu/cuda_common.h" 26 #include "backend/kernel_compiler/gpu/cuda_impl/reverse_sequence_impl.cuh" 27 #include "backend/kernel_compiler/gpu/kernel_constants.h" 28 29 namespace mindspore { 30 namespace kernel { 31 template <typename T, typename S> 32 class ReverseSequenceGpuFwdKernel : public GpuKernel { 33 public: ReverseSequenceGpuFwdKernel()34 ReverseSequenceGpuFwdKernel() 35 : shape_size_(0), 36 input_size_(0), 37 batch_dim_(0), 38 seq_dim_(0), 39 is_null_input_(false), 40 seq_len_size_(0), 41 total_index_dim_(0), 42 output_size_(0), 43 workspace_size_(0) {} 44 ~ReverseSequenceGpuFwdKernel() override = default; GetInputSizeList()45 const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } GetOutputSizeList()46 const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } GetWorkspaceSizeList()47 const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspace,const std::vector<AddressPtr> & outputs,void * stream_ptr)48 bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, 49 const std::vector<AddressPtr> &outputs, void *stream_ptr) override { 50 if (is_null_input_) { 51 return true; 52 } 53 T *input = GetDeviceAddress<T>(inputs, 0); 54 S *seq_len = GetDeviceAddress<S>(inputs, 1); 55 size_t *input_shape_ptr = GetDeviceAddress<size_t>(workspace, 0); 56 size_t *input_cum_shape_ptr = GetDeviceAddress<size_t>(workspace, 1); 57 size_t *cur_pos_arr = GetDeviceAddress<size_t>(workspace, 2); 58 T *output = GetDeviceAddress<T>(outputs, 0); 59 CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, 60 cudaMemcpyAsync(input_shape_ptr, &input_shape_[0], input_shape_.size() * sizeof(size_t), 61 cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), 62 "cudaMemcpyAsync input_shape_ failed"); 63 CalReverseSequence(input_size_, input, seq_len, batch_dim_, seq_dim_, cur_pos_arr, input_shape_ptr, 64 input_cum_shape_ptr, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr)); 65 return true; 66 } 67 Init(const CNodePtr & kernel_node)68 bool Init(const CNodePtr &kernel_node) override { 69 batch_dim_ = GetAttr<int64_t>(kernel_node, "batch_dim"); 70 seq_dim_ = GetAttr<int64_t>(kernel_node, "seq_dim"); 71 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); 72 if (input_num != 2) { 73 MS_LOG(ERROR) << "Input number is " << input_num << ", but ReverseSequence needs 2 input."; 74 return false; 75 } 76 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); 77 if (output_num != 1) { 78 MS_LOG(ERROR) << "Output number is " << output_num << ", but ReverseSequence needs 1 output."; 79 return false; 80 } 81 input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); 82 auto seq_len_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); 83 is_null_input_ = CHECK_NULL_INPUT(input_shape_) || CHECK_NULL_INPUT(seq_len_shape); 84 if (is_null_input_) { 85 MS_LOG(WARNING) << "For 'ReverseSequenceGpuKernel', input is null."; 86 InitSizeLists(); 87 return true; 88 } 89 if (input_shape_.size() < 1) { 90 MS_LOG(EXCEPTION) << "For 'ReverseSequenceGpuKernel', the rank of input cannot be less than 1, but got " 91 << input_shape_.size(); 92 } 93 input_size_ = 1; 94 shape_size_ = input_shape_.size(); // required for calls 95 for (size_t i = 0; i < shape_size_; i++) { 96 input_size_ *= input_shape_[i]; 97 } 98 // get seq len shape 99 seq_len_size_ = seq_len_shape.size(); 100 output_size_ = input_size_; // size does not change 101 // Allocate workspace memory to use for storing indices for each thread to compute with 102 size_t total_threads = GET_BLOCKS(input_size_) * GET_THREADS; 103 total_index_dim_ = total_threads * shape_size_; 104 InitSizeLists(); 105 return true; 106 } 107 108 protected: InitSizeLists()109 void InitSizeLists() override { 110 input_size_list_.push_back(input_size_ * sizeof(T)); 111 input_size_list_.push_back(seq_len_size_ * sizeof(S)); 112 workspace_size_list_.push_back(shape_size_ * sizeof(size_t)); // input_shape 113 workspace_size_list_.push_back(shape_size_ * sizeof(size_t)); // cumulative shape 114 workspace_size_list_.push_back(total_index_dim_ * sizeof(size_t)); // scratch memory for holding indices per thread 115 output_size_list_.push_back(output_size_ * sizeof(T)); 116 } 117 118 private: 119 size_t shape_size_; 120 size_t input_size_; 121 int64_t batch_dim_; 122 int64_t seq_dim_; 123 bool is_null_input_; 124 size_t seq_len_size_; 125 size_t total_index_dim_; 126 size_t output_size_; 127 size_t workspace_size_; 128 std::vector<size_t> input_shape_; 129 std::vector<size_t> input_size_list_; 130 std::vector<size_t> output_size_list_; 131 std::vector<size_t> workspace_size_list_; 132 }; 133 } // namespace kernel 134 } // namespace mindspore 135 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_SEQUENCE_GPU_KERNEL_H_ 136