import os import sys if os.getcwd() not in sys.path: sys.path.append(os.getcwd()) import warnings from models.llm_models.configuration_base import BaseConfig # flake8: noqa: E721 def warning_on_one_line(message, category, filename, lineno, file=None, line=None): return f"{category.__name__}: {message}\n" warnings.formatwarning = warning_on_one_line def check_all_chunks_same_num_layer(num_blocks_per_chunk): for i in range(1, len(num_blocks_per_chunk)): if num_blocks_per_chunk[i] != num_blocks_per_chunk[0]: print("num_blocks_per_chunk:", num_blocks_per_chunk) raise RuntimeError( "This version of the sdk doesn't support different number of " "decoder layers per chunk, as shape fixer stage will fail. If you require this support," " please contact Mediatek sdk owner." ) def check_between_exclusive(num, min_, max_, message=None): if not (type(num) == type(min_) == type(max_)): raise TypeError( f"Got different types for num ({type(num)}), min ({type(min_)}), and max ({type(max_)})" ) if not (min_ < num < max_): if message is None: raise ValueError( f"Expected number between {min_} and {max_} exclusive, but got: {num}" ) else: raise ValueError( f"{message} must be between {min_} and {max_} exclusive, but got: {num}" ) def check_between_inclusive(num, min_, max_, message=None): if not (type(num) == type(min_) == type(max_)): raise TypeError( f"Got different types for num ({type(num)}), min ({type(min_)}), and max ({type(max_)})" ) if not (min_ <= num <= max_): if message is None: raise ValueError( f"Expected number between {min_} and {max_} inclusive, but got: {num}" ) else: raise ValueError( f"{message} must be between {min_} and {max_} inclusive, but got: {num}" ) def check_exist(file_or_folder, message=None): if not os.path.exists(file_or_folder): if message is None: raise FileNotFoundError(f"{file_or_folder} does not exist.") else: raise FileNotFoundError(f"{message} does not exist: {file_or_folder}") def check_ext(file, ext, message=None): if not file.endswith(ext): if message is None: raise RuntimeError(f"Expected {ext} file, but got: {file}") else: raise RuntimeError(f"Expected {ext} file for {message}, but got: {file}") def check_isdir(folder, message=None): if not os.path.isdir(folder): if message is None: raise FileNotFoundError(f"{folder} is not a directory.") else: raise RuntimeError(f"Expected directory for {message}, but got: {folder}") def check_old_arg(path): if os.path.isdir(path): raise RuntimeError( "This package's main usage has changed starting from v0.8.0. Please use" " model's config.json as main argument instead of weight directory." ) def check_shapes(shapes): if not isinstance(shapes, list): raise TypeError(f"Expected shapes to be a list, but got {type(shapes)} instead") for shape in shapes: if shape.count("t") != 1 or shape.count("c") != 1: raise RuntimeError( f"Shape {shape} is in the wrong format. Every shape needs to be of" "the format: xtyc where x and y are integers. (e.g. 32t512c)" ) try: _ = int(shape.split("t")[0]) except ValueError: raise RuntimeError( f"Shape {shape} is in the wrong format. Every shape needs to be of" "the format: xtyc where x and y are integers. (e.g. 32t512c)" ) try: _ = int(shape.split("t")[1].split("c")[0]) except ValueError: raise RuntimeError( f"Shape {shape} is in the wrong format. Every shape needs to be of" "the format: xtyc where x and y are integers. (e.g. 32t512c)" ) def check_supported_model(config): SUPPORTED_MODELS = [ "llama", "bloom", "baichuan", "qwen", "qwen1.5", "qwen2", "milm", ] if not isinstance(config, BaseConfig): raise RuntimeError( f"Unsupported config class: {type(config)}. " "config needs to be subclassed from BaseConfig" ) if config.model_type not in SUPPORTED_MODELS: raise RuntimeError( f"Unsupported model: {config.model_type}. Supported models: " f"{SUPPORTED_MODELS}" ) def check_supported_tokenizer(config): SUPPORTED_TOKENIZERS = [ "default", "bloom", "baichuan", "gpt2", "gpt2_fast", "qwen", "qwen2", "qwen2_fast", "llama", "pretrained_fast", ] if not isinstance(config, BaseConfig): raise RuntimeError( f"Unsupported config class: {type(config)}. " "config needs to be subclassed from BaseConfig" ) if config.tokenizer not in SUPPORTED_TOKENIZERS: raise RuntimeError( f"Unsupported tokenizer: {config.tokenizer}. Supported tokenizers: " f"{SUPPORTED_TOKENIZERS}" ) def check_tokenizer_exist(folder): model = config = False for f in os.listdir(folder): if f == "tokenizer.model" or f == "tokenizer.json" or f.endswith(".tiktoken"): model = True if f == "tokenizer_config.json": config = True if not model: raise FileNotFoundError( f"Tokenizer not found in {folder}. Expected tokenizer.model, " "tokenizer.json, or tokenizer.tiktoken" ) if not config: raise FileNotFoundError( f"Tokenizer config not found in {folder}. Expected " "tokenizer_config.json" ) def check_weights_exist(weight_dir): if ( len( [ f for f in os.listdir(weight_dir) if ( (f.startswith("pytorch_model") and f.endswith(".bin")) or (f.startswith("model") and f.endswith(".safetensors")) ) ] ) == 0 ): raise FileNotFoundError( f"No weight files found in {weight_dir}! Weight files should be either .bin or .safetensors file types." ) safetensors_l = [f for f in os.listdir(weight_dir) if f.endswith(".safetensors")] bin_l = [ f for f in os.listdir(weight_dir) if f.endswith(".bin") and "embedding" not in f ] if len(safetensors_l) & len(bin_l): raise RuntimeError( "Weights should only be in either .bin or .safetensors format, not both." )