1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_context """ 16import os 17import shutil 18import pytest 19 20from mindspore import context 21from mindspore._c_expression import security 22from tests.security_utils import security_off_wrap 23 24 25# pylint: disable=W0212 26# W0212: protected-access 27 28 29def setup_module(module): 30 context.set_context(mode=context.PYNATIVE_MODE) 31 32 33def test_contex_create_context(): 34 """ test_contex_create_context """ 35 context.set_context(mode=context.PYNATIVE_MODE) 36 if context._k_context is None: 37 ctx = context._context() 38 assert ctx is not None 39 context._k_context = None 40 41 42def test_set_save_graphs_in_security(): 43 """ test set save_graphs in the security mode""" 44 if security.enable_security(): 45 with pytest.raises(ValueError) as err: 46 context.set_context(save_graphs=True) 47 assert "not supported" in str(err.value) 48 49 50def test_set_save_graphs_path_in_security(): 51 """ test set save_graphs_path in the security mode""" 52 if security.enable_security(): 53 with pytest.raises(ValueError) as err: 54 context.set_context(save_graphs_path="ir_files") 55 assert "not supported" in str(err.value) 56 57 58def test_switch_mode(): 59 """ test_switch_mode """ 60 context.set_context(mode=context.GRAPH_MODE) 61 assert context.get_context("mode") == context.GRAPH_MODE 62 context.set_context(mode=context.PYNATIVE_MODE) 63 assert context.get_context("mode") == context.PYNATIVE_MODE 64 65 66def test_set_device_id(): 67 """ test_set_device_id """ 68 with pytest.raises(TypeError): 69 context.set_context(device_id=1) 70 context.set_context(device_id="cpu") 71 assert context.get_context("device_id") == 1 72 73 74def test_device_target(): 75 """ test_device_target """ 76 with pytest.raises(TypeError): 77 context.set_context(device_target=123) 78 context.set_context(device_target="GPU") 79 assert context.get_context("device_target") == "GPU" 80 context.set_context(device_target="Ascend") 81 assert context.get_context("device_target") == "Ascend" 82 assert context.get_context("device_id") == 1 83 84 85def test_variable_memory_max_size(): 86 """test_variable_memory_max_size""" 87 with pytest.raises(TypeError): 88 context.set_context(variable_memory_max_size=True) 89 with pytest.raises(TypeError): 90 context.set_context(variable_memory_max_size=1) 91 context.set_context(variable_memory_max_size="1G") 92 context.set_context.__wrapped__(variable_memory_max_size="3GB") 93 94def test_max_device_memory_size(): 95 """test_max_device_memory_size""" 96 with pytest.raises(TypeError): 97 context.set_context(max_device_memory=True) 98 with pytest.raises(TypeError): 99 context.set_context(max_device_memory=1) 100 context.set_context(max_device_memory="3.5G") 101 context.set_context.__wrapped__(max_device_memory="3GB") 102 103def test_print_file_path(): 104 """test_print_file_path""" 105 with pytest.raises(IOError): 106 context.set_context(print_file_path="./") 107 108 109@security_off_wrap 110def test_set_context(): 111 """ test_set_context """ 112 context.set_context.__wrapped__(device_id=0) 113 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", 114 save_graphs=True, save_graphs_path="mindspore_ir_path") 115 assert context.get_context("device_id") == 0 116 assert context.get_context("device_target") == "Ascend" 117 assert context.get_context("save_graphs") 118 assert os.path.exists("mindspore_ir_path") 119 assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 120 assert context.get_context("mode") == context.GRAPH_MODE 121 122 context.set_context(mode=context.PYNATIVE_MODE) 123 assert context.get_context("mode") == context.PYNATIVE_MODE 124 assert context.get_context("device_target") == "Ascend" 125 126 with pytest.raises(ValueError): 127 context.set_context(modex="ge") 128 129 130def teardown_module(): 131 dirs = ['mindspore_ir_path'] 132 for item in dirs: 133 item_name = './' + item 134 if not os.path.exists(item_name): 135 continue 136 if os.path.isdir(item_name): 137 shutil.rmtree(item_name) 138 elif os.path.isfile(item_name): 139 os.remove(item_name) 140