1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17
18 #include <functional>
19 #include <memory>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26
27 namespace tensorflow {
28
KindToString(XlaResource::Kind kind)29 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
30 switch (kind) {
31 case XlaResource::kInvalid:
32 return "invalid";
33 case XlaResource::kVariable:
34 return "variable";
35 case XlaResource::kStack:
36 return "stack";
37 case XlaResource::kTensorArray:
38 return "tensorarray";
39 }
40 }
41
CreateStack(string name,DataType type,int64_t max_size)42 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
43 string name, DataType type, int64_t max_size) {
44 return std::make_unique<XlaResource>(
45 XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
46 /*initial_value=*/xla::XlaOp(),
47 /*max_array_size=*/max_size,
48 /*tensor_array_gradients=*/std::set<string>{},
49 /*tensor_array_multiple_writes_aggregate=*/false);
50 }
51
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size)52 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
53 string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
54 int64_t max_array_size) {
55 return std::make_unique<XlaResource>(
56 XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
57 initial_value, max_array_size,
58 /*tensor_array_gradients=*/std::set<string>{},
59 /*tensor_array_multiple_writes_aggregate=*/false);
60 }
61
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate,const std::optional<ManagedStackTrace> & definition_stack_trace)62 XlaResource::XlaResource(
63 Kind kind, int arg_num, string name, DataType type, TensorShape shape,
64 xla::XlaOp initial_value, int64_t max_array_size,
65 const std::set<string>& tensor_array_gradients,
66 bool tensor_array_multiple_writes_aggregate,
67 const std::optional<ManagedStackTrace>& definition_stack_trace)
68 : kind_(kind),
69 arg_num_(arg_num),
70 name_(std::move(name)),
71 type_(type),
72 shape_(std::move(shape)),
73 value_(initial_value),
74 initial_value_(initial_value),
75 max_array_size_(max_array_size),
76 tensor_array_multiple_writes_aggregate_(
77 tensor_array_multiple_writes_aggregate),
78 definition_stack_trace_(definition_stack_trace) {
79 CHECK(kind_ != kInvalid);
80
81 for (const string& gradient : tensor_array_gradients) {
82 tensor_array_gradients_[gradient].reset(new XlaResource(
83 /*kind=*/kTensorArray, /*arg_num=*/-1,
84 /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
85 xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
86 /*tensor_array_multiple_writes_aggregate=*/true));
87 }
88 }
89
SetTypeAndShape(DataType type,const TensorShape & shape)90 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
91 if (type == DT_INVALID) {
92 return errors::InvalidArgument(
93 "Attempted to set type of resource '", name_, "'' to an invalid type",
94 DefinitionLocationMsg(definition_stack_trace_));
95 }
96 if (initialized() && type_ != type) {
97 return errors::InvalidArgument(
98 "Trying to assign variable with wrong dtype. Expected ",
99 DataTypeString(type_), " got ", DataTypeString(type),
100 DefinitionLocationMsg(definition_stack_trace_));
101 }
102 if (initialized() && shape_ != shape) {
103 return errors::InvalidArgument(
104 "Shape of resource ", name_,
105 " cannot be changed after initialization: "
106 "old shape was ",
107 shape_.DebugString(), ", new shape is ", shape.DebugString(),
108 DefinitionLocationMsg(definition_stack_trace_));
109 }
110 type_ = type;
111 shape_ = shape;
112 return OkStatus();
113 }
114
SetValue(const xla::XlaOp & value)115 Status XlaResource::SetValue(const xla::XlaOp& value) {
116 if (type_ == DT_INVALID) {
117 return errors::InvalidArgument(
118 "Resource '", name_,
119 "' must be initialized with a valid type before use.");
120 }
121 value_ = value;
122 is_overwritten_ = true;
123 return OkStatus();
124 }
125
SetZeroValue(xla::XlaBuilder * builder)126 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
127 is_overwritten_ = true;
128 if (type_ == DT_INVALID) {
129 return errors::InvalidArgument(
130 "Resource '", name_,
131 "' must be initialized with a valid type before use.");
132 }
133 switch (kind_) {
134 case kVariable: {
135 value_ =
136 xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
137 break;
138 }
139 case kTensorArray: {
140 TensorShape ta_shape;
141 ta_shape.AddDim(max_array_size_);
142 ta_shape.AppendShape(shape_);
143 value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
144 ta_shape.dim_sizes());
145 break;
146 }
147 case kStack: {
148 TensorShape ta_shape;
149 ta_shape.AddDim(max_array_size_);
150 ta_shape.AppendShape(shape_);
151 value_ =
152 xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
153 ta_shape.dim_sizes()),
154 xla::ConstantR0<int32>(builder, 0)});
155 break;
156 }
157
158 case kInvalid:
159 default:
160 LOG(FATAL) << "Invalid resource type";
161 }
162 return OkStatus();
163 }
164
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)165 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
166 xla::XlaBuilder* builder,
167 XlaResource** gradient_out) {
168 VLOG(2) << "Gradient lookup for resource: " << name_
169 << " gradient: " << source;
170 TF_RET_CHECK(kind_ == kTensorArray);
171 std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
172 if (!gradient) {
173 TensorShape ta_shape;
174 ta_shape.AddDim(max_array_size_);
175 ta_shape.AppendShape(shape_);
176 xla::XlaOp gradient_value =
177 xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
178 gradient.reset(
179 new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
180 /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
181 type_, shape_, gradient_value, max_array_size_,
182 /*tensor_array_gradients=*/{},
183 /*tensor_array_multiple_writes_aggregate=*/true));
184 }
185 *gradient_out = gradient.get();
186 return OkStatus();
187 }
188
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const189 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
190 if (tensor_array_gradients_.empty()) {
191 *pack = value_;
192 } else {
193 TF_RET_CHECK(kind_ == kTensorArray);
194 std::vector<xla::XlaOp> elems;
195 elems.push_back(value_);
196 for (const auto& gradient : tensor_array_gradients_) {
197 elems.push_back(gradient.second->value_);
198 }
199 *pack = xla::Tuple(builder, elems);
200 }
201 return OkStatus();
202 }
203
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)204 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
205 const xla::XlaOp& pack,
206 xla::XlaBuilder* builder) {
207 if (gradient_sources.empty()) {
208 if (!initialized()) {
209 initial_value_ = pack;
210 }
211 value_ = pack;
212 } else {
213 TF_RET_CHECK(kind_ == kTensorArray);
214 int pos = 0;
215 auto v = xla::GetTupleElement(pack, pos++);
216 if (!initialized()) {
217 initial_value_ = v;
218 }
219 value_ = v;
220
221 for (const auto& source : gradient_sources) {
222 XlaResource* gradient;
223 TF_RETURN_IF_ERROR(
224 GetOrCreateTensorArrayGradient(source, builder, &gradient));
225 auto v = xla::GetTupleElement(pack, pos++);
226 if (!gradient->initialized()) {
227 gradient->initial_value_ = v;
228 }
229 gradient->value_ = v;
230 }
231 }
232 return OkStatus();
233 }
234
235 } // namespace tensorflow
236