diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 35fe44f033..17f7ab2b19 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -2,12 +2,11 @@ from collections import defaultdict import dpath.util -import toml -import yaml from voluptuous import Any from dvc.dependency.local import LocalDependency from dvc.exceptions import DvcException +from dvc.utils.serialize import PARSERS, ParseError class MissingParamsError(DvcException): @@ -22,8 +21,6 @@ class ParamsDependency(LocalDependency): PARAM_PARAMS = "params" PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)} DEFAULT_PARAMS_FILE = "params.yaml" - PARAMS_FILE_LOADERS = defaultdict(lambda: yaml.safe_load) - PARAMS_FILE_LOADERS.update({".toml": toml.load}) def __init__(self, stage, path, params): info = {} @@ -88,12 +85,12 @@ def read_params(self): if not self.exists: return {} + suffix = self.path_info.suffix.lower() + parser = PARSERS[suffix] with self.repo.tree.open(self.path_info, "r") as fobj: try: - config = self.PARAMS_FILE_LOADERS[ - self.path_info.suffix.lower() - ](fobj) - except (yaml.YAMLError, toml.TomlDecodeError) as exc: + config = parser(fobj.read(), self.path_info) + except ParseError as exc: raise BadParamFileError( f"Unable to read parameters from '{self}'" ) from exc diff --git a/dvc/exceptions.py b/dvc/exceptions.py index d6586a35db..57446f2871 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -182,22 +182,6 @@ def __init__(self): ) -class YAMLFileCorruptedError(DvcException): - def __init__(self, path): - path = relpath(path) - super().__init__( - f"unable to read: '{path}', YAML file structure is corrupted" - ) - - -class TOMLFileCorruptedError(DvcException): - def __init__(self, path): - path = relpath(path) - super().__init__( - f"unable to read: '{path}', TOML file structure is corrupted" - ) - - class RecursiveAddingWhileUsingFilename(DvcException): def __init__(self): super().__init__( diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index 33906ff3a6..b1b0e9b96c 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -1,12 +1,11 @@ import logging import os -import yaml - from dvc.exceptions import NoMetricsError from dvc.path_info import PathInfo from dvc.repo import locked from dvc.repo.tree import RepoTree +from dvc.utils.serialize import YAMLFileCorruptedError, parse_yaml logger = logging.getLogger(__name__) @@ -72,8 +71,8 @@ def _read_metrics(repo, metrics, rev): try: with tree.open(metric, "r") as fobj: # NOTE this also supports JSON - val = yaml.safe_load(fobj) - except (FileNotFoundError, yaml.YAMLError): + val = parse_yaml(fobj.read(), metric) + except (FileNotFoundError, YAMLFileCorruptedError): logger.debug( "failed to read '%s' on '%s'", metric, rev, exc_info=True ) diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index e60aebaedf..a7803ff6ff 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -1,12 +1,10 @@ import logging -import toml -import yaml - from dvc.dependency.param import ParamsDependency from dvc.exceptions import DvcException from dvc.path_info import PathInfo from dvc.repo import locked +from dvc.utils.serialize import PARSERS, ParseError logger = logging.getLogger(__name__) @@ -33,12 +31,12 @@ def _read_params(repo, configs, rev): if not repo.tree.exists(config): continue + suffix = config.suffix.lower() + parser = PARSERS[suffix] with repo.tree.open(config, "r") as fobj: try: - res[str(config)] = ParamsDependency.PARAMS_FILE_LOADERS[ - config.suffix.lower() - ](fobj) - except (yaml.YAMLError, toml.TomlDecodeError): + res[str(config)] = parser(fobj.read(), config) + except ParseError: logger.debug( "failed to read '%s' on '%s'", config, rev, exc_info=True ) diff --git a/dvc/repo/plots/data.py b/dvc/repo/plots/data.py index 78d44164f5..02caca70f4 100644 --- a/dvc/repo/plots/data.py +++ b/dvc/repo/plots/data.py @@ -5,9 +5,8 @@ from collections import OrderedDict from copy import copy -import yaml from funcy import first -from yaml import SafeLoader +from ruamel.yaml import YAML from dvc.exceptions import DvcException @@ -208,18 +207,7 @@ def raw(self, header=True, **kwargs): # pylint: disable=arguments-differ class YAMLPlotData(PlotData): def raw(self, **kwargs): - class OrderedLoader(SafeLoader): - pass - - def construct_mapping(loader, node): - loader.flatten_mapping(node) - return OrderedDict(loader.construct_pairs(node)) - - OrderedLoader.add_constructor( - yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping - ) - - return yaml.load(self.content, OrderedLoader) + return YAML().load(self.content) def _processors(self): parent_processors = super()._processors() diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 6fc0cc6cde..cd8bc17ef2 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -120,9 +120,10 @@ def reproduce( def _parse_params(path_params): from flatten_json import unflatten - from yaml import YAMLError, safe_load + from ruamel.yaml import YAMLError from dvc.dependency.param import ParamsDependency + from dvc.utils.serialize import loads_yaml ret = {} for path_param in path_params: @@ -133,7 +134,7 @@ def _parse_params(path_params): try: # interpret value strings using YAML rules key, value = param_str.split("=") - params[key] = safe_load(value) + params[key] = loads_yaml(value) except (ValueError, YAMLError): raise InvalidArgumentError( f"Invalid param/value pair '{param_str}'" diff --git a/dvc/scm/git.py b/dvc/scm/git.py index 5df270bdae..db515e8b6e 100644 --- a/dvc/scm/git.py +++ b/dvc/scm/git.py @@ -5,7 +5,6 @@ import shlex from functools import partial -import yaml from funcy import cached_property from pathspec.patterns import GitWildMatchPattern @@ -20,6 +19,7 @@ ) from dvc.utils import fix_env, is_binary, relpath from dvc.utils.fs import path_isin +from dvc.utils.serialize import dump_yaml, load_yaml logger = logging.getLogger(__name__) @@ -316,11 +316,7 @@ def install(self, use_pre_commit_tool=False): return config_path = os.path.join(self.root_dir, ".pre-commit-config.yaml") - - config = {} - if os.path.exists(config_path): - with open(config_path) as fobj: - config = yaml.safe_load(fobj) + config = load_yaml(config_path) if os.path.exists(config_path) else {} entry = { "repo": "https://github.com/iterative/dvc", @@ -349,8 +345,7 @@ def install(self, use_pre_commit_tool=False): return config["repos"].append(entry) - with open(config_path, "w+") as fobj: - yaml.dump(config, fobj) + dump_yaml(config_path, config) def cleanup_ignores(self): for path in self.ignored_paths: diff --git a/dvc/stage/cache.py b/dvc/stage/cache.py index 9eceb465e3..0e312c6d65 100644 --- a/dvc/stage/cache.py +++ b/dvc/stage/cache.py @@ -2,7 +2,6 @@ import os from contextlib import contextmanager -import yaml from funcy import first from voluptuous import Invalid @@ -10,7 +9,7 @@ from dvc.schema import COMPILED_LOCK_FILE_STAGE_SCHEMA from dvc.utils import dict_sha256, relpath from dvc.utils.fs import makedirs -from dvc.utils.serialize import dump_yaml +from dvc.utils.serialize import YAMLFileCorruptedError, dump_yaml, load_yaml from .loader import StageLoader from .serialize import to_single_stage_lockfile @@ -54,11 +53,10 @@ def _load_cache(self, key, value): path = self._get_cache_path(key, value) try: - with open(path) as fobj: - return COMPILED_LOCK_FILE_STAGE_SCHEMA(yaml.safe_load(fobj)) + return COMPILED_LOCK_FILE_STAGE_SCHEMA(load_yaml(path)) except FileNotFoundError: return None - except (yaml.error.YAMLError, Invalid): + except (YAMLFileCorruptedError, Invalid): logger.warning("corrupted cache file '%s'.", relpath(path)) os.unlink(path) return None diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index bcceb1f760..a2ada45ad3 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -1,7 +1,6 @@ """Helpers for other modules.""" import hashlib -import io import json import logging import math @@ -12,7 +11,6 @@ import colorama import nanotime -from ruamel.yaml import YAML from shortuuid import uuid logger = logging.getLogger(__name__) @@ -237,18 +235,6 @@ def current_timestamp(): return int(nanotime.timestamp(time.time())) -def from_yaml_string(s): - return YAML().load(io.StringIO(s)) - - -def to_yaml_string(data): - stream = io.StringIO() - yaml = YAML() - yaml.default_flow_style = False - yaml.dump(data, stream) - return stream.getvalue() - - def colorize(message, color=None): """Returns a message in a specified color.""" if not color: diff --git a/dvc/utils/serialize/__init__.py b/dvc/utils/serialize/__init__.py index ed51128cc7..ab658a5c53 100644 --- a/dvc/utils/serialize/__init__.py +++ b/dvc/utils/serialize/__init__.py @@ -1,2 +1,8 @@ +from collections import defaultdict + +from ._common import * # noqa, pylint: disable=wildcard-import from ._toml import * # noqa, pylint: disable=wildcard-import from ._yaml import * # noqa, pylint: disable=wildcard-import + +PARSERS = defaultdict(lambda: parse_yaml) # noqa: F405 +PARSERS.update({".toml": parse_toml}) # noqa: F405 diff --git a/dvc/utils/serialize/_common.py b/dvc/utils/serialize/_common.py new file mode 100644 index 0000000000..16c5b5e1a1 --- /dev/null +++ b/dvc/utils/serialize/_common.py @@ -0,0 +1,12 @@ +"""Common utilities for serialize.""" + +from dvc.exceptions import DvcException +from dvc.utils import relpath + + +class ParseError(DvcException): + """Errors while parsing files""" + + def __init__(self, path, message): + path = relpath(path) + super().__init__(f"unable to read: '{path}', {message}") diff --git a/dvc/utils/serialize/_toml.py b/dvc/utils/serialize/_toml.py index 74aa59da38..7c002ba7df 100644 --- a/dvc/utils/serialize/_toml.py +++ b/dvc/utils/serialize/_toml.py @@ -1,6 +1,17 @@ import toml +from funcy import reraise -from dvc.exceptions import TOMLFileCorruptedError +from ._common import ParseError + + +class TOMLFileCorruptedError(ParseError): + def __init__(self, path): + super().__init__(path, "TOML file structure is corrupted") + + +def parse_toml(text, path, decoder=None): + with reraise(toml.TomlDecodeError, TOMLFileCorruptedError(path)): + return toml.loads(text, decoder=decoder) def parse_toml_for_update(text, path): @@ -10,12 +21,10 @@ def parse_toml_for_update(text, path): keys may be re-ordered between load/dump, but this function will at least preserve comments. """ - try: - return toml.loads(text, decoder=toml.TomlPreserveCommentDecoder()) - except toml.TomlDecodeError as exc: - raise TOMLFileCorruptedError(path) from exc + decoder = toml.TomlPreserveCommentDecoder() + return parse_toml(text, path, decoder=decoder) def dump_toml(path, data): - with open(path, "w", encoding="utf-8") as fobj: + with open(path, "w+", encoding="utf-8") as fobj: toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder()) diff --git a/dvc/utils/serialize/_yaml.py b/dvc/utils/serialize/_yaml.py index 70fb46f31a..f60f7c64a2 100644 --- a/dvc/utils/serialize/_yaml.py +++ b/dvc/utils/serialize/_yaml.py @@ -1,14 +1,16 @@ +import io from collections import OrderedDict +from funcy import reraise from ruamel.yaml import YAML from ruamel.yaml.error import YAMLError -from dvc.exceptions import YAMLFileCorruptedError +from ._common import ParseError -try: - from yaml import CSafeLoader as SafeLoader -except ImportError: - from yaml import SafeLoader + +class YAMLFileCorruptedError(ParseError): + def __init__(self, path): + super().__init__(path, "YAML file structure is corrupted") def load_yaml(path): @@ -16,13 +18,10 @@ def load_yaml(path): return parse_yaml(fd.read(), path) -def parse_yaml(text, path): - try: - import yaml - - return yaml.load(text, Loader=SafeLoader) or {} - except yaml.error.YAMLError as exc: - raise YAMLFileCorruptedError(path) from exc +def parse_yaml(text, path, typ="safe"): + yaml = YAML(typ=typ) + with reraise(YAMLError, YAMLFileCorruptedError(path)): + return yaml.load(text) or {} def parse_yaml_for_update(text, path): @@ -34,20 +33,30 @@ def parse_yaml_for_update(text, path): This one is, however, several times slower than simple `parse_yaml()`. """ - try: - yaml = YAML() - return yaml.load(text) or {} - except YAMLError as exc: - raise YAMLFileCorruptedError(path) from exc + return parse_yaml(text, path, typ="rt") + + +def _get_yaml(): + yaml = YAML() + yaml.default_flow_style = False + + # tell Dumper to represent OrderedDict as normal dict + yaml_repr_cls = yaml.Representer + yaml_repr_cls.add_representer(OrderedDict, yaml_repr_cls.represent_dict) + return yaml def dump_yaml(path, data): - with open(path, "w", encoding="utf-8") as fd: - yaml = YAML() - yaml.default_flow_style = False - # tell Dumper to represent OrderedDict as - # normal dict - yaml.Representer.add_representer( - OrderedDict, yaml.Representer.represent_dict - ) + yaml = _get_yaml() + with open(path, "w+", encoding="utf-8") as fd: yaml.dump(data, fd) + + +def loads_yaml(s): + return YAML(typ="safe").load(s) + + +def dumps_yaml(d): + stream = io.StringIO() + YAML().dump(d, stream) + return stream.getvalue() diff --git a/setup.py b/setup.py index 687f7353c1..6031981bd8 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,6 @@ def run(self): "grandalf==0.6", "distro>=1.3.0", "appdirs>=1.4.3", - "PyYAML>=5.1.2,<5.4", # Compatibility with awscli "ruamel.yaml>=0.16.1", "toml>=0.10.1", "funcy>=1.14", diff --git a/tests/func/metrics/test_diff.py b/tests/func/metrics/test_diff.py index 752094df84..c7ea99bf33 100644 --- a/tests/func/metrics/test_diff.py +++ b/tests/func/metrics/test_diff.py @@ -1,9 +1,8 @@ import json import logging -import yaml - from dvc.main import main +from dvc.utils.serialize import dump_yaml def test_metrics_diff_simple(tmp_dir, scm, dvc, run_copy_metrics): @@ -24,7 +23,7 @@ def _gen(val): def test_metrics_diff_yaml(tmp_dir, scm, dvc, run_copy_metrics): def _gen(val): metrics = {"a": {"b": {"c": val, "d": 1, "e": str(val)}}} - tmp_dir.gen({"m_temp.yaml": yaml.dump(metrics)}) + dump_yaml("m_temp.yaml", metrics) run_copy_metrics( "m_temp.yaml", "m.yaml", metrics=["m.yaml"], commit=str(val) ) diff --git a/tests/func/plots/test_plots.py b/tests/func/plots/test_plots.py index 0e7427c769..60faeac34b 100644 --- a/tests/func/plots/test_plots.py +++ b/tests/func/plots/test_plots.py @@ -6,7 +6,6 @@ from collections import OrderedDict import pytest -import yaml from funcy import first from dvc.repo.plots.data import ( @@ -21,6 +20,7 @@ NoFieldInDataError, TemplateNotFoundError, ) +from dvc.utils.serialize import dump_yaml, dumps_yaml def _write_csv(metric, filename, header=True): @@ -493,9 +493,7 @@ def test_plot_default_choose_column(tmp_dir, scm, dvc, run_copy_metrics): def test_plot_yaml(tmp_dir, scm, dvc, run_copy_metrics): metric = [{"val": 2}, {"val": 3}] - with open("metric_t.yaml", "w") as fobj: - yaml.dump(metric, fobj) - + dump_yaml("metric_t.yaml", metric) run_copy_metrics( "metric_t.yaml", "metric.yaml", plots_no_cache=["metric.yaml"] ) @@ -543,7 +541,7 @@ def test_load_metric_from_dict_yaml(tmp_dir): metric = [{"acccuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}] dmetric = {"train": metric} - plot_data = YAMLPlotData("-", "revision", yaml.dump(dmetric)) + plot_data = YAMLPlotData("-", "revision", dumps_yaml(dmetric)) expected = metric for d in expected: diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 50bda72de9..2aea0e58eb 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -17,7 +17,6 @@ OutputDuplicationError, OverlappingOutputPathsError, RecursiveAddingWhileUsingFilename, - YAMLFileCorruptedError, ) from dvc.main import main from dvc.output.base import OutputAlreadyTrackedError, OutputIsStageFileError @@ -27,7 +26,7 @@ from dvc.tree.local import LocalTree from dvc.utils import LARGE_DIR_SIZE, file_md5, relpath from dvc.utils.fs import path_isin -from dvc.utils.serialize import load_yaml +from dvc.utils.serialize import YAMLFileCorruptedError, load_yaml from tests.basic_env import TestDvc from tests.utils import get_gitignore_content diff --git a/tests/func/test_lockfile.py b/tests/func/test_lockfile.py index b9a25ff184..c1de074b0e 100644 --- a/tests/func/test_lockfile.py +++ b/tests/func/test_lockfile.py @@ -1,20 +1,18 @@ from collections import OrderedDict from operator import itemgetter -import yaml - from dvc.dvcfile import PIPELINE_LOCK from dvc.stage.utils import split_params_deps from dvc.utils.fs import remove -from dvc.utils.serialize import parse_yaml_for_update +from dvc.utils.serialize import dumps_yaml, parse_yaml_for_update from tests.func.test_run_multistage import supported_params FS_STRUCTURE = { "foo": "bar\nfoobar", "bar": "foo\nfoobar", "foobar": "foobar\nbar", - "params.yaml": yaml.dump(supported_params), - "params2.yaml": yaml.dump(supported_params), + "params.yaml": dumps_yaml(supported_params), + "params2.yaml": dumps_yaml(supported_params), } diff --git a/tests/func/test_repro_multistage.py b/tests/func/test_repro_multistage.py index 979e063831..37ba758d61 100644 --- a/tests/func/test_repro_multistage.py +++ b/tests/func/test_repro_multistage.py @@ -3,7 +3,6 @@ from textwrap import dedent import pytest -import yaml from funcy import lsplit from dvc.dvcfile import PIPELINE_FILE, PIPELINE_LOCK @@ -474,11 +473,8 @@ def test_repro_multiple_params(tmp_dir, dvc): from dvc.stage.utils import split_params_deps from tests.func.test_run_multistage import supported_params - with (tmp_dir / "params2.yaml").open("w+") as f: - yaml.dump(supported_params, f) - - with (tmp_dir / "params.yaml").open("w+") as f: - yaml.dump(supported_params, f) + dump_yaml(tmp_dir / "params2.yaml", supported_params) + dump_yaml(tmp_dir / "params.yaml", supported_params) (tmp_dir / "foo").write_text("foo") stage = dvc.run( @@ -518,9 +514,8 @@ def test_repro_multiple_params(tmp_dir, dvc): assert set(defaults) == {"answer", "floats", "nested.nested1"} assert not dvc.reproduce(stage.addressing) - with (tmp_dir / "params.yaml").open("w+") as f: - params = deepcopy(supported_params) - params["answer"] = 43 - yaml.dump(params, f) + params = deepcopy(supported_params) + params["answer"] = 43 + dump_yaml(tmp_dir / "params.yaml", params) assert dvc.reproduce(stage.addressing) == [stage] diff --git a/tests/func/test_run_multistage.py b/tests/func/test_run_multistage.py index 29799e99b5..ec05642fa9 100644 --- a/tests/func/test_run_multistage.py +++ b/tests/func/test_run_multistage.py @@ -2,12 +2,11 @@ import textwrap import pytest -import yaml from dvc.exceptions import InvalidArgumentError from dvc.repo import Repo from dvc.stage.exceptions import DuplicateStageName, InvalidStageName -from dvc.utils.serialize import parse_yaml_for_update +from dvc.utils.serialize import dump_yaml, parse_yaml_for_update def test_run_with_name(tmp_dir, dvc, run_copy): @@ -236,9 +235,7 @@ def test_run_already_exists(tmp_dir, dvc, run_copy): def test_run_params_default(tmp_dir, dvc): from dvc.dependency import ParamsDependency - with (tmp_dir / "params.yaml").open("w+") as f: - yaml.dump(supported_params, f) - + dump_yaml(tmp_dir / "params.yaml", supported_params) stage = dvc.run( name="read_params", params=["nested.nested1.nested2"], @@ -261,9 +258,7 @@ def test_run_params_default(tmp_dir, dvc): def test_run_params_custom_file(tmp_dir, dvc): from dvc.dependency import ParamsDependency - with (tmp_dir / "params2.yaml").open("w+") as f: - yaml.dump(supported_params, f) - + dump_yaml(tmp_dir / "params2.yaml", supported_params) stage = dvc.run( name="read_params", params=["params2.yaml:lists"], @@ -286,9 +281,7 @@ def test_run_params_custom_file(tmp_dir, dvc): def test_run_params_no_exec(tmp_dir, dvc): from dvc.dependency import ParamsDependency - with (tmp_dir / "params2.yaml").open("w+") as f: - yaml.dump(supported_params, f) - + dump_yaml(tmp_dir / "params2.yaml", supported_params) stage = dvc.run( name="read_params", params=["params2.yaml:lists"], diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index 3dacc1691a..0ad370a1d0 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -1,11 +1,9 @@ import pytest -import toml -import yaml from dvc.dependency import ParamsDependency, loadd_from, loads_params from dvc.dependency.param import BadParamFileError, MissingParamsError from dvc.stage import Stage -from dvc.utils.serialize import load_yaml +from dvc.utils.serialize import dump_toml, dump_yaml, load_yaml PARAMS = { "foo": 1, @@ -92,9 +90,8 @@ def test_read_params_unsupported_format(tmp_dir, dvc): def test_read_params_nested(tmp_dir, dvc): - tmp_dir.gen( - DEFAULT_PARAMS_FILE, - yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}), + dump_yaml( + DEFAULT_PARAMS_FILE, {"some": {"path": {"foo": ["val1", "val2"]}}} ) dep = ParamsDependency(Stage(dvc), None, ["some.path.foo"]) assert dep.read_params() == {"some.path.foo": ["val1", "val2"]} @@ -102,20 +99,14 @@ def test_read_params_nested(tmp_dir, dvc): def test_read_params_default_loader(tmp_dir, dvc): parameters_file = "parameters.foo" - tmp_dir.gen( - parameters_file, - yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}), - ) + dump_yaml(parameters_file, {"some": {"path": {"foo": ["val1", "val2"]}}}) dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"]) assert dep.read_params() == {"some.path.foo": ["val1", "val2"]} def test_read_params_wrong_suffix(tmp_dir, dvc): parameters_file = "parameters.toml" - tmp_dir.gen( - parameters_file, - yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}), - ) + dump_yaml(parameters_file, {"some": {"path": {"foo": ["val1", "val2"]}}}) dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"]) with pytest.raises(BadParamFileError): dep.read_params() @@ -123,10 +114,7 @@ def test_read_params_wrong_suffix(tmp_dir, dvc): def test_read_params_toml(tmp_dir, dvc): parameters_file = "parameters.toml" - tmp_dir.gen( - parameters_file, - toml.dumps({"some": {"path": {"foo": ["val1", "val2"]}}}), - ) + dump_toml(parameters_file, {"some": {"path": {"foo": ["val1", "val2"]}}}) dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"]) assert dep.read_params() == {"some.path.foo": ["val1", "val2"]} @@ -144,13 +132,10 @@ def test_save_info_missing_param(tmp_dir, dvc): dep.save_info() -@pytest.mark.parametrize( - "param_value", - ["", "false", "[]", "{}", "null", "no", "off"] - # we use pyyaml to load params.yaml, which only supports YAML 1.1 - # so, some of the above are boolean values -) +@pytest.mark.regression_4184 +@pytest.mark.parametrize("param_value", ["", "false", "[]", "{}", "null"]) def test_params_with_false_values(tmp_dir, dvc, param_value): + """These falsy params values should not ignored by `status` on loading.""" key = "param" dep = ParamsDependency(Stage(dvc), DEFAULT_PARAMS_FILE, [key]) (tmp_dir / DEFAULT_PARAMS_FILE).write_text(f"{key}: {param_value}") diff --git a/tests/unit/test_lockfile.py b/tests/unit/test_lockfile.py index 26bcd13a5a..a5a8c867fb 100644 --- a/tests/unit/test_lockfile.py +++ b/tests/unit/test_lockfile.py @@ -1,8 +1,8 @@ import pytest -import yaml from dvc.dvcfile import Lockfile, LockfileCorruptedError from dvc.stage import PipelineStage +from dvc.utils.serialize import dump_yaml def test_stage_dump_no_outs_deps(tmp_dir, dvc): @@ -14,8 +14,7 @@ def test_stage_dump_no_outs_deps(tmp_dir, dvc): def test_stage_dump_when_already_exists(tmp_dir, dvc): data = {"s1": {"cmd": "command", "deps": [], "outs": []}} - with open("path.lock", "w+") as f: - yaml.dump(data, f) + dump_yaml("path.lock", data) stage = PipelineStage(name="s2", repo=dvc, path="path", cmd="command2") lockfile = Lockfile(dvc, "path.lock") lockfile.dump(stage) @@ -33,9 +32,7 @@ def test_stage_dump_with_deps_and_outs(tmp_dir, dvc): "outs": [{"md5": "2.txt", "path": "checksum"}], } } - with open("path.lock", "w+") as f: - yaml.dump(data, f) - + dump_yaml("path.lock", data) lockfile = Lockfile(dvc, "path.lock") stage = PipelineStage(name="s2", repo=dvc, path="path", cmd="command2") lockfile.dump(stage) @@ -77,8 +74,7 @@ def test_load_when_lockfile_does_not_exist(tmp_dir, dvc): ], ) def test_load_when_lockfile_is_corrupted(tmp_dir, dvc, corrupt_data): - with open("Dvcfile.lock", "w+") as f: - yaml.dump(corrupt_data, f) + dump_yaml("Dvcfile.lock", corrupt_data) lockfile = Lockfile(dvc, "Dvcfile.lock") with pytest.raises(LockfileCorruptedError) as exc_info: lockfile.load()