1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10
11 #include <executorch/backends/cadence/hifi/kernels/kernels.h>
12 #include <executorch/kernels/portable/cpu/util/functional_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 using exec_aten::ScalarType;
16 using exec_aten::Tensor;
17 using executorch::aten::RuntimeContext;
18 using torch::executor::Error;
19
20 namespace cadence {
21 namespace impl {
22 namespace HiFi {
23 namespace native {
24
25 using Tensor = exec_aten::Tensor;
26
sigmoid_out(RuntimeContext & ctx,const Tensor & in,Tensor & out)27 Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
28 (void)ctx;
29
30 ET_KERNEL_CHECK(
31 ctx, in.scalar_type() != ScalarType::Bool, InvalidArgument, out);
32 ET_KERNEL_CHECK(
33 ctx,
34 executorch::runtime::tensor_is_floating_type(out),
35 InvalidArgument,
36 out);
37
38 // Resize for dynamic shape
39 ET_KERNEL_CHECK_MSG(
40 ctx,
41 resize_tensor(out, in.sizes()) == Error::Ok,
42 InvalidArgument,
43 out,
44 "Failed to resize output tensor.");
45
46 ScalarType in_type = in.scalar_type();
47 ScalarType out_type = out.scalar_type();
48
49 bool optimized = 1;
50 if ((in_type != ScalarType::Float) || (out_type != ScalarType::Float))
51 optimized = 0;
52
53 if (optimized) {
54 float* data_in = in.mutable_data_ptr<float>();
55 float* data_out = out.mutable_data_ptr<float>();
56 xa_nn_vec_sigmoid_f32_f32(data_out, data_in, in.numel());
57
58 return out;
59 }
60
61 ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() {
62 ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() {
63 torch::executor::apply_unary_map_fn(
64 [](const CTYPE_IN val_in) {
65 // perform math in double to preserve precision
66 double in_casted = static_cast<double>(val_in);
67 double out_val = 1.0 / (1.0 + exp(-in_casted));
68 return static_cast<CTYPE_OUT>(out_val);
69 },
70 in.const_data_ptr<CTYPE_IN>(),
71 out.mutable_data_ptr<CTYPE_OUT>(),
72 in.numel());
73 });
74 });
75
76 return out;
77 }
78
79 } // namespace native
80 } // namespace HiFi
81 } // namespace impl
82 } // namespace cadence