1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 19 #include "common/common_test.h" 20 21 #include "pybind11/pybind11.h" 22 #include "pipeline/jit/static_analysis/static_analysis.h" 23 #include "abstract/utils.h" 24 #include "pipeline/jit/static_analysis/prim.h" 25 #include "pipeline/jit/parse/parse.h" 26 #include "pipeline/jit/parse/resolve.h" 27 #include "pipeline/jit/parse/data_converter.h" 28 #include "frontend/operator/ops.h" 29 30 namespace mindspore { 31 namespace abstract { 32 33 class TestAbstract : public UT::Common { 34 public: 35 TestAbstract() {} 36 virtual void SetUp() {} 37 virtual void TearDown() {} 38 }; 39 40 TEST_F(TestAbstract, TestParseDataClass) { 41 // Check initialization before callback to Python. 42 if (Py_IsInitialized() == 0) { 43 Py_Initialize(); 44 } 45 PyEval_InitThreads(); 46 47 py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo"); 48 49 ClassPtr cls_ptr = parse::ParseDataClass(fn); 50 ASSERT_TRUE(nullptr != cls_ptr); 51 std::shared_ptr<Class> cls = dyn_cast<Class>(cls_ptr); 52 ASSERT_TRUE(nullptr != cls); 53 54 MS_LOG(INFO) << "" << cls->ToString(); 55 ASSERT_EQ(cls->tag(), Named(std::string("TestFoo"))); 56 57 ClassAttrVector attributes = cls->GetAttributes(); 58 ASSERT_EQ(attributes.size(), 2); 59 for (auto &v : attributes) { 60 if (v.first == std::string("x")) { 61 ASSERT_TRUE(nullptr != dyn_cast<Float>(v.second)); 62 } 63 if (v.first == std::string("y")) { 64 ASSERT_TRUE(nullptr != dyn_cast<Int>(v.second)); 65 } 66 } 67 68 std::unordered_map<std::string, ValuePtr> methods = cls->methods(); 69 ASSERT_EQ(methods.size(), 4); 70 int counts = 0; 71 for (auto &v : methods) { 72 if (v.first == std::string("inf")) { 73 counts++; 74 } 75 MS_LOG(INFO) << "" << v.first; 76 } 77 ASSERT_EQ(counts, 1); 78 79 ValuePtr obj = std::make_shared<parse::ClassObject>(fn, "TestFoo"); 80 81 ValueNodePtr fn_node = NewValueNode(obj); 82 AnfNodeConfigPtr fn_conf = std::make_shared<AnfNodeConfig>(nullptr, fn_node, nullptr, nullptr); 83 AbstractBasePtr foo = ToAbstract(obj, nullptr, fn_conf); 84 ASSERT_TRUE(foo != nullptr); 85 86 AbstractBasePtr abstract_x = FromValue(1.1, true); 87 AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(5), true); 88 89 auto partical_func = dyn_cast<PartialAbstractClosure>(foo); 90 AbstractBasePtrList args_spec_list = partical_func->args(); 91 ASSERT_GT(args_spec_list.size(), 0); 92 AbstractScalarPtr abs_scalar = dyn_cast<AbstractScalar>(args_spec_list[0]); 93 94 AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; 95 96 auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); 97 ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_); 98 99 AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list); 100 ASSERT_TRUE(nullptr != new_cls); 101 } 102 103 } // namespace abstract 104 } // namespace mindspore 105