1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <memory>
18
19 #include "common/common_test.h"
20
21 #include "pybind11/pybind11.h"
22 #include "pipeline/jit/static_analysis/static_analysis.h"
23 #include "abstract/utils.h"
24 #include "pipeline/jit/static_analysis/prim.h"
25 #include "pipeline/jit/parse/parse.h"
26 #include "pipeline/jit/parse/resolve.h"
27 #include "pipeline/jit/parse/data_converter.h"
28 #include "frontend/operator/ops.h"
29
30 namespace mindspore {
31 namespace abstract {
32
33 class TestAbstract : public UT::Common {
34 public:
TestAbstract()35 TestAbstract() {}
SetUp()36 virtual void SetUp() {}
TearDown()37 virtual void TearDown() {}
38 };
39
TEST_F(TestAbstract,TestParseDataClass)40 TEST_F(TestAbstract, TestParseDataClass) {
41 // Check initialization before callback to Python.
42 if (Py_IsInitialized() == 0) {
43 Py_Initialize();
44 }
45 PyEval_InitThreads();
46
47 py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo");
48
49 ClassPtr cls_ptr = parse::ParseDataClass(fn);
50 ASSERT_TRUE(nullptr != cls_ptr);
51 std::shared_ptr<Class> cls = dyn_cast<Class>(cls_ptr);
52 ASSERT_TRUE(nullptr != cls);
53
54 MS_LOG(INFO) << "" << cls->ToString();
55 ASSERT_EQ(cls->tag(), Named(std::string("TestFoo")));
56
57 ClassAttrVector attributes = cls->GetAttributes();
58 ASSERT_EQ(attributes.size(), 2);
59 for (auto &v : attributes) {
60 if (v.first == std::string("x")) {
61 ASSERT_TRUE(nullptr != dyn_cast<Float>(v.second));
62 }
63 if (v.first == std::string("y")) {
64 ASSERT_TRUE(nullptr != dyn_cast<Int>(v.second));
65 }
66 }
67
68 std::unordered_map<std::string, ValuePtr> methods = cls->methods();
69 ASSERT_EQ(methods.size(), 4);
70 int counts = 0;
71 for (auto &v : methods) {
72 if (v.first == std::string("inf")) {
73 counts++;
74 }
75 MS_LOG(INFO) << "" << v.first;
76 }
77 ASSERT_EQ(counts, 1);
78
79 ValuePtr obj = std::make_shared<parse::ClassObject>(fn, "TestFoo");
80
81 ValueNodePtr fn_node = NewValueNode(obj);
82 AnfNodeConfigPtr fn_conf = std::make_shared<AnfNodeConfig>(nullptr, fn_node, nullptr, nullptr);
83 AbstractBasePtr foo = ToAbstract(obj, nullptr, fn_conf);
84 ASSERT_TRUE(foo != nullptr);
85
86 AbstractBasePtr abstract_x = FromValue(1.1, true);
87 AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(5), true);
88
89 auto partical_func = dyn_cast<PartialAbstractClosure>(foo);
90 AbstractBasePtrList args_spec_list = partical_func->args();
91 ASSERT_GT(args_spec_list.size(), 0);
92 AbstractScalarPtr abs_scalar = dyn_cast<AbstractScalar>(args_spec_list[0]);
93
94 AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y};
95
96 auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
97 ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_);
98
99 AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list);
100 ASSERT_TRUE(nullptr != new_cls);
101 }
102
103 } // namespace abstract
104 } // namespace mindspore
105