1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17
18 #include <functional>
19 #include <memory>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27
28 namespace tensorflow {
29
KindToString(XlaResource::Kind kind)30 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
31 switch (kind) {
32 case XlaResource::kInvalid:
33 return "invalid";
34 case XlaResource::kVariable:
35 return "variable";
36 case XlaResource::kStack:
37 return "stack";
38 case XlaResource::kTensorArray:
39 return "tensorarray";
40 }
41 }
42
CreateStack(string name,DataType type,int64 max_size)43 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
44 string name, DataType type, int64 max_size) {
45 return absl::make_unique<XlaResource>(
46 XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
47 /*initial_value=*/xla::XlaOp(),
48 /*max_array_size=*/max_size,
49 /*tensor_array_gradients=*/std::set<string>{},
50 /*tensor_array_multiple_writes_aggregate=*/false);
51 }
52
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64 max_array_size)53 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
54 string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
55 int64 max_array_size) {
56 return absl::make_unique<XlaResource>(
57 XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
58 initial_value, max_array_size,
59 /*tensor_array_gradients=*/std::set<string>{},
60 /*tensor_array_multiple_writes_aggregate=*/false);
61 }
62
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,const xla::XlaOp & initial_value,int64 max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate)63 XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
64 TensorShape shape, const xla::XlaOp& initial_value,
65 int64 max_array_size,
66 const std::set<string>& tensor_array_gradients,
67 bool tensor_array_multiple_writes_aggregate)
68 : kind_(kind),
69 arg_num_(arg_num),
70 name_(std::move(name)),
71 type_(type),
72 shape_(std::move(shape)),
73 value_(initial_value),
74 initial_value_(initial_value),
75 max_array_size_(max_array_size),
76 tensor_array_multiple_writes_aggregate_(
77 tensor_array_multiple_writes_aggregate) {
78 CHECK(kind_ != kInvalid);
79
80 for (const string& gradient : tensor_array_gradients) {
81 tensor_array_gradients_[gradient].reset(new XlaResource(
82 /*kind=*/kTensorArray, /*arg_num=*/-1,
83 /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
84 xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
85 /*tensor_array_multiple_writes_aggregate=*/true));
86 }
87 }
88
SetTypeAndShape(DataType type,const TensorShape & shape)89 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
90 if (type == DT_INVALID) {
91 return errors::InvalidArgument("Attempted to set type of resource '", name_,
92 "'' to an invalid type");
93 }
94 if (initialized() && type_ != type) {
95 return errors::InvalidArgument("Type of resource ", name_,
96 " cannot be changed after initialization: "
97 "old type was ",
98 DataTypeString(type_), ", new type is ",
99 DataTypeString(type));
100 }
101 if (initialized() && shape_ != shape) {
102 return errors::InvalidArgument("Shape of resource ", name_,
103 " cannot be changed after initialization: "
104 "old shape was ",
105 shape_.DebugString(), ", new shape is ",
106 shape.DebugString());
107 }
108 type_ = type;
109 shape_ = shape;
110 return Status::OK();
111 }
112
SetValue(const xla::XlaOp & value)113 Status XlaResource::SetValue(const xla::XlaOp& value) {
114 if (type_ == DT_INVALID) {
115 return errors::InvalidArgument(
116 "Resource '", name_,
117 "' must be initialized with a valid type before use.");
118 }
119 value_ = value;
120 return Status::OK();
121 }
122
SetZeroValue(xla::XlaBuilder * builder)123 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
124 if (type_ == DT_INVALID) {
125 return errors::InvalidArgument(
126 "Resource '", name_,
127 "' must be initialized with a valid type before use.");
128 }
129 switch (kind_) {
130 case kVariable: {
131 value_ =
132 xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
133 break;
134 }
135 case kTensorArray: {
136 TensorShape ta_shape;
137 ta_shape.AddDim(max_array_size_);
138 ta_shape.AppendShape(shape_);
139 value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
140 ta_shape.dim_sizes());
141 break;
142 }
143 case kStack: {
144 TensorShape ta_shape;
145 ta_shape.AddDim(max_array_size_);
146 ta_shape.AppendShape(shape_);
147 value_ =
148 xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
149 ta_shape.dim_sizes()),
150 xla::ConstantR0<int32>(builder, 0)});
151 break;
152 }
153
154 case kInvalid:
155 default:
156 LOG(FATAL) << "Invalid resource type";
157 }
158 return Status::OK();
159 }
160
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)161 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
162 xla::XlaBuilder* builder,
163 XlaResource** gradient_out) {
164 VLOG(2) << "Gradient lookup for resource: " << name_
165 << " gradient: " << source;
166 TF_RET_CHECK(kind_ == kTensorArray);
167 std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
168 if (!gradient) {
169 TensorShape ta_shape;
170 ta_shape.AddDim(max_array_size_);
171 ta_shape.AppendShape(shape_);
172 xla::XlaOp gradient_value =
173 xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
174 gradient.reset(
175 new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
176 /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
177 type_, shape_, gradient_value, max_array_size_,
178 /*tensor_array_gradients=*/{},
179 /*tensor_array_multiple_writes_aggregate=*/true));
180 }
181 *gradient_out = gradient.get();
182 return Status::OK();
183 }
184
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const185 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
186 if (tensor_array_gradients_.empty()) {
187 *pack = value_;
188 } else {
189 TF_RET_CHECK(kind_ == kTensorArray);
190 std::vector<xla::XlaOp> elems;
191 elems.push_back(value_);
192 for (const auto& gradient : tensor_array_gradients_) {
193 elems.push_back(gradient.second->value_);
194 }
195 *pack = xla::Tuple(builder, elems);
196 }
197 return Status::OK();
198 }
199
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)200 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
201 const xla::XlaOp& pack,
202 xla::XlaBuilder* builder) {
203 if (gradient_sources.empty()) {
204 if (!initialized()) {
205 initial_value_ = pack;
206 }
207 value_ = pack;
208 } else {
209 TF_RET_CHECK(kind_ == kTensorArray);
210 int pos = 0;
211 auto v = xla::GetTupleElement(pack, pos++);
212 if (!initialized()) {
213 initial_value_ = v;
214 }
215 value_ = v;
216
217 for (const auto& source : gradient_sources) {
218 XlaResource* gradient;
219 TF_RETURN_IF_ERROR(
220 GetOrCreateTensorArrayGradient(source, builder, &gradient));
221 auto v = xla::GetTupleElement(pack, pos++);
222 if (!gradient->initialized()) {
223 gradient->initial_value_ = v;
224 }
225 gradient->value_ = v;
226 }
227 }
228 return Status::OK();
229 }
230
231 } // namespace tensorflow
232