1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17
18 #include <functional>
19 #include <memory>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26
27 namespace tensorflow {
28
KindToString(XlaResource::Kind kind)29 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
30 switch (kind) {
31 case XlaResource::kInvalid:
32 return "invalid";
33 case XlaResource::kVariable:
34 return "variable";
35 case XlaResource::kStack:
36 return "stack";
37 case XlaResource::kTensorArray:
38 return "tensorarray";
39 }
40 }
41
CreateStack(string name,DataType type,int64_t max_size)42 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
43 string name, DataType type, int64_t max_size) {
44 return absl::make_unique<XlaResource>(
45 XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
46 /*initial_value=*/xla::XlaOp(),
47 /*max_array_size=*/max_size,
48 /*tensor_array_gradients=*/std::set<string>{},
49 /*tensor_array_multiple_writes_aggregate=*/false);
50 }
51
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size)52 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
53 string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
54 int64_t max_array_size) {
55 return absl::make_unique<XlaResource>(
56 XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
57 initial_value, max_array_size,
58 /*tensor_array_gradients=*/std::set<string>{},
59 /*tensor_array_multiple_writes_aggregate=*/false);
60 }
61
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate,const absl::optional<ManagedStackTrace> & definition_stack_trace)62 XlaResource::XlaResource(
63 Kind kind, int arg_num, string name, DataType type, TensorShape shape,
64 xla::XlaOp initial_value, int64_t max_array_size,
65 const std::set<string>& tensor_array_gradients,
66 bool tensor_array_multiple_writes_aggregate,
67 const absl::optional<ManagedStackTrace>& definition_stack_trace)
68 : kind_(kind),
69 arg_num_(arg_num),
70 name_(std::move(name)),
71 type_(type),
72 shape_(std::move(shape)),
73 value_(initial_value),
74 initial_value_(initial_value),
75 max_array_size_(max_array_size),
76 tensor_array_multiple_writes_aggregate_(
77 tensor_array_multiple_writes_aggregate),
78 definition_stack_trace_(definition_stack_trace) {
79 CHECK(kind_ != kInvalid);
80
81 for (const string& gradient : tensor_array_gradients) {
82 tensor_array_gradients_[gradient].reset(new XlaResource(
83 /*kind=*/kTensorArray, /*arg_num=*/-1,
84 /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
85 xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
86 /*tensor_array_multiple_writes_aggregate=*/true));
87 }
88 }
89
SetTypeAndShape(DataType type,const TensorShape & shape)90 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
91 if (type == DT_INVALID) {
92 return errors::InvalidArgument(
93 "Attempted to set type of resource '", name_, "'' to an invalid type",
94 DefinitionLocationMsg(definition_stack_trace_));
95 }
96 if (initialized() && type_ != type) {
97 return errors::InvalidArgument(
98 "Type of resource ", name_,
99 " cannot be changed after initialization: "
100 "old type was ",
101 DataTypeString(type_), ", new type is ", DataTypeString(type),
102 DefinitionLocationMsg(definition_stack_trace_));
103 }
104 if (initialized() && shape_ != shape) {
105 return errors::InvalidArgument(
106 "Shape of resource ", name_,
107 " cannot be changed after initialization: "
108 "old shape was ",
109 shape_.DebugString(), ", new shape is ", shape.DebugString(),
110 DefinitionLocationMsg(definition_stack_trace_));
111 }
112 type_ = type;
113 shape_ = shape;
114 return Status::OK();
115 }
116
SetValue(const xla::XlaOp & value)117 Status XlaResource::SetValue(const xla::XlaOp& value) {
118 if (type_ == DT_INVALID) {
119 return errors::InvalidArgument(
120 "Resource '", name_,
121 "' must be initialized with a valid type before use.");
122 }
123 value_ = value;
124 is_overwritten_ = true;
125 return Status::OK();
126 }
127
SetZeroValue(xla::XlaBuilder * builder)128 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
129 is_overwritten_ = true;
130 if (type_ == DT_INVALID) {
131 return errors::InvalidArgument(
132 "Resource '", name_,
133 "' must be initialized with a valid type before use.");
134 }
135 switch (kind_) {
136 case kVariable: {
137 value_ =
138 xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
139 break;
140 }
141 case kTensorArray: {
142 TensorShape ta_shape;
143 ta_shape.AddDim(max_array_size_);
144 ta_shape.AppendShape(shape_);
145 value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
146 ta_shape.dim_sizes());
147 break;
148 }
149 case kStack: {
150 TensorShape ta_shape;
151 ta_shape.AddDim(max_array_size_);
152 ta_shape.AppendShape(shape_);
153 value_ =
154 xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
155 ta_shape.dim_sizes()),
156 xla::ConstantR0<int32>(builder, 0)});
157 break;
158 }
159
160 case kInvalid:
161 default:
162 LOG(FATAL) << "Invalid resource type";
163 }
164 return Status::OK();
165 }
166
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)167 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
168 xla::XlaBuilder* builder,
169 XlaResource** gradient_out) {
170 VLOG(2) << "Gradient lookup for resource: " << name_
171 << " gradient: " << source;
172 TF_RET_CHECK(kind_ == kTensorArray);
173 std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
174 if (!gradient) {
175 TensorShape ta_shape;
176 ta_shape.AddDim(max_array_size_);
177 ta_shape.AppendShape(shape_);
178 xla::XlaOp gradient_value =
179 xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
180 gradient.reset(
181 new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
182 /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
183 type_, shape_, gradient_value, max_array_size_,
184 /*tensor_array_gradients=*/{},
185 /*tensor_array_multiple_writes_aggregate=*/true));
186 }
187 *gradient_out = gradient.get();
188 return Status::OK();
189 }
190
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const191 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
192 if (tensor_array_gradients_.empty()) {
193 *pack = value_;
194 } else {
195 TF_RET_CHECK(kind_ == kTensorArray);
196 std::vector<xla::XlaOp> elems;
197 elems.push_back(value_);
198 for (const auto& gradient : tensor_array_gradients_) {
199 elems.push_back(gradient.second->value_);
200 }
201 *pack = xla::Tuple(builder, elems);
202 }
203 return Status::OK();
204 }
205
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)206 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
207 const xla::XlaOp& pack,
208 xla::XlaBuilder* builder) {
209 if (gradient_sources.empty()) {
210 if (!initialized()) {
211 initial_value_ = pack;
212 }
213 value_ = pack;
214 } else {
215 TF_RET_CHECK(kind_ == kTensorArray);
216 int pos = 0;
217 auto v = xla::GetTupleElement(pack, pos++);
218 if (!initialized()) {
219 initial_value_ = v;
220 }
221 value_ = v;
222
223 for (const auto& source : gradient_sources) {
224 XlaResource* gradient;
225 TF_RETURN_IF_ERROR(
226 GetOrCreateTensorArrayGradient(source, builder, &gradient));
227 auto v = xla::GetTupleElement(pack, pos++);
228 if (!gradient->initialized()) {
229 gradient->initial_value_ = v;
230 }
231 gradient->value_ = v;
232 }
233 }
234 return Status::OK();
235 }
236
237 } // namespace tensorflow
238