/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include namespace torch { namespace executor { namespace native { namespace internal { #define DEFINE_BINARY_OPERATOR_TEMPLATE(name, op) \ template \ T name(const T val_a, const T val_b) { \ return val_a op val_b; \ } DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_and, &) DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_or, |) DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_xor, ^) template using bitwise_fn = T (*)(const T, const T); template constexpr bitwise_fn get_bitwise_fn() { std::string_view op = op_name; if (op == "bitwise_and.Tensor_out" || op == "bitwise_and.Scalar_out") { return bitwise_and; } if (op == "bitwise_or.Tensor_out" || op == "bitwise_or.Scalar_out") { return bitwise_or; } if (op == "bitwise_xor.Tensor_out" || op == "bitwise_xor.Scalar_out") { return bitwise_xor; } return nullptr; }; template struct BitwiseFnForOp { static constexpr auto value = get_bitwise_fn(); static_assert(value != nullptr, "unknown op_name!"); }; template Tensor& bitwise_tensor_out( RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { // Common Dtype ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); // Check Common Dtype ET_KERNEL_CHECK( ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); // Compute Dtype ScalarType compute_type = utils::get_compute_type(common_type); ET_SWITCH_INT_TYPES_AND( Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { utils::apply_bitensor_elementwise_fn( BitwiseFnForOp::value, ctx, a, utils::SupportedTensorDtypes::INTB, b, utils::SupportedTensorDtypes::INTB, out, utils::SupportedTensorDtypes::REALHBBF16); }); return out; } template Tensor& bitwise_scalar_out( RuntimeContext& ctx, const Tensor& a, const Scalar& b, Tensor& out) { // Common Dtype ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); // Check Common Dtype ET_KERNEL_CHECK( ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); // Resize ET_KERNEL_CHECK( ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); // Compute Dtype ScalarType compute_type = utils::get_compute_type(common_type); ET_SWITCH_INT_TYPES_AND( Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); utils::apply_unitensor_elementwise_fn( [val_b](const CTYPE_COMPUTE val_a) { return BitwiseFnForOp::value( val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::INTB, out, utils::SupportedTensorDtypes::REALHBBF16); }); return out; } } // namespace internal } // namespace native } // namespace executor } // namespace torch