• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 
27 namespace tensorflow {
28 
KindToString(XlaResource::Kind kind)29 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
30   switch (kind) {
31     case XlaResource::kInvalid:
32       return "invalid";
33     case XlaResource::kVariable:
34       return "variable";
35     case XlaResource::kStack:
36       return "stack";
37     case XlaResource::kTensorArray:
38       return "tensorarray";
39   }
40 }
41 
CreateStack(string name,DataType type,int64_t max_size)42 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
43     string name, DataType type, int64_t max_size) {
44   return std::make_unique<XlaResource>(
45       XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
46       /*initial_value=*/xla::XlaOp(),
47       /*max_array_size=*/max_size,
48       /*tensor_array_gradients=*/std::set<string>{},
49       /*tensor_array_multiple_writes_aggregate=*/false);
50 }
51 
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size)52 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
53     string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
54     int64_t max_array_size) {
55   return std::make_unique<XlaResource>(
56       XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
57       initial_value, max_array_size,
58       /*tensor_array_gradients=*/std::set<string>{},
59       /*tensor_array_multiple_writes_aggregate=*/false);
60 }
61 
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate,const std::optional<ManagedStackTrace> & definition_stack_trace)62 XlaResource::XlaResource(
63     Kind kind, int arg_num, string name, DataType type, TensorShape shape,
64     xla::XlaOp initial_value, int64_t max_array_size,
65     const std::set<string>& tensor_array_gradients,
66     bool tensor_array_multiple_writes_aggregate,
67     const std::optional<ManagedStackTrace>& definition_stack_trace)
68     : kind_(kind),
69       arg_num_(arg_num),
70       name_(std::move(name)),
71       type_(type),
72       shape_(std::move(shape)),
73       value_(initial_value),
74       initial_value_(initial_value),
75       max_array_size_(max_array_size),
76       tensor_array_multiple_writes_aggregate_(
77           tensor_array_multiple_writes_aggregate),
78       definition_stack_trace_(definition_stack_trace) {
79   CHECK(kind_ != kInvalid);
80 
81   for (const string& gradient : tensor_array_gradients) {
82     tensor_array_gradients_[gradient].reset(new XlaResource(
83         /*kind=*/kTensorArray, /*arg_num=*/-1,
84         /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
85         xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
86         /*tensor_array_multiple_writes_aggregate=*/true));
87   }
88 }
89 
SetTypeAndShape(DataType type,const TensorShape & shape)90 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
91   if (type == DT_INVALID) {
92     return errors::InvalidArgument(
93         "Attempted to set type of resource '", name_, "'' to an invalid type",
94         DefinitionLocationMsg(definition_stack_trace_));
95   }
96   if (initialized() && type_ != type) {
97     return errors::InvalidArgument(
98         "Trying to assign variable with wrong dtype. Expected ",
99         DataTypeString(type_), " got ", DataTypeString(type),
100         DefinitionLocationMsg(definition_stack_trace_));
101   }
102   if (initialized() && shape_ != shape) {
103     return errors::InvalidArgument(
104         "Shape of resource ", name_,
105         " cannot be changed after initialization: "
106         "old shape was ",
107         shape_.DebugString(), ", new shape is ", shape.DebugString(),
108         DefinitionLocationMsg(definition_stack_trace_));
109   }
110   type_ = type;
111   shape_ = shape;
112   return OkStatus();
113 }
114 
SetValue(const xla::XlaOp & value)115 Status XlaResource::SetValue(const xla::XlaOp& value) {
116   if (type_ == DT_INVALID) {
117     return errors::InvalidArgument(
118         "Resource '", name_,
119         "' must be initialized with a valid type before use.");
120   }
121   value_ = value;
122   is_overwritten_ = true;
123   return OkStatus();
124 }
125 
SetZeroValue(xla::XlaBuilder * builder)126 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
127   is_overwritten_ = true;
128   if (type_ == DT_INVALID) {
129     return errors::InvalidArgument(
130         "Resource '", name_,
131         "' must be initialized with a valid type before use.");
132   }
133   switch (kind_) {
134     case kVariable: {
135       value_ =
136           xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
137       break;
138     }
139     case kTensorArray: {
140       TensorShape ta_shape;
141       ta_shape.AddDim(max_array_size_);
142       ta_shape.AppendShape(shape_);
143       value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
144                               ta_shape.dim_sizes());
145       break;
146     }
147     case kStack: {
148       TensorShape ta_shape;
149       ta_shape.AddDim(max_array_size_);
150       ta_shape.AppendShape(shape_);
151       value_ =
152           xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
153                                               ta_shape.dim_sizes()),
154                                xla::ConstantR0<int32>(builder, 0)});
155       break;
156     }
157 
158     case kInvalid:
159     default:
160       LOG(FATAL) << "Invalid resource type";
161   }
162   return OkStatus();
163 }
164 
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)165 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
166                                                    xla::XlaBuilder* builder,
167                                                    XlaResource** gradient_out) {
168   VLOG(2) << "Gradient lookup for resource: " << name_
169           << " gradient: " << source;
170   TF_RET_CHECK(kind_ == kTensorArray);
171   std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
172   if (!gradient) {
173     TensorShape ta_shape;
174     ta_shape.AddDim(max_array_size_);
175     ta_shape.AppendShape(shape_);
176     xla::XlaOp gradient_value =
177         xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
178     gradient.reset(
179         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
180                         /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
181                         type_, shape_, gradient_value, max_array_size_,
182                         /*tensor_array_gradients=*/{},
183                         /*tensor_array_multiple_writes_aggregate=*/true));
184   }
185   *gradient_out = gradient.get();
186   return OkStatus();
187 }
188 
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const189 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
190   if (tensor_array_gradients_.empty()) {
191     *pack = value_;
192   } else {
193     TF_RET_CHECK(kind_ == kTensorArray);
194     std::vector<xla::XlaOp> elems;
195     elems.push_back(value_);
196     for (const auto& gradient : tensor_array_gradients_) {
197       elems.push_back(gradient.second->value_);
198     }
199     *pack = xla::Tuple(builder, elems);
200   }
201   return OkStatus();
202 }
203 
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)204 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
205                                 const xla::XlaOp& pack,
206                                 xla::XlaBuilder* builder) {
207   if (gradient_sources.empty()) {
208     if (!initialized()) {
209       initial_value_ = pack;
210     }
211     value_ = pack;
212   } else {
213     TF_RET_CHECK(kind_ == kTensorArray);
214     int pos = 0;
215     auto v = xla::GetTupleElement(pack, pos++);
216     if (!initialized()) {
217       initial_value_ = v;
218     }
219     value_ = v;
220 
221     for (const auto& source : gradient_sources) {
222       XlaResource* gradient;
223       TF_RETURN_IF_ERROR(
224           GetOrCreateTensorArrayGradient(source, builder, &gradient));
225       auto v = xla::GetTupleElement(pack, pos++);
226       if (!gradient->initialized()) {
227         gradient->initial_value_ = v;
228       }
229       gradient->value_ = v;
230     }
231   }
232   return OkStatus();
233 }
234 
235 }  // namespace tensorflow
236