From 047a63cf947dac461f51562170393bd08b5cd55b Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Sat, 18 Mar 2023 14:43:58 +0100 Subject: [PATCH] Better handling of content types (#46) Closes https://github.com/biopragmatics/bioregistry/issues/775. This PR adds handling of headers to both the Flask and FastAPI implementations of the apps. - [x] Add Flask implementation - [x] Add FastAPI implementation - [x] Add Flask tests - [x] Add FastAPI tests - [ ] Should the `output` parameter be supported? CC @vemonet. Ideally, I'd like to use https://github.com/vemonet/rdflib-endpoint and not re-implement this code, but we'll have to work through a few issues first (improving code modularity, documentation, and figuring out Flask suppot) before I can give that a try --- src/curies/mapping_service/__init__.py | 54 ++++++++-- tests/test_mapping_service.py | 142 ++++++++++++++++++------- 2 files changed, 145 insertions(+), 51 deletions(-) diff --git a/src/curies/mapping_service/__init__.py b/src/curies/mapping_service/__init__.py index a4dcc2b..a5b1cd9 100644 --- a/src/curies/mapping_service/__init__.py +++ b/src/curies/mapping_service/__init__.py @@ -104,7 +104,7 @@ """ import itertools as itt -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Set, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple, Union, cast from rdflib import OWL, Graph, URIRef from rdflib.term import _is_valid_uri @@ -227,6 +227,36 @@ def triples( yield subj_query, pred, obj +#: This is default for federated queries +DEFAULT_CONTENT_TYPE = "application/sparql-results+xml" + +#: A mapping from content types to the keys used for serializing +#: in :meth:`rdflib.Graph.serialize` and other serialization functions +CONTENT_TYPE_TO_RDFLIB_FORMAT = { + # https://www.w3.org/TR/sparql11-results-json/ + "application/sparql-results+json": "json", + "application/json": "json", + "text/json": "json", + # https://www.w3.org/TR/rdf-sparql-XMLres/ + "application/sparql-results+xml": "xml", + "application/xml": "xml", # for compatibility + "text/xml": "xml", # not standard + # https://www.w3.org/TR/sparql11-results-csv-tsv/ + "application/sparql-results+csv": "csv", + "text/csv": "csv", # for compatibility + # TODO other direct RDF serializations + # "text/turtle": "ttl", + # "text/n3": "n3", + # "application/ld+json": "json-ld", +} + + +def _handle_header(header: Optional[str]) -> str: + if not header or header == "*/*": + return DEFAULT_CONTENT_TYPE + return header + + def get_flask_mapping_blueprint( converter: Converter, route: str = "/sparql", **kwargs: Any ) -> "flask.Blueprint": @@ -249,8 +279,10 @@ def serve_sparql() -> "Response": sparql = (request.args if request.method == "GET" else request.json).get("query") if not sparql: return Response("Missing parameter query", 400) - results = graph.query(sparql, processor=processor).serialize(format="json") - return Response(results) + content_type = _handle_header(request.headers.get("accept")) + results = graph.query(sparql, processor=processor) + response = results.serialize(format=CONTENT_TYPE_TO_RDFLIB_FORMAT[content_type]) + return Response(response, content_type=content_type) return blueprint @@ -265,7 +297,7 @@ def get_fastapi_router( :param kwargs: Keyword arguments passed through to :class:`fastapi.APIRouter` :return: A router """ - from fastapi import APIRouter, Query, Response + from fastapi import APIRouter, Query, Request, Response from pydantic import BaseModel class QueryModel(BaseModel): # type:ignore @@ -277,22 +309,24 @@ class QueryModel(BaseModel): # type:ignore graph = MappingServiceGraph(converter=converter) processor = MappingServiceSPARQLProcessor(graph=graph) - def _resolve(sparql: str) -> Response: + def _resolve(request: Request, sparql: str) -> Response: + content_type = _handle_header(request.headers.get("accept")) results = graph.query(sparql, processor=processor) - # TODO enable different serializations - return Response(results.serialize(format="json"), media_type="application/json") + response = results.serialize(format=CONTENT_TYPE_TO_RDFLIB_FORMAT[content_type]) + return Response(response, media_type=content_type) @api_router.get(route) # type:ignore def resolve_get( + request: Request, query: str = Query(title="Query", description="The SPARQL query to run"), # noqa:B008 ) -> Response: """Run a SPARQL query and serve the results.""" - return _resolve(query) + return _resolve(request, query) @api_router.post(route) # type:ignore - def resolve_post(query: QueryModel) -> Response: + def resolve_post(request: Request, query: QueryModel) -> Response: """Run a SPARQL query and serve the results.""" - return _resolve(query.query) + return _resolve(request, query.query) return api_router diff --git a/tests/test_mapping_service.py b/tests/test_mapping_service.py index b091553..162c22b 100644 --- a/tests/test_mapping_service.py +++ b/tests/test_mapping_service.py @@ -6,6 +6,7 @@ import unittest from typing import Iterable, Set, Tuple from urllib.parse import quote +from xml import etree from fastapi.testclient import TestClient from rdflib import OWL, SKOS @@ -13,6 +14,8 @@ from curies import Converter from curies.mapping_service import ( + CONTENT_TYPE_TO_RDFLIB_FORMAT, + DEFAULT_CONTENT_TYPE, MappingServiceGraph, MappingServiceSPARQLProcessor, _prepare_predicates, @@ -168,6 +171,69 @@ def test_safe_expand(self): ) +def _handle_json(data) -> Set[Tuple[str, str]]: + return {(record["s"]["value"], record["o"]["value"]) for record in data["results"]["bindings"]} + + +def _handle_res_xml(res) -> Set[Tuple[str, str]]: + root = etree.ElementTree.fromstring(res.text) # noqa:S314 + results = root.find("{http://www.w3.org/2005/sparql-results#}results") + rv = set() + for result in results: + parsed_result = { + binding.attrib["name"]: binding.find("{http://www.w3.org/2005/sparql-results#}uri").text + for binding in result + } + rv.add((parsed_result["s"], parsed_result["o"])) + return rv + + +def _handle_res_json(res) -> Set[Tuple[str, str]]: + return _handle_json(json.loads(res.text)) + + +def _handle_res_csv(res) -> Set[Tuple[str, str]]: + header, *lines = (line.strip().split(",") for line in res.text.splitlines()) + records = (dict(zip(header, line)) for line in lines) + return {(record["s"], record["o"]) for record in records} + + +# def _handle_res_rdf(res, format) -> Set[Tuple[str, str]]: +# graph = rdflib.Graph() +# graph.parse(res, format=format) +# return {(str(s), str(o)) for s, o in graph.subject_objects()} + + +CONTENT_TYPES = { + "application/sparql-results+json": _handle_res_json, + "application/json": _handle_res_json, + "text/json": _handle_res_json, + "application/sparql-results+xml": _handle_res_xml, + "application/xml": _handle_res_xml, + "text/xml": _handle_res_xml, + "application/sparql-results+csv": _handle_res_csv, + "text/csv": _handle_res_csv, + # "text/turtle": partial(_handle_res_rdf, format="ttl"), + # "text/n3": partial(_handle_res_rdf, format="n3"), + # "application/ld+json": partial(_handle_res_rdf, format="json-ld"), +} +CONTENT_TYPES[""] = CONTENT_TYPES[DEFAULT_CONTENT_TYPE] +CONTENT_TYPES["*/*"] = CONTENT_TYPES[DEFAULT_CONTENT_TYPE] + + +class TestCompleteness(unittest.TestCase): + """Test that tests are complete.""" + + def test_content_types(self): + """Test that all content types are covered.""" + self.assertEqual( + sorted(CONTENT_TYPE_TO_RDFLIB_FORMAT), + sorted( + content_type for content_type in CONTENT_TYPES if content_type not in ["", "*/*"] + ), + ) + + class ConverterMixin(unittest.TestCase): """A mixin that has a converter.""" @@ -178,13 +244,25 @@ def setUp(self) -> None: def assert_get_sparql_results(self, client, sparql): """Test a sparql query returns expected values.""" - res = client.get(f"/sparql?query={quote(sparql)}") - self.assertEqual(200, res.status_code, msg=f"Response: {res}\n\n{res.text}") - records = { - (record["s"]["value"], record["o"]["value"]) - for record in json.loads(res.text)["results"]["bindings"] - } - self.assertEqual(EXPECTED, records) + for content_type, parse_func in sorted(CONTENT_TYPES.items()): + with self.subTest(content_type=content_type): + res = client.get(f"/sparql?query={quote(sparql)}", headers={"accept": content_type}) + self.assertEqual(200, res.status_code, msg=f"Response: {res}\n\n{res.text}") + self.assertEqual(EXPECTED, parse_func(res)) + + def assert_post_sparql_results(self, client, sparql): + """Test a sparql query returns expected values.""" + for content_type, parse_func in sorted(CONTENT_TYPES.items()): + with self.subTest(content_type=content_type): + res = client.post( + "/sparql", json={"query": sparql}, headers={"accept": content_type} + ) + self.assertEqual( + 200, + res.status_code, + msg=f"Response: {res}", + ) + self.assertEqual(EXPECTED, parse_func(res)) class TestFlaskMappingWeb(ConverterMixin): @@ -195,29 +273,21 @@ def setUp(self) -> None: super().setUp() self.app = get_flask_mapping_app(self.converter) - def assert_post_sparql_results(self, client, sparql): - """Test a sparql query returns expected values.""" - res = client.post("/sparql", json={"query": sparql}) - self.assertEqual( - 200, res.status_code, msg=f"\nRequest: {res.request}\nResponse: {res}\n\n{res.json}" - ) - records = { - (record["s"]["value"], record["o"]["value"]) - for record in json.loads(res.text)["results"]["bindings"] - } - self.assertEqual(EXPECTED, records) - def test_get_missing_query(self): """Test error on missing query parameter.""" with self.app.test_client() as client: - res = client.get("/sparql") - self.assertEqual(400, res.status_code, msg=f"Response: {res}") + for content_type in sorted(CONTENT_TYPES): + with self.subTest(content_type=content_type): + res = client.get("/sparql", headers={"accept": content_type}) + self.assertEqual(400, res.status_code, msg=f"Response: {res}") def test_post_missing_query(self): """Test error on missing query parameter.""" with self.app.test_client() as client: - res = client.post("/sparql") - self.assertEqual(400, res.status_code, msg=f"Response: {res}") + for content_type in sorted(CONTENT_TYPES): + with self.subTest(content_type=content_type): + res = client.post("/sparql", headers={"accept": content_type}) + self.assertEqual(400, res.status_code, msg=f"Response: {res}") def test_get_query(self): """Test querying the app with GET.""" @@ -249,29 +319,19 @@ def setUp(self) -> None: self.app = get_fastapi_mapping_app(self.converter) self.client = TestClient(self.app) - def assert_post_sparql_results(self, client, sparql): - """Test a sparql query returns expected values.""" - res = client.post("/sparql", json={"query": sparql}) - self.assertEqual( - 200, - res.status_code, - msg=f"Response: {res}", - ) - records = { - (record["s"]["value"], record["o"]["value"]) - for record in res.json()["results"]["bindings"] - } - self.assertEqual(EXPECTED, records) - def test_get_missing_query(self): """Test error on missing query parameter.""" - res = self.client.get("/sparql") - self.assertEqual(422, res.status_code, msg=f"Response: {res}") + for content_type in sorted(CONTENT_TYPES): + with self.subTest(content_type=content_type): + res = self.client.get("/sparql", headers={"accept": content_type}) + self.assertEqual(422, res.status_code, msg=f"Response: {res}") def test_post_missing_query(self): """Test error on missing query parameter.""" - res = self.client.post("/sparql") - self.assertEqual(422, res.status_code, msg=f"Response: {res}") + for content_type in sorted(CONTENT_TYPES): + with self.subTest(content_type=content_type): + res = self.client.post("/sparql", headers={"accept": content_type}) + self.assertEqual(422, res.status_code, msg=f"Response: {res}") def test_get_query(self): """Test querying the app with GET."""