• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
18 
19 #include <memory>
20 
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/util/managed_stack_trace.h"
28 
29 namespace tensorflow {
30 
31 // Represents a resource, such as a Variable or TensorArray.
32 class XlaResource {
33  public:
34   enum Kind {
35     kInvalid,
36     kVariable,
37     kTensorArray,
38     kStack,
39   };
40   static absl::string_view KindToString(Kind kind);
41 
42   // Creates a new Stack resource.
43   static std::unique_ptr<XlaResource> CreateStack(string name, DataType type,
44                                                   int64_t max_size);
45 
46   // Creates a new TensorArray resource.
47   static std::unique_ptr<XlaResource> CreateTensorArray(
48       string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
49       int64_t max_array_size);
50 
51   XlaResource(Kind kind, int arg_num, string name, DataType type,
52               TensorShape shape, xla::XlaOp initial_value,
53               int64_t max_array_size,
54               const std::set<string>& tensor_array_gradients,
55               bool tensor_array_multiple_writes_aggregate,
56               const std::optional<ManagedStackTrace>& definition_stack_trace =
57                   std::nullopt);
58 
59   XlaResource(const XlaResource&) = delete;
60   XlaResource(XlaResource&&) = delete;
61   XlaResource& operator=(const XlaResource&) = delete;
62   XlaResource& operator=(XlaResource&&) = delete;
63 
kind()64   Kind kind() const { return kind_; }
65 
66   // If this resource is visible externally to the computation, what was its
67   // argument number?
68   // < 0 means "not visible externally".
arg_num()69   int arg_num() const { return arg_num_; }
70 
71   // A descriptive name for the resource, used in error messages.
name()72   const string& name() const { return name_; }
73 
74   // Current type and value of the resource. Uninitialized resources are
75   // represented by a default (zero) handle and type DT_INVALID.
76   // While the type of a resource is notionally fixed during execution, when
77   // a resource is first initialized we do not yet know its type, so we keep
78   // track of its type dynamically.
type()79   DataType type() const { return type_; }
80 
81   // Shape of the resource. For an uninitialized resource, this is ignored.
82   // For a Variable, this is the shape of the value. For a TensorArray or Stack
83   // this is the shape of each entry in the TensorArray/Stack.
shape()84   const TensorShape& shape() const { return shape_; }
85 
value()86   const xla::XlaOp& value() const { return value_; }
87 
88   // Value of the resource at computation entry. Used to detect which
89   // variables have new values that need to be written back.
initial_value()90   const xla::XlaOp& initial_value() const { return initial_value_; }
91 
92   // An xla shape that indicates how this resource variable is represented on
93   // device.
representation_shape()94   const std::optional<xla::Shape>& representation_shape() const {
95     return representation_shape_;
96   }
97 
98   // A variable is initialized if it has a value.
initialized()99   bool initialized() const { return value_.valid(); }
100 
101   // Sets the type and shape of the resource. The type and shape of a resource
102   // must not change once the variable has been initialized.
103   Status SetTypeAndShape(DataType type, const TensorShape& shape);
104 
105   // Sets the current value of the resource. Returns an error if the type is not
106   // set to a valid value.
107   Status SetValue(const xla::XlaOp& value);
108 
109   // Sets the current value of the resource to an all-zero value.
110   Status SetZeroValue(xla::XlaBuilder* builder);
111 
112   // Sets the representational shape of the resource on device.
SetRepresentationShape(const xla::Shape & shape)113   void SetRepresentationShape(const xla::Shape& shape) {
114     representation_shape_ = absl::make_optional(shape);
115   }
116 
117   // Looks up the gradient for `source`, or creates it if it does not already
118   // exist. The call target must be an initialized TensorArray resource. A
119   // TensorArray can have multiple named gradients; see the operator
120   // documentation for TensorArrayGradV3 for details.
121   Status GetOrCreateTensorArrayGradient(const string& source,
122                                         xla::XlaBuilder* builder,
123                                         XlaResource** gradient_out);
124 
125   // Packs a resource into a single XLA value `pack`, suitable for use as
126   // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without
127   // gradients, sets `*pack` to `value`.
128   // For TensorArrays with gradients, packs the value and its gradient values in
129   // a tuple; the gradients values are packed in order by source name.
130   Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const;
131 
132   // Updates the resource with values from `pack`. If `gradient_sources` is
133   // non-empty, treats `pack` as a tuple that represents a TensorArray and
134   // its gradients, and unpacks and updates the gradient resources.
135   // If `reset_initial_values` is true, sets the initial_values as well as the
136   // values.
137   // Opposite of Pack().
138   Status SetFromPack(const std::set<string>& gradient_sources,
139                      const xla::XlaOp& pack, xla::XlaBuilder* builder);
140 
IsOverwritten()141   bool IsOverwritten() { return is_overwritten_; }
142 
143   // TensorArray and Stack specific fields
144   // TODO(phawkins): refactor this code to use subclasses, rather than putting
145   // kind-specific fields in XlaResource.
146 
147   // 'max_array_size' stores the expected size of the TensorArray or Stack.
148   // We need to store this since sometimes TensorArrays must be initialized
149   // lazily since we do not know the element shape at construction time.
150   // Used by both TensorArrays and Stacks.
max_array_size()151   int64_t max_array_size() const { return max_array_size_; }
set_max_array_size(int64_t size)152   void set_max_array_size(int64_t size) { max_array_size_ = size; }
153 
tensor_array_multiple_writes_aggregate()154   bool tensor_array_multiple_writes_aggregate() const {
155     return tensor_array_multiple_writes_aggregate_;
156   }
157 
158   // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
159   // to an XlaResource containing the gradient TensorArrays. We store a pointer
160   // here since there should only be one gradient TensorArray per 'source'
161   // string, irrespective of the number of calls to TensorArrayGrad. The map
162   // is ordered since values are packed into tuples by Pack() sorted by name
163   // order.
tensor_array_gradients()164   const std::map<string, std::unique_ptr<XlaResource>>& tensor_array_gradients()
165       const {
166     return tensor_array_gradients_;
167   }
168 
169  private:
170   const Kind kind_;
171   const int arg_num_;
172   const string name_;
173 
174   DataType type_;
175   TensorShape shape_;
176   xla::XlaOp value_;
177   xla::XlaOp initial_value_;
178 
179   // An xla shape that indicates how this resource variable is represented on
180   // device.
181   std::optional<xla::Shape> representation_shape_;
182 
183   int64_t max_array_size_ = -1;
184   bool tensor_array_multiple_writes_aggregate_ = false;
185 
186   std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
187   bool is_overwritten_ = false;
188 
189   std::optional<ManagedStackTrace> definition_stack_trace_;
190 };
191 
192 }  // namespace tensorflow
193 
194 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
195