Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very slow calculation of multivariate density over multiple at positions #120

Closed
robjhyndman opened this issue Jul 31, 2024 · 1 comment
Closed

Comments

@robjhyndman
Copy link
Contributor

library(distributional)
mu <- c(0, 0)
Sigma <- diag(2)
dist <- distributional::dist_multivariate_normal(list(mu), list(Sigma))

at <- expand.grid(x = seq(-3,3,by=0.5), y = seq(-2, 10, by = 2)) 
# This seems the most obvious way to specify multiple at positions
d <- density(dist, at) 
#> Error in `FUN()`:
#> ! Cannot recycle input of size 91 to match the distributions (size 1).
# This works but is very slow and produces unnecessary messages
system.time(density(dist, apply(at, 1, list)))
#> New names:
#> • `` -> `...1`
#> • `` -> `...2`
#> • `` -> `...3`
#> • `` -> `...4`
#> • `` -> `...5`
#> • `` -> `...6`
#> • `` -> `...7`
#> • `` -> `...8`
#> • `` -> `...9`
#> • `` -> `...10`
#> • `` -> `...11`
#> • `` -> `...12`
#> • `` -> `...13`
#> • `` -> `...14`
#> • `` -> `...15`
#> • `` -> `...16`
#> • `` -> `...17`
#> • `` -> `...18`
#> • `` -> `...19`
#> • `` -> `...20`
#> • `` -> `...21`
#> • `` -> `...22`
#> • `` -> `...23`
#> • `` -> `...24`
#> • `` -> `...25`
#> • `` -> `...26`
#> • `` -> `...27`
#> • `` -> `...28`
#> • `` -> `...29`
#> • `` -> `...30`
#> • `` -> `...31`
#> • `` -> `...32`
#> • `` -> `...33`
#> • `` -> `...34`
#> • `` -> `...35`
#> • `` -> `...36`
#> • `` -> `...37`
#> • `` -> `...38`
#> • `` -> `...39`
#> • `` -> `...40`
#> • `` -> `...41`
#> • `` -> `...42`
#> • `` -> `...43`
#> • `` -> `...44`
#> • `` -> `...45`
#> • `` -> `...46`
#> • `` -> `...47`
#> • `` -> `...48`
#> • `` -> `...49`
#> • `` -> `...50`
#> • `` -> `...51`
#> • `` -> `...52`
#> • `` -> `...53`
#> • `` -> `...54`
#> • `` -> `...55`
#> • `` -> `...56`
#> • `` -> `...57`
#> • `` -> `...58`
#> • `` -> `...59`
#> • `` -> `...60`
#> • `` -> `...61`
#> • `` -> `...62`
#> • `` -> `...63`
#> • `` -> `...64`
#> • `` -> `...65`
#> • `` -> `...66`
#> • `` -> `...67`
#> • `` -> `...68`
#> • `` -> `...69`
#> • `` -> `...70`
#> • `` -> `...71`
#> • `` -> `...72`
#> • `` -> `...73`
#> • `` -> `...74`
#> • `` -> `...75`
#> • `` -> `...76`
#> • `` -> `...77`
#> • `` -> `...78`
#> • `` -> `...79`
#> • `` -> `...80`
#> • `` -> `...81`
#> • `` -> `...82`
#> • `` -> `...83`
#> • `` -> `...84`
#> • `` -> `...85`
#> • `` -> `...86`
#> • `` -> `...87`
#> • `` -> `...88`
#> • `` -> `...89`
#> • `` -> `...90`
#> • `` -> `...91`
#>    user  system elapsed 
#>   0.675   0.007   0.682
system.time(mvtnorm::dmvnorm(at, mu, Sigma))
#>    user  system elapsed 
#>       0       0       0

Created on 2024-07-31 with reprex v2.1.1

@mitchelloharawild
Copy link
Owner

Here's the preferred approach. Multivariate distribution operations are built on matrices.

The list input for vectorised operations is more of a wide form, where the rows/sizes of inputs are 1 or match the length of the distribution (documentation needed, #107). In this case each of the 91 inputs would be named list elements (as you've done without names hence messages) or data frame columns.

Most of the speed loss with the benchmarks below is attributable to vectorisation improvements needed (#25)

library(distributional)
mu <- c(0, 0)
Sigma <- diag(2)
dist <- distributional::dist_multivariate_normal(list(mu), list(Sigma))

at <- expand.grid(x = seq(-3,3,by=0.5), y = seq(-2, 10, by = 2)) 

bench::mark(
  density(dist, as.matrix(at)),
  mvtnorm::dmvnorm(at, mu, Sigma),
  check = FALSE
)
#> # A tibble: 2 x 6
#>   expression                           min   median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr>                      <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
#> 1 density(dist, as.matrix(at))      11.3ms   13.4ms      73.0   448.1KB     11.8
#> 2 mvtnorm::dmvnorm(at, mu, Sigma)  157.2us    174us    5165.     14.2KB     12.6

Created on 2024-07-31 with reprex v2.1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants