Virtual Testbed
Ship dynamics simulator for extreme conditions
core/fourier_transform.hh
1 #ifndef VTESTBED_CORE_FOURIER_TRANSFORM_HH
2 #define VTESTBED_CORE_FOURIER_TRANSFORM_HH
3 
4 #include <complex>
5 #include <limits>
6 #include <stdexcept>
7 
8 #include <vtestbed/base/blitz.hh>
9 #include <vtestbed/base/for_loop.hh>
10 #include <vtestbed/core/math.hh>
11 
12 namespace vtb {
13 
14  namespace core {
15 
16  template <class T>
18 
19  public:
20  typedef blitz::Array<int,1> indices_type;
21  typedef blitz::Array<T,1> array_type;
22 
23  protected:
24  array_type _waves;
25  indices_type _indices;
26 
27  public:
28 
29  inline void
30  free() {
31  this->_waves.free();
32  this->_indices.free();
33  }
34 
35  inline array_type
36  waves() const {
37  return this->_waves;
38  }
39 
40  inline indices_type
41  indices() const {
42  return this->_indices;
43  }
44 
45  };
46 
47  template <class T>
48  class Wave_table;
49 
50  template <class T>
51  class Wave_table: public Wave_table_base<T> {
52 
53  static_assert(std::is_floating_point<T>::value, "bad type");
54 
55  public:
56  typedef blitz::Array<T,1> array_type;
57 
58  public:
59 
60  void
61  shape(int n) {
62  if (n < 1 || !blitz::is_power_of_two(n)) {
63  throw std::invalid_argument(__func__);
64  }
65  auto& indices = this->_indices;
66  auto& waves = this->_waves;
67  const int nw = n/4;
68  waves.resize(2*nw);
69  indices.resize(indices_length_real(n));
70  this->makewt(nw, indices.data(), waves.data());
71  this->makect(nw, indices.data(), waves.data() + nw);
72  }
73 
74  void
75  fourier_transform(array_type& x, int sign) {
76  auto& indices = this->_indices;
77  auto& waves = this->_waves;
78  // TODO the sign???
79  this->rdft(
80  x.extent(0),
81  -sign,
82  x.data(),
83  indices.data(),
84  waves.data()
85  );
86  }
87 
88  private:
89 
90  inline int
91  indices_length_real(int n) {
92  return 2 + static_cast<int>(std::sqrt(n/2));
93  }
94 
95  static void
96  makewt(int nw, int *ip, T *w);
97 
98  static void
99  makect(int nc, int *ip, T *c);
100 
101  static void
102  rdft(int n, int isgn, T *a, int *ip, const T *w);
103 
104  };
105 
106  template <class T>
107  class Wave_table<std::complex<T>>:
108  public Wave_table_base<std::complex<T>> {
109 
110  static_assert(std::is_floating_point<T>::value, "bad type");
111 
112  public:
113  typedef blitz::Array<std::complex<T>,1> array_type;
114 
115  public:
116 
117  void
118  shape(int n) {
119  if (n < 1 || !blitz::is_power_of_two(n)) {
120  throw std::invalid_argument(__func__);
121  }
122  auto& indices = this->_indices;
123  auto& waves = this->_waves;
124  const int nw = n/2;
125  waves.resize(nw);
126  indices.resize(indices_length_complex(n));
127  this->makewt(nw, indices.data(), from_complex(waves.data()));
128  }
129 
130  void
131  fourier_transform(array_type& x, int sign) {
132  auto& indices = this->_indices;
133  auto& waves = this->_waves;
134  this->cdft(
135  2*x.extent(0),
136  sign,
137  from_complex(x.data()),
138  indices.data(),
139  from_complex(waves.data())
140  );
141  }
142 
143  private:
144 
145  inline int
146  indices_length_complex(int n) {
147  return 2 + static_cast<int>(std::sqrt(n));
148  }
149 
150  inline T*
151  from_complex(std::complex<T>* x) {
152  return reinterpret_cast<T*>(x);
153  }
154 
155  inline const T*
156  from_complex(const std::complex<T>* x) {
157  return reinterpret_cast<const T*>(x);
158  }
159 
160  static void
161  makewt(int nw, int *ip, T *w);
162 
163  static void
164  cdft(int n, int isgn, T* a, int* ip, const T* w);
165 
166  };
167 
168  template <class T, int N>
169  inline void
170  fourier_transform(
171  blitz::Array<T,N>& x,
172  int sign,
173  Wave_table<T> waves[N]
174  ) {
175  const int nelems = x.numElements();
176  for (int i=0; i<N; ++i) {
177  const int stride = x.stride(i);
178  const int extent = x.extent(i);
179  const int block_size = extent*stride;
180  const int nblocks = nelems / block_size;
181  #if defined(VTB_DEBUG_FFT)
182  std::clog
183  << "n=" << i
184  << ",bs=" << block_size
185  << ",nblocks=" << nblocks
186  << std::endl;
187  #endif
188  for (int k=0; k<nblocks; ++k) {
189  for (int j=0; j<stride; ++j) {
190  const int offset = block_size*k + j;
191  #if defined(VTB_DEBUG_FFT)
192  std::clog << "FFT: extent=" << extent
193  << ",stride=" << stride
194  << ",offset=" << offset
195  << ",check=" << (offset + stride*(extent-1))
196  << std::endl;
197  #endif
198  blitz::Array<T,1> slice(blitz::shape(extent));
199  for (int idx=0; idx<extent; ++idx) {
200  slice(idx) = *(x.data() + offset + idx*stride);
201  }
202  waves[i].fourier_transform(slice, sign);
203  for (int idx=0; idx<extent; ++idx) {
204  *(x.data() + offset + idx*stride) = slice(idx);
205  }
206  }
207  }
208  }
209  }
210 
211 
219  template <class T, int N>
221 
222  public:
223  typedef blitz::TinyVector<int,N> shape_type;
224  typedef blitz::Array<T,N> array_type;
226 
227  private:
228  shape_type _shape;
229 
230  protected:
231  wavetable_type _waves[N];
232 
233  public:
234 
235  inline explicit
236  Fourier_transform_base(const shape_type& shp) {
237  this->_shape = 0;
238  this->shape(shp);
239  }
240 
241  inline
243  this->_shape = 0;
244  }
245 
246  inline const shape_type&
247  shape() const noexcept {
248  return this->_shape;
249  }
250 
258  inline void
259  shape(const shape_type& rhs) {
260  for (int i=0; i<N; ++i) {
261  if (this->_shape(i) != rhs(i)) {
262  this->_shape(i) = rhs(i);
263  this->_waves[i].shape(rhs(i));
264  }
265  }
266  }
267 
268  inline void
269  clear() {
270  this->_shape = 0;
271  for (int i=0; i<N; ++i) {
272  this->_waves[i].free();
273  }
274  }
275 
276  protected:
277 
278  inline void
279  check(const array_type& x) {
280  if (!blitz::all(x.shape() == this->shape())) {
281  throw std::invalid_argument("Fourier_transform_base::check: bad shape");
282  }
283  }
284 
285  };
286 
292  template <class T, int N>
294 
295  private:
297 
298  public:
299  using typename base_type::shape_type;
300  using typename base_type::array_type;
301 
302  public:
303 
304  inline
306  base_type{} {}
307 
308  inline explicit
309  Fourier_transform(const shape_type& shp):
310  base_type{shp} {}
311 
312  inline void
313  forward(array_type& x) {
314  this->transform(x, -1);
315  }
316 
317  inline void
318  backward(array_type& x) {
319  this->transform(x, 1);
320  }
321 
322  inline void
323  transform(array_type& x, int dir) {
324  #if defined(VTB_DEBUG)
325  this->check(x);
326  #endif
327  fourier_transform<T,N>(x, dir, this->_waves);
328  }
329 
330  };
331 
332  template <class T, int N>
334 
335  public:
336  typedef blitz::TinyVector<int,N> shape_type;
337  typedef blitz::Array<T,N> array_type;
339 
340  private:
341  typedef blitz::TinyVector<T,N> vec;
342  typedef blitz::RectDomain<N> domain;
343  typedef typename T::value_type R;
344 
345  private:
346  fft_type _fft;
347  shape_type _shape{0};
348  bool _power_of_two = false;
349  array_type _chirp;
350 
351  public:
352 
353  Chirp_Z_transform() = default;
354 
355  inline explicit
356  Chirp_Z_transform(const shape_type& shp) {
357  this->shape(shp);
358  }
359 
360  inline const shape_type&
361  shape() const noexcept {
362  return this->_shape;
363  }
364 
365  inline const shape_type&
366  fourier_transform_shape() const noexcept {
367  return this->_fft.shape();
368  }
369 
370  inline void
371  shape(const shape_type& rhs) {
372  using blitz::all;
373  if (all(this->_shape == rhs)) {
374  return;
375  }
376  this->_shape = rhs;
377  shape_type new_shape;
378  if (all(is_power_of_two(rhs))) {
379  new_shape = rhs;
380  _power_of_two = true;
381  } else {
382  new_shape = next_power_of_two(rhs*2 - 1);
383  _power_of_two = false;
384  }
385  _fft.shape(new_shape);
386  this->make_chirp();
387  }
388 
389  inline void
390  forward(array_type& x) {
391  this->transform(x, -1);
392  }
393 
394  inline void
395  backward(array_type& x) {
396  this->transform(x, 1);
397  }
398 
399  inline void
400  transform(array_type& x, int dir) {
401  using blitz::product;
402  if (_power_of_two) {
403  _fft.transform(x, dir);
404  } else {
405  #if defined(VTB_DEBUG_CHIRP_Z)
406  std::clog << "chirp=" << _chirp << std::endl;
407  #endif
408  array_type xp{_fft.shape()};
409  xp = 0;
410  shape_type offset{_shape-1};
411  domain rect{offset, _chirp.shape()-1};
412  domain lrect{xp.base(), offset};
413  #if defined(VTB_DEBUG_CHIRP_Z)
414  std::clog << "offset=" << offset << std::endl;
415  #endif
416  if (dir < 0) {
417  xp(lrect) = x*_chirp(rect);
418  } else {
419  // flip the sign of the imaginary part
420  xp(lrect) = x*blitz::conj(_chirp(rect));
421  }
422  #if defined(VTB_DEBUG_CHIRP_Z)
423  std::clog << "xp=" << xp << std::endl;
424  #endif
425  array_type ichirp{_fft.shape()};
426  ichirp = 0;
427  if (dir < 0) {
428  ichirp(_chirp.domain()) = R{1} / _chirp;
429  } else {
430  ichirp(_chirp.domain()) = R{1} / blitz::conj(_chirp);
431  }
432  #if defined(VTB_DEBUG_CHIRP_Z)
433  std::clog << "ichirp=" << ichirp << std::endl;
434  #endif
435  _fft.transform(xp, dir);
436  #if defined(VTB_DEBUG_CHIRP_Z)
437  std::clog << "fft(xp)=" << xp << std::endl;
438  #endif
439  _fft.transform(ichirp, dir);
440  #if defined(VTB_DEBUG_CHIRP_Z)
441  std::clog << "fft(ichirp)=" << ichirp << std::endl;
442  #endif
443  ichirp *= xp;
444  #if defined(VTB_DEBUG_CHIRP_Z)
445  std::clog << "mult2=" << ichirp << std::endl;
446  #endif
447  _fft.transform(ichirp, -dir);
448  #if defined(VTB_DEBUG_CHIRP_Z)
449  std::clog << "ifft=" << ichirp << std::endl;
450  #endif
451  x = ichirp(rect) * _chirp(rect) / R(product(_fft.shape()));
452  }
453  }
454 
455  private:
456 
457  void
458  make_chirp() {
459  using blitz::exp;
460  using blitz::pow2;
461  using blitz::pow;
462  using blitz::product;
463  const R pi2 = R{M_PI}*2;
464  shape_type chirp_shape{_shape*2-1};
465  _chirp.resize(chirp_shape);
466  T i{0,1};
467  R d{-1};
468  vec w = exp(d*i*pi2/vec{_shape});
469  #if defined(VTB_DEBUG_CHIRP_Z)
470  std::clog << "w=" << w << std::endl;
471  #endif
472  shape_type offset{_shape-1};
473  for_loop<N>(
474  chirp_shape,
475  [&offset,this,&w] (const shape_type& idx) {
476  vec k = idx - offset;
477  _chirp(idx) = product(pow(w, pow2(k)*T{0.5,0}));
478  }
479  );
480  }
481 
482  };
483 
484  }
485 
486 }
487 
488 #endif // vim:filetype=cpp
Main namespace.
Definition: convert.hh:9
void shape(const shape_type &rhs)
Update wave table for specified shape.
Base class for all Fourier transforms.
-dimensional Fourier transform.