• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""
17Note:
18  SPONGE operators. This is an experimental interface that is subject to change and/or deletion.
19"""
20
21import math
22
23from ..primitive import PrimitiveWithInfer, prim_attr_register
24from ..._checkparam import Rel
25from ..._checkparam import Validator as validator
26from ...common import dtype as mstype
27
28
29class BondForce(PrimitiveWithInfer):
30    """
31    Calculate the force exerted by the simple harmonic bond on the corresponding atoms.
32    Assume the number of harmonic bonds is m and the number of atoms is n.
33
34    Because there is a large amount of inputs and each of them are related,
35    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
36    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
37
38    .. math::
39
40        dr = (x_1-x_2, y_1-y_2, z_1-z_2)
41
42    .. math::
43
44        F = (F_x, F_y, F_z) = 2*k*(1 - r_0/|dr|)*dr
45
46    Args:
47        atom_numbers(int32): the number of atoms n.
48        bond_numbers(int32): the number of harmonic bonds m.
49
50    Inputs:
51        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
52          The data type is uint32 and the shape is :math:`(n, 3)`.
53        - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
54          between the real space float coordinates and the unsigned int coordinates.
55          The data type is float32  and the shape is :math:`(3,)`.
56        - **atom_a** (Tensor) - The first atom index of each bond.
57          The data type is int32 and the shape is :math:`(m,)`.
58        - **atom_b** (Tensor) - The second atom index of each bond.
59          The data type is int32 and the shape is :math:`(m,)`.
60        - **bond_k** (Tensor) - The force constant of each bond.
61          The data type is float32 and the shape is :math:`(m,)`.
62        - **bond_r0** (Tensor) - The equlibrium length of each bond.
63          The data type is float32 and the shape is :math:`(m,)`.
64
65    Outputs:
66        - **frc_f** (Tensor) - The force felt by each atom.
67          The data type is float32 and the shape is :math:`(n, 3)`.
68
69    Supported Platforms:
70        ``GPU``
71    """
72
73    @prim_attr_register
74    def __init__(self, bond_numbers, atom_numbers):
75        """Initialize BondForce."""
76        validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
77        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
78        self.bond_numbers = bond_numbers
79        self.atom_numbers = atom_numbers
80        self.add_prim_attr('bond_numbers', self.bond_numbers)
81        self.add_prim_attr('atom_numbers', self.atom_numbers)
82        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
83                                outputs=['frc_f'])
84
85    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
86        cls_name = self.name
87        n = self.atom_numbers
88        m = self.bond_numbers
89        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
90        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
91        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
92        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
93        validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
94        validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
95        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
96        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
97        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
98        validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
99        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
100        validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
101        validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
102        return uint_crd_f_shape
103
104    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
105        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
106        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
107        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
108        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
109        validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
110        validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
111        return bond_r0_type
112
113
114class BondEnergy(PrimitiveWithInfer):
115    """
116    Calculate the harmonic potential energy between each bonded atom pair.
117    Assume our system has n atoms and m harmonic bonds.
118
119    Because there is a large amount of inputs and each of them are related,
120    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
121    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
122
123    .. math::
124
125        dr = (x_1-x_2, y_1-y_2, z_1-z_2)
126
127    .. math::
128
129        E = k*(|dr| - r_0)^2
130
131    Args:
132        atom_numbers(int32): the number of atoms n.
133        bond_numbers(int32): the number of harmonic bonds m.
134
135    Inputs:
136        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
137          The data type is uint32 and the shape is :math:`(n, 3)`.
138        - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
139          between the real space float coordinates and the unsigned int coordinates.
140          The data type is float32 and the shape is :math:`(3,)`.
141        - **atom_a** (Tensor) - The first atom index of each bond.
142          The data type is int32 and the shape is :math:`(m,)`.
143        - **atom_b** (Tensor) - The second atom index of each bond.
144          The data type is int32 and the shape is :math:`(m,)`.
145        - **bond_k** (Tensor) - The force constant of each bond.
146          The data type is float32 and the shape is :math:`(m,)`.
147        - **bond_r0** (Tensor) - The equlibrium length of each bond.
148          The data type is float32 and the shape is :math:`(m,)`.
149
150    Outputs:
151        - **bond_ene** (Tensor) - The harmonic potential energy for each bond.
152          The data type is float32 and the shape is :math:`(m,)`.
153
154    Supported Platforms:
155        ``GPU``
156    """
157
158    @prim_attr_register
159    def __init__(self, bond_numbers, atom_numbers):
160        """Initialize BondEnergy."""
161        validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
162        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
163        self.bond_numbers = bond_numbers
164        self.atom_numbers = atom_numbers
165        self.add_prim_attr('bond_numbers', self.bond_numbers)
166        self.add_prim_attr('atom_numbers', self.atom_numbers)
167        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
168                                outputs=['bond_ene'])
169
170    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
171        cls_name = self.name
172        n = self.atom_numbers
173        m = self.bond_numbers
174        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
175        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
176        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
177        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
178        validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
179        validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
180        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
181        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
182        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
183        validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
184        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
185        validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
186        validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
187
188        return bond_k_shape
189
190    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
191        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
192        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
193        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
194        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
195        validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
196        validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
197        return bond_r0_type
198
199
200class BondAtomEnergy(PrimitiveWithInfer):
201    """
202    Add the potential energy caused by simple harmonic bonds to the total
203    potential energy of each atom.
204
205    The calculation formula is the same as operator BondEnergy().
206
207    Because there is a large amount of inputs and each of them are related,
208    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
209    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
210
211    Args:
212        atom_numbers(int32): the number of atoms n.
213        bond_numbers(int32): the number of harmonic bonds m.
214
215    Inputs:
216        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
217          The data type is uint32 and the shape is :math:`(n, 3)`.
218        - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
219          between the real space float coordinates and the unsigned int coordinates.
220          The data type is float32 and the shape is :math:`(3,)`.
221        - **atom_a** (Tensor) - The first atom index of each bond.
222          The data type is int32 and the shape is :math:`(m,)`.
223        - **atom_b** (Tensor) - The second atom index of each bond.
224          The data type is int32 and the shape is :math:`(m,)`.
225        - **bond_k** (Tensor) - The force constant of each bond.
226          The data type is float32 and the shape is :math:`(m,)`.
227        - **bond_r0** (Tensor) - The equlibrium length of each bond.
228          The data type is float32 and the shape is :math:`(m,)`.
229
230    Outputs:
231        - **atom_ene** (Tensor) - The accumulated potential energy for each atom.
232          The data type is float32 and the shape is :math:`(n,)`.
233
234    Supported Platforms:
235        ``GPU``
236    """
237
238    @prim_attr_register
239    def __init__(self, bond_numbers, atom_numbers):
240        """Initialize BondAtomEnergy."""
241        validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
242        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
243        self.bond_numbers = bond_numbers
244        self.atom_numbers = atom_numbers
245        self.add_prim_attr('bond_numbers', self.bond_numbers)
246        self.add_prim_attr('atom_numbers', self.atom_numbers)
247        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
248                                outputs=['atom_ene'])
249
250    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
251        cls_name = self.name
252        n = self.atom_numbers
253        m = self.bond_numbers
254        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
255        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
256        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
257        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
258        validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
259        validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
260        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
261        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
262        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
263        validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
264        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
265        validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
266        validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
267        return [n,]
268
269    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
270        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
271        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
272        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
273        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
274        validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
275        validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
276        return bond_r0_type
277
278
279class BondForceWithAtomEnergy(PrimitiveWithInfer):
280    """
281    Calculate bond force and harmonic potential energy together.
282
283    The calculation formula is the same as operator BondForce() and BondEnergy().
284
285    Because there is a large amount of inputs and each of them are related,
286    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
287    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
288
289    Args:
290        atom_numbers(int32): the number of atoms n.
291        bond_numbers(int32): the number of harmonic bonds m.
292
293    Inputs:
294        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
295          The data type is uint32 and the shape is :math:`(n, 3)`.
296        - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
297          between the real space float coordinates and the unsigned int coordinates.
298          The data type is float32 and the shape is :math:`(3,)`.
299        - **atom_a** (Tensor) - The first atom index of each bond.
300          The data type is int32 and the shape is :math:`(m,)`.
301        - **atom_b** (Tensor) - The second atom index of each bond.
302          The data type is int32 and the shape is :math:`(m,)`.
303        - **bond_k** (Tensor) - The force constant of each bond.
304          The data type is float32 and the shape is :math:`(m,)`.
305        - **bond_r0** (Tensor) - The equlibrium length of each bond.
306          The data type is float32 and the shape is :math:`(m,)`.
307
308    Outputs:
309        - **frc_f** (Tensor) - The force felt by each atom.
310          The data type is float32 and the shape is :math:`(n, 3)`.
311        - **atom_e** (Tensor) - The accumulated potential energy for each atom.
312          The data type is float32 and the shape is :math:`(n,)`.
313
314    Supported Platforms:
315        ``GPU``
316    """
317
318    @prim_attr_register
319    def __init__(self, bond_numbers, atom_numbers):
320        """Initialize BondForceWithAtomEnergy."""
321        validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
322        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
323        self.bond_numbers = bond_numbers
324        self.atom_numbers = atom_numbers
325        self.add_prim_attr('bond_numbers', self.bond_numbers)
326        self.add_prim_attr('atom_numbers', self.atom_numbers)
327        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
328                                outputs=['frc_f', 'atom_e'])
329
330    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
331        cls_name = self.name
332        n = self.atom_numbers
333        m = self.bond_numbers
334        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
335        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
336        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
337        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
338        validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
339        validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
340        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
341        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
342        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
343        validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
344        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
345        validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
346        validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
347
348        return uint_crd_f_shape, [n,]
349
350    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
351        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
352        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
353
354        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
355        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
356
357        validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
358        validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
359        return bond_r0_type, bond_r0_type
360
361
362class BondForceWithAtomVirial(PrimitiveWithInfer):
363    """
364    Calculate bond force and the virial coefficient caused by simple harmonic
365    bond for each atom together.
366
367    The calculation formula of the force part is the same as operator BondForce().
368
369    Because there is a large amount of inputs and each of them are related,
370    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
371    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
372
373    The Virial part is as follows:
374
375    .. math::
376
377        dr = (x_1-x_2, y_1-y_2, z_1-z_2)
378
379    .. math::
380
381        virial = |dr|*(|dr| - r_0)*k
382
383    Args:
384        atom_numbers(int32): the number of atoms n.
385        bond_numbers(int32): the number of harmonic bonds m.
386
387    Inputs:
388        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
389          The data type is uint32 and the shape is :math:`(n, 3)`.
390        - **scaler_f** (Tensor) - The 3-D scale factor (x, y, z),
391          between the real space float coordinates and the unsigned int coordinates.
392          The data type is float32 and the shape is :math:`(3,)`.
393        - **atom_a** (Tensor) - The first atom index of each bond.
394          The data type is int32 and the shape is :math:`(m,)`.
395        - **atom_b** (Tensor) - The second atom index of each bond.
396          The data type is int32 and the shape is :math:`(m,)`.
397        - **bond_k** (Tensor) - The force constant of each bond.
398          The data type is float32 and the shape is :math:`(m,)`.
399        - **bond_r0** (Tensor) - The equlibrium length of each bond.
400          The data type is float32 and the shape is :math:`(m,)`.
401
402    Outputs:
403        - **frc_f** (Tensor) - Same as operator BondForce().
404          The data type is float32 and the shape is :math:`(n, 3)`.
405        - **atom_v** (Tensor) - The accumulated virial coefficient for each atom.
406          The data type is float32 and the shape is :math:`(n,)`.
407
408    Supported Platforms:
409        ``GPU``
410    """
411
412    @prim_attr_register
413    def __init__(self, bond_numbers, atom_numbers):
414        """Initialize BondForceWithAtomVirial."""
415        validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
416        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
417        self.bond_numbers = bond_numbers
418        self.atom_numbers = atom_numbers
419        self.add_prim_attr('bond_numbers', self.bond_numbers)
420        self.add_prim_attr('atom_numbers', self.atom_numbers)
421        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'],
422                                outputs=['frc_f', 'atom_v'])
423
424    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape):
425        cls_name = self.name
426        n = self.atom_numbers
427        m = self.bond_numbers
428        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
429        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
430        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
431        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
432        validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
433        validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
434        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
435        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
436        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
437        validator.check_int(atom_a_shape[0], m, Rel.EQ, "uint_crd_f_shape", cls_name)
438        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
439        validator.check_int(bond_k_shape[0], m, Rel.EQ, "bond_k_shape", cls_name)
440        validator.check_int(bond_r0_shape[0], m, Rel.EQ, "bond_r0_shape", cls_name)
441
442        return uint_crd_f_shape, [n,]
443
444    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
445        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
446        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
447
448        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
449        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
450
451        validator.check_tensor_dtype_valid('bond_k', bond_k_type, [mstype.float32], self.name)
452        validator.check_tensor_dtype_valid('bond_r0', bond_r0_type, [mstype.float32], self.name)
453        return bond_r0_type, bond_r0_type
454
455
456class DihedralForce(PrimitiveWithInfer):
457    """
458    Calculate the force exerted by the dihedral term which made of 4-atoms
459    on the corresponding atoms. Assume the number of dihedral terms is m and
460    the number of atoms is n.
461
462    Because there is a large amount of inputs and each of them are related,
463    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
464    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
465
466    Args:
467        dihedral_numbers(int32): the number of dihedral terms m.
468
469    Inputs:
470        - **uint_crd_f** (Tensor) - The unsigned int coordinates
471          value of each atom. The data type is uint32 and the shape is :math:`(n, 3)`.
472        - **scaler_f** (Tensor) - The 3-D scale factor between
473          the real space float coordinates and the unsigned int coordinates.
474          The data type is float32 and the shape is :math:`(3,)`.
475        - **atom_a** (Tensor) - The 1st atom index of each dihedral.
476          The data type is int32 and the shape is :math:`(m,)`.
477        - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
478          The data type is int32 and the shape is :math:`(m,)`.
479        - **atom_c** (Tensor) - The 3rd atom index of each dihedral.
480          The data type is int32 and the shape is :math:`(m,)`.
481        - **atom_d** (Tensor) - The 4th atom index of each dihedral.
482          4 atoms are connected in the form a-b-c-d.
483          The data type is int32 and the shape is :math:`(m,)`.
484        - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
485          The data type is int32 and the shape is :math:`(m,)`.
486        - **pk** (Tensor) - The force constant of each dihedral.
487          The data type is float32 and the shape is :math:`(m,)`.
488        - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
489          The data type is float32 and the shape is :math:`(m,)`.
490        - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
491          The data type is float32 and the shape is :math:`(m,)`.
492        - **pn** (Tensor) - The floating point form of ipn.
493          The data type is float32 and the shape is :math:`(m,)`.
494
495    Outputs:
496        - **frc_f** (Tensor) - The force felt by each atom.
497          The data type is float32 and the shape is :math:`(n, 3)`.
498
499    Supported Platforms:
500        ``GPU``
501    """
502
503    @prim_attr_register
504    def __init__(self, dihedral_numbers):
505        """Initialize DihedralForce."""
506        validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
507        self.dihedral_numbers = dihedral_numbers
508        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
509                                        'gamc', 'gams', 'pn'],
510                                outputs=['frc_f'])
511        self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
512
513    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
514                    ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
515        cls_name = self.name
516        m = self.dihedral_numbers
517        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
518        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
519        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
520        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
521        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
522        validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
523        validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
524        validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
525        validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
526        validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
527        validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
528
529        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
530        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
531        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
532        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
533        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
534        validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
535        validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
536        validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
537        validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
538        validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
539        validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
540        return uint_crd_f_shape
541
542    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
543                    ipn_type, pk_type, gamc_type, gams_type, pn_type):
544        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
545        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
546        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
547        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
548        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
549        validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
550        validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
551        validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
552        validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
553        validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
554        validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
555
556        return pn_type
557
558
559class DihedralEnergy(PrimitiveWithInfer):
560    """
561    Calculate the potential energy caused by dihedral terms for each 4-atom pair.
562    Assume our system has n atoms and m dihedral terms.
563
564    Because there is a large amount of inputs and each of them are related,
565    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
566    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
567
568    Args:
569        dihedral_numbers(int32): the number of dihedral terms m.
570
571    Inputs:
572        - **uint_crd_f** (Tensor) - The unsigned int coordinates
573          value of each atom.
574          The data type is uint32 and the shape is :math:`(n, 3)`.
575        - **scaler_f** (Tensor) - The 3-D scale factor between
576          the real space float coordinates and the unsigned int coordinates.
577          The data type is float32 and the shape is :math:`(3,)`.
578        - **atom_a** (Tensor) - The 1st atom index of each dihedral.
579          The data type is int32 and the shape is :math:`(m,)`.
580        - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
581          The data type is int32 and the shape is :math:`(m,)`.
582        - **atom_c** (Tensor) - The 3rd atom index of each dihedral.
583          The data type is int32 and the shape is :math:`(m,)`.
584        - **atom_d** (Tensor) - The 4th atom index of each dihedral.
585          4 atoms are connected in the form a-b-c-d.
586          The data type is int32 and the shape is :math:`(m,)`.
587        - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
588          The data type is int32 and the shape is :math:`(m,)`.
589        - **pk** (Tensor) - The force constant of each dihedral.
590          The data type is int32 and the shape is :math:`(m,)`.
591        - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
592          The data type is float32 and the shape is :math:`(m,)`.
593        - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
594          The data type is float32 and the shape is :math:`(m,)`.
595        - **pn** (Tensor) - The floating point form of ipn.
596          The data type is float32 and the shape is :math:`(m,)`.
597
598    Outputs:
599        - **ene** (Tensor) - The potential energy for each
600          dihedral term. The data type is float32 and the shape is :math:`(m,)`.
601
602    Supported Platforms:
603        ``GPU``
604    """
605
606    @prim_attr_register
607    def __init__(self, dihedral_numbers):
608        """Initialize DihedralEnergy."""
609        validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
610        self.dihedral_numbers = dihedral_numbers
611        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
612                                        'gamc', 'gams', 'pn'],
613                                outputs=['ene'])
614        self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
615
616    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
617                    ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
618        cls_name = self.name
619        m = self.dihedral_numbers
620        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
621        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
622        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
623        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
624        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
625        validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
626        validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
627        validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
628        validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
629        validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
630        validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
631
632        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
633        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
634        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
635        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
636        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
637        validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
638        validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
639        validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
640        validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
641        validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
642        validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
643        return [m,]
644
645    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
646                    ipn_type, pk_type, gamc_type, gams_type, pn_type):
647        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
648        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
649        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
650        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
651        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
652        validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
653        validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
654        validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
655        validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
656        validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
657        validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
658
659        return pn_type
660
661
662class DihedralAtomEnergy(PrimitiveWithInfer):
663    """
664    Add the potential energy caused by dihedral terms to the total potential
665    energy of each atom.
666
667    Because there is a large amount of inputs and each of them are related,
668    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
669    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
670
671    The calculation formula is the same as operator DihedralEnergy().
672
673    Args:
674        dihedral_numbers(int32): the number of dihedral terms m.
675
676    Inputs:
677        - **uint_crd_f** (Tensor) - The unsigned int coordinates
678          value of each atom. The data type is uint32 and the shape is :math:`(n, 3)`.
679        - **scaler_f** (Tensor) - The 3-D scale factor between
680          the real space float coordinates and the unsigned int coordinates.
681          The data type is float32 and the shape is :math:`(3,)`.
682        - **atom_a** (Tensor) - The 1st atom index of each dihedral.
683          The data type is int32 and the shape is :math:`(m,)`.
684        - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
685          The data type is int32 and the shape is :math:`(m,)`.
686        - **atom_c** (Tenso) - The 3rd atom index of each dihedral.
687          The data type is int32 and the shape is :math:`(m,)`.
688        - **atom_d** (Tensor) - The 4th atom index of each dihedral.
689          4 atoms are connected in the form a-b-c-d. The data type is int32 and the shape is :math:`(m,)`.
690        - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
691          The data type is int32 and the shape is :math:`(m,)`.
692        - **pk** (Tensor) - The force constant of each dihedral.
693          The data type is float32 and the shape is :math:`(m,)`.
694        - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
695          The data type is float32 and the shape is :math:`(m,)`.
696        - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
697          The data type is float32 and the shape is :math:`(m,)`.
698        - **pn** (Tensor) - The floating point form of ipn.
699          The data type is float32 and the shape is :math:`(m,)`.
700
701    Outputs:
702        - **ene** (Tensor) - The accumulated potential
703          energy for each atom. The data type is float32 and the shape is :math:`(n,)`.
704
705    Supported Platforms:
706        ``GPU``
707    """
708
709    @prim_attr_register
710    def __init__(self, dihedral_numbers):
711        """Initialize DihedralAtomEnergy."""
712        validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
713        self.dihedral_numbers = dihedral_numbers
714        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
715                                        'gamc', 'gams', 'pn'],
716                                outputs=['ene'])
717        self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
718
719    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
720                    ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
721        cls_name = self.name
722        n = uint_crd_f_shape[0]
723        m = self.dihedral_numbers
724        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
725        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
726        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
727        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
728        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
729        validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
730        validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
731        validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
732        validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
733        validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
734        validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
735
736        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
737        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
738        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
739        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
740        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
741        validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
742        validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
743        validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
744        validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
745        validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
746        validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
747        return [n,]
748
749    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
750                    ipn_type, pk_type, gamc_type, gams_type, pn_type):
751        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
752        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
753        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
754        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
755        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
756        validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
757        validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
758        validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
759        validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
760        validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
761        validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
762
763        return pn_type
764
765
766class DihedralForceWithAtomEnergy(PrimitiveWithInfer):
767    """
768    Calculate dihedral force and potential energy together.
769
770    The calculation formula is the same as operator DihedralForce() and DihedralEnergy().
771
772    Because there is a large amount of inputs and each of them are related,
773    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
774    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
775
776    Args:
777        dihedral_numbers(int32): the number of dihedral terms m.
778
779    Inputs:
780        - **uint_crd_f** (Tensor) - The unsigned int coordinates
781          value of each atom. The data type is uint32 and the shape is :math:`(n, 3)`.
782        - **scaler_f** (Tensor) - The 3-D scale factor between
783          the real space float coordinates and the unsigned int coordinates.
784          The data type is float32 and the shape is :math:`(3,)`.
785        - **atom_a** (Tensor) - The 1st atom index of each dihedral.
786          The data type is int32 and the shape is :math:`(m,)`.
787        - **atom_b** (Tensor) - The 2nd atom index of each dihedral.
788          The data type is int32 and the shape is :math:`(m,)`.
789        - **atom_c** (Tenso) - The 3rd atom index of each dihedral.
790          The data type is int32 and the shape is :math:`(m,)`.
791        - **atom_d** (Tensor) - The 4th atom index of each dihedral.
792          4 atoms are connected in the form a-b-c-d. The data type is int32 and the shape is :math:`(m,)`.
793        - **ipn** (Tensor) - The period of dihedral angle of each dihedral.
794          The data type is int32 and the shape is :math:`(m,)`.
795        - **pk** (Tensor) - The force constant of each dihedral.
796          The data type is float32 and the shape is :math:`(m,)`.
797        - **gamc** (Tensor) - k*cos(phi_0) of each dihedral.
798          The data type is float32 and the shape is :math:`(m,)`.
799        - **gams** (Tensor) - k*sin(phi_0) of each dihedral.
800          The data type is float32 and the shape is :math:`(m,)`.
801        - **pn** (Tensor) - The floating point form of ipn.
802          The data type is float32 and the shape is :math:`(m,)`.
803
804    Outputs:
805        - **frc_f** (Tensor) - Same as operator DihedralForce().
806          The data type is float32 and the shape is :math:`(n, 3)`.
807        - **ene** (Tensor) - Same as operator DihedralAtomEnergy().
808          The data type is float32 and the shape is :math:`(n,)`.
809
810    Supported Platforms:
811        ``GPU``
812    """
813
814    @prim_attr_register
815    def __init__(self, dihedral_numbers):
816        """Initialize DihedralForceWithAtomEnergy."""
817        validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
818        self.dihedral_numbers = dihedral_numbers
819        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
820                                        'gamc', 'gams', 'pn'],
821                                outputs=['frc_f', 'ene'])
822        self.add_prim_attr('dihedral_numbers', self.dihedral_numbers)
823
824    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape,
825                    ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
826        cls_name = self.name
827        n = uint_crd_f_shape[0]
828        m = self.dihedral_numbers
829        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
830        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
831        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
832        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
833        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
834        validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
835        validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
836        validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
837        validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
838        validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
839        validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
840
841        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
842        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
843        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
844        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
845        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
846        validator.check_int(atom_d_shape[0], m, Rel.EQ, "atom_d_shape", cls_name)
847        validator.check_int(ipn_shape[0], m, Rel.EQ, "ipn_shape", cls_name)
848        validator.check_int(pk_shape[0], m, Rel.EQ, "pk_shape", cls_name)
849        validator.check_int(gamc_shape[0], m, Rel.EQ, "gamc_shape", cls_name)
850        validator.check_int(gams_shape[0], m, Rel.EQ, "gams_shape", cls_name)
851        validator.check_int(pn_shape[0], m, Rel.EQ, "pn_shape", cls_name)
852        return uint_crd_f_shape, [n,]
853
854    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
855                    ipn_type, pk_type, gamc_type, gams_type, pn_type):
856        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
857        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
858        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
859        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
860        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
861        validator.check_tensor_dtype_valid('atom_d', atom_d_type, [mstype.int32], self.name)
862        validator.check_tensor_dtype_valid('ipn', ipn_type, [mstype.int32], self.name)
863        validator.check_tensor_dtype_valid('pk', pk_type, [mstype.float32], self.name)
864        validator.check_tensor_dtype_valid('gamc', gamc_type, [mstype.float32], self.name)
865        validator.check_tensor_dtype_valid('gams', gams_type, [mstype.float32], self.name)
866        validator.check_tensor_dtype_valid('pn', pn_type, [mstype.float32], self.name)
867
868        return pn_type, pn_type
869
870
871class AngleForce(PrimitiveWithInfer):
872    """
873    Calculate the force exerted by angles made of 3 atoms on the
874    corresponding atoms. Assume the number of angles is m and the
875    number of atoms is n.
876
877    Because there is a large amount of inputs and each of them are related,
878    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
879    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
880
881    .. math::
882        dr_{ab} = (x_b-x_a, y_b-y_a, z_b-z_a)
883    .. math::
884        dr_{cb} = (x_b-x_c, y_b-y_c, z_b-z_c)
885    .. math::
886        theta = arccos(inner_product(dr_{ab}, dr_{cb})/|dr_{ab}|/|dr_{cb}|)
887    .. math::
888        F_a = -2*k*(theta-theta_0)/sin(theta)*[cos(theta)/|dr_{ab}|^2*dr_{ab}
889            - 1/|dr_{ab}|/|dr_{cb}|*dr_{cb}]
890    .. math::
891        F_c = -2*k*(theta-theta_0)/sin(theta)*[cos(theta)/|dr_{cb}|^2*dr_{cb}
892            - 1/|dr_{cb}|/|dr_{ab}|*dr_{ab}]
893    .. math::
894        F_b = -F_a - F_c
895
896    Args:
897        angle_numbers(int32): the number of angles m.
898
899    Inputs:
900        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
901          The data type is uint32 and the shape is :math:`(n, 3)`.
902        - **scaler_f** (Tensor) - The 3-D scale factor between
903          the real space float coordinates and the unsigned int coordinates.
904          The data type is float32 and the shape is :math:`(3,)`.
905        - **atom_a** (Tensor) - The 1st atom index of each angle.
906          The data type is int32 and the shape is :math:`(m,)`.
907        - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
908          The data type is int32 and the shape is :math:`(m,)`.
909        - **atom_c** (Tensor) - The 3rd atom index of each angle.
910          The data type is int32 and the shape is :math:`(m,)`.
911        - **angle_k** (Tensor) - The force constant for each angle.
912          The data type is float32 and the shape is :math:`(m,)`.
913        - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
914          The data type is float32 and the shape is :math:`(m,)`.
915
916    Outputs:
917        - **frc_f** (Tensor) - The force felt by each atom.
918          The data type is float32 and the shape is :math:`(n, 3)`.
919
920    Supported Platforms:
921        ``GPU``
922    """
923
924    @prim_attr_register
925    def __init__(self, angle_numbers):
926        """Initialize AngleForce."""
927        validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
928        self.angle_numbers = angle_numbers
929        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
930                                        'angle_theta0'],
931                                outputs=['frc_f'])
932        self.add_prim_attr('angle_numbers', self.angle_numbers)
933
934    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
935                    angle_theta0_shape):
936        cls_name = self.name
937        m = self.angle_numbers
938        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
939        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
940        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
941        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
942        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
943        validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
944        validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
945
946        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
947        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
948        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
949        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
950        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
951        validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
952        validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
953        return uint_crd_f_shape
954
955    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
956                    angle_theta0_type):
957        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
958        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
959        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
960        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
961        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
962        validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
963        validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
964        return angle_k_type
965
966
967class AngleEnergy(PrimitiveWithInfer):
968    """
969    Calculate the energy caused by 3-atoms angle term. Assume the number of angles is m and the
970    number of atoms is n.
971
972    Because there is a large amount of inputs and each of them are related,
973    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
974    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
975
976    .. math::
977        dr_{ab} = (x_b-x_a, y_b-y_a, z_b-z_a)
978    .. math::
979        dr_{cb} = (x_b-x_c, y_b-y_c, z_b-z_c)
980    .. math::
981        theta = arccos(inner_product(dr_{ab}, dr_{cb})/|dr_{ab}|/|dr_{cb}|)
982    .. math::
983        E = k*(theta - theta_0)^2
984
985    Args:
986        angle_numbers(int32): the number of angles m.
987
988    Inputs:
989        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
990          The data type is uint32 and the shape is :math:`(n, 3)`.
991        - **scaler_f** (Tensor) - The 3-D scale factor between
992          the real space float coordinates and the unsigned int coordinates.
993          The data type is float32 and the shape is :math:`(3,)`.
994        - **atom_a** (Tensor) - The 1st atom index of each angle.
995          The data type is int32 and the shape is :math:`(m,)`.
996        - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
997          The data type is int32 and the shape is :math:`(m,)`.
998        - **atom_c** (Tensor) - The 3rd atom index of each angle.
999          The data type is int32 and the shape is :math:`(m,)`.
1000        - **angle_k** (Tensor) - The force constant for each angle.
1001          The data type is float32 and the shape is :math:`(m,)`.
1002        - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
1003          The data type is float32 and the shape is :math:`(m,)`.
1004
1005    Outputs:
1006        - **ene** (Tensor) - The potential energy for each angle term.
1007          The data type is float32 and the shape is :math:`(m,)`.
1008
1009    Supported Platforms:
1010        ``GPU``
1011    """
1012
1013    @prim_attr_register
1014    def __init__(self, angle_numbers):
1015        """Initialize AngleEnergy."""
1016        validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
1017        self.angle_numbers = angle_numbers
1018        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
1019                                        'angle_theta0'],
1020                                outputs=['ene'])
1021        self.add_prim_attr('angle_numbers', self.angle_numbers)
1022
1023    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
1024                    angle_theta0_shape):
1025        cls_name = self.name
1026        m = self.angle_numbers
1027        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1028        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
1029        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
1030        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
1031        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
1032        validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
1033        validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
1034
1035        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1036        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
1037        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
1038        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
1039        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
1040        validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
1041        validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
1042        return [m,]
1043
1044    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
1045                    angle_theta0_type):
1046        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1047        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
1048        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
1049        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
1050        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
1051        validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
1052        validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
1053        return angle_k_type
1054
1055
1056class AngleAtomEnergy(PrimitiveWithInfer):
1057    """
1058    Add the potential energy caused by angle terms to the total potential
1059    energy of each atom. Assume the number of angles is m and the
1060    number of atoms is n.
1061
1062    The calculation formula is the same as operator AngleEnergy().
1063
1064    Because there is a large amount of inputs and each of them are related,
1065    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1066    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1067
1068    Args:
1069        angle_numbers(int32): the number of angles m.
1070
1071    Inputs:
1072        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1073          The data type is uint32 and the shape is :math:`(n, 3)`.
1074        - **scaler_f** (Tensor) - The 3-D scale factor between
1075          the real space float coordinates and the unsigned int coordinates.
1076          The data type is float32 and the shape is :math:`(3,)`.
1077        - **atom_a** (Tensor) - The 1st atom index of each angle.
1078          The data type is int32 and the shape is :math:`(m,)`.
1079        - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
1080          The data type is int32 and the shape is :math:`(m,)`.
1081        - **atom_c** (Tensor) - The 3rd atom index of each angle.
1082          The data type is int32 and the shape is :math:`(m,)`.
1083        - **angle_k** (Tensor) - The force constant for each angle.
1084          The data type is float32 and the shape is :math:`(m,)`.
1085        - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
1086          The data type is float32 and the shape is :math:`(m,)`.
1087
1088    Outputs:
1089        - **ene** (Tensor) - The accumulated potential energy for each atom.
1090          The data type is float32 and the shape is :math:`(n,)`.
1091
1092    Supported Platforms:
1093        ``GPU``
1094    """
1095
1096    @prim_attr_register
1097    def __init__(self, angle_numbers):
1098        """Initialize AngleAtomEnergy."""
1099        validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
1100        self.angle_numbers = angle_numbers
1101        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
1102                                        'angle_theta0'],
1103                                outputs=['ene'])
1104        self.add_prim_attr('angle_numbers', self.angle_numbers)
1105
1106    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
1107                    angle_theta0_shape):
1108        cls_name = self.name
1109        n = uint_crd_f_shape[0]
1110        m = self.angle_numbers
1111        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1112        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
1113        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
1114        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
1115        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
1116        validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
1117        validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
1118
1119        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1120        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
1121        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
1122        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
1123        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
1124        validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
1125        validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
1126        return [n,]
1127
1128    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
1129                    angle_theta0_type):
1130        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1131        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
1132        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
1133        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
1134        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
1135        validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
1136        validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
1137        return angle_k_type
1138
1139
1140class AngleForceWithAtomEnergy(PrimitiveWithInfer):
1141    """
1142    Calculate angle force and potential energy together. Assume the number of angles is m and the
1143    number of atoms is n.
1144
1145    The calculation formula is the same as operator AngleForce() and AngleEnergy().
1146
1147    Because there is a large amount of inputs and each of them are related,
1148    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1149    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1150
1151    Args:
1152        angle_numbers(int32): the number of angles m.
1153
1154    Inputs:
1155        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1156          The data type is uint32 and the shape is :math:`(n, 3)`.
1157        - **scaler_f** (Tensor) - The 3-D scale factor between
1158          the real space float coordinates and the unsigned int coordinates.
1159          The data type is float and the shape is :math:`(3,)`.
1160        - **atom_a** (Tensor) - The 1st atom index of each angle.
1161          The data type is int32 and the shape is :math:`(m,)`.
1162        - **atom_b** (Tensor) - The 2nd and the central atom index of each angle.
1163          The data type is int32 and the shape is :math:`(m,)`.
1164        - **atom_c** (Tensor) - The 3rd atom index of each angle.
1165          The data type is int32 and the shape is :math:`(m,)`.
1166        - **angle_k** (Tensor) - The force constant for each angle.
1167          The data type is float32 and the shape is :math:`(m,)`.
1168        - **angle_theta0** (Tensor) - The equilibrium position value for each angle.
1169          The data type is float32 and the shape is :math:`(m,)`.
1170
1171    Outputs:
1172        - **frc_f** (Tensor) - same as operator AngleForce().
1173          The data type is float32 and the shape is :math:`(n, 3)`.
1174        - **ene** (Tensor) - same as operator AngleAtomEnergy().
1175          The data type is float and the shape is :math:`(n,)`.
1176
1177    Supported Platforms:
1178        ``GPU``
1179    """
1180
1181    @prim_attr_register
1182    def __init__(self, angle_numbers):
1183        """Initialize AngleForceWithAtomEnergy."""
1184        validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
1185        self.angle_numbers = angle_numbers
1186        self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
1187                                        'angle_theta0'],
1188                                outputs=['frc_f', 'ene'])
1189        self.add_prim_attr('angle_numbers', self.angle_numbers)
1190
1191    def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape,
1192                    angle_theta0_shape):
1193        cls_name = self.name
1194        n = uint_crd_f_shape[0]
1195        m = self.angle_numbers
1196        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1197        validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
1198        validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
1199        validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
1200        validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
1201        validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
1202        validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
1203
1204        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1205        validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
1206        validator.check_int(atom_a_shape[0], m, Rel.EQ, "atom_a_shape", cls_name)
1207        validator.check_int(atom_b_shape[0], m, Rel.EQ, "atom_b_shape", cls_name)
1208        validator.check_int(atom_c_shape[0], m, Rel.EQ, "atom_c_shape", cls_name)
1209        validator.check_int(angle_k_shape[0], m, Rel.EQ, "angle_k_shape", cls_name)
1210        validator.check_int(angle_theta0_shape[0], m, Rel.EQ, "angle_theta0_shape", cls_name)
1211        return uint_crd_f_shape, [n,]
1212
1213    def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
1214                    angle_theta0_type):
1215        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1216        validator.check_tensor_dtype_valid('scaler_f', scaler_f_type, [mstype.float32], self.name)
1217        validator.check_tensor_dtype_valid('atom_a', atom_a_type, [mstype.int32], self.name)
1218        validator.check_tensor_dtype_valid('atom_b', atom_b_type, [mstype.int32], self.name)
1219        validator.check_tensor_dtype_valid('atom_c', atom_c_type, [mstype.int32], self.name)
1220        validator.check_tensor_dtype_valid('angle_k', angle_k_type, [mstype.float32], self.name)
1221        validator.check_tensor_dtype_valid('angle_theta0', angle_theta0_type, [mstype.float32], self.name)
1222        return angle_k_type, angle_k_type
1223
1224
1225class Dihedral14LJForce(PrimitiveWithInfer):
1226    """
1227    Calculate the Lennard-Jones part of 1,4 dihedral force correction
1228    for each necessary dihedral terms on the corresponding atoms.
1229
1230    Assume the number of necessary dihedral 1,4 terms is m, the number of atoms is n,
1231    and the number of Lennard-Jones types for all atoms is P, which means
1232    there will be q = P*(P+1)/2 types of possible Lennard-Jones interactions
1233    for all kinds of atom pairs.
1234
1235    Because there is a large amount of inputs and each of them are related,
1236    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1237    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1238
1239    .. math::
1240        dr = (x_a-x_b, y_a-y_b, z_a-z_b)
1241    .. math::
1242        F = k*(-12*A/|dr|^{14} + 6*B/|dr|^{8})*dr
1243
1244    Args:
1245        nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1246        atom_numbers (int32): the number of atoms n.
1247
1248    Inputs:
1249        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1250          The data type is uint32 and the shape is :math:`(n, 3)`.
1251        - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1252          The data type is int32 and the shape is :math:`(n,)`.
1253        - **charge** (Tensor) - The charge of each atom.
1254          The data type is float32 and the shape is :math:`(n,)`.
1255        - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1256          The data type is float32 and the shape is :math:`(3,)`.
1257        - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1258          The data type is int32 and the shape is :math:`(m,)`.
1259        - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1260          The data type is int32 and the shape is :math:`(m,)`.
1261        - **lj_scale_factor** (Tensor) - The scale factor for the
1262          Lennard-Jones part of force correction of each dihedral 1,4 term.
1263          The data type is float32 and the shape is :math:`(m,)`.
1264        - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1265          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1266        - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones shceme of each atom pair type.
1267          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1268
1269    Outputs:
1270        - **frc_f** (Tensor) - The force felt by each atom.
1271          The data type is float32 and the shape is :math:`(n, 3)`.
1272
1273    Supported Platforms:
1274        ``GPU``
1275    """
1276
1277    @prim_attr_register
1278    def __init__(self, nb14_numbers, atom_numbers):
1279        """Initialize Dihedral14LJForce."""
1280        validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1281        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1282        self.dihedral_14_numbers = nb14_numbers
1283        self.atom_numbers = atom_numbers
1284        self.init_prim_io_names(
1285            inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1286                    'LJ_type_A', 'LJ_type_B'],
1287            outputs=['frc_f'])
1288        self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1289        self.add_prim_attr('atom_numbers', self.atom_numbers)
1290
1291    def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1292                    lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1293        cls_name = self.name
1294        n = self.atom_numbers
1295        m = self.dihedral_14_numbers
1296        q = lj_type_a_shape[0]
1297        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1298        validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1299        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1300        validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1301        validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1302        validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1303        validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1304        validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1305
1306        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name)
1307        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
1308        validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name)
1309        validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name)
1310        validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
1311        validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
1312        validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1313        validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1314        validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1315        return uint_crd_f_shape
1316
1317    def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1318                    lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
1319        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1320        validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1321        validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1322        validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1323
1324        validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1325        validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1326
1327        validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1328        validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1329        validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1330        return lj_type_b_type
1331
1332
1333class Dihedral14LJEnergy(PrimitiveWithInfer):
1334    """
1335    Calculate the Lennard-Jones part of 1,4 dihedral energy correction for
1336    each necessary dihedral terms on the corresponding atoms.
1337
1338    Because there is a large amount of inputs and each of them are related,
1339    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1340    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1341
1342    .. math::
1343        dr = (x_a-x_b, y_a-y_b, z_a-z-b)
1344    .. math::
1345        E = k*(A/|dr|^{12} - B/|dr|^{6})
1346
1347    Args:
1348        nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1349        atom_numbers (int32): the number of atoms n.
1350
1351    Inputs:
1352        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1353          The data type is uint32 and the shape is :math:`(n, 3)`.
1354        - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1355          The data type is int32 and the shape is :math:`(n,)`.
1356        - **charge** (Tensor) - The charge of each atom.
1357          The data type is float32 and the shape is :math:`(n,)`.
1358        - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1359          The data type is float32 and the shape is :math:`(3,)`.
1360        - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1361          The data type is int32 and the shape is :math:`(m,)`.
1362        - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1363          The data type is int32 and the shape is :math:`(m,)`.
1364        - **lj_scale_factor** (Tensor) - The scale factor for the
1365          Lennard-Jones part of force correction of each dihedral 1,4 term.
1366          The data type is float32 and the shape is :math:`(m,)`.
1367        - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1368          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1369        - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones shceme of each atom pair type.
1370          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1371
1372    Outputs:
1373        - **ene** (Tensor) - The Lennard-Jones potential energy correction.
1374          The data type is float32 and the shape is :math:`(m,)`.
1375
1376    Supported Platforms:
1377        ``GPU``
1378    """
1379
1380    @prim_attr_register
1381    def __init__(self, nb14_numbers, atom_numbers):
1382        """Initialize Dihedral14LJEnergy"""
1383        validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1384        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1385        self.dihedral_14_numbers = nb14_numbers
1386        self.atom_numbers = atom_numbers
1387
1388        self.init_prim_io_names(
1389            inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1390                    'LJ_type_A', 'LJ_type_B'],
1391            outputs=['ene'])
1392        self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1393        self.add_prim_attr('atom_numbers', self.atom_numbers)
1394
1395    def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1396                    lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1397        cls_name = self.name
1398        n = self.atom_numbers
1399        m = self.dihedral_14_numbers
1400        q = lj_type_a_shape[0]
1401        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1402        validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1403        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1404        validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1405        validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1406        validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1407        validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1408        validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1409
1410        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name)
1411        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
1412        validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name)
1413        validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name)
1414        validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
1415        validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
1416        validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1417        validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1418        validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1419        return [self.dihedral_14_numbers,]
1420
1421    def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1422                    lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
1423        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1424        validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1425        validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1426        validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1427        validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1428        validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1429        validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1430        validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1431        validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1432
1433        return lj_type_a_type
1434
1435
1436class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
1437    """
1438    Calculate the Lennard-Jones part and the Coulomb part of force correction
1439    for each necessary dihedral 1,4 terms.
1440
1441    Because there is a large amount of inputs and each of them are related,
1442    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1443    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1444
1445    The calculation formula of the Lennard-Jones part is the same as operator
1446    Dihedral14LJForce(), and the Coulomb part is as follows:
1447
1448    .. math::
1449            dr = (x_a-x_b, y_a-y_b, z_a-z_b)
1450    .. math::
1451            F = -k*q_a*q_b/|r|^3*dr
1452
1453    Args:
1454        nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1455        atom_numbers (int32): the number of atoms n.
1456
1457    Inputs:
1458        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1459          The data type is uint32 and the shape is :math:`(n, 3)`.
1460        - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1461          The data type is int32 and the shape is :math:`(n,)`.
1462        - **charge** (Tensor) - The charge of each atom.
1463          The data type is float32 and the shape is :math:`(n,)`.
1464        - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1465          The data type is float32 and the shape is :math:`(3,)`.
1466        - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1467          The data type is int32 and the shape is :math:`(m,)`.
1468        - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1469          The data type is int32 and the shape is :math:`(m,)`.
1470        - **lj_scale_factor** (Tensor) - The scale factor for the
1471          Lennard-Jones part of force correction of each dihedral 1,4 term.
1472          The data type is float32 and the shape is :math:`(m,)`.
1473        - **cf_scale_factor** (Tensor) - The scale factor for the
1474          Coulomb part of force correction for each dihedral 1,4 terms.
1475          The data type is float32 and the shape is :math:`(m,)`.
1476        - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1477          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1478        - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones shceme of each atom pair type.
1479          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1480
1481    Outputs:
1482        - **frc_f** (Tensor) - The force felt by each atom.
1483          The data type is float32 and the shape is :math:`(n, 3)`
1484
1485    Supported Platforms:
1486        ``GPU``
1487    """
1488
1489    @prim_attr_register
1490    def __init__(self, nb14_numbers, atom_numbers):
1491        """Initialize Dihedral14LJForceWithDirectCF."""
1492        validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1493        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1494        self.dihedral_14_numbers = nb14_numbers
1495        self.atom_numbers = atom_numbers
1496
1497        self.init_prim_io_names(
1498            inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1499                    'cf_scale_factor',
1500                    'LJ_type_A', 'LJ_type_B'],
1501            outputs=['frc_f'])
1502        self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1503        self.add_prim_attr('atom_numbers', self.atom_numbers)
1504
1505    def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1506                    lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1507        cls_name = self.name
1508        n = self.atom_numbers
1509        m = self.dihedral_14_numbers
1510        q = lj_type_a_shape[0]
1511        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1512        validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1513        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1514        validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1515        validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1516        validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1517        validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1518        validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1519        validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1520
1521        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1522        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1523        validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1524        validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1525        validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1526        validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
1527        validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1528        validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1529        validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1530        validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1531        return [self.atom_numbers, 3]
1532
1533    def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1534                    lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type):
1535        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1536        validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1537        validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1538        validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1539        validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1540        validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1541        validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1542        validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name)
1543        validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1544        validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1545
1546        return lj_type_a_type
1547
1548
1549class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
1550    """
1551    Calculate the Lennard-Jones and Coulumb energy correction and force correction
1552    for each necessary dihedral 1,4 terms together and add them to the total force
1553    and potential energy for each atom.
1554
1555    The calculation formula of force correction is the same as operator
1556    :class:`Dihedral14LJForceWithDirectCF`, and the energy correction part is the same
1557    as operator :class:`Dihedral14LJEnergy` and :class:`Dihedral14CFEnergy`.
1558
1559    Because there is a large amount of inputs and each of them are related,
1560    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1561    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1562
1563    Args:
1564        nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1565        atom_numbers (int32): the number of atoms n.
1566
1567    Inputs:
1568        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1569          The data type is uint32 and the shape is :math:`(n, 3)`.
1570        - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1571          The data type is int32 and the shape is :math:`(n,)`.
1572        - **charge** (Tensor) - The charge of each atom.
1573          The data type is float32 and the shape is :math:`(n,)`.
1574        - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1575          The data type is float32 and the shape is :math:`(3,)`.
1576        - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1577          The data type is int32 and the shape is :math:`(m,)`.
1578        - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1579          The data type is int32 and the shape is :math:`(m,)`.
1580        - **lj_scale_factor** (Tensor) - The scale factor for the
1581          Lennard-Jones part of force correction of each dihedral 1,4 term.
1582          The data type is float32 and the shape is :math:`(m,)`.
1583        - **cf_scale_factor** (Tensor) - The scale factor for the
1584          Coulomb part of force correction for each dihedral 1,4 terms.
1585          The data type is float32 and the shape is :math:`(m,)`.
1586        - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1587          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1588        - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones shceme of each atom pair type.
1589          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1590
1591    Outputs:
1592        - **frc_f** (Tensor) - The force felt by each atom.
1593          The data type is float32 and the shape is :math:`(n, 3)`.
1594        - **atom_energy** (Tensor) - The accumulated potential energy for each atom.
1595          The data type is float32 and the shape is :math:`(n,)`.
1596
1597    Supported Platforms:
1598        ``GPU``
1599    """
1600
1601    @prim_attr_register
1602    def __init__(self, nb14_numbers, atom_numbers):
1603        """Initialize Dihedral14LJCFForceWithAtomEnergy."""
1604        validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1605        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1606        self.dihedral_14_numbers = nb14_numbers
1607        self.atom_numbers = atom_numbers
1608
1609        self.init_prim_io_names(
1610            inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1611                    'cf_scale_factor',
1612                    'LJ_type_A', 'LJ_type_B'],
1613            outputs=['frc_f', 'atom_energy'])
1614        self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1615        self.add_prim_attr('atom_numbers', self.atom_numbers)
1616
1617    def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1618                    lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1619        cls_name = self.name
1620        n = self.atom_numbers
1621        m = self.dihedral_14_numbers
1622        q = lj_type_a_shape[0]
1623        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1624        validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1625        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1626        validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1627        validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1628        validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1629        validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1630        validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1631        validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1632
1633        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1634        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1635        validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1636        validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1637        validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1638        validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
1639        validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1640        validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1641        validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1642        validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1643        return uint_crd_f_shape, charge_shape
1644
1645    def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1646                    lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type):
1647        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1648        validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1649        validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1650        validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1651        validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1652        validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1653        validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
1654        validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name)
1655        validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1656        validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1657
1658        return charge_dtype, charge_dtype
1659
1660
1661class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
1662    """
1663    Add the potential energy caused by Lennard-Jones energy correction for each
1664    necessary dihedral 1,4 terms to the total potential energy of each atom.
1665
1666    The calculation formula is the same as operator Dihedral14LJEnergy().
1667
1668    Because there is a large amount of inputs and each of them are related,
1669    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1670    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1671
1672    Args:
1673        nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1674        atom_numbers (int32): the number of atoms n.
1675
1676    Inputs:
1677        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1678          The data type is uint32 and the shape is :math:`(n, 3)`.
1679        - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1680          The data type is int32 and the shape is :math:`(n,)`.
1681        - **charge** (Tensor) - The charge of each atom.
1682          The data type is float32 and the shape is :math:`(n,)`.
1683        - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1684          The data type is float32 and the shape is :math:`(3,)`.
1685        - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1686          The data type is int32 and the shape is :math:`(m,)`.
1687        - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1688          The data type is int32 and the shape is :math:`(m,)`.
1689        - **lj_scale_factor** (Tensor) - The scale factor for the
1690          Lennard-Jones part of force correction of each dihedral 1,4 term.
1691          The data type is float32 and the shape is :math:`(m,)`.
1692        - **cf_scale_factor** (Tensor) - The scale factor for the
1693          Coulomb part of force correction for each dihedral 1,4 terms.
1694          The data type is float32 and the shape is :math:`(m,)`.
1695        - **LJ_type_A** (Tensor) - The A parameter in Lennard-Jones scheme of each atom pair type.
1696          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1697        - **LJ_type_B** (Tensor) - The B parameter in Lennard-Jones shceme of each atom pair type.
1698          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
1699
1700    Outputs:
1701        - **ene** (Tensor) - The accumulated potential energy of each atom.
1702          The data type is float32 and the shape is :math:`(n,)`.
1703
1704    Supported Platforms:
1705        ``GPU``
1706    """
1707
1708    @prim_attr_register
1709    def __init__(self, nb14_numbers, atom_numbers):
1710        """Initialize Dihedral14LJAtomEnergy."""
1711        validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1712        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1713        self.dihedral_14_numbers = nb14_numbers
1714        self.atom_numbers = atom_numbers
1715
1716        self.init_prim_io_names(
1717            inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'lj_scale_factor',
1718                    'LJ_type_A', 'LJ_type_B'],
1719            outputs=['ene'])
1720        self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1721        self.add_prim_attr('atom_numbers', self.atom_numbers)
1722
1723    def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1724                    lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
1725        cls_name = self.name
1726        n = self.atom_numbers
1727        q = lj_type_a_shape[0]
1728        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1729        validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1730        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1731        validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1732        validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1733        validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1734        validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
1735        validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
1736
1737        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1738        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1739        validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1740        validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1741        validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1742        validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
1743        m = self.dihedral_14_numbers
1744        validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1745        validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1746        validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
1747        return ljtype_shape
1748
1749    def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1750                    lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
1751        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1752        validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1753        validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1754        validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1755        validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1756        validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1757        validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32],
1758                                           self.name)
1759        validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
1760        validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
1761
1762        return lj_type_a_type
1763
1764
1765class Dihedral14CFEnergy(PrimitiveWithInfer):
1766    """
1767    Calculate the Coulumb part of 1,4 dihedral energy correction for
1768    each necessary dihedral terms on the corresponding atoms.
1769
1770    Because there is a large amount of inputs and each of them are related,
1771    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1772    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1773
1774    .. math::
1775
1776        dr = (x_a-x_b, y_a-y_b, z_a-z_b)
1777
1778    .. math::
1779        E = k*q_a*q_b/|dr|
1780
1781    Args:
1782        nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1783        atom_numbers (int32): the number of atoms n.
1784
1785    Inputs:
1786        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1787          The data type is uint32 and the shape is :math:`(n, 3)`.
1788        - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1789          The data type is int32 and the shape is :math:`(n,)`.
1790        - **charge** (Tensor) - The charge of each atom.
1791          The data type is float32 and the shape is :math:`(n,)`.
1792        - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1793          The data type is float32 and the shape is :math:`(3,)`.
1794        - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1795          The data type is int32 and the shape is :math:`(m,)`.
1796        - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1797          The data type is int32 and the shape is :math:`(m,)`.
1798        - **cf_scale_factor** (Tensor) - The scale factor for the
1799          Coulomb part of force correction for each dihedral 1,4 terms.
1800          The data type is float32 and the shape is :math:`(m,)`.
1801
1802    Outputs:
1803        - **ene** (Tensor) - The accumulated potential energy of each atom.
1804          The data type is float32 and the shape is :math:`(m,)`.
1805
1806    Supported Platforms:
1807        ``GPU``
1808    """
1809
1810    @prim_attr_register
1811    def __init__(self, nb14_numbers, atom_numbers):
1812        """Initialize Dihedral14CFEnergy."""
1813        validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1814        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1815        self.dihedral_14_numbers = nb14_numbers
1816        self.atom_numbers = atom_numbers
1817
1818        self.init_prim_io_names(
1819            inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'cj_scale_factor'],
1820            outputs=['ene'])
1821        self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1822        self.add_prim_attr('atom_numbers', self.atom_numbers)
1823
1824    def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1825                    cf_scale_factor_shape):
1826        cls_name = self.name
1827        n = self.atom_numbers
1828        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1829        validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1830        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1831        validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1832        validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1833        validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1834        validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1835
1836        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1837        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1838        validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1839        validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1840        validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1841        m = self.dihedral_14_numbers
1842        validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1843        validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1844        validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1845        return [self.dihedral_14_numbers,]
1846
1847    def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1848                    cf_scale_factor_type):
1849        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1850        validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1851        validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1852        validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1853        validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1854        validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1855        validator.check_tensor_dtype_valid('lj_scale_factor', cf_scale_factor_type, [mstype.float32],
1856                                           self.name)
1857
1858        return charge_dtype
1859
1860
1861class Dihedral14CFAtomEnergy(PrimitiveWithInfer):
1862    """
1863    Add the potential energy caused by Coulumb energy correction for each
1864    necessary dihedral 1,4 terms to the total potential energy of each atom.
1865
1866    The calculation formula is the same as operator :class:`Dihedral14CFEnergy`.
1867
1868    Because there is a large amount of inputs and each of them are related,
1869    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1870    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1871
1872    Args:
1873        nb14_numbers (int32): the number of necessary dihedral 1,4 terms m.
1874        atom_numbers (int32): the number of atoms n.
1875
1876    Inputs:
1877        - **uint_crd_f** (Tensor) - The unsigned int coordinate value of each atom.
1878          The data type is uint32 and the shape is :math:`(n, 3)`.
1879        - **LJ_type** (Tensor) - The Lennard-Jones type of each atom.
1880          The data type is int32 and the shape is :math:`(n,)`.
1881        - **charge** (Tensor) - The charge of each atom.
1882          The data type is float32 and the shape is :math:`(n,)`.
1883        - **boxlength_f** (Tensor) - The length of molecular simulation box in 3 dimensions.
1884          The data type is float32 and the shape is :math:`(3,)`.
1885        - **a_14** (Tensor) - The first atom index of each dihedral 1,4 term.
1886          The data type is int32 and the shape is :math:`(m,)`.
1887        - **b_14** (Tensor) - The second atom index of each dihedral 1,4 term.
1888          The data type is int32 and the shape is :math:`(m,)`.
1889        - **cf_scale_factor** (Tensor) - The scale factor for the
1890          Coulomb part of force correction for each dihedral 1,4 terms.
1891          The data type is float32 and the shape is :math:`(m,)`.
1892
1893    Outputs:
1894        - **ene** (Tensor) - The accumulated potential energy of each atom.
1895          The data type is float32 and the shape is :math:`(n,)`
1896
1897
1898    Supported Platforms:
1899        ``GPU``
1900    """
1901
1902    @prim_attr_register
1903    def __init__(self, nb14_numbers, atom_numbers):
1904        """Initialize Dihedral14CFAtomEnergy."""
1905        validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
1906        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1907        self.dihedral_14_numbers = nb14_numbers
1908        self.atom_numbers = atom_numbers
1909
1910        self.init_prim_io_names(
1911            inputs=['uint_crd_f', 'LJtype', 'charge', 'boxlength_f', 'a_14', 'b_14', 'cf_scale_factor'],
1912            outputs=['ene'])
1913        self.add_prim_attr('dihedral_14_numbers', self.dihedral_14_numbers)
1914        self.add_prim_attr('atom_numbers', self.atom_numbers)
1915
1916    def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
1917                    cf_scale_factor_shape):
1918        cls_name = self.name
1919        n = self.atom_numbers
1920        validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
1921        validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
1922        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
1923        validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
1924        validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
1925        validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
1926        validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
1927
1928        validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
1929        validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
1930        validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
1931        validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
1932        validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
1933        m = self.dihedral_14_numbers
1934        validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
1935        validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
1936        validator.check_int(cf_scale_factor_shape[0], m, Rel.EQ, "cf_scale_factor_shape", cls_name)
1937        return ljtype_shape
1938
1939    def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
1940                    cf_scale_factor_type):
1941        validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
1942        validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
1943        validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
1944        validator.check_tensor_dtype_valid('boxlength_f', boxlength_f_type, [mstype.float32], self.name)
1945        validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
1946        validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
1947        validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32],
1948                                           self.name)
1949
1950        return charge_dtype
1951
1952
1953class PMEReciprocalForce(PrimitiveWithInfer):
1954    """
1955    Calculate the reciprocal part of long-range Coulumb force using
1956    PME(Particle Meshed Ewald) method. Assume the number of atoms is n.
1957
1958    The detailed calculation formula of PME(Particle Meshed Ewald) method
1959    can be found in this paper: A Smooth Particle Mesh Ewald Method. DOI:
1960    10.1063/1.470117.
1961
1962    Because there is a large amount of inputs and each of them are related,
1963    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
1964    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
1965
1966    Args:
1967        atom_numbers(int32): the number of atoms, n.
1968        beta(float32): the PME beta parameter, determined by the
1969                       non-bond cutoff value and simulation precision tolerance.
1970        fftx(int32): the number of points for Fourier transform in dimension X.
1971        ffty(int32): the number of points for Fourier transform in dimension Y.
1972        fftz(int32): the number of points for Fourier transform in dimension Z.
1973        box_length_0(float32): the value of boxlength idx 0
1974        box_length_1(float32): the value of boxlength idx 1
1975        box_length_2(float32): the value of boxlength idx 2
1976
1977    Inputs:
1978        - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
1979          The data type is uint32 and the shape is :math:`(n, 3)`
1980        - **charge** (Tensor) - The charge carried by each atom.
1981          The data type is float32 and the shape is :math:`(n,)`
1982
1983    Outputs:
1984        - **force** (Tensor) - The force felt by each atom.
1985          The data type is float32 and the shape is :math:`(n, 3)`
1986
1987    Supported Platforms:
1988        ``GPU``
1989    """
1990
1991    @prim_attr_register
1992    def __init__(self, atom_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1, box_length_2):
1993        """Initialize PMEReciprocalForce."""
1994        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
1995        validator.check_value_type('beta', beta, float, self.name)
1996        validator.check_value_type('fftx', fftx, int, self.name)
1997        validator.check_value_type('ffty', ffty, int, self.name)
1998        validator.check_value_type('fftz', fftz, int, self.name)
1999        validator.check_value_type('box_length_0', box_length_0, float, self.name)
2000        validator.check_value_type('box_length_1', box_length_1, float, self.name)
2001        validator.check_value_type('box_length_2', box_length_2, float, self.name)
2002        self.atom_numbers = atom_numbers
2003        self.beta = beta
2004        self.fftx = fftx
2005        self.ffty = ffty
2006        self.fftz = fftz
2007        self.box_length_0 = box_length_0
2008        self.box_length_1 = box_length_1
2009        self.box_length_2 = box_length_2
2010
2011        self.init_prim_io_names(inputs=['boxlength', 'uint_crd', 'charge'],
2012                                outputs=['force'])
2013        self.add_prim_attr('atom_numbers', self.atom_numbers)
2014        self.add_prim_attr('beta', self.beta)
2015        self.add_prim_attr('fftx', self.fftx)
2016        self.add_prim_attr('ffty', self.ffty)
2017        self.add_prim_attr('fftz', self.fftz)
2018        self.add_prim_attr('box_length_0', self.box_length_0)
2019        self.add_prim_attr('box_length_1', self.box_length_1)
2020        self.add_prim_attr('box_length_2', self.box_length_2)
2021
2022    def infer_shape(self, uint_crd_shape, charge_shape):
2023        cls_name = self.name
2024        n = self.atom_numbers
2025        validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
2026        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
2027
2028        validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2029        validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2030        validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
2031        return uint_crd_shape
2032
2033    def infer_dtype(self, uint_crd_type, charge_type):
2034        validator.check_tensor_dtype_valid('uint_crd', uint_crd_type, [mstype.uint32], self.name)
2035        validator.check_tensor_dtype_valid('charge', charge_type, [mstype.float32], self.name)
2036        return charge_type
2037
2038
2039class PMEExcludedForce(PrimitiveWithInfer):
2040    """
2041    Calculate the excluded  part of long-range Coulumb force using
2042    PME(Particle Meshed Ewald) method. Assume the number of atoms is
2043    n, and the length of excluded list is E.
2044
2045    Because there is a large amount of inputs and each of them are related,
2046    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2047    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2048
2049    Args:
2050        atom_numbers(int32): the number of atoms, n.
2051        excluded_numbers(int32): the length of excluded list, E.
2052        beta(float32): the PME beta parameter, determined by the
2053          non-bond cutoff value and simulation precision tolerance.
2054
2055    Inputs:
2056        - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
2057          The data type is uint32 and the shape is :math:`(n, 3)`
2058        - **scaler** (Tensor) - The scale factor between real space
2059          coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2060        - **charge** (Tensor) - The charge carried by each atom.
2061          The data type is float32 and the shape is :math:`(n,)`
2062        - **excluded_list_start** (Tensor) - The start excluded index
2063          in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2064        - **excluded_list** (Tensor) - The contiguous join of excluded
2065          list of each atom. E is the number of excluded atoms. The data type is int32 and the shape is :math:`(E,)`
2066        - **excluded_atom_numbers** (Tensor) - The number of atom excluded
2067          in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2068
2069    Outputs:
2070        - **force** (Tensor) - The force felt by each atom.
2071          The data type is float32 and the shape is :math:`(n, 3)`
2072
2073    Supported Platforms:
2074        ``GPU``
2075    """
2076
2077    @prim_attr_register
2078    def __init__(self, atom_numbers, excluded_numbers, beta):
2079        """Initialize PMEExcludedForce."""
2080        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2081        validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
2082        validator.check_value_type('beta', beta, float, self.name)
2083        self.atom_numbers = atom_numbers
2084        self.excluded_numbers = excluded_numbers
2085        self.beta = beta
2086        self.init_prim_io_names(
2087            inputs=['uint_crd', 'sacler', 'charge', 'excluded_list_start', 'excluded_list', 'excluded_atom_numbers'],
2088            outputs=['force'])
2089        self.add_prim_attr('atom_numbers', self.atom_numbers)
2090        self.add_prim_attr('excluded_numbers', self.excluded_numbers)
2091        self.add_prim_attr('beta', self.beta)
2092
2093    def infer_shape(self, uint_crd_shape, sacler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape,
2094                    excluded_atom_numbers_shape):
2095        cls_name = self.name
2096        n = self.atom_numbers
2097        validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
2098        validator.check_int(len(sacler_shape), 1, Rel.EQ, "sacler_dim", cls_name)
2099        validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
2100        validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
2101        validator.check_int(len(excluded_atom_numbers_shape), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
2102
2103        validator.check_int(uint_crd_shape[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2104        validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2105        validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler_shape", cls_name)
2106        validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
2107        validator.check_int(excluded_list_start_shape[0], n, Rel.EQ, "excluded_list_start_shape", cls_name)
2108        validator.check_int(excluded_atom_numbers_shape[0], n, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
2109        return uint_crd_shape
2110
2111    def infer_dtype(self, uint_crd_type, sacler_type, charge_type, excluded_list_start_type, excluded_list_type,
2112                    excluded_atom_numbers_type):
2113        validator.check_tensor_dtype_valid('sacler', sacler_type, [mstype.float32], self.name)
2114        validator.check_tensor_dtype_valid('uint_crd', uint_crd_type, [mstype.uint32], self.name)
2115        validator.check_tensor_dtype_valid('charge', charge_type, [mstype.float32], self.name)
2116        validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start_type, [mstype.int32],
2117                                           self.name)
2118        validator.check_tensor_dtype_valid('excluded_list', excluded_list_type, [mstype.int32],
2119                                           self.name)
2120        validator.check_tensor_dtype_valid('excluded_atom_numbers', excluded_atom_numbers_type, [mstype.int32],
2121                                           self.name)
2122        return charge_type
2123
2124
2125class PMEEnergy(PrimitiveWithInfer):
2126    """
2127    Calculate the Coulumb energy of the system using PME method.
2128
2129    Because there is a large amount of inputs and each of them are related,
2130    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2131    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2132
2133    .. math::
2134
2135        E = sum_{ij} q_iq_j/r_{ij}
2136
2137    Args:
2138        atom_numbers(int32): the number of atoms, n.
2139        excluded_numbers(int32): the length of excluded list, E.
2140        beta(float32): the PME beta parameter, determined by the
2141                       non-bond cutoff value and simulation precision tolerance.
2142        fftx(int32): the number of points for Fourier transform in dimension X.
2143        ffty(int32): the number of points for Fourier transform in dimension Y.
2144        fftz(int32): the number of points for Fourier transform in dimension Z.
2145        box_length_0(float32): the value of boxlength idx 0
2146        box_length_1(float32): the value of boxlength idx 1
2147        box_length_2(float32): the value of boxlength idx 2
2148
2149
2150    Inputs:
2151        - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
2152          The data type is uint32 and the shape is :math:`(n, 3)`
2153        - **charge** (Tensor) - The charge carried by each atom.
2154          The data type is float32 and the shape is :math:`(n,)`
2155        - **nl_numbers** - (Tensor) - The each atom.
2156          The data type is int32 and the shape is :math:`(n, 3)`
2157        - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2158          The data type is int32 and the shape is :math:`(n, 800)`
2159        - **scaler** (Tensor) - The scale factor between real space
2160          coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2161        - **excluded_list_start** (Tensor) - The start excluded index
2162          in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2163        - **excluded_list** (Tensor) - The contiguous join of excluded
2164          list of each atom. E is the number of excluded atoms. The data type is int32 and the shape is :math:`(E,)`
2165        - **excluded_atom_numbers** (Tensor) - The number of atom excluded
2166          in excluded list for each atom. The data type is int32 and the shape is :math:`(n,)`
2167
2168    Outputs:
2169        - **reciprocal_ene** (Tensor) - The reciprocal term of PME energy.
2170          The data type is float32 and the the shape is :math:`(1,)`.
2171        - **self_ene** (Tensor) - The self term of PME energy.
2172          The data type is float32 and the the shape is :math:`(1,)`.
2173        - **direct_ene** (Tensor) - The direct term of PME energy.
2174          The data type is float32 and the the shape is :math:`(1,)`.
2175        - **correction_ene** (Tensor) - The correction term of PME energy.
2176          The data type is float32 and the the shape is :math:`(1,)`.
2177
2178    Supported Platforms:
2179        ``GPU``
2180    """
2181
2182    @prim_attr_register
2183    def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1,
2184                 box_length_2):
2185        """Initialize PMEEnergy."""
2186        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2187        validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
2188        validator.check_value_type('beta', beta, float, self.name)
2189        validator.check_value_type('fftx', fftx, int, self.name)
2190        validator.check_value_type('ffty', ffty, int, self.name)
2191        validator.check_value_type('fftz', fftz, int, self.name)
2192        validator.check_value_type('box_length_0', box_length_0, float, self.name)
2193        validator.check_value_type('box_length_1', box_length_1, float, self.name)
2194        validator.check_value_type('box_length_2', box_length_2, float, self.name)
2195        self.atom_numbers = atom_numbers
2196        self.excluded_numbers = excluded_numbers
2197        self.beta = beta
2198        self.fftx = fftx
2199        self.ffty = ffty
2200        self.fftz = fftz
2201        self.box_length_0 = box_length_0
2202        self.box_length_1 = box_length_1
2203        self.box_length_2 = box_length_2
2204        self.init_prim_io_names(
2205            inputs=['box_length', 'uint_crd', 'charge', 'nl_numbers', 'nl_serial', 'scaler', 'excluded_list_start',
2206                    'excluded_list', 'excluded_atom_numbers'],
2207            outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene'])
2208        self.add_prim_attr('atom_numbers', self.atom_numbers)
2209        self.add_prim_attr('excluded_numbers', self.excluded_numbers)
2210        self.add_prim_attr('beta', self.beta)
2211        self.add_prim_attr('fftx', self.fftx)
2212        self.add_prim_attr('ffty', self.ffty)
2213        self.add_prim_attr('fftz', self.fftz)
2214        self.add_prim_attr('box_length_0', self.box_length_0)
2215        self.add_prim_attr('box_length_1', self.box_length_1)
2216        self.add_prim_attr('box_length_2', self.box_length_2)
2217
2218    def infer_shape(self, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
2219                    excluded_list, excluded_atom_numbers):
2220        cls_name = self.name
2221        n = self.atom_numbers
2222        validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2223        validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2224        validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2225        validator.check_int(len(nl_serial), 2, Rel.LE, "nl_serial_dim", cls_name)
2226        validator.check_int(len(excluded_list_start), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
2227        validator.check_int(len(excluded_atom_numbers), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
2228        validator.check_int(len(excluded_list), 1, Rel.GE, "excluded_list", cls_name)
2229
2230        validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2231        validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2232        validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2233        validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape[0]", cls_name)
2234        validator.check_int(nl_serial[0], n, Rel.LE, "nl_serial_shape[0]", cls_name)
2235        validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
2236        validator.check_int(excluded_list_start[0], n, Rel.EQ, "excluded_list_start_shape", cls_name)
2237        validator.check_int(excluded_atom_numbers[0], n, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
2238        validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list_shape", cls_name)
2239        return (1,), (1,), (1,), (1,)
2240
2241    def infer_dtype(self, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
2242                    excluded_list, excluded_atom_numbers):
2243        validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2244        validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2245        validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2246        validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2247        validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2248        validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start, [mstype.int32],
2249                                           self.name)
2250        validator.check_tensor_dtype_valid('excluded_list', excluded_list, [mstype.int32],
2251                                           self.name)
2252        validator.check_tensor_dtype_valid('excluded_atom_numbers', excluded_atom_numbers, [mstype.int32],
2253                                           self.name)
2254        return charge, charge, charge, charge
2255
2256
2257class LJEnergy(PrimitiveWithInfer):
2258    """
2259    Calculate the Van der Waals interaction energy described by Lennard-Jones
2260    potential for each atom. Assume the number of atoms is n, and the number
2261    of Lennard-Jones types for all atoms is P, which means there will be
2262    q = P*(P+1)/2 types of possible Lennard-Jones interactions for all kinds
2263    of atom pairs.
2264
2265    Because there is a large amount of inputs and each of them are related,
2266    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2267    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2268
2269    .. math::
2270
2271        dr = (x_a-x_b, y_a-y_b, z_a-z_b)
2272
2273    .. math::
2274        E = A/|dr|^{12} - B/|dr|^{6}
2275
2276    Args:
2277        atom_numbers(int32): the number of atoms, n.
2278        cutoff_square(float32): the square value of cutoff.
2279
2280    Inputs:
2281        - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2282            The data type is uint32 and the shape is :math:`(n, 3)`
2283        - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
2284           The data type is int32 and the shape is :math:`(n,)`
2285        - **charge** (Tensor) - The charge carried by each atom.
2286           The data type is float32 and the shape is :math:`(n,)`
2287        - **scaler** (Tensor) - The scale factor between real
2288          space coordinate and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2289        - **nl_numbers** - (Tensor) - The each atom.
2290          The data type is int32 and the shape is :math:`(n,)`
2291        - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2292          The data type is int32 and the shape is :math:`(n, 800)`.
2293        - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
2294          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2295        - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
2296          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2297
2298    Outputs:
2299        - **d_LJ_energy_atom** (Tensor) - The Lennard-Jones potential energy of each atom.
2300           The data type is float32 and the shape is :math:`(n,)`.
2301        - **d_LJ_energy_sum** (Scalar), the sum of Lennard-Jones potential energy of each atom.
2302          The data type is float32.
2303
2304    Supported Platforms:
2305        ``GPU``
2306    """
2307
2308    @prim_attr_register
2309    def __init__(self, atom_numbers, cutoff_square):
2310        """Initialize LJEnergy."""
2311        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2312        validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
2313        self.atom_numbers = atom_numbers
2314        self.cutoff_square = cutoff_square
2315        self.init_prim_io_names(
2316            inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B'],
2317            outputs=['d_LJ_energy_atom'])
2318        self.add_prim_attr('atom_numbers', self.atom_numbers)
2319        self.add_prim_attr('cutoff_square', self.cutoff_square)
2320
2321    def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2322        cls_name = self.name
2323        n = self.atom_numbers
2324        q = d_lj_a[0]
2325        validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2326        validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
2327        validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2328        validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
2329        validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2330        validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
2331        validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
2332
2333        validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2334        validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2335        validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
2336        validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2337        validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
2338        validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
2339        validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
2340        validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
2341        validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
2342        validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
2343        validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
2344        return charge
2345
2346    def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2347        validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2348        validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
2349        validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2350        validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2351        validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2352        validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2353        validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
2354        validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
2355        return charge
2356
2357
2358class LJForce(PrimitiveWithInfer):
2359    """
2360    Calculate the Van der Waals interaction force described by Lennard-Jones
2361    potential energy for each atom.
2362
2363    Because there is a large amount of inputs and each of them are related,
2364    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2365    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2366
2367    .. math::
2368
2369        dr = (x_a-x_b, y_a-y_b, z_a-z_b)
2370
2371    .. math::
2372
2373        F = (-12*A/|dr|^{14} + 6*B/|dr|^{8}) * dr
2374
2375    Args:
2376        atom_numbers(int32): the number of atoms, n.
2377        cutoff_square(float32): the square value of cutoff.
2378
2379    Inputs:
2380        - **uint_crd** (Tensor) - The unsigned int coordinates value of each atom.
2381          The data type is uint32 and the shape is :math:`(n, 3)`
2382        - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
2383           The data type is int32 and the shape is :math:`(n,)`
2384        - **charge** (Tensor) - The charge carried by each atom.
2385           The data type is float32 and the shape is :math:`(n,)`
2386        - **scaler** (Tensor) - The scale factor between real space
2387          coordinates and its unsigned int value. The data type is float32 and the shape is :math:`(3,)`
2388        - **nl_numbers** - (Tensor) - The each atom.
2389          The data type is int32 and the shape is :math:`(n,)`
2390        - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2391          The data type is int32 and the shape is :math:`(n, 800)`.
2392        - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
2393          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2394        - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
2395          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2396
2397    outputs:
2398        - **frc** (Tensor) - The force felt by each atom.
2399          The data type is float32 and the shape is :math:`(n, 3)`.
2400
2401    Supported Platforms:
2402        ``GPU``
2403    """
2404
2405    @prim_attr_register
2406    def __init__(self, atom_numbers, cutoff_square):
2407        """Initialize LJForce."""
2408        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2409        validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
2410        self.atom_numbers = atom_numbers
2411        self.cutoff_square = cutoff_square
2412        self.init_prim_io_names(
2413            inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B'],
2414            outputs=['frc'])
2415        self.add_prim_attr('atom_numbers', self.atom_numbers)
2416        self.add_prim_attr('cutoff_square', self.cutoff_square)
2417
2418    def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2419        cls_name = self.name
2420        n = self.atom_numbers
2421        q = d_lj_a[0]
2422        validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2423        validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
2424        validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2425        validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
2426        validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2427        validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
2428        validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
2429
2430        validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2431        validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2432        validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
2433        validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2434        validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
2435        validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
2436        validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
2437        validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
2438        validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
2439        validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
2440        validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
2441        return uint_crd
2442
2443    def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2444        validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2445        validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
2446        validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2447        validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2448        validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2449        validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2450        validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
2451        validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
2452        return charge
2453
2454
2455class LJForceWithPMEDirectForce(PrimitiveWithInfer):
2456    """
2457    Calculate the Lennard-Jones force and PME direct force together.
2458
2459    The calculation formula of Lennard-Jones part is the same as operator
2460    LJForce(), and the PME direct part is within PME method.
2461
2462    Because there is a large amount of inputs and each of them are related,
2463    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2464    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2465
2466    Args:
2467        atom_numbers(int32): the number of atoms, n.
2468        cutoff_square(float32): the square value of cutoff.
2469        pme_beta(float32): PME beta parameter, same as operator PMEReciprocalForce().
2470
2471    Inputs:
2472        - **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
2473          The data type is uint32 and the shape is :math:`(n, 3)`.
2474        - **LJtype** (Tensor) - The Lennard-Jones type of each atom.
2475          The data type is int32 and the shape is :math:`(n,)`.
2476        - **charge** (Tensor) - The charge carried by each atom.
2477          The data type is float32 and the shape is :math:`(n,)`.
2478        - **scaler** (Tensor) - The scale factor between real
2479          space coordinate and its unsigned int value.
2480          The data type is float32 and the shape is :math:`(3,)`.
2481        - **nl_numbers** - (Tensor) - The each atom.
2482          The data type is int32 and the shape is :math:`(n,)`.
2483        - **nl_serial** - (Tensor) - The neighbor list of each atom, the max number is 800.
2484          The data type is int32 and the shape is :math:`(n, 800)`.
2485        - **d_LJ_A** (Tensor) - The Lennard-Jones A coefficient of each kind of atom pair.
2486          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2487        - **d_LJ_B** (Tensor) - The Lennard-Jones B coefficient of each kind of atom pair.
2488          q is the number of atom pair. The data type is float32 and the shape is :math:`(q,)`.
2489
2490    Outputs:
2491        - **frc** (Tensor), The force felt by each atom.
2492          The data type is float32 and the shape is :math:`(n, 3)`.
2493
2494    Supported Platforms:
2495        ``GPU``
2496    """
2497
2498    @prim_attr_register
2499    def __init__(self, atom_numbers, cutoff, pme_beta):
2500        """Initialize LJForceWithPMEDirectForce."""
2501        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2502        validator.check_value_type('cutoff', cutoff, float, self.name)
2503        validator.check_value_type('pme_beta', pme_beta, float, self.name)
2504        self.atom_numbers = atom_numbers
2505        self.cutoff = cutoff
2506        self.pme_beta = pme_beta
2507        self.init_prim_io_names(
2508            inputs=['uint_crd', 'LJtype', 'charge', 'scaler', 'nl_numbers', 'nl_serial', 'd_LJ_A', 'd_LJ_B'],
2509            outputs=['frc'])
2510        self.add_prim_attr('atom_numbers', self.atom_numbers)
2511        self.add_prim_attr('cutoff', self.cutoff)
2512        self.add_prim_attr('pme_beta', self.pme_beta)
2513
2514    def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2515        cls_name = self.name
2516        n = self.atom_numbers
2517        q = d_lj_a[0]
2518        validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
2519        validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
2520        validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
2521        validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
2522        validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
2523        validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
2524        validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
2525
2526        validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
2527        validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
2528        validator.check_int(ljtype[0], n, Rel.EQ, "LJtype_shape", cls_name)
2529        validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
2530        validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
2531        validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
2532        validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
2533        validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
2534        validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
2535        validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
2536        validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
2537        return uint_crd
2538
2539    def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
2540        validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
2541        validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
2542        validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
2543        validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
2544        validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
2545        validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
2546        validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
2547        validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
2548        return charge
2549
2550
2551class MDTemperature(PrimitiveWithInfer):
2552    """
2553    Compute the MD temperature.
2554
2555    Because there is a large amount of inputs and each of them are related,
2556    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2557    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2558
2559    Args:
2560        residue_numbers (int32): the number of residues m.
2561        atom_numbers (int32): the number of atoms n.
2562
2563    Inputs:
2564        - **start** (Tensor) - The start atom index of each residue.
2565          The data type is int32 and the shape is :math:`(m,)`.
2566        - **end** (Tensor) - The end atom index of each residue.
2567          The data type is int32 and the shape is :math:`(m,)`.
2568        - **atom_vel_f** (Tensor) - The velocity of each atom.
2569          The data type is float32 and the shape is :math:`(n, 3)`.
2570        - **atom_mass** (Tensor) - The mass of each atom.
2571          The data type is float32 and the shape is :math:`(n,)`.
2572
2573    Outputs:
2574        - **ek** (Tensor) - The temperature of each atom.
2575          The data type is float32 and the shape is :math:`(n,)`.
2576
2577    Supported Platforms:
2578        ``GPU``
2579    """
2580
2581    @prim_attr_register
2582    def __init__(self, residue_numbers, atom_numbers):
2583        """Initialize MDTemperature."""
2584        validator.check_value_type('residue_numbers', residue_numbers, int, self.name)
2585        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2586        self.residue_numbers = residue_numbers
2587        self.atom_numbers = atom_numbers
2588        self.add_prim_attr('residue_numbers', self.residue_numbers)
2589        self.add_prim_attr('atom_numbers', self.atom_numbers)
2590        self.init_prim_io_names(
2591            inputs=['start', 'end', 'atom_vel_f', 'atom_mass'],
2592            outputs=['ek'])
2593
2594    def infer_shape(self, start_shape, end_shape, atom_vel_f_shape, atom_mass_shape):
2595        cls_name = self.name
2596        n = self.residue_numbers
2597        m = self.atom_numbers
2598        validator.check_int(len(start_shape), 1, Rel.EQ, "start", cls_name)
2599        validator.check_int(start_shape[0], n, Rel.EQ, "end", cls_name)
2600        validator.check_int(len(end_shape), 1, Rel.EQ, "start", cls_name)
2601        validator.check_int(end_shape[0], n, Rel.EQ, "end", cls_name)
2602        validator.check_int(atom_vel_f_shape[0], m, Rel.EQ, "atom_vel_f", cls_name)
2603        validator.check_int(atom_vel_f_shape[1], 3, Rel.EQ, "atom_vel_f", cls_name)
2604        validator.check_int(len(atom_mass_shape), 1, Rel.EQ, "atom_mass", cls_name)
2605        validator.check_int(atom_mass_shape[0], m, Rel.EQ, "atom_mass", cls_name)
2606        return [n,]
2607
2608    def infer_dtype(self, start_dtype, end_dtype, atom_vel_f_dtype, atom_mass_dtype):
2609        validator.check_tensor_dtype_valid('start', start_dtype, [mstype.int32], self.name)
2610        validator.check_tensor_dtype_valid('end', end_dtype, [mstype.int32], self.name)
2611        validator.check_tensor_dtype_valid('atom_vel_f', atom_vel_f_dtype, [mstype.float32], self.name)
2612        validator.check_tensor_dtype_valid('atom_mass', atom_mass_dtype, [mstype.float32], self.name)
2613        return atom_mass_dtype
2614
2615
2616class MDIterationLeapFrogWithRF(PrimitiveWithInfer):
2617    """
2618    One step of classical leap frog algorithm to solve the finite difference
2619    Hamiltonian equations of motion for certain system, using Langevin dynamics
2620    with Liu's thermostat scheme. Assume the number of atoms is n and the target
2621    control temperature is T.
2622
2623    Detailed iteration formula can be found in this paper: A unified thermostat
2624    scheme for efficient configurational sampling for classical/quantum canonical
2625    ensembles via molecular dynamics. DOI: 10.1063/1.4991621.
2626
2627    Because there is a large amount of inputs and each of them are related,
2628    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2629    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2630
2631    Inputs:
2632        - **float4_numbers** (Scalar) - total length to store random numbers.
2633          The data type is int32.
2634        - **atom_numbers** (Scalar) - The number of atoms n.
2635          The data type is int32.
2636        - **dt** (Scalar) - time step for finite difference. The data type is float32.
2637        - **half_dt** (Scalar) - half of time step for finite difference.
2638          The data type is float32.
2639        - **exp_gamma** (Scalar) - parameter in Liu's dynamic, equals
2640          exp(-gamma_ln * dt), where gamma_ln is the firction factor in Langvin
2641          dynamics. The data type is float32.
2642        - **max_velocity** (Scalar) - The upper limit of velocity, when the
2643          veclocity overflows, scale it to the upper limit. The data type is float32.
2644        - **is_max_velocity** (Scalar) - whether the max velocity control is
2645          open or not. The data type is int32.
2646        - **mass_inverse** (Tensor) - The inverse value of
2647          mass of each atom. The data type is float32 and the shape is :math:`(n,)`.
2648        - **sqrt_mass** (Tensor) - The inverse square root value
2649          of effect mass in Liu's dynamics of each atom.
2650          The data type is float32 and the shape is :math:`(n,)`.
2651        - **vel** (Tensor) - The velocity of each atom.
2652          The data type is float32 and the shape is :math:`(n, 3)`.
2653        - **crd** (Tensor) - The coordinate of each atom.
2654          The data type is float32 and the shape is :math:`(n, 3)`.
2655        - **frc** (Tensor) - The force felt by each atom.
2656          The data type is float32 and the shape is :math:`(n, 3)`.
2657        - **acc** (Tensor) - The acceleration of each atom.
2658          The data type is float32 and the shape is :math:`(n, 3)`.
2659        - **random force** (Tensor) - The random forces.
2660          The data type is float32 and the shape is :math:`(n, 3)`.
2661
2662    Outputs:
2663        - **res** (Scalar) - The data type is float32.
2664
2665    Supported Platforms:
2666        ``GPU``
2667    Examples:
2668    """
2669
2670    @prim_attr_register
2671    def __init__(self, float4_numbers, atom_numbers, half_dt, dt, exp_gamma, is_max_velocity, max_velocity):
2672        """Initialize MDIterationLeapFrogWithRF."""
2673        validator.check_value_type('float4_numbers', float4_numbers, int, self.name)
2674        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2675        validator.check_value_type('half_dt', half_dt, float, self.name)
2676        validator.check_value_type('dt', dt, float, self.name)
2677        validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
2678        validator.check_value_type('is_max_velocity', is_max_velocity, int, self.name)
2679        validator.check_value_type('max_velocity', max_velocity, float, self.name)
2680        self.float4_numbers = float4_numbers
2681        self.atom_numbers = atom_numbers
2682        self.half_dt = half_dt
2683        self.dt = dt
2684        self.exp_gamma = exp_gamma
2685        self.is_max_velocity = is_max_velocity
2686        self.max_velocity = max_velocity
2687
2688        self.init_prim_io_names(
2689            inputs=['mass_inverse', 'sqrt_mass', 'vel_in', 'crd_in', 'frc_in', 'acc_in', 'random_force'],
2690            outputs=['res'])
2691        self.add_prim_attr('float4_numbers', self.float4_numbers)
2692        self.add_prim_attr('atom_numbers', self.atom_numbers)
2693        self.add_prim_attr('half_dt', self.half_dt)
2694        self.add_prim_attr('dt', self.dt)
2695        self.add_prim_attr('exp_gamma', self.exp_gamma)
2696        self.add_prim_attr('is_max_velocity', self.is_max_velocity)
2697        self.add_prim_attr('max_velocity', self.max_velocity)
2698
2699    def infer_shape(self, mass_inverse_shape, sqrt_mass_shape, vel_in_shape, crd_in_shape, frc_in_shape, acc_in_shape,
2700                    random_force_shape):
2701        n = self.atom_numbers
2702        validator.check_int(len(mass_inverse_shape), 1, Rel.EQ, "mass_inverse", self.name)
2703        validator.check_int(len(sqrt_mass_shape), 1, Rel.EQ, "mass_inverse", self.name)
2704        validator.check_int(mass_inverse_shape[0], n, Rel.EQ, "mass_inverse", self.name)
2705        validator.check_int(sqrt_mass_shape[0], n, Rel.EQ, "mass_inverse", self.name)
2706        validator.check_int(vel_in_shape[0], n, Rel.EQ, "vel_in", self.name)
2707        validator.check_int(vel_in_shape[1], 3, Rel.EQ, "vel_in", self.name)
2708        validator.check_int(crd_in_shape[0], n, Rel.EQ, "crd_in", self.name)
2709        validator.check_int(crd_in_shape[1], 3, Rel.EQ, "crd_in", self.name)
2710        validator.check_int(frc_in_shape[0], n, Rel.EQ, "frc_in", self.name)
2711        validator.check_int(frc_in_shape[1], 3, Rel.EQ, "frc_in", self.name)
2712        validator.check_int(acc_in_shape[0], n, Rel.EQ, "acc_in", self.name)
2713        validator.check_int(acc_in_shape[1], 3, Rel.EQ, "acc_in", self.name)
2714        validator.check_int(random_force_shape[0], n, Rel.EQ, "random_force", self.name)
2715        validator.check_int(random_force_shape[1], 3, Rel.EQ, "random_force", self.name)
2716
2717        return [1,]
2718
2719    def infer_dtype(self, mass_inverse_dtype, sqrt_mass_dtype, vel_in_dtype, crd_in_dtype, frc_in_dtype, acc_in_dtype,
2720                    rf_dtype):
2721        validator.check_tensor_dtype_valid('mass_inverse', mass_inverse_dtype, [mstype.float32], self.name)
2722        validator.check_tensor_dtype_valid('sqrt_mass', sqrt_mass_dtype, [mstype.float32], self.name)
2723        validator.check_tensor_dtype_valid('vel_in', vel_in_dtype, [mstype.float32], self.name)
2724        validator.check_tensor_dtype_valid('crd_in', crd_in_dtype, [mstype.float32], self.name)
2725        validator.check_tensor_dtype_valid('frc_in', frc_in_dtype, [mstype.float32], self.name)
2726        validator.check_tensor_dtype_valid('acc_in', acc_in_dtype, [mstype.float32], self.name)
2727        validator.check_tensor_dtype_valid('rf', rf_dtype, [mstype.float32], self.name)
2728        return mstype.float32
2729
2730
2731class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
2732    """
2733    One step of classical leap frog algorithm to solve the finite difference
2734    Hamiltonian equations of motion for certain system, using Langevin dynamics
2735    with Liu's thermostat scheme. Assume the number of atoms is n and the target
2736    control temperature is T.
2737
2738    Detailed iteration formula can be found in this paper: A unified thermostat
2739    scheme for efficient configurational sampling for classical/quantum canonical
2740    ensembles via molecular dynamics. DOI: 10.1063/1.4991621.
2741
2742    Because there is a large amount of inputs and each of them are related,
2743    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2744    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2745
2746    Args:
2747        atom_numbers(int32): the number of atoms n.
2748        dt(float32): time step for finite difference.
2749        half_dt(float32): half of time step for finite difference.
2750        exp_gamma(float32): parameter in Liu's dynamic.
2751
2752    Inputs:
2753        - **inverse_mass** (Tensor) - The inverse value of
2754          mass of each atom. The data type is float32 and the shape is :math:`(n)`.
2755        - **sqrt_mass_inverse** (Tensor) - The inverse square root value
2756          of effect mass in Liu's dynamics of each atom.
2757          The data type is float32 and the shape is :math:`(n,)`.
2758        - **vel** (Tensor) - The velocity of each atom.
2759          The data type is float32 and the shape is :math:`(n, 3)`.
2760        - **crd** (Tensor) - The coordinate of each atom.
2761          The data type is float32 and the shape is :math:`(n, 3)`.
2762        - **frc** (Tensor) - The force felt by each atom.
2763          The data type is float32 and the shape is :math:`(n, 3)`.
2764        - **acc** (Tensor) - The acceleration of each atom.
2765          The data type is float32 and the shape is :math:`(n, 3)`.
2766        - **rand_state** (Tensor) - Random state to generate
2767          random force. The data type is float32 and the shape is :math:`(math.ceil(n * 3.0 / 4.0) * 16, )`.
2768        - **rand_frc** (Tensor) - The random forces.
2769          The data type is float32 and the shape is :math:`(n, 3)`.
2770
2771    Outputs:
2772        - **output** (Tensor) - The output coordinates.
2773          The data type is float32, and the shape is :math:`(n, 3)`.
2774
2775    Supported Platforms:
2776        ``GPU``
2777    """
2778
2779    @prim_attr_register
2780    def __init__(self, atom_numbers, half_dt, dt, exp_gamma):
2781        """Initialize MDIterationLeapFrogLiujian."""
2782        self.atom_numbers = atom_numbers
2783        self.half_dt = half_dt
2784        self.dt = dt
2785        self.exp_gamma = exp_gamma
2786
2787        self.add_prim_attr('atom_numbers', self.atom_numbers)
2788        self.add_prim_attr('half_dt', self.half_dt)
2789        self.add_prim_attr('dt', self.dt)
2790        self.add_prim_attr('exp_gamma', self.exp_gamma)
2791        self.init_prim_io_names(
2792            inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
2793            outputs=['output'])
2794
2795    def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
2796        n = self.atom_numbers
2797        validator.check_int(len(inverse_mass), 1, Rel.EQ, "inverse_mass", self.name)
2798        validator.check_int(len(sqrt_mass_inverse), 1, Rel.EQ, "sqrt_mass_inverse", self.name)
2799        validator.check_int(inverse_mass[0], n, Rel.EQ, "inverse_mass", self.name)
2800        validator.check_int(sqrt_mass_inverse[0], n, Rel.EQ, "sqrt_mass_inverse", self.name)
2801        validator.check_int(len(rand_state), 1, Rel.EQ, "rand_state_dim", self.name)
2802        validator.check_int(len(rand_frc), 2, Rel.EQ, "rand_frc_dim", self.name)
2803        validator.check_int(len(vel), 2, Rel.EQ, "vel_dim", self.name)
2804        validator.check_int(len(crd), 2, Rel.EQ, "crd_dim", self.name)
2805        validator.check_int(len(frc), 2, Rel.EQ, "frc_dim", self.name)
2806        validator.check_int(len(acc), 2, Rel.EQ, "acc_dim", self.name)
2807        validator.check_int(vel[0], n, Rel.EQ, "vel_shape[0]", self.name)
2808        validator.check_int(vel[1], 3, Rel.EQ, "vel_shape[1]", self.name)
2809        validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", self.name)
2810        validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", self.name)
2811        validator.check_int(frc[0], n, Rel.EQ, "frc_shape[0]", self.name)
2812        validator.check_int(frc[1], 3, Rel.EQ, "frc_shape[1]", self.name)
2813        validator.check_int(acc[0], n, Rel.EQ, "acc_shape[0]", self.name)
2814        validator.check_int(acc[1], 3, Rel.EQ, "acc_shape[1]", self.name)
2815        validator.check_int(rand_frc[0], n, Rel.EQ, "rand_frc_shape[0]", self.name)
2816        validator.check_int(rand_frc[1], 3, Rel.EQ, "rand_frc_shape[1]", self.name)
2817        validator.check_int(rand_state[0], math.ceil(self.atom_numbers * 3 / 4.0) * 16, Rel.EQ, "rand_state", self.name)
2818        return [self.atom_numbers, 3]
2819
2820    def infer_dtype(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
2821        validator.check_tensor_dtype_valid('inverse_mass', inverse_mass, [mstype.float32], self.name)
2822        validator.check_tensor_dtype_valid('sqrt_mass_inverse', sqrt_mass_inverse, [mstype.float32], self.name)
2823        validator.check_tensor_dtype_valid('vel', vel, [mstype.float32], self.name)
2824        validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
2825        validator.check_tensor_dtype_valid('frc', frc, [mstype.float32], self.name)
2826        validator.check_tensor_dtype_valid('acc', acc, [mstype.float32], self.name)
2827        validator.check_tensor_dtype_valid('rand_frc', rand_frc, [mstype.float32], self.name)
2828        validator.check_tensor_dtype_valid('rand_state', rand_state, [mstype.float32], self.name)
2829        return mstype.float32
2830
2831
2832class CrdToUintCrd(PrimitiveWithInfer):
2833    """
2834    Convert FP32 coordinate to Uint32 coordinate.
2835
2836    Because there is a large amount of inputs and each of them are related,
2837    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2838    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2839
2840    Args:
2841        atom_numbers(int32): the number of atoms n.
2842
2843    Inputs:
2844        - **crd_to_uint_crd_cof** (Tensor) - The scale factor
2845          between the unsigned int value and the real space coordinates.
2846          The data type is float32 and the shape is :math:`(3,)`.
2847        - **crd** (Tensor) - The coordinate of each atom.
2848          The data type is float32 and the shape is :math:`(n, 3)`.
2849
2850    Outputs:
2851        - **output** (Scalar) - The data type is uint32.
2852
2853    Supported Platforms:
2854        ``GPU``
2855    """
2856
2857    @prim_attr_register
2858    def __init__(self, atom_numbers):
2859        """Initialize CrdToUintCrd."""
2860        self.atom_numbers = atom_numbers
2861        self.add_prim_attr('atom_numbers', self.atom_numbers)
2862        self.init_prim_io_names(
2863            inputs=['crd_to_uint_crd_cof', 'crd'],
2864            outputs=['output'])
2865
2866    def infer_shape(self, crd_to_uint_crd_cof, crd):
2867        validator.check_int(crd_to_uint_crd_cof[0], 3, Rel.EQ, "crd_to_uint_crd_cof_shape", self.name)
2868        validator.check_int(crd[0], self.atom_numbers, Rel.EQ, "crd[0]", self.name)
2869        validator.check_int(crd[1], 3, Rel.EQ, "crd[1]", self.name)
2870        return crd
2871
2872    def infer_dtype(self, crd_to_uint_crd_cof, crd):
2873        validator.check_tensor_dtype_valid('crd_to_uint_crd_cof', crd_to_uint_crd_cof, [mstype.float32], self.name)
2874        validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
2875        return mstype.uint32
2876
2877
2878class MDIterationSetupRandState(PrimitiveWithInfer):
2879    """
2880    Compute the random state of the iteration.
2881
2882    Because there is a large amount of inputs and each of them are related,
2883    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2884    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2885
2886    Args:
2887        atom_numbers(int32): the number of atoms n.
2888        seed(int32): random seed.
2889
2890    Outputs:
2891        - **output** (Tensor) random state.
2892          The data type is float32 and the shape is :math:`(ceil(n * 3 / 4),)`.
2893
2894    Supported Platforms:
2895        ``GPU``
2896    """
2897
2898    @prim_attr_register
2899    def __init__(self, atom_numbers, seed):
2900        """Initialize MDIterationSetupRandState."""
2901        self.atom_numbers = atom_numbers
2902        self.seed = seed
2903        self.add_prim_attr('atom_numbers', self.atom_numbers)
2904        self.add_prim_attr('seed', self.seed)
2905        self.init_prim_io_names(
2906            inputs=[],
2907            outputs=['output'])
2908
2909    def infer_shape(self):
2910        float4_numbers = math.ceil(self.atom_numbers * 3 / 4.0)
2911        curandsize = 64 / 4
2912        return [float4_numbers * int(curandsize),]
2913
2914    def infer_dtype(self):
2915        return mstype.float32
2916
2917
2918class TransferCrd(PrimitiveWithInfer):
2919    """
2920    Transfer the coordinates to angular and radial.
2921
2922    Because there is a large amount of inputs and each of them are related,
2923    there is no way to construct `Examples` using random methods. For details, refer the webpage `SPONGE in MindSpore
2924    <https://gitee.com/mindspore/mindscience/blob/master/MindSPONGE/docs/simple_formula.md>`_.
2925
2926    Args:
2927        start_serial(int32): the index start position.
2928        end_serial(int32): the index end position.
2929        number(int32): the length of angular and radial.
2930
2931    Inputs:
2932        - **crd** (Tensor) - The coordinate of each atom.
2933          n is the number of atoms. The data type is float32 and the shape is :math:`(n, 3)`.
2934        - **old_crd** (Tensor) - The last coordinate of each atom.
2935          n is the number of atoms. The data type is float32 and the shape is :math:`(n, 3)`.
2936        - **box** (Tensor) - The length of 3 dimensions of the simulation box.
2937          The data type is float32 and the shape is :math:`(3,)`.
2938
2939    Outputs:
2940        - **radial** (Tensor) - The array of radial transferred from coordinates.
2941          The data type is float32 and the shape is :math:`(number,)`.
2942        - **angular** (Tensor) - The array of angular transferred from coordinates.
2943          The data type is float32 and the shape is :math:`(number,)`.
2944        - **nowarp_crd** (Tensor) - The modified coordinate of each atom for
2945          computing radial and angular. The data type is float32 and the shape is :math:`(n, 3)`.
2946        - **box_map_times** (Tensor) - The box map times for radial and  angular.
2947          The data type is int32 and the shape is :math:`(n, 3)`.
2948
2949    Supported Platforms:
2950        ``GPU``
2951    """
2952
2953    @prim_attr_register
2954    def __init__(self, start_serial, end_serial, number, atom_numbers):
2955        """Initialize TransferCrd."""
2956        validator.check_value_type('start_serial', start_serial, int, self.name)
2957        validator.check_value_type('end_serial', end_serial, int, self.name)
2958        validator.check_value_type('number', number, int, self.name)
2959        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
2960        self.start_serial = start_serial
2961        self.end_serial = end_serial
2962        self.number = number
2963        self.atom_numbers = atom_numbers
2964        self.add_prim_attr('start_serial', self.start_serial)
2965        self.add_prim_attr('end_serial', self.end_serial)
2966        self.add_prim_attr('number', self.number)
2967        self.add_prim_attr('atom_numbers', self.atom_numbers)
2968        self.init_prim_io_names(
2969            inputs=['crd', 'old_crd', 'box'],
2970            outputs=['radial', 'angular', 'nowarp_crd', 'box_map_times'])
2971
2972    def infer_shape(self, crd_shape, old_crd_shape, box_shape):
2973        n = self.atom_numbers
2974        validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
2975        validator.check_int(crd_shape[0], n, Rel.EQ, "crd_shape[0]", self.name)
2976        validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", self.name)
2977        validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name)
2978        validator.check_int(old_crd_shape[0], n, Rel.EQ, "old_crd_shape[0]", self.name)
2979        validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd_shape[1]", self.name)
2980        validator.check_int(len(box_shape), 1, Rel.EQ, "box_dim", self.name)
2981        validator.check_int(box_shape[0], 3, Rel.EQ, "box_shape[0]", self.name)
2982        return [self.number,], [self.number,], [n, 3], [n, 3]
2983
2984    def infer_dtype(self, crd_dtype, old_crd_dtype, box_dtype):
2985        validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
2986        validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
2987        validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name)
2988        return mstype.float32, mstype.float32, mstype.float32, mstype.int32
2989
2990
2991class FFT3D(PrimitiveWithInfer):
2992    """
2993    Forward FFT with Three-Dimensional Input.
2994
2995    .. warning::
2996        This is an experimental prototype that is subject to change and/or deletion.
2997
2998    Inputs:
2999        - **input_tensor** (Tensor) - Three dimensional tensor, supported
3000          data type is float32.
3001
3002    Outputs:
3003        - **output_tensor** (Tensor) - The tensor after undergoing fast Fourier
3004          transform, the data type is complex64.
3005
3006    Supported Platforms:
3007        ``GPU``
3008    """
3009
3010    @prim_attr_register
3011    def __init__(self):
3012        self.init_prim_io_names(
3013            inputs=['input_tensor'],
3014            outputs=['output_tensor'])
3015
3016    def infer_shape(self, input_shape):
3017        self.add_prim_attr('fftx', input_shape[0])
3018        self.add_prim_attr('ffty', input_shape[1])
3019        self.add_prim_attr('fftz', input_shape[2])
3020        return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
3021
3022    def infer_dtype(self, input_dtype):
3023        validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.float32], self.name)
3024        return mstype.complex64
3025
3026
3027class IFFT3D(PrimitiveWithInfer):
3028    """
3029    Inverse FFT with Three-Dimensional Input.
3030
3031    .. warning::
3032        This is an experimental prototype that is subject to change and/or deletion.
3033
3034    Inputs:
3035        - **input_tensor** (Tensor) - Three dimensional input tensor, supported data
3036          type is complex64.
3037
3038    Outputs:
3039        - **output_tensor** (Tensor) - Returns the tensor after undergoing
3040          inverse Fourier transform, the data type is float32.
3041
3042    Supported Platforms:
3043        ``GPU``
3044    """
3045
3046    @prim_attr_register
3047    def __init__(self):
3048        self.init_prim_io_names(
3049            inputs=['input_tensor'],
3050            outputs=['output_tensor'])
3051
3052    def infer_shape(self, input_shape):
3053        self.add_prim_attr('fftx', input_shape[0])
3054        self.add_prim_attr('ffty', input_shape[1])
3055        self.add_prim_attr('fftz', input_shape[2])
3056        return [input_shape[0], input_shape[1], (input_shape[2]-1)*2]
3057
3058    def infer_dtype(self, input_dtype):
3059        validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
3060        return mstype.float32
3061
3062
3063class NeighborListUpdate(PrimitiveWithInfer):
3064    """
3065    Update (or construct if first time) the Verlet neighbor list for the
3066    calculation of short-ranged force. Assume the number of atoms is N,
3067    the number of grids divided is G, the maximum number of atoms in one
3068    grid is m, the maximum number of atoms in single atom's neighbor list
3069    is L, and the number of total atom in excluded list is E.
3070
3071    Args:
3072        grid_numbers(int32): the total number of grids divided.
3073        not_first_time(int32): whether to construct the neighbor
3074          list first time or not.
3075        nxy(int32): the total number of grids divided in xy plane.
3076        excluded_atom_numbers(int32): the total atom numbers in the excluded list.
3077        cutoff(float32): the cutoff distance for short-range force calculation.
3078        skin(float32): the overflow value of cutoff to maintain a neighbor list.
3079        cutoff_square(float32): the suqare value of cutoff.
3080        half_skin_square(float32): skin*skin/4, indicates the maximum
3081          square value of the distance atom allowed to move between two updates.
3082        cutoff_with_skin(float32): cutoff + skin, indicates the
3083          radius of the neighbor list for each atom.
3084        half_cutoff_with_skin(float32): cutoff_with_skin/2.
3085        cutoff_with_skin_square(float32): the square value of cutoff_with_skin.
3086        refresh_interval(int32): the number of iteration steps between two updates of neighbor list.
3087        max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid k.
3088
3089    Inputs:
3090        - **atom_numbers_in_grid_bucket** (Tensor) - The number of atoms in each grid bucket.
3091          The data type is int32 and the shape is :math:`(G,)`.
3092        - **bucket** (Tensor) - (Tensor) - The atom indices in each grid bucket.
3093          The data type is int32 and the shape is :math:`(G, k)`.
3094        - **crd** (Tensor) - The coordinates of each atom.
3095          The data type is float32 and the shape is :math:`(n, 3)`.
3096        - **box_length** (Tensor) - The box length of the simulation box.
3097          The data type is float32 and the shape is :math:`(3,)`.
3098        - **grid_N** (Tensor) - The number of grids divided of 3 dimensions of the simulation box.
3099          The data type is int32 and the shape is :math:`(3,)`.
3100        - **grid_length_inverse** (Tensor) - The inverse value of grid length.
3101          The data type is float32 and the shape is :math:`(3,)`.
3102        - **atom_in_grid_serial** (Tensor) - The grid index for each atom.
3103          The data type is int32 and the shape is :math:`(n,)`.
3104        - **old_crd** (Tensor) - The coordinates before update of each atom.
3105          The data type is float32 and the shape is :math:`(n, 3)`.
3106        - **crd_to_uint_crd_cof** (Tensor) - The scale factor between the unsigned int coordinate and the real one.
3107          The data type is float32 and the shape is :math:`(3,)`.
3108        - **uint_crd** (Tensor) - The unsigned int coordinates value fo each atom.
3109          The data type is unsigned int32 and the shape is :math:`(n, 3)`.
3110        - **gpointer** (Tensor) - The nearest neighbor grids (including self) of each grid.
3111          The data type is int32 and the shape is :math:`(G, 125)`.
3112        - **nl_atom_numbers** (Tensor) - The number of atoms in neighbor list of each atom.
3113          The data type is int32 and the shape is :math:`(n,)`.
3114        - **nl_atom_serial** (Tensor) - The indices of atoms in neighbor list of each atom.
3115          The data type is int32 and the shape is :math:`(n, m)`.
3116        - **uint_dr_to_dr_cof** (Tensor) - The scale factor.
3117          The data type is float32 and the shape is :math:`(3,)`.
3118        - **excluded_list_start** (Tensor) - The start excluded index in excluded list for each atom.
3119          The data type is int32 and the shape is :math:`(n,)`.
3120        - **excluded_list** (Tensor) - The contiguous join of excluded list of each atom.
3121          The data type is int32 and the shape is :math:`(E,)`.
3122        - **excluded_numbers** (Tensor) - The number of atom excluded in excluded list for each atom.
3123          The data type is int32 and the shape is :math:`(n,)`.
3124        - **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
3125          The data type is int32 and the shape is :math:`(1,)`.
3126        - **refresh_count** (Tensor) - Count how many iteration steps have passed since last update.
3127          The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
3128
3129    Outputs:
3130        - **res** (Tensor) - The return value after updating successfully.
3131          The data type is float32 and the shape is :math:`(1,)`.
3132
3133    Supported Platforms:
3134        ``GPU``
3135    """
3136
3137    @prim_attr_register
3138    def __init__(self, grid_numbers, atom_numbers, not_first_time, nxy, excluded_atom_numbers,
3139                 cutoff_square, half_skin_square, cutoff_with_skin, half_cutoff_with_skin, cutoff_with_skin_square,
3140                 refresh_interval=20, cutoff=10.0, skin=2.0, max_atom_in_grid_numbers=64, max_neighbor_numbers=800):
3141        self.grid_numbers = grid_numbers
3142        self.atom_numbers = atom_numbers
3143        self.refresh_interval = refresh_interval
3144        self.not_first_time = not_first_time
3145        self.cutoff = cutoff
3146        self.skin = skin
3147        self.max_atom_in_grid_numbers = max_atom_in_grid_numbers
3148        self.nxy = nxy
3149        self.excluded_atom_numbers = excluded_atom_numbers
3150        self.cutoff_square = cutoff_square
3151        self.half_skin_square = half_skin_square
3152        self.cutoff_with_skin = cutoff_with_skin
3153        self.half_cutoff_with_skin = half_cutoff_with_skin
3154        self.cutoff_with_skin_square = cutoff_with_skin_square
3155        self.max_neighbor_numbers = max_neighbor_numbers
3156        validator.check_value_type('grid_numbers', grid_numbers, int, self.name)
3157        validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
3158        validator.check_value_type('refresh_interval', refresh_interval, int, self.name)
3159        validator.check_value_type('not_first_time', not_first_time, int, self.name)
3160        validator.check_value_type('cutoff', cutoff, float, self.name)
3161        validator.check_value_type('skin', skin, float, self.name)
3162        validator.check_value_type('max_atom_in_grid_numbers', max_atom_in_grid_numbers, int, self.name)
3163        validator.check_value_type('nxy', nxy, int, self.name)
3164        validator.check_value_type('excluded_atom_numbers', excluded_atom_numbers, int, self.name)
3165        validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
3166        validator.check_value_type('half_skin_square', half_skin_square, float, self.name)
3167        validator.check_value_type('cutoff_with_skin', cutoff_with_skin, float, self.name)
3168        validator.check_value_type('half_cutoff_with_skin', half_cutoff_with_skin, float, self.name)
3169        validator.check_value_type('cutoff_with_skin_square', cutoff_with_skin_square, float, self.name)
3170        validator.check_value_type('max_neighbor_numbers', max_neighbor_numbers, int, self.name)
3171        self.init_prim_io_names(
3172            inputs=['atom_numbers_in_grid_bucket', 'bucket', 'crd', 'box_length', 'grid_N', 'grid_length_inverse',
3173                    'atom_in_grid_serial', 'old_crd', 'crd_to_uint_crd_cof', 'uint_crd', 'gpointer', 'nl_atom_numbers',
3174                    'nl_atom_serial', 'uint_dr_to_dr_cof', 'excluded_list_start', 'excluded_list', 'excluded_numbers',
3175                    'need_refresh_flag', 'refresh_count'], outputs=['res'])
3176
3177        self.add_prim_attr('grid_numbers', self.grid_numbers)
3178        self.add_prim_attr('atom_numbers', self.atom_numbers)
3179        self.add_prim_attr('refresh_interval', self.refresh_interval)
3180        self.add_prim_attr('not_first_time', self.not_first_time)
3181        self.add_prim_attr('cutoff', self.cutoff)
3182        self.add_prim_attr('skin', self.skin)
3183        self.add_prim_attr('max_atom_in_grid_numbers', self.max_atom_in_grid_numbers)
3184        self.add_prim_attr('nxy', self.nxy)
3185        self.add_prim_attr('excluded_atom_numbers', self.excluded_atom_numbers)
3186        self.add_prim_attr('cutoff_square', self.cutoff_square)
3187        self.add_prim_attr('half_skin_square', self.half_skin_square)
3188        self.add_prim_attr('cutoff_with_skin', self.cutoff_with_skin)
3189        self.add_prim_attr('half_cutoff_with_skin', self.half_cutoff_with_skin)
3190        self.add_prim_attr('cutoff_with_skin_square', self.cutoff_with_skin_square)
3191
3192    def infer_shape(self, atom_numbers_in_grid_bucket_shape, bucket_shape, crd_shape, box_length_shape, grid_n_shape,
3193                    grid_length_inverse_shape, atom_in_grid_serial_shape, old_crd_shape, crd_to_uint_crd_cof_shape,
3194                    uint_crd_shape, gpointer_shape, nl_atom_numbers_shape, nl_atom_serial_shape,
3195                    uint_dr_to_dr_cof_shape, excluded_list_start_shape, excluded_list_shape, excluded_numbers_shape,
3196                    need_refresh_flag_shape, refresh_count_shape):
3197        validator.check_int(len(atom_numbers_in_grid_bucket_shape), 1, Rel.EQ,
3198                            "atom_numbers_in_grid_bucket_dim", self.name)
3199        validator.check_int(len(bucket_shape), 2, Rel.EQ, "bucket_dim", self.name)
3200        validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
3201        validator.check_int(len(box_length_shape), 1, Rel.EQ, "box_length_dim", self.name)
3202        validator.check_int(len(grid_n_shape), 1, Rel.EQ, "grid_n_dim", self.name)
3203        validator.check_int(len(grid_length_inverse_shape), 1, Rel.EQ, "grid_length_inverse_dim", self.name)
3204        validator.check_int(len(atom_in_grid_serial_shape), 1, Rel.EQ, "atom_in_grid_serial_dim", self.name)
3205        validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name)
3206        validator.check_int(len(crd_to_uint_crd_cof_shape), 1, Rel.EQ, "crd_to_uint_crd_cof_dim", self.name)
3207        validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", self.name)
3208        validator.check_int(len(gpointer_shape), 2, Rel.EQ, "gpointer_dim", self.name)
3209        validator.check_int(len(nl_atom_numbers_shape), 1, Rel.EQ, "nl_atom_numbers_dim", self.name)
3210        validator.check_int(len(nl_atom_serial_shape), 2, Rel.EQ, "nl_atom_serial_dim", self.name)
3211        validator.check_int(len(uint_dr_to_dr_cof_shape), 1, Rel.EQ, "uint_dr_to_dr_cof_dim", self.name)
3212        validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", self.name)
3213        validator.check_int(len(excluded_list_shape), 1, Rel.EQ, "excluded_list_dim", self.name)
3214        validator.check_int(len(excluded_numbers_shape), 1, Rel.EQ, "excluded_numbers_dim", self.name)
3215        validator.check_int(len(need_refresh_flag_shape), 1, Rel.EQ, "need_refresh_flag_dim", self.name)
3216        validator.check_int(len(refresh_count_shape), 1, Rel.LE, "refresh_count_dim", self.name)
3217        validator.check_int(atom_numbers_in_grid_bucket_shape[0], self.grid_numbers, Rel.EQ,
3218                            "atom_numbers_in_grid_bucket", self.name)
3219        validator.check_int(bucket_shape[0], self.grid_numbers, Rel.EQ, "bucket", self.name)
3220        validator.check_int(bucket_shape[1], self.max_atom_in_grid_numbers, Rel.EQ, "bucket", self.name)
3221        validator.check_int(crd_shape[0], self.atom_numbers, Rel.EQ, "crd", self.name)
3222        validator.check_int(crd_shape[1], 3, Rel.EQ, "crd", self.name)
3223        validator.check_int(box_length_shape[0], 3, Rel.EQ, "box_length", self.name)
3224        validator.check_int(grid_n_shape[0], 3, Rel.EQ, "grid_N", self.name)
3225        validator.check_int(grid_length_inverse_shape[0], 3, Rel.EQ, "grid_length_inverse", self.name)
3226        validator.check_int(atom_in_grid_serial_shape[0], self.atom_numbers, Rel.EQ, "atom_in_grid_serial",
3227                            self.name)
3228        validator.check_int(old_crd_shape[0], self.atom_numbers, Rel.EQ, "old_crd", self.name)
3229        validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd", self.name)
3230        validator.check_int(crd_to_uint_crd_cof_shape[0], 3, Rel.EQ, "crd_to_uint_crd_cof", self.name)
3231        validator.check_int(uint_crd_shape[0], self.atom_numbers, Rel.EQ, "uint_crd", self.name)
3232        validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd", self.name)
3233        validator.check_int(gpointer_shape[0], self.grid_numbers, Rel.EQ, "gpointer", self.name)
3234        validator.check_int(gpointer_shape[1], 125, Rel.EQ, "gpointer", self.name)
3235        validator.check_int(nl_atom_numbers_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_numbers", self.name)
3236        validator.check_int(nl_atom_serial_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_serial", self.name)
3237        validator.check_int(nl_atom_serial_shape[1], self.max_neighbor_numbers, Rel.EQ, "nl_atom_serial",
3238                            self.name)
3239        validator.check_int(uint_dr_to_dr_cof_shape[0], 3, Rel.EQ, "uint_dr_to_dr_cof", self.name)
3240        validator.check_int(excluded_list_start_shape[0], self.atom_numbers, Rel.EQ, "excluded_list_start",
3241                            self.name)
3242        validator.check_int(excluded_list_shape[0], self.excluded_atom_numbers, Rel.EQ, "excluded_list",
3243                            self.name)
3244        validator.check_int(excluded_numbers_shape[0], self.atom_numbers, Rel.EQ, "excluded_numbers", self.name)
3245        validator.check_int(need_refresh_flag_shape[0], 1, Rel.EQ, "need_refresh_flag", self.name)
3246        if refresh_count_shape:
3247            validator.check_int(refresh_count_shape[0], 1, Rel.EQ, "refresh_count_shape", self.name)
3248        return [1,]
3249
3250    def infer_dtype(self, atom_numbers_in_grid_bucket_dtype, bucket_dtype, crd_dtype, box_length_dtype, grid_n_dtype,
3251                    grid_length_inverse_dtype, atom_in_grid_serial_dtype, old_crd_dtype, crd_to_uint_crd_cof_dtype,
3252                    uint_crd_dtype, gpointer_dtype, nl_atom_numbers_dtype, nl_atom_serial_dtype,
3253                    uint_dr_to_dr_cof_dtype, excluded_list_start_dtype, excluded_list_dtype, excluded_numbers_dtype,
3254                    need_refresh_flag_dtype, refresh_count_dtype):
3255        validator.check_tensor_dtype_valid('atom_numbers_in_grid_bucket', atom_numbers_in_grid_bucket_dtype,
3256                                           [mstype.int32], self.name)
3257        validator.check_tensor_dtype_valid('bucket', bucket_dtype, [mstype.int32], self.name)
3258        validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
3259        validator.check_tensor_dtype_valid('box_length', box_length_dtype, [mstype.float32], self.name)
3260        validator.check_tensor_dtype_valid('grid_N', grid_n_dtype, [mstype.int32], self.name)
3261        validator.check_tensor_dtype_valid('grid_length_inverse', grid_length_inverse_dtype, [mstype.float32],
3262                                           self.name)
3263        validator.check_tensor_dtype_valid('atom_in_grid_serial', atom_in_grid_serial_dtype, [mstype.int32],
3264                                           self.name)
3265        validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
3266        validator.check_tensor_dtype_valid('crd_to_uint_crd_cof', crd_to_uint_crd_cof_dtype, [mstype.float32],
3267                                           self.name)
3268        validator.check_tensor_dtype_valid('uint_crd', uint_crd_dtype, [mstype.uint32], self.name)
3269        validator.check_tensor_dtype_valid('gpointer', gpointer_dtype, [mstype.int32], self.name)
3270        validator.check_tensor_dtype_valid('nl_atom_numbers', nl_atom_numbers_dtype, [mstype.int32], self.name)
3271        validator.check_tensor_dtype_valid('nl_atom_serial', nl_atom_serial_dtype, [mstype.int32], self.name)
3272        validator.check_tensor_dtype_valid('uint_dr_to_dr_cof', uint_dr_to_dr_cof_dtype, [mstype.float32],
3273                                           self.name)
3274        validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start_dtype, [mstype.int32],
3275                                           self.name)
3276        validator.check_tensor_dtype_valid('excluded_list', excluded_list_dtype, [mstype.int32], self.name)
3277        validator.check_tensor_dtype_valid('excluded_numbers', excluded_numbers_dtype, [mstype.int32], self.name)
3278        validator.check_tensor_dtype_valid('need_refresh_flag', need_refresh_flag_dtype, [mstype.int32],
3279                                           self.name)
3280        validator.check_tensor_dtype_valid('refresh_count', refresh_count_dtype, [mstype.int32],
3281                                           self.name)
3282        return mstype.float32
3283