Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling requests exceptions with friendly error messages #338

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion safety/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from safety.formatter import report, license_report
import itertools
from safety.util import read_requirements, read_vulnerabilities, get_proxy_dict, get_packages_licenses
from safety.errors import DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError, TooManyRequestsError
from safety.errors import DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError, NetworkConnectionError, RequestTimeoutError, ServerError, TooManyRequestsError

try:
from json.decoder import JSONDecodeError
Expand Down Expand Up @@ -92,6 +92,18 @@ def check(key, db, json, full_report, bare, stdin, files, cache, ignore, output,
except DatabaseFileNotFoundError:
click.secho("Unable to load vulnerability database from {db}".format(db=db), fg="red", file=sys.stderr)
sys.exit(-1)
except NetworkConnectionError:
click.secho("Check your network connection, unable to reach the server", fg="red", file=sys.stderr)
sys.exit(-1)
except RequestTimeoutError:
click.secho("Check your network connection, the request timed out.", fg="red", file=sys.stderr)
sys.exit(-1)
Comment on lines +95 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we create a pattern for passing default messages to our error classes, you can just reuse the error message of the thrown exception here and maybe only catch the base class like DatabaseFetchError.

except ServerError:
click.secho(
"Sorry, something went wrong.\n" + "Safety CLI can not connect to the server.\n" +
"Our engineers are working quickly to resolve the issue.",
fg="red", file=sys.stderr)
sys.exit(-1)
except DatabaseFetchError:
click.secho("Unable to load vulnerability database", fg="red", file=sys.stderr)
sys.exit(-1)
Expand Down
13 changes: 13 additions & 0 deletions safety/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ class InvalidKeyError(DatabaseFetchError):

class TooManyRequestsError(DatabaseFetchError):
pass


class NetworkConnectionError(DatabaseFetchError):
pass


class RequestTimeoutError(DatabaseFetchError):
pass


class ServerError(DatabaseFetchError):
pass
Comment on lines +17 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you catch all of them and just change the message, not to mention they have all the same class, give a custom message to these classes so you don't need to repeat when catching them.


23 changes: 19 additions & 4 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from .constants import (API_MIRRORS, CACHE_FILE, CACHE_LICENSES_VALID_SECONDS,
CACHE_VALID_SECONDS, OPEN_MIRRORS, REQUEST_TIMEOUT)
from .errors import (DatabaseFetchError, DatabaseFileNotFoundError,
InvalidKeyError, TooManyRequestsError)
InvalidKeyError, NetworkConnectionError,
RequestTimeoutError, ServerError, TooManyRequestsError)
from .util import RequirementFile


Expand Down Expand Up @@ -87,17 +88,31 @@ def fetch_database_url(mirror, db_name, key, cached, proxy):
if cached_data:
return cached_data
url = mirror + db_name
r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)

try:
r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
except requests.exceptions.Timeout:
raise RequestTimeoutError()
except requests.exceptions.RequestException:
raise DatabaseFetchError()
Comment on lines +92 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thing you decided to add our own custom error classes here instead of just letting requests exceptions to slip through. Can you please pass the original exception as an inner attribute here?


if r.status_code == 200:
data = r.json()
if cached:
write_to_cache(db_name, data)
return data
Comment on lines 101 to 105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you modified the elif-statements below because of short-circuits, can we bring this to the end of the function without the if-statement? So, the "last case" would be the default, returning the data.

elif r.status_code == 403:

if r.status_code == 403:
raise InvalidKeyError()
elif r.status_code == 429:

if r.status_code == 429:
Comment on lines -96 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these condition ranges doesn't look good. There are several codes that would not raise an error. Can you please review the ranges here including those beyond 599 or less than 200?

raise TooManyRequestsError()

if 500 <= r.status_code < 600:
raise ServerError()
Comment on lines +113 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pass response reason as a message to ServerError, please?



def fetch_database_file(path, db_name):
full_path = os.path.join(path, db_name)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from safety import util
import os
import json
import requests
try:
from StringIO import StringIO
except ImportError:
Expand Down Expand Up @@ -220,6 +221,50 @@ def test_multiple_versions(self):
)
self.assertEqual(len(vulns), 4)

@patch.object(requests, 'get', side_effect=requests.exceptions.ConnectionError)
def test_check_fetch_database_url_connection_error(self, requests_mock):
from safety.errors import NetworkConnectionError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(NetworkConnectionError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

Comment on lines +224 to +233
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tests. 👏

Codecov provides some line-by-line comments as part of our PRs about test coverage. Please make sure we are watching them and respecting as much as possible.

@patch.object(requests, 'get', side_effect=requests.exceptions.Timeout)
def test_check_fetch_database_url_timeout_error(self, requests_mock):
from safety.errors import RequestTimeoutError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(RequestTimeoutError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch.object(requests, 'get', side_effect=requests.exceptions.RequestException)
def test_check_fetch_database_url_generic_exception_error(self, requests_mock):
from safety.errors import DatabaseFetchError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(DatabaseFetchError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch("safety.safety.requests")
def test_check_fetch_database_url_server_error(self, mocked_requests):
from safety.errors import ServerError

mock = Mock()
mock.status_code = 502
mocked_requests.get.return_value = mock

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(ServerError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

def test_check_live(self):
reqs = StringIO("insecure-package==0.1")
packages = util.read_requirements(reqs)
Expand Down