【转】一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的
来源:
一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的2008-12-01 20:37最近项目中需要用,就自己写了个,发在这里算是备忘。需要的朋友也可以拿去试试,经测试没有发现bug,基于STL实现。如果发现bug或有什么建议请通知我,谢谢。
/************************************************************************/ /* wavelet.h * Author: Collin * Date: 2008/12/01 */ /************************************************************************/ #ifndef WAVELET_H #define WAVELET_H #include <vector> namespace Wavelet{ using std::vector; struct C_L { vector<double> C; vector<int> L; }; struct WaveFilter{ vector<double> Low; vector<double> High; }; struct WaveCoeff{ vector<double> app; vector<double> det; }; const double sym4_Lo_D[] = {-0.0758, -0.0296, 0.4976, 0.8037, 0.2979, -0.0992, -0.0126, 0.0322}; const double sym4_Hi_D[] = {-0.0322, -0.0126, 0.0992, 0.2979, -0.8037, 0.4976, 0.0296, -0.0758}; const double sym4_Lo_R[] = {0.0322, -0.0126, -0.0992, 0.2979, 0.8037, 0.4976, -0.0296, -0.0758}; const double sym4_Hi_R[] = {-0.0758, 0.0296, 0.4976, -0.8037, 0.2979, 0.0992, -0.0126, -0.0322}; const static WaveFilter sym4_d = {vector<double>(sym4_Lo_D, sym4_Lo_D + 8), vector<double>(sym4_Hi_D, sym4_Hi_D + 8)}; const static WaveFilter sym4_r = {vector<double>(sym4_Lo_R, sym4_Lo_R + 8), vector<double>(sym4_Hi_R, sym4_Hi_R + 8)}; const WaveFilter& WFilters( const char* strWaveName, const char d_or_r ); C_L WaveDec( const vector<double>& signal, const int nMaxLevel, const char* strWaveName ); WaveCoeff DWT( const vector<double>& signal, const vector<double>& Lo_D, const vector<double>& Hi_D ); vector<double> WRCoef( const char a_or_d, const vector<double>& C, const vector<int>& L, const char* strWaveName, const int nLevel ); vector<double> AppCoef( const vector<double>& C, const vector<int>& L, const char* strWaveName, const int nLevel ); vector<double> DetCoef( const vector<double>& C, const vector<int>& L, const int nLevel ); //upsample and convolution vector<double> UpsConv1( const vector<double>& signal, const vector<double>& filter, const int nLen, const char* strMode = "sym" ); vector<double> Conv( const vector<double>& vecSignal, const vector<double>& vecFilter ); vector<double> IDWT( const vector<double>& app, const vector<double>& det, const vector<double>& Lo_R, const vector<double>& Hi_R, const int nLenCentral ); vector<double> WExtend( const vector<double>& signal, const int nLenExt, const char* mode = "sym" ); vector<double> WConv1( const vector<double>& signal, const vector<double>& filter, const char* shape = "valid" ); } #endif /************************************************************************/ /* wavelet.cpp * Author: Collin * Date: 2008/12/01 */ /************************************************************************/ #include <vector> #include <string> #include <iostream> #include "wavelet.h" using namespace std; using namespace Wavelet; C_L Wavelet::WaveDec(const vector<double>& signal, const int nMaxLevel, const char* strWaveName ) { const WaveFilter& filters = WFilters(strWaveName, 'd'); int len = signal.size(); C_L cl; cl.L.push_back(len); WaveCoeff waveCoeff; waveCoeff.app = signal; vector<double>::iterator itC; vector<int>::iterator itL; for (int i = 0; i < nMaxLevel; ++i){ waveCoeff = DWT(waveCoeff.app, filters.Low, filters.High); itC = cl.C.begin(); cl.C.insert(itC, waveCoeff.det.begin(), waveCoeff.det.end()); itL = cl.L.begin(); cl.L.insert(itL, waveCoeff.det.size()); } itC = cl.C.begin(); cl.C.insert(itC, waveCoeff.app.begin(), waveCoeff.app.end()); itL = cl.L.begin(); cl.L.insert(itL, waveCoeff.app.size()); return cl; } vector<double> Wavelet::WRCoef(const char a_or_d, const vector<double>& C, const vector<int>& L, const char* strWaveName, const int nLevel ) { vector<double> Coef; const WaveFilter& filter = WFilters(strWaveName, 'r'); int nMax = L.size() - 2; int nMin; char type = tolower(a_or_d); if ('a' == type) nMin = 0; else if ('d' == type) nMin = 1; else { cerr << "bad parameter: a_or_d: "<< a_or_d << "\n"; exit(1); } if (nLevel < nMin || nLevel > nMax){ cerr << "bad parameter for level\n"; exit(1); } vector<double> F1; switch (type){ case 'a': Coef = AppCoef(C, L, strWaveName, nLevel); if (0 == nLevel) return Coef; F1 = filter.Low; break; case 'd': Coef = DetCoef(C, L, nLevel); F1 = filter.High; break; default: ; } int iMin = L.size() - nLevel; Coef = UpsConv1(Coef, F1, L[iMin], "sym"); for (int k = 1; k < nLevel; ++k){ Coef = UpsConv1(Coef, filter.Low, L[iMin + k], "sym"); } return Coef; } vector<double> Wavelet::UpsConv1(const vector<double>& signal, const vector<double>& filter, const int nLen, const char* strMode ) { //implement dyadup(y,0) vector<double> y(2 * signal.size() - 1); y[0] = signal[0]; for (int i = 1; i < signal.size(); ++i){ y[2*i - 1] = 0; y[2*i] = signal[i]; } y = Conv(y, filter); //extract the central portion vector<double>::iterator it = y.begin(); return vector<double>(it + (y.size() - nLen) / 2, it + (y.size() + nLen) / 2); } vector<double> Wavelet::Conv(const vector<double>& vecSignal, const vector<double>& vecFilter){ vector<double> signal(vecSignal); vector<double> filter(vecFilter); if (signal.size() < filter.size()) signal.swap(filter); int lenSignal = signal.size(); int lenFilter = filter.size(); vector<double> result(lenSignal + lenFilter - 1); for (int i = 0; i < lenFilter; i++){ for (int j = 0; j <= i; j++) result[i] += signal[j] * filter[i - j]; } for (int i = lenFilter; i < lenSignal; i++){ for (int j = 0; j <lenFilter; j++) result[i] += signal[i - j] * filter[j]; } for (int i = lenSignal; i < lenSignal + lenFilter - 1; i++){ for (int j = i - lenSignal + 1; j < lenFilter; j++) result[i] += signal[i - j] * filter[j]; } return result; } vector<double> Wavelet::DetCoef(const vector<double>& C, const vector<int>& L, const int nLevel ) { if (nLevel < 1 || nLevel > L.size() - 2){ cerr << "bad level parameter\n"; exit(1); } int nlast = 0, nfirst = 0; vector<int>::const_reverse_iterator it = L.rbegin(); ++it; for (int i = 1; i < nLevel; ++i){ nlast += *it; ++it; } nfirst = nlast + *it; return vector<double>(C.end() - nfirst, C.end() - nlast); } WaveCoeff Wavelet::DWT(const vector<double>& signal, const vector<double>& Lo_D, const vector<double>& Hi_D ) { int nLenExt = Lo_D.size() - 1; vector<double> y; y = WExtend(signal, nLenExt, "sym"); vector<double> z; z = WConv1(y, Lo_D, "valid"); WaveCoeff coeff; for (int i = 1; i < z.size(); i += 2){ coeff.app.push_back(z[i]); } z = WConv1(y, Hi_D, "valid"); for (int i = 1; i < z.size(); i += 2){ coeff.det.push_back(z[i]); } return coeff; } const WaveFilter& Wavelet::WFilters(const char* strWaveName, const char d_or_r ) { char type = tolower(d_or_r); if (!strcmp(strWaveName, "sym4")){ switch(type){ case 'd': return Wavelet::sym4_d; break; case 'r': return Wavelet::sym4_r; break; default: cerr << "bad parameter for d_or_r\n"; exit(1); } } else { cerr << "not implement \n"; exit(1); } } vector<double> Wavelet::AppCoef( const vector<double>& C, const vector<int>& L, const char* strWaveName, const int nLevel ) { int nMaxLevel = L.size() - 2; if (nLevel < 0 || nLevel > nMaxLevel){ cerr << "bad parameter for level\n"; exit(1); } const WaveFilter& filters = WFilters(strWaveName, 'r'); vector<double> app(C.begin(), C.begin() + L[0]); //app for the last level vector<double> det; for (int i = 0; i < nMaxLevel - nLevel; ++i){ det = DetCoef(C, L, nMaxLevel - i); app = IDWT(app, det, filters.Low, filters.High, L[i + 2]); } return app; } vector<double> Wavelet::IDWT( const vector<double>& app, const vector<double>& det, const vector<double>& Lo_R, const vector<double>& Hi_R, const int nLenCentral ) { vector<double> app1, app2; app1 = UpsConv1(app, Lo_R, nLenCentral, "sym"); app2 = UpsConv1(det, Hi_R, nLenCentral, "sym"); for (int i = 0; i < nLenCentral; ++i){ app1[i] += app2[i]; } return app1; } vector<double> Wavelet::WExtend( const vector<double>& signal, const int nLenExt, const char* mode ) { int signalLen = signal.size(); vector<double> result(signalLen + 2 * nLenExt); for (int i = 0, idx = nLenExt; idx < signalLen + nLenExt; ++i, ++idx){ result[idx] = signal[i]; } for (int idx = nLenExt - 1, bFlag = 1, signalIdx = 0; idx >= 0; --idx){ result[idx] = signal[signalIdx]; if (bFlag && ++signalIdx == signalLen){ bFlag = 0; signalIdx = signalLen - 1; } else if (!bFlag && --signalIdx == -1) { bFlag = 1; signalIdx = 0; } } for (int idx = nLenExt + signalLen, bFlag = 0, signalIdx = signalLen - 1; idx < 2 * nLenExt + signalLen; ++idx){ result[idx] = signal[signalIdx]; if (bFlag && ++signalIdx == signalLen){ bFlag = 0; signalIdx = signalLen - 1; } else if (!bFlag && --signalIdx == -1) { bFlag = 1; signalIdx = 0; } } return result; } vector<double> Wavelet::WConv1( const vector<double>& signal, const vector<double>& filter, const char* shape ) { vector<double> y; y = Conv(signal, filter); int nLenExt = filter.size() - 1; return vector<double>(y.begin() + nLenExt, y.end() - nLenExt); }