• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "cpu_kernel/ms_kernel/log_normal_reverse.h"
18 
19 #include <Eigen/Core>
20 #include <ctime>
21 #include <set>
22 #include <iostream>
23 #include <random>
24 #include <algorithm>
25 
26 #include "context/inc/cpu_kernel_utils.h"
27 #include "utils/eigen_tensor.h"
28 #include "utils/kernel_util.h"
29 
30 namespace {
31 const uint32_t kNumInput = 1;
32 const uint32_t kNumOutput = 1;
33 
34 const char *kLogNormalReverse = "LogNormalReverse";
35 const int64_t kParallelDataNumSameShape = 16 * 1024;
36 const int64_t kParallelDataNumMid = 128 * 1024;
37 }  // namespace
38 namespace aicpu {
GetInputAndCheck(CpuKernelContext & ctx)39 uint32_t LogNormalReverseCpuKernel::GetInputAndCheck(CpuKernelContext &ctx) {
40   CUST_KERNEL_HANDLE_ERROR(ctx, NormalCheck(ctx, kNumInput, kNumOutput),
41                            "LogNormalReverse check input and output failed.");
42   // get and check input
43   Tensor *input = ctx.Input(0);
44   inputs_.push_back(input);
45 
46   // get output Tensors
47   Tensor *output = ctx.Output(0);
48   outputs_.push_back(output);
49 
50   return KERNEL_STATUS_OK;
51 }
52 
53 template <typename T>
DoCompute(CpuKernelContext & ctx)54 uint32_t LogNormalReverseCpuKernel::DoCompute(CpuKernelContext &ctx) {
55   float input_mean = 1.0;
56   float input_std = 2.0;
57 
58   auto mean_value = ctx.GetAttr("mean");
59   auto std_value = ctx.GetAttr("std");
60 
61   if (mean_value != nullptr) {
62     input_mean = mean_value->GetFloat();
63   }
64   if (std_value != nullptr) {
65     input_std = std_value->GetFloat();
66   }
67 
68   T *output_y = reinterpret_cast<T *>(outputs_[0]->GetData());
69 
70   static std::default_random_engine random_engine(time(0));
71   static std::normal_distribution<float> normal_value(input_mean, input_std);
72 
73   int64_t Nums = inputs_[0]->GetTensorShape()->NumElements();
74 
75   int64_t data_num = Nums;
76   if (data_num >= kParallelDataNumSameShape) {
77     uint32_t max_core_num = std::max(1U, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
78 
79     if (data_num <= kParallelDataNumMid) {
80       max_core_num = std::min(max_core_num, 4U);
81     }
82     if (max_core_num > data_num) {
83       max_core_num = data_num;
84     }
85 
86     auto shared_lognormalreverse = [&](size_t start, size_t end) {
87       for (size_t i = start; i < end; i++) {
88         output_y[i] = static_cast<T>(std::exp(normal_value(random_engine)));
89       }
90     };
91 
92     if (max_core_num == 0) {
93       max_core_num = 1;
94     }
95     CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shared_lognormalreverse);
96   } else {
97     for (int64_t i = 0; i < Nums; i++) {
98       output_y[i] = static_cast<T>(std::exp(normal_value(random_engine)));
99     }
100   }
101   return KERNEL_STATUS_OK;
102 }
103 
Compute(CpuKernelContext & ctx)104 uint32_t LogNormalReverseCpuKernel::Compute(CpuKernelContext &ctx) {
105   uint32_t res = GetInputAndCheck(ctx);
106   if (res != KERNEL_STATUS_OK) {
107     return res;
108   }
109 
110   DataType input_type{ctx.Input(0)->GetDataType()};
111   switch (input_type) {
112     case (DT_FLOAT16): {
113       DoCompute<Eigen::half>(ctx);
114       break;
115     }
116     case (DT_FLOAT): {
117       DoCompute<float>(ctx);
118       break;
119     }
120     default:
121       CUST_KERNEL_LOG_ERROR(ctx, "[%s] Data type of input is not support, input data type is [%s].",
122                             ctx.GetOpType().c_str(), DTypeStr(input_type).c_str());
123       res = KERNEL_STATUS_PARAM_INVALID;
124   }
125   if (res != KERNEL_STATUS_OK) {
126     CUST_KERNEL_LOG_ERROR(ctx, "log normal reverse failed");
127     return res;
128   }
129   return KERNEL_STATUS_OK;
130 }
131 REGISTER_MS_CPU_KERNEL(kLogNormalReverse, LogNormalReverseCpuKernel);
132 }  // namespace aicpu
133