• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2# TensorFlow Java API.
3
4package(default_visibility = ["//visibility:private"])
5
6licenses(["notice"])  # Apache 2.0
7
8load(":build_defs.bzl", "JAVACOPTS")
9load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
10load(
11    "//tensorflow:tensorflow.bzl",
12    "tf_binary_additional_srcs",
13    "tf_cc_binary",
14    "tf_copts",
15    "tf_custom_op_library",
16    "tf_java_test",
17    "tf_cc_test",
18)
19
20java_library(
21    name = "tensorflow",
22    srcs = [
23        ":java_op_sources",
24        ":java_sources",
25    ],
26    data = [":libtensorflow_jni"],
27    javacopts = JAVACOPTS,
28    plugins = [":processor"],
29    visibility = ["//visibility:public"],
30)
31
32# NOTE(ashankar): Rule to include the Java API in the Android Inference Library
33# .aar. At some point, might make sense for a .aar rule here instead.
34filegroup(
35    name = "java_sources",
36    srcs = glob([
37        "src/main/java/org/tensorflow/*.java",
38        "src/main/java/org/tensorflow/types/*.java",
39    ]),
40    visibility = [
41        "//tensorflow/contrib/android:__pkg__",
42        "//tensorflow/java:__pkg__",
43    ],
44)
45
46java_plugin(
47    name = "processor",
48    generates_api = True,
49    processor_class = "org.tensorflow.processor.OperatorProcessor",
50    visibility = ["//visibility:public"],
51    deps = [":processor_library"],
52)
53
54java_library(
55    name = "processor_library",
56    srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
57    javacopts = JAVACOPTS,
58    resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
59    deps = [
60        "@com_google_guava",
61        "@com_squareup_javapoet",
62    ],
63)
64
65filegroup(
66    name = "java_op_sources",
67    srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [":java_op_gen_sources"],
68    visibility = [
69        "//tensorflow/java:__pkg__",
70    ],
71)
72
73tf_java_op_gen_srcjar(
74    name = "java_op_gen_sources",
75    api_def_srcs = [
76        "//tensorflow/core/api_def:base_api_def",
77        "//tensorflow/core/api_def:java_api_def",
78    ],
79    base_package = "org.tensorflow.op",
80    gen_tool = ":java_op_gen_tool",
81)
82
83tf_cc_binary(
84    name = "java_op_gen_tool",
85    srcs = [
86        "src/gen/cc/op_gen_main.cc",
87    ],
88    copts = tf_copts(),
89    linkopts = select({
90        "//tensorflow:windows": [],
91        "//conditions:default": ["-lm"],
92    }),
93    linkstatic = 1,
94    deps = [
95        ":java_op_gen_lib",
96        "//tensorflow/core:framework",
97        "//tensorflow/core:framework_internal",
98        "//tensorflow/core:lib",
99        "//tensorflow/core:ops",
100    ],
101)
102
103cc_library(
104    name = "java_op_gen_lib",
105    srcs = [
106        "src/gen/cc/op_generator.cc",
107        "src/gen/cc/op_specs.cc",
108        "src/gen/cc/source_writer.cc",
109    ],
110    hdrs = [
111        "src/gen/cc/java_defs.h",
112        "src/gen/cc/op_generator.h",
113        "src/gen/cc/op_specs.h",
114        "src/gen/cc/source_writer.h",
115    ],
116    copts = tf_copts(),
117    deps = [
118        "//tensorflow/core:framework",
119        "//tensorflow/core:framework_internal",
120        "//tensorflow/core:lib",
121        "//tensorflow/core:lib_internal",
122        "//tensorflow/core:op_gen_lib",
123        "//tensorflow/core:protos_all_cc",
124        "@com_googlesource_code_re2//:re2",
125    ],
126)
127
128java_library(
129    name = "testutil",
130    testonly = 1,
131    srcs = ["src/test/java/org/tensorflow/TestUtil.java"],
132    javacopts = JAVACOPTS,
133    deps = [":tensorflow"],
134)
135
136tf_java_test(
137    name = "GraphTest",
138    size = "small",
139    srcs = ["src/test/java/org/tensorflow/GraphTest.java"],
140    javacopts = JAVACOPTS,
141    test_class = "org.tensorflow.GraphTest",
142    deps = [
143        ":tensorflow",
144        ":testutil",
145        "@junit",
146    ],
147)
148
149tf_java_test(
150    name = "OperationBuilderTest",
151    size = "small",
152    srcs = ["src/test/java/org/tensorflow/OperationBuilderTest.java"],
153    javacopts = JAVACOPTS,
154    test_class = "org.tensorflow.OperationBuilderTest",
155    deps = [
156        ":tensorflow",
157        ":testutil",
158        "@junit",
159    ],
160)
161
162tf_java_test(
163    name = "OperationTest",
164    size = "small",
165    srcs = ["src/test/java/org/tensorflow/OperationTest.java"],
166    javacopts = JAVACOPTS,
167    test_class = "org.tensorflow.OperationTest",
168    deps = [
169        ":tensorflow",
170        ":testutil",
171        "@junit",
172    ],
173)
174
175tf_java_test(
176    name = "SavedModelBundleTest",
177    size = "small",
178    srcs = ["src/test/java/org/tensorflow/SavedModelBundleTest.java"],
179    data = ["//tensorflow/cc/saved_model:saved_model_half_plus_two"],
180    javacopts = JAVACOPTS,
181    test_class = "org.tensorflow.SavedModelBundleTest",
182    deps = [
183        ":tensorflow",
184        ":testutil",
185        "@junit",
186    ],
187)
188
189tf_java_test(
190    name = "SessionTest",
191    size = "small",
192    srcs = ["src/test/java/org/tensorflow/SessionTest.java"],
193    javacopts = JAVACOPTS,
194    test_class = "org.tensorflow.SessionTest",
195    deps = [
196        ":tensorflow",
197        ":testutil",
198        "@junit",
199    ],
200)
201
202tf_java_test(
203    name = "ShapeTest",
204    size = "small",
205    srcs = ["src/test/java/org/tensorflow/ShapeTest.java"],
206    javacopts = JAVACOPTS,
207    test_class = "org.tensorflow.ShapeTest",
208    deps = [
209        ":tensorflow",
210        ":testutil",
211        "@junit",
212    ],
213)
214
215tf_custom_op_library(
216    name = "my_test_op.so",
217    srcs = ["src/test/native/my_test_op.cc"],
218)
219
220tf_java_test(
221    name = "TensorFlowTest",
222    size = "small",
223    srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"],
224    data = [":my_test_op.so"],
225    javacopts = JAVACOPTS,
226    test_class = "org.tensorflow.TensorFlowTest",
227    deps = [
228        ":tensorflow",
229        "@junit",
230    ],
231)
232
233tf_java_test(
234    name = "TensorTest",
235    size = "small",
236    srcs = ["src/test/java/org/tensorflow/TensorTest.java"],
237    javacopts = JAVACOPTS,
238    test_class = "org.tensorflow.TensorTest",
239    deps = [
240        ":tensorflow",
241        ":testutil",
242        "@junit",
243    ],
244)
245
246tf_java_test(
247    name = "ScopeTest",
248    size = "small",
249    srcs = ["src/test/java/org/tensorflow/op/ScopeTest.java"],
250    javacopts = JAVACOPTS,
251    test_class = "org.tensorflow.op.ScopeTest",
252    deps = [
253        ":tensorflow",
254        ":testutil",
255        "@junit",
256    ],
257)
258
259tf_java_test(
260    name = "PrimitiveOpTest",
261    size = "small",
262    srcs = ["src/test/java/org/tensorflow/op/PrimitiveOpTest.java"],
263    javacopts = JAVACOPTS,
264    test_class = "org.tensorflow.op.PrimitiveOpTest",
265    deps = [
266        ":tensorflow",
267        ":testutil",
268        "@junit",
269    ],
270)
271
272tf_java_test(
273    name = "OperandsTest",
274    size = "small",
275    srcs = ["src/test/java/org/tensorflow/op/OperandsTest.java"],
276    javacopts = JAVACOPTS,
277    test_class = "org.tensorflow.op.OperandsTest",
278    deps = [
279        ":tensorflow",
280        ":testutil",
281        "@junit",
282    ],
283)
284
285tf_java_test(
286    name = "ConstantTest",
287    size = "small",
288    srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
289    javacopts = JAVACOPTS,
290    test_class = "org.tensorflow.op.core.ConstantTest",
291    deps = [
292        ":tensorflow",
293        ":testutil",
294        "@junit",
295    ],
296)
297
298tf_java_test(
299    name = "GeneratedOperationsTest",
300    size = "small",
301    srcs = ["src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java"],
302    javacopts = JAVACOPTS,
303    test_class = "org.tensorflow.op.core.GeneratedOperationsTest",
304    deps = [
305        ":tensorflow",
306        ":testutil",
307        "@junit",
308    ],
309)
310
311tf_java_test(
312    name = "GradientsTest",
313    size = "small",
314    srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"],
315    javacopts = JAVACOPTS,
316    test_class = "org.tensorflow.op.core.GradientsTest",
317    deps = [
318        ":tensorflow",
319        ":testutil",
320        "@junit",
321    ],
322)
323
324tf_java_test(
325    name = "ZerosTest",
326    size = "small",
327    srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"],
328    javacopts = JAVACOPTS,
329    test_class = "org.tensorflow.op.core.ZerosTest",
330    deps = [
331        ":tensorflow",
332        ":testutil",
333        "@junit",
334    ],
335)
336
337filegroup(
338    name = "processor_test_resources",
339    srcs = glob([
340        "src/test/resources/org/tensorflow/**/*.java",
341        "src/main/java/org/tensorflow/op/annotation/Operator.java",
342    ]),
343)
344
345tf_cc_test(
346    name = "source_writer_test",
347    size = "small",
348    srcs = [
349        "src/gen/cc/source_writer_test.cc",
350    ],
351    data = [
352        "src/gen/resources/test.java.snippet",
353    ],
354    deps = [
355        ":java_op_gen_lib",
356        "//tensorflow/core:lib",
357        "//tensorflow/core:test",
358        "//tensorflow/core:test_main",
359    ],
360)
361
362filegroup(
363    name = "libtensorflow_jni",
364    srcs = select({
365        "//tensorflow:windows": [":tensorflow_jni.dll"],
366        "//tensorflow:macos": [":libtensorflow_jni.dylib"],
367        "//conditions:default": [":libtensorflow_jni.so"],
368    }),
369    visibility = ["//visibility:public"],
370)
371
372LINKER_VERSION_SCRIPT = ":config/version_script.lds"
373
374LINKER_EXPORTED_SYMBOLS = ":config/exported_symbols.lds"
375
376tf_cc_binary(
377    name = "tensorflow_jni",
378    # Set linker options to strip out anything except the JNI
379    # symbols from the library. This reduces the size of the library
380    # considerably (~50% as of January 2017).
381    linkopts = select({
382        "//tensorflow:debug": [],  # Disable all custom linker options in debug mode
383        "//tensorflow:macos": [
384            "-Wl,-exported_symbols_list,$(location {})".format(LINKER_EXPORTED_SYMBOLS),
385        ],
386        "//tensorflow:windows": [],
387        "//conditions:default": [
388            "-z defs",
389            "-s",
390            "-Wl,--version-script,$(location {})".format(LINKER_VERSION_SCRIPT),
391        ],
392    }),
393    linkshared = 1,
394    linkstatic = 1,
395    per_os_targets = True,
396    deps = [
397        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
398        "//tensorflow/java/src/main/native",
399        LINKER_VERSION_SCRIPT,
400        LINKER_EXPORTED_SYMBOLS,
401    ],
402)
403
404genrule(
405    name = "pom",
406    outs = ["pom.xml"],
407    cmd = "$(location generate_pom) >$@",
408    output_to_bindir = 1,
409    tools = [":generate_pom"] + tf_binary_additional_srcs(),
410)
411
412tf_cc_binary(
413    name = "generate_pom",
414    srcs = ["generate_pom.cc"],
415    deps = ["//tensorflow/c:c_api"],
416)
417