1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2022 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17 18"""Env related operations.""" 19from __future__ import absolute_import 20from mindspore.ops.composite.base import MultitypeFuncGraph 21from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like 22from mindspore.ops.primitive import Primitive 23from mindspore.ops.operations import _grad_ops 24from mindspore.ops import operations as P 25 26env_get = MultitypeFuncGraph("env_get") 27environ_get = Primitive('EnvironGet') 28ref_to_embed = _grad_ops.RefToEmbed() 29tensor_zeros_like = P.ZerosLike() 30 31 32@env_get.register("EnvType", "Tensor") 33def _tensor_env_get(env, parameter): 34 """Used to get env.""" 35 return environ_get(env, ref_to_embed(parameter), tensor_zeros_like(parameter)) 36 37 38@env_get.register("EnvType", "MapTensor") 39def _map_tensor_env_get(env, map_parameter): 40 """Used to get env for map parameter.""" 41 return environ_get(env, ref_to_embed(map_parameter), zeros_like(map_parameter)) 42