• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2# TensorFlow Java API.
3
4load(":build_defs.bzl", "JAVACOPTS")
5load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
6load(
7    "//tensorflow:tensorflow.bzl",
8    "VERSION",
9    "tf_binary_additional_srcs",
10    "tf_cc_binary",
11    "tf_cc_test",
12    "tf_copts",
13    "tf_custom_op_library",
14    "tf_java_test",
15)
16
17package(
18    default_visibility = ["//visibility:private"],
19    licenses = ["notice"],
20)
21
22java_library(
23    name = "tensorflow",
24    srcs = [
25        ":java_op_sources",
26        ":java_sources",
27    ],
28    data = tf_binary_additional_srcs() + [":libtensorflow_jni"],
29    javacopts = JAVACOPTS,
30    plugins = [":processor"],
31    resources = [":java_resources"],
32    visibility = ["//visibility:public"],
33)
34
35genrule(
36    name = "version-info",
37    outs = ["src/main/resources/tensorflow-version-info"],
38    cmd = "echo version=%s > $@" % VERSION,
39    output_to_bindir = 1,
40)
41
42filegroup(
43    name = "java_resources",
44    srcs = [":version-info"],
45    visibility = ["//visibility:private"],
46)
47
48# NOTE(ashankar): Rule to include the Java API in the Android Inference Library
49# .aar. At some point, might make sense for a .aar rule here instead.
50filegroup(
51    name = "java_sources",
52    srcs = glob([
53        "src/main/java/org/tensorflow/*.java",
54        "src/main/java/org/tensorflow/types/*.java",
55    ]),
56    visibility = ["//tensorflow/tools/android/inference_interface:__pkg__"],
57)
58
59java_plugin(
60    name = "processor",
61    generates_api = True,
62    processor_class = "org.tensorflow.processor.OperatorProcessor",
63    visibility = ["//visibility:public"],
64    deps = [":processor_library"],
65)
66
67java_library(
68    name = "processor_library",
69    srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
70    javacopts = JAVACOPTS,
71    resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
72    deps = [
73        "@com_google_guava",
74        "@com_squareup_javapoet",
75    ],
76)
77
78filegroup(
79    name = "java_op_sources",
80    srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [":java_op_gen_sources"],
81    visibility = ["//visibility:private"],
82)
83
84tf_java_op_gen_srcjar(
85    name = "java_op_gen_sources",
86    api_def_srcs = [
87        "//tensorflow/core/api_def:base_api_def",
88        "//tensorflow/core/api_def:java_api_def",
89    ],
90    base_package = "org.tensorflow.op",
91    gen_tool = ":java_op_gen_tool",
92)
93
94tf_cc_binary(
95    name = "java_op_gen_tool",
96    srcs = [
97        "src/gen/cc/op_gen_main.cc",
98    ],
99    copts = tf_copts(),
100    linkopts = select({
101        "//tensorflow:windows": [],
102        "//conditions:default": ["-lm"],
103    }),
104    linkstatic = 1,
105    deps = [
106        ":java_op_gen_lib",
107        "//tensorflow/core:framework",
108        "//tensorflow/core:framework_internal",
109        "//tensorflow/core:lib",
110        "//tensorflow/core:ops",
111    ],
112)
113
114cc_library(
115    name = "java_op_gen_lib",
116    srcs = [
117        "src/gen/cc/op_generator.cc",
118        "src/gen/cc/op_specs.cc",
119        "src/gen/cc/source_writer.cc",
120    ],
121    hdrs = [
122        "src/gen/cc/java_defs.h",
123        "src/gen/cc/op_generator.h",
124        "src/gen/cc/op_specs.h",
125        "src/gen/cc/source_writer.h",
126    ],
127    copts = tf_copts(),
128    deps = [
129        "//tensorflow/core:framework",
130        "//tensorflow/core:framework_internal",
131        "//tensorflow/core:lib",
132        "//tensorflow/core:lib_internal",
133        "//tensorflow/core:op_gen_lib",
134        "//tensorflow/core:protos_all_cc",
135        "@com_googlesource_code_re2//:re2",
136    ],
137)
138
139java_library(
140    name = "testutil",
141    testonly = 1,
142    srcs = ["src/test/java/org/tensorflow/TestUtil.java"],
143    javacopts = JAVACOPTS,
144    deps = [":tensorflow"],
145)
146
147tf_java_test(
148    name = "EagerSessionTest",
149    size = "small",
150    srcs = ["src/test/java/org/tensorflow/EagerSessionTest.java"],
151    javacopts = JAVACOPTS,
152    test_class = "org.tensorflow.EagerSessionTest",
153    deps = [
154        ":tensorflow",
155        ":testutil",
156        "@junit",
157    ],
158)
159
160tf_java_test(
161    name = "EagerOperationBuilderTest",
162    size = "small",
163    srcs = ["src/test/java/org/tensorflow/EagerOperationBuilderTest.java"],
164    javacopts = JAVACOPTS,
165    test_class = "org.tensorflow.EagerOperationBuilderTest",
166    deps = [
167        ":tensorflow",
168        ":testutil",
169        "@junit",
170    ],
171)
172
173tf_java_test(
174    name = "EagerOperationTest",
175    size = "small",
176    srcs = ["src/test/java/org/tensorflow/EagerOperationTest.java"],
177    javacopts = JAVACOPTS,
178    test_class = "org.tensorflow.EagerOperationTest",
179    deps = [
180        ":tensorflow",
181        ":testutil",
182        "@junit",
183    ],
184)
185
186tf_java_test(
187    name = "GraphTest",
188    size = "small",
189    srcs = ["src/test/java/org/tensorflow/GraphTest.java"],
190    javacopts = JAVACOPTS,
191    test_class = "org.tensorflow.GraphTest",
192    deps = [
193        ":tensorflow",
194        ":testutil",
195        "@junit",
196    ],
197)
198
199tf_java_test(
200    name = "GraphOperationBuilderTest",
201    size = "small",
202    srcs = ["src/test/java/org/tensorflow/GraphOperationBuilderTest.java"],
203    javacopts = JAVACOPTS,
204    test_class = "org.tensorflow.GraphOperationBuilderTest",
205    deps = [
206        ":tensorflow",
207        ":testutil",
208        "@junit",
209    ],
210)
211
212tf_java_test(
213    name = "GraphOperationTest",
214    size = "small",
215    srcs = ["src/test/java/org/tensorflow/GraphOperationTest.java"],
216    javacopts = JAVACOPTS,
217    test_class = "org.tensorflow.GraphOperationTest",
218    deps = [
219        ":tensorflow",
220        ":testutil",
221        "@junit",
222    ],
223)
224
225tf_java_test(
226    name = "SavedModelBundleTest",
227    size = "small",
228    srcs = ["src/test/java/org/tensorflow/SavedModelBundleTest.java"],
229    data = ["//tensorflow/cc/saved_model:saved_model_half_plus_two"],
230    javacopts = JAVACOPTS,
231    test_class = "org.tensorflow.SavedModelBundleTest",
232    deps = [
233        ":tensorflow",
234        ":testutil",
235        "@junit",
236    ],
237)
238
239tf_java_test(
240    name = "SessionTest",
241    size = "small",
242    srcs = ["src/test/java/org/tensorflow/SessionTest.java"],
243    javacopts = JAVACOPTS,
244    test_class = "org.tensorflow.SessionTest",
245    deps = [
246        ":tensorflow",
247        ":testutil",
248        "@junit",
249    ],
250)
251
252tf_java_test(
253    name = "ShapeTest",
254    size = "small",
255    srcs = ["src/test/java/org/tensorflow/ShapeTest.java"],
256    javacopts = JAVACOPTS,
257    test_class = "org.tensorflow.ShapeTest",
258    deps = [
259        ":tensorflow",
260        ":testutil",
261        "@junit",
262    ],
263)
264
265tf_custom_op_library(
266    name = "my_test_op.so",
267    srcs = ["src/test/native/my_test_op.cc"],
268)
269
270tf_java_test(
271    name = "TensorFlowTest",
272    size = "small",
273    srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"],
274    data = [":my_test_op.so"],
275    javacopts = JAVACOPTS,
276    test_class = "org.tensorflow.TensorFlowTest",
277    deps = [
278        ":tensorflow",
279        "@junit",
280    ],
281)
282
283tf_java_test(
284    name = "TensorTest",
285    size = "small",
286    srcs = ["src/test/java/org/tensorflow/TensorTest.java"],
287    javacopts = JAVACOPTS,
288    tags = [
289        "noasan",  # TODO(b/177573611): enable after this is fixed.
290    ],
291    test_class = "org.tensorflow.TensorTest",
292    deps = [
293        ":tensorflow",
294        ":testutil",
295        "@junit",
296    ],
297)
298
299tf_java_test(
300    name = "ScopeTest",
301    size = "small",
302    srcs = ["src/test/java/org/tensorflow/op/ScopeTest.java"],
303    javacopts = JAVACOPTS,
304    test_class = "org.tensorflow.op.ScopeTest",
305    deps = [
306        ":tensorflow",
307        ":testutil",
308        "@junit",
309    ],
310)
311
312tf_java_test(
313    name = "PrimitiveOpTest",
314    size = "small",
315    srcs = ["src/test/java/org/tensorflow/op/PrimitiveOpTest.java"],
316    javacopts = JAVACOPTS,
317    test_class = "org.tensorflow.op.PrimitiveOpTest",
318    deps = [
319        ":tensorflow",
320        ":testutil",
321        "@junit",
322    ],
323)
324
325tf_java_test(
326    name = "OperandsTest",
327    size = "small",
328    srcs = ["src/test/java/org/tensorflow/op/OperandsTest.java"],
329    javacopts = JAVACOPTS,
330    test_class = "org.tensorflow.op.OperandsTest",
331    deps = [
332        ":tensorflow",
333        ":testutil",
334        "@junit",
335    ],
336)
337
338tf_java_test(
339    name = "ConstantTest",
340    size = "small",
341    srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
342    javacopts = JAVACOPTS,
343    test_class = "org.tensorflow.op.core.ConstantTest",
344    deps = [
345        ":tensorflow",
346        ":testutil",
347        "@junit",
348    ],
349)
350
351tf_java_test(
352    name = "GeneratedOperationsTest",
353    size = "small",
354    srcs = ["src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java"],
355    javacopts = JAVACOPTS,
356    test_class = "org.tensorflow.op.core.GeneratedOperationsTest",
357    deps = [
358        ":tensorflow",
359        ":testutil",
360        "@junit",
361    ],
362)
363
364tf_java_test(
365    name = "GradientsTest",
366    size = "small",
367    srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"],
368    javacopts = JAVACOPTS,
369    test_class = "org.tensorflow.op.core.GradientsTest",
370    deps = [
371        ":tensorflow",
372        ":testutil",
373        "@junit",
374    ],
375)
376
377tf_java_test(
378    name = "ZerosTest",
379    size = "small",
380    srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"],
381    javacopts = JAVACOPTS,
382    test_class = "org.tensorflow.op.core.ZerosTest",
383    deps = [
384        ":tensorflow",
385        ":testutil",
386        "@junit",
387    ],
388)
389
390filegroup(
391    name = "processor_test_resources",
392    srcs = glob([
393        "src/test/resources/org/tensorflow/**/*.java",
394        "src/main/java/org/tensorflow/op/annotation/Operator.java",
395    ]),
396)
397
398tf_cc_test(
399    name = "source_writer_test",
400    size = "small",
401    srcs = [
402        "src/gen/cc/source_writer_test.cc",
403    ],
404    data = [
405        "src/gen/resources/test.java.snippet",
406    ],
407    deps = [
408        ":java_op_gen_lib",
409        "//tensorflow/core:lib",
410        "//tensorflow/core:test",
411        "//tensorflow/core:test_main",
412    ],
413)
414
415filegroup(
416    name = "libtensorflow_jni",
417    srcs = select({
418        "//tensorflow:windows": [":tensorflow_jni.dll"],
419        "//tensorflow:macos": [":libtensorflow_jni.dylib"],
420        "//conditions:default": [":libtensorflow_jni.so"],
421    }),
422    visibility = ["//visibility:public"],
423)
424
425LINKER_VERSION_SCRIPT = ":config/version_script.lds"
426
427LINKER_EXPORTED_SYMBOLS = ":config/exported_symbols.lds"
428
429tf_cc_binary(
430    name = "tensorflow_jni",
431    # Set linker options to strip out anything except the JNI
432    # symbols from the library. This reduces the size of the library
433    # considerably (~50% as of January 2017).
434    linkopts = select({
435        "//tensorflow:debug": [],  # Disable all custom linker options in debug mode
436        "//tensorflow:macos": [
437            "-Wl,-exported_symbols_list,$(location {})".format(LINKER_EXPORTED_SYMBOLS),
438        ],
439        "//tensorflow:windows": [],
440        "//conditions:default": [
441            "-z defs",
442            "-s",
443            "-Wl,--version-script,$(location {})".format(LINKER_VERSION_SCRIPT),
444        ],
445    }),
446    linkshared = 1,
447    linkstatic = 1,
448    per_os_targets = True,
449    deps = [
450        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
451        "//tensorflow/java/src/main/native",
452        LINKER_VERSION_SCRIPT,
453        LINKER_EXPORTED_SYMBOLS,
454    ],
455)
456
457genrule(
458    name = "pom",
459    outs = ["pom.xml"],
460    cmd = "$(location generate_pom) >$@",
461    output_to_bindir = 1,
462    tools = [":generate_pom"] + tf_binary_additional_srcs(),
463)
464
465tf_cc_binary(
466    name = "generate_pom",
467    srcs = ["generate_pom.cc"],
468    deps = ["//tensorflow/c:c_api"],
469)
470