diff --git a/pyhon/connection/auth.py b/pyhon/connection/auth.py index 72a7426..3066555 100644 --- a/pyhon/connection/auth.py +++ b/pyhon/connection/auth.py @@ -1,3 +1,4 @@ +import datetime import json import logging import re @@ -146,7 +147,7 @@ class HonAuth: if not await self._get_token(url): return False - post_headers = {"Content-Type": "application/json", "id-token": self._id_token} + post_headers = {"id-token": self._id_token} data = self._device.get() async with self._session.post(f"{const.API_URL}/auth/v1/login", headers=post_headers, json=data) as resp: try: @@ -156,3 +157,18 @@ class HonAuth: return False self._cognito_token = json_data["cognitoUser"]["Token"] return True + + async def refresh(self): + params = { + "client_id": const.CLIENT_ID, + "refresh_token": self._refresh_token, + "grant_type": "refresh_token" + } + async with self._session.post(f"{const.AUTH_API}/services/oauth2/token", params=params) as resp: + if resp.status >= 400: + return False + data = await resp.json() + self._id_token = data["id_token"] + self._access_token = data["access_token"] + + diff --git a/pyhon/connection/connection.py b/pyhon/connection/connection.py index 6c9bf56..66deb5b 100644 --- a/pyhon/connection/connection.py +++ b/pyhon/connection/connection.py @@ -66,23 +66,31 @@ class HonConnectionHandler(HonBaseConnectionHandler): return {h: v for h, v in self._request_headers.items() if h not in headers} @asynccontextmanager - async def get(self, *args, loop=0, **kwargs): + async def _intercept(self, method, *args, loop=0, **kwargs): kwargs["headers"] = await self._check_headers(kwargs.get("headers", {})) - async with self._session.get(*args, **kwargs) as response: + async with method(*args, **kwargs) as response: if response.status == 403 and not loop: + _LOGGER.info("Try refreshing token...") + await self._auth.refresh() + yield await self._intercept(method, *args, loop=loop + 1, **kwargs) + elif response.status == 403 and loop < 2: _LOGGER.warning("%s - Error %s - %s", response.request_info.url, response.status, await response.text()) await self.create() - yield await self.get(*args, loop=loop + 1, **kwargs) + yield await self._intercept(method, *args, loop=loop + 1, **kwargs) elif loop >= 2: _LOGGER.error("%s - Error %s - %s", response.request_info.url, response.status, await response.text()) - raise PermissionError() + raise PermissionError("Login failure") else: yield response + @asynccontextmanager + async def get(self, *args, **kwargs): + async with self._intercept(self._session.get, *args, **kwargs) as response: + yield response + @asynccontextmanager async def post(self, *args, **kwargs): - kwargs["headers"] = await self._check_headers(kwargs.get("headers", {})) - async with self._session.post(*args, **kwargs) as response: + async with self._intercept(self._session.post, *args, **kwargs) as response: yield response