1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17
18 #include <functional>
19 #include <memory>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26
27 namespace tensorflow {
28
KindToString(XlaResource::Kind kind)29 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
30 switch (kind) {
31 case XlaResource::kInvalid:
32 return "invalid";
33 case XlaResource::kVariable:
34 return "variable";
35 case XlaResource::kStack:
36 return "stack";
37 case XlaResource::kTensorArray:
38 return "tensorarray";
39 }
40 }
41
CreateStack(string name,DataType type,int64 max_size)42 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
43 string name, DataType type, int64 max_size) {
44 return absl::make_unique<XlaResource>(
45 XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
46 /*initial_value=*/xla::XlaOp(),
47 /*max_array_size=*/max_size,
48 /*tensor_array_gradients=*/std::set<string>{},
49 /*tensor_array_multiple_writes_aggregate=*/false);
50 }
51
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64 max_array_size)52 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
53 string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
54 int64 max_array_size) {
55 return absl::make_unique<XlaResource>(
56 XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
57 initial_value, max_array_size,
58 /*tensor_array_gradients=*/std::set<string>{},
59 /*tensor_array_multiple_writes_aggregate=*/false);
60 }
61
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,const xla::XlaOp & initial_value,int64 max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate)62 XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
63 TensorShape shape, const xla::XlaOp& initial_value,
64 int64 max_array_size,
65 const std::set<string>& tensor_array_gradients,
66 bool tensor_array_multiple_writes_aggregate)
67 : kind_(kind),
68 arg_num_(arg_num),
69 name_(std::move(name)),
70 type_(type),
71 shape_(std::move(shape)),
72 value_(initial_value),
73 initial_value_(initial_value),
74 max_array_size_(max_array_size),
75 tensor_array_multiple_writes_aggregate_(
76 tensor_array_multiple_writes_aggregate) {
77 CHECK(kind_ != kInvalid);
78
79 for (const string& gradient : tensor_array_gradients) {
80 tensor_array_gradients_[gradient].reset(new XlaResource(
81 /*kind=*/kTensorArray, /*arg_num=*/-1,
82 /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
83 xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
84 /*tensor_array_multiple_writes_aggregate=*/true));
85 }
86 }
87
SetTypeAndShape(DataType type,const TensorShape & shape)88 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
89 if (type == DT_INVALID) {
90 return errors::InvalidArgument("Attempted to set type of resource '", name_,
91 "'' to an invalid type");
92 }
93 if (initialized() && type_ != type) {
94 return errors::Unimplemented("Type of resource ", name_,
95 " cannot be changed after initialization: "
96 "old type was ",
97 DataTypeString(type_), ", new type is ",
98 DataTypeString(type));
99 }
100 if (initialized() && shape_ != shape) {
101 return errors::Unimplemented("Shape of resource ", name_,
102 " cannot be changed after initialization: "
103 "old shape was ",
104 shape_.DebugString(), ", new shape is ",
105 shape.DebugString());
106 }
107 type_ = type;
108 shape_ = shape;
109 return Status::OK();
110 }
111
SetValue(const xla::XlaOp & value)112 Status XlaResource::SetValue(const xla::XlaOp& value) {
113 if (type_ == DT_INVALID) {
114 return errors::InvalidArgument(
115 "Resource '", name_,
116 "' must be initialized with a valid type before use.");
117 }
118 value_ = value;
119 is_overwritten_ = true;
120 return Status::OK();
121 }
122
SetZeroValue(xla::XlaBuilder * builder)123 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
124 is_overwritten_ = true;
125 if (type_ == DT_INVALID) {
126 return errors::InvalidArgument(
127 "Resource '", name_,
128 "' must be initialized with a valid type before use.");
129 }
130 switch (kind_) {
131 case kVariable: {
132 value_ =
133 xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
134 break;
135 }
136 case kTensorArray: {
137 TensorShape ta_shape;
138 ta_shape.AddDim(max_array_size_);
139 ta_shape.AppendShape(shape_);
140 value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
141 ta_shape.dim_sizes());
142 break;
143 }
144 case kStack: {
145 TensorShape ta_shape;
146 ta_shape.AddDim(max_array_size_);
147 ta_shape.AppendShape(shape_);
148 value_ =
149 xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
150 ta_shape.dim_sizes()),
151 xla::ConstantR0<int32>(builder, 0)});
152 break;
153 }
154
155 case kInvalid:
156 default:
157 LOG(FATAL) << "Invalid resource type";
158 }
159 return Status::OK();
160 }
161
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)162 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
163 xla::XlaBuilder* builder,
164 XlaResource** gradient_out) {
165 VLOG(2) << "Gradient lookup for resource: " << name_
166 << " gradient: " << source;
167 TF_RET_CHECK(kind_ == kTensorArray);
168 std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
169 if (!gradient) {
170 TensorShape ta_shape;
171 ta_shape.AddDim(max_array_size_);
172 ta_shape.AppendShape(shape_);
173 xla::XlaOp gradient_value =
174 xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
175 gradient.reset(
176 new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
177 /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
178 type_, shape_, gradient_value, max_array_size_,
179 /*tensor_array_gradients=*/{},
180 /*tensor_array_multiple_writes_aggregate=*/true));
181 }
182 *gradient_out = gradient.get();
183 return Status::OK();
184 }
185
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const186 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
187 if (tensor_array_gradients_.empty()) {
188 *pack = value_;
189 } else {
190 TF_RET_CHECK(kind_ == kTensorArray);
191 std::vector<xla::XlaOp> elems;
192 elems.push_back(value_);
193 for (const auto& gradient : tensor_array_gradients_) {
194 elems.push_back(gradient.second->value_);
195 }
196 *pack = xla::Tuple(builder, elems);
197 }
198 return Status::OK();
199 }
200
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)201 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
202 const xla::XlaOp& pack,
203 xla::XlaBuilder* builder) {
204 if (gradient_sources.empty()) {
205 if (!initialized()) {
206 initial_value_ = pack;
207 }
208 value_ = pack;
209 } else {
210 TF_RET_CHECK(kind_ == kTensorArray);
211 int pos = 0;
212 auto v = xla::GetTupleElement(pack, pos++);
213 if (!initialized()) {
214 initial_value_ = v;
215 }
216 value_ = v;
217
218 for (const auto& source : gradient_sources) {
219 XlaResource* gradient;
220 TF_RETURN_IF_ERROR(
221 GetOrCreateTensorArrayGradient(source, builder, &gradient));
222 auto v = xla::GetTupleElement(pack, pos++);
223 if (!gradient->initialized()) {
224 gradient->initial_value_ = v;
225 }
226 gradient->value_ = v;
227 }
228 }
229 return Status::OK();
230 }
231
232 } // namespace tensorflow
233