• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Generate Flatbuffer binary from json."""
2
3load(
4    "//tensorflow:tensorflow.bzl",
5    "tf_binary_additional_srcs",
6    "tf_cc_shared_object",
7    "tf_cc_test",
8)
9
10def tflite_copts():
11    """Defines compile time flags."""
12    copts = [
13        "-DFARMHASH_NO_CXX_STRING",
14    ] + select({
15        str(Label("//tensorflow:android_arm64")): [
16            "-std=c++11",
17            "-O3",
18        ],
19        str(Label("//tensorflow:android_arm")): [
20            "-mfpu=neon",
21            "-mfloat-abi=softfp",
22            "-std=c++11",
23            "-O3",
24        ],
25        str(Label("//tensorflow:android_x86")): [
26            "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
27        ],
28        str(Label("//tensorflow:ios_x86_64")): [
29            "-msse4.1",
30        ],
31        str(Label("//tensorflow:windows")): [
32            "/DTF_COMPILE_LIBRARY",
33            "/wd4018",  # -Wno-sign-compare
34        ],
35        "//conditions:default": [
36            "-Wno-sign-compare",
37        ],
38    }) + select({
39        str(Label("//tensorflow:with_default_optimizations")): [],
40        "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
41    })
42
43    return copts
44
45LINKER_SCRIPT = "//tensorflow/lite/java/src/main/native:version_script.lds"
46
47def tflite_linkopts_unstripped():
48    """Defines linker flags to reduce size of TFLite binary.
49
50       These are useful when trying to investigate the relative size of the
51       symbols in TFLite.
52
53    Returns:
54       a select object with proper linkopts
55    """
56
57    # In case you wonder why there's no --icf is because the gains were
58    # negligible, and created potential compatibility problems.
59    return select({
60        "//tensorflow:android": [
61            "-Wl,--no-export-dynamic",  # Only inc syms referenced by dynamic obj.
62            "-Wl,--exclude-libs,ALL",  # Exclude syms in all libs from auto export.
63            "-Wl,--gc-sections",  # Eliminate unused code and data.
64            "-Wl,--as-needed",  # Don't link unused libs.
65        ],
66        "//conditions:default": [],
67    })
68
69def tflite_jni_linkopts_unstripped():
70    """Defines linker flags to reduce size of TFLite binary with JNI.
71
72       These are useful when trying to investigate the relative size of the
73       symbols in TFLite.
74
75    Returns:
76       a select object with proper linkopts
77    """
78
79    # In case you wonder why there's no --icf is because the gains were
80    # negligible, and created potential compatibility problems.
81    return select({
82        "//tensorflow:android": [
83            "-Wl,--gc-sections",  # Eliminate unused code and data.
84            "-Wl,--as-needed",  # Don't link unused libs.
85        ],
86        "//conditions:default": [],
87    })
88
89def tflite_symbol_opts():
90    """Defines linker flags whether to include symbols or not."""
91    return select({
92        "//tensorflow:android": [
93            "-latomic",  # Required for some uses of ISO C++11 <atomic> in x86.
94        ],
95        "//conditions:default": [],
96    }) + select({
97        "//tensorflow:debug": [],
98        "//conditions:default": [
99            "-s",  # Omit symbol table, for all non debug builds
100        ],
101    })
102
103def tflite_linkopts():
104    """Defines linker flags to reduce size of TFLite binary."""
105    return tflite_linkopts_unstripped() + tflite_symbol_opts()
106
107def tflite_jni_linkopts():
108    """Defines linker flags to reduce size of TFLite binary with JNI."""
109    return tflite_jni_linkopts_unstripped() + tflite_symbol_opts()
110
111def tflite_jni_binary(
112        name,
113        copts = tflite_copts(),
114        linkopts = tflite_jni_linkopts(),
115        linkscript = LINKER_SCRIPT,
116        linkshared = 1,
117        linkstatic = 1,
118        testonly = 0,
119        deps = [],
120        srcs = []):
121    """Builds a jni binary for TFLite."""
122    linkopts = linkopts + [
123        "-Wl,--version-script",  # Export only jni functions & classes.
124        "$(location {})".format(linkscript),
125    ]
126    native.cc_binary(
127        name = name,
128        copts = copts,
129        linkshared = linkshared,
130        linkstatic = linkstatic,
131        deps = deps + [linkscript],
132        srcs = srcs,
133        linkopts = linkopts,
134        testonly = testonly,
135    )
136
137def tflite_cc_shared_object(
138        name,
139        copts = tflite_copts(),
140        linkopts = [],
141        linkstatic = 1,
142        deps = []):
143    """Builds a shared object for TFLite."""
144    tf_cc_shared_object(
145        name = name,
146        copts = copts,
147        linkstatic = linkstatic,
148        linkopts = linkopts + tflite_jni_linkopts(),
149        framework_so = [],
150        deps = deps,
151    )
152
153def tf_to_tflite(name, src, options, out):
154    """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
155
156    Args:
157      name: Name of rule.
158      src: name of the input graphdef file.
159      options: options passed to TOCO.
160      out: name of the output flatbuffer file.
161    """
162
163    toco_cmdline = " ".join([
164        "$(location //tensorflow/lite/toco:toco)",
165        "--input_format=TENSORFLOW_GRAPHDEF",
166        "--output_format=TFLITE",
167        ("--input_file=$(location %s)" % src),
168        ("--output_file=$(location %s)" % out),
169    ] + options)
170    native.genrule(
171        name = name,
172        srcs = [src],
173        outs = [out],
174        cmd = toco_cmdline,
175        tools = ["//tensorflow/lite/toco:toco"] + tf_binary_additional_srcs(),
176    )
177
178def tflite_to_json(name, src, out):
179    """Convert a TF Lite flatbuffer to JSON.
180
181    Args:
182      name: Name of rule.
183      src: name of the input flatbuffer file.
184      out: name of the output JSON file.
185    """
186
187    flatc = "@flatbuffers//:flatc"
188    schema = "//tensorflow/lite/schema:schema.fbs"
189    native.genrule(
190        name = name,
191        srcs = [schema, src],
192        outs = [out],
193        cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
194               "$(location %s) --raw-binary --strict-json -t" +
195               " -o /tmp $(location %s) -- $${TMP}.bin &&" +
196               "cp $${TMP}.json $(location %s)") %
197              (src, flatc, schema, out),
198        tools = [flatc],
199    )
200
201def json_to_tflite(name, src, out):
202    """Convert a JSON file to TF Lite's flatbuffer.
203
204    Args:
205      name: Name of rule.
206      src: name of the input JSON file.
207      out: name of the output flatbuffer file.
208    """
209
210    flatc = "@flatbuffers//:flatc"
211    schema = "//tensorflow/lite/schema:schema_fbs"
212    native.genrule(
213        name = name,
214        srcs = [schema, src],
215        outs = [out],
216        cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
217               "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
218               " -o /tmp $(location %s) $${TMP}.json &&" +
219               "cp $${TMP}.bin $(location %s)") %
220              (src, flatc, schema, out),
221        tools = [flatc],
222    )
223
224# This is the master list of generated examples that will be made into tests. A
225# function called make_XXX_tests() must also appear in generate_examples.py.
226# Disable a test by adding it to the blacklists specified in
227# generated_test_models_failing().
228def generated_test_models():
229    return [
230        "abs",
231        "add",
232        "add_n",
233        "arg_min_max",
234        "avg_pool",
235        "batch_to_space_nd",
236        "ceil",
237        "concat",
238        "constant",
239        "control_dep",
240        "conv",
241        "conv2d_transpose",
242        "conv_with_shared_weights",
243        "conv_to_depthwiseconv_with_shared_weights",
244        "cos",
245        "depthwiseconv",
246        "div",
247        "elu",
248        "equal",
249        "exp",
250        "expand_dims",
251        "fill",
252        "floor",
253        "floor_div",
254        "floor_mod",
255        "fully_connected",
256        "fused_batch_norm",
257        "gather",
258        "gather_nd",
259        "gather_with_constant",
260        "global_batch_norm",
261        "greater",
262        "greater_equal",
263        "sum",
264        "l2norm",
265        "l2norm_shared_epsilon",
266        "l2_pool",
267        "leaky_relu",
268        "less",
269        "less_equal",
270        "local_response_norm",
271        "log_softmax",
272        "log",
273        "logical_and",
274        "logical_or",
275        "logical_xor",
276        "lstm",
277        "max_pool",
278        "maximum",
279        "mean",
280        "minimum",
281        "mirror_pad",
282        "mul",
283        "neg",
284        "not_equal",
285        "one_hot",
286        "pack",
287        "pad",
288        "padv2",
289        "placeholder_with_default",
290        "prelu",
291        "pow",
292        "range",
293        "rank",
294        "reduce_any",
295        "reduce_max",
296        "reduce_min",
297        "reduce_prod",
298        "relu",
299        "relu1",
300        "relu6",
301        "reshape",
302        "resize_bilinear",
303        "resolve_constant_strided_slice",
304        "reverse_sequence",
305        "reverse_v2",
306        "rsqrt",
307        "shape",
308        "sigmoid",
309        "sin",
310        "slice",
311        "softmax",
312        "space_to_batch_nd",
313        "space_to_depth",
314        "sparse_to_dense",
315        "split",
316        "splitv",
317        "sqrt",
318        "square",
319        "squared_difference",
320        "squeeze",
321        "strided_slice",
322        "strided_slice_1d_exhaustive",
323        "sub",
324        "tile",
325        "topk",
326        "transpose",
327        "transpose_conv",
328        "unidirectional_sequence_lstm",
329        "unidirectional_sequence_rnn",
330        "unique",
331        "unpack",
332        "unroll_batch_matmul",
333        "where",
334        "zeros_like",
335    ]
336
337# List of models that fail generated tests for the conversion mode.
338# If you have to disable a test, please add here with a link to the appropriate
339# bug or issue.
340def generated_test_models_failing(conversion_mode):
341    if conversion_mode == "toco-flex":
342        return [
343            "lstm",  # TODO(b/117510976): Restore when lstm flex conversion works.
344            "unroll_batch_matmul",  # TODO(b/123030774): Fails in 1.13 tests.
345            "unidirectional_sequence_lstm",
346            "unidirectional_sequence_rnn",
347        ]
348
349    return []
350
351def generated_test_conversion_modes():
352    """Returns a list of conversion modes."""
353
354    # TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050.
355    return ["toco-flex", ""]
356
357def generated_test_models_all():
358    """Generates a list of all tests with the different converters.
359
360    Returns:
361      List of tuples representing:
362            (conversion mode, name of test, test tags, test args).
363    """
364    conversion_modes = generated_test_conversion_modes()
365    tests = generated_test_models()
366    options = []
367    for conversion_mode in conversion_modes:
368        failing_tests = generated_test_models_failing(conversion_mode)
369        for test in tests:
370            tags = []
371            args = []
372            if test in failing_tests:
373                tags.append("notap")
374                tags.append("manual")
375            if conversion_mode:
376                test += "_%s" % conversion_mode
377
378            # Flex conversion shouldn't suffer from the same conversion bugs
379            # listed for the default TFLite kernel backend.
380            if conversion_mode == "toco-flex":
381                args.append("--ignore_known_bugs=false")
382            options.append((conversion_mode, test, tags, args))
383    return options
384
385def gen_zip_test(name, test_name, conversion_mode, **kwargs):
386    """Generate a zipped-example test and its dependent zip files.
387
388    Args:
389      name: str. Resulting cc_test target name
390      test_name: str. Test targets this model. Comes from the list above.
391      conversion_mode: str. Which conversion mode to run with. Comes from the
392        list above.
393      **kwargs: tf_cc_test kwargs
394    """
395    toco = "//tensorflow/lite/toco:toco"
396    flags = ""
397    if conversion_mode:
398        # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050.
399        # if conversion_mode == "pb2lite":
400        #     toco = "//tensorflow/lite/experimental/pb2lite:pb2lite"
401        flags = "--ignore_toco_errors --run_with_flex"
402
403    gen_zipped_test_file(
404        name = "zip_%s" % test_name,
405        file = "%s.zip" % test_name,
406        toco = toco,
407        flags = flags,
408    )
409    tf_cc_test(name, **kwargs)
410
411def gen_zipped_test_file(name, file, toco, flags):
412    """Generate a zip file of tests by using :generate_examples.
413
414    Args:
415      name: str. Name of output. We will produce "`file`.files" as a target.
416      file: str. The name of one of the generated_examples targets, e.g. "transpose"
417      toco: str. Pathname of toco binary to run
418      flags: str. Any additional flags to include
419    """
420    native.genrule(
421        name = file + ".files",
422        cmd = (("$(locations :generate_examples) --toco $(locations {0}) " +
423                " --zip_to_output {1} {2} $(@D)").format(toco, file, flags)),
424        outs = [file],
425        tools = [
426            ":generate_examples",
427            toco,
428        ],
429    )
430
431    native.filegroup(
432        name = name,
433        srcs = [file],
434    )
435
436def gen_selected_ops(name, model):
437    """Generate the library that includes only used ops.
438
439    Args:
440      name: Name of the generated library.
441      model: TFLite model to interpret.
442    """
443    out = name + "_registration.cc"
444    tool = "//tensorflow/lite/tools:generate_op_registrations"
445    tflite_path = "//tensorflow/lite"
446    native.genrule(
447        name = name,
448        srcs = [model],
449        outs = [out],
450        cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") %
451              (tool, model, out, tflite_path[2:]),
452        tools = [tool],
453    )
454
455def flex_dep(target_op_sets):
456    if "SELECT_TF_OPS" in target_op_sets:
457        return ["//tensorflow/lite/delegates/flex:delegate"]
458    else:
459        return []
460
461def gen_model_coverage_test(src, model_name, data, failure_type, tags):
462    """Generates Python test targets for testing TFLite models.
463
464    Args:
465      src: Main source file.
466      model_name: Name of the model to test (must be also listed in the 'data'
467        dependencies)
468      data: List of BUILD targets linking the data.
469      failure_type: List of failure types (none, toco, crash, inference)
470        expected for the corresponding combinations of op sets
471        ("TFLITE_BUILTINS", "TFLITE_BUILTINS,SELECT_TF_OPS", "SELECT_TF_OPS").
472      tags: List of strings of additional tags.
473    """
474    i = 0
475    for target_op_sets in ["TFLITE_BUILTINS", "TFLITE_BUILTINS,SELECT_TF_OPS", "SELECT_TF_OPS"]:
476        args = []
477        if failure_type[i] != "none":
478            args.append("--failure_type=%s" % failure_type[i])
479        i = i + 1
480        native.py_test(
481            name = "model_coverage_test_%s_%s" % (model_name, target_op_sets.lower().replace(",", "_")),
482            srcs = [src],
483            main = src,
484            size = "large",
485            args = [
486                "--model_name=%s" % model_name,
487                "--target_ops=%s" % target_op_sets,
488            ] + args,
489            data = data,
490            srcs_version = "PY2AND3",
491            tags = [
492                "no_oss",
493                "no_windows",
494            ] + tags,
495            deps = [
496                "//tensorflow/lite/testing/model_coverage:model_coverage_lib",
497                "//tensorflow/lite/python:lite",
498                "//tensorflow/python:client_testlib",
499            ] + flex_dep(target_op_sets),
500        )
501