1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Generate some SavedModels for use by AOT compilation tests.""" 16 17import os 18 19from absl import app 20from absl import flags 21 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.saved_model import save 28from tensorflow.python.trackable import autotrackable 29 30 31flags.DEFINE_string('out_dir', None, 32 'Directory to output saved models to.') 33 34FLAGS = flags.FLAGS 35 36 37def create_large_matmul_savedmodel(out_dir): 38 """Create a SavedModel that performs a large matmul.""" 39 root = autotrackable.AutoTrackable() 40 root.f = def_function.function( 41 lambda x, y: math_ops.matmul(x, y), # pylint: disable=unnecessary-lambda 42 input_signature=[tensor_spec.TensorSpec([3000, 5000], dtypes.float32), 43 tensor_spec.TensorSpec([5000, 4000], dtypes.float32),]) 44 root.f(x=array_ops.zeros((3000, 5000)), 45 y=array_ops.zeros((5000, 4000))) 46 save_dir = os.path.join(out_dir, 'x_matmul_y_large') 47 save.save(root, save_dir, root.f) 48 # This simple SavedModel lacks any variables, but we need to create a 49 # variables.index file to make bazel genrule happy. 50 with open(os.path.join(save_dir, 'variables', 'variables.index'), 'w'): 51 pass 52 53 54def create_small_matmul_savedmodel(out_dir): 55 """Create a SavedModel that performs a small matmul.""" 56 root = autotrackable.AutoTrackable() 57 root.f = def_function.function( 58 lambda x, y: math_ops.matmul(x, y), # pylint: disable=unnecessary-lambda 59 input_signature=[tensor_spec.TensorSpec([3, 5], dtypes.float32), 60 tensor_spec.TensorSpec([5, 4], dtypes.float32),]) 61 root.f(x=array_ops.zeros((3, 5)), 62 y=array_ops.zeros((5, 4))) 63 save_dir = os.path.join(out_dir, 'x_matmul_y_small') 64 save.save(root, save_dir, root.f) 65 # This simple SavedModel lacks any variables, but we need to create a 66 # variables.index file to make bazel genrule happy. 67 with open(os.path.join(save_dir, 'variables', 'variables.index'), 'w'): 68 pass 69 70 71def main(unused_args): 72 create_small_matmul_savedmodel(FLAGS.out_dir) 73 create_large_matmul_savedmodel(FLAGS.out_dir) 74 75 76if __name__ == '__main__': 77 flags.mark_flag_as_required('out_dir') 78 app.run(main) 79