10 #include <openclx/compiler> 12 #include <vtestbed/config/openmp.hh> 13 #include <vtestbed/opencl/fourier_transform.hh> 14 #include <vtestbed/opencl/pipeline.hh> 18 #if defined(VTB_DEBUG_CHIRP_Z) 19 template <
class T,
int N>
23 blitz::Array<T,N> x(shape);
26 std::clog << name <<
'=' << x << std::endl;
28 #define VTB_DUMP(x, shape, name) ::dump(x,shape,name) 30 #define VTB_DUMP(x, shape, name) 34 const char* base_kernels =
36 "#define M_PI 0x1.921fb54442d18p+1\n" 38 "#define complexMul(a,b) ((float2)(mad(-(a).y, (b).y, (a).x * (b).x), mad((a).y, (b).x, (a).x * (b).y)))\n" 39 "#define conj(a) ((float2)((a).x, -(a).y))\n" 40 "#define conjTransp(a) ((float2)(-(a).y, (a).x))\n" 42 "#define fftKernel2(a,dir) \\\n" 44 " float2 c = (a)[0]; \\\n" 45 " (a)[0] = c + (a)[1]; \\\n" 46 " (a)[1] = c - (a)[1]; \\\n" 49 "#define fftKernel2S(d1,d2,dir) \\\n" 51 " float2 c = (d1); \\\n" 52 " (d1) = c + (d2); \\\n" 53 " (d2) = c - (d2); \\\n" 56 "#define fftKernel4(a,dir) \\\n" 58 " fftKernel2S((a)[0], (a)[2], dir); \\\n" 59 " fftKernel2S((a)[1], (a)[3], dir); \\\n" 60 " fftKernel2S((a)[0], (a)[1], dir); \\\n" 61 " (a)[3] = (float2)(dir)*(conjTransp((a)[3])); \\\n" 62 " fftKernel2S((a)[2], (a)[3], dir); \\\n" 63 " float2 c = (a)[1]; \\\n" 64 " (a)[1] = (a)[2]; \\\n" 68 "#define fftKernel4s(a0,a1,a2,a3,dir) \\\n" 70 " fftKernel2S((a0), (a2), dir); \\\n" 71 " fftKernel2S((a1), (a3), dir); \\\n" 72 " fftKernel2S((a0), (a1), dir); \\\n" 73 " (a3) = (float2)(dir)*(conjTransp((a3))); \\\n" 74 " fftKernel2S((a2), (a3), dir); \\\n" 75 " float2 c = (a1); \\\n" 80 "#define bitreverse8(a) \\\n" 84 " (a)[1] = (a)[4]; \\\n" 87 " (a)[3] = (a)[6]; \\\n" 91 "#define fftKernel8(a,dir) \\\n" 93 " const float2 w1 = (float2)(0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f); \\\n" 94 " const float2 w3 = (float2)(-0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f); \\\n" 96 " fftKernel2S((a)[0], (a)[4], dir); \\\n" 97 " fftKernel2S((a)[1], (a)[5], dir); \\\n" 98 " fftKernel2S((a)[2], (a)[6], dir); \\\n" 99 " fftKernel2S((a)[3], (a)[7], dir); \\\n" 100 " (a)[5] = complexMul(w1, (a)[5]); \\\n" 101 " (a)[6] = (float2)(dir)*(conjTransp((a)[6])); \\\n" 102 " (a)[7] = complexMul(w3, (a)[7]); \\\n" 103 " fftKernel2S((a)[0], (a)[2], dir); \\\n" 104 " fftKernel2S((a)[1], (a)[3], dir); \\\n" 105 " fftKernel2S((a)[4], (a)[6], dir); \\\n" 106 " fftKernel2S((a)[5], (a)[7], dir); \\\n" 107 " (a)[3] = (float2)(dir)*(conjTransp((a)[3])); \\\n" 108 " (a)[7] = (float2)(dir)*(conjTransp((a)[7])); \\\n" 109 " fftKernel2S((a)[0], (a)[1], dir); \\\n" 110 " fftKernel2S((a)[2], (a)[3], dir); \\\n" 111 " fftKernel2S((a)[4], (a)[5], dir); \\\n" 112 " fftKernel2S((a)[6], (a)[7], dir); \\\n" 113 " bitreverse8((a)); \\\n" 116 "#define bitreverse4x4(a) \\\n" 119 " c = (a)[1]; (a)[1] = (a)[4]; (a)[4] = c; \\\n" 120 " c = (a)[2]; (a)[2] = (a)[8]; (a)[8] = c; \\\n" 121 " c = (a)[3]; (a)[3] = (a)[12]; (a)[12] = c; \\\n" 122 " c = (a)[6]; (a)[6] = (a)[9]; (a)[9] = c; \\\n" 123 " c = (a)[7]; (a)[7] = (a)[13]; (a)[13] = c; \\\n" 124 " c = (a)[11]; (a)[11] = (a)[14]; (a)[14] = c; \\\n" 127 "#define fftKernel16(a,dir) \\\n" 129 " const float w0 = 0x1.d906bcp-1f; \\\n" 130 " const float w1 = 0x1.87de2ap-2f; \\\n" 131 " const float w2 = 0x1.6a09e6p-1f; \\\n" 132 " fftKernel4s((a)[0], (a)[4], (a)[8], (a)[12], dir); \\\n" 133 " fftKernel4s((a)[1], (a)[5], (a)[9], (a)[13], dir); \\\n" 134 " fftKernel4s((a)[2], (a)[6], (a)[10], (a)[14], dir); \\\n" 135 " fftKernel4s((a)[3], (a)[7], (a)[11], (a)[15], dir); \\\n" 136 " (a)[5] = complexMul((a)[5], (float2)(w0, dir*w1)); \\\n" 137 " (a)[6] = complexMul((a)[6], (float2)(w2, dir*w2)); \\\n" 138 " (a)[7] = complexMul((a)[7], (float2)(w1, dir*w0)); \\\n" 139 " (a)[9] = complexMul((a)[9], (float2)(w2, dir*w2)); \\\n" 140 " (a)[10] = (float2)(dir)*(conjTransp((a)[10])); \\\n" 141 " (a)[11] = complexMul((a)[11], (float2)(-w2, dir*w2)); \\\n" 142 " (a)[13] = complexMul((a)[13], (float2)(w1, dir*w0)); \\\n" 143 " (a)[14] = complexMul((a)[14], (float2)(-w2, dir*w2)); \\\n" 144 " (a)[15] = complexMul((a)[15], (float2)(-w0, dir*-w1)); \\\n" 145 " fftKernel4((a), dir); \\\n" 146 " fftKernel4((a) + 4, dir); \\\n" 147 " fftKernel4((a) + 8, dir); \\\n" 148 " fftKernel4((a) + 12, dir); \\\n" 149 " bitreverse4x4((a)); \\\n" 152 "#define bitreverse32(a) \\\n" 154 " float2 c1, c2; \\\n" 155 " c1 = (a)[2]; (a)[2] = (a)[1]; c2 = (a)[4]; (a)[4] = c1; c1 = (a)[8]; (a)[8] = c2; c2 = (a)[16]; (a)[16] = c1; (a)[1] = c2; \\\n" 156 " c1 = (a)[6]; (a)[6] = (a)[3]; c2 = (a)[12]; (a)[12] = c1; c1 = (a)[24]; (a)[24] = c2; c2 = (a)[17]; (a)[17] = c1; (a)[3] = c2; \\\n" 157 " c1 = (a)[10]; (a)[10] = (a)[5]; c2 = (a)[20]; (a)[20] = c1; c1 = (a)[9]; (a)[9] = c2; c2 = (a)[18]; (a)[18] = c1; (a)[5] = c2; \\\n" 158 " c1 = (a)[14]; (a)[14] = (a)[7]; c2 = (a)[28]; (a)[28] = c1; c1 = (a)[25]; (a)[25] = c2; c2 = (a)[19]; (a)[19] = c1; (a)[7] = c2; \\\n" 159 " c1 = (a)[22]; (a)[22] = (a)[11]; c2 = (a)[13]; (a)[13] = c1; c1 = (a)[26]; (a)[26] = c2; c2 = (a)[21]; (a)[21] = c1; (a)[11] = c2; \\\n" 160 " c1 = (a)[30]; (a)[30] = (a)[15]; c2 = (a)[29]; (a)[29] = c1; c1 = (a)[27]; (a)[27] = c2; c2 = (a)[23]; (a)[23] = c1; (a)[15] = c2; \\\n" 163 "#define fftKernel32(a,dir) \\\n" 165 " fftKernel2S((a)[0], (a)[16], dir); \\\n" 166 " fftKernel2S((a)[1], (a)[17], dir); \\\n" 167 " fftKernel2S((a)[2], (a)[18], dir); \\\n" 168 " fftKernel2S((a)[3], (a)[19], dir); \\\n" 169 " fftKernel2S((a)[4], (a)[20], dir); \\\n" 170 " fftKernel2S((a)[5], (a)[21], dir); \\\n" 171 " fftKernel2S((a)[6], (a)[22], dir); \\\n" 172 " fftKernel2S((a)[7], (a)[23], dir); \\\n" 173 " fftKernel2S((a)[8], (a)[24], dir); \\\n" 174 " fftKernel2S((a)[9], (a)[25], dir); \\\n" 175 " fftKernel2S((a)[10], (a)[26], dir); \\\n" 176 " fftKernel2S((a)[11], (a)[27], dir); \\\n" 177 " fftKernel2S((a)[12], (a)[28], dir); \\\n" 178 " fftKernel2S((a)[13], (a)[29], dir); \\\n" 179 " fftKernel2S((a)[14], (a)[30], dir); \\\n" 180 " fftKernel2S((a)[15], (a)[31], dir); \\\n" 181 " (a)[17] = complexMul((a)[17], (float2)(0x1.f6297cp-1f, dir*0x1.8f8b84p-3f)); \\\n" 182 " (a)[18] = complexMul((a)[18], (float2)(0x1.d906bcp-1f, dir*0x1.87de2ap-2f)); \\\n" 183 " (a)[19] = complexMul((a)[19], (float2)(0x1.a9b662p-1f, dir*0x1.1c73b4p-1f)); \\\n" 184 " (a)[20] = complexMul((a)[20], (float2)(0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f)); \\\n" 185 " (a)[21] = complexMul((a)[21], (float2)(0x1.1c73b4p-1f, dir*0x1.a9b662p-1f)); \\\n" 186 " (a)[22] = complexMul((a)[22], (float2)(0x1.87de2ap-2f, dir*0x1.d906bcp-1f)); \\\n" 187 " (a)[23] = complexMul((a)[23], (float2)(0x1.8f8b84p-3f, dir*0x1.f6297cp-1f)); \\\n" 188 " (a)[24] = complexMul((a)[24], (float2)(0x0p+0f, dir*0x1p+0f)); \\\n" 189 " (a)[25] = complexMul((a)[25], (float2)(-0x1.8f8b84p-3f, dir*0x1.f6297cp-1f)); \\\n" 190 " (a)[26] = complexMul((a)[26], (float2)(-0x1.87de2ap-2f, dir*0x1.d906bcp-1f)); \\\n" 191 " (a)[27] = complexMul((a)[27], (float2)(-0x1.1c73b4p-1f, dir*0x1.a9b662p-1f)); \\\n" 192 " (a)[28] = complexMul((a)[28], (float2)(-0x1.6a09e6p-1f, dir*0x1.6a09e6p-1f)); \\\n" 193 " (a)[29] = complexMul((a)[29], (float2)(-0x1.a9b662p-1f, dir*0x1.1c73b4p-1f)); \\\n" 194 " (a)[30] = complexMul((a)[30], (float2)(-0x1.d906bcp-1f, dir*0x1.87de2ap-2f)); \\\n" 195 " (a)[31] = complexMul((a)[31], (float2)(-0x1.f6297cp-1f, dir*0x1.8f8b84p-3f)); \\\n" 196 " fftKernel16((a), dir); \\\n" 197 " fftKernel16((a) + 16, dir); \\\n" 198 " bitreverse32((a)); \\\n" 202 radix_array(
int n,
int max) {
204 max = std::min(n, max);
206 result.push_back(max);
217 case 2: result = {2};
break;
218 case 4: result = {4};
break;
219 case 8: result = {8};
break;
220 case 16: result = {8,2};
break;
221 case 32: result = {8,4};
break;
222 case 64: result = {8,8};
break;
223 case 128: result = {8,4,4};
break;
224 case 256: result = {4,4,4,4};
break;
225 case 512: result = {8,8,8};
break;
226 case 1024: result = {16,16,4};
break;
227 case 2048: result = {8,8,8,4};
break;
238 vtb::opencl::Fourier_transform_format dataFormat
240 using vtb::opencl::Fourier_transform_format;
241 if (dataFormat == Fourier_transform_format::Interleaved_complex) {
242 out <<
" a[" << (aIndex) <<
"] = in[" << (gIndex) <<
"];\n";
244 out <<
" a[" << (aIndex) <<
"].x = in_real[" << (gIndex) <<
"];\n";
245 out <<
" a[" << (aIndex) <<
"].y = in_imag[" << (gIndex) <<
"];\n";
254 vtb::opencl::Fourier_transform_format dataFormat
256 using vtb::opencl::Fourier_transform_format;
257 if (dataFormat == Fourier_transform_format::Interleaved_complex) {
258 out <<
" out[" << (gIndex) <<
"] = a[" << (aIndex) <<
"];\n";
260 out <<
" out_real[" << (gIndex) <<
"] = a[" << (aIndex) <<
"].x;\n";
261 out <<
" out_imag[" << (gIndex) <<
"] = a[" << (aIndex) <<
"].y;\n";
269 vtb::opencl::Fourier_transform_format dataFormat
271 using vtb::opencl::Fourier_transform_format;
272 if (dataFormat == Fourier_transform_format::Split_complex) {
273 out <<
"__kernel void " + kernelName
274 <<
"(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n";
276 out <<
"__kernel void " + kernelName
277 <<
"(__global float2 *in, __global float2 *out, int dir, int S)\n";
283 out <<
" int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n";
284 out <<
" int s, ii, jj, offset;\n";
285 out <<
" float2 w;\n";
286 out <<
" float ang, angf, ang1;\n";
287 out <<
" __local float *lMemStore, *lMemLoad;\n";
288 out <<
" float2 a[" << maxRadix <<
"];\n";
289 out <<
" int lId = get_local_id( 0 );\n";
290 out <<
" int groupId = get_group_id( 0 );\n";
294 insertfftKernel(
std::ostream& out,
int Nr,
int numIter) {
295 for (
int i=0; i<numIter; ++i) {
296 out <<
" fftKernel" << (Nr) <<
"(a+" << (i*Nr) <<
", dir);\n";
307 int numWorkItemsPerXForm
309 int logNPrev = (int)std::log2(Nprev);
310 for (
int z=0; z<numIter; ++z) {
313 out <<
" angf = (float) (ii >> " << (logNPrev) <<
");\n";
315 out <<
" angf = (float) ii;\n";
319 out <<
" angf = (float) ((" << (z*numWorkItemsPerXForm) <<
" + ii) >>" << (logNPrev) <<
");\n";
321 out <<
" angf = (float) (" << (z*numWorkItemsPerXForm) <<
" + ii);\n";
324 for (
int k=1; k<Nr; ++k) {
327 out <<
" ang = dir * ( 2.0f * M_PI * " << (k) <<
".0f / " << (len) <<
".0f )" <<
" * angf;\n";
328 out <<
" w = (float2)(native_cos(ang), native_sin(ang));\n";
329 out <<
" a[" << (ind) <<
"] = complexMul(a[" << (ind) <<
"], w);\n";
335 insertGlobalLoadsAndTranspose(
338 int numWorkItemsPerXForm,
341 int mem_coalesce_width,
342 vtb::opencl::Fourier_transform_format dataFormat
344 using vtb::opencl::Fourier_transform_format;
345 int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm);
346 int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
348 if (numXFormsPerWG > 1) {
349 out <<
" s = S & " << (numXFormsPerWG-1) <<
";\n";
351 if (numWorkItemsPerXForm >= mem_coalesce_width) {
352 if (numXFormsPerWG > 1) {
353 out <<
" ii = lId & " << (numWorkItemsPerXForm-1) <<
";\n";
354 out <<
" jj = lId >> " << log2NumWorkItemsPerXForm <<
";\n";
355 out <<
" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n";
356 out <<
" offset = mad24( mad24(groupId, " 360 if (dataFormat == Fourier_transform_format::Interleaved_complex) {
361 out <<
" in += offset;\n";
362 out <<
" out += offset;\n";
364 out <<
" in_real += offset;\n";
365 out <<
" in_imag += offset;\n";
366 out <<
" out_real += offset;\n";
367 out <<
" out_imag += offset;\n";
369 for (
int i=0; i<R0; ++i) {
370 formattedLoad(out, i, i*numWorkItemsPerXForm, dataFormat);
374 out <<
" ii = lId;\n";
376 out <<
" offset = mad24(groupId, " << N <<
", ii);\n";
377 if (dataFormat == Fourier_transform_format::Interleaved_complex) {
378 out <<
" in += offset;\n";
379 out <<
" out += offset;\n";
381 out <<
" in_real += offset;\n";
382 out <<
" in_imag += offset;\n";
383 out <<
" out_real += offset;\n";
384 out <<
" out_imag += offset;\n";
386 for (
int i=0; i<R0; ++i) {
387 formattedLoad(out, i, i*numWorkItemsPerXForm, dataFormat);
390 }
else if (N >= mem_coalesce_width) {
391 int numInnerIter = N / mem_coalesce_width;
392 int numOuterIter = numXFormsPerWG / (groupSize / mem_coalesce_width);
394 out <<
" ii = lId & " << (mem_coalesce_width - 1) <<
";\n";
395 out <<
" jj = lId >> " << ((int)log2(mem_coalesce_width)) <<
";\n";
396 out <<
" lMemStore = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) <<
", ii );\n";
397 out <<
" offset = mad24( groupId, " << (numXFormsPerWG) <<
", jj);\n";
398 out <<
" offset = mad24( offset, " << (N) <<
", ii );\n";
399 if (dataFormat == Fourier_transform_format::Interleaved_complex) {
400 out <<
" in += offset;\n";
401 out <<
" out += offset;\n";
403 out <<
" in_real += offset;\n";
404 out <<
" in_imag += offset;\n";
405 out <<
" out_real += offset;\n";
406 out <<
" out_imag += offset;\n";
408 out <<
"if((groupId == get_num_groups(0)-1) && s) {\n";
409 for(
int i=0; i<numOuterIter; ++i) {
410 out <<
" if( jj < s ) {\n";
411 for (
int j=0; j<numInnerIter; ++j) {
414 i * numInnerIter + j,
415 j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N,
420 if (i != numOuterIter-1) {
421 out <<
" jj += " << (groupSize / mem_coalesce_width) <<
";\n";
426 for (
int i = 0; i < numOuterIter; i++ ) {
427 for (
int j = 0; j < numInnerIter; j++ ) {
430 i * numInnerIter + j,
431 j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N,
437 out <<
" ii = lId & " << (numWorkItemsPerXForm - 1) <<
";\n";
438 out <<
" jj = lId >> " << (log2NumWorkItemsPerXForm) <<
";\n";
439 out <<
" lMemLoad = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) <<
", ii);\n";
440 for (
int i=0; i<numOuterIter; ++i) {
441 for (
int j=0; j<numInnerIter; ++j) {
443 << (j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm ))
445 << (i * numInnerIter + j)
449 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
450 for (
int i=0; i<R0; ++i) {
454 << (i * numWorkItemsPerXForm)
457 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
458 for (
int i=0; i<numOuterIter; ++i) {
459 for (
int j=0; j<numInnerIter; ++j) {
461 << (j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm ))
463 << (i * numInnerIter + j)
467 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
468 for (
int i=0; i<R0; ++i) {
472 << (i * numWorkItemsPerXForm) <<
"];\n";
474 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
475 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
477 out <<
" offset = mad24( groupId, " << (N * numXFormsPerWG) <<
", lId );\n";
478 if (dataFormat == Fourier_transform_format::Interleaved_complex) {
479 out <<
" in += offset;\n";
480 out <<
" out += offset;\n";
482 out <<
" in_real += offset;\n";
483 out <<
" in_imag += offset;\n";
484 out <<
" out_real += offset;\n";
485 out <<
" out_imag += offset;\n";
487 out <<
" ii = lId & " << (N-1) <<
";\n";
488 out <<
" jj = lId >> " << ((int)log2(N)) <<
";\n";
489 out <<
" lMemStore = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) <<
", ii );\n";
490 out <<
"if((groupId == get_num_groups(0)-1) && s) {\n";
491 for (
int i=0; i<R0; ++i) {
492 out <<
" if(jj < s )\n";
493 formattedLoad(out, i, i*groupSize, dataFormat);
495 out <<
" jj += " << (groupSize / N) <<
";\n";
500 for (
int i=0; i<R0; ++i) {
501 formattedLoad(out, i, i*groupSize, dataFormat);
504 if (numWorkItemsPerXForm > 1) {
505 out <<
" ii = lId & " << (numWorkItemsPerXForm - 1) <<
";\n";
506 out <<
" jj = lId >> " << (log2NumWorkItemsPerXForm) <<
";\n";
507 out <<
" lMemLoad = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) <<
", ii );\n";
510 out <<
" jj = lId;\n";
511 out <<
" lMemLoad = sMem + mul24( jj, " << (N + numWorkItemsPerXForm) <<
");\n";
513 for (
int i=0; i<R0; ++i) {
514 out <<
" lMemStore[" << (i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) <<
"] = a[" << (i) <<
"].x;\n";
516 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
518 for (
int i=0; i<R0; ++i) {
519 out <<
" a[" << (i) <<
"].x = lMemLoad[" << (i * numWorkItemsPerXForm) <<
"];\n";
521 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
522 for (
int i=0; i<R0; ++i) {
523 out <<
" lMemStore[" << (i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) <<
"] = a[" << (i) <<
"].y;\n";
525 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
526 for (
int i=0; i<R0; ++i) {
527 out <<
" a[" << (i) <<
"].y = lMemLoad[" << (i * numWorkItemsPerXForm) <<
"];\n";
529 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
530 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
536 insertGlobalStoresAndTranspose(
541 int numWorkItemsPerXForm,
543 int mem_coalesce_width,
544 vtb::opencl::Fourier_transform_format dataFormat
546 int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
549 int numIter = maxRadix / Nr;
550 if( numWorkItemsPerXForm >= mem_coalesce_width )
552 if(numXFormsPerWG > 1)
554 out <<
" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n";
556 for(i = 0; i < maxRadix; i++)
561 formattedStore(out, ind, i*numWorkItemsPerXForm, dataFormat);
563 if(numXFormsPerWG > 1)
566 else if( N >= mem_coalesce_width )
568 int numInnerIter = N / mem_coalesce_width;
569 int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
570 out <<
" lMemLoad = sMem + mad24( jj, " << (N + numWorkItemsPerXForm) <<
", ii );\n";
571 out <<
" ii = lId & " << (mem_coalesce_width - 1) <<
";\n";
572 out <<
" jj = lId >> " << ((int)log2(mem_coalesce_width)) <<
";\n";
573 out <<
" lMemStore = sMem + mad24( jj," << (N + numWorkItemsPerXForm) <<
", ii );\n";
574 for( i = 0; i < maxRadix; i++ )
579 out <<
" lMemLoad[" << (i*numWorkItemsPerXForm) <<
"] = a[" << (ind) <<
"].x;\n";
581 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
582 for( i = 0; i < numOuterIter; i++ )
583 for( j = 0; j < numInnerIter; j++ )
584 out <<
" a[" << (i*numInnerIter + j) <<
"].x = lMemStore[" << (j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) <<
"];\n";
585 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
586 for( i = 0; i < maxRadix; i++ )
591 out <<
" lMemLoad[" << (i*numWorkItemsPerXForm) <<
"] = a[" << (ind) <<
"].y;\n";
593 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
594 for( i = 0; i < numOuterIter; i++ )
595 for( j = 0; j < numInnerIter; j++ )
596 out <<
" a[" << (i*numInnerIter + j) <<
"].y = lMemStore[" << (j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) <<
"];\n";
597 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
598 out <<
"if((groupId == get_num_groups(0)-1) && s) {\n";
599 for(i = 0; i < numOuterIter; i++ )
601 out <<
" if( jj < s ) {\n";
602 for (
int j = 0; j < numInnerIter; j++ )
603 formattedStore(out, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
605 if(i != numOuterIter - 1)
606 out <<
" jj += " << (groupSize / mem_coalesce_width) <<
";\n";
610 for(i = 0; i < numOuterIter; i++ )
612 for (
int j = 0; j < numInnerIter; j++ )
613 formattedStore(out, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
616 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
620 out <<
" lMemLoad = sMem + mad24( jj," << (N + numWorkItemsPerXForm) <<
", ii );\n";
621 out <<
" ii = lId & " << (N - 1) <<
";\n";
622 out <<
" jj = lId >> " << ((int) log2(N)) <<
";\n";
623 out <<
" lMemStore = sMem + mad24( jj," << (N + numWorkItemsPerXForm) <<
", ii );\n";
624 for( i = 0; i < maxRadix; i++ )
629 out <<
" lMemLoad[" << (i*numWorkItemsPerXForm) <<
"] = a[" << (ind) <<
"].x;\n";
631 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
632 for( i = 0; i < maxRadix; i++ )
633 out <<
" a[" << (i) <<
"].x = lMemStore[" << (i*( groupSize / N )*( N + numWorkItemsPerXForm )) <<
"];\n";
634 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
635 for( i = 0; i < maxRadix; i++ )
640 out <<
" lMemLoad[" << (i*numWorkItemsPerXForm) <<
"] = a[" << (ind) <<
"].y;\n";
642 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
643 for( i = 0; i < maxRadix; i++ )
644 out <<
" a[" << (i) <<
"].y = lMemStore[" << (i*( groupSize / N )*( N + numWorkItemsPerXForm )) <<
"];\n";
645 out <<
" barrier( CLK_LOCAL_MEM_FENCE );\n";
646 out <<
"if((groupId == get_num_groups(0)-1) && s) {\n";
647 for( i = 0; i < maxRadix; i++ )
649 out <<
" if(jj < s ) {\n";
650 formattedStore(out, i, i*groupSize, dataFormat);
652 if( i != maxRadix - 1)
653 out <<
" jj +=" << (groupSize / N) <<
";\n";
657 for( i = 0; i < maxRadix; i++ )
659 formattedStore(out, i, i*groupSize, dataFormat);
662 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
668 insertLocalLoadIndexArithmatic(
673 int numWorkItemsPerXForm,
679 int Ncurr = Nprev * Nr;
680 int logNcurr = (int)std::log2(Ncurr);
681 int logNprev = (int)std::log2(Nprev);
682 int incr = (numWorkItemsReq + offset) * Nr + midPad;
683 if (Ncurr < numWorkItemsPerXForm) {
685 out <<
" j = ii & " << (Ncurr - 1) <<
";\n";
687 out <<
" j = (ii & " << (Ncurr - 1) <<
") >> " << logNprev <<
";\n";
690 out <<
" i = ii >> " << logNcurr <<
";\n";
692 out <<
" i = mad24(ii >> " << logNcurr <<
", " 693 << Nprev <<
", ii & " << (Nprev-1) <<
");\n";
699 out <<
" j = ii >> " << logNprev <<
";\n";
704 out <<
" i = ii & " << (Nprev-1) <<
";\n";
707 if (numXFormsPerWG > 1) {
708 out <<
" i = mad24(jj, " << incr <<
", i);\n";
710 out <<
" lMemLoad = sMem + mad24(j, " 711 << (numWorkItemsReq + offset)
716 insertLocalStoreIndexArithmatic(
724 if (numXFormsPerWG == 1) {
725 out <<
" lMemStore = sMem + ii;\n";
727 out <<
" lMemStore = sMem + mad24(jj, " 728 << ((numWorkItemsReq + offset)*Nr + midPad)
738 int numWorkItemsPerXForm,
743 for (
int z=0; z<numIter; ++z) {
744 for (
int k=0; k<Nr; ++k) {
745 int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm;
746 out <<
" lMemStore[" << (index)
747 <<
"] = a[" << (z*Nr + k) <<
"]." << comp <<
";\n";
750 out <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
761 int numWorkItemsPerXForm,
766 int numWorkItemsReqN = n / Nrn;
767 int interBlockHNum = std::max( Nprev / numWorkItemsPerXForm, 1 );
768 int interBlockHStride = numWorkItemsPerXForm;
769 int vertWidth = std::max(numWorkItemsPerXForm / Nprev, 1);
770 vertWidth = std::min( vertWidth, Nr);
771 int vertNum = Nr / vertWidth;
772 int vertStride = ( n / Nr + offset ) * vertWidth;
773 int iter = std::max( numWorkItemsReqN / numWorkItemsPerXForm, 1);
774 int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1;
775 intraBlockHStride *= Nprev;
776 int stride = numWorkItemsReq / Nrn;
778 for(i = 0; i < iter; i++) {
779 int ii = i / (interBlockHNum * vertNum);
780 int zz = i % (interBlockHNum * vertNum);
781 int jj = zz % interBlockHNum;
782 int kk = zz / interBlockHNum;
784 for(z = 0; z < Nrn; z++) {
785 int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride;
786 out <<
" a[" << (i*Nrn + z) <<
"]." 787 << comp <<
" = lMemLoad[" << (st) <<
"];\n";
790 out <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
795 int numWorkItemsPerXForm,
804 if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks))
807 int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev;
810 numColsReq = numRowsReq / Nr;
811 numColsReq = Nprev * numColsReq;
812 *offset = numColsReq;
814 if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1)
817 int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1);
818 if( bankNum >= numWorkItemsPerXForm )
821 *midPad = numWorkItemsPerXForm - bankNum;
823 int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1);
835 int baseRadix = std::min(n, 128);
838 while (N > baseRadix) {
842 for (
int i = 0; i < numR; i++) {
843 radix[i] = baseRadix;
848 for (
int i = 0; i < numR; i++) {
868 while (!rhs.empty() && rhs.back() <=
' ') { rhs.pop_back(); }
877 template <
class T,
int N>
879 num_bytes(
const blitz::TinyVector<int,N>& n) {
880 return blitz::product(n)*
sizeof(T);
886 vtb::opencl::Fourier_transform_base::kernel_name(
const char* prefix) {
887 const char sep =
'_';
889 name << prefix << sep
890 << this->_shape(0) << sep
891 << this->_shape(1) << sep
892 << this->_shape(2) << sep
898 vtb::opencl::Fourier_transform_base::generate_source_code() {
899 this->_src = base_kernels;
901 this->_kernels.clear();
902 for (
int i=0; i<3; ++i) {
903 this->generate_fft(i);
905 for (
const auto& kernel : this->_kernels) {
906 if (!kernel.in_place_possible) {
907 this->temp_buffer_needed =
true;
914 vtb::opencl::Fourier_transform_base::generate_fft(
int axis) {
916 int nx = this->_shape(0);
917 if (nx > this->max_localmem_fft_size) {
918 generate_fft_global(nx, 1, axis, 1);
921 if (nx/radices[0] <= this->_maxworkgroupsize) {
922 generate_fft_local();
924 radices = radix_array(nx, this->_maxradix);
925 if (nx/radices[0] <= this->_maxworkgroupsize) {
926 generate_fft_local();
928 generate_fft_global(nx, 1, axis, 1);
934 int ny = this->_shape(1);
936 int stride = this->_shape(0);
937 generate_fft_global(ny, stride, axis, 1);
941 int nz = this->_shape(2);
943 int stride = _shape(0)*_shape(1);
944 generate_fft_global(nz, stride, axis, 1);
950 vtb::opencl::Fourier_transform_base::generate_fft_local() {
951 int n = this->_shape(0);
952 if (n > this->_maxworkgroupsize*this->_maxradix) {
956 if (n/radices[0] > this->_maxworkgroupsize) {
957 radices = radix_array(n, this->_maxradix);
959 if (radices.front() > this->_maxradix) {
962 if (n/radices.front() > this->_maxworkgroupsize) {
964 "required work items per xform greater than " 965 "maximum work items allowed per work group for local mem fft" 968 int numRadix = radices.size();
971 for (
int i=0; i<numRadix; ++i) {
980 Kernel_info kernel{};
981 kernel.name = this->kernel_name(
"fft_local");
983 kernel.in_place_possible =
true;
984 int numWorkItemsPerXForm = n / radices[0];
985 int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm;
986 assert(numWorkItemsPerWG <= this->_maxworkgroupsize);
987 int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm;
988 kernel.num_workgroups = 1;
989 kernel.num_xforms_per_workgroup = numXFormsPerWG;
990 kernel.num_workitems_per_workgroup = numWorkItemsPerWG;
991 int maxRadix = radices[0];
993 insertVariables(out, maxRadix);
994 lMemSize = insertGlobalLoadsAndTranspose(
997 numWorkItemsPerXForm,
1000 this->min_mem_coalesce_width,
1003 kernel.lmem_size = (lMemSize > kernel.lmem_size) ? lMemSize : kernel.lmem_size;
1006 for(
int r = 0; r<numRadix; ++r) {
1007 int numIter = radices[0] / radices[r];
1008 int numWorkItemsReq = n / radices[r];
1009 int Ncurr = Nprev * radices[r];
1010 insertfftKernel(out, radices[r], numIter);
1011 if (r < (numRadix-1)) {
1012 insertTwiddleKernel(
1018 numWorkItemsPerXForm
1020 lMemSize = getPadding(
1021 numWorkItemsPerXForm,
1026 this->num_local_mem_banks,
1030 kernel.lmem_size = (lMemSize > kernel.lmem_size)
1033 insertLocalStoreIndexArithmatic(
1041 insertLocalLoadIndexArithmatic(
1046 numWorkItemsPerXForm,
1055 numWorkItemsPerXForm,
1067 numWorkItemsPerXForm,
1076 numWorkItemsPerXForm,
1088 numWorkItemsPerXForm,
1094 len = len / radices[r];
1097 lMemSize = insertGlobalStoresAndTranspose(
1101 radices[numRadix - 1],
1102 numWorkItemsPerXForm,
1104 this->min_mem_coalesce_width,
1107 kernel.lmem_size = (lMemSize > kernel.lmem_size) ? lMemSize : kernel.lmem_size;
1109 result << this->_src;
1110 insertHeader(result, kernel.name, this->_format);
1112 if (kernel.lmem_size) {
1113 result <<
" __local float sMem[" << kernel.lmem_size <<
"];\n";
1115 result << out.str();
1117 this->_src += result.str();
1118 this->_kernels.emplace_back(kernel);
1122 vtb::opencl::Fourier_transform_base::generate_fft_global(
1129 int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1130 int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1131 int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1134 int maxThreadsPerBlock = this->_maxworkgroupsize;
1135 int batchSize = this->min_mem_coalesce_width;
1136 int vertical = (axis == 0) ? 0 : 1;
1137 getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices);
1138 int numPasses = numRadices;
1140 int m = (int)log2(n);
1141 int Rinit = vertical ? BS : 1;
1142 batchSize = vertical ? std::min(BS, batchSize) : batchSize;
1144 for (
int passNum=0; passNum<numPasses; ++passNum) {
1146 radix = radixArr[passNum];
1147 R1 = R1Arr[passNum];
1148 R2 = R2Arr[passNum];
1149 int strideI = Rinit;
1150 for (
int i=0; i<numPasses; ++i) {
1152 strideI *= radixArr[i];
1155 int strideO = Rinit;
1156 for (
int i=0; i<passNum; ++i) {
1157 strideO *= radixArr[i];
1159 int threadsPerXForm = R2;
1160 batchSize = R2 == 1 ? this->_maxworkgroupsize : batchSize;
1161 batchSize = std::min(batchSize, strideI);
1162 int threadsPerBlock = batchSize * threadsPerXForm;
1163 threadsPerBlock = std::min(threadsPerBlock, maxThreadsPerBlock);
1164 batchSize = threadsPerBlock / threadsPerXForm;
1166 assert(R1*R2 == radix);
1167 assert(R1 <= this->_maxradix);
1168 assert(threadsPerBlock <= maxThreadsPerBlock);
1169 int numIter = R1 / R2;
1170 int gInInc = threadsPerBlock / batchSize;
1171 int lgStrideO = (int)log2(strideO);
1172 int numBlocksPerXForm = strideI / batchSize;
1173 int numBlocks = numBlocksPerXForm;
1177 numBlocks *= vertBS;
1179 Kernel_info kernel{};
1180 kernel.name = this->kernel_name(
"fft_global");
1182 kernel.lmem_size = 0;
1185 kernel.lmem_size = (radix + 1)*batchSize;
1187 kernel.lmem_size = threadsPerBlock*R1;
1190 kernel.num_workgroups = numBlocks;
1191 kernel.num_xforms_per_workgroup = 1;
1192 kernel.num_workitems_per_workgroup = threadsPerBlock;
1194 if((passNum == (numPasses - 1)) && (numPasses & 1)) {
1195 kernel.in_place_possible =
true;
1197 kernel.in_place_possible =
false;
1199 insertVariables(out, R1);
1201 out <<
"xNum = groupId >> " << ((int)log2(numBlocksPerXForm)) <<
";\n";
1202 out <<
"groupId = groupId & " << (numBlocksPerXForm - 1) <<
";\n";
1203 out <<
"indexIn = mad24(groupId, " << (batchSize) <<
", xNum << " << ((
int)log2(n*BS)) <<
");\n";
1204 out <<
"tid = mul24(groupId, " << (batchSize) <<
");\n";
1205 out <<
"i = tid >> " << (lgStrideO) <<
";\n";
1206 out <<
"j = tid & " << (strideO - 1) <<
";\n";
1207 int stride = radix*Rinit;
1208 for (
int i=0; i<passNum; ++i) {
1209 stride *= radixArr[i];
1211 out <<
"indexOut = mad24(i, " << (stride) <<
", j + " <<
"(xNum << " << ((
int) log2(n*BS)) <<
"));\n";
1212 out <<
"bNum = groupId;\n";
1214 int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm);
1215 out <<
"bNum = groupId & " << (numBlocksPerXForm - 1) <<
";\n";
1216 out <<
"xNum = groupId >> " << (lgNumBlocksPerXForm) <<
";\n";
1217 out <<
"indexIn = mul24(bNum, " << (batchSize) <<
");\n";
1218 out <<
"tid = indexIn;\n";
1219 out <<
"i = tid >> " << (lgStrideO) <<
";\n";
1220 out <<
"j = tid & " << (strideO - 1) <<
";\n";
1221 int stride = radix*Rinit;
1222 for (
int i=0; i<passNum; ++i) {
1223 stride *= radixArr[i];
1225 out <<
"indexOut = mad24(i, " << (stride) <<
", j);\n";
1226 out <<
"indexIn += (xNum << " << (m) <<
");\n";
1227 out <<
"indexOut += (xNum << " << (m) <<
");\n";
1230 int lgBatchSize = (int)log2(batchSize);
1231 out <<
"tid = lId;\n";
1232 out <<
"i = tid & " << (batchSize - 1) <<
";\n";
1233 out <<
"j = tid >> " << (lgBatchSize) <<
";\n";
1234 out <<
"indexIn += mad24(j, " << (strideI) <<
", i);\n";
1235 if (this->_format == Fourier_transform_format::Split_complex) {
1236 out <<
"in_real += indexIn;\n";
1237 out <<
"in_imag += indexIn;\n";
1238 for (
int j=0; j<R1; ++j)
1239 out <<
"a[" << (j) <<
"].x = in_real[" << (j*gInInc*strideI) <<
"];\n";
1240 for (
int j=0; j<R1; ++j)
1241 out <<
"a[" << (j) <<
"].y = in_imag[" << (j*gInInc*strideI) <<
"];\n";
1243 out <<
"in += indexIn;\n";
1244 for (
int j=0; j<R1; ++j) {
1245 out <<
"a[" << (j) <<
"] = in[" << (j*gInInc*strideI) <<
"];\n";
1248 out <<
"fftKernel" << (R1) <<
"(a, dir);\n";
1251 for (
int k = 1; k < R1; k++) {
1252 out <<
"ang = dir*(2.0f*M_PI*" << (k) <<
"/" << (radix) <<
")*j;\n";
1253 out <<
"w = (float2)(native_cos(ang), native_sin(ang));\n";
1254 out <<
"a[" << (k) <<
"] = complexMul(a[" << (k) <<
"], w);\n";
1258 out <<
"indexIn = mad24(j, " << (threadsPerBlock*numIter) <<
", i);\n";
1259 out <<
"lMemStore = sMem + tid;\n";
1260 out <<
"lMemLoad = sMem + indexIn;\n";
1261 for (
int k = 0; k < R1; k++) {
1262 out <<
"lMemStore[" << (k*threadsPerBlock) <<
"] = a[" << (k) <<
"].x;\n";
1264 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1265 for(k = 0; k < numIter; k++)
1266 for(t = 0; t < R2; t++)
1267 out <<
"a[" << (k*R2+t) <<
"].x = lMemLoad[" << (t*batchSize + k*threadsPerBlock) <<
"];\n";
1268 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1269 for(k = 0; k < R1; k++)
1270 out <<
"lMemStore[" << (k*threadsPerBlock) <<
"] = a[" << (k) <<
"].y;\n";
1271 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1272 for(k = 0; k < numIter; k++)
1273 for(t = 0; t < R2; t++)
1274 out <<
"a[" << (k*R2+t) <<
"].y = lMemLoad[" << (t*batchSize + k*threadsPerBlock) <<
"];\n";
1275 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1276 for(
int j = 0; j < numIter; j++)
1277 out <<
"fftKernel" << (R2) <<
"(a + " << (j*R2) <<
", dir);\n";
1280 if (passNum < (numPasses - 1)) {
1281 out <<
"l = ((bNum << " << (lgBatchSize) <<
") + i) >> " << (lgStrideO) <<
";\n";
1282 out <<
"k = j << " << ((int)log2(R1/R2)) <<
";\n";
1283 out <<
"ang1 = dir*(2.0f*M_PI/" << (N) <<
")*l;\n";
1284 for(t = 0; t < R1; t++)
1286 out <<
"ang = ang1*(k + " << ((t%R2)*R1 + (t/R2)) <<
");\n";
1287 out <<
"w = (float2)(native_cos(ang), native_sin(ang));\n";
1288 out <<
"a[" << (t) <<
"] = complexMul(a[" << (t) <<
"], w);\n";
1293 out <<
"lMemStore = sMem + mad24(i, " << (radix + 1) <<
", j << " << ((
int)log2(R1/R2)) <<
");\n";
1294 out <<
"lMemLoad = sMem + mad24(tid >> " << ((int)log2(radix)) <<
", " << (radix+1) <<
", tid & " << (radix-1) <<
");\n";
1295 for(
int i = 0; i < R1/R2; i++)
1296 for(
int j = 0; j < R2; j++)
1297 out <<
"lMemStore[ " << (i + j*R1) <<
"] = a[" << (i*R2+j) <<
"].x;\n";
1298 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1299 if(threadsPerBlock >= radix)
1301 for(
int i = 0; i < R1; i++)
1302 out <<
"a[" << (i) <<
"].x = lMemLoad[" << (i*(radix+1)*(threadsPerBlock/radix)) <<
"];\n";
1306 int innerIter = radix/threadsPerBlock;
1307 int outerIter = R1/innerIter;
1308 for(
int i = 0; i < outerIter; i++)
1309 for(
int j = 0; j < innerIter; j++)
1310 out <<
"a[" << (i*innerIter+j) <<
"].x = lMemLoad[" << (j*threadsPerBlock + i*(radix+1)) <<
"];\n";
1312 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1313 for (
int i = 0; i < R1/R2; i++)
1314 for(
int j = 0; j < R2; j++)
1315 out <<
"lMemStore[ " << (i + j*R1) <<
"] = a[" << (i*R2+j) <<
"].y;\n";
1316 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1317 if (threadsPerBlock >= radix) {
1318 for(
int i = 0; i < R1; i++)
1319 out <<
"a[" << (i) <<
"].y = lMemLoad[" << (i*(radix+1)*(threadsPerBlock/radix)) <<
"];\n";
1323 int innerIter = radix/threadsPerBlock;
1324 int outerIter = R1/innerIter;
1325 for(
int i = 0; i < outerIter; i++)
1326 for (
int j = 0; j < innerIter; j++)
1327 out <<
"a[" << (i*innerIter+j) <<
"].y = lMemLoad[" << (j*threadsPerBlock + i*(radix+1)) <<
"];\n";
1329 out <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
1330 out <<
"indexOut += tid;\n";
1331 if(this->_format == Fourier_transform_format::Split_complex) {
1332 out <<
"out_real += indexOut;\n";
1333 out <<
"out_imag += indexOut;\n";
1334 for(k = 0; k < R1; k++)
1335 out <<
"out_real[" << (k*threadsPerBlock) <<
"] = a[" << (k) <<
"].x;\n";
1336 for(k = 0; k < R1; k++)
1337 out <<
"out_imag[" << (k*threadsPerBlock) <<
"] = a[" << (k) <<
"].y;\n";
1340 out <<
"out += indexOut;\n";
1341 for(k = 0; k < R1; k++)
1342 out <<
"out[" << (k*threadsPerBlock) <<
"] = a[" << (k) <<
"];\n";
1345 out <<
"indexOut += mad24(j, " << (numIter*strideO) <<
", i);\n";
1346 if(this->_format == Fourier_transform_format::Split_complex) {
1347 out <<
"out_real += indexOut;\n";
1348 out <<
"out_imag += indexOut;\n";
1349 for(k = 0; k < R1; k++)
1350 out <<
"out_real[" << (((k%R2)*R1 + (k/R2))*strideO) <<
"] = a[" << (k) <<
"].x;\n";
1351 for(k = 0; k < R1; k++)
1352 out <<
"out_imag[" << (((k%R2)*R1 + (k/R2))*strideO) <<
"] = a[" << (k) <<
"].y;\n";
1355 out <<
"out += indexOut;\n";
1356 for(k = 0; k < R1; k++)
1357 out <<
"out[" << (((k%R2)*R1 + (k/R2))*strideO) <<
"] = a[" << (k) <<
"];\n";
1361 insertHeader(result, kernel.name, this->_format);
1363 if (kernel.lmem_size) {
1364 result <<
" __local float sMem[" << (kernel.lmem_size) <<
"];\n";
1366 result << out.str();
1368 this->_src += result.str();
1370 this->_kernels.emplace_back(kernel);
1375 vtb::opencl::Fourier_transform_base::getKernelWorkDimensions(
1376 const Kernel_info& kernel,
1381 *lWorkItems = kernel.num_workitems_per_workgroup;
1382 int numWorkGroups = kernel.num_workgroups;
1383 int numXFormsPerWG = kernel.num_xforms_per_workgroup;
1384 int ny = this->_shape(1);
1385 int nz = this->_shape(2);
1386 if (kernel.axis == 0) {
1387 *batchSize *= ny*nz;
1388 numWorkGroups = (*batchSize % numXFormsPerWG)
1389 ? (*batchSize/numXFormsPerWG + 1)
1390 : (*batchSize/numXFormsPerWG);
1391 numWorkGroups *= kernel.num_workgroups;
1393 if (kernel.axis == 1) {
1395 numWorkGroups *= *batchSize;
1397 if (kernel.axis == 2) {
1398 numWorkGroups *= *batchSize;
1400 *gWorkItems = numWorkGroups * *lWorkItems;
1407 int ni = max_power(0);
1408 int nj = max_power(1);
1410 auto t0 = clock::now();
1411 int max_count = product(max_power);
1413 #if defined(VTB_WITH_OPENMP) 1414 #pragma omp parallel for collapse(2) schedule(dynamic,1) 1416 for (
int i=1; i<=ni; ++i) {
1417 for (
int j=1; j<=nj; ++j) {
1419 fft.context(context);
1420 fft.shape({1<<i, 1<<j, 1});
1421 if (clock::now()-t0 > seconds(1)) { slow =
true; }
1423 if (slow && cnt%10 == 0 && cnt >= 10) {
1424 std::fprintf(stderr,
"%5d/%-5d compile fft\n", cnt, max_count);
1431 vtb::opencl::Fourier_transform_base::dump(
std::ostream& out) {
1432 size_t global = 0, local = 0;
1433 for (
const auto& kernel : this->_kernels) {
1434 cl_int batch_size = 1;
1435 getKernelWorkDimensions(kernel, &batch_size, &global, &local);
1436 out << std::setw(20) << kernel.name
1437 << std::setw(20) << global
1438 << std::setw(20) << local
1444 vtb::opencl::Fourier_transform_base::Fourier_transform_base(
const int3& shape):
1450 vtb::opencl::Fourier_transform_base::allocate_temporary_buffer(
int batch_size) {
1451 if (this->last_batch_size != batch_size) {
1452 this->last_batch_size = batch_size;
1453 size_t n = this->buffer_size(batch_size);
1454 this->_workarea = context()->context().buffer(clx::memory_flags::read_write, n);
1459 vtb::opencl::Fourier_transform_base::enqueue(
1464 auto& ppl = context()->pipeline();
1465 clx::buffer buffer_in = x, buffer_out(
nullptr);
1466 Kernel_info* first = this->_kernels.data();
1467 Kernel_info* last = this->_kernels.data() + this->_kernels.size();
1469 while (first != last && first->in_place_possible) {
1471 int new_batch_size = batch_size;
1472 size_t local = 0, global = 0;
1473 this->getKernelWorkDimensions(k, &new_batch_size, &global, &local);
1474 k.kernel.argument(0, buffer_in);
1475 k.kernel.argument(1, buffer_in);
1476 k.kernel.argument(2, direction);
1477 k.kernel.argument(3, new_batch_size);
1479 ppl.kernel(k.kernel, clx::range{global}, clx::range{local});
1483 if (first != last) {
1484 this->allocate_temporary_buffer(batch_size);
1485 buffer_out = this->_workarea;
1486 auto remaining = last - first;
1487 if (remaining%2 != 0) {
1490 ppl.copy(buffer_in, buffer_out);
1491 std::swap(buffer_in, buffer_out);
1496 while (first != last) {
1498 int new_batch_size = batch_size;
1499 size_t local = 0, global = 0;
1500 this->getKernelWorkDimensions(k, &new_batch_size, &global, &local);
1501 k.kernel.argument(0, buffer_in);
1502 k.kernel.argument(1, buffer_out);
1503 k.kernel.argument(2, direction);
1504 k.kernel.argument(3, new_batch_size);
1506 ppl.kernel(k.kernel, clx::range{global}, clx::range{local});
1507 std::swap(buffer_in, buffer_out);
1513 vtb::opencl::Fourier_transform_base::init() {
1514 if (!blitz::is_power_of_two(blitz::product(this->_shape))) {
1518 while (dim < 3 && this->_shape(dim) != 1) {
1521 if (dim < 0 || dim > 2) {
1524 this->_ndimensions = dim;
1525 bool success =
false;
1526 clx::compiler cc = context()->compiler_copy();
1527 cc.options(cc.options() +
" -cl-mad-enable");
1528 auto device = cc.devices().front();
1529 this->_maxworkgroupsize = device.max_work_group_size();
1531 this->generate_source_code();
1532 clx::program prog = cc.compile(
"fft.cl", this->_src);
1533 auto all_kernels = prog.kernels();
1534 for (
auto& kernel : all_kernels) {
1535 auto name = kernel.name();
1536 for (
auto& k : this->_kernels) {
1537 if (name == k.name) { k.kernel = kernel;
break; }
1542 for (
const auto& k : this->_kernels) {
1543 auto wg = k.kernel.work_group(device);
1544 if (wg.size <
size_t(k.num_workitems_per_workgroup)) { success =
false; }
1545 if (wg.size < min_wg_size) { min_wg_size = wg.size; }
1547 this->_maxworkgroupsize = min_wg_size;
1552 vtb::opencl::Chirp_Z_transform_base::context(Context* rhs) {
1553 this->_fft.context(rhs);
1554 auto prog = context()->compiler().compile(
"chirp_z_transform.cl");
1555 _makechirp = prog.kernel(
"make_chirp");
1556 _reciprocal_chirp = prog.kernel(
"reciprocal_chirp");
1557 _mult1 = prog.kernel(
"multiply_1");
1558 _mult2 = prog.kernel(
"multiply_2");
1559 _mult3 = prog.kernel(
"multiply_3");
1560 _zero_init = prog.kernel(
"zero_init");
1564 vtb::opencl::Chirp_Z_transform_base::make_chirp(
1566 const int3& fft_shape
1568 auto& ppl = context()->pipeline();
1570 int3 chirp_shape{shape*2-1};
1571 ppl.allocate(product(chirp_shape), this->_chirp);
1572 ppl.allocate(product(fft_shape), this->_xp);
1573 ppl.allocate(product(fft_shape), this->_ichirp);
1574 auto& kernel = this->_makechirp;
1575 kernel.arguments(this->_chirp, shape(0), shape(1), shape(2));
1576 ppl.kernel(kernel, chirp_shape);
1578 VTB_DUMP(this->_chirp, chirp_shape,
"chirp");
1582 vtb::opencl::Chirp_Z_transform_base::enqueue(
1587 using blitz::product;
1588 if (batch_size != 1) {
1591 auto& ppl = this->_fft.context()->pipeline();
1598 clx::kernel zero = _zero_init;
1599 zero.argument(0, _xp);
1600 _mult1.argument(0, x);
1601 _mult1.argument(1, _chirp);
1602 _mult1.argument(2, direction);
1603 _mult1.argument(3, _fft.shape()(0));
1604 _mult1.argument(4, _fft.shape()(1));
1605 _mult1.argument(5, _fft.shape()(2));
1606 _mult1.argument(6, _xp);
1607 ppl.kernel(zero, _fft.shape());
1609 ppl.kernel(_mult1, _shape);
1611 VTB_DUMP(this->_xp, this->_fft.shape(),
"xp");
1612 _fft.enqueue(_xp, direction, batch_size);
1613 VTB_DUMP(this->_xp, this->_fft.shape(),
"fft(xp)");
1615 int3 chirp_shape{_shape*2-1};
1618 clx::kernel zero = _zero_init;
1619 zero.argument(0, _ichirp);
1620 _reciprocal_chirp.argument(0, _chirp);
1621 _reciprocal_chirp.argument(1, direction);
1622 _reciprocal_chirp.argument(2, _fft.shape()(0));
1623 _reciprocal_chirp.argument(3, _fft.shape()(1));
1624 _reciprocal_chirp.argument(4, _fft.shape()(2));
1625 _reciprocal_chirp.argument(5, _ichirp);
1626 ppl.kernel(zero, _fft.shape());
1628 ppl.kernel(_reciprocal_chirp, chirp_shape);
1630 VTB_DUMP(this->_ichirp, this->_fft.shape(),
"ichirp");
1631 _fft.enqueue(_ichirp, direction, batch_size);
1632 VTB_DUMP(this->_ichirp, this->_fft.shape(),
"fft(ichirp)");
1634 _mult2.argument(0, _ichirp);
1635 _mult2.argument(1, _xp);
1637 ppl.kernel(_mult2, _fft.shape());
1638 VTB_DUMP(this->_ichirp, this->_fft.shape(),
"mult2");
1640 _fft.enqueue(_ichirp, -direction, batch_size);
1641 VTB_DUMP(this->_ichirp, this->_fft.shape(),
"ifft");
1643 _mult3.argument(0, _ichirp);
1644 _mult3.argument(1, _chirp);
1645 _mult3.argument(2, x);
1646 _mult3.argument(3, _fft.shape()(0));
1647 _mult3.argument(4, _fft.shape()(1));
1648 _mult3.argument(5, _fft.shape()(2));
1649 _mult3.argument(6, 1.f / product(_fft.shape()));
1650 ppl.kernel(_mult3, _shape);