1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_ 17 #define TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_ 18 19 #define EIGEN_USE_THREADS 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/tensor_types.h" 24 25 namespace tensorflow { 26 27 // TODO(jeff): Get rid of use_exclusive_lock_ option 28 29 // Computes *input[0] = input[1] 30 class AssignOp : public OpKernel { 31 public: AssignOp(OpKernelConstruction * context)32 explicit AssignOp(OpKernelConstruction* context) : OpKernel(context) { 33 OP_REQUIRES_OK(context, 34 context->GetAttr("use_locking", &use_exclusive_lock_)); 35 OP_REQUIRES_OK(context, 36 context->GetAttr("validate_shape", &validate_shape_)); 37 OP_REQUIRES(context, IsRefType(context->input_type(0)), 38 errors::InvalidArgument("lhs input needs to be a ref type")); 39 if (!context 40 ->GetAttr("_grappler_relax_allocator_constraints", 41 &relax_constraints_) 42 .ok()) { 43 relax_constraints_ = false; 44 } 45 } 46 Compute(OpKernelContext * context)47 void Compute(OpKernelContext* context) override { 48 const Tensor& rhs = context->input(1); 49 50 // We always return the input ref. 51 context->forward_ref_input_to_ref_output(0, 0); 52 53 // We can't always know how this value will be used downstream, so make 54 // conservative assumptions in specifying constraints on the memory 55 // allocation attributes, unless the Grappler graph analysis determined that 56 // it was safe not to. 57 AllocatorAttributes attr; 58 if (!relax_constraints_) { 59 attr.set_gpu_compatible(true); 60 attr.set_nic_compatible(true); 61 } 62 63 { 64 mutex_lock l(*context->input_ref_mutex(0)); 65 const Tensor& old_lhs = context->mutable_input(0, /* lock_held */ true); 66 const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape()); 67 if (validate_shape_) { 68 OP_REQUIRES(context, same_shape, 69 errors::InvalidArgument( 70 "Assign requires shapes of both tensors to match. " 71 "lhs shape= ", 72 old_lhs.shape().DebugString(), 73 " rhs shape= ", rhs.shape().DebugString())); 74 } 75 76 // In the code below we try to minimize the amount of memory allocation 77 // and copying by trying the following two shortcuts: 78 // 1. If the lhs is initialized and has the same number of elements as 79 // the rhs we can avoid a memory allocation. 80 // 2. If we can reuse the rhs buffer we avoid both a memory allocation 81 // and copying. 82 83 // 1. Try to copy into an existing buffer. 84 if (old_lhs.IsInitialized() && 85 old_lhs.shape().num_elements() == rhs.shape().num_elements()) { 86 // The existing lhs tensor has already been initialized and the right 87 // hand side can fit in the underlying buffer. 88 Tensor reshaped_old_lhs; 89 if (same_shape) { 90 reshaped_old_lhs = old_lhs; 91 } else { 92 CHECK(reshaped_old_lhs.CopyFrom(old_lhs, rhs.shape())); 93 context->replace_ref_input(0, reshaped_old_lhs, 94 /* lock_held */ true); 95 } 96 if (use_exclusive_lock_) { 97 Copy(context, &reshaped_old_lhs, rhs); 98 return; 99 } 100 } else { 101 // 2. Try to reuse the rhs. 102 std::unique_ptr<Tensor> input_alias = context->forward_input( 103 1, OpKernelContext::Params::kNoReservation /*output_index*/, 104 rhs.dtype(), rhs.shape(), DEVICE_MEMORY, attr); 105 if (input_alias != nullptr) { 106 // Update the ref to point to the new buffer. 107 context->replace_ref_input(0, *input_alias, /* lock_held */ true); 108 return; 109 } 110 111 // Otherwise, create a new tensor whose shape matches the 112 // right hand side, hand off to lhs and copy the rhs into it. 113 Tensor copy_tensor; 114 OP_REQUIRES_OK(context, 115 context->allocate_temp(old_lhs.dtype(), rhs.shape(), 116 ©_tensor, attr)); 117 // We track memory of variables in variable ops instead of in this 118 // assign op. 119 context->clear_recorded_memory(); 120 context->replace_ref_input(0, copy_tensor, /* lock_held */ true); 121 if (use_exclusive_lock_) { 122 Copy(context, ©_tensor, rhs); 123 return; 124 } 125 } 126 } 127 128 // The tensor has already been initialized and the right hand side 129 // matches the left hand side's shape. We have been told to do the 130 // copy outside the lock. 131 Tensor old_unlocked_lhs = context->mutable_input(0, /* lock_held */ false); 132 Copy(context, &old_unlocked_lhs, rhs); 133 } 134 135 virtual void Copy(OpKernelContext* context, Tensor* lhs, 136 const Tensor& rhs) = 0; 137 138 bool use_exclusive_lock_; 139 bool validate_shape_; 140 bool relax_constraints_; 141 }; 142 143 } // end namespace tensorflow 144 145 #endif // TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_ 146