diff --git a/shuffle/shuffle.go b/shuffle/shuffle.go index 8db20b2..666bb64 100644 --- a/shuffle/shuffle.go +++ b/shuffle/shuffle.go @@ -51,15 +51,39 @@ func (s Shuffler) UnshuffleInts(x []int) { } } +// maxMultiple returns the highest multiple of n that fits in a uint32 +func maxMultiple(n uint32) uint32 { + uint32Max := ^uint32(0) + return uint32Max - (uint32Max % n) +} + +// intn returns a random number uniformly distributed between 0 and n (not +// including n). +// +// rand should be a source of random bytes +// +// buf should be a temporary buffer with length at least 4 func intn(rand *bufio.Reader, n uint32, buf []byte) int { - max := ^uint32(0) - m := max - (max % n) + // intn does not simply take a random uint32 mod n because this is biased. + // Consider n=3 and a random uint32 u. (2^32-2)%3 == 2, so for u from 0 to + // 2^32-2, u%3 evenly rotates among 0, 1, and 2. However, (2^32-1)%3 == 0, + // so there is a slight bias in favor of u%3 == 0 in the case where u == + // 2^32-1. + // + // To solve this problem, intn rejection-samples a number x between 0 and a + // multiple of n (not including the upper bound), then takes x%n, which is + // truly uniform. + + m := maxMultiple(n) for { if _, err := rand.Read(buf); err != nil { panic(err) } + // Get a uniform random number in [0, 2^32) x := binary.BigEndian.Uint32(buf) if x < m { + // Accept only random numbers in [0, m). Because m is a multiple of + // n, x % n is uniformly distributed in [0, n). return int(x % n) } } diff --git a/shuffle/shuffle_test.go b/shuffle/shuffle_test.go index 0ef2ce8..321d8bf 100644 --- a/shuffle/shuffle_test.go +++ b/shuffle/shuffle_test.go @@ -41,6 +41,22 @@ func TestShuffle(t *testing.T) { } } +func TestMaxMultiple(t *testing.T) { + for _, n := range []uint32{2, 3, 5, 10, 15, 1<<10} { + m := maxMultiple(n) + if m%n != 0 { + t.Errorf("maxMultiple(%d) is not a multiple", n) + continue + } + // note that m + n will wrap around if m is maximal; this relies on + // uint32 modular arithmetic + if m + n > m { + t.Errorf("maxMultiple(%d) is not maximal", n) + continue + } + } +} + func BenchmarkNew(b *testing.B) { for i := 0; i < b.N; i++ { New(rand.Reader, 50000)