Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Fix MSA reading in predict.py #538

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions examples/variant-prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,26 @@ def remove_insertions(sequence: str) -> str:
return sequence.translate(translation)


def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]:
def process_sequence(sequence: str, lowercase_type="nonfocus") -> str:
if lowercase_type == "nonfocus":
return sequence.upper()
elif lowercase_type == "insertion":
return remove_insertions(sequence)
else:
raise ValueError(f"lowercase_type should be nonfocus or insert but got {lowercase_type}")


def read_msa(filename: str, nseq: int, lowercase_type="nonfocus") -> List[Tuple[str, str]]:
""" Reads the first nseq sequences from an MSA file, automatically removes insertions.

The input file must be in a3m format (although we use the SeqIO fasta parser)
for remove_insertions to work properly."""
If lowercase_type is 'insertion', the input file must be in a2m/a3m format
for remove_insertions to work properly.

If lowercase_type is 'nonfocus', all sequences should have the same length.
"""

msa = [
(record.description, remove_insertions(str(record.seq)))
(record.description, process_sequence(str(record.seq), lowercase_type))
for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)
]
return msa
Expand Down Expand Up @@ -99,6 +111,13 @@ def create_parser():
default=400,
help="number of sequences to select from the start of the MSA"
)
parser.add_argument(
"--lowercase-type", choices=["insertion", "nonfocus"], default="nonfocus",
help=(
"How lowercase amino acids in MSA should be interpreted: "
"nonfocus for EVMutation/Gym-style, insertion for a2m/a3m."
)
)
# fmt: on
parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available")
return parser
Expand Down Expand Up @@ -159,7 +178,7 @@ def main(args):
batch_converter = alphabet.get_batch_converter()

if isinstance(model, MSATransformer):
data = [read_msa(args.msa_path, args.msa_samples)]
data = [read_msa(args.msa_path, args.msa_samples, lowercase_type=args.lowercase_type)]
assert (
args.scoring_strategy == "masked-marginals"
), "MSA Transformer only supports masked marginal strategy"
Expand Down