Skip to content

Commit

Permalink
Handling requests exceptions with friendly error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
yeisonvargasf committed Feb 23, 2021
1 parent d8609b9 commit d853b0a
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 45 deletions.
46 changes: 32 additions & 14 deletions safety/cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import

import itertools
import sys

import click
from safety import __version__
from safety import safety
from safety.formatter import report, license_report
import itertools
from safety.util import read_requirements, read_vulnerabilities, get_proxy_dict, get_packages_licenses
from safety.errors import DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError, TooManyRequestsError

from safety import __version__, safety
from safety.errors import (DatabaseFetchError, DatabaseFileNotFoundError,
InvalidKeyError, NetworkConnectionError,
RequestTimeoutError, ServerError,
TooManyRequestsError)
from safety.formatter import license_report, report
from safety.util import (get_packages_licenses, get_proxy_dict,
read_requirements, read_vulnerabilities)

try:
from json.decoder import JSONDecodeError
Expand Down Expand Up @@ -65,16 +71,16 @@ def check(key, db, json, full_report, bare, stdin, files, cache, ignore, output,
packages = [
d for d in pkg_resources.working_set
if d.key not in {"python", "wsgiref", "argparse"}
]
]
proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport)
try:
vulns = safety.check(packages=packages, key=key, db_mirror=db, cached=cache, ignore_ids=ignore, proxy=proxy_dictionary)
output_report = report(vulns=vulns,
full=full_report,
json_report=json,
output_report = report(vulns=vulns,
full=full_report,
json_report=json,
bare_report=bare,
checked_packages=len(packages),
db=db,
db=db,
key=key)

if output:
Expand All @@ -92,6 +98,18 @@ def check(key, db, json, full_report, bare, stdin, files, cache, ignore, output,
except DatabaseFileNotFoundError:
click.secho("Unable to load vulnerability database from {db}".format(db=db), fg="red", file=sys.stderr)
sys.exit(-1)
except NetworkConnectionError:
click.secho("Check your network connection, unable to reach the server", fg="red", file=sys.stderr)
sys.exit(-1)
except RequestTimeoutError:
click.secho("Check your network connection, the request timed out.", fg="red", file=sys.stderr)
sys.exit(-1)
except ServerError:
click.secho(
"Sorry, something went wrong.\n" + "Safety CLI can not connect to the server.\n" +
"Our engineers are working quickly to resolve the issue",
fg="red", file=sys.stderr)
sys.exit(-1)
except DatabaseFetchError:
click.secho("Unable to load vulnerability database", fg="red", file=sys.stderr)
sys.exit(-1)
Expand Down Expand Up @@ -154,15 +172,15 @@ def license(key, db, json, bare, cache, files, proxyprotocol, proxyhost, proxypo
packages = [
d for d in pkg_resources.working_set
if d.key not in {"python", "wsgiref", "argparse"}
]
]

proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport)
try:
licenses_db = safety.get_licenses(key, db, cache, proxy_dictionary)
except InvalidKeyError as invalid_key_error:
if str(invalid_key_error):
message = str(invalid_key_error)
else:
else:
message = "Your API Key '{key}' is invalid. See {link}".format(
key=key, link='https://goo.gl/O7Y1rS'
)
Expand Down
13 changes: 13 additions & 0 deletions safety/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ class InvalidKeyError(DatabaseFetchError):

class TooManyRequestsError(DatabaseFetchError):
pass


class NetworkConnectionError(DatabaseFetchError):
pass


class RequestTimeoutError(DatabaseFetchError):
pass


class ServerError(DatabaseFetchError):
pass

23 changes: 19 additions & 4 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from .constants import (API_MIRRORS, CACHE_FILE, CACHE_LICENSES_VALID_SECONDS,
CACHE_VALID_SECONDS, OPEN_MIRRORS, REQUEST_TIMEOUT)
from .errors import (DatabaseFetchError, DatabaseFileNotFoundError,
InvalidKeyError, TooManyRequestsError)
InvalidKeyError, NetworkConnectionError,
RequestTimeoutError, ServerError, TooManyRequestsError)
from .util import RequirementFile


Expand Down Expand Up @@ -87,17 +88,31 @@ def fetch_database_url(mirror, db_name, key, cached, proxy):
if cached_data:
return cached_data
url = mirror + db_name
r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)

try:
r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
except requests.exceptions.Timeout:
raise RequestTimeoutError()
except requests.exceptions.RequestException:
raise DatabaseFetchError()

if r.status_code == 200:
data = r.json()
if cached:
write_to_cache(db_name, data)
return data
elif r.status_code == 403:

if r.status_code == 403:
raise InvalidKeyError()
elif r.status_code == 429:

if r.status_code == 429:
raise TooManyRequestsError()

if 500 <= r.status_code < 600:
raise ServerError()


def fetch_database_file(path, db_name):
full_path = os.path.join(path, db_name)
Expand Down
99 changes: 72 additions & 27 deletions tests/test_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@
"""


import unittest
import json
import os
import textwrap
from click.testing import CliRunner
import unittest
from unittest.mock import Mock, patch

from safety import safety
from safety import cli
from safety import formatter
from safety import util
import os
import json
import requests
from click.testing import CliRunner

from safety import cli, formatter, safety, util

try:
from StringIO import StringIO
except ImportError:
from io import StringIO
from safety.util import read_requirements
from safety.util import read_vulnerabilities

from safety.util import read_requirements, read_vulnerabilities


class TestSafetyCLI(unittest.TestCase):
Expand Down Expand Up @@ -220,6 +220,50 @@ def test_multiple_versions(self):
)
self.assertEqual(len(vulns), 4)

@patch.object(requests, 'get', side_effect=requests.exceptions.ConnectionError)
def test_check_fetch_database_url_connection_error(self, requests_mock):
from safety.errors import NetworkConnectionError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(NetworkConnectionError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch.object(requests, 'get', side_effect=requests.exceptions.Timeout)
def test_check_fetch_database_url_timeout_error(self, requests_mock):
from safety.errors import RequestTimeoutError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(RequestTimeoutError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch.object(requests, 'get', side_effect=requests.exceptions.RequestException)
def test_check_fetch_database_url_generic_exception_error(self, requests_mock):
from safety.errors import DatabaseFetchError

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(DatabaseFetchError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

@patch("safety.safety.requests")
def test_check_fetch_database_url_server_error(self, mocked_requests):
from safety.errors import ServerError

mock = Mock()
mock.status_code = 502
mocked_requests.get.return_value = mock

db_name = "insecure.json"
mirror = 'https://safety.test'

with self.assertRaises(ServerError):
safety.fetch_database_url(mirror, db_name=db_name, key="INVALID", cached=False, proxy={})

def test_check_live(self):
reqs = StringIO("insecure-package==0.1")
packages = util.read_requirements(reqs)
Expand Down Expand Up @@ -295,20 +339,6 @@ def test_get_packages_licenses(self):
"unexpected package '" + pkg_license['package'] + "' was found"
)

def test_get_packages_licenses_without_api_key(self):
from safety.errors import InvalidKeyError

# without providing an API-KEY
with self.assertRaises(InvalidKeyError) as error:
safety.get_licenses(
db_mirror=False,
cached=False,
proxy={},
key=None
)
db_generic_exception = error.exception
self.assertEqual(str(db_generic_exception), 'The API-KEY was not provided.')

@patch("safety.safety.requests")
def test_get_packages_licenses_with_invalid_api_key(self, requests):
from safety.errors import InvalidKeyError
Expand All @@ -326,6 +356,20 @@ def test_get_packages_licenses_with_invalid_api_key(self, requests):
key="INVALID"
)

def test_get_packages_licenses_without_api_key(self):
from safety.errors import InvalidKeyError

# without providing an API-KEY
with self.assertRaises(InvalidKeyError) as error:
safety.get_licenses(
db_mirror=False,
cached=False,
proxy={},
key=None
)
db_generic_exception = error.exception
self.assertEqual(str(db_generic_exception), 'The API-KEY was not provided.')

@patch("safety.safety.requests")
def test_get_packages_licenses_db_fetch_error(self, requests):
from safety.errors import DatabaseFetchError
Expand All @@ -341,7 +385,7 @@ def test_get_packages_licenses_db_fetch_error(self, requests):
proxy={},
key="MY-VALID-KEY"
)

def test_get_packages_licenses_with_invalid_db_file(self):
from safety.errors import DatabaseFileNotFoundError

Expand Down Expand Up @@ -373,6 +417,7 @@ def test_get_packages_licenses_very_often(self, requests):
@patch("safety.safety.requests")
def test_get_cached_packages_licenses(self, requests):
import copy

from safety.constants import CACHE_FILE

licenses_db = {
Expand Down Expand Up @@ -401,7 +446,7 @@ def test_get_cached_packages_licenses(self, requests):
f.write(json.dumps({}))
except Exception:
pass

# In order to cache the db (and get), we must set cached as True
response = safety.get_licenses(
db_mirror=False,
Expand All @@ -421,7 +466,7 @@ def test_get_cached_packages_licenses(self, requests):
proxy={},
key="MY-VALID-KEY"
)

self.assertNotEqual(resp, licenses_db)
self.assertEqual(resp, original_db)

Expand Down

0 comments on commit d853b0a

Please sign in to comment.