• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- Python -*-
2
3load(
4    "//tensorflow:tensorflow.bzl",
5    "tf_binary_additional_srcs",
6)
7
8# Generate Java wrapper classes for all registered core operations and package
9# them into a single source archive (.srcjar).
10#
11# For example:
12#  tf_java_op_gen_srcjar("gen_sources", ":gen_tool", "my.package")
13#
14# will create a genrule named "gen_sources" that generates source files under
15#     ops/src/main/java/my/package/**/*.java
16#
17# and then archive those source files into
18#     ops/gen_sources.srcjar
19#
20def tf_java_op_gen_srcjar(
21        name,
22        gen_tool,
23        base_package,
24        api_def_srcs = [],
25        out_dir = "ops/",
26        out_src_dir = "src/main/java/",
27        visibility = ["//tensorflow/java:__pkg__"]):
28    gen_cmds = ["rm -rf $(@D)"]  # Always start from fresh when generating source files
29    srcs = api_def_srcs[:]
30
31    if not api_def_srcs:
32        api_def_args_str = ","
33    else:
34        api_def_args = []
35        for api_def_src in api_def_srcs:
36            # Add directory of the first ApiDef source to args.
37            # We are assuming all ApiDefs in a single api_def_src are in the
38            # same directory.
39            api_def_args.append(
40                "$$(dirname $$(echo $(locations " + api_def_src +
41                ") | cut -d\" \" -f1))",
42            )
43        api_def_args_str = ",".join(api_def_args)
44
45    gen_cmds += ["$(location " + gen_tool + ")" +
46                 " --output_dir=$(@D)/" + out_src_dir +
47                 " --base_package=" + base_package +
48                 " --api_dirs=" + api_def_args_str]
49
50    # Generate a source archive containing generated code for these ops.
51    gen_srcjar = out_dir + name + ".srcjar"
52    gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
53
54    native.genrule(
55        name = name,
56        srcs = srcs,
57        outs = [gen_srcjar],
58        tools = [
59            "@local_jdk//:jar",
60            "@local_jdk//:jdk",
61            gen_tool,
62        ] + tf_binary_additional_srcs(),
63        cmd = " && ".join(gen_cmds),
64    )
65