Virtual Testbed
Ship dynamics simulator for extreme conditions
convolution.hh
1 #ifndef VTESTBED_CORE_CONVOLUTION_HH
2 #define VTESTBED_CORE_CONVOLUTION_HH
3 
4 #include <stdexcept>
5 
6 #include <vtestbed/base/blitz.hh>
7 #include <vtestbed/core/fourier_transform.hh>
8 
9 namespace vtb {
10 
11  namespace core {
12 
20  template <class T, int N>
21  class Convolution {
22 
23  public:
25  typedef typename transform_type::shape_type shape_type;
26  typedef typename transform_type::array_type array_type;
27  typedef blitz::RectDomain<N> domain_type;
28 
29  private:
30  shape_type _blockshape;
31  shape_type _padding;
32  transform_type _fft;
33 
34  public:
35 
36  Convolution() = default;
37 
38  inline explicit
40  const shape_type& signal_shape,
41  const shape_type& kernel_shape
42  ) {
43  this->shape(signal_shape, kernel_shape);
44  }
45 
46  void
47  shape(const shape_type& signal_shape, const shape_type& kernel_shape);
48 
49  inline void
50  padded_block_shape(const shape_type& block_size, const shape_type& padding) {
51  this->check(block_size, padding);
52  this->_blockshape = block_size;
53  this->_padding = padding;
54  this->_fft.shape(this->padded_block_shape());
55  }
56 
57  inline shape_type
58  padded_block_shape() const noexcept {
59  return this->_blockshape + this->_padding;
60  }
61 
62  inline const shape_type&
63  block_shape() const noexcept {
64  return this->_blockshape;
65  }
66 
67  inline const shape_type&
68  padding() const noexcept {
69  return this->_padding;
70  }
71 
72  inline array_type
73  operator()(array_type signal, array_type kernel) {
74  return this->convolve(signal, kernel);
75  }
76 
77  array_type
78  convolve(array_type signal, array_type kernel);
79 
80  private:
81 
82  static inline void
83  check(const shape_type& blocksize, const shape_type& padding) {
84  using blitz::all;
85  if (!all(padding >= 0)) {
86  throw std::length_error("bad padding");
87  }
88  if (!all(blocksize > 0)) {
89  throw std::length_error("bad block size");
90  }
91  if (!all(blocksize >= padding)) {
92  throw std::length_error("bad block size/padding ratio");
93  }
94  if (!all(is_power_of_two(blocksize + padding))) {
95  throw std::length_error("bad padded block size");
96  }
97  }
98 
99  shape_type
100  get_block_shape(
101  const shape_type& signal_shape,
102  const shape_type& kernel_shape
103  );
104 
105  };
106 
107  }
108 
109 }
110 
111 #endif // vim:filetype=cpp
Main namespace.
Definition: convert.hh:9
Multidimensional convolution based on Fourier transform.
Definition: convolution.hh:21
array_type convolve(array_type signal, array_type kernel)
Definition: convolution.cc:36
-dimensional Fourier transform.