1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
18
19 #include <vector>
20
21 #include "utils/base/statusor.h"
22 #include "utils/flatbuffers/flatbuffers.h"
23 #include "utils/flatbuffers/reflection.h"
24 #include "utils/flatbuffers/test-utils.h"
25 #include "utils/grammar/semantics/evaluator.h"
26 #include "utils/grammar/semantics/evaluators/const-eval.h"
27 #include "utils/grammar/semantics/expression_generated.h"
28 #include "utils/grammar/testing/utils.h"
29 #include "utils/grammar/testing/value_generated.h"
30 #include "gmock/gmock.h"
31 #include "gtest/gtest.h"
32 #include "flatbuffers/flatbuffers.h"
33
34 namespace libtextclassifier3::grammar {
35 namespace {
36
37 template <typename T>
38 class ArithmeticExpressionEvaluatorTest : public GrammarTest {
39 protected:
Eval(const ArithmeticExpression_::Operator op)40 T Eval(const ArithmeticExpression_::Operator op) {
41 ArithmeticExpressionT arithmetic_expression;
42 arithmetic_expression.base_type = flatbuffers_base_type<T>::value;
43 arithmetic_expression.op = op;
44 arithmetic_expression.values.push_back(
45 CreatePrimitiveConstExpression<T>(1));
46 arithmetic_expression.values.push_back(
47 CreatePrimitiveConstExpression<T>(2));
48 arithmetic_expression.values.push_back(
49 CreatePrimitiveConstExpression<T>(3));
50 arithmetic_expression.values.push_back(
51 CreatePrimitiveConstExpression<T>(4));
52 OwnedFlatbuffer<SemanticExpression> expression =
53 CreateExpression(std::move(arithmetic_expression));
54
55 // Setup evaluators.
56 ConstEvaluator const_eval(semantic_values_schema_.get());
57 ArithmeticExpressionEvaluator arithmetic_eval(&const_eval);
58
59 // Run evaluator.
60 StatusOr<const SemanticValue*> result =
61 arithmetic_eval.Apply(/*context=*/{}, expression.get(), &arena_);
62
63 // Check result.
64 EXPECT_TRUE(result.ok());
65 const SemanticValue* result_value = result.ValueOrDie();
66 EXPECT_TRUE(result_value != nullptr);
67 return result_value->Value<T>();
68 }
69 };
70
71 using NumberTypes = ::testing::Types<int8, uint8, int16, uint16, int32, uint32,
72 int64, uint64, double, float>;
73 TYPED_TEST_SUITE(ArithmeticExpressionEvaluatorTest, NumberTypes);
74
TYPED_TEST(ArithmeticExpressionEvaluatorTest,ParsesNumber)75 TYPED_TEST(ArithmeticExpressionEvaluatorTest, ParsesNumber) {
76 EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_ADD), 10);
77 EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MUL), 24);
78 EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MIN), 1);
79 EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MAX), 4);
80 }
81
82 } // namespace
83 } // namespace libtextclassifier3::grammar
84