1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines helper routines for the XLA device. 17 18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ 19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ 20 21 #include "absl/types/optional.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" 24 #include "tensorflow/compiler/xla/client/xla_builder.h" 25 #include "tensorflow/compiler/xla/service/computation_placer.h" 26 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/tensor.h" 29 30 namespace tensorflow { 31 32 // Helper methods for building XLA computations. 33 class XlaHelpers { 34 public: 35 // Returns a handle representing the zero value of a scalar 36 // element of data_type. 37 static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); 38 39 // Returns a handle representing the one value of a scalar 40 // element of data_type. 41 static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); 42 43 // Returns a handle representing the given value of an integer scalar 44 // element of data_type. 45 // Note that unlike One and Zero, does not work on boolean types. 46 static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type, 47 int64_t value); 48 49 // Returns a handle representing the given value of a floating-point scalar 50 // element of data_type. 51 static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type, 52 double value); 53 54 // Reshapes literal 'input' to have 'shape'. Both the original shape and 55 // 'shape' must contain the same number of elements. 56 static Status ReshapeLiteral(const xla::Literal& input, 57 absl::Span<const int64> shape, 58 xla::Literal* output); 59 60 // Converts `indices` into a one-hot representation. `depth` is the size 61 // of the new axis to add. `axis` is the position at which to add the new 62 // axis. `indices_shape` is the shape of `indices`. `on_value` and 63 // `off_value` represent the values to use for the on and off positions, 64 // respectively. 65 static Status OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, 66 DataType index_type, const TensorShape& indices_shape, 67 const xla::XlaOp& indices, const xla::XlaOp& on_value, 68 const xla::XlaOp& off_value, xla::XlaOp* one_hot); 69 70 // Certain DataTypes should use increased precision DataTypes when performing 71 // reductions. This function remaps a given DataType to a higher precision 72 // DataType if needed. 73 static DataType SumAccumulationType(const DataType& dtype); 74 75 // A helper for creating a ConvertElementType xla op given a DataType rather 76 // than the xla::PrimitiveType. 77 static xla::XlaOp ConvertElementType(const xla::XlaOp& operand, 78 const DataType new_element_type); 79 80 typedef std::function<StatusOr<xla::Shape>(const TensorShape&, DataType, 81 bool)> 82 ShapeRepresentationFn; 83 }; 84 85 // Creates an identity shape representation function. 86 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn(); 87 88 // Rewrites the layout of xla_shape if there is tiled sharding. 89 Status RewriteLayoutWithShardedShape( 90 const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory, 91 XlaHelpers::ShapeRepresentationFn shape_representation_fn, 92 xla::Shape* xla_shape); 93 94 // Adds reshapes to fix the layout of an output, if a shape_representation_fn or 95 // sharding is present. 96 StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding( 97 xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, 98 XlaHelpers::ShapeRepresentationFn shape_representation_fn, 99 absl::optional<xla::OpSharding> sharding, bool fast_mem); 100 101 struct XlaOutputDescription { 102 // Type and shape of the output. The shape is the unflattened shape. 103 // When `type` is DT_RESOURCE, `shape` is the shape of the resource 104 // variable's value. 105 DataType type; 106 TensorShape shape; 107 108 // Constant output value, if known to be constant at JIT compilation time. 109 // 'Tensor' is in host memory. 110 bool is_constant = false; 111 Tensor constant_value; 112 113 // When this output is a resource, i.e. `type == DT_RESOURCE`, this is 114 // the index of the input that contains the resource. 115 int input_index; 116 117 // Whether this output is a TensorList. 118 bool is_tensor_list = false; 119 }; 120 121 // Describes a variable write side effect of the computation. 122 struct XlaResourceUpdate { 123 // Index of the input that contains the variable resource to write to. 124 int input_index; 125 126 // Type and shape of the tensor to be written back. 127 // The `shape` field has the same meaning as the Argument::shape field. 128 DataType type; 129 TensorShape shape; 130 131 // Was the value of the variable modified by the computation? 132 // (Always true, unless `return_updated_values_for_all_resources` is true.) 133 bool modified; 134 135 // If the resource is a TensorArray, the set of gradients read or written. 136 std::set<string> tensor_array_gradients_accessed; 137 }; 138 139 struct XlaCompilationResult { 140 // Vector that maps from the parameters of the XLA computation to their 141 // original argument positions. To handle compile-time constant inputs, the 142 // parameters to the XLA computation may be a subset of the original 143 // arguments. The relative ordering of parameters are maintained. 144 std::vector<int> input_mapping; 145 146 // Input shapes of the computation. If we are flattening inputs, these are 147 // the flattened shapes. 148 std::vector<xla::Shape> xla_input_shapes; 149 150 // Output shape in XLA format. The output shape is always a tuple. If we 151 // are flattening outputs, these are the flattened shapes. 152 xla::Shape xla_output_shape; 153 154 // TensorFlow shapes of outputs, together with the values of any 155 // constant arguments. Vector indexed by Tensorflow _Retval number, 156 // containing both constant and non-constant results. 157 std::vector<XlaOutputDescription> outputs; 158 159 // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their 160 // matching RecvAtHost/SendFromHost Ops in the outer graph. 161 tf2xla::HostComputeMetadata host_compute_metadata; 162 163 // Resources whose values were updated by the computation, ordered 164 // by return value position (which is the same as the order the resources 165 // were passed as arguments). Resource updates follow the non-constant 166 // results in the outputs of XLA computation. 167 std::vector<XlaResourceUpdate> resource_updates; 168 169 // The XLA computation built from the tensorflow subgraph. 170 std::shared_ptr<xla::XlaComputation> computation; 171 172 // Meta-info about encountered CollectiveReduceV2Ops. 173 struct CollectiveReduceV2OpInfo { 174 int group_key; 175 int group_size; 176 }; 177 178 // Group keys of the collectives encountered during the translation. 179 // Mapping from group keys to group sizes. 180 absl::optional<CollectiveReduceV2OpInfo> collective_reduce_info; 181 }; 182 183 // Resolves the device assignment based on CollectiveReduceV2OpInfo. 184 // CollectiveReduceV2OpInfo records collective ops in the cluster. Note that 185 // this relies on a rendezvous and blocks until all replicas are there. 186 StatusOr<absl::optional<xla::DeviceAssignment>> ResolveDeviceAssignment( 187 OpKernelContext* ctx, 188 const absl::optional<XlaCompilationResult::CollectiveReduceV2OpInfo>& 189 collective_reduce_info); 190 191 // Generate a message with a definition location based on a provided stack 192 // trace, or an empty one if the stack trace is empty. 193 std::string DefinitionLocationMsg( 194 const absl::optional<ManagedStackTrace>& stack_trace); 195 196 } // end namespace tensorflow 197 198 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ 199