Skip to content

Commit

Permalink
Handling requests exceptions with friendly error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
yeisonvargasf committed Feb 26, 2021
1 parent d8609b9 commit 79c1120
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 5 deletions.
14 changes: 13 additions & 1 deletion safety/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from safety.formatter import report, license_report
import itertools
from safety.util import read_requirements, read_vulnerabilities, get_proxy_dict, get_packages_licenses
from safety.errors import DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError, TooManyRequestsError
from safety.errors import DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError, NetworkConnectionError, RequestTimeoutError, ServerError, TooManyRequestsError

try:
from json.decoder import JSONDecodeError
Expand Down Expand Up @@ -92,6 +92,18 @@ def check(key, db, json, full_report, bare, stdin, files, cache, ignore, output,
except DatabaseFileNotFoundError:
click.secho("Unable to load vulnerability database from {db}".format(db=db), fg="red", file=sys.stderr)
sys.exit(-1)
except NetworkConnectionError:
click.secho("Check your network connection, unable to reach the server", fg="red", file=sys.stderr)
sys.exit(-1)
except RequestTimeoutError:
click.secho("Check your network connection, the request timed out.", fg="red", file=sys.stderr)
sys.exit(-1)
except ServerError:
click.secho(
"Sorry, something went wrong.\n" + "Safety CLI can not connect to the server.\n" +
"Our engineers are working quickly to resolve the issue",
fg="red", file=sys.stderr)
sys.exit(-1)
except DatabaseFetchError:
click.secho("Unable to load vulnerability database", fg="red", file=sys.stderr)
sys.exit(-1)
Expand Down
13 changes: 13 additions & 0 deletions safety/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ class InvalidKeyError(DatabaseFetchError):

class TooManyRequestsError(DatabaseFetchError):
pass


class NetworkConnectionError(DatabaseFetchError):
pass


class RequestTimeoutError(DatabaseFetchError):
pass


class ServerError(DatabaseFetchError):
pass

23 changes: 19 additions & 4 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from .constants import (API_MIRRORS, CACHE_FILE, CACHE_LICENSES_VALID_SECONDS,
CACHE_VALID_SECONDS, OPEN_MIRRORS, REQUEST_TIMEOUT)
from .errors import (DatabaseFetchError, DatabaseFileNotFoundError,
InvalidKeyError, TooManyRequestsError)
InvalidKeyError, NetworkConnectionError,
RequestTimeoutError, ServerError, TooManyRequestsError)
from .util import RequirementFile


Expand Down Expand Up @@ -87,17 +88,31 @@ def fetch_database_url(mirror, db_name, key, cached, proxy):
if cached_data:
return cached_data
url = mirror + db_name
r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)

try:
r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
except requests.exceptions.Timeout:
raise RequestTimeoutError()
except requests.exceptions.RequestException:
raise DatabaseFetchError()

if r.status_code == 200:
data = r.json()
if cached:
write_to_cache(db_name, data)
return data
elif r.status_code == 403:

if r.status_code == 403:
raise InvalidKeyError()
elif r.status_code == 429:

if r.status_code == 429:
raise TooManyRequestsError()

if 500 <= r.status_code < 600:
raise ServerError()


def fetch_database_file(path, db_name):
full_path = os.path.join(path, db_name)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from safety import util
import os
import json
import requests
try:
from StringIO import StringIO
except ImportError:
Expand Down Expand Up @@ -220,6 +221,50 @@ def test_multiple_versions(self):
)
self.assertEqual(len(vulns), 4)

@patch.object(requests, 'get', side_effect=requests.exceptions.ConnectionError)
def test_check_fetch_database_url_connection_error(self, requests_mock):
from safety.errors import NetworkConnectionError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(NetworkConnectionError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch.object(requests, 'get', side_effect=requests.exceptions.Timeout)
def test_check_fetch_database_url_timeout_error(self, requests_mock):
from safety.errors import RequestTimeoutError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(RequestTimeoutError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch.object(requests, 'get', side_effect=requests.exceptions.RequestException)
def test_check_fetch_database_url_generic_exception_error(self, requests_mock):
from safety.errors import DatabaseFetchError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(DatabaseFetchError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch("safety.safety.requests")
def test_check_fetch_database_url_server_error(self, mocked_requests):
from safety.errors import ServerError

mock = Mock()
mock.status_code = 502
mocked_requests.get.return_value = mock

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(ServerError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

def test_check_live(self):
reqs = StringIO("insecure-package==0.1")
packages = util.read_requirements(reqs)
Expand Down

0 comments on commit 79c1120

Please sign in to comment.