• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2022 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17
18"""Env related operations."""
19from __future__ import absolute_import
20from mindspore.ops.composite.base import MultitypeFuncGraph
21from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
22from mindspore.ops.primitive import Primitive
23from mindspore.ops.operations import _grad_ops
24from mindspore.ops import operations as P
25
26env_get = MultitypeFuncGraph("env_get")
27environ_get = Primitive('EnvironGet')
28ref_to_embed = _grad_ops.RefToEmbed()
29tensor_zeros_like = P.ZerosLike()
30
31
32@env_get.register("EnvType", "Tensor")
33def _tensor_env_get(env, parameter):
34    """Used to get env."""
35    return environ_get(env, ref_to_embed(parameter), tensor_zeros_like(parameter))
36
37
38@env_get.register("EnvType", "MapTensor")
39def _map_tensor_env_get(env, map_parameter):
40    """Used to get env for map parameter."""
41    return environ_get(env, ref_to_embed(map_parameter), zeros_like(map_parameter))
42