From a5f3e4ae85cd2099e941d1b8bb05e7e1782f4165 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Fri, 7 Aug 2020 19:53:00 +0200 Subject: [PATCH 1/5] PoC recursive schemas --- src/hypothesis_jsonschema/_canonicalise.py | 8 +- src/hypothesis_jsonschema/_from_schema.py | 88 ++++++++++++++++------ tests/test_from_schema.py | 33 ++++++++ 3 files changed, 102 insertions(+), 27 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index 157100c..5d7ccc9 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -579,7 +579,12 @@ def resolve_all_refs( f"resolver={resolver} (type {type(resolver).__name__}) is not a RefResolver" ) - if "$ref" in schema: + def is_recursive(reference: str) -> bool: + return reference == "#" or resolver.resolution_scope == reference # type: ignore + + # To avoid infinite recursion, we skip all recursive definitions, and such references will be processed later + # A definition is recursive if it contains a reference to itself or one of its ancestors. + if "$ref" in schema and not is_recursive(schema["$ref"]): # type: ignore s = dict(schema) ref = s.pop("$ref") with resolver.resolving(ref) as got: @@ -590,7 +595,6 @@ def resolve_all_refs( msg = f"$ref:{ref!r} had incompatible base schema {s!r}" raise HypothesisRefResolutionError(msg) return resolve_all_refs(m, resolver=resolver) - assert "$ref" not in schema for key in SCHEMA_KEYS: val = schema.get(key, False) diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index 46cc3e1..aedbf75 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -4,6 +4,7 @@ import math import operator import re +from copy import deepcopy from fractions import Fraction from functools import partial from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union @@ -18,6 +19,8 @@ TRUTHY, TYPE_STRINGS, HypothesisRefResolutionError, + JSONType, + LocalResolver, Schema, canonicalish, get_integer_bounds, @@ -42,11 +45,13 @@ def merged_as_strategies( - schemas: List[Schema], custom_formats: Optional[Dict[str, st.SearchStrategy[str]]] + schemas: List[Schema], + custom_formats: Optional[Dict[str, st.SearchStrategy[str]]], + resolver: LocalResolver, ) -> st.SearchStrategy[JSONType]: assert schemas, "internal error: must pass at least one schema to merge" if len(schemas) == 1: - return from_schema(schemas[0], custom_formats=custom_formats) + return from_schema(schemas[0], custom_formats=custom_formats, resolver=resolver) # Try to merge combinations of strategies. strats = [] combined: Set[str] = set() @@ -60,7 +65,7 @@ def merged_as_strategies( if s is not None and s != FALSEY: validators = [make_validator(s) for s in schemas] strats.append( - from_schema(s, custom_formats=custom_formats).filter( + from_schema(s, custom_formats=custom_formats, resolver=resolver).filter( lambda obj: all(v.is_valid(obj) for v in validators) ) ) @@ -72,6 +77,7 @@ def from_schema( schema: Union[bool, Schema], *, custom_formats: Dict[str, st.SearchStrategy[str]] = None, + resolver: Optional[LocalResolver] = None, ) -> st.SearchStrategy[JSONType]: """Take a JSON schema and return a strategy for allowed JSON objects. @@ -79,7 +85,7 @@ def from_schema( everything else in drafts 04, 05, and 07 is fully tested and working. """ try: - return __from_schema(schema, custom_formats=custom_formats) + return __from_schema(schema, custom_formats=custom_formats, resolver=resolver) except Exception as err: error = err @@ -112,9 +118,10 @@ def __from_schema( schema: Union[bool, Schema], *, custom_formats: Dict[str, st.SearchStrategy[str]] = None, + resolver: Optional[LocalResolver] = None, ) -> st.SearchStrategy[JSONType]: try: - schema = resolve_all_refs(schema) + schema = resolve_all_refs(schema, resolver=resolver) except RecursionError: raise HypothesisRefResolutionError( f"Could not resolve recursive references in schema={schema!r}" @@ -141,6 +148,9 @@ def __from_schema( } custom_formats[_FORMATS_TOKEN] = None # type: ignore + if resolver is None: + resolver = LocalResolver.from_schema(deepcopy(schema)) + schema = canonicalish(schema) # Boolean objects are special schemata; False rejects all and True accepts all. if schema == FALSEY: @@ -155,24 +165,36 @@ def __from_schema( assert isinstance(schema, dict) # Now we handle as many validation keywords as we can... + if "$ref" in schema: + ref = schema["$ref"] + + def _recurse() -> st.SearchStrategy[JSONType]: + _, resolved = resolver.resolve(ref) # type: ignore + return from_schema( + resolved, custom_formats=custom_formats, resolver=resolver + ) + + return st.deferred(_recurse) # Applying subschemata with boolean logic if "not" in schema: not_ = schema.pop("not") assert isinstance(not_, dict) validator = make_validator(not_).is_valid - return from_schema(schema, custom_formats=custom_formats).filter( - lambda v: not validator(v) - ) + return from_schema( + schema, custom_formats=custom_formats, resolver=resolver + ).filter(lambda v: not validator(v)) if "anyOf" in schema: tmp = schema.copy() ao = tmp.pop("anyOf") assert isinstance(ao, list) - return st.one_of([merged_as_strategies([tmp, s], custom_formats) for s in ao]) + return st.one_of( + [merged_as_strategies([tmp, s], custom_formats, resolver) for s in ao] + ) if "allOf" in schema: tmp = schema.copy() ao = tmp.pop("allOf") assert isinstance(ao, list) - return merged_as_strategies([tmp] + ao, custom_formats) + return merged_as_strategies([tmp] + ao, custom_formats, resolver) if "oneOf" in schema: tmp = schema.copy() oo = tmp.pop("oneOf") @@ -180,7 +202,7 @@ def __from_schema( schemas = [merged([tmp, s]) for s in oo] return st.one_of( [ - from_schema(s, custom_formats=custom_formats) + from_schema(s, custom_formats=custom_formats, resolver=resolver) for s in schemas if s is not None ] @@ -198,8 +220,8 @@ def __from_schema( "number": number_schema, "integer": integer_schema, "string": partial(string_schema, custom_formats), - "array": partial(array_schema, custom_formats), - "object": partial(object_schema, custom_formats), + "array": partial(array_schema, custom_formats, resolver), + "object": partial(object_schema, custom_formats, resolver), } assert set(map_) == set(TYPE_STRINGS) return st.one_of([map_[t](schema) for t in get_type(schema)]) @@ -422,10 +444,14 @@ def string_schema( def array_schema( - custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict + custom_formats: Dict[str, st.SearchStrategy[str]], + resolver: LocalResolver, + schema: dict, ) -> st.SearchStrategy[List[JSONType]]: """Handle schemata for arrays.""" - _from_schema_ = partial(from_schema, custom_formats=custom_formats) + _from_schema_ = partial( + from_schema, custom_formats=custom_formats, resolver=resolver + ) items = schema.get("items", {}) additional_items = schema.get("additionalItems", {}) min_size = schema.get("minItems", 0) @@ -436,14 +462,16 @@ def array_schema( if max_size is not None: max_size -= len(items) - items_strats = [_from_schema_(s) for s in items] + items_strats = [_from_schema_(s) for s in deepcopy(items)] additional_items_strat = _from_schema_(additional_items) # If we have a contains schema to satisfy, we try generating from it when # allowed to do so. We'll skip the None (unmergable / no contains) cases # below, and let Hypothesis ignore the FALSEY cases for us. if "contains" in schema: - for i, mrgd in enumerate(merged([schema["contains"], s]) for s in items): + for i, mrgd in enumerate( + merged([schema["contains"], s]) for s in deepcopy(items) + ): if mrgd is not None: items_strats[i] |= _from_schema_(mrgd) contains_additional = merged([schema["contains"], additional_items]) @@ -480,10 +508,10 @@ def not_seen(elem: JSONType) -> bool: st.lists(additional_items_strat, min_size=min_size, max_size=max_size), ) else: - items_strat = _from_schema_(items) + items_strat = _from_schema_(deepcopy(items)) if "contains" in schema: contains_strat = _from_schema_(schema["contains"]) - if merged([items, schema["contains"]]) != schema["contains"]: + if merged([deepcopy(items), schema["contains"]]) != schema["contains"]: # We only need this filter if we couldn't merge items in when # canonicalising. Note that for list-items, above, we just skip # the mixed generation in this case (because they tend to be @@ -504,7 +532,9 @@ def not_seen(elem: JSONType) -> bool: def object_schema( - custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict + custom_formats: Dict[str, st.SearchStrategy[str]], + resolver: LocalResolver, + schema: dict, ) -> st.SearchStrategy[Dict[str, JSONType]]: """Handle a manageable subset of possible schemata for objects.""" required = schema.get("required", []) # required keys @@ -518,7 +548,7 @@ def object_schema( return st.builds(dict) names["type"] = "string" - properties = schema.get("properties", {}) # exact name: value schema + properties = deepcopy(schema.get("properties", {})) # exact name: value schema patterns = schema.get("patternProperties", {}) # regex for names: value schema # schema for other values; handled specially if nothing matches additional = schema.get("additionalProperties", {}) @@ -533,7 +563,7 @@ def object_schema( st.sampled_from(sorted(dep_names) + sorted(dep_schemas) + sorted(properties)) if (dep_names or dep_schemas or properties) else st.nothing(), - from_schema(names, custom_formats=custom_formats) + from_schema(names, custom_formats=custom_formats, resolver=resolver) if additional_allowed else st.nothing(), st.one_of([st.from_regex(p) for p in sorted(patterns)]), @@ -579,12 +609,20 @@ def from_object_schema(draw: Any) -> Any: if re.search(rgx, string=key) is not None ] if key in properties: - pattern_schemas.insert(0, properties[key]) + pattern_schemas.insert(0, deepcopy(properties[key])) if pattern_schemas: - out[key] = draw(merged_as_strategies(pattern_schemas, custom_formats)) + out[key] = draw( + merged_as_strategies(pattern_schemas, custom_formats, resolver) + ) else: - out[key] = draw(from_schema(additional, custom_formats=custom_formats)) + out[key] = draw( + from_schema( + deepcopy(additional), + custom_formats=custom_formats, + resolver=resolver, + ) + ) for k, v in dep_schemas.items(): if k in out and not make_validator(v).is_valid(out): diff --git a/tests/test_from_schema.py b/tests/test_from_schema.py index 4a6e246..2959432 100644 --- a/tests/test_from_schema.py +++ b/tests/test_from_schema.py @@ -425,3 +425,36 @@ def test_allowed_custom_format(num): def test_allowed_unknown_custom_format(string): assert string == "hello world" assert "not registered" not in jsonschema.FormatChecker().checkers + + +@pytest.mark.parametrize( + "schema", + ( + { + "properties": {"foo": {"$ref": "#"}}, + "additionalProperties": False, + "type": "object", + }, + { + "definitions": { + "Node": { + "type": "object", + "properties": { + "children": { + "type": "array", + "items": {"$ref": "#/definitions/Node"}, + "maxItems": 2, + } + }, + "required": ["children"], + "additionalProperties": False, + }, + }, + "$ref": "#/definitions/Node", + }, + ), +) +@given(data=st.data()) +def test_recursive_reference(data, schema): + value = data.draw(from_schema(schema)) + jsonschema.validate(value, schema) From f06c423865dc02adc1f51728321b49286a222298 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Sun, 9 Aug 2020 12:31:09 +0200 Subject: [PATCH 2/5] Further improvements --- src/hypothesis_jsonschema/_canonicalise.py | 6 +- src/hypothesis_jsonschema/_from_schema.py | 20 +++---- tests/test_from_schema.py | 70 ++++++++++++++++++++++ 3 files changed, 82 insertions(+), 14 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index 5d7ccc9..85c1dc6 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -580,7 +580,7 @@ def resolve_all_refs( ) def is_recursive(reference: str) -> bool: - return reference == "#" or resolver.resolution_scope == reference # type: ignore + return reference == "#" or reference in resolver._scopes_stack # type: ignore # To avoid infinite recursion, we skip all recursive definitions, and such references will be processed later # A definition is recursive if it contains a reference to itself or one of its ancestors. @@ -612,7 +612,9 @@ def is_recursive(reference: str) -> bool: subschema = schema[key] assert isinstance(subschema, dict) schema[key] = { - k: resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v + k: resolve_all_refs(deepcopy(v), resolver=resolver) + if isinstance(v, dict) + else v for k, v in subschema.items() } assert isinstance(schema, dict) diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index aedbf75..1f5712f 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -171,7 +171,7 @@ def __from_schema( def _recurse() -> st.SearchStrategy[JSONType]: _, resolved = resolver.resolve(ref) # type: ignore return from_schema( - resolved, custom_formats=custom_formats, resolver=resolver + deepcopy(resolved), custom_formats=custom_formats, resolver=resolver ) return st.deferred(_recurse) @@ -462,16 +462,14 @@ def array_schema( if max_size is not None: max_size -= len(items) - items_strats = [_from_schema_(s) for s in deepcopy(items)] + items_strats = [_from_schema_(s) for s in items] additional_items_strat = _from_schema_(additional_items) # If we have a contains schema to satisfy, we try generating from it when # allowed to do so. We'll skip the None (unmergable / no contains) cases # below, and let Hypothesis ignore the FALSEY cases for us. if "contains" in schema: - for i, mrgd in enumerate( - merged([schema["contains"], s]) for s in deepcopy(items) - ): + for i, mrgd in enumerate(merged([schema["contains"], s]) for s in items): if mrgd is not None: items_strats[i] |= _from_schema_(mrgd) contains_additional = merged([schema["contains"], additional_items]) @@ -508,10 +506,10 @@ def not_seen(elem: JSONType) -> bool: st.lists(additional_items_strat, min_size=min_size, max_size=max_size), ) else: - items_strat = _from_schema_(deepcopy(items)) + items_strat = _from_schema_(items) if "contains" in schema: contains_strat = _from_schema_(schema["contains"]) - if merged([deepcopy(items), schema["contains"]]) != schema["contains"]: + if merged([items, schema["contains"]]) != schema["contains"]: # We only need this filter if we couldn't merge items in when # canonicalising. Note that for list-items, above, we just skip # the mixed generation in this case (because they tend to be @@ -548,7 +546,7 @@ def object_schema( return st.builds(dict) names["type"] = "string" - properties = deepcopy(schema.get("properties", {})) # exact name: value schema + properties = schema.get("properties", {}) # exact name: value schema patterns = schema.get("patternProperties", {}) # regex for names: value schema # schema for other values; handled specially if nothing matches additional = schema.get("additionalProperties", {}) @@ -609,7 +607,7 @@ def from_object_schema(draw: Any) -> Any: if re.search(rgx, string=key) is not None ] if key in properties: - pattern_schemas.insert(0, deepcopy(properties[key])) + pattern_schemas.insert(0, properties[key]) if pattern_schemas: out[key] = draw( @@ -618,9 +616,7 @@ def from_object_schema(draw: Any) -> Any: else: out[key] = draw( from_schema( - deepcopy(additional), - custom_formats=custom_formats, - resolver=resolver, + additional, custom_formats=custom_formats, resolver=resolver, ) ) diff --git a/tests/test_from_schema.py b/tests/test_from_schema.py index 2959432..95bf3ca 100644 --- a/tests/test_from_schema.py +++ b/tests/test_from_schema.py @@ -452,9 +452,79 @@ def test_allowed_unknown_custom_format(string): }, "$ref": "#/definitions/Node", }, + # Simplified Open API schema + { + "type": "object", + "required": ["paths"], + "properties": {"paths": {"$ref": "#/definitions/Paths"}}, + "additionalProperties": False, + "definitions": { + "Schema": { + "type": "object", + "properties": {"items": {"$ref": "#/definitions/Schema"}}, + "additionalProperties": False, + }, + "MediaType": { + "type": "object", + "properties": {"schema": {"$ref": "#/definitions/Schema"}}, + "patternProperties": {"^x-": {}}, + "additionalProperties": False, + }, + "Paths": { + "type": "object", + "patternProperties": { + "^\\/": {"$ref": "#/definitions/PathItem"}, + "^x-": {}, + }, + "additionalProperties": False, + }, + "PathItem": { + "type": "object", + "properties": { + "parameters": { + "type": "array", + "items": {"$ref": "#/definitions/Parameter"}, + "uniqueItems": True, + }, + }, + "patternProperties": { + "^(get|put|post|delete|options|head|patch|trace)$": { + "$ref": "#/definitions/Operation" + }, + "^x-": {}, + }, + "additionalProperties": False, + }, + "Operation": { + "type": "object", + "required": ["responses"], + "properties": { + "parameters": { + "type": "array", + "items": {"$ref": "#/definitions/Parameter"}, + "uniqueItems": True, + }, + }, + "additionalProperties": False, + }, + "Parameter": { + "type": "object", + "properties": { + "schema": {"$ref": "#/definitions/Schema"}, + "content": { + "type": "object", + "minProperties": 1, + "maxProperties": 1, + }, + }, + "additionalProperties": False, + }, + }, + }, ), ) @given(data=st.data()) +@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much]) def test_recursive_reference(data, schema): value = data.draw(from_schema(schema)) jsonschema.validate(value, schema) From f03d92a9dbc2fd07451462c161799ed096a5ab36 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Tue, 25 Aug 2020 17:03:18 +0200 Subject: [PATCH 3/5] Copy data before recursing --- src/hypothesis_jsonschema/_canonicalise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index 85c1dc6..64caf38 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -600,11 +600,11 @@ def is_recursive(reference: str) -> bool: val = schema.get(key, False) if isinstance(val, list): schema[key] = [ - resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v + resolve_all_refs(deepcopy(v), resolver=resolver) if isinstance(v, dict) else v for v in val ] elif isinstance(val, dict): - schema[key] = resolve_all_refs(val, resolver=resolver) + schema[key] = resolve_all_refs(deepcopy(val), resolver=resolver) else: assert isinstance(val, bool) for key in SCHEMA_OBJECT_KEYS: # values are keys-to-schema-dicts, not schemas From 448e36ab9455785b8999be21326c38818229cb26 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Sun, 30 Aug 2020 16:37:26 +0200 Subject: [PATCH 4/5] Pass resolver where it is needed --- src/hypothesis_jsonschema/_canonicalise.py | 99 +++++++++++++--------- src/hypothesis_jsonschema/_from_schema.py | 72 +++++++++------- tests/test_canonicalise.py | 2 +- tests/test_from_schema.py | 7 +- 4 files changed, 102 insertions(+), 78 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index 64caf38..117c8ea 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -68,6 +68,13 @@ def next_down(val: float) -> float: return out +class LocalResolver(jsonschema.RefResolver): + def resolve_remote(self, uri: str) -> NoReturn: + raise HypothesisRefResolutionError( + f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})" + ) + + def _get_validator_class(schema: Schema) -> JSONSchemaValidator: try: validator = jsonschema.validators.validator_for(schema) @@ -202,7 +209,9 @@ def get_integer_bounds(schema: Schema) -> Tuple[Optional[int], Optional[int]]: return lower, upper -def canonicalish(schema: JSONType) -> Dict[str, Any]: +def canonicalish( + schema: JSONType, resolver: Optional[LocalResolver] = None +) -> Dict[str, Any]: """Convert a schema into a more-canonical form. This is obviously incomplete, but improves best-effort recognition of @@ -224,12 +233,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: "but expected a dict." ) + if resolver is None: + resolver = LocalResolver.from_schema(deepcopy(schema)) + if "const" in schema: - if not make_validator(schema).is_valid(schema["const"]): + if not make_validator(schema, resolver=resolver).is_valid(schema["const"]): return FALSEY return {"const": schema["const"]} if "enum" in schema: - validator = make_validator(schema) + validator = make_validator(schema, resolver=resolver) enum_ = sorted( (v for v in schema["enum"] if validator.is_valid(v)), key=sort_key ) @@ -253,15 +265,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: # Recurse into the value of each keyword with a schema (or list of them) as a value for key in SCHEMA_KEYS: if isinstance(schema.get(key), list): - schema[key] = [canonicalish(v) for v in schema[key]] + schema[key] = [canonicalish(v, resolver=resolver) for v in schema[key]] elif isinstance(schema.get(key), (bool, dict)): - schema[key] = canonicalish(schema[key]) + schema[key] = canonicalish(schema[key], resolver=resolver) else: assert key not in schema, (key, schema[key]) for key in SCHEMA_OBJECT_KEYS: if key in schema: schema[key] = { - k: v if isinstance(v, list) else canonicalish(v) + k: v if isinstance(v, list) else canonicalish(v, resolver=resolver) for k, v in schema[key].items() } @@ -307,7 +319,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: if "array" in type_ and "contains" in schema: if isinstance(schema.get("items"), dict): - contains_items = merged([schema["contains"], schema["items"]]) + contains_items = merged( + [schema["contains"], schema["items"]], resolver=resolver + ) if contains_items is not None: schema["contains"] = contains_items @@ -432,7 +446,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: type_.remove("object") else: propnames = schema.get("propertyNames", {}) - validator = make_validator(propnames) + validator = make_validator(propnames, resolver=resolver) if not all(validator.is_valid(name) for name in schema["required"]): type_.remove("object") @@ -461,9 +475,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: type_.remove(t) if t not in ("integer", "number"): not_["type"].remove(t) - not_ = canonicalish(not_) + not_ = canonicalish(not_, resolver=resolver) - m = merged([not_, {**schema, "type": type_}]) + m = merged([not_, {**schema, "type": type_}], resolver=resolver) if m is not None: not_ = m if not_ != FALSEY: @@ -525,7 +539,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: else: tmp = schema.copy() ao = tmp.pop("allOf") - out = merged([tmp] + ao) + out = merged([tmp] + ao, resolver=resolver) if isinstance(out, dict): # pragma: no branch schema = out # TODO: this assertion is soley because mypy 0.750 doesn't know @@ -537,7 +551,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: one_of = sorted(one_of, key=encode_canonical_json) one_of = [s for s in one_of if s != FALSEY] if len(one_of) == 1: - m = merged([schema, one_of[0]]) + m = merged([schema, one_of[0]], resolver=resolver) if m is not None: # pragma: no branch return m if (not one_of) or one_of.count(TRUTHY) > 1: @@ -552,13 +566,6 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: FALSEY = canonicalish(False) -class LocalResolver(jsonschema.RefResolver): - def resolve_remote(self, uri: str) -> NoReturn: - raise HypothesisRefResolutionError( - f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})" - ) - - def resolve_all_refs( schema: Union[bool, Schema], *, resolver: LocalResolver = None ) -> Schema: @@ -590,7 +597,7 @@ def is_recursive(reference: str) -> bool: with resolver.resolving(ref) as got: if s == {}: return resolve_all_refs(got, resolver=resolver) - m = merged([s, got]) + m = merged([s, got], resolver=resolver) if m is None: # pragma: no cover msg = f"$ref:{ref!r} had incompatible base schema {s!r}" raise HypothesisRefResolutionError(msg) @@ -600,7 +607,9 @@ def is_recursive(reference: str) -> bool: val = schema.get(key, False) if isinstance(val, list): schema[key] = [ - resolve_all_refs(deepcopy(v), resolver=resolver) if isinstance(v, dict) else v + resolve_all_refs(deepcopy(v), resolver=resolver) + if isinstance(v, dict) + else v for v in val ] elif isinstance(val, dict): @@ -621,7 +630,9 @@ def is_recursive(reference: str) -> bool: return schema -def merged(schemas: List[Any]) -> Optional[Schema]: +def merged( + schemas: List[Any], resolver: Optional[LocalResolver] = None +) -> Optional[Schema]: """Merge *n* schemas into a single schema, or None if result is invalid. Takes the logical intersection, so any object that validates against the returned @@ -634,7 +645,9 @@ def merged(schemas: List[Any]) -> Optional[Schema]: It's currently also used for keys that could be merged but aren't yet. """ assert schemas, "internal error: must pass at least one schema to merge" - schemas = sorted((canonicalish(s) for s in schemas), key=upper_bound_instances) + schemas = sorted( + (canonicalish(s, resolver=resolver) for s in schemas), key=upper_bound_instances + ) if any(s == FALSEY for s in schemas): return FALSEY out = schemas[0] @@ -643,11 +656,11 @@ def merged(schemas: List[Any]) -> Optional[Schema]: continue # If we have a const or enum, this is fairly easy by filtering: if "const" in out: - if make_validator(s).is_valid(out["const"]): + if make_validator(s, resolver=resolver).is_valid(out["const"]): continue return FALSEY if "enum" in out: - validator = make_validator(s) + validator = make_validator(s, resolver=resolver) enum_ = [v for v in out["enum"] if validator.is_valid(v)] if not enum_: return FALSEY @@ -698,21 +711,23 @@ def merged(schemas: List[Any]) -> Optional[Schema]: else: out_combined = merged( [s for p, s in out_pat.items() if re.search(p, prop_name)] - or [out_add] + or [out_add], + resolver=resolver, ) if prop_name in s_props: s_combined = s_props[prop_name] else: s_combined = merged( [s for p, s in s_pat.items() if re.search(p, prop_name)] - or [s_add] + or [s_add], + resolver=resolver, ) if out_combined is None or s_combined is None: # pragma: no cover # Note that this can only be the case if we were actually going to # use the schema which we attempted to merge, i.e. prop_name was # not in the schema and there were unmergable pattern schemas. return None - m = merged([out_combined, s_combined]) + m = merged([out_combined, s_combined], resolver=resolver) if m is None: return None out_props[prop_name] = m @@ -720,14 +735,17 @@ def merged(schemas: List[Any]) -> Optional[Schema]: # simpler as we merge with either an identical pattern, or additionalProperties. if out_pat or s_pat: for pattern in set(out_pat) | set(s_pat): - m = merged([out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)]) + m = merged( + [out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)], + resolver=resolver, + ) if m is None: # pragma: no cover return None out_pat[pattern] = m out["patternProperties"] = out_pat # Finally, we merge togther the additionalProperties schemas. if out_add or s_add: - m = merged([out_add, s_add]) + m = merged([out_add, s_add], resolver=resolver) if m is None: # pragma: no cover return None out["additionalProperties"] = m @@ -761,7 +779,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: return None if "contains" in out and "contains" in s and out["contains"] != s["contains"]: # If one `contains` schema is a subset of the other, we can discard it. - m = merged([out["contains"], s["contains"]]) + m = merged([out["contains"], s["contains"]], resolver=resolver) if m == out["contains"] or m == s["contains"]: out["contains"] = m s.pop("contains") @@ -791,7 +809,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: v = {"required": v} elif isinstance(sval, list): sval = {"required": sval} - m = merged([v, sval]) + m = merged([v, sval], resolver=resolver) if m is None: return None odeps[k] = m @@ -805,26 +823,27 @@ def merged(schemas: List[Any]) -> Optional[Schema]: [ out.get("additionalItems", TRUTHY), s.get("additionalItems", TRUTHY), - ] + ], + resolver=resolver, ) for a, b in itertools.zip_longest(oitems, sitems): if a is None: a = out.get("additionalItems", TRUTHY) elif b is None: b = s.get("additionalItems", TRUTHY) - out["items"].append(merged([a, b])) + out["items"].append(merged([a, b], resolver=resolver)) elif isinstance(oitems, list): - out["items"] = [merged([x, sitems]) for x in oitems] + out["items"] = [merged([x, sitems], resolver=resolver) for x in oitems] out["additionalItems"] = merged( - [out.get("additionalItems", TRUTHY), sitems] + [out.get("additionalItems", TRUTHY), sitems], resolver=resolver ) elif isinstance(sitems, list): - out["items"] = [merged([x, oitems]) for x in sitems] + out["items"] = [merged([x, oitems], resolver=resolver) for x in sitems] out["additionalItems"] = merged( - [s.get("additionalItems", TRUTHY), oitems] + [s.get("additionalItems", TRUTHY), oitems], resolver=resolver ) else: - out["items"] = merged([oitems, sitems]) + out["items"] = merged([oitems, sitems], resolver=resolver) if out["items"] is None: return None if isinstance(out["items"], list) and None in out["items"]: @@ -848,7 +867,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: # If non-validation keys like `title` or `description` don't match, # that doesn't really matter and we'll just go with first we saw. return None - out = canonicalish(out) + out = canonicalish(out, resolver=resolver) if out == FALSEY: return FALSEY assert isinstance(out, dict) diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index 1f5712f..91a5c7e 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -18,7 +18,6 @@ FALSEY, TRUTHY, TYPE_STRINGS, - HypothesisRefResolutionError, JSONType, LocalResolver, Schema, @@ -61,9 +60,9 @@ def merged_as_strategies( ): if combined.issuperset(group): continue - s = merged([inputs[g] for g in group]) + s = merged([inputs[g] for g in group], resolver=resolver) if s is not None and s != FALSEY: - validators = [make_validator(s) for s in schemas] + validators = [make_validator(s, resolver=resolver) for s in schemas] strats.append( from_schema(s, custom_formats=custom_formats, resolver=resolver).filter( lambda obj: all(v.is_valid(obj) for v in validators) @@ -120,12 +119,7 @@ def __from_schema( custom_formats: Dict[str, st.SearchStrategy[str]] = None, resolver: Optional[LocalResolver] = None, ) -> st.SearchStrategy[JSONType]: - try: - schema = resolve_all_refs(schema, resolver=resolver) - except RecursionError: - raise HypothesisRefResolutionError( - f"Could not resolve recursive references in schema={schema!r}" - ) from None + schema = resolve_all_refs(schema, resolver=resolver) # We check for _FORMATS_TOKEN to avoid re-validating known good data. if custom_formats is not None and _FORMATS_TOKEN not in custom_formats: assert isinstance(custom_formats, dict) @@ -151,7 +145,7 @@ def __from_schema( if resolver is None: resolver = LocalResolver.from_schema(deepcopy(schema)) - schema = canonicalish(schema) + schema = canonicalish(schema, resolver) # Boolean objects are special schemata; False rejects all and True accepts all. if schema == FALSEY: return st.nothing() @@ -169,8 +163,9 @@ def __from_schema( ref = schema["$ref"] def _recurse() -> st.SearchStrategy[JSONType]: - _, resolved = resolver.resolve(ref) # type: ignore - return from_schema( + url, resolved = resolver.resolve(ref) # type: ignore + resolver.push_scope(url) # type: ignore + return __from_schema( deepcopy(resolved), custom_formats=custom_formats, resolver=resolver ) @@ -179,7 +174,7 @@ def _recurse() -> st.SearchStrategy[JSONType]: if "not" in schema: not_ = schema.pop("not") assert isinstance(not_, dict) - validator = make_validator(not_).is_valid + validator = make_validator(not_, resolver=resolver).is_valid return from_schema( schema, custom_formats=custom_formats, resolver=resolver ).filter(lambda v: not validator(v)) @@ -199,14 +194,14 @@ def _recurse() -> st.SearchStrategy[JSONType]: tmp = schema.copy() oo = tmp.pop("oneOf") assert isinstance(oo, list) - schemas = [merged([tmp, s]) for s in oo] + schemas = [merged([tmp, s], resolver=resolver) for s in oo] return st.one_of( [ from_schema(s, custom_formats=custom_formats, resolver=resolver) for s in schemas if s is not None ] - ).filter(make_validator(schema).is_valid) + ).filter(make_validator(schema, resolver=resolver).is_valid) # Simple special cases if "enum" in schema: assert schema["enum"], "Canonicalises to non-empty list or FALSEY" @@ -217,8 +212,9 @@ def _recurse() -> st.SearchStrategy[JSONType]: map_: Dict[str, Callable[[Schema], st.SearchStrategy[JSONType]]] = { "null": lambda _: st.none(), "boolean": lambda _: st.booleans(), - "number": number_schema, - "integer": integer_schema, + # Mypy doesn't recognize that `resolver` has the `LocalResolver` type + "number": lambda s: number_schema(s, resolver=resolver), # type: ignore + "integer": lambda s: integer_schema(s, resolver=resolver), # type: ignore "string": partial(string_schema, custom_formats), "array": partial(array_schema, custom_formats, resolver), "object": partial(object_schema, custom_formats, resolver), @@ -228,7 +224,10 @@ def _recurse() -> st.SearchStrategy[JSONType]: def _numeric_with_multiplier( - min_value: Optional[float], max_value: Optional[float], schema: Schema + min_value: Optional[float], + max_value: Optional[float], + schema: Schema, + resolver: LocalResolver, ) -> st.SearchStrategy[float]: """Handle numeric schemata containing the multipleOf key.""" multiple_of = schema["multipleOf"] @@ -246,23 +245,23 @@ def _numeric_with_multiplier( return ( st.integers(min_value, max_value) .map(lambda x: x * multiple_of) - .filter(make_validator(schema).is_valid) + .filter(make_validator(schema, resolver=resolver).is_valid) ) -def integer_schema(schema: dict) -> st.SearchStrategy[float]: +def integer_schema(schema: dict, resolver: LocalResolver) -> st.SearchStrategy[float]: """Handle integer schemata.""" min_value, max_value = get_integer_bounds(schema) if "multipleOf" in schema: - return _numeric_with_multiplier(min_value, max_value, schema) + return _numeric_with_multiplier(min_value, max_value, schema, resolver) return st.integers(min_value, max_value) -def number_schema(schema: dict) -> st.SearchStrategy[float]: +def number_schema(schema: dict, resolver: LocalResolver) -> st.SearchStrategy[float]: """Handle numeric schemata.""" min_value, max_value, exclude_min, exclude_max = get_number_bounds(schema) if "multipleOf" in schema: - return _numeric_with_multiplier(min_value, max_value, schema) + return _numeric_with_multiplier(min_value, max_value, schema, resolver) return st.floats( min_value=min_value, max_value=max_value, @@ -469,10 +468,14 @@ def array_schema( # allowed to do so. We'll skip the None (unmergable / no contains) cases # below, and let Hypothesis ignore the FALSEY cases for us. if "contains" in schema: - for i, mrgd in enumerate(merged([schema["contains"], s]) for s in items): + for i, mrgd in enumerate( + merged([schema["contains"], s], resolver=resolver) for s in items + ): if mrgd is not None: items_strats[i] |= _from_schema_(mrgd) - contains_additional = merged([schema["contains"], additional_items]) + contains_additional = merged( + [schema["contains"], additional_items], resolver=resolver + ) if contains_additional is not None: additional_items_strat |= _from_schema_(contains_additional) @@ -509,12 +512,17 @@ def not_seen(elem: JSONType) -> bool: items_strat = _from_schema_(items) if "contains" in schema: contains_strat = _from_schema_(schema["contains"]) - if merged([items, schema["contains"]]) != schema["contains"]: + if ( + merged([items, schema["contains"]], resolver=resolver) + != schema["contains"] + ): # We only need this filter if we couldn't merge items in when # canonicalising. Note that for list-items, above, we just skip # the mixed generation in this case (because they tend to be # heterogeneous) and hope it works out anyway. - contains_strat = contains_strat.filter(make_validator(items).is_valid) + contains_strat = contains_strat.filter( + make_validator(items, resolver=resolver).is_valid + ) items_strat |= contains_strat strat = st.lists( @@ -525,7 +533,7 @@ def not_seen(elem: JSONType) -> bool: ) if "contains" not in schema: return strat - contains = make_validator(schema["contains"]).is_valid + contains = make_validator(schema["contains"], resolver=resolver).is_valid return strat.filter(lambda val: any(contains(x) for x in val)) @@ -567,7 +575,7 @@ def object_schema( st.one_of([st.from_regex(p) for p in sorted(patterns)]), ) all_names_strategy = st.one_of([s for s in name_strats if not s.is_empty]).filter( - make_validator(names).is_valid + make_validator(names, resolver=resolver).is_valid ) @st.composite # type: ignore @@ -611,7 +619,9 @@ def from_object_schema(draw: Any) -> Any: if pattern_schemas: out[key] = draw( - merged_as_strategies(pattern_schemas, custom_formats, resolver) + merged_as_strategies( + pattern_schemas, custom_formats, resolver=resolver + ) ) else: out[key] = draw( @@ -621,7 +631,7 @@ def from_object_schema(draw: Any) -> Any: ) for k, v in dep_schemas.items(): - if k in out and not make_validator(v).is_valid(out): + if k in out and not make_validator(v, resolver=resolver).is_valid(out): out.pop(key) elements.reject() diff --git a/tests/test_canonicalise.py b/tests/test_canonicalise.py index 45f4a03..6d0376e 100644 --- a/tests/test_canonicalise.py +++ b/tests/test_canonicalise.py @@ -20,7 +20,7 @@ def is_valid(instance, schema): - return make_validator(schema).is_valid(instance) + return make_validator(schema, resolver=None).is_valid(instance) @settings(suppress_health_check=[HealthCheck.too_slow], deadline=None) diff --git a/tests/test_from_schema.py b/tests/test_from_schema.py index 95bf3ca..f2b9dd8 100644 --- a/tests/test_from_schema.py +++ b/tests/test_from_schema.py @@ -242,16 +242,11 @@ def inner(*args, **kwargs): assert isinstance(name, str) try: f(*args, **kwargs) - assert name not in RECURSIVE_REFS except jsonschema.exceptions.RefResolutionError as err: if ( isinstance(err, HypothesisRefResolutionError) or isinstance(err._cause, HypothesisRefResolutionError) - ) and ( - "does not fetch remote references" in str(err) - or name in RECURSIVE_REFS - and "Could not resolve recursive references" in str(err) - ): + ) and "does not fetch remote references" in str(err): pytest.xfail() raise From c8f78f97b0135fe29ade2aecd17996be19209004 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Sun, 30 Aug 2020 18:53:02 +0200 Subject: [PATCH 5/5] Add seen_map --- src/hypothesis_jsonschema/_canonicalise.py | 60 ++++++++++++---------- src/hypothesis_jsonschema/_from_schema.py | 1 - 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index 117c8ea..93175f3 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -17,7 +17,8 @@ import math import re from copy import deepcopy -from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union +from typing import Any, Dict, List, NoReturn, Optional, Set, Tuple, Union +from urllib.parse import urljoin import jsonschema from hypothesis.errors import InvalidArgument @@ -85,9 +86,9 @@ def _get_validator_class(schema: Schema) -> JSONSchemaValidator: return validator -def make_validator(schema: Schema) -> JSONSchemaValidator: +def make_validator(schema: Schema, resolver: LocalResolver) -> JSONSchemaValidator: validator = _get_validator_class(schema) - return validator(schema) + return validator(schema, resolver=resolver) class HypothesisRefResolutionError(jsonschema.exceptions.RefResolutionError): @@ -567,15 +568,14 @@ def canonicalish( def resolve_all_refs( - schema: Union[bool, Schema], *, resolver: LocalResolver = None + schema: Union[bool, Schema], + *, + resolver: LocalResolver = None, + seen_map: Dict[str, Set[str]] = None, ) -> Schema: - """ - Resolve all references in the given schema. - - This handles nested definitions, but not recursive definitions. - The latter require special handling to convert to strategies and are much - less common, so we just ignore them (and error out) for now. - """ + """Resolve all non-recursive references in the given schema.""" + if seen_map is None: + seen_map = {} if isinstance(schema, bool): return canonicalish(schema) assert isinstance(schema, dict), schema @@ -587,33 +587,41 @@ def resolve_all_refs( ) def is_recursive(reference: str) -> bool: - return reference == "#" or reference in resolver._scopes_stack # type: ignore + full_ref = urljoin(resolver.base_uri, reference) # type: ignore + return reference == "#" or reference in resolver._scopes_stack or full_ref in resolver._scopes_stack # type: ignore # To avoid infinite recursion, we skip all recursive definitions, and such references will be processed later # A definition is recursive if it contains a reference to itself or one of its ancestors. - if "$ref" in schema and not is_recursive(schema["$ref"]): # type: ignore - s = dict(schema) - ref = s.pop("$ref") - with resolver.resolving(ref) as got: - if s == {}: - return resolve_all_refs(got, resolver=resolver) - m = merged([s, got], resolver=resolver) - if m is None: # pragma: no cover - msg = f"$ref:{ref!r} had incompatible base schema {s!r}" - raise HypothesisRefResolutionError(msg) - return resolve_all_refs(m, resolver=resolver) + if "$ref" in schema: + path = "-".join(resolver._scopes_stack) + seen_paths = seen_map.setdefault(path, set()) + if schema["$ref"] not in seen_paths and not is_recursive(schema["$ref"]): # type: ignore + seen_paths.add(schema["$ref"]) # type: ignore + s = dict(schema) + ref = s.pop("$ref") + with resolver.resolving(ref) as got: + if s == {}: + return resolve_all_refs(got, resolver=resolver, seen_map=seen_map) + m = merged([s, got]) + if m is None: # pragma: no cover + msg = f"$ref:{ref!r} had incompatible base schema {s!r}" + raise HypothesisRefResolutionError(msg) + + return resolve_all_refs(m, resolver=resolver, seen_map=seen_map) for key in SCHEMA_KEYS: val = schema.get(key, False) if isinstance(val, list): schema[key] = [ - resolve_all_refs(deepcopy(v), resolver=resolver) + resolve_all_refs(deepcopy(v), resolver=resolver, seen_map=seen_map) if isinstance(v, dict) else v for v in val ] elif isinstance(val, dict): - schema[key] = resolve_all_refs(deepcopy(val), resolver=resolver) + schema[key] = resolve_all_refs( + deepcopy(val), resolver=resolver, seen_map=seen_map + ) else: assert isinstance(val, bool) for key in SCHEMA_OBJECT_KEYS: # values are keys-to-schema-dicts, not schemas @@ -621,7 +629,7 @@ def is_recursive(reference: str) -> bool: subschema = schema[key] assert isinstance(subschema, dict) schema[key] = { - k: resolve_all_refs(deepcopy(v), resolver=resolver) + k: resolve_all_refs(deepcopy(v), resolver=resolver, seen_map=seen_map) if isinstance(v, dict) else v for k, v in subschema.items() diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index 91a5c7e..658cb42 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -18,7 +18,6 @@ FALSEY, TRUTHY, TYPE_STRINGS, - JSONType, LocalResolver, Schema, canonicalish,