Skip to content

Commit

Permalink
Better handling of content types (#46)
Browse files Browse the repository at this point in the history
Closes biopragmatics/bioregistry#775. 

This PR adds handling of headers to both the Flask and FastAPI
implementations of the apps.

- [x] Add Flask implementation
- [x] Add FastAPI implementation
- [x] Add Flask tests
- [x] Add FastAPI tests
- [ ] Should the `output` parameter be supported?

CC @vemonet. Ideally, I'd like to use
https://github.com/vemonet/rdflib-endpoint and not re-implement this
code, but we'll have to work through a few issues first (improving code
modularity, documentation, and figuring out Flask suppot) before I can
give that a try
  • Loading branch information
cthoyt committed Mar 18, 2023
1 parent 89ef83a commit 047a63c
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 51 deletions.
54 changes: 44 additions & 10 deletions src/curies/mapping_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"""

import itertools as itt
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Set, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple, Union, cast

from rdflib import OWL, Graph, URIRef
from rdflib.term import _is_valid_uri
Expand Down Expand Up @@ -227,6 +227,36 @@ def triples(
yield subj_query, pred, obj


#: This is default for federated queries
DEFAULT_CONTENT_TYPE = "application/sparql-results+xml"

#: A mapping from content types to the keys used for serializing
#: in :meth:`rdflib.Graph.serialize` and other serialization functions
CONTENT_TYPE_TO_RDFLIB_FORMAT = {
# https://www.w3.org/TR/sparql11-results-json/
"application/sparql-results+json": "json",
"application/json": "json",
"text/json": "json",
# https://www.w3.org/TR/rdf-sparql-XMLres/
"application/sparql-results+xml": "xml",
"application/xml": "xml", # for compatibility
"text/xml": "xml", # not standard
# https://www.w3.org/TR/sparql11-results-csv-tsv/
"application/sparql-results+csv": "csv",
"text/csv": "csv", # for compatibility
# TODO other direct RDF serializations
# "text/turtle": "ttl",
# "text/n3": "n3",
# "application/ld+json": "json-ld",
}


def _handle_header(header: Optional[str]) -> str:
if not header or header == "*/*":
return DEFAULT_CONTENT_TYPE
return header


def get_flask_mapping_blueprint(
converter: Converter, route: str = "/sparql", **kwargs: Any
) -> "flask.Blueprint":
Expand All @@ -249,8 +279,10 @@ def serve_sparql() -> "Response":
sparql = (request.args if request.method == "GET" else request.json).get("query")
if not sparql:
return Response("Missing parameter query", 400)
results = graph.query(sparql, processor=processor).serialize(format="json")
return Response(results)
content_type = _handle_header(request.headers.get("accept"))
results = graph.query(sparql, processor=processor)
response = results.serialize(format=CONTENT_TYPE_TO_RDFLIB_FORMAT[content_type])
return Response(response, content_type=content_type)

return blueprint

Expand All @@ -265,7 +297,7 @@ def get_fastapi_router(
:param kwargs: Keyword arguments passed through to :class:`fastapi.APIRouter`
:return: A router
"""
from fastapi import APIRouter, Query, Response
from fastapi import APIRouter, Query, Request, Response
from pydantic import BaseModel

class QueryModel(BaseModel): # type:ignore
Expand All @@ -277,22 +309,24 @@ class QueryModel(BaseModel): # type:ignore
graph = MappingServiceGraph(converter=converter)
processor = MappingServiceSPARQLProcessor(graph=graph)

def _resolve(sparql: str) -> Response:
def _resolve(request: Request, sparql: str) -> Response:
content_type = _handle_header(request.headers.get("accept"))
results = graph.query(sparql, processor=processor)
# TODO enable different serializations
return Response(results.serialize(format="json"), media_type="application/json")
response = results.serialize(format=CONTENT_TYPE_TO_RDFLIB_FORMAT[content_type])
return Response(response, media_type=content_type)

@api_router.get(route) # type:ignore
def resolve_get(
request: Request,
query: str = Query(title="Query", description="The SPARQL query to run"), # noqa:B008
) -> Response:
"""Run a SPARQL query and serve the results."""
return _resolve(query)
return _resolve(request, query)

@api_router.post(route) # type:ignore
def resolve_post(query: QueryModel) -> Response:
def resolve_post(request: Request, query: QueryModel) -> Response:
"""Run a SPARQL query and serve the results."""
return _resolve(query.query)
return _resolve(request, query.query)

return api_router

Expand Down
142 changes: 101 additions & 41 deletions tests/test_mapping_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import unittest
from typing import Iterable, Set, Tuple
from urllib.parse import quote
from xml import etree

from fastapi.testclient import TestClient
from rdflib import OWL, SKOS
from rdflib.query import ResultRow

from curies import Converter
from curies.mapping_service import (
CONTENT_TYPE_TO_RDFLIB_FORMAT,
DEFAULT_CONTENT_TYPE,
MappingServiceGraph,
MappingServiceSPARQLProcessor,
_prepare_predicates,
Expand Down Expand Up @@ -168,6 +171,69 @@ def test_safe_expand(self):
)


def _handle_json(data) -> Set[Tuple[str, str]]:
return {(record["s"]["value"], record["o"]["value"]) for record in data["results"]["bindings"]}


def _handle_res_xml(res) -> Set[Tuple[str, str]]:
root = etree.ElementTree.fromstring(res.text) # noqa:S314
results = root.find("{http://www.w3.org/2005/sparql-results#}results")
rv = set()
for result in results:
parsed_result = {
binding.attrib["name"]: binding.find("{http://www.w3.org/2005/sparql-results#}uri").text
for binding in result
}
rv.add((parsed_result["s"], parsed_result["o"]))
return rv


def _handle_res_json(res) -> Set[Tuple[str, str]]:
return _handle_json(json.loads(res.text))


def _handle_res_csv(res) -> Set[Tuple[str, str]]:
header, *lines = (line.strip().split(",") for line in res.text.splitlines())
records = (dict(zip(header, line)) for line in lines)
return {(record["s"], record["o"]) for record in records}


# def _handle_res_rdf(res, format) -> Set[Tuple[str, str]]:
# graph = rdflib.Graph()
# graph.parse(res, format=format)
# return {(str(s), str(o)) for s, o in graph.subject_objects()}


CONTENT_TYPES = {
"application/sparql-results+json": _handle_res_json,
"application/json": _handle_res_json,
"text/json": _handle_res_json,
"application/sparql-results+xml": _handle_res_xml,
"application/xml": _handle_res_xml,
"text/xml": _handle_res_xml,
"application/sparql-results+csv": _handle_res_csv,
"text/csv": _handle_res_csv,
# "text/turtle": partial(_handle_res_rdf, format="ttl"),
# "text/n3": partial(_handle_res_rdf, format="n3"),
# "application/ld+json": partial(_handle_res_rdf, format="json-ld"),
}
CONTENT_TYPES[""] = CONTENT_TYPES[DEFAULT_CONTENT_TYPE]
CONTENT_TYPES["*/*"] = CONTENT_TYPES[DEFAULT_CONTENT_TYPE]


class TestCompleteness(unittest.TestCase):
"""Test that tests are complete."""

def test_content_types(self):
"""Test that all content types are covered."""
self.assertEqual(
sorted(CONTENT_TYPE_TO_RDFLIB_FORMAT),
sorted(
content_type for content_type in CONTENT_TYPES if content_type not in ["", "*/*"]
),
)


class ConverterMixin(unittest.TestCase):
"""A mixin that has a converter."""

Expand All @@ -178,13 +244,25 @@ def setUp(self) -> None:

def assert_get_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
res = client.get(f"/sparql?query={quote(sparql)}")
self.assertEqual(200, res.status_code, msg=f"Response: {res}\n\n{res.text}")
records = {
(record["s"]["value"], record["o"]["value"])
for record in json.loads(res.text)["results"]["bindings"]
}
self.assertEqual(EXPECTED, records)
for content_type, parse_func in sorted(CONTENT_TYPES.items()):
with self.subTest(content_type=content_type):
res = client.get(f"/sparql?query={quote(sparql)}", headers={"accept": content_type})
self.assertEqual(200, res.status_code, msg=f"Response: {res}\n\n{res.text}")
self.assertEqual(EXPECTED, parse_func(res))

def assert_post_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
for content_type, parse_func in sorted(CONTENT_TYPES.items()):
with self.subTest(content_type=content_type):
res = client.post(
"/sparql", json={"query": sparql}, headers={"accept": content_type}
)
self.assertEqual(
200,
res.status_code,
msg=f"Response: {res}",
)
self.assertEqual(EXPECTED, parse_func(res))


class TestFlaskMappingWeb(ConverterMixin):
Expand All @@ -195,29 +273,21 @@ def setUp(self) -> None:
super().setUp()
self.app = get_flask_mapping_app(self.converter)

def assert_post_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
res = client.post("/sparql", json={"query": sparql})
self.assertEqual(
200, res.status_code, msg=f"\nRequest: {res.request}\nResponse: {res}\n\n{res.json}"
)
records = {
(record["s"]["value"], record["o"]["value"])
for record in json.loads(res.text)["results"]["bindings"]
}
self.assertEqual(EXPECTED, records)

def test_get_missing_query(self):
"""Test error on missing query parameter."""
with self.app.test_client() as client:
res = client.get("/sparql")
self.assertEqual(400, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = client.get("/sparql", headers={"accept": content_type})
self.assertEqual(400, res.status_code, msg=f"Response: {res}")

def test_post_missing_query(self):
"""Test error on missing query parameter."""
with self.app.test_client() as client:
res = client.post("/sparql")
self.assertEqual(400, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = client.post("/sparql", headers={"accept": content_type})
self.assertEqual(400, res.status_code, msg=f"Response: {res}")

def test_get_query(self):
"""Test querying the app with GET."""
Expand Down Expand Up @@ -249,29 +319,19 @@ def setUp(self) -> None:
self.app = get_fastapi_mapping_app(self.converter)
self.client = TestClient(self.app)

def assert_post_sparql_results(self, client, sparql):
"""Test a sparql query returns expected values."""
res = client.post("/sparql", json={"query": sparql})
self.assertEqual(
200,
res.status_code,
msg=f"Response: {res}",
)
records = {
(record["s"]["value"], record["o"]["value"])
for record in res.json()["results"]["bindings"]
}
self.assertEqual(EXPECTED, records)

def test_get_missing_query(self):
"""Test error on missing query parameter."""
res = self.client.get("/sparql")
self.assertEqual(422, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = self.client.get("/sparql", headers={"accept": content_type})
self.assertEqual(422, res.status_code, msg=f"Response: {res}")

def test_post_missing_query(self):
"""Test error on missing query parameter."""
res = self.client.post("/sparql")
self.assertEqual(422, res.status_code, msg=f"Response: {res}")
for content_type in sorted(CONTENT_TYPES):
with self.subTest(content_type=content_type):
res = self.client.post("/sparql", headers={"accept": content_type})
self.assertEqual(422, res.status_code, msg=f"Response: {res}")

def test_get_query(self):
"""Test querying the app with GET."""
Expand Down

0 comments on commit 047a63c

Please sign in to comment.