1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ 18 19 #include <utility> 20 #include <vector> 21 22 #include "absl/types/optional.h" 23 #include "absl/types/variant.h" 24 #include "pybind11/numpy.h" 25 #include "pybind11/pybind11.h" 26 #include "pybind11/pytypes.h" 27 #include "tensorflow/compiler/xla/python/py_buffer.h" 28 #include "tensorflow/core/platform/logging.h" 29 30 // TODO(jblespiau): The current implementation moves the Python logic to C++, 31 // as a preliminary step to executing the `pmap` execution path from C++. 32 // It implements the current Python behavior (thus, it may not be optimal, and 33 // we will be able to modify it later). 34 35 namespace jax { 36 37 // High level introduction. 38 // 39 // pmap and other parallel computation functions distribute some computation on 40 // several devices. On December 2020, the devices mesh (i.e. N-dimentional array 41 // of devices on which we map the computation) is defined by the user. 42 // 43 // We describe how to shard the inputs, and how to map it to the mesh of devices 44 // using `ShardingSpec`. It's mainly based on 2 components: 45 // - `sharding`, which specifies how to shard the inputs. 46 // - `mesh_mapping`, which specifies how to map shards to devices. 47 // 48 49 // The 3 following structs define how to shard one dimension of an ndarry. 50 // 51 // `NoSharding` (`None` in Python) means no sharding. 52 struct NoSharding {}; 53 54 // `Chunked` means that the dimension is split into np.prod(chunks) chunks 55 // and the split dimension itself is preserved inside the map. 56 // Those chunks are distributed over `len(chunks)` ShardedAxes axes 57 // (major-to-minor). 58 // For example, for a tensor `t` or shape [N] sharded using [Chunked([p])] (with 59 // p dividing N, let S = N // p) the tensor will be split into p chunks of 60 // shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right 61 // excluded) for k in {0, ... p-1}. 62 struct Chunked { 63 public: ChunkedChunked64 explicit Chunked(std::vector<int> chunks_) : chunks(std::move(chunks_)) {} 65 // The number of chunks per axis. 66 std::vector<int> chunks; 67 68 bool operator==(const Chunked& other) const { return chunks == other.chunks; } 69 bool operator!=(const Chunked& other) const { return chunks != other.chunks; } 70 }; 71 72 // `Unstacked` means that the dimension is split into chunks of size 1, and 73 // doesn't appear inside the map. `size` is always the dimension size. 74 // For example, a Tensor t of shape [N] will be sharded into N tensors of shape 75 // [], when using `Unstacked(N)`. 76 struct Unstacked { 77 public: UnstackedUnstacked78 explicit Unstacked(int sz) : size(sz) {} 79 int size; 80 81 bool operator==(const Unstacked& other) const { return size == other.size; } 82 bool operator!=(const Unstacked& other) const { return size != other.size; } 83 }; 84 85 using AvalDimSharding = absl::variant<NoSharding, Chunked, Unstacked>; 86 87 // Assigns sharded axes to mesh dimensions. 88 // 89 // The devices will be for each dimension which has a sharded `AvalDimSharding` 90 // When no axis is assigned, the data is replicated. 91 // As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually 92 // sharded axis (i.e. counting as if the None dimensions of sharding were 93 // filtered out). 94 // For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry 95 // of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`. 96 97 struct ShardedAxis { 98 int axis; 99 bool operator==(const ShardedAxis& other) const { return axis == other.axis; } 100 bool operator!=(const ShardedAxis& other) const { return axis != other.axis; } 101 }; 102 struct Replicated { 103 int replicas; 104 bool operator==(const Replicated& other) const { 105 return replicas == other.replicas; 106 } 107 bool operator!=(const Replicated& other) const { 108 return replicas != other.replicas; 109 } 110 }; 111 112 using MeshDimAssignment = absl::variant<ShardedAxis, Replicated>; 113 114 // Describes how each axis is sharded (if it is), and how it'smapped to the 115 // devices mesh. 116 class ShardingSpec { 117 public: ShardingSpec(std::vector<AvalDimSharding> sharding,std::vector<MeshDimAssignment> mesh_mapping)118 ShardingSpec(std::vector<AvalDimSharding> sharding, 119 std::vector<MeshDimAssignment> mesh_mapping) 120 : sharding_(std::move(sharding)), 121 mesh_mapping_(std::move(mesh_mapping)) {} 122 GetSharding()123 const std::vector<AvalDimSharding>& GetSharding() const { return sharding_; } GetMeshMapping()124 const std::vector<MeshDimAssignment>& GetMeshMapping() const { 125 return mesh_mapping_; 126 } 127 128 private: 129 // `sharding` specifies how the array is supposed to get partitioned into 130 // chunks. Its length matchs the rank of the array. See the docstring 131 // of `AvalDimSharding` for the supported partitioning schemes. 132 std::vector<AvalDimSharding> sharding_; 133 // `mesh_mapping` describes an assignments of the array chunks created by 134 // `sharding` to a logical device mesh. The length of the tuple is equal to 135 // the rank of the mesh. Each mesh dimension can either get partitions of 136 // data varying along one of the sharded dimensions, or the data can be 137 // replicated. 138 std::vector<MeshDimAssignment> mesh_mapping_; 139 }; 140 141 // A ShardedDeviceArray is an ndarray sharded across devices. 142 // 143 // The purpose of a ShardedDeviceArray is to reduce the number of transfers when 144 // executing replicated computations, by allowing results to persist on the 145 // devices that produced them. That way dispatching a similarly replicated 146 // computation that consumes the same sharded memory layout does not incur any 147 // transfers. 148 149 // A ShardedDeviceArray represents one logical ndarray value, and simulates the 150 // behavior of an ndarray so that it can be treated by user code as an ndarray; 151 // that is, it is only an optimization to reduce transfers. 152 153 // Design note: We move to C++, only what will need to be accessed by C++ to 154 // execute a pmap computation. A large part of the logic is still in Python. 155 class ShardedDeviceArray : xla::DeviceArrayBase { 156 public: ShardedDeviceArray(pybind11::handle aval,ShardingSpec sharding_spec,pybind11::list device_buffers)157 ShardedDeviceArray( 158 pybind11::handle aval, ShardingSpec sharding_spec, 159 // Buffers are expected to be xla::PyBuffer objects, but as there are 160 // alternative backend implementations, this may not be guaranteed. 161 // TODO(jblespiau): As soon as PjRtBuffer is supported by all 162 // implementations, we should be able to store this with the C++ objects. 163 pybind11::list device_buffers) 164 : DeviceArrayBase(), 165 aval_(pybind11::cast<pybind11::object>(aval)), 166 sharding_spec_(std::move(sharding_spec)), 167 device_buffers_(device_buffers) {} 168 GetAval()169 pybind11::object GetAval() const { return aval_; } GetShardingSpec()170 const ShardingSpec& GetShardingSpec() const { return sharding_spec_; } GetDeviceBuffers()171 pybind11::list GetDeviceBuffers() const { return device_buffers_; } 172 173 private: 174 // A ShapedArray indicating the shape and dtype of this array. 175 pybind11::object aval_; 176 // Describes how this array is sharded across `device_buffers`. 177 ShardingSpec sharding_spec_; 178 // The buffers containing the data for this array. Each buffer is the same 179 // shape and on a different device. Buffers are in row-major order, with 180 // replication treated as an extra innermost dimension. 181 pybind11::list device_buffers_; 182 }; 183 184 void BuildPmapSubmodule(pybind11::module& m); 185 186 } // namespace jax 187 188 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ 189