• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Resolves non-system C/C++ includes to their full paths.
17
18Used to generate Arduino and ESP-IDF examples.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import argparse
26import os
27import re
28import sys
29
30import six
31
32
33EXAMPLE_DIR_PATH = 'tensorflow/lite/micro/examples/'
34
35
36def replace_arduino_includes(line, supplied_headers_list):
37  """Updates any includes to reference the new Arduino library paths."""
38  include_match = re.match(r'(.*#include.*")(.*)(")', line)
39  if include_match:
40    path = include_match.group(2)
41    for supplied_header in supplied_headers_list:
42      if six.ensure_str(supplied_header).endswith(path):
43        path = supplied_header
44        break
45    line = include_match.group(1) + six.ensure_str(path) + include_match.group(
46        3)
47  return line
48
49
50def replace_arduino_main(line):
51  """Updates any occurences of a bare main definition to the Arduino equivalent."""
52  main_match = re.match(r'(.*int )(main)(\(.*)', line)
53  if main_match:
54    line = main_match.group(1) + 'tflite_micro_main' + main_match.group(3)
55  return line
56
57
58def check_ino_functions(input_text):
59  """Ensures the required functions exist."""
60  # We're moving to an Arduino-friendly structure for all our examples, so they
61  # have to have a setup() and loop() function, just like their IDE expects.
62  if not re.search(r'void setup\(\) \{', input_text):
63    raise Exception(
64        'All examples must have a setup() function for Arduino compatiblity\n' +
65        input_text)
66  if not re.search(r'void loop\(\) \{', input_text):
67    raise Exception(
68        'All examples must have a loop() function for Arduino compatiblity')
69  return input_text
70
71
72def add_example_ino_library_include(input_text):
73  """Makes sure the example includes the header that loads the library."""
74  return re.sub(r'#include ', '#include <TensorFlowLite.h>\n\n#include ',
75                input_text, 1)
76
77
78def replace_ardunio_example_includes(line, _):
79  """Updates any includes for local example files."""
80  # Because the export process moves the example source and header files out of
81  # their default locations into the top-level 'examples' folder in the Arduino
82  # library, we have to update any include references to match.
83  dir_path = 'tensorflow/lite/micro/examples/'
84  include_match = re.match(
85      r'(.*#include.*")' + six.ensure_str(dir_path) + r'([^/]+)/(.*")', line)
86  if include_match:
87    flattened_name = re.sub(r'/', '_', include_match.group(3))
88    line = include_match.group(1) + flattened_name
89  return line
90
91
92def replace_esp_example_includes(line, source_path):
93  """Updates any includes for local example files."""
94  # Because the export process moves the example source and header files out of
95  # their default locations into the top-level 'main' folder in the ESP-IDF
96  # project, we have to update any include references to match.
97  include_match = re.match(r'.*#include.*"(' + EXAMPLE_DIR_PATH + r'.*)"', line)
98
99  if include_match:
100    # Compute the target path relative from the source's directory
101    target_path = include_match.group(1)
102    source_dirname = os.path.dirname(source_path)
103    rel_to_target = os.path.relpath(target_path, start=source_dirname)
104
105    line = '#include "%s"' % rel_to_target
106  return line
107
108
109def transform_arduino_sources(input_lines, flags):
110  """Transform sources for the Arduino platform.
111
112  Args:
113    input_lines: A sequence of lines from the input file to process.
114    flags: Flags indicating which transformation(s) to apply.
115
116  Returns:
117    The transformed output as a string.
118  """
119  supplied_headers_list = six.ensure_str(flags.third_party_headers).split(' ')
120
121  output_lines = []
122  for line in input_lines:
123    line = replace_arduino_includes(line, supplied_headers_list)
124    if flags.is_example_ino or flags.is_example_source:
125      line = replace_ardunio_example_includes(line, flags.source_path)
126    else:
127      line = replace_arduino_main(line)
128    output_lines.append(line)
129  output_text = '\n'.join(output_lines)
130
131  if flags.is_example_ino:
132    output_text = check_ino_functions(output_text)
133    output_text = add_example_ino_library_include(output_text)
134
135  return output_text
136
137
138def transform_esp_sources(input_lines, flags):
139  """Transform sources for the ESP-IDF platform.
140
141  Args:
142    input_lines: A sequence of lines from the input file to process.
143    flags: Flags indicating which transformation(s) to apply.
144
145  Returns:
146    The transformed output as a string.
147  """
148  output_lines = []
149  for line in input_lines:
150    if flags.is_example_source:
151      line = replace_esp_example_includes(line, flags.source_path)
152    output_lines.append(line)
153
154  output_text = '\n'.join(output_lines)
155  return output_text
156
157
158def main(unused_args, flags):
159  """Transforms the input source file to work when exported as example."""
160  input_file_lines = sys.stdin.read().split('\n')
161
162  output_text = ''
163  if flags.platform == 'arduino':
164    output_text = transform_arduino_sources(input_file_lines, flags)
165  elif flags.platform == 'esp':
166    output_text = transform_esp_sources(input_file_lines, flags)
167
168  sys.stdout.write(output_text)
169
170
171def parse_args():
172  """Converts the raw arguments into accessible flags."""
173  parser = argparse.ArgumentParser()
174  parser.add_argument(
175      '--platform',
176      choices=['arduino', 'esp'],
177      required=True,
178      help='Target platform.')
179  parser.add_argument(
180      '--third_party_headers',
181      type=str,
182      default='',
183      help='Space-separated list of headers to resolve.')
184  parser.add_argument(
185      '--is_example_ino',
186      dest='is_example_ino',
187      action='store_true',
188      help='Whether the destination is an example main ino.')
189  parser.add_argument(
190      '--is_example_source',
191      dest='is_example_source',
192      action='store_true',
193      help='Whether the destination is an example cpp or header file.')
194  parser.add_argument(
195      '--source_path',
196      type=str,
197      default='',
198      help='The relative path of the source code file.')
199  flags, unparsed = parser.parse_known_args()
200
201  main(unparsed, flags)
202
203
204if __name__ == '__main__':
205  parse_args()
206