Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Передача параметра few_shot_examples через pydantic схему #9

Closed
wants to merge 2 commits into from

Conversation

NIK-TIGER-BILL
Copy link
Contributor

Реализация возможности передачи параметра few_shot_examples через pydantic схему. Это необходимо, чтобы удобно использовать функционал .with_structured_output.

Пример передачи параметра few_shot_examples:

class NewCheckManagerMessage(BaseModel):
    """Проверка сообщения от клиентского менеджера"""
    is_valid: int = Field(
        description='Признак валидности сообщения. '
                    'Валидное сообщение должно быть похоже на сообщение представителя банка, '
                    'соблюдается формальный стиль. '
                    'В сообщении не уместны не подходящие темы, шутки, личные разговоры. '
                    '1 - обозначает корректное сообщение от клиентского менеджера банка, '
                    '0 - сообщение написано не корректно.'
    )
    description: str = Field(description='Объясни и аргументируй почему ты выбрал такое решение о признаке валидности')

    @staticmethod
    def few_shot_examples() -> FewShotExamples:
        return [
            {
                'request': 'В нем множество преимуществ. Хотите узнать какие?',
                'params': {
                    'is_valid': 1,
                    'description': 'Сообщение корректно'
                }
            }
        ]


check_manager_message_prompt = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            'Вы - высококлассный сотрудник безопасности банка, который прослушивает разговоры между клиентами и их менеджерами банка. '
            'Твоя задача следить за сообщения от менеджера и проверять соблюдение правил общения представителей банка. '
        ),
        (
            'placeholder',
            '{dialogue}'
        ),
        (
            'user',
            'Учитывая приведённый разговор выше. Реши, корректно ли следующее сообщение от менеджера: "{message}"'
        ),
    ]
)


def get_check_manager_message_chain(llm: BaseChatModel):
    return check_manager_message_prompt | llm.with_structured_output(NewCheckManagerMessage)

Раньше, чтобы не терять совместимость с переключением на другие модели, приходилось выкручиваться примерно таким способом:

class CheckManagerMessage(BaseModel):
    """Проверка сообщения от клиентского менеджера"""
    is_valid: int = Field(
        description='Признак валидности сообщения. '
                    'Валидное сообщение должно быть похоже на сообщение представителя банка, '
                    'соблюдается формальный стиль. '
                    'В сообщении не уместны не подходящие темы, шутки, личные разговоры. '
                    '1 - обозначает корректное сообщение от клиентского менеджера банка, '
                    '0 - сообщение написано не корректно.'
    )
    description: str = Field(description='Объясни и аргументируй почему ты выбрал такое решение о признаке валидности')


def get_structured_output_chain(llm: BaseChatModel, schema: BaseModel, few_show_examples: FewShotExamples):

    class SubTool(BaseTool):
        name: str = schema.__name__
        description: str = schema.__doc__
        args_schema: Type[BaseModel] = schema
        few_shot_examples: FewShotExamples = few_show_examples

        def _run(self):
            pass

    tool = SubTool()

    return llm.bind_tools([tool], tool_choice=tool.name) | PydanticToolsParser(tools=[tool.args_schema])


check_manager_message_prompt = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            'Вы - высококлассный сотрудник безопасности банка, который прослушивает разговоры между клиентами и их менеджерами банка. '
            'Твоя задача следить за сообщения от менеджера и проверять соблюдение правил общения представителей банка. '
        ),
        (
            'placeholder',
            '{dialogue}'
        ),
        (
            'user',
            'Учитывая приведённый разговор выше. Реши, корректно ли следующее сообщение от менеджера: "{message}"'
        ),
    ]
)

check_manager_message_few_shots = [
            {
                'request': 'В нем множество преимуществ. Хотите узнать какие?',
                'params': {
                    'is_valid': 1,
                    'description': 'Сообщение корректно'
                }
            }
        ]


def get_check_manager_message_chain(llm: BaseChatModel,):
    return check_manager_message_prompt | get_structured_output_chain(
        llm, CheckManagerMessage, check_manager_message_few_shots)

Примеры для запуска:

Новая реализация

from dotenv import load_dotenv
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, AIMessage

from langchain_gigachat import GigaChat
from langchain_gigachat.tools.giga_tool import FewShotExamples

load_dotenv()


class NewCheckManagerMessage(BaseModel):
    """Проверка сообщения от клиентского менеджера"""
    is_valid: int = Field(
        description='Признак валидности сообщения. '
                    'Валидное сообщение должно быть похоже на сообщение представителя банка, '
                    'соблюдается формальный стиль. '
                    'В сообщении не уместны не подходящие темы, шутки, личные разговоры. '
                    '1 - обозначает корректное сообщение от клиентского менеджера банка, '
                    '0 - сообщение написано не корректно.'
    )
    description: str = Field(description='Объясни и аргументируй почему ты выбрал такое решение о признаке валидности')

    @staticmethod
    def few_shot_examples() -> FewShotExamples:
        return [
            {
                'request': 'В нем множество преимуществ. Хотите узнать какие?',
                'params': {
                    'is_valid': 1,
                    'description': 'Сообщение корректно'
                }
            }
        ]


check_manager_message_prompt = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            'Вы - высококлассный сотрудник безопасности банка, который прослушивает разговоры между клиентами и их менеджерами банка. '
            'Твоя задача следить за сообщения от менеджера и проверять соблюдение правил общения представителей банка. '
        ),
        (
            'placeholder',
            '{dialogue}'
        ),
        (
            'user',
            'Учитывая приведённый разговор выше. Реши, корректно ли следующее сообщение от менеджера: "{message}"'
        ),
    ]
)


def get_check_manager_message_chain(llm: BaseChatModel):
    return check_manager_message_prompt | llm.with_structured_output(NewCheckManagerMessage)


llm = GigaChat(
            temperature=1.0,
            timeout=90.0,
            verify_ssl_certs=False,
            profanity_check=False,
            model='GigaChat-Pro',
        )
chain = get_check_manager_message_chain(llm)
dialogue = [
    HumanMessage('Привет! Хотите приобрести новый продукт СберПремьер?'),
    AIMessage('Да, конечно. А какие задачи он решает?'),
]
result = chain.invoke({'dialogue': dialogue, 'message': 'В нем множество преимуществ. Хотите узнать какие?'})
print(result)

Старая реализация

from typing import Type

from dotenv import load_dotenv
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.output_parsers import PydanticToolsParser

from langchain_gigachat import GigaChat
from langchain_gigachat.tools.giga_tool import FewShotExamples


load_dotenv()


class CheckManagerMessage(BaseModel):
    """Проверка сообщения от клиентского менеджера"""
    is_valid: int = Field(
        description='Признак валидности сообщения. '
                    'Валидное сообщение должно быть похоже на сообщение представителя банка, '
                    'соблюдается формальный стиль. '
                    'В сообщении не уместны не подходящие темы, шутки, личные разговоры. '
                    '1 - обозначает корректное сообщение от клиентского менеджера банка, '
                    '0 - сообщение написано не корректно.'
    )
    description: str = Field(description='Объясни и аргументируй почему ты выбрал такое решение о признаке валидности')


def get_structured_output_chain(llm: BaseChatModel, schema: BaseModel, few_show_examples: FewShotExamples):

    class SubTool(BaseTool):
        name: str = schema.__name__
        description: str = schema.__doc__
        args_schema: Type[BaseModel] = schema
        few_shot_examples: FewShotExamples = few_show_examples

        def _run(self):
            pass

    tool = SubTool()

    return llm.bind_tools([tool], tool_choice=tool.name) | PydanticToolsParser(tools=[tool.args_schema])


check_manager_message_prompt = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            'Вы - высококлассный сотрудник безопасности банка, который прослушивает разговоры между клиентами и их менеджерами банка. '
            'Твоя задача следить за сообщения от менеджера и проверять соблюдение правил общения представителей банка. '
        ),
        (
            'placeholder',
            '{dialogue}'
        ),
        (
            'user',
            'Учитывая приведённый разговор выше. Реши, корректно ли следующее сообщение от менеджера: "{message}"'
        ),
    ]
)

check_manager_message_few_shots = [
            {
                'request': 'В нем множество преимуществ. Хотите узнать какие?',
                'params': {
                    'is_valid': 1,
                    'description': 'Сообщение корректно'
                }
            }
        ]


def get_check_manager_message_chain(llm: BaseChatModel,):
    return check_manager_message_prompt | get_structured_output_chain(
        llm, CheckManagerMessage, check_manager_message_few_shots)


llm = GigaChat(
            temperature=1.0,
            timeout=90.0,
            verify_ssl_certs=False,
            profanity_check=False,
            model='GigaChat-Pro',
        )
chain = get_check_manager_message_chain(llm)
dialogue = [
    HumanMessage('Привет! Хотите приобрести новый продукт СберПремьер?'),
    AIMessage('Да, конечно. А какие задачи он решает?'),
]
result = chain.invoke({'dialogue': dialogue, 'message': 'В нем множество преимуществ. Хотите узнать какие?'})
print(result)

@Rai220 Rai220 changed the base branch from master to dev November 11, 2024 13:25
@Rai220 Rai220 closed this Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants