import Numeric, fftpack, copy

_fft_cache = {}
_real_fft_cache = {}

def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti, 
		work_function=fftpack.cfftf, fft_cache = _fft_cache ):
	a = Numeric.asarray(a)

	if n == None: n = a.shape[axis]

	try:
		wsave = fft_cache[n]
	except(KeyError):
		wsave = init_function(n)
		fft_cache[n] = wsave

	if a.shape[axis] != n:
		s = list(a.shape)
		if s[axis] > n:
			index = [slice(None)]*len(s)
			index[axis] = slice(0,n)
			a = a[index]
		else:	
			s[axis] = n-s[axis]
			z = Numeric.zeros(s,a.typecode())
			a = Numeric.concatenate( (a,z) , axis=axis)

	if axis != -1: a = Numeric.swapaxes(a, axis, -1)	
	r = work_function(a, wsave)
	if axis != -1: r = Numeric.swapaxes(r, axis, -1)
	return r

def fft(a, n=None, axis=-1): 
	return _raw_fft(a, n, axis, fftpack.cffti, fftpack.cfftf, _fft_cache)

def inverse_fft(a, n=None, axis=-1): 
	if n == None: n = Numeric.shape(a)[axis]
	return _raw_fft(a, n, axis, fftpack.cffti, fftpack.cfftb, _fft_cache)/n

def real_fft(a, n=None, axis=-1): 
	return _raw_fft(a.astype(Numeric.Float), n, axis, fftpack.rffti, fftpack.rfftf, _real_fft_cache)

def inverse_real_fft(a, n=None, axis=-1): 
	if n == None: n = Numeric.shape(a)[axis]
	return _raw_fft(a.astype(Numeric.Float), n, axis, fftpack.rffti, fftpack.rfftb, _real_fft_cache)/n

def _raw_fft2d(a, s=None, axes=(-2,-1), function=fft):
	a = Numeric.asarray(a)
	if s == None: s = a.shape[-2:]
	f1 = function(a, n=s[1], axis=axes[1])
	return function(f1, n=s[0], axis=axes[0])

def fft2d(a, s=None, axes=(-2,-1)):
	return _raw_fft2d(a,s,axes,fft)

def inverse_fft2d(a, s=None, axes=(-2,-1)):
    return _raw_fft2d(a, s, axes, inverse_fft)

def real_fft2d(a, s=None, axes=(-2,-1)):
	return _raw_fft2d(a, s, axes, real_fft)

def test():
	print fft( (0,1)*4 )
	print inverse_fft( fft((0,1)*4) )
	print fft( (0,1)*4, n=16 )
	print fft( (0,1)*4, n=4 )

	print fft2d( [(0,1),(1,0)] )
	print inverse_fft2d (fft2d( [(0, 1), (1, 0)] ) )
	print real_fft2d([(0,1),(1,0)] )
	print real_fft2d([(1,1),(1,1)] )

if __name__ == '__main__': test()
