Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enable strict schema enforcement #3514

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion ludwig/config_validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_schema(model_type: str = MODEL_ECD):
"title": "model_options",
"description": "Settings for Ludwig configuration",
"required": required,
"additionalProperties": True, # TODO: Set to false after 0.8 releases.
"additionalProperties": False, # TODO: May cause collision
}


Expand Down
1 change: 0 additions & 1 deletion ludwig/schema/features/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def get_augmentation_list_jsonschema(feature_type: str, default: List[Dict[str,
"description": "Type of augmentation to apply.",
},
},
"additionalProperties": True,
"allOf": get_augmentation_list_conds(feature_type),
"required": ["type"],
},
Expand Down
2 changes: 1 addition & 1 deletion ludwig/schema/features/preprocessing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _jsonschema_type_mapping(self):
"type": "object",
"properties": props,
"title": "preprocessing_options",
"additionalProperties": True,
"additionalProperties": False, # TODO: May cause collision
}

try:
Expand Down
2 changes: 0 additions & 2 deletions ludwig/schema/features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def get_input_feature_jsonschema(model_type: str):
},
"column": {"type": "string", "title": "column", "description": "Name of the column."},
},
"additionalProperties": True,
"allOf": get_input_feature_conds(model_type),
"required": ["name", "type"],
"title": "input_feature",
Expand Down Expand Up @@ -126,7 +125,6 @@ def get_output_feature_jsonschema(model_type: str):
},
"column": {"type": "string", "title": "column", "description": "Name of the column."},
},
"additionalProperties": True,
"allOf": get_output_feature_conds(model_type),
"required": ["name", "type"],
"title": "output_feature",
Expand Down
4 changes: 0 additions & 4 deletions ludwig/schema/model_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ def from_dict(config: ModelConfigDict) -> "ModelConfig":
# have `additionalProperties=False`, does not.
#
# Illustrative example: test_validate_config_misc.py::test_validate_no_trainer_type
#
# TODO: Set `additionalProperties=False` for all Ludwig schema, and look into passing in `unknown='RAISE'` to
# marshmallow.load(), which raises an error for unknown fields during deserialization.
# https://marshmallow.readthedocs.io/en/stable/marshmallow.schema.html#marshmallow.schema.Schema.load
check_schema(config)

cls = model_type_schema_registry[model_type]
Expand Down
8 changes: 4 additions & 4 deletions ludwig/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import marshmallow_dataclass
import yaml
from marshmallow import EXCLUDE, fields, pre_load, schema, validate, ValidationError
from marshmallow import EXCLUDE, fields, pre_load, RAISE, schema, validate, ValidationError
from marshmallow.utils import missing
from marshmallow_dataclass import dataclass as m_dataclass
from marshmallow_jsonschema import JSONSchema as js
Expand Down Expand Up @@ -152,7 +152,7 @@ def to_list(self) -> TList:


# TODO: Change to RAISE and update descriptions once we want to enforce strict schemas.
LUDWIG_SCHEMA_VALIDATION_POLICY_VAR = os.environ.get(LUDWIG_SCHEMA_VALIDATION_POLICY, EXCLUDE).lower()
LUDWIG_SCHEMA_VALIDATION_POLICY_VAR = os.environ.get(LUDWIG_SCHEMA_VALIDATION_POLICY, RAISE).lower()


@DeveloperAPI
Expand Down Expand Up @@ -227,7 +227,7 @@ def assert_is_a_marshmallow_class(cls):


@DeveloperAPI
def unload_jsonschema_from_marshmallow_class(mclass, additional_properties: bool = True, title: str = None) -> TDict:
def unload_jsonschema_from_marshmallow_class(mclass, additional_properties: bool = False, title: str = None) -> TDict:
"""Helper method to directly get a marshmallow class's JSON schema without extra wrapping props."""
assert_is_a_marshmallow_class(mclass)
schema = js(props_ordered=True).dump(mclass.Schema())["definitions"][mclass.__name__]
Expand All @@ -236,7 +236,7 @@ def unload_jsonschema_from_marshmallow_class(mclass, additional_properties: bool
prop_schema = schema["properties"][prop]
if "parameter_metadata" in prop_schema:
prop_schema["parameter_metadata"] = copy.deepcopy(prop_schema["parameter_metadata"])
schema["additionalProperties"] = additional_properties
schema["additionalProperties"] = additional_properties # TODO: May cause collision
if title is not None:
schema["title"] = title
return schema
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def category_feature(output_feature=False, **kwargs):
else:
feature.update(
{
ENCODER: {"vocab_size": 10, "embedding_size": 5},
ENCODER: {"embedding_size": 5},
}
)
recursive_update(feature, kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,12 @@ def test_decoder_descriptions():
def test_deprecation_warning_raised_for_unknown_parameters():
config = {
"input_features": [
category_feature(encoder={"type": "dense", "vocab_size": 2}, reduce_input="sum"),
category_feature(encoder={"type": "dense"}),
number_feature(),
],
"output_features": [binary_feature()],
"combiner": {
"type": "tabnet",
"unknown_parameter_combiner": False,
},
TRAINER: {
"epochs": 1000,
Expand Down