/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include namespace torch { namespace executor { namespace native { namespace { ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { if (isFloatingType(a_type) && isFloatingType(b_type)) { return promoteTypes(a_type, b_type); } else if (isFloatingType(a_type)) { return a_type; } else if (isFloatingType(b_type)) { return b_type; } return ScalarType::Float; } } // namespace Tensor& atan2_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { // Common Dtype ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); // Compute Dtype ScalarType compute_type = utils::get_compute_type(common_type); // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "atan2.out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { utils::apply_bitensor_elementwise_fn( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return std::atan2(val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, out, utils::SupportedTensorDtypes::FLOATHBF16); }); return out; } } // namespace native } // namespace executor } // namespace torch