Skip to content

Commit

Permalink
update fields extension and make sure the app can work without any ex…
Browse files Browse the repository at this point in the history
…tension (#123)

* update fields extension and make sure the app can work without any extension

* Update stac_fastapi/pgstac/core.py

Co-authored-by: Jonathan Healy <[email protected]>

---------

Co-authored-by: Jonathan Healy <[email protected]>
  • Loading branch information
vincentsarago and jonhealy1 committed Jun 18, 2024
1 parent 26f6d91 commit 810bbb5
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## [Unreleased]

- Update stac-fastapi libraries to `~=3.0.0a3`
- make sure the application can work without any extension

## [3.0.0a1] - 2024-05-22

- Update stac-fastapi libraries to `~=3.0.0a1`
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"orjson",
"pydantic",
"stac_pydantic==3.1.*",
"stac-fastapi.api~=3.0.0a1",
"stac-fastapi.extensions~=3.0.0a1",
"stac-fastapi.types~=3.0.0a1",
"stac-fastapi.api~=3.0.0a3",
"stac-fastapi.extensions~=3.0.0a3",
"stac-fastapi.types~=3.0.0a3",
"asyncpg",
"buildpg",
"brotli_asgi",
Expand Down
4 changes: 2 additions & 2 deletions stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@
extensions = list(extensions_map.values())

post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)

get_request_model = create_get_request_model(extensions)
api = StacApi(
settings=settings,
extensions=extensions,
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
response_class=ORJSONResponse,
search_get_request_model=create_get_request_model(extensions),
search_get_request_model=get_request_model,
search_post_request_model=post_request_model,
)
app = api.app
Expand Down
27 changes: 11 additions & 16 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Item crud client."""

import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Set, Union
from urllib.parse import unquote_plus, urljoin

import attr
Expand Down Expand Up @@ -184,12 +184,9 @@ async def _search_base( # noqa: C901
prev: Optional[str] = items.pop("prev", None)
collection = ItemCollection(**items)

exclude = search_request.fields.exclude
if exclude and len(exclude) == 0:
exclude = None
include = search_request.fields.include
if include and len(include) == 0:
include = None
fields = getattr(search_request, "fields", None)
include: Set[str] = fields.include if fields and fields.include else set()
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()

async def _add_item_links(
feature: Item,
Expand All @@ -204,11 +201,7 @@ async def _add_item_links(
collection_id = feature.get("collection") or collection_id
item_id = feature.get("id") or item_id

if (
search_request.fields.exclude is None
or "links" not in search_request.fields.exclude
and all([collection_id, item_id])
):
if not exclude or "links" not in exclude and all([collection_id, item_id]):
feature["links"] = await ItemLinks(
collection_id=collection_id, # type: ignore
item_id=item_id, # type: ignore
Expand Down Expand Up @@ -252,6 +245,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
next=next,
prev=prev,
).get_links()

return collection

async def item_collection(
Expand Down Expand Up @@ -295,14 +289,14 @@ async def item_collection(
if v is not None and v != []:
clean[k] = v

search_request = self.post_request_model(
**clean,
)
search_request = self.post_request_model(**clean)
item_collection = await self._search_base(search_request, request=request)

links = await ItemCollectionLinks(
collection_id=collection_id, request=request
).get_links(extra_links=item_collection["links"])
item_collection["links"] = links

return item_collection

async def get_item(
Expand Down Expand Up @@ -355,15 +349,16 @@ async def get_search( # noqa: C901
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
bbox: Optional[BBox] = None,
intersects: Optional[str] = None,
datetime: Optional[DateTimeType] = None,
limit: Optional[int] = None,
# Extensions
query: Optional[str] = None,
token: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
intersects: Optional[str] = None,
**kwargs,
) -> ItemCollection:
"""Cross catalog search (GET).
Expand Down
2 changes: 1 addition & 1 deletion stac_fastapi/pgstac/extensions/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from buildpg import render
from fastapi import Request
from stac_fastapi.types.core import AsyncBaseFiltersClient
from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient
from stac_fastapi.types.errors import NotFoundError


Expand Down
112 changes: 111 additions & 1 deletion tests/api/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from datetime import datetime, timedelta
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar
from urllib.parse import quote_plus
Expand All @@ -6,9 +7,11 @@
import pytest
from fastapi import Request
from httpx import ASGITransport, AsyncClient
from pypgstac.db import PgstacDB
from pypgstac.load import Loader
from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import create_post_request_model
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension
from stac_fastapi.types import stac as stac_types

Expand All @@ -17,6 +20,9 @@
from stac_fastapi.pgstac.transactions import TransactionsClient
from stac_fastapi.pgstac.types.search import PgstacSearch

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")


STAC_CORE_ROUTES = [
"GET /",
"GET /collections",
Expand Down Expand Up @@ -669,11 +675,13 @@ async def get_collection(
FieldsExtension(),
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)
api = StacApi(
client=Client(post_request_model=post_request_model),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
search_get_request_model=get_request_model,
)
app = api.app
await connect_to_db(app)
Expand All @@ -695,3 +703,105 @@ async def get_collection(
assert response.status_code == 200
finally:
await close_db_connection(app)


@pytest.mark.asyncio
@pytest.mark.parametrize("validation", [True, False])
@pytest.mark.parametrize("hydrate", [True, False])
async def test_no_extension(
hydrate, validation, load_test_data, database, pgstac
) -> None:
"""test PgSTAC with no extension."""
connection = f"postgresql://{database.user}:{database.password}@{database.host}:{database.port}/{database.dbname}"
with PgstacDB(dsn=connection) as db:
loader = Loader(db=db)
loader.load_collections(os.path.join(DATA_DIR, "test_collection.json"))
loader.load_items(os.path.join(DATA_DIR, "test_item.json"))

settings = Settings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
testing=True,
use_api_hydrate=hydrate,
enable_response_models=validation,
)
extensions = []
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
api = StacApi(
client=CoreCrudClient(post_request_model=post_request_model),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
)
app = api.app
await connect_to_db(app)
try:
async with AsyncClient(transport=ASGITransport(app=app)) as client:
landing = await client.get("http://test/")
assert landing.status_code == 200, landing.text

collection = await client.get("http://test/collections/test-collection")
assert collection.status_code == 200, collection.text

collections = await client.get("http://test/collections")
assert collections.status_code == 200, collections.text

item = await client.get(
"http://test/collections/test-collection/items/test-item"
)
assert item.status_code == 200, item.text

item_collection = await client.get(
"http://test/collections/test-collection/items",
params={"limit": 10},
)
assert item_collection.status_code == 200, item_collection.text

get_search = await client.get(
"http://test/search",
params={
"collections": ["test-collection"],
},
)
assert get_search.status_code == 200, get_search.text

post_search = await client.post(
"http://test/search",
json={
"collections": ["test-collection"],
},
)
assert post_search.status_code == 200, post_search.text

get_search = await client.get(
"http://test/search",
params={
"collections": ["test-collection"],
"fields": "properties.datetime",
},
)
# fields should be ignored
assert get_search.status_code == 200, get_search.text
props = get_search.json()["features"][0]["properties"]
assert len(props) > 1

post_search = await client.post(
"http://test/search",
json={
"collections": ["test-collection"],
"fields": {
"include": ["properties.datetime"],
},
},
)
# fields should be ignored
assert post_search.status_code == 200, post_search.text
props = get_search.json()["features"][0]["properties"]
assert len(props) > 1

finally:
await close_db_connection(app)

0 comments on commit 810bbb5

Please sign in to comment.