microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2cf9bab611e9ad563822dee69c44e23bd017fadc

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_torch_cvt.py

246lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6"""
7_torch_cvt.py: Data processing graph converted from PyTorch
8"""
9
10import io
11import onnx
12import torch
13import numpy as np
14
15from onnx import numpy_helper
16
17from ._ortapi2 import make_onnx_model
18from ._cuops import SingleOpGraph
19from ._hf_cvt import HFTokenizerConverter
20from .util import remove_unused_initializers
21
22
23class _WhisperHParams:
24 SAMPLE_RATE = 16000
25 N_FFT = 400
26 N_MELS = 80
27 HOP_LENGTH = 160
28 CHUNK_LENGTH = 30
29 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
30 N_FRAMES = N_SAMPLES // HOP_LENGTH
31
32
33def _mel_filterbank(
34 n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
35 """
36 Compute a Mel-filterbank. The filters are stored in the rows, the columns,
37 and it is Slaney normalized mel-scale filterbank.
38 """
39 fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)
40
41 # the centers of the frequency bins for the DFT
42 freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
43
44 mel = np.linspace(min_mel, max_mel, n_mels + 2)
45 # Fill in the linear scale
46 f_min = 0.0
47 f_sp = 200.0 / 3
48 freqs = f_min + f_sp * mel
49
50 # And now the nonlinear scale
51 min_log_hz = 1000.0 # beginning of log region (Hz)
52 min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
53 logstep = np.log(6.4) / 27.0 # step size for log region
54
55 log_t = mel >= min_log_mel
56 freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
57 mel_bins = freqs
58
59 mel_spacing = np.diff(mel_bins)
60
61 ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
62 for i in range(n_mels):
63 left = -ramps[i] / mel_spacing[i]
64 right = ramps[i + 2] / mel_spacing[i + 1]
65
66 # intersect them with each other and zero
67 fbank[i] = np.maximum(0, np.minimum(left, right))
68
69 energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
70 fbank *= energy_norm[:, np.newaxis]
71 return fbank
72
73
74class CustomOpStftNorm(torch.autograd.Function):
75 @staticmethod
76 def symbolic(g, self, n_fft, hop_length, window):
77 t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
78 t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
79 t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
80 return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
81
82 @staticmethod
83 def forward(ctx, audio, n_fft, hop_length, window):
84 win_length = window.shape[0]
85 stft = torch.stft(audio, n_fft, hop_length, win_length, window,
86 center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
87 return stft.abs() ** 2
88
89
90class WhisperPrePipeline(torch.nn.Module):
91 def __init__(self):
92 super().__init__()
93 self.window = torch.hann_window(_WhisperHParams.N_FFT)
94 self.mel_filters = torch.from_numpy(
95 _mel_filterbank(
96 sr=_WhisperHParams.SAMPLE_RATE,
97 n_fft=_WhisperHParams.N_FFT,
98 n_mels=_WhisperHParams.N_MELS))
99
100 def forward(self, audio_pcm: torch.Tensor):
101 stft_norm = CustomOpStftNorm.apply(audio_pcm,
102 _WhisperHParams.N_FFT,
103 _WhisperHParams.HOP_LENGTH,
104 self.window)
105 magnitudes = stft_norm[:, :, :-1]
106 mel_spec = self.mel_filters @ magnitudes
107 log_spec = torch.clamp(mel_spec, min=1e-10).log10()
108 spec_min = log_spec.max() - 8.0
109 log_spec = torch.maximum(log_spec, spec_min)
110 spec_shape = log_spec.shape
111 padding_spec = torch.ones(spec_shape[0],
112 spec_shape[1],
113 _WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2],
114 dtype=torch.float)
115 padding_spec *= spec_min
116 log_spec = torch.cat((log_spec, padding_spec), dim=2)
117 log_spec = (log_spec + 4.0) / 4.0
118 return log_spec
119
120
121def _to_onnx_stft(onnx_model):
122 """Convert custom-op STFT-Norm to ONNX STFT"""
123 node_idx = 0
124 new_stft_nodes = []
125 stft_norm_node = None
126 for node in onnx_model.graph.node:
127 if node.op_type == "StftNorm":
128 stft_norm_node = node
129 break
130 node_idx += 1
131
132 if stft_norm_node is None:
133 raise RuntimeError("Cannot find STFTNorm node in the graph")
134
135 make_node = onnx.helper.make_node
136 replaced_nodes = [
137 make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
138 value=numpy_helper.from_array(np.array([0,
139 _WhisperHParams.N_FFT // 2, 0,
140 _WhisperHParams.N_FFT // 2], dtype='int64'),
141 name='const_14')),
142 make_node('Pad',
143 inputs=[stft_norm_node.input[0], 'const_14_output_0'],
144 outputs=['pad_1_output_0'], mode='reflect'),
145 make_node('STFT',
146 inputs=['pad_1_output_0', stft_norm_node.input[2], stft_norm_node.input[3], stft_norm_node.input[4]],
147 outputs=['stft_output_0'], name='stft', domain='', onesided=1),
148 make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
149 perm=[0, 2, 1, 3]),
150 make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
151 value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
152 make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
153 value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
154 make_node('Constant', inputs=[], outputs=['const_19_output_0'], name='const_19',
155 value=numpy_helper.from_array(np.array([-1], dtype='int64'), name='')),
156 make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
157 value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
158 make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_19_output_0',
159 'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
160 name='slice_1'),
161 make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
162 make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
163 make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
164 name='gather_4', axis=3),
165 make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
166 name='gather_5', axis=3),
167 make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
168 make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
169 make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
170 ]
171 new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
172 new_stft_nodes.extend(replaced_nodes)
173 new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
174 del onnx_model.graph.node[:]
175 onnx_model.graph.node.extend(new_stft_nodes)
176 onnx.checker.check_model(onnx_model)
177 return onnx_model
178
179
180def _torch_export(*arg, **kwargs):
181 with io.BytesIO() as f:
182 torch.onnx.export(*arg, f, **kwargs)
183 return onnx.load_from_string(f.getvalue())
184
185
186class WhisperDataProcGraph:
187 def __init__(self, processor, **kwargs):
188 self.hf_processor = processor
189 _opset = kwargs.pop('opset', 17)
190 self.opset_version = _opset if _opset else 17
191
192 def pre_processing(self, **kwargs):
193 use_audio_decoder = kwargs.pop('USE_AUDIO_DECODER', True)
194 use_onnx_stft = kwargs.pop('USE_ONNX_STFT', True)
195 whisper_processing = WhisperPrePipeline()
196
197 audio_pcm = torch.rand((1, 32000), dtype=torch.float32)
198 model_args = (audio_pcm,)
199 pre_model = _torch_export(
200 whisper_processing,
201 model_args,
202 input_names=["audio_pcm"],
203 output_names=["log_mel"],
204 do_constant_folding=True,
205 export_params=True,
206 opset_version=self.opset_version,
207 dynamic_axes={
208 "audio_pcm": {1: "sample_len"},
209 }
210 )
211 if use_onnx_stft:
212 pre_model = _to_onnx_stft(pre_model)
213 remove_unused_initializers(pre_model.graph)
214
215 pre_full = pre_model
216 if use_audio_decoder:
217 audecoder_g = SingleOpGraph.build_graph(
218 "AudioDecoder", downsampling_rate=_WhisperHParams.SAMPLE_RATE, stereo_to_mono=1)
219 audecoder_m = make_onnx_model(audecoder_g)
220 pre_full = onnx.compose.merge_models(
221 audecoder_m,
222 pre_model,
223 io_map=[("floatPCM", "audio_pcm")])
224
225 return pre_full
226
227 def post_processing(self, **kwargs):
228 skip_special_tokens = kwargs.get('skip_special_tokens', True)
229 g = SingleOpGraph.build_graph(
230 "BpeDecoder",
231 cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
232 skip_special_tokens=skip_special_tokens)
233
234 bpenode = g.node[0]
235 bpenode.input[0] = "generated_ids"
236 nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
237 bpenode]
238 del g.node[:]
239 g.node.extend(nodes)
240
241 inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
242 del g.input[:]
243 g.input.extend(inputs)
244 g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
245
246 return make_onnx_model(g, opset_version=self.opset_version)
247