• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 
19 #include "common/common_test.h"
20 
21 #include "ir/anf.h"
22 #include "base/base_ref.h"
23 
24 namespace mindspore {
25 namespace utils {
26 class TestBaseRef : public UT::Common {
27  public:
TestBaseRef()28   TestBaseRef() {}
SetUp()29   virtual void SetUp() {}
TearDown()30   virtual void TearDown() {}
31 };
32 
TEST_F(TestBaseRef,TestScalar)33 TEST_F(TestBaseRef, TestScalar) {
34   BaseRef a = static_cast<int64_t>(1);
35   BaseRef b = 1.0;
36   if (isa<int64_t>(a)) {
37     ASSERT_EQ(cast<int64_t>(a), 1);
38     Int64ImmPtr c = cast<Int64ImmPtr>(a);
39     ASSERT_EQ(cast<int64_t>(c), 1);
40   }
41   ASSERT_TRUE(isa<Int64Imm>(a));
42   ASSERT_TRUE(isa<BaseRef>(a));
43   ASSERT_TRUE(isa<double>(b));
44   ASSERT_TRUE(isa<FP64Imm>(b));
45   BaseRef c = static_cast<int64_t>(1);
46   ASSERT_EQ(a == c, true);
47 }
48 
func(const BaseRef & sexp)49 void func(const BaseRef& sexp) {
50   if (isa<VectorRef>(sexp)) {
51     const VectorRef& a = cast<VectorRef>(sexp);
52     for (size_t i = 0; i < a.size(); i++) {
53       BaseRef v = a[i];
54       MS_LOG(INFO) << "for is i:" << i << ", " << v.ToString() << "\n";
55     }
56     MS_LOG(INFO) << "in func  is valuesequeue:" << sexp.ToString() << "\n";
57   }
58 }
59 
TEST_F(TestBaseRef,TestNode)60 TEST_F(TestBaseRef, TestNode) {
61   AnfNodePtr anf = NewValueNode(static_cast<int64_t>(1));
62   BaseRef d = anf;
63   MS_LOG(INFO) << "anf typeid:" << dyn_cast<AnfNode>(anf).get();
64   MS_LOG(INFO) << "anf typeid:" << NewValueNode(static_cast<int64_t>(1))->tid();
65   MS_LOG(INFO) << "node reftypeid:" << d.tid();
66   ASSERT_EQ(isa<AnfNodePtr>(d), true);
67   ASSERT_EQ(isa<AnfNode>(d), true);
68   ASSERT_EQ(isa<ValueNode>(d), true);
69   AnfNodePtr c = cast<ValueNodePtr>(d);
70   ASSERT_NE(c, nullptr);
71 }
72 
TEST_F(TestBaseRef,TestVector)73 TEST_F(TestBaseRef, TestVector) {
74   AnfNodePtr anf = NewValueNode(static_cast<int64_t>(1));
75   VectorRef a({static_cast<int64_t>(1), static_cast<int64_t>(2), anf, NewValueNode(static_cast<int64_t>(1))});
76   ASSERT_TRUE(isa<VectorRef>(a));
77   func(a);
78   BaseRef b;
79   b = static_cast<int64_t>(1);
80   ASSERT_TRUE(isa<int64_t>(b));
81   std::vector<int64_t> int64({1, 2, 3});
82   VectorRef k;
83   k.insert(k.end(), int64.begin(), int64.end());
84 
85   k = a;
86   func(k);
87 
88   BaseRef c = std::make_shared<VectorRef>(a);
89   BaseRef c1 = std::make_shared<VectorRef>(a);
90   ASSERT_TRUE(c == c1);
91 
92   ASSERT_TRUE(isa<VectorRef>(c));
93   VectorRefPtr d = cast<VectorRefPtr>(c);
94   ASSERT_TRUE(isa<VectorRef>(d));
95   VectorRef e1({static_cast<int64_t>(1), static_cast<int64_t>(2), anf});
96   VectorRef e({static_cast<int64_t>(1), static_cast<int64_t>(2), anf});
97   ASSERT_EQ(e1 == e, true);
98 }
99 }  // namespace utils
100 }  // namespace mindspore
101