1# Lint as: python2, python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Moves source files to match Arduino library conventions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import argparse 23import glob 24import os 25 26import six 27 28 29def rename_example_subfolder_files(library_dir): 30 """Moves source files in example subfolders to equivalents at root.""" 31 patterns = ['*.h', '*.cpp', '*.c'] 32 for pattern in patterns: 33 search_path = os.path.join(library_dir, 'examples/*/*', pattern) 34 for source_file_path in glob.glob(search_path): 35 source_file_dir = os.path.dirname(source_file_path) 36 source_file_base = os.path.basename(source_file_path) 37 new_source_file_path = source_file_dir + '_' + source_file_base 38 os.rename(source_file_path, new_source_file_path) 39 40 41def move_person_data(library_dir): 42 """Moves the downloaded person model into the examples folder.""" 43 old_person_data_path = os.path.join( 44 library_dir, 'src/tensorflow/lite/micro/tools/make/downloads/' + 45 'person_model_int8/person_detect_model_data.cpp') 46 new_person_data_path = os.path.join( 47 library_dir, 'examples/person_detection/person_detect_model_data.cpp') 48 if os.path.exists(old_person_data_path): 49 os.rename(old_person_data_path, new_person_data_path) 50 # Update include. 51 with open(new_person_data_path, 'r') as source_file: 52 file_contents = source_file.read() 53 file_contents = file_contents.replace( 54 six.ensure_str('#include "tensorflow/lite/micro/examples/' + 55 'person_detection/person_detect_model_data.h"'), 56 '#include "person_detect_model_data.h"') 57 with open(new_person_data_path, 'w') as source_file: 58 source_file.write(file_contents) 59 60 61def move_image_data_experimental(library_dir): 62 """Moves the downloaded image detection model into the examples folder.""" 63 old_image_data_path = os.path.join( 64 library_dir, 'src/tensorflow/lite/micro/tools/make/downloads/' + 65 'image_recognition_model/image_recognition_model.cpp') 66 new_image_data_path = os.path.join( 67 library_dir, 68 'examples/image_recognition_experimental/image_recognition_model.cpp') 69 if os.path.exists(old_image_data_path): 70 os.rename(old_image_data_path, new_image_data_path) 71 # Update include. 72 with open(new_image_data_path, 'r') as source_file: 73 file_contents = source_file.read() 74 file_contents = file_contents.replace( 75 six.ensure_str('#include "tensorflow/lite/micro/examples/' + 76 'image_recognition_example/image_recognition_model.h"'), 77 '#include "image_recognition_model.h"') 78 with open(new_image_data_path, 'w') as source_file: 79 source_file.write(file_contents) 80 81 82def rename_example_main_inos(library_dir): 83 """Makes sure the .ino sketch files match the example name.""" 84 search_path = os.path.join(library_dir, 'examples/*', 'main.ino') 85 for ino_path in glob.glob(search_path): 86 example_path = os.path.dirname(ino_path) 87 example_name = os.path.basename(example_path) 88 new_ino_path = os.path.join(example_path, example_name + '.ino') 89 os.rename(ino_path, new_ino_path) 90 91 92def main(unparsed_args): 93 """Control the rewriting of source files.""" 94 library_dir = unparsed_args[0] 95 rename_example_subfolder_files(library_dir) 96 rename_example_main_inos(library_dir) 97 move_person_data(library_dir) 98 move_image_data_experimental(library_dir) 99 100 101def parse_args(): 102 """Converts the raw arguments into accessible flags.""" 103 parser = argparse.ArgumentParser() 104 _, unparsed_args = parser.parse_known_args() 105 106 main(unparsed_args) 107 108 109if __name__ == '__main__': 110 parse_args() 111