-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
comment out LKJ tests purely to get a report back from CI on how gret…
…a is performing
- Loading branch information
Showing
3 changed files
with
244 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# ## This code was initially written to compare and understand if the | ||
# ## bijectors for lkj were accurate/needed to change | ||
# ## Overall we found that we didn't actually need to use the custom | ||
# ## forward log det jacobian that we had for lkj | ||
# ## Which was initially written as | ||
# test_that("Log prob for lkj is correct", { | ||
# | ||
# set.seed(2024-10-31-1027) | ||
# | ||
# eta <- pi | ||
# dim <- 2 | ||
# | ||
# x <- rlkjcorr(n = 1, eta = eta, dimension = dim) | ||
# chol_x <- chol(x) | ||
# | ||
# new_lkj_bijector <- function(){ | ||
# steps <- list( | ||
# tfp$bijectors$Transpose(perm = 1:0), | ||
# tfp$bijectors$CorrelationCholesky() | ||
# ) | ||
# bijector <- tfp$bijectors$Chain(steps) | ||
# bijector | ||
# } | ||
# | ||
# new_bijector <- new_lkj_bijector() | ||
# | ||
# free_state <- new_bijector$inverse(fl(chol_x)) | ||
# free_state | ||
# | ||
# # check that we can rebuild x | ||
# | ||
# new_chol_x <- new_bijector$forward(free_state) | ||
# class(new_chol_x) | ||
# chol2symm(new_chol_x) | ||
# # this is the same as chol2symm - slight rounding difference due to signif digits | ||
# new_x <- tf$matmul(new_chol_x, new_chol_x, adjoint_a = TRUE) | ||
# | ||
# ## we want to check that the log prob is correct! | ||
# | ||
# x_g <- lkj_correlation(eta = eta, dimension = dim) | ||
# m_g <- model(x_g) | ||
# greta_log_prob <- m_g$dag$generate_log_prob_function() | ||
# | ||
# # make it 2d and tranpose to ensure it is 1xnumber of free state elements | ||
# free_state_mat <- t(as.matrix(free_state)) | ||
# greta_log_probs <- greta_log_prob(free_state_mat) | ||
# | ||
# greta_log_probs | ||
# | ||
# log_prob_adjustment <- new_bijector$forward_log_det_jacobian(free_state) | ||
# dist_chol_lkj <- tfp$distributions$CholeskyLKJ( | ||
# dimension = as.integer(dim), | ||
# concentration = fl(eta) | ||
# ) | ||
# | ||
# log_prob_raw <- dist_chol_lkj$log_prob(new_chol_x) | ||
# log_prob_raw_unnormalised <- dist_chol_lkj$unnormalized_log_prob(new_chol_x) | ||
# | ||
# # this is the normalisation contant, which depends on what we set for `eta` | ||
# dist_chol_lkj$`_log_normalization`() | ||
# | ||
# # this is the unnormalised probability | ||
# dist_chol_lkj$`_log_unnorm_prob`(new_chol_x, fl(eta)) | ||
# | ||
# # our implementation for log normalisation is (most likely?) wrong | ||
# lkj_log_normalising(eta, n = 1) | ||
# | ||
# | ||
# log_prob_raw | ||
# log_prob_raw_unnormalised | ||
# | ||
# log_prob_adjusted <- log_prob_raw + log_prob_adjustment | ||
# | ||
# as.numeric(log_prob_adjusted) | ||
# as.numeric(greta_log_probs$adjusted) | ||
# | ||
# as.numeric(log_prob_raw) | ||
# as.numeric(greta_log_probs$unadjusted) | ||
# | ||
# as.numeric(greta_log_probs$adjusted) - as.numeric(greta_log_probs$unadjusted) | ||
# as.numeric(log_prob_adjustment) | ||
# | ||
# # use the old (current one) | ||
# old_bijector <- tf_correlation_cholesky_bijector() | ||
# | ||
# old_log_prob_adjustment <- old_bijector$forward_log_det_jacobian(t(as.matrix(free_state))) | ||
# | ||
# as.numeric(old_log_prob_adjustment) | ||
# as.numeric(greta_log_probs$adjusted) - as.numeric(greta_log_probs$unadjusted) | ||
# | ||
# }) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,120 +1,133 @@ | ||
test_that("Log prob for lkj is correct", { | ||
# General process is: | ||
# 1. Simulate a lkj draw x with rlkj() | ||
# 2. Transform x to the equivalent free_state, using the bijector but running | ||
# it in reverse | ||
# 3. Run the log_prob() function on free_state | ||
# 4. Run dlkj(..., log = TRUE) on x, and compare with result of step 3 | ||
|
||
# 1. Simulate a lkj draw x with rlkj() ---------------------------------- | ||
set.seed(2024-10-31-1027) | ||
|
||
eta <- 1 | ||
dim <- 2 | ||
|
||
x <- rlkjcorr(n = 1, eta = eta, dimension = dim) | ||
chol_x <- chol(x) | ||
|
||
## 2. Transform x to the equivalent free_state, using the bijector but | ||
## running it in reverse ----------------------------------------------------- | ||
# we need to get a free state that we can plug into log prob. We know that | ||
# this free state matches the chol_x, so we can compare them later. | ||
|
||
new_lkj_bijector <- function(){ | ||
steps <- list( | ||
tfp$bijectors$Transpose(perm = 1:0), | ||
tfp$bijectors$CorrelationCholesky() | ||
) | ||
bijector <- tfp$bijectors$Chain(steps) | ||
bijector | ||
} | ||
|
||
a_bijector <- new_lkj_bijector() | ||
free_state <- a_bijector$inverse(fl(chol_x)) | ||
|
||
# TODO | ||
# get the greta log prob function | ||
# x_g <- lkj_correlation(eta = eta, dimension = dim) | ||
# m_g <- model(x_g) | ||
# greta_log_prob <- m_g$dag$generate_log_prob_function() | ||
# | ||
# free_state_mat <- t(as.matrix(free_state)) | ||
# log_probs <- greta_log_prob(free_state_mat) | ||
|
||
# compare this with a direct calculation of what it should be, using TF | ||
|
||
# lkj distribution for choleskies, matching what's in the greta code | ||
lkj_cholesky_dist <- tfp$distributions$CholeskyLKJ( | ||
dimension = as.integer(dim), | ||
concentration = fl(eta) | ||
) | ||
|
||
# we should expect this to have 1s on the diagonal... | ||
new_draws <- lkj_cholesky_dist$sample() | ||
|
||
new_draws | ||
|
||
chol2symm(t(new_draws)) | ||
chol2symm(new_draws) | ||
tf_chol2symm(new_draws) | ||
|
||
tf_chol2symm_new <- function(x) { | ||
tf$matmul(x, tf_transpose(x)) | ||
} | ||
|
||
tf_transpose(new_draws) |> | ||
tf_chol2symm_new() | ||
|
||
tf_transpose(new_draws) |> | ||
tf_chol2symm() | ||
|
||
|
||
chol2symm(tf_transpose(lkj_cholesky_dist$sample()) | ||
tf_chol2symm(lkj_cholesky_dist$sample()) | ||
tf_chol2symm(tf_tranpose(lkj_cholesky_dist$sample())) | ||
|
||
lkj_dist <- tfp$distributions$LKJ(dimension = as.integer(dim), | ||
concentration = fl(eta), | ||
input_output_cholesky = FALSE) | ||
|
||
# project forward from the free state | ||
chol_x_new <- a_bijector$forward(free_state) | ||
x_new <- tf_chol2symm(chol_x_new) | ||
# compute the log density of the transformed variable, and the adjustment for | ||
# change of support | ||
### njt log_density_unadjusted? | ||
log_density_raw <- lkj_cholesky_dist$log_prob(chol_x_new) | ||
log_density_raw_new <- lkj_dist$log_prob(x_new) | ||
adjustment <- a_bijector$forward_log_det_jacobian(free_state) | ||
log_density <- log_density_raw + adjustment | ||
log_density_new <- log_density_raw_new + adjustment | ||
|
||
log_density | ||
log_density_new | ||
# greta log probs should match alternative way to calculate these | ||
expect_equal(as.numeric(log_density), as.numeric(log_probs$adjusted)) | ||
# this is what the unadjusted version would look like, without accounting for | ||
# the density adjustment due to the bijector | ||
expect_equal(as.numeric(log_density_raw), as.numeric(log_probs$unadjusted)) | ||
|
||
# the difference between adjusted and unadjusted in greta is the same as | ||
# the bijector adjustment calculated here, so the bijector adjustment | ||
# calculated in greta must be correct, and it must be a problem with the | ||
# distribution, not the bijector | ||
expect_equal( | ||
as.numeric(adjustment), | ||
as.numeric(log_probs$adjusted - log_probs$unadjusted) | ||
) | ||
|
||
# check the TF distribution we use above against the R version (the same) | ||
expect_equal( | ||
as.numeric(log_density_raw), | ||
dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim) | ||
) | ||
|
||
expect_equal( | ||
as.numeric(log_density_raw_new), | ||
dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim) | ||
) | ||
|
||
}) | ||
# test_that("Log prob for lkj is correct", { | ||
# # General process is: | ||
# # 1. Simulate a lkj draw x with rlkj() | ||
# # 2. Transform x to the equivalent free_state, using the bijector but running | ||
# # it in reverse | ||
# # 3. Run the log_prob() function on free_state | ||
# # 4. Run dlkj(..., log = TRUE) on x, and compare with result of step 3 | ||
# | ||
# # 1. Simulate a lkj draw x with rlkj() ---------------------------------- | ||
# set.seed(2024-10-31-1027) | ||
# | ||
# eta <- 1 | ||
# dim <- 2 | ||
# | ||
# x <- rlkjcorr(n = 1, eta = eta, dimension = dim) | ||
# chol_x <- chol(x) | ||
# | ||
# ## 2. Transform x to the equivalent free_state, using the bijector but | ||
# ## running it in reverse ----------------------------------------------------- | ||
# # we need to get a free state that we can plug into log prob. We know that | ||
# # this free state matches the chol_x, so we can compare them later. | ||
# | ||
# # old bijector? | ||
# | ||
# new_lkj_bijector <- function(){ | ||
# steps <- list( | ||
# tfp$bijectors$Transpose(perm = 1:0), | ||
# tfp$bijectors$CorrelationCholesky() | ||
# ) | ||
# bijector <- tfp$bijectors$Chain(steps) | ||
# bijector | ||
# } | ||
# | ||
# a_bijector <- new_lkj_bijector() | ||
# free_state_new_bijector <- a_bijector$inverse(fl(chol_x)) | ||
# | ||
# # TODO | ||
# # these are the same! Do we care which bijector we use? | ||
# # We probably want the old one as it has a different log det jeacobian fun? | ||
# old_lkj_bijector <- tf_correlation_cholesky_bijector() | ||
# free_state_old_bijector <- old_lkj_bijector$inverse(fl(chol_x)) | ||
# expect_equal( | ||
# as.numeric(free_state_old_bijector), | ||
# as.numeric(free_state_new_bijector) | ||
# ) | ||
# | ||
# # comparing log_det_jacobian of old and new | ||
# free_state_new_bijector | ||
# free_state_old_bijector | ||
# t(as.matrix(free_state_old_bijector)) | ||
# free_state_old_bijector | ||
# old_lkj_bijector$forward_log_det_jacobian(t(as.matrix(free_state_old_bijector))) | ||
# a_bijector$forward_log_det_jacobian(free_state_old_bijector) | ||
# | ||
# # TODO | ||
# # get the greta log prob function | ||
# # x_g <- lkj_correlation(eta = eta, dimension = dim) | ||
# # m_g <- model(x_g) | ||
# # greta_log_prob <- m_g$dag$generate_log_prob_function() | ||
# # | ||
# # free_state_mat <- t(as.matrix(free_state)) | ||
# # log_probs <- greta_log_prob(free_state_mat) | ||
# | ||
# # compare this with a direct calculation of what it should be, using TF | ||
# | ||
# # lkj distribution for choleskies, matching what's in the greta code | ||
# lkj_cholesky_dist <- tfp$distributions$CholeskyLKJ( | ||
# dimension = as.integer(dim), | ||
# concentration = fl(eta) | ||
# ) | ||
# | ||
# # we should expect this to have 1s on the diagonal... | ||
# new_draws <- lkj_cholesky_dist$sample() | ||
# | ||
# new_draws | ||
# | ||
# tf$matmul(new_draws, new_draws, adjoint_b = TRUE) | ||
# | ||
# lkj_dist <- tfp$distributions$LKJ(dimension = as.integer(dim), | ||
# concentration = fl(eta), | ||
# input_output_cholesky = FALSE) | ||
# | ||
# lkj_dist$sample() | ||
# | ||
# # same value for both old and new bijector | ||
# free_state <- free_state_old_bijector | ||
# # project forward from the free state | ||
# chol_x_new <- a_bijector$forward(free_state) | ||
# chol_x_new | ||
# chol2symm(chol_x_new) | ||
# x_new <- tf$matmul(new_draws, new_draws, adjoint_b = TRUE) | ||
# tf$matmul(new_draws, new_draws, adjoint_a = TRUE) | ||
# tf$matmul(chol_x_new, chol_x_new,adjoint_a = TRUE) | ||
# x_new | ||
# x | ||
# | ||
# # compute the log density of the transformed variable, and the adjustment for | ||
# # change of support | ||
# log_density_raw <- lkj_cholesky_dist$log_prob(chol_x_new) | ||
# log_density_raw_new <- lkj_dist$log_prob(x_new) | ||
# adjustment <- a_bijector$forward_log_det_jacobian(free_state) | ||
# log_density <- log_density_raw + adjustment | ||
# log_density_new <- log_density_raw_new + adjustment | ||
# | ||
# log_density | ||
# | ||
# # greta log probs should match alternative way to calculate these | ||
# expect_equal(as.numeric(log_density), as.numeric(log_probs$adjusted)) | ||
# # this is what the unadjusted version would look like, without accounting for | ||
# # the density adjustment due to the bijector | ||
# expect_equal(as.numeric(log_density_raw), as.numeric(log_probs$unadjusted)) | ||
# | ||
# # the difference between adjusted and unadjusted in greta is the same as | ||
# # the bijector adjustment calculated here, so the bijector adjustment | ||
# # calculated in greta must be correct, and it must be a problem with the | ||
# # distribution, not the bijector | ||
# expect_equal( | ||
# as.numeric(adjustment), | ||
# as.numeric(log_probs$adjusted - log_probs$unadjusted) | ||
# ) | ||
# | ||
# # check the TF distribution we use above against the R version (the same) | ||
# expect_equal( | ||
# as.numeric(log_density_raw), | ||
# dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim) | ||
# ) | ||
# | ||
# expect_equal( | ||
# as.numeric(log_density_raw_new), | ||
# dlkj_correlation(x, eta = eta, log = TRUE, dimension = dim) | ||
# ) | ||
# | ||
# }) |
Oops, something went wrong.