# Copyright 2024 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """Tests for pw_config_loader.""" from pathlib import Path import tempfile from typing import Any import unittest from pw_config_loader import yaml_config_loader_mixin import yaml # pylint: disable=no-member,no-self-use class YamlConfigLoader(yaml_config_loader_mixin.YamlConfigLoaderMixin): @property def config(self) -> dict[str, Any]: return self._config class TestOneFile(unittest.TestCase): """Tests for loading a config section from one file.""" def setUp(self): self._title = 'title' def init(self, config: dict[str, Any]) -> dict[str, Any]: loader = YamlConfigLoader() with tempfile.TemporaryDirectory() as folder: path = Path(folder, 'foo.yaml') path.write_bytes(yaml.safe_dump(config).encode()) loader.config_init( user_file=path, config_section_title=self._title, ) return loader.config def test_normal(self): content = {'a': 1, 'b': 2} config = self.init({self._title: content}) self.assertEqual(content['a'], config['a']) self.assertEqual(content['b'], config['b']) def test_config_title(self): content = {'a': 1, 'b': 2, 'config_title': self._title} config = self.init(content) self.assertEqual(content['a'], config['a']) self.assertEqual(content['b'], config['b']) class TestMultipleFiles(unittest.TestCase): """Tests for loading config sections from multiple files.""" def init( self, project_config: dict[str, Any], project_user_config: dict[str, Any], user_config: dict[str, Any], ) -> dict[str, Any]: """Write config files then read and parse them.""" loader = YamlConfigLoader() title = 'title' with tempfile.TemporaryDirectory() as folder: path = Path(folder) user_path = path / 'user.yaml' user_path.write_text(yaml.safe_dump({title: user_config})) project_user_path = path / 'project_user.yaml' project_user_path.write_text( yaml.safe_dump({title: project_user_config}) ) project_path = path / 'project.yaml' project_path.write_text(yaml.safe_dump({title: project_config})) loader.config_init( user_file=user_path, project_user_file=project_user_path, project_file=project_path, config_section_title=title, ) return loader.config def test_user_override(self): config = self.init( user_config={'a': 1}, project_user_config={'a': 2}, project_config={'a': 3}, ) self.assertEqual(config['a'], 1) def test_project_user_override(self): config = self.init( user_config={}, project_user_config={'a': 2}, project_config={'a': 3}, ) self.assertEqual(config['a'], 2) def test_not_overridden(self): config = self.init( user_config={}, project_user_config={}, project_config={'a': 3}, ) self.assertEqual(config['a'], 3) def test_different_keys(self): config = self.init( user_config={'a': 1}, project_user_config={'b': 2}, project_config={'c': 3}, ) self.assertEqual(config['a'], 1) self.assertEqual(config['b'], 2) self.assertEqual(config['c'], 3) class TestNestedTitle(unittest.TestCase): """Tests for nested config section loading.""" def setUp(self): self._title = ('title', 'subtitle', 'subsubtitle', 'subsubsubtitle') def init(self, config: dict[str, Any]) -> dict[str, Any]: loader = YamlConfigLoader() with tempfile.TemporaryDirectory() as folder: path = Path(folder, 'foo.yaml') path.write_bytes(yaml.safe_dump(config).encode()) loader.config_init( user_file=path, config_section_title=self._title, ) return loader.config def test_normal(self): content = {'a': 1, 'b': 2} for part in reversed(self._title): content = {part: content} config = self.init(content) self.assertEqual(config['a'], 1) self.assertEqual(config['b'], 2) def test_config_title(self): content = {'a': 1, 'b': 2, 'config_title': '.'.join(self._title)} config = self.init(content) self.assertEqual(config['a'], 1) self.assertEqual(config['b'], 2) class CustomOverloadYamlConfigLoader( yaml_config_loader_mixin.YamlConfigLoaderMixin ): """Custom config loader that implements handle_overloaded_value().""" @property def config(self) -> dict[str, Any]: return self._config def handle_overloaded_value( # pylint: disable=no-self-use self, key: str, stage: yaml_config_loader_mixin.Stage, original_value: Any, overriding_value: Any, ): if key == 'extend': if original_value: return original_value + overriding_value return overriding_value if key == 'extend_sort': if original_value: result = original_value + overriding_value else: result = overriding_value return sorted(result) if key == 'do_not_override': if original_value: return original_value if key == 'max': return max(original_value, overriding_value) return overriding_value class TestOverloading(unittest.TestCase): """Tests for envparse.EnvironmentParser.""" def init( self, project_config: dict[str, Any], project_user_config: dict[str, Any], user_config: dict[str, Any], ) -> dict[str, Any]: """Write config files then read and parse them.""" loader = CustomOverloadYamlConfigLoader() title = 'title' with tempfile.TemporaryDirectory() as folder: path = Path(folder) user_path = path / 'user.yaml' user_path.write_text(yaml.safe_dump({title: user_config})) project_user_path = path / 'project_user.yaml' project_user_path.write_text( yaml.safe_dump({title: project_user_config}) ) project_path = path / 'project.yaml' project_path.write_text(yaml.safe_dump({title: project_config})) loader.config_init( user_file=user_path, project_user_file=project_user_path, project_file=project_path, config_section_title=title, ) return loader.config def test_lists(self): config = self.init( project_config={ 'extend': list('abc'), 'extend_sort': list('az'), 'do_not_override': ['persists'], 'override': ['hidden'], }, project_user_config={ 'extend': list('def'), 'extend_sort': list('by'), 'do_not_override': ['ignored'], 'override': ['ignored'], }, user_config={ 'extend': list('ghi'), 'extend_sort': list('cx'), 'do_not_override': ['ignored_2'], 'override': ['overrides'], }, ) self.assertEqual(config['extend'], list('abcdefghi')) self.assertEqual(config['extend_sort'], list('abcxyz')) self.assertEqual(config['do_not_override'], ['persists']) self.assertEqual(config['override'], ['overrides']) def test_scalars(self): config = self.init( project_config={'extend': 'abc', 'max': 1}, project_user_config={'extend': 'def', 'max': 3}, user_config={'extend': 'ghi', 'max': 2}, ) self.assertEqual(config['extend'], 'abcdefghi') self.assertEqual(config['max'], 3) if __name__ == '__main__': unittest.main()