• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 #include <vector>
19 
20 #include "common/common_test.h"
21 
22 #include "ir/anf.h"
23 #include "ir/dtype.h"
24 #include "frontend/operator/prim_to_function.h"
25 #include "base/core_ops.h"
26 
27 namespace mindspore {
28 namespace prim {
29 
30 class TestPrimFunc : public UT::Common {
31  public:
TestPrimFunc()32   TestPrimFunc() {}
SetUp()33   virtual void SetUp() {}
34 };
35 
TEST_F(TestPrimFunc,ScalarAddTest)36 TEST_F(TestPrimFunc, ScalarAddTest) {
37   auto prim = std::make_shared<Primitive>(prim::kScalarAdd);
38   FunctionPtr func = nullptr;
39   PrimToFunction::GetInstance().GetFunction(prim, &func);
40 
41   std::vector<std::shared_ptr<Type>> two_args{std::make_shared<Number>(), std::make_shared<Number>()};
42   std::shared_ptr<Type> retval = std::make_shared<Number>();
43   Function func_add = Function(two_args, retval);
44 
45   std::cout << "func_add: " + func_add.ToString() << std::endl;
46   std::cout << "prim_func: " + func->ToString() << std::endl;
47 
48   ASSERT_EQ(func_add.ToString(), func->ToString());
49 }
50 
TEST_F(TestPrimFunc,ScalarExpTest)51 TEST_F(TestPrimFunc, ScalarExpTest) {
52   auto prim = std::make_shared<Primitive>("scalar_exp");
53   FunctionPtr func = nullptr;
54   PrimToFunction::GetInstance().GetFunction(prim, &func);
55 
56   std::vector<std::shared_ptr<Type>> one_arg{std::make_shared<Number>()};
57   std::shared_ptr<Type> retval = std::make_shared<Number>();
58   Function func_add = Function(one_arg, retval);
59 
60   std::cout << "func_exp: " + func_add.ToString() << std::endl;
61   std::cout << "prim_func: " + func->ToString() << std::endl;
62 
63   ASSERT_EQ(func_add.ToString(), func->ToString());
64 }
65 
66 }  // namespace prim
67 }  // namespace mindspore
68