diff --git a/CHANGES.md b/CHANGES.md index 323c1c7..fdab5fd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,9 @@ ## [Unreleased] +- Update stac-fastapi libraries to `~=3.0.0a3` +- make sure the application can work without any extension + ## [3.0.0a1] - 2024-05-22 - Update stac-fastapi libraries to `~=3.0.0a1` diff --git a/setup.py b/setup.py index 2dfc728..590319a 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ "orjson", "pydantic", "stac_pydantic==3.1.*", - "stac-fastapi.api~=3.0.0a1", - "stac-fastapi.extensions~=3.0.0a1", - "stac-fastapi.types~=3.0.0a1", + "stac-fastapi.api~=3.0.0a3", + "stac-fastapi.extensions~=3.0.0a3", + "stac-fastapi.types~=3.0.0a3", "asyncpg", "buildpg", "brotli_asgi", diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 3c08c23..067a244 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -50,13 +50,13 @@ extensions = list(extensions_map.values()) post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) - +get_request_model = create_get_request_model(extensions) api = StacApi( settings=settings, extensions=extensions, client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, - search_get_request_model=create_get_request_model(extensions), + search_get_request_model=get_request_model, search_post_request_model=post_request_model, ) app = api.app diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 403f70d..3bc0dce 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -1,7 +1,7 @@ """Item crud client.""" import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Set, Union from urllib.parse import unquote_plus, urljoin import attr @@ -184,12 +184,9 @@ async def _search_base( # noqa: C901 prev: Optional[str] = items.pop("prev", None) collection = ItemCollection(**items) - exclude = search_request.fields.exclude - if exclude and len(exclude) == 0: - exclude = None - include = search_request.fields.include - if include and len(include) == 0: - include = None + fields = getattr(search_request, "fields", None) + include: Set[str] = fields.include if fields and fields.include else set() + exclude: Set[str] = fields.exclude if fields and fields.exclude else set() async def _add_item_links( feature: Item, @@ -204,11 +201,7 @@ async def _add_item_links( collection_id = feature.get("collection") or collection_id item_id = feature.get("id") or item_id - if ( - search_request.fields.exclude is None - or "links" not in search_request.fields.exclude - and all([collection_id, item_id]) - ): + if not exclude or "links" not in exclude and all([collection_id, item_id]): feature["links"] = await ItemLinks( collection_id=collection_id, # type: ignore item_id=item_id, # type: ignore @@ -252,6 +245,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]: next=next, prev=prev, ).get_links() + return collection async def item_collection( @@ -295,14 +289,14 @@ async def item_collection( if v is not None and v != []: clean[k] = v - search_request = self.post_request_model( - **clean, - ) + search_request = self.post_request_model(**clean) item_collection = await self._search_base(search_request, request=request) + links = await ItemCollectionLinks( collection_id=collection_id, request=request ).get_links(extra_links=item_collection["links"]) item_collection["links"] = links + return item_collection async def get_item( @@ -355,15 +349,16 @@ async def get_search( # noqa: C901 collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[BBox] = None, + intersects: Optional[str] = None, datetime: Optional[DateTimeType] = None, limit: Optional[int] = None, + # Extensions query: Optional[str] = None, token: Optional[str] = None, fields: Optional[List[str]] = None, sortby: Optional[str] = None, filter: Optional[str] = None, filter_lang: Optional[str] = None, - intersects: Optional[str] = None, **kwargs, ) -> ItemCollection: """Cross catalog search (GET). diff --git a/stac_fastapi/pgstac/extensions/filter.py b/stac_fastapi/pgstac/extensions/filter.py index 15c3d0f..b6ed99c 100644 --- a/stac_fastapi/pgstac/extensions/filter.py +++ b/stac_fastapi/pgstac/extensions/filter.py @@ -4,7 +4,7 @@ from buildpg import render from fastapi import Request -from stac_fastapi.types.core import AsyncBaseFiltersClient +from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient from stac_fastapi.types.errors import NotFoundError diff --git a/tests/api/test_api.py b/tests/api/test_api.py index b7a5892..3546867 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -1,3 +1,4 @@ +import os from datetime import datetime, timedelta from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar from urllib.parse import quote_plus @@ -6,9 +7,11 @@ import pytest from fastapi import Request from httpx import ASGITransport, AsyncClient +from pypgstac.db import PgstacDB +from pypgstac.load import Loader from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent from stac_fastapi.api.app import StacApi -from stac_fastapi.api.models import create_post_request_model +from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension from stac_fastapi.types import stac as stac_types @@ -17,6 +20,9 @@ from stac_fastapi.pgstac.transactions import TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") + + STAC_CORE_ROUTES = [ "GET /", "GET /collections", @@ -669,11 +675,13 @@ async def get_collection( FieldsExtension(), ] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) + get_request_model = create_get_request_model(extensions) api = StacApi( client=Client(post_request_model=post_request_model), settings=settings, extensions=extensions, search_post_request_model=post_request_model, + search_get_request_model=get_request_model, ) app = api.app await connect_to_db(app) @@ -695,3 +703,105 @@ async def get_collection( assert response.status_code == 200 finally: await close_db_connection(app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("validation", [True, False]) +@pytest.mark.parametrize("hydrate", [True, False]) +async def test_no_extension( + hydrate, validation, load_test_data, database, pgstac +) -> None: + """test PgSTAC with no extension.""" + connection = f"postgresql://{database.user}:{database.password}@{database.host}:{database.port}/{database.dbname}" + with PgstacDB(dsn=connection) as db: + loader = Loader(db=db) + loader.load_collections(os.path.join(DATA_DIR, "test_collection.json")) + loader.load_items(os.path.join(DATA_DIR, "test_item.json")) + + settings = Settings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + testing=True, + use_api_hydrate=hydrate, + enable_response_models=validation, + ) + extensions = [] + post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) + api = StacApi( + client=CoreCrudClient(post_request_model=post_request_model), + settings=settings, + extensions=extensions, + search_post_request_model=post_request_model, + ) + app = api.app + await connect_to_db(app) + try: + async with AsyncClient(transport=ASGITransport(app=app)) as client: + landing = await client.get("http://test/") + assert landing.status_code == 200, landing.text + + collection = await client.get("http://test/collections/test-collection") + assert collection.status_code == 200, collection.text + + collections = await client.get("http://test/collections") + assert collections.status_code == 200, collections.text + + item = await client.get( + "http://test/collections/test-collection/items/test-item" + ) + assert item.status_code == 200, item.text + + item_collection = await client.get( + "http://test/collections/test-collection/items", + params={"limit": 10}, + ) + assert item_collection.status_code == 200, item_collection.text + + get_search = await client.get( + "http://test/search", + params={ + "collections": ["test-collection"], + }, + ) + assert get_search.status_code == 200, get_search.text + + post_search = await client.post( + "http://test/search", + json={ + "collections": ["test-collection"], + }, + ) + assert post_search.status_code == 200, post_search.text + + get_search = await client.get( + "http://test/search", + params={ + "collections": ["test-collection"], + "fields": "properties.datetime", + }, + ) + # fields should be ignored + assert get_search.status_code == 200, get_search.text + props = get_search.json()["features"][0]["properties"] + assert len(props) > 1 + + post_search = await client.post( + "http://test/search", + json={ + "collections": ["test-collection"], + "fields": { + "include": ["properties.datetime"], + }, + }, + ) + # fields should be ignored + assert post_search.status_code == 200, post_search.text + props = get_search.json()["features"][0]["properties"] + assert len(props) > 1 + + finally: + await close_db_connection(app)