diff --git a/webview-ui/src/__mocks__/lucide-react.ts b/webview-ui/src/__mocks__/lucide-react.ts new file mode 100644 index 0000000000..d85cd25d6a --- /dev/null +++ b/webview-ui/src/__mocks__/lucide-react.ts @@ -0,0 +1,6 @@ +import React from "react" + +export const Check = () => React.createElement("div") +export const ChevronsUpDown = () => React.createElement("div") +export const Loader = () => React.createElement("div") +export const X = () => React.createElement("div") diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 1303e79c7a..c9777fd9be 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -36,17 +36,13 @@ import { import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage" import { vscode } from "../../utils/vscode" import VSCodeButtonLink from "../common/VSCodeButtonLink" -import { OpenRouterModelPicker } from "./OpenRouterModelPicker" -import OpenAiModelPicker from "./OpenAiModelPicker" -import { GlamaModelPicker } from "./GlamaModelPicker" -import { UnboundModelPicker } from "./UnboundModelPicker" import { ModelInfoView } from "./ModelInfoView" import { DROPDOWN_Z_INDEX } from "./styles" -import { RequestyModelPicker } from "./RequestyModelPicker" +import { ModelPicker } from "./ModelPicker" interface ApiOptionsProps { uriScheme: string | undefined - apiConfiguration: ApiConfiguration | undefined + apiConfiguration: ApiConfiguration setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void apiErrorMessage?: string modelIdErrorMessage?: string @@ -64,6 +60,20 @@ const ApiOptions = ({ const [ollamaModels, setOllamaModels] = useState([]) const [lmStudioModels, setLmStudioModels] = useState([]) const [vsCodeLmModels, setVsCodeLmModels] = useState([]) + const [openRouterModels, setOpenRouterModels] = useState | null>({ + [openRouterDefaultModelId]: openRouterDefaultModelInfo, + }) + const [glamaModels, setGlamaModels] = useState | null>({ + [glamaDefaultModelId]: glamaDefaultModelInfo, + }) + const [unboundModels, setUnboundModels] = useState | null>({ + [unboundDefaultModelId]: unboundDefaultModelInfo, + }) + const [requestyModels, setRequestyModels] = useState | null>({ + [requestyDefaultModelId]: requestyDefaultModelInfo, + }) + const [openAiModels, setOpenAiModels] = useState | null>(null) + const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl) const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion) const [openRouterBaseUrlSelected, setOpenRouterBaseUrlSelected] = useState(!!apiConfiguration?.openRouterBaseUrl) @@ -98,22 +108,92 @@ const ApiOptions = ({ vscode.postMessage({ type: "requestLmStudioModels", text: apiConfiguration?.lmStudioBaseUrl }) } else if (selectedProvider === "vscode-lm") { vscode.postMessage({ type: "requestVsCodeLmModels" }) + } else if (selectedProvider === "openai") { + vscode.postMessage({ + type: "refreshOpenAiModels", + values: { + baseUrl: apiConfiguration?.openAiBaseUrl, + apiKey: apiConfiguration?.openAiApiKey, + }, + }) + } else if (selectedProvider === "openrouter") { + vscode.postMessage({ type: "refreshOpenRouterModels", values: {} }) + } else if (selectedProvider === "glama") { + vscode.postMessage({ type: "refreshGlamaModels", values: {} }) + } else if (selectedProvider === "requesty") { + vscode.postMessage({ + type: "refreshRequestyModels", + values: { + apiKey: apiConfiguration?.requestyApiKey, + }, + }) } }, 250, - [selectedProvider, apiConfiguration?.ollamaBaseUrl, apiConfiguration?.lmStudioBaseUrl], + [ + selectedProvider, + apiConfiguration?.ollamaBaseUrl, + apiConfiguration?.lmStudioBaseUrl, + apiConfiguration?.openAiBaseUrl, + apiConfiguration?.openAiApiKey, + apiConfiguration?.requestyApiKey, + ], ) const handleMessage = useCallback((event: MessageEvent) => { const message: ExtensionMessage = event.data - if (message.type === "ollamaModels" && Array.isArray(message.ollamaModels)) { - const newModels = message.ollamaModels - setOllamaModels(newModels) - } else if (message.type === "lmStudioModels" && Array.isArray(message.lmStudioModels)) { - const newModels = message.lmStudioModels - setLmStudioModels(newModels) - } else if (message.type === "vsCodeLmModels" && Array.isArray(message.vsCodeLmModels)) { - const newModels = message.vsCodeLmModels - setVsCodeLmModels(newModels) + switch (message.type) { + case "ollamaModels": + { + const newModels = message.ollamaModels ?? [] + setOllamaModels(newModels) + } + break + case "lmStudioModels": + { + const newModels = message.lmStudioModels ?? [] + setLmStudioModels(newModels) + } + break + case "vsCodeLmModels": + { + const newModels = message.vsCodeLmModels ?? [] + setVsCodeLmModels(newModels) + } + break + case "glamaModels": { + const updatedModels = message.glamaModels ?? {} + setGlamaModels({ + [glamaDefaultModelId]: glamaDefaultModelInfo, // in case the extension sent a model list without the default model + ...updatedModels, + }) + break + } + case "openRouterModels": { + const updatedModels = message.openRouterModels ?? {} + setOpenRouterModels({ + [openRouterDefaultModelId]: openRouterDefaultModelInfo, // in case the extension sent a model list without the default model + ...updatedModels, + }) + break + } + case "openAiModels": { + const updatedModels = message.openAiModels ?? [] + setOpenAiModels(Object.fromEntries(updatedModels.map((item) => [item, openAiModelInfoSaneDefaults]))) + break + } + case "unboundModels": { + const updatedModels = message.unboundModels ?? {} + setUnboundModels(updatedModels) + break + } + case "requestyModels": { + const updatedModels = message.requestyModels ?? {} + setRequestyModels({ + [requestyDefaultModelId]: requestyDefaultModelInfo, // in case the extension sent a model list without the default model + ...updatedModels, + }) + break + } } }, []) useEvent("message", handleMessage) @@ -604,7 +684,17 @@ const ApiOptions = ({ placeholder="Enter API Key..."> API Key - +
{ + onInput={handleInputChange("openAiCustomModelInfo", (e) => { const value = parseInt((e.target as HTMLInputElement).value) return { ...(apiConfiguration?.openAiCustomModelInfo || @@ -738,7 +828,7 @@ const ApiOptions = ({ })(), }} title="Total number of tokens (input + output) the model can process in a single request" - onChange={handleInputChange("openAiCustomModelInfo", (e) => { + onInput={handleInputChange("openAiCustomModelInfo", (e) => { const value = (e.target as HTMLInputElement).value const parsed = parseInt(value) return { @@ -884,7 +974,7 @@ const ApiOptions = ({ : "var(--vscode-errorForeground)" })(), }} - onChange={handleInputChange("openAiCustomModelInfo", (e) => { + onInput={handleInputChange("openAiCustomModelInfo", (e) => { const value = (e.target as HTMLInputElement).value const parsed = parseInt(value) return { @@ -929,7 +1019,7 @@ const ApiOptions = ({ : "var(--vscode-errorForeground)" })(), }} - onChange={handleInputChange("openAiCustomModelInfo", (e) => { + onInput={handleInputChange("openAiCustomModelInfo", (e) => { const value = (e.target as HTMLInputElement).value const parsed = parseInt(value) return { @@ -1207,7 +1297,18 @@ const ApiOptions = ({ }}> This key is stored locally and only used to make API requests from this extension.

- +
)} @@ -1223,10 +1324,49 @@ const ApiOptions = ({

)} - {selectedProvider === "glama" && } + {selectedProvider === "glama" && ( + + )} - {selectedProvider === "openrouter" && } - {selectedProvider === "requesty" && } + {selectedProvider === "openrouter" && ( + + )} + {selectedProvider === "requesty" && ( + + )} {selectedProvider !== "glama" && selectedProvider !== "openrouter" && diff --git a/webview-ui/src/components/settings/GlamaModelPicker.tsx b/webview-ui/src/components/settings/GlamaModelPicker.tsx deleted file mode 100644 index cb813a0d05..0000000000 --- a/webview-ui/src/components/settings/GlamaModelPicker.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { ModelPicker } from "./ModelPicker" -import { glamaDefaultModelId } from "../../../../src/shared/api" - -export const GlamaModelPicker = () => ( - -) diff --git a/webview-ui/src/components/settings/ModelInfoView.tsx b/webview-ui/src/components/settings/ModelInfoView.tsx index 397d04e02f..5edbec12b9 100644 --- a/webview-ui/src/components/settings/ModelInfoView.tsx +++ b/webview-ui/src/components/settings/ModelInfoView.tsx @@ -12,61 +12,49 @@ export const ModelInfoView = ({ setIsDescriptionExpanded, }: { selectedModelId: string - modelInfo: ModelInfo + modelInfo: ModelInfo | null isDescriptionExpanded: boolean setIsDescriptionExpanded: (isExpanded: boolean) => void }) => { const isGemini = Object.keys(geminiModels).includes(selectedModelId) const infoItems = [ - modelInfo.description && ( + modelInfo?.description && ( ), - , - , + , + , !isGemini && ( - + ), - modelInfo.maxTokens !== undefined && modelInfo.maxTokens > 0 && ( + modelInfo?.maxTokens !== undefined && modelInfo?.maxTokens > 0 && ( Max output: {modelInfo.maxTokens?.toLocaleString()} tokens ), - modelInfo.inputPrice !== undefined && modelInfo.inputPrice > 0 && ( + modelInfo?.inputPrice !== undefined && modelInfo.inputPrice > 0 && ( Input price: {formatPrice(modelInfo.inputPrice)}/million tokens ), - modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && ( + modelInfo?.supportsPromptCache && modelInfo.cacheWritesPrice && ( Cache writes price:{" "} {formatPrice(modelInfo.cacheWritesPrice || 0)}/million tokens ), - modelInfo.supportsPromptCache && modelInfo.cacheReadsPrice && ( + modelInfo?.supportsPromptCache && modelInfo.cacheReadsPrice && ( Cache reads price:{" "} {formatPrice(modelInfo.cacheReadsPrice || 0)}/million tokens ), - modelInfo.outputPrice !== undefined && modelInfo.outputPrice > 0 && ( + modelInfo?.outputPrice !== undefined && modelInfo.outputPrice > 0 && ( Output price: {formatPrice(modelInfo.outputPrice)}/million tokens @@ -95,15 +83,7 @@ export const ModelInfoView = ({ ) } -const ModelInfoSupportsItem = ({ - isSupported, - supportsLabel, - doesNotSupportLabel, -}: { - isSupported: boolean - supportsLabel: string - doesNotSupportLabel: string -}) => ( +const ModelInfoSupportsItem = ({ isSupported, label }: { isSupported: boolean | undefined | null; label: string }) => ( - {isSupported ? supportsLabel : doesNotSupportLabel} + {label} + {": "} + {isSupported == null ? "Unknown" : isSupported ? "Yes" : "No"} ) diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index b21b37ef0f..bf1b79121d 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -1,185 +1,80 @@ import { VSCodeLink } from "@vscode/webview-ui-toolkit/react" -import debounce from "debounce" -import { useMemo, useState, useCallback, useEffect, useRef } from "react" -import { useMount } from "react-use" -import { CaretSortIcon, CheckIcon } from "@radix-ui/react-icons" +import { useMemo, useState, useCallback } from "react" -import { cn } from "@/lib/utils" -import { - Button, - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, - Popover, - PopoverContent, - PopoverTrigger, -} from "@/components/ui" - -import { useExtensionState } from "../../context/ExtensionStateContext" -import { vscode } from "../../utils/vscode" import { normalizeApiConfiguration } from "./ApiOptions" import { ModelInfoView } from "./ModelInfoView" - -type ModelProvider = "glama" | "openRouter" | "unbound" | "requesty" | "openAi" - -type ModelKeys = `${T}Models` -type ConfigKeys = `${T}ModelId` -type InfoKeys = `${T}ModelInfo` -type RefreshMessageType = `refresh${Capitalize}Models` - -interface ModelPickerProps { - defaultModelId: string - modelsKey: ModelKeys - configKey: ConfigKeys - infoKey: InfoKeys - refreshMessageType: RefreshMessageType - refreshValues?: Record +import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api" +import { Combobox, ComboboxContent, ComboboxEmpty, ComboboxInput, ComboboxItem } from "../ui/combobox" + +type ExtractType = Exclude< + { [K in keyof ApiConfiguration]: Required[K] extends T ? K : never }[keyof ApiConfiguration], + undefined +> + +type ModelIdKeys = Exclude< + { [K in keyof ApiConfiguration]: K extends `${string}ModelId` ? K : never }[keyof ApiConfiguration], + undefined +> + +interface ModelPickerProps { + defaultModelId?: string + models: Record | null + modelIdKey: ModelIdKeys + modelInfoKey: ExtractType serviceName: string serviceUrl: string recommendedModel: string - allowCustomModel?: boolean + apiConfiguration: ApiConfiguration + setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void + defaultModelInfo?: ModelInfo } export const ModelPicker = ({ defaultModelId, - modelsKey, - configKey, - infoKey, - refreshMessageType, - refreshValues, + models, + modelIdKey, + modelInfoKey, serviceName, serviceUrl, recommendedModel, - allowCustomModel = false, + apiConfiguration, + setApiConfigurationField, + defaultModelInfo, }: ModelPickerProps) => { - const [customModelId, setCustomModelId] = useState("") - const [isCustomModel, setIsCustomModel] = useState(false) - const [open, setOpen] = useState(false) - const [value, setValue] = useState(defaultModelId) const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false) - const prevRefreshValuesRef = useRef | undefined>() - - const { apiConfiguration, [modelsKey]: models, onUpdateApiConfig, setApiConfiguration } = useExtensionState() - const modelIds = useMemo( - () => (Array.isArray(models) ? models : Object.keys(models)).sort((a, b) => a.localeCompare(b)), - [models], - ) + const modelIds = useMemo(() => Object.keys(models ?? {}).sort((a, b) => a.localeCompare(b)), [models]) const { selectedModelId, selectedModelInfo } = useMemo( () => normalizeApiConfiguration(apiConfiguration), [apiConfiguration], ) - - const onSelectCustomModel = useCallback( - (modelId: string) => { - setCustomModelId(modelId) - const modelInfo = { id: modelId } - const apiConfig = { ...apiConfiguration, [configKey]: modelId, [infoKey]: modelInfo } - setApiConfiguration(apiConfig) - onUpdateApiConfig(apiConfig) - setValue(modelId) - setOpen(false) - setIsCustomModel(false) - }, - [apiConfiguration, configKey, infoKey, onUpdateApiConfig, setApiConfiguration], - ) - const onSelect = useCallback( (modelId: string) => { - const modelInfo = Array.isArray(models) - ? { id: modelId } // For OpenAI models which are just strings - : models[modelId] // For other models that have full info objects - const apiConfig = { ...apiConfiguration, [configKey]: modelId, [infoKey]: modelInfo } - setApiConfiguration(apiConfig) - onUpdateApiConfig(apiConfig) - setValue(modelId) - setOpen(false) + const modelInfo = models?.[modelId] + setApiConfigurationField(modelIdKey, modelId) + setApiConfigurationField(modelInfoKey, modelInfo ?? defaultModelInfo) }, - [apiConfiguration, configKey, infoKey, models, onUpdateApiConfig, setApiConfiguration], + [modelIdKey, modelInfoKey, models, setApiConfigurationField, defaultModelInfo], ) - const debouncedRefreshModels = useMemo(() => { - return debounce(() => { - const message = refreshValues - ? { type: refreshMessageType, values: refreshValues } - : { type: refreshMessageType } - vscode.postMessage(message) - }, 100) - }, [refreshMessageType, refreshValues]) - - useMount(() => { - debouncedRefreshModels() - return () => debouncedRefreshModels.clear() - }) - - useEffect(() => { - if (!refreshValues) { - prevRefreshValuesRef.current = undefined - return - } - - // Check if all values in refreshValues are truthy - if (Object.values(refreshValues).some((value) => !value)) { - prevRefreshValuesRef.current = undefined - return - } - - // Compare with previous values - const prevValues = prevRefreshValuesRef.current - if (prevValues && JSON.stringify(prevValues) === JSON.stringify(refreshValues)) { - return - } - - prevRefreshValuesRef.current = refreshValues - debouncedRefreshModels() - }, [debouncedRefreshModels, refreshValues]) - - useEffect(() => setValue(selectedModelId), [selectedModelId]) - return ( <>
Model
- - - - - - - - - No model found. - - {modelIds.map((model) => ( - - {model} - - - ))} - - {allowCustomModel && ( - - { - setIsCustomModel(true) - setOpen(false) - }}> - + Add custom model - - - )} - - - - + + + + No model found. + {modelIds.map((model) => ( + + {model} + + ))} + + {selectedModelId && selectedModelInfo && ( onSelect(recommendedModel)}>{recommendedModel}. You can also try searching "free" for no-cost options currently available.

- {allowCustomModel && isCustomModel && ( -
-
-

Add Custom Model

- setCustomModelId(e.target.value)} - /> -
- - -
-
-
- )} ) } diff --git a/webview-ui/src/components/settings/OpenAiModelPicker.tsx b/webview-ui/src/components/settings/OpenAiModelPicker.tsx deleted file mode 100644 index 040da1d421..0000000000 --- a/webview-ui/src/components/settings/OpenAiModelPicker.tsx +++ /dev/null @@ -1,27 +0,0 @@ -import React from "react" -import { useExtensionState } from "../../context/ExtensionStateContext" -import { ModelPicker } from "./ModelPicker" - -const OpenAiModelPicker: React.FC = () => { - const { apiConfiguration } = useExtensionState() - - return ( - - ) -} - -export default OpenAiModelPicker diff --git a/webview-ui/src/components/settings/OpenRouterModelPicker.tsx b/webview-ui/src/components/settings/OpenRouterModelPicker.tsx deleted file mode 100644 index 9111407cd6..0000000000 --- a/webview-ui/src/components/settings/OpenRouterModelPicker.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { ModelPicker } from "./ModelPicker" -import { openRouterDefaultModelId } from "../../../../src/shared/api" - -export const OpenRouterModelPicker = () => ( - -) diff --git a/webview-ui/src/components/settings/RequestyModelPicker.tsx b/webview-ui/src/components/settings/RequestyModelPicker.tsx deleted file mode 100644 index e0759a43ba..0000000000 --- a/webview-ui/src/components/settings/RequestyModelPicker.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import { ModelPicker } from "./ModelPicker" -import { requestyDefaultModelId } from "../../../../src/shared/api" -import { useExtensionState } from "@/context/ExtensionStateContext" - -export const RequestyModelPicker = () => { - const { apiConfiguration } = useExtensionState() - return ( - - ) -} diff --git a/webview-ui/src/components/settings/SettingsView.tsx b/webview-ui/src/components/settings/SettingsView.tsx index 495bf49bd7..2c526e6ef1 100644 --- a/webview-ui/src/components/settings/SettingsView.tsx +++ b/webview-ui/src/components/settings/SettingsView.tsx @@ -100,11 +100,14 @@ const SettingsView = forwardRef(({ onDone }, const setApiConfigurationField = useCallback( (field: K, value: ApiConfiguration[K]) => { + console.trace("setApiConfigurationField", field, value) setCachedState((prevState) => { if (prevState.apiConfiguration?.[field] === value) { return prevState } + console.trace("setApiConfigurationField.setChangeDetected", field, value) setChangeDetected(true) + return { ...prevState, apiConfiguration: { @@ -343,7 +346,7 @@ const SettingsView = forwardRef(({ onDone }, /> ( - -) diff --git a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx index 4e7c67c187..4045547176 100644 --- a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx +++ b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx @@ -3,7 +3,7 @@ import { screen, fireEvent, render } from "@testing-library/react" import { act } from "react" import { ModelPicker } from "../ModelPicker" -import { useExtensionState } from "../../../context/ExtensionStateContext" +import { glamaDefaultModelInfo } from "../../../../../src/shared/api" jest.mock("../../../context/ExtensionStateContext", () => ({ useExtensionState: jest.fn(), @@ -20,36 +20,30 @@ global.ResizeObserver = MockResizeObserver Element.prototype.scrollIntoView = jest.fn() describe("ModelPicker", () => { - const mockOnUpdateApiConfig = jest.fn() - const mockSetApiConfiguration = jest.fn() + const mockSetApiConfigurationField = jest.fn() + const mockModels = { + model1: { name: "Model 1", description: "Test model 1", ...glamaDefaultModelInfo }, + model2: { name: "Model 2", description: "Test model 2", ...glamaDefaultModelInfo }, + } const defaultProps = { + apiConfiguration: {}, defaultModelId: "model1", - modelsKey: "glamaModels" as const, - configKey: "glamaModelId" as const, - infoKey: "glamaModelInfo" as const, - refreshMessageType: "refreshGlamaModels" as const, + defaultModelInfo: glamaDefaultModelInfo, + modelIdKey: "glamaModelId" as const, + modelInfoKey: "glamaModelInfo" as const, serviceName: "Test Service", serviceUrl: "https://test.service", recommendedModel: "recommended-model", - } - - const mockModels = { - model1: { name: "Model 1", description: "Test model 1" }, - model2: { name: "Model 2", description: "Test model 2" }, + models: mockModels, + setApiConfigurationField: mockSetApiConfigurationField, } beforeEach(() => { jest.clearAllMocks() - ;(useExtensionState as jest.Mock).mockReturnValue({ - apiConfiguration: {}, - setApiConfiguration: mockSetApiConfiguration, - glamaModels: mockModels, - onUpdateApiConfig: mockOnUpdateApiConfig, - }) }) - it("calls onUpdateApiConfig when a model is selected", async () => { + it("calls setApiConfigurationField when a model is selected", async () => { await act(async () => { render() }) @@ -67,20 +61,12 @@ describe("ModelPicker", () => { await act(async () => { // Find and click the model item by its value. - const modelItem = screen.getByRole("option", { name: "model2" }) - fireEvent.click(modelItem) + const modelItem = screen.getByTestId("model-input") + fireEvent.input(modelItem, { target: { value: "model2" } }) }) // Verify the API config was updated. - expect(mockSetApiConfiguration).toHaveBeenCalledWith({ - glamaModelId: "model2", - glamaModelInfo: mockModels["model2"], - }) - - // Verify onUpdateApiConfig was called with the new config. - expect(mockOnUpdateApiConfig).toHaveBeenCalledWith({ - glamaModelId: "model2", - glamaModelInfo: mockModels["model2"], - }) + expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, "model2") + expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelInfoKey, mockModels.model2) }) }) diff --git a/webview-ui/src/components/ui/combobox-primitive.tsx b/webview-ui/src/components/ui/combobox-primitive.tsx new file mode 100644 index 0000000000..13bad87aba --- /dev/null +++ b/webview-ui/src/components/ui/combobox-primitive.tsx @@ -0,0 +1,522 @@ +/* eslint-disable react/jsx-pascal-case */ +"use client" + +import * as React from "react" +import { composeEventHandlers } from "@radix-ui/primitive" +import { useComposedRefs } from "@radix-ui/react-compose-refs" +import * as PopoverPrimitive from "@radix-ui/react-popover" +import { Primitive } from "@radix-ui/react-primitive" +import * as RovingFocusGroupPrimitive from "@radix-ui/react-roving-focus" +import { useControllableState } from "@radix-ui/react-use-controllable-state" +import { Command as CommandPrimitive } from "cmdk" + +export type ComboboxContextProps = { + inputValue: string + onInputValueChange: (inputValue: string, reason: "inputChange" | "itemSelect" | "clearClick") => void + onInputBlur?: (e: React.FocusEvent) => void + open: boolean + onOpenChange: (open: boolean) => void + currentTabStopId: string | null + onCurrentTabStopIdChange: (currentTabStopId: string | null) => void + inputRef: React.RefObject + tagGroupRef: React.RefObject> + disabled?: boolean + required?: boolean +} & ( + | Required> + | Required> +) + +const ComboboxContext = React.createContext({ + type: "single", + value: "", + onValueChange: () => {}, + inputValue: "", + onInputValueChange: () => {}, + onInputBlur: () => {}, + open: false, + onOpenChange: () => {}, + currentTabStopId: null, + onCurrentTabStopIdChange: () => {}, + inputRef: { current: null }, + tagGroupRef: { current: null }, + disabled: false, + required: false, +}) + +export const useComboboxContext = () => React.useContext(ComboboxContext) + +export type ComboboxType = "single" | "multiple" + +export interface ComboboxBaseProps + extends React.ComponentProps, + Omit, "value" | "defaultValue" | "onValueChange"> { + type?: ComboboxType | undefined + inputValue?: string + defaultInputValue?: string + onInputValueChange?: (inputValue: string, reason: "inputChange" | "itemSelect" | "clearClick") => void + onInputBlur?: (e: React.FocusEvent) => void + disabled?: boolean + required?: boolean +} + +export type ComboboxValue = T extends "single" + ? string + : T extends "multiple" + ? string[] + : never + +export interface ComboboxSingleProps { + type: "single" + value?: string + defaultValue?: string + onValueChange?: (value: string) => void +} + +export interface ComboboxMultipleProps { + type: "multiple" + value?: string[] + defaultValue?: string[] + onValueChange?: (value: string[]) => void +} + +export type ComboboxProps = ComboboxBaseProps & (ComboboxSingleProps | ComboboxMultipleProps) + +export const Combobox = React.forwardRef( + ( + { + type = "single" as T, + open: openProp, + onOpenChange, + defaultOpen, + modal, + children, + value: valueProp, + defaultValue, + onValueChange, + inputValue: inputValueProp, + defaultInputValue, + onInputValueChange, + onInputBlur, + disabled, + required, + ...props + }: ComboboxProps, + ref: React.ForwardedRef>, + ) => { + const [value = type === "multiple" ? [] : "", setValue] = useControllableState>({ + prop: valueProp as ComboboxValue, + defaultProp: defaultValue as ComboboxValue, + onChange: onValueChange as (value: ComboboxValue) => void, + }) + const [inputValue = "", setInputValue] = useControllableState({ + prop: inputValueProp, + defaultProp: defaultInputValue, + }) + const [open = false, setOpen] = useControllableState({ + prop: openProp, + defaultProp: defaultOpen, + onChange: onOpenChange, + }) + const [currentTabStopId, setCurrentTabStopId] = React.useState(null) + const inputRef = React.useRef(null) + const tagGroupRef = React.useRef>(null) + + const handleInputValueChange: ComboboxContextProps["onInputValueChange"] = React.useCallback( + (inputValue, reason) => { + setInputValue(inputValue) + onInputValueChange?.(inputValue, reason) + }, + [setInputValue, onInputValueChange], + ) + + return ( + + + + {children} + {!open && + + + ) + }, +) +Combobox.displayName = "Combobox" + +export const ComboboxTagGroup = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>((props, ref) => { + const { currentTabStopId, onCurrentTabStopIdChange, tagGroupRef, type } = useComboboxContext() + + if (type !== "multiple") { + throw new Error(' should only be used when type is "multiple"') + } + + const composedRefs = useComposedRefs(ref, tagGroupRef) + + return ( + onCurrentTabStopIdChange(null)} + {...props} + /> + ) +}) +ComboboxTagGroup.displayName = "ComboboxTagGroup" + +export interface ComboboxTagGroupItemProps + extends React.ComponentPropsWithoutRef { + value: string + disabled?: boolean +} + +const ComboboxTagGroupItemContext = React.createContext>({ + value: "", + disabled: false, +}) + +const useComboboxTagGroupItemContext = () => React.useContext(ComboboxTagGroupItemContext) + +export const ComboboxTagGroupItem = React.forwardRef< + React.ElementRef, + ComboboxTagGroupItemProps +>(({ onClick, onKeyDown, value: valueProp, disabled, ...props }, ref) => { + const { value, onValueChange, inputRef, currentTabStopId, type } = useComboboxContext() + + if (type !== "multiple") { + throw new Error(' should only be used when type is "multiple"') + } + + const lastItemValue = value.at(-1) + + return ( + + { + if (event.key === "Escape") { + inputRef.current?.focus() + } + if (event.key === "ArrowUp" || event.key === "ArrowDown") { + event.preventDefault() + inputRef.current?.focus() + } + if (event.key === "ArrowRight" && currentTabStopId === lastItemValue) { + inputRef.current?.focus() + } + if (event.key === "Backspace" || event.key === "Delete") { + onValueChange(value.filter((v) => v !== currentTabStopId)) + inputRef.current?.focus() + } + })} + onClick={composeEventHandlers(onClick, () => disabled && inputRef.current?.focus())} + tabStopId={valueProp} + focusable={!disabled} + data-disabled={disabled} + active={valueProp === lastItemValue} + {...props} + /> + + ) +}) +ComboboxTagGroupItem.displayName = "ComboboxTagGroupItem" + +export const ComboboxTagGroupItemRemove = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ onClick, ...props }, ref) => { + const { value, onValueChange, type } = useComboboxContext() + + if (type !== "multiple") { + throw new Error(' should only be used when type is "multiple"') + } + + const { value: valueProp, disabled } = useComboboxTagGroupItemContext() + + return ( + onValueChange(value.filter((v) => v !== valueProp)))} + {...props} + /> + ) +}) +ComboboxTagGroupItemRemove.displayName = "ComboboxTagGroupItemRemove" + +export const ComboboxInput = React.forwardRef< + React.ElementRef, + Omit, "value" | "onValueChange"> +>(({ onKeyDown, onMouseDown, onFocus, onBlur, ...props }, ref) => { + const { + type, + inputValue, + onInputValueChange, + onInputBlur, + open, + onOpenChange, + value, + onValueChange, + inputRef, + disabled, + required, + tagGroupRef, + } = useComboboxContext() + + const composedRefs = useComposedRefs(ref, inputRef) + + return ( + { + if (!open) { + onOpenChange(true) + } + // Schedule input value change to the next tick. + setTimeout(() => onInputValueChange(search, "inputChange")) + if (!search && type === "single") { + onValueChange("") + } + }} + onKeyDown={composeEventHandlers(onKeyDown, (event) => { + if (event.key === "ArrowUp" || event.key === "ArrowDown") { + if (!open) { + event.preventDefault() + onOpenChange(true) + } + } + if (type !== "multiple") { + return + } + if (event.key === "ArrowLeft" && !inputValue && value.length) { + tagGroupRef.current?.focus() + } + if (event.key === "Backspace" && !inputValue) { + onValueChange(value.slice(0, -1)) + } + })} + onMouseDown={composeEventHandlers(onMouseDown, () => onOpenChange(!!inputValue || !open))} + onFocus={composeEventHandlers(onFocus, () => onOpenChange(true))} + onBlur={composeEventHandlers(onBlur, (event) => { + if (!event.relatedTarget?.hasAttribute("cmdk-list")) { + onInputBlur?.(event) + } + })} + {...props} + /> + ) +}) +ComboboxInput.displayName = "ComboboxInput" + +export const ComboboxClear = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ onClick, ...props }, ref) => { + const { value, onValueChange, inputValue, onInputValueChange, type } = useComboboxContext() + + const isValueEmpty = type === "single" ? !value : !value.length + + return ( + { + if (type === "single") { + onValueChange("") + } else { + onValueChange([]) + } + onInputValueChange("", "clearClick") + })} + {...props} + /> + ) +}) +ComboboxClear.displayName = "ComboboxClear" + +export const ComboboxTrigger = PopoverPrimitive.Trigger + +export const ComboboxAnchor = PopoverPrimitive.Anchor + +export const ComboboxPortal = PopoverPrimitive.Portal + +export const ComboboxContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ children, onOpenAutoFocus, onInteractOutside, ...props }, ref) => ( + event.preventDefault())} + onCloseAutoFocus={composeEventHandlers(onOpenAutoFocus, (event) => event.preventDefault())} + onInteractOutside={composeEventHandlers(onInteractOutside, (event) => { + if (event.target instanceof Element && event.target.hasAttribute("cmdk-input")) { + event.preventDefault() + } + })} + {...props}> + {children} + +)) +ComboboxContent.displayName = "ComboboxContent" + +export const ComboboxEmpty = CommandPrimitive.Empty + +export const ComboboxLoading = CommandPrimitive.Loading + +export interface ComboboxItemProps extends Omit, "value"> { + value: string +} + +const ComboboxItemContext = React.createContext({ isSelected: false }) + +const useComboboxItemContext = () => React.useContext(ComboboxItemContext) + +const findComboboxItemText = (children: React.ReactNode) => { + let text = "" + + React.Children.forEach(children, (child) => { + if (text) { + return + } + + if (React.isValidElement<{ children: React.ReactNode }>(child)) { + if (child.type === ComboboxItemText) { + text = child.props.children as string + } else { + text = findComboboxItemText(child.props.children) + } + } + }) + + return text +} + +export const ComboboxItem = React.forwardRef, ComboboxItemProps>( + ({ value: valueProp, children, onMouseDown, ...props }, ref) => { + const { type, value, onValueChange, onInputValueChange, onOpenChange } = useComboboxContext() + + const inputValue = React.useMemo(() => findComboboxItemText(children), [children]) + + const isSelected = type === "single" ? value === valueProp : value.includes(valueProp) + + return ( + + event.preventDefault())} + onSelect={() => { + if (type === "multiple") { + onValueChange( + value.includes(valueProp) + ? value.filter((v) => v !== valueProp) + : [...value, valueProp], + ) + onInputValueChange("", "itemSelect") + } else { + onValueChange(valueProp) + onInputValueChange(inputValue, "itemSelect") + // Schedule open change to the next tick. + setTimeout(() => onOpenChange(false)) + } + }} + value={inputValue} + {...props}> + {children} + + + ) + }, +) +ComboboxItem.displayName = "ComboboxItem" + +export const ComboboxItemIndicator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>((props, ref) => { + const { isSelected } = useComboboxItemContext() + + if (!isSelected) { + return null + } + + return +}) +ComboboxItemIndicator.displayName = "ComboboxItemIndicator" + +export interface ComboboxItemTextProps extends React.ComponentPropsWithoutRef { + children: string +} + +export const ComboboxItemText = (props: ComboboxItemTextProps) => +ComboboxItemText.displayName = "ComboboxItemText" + +export const ComboboxGroup = CommandPrimitive.Group + +export const ComboboxSeparator = CommandPrimitive.Separator + +const Root = Combobox +const TagGroup = ComboboxTagGroup +const TagGroupItem = ComboboxTagGroupItem +const TagGroupItemRemove = ComboboxTagGroupItemRemove +const Input = ComboboxInput +const Clear = ComboboxClear +const Trigger = ComboboxTrigger +const Anchor = ComboboxAnchor +const Portal = ComboboxPortal +const Content = ComboboxContent +const Empty = ComboboxEmpty +const Loading = ComboboxLoading +const Item = ComboboxItem +const ItemIndicator = ComboboxItemIndicator +const ItemText = ComboboxItemText +const Group = ComboboxGroup +const Separator = ComboboxSeparator + +export { + Root, + TagGroup, + TagGroupItem, + TagGroupItemRemove, + Input, + Clear, + Trigger, + Anchor, + Portal, + Content, + Empty, + Loading, + Item, + ItemIndicator, + ItemText, + Group, + Separator, +} diff --git a/webview-ui/src/components/ui/combobox.tsx b/webview-ui/src/components/ui/combobox.tsx new file mode 100644 index 0000000000..24b2f7be1f --- /dev/null +++ b/webview-ui/src/components/ui/combobox.tsx @@ -0,0 +1,177 @@ +"use client" + +import * as React from "react" +import { Slottable } from "@radix-ui/react-slot" +import { cva } from "class-variance-authority" +import { Check, ChevronsUpDown, Loader, X } from "lucide-react" + +import { cn } from "@/lib/utils" +import * as ComboboxPrimitive from "@/components/ui/combobox-primitive" +import { badgeVariants } from "@/components/ui/badge" +// import * as ComboboxPrimitive from "@/registry/default/ui/combobox-primitive" +import { + InputBase, + InputBaseAdornmentButton, + InputBaseControl, + InputBaseFlexWrapper, + InputBaseInput, +} from "@/components/ui/input-base" + +export const Combobox = ComboboxPrimitive.Root + +const ComboboxInputBase = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ children, ...props }, ref) => ( + + + {children} + + + + + + + + + + + + +)) +ComboboxInputBase.displayName = "ComboboxInputBase" + +export const ComboboxInput = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>((props, ref) => ( + + + + + + + +)) +ComboboxInput.displayName = "ComboboxInput" + +export const ComboboxTagsInput = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ children, ...props }, ref) => ( + + + + {children} + + + + + + + + +)) +ComboboxTagsInput.displayName = "ComboboxTagsInput" + +export const ComboboxTag = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ children, className, ...props }, ref) => ( + + {children} + + + Remove + + +)) +ComboboxTag.displayName = "ComboboxTag" + +export const ComboboxContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, align = "start", alignOffset = 0, ...props }, ref) => ( + + + +)) +ComboboxContent.displayName = "ComboboxContent" + +export const ComboboxEmpty = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +ComboboxEmpty.displayName = "ComboboxEmpty" + +export const ComboboxLoading = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)) +ComboboxLoading.displayName = "ComboboxLoading" + +export const ComboboxGroup = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +ComboboxGroup.displayName = "ComboboxGroup" + +const ComboboxSeparator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +ComboboxSeparator.displayName = "ComboboxSeparator" + +export const comboboxItemStyle = cva( + "relative flex w-full cursor-pointer select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none data-[disabled=true]:pointer-events-none data-[selected=true]:bg-accent data-[selected=true]:text-vscode-dropdown-foreground data-[disabled=true]:opacity-50", +) + +export const ComboboxItem = React.forwardRef< + React.ElementRef, + Omit, "children"> & + Pick, "children"> +>(({ className, children, ...props }, ref) => ( + + {children} + + + + +)) +ComboboxItem.displayName = "ComboboxItem" diff --git a/webview-ui/src/components/ui/input-base.tsx b/webview-ui/src/components/ui/input-base.tsx new file mode 100644 index 0000000000..9dbda6eb13 --- /dev/null +++ b/webview-ui/src/components/ui/input-base.tsx @@ -0,0 +1,157 @@ +/* eslint-disable react/jsx-no-comment-textnodes */ +/* eslint-disable react/jsx-pascal-case */ +"use client" + +import * as React from "react" +import { composeEventHandlers } from "@radix-ui/primitive" +import { composeRefs } from "@radix-ui/react-compose-refs" +import { Primitive } from "@radix-ui/react-primitive" +import { Slot } from "@radix-ui/react-slot" + +import { cn } from "@/lib/utils" +import { Button } from "./button" + +export type InputBaseContextProps = Pick & { + controlRef: React.RefObject + onFocusedChange: (focused: boolean) => void +} + +const InputBaseContext = React.createContext({ + autoFocus: false, + controlRef: { current: null }, + disabled: false, + onFocusedChange: () => {}, +}) + +const useInputBaseContext = () => React.useContext(InputBaseContext) + +export interface InputBaseProps extends React.ComponentPropsWithoutRef { + autoFocus?: boolean + disabled?: boolean +} + +export const InputBase = React.forwardRef, InputBaseProps>( + ({ autoFocus, disabled, className, onClick, ...props }, ref) => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const [focused, setFocused] = React.useState(false) + + const controlRef = React.useRef(null) + + return ( + + { + // Based on MUI's implementation. + // https://github.com/mui/material-ui/blob/master/packages/mui-material/src/InputBase/InputBase.js#L458~L460 + if (controlRef.current && event.currentTarget === event.target) { + controlRef.current.focus() + } + })} + className={cn( + "flex w-full text-vscode-input-foreground border border-vscode-dropdown-border bg-vscode-input-background rounded-xs px-3 py-0.5 text-base transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-foreground placeholder:text-muted-foreground focus:outline-0 focus-visible:outline-none focus-visible:border-vscode-focusBorder disabled:cursor-not-allowed disabled:opacity-50", + disabled && "cursor-not-allowed opacity-50", + className, + )} + {...props} + /> + + ) + }, +) +InputBase.displayName = "InputBase" + +export const InputBaseFlexWrapper = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +InputBaseFlexWrapper.displayName = "InputBaseFlexWrapper" + +export const InputBaseControl = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ onFocus, onBlur, ...props }, ref) => { + const { controlRef, autoFocus, disabled, onFocusedChange } = useInputBaseContext() + + return ( + onFocusedChange(true))} + onBlur={composeEventHandlers(onBlur, () => onFocusedChange(false))} + {...{ disabled }} + {...props} + /> + ) +}) +InputBaseControl.displayName = "InputBaseControl" + +export interface InputBaseAdornmentProps extends React.ComponentPropsWithoutRef<"div"> { + asChild?: boolean + disablePointerEvents?: boolean +} + +export const InputBaseAdornment = React.forwardRef, InputBaseAdornmentProps>( + ({ className, disablePointerEvents, asChild, children, ...props }, ref) => { + const Comp = asChild ? Slot : typeof children === "string" ? "p" : "div" + + const isAction = React.isValidElement(children) && children.type === InputBaseAdornmentButton + + return ( + + {children} + + ) + }, +) +InputBaseAdornment.displayName = "InputBaseAdornment" + +export const InputBaseAdornmentButton = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ type = "button", variant = "ghost", size = "icon", disabled: disabledProp, className, ...props }, ref) => { + const { disabled } = useInputBaseContext() + + return ( +