microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/_torch_cvt.py
246lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. See License.txt in the project root for |
| 3 | # license information. |
| 4 | ############################################################################### |
| 5 | |
| 6 | """ |
| 7 | _torch_cvt.py: Data processing graph converted from PyTorch |
| 8 | """ |
| 9 | |
| 10 | import io |
| 11 | import onnx |
| 12 | import torch |
| 13 | import numpy as np |
| 14 | |
| 15 | from onnx import numpy_helper |
| 16 | |
| 17 | from ._ortapi2 import make_onnx_model |
| 18 | from ._cuops import SingleOpGraph |
| 19 | from ._hf_cvt import HFTokenizerConverter |
| 20 | from .util import remove_unused_initializers |
| 21 | |
| 22 | |
| 23 | class _WhisperHParams: |
| 24 | SAMPLE_RATE = 16000 |
| 25 | N_FFT = 400 |
| 26 | N_MELS = 80 |
| 27 | HOP_LENGTH = 160 |
| 28 | CHUNK_LENGTH = 30 |
| 29 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk |
| 30 | N_FRAMES = N_SAMPLES // HOP_LENGTH |
| 31 | |
| 32 | |
| 33 | def _mel_filterbank( |
| 34 | n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32): |
| 35 | """ |
| 36 | Compute a Mel-filterbank. The filters are stored in the rows, the columns, |
| 37 | and it is Slaney normalized mel-scale filterbank. |
| 38 | """ |
| 39 | fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype) |
| 40 | |
| 41 | # the centers of the frequency bins for the DFT |
| 42 | freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) |
| 43 | |
| 44 | mel = np.linspace(min_mel, max_mel, n_mels + 2) |
| 45 | # Fill in the linear scale |
| 46 | f_min = 0.0 |
| 47 | f_sp = 200.0 / 3 |
| 48 | freqs = f_min + f_sp * mel |
| 49 | |
| 50 | # And now the nonlinear scale |
| 51 | min_log_hz = 1000.0 # beginning of log region (Hz) |
| 52 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) |
| 53 | logstep = np.log(6.4) / 27.0 # step size for log region |
| 54 | |
| 55 | log_t = mel >= min_log_mel |
| 56 | freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel)) |
| 57 | mel_bins = freqs |
| 58 | |
| 59 | mel_spacing = np.diff(mel_bins) |
| 60 | |
| 61 | ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1) |
| 62 | for i in range(n_mels): |
| 63 | left = -ramps[i] / mel_spacing[i] |
| 64 | right = ramps[i + 2] / mel_spacing[i + 1] |
| 65 | |
| 66 | # intersect them with each other and zero |
| 67 | fbank[i] = np.maximum(0, np.minimum(left, right)) |
| 68 | |
| 69 | energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels]) |
| 70 | fbank *= energy_norm[:, np.newaxis] |
| 71 | return fbank |
| 72 | |
| 73 | |
| 74 | class CustomOpStftNorm(torch.autograd.Function): |
| 75 | @staticmethod |
| 76 | def symbolic(g, self, n_fft, hop_length, window): |
| 77 | t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) |
| 78 | t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64)) |
| 79 | t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) |
| 80 | return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size) |
| 81 | |
| 82 | @staticmethod |
| 83 | def forward(ctx, audio, n_fft, hop_length, window): |
| 84 | win_length = window.shape[0] |
| 85 | stft = torch.stft(audio, n_fft, hop_length, win_length, window, |
| 86 | center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True) |
| 87 | return stft.abs() ** 2 |
| 88 | |
| 89 | |
| 90 | class WhisperPrePipeline(torch.nn.Module): |
| 91 | def __init__(self): |
| 92 | super().__init__() |
| 93 | self.window = torch.hann_window(_WhisperHParams.N_FFT) |
| 94 | self.mel_filters = torch.from_numpy( |
| 95 | _mel_filterbank( |
| 96 | sr=_WhisperHParams.SAMPLE_RATE, |
| 97 | n_fft=_WhisperHParams.N_FFT, |
| 98 | n_mels=_WhisperHParams.N_MELS)) |
| 99 | |
| 100 | def forward(self, audio_pcm: torch.Tensor): |
| 101 | stft_norm = CustomOpStftNorm.apply(audio_pcm, |
| 102 | _WhisperHParams.N_FFT, |
| 103 | _WhisperHParams.HOP_LENGTH, |
| 104 | self.window) |
| 105 | magnitudes = stft_norm[:, :, :-1] |
| 106 | mel_spec = self.mel_filters @ magnitudes |
| 107 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
| 108 | spec_min = log_spec.max() - 8.0 |
| 109 | log_spec = torch.maximum(log_spec, spec_min) |
| 110 | spec_shape = log_spec.shape |
| 111 | padding_spec = torch.ones(spec_shape[0], |
| 112 | spec_shape[1], |
| 113 | _WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2], |
| 114 | dtype=torch.float) |
| 115 | padding_spec *= spec_min |
| 116 | log_spec = torch.cat((log_spec, padding_spec), dim=2) |
| 117 | log_spec = (log_spec + 4.0) / 4.0 |
| 118 | return log_spec |
| 119 | |
| 120 | |
| 121 | def _to_onnx_stft(onnx_model): |
| 122 | """Convert custom-op STFT-Norm to ONNX STFT""" |
| 123 | node_idx = 0 |
| 124 | new_stft_nodes = [] |
| 125 | stft_norm_node = None |
| 126 | for node in onnx_model.graph.node: |
| 127 | if node.op_type == "StftNorm": |
| 128 | stft_norm_node = node |
| 129 | break |
| 130 | node_idx += 1 |
| 131 | |
| 132 | if stft_norm_node is None: |
| 133 | raise RuntimeError("Cannot find STFTNorm node in the graph") |
| 134 | |
| 135 | make_node = onnx.helper.make_node |
| 136 | replaced_nodes = [ |
| 137 | make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14', |
| 138 | value=numpy_helper.from_array(np.array([0, |
| 139 | _WhisperHParams.N_FFT // 2, 0, |
| 140 | _WhisperHParams.N_FFT // 2], dtype='int64'), |
| 141 | name='const_14')), |
| 142 | make_node('Pad', |
| 143 | inputs=[stft_norm_node.input[0], 'const_14_output_0'], |
| 144 | outputs=['pad_1_output_0'], mode='reflect'), |
| 145 | make_node('STFT', |
| 146 | inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]], |
| 147 | outputs=['stft_output_0'], name='stft', domain='', onesided=1), |
| 148 | make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1', |
| 149 | perm=[0, 2, 1, 3]), |
| 150 | make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17', |
| 151 | value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')), |
| 152 | make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18', |
| 153 | value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')), |
| 154 | make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19', |
| 155 | value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')), |
| 156 | make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20', |
| 157 | value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')), |
| 158 | make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0', |
| 159 | 'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'], |
| 160 | name='slice_1'), |
| 161 | make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0), |
| 162 | make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1), |
| 163 | make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'], |
| 164 | name='gather_4', axis=3), |
| 165 | make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'], |
| 166 | name='gather_5', axis=3), |
| 167 | make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'), |
| 168 | make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'), |
| 169 | make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'), |
| 170 | ] |
| 171 | new_stft_nodes.extend(onnx_model.graph.node[:node_idx]) |
| 172 | new_stft_nodes.extend(replaced_nodes) |
| 173 | new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:]) |
| 174 | del onnx_model.graph.node[:] |
| 175 | onnx_model.graph.node.extend(new_stft_nodes) |
| 176 | onnx.checker.check_model(onnx_model) |
| 177 | return onnx_model |
| 178 | |
| 179 | |
| 180 | def _torch_export(*arg, **kwargs): |
| 181 | with io.BytesIO() as f: |
| 182 | torch.onnx.export(*arg, f, **kwargs) |
| 183 | return onnx.load_from_string(f.getvalue()) |
| 184 | |
| 185 | |
| 186 | class WhisperDataProcGraph: |
| 187 | def __init__(self, processor, **kwargs): |
| 188 | self.hf_processor = processor |
| 189 | _opset = kwargs.pop('opset', 17) |
| 190 | self.opset_version = _opset if _opset else 17 |
| 191 | |
| 192 | def pre_processing(self, **kwargs): |
| 193 | use_audio_decoder = kwargs.pop('USE_AUDIO_DECODER', True) |
| 194 | use_onnx_stft = kwargs.pop('USE_ONNX_STFT', True) |
| 195 | whisper_processing = WhisperPrePipeline() |
| 196 | |
| 197 | audio_pcm = torch.rand((1, 32000), dtype=torch.float32) |
| 198 | model_args = (audio_pcm,) |
| 199 | pre_model = _torch_export( |
| 200 | whisper_processing, |
| 201 | model_args, |
| 202 | input_names=["audio_pcm"], |
| 203 | output_names=["log_mel"], |
| 204 | do_constant_folding=True, |
| 205 | export_params=True, |
| 206 | opset_version=self.opset_version, |
| 207 | dynamic_axes={ |
| 208 | "audio_pcm": {1: "sample_len"}, |
| 209 | } |
| 210 | ) |
| 211 | if use_onnx_stft: |
| 212 | pre_model = _to_onnx_stft(pre_model) |
| 213 | remove_unused_initializers(pre_model.graph) |
| 214 | |
| 215 | pre_full = pre_model |
| 216 | if use_audio_decoder: |
| 217 | audecoder_g = SingleOpGraph.build_graph( |
| 218 | "AudioDecoder", downsampling_rate=_WhisperHParams.SAMPLE_RATE, stereo_to_mono=1) |
| 219 | audecoder_m = make_onnx_model(audecoder_g) |
| 220 | pre_full = onnx.compose.merge_models( |
| 221 | audecoder_m, |
| 222 | pre_model, |
| 223 | io_map=[("floatPCM", "audio_pcm")]) |
| 224 | |
| 225 | return pre_full |
| 226 | |
| 227 | def post_processing(self, **kwargs): |
| 228 | skip_special_tokens = kwargs.get('skip_special_tokens', True) |
| 229 | g = SingleOpGraph.build_graph( |
| 230 | "BpeDecoder", |
| 231 | cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder, |
| 232 | skip_special_tokens=skip_special_tokens) |
| 233 | |
| 234 | bpenode = g.node[0] |
| 235 | bpenode.input[0] = "generated_ids" |
| 236 | nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64), |
| 237 | bpenode] |
| 238 | del g.node[:] |
| 239 | g.node.extend(nodes) |
| 240 | |
| 241 | inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])] |
| 242 | del g.input[:] |
| 243 | g.input.extend(inputs) |
| 244 | g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text'])) |
| 245 | |
| 246 | return make_onnx_model(g, opset_version=self.opset_version) |
| 247 | |