templateglobal void
pooling_max_kernel
(T* pooled,
const T* data,
const T* stochastic_value,
const int mode,
const int pooledWidth,
const int pooledHeight,
const int pooledVolume,
const int width,
const int height,
const int poolWidth,
const int poolHeight,
const int strideX,
const int strideY,
const int padLeft,
const int padTop)
{
int pooledIndex = threadIdx.x + blockIdx.x * blockDim.x;
int stochastic_value_index=0;
if (pooledIndex < pooledVolume) {
int px = pooledIndex ;
int py = px / pooledWidth ;
int pz = py / pooledHeight ;
px %= pooledWidth ;
py %= pooledHeight ;
data += pz * (width*height) ;
int x1 = px * strideX - padLeft ;
int y1 = py * strideY - padTop ;
int x2 = min(x1 + poolWidth, width) ;
int y2 = min(y1 + poolHeight, height) ;
x1 = max(x1, 0) ;
y1 = max(y1, 0) ;
//T *savedata;T *dataprob;
//cudaMalloc((void**)&savedata,sizeof(T)*9);cudaMalloc((void**)&dataprob,sizeof(T)*9);/////////////////////////////////////////////////定义变量分配内存
T savedata[9]={0};T dataprob[10]={0};dataprob[0]=0;
T sumdata=0;T weightsum=0;int i=0;/////////////////////////////lyz
//T bestValue = data[y1 * width + x1] ;
for (int y = y1 ; y < y2 ; ++y) {
for (int x = x1 ; x < x2 ; ++x) {
//bestValue = max(bestValue, data[y * width + x]) ;
savedata[i]=data[y * width + x];
sumdata=sumdata+data[y * width + x];
i=i+1;
if (i>8)
{i=0;}
}
}
if(mode==0)
{
T randdata = stochastic_value[pooledIndex];
if (sumdata>0)
{ for(int j=0;j<poolWidth*poolHeight;++j)
{
dataprob[j+1]=dataprob[j]+savedata[j]/sumdata;
if (randdata<dataprob[j+1])
{ stochastic_value_index=j;
break; }
}
pooled[pooledIndex] = savedata[stochastic_value_index] ;
}
else
{
pooled[pooledIndex] = 0 ;
}
}
else
{
if (sumdata>0)
{ for(int j=0;j<poolWidth*poolHeight;++j)
{
weightsum=weightsum+savedata[j]*savedata[j]/sumdata;
}
pooled[pooledIndex] = weightsum ;
}
else
{
pooled[pooledIndex] = 0 ;
}
}
//cudaFree(savedata);cudaFree(dataprob);
}
}
另,在kernel函数中定义变量需要分配内存吗,我之前用注释掉的那个分配内存,编译的时候回报错
但是对于kernel函数,没有返回cudaerror,这个时候检查错误,需要用到,
cudaError_t err = cudaGetLastError();
printf("%s\n",cudaGetErrorString(err));
打印出最近出的错。