Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/add-remote-reference-support #1

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
import shutil
import subprocess
import sys
import urllib
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union, cast

import httpcore
import httpx
import yaml
from jinja2 import BaseLoader, ChoiceLoader, Environment, FileSystemLoader, PackageLoader

from openapi_python_client import utils

from .parser import GeneratorData, import_string_from_reference
from .parser.errors import GeneratorError
from .resolver.schema_resolver import SchemaResolver
from .utils import snake_case

if sys.version_info.minor < 8: # version did not exist before 3.8, need to use a backport
Expand Down Expand Up @@ -287,20 +288,21 @@ def update_existing_client(


def _get_document(*, url: Optional[str], path: Optional[Path]) -> Union[Dict[str, Any], GeneratorError]:
yaml_bytes: bytes
if url is not None and path is not None:
return GeneratorError(header="Provide URL or Path, not both.")
if url is not None:
try:
response = httpx.get(url)
yaml_bytes = response.content
except (httpx.HTTPError, httpcore.NetworkError):
return GeneratorError(header="Could not get OpenAPI document from provided URL")
elif path is not None:
yaml_bytes = path.read_bytes()
else:

if url is None and path is None:
return GeneratorError(header="No URL or Path provided")

source = cast(Union[str, Path], (url if url is not None else path))
try:
return yaml.safe_load(yaml_bytes)
except yaml.YAMLError:
resolver = SchemaResolver(source)
result = resolver.resolve()
if len(result.errors) > 0:
return GeneratorError(header="; ".join(result.errors))
except (httpx.HTTPError, httpcore.NetworkError, urllib.error.URLError):
return GeneratorError(header="Could not get OpenAPI document from provided URL")
except Exception:
return GeneratorError(header="Invalid YAML from provided source")

return result.schema
Empty file.
145 changes: 145 additions & 0 deletions openapi_python_client/resolver/collision_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import hashlib
from typing import Any, Dict, List, Tuple

from .reference import Reference
from .resolver_types import SchemaData


class CollisionResolver:
def __init__(self, root: SchemaData, refs: Dict[str, SchemaData], errors: List[str], parent: str):
self._root: SchemaData = root
self._refs: Dict[str, SchemaData] = refs
self._errors: List[str] = errors
self._parent = parent
self._refs_index: Dict[str, str] = dict()
self._schema_index: Dict[str, Reference] = dict()
self._keys_to_replace: Dict[str, Tuple[int, SchemaData, List[str]]] = dict()

def _browse_schema(self, attr: Any, root_attr: Any) -> None:
if isinstance(attr, dict):
attr_copy = {**attr} # Create a shallow copy
for key, val in attr_copy.items():
if key == "$ref":
ref = Reference(val, self._parent)
value = ref.pointer.value

assert value

schema = self._get_from_ref(ref, root_attr)
hashed_schema = self._reference_schema_hash(schema)

if value in self._refs_index.keys():
if self._refs_index[value] != hashed_schema:
if ref.is_local():
self._increment_ref(ref, root_attr, hashed_schema, attr, key)
else:
assert ref.abs_path in self._refs.keys()
self._increment_ref(ref, self._refs[ref.abs_path], hashed_schema, attr, key)
else:
self._refs_index[value] = hashed_schema

if hashed_schema in self._schema_index.keys():
existing_ref = self._schema_index[hashed_schema]
if (
existing_ref.pointer.value != ref.pointer.value
and ref.pointer.tokens()[-1] == existing_ref.pointer.tokens()[-1]
):
print("Found same schema for different pointer")
else:
self._schema_index[hashed_schema] = ref

else:
self._browse_schema(val, root_attr)

elif isinstance(attr, list):
for val in attr:
self._browse_schema(val, root_attr)

def _get_from_ref(self, ref: Reference, attr: SchemaData) -> SchemaData:
if ref.is_remote():
assert ref.abs_path in self._refs.keys()
attr = self._refs[ref.abs_path]
cursor = attr
query_parts = ref.pointer.tokens()

for key in query_parts:
if key == "":
continue

if isinstance(cursor, dict) and key in cursor:
cursor = cursor[key]
else:
print("ERROR")

if list(cursor) == ["$ref"]:
ref2 = Reference(cursor["$ref"], self._parent)
if ref2.is_remote():
attr = self._refs[ref2.abs_path]
return self._get_from_ref(ref2, attr)

return cursor

def _increment_ref(
self, ref: Reference, schema: SchemaData, hashed_schema: str, attr: Dict[str, Any], key: str
) -> None:
i = 2
value = ref.pointer.value
incremented_value = value + "_" + str(i)

while incremented_value in self._refs_index.keys():
if self._refs_index[incremented_value] == hashed_schema:
if ref.value not in self._keys_to_replace.keys():
break # have to increment target key aswell
else:
attr[key] = ref.value + "_" + str(i)
return
else:
i = i + 1
incremented_value = value + "_" + str(i)

attr[key] = ref.value + "_" + str(i)
self._refs_index[incremented_value] = hashed_schema
self._keys_to_replace[ref.value] = (i, schema, ref.pointer.tokens())

def _modify_root_ref_name(self, query_parts: List[str], i: int, attr: SchemaData) -> None:
cursor = attr
last_key = query_parts[-1]

for key in query_parts:
if key == "":
continue

if key == last_key and key + "_" + str(i) not in cursor:
assert key in cursor, "Didnt find %s in %s" % (key, attr)
cursor[key + "_" + str(i)] = cursor.pop(key)
return

if isinstance(cursor, dict) and key in cursor:
cursor = cursor[key]
else:
return

def resolve(self) -> None:
self._browse_schema(self._root, self._root)
for file, schema in self._refs.items():
self._browse_schema(schema, schema)
for a, b in self._keys_to_replace.items():
self._modify_root_ref_name(b[2], b[0], b[1])

def _reference_schema_hash(self, schema: Dict[str, Any]) -> str:
md5 = hashlib.md5()
hash_elms = []
for key in schema.keys():
if key == "description":
hash_elms.append(schema[key])
if key == "type":
hash_elms.append(schema[key])
if key == "allOf":
for item in schema[key]:
hash_elms.append(str(item))

hash_elms.append(key)

hash_elms.sort()
md5.update(";".join(hash_elms).encode("utf-8"))
return md5.hexdigest()
22 changes: 22 additions & 0 deletions openapi_python_client/resolver/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import yaml

from .resolver_types import SchemaData


class DataLoader:
@classmethod
def load(cls, path: str, data: bytes) -> SchemaData:
data_type = path.split(".")[-1].casefold()

if data_type == "json":
return cls.load_json(data)
else:
return cls.load_yaml(data)

@classmethod
def load_json(cls, data: bytes) -> SchemaData:
raise NotImplementedError()

@classmethod
def load_yaml(cls, data: bytes) -> SchemaData:
return yaml.safe_load(data)
48 changes: 48 additions & 0 deletions openapi_python_client/resolver/pointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import urllib.parse
from typing import List, Union


class Pointer:
""" https://tools.ietf.org/html/rfc6901 """

def __init__(self, pointer: str) -> None:
if pointer is None or pointer != "" and not pointer.startswith("/"):
raise ValueError(f'Invalid pointer value {pointer}, it must match: *( "/" reference-token )')

self._pointer = pointer

@property
def value(self) -> str:
return self._pointer

@property
def parent(self) -> Union["Pointer", None]:
tokens = self.tokens(False)

if len(tokens) > 1:
tokens.pop()
return Pointer("/".join(tokens))
else:
assert tokens[-1] == ""
return None

def tokens(self, unescape: bool = True) -> List[str]:
tokens = []

if unescape:
for token in self._pointer.split("/"):
tokens.append(self._unescape(token))
else:
tokens = self._pointer.split("/")

return tokens

@property
def unescapated_value(self) -> str:
return self._unescape(self._pointer)

def _unescape(self, data: str) -> str:
data = urllib.parse.unquote(data)
data = data.replace("~1", "/")
data = data.replace("~0", "~")
return data
68 changes: 68 additions & 0 deletions openapi_python_client/resolver/reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import urllib.parse
from pathlib import Path
from typing import Union

from .pointer import Pointer


class Reference:
""" https://tools.ietf.org/html/draft-pbryan-zyp-json-ref-03 """

def __init__(self, reference: str, parent: str = None):
self._ref = reference
self._parsed_ref = urllib.parse.urlparse(reference)
self._parent = parent

@property
def path(self) -> str:
return urllib.parse.urldefrag(self._parsed_ref.geturl()).url

@property
def abs_path(self) -> str:
if self._parent:
parent_dir = Path(self._parent)
abs_path = parent_dir.joinpath(self.path)
abs_path = abs_path.resolve()
return str(abs_path)
else:
return self.path

@property
def parent(self) -> Union[str, None]:
return self._parent

@property
def pointer(self) -> Pointer:
frag = self._parsed_ref.fragment
if self.is_url() and frag != "" and not frag.startswith("/"):
frag = f"/{frag}"

return Pointer(frag)

def is_relative(self) -> bool:
""" return True if reference path is a relative path """
return not self.is_absolute()

def is_absolute(self) -> bool:
""" return True is reference path is an absolute path """
return self._parsed_ref.netloc != ""

@property
def value(self) -> str:
return self._ref

def is_url(self) -> bool:
""" return True if the reference path is pointing to an external url location """
return self.is_remote() and self._parsed_ref.netloc != ""

def is_remote(self) -> bool:
""" return True if the reference pointer is pointing to a remote document """
return not self.is_local()

def is_local(self) -> bool:
""" return True if the reference pointer is pointing to the current document """
return self._parsed_ref.path == ""

def is_full_document(self) -> bool:
""" return True if the reference pointer is pointing to the whole document content """
return self.pointer.parent is None
Loading