语音项目中我们通常会使用 stft 对特征进行提取, 很多 python 库也提供了接口. 本文主要介绍使用 librosa,torch, 以及卷积方式进行 stft 和 istft 的运算.
1. stft 运算
关于傅里叶变换和逆变换的基础知识在之前文章中已经做过介绍:
这里就不再介绍了, 下面直接通过代码来得出音频振幅谱和相位谱.
2. librosa 接口
librosa 提供的接口非常简单, 我们通过一个例子进行 stft 和 istft 来恢复一段音频
- def test_lib(data):
- win_len = 320
- win_hop = 160
- fft_len = 512
- spec = librosa.stft(data, Windows='hann', win_length=win_len, n_fft=fft_len, hop_length=win_hop,
- center=True)
- outputs = librosa.istft(spec, Windows='hann', win_length=win_len, hop_length=win_hop,
- center=True)
- sf.write('./lib_stft.wav', outputs, 16000)
- return outputs
其中 librosa_stft 是一个复数形式, 我们可以获取其中的一些特征, 比如
- # 实部
- real = np.real(spec)
- # 虚部
- imag = np.imag(spec)
- # 振幅谱
- mags = np.sqrt(real ** 2 + imag ** 2)
- # 相位谱
- phase = np.angle(spec)
3. torch 接口
同样我们通过一个例子使用 torch 提供的接口来进行 stft 和 istft 恢复一段音频
- def test_torch(inputs):
- fft_len=512
- win_len=320
- len_hop=160
- inputs = torch.from_numpy(inputs.reshape(1,-1).astype(np.float32))
- Windows = torch.hann_window(win_len)
- spec = torch.stft(inputs, fft_len, len_hop, win_len, Windows, center=True, return_complex=False)
- print("stft out", spec.shape)
- out = torch.istft(spec, fft_len, len_hop, win_len, Windows, True, return_complex=False)
- return out
其中 spec 是一个虚部和实部 concatenate 一起的, 我们同样可以获取其中的一些特征:
- real = spec[:, :, :, 0] # 实部
- imag = spec[:, :, :, 1] # 虚部
- mags = torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2)))
- phase = torch.atan2(imag.data, real.data)
4. 利用卷积实现 stft
python 中使用 librosa 以及 pytorch 中使用接口都是很常用的特征提取方式, 但是有时我们需要将算子移植到终端就比较麻烦, 框架通常不直接提供这两个 op, 所以使用卷积实现 stft 和 istft 更容易进行工程移植.
我参考了这里的实现:
其中在使用 test_fft() 测试时会提示错误, 所以对代码进行了一点修改, 其中修改地方添加了注释:
- import torch
- import torch.nn as nn
- import numpy as np
- import torch.nn.functional as F
- from scipy.signal import get_window
- def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
- if win_type == 'None' or win_type is None:
- Windows = np.ones(win_len)
- else:
- Windows = get_window(win_type, win_len, fftbins=True)#**0.5
- N = fft_len
- fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
- real_kernel = np.real(fourier_basis)
- imag_kernel = np.imag(fourier_basis)
- kernel = np.concatenate([real_kernel, imag_kernel], 1).T
- if invers :
- kernel = np.linalg.pinv(kernel).T
- kernel = kernel*Windows
- kernel = kernel[:, None, :]
- return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(Windows[None,:,None].astype(np.float32))
- class ConvSTFT(nn.Module):
- def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
- super(ConvSTFT, self).__init__()
- if fft_len == None:
- self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
- else:
- self.fft_len = fft_len
- kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
- #self.weight = nn.Parameter(kernel, requires_grad=(not fix))
- self.register_buffer('weight', kernel)
- self.feature_type = feature_type
- self.stride = win_inc
- self.win_len = win_len
- self.dim = self.fft_len
- def forward(self, inputs):
- if inputs.dim() == 2:
- inputs = torch.unsqueeze(inputs, 1)
- # 注意这里 pad 方式的对齐
- inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride], mode='reflect')
- outputs = F.conv1d(inputs, self.weight, stride=self.stride)
- # 前半段系数为实数, 后半段系数为虚数
- if self.feature_type == 'complex':
- return outputs
- else:
- dim = self.dim//2+1
- real = outputs[:, :dim, :]
- imag = outputs[:, dim:, :]
- mags = torch.sqrt(real**2+imag**2)
- phase = torch.atan2(imag, real)
- return mags, phase
- class ConviSTFT(nn.Module):
- def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
- super(ConviSTFT, self).__init__()
- if fft_len == None:
- self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
- else:
- self.fft_len = fft_len
- kernel, Windows = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
- #self.weight = nn.Parameter(kernel, requires_grad=(not fix))
- self.register_buffer('weight', kernel)
- self.feature_type = feature_type
- self.win_type = win_type
- self.win_len = win_len
- self.stride = win_inc
- self.stride = win_inc
- self.dim = self.fft_len
- self.register_buffer('window', Windows)
- self.register_buffer('enframe', torch.eye(win_len)[:,None,:])
- def forward(self, inputs, phase=None):
- """
- inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
- phase: [B, N//2+1, T] (if not none)
- """
- if phase is not None:
- real = inputs*torch.cos(phase)
- imag = inputs*torch.sin(phase)
- inputs = torch.cat([real, imag], 1)
- outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
- # this is from torch-stft: https://github.com/pseeth/torch-stft
- t = self.Windows.repeat(1,1,inputs.size(-1))**2
- coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
- outputs = outputs/(coff+1e-8)
- #outputs = torch.where(coff == 0, outputs, outputs/coff)
- outputs = outputs[...,self.win_len-self.stride:-(self.win_len-self.stride)]
- return outputs
- def test_fft():
- torch.manual_seed(20)
- win_len = 320
- win_inc = 160
- fft_len = 512
- inputs = torch.randn([1, 1, 16000*4])
- fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real')
- import librosa
- outputs1 = fft(inputs)[0]
- outputs1 = outputs1.numpy()[0]
- np_inputs = inputs.numpy().reshape([-1])
- # center=True, 在 input 的两侧, 分别镜像填充 n_fft//2 个数据
- librosa_stft = librosa.stft(np_inputs, Windows='hann',win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=True)
- print(np.mean((outputs1 - np.abs(librosa_stft))**2))
- def test_conv_complex(data):
- inputs = data.reshape([1, 1, -1])
- N = 320
- inc = 160
- fft_len = 512
- fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
- ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
- inputs = torch.from_numpy(inputs.astype(np.float32))
- outputs1 = fft(inputs)
- outputs2 = ifft(outputs1)
- sf.write('./conv_stft_complex.wav', outputs2.numpy()[0, 0, :], 16000)
- return outputs2.numpy()[0, 0, :]
- if __name__ == '__main__':
- test_fft()
- #test_conv_complex(data)
总结下如果是 python 项目可以直接使用 librosa 接口, 如果是 pytorch 项目可以直接使用 torch 接口, 如果是需要模型移植到终端的项目, 建议可使用卷积方式方便移植~
来源: https://www.qcloud.com/developer/article/1909240