From ac8ed5902a2c96019ea1137b5138d48017fabf4e Mon Sep 17 00:00:00 2001 From: Bin Du Date: Tue, 28 Nov 2023 11:05:49 -0800 Subject: [PATCH] `SimpleSentimentModel` should inherit from `BatchedModel`. This class implements the abstract method `predict_minibatch` in `BatchedModel`. Addressing https://github.com/PAIR-code/lit/issues/1361. PiperOrigin-RevId: 586041904 --- lit_nlp/examples/simple_pytorch_demo.py | 4 ++-- lit_nlp/examples/sst_pytorch_demo.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lit_nlp/examples/simple_pytorch_demo.py b/lit_nlp/examples/simple_pytorch_demo.py index 2a925895..cde642db 100644 --- a/lit_nlp/examples/simple_pytorch_demo.py +++ b/lit_nlp/examples/simple_pytorch_demo.py @@ -79,7 +79,7 @@ def _from_pretrained(cls, *args, **kw): return cls.from_pretrained(*args, from_tf=True, **kw) -class SimpleSentimentModel(lit_model.Model): +class SimpleSentimentModel(lit_model.BatchedModel): """Simple sentiment analysis model.""" LABELS = ["0", "1"] # negative, positive @@ -103,7 +103,7 @@ def __init__(self, model_name_or_path): ## # LIT API implementation def max_minibatch_size(self): - # This tells lit_model.Model.predict() how to batch inputs to + # This tells lit_model.BatchedModel.predict() how to batch inputs to # predict_minibatch(). # Alternately, you can just override predict() and handle batching yourself. return 32 diff --git a/lit_nlp/examples/sst_pytorch_demo.py b/lit_nlp/examples/sst_pytorch_demo.py index fd341b37..dede8a61 100644 --- a/lit_nlp/examples/sst_pytorch_demo.py +++ b/lit_nlp/examples/sst_pytorch_demo.py @@ -70,7 +70,7 @@ def _from_pretrained(cls, *args, **kw): return cls.from_pretrained(*args, from_tf=True, **kw) -class SimpleSentimentModel(lit_model.Model): +class SimpleSentimentModel(lit_model.BatchedModel): """Simple sentiment analysis model.""" LABELS = ["0", "1"] # negative, positive @@ -95,7 +95,7 @@ def __init__(self, model_name_or_path): ## # LIT API implementation def max_minibatch_size(self): - # This tells lit_model.Model.predict() how to batch inputs to + # This tells lit_model.BatchedModel.predict() how to batch inputs to # predict_minibatch(). # Alternately, you can just override predict() and handle batching yourself. return 32