Tinker9 70bd052 (Thu Nov 9 12:11:35 2023 -0800)
Loading...
Searching...
No Matches
reduce.h
1#pragma once
2#include "tool/macro.h"
3
4namespace tinker {
5template <class T>
6struct OpPlus
7{
8 __device__
9 T init() const
10 {
11 return 0;
12 }
13
14 __device__
15 T operator()(T a, T b) const
16 {
17 return a + b;
18 }
19
20 __device__
21 void x(volatile T& a, T b) const
22 {
23 a += b;
24 }
25};
26
27template <class T>
29{
30 __device__
31 T init() const
32 {
33 return false;
34 }
35
36 __device__
37 T operator()(T a, T b) const
38 {
39 return a || b;
40 }
41
42 __device__
43 void x(volatile T& a, T b) const
44 {
45 a |= b;
46 }
47};
48
49template <class T, unsigned int B, class Op>
50__device__
51inline void warp_reduce(volatile T* sd, unsigned int t, Op op)
52{
53 // clang-format off
54#if 0
55 // code was correct, but should be updated with __syncwarp()
56 if (B >= 64) sd[t] = op(sd[t], sd[t + 32]);
57 if (B >= 32) sd[t] = op(sd[t], sd[t + 16]);
58 if (B >= 16) sd[t] = op(sd[t], sd[t + 8 ]);
59 if (B >= 8) sd[t] = op(sd[t], sd[t + 4 ]);
60 if (B >= 4) sd[t] = op(sd[t], sd[t + 2 ]);
61 if (B >= 2) sd[t] = op(sd[t], sd[t + 1 ]);
62#else
63 T var;
64 if (B >= 64) { var=sd[t+32];__syncwarp(); op.x(sd[t],var);__syncwarp(); }
65 if (B >= 32) { var=sd[t+16];__syncwarp(); op.x(sd[t],var);__syncwarp(); }
66 if (B >= 16) { var=sd[t+8 ];__syncwarp(); op.x(sd[t],var);__syncwarp(); }
67 if (B >= 8) { var=sd[t+4 ];__syncwarp(); op.x(sd[t],var);__syncwarp(); }
68 if (B >= 4) { var=sd[t+2 ];__syncwarp(); op.x(sd[t],var);__syncwarp(); }
69 if (B >= 2) { var=sd[t+1 ];__syncwarp(); op.x(sd[t],var);__syncwarp(); }
70#endif
71 // clang-format on
72}
73
74template <class T, unsigned int HN, unsigned int B, class Op>
75__device__
76inline void warp_reduce2(volatile T (*sd)[B], unsigned int t, Op op)
77{
78 // clang-format off
79#if 0
80 // code was correct, but should be updated with __syncwarp()
81 if (B >= 64) _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 32]);
82 if (B >= 32) _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 16]);
83 if (B >= 16) _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 8 ]);
84 if (B >= 8) _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 4 ]);
85 if (B >= 4) _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 2 ]);
86 if (B >= 2) _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 1 ]);
87#else
88 T var;
89 if (B >= 64) _Pragma("unroll") for (int j = 0; j < HN; ++j) { var=sd[j][t+32];__syncwarp(); op.x(sd[j][t],var);__syncwarp(); }
90 if (B >= 32) _Pragma("unroll") for (int j = 0; j < HN; ++j) { var=sd[j][t+16];__syncwarp(); op.x(sd[j][t],var);__syncwarp(); }
91 if (B >= 16) _Pragma("unroll") for (int j = 0; j < HN; ++j) { var=sd[j][t+8 ];__syncwarp(); op.x(sd[j][t],var);__syncwarp(); }
92 if (B >= 8) _Pragma("unroll") for (int j = 0; j < HN; ++j) { var=sd[j][t+4 ];__syncwarp(); op.x(sd[j][t],var);__syncwarp(); }
93 if (B >= 4) _Pragma("unroll") for (int j = 0; j < HN; ++j) { var=sd[j][t+2 ];__syncwarp(); op.x(sd[j][t],var);__syncwarp(); }
94 if (B >= 2) _Pragma("unroll") for (int j = 0; j < HN; ++j) { var=sd[j][t+1 ];__syncwarp(); op.x(sd[j][t],var);__syncwarp(); }
95#endif
96 // clang-format on
97}
98
99template <class T, unsigned int B, class Op>
100__global__
101void reduce(T* g_odata, const T* g_idata, size_t n, Op op = Op())
102{
103 __shared__ T sd[B];
104 unsigned int t = threadIdx.x;
105 sd[t] = op.init();
106 for (int i = t + blockIdx.x * B; i < n; i += B * gridDim.x) {
107 sd[t] = op(sd[t], g_idata[i]);
108 }
109 __syncthreads();
110
111 // clang-format off
112 if (B >= 512) { if (t < 256) { sd[t] = op(sd[t], sd[t + 256]); } __syncthreads(); }
113 if (B >= 256) { if (t < 128) { sd[t] = op(sd[t], sd[t + 128]); } __syncthreads(); }
114 if (B >= 128) { if (t < 64 ) { sd[t] = op(sd[t], sd[t + 64 ]); } __syncthreads(); }
115 if (t < 32 ) warp_reduce<T, B, Op>(sd, t, op);
116 // clang-format on
117 if (t == 0) g_odata[blockIdx.x] = sd[0];
118}
119
120template <class T, unsigned int B, unsigned int HN, size_t N, class Op>
121__global__
122void reduce2(T (*g_odata)[HN], const T (*g_idata)[N], size_t n, Op op = Op())
123{
124 __shared__ T sd[HN][B];
125 unsigned int t = threadIdx.x;
126 #pragma unroll
127 for (int j = 0; j < HN; ++j)
128 sd[j][t] = 0;
129 for (int i = t + blockIdx.x * B; i < n; i += B * gridDim.x) {
130 #pragma unroll
131 for (int j = 0; j < HN; ++j)
132 sd[j][t] = op(sd[j][t], g_idata[i][j]);
133 }
134 __syncthreads();
135
136 // clang-format off
137 if (B >= 512) { if (t < 256) { _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 256]); } __syncthreads(); }
138 if (B >= 256) { if (t < 128) { _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 128]); } __syncthreads(); }
139 if (B >= 128) { if (t < 64 ) { _Pragma("unroll") for (int j = 0; j < HN; ++j) sd[j][t] = op(sd[j][t], sd[j][t + 64 ]); } __syncthreads(); }
140 if (t < 32 ) warp_reduce2<T, HN, B, Op>(sd, t, op);
141 // clang-format on
142 if (t == 0)
143 #pragma unroll
144 for (int j = 0; j < HN; ++j)
145 g_odata[blockIdx.x][j] = sd[j][0];
146}
147}
int n
Number of atoms padded by WARP_SIZE.
Definition: testrt.h:9
__global__ void reduce(T *g_odata, const T *g_idata, size_t n, Op op=Op())
Definition: reduce.h:101
__device__ void warp_reduce2(volatile T(*sd)[B], unsigned int t, Op op)
Definition: reduce.h:76
__global__ void reduce2(T(*g_odata)[HN], const T(*g_idata)[N], size_t n, Op op=Op())
Definition: reduce.h:122
__device__ void warp_reduce(volatile T *sd, unsigned int t, Op op)
Definition: reduce.h:51
Definition: reduce.h:29
__device__ void x(volatile T &a, T b) const
Definition: reduce.h:43
__device__ T operator()(T a, T b) const
Definition: reduce.h:37
__device__ T init() const
Definition: reduce.h:31
Definition: reduce.h:7
__device__ T init() const
Definition: reduce.h:9
__device__ T operator()(T a, T b) const
Definition: reduce.h:15
__device__ void x(volatile T &a, T b) const
Definition: reduce.h:21