1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <map> 17 #include <string> 18 #include "pybind11/pybind11.h" 19 #include "utils/callbacks.h" 20 #include "common/common_test.h" 21 #include "pipeline/jit/pipeline.h" 22 #include "pipeline/jit/parse/python_adapter.h" 23 #include "transform/graph_ir/df_graph_manager.h" 24 #include "debug/draw.h" 25 #ifdef ENABLE_GE 26 #include "utils/callbacks_ge.h" 27 #endif 28 29 namespace mindspore { 30 namespace python_adapter = mindspore::parse::python_adapter; 31 32 class TestCallback : public UT::Common { 33 public: TestCallback()34 TestCallback() {} 35 }; 36 37 /* 38 * # ut and python static info not share 39 TEST_F(TestCallback, test_get_anf_tensor_shape) { 40 py::object obj = python_adapter::CallPyFn("gtest_input.pipeline.parse.parse_class", "test_get_object_graph"); 41 FuncGraphPtr func_graph = pipeline::GraphExecutorPy::GetInstance()->GetFuncGraphPy(obj); 42 transform::DfGraphManager::GetInstance().SetAnfGraph(func_graph); 43 std::shared_ptr<std::vector<int64_t>> param_shape_ptr = std::make_shared<std::vector<int64_t>>(); 44 bool get_shape = callbacks::GetParameterShape(func_graph, "weight", param_shape_ptr); 45 ASSERT_TRUE(get_shape == true); 46 } 47 48 TEST_F(TestCallback, test_checkpoint_save_op) { 49 py::object obj = python_adapter::CallPyFn("gtest_input.pipeline.parse.parse_class", "test_get_object_graph"); 50 FuncGraphPtr func_graph = pipeline::GraphExecutorPy::GetInstance()->GetFuncGraphPy(obj); 51 transform::DfGraphManager::GetInstance().SetAnfGraph(func_graph); 52 53 #define DTYPE float 54 ge::DataType dt = ge::DataType::DT_FLOAT; 55 56 std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9}; 57 auto data = data1; 58 ge::Shape shape({2, 2, 2, 1}); 59 ge::Format format = ge::Format::FORMAT_NCHW; 60 ge::TensorDesc desc(shape, format, dt); 61 transform::GeTensorPtr ge_tensor_ptr = 62 std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(DTYPE)); 63 std::map<std::string, GeTensor> param_map; 64 param_map.insert(std::pair<std::string, GeTensor>("weight", *ge_tensor_ptr)); 65 param_map.insert(std::pair<std::string, GeTensor>("network.weight", *ge_tensor_ptr)); 66 int ret = callbacks::CheckpointSaveCallback(0, param_map); 67 MS_LOG(INFO) << "ret=" << ret; 68 ASSERT_EQ(ret, 0); 69 } 70 */ 71 72 /* 73 TEST_F(TestCallback, test_summary_save_op) { 74 py::object obj = python_adapter::CallPyFn( 75 "gtest_input.pipeline.parse.parse_class", "test_get_object_graph"); 76 FuncGraphPtr func_graph = obj.cast<FuncGraphPtr>(); 77 transform::DfGraphManager::GetInstance().SetAnfGraph(func_graph); 78 79 #define DTYPE float 80 ge::DataType dt = ge::DataType::DT_FLOAT; 81 82 float data1 = 1.1; 83 float data2 = 2.1; 84 ge::Shape shape({1, 1, 1, 1}); 85 ge::Format format = ge::Format::FORMAT_NCHW; 86 ge::TensorDesc desc(shape, format, dt); 87 GeTensorPtr ge_tensor_ptr1 = std::make_shared<GeTensor>(desc, 88 reinterpret_cast<uint8_t *>(&data1), 89 sizeof(DTYPE)); 90 GeTensorPtr ge_tensor_ptr2 = std::make_shared<GeTensor>(desc, 91 reinterpret_cast<uint8_t *>(&data2), 92 sizeof(DTYPE)); 93 std::map<std::string, GeTensor> param_map; 94 param_map.insert(std::pair<std::string, GeTensor>("x1[:Scalar]", *ge_tensor_ptr1)); 95 param_map.insert(std::pair<std::string, GeTensor>("x2[:Scalar]", *ge_tensor_ptr2)); 96 int ret = callbacks::SummarySaveCallback(0, param_map); 97 MS_LOG(INFO) << "ret=" << ret; 98 ASSERT_TRUE(ret == 0); 99 } 100 */ 101 } // namespace mindspore 102