Skip to content

Commit

Permalink
Fix workspace permission
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Aug 1, 2024
1 parent d83dd47 commit c5b98e4
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 28 deletions.
2 changes: 1 addition & 1 deletion hypha/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.20.1b2"
"version": "0.20.1b3"
}
3 changes: 3 additions & 0 deletions hypha/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def get_workspace_interface(
"""Get the interface of a workspace."""
assert workspace, "Workspace name is required"
assert user_info and isinstance(user_info, UserInfo), "User info is required"
# Check if workspace exists
if not self._redis.hexists("workspaces", workspace):
raise KeyError(f"Workspace {workspace} does not exist")
# the client will be hidden if client_id is None
if silent is None:
silent = client_id is None
Expand Down
14 changes: 5 additions & 9 deletions hypha/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,7 @@ async def list_services(
client_id = "*"
service_id = "*"

if query == "public":
query = {
"visibility": visibility,
"workspace": workspace,
"client_id": client_id,
"service_id": service_id,
}
elif "/" in query and ":" in query:
if "/" in query and ":" in query:
parts = query.split("/")
workspace_part = parts[0]
remaining = parts[1]
Expand Down Expand Up @@ -414,7 +407,7 @@ async def list_services(
"service_id": service_id,
}
else:
service_id = query
workspace = query
query = {
"visibility": visibility,
"workspace": workspace,
Expand Down Expand Up @@ -786,6 +779,9 @@ async def _get_service_api(self, service_id: str, context=None):
self.validate_context(context, permission=UserPermission.read)
ws = context["ws"]
user_info = UserInfo.model_validate(context["user"])
# Check if workspace exists
if not self._redis.hexists("workspaces", ws):
raise KeyError(f"Workspace {ws} does not exist")
# Now launch the app and get the service
svc = await self.get_service(service_id, mode="random", context=context)
# Create a rpc client for getting the launcher service as user.
Expand Down
27 changes: 9 additions & 18 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ async def test_workspace(fastapi_server, test_user_token):
public_svc = await api.list_services("public")
assert len(public_svc) > 0

current_svcs = await api.list_services(api.config.workspace)
assert len(current_svcs) == 1

current_svcs = await api.list_services()
assert len(current_svcs) == 1

login_svc = find_item(public_svc, "name", "Hypha Login")
assert len(login_svc["description"]) > 0

Expand Down Expand Up @@ -260,7 +266,7 @@ def test(context=None):
context = await service.test()
assert "from" in context and "to" in context and "user" in context

svcs = await api2.list_services("public")
svcs = await api2.list_services(service_info.config.workspace)
assert find_item(svcs, "name", "test_service_2")

assert api2.config["workspace"] == "my-test-workspace"
Expand Down Expand Up @@ -322,10 +328,8 @@ async def test_services(fastapi_server):
)
service = await api.get_service(service_info)
assert service["name"] == "test_service"
services = await api.list_services(
{"workspace": api.config.workspace, "id": "test_service"}
)
assert len(services) == 1
services = await api.list_services({"workspace": api.config.workspace})
assert find_item(services, "name", "test_service")

service_info = await api.register_service(
{
Expand Down Expand Up @@ -381,19 +385,6 @@ async def test_services(fastapi_server):
== 2
)

# service_info = await api.register_service(
# {
# "name": "test_service",
# "type": "#test",
# "idx": 4,
# "config": {"flags": ["single-instance"]}, # mark it as single instance
# },
# overwrite=True,
# )
# # it should remove other services because it's single instance service
# assert len(await api.list_services({"name": "test_service"})) == 1
# assert (await api.get_service("test_service"))["idx"] == 4


async def test_server_reconnection(fastapi_server):
"""Test reconnecting to the server."""
Expand Down

0 comments on commit c5b98e4

Please sign in to comment.