1 /**
2 * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "cpu_kernel/ms_kernel/log_normal_reverse.h"
18
19 #include <Eigen/Core>
20 #include <ctime>
21 #include <set>
22 #include <iostream>
23 #include <random>
24 #include <algorithm>
25
26 #include "context/inc/cpu_kernel_utils.h"
27 #include "utils/eigen_tensor.h"
28 #include "utils/kernel_util.h"
29
30 namespace {
31 const uint32_t kNumInput = 1;
32 const uint32_t kNumOutput = 1;
33
34 const char *kLogNormalReverse = "LogNormalReverse";
35 const int64_t kParallelDataNumSameShape = 16 * 1024;
36 const int64_t kParallelDataNumMid = 128 * 1024;
37 } // namespace
38 namespace aicpu {
GetInputAndCheck(CpuKernelContext & ctx)39 uint32_t LogNormalReverseCpuKernel::GetInputAndCheck(CpuKernelContext &ctx) {
40 CUST_KERNEL_HANDLE_ERROR(ctx, NormalCheck(ctx, kNumInput, kNumOutput),
41 "LogNormalReverse check input and output failed.");
42 // get and check input
43 Tensor *input = ctx.Input(0);
44 inputs_.push_back(input);
45
46 // get output Tensors
47 Tensor *output = ctx.Output(0);
48 outputs_.push_back(output);
49
50 return KERNEL_STATUS_OK;
51 }
52
53 template <typename T>
DoCompute(CpuKernelContext & ctx)54 uint32_t LogNormalReverseCpuKernel::DoCompute(CpuKernelContext &ctx) {
55 float input_mean = 1.0;
56 float input_std = 2.0;
57
58 auto mean_value = ctx.GetAttr("mean");
59 auto std_value = ctx.GetAttr("std");
60
61 if (mean_value != nullptr) {
62 input_mean = mean_value->GetFloat();
63 }
64 if (std_value != nullptr) {
65 input_std = std_value->GetFloat();
66 }
67
68 T *output_y = reinterpret_cast<T *>(outputs_[0]->GetData());
69
70 static std::default_random_engine random_engine(time(0));
71 static std::normal_distribution<float> normal_value(input_mean, input_std);
72
73 int64_t Nums = inputs_[0]->GetTensorShape()->NumElements();
74
75 int64_t data_num = Nums;
76 if (data_num >= kParallelDataNumSameShape) {
77 uint32_t max_core_num = std::max(1U, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
78
79 if (data_num <= kParallelDataNumMid) {
80 max_core_num = std::min(max_core_num, 4U);
81 }
82 if (max_core_num > data_num) {
83 max_core_num = data_num;
84 }
85
86 auto shared_lognormalreverse = [&](size_t start, size_t end) {
87 for (size_t i = start; i < end; i++) {
88 output_y[i] = static_cast<T>(std::exp(normal_value(random_engine)));
89 }
90 };
91
92 if (max_core_num == 0) {
93 max_core_num = 1;
94 }
95 CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shared_lognormalreverse);
96 } else {
97 for (int64_t i = 0; i < Nums; i++) {
98 output_y[i] = static_cast<T>(std::exp(normal_value(random_engine)));
99 }
100 }
101 return KERNEL_STATUS_OK;
102 }
103
Compute(CpuKernelContext & ctx)104 uint32_t LogNormalReverseCpuKernel::Compute(CpuKernelContext &ctx) {
105 uint32_t res = GetInputAndCheck(ctx);
106 if (res != KERNEL_STATUS_OK) {
107 return res;
108 }
109
110 DataType input_type{ctx.Input(0)->GetDataType()};
111 switch (input_type) {
112 case (DT_FLOAT16): {
113 DoCompute<Eigen::half>(ctx);
114 break;
115 }
116 case (DT_FLOAT): {
117 DoCompute<float>(ctx);
118 break;
119 }
120 default:
121 CUST_KERNEL_LOG_ERROR(ctx, "[%s] Data type of input is not support, input data type is [%s].",
122 ctx.GetOpType().c_str(), DTypeStr(input_type).c_str());
123 res = KERNEL_STATUS_PARAM_INVALID;
124 }
125 if (res != KERNEL_STATUS_OK) {
126 CUST_KERNEL_LOG_ERROR(ctx, "log normal reverse failed");
127 return res;
128 }
129 return KERNEL_STATUS_OK;
130 }
131 REGISTER_MS_CPU_KERNEL(kLogNormalReverse, LogNormalReverseCpuKernel);
132 } // namespace aicpu
133