Skip to content

Commit

Permalink
chore clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Jan 8, 2025
1 parent 4e291e7 commit 69eaba9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
11 changes: 4 additions & 7 deletions src/distilabel/models/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList

from distilabel.steps.tasks.typing import (
FormattedInput,
StandardInput,
)
from distilabel.steps.tasks.typing import FormattedInput, StandardInput


class LlamaCppLLM(LLM, MagpieChatTemplateMixin):
Expand Down Expand Up @@ -173,7 +170,7 @@ class User(BaseModel):
_logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None)
_model: Optional["Llama"] = PrivateAttr(...)

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_magpie_usage(
self,
) -> "LlamaCppLLM":
Expand All @@ -195,7 +192,7 @@ def load(self) -> None:
) from ie

self._model = Llama(
model_path=self.model_path.as_posix(), # type: ignore
model_path=self.model_path.as_posix(),
seed=self.seed,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
Expand Down Expand Up @@ -223,7 +220,7 @@ def load(self) -> None:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install transformers`."
"Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
if self._tokenizer.chat_template is None:
Expand Down
21 changes: 8 additions & 13 deletions src/distilabel/models/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union

from llama_cpp.llama_types import CreateChatCompletionResponse
from pydantic import Field, PrivateAttr, model_validator, validate_call
from typing_extensions import TypedDict

Expand All @@ -26,13 +25,11 @@
from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse
from ollama import AsyncClient
from ollama._types import ChatResponse, GenerateResponse

from distilabel.models.llms.typing import LLMStatistics
from distilabel.steps.tasks.typing import (
StandardInput,
)
from distilabel.steps.tasks.typing import StandardInput


# Copied from `ollama._types.Options`
Expand Down Expand Up @@ -146,7 +143,7 @@ class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin):
"`llama3`, `qwen2` or another pre-query template provided.",
)
_num_generations_param_supported = False
_aclient: Optional["AsyncClient"] = PrivateAttr(...)
_aclient: Optional["AsyncClient"] = PrivateAttr(...) # type: ignore

@model_validator(mode="after") # type: ignore
def validate_magpie_usage(
Expand Down Expand Up @@ -183,7 +180,7 @@ def load(self) -> None:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install transformers`."
"Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
if self._tokenizer.chat_template is None:
Expand All @@ -202,7 +199,7 @@ async def _generate_chat_completion(
format: Literal["", "json"] = "",
options: Union[Options, None] = None,
keep_alive: Union[bool, None] = None,
) -> "CreateChatCompletionResponse":
) -> "ChatResponse":
return await self._aclient.chat(
model=self.model,
messages=input,
Expand Down Expand Up @@ -239,7 +236,7 @@ async def _generate_with_text_generation(
format: Literal["", "json"] = None,
options: Union[Options, None] = None,
keep_alive: Union[bool, None] = None,
) -> "CreateChatCompletionResponse":
) -> "GenerateResponse":
input = self.prepare_input(input)
return await self._aclient.generate(
model=self.model,
Expand Down Expand Up @@ -281,10 +278,8 @@ async def agenerate(
)
text = completion["message"]["content"]
else:
completion: CreateChatCompletionResponse = (
await self._generate_with_text_generation(
input, format, options, keep_alive
)
completion = await self._generate_with_text_generation(
input, format, options, keep_alive
)
text = completion.response
except Exception as e:
Expand Down

0 comments on commit 69eaba9

Please sign in to comment.