HarvardX TinyML小笔记2(番外3:语音关键词跟踪)

1 原理

整体流程

关键字识别面临的一些问题。时延,准确率,安全,功耗等。

整体的流程和一般的深度学习还是差不多。数据收集,数据预处理,模型设计,训练,评估。

 整体过程:

Course | edX

声音处理的流程:

以单通道16 kHz音频、1秒窗口为例:

(a) 预处理

  • 去直流/归一化;可选预加重y[n]=x[n]-\alpha x[n-1](\alpha≈0.97),提升高频。

  • 可选静音段剔除(VAD)。

(b) 分帧 + 加窗

  • 帧长 25 ms(400点),帧移 10 ms(160点),海明窗 w[n]

  • 1秒里帧数 =\left\lfloor\frac{16000-400}{160}\right\rfloor+1=98

(c) STFT 与功率谱

  • 每帧做FFT(常用512或256点),得频谱 X(f,t);

  • 功率 (f,t)=\frac{|X(f,t)|^2}{N}

(d) 梅尔滤波器组

  • 设 40 个三角形梅尔滤波器 H_k(f),把线性频率积分到梅尔频带:E_k(t)=\sum_f P(f,t)\,H_k(f)

  • 取对数(dB或自然对数)压缩动态范围:

  • S_k(t)=\log(E_k(t)+\epsilon)

(e) 得到“图像”/特征张量

  • 得到大小 (频带数 × 帧数) = (40 × 98) 的矩阵,就是对时间的“条纹”图

  • 归一化(每频带减均值/除方差,或整体min–max到[0,1]);

  • 这就是输入给CNN的灰度图;也可叠加**一阶/二阶差分(Δ/ΔΔ)**当作3个通道 → 形状 (40×98×3)

可选:MFCC是在上一步的 \{S_k\}上做DCT,取前13维系数(+Δ/ΔΔ),更紧凑,但近年KWS更常直接用log-mel谱。

第一步是数模转换。

第二步将声音信号转换为频率信号,用的是FFT,也就是快速Fourier transforms--傅里叶变换。(还别说,我之前本科的毕业论文就这方面)。简单说,原始声音的时域分析能告诉你 “声音什么时候响、有多响”,做了FFT之后能告诉你 “这是什么声音、里面有什么成分”—— 这个是解决多数实际声音分析问题的关键点。

第三步将频域信号转换为频谱图。

在转换中,使用了梅尔MEL滤波器。通过梅尔滤波,丢弃人耳不敏感的高频冗余,保留对听觉有意义的低频核心信息,让频谱更贴近人类对声音的实际感知,更方便后续的识别。

MFCC(Mel-Frequency Cepstral Coefficients,梅尔频率倒谱系数)是语音 / 音频信号处理领域最核心、最常用的特征提取技术之一,本质是 “基于人类听觉系统特性,从声音信号中提取的紧凑且鲁棒的频域特征“。(尽管近年来有 CNN、Transformer 等模型直接从原始音频提取特征,但 MFCC 的简洁性和鲁棒性仍使其在轻量级场景中不可替代)

简单总结:

We're actually completely converting it on a different representation.We went from an audio time squeeze signal into a frequency signal. And we're looking at the frequency as a picture.

这句话就简单总结了整个流程。首先是将声音信号转换为频率信号,这里用的是FFT算法,然后将频率信号转换成可以观察的图片。这样后面就用和识图一样的处理手段了。

2 其余要考虑的

数据集的创建,这个可能依然是工作量最大的部分,要考虑最终使用场景和用户。

声音的附件处理,包括归一化和降噪等。

在使用数据时,要根据使用场景,选定最好的Operation point。

最后的部署,是两级,第一级用tinyML唤醒,第二级用NN增加准确率。

3 代码阅读

3.1 Spectrograms and MFCCs

https://colab.research.google.com/github/tinyMLx/colabs/blob/master/3-5-10-SpectrogramsMFCCs.ipynb

这个示例主要是代码详细的就不列了,太多,有水字数的嫌疑。重点列一些代码片段吧。

必须的库。除了tensorflow,还有ffmpeg和speech的一个。

!pip install ffmpeg-python &> 0
!pip install tensorflow-io &> 0
!pip install python_speech_features &> 0
print("Packages Installed")

定义读取语音的web接口,这个还是有意思。

<script>
var my_div = document.createElement("DIV");
var my_p = document.createElement("P");
var my_btn = document.createElement("BUTTON");
var t = document.createTextNode("Press to start recording");

my_btn.appendChild(t);
//my_p.appendChild(my_btn);
my_div.appendChild(my_btn);
document.body.appendChild(my_div);

var base64data = 0;
var reader;
var recorder, gumStream;
var recordButton = my_btn;

var handleSuccess = function(stream) {
  gumStream = stream;
  var options = {
    //bitsPerSecond: 8000, //chrome seems to ignore, always 48k
    mimeType : 'audio/webm;codecs=opus'
    //mimeType : 'audio/webm;codecs=pcm'
  };            
  //recorder = new MediaRecorder(stream, options);
  recorder = new MediaRecorder(stream);
  recorder.ondataavailable = function(e) {            
    var url = URL.createObjectURL(e.data);
    var preview = document.createElement('audio');
    preview.controls = true;
    preview.src = url;
    document.body.appendChild(preview);

    reader = new FileReader();
    reader.readAsDataURL(e.data); 
    reader.onloadend = function() {
      base64data = reader.result;
      //console.log("Inside FileReader:" + base64data);
    }
  };
  recorder.start();
  };

recordButton.innerText = "Recording... press to stop";

navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);


function toggleRecording() {
  if (recorder && recorder.state == "recording") {
      recorder.stop();
      gumStream.getAudioTracks()[0].stop();
      recordButton.innerText = "Saving the recording... pls wait!"
  }
}

// https://stackoverflow.com/a/951057
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

var data = new Promise(resolve=>{
//recordButton.addEventListener("click", toggleRecording);
recordButton.onclick = ()=>{
toggleRecording()

sleep(2000).then(() => {
  // wait 2000ms for the data to be available...
  // ideally this should use something like await...
  //console.log("Inside data:" + base64data)
  resolve(base64data.toString())

});

}
});
      
</script>
AUDIO_HTML = """
JS
"""

def get_audio():
  display(HTML(AUDIO_HTML))
  data = eval_js("data")
  binary = b64decode(data.split(',')[1])
  
  process = (ffmpeg
    .input('pipe:0')
    .output('pipe:1', format='wav')
    .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)
  )
  output, err = process.communicate(input=binary)
  
  riff_chunk_size = len(output) - 8
  # Break up the chunk size into four bytes, held in b.
  q = riff_chunk_size
  b = []
  for i in range(4):
      q, r = divmod(q, 256)
      b.append(r)

  # Replace bytes 4:8 in proc.stdout with the actual size of the RIFF chunk.
  riff = output[:4] + bytes(b) + output[8:]

  sr, audio = wav_read(io.BytesIO(riff))

  return audio, sr

print("Chrome Audio Recorder Defined")

大概就是JS调用recorder = new MediaRecorder(stream);创建录音器,之后recorder.start()开始录音。录音完成后,使用resolve接口发给python。

new Promise(resolve=>{
  recordButton.onclick = ()=>{
    toggleRecording();
    sleep(2000).then(() => {
      resolve(base64data.toString());
    });
  }

在python端首先使用display(HTML(AUDIO_HTML))运行JS,之后使用eval_js获取数据。这个接口还真从来没见过,看起来是Colab专用。之后使用ffmpeg,将数据转换成wav,最后对wav进行了一下修正。

process = (ffmpeg
    .input('pipe:0')
    .output('pipe:1', format='wav')
    .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)
  )

然后是调用这个接口,获取了4个数据,分别是Yes(大声,小声),No(大声,小声)。

之后直接画出了4个音频的原始声音信号。(其实还是能看出差别)

之后是对信号进行了快速傅里叶变换。用的是np.fft.fft。

#Adapted from https://makersportal.com/blog/2018/9/13/audio-processing-in-python-part-i-sampling-and-the-fast-fourier-transform
# compute the FFT and take the single-sided spectrum only and remove imaginary part
ft_audio_yes_loud = np.abs(2*np.fft.fft(audio_yes_loud))
ft_audio_yes_quiet = np.abs(2*np.fft.fft(audio_yes_quiet))
ft_audio_no_loud = np.abs(2*np.fft.fft(audio_no_loud))
ft_audio_no_quiet = np.abs(2*np.fft.fft(audio_no_quiet))

说实话,感觉变换之后感觉还不如之前清楚。

之后使用了import tensorflow_io as tfio,将信号转成频谱图。

# Convert to spectrogram and display
# adapted from https://aruno14.medium.com/comparaison-of-audio-representation-in-tensorflow-b6c33a83d77f
spectrogram_yes_loud = tfio.audio.spectrogram(audio_yes_loud/1.0, nfft=2048, window=len(audio_yes_loud), stride=int(sr_yes_loud * 0.008))
spectrogram_yes_quiet = tfio.audio.spectrogram(audio_yes_quiet/1.0, nfft=2048, window=len(audio_yes_quiet), stride=int(sr_yes_quiet * 0.008))
spectrogram_no_loud = tfio.audio.spectrogram(audio_no_loud/1.0, nfft=2048, window=len(audio_no_loud), stride=int(sr_no_loud * 0.008))
spectrogram_no_quiet = tfio.audio.spectrogram(audio_no_quiet/1.0, nfft=2048, window=len(audio_no_quiet), stride=int(sr_no_quiet * 0.008))

然后显示。这个直接多了,很多内容是肉眼可见。

最后是做MFCC处理。使用的import librosa,具体是librosa.power_to_db做的梅尔滤波。

# Convert to MFCC using the Mel Scale
# adapted from: https://towardsdatascience.com/getting-to-know-the-mel-spectrogram-31bca3e2d9d0
mfcc_yes_loud = librosa.power_to_db(librosa.feature.melspectrogram(
    y=np.float32(audio_yes_loud), sr=sr_yes_loud, n_fft=2048, hop_length=512, n_mels=128), ref=np.max)
mfcc_yes_quiet = librosa.power_to_db(librosa.feature.melspectrogram(
    y=np.float32(audio_yes_quiet), sr=sr_yes_quiet, n_fft=2048, hop_length=512, n_mels=128), ref=np.max)
mfcc_no_loud = librosa.power_to_db(librosa.feature.melspectrogram(
    y=np.float32(audio_no_loud), sr=sr_no_loud, n_fft=2048, hop_length=512, n_mels=128), ref=np.max)
mfcc_no_quiet = librosa.power_to_db(librosa.feature.melspectrogram(
    y=np.float32(audio_no_quiet), sr=sr_no_quiet, n_fft=2048, hop_length=512, n_mels=128), ref=np.max)

最后的图如下,看起来确实更明了了。

3.2 Keyword Spotting

https://colab.research.google.com/github/tinyMLx/colabs/blob/master/3-5-13-PretrainedModel.ipynb

介绍资料:Course | edXCourse | edX

这个练习就基本上是全过程。

3.2.1 生成预训练模型

还是要用2.14的库,不过这次官方好像发现Google环境的问题,所以将库做成手动安装了。

!wget https://github.com/tensorflow/tensorflow/archive/v2.14.0.zip
!unzip v2.14.0.zip &> 0
!mv tensorflow-2.14.0/ tensorflow

在导入库的时候,要手动安装speech的库。

# We add this path so we can import the speech processing modules.
sys.path.append("/content/tensorflow/tensorflow/examples/speech_commands/")

这个就是预训练库,大概是2021k。

!curl -O "https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_micro_train_2020_05_10.tgz"
!tar xzf speech_micro_train_2020_05_10.tgz
TOTAL_STEPS = 15000 # used to identify which checkpoint file

之后使用freeze.py冻结模型。

!rm -rf {SAVED_MODEL}
!python tensorflow/tensorflow/examples/speech_commands/freeze.py \
--wanted_words=$WANTED_WORDS \
--window_stride_ms=$WINDOW_STRIDE \
--preprocess=$PREPROCESS \
--model_architecture=$MODEL_ARCHITECTURE \
--start_checkpoint=$TRAIN_DIR$MODEL_ARCHITECTURE'.ckpt-'{TOTAL_STEPS} \
--save_format=saved_model \
--output_file={SAVED_MODEL}

在机器学习(尤其是深度学习)中,“冻结模型”(Frozen Model)指的是固定模型的部分或全部参数,使其在后续训练过程中不再更新的操作。这一技术的核心目的是 “复用已有模型的知识” 或 “控制模型训练复杂度”,常见于迁移学习、模型微调、多阶段训练等场景,本质是平衡 “知识复用” 与 “任务适配” 的需求。

生成lite模型。

with tf.Session() as sess:
# with tf.compat.v1.Session() as sess: #replaces the above line for use with TF2.x
  float_converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL)
  float_tflite_model = float_converter.convert()
  float_tflite_model_size = open(FLOAT_MODEL_TFLITE, "wb").write(float_tflite_model)
  print("Float model is %d bytes" % float_tflite_model_size)

  converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL)
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  converter.inference_input_type = tf.lite.constants.INT8
  # converter.inference_input_type = tf.compat.v1.lite.constants.INT8 #replaces the above line for use with TF2.x   
  converter.inference_output_type = tf.lite.constants.INT8
  # converter.inference_output_type = tf.compat.v1.lite.constants.INT8 #replaces the above line for use with TF2.x
  def representative_dataset_gen():
    for i in range(100):
      data, _ = audio_processor.get_data(1, i*1, model_settings,
                                         BACKGROUND_FREQUENCY, 
                                         BACKGROUND_VOLUME_RANGE,
                                         TIME_SHIFT_MS,
                                         'testing',
                                         sess)
      flattened_data = np.array(data.flatten(), dtype=np.float32).reshape(1, 1960)
      yield [flattened_data]
  converter.representative_dataset = representative_dataset_gen
  tflite_model = converter.convert()
  tflite_model_size = open(MODEL_TFLITE, "wb").write(tflite_model)
  print("Quantized model is %d bytes" % tflite_model_size)

最后的模型大概是19K。

Float model is 68392 bytes
Quantized model is 19160 bytes

测试准确率

测试函数如下:

# Helper function to run inference
def run_tflite_inference_testSet(tflite_model_path, model_type="Float"):
  #
  # Load test data
  #
  np.random.seed(0) # set random seed for reproducible test results.
  with tf.Session() as sess:
  # with tf.compat.v1.Session() as sess: #replaces the above line for use with TF2.x
    test_data, test_labels = audio_processor.get_data(
        -1, 0, model_settings, BACKGROUND_FREQUENCY, BACKGROUND_VOLUME_RANGE,
        TIME_SHIFT_MS, 'testing', sess)
  test_data = np.expand_dims(test_data, axis=1).astype(np.float32)

  #
  # Initialize the interpreter
  #
  interpreter = tf.lite.Interpreter(tflite_model_path)
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]
  
  #
  # For quantized models, manually quantize the input data from float to integer
  #
  if model_type == "Quantized":
    input_scale, input_zero_point = input_details["quantization"]
    test_data = test_data / input_scale + input_zero_point
    test_data = test_data.astype(input_details["dtype"])

  #
  # Evaluate the predictions
  #
  correct_predictions = 0
  for i in range(len(test_data)):
    interpreter.set_tensor(input_details["index"], test_data[i])
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]
    top_prediction = output.argmax()
    correct_predictions += (top_prediction == test_labels[i])

  print('%s model accuracy is %f%% (Number of test samples=%d)' % (
      model_type, (correct_predictions * 100) / len(test_data), len(test_data)))

最后发现量化后准确率完全没变化。。。

Float model accuracy is 90.857605% (Number of test samples=1236)
Quantized model accuracy is 90.857605% (Number of test samples=1236)

3.2.2 使用声音数据测试

数据来源。

from IPython.display import HTML, Audio
!wget --no-check-certificate --content-disposition https://github.com/tinyMLx/colabs/blob/master/yes_no.pkl?raw=true
print("Wait a minute for the file to sync in the Colab and then run the next cell!")

读取声音数据,进行FFT,MFCC处理,最后送入模型。

# Helper function to run inference (on a single input this time)
# Note: this also includes additional manual pre-processing
TF_SESS = tf.compat.v1.InteractiveSession()
def run_tflite_inference_singleFile(tflite_model_path, custom_audio, sr_custom_audio, model_type="Float"):
  #
  # Preprocess the sample to get the features we pass to the model
  #
  # First re-sample to the needed rate (and convert to mono if needed)
  custom_audio_resampled = librosa.resample(librosa.to_mono(np.float64(custom_audio)), orig_sr = sr_custom_audio, target_sr = SAMPLE_RATE)
  # Then extract the loudest one second
  scipy.io.wavfile.write('custom_audio.wav', SAMPLE_RATE, np.int16(custom_audio_resampled))
  !/tmp/extract_loudest_section/gen/bin/extract_loudest_section custom_audio.wav ./trimmed
  # Finally pass it through the TFLiteMicro preprocessor to produce the 
  # spectrogram/MFCC input that the model expects
  custom_model_settings = models.prepare_model_settings(
      0, SAMPLE_RATE, CLIP_DURATION_MS, WINDOW_SIZE_MS,
      WINDOW_STRIDE, FEATURE_BIN_COUNT, PREPROCESS)
  custom_audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
                                                    model_settings, None)
  custom_audio_preprocessed = custom_audio_processor.get_features_for_wav(
                                        'trimmed/custom_audio.wav', model_settings, TF_SESS)
  # Reshape the output into a 1,1960 matrix as that is what the model expects
  custom_audio_input = custom_audio_preprocessed[0].flatten()
  test_data = np.reshape(custom_audio_input,(1,len(custom_audio_input)))

  #
  # Initialize the interpreter
  #
  interpreter = tf.lite.Interpreter(tflite_model_path)
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  #
  # For quantized models, manually quantize the input data from float to integer
  #
  if model_type == "Quantized":
    input_scale, input_zero_point = input_details["quantization"]
    test_data = test_data / input_scale + input_zero_point
    test_data = test_data.astype(input_details["dtype"])

  #
  # Run the interpreter
  #
  interpreter.set_tensor(input_details["index"], test_data)
  interpreter.invoke()
  output = interpreter.get_tensor(output_details["index"])[0]
  top_prediction = output.argmax()

  #
  # Translate the output
  #
  top_prediction_str = ''
  if top_prediction == 2 or top_prediction == 3:
    top_prediction_str = WANTED_WORDS.split(',')[top_prediction-2]
  elif top_prediction == 0:
    top_prediction_str = 'silence'
  else:
    top_prediction_str = 'unknown'

  print('%s model guessed the value to be %s' % (model_type, top_prediction_str))

这里首先使用了extract_loudest_section工具读取声音最响的部分。

!/tmp/extract_loudest_section/gen/bin/extract_loudest_section custom_audio.wav ./trimmed

之后使用Google的speech工具做的FFT和MFCC处理。

库来源:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands

custom_model_settings = models.prepare_model_settings(...)
custom_audio_processor = input_data.AudioProcessor(...)
custom_audio_preprocessed = custom_audio_processor.get_features_for_wav(
    'trimmed/custom_audio.wav', model_settings, TF_SESS)

之后做量化。

    input_scale, input_zero_point = input_details["quantization"]
    test_data = test_data / input_scale + input_zero_point
    test_data = test_data.astype(input_details["dtype"])

最后就是推理。

效果差强人意吧,正常的yes和no还是可以,稍微装怪换一下声音就出错或者识别不了了。

Saved to './trimmed/custom_audio.wav'
Quantized model guessed the value to be unknown
Testing yes4
Saved to './trimmed/custom_audio.wav'
Quantized model guessed the value to be yes
Testing no1
Saved to './trimmed/custom_audio.wav'
Quantized model guessed the value to be no
Testing no2
Saved to './trimmed/custom_audio.wav'
Quantized model guessed the value to be unknown
Testing no3
Saved to './trimmed/custom_audio.wav'
Quantized model guessed the value to be unknown
Testing no4
Saved to './trimmed/custom_audio.wav'
Quantized model guessed the value to be no

可以看到,官方自己的样例库,大概都有3个无法识别。

最后一段使用自己的语音进行识别,和上面差不多,正常说yes和no还是可以,但是只要刻意的变换一下声音,就不行了。

3.3 训练Keyword Spotting Model

https://colab.research.google.com/github/tinyMLx/colabs/blob/master/3-5-18-TrainingKeywordSpotting.ipynb

这个是一个assignment,但是这个部分没有更新,还有之前2.14不能下载的问题,所以只能简单看看了。

要求填写的内容如下:

WANTED_WORDS = # YOUR CODE HERE #
TRAINING_STEPS =  # YOUR CODE HERE #
LEARNING_RATE = # YOUR CODE HERE #

答案在它自己的注释里面都说的很明白了。第一个是"yes,no"。

第二个和第三个就是分段训练设置,第一批次多少条,用什么学习率,第二批次多少条,用什么学习率。不过注释里面说的常用配置 "12000,3000" / "0.001,0.0001",据说已经是经验最优值之一了,不用再去改了。

这里再区分一下step和epoch。epoch是把 整个训练集(N 条样本)过一遍。step只是抽指定数量的batch做一次完整的前向和反向。

相比上一个程序,这边多了一个计算全连接的参数。基本就是根据CLIP_DURATION_MS, WINDOW_SIZE_MS, WINDOW_STRIDE这几个来算。如下:

# Calculate the correct flattened input data shape for later use in model conversion
# since the model takes a flattened version of the spectrogram. The shape is number of 
# overlapping windows times the number of frequency bins. For the default settings we have
# 40 bins (as set above) times 49 windows (as calculated below) so the shape is (1,1960)
def window_counter(total_samples, window_size, stride):
  '''helper function to count the number of full-length overlapping windows'''
  window_count = 0
  sample_index = 0
  while True:
    window = range(sample_index,sample_index+stride)
    if window.stop < total_samples:
      window_count += 1
    else:
      break
    
    sample_index += stride
  return window_count

OVERLAPPING_WINDOWS = window_counter(CLIP_DURATION_MS, int(WINDOW_SIZE_MS), WINDOW_STRIDE)
FLATTENED_SPECTROGRAM_SHAPE = (1, OVERLAPPING_WINDOWS * FEATURE_BIN_COUNT)

这里的数据集是Google Speech Commands Dataset。

# URL for the dataset and train/val/test split
DATA_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'

训练整个模型,据说用GPU的话要算2个小时,不用GPU用CPU的话,要算10个小时。。。每隔一小时,还必须手动点一下确认。。。

使用train.py训练模型。

!python tensorflow/tensorflow/examples/speech_commands/train.py \
--data_dir={DATASET_DIR} \
--wanted_words={WANTED_WORDS} \
--silence_percentage={SILENT_PERCENTAGE} \
--unknown_percentage={UNKNOWN_PERCENTAGE} \
--preprocess={PREPROCESS} \
--window_stride={WINDOW_STRIDE} \
--model_architecture={MODEL_ARCHITECTURE} \
--how_many_training_steps={TRAINING_STEPS} \
--learning_rate={LEARNING_RATE} \
--train_dir={TRAIN_DIR} \
--summaries_dir={LOGS_DIR} \
--verbosity={VERBOSITY} \
--eval_step_interval={EVAL_STEP_INTERVAL} \
--save_step_interval={SAVE_STEP_INTERVAL}

之后就是冻结freeze模型,推理这些,和3.2.2的例子没有什么区别了。

4 参考

三棕一蓝3Blue1Brown的个人空间-3Blue1Brown个人主页-哔哩哔哩视频

居然还是哈佛的官方推荐,有点小惊讶。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值