• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Implementation for internal polymorphism `add` operations."""
17
18from . import _compile_utils as utils
19from ...composite import base
20from ... import functional as F
21
22
23add = base.MultitypeFuncGraph('add', True)
24"""`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator."""
25
26
27_add_backward = base.MultitypeFuncGraph('add_backward')
28"""
29`_add_backward` is an metafuncgraph object which will add_backward two objects according to input type
30using ".register" decorator.
31"""
32
33
34class _TupleAdd(base.TupleAdd_):
35    """
36    Adding two tuples.
37
38    Args:
39        x (tuple): x
40        y (tuple): y
41
42    Returns:
43        Tuple, consists of elements of x and elements of y.
44    """
45
46    def __init__(self, name):
47        """Initialize _TupleAdd."""
48        base.TupleAdd_.__init__(self, name)
49
50    def __call__(self, *args):
51        pass
52
53
54_tuple_add = _TupleAdd('tuple_add')
55"""`_tuple_add` is an metafuncgraph object which will concatenate two tuples to form a tuple."""
56
57
58@add.register("Number", "Number")
59@_add_backward.register("Number", "Number")
60def _scalar_add_scalar(x, y):
61    """
62    Returns the sum of two numbers.
63
64    Args:
65        x (Number): x
66        y (Number): y
67
68    Returns:
69        Number, equal to x + y, has the same type as x.
70    """
71    return F.scalar_add(x, y)
72
73
74@add.register("String", "String")
75def _string_concat_string(x, y):
76    """
77    Concatenate the string y to the string x.
78
79    Args:
80        x (str): The first input string.
81        y (str): the second input string.
82
83    Returns:
84        str, concatenate the y to the x.
85    """
86    return F.string_concat(x, y)
87
88
89@add.register("Number", "Tensor")
90def _scalar_add_tensor(x, y):
91    """
92    Number is added to tensor.
93
94    Args:
95        x (Number): x
96        y (Tensor): The dtype is same as x.
97
98    Returns:
99        Tensor, has the same dtype as x.
100    """
101    return F.add(x, y)
102
103
104@add.register("Tensor", "Number")
105def _tensor_add_scalar(x, y):
106    """
107    Tensor is added to number.
108
109    Args:
110        x (Tensor): x
111        y (Number): The dtype is same as x.
112
113    Returns:
114        Tensor, has the same dtype as x.
115    """
116    return F.add(x, y)
117
118
119@add.register("Tuple", "Tensor")
120def _tuple_add_tensor(x, y):
121    """
122    Tuple is added to tensor.
123
124    Args:
125        x (Tuple): x
126        y (Tensor): The dtype is same as x.
127
128    Returns:
129        Tensor, has the same dtype as x.
130    """
131    x = utils.sequence_to_tensor(x, y.dtype)
132    return F.tensor_add(x, y)
133
134
135@add.register("Tensor", "Tuple")
136def _tensor_add_tuple(x, y):
137    """
138    Tensor is added to number.
139
140    Args:
141        x (Tensor): x
142        y (Tuple): The dtype is same as x.
143
144    Returns:
145        Tensor, has the same dtype as x.
146    """
147    y = utils.sequence_to_tensor(y, x.dtype)
148    return F.tensor_add(x, y)
149
150
151@add.register("List", "Tensor")
152def _list_add_tensor(x, y):
153    """
154    Tuple is added to tensor.
155
156    Args:
157        x (List): x
158        y (Tensor): The dtype is same as x.
159
160    Returns:
161        Tensor, has the same dtype as x.
162    """
163    x = utils.sequence_to_tensor(x, y.dtype)
164    return F.tensor_add(x, y)
165
166
167@add.register("Tensor", "List")
168def _tensor_add_list(x, y):
169    """
170    Tensor is added to number.
171
172    Args:
173        x (Tensor): x
174        y (List): The dtype is same as x.
175
176    Returns:
177        Tensor, has the same dtype as x.
178    """
179    y = utils.sequence_to_tensor(y, x.dtype)
180    return F.tensor_add(x, y)
181
182
183@add.register("List", "List")
184def _list_add_list(x, y):
185    """
186        list is added to list.
187
188        Args:
189            x (list): x
190            y (list): y.
191
192        Returns:
193            list, has the same dtype as x.
194    """
195    for i in y:
196        x.append(i)
197    return x
198
199
200@add.register("Tensor", "Tensor")
201def _tensor_add_tensor(x, y):
202    """
203    Returns x + y element-wise.
204
205    Args:
206        x (Tensor): x
207        y (Tensor): The dtype is same as x.
208
209    Returns:
210        Tensor, has the same dtype as x.
211    """
212    return F.add(x, y)
213
214
215@add.register("RowTensor", "Tensor")
216def add_rowtensor_tensor(x, y):
217    """
218   Adds RowTensor and Tensor.
219
220   Args:
221       x (RowTensor): x
222       y (Tensor): y
223
224   Returns:
225       RowTensor, the dtype is same as x.
226   """
227    return F.row_tensor_add(x, y)
228
229
230@add.register("None", "None")
231def _none_add_none(x, y):
232    """
233   Adds None and None.
234
235   Args:
236       x (None): x
237       y (None): y
238
239   Returns:
240       None.
241   """
242    return None
243
244
245@_add_backward.register("EnvType", "EnvType")
246def _add_env(x, y):
247    """
248    Adds two EnvType variables.
249
250    Args:
251        x (EnvType): x
252        y (EnvType): y
253
254    Returns:
255        EnvType, equal to x + y.
256    """
257    return F.env_add(x, y)
258
259
260@add.register("Tuple", "Tuple")
261def _add_tuple(x, y):
262    """
263    Adds two tuples.
264
265    Args:
266        x (tuple): x
267        y (tuple): y
268
269    Returns:
270        Tuple, consists of elements of x and elements of y.
271    """
272    return _tuple_add(x, y)
273
274
275@_add_backward.register("Tensor", "Tensor")
276def _add_addn(x, y):
277    """
278   Adds two tensors by element.
279
280   Args:
281       x (Tensor): x
282       y (Tensor): The dtype is same as x.
283
284   Returns:
285       Tensor, the dtype is same as x.
286   """
287    return F.addn((x, y))
288
289
290@_add_backward.register("UMonad", "UMonad")
291def _add_umonad_umonad(x, y):
292    """
293   Adds two monad.
294
295   Args:
296       x (UMonad): x
297       y (UMonad): y
298
299   Returns:
300       Monad, the dtype is same as x.
301   """
302    return x
303
304
305@_add_backward.register("IOMonad", "IOMonad")
306def _add_iomonad_iomonad(x, y):
307    """
308   Adds two monad.
309
310   Args:
311       x (IOMonad): x
312       y (IOMonad): y
313
314   Returns:
315       Monad, the dtype is same as x.
316   """
317    return x
318
319
320@_add_backward.register("RowTensor", "Tensor")
321def _add_rowtensor_tensor(x, y):
322    """
323   Adds RowTensor and Tensor.
324
325   Args:
326       x (RowTensor): x
327       y (Tensor): y
328
329   Returns:
330       RowTensor, the dtype is same as x.
331   """
332    return x + y
333
334
335@_add_backward.register("None", "None")
336def _add_nonetensor_tensor(x, y):
337    """
338   Adds None and None.
339
340   Args:
341       x (None): x
342       y (None): y
343
344   Returns:
345       None.
346   """
347    return x + y
348
349hyper_add = base.HyperMap(_add_backward)
350