1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <iostream> 18 #include <memory> 19 #include "common/common_test.h" 20 21 #ifdef OPEN_SOURCE 22 #include "ge/client/ge_api.h" 23 #else 24 #include "external/ge/ge_api.h" 25 #endif 26 27 #define private public 28 #include "transform/graph_ir/df_graph_manager.h" 29 30 using UT::Common; 31 32 namespace mindspore { 33 namespace transform { 34 35 class TestDfGraphManager : public UT::Common { 36 public: 37 TestDfGraphManager() {} 38 }; 39 40 TEST_F(TestDfGraphManager, TestAPI) { 41 // test public interface: 42 DfGraphManager& graph_manager = DfGraphManager::GetInstance(); 43 ASSERT_EQ(0, graph_manager.GetAllGraphs().size()); 44 45 // test public interface: 46 std::shared_ptr<ge::Graph> ge_graph = std::make_shared<ge::Graph>(); 47 ASSERT_TRUE(graph_manager.AddGraph("test_graph", nullptr) != Status::SUCCESS); 48 graph_manager.AddGraph("test_graph", ge_graph); 49 ASSERT_EQ(1, graph_manager.GetAllGraphs().size()); 50 std::vector<DfGraphWrapperPtr> wrappers = graph_manager.GetAllGraphs(); 51 ASSERT_EQ("test_graph", wrappers.back()->name_); 52 ASSERT_EQ(ge_graph, wrappers.back()->graph_ptr_); 53 54 // test public interface: 55 DfGraphWrapperPtr wrappers2 = graph_manager.GetGraphByName("test_graph"); 56 ASSERT_EQ(ge_graph, wrappers2->graph_ptr_); 57 58 // test public interface: 59 graph_manager.ClearGraph(); 60 ASSERT_EQ(0, graph_manager.GetAllGraphs().size()); 61 62 // test public interface: 63 int id = graph_manager.GenerateId(); 64 assert(id > 0); 65 } 66 67 } // namespace transform 68 } // namespace mindspore 69