1# Lint as: python2, python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Resolves non-system C/C++ includes to their full paths. 17 18Used to generate Arduino and ESP-IDF examples. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import argparse 26import os 27import re 28import sys 29 30import six 31 32 33EXAMPLE_DIR_PATH = 'tensorflow/lite/micro/examples/' 34 35 36def replace_arduino_includes(line, supplied_headers_list): 37 """Updates any includes to reference the new Arduino library paths.""" 38 include_match = re.match(r'(.*#include.*")(.*)(")', line) 39 if include_match: 40 path = include_match.group(2) 41 for supplied_header in supplied_headers_list: 42 if six.ensure_str(supplied_header).endswith(path): 43 path = supplied_header 44 break 45 line = include_match.group(1) + six.ensure_str(path) + include_match.group( 46 3) 47 return line 48 49 50def replace_arduino_main(line): 51 """Updates any occurences of a bare main definition to the Arduino equivalent.""" 52 main_match = re.match(r'(.*int )(main)(\(.*)', line) 53 if main_match: 54 line = main_match.group(1) + 'tflite_micro_main' + main_match.group(3) 55 return line 56 57 58def check_ino_functions(input_text): 59 """Ensures the required functions exist.""" 60 # We're moving to an Arduino-friendly structure for all our examples, so they 61 # have to have a setup() and loop() function, just like their IDE expects. 62 if not re.search(r'void setup\(\) \{', input_text): 63 raise Exception( 64 'All examples must have a setup() function for Arduino compatiblity\n' + 65 input_text) 66 if not re.search(r'void loop\(\) \{', input_text): 67 raise Exception( 68 'All examples must have a loop() function for Arduino compatiblity') 69 return input_text 70 71 72def add_example_ino_library_include(input_text): 73 """Makes sure the example includes the header that loads the library.""" 74 return re.sub(r'#include ', '#include <TensorFlowLite.h>\n\n#include ', 75 input_text, 1) 76 77 78def replace_ardunio_example_includes(line, _): 79 """Updates any includes for local example files.""" 80 # Because the export process moves the example source and header files out of 81 # their default locations into the top-level 'examples' folder in the Arduino 82 # library, we have to update any include references to match. 83 dir_path = 'tensorflow/lite/micro/examples/' 84 include_match = re.match( 85 r'(.*#include.*")' + six.ensure_str(dir_path) + r'([^/]+)/(.*")', line) 86 if include_match: 87 flattened_name = re.sub(r'/', '_', include_match.group(3)) 88 line = include_match.group(1) + flattened_name 89 return line 90 91 92def replace_esp_example_includes(line, source_path): 93 """Updates any includes for local example files.""" 94 # Because the export process moves the example source and header files out of 95 # their default locations into the top-level 'main' folder in the ESP-IDF 96 # project, we have to update any include references to match. 97 include_match = re.match(r'.*#include.*"(' + EXAMPLE_DIR_PATH + r'.*)"', line) 98 99 if include_match: 100 # Compute the target path relative from the source's directory 101 target_path = include_match.group(1) 102 source_dirname = os.path.dirname(source_path) 103 rel_to_target = os.path.relpath(target_path, start=source_dirname) 104 105 line = '#include "%s"' % rel_to_target 106 return line 107 108 109def transform_arduino_sources(input_lines, flags): 110 """Transform sources for the Arduino platform. 111 112 Args: 113 input_lines: A sequence of lines from the input file to process. 114 flags: Flags indicating which transformation(s) to apply. 115 116 Returns: 117 The transformed output as a string. 118 """ 119 supplied_headers_list = six.ensure_str(flags.third_party_headers).split(' ') 120 121 output_lines = [] 122 for line in input_lines: 123 line = replace_arduino_includes(line, supplied_headers_list) 124 if flags.is_example_ino or flags.is_example_source: 125 line = replace_ardunio_example_includes(line, flags.source_path) 126 else: 127 line = replace_arduino_main(line) 128 output_lines.append(line) 129 output_text = '\n'.join(output_lines) 130 131 if flags.is_example_ino: 132 output_text = check_ino_functions(output_text) 133 output_text = add_example_ino_library_include(output_text) 134 135 return output_text 136 137 138def transform_esp_sources(input_lines, flags): 139 """Transform sources for the ESP-IDF platform. 140 141 Args: 142 input_lines: A sequence of lines from the input file to process. 143 flags: Flags indicating which transformation(s) to apply. 144 145 Returns: 146 The transformed output as a string. 147 """ 148 output_lines = [] 149 for line in input_lines: 150 if flags.is_example_source: 151 line = replace_esp_example_includes(line, flags.source_path) 152 output_lines.append(line) 153 154 output_text = '\n'.join(output_lines) 155 return output_text 156 157 158def main(unused_args, flags): 159 """Transforms the input source file to work when exported as example.""" 160 input_file_lines = sys.stdin.read().split('\n') 161 162 output_text = '' 163 if flags.platform == 'arduino': 164 output_text = transform_arduino_sources(input_file_lines, flags) 165 elif flags.platform == 'esp': 166 output_text = transform_esp_sources(input_file_lines, flags) 167 168 sys.stdout.write(output_text) 169 170 171def parse_args(): 172 """Converts the raw arguments into accessible flags.""" 173 parser = argparse.ArgumentParser() 174 parser.add_argument( 175 '--platform', 176 choices=['arduino', 'esp'], 177 required=True, 178 help='Target platform.') 179 parser.add_argument( 180 '--third_party_headers', 181 type=str, 182 default='', 183 help='Space-separated list of headers to resolve.') 184 parser.add_argument( 185 '--is_example_ino', 186 dest='is_example_ino', 187 action='store_true', 188 help='Whether the destination is an example main ino.') 189 parser.add_argument( 190 '--is_example_source', 191 dest='is_example_source', 192 action='store_true', 193 help='Whether the destination is an example cpp or header file.') 194 parser.add_argument( 195 '--source_path', 196 type=str, 197 default='', 198 help='The relative path of the source code file.') 199 flags, unparsed = parser.parse_known_args() 200 201 main(unparsed, flags) 202 203 204if __name__ == '__main__': 205 parse_args() 206