一. 什么是 WaveNet?
简单来说, WaveNet 是一种生成模型, 类似 VAE,GAN 等, WaveNet 最大的特点是可以直接生成 raw audio 的模型, 由 2017 年 DeepMind 提出, 在 TTS(文字转语音) 任务上可以达到 state-of-art 的效果.
此外 WaveNet 也可以用来做生成文字, 生成图片, 语音识别等.
WaveNet 的具体相关可以参考以下资料:
- Ref :
- DeepMind WaveNet Blog https://deepmind.com/blog/wavenet-generative-model-raw-audio/
keras 实现代码 https://github.com/basveeling/wavenet
WaveNet 论文 https://arxiv.org/pdf/1609.03499.pdf
窃以为, 学习一种网络结构应该结合论文和代码, 而理解模型的基础首先是知道模型的输入输出.
DeepMind 博客上的动图非常清楚地展示了这个模型的工作过程. 一定要看! 体会!
WaveNet 的网络结构并不复杂, 说白了其实就是一类变种 CNN. 但是介绍 WaveNet 的各种文章只对 WaveNet 的结构夸夸其谈, 丝毫没有涉及模型的输入输出到底是什么, 对小白非常不友好.
本文着重介绍 WaveNet keras 实现代码中的输入数据组织.
二. 模型的运作过程
这里不谈模型的原理和结构 (实际上只要理解了 CNN, 理解 WaveNet 非常容易). 我们先谈谈 WaveNet 到底 "做了什么"?
- def wav_to_float(x):
- try:
- max_value = np.iinfo(x.dtype).max
- min_value = np.iinfo(x.dtype).min
- except:
- max_value = np.finfo(x.dtype).max
- min_value = np.iinfo(x.dtype).min
- x = x.astype('float64', casting='safe')
- x -= min_value
- x /= ((max_value - min_value) / 2.)
- x -= 1.
- return x
- def float_to_uint8(x):
- x += 1.
- x /= 2.
- uint8_max_value = np.iinfo('uint8').max
- x *= uint8_max_value
- x = x.astype('uint8')
- return x
- def process_wav(desired_sample_rate, filename, use_ulaw):
- # print('reading wavfile...',filename)
- with warnings.catch_warnings():
- # warnings.simplefilter("error") # 提升警告等级? 会导致 np.fromstring 报错
- channels = scipy.io.wavfile.read(filename)
- file_sample_rate, audio = channels
- audio = ensure_mono(audio)
- audio = wav_to_float(audio)
- if use_ulaw:
- audio = ulaw(audio)
- audio = ensure_sample_rate(desired_sample_rate, file_sample_rate, audio)
- audio = float_to_uint8(audio)
- return audio
- def fragment_indices(full_sequences, fragment_length, batch_size, fragment_stride, nb_output_bins):
- for seq_i, sequence in enumerate(full_sequences):
- for i in range(0, sequence.shape[0] - fragment_length, fragment_stride):
- yield seq_i, i
- # i 为 input sequence 的起点 seq_i 为音频文件的 id
- batches = cycle(partition_all(batch_size, indices)) # indices 为列表
- for batch in batches:
- if len(batch) < batch_size:
- continue
- yield np.array(
- [one_hot(full_sequences[e[0]][e[1]:e[1] + fragment_length]) for e in batch], dtype='uint8'), np.array(
- [one_hot(full_sequences[e[0]][e[1] + 1:e[1] + fragment_length + 1]) for e in batch], dtype='uint8')
来源: https://www.cnblogs.com/seanliao/p/9595536.html