• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 
19 #include "common/common_test.h"
20 
21 #include "pybind11/pybind11.h"
22 #include "pipeline/jit/static_analysis/static_analysis.h"
23 #include "abstract/utils.h"
24 #include "pipeline/jit/static_analysis/prim.h"
25 #include "pipeline/jit/parse/parse.h"
26 #include "pipeline/jit/parse/resolve.h"
27 #include "pipeline/jit/parse/data_converter.h"
28 #include "frontend/operator/ops.h"
29 
30 namespace mindspore {
31 namespace abstract {
32 
33 class TestAbstract : public UT::Common {
34  public:
TestAbstract()35   TestAbstract() {}
SetUp()36   virtual void SetUp() {}
TearDown()37   virtual void TearDown() {}
38 };
39 
TEST_F(TestAbstract,TestParseDataClass)40 TEST_F(TestAbstract, TestParseDataClass) {
41   // Check initialization before callback to Python.
42   if (Py_IsInitialized() == 0) {
43     Py_Initialize();
44   }
45   PyEval_InitThreads();
46 
47   py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo");
48 
49   ClassPtr cls_ptr = parse::ParseDataClass(fn);
50   ASSERT_TRUE(nullptr != cls_ptr);
51   std::shared_ptr<Class> cls = dyn_cast<Class>(cls_ptr);
52   ASSERT_TRUE(nullptr != cls);
53 
54   MS_LOG(INFO) << "" << cls->ToString();
55   ASSERT_EQ(cls->tag(), Named(std::string("TestFoo")));
56 
57   ClassAttrVector attributes = cls->GetAttributes();
58   ASSERT_EQ(attributes.size(), 2);
59   for (auto &v : attributes) {
60     if (v.first == std::string("x")) {
61       ASSERT_TRUE(nullptr != dyn_cast<Float>(v.second));
62     }
63     if (v.first == std::string("y")) {
64       ASSERT_TRUE(nullptr != dyn_cast<Int>(v.second));
65     }
66   }
67 
68   std::unordered_map<std::string, ValuePtr> methods = cls->methods();
69   ASSERT_EQ(methods.size(), 4);
70   int counts = 0;
71   for (auto &v : methods) {
72     if (v.first == std::string("inf")) {
73       counts++;
74     }
75     MS_LOG(INFO) << "" << v.first;
76   }
77   ASSERT_EQ(counts, 1);
78 
79   ValuePtr obj = std::make_shared<parse::ClassObject>(fn, "TestFoo");
80 
81   ValueNodePtr fn_node = NewValueNode(obj);
82   AnfNodeConfigPtr fn_conf = std::make_shared<AnfNodeConfig>(nullptr, fn_node, nullptr, nullptr);
83   AbstractBasePtr foo = ToAbstract(obj, nullptr, fn_conf);
84   ASSERT_TRUE(foo != nullptr);
85 
86   AbstractBasePtr abstract_x = FromValue(1.1, true);
87   AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(5), true);
88 
89   auto partical_func = dyn_cast<PartialAbstractClosure>(foo);
90   AbstractBasePtrList args_spec_list = partical_func->args();
91   ASSERT_GT(args_spec_list.size(), 0);
92   AbstractScalarPtr abs_scalar = dyn_cast<AbstractScalar>(args_spec_list[0]);
93 
94   AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y};
95 
96   auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
97   ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_);
98 
99   AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list);
100   ASSERT_TRUE(nullptr != new_cls);
101 }
102 
103 }  // namespace abstract
104 }  // namespace mindspore
105