Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils.serialize: modify_yaml contextmanager to reduce boilerplate #4426

Merged
merged 1 commit into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 24 additions & 29 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dvc.utils.serialize import (
dump_yaml,
load_yaml,
modify_yaml,
parse_yaml,
parse_yaml_for_update,
)
Expand Down Expand Up @@ -193,30 +194,27 @@ def _dump_lockfile(self, stage):
self._lockfile.dump(stage)

def _dump_pipeline_file(self, stage):
data = {}
if self.exists():
with open(self.path) as fd:
data = parse_yaml_for_update(fd.read(), self.path)
else:
logger.info("Creating '%s'", self.relpath)
open(self.path, "w+").close()

data["stages"] = data.get("stages", {})
stage_data = serialize.to_pipeline_file(stage)
existing_entry = stage.name in data["stages"]

action = "Modifying" if existing_entry else "Adding"
logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath)
with modify_yaml(self.path, tree=self.repo.tree) as data:
if not data:
logger.info("Creating '%s'", self.relpath)

if existing_entry:
orig_stage_data = data["stages"][stage.name]
if "meta" in orig_stage_data:
stage_data[stage.name]["meta"] = orig_stage_data["meta"]
apply_diff(stage_data[stage.name], orig_stage_data)
else:
data["stages"].update(stage_data)
data["stages"] = data.get("stages", {})
existing_entry = stage.name in data["stages"]
action = "Modifying" if existing_entry else "Adding"
logger.info(
"%s stage '%s' in '%s'", action, stage.name, self.relpath
)

if existing_entry:
orig_stage_data = data["stages"][stage.name]
if "meta" in orig_stage_data:
stage_data[stage.name]["meta"] = orig_stage_data["meta"]
apply_diff(stage_data[stage.name], orig_stage_data)
else:
data["stages"].update(stage_data)

dump_yaml(self.path, data)
self.repo.scm.track_file(self.relpath)

@property
Expand Down Expand Up @@ -281,21 +279,18 @@ def load(self):

def dump(self, stage, **kwargs):
stage_data = serialize.to_lockfile(stage)
if not self.exists():
modified = True
logger.info("Generating lock file '%s'", self.relpath)
data = stage_data
open(self.path, "w+").close()
else:
with self.repo.tree.open(self.path, "r") as fd:
data = parse_yaml_for_update(fd.read(), self.path)

with modify_yaml(self.path, tree=self.repo.tree) as data:
if not data:
logger.info("Generating lock file '%s'", self.relpath)

modified = data.get(stage.name, {}) != stage_data.get(
stage.name, {}
)
if modified:
logger.info("Updating lock file '%s'", self.relpath)
data.update(stage_data)
dump_yaml(self.path, data)

if modified:
self.repo.scm.track_file(self.relpath)

Expand Down
21 changes: 4 additions & 17 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Mapping
from concurrent.futures import (
ProcessPoolExecutor,
Expand Down Expand Up @@ -213,12 +212,7 @@ def _unpack_args(self, tree=None):

def _update_params(self, params: dict):
"""Update experiment params files with the specified values."""
from dvc.utils.serialize import (
dump_toml,
dump_yaml,
parse_toml_for_update,
parse_yaml_for_update,
)
from dvc.utils.serialize import MODIFIERS

logger.debug("Using experiment params '%s'", params)

Expand All @@ -231,19 +225,12 @@ def _update(dict_, other):
dict_[key] = value
return dict_

loaders = defaultdict(lambda: parse_yaml_for_update)
loaders.update({".toml": parse_toml_for_update})
dumpers = defaultdict(lambda: dump_yaml)
dumpers.update({".toml": dump_toml})

for params_fname in params:
path = PathInfo(self.exp_dvc.root_dir) / params_fname
with self.exp_dvc.tree.open(path, "r") as fobj:
text = fobj.read()
suffix = path.suffix.lower()
data = loaders[suffix](text, path)
_update(data, params[params_fname])
dumpers[suffix](path, data)
modify_data = MODIFIERS[suffix]
with modify_data(path, tree=self.exp_dvc.tree) as data:
_update(data, params[params_fname])

def _commit(self, exp_hash, check_exists=True, branch=True):
"""Commit stages as an experiment and return the commit SHA."""
Expand Down
58 changes: 27 additions & 31 deletions dvc/scm/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from dvc.utils import fix_env, is_binary, relpath
from dvc.utils.fs import path_isin
from dvc.utils.serialize import dump_yaml, load_yaml
from dvc.utils.serialize import modify_yaml

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -330,36 +330,32 @@ def install(self, use_pre_commit_tool=False):
return

config_path = os.path.join(self.root_dir, ".pre-commit-config.yaml")
config = load_yaml(config_path) if os.path.exists(config_path) else {}

entry = {
"repo": "https://github.com/iterative/dvc",
"rev": "master",
"hooks": [
{
"id": "dvc-pre-commit",
"language_version": "python3",
"stages": ["commit"],
},
{
"id": "dvc-pre-push",
"language_version": "python3",
"stages": ["push"],
},
{
"id": "dvc-post-checkout",
"language_version": "python3",
"stages": ["post-checkout"],
"always_run": True,
},
],
}

if entry in config["repos"]:
return

config["repos"].append(entry)
dump_yaml(config_path, config)
with modify_yaml(config_path) as config:
entry = {
"repo": "https://github.com/iterative/dvc",
"rev": "master",
"hooks": [
{
"id": "dvc-pre-commit",
"language_version": "python3",
"stages": ["commit"],
},
{
"id": "dvc-pre-push",
"language_version": "python3",
"stages": ["push"],
},
{
"id": "dvc-post-checkout",
"language_version": "python3",
"stages": ["post-checkout"],
"always_run": True,
},
],
}

if entry not in config["repos"]:
config["repos"].append(entry)

def cleanup_ignores(self):
for path in self.ignored_paths:
Expand Down
3 changes: 3 additions & 0 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@

LOADERS = defaultdict(lambda: load_yaml) # noqa: F405
LOADERS.update({".toml": load_toml}) # noqa: F405

MODIFIERS = defaultdict(lambda: modify_yaml) # noqa: F405
MODIFIERS.update({".toml": modify_toml}) # noqa: F405
10 changes: 10 additions & 0 deletions dvc/utils/serialize/_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Common utilities for serialize."""
import os
from contextlib import contextmanager

from dvc.exceptions import DvcException
from dvc.utils import relpath
Expand All @@ -22,3 +24,11 @@ def _dump_data(path, data, dumper, tree=None):
open_fn = tree.open if tree else open
with open_fn(path, "w+", encoding="utf-8") as fd:
dumper(data, fd)


@contextmanager
def _modify_data(path, parser, dumper, tree=None):
exists = tree.exists if tree else os.path.exists
data = _load_data(path, parser=parser, tree=tree) if exists(path) else {}
yield data
dumper(path, data, tree=tree)
10 changes: 9 additions & 1 deletion dvc/utils/serialize/_toml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from contextlib import contextmanager

import toml
from funcy import reraise

from ._common import ParseError, _dump_data, _load_data
from ._common import ParseError, _dump_data, _load_data, _modify_data


class TOMLFileCorruptedError(ParseError):
Expand Down Expand Up @@ -35,3 +37,9 @@ def _dump(data, stream):

def dump_toml(path, data, tree=None):
return _dump_data(path, data, dumper=_dump, tree=tree)


@contextmanager
def modify_toml(path, tree=None):
with _modify_data(path, parse_toml_for_update, dump_toml, tree=tree) as d:
yield d
12 changes: 9 additions & 3 deletions dvc/utils/serialize/_yaml.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import io
from collections import OrderedDict
from contextlib import contextmanager

from funcy import reraise
from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError

from ._common import ParseError, _dump_data, _load_data
from ._common import ParseError, _dump_data, _load_data, _modify_data


class YAMLFileCorruptedError(ParseError):
Expand Down Expand Up @@ -60,6 +61,11 @@ def loads_yaml(s, typ="safe"):

def dumps_yaml(d):
stream = io.StringIO()
yaml = _get_yaml()
yaml.dump(d, stream)
_dump(d, stream)
return stream.getvalue()


@contextmanager
def modify_yaml(path, tree=None):
with _modify_data(path, parse_yaml_for_update, dump_yaml, tree=tree) as d:
yield d