1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 19 #include "common/common_test.h" 20 #include "common/py_func_graph_fetcher.h" 21 #include "pipeline/jit/static_analysis/prim.h" 22 #include "frontend/operator/ops.h" 23 #include "abstract/utils.h" 24 25 namespace mindspore { 26 namespace abstract { 27 28 class TestData : public UT::Common { 29 public: 30 void SetUp(); 31 void TearDown(); 32 }; 33 34 void TestData::SetUp() { UT::InitPythonPath(); } 35 36 void TestData::TearDown() { 37 // destroy resource 38 } 39 40 TEST_F(TestData, test_build_value) { 41 // assert build_value(S(1)) == 1 42 AbstractScalar s1 = AbstractScalar(static_cast<int64_t>(1)); 43 ASSERT_EQ(1, s1.BuildValue()->cast<Int64ImmPtr>()->value()); 44 // assert build_value(S(t=ty.Int[64]), default=ANYTHING) is ANYTHING 45 s1 = AbstractScalar(kAnyValue, kInt64); 46 ASSERT_TRUE(s1.BuildValue()->isa<AnyValue>()); 47 ASSERT_TRUE(s1.BuildValue()->isa<AnyValue>()); 48 49 // assert build_value(T([S(1), S(2)])) == (1, 2) 50 AbstractBasePtr base1 = std::make_shared<AbstractScalar>(static_cast<int64_t>(1)); 51 AbstractBasePtr base2 = std::make_shared<AbstractScalar>(static_cast<int64_t>(2)); 52 AbstractBasePtrList base_list = {base1, base2}; 53 AbstractTuple t1 = AbstractTuple(base_list); 54 55 std::vector<ValuePtr> value_list = {MakeValue(static_cast<int64_t>(1)), MakeValue(static_cast<int64_t>(2))}; 56 auto tup = t1.BuildValue()->cast<ValueTuplePtr>()->value(); 57 58 ASSERT_TRUE(tup.size() == value_list.size()); 59 for (int i = 0; i < value_list.size(); i++) { 60 ASSERT_EQ(*tup[i], *value_list[i]); 61 } 62 63 // BuildValue(AbstractFunction) should return kAnyValue. 64 AbstractBasePtr abs_f1 = FromValue(prim::kPrimReturn, false); 65 ValuePtr abs_f1_built = abs_f1->BuildValue(); 66 ASSERT_EQ(abs_f1_built, prim::kPrimReturn); 67 68 FuncGraphPtr fg1 = std::make_shared<FuncGraph>(); 69 AbstractBasePtr abs_fg1 = FromValue(fg1, false); 70 ValuePtr abs_fg1_built = abs_fg1->BuildValue(); 71 ASSERT_EQ(abs_fg1_built, kAnyValue); 72 73 // BuildValue(Tuple(AbstractFunction)) should return kAnyValue; 74 AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false); 75 AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2})); 76 ValuePtr func_tuple_built = abs_func_tuple->BuildValue(); 77 ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); 78 79 // BuildValue(List(AbstractFunction)) should return kAnyValue; 80 AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2})); 81 ValuePtr func_list_built = abs_func_list->BuildValue(); 82 ASSERT_EQ(*func_list_built, ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); 83 84 // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue 85 abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2})); 86 func_tuple_built = abs_func_tuple->BuildValue(); 87 ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd})); 88 } 89 90 TEST_F(TestData, test_build_type) { 91 AbstractBasePtr s1 = FromValue(static_cast<int64_t>(1), false); 92 AbstractBasePtr s2 = FromValue(static_cast<int64_t>(2), false); 93 ASSERT_TRUE(Int(64) == *s1->BuildType()); 94 95 AbstractFunctionPtr f1 = std::make_shared<PrimitiveAbstractClosure>(nullptr, nullptr); 96 ASSERT_TRUE(Function() == *f1->BuildType()); 97 98 AbstractList l1 = AbstractList({s1, s2}); 99 ASSERT_TRUE(List({std::make_shared<Int>(64), std::make_shared<Int>(64)}) == *l1.BuildType()); 100 } 101 102 TEST_F(TestData, test_build_shape) { 103 AbstractBasePtr s1 = FromValue(static_cast<int64_t>(1), false); 104 AbstractBasePtr s2 = FromValue(static_cast<int64_t>(2), false); 105 ASSERT_TRUE(NoShape() == *s1->BuildShape()); 106 107 AbstractFunctionPtr f1 = std::make_shared<PrimitiveAbstractClosure>(nullptr, nullptr); 108 ASSERT_TRUE(NoShape() == *f1->BuildShape()); 109 110 AbstractList l1 = AbstractList({s1, s2}); 111 auto lshape = l1.BuildShape(); 112 ASSERT_TRUE(lshape); 113 114 std::vector<int64_t> weight1_dims = {2, 20, 5, 5}; 115 std::vector<int64_t> weight2_dims = {2, 2, 5, 5}; 116 tensor::TensorPtr weight1 = std::make_shared<tensor::Tensor>(kNumberTypeInt64, weight1_dims); 117 tensor::TensorPtr weight2 = std::make_shared<tensor::Tensor>(kNumberTypeInt64, weight2_dims); 118 119 AbstractBasePtr abstract_weight1 = FromValue(weight1, true); 120 AbstractBasePtr abstract_weight2 = FromValue(weight2, true); 121 ShapePtr shape_weight = dyn_cast<Shape>(abstract_weight1->BuildShape()); 122 ASSERT_TRUE(shape_weight); 123 ASSERT_EQ(weight1_dims, shape_weight->shape()); 124 125 std::vector<ValuePtr> vec({weight1, weight2}); 126 AbstractBasePtr abstract_tup = FromValue(vec, true); 127 std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape()); 128 ASSERT_TRUE(shape_tuple); 129 const std::vector<BaseShapePtr> &ptr_vec = shape_tuple->shape(); 130 ASSERT_EQ(ptr_vec.size(), 2); 131 132 ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]); 133 ASSERT_TRUE(shape1); 134 ASSERT_EQ(weight1_dims, shape1->shape()); 135 136 ShapePtr shape2 = dyn_cast<Shape>(ptr_vec[1]); 137 ASSERT_TRUE(shape2); 138 ASSERT_EQ(weight2_dims, shape2->shape()); 139 } 140 141 TEST_F(TestData, test_clone) { 142 AbstractBasePtr s1 = FromValue(static_cast<int64_t>(1), false); 143 AbstractBasePtr s2 = s1->Clone(); 144 ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); 145 ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack()); 146 ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack()); 147 148 AbstractFunctionPtr f1 = 149 std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext()); 150 AbstractBasePtr f2 = f1->Clone(); 151 ASSERT_TRUE(*f2 == *f1); 152 153 AbstractList l1 = AbstractList({s1, s2}); 154 AbstractBasePtr l2 = l1.Clone(); 155 AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get()); 156 ASSERT_TRUE(l2_cast != nullptr); 157 ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack()); 158 159 std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}, 160 {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}}; 161 std::unordered_map<std::string, ValuePtr> methods; 162 AbstractBasePtr c1 = std::make_shared<AbstractClass>(Named("Point"), attr, methods); 163 AbstractBasePtr c2 = c1->Clone(); 164 ASSERT_EQ(*c1, *c2); 165 } 166 167 TEST_F(TestData, test_join) { 168 int64_t int1 = 1; 169 AbstractBasePtr s1 = FromValue(int1, false); 170 AbstractBasePtr s2 = s1->Broaden(); 171 172 std::vector<AbstractBasePtr> xx = {s1, s2}; 173 AbstractListPtr l1 = std::make_shared<AbstractList>(xx); 174 AbstractListPtr l2 = std::make_shared<AbstractList>(xx); 175 l1->Join(l2); 176 } 177 178 TEST_F(TestData, test_broaden) { 179 int64_t int1 = 1; 180 AbstractBasePtr s1 = FromValue(int1, false); 181 AbstractBasePtr s2 = s1->Broaden(); 182 ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); 183 ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); 184 ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>()); 185 186 AbstractFunctionPtr f1 = 187 std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext()); 188 AbstractBasePtr f2 = f1->Broaden(); 189 ASSERT_TRUE(f2 == f1); 190 191 AbstractList l1 = AbstractList({s1, s2}); 192 AbstractBasePtr l2 = l1.Broaden(); 193 AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get()); 194 ASSERT_TRUE(l2_cast != nullptr); 195 AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); 196 ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>()); 197 } 198 199 } // namespace abstract 200 } // namespace mindspore 201