Skip to content

Commit

Permalink
Merge pull request #13 from tanaylab/feat@interactions
Browse files Browse the repository at this point in the history
Feat@interactions
  • Loading branch information
aviezerl authored Sep 30, 2024
2 parents c75e8e8 + 61fb78a commit 1902702
Show file tree
Hide file tree
Showing 24 changed files with 459 additions and 83 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
^man/figures/pipeline\.png$
^example_data$
^data-raw$
^vignettes/articles/iceqream\.R$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ example_data
vignettes/articles/iceqream_cache/
data-raw/*.ipynb
data-raw/*.png
vignettes/articles/iceqream.R
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Authors@R: c(
)
Description: iceqream is a package for regressing accessibility from sequences using physical models of TF binding, It models TF effective concentrations as latent variables that activate or repress regulatory elements in a nonlinear fashion, with possible contribution from pairwise interactions and synergistic chromosomal domain effects. iceqream allows inference and synthesis of models explaining accessibility dynamics over an entire single cell manifold.
License: MIT + file LICENSE
Depends:
R (>= 2.10),
misha (>= 4.2.0)
Imports:
cli,
dplyr,
Expand Down Expand Up @@ -50,11 +53,9 @@ Remotes:
tanaylab/prego
Config/testthat/edition: 3
Encoding: UTF-8
Language: es
Language: en-US
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Config/Needs/website: rmarkdown
Depends:
R (>= 2.10)
LazyData: true
URL: https://tanaylab.github.io/iceqream/
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

export("%>%")
export(add_features_r2)
export(add_interactions)
export(compute_motif_directional_hits)
export(compute_motif_energies)
export(compute_tracks_q)
Expand Down Expand Up @@ -65,6 +66,7 @@ export(rescale)
export(split_traj_model_to_train_test)
export(traj_model_to_iq_feature_list)
export(traj_model_to_pbm_list)
export(traj_model_variable_response)
exportClasses(IQFeature)
exportClasses(IQFeatureGroup)
exportClasses(IQSeqFeature)
Expand Down
16 changes: 13 additions & 3 deletions R/TrajectoryModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
#' @slot normalization_intervals data.frame
#' A data frame containing the intervals used for energy normalization.
#'
#' @slot interactions matrix
#' A matrix of the interaction features.
#'
#'
#'
#' @exportClass TrajectoryModel
Expand All @@ -65,6 +68,7 @@ TrajectoryModel <- setClass(
normalization_intervals = "data.frame",
additional_features = "data.frame",
features_r2 = "numeric",
interactions = "matrix",
params = "list"
)
)
Expand All @@ -74,7 +78,12 @@ TrajectoryModel <- setClass(
#' @exportMethod show
setMethod("show", signature = "TrajectoryModel", definition = function(object) {
cli::cli({
cli::cli_text("{.cls TrajectoryModel} with {.val {length(object@motif_models)}} motifs and {.val {length(object@additional_features)}} additional features\n")
if (has_interactions(object)) {
cli::cli_text("{.cls TrajectoryModel} with {.val {length(object@motif_models)}} motifs, {.val {length(object@additional_features)}} additional features and {.val {n_interactions(object)}} interaction terms\n")
} else {
cli::cli_text("{.cls TrajectoryModel} with {.val {length(object@motif_models)}} motifs and {.val {length(object@additional_features)}} additional features\n")
}

cli::cli_text("\n")
cli::cli_text("Slots include:")
cli_ul(c("{.field @model}: A GLM model object. Number of non-zero coefficients: {.val {sum(object@model$beta[, 1] != 0)}}"))
Expand All @@ -88,8 +97,9 @@ setMethod("show", signature = "TrajectoryModel", definition = function(object) {
cli_ul(c("{.field @predicted_diff_score}: A numeric value representing the predicted difference score"))
cli_ul(c("{.field @initial_prego_models}: A list of prego models used in the initial phase of the algorithm ({.val {length(object@initial_prego_models)}} models)"))
cli_ul(c("{.field @peak_intervals}: A data frame containing the peak intervals ({.val {nrow(object@peak_intervals)}} elements)"))
if ("normalization_intervals" %in% slotNames(object)) { # here for backwards compatibility
cli_ul(c("{.field @normalization_intervals}: A data frame containing the intervals used for energy normalization ({.val {nrow(object@normalization_intervals)}} elements)"))
cli_ul(c("{.field @normalization_intervals}: A data frame containing the intervals used for energy normalization ({.val {nrow(object@normalization_intervals)}} elements)"))
if (has_interactions(object)) {
cli_ul(c("{.field @interactions}: A matrix of the interaction features ({.val {nrow(object@interactions)}}x{.val {ncol(object@interactions)}})"))
}
if (length(object@features_r2) > 0) {
cli_ul(c("{.field @features_r2}: A numeric vector of the added R^2 values for each feature ({.val {length(object@features_r2)}} elements)"))
Expand Down
4 changes: 2 additions & 2 deletions R/distill-motifs.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ distill_motifs <- function(features, target_number, glm_model, y, seqs, norm_seq
select(pos, A, C, G, T)
optimize_pwm <- FALSE
} else {
cli::cli_alert_warning("No current model found for {.val {x$feat}}. Distilling on a single motif")
cli::cli_alert_warning("No current model found for {.val {x$feat}}.")
}
} else {
cli_alert_info("Running {.field prego} on cluster {.val {x$feat}} (distilling {.val {n_feats}} motifs)")
cli_alert_info("Running {.field prego} on cluster {.val {x$feat}} (fusing {.val {n_feats}} motifs)")
}

res <- run_prego_on_clust_residuals(
Expand Down
6 changes: 3 additions & 3 deletions R/distill-multi-traj.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ distill_traj_model_multi <- function(traj_models, max_motif_num = NULL, min_diff
optimize_pwm <- FALSE
motif <- motif_models[[x$feat[1]]]$pssm
} else {
cli_alert_info("Running {.field prego} on cluster {.val {clust_name}}, distilling {.val {n_feats}} features")
cli_alert_info("Running {.field prego} on cluster {.val {clust_name}}, fusing {.val {n_feats}} features")
}


Expand Down Expand Up @@ -205,7 +205,7 @@ distill_traj_model_multi <- function(traj_models, max_motif_num = NULL, min_diff
) %>%
cli::cli_fmt()

cli::cli_alert_success("Finished distilling cluster {.val {clust_name}}")
cli::cli_alert_success("Finished fusing cluster {.val {clust_name}}")

return(prego::export_regression_model(prego_model))
}, .parallel = TRUE)
Expand Down Expand Up @@ -247,7 +247,7 @@ distill_traj_model_multi <- function(traj_models, max_motif_num = NULL, min_diff
compute_traj_list_stats(traj_models_full) %>% mutate(type = "full")
)

cli_alert_success("Finished distilling trajectory models")
cli_alert_success("Finished fusing trajectory models")
purrr::walk(names(traj_models), ~ {
cli_alert_info("Model {.field {.x}}: R^2: {.val {traj_models_new[[.x]]@params$stats$r2_all}} ({.val {length(traj_models_new[[.x]]@motif_models)}} motifs), before distillation: {.val {traj_models[[.x]]@params$stats$r2_all}} ({.val {length(traj_models[[.x]]@motif_models)}} motifs)")
})
Expand Down
43 changes: 0 additions & 43 deletions R/filter-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,46 +95,3 @@ filter_traj_model <- function(traj_model, r2_threshold = 0.0005, bits_threshold

return(traj_model_new)
}


filter_model_using_coefs <- function(X, coefs, diff_score, alpha, lambda, seed, full_model, n_motifs, ignore_variables = NULL) {
y <- norm01(diff_score)
variables <- coefs$variable
if (!is.null(ignore_variables)) {
variables <- variables[!(variables %in% ignore_variables)]
}
coefs_max <- coefs %>%
tibble::column_to_rownames("variable") %>%
.[variables, ] %>%
apply(1, max) %>%
sort(decreasing = TRUE)

vars_f <- names(coefs_max)[1:n_motifs]

X_f <- X[, grep(paste0("(", paste(c(vars_f, ignore_variables), collapse = "|"), ").+"), colnames(X))]
cli_alert_info("Number of features left: {.val {length(vars_f)}}")

model_f <- glmnet::glmnet(X_f, y, binomial(link = "logit"), alpha = alpha, lambda = lambda, parallel = FALSE, seed = seed)
pred_f <- logist(glmnet::predict.glmnet(model_f, newx = X_f, type = "link", s = lambda))[, 1]
pred_f <- norm01(pred_f)
pred_f <- rescale(pred_f, diff_score)
r2_f <- cor(pred_f, y)^2
cli_alert_info("R^2 after filtering: {.val {r2_f}}")

return(list(model = model_f, pred = pred_f, X = X_f, r2 = r2_f, vars = vars_f))
}


filter_traj_model_using_coefs <- function(traj_model, n_motifs) {
res <- filter_model_using_coefs(traj_model@model_features, traj_model@coefs, traj_model@diff_score, traj_model@params$alpha, traj_model@params$lambda, traj_model@params$seed, traj_model@model, ignore_variables = colnames(traj_model@additional_features), n_motifs = n_motifs)

traj_model@model <- res$model
traj_model@predicted_diff_score <- res$pred
traj_model@model_features <- res$X
traj_model@coefs <- get_model_coefs(res$model)
traj_model@normalized_energies <- traj_model@normalized_energies[, res$vars, drop = FALSE]

cli_alert_success("After filtering: Number of non-zero coefficients: {.val {sum(traj_model@model$beta != 0)}} (out of {.val {ncol(traj_model@model_features)}}). R^2: {.val {cor(traj_model@predicted_diff_score, norm01(traj_model@diff_score))^2}}")

return(traj_model)
}
14 changes: 14 additions & 0 deletions R/inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,20 @@ infer_trajectory_motifs <- function(traj_model, peak_intervals, atac_scores = NU
e_test <- cbind(e_test, additional_features)
}

if (has_interactions(traj_model)) {
ftv_inter <- feat_to_variable(traj_model, add_type = TRUE) %>%
filter(type == "interaction") %>%
distinct(variable, term1, term2)
cli::cli_alert_info("Computing {.val {nrow(ftv_inter)}} interaction terms")
interactions <- create_specifc_terms(e_test, ftv_inter)
interactions <- interactions[, colnames(traj_model@interactions), drop = FALSE]
e_test <- cbind(e_test, interactions)
}

e_test_logist <- create_logist_features(e_test)
e_test_logist <- e_test_logist[, colnames(traj_model@model_features), drop = FALSE]

cli::cli_alert_info("Inferring the model on {.val {nrow(e_test_logist)}} intervals")
pred <- predict_traj_model(traj_model, e_test_logist)
traj_model@predicted_diff_score <- c(traj_model@predicted_diff_score, pred)

Expand All @@ -70,6 +81,9 @@ infer_trajectory_motifs <- function(traj_model, peak_intervals, atac_scores = NU
traj_model@additional_features <- bind_rows(traj_model@additional_features, as.data.frame(additional_features))
}

if (has_interactions(traj_model)) {
traj_model@interactions <- rbind(traj_model@interactions, interactions)
}

return(traj_model)
}
Expand Down
143 changes: 143 additions & 0 deletions R/interactions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
has_interactions <- function(traj_model) {
n_interactions(traj_model) > 0
}

n_interactions <- function(traj_model) {
ncol(traj_model@interactions)
}

create_specifc_terms <- function(energies, terms) {
term1_matrix <- energies[, terms$term1]
term2_matrix <- energies[, terms$term2]
inter <- term1_matrix * term2_matrix
inter <- t(t(inter) / apply(inter, 2, max, na.rm = TRUE))
inter <- apply(inter, 2, norm01) * 1
colnames(inter) <- terms$variable
return(inter)
}


create_interaction_terms <- function(energies, motif_feats = NULL, add_feats = NULL, additional_features = NULL, max_motif_n = NULL, max_add_n = NULL) {
create_interactions <- function(features, data, max_n) {
if (is.null(features) || is.null(data)) {
return(NULL)
}

features <- head(features, n = max_n %||% length(features))

interactions <- purrr::map_dfc(features, ~ {
inter <- energies[, setdiff(colnames(energies), .x)] * data[, .x]
inter <- t(t(inter) / apply(inter, 2, max, na.rm = TRUE))
colnames(inter) <- paste0(.x, ":", colnames(inter))
inter
})

interactions <- apply(interactions, 2, norm01) * 1
interactions
}


add_inter <- create_interactions(add_feats, additional_features, max_add_n)

if (!is.null(add_inter)) {
cli::cli_alert_info("Created {.val {ncol(add_inter)}} interactions between additional features and motif features.")
}


motif_inter <- create_interactions(motif_feats, energies, max_motif_n)
if (!is.null(motif_inter)) {
cli::cli_alert_info("Created {.val {ncol(motif_inter)}} interactions between motif features.")
}

interactions <- cbind(motif_inter, add_inter)
if (!is.null(interactions)) {
if (!is.null(rownames(energies))) {
rownames(interactions) <- rownames(energies)
}
cli::cli_alert_info("Created {.val {ncol(interactions)}} interactions in total.")
}

return(interactions)
}

get_significant_interactions <- function(
energies, y, interaction_threshold, max_motif_n = NULL, max_add_n = NULL,
additional_features = NULL, lambda = 1e-5, alpha = 1, seed = 60427,
ignore_feats = c("TT", "CT", "GT", "AT", "TC", "CC", "GC", "AC", "TG", "CG", "GG", "AG", "TA", "CA", "GA", "AA")) {
glm_model_lin <- glmnet::glmnet(as.matrix(energies), y, binomial(link = "logit"), alpha = alpha, lambda = lambda, seed = seed)

feats_all <- abs(stats::coef(glm_model_lin)[-1])
names(feats_all) <- rownames(stats::coef(glm_model_lin))[-1]
sig_feats <- names(feats_all)[feats_all > interaction_threshold]
sig_feats <- setdiff(sig_feats, ignore_feats)

if (length(sig_feats) == 0) {
cli::cli_alert_warning("No significant features to consider for interactions.")
return(NULL)
}

add_feats <- intersect(sig_feats, colnames(additional_features))
motif_feats <- setdiff(sig_feats, add_feats)

cli::cli_alert_info("# of significant features to consider for interactions: {.val {length(sig_feats)}} (out of {.val {ncol(energies)}}) above the threshold of {.val {interaction_threshold}}. Of these, {.val {length(motif_feats)}} are motif features and {.val {length(add_feats)}} are additional features.")

if (!is.null(additional_features)) {
# remove the features from energies
energies <- energies[, setdiff(colnames(energies), colnames(additional_features))]

# remove the ignored features
energies <- energies[, setdiff(colnames(energies), ignore_feats)]
}

create_interaction_terms(energies,
motif_feats = motif_feats, add_feats = add_feats,
additional_features = additional_features, max_motif_n = max_motif_n, max_add_n = max_add_n
)
}

#' Add interactions to a trajectory model
#'
#' This function adds significant interactions to a given trajectory model if they do not already exist.
#' It identifies significant interactions based on the provided threshold and updates the model features
#' with logistic features derived from these interactions. The trajectory model is then re-learned with
#' the new features.
#'
#' @inheritParams regress_trajectory_motifs
#'
#' @return The updated trajectory model with added interactions.
#' @export
add_interactions <- function(traj_model, interaction_threshold = 0.001, max_motif_n = NULL, max_add_n = NULL, lambda = 1e-5, alpha = 1, seed = 60427) {
if (!has_interactions(traj_model)) {
cli::cli_alert("Adding interactions")
interactions <- get_significant_interactions(
cbind(traj_model@normalized_energies, traj_model@additional_features), norm01(traj_model@diff_score), interaction_threshold,
max_motif_n = max_motif_n, max_add_n = max_add_n,
additional_features = traj_model@additional_features, lambda = lambda, alpha = alpha, seed = seed
)

if (!is.null(interactions)) {
traj_model@interactions <- interactions
}

logist_inter <- create_logist_features(interactions)
traj_model@model_features <- cbind(traj_model@model_features, logist_inter)

cli::cli_alert_info("Re-learning the model with the new interactions. Number of features: {.val {ncol(traj_model@model_features)}}")
cli::cli_alert_info("R^2 all before learning: {.val {cor(traj_model@diff_score, traj_model@predicted_diff_score)^2}}")
if (traj_model_has_test(traj_model)) {
cli::cli_alert_info("R^2 train before learning: {.val {cor(traj_model@diff_score[traj_model@type == 'train'], traj_model@predicted_diff_score[traj_model@type == 'train'])^2}}")
cli::cli_alert_info("R^2 test before learning: {.val {cor(traj_model@diff_score[traj_model@type == 'test'], traj_model@predicted_diff_score[traj_model@type == 'test'])^2}}")
}

traj_model <- relearn_traj_model(traj_model, new_energies = FALSE, new_logist = FALSE, use_additional_features = TRUE, use_motifs = TRUE, verbose = FALSE)
cli::cli_alert_info("R^2 all after learning: {.val {cor(traj_model@diff_score, traj_model@predicted_diff_score)^2}}")
if (traj_model_has_test(traj_model)) {
cli::cli_alert_info("R^2 train after learning: {.val {cor(traj_model@diff_score[traj_model@type == 'train'], traj_model@predicted_diff_score[traj_model@type == 'train'])^2}}")
cli::cli_alert_info("R^2 test after learning: {.val {cor(traj_model@diff_score[traj_model@type == 'test'], traj_model@predicted_diff_score[traj_model@type == 'test'])^2}}")
}
} else {
cli::cli_alert_warning("Interactions already exist.")
}

return(traj_model)
}
Loading

0 comments on commit 1902702

Please sign in to comment.