diff --git a/.github/workflows/openhands-resolver.yml b/.github/workflows/openhands-resolver.yml index 028316ee05d5..a9d90c38b139 100644 --- a/.github/workflows/openhands-resolver.yml +++ b/.github/workflows/openhands-resolver.yml @@ -184,6 +184,7 @@ jobs: }); - name: Install OpenHands + id: install_openhands uses: actions/github-script@v7 env: COMMENT_BODY: ${{ github.event.comment.body || '' }} @@ -196,7 +197,6 @@ jobs: const reviewBody = process.env.REVIEW_BODY.trim(); const labelName = process.env.LABEL_NAME.trim(); const eventName = process.env.EVENT_NAME.trim(); - // Check conditions const isExperimentalLabel = labelName === "fix-me-experimental"; const isIssueCommentExperimental = @@ -205,6 +205,9 @@ jobs: const isReviewCommentExperimental = eventName === "pull_request_review" && reviewBody.includes("@openhands-agent-exp"); + // Set output variable + core.setOutput('isExperimental', isExperimentalLabel || isIssueCommentExperimental || isReviewCommentExperimental); + // Perform package installation if (isExperimentalLabel || isIssueCommentExperimental || isReviewCommentExperimental) { console.log("Installing experimental OpenHands..."); @@ -230,7 +233,8 @@ jobs: --issue-number ${{ env.ISSUE_NUMBER }} \ --issue-type ${{ env.ISSUE_TYPE }} \ --max-iterations ${{ env.MAX_ITERATIONS }} \ - --comment-id ${{ env.COMMENT_ID }} + --comment-id ${{ env.COMMENT_ID }} \ + --is-experimental ${{ steps.install_openhands.outputs.isExperimental }} - name: Check resolution result id: check_result diff --git a/config.template.toml b/config.template.toml index 8f26eaf92b88..ccb7b1159747 100644 --- a/config.template.toml +++ b/config.template.toml @@ -23,6 +23,9 @@ workspace_base = "./workspace" # Cache directory path #cache_dir = "/tmp/cache" +# Reasoning effort for o1 models (low, medium, high, or not set) +#reasoning_effort = "medium" + # Debugging enabled #debug = false @@ -220,8 +223,8 @@ codeact_enable_jupyter = true # LLM config group to use #llm_config = 'your-llm-config-group' -# Whether to use microagents at all -#use_microagents = true +# Whether to use prompt extension (e.g., microagent, repo/runtime info) at all +#enable_prompt_extensions = true # List of microagents to disable #disabled_microagents = [] diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/configuration-options.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/configuration-options.md index 848b85a53164..b79a65073acc 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/configuration-options.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/configuration-options.md @@ -373,7 +373,7 @@ Agent 配置选项在 `config.toml` 文件的 `[agent]` 和 `[agent. - 描述: 是否在 action space 中启用 Jupyter **Microagent 使用** -- `use_microagents` +- `enable_prompt_extensions` - 类型: `bool` - 默认值: `true` - 描述: 是否使用 microagents diff --git a/docs/modules/usage/configuration-options.md b/docs/modules/usage/configuration-options.md index 422dc1cc4913..a3c11de52ed8 100644 --- a/docs/modules/usage/configuration-options.md +++ b/docs/modules/usage/configuration-options.md @@ -336,7 +336,7 @@ The agent configuration options are defined in the `[agent]` and `[agent. test_output.txt 2>&1" ) - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -292,7 +292,7 @@ def complete_runtime( ) # Read test output action = CmdRunAction(command='cat test_output.txt') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) # logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -305,7 +305,7 @@ def complete_runtime( # Save pytest exit code action = CmdRunAction(command='echo $?') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) # logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -318,7 +318,7 @@ def complete_runtime( # Read the test report action = CmdRunAction(command='cat report.json') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) # logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -330,7 +330,7 @@ def complete_runtime( repo_name = instance['repo'].split('/')[1] repo_name = repo_name.replace('.', '-') action = CmdRunAction(command=f'commit0 get-tests {repo_name}') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) # logger.info(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/evaluation/benchmarks/discoverybench/run_infer.py b/evaluation/benchmarks/discoverybench/run_infer.py index 05ff44003517..30af2d19d473 100644 --- a/evaluation/benchmarks/discoverybench/run_infer.py +++ b/evaluation/benchmarks/discoverybench/run_infer.py @@ -78,7 +78,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False agent_config = AgentConfig( function_calling=False, codeact_enable_jupyter=True, diff --git a/evaluation/benchmarks/gaia/run_infer.py b/evaluation/benchmarks/gaia/run_infer.py index 7974a092903c..b4c704e497f7 100644 --- a/evaluation/benchmarks/gaia/run_infer.py +++ b/evaluation/benchmarks/gaia/run_infer.py @@ -63,7 +63,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/gorilla/run_infer.py b/evaluation/benchmarks/gorilla/run_infer.py index 740a3c3ada8f..e97be5ed836c 100644 --- a/evaluation/benchmarks/gorilla/run_infer.py +++ b/evaluation/benchmarks/gorilla/run_infer.py @@ -56,7 +56,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/gpqa/run_infer.py b/evaluation/benchmarks/gpqa/run_infer.py index eb1c808ec8a4..cf4106b97136 100644 --- a/evaluation/benchmarks/gpqa/run_infer.py +++ b/evaluation/benchmarks/gpqa/run_infer.py @@ -77,7 +77,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/humanevalfix/run_infer.py b/evaluation/benchmarks/humanevalfix/run_infer.py index ba802ddf9dfa..fec040079cc6 100644 --- a/evaluation/benchmarks/humanevalfix/run_infer.py +++ b/evaluation/benchmarks/humanevalfix/run_infer.py @@ -98,7 +98,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/logic_reasoning/run_infer.py b/evaluation/benchmarks/logic_reasoning/run_infer.py index ee48f5ea76c8..acd07edef26e 100644 --- a/evaluation/benchmarks/logic_reasoning/run_infer.py +++ b/evaluation/benchmarks/logic_reasoning/run_infer.py @@ -62,7 +62,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/mint/run_infer.py b/evaluation/benchmarks/mint/run_infer.py index 61223572ae83..ddfef0ea685b 100644 --- a/evaluation/benchmarks/mint/run_infer.py +++ b/evaluation/benchmarks/mint/run_infer.py @@ -120,7 +120,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/ml_bench/run_infer.py b/evaluation/benchmarks/ml_bench/run_infer.py index 4e396b3c3fe1..c2fcc1ae3e26 100644 --- a/evaluation/benchmarks/ml_bench/run_infer.py +++ b/evaluation/benchmarks/ml_bench/run_infer.py @@ -93,7 +93,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/swe_bench/eval_infer.py b/evaluation/benchmarks/swe_bench/eval_infer.py index 7beacf344408..ab7feb38b678 100644 --- a/evaluation/benchmarks/swe_bench/eval_infer.py +++ b/evaluation/benchmarks/swe_bench/eval_infer.py @@ -174,7 +174,7 @@ def process_instance( # Set +x action = CmdRunAction(command='chmod +x /tmp/eval.sh') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -189,7 +189,7 @@ def process_instance( "echo 'APPLY_PATCH_FAIL')))" ) action = CmdRunAction(command=exec_command) - action.timeout = 600 + action.set_hard_timeout(600) obs = runtime.run_action(action) assert isinstance(obs, CmdOutputObservation) apply_patch_output = obs.content @@ -212,7 +212,7 @@ def process_instance( # Run eval script in background and save output to log file log_file = '/tmp/eval_output.log' action = CmdRunAction(command=f'/tmp/eval.sh > {log_file} 2>&1 & echo $!') - action.timeout = 60 # Short timeout just to get the process ID + action.set_hard_timeout(60) # Short timeout just to get the process ID obs = runtime.run_action(action) if isinstance(obs, CmdOutputObservation) and obs.exit_code == 0: @@ -235,7 +235,7 @@ def process_instance( check_action = CmdRunAction( command=f'ps -p {pid} > /dev/null; echo $?' ) - check_action.timeout = 60 + check_action.set_hard_timeout(60) check_obs = runtime.run_action(check_action) if ( isinstance(check_obs, CmdOutputObservation) @@ -252,7 +252,7 @@ def process_instance( # Read the log file cat_action = CmdRunAction(command=f'cat {log_file}') - cat_action.timeout = 300 + cat_action.set_hard_timeout(300) cat_obs = runtime.run_action(cat_action) # Grade answer diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index bf065ada9734..ac9e85b60b10 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -173,7 +173,7 @@ def initialize_runtime( action = CmdRunAction( command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc""" ) - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -182,7 +182,7 @@ def initialize_runtime( ) action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """) - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -194,7 +194,7 @@ def initialize_runtime( # inject the instance info action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -223,14 +223,14 @@ def initialize_runtime( '/swe_util/', ) action = CmdRunAction(command='cat ~/.bashrc') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}') action = CmdRunAction(command='source ~/.bashrc') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -239,7 +239,7 @@ def initialize_runtime( assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}') action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh') - action.timeout = 3600 + action.set_hard_timeout(3600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -249,7 +249,7 @@ def initialize_runtime( ) else: action = CmdRunAction(command='source /swe_util/swe_entry.sh') - action.timeout = 1800 + action.set_hard_timeout(1800) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -259,7 +259,7 @@ def initialize_runtime( ) action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -269,7 +269,7 @@ def initialize_runtime( ) action = CmdRunAction(command='git reset --hard') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -278,14 +278,14 @@ def initialize_runtime( action = CmdRunAction( command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done' ) - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}') action = CmdRunAction(command='which python') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -316,7 +316,7 @@ def complete_runtime( workspace_dir_name = _get_swebench_workspace_dir_name(instance) action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -326,7 +326,7 @@ def complete_runtime( ) action = CmdRunAction(command='git config --global core.pager ""') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -336,7 +336,7 @@ def complete_runtime( ) action = CmdRunAction(command='git add -A') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -351,7 +351,7 @@ def complete_runtime( action = CmdRunAction( command=f'git diff --no-color --cached {instance["base_commit"]}' ) - action.timeout = 600 + 100 * n_retries + action.set_hard_timeout(600 + 100 * n_retries) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/evaluation/benchmarks/the_agent_company/browsing.py b/evaluation/benchmarks/the_agent_company/browsing.py index 7384dddbdfce..5ce97129777a 100644 --- a/evaluation/benchmarks/the_agent_company/browsing.py +++ b/evaluation/benchmarks/the_agent_company/browsing.py @@ -262,7 +262,7 @@ def pre_login( instruction = action.to_instruction() browser_action = BrowseInteractiveAction(browser_actions=instruction) - browser_action.timeout = 10000 + browser_action.set_hard_timeout(10000) logger.info(browser_action, extra={'msg_type': 'ACTION'}) obs: BrowserOutputObservation = runtime.run_action(browser_action) logger.debug(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/evaluation/benchmarks/the_agent_company/run_infer.py b/evaluation/benchmarks/the_agent_company/run_infer.py index a82db6d56081..8f8a1b599e6f 100644 --- a/evaluation/benchmarks/the_agent_company/run_infer.py +++ b/evaluation/benchmarks/the_agent_company/run_infer.py @@ -86,7 +86,7 @@ def init_task_env(runtime: Runtime, hostname: str, env_llm_config: LLMConfig): 'bash /utils/init.sh' ) action = CmdRunAction(command=command) - action.timeout = 900 + action.set_hard_timeout(900) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -172,7 +172,7 @@ def run_evaluator( f'python_default /utils/eval.py --trajectory_path {trajectory_path} --result_path {result_path}' ) action = CmdRunAction(command=command) - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/evaluation/benchmarks/toolqa/run_infer.py b/evaluation/benchmarks/toolqa/run_infer.py index 8586f9a7bb7c..8306292d8f2f 100644 --- a/evaluation/benchmarks/toolqa/run_infer.py +++ b/evaluation/benchmarks/toolqa/run_infer.py @@ -57,7 +57,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/evaluation/benchmarks/webarena/run_infer.py b/evaluation/benchmarks/webarena/run_infer.py index c35c79ba2cce..79b7fc4371aa 100644 --- a/evaluation/benchmarks/webarena/run_infer.py +++ b/evaluation/benchmarks/webarena/run_infer.py @@ -78,7 +78,7 @@ def get_config( ) config.set_llm_config(metadata.llm_config) agent_config = config.get_agent_config(metadata.agent_class) - agent_config.use_microagents = False + agent_config.enable_prompt_extensions = False return config diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index d8b5702a235d..37c52855148a 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -111,7 +111,7 @@ def __init__( os.path.dirname(os.path.dirname(openhands.__file__)), 'microagents', ) - if self.config.use_microagents + if self.config.enable_prompt_extensions else None, prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'), disabled_microagents=self.config.disabled_microagents, @@ -448,6 +448,17 @@ def _get_messages(self, state: State) -> list[Message]: ) ) + # Repository and runtime info + additional_info = self.prompt_manager.get_additional_info() + if self.config.enable_prompt_extensions and additional_info: + # only add these if prompt extension is enabled + messages.append( + Message( + role='user', + content=[TextContent(text=additional_info)], + ) + ) + pending_tool_call_action_messages: dict[str, Message] = {} tool_call_id_to_message: dict[str, Message] = {} diff --git a/openhands/agenthub/codeact_agent/prompts/system_prompt.j2 b/openhands/agenthub/codeact_agent/prompts/system_prompt.j2 index b6dfcd9bda75..325392f2e662 100644 --- a/openhands/agenthub/codeact_agent/prompts/system_prompt.j2 +++ b/openhands/agenthub/codeact_agent/prompts/system_prompt.j2 @@ -3,26 +3,4 @@ You are OpenHands agent, a helpful AI assistant that can interact with a compute * If user provides a path, you should NOT assume it's relative to the current working directory. Instead, you should explore the file system to find the file before working on it. * When configuring git credentials, use "openhands" as the user.name and "openhands@all-hands.dev" as the user.email by default, unless explicitly instructed otherwise. * The assistant MUST NOT include comments in the code unless they are necessary to describe non-obvious behavior. -{{ runtime_info }} -{% if repository_info %} - -At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}. - -{% endif %} -{% if repository_instructions -%} - -{{ repository_instructions }} - -{% endif %} -{% if runtime_info and runtime_info.available_hosts -%} - -The user has access to the following hosts for accessing a web application, -each of which has a corresponding port: -{% for host, port in runtime_info.available_hosts.items() -%} -* {{ host }} (port {{ port }}) -{% endfor %} -When starting a web server, use the corresponding ports. You should also -set any options to allow iframes and CORS requests. - -{% endif %} diff --git a/openhands/core/config/agent_config.py b/openhands/core/config/agent_config.py index 77e9dbc1e32d..375fd9b12e8a 100644 --- a/openhands/core/config/agent_config.py +++ b/openhands/core/config/agent_config.py @@ -17,7 +17,7 @@ class AgentConfig: memory_enabled: Whether long-term memory (embeddings) is enabled. memory_max_threads: The maximum number of threads indexing at the same time for embeddings. llm_config: The name of the llm config to use. If specified, this will override global llm config. - use_microagents: Whether to use microagents at all. Default is True. + enable_prompt_extensions: Whether to use prompt extensions (e.g., microagents, inject runtime info). Default is True. disabled_microagents: A list of microagents to disable. Default is None. condenser: Configuration for the memory condenser. Default is NoOpCondenserConfig. """ @@ -29,7 +29,7 @@ class AgentConfig: memory_enabled: bool = False memory_max_threads: int = 3 llm_config: str | None = None - use_microagents: bool = True + enable_prompt_extensions: bool = True disabled_microagents: list[str] | None = None condenser: CondenserConfig = field(default_factory=NoOpCondenserConfig) # type: ignore diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 16c08a7693f0..bae58373811d 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -40,6 +40,7 @@ class LLMConfig: drop_params: Drop any unmapped (unsupported) params without causing an exception. modify_params: Modify params allows litellm to do transformations like adding a default message, when a message is empty. disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction). + reasoning_effort: The effort to put into reasoning. This is a string that can be one of 'low', 'medium', 'high', or 'none'. Exclusive for o1 models. caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider. log_completions: Whether to log LLM completions to the state. log_completions_folder: The folder to log LLM completions to. Required if log_completions is True. @@ -79,6 +80,7 @@ class LLMConfig: # Note: this setting is actually global, unlike drop_params modify_params: bool = True disable_vision: bool | None = None + reasoning_effort: str | None = None caching_prompt: bool = True log_completions: bool = False log_completions_folder: str = os.path.join(LOG_DIR, 'completions') diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 0ea40f29faab..3a0b705dd02d 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -41,7 +41,7 @@ class SandboxConfig: remote_runtime_api_url: str = 'http://localhost:8000' local_runtime_url: str = 'http://localhost' - keep_runtime_alive: bool = False + keep_runtime_alive: bool = True rm_all_containers: bool = False api_key: str | None = None base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime @@ -60,7 +60,7 @@ class SandboxConfig: runtime_startup_env_vars: dict[str, str] = field(default_factory=dict) browsergym_eval_env: str | None = None platform: str | None = None - close_delay: int = 15 + close_delay: int = 900 remote_runtime_resource_factor: int = 1 enable_gpu: bool = False docker_runtime_kwargs: str | None = None diff --git a/openhands/events/event.py b/openhands/events/event.py index 6c7a2d8a3ac1..1bdece59eb75 100644 --- a/openhands/events/event.py +++ b/openhands/events/event.py @@ -64,8 +64,12 @@ def timeout(self) -> int | None: return self._timeout # type: ignore[attr-defined] return None - @timeout.setter - def timeout(self, value: int | None) -> None: + def set_hard_timeout(self, value: int | None, blocking: bool = True) -> None: + """Set the timeout for the event. + + NOTE, this is a hard timeout, meaning that the event will be blocked + until the timeout is reached. + """ self._timeout = value if value is not None and value > 600: from openhands.core.logger import openhands_logger as logger @@ -78,7 +82,7 @@ def timeout(self, value: int | None) -> None: # Check if .blocking is an attribute of the event if hasattr(self, 'blocking'): # .blocking needs to be set to True if .timeout is set - self.blocking = True + self.blocking = blocking # optional metadata, LLM call cost of the edit @property diff --git a/openhands/events/serialization/action.py b/openhands/events/serialization/action.py index 90945c1d4dfd..be9990750fc6 100644 --- a/openhands/events/serialization/action.py +++ b/openhands/events/serialization/action.py @@ -74,7 +74,8 @@ def action_from_dict(action: dict) -> Action: try: decoded_action = action_class(**args) if 'timeout' in action: - decoded_action.timeout = action['timeout'] + blocking = args.get('blocking', False) + decoded_action.set_hard_timeout(action['timeout'], blocking=blocking) # Set timestamp if it was provided if timestamp: diff --git a/openhands/llm/async_llm.py b/openhands/llm/async_llm.py index ed84273c737b..f553ae173fd6 100644 --- a/openhands/llm/async_llm.py +++ b/openhands/llm/async_llm.py @@ -6,7 +6,11 @@ from openhands.core.exceptions import UserCancelledError from openhands.core.logger import openhands_logger as logger -from openhands.llm.llm import LLM, LLM_RETRY_EXCEPTIONS +from openhands.llm.llm import ( + LLM, + LLM_RETRY_EXCEPTIONS, + REASONING_EFFORT_SUPPORTED_MODELS, +) from openhands.utils.shutdown_listener import should_continue @@ -55,6 +59,10 @@ async def async_completion_wrapper(*args, **kwargs): elif 'messages' in kwargs: messages = kwargs['messages'] + # Set reasoning effort for models that support it + if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS: + kwargs['reasoning_effort'] = self.config.reasoning_effort + # ensure we work with a list of messages messages = messages if isinstance(messages, list) else [messages] diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 743d6535ba3b..88cda96c5f00 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -71,6 +71,15 @@ 'claude-3-5-haiku-20241022', 'gpt-4o-mini', 'gpt-4o', + 'o1-2024-12-17', +] + +REASONING_EFFORT_SUPPORTED_MODELS = [ + 'o1-2024-12-17', +] + +MODELS_WITHOUT_STOP_WORDS = [ + 'o1-mini', ] @@ -186,7 +195,8 @@ def wrapper(*args, **kwargs): messages, kwargs['tools'] ) kwargs['messages'] = messages - kwargs['stop'] = STOP_WORDS + if self.config.model not in MODELS_WITHOUT_STOP_WORDS: + kwargs['stop'] = STOP_WORDS mock_fncall_tools = kwargs.pop('tools') # if we have no messages, something went very wrong @@ -205,6 +215,10 @@ def wrapper(*args, **kwargs): 'anthropic-beta': 'prompt-caching-2024-07-31', } + # Set reasoning effort for models that support it + if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS: + kwargs['reasoning_effort'] = self.config.reasoning_effort + # set litellm modify_params to the configured value # True by default to allow litellm to do transformations like adding a default message, when a message is empty # NOTE: this setting is global; unlike drop_params, it cannot be overridden in the litellm completion partial @@ -213,7 +227,6 @@ def wrapper(*args, **kwargs): try: # Record start time for latency measurement start_time = time.time() - # we don't support streaming here, thus we get a ModelResponse resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) @@ -597,17 +610,16 @@ def _completion_cost(self, response) -> float: logger.debug(f'Using custom cost per token: {cost_per_token}') extra_kwargs['custom_cost_per_token'] = cost_per_token - try: - # try directly get response_cost from response - _hidden_params = getattr(response, '_hidden_params', {}) - cost = _hidden_params.get('response_cost', None) - if cost is None: - cost = float( - _hidden_params.get('additional_headers', {}).get( - 'llm_provider-x-litellm-response-cost', 0.0 - ) - ) + # try directly get response_cost from response + _hidden_params = getattr(response, '_hidden_params', {}) + cost = _hidden_params.get('additional_headers', {}).get( + 'llm_provider-x-litellm-response-cost', None + ) + if cost is not None: + cost = float(cost) + logger.debug(f'Got response_cost from response: {cost}') + try: if cost is None: try: cost = litellm_completion_cost( diff --git a/openhands/llm/streaming_llm.py b/openhands/llm/streaming_llm.py index 77d999fadcd3..10925b9564cf 100644 --- a/openhands/llm/streaming_llm.py +++ b/openhands/llm/streaming_llm.py @@ -5,6 +5,7 @@ from openhands.core.exceptions import UserCancelledError from openhands.core.logger import openhands_logger as logger from openhands.llm.async_llm import LLM_RETRY_EXCEPTIONS, AsyncLLM +from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS class StreamingLLM(AsyncLLM): @@ -61,6 +62,10 @@ async def async_streaming_completion_wrapper(*args, **kwargs): 'The messages list is empty. At least one message is required.' ) + # Set reasoning effort for models that support it + if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS: + kwargs['reasoning_effort'] = self.config.reasoning_effort + self.log_prompt(messages) try: diff --git a/openhands/resolver/resolve_issue.py b/openhands/resolver/resolve_issue.py index f50b37d79447..4e0b2b4ad96c 100644 --- a/openhands/resolver/resolve_issue.py +++ b/openhands/resolver/resolve_issue.py @@ -118,7 +118,7 @@ async def complete_runtime( git_patch = None while n_retries < 5: action = CmdRunAction(command=f'git diff --no-color --cached {base_commit}') - action.timeout = 600 + 100 * n_retries + action.set_hard_timeout(600 + 100 * n_retries) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/openhands/runtime/action_execution_server.py b/openhands/runtime/action_execution_server.py index b483c183cbdc..8a5fcdc0edd9 100644 --- a/openhands/runtime/action_execution_server.py +++ b/openhands/runtime/action_execution_server.py @@ -120,6 +120,9 @@ async def ainit(self): self.bash_session = BashSession( work_dir=self._initial_cwd, username=self.username, + no_change_timeout_seconds=int( + os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 30) + ), ) self.bash_session.initialize() await wait_all( @@ -163,7 +166,7 @@ async def _init_bash_commands(self): logger.debug(f'Initializing by running {len(INIT_COMMANDS)} bash commands...') for command in INIT_COMMANDS: action = CmdRunAction(command=command) - action.timeout = 300 + action.set_hard_timeout(300) logger.debug(f'Executing init command: {command}') obs = await self.run(action) assert isinstance(obs, CmdOutputObservation) diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 114289f390b2..94a059f06aa3 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -182,7 +182,8 @@ def on_event(self, event: Event) -> None: async def _handle_action(self, event: Action) -> None: if event.timeout is None: - event.timeout = self.config.sandbox.timeout + # We don't block the command if this is a default timeout action + event.set_hard_timeout(self.config.sandbox.timeout, blocking=False) assert event.timeout is not None try: observation: Observation = await call_sync_from_async( diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py index a728460a374e..b2e869eca3bf 100644 --- a/openhands/runtime/builder/remote.py +++ b/openhands/runtime/builder/remote.py @@ -9,7 +9,6 @@ from openhands.core.logger import openhands_logger as logger from openhands.runtime.builder import RuntimeBuilder from openhands.runtime.utils.request import send_request -from openhands.utils.http_session import HttpSession from openhands.utils.shutdown_listener import ( should_continue, sleep_if_should_continue, @@ -19,10 +18,12 @@ class RemoteRuntimeBuilder(RuntimeBuilder): """This class interacts with the remote Runtime API for building and managing container images.""" - def __init__(self, api_url: str, api_key: str, session: HttpSession | None = None): + def __init__( + self, api_url: str, api_key: str, session: requests.Session | None = None + ): self.api_url = api_url self.api_key = api_key - self.session = session or HttpSession() + self.session = session or requests.Session() self.session.headers.update({'X-API-Key': self.api_key}) def build( diff --git a/openhands/runtime/impl/action_execution/action_execution_client.py b/openhands/runtime/impl/action_execution/action_execution_client.py index 4965fc1752af..f8c93dd5561f 100644 --- a/openhands/runtime/impl/action_execution/action_execution_client.py +++ b/openhands/runtime/impl/action_execution/action_execution_client.py @@ -35,7 +35,6 @@ from openhands.runtime.base import Runtime from openhands.runtime.plugins import PluginRequirement from openhands.runtime.utils.request import send_request -from openhands.utils.http_session import HttpSession class ActionExecutionClient(Runtime): @@ -56,7 +55,7 @@ def __init__( attach_to_existing: bool = False, headless_mode: bool = True, ): - self.session = HttpSession() + self.session = requests.Session() self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time self._runtime_initialized: bool = False self._vscode_token: str | None = None # initial dummy value @@ -217,7 +216,8 @@ def send_action_for_execution(self, action: Action) -> Observation: # set timeout to default if not set if action.timeout is None: - action.timeout = self.config.sandbox.timeout + # We don't block the command if this is a default timeout action + action.set_hard_timeout(self.config.sandbox.timeout, blocking=False) with self.action_semaphore: if not action.runnable: diff --git a/openhands/runtime/impl/docker/docker_runtime.py b/openhands/runtime/impl/docker/docker_runtime.py index 5111f0f36831..bf06e00e854f 100644 --- a/openhands/runtime/impl/docker/docker_runtime.py +++ b/openhands/runtime/impl/docker/docker_runtime.py @@ -228,6 +228,8 @@ def _init_container(self): } if self.config.debug or DEBUG: environment['DEBUG'] = 'true' + # also update with runtime_startup_env_vars + environment.update(self.config.sandbox.runtime_startup_env_vars) self.log('debug', f'Workspace Base: {self.config.workspace_base}') if ( diff --git a/openhands/runtime/utils/bash.py b/openhands/runtime/utils/bash.py index 351d990dcda6..87b2ae405f1d 100644 --- a/openhands/runtime/utils/bash.py +++ b/openhands/runtime/utils/bash.py @@ -174,7 +174,7 @@ def __init__( self, work_dir: str, username: str | None = None, - no_change_timeout_seconds: float = 30.0, + no_change_timeout_seconds: int = 30, ): self.NO_CHANGE_TIMEOUT_SECONDS = no_change_timeout_seconds self.work_dir = work_dir @@ -369,7 +369,7 @@ def _handle_nochange_timeout_command( command, raw_command_output, metadata, - continue_prefix='[Command output continued from previous command]\n', + continue_prefix='[Below is the output of the previous command.]\n', ) return CmdOutputObservation( content=command_output, @@ -404,7 +404,7 @@ def _handle_hard_timeout_command( command, raw_command_output, metadata, - continue_prefix='[Command output continued from previous command]\n', + continue_prefix='[Below is the output of the previous command.]\n', ) return CmdOutputObservation( @@ -441,6 +441,8 @@ def _combine_outputs_between_matches( else: # The command output is the content after the last PS1 prompt return pane_content[ps1_matches[0].end() + 1 :] + elif len(ps1_matches) == 0: + return pane_content combined_output = '' for i in range(len(ps1_matches) - 1): # Extract content between current and next PS1 prompt @@ -459,6 +461,9 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati # Strip the command of any leading/trailing whitespace logger.debug(f'RECEIVED ACTION: {action}') command = action.command.strip() + is_special_key = self._is_special_key(command) + + # Handle when prev command is hard timeout if command == '' and self.prev_status not in { BashCommandStatus.CONTINUE, @@ -486,13 +491,45 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati last_change_time = start_time last_pane_output = self._get_pane_content() - if command != '': + # Do not check hard timeout if the command is a special key + if command != '' and is_special_key: + logger.debug(f'SENDING SPECIAL KEY: {command!r}') + self.pane.send_keys(command, enter=False) + # When prev command is hard timeout, and we are trying to execute new command + elif self.prev_status == BashCommandStatus.HARD_TIMEOUT and command != '': + if not last_pane_output.endswith(CMD_OUTPUT_PS1_END): + _ps1_matches = CmdOutputMetadata.matches_ps1_metadata(last_pane_output) + raw_command_output = self._combine_outputs_between_matches( + last_pane_output, _ps1_matches + ) + metadata = CmdOutputMetadata() # No metadata available + metadata.suffix = ( + f'\n[Your command "{command}" is NOT executed. ' + f'The previous command was timed out but still running. Above is the output of the previous command. ' + "You may wait longer to see additional output of the previous command by sending empty command '', " + 'send other commands to interact with the current process, ' + 'or send keys ("C-c", "C-z", "C-d") to interrupt/kill the previous command before sending your new command.]' + ) + command_output = self._get_command_output( + command, + raw_command_output, + metadata, + continue_prefix='[Below is the output of the previous command.]\n', + ) + return CmdOutputObservation( + command=command, + content=command_output, + metadata=metadata, + ) + # Only send the command to the pane if it's not a special key and it's not empty + # AND previous hard timeout command is resolved + elif command != '' and not is_special_key: # convert command to raw string command = escape_bash_special_chars(command) logger.debug(f'SENDING COMMAND: {command!r}') self.pane.send_keys( command, - enter=not self._is_special_key(command), + enter=True, ) # Loop until the command completes or times out @@ -525,7 +562,7 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati # We ignore this if the command is *blocking time_since_last_change = time.time() - last_change_time logger.debug( - f'CHECKING NO CHANGE TIMEOUT ({self.NO_CHANGE_TIMEOUT_SECONDS}s): elapsed {time_since_last_change}' + f'CHECKING NO CHANGE TIMEOUT ({self.NO_CHANGE_TIMEOUT_SECONDS}s): elapsed {time_since_last_change}. Action blocking: {action.blocking}' ) if ( not action.blocking diff --git a/openhands/runtime/utils/request.py b/openhands/runtime/utils/request.py index 0117e019a6a8..e05a083e7b0d 100644 --- a/openhands/runtime/utils/request.py +++ b/openhands/runtime/utils/request.py @@ -4,7 +4,6 @@ import requests from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential -from openhands.utils.http_session import HttpSession from openhands.utils.tenacity_stop import stop_if_should_exit @@ -35,7 +34,7 @@ def is_retryable_error(exception): wait=wait_exponential(multiplier=1, min=4, max=60), ) def send_request( - session: HttpSession, + session: requests.Session, method: str, url: str, timeout: int = 10, @@ -49,11 +48,11 @@ def send_request( _json = response.json() except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError): _json = None + finally: + response.close() raise RequestHTTPError( e, response=e.response, detail=_json.get('detail') if _json is not None else None, ) from e - finally: - response.close() return response diff --git a/openhands/server/listen_socket.py b/openhands/server/listen_socket.py index 81816a1325e1..baa79a4c6e35 100644 --- a/openhands/server/listen_socket.py +++ b/openhands/server/listen_socket.py @@ -44,7 +44,8 @@ async def connect(connection_id: str, environ, auth): conversation_store = await ConversationStoreImpl.get_instance(config, user_id) metadata = await conversation_store.get_metadata(conversation_id) - if metadata.github_user_id != user_id: + + if metadata.github_user_id != str(user_id): logger.error( f'User {user_id} is not allowed to join conversation {conversation_id}' ) diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index f622c5bad9cf..0015cf0e90aa 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -12,6 +12,7 @@ from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl from openhands.server.session.conversation_init_data import ConversationInitData from openhands.server.shared import config, session_manager +from openhands.server.types import LLMAuthenticationError, MissingSettingsError from openhands.storage.data_models.conversation_info import ConversationInfo from openhands.storage.data_models.conversation_info_result_set import ( ConversationInfoResultSet, @@ -33,16 +34,12 @@ class InitSessionRequest(BaseModel): selected_repository: str | None = None -@app.post('/conversations') -async def new_conversation(request: Request, data: InitSessionRequest): - """Initialize a new session or join an existing one. - After successful initialization, the client should connect to the WebSocket - using the returned conversation ID - """ - logger.info('Initializing new conversation') - +async def _create_new_conversation( + user_id: str | None, + token: str | None, + selected_repository: str | None, +): logger.info('Loading settings') - user_id = get_user_id(request) settings_store = await SettingsStoreImpl.get_instance(config, user_id) settings = await settings_store.load() logger.info('Settings loaded') @@ -54,25 +51,16 @@ async def new_conversation(request: Request, data: InitSessionRequest): # but that would run a tiny inference. if not settings.llm_api_key or settings.llm_api_key.isspace(): logger.warn(f'Missing api key for model {settings.llm_model}') - return JSONResponse( - content={ - 'status': 'error', - 'message': 'Error authenticating with the LLM provider. Please check your API key', - 'msg_id': 'STATUS$ERROR_LLM_AUTHENTICATION', - } + raise LLMAuthenticationError( + 'Error authenticating with the LLM provider. Please check your API key' ) + else: logger.warn('Settings not present, not starting conversation') - return JSONResponse( - content={ - 'status': 'error', - 'message': 'Settings not found', - 'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND', - } - ) - github_token = getattr(request.state, 'github_token', '') - session_init_args['github_token'] = github_token or data.github_token or '' - session_init_args['selected_repository'] = data.selected_repository + raise MissingSettingsError('Settings not found') + + session_init_args['github_token'] = token or '' + session_init_args['selected_repository'] = selected_repository conversation_init_data = ConversationInitData(**session_init_args) logger.info('Loading conversation store') conversation_store = await ConversationStoreImpl.get_instance(config, user_id) @@ -85,7 +73,7 @@ async def new_conversation(request: Request, data: InitSessionRequest): logger.info(f'New conversation ID: {conversation_id}') repository_title = ( - data.selected_repository.split('/')[-1] if data.selected_repository else None + selected_repository.split('/')[-1] if selected_repository else None ) conversation_title = f'{repository_title or "Conversation"} {conversation_id[:5]}' @@ -95,7 +83,7 @@ async def new_conversation(request: Request, data: InitSessionRequest): conversation_id=conversation_id, title=conversation_title, github_user_id=user_id, - selected_repository=data.selected_repository, + selected_repository=selected_repository, ) ) @@ -112,7 +100,47 @@ async def new_conversation(request: Request, data: InitSessionRequest): except ValueError: pass # Already subscribed - take no action logger.info(f'Finished initializing conversation {conversation_id}') - return JSONResponse(content={'status': 'ok', 'conversation_id': conversation_id}) + + return conversation_id + + +@app.post('/conversations') +async def new_conversation(request: Request, data: InitSessionRequest): + """Initialize a new session or join an existing one. + After successful initialization, the client should connect to the WebSocket + using the returned conversation ID + """ + logger.info('Initializing new conversation') + user_id = get_user_id(request) + github_token = getattr(request.state, 'github_token', '') or data.github_token + selected_repository = data.selected_repository + + try: + conversation_id = await _create_new_conversation( + user_id, github_token, selected_repository + ) + + return JSONResponse( + content={'status': 'ok', 'conversation_id': conversation_id} + ) + except MissingSettingsError as e: + return JSONResponse( + content={ + 'status': 'error', + 'message': str(e), + 'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND', + }, + status_code=400, + ) + + except LLMAuthenticationError as e: + return JSONResponse( + content={ + 'status': 'error', + 'message': str(e), + 'msg_id': 'STATUS$ERROR_LLM_AUTHENTICATION', + }, + ) @app.get('/conversations') @@ -130,7 +158,7 @@ async def search_conversations( for conversation in conversation_metadata_result_set.results if hasattr(conversation, 'created_at') ) - running_conversations = await session_manager.get_running_agent_loops( + running_conversations = await session_manager.get_agent_loop_running( get_user_id(request), set(conversation_ids) ) result = ConversationInfoResultSet( diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 285acccbfbe4..70bf6eeca6bb 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -1,5 +1,4 @@ import asyncio -import time from typing import Callable, Optional from openhands.controller import AgentController @@ -17,10 +16,10 @@ from openhands.runtime.base import Runtime from openhands.security import SecurityAnalyzer, options from openhands.storage.files import FileStore -from openhands.utils.async_utils import call_sync_from_async +from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async from openhands.utils.shutdown_listener import should_continue -WAIT_TIME_BEFORE_CLOSE = 90 +WAIT_TIME_BEFORE_CLOSE = 300 WAIT_TIME_BEFORE_CLOSE_INTERVAL = 5 @@ -37,8 +36,7 @@ class AgentSession: controller: AgentController | None = None runtime: Runtime | None = None security_analyzer: SecurityAnalyzer | None = None - _starting: bool = False - _started_at: float = 0 + _initializing: bool = False _closed: bool = False loop: asyncio.AbstractEventLoop | None = None @@ -90,8 +88,7 @@ async def start( if self._closed: logger.warning('Session closed before starting') return - self._starting = True - self._started_at = time.time() + self._initializing = True self._create_security_analyzer(config.security.security_analyzer) await self._create_runtime( runtime_name=runtime_name, @@ -112,19 +109,24 @@ async def start( self.event_stream.add_event( ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT ) - self._starting = False + self._initializing = False - async def close(self): + def close(self): """Closes the Agent session""" if self._closed: return self._closed = True - while self._starting and should_continue(): + call_async_from_sync(self._close) + + async def _close(self): + seconds_waited = 0 + while self._initializing and should_continue(): logger.debug( f'Waiting for initialization to finish before closing session {self.sid}' ) await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL) - if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE: + seconds_waited += WAIT_TIME_BEFORE_CLOSE_INTERVAL + if seconds_waited > WAIT_TIME_BEFORE_CLOSE: logger.error( f'Waited too long for initialization to finish before closing session {self.sid}' ) @@ -309,12 +311,3 @@ def _maybe_restore_state(self) -> State | None: else: logger.debug('No events found, no state to restore') return restored_state - - def get_state(self) -> AgentState | None: - controller = self.controller - if controller: - return controller.state.agent_state - if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE: - # If 5 minutes have elapsed and we still don't have a controller, something has gone wrong - return AgentState.ERROR - return None diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 3c4d929a72de..67358f61fbe8 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -2,7 +2,6 @@ import json import time from dataclasses import dataclass, field -from typing import Generic, Iterable, TypeVar from uuid import uuid4 import socketio @@ -10,28 +9,26 @@ from openhands.core.config import AppConfig from openhands.core.exceptions import AgentRuntimeUnavailableError from openhands.core.logger import openhands_logger as logger -from openhands.core.schema.agent import AgentState from openhands.events.stream import EventStream, session_exists from openhands.server.session.conversation import Conversation from openhands.server.session.session import ROOM_KEY, Session from openhands.server.settings import Settings from openhands.storage.files import FileStore -from openhands.utils.async_utils import wait_all +from openhands.utils.async_utils import call_sync_from_async from openhands.utils.shutdown_listener import should_continue _REDIS_POLL_TIMEOUT = 1.5 _CHECK_ALIVE_INTERVAL = 15 _CLEANUP_INTERVAL = 15 -MAX_RUNNING_CONVERSATIONS = 3 -T = TypeVar('T') +_CLEANUP_EXCEPTION_WAIT_TIME = 15 @dataclass -class _ClusterQuery(Generic[T]): - query_id: str - request_ids: set[str] | None - result: T +class _SessionIsRunningCheck: + request_id: str + request_sids: list[str] + running_sids: set[str] = field(default_factory=set) flag: asyncio.Event = field(default_factory=asyncio.Event) @@ -41,10 +38,10 @@ class SessionManager: config: AppConfig file_store: FileStore _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict) - _local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) + local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) _last_alive_timestamps: dict[str, float] = field(default_factory=dict) _redis_listen_task: asyncio.Task | None = None - _running_sid_queries: dict[str, _ClusterQuery[set[str]]] = field( + _session_is_running_checks: dict[str, _SessionIsRunningCheck] = field( default_factory=dict ) _active_conversations: dict[str, tuple[Conversation, int]] = field( @@ -55,7 +52,7 @@ class SessionManager: ) _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _cleanup_task: asyncio.Task | None = None - _connection_queries: dict[str, _ClusterQuery[dict[str, str]]] = field( + _has_remote_connections_flags: dict[str, asyncio.Event] = field( default_factory=dict ) @@ -63,7 +60,7 @@ async def __aenter__(self): redis_client = self._get_redis_client() if redis_client: self._redis_listen_task = asyncio.create_task(self._redis_subscribe()) - self._cleanup_task = asyncio.create_task(self._cleanup_stale()) + self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations()) return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -85,7 +82,7 @@ async def _redis_subscribe(self): logger.debug('_redis_subscribe') redis_client = self._get_redis_client() pubsub = redis_client.pubsub() - await pubsub.subscribe('session_msg') + await pubsub.subscribe('oh_event') while should_continue(): try: message = await pubsub.get_message( @@ -111,71 +108,59 @@ async def _process_message(self, message: dict): session = self._local_agent_loops_by_sid.get(sid) if session: await session.dispatch(data['data']) - elif message_type == 'running_agent_loops_query': + elif message_type == 'is_session_running': # Another node in the cluster is asking if the current node is running the session given. - query_id = data['query_id'] - sids = self._get_running_agent_loops_locally( - data.get('user_id'), data.get('filter_to_sids') - ) + request_id = data['request_id'] + sids = [ + sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid + ] if sids: await self._get_redis_client().publish( - 'session_msg', + 'oh_event', json.dumps( { - 'query_id': query_id, - 'sids': list(sids), - 'message_type': 'running_agent_loops_response', + 'request_id': request_id, + 'sids': sids, + 'message_type': 'session_is_running', } ), ) - elif message_type == 'running_agent_loops_response': - query_id = data['query_id'] + elif message_type == 'session_is_running': + request_id = data['request_id'] for sid in data['sids']: self._last_alive_timestamps[sid] = time.time() - running_query = self._running_sid_queries.get(query_id) - if running_query: - running_query.result.update(data['sids']) - if running_query.request_ids is not None and len( - running_query.request_ids - ) == len(running_query.result): - running_query.flag.set() - elif message_type == 'connections_query': + check = self._session_is_running_checks.get(request_id) + if check: + check.running_sids.update(data['sids']) + if len(check.request_sids) == len(check.running_sids): + check.flag.set() + elif message_type == 'has_remote_connections_query': # Another node in the cluster is asking if the current node is connected to a session - query_id = data['query_id'] - connections = self._get_connections_locally( - data.get('user_id'), data.get('filter_to_sids') - ) - if connections: + sid = data['sid'] + required = sid in self.local_connection_id_to_session_id.values() + if required: await self._get_redis_client().publish( - 'session_msg', + 'oh_event', json.dumps( - { - 'query_id': query_id, - 'connections': connections, - 'message_type': 'connections_response', - } + {'sid': sid, 'message_type': 'has_remote_connections_response'} ), ) - elif message_type == 'connections_response': - query_id = data['query_id'] - connection_query = self._connection_queries.get(query_id) - if connection_query: - connection_query.result.update(**data['connections']) - if connection_query.request_ids is not None and len( - connection_query.request_ids - ) == len(connection_query.result): - connection_query.flag.set() + elif message_type == 'has_remote_connections_response': + sid = data['sid'] + flag = self._has_remote_connections_flags.get(sid) + if flag: + flag.set() elif message_type == 'close_session': sid = data['sid'] if sid in self._local_agent_loops_by_sid: - await self._close_session(sid) + await self._on_close_session(sid) elif message_type == 'session_closing': # Session closing event - We only get this in the event of graceful shutdown, # which can't be guaranteed - nodes can simply vanish unexpectedly! sid = data['sid'] logger.debug(f'session_closing:{sid}') # Create a list of items to process to avoid modifying dict during iteration - items = list(self._local_connection_id_to_session_id.items()) + items = list(self.local_connection_id_to_session_id.items()) for connection_id, local_sid in items: if sid == local_sid: logger.warning( @@ -223,7 +208,7 @@ async def join_conversation( ): logger.info(f'join_conversation:{sid}:{connection_id}') await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid)) - self._local_connection_id_to_session_id[connection_id] = sid + self.local_connection_id_to_session_id[connection_id] = sid event_stream = await self._get_event_stream(sid) if not event_stream: return await self.maybe_start_agent_loop(sid, settings, user_id) @@ -241,7 +226,7 @@ async def detach_from_conversation(self, conversation: Conversation): self._active_conversations.pop(sid) self._detached_conversations[sid] = (conversation, time.time()) - async def _cleanup_stale(self): + async def _cleanup_detached_conversations(self): while should_continue(): if self._get_redis_client(): # Debug info for HA envs @@ -255,7 +240,7 @@ async def _cleanup_stale(self): f'Running agent loops: {len(self._local_agent_loops_by_sid)}' ) logger.info( - f'Local connections: {len(self._local_connection_id_to_session_id)}' + f'Local connections: {len(self.local_connection_id_to_session_id)}' ) try: async with self._conversations_lock: @@ -265,176 +250,97 @@ async def _cleanup_stale(self): await conversation.disconnect() self._detached_conversations.pop(sid, None) - close_threshold = time.time() - self.config.sandbox.close_delay - running_loops = list(self._local_agent_loops_by_sid.items()) - running_loops.sort(key=lambda item: item[1].last_active_ts) - sid_to_close: list[str] = [] - for sid, session in running_loops: - state = session.agent_session.get_state() - if session.last_active_ts < close_threshold and state not in [ - AgentState.RUNNING, - None, - ]: - sid_to_close.append(sid) - - connections = self._get_connections_locally( - filter_to_sids=set(sid_to_close) - ) - connected_sids = {sid for _, sid in connections.items()} - sid_to_close = [ - sid for sid in sid_to_close if sid not in connected_sids - ] - - if sid_to_close: - connections = await self._get_connections_remotely( - filter_to_sids=set(sid_to_close) - ) - connected_sids = {sid for _, sid in connections.items()} - sid_to_close = [ - sid for sid in sid_to_close if sid not in connected_sids - ] - - await wait_all(self._close_session(sid) for sid in sid_to_close) await asyncio.sleep(_CLEANUP_INTERVAL) except asyncio.CancelledError: async with self._conversations_lock: for conversation, _ in self._detached_conversations.values(): await conversation.disconnect() self._detached_conversations.clear() - await wait_all( - self._close_session(sid) for sid in self._local_agent_loops_by_sid - ) return except Exception as e: - logger.warning(f'error_cleaning_stale: {str(e)}') - await asyncio.sleep(_CLEANUP_INTERVAL) + logger.warning(f'error_cleaning_detached_conversations: {str(e)}') + await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME) + + async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]: + running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid) + check_cluster_sids = [sid for sid in sids if sid not in running_sids] + running_cluster_sids = await self.get_agent_loop_running_in_cluster( + check_cluster_sids + ) + running_sids.union(running_cluster_sids) + return running_sids async def is_agent_loop_running(self, sid: str) -> bool: - sids = await self.get_running_agent_loops(filter_to_sids={sid}) - return bool(sids) - - async def get_running_agent_loops( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> set[str]: - """Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest.""" - sids = self._get_running_agent_loops_locally(user_id, filter_to_sids) - remote_sids = await self._get_running_agent_loops_remotely( - user_id, filter_to_sids - ) - return sids.union(remote_sids) - - def _get_running_agent_loops_locally( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> set[str]: - items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items() - if filter_to_sids is not None: - items = (item for item in items if item[0] in filter_to_sids) - if user_id: - items = (item for item in items if item[1].user_id == user_id) - sids = {sid for sid, _ in items} - return sids - - async def _get_running_agent_loops_remotely( - self, - user_id: str | None = None, - filter_to_sids: set[str] | None = None, - ) -> set[str]: + if await self.is_agent_loop_running_locally(sid): + return True + if await self.is_agent_loop_running_in_cluster(sid): + return True + return False + + async def is_agent_loop_running_locally(self, sid: str) -> bool: + return sid in self._local_agent_loops_by_sid + + async def is_agent_loop_running_in_cluster(self, sid: str) -> bool: + running_sids = await self.get_agent_loop_running_in_cluster([sid]) + return bool(running_sids) + + async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]: """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply""" redis_client = self._get_redis_client() if not redis_client: return set() flag = asyncio.Event() - query_id = str(uuid4()) - query = _ClusterQuery[set[str]]( - query_id=query_id, request_ids=filter_to_sids, result=set() - ) - self._running_sid_queries[query_id] = query + request_id = str(uuid4()) + check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids) + self._session_is_running_checks[request_id] = check try: - logger.debug( - f'publish:_get_running_agent_loops_remotely_query:{user_id}:{filter_to_sids}' + logger.debug(f'publish:is_session_running:{sids}') + await redis_client.publish( + 'oh_event', + json.dumps( + { + 'request_id': request_id, + 'sids': sids, + 'message_type': 'is_session_running', + } + ), ) - data: dict = { - 'query_id': query_id, - 'message_type': 'running_agent_loops_query', - } - if user_id: - data['user_id'] = user_id - if filter_to_sids: - data['filter_to_sids'] = list(filter_to_sids) - await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - return query.result + return check.running_sids except TimeoutError: # Nobody replied in time - return query.result + return check.running_sids finally: - self._running_sid_queries.pop(query_id, None) - - async def get_connections( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> dict[str, str]: - connection_ids = self._get_connections_locally(user_id, filter_to_sids) - remote_connection_ids = await self._get_connections_remotely( - user_id, filter_to_sids - ) - connection_ids.update(**remote_connection_ids) - return connection_ids - - def _get_connections_locally( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> dict[str, str]: - connections = dict(**self._local_connection_id_to_session_id) - if filter_to_sids is not None: - connections = { - connection_id: sid - for connection_id, sid in connections.items() - if sid in filter_to_sids - } - if user_id: - for connection_id, sid in list(connections.items()): - session = self._local_agent_loops_by_sid.get(sid) - if not session or session.user_id != user_id: - connections.pop(connection_id) - return connections - - async def _get_connections_remotely( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> dict[str, str]: - redis_client = self._get_redis_client() - if not redis_client: - return {} + self._session_is_running_checks.pop(request_id, None) + async def _has_remote_connections(self, sid: str) -> bool: + """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply""" + # Create a flag for the callback flag = asyncio.Event() - query_id = str(uuid4()) - query = _ClusterQuery[dict[str, str]]( - query_id=query_id, request_ids=filter_to_sids, result={} - ) - self._connection_queries[query_id] = query + self._has_remote_connections_flags[sid] = flag try: - logger.debug( - f'publish:get_connections_remotely_query:{user_id}:{filter_to_sids}' + await self._get_redis_client().publish( + 'oh_event', + json.dumps( + { + 'sid': sid, + 'message_type': 'has_remote_connections_query', + } + ), ) - data: dict = { - 'query_id': query_id, - 'message_type': 'connections_query', - } - if user_id: - data['user_id'] = user_id - if filter_to_sids: - data['filter_to_sids'] = list(filter_to_sids) - await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - return query.result + result = flag.is_set() + return result except TimeoutError: # Nobody replied in time - return query.result + return False finally: - self._connection_queries.pop(query_id, None) + self._has_remote_connections_flags.pop(sid, None) async def maybe_start_agent_loop( self, sid: str, settings: Settings, user_id: str | None @@ -443,18 +349,8 @@ async def maybe_start_agent_loop( session: Session | None = None if not await self.is_agent_loop_running(sid): logger.info(f'start_agent_loop:{sid}') - - response_ids = await self.get_running_agent_loops(user_id) - if len(response_ids) >= MAX_RUNNING_CONVERSATIONS: - logger.info('too_many_sessions_for:{user_id}') - await self.close_session(next(iter(response_ids))) - session = Session( - sid=sid, - file_store=self.file_store, - config=self.config, - sio=self.sio, - user_id=user_id, + sid=sid, file_store=self.file_store, config=self.config, sio=self.sio ) self._local_agent_loops_by_sid[sid] = session asyncio.create_task(session.initialize_agent(settings)) @@ -463,6 +359,7 @@ async def maybe_start_agent_loop( if not event_stream: logger.error(f'No event stream after starting agent loop: {sid}') raise RuntimeError(f'no_event_stream:{sid}') + asyncio.create_task(self._cleanup_session_later(sid)) return event_stream async def _get_event_stream(self, sid: str) -> EventStream | None: @@ -472,7 +369,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: logger.info(f'found_local_agent_loop:{sid}') return session.agent_session.event_stream - if await self._get_running_agent_loops_remotely(filter_to_sids={sid}): + if await self.is_agent_loop_running_in_cluster(sid): logger.info(f'found_remote_agent_loop:{sid}') return EventStream(sid, self.file_store) @@ -480,7 +377,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: async def send_to_event_stream(self, connection_id: str, data: dict): # If there is a local session running, send to that - sid = self._local_connection_id_to_session_id.get(connection_id) + sid = self.local_connection_id_to_session_id.get(connection_id) if not sid: raise RuntimeError(f'no_connected_session:{connection_id}') @@ -496,11 +393,11 @@ async def send_to_event_stream(self, connection_id: str, data: dict): next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL if ( next_alive_check > time.time() - or await self._get_running_agent_loops_remotely(filter_to_sids={sid}) + or await self.is_agent_loop_running_in_cluster(sid) ): # Send the event to the other pod await redis_client.publish( - 'session_msg', + 'oh_event', json.dumps( { 'sid': sid, @@ -514,37 +411,75 @@ async def send_to_event_stream(self, connection_id: str, data: dict): raise RuntimeError(f'no_connected_session:{connection_id}:{sid}') async def disconnect_from_session(self, connection_id: str): - sid = self._local_connection_id_to_session_id.pop(connection_id, None) + sid = self.local_connection_id_to_session_id.pop(connection_id, None) logger.info(f'disconnect_from_session:{connection_id}:{sid}') if not sid: # This can occur if the init action was never run. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}') return + if should_continue(): + asyncio.create_task(self._cleanup_session_later(sid)) + else: + await self._on_close_session(sid) + + async def _cleanup_session_later(self, sid: str): + # Once there have been no connections to a session for a reasonable period, we close it + try: + await asyncio.sleep(self.config.sandbox.close_delay) + finally: + # If the sleep was cancelled, we still want to close these + await self._cleanup_session(sid) + + async def _cleanup_session(self, sid: str) -> bool: + # Get local connections + logger.info(f'_cleanup_session:{sid}') + has_local_connections = next( + (True for v in self.local_connection_id_to_session_id.values() if v == sid), + False, + ) + if has_local_connections: + return False + + # If no local connections, get connections through redis + redis_client = self._get_redis_client() + if redis_client and await self._has_remote_connections(sid): + return False + + # We alert the cluster in case they are interested + if redis_client: + await redis_client.publish( + 'oh_event', + json.dumps({'sid': sid, 'message_type': 'session_closing'}), + ) + + await self._on_close_session(sid) + return True + async def close_session(self, sid: str): session = self._local_agent_loops_by_sid.get(sid) if session: - await self._close_session(sid) + await self._on_close_session(sid) redis_client = self._get_redis_client() if redis_client: await redis_client.publish( - 'session_msg', + 'oh_event', json.dumps({'sid': sid, 'message_type': 'close_session'}), ) - async def _close_session(self, sid: str): + async def _on_close_session(self, sid: str): logger.info(f'_close_session:{sid}') # Clear up local variables connection_ids_to_remove = list( connection_id - for connection_id, conn_sid in self._local_connection_id_to_session_id.items() + for connection_id, conn_sid in self.local_connection_id_to_session_id.items() if sid == conn_sid ) logger.info(f'removing connections: {connection_ids_to_remove}') for connnnection_id in connection_ids_to_remove: - self._local_connection_id_to_session_id.pop(connnnection_id, None) + self.local_connection_id_to_session_id.pop(connnnection_id, None) session = self._local_agent_loops_by_sid.pop(sid, None) if not session: @@ -553,17 +488,12 @@ async def _close_session(self, sid: str): logger.info(f'closing_session:{session.sid}') # We alert the cluster in case they are interested - try: - redis_client = self._get_redis_client() - if redis_client: - await redis_client.publish( - 'session_msg', - json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), - ) - except Exception: - logger.info( - 'error_publishing_close_session_event', exc_info=True, stack_info=True + redis_client = self._get_redis_client() + if redis_client: + await redis_client.publish( + 'oh_event', + json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), ) - await session.close() + await call_sync_from_async(session.close) logger.info(f'closed_session:{session.sid}') diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index e77a77101b20..8318ab773129 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -62,17 +62,9 @@ def __init__( self.loop = asyncio.get_event_loop() self.user_id = user_id - async def close(self): - if self.sio: - await self.sio.emit( - 'oh_event', - event_to_dict( - AgentStateChangedObservation('', AgentState.STOPPED.value) - ), - to=ROOM_KEY.format(sid=self.sid), - ) + def close(self): self.is_alive = False - await self.agent_session.close() + self.agent_session.close() async def initialize_agent( self, diff --git a/openhands/server/types.py b/openhands/server/types.py index 8ecb898a76cf..cbf9389d2b44 100644 --- a/openhands/server/types.py +++ b/openhands/server/types.py @@ -35,3 +35,15 @@ async def verify_github_repo_list(self, installation_id: int | None) -> None: async def get_config(self) -> dict[str, str]: """Configure attributes for frontend""" raise NotImplementedError + + +class MissingSettingsError(ValueError): + """Raised when settings are missing or not found.""" + + pass + + +class LLMAuthenticationError(ValueError): + """Raised when there is an issue with LLM authentication.""" + + pass diff --git a/openhands/utils/http_session.py b/openhands/utils/http_session.py deleted file mode 100644 index 4edc4e6546c3..000000000000 --- a/openhands/utils/http_session.py +++ /dev/null @@ -1,24 +0,0 @@ -from dataclasses import dataclass, field - -import requests - - -@dataclass -class HttpSession: - """ - request.Session is reusable after it has been closed. This behavior makes it - likely to leak file descriptors (Especially when combined with tenacity). - We wrap the session to make it unusable after being closed - """ - - session: requests.Session | None = field(default_factory=requests.Session) - - def __getattr__(self, name): - if self.session is None: - raise ValueError('session_was_closed') - return object.__getattribute__(self.session, name) - - def close(self): - if self.session is not None: - self.session.close() - self.session = None diff --git a/openhands/utils/prompt.py b/openhands/utils/prompt.py index 8d81cbbdf9d6..1861c45308b5 100644 --- a/openhands/utils/prompt.py +++ b/openhands/utils/prompt.py @@ -28,6 +28,33 @@ class RepositoryInfo: repo_directory: str | None = None +ADDITIONAL_INFO_TEMPLATE = Template( + """ +{% if repository_info %} + +At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}. + +{% endif %} +{% if repository_instructions -%} + +{{ repository_instructions }} + +{% endif %} +{% if runtime_info and runtime_info.available_hosts -%} + +The user has access to the following hosts for accessing a web application, +each of which has a corresponding port: +{% for host, port in runtime_info.available_hosts.items() -%} +* {{ host }} (port {{ port }}) +{% endfor %} +When starting a web server, use the corresponding ports. You should also +set any options to allow iframes and CORS requests. + +{% endif %} +""" +) + + class PromptManager: """ Manages prompt templates and micro-agents for AI interactions. @@ -59,6 +86,9 @@ def __init__( self.repo_microagents: dict[str, RepoMicroAgent] = {} if microagent_dir: + # This loads micro-agents from the microagent_dir + # which is typically the OpenHands/microagents (i.e., the PUBLIC microagents) + # Only load KnowledgeMicroAgents repo_microagents, knowledge_microagents, _ = load_microagents_from_dir( microagent_dir @@ -79,6 +109,10 @@ def __init__( self.repo_microagents[name] = microagent def load_microagents(self, microagents: list[BaseMicroAgent]): + """Load microagents from a list of BaseMicroAgents. + + This is typically used when loading microagents from inside a repo. + """ # Only keep KnowledgeMicroAgents and RepoMicroAgents for microagent in microagents: if microagent.name in self.disabled_microagents: @@ -98,6 +132,13 @@ def _load_template(self, template_name: str) -> Template: return Template(file.read()) def get_system_message(self) -> str: + return self.system_template.render().strip() + + def get_additional_info(self) -> str: + """Gets information about the repository and runtime. + + This is used to inject information about the repository and runtime into the initial user message. + """ repo_instructions = '' assert ( len(self.repo_microagents) <= 1 @@ -108,7 +149,7 @@ def get_system_message(self) -> str: repo_instructions += '\n\n' repo_instructions += microagent.content - return self.system_template.render( + return ADDITIONAL_INFO_TEMPLATE.render( repository_instructions=repo_instructions, repository_info=self.repository_info, runtime_info=self.runtime_info, diff --git a/poetry.lock b/poetry.lock index 3cd9f4706f2e..06c119ee3497 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3751,19 +3751,19 @@ pydantic = ">=1.10" [[package]] name = "llama-index" -version = "0.12.10" +version = "0.12.11" description = "Interface between LLMs and your data" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "llama_index-0.12.10-py3-none-any.whl", hash = "sha256:c397e1355d48a043a4636857519185f9a47eb25e6482134c28e75f64cd4fe11e"}, - {file = "llama_index-0.12.10.tar.gz", hash = "sha256:942bd89f6363a553ff30f053df3c12703ac81c726d1afb7fc14555b0ede5e8a2"}, + {file = "llama_index-0.12.11-py3-none-any.whl", hash = "sha256:007361c35e1981a1656cef287b7bcdf22aa88e7d41b8e3a8ee261bb5a10519a9"}, + {file = "llama_index-0.12.11.tar.gz", hash = "sha256:b1116946a2414aec104a6c417b847da5b4f077a0966c50ebd2fc445cd713adce"}, ] [package.dependencies] llama-index-agent-openai = ">=0.4.0,<0.5.0" llama-index-cli = ">=0.4.0,<0.5.0" -llama-index-core = ">=0.12.10,<0.13.0" +llama-index-core = ">=0.12.11,<0.13.0" llama-index-embeddings-openai = ">=0.3.0,<0.4.0" llama-index-indices-managed-llama-cloud = ">=0.4.0" llama-index-llms-openai = ">=0.3.0,<0.4.0" @@ -3808,13 +3808,13 @@ llama-index-llms-openai = ">=0.3.0,<0.4.0" [[package]] name = "llama-index-core" -version = "0.12.10.post1" +version = "0.12.11" description = "Interface between LLMs and your data" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "llama_index_core-0.12.10.post1-py3-none-any.whl", hash = "sha256:897e8cd4efeff6842580b043bdf4008ac60f693df1de2bfd975307a4845707c2"}, - {file = "llama_index_core-0.12.10.post1.tar.gz", hash = "sha256:af27bea4d1494ba84983a649976e60e3de677a73946aa45ed12ce27e3a623ddf"}, + {file = "llama_index_core-0.12.11-py3-none-any.whl", hash = "sha256:3b1e019c899e9e011dfa01c96b7e3f666e0c161035fbca6cb787b4c61e0c94db"}, + {file = "llama_index_core-0.12.11.tar.gz", hash = "sha256:9a41ca91167ea5eec9ebaac7f5e958b7feddbd8af3bfbf7c393a5edfb994d566"}, ] [package.dependencies] diff --git a/tests/runtime/test_bash.py b/tests/runtime/test_bash.py index 3a25fd01ddee..828c859f11dd 100644 --- a/tests/runtime/test_bash.py +++ b/tests/runtime/test_bash.py @@ -1,6 +1,7 @@ """Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox.""" import os +import time from pathlib import Path import pytest @@ -45,7 +46,7 @@ def test_bash_server(temp_dir, runtime_cls, run_as_openhands): runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands) try: action = CmdRunAction(command='python3 -m http.server 8080') - action.timeout = 1 + action.set_hard_timeout(1) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert isinstance(obs, CmdOutputObservation) @@ -57,7 +58,7 @@ def test_bash_server(temp_dir, runtime_cls, run_as_openhands): ) action = CmdRunAction(command='C-c') - action.timeout = 30 + action.set_hard_timeout(30) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert isinstance(obs, CmdOutputObservation) @@ -66,7 +67,7 @@ def test_bash_server(temp_dir, runtime_cls, run_as_openhands): assert '/workspace' in obs.metadata.working_dir action = CmdRunAction(command='ls') - action.timeout = 1 + action.set_hard_timeout(1) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert isinstance(obs, CmdOutputObservation) @@ -76,7 +77,7 @@ def test_bash_server(temp_dir, runtime_cls, run_as_openhands): # run it again! action = CmdRunAction(command='python3 -m http.server 8080') - action.timeout = 1 + action.set_hard_timeout(1) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert isinstance(obs, CmdOutputObservation) @@ -555,15 +556,20 @@ def test_basic_command(temp_dir, runtime_cls, run_as_openhands): def test_interactive_command(temp_dir, runtime_cls, run_as_openhands): - runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands) + runtime = _load_runtime( + temp_dir, + runtime_cls, + run_as_openhands, + runtime_startup_env_vars={'NO_CHANGE_TIMEOUT_SECONDS': '1'}, + ) try: # Test interactive command action = CmdRunAction('read -p "Enter name: " name && echo "Hello $name"') - action.timeout = 1 obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - # assert 'Enter name:' in obs.content # FIXME: this is not working - assert '[The command timed out after 1 seconds.' in obs.metadata.suffix + # This should trigger SOFT timeout, so no need to set hard timeout + assert 'Enter name:' in obs.content + assert '[The command has no new output after 1 seconds.' in obs.metadata.suffix action = CmdRunAction('John') obs = runtime.run_action(action) @@ -590,7 +596,7 @@ def test_long_output(temp_dir, runtime_cls, run_as_openhands): try: # Generate a long output action = CmdRunAction('for i in $(seq 1 5000); do echo "Line $i"; done') - action.timeout = 10 + action.set_hard_timeout(10) obs = runtime.run_action(action) assert obs.exit_code == 0 assert 'Line 1' in obs.content @@ -604,7 +610,7 @@ def test_long_output_exceed_history_limit(temp_dir, runtime_cls, run_as_openhand try: # Generate a long output action = CmdRunAction('for i in $(seq 1 50000); do echo "Line $i"; done') - action.timeout = 30 + action.set_hard_timeout(30) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert obs.exit_code == 0 @@ -621,13 +627,13 @@ def test_long_output_from_nested_directories(temp_dir, runtime_cls, run_as_openh # Create nested directories with many files setup_cmd = 'mkdir -p /tmp/test_dir && cd /tmp/test_dir && for i in $(seq 1 100); do mkdir -p "folder_$i"; for j in $(seq 1 100); do touch "folder_$i/file_$j.txt"; done; done' setup_action = CmdRunAction(setup_cmd.strip()) - setup_action.timeout = 60 + setup_action.set_hard_timeout(60) obs = runtime.run_action(setup_action) assert obs.exit_code == 0 # List the directory structure recursively action = CmdRunAction('ls -R /tmp/test_dir') - action.timeout = 60 + action.set_hard_timeout(60) obs = runtime.run_action(action) assert obs.exit_code == 0 @@ -672,7 +678,7 @@ def test_command_output_continuation(temp_dir, runtime_cls, run_as_openhands): try: # Start a command that produces output slowly action = CmdRunAction('for i in {1..5}; do echo $i; sleep 3; done') - action.timeout = 2.5 # Set timeout to 2.5 seconds + action.set_hard_timeout(2.5) obs = runtime.run_action(action) assert obs.content.strip() == '1' assert obs.metadata.prefix == '' @@ -680,20 +686,19 @@ def test_command_output_continuation(temp_dir, runtime_cls, run_as_openhands): # Continue watching output action = CmdRunAction('') - action.timeout = 2.5 + action.set_hard_timeout(2.5) obs = runtime.run_action(action) - assert '[Command output continued from previous command]' in obs.metadata.prefix + assert '[Below is the output of the previous command.]' in obs.metadata.prefix assert obs.content.strip() == '2' assert '[The command timed out after 2.5 seconds.' in obs.metadata.suffix # Continue until completion for expected in ['3', '4', '5']: action = CmdRunAction('') - action.timeout = 2.5 + action.set_hard_timeout(2.5) obs = runtime.run_action(action) assert ( - '[Command output continued from previous command]' - in obs.metadata.prefix + '[Below is the output of the previous command.]' in obs.metadata.prefix ) assert obs.content.strip() == expected assert '[The command timed out after 2.5 seconds.' in obs.metadata.suffix @@ -713,8 +718,7 @@ def test_long_running_command_follow_by_execute( try: # Test command that produces output slowly action = CmdRunAction('for i in {1..3}; do echo $i; sleep 3; done') - action.timeout = 2.5 - action.blocking = False + action.set_hard_timeout(2.5) obs = runtime.run_action(action) assert '1' in obs.content # First number should appear before timeout assert obs.metadata.exit_code == -1 # -1 indicates command is still running @@ -723,25 +727,32 @@ def test_long_running_command_follow_by_execute( # Continue watching output action = CmdRunAction('') - action.timeout = 2.5 + action.set_hard_timeout(2.5) obs = runtime.run_action(action) assert '2' in obs.content - assert ( - obs.metadata.prefix == '[Command output continued from previous command]\n' - ) + assert obs.metadata.prefix == '[Below is the output of the previous command.]\n' assert '[The command timed out after 2.5 seconds.' in obs.metadata.suffix assert obs.metadata.exit_code == -1 # -1 indicates command is still running # Test command that produces no output action = CmdRunAction('sleep 15') - action.timeout = 2.5 + action.set_hard_timeout(2.5) obs = runtime.run_action(action) - assert '3' in obs.content + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert '3' not in obs.content + assert obs.metadata.prefix == '[Below is the output of the previous command.]\n' assert ( - obs.metadata.prefix == '[Command output continued from previous command]\n' + 'The previous command was timed out but still running.' + in obs.metadata.suffix ) - assert '[The command timed out after 2.5 seconds.' in obs.metadata.suffix assert obs.metadata.exit_code == -1 # -1 indicates command is still running + + # Finally continue again + action = CmdRunAction('') + obs = runtime.run_action(action) + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert '3' in obs.content + assert '[The command completed with exit code 0.]' in obs.metadata.suffix finally: _close_test_runtime(runtime) @@ -783,3 +794,96 @@ def test_python_interactive_input(temp_dir, runtime_cls, run_as_openhands): assert '[The command completed with exit code 0.]' in obs.metadata.suffix finally: _close_test_runtime(runtime) + + +def test_stress_long_output_with_soft_and_hard_timeout( + temp_dir, runtime_cls, run_as_openhands +): + runtime = _load_runtime( + temp_dir, + runtime_cls, + run_as_openhands, + runtime_startup_env_vars={'NO_CHANGE_TIMEOUT_SECONDS': '1'}, + docker_runtime_kwargs={ + 'cpu_period': 100000, # 100ms + 'cpu_quota': 100000, # Can use 100ms out of each 100ms period (1 CPU) + 'mem_limit': '4G', # 4 GB of memory + }, + ) + try: + # Run a command that generates long output multiple times + for i in range(10): + start_time = time.time() + + # Check tmux memory usage (in KB) + mem_action = CmdRunAction( + 'ps aux | awk \'{printf "%8.1f KB %s\\n", $6, $0}\' | sort -nr | grep "/usr/bin/tmux" | grep -v grep | awk \'{print $1}\'' + ) + mem_obs = runtime.run_action(mem_action) + assert mem_obs.exit_code == 0 + logger.info( + f'Tmux memory usage (iteration {i}): {mem_obs.content.strip()} KB' + ) + + # Check action_execution_server mem + mem_action = CmdRunAction( + 'ps aux | awk \'{printf "%8.1f KB %s\\n", $6, $0}\' | sort -nr | grep "action_execution_server" | grep "/openhands/poetry" | grep -v grep | awk \'{print $1}\'' + ) + mem_obs = runtime.run_action(mem_action) + assert mem_obs.exit_code == 0 + logger.info( + f'Action execution server memory usage (iteration {i}): {mem_obs.content.strip()} KB' + ) + + # Test soft timeout + action = CmdRunAction( + 'read -p "Do you want to continue? [Y/n] " answer; if [[ $answer == "Y" ]]; then echo "Proceeding with operation..."; echo "Operation completed successfully!"; else echo "Operation cancelled."; exit 1; fi' + ) + obs = runtime.run_action(action) + assert 'Do you want to continue?' in obs.content + assert obs.exit_code == -1 # Command is still running, waiting for input + + # Send the confirmation + action = CmdRunAction('Y') + obs = runtime.run_action(action) + assert 'Proceeding with operation...' in obs.content + assert 'Operation completed successfully!' in obs.content + assert obs.exit_code == 0 + assert '[The command completed with exit code 0.]' in obs.metadata.suffix + + # Test hard timeout w/ long output + # Generate long output with 1000 asterisks per line + action = CmdRunAction( + f'export i={i}; for j in $(seq 1 100); do echo "Line $j - Iteration $i - $(printf \'%1000s\' | tr " " "*")"; sleep 1; done' + ) + action.set_hard_timeout(2) + obs = runtime.run_action(action) + + # Verify the output + assert obs.exit_code == -1 + assert f'Line 1 - Iteration {i}' in obs.content + # assert f'Line 1000 - Iteration {i}' in obs.content + # assert '[The command completed with exit code 0.]' in obs.metadata.suffix + + # Because hard-timeout is triggered, the terminal will in a weird state + # where it will not accept any new commands. + obs = runtime.run_action(CmdRunAction('ls')) + assert obs.exit_code == -1 + assert ( + 'The previous command was timed out but still running.' + in obs.metadata.suffix + ) + + # We need to send a Ctrl+C to reset the terminal. + obs = runtime.run_action(CmdRunAction('C-c')) + assert obs.exit_code == 130 + + # Now make sure the terminal is in a good state + obs = runtime.run_action(CmdRunAction('ls')) + assert obs.exit_code == 0 + + duration = time.time() - start_time + logger.info(f'Completed iteration {i} in {duration:.2f} seconds') + + finally: + _close_test_runtime(runtime) diff --git a/tests/runtime/test_stress_docker_runtime.py b/tests/runtime/test_stress_docker_runtime.py index d0e141ee3142..6e8a9d5957e8 100644 --- a/tests/runtime/test_stress_docker_runtime.py +++ b/tests/runtime/test_stress_docker_runtime.py @@ -28,7 +28,7 @@ def test_stress_docker_runtime(temp_dir, runtime_cls, repeat=1): for _ in range(repeat): # run stress-ng stress tests for 1 minute action = CmdRunAction(command='stress-ng --all 1 -t 1m') - action.timeout = 120 + action.set_hard_timeout(120) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/tests/runtime/test_stress_remote_runtime.py b/tests/runtime/test_stress_remote_runtime.py index 367af20467be..a2f6c7d2082b 100644 --- a/tests/runtime/test_stress_remote_runtime.py +++ b/tests/runtime/test_stress_remote_runtime.py @@ -92,14 +92,14 @@ def initialize_runtime( obs: CmdOutputObservation action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """) - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}') action = CmdRunAction(command='mkdir -p /dummy_dir') - action.timeout = 600 + action.set_hard_timeout(600) logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/tests/unit/test_bash_session.py b/tests/unit/test_bash_session.py index 13f32ad27d25..fc29eaffb2a5 100644 --- a/tests/unit/test_bash_session.py +++ b/tests/unit/test_bash_session.py @@ -94,7 +94,7 @@ def test_long_running_command_follow_by_execute(): obs = session.execute(CmdRunAction('')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert '2' in obs.content - assert obs.metadata.prefix == '[Command output continued from previous command]\n' + assert obs.metadata.prefix == '[Below is the output of the previous command.]\n' assert obs.metadata.suffix == ( '\n[The command has no new output after 2 seconds. ' "You may wait longer to see additional output by sending empty command '', " @@ -108,7 +108,7 @@ def test_long_running_command_follow_by_execute(): obs = session.execute(CmdRunAction('sleep 15')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert '3' in obs.content - assert obs.metadata.prefix == '[Command output continued from previous command]\n' + assert obs.metadata.prefix == '[Below is the output of the previous command.]\n' assert obs.metadata.suffix == ( '\n[The command has no new output after 2 seconds. ' "You may wait longer to see additional output by sending empty command '', " @@ -175,7 +175,7 @@ def test_interactive_command(): 'send other commands to interact with the current process, ' 'or send keys to interrupt/kill the command.]' ) - assert obs.metadata.prefix == '[Command output continued from previous command]\n' + assert obs.metadata.prefix == '[Below is the output of the previous command.]\n' obs = session.execute(CmdRunAction('line 2')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -187,7 +187,7 @@ def test_interactive_command(): 'send other commands to interact with the current process, ' 'or send keys to interrupt/kill the command.]' ) - assert obs.metadata.prefix == '[Command output continued from previous command]\n' + assert obs.metadata.prefix == '[Below is the output of the previous command.]\n' obs = session.execute(CmdRunAction('EOF')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) @@ -266,14 +266,14 @@ def test_command_output_continuation(): obs = session.execute(CmdRunAction('')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert '[Command output continued from previous command]' in obs.metadata.prefix + assert '[Below is the output of the previous command.]' in obs.metadata.prefix assert obs.content.strip() == '2' assert '[The command has no new output after 2 seconds.' in obs.metadata.suffix assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT obs = session.execute(CmdRunAction('')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert '[Command output continued from previous command]' in obs.metadata.prefix + assert '[Below is the output of the previous command.]' in obs.metadata.prefix assert obs.content.strip() == '3' assert '[The command has no new output after 2 seconds.' in obs.metadata.suffix @@ -281,14 +281,14 @@ def test_command_output_continuation(): obs = session.execute(CmdRunAction('')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert '[Command output continued from previous command]' in obs.metadata.prefix + assert '[Below is the output of the previous command.]' in obs.metadata.prefix assert obs.content.strip() == '4' assert '[The command has no new output after 2 seconds.' in obs.metadata.suffix assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT obs = session.execute(CmdRunAction('')) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert '[Command output continued from previous command]' in obs.metadata.prefix + assert '[Below is the output of the previous command.]' in obs.metadata.prefix assert obs.content.strip() == '5' assert '[The command has no new output after 2 seconds.' in obs.metadata.suffix assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index b1f5e420c3b4..84f0b8fc1993 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -471,7 +471,7 @@ def test_mock_function_calling(): llm = Mock() llm.is_function_calling_active = lambda: False config = AgentConfig() - config.use_microagents = False + config.enable_prompt_extensions = False agent = CodeActAgent(llm=llm, config=config) assert agent.mock_function_calling is True @@ -509,7 +509,7 @@ def test_step_with_no_pending_actions(mock_state: State): # Create agent with mocked LLM config = AgentConfig() - config.use_microagents = False + config.enable_prompt_extensions = False agent = CodeActAgent(llm=llm, config=config) # Test step with no pending actions diff --git a/tests/unit/test_manager.py b/tests/unit/test_manager.py index cd2ddf6ba0a6..f0ac68ff8361 100644 --- a/tests/unit/test_manager.py +++ b/tests/unit/test_manager.py @@ -44,28 +44,28 @@ async def test_session_not_running_in_cluster(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager._get_running_agent_loops_remotely( - filter_to_sids={'non-existant-session'} + result = await session_manager.is_agent_loop_running_in_cluster( + 'non-existant-session' ) - assert result == set() + assert result is False assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'session_msg', - '{"query_id": "' + 'oh_event', + '{"request_id": "' + str(id) - + '", "message_type": "running_agent_loops_query", "filter_to_sids": ["non-existant-session"]}', + + '", "sids": ["non-existant-session"], "message_type": "is_session_running"}', ) @pytest.mark.asyncio -async def test_get_running_agent_loops_remotely(): +async def test_session_is_running_in_cluster(): id = uuid4() sio = get_mock_sio( GetMessageMock( { - 'query_id': str(id), + 'request_id': str(id), 'sids': ['existing-session'], - 'message_type': 'running_agent_loops_response', + 'message_type': 'session_is_running', } ) ) @@ -76,16 +76,16 @@ async def test_get_running_agent_loops_remotely(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager._get_running_agent_loops_remotely( - 1, {'existing-session'} + result = await session_manager.is_agent_loop_running_in_cluster( + 'existing-session' ) - assert result == {'existing-session'} + assert result is True assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'session_msg', - '{"query_id": "' + 'oh_event', + '{"request_id": "' + str(id) - + '", "message_type": "running_agent_loops_query", "user_id": 1, "filter_to_sids": ["existing-session"]}', + + '", "sids": ["existing-session"], "message_type": "is_session_running"}', ) @@ -96,8 +96,8 @@ async def test_init_new_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = set() + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = False with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1), @@ -106,8 +106,8 @@ async def test_init_new_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.get_running_agent_loops', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -130,8 +130,8 @@ async def test_join_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = set() + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = False with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -140,8 +140,8 @@ async def test_join_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.get_running_agent_loops', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -167,8 +167,8 @@ async def test_join_cluster_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = {'new-session-id'} + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = True with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -177,8 +177,8 @@ async def test_join_cluster_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -198,8 +198,8 @@ async def test_add_to_local_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = set() + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = False with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -208,8 +208,8 @@ async def test_add_to_local_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.get_running_agent_loops', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -234,8 +234,8 @@ async def test_add_to_cluster_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = {'new-session-id'} + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = True with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -244,8 +244,8 @@ async def test_add_to_cluster_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -259,7 +259,7 @@ async def test_add_to_cluster_event_stream(): ) assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'session_msg', + 'oh_event', '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}', ) @@ -277,7 +277,7 @@ async def test_cleanup_session_connections(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - session_manager._local_connection_id_to_session_id.update( + session_manager.local_connection_id_to_session_id.update( { 'conn1': 'session1', 'conn2': 'session1', @@ -286,9 +286,9 @@ async def test_cleanup_session_connections(): } ) - await session_manager._close_session('session1') + await session_manager._on_close_session('session1') - remaining_connections = session_manager._local_connection_id_to_session_id + remaining_connections = session_manager.local_connection_id_to_session_id assert 'conn1' not in remaining_connections assert 'conn2' not in remaining_connections assert 'conn3' in remaining_connections diff --git a/tests/unit/test_prompt_manager.py b/tests/unit/test_prompt_manager.py index 4f2a69f7f0d5..46f1f5a254a1 100644 --- a/tests/unit/test_prompt_manager.py +++ b/tests/unit/test_prompt_manager.py @@ -59,9 +59,10 @@ def test_prompt_manager_with_microagent(prompt_dir): # Test with GitHub repo manager.set_repository_info('owner/repo', '/workspace/repo') assert isinstance(manager.get_system_message(), str) - assert '' in manager.get_system_message() - assert 'owner/repo' in manager.get_system_message() - assert '/workspace/repo' in manager.get_system_message() + additional_info = manager.get_additional_info() + assert '' in additional_info + assert 'owner/repo' in additional_info + assert '/workspace/repo' in additional_info assert isinstance(manager.get_example_user_message(), str) @@ -85,13 +86,7 @@ def test_prompt_manager_file_not_found(prompt_dir): def test_prompt_manager_template_rendering(prompt_dir): # Create temporary template files with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f: - f.write("""System prompt: bar -{% if repository_info %} - -At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}. - -{% endif %} -{{ repo_instructions }}""") + f.write("""System prompt: bar""") with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f: f.write('User prompt: foo') @@ -106,12 +101,13 @@ def test_prompt_manager_template_rendering(prompt_dir): assert manager.repository_info.repo_name == 'owner/repo' system_msg = manager.get_system_message() assert 'System prompt: bar' in system_msg - assert '' in system_msg + additional_info = manager.get_additional_info() + assert '' in additional_info assert ( "At the user's request, repository owner/repo has been cloned to directory /workspace/repo." - in system_msg + in additional_info ) - assert '' in system_msg + assert '' in additional_info assert manager.get_example_user_message() == 'User prompt: foo' # Clean up temporary files diff --git a/tests/unit/test_runtime_reboot.py b/tests/unit/test_runtime_reboot.py index e3ae31815a3e..c78fbe029ad6 100644 --- a/tests/unit/test_runtime_reboot.py +++ b/tests/unit/test_runtime_reboot.py @@ -27,7 +27,7 @@ def runtime(mock_session): def test_runtime_timeout_error(runtime, mock_session): # Create a command action action = CmdRunAction(command='test command') - action.timeout = 120 + action.set_hard_timeout(120) # Mock the runtime to raise a timeout error runtime.send_action_for_execution.side_effect = AgentRuntimeTimeoutError( @@ -78,7 +78,7 @@ def test_runtime_disconnected_error( # Create a command action action = CmdRunAction(command='test command') - action.timeout = 120 + action.set_hard_timeout(120) # Verify that the error message is correct with pytest.raises(AgentRuntimeDisconnectedError) as exc_info: