1 #ifndef VTESTBED_OPENCL_FOURIER_TRANSFORM_HH 2 #define VTESTBED_OPENCL_FOURIER_TRANSFORM_HH 9 #include <openclx/forward> 11 #include <vtestbed/base/blitz.hh> 12 #include <vtestbed/opencl/opencl.hh> 19 inline blitz::TinyVector<int,3>
20 make_shape(
const blitz::TinyVector<int,N>& rhs) {
21 static_assert(0 < N && N <= 3,
"bad N");
22 blitz::TinyVector<int,3> result;
24 for (
int i=0; i<N; ++i) {
31 inline blitz::TinyVector<int,3>
32 make_shape<3>(
const blitz::TinyVector<int,3>& rhs) {
37 inline blitz::TinyVector<int,N>
38 reduce_shape(
const blitz::TinyVector<int,3>& rhs) {
39 static_assert(0 < N && N <= 3,
"bad N");
40 blitz::TinyVector<int,N> result;
41 for (
int i=0; i<N; ++i) {
48 inline blitz::TinyVector<int,3>
49 reduce_shape<3>(
const blitz::TinyVector<int,3>& rhs) {
53 enum class Fourier_transform_format {
55 Interleaved_complex = 1
62 int num_workgroups = 0;
63 int num_xforms_per_workgroup = 0;
64 int num_workitems_per_workgroup = 0;
66 bool in_place_possible = 0;
72 typedef blitz::TinyVector<int,3> int3;
73 typedef blitz::TinyVector<int,2> int2;
74 typedef blitz::TinyVector<int,1> int1;
80 Fourier_transform_format _format = Fourier_transform_format::Interleaved_complex;
91 bool temp_buffer_needed =
false;
100 int last_batch_size = 0;
103 clx::buffer _workarea{
nullptr};
108 int max_localmem_fft_size = 2048;
112 int _maxworkgroupsize = 256;
121 int min_mem_coalesce_width = 16;
126 int num_local_mem_banks = 16;
139 enqueue(clx::buffer x,
int direction,
int batch_size=1);
142 forward(clx::buffer x,
int batch_size=1) {
143 this->enqueue(x, -1, batch_size);
147 backward(clx::buffer x,
int batch_size=1) {
148 this->enqueue(x, 1, batch_size);
155 shape()
const noexcept {
160 shape(
const int3& rhs) {
161 if (!blitz::all(this->_shape == rhs)) {
170 inline void context(
Context* rhs) { this->_context = rhs; }
171 inline Context* context() {
return this->_context; }
176 unique_kernel_index() {
177 return ++this->_kindex;
181 kernel_name(
const char* prefix);
184 buffer_size(
int batch_size)
const {
185 return blitz::product(this->_shape) *
186 batch_size * 2 *
sizeof(cl_float);
190 generate_source_code();
193 generate_fft(
int axis);
196 generate_fft_local();
199 generate_fft_global(
int n,
int BS,
int axis,
int vertBS);
202 getKernelWorkDimensions(
210 allocate_temporary_buffer(
int batch_size);
217 template <
class T,
int N>
221 static_assert(0 < N && N <= 3,
"bad N");
227 using shape_type = blitz::TinyVector<int,N>;
228 using Fourier_transform_base::context;
239 shape()
const noexcept {
240 return reduce_shape<N>(this->base_type::shape());
244 shape(
const shape_type& rhs) {
245 this->base_type::shape(make_shape<N>(rhs));
249 precompile(
const shape_type& shp,
Context* context) {
254 forward(clx::buffer x,
int batch_size=1) {
255 this->base_type::forward(x, batch_size);
259 backward(clx::buffer x,
int batch_size=1) {
260 this->base_type::backward(x, batch_size);
263 using base_type::dump;
270 using int3 = blitz::TinyVector<int,3>;
280 clx::kernel _makechirp, _reciprocal_chirp, _mult1, _mult2, _mult3, _zero_init;
285 enqueue(clx::buffer x,
int direction,
int batch_size=1);
288 inline Context* context() {
return this->_fft.context(); }
293 make_chirp(
const int3& shape,
const int3& fft_shape);
297 template <
class T,
int N>
301 static_assert(0 < N && N <= 3,
"bad N");
304 typedef blitz::TinyVector<int,N> shape_type;
307 typedef blitz::TinyVector<T,N> vec;
308 typedef blitz::RectDomain<N> domain;
309 typedef typename T::value_type R;
312 shape_type _shape{0};
313 bool _power_of_two =
false;
324 inline const shape_type&
325 shape()
const noexcept {
330 shape(
const shape_type& rhs) {
332 if (all(this->_shape == rhs)) {
336 shape_type new_shape;
337 if (all(blitz::is_power_of_two(rhs))) {
339 _power_of_two =
true;
341 new_shape = next_power_of_two(rhs*2 - 1);
342 _power_of_two =
false;
344 auto fft_shape{make_shape<N>(new_shape)};
345 _fft.shape(fft_shape);
346 this->make_chirp(make_shape<N>(_shape), fft_shape);
350 forward(clx::buffer x) {
351 this->transform(x, -1);
355 backward(clx::buffer x) {
356 this->transform(x, 1);
360 transform(clx::buffer x,
int dir) {
361 using blitz::product;
363 _fft.enqueue(x, dir, 1);
365 this->enqueue(x, dir, 1);
375 #endif // vim:filetype=cpp