Skip to content

Commit

Permalink
edit unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
deigen committed Jan 23, 2025
1 parent 2312951 commit 921ad10
Showing 1 changed file with 15 additions and 34 deletions.
49 changes: 15 additions & 34 deletions tests/runners/test_runners.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,24 @@
# This test will create dummy runner and start the runner at first
# Testing outputs received by client and programmed outputs of runner server
#
import importlib.util
import inspect
import os
import sys
import threading
import uuid

import pytest
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from clarifai_protocol import BaseRunner
from clarifai_protocol.utils.logging import logger
from google.protobuf import json_format

from clarifai.client import BaseClient, Model, User
from clarifai.client.auth.helper import ClarifaiAuthHelper
from clarifai.runners.models.model_builder import ModelBuilder
from clarifai.runners.models.model_runner import ModelRunner


def runner_class(runner_path):

# arbitrary name given to the module to be imported
module = "runner_module"

spec = importlib.util.spec_from_file_location(module, runner_path)
runner_module = importlib.util.module_from_spec(spec)
sys.modules[module] = runner_module
spec.loader.exec_module(runner_module)

# Find all classes in the model.py file that are subclasses of BaseRunner
classes = [
cls for _, cls in inspect.getmembers(runner_module, inspect.isclass)
if issubclass(cls, BaseRunner) and cls.__module__ == runner_module.__name__
]

# Ensure there is exactly one subclass of BaseRunner in the model.py file
if len(classes) != 1:
raise Exception("Expected exactly one subclass of BaseRunner, found: {}".format(len(classes)))

return classes[0]


MyRunner = runner_class(
runner_path=os.path.join(os.path.dirname(__file__), "dummy_runner_models", "1", "model.py"))
MyWrapperRunner = runner_class(runner_path=os.path.join(
os.path.dirname(__file__), "dummy_runner_models", "1", "model_wrapper.py"))
MY_MODEL_PATH = os.path.join(os.path.dirname(__file__), "dummy_runner_models", "1", "model.py")
MY_WRAPPER_MODEL_PATH = os.path.join(
os.path.dirname(__file__), "dummy_runner_models", "1", "model_wrapper.py")

# logger.disabled = True

Expand Down Expand Up @@ -176,7 +149,11 @@ def setup_class(cls):
base_url=cls.AUTH.base,
pat=cls.AUTH.pat,
)
cls.runner = MyRunner(

cls.runner_model = ModelBuilder(MY_MODEL_PATH).create_model_instance()

cls.runner = ModelRunner(
model=cls.runner_model,
runner_id=cls.RUNNER_ID,
nodepool_id=cls.NODEPOOL_ID,
compute_cluster_id=cls.COMPUTE_CLUSTER_ID,
Expand Down Expand Up @@ -492,7 +469,11 @@ def setup_class(cls):
cls.NODEPOOL_ID,
cls.COMPUTE_CLUSTER_ID,
)
cls.runner = MyWrapperRunner(

cls.runner_model = ModelBuilder(MY_WRAPPER_MODEL_PATH).create_model_instance()

cls.runner = ModelRunner(
model=cls.runner_model,
runner_id=cls.RUNNER_ID,
nodepool_id=cls.NODEPOOL_ID,
compute_cluster_id=cls.COMPUTE_CLUSTER_ID,
Expand Down

0 comments on commit 921ad10

Please sign in to comment.