1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 18 19 #include <memory> 20 21 #include "absl/strings/string_view.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/xla_data.pb.h" 24 #include "tensorflow/core/framework/tensor_shape.h" 25 #include "tensorflow/core/framework/types.pb.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/util/managed_stack_trace.h" 28 29 namespace tensorflow { 30 31 // Represents a resource, such as a Variable or TensorArray. 32 class XlaResource { 33 public: 34 enum Kind { 35 kInvalid, 36 kVariable, 37 kTensorArray, 38 kStack, 39 }; 40 static absl::string_view KindToString(Kind kind); 41 42 // Creates a new Stack resource. 43 static std::unique_ptr<XlaResource> CreateStack(string name, DataType type, 44 int64_t max_size); 45 46 // Creates a new TensorArray resource. 47 static std::unique_ptr<XlaResource> CreateTensorArray( 48 string name, DataType type, TensorShape shape, xla::XlaOp initial_value, 49 int64_t max_array_size); 50 51 XlaResource(Kind kind, int arg_num, string name, DataType type, 52 TensorShape shape, xla::XlaOp initial_value, 53 int64_t max_array_size, 54 const std::set<string>& tensor_array_gradients, 55 bool tensor_array_multiple_writes_aggregate, 56 const std::optional<ManagedStackTrace>& definition_stack_trace = 57 std::nullopt); 58 59 XlaResource(const XlaResource&) = delete; 60 XlaResource(XlaResource&&) = delete; 61 XlaResource& operator=(const XlaResource&) = delete; 62 XlaResource& operator=(XlaResource&&) = delete; 63 kind()64 Kind kind() const { return kind_; } 65 66 // If this resource is visible externally to the computation, what was its 67 // argument number? 68 // < 0 means "not visible externally". arg_num()69 int arg_num() const { return arg_num_; } 70 71 // A descriptive name for the resource, used in error messages. name()72 const string& name() const { return name_; } 73 74 // Current type and value of the resource. Uninitialized resources are 75 // represented by a default (zero) handle and type DT_INVALID. 76 // While the type of a resource is notionally fixed during execution, when 77 // a resource is first initialized we do not yet know its type, so we keep 78 // track of its type dynamically. type()79 DataType type() const { return type_; } 80 81 // Shape of the resource. For an uninitialized resource, this is ignored. 82 // For a Variable, this is the shape of the value. For a TensorArray or Stack 83 // this is the shape of each entry in the TensorArray/Stack. shape()84 const TensorShape& shape() const { return shape_; } 85 value()86 const xla::XlaOp& value() const { return value_; } 87 88 // Value of the resource at computation entry. Used to detect which 89 // variables have new values that need to be written back. initial_value()90 const xla::XlaOp& initial_value() const { return initial_value_; } 91 92 // An xla shape that indicates how this resource variable is represented on 93 // device. representation_shape()94 const std::optional<xla::Shape>& representation_shape() const { 95 return representation_shape_; 96 } 97 98 // A variable is initialized if it has a value. initialized()99 bool initialized() const { return value_.valid(); } 100 101 // Sets the type and shape of the resource. The type and shape of a resource 102 // must not change once the variable has been initialized. 103 Status SetTypeAndShape(DataType type, const TensorShape& shape); 104 105 // Sets the current value of the resource. Returns an error if the type is not 106 // set to a valid value. 107 Status SetValue(const xla::XlaOp& value); 108 109 // Sets the current value of the resource to an all-zero value. 110 Status SetZeroValue(xla::XlaBuilder* builder); 111 112 // Sets the representational shape of the resource on device. SetRepresentationShape(const xla::Shape & shape)113 void SetRepresentationShape(const xla::Shape& shape) { 114 representation_shape_ = absl::make_optional(shape); 115 } 116 117 // Looks up the gradient for `source`, or creates it if it does not already 118 // exist. The call target must be an initialized TensorArray resource. A 119 // TensorArray can have multiple named gradients; see the operator 120 // documentation for TensorArrayGradV3 for details. 121 Status GetOrCreateTensorArrayGradient(const string& source, 122 xla::XlaBuilder* builder, 123 XlaResource** gradient_out); 124 125 // Packs a resource into a single XLA value `pack`, suitable for use as 126 // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without 127 // gradients, sets `*pack` to `value`. 128 // For TensorArrays with gradients, packs the value and its gradient values in 129 // a tuple; the gradients values are packed in order by source name. 130 Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; 131 132 // Updates the resource with values from `pack`. If `gradient_sources` is 133 // non-empty, treats `pack` as a tuple that represents a TensorArray and 134 // its gradients, and unpacks and updates the gradient resources. 135 // If `reset_initial_values` is true, sets the initial_values as well as the 136 // values. 137 // Opposite of Pack(). 138 Status SetFromPack(const std::set<string>& gradient_sources, 139 const xla::XlaOp& pack, xla::XlaBuilder* builder); 140 IsOverwritten()141 bool IsOverwritten() { return is_overwritten_; } 142 143 // TensorArray and Stack specific fields 144 // TODO(phawkins): refactor this code to use subclasses, rather than putting 145 // kind-specific fields in XlaResource. 146 147 // 'max_array_size' stores the expected size of the TensorArray or Stack. 148 // We need to store this since sometimes TensorArrays must be initialized 149 // lazily since we do not know the element shape at construction time. 150 // Used by both TensorArrays and Stacks. max_array_size()151 int64_t max_array_size() const { return max_array_size_; } set_max_array_size(int64_t size)152 void set_max_array_size(int64_t size) { max_array_size_ = size; } 153 tensor_array_multiple_writes_aggregate()154 bool tensor_array_multiple_writes_aggregate() const { 155 return tensor_array_multiple_writes_aggregate_; 156 } 157 158 // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes 159 // to an XlaResource containing the gradient TensorArrays. We store a pointer 160 // here since there should only be one gradient TensorArray per 'source' 161 // string, irrespective of the number of calls to TensorArrayGrad. The map 162 // is ordered since values are packed into tuples by Pack() sorted by name 163 // order. tensor_array_gradients()164 const std::map<string, std::unique_ptr<XlaResource>>& tensor_array_gradients() 165 const { 166 return tensor_array_gradients_; 167 } 168 169 private: 170 const Kind kind_; 171 const int arg_num_; 172 const string name_; 173 174 DataType type_; 175 TensorShape shape_; 176 xla::XlaOp value_; 177 xla::XlaOp initial_value_; 178 179 // An xla shape that indicates how this resource variable is represented on 180 // device. 181 std::optional<xla::Shape> representation_shape_; 182 183 int64_t max_array_size_ = -1; 184 bool tensor_array_multiple_writes_aggregate_ = false; 185 186 std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_; 187 bool is_overwritten_ = false; 188 189 std::optional<ManagedStackTrace> definition_stack_trace_; 190 }; 191 192 } // namespace tensorflow 193 194 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 195