diff --git a/lib/THC/THCTensorRandom.cuh b/lib/THC/THCTensorRandom.cuh index 5afd8fed..bc519196 100644 --- a/lib/THC/THCTensorRandom.cuh +++ b/lib/THC/THCTensorRandom.cuh @@ -228,7 +228,7 @@ sampleMultinomialOnce(long* dest, template __global__ void -sampleMultinomialWithReplacement(curandStateMtgp32* state, +sampleMultinomialWithReplacement(T* uniform_idx, int totalSamples, long* dest, long distributions, @@ -236,9 +236,7 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state, T* normDistPrefixSum) { // At the moment, each warp computes one sample value in the binary // search due to divergence. It seems possible to compute multiple - // values and limit divergence though later on. However, no matter - // what, all block threads must participate in the curand_uniform - // call to update the generator state. + // values and limit divergence though later on. // The block determines the distribution for which we generate a point for (long curDist = blockIdx.x; @@ -250,7 +248,7 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state, int sample = sampleBase + threadIdx.y; // All threads participate in this - T r = ScalarConvert::to(curand_uniform(&state[blockIdx.x])); + T r = uniform_idx[sample]; if (threadIdx.x == 0 && sample < totalSamples) { // Find the bucket that a uniform sample lies in @@ -284,6 +282,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state, // The block and warp determines the distribution for which we // generate a point + T zero = ScalarConvert::to(0); for (long curDistBase = blockIdx.x * blockDim.y; curDistBase < distributions; curDistBase += gridDim.x * blockDim.y) { @@ -292,7 +291,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state, // All threads must participate in this T r = ScalarConvert::to(curand_uniform(&state[blockIdx.x])); - + if (threadIdx.x == 0 && curDist < distributions) { // Find the bucket that a uniform sample lies in int choice = binarySearchForMultinomial( @@ -305,7 +304,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state, // Without replacement, so update the original probability so it // is not considered a second time - origDist[curDist * categories + choice] = ScalarConvert::to(0); + origDist[curDist * categories + choice] = zero; } } } diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu index 4c6d2fb2..341c4fd6 100644 --- a/lib/THC/generic/THCTensorRandom.cu +++ b/lib/THC/generic/THCTensorRandom.cu @@ -202,13 +202,21 @@ THC_API void THCTensor_(multinomial)(struct THCState *state, // distribution concurrently. dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS); + + //Create the matrix of uniformly sampled numbers + THCTensor *uniform_idx = THCTensor_(newWithSize1d)(state, n_sample); + THCTensor_(uniform)(state, uniform_idx, 0, 1); + + sampleMultinomialWithReplacement <<>>( - gen->gen_states, - n_sample, - THCudaLongTensor_data(state, self), - numDist, numCategories, - THCTensor_(data)(state, prefixSum)); + THCTensor_(data)(state, uniform_idx), + n_sample, + THCudaLongTensor_data(state, self), + numDist, numCategories, + THCTensor_(data)(state, prefixSum)); + + THCTensor_(free)(state, uniform_idx); } else { // Sample without replacement @@ -237,13 +245,13 @@ THC_API void THCTensor_(multinomial)(struct THCState *state, // recalculate our distribution sampleMultinomialWithoutReplacement <<>>( - gen->gen_states, - n_sample, - sample, - THCudaLongTensor_data(state, self), - numDist, numCategories, - THCTensor_(data)(state, origDist), - THCTensor_(data)(state, prefixSum)); + gen->gen_states, + n_sample, + sample, + THCudaLongTensor_data(state, self), + numDist, numCategories, + THCTensor_(data)(state, origDist), + THCTensor_(data)(state, prefixSum)); } } diff --git a/test/multinomial.lua b/test/multinomial.lua new file mode 100644 index 00000000..0c02da1d --- /dev/null +++ b/test/multinomial.lua @@ -0,0 +1,52 @@ +local tester = torch.Tester() + +cmd = torch.CmdLine() +cmd:text() +cmd:text() +cmd:text('Testing alias multinomial on cuda') +cmd:text() +cmd:text('Options') +cmd:option('--compare',false,'compare with cutorch multinomial') +cmd:text() + +-- parse input params +params = cmd:parse(arg) + +require 'cutorch' +local function checkMultinomial() + local n_class = {10, 100, 1000} + local n_sample = {10, 100, 1000, 10000} + local n_dist = 100 + for _, curr_n_class in pairs(n_class) do + for _, curr_n_sample in pairs(n_sample) do + print("") + print("Benchmarking multinomial with "..curr_n_class.." classes and "..curr_n_sample.." samples") + torch.seed() + local probs = torch.CudaDoubleTensor(n_dist, curr_n_class):uniform(0,1) + local a = torch.Timer() + local cold_time = a:time().real + a:reset() + cutorch.synchronize() + a:reset() + for i = 1,10 do + torch.multinomial(probs, curr_n_sample, true) + cutorch.synchronize() + end + print("[CUDA] : torch.multinomial draw: "..(a:time().real/10).." seconds (hot)") + end + torch.seed() + local probs = torch.CudaDoubleTensor(3, curr_n_class):uniform(0,1) + for i =1,3 do + probs[i]:div(probs[i]:sum()) + end + local output = torch.multinomial(probs, 5000000, true) + local counts = torch.Tensor(3, curr_n_class):zero() + for i=1,3 do + output[i]:long():apply(function(x) counts[{i, x}] = counts[{i, x}] + 1 end) + counts[i]:div(counts[i]:sum()) + end + tester:eq(probs:double(), counts, 0.01, "probs and counts should be approximately equal for n_class = "..curr_n_class) + end +end +tester:add(checkMultinomial) +tester:run()