1"""Generate Flatbuffer binary from json.""" 2 3load( 4 "//tensorflow:tensorflow.bzl", 5 "tf_binary_additional_srcs", 6 "tf_cc_shared_object", 7 "tf_cc_test", 8) 9 10def tflite_copts(): 11 """Defines compile time flags.""" 12 copts = [ 13 "-DFARMHASH_NO_CXX_STRING", 14 ] + select({ 15 str(Label("//tensorflow:android_arm64")): [ 16 "-std=c++11", 17 "-O3", 18 ], 19 str(Label("//tensorflow:android_arm")): [ 20 "-mfpu=neon", 21 "-mfloat-abi=softfp", 22 "-std=c++11", 23 "-O3", 24 ], 25 str(Label("//tensorflow:android_x86")): [ 26 "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", 27 ], 28 str(Label("//tensorflow:ios_x86_64")): [ 29 "-msse4.1", 30 ], 31 str(Label("//tensorflow:windows")): [ 32 "/DTF_COMPILE_LIBRARY", 33 "/wd4018", # -Wno-sign-compare 34 ], 35 "//conditions:default": [ 36 "-Wno-sign-compare", 37 ], 38 }) + select({ 39 str(Label("//tensorflow:with_default_optimizations")): [], 40 "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], 41 }) 42 43 return copts 44 45LINKER_SCRIPT = "//tensorflow/lite/java/src/main/native:version_script.lds" 46 47def tflite_linkopts_unstripped(): 48 """Defines linker flags to reduce size of TFLite binary. 49 50 These are useful when trying to investigate the relative size of the 51 symbols in TFLite. 52 53 Returns: 54 a select object with proper linkopts 55 """ 56 57 # In case you wonder why there's no --icf is because the gains were 58 # negligible, and created potential compatibility problems. 59 return select({ 60 "//tensorflow:android": [ 61 "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. 62 "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. 63 "-Wl,--gc-sections", # Eliminate unused code and data. 64 "-Wl,--as-needed", # Don't link unused libs. 65 ], 66 "//conditions:default": [], 67 }) 68 69def tflite_jni_linkopts_unstripped(): 70 """Defines linker flags to reduce size of TFLite binary with JNI. 71 72 These are useful when trying to investigate the relative size of the 73 symbols in TFLite. 74 75 Returns: 76 a select object with proper linkopts 77 """ 78 79 # In case you wonder why there's no --icf is because the gains were 80 # negligible, and created potential compatibility problems. 81 return select({ 82 "//tensorflow:android": [ 83 "-Wl,--gc-sections", # Eliminate unused code and data. 84 "-Wl,--as-needed", # Don't link unused libs. 85 ], 86 "//conditions:default": [], 87 }) 88 89def tflite_symbol_opts(): 90 """Defines linker flags whether to include symbols or not.""" 91 return select({ 92 "//tensorflow:android": [ 93 "-latomic", # Required for some uses of ISO C++11 <atomic> in x86. 94 ], 95 "//conditions:default": [], 96 }) + select({ 97 "//tensorflow:debug": [], 98 "//conditions:default": [ 99 "-s", # Omit symbol table, for all non debug builds 100 ], 101 }) 102 103def tflite_linkopts(): 104 """Defines linker flags to reduce size of TFLite binary.""" 105 return tflite_linkopts_unstripped() + tflite_symbol_opts() 106 107def tflite_jni_linkopts(): 108 """Defines linker flags to reduce size of TFLite binary with JNI.""" 109 return tflite_jni_linkopts_unstripped() + tflite_symbol_opts() 110 111def tflite_jni_binary( 112 name, 113 copts = tflite_copts(), 114 linkopts = tflite_jni_linkopts(), 115 linkscript = LINKER_SCRIPT, 116 linkshared = 1, 117 linkstatic = 1, 118 testonly = 0, 119 deps = [], 120 srcs = []): 121 """Builds a jni binary for TFLite.""" 122 linkopts = linkopts + [ 123 "-Wl,--version-script", # Export only jni functions & classes. 124 "$(location {})".format(linkscript), 125 ] 126 native.cc_binary( 127 name = name, 128 copts = copts, 129 linkshared = linkshared, 130 linkstatic = linkstatic, 131 deps = deps + [linkscript], 132 srcs = srcs, 133 linkopts = linkopts, 134 testonly = testonly, 135 ) 136 137def tflite_cc_shared_object( 138 name, 139 copts = tflite_copts(), 140 linkopts = [], 141 linkstatic = 1, 142 deps = []): 143 """Builds a shared object for TFLite.""" 144 tf_cc_shared_object( 145 name = name, 146 copts = copts, 147 linkstatic = linkstatic, 148 linkopts = linkopts + tflite_jni_linkopts(), 149 framework_so = [], 150 deps = deps, 151 ) 152 153def tf_to_tflite(name, src, options, out): 154 """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. 155 156 Args: 157 name: Name of rule. 158 src: name of the input graphdef file. 159 options: options passed to TOCO. 160 out: name of the output flatbuffer file. 161 """ 162 163 toco_cmdline = " ".join([ 164 "$(location //tensorflow/lite/toco:toco)", 165 "--input_format=TENSORFLOW_GRAPHDEF", 166 "--output_format=TFLITE", 167 ("--input_file=$(location %s)" % src), 168 ("--output_file=$(location %s)" % out), 169 ] + options) 170 native.genrule( 171 name = name, 172 srcs = [src], 173 outs = [out], 174 cmd = toco_cmdline, 175 tools = ["//tensorflow/lite/toco:toco"] + tf_binary_additional_srcs(), 176 ) 177 178def tflite_to_json(name, src, out): 179 """Convert a TF Lite flatbuffer to JSON. 180 181 Args: 182 name: Name of rule. 183 src: name of the input flatbuffer file. 184 out: name of the output JSON file. 185 """ 186 187 flatc = "@flatbuffers//:flatc" 188 schema = "//tensorflow/lite/schema:schema.fbs" 189 native.genrule( 190 name = name, 191 srcs = [schema, src], 192 outs = [out], 193 cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + 194 "$(location %s) --raw-binary --strict-json -t" + 195 " -o /tmp $(location %s) -- $${TMP}.bin &&" + 196 "cp $${TMP}.json $(location %s)") % 197 (src, flatc, schema, out), 198 tools = [flatc], 199 ) 200 201def json_to_tflite(name, src, out): 202 """Convert a JSON file to TF Lite's flatbuffer. 203 204 Args: 205 name: Name of rule. 206 src: name of the input JSON file. 207 out: name of the output flatbuffer file. 208 """ 209 210 flatc = "@flatbuffers//:flatc" 211 schema = "//tensorflow/lite/schema:schema_fbs" 212 native.genrule( 213 name = name, 214 srcs = [schema, src], 215 outs = [out], 216 cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + 217 "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + 218 " -o /tmp $(location %s) $${TMP}.json &&" + 219 "cp $${TMP}.bin $(location %s)") % 220 (src, flatc, schema, out), 221 tools = [flatc], 222 ) 223 224# This is the master list of generated examples that will be made into tests. A 225# function called make_XXX_tests() must also appear in generate_examples.py. 226# Disable a test by adding it to the blacklists specified in 227# generated_test_models_failing(). 228def generated_test_models(): 229 return [ 230 "abs", 231 "add", 232 "add_n", 233 "arg_min_max", 234 "avg_pool", 235 "batch_to_space_nd", 236 "ceil", 237 "concat", 238 "constant", 239 "control_dep", 240 "conv", 241 "conv2d_transpose", 242 "conv_with_shared_weights", 243 "conv_to_depthwiseconv_with_shared_weights", 244 "cos", 245 "depthwiseconv", 246 "div", 247 "elu", 248 "equal", 249 "exp", 250 "expand_dims", 251 "fill", 252 "floor", 253 "floor_div", 254 "floor_mod", 255 "fully_connected", 256 "fused_batch_norm", 257 "gather", 258 "gather_nd", 259 "gather_with_constant", 260 "global_batch_norm", 261 "greater", 262 "greater_equal", 263 "sum", 264 "l2norm", 265 "l2norm_shared_epsilon", 266 "l2_pool", 267 "leaky_relu", 268 "less", 269 "less_equal", 270 "local_response_norm", 271 "log_softmax", 272 "log", 273 "logical_and", 274 "logical_or", 275 "logical_xor", 276 "lstm", 277 "max_pool", 278 "maximum", 279 "mean", 280 "minimum", 281 "mirror_pad", 282 "mul", 283 "neg", 284 "not_equal", 285 "one_hot", 286 "pack", 287 "pad", 288 "padv2", 289 "placeholder_with_default", 290 "prelu", 291 "pow", 292 "range", 293 "rank", 294 "reduce_any", 295 "reduce_max", 296 "reduce_min", 297 "reduce_prod", 298 "relu", 299 "relu1", 300 "relu6", 301 "reshape", 302 "resize_bilinear", 303 "resolve_constant_strided_slice", 304 "reverse_sequence", 305 "reverse_v2", 306 "rsqrt", 307 "shape", 308 "sigmoid", 309 "sin", 310 "slice", 311 "softmax", 312 "space_to_batch_nd", 313 "space_to_depth", 314 "sparse_to_dense", 315 "split", 316 "splitv", 317 "sqrt", 318 "square", 319 "squared_difference", 320 "squeeze", 321 "strided_slice", 322 "strided_slice_1d_exhaustive", 323 "sub", 324 "tile", 325 "topk", 326 "transpose", 327 "transpose_conv", 328 "unidirectional_sequence_lstm", 329 "unidirectional_sequence_rnn", 330 "unique", 331 "unpack", 332 "unroll_batch_matmul", 333 "where", 334 "zeros_like", 335 ] 336 337# List of models that fail generated tests for the conversion mode. 338# If you have to disable a test, please add here with a link to the appropriate 339# bug or issue. 340def generated_test_models_failing(conversion_mode): 341 if conversion_mode == "toco-flex": 342 return [ 343 "lstm", # TODO(b/117510976): Restore when lstm flex conversion works. 344 "unroll_batch_matmul", # TODO(b/123030774): Fails in 1.13 tests. 345 "unidirectional_sequence_lstm", 346 "unidirectional_sequence_rnn", 347 ] 348 349 return [] 350 351def generated_test_conversion_modes(): 352 """Returns a list of conversion modes.""" 353 354 # TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050. 355 return ["toco-flex", ""] 356 357def generated_test_models_all(): 358 """Generates a list of all tests with the different converters. 359 360 Returns: 361 List of tuples representing: 362 (conversion mode, name of test, test tags, test args). 363 """ 364 conversion_modes = generated_test_conversion_modes() 365 tests = generated_test_models() 366 options = [] 367 for conversion_mode in conversion_modes: 368 failing_tests = generated_test_models_failing(conversion_mode) 369 for test in tests: 370 tags = [] 371 args = [] 372 if test in failing_tests: 373 tags.append("notap") 374 tags.append("manual") 375 if conversion_mode: 376 test += "_%s" % conversion_mode 377 378 # Flex conversion shouldn't suffer from the same conversion bugs 379 # listed for the default TFLite kernel backend. 380 if conversion_mode == "toco-flex": 381 args.append("--ignore_known_bugs=false") 382 options.append((conversion_mode, test, tags, args)) 383 return options 384 385def gen_zip_test(name, test_name, conversion_mode, **kwargs): 386 """Generate a zipped-example test and its dependent zip files. 387 388 Args: 389 name: str. Resulting cc_test target name 390 test_name: str. Test targets this model. Comes from the list above. 391 conversion_mode: str. Which conversion mode to run with. Comes from the 392 list above. 393 **kwargs: tf_cc_test kwargs 394 """ 395 toco = "//tensorflow/lite/toco:toco" 396 flags = "" 397 if conversion_mode: 398 # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050. 399 # if conversion_mode == "pb2lite": 400 # toco = "//tensorflow/lite/experimental/pb2lite:pb2lite" 401 flags = "--ignore_toco_errors --run_with_flex" 402 403 gen_zipped_test_file( 404 name = "zip_%s" % test_name, 405 file = "%s.zip" % test_name, 406 toco = toco, 407 flags = flags, 408 ) 409 tf_cc_test(name, **kwargs) 410 411def gen_zipped_test_file(name, file, toco, flags): 412 """Generate a zip file of tests by using :generate_examples. 413 414 Args: 415 name: str. Name of output. We will produce "`file`.files" as a target. 416 file: str. The name of one of the generated_examples targets, e.g. "transpose" 417 toco: str. Pathname of toco binary to run 418 flags: str. Any additional flags to include 419 """ 420 native.genrule( 421 name = file + ".files", 422 cmd = (("$(locations :generate_examples) --toco $(locations {0}) " + 423 " --zip_to_output {1} {2} $(@D)").format(toco, file, flags)), 424 outs = [file], 425 tools = [ 426 ":generate_examples", 427 toco, 428 ], 429 ) 430 431 native.filegroup( 432 name = name, 433 srcs = [file], 434 ) 435 436def gen_selected_ops(name, model): 437 """Generate the library that includes only used ops. 438 439 Args: 440 name: Name of the generated library. 441 model: TFLite model to interpret. 442 """ 443 out = name + "_registration.cc" 444 tool = "//tensorflow/lite/tools:generate_op_registrations" 445 tflite_path = "//tensorflow/lite" 446 native.genrule( 447 name = name, 448 srcs = [model], 449 outs = [out], 450 cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") % 451 (tool, model, out, tflite_path[2:]), 452 tools = [tool], 453 ) 454 455def flex_dep(target_op_sets): 456 if "SELECT_TF_OPS" in target_op_sets: 457 return ["//tensorflow/lite/delegates/flex:delegate"] 458 else: 459 return [] 460 461def gen_model_coverage_test(src, model_name, data, failure_type, tags): 462 """Generates Python test targets for testing TFLite models. 463 464 Args: 465 src: Main source file. 466 model_name: Name of the model to test (must be also listed in the 'data' 467 dependencies) 468 data: List of BUILD targets linking the data. 469 failure_type: List of failure types (none, toco, crash, inference) 470 expected for the corresponding combinations of op sets 471 ("TFLITE_BUILTINS", "TFLITE_BUILTINS,SELECT_TF_OPS", "SELECT_TF_OPS"). 472 tags: List of strings of additional tags. 473 """ 474 i = 0 475 for target_op_sets in ["TFLITE_BUILTINS", "TFLITE_BUILTINS,SELECT_TF_OPS", "SELECT_TF_OPS"]: 476 args = [] 477 if failure_type[i] != "none": 478 args.append("--failure_type=%s" % failure_type[i]) 479 i = i + 1 480 native.py_test( 481 name = "model_coverage_test_%s_%s" % (model_name, target_op_sets.lower().replace(",", "_")), 482 srcs = [src], 483 main = src, 484 size = "large", 485 args = [ 486 "--model_name=%s" % model_name, 487 "--target_ops=%s" % target_op_sets, 488 ] + args, 489 data = data, 490 srcs_version = "PY2AND3", 491 tags = [ 492 "no_oss", 493 "no_windows", 494 ] + tags, 495 deps = [ 496 "//tensorflow/lite/testing/model_coverage:model_coverage_lib", 497 "//tensorflow/lite/python:lite", 498 "//tensorflow/python:client_testlib", 499 ] + flex_dep(target_op_sets), 500 ) 501