Skip to content

Commit

Permalink
Merge pull request #1995 from hlohaus/gemini-
Browse files Browse the repository at this point in the history
Add streaming and conversation support to gemini
  • Loading branch information
hlohaus committed May 22, 2024
2 parents 7eb41cf + 62b2b27 commit 6830dfc
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 43 deletions.
12 changes: 8 additions & 4 deletions g4f/Provider/You.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..typing import AsyncResult, Messages, ImageType, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt
from ..image import ImageResponse, ImagePreview, to_bytes, is_accepted_format
from ..image import ImageResponse, ImagePreview, EXTENSIONS_MAP, to_bytes, is_accepted_format
from ..requests import StreamSession, FormData, raise_for_status
from .you.har_file import get_telemetry_ids
from .. import debug
Expand Down Expand Up @@ -94,6 +94,8 @@ async def create_async_generator(
"q": format_prompt(messages),
"domain": "youchat",
"selectedChatMode": chat_mode,
"conversationTurnId": str(uuid.uuid4()),
"chatId": str(uuid.uuid4()),
}
params = {
"userFiles": upload,
Expand All @@ -106,8 +108,8 @@ async def create_async_generator(

async with (session.post if chat_mode == "default" else session.get)(
f"{cls.url}/api/streamingSearch",
data=data,
params=params,
data=data if chat_mode == "default" else None,
params=params if chat_mode == "default" else data,
headers=headers,
cookies=cookies
) as response:
Expand Down Expand Up @@ -142,7 +144,9 @@ async def upload_file(cls, client: StreamSession, cookies: Cookies, file: bytes,
await raise_for_status(response)
upload_nonce = await response.text()
data = FormData()
data.add_field('file', file, content_type=is_accepted_format(file), filename=filename)
content_type = is_accepted_format(file)
filename = f"image.{EXTENSIONS_MAP[content_type]}" if filename is None else filename
data.add_field('file', file, content_type=content_type, filename=filename)
async with client.post(
f"{cls.url}/api/upload",
data=data,
Expand Down
1 change: 1 addition & 0 deletions g4f/Provider/base_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..providers.base_provider import *
from ..providers.types import FinishReason, Streaming
from ..providers.conversation import BaseConversation
from .helper import get_cookies, format_prompt
115 changes: 76 additions & 39 deletions g4f/Provider/needs_auth/Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

from ... import debug
from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator
from ..base_provider import AsyncGeneratorProvider
from ..base_provider import AsyncGeneratorProvider, BaseConversation
from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError, MissingRequirementsError
from ...image import to_bytes, ImageResponse, ImageDataResponse
from ...image import ImageResponse, to_bytes
from ...webdriver import get_browser, get_driver_cookies

REQUEST_HEADERS = {
Expand All @@ -32,7 +32,7 @@
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
'x-same-domain': '1',
}
REQUEST_BL_PARAM = "boq_assistant-bard-web-server_20240421.18_p0"
REQUEST_BL_PARAM = "boq_assistant-bard-web-server_20240519.16_p0"
REQUEST_URL = "https://gemini.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate"
UPLOAD_IMAGE_URL = "https://content-push.googleapis.com/upload/"
UPLOAD_IMAGE_HEADERS = {
Expand All @@ -57,6 +57,8 @@ class Gemini(AsyncGeneratorProvider):
image_models = ["gemini"]
default_vision_model = "gemini"
_cookies: Cookies = None
_snlm0e: str = None
_sid: str = None

@classmethod
async def nodriver_login(cls, proxy: str = None) -> AsyncIterator[str]:
Expand Down Expand Up @@ -117,56 +119,58 @@ async def create_async_generator(
model: str,
messages: Messages,
proxy: str = None,
api_key: str = None,
cookies: Cookies = None,
connector: BaseConnector = None,
image: ImageType = None,
image_name: str = None,
response_format: str = None,
return_conversation: bool = False,
conversation: Conversation = None,
language: str = "en",
**kwargs
) -> AsyncResult:
prompt = format_prompt(messages)
if api_key is not None:
if cookies is None:
cookies = {}
cookies["__Secure-1PSID"] = api_key
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"]
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
base_connector = get_connector(connector, proxy)
async with ClientSession(
headers=REQUEST_HEADERS,
connector=base_connector
) as session:
snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
if not snlm0e:
if not cls._snlm0e:
await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
if not cls._snlm0e:
async for chunk in cls.nodriver_login(proxy):
yield chunk
if cls._cookies is None:
async for chunk in cls.webdriver_login(proxy):
yield chunk

if not snlm0e:
if not cls._snlm0e:
if cls._cookies is None or "__Secure-1PSID" not in cls._cookies:
raise MissingAuthError('Missing "__Secure-1PSID" cookie')
snlm0e = await cls.fetch_snlm0e(session, cls._cookies)
if not snlm0e:
await cls.fetch_snlm0e(session, cls._cookies)
if not cls._snlm0e:
raise RuntimeError("Invalid cookies. SNlM0e not found")

image_url = await cls.upload_image(base_connector, to_bytes(image), image_name) if image else None

async with ClientSession(
cookies=cls._cookies,
headers=REQUEST_HEADERS,
connector=base_connector,
) as client:
params = {
'bl': REQUEST_BL_PARAM,
'hl': language,
'_reqid': random.randint(1111, 9999),
'rt': 'c'
'rt': 'c',
"f.sid": cls._sid,
}
data = {
'at': snlm0e,
'at': cls._snlm0e,
'f.req': json.dumps([None, json.dumps(cls.build_request(
prompt,
language=language,
conversation=conversation,
image_url=image_url,
image_name=image_name
))])
Expand All @@ -177,19 +181,33 @@ async def create_async_generator(
params=params,
) as response:
await raise_for_status(response)
response = await response.text()
response_part = json.loads(json.loads(response.splitlines()[-5])[0][2])
if response_part[4] is None:
response_part = json.loads(json.loads(response.splitlines()[-7])[0][2])

content = response_part[4][0][1][0]
image_prompt = None
match = re.search(r'\[Imagen of (.*?)\]', content)
if match:
image_prompt = match.group(1)
content = content.replace(match.group(0), '')

yield content
image_prompt = response_part = None
last_content_len = 0
async for line in response.content:
try:
try:
line = json.loads(line)
except ValueError:
continue
if not isinstance(line, list):
continue
if len(line[0]) < 3 or not line[0][2]:
continue
response_part = json.loads(line[0][2])
if not response_part[4]:
continue
if return_conversation:
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
content = response_part[4][0][1][0]
except (ValueError, KeyError, TypeError, IndexError) as e:
print(f"{cls.__name__}:{e.__class__.__name__}:{e}")
continue
match = re.search(r'\[Imagen of (.*?)\]', content)
if match:
image_prompt = match.group(1)
content = content.replace(match.group(0), '')
yield content[last_content_len:]
last_content_len = len(content)
if image_prompt:
images = [image[0][3][3] for image in response_part[4][0][12][7][0]]
if response_format == "b64_json":
Expand All @@ -208,18 +226,24 @@ async def create_async_generator(

def build_request(
prompt: str,
conversation_id: str = "",
response_id: str = "",
choice_id: str = "",
language: str,
conversation: Conversation = None,
image_url: str = None,
image_name: str = None,
tools: list[list[str]] = []
) -> list:
image_list = [[[image_url, 1], image_name]] if image_url else []
return [
[prompt, 0, None, image_list, None, None, 0],
["en"],
[conversation_id, response_id, choice_id, None, None, []],
[language],
[
None if conversation is None else conversation.conversation_id,
None if conversation is None else conversation.response_id,
None if conversation is None else conversation.choice_id,
None,
None,
[]
],
None,
None,
None,
Expand Down Expand Up @@ -265,7 +289,20 @@ async def upload_image(connector: BaseConnector, image: bytes, image_name: str =
async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies):
async with session.get(cls.url, cookies=cookies) as response:
await raise_for_status(response)
text = await response.text()
match = re.search(r'SNlM0e\":\"(.*?)\"', text)
response_text = await response.text()
match = re.search(r'SNlM0e\":\"(.*?)\"', response_text)
if match:
return match.group(1)
cls._snlm0e = match.group(1)
sid_match = re.search(r'"FdrFJe":"([\d-]+)"', response_text)
if sid_match:
cls._sid = sid_match.group(1)

class Conversation(BaseConversation):
def __init__(self,
conversation_id: str = "",
response_id: str = "",
choice_id: str = ""
) -> None:
self.conversation_id = conversation_id
self.response_id = response_id
self.choice_id = choice_id
4 changes: 4 additions & 0 deletions g4f/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..typing import Union, Messages, AsyncIterator, ImageType
from ..errors import NoImageResponseError, ProviderNotFoundError
from ..requests.aiohttp import get_connector
from ..providers.conversation import BaseConversation
from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse

try:
Expand All @@ -42,6 +43,9 @@ async def iter_response(
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
elif isinstance(chunk, BaseConversation):
yield chunk
continue
content += str(chunk)
count += 1
if max_tokens is not None and count >= max_tokens:
Expand Down
4 changes: 4 additions & 0 deletions g4f/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ..typing import Union, Iterator, Messages, ImageType
from ..providers.types import BaseProvider, ProviderType, FinishReason
from ..providers.conversation import BaseConversation
from ..image import ImageResponse as ImageProviderResponse
from ..errors import NoImageResponseError
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
Expand All @@ -29,6 +30,9 @@ def iter_response(
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
elif isinstance(chunk, BaseConversation):
yield chunk
continue
content += str(chunk)
if max_tokens is not None and idx + 1 >= max_tokens:
finish_reason = "length"
Expand Down
7 changes: 7 additions & 0 deletions g4f/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}

EXTENSIONS_MAP: dict[str, str] = {
"image/png": "png",
"image/jpeg": "jpg",
"image/gif": "gif",
"image/webp": "webp",
}

def to_image(image: ImageType, is_svg: bool = False) -> Image:
"""
Converts the input image to a PIL Image object.
Expand Down

0 comments on commit 6830dfc

Please sign in to comment.