Skip to content

Commit

Permalink
mask additional attributes in PPI site prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamasb committed Nov 22, 2023
1 parent b30d40f commit 01caa96
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions proteinworkshop/tasks/ppi_site_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def __init__(self, radius: float = 3.5, ca_only: bool = True) -> None:
self.fill_value = 1e-5
self.ca_only = ca_only
charstr: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
self.chain_map: Dict[str, int] = {
charstr[i]: i for i in range(len(charstr))
}
self.chain_map: Dict[str, int] = {charstr[i]: i for i in range(len(charstr))}

def __call__(self, data: Union[Protein, Data]):
# Map the chain labels to integers
Expand Down Expand Up @@ -61,9 +59,7 @@ def __call__(self, data: Union[Protein, Data]):
# Unwrap the coordinates
other_chains = other_chains.reshape(-1, 3)
# Remove any rows with 1e-5
other_chains = other_chains[
~torch.all(other_chains == self.fill_value, dim=1)
]
other_chains = other_chains[~torch.all(other_chains == self.fill_value, dim=1)]

# Create a KDTree
# If Ca only, we only see if the interacting chains are within the
Expand Down Expand Up @@ -101,6 +97,12 @@ def __call__(self, data: Union[Protein, Data]):
if data.x is not None:
data.x = data.x[mask]

if data.seq_pos is not None:
data.seq_pos = data.seq_pos[mask]

if data.amino_acid_one_hot is not None:
data.amino_acid_one_hot = data.amino_acid_one_hot[mask]

return data


Expand Down

0 comments on commit 01caa96

Please sign in to comment.