1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10 #include <executorch/kernels/portable/cpu/util/math_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <executorch/runtime/platform/assert.h>
13 #include <cmath>
14 #include <type_traits>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
floor_divide_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)20 Tensor& floor_divide_out(
21 KernelRuntimeContext& ctx,
22 const Tensor& a,
23 const Tensor& b,
24 Tensor& out) {
25 // Common Dtype
26 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27
28 // Check Common Dtype
29 ET_KERNEL_CHECK(
30 ctx,
31 (canCast(common_type, out.scalar_type()) &&
32 common_type != ScalarType::Bool),
33 InvalidArgument,
34 out);
35
36 // Check Dim Order
37 ET_KERNEL_CHECK(
38 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39
40 // Resize
41 ET_KERNEL_CHECK(
42 ctx,
43 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44 InvalidArgument,
45 out);
46
47 // Compute Dtype
48 ScalarType compute_type = utils::get_compute_type(common_type);
49
50 // @lint-ignore CLANGTIDY facebook-hte-CArray
51 static constexpr const char op_name[] = "floor_divide.out";
52
53 bool div_by_zero_error = false;
54
55 ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57 [&div_by_zero_error](
58 const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59 if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
60 if (val_b == 0) {
61 div_by_zero_error = true;
62 return static_cast<CTYPE_COMPUTE>(0);
63 }
64 }
65 return utils::floor_divide(val_a, val_b);
66 },
67 ctx,
68 a,
69 utils::SupportedTensorDtypes::REALHBBF16,
70 b,
71 utils::SupportedTensorDtypes::REALHBBF16,
72 out,
73 utils::SupportedTensorDtypes::REALHBF16);
74 });
75
76 ET_KERNEL_CHECK_MSG(
77 ctx,
78 !div_by_zero_error,
79 InvalidArgument,
80 out,
81 "Floor divide operation encountered integer division by zero");
82
83 return out;
84 }
85
86 } // namespace native
87 } // namespace executor
88 } // namespace torch
89