• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10 #include <executorch/kernels/portable/cpu/util/math_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <executorch/runtime/platform/assert.h>
13 #include <cmath>
14 #include <type_traits>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
floor_divide_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)20 Tensor& floor_divide_out(
21     KernelRuntimeContext& ctx,
22     const Tensor& a,
23     const Tensor& b,
24     Tensor& out) {
25   // Common Dtype
26   ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27 
28   // Check Common Dtype
29   ET_KERNEL_CHECK(
30       ctx,
31       (canCast(common_type, out.scalar_type()) &&
32        common_type != ScalarType::Bool),
33       InvalidArgument,
34       out);
35 
36   // Check Dim Order
37   ET_KERNEL_CHECK(
38       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39 
40   // Resize
41   ET_KERNEL_CHECK(
42       ctx,
43       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44       InvalidArgument,
45       out);
46 
47   // Compute Dtype
48   ScalarType compute_type = utils::get_compute_type(common_type);
49 
50   // @lint-ignore CLANGTIDY facebook-hte-CArray
51   static constexpr const char op_name[] = "floor_divide.out";
52 
53   bool div_by_zero_error = false;
54 
55   ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56     utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57         [&div_by_zero_error](
58             const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59           if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
60             if (val_b == 0) {
61               div_by_zero_error = true;
62               return static_cast<CTYPE_COMPUTE>(0);
63             }
64           }
65           return utils::floor_divide(val_a, val_b);
66         },
67         ctx,
68         a,
69         utils::SupportedTensorDtypes::REALHBBF16,
70         b,
71         utils::SupportedTensorDtypes::REALHBBF16,
72         out,
73         utils::SupportedTensorDtypes::REALHBF16);
74   });
75 
76   ET_KERNEL_CHECK_MSG(
77       ctx,
78       !div_by_zero_error,
79       InvalidArgument,
80       out,
81       "Floor divide operation encountered integer division by zero");
82 
83   return out;
84 }
85 
86 } // namespace native
87 } // namespace executor
88 } // namespace torch
89