From 57af09129f16669de666d13c15f9f7c85f391ba5 Mon Sep 17 00:00:00 2001 From: Marouane <42834703+MarouaneMaatouk@users.noreply.github.com> Date: Thu, 25 Jan 2024 16:50:40 +0100 Subject: [PATCH] Add XML loader (#887) * Add XML loader * Fix unit test, llama_hub link * Fix llama_hub link in readme --- llama_hub/file/xml/README.md | 19 ++++++ llama_hub/file/xml/__init__.py | 6 ++ llama_hub/file/xml/base.py | 95 +++++++++++++++++++++++++++++ llama_hub/file/xml/requirements.txt | 0 llama_hub/library.json | 4 ++ tests/file/json/test_json.py | 1 + tests/file/xml/test_xml.py | 64 +++++++++++++++++++ 7 files changed, 189 insertions(+) create mode 100644 llama_hub/file/xml/README.md create mode 100644 llama_hub/file/xml/__init__.py create mode 100644 llama_hub/file/xml/base.py create mode 100644 llama_hub/file/xml/requirements.txt create mode 100644 tests/file/xml/test_xml.py diff --git a/llama_hub/file/xml/README.md b/llama_hub/file/xml/README.md new file mode 100644 index 0000000000..c0d0d23608 --- /dev/null +++ b/llama_hub/file/xml/README.md @@ -0,0 +1,19 @@ +# XML Loader + +This loader extracts the text from a local XML file. A single local file is passed in each time you call `load_data`. + +## Usage + +To use this loader, you need to pass in a `Path` to a local file. + +```python +from pathlib import Path +from llama_index import download_loader + +XMLReader = download_loader("XMLReader") + +loader = XMLReader() +documents = loader.load_data(file=Path('../example.xml')) +``` + +This loader is designed to be used as a way to load data into [LlamaIndex](https://github.com/run-llama/llama_index/tree/main/llama_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent. See [here](https://github.com/run-llama/llama-hub/tree/main/llama_hub) for examples. \ No newline at end of file diff --git a/llama_hub/file/xml/__init__.py b/llama_hub/file/xml/__init__.py new file mode 100644 index 0000000000..1c0e6bba23 --- /dev/null +++ b/llama_hub/file/xml/__init__.py @@ -0,0 +1,6 @@ +"""Init file.""" +from llama_hub.file.xml.base import ( + XMLReader, +) + +__all__ = ["XMLReader"] diff --git a/llama_hub/file/xml/base.py b/llama_hub/file/xml/base.py new file mode 100644 index 0000000000..2572382563 --- /dev/null +++ b/llama_hub/file/xml/base.py @@ -0,0 +1,95 @@ +"""JSON Reader.""" + +import re +from pathlib import Path +from typing import Dict, List, Optional + +from llama_index.readers.base import BaseReader +from llama_index.readers.schema.base import Document +import xml.etree.ElementTree as ET + + +def _get_leaf_nodes_up_to_level(root: ET.Element, level: int) -> List[ET.Element]: + """Get collection of nodes up to certain level including leaf nodes + + Args: + root (ET.Element): XML Root Element + level (int): Levels to traverse in the tree + + Returns: + List[ET.Element]: List of target nodes + """ + + def traverse(current_node, current_level): + if len(current_node) == 0 or level == current_level: + # Keep leaf nodes and target level nodes + nodes.append(current_node) + elif current_level < level: + # Move to the next level + for child in current_node: + traverse(child, current_level + 1) + + nodes = [] + traverse(root, 0) + return nodes + + +class XMLReader(BaseReader): + """XML reader. + + Reads XML documents with options to help suss out relationships between nodes. + + Args: + tree_level_split (int): From which level in the xml tree we split documents, + the default level is the root which is level 0 + + """ + + def __init__(self, tree_level_split: Optional[int] = 0) -> None: + """Initialize with arguments.""" + super().__init__() + self.tree_level_split = tree_level_split + + def _parse_xmlelt_to_document( + self, root: ET.Element, extra_info: Optional[Dict] = None + ) -> List[Document]: + """Parse the xml object into a list of Documents. + + Args: + root: The XML Element to be converted. + extra_info (Optional[Dict]): Additional information. Default is None. + + Returns: + Document: The documents. + """ + nodes = _get_leaf_nodes_up_to_level(root, self.tree_level_split) + documents = [] + for node in nodes: + content = ET.tostring(node, encoding="utf8").decode("utf-8") + content = re.sub(r"^<\?xml.*", "", content) + content = content.strip() + documents.append(Document(text=content, extra_info=extra_info or {})) + + return documents + + def load_data( + self, + file: Path, + extra_info: Optional[Dict] = None, + ) -> List[Document]: + """Load data from the input file. + + Args: + file (Path): Path to the input file. + extra_info (Optional[Dict]): Additional information. Default is None. + + Returns: + List[Document]: List of documents. + """ + if not isinstance(file, Path): + file = Path(file) + + tree = ET.parse(file) + documents = self._parse_xmlelt_to_document(tree.getroot(), extra_info) + + return documents diff --git a/llama_hub/file/xml/requirements.txt b/llama_hub/file/xml/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/llama_hub/library.json b/llama_hub/library.json index dea2df6c52..66c1d1a054 100644 --- a/llama_hub/library.json +++ b/llama_hub/library.json @@ -1219,5 +1219,9 @@ "TelegramReader": { "id": "telegram", "author": "diicell" + }, + "XMLReader": { + "id": "file/xml", + "author": "mmaatouk" } } diff --git a/tests/file/json/test_json.py b/tests/file/json/test_json.py index ab8f247f53..6f9b786e1d 100644 --- a/tests/file/json/test_json.py +++ b/tests/file/json/test_json.py @@ -12,6 +12,7 @@ SAMPLE_JSONL = [json.dumps(SAMPLE_JSON), json.dumps({"name": "Jane Doe", "age": 25})] + # Fixture to create a temporary JSON file @pytest.fixture def json_file(tmp_path): diff --git a/tests/file/xml/test_xml.py b/tests/file/xml/test_xml.py new file mode 100644 index 0000000000..c53567c6e7 --- /dev/null +++ b/tests/file/xml/test_xml.py @@ -0,0 +1,64 @@ +import pytest + +from llama_hub.file.xml import XMLReader +import xml.etree.ElementTree as ET + +# Sample XML data for testing +SAMPLE_XML = """ + + + Apple + Red + 1.20 + + + Carrot + Orange + 0.50 + + + Banana + Yellow + 0.30 + + + Fresh Produce Ltd. +
+ 123 Green Lane + Garden City + Harvest + 54321 +
+
+
""" + + +# Fixture to create a temporary XML file +@pytest.fixture +def xml_file(tmp_path): + file = tmp_path / "test.xml" + with open(file, "w") as f: + f.write(SAMPLE_XML) + return file + + +def test_xml_reader_init(): + reader = XMLReader(tree_level_split=2) + assert reader.tree_level_split == 2 + + +def test_parse_xml_to_document(): + reader = XMLReader(1) + root = ET.fromstring(SAMPLE_XML) + documents = reader._parse_xmlelt_to_document(root) + assert "Fresh Produce Ltd." in documents[-1].text + assert "fruit" in documents[0].text + + +def test_load_data_xml(xml_file): + reader = XMLReader() + + documents = reader.load_data(xml_file) + assert len(documents) == 1 + assert "Apple" in documents[0].text + assert "Garden City" in documents[0].text