• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2021-2023 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# This file does only contain a selection of the most common options. For a
18# full list see the documentation:
19# http://www.sphinx-doc.org/en/master/config
20
21from functools import cached_property
22from typing import Dict, Optional
23
24from runner.options.decorator_value import value, _to_bool, _to_jit_preheats, _to_str
25
26
27class JitOptions:
28    __DEFAULT_REPEATS = 0
29    __DEFAULT_REPEATS_IF_EMPTY = 30
30    __DEFAULT_THRESHOLD_IF_EMPTY = 20
31
32    def __str__(self) -> str:
33        return _to_str(self, 2)
34
35    def to_dict(self) -> Dict[str, object]:
36        return {
37            "enable": self.enable,
38            "num_repeats": self.num_repeats,
39            "compiler_threshold": self.compiler_threshold,
40        }
41
42    @cached_property
43    @value(yaml_path="ark.jit.enable", cli_name="jit", cast_to_type=_to_bool)
44    def enable(self) -> bool:
45        return False
46
47    @cached_property
48    @value(
49        yaml_path="ark.jit.num_repeats",
50        cli_name="jit_preheat_repeats",
51        cast_to_type=lambda x: _to_jit_preheats(
52            cli_value=x, prop="num_repeats",
53            default_if_empty=JitOptions.__DEFAULT_REPEATS_IF_EMPTY)
54    )
55    def num_repeats(self) -> int:
56        return JitOptions.__DEFAULT_REPEATS
57
58    @cached_property
59    @value(
60        yaml_path="ark.jit.compiler_threshold",
61        cli_name="jit_preheat_repeats",
62        cast_to_type=lambda x: _to_jit_preheats(
63            cli_value=x, prop="compiler_threshold",
64            default_if_empty=JitOptions.__DEFAULT_THRESHOLD_IF_EMPTY)
65    )
66    def compiler_threshold(self) -> Optional[int]:
67        return None
68
69    def get_command_line(self) -> str:
70        options = '--jit' if self.enable else ''
71        jit_options = [
72            f'num_repeats={self.num_repeats}' if self.num_repeats != JitOptions.__DEFAULT_REPEATS else '',
73            f'compiler_threshold={self.compiler_threshold}' if self.compiler_threshold is not None else '',
74        ]
75        jit_options = [option for option in jit_options if option]
76
77        if jit_options:
78            options += f' --jit-preheat-repeats="{",".join(jit_options)}"'
79        return options
80