• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
2load("@flatbuffers//:build_defs.bzl", "flatbuffer_py_library")
3load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable", "if_portable")
4load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist")
5
6package(
7    default_visibility = [
8        "//tensorflow:internal",
9        "//third_party/tflite_micro:__subpackages__",
10    ],
11    licenses = ["notice"],
12)
13
14exports_files(["tflite_convert.py"])
15
16flatbuffer_py_library(
17    name = "schema_py",
18    srcs = ["//tensorflow/lite/schema:schema.fbs"],
19)
20
21py_library(
22    name = "interpreter",
23    srcs = [
24        "interpreter.py",
25    ],
26    compatible_with = get_compatible_with_portable(),
27    srcs_version = "PY3",
28    visibility = ["//visibility:public"],
29    deps = [
30        ":metrics",
31        "//tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper",
32        "//tensorflow/python:util",
33        "//tensorflow/python/util:tf_export",
34        "//third_party/py/numpy",
35    ],
36)
37
38py_test(
39    name = "interpreter_test",
40    srcs = ["interpreter_test.py"],
41    data = [
42        "//tensorflow/lite:testdata/sparse_tensor.bin",
43        "//tensorflow/lite/python/testdata:interpreter_test_data",
44        "//tensorflow/lite/python/testdata:test_delegate.so",
45    ],
46    python_version = "PY3",
47    srcs_version = "PY3",
48    tags = [
49        "no_oss",  # TODO(b/190842754): Enable test in OSS.
50        "no_pip",  # TODO(b/187847053): Enable test in pip.
51    ],
52    deps = [
53        ":interpreter",
54        "//tensorflow/lite/python/testdata:_pywrap_test_registerer",
55        "//tensorflow/python:client_testlib",
56        "//tensorflow/python:framework_test_lib",
57        "//tensorflow/python:platform",
58        "//third_party/py/numpy",
59        "@six_archive//:six",
60    ],
61)
62
63py_binary(
64    name = "tflite_convert",
65    srcs = ["tflite_convert.py"],
66    python_version = "PY3",
67    srcs_version = "PY3",
68    visibility = ["//visibility:public"],
69    deps = [
70        ":tflite_convert_main_lib",
71        "//tensorflow:tensorflow_py",
72        "@absl_py//absl:app",
73        "@six_archive//:six",
74    ],
75)
76
77py_library(
78    name = "tflite_convert_main_lib",
79    srcs = ["tflite_convert.py"],
80    srcs_version = "PY3",
81    visibility = ["//visibility:public"],
82    deps = [
83        ":tflite_convert_lib",
84        "//tensorflow:tensorflow_py",
85        "@absl_py//absl:app",
86        "@six_archive//:six",
87    ],
88)
89
90py_library(
91    name = "tflite_convert_lib",
92    srcs = ["tflite_convert.py"],
93    srcs_version = "PY3",
94    visibility = ["//visibility:public"],
95    deps = [
96        ":lite",
97        "//tensorflow/lite/toco/logging:gen_html",
98        "//tensorflow/lite/toco/logging:toco_conversion_log_proto_py",
99        "//tensorflow/python:util",
100        "@absl_py//absl:app",
101        "@six_archive//:six",
102    ],
103)
104
105py_library(
106    name = "test_util",
107    srcs = ["test_util.py"],
108    srcs_version = "PY3",
109    deps = [
110        ":lite",
111        ":schema_util",
112        "//tensorflow/lite/tools:visualize",
113        "//tensorflow/python:framework",
114    ],
115)
116
117py_test(
118    name = "test_util_test",
119    srcs = ["test_util_test.py"],
120    data = [
121        "//tensorflow/lite:testdata/add.bin",
122        "//tensorflow/lite:testdata/softplus_flex.bin",
123    ],
124    python_version = "PY3",
125    deps = [
126        ":test_util",
127    ],
128)
129
130py_test(
131    name = "tflite_convert_test",
132    srcs = ["tflite_convert_test.py"],
133    data = [
134        ":tflite_convert.par",
135        "@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
136    ],
137    python_version = "PY3",
138    # Increased thread count for reducing timeout failures.
139    shard_count = 10,
140    srcs_version = "PY3",
141    tags = [
142        "no_oss",
143        "no_pip",
144        "no_windows",
145        "noasan",  # b/144707533
146        "notsan",  # b/160824139
147    ],
148    deps = [
149        ":convert",
150        ":test_util",
151        ":tflite_convert",
152        "//tensorflow:tensorflow_py",
153        "//tensorflow/python:array_ops",
154        "//tensorflow/python:client_testlib",
155        "//tensorflow/python:constant_op",
156        "//tensorflow/python:dtypes",
157        "//tensorflow/python:framework",
158        "//tensorflow/python:framework_ops",
159        "//tensorflow/python:framework_test_lib",
160        "//tensorflow/python:platform",
161        "//tensorflow/python:random_ops",
162        "//tensorflow/python:session",
163        "//tensorflow/python:tf2",
164        "//tensorflow/python/eager:def_function",
165        "//tensorflow/python/saved_model",
166        "//tensorflow/python/saved_model:save",
167        "//tensorflow/python/training:training_util",
168        "//tensorflow/python/training/tracking",
169        "//third_party/py/numpy",
170    ],
171)
172
173py_library(
174    name = "lite",
175    srcs = ["lite.py"],
176    srcs_version = "PY3",
177    visibility = ["//visibility:public"],
178    deps = [
179        ":convert",
180        ":convert_phase",
181        ":convert_saved_model",
182        ":interpreter",
183        ":lite_constants",
184        ":metrics",
185        ":op_hint",
186        ":util",
187        "//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
188        "//tensorflow/lite/python/optimize:calibrator",
189        "//tensorflow/python/client:session",
190        "//tensorflow/python/eager:context",
191        "//tensorflow/python/eager:def_function",
192        "//tensorflow/python/eager:function",
193        "//tensorflow/python/framework",
194        "//tensorflow/python/framework:convert_to_constants",
195        "//tensorflow/python/framework:dtypes",
196        "//tensorflow/python/framework:errors",
197        "//tensorflow/python/framework:ops",
198        "//tensorflow/python/platform",
199        "//tensorflow/python/saved_model:load",
200        "//tensorflow/python/saved_model:loader",
201        "//tensorflow/python/saved_model:signature_constants",
202        "//tensorflow/python/saved_model:tag_constants",
203        "//tensorflow/python/util",
204        "//tensorflow/python/util:tf_export",
205        "@absl_py//absl/logging",
206        "@six_archive//:six",
207    ],
208)
209
210py_test(
211    name = "lite_test",
212    srcs = ["lite_test.py"],
213    data = [
214        "//tensorflow/lite/python/testdata:control_flow_v1.pbtxt",
215        "@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
216    ],
217    python_version = "PY3",
218    shard_count = 4,
219    srcs_version = "PY3",
220    tags = [
221        "no_windows",
222    ],
223    deps = [
224        ":lite",
225        "//tensorflow:tensorflow_py",
226        "//tensorflow/python:client_testlib",
227        "//tensorflow/python:framework_test_lib",
228        "@six_archive//:six",
229    ],
230)
231
232py_test(
233    name = "lite_v2_test",
234    srcs = ["lite_v2_test.py"],
235    data = [
236        "//tensorflow/lite/python/testdata:test_delegate.so",
237        "//tensorflow/lite/python/testdata/control_flow_v1_saved_model:saved_model.pb",
238    ],
239    python_version = "PY3",
240    shard_count = 12,
241    srcs_version = "PY3",
242    tags = [
243        "no_windows",
244    ],
245    deps = [
246        ":lite",
247        ":lite_v2_test_util",
248        ":test_util",
249        "//tensorflow:tensorflow_py",
250        "//tensorflow/lite/python/testdata:_pywrap_test_registerer",
251        "//tensorflow/lite/python/testdata:double_op",
252        "//tensorflow/python:client_testlib",
253        "//tensorflow/python:framework_test_lib",
254        "@six_archive//:six",
255    ],
256)
257
258py_library(
259    name = "lite_v2_test_util",
260    testonly = 1,
261    srcs = ["lite_v2_test_util.py"],
262    srcs_version = "PY3",
263    tags = [
264        "no_windows",
265    ],
266    deps = [
267        ":lite",
268        "//tensorflow/python:client_testlib",
269        "//tensorflow/python:framework_test_lib",
270        "@six_archive//:six",
271    ],
272)
273
274py_test(
275    name = "lite_flex_test",
276    srcs = ["lite_flex_test.py"],
277    python_version = "PY3",
278    srcs_version = "PY3",
279    deps = [
280        ":lite",
281        ":test_util",
282        "//tensorflow/lite/python/testdata:double_op",
283        "//tensorflow/python:client_testlib",
284        "//tensorflow/python:framework_test_lib",
285    ],
286)
287
288py_library(
289    name = "util",
290    srcs = ["util.py"],
291    srcs_version = "PY3",
292    visibility = internal_visibility_allowlist(),
293    deps = [
294        ":op_hint",
295        ":schema_py",
296        ":schema_util",
297        "//tensorflow/lite/python:tflite_keras_util",
298        "//tensorflow/lite/toco:toco_flags_proto_py",
299        "//tensorflow/python:convert_to_constants",
300        "//tensorflow/python:dtypes",
301        "//tensorflow/python:error_interpolation",
302        "//tensorflow/python:graph_util",
303        "//tensorflow/python:tf_optimizer",
304        "//tensorflow/python/eager:function",
305        "//tensorflow/python/training:saver",
306        "@absl_py//absl/logging",
307        "@flatbuffers//:runtime_py",
308        "@six_archive//:six",
309    ],
310)
311
312py_test(
313    name = "util_test",
314    srcs = ["util_test.py"],
315    python_version = "PY3",
316    srcs_version = "PY3",
317    tags = [
318        "no_windows",
319    ],
320    deps = [
321        ":util",
322        "//tensorflow:tensorflow_py",
323        "//tensorflow/python:array_ops",
324        "//tensorflow/python:client_testlib",
325        "//tensorflow/python:control_flow_ops",
326        "//tensorflow/python:convert_to_constants",
327        "//tensorflow/python:dtypes",
328        "//tensorflow/python:framework_ops",
329        "//tensorflow/python:framework_test_lib",
330        "//tensorflow/python:math_ops",
331        "//tensorflow/python:session",
332        "//third_party/py/numpy",
333        "@absl_py//absl/testing:parameterized",
334        "@six_archive//:six",
335    ],
336)
337
338py_library(
339    name = "tflite_keras_util",
340    srcs = [
341        "tflite_keras_util.py",
342    ],
343    srcs_version = "PY3",
344    deps = [
345        "//tensorflow/python:util",
346        "//tensorflow/python/eager:def_function",
347    ],
348)
349
350py_library(
351    name = "wrap_toco",
352    srcs = [
353        "wrap_toco.py",
354    ],
355    srcs_version = "PY3",
356    deps = [
357        "//tensorflow/python:_pywrap_toco_api",
358        "//tensorflow/python:pywrap_tensorflow",
359        "//tensorflow/python:util",
360    ],
361)
362
363py_library(
364    name = "lite_constants",
365    srcs = ["lite_constants.py"],
366    srcs_version = "PY3",
367    deps = [
368        "//tensorflow/lite/toco:toco_flags_proto_py",
369        "//tensorflow/python:dtypes",
370    ],
371)
372
373py_library(
374    name = "convert",
375    srcs = ["convert.py"],
376    srcs_version = "PY3",
377    visibility = ["//visibility:public"],
378    deps = [
379        ":convert_phase",
380        ":lite_constants",
381        ":util",
382        ":wrap_toco",
383        "//tensorflow/lite/python/metrics_wrapper",
384        "//tensorflow/lite/toco:model_flags_proto_py",
385        "//tensorflow/lite/toco:toco_flags_proto_py",
386        "//tensorflow/lite/toco/python:toco_from_protos",
387        "//tensorflow/lite/tools:flatbuffer_utils",
388        "//tensorflow/python:dtypes",
389        "//tensorflow/python:platform",
390        "//tensorflow/python:tensor_shape",
391        "//tensorflow/python/util",
392        "//tensorflow/python/util:tf_export",
393        "@six_archive//:six",
394    ],
395)
396
397py_library(
398    name = "op_hint",
399    srcs = ["op_hint.py"],
400    srcs_version = "PY3",
401    visibility = ["//visibility:public"],
402    deps = [
403        "//tensorflow/core:protos_all_py",
404        "//tensorflow/python:graph_util",
405        "//tensorflow/python:platform",
406        "//tensorflow/python:util",
407    ],
408)
409
410py_test(
411    name = "convert_test",
412    srcs = ["convert_test.py"],
413    python_version = "PY3",
414    srcs_version = "PY3",
415    deps = [
416        ":convert",
417        ":interpreter",
418        ":op_hint",
419        "//tensorflow/python:array_ops",
420        "//tensorflow/python:client_testlib",
421        "//tensorflow/python:dtypes",
422        "//tensorflow/python:platform_test",
423        "//tensorflow/python:session",
424    ],
425)
426
427py_library(
428    name = "convert_saved_model",
429    srcs = ["convert_saved_model.py"],
430    srcs_version = "PY3",
431    visibility = [
432        "//tensorflow/lite:__subpackages__",
433    ],
434    deps = [
435        ":convert_phase",
436        ":util",
437        "//tensorflow/python:graph_util",
438        "//tensorflow/python:platform",
439        "//tensorflow/python/saved_model",
440    ],
441)
442
443py_test(
444    name = "convert_saved_model_test",
445    srcs = ["convert_saved_model_test.py"],
446    python_version = "PY3",
447    srcs_version = "PY3",
448    tags = [
449        "no_windows",
450    ],
451    visibility = ["//visibility:public"],
452    deps = [
453        ":convert_saved_model",
454        "//tensorflow/python:client_testlib",
455        "//tensorflow/python:layers",
456        "//tensorflow/python:nn",
457        "//tensorflow/python:platform_test",
458        "//tensorflow/python:session",
459        "//tensorflow/python/ops/losses",
460        "//tensorflow/python/saved_model",
461    ],
462)
463
464py_binary(
465    name = "convert_file_to_c_source",
466    srcs = ["convert_file_to_c_source.py"],
467    python_version = "PY3",
468    srcs_version = "PY3",
469    visibility = ["//visibility:public"],
470    deps = [
471        ":util",
472        "@absl_py//absl:app",
473        "@absl_py//absl/flags",
474    ],
475)
476
477sh_test(
478    name = "convert_file_to_c_source_test",
479    srcs = ["convert_file_to_c_source_test.sh"],
480    data = [":convert_file_to_c_source"],
481)
482
483py_library(
484    name = "schema_util",
485    srcs = ["schema_util.py"],
486    srcs_version = "PY3",
487    visibility = ["//tensorflow/lite/schema:utils_friends"],
488    deps = [
489        "//tensorflow/python:util",
490    ],
491)
492
493pytype_strict_library(
494    name = "metrics_interface",
495    srcs = ["metrics_interface.py"],
496    compatible_with = get_compatible_with_portable(),
497    srcs_version = "PY3",
498    visibility = ["//visibility:private"],
499)
500
501# Use py_library since the metrics module is imported in a try-except block,
502# which doesn't work with the pytype_strict_library.
503py_library(
504    name = "convert_phase",
505    srcs = ["convert_phase.py"],
506    srcs_version = "PY3",
507    visibility = ["//tensorflow/lite:__subpackages__"],
508    deps = [
509        ":metrics",
510        "//tensorflow/lite/python/metrics_wrapper:converter_error_data_proto_py",
511    ],
512)
513
514pytype_strict_library(
515    name = "metrics_nonportable",
516    srcs = ["metrics_nonportable.py"],
517    srcs_version = "PY3",
518    visibility = ["//visibility:private"],
519    deps = [
520        ":metrics_interface",
521        "//tensorflow/lite/python/metrics_wrapper",
522        "//tensorflow/lite/python/metrics_wrapper:converter_error_data_proto_py",
523        "//tensorflow/python/eager:monitoring",
524    ],
525)
526
527py_test(
528    name = "metrics_nonportable_test",
529    srcs = ["metrics_nonportable_test.py"],
530    data = [
531        "//tensorflow/lite/python/testdata/control_flow_v1_saved_model:saved_model.pb",
532    ],
533    python_version = "PY3",
534    tags = ["no_oss"],
535    visibility = ["//visibility:public"],
536    deps = [
537        ":lite",
538        ":metrics_nonportable",
539        "//tensorflow:tensorflow_py",
540        "//tensorflow/lite/python/metrics_wrapper:converter_error_data_proto_py",
541        "//tensorflow/python:client_testlib",
542        "//tensorflow/python:framework_test_lib",
543    ],
544)
545
546pytype_strict_library(
547    name = "metrics_portable",
548    srcs = ["metrics_portable.py"],
549    compatible_with = get_compatible_with_portable(),
550    srcs_version = "PY3",
551    visibility = ["//visibility:private"],
552    deps = [
553        ":metrics_interface",
554    ],
555)
556
557py_test(
558    name = "metrics_portable_test",
559    srcs = ["metrics_portable_test.py"],
560    python_version = "PY3",
561    visibility = ["//visibility:public"],
562    deps = [
563        ":metrics_portable",
564        "//tensorflow/python:client_testlib",
565        "//tensorflow/python:framework_test_lib",
566    ],
567)
568
569pytype_strict_library(
570    name = "metrics",
571    compatible_with = get_compatible_with_portable(),
572    srcs_version = "PY3",
573    visibility = ["//tensorflow/lite:__subpackages__"],
574    deps = if_portable(
575        if_false = [":metrics_nonportable"],
576        if_true = [":metrics_portable"],
577    ),
578)
579
580py_library(
581    name = "analyzer",
582    srcs = [
583        "analyzer.py",
584    ],
585    srcs_version = "PY3",
586    visibility = ["//visibility:public"],
587    deps = [
588        "//tensorflow/lite/python/analyzer_wrapper:_pywrap_analyzer_wrapper",
589    ],
590)
591
592py_test(
593    name = "analyzer_test",
594    srcs = ["analyzer_test.py"],
595    data = [
596        "//tensorflow/lite:testdata/add.bin",
597        "//tensorflow/lite:testdata/conv_huge_im2col.bin",
598        "//tensorflow/lite:testdata/multi_add_flex.bin",
599    ],
600    python_version = "PY3",
601    deps = [
602        ":analyzer",
603        "//tensorflow:tensorflow_py",
604        "//tensorflow/python:client_testlib",
605        "//tensorflow/python:framework_test_lib",
606    ],
607)
608