diff --git a/.github/workflows/ghcr-build.yml b/.github/workflows/ghcr-build.yml
index 7948291d6ec3..09cfffaaf9ff 100644
--- a/.github/workflows/ghcr-build.yml
+++ b/.github/workflows/ghcr-build.yml
@@ -68,6 +68,9 @@ jobs:
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v3
+ - name: "Set up docker layer caching"
+ uses: satackey/action-docker-layer-caching@v0.0.11
+ continue-on-error: true
- name: Build and push app image
if: "!github.event.pull_request.head.repo.fork"
run: |
diff --git a/.github/workflows/py-unit-tests.yml b/.github/workflows/py-unit-tests.yml
index 6e624904a82f..28a14095d7c1 100644
--- a/.github/workflows/py-unit-tests.yml
+++ b/.github/workflows/py-unit-tests.yml
@@ -42,7 +42,7 @@ jobs:
- name: Build Environment
run: make build
- name: Run Tests
- run: poetry run pytest --forked --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_memory.py
+ run: poetry run pytest --forked -n auto --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_memory.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
env:
diff --git a/config.template.toml b/config.template.toml
index 9d84c5a7ff34..6f626e6bee90 100644
--- a/config.template.toml
+++ b/config.template.toml
@@ -217,6 +217,9 @@ llm_config = 'gpt3'
# Use host network
#use_host_network = false
+# runtime extra build args
+#runtime_extra_build_args = ["--network=host", "--add-host=host.docker.internal:host-gateway"]
+
# Enable auto linting after editing
#enable_auto_lint = false
diff --git a/docs/modules/usage/how-to/headless-mode.md b/docs/modules/usage/how-to/headless-mode.md
index dfd4dd5e3e14..ff5e622de6a4 100644
--- a/docs/modules/usage/how-to/headless-mode.md
+++ b/docs/modules/usage/how-to/headless-mode.md
@@ -55,5 +55,5 @@ docker run -it \
--add-host host.docker.internal:host-gateway \
--name openhands-app-$(date +%Y%m%d%H%M%S) \
docker.all-hands.dev/all-hands-ai/openhands:0.15 \
- python -m openhands.core.main -t "write a bash script that prints hi"
+ python -m openhands.core.main -t "write a bash script that prints hi" --no-auto-continue
```
diff --git a/frontend/__tests__/components/chat/expandable-message.test.tsx b/frontend/__tests__/components/chat/expandable-message.test.tsx
new file mode 100644
index 000000000000..8eab988339de
--- /dev/null
+++ b/frontend/__tests__/components/chat/expandable-message.test.tsx
@@ -0,0 +1,60 @@
+import { describe, expect, it } from "vitest";
+import { screen } from "@testing-library/react";
+import { renderWithProviders } from "test-utils";
+import { ExpandableMessage } from "#/components/features/chat/expandable-message";
+
+describe("ExpandableMessage", () => {
+ it("should render with neutral border for non-action messages", () => {
+ renderWithProviders();
+ const element = screen.getByText("Hello");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
+ });
+
+ it("should render with neutral border for error messages", () => {
+ renderWithProviders();
+ const element = screen.getByText("Error occurred");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
+ });
+
+ it("should render with success icon for successful action messages", () => {
+ renderWithProviders(
+
+ );
+ const element = screen.getByText("Command executed successfully");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ const icon = screen.getByTestId("status-icon");
+ expect(icon).toHaveClass("fill-success");
+ });
+
+ it("should render with error icon for failed action messages", () => {
+ renderWithProviders(
+
+ );
+ const element = screen.getByText("Command failed");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ const icon = screen.getByTestId("status-icon");
+ expect(icon).toHaveClass("fill-danger");
+ });
+
+ it("should render with neutral border and no icon for action messages without success prop", () => {
+ renderWithProviders();
+ const element = screen.getByText("Running command");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
+ });
+});
diff --git a/frontend/__tests__/components/interactive-chat-box.test.tsx b/frontend/__tests__/components/interactive-chat-box.test.tsx
index fa0d3a1b8e30..fe6ba329763b 100644
--- a/frontend/__tests__/components/interactive-chat-box.test.tsx
+++ b/frontend/__tests__/components/interactive-chat-box.test.tsx
@@ -1,4 +1,4 @@
-import { render, screen, within } from "@testing-library/react";
+import { render, screen, within, fireEvent } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { afterEach, beforeAll, describe, expect, it, vi } from "vitest";
import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-box";
@@ -131,4 +131,60 @@ describe("InteractiveChatBox", () => {
await user.click(stopButton);
expect(onStopMock).toHaveBeenCalledOnce();
});
+
+ it("should handle image upload and message submission correctly", async () => {
+ const user = userEvent.setup();
+ const onSubmit = vi.fn();
+ const onStop = vi.fn();
+ const onChange = vi.fn();
+
+ const { rerender } = render(
+
+ );
+
+ // Upload an image via the upload button - this should NOT clear the text input
+ const file = new File(["dummy content"], "test.png", { type: "image/png" });
+ const input = screen.getByTestId("upload-image-input");
+ await user.upload(input, file);
+
+ // Verify text input was not cleared
+ expect(screen.getByRole("textbox")).toHaveValue("test message");
+ expect(onChange).not.toHaveBeenCalledWith("");
+
+ // Submit the message with image
+ const submitButton = screen.getByRole("button", { name: "Send" });
+ await user.click(submitButton);
+
+ // Verify onSubmit was called with the message and image
+ expect(onSubmit).toHaveBeenCalledWith("test message", [file]);
+
+ // Verify onChange was called to clear the text input
+ expect(onChange).toHaveBeenCalledWith("");
+
+ // Simulate parent component updating the value prop
+ rerender(
+
+ );
+
+ // Verify the text input was cleared
+ expect(screen.getByRole("textbox")).toHaveValue("");
+
+ // Upload another image - this should NOT clear the text input
+ onChange.mockClear();
+ await user.upload(input, file);
+
+ // Verify text input is still empty and onChange was not called
+ expect(screen.getByRole("textbox")).toHaveValue("");
+ expect(onChange).not.toHaveBeenCalled();
+ });
});
diff --git a/frontend/src/components/features/chat/chat-input.tsx b/frontend/src/components/features/chat/chat-input.tsx
index b69bcadfa50c..d999d8d1afb9 100644
--- a/frontend/src/components/features/chat/chat-input.tsx
+++ b/frontend/src/components/features/chat/chat-input.tsx
@@ -84,9 +84,13 @@ export function ChatInput({
const handleSubmitMessage = () => {
const trimmedValue = textareaRef.current?.value.trim();
- if (trimmedValue) {
- onSubmit(trimmedValue);
- textareaRef.current!.value = "";
+ if (value || (trimmedValue && !value)) {
+ onSubmit(value || trimmedValue || "");
+ if (value) {
+ onChange?.("");
+ } else if (textareaRef.current) {
+ textareaRef.current.value = "";
+ }
}
};
diff --git a/frontend/src/components/features/chat/expandable-message.tsx b/frontend/src/components/features/chat/expandable-message.tsx
index f42b3f0b13ad..6ebcaa3aeed5 100644
--- a/frontend/src/components/features/chat/expandable-message.tsx
+++ b/frontend/src/components/features/chat/expandable-message.tsx
@@ -6,17 +6,21 @@ import { code } from "../markdown/code";
import { ol, ul } from "../markdown/list";
import ArrowUp from "#/icons/angle-up-solid.svg?react";
import ArrowDown from "#/icons/angle-down-solid.svg?react";
+import CheckCircle from "#/icons/check-circle-solid.svg?react";
+import XCircle from "#/icons/x-circle-solid.svg?react";
interface ExpandableMessageProps {
id?: string;
message: string;
type: string;
+ success?: boolean;
}
export function ExpandableMessage({
id,
message,
type,
+ success,
}: ExpandableMessageProps) {
const { t, i18n } = useTranslation();
const [showDetails, setShowDetails] = useState(true);
@@ -31,22 +35,14 @@ export function ExpandableMessage({
}
}, [id, message, i18n.language]);
- const border = type === "error" ? "border-danger" : "border-neutral-300";
- const textColor = type === "error" ? "text-danger" : "text-neutral-300";
- let arrowClasses = "h-4 w-4 ml-2 inline";
- if (type === "error") {
- arrowClasses += " fill-danger";
- } else {
- arrowClasses += " fill-neutral-300";
- }
+ const arrowClasses = "h-4 w-4 ml-2 inline fill-neutral-300";
+ const statusIconClasses = "h-4 w-4 ml-2 inline";
return (
-
+
{headline && (
-
+
{headline}
+ {type === "action" && success !== undefined && (
+
+ {success ? (
+
+ ) : (
+
+ )}
+
+ )}
);
}
diff --git a/frontend/src/components/features/chat/interactive-chat-box.tsx b/frontend/src/components/features/chat/interactive-chat-box.tsx
index e96339adf0a7..09dcf84b32d6 100644
--- a/frontend/src/components/features/chat/interactive-chat-box.tsx
+++ b/frontend/src/components/features/chat/interactive-chat-box.tsx
@@ -38,6 +38,9 @@ export function InteractiveChatBox({
const handleSubmit = (message: string) => {
onSubmit(message, images);
setImages([]);
+ if (message) {
+ onChange?.("");
+ }
};
return (
diff --git a/frontend/src/components/features/chat/messages.tsx b/frontend/src/components/features/chat/messages.tsx
index 8b0d703b7553..e1bd34637472 100644
--- a/frontend/src/components/features/chat/messages.tsx
+++ b/frontend/src/components/features/chat/messages.tsx
@@ -20,6 +20,7 @@ export function Messages({
type={message.type}
id={message.translationID}
message={message.content}
+ success={message.success}
/>
);
}
diff --git a/frontend/src/icons/check-circle-solid.svg b/frontend/src/icons/check-circle-solid.svg
new file mode 100644
index 000000000000..a07362b4ab3f
--- /dev/null
+++ b/frontend/src/icons/check-circle-solid.svg
@@ -0,0 +1,4 @@
+
+
diff --git a/frontend/src/icons/x-circle-solid.svg b/frontend/src/icons/x-circle-solid.svg
new file mode 100644
index 000000000000..f673bbf0b1e5
--- /dev/null
+++ b/frontend/src/icons/x-circle-solid.svg
@@ -0,0 +1,4 @@
+
+
diff --git a/frontend/src/message.d.ts b/frontend/src/message.d.ts
index 5b70e39c8f56..65bd7e0cb193 100644
--- a/frontend/src/message.d.ts
+++ b/frontend/src/message.d.ts
@@ -4,6 +4,7 @@ type Message = {
timestamp: string;
imageUrls?: string[];
type?: "thought" | "error" | "action";
+ success?: boolean;
pending?: boolean;
translationID?: string;
eventID?: number;
diff --git a/frontend/src/routes/_oh.app/hooks/use-ws-status-change.ts b/frontend/src/routes/_oh.app/hooks/use-ws-status-change.ts
index 6c465e0d7088..789af7dadfec 100644
--- a/frontend/src/routes/_oh.app/hooks/use-ws-status-change.ts
+++ b/frontend/src/routes/_oh.app/hooks/use-ws-status-change.ts
@@ -52,6 +52,7 @@ export const useWSStatusChange = () => {
if (gitHubToken && selectedRepository) {
dispatch(clearSelectedRepository());
+ additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
} else if (importedProjectZip) {
// if there's an uploaded project zip, add it to the chat
additionalInfo =
diff --git a/frontend/src/state/chat-slice.ts b/frontend/src/state/chat-slice.ts
index 1a95d31b4429..a9236b0f0afb 100644
--- a/frontend/src/state/chat-slice.ts
+++ b/frontend/src/state/chat-slice.ts
@@ -1,14 +1,25 @@
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
import { ActionSecurityRisk } from "#/state/security-analyzer-slice";
-import { OpenHandsObservation } from "#/types/core/observations";
+import {
+ OpenHandsObservation,
+ CommandObservation,
+ IPythonObservation,
+} from "#/types/core/observations";
import { OpenHandsAction } from "#/types/core/actions";
+import { OpenHandsEventType } from "#/types/core/base";
type SliceState = { messages: Message[] };
const MAX_CONTENT_LENGTH = 1000;
-const HANDLED_ACTIONS = ["run", "run_ipython", "write", "read", "browse"];
+const HANDLED_ACTIONS: OpenHandsEventType[] = [
+ "run",
+ "run_ipython",
+ "write",
+ "read",
+ "browse",
+];
function getRiskText(risk: ActionSecurityRisk) {
switch (risk) {
@@ -131,6 +142,18 @@ export const chatSlice = createSlice({
return;
}
causeMessage.translationID = translationID;
+ // Set success property based on observation type
+ if (observationID === "run") {
+ const commandObs = observation.payload as CommandObservation;
+ causeMessage.success = commandObs.extras.exit_code === 0;
+ } else if (observationID === "run_ipython") {
+ // For IPython, we consider it successful if there's no error message
+ const ipythonObs = observation.payload as IPythonObservation;
+ causeMessage.success = !ipythonObs.message
+ .toLowerCase()
+ .includes("error");
+ }
+
if (observationID === "run" || observationID === "run_ipython") {
let { content } = observation.payload;
if (content.length > MAX_CONTENT_LENGTH) {
diff --git a/frontend/src/types/core/observations.ts b/frontend/src/types/core/observations.ts
index 0b95099a8384..7ddc3f05dd94 100644
--- a/frontend/src/types/core/observations.ts
+++ b/frontend/src/types/core/observations.ts
@@ -52,6 +52,21 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> {
};
}
+export interface WriteObservation extends OpenHandsObservationEvent<"write"> {
+ source: "agent";
+ extras: {
+ path: string;
+ content: string;
+ };
+}
+
+export interface ReadObservation extends OpenHandsObservationEvent<"read"> {
+ source: "agent";
+ extras: {
+ path: string;
+ };
+}
+
export interface ErrorObservation extends OpenHandsObservationEvent<"error"> {
source: "user";
extras: {
@@ -65,4 +80,6 @@ export type OpenHandsObservation =
| IPythonObservation
| DelegateObservation
| BrowseObservation
+ | WriteObservation
+ | ReadObservation
| ErrorObservation;
diff --git a/frontend/tailwind.config.js b/frontend/tailwind.config.js
index 1e665d23fdbb..1a57ebcd8bce 100644
--- a/frontend/tailwind.config.js
+++ b/frontend/tailwind.config.js
@@ -14,6 +14,7 @@ export default {
'root-secondary': '#262626',
'hyperlink': '#007AFF',
'danger': '#EF3744',
+ 'success': '#4CAF50',
},
},
},
diff --git a/frontend/test-utils.tsx b/frontend/test-utils.tsx
index 4b336602fbf6..6739e3be6e15 100644
--- a/frontend/test-utils.tsx
+++ b/frontend/test-utils.tsx
@@ -6,10 +6,31 @@ import { configureStore } from "@reduxjs/toolkit";
// eslint-disable-next-line import/no-extraneous-dependencies
import { RenderOptions, render } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
+import { I18nextProvider } from "react-i18next";
+import i18n from "i18next";
+import { initReactI18next } from "react-i18next";
import { AppStore, RootState, rootReducer } from "./src/store";
import { AuthProvider } from "#/context/auth-context";
import { UserPrefsProvider } from "#/context/user-prefs-context";
+// Initialize i18n for tests
+i18n
+ .use(initReactI18next)
+ .init({
+ lng: "en",
+ fallbackLng: "en",
+ ns: ["translation"],
+ defaultNS: "translation",
+ resources: {
+ en: {
+ translation: {},
+ },
+ },
+ interpolation: {
+ escapeValue: false,
+ },
+ });
+
const setupStore = (preloadedState?: Partial
): AppStore =>
configureStore({
reducer: rootReducer,
@@ -40,7 +61,9 @@ export function renderWithProviders(
- {children}
+
+ {children}
+
diff --git a/frontend/vitest.setup.ts b/frontend/vitest.setup.ts
index 105337e75eba..e9a89c8677f6 100644
--- a/frontend/vitest.setup.ts
+++ b/frontend/vitest.setup.ts
@@ -12,7 +12,13 @@ HTMLElement.prototype.scrollTo = vi.fn();
// Mock the i18n provider
vi.mock("react-i18next", async (importOriginal) => ({
...(await importOriginal()),
- useTranslation: () => ({ t: (key: string) => key }),
+ useTranslation: () => ({
+ t: (key: string) => key,
+ i18n: {
+ language: "en",
+ exists: () => false,
+ },
+ }),
}));
// Mock requests during tests
diff --git a/openhands/agenthub/codeact_agent/README.md b/openhands/agenthub/codeact_agent/README.md
index 9a5093820e85..0e15939cdfb8 100644
--- a/openhands/agenthub/codeact_agent/README.md
+++ b/openhands/agenthub/codeact_agent/README.md
@@ -110,4 +110,4 @@ The agent is implemented in two main files:
2. `function_calling.py`: Tool definitions and function calling interface with:
- Tool parameter specifications
- Tool descriptions and examples
- - Function calling response parsing
\ No newline at end of file
+ - Function calling response parsing
diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py
index 8e2380d6552f..b940a2e35eb0 100644
--- a/openhands/core/config/sandbox_config.py
+++ b/openhands/core/config/sandbox_config.py
@@ -51,6 +51,7 @@ class SandboxConfig:
False # once enabled, OpenHands would lint files after editing
)
use_host_network: bool = True
+ runtime_extra_build_args: list[str] | None = None
initialize_plugins: bool = True
force_rebuild_runtime: bool = False
runtime_extra_deps: str | None = None
diff --git a/openhands/events/observation/commands.py b/openhands/events/observation/commands.py
index f3f94adf2eac..dccd9b6c9ccd 100644
--- a/openhands/events/observation/commands.py
+++ b/openhands/events/observation/commands.py
@@ -23,6 +23,10 @@ def error(self) -> bool:
def message(self) -> str:
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
+ @property
+ def success(self) -> bool:
+ return not self.error
+
def __str__(self) -> str:
return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}'
@@ -42,5 +46,9 @@ def error(self) -> bool:
def message(self) -> str:
return 'Code executed in IPython cell.'
+ @property
+ def success(self) -> bool:
+ return True # IPython cells are always considered successful
+
def __str__(self) -> str:
return f'**IPythonRunCellObservation (source={self.source})**\n{self.content}'
diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py
index b2c9b0284610..cefc04398651 100644
--- a/openhands/events/serialization/event.py
+++ b/openhands/events/serialization/event.py
@@ -88,6 +88,9 @@ def event_to_dict(event: 'Event') -> dict:
elif 'observation' in d:
d['content'] = props.pop('content', '')
d['extras'] = props
+ # Include success field for CmdOutputObservation
+ if hasattr(event, 'success'):
+ d['success'] = event.success
elif 'log' in d:
pass
else:
diff --git a/openhands/events/serialization/observation.py b/openhands/events/serialization/observation.py
index 9030ccb1e1dd..d9d8dc51adaf 100644
--- a/openhands/events/serialization/observation.py
+++ b/openhands/events/serialization/observation.py
@@ -50,4 +50,5 @@ def observation_from_dict(observation: dict) -> Observation:
observation.pop('message', None)
content = observation.pop('content', '')
extras = observation.pop('extras', {})
+
return observation_class(content=content, **extras)
diff --git a/openhands/llm/debug_mixin.py b/openhands/llm/debug_mixin.py
index c56980e78da9..25761192ed13 100644
--- a/openhands/llm/debug_mixin.py
+++ b/openhands/llm/debug_mixin.py
@@ -17,7 +17,7 @@ def log_prompt(self, messages: list[Message] | Message):
debug_message = MESSAGE_SEPARATOR.join(
self._format_message_content(msg)
for msg in messages
- if msg.get('content', None)
+ if msg['content'] is not None
)
if debug_message:
diff --git a/openhands/llm/fn_call_converter.py b/openhands/llm/fn_call_converter.py
index 491ef906eaa1..ae4d87f8d62b 100644
--- a/openhands/llm/fn_call_converter.py
+++ b/openhands/llm/fn_call_converter.py
@@ -321,7 +321,7 @@ def convert_fncall_messages_to_non_fncall_messages(
first_user_message_encountered = False
for message in messages:
role = message['role']
- content = message.get('content', '')
+ content = message['content']
# 1. SYSTEM MESSAGES
# append system prompt suffix to content
diff --git a/openhands/runtime/builder/base.py b/openhands/runtime/builder/base.py
index df2ee99035c9..acfe3c60fb89 100644
--- a/openhands/runtime/builder/base.py
+++ b/openhands/runtime/builder/base.py
@@ -8,6 +8,7 @@ def build(
path: str,
tags: list[str],
platform: str | None = None,
+ extra_build_args: list[str] | None = None,
) -> str:
"""Build the runtime image.
@@ -15,6 +16,7 @@ def build(
path (str): The path to the runtime image's build directory.
tags (list[str]): The tags to apply to the runtime image (e.g., ["repo:my-repo", "sha:my-sha"]).
platform (str, optional): The target platform for the build. Defaults to None.
+ extra_build_args (list[str], optional): Additional build arguments to pass to the builder. Defaults to None.
Returns:
str: The name:tag of the runtime image after build (e.g., "repo:sha").
diff --git a/openhands/runtime/builder/docker.py b/openhands/runtime/builder/docker.py
index 9cdf0f998fcd..880b1c73c578 100644
--- a/openhands/runtime/builder/docker.py
+++ b/openhands/runtime/builder/docker.py
@@ -28,8 +28,8 @@ def build(
path: str,
tags: list[str],
platform: str | None = None,
- use_local_cache: bool = False,
extra_build_args: list[str] | None = None,
+ use_local_cache: bool = False,
) -> str:
"""Builds a Docker image using BuildKit and handles the build logs appropriately.
diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py
index c9d3228a70af..5cfe1a4943a4 100644
--- a/openhands/runtime/builder/remote.py
+++ b/openhands/runtime/builder/remote.py
@@ -23,7 +23,13 @@ def __init__(self, api_url: str, api_key: str):
self.session = requests.Session()
self.session.headers.update({'X-API-Key': self.api_key})
- def build(self, path: str, tags: list[str], platform: str | None = None) -> str:
+ def build(
+ self,
+ path: str,
+ tags: list[str],
+ platform: str | None = None,
+ extra_build_args: list[str] | None = None,
+ ) -> str:
"""Builds a Docker image using the Runtime API's /build endpoint."""
# Create a tar archive of the build context
tar_buffer = io.BytesIO()
diff --git a/openhands/runtime/impl/eventstream/eventstream_runtime.py b/openhands/runtime/impl/eventstream/eventstream_runtime.py
index 1cd1068f4867..521d61473de1 100644
--- a/openhands/runtime/impl/eventstream/eventstream_runtime.py
+++ b/openhands/runtime/impl/eventstream/eventstream_runtime.py
@@ -249,6 +249,7 @@ async def connect(self):
platform=self.config.sandbox.platform,
extra_deps=self.config.sandbox.runtime_extra_deps,
force_rebuild=self.config.sandbox.force_rebuild_runtime,
+ extra_build_args=self.config.sandbox.runtime_extra_build_args,
)
self.log(
diff --git a/openhands/runtime/utils/runtime_build.py b/openhands/runtime/utils/runtime_build.py
index eab98befe538..de939efd9a38 100644
--- a/openhands/runtime/utils/runtime_build.py
+++ b/openhands/runtime/utils/runtime_build.py
@@ -111,6 +111,7 @@ def build_runtime_image(
build_folder: str | None = None,
dry_run: bool = False,
force_rebuild: bool = False,
+ extra_build_args: List[str] | None = None,
) -> str:
"""Prepares the final docker build folder.
If dry_run is False, it will also build the OpenHands runtime Docker image using the docker build folder.
@@ -123,6 +124,7 @@ def build_runtime_image(
- build_folder (str): The directory to use for the build. If not provided a temporary directory will be used
- dry_run (bool): if True, it will only ready the build folder. It will not actually build the Docker image
- force_rebuild (bool): if True, it will create the Dockerfile which uses the base_image
+ - extra_build_args (List[str]): Additional build arguments to pass to the builder
Returns:
- str: :. Where MD5 hash is the hash of the docker build folder
@@ -139,6 +141,7 @@ def build_runtime_image(
dry_run=dry_run,
force_rebuild=force_rebuild,
platform=platform,
+ extra_build_args=extra_build_args,
)
return result
@@ -150,6 +153,7 @@ def build_runtime_image(
dry_run=dry_run,
force_rebuild=force_rebuild,
platform=platform,
+ extra_build_args=extra_build_args,
)
return result
@@ -162,6 +166,7 @@ def build_runtime_image_in_folder(
dry_run: bool,
force_rebuild: bool,
platform: str | None = None,
+ extra_build_args: List[str] | None = None,
) -> str:
runtime_image_repo, _ = get_runtime_image_repo_and_tag(base_image)
lock_tag = f'oh_v{oh_version}_{get_hash_for_lock_files(base_image)}'
@@ -193,6 +198,7 @@ def build_runtime_image_in_folder(
lock_tag,
versioned_tag,
platform,
+ extra_build_args=extra_build_args,
)
return hash_image_name
@@ -234,6 +240,7 @@ def build_runtime_image_in_folder(
if build_from == BuildFromImageType.SCRATCH
else None,
platform=platform,
+ extra_build_args=extra_build_args,
)
return hash_image_name
@@ -339,6 +346,7 @@ def _build_sandbox_image(
lock_tag: str,
versioned_tag: str | None,
platform: str | None = None,
+ extra_build_args: List[str] | None = None,
):
"""Build and tag the sandbox image. The image will be tagged with all tags that do not yet exist"""
names = [
@@ -350,7 +358,10 @@ def _build_sandbox_image(
names = [name for name in names if not runtime_builder.image_exists(name, False)]
image_name = runtime_builder.build(
- path=str(build_folder), tags=names, platform=platform
+ path=str(build_folder),
+ tags=names,
+ platform=platform,
+ extra_build_args=extra_build_args,
)
if not image_name:
raise RuntimeError(f'Build failed for image {names}')
diff --git a/openhands/server/app.py b/openhands/server/app.py
index 3ea05dc29754..fd4500339a7a 100644
--- a/openhands/server/app.py
+++ b/openhands/server/app.py
@@ -1,4 +1,5 @@
import warnings
+from contextlib import asynccontextmanager
from fastapi.responses import RedirectResponse
@@ -23,10 +24,17 @@
from openhands.server.routes.files import app as files_api_router
from openhands.server.routes.public import app as public_api_router
from openhands.server.routes.security import app as security_api_router
-from openhands.server.shared import config
+from openhands.server.shared import config, session_manager
from openhands.utils.import_utils import get_impl
-app = FastAPI()
+
+@asynccontextmanager
+async def _lifespan(app: FastAPI):
+ async with session_manager:
+ yield
+
+
+app = FastAPI(lifespan=_lifespan)
app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py
index 7e5d697b0258..c2a2d64940bf 100644
--- a/openhands/server/session/manager.py
+++ b/openhands/server/session/manager.py
@@ -52,6 +52,7 @@ async def _redis_subscribe(self):
"""
We use a redis backchannel to send actions between server nodes
"""
+ logger.debug('_redis_subscribe')
redis_client = self._get_redis_client()
pubsub = redis_client.pubsub()
await pubsub.subscribe('oh_event')
@@ -75,7 +76,7 @@ async def _redis_subscribe(self):
async def _process_message(self, message: dict):
data = json.loads(message['data'])
- logger.info(f'got_published_message:{message}')
+ logger.debug(f'got_published_message:{message}')
sid = data['sid']
message_type = data['message_type']
if message_type == 'event':
@@ -112,7 +113,7 @@ async def _process_message(self, message: dict):
elif message_type == 'session_closing':
# Session closing event - We only get this in the event of graceful shutdown,
# which can't be guaranteed - nodes can simply vanish unexpectedly!
- logger.info(f'session_closing:{sid}')
+ logger.debug(f'session_closing:{sid}')
for (
connection_id,
local_sid,
@@ -142,7 +143,9 @@ async def attach_to_conversation(self, sid: str) -> Conversation | None:
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()
- async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
+ async def init_or_join_session(
+ self, sid: str, connection_id: str, session_init_data: SessionInitData
+ ):
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid
@@ -165,6 +168,7 @@ async def _is_session_running_in_cluster(self, sid: str) -> bool:
flag = asyncio.Event()
self._session_is_running_flags[sid] = flag
try:
+ logger.debug(f'publish:is_session_running:{sid}')
await self._get_redis_client().publish(
'oh_event',
json.dumps(
diff --git a/poetry.lock b/poetry.lock
index 11cb9c8862ad..a0459d9fa5ec 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -553,17 +553,17 @@ files = [
[[package]]
name = "boto3"
-version = "1.35.68"
+version = "1.35.78"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "boto3-1.35.68-py3-none-any.whl", hash = "sha256:9b26fa31901da7793c1dcd65eee9bab7e897d8aa1ffed0b5e1c3bce93d2aefe4"},
- {file = "boto3-1.35.68.tar.gz", hash = "sha256:091d6bed1422370987a839bff3f8755df7404fc15e9fac2a48e8505356f07433"},
+ {file = "boto3-1.35.78-py3-none-any.whl", hash = "sha256:5ef7166fe5060637b92af8dc152cd7acecf96b3fc9c5456706a886cadb534391"},
+ {file = "boto3-1.35.78.tar.gz", hash = "sha256:fc8001519c8842e766ad3793bde3fbd0bb39e821a582fc12cf67876b8f3cf7f1"},
]
[package.dependencies]
-botocore = ">=1.35.68,<1.36.0"
+botocore = ">=1.35.78,<1.36.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
@@ -572,13 +572,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
-version = "1.35.68"
+version = "1.35.78"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
- {file = "botocore-1.35.68-py3-none-any.whl", hash = "sha256:599139d5564291f5be873800711f9e4e14a823395ae9ce7b142be775e9849b94"},
- {file = "botocore-1.35.68.tar.gz", hash = "sha256:42c3700583a82f2b5316281a073d644a521d6358837e2b446dc458ba5d990fb4"},
+ {file = "botocore-1.35.78-py3-none-any.whl", hash = "sha256:41c37bd7c0326f25122f33ec84fb80fc0a14d7fcc9961431b0e57568e88c9cb5"},
+ {file = "botocore-1.35.78.tar.gz", hash = "sha256:6905036c25449ae8dba5e950e4b908e4b8a6fe6b516bf61e007ecb62fa21f323"},
]
[package.dependencies]
@@ -3739,22 +3739,23 @@ types-tqdm = "*"
[[package]]
name = "litellm"
-version = "1.52.15"
+version = "1.54.1"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
files = [
- {file = "litellm-1.52.15-py3-none-any.whl", hash = "sha256:8a2d8e2526c5e7afb3006b0214d3c348778462fefafd582fd76bb7f5c35d28d0"},
- {file = "litellm-1.52.15.tar.gz", hash = "sha256:11a61b1b033ddff9d480da66c00acc9d3e4fbfeed166d1b0de8eda16c684116e"},
+ {file = "litellm-1.54.1-py3-none-any.whl", hash = "sha256:d8e60d4a5e8decb0234a1e8c20351c904aec561fb4025df7df3d0d7ea81ca442"},
+ {file = "litellm-1.54.1.tar.gz", hash = "sha256:b5a8fc99160fab0699b9258457432b3975499218ffcf1b515709808b2ce5a2d7"},
]
[package.dependencies]
aiohttp = "*"
click = "*"
+httpx = ">=0.23.0,<0.28.0"
importlib-metadata = ">=6.8.0"
jinja2 = ">=3.1.2,<4.0.0"
jsonschema = ">=4.22.0,<5.0.0"
-openai = ">=1.54.0"
+openai = ">=1.55.3"
pydantic = ">=2.0.0,<3.0.0"
python-dotenv = ">=0.2.0"
requests = ">=2.31.0,<3.0.0"
@@ -5413,13 +5414,13 @@ sympy = "*"
[[package]]
name = "openai"
-version = "1.55.0"
+version = "1.57.2"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.8"
files = [
- {file = "openai-1.55.0-py3-none-any.whl", hash = "sha256:446e08918f8dd70d8723274be860404c8c7cc46b91b93bbc0ef051f57eb503c1"},
- {file = "openai-1.55.0.tar.gz", hash = "sha256:6c0975ac8540fe639d12b4ff5a8e0bf1424c844c4a4251148f59f06c4b2bd5db"},
+ {file = "openai-1.57.2-py3-none-any.whl", hash = "sha256:f7326283c156fdee875746e7e54d36959fb198eadc683952ee05e3302fbd638d"},
+ {file = "openai-1.57.2.tar.gz", hash = "sha256:5f49fd0f38e9f2131cda7deb45dafdd1aee4f52a637e190ce0ecf40147ce8cee"},
]
[package.dependencies]
@@ -10060,4 +10061,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
-content-hash = "ff3daee70a197e3f6ff460bd1e14be7ed443a100805947ee18df7afb7d898584"
+content-hash = "039581f859df4446dc9491bf39913a54f53c5d71e9bad86ff71ddd1d1682f9af"
diff --git a/pyproject.toml b/pyproject.toml
index acae091bb6e9..4da046981212 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,7 +14,7 @@ packages = [
python = "^3.12"
datasets = "*"
pandas = "*"
-litellm = "^1.52.3"
+litellm = "^1.54.1"
google-generativeai = "*" # To use litellm with Gemini Pro API
google-api-python-client = "*" # For Google Sheets API
google-auth-httplib2 = "*" # For Google Sheets authentication
@@ -98,6 +98,7 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
+
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -128,6 +129,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
+
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"
diff --git a/tests/unit/test_command_success.py b/tests/unit/test_command_success.py
new file mode 100644
index 000000000000..b52ceb4815c7
--- /dev/null
+++ b/tests/unit/test_command_success.py
@@ -0,0 +1,27 @@
+from openhands.events.observation.commands import (
+ CmdOutputObservation,
+ IPythonRunCellObservation,
+)
+
+
+def test_cmd_output_success():
+ # Test successful command
+ obs = CmdOutputObservation(
+ command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
+ )
+ assert obs.success is True
+ assert obs.error is False
+
+ # Test failed command
+ obs = CmdOutputObservation(
+ command_id=2, command='ls', content='No such file or directory', exit_code=1
+ )
+ assert obs.success is False
+ assert obs.error is True
+
+
+def test_ipython_cell_success():
+ # IPython cells are always successful
+ obs = IPythonRunCellObservation(code='print("Hello")', content='Hello')
+ assert obs.success is True
+ assert obs.error is False
diff --git a/tests/unit/test_event_serialization.py b/tests/unit/test_event_serialization.py
new file mode 100644
index 000000000000..d1989a30bb09
--- /dev/null
+++ b/tests/unit/test_event_serialization.py
@@ -0,0 +1,18 @@
+from openhands.events.observation import CmdOutputObservation
+from openhands.events.serialization import event_to_dict
+
+
+def test_command_output_success_serialization():
+ # Test successful command
+ obs = CmdOutputObservation(
+ command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is True
+
+ # Test failed command
+ obs = CmdOutputObservation(
+ command_id=2, command='ls', content='No such file or directory', exit_code=1
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is False
diff --git a/tests/unit/test_observation_serialization.py b/tests/unit/test_observation_serialization.py
index ae636ddf562b..67a95449b719 100644
--- a/tests/unit/test_observation_serialization.py
+++ b/tests/unit/test_observation_serialization.py
@@ -40,36 +40,23 @@ def serialization_deserialization(
# Additional tests for various observation subclasses can be included here
-def test_observation_event_props_serialization_deserialization():
- original_observation_dict = {
- 'id': 42,
- 'source': 'agent',
- 'timestamp': '2021-08-01T12:00:00',
- 'observation': 'run',
- 'message': 'Command `ls -l` executed with exit code 0.',
- 'extras': {
- 'exit_code': 0,
- 'command': 'ls -l',
- 'command_id': 3,
- 'hidden': False,
- 'interpreter_details': '',
- },
- 'content': 'foo.txt',
- }
- serialization_deserialization(original_observation_dict, CmdOutputObservation)
-
+def test_success_field_serialization():
+ # Test success=True
+ obs = CmdOutputObservation(
+ content='Command succeeded',
+ exit_code=0,
+ command='ls -l',
+ command_id=3,
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is True
-def test_command_output_observation_serialization_deserialization():
- original_observation_dict = {
- 'observation': 'run',
- 'extras': {
- 'exit_code': 0,
- 'command': 'ls -l',
- 'command_id': 3,
- 'hidden': False,
- 'interpreter_details': '',
- },
- 'message': 'Command `ls -l` executed with exit code 0.',
- 'content': 'foo.txt',
- }
- serialization_deserialization(original_observation_dict, CmdOutputObservation)
+ # Test success=False
+ obs = CmdOutputObservation(
+ content='No such file or directory',
+ exit_code=1,
+ command='ls -l',
+ command_id=3,
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is False
diff --git a/tests/unit/test_runtime_build.py b/tests/unit/test_runtime_build.py
index 79a7c9a22b6b..b5cbd91056e8 100644
--- a/tests/unit/test_runtime_build.py
+++ b/tests/unit/test_runtime_build.py
@@ -239,6 +239,7 @@ def test_build_runtime_image_from_scratch():
f'{get_runtime_image_repo()}:{OH_VERSION}_mock-versioned-tag',
],
platform=None,
+ extra_build_args=None,
)
assert (
image_name
@@ -333,6 +334,7 @@ def image_exists_side_effect(image_name, *args):
# VERSION tag will NOT be included except from scratch
],
platform=None,
+ extra_build_args=None,
)
mock_prep_build_folder.assert_called_once_with(
ANY,
@@ -391,6 +393,7 @@ def image_exists_side_effect(image_name, *args):
# VERSION tag will NOT be included except from scratch
],
platform=None,
+ extra_build_args=None,
)
mock_prep_build_folder.assert_called_once_with(
ANY,
diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py
index 8eb7432ccdc1..c886f9d80b89 100644
--- a/tests/unit/test_security.py
+++ b/tests/unit/test_security.py
@@ -51,24 +51,48 @@ def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]])
def test_msg(temp_dir: str):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- policy = """
- raise "Disallow ABC [risk=medium]" if:
- (msg: Message)
- "ABC" in msg.content
- """
- InvariantAnalyzer(event_stream, policy)
- data = [
- (MessageAction('Hello world!'), EventSource.USER),
- (MessageAction('AB!'), EventSource.AGENT),
- (MessageAction('Hello world!'), EventSource.USER),
- (MessageAction('ABC!'), EventSource.AGENT),
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
+
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [], # First check
+ [], # Second check
+ [], # Third check
+ [
+ 'PolicyViolation(Disallow ABC [risk=medium], ranges=[<2 ranges>])'
+ ], # Fourth check
]
- add_events(event_stream, data)
- for i in range(3):
- assert data[i][0].security_risk == ActionSecurityRisk.LOW
- assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ policy = """
+ raise "Disallow ABC [risk=medium]" if:
+ (msg: Message)
+ "ABC" in msg.content
+ """
+ InvariantAnalyzer(event_stream, policy)
+ data = [
+ (MessageAction('Hello world!'), EventSource.USER),
+ (MessageAction('AB!'), EventSource.AGENT),
+ (MessageAction('Hello world!'), EventSource.USER),
+ (MessageAction('ABC!'), EventSource.AGENT),
+ ]
+ add_events(event_stream, data)
+ for i in range(3):
+ assert data[i][0].security_risk == ActionSecurityRisk.LOW
+ assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
@pytest.mark.parametrize(
@@ -76,22 +100,44 @@ def test_msg(temp_dir: str):
[('rm -rf root_dir', ActionSecurityRisk.MEDIUM), ['ls', ActionSecurityRisk.LOW]],
)
def test_cmd(cmd, expected_risk, temp_dir: str):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- policy = """
- raise "Disallow rm -rf [risk=medium]" if:
- (call: ToolCall)
- call is tool:run
- match("rm -rf", call.function.arguments.command)
- """
- InvariantAnalyzer(event_stream, policy)
- data = [
- (MessageAction('Hello world!'), EventSource.USER),
- (CmdRunAction(cmd), EventSource.USER),
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
+
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [], # First check
+ ['PolicyViolation(Disallow rm -rf [risk=medium], ranges=[<2 ranges>])']
+ if expected_risk == ActionSecurityRisk.MEDIUM
+ else [], # Second check
]
- add_events(event_stream, data)
- assert data[0][0].security_risk == ActionSecurityRisk.LOW
- assert data[1][0].security_risk == expected_risk
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ policy = """
+ raise "Disallow rm -rf [risk=medium]" if:
+ (call: ToolCall)
+ call is tool:run
+ match("rm -rf", call.function.arguments.command)
+ """
+ InvariantAnalyzer(event_stream, policy)
+ data = [
+ (MessageAction('Hello world!'), EventSource.USER),
+ (CmdRunAction(cmd), EventSource.USER),
+ ]
+ add_events(event_stream, data)
+ assert data[0][0].security_risk == ActionSecurityRisk.LOW
+ assert data[1][0].security_risk == expected_risk
@pytest.mark.parametrize(
@@ -102,26 +148,49 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
],
)
def test_leak_secrets(code, expected_risk, temp_dir: str):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- policy = """
- from invariant.detectors import secrets
-
- raise "Disallow writing secrets [risk=medium]" if:
- (call: ToolCall)
- call is tool:run_ipython
- any(secrets(call.function.arguments.code))
- """
- InvariantAnalyzer(event_stream, policy)
- data = [
- (MessageAction('Hello world!'), EventSource.USER),
- (IPythonRunCellAction(code), EventSource.AGENT),
- (IPythonRunCellAction('hello'), EventSource.AGENT),
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
+
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [], # First check
+ ['PolicyViolation(Disallow writing secrets [risk=medium], ranges=[<2 ranges>])']
+ if expected_risk == ActionSecurityRisk.MEDIUM
+ else [], # Second check
+ [], # Third check
]
- add_events(event_stream, data)
- assert data[0][0].security_risk == ActionSecurityRisk.LOW
- assert data[1][0].security_risk == expected_risk
- assert data[2][0].security_risk == ActionSecurityRisk.LOW
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ policy = """
+ from invariant.detectors import secrets
+
+ raise "Disallow writing secrets [risk=medium]" if:
+ (call: ToolCall)
+ call is tool:run_ipython
+ any(secrets(call.function.arguments.code))
+ """
+ InvariantAnalyzer(event_stream, policy)
+ data = [
+ (MessageAction('Hello world!'), EventSource.USER),
+ (IPythonRunCellAction(code), EventSource.AGENT),
+ (IPythonRunCellAction('hello'), EventSource.AGENT),
+ ]
+ add_events(event_stream, data)
+ assert data[0][0].security_risk == ActionSecurityRisk.LOW
+ assert data[1][0].security_risk == expected_risk
+ assert data[2][0].security_risk == ActionSecurityRisk.LOW
def test_unsafe_python_code(temp_dir: str):
@@ -460,26 +529,48 @@ def default_config():
def test_check_usertask(
mock_litellm_completion, usertask, is_appropriate, default_config, temp_dir: str
):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- analyzer = InvariantAnalyzer(event_stream)
- mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
- mock_litellm_completion.return_value = mock_response
- analyzer.guardrail_llm = LLM(config=default_config)
- analyzer.check_browsing_alignment = True
- data = [
- (MessageAction(usertask), EventSource.USER),
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
+
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [],
+ [
+ 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
+ ],
]
- add_events(event_stream, data)
- event_list = list(event_stream.get_events())
- if is_appropriate == 'No':
- assert len(event_list) == 2
- assert type(event_list[0]) == MessageAction
- assert type(event_list[1]) == ChangeAgentStateAction
- elif is_appropriate == 'Yes':
- assert len(event_list) == 1
- assert type(event_list[0]) == MessageAction
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ analyzer = InvariantAnalyzer(event_stream)
+ mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
+ mock_litellm_completion.return_value = mock_response
+ analyzer.guardrail_llm = LLM(config=default_config)
+ analyzer.check_browsing_alignment = True
+ data = [
+ (MessageAction(usertask), EventSource.USER),
+ ]
+ add_events(event_stream, data)
+ event_list = list(event_stream.get_events())
+
+ if is_appropriate == 'No':
+ assert len(event_list) == 2
+ assert type(event_list[0]) == MessageAction
+ assert type(event_list[1]) == ChangeAgentStateAction
+ elif is_appropriate == 'Yes':
+ assert len(event_list) == 1
+ assert type(event_list[0]) == MessageAction
@pytest.mark.parametrize(
@@ -493,23 +584,45 @@ def test_check_usertask(
def test_check_fillaction(
mock_litellm_completion, fillaction, is_harmful, default_config, temp_dir: str
):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- analyzer = InvariantAnalyzer(event_stream)
- mock_response = {'choices': [{'message': {'content': is_harmful}}]}
- mock_litellm_completion.return_value = mock_response
- analyzer.guardrail_llm = LLM(config=default_config)
- analyzer.check_browsing_alignment = True
- data = [
- (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
+
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [],
+ [
+ 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
+ ],
]
- add_events(event_stream, data)
- event_list = list(event_stream.get_events())
-
- if is_harmful == 'Yes':
- assert len(event_list) == 2
- assert type(event_list[0]) == BrowseInteractiveAction
- assert type(event_list[1]) == ChangeAgentStateAction
- elif is_harmful == 'No':
- assert len(event_list) == 1
- assert type(event_list[0]) == BrowseInteractiveAction
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ analyzer = InvariantAnalyzer(event_stream)
+ mock_response = {'choices': [{'message': {'content': is_harmful}}]}
+ mock_litellm_completion.return_value = mock_response
+ analyzer.guardrail_llm = LLM(config=default_config)
+ analyzer.check_browsing_alignment = True
+ data = [
+ (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
+ ]
+ add_events(event_stream, data)
+ event_list = list(event_stream.get_events())
+
+ if is_harmful == 'Yes':
+ assert len(event_list) == 2
+ assert type(event_list[0]) == BrowseInteractiveAction
+ assert type(event_list[1]) == ChangeAgentStateAction
+ elif is_harmful == 'No':
+ assert len(event_list) == 1
+ assert type(event_list[0]) == BrowseInteractiveAction