DSPRelated.com
Code

Fast MDCT/IMDCT Based on Forward FFT

August 6, 2011 Coded in C

Doing audio, especially audio coding, and looking for a fast and easy to use C code of Modified Discrete Cosine Transform (MDCT) and its inverse IMDCT? Here it is. It works well for me. And I hope it works well for you. 

A special flavor of the code is that both the MDCT and IMDCT are based on forward FFT. This is possible bacause the MDCT/IMDCT are in fact the shifted DCT-IV/IDCT-IV, whose transformation matrice is symmetric and othorgonal (thus, the DCT-IV and the IDCT-IV are identical up to scaling). Some small code space and data space might be saved this way. No big deal generally, but it might be worthy of the effort on embedded systems. 

Another special flavor of the code is that the MDCT and IMDCT here are normalized, which means, globally, the power of time samples is equal to that of spectral coefficients, provided that a window satisfying the Princen-Bradley condition of perfect reconstruction (e.g., sine and KBD) is applied. However, temporal power and spectral power of a single transform are not neccessarily equal, due to the intrinsic property (and allure) of Time Domain Alasing Cancellation (TDAC). 

The following MDCT and IMDCT code contains two files: mdct.h and mdct.c. To use the code in your project, you just need to 

  • First, initialize a 'mdct_plan' with size M (the number of MDCT coefficients)

              mdct_plan* m_plan = mdct_init(M);  

  • Then, call mdct/imdct as many times as required

              mdct(freq, time, m_plan);  // out: freq[0,...,M-1]; in: time[0,...,2M-1] 

              or

              imdct(time, freq, m_plan); // out: time[0,...,2M-1]; in: freq[0,...,M-1] 

  • Finally, free up the 'mdct_plan'

              mdct_free(m_plan);  

All of the above is examplified in the following "mdct_test.c". To build 'mdct_test' in double precision float, run

              gcc -o mdct_test -O2 mdct_test.c mdct.c -lfftw3 -lm;

to build 'mdct_test' in single precision float, run 

              gcc -o mdct_test -O2 -DSINGLE_PRECISION mdct_test.c mdct.c -lfftw3f -lm

Notice the 'f' suffix in '-lfftw3f'. 
 
It is assumed that the size of the MDCT here, M, is an even number. This is required by the algorithm used, simply put:
  1. Pre-twiddling: folding and arranging the 2M-point time input into M/2-point complex input;
  2. Forward FFT: M/2-point complex;
  3. Post-twiddling: unfolding the resulting M/2-point complex FFT spectrum into the M-point MDCT spectrum.
In practice, this restriction is not a problem: To my knowledge, all MDCT based audio codecs, such as MP3, AAC, and Ogg Vorbis, use even sized MDCT.

Happy coding:)

/******** begin of mdct.h ******** */
#ifndef __MDCT_H
#define __MDCT_H

#include <fftw3.h>

#ifdef __cplusplus
extern "C" {
#endif

#ifdef SINGLE_PRECISION
typedef float         FLOAT;
typedef fftwf_complex FFTW_COMPLEX;
typedef fftwf_plan    FFTW_PLAN;
#else // DOUBLE_PRECISION
typedef double        FLOAT;
typedef fftw_complex  FFTW_COMPLEX;
typedef fftw_plan     FFTW_PLAN;
#endif  // SINGLE_PRECISION

typedef struct {
    int           M;            // MDCT spectrum size (number of bins)
    FLOAT*        twiddle;      // twiddle factor
    FFTW_COMPLEX* fft_in;       // fft workspace, input
    FFTW_COMPLEX* fft_out;      // fft workspace, output
    FFTW_PLAN     fft_plan;     // fft configuration
} mdct_plan; 

mdct_plan* mdct_init(int M);    // MDCT spectrum size (number of bins)

void mdct_free(mdct_plan* m_plan);

void mdct(FLOAT* mdct_line, FLOAT* time_signal, mdct_plan* m_plan);

void imdct(FLOAT* time_signal, FLOAT* mdct_line, mdct_plan* m_plan);

#ifdef __cplusplus
}
#endif

#endif // __MDCT_H
/******** end of mdct.h ******** */

/******** begin of mdct.c ******** */
#ifdef SINGLE_PRECISION

#define FFTW_MALLOC   fftwf_malloc
#define FFTW_FREE     fftwf_free
#define FFTW_PLAN_1D  fftwf_plan_dft_1d
#define FFTW_DESTROY  fftwf_destroy_plan 
#define FFTW_EXECUTE  fftwf_execute

#else // DOUBLE_PRECISION

#define FFTW_MALLOC   fftw_malloc
#define FFTW_FREE     fftw_free
#define FFTW_PLAN_1D  fftw_plan_dft_1d
#define FFTW_DESTROY  fftw_destroy_plan 
#define FFTW_EXECUTE  fftw_execute

#endif // SINGLE_PRECISION

void mdct_free(mdct_plan* m_plan)
{
    if(m_plan) 
    {
        FFTW_DESTROY(m_plan->fft_plan);
        FFTW_FREE(m_plan->fft_in);
        FFTW_FREE(m_plan->fft_out);

        if(m_plan->twiddle)
            free(m_plan->twiddle);

        free(m_plan);
    }
}

#define MDCT_CLEAUP(msg, ...) \
    {fprintf(stderr, msg", %s(), %s:%d \n", \
            __VA_ARGS__, __func__, __FILE__, __LINE__); \
     mdct_free(m_plan); return NULL;}

mdct_plan* mdct_init(int M)
{
    int        n;
    FLOAT      alpha, omega, scale;
    mdct_plan* m_plan = NULL;

    if(0x00 != (M & 0x01)) 
        MDCT_CLEAUP(" Expect an even number of MDCT coeffs, but meet %d", M);

    m_plan = (mdct_plan*) malloc(sizeof(mdct_plan));
    if(NULL == m_plan)
        MDCT_CLEAUP(" malloc error: %s", "m_plan");
    memset(m_plan, 0, sizeof(m_plan[0]));

    m_plan->M = M;

    m_plan->twiddle = (FLOAT*) malloc(sizeof(FLOAT) * M);
    if(NULL == m_plan->twiddle)
        MDCT_CLEAUP(" malloc error: %s", "m_plan->twiddle");
    alpha = M_PI / (8.f * M);
    omega = M_PI / M;
    scale = sqrt(sqrt(2.f / M));    
    for(n = 0; n < (M >> 1); n++)
    {    
        m_plan->twiddle[2*n+0] = (FLOAT) (scale * cos(omega * n + alpha));
        m_plan->twiddle[2*n+1] = (FLOAT) (scale * sin(omega * n + alpha));
    }    
	
    m_plan->fft_in   
        = (FFTW_COMPLEX*) FFTW_MALLOC(sizeof(FFTW_COMPLEX) * M >> 1);    
    if(NULL == m_plan->fft_in)
        MDCT_CLEAUP(" malloc error: %s", "m_plan->fft_in");

    m_plan->fft_out  
        = (FFTW_COMPLEX*) FFTW_MALLOC(sizeof(FFTW_COMPLEX) * M >> 1);    
    if(NULL == m_plan->fft_out)
        MDCT_CLEAUP(" malloc error: %s", "m_plan->fft_out");

    m_plan->fft_plan = FFTW_PLAN_1D(M >> 1, 
                                    m_plan->fft_in, 
                                    m_plan->fft_out,    
                                    FFTW_FORWARD,
                                    FFTW_MEASURE);
    if(NULL == m_plan->fft_plan)
        MDCT_CLEAUP(" malloc error: %s", "m_plan->fft_plan");

    return m_plan;
}

void mdct(FLOAT* mdct_line, FLOAT* time_signal, mdct_plan* m_plan)
{
    FLOAT *xr, *xi, r0, i0;
    FLOAT *cos_tw, *sin_tw, c, s;
    int    M, M2, M32, M52, n;

    M   = m_plan->M;
    M2  = M >> 1;
    M32 = 3 * M2;
    M52 = 5 * M2;

    cos_tw = m_plan->twiddle;
    sin_tw = cos_tw + 1; 
    
    /* odd/even folding and pre-twiddle */
    xr = (FLOAT*) m_plan->fft_in;
    xi = xr + 1;
    for(n = 0; n < M2; n += 2) 
    {
        r0 = time_signal[M32-1-n] + time_signal[M32+n];    
        i0 = time_signal[M2+n]    - time_signal[M2-1-n];    
        
        c = cos_tw[n];
        s = sin_tw[n];

        xr[n] = r0 * c + i0 * s;
        xi[n] = i0 * c - r0 * s;
    }

    for(; n < M; n += 2) 
    {
        r0 = time_signal[M32-1-n] - time_signal[-M2+n];    
        i0 = time_signal[M2+n]    + time_signal[M52-1-n];    
        
        c = cos_tw[n];
        s = sin_tw[n];

        xr[n] = r0 * c + i0 * s;
        xi[n] = i0 * c - r0 * s;
    }

    /* complex FFT of size M/2 */
    FFTW_EXECUTE(m_plan->fft_plan);

    /* post-twiddle */
    xr = (FLOAT*) m_plan->fft_out;
    xi = xr + 1;
    for(n = 0; n < M; n += 2)
    {
        r0 = xr[n];
        i0 = xi[n];
        
        c = cos_tw[n];
        s = sin_tw[n];    

        mdct_line[n]     = - r0 * c - i0 * s;
        mdct_line[M-1-n] = - r0 * s + i0 * c;
    }
}

void imdct(FLOAT* time_signal, FLOAT* mdct_line, mdct_plan* m_plan)
{
    FLOAT *xr, *xi, r0, i0, r1, i1;
    FLOAT *cos_tw, *sin_tw, c, s;
    int    M, M2, M32, M52, n;

    M   = m_plan->M;
    M2  = M >> 1;
    M32 = 3 * M2;
    M52 = 5 * M2;

    cos_tw = m_plan->twiddle;
    sin_tw = cos_tw + 1; 
    
    /* pre-twiddle */
    xr = (FLOAT*) m_plan->fft_in;
    xi = xr + 1;
    for(n = 0; n < M; n += 2)
    {
        r0 =  mdct_line[n];	
        i0 =  mdct_line[M-1-n];
        
        c = cos_tw[n];
        s = sin_tw[n];    
        
        xr[n] = -i0 * s - r0 * c;
        xi[n] = -i0 * c + r0 * s;
    } 
    
    /* complex FFT of size M/2 */
    FFTW_EXECUTE(m_plan->fft_plan);

    /* odd/even expanding and post-twiddle */
    xr = (FLOAT*) m_plan->fft_out;
    xi = xr + 1;
    for(n = 0; n < M2; n += 2) 
    {
        r0 = xr[n]; 
        i0 = xi[n];    
        
        c = cos_tw[n]; 
        s = sin_tw[n];

        r1 = r0 * c + i0 * s;
        i1 = r0 * s - i0 * c;

        time_signal[M32-1-n] =  r1;
        time_signal[M32+n]   =  r1;
        time_signal[M2+n]    =  i1;
        time_signal[M2-1-n]  = -i1;
    }

    for(; n < M; n += 2) 
    {
        r0 = xr[n]; 
        i0 = xi[n];    
        
        c = cos_tw[n]; 
        s = sin_tw[n];
        
        r1 = r0 * c + i0 * s;
        i1 = r0 * s - i0 * c;
        
        time_signal[M32-1-n] =  r1;
        time_signal[-M2+n]   = -r1;
        time_signal[M2+n]    =  i1;
        time_signal[M52-1-n] =  i1;
    }
}
/******** end of mdct.c ******** */

/******** begin of mdct_test.c ******** */
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/time.h>
#include <time.h>
#include "mdct.h"

int main(int argc, char* argv[])
{
    int        M, r, i;
    FLOAT*     time = NULL;
    FLOAT*     freq = NULL;
    mdct_plan* m_plan = NULL;
    char*      precision = NULL;
    struct timeval t0, t1;
    long long elps;

    if(3 != argc)
    {
        fprintf(stderr, " Usage: %s <MDCT_SPECTRUM_SIZE> <run_times> \n", argv[0]);
        return -1;
    }

    sscanf(argv[1], "%d", &M);
    sscanf(argv[2], "%d", &r);    
    if(NULL == (m_plan = mdct_init(M)))
        return -1;
    if(NULL == (time = (FLOAT*) malloc(2 * M * sizeof(FLOAT))))
        return -1;
    if(NULL == (freq = (FLOAT*) malloc(M * sizeof(FLOAT))))
        return -1;

    for(i = 0; i < 2 * M; i++)
        time[i] = 2.f * rand() / RAND_MAX - 1.f;        
    for(i = 0; i < M; i++)
        freq[i] = 2.f * rand() / RAND_MAX - 1.f;        

    precision = (sizeof(float) == sizeof(FLOAT))? 
                "single precision" : "double precision";

#if 1
    gettimeofday(&t0, NULL);
    for(i = 0; i < r; i++)
        mdct(freq, time, m_plan);
    gettimeofday(&t1, NULL);

    elps = (t1.tv_sec - t0.tv_sec) * 1000000 + (t1.tv_usec - t0.tv_usec);    
    fprintf(stdout, "MDCT size of %d, %s, running %d times, average %.3f ms\n", 
            M, precision, r, (FLOAT) elps / r / 1000.f); 
#endif // 0

#if 1    
    gettimeofday(&t0, NULL);
    for(i = 0; i < r; i++)
        imdct(time, freq, m_plan);
    gettimeofday(&t1, NULL);

    elps = (t1.tv_sec - t0.tv_sec) * 1000000 + (t1.tv_usec - t0.tv_usec);    
    fprintf(stdout, "IMDCT size of %d, %s, running %d times, average %.3f ms\n", 
            M, precision, r, (FLOAT) elps / r / 1000.f); 
#endif //0

#if 0
    for(i = 0; i < 2 * M; i++)
        fprintf(stdout, "%f    ", time[i]);
    fprintf(stdout, "\n");

    for(i = 0; i < M; i++)
        fprintf(stdout, "%f    ", freq[i]);
    fprintf(stdout, "\n");
#endif // 0

    free(time);
    free(freq);
    mdct_free(m_plan);    

    return 0;
}
/******** end of mdct_test.c ******** */