GPUMLib  0.2.2
GPU Machine Learning Library
NMFkernels.cu
1 /*
2  Noel Lopes is an Assistant Professor at the Polytechnic Institute of Guarda, Portugal
3  Copyright (C) 2009, 2010, 2011, 2012 Noel de Jesus Mendonša Lopes
4 
5  This file is part of GPUMLib.
6 
7  GPUMLib is free software: you can redistribute it and/or modify
8  it under the terms of the GNU General Public License as published by
9  the Free Software Foundation, either version 3 of the License, or
10  (at your option) any later version.
11 
12  This program is distributed in the hope that it will be useful,
13  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  GNU General Public License for more details.
16 
17  You should have received a copy of the GNU General Public License
18  along with this program. If not, see <http://www.gnu.org/licenses/>.
19 */
20 
21 #include "NMFkernels.h"
22 
23 namespace GPUMLib {
24 
27 
28 // NMF_AditiveEuclidianDistance kernels
29 
30 KERNEL UpdateMatrix_AE(cudafloat * X, cudafloat * deltaX1, cudafloat * deltaX2, int elements) {
31  int idx = blockIdx.x * blockDim.x + threadIdx.x;
32 
33  if (idx < elements) {
34  cudafloat v = X[idx] + (X[idx] / deltaX2[idx]) * (deltaX1[idx] - deltaX2[idx]);
35  if (v < CUDA_VALUE(0.0)) v = CUDA_VALUE(0.0);
36  X[idx] = v;
37  }
38 }
39 
40 // NMF_MultiplicativeEuclidianDistance kernels
41 
42 KERNEL UpdateMatrix_ME(cudafloat * nm, cudafloat * dm, cudafloat * m, int elements) {
43  int idx = blockIdx.x * blockDim.x + threadIdx.x;
44 
45  if (idx < elements) m[idx] *= nm[idx] / (dm[idx] + SMALL_VALUE_TO_ADD_DENOMINATOR);
46 }
47 
48 // NMF_MultiplicativeDivergence kernels
49 
50 #ifdef ROW_MAJOR_H
51  #define HMATRIX(_ROW, _COL, _R, _M) (H[(_ROW) * (_M) + (_COL)])
52 #else
53  #define HMATRIX(_ROW, _COL, _R, _M) (H[(_COL) * (_R) + (_ROW)])
54 #endif
55 
56 #ifdef ROW_MAJOR_W
57  #define WMATRIX(_ROW, _COL, _N, _R) (W[(_ROW) * (_R) + (_COL)])
58 #else
59  #define WMATRIX(_ROW, _COL, _N, _R) (W[(_COL) * (_N) + (_ROW)])
60 #endif
61 
62 template <int blockSize> KERNEL SumW(cudafloat * W, int n, cudafloat * sumW) {
63  extern __shared__ cudafloat w[];
64 
65  w[threadIdx.x] = CUDA_VALUE(0.0);
66  for(int k = threadIdx.x; k < n; k += blockSize) {
67  w[threadIdx.x] += WMATRIX(k, blockIdx.x, n, gridDim.x);
68  }
69  __syncthreads();
70 
71  if (blockSize >= 1024) {
72  if (threadIdx.x < 512) w[threadIdx.x] += w[threadIdx.x + 512];
73  __syncthreads();
74  }
75 
76  if (blockSize >= 512) {
77  if (threadIdx.x < 256) w[threadIdx.x] += w[threadIdx.x + 256];
78  __syncthreads();
79  }
80 
81  if (blockSize >= 256) {
82  if (threadIdx.x < 128) w[threadIdx.x] += w[threadIdx.x + 128];
83  __syncthreads();
84  }
85 
86  if (blockSize >= 128) {
87  if (threadIdx.x < 64) w[threadIdx.x] += w[threadIdx.x + 64];
88  __syncthreads();
89  }
90 
91  if (threadIdx.x < 32) {
92  volatile cudafloat * _w = w;
93 
94  if (blockSize >= 64) _w[threadIdx.x] += _w[threadIdx.x + 32];
95  if (blockSize >= 32) _w[threadIdx.x] += _w[threadIdx.x + 16];
96  if (blockSize >= 16) _w[threadIdx.x] += _w[threadIdx.x + 8];
97  if (blockSize >= 8) _w[threadIdx.x] += _w[threadIdx.x + 4];
98  if (blockSize >= 4) _w[threadIdx.x] += _w[threadIdx.x + 2];
99  if (blockSize >= 2) _w[threadIdx.x] += _w[threadIdx.x + 1];
100 
101  if (threadIdx.x == 0) {
102  cudafloat sum = w[0];
104 
105  sumW[blockIdx.x] = sum;
106  }
107  }
108 }
109 
110 void KernelSumW(int blockSize, cudafloat * W, int n, int r, cudafloat * sumW) {
111  switch(blockSize) {
112  #ifdef FERMI
113  case 1024:
114  SumW<1024><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
115  break;
116  #endif
117  case 512:
118  SumW<512><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
119  break;
120  case 256:
121  SumW<256><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
122  break;
123  case 128:
124  SumW<128><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
125  break;
126  case 64:
127  SumW<64><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
128  break;
129  case 32:
130  SumW<32><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
131  break;
132  case 16:
133  SumW<16><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
134  break;
135  case 8:
136  SumW<8><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
137  break;
138  case 4:
139  SumW<4><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
140  break;
141  case 2:
142  SumW<2><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
143  break;
144  case 1:
145  SumW<1><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(W, n, sumW);
146  break;
147  }
148 }
149 
150 template <int blockSize> KERNEL SumH(cudafloat * H, int m, cudafloat * sumH) {
151  extern __shared__ cudafloat h[];
152 
153  h[threadIdx.x] = CUDA_VALUE(0.0);
154  for(int k = threadIdx.x; k < m; k += blockSize) {
155  h[threadIdx.x] += HMATRIX(blockIdx.x, k, gridDim.x, m);
156  }
157  __syncthreads();
158 
159  if (blockSize >= 1024) {
160  if (threadIdx.x < 512) h[threadIdx.x] += h[threadIdx.x + 512];
161  __syncthreads();
162  }
163 
164  if (blockSize >= 512) {
165  if (threadIdx.x < 256) h[threadIdx.x] += h[threadIdx.x + 256];
166  __syncthreads();
167  }
168 
169  if (blockSize >= 256) {
170  if (threadIdx.x < 128) h[threadIdx.x] += h[threadIdx.x + 128];
171  __syncthreads();
172  }
173 
174  if (blockSize >= 128) {
175  if (threadIdx.x < 64) h[threadIdx.x] += h[threadIdx.x + 64];
176  __syncthreads();
177  }
178 
179  if (threadIdx.x < 32) {
180  volatile cudafloat * _h = h;
181 
182  if (blockSize >= 64) _h[threadIdx.x] += _h[threadIdx.x + 32];
183  if (blockSize >= 32) _h[threadIdx.x] += _h[threadIdx.x + 16];
184  if (blockSize >= 16) _h[threadIdx.x] += _h[threadIdx.x + 8];
185  if (blockSize >= 8) _h[threadIdx.x] += _h[threadIdx.x + 4];
186  if (blockSize >= 4) _h[threadIdx.x] += _h[threadIdx.x + 2];
187  if (blockSize >= 2) _h[threadIdx.x] += _h[threadIdx.x + 1];
188 
189  if (threadIdx.x == 0) {
190  cudafloat sum = h[0];
192 
193  sumH[blockIdx.x] = sum;
194  }
195  }
196 }
197 
198 void KernelSumH(int blockSize, cudafloat * H, int r, int m, cudafloat * sumH) {
199  switch(blockSize) {
200  #ifdef FERMI
201  case 1024:
202  SumH<1024><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
203  break;
204  #endif
205  case 512:
206  SumH<512><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
207  break;
208  case 256:
209  SumH<256><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
210  break;
211  case 128:
212  SumH<128><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
213  break;
214  case 64:
215  SumH<64><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
216  break;
217  case 32:
218  SumH<32><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
219  break;
220  case 16:
221  SumH<16><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
222  break;
223  case 8:
224  SumH<8><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
225  break;
226  case 4:
227  SumH<4><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
228  break;
229  case 2:
230  SumH<2><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
231  break;
232  case 1:
233  SumH<1><<<r, blockSize, blockSize * sizeof(cudafloat)>>>(H, m, sumH);
234  break;
235  }
236 }
237 
238 //#define SW(_R, _C) sw[(_R)][(_C)]
239 #define SW(_R, _C) (sw[(_C)][(_R)])
240 
241 #define SVH(_R, _C) svh[(_R)][(_C)]
242 //#define SVH(_R, _C) (svh[(_C)][(_R)])
243 
244 //#define SH(_R, _C) sh[(_R)][(_C)]
245 #define SH(_R, _C) sh[(_C)][(_R)]
246 
247 #define SVW(_R, _C) svw[(_R)][(_C)]
248 //#define SVW(_R, _C) svw[(_C)][(_R)]
249 
250 KERNEL UpdateW_MD(cudafloat * W, cudafloat * H, cudafloat * V, cudafloat * WH, cudafloat * sumH, int n, int m, int r) {
251  __shared__ cudafloat SH(32, 32);
252  __shared__ cudafloat SVW(32, 32);
253 
254  int x = blockIdx.x * 32 + threadIdx.x;
255  int y = blockIdx.y * 32 + threadIdx.y;
256 
257  cudafloat sum1 = CUDA_VALUE(0.0);
258  cudafloat sum2 = CUDA_VALUE(0.0);
259 
260  for(int k = 0; k < m; k += 32) {
261  int tx = threadIdx.x + 16;
262 
263  if (x < r && threadIdx.y + k < m) {
264  int ky = k + threadIdx.y;
265  SH(threadIdx.x, threadIdx.y) = HMATRIX(x, ky, r, m);
266  SH(tx, threadIdx.y) = (x + 16 < r) ? HMATRIX(x + 16, ky, r, m) : CUDA_VALUE(0.0);
267  } else {
268  SH(threadIdx.x, threadIdx.y) = CUDA_VALUE(0.0);
269  SH(tx, threadIdx.y) = CUDA_VALUE(0.0);
270  }
271 
272  if (y < n && k + threadIdx.x < m) {
273  int idx = (k + threadIdx.x) * n + y;
274  SVW(threadIdx.y, threadIdx.x) = (V[idx] / (WH[idx] + SMALL_VALUE_TO_ADD_DENOMINATOR));
275 
276  idx += (n << 4);
277  SVW(threadIdx.y, tx) = (k + tx < m) ? (V[idx] / (WH[idx] + SMALL_VALUE_TO_ADD_DENOMINATOR)) : CUDA_VALUE(0.0);
278  } else {
279  SVW(threadIdx.y, threadIdx.x) = CUDA_VALUE(0.0);
280  SVW(threadIdx.y, tx) = CUDA_VALUE(0.0);
281  }
282  __syncthreads();
283 
284  for(int i = 0; i < 32; i++) {
285  sum1 += SH(threadIdx.x, i) * SVW(threadIdx.y, i);
286  sum2 += SH(tx, i) * SVW(threadIdx.y, i);
287  }
288  __syncthreads();
289  }
290 
291  if (y < n && x < r) {
292  WMATRIX(y, x, n, r) *= (sum1 / sumH[x]);
293  x += 16;
294  if (x < r) WMATRIX(y, x, n, r) *= (sum2 / sumH[x]);
295  }
296 }
297 
298 KERNEL UpdateH_MD(cudafloat * H, cudafloat * W, cudafloat * V, cudafloat * WH, cudafloat * sumW, int n, int m, int r) {
299  __shared__ cudafloat SW(32, 32);
300  __shared__ cudafloat SVH(32, 32);
301 
302  int x = blockIdx.x * 32 + threadIdx.x;
303  int y = blockIdx.y * 32 + threadIdx.y;
304 
305  cudafloat sum1 = CUDA_VALUE(0.0);
306  cudafloat sum2 = CUDA_VALUE(0.0);
307 
308  for(int k = 0; k < n; k += 32) {
309  int ty = threadIdx.y + 16;
310 
311  if (y < r && k + threadIdx.x < n) {
312  int kx = k + threadIdx.x;
313  SW(threadIdx.x, threadIdx.y) = WMATRIX(kx, y, n, r);
314  SW(threadIdx.x, ty) = (y + 16 < r) ? WMATRIX(kx, y + 16, n, r) : CUDA_VALUE(0.0);
315  } else {
316  SW(threadIdx.x, threadIdx.y) = CUDA_VALUE(0.0);
317  SW(threadIdx.x, ty) = CUDA_VALUE(0.0);
318  }
319 
320  if (x < m && k + threadIdx.y < n) {
321  int idx = x * n + (k + threadIdx.y);
322  SVH(threadIdx.y, threadIdx.x) = V[idx] / (WH[idx] + SMALL_VALUE_TO_ADD_DENOMINATOR);
323 
324  idx += 16;
325  SVH(ty, threadIdx.x) = (k + ty < n) ? (V[idx] / (WH[idx] + SMALL_VALUE_TO_ADD_DENOMINATOR)) : CUDA_VALUE(0.0);
326  } else {
327  SVH(threadIdx.y, threadIdx.x) = CUDA_VALUE(0.0);
328  SVH(ty, threadIdx.x) = CUDA_VALUE(0.0);
329  }
330  __syncthreads();
331 
332  for(int i = 0; i < 32; i++) {
333  sum1 += SW(i, threadIdx.y) * SVH(i, threadIdx.x);
334  sum2 += SW(i, ty) * SVH(i, threadIdx.x);
335  }
336  __syncthreads();
337  }
338 
339  if (y < r && x < m) {
340  HMATRIX(y, x, r, m) *= (sum1 / sumW[y]);
341  y += 16;
342  if (y < r) HMATRIX(y, x, r, m) *= (sum2 / sumW[y]);
343  }
344 }
345 
347 
348 }
KERNEL UpdateW_MD(cudafloat *W, cudafloat *H, cudafloat *V, cudafloat *WH, cudafloat *sumH, int n, int m, int r)
Definition: NMFkernels.cu:250
KERNEL UpdateH_MD(cudafloat *H, cudafloat *W, cudafloat *V, cudafloat *WH, cudafloat *sumW, int n, int m, int r)
Definition: NMFkernels.cu:298
void KernelSumW(int blockSize, cudafloat *W, int n, int r, cudafloat *sumW)
Definition: NMFkernels.cu:110
void KernelSumH(int blockSize, cudafloat *H, int r, int m, cudafloat *sumH)
Definition: NMFkernels.cu:198
#define SMALL_VALUE_TO_ADD_DENOMINATOR
Small value added to the denominator of a fraction to prevent division by zero.
Definition: NMFkernels.h:32
#define KERNEL
Defines the type of a kernel function.
#define CUDA_VALUE(X)
KERNEL UpdateMatrix_ME(cudafloat *nm, cudafloat *dm, cudafloat *m, int elements)
Definition: NMFkernels.cu:42
float cudafloat