# mypy: allow-untyped-defs import os import pathlib from collections import defaultdict from typing import Any, Dict, List, Set, Tuple, Union def materialize_lines(lines: List[str], indentation: int) -> str: output = "" new_line_with_indent = "\n" + " " * indentation for i, line in enumerate(lines): if i != 0: output += new_line_with_indent output += line.replace("\n", new_line_with_indent) return output def gen_from_template( dir: str, template_name: str, output_name: str, replacements: List[Tuple[str, Any, int]], ): template_path = os.path.join(dir, template_name) output_path = os.path.join(dir, output_name) with open(template_path) as f: content = f.read() for placeholder, lines, indentation in replacements: with open(output_path, "w") as f: content = content.replace( placeholder, materialize_lines(lines, indentation) ) f.write(content) def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]: """ When given a path to a directory, returns the paths to the relevant files within it. This function does NOT recursive traverse to subdirectories. """ paths: Set[str] = set() for dir_path in dir_paths: all_files = os.listdir(dir_path) python_files = {fname for fname in all_files if ".py" == fname[-3:]} filter_files = { fname for fname in python_files if fname not in files_to_exclude } paths.update({os.path.join(dir_path, fname) for fname in filter_files}) return paths def extract_method_name(line: str) -> str: """Extract method name from decorator in the form of "@functional_datapipe({method_name})".""" if '("' in line: start_token, end_token = '("', '")' elif "('" in line: start_token, end_token = "('", "')" else: raise RuntimeError( f"Unable to find appropriate method name within line:\n{line}" ) start, end = line.find(start_token) + len(start_token), line.find(end_token) return line[start:end] def extract_class_name(line: str) -> str: """Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):".""" start_token = "class " end_token = "(" start, end = line.find(start_token) + len(start_token), line.find(end_token) return line[start:end] def parse_datapipe_file( file_path: str, ) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]: """Given a path to file, parses the file and returns a dictionary of method names to function signatures.""" method_to_signature, method_to_class_name, special_output_type = {}, {}, set() doc_string_dict = defaultdict(list) with open(file_path) as f: open_paren_count = 0 method_name, class_name, signature = "", "", "" skip = False for line in f: if line.count('"""') % 2 == 1: skip = not skip if skip or '"""' in line: # Saving docstrings doc_string_dict[method_name].append(line) continue if "@functional_datapipe" in line: method_name = extract_method_name(line) doc_string_dict[method_name] = [] continue if method_name and "class " in line: class_name = extract_class_name(line) continue if method_name and ("def __init__(" in line or "def __new__(" in line): if "def __new__(" in line: special_output_type.add(method_name) open_paren_count += 1 start = line.find("(") + len("(") line = line[start:] if open_paren_count > 0: open_paren_count += line.count("(") open_paren_count -= line.count(")") if open_paren_count == 0: end = line.rfind(")") signature += line[:end] method_to_signature[method_name] = process_signature(signature) method_to_class_name[method_name] = class_name method_name, class_name, signature = "", "", "" elif open_paren_count < 0: raise RuntimeError( "open parenthesis count < 0. This shouldn't be possible." ) else: signature += line.strip("\n").strip(" ") return ( method_to_signature, method_to_class_name, special_output_type, doc_string_dict, ) def parse_datapipe_files( file_paths: Set[str], ) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]: ( methods_and_signatures, methods_and_class_names, methods_with_special_output_types, ) = ({}, {}, set()) methods_and_doc_strings = {} for path in file_paths: ( method_to_signature, method_to_class_name, methods_needing_special_output_types, doc_string_dict, ) = parse_datapipe_file(path) methods_and_signatures.update(method_to_signature) methods_and_class_names.update(method_to_class_name) methods_with_special_output_types.update(methods_needing_special_output_types) methods_and_doc_strings.update(doc_string_dict) return ( methods_and_signatures, methods_and_class_names, methods_with_special_output_types, methods_and_doc_strings, ) def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]: """Given a line of text, split it on comma unless the comma is within a bracket '[]'.""" bracket_count = 0 curr_token = "" res = [] for char in line: if char == "[": bracket_count += 1 elif char == "]": bracket_count -= 1 elif char == delimiter and bracket_count == 0: res.append(curr_token) curr_token = "" continue curr_token += char res.append(curr_token) return res def process_signature(line: str) -> str: """ Clean up a given raw function signature. This includes removing the self-referential datapipe argument, default arguments of input functions, newlines, and spaces. """ tokens: List[str] = split_outside_bracket(line) for i, token in enumerate(tokens): tokens[i] = token.strip(" ") if token == "cls": tokens[i] = "self" elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"): # Remove the datapipe after 'self' or 'cls' unless it has '*' tokens[i] = "" elif "Callable =" in token: # Remove default argument if it is a function head, default_arg = token.rsplit("=", 2) tokens[i] = head.strip(" ") + "= ..." tokens = [t for t in tokens if t != ""] line = ", ".join(tokens) return line def get_method_definitions( file_path: Union[str, List[str]], files_to_exclude: Set[str], deprecated_files: Set[str], default_output_type: str, method_to_special_output_type: Dict[str, str], root: str = "", ) -> List[str]: """ #.pyi generation for functional DataPipes Process. # 1. Find files that we want to process (exclude the ones who don't) # 2. Parse method name and signature # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces """ if root == "": root = str(pathlib.Path(__file__).parent.resolve()) file_path = [file_path] if isinstance(file_path, str) else file_path file_path = [os.path.join(root, path) for path in file_path] file_paths = find_file_paths( file_path, files_to_exclude=files_to_exclude.union(deprecated_files) ) ( methods_and_signatures, methods_and_class_names, methods_w_special_output_types, methods_and_doc_strings, ) = parse_datapipe_files(file_paths) for fn_name in method_to_special_output_type: if fn_name not in methods_w_special_output_types: methods_w_special_output_types.add(fn_name) method_definitions = [] for method_name, arguments in methods_and_signatures.items(): class_name = methods_and_class_names[method_name] if method_name in methods_w_special_output_types: output_type = method_to_special_output_type[method_name] else: output_type = default_output_type doc_string = "".join(methods_and_doc_strings[method_name]) if doc_string == "": doc_string = " ...\n" method_definitions.append( f"# Functional form of '{class_name}'\n" f"def {method_name}({arguments}) -> {output_type}:\n" f"{doc_string}" ) method_definitions.sort( key=lambda s: s.split("\n")[1] ) # sorting based on method_name return method_definitions # Defined outside of main() so they can be imported by TorchData iterDP_file_path: str = "iter" iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"} iterDP_deprecated_files: Set[str] = set() iterDP_method_to_special_output_type: Dict[str, str] = { "demux": "List[IterDataPipe]", "fork": "List[IterDataPipe]", } mapDP_file_path: str = "map" mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"} mapDP_deprecated_files: Set[str] = set() mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"} def main() -> None: """ # Inject file into template datapipe.pyi.in. TODO: The current implementation of this script only generates interfaces for built-in methods. To generate interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`. """ iter_method_definitions = get_method_definitions( iterDP_file_path, iterDP_files_to_exclude, iterDP_deprecated_files, "IterDataPipe", iterDP_method_to_special_output_type, ) map_method_definitions = get_method_definitions( mapDP_file_path, mapDP_files_to_exclude, mapDP_deprecated_files, "MapDataPipe", mapDP_method_to_special_output_type, ) path = pathlib.Path(__file__).parent.resolve() replacements = [ ("${IterDataPipeMethods}", iter_method_definitions, 4), ("${MapDataPipeMethods}", map_method_definitions, 4), ] gen_from_template( dir=str(path), template_name="datapipe.pyi.in", output_name="datapipe.pyi", replacements=replacements, ) if __name__ == "__main__": main()