diff --git a/backend/app/app/source/AAAI.py b/backend/app/app/source/AAAI.py index 74ba44c..0a5cf6d 100644 --- a/backend/app/app/source/AAAI.py +++ b/backend/app/app/source/AAAI.py @@ -2,7 +2,7 @@ from scrapy.http import HtmlResponse -from app.source.base import PaperRequestsTask +from app.source.base import PaperRequestsTask, PaperType import xml.dom.minidom @@ -43,7 +43,7 @@ def parse(response: HtmlResponse): return item @staticmethod - def post_parse(item: dict[str, Any]) -> dict[str, Any]: + def post_parse(item: PaperType) -> PaperType: if item["authors"] is not None: for i, author in enumerate(item["authors"]): item["authors"][i] = author.strip() diff --git a/backend/app/app/source/Arxiv.py b/backend/app/app/source/Arxiv.py index adb9246..223fd50 100644 --- a/backend/app/app/source/Arxiv.py +++ b/backend/app/app/source/Arxiv.py @@ -3,7 +3,7 @@ from scrapy.http import HtmlResponse -from app.source.base import RSSTask +from app.source.base import PaperType, RSSTask CATEGORY_MAP = { "cs.AI": "Artificial Intelligence", @@ -48,7 +48,7 @@ def parse(entry: dict) -> dict[str, Any]: "abstract": entry["summary"], } - def post_parse(self, entry: dict[str, Any]) -> dict[str, Any]: + def post_parse(self, entry: PaperType) -> PaperType: category = re.findall(r"\[(.*?)\]", entry["title"])[0] entry["title"] = entry["title"].split("(", 1)[0] entry["authors"] = ( diff --git a/backend/app/app/source/NIPS.py b/backend/app/app/source/NIPS.py index d6f4f7b..918407e 100644 --- a/backend/app/app/source/NIPS.py +++ b/backend/app/app/source/NIPS.py @@ -2,7 +2,7 @@ from scrapy.http import HtmlResponse -from app.source.base import PaperRequestsTask, openreview_url +from app.source.base import PaperRequestsTask, PaperType, openreview_url class NIPS(PaperRequestsTask): @@ -34,9 +34,9 @@ def parse(response: HtmlResponse): return item @staticmethod - def post_parse(item: dict[str, Any]) -> dict[str, Any]: + def post_parse(item: PaperType) -> PaperType: if item["authors"] is not None: - item["authors"] = item["authors"].split(" · ") + item["authors"] = item["authors"].split(" · ") # type: ignore if item["authors"] is not None: for i, author in enumerate(item["authors"]): item["authors"][i] = author.strip() diff --git a/backend/app/app/source/base.py b/backend/app/app/source/base.py index fb535a4..f822ccb 100644 --- a/backend/app/app/source/base.py +++ b/backend/app/app/source/base.py @@ -1,6 +1,6 @@ import logging from datetime import datetime -from typing import Any +from typing import Any, Optional, TypedDict import feedparser import requests @@ -12,6 +12,15 @@ from app.models import CrawledItem, Item +class PaperType(TypedDict): + title: str + abstract: Optional[str] + url: str + authors: Optional[list[str]] + category: Optional[list[str]] + keywords: Optional[list[str]] + + def openreview_url(urls): for url in urls[::-1]: if "openreview" in url: @@ -46,7 +55,7 @@ def get_urls(cls) -> list[str]: return cls.parse_urls(response) @staticmethod - def parse(response: HtmlResponse) -> dict[str, str]: + def parse(response: HtmlResponse) -> PaperType: # you should return dict with fields: # title, abstract, url raise NotImplementedError @@ -81,7 +90,7 @@ def save(self, data: list[tuple[str, dict[str, Any]]]) -> None: db.commit() @staticmethod - def post_parse(data: dict[str, Any]) -> dict[str, Any]: + def post_parse(data: PaperType) -> PaperType: # you can do some post processing here return data @@ -118,13 +127,13 @@ def db(self): return Session(engine) @staticmethod - def parse(entry) -> dict[str, Any]: + def parse(entry) -> PaperType: raise NotImplementedError - def post_parse(self, entry: dict[str, Any]) -> dict[str, Any]: + def post_parse(self, entry: PaperType) -> PaperType: return entry - def save(self, data: list[dict[str, Any]]) -> None: + def save(self, data: list[PaperType]) -> None: with self.db as db: # update Item table if exists for item in data: