1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ 19 20 #include <vector> 21 #include <chrono> 22 #include <random> 23 #include "backend/kernel_compiler/gpu/gpu_kernel.h" 24 #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" 25 #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" 26 27 namespace mindspore { 28 namespace kernel { 29 template <typename T, typename S> 30 class RandomChoiceWithMaskGpuKernel : public GpuKernel { 31 public: RandomChoiceWithMaskGpuKernel()32 RandomChoiceWithMaskGpuKernel() 33 : input_shape_size_(0), seed_(0), seed2_(0), input_size_(1), count_(0), ceil_power2_(0), is_null_input_(false) {} 34 ~RandomChoiceWithMaskGpuKernel() override = default; 35 GetInputSizeList()36 const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } GetOutputSizeList()37 const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } GetWorkspaceSizeList()38 const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } 39 Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspaces,const std::vector<AddressPtr> & outputs,void * stream_ptr)40 bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, 41 const std::vector<AddressPtr> &outputs, void *stream_ptr) override { 42 if (is_null_input_) { 43 return true; 44 } 45 T *input = GetDeviceAddress<T>(inputs, 0); 46 S *output_index = GetDeviceAddress<S>(outputs, 0); 47 T *output_mask = GetDeviceAddress<T>(outputs, 1); 48 int seedc = 0; 49 if (seed2_ != 0) { 50 seedc = seed2_; 51 } else if (seed_ != 0) { 52 seedc = seed_; 53 } else { 54 seedc = generator_(); 55 } 56 if (count_ > kSmallK || input_shape_size_ > 1) { 57 S *index_buff = GetDeviceAddress<S>(workspaces, 0); 58 S *mask_buff = GetDeviceAddress<S>(workspaces, 1); 59 S *rank_buff = GetDeviceAddress<S>(workspaces, 2); 60 S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3); 61 S *tmp_buff = GetDeviceAddress<S>(workspaces, 4); 62 void *States = GetDeviceAddress<void *>(workspaces, 5); 63 curandState *devStates = reinterpret_cast<curandState *>(States); 64 CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], 65 input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc, count_, input, 66 output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, 67 devStates, reinterpret_cast<cudaStream_t>(stream_ptr)); 68 } else { 69 CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc, count_, input, output_index, output_mask, 70 reinterpret_cast<cudaStream_t>(stream_ptr)); 71 } 72 return true; 73 } 74 Init(const CNodePtr & kernel_node)75 bool Init(const CNodePtr &kernel_node) override { 76 MS_EXCEPTION_IF_NULL(kernel_node); 77 uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count(); 78 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); 79 if (input_num != 1) { 80 MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; 81 return false; 82 } 83 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); 84 if (output_num != 2) { 85 MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs."; 86 return false; 87 } 88 auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); 89 is_null_input_ = CHECK_NULL_INPUT(input_shape); 90 if (is_null_input_) { 91 MS_LOG(WARNING) << "For 'RandomChoiceWithMaskGpuKernel', input is null"; 92 InitSizeLists(); 93 return true; 94 } 95 input_shape_size_ = input_shape.size(); 96 if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) { 97 MS_LOG(ERROR) << "Input is " << input_shape_size_ 98 << "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs."; 99 return false; 100 } 101 // convert size_t to int 102 for (auto i = 0; i < input_shape_size_; i++) { 103 input_shape_5D_.push_back(input_shape[i]); 104 } 105 // convert shape to 5D 106 while (input_shape_5D_.size() != MAX_DIMENSION) { 107 (void)input_shape_5D_.insert(input_shape_5D_.begin(), 1); 108 } 109 // init seedc 110 seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); 111 seed2_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2")); 112 generator_.seed(time_interval); 113 // init memory 114 for (size_t i = 0; i < input_shape.size(); i++) { 115 input_size_ *= input_shape[i]; 116 } 117 count_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "count")); 118 // upper ceiling for input for ceil_power2 119 if (count_ > kSmallK || input_shape_size_ > 1) { 120 ceil_power2_ = RcwmRoundUpPower2(input_size_); 121 } 122 InitSizeLists(); 123 return true; 124 } 125 126 protected: InitSizeLists()127 void InitSizeLists() override { 128 input_size_list_.push_back(input_size_ * sizeof(T)); 129 output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S)); 130 output_size_list_.push_back(count_ * sizeof(T)); 131 if (count_ > kSmallK || input_shape_size_ > 1) { 132 workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); 133 workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); 134 workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); 135 int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE); 136 workspace_size_list_.push_back(blocknum * sizeof(S)); 137 workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); 138 workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); 139 } 140 } 141 142 private: 143 const int kSmallK = 2048; 144 int input_shape_size_; 145 int seed_; 146 int seed2_; 147 int input_size_; 148 int count_; 149 int ceil_power2_; 150 bool is_null_input_; 151 std::mt19937 generator_; 152 std::vector<int> input_shape_5D_; 153 std::vector<size_t> input_size_list_; 154 std::vector<size_t> output_size_list_; 155 std::vector<size_t> workspace_size_list_; 156 }; 157 } // namespace kernel 158 } // namespace mindspore 159 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ 160