1"""Generate Flatbuffer binary from json."""
4    "//tensorflow:tensorflow.bzl",
5    "tf_binary_additional_srcs",
6    "tf_cc_shared_object",
7    "tf_cc_test",
10def tflite_copts():
11    """Defines compile time flags."""
12    copts = [
14    ] + select({
15        str(Label("//tensorflow:android_arm64")): [
16            "-std=c++11",
17            "-O3",
18        ],
19        str(Label("//tensorflow:android_arm")): [
20            "-mfpu=neon",
21            "-mfloat-abi=softfp",
22            "-std=c++11",
23            "-O3",
24        ],
25        str(Label("//tensorflow:android_x86")): [
27        ],
28        str(Label("//tensorflow:ios_x86_64")): [
29            "-msse4.1",
30        ],
31        str(Label("//tensorflow:windows")): [
32            "/DTF_COMPILE_LIBRARY",
33            "/wd4018",  # -Wno-sign-compare
34        ],
35        "//conditions:default": [
36            "-Wno-sign-compare",
37        ],
38    }) + select({
39        str(Label("//tensorflow:with_default_optimizations")): [],
40        "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
41    })
43    return copts
45LINKER_SCRIPT = "//tensorflow/lite/java/src/main/native:version_script.lds"
47def tflite_linkopts_unstripped():
48    """Defines linker flags to reduce size of TFLite binary.
50       These are useful when trying to investigate the relative size of the
51       symbols in TFLite.
53    Returns:
54       a select object with proper linkopts
55    """
57    # In case you wonder why there's no --icf is because the gains were
58    # negligible, and created potential compatibility problems.
59    return select({
60        "//tensorflow:android": [
61            "-Wl,--no-export-dynamic",  # Only inc syms referenced by dynamic obj.
62            "-Wl,--exclude-libs,ALL",  # Exclude syms in all libs from auto export.
63            "-Wl,--gc-sections",  # Eliminate unused code and data.
64            "-Wl,--as-needed",  # Don't link unused libs.
65        ],
66        "//conditions:default": [],
67    })
69def tflite_jni_linkopts_unstripped():
70    """Defines linker flags to reduce size of TFLite binary with JNI.
72       These are useful when trying to investigate the relative size of the
73       symbols in TFLite.
75    Returns:
76       a select object with proper linkopts
77    """
79    # In case you wonder why there's no --icf is because the gains were
80    # negligible, and created potential compatibility problems.
81    return select({
82        "//tensorflow:android": [
83            "-Wl,--gc-sections",  # Eliminate unused code and data.
84            "-Wl,--as-needed",  # Don't link unused libs.
85        ],
86        "//conditions:default": [],
87    })
89def tflite_symbol_opts():
90    """Defines linker flags whether to include symbols or not."""
91    return select({
92        "//tensorflow:android": [
93            "-latomic",  # Required for some uses of ISO C++11 <atomic> in x86.
94        ],
95        "//conditions:default": [],
96    }) + select({
97        "//tensorflow:debug": [],
98        "//conditions:default": [
99            "-s",  # Omit symbol table, for all non debug builds
100        ],
101    })
103def tflite_linkopts():
104    """Defines linker flags to reduce size of TFLite binary."""
105    return tflite_linkopts_unstripped() + tflite_symbol_opts()
107def tflite_jni_linkopts():
108    """Defines linker flags to reduce size of TFLite binary with JNI."""
109    return tflite_jni_linkopts_unstripped() + tflite_symbol_opts()
111def tflite_jni_binary(
112        name,
113        copts = tflite_copts(),
114        linkopts = tflite_jni_linkopts(),
115        linkscript = LINKER_SCRIPT,
116        linkshared = 1,
117        linkstatic = 1,
118        testonly = 0,
119        deps = [],
120        srcs = []):
121    """Builds a jni binary for TFLite."""
122    linkopts = linkopts + [
123        "-Wl,--version-script",  # Export only jni functions & classes.
124        "$(location {})".format(linkscript),
125    ]
126    native.cc_binary(
127        name = name,
128        copts = copts,
129        linkshared = linkshared,
130        linkstatic = linkstatic,
131        deps = deps + [linkscript],
132        srcs = srcs,
133        linkopts = linkopts,
134        testonly = testonly,
135    )
137def tflite_cc_shared_object(
138        name,
139        copts = tflite_copts(),
140        linkopts = [],
141        linkstatic = 1,
142        deps = []):
143    """Builds a shared object for TFLite."""
144    tf_cc_shared_object(
145        name = name,
146        copts = copts,
147        linkstatic = linkstatic,
148        linkopts = linkopts + tflite_jni_linkopts(),
149        framework_so = [],
150        deps = deps,
151    )
153def tf_to_tflite(name, src, options, out):
154    """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
156    Args:
157      name: Name of rule.
158      src: name of the input graphdef file.
159      options: options passed to TOCO.
160      out: name of the output flatbuffer file.
161    """
163    toco_cmdline = " ".join([
164        "$(location //tensorflow/lite/toco:toco)",
165        "--input_format=TENSORFLOW_GRAPHDEF",
166        "--output_format=TFLITE",
167        ("--input_file=$(location %s)" % src),
168        ("--output_file=$(location %s)" % out),
169    ] + options)
170    native.genrule(
171        name = name,
172        srcs = [src],
173        outs = [out],
174        cmd = toco_cmdline,
175        tools = ["//tensorflow/lite/toco:toco"] + tf_binary_additional_srcs(),
176    )
178def tflite_to_json(name, src, out):
179    """Convert a TF Lite flatbuffer to JSON.
181    Args:
182      name: Name of rule.
183      src: name of the input flatbuffer file.
184      out: name of the output JSON file.
185    """
187    flatc = "@flatbuffers//:flatc"
188    schema = "//tensorflow/lite/schema:schema.fbs"
189    native.genrule(
190        name = name,
191        srcs = [schema, src],
192        outs = [out],
193        cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
194               "$(location %s) --raw-binary --strict-json -t" +
195               " -o /tmp $(location %s) -- $${TMP}.bin &&" +
196               "cp $${TMP}.json $(location %s)") %
197              (src, flatc, schema, out),
198        tools = [flatc],
199    )
201def json_to_tflite(name, src, out):
202    """Convert a JSON file to TF Lite's flatbuffer.
204    Args:
205      name: Name of rule.
206      src: name of the input JSON file.
207      out: name of the output flatbuffer file.
208    """
210    flatc = "@flatbuffers//:flatc"
211    schema = "//tensorflow/lite/schema:schema_fbs"
212    native.genrule(
213        name = name,
214        srcs = [schema, src],
215        outs = [out],
216        cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
217               "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
218               " -o /tmp $(location %s) $${TMP}.json &&" +
219               "cp $${TMP}.bin $(location %s)") %
220              (src, flatc, schema, out),
221        tools = [flatc],
222    )
224# This is the master list of generated examples that will be made into tests. A
225# function called make_XXX_tests() must also appear in generate_examples.py.
226# Disable a test by adding it to the blacklists specified in
227# generated_test_models_failing().
228def generated_test_models():
229    return [
230        "abs",
231        "add",
232        "add_n",
233        "arg_min_max",
234        "avg_pool",
235        "batch_to_space_nd",
236        "ceil",
237        "concat",
238        "constant",
239        "control_dep",
240        "conv",
241        "conv2d_transpose",
242        "conv_with_shared_weights",
243        "conv_to_depthwiseconv_with_shared_weights",
244        "cos",
245        "depthwiseconv",
246        "div",
247        "elu",
248        "equal",
249        "exp",
250        "expand_dims",
251        "fill",
252        "floor",
253        "floor_div",
254        "floor_mod",
255        "fully_connected",
256        "fused_batch_norm",
257        "gather",
258        "gather_nd",
259        "gather_with_constant",
260        "global_batch_norm",
261        "greater",
262        "greater_equal",
263        "sum",
264        "l2norm",
265        "l2norm_shared_epsilon",
266        "l2_pool",
267        "leaky_relu",
268        "less",
269        "less_equal",
270        "local_response_norm",
271        "log_softmax",
272        "log",
273        "logical_and",
274        "logical_or",
275        "logical_xor",
276        "lstm",
277        "max_pool",
278        "maximum",
279        "mean",
280        "minimum",
281        "mirror_pad",
282        "mul",
283        "neg",
284        "not_equal",
285        "one_hot",
286        "pack",
287        "pad",
288        "padv2",
289        "placeholder_with_default",
290        "prelu",
291        "pow",
292        "range",
293        "rank",
294        "reduce_any",
295        "reduce_max",
296        "reduce_min",
297        "reduce_prod",
298        "relu",
299        "relu1",
300        "relu6",
301        "reshape",
302        "resize_bilinear",
303        "resolve_constant_strided_slice",
304        "reverse_sequence",
305        "reverse_v2",
306        "rsqrt",
307        "shape",
308        "sigmoid",
309        "sin",
310        "slice",
311        "softmax",
312        "space_to_batch_nd",
313        "space_to_depth",
314        "sparse_to_dense",
315        "split",
316        "splitv",
317        "sqrt",
318        "square",
319        "squared_difference",
320        "squeeze",
321        "strided_slice",
322        "strided_slice_1d_exhaustive",
323        "sub",
324        "tile",
325        "topk",
326        "transpose",
327        "transpose_conv",
328        "unidirectional_sequence_lstm",
329        "unidirectional_sequence_rnn",
330        "unique",
331        "unpack",
332        "unroll_batch_matmul",
333        "where",
334        "zeros_like",
335    ]
337# List of models that fail generated tests for the conversion mode.
338# If you have to disable a test, please add here with a link to the appropriate
339# bug or issue.
340def generated_test_models_failing(conversion_mode):
341    if conversion_mode == "toco-flex":
342        return [
343            "lstm",  # TODO(b/117510976): Restore when lstm flex conversion works.
344            "unroll_batch_matmul",  # TODO(b/123030774): Fails in 1.13 tests.
345            "unidirectional_sequence_lstm",
346            "unidirectional_sequence_rnn",
347        ]
349    return []
351def generated_test_conversion_modes():
352    """Returns a list of conversion modes."""
354    # TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050.
355    return ["toco-flex", ""]
357def generated_test_models_all():
358    """Generates a list of all tests with the different converters.
360    Returns:
361      List of tuples representing:
362            (conversion mode, name of test, test tags, test args).
363    """
364    conversion_modes = generated_test_conversion_modes()
365    tests = generated_test_models()
366    options = []
367    for conversion_mode in conversion_modes:
368        failing_tests = generated_test_models_failing(conversion_mode)
369        for test in tests:
370            tags = []
371            args = []
372            if test in failing_tests:
373                tags.append("notap")
374                tags.append("manual")
375            if conversion_mode:
376                test += "_%s" % conversion_mode
378            # Flex conversion shouldn't suffer from the same conversion bugs
379            # listed for the default TFLite kernel backend.
380            if conversion_mode == "toco-flex":
381                args.append("--ignore_known_bugs=false")
382            options.append((conversion_mode, test, tags, args))
383    return options
385def gen_zip_test(name, test_name, conversion_mode, **kwargs):
386    """Generate a zipped-example test and its dependent zip files.
388    Args:
389      name: str. Resulting cc_test target name
390      test_name: str. Test targets this model. Comes from the list above.
391      conversion_mode: str. Which conversion mode to run with. Comes from the
392        list above.
393      **kwargs: tf_cc_test kwargs
394    """
395    toco = "//tensorflow/lite/toco:toco"
396    flags = ""
397    if conversion_mode:
398        # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050.
399        # if conversion_mode == "pb2lite":
400        #     toco = "//tensorflow/lite/experimental/pb2lite:pb2lite"
401        flags = "--ignore_toco_errors --run_with_flex"
403    gen_zipped_test_file(
404        name = "zip_%s" % test_name,
405        file = "%s.zip" % test_name,
406        toco = toco,
407        flags = flags,
408    )
409    tf_cc_test(name, **kwargs)
411def gen_zipped_test_file(name, file, toco, flags):
412    """Generate a zip file of tests by using :generate_examples.
414    Args:
415      name: str. Name of output. We will produce "`file`.files" as a target.
416      file: str. The name of one of the generated_examples targets, e.g. "transpose"
417      toco: str. Pathname of toco binary to run
418      flags: str. Any additional flags to include
419    """
420    native.genrule(
421        name = file + ".files",
422        cmd = (("$(locations :generate_examples) --toco $(locations {0}) " +
423                " --zip_to_output {1} {2} $(@D)").format(toco, file, flags)),
424        outs = [file],
425        tools = [
426            ":generate_examples",
427            toco,
428        ],
429    )
431    native.filegroup(
432        name = name,
433        srcs = [file],
434    )
436def gen_selected_ops(name, model):
437    """Generate the library that includes only used ops.
439    Args:
440      name: Name of the generated library.
441      model: TFLite model to interpret.
442    """
443    out = name + "_registration.cc"
444    tool = "//tensorflow/lite/tools:generate_op_registrations"
445    tflite_path = "//tensorflow/lite"
446    native.genrule(
447        name = name,
448        srcs = [model],
449        outs = [out],
450        cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") %
451              (tool, model, out, tflite_path[2:]),
452        tools = [tool],
453    )
455def flex_dep(target_op_sets):
456    if "SELECT_TF_OPS" in target_op_sets:
457        return ["//tensorflow/lite/delegates/flex:delegate"]
458    else:
459        return []
461def gen_model_coverage_test(src, model_name, data, failure_type, tags):
462    """Generates Python test targets for testing TFLite models.
464    Args:
465      src: Main source file.
466      model_name: Name of the model to test (must be also listed in the 'data'
467        dependencies)
468      data: List of BUILD targets linking the data.
469      failure_type: List of failure types (none, toco, crash, inference)
470        expected for the corresponding combinations of op sets
472      tags: List of strings of additional tags.
473    """
474    i = 0
476        args = []
477        if failure_type[i] != "none":
478            args.append("--failure_type=%s" % failure_type[i])
479        i = i + 1
480        native.py_test(
481            name = "model_coverage_test_%s_%s" % (model_name, target_op_sets.lower().replace(",", "_")),
482            srcs = [src],
483            main = src,
484            size = "large",
485            args = [
486                "--model_name=%s" % model_name,
487                "--target_ops=%s" % target_op_sets,
488            ] + args,
489            data = data,
490            srcs_version = "PY2AND3",
491            tags = [
492                "no_oss",
493                "no_windows",
494            ] + tags,
495            deps = [
496                "//tensorflow/lite/testing/model_coverage:model_coverage_lib",
497                "//tensorflow/lite/python:lite",
498                "//tensorflow/python:client_testlib",
499            ] + flex_dep(target_op_sets),
500        )