Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement SNS Filter/operators $or, suffix, equals-ignore-case, anything-but #10691

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
176 changes: 113 additions & 63 deletions localstack/services/sns/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@ class SubscriptionFilter:
def check_filter_policy_on_message_attributes(
self, filter_policy: dict, message_attributes: dict
):
for criteria, conditions in filter_policy.items():
if not self._evaluate_filter_policy_conditions_on_attribute(
conditions,
message_attributes.get(criteria),
field_exists=criteria in message_attributes,
):
return False
if not filter_policy:
return True

return True
flat_policy_conditions = self.flatten_policy(filter_policy)

return any(
all(
self._evaluate_filter_policy_conditions_on_attribute(
conditions,
message_attributes.get(criteria),
field_exists=criteria in message_attributes,
)
for criteria, conditions in flat_policy.items()
)
for flat_policy in flat_policy_conditions
)

def check_filter_policy_on_message_body(self, filter_policy: dict, message_body: str):
try:
Expand Down Expand Up @@ -45,18 +52,26 @@ def _evaluate_nested_filter_policy_on_dict(self, filter_policy, payload: dict) -
:param payload: a dict, starting at the MessageBody
:return: True if the payload respect the filter policy, otherwise False
"""
flat_policy = self._flatten_dict(filter_policy)
flat_payloads = self._flatten_dict_with_list(payload)
for key, values in flat_policy.items():
if not any(
self._evaluate_condition(
flat_payload.get(key), condition, field_exists=key in flat_payload
if not filter_policy:
return True

# TODO: maybe save/cache the flattened/expanded policy?
flat_policy_conditions = self.flatten_policy(filter_policy)
flat_payloads = self.flatten_payload(payload)

return any(
all(
any(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤣 Ok, I understand why you want to keep it separate from the attributes policy...

self._evaluate_condition(
flat_payload.get(key), condition, field_exists=key in flat_payload
)
for condition in values
for flat_payload in flat_payloads
)
for condition in values
for flat_payload in flat_payloads
):
return False
return True
for key, values in flat_policy.items()
)
for flat_policy in flat_policy_conditions
)

def _evaluate_filter_policy_conditions_on_attribute(
self, conditions, attribute, field_exists: bool
Expand Down Expand Up @@ -94,11 +109,19 @@ def _evaluate_condition(self, value, condition, field_exists: bool):
# the remaining conditions require the value to not be None
return False
elif anything_but := condition.get("anything-but"):
# TODO: support with `prefix`
# https://docs.aws.amazon.com/sns/latest/dg/string-value-matching.html#string-anything-but-matching-prefix
return value not in anything_but
elif prefix := (condition.get("prefix")):
if isinstance(anything_but, dict):
not_prefix = anything_but.get("prefix")
return not value.startswith(not_prefix)
elif isinstance(anything_but, list):
return value not in anything_but
else:
return value != anything_but
elif prefix := condition.get("prefix"):
return value.startswith(prefix)
elif suffix := condition.get("suffix"):
return value.endswith(suffix)
elif equal_ignore_case := condition.get("equals-ignore-case"):
return equal_ignore_case.lower() == value.lower()
elif numeric_condition := condition.get("numeric"):
return self._evaluate_numeric_condition(numeric_condition, value)
return False
Expand Down Expand Up @@ -135,35 +158,59 @@ def _evaluate_numeric_condition(conditions, value):
return True

@staticmethod
def _flatten_dict(nested_dict: dict):
def flatten_policy(nested_dict: dict) -> list[dict]:
"""
Takes a dictionary as input and will output the dictionary on a single level.
Input:
`{"field1": {"field2: {"field3: "val1", "field4": "val2"}}}`
`{"field1": {"field2": {"field3": "val1", "field4": "val2"}}}`
Output:
`{
"field1.field2.field3": "val1",
"field1.field2.field4": "val1"
}`
`[
{
"field1.field2.field3": "val1",
"field1.field2.field4": "val2"
}
]`
Input with $or will create multiple outputs:
`{"$or": [{"field1": "val1"}, {"field2": "val2"}], "field3": "val3"}`
Output:
`[
{"field1": "val1", "field3": "val3"},
{"field2": "val2", "field3": "val3"}
]`
:param nested_dict: a (nested) dictionary
:return: a list of flattened dictionaries with no nested dict or list inside, flattened to a
single level, one list item for every list item encountered
"""
flatten = {}

def _traverse(_policy: dict, parent_key=None):
for key, values in _policy.items():
flattened_parent_key = key if not parent_key else f"{parent_key}.{key}"
if not isinstance(values, dict):
flatten[flattened_parent_key] = values
def _traverse_policy(obj, array=None, parent_key=None) -> list:
if array is None:
array = [{}]

for key, values in obj.items():
if key == "$or" and isinstance(values, list) and len(values) > 1:
# $or will create multiple new branches in the array.
# Each current branch will traverse with each choice in $or
array = [
i for value in values for i in _traverse_policy(value, array, parent_key)
]
else:
_traverse(values, parent_key=flattened_parent_key)
# We update the parent key do that {"key1": {"key2": ""}} becomes "key1.key2"
_parent_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(values, dict):
# If the current key has child dict -- key: "key1", child: {"key2": ["val1", val2"]}
# We only update the parent_key and traverse its children with the current branches
array = _traverse_policy(values, array, _parent_key)
else:
# If the current key has no child, this means we found the values to match -- child: ["val1", val2"]
# we update the branches with the parent chain and the values -- {"key1.key2": ["val1, val2"]}
array = [{**item, _parent_key: values} for item in array]

_traverse(nested_dict)
return flatten
return array

return _traverse_policy(nested_dict)

@staticmethod
def _flatten_dict_with_list(nested_dict: dict) -> list[dict]:
def flatten_payload(nested_dict: dict) -> list[dict]:
"""
Takes a dictionary as input and will output the dictionary on a single level.
The dictionary can have lists containing other dictionaries, and one root level entry will be created for every
Expand All @@ -189,37 +236,22 @@ def _flatten_dict_with_list(nested_dict: dict) -> list[dict]:
:param nested_dict: a (nested) dictionary
:return: flatten_dict: a dictionary with no nested dict inside, flattened to a single level
"""
flattened = []
current_object = {}

def _traverse(_object, parent_key=None):
def _traverse(_object: dict, array=None, parent_key=None) -> list:
if isinstance(_object, dict):
for key, values in _object.items():
flattened_parent_key = key if not parent_key else f"{parent_key}.{key}"
_traverse(values, flattened_parent_key)
# We update the parent key do that {"key1": {"key2": ""}} becomes "key1.key2"
_parent_key = f"{parent_key}.{key}" if parent_key else key
array = _traverse(values, array, _parent_key)

# we don't have to worry about `parent_key` being None for list or any other type, because we have a check
# that the first object is always a dict, thus setting a parent key on first iteration
elif isinstance(_object, list):
for value in _object:
if isinstance(value, (dict, list)):
_traverse(value, parent_key=parent_key)
else:
current_object[parent_key] = value

if current_object:
flattened.append({**current_object})
current_object.clear()
array = [i for value in _object for i in _traverse(value, array, parent_key)]
else:
current_object[parent_key] = _object

_traverse(nested_dict)
array = [{**item, parent_key: _object} for item in array]

# if the payload did not have any list, we manually append the current object
if not flattened:
flattened.append(current_object)
return array

return flattened
return _traverse(nested_dict, array=[{}], parent_key=None)


class FilterPolicyValidator:
Expand Down Expand Up @@ -340,7 +372,6 @@ def _validate_rule(self, rule: t.Any) -> None:
operator, value = k, v

if operator in (
"anything-but",
"equals-ignore-case",
"prefix",
"suffix",
Expand All @@ -351,6 +382,25 @@ def _validate_rule(self, rule: t.Any) -> None:
)
return

elif operator == "anything-but":
# anything-but can actually contain any kind of simple rule (str, number, and list)
if isinstance(value, list):
for v in value:
self._validate_rule(v)

return

# or have a nested `prefix` pattern
elif isinstance(value, dict):
for inner_operator in value.keys():
if inner_operator != "prefix":
raise InvalidParameterException(
f"{self.error_prefix}FilterPolicy: Unsupported anything-but pattern: {inner_operator}"
)

self._validate_rule(value)
return

elif operator == "exists":
if not isinstance(value, bool):
raise InvalidParameterException(
Expand Down
81 changes: 79 additions & 2 deletions tests/aws/services/sns/test_sns_filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,6 @@ def get_messages(_queue_url: str, _received_messages: list):
snapshot.match("messages", {"Messages": received_messages})

@markers.aws.validated
@pytest.mark.skip("Not yet supported by LocalStack")
def test_filter_policy_on_message_body_or_attribute(
self,
sqs_create_queue,
Expand Down Expand Up @@ -1237,7 +1236,7 @@ def test_validate_policy_string_operators(
topic_arn = sns_create_topic()["TopicArn"]

def _subscribe(policy: dict):
sns_subscription(
return sns_subscription(
TopicArn=topic_arn,
Protocol="sms",
Endpoint=phone_number,
Expand All @@ -1262,6 +1261,18 @@ def _subscribe(policy: dict):
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-is-not-list-and-operator", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"suffix": []}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-empty-list", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"suffix": ["test", "test2"]}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-list-wrong-type", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": {"suffix": "value", "prefix": "value"}}
_subscribe(filter_policy)
Expand Down Expand Up @@ -1413,6 +1424,72 @@ def _subscribe(policy: dict):
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-string", e.value.response)

@markers.aws.validated
@markers.snapshot.skip_snapshot_verify(paths=["$..Error.Message"])
def test_validate_policy_nested_anything_but_operator(
self,
sns_create_topic,
sns_subscription,
snapshot,
aws_client,
):
phone_number = "+123123123"
topic_arn = sns_create_topic()["TopicArn"]

def _subscribe(policy: dict):
return sns_subscription(
TopicArn=topic_arn,
Protocol="sms",
Endpoint=phone_number,
Attributes={"FilterPolicy": json.dumps(policy)},
)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"wrong-operator": None}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-wrong-operator", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"suffix": "test"}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-anything-but-suffix", e.value.response)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I kept wondering if this was a thing we needed to add support for! ⚡

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah me too, I thought the documentation was lying... but no, they really do not support anything else than prefix 😆


with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"exists": False}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-anything-but-exists", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"prefix": False}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-anything-but-prefix-wrong-type", e.value.response)

# positive testing
filter_policy = {"key": [{"anything-but": {"prefix": "test-"}}]}
response = _subscribe(filter_policy)
assert "SubscriptionArn" in response
subscription_arn = response["SubscriptionArn"]

filter_policy = {"key": [{"anything-but": ["test", "test2"]}]}
response = aws_client.sns.set_subscription_attributes(
SubscriptionArn=subscription_arn,
AttributeName="FilterPolicy",
AttributeValue=json.dumps(filter_policy),
)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

filter_policy = {"key": [{"anything-but": "test"}]}
response = aws_client.sns.set_subscription_attributes(
SubscriptionArn=subscription_arn,
AttributeName="FilterPolicy",
AttributeValue=json.dumps(filter_policy),
)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

@markers.aws.validated
def test_policy_complexity(
self,
Expand Down