Skip to content

Commit 80ed0c8

Browse files
committed
Whisper pipeline
1 parent a1222f4 commit 80ed0c8

File tree

5 files changed

+645
-0
lines changed

5 files changed

+645
-0
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using System.Numerics;
6+
using TensorStack.Common.Tensor;
7+
8+
namespace TensorStack.TextGeneration.Pipelines.Whisper
9+
{
10+
public class WhisperPreprocessor
11+
{
12+
private readonly float[] _melFilterBank;
13+
private readonly int _sampleRate = 16000;
14+
private readonly int _frameSize = 400; // 25ms @ 16kHz
15+
private readonly int _hopLength = 160; // 10ms @ 16kHz
16+
private readonly int _numMelBins = 80;
17+
private readonly int _nfft = 512;
18+
19+
public WhisperPreprocessor()
20+
{
21+
_melFilterBank = CreateMelFilterBank();
22+
}
23+
24+
25+
public Tensor<float> Process(string wavPath)
26+
{
27+
var samples = LoadWavPcm16Mono(wavPath);
28+
29+
// Pre-emphasis (optional, Whisper doesn’t strictly need it)
30+
// for (int i = samples.Length - 1; i > 0; i--)
31+
// samples[i] -= 0.97f * samples[i - 1];
32+
33+
var stft = STFT(samples);
34+
var melSpec = ApplyMel(stft);
35+
36+
// log10(mel + epsilon)
37+
var result = new Tensor<float>([1, melSpec.GetLength(0), melSpec.GetLength(1)]);
38+
for (int i = 0; i < melSpec.GetLength(0); i++)
39+
for (int j = 0; j < melSpec.GetLength(1); j++)
40+
result[0, i, j] = (float)Math.Log10(Math.Max(1e-10f, melSpec[i, j]));
41+
42+
return result;
43+
}
44+
45+
46+
private float[] LoadWavPcm16Mono(string path)
47+
{
48+
using var br = new BinaryReader(File.OpenRead(path));
49+
br.ReadBytes(44); // skip WAV header
50+
var data = new List<float>();
51+
while (br.BaseStream.Position < br.BaseStream.Length)
52+
data.Add(br.ReadInt16() / 32768f);
53+
return data.ToArray();
54+
}
55+
56+
57+
private Complex[][] STFT(float[] samples)
58+
{
59+
int numFrames = 1 + (samples.Length - _frameSize) / _hopLength;
60+
var frames = new Complex[numFrames][];
61+
62+
// Hann window
63+
float[] window = Enumerable.Range(0, _frameSize)
64+
.Select(n => 0.5f - 0.5f * (float)Math.Cos(2 * Math.PI * n / _frameSize))
65+
.ToArray();
66+
67+
for (int i = 0; i < numFrames; i++)
68+
{
69+
var frame = new Complex[_nfft];
70+
int start = i * _hopLength;
71+
for (int j = 0; j < _frameSize; j++)
72+
frame[j] = samples[start + j] * window[j];
73+
for (int j = _frameSize; j < _nfft; j++)
74+
frame[j] = Complex.Zero;
75+
76+
FFT(frame); // in-place
77+
frames[i] = frame;
78+
}
79+
80+
return frames;
81+
}
82+
83+
84+
private float[,] ApplyMel(Complex[][] stft)
85+
{
86+
int numFrames = stft.Length;
87+
float[,] melSpec = new float[_numMelBins, numFrames];
88+
89+
for (int t = 0; t < numFrames; t++)
90+
{
91+
float[] power = new float[_nfft / 2 + 1];
92+
for (int f = 0; f < power.Length; f++)
93+
power[f] = (float)(stft[t][f].Magnitude * stft[t][f].Magnitude);
94+
95+
for (int m = 0; m < _numMelBins; m++)
96+
{
97+
float sum = 0;
98+
for (int f = 0; f < power.Length; f++)
99+
sum += power[f] * _melFilterBank[m * power.Length + f];
100+
melSpec[m, t] = sum;
101+
}
102+
}
103+
104+
return melSpec;
105+
}
106+
107+
108+
private float[] CreateMelFilterBank()
109+
{
110+
int numFreqs = _nfft / 2 + 1;
111+
float[] filterBank = new float[_numMelBins * numFreqs];
112+
113+
double fMin = 0;
114+
double fMax = _sampleRate / 2;
115+
double melMin = HzToMel(fMin);
116+
double melMax = HzToMel(fMax);
117+
118+
double[] melPoints = Enumerable.Range(0, _numMelBins + 2)
119+
.Select(i => melMin + (melMax - melMin) * i / (_numMelBins + 1))
120+
.ToArray();
121+
122+
double[] hzPoints = melPoints.Select(MelToHz).ToArray();
123+
int[] bins = hzPoints.Select(hz => (int)Math.Floor((_nfft + 1) * hz / _sampleRate)).ToArray();
124+
125+
for (int m = 1; m <= _numMelBins; m++)
126+
{
127+
int f0 = bins[m - 1], f1 = bins[m], f2 = bins[m + 1];
128+
for (int f = f0; f < f1; f++)
129+
filterBank[(m - 1) * numFreqs + f] = (float)(f - f0) / (f1 - f0);
130+
for (int f = f1; f < f2; f++)
131+
filterBank[(m - 1) * numFreqs + f] = (float)(f2 - f) / (f2 - f1);
132+
}
133+
134+
return filterBank;
135+
}
136+
137+
private static double HzToMel(double hz) => 2595 * Math.Log10(1 + hz / 700);
138+
private static double MelToHz(double mel) => 700 * (Math.Pow(10, mel / 2595) - 1);
139+
140+
141+
private void FFT(Complex[] buffer)
142+
{
143+
int n = buffer.Length;
144+
int bits = (int)Math.Log2(n);
145+
146+
// bit-reversal
147+
for (int i = 1, j = 0; i < n; i++)
148+
{
149+
int bit = n >> 1;
150+
for (; (j & bit) != 0; bit >>= 1) j ^= bit;
151+
j ^= bit;
152+
if (i < j)
153+
{
154+
var temp = buffer[i];
155+
buffer[i] = buffer[j];
156+
buffer[j] = temp;
157+
}
158+
}
159+
160+
// FFT
161+
for (int len = 2; len <= n; len <<= 1)
162+
{
163+
double ang = -2 * Math.PI / len;
164+
Complex wlen = new Complex(Math.Cos(ang), Math.Sin(ang));
165+
for (int i = 0; i < n; i += len)
166+
{
167+
Complex w = Complex.One;
168+
for (int j = 0; j < len / 2; j++)
169+
{
170+
Complex u = buffer[i + j];
171+
Complex v = buffer[i + j + len / 2] * w;
172+
buffer[i + j] = u + v;
173+
buffer[i + j + len / 2] = u - v;
174+
w *= wlen;
175+
}
176+
}
177+
}
178+
}
179+
}
180+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using TensorStack.TextGeneration.Common;
2+
3+
namespace TensorStack.TextGeneration.Pipelines.Whisper
4+
{
5+
public record WhisperConfig : TransformerConfig
6+
{
7+
}
8+
}

0 commit comments

Comments
 (0)