From 122548fdae9210c61193ec1df27e1cb822156d98 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Fri, 22 Mar 2024 12:01:24 +0000 Subject: [PATCH] Remove the copying hack and add proper params querying capabilities Signed-off-by: Ivan Danov --- kedro/framework/context/context.py | 32 ++---------------------------- kedro/io/data_catalog.py | 14 +++++++++++++ 2 files changed, 16 insertions(+), 30 deletions(-) diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index b82a979ffa..46f0a2c2da 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -235,7 +235,8 @@ def _get_catalog( save_version=save_version, ) - feed_dict = self._get_feed_dict() + params = self.params + feed_dict = {"parameters": params, "params": params} catalog.add_feed_dict(feed_dict) _validate_transcoded_datasets(catalog) self._hook_manager.hook.after_catalog_created( @@ -248,35 +249,6 @@ def _get_catalog( ) return catalog - def _get_feed_dict(self) -> dict[str, Any]: - """Get parameters and return the feed dictionary.""" - params = self.params - feed_dict = {"parameters": params} - - def _add_param_to_feed_dict(param_name: str, param_value: Any) -> None: - """This recursively adds parameter paths to the `feed_dict`, - whenever `param_value` is a dictionary itself, so that users can - specify specific nested parameters in their node inputs. - - Example: - - >>> param_name = "a" - >>> param_value = {"b": 1} - >>> _add_param_to_feed_dict(param_name, param_value) - >>> assert feed_dict["params:a"] == {"b": 1} - >>> assert feed_dict["params:a.b"] == 1 - """ - key = f"params:{param_name}" - feed_dict[key] = param_value - if isinstance(param_value, dict): - for key, val in param_value.items(): - _add_param_to_feed_dict(f"{param_name}.{key}", val) - - for param_name, param_value in params.items(): - _add_param_to_feed_dict(param_name, param_value) - - return feed_dict - def _get_config_credentials(self) -> dict[str, Any]: """Getter for credentials specified in credentials directory.""" try: diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 411cf14e09..db5badc7f6 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -12,6 +12,7 @@ import re from typing import Any, Dict +from omegaconf import OmegaConf from parse import parse from kedro.io.core import ( @@ -416,6 +417,8 @@ def _get_dataset( def __contains__(self, dataset_name: str) -> bool: """Check if an item is in the catalog as a materialised dataset or pattern""" + if ":" in dataset_name: + dataset_name, _ = dataset_name.split(":", 1) matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name) if dataset_name in self._datasets or matched_pattern: return True @@ -477,6 +480,10 @@ def load(self, name: str, version: str | None = None) -> Any: >>> >>> df = io.load("cars") """ + query = None + if ":" in name: + name, query = name.split(":", 1) + load_version = Version(version, None) if version else None dataset = self._get_dataset(name, version=load_version) @@ -488,6 +495,13 @@ def load(self, name: str, version: str | None = None) -> Any: ) result = dataset.load() + if query and isinstance(result, dict): + result = OmegaConf.select(OmegaConf.create(result), query) + result = ( + OmegaConf.to_container(result) + if OmegaConf.is_config(result) + else result + ) return result