Tinker9 70bd052 (Thu Nov 9 12:11:35 2023 -0800)
Loading...
Searching...
No Matches
darray.h
1#pragma once
2#include "math/const.h"
3#include "math/libfunc.h"
4#include "math/parallel.h"
5#include "tool/accasync.h"
6#include "tool/ptrtrait.h"
7
8#include <vector>
9
10namespace tinker {
13void waitFor(int queue
14);
15
19 const void* src,
20 size_t nbytes,
21 int queue
22);
23
27 const void* src,
28 size_t nbytes,
29 int queue
30);
31
36 const void* src,
37 size_t nbytes,
38 int queue
39);
40
44 size_t nbytes,
45 int queue
46);
47
51);
52
55void deviceMemoryAllocateBytes(void** pptr,
56 size_t nbytes
57);
58}
59
60namespace tinker {
61inline namespace v1 {
64template <class T>
66{
67 static_assert(std::is_enum<T>::value || std::is_integral<T>::value || std::is_floating_point<T>::value || std::is_trivial<T>::value, "");
68}
69
72template <class DT, class ST>
74 const ST* src,
75 size_t nelem,
76 int q
77)
78{
79 deviceMemoryCheckType<DT>();
80 deviceMemoryCheckType<ST>();
81 constexpr size_t ds = sizeof(DT); // device type
82 constexpr size_t ss = sizeof(ST); // host type
83
84 size_t size = ds * nelem;
85 if (ds == ss) {
86 deviceMemoryCopyinBytesAsync(dst, src, size, q);
87 } else {
88 std::vector<DT> buf(nelem);
89 for (size_t i = 0; i < nelem; ++i)
90 buf[i] = src[i];
91 deviceMemoryCopyinBytesAsync(dst, buf.data(), size, q);
92 waitFor(q);
93 }
94}
95
98template <class DT, class ST>
100 const ST* src,
101 size_t nelem,
102 int q
103)
104{
105 deviceMemoryCheckType<DT>();
106 deviceMemoryCheckType<ST>();
107 constexpr size_t ds = sizeof(DT); // host type
108 constexpr size_t ss = sizeof(ST); // device type
109
110 size_t size = ss * nelem;
111 if (ds == ss) {
112 deviceMemoryCopyoutBytesAsync(dst, src, size, q);
113 } else {
114 std::vector<ST> buf(nelem);
115 deviceMemoryCopyoutBytesAsync(buf.data(), src, size, q);
116 waitFor(q);
117 for (size_t i = 0; i < nelem; ++i)
118 dst[i] = buf[i];
119 }
120}
121}
122}
123
124namespace tinker {
128{
129private:
130 template <class PTR>
131 static typename PtrTrait<PTR>::type* flatten(PTR p)
132 {
133 typedef typename PtrTrait<PTR>::type T;
134 return reinterpret_cast<T*>(p);
135 }
136
137public:
138 template <class PTR>
139 static void allocate(size_t nelem, PTR* pp)
140 {
141 typedef typename PtrTrait<PTR>::type T;
142 constexpr size_t N = PtrTrait<PTR>::n;
143 deviceMemoryAllocateBytes(reinterpret_cast<void**>(pp), sizeof(T) * nelem * N);
144 }
145
146 template <class PTR, class... PTRS>
147 static void allocate(size_t nelem, PTR* pp, PTRS... pps)
148 {
149 allocate(nelem, pp);
150 allocate(nelem, pps...);
151 }
152
153 template <class PTR>
154 static void deallocate(PTR p)
155 {
156 deviceMemoryDeallocate(flatten(p));
157 }
158
159 template <class PTR, class... PTRS>
160 static void deallocate(PTR p, PTRS... ps)
161 {
162 deallocate(p);
163 deallocate(ps...);
164 }
165
166 template <class PTR>
167 static void zero(int q, size_t nelem, PTR p)
168 {
169 typedef typename PtrTrait<PTR>::type T;
170 constexpr size_t N = PtrTrait<PTR>::n;
171 deviceMemoryZeroBytesAsync(flatten(p), sizeof(T) * nelem * N, q);
172 }
173
174 template <class PTR, class... PTRS>
175 static void zero(int q, size_t nelem, PTR p, PTRS... ps)
176 {
177 zero(q, nelem, p);
178 zero(q, nelem, ps...);
179 }
180
181 template <class PTR, class U>
182 static void copyin(int q, size_t nelem, PTR dst, const U* src)
183 {
184 constexpr size_t N = PtrTrait<PTR>::n;
185 deviceMemoryCopyin1dArray(flatten(dst), flatten(src), nelem * N, q);
186 }
187
188 template <class U, class PTR>
189 static void copyout(int q, size_t nelem, U* dst, const PTR src)
190 {
191 constexpr size_t N = PtrTrait<PTR>::n;
192 deviceMemoryCopyout1dArray(flatten(dst), flatten(src), nelem * N, q);
193 }
194
196 template <class PTR, class U>
197 static void copy(int q, size_t nelem, PTR dst, const U* src)
198 {
199 constexpr size_t N = PtrTrait<PTR>::n;
200 using DT = typename PtrTrait<PTR>::type;
201 using ST = typename PtrTrait<U*>::type;
202 static_assert(std::is_same<DT, ST>::value, "");
203 size_t size = N * sizeof(ST) * nelem;
204 deviceMemoryCopyBytesAsync(flatten(dst), flatten(src), size, q);
205 }
206
208 template <class PTR, class PTR2>
209 static typename PtrTrait<PTR>::type dotThenReturn(int q, size_t nelem, const PTR ptr, const PTR2 b)
210 {
211 typedef typename PtrTrait<PTR>::type T;
212 constexpr size_t N = PtrTrait<PTR>::n;
213 typedef typename PtrTrait<PTR2>::type T2;
214 static_assert(std::is_same<T, T2>::value, "");
215 return dotProd(flatten(ptr), flatten(b), nelem * N, q);
216 }
217
219 template <class ANS, class PTR, class PTR2>
220 static void dot(int q, size_t nelem, ANS ans, const PTR ptr, const PTR2 ptr2)
221 {
222 typedef typename PtrTrait<PTR>::type T;
223 constexpr size_t N = PtrTrait<PTR>::n;
224 typedef typename PtrTrait<PTR2>::type T2;
225 static_assert(std::is_same<T, T2>::value, "");
226 typedef typename PtrTrait<ANS>::type TA;
227 static_assert(std::is_same<T, TA>::value, "");
228 dotProd(ans, flatten(ptr), flatten(ptr2), nelem * N, q);
229 }
230
231 template <class FLT, class PTR>
232 static void scale(int q, size_t nelem, FLT scal, PTR ptr)
233 {
234 constexpr size_t N = PtrTrait<PTR>::n;
235 scaleArray(flatten(ptr), scal, nelem * N, q);
236 }
237
238 template <class FLT, class PTR, class... PTRS>
239 static void scale(int q, size_t nelem, FLT scal, PTR ptr, PTRS... ptrs)
240 {
241 scale(q, nelem, scal, ptr);
242 scale(q, nelem, scal, ptrs...);
243 }
244};
245}
Definition: ptrtrait.h:19
T dotProd(const T *a, const T *b, size_t nelem, int queue)
Dot product of two linear arrays.
Definition: parallel.h:82
void scaleArray(T *dst, T scal, size_t nelem, int queue)
Multiply all of the elements in an 1D array by a scalar.
Definition: parallel.h:102
static void allocate(size_t nelem, PTR *pp, PTRS... pps)
Definition: darray.h:147
static void deallocate(PTR p)
Definition: darray.h:154
static void copyout(int q, size_t nelem, U *dst, const PTR src)
Definition: darray.h:189
static void scale(int q, size_t nelem, FLT scal, PTR ptr)
Definition: darray.h:232
static void zero(int q, size_t nelem, PTR p)
Definition: darray.h:167
static void dot(int q, size_t nelem, ANS ans, const PTR ptr, const PTR2 ptr2)
Calculates the dot product and saves the answer to pointer ans.
Definition: darray.h:220
static void scale(int q, size_t nelem, FLT scal, PTR ptr, PTRS... ptrs)
Definition: darray.h:239
static void copy(int q, size_t nelem, PTR dst, const U *src)
Copies data across two device memory pointers.
Definition: darray.h:197
static void zero(int q, size_t nelem, PTR p, PTRS... ps)
Definition: darray.h:175
static void copyin(int q, size_t nelem, PTR dst, const U *src)
Definition: darray.h:182
static void deallocate(PTR p, PTRS... ps)
Definition: darray.h:160
static PtrTrait< PTR >::type dotThenReturn(int q, size_t nelem, const PTR ptr, const PTR2 b)
Calculates the dot product and returns the answer to the host.
Definition: darray.h:209
static void allocate(size_t nelem, PTR *pp)
Definition: darray.h:139
Device array.
Definition: darray.h:128
void deviceMemoryAllocateBytes(void **pptr, size_t nbytes)
Allocates device pointer.
void deviceMemoryCopyinBytesAsync(void *dst, const void *src, size_t nbytes, int queue)
Similar to OpenACC async copyin, copies data from host to device.
void deviceMemoryZeroBytesAsync(void *dst, size_t nbytes, int queue)
Writes zero bytes on device.
void deviceMemoryCheckType()
Sanity check.
Definition: darray.h:65
void waitFor(int queue)
Similar to OpenACC wait and CUDA stream synchronize.
void deviceMemoryDeallocate(void *ptr)
Deallocates device pointer.
void deviceMemoryCopyoutBytesAsync(void *dst, const void *src, size_t nbytes, int queue)
Similar to OpenACC async copyout, copies data from device to host.
void deviceMemoryCopyBytesAsync(void *dst, const void *src, size_t nbytes, int queue)
Copies data between two pointers on device.
void deviceMemoryCopyin1dArray(DT *dst, const ST *src, size_t nelem, int q)
Copies data to 1D array, host to device.
Definition: darray.h:73
void deviceMemoryCopyout1dArray(DT *dst, const ST *src, size_t nelem, int q)
Copies data to 1D array, device to host.
Definition: darray.h:99
Definition: testrt.h:9