/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include // Declares the operator #include #include #include #include #include #include #include using namespace ::testing; using exec_aten::ArrayRef; using exec_aten::optional; using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; class OpArgminTest : public OperatorTest { protected: Tensor& op_argmin_out( const Tensor& in, optional dim, bool keepdim, Tensor& out) { return torch::executor::aten::argmin_outf(context_, in, dim, keepdim, out); } }; TEST_F(OpArgminTest, SanityCheckLong) { TensorFactory tf; // clang-format off Tensor in = tf.make( { 2, 3, 4 }, { 1, 4, 1, 6, 5, 8, 5, 6, 5, 3, 9, 2, 3, 9, 1, 4, 9, 7, 5, 5, 7, 7, 6, 3 }); Tensor out = tf.zeros({2, 4}); Tensor expected = tf.make({2, 4}, { 0, 2, 0, 2, 0, 1, 0, 2 }); Tensor ret = op_argmin_out(in, 1, false, out); EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, expected); // clang-format on } TEST_F(OpArgminTest, SanityCheckShort) { TensorFactory tfl; TensorFactory tfs; // clang-format off Tensor in = tfs.make( { 2, 3, 4 }, { 1, 4, 1, 6, 5, 8, 5, 6, 5, 3, 9, 2, 3, 9, 1, 4, 9, 7, 5, 5, 7, 7, 6, 3 }); Tensor out = tfl.zeros({2, 4}); Tensor expected = tfl.make({2, 4}, { 0, 2, 0, 2, 0, 1, 0, 2 }); Tensor ret = op_argmin_out(in, 1, false, out); EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, expected); // clang-format on } TEST_F(OpArgminTest, SanityCheckNullDim) { TensorFactory tf; // clang-format off Tensor in = tf.make( { 2, 3, 4 }, { 9, 4, 1, 6, 5, 8, 5, 6, 5, 3, 9, 2, 3, 9, 1, 4, 9, 7, 5, 5, 7, 7, 6, 3 }); Tensor out = tf.zeros({}); Tensor expected = tf.make({}, {2}); optional dim; Tensor ret = op_argmin_out(in, dim, false, out); EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, expected); // clang-format on }