Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move to yaml1.2 #4415

Merged
merged 2 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from collections import defaultdict

import dpath.util
import toml
Copy link
Member Author

@skshetry skshetry Aug 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmrowla, tomlkit preserves ordering and comments. We can move to it in the future.

import yaml
from voluptuous import Any

from dvc.dependency.local import LocalDependency
from dvc.exceptions import DvcException
from dvc.utils.serialize import PARSERS, ParseError


class MissingParamsError(DvcException):
Expand All @@ -22,8 +21,6 @@ class ParamsDependency(LocalDependency):
PARAM_PARAMS = "params"
PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)}
DEFAULT_PARAMS_FILE = "params.yaml"
PARAMS_FILE_LOADERS = defaultdict(lambda: yaml.safe_load)
PARAMS_FILE_LOADERS.update({".toml": toml.load})

def __init__(self, stage, path, params):
info = {}
Expand Down Expand Up @@ -88,12 +85,12 @@ def read_params(self):
if not self.exists:
return {}

suffix = self.path_info.suffix.lower()
parser = PARSERS[suffix]
with self.repo.tree.open(self.path_info, "r") as fobj:
try:
config = self.PARAMS_FILE_LOADERS[
self.path_info.suffix.lower()
](fobj)
except (yaml.YAMLError, toml.TomlDecodeError) as exc:
config = parser(fobj.read(), self.path_info)
except ParseError as exc:
skshetry marked this conversation as resolved.
Show resolved Hide resolved
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
) from exc
Expand Down
16 changes: 0 additions & 16 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,22 +182,6 @@ def __init__(self):
)


class YAMLFileCorruptedError(DvcException):
def __init__(self, path):
path = relpath(path)
super().__init__(
f"unable to read: '{path}', YAML file structure is corrupted"
)


class TOMLFileCorruptedError(DvcException):
def __init__(self, path):
path = relpath(path)
super().__init__(
f"unable to read: '{path}', TOML file structure is corrupted"
)


class RecursiveAddingWhileUsingFilename(DvcException):
def __init__(self):
super().__init__(
Expand Down
7 changes: 3 additions & 4 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
import os

import yaml

from dvc.exceptions import NoMetricsError
from dvc.path_info import PathInfo
from dvc.repo import locked
from dvc.repo.tree import RepoTree
from dvc.utils.serialize import YAMLFileCorruptedError, parse_yaml

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,8 +71,8 @@ def _read_metrics(repo, metrics, rev):
try:
with tree.open(metric, "r") as fobj:
# NOTE this also supports JSON
val = yaml.safe_load(fobj)
except (FileNotFoundError, yaml.YAMLError):
val = parse_yaml(fobj.read(), metric)
except (FileNotFoundError, YAMLFileCorruptedError):
skshetry marked this conversation as resolved.
Show resolved Hide resolved
logger.debug(
"failed to read '%s' on '%s'", metric, rev, exc_info=True
)
Expand Down
12 changes: 5 additions & 7 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging

import toml
import yaml

from dvc.dependency.param import ParamsDependency
from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.repo import locked
from dvc.utils.serialize import PARSERS, ParseError

logger = logging.getLogger(__name__)

Expand All @@ -33,12 +31,12 @@ def _read_params(repo, configs, rev):
if not repo.tree.exists(config):
continue

suffix = config.suffix.lower()
parser = PARSERS[suffix]
with repo.tree.open(config, "r") as fobj:
try:
res[str(config)] = ParamsDependency.PARAMS_FILE_LOADERS[
config.suffix.lower()
](fobj)
except (yaml.YAMLError, toml.TomlDecodeError):
res[str(config)] = parser(fobj.read(), config)
except ParseError:
logger.debug(
"failed to read '%s' on '%s'", config, rev, exc_info=True
)
Expand Down
16 changes: 2 additions & 14 deletions dvc/repo/plots/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from collections import OrderedDict
from copy import copy

import yaml
from funcy import first
from yaml import SafeLoader
from ruamel.yaml import YAML

from dvc.exceptions import DvcException

Expand Down Expand Up @@ -208,18 +207,7 @@ def raw(self, header=True, **kwargs): # pylint: disable=arguments-differ

class YAMLPlotData(PlotData):
def raw(self, **kwargs):
class OrderedLoader(SafeLoader):
pass

def construct_mapping(loader, node):
loader.flatten_mapping(node)
return OrderedDict(loader.construct_pairs(node))

OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping
)

return yaml.load(self.content, OrderedLoader)
return YAML().load(self.content)
skshetry marked this conversation as resolved.
Show resolved Hide resolved

def _processors(self):
parent_processors = super()._processors()
Expand Down
5 changes: 3 additions & 2 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ def reproduce(

def _parse_params(path_params):
from flatten_json import unflatten
from yaml import YAMLError, safe_load
from ruamel.yaml import YAMLError

from dvc.dependency.param import ParamsDependency
from dvc.utils.serialize import loads_yaml

ret = {}
for path_param in path_params:
Expand All @@ -133,7 +134,7 @@ def _parse_params(path_params):
try:
# interpret value strings using YAML rules
key, value = param_str.split("=")
params[key] = safe_load(value)
params[key] = loads_yaml(value)
except (ValueError, YAMLError):
raise InvalidArgumentError(
f"Invalid param/value pair '{param_str}'"
Expand Down
11 changes: 3 additions & 8 deletions dvc/scm/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import shlex
from functools import partial

import yaml
from funcy import cached_property
from pathspec.patterns import GitWildMatchPattern

Expand All @@ -20,6 +19,7 @@
)
from dvc.utils import fix_env, is_binary, relpath
from dvc.utils.fs import path_isin
from dvc.utils.serialize import dump_yaml, load_yaml

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -316,11 +316,7 @@ def install(self, use_pre_commit_tool=False):
return

config_path = os.path.join(self.root_dir, ".pre-commit-config.yaml")

config = {}
if os.path.exists(config_path):
with open(config_path) as fobj:
config = yaml.safe_load(fobj)
config = load_yaml(config_path) if os.path.exists(config_path) else {}

entry = {
"repo": "https://github.com/iterative/dvc",
Expand Down Expand Up @@ -349,8 +345,7 @@ def install(self, use_pre_commit_tool=False):
return

config["repos"].append(entry)
with open(config_path, "w+") as fobj:
yaml.dump(config, fobj)
dump_yaml(config_path, config)
skshetry marked this conversation as resolved.
Show resolved Hide resolved

def cleanup_ignores(self):
for path in self.ignored_paths:
Expand Down
8 changes: 3 additions & 5 deletions dvc/stage/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
from contextlib import contextmanager

import yaml
from funcy import first
from voluptuous import Invalid

from dvc.cache.local import _log_exceptions
from dvc.schema import COMPILED_LOCK_FILE_STAGE_SCHEMA
from dvc.utils import dict_sha256, relpath
from dvc.utils.fs import makedirs
from dvc.utils.serialize import dump_yaml
from dvc.utils.serialize import YAMLFileCorruptedError, dump_yaml, load_yaml

from .loader import StageLoader
from .serialize import to_single_stage_lockfile
Expand Down Expand Up @@ -54,11 +53,10 @@ def _load_cache(self, key, value):
path = self._get_cache_path(key, value)

try:
with open(path) as fobj:
return COMPILED_LOCK_FILE_STAGE_SCHEMA(yaml.safe_load(fobj))
return COMPILED_LOCK_FILE_STAGE_SCHEMA(load_yaml(path))
except FileNotFoundError:
return None
except (yaml.error.YAMLError, Invalid):
except (YAMLFileCorruptedError, Invalid):
logger.warning("corrupted cache file '%s'.", relpath(path))
os.unlink(path)
return None
Expand Down
14 changes: 0 additions & 14 deletions dvc/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Helpers for other modules."""

import hashlib
import io
import json
import logging
import math
Expand All @@ -12,7 +11,6 @@

import colorama
import nanotime
from ruamel.yaml import YAML
from shortuuid import uuid

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -237,18 +235,6 @@ def current_timestamp():
return int(nanotime.timestamp(time.time()))


def from_yaml_string(s):
return YAML().load(io.StringIO(s))


def to_yaml_string(data):
stream = io.StringIO()
yaml = YAML()
yaml.default_flow_style = False
yaml.dump(data, stream)
return stream.getvalue()


skshetry marked this conversation as resolved.
Show resolved Hide resolved
def colorize(message, color=None):
"""Returns a message in a specified color."""
if not color:
Expand Down
6 changes: 6 additions & 0 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from collections import defaultdict

from ._common import * # noqa, pylint: disable=wildcard-import
from ._toml import * # noqa, pylint: disable=wildcard-import
from ._yaml import * # noqa, pylint: disable=wildcard-import

PARSERS = defaultdict(lambda: parse_yaml) # noqa: F405
PARSERS.update({".toml": parse_toml}) # noqa: F405
Copy link
Member Author

@skshetry skshetry Aug 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it'd be good to have JSON utils after the upcoming PR with trees are done. Maybe even move to ujson as it has now wheels on all platforms.

In [1]: %timeit json.loads('{"ja": "son"}')
2.32 Β΅s Β± 35 ns per loop (mean Β± std. dev. of 7 runs, 100000 loops each)

In [2]: %timeit YAML(typ='safe').load('{"ja": "son"}')
180 Β΅s Β± 1.28 Β΅s per loop (mean Β± std. dev. of 7 runs, 10000 loops each)

In [3]: %timeit ujson.loads('{"ja": "son"}')
372 ns Β± 13.4 ns per loop (mean Β± std. dev. of 7 runs, 1000000 loops each)

12 changes: 12 additions & 0 deletions dvc/utils/serialize/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Common utilities for serialize."""

from dvc.exceptions import DvcException
from dvc.utils import relpath


class ParseError(DvcException):
"""Errors while parsing files"""

def __init__(self, path, message):
path = relpath(path)
super().__init__(f"unable to read: '{path}', {message}")
21 changes: 15 additions & 6 deletions dvc/utils/serialize/_toml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import toml
from funcy import reraise

from dvc.exceptions import TOMLFileCorruptedError
from ._common import ParseError


class TOMLFileCorruptedError(ParseError):
def __init__(self, path):
super().__init__(path, "TOML file structure is corrupted")


def parse_toml(text, path, decoder=None):
with reraise(toml.TomlDecodeError, TOMLFileCorruptedError(path)):
return toml.loads(text, decoder=decoder)


def parse_toml_for_update(text, path):
Expand All @@ -10,12 +21,10 @@ def parse_toml_for_update(text, path):
keys may be re-ordered between load/dump, but this function will at
least preserve comments.
"""
try:
return toml.loads(text, decoder=toml.TomlPreserveCommentDecoder())
except toml.TomlDecodeError as exc:
raise TOMLFileCorruptedError(path) from exc
decoder = toml.TomlPreserveCommentDecoder()
return parse_toml(text, path, decoder=decoder)


def dump_toml(path, data):
with open(path, "w", encoding="utf-8") as fobj:
with open(path, "w+", encoding="utf-8") as fobj:
toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder())
Loading