声音信号的基频检测(python版本)

import math
import wave
import array
import functools
from abc import ABC, abstractmethod
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import os
import sys

# ====================== 设计模式部分 ======================
class PreprocessStrategy(ABC):
    """预处理策略基类"""
    @abstractmethod
    def process(self, signal):
        pass

class FrameStrategy(ABC):
    """分帧策略基类"""
    @abstractmethod
    def frame(self, signal, sample_rate):
        pass

class PitchDetector(ABC):
    """基频检测器基类"""
    @abstractmethod
    def detect(self, frames, sample_rate):
        pass

class Visualizer(ABC):
    """可视化器基类"""
    @abstractmethod
    def visualize(self, analysis_data):
        pass

class ProcessorFactory:
    """处理器工厂(工厂模式)"""
    @staticmethod
    def create_preprocessor(strategy_type):
        if strategy_type == "preemphasis":
            return PreEmphasisStrategy()
        raise ValueError("未知的预处理策略")

    @staticmethod
    def create_framer(strategy_type):
        if strategy_type == "fixed":
            return FixedFrameStrategy()
        raise ValueError("未知的分帧策略")

    @staticmethod
    def create_detector(strategy_type):
        if strategy_type == "autocorrelation":
            return AutoCorrelationDetector()
        elif strategy_type == "cepstrum":
            return CepstrumDetector()
        raise ValueError("未知的检测策略")
    
    @staticmethod
    def create_visualizer(strategy_type):
        if strategy_type == "matplotlib":
            return MatplotlibVisualizer()
        elif strategy_type == "text":
            return TextVisualizer()
        raise ValueError("未知的可视化策略")

# ====================== 信号处理部分 ======================
class PreEmphasisStrategy(PreprocessStrategy):
    """预加重策略(提升高频分量)"""
    def process(self, signal):
        return [signal[i] - 0.97 * signal[i-1] for i in range(1, len(signal))]

class FixedFrameStrategy(FrameStrategy):
    """固定分帧策略(策略模式)"""
    def __init__(self, frame_ms=25, overlap_ratio=0.5):
        self.frame_ms = frame_ms
        self.overlap_ratio = overlap_ratio
    
    def frame(self, signal, sample_rate):
        frame_length = int(sample_rate * self.frame_ms / 1000)
        step = int(frame_length * (1 - self.overlap_ratio))
        frames = []
        for start in range(0, len(signal) - frame_length, step):
            frames.append(signal[start:start+frame_length])
        return frames

class AutoCorrelationDetector(PitchDetector):
    """自相关基频检测(核心算法)"""
    def detect(self, frames, sample_rate):
        pitches = []
        acf_values = []  # 存储每帧的自相关值用于可视化
        
        for frame in frames:
            # 计算自相关函数
            acf = self._autocorrelation(frame)
            acf_values.append(acf)
            # 寻找基频峰值
            pitch = self._find_pitch(acf, sample_rate)
            pitches.append(pitch)
            
        return pitches, acf_values  # 返回基频和自相关值

    def _autocorrelation(self, frame):
        n = len(frame)
        acf = []
        for lag in range(n//2):  # 只计算一半延迟
            total = 0.0
            for i in range(n - lag):
                total += frame[i] * frame[i + lag]
            acf.append(total)
        return acf

    def _find_pitch(self, acf, sample_rate):
        # 忽略前10个样本(避免直流分量影响)
        if len(acf) < 20:  # 确保有足够的数据点
            return 0
            
        search_range = acf[10:]
        if not search_range:
            return 0
        
        # 寻找最大峰值位置
        max_index = search_range.index(max(search_range)) + 10
        
        # 二次插值提高精度
        if 1 < max_index < len(acf)-1:
            alpha = acf[max_index-1]
            beta = acf[max_index]
            gamma = acf[max_index+1]
            
            # 避免除以零
            denominator = alpha - 2*beta + gamma
            if abs(denominator) < 1e-10:  # 防止分母接近零
                peak_index = max_index
            else:
                delta = 0.5 * (alpha - gamma) / denominator
                peak_index = max_index + delta
        else:
            peak_index = max_index
        
        # 计算基频
        if peak_index > 0:
            return sample_rate / peak_index
        return 0

class CepstrumDetector(PitchDetector):
    """倒谱法基频检测(备选策略)"""
    def detect(self, frames, sample_rate):
        # 实现类似自相关法,使用倒谱峰值检测
        # 此处省略具体实现
        return [0] * len(frames), [[]] * len(frames)

# ====================== 可视化部分 ======================
class MatplotlibVisualizer(Visualizer):
    """使用matplotlib进行可视化(需要matplotlib库)"""
    def __init__(self):
        # 解决中文显示问题
        self._configure_matplotlib()
    
    def _configure_matplotlib(self):
        """配置matplotlib以支持中文显示"""
        plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']
        plt.rcParams['axes.unicode_minus'] = False
    
    def visualize(self, analysis_data):
        # 解包分析数据
        original_signal = analysis_data['original_signal']
        processed_signal = analysis_data['processed_signal']
        frames = analysis_data['frames']
        windowed_frames = analysis_data['windowed_frames']
        pitches = analysis_data['pitches']
        acf_values = analysis_data['acf_values']
        sample_rate = analysis_data['sample_rate']
        
        # 创建画布
        fig = plt.figure(figsize=(15, 12))
        gs = GridSpec(4, 2, figure=fig)
        
        # 1. 原始信号和预处理信号对比
        ax1 = plt.subplot(gs[0, :])
        time_original = [i / sample_rate for i in range(len(original_signal))]
        time_processed = [i / sample_rate for i in range(len(processed_signal))]
        ax1.plot(time_original, original_signal, label='原始信号')
        ax1.plot(time_processed, processed_signal, label='预处理后信号', alpha=0.7)
        ax1.set_title('原始信号 vs 预处理信号')
        ax1.set_xlabel('时间 (秒)')
        ax1.set_ylabel('振幅')
        ax1.legend()
        ax1.grid(True)
        
        # 2. 第一帧的原始和加窗信号
        ax2 = plt.subplot(gs[1, 0])
        frame_index = 0
        frame_time = [i / sample_rate * 1000 for i in range(len(frames[frame_index]))]  # 毫秒
        ax2.plot(frame_time, frames[frame_index], label='原始帧')
        ax2.plot(frame_time, windowed_frames[frame_index], label='加窗帧')
        ax2.set_title(f'第{frame_index+1}帧信号 (原始 vs 加窗)')
        ax2.set_xlabel('时间 (毫秒)')
        ax2.set_ylabel('振幅')
        ax2.legend()
        ax2.grid(True)
        
        # 3. 第一帧的自相关函数
        ax3 = plt.subplot(gs[1, 1])
        lags = [i * 1000 / sample_rate for i in range(len(acf_values[frame_index]))]  # 毫秒
        ax3.plot(lags, acf_values[frame_index])
        ax3.set_title(f'第{frame_index+1}帧的自相关函数')
        ax3.set_xlabel('延迟 (毫秒)')
        ax3.set_ylabel('自相关值')
        ax3.grid(True)
        
        # 标记基音周期
        pitch = pitches[frame_index]
        if pitch > 0:
            period = 1000 / pitch  # 周期(毫秒)
            ax3.axvline(period, color='r', linestyle='--', 
                         label=f'基音周期: {period:.2f}ms')
            ax3.legend()
        
        # 4. 基频检测结果
        ax4 = plt.subplot(gs[2, :])
        frame_times = [i * (len(frames[0]) * (1-0.5) / sample_rate) 
                      for i in range(len(pitches))]  # 帧中心时间
        ax4.plot(frame_times, pitches)
        ax4.set_title('基频检测结果')
        ax4.set_xlabel('时间 (秒)')
        ax4.set_ylabel('频率 (Hz)')
        ax4.grid(True)
        
        # 5. 频谱图
        ax5 = plt.subplot(gs[3, 0])
        frame_to_show = windowed_frames[frame_index]
        n = len(frame_to_show)
        # 使用DFT计算频谱
        spectrum = [abs(self._dft(frame_to_show, k)) for k in range(n//2)]
        freqs = [k * sample_rate / n for k in range(n//2)]
        ax5.plot(freqs, spectrum)
        ax5.set_title(f'第{frame_index+1}帧频谱')
        ax5.set_xlabel('频率 (Hz)')
        ax5.set_ylabel('幅度')
        ax5.set_xlim(0, 2000)
        ax5.grid(True)
        
        # 标记基频和谐波
        if pitch > 0:
            ax5.axvline(pitch, color='r', linestyle='--', label=f'基频: {pitch:.1f}Hz')
            # 标记前5个谐波
            for i in range(2, 6):
                harmonic = i * pitch
                if harmonic < 2000:
                    ax5.axvline(harmonic, color='g', linestyle=':', 
                                label=f'{i}次谐波' if i==2 else None)
            ax5.legend()
        
        # 6. 原始信号频谱图
        ax6 = plt.subplot(gs[3, 1])
        # 对整段信号进行DFT
        full_spectrum = [abs(self._dft(original_signal, k)) for k in range(len(original_signal)//2)]
        full_freqs = [k * sample_rate / len(original_signal) for k in range(len(original_signal)//2)]
        ax6.plot(full_freqs, full_spectrum)
        ax6.set_title('整段信号频谱')
        ax6.set_xlabel('频率 (Hz)')
        ax6.set_ylabel('幅度')
        ax6.set_xlim(0, 2000)
        ax6.grid(True)
        
        plt.tight_layout()
        plt.show()
    
    def _dft(self, x, k):
        """离散傅里叶变换 (不使用第三方库)"""
        N = len(x)
        real = 0.0
        imag = 0.0
        for n in range(N):
            angle = 2 * math.pi * k * n / N
            real += x[n] * math.cos(angle)
            imag -= x[n] * math.sin(angle)
        return math.sqrt(real*real + imag*imag) / N

class TextVisualizer(Visualizer):
    """文本可视化器(用于无图形界面环境)"""
    def visualize(self, analysis_data):
        pitches = analysis_data['pitches']
        print("\n基频检测结果摘要:")
        print(f"分析帧数: {len(pitches)}")
        
        # 计算有效基频统计
        valid_pitches = [p for p in pitches if 50 < p < 800]
        if valid_pitches:
            avg_pitch = sum(valid_pitches) / len(valid_pitches)
            min_pitch = min(valid_pitches)
            max_pitch = max(valid_pitches)
            print(f"平均基频: {avg_pitch:.2f} Hz")
            print(f"最小基频: {min_pitch:.2f} Hz")
            print(f"最大基频: {max_pitch:.2f} Hz")
        else:
            print("未检测到有效基频")
        
        # 打印前5帧的基频
        print("\n前5帧基频值:")
        for i, pitch in enumerate(pitches[:5]):
            print(f"帧 {i+1}: {pitch:.2f} Hz")

# ====================== 工具函数 ======================
def normalize_signal(signal):
    """信号归一化到[-1, 1]范围"""
    max_val = max(abs(x) for x in signal)
    return [x / max_val for x in signal] if max_val > 0 else signal

def hamming_window(frame):
    """应用汉明窗减少频谱泄露"""
    N = len(frame)
    return [frame[i] * (0.54 - 0.46 * math.cos(2 * math.pi * i / (N - 1))) 
            for i in range(N)]

# ====================== 核心处理流程 ======================
class PitchAnalyzer:
    """基频分析主流程(门面模式)"""
    def __init__(self, config):
        self.preprocessor = ProcessorFactory.create_preprocessor(config["preprocess"])
        self.framer = ProcessorFactory.create_framer(config["frame"])
        self.detector = ProcessorFactory.create_detector(config["detect"])
        self.visualizer = ProcessorFactory.create_visualizer(config.get("visualize", "text"))
        self.config = config

    def analyze(self, signal, sample_rate, visualize=False):
        # 保存原始信号用于可视化
        original_signal = signal.copy()
        
        # 1. 预处理
        processed = self.preprocessor.process(signal)
        
        # 2. 分帧
        frames = self.framer.frame(processed, sample_rate)
        
        # 3. 加窗处理
        windowed_frames = [hamming_window(frame) for frame in frames]
        
        # 4. 基频检测
        pitches, acf_values = self.detector.detect(windowed_frames, sample_rate)
        
        # 准备可视化数据
        analysis_data = {
            'original_signal': original_signal,
            'processed_signal': processed,
            'frames': frames,
            'windowed_frames': windowed_frames,
            'pitches': pitches,
            'acf_values': acf_values,
            'sample_rate': sample_rate
        }
        
        # 5. 可视化
        if visualize:
            self.visualizer.visualize(analysis_data)
        
        return pitches, analysis_data

# ====================== 文件处理 ======================
def read_wav_file(filename):
    """读取WAV文件返回信号和采样率"""
    with wave.open(filename, 'rb') as wav:
        n_frames = wav.getnframes()
        sample_rate = wav.getframerate()
        data = wav.readframes(n_frames)
        # 将字节数据转换为浮点数
        samples = array.array('h', data)
        return [s / 32768.0 for s in samples], sample_rate

# ====================== 测试用例 ======================
def test_pitch_analysis(show_plot=True):
    """测试基频检测功能"""
    # 生成440Hz正弦波(A4标准音)
    sample_rate = 44100
    duration = 0.5  # 0.5秒
    freq = 440.0
    t = [i / sample_rate for i in range(int(sample_rate * duration))]
    signal = [math.sin(2 * math.pi * freq * t_i) for t_i in t]
    
    # 添加噪声模拟真实环境
    signal = [s + 0.1 * math.sin(2 * math.pi * 3000 * t_i) for s, t_i in zip(signal, t)]
    
    # 配置分析器
    config = {
        "preprocess": "preemphasis",
        "frame": "fixed",
        "detect": "autocorrelation",
        "visualize": "matplotlib" if show_plot else "text"
    }
    analyzer = PitchAnalyzer(config)
    
    # 执行分析并可视化
    pitches, analysis_data = analyzer.analyze(signal, sample_rate, visualize=show_plot)
    
    # 验证结果(取有效帧的平均值)
    valid_pitches = [p for p in pitches if 50 < p < 800]
    if valid_pitches:
        avg_pitch = sum(valid_pitches) / len(valid_pitches)
        # 允许±5Hz误差
        assert 435 < avg_pitch < 445, f"检测失败:期望440Hz,得到{avg_pitch:.2f}Hz"
        print(f"测试通过!检测基频:{avg_pitch:.2f}Hz (期望440Hz)")
    else:
        print("测试失败:未检测到有效基频")

def test_zero_division():
    """测试除以零异常处理"""
    # 创建一个全零信号,会触发除以零异常
    sample_rate = 44100
    signal = [0.0] * 1000  # 1秒的静音
    
    # 配置分析器
    config = {
        "preprocess": "preemphasis",
        "frame": "fixed",
        "detect": "autocorrelation",
        "visualize": "text"
    }
    analyzer = PitchAnalyzer(config)
    
    # 执行分析
    pitches, _ = analyzer.analyze(signal, sample_rate)
    
    # 验证结果应为0或接近0
    assert all(p < 1e-6 for p in pitches), "全零信号处理失败"
    print("除以零测试通过!")

def test_real_audio(filename=r"E:\temp\sample.wav"):
    """测试真实音频文件"""
    try:
        signal, sample_rate = read_wav_file(filename)
    except FileNotFoundError:
        print(f"文件 {filename} 未找到,使用测试信号代替")
        test_pitch_analysis(show_plot=True)
        return
    
    # 只取前1秒音频
    if len(signal) > sample_rate:
        signal = signal[:sample_rate]
    
    # 配置分析器
    config = {
        "preprocess": "preemphasis",
        "frame": "fixed",
        "detect": "autocorrelation",
        "visualize": "matplotlib"
    }
    analyzer = PitchAnalyzer(config)
    
    # 执行分析并可视化
    analyzer.analyze(signal, sample_rate, visualize=True)

if __name__ == "__main__":
    #print("1. 测试合成信号")
    #test_pitch_analysis(show_plot=True)
    
    #print("\n2. 测试除以零处理")
    #test_zero_division()
    
    print("\n3. 测试真实音频(需要sample.wav文件)")
    test_real_audio()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值