• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Ops that communicate with other processes via MPI.
2
3package(default_visibility = [
4    "//tensorflow:__subpackages__",
5])
6
7licenses(["notice"])  # Apache 2.0
8
9load(
10    "//tensorflow/core:platform/default/build_config.bzl",
11    "tf_additional_mpi_lib_defines",
12    "tf_proto_library_cc",
13)
14
15tf_proto_library_cc(
16    name = "mpi_message_proto",
17    srcs = ["mpi_message.proto"],
18    cc_api_version = 2,
19    protodeps = ["//tensorflow/core:protos_all"],
20    visibility = [
21        "//tensorflow:__subpackages__",
22    ],
23)
24
25cc_library(
26    name = "mpi_defines",
27    defines = tf_additional_mpi_lib_defines(),
28)
29
30load(
31    "//tensorflow:tensorflow.bzl",
32    "tf_custom_op_py_library",
33    "tf_custom_op_library",
34    "tf_gen_op_wrapper_py",
35    "tf_gen_op_libs",
36    "tf_kernel_library",
37    "tf_py_test",
38)
39
40tf_custom_op_library(
41    name = "python/ops/_mpi_ops.so",
42    srcs = [
43        "kernels/mpi_ops.cc",
44        "kernels/ring.cc",
45        "kernels/ring.h",
46        "ops/mpi_ops.cc",
47    ],
48    gpu_srcs = [
49        "kernels/ring.cu.cc",
50        "kernels/ring.h",
51    ],
52    deps = [
53        ":mpi_defines",
54        ":mpi_message_proto_cc",
55        "//third_party/mpi",
56    ],
57)
58
59tf_kernel_library(
60    name = "mpi_ops_kernels",
61    srcs = [
62        "kernels/mpi_ops.cc",
63        "kernels/ring.cc",
64    ],
65    hdrs = [
66        "kernels/ring.h",
67    ],
68    gpu_srcs = [
69        "kernels/ring.cu.cc",
70    ],
71    deps = [
72        ":mpi_defines",
73        "//tensorflow/core:core_cpu",
74        "//tensorflow/core:framework",
75        "//tensorflow/core:gpu_headers_lib",
76        "//tensorflow/core:lib",
77        "//tensorflow/core:proto_text",
78        "//tensorflow/core:stream_executor",
79    ],
80    # TODO: Include?    alwayslink = 1,
81)
82
83tf_gen_op_libs(
84    op_lib_names = ["mpi_ops"],
85)
86
87tf_gen_op_wrapper_py(
88    name = "mpi_ops",
89    deps = [":mpi_ops_op_lib"],
90)
91
92tf_custom_op_py_library(
93    name = "mpi_collectives_py",
94    srcs = [
95        "__init__.py",
96        "python/ops/mpi_ops.py",
97    ],
98    dso = [
99        ":python/ops/_mpi_ops.so",
100    ],
101    kernels = [
102        ":mpi_ops_kernels",
103        ":mpi_ops_op_lib",
104    ],
105    srcs_version = "PY2AND3",
106    visibility = ["//visibility:public"],
107    deps = [
108        ":mpi_ops",
109        "//tensorflow/contrib/util:util_py",
110        "//tensorflow/python:device",
111        "//tensorflow/python:framework_ops",
112        "//tensorflow/python:platform",
113        "//tensorflow/python:util",
114    ],
115)
116
117tf_py_test(
118    name = "mpi_ops_test",
119    srcs = ["mpi_ops_test.py"],
120    additional_deps = [
121        "//tensorflow:tensorflow_py",
122        "//tensorflow/python:platform",
123    ],
124    data = [
125        ":python/ops/_mpi_ops.so",
126    ],
127    tags = ["manual"],
128)
129