1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 18 19 #include <memory> 20 21 #include "absl/strings/string_view.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/xla_data.pb.h" 24 #include "tensorflow/core/framework/tensor_shape.h" 25 #include "tensorflow/core/framework/types.pb.h" 26 #include "tensorflow/core/lib/core/status.h" 27 28 namespace tensorflow { 29 30 // Represents a resource, such as a Variable or TensorArray. 31 class XlaResource { 32 public: 33 enum Kind { 34 kInvalid, 35 kVariable, 36 kTensorArray, 37 kStack, 38 }; 39 static absl::string_view KindToString(Kind kind); 40 41 // Creates a new Stack resource. 42 static std::unique_ptr<XlaResource> CreateStack(string name, DataType type, 43 int64 max_size); 44 45 // Creates a new TensorArray resource. 46 static std::unique_ptr<XlaResource> CreateTensorArray( 47 string name, DataType type, TensorShape shape, xla::XlaOp initial_value, 48 int64 max_array_size); 49 50 XlaResource(Kind kind, int arg_num, string name, DataType type, 51 TensorShape shape, const xla::XlaOp& initial_value, 52 int64 max_array_size, 53 const std::set<string>& tensor_array_gradients, 54 bool tensor_array_multiple_writes_aggregate); 55 56 XlaResource(const XlaResource&) = delete; 57 XlaResource(XlaResource&&) = delete; 58 XlaResource& operator=(const XlaResource&) = delete; 59 XlaResource& operator=(XlaResource&&) = delete; 60 kind()61 Kind kind() const { return kind_; } 62 63 // If this resource is visible externally to the computation, what was its 64 // argument number? 65 // < 0 means "not visible externally". arg_num()66 int arg_num() const { return arg_num_; } 67 68 // A descriptive name for the resource, used in error messages. name()69 const string& name() const { return name_; } 70 71 // Current type and value of the resource. Uninitialized resources are 72 // represented by a default (zero) handle and type DT_INVALID. 73 // While the type of a resource is notionally fixed during execution, when 74 // a resource is first initialized we do not yet know its type, so we keep 75 // track of its type dynamically. type()76 DataType type() const { return type_; } 77 78 // Shape of the resource. For an uninitialized resource, this is ignored. 79 // For a Variable, this is the shape of the value. For a TensorArray or Stack 80 // this is the shape of each entry in the TensorArray/Stack. shape()81 const TensorShape& shape() const { return shape_; } 82 value()83 const xla::XlaOp& value() const { return value_; } 84 85 // Value of the resource at computation entry. Used to detect which 86 // variables have new values that need to be written back. initial_value()87 const xla::XlaOp& initial_value() const { return initial_value_; } 88 89 // An xla shape that indicates how this resource variable is represented on 90 // device. representation_shape()91 const absl::optional<xla::Shape>& representation_shape() const { 92 return representation_shape_; 93 } 94 95 // A variable is initialized if it has a value. initialized()96 bool initialized() const { return value_.valid(); } 97 98 // Sets the type and shape of the resource. The type and shape of a resource 99 // must not change once the variable has been initialized. 100 Status SetTypeAndShape(DataType type, const TensorShape& shape); 101 102 // Sets the current value of the resource. Returns an error if the type is not 103 // set to a valid value. 104 Status SetValue(const xla::XlaOp& value); 105 106 // Sets the current value of the resource to an all-zero value. 107 Status SetZeroValue(xla::XlaBuilder* builder); 108 109 // Sets the representational shape of the resource on device. SetRepresentationShape(const xla::Shape & shape)110 void SetRepresentationShape(const xla::Shape& shape) { 111 representation_shape_ = absl::make_optional(shape); 112 } 113 114 // Looks up the gradient for `source`, or creates it if it does not already 115 // exist. The call target must be an initialized TensorArray resource. A 116 // TensorArray can have multiple named gradients; see the operator 117 // documentation for TensorArrayGradV3 for details. 118 Status GetOrCreateTensorArrayGradient(const string& source, 119 xla::XlaBuilder* builder, 120 XlaResource** gradient_out); 121 122 // Packs a resource into a single XLA value `pack`, suitable for use as 123 // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without 124 // gradients, sets `*pack` to `value`. 125 // For TensorArrays with gradients, packs the value and its gradient values in 126 // a tuple; the gradients values are packed in order by source name. 127 Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; 128 129 // Updates the resource with values from `pack`. If `gradient_sources` is 130 // non-empty, treats `pack` as a tuple that represents a TensorArray and 131 // its gradients, and unpacks and updates the gradient resources. 132 // If `reset_initial_values` is true, sets the initial_values as well as the 133 // values. 134 // Opposite of Pack(). 135 Status SetFromPack(const std::set<string>& gradient_sources, 136 const xla::XlaOp& pack, xla::XlaBuilder* builder); 137 IsOverwritten()138 bool IsOverwritten() { return is_overwritten_; } 139 140 // TensorArray and Stack specific fields 141 // TODO(phawkins): refactor this code to use subclasses, rather than putting 142 // kind-specific fields in XlaResource. 143 144 // 'max_array_size' stores the expected size of the TensorArray or Stack. 145 // We need to store this since sometimes TensorArrays must be initialized 146 // lazily since we do not know the element shape at construction time. 147 // Used by both TensorArrays and Stacks. max_array_size()148 int64 max_array_size() const { return max_array_size_; } set_max_array_size(int64 size)149 void set_max_array_size(int64 size) { max_array_size_ = size; } 150 tensor_array_multiple_writes_aggregate()151 bool tensor_array_multiple_writes_aggregate() const { 152 return tensor_array_multiple_writes_aggregate_; 153 } 154 155 // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes 156 // to an XlaResource containing the gradient TensorArrays. We store a pointer 157 // here since there should only be one gradient TensorArray per 'source' 158 // string, irrespective of the number of calls to TensorArrayGrad. The map 159 // is ordered since values are packed into tuples by Pack() sorted by name 160 // order. tensor_array_gradients()161 const std::map<string, std::unique_ptr<XlaResource>>& tensor_array_gradients() 162 const { 163 return tensor_array_gradients_; 164 } 165 166 private: 167 const Kind kind_; 168 const int arg_num_; 169 const string name_; 170 171 DataType type_; 172 TensorShape shape_; 173 xla::XlaOp value_; 174 xla::XlaOp initial_value_; 175 176 // An xla shape that indicates how this resource variable is represented on 177 // device. 178 absl::optional<xla::Shape> representation_shape_; 179 180 int64 max_array_size_ = -1; 181 bool tensor_array_multiple_writes_aggregate_ = false; 182 183 std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_; 184 bool is_overwritten_ = false; 185 }; 186 187 } // namespace tensorflow 188 189 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 190