Skip to content

Commit

Permalink
fix: update the unit_test
Browse files Browse the repository at this point in the history
Signed-off-by: taieeuu <[email protected]>
  • Loading branch information
taieeuu committed Dec 22, 2024
1 parent 8f28333 commit db7cb2f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 28 deletions.
18 changes: 1 addition & 17 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import requests
from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest
from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub
from grpc_health.v1 import health_pb2, health_pb2_grpc

from flytekit.clients.auth.authenticator import (
Authenticator,
Expand Down Expand Up @@ -234,22 +233,7 @@ def wrap_exceptions_channel(cfg: PlatformConfig, in_channel: grpc.Channel) -> gr
:param in_channel: grpc.Channel
:return: grpc.Channel
"""

try:
health_stub = health_pb2_grpc.HealthStub(in_channel)
request = health_pb2.HealthCheckRequest()
health_stub.Check(request)

except grpc.RpcError as e:
logging.warning(f"RPC error occurred: {e.code()}")
if e.code() == grpc.StatusCode.UNAUTHENTICATED:
in_channel = wrap_exceptions_channel(
cfg,
upgrade_channel_to_authenticated(
cfg, upgrade_channel_to_proxy_authenticated(cfg, get_channel(cfg, options=cfg.options))
),
)

print("wrap_exceptions_channel")
return grpc.intercept_channel(in_channel, RetryExceptionWrapperInterceptor(max_retries=cfg.rpc_retries))


Expand Down
35 changes: 31 additions & 4 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service
from flyteidl.service import signal_pb2_grpc as signal_service
from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub
from grpc_health.v1 import health_pb2, health_pb2_grpc

from flytekit.clients.auth_helper import (
get_channel,
Expand All @@ -18,6 +19,12 @@
wrap_exceptions_channel,
)
from flytekit.configuration import PlatformConfig
from flytekit.exceptions.system import FlyteSystemUnavailableException
from flytekit.exceptions.user import (
FlyteEntityAlreadyExistsException,
FlyteEntityNotExistException,
FlyteInvalidInputException,
)
from flytekit.loggers import logger


Expand Down Expand Up @@ -51,11 +58,10 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
# 32KB for error messages, 20MB for actual messages.
options = (("grpc.max_metadata_size", 32 * 1024), ("grpc.max_receive_message_length", 20 * 1024 * 1024))
self._cfg = cfg
self.skip_auth = True
if self.skip_auth:
base_channel = get_channel(cfg, options=options)
base_channel = get_channel(cfg, options=options)

if self.check_grpc_health_with_authentication(base_channel):
self._channel = wrap_exceptions_channel(cfg, base_channel)
self.skip_auth = False
else:
self._channel = wrap_exceptions_channel(
cfg,
Expand All @@ -74,6 +80,27 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
# metadata will hold the value of the token to send to the various endpoints.
self._metadata = None

@staticmethod
def check_grpc_health_with_authentication(in_channel):
health_stub = health_pb2_grpc.HealthStub(in_channel)
request = health_pb2.HealthCheckRequest()
try:
response = health_stub.Check(request)
if response.status == health_pb2.HealthCheckResponse.SERVING:
print("Service is healthy and ready to serve.")
return True
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNAUTHENTICATED:
return False
elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise FlyteEntityAlreadyExistsException() from e
elif e.code() == grpc.StatusCode.NOT_FOUND:
raise FlyteEntityNotExistException() from e
elif e.code() == grpc.StatusCode.INVALID_ARGUMENT:
raise FlyteInvalidInputException(request) from e
elif e.code() == grpc.StatusCode.UNAVAILABLE:
raise FlyteSystemUnavailableException() from e

@classmethod
def with_root_certificate(cls, cfg: PlatformConfig, root_cert_file: str) -> RawSynchronousFlyteClient:
b = None
Expand Down
11 changes: 7 additions & 4 deletions tests/flytekit/unit/clients/test_friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,29 @@
from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient
from flytekit.configuration import PlatformConfig
from flytekit.models.project import Project as _Project

from grpc_health.v1 import health_pb2

@mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.update_project")
def test_update_project(mock_raw_update_project):
@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING)
def test_update_project(mock_check_health, mock_raw_update_project):
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True))
project = _Project("foo", "name", "description", state=_Project.ProjectState.ACTIVE)
client.update_project(project)
mock_raw_update_project.assert_called_with(project.to_flyte_idl())


@mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.list_projects")
def test_list_projects_paginated(mock_raw_list_projects):
@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING)
def test_list_projects_paginated(mock_check_health, mock_raw_list_projects):
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True))
client.list_projects_paginated(limit=100, token="")
project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None)
mock_raw_list_projects.assert_called_with(project_list_request=project_list_request)


@mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.create_upload_location")
def test_create_upload_location(mock_raw_create_upload_location):
@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING)
def test_create_upload_location(mock_check_health, mock_raw_create_upload_location):
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True))
client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42), add_content_md5_metadata=True)
duration_pb = Duration()
Expand Down
9 changes: 6 additions & 3 deletions tests/flytekit/unit/clients/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from flytekit.clients.raw import RawSynchronousFlyteClient
from flytekit.configuration import PlatformConfig

from grpc_health.v1 import health_pb2

@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
def test_update_project(mock_channel, mock_admin):
@mock.patch.object(RawSynchronousFlyteClient, "check_grpc_health_with_authentication", return_value=True)
def test_update_project(mock_check_health, mock_channel, mock_admin):
mock_health_stub = mock.Mock()
client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True))
project = _project_pb2.Project(id="foo", name="name", description="description", state=_project_pb2.Project.ACTIVE)
client.update_project(project)
Expand All @@ -17,7 +19,8 @@ def test_update_project(mock_channel, mock_admin):

@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
def test_list_projects_paginated(mock_channel, mock_admin):
@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING)
def test_list_projects_paginated(mock_check_health, mock_channel, mock_admin):
client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True))
project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None)
client.list_projects(project_list_request)
Expand Down

0 comments on commit db7cb2f

Please sign in to comment.