1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <memory>
18 #include <vector>
19
20 #include "common/common_test.h"
21
22 #include "ir/anf.h"
23 #include "ir/dtype.h"
24 #include "frontend/operator/prim_to_function.h"
25 #include "base/core_ops.h"
26
27 namespace mindspore {
28 namespace prim {
29
30 class TestPrimFunc : public UT::Common {
31 public:
TestPrimFunc()32 TestPrimFunc() {}
SetUp()33 virtual void SetUp() {}
34 };
35
TEST_F(TestPrimFunc,ScalarAddTest)36 TEST_F(TestPrimFunc, ScalarAddTest) {
37 auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
38 FunctionPtr func = nullptr;
39 PrimToFunction::GetInstance().GetFunction(prim, &func);
40
41 std::vector<std::shared_ptr<Type>> two_args{std::make_shared<Number>(), std::make_shared<Number>()};
42 std::shared_ptr<Type> retval = std::make_shared<Number>();
43 Function func_add = Function(two_args, retval);
44
45 std::cout << "func_add: " + func_add.ToString() << std::endl;
46 std::cout << "prim_func: " + func->ToString() << std::endl;
47
48 ASSERT_EQ(func_add.ToString(), func->ToString());
49 }
50
TEST_F(TestPrimFunc,ScalarExpTest)51 TEST_F(TestPrimFunc, ScalarExpTest) {
52 auto prim = std::make_shared<Primitive>("scalar_exp");
53 FunctionPtr func = nullptr;
54 PrimToFunction::GetInstance().GetFunction(prim, &func);
55
56 std::vector<std::shared_ptr<Type>> one_arg{std::make_shared<Number>()};
57 std::shared_ptr<Type> retval = std::make_shared<Number>();
58 Function func_add = Function(one_arg, retval);
59
60 std::cout << "func_exp: " + func_add.ToString() << std::endl;
61 std::cout << "prim_func: " + func->ToString() << std::endl;
62
63 ASSERT_EQ(func_add.ToString(), func->ToString());
64 }
65
66 } // namespace prim
67 } // namespace mindspore
68