File: speech_probability_estimator.cc

package info (click to toggle)
chromium 138.0.7204.183-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,071,908 kB
  • sloc: cpp: 34,937,088; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,953; asm: 946,768; xml: 739,971; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,806; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (104 lines) | stat: -rw-r--r-- 3,987 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/*
 *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/ns/speech_probability_estimator.h"

#include <math.h>

#include <algorithm>

#include "modules/audio_processing/ns/fast_math.h"
#include "rtc_base/checks.h"

namespace webrtc {

SpeechProbabilityEstimator::SpeechProbabilityEstimator() {
  speech_probability_.fill(0.f);
}

void SpeechProbabilityEstimator::Update(
    int32_t num_analyzed_frames,
    ArrayView<const float, kFftSizeBy2Plus1> prior_snr,
    ArrayView<const float, kFftSizeBy2Plus1> post_snr,
    ArrayView<const float, kFftSizeBy2Plus1> conservative_noise_spectrum,
    ArrayView<const float, kFftSizeBy2Plus1> signal_spectrum,
    float signal_spectral_sum,
    float signal_energy) {
  // Update models.
  if (num_analyzed_frames < kLongStartupPhaseBlocks) {
    signal_model_estimator_.AdjustNormalization(num_analyzed_frames,
                                                signal_energy);
  }
  signal_model_estimator_.Update(prior_snr, post_snr,
                                 conservative_noise_spectrum, signal_spectrum,
                                 signal_spectral_sum, signal_energy);

  const SignalModel& model = signal_model_estimator_.get_model();
  const PriorSignalModel& prior_model =
      signal_model_estimator_.get_prior_model();

  // Width parameter in sigmoid map for prior model.
  constexpr float kWidthPrior0 = 4.f;
  // Width for pause region: lower range, so increase width in tanh map.
  constexpr float kWidthPrior1 = 2.f * kWidthPrior0;

  // Average LRT feature: use larger width in tanh map for pause regions.
  float width_prior = model.lrt < prior_model.lrt ? kWidthPrior1 : kWidthPrior0;

  // Compute indicator function: sigmoid map.
  float indicator0 =
      0.5f * (tanh(width_prior * (model.lrt - prior_model.lrt)) + 1.f);

  // Spectral flatness feature: use larger width in tanh map for pause regions.
  width_prior = model.spectral_flatness > prior_model.flatness_threshold
                    ? kWidthPrior1
                    : kWidthPrior0;

  // Compute indicator function: sigmoid map.
  float indicator1 =
      0.5f * (tanh(1.f * width_prior *
                   (prior_model.flatness_threshold - model.spectral_flatness)) +
              1.f);

  // For template spectrum-difference : use larger width in tanh map for pause
  // regions.
  width_prior = model.spectral_diff < prior_model.template_diff_threshold
                    ? kWidthPrior1
                    : kWidthPrior0;

  // Compute indicator function: sigmoid map.
  float indicator2 =
      0.5f * (tanh(width_prior * (model.spectral_diff -
                                  prior_model.template_diff_threshold)) +
              1.f);

  // Combine the indicator function with the feature weights.
  float ind_prior = prior_model.lrt_weighting * indicator0 +
                    prior_model.flatness_weighting * indicator1 +
                    prior_model.difference_weighting * indicator2;

  // Compute the prior probability.
  prior_speech_prob_ += 0.1f * (ind_prior - prior_speech_prob_);

  // Make sure probabilities are within range: keep floor to 0.01.
  prior_speech_prob_ = std::max(std::min(prior_speech_prob_, 1.f), 0.01f);

  // Final speech probability: combine prior model with LR factor:.
  float gain_prior =
      (1.f - prior_speech_prob_) / (prior_speech_prob_ + 0.0001f);

  std::array<float, kFftSizeBy2Plus1> inv_lrt;
  ExpApproximationSignFlip(model.avg_log_lrt, inv_lrt);
  for (size_t i = 0; i < kFftSizeBy2Plus1; ++i) {
    speech_probability_[i] = 1.f / (1.f + gain_prior * inv_lrt[i]);
  }
}

}  // namespace webrtc