Skip to content

Commit

Permalink
Support current workspace for http endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Sep 28, 2024
1 parent dc66b2e commit efc4517
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 26 deletions.
6 changes: 4 additions & 2 deletions hypha/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ async def start(
workspace = context["ws"]
user_info = UserInfo.model_validate(context["user"])

async with self.store.get_workspace_interface(workspace, user_info) as ws:
async with self.store.get_workspace_interface(user_info, workspace) as ws:
token = await ws.generate_token()

if not user_info.check_permission(workspace, UserPermission.read):
Expand Down Expand Up @@ -520,7 +520,9 @@ async def list_running(self, context: Optional[dict] = None) -> List[str]:
async def list_apps(self, context: Optional[dict] = None):
"""List applications in the workspace."""
try:
apps = await self.artifact_manager.read(prefix="applications", context=context)
apps = await self.artifact_manager.read(
prefix="applications", context=context
)
return apps["collection"]
except KeyError:
return []
Expand Down
32 changes: 21 additions & 11 deletions hypha/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,13 @@ async def start_login(workspace: str = None, expires_in: int = None):
# set the key and with expire time
await redis.setex(LOGIN_KEY_PREFIX + key, MAXIMUM_LOGIN_TIME, "")
return {
"login_url": f"{login_service_url.replace('/services/', '/apps/')}/?key={key}" + (
f"&workspace={workspace}" if workspace else "" +
f"&expires_in={expires_in}" if expires_in else ""
"login_url": f"{login_service_url.replace('/services/', '/apps/')}/?key={key}"
+ (
f"&workspace={workspace}"
if workspace
else "" + f"&expires_in={expires_in}"
if expires_in
else ""
),
"key": key,
"report_url": f"{login_service_url}/report",
Expand All @@ -385,7 +389,9 @@ async def index(event):

async def check_login(key, timeout=MAXIMUM_LOGIN_TIME, profile=False):
"""Check the status of a login session."""
assert await redis.exists(LOGIN_KEY_PREFIX + key), "Invalid key, key does not exist"
assert await redis.exists(
LOGIN_KEY_PREFIX + key
), "Invalid key, key does not exist"
if timeout <= 0:
user_info = await redis.get(LOGIN_KEY_PREFIX + key)
if user_info == b"":
Expand Down Expand Up @@ -434,7 +440,9 @@ async def report_login(
picture=None,
):
"""Report a token associated with a login session."""
assert await redis.exists(LOGIN_KEY_PREFIX + key), "Invalid key, key does not exist or expired"
assert await redis.exists(
LOGIN_KEY_PREFIX + key
), "Invalid key, key does not exist or expired"
# workspace = workspace or ("ws-user-" + user_id)
kwargs = {
"token": token,
Expand All @@ -447,24 +455,26 @@ async def report_login(
"user_id": user_id,
"picture": picture,
}

user_token_info = UserTokenInfo.model_validate(kwargs)
if workspace:
user_info = parse_token(token)
# based on the user token, create a scoped token
workspace = workspace or user_info.get_workspace()
# generate scoped token
workspace_info = await store.load_or_create_workspace(user_info, workspace)
user_info.scope = update_user_scope(
user_info, workspace_info
)
user_info.scope = update_user_scope(user_info, workspace_info)
if not user_info.check_permission(workspace, UserPermission.read):
raise Exception(f"Invalid permission for the workspace {workspace}")

token = generate_presigned_token(user_info, int(expires_in or 3600))
# replace the token
user_token_info.token = token
await redis.setex(LOGIN_KEY_PREFIX + key, MAXIMUM_LOGIN_TIME, user_token_info.model_dump_json())
await redis.setex(
LOGIN_KEY_PREFIX + key,
MAXIMUM_LOGIN_TIME,
user_token_info.model_dump_json(),
)

logger.info(
f"To preview the login page, visit: {login_service_url.replace('/services/', '/apps/')}"
Expand Down
20 changes: 14 additions & 6 deletions hypha/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ async def init(self, reset_redis, startup_functions=None):
async def _register_root_services(self):
"""Register root services."""
self._root_workspace_interface = await self.get_workspace_interface(
self._root_user.get_workspace(),
self._root_user,
self._root_user.get_workspace(),
client_id=self._server_id,
silent=False,
)
Expand Down Expand Up @@ -476,7 +476,7 @@ async def get_public_api(self):
"""Get the public API."""
if self._public_workspace_interface is None:
self._public_workspace_interface = await self.get_workspace_interface(
"public", self._root_user, client_id=self._server_id, silent=False
self._root_user, "public", client_id=self._server_id, silent=False
)
return self._public_workspace_interface

Expand Down Expand Up @@ -532,7 +532,10 @@ async def parse_user_token(self, token):
return user_info

async def login_optional(
self, authorization: str = Header(None), access_token: str = Cookie(None), _token: str = Query(None)
self,
authorization: str = Header(None),
access_token: str = Cookie(None),
_token: str = Query(None),
):
"""Return user info or create an anonymouse user.
Expand All @@ -544,7 +547,12 @@ async def login_optional(
user_info = await self.parse_user_token(token)
return user_info
else:
return generate_anonymous_user()
user_info = generate_anonymous_user()
user_workspace = user_info.get_workspace()
user_info.scope = create_scope(
f"{user_workspace}#a", current_workspace=user_workspace
)
return user_info

async def get_all_workspace(self):
"""Get all workspaces."""
Expand Down Expand Up @@ -583,7 +591,7 @@ def connect_to_workspace(
client_id = self._server_id + "-" + random_id(readable=False)
user_info = user_info or self._root_user
return self.get_workspace_interface(
workspace, user_info, client_id=client_id, timeout=timeout, silent=silent
user_info, workspace, client_id=client_id, timeout=timeout, silent=silent
)

def get_manager_id(self):
Expand All @@ -609,8 +617,8 @@ def get_server_info(self):

def get_workspace_interface(
self,
workspace: str,
user_info: UserInfo,
workspace: str,
client_id=None,
timeout=10,
silent=True,
Expand Down
14 changes: 8 additions & 6 deletions hypha/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ async def __call__(self, scope, receive, send):
)

async with self.store.get_workspace_interface(
workspace, user_info
user_info, user_info.scope.current_workspace
) as api:
# Call get_service_type_id to check if it's an ASGI service
service_info = await api.get_service_info(
Expand Down Expand Up @@ -521,7 +521,7 @@ async def get_workspace_services(
"""Route for get services under a workspace."""
try:
async with self.store.get_workspace_interface(
workspace, user_info
user_info, workspace
) as manager:
services = await manager.list_services()
info = serialize(services)
Expand All @@ -548,7 +548,7 @@ async def get_service_info(
"""Route for checking details of a service."""
try:
async with self.store.get_workspace_interface(
workspace, user_info
user_info, workspace
) as api:
if service_id == "ws":
return serialize(api)
Expand Down Expand Up @@ -584,7 +584,7 @@ async def get_workspace_apps(
"""Route for get apps under a workspace."""
try:
async with self.store.get_workspace_interface(
workspace, user_info
user_info, user_info.scope.current_workspace
) as manager:
try:
controller = await manager.get_service("public/server-apps")
Expand Down Expand Up @@ -776,12 +776,14 @@ async def service_function(
try:
workspace, service_id, function_key = function_info
async with self.store.get_workspace_interface(
workspace, user_info
user_info, user_info.scope.current_workspace
) as api:
if service_id == "ws":
service = api
else:
info = await api.get_service_info(service_id, {"mode": _mode})
info = await api.get_service_info(
workspace + "/" + service_id, {"mode": _mode}
)
service = await api.get_service(info.id)
func = get_value(function_key, service)
if not func:
Expand Down
4 changes: 3 additions & 1 deletion hypha/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ async def force_disconnect(_):
)

event_bus = self.store.get_event_bus()
assert user_info.scope.current_workspace == workspace, f"Workspace mismatch: {workspace} != {user_info.current_workspace}"
assert (
user_info.scope.current_workspace == workspace
), f"Workspace mismatch: {workspace} != {user_info.current_workspace}"
conn = RedisRPCConnection(event_bus, workspace, client_id, user_info, None)
self._websockets[f"{workspace}/{client_id}"] = websocket
try:
Expand Down

0 comments on commit efc4517

Please sign in to comment.