1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Parameter for cell.""" 17from copy import copy 18import numbers 19import numpy as np 20from .._c_expression import ParamInfo 21from . import dtype as mstype 22from .. import context 23from ..parallel._utils import _get_parallel_mode 24from .initializer import initializer 25from .tensor import Tensor 26from .._checkparam import Validator 27from .._c_expression import Tensor as Tensor_ 28from ..parallel._tensor import _get_slice_index 29from ..parallel._auto_parallel_context import auto_parallel_context 30from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table 31from ..parallel._ps_context import _reinsert_hash_table_size 32from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info 33from .seed import _get_global_and_op_seed 34 35__all__ = ['Parameter', 'ParameterTuple'] 36 37PARAMETER_NAME_DEFAULT = "Parameter" 38PARAMETER_NAME_PREFIX_MAX_LEN = 1024 39 40 41def _is_in_parallel_mode(): 42 """Get parallel mode.""" 43 return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] 44 45 46def init_to_value(init): 47 """Get value of initializer.""" 48 if isinstance(init, str): 49 if init == 'zeros': 50 return 0.0 51 if init == 'ones': 52 return 1.0 53 raise ValueError("The argument 'init' should be one of values in ['zeros', 'ones'].") 54 if isinstance(init, numbers.Number): 55 return float(init) 56 raise ValueError("The argument 'init' should be number or string, but got {}.".format(type(init))) 57 58 59class Parameter(Tensor_): 60 """ 61 An object holding weights of cells, after initialized `Parameter` is a subtype of `Tensor`. 62 63 Note: 64 In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by 65 a `Tensor`, the type of Parameter will be `Tensor`. `Tensor` 66 will save the shape and type info of a tensor with no memory usage. The shape can be changed while 67 compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. 68 If there is an operator in the network that requires part of the inputs to be Parameter, 69 then the Parameters as this part of the inputs are not allowed to be cast. 70 It is recommended to use the default value of `name` when initialize a parameter as one attribute of a cell, 71 otherwise, the parameter name may be different from expected. 72 73 Args: 74 default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data, 75 to initialize the parameter data. 76 name (str): Name of the child parameter. Default: None. 77 requires_grad (bool): True if the parameter requires gradient. Default: True. 78 layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode, 79 broadcast and gradients communication would not be applied to parameters. Default: False. 80 parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel 81 mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`. 82 Default: True. 83 84 Examples: 85 >>> import numpy as np 86 >>> from mindspore import Parameter, Tensor 87 >>> import mindspore.ops as ops 88 >>> import mindspore.nn as nn 89 >>> import mindspore 90 >>> 91 >>> class Net(nn.Cell): 92 ... def __init__(self): 93 ... super(Net, self).__init__() 94 ... self.matmul = ops.MatMul() 95 ... self.weight = Parameter(Tensor(np.ones((1, 2)), mindspore.float32), name="w", requires_grad=True) 96 ... 97 ... def construct(self, x): 98 ... out = self.matmul(self.weight, x) 99 ... return out 100 >>> net = Net() 101 >>> x = Tensor(np.ones((2, 1)), mindspore.float32) 102 >>> print(net(x)) 103 [[2.]] 104 >>> net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32)) 105 >>> print(net(x)) 106 [[0.]] 107 """ 108 __base_type__ = {} 109 110 def __new__(cls, default_input, *args, **kwargs): 111 init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init) 112 input_class, *class_init_args = Parameter._get_parameter_new_args(default_input) 113 new_type = Parameter._get_base_class(input_class) 114 obj = input_class.__new__(new_type) 115 input_class.__init__(obj, *class_init_args) 116 # it's better to make the Initializer a kind of tensor. 117 obj.init_mode = None 118 obj.is_default_input_init = init_data_flag 119 if obj.has_init: 120 obj.init_mode = default_input 121 return obj 122 123 def __reduce_ex__(self, _): 124 data = self 125 if self.init_mode is not None: 126 data = self.init_mode 127 else: 128 # cast to break deep infinite loop while deepcopy 129 data = Tensor(self) 130 return ( 131 Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) 132 133 def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): 134 self.param_info = ParamInfo() 135 self.init_in_server = False 136 self.cache_enable = False 137 self.name = name 138 self.requires_grad = requires_grad 139 self.layerwise_parallel = layerwise_parallel 140 self.parallel_optimizer = parallel_optimizer 141 # this flag for tensor copy data. 142 self.init_flag = False 143 # this flag is for ge variable copy data. 144 self._is_init = False 145 self._inited_param = None 146 self._sliced = False 147 self.is_param_ps = False 148 self.push_weight_to_server = False 149 self.pull_weight_from_server = False 150 self.requires_aggr = True 151 self._cast_type = None 152 self._unique = False 153 self.is_in_parallel = _is_in_parallel_mode() 154 self._pipeline_stage_list = [] 155 if isinstance(default_input, (Tensor_, Tensor)): 156 Tensor_.__init__(self, default_input.dtype, default_input.shape) 157 elif isinstance(default_input, int): 158 Tensor_.__init__(self, mstype.int64, ()) 159 elif isinstance(default_input, float): 160 Tensor_.__init__(self, mstype.float32, ()) 161 elif isinstance(default_input, (np.ndarray, list)): 162 Tensor_.__init__(self, default_input) 163 else: 164 raise TypeError(f"The type of the argument 'default_input' must be in ['Tensor', 'int', 'float'," 165 f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.") 166 167 def __deepcopy__(self, memodict): 168 new_obj = Parameter(self) 169 new_obj.name = self.name 170 new_obj._inited_param = self._inited_param # pylint: disable=W0212 171 return new_obj 172 173 @staticmethod 174 def _get_base_class(input_class): 175 input_class_name = f'Parameter{input_class.__name__}' 176 if input_class_name in Parameter.__base_type__: 177 new_type = Parameter.__base_type__[input_class_name] 178 else: 179 new_type = type(input_class_name, (Parameter, input_class), {}) 180 Parameter.__base_type__[input_class_name] = new_type 181 return new_type 182 183 @staticmethod 184 def _get_parameter_new_args(data): 185 """Set `set_data` of current `Parameter`.""" 186 if isinstance(data, bool): 187 raise ValueError('Parameter data can not be `bool`') 188 if isinstance(data, Tensor) and data.has_init: 189 if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched() or _is_role_pserver(): 190 # do not init data while in auto parallel. 191 return (Tensor, None, data.dtype, data.shape, data.init) 192 data = data.init_data().asnumpy() 193 elif isinstance(data, Tensor): 194 # make a copy of Tensor to init the parameter 195 return (Tensor, data.asnumpy(),) 196 if isinstance(data, int): 197 return (Tensor, data, mstype.int32) 198 if isinstance(data, float): 199 return (Tensor, data, mstype.float32) 200 return (Tensor, data) 201 202 def __str__(self): 203 return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \ 204 f'requires_grad={self.requires_grad})' 205 206 def __repr__(self): 207 return self.__str__() 208 209 def __parameter__(self): 210 """For parse check.""" 211 212 def set_param_ps(self, init_in_server=False): 213 """ 214 Set whether the trainable parameter is updated by parameter server and whether the 215 trainable parameter is initialized on server. 216 217 Note: 218 It only works when a running task is in the parameter server mode. 219 220 Args: 221 init_in_server (bool): Whether trainable parameter updated by parameter server is 222 initialized on server. Default: False. 223 """ 224 if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()): 225 raise RuntimeError("Must complete following two steps before calling set_param_ps: \n" 226 "1. context.set_ps_context(enable_ps=True) \n" 227 "2. export MS_ROLE environment variable \n" 228 "Please refer to the official website for detailed usage.") 229 if init_in_server and (not self.name.endswith("embedding_table")): 230 raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " 231 "sparse operator support initialization in server.".format(self.name)) 232 self.is_param_ps = True 233 self.init_in_server = init_in_server 234 self.param_info.init_in_server = init_in_server 235 236 def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True): 237 """ 238 Set the way of parameter and server interaction. 239 240 Args: 241 push_to_server (bool): Whether the parameter should be pushed to server. Default: False. 242 pull_from_server (bool): Whether the parameter should be pulled from server. Default: False. 243 requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True. 244 """ 245 if push_to_server: 246 self.push_weight_to_server = True 247 if pull_from_server: 248 self.pull_weight_from_server = True 249 if not requires_aggr: 250 self.requires_aggr = False 251 self.param_info.requires_aggr = False 252 253 @property 254 def inited_param(self): 255 """ 256 Get the new parameter after call the init_data. 257 258 Default is a None, If `self` is a Parameter without data, after call the 259 `init_data` the initialized Parameter with data will be recorded here. 260 """ 261 return self._inited_param 262 263 @property 264 def name(self): 265 """Get the name of the parameter.""" 266 return self.param_info.name 267 268 @name.setter 269 def name(self, name_): 270 """ 271 Define a name for the parameter. 272 273 Args: 274 name_ (`str` or `None`): The name of the parameter. When the parameter is None or an empty string, 275 the default value `PARAMETER_NAME_DEFAULT` is used. 276 """ 277 if name_ is None: 278 name_ = PARAMETER_NAME_DEFAULT 279 elif isinstance(name_, str): 280 name_ = name_.strip() 281 if name_ == '': 282 name_ = PARAMETER_NAME_DEFAULT 283 if len(name_) > PARAMETER_NAME_PREFIX_MAX_LEN: 284 raise ValueError("The length of the '{}' name should be less than {}.". 285 format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) 286 else: 287 raise ValueError("The type of the Parameter's name should be 'string' or 'None', " 288 "but got {}.".format(type(name_))) 289 290 if _is_role_worker() and self.cache_enable: 291 if len(self.shape) != 2: 292 raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." 293 .format(self.name, len(self.shape))) 294 _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1]) 295 296 self.param_info.name = name_ 297 298 @property 299 def sliced(self): 300 """Get slice status of the parameter.""" 301 return self._sliced 302 303 @sliced.setter 304 def sliced(self, sliced_): 305 self._sliced = sliced_ 306 307 @property 308 def comm_fusion(self): 309 """ 310 Get and set the fusion type (int) for communication operators corresponding to this parameter. 311 312 In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or 313 gradients aggregation are inserted automatically. Set the fusion type for communication operators generated 314 for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0, 315 operators will not be fused together. 316 317 Only support in Ascend environment with Graph mode. 318 """ 319 return self.param_info.comm_fusion 320 321 @comm_fusion.setter 322 def comm_fusion(self, comm_fusion_): 323 if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode(): 324 raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE") 325 Validator.check_non_negative_int(comm_fusion_) 326 self.param_info.comm_fusion = comm_fusion_ 327 328 @property 329 def parallel_optimizer_comm_recompute(self): 330 """ 331 Get and Set the whether do recompute for communication operators corresponding to this parameter 332 when applying parallel optimizer. 333 334 In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, when applying parallel optimizer, some all_gather operators 335 used for parameters gathering are inserted automatically. 336 The interface is used to control the recompute attr for those all_gather operators. 337 338 Note: 339 - Only `Ascend` and `Graph` mode is supported. 340 - It is recommended to use cell.recompute(parallel_optimizer_comm_recompute=True/False) to configure 341 the all_gather operators introducing by parallel optimizer rather than using this interface directly. 342 """ 343 return self.param_info.parallel_optimizer_comm_recompute 344 345 @parallel_optimizer_comm_recompute.setter 346 def parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute_): 347 Validator.check_bool(parallel_optimizer_comm_recompute_) 348 self.param_info.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute_ 349 350 @property 351 def unique(self): 352 """whether the parameter is already unique or not.""" 353 return self._unique 354 355 @unique.setter 356 def unique(self, unique_): 357 self._unique = unique_ 358 359 @property 360 def is_init(self): 361 """ 362 Get the initialization status of the parameter. 363 364 This flag only work in GE, and it will be set to False in other backend. 365 """ 366 return self._is_init 367 368 @is_init.setter 369 def is_init(self, is_init_): 370 """ 371 Set init status of the parameter. 372 373 Args: 374 is_init_ (bool): The init status of the parameter. 375 """ 376 self._is_init = is_init_ 377 378 def clone(self, init='same'): 379 """ 380 Clone the parameter. 381 382 Args: 383 init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameter. 384 If `init` is a `Tensor` or `numbers.Number`, clone a new parameter with the same shape 385 and dtype, and the data of the new parameter will be set according to `init`. If `init` 386 is a `str`, the `init` should be the alias of the class inheriting from `Initializer`. 387 For example, if `init` is 'same', clone a new parameter with the same data, shape, and 388 dtype. Default: 'same'. 389 390 Returns: 391 Parameter, a new parameter. 392 """ 393 x = copy(self) 394 x.param_info = self.param_info.clone() 395 x.is_init = False 396 x.init = self.init 397 x.is_param_ps = self.is_param_ps 398 x.init_in_server = self.init_in_server 399 x.cache_enable = self.cache_enable 400 x.requires_aggr = self.requires_aggr 401 if self.cache_shape: 402 x.cache_shape = self.cache_shape 403 if init != 'same': 404 shape = self.shape 405 dtype = self.dtype 406 x.set_data(initializer(init, shape=shape, dtype=dtype)) 407 return x 408 409 @property 410 def layerwise_parallel(self): 411 """ 412 When layerwise_parallel is true in data/hybrid parallel mode, broadcast and gradients communication would not 413 be applied to parameters. 414 """ 415 return self.param_info.layerwise_parallel 416 417 @layerwise_parallel.setter 418 def layerwise_parallel(self, value=True): 419 if not isinstance(value, bool): 420 raise TypeError("The argument `layerwise_parallel` must be bool type.") 421 self.param_info.layerwise_parallel = value 422 423 @property 424 def parallel_optimizer(self): 425 """ 426 It is used to filter the weight shard operation in semi auto or auto parallel mode. It works only 427 when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`. 428 """ 429 return self.param_info.parallel_optimizer 430 431 @parallel_optimizer.setter 432 def parallel_optimizer(self, value=True): 433 if not isinstance(value, bool): 434 raise TypeError("The argument `parallel_optimizer` must be bool type.") 435 self.param_info.parallel_optimizer = value 436 437 @property 438 def cache_enable(self): 439 """Return whether the parameter is cache enable.""" 440 return self.param_info.cache_enable 441 442 @cache_enable.setter 443 def cache_enable(self, value=True): 444 if not isinstance(value, bool): 445 raise TypeError("The argument `cache_enable` must be bool type.") 446 self.param_info.cache_enable = value 447 448 @property 449 def cache_shape(self): 450 """Return the cache shape corresponding to the parameter if use cache.""" 451 return self.param_info.cache_shape 452 453 @cache_shape.setter 454 def cache_shape(self, value): 455 if not isinstance(value, (tuple, list)): 456 raise TypeError("The argument `cache_shape` must be tuple or list type.") 457 self.param_info.cache_shape = value 458 459 @property 460 def requires_grad(self): 461 """Return whether the parameter requires gradient.""" 462 return self.param_info.requires_grad 463 464 @requires_grad.setter 465 def requires_grad(self, value=True): 466 if not isinstance(value, bool): 467 raise TypeError("The argument `requires_grad` must be bool type") 468 self.param_info.requires_grad = value 469 470 @property 471 def data(self): 472 """Return the parameter object.""" 473 return self 474 475 def _update_tensor_data(self, data): 476 """Update the parameter by a Tensor.""" 477 if isinstance(self, Tensor): 478 self.init_flag = False 479 self.init = None 480 return self.assign_value(data) 481 new_param = Parameter(data, self.name, self.requires_grad) 482 new_param.param_info = self.param_info 483 return new_param 484 485 def add_pipeline_stage(self, stage): 486 if not isinstance(stage, int) or stage < 0: 487 raise TypeError("`stage` must be a positive number of int type") 488 self._pipeline_stage_list.append(stage) 489 490 def set_data(self, data, slice_shape=False): 491 """ 492 Set Parameter's data. 493 494 Args: 495 data (Union[Tensor, int, float]): new data. 496 slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency. 497 Default: False. 498 499 Returns: 500 Parameter, the parameter after set data. 501 """ 502 def raise_type_error(incoming): 503 raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. " 504 f"Current dtype is {self.dtype}, and incoming is {incoming}. " 505 f"Use .set_dtype(xxx) to change the dtype.") 506 507 if not isinstance(data, (Tensor, int, float)): 508 raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` " 509 f"(like `Tensor`). But with type {type(data)}.") 510 if isinstance(data, (int, float)): 511 if self.dtype in mstype.int_type and isinstance(data, float): 512 raise_type_error(mstype.float_) 513 data = Tensor(data, self.dtype) 514 # both not init. 515 incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init 516 current_tensor_is_init = isinstance(self, Tensor) and not self.has_init 517 518 if incoming_tensor_is_init and not current_tensor_is_init: 519 raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized." 520 "Please initialize 'data' before call this method.") 521 if tuple(self.shape) != tuple(data.shape): 522 # If Slice create Parameter shape can be change. 523 if not slice_shape: 524 raise ValueError(f"Can not change the shape of Parameter which has been initialized." 525 f" Current shape is {self.shape}, and incoming is {data.shape}.") 526 if self.dtype != data.dtype: 527 if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]: 528 raise_type_error(data.dtype) 529 else: 530 from mindspore.ops import functional as F 531 data = F.cast(data, self.dtype) 532 if isinstance(data, Tensor) and data.has_init: 533 # The parameter has been initialized, directly update by the data 534 if current_tensor_is_init: 535 self._update_tensor_data(data.init_data()) 536 else: 537 # also update the related inited parameter data 538 if self.inited_param is not None: 539 self.inited_param.set_data(data) 540 self.init_mode = data 541 elif incoming_tensor_is_init or current_tensor_is_init: 542 self._update_tensor_data(data) 543 self.sliced = slice_shape 544 return self 545 546 def init_data(self, layout=None, set_sliced=False): 547 """ 548 Initialize the parameter's data. 549 550 Args: 551 layout (Union[None, tuple(list(int))]): Parameter slice 552 layout [dev_mat, tensor_map, slice_shape]. Default: None. 553 554 - dev_mat (list(int)): Device matrix. 555 - tensor_map (list(int)): Tensor map. 556 - slice_shape (list(int)): Shape of slice. 557 558 set_sliced (bool): True if the parameter is set sliced after initializing the data. 559 Default: False. 560 561 Raises: 562 RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created. 563 ValueError: If the length of the layout is less than 3. 564 TypeError: If `layout` is not tuple. 565 566 Returns: 567 Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, 568 returns the same initialized `Parameter`. 569 """ 570 if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode(): 571 raise RuntimeError("Must set or change parallel mode before any Tensor created.") 572 if self.init_mode is None: 573 return self 574 if self.inited_param is not None: 575 return self.inited_param 576 if _is_role_worker() and self.cache_enable: 577 global_seed, op_seed = _get_global_and_op_seed() 578 _insert_weight_init_info(self.name, global_seed, op_seed) 579 580 init_data_args = () 581 if layout is not None: 582 if not isinstance(layout, tuple): 583 raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout))) 584 if len(layout) < 6: 585 raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout))) 586 slice_index = int(_get_slice_index(layout[0], layout[1])) 587 init_data_args += (slice_index, layout[2], layout[5]) 588 589 if _is_role_pserver(): 590 return self 591 592 if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \ 593 self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()): 594 data = self.init_mode.init_data(0, [1]) 595 else: 596 data = self.init_mode.init_data(*init_data_args) 597 598 obj = self._update_tensor_data(data) 599 if id(obj) != id(self): 600 self._inited_param = obj 601 obj.init_mode = None 602 obj.sliced = set_sliced 603 return obj 604 605 606class ParameterTuple(tuple): 607 """ 608 Class for storing tuple of parameters. 609 610 Note: 611 It is used to store the parameters of the network into the parameter tuple collection. 612 """ 613 def __new__(cls, iterable): 614 """Create instance object of ParameterTuple.""" 615 data = tuple(iterable) 616 ids = set() 617 orders = {} 618 for x in data: 619 if not isinstance(x, Parameter): 620 raise TypeError(f"ParameterTuple input should be `Parameter` collection." 621 f"But got a {type(iterable)}, {iterable}") 622 if id(x) not in ids: 623 ids.add(id(x)) 624 if x.name not in orders.keys(): 625 orders[x.name] = [0, x] 626 else: 627 if isinstance(orders[x.name], list): 628 name = x.name 629 orders[name][1].name = name + "_" + str(0) 630 x.name = x.name + "_" + str(1) 631 orders[name] = 1 632 else: 633 orders[x.name] += 1 634 x.name = x.name + "_" + str(orders[x.name]) 635 return tuple.__new__(ParameterTuple, tuple(data)) 636 637 def clone(self, prefix, init='same'): 638 """ 639 Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple. 640 641 Args: 642 prefix (str): Namespace of parameter. 643 init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameters. 644 The definition of `init` is the same as in `Parameter` API. If `init` is 'same', the 645 parameters in the new parameter tuple are the same as those in the original parameter tuple. 646 Default: 'same'. 647 648 Raises: 649 RuntimeError: If parameter's name is not end with embedding_table. 650 651 Returns: 652 Tuple, the new Parameter tuple. 653 """ 654 Validator.check_str_by_regular(prefix) 655 new = [] 656 for x in self: 657 x1 = x.clone(init) 658 x1.name = prefix + "." + x1.name 659 new.append(x1) 660 661 if not x1.cache_enable: 662 continue 663 if not x1.name.endswith("embedding_table"): 664 raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of " 665 "sparse operator support enable cache.".format(x1.name)) 666 667 if _is_role_worker(): 668 _clone_hash_table(x.name, x1.name) 669 _insert_accumu_init_info(x1.name, init_to_value(init)) 670 return ParameterTuple(new) 671 672 def __parameter_tuple__(self): 673 """For parse check.""" 674