Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvement: better sql completions #3930

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions frontend/src/components/datasources/components.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ export const EmptyState: React.FC<{ content: string; className?: string }> = ({
</div>
);
};

export const ItemSubtext: React.FC<{ content: string }> = ({ content }) => {
return (
<span className="text-xs text-black bg-gray-200 rounded px-1">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bg-gray-200 won't support dark mode, and i think it could be slightly lighter

maybe bg-[var(--slate-8] text-[var(--slate-9]?

{content}
</span>
);
};
46 changes: 33 additions & 13 deletions frontend/src/components/datasources/datasources.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ import {
import { PythonIcon } from "../editor/cell/code/icons";
import { PreviewSQLTable } from "@/core/functions/FunctionRegistry";
import { useAsyncData } from "@/hooks/useAsyncData";
import { DatasourceLabel, EmptyState, RotatingChevron } from "./components";
import {
DatasourceLabel,
EmptyState,
ItemSubtext,
RotatingChevron,
} from "./components";

const sortedTablesAtom = atom((get) => {
const tables = get(datasetTablesAtom);
Expand Down Expand Up @@ -186,6 +191,7 @@ export const DataSources: React.FC = () => {
>
<SchemaList
schemas={database.schemas}
defaultSchema={connection.default_schema}
engineName={connection.name}
databaseName={database.name}
hasSearch={hasSearch}
Expand Down Expand Up @@ -285,11 +291,19 @@ const DatabaseItem: React.FC<{

const SchemaList: React.FC<{
schemas: DatabaseSchema[];
defaultSchema?: string | null;
engineName: string;
databaseName: string;
hasSearch: boolean;
searchValue?: string;
}> = ({ schemas, engineName, databaseName, hasSearch, searchValue }) => {
}> = ({
schemas,
defaultSchema,
engineName,
databaseName,
hasSearch,
searchValue,
}) => {
if (schemas.length === 0) {
return <EmptyState content="No schemas available" className="pl-6" />;
}
Expand All @@ -310,6 +324,7 @@ const SchemaList: React.FC<{
key={schema.name}
databaseName={databaseName}
schema={schema}
isDefaultSchema={schema.name === defaultSchema}
hasSearch={hasSearch}
>
<TableList
Expand All @@ -319,6 +334,7 @@ const SchemaList: React.FC<{
engine: engineName,
database: databaseName,
schema: schema.name,
defaultSchema: defaultSchema,
}}
/>
</SchemaItem>
Expand All @@ -330,9 +346,10 @@ const SchemaList: React.FC<{
const SchemaItem: React.FC<{
databaseName: string;
schema: DatabaseSchema;
isDefaultSchema?: boolean;
children: React.ReactNode;
hasSearch: boolean;
}> = ({ databaseName, schema, children, hasSearch }) => {
}> = ({ databaseName, schema, isDefaultSchema, children, hasSearch }) => {
const [isExpanded, setIsExpanded] = React.useState(false);
const [isSelected, setIsSelected] = React.useState(false);
const uniqueValue = `${databaseName}:${schema.name}`;
Expand Down Expand Up @@ -361,6 +378,8 @@ const SchemaItem: React.FC<{
<span className={cn(isSelected && isExpanded && "font-semibold")}>
{schema.name}
</span>
{/* Do we want this? They could change the default by executing USE schema.. */}
{isDefaultSchema && <ItemSubtext content="default" />}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can remove for now? we can look if other SQL apps do this:
e.g. jetbrains or dbbeaver, if you can drum up screenshots (but don't need to spend too much time)

</CommandItem>
{isExpanded && children}
</>
Expand All @@ -371,6 +390,7 @@ interface SQLTableContext {
engine: string;
database: string;
schema: string;
defaultSchema?: string | null;
}

const TableList: React.FC<{
Expand Down Expand Up @@ -441,8 +461,10 @@ const DatasetTableItem: React.FC<{
maybeAddMarimoImport(autoInstantiate, createNewCell, lastFocusedCellId);
let code = "";
if (sqlTableContext) {
const { engine, schema } = sqlTableContext;
code = `_df = mo.sql(f"SELECT * FROM ${schema}.${table.name} LIMIT 100", engine=${engine})`;
const { engine, schema, defaultSchema } = sqlTableContext;
const tableName =
defaultSchema === schema ? table.name : `${schema}.${table.name}`;
code = `_df = mo.sql(f"SELECT * FROM ${tableName} LIMIT 100", engine=${engine})`;
} else {
switch (table.source_type) {
case "local":
Expand Down Expand Up @@ -638,16 +660,12 @@ const DatasetColumnItem: React.FC<{
<span>{column.name}</span>
{isPrimaryKey && (
<Tooltip content="Primary Key" delayDuration={100}>
<span className="text-xs text-black bg-gray-200 rounded px-1">
PK
</span>
<ItemSubtext content="PK" />
</Tooltip>
)}
{isIndexed && (
<Tooltip content="Indexed" delayDuration={100}>
<span className="text-xs text-black bg-gray-200 rounded px-1">
IDX
</span>
<ItemSubtext content="IDX" />
</Tooltip>
)}
</div>
Expand Down Expand Up @@ -837,8 +855,10 @@ function sqlCode(
sqlTableContext?: SQLTableContext,
) {
if (sqlTableContext) {
const { engine, schema } = sqlTableContext;
return `_df = mo.sql(f'SELECT ${column.name} FROM ${schema}.${table.name} LIMIT 100', engine=${engine})`;
const { engine, schema, defaultSchema } = sqlTableContext;
const tableName =
defaultSchema === schema ? table.name : `${schema}.${table.name}`;
return `_df = mo.sql(f"SELECT ${column.name} FROM ${tableName} LIMIT 100", engine=${engine})`;
}
return `_df = mo.sql(f'SELECT "${column.name}" FROM ${table.name} LIMIT 100')`;
}
25 changes: 8 additions & 17 deletions frontend/src/components/editor/ai/add-cell-with-ai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import { Prec } from "@codemirror/state";
import { customPythonLanguageSupport } from "@/core/codemirror/language/python";
import { asURL } from "@/utils/url";
import { useMemo, useState } from "react";
import { datasetTablesAtom } from "@/core/datasets/state";
import { useAtom, useAtomValue } from "jotai";
import type { Completion } from "@codemirror/autocomplete";
import {
Expand All @@ -31,7 +30,7 @@ import { SQLLanguageAdapter } from "@/core/codemirror/language/sql";
import { atomWithStorage } from "jotai/utils";
import { type ResolvedTheme, useTheme } from "@/theme/useTheme";
import { getAICompletionBody, mentions } from "./completion-utils";
import { dataSourceConnectionsAtom } from "@/core/datasets/data-source-connections";
import { allTablesAtom } from "@/core/datasets/data-source-connections";

const pythonExtensions = [
customPythonLanguageSupport(),
Expand Down Expand Up @@ -230,21 +229,12 @@ export const PromptInput = ({
}: PromptInputProps) => {
const handleSubmit = onSubmit;
const handleEscape = onClose;
const tables = useAtomValue(datasetTablesAtom);
const datasources = useAtomValue(dataSourceConnectionsAtom);
const tablesMap = useAtomValue(allTablesAtom);

const extensions = useMemo(() => {
const connections = [...datasources.connectionsMap.values()];
const allTables = [
...tables,
...connections.flatMap((c) =>
c.databases.flatMap((d) => d.schemas.flatMap((s) => s.tables)),
),
];

const completions = allTables.map(
(table): Completion => ({
label: `@${table.name}`,
const completions = [...tablesMap.entries()].map(
([tableName, table]): Completion => ({
label: `@${tableName}`,
detail: table.source,
info: () => {
const shape = [
Expand Down Expand Up @@ -298,7 +288,8 @@ export const PromptInput = ({
}),
);

const matchBeforeRegexes = [/@(\w+)?/]; // Trigger autocompletion for text that begins with @
// Trigger autocompletion for text that begins with @, can contain dots
const matchBeforeRegexes = [/@([\w.]+)?/];
if (additionalCompletions) {
matchBeforeRegexes.push(additionalCompletions.triggerCompletionRegex);
}
Expand Down Expand Up @@ -378,7 +369,7 @@ export const PromptInput = ({
},
]),
];
}, [tables, datasources, additionalCompletions, handleSubmit, handleEscape]);
}, [tablesMap, additionalCompletions, handleSubmit, handleEscape]);

return (
<ReactCodeMirror
Expand Down
18 changes: 3 additions & 15 deletions frontend/src/components/editor/ai/completion-utils.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
/* Copyright 2024 Marimo. All rights reserved. */
import { getCodes } from "@/core/codemirror/copilot/getCodes";
import { dataSourceConnectionsAtom } from "@/core/datasets/data-source-connections";
import { datasetTablesAtom } from "@/core/datasets/state";
import { allTablesAtom } from "@/core/datasets/data-source-connections";
import type { DataTable } from "@/core/kernel/messages";
import type { AiCompletionRequest } from "@/core/network/types";
import { store } from "@/core/state/jotai";
import { Logger } from "@/utils/Logger";
import { Maps } from "@/utils/maps";
import {
autocompletion,
type Completion,
Expand Down Expand Up @@ -43,25 +41,15 @@ export function getAICompletionBody(
* Datasets are referenced with @<dataset_name> in the input.
*/
function extractDatasets(input: string): DataTable[] {
const datasets = store.get(datasetTablesAtom);
const connectionsMap = store.get(dataSourceConnectionsAtom).connectionsMap;
const connections = [...connectionsMap.values()];
const allTables = [
...datasets,
...connections.flatMap((c) =>
c.databases.flatMap((d) => d.schemas.flatMap((s) => s.tables)),
),
];
// TODO: This does not handle duplicates table names.
const existingDatasets = Maps.keyBy(allTables, (dataset) => dataset.name);
const allTables = store.get(allTablesAtom);

// Extract dataset mentions from the input
const mentionedDatasets = input.match(/@([\w.]+)/g) || [];

// Filter to only include datasets that exist
return mentionedDatasets
.map((mention) => mention.slice(1))
.map((name) => existingDatasets.get(name))
.map((name) => allTables.get(name))
.filter(Boolean);
}

Expand Down
Loading
Loading