1# Copyright 2022 Google Inc. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import argparse 16import platform 17import subprocess 18from pathlib import Path 19 20parser = argparse.ArgumentParser() 21parser.add_argument( 22 "--flatc", help="path of the Flat C compiler relative to the root directory" 23) 24 25args = parser.parse_args() 26 27# Get the path where this script is located so we can invoke the script from 28# any directory and have the paths work correctly. 29script_path = Path(__file__).parent.resolve() 30 31# Get the root path as an absolute path, so all derived paths are absolute. 32root_path = script_path.parent.parent.absolute() 33 34# Get the location of the flatc executable, reading from the first command line 35# argument or defaulting to default names. 36flatc_exe = Path( 37 ("flatc" if not platform.system() == "Windows" else "flatc.exe") 38 if not args.flatc 39 else args.flatc 40) 41 42# Find and assert flatc compiler is present. 43if root_path in flatc_exe.parents: 44 flatc_exe = flatc_exe.relative_to(root_path) 45flatc_path = Path(root_path, flatc_exe) 46assert flatc_path.exists(), "Cannot find the flatc compiler " + str(flatc_path) 47 48# Execute the flatc compiler with the specified parameters 49def flatc(options, cwd=script_path): 50 cmd = [str(flatc_path)] + options 51 subprocess.check_call(cmd, cwd=str(cwd)) 52 53 54def reflection_fbs_path(): 55 return Path(root_path).joinpath("reflection", "reflection.fbs") 56 57 58def make_absolute(filename, path=script_path): 59 return str(Path(path, filename).absolute()) 60 61 62def assert_file_exists(filename, path=script_path): 63 file = Path(path, filename) 64 assert file.exists(), "could not find file: " + filename 65 return file 66 67 68def assert_file_doesnt_exists(filename, path=script_path): 69 file = Path(path, filename) 70 assert not file.exists(), "file exists but shouldn't: " + filename 71 return file 72 73 74def get_file_contents(filename, path=script_path): 75 file = Path(path, filename) 76 contents = "" 77 with open(file) as file: 78 contents = file.read() 79 return contents 80 81 82def assert_file_contains(file, needles): 83 with open(file) as file: 84 contents = file.read() 85 for needle in [needles] if isinstance(needles, str) else needles: 86 assert needle in contents, ( 87 "coudn't find '" + needle + "' in file: " + str(file) 88 ) 89 return file 90 91 92def assert_file_doesnt_contains(file, needles): 93 with open(file) as file: 94 contents = file.read() 95 for needle in [needles] if isinstance(needles, str) else needles: 96 assert needle not in contents, ( 97 "Found unexpected '" + needle + "' in file: " + str(file) 98 ) 99 return file 100 101 102def assert_file_and_contents( 103 file, needle, doesnt_contain=None, path=script_path, unlink=True 104): 105 assert_file_contains(assert_file_exists(file, path), needle) 106 if doesnt_contain: 107 assert_file_doesnt_contains(assert_file_exists(file, path), doesnt_contain) 108 if unlink: 109 Path(path, file).unlink() 110 111 112def run_all(*modules): 113 failing = 0 114 passing = 0 115 for module in modules: 116 methods = [ 117 func 118 for func in dir(module) 119 if callable(getattr(module, func)) and not func.startswith("__") 120 ] 121 module_failing = 0 122 module_passing = 0 123 for method in methods: 124 try: 125 print("{0}.{1}".format(module.__name__, method)) 126 getattr(module, method)(module) 127 print(" [PASSED]") 128 module_passing = module_passing + 1 129 except Exception as e: 130 print(" [FAILED]: " + str(e)) 131 module_failing = module_failing + 1 132 print( 133 "{0}: {1} of {2} passsed".format( 134 module.__name__, module_passing, module_passing + module_failing 135 ) 136 ) 137 passing = passing + module_passing 138 failing = failing + module_failing 139 return passing, failing 140