Skip to content

Commit

Permalink
Add reauthentication flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus authored and Markus committed Nov 11, 2023
1 parent 00abdd6 commit c204a46
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 69 deletions.
43 changes: 29 additions & 14 deletions custom_components/unifi_voucher/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from aiohttp import CookieJar
import aiounifi
from aiounifi.models.configuration import Configuration
from aiounifi.models.api import ApiRequest
from aiounifi.models.api import (
ApiRequest,
TypedApiResponse,
)

from homeassistant.core import (
callback,
Expand Down Expand Up @@ -169,18 +172,31 @@ async def check_api_user(
try:
async with asyncio.timeout(10):
await self.api.login()

await self.api.sites.update()
for _id, _site in self.api.sites.items():
for _unique_id, _site in self.api.sites.items():
# User must have admin or hotspot permissions
if _site.role in ("admin", "hotspot"):
_sites[_id] = _site.description
_sites[_unique_id] = _site

# No site with the required permissions found
if len(_sites) == 0:
LOGGER.warning(
"Connected to UniFi Network at %s but no access.",
self.host,
)
raise UnifiVoucherApiAccessError
return _sites
except (
aiounifi.LoginRequired,
aiounifi.Unauthorized,
aiounifi.Forbidden,
) as err:
LOGGER.warning(
"Connected to UniFi Network at %s but login required: %s",
self.host,
err,
)
raise UnifiVoucherApiAuthenticationError from err
except (
asyncio.TimeoutError,
aiounifi.BadGateway,
Expand All @@ -195,20 +211,19 @@ async def check_api_user(
)
raise UnifiVoucherApiConnectionError from err
except (
aiounifi.LoginRequired,
aiounifi.Unauthorized,
aiounifi.Forbidden,
aiounifi.AiounifiException,
Exception,
) as err:
LOGGER.warning(
"Connected to UniFi Network at %s but login required: %s",
self.host,
err,
)
raise UnifiVoucherApiAuthenticationError from err
except aiounifi.AiounifiException as err:
LOGGER.exception(
"Unknown UniFi Network communication error occurred: %s",
err,
)
raise UnifiVoucherApiError from err
return False

async def request(
self,
api_request: ApiRequest,
) -> TypedApiResponse:
"""Make a request to the API, retry login on failure."""
return await self.api.request(api_request)
3 changes: 0 additions & 3 deletions custom_components/unifi_voucher/button.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ async def async_setup_entry(
[
UnifiVoucherButton(
coordinator=coordinator,
host=config_entry.data[CONF_HOST],
entity_description=entity_description,
)
for entity_description in entity_descriptions
Expand All @@ -74,13 +73,11 @@ class UnifiVoucherButton(UnifiVoucherEntity, ButtonEntity):
def __init__(
self,
coordinator: UnifiVoucherCoordinator,
host: str,
entity_description: UnifiVoucherButtonDescription,
) -> None:
"""Initialize the button class."""
super().__init__(
coordinator=coordinator,
host=host,
entity_type="button",
entity_key=entity_description.key,
)
Expand Down
160 changes: 122 additions & 38 deletions custom_components/unifi_voucher/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ class UnifiVoucherConfigFlow(ConfigFlow, domain=DOMAIN):

def __init__(self) -> None:
"""Initialize the UniFi Network flow."""
self._title: str | None = None
self._options: dict[str, any] | None = None
self._sites: dict[str, str] | None = None
self.title: str | None = None
self.data: dict[str, any] | None = None
self.sites: dict[str, str] | None = None
self.reauth_config_entry: config_entries.ConfigEntry | None = None
self.reauth_schema: dict[vol.Marker, Any] = {}

async def async_step_user(
self,
Expand All @@ -63,12 +65,6 @@ async def async_step_user(
"""Invoke when a user initiates a flow via the user interface."""
errors: dict[str, str] = {}
if user_input is not None:
self._async_abort_entries_match(
{
CONF_HOST: user_input[CONF_HOST],
CONF_USERNAME: user_input[CONF_USERNAME],
}
)
try:
client = UnifiVoucherApiClient(
self.hass,
Expand All @@ -79,7 +75,7 @@ async def async_step_user(
site_id=DEFAULT_SITE_ID,
verify_ssl=user_input[CONF_VERIFY_SSL],
)
self._sites = await client.check_api_user()
self.sites = await client.check_api_user()
except UnifiVoucherApiConnectionError:
errors["base"] = "cannot_connect"
except UnifiVoucherApiAuthenticationError:
Expand All @@ -92,30 +88,27 @@ async def async_step_user(

if not errors:
# Input is valid, set data
self._options = {
self.data = {
CONF_HOST: user_input.get(CONF_HOST, "").strip(),
CONF_USERNAME: user_input.get(CONF_USERNAME, "").strip(),
CONF_PASSWORD: user_input.get(CONF_PASSWORD, "").strip(),
CONF_PORT: int(user_input[CONF_PORT]),
CONF_SITE_ID: DEFAULT_SITE_ID,
CONF_VERIFY_SSL: user_input.get(CONF_VERIFY_SSL, False),
}
# Reauth
if (
self.reauth_config_entry
and self.reauth_config_entry.unique_id is not None
and self.reauth_config_entry.unique_id in self.sites
):
return await self.async_step_site(
{
CONF_SITE_ID: self.reauth_config_entry.unique_id
}
)
# Go to site selection, if user has access to more than one site
if len(self._sites) > 1:
return await self.async_step_site()

site_id = list(self._sites.keys())[0]

self._title = self._sites.get(site_id)
self._options.update({
CONF_SITE_ID: site_id
})
# User is done, create the config entry.
return self.async_create_entry(
title=self._title,
data={},
options=self._options,
)
return await self.async_step_site()

if await _async_discover_unifi(
self.hass
Expand Down Expand Up @@ -147,7 +140,7 @@ async def async_step_user(
default=(user_input or {}).get(CONF_PASSWORD, DEFAULT_PASSWORD),
): selector.TextSelector(
selector.TextSelectorConfig(
type=selector.TextSelectorType.TEXT
type=selector.TextSelectorType.PASSWORD
),
),
vol.Required(
Expand Down Expand Up @@ -176,23 +169,57 @@ async def async_step_site(
"""Second step in config flow to save site."""
errors: dict[str, str] = {}
if user_input is not None:
site_id = user_input.get(CONF_SITE_ID, "")
unique_id = user_input.get(CONF_SITE_ID, "")

if not self._sites.get(site_id):
if not self.sites.get(unique_id):
errors["base"] = "site_invalid"

config_entry = await self.async_set_unique_id(unique_id)
abort_reason = "configuration_updated"

if self.reauth_config_entry:
config_entry = self.reauth_config_entry
abort_reason = "reauth_successful"
else:
# Abort if site is already configured
self._async_abort_entries_match(
{
CONF_HOST: user_input[CONF_HOST],
CONF_SITE_ID: self.sites[unique_id].name,
}
)

if config_entry:
self.hass.config_entries.async_update_entry(
config_entry, data=self.config
)
await self.hass.config_entries.async_reload(
config_entry.entry_id
)
return self.async_abort(
reason=abort_reason
)

if not errors:
# Input is valid, set data.
self._title = self._sites.get(site_id)
self._options.update({
CONF_SITE_ID: site_id
self.title = self.sites[unique_id].description
self.data.update({
CONF_SITE_ID: self.sites[unique_id].name
})
# User is done, create the config entry.
return self.async_create_entry(
title=self._title,
data={},
options=self._options,
title=self.title,
data=self.data,
)

# Only one site is available, skip selection
if len(self.sites.values()) == 1:
return await self.async_step_site(
{
CONF_SITE_ID: next(iter(self.sites)),
}
)

return self.async_show_form(
step_id="site",
data_schema=vol.Schema(
Expand All @@ -204,10 +231,10 @@ async def async_step_site(
selector.SelectSelectorConfig(
options=[
selector.SelectOptionDict(
value=site_id,
label=site_description,
value=_unique_id,
label=_site.description,
)
for site_id, site_description in self._sites.items()
for _unique_id, _site in self.sites.items()
],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key=CONF_SITE_ID,
Expand All @@ -219,6 +246,63 @@ async def async_step_site(
errors=errors,
)

async def async_step_reauth(
self,
entry_data: dict[str, any],
) -> FlowResult:
"""Trigger a reauthentication flow."""
config_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
assert config_entry
self.reauth_config_entry = config_entry

self.context["title_placeholders"] = {
CONF_HOST: config_entry.data[CONF_HOST],
CONF_SITE_ID: config_entry.title,
}

self.reauth_schema = {
vol.Required(
CONF_HOST,
default=config_entry.data[CONF_HOST],
): selector.TextSelector(
selector.TextSelectorConfig(
type=selector.TextSelectorType.TEXT
),
),
vol.Required(
CONF_USERNAME,
default=config_entry.data[CONF_USERNAME],
): selector.TextSelector(
selector.TextSelectorConfig(
type=selector.TextSelectorType.TEXT
),
),
vol.Required(
CONF_PASSWORD,
): selector.TextSelector(
selector.TextSelectorConfig(
type=selector.TextSelectorType.PASSWORD
),
),
vol.Required(
CONF_PORT,
default=config_entry.data[CONF_PORT],
): selector.NumberSelector(
selector.NumberSelectorConfig(
mode=selector.NumberSelectorMode.BOX,
min=1,
max=65535,
)
),
vol.Required(
CONF_VERIFY_SSL,
default=config_entry.data[CONF_VERIFY_SSL],
): selector.BooleanSelector(),
}
return await self.async_step_user()

async def _async_discover_unifi(hass: HomeAssistant) -> str | None:
"""Discover UniFi Network address."""
try:
Expand Down
20 changes: 13 additions & 7 deletions custom_components/unifi_voucher/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from .api import (
UnifiVoucherApiClient,
UnifiVoucherListRequest,
UnifiVoucherCreateRequest,
UnifiVoucherApiAuthenticationError,
UnifiVoucherApiConnectionError,
UnifiVoucherApiError,
Expand All @@ -55,14 +57,14 @@ def __init__(
update_interval=update_interval,
)
self.config_entry = config_entry
self.client = UnifiVoucherApiClient(
self.api = UnifiVoucherApiClient(
hass,
host=config_entry.options.get(CONF_HOST, DEFAULT_HOST),
username=config_entry.options.get(CONF_USERNAME, ""),
password=config_entry.options.get(CONF_PASSWORD, ""),
port=int(config_entry.options.get(CONF_PORT, DEFAULT_PORT)),
site_id=config_entry.options.get(CONF_SITE_ID, DEFAULT_SITE_ID),
verify_ssl=config_entry.options.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL),
host=config_entry.data.get(CONF_HOST, DEFAULT_HOST),
username=config_entry.data.get(CONF_USERNAME, ""),
password=config_entry.data.get(CONF_PASSWORD, ""),
port=int(config_entry.data.get(CONF_PORT, DEFAULT_PORT)),
site_id=config_entry.data.get(CONF_SITE_ID, DEFAULT_SITE_ID),
verify_ssl=config_entry.data.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL),
)
self._last_pull = None

Expand All @@ -81,6 +83,10 @@ async def _async_update_data(self):
try:
self._last_pull = dt_util.now()
_available = True

foo = await self.api.request(UnifiVoucherListRequest.create())
LOGGER.debug(foo)

# TODO
#except (UnifiVoucherClientTimeoutError, UnifiVoucherClientCommunicationError, UnifiVoucherClientAuthenticationError) as exception:
# LOGGER.error(str(exception))
Expand Down
3 changes: 1 addition & 2 deletions custom_components/unifi_voucher/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ class UnifiVoucherEntity(CoordinatorEntity):
def __init__(
self,
coordinator: UnifiVoucherCoordinator,
host: str,
entity_type: str,
entity_key: str,
) -> None:
"""Initialize."""
super().__init__(coordinator)

self._host = host
self._host = "unifi_voucher" # TODO
self._entity_type = entity_type
self._entity_key = entity_key

Expand Down
Loading

0 comments on commit c204a46

Please sign in to comment.