Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'decontXcounts' #15

Open
namratabhattacharya opened this issue Mar 9, 2023 · 6 comments
Open

KeyError: 'decontXcounts' #15

namratabhattacharya opened this issue Mar 9, 2023 · 6 comments

Comments

@namratabhattacharya
Copy link

Upon running the preprocess.py, I am getting the following error-

python scripts/preprocess.py --ts_path input.h5ad --out_path input

Loaded data.
Subsetted to CCLE genes.
Traceback (most recent call last):
File "/home/namratab/namratab/TransferLearning/Comparison/exceiver/scripts/preprocess.py", line 150, in
sys.exit(main(args))
File "/home/namratab/namratab/TransferLearning/Comparison/exceiver/scripts/preprocess.py", line 53, in main
tsdata = ad.AnnData(tsdata.layers[args.count_var], obs=tsdata.obs, var=tsdata.var, uns=tsdata.uns)
File "/home/namratab/.local/lib/python3.9/site-packages/anndata/_core/aligned_mapping.py", line 113, in getitem
_subset(self.parent_mapping[key], self.subset_idx),
File "/home/namratab/.local/lib/python3.9/site-packages/anndata/_core/aligned_mapping.py", line 148, in getitem
return self._data[key]
KeyError: 'decontXcounts'

How to resolve this?

@khanu263
Copy link
Collaborator

Hi @namratabhattacharya -- you need to use the --count_var command line argument to specify the layer of your AnnData object to use. For instance, we worked a lot with the Tabula Sapiens dataset, so we default to using the decontXcounts layer. If you know the analogous layer in your AnnData object you should add --count_var [layer name] when running the script. If you don't need to select a specific layer you can just comment out that line in preprocess.py.

@namratabhattacharya
Copy link
Author

I have run the model successfully. However, I cannot retrieve the Leiden clusters. What steps are needed for that after the model is trained?

@khanu263
Copy link
Collaborator

If you're referring to the sample embeddings (Figure 1c in the preprint), these come from the self-attention step in the model. In other words, you should go through the forward method in models.py until line 205, at which point each cell will be represented as a 128x4 vector. To do the clustering, we simply reshaped these into 512x1 vectors and created an AnnData object, to which we could apply scanpy's clustering functions.

If you're referring to the gene embeddings (Figure 2a), these are stored within the gene_emb variable of the Exceiver class. You can query that nn.Embedding object (excluding the padding index) to retrieve a NumPy array and similarly use AnnData / scanpy to perform the clustering.

# retrieve gene embeddings
embeddings = model.gene_emb(torch.arange(model.gene_emb.num_embeddings - 2)).detach().cpu().numpy()

@namratabhattacharya
Copy link
Author

I want to find the sample embeddings mentioned in Fig 1c in the pre-print. Is it possible to fetch the embeddings without making changes in the code?

@khanu263
Copy link
Collaborator

Yeah, we currently don't have a simple function or anything that will let you get the sample embeddings for given inputs in a single line. That is a good idea and something we should add though! For now you'll need to partially go through the forward method as I described before.

@AidenSb
Copy link

AidenSb commented Jun 23, 2023

@namratabhattacharya
you can add this function to the model and use it to get your sample embeddings:

def get_latent(self, dataloader):
        device = 'cpu'
        latents = []
        self.eval()
        with torch.no_grad():
            for batch in dataloader:
                gene_ids, gene_vals, _, _, key_padding_mask = [x.to(device) for x in batch]
                input_query = self.query_emb.repeat(len(gene_ids), 1, 1)
                latent, _ = self.encoder_attn_step(gene_ids, gene_vals, input_query, key_padding_mask)
                latent = self.process_self_attn(latent)
                latents.append(latent.detach())
        all_latents = torch.cat(latents, dim=0)
        return all_latents.cpu().numpy()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants