1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <memory>
18
19 #include "common/common_test.h"
20
21 #include "ir/anf.h"
22 #include "base/base_ref.h"
23
24 namespace mindspore {
25 namespace utils {
26 class TestBaseRef : public UT::Common {
27 public:
TestBaseRef()28 TestBaseRef() {}
SetUp()29 virtual void SetUp() {}
TearDown()30 virtual void TearDown() {}
31 };
32
TEST_F(TestBaseRef,TestScalar)33 TEST_F(TestBaseRef, TestScalar) {
34 BaseRef a = static_cast<int64_t>(1);
35 BaseRef b = 1.0;
36 if (isa<int64_t>(a)) {
37 ASSERT_EQ(cast<int64_t>(a), 1);
38 Int64ImmPtr c = cast<Int64ImmPtr>(a);
39 ASSERT_EQ(cast<int64_t>(c), 1);
40 }
41 ASSERT_TRUE(isa<Int64Imm>(a));
42 ASSERT_TRUE(isa<BaseRef>(a));
43 ASSERT_TRUE(isa<double>(b));
44 ASSERT_TRUE(isa<FP64Imm>(b));
45 BaseRef c = static_cast<int64_t>(1);
46 ASSERT_EQ(a == c, true);
47 }
48
func(const BaseRef & sexp)49 void func(const BaseRef& sexp) {
50 if (isa<VectorRef>(sexp)) {
51 const VectorRef& a = cast<VectorRef>(sexp);
52 for (size_t i = 0; i < a.size(); i++) {
53 BaseRef v = a[i];
54 MS_LOG(INFO) << "for is i:" << i << ", " << v.ToString() << "\n";
55 }
56 MS_LOG(INFO) << "in func is valuesequeue:" << sexp.ToString() << "\n";
57 }
58 }
59
TEST_F(TestBaseRef,TestNode)60 TEST_F(TestBaseRef, TestNode) {
61 AnfNodePtr anf = NewValueNode(static_cast<int64_t>(1));
62 BaseRef d = anf;
63 MS_LOG(INFO) << "anf typeid:" << dyn_cast<AnfNode>(anf).get();
64 MS_LOG(INFO) << "anf typeid:" << NewValueNode(static_cast<int64_t>(1))->tid();
65 MS_LOG(INFO) << "node reftypeid:" << d.tid();
66 ASSERT_EQ(isa<AnfNodePtr>(d), true);
67 ASSERT_EQ(isa<AnfNode>(d), true);
68 ASSERT_EQ(isa<ValueNode>(d), true);
69 AnfNodePtr c = cast<ValueNodePtr>(d);
70 ASSERT_NE(c, nullptr);
71 }
72
TEST_F(TestBaseRef,TestVector)73 TEST_F(TestBaseRef, TestVector) {
74 AnfNodePtr anf = NewValueNode(static_cast<int64_t>(1));
75 VectorRef a({static_cast<int64_t>(1), static_cast<int64_t>(2), anf, NewValueNode(static_cast<int64_t>(1))});
76 ASSERT_TRUE(isa<VectorRef>(a));
77 func(a);
78 BaseRef b;
79 b = static_cast<int64_t>(1);
80 ASSERT_TRUE(isa<int64_t>(b));
81 std::vector<int64_t> int64({1, 2, 3});
82 VectorRef k;
83 k.insert(k.end(), int64.begin(), int64.end());
84
85 k = a;
86 func(k);
87
88 BaseRef c = std::make_shared<VectorRef>(a);
89 BaseRef c1 = std::make_shared<VectorRef>(a);
90 ASSERT_TRUE(c == c1);
91
92 ASSERT_TRUE(isa<VectorRef>(c));
93 VectorRefPtr d = cast<VectorRefPtr>(c);
94 ASSERT_TRUE(isa<VectorRef>(d));
95 VectorRef e1({static_cast<int64_t>(1), static_cast<int64_t>(2), anf});
96 VectorRef e({static_cast<int64_t>(1), static_cast<int64_t>(2), anf});
97 ASSERT_EQ(e1 == e, true);
98 }
99 } // namespace utils
100 } // namespace mindspore
101