diff --git a/package/kedro_viz/api/rest/responses.py b/package/kedro_viz/api/rest/responses.py index 5a38ef6b4c..1e885eced1 100644 --- a/package/kedro_viz/api/rest/responses.py +++ b/package/kedro_viz/api/rest/responses.py @@ -12,15 +12,13 @@ from kedro_viz.api.rest.utils import get_package_compatibilities from kedro_viz.data_access import data_access_manager -from kedro_viz.models.flowchart import ( - DataNode, +from kedro_viz.models.flowchart.node_metadata import ( DataNodeMetadata, ParametersNodeMetadata, - TaskNode, TaskNodeMetadata, - TranscodedDataNode, TranscodedDataNodeMetadata, ) +from kedro_viz.models.flowchart.nodes import DataNode, TaskNode, TranscodedDataNode from kedro_viz.models.metadata import Metadata, PackageCompatibility logger = logging.getLogger(__name__) diff --git a/package/kedro_viz/data_access/managers.py b/package/kedro_viz/data_access/managers.py index 40e8ac56f6..4468804c77 100644 --- a/package/kedro_viz/data_access/managers.py +++ b/package/kedro_viz/data_access/managers.py @@ -20,15 +20,15 @@ from kedro_viz.constants import DEFAULT_REGISTERED_PIPELINE_ID, ROOT_MODULAR_PIPELINE_ID from kedro_viz.integrations.utils import UnavailableDataset -from kedro_viz.models.flowchart import ( +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.model_utils import GraphNodeType +from kedro_viz.models.flowchart.named_entities import RegisteredPipeline +from kedro_viz.models.flowchart.nodes import ( DataNode, - GraphEdge, GraphNode, - GraphNodeType, ModularPipelineChild, ModularPipelineNode, ParametersNode, - RegisteredPipeline, TaskNode, TranscodedDataNode, ) diff --git a/package/kedro_viz/data_access/repositories/graph.py b/package/kedro_viz/data_access/repositories/graph.py index 601e52d060..bea6095bc9 100644 --- a/package/kedro_viz/data_access/repositories/graph.py +++ b/package/kedro_viz/data_access/repositories/graph.py @@ -3,7 +3,8 @@ from typing import Dict, Generator, List, Optional, Set -from kedro_viz.models.flowchart import GraphEdge, GraphNode +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.nodes import GraphNode class GraphNodesRepository: diff --git a/package/kedro_viz/data_access/repositories/modular_pipelines.py b/package/kedro_viz/data_access/repositories/modular_pipelines.py index 746f6700df..dc51df7f80 100644 --- a/package/kedro_viz/data_access/repositories/modular_pipelines.py +++ b/package/kedro_viz/data_access/repositories/modular_pipelines.py @@ -8,9 +8,9 @@ from kedro.pipeline.node import Node as KedroNode from kedro_viz.constants import ROOT_MODULAR_PIPELINE_ID -from kedro_viz.models.flowchart import ( +from kedro_viz.models.flowchart.model_utils import GraphNodeType +from kedro_viz.models.flowchart.nodes import ( GraphNode, - GraphNodeType, ModularPipelineChild, ModularPipelineNode, ) diff --git a/package/kedro_viz/data_access/repositories/registered_pipelines.py b/package/kedro_viz/data_access/repositories/registered_pipelines.py index d73f621867..1309548fac 100644 --- a/package/kedro_viz/data_access/repositories/registered_pipelines.py +++ b/package/kedro_viz/data_access/repositories/registered_pipelines.py @@ -4,7 +4,7 @@ from collections import OrderedDict, defaultdict from typing import Dict, List, Optional, Set -from kedro_viz.models.flowchart import RegisteredPipeline +from kedro_viz.models.flowchart.named_entities import RegisteredPipeline class RegisteredPipelinesRepository: diff --git a/package/kedro_viz/data_access/repositories/tags.py b/package/kedro_viz/data_access/repositories/tags.py index 0bb46949ac..a7bd33e31f 100644 --- a/package/kedro_viz/data_access/repositories/tags.py +++ b/package/kedro_viz/data_access/repositories/tags.py @@ -3,7 +3,7 @@ from typing import Iterable, List, Set -from kedro_viz.models.flowchart import Tag +from kedro_viz.models.flowchart.named_entities import Tag class TagsRepository: diff --git a/package/kedro_viz/models/flowchart/__init__.py b/package/kedro_viz/models/flowchart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/package/kedro_viz/models/flowchart/edge.py b/package/kedro_viz/models/flowchart/edge.py new file mode 100644 index 0000000000..439cafc782 --- /dev/null +++ b/package/kedro_viz/models/flowchart/edge.py @@ -0,0 +1,15 @@ +"""`kedro_viz.models.flowchart.edge` defines data models to represent Kedro edges in a viz graph.""" + +from pydantic import BaseModel + + +class GraphEdge(BaseModel, frozen=True): + """Represent an edge in the graph + + Args: + source (str): The id of the source node. + target (str): The id of the target node. + """ + + source: str + target: str diff --git a/package/kedro_viz/models/flowchart/model_utils.py b/package/kedro_viz/models/flowchart/model_utils.py new file mode 100644 index 0000000000..f12e94b669 --- /dev/null +++ b/package/kedro_viz/models/flowchart/model_utils.py @@ -0,0 +1,45 @@ +"""`kedro_viz.models.flowchart.model_utils` defines utils for Kedro entities in a viz graph.""" + +import logging +from enum import Enum +from types import FunctionType +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +def _parse_filepath(dataset_description: Dict[str, Any]) -> Optional[str]: + """ + Extract the file path from a dataset description dictionary. + """ + filepath = dataset_description.get("filepath") or dataset_description.get("path") + return str(filepath) if filepath else None + + +def _extract_wrapped_func(func: FunctionType) -> FunctionType: + """Extract a wrapped decorated function to inspect the source code if available. + Adapted from https://stackoverflow.com/a/43506509/1684058 + """ + if func.__closure__ is None: + return func + closure = (c.cell_contents for c in func.__closure__) + wrapped_func = next((c for c in closure if isinstance(c, FunctionType)), None) + # return the original function if it's not a decorated function + return func if wrapped_func is None else wrapped_func + + +# ============================================================================= +# Shared base classes and enumerations for model components +# ============================================================================= + + +class GraphNodeType(str, Enum): + """Represent all possible node types in the graph representation of a Kedro pipeline. + The type needs to inherit from str as well so FastAPI can serialise it. See: + https://fastapi.tiangolo.com/tutorial/path-params/#working-with-python-enumerations + """ + + TASK = "task" + DATA = "data" + PARAMETERS = "parameters" + MODULAR_PIPELINE = "modularPipeline" # CamelCase for frontend compatibility diff --git a/package/kedro_viz/models/flowchart/named_entities.py b/package/kedro_viz/models/flowchart/named_entities.py new file mode 100644 index 0000000000..65944c0764 --- /dev/null +++ b/package/kedro_viz/models/flowchart/named_entities.py @@ -0,0 +1,41 @@ +"""kedro_viz.models.flowchart.named_entities` defines data models for representing named entities +such as tags and registered pipelines within a Kedro visualization graph.""" + +from typing import Optional + +from pydantic import BaseModel, Field, ValidationInfo, field_validator + + +class NamedEntity(BaseModel): + """Represent a named entity (Tag/Registered Pipeline) in a Kedro project + Args: + id (str): Id of the registered pipeline + + Raises: + AssertionError: If id is not supplied during instantiation + """ + + id: str + name: Optional[str] = Field( + default=None, + validate_default=True, + description="The name of the entity", + ) + + @field_validator("name") + @classmethod + def set_name(cls, _, info: ValidationInfo): + """Ensures that the 'name' field is set to the value of 'id' if 'name' is not provided.""" + assert "id" in info.data + return info.data["id"] + + +class RegisteredPipeline(NamedEntity): + """Represent a registered pipeline in a Kedro project.""" + + +class Tag(NamedEntity): + """Represent a tag in a Kedro project.""" + + def __hash__(self) -> int: + return hash(self.id) diff --git a/package/kedro_viz/models/flowchart/node_metadata.py b/package/kedro_viz/models/flowchart/node_metadata.py new file mode 100644 index 0000000000..20940a9b3a --- /dev/null +++ b/package/kedro_viz/models/flowchart/node_metadata.py @@ -0,0 +1,406 @@ +""" +`kedro_viz.models.flowchart.node_metadata` defines data models to represent +Kedro metadata in a visualization graph. +""" + +import inspect +import logging +from abc import ABC +from pathlib import Path +from typing import ClassVar, Dict, List, Optional, Union, cast + +from kedro.pipeline.node import Node as KedroNode +from pydantic import BaseModel, Field, field_validator, model_validator + +try: + # kedro 0.18.12 onwards + from kedro.io.core import AbstractDataset +except ImportError: # pragma: no cover + # older versions + from kedro.io.core import AbstractDataSet as AbstractDataset # type: ignore + +from kedro_viz.models.utils import get_dataset_type + +from .model_utils import _extract_wrapped_func, _parse_filepath +from .nodes import DataNode, ParametersNode, TaskNode, TranscodedDataNode + +logger = logging.getLogger(__name__) + + +class GraphNodeMetadata(BaseModel, ABC): + """Represent a graph node's metadata.""" + + +class TaskNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a TaskNode. + + Args: + task_node (TaskNode): Task node to which this metadata belongs to. + + Raises: + AssertionError: If task_node is not supplied during instantiation. + """ + + task_node: TaskNode = Field(..., exclude=True) + + code: Optional[str] = Field( + default=None, + validate_default=True, + description="Source code of the node's function", + ) + + filepath: Optional[str] = Field( + default=None, + validate_default=True, + description="Path to the file where the node is defined", + ) + + parameters: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The parameters of the node, if available", + ) + run_command: Optional[str] = Field( + default=None, + validate_default=True, + description="The command to run the pipeline to this node", + ) + + inputs: Optional[List[str]] = Field( + default=None, validate_default=True, description="The inputs to the TaskNode" + ) + outputs: Optional[List[str]] = Field( + default=None, validate_default=True, description="The outputs from the TaskNode" + ) + + @model_validator(mode="before") + @classmethod + def check_task_node_exists(cls, values): + assert "task_node" in values + cls.set_task_and_kedro_node(values["task_node"]) + return values + + @classmethod + def set_task_and_kedro_node(cls, task_node): + cls.task_node = task_node + cls.kedro_node = cast(KedroNode, task_node.kedro_obj) + + @field_validator("code") + @classmethod + def set_code(cls, code): + # this is required to handle partial, curry functions + if inspect.isfunction(cls.kedro_node.func): + code = inspect.getsource(_extract_wrapped_func(cls.kedro_node.func)) + return code + + return None + + @field_validator("filepath") + @classmethod + def set_filepath(cls, filepath): + # this is required to handle partial, curry functions + if inspect.isfunction(cls.kedro_node.func): + code_full_path = ( + Path(inspect.getfile(cls.kedro_node.func)).expanduser().resolve() + ) + + try: + filepath = code_full_path.relative_to(Path.cwd().parent) + except ValueError: # pragma: no cover + # if the filepath can't be resolved relative to the current directory, + # e.g. either during tests or during launching development server + # outside of a Kedro project, simply return the fullpath to the file. + filepath = code_full_path + + return str(filepath) + + return None + + @field_validator("parameters") + @classmethod + def set_parameters(cls, _): + return cls.task_node.parameters + + @field_validator("run_command") + @classmethod + def set_run_command(cls, _): + return f"kedro run --to-nodes='{cls.kedro_node.name}'" + + @field_validator("inputs") + @classmethod + def set_inputs(cls, _): + return cls.kedro_node.inputs + + @field_validator("outputs") + @classmethod + def set_outputs(cls, _): + return cls.kedro_node.outputs + + +class DataNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a DataNode. + + Args: + data_node (DataNode): Data node to which this metadata belongs to. + + Attributes: + is_all_previews_enabled (bool): Class-level attribute to determine if + previews are enabled for all nodes. This can be configured via CLI + or UI to manage the preview settings. + + Raises: + AssertionError: If data_node is not supplied during instantiation. + """ + + data_node: DataNode = Field(..., exclude=True) + + is_all_previews_enabled: ClassVar[bool] = True + + type: Optional[str] = Field( + default=None, validate_default=True, description="The type of the data node" + ) + + filepath: Optional[str] = Field( + default=None, + validate_default=True, + description="The path to the actual data file for the underlying dataset", + ) + + run_command: Optional[str] = Field( + default=None, + validate_default=True, + description="Command to run the pipeline to this node", + ) + + preview: Optional[Union[Dict, str]] = Field( + default=None, + validate_default=True, + description="Preview data for the underlying datanode", + ) + + preview_type: Optional[str] = Field( + default=None, + validate_default=True, + description="Type of preview for the dataset", + ) + + stats: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The statistics for the data node.", + ) + + @model_validator(mode="before") + @classmethod + def check_data_node_exists(cls, values): + assert "data_node" in values + cls.set_data_node_and_dataset(values["data_node"]) + return values + + @classmethod + def set_is_all_previews_enabled(cls, value: bool): + cls.is_all_previews_enabled = value + + @classmethod + def set_data_node_and_dataset(cls, data_node): + cls.data_node = data_node + cls.dataset = cast(AbstractDataset, data_node.kedro_obj) + + # dataset.release clears the cache before loading to ensure that this issue + # does not arise: https://github.com/kedro-org/kedro-viz/pull/573. + cls.dataset.release() + + @field_validator("type") + @classmethod + def set_type(cls, _): + return cls.data_node.dataset_type + + @field_validator("filepath") + @classmethod + def set_filepath(cls, _): + dataset_description = cls.dataset._describe() + return _parse_filepath(dataset_description) + + @field_validator("run_command") + @classmethod + def set_run_command(cls, _): + if not cls.data_node.is_free_input: + return f"kedro run --to-outputs={cls.data_node.name}" + return None + + @field_validator("preview") + @classmethod + def set_preview(cls, _): + if ( + not cls.data_node.is_preview_enabled() + or not hasattr(cls.dataset, "preview") + or not cls.is_all_previews_enabled + ): + return None + + try: + preview_args = ( + cls.data_node.get_preview_args() if cls.data_node.viz_metadata else None + ) + if preview_args is None: + return cls.dataset.preview() + return cls.dataset.preview(**preview_args) + + except Exception as exc: # noqa: BLE001 + logger.warning( + "'%s' could not be previewed. Full exception: %s: %s", + cls.data_node.name, + type(exc).__name__, + exc, + ) + return None + + @field_validator("preview_type") + @classmethod + def set_preview_type(cls, _): + if ( + not cls.data_node.is_preview_enabled() + or not hasattr(cls.dataset, "preview") + or not cls.is_all_previews_enabled + ): + return None + + try: + preview_type_annotation = inspect.signature( + cls.dataset.preview + ).return_annotation + # Attempt to get the name attribute, if it exists. + # Otherwise, use str to handle the annotation directly. + preview_type_name = getattr( + preview_type_annotation, "__name__", str(preview_type_annotation) + ) + return preview_type_name + + except Exception as exc: # noqa: BLE001 # pragma: no cover + logger.warning( + "'%s' did not have preview type. Full exception: %s: %s", + cls.data_node.name, + type(exc).__name__, + exc, + ) + return None + + @field_validator("stats") + @classmethod + def set_stats(cls, _): + return cls.data_node.stats + + +class TranscodedDataNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a TranscodedDataNode. + Args: + transcoded_data_node: The transcoded data node to which this metadata belongs. + + Raises: + AssertionError: If `transcoded_data_node` is not supplied during instantiation. + """ + + transcoded_data_node: TranscodedDataNode = Field(..., exclude=True) + + # Only available if the dataset has filepath set. + filepath: Optional[str] = Field( + default=None, + validate_default=True, + description="The path to the actual data file for the underlying dataset", + ) + + run_command: Optional[str] = Field( + default=None, + validate_default=True, + description="Command to run the pipeline to this node", + ) + original_type: Optional[str] = Field( + default=None, + validate_default=True, + description="The dataset type of the underlying transcoded data node original version", + ) + transcoded_types: Optional[List[str]] = Field( + default=None, + validate_default=True, + description="The list of all dataset types for the transcoded versions", + ) + + # Statistics for the underlying data node + stats: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The statistics for the transcoded data node metadata.", + ) + + @model_validator(mode="before") + @classmethod + def check_transcoded_data_node_exists(cls, values): + assert "transcoded_data_node" in values + cls.transcoded_data_node = values["transcoded_data_node"] + return values + + @field_validator("filepath") + @classmethod + def set_filepath(cls, _): + dataset_description = cls.transcoded_data_node.original_version._describe() + return _parse_filepath(dataset_description) + + @field_validator("run_command") + @classmethod + def set_run_command(cls, _): + if not cls.transcoded_data_node.is_free_input: + return f"kedro run --to-outputs={cls.transcoded_data_node.original_name}" + return None + + @field_validator("original_type") + @classmethod + def set_original_type(cls, _): + return get_dataset_type(cls.transcoded_data_node.original_version) + + @field_validator("transcoded_types") + @classmethod + def set_transcoded_types(cls, _): + return [ + get_dataset_type(transcoded_version) + for transcoded_version in cls.transcoded_data_node.transcoded_versions + ] + + @field_validator("stats") + @classmethod + def set_stats(cls, _): + return cls.transcoded_data_node.stats + + +class ParametersNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a ParametersNode. + + Args: + parameters_node (ParametersNode): The underlying parameters node + for the parameters metadata node. + + Raises: + AssertionError: If parameters_node is not supplied during instantiation. + """ + + parameters_node: ParametersNode = Field(..., exclude=True) + parameters: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The parameters dictionary for the parameters metadata node", + ) + + @model_validator(mode="before") + @classmethod + def check_parameters_node_exists(cls, values): + assert "parameters_node" in values + cls.parameters_node = values["parameters_node"] + return values + + @field_validator("parameters") + @classmethod + def set_parameters(cls, _): + if cls.parameters_node.is_single_parameter(): + return { + cls.parameters_node.parameter_name: cls.parameters_node.parameter_value + } + return cls.parameters_node.parameter_value diff --git a/package/kedro_viz/models/flowchart.py b/package/kedro_viz/models/flowchart/nodes.py similarity index 53% rename from package/kedro_viz/models/flowchart.py rename to package/kedro_viz/models/flowchart/nodes.py index 299dbc120e..0289fe1e1e 100644 --- a/package/kedro_viz/models/flowchart.py +++ b/package/kedro_viz/models/flowchart/nodes.py @@ -1,12 +1,8 @@ -"""`kedro_viz.models.flowchart` defines data models to represent Kedro entities in a viz graph.""" +"""`kedro_viz.models.flowchart.nodes` defines models to represent Kedro nodes in a viz graph.""" -import abc -import inspect import logging -from enum import Enum -from pathlib import Path -from types import FunctionType -from typing import Any, ClassVar, Dict, List, Optional, Set, Union, cast +from abc import ABC +from typing import Any, Dict, Optional, Set, Union, cast from fastapi.encoders import jsonable_encoder from kedro.pipeline.node import Node as KedroNode @@ -19,9 +15,6 @@ model_validator, ) -from kedro_viz.models.utils import get_dataset_type -from kedro_viz.utils import TRANSCODING_SEPARATOR, _strip_transcoding - try: # kedro 0.18.11 onwards from kedro.io.core import DatasetError @@ -35,75 +28,15 @@ # older versions from kedro.io.core import AbstractDataSet as AbstractDataset # type: ignore -logger = logging.getLogger(__name__) - - -def _parse_filepath(dataset_description: Dict[str, Any]) -> Optional[str]: - filepath = dataset_description.get("filepath") or dataset_description.get("path") - return str(filepath) if filepath else None - - -class NamedEntity(BaseModel): - """Represent a named entity (Tag/Registered Pipeline) in a Kedro project - Args: - id (str): Id of the registered pipeline - - Raises: - AssertionError: If id is not supplied during instantiation - """ - - id: str - name: Optional[str] = Field( - default=None, - validate_default=True, - description="The name of the registered pipeline", - ) - - @field_validator("name") - @classmethod - def set_name(cls, _, info: ValidationInfo): - assert "id" in info.data - return info.data["id"] - - -class RegisteredPipeline(NamedEntity): - """Represent a registered pipeline in a Kedro project""" - - -class GraphNodeType(str, Enum): - """Represent all possible node types in the graph representation of a Kedro pipeline. - The type needs to inherit from str as well so FastAPI can serialise it. See: - https://fastapi.tiangolo.com/tutorial/path-params/#working-with-python-enumerations - """ - - TASK = "task" - DATA = "data" - PARAMETERS = "parameters" - MODULAR_PIPELINE = ( - "modularPipeline" # camelCase so it can be referred directly to in the frontend - ) - - -class ModularPipelineChild(BaseModel, frozen=True): - """Represent a child of a modular pipeline. - - Args: - id (str): Id of the modular pipeline child - type (GraphNodeType): Type of modular pipeline child - """ - - id: str - type: GraphNodeType - +from kedro_viz.models.utils import get_dataset_type +from kedro_viz.utils import TRANSCODING_SEPARATOR, _strip_transcoding -class Tag(NamedEntity): - """Represent a tag in a Kedro project""" +from .model_utils import GraphNodeType - def __hash__(self) -> int: - return hash(self.id) +logger = logging.getLogger(__name__) -class GraphNode(BaseModel, abc.ABC): +class GraphNode(BaseModel, ABC): """Represent a node in the graph representation of a Kedro pipeline. All node models except the metadata node models should inherit from this class @@ -281,8 +214,16 @@ def has_metadata(self) -> bool: return self.kedro_obj is not None -class GraphNodeMetadata(BaseModel, abc.ABC): - """Represent a graph node's metadata""" +class ModularPipelineChild(BaseModel, frozen=True): + """Represent a child of a modular pipeline. + + Args: + id (str): Id of the modular pipeline child + type (GraphNodeType): Type of modular pipeline child + """ + + id: str + type: GraphNodeType class TaskNode(GraphNode): @@ -317,154 +258,6 @@ def set_namespace(cls, _, info: ValidationInfo): return info.data["kedro_obj"].namespace -def _extract_wrapped_func(func: FunctionType) -> FunctionType: - """Extract a wrapped decorated function to inspect the source code if available. - Adapted from https://stackoverflow.com/a/43506509/1684058 - """ - if func.__closure__ is None: - return func - closure = (c.cell_contents for c in func.__closure__) - wrapped_func = next((c for c in closure if isinstance(c, FunctionType)), None) - # return the original function if it's not a decorated function - return func if wrapped_func is None else wrapped_func - - -class ModularPipelineNode(GraphNode): - """Represent a modular pipeline node in the graph""" - - # A modular pipeline doesn't belong to any other modular pipeline, - # in the same sense as other types of GraphNode do. - # Therefore it's default to None. - # The parent-child relationship between modular pipeline themselves is modelled explicitly. - modular_pipelines: Optional[Set[str]] = None - - # Model the modular pipelines tree using a child-references representation of a tree. - # See: https://docs.mongodb.com/manual/tutorial/model-tree-structures-with-child-references/ - # for more details. - # For example, if a node namespace is "uk.data_science", - # the "uk" modular pipeline node's children are ["uk.data_science"] - children: Set[ModularPipelineChild] = Field( - set(), description="The children for the modular pipeline node" - ) - - inputs: Set[str] = Field( - set(), description="The input datasets to the modular pipeline node" - ) - - outputs: Set[str] = Field( - set(), description="The output datasets from the modular pipeline node" - ) - - # The type for Modular Pipeline Node - type: str = GraphNodeType.MODULAR_PIPELINE.value - - -class TaskNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a TaskNode - - Args: - task_node (TaskNode): Task node to which this metadata belongs to. - - Raises: - AssertionError: If task_node is not supplied during instantiation - """ - - task_node: TaskNode = Field(..., exclude=True) - - code: Optional[str] = Field( - default=None, - validate_default=True, - description="Source code of the node's function", - ) - - filepath: Optional[str] = Field( - default=None, - validate_default=True, - description="Path to the file where the node is defined", - ) - - parameters: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The parameters of the node, if available", - ) - run_command: Optional[str] = Field( - default=None, - validate_default=True, - description="The command to run the pipeline to this node", - ) - - inputs: Optional[List[str]] = Field( - default=None, validate_default=True, description="The inputs to the TaskNode" - ) - outputs: Optional[List[str]] = Field( - default=None, validate_default=True, description="The outputs from the TaskNode" - ) - - @model_validator(mode="before") - @classmethod - def check_task_node_exists(cls, values): - assert "task_node" in values - cls.set_task_and_kedro_node(values["task_node"]) - return values - - @classmethod - def set_task_and_kedro_node(cls, task_node): - cls.task_node = task_node - cls.kedro_node = cast(KedroNode, task_node.kedro_obj) - - @field_validator("code") - @classmethod - def set_code(cls, code): - # this is required to handle partial, curry functions - if inspect.isfunction(cls.kedro_node.func): - code = inspect.getsource(_extract_wrapped_func(cls.kedro_node.func)) - return code - - return None - - @field_validator("filepath") - @classmethod - def set_filepath(cls, filepath): - # this is required to handle partial, curry functions - if inspect.isfunction(cls.kedro_node.func): - code_full_path = ( - Path(inspect.getfile(cls.kedro_node.func)).expanduser().resolve() - ) - - try: - filepath = code_full_path.relative_to(Path.cwd().parent) - except ValueError: # pragma: no cover - # if the filepath can't be resolved relative to the current directory, - # e.g. either during tests or during launching development server - # outside of a Kedro project, simply return the fullpath to the file. - filepath = code_full_path - - return str(filepath) - - return None - - @field_validator("parameters") - @classmethod - def set_parameters(cls, _): - return cls.task_node.parameters - - @field_validator("run_command") - @classmethod - def set_run_command(cls, _): - return f"kedro run --to-nodes='{cls.kedro_node.name}'" - - @field_validator("inputs") - @classmethod - def set_inputs(cls, _): - return cls.kedro_node.inputs - - @field_validator("outputs") - @classmethod - def set_outputs(cls, _): - return cls.kedro_node.outputs - - class DataNode(GraphNode): """Represent a graph node of type data @@ -580,241 +373,6 @@ def has_metadata(self) -> bool: return True -class DataNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a DataNode - - Args: - data_node (DataNode): Data node to which this metadata belongs to. - - Attributes: - is_all_previews_enabled (bool): Class-level attribute to determine if - previews are enabled for all nodes. This can be configured via CLI - or UI to manage the preview settings. - - Raises: - AssertionError: If data_node is not supplied during instantiation - """ - - data_node: DataNode = Field(..., exclude=True) - - is_all_previews_enabled: ClassVar[bool] = True - - type: Optional[str] = Field( - default=None, validate_default=True, description="The type of the data node" - ) - - filepath: Optional[str] = Field( - default=None, - validate_default=True, - description="The path to the actual data file for the underlying dataset", - ) - - run_command: Optional[str] = Field( - default=None, - validate_default=True, - description="Command to run the pipeline to this node", - ) - - preview: Optional[Union[Dict, str]] = Field( - default=None, - validate_default=True, - description="Preview data for the underlying datanode", - ) - - preview_type: Optional[str] = Field( - default=None, - validate_default=True, - description="Type of preview for the dataset", - ) - - stats: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The statistics for the data node.", - ) - - @model_validator(mode="before") - @classmethod - def check_data_node_exists(cls, values): - assert "data_node" in values - cls.set_data_node_and_dataset(values["data_node"]) - return values - - @classmethod - def set_is_all_previews_enabled(cls, value: bool): - cls.is_all_previews_enabled = value - - @classmethod - def set_data_node_and_dataset(cls, data_node): - cls.data_node = data_node - cls.dataset = cast(AbstractDataset, data_node.kedro_obj) - - # dataset.release clears the cache before loading to ensure that this issue - # does not arise: https://github.com/kedro-org/kedro-viz/pull/573. - cls.dataset.release() - - @field_validator("type") - @classmethod - def set_type(cls, _): - return cls.data_node.dataset_type - - @field_validator("filepath") - @classmethod - def set_filepath(cls, _): - dataset_description = cls.dataset._describe() - return _parse_filepath(dataset_description) - - @field_validator("run_command") - @classmethod - def set_run_command(cls, _): - if not cls.data_node.is_free_input: - return f"kedro run --to-outputs={cls.data_node.name}" - return None - - @field_validator("preview") - @classmethod - def set_preview(cls, _): - if ( - not cls.data_node.is_preview_enabled() - or not hasattr(cls.dataset, "preview") - or not cls.is_all_previews_enabled - ): - return None - - try: - preview_args = ( - cls.data_node.get_preview_args() if cls.data_node.viz_metadata else None - ) - if preview_args is None: - return cls.dataset.preview() - return cls.dataset.preview(**preview_args) - - except Exception as exc: # noqa: BLE001 - logger.warning( - "'%s' could not be previewed. Full exception: %s: %s", - cls.data_node.name, - type(exc).__name__, - exc, - ) - return None - - @field_validator("preview_type") - @classmethod - def set_preview_type(cls, _): - if ( - not cls.data_node.is_preview_enabled() - or not hasattr(cls.dataset, "preview") - or not cls.is_all_previews_enabled - ): - return None - - try: - preview_type_annotation = inspect.signature( - cls.dataset.preview - ).return_annotation - # Attempt to get the name attribute, if it exists. - # Otherwise, use str to handle the annotation directly. - preview_type_name = getattr( - preview_type_annotation, "__name__", str(preview_type_annotation) - ) - return preview_type_name - - except Exception as exc: # noqa: BLE001 # pragma: no cover - logger.warning( - "'%s' did not have preview type. Full exception: %s: %s", - cls.data_node.name, - type(exc).__name__, - exc, - ) - return None - - @field_validator("stats") - @classmethod - def set_stats(cls, _): - return cls.data_node.stats - - -class TranscodedDataNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a TranscodedDataNode - Args: - transcoded_data_node (TranscodedDataNode): The underlying transcoded - data node to which this metadata belongs to. - - Raises: - AssertionError: If transcoded_data_node is not supplied during instantiation - """ - - transcoded_data_node: TranscodedDataNode = Field(..., exclude=True) - - # Only available if the dataset has filepath set. - filepath: Optional[str] = Field( - default=None, - validate_default=True, - description="The path to the actual data file for the underlying dataset", - ) - - run_command: Optional[str] = Field( - default=None, - validate_default=True, - description="Command to run the pipeline to this node", - ) - original_type: Optional[str] = Field( - default=None, - validate_default=True, - description="The dataset type of the underlying transcoded data node original version", - ) - transcoded_types: Optional[List[str]] = Field( - default=None, - validate_default=True, - description="The list of all dataset types for the transcoded versions", - ) - - # Statistics for the underlying data node - stats: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The statistics for the transcoded data node metadata.", - ) - - @model_validator(mode="before") - @classmethod - def check_transcoded_data_node_exists(cls, values): - assert "transcoded_data_node" in values - cls.transcoded_data_node = values["transcoded_data_node"] - return values - - @field_validator("filepath") - @classmethod - def set_filepath(cls, _): - dataset_description = cls.transcoded_data_node.original_version._describe() - return _parse_filepath(dataset_description) - - @field_validator("run_command") - @classmethod - def set_run_command(cls, _): - if not cls.transcoded_data_node.is_free_input: - return f"kedro run --to-outputs={cls.transcoded_data_node.original_name}" - return None - - @field_validator("original_type") - @classmethod - def set_original_type(cls, _): - return get_dataset_type(cls.transcoded_data_node.original_version) - - @field_validator("transcoded_types") - @classmethod - def set_transcoded_types(cls, _): - return [ - get_dataset_type(transcoded_version) - for transcoded_version in cls.transcoded_data_node.transcoded_versions - ] - - @field_validator("stats") - @classmethod - def set_stats(cls, _): - return cls.transcoded_data_node.stats - - class ParametersNode(GraphNode): """Represent a graph node of type parameters Args: @@ -882,48 +440,31 @@ def parameter_value(self) -> Any: return None -class ParametersNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a ParametersNode - - Args: - parameters_node (ParametersNode): The underlying parameters node - for the parameters metadata node. +class ModularPipelineNode(GraphNode): + """Represent a modular pipeline node in the graph""" - Raises: - AssertionError: If parameters_node is not supplied during instantiation - """ + # A modular pipeline doesn't belong to any other modular pipeline, + # in the same sense as other types of GraphNode do. + # Therefore, it's default to None. + # The parent-child relationship between modular pipeline themselves is modelled explicitly. + modular_pipelines: Optional[Set[str]] = None - parameters_node: ParametersNode = Field(..., exclude=True) - parameters: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The parameters dictionary for the parameters metadata node", + # Model the modular pipelines tree using a child-references representation of a tree. + # See: https://docs.mongodb.com/manual/tutorial/model-tree-structures-with-child-references/ + # for more details. + # For example, if a node namespace is "uk.data_science", + # the "uk" modular pipeline node's children are ["uk.data_science"] + children: Set[ModularPipelineChild] = Field( + set(), description="The children for the modular pipeline node" ) - @model_validator(mode="before") - @classmethod - def check_parameters_node_exists(cls, values): - assert "parameters_node" in values - cls.parameters_node = values["parameters_node"] - return values - - @field_validator("parameters") - @classmethod - def set_parameters(cls, _): - if cls.parameters_node.is_single_parameter(): - return { - cls.parameters_node.parameter_name: cls.parameters_node.parameter_value - } - return cls.parameters_node.parameter_value - - -class GraphEdge(BaseModel, frozen=True): - """Represent an edge in the graph + inputs: Set[str] = Field( + set(), description="The input datasets to the modular pipeline node" + ) - Args: - source (str): The id of the source node. - target (str): The id of the target node. - """ + outputs: Set[str] = Field( + set(), description="The output datasets from the modular pipeline node" + ) - source: str - target: str + # The type for Modular Pipeline Node + type: str = GraphNodeType.MODULAR_PIPELINE.value diff --git a/package/kedro_viz/services/layers.py b/package/kedro_viz/services/layers.py index f8840534fc..7cba369aa1 100644 --- a/package/kedro_viz/services/layers.py +++ b/package/kedro_viz/services/layers.py @@ -5,7 +5,7 @@ from graphlib import CycleError, TopologicalSorter from typing import Dict, List, Set -from kedro_viz.models.flowchart import GraphNode +from kedro_viz.models.flowchart.nodes import GraphNode logger = logging.getLogger(__name__) diff --git a/package/tests/conftest.py b/package/tests/conftest.py index 7c66051328..c6b802974a 100644 --- a/package/tests/conftest.py +++ b/package/tests/conftest.py @@ -21,7 +21,8 @@ ) from kedro_viz.integrations.kedro.hooks import DatasetStatsHook from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore -from kedro_viz.models.flowchart import DataNodeMetadata, GraphNode +from kedro_viz.models.flowchart.node_metadata import DataNodeMetadata +from kedro_viz.models.flowchart.nodes import GraphNode from kedro_viz.server import populate_data diff --git a/package/tests/test_api/test_rest/test_responses.py b/package/tests/test_api/test_rest/test_responses.py index 6f4581d3a3..8dbf549416 100644 --- a/package/tests/test_api/test_rest/test_responses.py +++ b/package/tests/test_api/test_rest/test_responses.py @@ -19,7 +19,7 @@ save_api_responses_to_fs, write_api_response_to_fs, ) -from kedro_viz.models.flowchart import TaskNode +from kedro_viz.models.flowchart.nodes import TaskNode from kedro_viz.models.metadata import Metadata diff --git a/package/tests/test_data_access/test_managers.py b/package/tests/test_data_access/test_managers.py index 66bd08f1e9..abb8df9be5 100644 --- a/package/tests/test_data_access/test_managers.py +++ b/package/tests/test_data_access/test_managers.py @@ -15,11 +15,11 @@ ModularPipelinesRepository, ) from kedro_viz.integrations.utils import UnavailableDataset -from kedro_viz.models.flowchart import ( +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.named_entities import Tag +from kedro_viz.models.flowchart.nodes import ( DataNode, - GraphEdge, ParametersNode, - Tag, TaskNode, TranscodedDataNode, ) diff --git a/package/tests/test_data_access/test_repositories/test_graph.py b/package/tests/test_data_access/test_repositories/test_graph.py index c45232ebd1..51f8684368 100644 --- a/package/tests/test_data_access/test_repositories/test_graph.py +++ b/package/tests/test_data_access/test_repositories/test_graph.py @@ -4,7 +4,8 @@ GraphEdgesRepository, GraphNodesRepository, ) -from kedro_viz.models.flowchart import GraphEdge, GraphNode +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.nodes import GraphNode class TestGraphNodeRepository: diff --git a/package/tests/test_data_access/test_repositories/test_modular_pipelines.py b/package/tests/test_data_access/test_repositories/test_modular_pipelines.py index 5b5a5e783b..ef6058ca8b 100644 --- a/package/tests/test_data_access/test_repositories/test_modular_pipelines.py +++ b/package/tests/test_data_access/test_repositories/test_modular_pipelines.py @@ -6,11 +6,8 @@ from kedro_viz.constants import ROOT_MODULAR_PIPELINE_ID from kedro_viz.data_access.repositories import ModularPipelinesRepository -from kedro_viz.models.flowchart import ( - GraphNodeType, - ModularPipelineChild, - ModularPipelineNode, -) +from kedro_viz.models.flowchart.model_utils import GraphNodeType +from kedro_viz.models.flowchart.nodes import ModularPipelineChild, ModularPipelineNode @pytest.fixture diff --git a/package/tests/test_models/test_flowchart/__init__.py b/package/tests/test_models/test_flowchart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/package/tests/test_models/test_flowchart.py b/package/tests/test_models/test_flowchart/test_node_metadata.py similarity index 55% rename from package/tests/test_models/test_flowchart.py rename to package/tests/test_models/test_flowchart/test_node_metadata.py index 01238f286d..f8ebd4f8ec 100644 --- a/package/tests/test_models/test_flowchart.py +++ b/package/tests/test_models/test_flowchart/test_node_metadata.py @@ -1,7 +1,6 @@ from functools import partial from pathlib import Path from textwrap import dedent -from unittest.mock import call, patch import pytest from kedro.io import MemoryDataset @@ -9,18 +8,13 @@ from kedro_datasets.pandas import CSVDataset, ParquetDataset from kedro_datasets.partitions.partitioned_dataset import PartitionedDataset -from kedro_viz.models.flowchart import ( - DataNode, +from kedro_viz.models.flowchart.node_metadata import ( DataNodeMetadata, - GraphNode, - ParametersNode, ParametersNodeMetadata, - RegisteredPipeline, - TaskNode, TaskNodeMetadata, - TranscodedDataNode, TranscodedDataNodeMetadata, ) +from kedro_viz.models.flowchart.nodes import GraphNode def identity(x): @@ -56,264 +50,6 @@ def full_func(a, b, c, x): partial_func = partial(full_func, 3, 1, 4) -class TestGraphNodeCreation: - @pytest.mark.parametrize( - "namespace,expected_modular_pipelines", - [ - (None, set()), - ( - "uk.data_science.model_training", - set( - [ - "uk", - "uk.data_science", - "uk.data_science.model_training", - ] - ), - ), - ], - ) - def test_create_task_node(self, namespace, expected_modular_pipelines): - kedro_node = node( - identity, - inputs="x", - outputs="y", - name="identity_node", - tags={"tag"}, - namespace=namespace, - ) - task_node = GraphNode.create_task_node( - kedro_node, "identity_node", expected_modular_pipelines - ) - assert isinstance(task_node, TaskNode) - assert task_node.kedro_obj is kedro_node - assert task_node.name == "identity_node" - assert task_node.tags == {"tag"} - assert task_node.pipelines == set() - assert task_node.modular_pipelines == expected_modular_pipelines - assert task_node.namespace == namespace - - @pytest.mark.parametrize( - "dataset_name, expected_modular_pipelines", - [ - ("dataset", set()), - ( - "uk.data_science.model_training.dataset", - set( - [ - "uk", - "uk.data_science", - "uk.data_science.model_training", - ] - ), - ), - ], - ) - def test_create_data_node(self, dataset_name, expected_modular_pipelines): - kedro_dataset = CSVDataset(filepath="foo.csv") - data_node = GraphNode.create_data_node( - dataset_id=dataset_name, - dataset_name=dataset_name, - layer="raw", - tags=set(), - dataset=kedro_dataset, - stats={"rows": 10, "columns": 5, "file_size": 1024}, - modular_pipelines=set(expected_modular_pipelines), - ) - assert isinstance(data_node, DataNode) - assert data_node.kedro_obj is kedro_dataset - assert data_node.id == dataset_name - assert data_node.name == dataset_name - assert data_node.layer == "raw" - assert data_node.tags == set() - assert data_node.pipelines == set() - assert data_node.modular_pipelines == expected_modular_pipelines - assert data_node.stats["rows"] == 10 - assert data_node.stats["columns"] == 5 - assert data_node.stats["file_size"] == 1024 - - @pytest.mark.parametrize( - "transcoded_dataset_name, original_name", - [ - ("dataset@pandas2", "dataset"), - ( - "uk.data_science.model_training.dataset@pandas2", - "uk.data_science.model_training.dataset", - ), - ], - ) - def test_create_transcoded_data_node(self, transcoded_dataset_name, original_name): - kedro_dataset = CSVDataset(filepath="foo.csv") - data_node = GraphNode.create_data_node( - dataset_id=original_name, - dataset_name=transcoded_dataset_name, - layer="raw", - tags=set(), - dataset=kedro_dataset, - stats={"rows": 10, "columns": 2, "file_size": 1048}, - modular_pipelines=set(), - ) - assert isinstance(data_node, TranscodedDataNode) - assert data_node.id == original_name - assert data_node.name == original_name - assert data_node.layer == "raw" - assert data_node.tags == set() - assert data_node.pipelines == set() - assert data_node.stats["rows"] == 10 - assert data_node.stats["columns"] == 2 - assert data_node.stats["file_size"] == 1048 - - def test_create_parameters_all_parameters(self): - parameters_dataset = MemoryDataset( - data={"test_split_ratio": 0.3, "num_epochs": 1000} - ) - parameters_node = GraphNode.create_parameters_node( - dataset_id="parameters", - dataset_name="parameters", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert parameters_node.id == "parameters" - assert parameters_node.is_all_parameters() - assert not parameters_node.is_single_parameter() - assert parameters_node.parameter_value == { - "test_split_ratio": 0.3, - "num_epochs": 1000, - } - assert not parameters_node.modular_pipelines - - @pytest.mark.parametrize( - "dataset_name,expected_modular_pipelines", - [ - ("params:test_split_ratio", set()), - ( - "params:uk.data_science.model_training.test_split_ratio", - set(["uk", "uk.data_science", "uk.data_science.model_training"]), - ), - ], - ) - def test_create_parameters_node_single_parameter( - self, dataset_name, expected_modular_pipelines - ): - parameters_dataset = MemoryDataset(data=0.3) - parameters_node = GraphNode.create_parameters_node( - dataset_id=dataset_name, - dataset_name=dataset_name, - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=expected_modular_pipelines, - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert not parameters_node.is_all_parameters() - assert parameters_node.is_single_parameter() - assert parameters_node.parameter_value == 0.3 - assert parameters_node.modular_pipelines == expected_modular_pipelines - - def test_create_single_parameter_with_complex_type(self): - parameters_dataset = MemoryDataset(data=object()) - parameters_node = GraphNode.create_parameters_node( - dataset_id="params:test_split_ratio", - dataset_name="params:test_split_ratio", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert not parameters_node.is_all_parameters() - assert parameters_node.is_single_parameter() - assert isinstance(parameters_node.parameter_value, str) - - def test_create_all_parameters_with_complex_type(self): - mock_object = object() - parameters_dataset = MemoryDataset( - data={ - "test_split_ratio": 0.3, - "num_epochs": 1000, - "complex_param": mock_object, - } - ) - parameters_node = GraphNode.create_parameters_node( - dataset_id="parameters", - dataset_name="parameters", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert parameters_node.id == "parameters" - assert parameters_node.is_all_parameters() - assert not parameters_node.is_single_parameter() - assert isinstance(parameters_node.parameter_value, str) - - def test_create_non_existing_parameter_node(self): - """Test the case where ``parameters`` is equal to None""" - parameters_node = GraphNode.create_parameters_node( - dataset_id="non_existing", - dataset_name="non_existing", - layer=None, - tags=set(), - parameters=None, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.parameter_value is None - - @patch("logging.Logger.warning") - def test_create_non_existing_parameter_node_empty_dataset(self, patched_warning): - """Test the case where ``parameters`` is equal to a MemoryDataset with no data""" - parameters_dataset = MemoryDataset() - parameters_node = GraphNode.create_parameters_node( - dataset_id="non_existing", - dataset_name="non_existing", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert parameters_node.parameter_value is None - patched_warning.assert_has_calls( - [call("Cannot find parameter `%s` in the catalog.", "non_existing")] - ) - - -class TestGraphNodePipelines: - def test_registered_pipeline_name(self): - pipeline = RegisteredPipeline(id="__default__") - assert pipeline.name == "__default__" - - def test_modular_pipeline_name(self): - pipeline = GraphNode.create_modular_pipeline_node("data_engineering") - assert pipeline.name == "data_engineering" - - def test_add_node_to_pipeline(self): - default_pipeline = RegisteredPipeline(id="__default__") - another_pipeline = RegisteredPipeline(id="testing") - kedro_dataset = CSVDataset(filepath="foo.csv") - data_node = GraphNode.create_data_node( - dataset_id="dataset@transcoded", - dataset_name="dataset@transcoded", - layer="raw", - tags=set(), - dataset=kedro_dataset, - stats={"rows": 10, "columns": 2, "file_size": 1048}, - modular_pipelines=set(), - ) - assert data_node.pipelines == set() - data_node.add_pipeline(default_pipeline.id) - assert data_node.belongs_to_pipeline(default_pipeline.id) - assert not data_node.belongs_to_pipeline(another_pipeline.id) - - class TestGraphNodeMetadata: @pytest.mark.parametrize( "dataset,has_metadata", [(MemoryDataset(data=1), True), (None, False)] diff --git a/package/tests/test_models/test_flowchart/test_nodes.py b/package/tests/test_models/test_flowchart/test_nodes.py new file mode 100644 index 0000000000..2d7a59d338 --- /dev/null +++ b/package/tests/test_models/test_flowchart/test_nodes.py @@ -0,0 +1,248 @@ +from unittest.mock import call, patch + +import pytest +from kedro.io import MemoryDataset +from kedro.pipeline.node import node +from kedro_datasets.pandas import CSVDataset + +from kedro_viz.models.flowchart.nodes import ( + DataNode, + GraphNode, + ParametersNode, + TaskNode, + TranscodedDataNode, +) + + +def identity(x): + return x + + +class TestGraphNodeCreation: + @pytest.mark.parametrize( + "namespace,expected_modular_pipelines", + [ + (None, set()), + ( + "uk.data_science.model_training", + set( + [ + "uk", + "uk.data_science", + "uk.data_science.model_training", + ] + ), + ), + ], + ) + def test_create_task_node(self, namespace, expected_modular_pipelines): + kedro_node = node( + identity, + inputs="x", + outputs="y", + name="identity_node", + tags={"tag"}, + namespace=namespace, + ) + task_node = GraphNode.create_task_node( + kedro_node, "identity_node", expected_modular_pipelines + ) + assert isinstance(task_node, TaskNode) + assert task_node.kedro_obj is kedro_node + assert task_node.name == "identity_node" + assert task_node.tags == {"tag"} + assert task_node.pipelines == set() + assert task_node.modular_pipelines == expected_modular_pipelines + assert task_node.namespace == namespace + + @pytest.mark.parametrize( + "dataset_name, expected_modular_pipelines", + [ + ("dataset", set()), + ( + "uk.data_science.model_training.dataset", + set( + [ + "uk", + "uk.data_science", + "uk.data_science.model_training", + ] + ), + ), + ], + ) + def test_create_data_node(self, dataset_name, expected_modular_pipelines): + kedro_dataset = CSVDataset(filepath="foo.csv") + data_node = GraphNode.create_data_node( + dataset_id=dataset_name, + dataset_name=dataset_name, + layer="raw", + tags=set(), + dataset=kedro_dataset, + stats={"rows": 10, "columns": 5, "file_size": 1024}, + modular_pipelines=set(expected_modular_pipelines), + ) + assert isinstance(data_node, DataNode) + assert data_node.kedro_obj is kedro_dataset + assert data_node.id == dataset_name + assert data_node.name == dataset_name + assert data_node.layer == "raw" + assert data_node.tags == set() + assert data_node.pipelines == set() + assert data_node.modular_pipelines == expected_modular_pipelines + assert data_node.stats["rows"] == 10 + assert data_node.stats["columns"] == 5 + assert data_node.stats["file_size"] == 1024 + + @pytest.mark.parametrize( + "transcoded_dataset_name, original_name", + [ + ("dataset@pandas2", "dataset"), + ( + "uk.data_science.model_training.dataset@pandas2", + "uk.data_science.model_training.dataset", + ), + ], + ) + def test_create_transcoded_data_node(self, transcoded_dataset_name, original_name): + kedro_dataset = CSVDataset(filepath="foo.csv") + data_node = GraphNode.create_data_node( + dataset_id=original_name, + dataset_name=transcoded_dataset_name, + layer="raw", + tags=set(), + dataset=kedro_dataset, + stats={"rows": 10, "columns": 2, "file_size": 1048}, + modular_pipelines=set(), + ) + assert isinstance(data_node, TranscodedDataNode) + assert data_node.id == original_name + assert data_node.name == original_name + assert data_node.layer == "raw" + assert data_node.tags == set() + assert data_node.pipelines == set() + assert data_node.stats["rows"] == 10 + assert data_node.stats["columns"] == 2 + assert data_node.stats["file_size"] == 1048 + + def test_create_parameters_all_parameters(self): + parameters_dataset = MemoryDataset( + data={"test_split_ratio": 0.3, "num_epochs": 1000} + ) + parameters_node = GraphNode.create_parameters_node( + dataset_id="parameters", + dataset_name="parameters", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert parameters_node.id == "parameters" + assert parameters_node.is_all_parameters() + assert not parameters_node.is_single_parameter() + assert parameters_node.parameter_value == { + "test_split_ratio": 0.3, + "num_epochs": 1000, + } + assert not parameters_node.modular_pipelines + + @pytest.mark.parametrize( + "dataset_name,expected_modular_pipelines", + [ + ("params:test_split_ratio", set()), + ( + "params:uk.data_science.model_training.test_split_ratio", + set(["uk", "uk.data_science", "uk.data_science.model_training"]), + ), + ], + ) + def test_create_parameters_node_single_parameter( + self, dataset_name, expected_modular_pipelines + ): + parameters_dataset = MemoryDataset(data=0.3) + parameters_node = GraphNode.create_parameters_node( + dataset_id=dataset_name, + dataset_name=dataset_name, + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=expected_modular_pipelines, + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert not parameters_node.is_all_parameters() + assert parameters_node.is_single_parameter() + assert parameters_node.parameter_value == 0.3 + assert parameters_node.modular_pipelines == expected_modular_pipelines + + def test_create_single_parameter_with_complex_type(self): + parameters_dataset = MemoryDataset(data=object()) + parameters_node = GraphNode.create_parameters_node( + dataset_id="params:test_split_ratio", + dataset_name="params:test_split_ratio", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert not parameters_node.is_all_parameters() + assert parameters_node.is_single_parameter() + assert isinstance(parameters_node.parameter_value, str) + + def test_create_all_parameters_with_complex_type(self): + mock_object = object() + parameters_dataset = MemoryDataset( + data={ + "test_split_ratio": 0.3, + "num_epochs": 1000, + "complex_param": mock_object, + } + ) + parameters_node = GraphNode.create_parameters_node( + dataset_id="parameters", + dataset_name="parameters", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert parameters_node.id == "parameters" + assert parameters_node.is_all_parameters() + assert not parameters_node.is_single_parameter() + assert isinstance(parameters_node.parameter_value, str) + + def test_create_non_existing_parameter_node(self): + """Test the case where ``parameters`` is equal to None""" + parameters_node = GraphNode.create_parameters_node( + dataset_id="non_existing", + dataset_name="non_existing", + layer=None, + tags=set(), + parameters=None, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.parameter_value is None + + @patch("logging.Logger.warning") + def test_create_non_existing_parameter_node_empty_dataset(self, patched_warning): + """Test the case where ``parameters`` is equal to a MemoryDataset with no data""" + parameters_dataset = MemoryDataset() + parameters_node = GraphNode.create_parameters_node( + dataset_id="non_existing", + dataset_name="non_existing", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert parameters_node.parameter_value is None + patched_warning.assert_has_calls( + [call("Cannot find parameter `%s` in the catalog.", "non_existing")] + ) diff --git a/package/tests/test_models/test_flowchart/test_pipeline.py b/package/tests/test_models/test_flowchart/test_pipeline.py new file mode 100644 index 0000000000..520aff01d9 --- /dev/null +++ b/package/tests/test_models/test_flowchart/test_pipeline.py @@ -0,0 +1,32 @@ +from kedro_datasets.pandas import CSVDataset + +from kedro_viz.models.flowchart.named_entities import RegisteredPipeline +from kedro_viz.models.flowchart.nodes import GraphNode + + +class TestGraphNodePipelines: + def test_registered_pipeline_name(self): + pipeline = RegisteredPipeline(id="__default__") + assert pipeline.name == "__default__" + + def test_modular_pipeline_name(self): + pipeline = GraphNode.create_modular_pipeline_node("data_engineering") + assert pipeline.name == "data_engineering" + + def test_add_node_to_pipeline(self): + default_pipeline = RegisteredPipeline(id="__default__") + another_pipeline = RegisteredPipeline(id="testing") + kedro_dataset = CSVDataset(filepath="foo.csv") + data_node = GraphNode.create_data_node( + dataset_id="dataset@transcoded", + dataset_name="dataset@transcoded", + layer="raw", + tags=set(), + dataset=kedro_dataset, + stats={"rows": 10, "columns": 2, "file_size": 1048}, + modular_pipelines=set(), + ) + assert data_node.pipelines == set() + data_node.add_pipeline(default_pipeline.id) + assert data_node.belongs_to_pipeline(default_pipeline.id) + assert not data_node.belongs_to_pipeline(another_pipeline.id) diff --git a/package/tests/test_services/test_layers.py b/package/tests/test_services/test_layers.py index 80d76fae5a..c949a9f98b 100644 --- a/package/tests/test_services/test_layers.py +++ b/package/tests/test_services/test_layers.py @@ -1,6 +1,6 @@ import pytest -from kedro_viz.models.flowchart import GraphNode +from kedro_viz.models.flowchart.nodes import GraphNode from kedro_viz.services.layers import sort_layers diff --git a/ruff.toml b/ruff.toml index 52a1d6c8f3..166d54a4a7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -45,7 +45,8 @@ ignore = [ "package/features/steps/sh_run.py" = ["PLW1510"] # `subprocess.run` without explicit `check` argument "*/tests/*.py" = ["SLF", "D", "ARG"] "package/kedro_viz/models/experiment_tracking.py" = ["SLF"] -"package/kedro_viz/models/flowchart.py" = ["SLF"] +"package/kedro_viz/models/flowchart/nodes.py" = ["SLF"] +"package/kedro_viz/models/flowchart/node_metadata.py" = ["SLF"] "package/kedro_viz/integrations/kedro/hooks.py" = ["SLF", "BLE"] "package/kedro_viz/integrations/kedro/sqlite_store.py" = ["BLE"] "package/kedro_viz/integrations/kedro/data_loader.py" = ["SLF"]