Discrete Wavelet Transform via linear convolution matrix
The following code computes the Discrete Wavelet Transform via a linear convolution matrix, that is built from equivalent filters generated from basic filters by the function "filterform.m"
Haar's ortogonal wavelets are used.
Required for this program are:
http://www.dsprelated.com/showcode/10.php
http://www.dsprelated.com/showcode/12.php
For more details on this code and its dependencies, please visit these blog posts:
http://www.dsprelated.com/showarticle/115.php
http://www.dsprelated.com/showarticle/116.php
% ----------------------------------------------------
% Title: Discrete Wavelet Transform
% by means of a linear convolution matrix
%
% Author: David Valencia
% UPIITA IPN
%
% Posted at: http://www.dsprelated.com/showcode/11.php
%
% Description:
% Computes the Discrete Wavelet Transform
% of n levels (and its branches) from a n sample input.
% Base filters can be of any lenght.
%
% Generalized DWT filter bank computed via the
% convolution matrix method (linear convolution)
%
% Dependencies:
%
% formfilter.m
% http://www.dsprelated.com/showcode/12.php
% upsample2.m
% http://www.dsprelated.com/showcode/10.php
%
% Revisions:
% v1.0a Commented and translated in English
%
% ----------------------------------------------------
close all; clear all; clc;
disp('== GENERALIZED MATRICIAL DWT ==')
%% STEP 1 - Base Filters
% Select type of base filters
typeofbasis = 'o';
typbior = 'bior2.2';
typor = 'db3';
% Obtain base filters
if(typeofbasis == 'b')
[Rf,Df] = biorwavf(typbior);
[h0,h1,g0,g1] = biorfilt(Df,Rf);
elseif (typeofbasis == 'o')
[h0,h1,g0,g1] = wfilters(typor);
end;
%% STEP 2 - Parameter configuration
% One can declare or recover an input vector from another
% program, but here an example vector is built for testing.
N = 16; %Number of input samples
x = (1:N)';
L = length(h0); %Base filter lenght
n_stages = 2; %Of the low-pass stage
n_branches = 2^n_stages; %Number of branches
dec_fact = n_branches; %Decimating factor
% L = Basic filters lenght (h0 รณ h1)
% N = Input vector length
% n_stages = number of stages, it generates (2^n_stages) branches
hx = formfilter(n_stages, 1, h0, h1);
Lhx = length(hx);
fprintf('\nLinear convolution matrix processing\n');
%% STEP 3: Build analysis matrices
% -------------------------------------------------------------
% High-pass analysis matrix (single scale) [linear convolution]
% As the DWT has a simple high-pass branch, the matrix is easy
% to build with the base filter.
% Append zeros to each side of the vector to use it as a sliding
% part to build the matrix
movil_1 = [zeros(1,N-1),fliplr(h1),zeros(1,N-1)];
lm = length(movil_1);
n_rows = ceil((L+N-1)/2); %Compute the number of rows
W1 = zeros(n_rows,N); % Allocate memory for the high pass matrix
dec_factor = 2; %Decimate factor for the high-pass matrix
% Build matrix (with sequential shifting)
for i = 0:(n_rows-1)
W1(i+1,:) = movil_1(lm-(i*dec_factor)-N+1 : lm-(i*dec_factor)); %Rama pasaaltas
end
disp('High-pass analysis matrix: ');
W1
disp('High-pass branch results');
y1 = W1 * x
% -------------------------------------------------------------
% Low-pass several stage analysis matrix [linear convolution]
% Get an equivalent filter to get its length
hx = formfilter(n_stages, 1, h0, h1);
Lhx = length(hx);
% Compute the number of rows per branch
rowsperfilter = ceil((N+Lhx)/dec_fact);
% Allocate memory for the low-pass matrix
W0 = zeros((n_branches/2)*rowsperfilter,N);
tic; %Start cronometer
% Build low pass filter
for i = 0:(n_branches/2)-1
% Get the equivalent filter depending on the number of stages
if n_stages < 3
hx = formfilter(n_stages, (n_branches/2) - i, h0, h1);
else
hx = formfilter(n_stages, (n_branches/2) - i - 1, h0, h1);
end
% Append zeros to the vector
movil = [zeros(1,N-1),fliplr(hx),zeros(1,N-1)];
lm = length(movil);
% Shift and add the vector to the matrix
for j = 0:rowsperfilter-1
W0(i*rowsperfilter + j + 1,:) = movil(lm-(dec_fact*j)-N + 1 : lm-(dec_fact*j));
end
end
disp('Low-pass many stage analysis matrix: ');
W0
disp('Low-pass filter bank results: ');
y0 = W0*x
%% STEP 4: Recover signal
disp('Recovered signal: ');
xn1 = (W1')*y1;
xn0 = (W0')*y0;
xn = xn1+xn0
%% STEP 5: Compute error
err = (x - xn).^2;
disp('Error :');
err = (sum(err)/N);
disp(err);
fprintf('\nProcessing time = %d seg \n',toc);
fprintf('\nProgram end >>> \n');