Virtual Testbed
Ship dynamics simulator for extreme conditions
opencl/fourier_transform.cc
1 #include <atomic>
2 #include <cassert>
3 #include <chrono>
4 #include <cstdio>
5 #include <iomanip>
6 #include <iostream>
7 #include <ostream>
8 #include <sstream>
9 
10 #include <openclx/compiler>
11 
12 #include <vtestbed/config/openmp.hh>
13 #include <vtestbed/opencl/fourier_transform.hh>
14 #include <vtestbed/opencl/pipeline.hh>
15 
16 namespace {
17 
18  #if defined(VTB_DEBUG_CHIRP_Z)
19  template <class T, int N>
20  void
21  dump(vtb::opencl::Buffer<T> d_x, const blitz::TinyVector<int,N>& shape,
22  const char* name) {
23  blitz::Array<T,N> x(shape);
24  ppl.copy(d_x, x);
25  ppl.wait();
26  std::clog << name << '=' << x << std::endl;
27  }
28  #define VTB_DUMP(x, shape, name) ::dump(x,shape,name)
29  #else
30  #define VTB_DUMP(x, shape, name)
31  #endif
32 
33 
34  const char* base_kernels =
35  "#ifndef M_PI\n"
36  "#define M_PI 0x1.921fb54442d18p+1\n"
37  "#endif\n"
38  "#define complexMul(a,b) ((float2)(mad(-(a).y, (b).y, (a).x * (b).x), mad((a).y, (b).x, (a).x * (b).y)))\n"
39  "#define conj(a) ((float2)((a).x, -(a).y))\n"
40  "#define conjTransp(a) ((float2)(-(a).y, (a).x))\n"
41  "\n"
42  "#define fftKernel2(a,dir) \\\n"
43  "{ \\\n"
44  " float2 c = (a)[0]; \\\n"
45  " (a)[0] = c + (a)[1]; \\\n"
46  " (a)[1] = c - (a)[1]; \\\n"
47  "}\n"
48  "\n"
49  "#define fftKernel2S(d1,d2,dir) \\\n"
50  "{ \\\n"
51  " float2 c = (d1); \\\n"
52  " (d1) = c + (d2); \\\n"
53  " (d2) = c - (d2); \\\n"
54  "}\n"
55  "\n"
56  "#define fftKernel4(a,dir) \\\n"
57  "{ \\\n"
58  " fftKernel2S((a)[0], (a)[2], dir); \\\n"
59  " fftKernel2S((a)[1], (a)[3], dir); \\\n"
60  " fftKernel2S((a)[0], (a)[1], dir); \\\n"
61  " (a)[3] = (float2)(dir)*(conjTransp((a)[3])); \\\n"
62  " fftKernel2S((a)[2], (a)[3], dir); \\\n"
63  " float2 c = (a)[1]; \\\n"
64  " (a)[1] = (a)[2]; \\\n"
65  " (a)[2] = c; \\\n"
66  "}\n"
67  "\n"
68  "#define fftKernel4s(a0,a1,a2,a3,dir) \\\n"
69  "{ \\\n"
70  " fftKernel2S((a0), (a2), dir); \\\n"
71  " fftKernel2S((a1), (a3), dir); \\\n"
72  " fftKernel2S((a0), (a1), dir); \\\n"
73  " (a3) = (float2)(dir)*(conjTransp((a3))); \\\n"
74  " fftKernel2S((a2), (a3), dir); \\\n"
75  " float2 c = (a1); \\\n"
76  " (a1) = (a2); \\\n"
77  " (a2) = c; \\\n"
78  "}\n"
79  "\n"
80  "#define bitreverse8(a) \\\n"
81  "{ \\\n"
82  " float2 c; \\\n"
83  " c = (a)[1]; \\\n"
84  " (a)[1] = (a)[4]; \\\n"
85  " (a)[4] = c; \\\n"
86  " c = (a)[3]; \\\n"
87  " (a)[3] = (a)[6]; \\\n"
88  " (a)[6] = c; \\\n"
89  "}\n"
90  "\n"
91  "#define fftKernel8(a,dir) \\\n"
92  "{ \\\n"
93  " const float2 w1 = (float2)(0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f); \\\n"
94  " const float2 w3 = (float2)(-0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f); \\\n"
95  " float2 c; \\\n"
96  " fftKernel2S((a)[0], (a)[4], dir); \\\n"
97  " fftKernel2S((a)[1], (a)[5], dir); \\\n"
98  " fftKernel2S((a)[2], (a)[6], dir); \\\n"
99  " fftKernel2S((a)[3], (a)[7], dir); \\\n"
100  " (a)[5] = complexMul(w1, (a)[5]); \\\n"
101  " (a)[6] = (float2)(dir)*(conjTransp((a)[6])); \\\n"
102  " (a)[7] = complexMul(w3, (a)[7]); \\\n"
103  " fftKernel2S((a)[0], (a)[2], dir); \\\n"
104  " fftKernel2S((a)[1], (a)[3], dir); \\\n"
105  " fftKernel2S((a)[4], (a)[6], dir); \\\n"
106  " fftKernel2S((a)[5], (a)[7], dir); \\\n"
107  " (a)[3] = (float2)(dir)*(conjTransp((a)[3])); \\\n"
108  " (a)[7] = (float2)(dir)*(conjTransp((a)[7])); \\\n"
109  " fftKernel2S((a)[0], (a)[1], dir); \\\n"
110  " fftKernel2S((a)[2], (a)[3], dir); \\\n"
111  " fftKernel2S((a)[4], (a)[5], dir); \\\n"
112  " fftKernel2S((a)[6], (a)[7], dir); \\\n"
113  " bitreverse8((a)); \\\n"
114  "}\n"
115  "\n"
116  "#define bitreverse4x4(a) \\\n"
117  "{ \\\n"
118  " float2 c; \\\n"
119  " c = (a)[1]; (a)[1] = (a)[4]; (a)[4] = c; \\\n"
120  " c = (a)[2]; (a)[2] = (a)[8]; (a)[8] = c; \\\n"
121  " c = (a)[3]; (a)[3] = (a)[12]; (a)[12] = c; \\\n"
122  " c = (a)[6]; (a)[6] = (a)[9]; (a)[9] = c; \\\n"
123  " c = (a)[7]; (a)[7] = (a)[13]; (a)[13] = c; \\\n"
124  " c = (a)[11]; (a)[11] = (a)[14]; (a)[14] = c; \\\n"
125  "}\n"
126  "\n"
127  "#define fftKernel16(a,dir) \\\n"
128  "{ \\\n"
129  " const float w0 = 0x1.d906bcp-1f; \\\n"
130  " const float w1 = 0x1.87de2ap-2f; \\\n"
131  " const float w2 = 0x1.6a09e6p-1f; \\\n"
132  " fftKernel4s((a)[0], (a)[4], (a)[8], (a)[12], dir); \\\n"
133  " fftKernel4s((a)[1], (a)[5], (a)[9], (a)[13], dir); \\\n"
134  " fftKernel4s((a)[2], (a)[6], (a)[10], (a)[14], dir); \\\n"
135  " fftKernel4s((a)[3], (a)[7], (a)[11], (a)[15], dir); \\\n"
136  " (a)[5] = complexMul((a)[5], (float2)(w0, dir*w1)); \\\n"
137  " (a)[6] = complexMul((a)[6], (float2)(w2, dir*w2)); \\\n"
138  " (a)[7] = complexMul((a)[7], (float2)(w1, dir*w0)); \\\n"
139  " (a)[9] = complexMul((a)[9], (float2)(w2, dir*w2)); \\\n"
140  " (a)[10] = (float2)(dir)*(conjTransp((a)[10])); \\\n"
141  " (a)[11] = complexMul((a)[11], (float2)(-w2, dir*w2)); \\\n"
142  " (a)[13] = complexMul((a)[13], (float2)(w1, dir*w0)); \\\n"
143  " (a)[14] = complexMul((a)[14], (float2)(-w2, dir*w2)); \\\n"
144  " (a)[15] = complexMul((a)[15], (float2)(-w0, dir*-w1)); \\\n"
145  " fftKernel4((a), dir); \\\n"
146  " fftKernel4((a) + 4, dir); \\\n"
147  " fftKernel4((a) + 8, dir); \\\n"
148  " fftKernel4((a) + 12, dir); \\\n"
149  " bitreverse4x4((a)); \\\n"
150  "}\n"
151  "\n"
152  "#define bitreverse32(a) \\\n"
153  "{ \\\n"
154  " float2 c1, c2; \\\n"
155  " c1 = (a)[2]; (a)[2] = (a)[1]; c2 = (a)[4]; (a)[4] = c1; c1 = (a)[8]; (a)[8] = c2; c2 = (a)[16]; (a)[16] = c1; (a)[1] = c2; \\\n"
156  " c1 = (a)[6]; (a)[6] = (a)[3]; c2 = (a)[12]; (a)[12] = c1; c1 = (a)[24]; (a)[24] = c2; c2 = (a)[17]; (a)[17] = c1; (a)[3] = c2; \\\n"
157  " c1 = (a)[10]; (a)[10] = (a)[5]; c2 = (a)[20]; (a)[20] = c1; c1 = (a)[9]; (a)[9] = c2; c2 = (a)[18]; (a)[18] = c1; (a)[5] = c2; \\\n"
158  " c1 = (a)[14]; (a)[14] = (a)[7]; c2 = (a)[28]; (a)[28] = c1; c1 = (a)[25]; (a)[25] = c2; c2 = (a)[19]; (a)[19] = c1; (a)[7] = c2; \\\n"
159  " c1 = (a)[22]; (a)[22] = (a)[11]; c2 = (a)[13]; (a)[13] = c1; c1 = (a)[26]; (a)[26] = c2; c2 = (a)[21]; (a)[21] = c1; (a)[11] = c2; \\\n"
160  " c1 = (a)[30]; (a)[30] = (a)[15]; c2 = (a)[29]; (a)[29] = c1; c1 = (a)[27]; (a)[27] = c2; c2 = (a)[23]; (a)[23] = c1; (a)[15] = c2; \\\n"
161  "}\n"
162  "\n"
163  "#define fftKernel32(a,dir) \\\n"
164  "{ \\\n"
165  " fftKernel2S((a)[0], (a)[16], dir); \\\n"
166  " fftKernel2S((a)[1], (a)[17], dir); \\\n"
167  " fftKernel2S((a)[2], (a)[18], dir); \\\n"
168  " fftKernel2S((a)[3], (a)[19], dir); \\\n"
169  " fftKernel2S((a)[4], (a)[20], dir); \\\n"
170  " fftKernel2S((a)[5], (a)[21], dir); \\\n"
171  " fftKernel2S((a)[6], (a)[22], dir); \\\n"
172  " fftKernel2S((a)[7], (a)[23], dir); \\\n"
173  " fftKernel2S((a)[8], (a)[24], dir); \\\n"
174  " fftKernel2S((a)[9], (a)[25], dir); \\\n"
175  " fftKernel2S((a)[10], (a)[26], dir); \\\n"
176  " fftKernel2S((a)[11], (a)[27], dir); \\\n"
177  " fftKernel2S((a)[12], (a)[28], dir); \\\n"
178  " fftKernel2S((a)[13], (a)[29], dir); \\\n"
179  " fftKernel2S((a)[14], (a)[30], dir); \\\n"
180  " fftKernel2S((a)[15], (a)[31], dir); \\\n"
181  " (a)[17] = complexMul((a)[17], (float2)(0x1.f6297cp-1f, dir*0x1.8f8b84p-3f)); \\\n"
182  " (a)[18] = complexMul((a)[18], (float2)(0x1.d906bcp-1f, dir*0x1.87de2ap-2f)); \\\n"
183  " (a)[19] = complexMul((a)[19], (float2)(0x1.a9b662p-1f, dir*0x1.1c73b4p-1f)); \\\n"
184  " (a)[20] = complexMul((a)[20], (float2)(0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f)); \\\n"
185  " (a)[21] = complexMul((a)[21], (float2)(0x1.1c73b4p-1f, dir*0x1.a9b662p-1f)); \\\n"
186  " (a)[22] = complexMul((a)[22], (float2)(0x1.87de2ap-2f, dir*0x1.d906bcp-1f)); \\\n"
187  " (a)[23] = complexMul((a)[23], (float2)(0x1.8f8b84p-3f, dir*0x1.f6297cp-1f)); \\\n"
188  " (a)[24] = complexMul((a)[24], (float2)(0x0p+0f, dir*0x1p+0f)); \\\n"
189  " (a)[25] = complexMul((a)[25], (float2)(-0x1.8f8b84p-3f, dir*0x1.f6297cp-1f)); \\\n"
190  " (a)[26] = complexMul((a)[26], (float2)(-0x1.87de2ap-2f, dir*0x1.d906bcp-1f)); \\\n"
191  " (a)[27] = complexMul((a)[27], (float2)(-0x1.1c73b4p-1f, dir*0x1.a9b662p-1f)); \\\n"
192  " (a)[28] = complexMul((a)[28], (float2)(-0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f)); \\\n"
193  " (a)[29] = complexMul((a)[29], (float2)(-0x1.a9b662p-1f, dir*0x1.1c73b4p-1f)); \\\n"
194  " (a)[30] = complexMul((a)[30], (float2)(-0x1.d906bcp-1f, dir*0x1.87de2ap-2f)); \\\n"
195  " (a)[31] = complexMul((a)[31], (float2)(-0x1.f6297cp-1f, dir*0x1.8f8b84p-3f)); \\\n"
196  " fftKernel16((a), dir); \\\n"
197  " fftKernel16((a) + 16, dir); \\\n"
198  " bitreverse32((a)); \\\n"
199  "}\n\n";
200 
202  radix_array(int n, int max) {
203  std::vector<int> result;
204  max = std::min(n, max);
205  while (n > max) {
206  result.push_back(max);
207  n /= max;
208  }
209  result.push_back(n);
210  return result;
211  }
212 
214  radix_array(int n) {
215  std::vector<int> result;
216  switch (n) {
217  case 2: result = {2}; break;
218  case 4: result = {4}; break;
219  case 8: result = {8}; break;
220  case 16: result = {8,2}; break;
221  case 32: result = {8,4}; break;
222  case 64: result = {8,8}; break;
223  case 128: result = {8,4,4}; break;
224  case 256: result = {4,4,4,4}; break;
225  case 512: result = {8,8,8}; break;
226  case 1024: result = {16,16,4}; break;
227  case 2048: result = {8,8,8,4}; break;
228  default: throw std::runtime_error{"unable to generate radix array"};
229  }
230  return result;
231  }
232 
233  void
234  formattedLoad(
235  std::ostream& out,
236  int aIndex,
237  int gIndex,
238  vtb::opencl::Fourier_transform_format dataFormat
239  ) {
240  using vtb::opencl::Fourier_transform_format;
241  if (dataFormat == Fourier_transform_format::Interleaved_complex) {
242  out << " a[" << (aIndex) << "] = in[" << (gIndex) << "];\n";
243  } else {
244  out << " a[" << (aIndex) << "].x = in_real[" << (gIndex) << "];\n";
245  out << " a[" << (aIndex) << "].y = in_imag[" << (gIndex) << "];\n";
246  }
247  }
248 
249  void
250  formattedStore(
251  std::ostream& out,
252  int aIndex,
253  int gIndex,
254  vtb::opencl::Fourier_transform_format dataFormat
255  ) {
256  using vtb::opencl::Fourier_transform_format;
257  if (dataFormat == Fourier_transform_format::Interleaved_complex) {
258  out << " out[" << (gIndex) << "] = a[" << (aIndex) << "];\n";
259  } else {
260  out << " out_real[" << (gIndex) << "] = a[" << (aIndex) << "].x;\n";
261  out << " out_imag[" << (gIndex) << "] = a[" << (aIndex) << "].y;\n";
262  }
263  }
264 
265  void
266  insertHeader(
267  std::ostream& out,
268  std::string kernelName,
269  vtb::opencl::Fourier_transform_format dataFormat
270  ) {
271  using vtb::opencl::Fourier_transform_format;
272  if (dataFormat == Fourier_transform_format::Split_complex) {
273  out << "__kernel void " + kernelName
274  << "(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n";
275  } else {
276  out << "__kernel void " + kernelName
277  << "(__global float2 *in, __global float2 *out, int dir, int S)\n";
278  }
279  }
280 
281  void
282  insertVariables(std::ostream& out, int maxRadix) {
283  out << " int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n";
284  out << " int s, ii, jj, offset;\n";
285  out << " float2 w;\n";
286  out << " float ang, angf, ang1;\n";
287  out << " __local float *lMemStore, *lMemLoad;\n";
288  out << " float2 a[" << maxRadix << "];\n";
289  out << " int lId = get_local_id( 0 );\n";
290  out << " int groupId = get_group_id( 0 );\n";
291  }
292 
293  void
294  insertfftKernel(std::ostream& out, int Nr, int numIter) {
295  for (int i=0; i<numIter; ++i) {
296  out << " fftKernel" << (Nr) << "(a+" << (i*Nr) << ", dir);\n";
297  }
298  }
299 
300  void
301  insertTwiddleKernel(
302  std::ostream& out,
303  int Nr,
304  int numIter,
305  int Nprev,
306  int len,
307  int numWorkItemsPerXForm
308  ) {
309  int logNPrev = (int)std::log2(Nprev);
310  for (int z=0; z<numIter; ++z) {
311  if (z == 0) {
312  if (Nprev > 1) {
313  out << " angf = (float) (ii >> " << (logNPrev) << ");\n";
314  } else {
315  out << " angf = (float) ii;\n";
316  }
317  } else {
318  if (Nprev > 1) {
319  out << " angf = (float) ((" << (z*numWorkItemsPerXForm) << " + ii) >>" << (logNPrev) << ");\n";
320  } else {
321  out << " angf = (float) (" << (z*numWorkItemsPerXForm) << " + ii);\n";
322  }
323  }
324  for (int k=1; k<Nr; ++k) {
325  int ind = z*Nr + k;
326  //float fac = (float) (2.0 * M_PI * (double) k / (double) len);
327  out << " ang = dir * ( 2.0f * M_PI * " << (k) << ".0f / " << (len) << ".0f )" << " * angf;\n";
328  out << " w = (float2)(native_cos(ang), native_sin(ang));\n";
329  out << " a[" << (ind) << "] = complexMul(a[" << (ind) << "], w);\n";
330  }
331  }
332  }
333 
334  int
335  insertGlobalLoadsAndTranspose(
336  std::ostream& out,
337  int N,
338  int numWorkItemsPerXForm,
339  int numXFormsPerWG,
340  int R0,
341  int mem_coalesce_width,
342  vtb::opencl::Fourier_transform_format dataFormat
343  ) {
344  using vtb::opencl::Fourier_transform_format;
345  int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm);
346  int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
347  int lMemSize = 0;
348  if (numXFormsPerWG > 1) {
349  out << " s = S & " << (numXFormsPerWG-1) << ";\n";
350  }
351  if (numWorkItemsPerXForm >= mem_coalesce_width) {
352  if (numXFormsPerWG > 1) {
353  out << " ii = lId & " << (numWorkItemsPerXForm-1) << ";\n";
354  out << " jj = lId >> " << log2NumWorkItemsPerXForm << ";\n";
355  out << " if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n";
356  out << " offset = mad24( mad24(groupId, "
357  << numXFormsPerWG
358  << ", jj), " << N
359  << ", ii );\n";
360  if (dataFormat == Fourier_transform_format::Interleaved_complex) {
361  out << " in += offset;\n";
362  out << " out += offset;\n";
363  } else {
364  out << " in_real += offset;\n";
365  out << " in_imag += offset;\n";
366  out << " out_real += offset;\n";
367  out << " out_imag += offset;\n";
368  }
369  for (int i=0; i<R0; ++i) {
370  formattedLoad(out, i, i*numWorkItemsPerXForm, dataFormat);
371  }
372  out << " }\n";
373  } else {
374  out << " ii = lId;\n";
375  out << " jj = 0;\n";
376  out << " offset = mad24(groupId, " << N << ", ii);\n";
377  if (dataFormat == Fourier_transform_format::Interleaved_complex) {
378  out << " in += offset;\n";
379  out << " out += offset;\n";
380  } else {
381  out << " in_real += offset;\n";
382  out << " in_imag += offset;\n";
383  out << " out_real += offset;\n";
384  out << " out_imag += offset;\n";
385  }
386  for (int i=0; i<R0; ++i) {
387  formattedLoad(out, i, i*numWorkItemsPerXForm, dataFormat);
388  }
389  }
390  } else if (N >= mem_coalesce_width) {
391  int numInnerIter = N / mem_coalesce_width;
392  int numOuterIter = numXFormsPerWG / (groupSize / mem_coalesce_width);
393 
394  out << " ii = lId & " << (mem_coalesce_width - 1) << ";\n";
395  out << " jj = lId >> " << ((int)log2(mem_coalesce_width)) << ";\n";
396  out << " lMemStore = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) << ", ii );\n";
397  out << " offset = mad24( groupId, " << (numXFormsPerWG) << ", jj);\n";
398  out << " offset = mad24( offset, " << (N) << ", ii );\n";
399  if (dataFormat == Fourier_transform_format::Interleaved_complex) {
400  out << " in += offset;\n";
401  out << " out += offset;\n";
402  } else {
403  out << " in_real += offset;\n";
404  out << " in_imag += offset;\n";
405  out << " out_real += offset;\n";
406  out << " out_imag += offset;\n";
407  }
408  out << "if((groupId == get_num_groups(0)-1) && s) {\n";
409  for(int i=0; i<numOuterIter; ++i) {
410  out << " if( jj < s ) {\n";
411  for (int j=0; j<numInnerIter; ++j) {
412  formattedLoad(
413  out,
414  i * numInnerIter + j,
415  j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N,
416  dataFormat
417  );
418  }
419  out << " }\n";
420  if (i != numOuterIter-1) {
421  out << " jj += " << (groupSize / mem_coalesce_width) << ";\n";
422  }
423  }
424  out << "}\n ";
425  out << "else {\n";
426  for (int i = 0; i < numOuterIter; i++ ) {
427  for (int j = 0; j < numInnerIter; j++ ) {
428  formattedLoad(
429  out,
430  i * numInnerIter + j,
431  j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N,
432  dataFormat
433  );
434  }
435  }
436  out << "}\n";
437  out << " ii = lId & " << (numWorkItemsPerXForm - 1) << ";\n";
438  out << " jj = lId >> " << (log2NumWorkItemsPerXForm) << ";\n";
439  out << " lMemLoad = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) << ", ii);\n";
440  for (int i=0; i<numOuterIter; ++i) {
441  for (int j=0; j<numInnerIter; ++j) {
442  out << " lMemStore["
443  << (j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm ))
444  << "] = a["
445  << (i * numInnerIter + j)
446  << "].x;\n";
447  }
448  }
449  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
450  for (int i=0; i<R0; ++i) {
451  out << " a["
452  << (i)
453  << "].x = lMemLoad["
454  << (i * numWorkItemsPerXForm)
455  << "];\n";
456  }
457  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
458  for (int i=0; i<numOuterIter; ++i) {
459  for (int j=0; j<numInnerIter; ++j) {
460  out << " lMemStore["
461  << (j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm ))
462  << "] = a["
463  << (i * numInnerIter + j)
464  << "].y;\n";
465  }
466  }
467  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
468  for (int i=0; i<R0; ++i) {
469  out << " a["
470  << (i)
471  << "].y = lMemLoad["
472  << (i * numWorkItemsPerXForm) << "];\n";
473  }
474  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
475  lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
476  } else {
477  out << " offset = mad24( groupId, " << (N * numXFormsPerWG) << ", lId );\n";
478  if (dataFormat == Fourier_transform_format::Interleaved_complex) {
479  out << " in += offset;\n";
480  out << " out += offset;\n";
481  } else {
482  out << " in_real += offset;\n";
483  out << " in_imag += offset;\n";
484  out << " out_real += offset;\n";
485  out << " out_imag += offset;\n";
486  }
487  out << " ii = lId & " << (N-1) << ";\n";
488  out << " jj = lId >> " << ((int)log2(N)) << ";\n";
489  out << " lMemStore = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) << ", ii );\n";
490  out << "if((groupId == get_num_groups(0)-1) && s) {\n";
491  for (int i=0; i<R0; ++i) {
492  out << " if(jj < s )\n";
493  formattedLoad(out, i, i*groupSize, dataFormat);
494  if (i != R0-1) {
495  out << " jj += " << (groupSize / N) << ";\n";
496  }
497  }
498  out << "}\n";
499  out << "else {\n";
500  for (int i=0; i<R0; ++i) {
501  formattedLoad(out, i, i*groupSize, dataFormat);
502  }
503  out << "}\n";
504  if (numWorkItemsPerXForm > 1) {
505  out << " ii = lId & " << (numWorkItemsPerXForm - 1) << ";\n";
506  out << " jj = lId >> " << (log2NumWorkItemsPerXForm) << ";\n";
507  out << " lMemLoad = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) << ", ii );\n";
508  } else {
509  out << " ii = 0;\n";
510  out << " jj = lId;\n";
511  out << " lMemLoad = sMem + mul24( jj, " << (N + numWorkItemsPerXForm) << ");\n";
512  }
513  for (int i=0; i<R0; ++i) {
514  out << " lMemStore[" << (i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) << "] = a[" << (i) << "].x;\n";
515  }
516  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
517 
518  for (int i=0; i<R0; ++i) {
519  out << " a[" << (i) << "].x = lMemLoad[" << (i * numWorkItemsPerXForm) << "];\n";
520  }
521  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
522  for (int i=0; i<R0; ++i) {
523  out << " lMemStore[" << (i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) << "] = a[" << (i) << "].y;\n";
524  }
525  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
526  for (int i=0; i<R0; ++i) {
527  out << " a[" << (i) << "].y = lMemLoad[" << (i * numWorkItemsPerXForm) << "];\n";
528  }
529  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
530  lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
531  }
532  return lMemSize;
533  }
534 
535  int
536  insertGlobalStoresAndTranspose(
537  std::ostream& out,
538  int N,
539  int maxRadix,
540  int Nr,
541  int numWorkItemsPerXForm,
542  int numXFormsPerWG,
543  int mem_coalesce_width,
544  vtb::opencl::Fourier_transform_format dataFormat
545  ) {
546  int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
547  int i, j, k, ind;
548  int lMemSize = 0;
549  int numIter = maxRadix / Nr;
550  if( numWorkItemsPerXForm >= mem_coalesce_width )
551  {
552  if(numXFormsPerWG > 1)
553  {
554  out << " if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n";
555  }
556  for(i = 0; i < maxRadix; i++)
557  {
558  j = i % numIter;
559  k = i / numIter;
560  ind = j * Nr + k;
561  formattedStore(out, ind, i*numWorkItemsPerXForm, dataFormat);
562  }
563  if(numXFormsPerWG > 1)
564  out << " }\n";
565  }
566  else if( N >= mem_coalesce_width )
567  {
568  int numInnerIter = N / mem_coalesce_width;
569  int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
570  out << " lMemLoad = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) << ", ii );\n";
571  out << " ii = lId & " << (mem_coalesce_width - 1) << ";\n";
572  out << " jj = lId >> " << ((int)log2(mem_coalesce_width)) << ";\n";
573  out << " lMemStore = sMem + mad24( jj," << (N + numWorkItemsPerXForm) << ", ii );\n";
574  for( i = 0; i < maxRadix; i++ )
575  {
576  j = i % numIter;
577  k = i / numIter;
578  ind = j * Nr + k;
579  out << " lMemLoad[" << (i*numWorkItemsPerXForm) << "] = a[" << (ind) << "].x;\n";
580  }
581  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
582  for( i = 0; i < numOuterIter; i++ )
583  for( j = 0; j < numInnerIter; j++ )
584  out << " a[" << (i*numInnerIter + j) << "].x = lMemStore[" << (j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) << "];\n";
585  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
586  for( i = 0; i < maxRadix; i++ )
587  {
588  j = i % numIter;
589  k = i / numIter;
590  ind = j * Nr + k;
591  out << " lMemLoad[" << (i*numWorkItemsPerXForm) << "] = a[" << (ind) << "].y;\n";
592  }
593  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
594  for( i = 0; i < numOuterIter; i++ )
595  for( j = 0; j < numInnerIter; j++ )
596  out << " a[" << (i*numInnerIter + j) << "].y = lMemStore[" << (j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) << "];\n";
597  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
598  out << "if((groupId == get_num_groups(0)-1) && s) {\n";
599  for(i = 0; i < numOuterIter; i++ )
600  {
601  out << " if( jj < s ) {\n";
602  for (int j = 0; j < numInnerIter; j++ )
603  formattedStore(out, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
604  out << " }\n";
605  if(i != numOuterIter - 1)
606  out << " jj += " << (groupSize / mem_coalesce_width) << ";\n";
607  }
608  out << "}\n";
609  out << "else {\n";
610  for(i = 0; i < numOuterIter; i++ )
611  {
612  for (int j = 0; j < numInnerIter; j++ )
613  formattedStore(out, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
614  }
615  out << "}\n";
616  lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
617  }
618  else
619  {
620  out << " lMemLoad = sMem + mad24( jj," << (N + numWorkItemsPerXForm) << ", ii );\n";
621  out << " ii = lId & " << (N - 1) << ";\n";
622  out << " jj = lId >> " << ((int) log2(N)) << ";\n";
623  out << " lMemStore = sMem + mad24( jj," << (N + numWorkItemsPerXForm) << ", ii );\n";
624  for( i = 0; i < maxRadix; i++ )
625  {
626  j = i % numIter;
627  k = i / numIter;
628  ind = j * Nr + k;
629  out << " lMemLoad[" << (i*numWorkItemsPerXForm) << "] = a[" << (ind) << "].x;\n";
630  }
631  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
632  for( i = 0; i < maxRadix; i++ )
633  out << " a[" << (i) << "].x = lMemStore[" << (i*( groupSize / N )*( N + numWorkItemsPerXForm )) << "];\n";
634  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
635  for( i = 0; i < maxRadix; i++ )
636  {
637  j = i % numIter;
638  k = i / numIter;
639  ind = j * Nr + k;
640  out << " lMemLoad[" << (i*numWorkItemsPerXForm) << "] = a[" << (ind) << "].y;\n";
641  }
642  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
643  for( i = 0; i < maxRadix; i++ )
644  out << " a[" << (i) << "].y = lMemStore[" << (i*( groupSize / N )*( N + numWorkItemsPerXForm )) << "];\n";
645  out << " barrier( CLK_LOCAL_MEM_FENCE );\n";
646  out << "if((groupId == get_num_groups(0)-1) && s) {\n";
647  for( i = 0; i < maxRadix; i++ )
648  {
649  out << " if(jj < s ) {\n";
650  formattedStore(out, i, i*groupSize, dataFormat);
651  out << " }\n";
652  if( i != maxRadix - 1)
653  out << " jj +=" << (groupSize / N) << ";\n";
654  }
655  out << "}\n";
656  out << "else {\n";
657  for( i = 0; i < maxRadix; i++ )
658  {
659  formattedStore(out, i, i*groupSize, dataFormat);
660  }
661  out << "}\n";
662  lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
663  }
664  return lMemSize;
665  }
666 
667  static void
668  insertLocalLoadIndexArithmatic(
669  std::ostream& out,
670  int Nprev,
671  int Nr,
672  int numWorkItemsReq,
673  int numWorkItemsPerXForm,
674  int numXFormsPerWG,
675  int offset,
676  int midPad
677  )
678  {
679  int Ncurr = Nprev * Nr;
680  int logNcurr = (int)std::log2(Ncurr);
681  int logNprev = (int)std::log2(Nprev);
682  int incr = (numWorkItemsReq + offset) * Nr + midPad;
683  if (Ncurr < numWorkItemsPerXForm) {
684  if (Nprev == 1) {
685  out << " j = ii & " << (Ncurr - 1) << ";\n";
686  } else {
687  out << " j = (ii & " << (Ncurr - 1) << ") >> " << logNprev << ";\n";
688  }
689  if (Nprev == 1) {
690  out << " i = ii >> " << logNcurr << ";\n";
691  } else {
692  out << " i = mad24(ii >> " << logNcurr << ", "
693  << Nprev << ", ii & " << (Nprev-1) << ");\n";
694  }
695  } else {
696  if (Nprev == 1) {
697  out << " j = ii;\n";
698  } else {
699  out << " j = ii >> " << logNprev << ";\n";
700  }
701  if (Nprev == 1) {
702  out << " i = 0;\n";
703  } else {
704  out << " i = ii & " << (Nprev-1) << ";\n";
705  }
706  }
707  if (numXFormsPerWG > 1) {
708  out << " i = mad24(jj, " << incr << ", i);\n";
709  }
710  out << " lMemLoad = sMem + mad24(j, "
711  << (numWorkItemsReq + offset)
712  << ", i);\n";
713  }
714 
715  void
716  insertLocalStoreIndexArithmatic(
717  std::ostream& out,
718  int numWorkItemsReq,
719  int numXFormsPerWG,
720  int Nr,
721  int offset,
722  int midPad
723  ) {
724  if (numXFormsPerWG == 1) {
725  out << " lMemStore = sMem + ii;\n";
726  } else {
727  out << " lMemStore = sMem + mad24(jj, "
728  << ((numWorkItemsReq + offset)*Nr + midPad)
729  <<", ii);\n";
730  }
731  }
732 
733  void
734  insertLocalStores(
735  std::ostream& out,
736  int numIter,
737  int Nr,
738  int numWorkItemsPerXForm,
739  int numWorkItemsReq,
740  int offset,
741  const char* comp
742  ) {
743  for (int z=0; z<numIter; ++z) {
744  for (int k=0; k<Nr; ++k) {
745  int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm;
746  out << " lMemStore[" << (index)
747  << "] = a[" << (z*Nr + k) << "]." << comp << ";\n";
748  }
749  }
750  out << " barrier(CLK_LOCAL_MEM_FENCE);\n";
751  }
752 
753  void
754  insertLocalLoads(
755  std::ostream& out,
756  int n,
757  int Nr,
758  int Nrn,
759  int Nprev,
760  int Ncurr,
761  int numWorkItemsPerXForm,
762  int numWorkItemsReq,
763  int offset,
764  const char* comp
765  ) {
766  int numWorkItemsReqN = n / Nrn;
767  int interBlockHNum = std::max( Nprev / numWorkItemsPerXForm, 1 );
768  int interBlockHStride = numWorkItemsPerXForm;
769  int vertWidth = std::max(numWorkItemsPerXForm / Nprev, 1);
770  vertWidth = std::min( vertWidth, Nr);
771  int vertNum = Nr / vertWidth;
772  int vertStride = ( n / Nr + offset ) * vertWidth;
773  int iter = std::max( numWorkItemsReqN / numWorkItemsPerXForm, 1);
774  int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1;
775  intraBlockHStride *= Nprev;
776  int stride = numWorkItemsReq / Nrn;
777  int i;
778  for(i = 0; i < iter; i++) {
779  int ii = i / (interBlockHNum * vertNum);
780  int zz = i % (interBlockHNum * vertNum);
781  int jj = zz % interBlockHNum;
782  int kk = zz / interBlockHNum;
783  int z;
784  for(z = 0; z < Nrn; z++) {
785  int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride;
786  out << " a[" << (i*Nrn + z) << "]."
787  << comp << " = lMemLoad[" << (st) << "];\n";
788  }
789  }
790  out << " barrier(CLK_LOCAL_MEM_FENCE);\n";
791  }
792 
793  int
794  getPadding(
795  int numWorkItemsPerXForm,
796  int Nprev,
797  int numWorkItemsReq,
798  int numXFormsPerWG,
799  int Nr,
800  int numBanks,
801  int* offset,
802  int* midPad
803  ) {
804  if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks))
805  *offset = 0;
806  else {
807  int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev;
808  int numColsReq = 1;
809  if(numRowsReq > Nr)
810  numColsReq = numRowsReq / Nr;
811  numColsReq = Nprev * numColsReq;
812  *offset = numColsReq;
813  }
814  if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1)
815  *midPad = 0;
816  else {
817  int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1);
818  if( bankNum >= numWorkItemsPerXForm )
819  *midPad = 0;
820  else
821  *midPad = numWorkItemsPerXForm - bankNum;
822  }
823  int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1);
824  return lMemSize;
825  }
826 
827  void
828  getGlobalRadixInfo(
829  int n,
830  int *radix,
831  int *R1,
832  int *R2,
833  int *numRadices
834  ) {
835  int baseRadix = std::min(n, 128);
836  int numR = 0;
837  int N = n;
838  while (N > baseRadix) {
839  N /= baseRadix;
840  numR++;
841  }
842  for (int i = 0; i < numR; i++) {
843  radix[i] = baseRadix;
844  }
845  radix[numR] = N;
846  numR++;
847  *numRadices = numR;
848  for (int i = 0; i < numR; i++) {
849  int B = radix[i];
850  if (B <= 8) {
851  R1[i] = B;
852  R2[i] = 1;
853  continue;
854  }
855  int r1 = 2;
856  int r2 = B / r1;
857  while (r2 > r1) {
858  r1 *= 2;
859  r2 = B / r1;
860  }
861  R1[i] = r1;
862  R2[i] = r2;
863  }
864  }
865 
866  inline void
867  trim_right(std::string& rhs) {
868  while (!rhs.empty() && rhs.back() <= ' ') { rhs.pop_back(); }
869  }
870 
871  inline std::string
872  trim(std::string rhs) {
873  trim_right(rhs);
874  return rhs;
875  }
876 
877  template <class T, int N>
878  inline size_t
879  num_bytes(const blitz::TinyVector<int,N>& n) {
880  return blitz::product(n)*sizeof(T);
881  }
882 
883 }
884 
886 vtb::opencl::Fourier_transform_base::kernel_name(const char* prefix) {
887  const char sep = '_';
888  std::stringstream name;
889  name << prefix << sep
890  << this->_shape(0) << sep
891  << this->_shape(1) << sep
892  << this->_shape(2) << sep
893  << ++this->_kindex;
894  return name.str();
895 }
896 
897 void
898 vtb::opencl::Fourier_transform_base::generate_source_code() {
899  this->_src = base_kernels;
900  this->_kindex = 0;
901  this->_kernels.clear();
902  for (int i=0; i<3; ++i) {
903  this->generate_fft(i);
904  }
905  for (const auto& kernel : this->_kernels) {
906  if (!kernel.in_place_possible) {
907  this->temp_buffer_needed = true;
908  break;
909  }
910  }
911 }
912 
913 void
914 vtb::opencl::Fourier_transform_base::generate_fft(int axis) {
915  if (axis == 0) {
916  int nx = this->_shape(0);
917  if (nx > this->max_localmem_fft_size) {
918  generate_fft_global(nx, 1, axis, 1);
919  } else if (nx > 1) {
920  std::vector<int> radices{radix_array(nx)};
921  if (nx/radices[0] <= this->_maxworkgroupsize) {
922  generate_fft_local();
923  } else {
924  radices = radix_array(nx, this->_maxradix);
925  if (nx/radices[0] <= this->_maxworkgroupsize) {
926  generate_fft_local();
927  } else {
928  generate_fft_global(nx, 1, axis, 1);
929  }
930  }
931  }
932  }
933  if (axis == 1) {
934  int ny = this->_shape(1);
935  if (ny > 1) {
936  int stride = this->_shape(0);
937  generate_fft_global(ny, stride, axis, 1);
938  }
939  }
940  if (axis == 2) {
941  int nz = this->_shape(2);
942  if (nz > 1) {
943  int stride = _shape(0)*_shape(1);
944  generate_fft_global(nz, stride, axis, 1);
945  }
946  }
947 }
948 
949 void
950 vtb::opencl::Fourier_transform_base::generate_fft_local() {
951  int n = this->_shape(0);
952  if (n > this->_maxworkgroupsize*this->_maxradix) {
953  throw std::invalid_argument{"signal length too big for local mem fft"};
954  }
955  std::vector<int> radices{radix_array(n)};
956  if (n/radices[0] > this->_maxworkgroupsize) {
957  radices = radix_array(n, this->_maxradix);
958  }
959  if (radices.front() > this->_maxradix) {
960  throw std::invalid_argument{"bad radix array"};
961  }
962  if (n/radices.front() > this->_maxworkgroupsize) {
963  throw std::invalid_argument{
964  "required work items per xform greater than "
965  "maximum work items allowed per work group for local mem fft"
966  };
967  }
968  int numRadix = radices.size();
969  {
970  int prod = 1;
971  for (int i=0; i<numRadix; ++i) {
972  prod *= radices[i];
973  }
974  if (prod != n) {
975  throw std::invalid_argument{"bad radices"};
976  }
977  }
978  int offset, midPad;
979  std::stringstream out;
980  Kernel_info kernel{};
981  kernel.name = this->kernel_name("fft_local");
982  kernel.axis = 0;
983  kernel.in_place_possible = true;
984  int numWorkItemsPerXForm = n / radices[0];
985  int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm;
986  assert(numWorkItemsPerWG <= this->_maxworkgroupsize);
987  int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm;
988  kernel.num_workgroups = 1;
989  kernel.num_xforms_per_workgroup = numXFormsPerWG;
990  kernel.num_workitems_per_workgroup = numWorkItemsPerWG;
991  int maxRadix = radices[0];
992  int lMemSize = 0;
993  insertVariables(out, maxRadix);
994  lMemSize = insertGlobalLoadsAndTranspose(
995  out,
996  n,
997  numWorkItemsPerXForm,
998  numXFormsPerWG,
999  maxRadix,
1000  this->min_mem_coalesce_width,
1001  this->_format
1002  );
1003  kernel.lmem_size = (lMemSize > kernel.lmem_size) ? lMemSize : kernel.lmem_size;
1004  int Nprev = 1;
1005  int len = n;
1006  for(int r = 0; r<numRadix; ++r) {
1007  int numIter = radices[0] / radices[r];
1008  int numWorkItemsReq = n / radices[r];
1009  int Ncurr = Nprev * radices[r];
1010  insertfftKernel(out, radices[r], numIter);
1011  if (r < (numRadix-1)) {
1012  insertTwiddleKernel(
1013  out,
1014  radices[r],
1015  numIter,
1016  Nprev,
1017  len,
1018  numWorkItemsPerXForm
1019  );
1020  lMemSize = getPadding(
1021  numWorkItemsPerXForm,
1022  Nprev,
1023  numWorkItemsReq,
1024  numXFormsPerWG,
1025  radices[r],
1026  this->num_local_mem_banks,
1027  &offset,
1028  &midPad
1029  );
1030  kernel.lmem_size = (lMemSize > kernel.lmem_size)
1031  ? lMemSize
1032  : kernel.lmem_size;
1033  insertLocalStoreIndexArithmatic(
1034  out,
1035  numWorkItemsReq,
1036  numXFormsPerWG,
1037  radices[r],
1038  offset,
1039  midPad
1040  );
1041  insertLocalLoadIndexArithmatic(
1042  out,
1043  Nprev,
1044  radices[r],
1045  numWorkItemsReq,
1046  numWorkItemsPerXForm,
1047  numXFormsPerWG,
1048  offset,
1049  midPad
1050  );
1051  insertLocalStores(
1052  out,
1053  numIter,
1054  radices[r],
1055  numWorkItemsPerXForm,
1056  numWorkItemsReq,
1057  offset,
1058  "x"
1059  );
1060  insertLocalLoads(
1061  out,
1062  n,
1063  radices[r],
1064  radices[r+1],
1065  Nprev,
1066  Ncurr,
1067  numWorkItemsPerXForm,
1068  numWorkItemsReq,
1069  offset,
1070  "x"
1071  );
1072  insertLocalStores(
1073  out,
1074  numIter,
1075  radices[r],
1076  numWorkItemsPerXForm,
1077  numWorkItemsReq,
1078  offset,
1079  "y"
1080  );
1081  insertLocalLoads(
1082  out,
1083  n,
1084  radices[r],
1085  radices[r+1],
1086  Nprev,
1087  Ncurr,
1088  numWorkItemsPerXForm,
1089  numWorkItemsReq,
1090  offset,
1091  "y"
1092  );
1093  Nprev = Ncurr;
1094  len = len / radices[r];
1095  }
1096  }
1097  lMemSize = insertGlobalStoresAndTranspose(
1098  out,
1099  n,
1100  maxRadix,
1101  radices[numRadix - 1],
1102  numWorkItemsPerXForm,
1103  numXFormsPerWG,
1104  this->min_mem_coalesce_width,
1105  this->_format
1106  );
1107  kernel.lmem_size = (lMemSize > kernel.lmem_size) ? lMemSize : kernel.lmem_size;
1108  std::stringstream result;
1109  result << this->_src;
1110  insertHeader(result, kernel.name, this->_format);
1111  result << "{\n";
1112  if (kernel.lmem_size) {
1113  result << " __local float sMem[" << kernel.lmem_size << "];\n";
1114  }
1115  result << out.str();
1116  result << "}\n";
1117  this->_src += result.str();
1118  this->_kernels.emplace_back(kernel);
1119 }
1120 
1121 void
1122 vtb::opencl::Fourier_transform_base::generate_fft_global(
1123  int n,
1124  int BS,
1125  int axis,
1126  int vertBS
1127 ) {
1128  int k, t;
1129  int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1130  int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1131  int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1132  int radix, R1, R2;
1133  int numRadices;
1134  int maxThreadsPerBlock = this->_maxworkgroupsize;
1135  int batchSize = this->min_mem_coalesce_width;
1136  int vertical = (axis == 0) ? 0 : 1;
1137  getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices);
1138  int numPasses = numRadices;
1139  int N = n;
1140  int m = (int)log2(n);
1141  int Rinit = vertical ? BS : 1;
1142  batchSize = vertical ? std::min(BS, batchSize) : batchSize;
1143  std::stringstream out;
1144  for (int passNum=0; passNum<numPasses; ++passNum) {
1145  out.str("");
1146  radix = radixArr[passNum];
1147  R1 = R1Arr[passNum];
1148  R2 = R2Arr[passNum];
1149  int strideI = Rinit;
1150  for (int i=0; i<numPasses; ++i) {
1151  if (i != passNum){
1152  strideI *= radixArr[i];
1153  }
1154  }
1155  int strideO = Rinit;
1156  for (int i=0; i<passNum; ++i) {
1157  strideO *= radixArr[i];
1158  }
1159  int threadsPerXForm = R2;
1160  batchSize = R2 == 1 ? this->_maxworkgroupsize : batchSize;
1161  batchSize = std::min(batchSize, strideI);
1162  int threadsPerBlock = batchSize * threadsPerXForm;
1163  threadsPerBlock = std::min(threadsPerBlock, maxThreadsPerBlock);
1164  batchSize = threadsPerBlock / threadsPerXForm;
1165  assert(R2 <= R1);
1166  assert(R1*R2 == radix);
1167  assert(R1 <= this->_maxradix);
1168  assert(threadsPerBlock <= maxThreadsPerBlock);
1169  int numIter = R1 / R2;
1170  int gInInc = threadsPerBlock / batchSize;
1171  int lgStrideO = (int)log2(strideO);
1172  int numBlocksPerXForm = strideI / batchSize;
1173  int numBlocks = numBlocksPerXForm;
1174  if (!vertical) {
1175  numBlocks *= BS;
1176  } else {
1177  numBlocks *= vertBS;
1178  }
1179  Kernel_info kernel{};
1180  kernel.name = this->kernel_name("fft_global");
1181  if (R2 == 1) {
1182  kernel.lmem_size = 0;
1183  } else {
1184  if (strideO == 1) {
1185  kernel.lmem_size = (radix + 1)*batchSize;
1186  } else {
1187  kernel.lmem_size = threadsPerBlock*R1;
1188  }
1189  }
1190  kernel.num_workgroups = numBlocks;
1191  kernel.num_xforms_per_workgroup = 1;
1192  kernel.num_workitems_per_workgroup = threadsPerBlock;
1193  kernel.axis = axis;
1194  if((passNum == (numPasses - 1)) && (numPasses & 1)) {
1195  kernel.in_place_possible = true;
1196  } else {
1197  kernel.in_place_possible = false;
1198  }
1199  insertVariables(out, R1);
1200  if (vertical) {
1201  out << "xNum = groupId >> " << ((int)log2(numBlocksPerXForm)) << ";\n";
1202  out << "groupId = groupId & " << (numBlocksPerXForm - 1) << ";\n";
1203  out << "indexIn = mad24(groupId, " << (batchSize) << ", xNum << " << ((int)log2(n*BS)) << ");\n";
1204  out << "tid = mul24(groupId, " << (batchSize) << ");\n";
1205  out << "i = tid >> " << (lgStrideO) << ";\n";
1206  out << "j = tid & " << (strideO - 1) << ";\n";
1207  int stride = radix*Rinit;
1208  for (int i=0; i<passNum; ++i) {
1209  stride *= radixArr[i];
1210  }
1211  out << "indexOut = mad24(i, " << (stride) << ", j + " << "(xNum << " << ((int) log2(n*BS)) << "));\n";
1212  out << "bNum = groupId;\n";
1213  } else {
1214  int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm);
1215  out << "bNum = groupId & " << (numBlocksPerXForm - 1) << ";\n";
1216  out << "xNum = groupId >> " << (lgNumBlocksPerXForm) << ";\n";
1217  out << "indexIn = mul24(bNum, " << (batchSize) << ");\n";
1218  out << "tid = indexIn;\n";
1219  out << "i = tid >> " << (lgStrideO) << ";\n";
1220  out << "j = tid & " << (strideO - 1) << ";\n";
1221  int stride = radix*Rinit;
1222  for (int i=0; i<passNum; ++i) {
1223  stride *= radixArr[i];
1224  }
1225  out << "indexOut = mad24(i, " << (stride) << ", j);\n";
1226  out << "indexIn += (xNum << " << (m) << ");\n";
1227  out << "indexOut += (xNum << " << (m) << ");\n";
1228  }
1229  // Load Data
1230  int lgBatchSize = (int)log2(batchSize);
1231  out << "tid = lId;\n";
1232  out << "i = tid & " << (batchSize - 1) << ";\n";
1233  out << "j = tid >> " << (lgBatchSize) << ";\n";
1234  out << "indexIn += mad24(j, " << (strideI) << ", i);\n";
1235  if (this->_format == Fourier_transform_format::Split_complex) {
1236  out << "in_real += indexIn;\n";
1237  out << "in_imag += indexIn;\n";
1238  for (int j=0; j<R1; ++j)
1239  out << "a[" << (j) << "].x = in_real[" << (j*gInInc*strideI) << "];\n";
1240  for (int j=0; j<R1; ++j)
1241  out << "a[" << (j) << "].y = in_imag[" << (j*gInInc*strideI) << "];\n";
1242  } else {
1243  out << "in += indexIn;\n";
1244  for (int j=0; j<R1; ++j) {
1245  out << "a[" << (j) << "] = in[" << (j*gInInc*strideI) << "];\n";
1246  }
1247  }
1248  out << "fftKernel" << (R1) << "(a, dir);\n";
1249  if (R2 > 1) {
1250  // twiddle
1251  for (int k = 1; k < R1; k++) {
1252  out << "ang = dir*(2.0f*M_PI*" << (k) << "/" << (radix) << ")*j;\n";
1253  out << "w = (float2)(native_cos(ang), native_sin(ang));\n";
1254  out << "a[" << (k) << "] = complexMul(a[" << (k) << "], w);\n";
1255  }
1256  // shuffle
1257  numIter = R1 / R2;
1258  out << "indexIn = mad24(j, " << (threadsPerBlock*numIter) << ", i);\n";
1259  out << "lMemStore = sMem + tid;\n";
1260  out << "lMemLoad = sMem + indexIn;\n";
1261  for (int k = 0; k < R1; k++) {
1262  out << "lMemStore[" << (k*threadsPerBlock) << "] = a[" << (k) << "].x;\n";
1263  }
1264  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1265  for(k = 0; k < numIter; k++)
1266  for(t = 0; t < R2; t++)
1267  out << "a[" << (k*R2+t) << "].x = lMemLoad[" << (t*batchSize + k*threadsPerBlock) << "];\n";
1268  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1269  for(k = 0; k < R1; k++)
1270  out << "lMemStore[" << (k*threadsPerBlock) << "] = a[" << (k) << "].y;\n";
1271  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1272  for(k = 0; k < numIter; k++)
1273  for(t = 0; t < R2; t++)
1274  out << "a[" << (k*R2+t) << "].y = lMemLoad[" << (t*batchSize + k*threadsPerBlock) << "];\n";
1275  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1276  for(int j = 0; j < numIter; j++)
1277  out << "fftKernel" << (R2) << "(a + " << (j*R2) << ", dir);\n";
1278  }
1279  // twiddle
1280  if (passNum < (numPasses - 1)) {
1281  out << "l = ((bNum << " << (lgBatchSize) << ") + i) >> " << (lgStrideO) << ";\n";
1282  out << "k = j << " << ((int)log2(R1/R2)) << ";\n";
1283  out << "ang1 = dir*(2.0f*M_PI/" << (N) << ")*l;\n";
1284  for(t = 0; t < R1; t++)
1285  {
1286  out << "ang = ang1*(k + " << ((t%R2)*R1 + (t/R2)) << ");\n";
1287  out << "w = (float2)(native_cos(ang), native_sin(ang));\n";
1288  out << "a[" << (t) << "] = complexMul(a[" << (t) << "], w);\n";
1289  }
1290  }
1291  // Store Data
1292  if(strideO == 1) {
1293  out << "lMemStore = sMem + mad24(i, " << (radix + 1) << ", j << " << ((int)log2(R1/R2)) << ");\n";
1294  out << "lMemLoad = sMem + mad24(tid >> " << ((int)log2(radix)) << ", " << (radix+1) << ", tid & " << (radix-1) << ");\n";
1295  for(int i = 0; i < R1/R2; i++)
1296  for(int j = 0; j < R2; j++)
1297  out << "lMemStore[ " << (i + j*R1) << "] = a[" << (i*R2+j) << "].x;\n";
1298  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1299  if(threadsPerBlock >= radix)
1300  {
1301  for(int i = 0; i < R1; i++)
1302  out << "a[" << (i) << "].x = lMemLoad[" << (i*(radix+1)*(threadsPerBlock/radix)) << "];\n";
1303  }
1304  else
1305  {
1306  int innerIter = radix/threadsPerBlock;
1307  int outerIter = R1/innerIter;
1308  for(int i = 0; i < outerIter; i++)
1309  for(int j = 0; j < innerIter; j++)
1310  out << "a[" << (i*innerIter+j) << "].x = lMemLoad[" << (j*threadsPerBlock + i*(radix+1)) << "];\n";
1311  }
1312  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1313  for (int i = 0; i < R1/R2; i++)
1314  for(int j = 0; j < R2; j++)
1315  out << "lMemStore[ " << (i + j*R1) << "] = a[" << (i*R2+j) << "].y;\n";
1316  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1317  if (threadsPerBlock >= radix) {
1318  for(int i = 0; i < R1; i++)
1319  out << "a[" << (i) << "].y = lMemLoad[" << (i*(radix+1)*(threadsPerBlock/radix)) << "];\n";
1320  }
1321  else
1322  {
1323  int innerIter = radix/threadsPerBlock;
1324  int outerIter = R1/innerIter;
1325  for(int i = 0; i < outerIter; i++)
1326  for (int j = 0; j < innerIter; j++)
1327  out << "a[" << (i*innerIter+j) << "].y = lMemLoad[" << (j*threadsPerBlock + i*(radix+1)) << "];\n";
1328  }
1329  out << "barrier(CLK_LOCAL_MEM_FENCE);\n";
1330  out << "indexOut += tid;\n";
1331  if(this->_format == Fourier_transform_format::Split_complex) {
1332  out << "out_real += indexOut;\n";
1333  out << "out_imag += indexOut;\n";
1334  for(k = 0; k < R1; k++)
1335  out << "out_real[" << (k*threadsPerBlock) << "] = a[" << (k) << "].x;\n";
1336  for(k = 0; k < R1; k++)
1337  out << "out_imag[" << (k*threadsPerBlock) << "] = a[" << (k) << "].y;\n";
1338  }
1339  else {
1340  out << "out += indexOut;\n";
1341  for(k = 0; k < R1; k++)
1342  out << "out[" << (k*threadsPerBlock) << "] = a[" << (k) << "];\n";
1343  }
1344  } else {
1345  out << "indexOut += mad24(j, " << (numIter*strideO) << ", i);\n";
1346  if(this->_format == Fourier_transform_format::Split_complex) {
1347  out << "out_real += indexOut;\n";
1348  out << "out_imag += indexOut;\n";
1349  for(k = 0; k < R1; k++)
1350  out << "out_real[" << (((k%R2)*R1 + (k/R2))*strideO) << "] = a[" << (k) << "].x;\n";
1351  for(k = 0; k < R1; k++)
1352  out << "out_imag[" << (((k%R2)*R1 + (k/R2))*strideO) << "] = a[" << (k) << "].y;\n";
1353  }
1354  else {
1355  out << "out += indexOut;\n";
1356  for(k = 0; k < R1; k++)
1357  out << "out[" << (((k%R2)*R1 + (k/R2))*strideO) << "] = a[" << (k) << "];\n";
1358  }
1359  }
1360  std::stringstream result;
1361  insertHeader(result, kernel.name, this->_format);
1362  result << "{\n";
1363  if (kernel.lmem_size) {
1364  result << " __local float sMem[" << (kernel.lmem_size) << "];\n";
1365  }
1366  result << out.str();
1367  result << "}\n";
1368  this->_src += result.str();
1369  N /= radix;
1370  this->_kernels.emplace_back(kernel);
1371  }
1372 }
1373 
1374 void
1375 vtb::opencl::Fourier_transform_base::getKernelWorkDimensions(
1376  const Kernel_info& kernel,
1377  int* batchSize,
1378  size_t* gWorkItems,
1379  size_t* lWorkItems
1380 ) {
1381  *lWorkItems = kernel.num_workitems_per_workgroup;
1382  int numWorkGroups = kernel.num_workgroups;
1383  int numXFormsPerWG = kernel.num_xforms_per_workgroup;
1384  int ny = this->_shape(1);
1385  int nz = this->_shape(2);
1386  if (kernel.axis == 0) {
1387  *batchSize *= ny*nz;
1388  numWorkGroups = (*batchSize % numXFormsPerWG)
1389  ? (*batchSize/numXFormsPerWG + 1)
1390  : (*batchSize/numXFormsPerWG);
1391  numWorkGroups *= kernel.num_workgroups;
1392  }
1393  if (kernel.axis == 1) {
1394  *batchSize *= nz;
1395  numWorkGroups *= *batchSize;
1396  }
1397  if (kernel.axis == 2) {
1398  numWorkGroups *= *batchSize;
1399  }
1400  *gWorkItems = numWorkGroups * *lWorkItems;
1401 }
1402 
1403 void
1405  using clock = std::chrono::system_clock;
1406  using std::chrono::seconds;
1407  int ni = max_power(0);
1408  int nj = max_power(1);
1409  std::atomic<int> count{};
1410  auto t0 = clock::now();
1411  int max_count = product(max_power);
1412  std::atomic<bool> slow{false};
1413  #if defined(VTB_WITH_OPENMP)
1414  #pragma omp parallel for collapse(2) schedule(dynamic,1)
1415  #endif
1416  for (int i=1; i<=ni; ++i) {
1417  for (int j=1; j<=nj; ++j) {
1419  fft.context(context);
1420  fft.shape({1<<i, 1<<j, 1});
1421  if (clock::now()-t0 > seconds(1)) { slow = true; }
1422  auto cnt = ++count;
1423  if (slow && cnt%10 == 0 && cnt >= 10) {
1424  std::fprintf(stderr, "%5d/%-5d compile fft\n", cnt, max_count);
1425  }
1426  }
1427  }
1428 }
1429 
1430 void
1431 vtb::opencl::Fourier_transform_base::dump(std::ostream& out) {
1432  size_t global = 0, local = 0;
1433  for (const auto& kernel : this->_kernels) {
1434  cl_int batch_size = 1;
1435  getKernelWorkDimensions(kernel, &batch_size, &global, &local);
1436  out << std::setw(20) << kernel.name
1437  << std::setw(20) << global
1438  << std::setw(20) << local
1439  << std::endl;
1440  }
1441  out << this->_src;
1442 }
1443 
1444 vtb::opencl::Fourier_transform_base::Fourier_transform_base(const int3& shape):
1445 _shape{shape} {
1446  this->init();
1447 }
1448 
1449 void
1450 vtb::opencl::Fourier_transform_base::allocate_temporary_buffer(int batch_size) {
1451  if (this->last_batch_size != batch_size) {
1452  this->last_batch_size = batch_size;
1453  size_t n = this->buffer_size(batch_size);
1454  this->_workarea = context()->context().buffer(clx::memory_flags::read_write, n);
1455  }
1456 }
1457 
1458 void
1459 vtb::opencl::Fourier_transform_base::enqueue(
1460  clx::buffer x,
1461  int direction,
1462  int batch_size
1463 ) {
1464  auto& ppl = context()->pipeline();
1465  clx::buffer buffer_in = x, buffer_out(nullptr);
1466  Kernel_info* first = this->_kernels.data();
1467  Kernel_info* last = this->_kernels.data() + this->_kernels.size();
1468  // compute in-place transforms
1469  while (first != last && first->in_place_possible) {
1470  auto& k = *first;
1471  int new_batch_size = batch_size;
1472  size_t local = 0, global = 0;
1473  this->getKernelWorkDimensions(k, &new_batch_size, &global, &local);
1474  k.kernel.argument(0, buffer_in);
1475  k.kernel.argument(1, buffer_in);
1476  k.kernel.argument(2, direction);
1477  k.kernel.argument(3, new_batch_size);
1478  ppl.step();
1479  ppl.kernel(k.kernel, clx::range{global}, clx::range{local});
1480  ++first;
1481  }
1482  // allocate temporary buffer to compute the remaining out-of-place transforms
1483  if (first != last) {
1484  this->allocate_temporary_buffer(batch_size);
1485  buffer_out = this->_workarea;
1486  auto remaining = last - first;
1487  if (remaining%2 != 0) {
1488  // TODO this is not tested
1489  ppl.step();
1490  ppl.copy(buffer_in, buffer_out);
1491  std::swap(buffer_in, buffer_out);
1492  }
1493  }
1494  // compute out-of-place transforms by swapping between actual and
1495  // temporary buffer
1496  while (first != last) {
1497  auto& k = *first;
1498  int new_batch_size = batch_size;
1499  size_t local = 0, global = 0;
1500  this->getKernelWorkDimensions(k, &new_batch_size, &global, &local);
1501  k.kernel.argument(0, buffer_in);
1502  k.kernel.argument(1, buffer_out);
1503  k.kernel.argument(2, direction);
1504  k.kernel.argument(3, new_batch_size);
1505  ppl.step();
1506  ppl.kernel(k.kernel, clx::range{global}, clx::range{local});
1507  std::swap(buffer_in, buffer_out);
1508  ++first;
1509  }
1510 }
1511 
1512 void
1513 vtb::opencl::Fourier_transform_base::init() {
1514  if (!blitz::is_power_of_two(blitz::product(this->_shape))) {
1515  throw std::invalid_argument{"bad shape"};
1516  }
1517  int dim = 0;
1518  while (dim < 3 && this->_shape(dim) != 1) {
1519  ++dim;
1520  }
1521  if (dim < 0 || dim > 2) {
1522  throw std::invalid_argument("OpenCL FFT supports 1,2,3 dimensions only");
1523  }
1524  this->_ndimensions = dim;
1525  bool success = false;
1526  clx::compiler cc = context()->compiler_copy();
1527  cc.options(cc.options() + " -cl-mad-enable");
1528  auto device = cc.devices().front();
1529  this->_maxworkgroupsize = device.max_work_group_size();
1530  while (!success) {
1531  this->generate_source_code();
1532  clx::program prog = cc.compile("fft.cl", this->_src);
1533  auto all_kernels = prog.kernels();
1534  for (auto& kernel : all_kernels) {
1535  auto name = kernel.name();
1536  for (auto& k : this->_kernels) {
1537  if (name == k.name) { k.kernel = kernel; break; }
1538  }
1539  }
1540  success = true;
1541  size_t min_wg_size = std::numeric_limits<size_t>::max();
1542  for (const auto& k : this->_kernels) {
1543  auto wg = k.kernel.work_group(device);
1544  if (wg.size < size_t(k.num_workitems_per_workgroup)) { success = false; }
1545  if (wg.size < min_wg_size) { min_wg_size = wg.size; }
1546  }
1547  this->_maxworkgroupsize = min_wg_size;
1548  }
1549 }
1550 
1551 void
1552 vtb::opencl::Chirp_Z_transform_base::context(Context* rhs) {
1553  this->_fft.context(rhs);
1554  auto prog = context()->compiler().compile("chirp_z_transform.cl");
1555  _makechirp = prog.kernel("make_chirp");
1556  _reciprocal_chirp = prog.kernel("reciprocal_chirp");
1557  _mult1 = prog.kernel("multiply_1");
1558  _mult2 = prog.kernel("multiply_2");
1559  _mult3 = prog.kernel("multiply_3");
1560  _zero_init = prog.kernel("zero_init");
1561 }
1562 
1563 void
1564 vtb::opencl::Chirp_Z_transform_base::make_chirp(
1565  const int3& shape,
1566  const int3& fft_shape
1567 ) {
1568  auto& ppl = context()->pipeline();
1569  _shape = shape;
1570  int3 chirp_shape{shape*2-1};
1571  ppl.allocate(product(chirp_shape), this->_chirp);
1572  ppl.allocate(product(fft_shape), this->_xp);
1573  ppl.allocate(product(fft_shape), this->_ichirp);
1574  auto& kernel = this->_makechirp;
1575  kernel.arguments(this->_chirp, shape(0), shape(1), shape(2));
1576  ppl.kernel(kernel, chirp_shape);
1577  ppl.step();
1578  VTB_DUMP(this->_chirp, chirp_shape, "chirp");
1579 }
1580 
1581 void
1582 vtb::opencl::Chirp_Z_transform_base::enqueue(
1583  clx::buffer x,
1584  int direction,
1585  int batch_size
1586 ) {
1587  using blitz::product;
1588  if (batch_size != 1) {
1589  throw std::runtime_error{"batch size > 1 not supported"};
1590  }
1591  auto& ppl = this->_fft.context()->pipeline();
1592  // TODO Step is deprecated. Replace it with Stack.
1593  // TODO zero_init is not copied, but this is not a problem until
1594  // we run two steps in parallel, which modern gpus cannot do
1595 // Step st;
1596  {
1597 // Step st1;
1598  clx::kernel zero = _zero_init;
1599  zero.argument(0, _xp);
1600  _mult1.argument(0, x);
1601  _mult1.argument(1, _chirp);
1602  _mult1.argument(2, direction);
1603  _mult1.argument(3, _fft.shape()(0));
1604  _mult1.argument(4, _fft.shape()(1));
1605  _mult1.argument(5, _fft.shape()(2));
1606  _mult1.argument(6, _xp);
1607  ppl.kernel(zero, _fft.shape());
1608  ppl.step();
1609  ppl.kernel(_mult1, _shape);
1610  ppl.step();
1611  VTB_DUMP(this->_xp, this->_fft.shape(), "xp");
1612  _fft.enqueue(_xp, direction, batch_size);
1613  VTB_DUMP(this->_xp, this->_fft.shape(), "fft(xp)");
1614  }
1615  int3 chirp_shape{_shape*2-1};
1616  {
1617 // Step st1;
1618  clx::kernel zero = _zero_init;
1619  zero.argument(0, _ichirp);
1620  _reciprocal_chirp.argument(0, _chirp);
1621  _reciprocal_chirp.argument(1, direction);
1622  _reciprocal_chirp.argument(2, _fft.shape()(0));
1623  _reciprocal_chirp.argument(3, _fft.shape()(1));
1624  _reciprocal_chirp.argument(4, _fft.shape()(2));
1625  _reciprocal_chirp.argument(5, _ichirp);
1626  ppl.kernel(zero, _fft.shape());
1627  ppl.step();
1628  ppl.kernel(_reciprocal_chirp, chirp_shape);
1629  ppl.step();
1630  VTB_DUMP(this->_ichirp, this->_fft.shape(), "ichirp");
1631  _fft.enqueue(_ichirp, direction, batch_size);
1632  VTB_DUMP(this->_ichirp, this->_fft.shape(), "fft(ichirp)");
1633  }
1634  _mult2.argument(0, _ichirp);
1635  _mult2.argument(1, _xp);
1636  ppl.step();
1637  ppl.kernel(_mult2, _fft.shape());
1638  VTB_DUMP(this->_ichirp, this->_fft.shape(), "mult2");
1639  ppl.step();
1640  _fft.enqueue(_ichirp, -direction, batch_size);
1641  VTB_DUMP(this->_ichirp, this->_fft.shape(), "ifft");
1642  ppl.step();
1643  _mult3.argument(0, _ichirp);
1644  _mult3.argument(1, _chirp);
1645  _mult3.argument(2, x);
1646  _mult3.argument(3, _fft.shape()(0));
1647  _mult3.argument(4, _fft.shape()(1));
1648  _mult3.argument(5, _fft.shape()(2));
1649  _mult3.argument(6, 1.f / product(_fft.shape()));
1650  ppl.kernel(_mult3, _shape);
1651 }
T max(T... args)
static void precompile(const int3 &max_power, Context *context)
Compile the code for each power of 2 up to max_power.