1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/ops/math_ops.h"
16
17 #include "tensorflow/c/eager/abstract_context.h"
18 #include "tensorflow/c/eager/abstract_tensor_handle.h"
19 #include "tensorflow/c/eager/tracing_utils.h"
20 #include "tensorflow/c/experimental/ops/array_ops.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/platform/errors.h"
23
24 using tensorflow::tracing::MaybeSetOpName;
25
26 namespace tensorflow {
27 namespace ops {
28
Mul(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)29 Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
30 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
31 AbstractOperationPtr mul_op(ctx->CreateOperation());
32 TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
33 TF_RETURN_IF_ERROR(MaybeSetOpName(mul_op.get(), name));
34 TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
35 TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
36 int num_retvals = 1;
37 return mul_op->Execute(outputs, &num_retvals);
38 }
39
Conj(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)40 Status Conj(AbstractContext* ctx,
41 absl::Span<AbstractTensorHandle* const> inputs,
42 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
43 auto dtype = inputs[0]->DataType();
44 if (DataTypeIsFloating(BaseType(dtype)) ||
45 DataTypeIsInteger(BaseType(dtype))) {
46 TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
47 } else if (DataTypeIsComplex(BaseType(dtype)) ||
48 BaseType(dtype) == DT_VARIANT) {
49 AbstractOperationPtr conj_op(ctx->CreateOperation());
50 TF_RETURN_IF_ERROR(conj_op->Reset("Conj", /*raw_device_name=*/nullptr));
51 TF_RETURN_IF_ERROR(MaybeSetOpName(conj_op.get(), name));
52 TF_RETURN_IF_ERROR(conj_op->AddInput(inputs[0]));
53
54 int num_retvals = 1;
55 TF_RETURN_IF_ERROR(conj_op->Execute(outputs, &num_retvals));
56 } else {
57 return errors::InvalidArgument(
58 "Expected numeric or variant tensor, got dtype ", dtype);
59 }
60 return Status::OK();
61 }
62
Add(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)63 Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
64 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
65 AbstractOperationPtr add_op(ctx->CreateOperation());
66 TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
67 TF_RETURN_IF_ERROR(MaybeSetOpName(add_op.get(), name));
68 TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
69 TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
70
71 int num_retvals = 1;
72 TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals));
73 return Status::OK();
74 }
75
Sub(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)76 Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
77 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
78 AbstractOperationPtr sub_op(ctx->CreateOperation());
79 TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
80 TF_RETURN_IF_ERROR(MaybeSetOpName(sub_op.get(), name));
81 TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
82 TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
83
84 int num_retvals = 1;
85 TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals));
86 return Status::OK();
87 }
88
MatMul(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name,bool transpose_a=false,bool transpose_b=false)89 Status MatMul(AbstractContext* ctx,
90 absl::Span<AbstractTensorHandle* const> inputs,
91 absl::Span<AbstractTensorHandle*> outputs, const char* name,
92 bool transpose_a = false, bool transpose_b = false) {
93 AbstractOperationPtr matmul_op(ctx->CreateOperation());
94 TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
95 TF_RETURN_IF_ERROR(MaybeSetOpName(matmul_op.get(), name));
96 TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
97 TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
98
99 TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a));
100 TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b));
101
102 int num_retvals = 1;
103 TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals));
104 return Status::OK();
105 }
106
Neg(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)107 Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
108 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
109 AbstractOperationPtr neg_op(ctx->CreateOperation());
110 TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
111 TF_RETURN_IF_ERROR(MaybeSetOpName(neg_op.get(), name));
112 TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
113
114 int num_retvals = 1;
115 return neg_op->Execute(outputs, &num_retvals);
116 }
117
Sum(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)118 Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
119 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
120 AbstractOperationPtr sum_op(ctx->CreateOperation());
121 TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
122 TF_RETURN_IF_ERROR(MaybeSetOpName(sum_op.get(), name));
123 TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
124 TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices
125
126 int num_retvals = 1;
127 TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
128 return Status::OK();
129 }
130
Div(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)131 Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
132 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
133 AbstractOperationPtr div_op(ctx->CreateOperation());
134 TF_RETURN_IF_ERROR(div_op->Reset("Div", /*raw_device_name=*/nullptr));
135 TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
136 TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
137 TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
138
139 int num_retvals = 1;
140 TF_RETURN_IF_ERROR(div_op->Execute(outputs, &num_retvals)); // z = x / y
141 return Status::OK();
142 }
143
DivNoNan(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)144 Status DivNoNan(AbstractContext* ctx,
145 absl::Span<AbstractTensorHandle* const> inputs,
146 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
147 AbstractOperationPtr div_op(ctx->CreateOperation());
148 TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr));
149 TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
150 TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
151 TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
152
153 int num_retvals = 1;
154 TF_RETURN_IF_ERROR(div_op->Execute(
155 outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0)
156 return Status::OK();
157 }
158
Exp(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)159 Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
160 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
161 AbstractOperationPtr exp_op(ctx->CreateOperation());
162 TF_RETURN_IF_ERROR(exp_op->Reset("Exp", /*raw_device_name=*/nullptr));
163 TF_RETURN_IF_ERROR(MaybeSetOpName(exp_op.get(), name));
164 TF_RETURN_IF_ERROR(exp_op->AddInput(inputs[0]));
165
166 int num_retvals = 1;
167 return exp_op->Execute(outputs, &num_retvals);
168 }
169
Sqrt(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)170 Status Sqrt(AbstractContext* ctx,
171 absl::Span<AbstractTensorHandle* const> inputs,
172 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
173 AbstractOperationPtr sqrt_op(ctx->CreateOperation());
174 TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
175 TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
176 TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
177
178 int num_retvals = 1;
179 Status s = sqrt_op->Execute(outputs, &num_retvals);
180 return s;
181 }
182
SqrtGrad(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)183 Status SqrtGrad(AbstractContext* ctx,
184 absl::Span<AbstractTensorHandle* const> inputs,
185 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
186 AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
187 TF_RETURN_IF_ERROR(
188 sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
189 TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
190 TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
191 TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
192
193 int num_retvals = 1;
194 Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
195 return s;
196 }
197
Log1p(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)198 Status Log1p(AbstractContext* ctx,
199 absl::Span<AbstractTensorHandle* const> inputs,
200 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
201 AbstractOperationPtr log1p_op(ctx->CreateOperation());
202 TF_RETURN_IF_ERROR(log1p_op->Reset("Log1p", /*raw_device_name=*/nullptr));
203 TF_RETURN_IF_ERROR(MaybeSetOpName(log1p_op.get(), name));
204 TF_RETURN_IF_ERROR(log1p_op->AddInput(inputs[0]));
205
206 int num_retvals = 1;
207 Status s = log1p_op->Execute(outputs, &num_retvals);
208 return s;
209 }
210
211 } // namespace ops
212 } // namespace tensorflow
213