1 #include <vtestbed/base/blitz.hh> 2 #include <vtestbed/base/for_loop.hh> 3 #include <vtestbed/config/openmp.hh> 4 #include <vtestbed/config/real_type.hh> 5 #include <vtestbed/core/convolution.hh> 9 #if defined(VTB_WITH_OPENMP) 13 template <
class T,
int N>
16 const shape_type& signal_shape,
17 const shape_type& kernel_shape
21 using blitz::next_power_of_two;
22 #if defined(VTB_WITH_OPENMP) 23 const int parallelism = omp_get_max_threads();
25 const int parallelism = 2;
27 shape_type guess1 =
min(signal_shape, 4*kernel_shape);
28 shape_type guess2 =
max(kernel_shape, signal_shape / parallelism);
29 shape_type block_size =
min(guess1, guess2) + abs(guess1 - guess2) / 2;
30 block_size = next_power_of_two(block_size + kernel_shape) - kernel_shape;
31 this->padded_block_shape(block_size, kernel_shape);
34 template <
class T,
int N>
39 using blitz::div_ceil;
41 if (!all(kernel.shape() <= this->_blockshape)) {
45 const shape_type padded_block = this->padded_block_shape();
46 domain_type orig_domain(shape_type(0), kernel.shape()-1);
47 array_type padded_kernel(padded_block);
48 padded_kernel(orig_domain) = kernel;
49 const auto nelements = padded_kernel.numElements();
51 this->_fft.forward(padded_kernel);
53 const shape_type bs = this->_blockshape;
54 const shape_type pad = this->_padding;
55 const shape_type nparts = div_ceil(signal.shape(), bs);
56 if (!all(bs <= signal.shape())) {
59 array_type out_signal(signal.shape());
62 [&] (
const shape_type& idx) {
65 const shape_type offset = idx*bs;
66 const shape_type from = offset;
67 const shape_type to = min(signal.shape(), offset+bs) - 1;
68 const domain_type part_domain(from, to);
69 const domain_type dom_to(from-offset, to-offset);
70 array_type padded_part(padded_block);
71 padded_part(dom_to) = signal(part_domain);
73 this->_fft.forward(padded_part);
75 padded_part *= padded_kernel;
77 this->_fft.backward(padded_part);
78 padded_part /= nelements;
80 const domain_type padded_from(from, min(to+pad, signal.shape()-1));
81 const domain_type padded_to(from-offset, padded_from.ubound()-offset);
82 #if defined(VTB_WITH_OPENMP) 86 out_signal(padded_from) += padded_part(padded_to);
Multidimensional convolution based on Fourier transform.
array_type convolve(array_type signal, array_type kernel)