• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/base/random_standard_normal.h"
18 #include <random>
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21 #ifndef CONTROLFLOW_TENSORLIST_CLIP
22 #include "src/tensorlist.h"
23 #endif
24 
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_RandomStandardNormal;
29 
30 namespace mindspore::kernel {
Init()31 int RandomStandardNormalCPUKernel::Init() { return RET_OK; }
32 
ReSize()33 int RandomStandardNormalCPUKernel::ReSize() { return RET_OK; }
34 
Run()35 int RandomStandardNormalCPUKernel::Run() {
36   size_t random_seed = 0;
37   if (param_->seed2_ != 0) {
38     random_seed = static_cast<size_t>(param_->seed2_);
39   } else if (param_->seed_ != 0) {
40     random_seed = static_cast<size_t>(param_->seed_);
41   } else {
42     random_seed = static_cast<size_t>(clock());
43   }
44   std::default_random_engine engine{static_cast<unsigned int>(random_seed)};
45   std::normal_distribution<double> nums(0, 1.0);
46   auto all_data_nums = out_tensors_[0]->ElementsNum();
47   auto out_data = out_tensors_[0]->data();
48   MS_ASSERT(out_data != nullptr);
49   auto output = reinterpret_cast<float *>(out_data);
50 
51   std::generate_n(output, all_data_nums, [&]() { return nums(engine); });
52   return RET_OK;
53 }
54 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_RandomStandardNormal, LiteKernelCreator<RandomStandardNormalCPUKernel>)
55 }  // namespace mindspore::kernel
56