Skip to content

Commit

Permalink
frontend: added temperature gauge to assistant form (#901)
Browse files Browse the repository at this point in the history
* feat(assistants_web): added temperature to agent settings

* feat(assistants_web): use agent temperature for chat requests

* chore(assistants_web): removed not working ticks from slider component

* chore(assistants_web): refactored setter/getter with more verbose name

* feat(backend): if no temperature is sent with chat request, use agent temperature
  • Loading branch information
ezawadski authored Jan 14, 2025
1 parent f4d0172 commit 949953a
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 28 deletions.
11 changes: 4 additions & 7 deletions src/backend/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Generator

from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends
from sse_starlette.sse import EventSourceResponse

from backend.chat.custom.custom import CustomChat
Expand Down Expand Up @@ -31,7 +31,6 @@
async def chat_stream(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context = Depends(get_context),
) -> Generator[ChatResponseEvent, Any, None]:
"""
Expand All @@ -58,7 +57,7 @@ async def chat_stream(
managed_tools,
next_message_position,
ctx,
) = process_chat(session, chat_request, request, ctx)
) = process_chat(session, chat_request, ctx)

return EventSourceResponse(
generate_chat_stream(
Expand Down Expand Up @@ -86,7 +85,6 @@ async def chat_stream(
async def regenerate_chat_stream(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context = Depends(get_context),
) -> EventSourceResponse:
"""
Expand Down Expand Up @@ -127,7 +125,7 @@ async def regenerate_chat_stream(
previous_response_message_ids,
managed_tools,
ctx,
) = process_message_regeneration(session, chat_request, request, ctx)
) = process_message_regeneration(session, chat_request, ctx)

return EventSourceResponse(
generate_chat_stream(
Expand Down Expand Up @@ -155,7 +153,6 @@ async def regenerate_chat_stream(
async def chat(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context = Depends(get_context),
) -> NonStreamedChatResponse:
"""
Expand Down Expand Up @@ -197,7 +194,7 @@ async def chat(
managed_tools,
next_message_position,
ctx,
) = process_chat(session, chat_request, request, ctx)
) = process_chat(session, chat_request, ctx)

response = await generate_chat_response(
session,
Expand Down
20 changes: 12 additions & 8 deletions src/backend/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import nltk
from cohere.types import StreamedChatResponse
from fastapi import HTTPException, Request
from fastapi import HTTPException
from fastapi.encoders import jsonable_encoder

from backend.chat.collate import to_dict
Expand Down Expand Up @@ -74,19 +74,17 @@ def generate_tools_preamble(chat_request: CohereChatRequest) -> str:

def process_chat(
session: DBSessionDep,
chat_request: BaseChatRequest,
request: Request,
chat_request: CohereChatRequest,
ctx: Context,
) -> tuple[
DBSessionDep, BaseChatRequest, Union[list[str], None], Message, str, str, dict
DBSessionDep, CohereChatRequest, Union[list[str], None], Message, str, str, Context
]:
"""
Process a chat request.
Args:
chat_request (BaseChatRequest): Chat request data.
chat_request (CohereChatRequest): Chat request data.
session (DBSessionDep): Database session.
request (Request): Request object.
ctx (Context): Context object.
Returns:
Expand Down Expand Up @@ -124,6 +122,10 @@ def process_chat(
chat_request.model = agent.model
chat_request.preamble = agent.preamble

# If temperature is not defined in the chat request, use the temperature from the agent
if not chat_request.temperature:
chat_request.temperature = agent.temperature

should_store = chat_request.chat_history is None and not is_custom_tool_call(
chat_request
)
Expand Down Expand Up @@ -193,7 +195,6 @@ def process_chat(
def process_message_regeneration(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context,
) -> tuple[Any, CohereChatRequest, Message, list[str], bool, Context]:
"""
Expand All @@ -202,7 +203,6 @@ def process_message_regeneration(
Args:
session (DBSessionDep): Database session.
chat_request (CohereChatRequest): Chat request data.
request (Request): Request object.
ctx (Context): Context object.
Returns:
Expand All @@ -224,6 +224,10 @@ def process_message_regeneration(
# Set the agent settings in the chat request
chat_request.preamble = agent.preamble

# If temperature is not defined in the chat request, use the temperature from the agent
if not chat_request.temperature:
chat_request.temperature = agent.temperature

conversation_id = chat_request.conversation_id
ctx.with_conversation_id(conversation_id)

Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/unit/factories/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Meta:
description = factory.Faker("sentence")
preamble = factory.Faker("sentence")
version = factory.Faker("random_int")
temperature = factory.Faker("pyfloat")
temperature = factory.Faker("pyfloat", min_value=0.0, max_value=1.0)
created_at = factory.Faker("date_time")
updated_at = factory.Faker("date_time")
tools = factory.List(
Expand Down
1 change: 0 additions & 1 deletion src/backend/tests/unit/routers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def test_streaming_fail_chat_missing_message(
"loc": ["body", "message"],
"msg": "Field required",
"input": {},
"url": "https://errors.pydantic.dev/2.10/v/missing",
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({
const fileIds = conversation?.files.map((file) => file.id);

setParams({
temperature: agent?.temperature,
tools: agentTools,
fileIds,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import { AgentSettingsFields, AgentSettingsForm } from '@/components/AgentSettin
import { MobileHeader } from '@/components/Global';
import { DeleteAgent } from '@/components/Modals/DeleteAgent';
import { Button, Icon, Spinner, Text } from '@/components/UI';
import { DEFAULT_AGENT_MODEL, DEPLOYMENT_COHERE_PLATFORM } from '@/constants';
import {
DEFAULT_AGENT_MODEL,
DEFAULT_AGENT_TEMPERATURE,
DEPLOYMENT_COHERE_PLATFORM,
} from '@/constants';
import { useContextStore } from '@/context';
import { useIsAgentNameUnique, useNotify, useUpdateAgent } from '@/hooks';

Expand All @@ -28,6 +32,7 @@ export const UpdateAgent: React.FC<Props> = ({ agent }) => {
description: agent.description,
deployment: agent.deployment ?? DEPLOYMENT_COHERE_PLATFORM,
model: agent.model ?? DEFAULT_AGENT_MODEL,
temperature: agent.temperature ?? DEFAULT_AGENT_TEMPERATURE,
tools: agent.tools,
preamble: agent.preamble,
tools_metadata: agent.tools_metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { Button, Icon, Text } from '@/components/UI';
import {
BACKGROUND_TOOLS,
DEFAULT_AGENT_MODEL,
DEFAULT_AGENT_TEMPERATURE,
DEFAULT_PREAMBLE,
DEPLOYMENT_COHERE_PLATFORM,
} from '@/constants';
Expand All @@ -23,6 +24,7 @@ const DEFAULT_FIELD_VALUES = {
preamble: DEFAULT_PREAMBLE,
deployment: DEPLOYMENT_COHERE_PLATFORM,
model: DEFAULT_AGENT_MODEL,
temperature: DEFAULT_AGENT_TEMPERATURE,
tools: BACKGROUND_TOOLS,
is_private: false,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { useState } from 'react';

import { AgentSettingsFields } from '@/components/AgentSettingsForm';
import { Dropdown } from '@/components/UI';
import { Dropdown, Slider } from '@/components/UI';
import { useListAllDeployments } from '@/hooks';

type Props = {
Expand All @@ -13,7 +13,10 @@ type Props = {
};

export const ConfigStep: React.FC<Props> = ({ fields, setFields }) => {
const [selectedValue, setSelectedValue] = useState<string | undefined>(fields.model);
const [selectedModelValue, setSelectedModelValue] = useState<string | undefined>(fields.model);
const [selectedTemperatureValue, setSelectedTemperatureValue] = useState<number | undefined>(
fields.temperature
);

const { data: deployments } = useListAllDeployments();

Expand All @@ -27,12 +30,23 @@ export const ConfigStep: React.FC<Props> = ({ fields, setFields }) => {
<Dropdown
label="Model"
options={modelOptions ?? []}
value={selectedValue}
value={selectedModelValue}
onChange={(model) => {
setFields({ ...fields, model: model });
setSelectedValue(model);
setSelectedModelValue(model);
}}
/>
<Slider
label="Temperature"
min={0}
max={1.0}
step={0.1}
value={selectedTemperatureValue || 0}
onChange={(temperature) => {
setFields({ ...fields, temperature: temperature });
setSelectedTemperatureValue(temperature);
}}
></Slider>
</div>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ type RequiredAndNotNull<T> = {
type RequireAndNotNullSome<T, K extends keyof T> = RequiredAndNotNull<Pick<T, K>> & Omit<T, K>;

type CreateAgentSettingsFields = RequireAndNotNullSome<
Omit<CreateAgentRequest, 'version' | 'temperature'>,
'name' | 'model' | 'deployment'
Omit<CreateAgentRequest, 'version'>,
'name' | 'model' | 'deployment' | 'temperature'
>;

type UpdateAgentSettingsFields = RequireAndNotNullSome<
Omit<UpdateAgentRequest, 'version' | 'temperature'>,
'name' | 'model' | 'deployment'
Omit<UpdateAgentRequest, 'version'>,
'name' | 'model' | 'deployment' | 'temperature'
> & { is_private?: boolean };

export type AgentSettingsFields = CreateAgentSettingsFields | UpdateAgentSettingsFields;
Expand Down
82 changes: 82 additions & 0 deletions src/interfaces/assistants_web/src/components/UI/Slider.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
'use client';

import { ChangeEvent, useEffect, useMemo } from 'react';

import { InputLabel, Text } from '@/components/UI';
import { cn } from '@/utils';

type Props = {
label: string;
min: number;
max: number;
step: number;
value: number;
onChange: (value: number) => void;
sublabel?: string;
className?: string;
tooltipLabel?: React.ReactNode;
formatValue?: (value: number) => string;
};

/**
*
* Renders a slider with a label, a minimum, maximum and step value, and optional subLabel and tooltip.
* Styling for the thumb is located in main.css
*/
export const Slider: React.FC<Props> = ({
label,
sublabel,
min,
max,
step,
value,
onChange,
tooltipLabel,
formatValue,
className = '',
}) => {
// if `max` is changed dynamically don't allow the value to surpass it
useEffect(() => {
if (value > max) onChange(Math.min(value, max));
}, [max, onChange, value]);

// if `min` is changed dynamically don't allow the value to go below it
useEffect(() => {
if (value < min) onChange(Math.max(value, min));
}, [min, onChange, value]);

const handleChange = (e: ChangeEvent<HTMLInputElement>) => {
const value = Number(e.target.value);

onChange(value);
};

const ticks = useMemo(() => {
return Array.from({ length: (max - min) / step + 1 }, (_, i) => {
return i * step + min;
});
}, [max, min, step]);

return (
<div className={cn('flex flex-col space-y-4', className)}>
<div className="flex w-full items-center justify-between">
<InputLabel label={label} tooltipLabel={tooltipLabel} sublabel={sublabel} />
<Text>{formatValue ? formatValue(value) : value}</Text>
</div>
<div className="flex items-center">
<input
type="range"
value={value}
max={max}
min={min}
step={step}
onChange={handleChange}
className={cn(
'flex w-full cursor-pointer appearance-none items-center rounded-lg border outline-none active:cursor-grabbing',
'focus-visible:outline focus-visible:outline-1 focus-visible:outline-offset-4 focus-visible:outline-volcanic-100'
)}
/>
</div>
</div>
);
};
1 change: 1 addition & 0 deletions src/interfaces/assistants_web/src/components/UI/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export * from './RadioGroup';
export * from './Shortcut';
export * from './ShowStepsToggle';
export * from './Skeleton';
export * from './Slider';
export * from './Spinner';
export * from './Switch';
export * from './Tabs';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { FileAccept } from '@/components/UI';
export const DEFAULT_CONVERSATION_NAME = 'New Conversation';
export const DEFAULT_AGENT_MODEL = 'command-r-plus';
export const DEFAULT_AGENT_ID = 'default';
export const DEFAULT_AGENT_TEMPERATURE = 0.3;
export const DEFAULT_TYPING_VELOCITY = 35;
export const CONVERSATION_HISTORY_OFFSET = 100;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { StateCreator } from 'zustand';

import { CohereChatRequest, DEFAULT_CHAT_TEMPERATURE } from '@/cohere-client';
import { CohereChatRequest } from '@/cohere-client';

import { StoreState } from '..';

const INITIAL_STATE: ConfigurableParams = {
model: undefined,
temperature: DEFAULT_CHAT_TEMPERATURE,
temperature: undefined,
preamble: '',
tools: [],
fileIds: [],
Expand Down

0 comments on commit 949953a

Please sign in to comment.