From 58de8c7c5b1f3424d11ee05ca96837a34b39c74c Mon Sep 17 00:00:00 2001 From: Roland Date: Wed, 17 Apr 2024 11:25:26 -0600 Subject: [PATCH] Upgrading Sanic and yt-dlp. Major changes occurred because Sanic changed how multiprocessing works. Python 3.11 is now required, old versions were causing unpredictability in tests. (Sanic does not yet support 3.12) Sanic has been upgraded to 23.6.0, which is the latest version that avoids this bug: https://github.com/sanic-org/sanic/issues/2921 New strategy for multiprocessing is to create all multiprocessing tools in one process, then fork to other processes. The previous strategy was to declare multiprocessing tools at the top of every file, or wherever they were needed at import/creation. Now all multiprocessing tools are attached to the app.shared_ctx. This means `api_app` is imported in many, many places. This forced a change in how the DownloadManager works. Previously, it would continually run download workers which would pull downloads from a multiprocessing.Queue. Now, a single worker checks for new downloads and sends a Sanic signal. Flags have been reworked to use the `api_app`. I removed the `which` flag functionality because the `which` are called at import and needed their own multiprocessing.Event. --- .circleci/config.yml | 89 +--- .gitignore | 1 + app/src/api.js | 14 + app/src/components/Common.js | 9 + app/src/components/Map.js | 13 +- app/src/components/Zim.js | 10 +- app/src/components/admin/Downloads.js | 36 +- docker-compose.yml | 3 +- docker/api/Dockerfile | 10 +- main.py | 295 ++++++----- modules/archive/__init__.py | 4 +- modules/archive/api.py | 16 +- modules/archive/test/test_lib.py | 8 +- modules/inventory/__init__.py | 2 +- modules/inventory/api.py | 30 +- modules/inventory/common.py | 2 +- modules/inventory/conftest.py | 3 - modules/inventory/test/test_inventory.py | 11 +- modules/map/api.py | 14 +- modules/map/lib.py | 14 +- modules/otp/api.py | 11 +- modules/videos/__init__.py | 2 +- modules/videos/api.py | 5 +- modules/videos/channel/api.py | 2 +- modules/videos/channel/test/test_api.py | 8 +- modules/videos/channel/test/test_lib.py | 2 +- modules/videos/conftest.py | 6 +- modules/videos/downloader.py | 5 +- modules/videos/lib.py | 39 +- modules/videos/models.py | 10 +- modules/videos/test/test_api.py | 2 +- modules/videos/test/test_common.py | 4 +- modules/videos/test/test_downloader.py | 5 +- modules/videos/test/test_video.py | 1 + modules/videos/video/api.py | 11 +- modules/videos/video/test/test_api.py | 4 +- modules/zim/api.py | 30 +- modules/zim/test/test_lib.py | 6 +- pytest.ini | 2 +- requirements.txt | 14 +- wrolpi/__init__.py | 2 +- wrolpi/api_utils.py | 92 ++++ wrolpi/cmd.py | 17 +- wrolpi/common.py | 107 ++-- wrolpi/conftest.py | 96 ++-- wrolpi/contexts.py | 100 ++++ wrolpi/dates.py | 3 + wrolpi/db.py | 31 +- wrolpi/downloader.py | 630 +++++++++++------------ wrolpi/events.py | 15 +- wrolpi/files/__init__.py | 2 +- wrolpi/files/api.py | 53 +- wrolpi/files/lib.py | 17 +- wrolpi/files/test/test_api.py | 2 +- wrolpi/files/test/test_lib.py | 10 +- wrolpi/flags.py | 63 +-- wrolpi/root_api.py | 166 ++---- wrolpi/status.py | 53 +- wrolpi/test/common.py | 2 +- wrolpi/test/test_admin.py | 4 +- wrolpi/test/test_dates.py | 6 +- wrolpi/test/test_downloader.py | 32 +- wrolpi/test/test_root_api.py | 26 +- wrolpi/test/test_rss.py | 2 +- wrolpi/vars.py | 2 + 65 files changed, 1233 insertions(+), 1053 deletions(-) create mode 100644 wrolpi/api_utils.py create mode 100644 wrolpi/contexts.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 0d5b1a30..d88732ce 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -61,91 +61,10 @@ jobs: paths: - ./app/node_modules - run: cd app && npm run test - api-tests-3-8: - docker: - - image: cimg/python:3.8 - - image: cimg/postgres:12.9 - environment: - POSTGRES_USER: postgres - POSTGRES_DB: wrolpi - POSTGRES_PASSWORD: "wrolpi" - resource_class: medium - steps: - - checkout - - run: sudo apt-get update - - run: sudo apt-get install -y ffmpeg catdoc - - restore_cache: - key: deps-3.8-{{ checksum "requirements.txt" }}-2 - - run: - name: Install Requirements - command: | - python3 -m venv venv - . venv/bin/activate - pip install -r requirements.txt - - save_cache: - key: deps-3.8-{{ checksum "requirements.txt" }} - paths: - - "venv" - - run: - command: './venv/bin/pytest -svv' - api-tests-3-9: - docker: - - image: cimg/python:3.9 - - image: cimg/postgres:13.5 - environment: - POSTGRES_USER: postgres - POSTGRES_DB: wrolpi - POSTGRES_PASSWORD: "wrolpi" - resource_class: medium - steps: - - checkout - - run: sudo apt-get update - - run: sudo apt-get install -y ffmpeg catdoc - - restore_cache: - key: deps-3.9-{{ checksum "requirements.txt" }}-2 - - run: - name: Install Requirements - command: | - python3 -m venv venv - . venv/bin/activate - pip install -r requirements.txt - - save_cache: - key: deps-3.9-{{ checksum "requirements.txt" }} - paths: - - "venv" - - run: - command: './venv/bin/pytest -svv' - api-tests-3-10: - docker: - - image: cimg/python:3.10 - - image: cimg/postgres:14.1 - environment: - POSTGRES_USER: postgres - POSTGRES_DB: wrolpi - POSTGRES_PASSWORD: "wrolpi" - resource_class: medium - steps: - - checkout - - run: sudo apt-get update - - run: sudo apt-get install -y ffmpeg catdoc - - restore_cache: - key: deps-3.10-{{ checksum "requirements.txt" }}-2 - - run: - name: Install Requirements - command: | - python3 -m venv venv - . venv/bin/activate - pip install -r requirements.txt - - save_cache: - key: deps-3.10-{{ checksum "requirements.txt" }} - paths: - - "venv" - - run: - command: './venv/bin/pytest -svv' api-tests-3-11: docker: - image: cimg/python:3.11 - - image: cimg/postgres:14.1 + - image: cimg/postgres:15.6 environment: POSTGRES_USER: postgres POSTGRES_DB: wrolpi @@ -172,7 +91,7 @@ jobs: api-tests-3-12: docker: - image: cimg/python:3.12 - - image: cimg/postgres:14.1 + - image: cimg/postgres:15.6 environment: POSTGRES_USER: postgres POSTGRES_DB: wrolpi @@ -200,10 +119,8 @@ jobs: workflows: wrolpi-api-tests: jobs: - - api-tests-3-8 - - api-tests-3-9 - - api-tests-3-10 - api-tests-3-11 +# - api-tests-3-12 Sanic does not yet support 3.12. wrolpi-app-test: jobs: - app-tests-14 diff --git a/.gitignore b/.gitignore index ab6b45eb..2594898a 100644 --- a/.gitignore +++ b/.gitignore @@ -126,6 +126,7 @@ docker-compose.override.yml # test directory is used as media directory, we don't want to commit what a user downloads. test +pg_data # Directories used to build images /debian-live-config/config/includes.chroot/opt/wrolpi-blobs/gis-map.dump.gz diff --git a/app/src/api.js b/app/src/api.js index 88244f86..b76bdef3 100644 --- a/app/src/api.js +++ b/app/src/api.js @@ -772,6 +772,20 @@ export async function clearFailedDownloads() { } } +export async function deleteOnceDownloads() { + let response = await apiPost(`${API_URI}/download/delete_once`); + if (response.status === 204) { + return null + } else { + toast({ + type: 'error', + title: 'Error!', + description: 'Could not delete once downloads! See server logs.', + time: 5000, + }); + } +} + export async function getStatistics() { let response = await apiGet(`${API_URI}/statistics`); if (response.status === 200) { diff --git a/app/src/components/Common.js b/app/src/components/Common.js index 30e3ebb1..536ad970 100644 --- a/app/src/components/Common.js +++ b/app/src/components/Common.js @@ -1583,6 +1583,15 @@ export function InfoMessage({children, size = null}) { } +export function HandPointMessage({children, size = null}) { + return + + + {children} + + +} + export function WarningMessage({children, size = null}) { return diff --git a/app/src/components/Map.js b/app/src/components/Map.js index 8a5e323a..ce0983bd 100644 --- a/app/src/components/Map.js +++ b/app/src/components/Map.js @@ -2,6 +2,7 @@ import React, {useContext} from "react"; import { APIButton, ErrorMessage, + HandPointMessage, HelpPopup, humanFileSize, IframeViewer, @@ -31,6 +32,9 @@ import {Loader, Placeholder, Table} from "./Theme"; import {StatusContext} from "../contexts/contexts"; import _ from "lodash"; + +const VIEWER_URL = `http://${window.location.hostname}:8084/`; + function DockerMapImportWarning() { const {status} = useContext(StatusContext); if (status['dockerized']) { @@ -89,6 +93,12 @@ function DownloadMessage() { } +const ViewerMessage = () => { + return +

You can view your Map at {VIEWER_URL}

+
+} + function SlowImportMessage() { const {status} = useContext(StatusContext); if (status && status['cpu_info'] && status['cpu_info']['temperature'] >= 80) { @@ -277,12 +287,13 @@ class ManageMap extends React.Component { + } } function MapPage() { - return + return } export function MapRoute() { diff --git a/app/src/components/Zim.js b/app/src/components/Zim.js index 68c46d81..06eecc2b 100644 --- a/app/src/components/Zim.js +++ b/app/src/components/Zim.js @@ -33,6 +33,7 @@ import { APIButton, encodeMediaPath, ErrorMessage, + HandPointMessage, humanFileSize, IframeViewer, InfoMessage, @@ -328,14 +329,18 @@ export const ZimSearchView = ({suggestions, loading}) => { } -const ViewerMessage = () => { +const DownloadMessage = () => { return

More Zim files are available from the full Kiwix library  https://download.kiwix.org/

+
+} +const ViewerMessage = () => { + return

You can view your Zim files using the Kiwix app, or at {VIEWER_URL}

- +
} const ZimCatalogItemRow = ({item, subscriptions, iso_639_codes, fetchSubscriptions}) => { @@ -524,6 +529,7 @@ class ManageZim extends React.Component { + } diff --git a/app/src/components/admin/Downloads.js b/app/src/components/admin/Downloads.js index 4047c16a..aba36f36 100644 --- a/app/src/components/admin/Downloads.js +++ b/app/src/components/admin/Downloads.js @@ -1,5 +1,12 @@ import React from "react"; -import {clearCompletedDownloads, clearFailedDownloads, deleteDownload, killDownload, restartDownload} from "../../api"; +import { + clearCompletedDownloads, + clearFailedDownloads, + deleteDownload, + deleteOnceDownloads, + killDownload, + restartDownload +} from "../../api"; import {Link} from "react-router-dom"; import { APIButton, @@ -40,7 +47,7 @@ function ClearCompleteDownloads({callback}) { return <> Clear Completed @@ -58,7 +65,7 @@ function ClearFailedDownloads({callback}) { } return } +function DeleteOnceDownloads({callback}) { + async function localDeleteOnce() { + try { + await deleteOnceDownloads(); + } finally { + if (callback) { + callback() + } + } + } + + return + Delete All + +} + class RecurringDownloadRow extends React.Component { constructor(props) { super(props); @@ -306,6 +335,7 @@ export function OnceDownloadsTable({downloads, fetchDownloads}) { + diff --git a/docker-compose.yml b/docker-compose.yml index edcc0822..4e5067bd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,6 +11,8 @@ services: healthcheck: test: [ 'CMD-SHELL', 'pg_isready -U postgres' ] interval: 10s + volumes: + - ./pg_data:/var/lib/postgresql/data api: depends_on: @@ -29,7 +31,6 @@ services: - './alembic.ini:/opt/wrolpi/alembic.ini' ports: - ${REACT_APP_API-127.0.0.1:8081}:8081 - command: '-vv api --host 0.0.0.0' user: '${UID-1000}:${GID-1000}' healthcheck: test: [ 'CMD-SHELL', 'curl http://127.0.0.1:8081/api/echo' ] diff --git a/docker/api/Dockerfile b/docker/api/Dockerfile index 6f871ac5..5f316f4d 100644 --- a/docker/api/Dockerfile +++ b/docker/api/Dockerfile @@ -8,12 +8,14 @@ RUN apt update RUN apt-get install -y ffmpeg catdoc RUN ffmpeg -version -# Install WROLPi +# Install dependencies. +COPY requirements.txt /opt/wrolpi/requirements.txt +RUN pip3 install -r /opt/wrolpi/requirements.txt + +# Install WROLPi. COPY main.py /opt/wrolpi/ COPY wrolpi /opt/wrolpi/wrolpi COPY modules /opt/wrolpi/modules -COPY requirements.txt /opt/wrolpi/requirements.txt -RUN pip3 install -r /opt/wrolpi/requirements.txt ENTRYPOINT [ "python3", "-OO", "/opt/wrolpi/main.py"] -CMD ["api", "--host", "0.0.0.0" ] +CMD ["-vv", "api", "--host", "0.0.0.0"] diff --git a/main.py b/main.py index 6330edcb..bb9958ab 100755 --- a/main.py +++ b/main.py @@ -7,14 +7,17 @@ from sanic import Sanic from sanic.signals import Event -from wrolpi import flags, BEFORE_STARTUP_FUNCTIONS -from wrolpi import root_api, admin -from wrolpi.common import logger, get_wrolpi_config, check_media_directory, limit_concurrent, \ - wrol_mode_enabled, cancel_refresh_tasks, set_log_level, background_task, cancel_background_tasks +from wrolpi import flags, BEFORE_STARTUP_FUNCTIONS, admin +from wrolpi import root_api +from wrolpi import tags +from wrolpi.api_utils import api_app +from wrolpi.common import logger, check_media_directory, set_log_level, limit_concurrent, \ + cancel_refresh_tasks, cancel_background_tasks, get_wrolpi_config +from wrolpi.contexts import attach_shared_contexts, reset_shared_contexts, initialize_configs_contexts from wrolpi.dates import Seconds -from wrolpi.downloader import download_manager, import_downloads_config -from wrolpi.root_api import api_app -from wrolpi.vars import PROJECT_DIR, DOCKERIZED, PYTEST +from wrolpi.downloader import import_downloads_config, download_manager, \ + perpetual_download_worker +from wrolpi.vars import PROJECT_DIR, DOCKERIZED from wrolpi.version import get_version_string logger = logger.getChild('wrolpi-main') @@ -22,7 +25,7 @@ def db_main(args): """ - Handle database migrations. Currently this uses Alembic, supported commands are "upgrade" and "downgrade". + Handle database migrations. This uses Alembic, supported commands are "upgrade" and "downgrade". """ from alembic.config import Config from alembic import command @@ -115,46 +118,129 @@ def main(): return 1 logger.warning(f'Starting with: {sys.argv}') - from wrolpi.common import LOG_LEVEL - with LOG_LEVEL.get_lock(): - if args.verbose == 1: - LOG_LEVEL.value = logging.INFO - set_log_level(logging.INFO) - elif args.verbose and args.verbose == 2: - LOG_LEVEL.value = logging.DEBUG - set_log_level(logging.DEBUG) - elif args.verbose and args.verbose >= 3: - # Log everything. Add SQLAlchemy debug logging. - LOG_LEVEL.value = logging.NOTSET - set_log_level(logging.NOTSET) + if args.verbose == 1: + set_log_level(logging.INFO) + elif args.verbose and args.verbose == 2: + set_log_level(logging.DEBUG) + elif args.verbose and args.verbose >= 3: + # Log everything. Add SQLAlchemy debug logging. + set_log_level(logging.NOTSET) logger.info(get_version_string()) if DOCKERIZED: logger.info('Running in Docker') + check_media_directory() + # Run DB migrations before anything else. if args.sub_commands == 'db': return db_main(args) - config = get_wrolpi_config() + # Run the API. + if args.sub_commands == 'api': + return root_api.main(args) + + +@api_app.main_process_start +async def startup(app: Sanic): + """ + Initializes multiprocessing tools, flags, etc. + + Performed only once when the server starts, this is done before server processes are forked. + + @warning: This is NOT run after auto-reload! You must stop and start Sanic. + """ + logger.debug('startup') + # Initialize multiprocessing shared contexts before forking Sanic processes. + attach_shared_contexts(app) + logger.debug('startup done') + + +@api_app.listener('before_server_start') # FileConfigs need to be initialized first. +async def initialize_configs(app: Sanic): + """Each Sanic process runs this once.""" + # Each process will have their own FileConfig object, but share the `app.shared_ctx.*config` + logger.debug('initialize_configs') + + try: + initialize_configs_contexts(app) + except Exception as e: + logger.error('initialize_configs failed with', exc_info=e) + raise + + logger.debug('initialize_configs done') + + +@api_app.signal(Event.SERVER_SHUTDOWN_BEFORE) +@api_app.listener('reload_process_stop') +@limit_concurrent(1) +async def handle_server_shutdown(*args, **kwargs): + """Stop downloads when server is shutting down.""" + logger.debug('handle_server_shutdown') + download_manager.stop() + await cancel_refresh_tasks() + await cancel_background_tasks() + + +@api_app.signal(Event.SERVER_SHUTDOWN_AFTER) +async def handle_server_shutdown_reset(app: Sanic, loop): + """Reset things after shutdown is complete, just in case server is going to start again.""" + reset_shared_contexts(app) + + +# Start periodic tasks after configs are ready. +@api_app.listener('after_server_start') +async def start_single_tasks(app: Sanic): + """Recurring/Single tasks that are started in only one Sanic process.""" + # Only allow one child process to perform periodic tasks. See `handle_server_shutdown` + if app.shared_ctx.single_tasks_started.is_set(): + return + app.shared_ctx.single_tasks_started.set() + + from modules.zim.lib import flag_outdated_zim_files + + logger.debug(f'start_single_tasks started') + if get_wrolpi_config().download_on_startup: + download_manager.enable() + else: + download_manager.disable() + + # Start perpetual tasks. DO NOT AWAIT! + app.add_task(perpetual_check_db_is_up_worker()) # noqa + app.add_task(perpetual_download_worker()) # noqa + + # await app.dispatch('wrolpi.periodic.start_video_missing_comments_download') + await app.dispatch('wrolpi.periodic.bandwidth') + + try: + flag_outdated_zim_files() + except Exception as e: + logger.error('Failed to flag outdated Zims', exc_info=e) + + logger.debug('start_single_tasks waiting for db...') + async with flags.db_up.wait_for(): + logger.debug('start_single_tasks db is up') + + tags.import_tags_config() + + await import_downloads_config() + + if flags.refresh_complete.is_set(): + # Set all downloads to new. + download_manager.reset_downloads() # Hotspot/throttle are not supported in Docker containers. - if not DOCKERIZED and config.hotspot_on_startup: + if not DOCKERIZED and get_wrolpi_config().hotspot_on_startup: try: admin.enable_hotspot() except Exception as e: logger.error('Failed to enable hotspot', exc_info=e) - if not DOCKERIZED and config.throttle_on_startup: + if not DOCKERIZED and get_wrolpi_config().throttle_on_startup: try: admin.throttle_cpu_on() except Exception as e: logger.error('Failed to throttle CPU', exc_info=e) - check_media_directory() - - # Import modules before calling BEFORE_STARTUP_FUNCTIONS. - import modules # noqa - # Run the startup functions for func in BEFORE_STARTUP_FUNCTIONS: try: @@ -163,128 +249,77 @@ def main(): except Exception as e: logger.warning(f'Startup {func} failed!', exc_info=e) - # Run the API. - if args.sub_commands == 'api': - return root_api.main(args) + logger.debug(f'start_single_tasks done') -@api_app.before_server_start -@limit_concurrent(1) -async def startup(app: Sanic): - from wrolpi.common import LOG_LEVEL +@api_app.listener('after_server_start') +async def start_sanic_worker(app: Sanic): + logger.debug(f'start_sanic_worker') - # Check database status first. Many functions will reference flags.db_up. - flags.check_db_is_up() + await app.dispatch('wrolpi.periodic.check_log_level') - flags.init_flags() - await import_downloads_config() - async def periodic_check_db_is_up(): - while True: - flags.check_db_is_up() +@api_app.listener('after_server_start') +async def start_initialize_flags(app: Sanic): + # Only allow one child process to initialize flags. + if not app.shared_ctx.flags_initialized.is_set(): + logger.warning('start_initialize_flags') + app.shared_ctx.flags_initialized.set() + async with flags.db_up.wait_for(): flags.init_flags() - await asyncio.sleep(10) - background_task(periodic_check_db_is_up()) - async def periodic_check_log_level(): - while True: - log_level = LOG_LEVEL.value - if log_level != logger.getEffectiveLevel(): - set_log_level(log_level) - await asyncio.sleep(1) - - background_task(periodic_check_log_level()) - - from wrolpi import status - background_task(status.bandwidth_worker()) +@api_app.signal('wrolpi.periodic.check_log_level') +async def periodic_check_log_level(): + """Copies global log level into this Sanic worker's logger.""" + log_level = api_app.shared_ctx.log_level.value + if log_level != logger.getEffectiveLevel(): + logger.info(f'changing log level {log_level}') + set_log_level(log_level) - from modules.zim.lib import flag_outdated_zim_files - flag_outdated_zim_files() - - from modules.videos.video.lib import get_missing_videos_comments - - async def periodic_start_video_missing_comments_download(): - async with flags.refresh_complete.wait_for(): - # We can't search for Videos missing comments until the refresh has completed. - pass - - # Wait for download manager to startup. - await asyncio.sleep(5) - - while True: - # Fetch comments for videos every hour. - if download_manager.disabled.is_set() or download_manager.stopped.is_set(): - await asyncio.sleep(10) - else: - await get_missing_videos_comments() - await asyncio.sleep(int(Seconds.hour)) - - background_task(periodic_start_video_missing_comments_download()) - - -@api_app.after_server_start -async def periodic_downloads(app: Sanic): - """ - Starts the perpetual downloader on download manager. - - Limited to only one process. - """ - async with flags.db_up.wait_for(): + try: + await asyncio.sleep(1) + await api_app.dispatch('wrolpi.periodic.check_log_level') + except asyncio.CancelledError: + # Server is shutting down. pass - if not flags.refresh_complete.is_set(): - logger.warning('Refusing to download without refresh') - download_manager.disable() - return - - # Set all downloads to new. - download_manager.reset_downloads() - - if wrol_mode_enabled(): - logger.warning('Not starting download manager because WROL Mode is enabled.') - download_manager.disable() - return - config = get_wrolpi_config() - if config.download_on_startup is False: - logger.warning('Not starting download manager because Downloads are disabled on startup.') - download_manager.disable() - return +async def perpetual_check_db_is_up_worker(): + try: + while True: + try: + flags.check_db_is_up() + flags.init_flags() + except Exception as e: + logger.error('Failed to check if db is up', exc_info=e) - async with flags.db_up.wait_for(): - download_manager.enable() - app.add_task(download_manager.perpetual_download()) + await asyncio.sleep(10) + except asyncio.CancelledError: + logger.info('periodic_check_db_is_up was cancelled...') -@api_app.after_server_start -async def start_workers(app: Sanic): - """All Sanic processes have their own Download workers.""" - if wrol_mode_enabled(): - logger.warning(f'Not starting download workers because WROL Mode is enabled.') - download_manager.stop() - return +@api_app.signal('wrolpi.periodic.start_video_missing_comments_download') +async def periodic_start_video_missing_comments_download(): + logger.debug('periodic_start_video_missing_comments_download is running') - async with flags.db_up.wait_for(): - download_manager.start_workers() + from modules.videos.video.lib import get_missing_videos_comments + async with flags.refresh_complete.wait_for(): + # We can't search for Videos missing comments until the refresh has completed. + pass -@api_app.before_server_start -@limit_concurrent(1) -async def main_import_tags_config(app: Sanic): - from wrolpi import tags - async with flags.db_up.wait_for(): - tags.import_tags_config() + # Wait for download manager to startup. + await asyncio.sleep(5) + # Fetch comments for videos every hour. + if download_manager.disabled.is_set() or download_manager.stopped.is_set(): + await asyncio.sleep(10) + else: + await get_missing_videos_comments() + await asyncio.sleep(int(Seconds.hour)) -@root_api.api_app.signal(Event.SERVER_SHUTDOWN_BEFORE) -@limit_concurrent(1) -async def handle_server_shutdown(*args, **kwargs): - """Stop downloads when server is shutting down.""" - if not PYTEST: - download_manager.stop() - await cancel_refresh_tasks() - await cancel_background_tasks() + await api_app.dispatch('wrolpi.periodic.start_video_missing_comments_download') if __name__ == '__main__': diff --git a/modules/archive/__init__.py b/modules/archive/__init__.py index 088497cc..a5524cf2 100644 --- a/modules/archive/__init__.py +++ b/modules/archive/__init__.py @@ -18,7 +18,7 @@ from wrolpi.files.models import FileGroup from wrolpi.vars import PYTEST, DOCKERIZED from . import lib -from .api import bp # noqa +from .api import archive_bp # noqa from .errors import InvalidArchive from .lib import is_singlefile_file, request_archive, SINGLEFILE_HEADER from .models import Archive, Domain @@ -74,6 +74,7 @@ async def do_singlefile(self, download: Download) -> bytes: '--browser-args', '["--no-sandbox"]', '--dump-content') return_code, _, stdout = await self.process_runner( + download.id, download.url, cmd, pathlib.Path('/home/wrolpi'), @@ -91,6 +92,7 @@ async def do_readability(self, download: Download, html: bytes) -> dict: cmd = (READABILITY_BIN, fh.name, download.url) logger.debug(f'readability cmd: {cmd}') return_code, logs, stdout = await self.process_runner( + download.id, download.url, cmd, pathlib.Path('/home/wrolpi'), diff --git a/modules/archive/api.py b/modules/archive/api.py index 2e88c790..01128d88 100644 --- a/modules/archive/api.py +++ b/modules/archive/api.py @@ -1,23 +1,23 @@ from http import HTTPStatus -from sanic import response, Request +from sanic import response, Request, Blueprint from sanic_ext import validate from sanic_ext.extensions.openapi import openapi from wrolpi.common import logger, wrol_mode_check, api_param_limiter from wrolpi.errors import ValidationError -from wrolpi.root_api import get_blueprint, json_response +from wrolpi.api_utils import json_response from wrolpi.schema import JSONErrorResponse from . import lib, schema NAME = 'archive' -bp = get_blueprint('Archive', '/api/archive') +archive_bp = Blueprint('Archive', '/api/archive') logger = logger.getChild(__name__) -@bp.get('/') +@archive_bp.get('/') @openapi.description('Get an archive') @openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse) async def get_archive(_: Request, archive_id: int): @@ -27,8 +27,8 @@ async def get_archive(_: Request, archive_id: int): return json_response({'file_group': archive_file_group, 'history': history}) -@bp.delete('/') -@bp.delete('/') +@archive_bp.delete('/', name='archive_delete_one') +@archive_bp.delete('/', name='archive_delete_many') @openapi.description('Delete an individual archive') @openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse) @wrol_mode_check @@ -41,7 +41,7 @@ async def delete_archive(_: Request, archive_ids: str): return response.empty() -@bp.get('/domains') +@archive_bp.get('/domains') @openapi.summary('Get a list of all Domains and their Archive statistics') @openapi.response(200, schema.GetDomainsResponse, "The list of domains") async def get_domains(_: Request): @@ -52,7 +52,7 @@ async def get_domains(_: Request): archive_limit_limiter = api_param_limiter(100) -@bp.post('/search') +@archive_bp.post('/search') @openapi.definition( summary='A File search with more filtering related to Archives', body=schema.ArchiveSearchRequest, diff --git a/modules/archive/test/test_lib.py b/modules/archive/test/test_lib.py index 330aefbd..3731185f 100644 --- a/modules/archive/test/test_lib.py +++ b/modules/archive/test/test_lib.py @@ -17,7 +17,7 @@ from wrolpi.db import get_db_session from wrolpi.files import lib as files_lib from wrolpi.files.models import FileGroup -from wrolpi.root_api import CustomJSONEncoder +from wrolpi.api_utils import CustomJSONEncoder from wrolpi.test.common import skip_circleci @@ -37,7 +37,7 @@ def make_fake_archive_result(readability=True, screenshot=True, title=True): @pytest.mark.asyncio -async def test_no_screenshot(test_session): +async def test_no_screenshot(test_directory, test_session): singlefile, readability, screenshot = make_fake_archive_result(screenshot=False) archive = await model_archive_result('https://example.com', singlefile, readability, screenshot) assert isinstance(archive.singlefile_path, pathlib.Path) @@ -80,7 +80,7 @@ async def test_relationships(test_session, example_singlefile): @pytest.mark.asyncio -async def test_archive_title(test_session, archive_factory, singlefile_contents_factory): +async def test_archive_title(test_async_client, test_session, archive_factory, singlefile_contents_factory): """An Archive's title can be fetched in multiple ways. This tests from most to least reliable.""" # Create some test files, delete all records for a fresh refresh. archive_factory( @@ -329,7 +329,7 @@ async def test_new_archive(test_session, fake_now): @pytest.mark.asyncio -async def test_get_title_from_html(test_session, fake_now): +async def test_get_title_from_html(test_directory, test_session, fake_now): fake_now(datetime(2000, 1, 1)) singlefile, readability, screenshot = make_fake_archive_result() archive = await model_archive_result('https://example.com', singlefile, readability, screenshot) diff --git a/modules/inventory/__init__.py b/modules/inventory/__init__.py index 64c647f3..f51bbc31 100644 --- a/modules/inventory/__init__.py +++ b/modules/inventory/__init__.py @@ -1,6 +1,6 @@ import wrolpi from wrolpi.db import get_db_session -from .api import bp +from .api import inventory_bp from .inventory import logger, DEFAULT_CATEGORIES, DEFAULT_INVENTORIES from .models import Item, Inventory diff --git a/modules/inventory/api.py b/modules/inventory/api.py index c92e0395..545ee36e 100644 --- a/modules/inventory/api.py +++ b/modules/inventory/api.py @@ -1,39 +1,39 @@ from http import HTTPStatus -from sanic import response +from sanic import response, Blueprint from sanic.request import Request from sanic_ext import validate from sanic_ext.extensions.openapi import openapi from modules.inventory import common, inventory, schema +from wrolpi.api_utils import json_response from wrolpi.common import run_after, recursive_map from wrolpi.errors import ValidationError -from wrolpi.root_api import get_blueprint, json_response NAME = 'inventory' -bp = get_blueprint('Inventory', '/api/inventory') +inventory_bp = Blueprint('Inventory', '/api/inventory') -@bp.get('/categories') +@inventory_bp.get('/categories') def get_categories(_: Request): categories = inventory.get_categories() return json_response(dict(categories=categories)) -@bp.get('/brands') +@inventory_bp.get('/brands') def get_brands(_: Request): brands = inventory.get_brands() return json_response(dict(brands=brands)) -@bp.get('/') +@inventory_bp.get('/') def get_inventories(_: Request): inventories = inventory.get_inventories() return json_response(dict(inventories=inventories)) -@bp.get('/') +@inventory_bp.get('/') def get_inventory(_: Request, inventory_id: int): by_category = common.get_inventory_by_category(inventory_id) by_subcategory = common.get_inventory_by_subcategory(inventory_id) @@ -41,7 +41,7 @@ def get_inventory(_: Request, inventory_id: int): return json_response(dict(by_category=by_category, by_subcategory=by_subcategory, by_name=by_name)) -@bp.post('/') +@inventory_bp.post('/') @openapi.definition( summary='Save a new inventory', body=schema.InventoryPostRequest, @@ -55,7 +55,7 @@ def post_inventory(_: Request, body: schema.InventoryPostRequest): return response.empty(HTTPStatus.CREATED) -@bp.put('/') +@inventory_bp.put('/') @openapi.definition( summary='Update an inventory', body=schema.InventoryPutRequest, @@ -69,7 +69,7 @@ def put_inventory(_: Request, inventory_id: int, body: schema.InventoryPutReques return response.empty() -@bp.delete('/') +@inventory_bp.delete('/') @openapi.description('Delete an inventory.') @run_after(common.save_inventories_file) def inventory_delete(_: Request, inventory_id: int): @@ -77,14 +77,14 @@ def inventory_delete(_: Request, inventory_id: int): return response.empty() -@bp.get('//item') +@inventory_bp.get('//item') @openapi.description('Get all items from an inventory.') def items_get(_: Request, inventory_id: int): items = inventory.get_items(inventory_id) return json_response({'items': items}) -@bp.post('//item') +@inventory_bp.post('//item') @openapi.definition( summary="Save an item into it's inventory.", body=schema.ItemPostRequest, @@ -98,7 +98,7 @@ def post_item(_: Request, inventory_id: int, body: schema.ItemPostRequest): return response.empty() -@bp.put('/item/') +@inventory_bp.put('/item/') @openapi.definition( summary='Update an item.', body=schema.ItemPutRequest, @@ -112,8 +112,8 @@ def put_item(_: Request, item_id: int, body: schema.ItemPutRequest): return response.empty() -@bp.delete('/item/') -@bp.delete('/item/') +@inventory_bp.delete('/item/', name='item_delete_many') +@inventory_bp.delete('/item/', name='item_delete_one') @openapi.description('Delete items from an inventory.') @run_after(common.save_inventories_file) def item_delete(_: Request, item_ids: str): diff --git a/modules/inventory/common.py b/modules/inventory/common.py index a8020b8a..887c4e04 100644 --- a/modules/inventory/common.py +++ b/modules/inventory/common.py @@ -163,7 +163,7 @@ def set_test_inventories_config(enabled: bool): def save_inventories_file(): - """Write all inventories and their respective items to a YAML file.""" + """Write all inventories and their respective items to a WROLPi Config file.""" config = get_inventories_config() inventories = [] diff --git a/modules/inventory/conftest.py b/modules/inventory/conftest.py index e3990494..ad9ca198 100644 --- a/modules/inventory/conftest.py +++ b/modules/inventory/conftest.py @@ -1,7 +1,6 @@ import pytest from modules.inventory import Inventory, Item, DEFAULT_CATEGORIES, DEFAULT_INVENTORIES -from modules.inventory.common import set_test_inventories_config @pytest.fixture @@ -24,6 +23,4 @@ def init_test_inventory(test_session): inventory = test_session.query(Inventory).filter_by(name='Food Storage').one() test_session.commit() - set_test_inventories_config(True) yield inventory - set_test_inventories_config(False) diff --git a/modules/inventory/test/test_inventory.py b/modules/inventory/test/test_inventory.py index 9d2e9565..13d47ae7 100644 --- a/modules/inventory/test/test_inventory.py +++ b/modules/inventory/test/test_inventory.py @@ -1,5 +1,4 @@ from decimal import Decimal -from decimal import Decimal from itertools import zip_longest from typing import List, Iterable @@ -10,6 +9,7 @@ from wrolpi.common import Base from wrolpi.db import get_db_session +from wrolpi.api_utils import api_app from wrolpi.test.common import PytestCase from .. import init from ..common import sum_by_key, get_inventory_by_category, get_inventory_by_subcategory, get_inventory_by_name, \ @@ -274,8 +274,10 @@ def test_no_inventories(test_session, test_directory): pass -def test_inventories_version(test_session, test_directory, init_test_inventory): +def test_inventories_version(test_async_client, test_session, test_directory, init_test_inventory): """You can't save over a newer version of an inventory.""" + config = get_inventories_config() + for item in TEST_ITEMS: item = Item(**item) test_session.add(item) @@ -283,12 +285,10 @@ def test_inventories_version(test_session, test_directory, init_test_inventory): # Version is set to 1 on first save. save_inventories_file() - config = get_inventories_config() assert config.version == 1 # Version is incremented when saving. save_inventories_file() - config = get_inventories_config() assert config.version == 2 # Version is greater than what will be saved. @@ -299,7 +299,8 @@ def test_inventories_version(test_session, test_directory, init_test_inventory): save_inventories_file() -def test_inventories_config(test_session, test_directory, init_test_inventory): +@pytest.mark.asyncio +async def test_inventories_config(test_async_client, test_session, test_directory, init_test_inventory): for item in TEST_ITEMS: item = Item(**item) test_session.add(item) diff --git a/modules/map/api.py b/modules/map/api.py index 77eba4c6..aaae5f9f 100644 --- a/modules/map/api.py +++ b/modules/map/api.py @@ -1,21 +1,21 @@ from http import HTTPStatus from pathlib import Path -from sanic import Request, response +from sanic import Request, response, Blueprint from sanic_ext import validate from sanic_ext.extensions.openapi import openapi from modules.map import lib, schema from wrolpi import flags +from wrolpi.api_utils import json_response from wrolpi.common import wrol_mode_check, get_media_directory, background_task from wrolpi.errors import ValidationError -from wrolpi.root_api import get_blueprint, json_response from wrolpi.vars import PYTEST, DOCKERIZED -bp = get_blueprint('Map', '/api/map') +map_bp = Blueprint('Map', '/api/map') -@bp.post('/import') +@map_bp.post('/import') @openapi.definition( summary='Import PBF/dump map files', body=schema.ImportPost, @@ -38,12 +38,12 @@ async def import_pbfs(_: Request, body: schema.ImportPost): return response.empty() -@bp.get('/files') +@map_bp.get('/files') @openapi.description('Find any map files, get their import status') -def get_files_status(_: Request): +def get_files_status(request: Request): paths = lib.get_import_status() paths = sorted(paths, key=lambda i: str(i.path)) - pending = lib.IMPORTING.get('pending') + pending = request.app.shared_ctx.map_importing.get('pending') if pending: pending = [Path(i).relative_to(get_media_directory()) for i in pending] body = dict( diff --git a/modules/map/lib.py b/modules/map/lib.py index dd025794..595fb64e 100644 --- a/modules/map/lib.py +++ b/modules/map/lib.py @@ -1,6 +1,5 @@ import asyncio import subprocess -from multiprocessing import Manager from pathlib import Path from typing import List @@ -8,6 +7,7 @@ from modules.map.models import MapFile from wrolpi import flags +from wrolpi.api_utils import api_app from wrolpi.cmd import SUDO_BIN from wrolpi.common import get_media_directory, walk, logger, get_wrolpi_config from wrolpi.dates import now, timedelta_to_timestamp, seconds_to_timestamp @@ -17,10 +17,6 @@ logger = logger.getChild(__name__) -IMPORTING = Manager().dict(dict( - pending=None, -)) - def get_map_directory() -> Path: map_directory = get_media_directory() / get_wrolpi_config().map_directory @@ -100,7 +96,7 @@ async def import_files(paths: List[str]): if pbfs: success = False try: - IMPORTING.update(dict( + api_app.shared_ctx.map_importing.update(dict( pending=list(pbfs), )) total_elapsed += await run_import_command(*pbfs) @@ -109,7 +105,7 @@ async def import_files(paths: List[str]): except Exception as e: import_logger.warning('Failed to run import', exc_info=e) finally: - IMPORTING.update(dict( + api_app.shared_ctx.map_importing.update(dict( pending=None, )) @@ -137,7 +133,7 @@ async def import_files(paths: List[str]): success = False try: - IMPORTING.update(dict( + api_app.shared_ctx.map_importing.update(dict( pending=str(path), )) total_elapsed += await run_import_command(path) @@ -146,7 +142,7 @@ async def import_files(paths: List[str]): except Exception as e: import_logger.warning('Failed to run import', exc_info=e) finally: - IMPORTING.update(dict( + api_app.shared_ctx.map_importing.update(dict( pending=None, )) diff --git a/modules/otp/api.py b/modules/otp/api.py index 552e195c..febfb6ac 100644 --- a/modules/otp/api.py +++ b/modules/otp/api.py @@ -1,15 +1,14 @@ -from sanic import response +from sanic import response, Blueprint from sanic.request import Request from sanic_ext import validate from sanic_ext.extensions.openapi import openapi -from wrolpi.root_api import get_blueprint from . import lib, schema -bp = get_blueprint('OTP', '/api/otp') +otp_bp = Blueprint('OTP', '/api/otp') -@bp.post('/encrypt_otp') +@otp_bp.post('/encrypt_otp') @openapi.definition( summary='Encrypt a message with OTP.', body=schema.EncryptOTPRequest, @@ -20,7 +19,7 @@ async def post_encrypt_otp(_: Request, body: schema.EncryptOTPRequest): return response.json(data) -@bp.post('/decrypt_otp') +@otp_bp.post('/decrypt_otp') @openapi.definition( summary='Decrypt a message with OTP.', body=schema.DecryptOTPRequest, @@ -31,7 +30,7 @@ async def post_decrypt_otp(_: Request, body: schema.DecryptOTPRequest): return response.json(data) -@bp.get('/html') +@otp_bp.get('/html') async def get_new_otp_html(_: Request): body = lib.generate_html() return response.html(body) diff --git a/modules/videos/__init__.py b/modules/videos/__init__.py index 8263da7f..b9ab8b25 100644 --- a/modules/videos/__init__.py +++ b/modules/videos/__init__.py @@ -45,7 +45,7 @@ async def video_modeler(): if PYTEST: raise i = video.file_group.primary_path if video.file_group else video_id - logger.error(f'Unable to model Video {i=}', exc_info=e) + logger.error(f'Unable to model Video: {str(i)}', exc_info=e) file_group.indexed = True diff --git a/modules/videos/api.py b/modules/videos/api.py index d632834c..a1ed068a 100644 --- a/modules/videos/api.py +++ b/modules/videos/api.py @@ -4,19 +4,18 @@ from sanic.request import Request from sanic_ext.extensions.openapi import openapi +from wrolpi.api_utils import json_response from wrolpi.common import logger -from wrolpi.root_api import add_blueprint, json_response from . import lib, schema from .channel.api import channel_bp from .video.api import video_bp content_bp = Blueprint('VideoContent', '/api/videos') -bp = Blueprint('Videos', '/api/videos').group( +videos_bp = Blueprint('Videos', '/api/videos').group( content_bp, # view and manage video content and settings channel_bp, # view and manage channels video_bp, # view videos ) -add_blueprint(bp) logger = logger.getChild(__name__) diff --git a/modules/videos/channel/api.py b/modules/videos/channel/api.py index 276a3bd5..849d42da 100644 --- a/modules/videos/channel/api.py +++ b/modules/videos/channel/api.py @@ -11,7 +11,7 @@ get_media_directory, run_after, get_relative_to_media_directory from wrolpi.downloader import download_manager from wrolpi.events import Events -from wrolpi.root_api import json_response +from wrolpi.api_utils import json_response from wrolpi.schema import JSONErrorResponse from wrolpi.vars import PYTEST from . import lib diff --git a/modules/videos/channel/test/test_api.py b/modules/videos/channel/test/test_api.py index c6de0863..e7df28e3 100644 --- a/modules/videos/channel/test/test_api.py +++ b/modules/videos/channel/test/test_api.py @@ -270,8 +270,8 @@ def test_channel_empty_url_doesnt_conflict(test_client, test_session, test_direc @pytest.mark.asyncio -async def test_channel_download_requires_refresh(test_session, download_channel, video_download_manager, video_factory, - events_history): +async def test_channel_download_requires_refresh( + test_async_client, test_session, download_channel, video_download_manager, video_factory, events_history): """A Channel cannot be downloaded until it has been refreshed. Videos already downloaded are not downloaded again.""" @@ -287,6 +287,8 @@ def assert_refreshed(expected: bool): assert channel.refreshed == expected assert bool(channel.info_json) == expected + await test_async_client.sanic_app.dispatch('wrolpi.download.') + assert_refreshed(False) test_session.commit() @@ -318,7 +320,7 @@ async def do_download(_, download): assert downloaded_urls == ['https://example.com/2'] # Should not send refresh events because downloads are automated. - assert events_history == [] + assert list(events_history) == [] def test_channel_post_directory(test_session, test_client, test_directory): diff --git a/modules/videos/channel/test/test_lib.py b/modules/videos/channel/test/test_lib.py index 5711c9f2..1b08eeb3 100644 --- a/modules/videos/channel/test/test_lib.py +++ b/modules/videos/channel/test/test_lib.py @@ -62,7 +62,7 @@ def test_get_channel(test_session, test_directory, channel_factory): assert lib.get_channel(**p) -def test_channels_no_url(test_session, test_directory): +def test_channels_no_url(test_session, test_directory, test_channels_config): """Test that a Channel's URL is coerced to None if it is empty.""" channel1_directory = test_directory / 'channel1' channel1_directory.mkdir() diff --git a/modules/videos/conftest.py b/modules/videos/conftest.py index 60fd3ad3..6b2b3dcc 100644 --- a/modules/videos/conftest.py +++ b/modules/videos/conftest.py @@ -11,10 +11,11 @@ from PIL import Image from modules.videos.downloader import VideoDownloader, ChannelDownloader -from modules.videos.lib import set_test_channels_config, set_test_downloader_config +from modules.videos.lib import set_test_channels_config, set_test_downloader_config, get_channels_config from modules.videos.models import Channel, Video from wrolpi.downloader import DownloadFrequency, DownloadManager, Download from wrolpi.files.models import FileGroup +from wrolpi.api_utils import api_app from wrolpi.vars import PROJECT_DIR @@ -162,7 +163,8 @@ def video_download_manager(test_download_manager) -> DownloadManager: def test_channels_config(test_directory): (test_directory / 'config').mkdir(exist_ok=True) config_path = test_directory / 'config/channels.yaml' - with set_test_channels_config(): + with set_test_channels_config() as config: + config.initialize(api_app.shared_ctx.channels_config) yield config_path diff --git a/modules/videos/downloader.py b/modules/videos/downloader.py index 316677ef..dc020868 100755 --- a/modules/videos/downloader.py +++ b/modules/videos/downloader.py @@ -202,6 +202,7 @@ async def do_download(self, download: Download) -> DownloadResult: if download.attempts >= 10: raise UnrecoverableDownloadError('Max download attempts reached') + download_id = download.id url = normalize_video_url(download.url) # Video may have been downloaded previously, get its location for error reporting. @@ -325,7 +326,7 @@ async def do_download(self, download: Download) -> DownloadResult: '--ppa', 'Merger+ffmpeg_o1:-strict -2', url, ) - return_code, logs, _ = await self.process_runner(url, cmd, out_dir) + return_code, logs, _ = await self.process_runner(download_id, url, cmd, out_dir) stdout = logs['stdout'].decode() if hasattr(logs['stdout'], 'decode') else logs['stdout'] stderr = logs['stderr'].decode() if hasattr(logs['stderr'], 'decode') else logs['stderr'] @@ -408,7 +409,7 @@ async def do_download(self, download: Download) -> DownloadResult: except yt_dlp.utils.UnsupportedError as e: raise UnrecoverableDownloadError('URL is not supported by yt-dlp') from e except Exception as e: - logger.warning(f'VideoDownloader failed to download: {download.url}', exc_info=e) + logger.warning(f'VideoDownloader failed to download: {url}', exc_info=e) if _skip_download(e): # The video failed to download, and the error will never be fixed. Skip it forever. try: diff --git a/modules/videos/lib.py b/modules/videos/lib.py index 60bb495a..2430d159 100644 --- a/modules/videos/lib.py +++ b/modules/videos/lib.py @@ -191,14 +191,12 @@ def convert_or_generate_poster(video: Video) -> Tuple[Optional[pathlib.Path], Op class ChannelsConfig(ConfigFile): file_name = 'channels.yaml' default_config = dict( - channels={ - 'wrolpi': dict( - name='WROLPi', - url='https://www.youtube.com/channel/UC4t8bw1besFTyjW7ZBCOIrw/videos', - directory='videos/wrolpi', - download_frequency=604800, - ) - }, + channels=[dict( + name='WROLPi', + url='https://www.youtube.com/channel/UC4t8bw1besFTyjW7ZBCOIrw/videos', + directory='videos/wrolpi', + download_frequency=604800, + )], ) @property @@ -216,7 +214,11 @@ def channels(self, value: dict): def get_channels_config() -> ChannelsConfig: global TEST_CHANNELS_CONFIG - if isinstance(TEST_CHANNELS_CONFIG, ConfigFile): + if PYTEST and not TEST_CHANNELS_CONFIG: + logger.warning('Test did not initialize the channels config') + return + + if TEST_CHANNELS_CONFIG: return TEST_CHANNELS_CONFIG global CHANNELS_CONFIG @@ -227,7 +229,8 @@ def get_channels_config() -> ChannelsConfig: def set_test_channels_config(): global TEST_CHANNELS_CONFIG TEST_CHANNELS_CONFIG = ChannelsConfig() - yield + TEST_CHANNELS_CONFIG.initialize() + yield TEST_CHANNELS_CONFIG TEST_CHANNELS_CONFIG = None @@ -342,10 +345,7 @@ def get_downloader_config() -> VideoDownloaderConfig: def set_test_downloader_config(enabled: bool): global TEST_VIDEO_DOWNLOADER_CONFIG - if enabled: - TEST_VIDEO_DOWNLOADER_CONFIG = VideoDownloaderConfig() - else: - TEST_VIDEO_DOWNLOADER_CONFIG = None + TEST_VIDEO_DOWNLOADER_CONFIG = VideoDownloaderConfig() if enabled else None def get_channels_config_from_db(session: Session) -> dict: @@ -360,6 +360,11 @@ def save_channels_config(session: Session = None): """Get the Channel information from the DB, save it to the config.""" config = get_channels_config_from_db(session) channels_config = get_channels_config() + + if PYTEST and not channels_config: + logger.warning('Refusing to save channels config because test did not initialize a test config!') + return + channels_config.update(config) @@ -371,10 +376,8 @@ def save_channels_config(session: Session = None): @limit_concurrent(1) def import_channels_config(): """Import channel settings to the DB. Existing channels will be updated.""" - if PYTEST and not TEST_CHANNELS_CONFIG: - channel_import_logger.warning( - f'Not importing channels during this test. Use `test_channels_config` fixture if you would ' - f'like to call this.') + if PYTEST and not get_channels_config(): + logger.warning('Skipping import_channels_config for this test') return channel_import_logger.info('Importing videos config') diff --git a/modules/videos/models.py b/modules/videos/models.py index 55121865..5ae54269 100644 --- a/modules/videos/models.py +++ b/modules/videos/models.py @@ -9,7 +9,7 @@ from modules.videos.errors import UnknownVideo, UnknownChannel from wrolpi.captions import read_captions -from wrolpi.common import Base, ModelHelper, logger, get_media_directory, background_task +from wrolpi.common import Base, ModelHelper, logger, get_media_directory, background_task, truncate_object_bytes from wrolpi.db import get_db_curs, get_db_session, optional_session from wrolpi.downloader import Download, download_manager from wrolpi.files.lib import refresh_files, split_path_stem_and_suffix @@ -62,15 +62,19 @@ def __json__(self) -> dict: except Exception as e: logger.error(f'{self} ffprobe_json is invalid', exc_info=e) + # TODO these are large objects. Can they be fetched on demand? + captions = self.file_group.d_text + comments = self.get_comments() + # Put live data in "video" instead of "data" to avoid confusion on the frontend. d['video'] = dict( - caption=self.file_group.d_text, + caption=captions, caption_files=self.caption_files, channel=channel, channel_id=self.channel_id, codec_names=codec_names, codec_types=codec_types, - comments=self.get_comments(), + comments=comments, description=self.file_group.c_text or self.get_video_description(), id=self.id, info_json_file=self.info_json_file, diff --git a/modules/videos/test/test_api.py b/modules/videos/test/test_api.py index a53bbbff..527584bf 100644 --- a/modules/videos/test/test_api.py +++ b/modules/videos/test/test_api.py @@ -13,7 +13,7 @@ @pytest.mark.asyncio -async def test_refresh_videos_index(test_session, test_directory, video_factory): +async def test_refresh_videos_index(test_async_client, test_session, test_directory, video_factory): """The video modeler indexes video data into the Video's FileGroup.""" video_factory(with_video_file=True, with_caption_file=True, with_poster_ext='jpg', with_info_json=True) test_session.commit() diff --git a/modules/videos/test/test_common.py b/modules/videos/test/test_common.py index 9fe817c2..2900c570 100644 --- a/modules/videos/test/test_common.py +++ b/modules/videos/test/test_common.py @@ -10,6 +10,7 @@ from wrolpi.common import get_absolute_media_path, get_wrolpi_config from wrolpi.downloader import Download, DownloadFrequency from wrolpi.files import lib as files_lib +from wrolpi.api_utils import api_app from wrolpi.vars import PROJECT_DIR from .. import common from ..common import convert_image, update_view_counts, get_video_duration, generate_video_poster, is_valid_poster @@ -170,7 +171,8 @@ def update_channel_config(conf, source_id, d): assert download.frequency == channel1.download_frequency # Download frequency is adjusted when config file changes. - update_channel_config(channels_config, 'foo', {'download_frequency': DownloadFrequency.weekly}) + update_channel_config(channels_config, + 'foo', {'download_frequency': DownloadFrequency.weekly}) import_channels_config() assert len(test_session.query(Channel).all()) == 2 assert download.url == channel1.url diff --git a/modules/videos/test/test_downloader.py b/modules/videos/test/test_downloader.py index cdb575d7..f17e8068 100644 --- a/modules/videos/test/test_downloader.py +++ b/modules/videos/test/test_downloader.py @@ -125,12 +125,11 @@ async def test_download_video_tags(test_session, video_download_manager, video_f @skip_circleci -@pytest.mark.skip # TODO this test fails when running multiple tests. @pytest.mark.asyncio async def test_download_channel(test_session, simple_channel, video_download_manager, video_file, mock_video_extract_info, mock_video_prepare_filename, mock_video_process_runner): - """Downloading (updating the catalog of) a Channel updates it's info_json. + """Downloading (updating the catalog of) a Channel updates its info_json. If a Channel has `match_regex` only those videos with matching titles will be downloaded.""" url = 'https://www.youtube.com/c/LearningSelfReliance/videos' @@ -262,7 +261,7 @@ async def test_video_download(test_session, test_directory, simple_channel, vide await video_download_manager.wait_for_all_downloads() mock_video_process_runner.assert_called_once() - video_url, _, out_dir = mock_video_process_runner.call_args[0] + download_id, video_url, _, out_dir = mock_video_process_runner.call_args[0] download: Download = test_session.query(Download).one() assert video_url == download.url diff --git a/modules/videos/test/test_video.py b/modules/videos/test/test_video.py index bcdc9704..4f90ad0e 100644 --- a/modules/videos/test/test_video.py +++ b/modules/videos/test/test_video.py @@ -2,6 +2,7 @@ import pytest +from modules.videos.lib import save_channels_config from modules.videos.models import Video from wrolpi.errors import FileGroupIsTagged from wrolpi.files import lib as files_lib diff --git a/modules/videos/video/api.py b/modules/videos/video/api.py index c77f6511..ccfef2c7 100644 --- a/modules/videos/video/api.py +++ b/modules/videos/video/api.py @@ -1,14 +1,17 @@ +import json +import sys from http import HTTPStatus +import sanic.response from sanic import response, Blueprint from sanic.request import Request from sanic_ext import validate from sanic_ext.extensions.openapi import openapi +from wrolpi.api_utils import json_response, CustomJSONEncoder from wrolpi.common import logger, wrol_mode_check, run_after from wrolpi.errors import InvalidOrderBy, ValidationError from wrolpi.events import Events -from wrolpi.root_api import json_response from wrolpi.schema import JSONErrorResponse from . import lib from .. import schema @@ -35,7 +38,7 @@ def video_get(_: Request, video_id: int): @openapi.response(HTTPStatus.OK, schema.VideoSearchResponse) @openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse) @validate(schema.VideoSearchRequest) -async def search(_: Request, body: schema.VideoSearchRequest): +async def search_videos(_: Request, body: schema.VideoSearchRequest): if body.order_by not in lib.VIDEO_ORDERS: raise InvalidOrderBy('Invalid order by') @@ -53,8 +56,8 @@ async def search(_: Request, body: schema.VideoSearchRequest): return json_response(ret) -@video_bp.delete('/video/') -@video_bp.delete('/video/') +@video_bp.delete('/video/', name='Video Delete Many') +@video_bp.delete('/video/', name='Video Delete One') @openapi.description('Delete videos.') @openapi.response(HTTPStatus.NO_CONTENT) @openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse) diff --git a/modules/videos/video/test/test_api.py b/modules/videos/video/test/test_api.py index 7fba7e73..69f9c18c 100644 --- a/modules/videos/video/test/test_api.py +++ b/modules/videos/video/test/test_api.py @@ -155,7 +155,7 @@ async def test_wrol_mode(test_async_client, simple_channel, simple_video, wrol_m channel = dumps(dict(name=simple_channel.name, directory=str(simple_channel.directory))) tag = tag_factory() - wrol_mode_fixture(True) + await wrol_mode_fixture(True) # Can't create, update, or delete a channel. _, resp = await test_async_client.post('/api/videos/channels', content=channel) @@ -188,5 +188,5 @@ async def test_wrol_mode(test_async_client, simple_channel, simple_video, wrol_m assert test_download_manager.stopped.is_set() - wrol_mode_fixture(False) + await wrol_mode_fixture(False) assert not test_download_manager.stopped.is_set() diff --git a/modules/zim/api.py b/modules/zim/api.py index 144a7c6e..f8598308 100644 --- a/modules/zim/api.py +++ b/modules/zim/api.py @@ -1,25 +1,25 @@ import urllib.parse from http import HTTPStatus -from sanic import Request, response +from sanic import Request, response, Blueprint from sanic_ext import validate from sanic_ext.extensions.openapi import openapi from wrolpi import lang +from wrolpi.api_utils import json_response from wrolpi.common import logger from wrolpi.db import get_db_session from wrolpi.downloader import download_manager from wrolpi.events import Events -from wrolpi.root_api import get_blueprint, json_response from . import lib, schema from .models import Zims -bp = get_blueprint('Zim', '/api/zim') +zim_bp = Blueprint('Zim', '/api/zim') logger = logger.getChild(__name__) -@bp.get('/') +@zim_bp.get('/') @openapi.definition( summary='List all known Zim files', response=schema.GetZimsResponse, @@ -33,7 +33,7 @@ async def get_zims(_: Request): return json_response(resp) -@bp.delete('/') +@zim_bp.delete('/') @openapi.definition( summary='Delete all Zim files', ) @@ -43,7 +43,7 @@ async def delete_zims(_: Request, zim_ids: str): return response.empty() -@bp.post('/search/') +@zim_bp.post('/search/') @openapi.definition( summary='Search all entries of a Zim', body=schema.ZimSearchRequest, @@ -56,7 +56,7 @@ async def search_zim(_: Request, zim_id: int, body: schema.ZimSearchRequest): return json_response({'zim': headlines}) -@bp.get('//entry/') +@zim_bp.get('//entry/') @openapi.definition( summary='Read the entry at `zim_path` from the Zim file', ) @@ -74,7 +74,7 @@ async def get_zim_entry(_: Request, zim_id: int, zim_path: str): return resp -@bp.post('/tag') +@zim_bp.post('/tag') @openapi.definition( summary='Tag a Zim entry', body=schema.TagZimEntry, @@ -85,7 +85,7 @@ async def post_zim_tag(_: Request, body: schema.TagZimEntry): return response.empty(HTTPStatus.CREATED) -@bp.post('/untag') +@zim_bp.post('/untag') @openapi.definition( summary='Untag a Zim entry', body=schema.TagZimEntry, @@ -96,7 +96,7 @@ async def post_zim_untag(_: Request, body: schema.TagZimEntry): return response.empty(HTTPStatus.NO_CONTENT) -@bp.get('/subscribe') +@zim_bp.get('/subscribe') @openapi.definition( summary='Retrieve Zim subscriptions', response=schema.ZimSubscriptions, @@ -113,7 +113,7 @@ async def get_zim_subscriptions(_: Request): return json_response(resp) -@bp.post('/subscribe') +@zim_bp.post('/subscribe') @openapi.definition( summary='Subscribe to a particular Kiwix Zim', body=schema.ZimSubscribeRequest, @@ -128,7 +128,7 @@ async def post_zim_subscribe(_: Request, body: schema.ZimSubscribeRequest): return response.empty(HTTPStatus.CREATED) -@bp.delete('/subscribe/') +@zim_bp.delete('/subscribe/') @openapi.definition( summary='Unsubscribe to a particular Kiwix Zim', ) @@ -138,7 +138,7 @@ async def delete_zim_subscription(_: Request, subscription_id: int): return response.empty(HTTPStatus.NO_CONTENT) -@bp.get('/outdated') +@zim_bp.get('/outdated') @openapi.definition( summary='Returns the outdated and current Zim files', response=schema.OutdatedZims, @@ -149,7 +149,7 @@ async def get_outdated_zims(_: Request): return json_response(d) -@bp.delete('/outdated') +@zim_bp.delete('/outdated') @openapi.definition( summary='Remove all outdated Zims, if any.' ) @@ -161,7 +161,7 @@ async def delete_outdated_zims(_: Request): return response.empty(HTTPStatus.NO_CONTENT) -@bp.post('/search_estimates') +@zim_bp.post('/search_estimates') @validate(json=schema.SearchEstimateRequest) async def post_search_estimates(_: Request, body: schema.SearchEstimateRequest): """Get an estimated count of FileGroups/Zims which may or may not have been tagged.""" diff --git a/modules/zim/test/test_lib.py b/modules/zim/test/test_lib.py index 2ac83f90..7e414778 100644 --- a/modules/zim/test/test_lib.py +++ b/modules/zim/test/test_lib.py @@ -15,7 +15,7 @@ @pytest.mark.asyncio -async def test_get_zim(test_session, zim_path_factory): +async def test_get_zim(test_async_client, test_session, zim_path_factory): zim_path_factory() await files_lib.refresh_files() @@ -26,7 +26,7 @@ async def test_get_zim(test_session, zim_path_factory): @pytest.mark.asyncio -async def test_zim_get_entry(test_session, zim_path_factory): +async def test_zim_get_entry(test_async_client, test_session, zim_path_factory): zim_path_factory() await files_lib.refresh_files() @@ -50,7 +50,7 @@ async def test_zim_get_entry(test_session, zim_path_factory): @pytest.mark.asyncio -async def test_zim_get_entries_tags(test_session, test_zim, tag_factory): +async def test_zim_get_entries_tags(test_async_client, test_session, test_zim, tag_factory): tag1, tag2 = tag_factory('tag1'), tag_factory('tag2') test_zim.tag_entry(tag1.name, 'one') test_zim.tag_entry(tag2.name, 'one') diff --git a/pytest.ini b/pytest.ini index 19652fb6..64d9ad94 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -asyncio_mode=auto +asyncio_mode = auto testpaths = wrolpi modules diff --git a/requirements.txt b/requirements.txt index af1356eb..87d73fb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ Pillow==10.3.0 Pint==0.21 -PyYAML==6.0 +PyYAML==6.0.1 SQLAlchemy[asyncio]==1.3.22 -aiohttp==3.9.2 +aiohttp~=3.9.5 alembic==1.10.4 beautifulsoup4==4.12.2 cachetools==5.3.0 ebooklib==0.18 feedparser==6.0.10 libzim==3.3.0.post0 -mock==5.0.2 +mock==5.1.0 psutil==5.9.5 psycopg2-binary==2.9.6 pypdf==3.17.0 @@ -22,10 +22,10 @@ python-magic~=0.4.27 pytz==2023.3 sanic-ext==22.6.3 sanic-testing==22.6.0 -sanic==22.6.2 +sanic==23.6.0 # This is the latest version of Sanic that avoids this bug: https://github.com/sanic-org/sanic/issues/2921 selenium==4.9.1 srt==3.5.3 -vininfo[cli]==1.7.0 -websockets==10.4 +vininfo[cli]==1.8.0 +websockets==12.0 webvtt-py==0.4.6 -yt-dlp==2023.11.16 \ No newline at end of file +yt-dlp==2024.4.9 \ No newline at end of file diff --git a/wrolpi/__init__.py b/wrolpi/__init__.py index f59d4abd..e15d054f 100644 --- a/wrolpi/__init__.py +++ b/wrolpi/__init__.py @@ -17,6 +17,6 @@ def after_startup(func: callable): """ Run a function after the startup of the WROLPi Sanic API. This will be run for each process! """ - from .root_api import api_app + from wrolpi.api_utils import api_app api_app.after_server_start(func) return func diff --git a/wrolpi/api_utils.py b/wrolpi/api_utils.py new file mode 100644 index 00000000..de998765 --- /dev/null +++ b/wrolpi/api_utils.py @@ -0,0 +1,92 @@ +import json +from datetime import datetime, timezone, date +from decimal import Decimal +from functools import wraps +from pathlib import Path + +from sanic import response, HTTPResponse, Request, Sanic + +from wrolpi.common import Base, get_media_directory, logger, LOGGING_CONFIG +from wrolpi.errors import APIError + +logger = logger.getChild(__name__) + +# The only Sanic App, this is imported all over. +api_app = Sanic(name='api_app', log_config=LOGGING_CONFIG) + + +@wraps(response.json) +def json_response(*a, **kwargs) -> HTTPResponse: + """ + Handles encoding date/datetime in JSON. + """ + resp = response.json(*a, **kwargs, cls=CustomJSONEncoder, dumps=json.dumps) + return resp + + +class CustomJSONEncoder(json.JSONEncoder): + + def default(self, obj): + try: + if hasattr(obj, '__json__'): + # Get __json__ before others. + return obj.__json__() + elif isinstance(obj, datetime): + # API always returns dates in UTC. + if obj.tzinfo: + obj = obj.astimezone(timezone.utc) + else: + # A datetime with no timezone is UTC. + obj = obj.replace(tzinfo=timezone.utc) + obj = obj.isoformat() + return obj + elif isinstance(obj, date): + # API always returns dates in UTC. + obj = datetime(obj.year, obj.month, obj.day, tzinfo=timezone.utc) + return obj.isoformat() + elif isinstance(obj, Decimal): + return str(obj) + elif isinstance(obj, Base): + if hasattr(obj, 'dict'): + return obj.dict() + elif isinstance(obj, Path): + media_directory = get_media_directory() + try: + path = obj.relative_to(media_directory) + except ValueError: + # Path may not be absolute. + path = obj + if str(path) == '.': + return '' + return str(path) + return super(CustomJSONEncoder, self).default(obj) + except Exception as e: + logger.fatal(f'Failed to JSON encode {obj}', exc_info=e) + raise + + +def get_error_json(exception: BaseException): + """Return a JSON representation of the Exception instance.""" + if isinstance(exception, APIError): + # Error especially defined for WROLPi. + body = dict(error=str(exception), summary=exception.summary, code=exception.code) + else: + # Not a WROLPi APIError error. + body = dict( + error=str(exception), + summary=None, + code=None, + ) + if exception.__cause__: + # This exception was caused by another, follow the stack. + body['cause'] = get_error_json(exception.__cause__) + return body + + +def json_error_handler(request: Request, exception: APIError): + body = get_error_json(exception) + error = repr(str(body["error"])) + summary = repr(str(body["summary"])) + code = body['code'] + logger.debug(f'API returning JSON error {exception=} {error=} {summary=} {code=}') + return json_response(body, exception.status) diff --git a/wrolpi/cmd.py b/wrolpi/cmd.py index 382ab957..f146ff01 100644 --- a/wrolpi/cmd.py +++ b/wrolpi/cmd.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Optional -from wrolpi import flags from wrolpi.common import logger from wrolpi.vars import PYTEST, DOCKERIZED @@ -16,7 +15,7 @@ def is_executable(path: Path) -> bool: return path.is_file() and os.access(path, os.X_OK) -def which(*possible_paths: str, warn: bool = False, flag: flags.Flag = None) -> Optional[Path]: +def which(*possible_paths: str, warn: bool = False) -> Optional[Path]: """ Find an executable in the system $PATH. If the executable cannot be found in $PATH, then return the first executable found in possible_paths. @@ -39,20 +38,18 @@ def which(*possible_paths: str, warn: bool = False, flag: flags.Flag = None) -> logger.warning(f'Cannot find executable {possible_paths[0]}') return elif found: - if flag: - flag.set() return found.absolute() # Admin SUDO_BIN = which('sudo', '/usr/bin/sudo') -NMCLI_BIN = which('nmcli', '/usr/bin/nmcli', flag=flags.nmcli_installed) +NMCLI_BIN = which('nmcli', '/usr/bin/nmcli') CPUFREQ_INFO_BIN = which('cpufreq-info', '/usr/bin/cpufreq-info') CPUFREQ_SET_BIN = which('cpufreq-set', '/usr/bin/cpufreq-set') # Files -WGET_BIN = which('wget', '/usr/bin/wget', flag=flags.wget_installed) +WGET_BIN = which('wget', '/usr/bin/wget') # Map BASH_BIN = which('bash', '/bin/bash') @@ -61,17 +58,14 @@ def which(*possible_paths: str, warn: bool = False, flag: flags.Flag = None) -> SINGLE_FILE_BIN = which('single-file', '/usr/bin/single-file', # rpi os '/usr/local/bin/single-file', # debian - flag=flags.singlefile_installed, ) CHROMIUM = which('chromium-browser', 'chromium', '/usr/bin/chromium-browser', # rpi os '/usr/bin/chromium', # debian - flag=flags.chromium_installed, ) READABILITY_BIN = which('readability-extractor', '/usr/bin/readability-extractor', # rpi os '/usr/local/bin/readability-extractor', # debian - flag=flags.readability_installed, ) # Videos @@ -79,7 +73,6 @@ def which(*possible_paths: str, warn: bool = False, flag: flags.Flag = None) -> 'yt-dlp', '/usr/local/bin/yt-dlp', # Location in docker container '/opt/wrolpi/venv/bin/yt-dlp', # Use virtual environment location - flag=flags.yt_dlp_installed, ) -FFPROBE_BIN = which('ffprobe', '/usr/bin/ffprobe', flag=flags.ffprobe_installed) -FFMPEG_BIN = which('ffmpeg', '/usr/bin/ffmpeg', flag=flags.ffmpeg_installed) +FFPROBE_BIN = which('ffprobe', '/usr/bin/ffprobe') +FFMPEG_BIN = which('ffmpeg', '/usr/bin/ffmpeg') diff --git a/wrolpi/common.py b/wrolpi/common.py index 48dad949..983b90ff 100644 --- a/wrolpi/common.py +++ b/wrolpi/common.py @@ -1,10 +1,10 @@ import asyncio import atexit import contextlib -import ctypes import inspect import json import logging +import logging.config import multiprocessing import os import pathlib @@ -21,7 +21,7 @@ from functools import wraps from http import HTTPStatus from itertools import islice, filterfalse, tee -from multiprocessing import Lock, Manager +from multiprocessing.managers import DictProxy from pathlib import Path from types import GeneratorType from typing import Union, Callable, Tuple, Dict, List, Iterable, Optional, Generator, Any, Set, Coroutine @@ -42,17 +42,34 @@ from wrolpi.errors import WROLModeEnabled, NativeOnly, UnrecoverableDownloadError, LogLevelError from wrolpi.vars import PYTEST, DOCKERIZED, CONFIG_DIR, MEDIA_DIRECTORY, DEFAULT_HTTP_HEADERS -LOG_LEVEL = multiprocessing.Value(ctypes.c_int, 20) - -# Get root handler, delete any existing handlers. +LOGGING_CONFIG = { + 'version': 1, + 'disable_existing_loggers': True, + 'formatters': { + 'standard': { + 'format': '[%(asctime)s] [%(process)d] [%(name)s:%(lineno)d] [%(levelname)s] %(message)s' + }, + 'detailed': { + 'format': '[%(asctime)s] [%(process)d] [%(name)s:%(lineno)d] [%(levelname)s] %(message)s' + } + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'DEBUG', + 'formatter': 'standard', + 'stream': 'ext://sys.stdout' # Use standard output + } + }, + 'root': { + 'handlers': ['console'], + 'level': 'DEBUG' + } +} + +# Apply logging config. logger = logging.getLogger() -for handler in logger.handlers: - logger.removeHandler(handler) - -ch = logging.StreamHandler() -formatter = logging.Formatter('[%(asctime)s] [%(process)d] [%(name)s:%(lineno)d] [%(levelname)s] %(message)s') -ch.setFormatter(formatter) -logger.addHandler(ch) +logging.config.dictConfig(LOGGING_CONFIG) logger_ = logger.getChild(__name__) @@ -86,8 +103,9 @@ def set_global_log_level(log_level: int): """Set the global (shared between processes) log level.""" if not isinstance(log_level, int) or 0 > log_level or log_level > 40: raise LogLevelError() - with LOG_LEVEL.get_lock(): - LOG_LEVEL.value = log_level + from wrolpi.api_utils import api_app + with api_app.shared_ctx.log_level.get_lock(): + api_app.shared_ctx.log_level.value = log_level @contextlib.contextmanager @@ -103,7 +121,6 @@ def log_level_context(level): 'ConfigFile', 'DownloadFileInfo', 'DownloadFileInfoLink', - 'LOG_LEVEL', 'ModelHelper', 'WROLPI_CONFIG', 'aiohttp_get', @@ -113,6 +130,7 @@ def log_level_context(level): 'apply_modelers', 'apply_refresh_cleanup', 'background_task', + 'cancel_background_tasks', 'cancel_refresh_tasks', 'cancelable_wrapper', 'chain', @@ -123,6 +141,7 @@ def log_level_context(level): 'compile_tsvector', 'cum_timer', 'date_range', + 'LOGGING_CONFIG', 'disable_wrol_mode', 'download_file', 'enable_wrol_mode', @@ -135,9 +154,11 @@ def log_level_context(level): 'get_absolute_media_path', 'get_download_info', 'get_files_and_directories', + 'get_global_statistics', 'get_html_soup', 'get_media_directory', 'get_relative_to_media_directory', + 'get_title_from_html', 'get_warn_once', 'get_wrolpi_config', 'html_screenshot', @@ -155,6 +176,7 @@ def log_level_context(level): 'remove_whitespace', 'resolve_generators', 'run_after', + 'set_global_log_level', 'set_log_level', 'set_test_config', 'set_test_media_directory', @@ -164,6 +186,7 @@ def log_level_context(level): 'truncate_generator_bytes', 'truncate_object_bytes', 'tsvector', + 'url_strip_host', 'walk', 'wrol_mode_check', 'wrol_mode_enabled', @@ -254,7 +277,7 @@ def get_media_directory() -> Path: global TEST_MEDIA_DIRECTORY if PYTEST and not TEST_MEDIA_DIRECTORY: - raise ValueError('No test media directory set during testing!!') + raise RuntimeError('No test media directory set during testing!!') if isinstance(TEST_MEDIA_DIRECTORY, pathlib.Path): if not str(TEST_MEDIA_DIRECTORY).startswith('/tmp'): @@ -274,28 +297,28 @@ class ConfigFile: width: int = None def __init__(self): - self.file_lock = Lock() self.width = self.width or 90 if PYTEST: - # Do not load a global config on import while testing. A global instance will never be used for testing. + # Do not load a global config on import while testing. A global instance should never be used for testing. self._config = self.default_config.copy() return - self.initialize() def __repr__(self): return f'<{self.__class__.__name__} file={self.get_file()}>' - def initialize(self): + def initialize(self, multiprocessing_dict: Optional[DictProxy] = None): """Initializes this config dict using the default config and the config file.""" config_file = self.get_file() - self._config = Manager().dict() + # Use the provided multiprocessing.Manager().dict(), or dict() for testing. + self._config = multiprocessing_dict or dict() # Use the default settings to initialize the config. self._config.update(deepcopy(self.default_config)) if config_file.is_file(): # Use the config file to get the values the user set. with config_file.open('rt') as fh: self._config.update(yaml.load(fh, Loader=yaml.Loader)) + return self def _get_backup_filename(self): """Returns the path for the backup file for today.""" @@ -313,13 +336,16 @@ def save(self): Use the existing config file as a template; if any values are missing in the new config, use the values from the config file. """ + from wrolpi.api_utils import api_app + config_file = self.get_file() # Don't overwrite a real config while testing. if PYTEST and not str(config_file).startswith('/tmp'): raise ValueError(f'Refusing to save config file while testing: {config_file}') - # Only one process can write to the file. - acquired = self.file_lock.acquire(block=True, timeout=5.0) + # Only one process can write to a config. + lock = api_app.shared_ctx.config_save_lock + acquired = lock.acquire(block=True, timeout=5.0) try: # Config directory may not exist. @@ -341,12 +367,15 @@ def save(self): config = dict() config.update({k: v for k, v in self._config.items() if v is not None}) - logger.warning(f'Saving config: {config_file}') + logger.debug(f'Saving config: {config_file}') with config_file.open('wt') as fh: yaml.dump(config, fh, width=self.width) + # Wait for data to be written before releasing lock. + fh.flush() + os.fsync(fh.fileno()) finally: if acquired: - self.file_lock.release() + lock.release() def get_file(self) -> Path: if not self.file_name: @@ -563,7 +592,7 @@ def check(*a, **kw): return check -def enable_wrol_mode(download_manager=None): +def enable_wrol_mode(): """ Modify config to enable WROL Mode. @@ -571,15 +600,11 @@ def enable_wrol_mode(download_manager=None): """ logger_.warning('ENABLING WROL MODE') get_wrolpi_config().wrol_mode = True - if not download_manager: - from wrolpi.downloader import download_manager - download_manager.stop() - else: - # Testing. - download_manager.stop() + from wrolpi.downloader import download_manager + download_manager.stop() -def disable_wrol_mode(download_manager=None): +async def disable_wrol_mode(): """ Modify config to disable WROL Mode. @@ -587,12 +612,8 @@ def disable_wrol_mode(download_manager=None): """ logger_.warning('DISABLING WROL MODE') get_wrolpi_config().wrol_mode = False - if not download_manager: - from wrolpi.downloader import download_manager - download_manager.enable() - else: - # Testing. - download_manager.enable() + from wrolpi.downloader import download_manager + await download_manager.enable() def insert_parameter(func: Callable, parameter_name: str, item, args: Tuple, kwargs: Dict) -> Tuple[Tuple, Dict]: @@ -1276,7 +1297,13 @@ def chunks_by_stem(it: List[Union[pathlib.Path, str, int]], size: int) -> Genera @contextlib.contextmanager def timer(name, level: str = 'debug'): - """Prints out the time elapsed during the call of some block.""" + """Prints out the time elapsed during the call of some block. + + Example: + with timer('sleepy'): + time.sleep(10) + + """ before = datetime.now() log_method = getattr(logger_, level) try: diff --git a/wrolpi/conftest.py b/wrolpi/conftest.py index 750d96f3..d469c3bf 100644 --- a/wrolpi/conftest.py +++ b/wrolpi/conftest.py @@ -13,10 +13,10 @@ from abc import ABC from datetime import datetime from itertools import zip_longest -from typing import List, Callable, Dict, Sequence, Union +from typing import List, Callable, Dict, Sequence, Union, Coroutine from typing import Tuple, Set from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, AsyncMock from uuid import uuid1, uuid4 import pytest @@ -28,16 +28,19 @@ from sqlalchemy.engine import Engine, create_engine from sqlalchemy.orm import Session, sessionmaker +# Import root api so blueprints are attached. +import wrolpi.root_api # noqa from wrolpi import flags -from wrolpi.common import iterify, log_level_context +from wrolpi.api_utils import api_app +from wrolpi.common import iterify, log_level_context, enable_wrol_mode, disable_wrol_mode from wrolpi.common import set_test_media_directory, Base, set_test_config +from wrolpi.contexts import attach_shared_contexts, initialize_configs_contexts from wrolpi.dates import set_test_now from wrolpi.db import postgres_engine, get_db_args from wrolpi.downloader import DownloadManager, DownloadResult, Download, Downloader, \ downloads_manager_config_context from wrolpi.errors import UnrecoverableDownloadError from wrolpi.files.models import Directory, FileGroup -from wrolpi.root_api import BLUEPRINTS, api_app from wrolpi.tags import Tag from wrolpi.vars import PROJECT_DIR @@ -96,7 +99,7 @@ def test_debug_logger(level: int = logging.DEBUG): yield -@pytest.fixture(autouse=True) +@pytest.fixture def test_directory() -> pathlib.Path: """ Overwrite the media directory with a temporary directory. @@ -109,7 +112,7 @@ def test_directory() -> pathlib.Path: yield tmp_path -@pytest.fixture(autouse=True) +@pytest.fixture def test_config(test_directory) -> pathlib.Path: """ Create a test config based off the example config. @@ -124,17 +127,14 @@ def test_config(test_directory) -> pathlib.Path: @pytest.fixture() -def test_client() -> ReusableClient: +def test_client(test_directory) -> ReusableClient: """Get a Reusable Sanic Test Client with all default routes attached. (A non-reusable client would turn on for each request) """ - global ROUTES_ATTACHED - if ROUTES_ATTACHED is False: - # Attach any blueprints for the test. - for bp in BLUEPRINTS: - api_app.blueprint(bp) - ROUTES_ATTACHED = True + attach_shared_contexts(api_app) + + initialize_configs_contexts(api_app) for _ in range(5): # Sometimes the Sanic client tries to use a port already in use, try again... @@ -151,21 +151,18 @@ def test_client() -> ReusableClient: raise RuntimeError('Test never got unused port') -@pytest.fixture() -def test_async_client() -> SanicASGITestClient: +@pytest.fixture +def test_async_client(test_directory) -> SanicASGITestClient: """Get an Async Sanic Test Client with all default routes attached.""" - global ROUTES_ATTACHED - if ROUTES_ATTACHED is False: - # Attach any blueprints for the test. - for bp in BLUEPRINTS: - api_app.blueprint(bp) - ROUTES_ATTACHED = True + attach_shared_contexts(api_app) + + initialize_configs_contexts(api_app) return SanicASGITestClient(api_app) @pytest.fixture -def test_download_manager_config(test_directory): +def test_download_manager_config(test_async_client, test_directory): with downloads_manager_config_context(): (test_directory / 'config').mkdir(exist_ok=True) config_path = test_directory / 'config/download_manager.yaml' @@ -174,15 +171,15 @@ def test_download_manager_config(test_directory): @pytest.fixture async def test_download_manager( + test_async_client, test_session, # session is required because downloads can start without the test DB in place. test_download_manager_config, ): + # Needed to use signals in test app? + api_app.signalize() + manager = DownloadManager() - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.get_event_loop() - manager.enable(loop) + await manager.enable() yield manager @@ -243,28 +240,36 @@ class TestDownloader(Downloader, ABC): def __repr__(self): return '' - do_download = MagicMock() + do_download = AsyncMock() def set_test_success(self): async def _(*a, **kwargs): + # Sleep so download happens after testing is waiting. + await asyncio.sleep(1) return DownloadResult(success=True) self.do_download.side_effect = _ def set_test_failure(self): async def _(*a, **kwargs): + # Sleep so download happens after testing is waiting. + await asyncio.sleep(1) return DownloadResult(success=False) self.do_download.side_effect = _ def set_test_exception(self, exception: Exception = Exception('Test downloader exception')): async def _(*a, **kwargs): + # Sleep so download happens after testing is waiting. + await asyncio.sleep(1) raise exception self.do_download.side_effect = _ def set_test_unrecoverable_exception(self): async def _(*a, **kwargs): + # Sleep so download happens after testing is waiting. + await asyncio.sleep(1) raise UnrecoverableDownloadError() self.do_download.side_effect = _ @@ -313,15 +318,6 @@ def video_bytes(): return (PROJECT_DIR / 'test/big_buck_bunny_720p_1mb.mp4').read_bytes() -@pytest.fixture -def corrupted_video_file(test_directory) -> pathlib.Path: - """Return a copy of the corrupted video file in the `test_directory`.""" - destination = test_directory / f'{uuid4()}.mp4' - shutil.copy(PROJECT_DIR / 'test/corrupted.mp4', destination) - - yield destination - - @pytest.fixture def image_file(test_directory) -> pathlib.Path: """Create a small image file in the `test_directory`.""" @@ -445,16 +441,15 @@ def touch_paths(paths_): return create_files -@pytest.fixture -def wrol_mode_fixture(test_config, test_download_manager): - from wrolpi.common import enable_wrol_mode, disable_wrol_mode +async def set_wrol_mode(enabled: bool): + if enabled: + enable_wrol_mode() + else: + await disable_wrol_mode() - def set_wrol_mode(enabled: bool): - if enabled: - enable_wrol_mode(test_download_manager) - else: - disable_wrol_mode(test_download_manager) +@pytest.fixture +def wrol_mode_fixture(test_config, test_download_manager) -> Callable[[bool], Coroutine]: return set_wrol_mode @@ -477,12 +472,10 @@ async def create_subprocess_shell(*a, **kw): return mocker -@pytest.fixture(autouse=True) -def events_history(): - """Give each test it's own Events history.""" - with mock.patch('wrolpi.events.EVENTS_HISTORY', list()): - from wrolpi.events import EVENTS_HISTORY - yield EVENTS_HISTORY +@pytest.fixture +def events_history(test_async_client): + """Give each test its own Events history.""" + yield api_app.shared_ctx.events_history FLAGS_LOCK = multiprocessing.Lock() @@ -575,6 +568,7 @@ def _(file_groups: List[Dict], assert_count: bool = True): @pytest.fixture def assert_files_search(test_client): from wrolpi.test.common import assert_dict_contains + def _(search_str: str, expected: List[dict]): content = json.dumps({'search_str': search_str}) request, response = test_client.post('/api/files/search', content=content) diff --git a/wrolpi/contexts.py b/wrolpi/contexts.py new file mode 100644 index 00000000..e1bd9eb5 --- /dev/null +++ b/wrolpi/contexts.py @@ -0,0 +1,100 @@ +import ctypes +import logging +import multiprocessing + +from sanic import Sanic + + +def attach_shared_contexts(app: Sanic): + """Initializes Sanic's shared context with WROLPi's multiprocessing tools. + + This is called by main.py, and by testing.""" + # Many things wait for flags.db_up, initialize before starting. + from wrolpi import flags + app.shared_ctx.flags = multiprocessing.Manager().dict({i: False for i in flags.FLAG_NAMES}) + + # ConfigFile multiprocessing_dict's. + # Shared Configs + app.shared_ctx.wrolpi_config = multiprocessing.Manager().dict() + app.shared_ctx.tags_config = multiprocessing.Manager().dict() + app.shared_ctx.inventories_config = multiprocessing.Manager().dict() + app.shared_ctx.channels_config = multiprocessing.Manager().dict() + app.shared_ctx.download_manager_config = multiprocessing.Manager().dict() + app.shared_ctx.video_downloader_config = multiprocessing.Manager().dict() + # Shared dicts. + app.shared_ctx.refresh = multiprocessing.Manager().dict() + app.shared_ctx.uploaded_files = multiprocessing.Manager().dict() + app.shared_ctx.bandwidth = multiprocessing.Manager().dict() + app.shared_ctx.disks_bandwidth = multiprocessing.Manager().dict() + app.shared_ctx.max_disks_bandwidth = multiprocessing.Manager().dict() + app.shared_ctx.map_importing = multiprocessing.Manager().dict() + # Shared lists. + app.shared_ctx.events_history = multiprocessing.Manager().list() + # Shared ints + app.shared_ctx.log_level = multiprocessing.Value(ctypes.c_int, logging.DEBUG) + + # Download Manager + app.shared_ctx.download_manager_data = multiprocessing.Manager().dict() + app.shared_ctx.download_manager_queue = multiprocessing.Queue() + app.shared_ctx.download_manager_disabled = multiprocessing.Event() + app.shared_ctx.download_manager_stopped = multiprocessing.Event() + + # Events. + app.shared_ctx.single_tasks_started = multiprocessing.Event() + app.shared_ctx.flags_initialized = multiprocessing.Event() + + # Locks + app.shared_ctx.config_save_lock = multiprocessing.Lock() + + reset_shared_contexts(app) + + +def reset_shared_contexts(app: Sanic): + """Resets shared contexts (dicts/lists/Events,etc.). + + Should only be called when server is starting, or could start back up.""" + # Should only be called when server is expected to start again. + app.shared_ctx.wrolpi_config.clear() + app.shared_ctx.tags_config.clear() + app.shared_ctx.inventories_config.clear() + app.shared_ctx.channels_config.clear() + app.shared_ctx.download_manager_config.clear() + app.shared_ctx.video_downloader_config.clear() + # Shared dicts. + app.shared_ctx.refresh.clear() + app.shared_ctx.uploaded_files.clear() + app.shared_ctx.bandwidth.clear() + app.shared_ctx.disks_bandwidth.clear() + app.shared_ctx.max_disks_bandwidth.clear() + app.shared_ctx.map_importing.clear() + # Shared ints + app.shared_ctx.log_level.value = logging.DEBUG + + # Download Manager + app.shared_ctx.download_manager_data.clear() + app.shared_ctx.download_manager_data.update(dict( + processing_domains=[], + killed_downloads=[], + )) + + # Events. + app.shared_ctx.single_tasks_started.clear() + app.shared_ctx.download_manager_disabled.clear() + app.shared_ctx.download_manager_stopped.clear() + app.shared_ctx.flags_initialized.clear() + + +def initialize_configs_contexts(app: Sanic): + """Assign multiprocessing Dicts to their respective FileConfigs in this process.""" + from modules.inventory.common import INVENTORIES_CONFIG + from modules.videos.lib import CHANNELS_CONFIG + from modules.videos.lib import VIDEO_DOWNLOADER_CONFIG + from wrolpi.common import WROLPI_CONFIG + from wrolpi.tags import TAGS_CONFIG + from wrolpi.downloader import DOWNLOAD_MANAGER_CONFIG + INVENTORIES_CONFIG.initialize(app.shared_ctx.inventories_config) + CHANNELS_CONFIG.initialize(app.shared_ctx.channels_config) + VIDEO_DOWNLOADER_CONFIG.initialize(app.shared_ctx.video_downloader_config) + WROLPI_CONFIG.initialize(app.shared_ctx.wrolpi_config) + TAGS_CONFIG.initialize(app.shared_ctx.tags_config) + DOWNLOAD_MANAGER_CONFIG.initialize(app.shared_ctx.download_manager_config) diff --git a/wrolpi/dates.py b/wrolpi/dates.py index 7457f8f2..6aed40db 100644 --- a/wrolpi/dates.py +++ b/wrolpi/dates.py @@ -60,6 +60,9 @@ def strpdate(dt: str) -> datetime: if len(a) == 4: # Y/m/d return datetime.strptime(dt, '%Y/%m/%d') + if len(a) == 2 and len(b) == 2 and len(c) == 13: + # Assume d/m/Y HH:MM:SS + return datetime.strptime(dt, '%m/%d/%Y %H:%M:%S') except ValueError: pass elif dt.count('-') == 2 and len(dt) <= 10: diff --git a/wrolpi/db.py b/wrolpi/db.py index 12379e02..33d8dfea 100644 --- a/wrolpi/db.py +++ b/wrolpi/db.py @@ -100,17 +100,12 @@ def get_db_session(commit: bool = False) -> ContextManager[Session]: @contextmanager -def get_db_curs(commit: bool = False): - """ - Context manager that yields a DictCursor to execute raw SQL statements. - """ +def get_db_conn(isolation_level=psycopg2.extensions.ISOLATION_LEVEL_DEFAULT): local_engine, session = get_db_context() connection = local_engine.raw_connection() - curs = connection.cursor(cursor_factory=psycopg2.extras.DictCursor) + connection.set_isolation_level(isolation_level) try: - yield curs - if commit: - connection.commit() + yield connection, session except sqlalchemy.exc.DatabaseError: session.rollback() raise @@ -120,6 +115,26 @@ def get_db_curs(commit: bool = False): connection.rollback() +@contextmanager +def get_db_curs(commit: bool = False, isolation_level=psycopg2.extensions.ISOLATION_LEVEL_DEFAULT): + """ + Context manager that yields a DictCursor to execute raw SQL statements. + """ + with get_db_conn(isolation_level=isolation_level) as (connection, session): + curs = connection.cursor(cursor_factory=psycopg2.extras.DictCursor) + try: + yield curs + if commit: + connection.commit() + except sqlalchemy.exc.DatabaseError: + session.rollback() + raise + finally: + # Rollback only if a transaction hasn't been committed. + if session.transaction.is_active: + connection.rollback() + + def optional_session(commit: Union[callable, bool] = False) -> callable: """ Wraps a function, if a Session is passed it will be used. Otherwise, a new session will be diff --git a/wrolpi/downloader.py b/wrolpi/downloader.py index c163b203..b7ea6617 100644 --- a/wrolpi/downloader.py +++ b/wrolpi/downloader.py @@ -11,7 +11,6 @@ from datetime import timedelta, datetime from enum import Enum from itertools import filterfalse -from queue import Empty from typing import List, Dict, Generator, Iterable, Coroutine from typing import Tuple, Optional from urllib.parse import urlparse @@ -22,18 +21,44 @@ from sqlalchemy import Column, Integer, String, Text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Session +from sqlalchemy.sql import Delete from wrolpi import flags +from wrolpi.api_utils import api_app from wrolpi.common import Base, ModelHelper, logger, wrol_mode_check, zig_zag, ConfigFile, WROLPI_CONFIG, \ - background_task, limit_concurrent + limit_concurrent, wrol_mode_enabled from wrolpi.dates import TZDateTime, now, Seconds from wrolpi.db import get_db_session, get_db_curs, optional_session from wrolpi.errors import InvalidDownload, UnrecoverableDownloadError -from wrolpi.vars import PYTEST +from wrolpi.vars import PYTEST, SIMULTANEOUS_DOWNLOAD_DOMAINS logger = logger.getChild(__name__) +async def perpetual_download_worker(): + logger.debug('perpetual_download waiting for db...') + async with flags.db_up.wait_for(): + pass + + logger.debug('perpetual_download starting...') + + while True: + try: + logger.debug('perpetual_download is running') + await asyncio.sleep(10) + + if download_manager.is_stopped: + logger.warning('DownloadManager stopped, quitting...') + return + + await download_manager.do_downloads() + except asyncio.CancelledError: + logger.info('perpetual_download cancelled...') + return + except Exception as e: + logger.error('perpetual_download failed from unexpected error', exc_info=e) + + class DownloadFrequency(int, Enum): hourly = 3600 hours3 = hourly * 3 @@ -56,6 +81,14 @@ class DownloadResult: settings: dict = field(default_factory=dict) +class DownloadStatus(str, Enum): + new = 'new' + pending = 'pending' + complete = 'complete' + failed = 'failed' + deferred = 'deferred' + + class Download(ModelHelper, Base): # noqa """Model that is used to schedule downloads.""" __tablename__ = 'download' # noqa @@ -72,7 +105,7 @@ class Download(ModelHelper, Base): # noqa location = Column(Text) # Relative App URL where the item is downloaded next_download = Column(TZDateTime) settings = Column(JSONB) # information about how the download should happen (destination, etc.) - status = Column(String, default='new') # 'new', 'pending', 'complete', 'failed', 'deferred' + status = Column(String, default=DownloadStatus.new) # `DownloadStatus` enum. _manager: 'DownloadManager' = None def __init__(self, *args, **kwargs): @@ -103,49 +136,54 @@ def __json__(self): def renew(self, reset_attempts: bool = False): """Mark this Download as "new" so it will be retried.""" - self.status = 'new' + self.status = DownloadStatus.new if reset_attempts: self.attempts = 0 + @property def is_new(self) -> bool: - return self.status == 'new' + return self.status == DownloadStatus.new def defer(self): """Download should be tried again after a time.""" - self.status = 'deferred' + self.status = DownloadStatus.deferred + @property def is_deferred(self) -> bool: - return self.status == 'deferred' + return self.status == DownloadStatus.deferred def fail(self): """Download should not be attempted again. A recurring Download will raise an error.""" if self.frequency: raise ValueError('Recurring download should not be failed.') - self.status = 'failed' + self.status = DownloadStatus.failed + @property def is_failed(self) -> bool: - return self.status == 'failed' + return self.status == DownloadStatus.failed def started(self): """Mark this Download as in progress.""" self.attempts += 1 - self.status = 'pending' + self.status = DownloadStatus.pending + @property def is_pending(self) -> bool: - return self.status == 'pending' + return self.status == DownloadStatus.pending def complete(self): """Mark this Download as successfully downloaded.""" - self.status = 'complete' + self.status = DownloadStatus.complete self.error = None # clear any old errors self.last_successful_download = now() + @property def is_complete(self) -> bool: - return self.status == 'complete' + return self.status == DownloadStatus.complete def get_downloader(self): if self.downloader: - return self.manager.get_downloader_by_name(self.downloader) + return download_manager.get_downloader_by_name(self.downloader) raise UnrecoverableDownloadError(f'Cannot find downloader for {repr(str(self.url))}') @@ -153,16 +191,6 @@ def get_downloader(self): def domain(self): return urlparse(self.url).netloc - @property - def manager(self) -> 'DownloadManager': - if self._manager: - return self._manager - raise ValueError('No manager has been set!') - - @manager.setter - def manager(self, value): - self._manager = value - def filter_excluded(self, urls: List[str]) -> List[str]: """Return any URLs that do not match my excluded_urls.""" if self.settings and (excluded_urls := self.settings.get('excluded_urls')): @@ -173,10 +201,7 @@ def excluded(url: str): return urls def add_to_skip_list(self): - if self.manager: - self.manager.add_to_skip_list(self.url) - else: - raise RuntimeError(f'Cannot add {self} to skip list because I do not have a manager.') + download_manager.add_to_skip_list(self.url) class Downloader: @@ -191,7 +216,6 @@ def __init__(self, name: str = None, timeout: int = None): self.name: str = self.name or name self.timeout: int = timeout or self.timeout - self._kill = multiprocessing.Event() self._manager: DownloadManager = None # noqa @@ -212,25 +236,10 @@ def already_downloaded(self, *urls: List[str], session: Session = None): @property def manager(self): - if self._manager is None: - raise NotImplementedError('This needs to be registered, see DownloadManager.register_downloader') - return self._manager - - @manager.setter - def manager(self, value): - self._manager = value - - def kill(self): - """Kill the running download for this Downloader.""" - if not self._kill.is_set(): - self._kill.set() - - def clear(self): - """Clear any "kill" request for this Downloader.""" - if self._kill.is_set(): - self._kill.clear() + return download_manager - async def process_runner(self, url: str, cmd: Tuple[str, ...], cwd: pathlib.Path, timeout: int = None, + async def process_runner(self, download_id: int, url: str, cmd: Tuple[str, ...], cwd: pathlib.Path, + timeout: int = None, **kwargs) -> Tuple[int, dict, bytes]: """ Run a subprocess using the provided arguments. This process can be killed by the Download Manager. @@ -271,9 +280,9 @@ async def process_runner(self, url: str, cmd: Tuple[str, ...], cwd: pathlib.Path elapsed = (now() - start).total_seconds() if timeout and elapsed > timeout: logger.warning(f'Download has exceeded its timeout {elapsed=}') - self.kill() + download_manager.kill_download(download_id) - if self._kill.is_set(): + if download_manager.download_is_killed(download_id): logger.warning(f'Killing download {pid=}, {elapsed} seconds elapsed (timeout was not exceeded).') proc.kill() break @@ -281,8 +290,6 @@ async def process_runner(self, url: str, cmd: Tuple[str, ...], cwd: pathlib.Path logger.error(f'{self}.process_runner had a download error', exc_info=e) raise finally: - self.clear() - # Output all logs from the process. # TODO is there a way to stream this output while the process is running? logger.debug(f'Download exited with {proc.returncode}') @@ -290,26 +297,29 @@ async def process_runner(self, url: str, cmd: Tuple[str, ...], cwd: pathlib.Path return proc.returncode, logs, stdout - async def cancel_wrapper(self, coro: Coroutine, download: Download): + @staticmethod + async def cancel_wrapper(coro: Coroutine, download: Download): """ Converts an async coroutine to a task. If DownloadManager receives a kill request, this method will cancel the task. """ + download_id = download.id task = asyncio.create_task(coro) while not task.done(): - if self._kill.is_set(): + if download_manager.download_is_killed( + download_id) or download_manager.is_disabled or download_manager.is_stopped: logger.warning(f'Cancel download of {download.url}') task.cancel() try: await task except asyncio.CancelledError as e: - logger.debug(f'Successful cancel of {download.url}', exc_info=e) + logger.info(f'Successful cancel of {download.url}', exc_info=e) return DownloadResult( success=False, error='Download was canceled', ) finally: - self.clear() + download_manager.unkill_download(download_id) else: # Wait for the download to complete. Cancel if requested. await asyncio.sleep(0.1) @@ -330,141 +340,58 @@ class DownloadManager: def __init__(self): self.instances: Tuple[Downloader] = tuple() self._instances = dict() - self.disabled = multiprocessing.Event() - self.stopped = multiprocessing.Event() - self.download_queue: multiprocessing.Queue = multiprocessing.Queue() self.workers: List[Dict] = [] self.worker_count: int = 1 self.worker_alive_frequency = timedelta(minutes=10) - self.data = multiprocessing.Manager().dict() - # We haven't started downloads yet, so no domains are downloading. - self.data['processing_domains'] = [] - self.data['workers'] = dict() - def __repr__(self): return f'' - def register_downloader(self, instance: Downloader): - if not isinstance(instance, Downloader): - raise ValueError(f'Invalid downloader cannot be registered! {instance=}') - if instance in self.instances: - raise ValueError(f'Downloader already registered! {instance=}') - - instance.manager = self - - self.instances = (*self.instances, instance) - self._instances[instance.name] = instance - - async def download_worker(self, num: int): - """Fetch a download from the queue, perform the download then store the results. - - Calls DownloadManger.start_downloads() after a download completes. - """ - from wrolpi.db import get_db_session - - pid = os.getpid() - name = f'{pid}.{num}' - worker_logger = logger.getChild(f'download_worker.{name}') - - disabled = 'disabled' if self.disabled.is_set() else 'enabled' - worker_logger.info(f'Starting up. DownloadManager is {disabled}.') - last_heartbeat = now() + @property + def disabled(self): + # Server is going to keep running, but downloads should stop. + return api_app.shared_ctx.download_manager_disabled - while True: - if self.stopped.is_set(): - # Service may be restarting, close the worker. - worker_logger.warning("DownloadManager is stopped. I'm stopping.") - return + @property + def is_disabled(self): + return self.disabled.is_set() - if now() - last_heartbeat > self.worker_alive_frequency: - last_heartbeat = now() + @property + def stopped(self): + # Server is stopping and perpetual download should stop. + return api_app.shared_ctx.download_manager_stopped - disabled = self.disabled.is_set() + @property + def is_stopped(self): + return self.stopped.is_set() - if disabled: - # Downloading is disabled, wait for it to enable. - await asyncio.sleep(1) - continue + @property + def download_queue(self) -> multiprocessing.Queue: + return api_app.shared_ctx.download_manager_queue - try: - download_id, url = self.download_queue.get_nowait() - worker_logger.debug(f'Got download {download_id}') + @property + def processing_domains(self): + return api_app.shared_ctx.download_manager_data['processing_domains'] - with get_db_session(commit=True) as session: - # Mark the download as started in new session so the change is committed. - download = session.query(Download).filter_by(id=download_id).one() - download.started() + @processing_domains.setter + def processing_domains(self, value: list): + api_app.shared_ctx.download_manager_data.update({'processing_domains': value}) - # Set the Download's manager. Testing will not use the global manager. - download.manager = self + def _add_processing_domain(self, domain: str): + self.processing_domains = list(self.processing_domains) + [domain, ] - downloader: Downloader = download.get_downloader() - if not downloader: - worker_logger.warning(f'Could not find downloader for {download.downloader=}') + def _delete_processing_domain(self, domain: str): + self.processing_domains = [i for i in self.processing_domains if i != domain] - self.data['processing_domains'].append(download.domain) + def register_downloader(self, instance: Downloader): + if not isinstance(instance, Downloader): + raise ValueError(f'Invalid downloader cannot be registered! {instance=}') + if instance in self.instances: + raise ValueError(f'Downloader already registered! {instance=}') - try_again = True - try: - # Create download coroutine; wrap it, so it can be canceled - coro = downloader.do_download(download) - if not inspect.iscoroutine(coro): - raise RuntimeError(f'Coroutine expected from {downloader} do_download method.') - result = await downloader.cancel_wrapper(coro, download) - except UnrecoverableDownloadError as e: - # Download failed and should not be retried. - worker_logger.warning(f'UnrecoverableDownloadError for {url}', exc_info=e) - result = DownloadResult(success=False, error=str(traceback.format_exc())) - try_again = False - except Exception as e: - worker_logger.warning(f'Failed to download {url}. Will be tried again later.', exc_info=e) - result = DownloadResult(success=False, error=str(traceback.format_exc())) - - error_len = len(result.error) if result.error else 0 - worker_logger.debug( - f'Got success={result.success} from {downloader} download_id={download.id} with {error_len=}') - - with get_db_session(commit=True) as session: - # Modify the download in a new session because downloads may take a long time. - download: Download = session.query(Download).filter_by(id=download_id).one() - # Use a new location if provided, keep the old location if no new location is provided, otherwise - # clear out an outdated location. - download.location = result.location or download.location or None - # Clear any old errors if the download succeeded. - download.error = result.error if result.error else None - download.next_download = self.calculate_next_download(download, session) - - if result.downloads: - worker_logger.info(f'Adding {len(result.downloads)} downloads from result of {download.url}') - urls = download.filter_excluded(result.downloads) - self.create_downloads(urls, session, downloader_name=download.sub_downloader, - settings=result.settings) - - if try_again is False and not download.frequency: - # Only once-downloads can fail. - download.fail() - elif result.success: - download.complete() - else: - download.defer() - - # Remove this domain from the running list. - self._remove_domain(download.domain) - # Request any new downloads be added to the queue. - background_task(self.queue_downloads()) - # Save the config now that the Download has finished. - background_task(save_downloads_config()) - except asyncio.CancelledError as e: - worker_logger.warning('Canceled!', exc_info=e) - self.download_queue.task_done() - return - except Empty: - # No work yet. - await asyncio.sleep(0.1) - except Exception as e: - worker_logger.warning(f'Unexpected error', exc_info=e) + self.instances = (*self.instances, instance) + self._instances[instance.name] = instance def log(self, message: str, level=logging.DEBUG, exc_info=None): logger.log(level, f'{self} {message}', exc_info=exc_info) @@ -481,19 +408,6 @@ def log_error(self, message: str, exc_info=None): def log_warning(self, message: str): return self.log(message, logging.WARNING) - def _add_domain(self, domain: str): - """Add a domain to the processing list. - - Raises ValueError if the domain is already processing.""" - processing_domains = self.data['processing_domains'] - if domain in processing_domains: - raise ValueError(f'Domain already being downloaded! {domain}') - self.data['processing_domains'] = [*processing_domains, domain] - - def _remove_domain(self, domain: str): - """Remove a domain from the processing list.""" - self.data['processing_domains'] = [i for i in self.data['processing_domains'] if i != domain] - @wrol_mode_check def start_workers(self, loop=None): """Start all download worker tasks. Does nothing if they are already running.""" @@ -510,40 +424,6 @@ def start_workers(self, loop=None): task = loop.create_task(coro) self.workers.append(task) - def workers_running(self): - for task in self.workers: - if not task.done(): - return True - return False - - def cancel_workers(self): - if self.workers_running(): - for task in self.workers: - task.cancel() - - async def perpetual_download(self): - """ - A method that calls itself forever. It will queue new downloads when they are ready. - - Only one of these can be running at a time. - """ - if self.manager.is_set(): - # Only one manager needs to be running. - return - - self.manager.set() - - async def _perpetual_download(): - self.log_debug(f'perpetual download is alive with {self.download_queue.qsize()} queued downloads') - if self.stopped.is_set(): - return - - await download_manager.do_downloads() - await asyncio.sleep(30) - background_task(_perpetual_download()) - - background_task(_perpetual_download()) - def get_downloader_by_name(self, name: str) -> Optional[Downloader]: """Attempt to find a registered Downloader by its name. Returns None if it cannot be found.""" if downloader := self._instances.get(name): @@ -566,7 +446,6 @@ def get_or_create_download(self, url: str, session: Session, reset_attempts: boo download = Download(url=url, status='new') session.add(download) session.flush() - download.manager = self return download @optional_session @@ -599,9 +478,9 @@ def create_downloads(self, urls: List[str], downloader_name: str, session: Sessi try: # Start downloading ASAP. - background_task(self.queue_downloads()) + api_app.add_task(self.dispatch_downloads()) # Save the config now that new Downloads exist. - background_task(save_downloads_config()) + api_app.add_task(save_downloads_config()) except RuntimeError: # Event loop isn't running. Probably testing? if not PYTEST: @@ -636,14 +515,14 @@ def recurring_download(self, url: str, frequency: int, downloader_name: str, ses @wrol_mode_check @optional_session - async def queue_downloads(self, session: Session = None): - """Put all downloads in queue. Will only queue downloads if there are workers to take them. Each worker - only receives one domain, this is to prevent downloading from one domain many times at once.""" - if self.disabled.is_set() or self.stopped.is_set(): + async def dispatch_downloads(self, session: Session = None): + """Dispatch Sanic signals to start downloads. This only starts as many downloads as the + SIMULTANEOUS_DOWNLOAD_DOMAINS variable.""" + if self.is_disabled or self.is_stopped: # Don't queue downloads when disabled. return - if (domains := len(self.data['processing_domains'])) >= 4: + if (domains := len(self.processing_domains)) > SIMULTANEOUS_DOWNLOAD_DOMAINS: self.log_debug( f'Unable to queue downloads because there are more domains than workers: {domains} >= 4') return @@ -651,18 +530,18 @@ async def queue_downloads(self, session: Session = None): # Find download whose domain isn't already being downloaded. new_downloads = list(session.query(Download).filter( Download.status == 'new', - Download.domain not in self.data['processing_domains'], + Download.domain not in self.processing_domains, ).order_by( Download.frequency.is_(None), Download.frequency, Download.id)) # noqa count = 0 for download in new_downloads: - download.manager = self # Assign this Download to this manager. domain = download.domain - if domain not in self.data['processing_domains']: - self._add_domain(domain) - self.download_queue.put((download.id, download.url)) + if domain not in self.processing_domains and len(self.processing_domains) < SIMULTANEOUS_DOWNLOAD_DOMAINS: + self._add_processing_domain(domain) + context = dict(download_id=download.id, download_url=download.url) + await api_app.dispatch('wrolpi.download.download', context=context) count += 1 if count: self.log_debug(f'Added {count} downloads to queue.') @@ -672,7 +551,7 @@ async def do_downloads(self): Warning: Downloads will still be running even after this returns! See `wait_for_all_downloads`. """ - if self.disabled.is_set(): + if self.disabled.is_set() or wrol_mode_enabled(): return try: @@ -685,31 +564,34 @@ async def do_downloads(self): except Exception as e: self.log_error(f'Unable to delete old downloads!', exc_info=e) - await self.queue_downloads() + await self.dispatch_downloads() - async def wait_for_all_downloads(self): - """Wait for all Downloads in queue AND any new Downloads to complete. + async def wait_for_all_downloads(self, timeout: int = 10): + """Signals start of all pending Downloads, waits for all Downloads to be processed. - THIS METHOD IS FOR TESTING. - """ - while True: - await asyncio.sleep(0.1) + @param timeout: Give up waiting after this many seconds. + @raises TimeoutError: If timeout is exceeded. - if not self.workers_running(): - raise ValueError('No workers are running!') + @warning: THIS METHOD IS FOR TESTING. + """ + # Use real datetime.now to avoid `fake_now`. + start = datetime.now() - await self.queue_downloads() + while (datetime.now() - start).total_seconds() < timeout: + # Send out download signals. + await self.dispatch_downloads() - try: - next(self.get_new_downloads()) - continue - except StopIteration: - # Give any new downloads a chance to start up. - await asyncio.sleep(0.1) + # Wait for processes to start. + await asyncio.sleep(1) - if self.download_queue.empty(): - # Queue is empty. Wait for background tasks. - break + # Break out of loop only when all downloads have been processed. + with get_db_session() as session: + statuses = {i.status for i in self.get_downloads(session)} + if DownloadStatus.new not in statuses and DownloadStatus.pending not in statuses: + # All downloads must be complete/deferred/failed. + break + else: + raise TimeoutError('Downloads never finished!') @staticmethod def reset_downloads(): @@ -717,7 +599,13 @@ def reset_downloads(): with get_db_curs(commit=True) as curs: curs.execute("UPDATE download SET status='new' WHERE status='pending' OR status='deferred'") - DOWNLOAD_SORT = ('pending', 'failed', 'new', 'deferred', 'complete') + DOWNLOAD_SORT = ( + DownloadStatus.pending, + DownloadStatus.failed, + DownloadStatus.new, + DownloadStatus.deferred, + DownloadStatus.complete, + ) @optional_session def get_new_downloads(self, session: Session) -> Generator[Download, None, None]: @@ -734,7 +622,6 @@ def get_new_downloads(self, session: Session) -> Generator[Download, None, None] # Got the last download again. Is something wrong? return last = download - download.manager = self yield download @optional_session @@ -784,15 +671,15 @@ def renew_recurring_downloads(self, session: Session = None): session.commit() # Save the config now that some Downloads renewed. - background_task(save_downloads_config()) + api_app.add_task(save_downloads_config()) - def get_downloads(self, session: Session) -> List[Download]: + @staticmethod + def get_downloads(session: Session) -> List[Download]: downloads = list(session.query(Download).all()) - for download in downloads: - download.manager = self return downloads - def get_download(self, session: Session, url: str = None, id_: int = None) -> Optional[Download]: + @staticmethod + def get_download(session: Session, url: str = None, id_: int = None) -> Optional[Download]: """Attempt to find a Download by its URL or by its id.""" query = session.query(Download) if url: @@ -800,9 +687,7 @@ def get_download(self, session: Session, url: str = None, id_: int = None) -> Op elif id: download = query.filter_by(id=id_).one_or_none() else: - raise ValueError('Cannot find download without some params.') - if download: - download.manager = self + raise RuntimeError('Cannot find download without some params.') return download @optional_session @@ -830,50 +715,46 @@ def restart_download(self, download_id: int, session: Session = None) -> Downloa def kill_download(self, download_id: int): """Fail a Download. If it is pending, kill the Downloader so the download stops.""" + logger.info(f'Killing Download: {download_id}') + download_manager_data = api_app.shared_ctx.download_manager_data.copy() + download_manager_data['killed_downloads'] = download_manager_data['killed_downloads'] + [download_id, ] + api_app.shared_ctx.download_manager_data.update(download_manager_data) + with get_db_session(commit=True) as session: - download = self.get_download(session, id_=download_id) - downloader = download.get_downloader() - self.log_warning(f'Killing download {download_id} in {downloader}') - if download.is_pending(): - downloader.kill() - download.fail() + if download := self.get_download(session, id_=download_id): + download.error = 'User stopped this download' + download.fail() + + @staticmethod + def unkill_download(download_id: int): + """Remove a Download from the killed_downloads list. This allows it to be run again.""" + download_manager_data = api_app.shared_ctx.download_manager_data.copy() + download_manager_data['killed_downloads'] = \ + [i for i in download_manager_data['killed_downloads'] if i != download_id] + api_app.shared_ctx.download_manager_data.update(download_manager_data) + + @staticmethod + def download_is_killed(download_id: int): + return download_id in api_app.shared_ctx.download_manager_data['killed_downloads'] def disable(self): - """Stop all downloads and downloaders. Workers will stay idle.""" - self.log_info('Disabling downloads and downloaders.') - self.disabled.set() - for downloader in self.instances: - downloader.kill() - if flags.db_up.is_set(): - # Only defer downloads if the DB is up. - for download in self.get_pending_downloads(): - download.defer() - self.cancel_workers() + """Stop all downloads and downloaders.""" + if not self.disabled.is_set(): + self.disabled.set() def stop(self): - """Stop all downloads, downloaders and workers, defer all pending downloads.""" - self.log_warning('Stopping all workers') - self.stopped.set() + """Stop all downloads, downloaders and workers. This is called when the server is shutting down.""" + if not self.is_stopped: + self.stopped.set() self.disable() - def enable(self, loop=None): - """Enable downloading. Start downloading. Start workers.""" + async def enable(self): + """Enable downloading. Start downloading.""" self.log_info('Enabling downloading') - for downloader in self.instances: - downloader.clear() self.stopped.clear() self.disabled.clear() - self.start_workers(loop) - try: - background_task(self.perpetual_download()) - background_task(self.do_downloads()) - except RuntimeError: - # This may not work while testing. - if not PYTEST: - raise - - FINISHED_STATUSES = ('complete', 'failed') + FINISHED_STATUSES = (DownloadStatus.complete, DownloadStatus.failed) def delete_old_once_downloads(self): """Delete all once-downloads that have expired. @@ -901,7 +782,7 @@ def list_downloaders(self) -> List[Downloader]: END''' def get_fe_downloads(self): - """Get downloads for the Frontend.""" + """Get downloads for the Frontend. Uses raw SQL for faster result.""" # Use custom SQL because SQLAlchemy is slow. with get_db_curs() as curs: stmt = f''' @@ -982,16 +863,14 @@ def get_summary(self) -> dict: summary = dict( pending=counts['pending_downloads'], recurring=counts['recurring_downloads'], - disabled=self.disabled.is_set(), - stopped=self.stopped.is_set(), + disabled=self.is_disabled, + stopped=self.is_stopped, ) return summary @optional_session def get_pending_downloads(self, session: Session) -> List[Download]: - downloads = session.query(Download).filter_by(status='pending').all() - for download in downloads: - download.manager = self + downloads = session.query(Download).filter_by(status=DownloadStatus.pending).all() return downloads @staticmethod @@ -1003,7 +882,7 @@ def calculate_next_download(download: Download, session: Session) -> Optional[da If the download is "complete" and the download has a frequency, schedule the download in it's next iteration. (Next week, month, etc.) """ - if download.is_deferred(): + if download.is_deferred: # Increase next_download slowly at first, then by large gaps later. The largest gap is the download # frequency. hours = 3 ** (download.attempts or 1) @@ -1040,30 +919,42 @@ def calculate_next_download(download: Download, session: Session) -> Optional[da next_download = next_download return next_download + @staticmethod + def _delete_downloads_q(once: bool = False, status: str = None, returning=Download.id) -> Delete: + stmt = Download.__table__.delete().returning(returning) + if once: + stmt = stmt.where(Download.frequency == None) + if status: + stmt = stmt.where(Download.status == status) + return stmt + @optional_session - def delete_completed(self, session: Session): + def delete_completed(self, session: Session) -> List[int]: """Delete any completed download records.""" - session.query(Download).filter( - Download.status == 'complete', - Download.frequency == None, # noqa - ).delete() + stmt = self._delete_downloads_q(once=True, status=DownloadStatus.complete) + deleted_ids = [i for i, in session.execute(stmt).fetchall()] session.commit() + return deleted_ids @optional_session def delete_failed(self, session: Session): """Delete any failed download records.""" - failed_downloads = session.query(Download).filter( - Download.status == 'failed', - Download.frequency == None, # noqa - ).all() + stmt = self._delete_downloads_q(once=True, status=DownloadStatus.failed, returning=Download.url) + deleted_urls = [i for i, in session.execute(stmt).fetchall()] # Add all downloads to permanent skip list. - ids = [i.id for i in failed_downloads] - self.add_to_skip_list(*(i.url for i in failed_downloads)) + self.add_to_skip_list(*deleted_urls) + + session.commit() - # Delete all failed once-downloads. - session.execute('DELETE FROM download WHERE id = ANY(:ids)', {'ids': ids}) + @optional_session + def delete_once(self, session: Session): + """Delete any once-download records.""" + stmt = self._delete_downloads_q(once=True) + deleted_ids = [i for i, in session.execute(stmt).fetchall()] session.commit() + api_app.add_task(save_downloads_config()) + return deleted_ids @staticmethod def is_skipped(*urls: str) -> bool: @@ -1081,12 +972,103 @@ def remove_from_skip_list(url: str): get_download_manager_config().skip_urls = [i for i in get_download_manager_config().skip_urls if i != url] get_download_manager_config().save() + @staticmethod + def get_download_by_url(url: str) -> Optional[Download]: + with get_db_session() as session: + download = session.query(Download).filter_by(url=url).one_or_none() + return download + # The global DownloadManager. This should be used everywhere! download_manager = DownloadManager() -class DownloadMangerConfig(ConfigFile): +@api_app.signal('wrolpi.download.download') +async def signal_download_download(download_id: int, download_url: str): + """Calls Downloaders based on the download information provided, as well as what is in the DB.""" + from wrolpi.db import get_db_session + + url = download_url + download_domain = None + + name = f'download_worker' + worker_logger = logger.getChild(name) + + try: + worker_logger.debug(f'Got download {download_id}') + + with get_db_session(commit=True) as session: + # Mark the download as started in new session so the change is committed. + download = session.query(Download).filter_by(id=download_id).one() + download.started() + download_domain = download.domain + + downloader: Downloader = download.get_downloader() + if not downloader: + worker_logger.warning(f'Could not find downloader for {download.downloader=}') + + try_again = True + try: + # Create download coroutine. Wrap it, so it can be canceled. + if not inspect.iscoroutinefunction(downloader.do_download): + raise RuntimeError(f'Coroutine expected from {downloader} do_download method.') + coro = downloader.do_download(download) + result = await downloader.cancel_wrapper(coro, download) + except UnrecoverableDownloadError as e: + # Download failed and should not be retried. + worker_logger.warning(f'UnrecoverableDownloadError for {url}', exc_info=e) + result = DownloadResult(success=False, error=str(traceback.format_exc())) + try_again = False + except Exception as e: + worker_logger.warning(f'Failed to download {url}. Will be tried again later.', exc_info=e) + result = DownloadResult(success=False, error=str(traceback.format_exc())) + + error_len = len(result.error) if result.error else 0 + worker_logger.debug( + f'Got success={result.success} from {downloader} download_id={download.id} with {error_len=}') + + with get_db_session(commit=True) as session: + # Modify the download in a new session because downloads may take a long time. + download: Download = session.query(Download).filter_by(id=download_id).one() + # Use a new location if provided, keep the old location if no new location is provided, otherwise + # clear out an outdated location. + download.location = result.location or download.location or None + # Clear any old errors if the download succeeded. + download.error = result.error if result.error else None + download.next_download = download_manager.calculate_next_download(download, session) + + if result.downloads: + worker_logger.info(f'Adding {len(result.downloads)} downloads from result of {download.url}') + urls = download.filter_excluded(result.downloads) + download_manager.create_downloads(urls, session, downloader_name=download.sub_downloader, + settings=result.settings) + + if try_again is False and not download.frequency: + # Only once-downloads can fail. + download.fail() + elif result.success: + download.complete() + else: + download.defer() + + # Remove this domain from the running list. + download_manager._delete_processing_domain(download_domain) + # Allow the download to resume. + download_manager.unkill_download(download_id) + # Save the config now that the Download has finished. + api_app.add_task(save_downloads_config()) + except asyncio.CancelledError as e: + worker_logger.warning('Canceled!', exc_info=e) + return + except Exception as e: + worker_logger.warning(f'Unexpected error', exc_info=e) + finally: + # Remove this domain from the running list. + if download_domain: + download_manager._delete_processing_domain(download_domain) + + +class DownloadManagerConfig(ConfigFile): file_name = 'download_manager.yaml' default_config = dict( skip_urls=[], @@ -1110,20 +1092,20 @@ def downloads(self, value: List[dict]): self.update({'downloads': value}) -DOWNLOAD_MANAGER_CONFIG: DownloadMangerConfig = DownloadMangerConfig() -TEST_DOWNLOAD_MANAGER_CONFIG: DownloadMangerConfig = None +DOWNLOAD_MANAGER_CONFIG: DownloadManagerConfig = DownloadManagerConfig() +TEST_DOWNLOAD_MANAGER_CONFIG: DownloadManagerConfig = None @contextlib.contextmanager -def downloads_manager_config_context(): +def downloads_manager_config_context() -> DownloadManagerConfig: """Used to create a test config.""" global TEST_DOWNLOAD_MANAGER_CONFIG - TEST_DOWNLOAD_MANAGER_CONFIG = DownloadMangerConfig() - yield + TEST_DOWNLOAD_MANAGER_CONFIG = DownloadManagerConfig() + yield TEST_DOWNLOAD_MANAGER_CONFIG TEST_DOWNLOAD_MANAGER_CONFIG = None -def get_download_manager_config() -> DownloadMangerConfig: +def get_download_manager_config() -> DownloadManagerConfig: global TEST_DOWNLOAD_MANAGER_CONFIG if isinstance(TEST_DOWNLOAD_MANAGER_CONFIG, ConfigFile): return TEST_DOWNLOAD_MANAGER_CONFIG @@ -1264,7 +1246,7 @@ async def do_download(self, download: Download) -> DownloadResult: # Only download URLs that have not yet been downloaded. urls = [] - sub_downloader = self.manager.get_downloader_by_name(download.sub_downloader) + sub_downloader = download_manager.get_downloader_by_name(download.sub_downloader) if not sub_downloader: raise ValueError(f'Unable to find sub_downloader for {download.url}') diff --git a/wrolpi/events.py b/wrolpi/events.py index 468fd539..694d0302 100644 --- a/wrolpi/events.py +++ b/wrolpi/events.py @@ -12,8 +12,6 @@ HISTORY_SIZE = 100 EVENTS_LOCK = multiprocessing.Lock() -EVENTS_HISTORY = multiprocessing.Manager().list() - class Events: @@ -94,6 +92,7 @@ def log_event(event: str, message: str = None, action: str = None, subject: str def send_event(event: str, message: str = None, action: str = None, subject: str = None, url: str = None): + from wrolpi.api_utils import api_app EVENTS_LOCK.acquire() try: # All events will be in time order, they should never be at the exact same time. @@ -107,11 +106,11 @@ def send_event(event: str, message: str = None, action: str = None, subject: str subject=subject, url=url, ) - EVENTS_HISTORY.append(e) + api_app.shared_ctx.events_history.append(e) # Keep events below limit. - while len(EVENTS_HISTORY) > HISTORY_SIZE: - EVENTS_HISTORY.pop(0) + while len(api_app.shared_ctx.events_history) > HISTORY_SIZE: + api_app.shared_ctx.events_history.pop(0) finally: EVENTS_LOCK.release() @@ -120,10 +119,12 @@ def send_event(event: str, message: str = None, action: str = None, subject: str @iterify(list) def get_events(after: datetime = None): + from wrolpi.api_utils import api_app + events_history = api_app.shared_ctx.events_history if not after: - events = [i for i in EVENTS_HISTORY] + events = [i for i in events_history] else: - events = [i for i in EVENTS_HISTORY if i['dt'] > after] + events = [i for i in events_history if i['dt'] > after] # Most recent first. return events[::-1] diff --git a/wrolpi/files/__init__.py b/wrolpi/files/__init__.py index 1111210a..efd4d47a 100644 --- a/wrolpi/files/__init__.py +++ b/wrolpi/files/__init__.py @@ -1,4 +1,4 @@ from . import downloader # noqa from . import ebooks # noqa from . import pdfs # noqa -from .api import bp +from .api import files_bp diff --git a/wrolpi/files/api.py b/wrolpi/files/api.py index 2f8d2dbd..ce4c6944 100644 --- a/wrolpi/files/api.py +++ b/wrolpi/files/api.py @@ -1,27 +1,26 @@ import pathlib from http import HTTPStatus -from multiprocessing import Manager from typing import List import sanic.request -from sanic import response, Request +from sanic import response, Request, Blueprint from sanic_ext import validate from sanic_ext.extensions.openapi import openapi from wrolpi.common import get_media_directory, wrol_mode_check, get_relative_to_media_directory, logger, \ background_task, walk from wrolpi.errors import InvalidFile, UnknownDirectory, FileUploadFailed, FileConflict -from wrolpi.root_api import get_blueprint, json_response from . import lib, schema +from ..api_utils import json_response, api_app from ..schema import JSONErrorResponse from ..vars import PYTEST -bp = get_blueprint('Files', '/api/files') +files_bp = Blueprint('Files', '/api/files') logger = logger.getChild(__name__) -@bp.post('/') +@files_bp.post('/') @openapi.definition( summary='List files in a directory', body=schema.FilesRequest, @@ -34,7 +33,7 @@ async def get_files(_: Request, body: schema.FilesRequest): return json_response({'files': files}) -@bp.post('/file') +@files_bp.post('/file') @openapi.definition( summary='Get the dict of one file', body=schema.FileRequest, @@ -48,7 +47,7 @@ async def get_file(_: Request, body: schema.FileRequest): return json_response({'file': file}) -@bp.post('/delete') +@files_bp.post('/delete') @openapi.definition( summary='Delete files or directories. Directories are deleted recursively.' ' Returns an error if WROL Mode is enabled.', @@ -63,7 +62,7 @@ async def delete_file(_: Request, body: schema.DeleteRequest): return response.empty() -@bp.post('/refresh') +@files_bp.post('/refresh') @openapi.definition( summary='Refresh and index all paths (files/directories) in the provided list. Refresh all files if not provided.', body=schema.FilesRefreshRequest, @@ -81,7 +80,7 @@ async def refresh(request: Request): return response.empty() -@bp.get('/refresh_progress') +@files_bp.get('/refresh_progress') @openapi.definition( summary='Get the progress of the file refresh' ) @@ -92,7 +91,7 @@ async def refresh_progress(request: Request): )) -@bp.post('/search') +@files_bp.post('/search') @openapi.definition( summary='Search Files', body=schema.FilesSearchRequest, @@ -104,7 +103,7 @@ async def post_search_files(_: Request, body: schema.FilesSearchRequest): return json_response(dict(file_groups=file_groups, totals=dict(file_groups=total))) -@bp.post('/directories') +@files_bp.post('/directories') @openapi.definition( summary='Get all directories that match the search_str, prefixed by the media directory.', body=schema.DirectoriesRequest, @@ -122,7 +121,7 @@ def post_directories(_, body: schema.DirectoriesRequest): return response.json(body) -@bp.post('/search_directories') +@files_bp.post('/search_directories') @openapi.definition( summary='Get all directories whose name matches the provided name.', body=schema.DirectoriesSearchRequest, @@ -156,7 +155,7 @@ async def post_search_directories(_, body: schema.DirectoriesSearchRequest): return json_response(body) -@bp.post('/get_directory') +@files_bp.post('/get_directory') @openapi.definition( summary='Get data about a directory', body=schema.Directory, @@ -182,7 +181,7 @@ async def post_get_directory(_: Request, body: schema.Directory): return json_response(body) -@bp.post('/directory') +@files_bp.post('/directory') @openapi.definition( summary='Create a directory in the media directory.', body=schema.Directory, @@ -203,7 +202,7 @@ async def post_create_directory(_: Request, body: schema.Directory): return response.empty(HTTPStatus.CREATED) -@bp.post('/move') +@files_bp.post('/move') @openapi.definition( summary='Move a file/directory into another directory in the media directory.', body=schema.Move, @@ -228,7 +227,7 @@ async def post_move(_: Request, body: schema.Move): return response.empty(HTTPStatus.NO_CONTENT) -@bp.post('/rename') +@files_bp.post('/rename') @openapi.definition( summary='Rename a file/directory in-place.', body=schema.Rename, @@ -256,7 +255,7 @@ async def post_rename(_: Request, body: schema.Rename): return response.empty(HTTPStatus.NO_CONTENT) -@bp.post('/tag') +@files_bp.post('/tag') @validate(schema.TagFileGroupPost) async def post_tag_file_group(_, body: schema.TagFileGroupPost): if not body.tag_id and not body.tag_name: @@ -274,7 +273,7 @@ async def post_tag_file_group(_, body: schema.TagFileGroupPost): return response.empty(HTTPStatus.CREATED) -@bp.post('/untag') +@files_bp.post('/untag') @validate(schema.TagFileGroupPost) async def post_untag_file_group(_, body: schema.TagFileGroupPost): await lib.remove_file_group_tag(body.file_group_id, body.file_group_primary_path, body.tag_name, body.tag_id) @@ -284,16 +283,16 @@ async def post_untag_file_group(_, body: schema.TagFileGroupPost): # { # '/media/directory/sub-dir/the-file-name.suffix': 2, # The chunk number we will receive next. # } -UPLOADED_FILES = Manager().dict() -@bp.post('/upload') +@files_bp.post('/upload') async def post_upload(request: Request): """Accepts a multipart/form-data request to upload a single file. Tracks the number of chunks and will request the correct chunk of the chunks come out of order. Will not overwrite an existing file, unless a previous upload did not complete.""" + try: destination = request.form['destination'][0] except Exception as e: @@ -324,7 +323,7 @@ async def post_upload(request: Request): output_str = str(output) # Chunks start at 0. - expected_chunk_num = UPLOADED_FILES.get(output_str, 0) + expected_chunk_num = api_app.shared_ctx.uploaded_files.get(output_str, 0) logger.debug(f'last_chunk_num is {expected_chunk_num} for {repr(output_str)} received {chunk_num=}') chunk_size = int(request.form['chunkSize'][0]) @@ -333,7 +332,7 @@ async def post_upload(request: Request): if (body_size := len(chunk.body)) != chunk_size: raise FileUploadFailed(f'Chunk size does not match the size of the chunk! {chunk_size} != {body_size=}') - if chunk_num == 0 and output.is_file() and output_str in UPLOADED_FILES: + if chunk_num == 0 and output.is_file() and output_str in api_app.shared_ctx.uploaded_files: # User attempted to upload this same file, but it did not finish. User has started over again. logger.info(f'Restarting upload of {repr(output_str)}') output.unlink() @@ -366,12 +365,12 @@ async def post_upload(request: Request): # Store what we expect to receive next. expected_chunk_num += 1 - UPLOADED_FILES[output_str] = expected_chunk_num + api_app.shared_ctx.uploaded_files[output_str] = expected_chunk_num # Chunks start at 0. if chunk_num == total_chunks: # File upload is complete. - del UPLOADED_FILES[output_str] + del api_app.shared_ctx.uploaded_files[output_str] background_task(lib.refresh_files([output.parent])) return response.empty(HTTPStatus.CREATED) @@ -379,15 +378,15 @@ async def post_upload(request: Request): return json_response({'expected_chunk': expected_chunk_num}, HTTPStatus.OK) -@bp.post('/ignore_directory') +@files_bp.post('/ignore_directory') @validate(schema.Directory) async def post_ignore_directory(request: Request, body: schema.Directory): lib.add_ignore_directory(body.path) return response.empty(HTTPStatus.OK) -@bp.post('/unignore_directory') +@files_bp.post('/unignore_directory') @validate(schema.Directory) -async def post_ignore_directory(request: Request, body: schema.Directory): +async def post_unignore_directory(request: Request, body: schema.Directory): lib.remove_ignored_directory(body.path) return response.empty(HTTPStatus.OK) diff --git a/wrolpi/files/lib.py b/wrolpi/files/lib.py index 814a6fbf..d16973b4 100644 --- a/wrolpi/files/lib.py +++ b/wrolpi/files/lib.py @@ -4,7 +4,6 @@ import functools import glob import json -import multiprocessing import os import pathlib import re @@ -524,21 +523,19 @@ async def refresh_discover_paths(paths: List[pathlib.Path], idempotency: datetim curs.execute(stmt) -REFRESH = multiprocessing.Manager().dict() - - @limit_concurrent(1) # Only one refresh at a time. @wrol_mode_check @cancelable_wrapper async def refresh_files(paths: List[pathlib.Path] = None, send_events: bool = True): """Find, model, and index all files in the media directory.""" + from wrolpi.api_utils import api_app if isinstance(paths, str): paths = [pathlib.Path(paths), ] if isinstance(paths, pathlib.Path): paths = [paths, ] idempotency = now() - REFRESH['idempotency'] = idempotency + api_app.shared_ctx.refresh['idempotency'] = idempotency refreshing_all_files = False @@ -563,10 +560,10 @@ async def refresh_files(paths: List[pathlib.Path] = None, send_events: bool = Tr files, dirs = get_files_and_directories(directory) directories.extend(dirs) found_directories |= set(dirs) - REFRESH['counted_files'] = REFRESH.get('counted_files', 0) + len(files) + api_app.shared_ctx.refresh['counted_files'] = api_app.shared_ctx.refresh.get('counted_files', 0) + len(files) # Sleep to catch cancel. await asyncio.sleep(0) - refresh_logger.info(f'Counted {REFRESH["counted_files"]} files') + refresh_logger.info(f'Counted {api_app.shared_ctx.refresh["counted_files"]} files') with flags.refresh_discovery: await refresh_discover_paths(paths, idempotency) @@ -1023,7 +1020,9 @@ def __json__(self): def get_refresh_progress() -> RefreshProgress: - idempotency = REFRESH.get('idempotency') + from wrolpi.api_utils import api_app + + idempotency = api_app.shared_ctx.refresh.get('idempotency') if idempotency: stmt = ''' SELECT @@ -1052,7 +1051,7 @@ def get_refresh_progress() -> RefreshProgress: # TODO counts are wrong if we are not refreshing all files. progress = RefreshProgress( - counted_files=REFRESH.get('counted_files', 0), + counted_files=api_app.shared_ctx.refresh.get('counted_files', 0), counting=flags.refresh_counting.is_set(), discovery=flags.refresh_discovery.is_set(), indexed=int(results['indexed'] or 0), diff --git a/wrolpi/files/test/test_api.py b/wrolpi/files/test/test_api.py index cefb723c..d2190b86 100644 --- a/wrolpi/files/test/test_api.py +++ b/wrolpi/files/test/test_api.py @@ -150,7 +150,7 @@ def test_delete_invalid_file(test_client, paths): @pytest.mark.asyncio async def test_delete_wrol_mode(test_async_client, wrol_mode_fixture): """Can't delete a file when WROL Mode is enabled.""" - wrol_mode_fixture(True) + await wrol_mode_fixture(True) request, response = await test_async_client.post('/api/files/delete', content=json.dumps({'paths': ['foo', ]})) assert response.status_code == HTTPStatus.FORBIDDEN diff --git a/wrolpi/files/test/test_lib.py b/wrolpi/files/test/test_lib.py index f5fdbcf7..db8e008d 100644 --- a/wrolpi/files/test/test_lib.py +++ b/wrolpi/files/test/test_lib.py @@ -22,7 +22,7 @@ @pytest.mark.asyncio -async def test_delete_file(test_session, make_files_structure, test_directory): +async def test_delete_file(test_async_client, test_session, make_files_structure, test_directory): """ File in the media directory can be deleted. """ @@ -54,7 +54,7 @@ async def test_delete_file(test_session, make_files_structure, test_directory): @pytest.mark.asyncio -async def test_delete_file_multiple(test_session, make_files_structure, test_directory): +async def test_delete_file_multiple(test_async_client, test_session, make_files_structure, test_directory): """Multiple files can be deleted at once.""" foo, bar, baz = make_files_structure([ 'archives/foo.txt', @@ -72,7 +72,7 @@ async def test_delete_file_multiple(test_session, make_files_structure, test_dir @pytest.mark.asyncio -async def test_delete_file_names(test_session, make_files_structure, test_directory, tag_factory): +async def test_delete_file_names(test_async_client, test_session, make_files_structure, test_directory, tag_factory): """Will not refuse to delete a file that shares the name of a nearby file when they are in different FileGroups.""" foo, foo1 = make_files_structure({ 'archives/foo': 'text', @@ -92,7 +92,7 @@ async def test_delete_file_names(test_session, make_files_structure, test_direct @pytest.mark.asyncio -async def test_delete_file_link(test_session, test_directory): +async def test_delete_file_link(test_async_client, test_session, test_directory): """Links can be deleted.""" foo, bar = test_directory / 'foo', test_directory / 'bar' foo.touch() @@ -524,7 +524,7 @@ def test_split_file_name_words(name, expected): @pytest.mark.asyncio -async def test_large_text_indexer(test_session, make_files_structure): +async def test_large_text_indexer(test_async_client, test_session, make_files_structure): """ Large files have their indexes truncated. """ diff --git a/wrolpi/flags.py b/wrolpi/flags.py index b31f8374..cddd2167 100644 --- a/wrolpi/flags.py +++ b/wrolpi/flags.py @@ -2,7 +2,6 @@ import contextlib import multiprocessing import subprocess -import threading from datetime import datetime from typing import List @@ -22,22 +21,20 @@ TESTING_LOCK = multiprocessing.Event() +FLAG_NAMES = set() + class Flag: """A simple wrapper around multiprocessing.Event. This allows synchronization between the App and this API. - This may store it's value in the DB table wrolpi_flag.""" + This may store its value in the DB table wrolpi_flag.""" def __init__(self, name: str, store_db: bool = False): - if PYTEST: - # Use threading Event during testing to avoid tests clobbering each other. - self._flag = threading.Event() - else: - self._flag = multiprocessing.Event() self.name = name self.store_db = store_db + FLAG_NAMES.add(name) def __repr__(self): return f'' @@ -48,7 +45,8 @@ def set(self): # Testing, but the test does not need flags. return - self._flag.set() + from wrolpi.api_utils import api_app + api_app.shared_ctx.flags.update({self.name: True}) self._save(True) def clear(self): @@ -57,7 +55,8 @@ def clear(self): # Testing, but the test does not need flags. return - self._flag.clear() + from wrolpi.api_utils import api_app + api_app.shared_ctx.flags.update({self.name: False}) self._save(False) def is_set(self): @@ -66,14 +65,15 @@ def is_set(self): # Testing, but the test does not need flags. return - return self._flag.is_set() + from wrolpi.api_utils import api_app + return api_app.shared_ctx.flags[self.name] def __enter__(self): if PYTEST and not TESTING_LOCK.is_set(): # Testing, but the test does not need flags. return - if self._flag.is_set(): + if self.is_set(): raise ValueError(f'{self} flag is already set!') self.set() @@ -84,7 +84,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.clear() - def _save(self, value): + def _save(self, value: bool): if self.store_db: from wrolpi.db import get_db_curs # Store the value of this Flag in it's matching column in the DB. Not all flags will do this. @@ -101,7 +101,6 @@ def _save(self, value): @contextlib.asynccontextmanager async def wait_for(self, timeout: int = 0): """Wait for this Flag to be set.""" - logger.debug(f'Waiting for flag {self.name}') async with wait_for_flag(self, timeout=timeout): yield @@ -126,15 +125,6 @@ async def wait_for(self, timeout: int = 0): # The global refresh has been performed. This is False on a fresh instance of WROLPi. refresh_complete = Flag('refresh_complete', store_db=True) -# Third party packages that may be required. See `cmd.py` -chromium_installed = Flag('chromium_installed') -ffmpeg_installed = Flag('ffmpeg_installed') -ffprobe_installed = Flag('ffprobe_installed') -nmcli_installed = Flag('nmcli_installed') -readability_installed = Flag('readability_installed') -singlefile_installed = Flag('singlefile_installed') -wget_installed = Flag('wget_installed') -yt_dlp_installed = Flag('yt_dlp_installed') def get_flags() -> List[str]: @@ -146,20 +136,6 @@ def get_flags() -> List[str]: flags.append('refreshing') if refresh_complete.is_set(): flags.append('refresh_complete') - if chromium_installed.is_set(): - flags.append('chromium_installed') - if ffmpeg_installed.is_set(): - flags.append('ffmpeg_installed') - if ffprobe_installed.is_set(): - flags.append('ffprobe_installed') - if nmcli_installed.is_set(): - flags.append('nmcli_installed') - if readability_installed.is_set(): - flags.append('readability_installed') - if singlefile_installed.is_set(): - flags.append('singlefile_installed') - if yt_dlp_installed.is_set(): - flags.append('yt_dlp_installed') if outdated_zims.is_set(): flags.append('outdated_zims') if kiwix_restart.is_set(): @@ -205,18 +181,17 @@ def check_db_is_up(): db_up.clear() -FLAGS_INITIALIZED = multiprocessing.Event() - - def init_flags(): - """Set flags to match their DB values.""" - if FLAGS_INITIALIZED.is_set(): - return - + """Read flag values from the DB, copy them to the shared context flags.""" if not db_up.is_set(): logger.error(f'Refusing to initialize flags when DB is not up.') return + from wrolpi.api_utils import api_app + if api_app.shared_ctx.flags_initialized.is_set(): + # Only need to read from once DB at startup. + return + from wrolpi.db import get_db_session with get_db_session() as session: flags: WROLPiFlag = session.query(WROLPiFlag).one_or_none() @@ -230,8 +205,6 @@ def init_flags(): else: outdated_zims.clear() - FLAGS_INITIALIZED.set() - @contextlib.asynccontextmanager async def wait_for_flag(flag: Flag, timeout: int = 0): diff --git a/wrolpi/root_api.py b/wrolpi/root_api.py index 37c315ea..b480aa9e 100644 --- a/wrolpi/root_api.py +++ b/wrolpi/root_api.py @@ -1,35 +1,34 @@ import asyncio -import json import pathlib import re -from datetime import datetime, date, timezone -from decimal import Decimal -from functools import wraps from http import HTTPStatus -from pathlib import Path -from typing import Union import vininfo.exceptions -from sanic import Sanic, response, Blueprint, __version__ as sanic_version -from sanic.blueprint_group import BlueprintGroup +from sanic import response, Blueprint, __version__ as sanic_version from sanic.request import Request -from sanic.response import HTTPResponse from sanic_ext import validate from sanic_ext.extensions.openapi import openapi from vininfo import Vin from vininfo.details._base import VinDetails +from modules.archive import archive_bp +from modules.inventory import inventory_bp +from modules.map.api import map_bp +from modules.otp.api import otp_bp +from modules.videos.api import videos_bp +from modules.zim.api import zim_bp from wrolpi import admin, status, flags, schema, dates from wrolpi import tags from wrolpi.admin import HotspotStatus -from wrolpi.common import logger, get_wrolpi_config, wrol_mode_enabled, Base, get_media_directory, \ - wrol_mode_check, native_only, disable_wrol_mode, enable_wrol_mode, get_global_statistics, url_strip_host, LOG_LEVEL, \ +from wrolpi.api_utils import json_response, json_error_handler, api_app +from wrolpi.common import logger, get_wrolpi_config, wrol_mode_enabled, get_media_directory, \ + wrol_mode_check, native_only, disable_wrol_mode, enable_wrol_mode, get_global_statistics, url_strip_host, \ set_global_log_level, get_relative_to_media_directory from wrolpi.dates import now -from wrolpi.downloader import download_manager from wrolpi.errors import WROLModeEnabled, APIError, HotspotError, InvalidDownload, \ HotspotPasswordTooShort, NativeOnly, InvalidConfig from wrolpi.events import get_events, Events +from wrolpi.files import files_bp from wrolpi.files.lib import get_file_statistics, search_file_suggestion_count from wrolpi.vars import API_HOST, API_PORT, DOCKERIZED, API_DEBUG, API_ACCESS_LOG, API_WORKERS, API_AUTO_RELOAD, \ truthy_arg, IS_RPI, IS_RPI4, IS_RPI5 @@ -37,25 +36,19 @@ logger = logger.getChild(__name__) -api_app = Sanic(name='api_app') api_app.config.FALLBACK_ERROR_FORMAT = 'json' api_bp = Blueprint('RootAPI', url_prefix='/api') -BLUEPRINTS = [api_bp, ] - - -def get_blueprint(name: str, url_prefix: str) -> Blueprint: - """ - Create a new Sanic blueprint. This will be attached to the app just before run. See `root_api.run_webserver`. - """ - bp = Blueprint(name, url_prefix) - add_blueprint(bp) - return bp - - -def add_blueprint(bp: Union[Blueprint, BlueprintGroup]): - BLUEPRINTS.append(bp) +# Blueprints order here defines what order they are displayed in OpenAPI Docs. +api_app.blueprint(api_bp) +api_app.blueprint(archive_bp) +api_app.blueprint(files_bp) +api_app.blueprint(inventory_bp) +api_app.blueprint(map_bp) +api_app.blueprint(otp_bp) +api_app.blueprint(videos_bp) +api_app.blueprint(zim_bp) def run_webserver( @@ -66,8 +59,6 @@ def run_webserver( access_log: bool = API_ACCESS_LOG, ): # Attach all blueprints after they have been defined. - for bp in BLUEPRINTS: - api_app.blueprint(bp) kwargs = dict( host=host, @@ -140,6 +131,7 @@ async def echo(request: Request): @openapi.description('Get WROLPi settings') @openapi.response(HTTPStatus.OK, schema.SettingsResponse) def get_settings(_: Request): + from wrolpi.downloader import download_manager config = get_wrolpi_config() ignored_directories = [get_relative_to_media_directory(i) for i in config.ignored_directories] @@ -157,7 +149,7 @@ def get_settings(_: Request): 'hotspot_status': admin.hotspot_status().name, 'ignore_outdated_zims': config.ignore_outdated_zims, 'ignored_directories': ignored_directories, - 'log_level': LOG_LEVEL.value, + 'log_level': api_app.shared_ctx.log_level.value, 'map_directory': config.map_directory, 'media_directory': str(get_media_directory()), # Convert to string to avoid conversion to relative. 'throttle_on_startup': config.throttle_on_startup, @@ -173,14 +165,14 @@ def get_settings(_: Request): @api_bp.patch('/settings') @openapi.description('Update WROLPi settings') @validate(json=schema.SettingsRequest) -def update_settings(_: Request, body: schema.SettingsRequest): +async def update_settings(_: Request, body: schema.SettingsRequest): if wrol_mode_enabled() and body.wrol_mode is None: # Cannot update settings while WROL Mode is enabled, unless you want to disable WROL Mode. raise WROLModeEnabled() if body.wrol_mode is False: # Disable WROL Mode - disable_wrol_mode() + await disable_wrol_mode() return response.empty() elif body.wrol_mode is True: # Enable WROL Mode @@ -257,6 +249,7 @@ def valid_regex(_: Request, body: schema.RegexRequest): @validate(schema.DownloadRequest) @wrol_mode_check async def post_download(_: Request, body: schema.DownloadRequest): + from wrolpi.downloader import download_manager downloader = download_manager.get_downloader_by_name(body.downloader) if not downloader: raise InvalidDownload(f'Cannot find downloader with name {body.downloader}') @@ -290,6 +283,7 @@ async def post_download(_: Request, body: schema.DownloadRequest): @api_bp.post('/download//restart') @openapi.description('Restart a download.') async def restart_download(_: Request, download_id: int): + from wrolpi.downloader import download_manager download_manager.restart_download(download_id) return response.empty() @@ -297,6 +291,7 @@ async def restart_download(_: Request, download_id: int): @api_bp.get('/download') @openapi.description('Get all Downloads that need to be processed.') async def get_downloads(_: Request): + from wrolpi.downloader import download_manager data = download_manager.get_fe_downloads() return json_response(data) @@ -304,6 +299,7 @@ async def get_downloads(_: Request): @api_bp.post('/download//kill') @openapi.description('Kill a download. It will be stopped if it is pending.') async def kill_download(_: Request, download_id: int): + from wrolpi.downloader import download_manager download_manager.kill_download(download_id) return response.empty() @@ -311,6 +307,8 @@ async def kill_download(_: Request, download_id: int): @api_bp.post('/download/kill') @openapi.description('Kill all downloads. Disable downloading.') async def kill_downloads(_: Request): + from wrolpi.downloader import download_manager + logger.warning('Disabled downloads') download_manager.disable() return response.empty() @@ -318,13 +316,15 @@ async def kill_downloads(_: Request): @api_bp.post('/download/enable') @openapi.description('Enable and start downloading.') async def enable_downloads(_: Request): - download_manager.enable() + from wrolpi.downloader import download_manager + await download_manager.enable() return response.empty() @api_bp.post('/download/clear_completed') @openapi.description('Clear completed downloads') async def clear_completed(_: Request): + from wrolpi.downloader import download_manager download_manager.delete_completed() return response.empty() @@ -332,14 +332,24 @@ async def clear_completed(_: Request): @api_bp.post('/download/clear_failed') @openapi.description('Clear failed downloads') async def clear_failed(_: Request): + from wrolpi.downloader import download_manager download_manager.delete_failed() return response.empty() +@api_bp.post('/download/delete_once') +@openapi.description('Delete all once downloads') +async def delete_once(_: Request): + from wrolpi.downloader import download_manager + download_manager.delete_once() + return response.empty() + + @api_bp.delete('/download/') @openapi.description('Delete a download') @wrol_mode_check async def delete_download(_: Request, download_id: int): + from wrolpi.downloader import download_manager deleted = download_manager.delete_download(download_id) return response.empty(HTTPStatus.NO_CONTENT if deleted else HTTPStatus.NOT_FOUND) @@ -347,6 +357,7 @@ async def delete_download(_: Request, download_id: int): @api_bp.get('/downloaders') @openapi.description('List all Downloaders that can be specified by the user.') async def get_downloaders(_: Request): + from wrolpi.downloader import download_manager downloaders = download_manager.list_downloaders() disabled = download_manager.disabled.is_set() ret = dict(downloaders=downloaders, manager_disabled=disabled) @@ -396,6 +407,8 @@ async def throttle_off(_: Request): @api_bp.get('/status') @openapi.description('Get the status of CPU/load/etc.') async def get_status(_: Request): + from wrolpi.downloader import download_manager + s = await status.get_status() downloads = dict() if flags.db_up.is_set(): @@ -440,8 +453,12 @@ async def get_statistics(_): @api_bp.get('/events/feed') @validate(query=schema.EventsRequest) -async def feed(_: Request, query: schema.EventsRequest): +async def feed(request: Request, query: schema.EventsRequest): + # Get the current datetime from the API. The frontend will use this to request any events that happen after it's + # previous request. The API decides what the time is, just in case the RPi's clock is wrong, or no NTP is + # available. start = now() + after = None if query.after == 'None' else dates.strpdate(query.after) events = get_events(after) return json_response(dict(events=events, now=start)) @@ -456,8 +473,8 @@ async def get_tags_request(_: Request): return json_response(dict(tags=tags_)) -@api_bp.post('/tag') -@api_bp.post('/tag/') +@api_bp.post('/tag', name='tag_crate') +@api_bp.post('/tag/', name='tag_update') @validate(schema.TagRequest) @openapi.definition( summary='Create or update a Tag', @@ -574,81 +591,4 @@ async def post_search_file_estimates(_: Request, body: schema.SearchFileEstimate return json_response(ret) -class CustomJSONEncoder(json.JSONEncoder): - - def default(self, obj): - try: - if hasattr(obj, '__json__'): - # Get __json__ before others. - return obj.__json__() - elif isinstance(obj, datetime): - # API always returns dates in UTC. - if obj.tzinfo: - obj = obj.astimezone(timezone.utc) - else: - # A datetime with no timezone is UTC. - obj = obj.replace(tzinfo=timezone.utc) - obj = obj.isoformat() - return obj - elif isinstance(obj, date): - # API always returns dates in UTC. - obj = datetime(obj.year, obj.month, obj.day, tzinfo=timezone.utc) - return obj.isoformat() - elif isinstance(obj, Decimal): - return str(obj) - elif isinstance(obj, Base): - if hasattr(obj, 'dict'): - return obj.dict() - elif isinstance(obj, Path): - media_directory = get_media_directory() - try: - path = obj.relative_to(media_directory) - except ValueError: - # Path may not be absolute. - path = obj - if str(path) == '.': - return '' - return str(path) - return super(CustomJSONEncoder, self).default(obj) - except Exception as e: - logger.fatal(f'Failed to JSON encode {obj}', exc_info=e) - raise - - -@wraps(response.json) -def json_response(*a, **kwargs) -> HTTPResponse: - """ - Handles encoding date/datetime in JSON. - """ - resp = response.json(*a, **kwargs, cls=CustomJSONEncoder, dumps=json.dumps) - return resp - - -def get_error_json(exception: BaseException): - """Return a JSON representation of the Exception instance.""" - if isinstance(exception, APIError): - # Error especially defined for WROLPi. - body = dict(error=str(exception), summary=exception.summary, code=exception.code) - else: - # Not a WROLPi APIError error. - body = dict( - error=str(exception), - summary=None, - code=None, - ) - if exception.__cause__: - # This exception was caused by another, follow the stack. - body['cause'] = get_error_json(exception.__cause__) - return body - - -def json_error_handler(request: Request, exception: APIError): - body = get_error_json(exception) - error = repr(str(body["error"])) - summary = repr(str(body["summary"])) - code = body['code'] - logger.debug(f'API returning JSON error {exception=} {error=} {summary=} {code=}') - return json_response(body, exception.status) - - api_app.error_handler.add(APIError, json_error_handler) diff --git a/wrolpi/status.py b/wrolpi/status.py index 01efe1c0..feccefdf 100755 --- a/wrolpi/status.py +++ b/wrolpi/status.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import List, Optional, Tuple +from wrolpi.api_utils import api_app from wrolpi.cmd import which from wrolpi.common import logger, limit_concurrent, get_warn_once from wrolpi.dates import now @@ -337,8 +338,6 @@ def __json__(self): 'br-', } -BANDWIDTH = multiprocessing.Manager().dict() - def get_nic_names() -> List[str]: """Finds all non-virtual and non-docker network interface names.""" @@ -374,18 +373,17 @@ def __json__(self): 'loop', 'ram', ) -DISKS_BANDWIDTH = multiprocessing.Manager().dict() -MAX_DISKS_BANDWIDTH = multiprocessing.Manager().dict() async def get_bandwidth_info() -> Tuple[List[NICBandwidthInfo], List[DiskBandwidthInfo]]: """Get all bandwidth information for all NICs and Disks.""" + from wrolpi.api_utils import api_app nics_info = [] disks_info = [] try: - for name in sorted(BANDWIDTH.keys()): - nic = BANDWIDTH[name] + for name in sorted(api_app.shared_ctx.bandwidth.keys()): + nic = api_app.shared_ctx.bandwidth[name] if 'bytes_recv_ps' not in nic: # Not stats collected yet. continue @@ -397,8 +395,8 @@ async def get_bandwidth_info() -> Tuple[List[NICBandwidthInfo], List[DiskBandwid speed=nic['speed'], )) used_disks = [] - for name in sorted(DISKS_BANDWIDTH.keys()): - disk = DISKS_BANDWIDTH[name] + for name in sorted(api_app.shared_ctx.disks_bandwidth.keys()): + disk = api_app.shared_ctx.disks_bandwidth[name] if 'bytes_read_ps' not in disk: # No status collected yet continue @@ -411,15 +409,17 @@ async def get_bandwidth_info() -> Tuple[List[NICBandwidthInfo], List[DiskBandwid used_disks.append(name) try: - maximum_read_ps = max(disk['bytes_read_ps'], MAX_DISKS_BANDWIDTH[name]['maximum_read_ps']) - maximum_write_ps = max(disk['bytes_write_ps'], MAX_DISKS_BANDWIDTH[name]['maximum_write_ps']) + maximum_read_ps = max(disk['bytes_read_ps'], + api_app.shared_ctx.max_disks_bandwidth[name]['maximum_read_ps']) + maximum_write_ps = max(disk['bytes_write_ps'], + api_app.shared_ctx.max_disks_bandwidth[name]['maximum_write_ps']) except KeyError: # Use a low first value. Hopefully all drives are capable of this speed. maximum_read_ps = 500_000 maximum_write_ps = 500_000 # Always write the new maximums. value = {name: {'maximum_read_ps': maximum_read_ps, 'maximum_write_ps': maximum_write_ps}} - MAX_DISKS_BANDWIDTH.update(value) + api_app.shared_ctx.max_disks_bandwidth.update(value) disks_info.append(DiskBandwidthInfo( bytes_read_ps=disk['bytes_read_ps'], @@ -460,9 +460,11 @@ def _calculate_bytes_per_second(history: List[Tuple]) -> Tuple[int, int, int]: return bytes_recv_ps, bytes_sent_ps, elapsed -@limit_concurrent(1) +@api_app.signal('wrolpi.periodic.bandwidth') async def bandwidth_worker(count: int = None): """A background process which will gather historical data about all NIC bandwidth statistics.""" + from wrolpi.api_utils import api_app + if not psutil: return @@ -472,25 +474,25 @@ async def bandwidth_worker(count: int = None): def append_all_stats(): for name_ in nic_names: - nic = BANDWIDTH.get(name_) + nic = api_app.shared_ctx.bandwidth.get(name_) if not nic: # Initialize history for this NIC. - BANDWIDTH.update({ + api_app.shared_ctx.bandwidth.update({ name_: dict(historical=[_get_nic_tick(name_), ]), }) else: # Append to history for this NIC. nic['historical'] = (nic['historical'] + [_get_nic_tick(name_), ])[-21:] - BANDWIDTH.update({name_: nic}) + api_app.shared_ctx.bandwidth.update({name_: nic}) timestamp = now().timestamp() for name_, disk in psutil.disk_io_counters(perdisk=True).items(): tic = timestamp, disk.read_bytes, disk.write_bytes - if bw := DISKS_BANDWIDTH.get(name_): + if bw := api_app.shared_ctx.disks_bandwidth.get(name_): bw['historical'] = (bw['historical'] + [tic, ])[-21:] - DISKS_BANDWIDTH.update({name_: bw}) + api_app.shared_ctx.disks_bandwidth.update({name_: bw}) else: - DISKS_BANDWIDTH.update({ + api_app.shared_ctx.disks_bandwidth.update({ name_: dict(historical=[tic, ]), }) @@ -498,17 +500,22 @@ def append_all_stats(): append_all_stats() while count is None or count > 0: - await asyncio.sleep(1) + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + # Server is restarting. + break + if count is not None: count -= 1 append_all_stats() # Calculate the difference between the first and last bandwidth ticks for all NICs. - for name, nic in BANDWIDTH.items(): + for name, nic in api_app.shared_ctx.bandwidth.items(): historical = nic['historical'] bytes_recv_ps, bytes_sent_ps, elapsed = _calculate_bytes_per_second(historical) - BANDWIDTH.update({ + api_app.shared_ctx.bandwidth.update({ name: dict( historical=historical, bytes_recv_ps=bytes_recv_ps, @@ -517,13 +524,13 @@ def append_all_stats(): speed=historical[-1][-1], # Use the most recent speed. ) }) - for name, stats in DISKS_BANDWIDTH.items(): + for name, stats in api_app.shared_ctx.disks_bandwidth.items(): if 'historical' not in stats: continue historical = stats['historical'] bytes_read_ps, bytes_write_ps, elapsed = _calculate_bytes_per_second(historical) - DISKS_BANDWIDTH.update({ + api_app.shared_ctx.disks_bandwidth.update({ name: dict( historical=historical, bytes_read_ps=bytes_read_ps, diff --git a/wrolpi/test/common.py b/wrolpi/test/common.py index 7b330808..31452467 100644 --- a/wrolpi/test/common.py +++ b/wrolpi/test/common.py @@ -10,7 +10,7 @@ import pytest from wrolpi.common import get_media_directory -from wrolpi.conftest import ROUTES_ATTACHED, test_db, test_client # noqa +from wrolpi.conftest import test_db, test_client # noqa from wrolpi.db import postgres_engine TEST_CONFIG_PATH = tempfile.NamedTemporaryFile(mode='rt', delete=False) diff --git a/wrolpi/test/test_admin.py b/wrolpi/test/test_admin.py index f046837e..49fc4e77 100644 --- a/wrolpi/test/test_admin.py +++ b/wrolpi/test/test_admin.py @@ -69,7 +69,7 @@ def test_enable_hotspot_connected(): ) -def test_change_hotspot(test_config): +def test_change_hotspot(test_config, test_async_client): """The hotspot can be configured.""" with mock.patch('wrolpi.admin.hotspot_status') as mock_hotspot_status, \ mock.patch('wrolpi.admin.subprocess') as mock_subprocess, \ @@ -113,7 +113,7 @@ def test_throttle_off(): @skip_circleci -def test_hotspot_device(): +def test_hotspot_device(test_async_client): """Changing the hotspot device changes what device is turned on.""" from wrolpi.common import WROLPI_CONFIG WROLPI_CONFIG.hotspot_device = 'wlp2s0' diff --git a/wrolpi/test/test_dates.py b/wrolpi/test/test_dates.py index dbba0f38..c0c42aad 100644 --- a/wrolpi/test/test_dates.py +++ b/wrolpi/test/test_dates.py @@ -120,7 +120,8 @@ def test_timedelta_to_timestamp(td, expected): # PDFs are the Wild West... ("D:20221226113758-07'00", datetime(2022, 12, 26, 11, 37, 58, tzinfo=timezone(timedelta(days=-1, seconds=61200)))), ('D:20200205184724', datetime(2020, 2, 5, 18, 47, 24)), - ("D:20091019120104+", datetime(2009, 10, 19, 12, 1, 4)), + ('D:20091019120104+', datetime(2009, 10, 19, 12, 1, 4)), + ('04/27/2024 18:52:55', datetime(2024, 4, 27, 18, 52, 55)), ]) def test_strpdate(dt, expected): assert dates.strpdate(dt) == expected @@ -139,3 +140,6 @@ def test_invalid_strpdate(): dates.strpdate('2001-2-31') with pytest.raises(InvalidDatetime): dates.strpdate('2001-2-30') + + with pytest.raises(InvalidDatetime): + dates.strpdate('27/04/2024 18:52:55') diff --git a/wrolpi/test/test_downloader.py b/wrolpi/test/test_downloader.py index 7fc38b59..f3349f12 100644 --- a/wrolpi/test/test_downloader.py +++ b/wrolpi/test/test_downloader.py @@ -10,6 +10,7 @@ import pytz import yaml +from wrolpi.common import get_wrolpi_config from wrolpi.dates import Seconds from wrolpi.db import get_db_context from wrolpi.downloader import Downloader, Download, DownloadFrequency, import_downloads_config, \ @@ -178,11 +179,11 @@ async def test_recurring_downloads(test_session, test_download_manager, fake_now test_downloader.set_test_success() # Download every hour. - test_download_manager.recurring_download('https://example.com', Seconds.hour, test_downloader.name) + test_download_manager.recurring_download('https://example.com/recurring', Seconds.hour, test_downloader.name) # One download is scheduled. downloads = test_download_manager.get_new_downloads(test_session) - assert [(i.url, i.frequency) for i in downloads] == [('https://example.com', Seconds.hour)] + assert [(i.url, i.frequency) for i in downloads] == [('https://example.com/recurring', Seconds.hour)] now_ = fake_now(datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)) @@ -194,6 +195,7 @@ async def test_recurring_downloads(test_session, test_download_manager, fake_now assert len(downloads) == 1 download = downloads[0] expected = datetime(2020, 1, 1, 1, 0, 0, tzinfo=pytz.UTC) + assert download.is_complete, download.status assert download.next_download == expected assert download.last_successful_download == now_ @@ -208,9 +210,9 @@ async def test_recurring_downloads(test_session, test_download_manager, fake_now test_download_manager.renew_recurring_downloads(test_session) (download,) = list(test_download_manager.get_new_downloads(test_session)) # Download is "new" but has not been downloaded a second time. + assert download.is_new, download.status assert download.next_download == expected assert download.last_successful_download == now_ - assert download.is_new() # Try the download, but it fails. test_downloader.do_download.reset_mock() @@ -219,7 +221,7 @@ async def test_recurring_downloads(test_session, test_download_manager, fake_now test_downloader.do_download.assert_called_once() download = test_session.query(Download).one() # Download is deferred, last successful download remains the same. - assert download.is_deferred() + assert download.is_deferred, download.status assert download.last_successful_download == now_ # Download should be retried after the DEFAULT_RETRY_FREQUENCY. expected = datetime(2020, 1, 1, 3, 0, 0, 997200, tzinfo=pytz.UTC) @@ -233,7 +235,7 @@ async def test_recurring_downloads(test_session, test_download_manager, fake_now await test_download_manager.wait_for_all_downloads() test_downloader.do_download.assert_called_once() download = test_session.query(Download).one() - assert download.is_complete() + assert download.is_complete, download.status assert download.last_successful_download == now_ # Floats cause slightly wrong date. assert download.next_download == datetime(2020, 1, 1, 5, 0, 0, 997200, tzinfo=pytz.UTC) @@ -262,11 +264,12 @@ async def test_max_attempts(test_session, test_download_manager, test_downloader await test_download_manager.wait_for_all_downloads() download = session.query(Download).one() assert download.attempts == 3 - assert download.is_failed() + assert download.is_failed @pytest.mark.asyncio -async def test_skip_urls(test_session, test_download_manager, assert_download_urls, test_downloader): +async def test_skip_urls(test_session, test_download_manager, assert_download_urls, test_downloader, + test_download_manager_config): """The DownloadManager will not create downloads for URLs in its skip list.""" _, session = get_db_context() get_download_manager_config().skip_urls = ['https://example.com/skipme'] @@ -282,7 +285,7 @@ async def test_skip_urls(test_session, test_download_manager, assert_download_ur assert_download_urls({'https://example.com/1', 'https://example.com/2'}) assert get_download_manager_config().skip_urls == ['https://example.com/skipme'] - test_download_manager.delete_completed() + test_download_manager.delete_once() # The user can start a download even if the URL is in the skip list. test_download_manager.create_download('https://example.com/skipme', test_downloader.name, reset_attempts=True) @@ -291,28 +294,27 @@ async def test_skip_urls(test_session, test_download_manager, assert_download_ur @pytest.mark.asyncio -async def test_process_runner_timeout(test_directory): - """A Downloader can cancel it's download using a timeout.""" +async def test_process_runner_timeout(test_async_client, test_session, test_directory): + """A Downloader can cancel its download using a timeout.""" # Default timeout of 3 seconds. downloader = Downloader('downloader', timeout=3) # Default timeout is obeyed. start = datetime.now() - await downloader.process_runner('https://example.com', ('sleep', '8'), test_directory) + await downloader.process_runner(1, 'https://example.com', ('sleep', '8'), test_directory) elapsed = datetime.now() - start assert 3 < elapsed.total_seconds() < 4 # One-off timeout is obeyed. start = datetime.now() - await downloader.process_runner('https://example.com', ('sleep', '8'), test_directory, timeout=1) + await downloader.process_runner(2, 'https://example.com', ('sleep', '8'), test_directory, timeout=1) elapsed = datetime.now() - start assert 1 < elapsed.total_seconds() < 2 # Global timeout is obeyed. - from wrolpi.common import WROLPI_CONFIG - WROLPI_CONFIG.download_timeout = 3 + get_wrolpi_config().download_timeout = 3 start = datetime.now() - await downloader.process_runner('https://example.com', ('sleep', '8'), test_directory) + await downloader.process_runner(3, 'https://example.com', ('sleep', '8'), test_directory) elapsed = datetime.now() - start assert 2 < elapsed.total_seconds() < 4 diff --git a/wrolpi/test/test_root_api.py b/wrolpi/test/test_root_api.py index 2805c6cd..2d5737d5 100644 --- a/wrolpi/test/test_root_api.py +++ b/wrolpi/test/test_root_api.py @@ -1,4 +1,3 @@ -import asyncio import json from http import HTTPStatus from itertools import zip_longest @@ -7,10 +6,10 @@ from mock import mock from wrolpi.admin import HotspotStatus +from wrolpi.api_utils import json_error_handler from wrolpi.common import get_wrolpi_config from wrolpi.downloader import Download, get_download_manager_config from wrolpi.errors import ValidationError, SearchEmpty -from wrolpi.root_api import json_error_handler from wrolpi.test.common import skip_circleci, assert_dict_contains @@ -348,6 +347,15 @@ def check_downloads(response_, once_downloads_, recurring_downloads_, status_cod # Failed once-downloads will not be downloaded again. assert get_download_manager_config().skip_urls == ['https://example.com/5', ] + # "Delete All" button deletes all once-downloads. + request, response = await test_async_client.post('/api/download/delete_once') + assert response.status_code == HTTPStatus.NO_CONTENT + request, response = await test_async_client.get('/api/download') + check_downloads(response, [], recurring_downloads, status_code=HTTPStatus.OK) + + # Failed once-downloads will not be downloaded again. + assert get_download_manager_config().skip_urls == ['https://example.com/5', ] + @pytest.mark.asyncio async def test_get_status(test_async_client, test_session): @@ -368,10 +376,10 @@ async def test_get_status(test_async_client, test_session): def test_post_download(test_session, test_client, test_download_manager_config): """Test creating once-downloads and recurring downloads.""" - async def queue_downloads(*a, **kw): + async def dispatch_downloads(*a, **kw): pass - with mock.patch('wrolpi.downloader.DownloadManager.queue_downloads', queue_downloads): + with mock.patch('wrolpi.downloader.DownloadManager.dispatch_downloads', dispatch_downloads): # Create a single recurring Download. content = dict( urls=['https://example.com/1', ], @@ -423,18 +431,20 @@ async def test_restart_download(test_session, test_async_client, test_download_m download = test_download_manager.create_download('https://example.com', test_downloader.name) download.fail() test_session.commit() - assert test_session.query(Download).one().is_failed() + assert test_session.query(Download).one().is_failed test_downloader.set_test_failure() # Download is now "new" again. request, response = await test_async_client.post(f'/api/download/{download.id}/restart') assert response.status_code == HTTPStatus.NO_CONTENT - assert test_session.query(Download).one().is_new() + download = test_session.query(Download).one() + assert download.is_new, download.status # Wait for the background download to fail. It should be deferred. - await asyncio.sleep(0.5) - assert test_session.query(Download).one().is_deferred() + await test_download_manager.wait_for_all_downloads() + download = test_session.query(Download).one() + assert download.is_deferred, download.status def test_get_global_statistics(test_session, test_client): diff --git a/wrolpi/test/test_rss.py b/wrolpi/test/test_rss.py index c5871eed..e1c87750 100644 --- a/wrolpi/test/test_rss.py +++ b/wrolpi/test/test_rss.py @@ -93,5 +93,5 @@ async def test_rss_no_entries(test_session, test_download_manager): await test_download_manager.wait_for_all_downloads() (download,) = test_download_manager.get_downloads(test_session) - assert download.is_deferred() + assert download.is_deferred assert 'entries' in download.error diff --git a/wrolpi/vars.py b/wrolpi/vars.py index 7ec285d9..d9f20694 100644 --- a/wrolpi/vars.py +++ b/wrolpi/vars.py @@ -88,3 +88,5 @@ def truthy_arg(value: str) -> bool: IS_RPI4 = 'Raspberry Pi 4' in RPI_DEVICE_MODEL_CONTENTS if RPI_DEVICE_MODEL_CONTENTS else False IS_RPI5 = 'Raspberry Pi 5' in RPI_DEVICE_MODEL_CONTENTS if RPI_DEVICE_MODEL_CONTENTS else False IS_RPI = IS_RPI4 or IS_RPI5 + +SIMULTANEOUS_DOWNLOAD_DOMAINS = int(os.environ.get('SIMULTANEOUS_DOWNLOAD_DOMAINS', 4))