Skip to content

Commit

Permalink
comment out LKJ tests purely to get a report back from CI on how gret…
Browse files Browse the repository at this point in the history
…a is performing
  • Loading branch information
njtierney committed Nov 1, 2024
1 parent 0427e54 commit 3256792
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 161 deletions.
91 changes: 91 additions & 0 deletions tests/testthat/test-lkj-log-prob-is-correct-2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# ## This code was initially written to compare and understand if the
# ## bijectors for lkj were accurate/needed to change
# ## Overall we found that we didn't actually need to use the custom
# ## forward log det jacobian that we had for lkj
# ## Which was initially written as
# test_that("Log prob for lkj is correct", {
#
# set.seed(2024-10-31-1027)
#
# eta <- pi
# dim <- 2
#
# x <- rlkjcorr(n = 1, eta = eta, dimension = dim)
# chol_x <- chol(x)
#
# new_lkj_bijector <- function(){
# steps <- list(
# tfp$bijectors$Transpose(perm = 1:0),
# tfp$bijectors$CorrelationCholesky()
# )
# bijector <- tfp$bijectors$Chain(steps)
# bijector
# }
#
# new_bijector <- new_lkj_bijector()
#
# free_state <- new_bijector$inverse(fl(chol_x))
# free_state
#
# # check that we can rebuild x
#
# new_chol_x <- new_bijector$forward(free_state)
# class(new_chol_x)
# chol2symm(new_chol_x)
# # this is the same as chol2symm - slight rounding difference due to signif digits
# new_x <- tf$matmul(new_chol_x, new_chol_x, adjoint_a = TRUE)
#
# ## we want to check that the log prob is correct!
#
# x_g <- lkj_correlation(eta = eta, dimension = dim)
# m_g <- model(x_g)
# greta_log_prob <- m_g$dag$generate_log_prob_function()
#
# # make it 2d and tranpose to ensure it is 1xnumber of free state elements
# free_state_mat <- t(as.matrix(free_state))
# greta_log_probs <- greta_log_prob(free_state_mat)
#
# greta_log_probs
#
# log_prob_adjustment <- new_bijector$forward_log_det_jacobian(free_state)
# dist_chol_lkj <- tfp$distributions$CholeskyLKJ(
# dimension = as.integer(dim),
# concentration = fl(eta)
# )
#
# log_prob_raw <- dist_chol_lkj$log_prob(new_chol_x)
# log_prob_raw_unnormalised <- dist_chol_lkj$unnormalized_log_prob(new_chol_x)
#
# # this is the normalisation contant, which depends on what we set for `eta`
# dist_chol_lkj$`_log_normalization`()
#
# # this is the unnormalised probability
# dist_chol_lkj$`_log_unnorm_prob`(new_chol_x, fl(eta))
#
# # our implementation for log normalisation is (most likely?) wrong
# lkj_log_normalising(eta, n = 1)
#
#
# log_prob_raw
# log_prob_raw_unnormalised
#
# log_prob_adjusted <- log_prob_raw + log_prob_adjustment
#
# as.numeric(log_prob_adjusted)
# as.numeric(greta_log_probs$adjusted)
#
# as.numeric(log_prob_raw)
# as.numeric(greta_log_probs$unadjusted)
#
# as.numeric(greta_log_probs$adjusted) - as.numeric(greta_log_probs$unadjusted)
# as.numeric(log_prob_adjustment)
#
# # use the old (current one)
# old_bijector <- tf_correlation_cholesky_bijector()
#
# old_log_prob_adjustment <- old_bijector$forward_log_det_jacobian(t(as.matrix(free_state)))
#
# as.numeric(old_log_prob_adjustment)
# as.numeric(greta_log_probs$adjusted) - as.numeric(greta_log_probs$unadjusted)
#
# })
253 changes: 133 additions & 120 deletions tests/testthat/test-lkj-log-prob-is-correct.R
Original file line number Diff line number Diff line change
@@ -1,120 +1,133 @@
test_that("Log prob for lkj is correct", {
# General process is:
# 1. Simulate a lkj draw x with rlkj()
# 2. Transform x to the equivalent free_state, using the bijector but running
# it in reverse
# 3. Run the log_prob() function on free_state
# 4. Run dlkj(..., log = TRUE) on x, and compare with result of step 3

# 1. Simulate a lkj draw x with rlkj() ----------------------------------
set.seed(2024-10-31-1027)

eta <- 1
dim <- 2

x <- rlkjcorr(n = 1, eta = eta, dimension = dim)
chol_x <- chol(x)

## 2. Transform x to the equivalent free_state, using the bijector but
## running it in reverse -----------------------------------------------------
# we need to get a free state that we can plug into log prob. We know that
# this free state matches the chol_x, so we can compare them later.

new_lkj_bijector <- function(){
steps <- list(
tfp$bijectors$Transpose(perm = 1:0),
tfp$bijectors$CorrelationCholesky()
)
bijector <- tfp$bijectors$Chain(steps)
bijector
}

a_bijector <- new_lkj_bijector()
free_state <- a_bijector$inverse(fl(chol_x))

# TODO
# get the greta log prob function
# x_g <- lkj_correlation(eta = eta, dimension = dim)
# m_g <- model(x_g)
# greta_log_prob <- m_g$dag$generate_log_prob_function()
#
# free_state_mat <- t(as.matrix(free_state))
# log_probs <- greta_log_prob(free_state_mat)

# compare this with a direct calculation of what it should be, using TF

# lkj distribution for choleskies, matching what's in the greta code
lkj_cholesky_dist <- tfp$distributions$CholeskyLKJ(
dimension = as.integer(dim),
concentration = fl(eta)
)

# we should expect this to have 1s on the diagonal...
new_draws <- lkj_cholesky_dist$sample()

new_draws

chol2symm(t(new_draws))
chol2symm(new_draws)
tf_chol2symm(new_draws)

tf_chol2symm_new <- function(x) {
tf$matmul(x, tf_transpose(x))
}

tf_transpose(new_draws) |>
tf_chol2symm_new()

tf_transpose(new_draws) |>
tf_chol2symm()


chol2symm(tf_transpose(lkj_cholesky_dist$sample())
tf_chol2symm(lkj_cholesky_dist$sample())
tf_chol2symm(tf_tranpose(lkj_cholesky_dist$sample()))

lkj_dist <- tfp$distributions$LKJ(dimension = as.integer(dim),
concentration = fl(eta),
input_output_cholesky = FALSE)

# project forward from the free state
chol_x_new <- a_bijector$forward(free_state)
x_new <- tf_chol2symm(chol_x_new)
# compute the log density of the transformed variable, and the adjustment for
# change of support
### njt log_density_unadjusted?
log_density_raw <- lkj_cholesky_dist$log_prob(chol_x_new)
log_density_raw_new <- lkj_dist$log_prob(x_new)
adjustment <- a_bijector$forward_log_det_jacobian(free_state)
log_density <- log_density_raw + adjustment
log_density_new <- log_density_raw_new + adjustment

log_density
log_density_new
# greta log probs should match alternative way to calculate these
expect_equal(as.numeric(log_density), as.numeric(log_probs$adjusted))
# this is what the unadjusted version would look like, without accounting for
# the density adjustment due to the bijector
expect_equal(as.numeric(log_density_raw), as.numeric(log_probs$unadjusted))

# the difference between adjusted and unadjusted in greta is the same as
# the bijector adjustment calculated here, so the bijector adjustment
# calculated in greta must be correct, and it must be a problem with the
# distribution, not the bijector
expect_equal(
as.numeric(adjustment),
as.numeric(log_probs$adjusted - log_probs$unadjusted)
)

# check the TF distribution we use above against the R version (the same)
expect_equal(
as.numeric(log_density_raw),
dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim)
)

expect_equal(
as.numeric(log_density_raw_new),
dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim)
)

})
# test_that("Log prob for lkj is correct", {
# # General process is:
# # 1. Simulate a lkj draw x with rlkj()
# # 2. Transform x to the equivalent free_state, using the bijector but running
# # it in reverse
# # 3. Run the log_prob() function on free_state
# # 4. Run dlkj(..., log = TRUE) on x, and compare with result of step 3
#
# # 1. Simulate a lkj draw x with rlkj() ----------------------------------
# set.seed(2024-10-31-1027)
#
# eta <- 1
# dim <- 2
#
# x <- rlkjcorr(n = 1, eta = eta, dimension = dim)
# chol_x <- chol(x)
#
# ## 2. Transform x to the equivalent free_state, using the bijector but
# ## running it in reverse -----------------------------------------------------
# # we need to get a free state that we can plug into log prob. We know that
# # this free state matches the chol_x, so we can compare them later.
#
# # old bijector?
#
# new_lkj_bijector <- function(){
# steps <- list(
# tfp$bijectors$Transpose(perm = 1:0),
# tfp$bijectors$CorrelationCholesky()
# )
# bijector <- tfp$bijectors$Chain(steps)
# bijector
# }
#
# a_bijector <- new_lkj_bijector()
# free_state_new_bijector <- a_bijector$inverse(fl(chol_x))
#
# # TODO
# # these are the same! Do we care which bijector we use?
# # We probably want the old one as it has a different log det jeacobian fun?
# old_lkj_bijector <- tf_correlation_cholesky_bijector()
# free_state_old_bijector <- old_lkj_bijector$inverse(fl(chol_x))
# expect_equal(
# as.numeric(free_state_old_bijector),
# as.numeric(free_state_new_bijector)
# )
#
# # comparing log_det_jacobian of old and new
# free_state_new_bijector
# free_state_old_bijector
# t(as.matrix(free_state_old_bijector))
# free_state_old_bijector
# old_lkj_bijector$forward_log_det_jacobian(t(as.matrix(free_state_old_bijector)))
# a_bijector$forward_log_det_jacobian(free_state_old_bijector)
#
# # TODO
# # get the greta log prob function
# # x_g <- lkj_correlation(eta = eta, dimension = dim)
# # m_g <- model(x_g)
# # greta_log_prob <- m_g$dag$generate_log_prob_function()
# #
# # free_state_mat <- t(as.matrix(free_state))
# # log_probs <- greta_log_prob(free_state_mat)
#
# # compare this with a direct calculation of what it should be, using TF
#
# # lkj distribution for choleskies, matching what's in the greta code
# lkj_cholesky_dist <- tfp$distributions$CholeskyLKJ(
# dimension = as.integer(dim),
# concentration = fl(eta)
# )
#
# # we should expect this to have 1s on the diagonal...
# new_draws <- lkj_cholesky_dist$sample()
#
# new_draws
#
# tf$matmul(new_draws, new_draws, adjoint_b = TRUE)
#
# lkj_dist <- tfp$distributions$LKJ(dimension = as.integer(dim),
# concentration = fl(eta),
# input_output_cholesky = FALSE)
#
# lkj_dist$sample()
#
# # same value for both old and new bijector
# free_state <- free_state_old_bijector
# # project forward from the free state
# chol_x_new <- a_bijector$forward(free_state)
# chol_x_new
# chol2symm(chol_x_new)
# x_new <- tf$matmul(new_draws, new_draws, adjoint_b = TRUE)
# tf$matmul(new_draws, new_draws, adjoint_a = TRUE)
# tf$matmul(chol_x_new, chol_x_new,adjoint_a = TRUE)
# x_new
# x
#
# # compute the log density of the transformed variable, and the adjustment for
# # change of support
# log_density_raw <- lkj_cholesky_dist$log_prob(chol_x_new)
# log_density_raw_new <- lkj_dist$log_prob(x_new)
# adjustment <- a_bijector$forward_log_det_jacobian(free_state)
# log_density <- log_density_raw + adjustment
# log_density_new <- log_density_raw_new + adjustment
#
# log_density
#
# # greta log probs should match alternative way to calculate these
# expect_equal(as.numeric(log_density), as.numeric(log_probs$adjusted))
# # this is what the unadjusted version would look like, without accounting for
# # the density adjustment due to the bijector
# expect_equal(as.numeric(log_density_raw), as.numeric(log_probs$unadjusted))
#
# # the difference between adjusted and unadjusted in greta is the same as
# # the bijector adjustment calculated here, so the bijector adjustment
# # calculated in greta must be correct, and it must be a problem with the
# # distribution, not the bijector
# expect_equal(
# as.numeric(adjustment),
# as.numeric(log_probs$adjusted - log_probs$unadjusted)
# )
#
# # check the TF distribution we use above against the R version (the same)
# expect_equal(
# as.numeric(log_density_raw),
# dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim)
# )
#
# expect_equal(
# as.numeric(log_density_raw_new),
# dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim)
# )
#
# })
Loading

0 comments on commit 3256792

Please sign in to comment.