-
-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
329b969
commit 8312464
Showing
3 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from fastapi import APIRouter, Depends, HTTPException, Request | ||
from logger import get_logger | ||
from middlewares.auth import AuthBearer, get_current_user | ||
from modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput | ||
from modules.sync.service.sync_service import SyncService | ||
from modules.user.entity.user_identity import UserIdentity | ||
from msal import PublicClientApplication | ||
|
||
# Initialize logger | ||
logger = get_logger(__name__) | ||
|
||
# Initialize sync service | ||
sync_service = SyncService() | ||
|
||
# Initialize API router | ||
azure_sync_router = APIRouter() | ||
|
||
# Constants | ||
CLIENT_ID = "511dce23-02f3-4724-8684-05da226df5f3" | ||
AUTHORITY = "https://login.microsoftonline.com/common" | ||
REDIRECT_URI = "http://localhost:5050/sync/azure/oauth2callback" | ||
SCOPE = [ | ||
"https://graph.microsoft.com/Files.Read", | ||
"https://graph.microsoft.com/User.Read", | ||
"https://graph.microsoft.com/Sites.Read.All", | ||
] | ||
|
||
client = PublicClientApplication(CLIENT_ID, authority=AUTHORITY) | ||
|
||
|
||
@azure_sync_router.get( | ||
"/sync/azure/authorize", | ||
dependencies=[Depends(AuthBearer())], | ||
tags=["Sync"], | ||
) | ||
def authorize_azure( | ||
request: Request, current_user: UserIdentity = Depends(get_current_user) | ||
): | ||
""" | ||
Authorize Azure sync for the current user. | ||
Args: | ||
request (Request): The request object. | ||
current_user (UserIdentity): The current authenticated user. | ||
Returns: | ||
dict: A dictionary containing the authorization URL. | ||
""" | ||
logger.debug(f"Authorizing Azure sync for user: {current_user.id}") | ||
state = f"user_id={current_user.id}" | ||
authorization_url = client.get_authorization_request_url( | ||
scopes=SCOPE, redirect_uri=REDIRECT_URI, state=state | ||
) | ||
|
||
sync_user_input = SyncsUserInput( | ||
user_id=str(current_user.id), | ||
sync_name="Azure", | ||
credentials={}, | ||
state={"state": state}, | ||
) | ||
sync_service.create_sync_user(sync_user_input) | ||
return {"authorization_url": authorization_url} | ||
|
||
|
||
@azure_sync_router.get("/sync/azure/oauth2callback", tags=["Sync"]) | ||
def oauth2callback_azure(request: Request): | ||
""" | ||
Handle OAuth2 callback from Azure. | ||
Args: | ||
request (Request): The request object. | ||
Returns: | ||
dict: A dictionary containing a success message. | ||
""" | ||
state = request.query_params.get("state") | ||
state_dict = {"state": state} | ||
current_user = state.split('=')[1] # Extract user_id from state | ||
logger.debug( | ||
f"Handling OAuth2 callback for user: {current_user} with state: {state}" | ||
) | ||
sync_user_state = sync_service.get_sync_user_by_state(state_dict) | ||
logger.info(f"Retrieved sync user state: {sync_user_state}") | ||
|
||
if state_dict != sync_user_state["state"]: | ||
logger.error("Invalid state parameter") | ||
raise HTTPException(status_code=400, detail="Invalid state parameter") | ||
if sync_user_state.get("user_id") != current_user: | ||
logger.error("Invalid user") | ||
raise HTTPException(status_code=400, detail="Invalid user") | ||
|
||
result = client.acquire_token_by_authorization_code( | ||
request.query_params.get("code"), scopes=SCOPE, redirect_uri=REDIRECT_URI | ||
) | ||
if "access_token" not in result: | ||
logger.error("Failed to acquire token") | ||
raise HTTPException(status_code=400, detail="Failed to acquire token") | ||
|
||
creds = result | ||
logger.info(f"Fetched OAuth2 token for user: {current_user}") | ||
|
||
sync_user_input = SyncUserUpdateInput( | ||
credentials=creds, | ||
state={}, | ||
) | ||
sync_service.update_sync_user(current_user, state_dict, sync_user_input) | ||
logger.info(f"Azure sync created successfully for user: {current_user}") | ||
return {"message": "Azure sync created successfully"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters