1 #ifndef VTESTBED_CORE_FOURIER_TRANSFORM_HH 2 #define VTESTBED_CORE_FOURIER_TRANSFORM_HH 8 #include <vtestbed/base/blitz.hh> 9 #include <vtestbed/base/for_loop.hh> 10 #include <vtestbed/core/math.hh> 20 typedef blitz::Array<int,1> indices_type;
21 typedef blitz::Array<T,1> array_type;
25 indices_type _indices;
32 this->_indices.free();
42 return this->_indices;
56 typedef blitz::Array<T,1> array_type;
62 if (n < 1 || !blitz::is_power_of_two(n)) {
65 auto& indices = this->_indices;
66 auto& waves = this->_waves;
69 indices.resize(indices_length_real(n));
70 this->makewt(nw, indices.data(), waves.data());
71 this->makect(nw, indices.data(), waves.data() + nw);
75 fourier_transform(array_type& x,
int sign) {
76 auto& indices = this->_indices;
77 auto& waves = this->_waves;
91 indices_length_real(
int n) {
92 return 2 + static_cast<int>(std::sqrt(n/2));
96 makewt(
int nw,
int *ip, T *w);
99 makect(
int nc,
int *ip, T *c);
102 rdft(
int n,
int isgn, T *a,
int *ip,
const T *w);
107 class Wave_table<std::complex<T>>:
108 public Wave_table_base<std::complex<T>> {
113 typedef blitz::Array<std::complex<T>,1> array_type;
119 if (n < 1 || !blitz::is_power_of_two(n)) {
122 auto& indices = this->_indices;
123 auto& waves = this->_waves;
126 indices.resize(indices_length_complex(n));
127 this->makewt(nw, indices.data(), from_complex(waves.data()));
131 fourier_transform(array_type& x,
int sign) {
132 auto& indices = this->_indices;
133 auto& waves = this->_waves;
137 from_complex(x.data()),
139 from_complex(waves.data())
146 indices_length_complex(
int n) {
147 return 2 + static_cast<int>(std::sqrt(n));
152 return reinterpret_cast<T*>(x);
157 return reinterpret_cast<const T*>(x);
161 makewt(
int nw,
int *ip, T *w);
164 cdft(
int n,
int isgn, T* a,
int* ip,
const T* w);
168 template <
class T,
int N>
171 blitz::Array<T,N>& x,
173 Wave_table<T> waves[N]
175 const int nelems = x.numElements();
176 for (
int i=0; i<N; ++i) {
177 const int stride = x.stride(i);
178 const int extent = x.extent(i);
179 const int block_size = extent*stride;
180 const int nblocks = nelems / block_size;
181 #if defined(VTB_DEBUG_FFT) 184 <<
",bs=" << block_size
185 <<
",nblocks=" << nblocks
188 for (
int k=0; k<nblocks; ++k) {
189 for (
int j=0; j<stride; ++j) {
190 const int offset = block_size*k + j;
191 #if defined(VTB_DEBUG_FFT) 193 <<
",stride=" << stride
194 <<
",offset=" << offset
195 <<
",check=" << (offset + stride*(extent-1))
198 blitz::Array<T,1> slice(blitz::shape(extent));
199 for (
int idx=0; idx<extent; ++idx) {
200 slice(idx) = *(x.data() + offset + idx*stride);
202 waves[i].fourier_transform(slice, sign);
203 for (
int idx=0; idx<extent; ++idx) {
204 *(x.data() + offset + idx*stride) = slice(idx);
219 template <
class T,
int N>
223 typedef blitz::TinyVector<int,N> shape_type;
224 typedef blitz::Array<T,N> array_type;
246 inline const shape_type&
247 shape()
const noexcept {
260 for (
int i=0; i<N; ++i) {
261 if (this->_shape(i) != rhs(i)) {
262 this->_shape(i) = rhs(i);
263 this->_waves[i].shape(rhs(i));
271 for (
int i=0; i<N; ++i) {
272 this->_waves[i].free();
279 check(
const array_type& x) {
280 if (!blitz::all(x.shape() == this->shape())) {
292 template <
class T,
int N>
299 using typename base_type::shape_type;
300 using typename base_type::array_type;
313 forward(array_type& x) {
314 this->transform(x, -1);
318 backward(array_type& x) {
319 this->transform(x, 1);
323 transform(array_type& x,
int dir) {
324 #if defined(VTB_DEBUG) 327 fourier_transform<T,N>(x, dir, this->_waves);
332 template <
class T,
int N>
336 typedef blitz::TinyVector<int,N> shape_type;
337 typedef blitz::Array<T,N> array_type;
341 typedef blitz::TinyVector<T,N> vec;
342 typedef blitz::RectDomain<N> domain;
343 typedef typename T::value_type R;
347 shape_type _shape{0};
348 bool _power_of_two =
false;
360 inline const shape_type&
361 shape()
const noexcept {
365 inline const shape_type&
366 fourier_transform_shape()
const noexcept {
367 return this->_fft.shape();
371 shape(
const shape_type& rhs) {
373 if (all(this->_shape == rhs)) {
377 shape_type new_shape;
378 if (all(is_power_of_two(rhs))) {
380 _power_of_two =
true;
382 new_shape = next_power_of_two(rhs*2 - 1);
383 _power_of_two =
false;
385 _fft.shape(new_shape);
390 forward(array_type& x) {
391 this->transform(x, -1);
395 backward(array_type& x) {
396 this->transform(x, 1);
400 transform(array_type& x,
int dir) {
401 using blitz::product;
403 _fft.transform(x, dir);
405 #if defined(VTB_DEBUG_CHIRP_Z) 406 std::clog <<
"chirp=" << _chirp << std::endl;
408 array_type xp{_fft.shape()};
410 shape_type offset{_shape-1};
411 domain rect{offset, _chirp.shape()-1};
412 domain lrect{xp.base(), offset};
413 #if defined(VTB_DEBUG_CHIRP_Z) 414 std::clog <<
"offset=" << offset << std::endl;
417 xp(lrect) = x*_chirp(rect);
420 xp(lrect) = x*blitz::conj(_chirp(rect));
422 #if defined(VTB_DEBUG_CHIRP_Z) 425 array_type ichirp{_fft.shape()};
428 ichirp(_chirp.domain()) = R{1} / _chirp;
430 ichirp(_chirp.domain()) = R{1} / blitz::conj(_chirp);
432 #if defined(VTB_DEBUG_CHIRP_Z) 433 std::clog <<
"ichirp=" << ichirp << std::endl;
435 _fft.transform(xp, dir);
436 #if defined(VTB_DEBUG_CHIRP_Z) 437 std::clog <<
"fft(xp)=" << xp << std::endl;
439 _fft.transform(ichirp, dir);
440 #if defined(VTB_DEBUG_CHIRP_Z) 441 std::clog <<
"fft(ichirp)=" << ichirp << std::endl;
444 #if defined(VTB_DEBUG_CHIRP_Z) 445 std::clog <<
"mult2=" << ichirp << std::endl;
447 _fft.transform(ichirp, -dir);
448 #if defined(VTB_DEBUG_CHIRP_Z) 449 std::clog <<
"ifft=" << ichirp << std::endl;
451 x = ichirp(rect) * _chirp(rect) / R(product(_fft.shape()));
462 using blitz::product;
463 const R pi2 = R{M_PI}*2;
464 shape_type chirp_shape{_shape*2-1};
465 _chirp.resize(chirp_shape);
468 vec w = exp(d*i*pi2/vec{_shape});
469 #if defined(VTB_DEBUG_CHIRP_Z) 472 shape_type offset{_shape-1};
475 [&offset,
this,&w] (
const shape_type& idx) {
476 vec k = idx - offset;
477 _chirp(idx) = product(pow(w, pow2(k)*T{0.5,0}));
488 #endif // vim:filetype=cpp