• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17import torch
18import mindspore
19from mindspore import Tensor
20from mindspore.nn import Cell
21from mindspore.ops import operations as P
22from mindspore import context
23from mindspore.ops.operations import math_ops as MP
24
25
26class NetTest(Cell):
27    def __init__(self):
28        super().__init__()
29        self.sinc = MP.Sinc()
30        self.reduce_sum = P.ReduceMax(keep_dims=False)
31        self.relu = P.ReLU()
32
33    def construct(self, x, indices):
34        unique_indices = self.relu(indices)
35        x = self.reduce_sum(x, unique_indices)
36        return self.sinc(x)
37
38
39@pytest.mark.level0
40@pytest.mark.platform_arm_ascend_training
41@pytest.mark.platform_x86_ascend_training
42@pytest.mark.env_onecard
43def test_pynative_and_graph_mixed_run():
44    """
45    Feature: test pynative and graph mixed run
46    Description: single op run in pynative, the output to net input which run in graph
47    Expectation: run success
48    """
49    context.set_context(jit_level='O0')
50    data_x = np.random.randn(7, 3, 8, 8, 8).astype(np.float32)
51    x = Tensor(data_x) + 100
52    data_indices = np.unique(np.random.randint(2, 4, size=4).astype(np.int32))
53    indices = Tensor(data_indices)
54    context.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
55    out_ms = NetTest()(x, indices)
56
57    y = torch.tensor(data_x) + 100
58    indices_pt = torch.tensor(data_indices)
59    unique_indices = list(torch.relu(indices_pt).numpy())
60    y_reduce = torch.amax(input=y, dim=unique_indices, keepdims=False)
61    out_tf = torch.sinc(y_reduce)
62    assert np.allclose(out_ms.asnumpy(), out_tf, 0.0001, 0.0001)
63