1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 19 #include "common/common_test.h" 20 21 #include "ir/anf.h" 22 #include "base/base_ref.h" 23 24 namespace mindspore { 25 namespace utils { 26 class TestBaseRef : public UT::Common { 27 public: 28 TestBaseRef() {} 29 virtual void SetUp() {} 30 virtual void TearDown() {} 31 }; 32 33 TEST_F(TestBaseRef, TestScalar) { 34 BaseRef a = static_cast<int64_t>(1); 35 BaseRef b = 1.0; 36 if (isa<int64_t>(a)) { 37 ASSERT_EQ(cast<int64_t>(a), 1); 38 Int64ImmPtr c = cast<Int64ImmPtr>(a); 39 ASSERT_EQ(cast<int64_t>(c), 1); 40 } 41 ASSERT_TRUE(isa<Int64Imm>(a)); 42 ASSERT_TRUE(isa<BaseRef>(a)); 43 ASSERT_TRUE(isa<double>(b)); 44 ASSERT_TRUE(isa<FP64Imm>(b)); 45 BaseRef c = static_cast<int64_t>(1); 46 ASSERT_EQ(a == c, true); 47 } 48 49 void func(const BaseRef& sexp) { 50 if (isa<VectorRef>(sexp)) { 51 const VectorRef& a = cast<VectorRef>(sexp); 52 for (size_t i = 0; i < a.size(); i++) { 53 BaseRef v = a[i]; 54 MS_LOG(INFO) << "for is i:" << i << ", " << v.ToString() << "\n"; 55 } 56 MS_LOG(INFO) << "in func is valuesequeue:" << sexp.ToString() << "\n"; 57 } 58 } 59 60 TEST_F(TestBaseRef, TestNode) { 61 AnfNodePtr anf = NewValueNode(static_cast<int64_t>(1)); 62 BaseRef d = anf; 63 MS_LOG(INFO) << "anf typeid:" << dyn_cast<AnfNode>(anf).get(); 64 MS_LOG(INFO) << "anf typeid:" << NewValueNode(static_cast<int64_t>(1))->tid(); 65 MS_LOG(INFO) << "node reftypeid:" << d.tid(); 66 ASSERT_EQ(isa<AnfNodePtr>(d), true); 67 ASSERT_EQ(isa<AnfNode>(d), true); 68 ASSERT_EQ(isa<ValueNode>(d), true); 69 AnfNodePtr c = cast<ValueNodePtr>(d); 70 ASSERT_NE(c, nullptr); 71 } 72 73 TEST_F(TestBaseRef, TestVector) { 74 AnfNodePtr anf = NewValueNode(static_cast<int64_t>(1)); 75 VectorRef a({static_cast<int64_t>(1), static_cast<int64_t>(2), anf, NewValueNode(static_cast<int64_t>(1))}); 76 ASSERT_TRUE(isa<VectorRef>(a)); 77 func(a); 78 BaseRef b; 79 b = static_cast<int64_t>(1); 80 ASSERT_TRUE(isa<int64_t>(b)); 81 std::vector<int64_t> int64({1, 2, 3}); 82 VectorRef k; 83 k.insert(k.end(), int64.begin(), int64.end()); 84 85 k = a; 86 func(k); 87 88 BaseRef c = std::make_shared<VectorRef>(a); 89 BaseRef c1 = std::make_shared<VectorRef>(a); 90 ASSERT_TRUE(c == c1); 91 92 ASSERT_TRUE(isa<VectorRef>(c)); 93 VectorRefPtr d = cast<VectorRefPtr>(c); 94 ASSERT_TRUE(isa<VectorRef>(d)); 95 VectorRef e1({static_cast<int64_t>(1), static_cast<int64_t>(2), anf}); 96 VectorRef e({static_cast<int64_t>(1), static_cast<int64_t>(2), anf}); 97 ASSERT_EQ(e1 == e, true); 98 } 99 } // namespace utils 100 } // namespace mindspore 101