• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 
11 #include <executorch/backends/cadence/hifi/kernels/kernels.h>
12 #include <executorch/kernels/portable/cpu/util/functional_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 using exec_aten::ScalarType;
16 using exec_aten::Tensor;
17 using executorch::aten::RuntimeContext;
18 using torch::executor::Error;
19 
20 namespace cadence {
21 namespace impl {
22 namespace HiFi {
23 namespace native {
24 
25 using Tensor = exec_aten::Tensor;
26 
sigmoid_out(RuntimeContext & ctx,const Tensor & in,Tensor & out)27 Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
28   (void)ctx;
29 
30   ET_KERNEL_CHECK(
31       ctx, in.scalar_type() != ScalarType::Bool, InvalidArgument, out);
32   ET_KERNEL_CHECK(
33       ctx,
34       executorch::runtime::tensor_is_floating_type(out),
35       InvalidArgument,
36       out);
37 
38   // Resize for dynamic shape
39   ET_KERNEL_CHECK_MSG(
40       ctx,
41       resize_tensor(out, in.sizes()) == Error::Ok,
42       InvalidArgument,
43       out,
44       "Failed to resize output tensor.");
45 
46   ScalarType in_type = in.scalar_type();
47   ScalarType out_type = out.scalar_type();
48 
49   bool optimized = 1;
50   if ((in_type != ScalarType::Float) || (out_type != ScalarType::Float))
51     optimized = 0;
52 
53   if (optimized) {
54     float* data_in = in.mutable_data_ptr<float>();
55     float* data_out = out.mutable_data_ptr<float>();
56     xa_nn_vec_sigmoid_f32_f32(data_out, data_in, in.numel());
57 
58     return out;
59   }
60 
61   ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() {
62     ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() {
63       torch::executor::apply_unary_map_fn(
64           [](const CTYPE_IN val_in) {
65             // perform math in double to preserve precision
66             double in_casted = static_cast<double>(val_in);
67             double out_val = 1.0 / (1.0 + exp(-in_casted));
68             return static_cast<CTYPE_OUT>(out_val);
69           },
70           in.const_data_ptr<CTYPE_IN>(),
71           out.mutable_data_ptr<CTYPE_OUT>(),
72           in.numel());
73     });
74   });
75 
76   return out;
77 }
78 
79 } // namespace native
80 } // namespace HiFi
81 } // namespace impl
82 } // namespace cadence