Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of content types #46

Merged
merged 10 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions src/curies/mapping_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"""

import itertools as itt
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Set, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple, Union, cast

from rdflib import OWL, Graph, URIRef
from rdflib.term import _is_valid_uri
Expand Down Expand Up @@ -227,6 +227,36 @@ def triples(
yield subj_query, pred, obj


#: This is default for federated queries
DEFAULT_CONTENT_TYPE = "application/sparql-results+xml"

#: A mapping from content types to the keys used for serializing
#: in :meth:`rdflib.Graph.serialize` and other serialization functions
CONTENT_TYPE_TO_RDFLIB_FORMAT = {
# https://www.w3.org/TR/sparql11-results-json/
"application/sparql-results+json": "json",
"application/json": "json",
"text/json": "json",
# https://www.w3.org/TR/rdf-sparql-XMLres/
"application/sparql-results+xml": "xml",
"application/xml": "xml", # for compatibility
"text/xml": "xml", # not standard
# https://www.w3.org/TR/sparql11-results-csv-tsv/
"application/sparql-results+csv": "csv",
"text/csv": "csv", # for compatibility
# TODO other direct RDF serializations
# "text/turtle": "ttl",
# "text/n3": "n3",
# "application/ld+json": "json-ld",
}


def _handle_header(header: Optional[str]) -> str:
if not header or header == "*/*":
return DEFAULT_CONTENT_TYPE
return header


def get_flask_mapping_blueprint(
converter: Converter, route: str = "/sparql", **kwargs: Any
) -> "flask.Blueprint":
Expand All @@ -249,8 +279,10 @@ def serve_sparql() -> "Response":
sparql = (request.args if request.method == "GET" else request.json).get("query")
if not sparql:
return Response("Missing parameter query", 400)
results = graph.query(sparql, processor=processor).serialize(format="json")
return Response(results)
content_type = _handle_header(request.headers.get("accept"))
results = graph.query(sparql, processor=processor)
response = results.serialize(format=CONTENT_TYPE_TO_RDFLIB_FORMAT[content_type])
return Response(response, content_type=content_type)

return blueprint

Expand All @@ -265,7 +297,7 @@ def get_fastapi_router(
:param kwargs: Keyword arguments passed through to :class:`fastapi.APIRouter`
:return: A router
"""
from fastapi import APIRouter, Query, Response
from fastapi import APIRouter, Query, Request, Response
from pydantic import BaseModel

class QueryModel(BaseModel): # type:ignore
Expand All @@ -277,22 +309,24 @@ class QueryModel(BaseModel): # type:ignore
graph = MappingServiceGraph(converter=converter)
processor = MappingServiceSPARQLProcessor(graph=graph)

def _resolve(sparql: str) -> Response:
def _resolve(request: Request, sparql: str) -> Response:
content_type = _handle_header(request.headers.get("accept"))
results = graph.query(sparql, processor=processor)
# TODO enable different serializations
return Response(results.serialize(format="json"), media_type="application/json")
response = results.serialize(format=CONTENT_TYPE_TO_RDFLIB_FORMAT[content_type])
return Response(response, media_type=content_type)

@api_router.get(route) # type:ignore
def resolve_get(
request: Request,
query: str = Query(title="Query", description="The SPARQL query to run"), # noqa:B008
) -> Response:
"""Run a SPARQL query and serve the results."""
return _resolve(query)
return _resolve(request, query)

@api_router.post(route) # type:ignore
def resolve_post(query: QueryModel) -> Response:
def resolve_post(request: Request, query: QueryModel) -> Response:
"""Run a SPARQL query and serve the results."""
return _resolve(query.query)
return _resolve(request, query.query)

return api_router

Expand Down
142 changes: 101 additions & 41 deletions tests/test_mapping_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import unittest
from typing import Iterable, Set, Tuple
from urllib.parse import quote
from xml import etree

from fastapi.testclient import TestClient
from rdflib import OWL, SKOS
from rdflib.query import ResultRow

from curies import Converter
from curies.mapping_service import (
CONTENT_TYPE_TO_RDFLIB_FORMAT,
DEFAULT_CONTENT_TYPE,
MappingServiceGraph,
MappingServiceSPARQLProcessor,
_prepare_predicates,
Expand Down Expand Up @@ -168,6 +171,69 @@ def test_safe_expand(self):
)


def _handle_json(data) -> Set[Tuple[str, str]]:
return {(record["s"]["value"], record["o"]["value"]) for record in data["results"]["bindings"]}


def _handle_res_xml(res) -> Set[Tuple[str, str]]:
root = etree.ElementTree.fromstring(res.text) # noqa:S314
results = root.find("{http://www.w3.org/2005/sparql-results#}results")
rv = set()
for result in results:
parsed_result = {
binding.attrib["name"]: binding.find("{http://www.w3.org/2005/sparql-results#}uri").text
for binding in result
}
rv.add((parsed_result["s"], parsed_result["o"]))
return rv


def _handle_res_json(res) -> Set[Tuple[str, str]]:
return _handle_json(json.loads(res.text))


def _handle_res_csv(res) -> Set[Tuple[str, str]]:
header, *lines = (line.strip().split(",") for line in res.text.splitlines())
records = (dict(zip(header, line)) for line in lines)
return {(record["s"], record["o"]) for record in records}


# def _handle_res_rdf(res, format) -> Set[Tuple[str, str]]:
# graph = rdflib.Graph()
# graph.parse(res, format=format)
# return {(str(s), str(o)) for s, o in graph.subject_objects()}


CONTENT_TYPES = {
"application/sparql-results+json": _handle_res_json,
"application/json": _handle_res_json,
"text/json": _handle_res_json,
"application/sparql-results+xml": _handle_res_xml,
"application/xml": _handle_res_xml,
"text/xml": _handle_res_xml,
"application/sparql-results+csv": _handle_res_csv,
"text/csv": _handle_res_csv,
# "text/turtle": partial(_handle_res_rdf, format="ttl"),
# "text/n3": partial(_handle_res_rdf, format="n3"),
# "application/ld+json": partial(_handle_res_rdf, format="json-ld"),
}
CONTENT_TYPES[""] = CONTENT_TYPES[DEFAULT_CONTENT_TYPE]
CONTENT_TYPES["*/*"] = CONTENT_TYPES[DEFAULT_CONTENT_TYPE]


class TestCompleteness(unittest.TestCase):
"""Test that tests are complete."""

def test_content_types(self):
"""Test that all content types are covered."""
self.assertEqual(
sorted(CONTENT_TYPE_TO_RDFLIB_FORMAT),
sorted(
content_type for content_type in CONTENT_TYPES if content_type not in ["", "*/*"]
),
)


class ConverterMixin(unittest.TestCase):
"""A mixin that has a converter."""

Expand All @@ -178,13 +244,25 @@ def setUp(self) -> None:

def assert_get_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
res = client.get(f"/sparql?query={quote(sparql)}")
self.assertEqual(200, res.status_code, msg=f"Response: {res}\n\n{res.text}")
records = {
(record["s"]["value"], record["o"]["value"])
for record in json.loads(res.text)["results"]["bindings"]
}
self.assertEqual(EXPECTED, records)
for content_type, parse_func in sorted(CONTENT_TYPES.items()):
with self.subTest(content_type=content_type):
res = client.get(f"/sparql?query={quote(sparql)}", headers={"accept": content_type})
self.assertEqual(200, res.status_code, msg=f"Response: {res}\n\n{res.text}")
self.assertEqual(EXPECTED, parse_func(res))

def assert_post_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
for content_type, parse_func in sorted(CONTENT_TYPES.items()):
with self.subTest(content_type=content_type):
res = client.post(
"/sparql", json={"query": sparql}, headers={"accept": content_type}
)
self.assertEqual(
200,
res.status_code,
msg=f"Response: {res}",
)
self.assertEqual(EXPECTED, parse_func(res))


class TestFlaskMappingWeb(ConverterMixin):
Expand All @@ -195,29 +273,21 @@ def setUp(self) -> None:
super().setUp()
self.app = get_flask_mapping_app(self.converter)

def assert_post_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
res = client.post("/sparql", json={"query": sparql})
self.assertEqual(
200, res.status_code, msg=f"\nRequest: {res.request}\nResponse: {res}\n\n{res.json}"
)
records = {
(record["s"]["value"], record["o"]["value"])
for record in json.loads(res.text)["results"]["bindings"]
}
self.assertEqual(EXPECTED, records)

def test_get_missing_query(self):
"""Test error on missing query parameter."""
with self.app.test_client() as client:
res = client.get("/sparql")
self.assertEqual(400, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = client.get("/sparql", headers={"accept": content_type})
self.assertEqual(400, res.status_code, msg=f"Response: {res}")

def test_post_missing_query(self):
"""Test error on missing query parameter."""
with self.app.test_client() as client:
res = client.post("/sparql")
self.assertEqual(400, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = client.post("/sparql", headers={"accept": content_type})
self.assertEqual(400, res.status_code, msg=f"Response: {res}")

def test_get_query(self):
"""Test querying the app with GET."""
Expand Down Expand Up @@ -249,29 +319,19 @@ def setUp(self) -> None:
self.app = get_fastapi_mapping_app(self.converter)
self.client = TestClient(self.app)

def assert_post_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
res = client.post("/sparql", json={"query": sparql})
self.assertEqual(
200,
res.status_code,
msg=f"Response: {res}",
)
records = {
(record["s"]["value"], record["o"]["value"])
for record in res.json()["results"]["bindings"]
}
self.assertEqual(EXPECTED, records)

def test_get_missing_query(self):
"""Test error on missing query parameter."""
res = self.client.get("/sparql")
self.assertEqual(422, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = self.client.get("/sparql", headers={"accept": content_type})
self.assertEqual(422, res.status_code, msg=f"Response: {res}")

def test_post_missing_query(self):
"""Test error on missing query parameter."""
res = self.client.post("/sparql")
self.assertEqual(422, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = self.client.post("/sparql", headers={"accept": content_type})
self.assertEqual(422, res.status_code, msg=f"Response: {res}")

def test_get_query(self):
"""Test querying the app with GET."""
Expand Down