diff --git a/src/distilabel/models/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py index ba30735f61..652b9f2a7e 100644 --- a/src/distilabel/models/llms/llamacpp.py +++ b/src/distilabel/models/llms/llamacpp.py @@ -26,10 +26,7 @@ if TYPE_CHECKING: from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList - from distilabel.steps.tasks.typing import ( - FormattedInput, - StandardInput, - ) + from distilabel.steps.tasks.typing import FormattedInput, StandardInput class LlamaCppLLM(LLM, MagpieChatTemplateMixin): @@ -173,7 +170,7 @@ class User(BaseModel): _logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None) _model: Optional["Llama"] = PrivateAttr(...) - @model_validator(mode="after") # type: ignore + @model_validator(mode="after") def validate_magpie_usage( self, ) -> "LlamaCppLLM": @@ -195,7 +192,7 @@ def load(self) -> None: ) from ie self._model = Llama( - model_path=self.model_path.as_posix(), # type: ignore + model_path=self.model_path.as_posix(), seed=self.seed, n_ctx=self.n_ctx, n_batch=self.n_batch, @@ -223,7 +220,7 @@ def load(self) -> None: from transformers import AutoTokenizer except ImportError as ie: raise ImportError( - "Transformers is not installed. Please install it using `pip install transformers`." + "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`." ) from ie self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) if self._tokenizer.chat_template is None: diff --git a/src/distilabel/models/llms/ollama.py b/src/distilabel/models/llms/ollama.py index ae3f2715ce..b303bd526a 100644 --- a/src/distilabel/models/llms/ollama.py +++ b/src/distilabel/models/llms/ollama.py @@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union -from llama_cpp.llama_types import CreateChatCompletionResponse from pydantic import Field, PrivateAttr, model_validator, validate_call from typing_extensions import TypedDict @@ -26,13 +25,11 @@ from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput if TYPE_CHECKING: - from llama_cpp import CreateChatCompletionResponse from ollama import AsyncClient + from ollama._types import ChatResponse, GenerateResponse from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import ( - StandardInput, - ) + from distilabel.steps.tasks.typing import StandardInput # Copied from `ollama._types.Options` @@ -146,7 +143,7 @@ class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin): "`llama3`, `qwen2` or another pre-query template provided.", ) _num_generations_param_supported = False - _aclient: Optional["AsyncClient"] = PrivateAttr(...) + _aclient: Optional["AsyncClient"] = PrivateAttr(...) # type: ignore @model_validator(mode="after") # type: ignore def validate_magpie_usage( @@ -183,7 +180,7 @@ def load(self) -> None: from transformers import AutoTokenizer except ImportError as ie: raise ImportError( - "Transformers is not installed. Please install it using `pip install transformers`." + "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`." ) from ie self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) if self._tokenizer.chat_template is None: @@ -202,7 +199,7 @@ async def _generate_chat_completion( format: Literal["", "json"] = "", options: Union[Options, None] = None, keep_alive: Union[bool, None] = None, - ) -> "CreateChatCompletionResponse": + ) -> "ChatResponse": return await self._aclient.chat( model=self.model, messages=input, @@ -239,7 +236,7 @@ async def _generate_with_text_generation( format: Literal["", "json"] = None, options: Union[Options, None] = None, keep_alive: Union[bool, None] = None, - ) -> "CreateChatCompletionResponse": + ) -> "GenerateResponse": input = self.prepare_input(input) return await self._aclient.generate( model=self.model, @@ -281,10 +278,8 @@ async def agenerate( ) text = completion["message"]["content"] else: - completion: CreateChatCompletionResponse = ( - await self._generate_with_text_generation( - input, format, options, keep_alive - ) + completion = await self._generate_with_text_generation( + input, format, options, keep_alive ) text = completion.response except Exception as e: