• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Parameter for cell."""
17from copy import copy
18import numbers
19import numpy as np
20from .._c_expression import ParamInfo
21from . import dtype as mstype
22from .. import context
23from ..parallel._utils import _get_parallel_mode
24from .initializer import initializer
25from .tensor import Tensor
26from .._checkparam import Validator
27from .._c_expression import Tensor as Tensor_
28from ..parallel._tensor import _get_slice_index
29from ..parallel._auto_parallel_context import auto_parallel_context
30from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table
31from ..parallel._ps_context import _reinsert_hash_table_size
32from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
33from .seed import _get_global_and_op_seed
34
35__all__ = ['Parameter', 'ParameterTuple']
36
37PARAMETER_NAME_DEFAULT = "Parameter"
38PARAMETER_NAME_PREFIX_MAX_LEN = 1024
39
40
41def _is_in_parallel_mode():
42    """Get parallel mode."""
43    return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
44
45
46def init_to_value(init):
47    """Get value of initializer."""
48    if isinstance(init, str):
49        if init == 'zeros':
50            return 0.0
51        if init == 'ones':
52            return 1.0
53        raise ValueError("The argument 'init' should be one of values in ['zeros', 'ones'].")
54    if isinstance(init, numbers.Number):
55        return float(init)
56    raise ValueError("The argument 'init' should be number or string, but got {}.".format(type(init)))
57
58
59class Parameter(Tensor_):
60    """
61    An object holding weights of cells, after initialized `Parameter` is a subtype of `Tensor`.
62
63    Note:
64        In auto_parallel mode of  "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
65        a `Tensor`, the type of Parameter will be `Tensor`. `Tensor`
66        will save the shape and type info of a tensor with no memory usage. The shape can be changed while
67        compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
68        If there is an operator in the network that requires part of the inputs to be Parameter,
69        then the Parameters as this part of the inputs are not allowed to be cast.
70        It is recommended to use the default value of `name` when initialize a parameter as one attribute of a cell,
71        otherwise, the parameter name may be different from expected.
72
73    Args:
74        default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data,
75            to initialize the parameter data.
76        name (str): Name of the child parameter. Default: None.
77        requires_grad (bool): True if the parameter requires gradient. Default: True.
78        layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode,
79            broadcast and gradients communication would not be applied to parameters. Default: False.
80        parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel
81            mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
82            Default: True.
83
84    Examples:
85        >>> import numpy as np
86        >>> from mindspore import Parameter, Tensor
87        >>> import mindspore.ops as ops
88        >>> import mindspore.nn as nn
89        >>> import mindspore
90        >>>
91        >>> class Net(nn.Cell):
92        ...     def __init__(self):
93        ...         super(Net, self).__init__()
94        ...         self.matmul = ops.MatMul()
95        ...         self.weight = Parameter(Tensor(np.ones((1, 2)), mindspore.float32), name="w", requires_grad=True)
96        ...
97        ...     def construct(self, x):
98        ...         out = self.matmul(self.weight, x)
99        ...         return out
100        >>> net = Net()
101        >>> x = Tensor(np.ones((2, 1)), mindspore.float32)
102        >>> print(net(x))
103        [[2.]]
104        >>> net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
105        >>> print(net(x))
106        [[0.]]
107    """
108    __base_type__ = {}
109
110    def __new__(cls, default_input, *args, **kwargs):
111        init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
112        input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
113        new_type = Parameter._get_base_class(input_class)
114        obj = input_class.__new__(new_type)
115        input_class.__init__(obj, *class_init_args)
116        # it's better to make the Initializer a kind of tensor.
117        obj.init_mode = None
118        obj.is_default_input_init = init_data_flag
119        if obj.has_init:
120            obj.init_mode = default_input
121        return obj
122
123    def __reduce_ex__(self, _):
124        data = self
125        if self.init_mode is not None:
126            data = self.init_mode
127        else:
128            # cast to break deep infinite loop while deepcopy
129            data = Tensor(self)
130        return (
131            Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
132
133    def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
134        self.param_info = ParamInfo()
135        self.init_in_server = False
136        self.cache_enable = False
137        self.name = name
138        self.requires_grad = requires_grad
139        self.layerwise_parallel = layerwise_parallel
140        self.parallel_optimizer = parallel_optimizer
141        # this flag for tensor copy data.
142        self.init_flag = False
143        # this flag is for ge variable copy data.
144        self._is_init = False
145        self._inited_param = None
146        self._sliced = False
147        self.is_param_ps = False
148        self.push_weight_to_server = False
149        self.pull_weight_from_server = False
150        self.requires_aggr = True
151        self._cast_type = None
152        self._unique = False
153        self.is_in_parallel = _is_in_parallel_mode()
154        self._pipeline_stage_list = []
155        if isinstance(default_input, (Tensor_, Tensor)):
156            Tensor_.__init__(self, default_input.dtype, default_input.shape)
157        elif isinstance(default_input, int):
158            Tensor_.__init__(self, mstype.int64, ())
159        elif isinstance(default_input, float):
160            Tensor_.__init__(self, mstype.float32, ())
161        elif isinstance(default_input, (np.ndarray, list)):
162            Tensor_.__init__(self, default_input)
163        else:
164            raise TypeError(f"The type of the argument 'default_input' must be in ['Tensor', 'int', 'float',"
165                            f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
166
167    def __deepcopy__(self, memodict):
168        new_obj = Parameter(self)
169        new_obj.name = self.name
170        new_obj._inited_param = self._inited_param  # pylint: disable=W0212
171        return new_obj
172
173    @staticmethod
174    def _get_base_class(input_class):
175        input_class_name = f'Parameter{input_class.__name__}'
176        if input_class_name in Parameter.__base_type__:
177            new_type = Parameter.__base_type__[input_class_name]
178        else:
179            new_type = type(input_class_name, (Parameter, input_class), {})
180            Parameter.__base_type__[input_class_name] = new_type
181        return new_type
182
183    @staticmethod
184    def _get_parameter_new_args(data):
185        """Set `set_data` of current `Parameter`."""
186        if isinstance(data, bool):
187            raise ValueError('Parameter data can not be `bool`')
188        if isinstance(data, Tensor) and data.has_init:
189            if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched() or _is_role_pserver():
190                # do not init data while in auto parallel.
191                return (Tensor, None, data.dtype, data.shape, data.init)
192            data = data.init_data().asnumpy()
193        elif isinstance(data, Tensor):
194            # make a copy of Tensor to init the parameter
195            return (Tensor, data.asnumpy(),)
196        if isinstance(data, int):
197            return (Tensor, data, mstype.int32)
198        if isinstance(data, float):
199            return (Tensor, data, mstype.float32)
200        return (Tensor, data)
201
202    def __str__(self):
203        return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
204               f'requires_grad={self.requires_grad})'
205
206    def __repr__(self):
207        return self.__str__()
208
209    def __parameter__(self):
210        """For parse check."""
211
212    def set_param_ps(self, init_in_server=False):
213        """
214        Set whether the trainable parameter is updated by parameter server and whether the
215        trainable parameter is initialized on server.
216
217        Note:
218            It only works when a running task is in the parameter server mode.
219
220        Args:
221            init_in_server (bool): Whether trainable parameter updated by parameter server is
222                initialized on server. Default: False.
223        """
224        if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()):
225            raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
226                               "1. context.set_ps_context(enable_ps=True) \n"
227                               "2. export MS_ROLE environment variable \n"
228                               "Please refer to the official website for detailed usage.")
229        if init_in_server and (not self.name.endswith("embedding_table")):
230            raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
231                               "sparse operator support initialization in server.".format(self.name))
232        self.is_param_ps = True
233        self.init_in_server = init_in_server
234        self.param_info.init_in_server = init_in_server
235
236    def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
237        """
238        Set the way of parameter and server interaction.
239
240        Args:
241            push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
242            pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
243            requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
244        """
245        if push_to_server:
246            self.push_weight_to_server = True
247        if pull_from_server:
248            self.pull_weight_from_server = True
249        if not requires_aggr:
250            self.requires_aggr = False
251            self.param_info.requires_aggr = False
252
253    @property
254    def inited_param(self):
255        """
256        Get the new parameter after call the init_data.
257
258        Default is a None, If `self` is a Parameter without data, after call the
259        `init_data` the initialized Parameter with data will be recorded here.
260        """
261        return self._inited_param
262
263    @property
264    def name(self):
265        """Get the name of the parameter."""
266        return self.param_info.name
267
268    @name.setter
269    def name(self, name_):
270        """
271        Define a name for the parameter.
272
273        Args:
274            name_ (`str` or `None`): The name of the parameter. When the parameter is None or an empty string,
275                the default value `PARAMETER_NAME_DEFAULT` is used.
276        """
277        if name_ is None:
278            name_ = PARAMETER_NAME_DEFAULT
279        elif isinstance(name_, str):
280            name_ = name_.strip()
281            if name_ == '':
282                name_ = PARAMETER_NAME_DEFAULT
283            if len(name_) > PARAMETER_NAME_PREFIX_MAX_LEN:
284                raise ValueError("The length of the '{}' name should be less than {}.".
285                                 format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
286        else:
287            raise ValueError("The type of the Parameter's name should be 'string' or 'None', "
288                             "but got {}.".format(type(name_)))
289
290        if _is_role_worker() and self.cache_enable:
291            if len(self.shape) != 2:
292                raise RuntimeError("The dims of parameter '{}' must be 2, but got {}."
293                                   .format(self.name, len(self.shape)))
294            _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1])
295
296        self.param_info.name = name_
297
298    @property
299    def sliced(self):
300        """Get slice status of the parameter."""
301        return self._sliced
302
303    @sliced.setter
304    def sliced(self, sliced_):
305        self._sliced = sliced_
306
307    @property
308    def comm_fusion(self):
309        """
310        Get and set the fusion type (int) for communication operators corresponding to this parameter.
311
312        In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
313        gradients aggregation are inserted automatically. Set the fusion type for communication operators generated
314        for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0,
315        operators will not be fused together.
316
317        Only support in Ascend environment with Graph mode.
318        """
319        return self.param_info.comm_fusion
320
321    @comm_fusion.setter
322    def comm_fusion(self, comm_fusion_):
323        if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
324            raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
325        Validator.check_non_negative_int(comm_fusion_)
326        self.param_info.comm_fusion = comm_fusion_
327
328    @property
329    def parallel_optimizer_comm_recompute(self):
330        """
331        Get and Set the whether do recompute for communication operators corresponding to this parameter
332        when applying parallel optimizer.
333
334        In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, when applying parallel optimizer, some all_gather operators
335        used for parameters gathering are inserted automatically.
336        The interface is used to control the recompute attr for those all_gather operators.
337
338        Note:
339            - Only `Ascend` and `Graph` mode is supported.
340            - It is recommended to use cell.recompute(parallel_optimizer_comm_recompute=True/False) to configure
341              the all_gather operators introducing by parallel optimizer rather than using this interface directly.
342        """
343        return self.param_info.parallel_optimizer_comm_recompute
344
345    @parallel_optimizer_comm_recompute.setter
346    def parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute_):
347        Validator.check_bool(parallel_optimizer_comm_recompute_)
348        self.param_info.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute_
349
350    @property
351    def unique(self):
352        """whether the parameter is already unique or not."""
353        return self._unique
354
355    @unique.setter
356    def unique(self, unique_):
357        self._unique = unique_
358
359    @property
360    def is_init(self):
361        """
362        Get the initialization status of the parameter.
363
364        This flag only work in GE, and it will be set to False in other backend.
365        """
366        return self._is_init
367
368    @is_init.setter
369    def is_init(self, is_init_):
370        """
371        Set init status of the parameter.
372
373        Args:
374            is_init_ (bool): The init status of the parameter.
375        """
376        self._is_init = is_init_
377
378    def clone(self, init='same'):
379        """
380        Clone the parameter.
381
382        Args:
383            init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameter.
384                If `init` is a `Tensor` or `numbers.Number`, clone a new parameter with the same shape
385                and dtype, and the data of the new parameter will be set according to `init`. If `init`
386                is a `str`, the `init` should be the alias of the class inheriting from `Initializer`.
387                For example, if `init` is 'same', clone a new parameter with the same data, shape, and
388                dtype. Default: 'same'.
389
390        Returns:
391            Parameter, a new parameter.
392        """
393        x = copy(self)
394        x.param_info = self.param_info.clone()
395        x.is_init = False
396        x.init = self.init
397        x.is_param_ps = self.is_param_ps
398        x.init_in_server = self.init_in_server
399        x.cache_enable = self.cache_enable
400        x.requires_aggr = self.requires_aggr
401        if self.cache_shape:
402            x.cache_shape = self.cache_shape
403        if init != 'same':
404            shape = self.shape
405            dtype = self.dtype
406            x.set_data(initializer(init, shape=shape, dtype=dtype))
407        return x
408
409    @property
410    def layerwise_parallel(self):
411        """
412        When layerwise_parallel is true in data/hybrid parallel mode, broadcast and gradients communication would not
413        be applied to parameters.
414        """
415        return self.param_info.layerwise_parallel
416
417    @layerwise_parallel.setter
418    def layerwise_parallel(self, value=True):
419        if not isinstance(value, bool):
420            raise TypeError("The argument `layerwise_parallel` must be bool type.")
421        self.param_info.layerwise_parallel = value
422
423    @property
424    def parallel_optimizer(self):
425        """
426        It is used to filter the weight shard operation in semi auto or auto parallel mode. It works only
427        when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
428        """
429        return self.param_info.parallel_optimizer
430
431    @parallel_optimizer.setter
432    def parallel_optimizer(self, value=True):
433        if not isinstance(value, bool):
434            raise TypeError("The argument `parallel_optimizer` must be bool type.")
435        self.param_info.parallel_optimizer = value
436
437    @property
438    def cache_enable(self):
439        """Return whether the parameter is cache enable."""
440        return self.param_info.cache_enable
441
442    @cache_enable.setter
443    def cache_enable(self, value=True):
444        if not isinstance(value, bool):
445            raise TypeError("The argument `cache_enable` must be bool type.")
446        self.param_info.cache_enable = value
447
448    @property
449    def cache_shape(self):
450        """Return the cache shape corresponding to the parameter if use cache."""
451        return self.param_info.cache_shape
452
453    @cache_shape.setter
454    def cache_shape(self, value):
455        if not isinstance(value, (tuple, list)):
456            raise TypeError("The argument `cache_shape` must be tuple or list type.")
457        self.param_info.cache_shape = value
458
459    @property
460    def requires_grad(self):
461        """Return whether the parameter requires gradient."""
462        return self.param_info.requires_grad
463
464    @requires_grad.setter
465    def requires_grad(self, value=True):
466        if not isinstance(value, bool):
467            raise TypeError("The argument `requires_grad` must be bool type")
468        self.param_info.requires_grad = value
469
470    @property
471    def data(self):
472        """Return the parameter object."""
473        return self
474
475    def _update_tensor_data(self, data):
476        """Update the parameter by a Tensor."""
477        if isinstance(self, Tensor):
478            self.init_flag = False
479            self.init = None
480            return self.assign_value(data)
481        new_param = Parameter(data, self.name, self.requires_grad)
482        new_param.param_info = self.param_info
483        return new_param
484
485    def add_pipeline_stage(self, stage):
486        if not isinstance(stage, int) or stage < 0:
487            raise TypeError("`stage` must be a positive number of int type")
488        self._pipeline_stage_list.append(stage)
489
490    def set_data(self, data, slice_shape=False):
491        """
492        Set Parameter's data.
493
494        Args:
495            data (Union[Tensor, int, float]): new data.
496            slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency.
497                                Default: False.
498
499        Returns:
500            Parameter, the parameter after set data.
501        """
502        def raise_type_error(incoming):
503            raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
504                            f"Current dtype is {self.dtype}, and incoming is {incoming}. "
505                            f"Use .set_dtype(xxx) to change the dtype.")
506
507        if not isinstance(data, (Tensor, int, float)):
508            raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` "
509                            f"(like `Tensor`). But with type {type(data)}.")
510        if isinstance(data, (int, float)):
511            if self.dtype in mstype.int_type and isinstance(data, float):
512                raise_type_error(mstype.float_)
513            data = Tensor(data, self.dtype)
514        # both not init.
515        incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init
516        current_tensor_is_init = isinstance(self, Tensor) and not self.has_init
517
518        if incoming_tensor_is_init and not current_tensor_is_init:
519            raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
520                            "Please initialize 'data' before call this method.")
521        if tuple(self.shape) != tuple(data.shape):
522            # If Slice create Parameter shape can be change.
523            if not slice_shape:
524                raise ValueError(f"Can not change the shape of Parameter which has been initialized."
525                                 f" Current shape is {self.shape}, and incoming is {data.shape}.")
526        if self.dtype != data.dtype:
527            if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
528                raise_type_error(data.dtype)
529            else:
530                from mindspore.ops import functional as F
531                data = F.cast(data, self.dtype)
532        if isinstance(data, Tensor) and data.has_init:
533            # The parameter has been initialized, directly update by the data
534            if current_tensor_is_init:
535                self._update_tensor_data(data.init_data())
536            else:
537                # also update the related inited parameter data
538                if self.inited_param is not None:
539                    self.inited_param.set_data(data)
540                self.init_mode = data
541        elif incoming_tensor_is_init or current_tensor_is_init:
542            self._update_tensor_data(data)
543        self.sliced = slice_shape
544        return self
545
546    def init_data(self, layout=None, set_sliced=False):
547        """
548        Initialize the parameter's data.
549
550        Args:
551            layout (Union[None, tuple(list(int))]): Parameter slice
552                layout [dev_mat, tensor_map, slice_shape]. Default: None.
553
554                - dev_mat (list(int)): Device matrix.
555                - tensor_map (list(int)): Tensor map.
556                - slice_shape (list(int)): Shape of slice.
557
558            set_sliced (bool): True if the parameter is set sliced after initializing the data.
559                Default: False.
560
561        Raises:
562            RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
563            ValueError: If the length of the layout is less than 3.
564            TypeError: If `layout` is not tuple.
565
566        Returns:
567            Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
568            returns the same initialized `Parameter`.
569        """
570        if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode():
571            raise RuntimeError("Must set or change parallel mode before any Tensor created.")
572        if self.init_mode is None:
573            return self
574        if self.inited_param is not None:
575            return self.inited_param
576        if _is_role_worker() and self.cache_enable:
577            global_seed, op_seed = _get_global_and_op_seed()
578            _insert_weight_init_info(self.name, global_seed, op_seed)
579
580        init_data_args = ()
581        if layout is not None:
582            if not isinstance(layout, tuple):
583                raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
584            if len(layout) < 6:
585                raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
586            slice_index = int(_get_slice_index(layout[0], layout[1]))
587            init_data_args += (slice_index, layout[2], layout[5])
588
589        if _is_role_pserver():
590            return self
591
592        if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \
593           self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()):
594            data = self.init_mode.init_data(0, [1])
595        else:
596            data = self.init_mode.init_data(*init_data_args)
597
598        obj = self._update_tensor_data(data)
599        if id(obj) != id(self):
600            self._inited_param = obj
601        obj.init_mode = None
602        obj.sliced = set_sliced
603        return obj
604
605
606class ParameterTuple(tuple):
607    """
608    Class for storing tuple of parameters.
609
610    Note:
611        It is used to store the parameters of the network into the parameter tuple collection.
612    """
613    def __new__(cls, iterable):
614        """Create instance object of ParameterTuple."""
615        data = tuple(iterable)
616        ids = set()
617        orders = {}
618        for x in data:
619            if not isinstance(x, Parameter):
620                raise TypeError(f"ParameterTuple input should be `Parameter` collection."
621                                f"But got a {type(iterable)}, {iterable}")
622            if id(x) not in ids:
623                ids.add(id(x))
624                if x.name not in orders.keys():
625                    orders[x.name] = [0, x]
626                else:
627                    if isinstance(orders[x.name], list):
628                        name = x.name
629                        orders[name][1].name = name + "_" + str(0)
630                        x.name = x.name + "_" + str(1)
631                        orders[name] = 1
632                    else:
633                        orders[x.name] += 1
634                        x.name = x.name + "_" + str(orders[x.name])
635        return tuple.__new__(ParameterTuple, tuple(data))
636
637    def clone(self, prefix, init='same'):
638        """
639        Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple.
640
641        Args:
642            prefix (str): Namespace of parameter.
643            init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameters.
644                The definition of `init` is the same as in `Parameter` API. If `init` is 'same', the
645                parameters in the new parameter tuple are the same as those in the original parameter tuple.
646                Default: 'same'.
647
648        Raises:
649            RuntimeError: If parameter's name is not end with embedding_table.
650
651        Returns:
652            Tuple, the new Parameter tuple.
653        """
654        Validator.check_str_by_regular(prefix)
655        new = []
656        for x in self:
657            x1 = x.clone(init)
658            x1.name = prefix + "." + x1.name
659            new.append(x1)
660
661            if not x1.cache_enable:
662                continue
663            if not x1.name.endswith("embedding_table"):
664                raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of "
665                                   "sparse operator support enable cache.".format(x1.name))
666
667            if _is_role_worker():
668                _clone_hash_table(x.name, x1.name)
669                _insert_accumu_init_info(x1.name, init_to_value(init))
670        return ParameterTuple(new)
671
672    def __parameter_tuple__(self):
673        """For parse check."""
674