1# Ops that communicate with other processes via MPI. 2 3package(default_visibility = [ 4 "//tensorflow:__subpackages__", 5]) 6 7licenses(["notice"]) # Apache 2.0 8 9load( 10 "//tensorflow/core:platform/default/build_config.bzl", 11 "tf_additional_mpi_lib_defines", 12 "tf_proto_library_cc", 13) 14 15tf_proto_library_cc( 16 name = "mpi_message_proto", 17 srcs = ["mpi_message.proto"], 18 cc_api_version = 2, 19 protodeps = ["//tensorflow/core:protos_all"], 20 visibility = [ 21 "//tensorflow:__subpackages__", 22 ], 23) 24 25cc_library( 26 name = "mpi_defines", 27 defines = tf_additional_mpi_lib_defines(), 28) 29 30load( 31 "//tensorflow:tensorflow.bzl", 32 "tf_custom_op_py_library", 33 "tf_custom_op_library", 34 "tf_gen_op_wrapper_py", 35 "tf_gen_op_libs", 36 "tf_kernel_library", 37 "tf_py_test", 38) 39 40tf_custom_op_library( 41 name = "python/ops/_mpi_ops.so", 42 srcs = [ 43 "kernels/mpi_ops.cc", 44 "kernels/ring.cc", 45 "kernels/ring.h", 46 "ops/mpi_ops.cc", 47 ], 48 gpu_srcs = [ 49 "kernels/ring.cu.cc", 50 "kernels/ring.h", 51 ], 52 deps = [ 53 ":mpi_defines", 54 ":mpi_message_proto_cc", 55 "//third_party/mpi", 56 ], 57) 58 59tf_kernel_library( 60 name = "mpi_ops_kernels", 61 srcs = [ 62 "kernels/mpi_ops.cc", 63 "kernels/ring.cc", 64 ], 65 hdrs = [ 66 "kernels/ring.h", 67 ], 68 gpu_srcs = [ 69 "kernels/ring.cu.cc", 70 ], 71 deps = [ 72 ":mpi_defines", 73 "//tensorflow/core:core_cpu", 74 "//tensorflow/core:framework", 75 "//tensorflow/core:gpu_headers_lib", 76 "//tensorflow/core:lib", 77 "//tensorflow/core:proto_text", 78 "//tensorflow/core:stream_executor", 79 ], 80 # TODO: Include? alwayslink = 1, 81) 82 83tf_gen_op_libs( 84 op_lib_names = ["mpi_ops"], 85) 86 87tf_gen_op_wrapper_py( 88 name = "mpi_ops", 89 deps = [":mpi_ops_op_lib"], 90) 91 92tf_custom_op_py_library( 93 name = "mpi_collectives_py", 94 srcs = [ 95 "__init__.py", 96 "python/ops/mpi_ops.py", 97 ], 98 dso = [ 99 ":python/ops/_mpi_ops.so", 100 ], 101 kernels = [ 102 ":mpi_ops_kernels", 103 ":mpi_ops_op_lib", 104 ], 105 srcs_version = "PY2AND3", 106 visibility = ["//visibility:public"], 107 deps = [ 108 ":mpi_ops", 109 "//tensorflow/contrib/util:util_py", 110 "//tensorflow/python:device", 111 "//tensorflow/python:framework_ops", 112 "//tensorflow/python:platform", 113 "//tensorflow/python:util", 114 ], 115) 116 117tf_py_test( 118 name = "mpi_ops_test", 119 srcs = ["mpi_ops_test.py"], 120 additional_deps = [ 121 "//tensorflow:tensorflow_py", 122 "//tensorflow/python:platform", 123 ], 124 data = [ 125 ":python/ops/_mpi_ops.so", 126 ], 127 tags = ["manual"], 128) 129