Skip to content

Commit

Permalink
Feat/docling-support (#1763)
Browse files Browse the repository at this point in the history
* added tool for docling support

* docling support installation

* use file_paths instead of file_path

* fix import

* organized imports

* run_type docs

* needs to be list

* fixed logic

* logged but file_path is backwards compatible

* use file_paths instead of file_path 2

* added test for multiple sources for file_paths

* fix run-types

* enabling local files to work and type cleanup

* linted

* fix test and types

* fixed run types

* fix types

* renamed to CrewDoclingSource

* linted

* added docs

* resolve conflicts

---------

Co-authored-by: Brandon Hancock (bhancock_ai) <[email protected]>
Co-authored-by: Brandon Hancock <[email protected]>
  • Loading branch information
3 people authored Dec 23, 2024
1 parent c887ff1 commit b3185ad
Show file tree
Hide file tree
Showing 8 changed files with 1,166 additions and 35 deletions.
49 changes: 49 additions & 0 deletions docs/concepts/knowledge.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,55 @@ crew = Crew(
result = crew.kickoff(inputs={"question": "What city does John live in and how old is he?"})
```


Here's another example with the `CrewDoclingSource`
```python Code
from crewai import LLM, Agent, Crew, Process, Task
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource

# Create a knowledge source
content_source = CrewDoclingSource(
file_paths=[
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking",
"https://lilianweng.github.io/posts/2024-07-07-hallucination",
],
)

# Create an LLM with a temperature of 0 to ensure deterministic outputs
llm = LLM(model="gpt-4o-mini", temperature=0)

# Create an agent with the knowledge store
agent = Agent(
role="About papers",
goal="You know everything about the papers.",
backstory="""You are a master at understanding papers and their content.""",
verbose=True,
allow_delegation=False,
llm=llm,
)
task = Task(
description="Answer the following questions about the papers: {question}",
expected_output="An answer to the question.",
agent=agent,
)

crew = Crew(
agents=[agent],
tasks=[task],
verbose=True,
process=Process.sequential,
knowledge_sources=[
content_source
], # Enable knowledge by adding the sources here. You can also add more sources to the sources list.
)

result = crew.kickoff(
inputs={
"question": "What is the reward hacking paper about? Be sure to provide sources."
}
)
```

## Knowledge Configuration

### Chunking Configuration
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ openpyxl = [
"openpyxl>=3.1.5",
]
mem0 = ["mem0ai>=0.1.29"]
docling = [
"docling>=2.12.0",
]

[tool.uv]
dev-dependencies = [
Expand Down
53 changes: 40 additions & 13 deletions src/crewai/knowledge/source/base_file_knowledge_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

from pydantic import Field
from pydantic import Field, field_validator

from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
Expand All @@ -14,25 +14,36 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""

_logger: Logger = Logger(verbose=True)
file_path: Union[Path, List[Path], str, List[str]] = Field(
..., description="The path to the file"
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
safe_file_paths: List[Path] = Field(default_factory=list)

@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, values):
"""Validate that at least one of file_path or file_paths is provided."""
if v is None and ("file_path" not in values or values.get("file_path") is None):
raise ValueError("Either file_path or file_paths must be provided")
return v

def model_post_init(self, _):
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_paths()
self.validate_content()
self.content = self.load_content()

@abstractmethod
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
pass

def validate_paths(self):
def validate_content(self):
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
Expand All @@ -59,13 +70,29 @@ def convert_to_path(self, path: Union[Path, str]) -> Path:

def _process_file_paths(self) -> List[Path]:
"""Convert file_path to a list of Path objects."""
paths = (
[self.file_path]
if isinstance(self.file_path, (str, Path))
else self.file_path
)

if not isinstance(paths, list):
raise ValueError("file_path must be a Path, str, or a list of these types")
# Check if old file_path is being used
if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log(
"warning",
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
color="yellow",
)
paths = (
[self.file_path]
if isinstance(self.file_path, (str, Path))
else self.file_path
)
else:
if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []")
elif isinstance(self.file_paths, list) and len(self.file_paths) == 0:
raise ValueError("Empty file_paths are not allowed")
else:
paths = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else self.file_paths
)

return [self.convert_to_path(path) for path in paths]
2 changes: 1 addition & 1 deletion src/crewai/knowledge/source/base_knowledge_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
collection_name: Optional[str] = Field(default=None)

@abstractmethod
def load_content(self) -> Dict[Any, str]:
def validate_content(self) -> Any:
"""Load and preprocess content from the source."""
pass

Expand Down
120 changes: 120 additions & 0 deletions src/crewai/knowledge/source/crew_docling_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from pathlib import Path
from typing import Iterator, List, Optional, Union
from urllib.parse import urlparse

from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
from pydantic import Field

from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger


class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""

_logger: Logger = Logger(verbose=True)

file_path: Optional[List[Union[Path, str]]] = Field(default=None)
file_paths: List[Union[Path, str]] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
content: List[DoclingDocument] = Field(default_factory=list)
document_converter: DocumentConverter = Field(
default_factory=lambda: DocumentConverter(
allowed_formats=[
InputFormat.MD,
InputFormat.ASCIIDOC,
InputFormat.PDF,
InputFormat.DOCX,
InputFormat.HTML,
InputFormat.IMAGE,
InputFormat.XLSX,
InputFormat.PPTX,
]
)
)

def model_post_init(self, _) -> None:
if self.file_path:
self._logger.log(
"warning",
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
color="yellow",
)
self.file_paths = self.file_path
self.safe_file_paths = self.validate_content()
self.content = self._load_content()

def _load_content(self) -> List[DoclingDocument]:
try:
return self._convert_source_to_docling_documents()
except ConversionError as e:
self._logger.log(
"error",
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
"red",
)
raise e
except Exception as e:
self._logger.log("error", f"Error loading content: {e}")
raise e

def add(self) -> None:
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()

def _convert_source_to_docling_documents(self) -> List[DoclingDocument]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]

def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]:
chunker = HierarchicalChunker()
for chunk in chunker.chunk(doc):
yield chunk.text

def validate_content(self) -> List[Union[Path, str]]:
processed_paths: List[Union[Path, str]] = []
for path in self.file_paths:
if isinstance(path, str):
if path.startswith(("http://", "https://")):
try:
if self._validate_url(path):
processed_paths.append(path)
else:
raise ValueError(f"Invalid URL format: {path}")
except Exception as e:
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():
processed_paths.append(local_path)
else:
raise FileNotFoundError(f"File not found: {local_path}")
else:
# this is an instance of Path
processed_paths.append(path)
return processed_paths

def _validate_url(self, url: str) -> bool:
try:
result = urlparse(url)
return all(
[
result.scheme in ("http", "https"),
result.netloc,
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
]
)
except Exception:
return False
4 changes: 2 additions & 2 deletions src/crewai/knowledge/source/string_knowledge_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class StringKnowledgeSource(BaseKnowledgeSource):

def model_post_init(self, _):
"""Post-initialization method to validate content."""
self.load_content()
self.validate_content()

def load_content(self):
def validate_content(self):
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")
Expand Down
Loading

0 comments on commit b3185ad

Please sign in to comment.