# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging import torch import torchaudio from ..model_base import EagerModelBase FORMAT = "[%(filename)s:%(lineno)s] %(message)s" logging.basicConfig(format=FORMAT) __all__ = [ "EmformerRnntTranscriberModel", "EmformerRnntPredictorModel", "EmformerRnntJoinerModel", ] class EmformerRnntTranscriberExample(torch.nn.Module): """ This is a wrapper for validating transcriber for the Emformer RNN-T architecture. It does not reflect the actual usage such as beam search, but rather an example for the export workflow. """ def __init__(self) -> None: super().__init__() bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH decoder = bundle.get_decoder() m = decoder.model self.rnnt = m def forward(self, transcribe_inputs): return self.rnnt.transcribe(*transcribe_inputs) class EmformerRnntTranscriberModel(EagerModelBase): def __init__(self): pass def get_eager_model(self) -> torch.nn.Module: logging.info("Loading emformer rnnt transcriber") m = EmformerRnntTranscriberExample() logging.info("Loaded emformer rnnt transcriber") return m def get_example_inputs(self): transcribe_inputs = ( torch.randn(1, 128, 80), torch.tensor([128]), ) return (transcribe_inputs,) class EmformerRnntPredictorExample(torch.nn.Module): """ This is a wrapper for validating predictor for the Emformer RNN-T architecture. It does not reflect the actual usage such as beam search, but rather an example for the export workflow. """ def __init__(self) -> None: super().__init__() bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH decoder = bundle.get_decoder() m = decoder.model self.rnnt = m def forward(self, predict_inputs): return self.rnnt.predict(*predict_inputs) class EmformerRnntPredictorModel(EagerModelBase): def __init__(self): pass def get_eager_model(self) -> torch.nn.Module: logging.info("Loading emformer rnnt predictor") m = EmformerRnntPredictorExample() logging.info("Loaded emformer rnnt predictor") return m def get_example_inputs(self): predict_inputs = ( torch.zeros([1, 128], dtype=int), torch.tensor([128], dtype=int), None, ) return (predict_inputs,) class EmformerRnntJoinerExample(torch.nn.Module): """ This is a wrapper for validating joiner for the Emformer RNN-T architecture. It does not reflect the actual usage such as beam search, but rather an example for the export workflow. """ def __init__(self) -> None: super().__init__() bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH decoder = bundle.get_decoder() m = decoder.model self.rnnt = m def forward(self, predict_inputs): return self.rnnt.join(*predict_inputs) class EmformerRnntJoinerModel(EagerModelBase): def __init__(self): pass def get_eager_model(self) -> torch.nn.Module: logging.info("Loading emformer rnnt joiner") m = EmformerRnntJoinerExample() logging.info("Loaded emformer rnnt joiner") return m def get_example_inputs(self): join_inputs = ( torch.rand([1, 128, 1024]), torch.tensor([128]), torch.rand([1, 128, 1024]), torch.tensor([128]), ) return (join_inputs,)