• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2021-2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import logging
19import subprocess
20from os import path, listdir, getenv, chdir
21from typing import Dict, List, Optional
22
23from runner.logger import Log
24
25_LOGGER = logging.getLogger("runner.plugins_registry")
26
27
28class PluginsRegistry:
29    ENV_PLUGIN_PATH = "PLUGIN_PATH"
30    BUILTIN_PLUGINS = "plugins"
31    BUILTIN_ROOT = "runner"
32
33    def __init__(self) -> None:
34        self.registry: Dict[str, type] = {}
35        self.side_plugins: List[str] = []
36        self.load_from_env()
37        self.load_builtin_plugins()
38
39    @staticmethod
40    def filter_builtins(items: List[str]) -> List[str]:
41        return [item for item in items if not item.startswith("__")]
42
43    @staticmethod
44    def my_dir(obj: str) -> List[str]:
45        return PluginsRegistry.filter_builtins(dir(obj))
46
47    def load_plugin(self, plugin_name: str, plugin_path: str) -> None:
48        runner_class = [
49            cls
50            for cls in PluginsRegistry.filter_builtins(listdir(plugin_path))
51            if cls.startswith("runner_")
52        ]
53        runner_class_name = runner_class.pop() if len(runner_class) > 0 else None
54        if runner_class_name is not None:
55            last_dot = runner_class_name.rfind(".")
56            runner_class_name = runner_class_name[:last_dot]
57            class_module_name = f"{PluginsRegistry.BUILTIN_ROOT}" \
58                                f".{PluginsRegistry.BUILTIN_PLUGINS}" \
59                                f".{plugin_name}.{runner_class_name}"
60            class_module_root = __import__(class_module_name)
61            class_module_plugin = getattr(getattr(class_module_root, PluginsRegistry.BUILTIN_PLUGINS), plugin_name)
62            class_module_runner = getattr(class_module_plugin, runner_class_name)
63            classes = PluginsRegistry.my_dir(class_module_runner)
64            classes = [cls for cls in classes if cls.startswith("Runner") and cls.lower().endswith(plugin_name.lower())]
65            class_name = classes.pop() if len(classes) > 0 else None
66            if class_name is not None:
67                class_obj = getattr(class_module_runner, class_name)
68                self.add(plugin_name, class_obj)
69
70    def load_builtin_plugins(self) -> None:
71        starting_path = path.join(path.dirname(__file__), PluginsRegistry.BUILTIN_PLUGINS)
72        plugin_names = PluginsRegistry.filter_builtins(listdir(starting_path))
73
74        for plugin_name in plugin_names:
75            plugin_path = path.join(starting_path, plugin_name)
76            if path.isdir(plugin_path):
77                self.load_plugin(plugin_name, plugin_path)
78            else:
79                Log.all(_LOGGER, f"Found extra file '{plugin_path}' at plugins folder")
80
81    def load_from_env(self) -> None:
82        builtin_plugins_path = path.join(path.dirname(__file__), PluginsRegistry.BUILTIN_PLUGINS)
83        side_plugins = getenv(PluginsRegistry.ENV_PLUGIN_PATH, "").split(path.pathsep)
84        if len(side_plugins) == 0:
85            return
86
87        chdir(builtin_plugins_path)
88        for side_plugin in side_plugins:
89            if not path.exists(side_plugin):
90                continue
91            cmd = ["ln", "-s", side_plugin]
92            subprocess.run(cmd, check=True)
93            self.side_plugins.append(path.join(
94                path.dirname(__file__),
95                PluginsRegistry.BUILTIN_PLUGINS,
96                path.basename(side_plugin)
97            ))
98
99    def add(self, runner_name: str, runner: type) -> None:
100        if runner_name not in self.registry:
101            self.registry[runner_name] = runner
102            Log.all(_LOGGER, f"Registered plugin '{runner_name}' with class '{runner.__name__}'")
103        else:
104            Log.exception_and_raise(_LOGGER, f"Plugin '{runner_name}' already registered")
105
106    def get_runner(self, name: str) -> Optional[type]:
107        return self.registry.get(name)
108
109    def cleanup(self) -> None:
110        if len(self.side_plugins) == 0:
111            return
112        cmd = ["rm"] + self.side_plugins
113        subprocess.run(cmd, check=True)
114