Skip to content

Commit

Permalink
async update
Browse files Browse the repository at this point in the history
  • Loading branch information
bigsk1 committed Jun 25, 2024
1 parent b1c2dd2 commit 23924ad
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 88 deletions.
109 changes: 61 additions & 48 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def open_file(filepath):

# Function to play audio using PyAudio
async def play_audio(file_path):
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, sync_play_audio, file_path)
await asyncio.to_thread(sync_play_audio, file_path)

def sync_play_audio(file_path):
print("Starting audio playback")
Expand Down Expand Up @@ -215,10 +214,10 @@ async def process_and_play(prompt, audio_file_pth):
await send_message_to_clients(json.dumps({"action": "ai_stop_speaking"}))
else:
print("Error: Audio file not found.")
else:
tts_model = xtts_model
else: # XTTS
try:
outputs = await tts_model.synthesize(
tts_model = xtts_model
outputs = await asyncio.to_thread(tts_model.synthesize,
prompt,
xtts_config,
speaker_wav=audio_file_pth,
Expand Down Expand Up @@ -389,6 +388,7 @@ def adjust_prompt(mood):

def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_history):
full_response = ""
print(f"Debug: chatgpt_streamed started. MODEL_PROVIDER: {MODEL_PROVIDER}")

if MODEL_PROVIDER == 'ollama':
headers = {'Content-Type': 'application/json'}
Expand All @@ -399,6 +399,7 @@ def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_histo
"options": {"num_predict": -2, "temperature": 1.0}
}
try:
print(f"Debug: Sending request to Ollama: {OLLAMA_BASE_URL}/v1/chat/completions")
response = requests.post(f'{OLLAMA_BASE_URL}/v1/chat/completions', headers=headers, json=payload, stream=True, timeout=30)
response.raise_for_status()

Expand Down Expand Up @@ -426,12 +427,14 @@ def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_histo

except requests.exceptions.RequestException as e:
full_response = f"Error connecting to Ollama model: {e}"
print(f"Debug: Ollama error - {e}")

elif MODEL_PROVIDER == 'openai':
messages = [{"role": "system", "content": system_message + "\n" + mood_prompt}] + conversation_history + [{"role": "user", "content": user_input}]
headers = {'Authorization': f'Bearer {OPENAI_API_KEY}', 'Content-Type': 'application/json'}
payload = {"model": OPENAI_MODEL, "messages": messages, "stream": True}
try:
print(f"Debug: Sending request to OpenAI: {OPENAI_BASE_URL}")
response = requests.post(OPENAI_BASE_URL, headers=headers, json=payload, stream=True, timeout=30)
response.raise_for_status()

Expand Down Expand Up @@ -461,7 +464,9 @@ def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_histo

except requests.exceptions.RequestException as e:
full_response = f"Error connecting to OpenAI model: {e}"
print(f"Debug: OpenAI error - {e}")

print(f"Debug: chatgpt_streamed completed. Response length: {len(full_response)}")
return full_response

def transcribe_with_whisper(audio_file):
Expand Down Expand Up @@ -505,7 +510,7 @@ def record_audio(file_path, silence_threshold=512, silence_duration=4.0, chunk_s
wf.writeframes(b''.join(frames))
wf.close()

def execute_once(question_prompt):
async def execute_once(question_prompt):
temp_image_path = os.path.join(output_dir, 'temp_img.jpg')

# Determine the audio file format based on the TTS provider
Expand All @@ -519,8 +524,8 @@ def execute_once(question_prompt):
temp_audio_path = os.path.join(output_dir, 'temp_audio.wav') # Use wav for XTTS
max_char_length = 250 # Set a lower limit for XTTS

image_path = take_screenshot(temp_image_path)
response = analyze_image(image_path, question_prompt)
image_path = await take_screenshot(temp_image_path)
response = await analyze_image(image_path, question_prompt)
text_response = response.get('choices', [{}])[0].get('message', {}).get('content', 'No response received.')

# Truncate response based on the TTS provider's limit
Expand All @@ -529,40 +534,40 @@ def execute_once(question_prompt):

print(text_response)

generate_speech(text_response, temp_audio_path)
await generate_speech(text_response, temp_audio_path)

if TTS_PROVIDER == 'elevenlabs':
# Convert MP3 to WAV if ElevenLabs is used
temp_wav_path = os.path.join(output_dir, 'temp_output.wav')
audio = AudioSegment.from_mp3(temp_audio_path)
audio.export(temp_wav_path, format="wav")
play_audio(temp_wav_path)
await play_audio(temp_wav_path)
else:
play_audio(temp_audio_path)
await play_audio(temp_audio_path)

os.remove(image_path)


def execute_screenshot_and_analyze():
async def execute_screenshot_and_analyze():
question_prompt = "What do you see in this image? Keep it short but detailed and answer any follow up questions about it"
print("Taking screenshot and analyzing...")
execute_once(question_prompt)
await execute_once(question_prompt)
print("\nReady for the next question....")

def take_screenshot(temp_image_path):
time.sleep(5)
async def take_screenshot(temp_image_path):
await asyncio.sleep(5)
screenshot = ImageGrab.grab()
screenshot = screenshot.resize((1024, 1024))
screenshot.save(temp_image_path, 'JPEG')
return temp_image_path

def encode_image(image_path):
# Encode Image
async def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')

def analyze_image(image_path, question_prompt):
encoded_image = encode_image(image_path)

# Analyze Image
async def analyze_image(image_path, question_prompt):
encoded_image = await encode_image(image_path)
if MODEL_PROVIDER == 'ollama':
headers = {'Content-Type': 'application/json'}
payload = {
Expand All @@ -572,15 +577,17 @@ def analyze_image(image_path, question_prompt):
"stream": False
}
try:
response = requests.post(f'{OLLAMA_BASE_URL}/api/generate', headers=headers, json=payload, timeout=30)
print(f"Response status code: {response.status_code}")
if response.status_code == 200:
return {"choices": [{"message": {"content": response.json().get('response', 'No response received.')}}]}
elif response.status_code == 404:
return {"choices": [{"message": {"content": "The llava model is not available on this server."}}]}
else:
response.raise_for_status()
except requests.exceptions.RequestException as e:
async with aiohttp.ClientSession() as session:
async with session.post(f'{OLLAMA_BASE_URL}/api/generate', headers=headers, json=payload, timeout=30) as response:
print(f"Response status code: {response.status}")
if response.status == 200:
response_json = await response.json()
return {"choices": [{"message": {"content": response_json.get('response', 'No response received.')}}]}
elif response.status == 404:
return {"choices": [{"message": {"content": "The llava model is not available on this server."}}]}
else:
response.raise_for_status()
except aiohttp.ClientError as e:
print(f"Request failed: {e}")
return {"choices": [{"message": {"content": "Failed to process the image with the llava model."}}]}
else:
Expand All @@ -594,29 +601,35 @@ def analyze_image(image_path, question_prompt):
}
payload = {"model": OPENAI_MODEL, "temperature": 0.5, "messages": [message], "max_tokens": 1000}
try:
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=30)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
async with aiohttp.ClientSession() as session:
async with session.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=30) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientError as e:
print(f"Request failed: {e}")
return {"choices": [{"message": {"content": "Failed to process the image with the OpenAI model."}}]}

def generate_speech(text, temp_audio_path):

async def generate_speech(text, temp_audio_path):
if TTS_PROVIDER == 'openai':
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {OPENAI_API_KEY}"}
payload = {"model": "tts-1", "voice": OPENAI_TTS_VOICE, "input": text, "response_format": "wav"}
response = requests.post(OPENAI_TTS_URL, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
with open(temp_audio_path, "wb") as audio_file:
audio_file.write(response.content)
else:
print(f"Failed to generate speech: {response.status_code} - {response.text}")
async with aiohttp.ClientSession() as session:
async with session.post(OPENAI_TTS_URL, headers=headers, json=payload, timeout=30) as response:
if response.status == 200:
with open(temp_audio_path, "wb") as audio_file:
audio_file.write(await response.read())
else:
print(f"Failed to generate speech: {response.status} - {await response.text()}")

elif TTS_PROVIDER == 'elevenlabs':
elevenlabs_text_to_speech(text, temp_audio_path)
else:
await elevenlabs_text_to_speech(text, temp_audio_path)

else: # XTTS
tts_model = xtts_model
try:
outputs = tts_model.synthesize(
outputs = await asyncio.to_thread(
tts_model.synthesize,
text,
xtts_config,
speaker_wav=character_audio_file,
Expand All @@ -632,7 +645,7 @@ def generate_speech(text, temp_audio_path):
except Exception as e:
print(f"Error during XTTS audio generation: {e}")

def user_chatbot_conversation():
async def user_chatbot_conversation():
conversation_history = []
base_system_message = open_file(character_prompt_file)
quit_phrases = ["quit", "Quit", "Quit.", "Exit.", "exit", "Exit", "leave", "Leave."]
Expand All @@ -659,7 +672,7 @@ def user_chatbot_conversation():
conversation_history.append({"role": "user", "content": user_input})

if any(phrase in user_input.lower() for phrase in screenshot_phrases):
execute_screenshot_and_analyze()
await execute_screenshot_and_analyze() # Note the 'await' here
continue

mood = analyze_mood(user_input)
Expand All @@ -672,11 +685,11 @@ def user_chatbot_conversation():
if len(sanitized_response) > 400:
sanitized_response = sanitized_response[:400] + "..."
prompt2 = sanitized_response
process_and_play(prompt2, character_audio_file)
await process_and_play(prompt2, character_audio_file) # Note the 'await' here
if len(conversation_history) > 20:
conversation_history = conversation_history[-20:]
except KeyboardInterrupt:
print("Quitting the conversation...")

if __name__ == "__main__":
user_chatbot_conversation()
asyncio.run(user_chatbot_conversation())
2 changes: 1 addition & 1 deletion app/app_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def conversation_loop():
break

if any(phrase in user_input.lower() for phrase in screenshot_phrases):
execute_screenshot_and_analyze()
await execute_screenshot_and_analyze()
continue

try:
Expand Down
7 changes: 6 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi.middleware.cors import CORSMiddleware
from .shared import clients, get_current_character, set_current_character
from .app_logic import start_conversation, stop_conversation, set_env_variable
# from .app import user_chatbot_conversation

app = FastAPI()

Expand All @@ -25,6 +26,10 @@
allow_headers=["*"],
)

# @app.on_event("startup")
# async def startup_event():
# asyncio.create_task(user_chatbot_conversation())

@app.get("/")
async def get(request: Request):
model_provider = os.getenv("MODEL_PROVIDER")
Expand Down Expand Up @@ -85,7 +90,7 @@ async def websocket_endpoint(websocket: WebSocket):
await start_conversation()
elif message["action"] == "set_character":
set_current_character(message["character"])
await websocket.send_json({"message": f"Character set to {message['character']}"})
await websocket.send_json({"message": f"Character: {message['character']}"})
elif message["action"] == "set_provider":
set_env_variable("MODEL_PROVIDER", message["provider"])
elif message["action"] == "set_tts":
Expand Down
Loading

0 comments on commit 23924ad

Please sign in to comment.