require "narray"  # This line is needed for rake test when making a gem package.
require "numru/fftw3"
require "test/unit"
include NumRu

class FFTW3_R2R_Test < Test::Unit::TestCase
  def setup
    @eps = 1e-10
    @seps = 1e-5
  end

  def test_r2r_all_dims
    nx = 8
    ny = 4
    na = NArray.float(nx,ny).indgen!

    fc = FFTW3.fft_r2r(na, FFTW3::REDFT00)  # cosine trans at 0, 1, 2,...
    c = 1.0 / (2*(nx-1)) / (2*(ny-1))
    nb = FFTW3.fft_r2r(fc*c, FFTW3::REDFT00)
    assert( (na-nb).abs.max < @eps )

    fc = FFTW3.fft_r2r(na, FFTW3::REDFT11)  # cosine trans at 1/2, 1+1/2,...
    c = 1.0 / (2*nx) / (2*ny)
    nb = FFTW3.fft_r2r(fc*c, FFTW3::REDFT11)
    assert( (na-nb).abs.max < @eps )

    fc = FFTW3.fft_r2r(na, FFTW3::REDFT01)
    c = 1.0 / (2*nx) / (2*ny)
    nb = FFTW3.fft_r2r(fc*c, FFTW3::REDFT10)
    assert( (na-nb).abs.max < @eps )

    fc = FFTW3.fft_r2r(na, FFTW3::RODFT00)  # sine trans at 1, 2,...
    c = 1.0 / (2*(nx+1)) / (2*(ny+1))
    nb = FFTW3.fft_r2r(fc*c, FFTW3::RODFT00)
    assert( (na-nb).abs.max < @eps )

    fc = FFTW3.fft_r2r(na, FFTW3::RODFT11)  # sine trans at 1/2, 1+1/2,...
    c = 1.0 / (2*nx) / (2*ny)
    nb = FFTW3.fft_r2r(fc*c, FFTW3::RODFT11)
    assert( (na-nb).abs.max < @eps )

    fc = FFTW3.fft_r2r(na, FFTW3::RODFT01)  # sine trans
    c = 1.0 / (2*nx) / (2*ny)
    nb = FFTW3.fft_r2r(fc*c, FFTW3::RODFT10)
    assert( (na-nb).abs.max < @eps )

  end

  def test_r2r_sigle
    nx = 8
    ny = 4
    na = NArray.sfloat(nx,ny).indgen!

    fc = FFTW3.fft_r2r(na, FFTW3::REDFT00)  # cosine trans at 0, 1, 2,...
    c = 1.0 / (2*(nx-1)) / (2*(ny-1))
    nb = FFTW3.fft_r2r(fc*c, FFTW3::REDFT00)
    assert( (na-nb).abs.max < @seps )
  end


  def test_r2r_some_dims
    nx = 8
    ny = 4
    na = NArray.float(nx,ny).indgen!

    fc = FFTW3.fft_r2r(na, FFTW3::REDFT00, 0)
    nb = FFTW3.fft_r2r(fc, FFTW3::REDFT00, 0) / (2*(nx-1))
    assert( (na-nb).abs.max < @eps )

    fc = FFTW3.fft_r2r(fc, FFTW3::RODFT11, 1)
    fc2 = FFTW3.fft_r2r(na, [FFTW3::REDFT00, FFTW3::RODFT11])
    fc3 = FFTW3.fft_r2r(na, [FFTW3::REDFT00, FFTW3::RODFT11], 0, 1)
    assert( (fc-fc2).abs.max < @eps )
    assert( (fc-fc3).abs.max < @eps )
  end

end
