• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
19 
20 #include <vector>
21 #include <chrono>
22 #include <random>
23 #include "backend/kernel_compiler/gpu/gpu_kernel.h"
24 #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
25 #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh"
26 
27 namespace mindspore {
28 namespace kernel {
29 template <typename T, typename S>
30 class RandomChoiceWithMaskGpuKernel : public GpuKernel {
31  public:
RandomChoiceWithMaskGpuKernel()32   RandomChoiceWithMaskGpuKernel()
33       : input_shape_size_(0), seed_(0), seed2_(0), input_size_(1), count_(0), ceil_power2_(0), is_null_input_(false) {}
34   ~RandomChoiceWithMaskGpuKernel() override = default;
35 
GetInputSizeList()36   const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
GetOutputSizeList()37   const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
GetWorkspaceSizeList()38   const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
39 
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspaces,const std::vector<AddressPtr> & outputs,void * stream_ptr)40   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
41               const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
42     if (is_null_input_) {
43       return true;
44     }
45     T *input = GetDeviceAddress<T>(inputs, 0);
46     S *output_index = GetDeviceAddress<S>(outputs, 0);
47     T *output_mask = GetDeviceAddress<T>(outputs, 1);
48     int seedc = 0;
49     if (seed2_ != 0) {
50       seedc = seed2_;
51     } else if (seed_ != 0) {
52       seedc = seed_;
53     } else {
54       seedc = generator_();
55     }
56     if (count_ > kSmallK || input_shape_size_ > 1) {
57       S *index_buff = GetDeviceAddress<S>(workspaces, 0);
58       S *mask_buff = GetDeviceAddress<S>(workspaces, 1);
59       S *rank_buff = GetDeviceAddress<S>(workspaces, 2);
60       S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3);
61       S *tmp_buff = GetDeviceAddress<S>(workspaces, 4);
62       void *States = GetDeviceAddress<void *>(workspaces, 5);
63       curandState *devStates = reinterpret_cast<curandState *>(States);
64       CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1],
65                               input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc, count_, input,
66                               output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff,
67                               devStates, reinterpret_cast<cudaStream_t>(stream_ptr));
68     } else {
69       CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc, count_, input, output_index, output_mask,
70                                                 reinterpret_cast<cudaStream_t>(stream_ptr));
71     }
72     return true;
73   }
74 
Init(const CNodePtr & kernel_node)75   bool Init(const CNodePtr &kernel_node) override {
76     MS_EXCEPTION_IF_NULL(kernel_node);
77     uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
78     size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
79     if (input_num != 1) {
80       MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input.";
81       return false;
82     }
83     size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
84     if (output_num != 2) {
85       MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs.";
86       return false;
87     }
88     auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
89     is_null_input_ = CHECK_NULL_INPUT(input_shape);
90     if (is_null_input_) {
91       MS_LOG(WARNING) << "For 'RandomChoiceWithMaskGpuKernel', input is null";
92       InitSizeLists();
93       return true;
94     }
95     input_shape_size_ = input_shape.size();
96     if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) {
97       MS_LOG(ERROR) << "Input is " << input_shape_size_
98                     << "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs.";
99       return false;
100     }
101     // convert size_t to int
102     for (auto i = 0; i < input_shape_size_; i++) {
103       input_shape_5D_.push_back(input_shape[i]);
104     }
105     // convert shape to 5D
106     while (input_shape_5D_.size() != MAX_DIMENSION) {
107       (void)input_shape_5D_.insert(input_shape_5D_.begin(), 1);
108     }
109     // init seedc
110     seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
111     seed2_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2"));
112     generator_.seed(time_interval);
113     // init memory
114     for (size_t i = 0; i < input_shape.size(); i++) {
115       input_size_ *= input_shape[i];
116     }
117     count_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "count"));
118     // upper ceiling for input for ceil_power2
119     if (count_ > kSmallK || input_shape_size_ > 1) {
120       ceil_power2_ = RcwmRoundUpPower2(input_size_);
121     }
122     InitSizeLists();
123     return true;
124   }
125 
126  protected:
InitSizeLists()127   void InitSizeLists() override {
128     input_size_list_.push_back(input_size_ * sizeof(T));
129     output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S));
130     output_size_list_.push_back(count_ * sizeof(T));
131     if (count_ > kSmallK || input_shape_size_ > 1) {
132       workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S));
133       workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
134       workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
135       int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE);
136       workspace_size_list_.push_back(blocknum * sizeof(S));
137       workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
138       workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState));
139     }
140   }
141 
142  private:
143   const int kSmallK = 2048;
144   int input_shape_size_;
145   int seed_;
146   int seed2_;
147   int input_size_;
148   int count_;
149   int ceil_power2_;
150   bool is_null_input_;
151   std::mt19937 generator_;
152   std::vector<int> input_shape_5D_;
153   std::vector<size_t> input_size_list_;
154   std::vector<size_t> output_size_list_;
155   std::vector<size_t> workspace_size_list_;
156 };
157 }  // namespace kernel
158 }  // namespace mindspore
159 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
160