Skip to content

Commit

Permalink
Merge branch 'master' into oraclevs_integration
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed May 4, 2024
2 parents 53cd5ac + c9e9470 commit ff92d2e
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 237 deletions.
63 changes: 34 additions & 29 deletions libs/community/langchain_community/tools/gmail/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Gmail tool utils."""

from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING, List, Optional, Tuple

from langchain_core.utils import guard_import

if TYPE_CHECKING:
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
Expand All @@ -21,16 +24,15 @@ def import_google() -> Tuple[Request, Credentials]:
Returns:
Tuple[Request, Credentials]: Request and Credentials classes.
"""
# google-auth-httplib2
try:
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
except ImportError:
raise ImportError(
"You need to install google-auth-httplib2 to use this toolkit. "
"Try running pip install --upgrade google-auth-httplib2"
)
return Request, Credentials
return (
guard_import(
module_name="google.auth.transport.requests",
pip_name="google-auth-httplib2",
).Request,
guard_import(
module_name="google.oauth2.credentials", pip_name="google-auth-httplib2"
).Credentials,
)


def import_installed_app_flow() -> InstalledAppFlow:
Expand All @@ -39,14 +41,9 @@ def import_installed_app_flow() -> InstalledAppFlow:
Returns:
InstalledAppFlow: InstalledAppFlow class.
"""
try:
from google_auth_oauthlib.flow import InstalledAppFlow
except ImportError:
raise ImportError(
"You need to install google-auth-oauthlib to use this toolkit. "
"Try running pip install --upgrade google-auth-oauthlib"
)
return InstalledAppFlow
return guard_import(
module_name="google_auth_oauthlib.flow", pip_name="google-auth-oauthlib"
).InstalledAppFlow


def import_googleapiclient_resource_builder() -> build_resource:
Expand All @@ -55,14 +52,9 @@ def import_googleapiclient_resource_builder() -> build_resource:
Returns:
build_resource: googleapiclient.discovery.build function.
"""
try:
from googleapiclient.discovery import build
except ImportError:
raise ImportError(
"You need to install googleapiclient to use this toolkit. "
"Try running pip install --upgrade google-api-python-client"
)
return build
return guard_import(
module_name="googleapiclient.discovery", pip_name="google-api-python-client"
).build


DEFAULT_SCOPES = ["https://mail.google.com/"]
Expand All @@ -77,8 +69,19 @@ def get_gmail_credentials(
) -> Credentials:
"""Get credentials."""
# From https://developers.google.com/gmail/api/quickstart/python
Request, Credentials = import_google()
InstalledAppFlow = import_installed_app_flow()
Request, Credentials = (
guard_import(
module_name="google.auth.transport.requests",
pip_name="google-auth-httplib2",
).Request,
guard_import(
module_name="google.oauth2.credentials", pip_name="google-auth-httplib2"
).Credentials,
)

InstalledAppFlow = guard_import(
module_name="google_auth_oauthlib.flow", pip_name="google-auth-oauthlib"
).InstalledAppFlow
creds = None
scopes = scopes or DEFAULT_SCOPES
token_file = token_file or DEFAULT_CREDS_TOKEN_FILE
Expand Down Expand Up @@ -111,7 +114,9 @@ def build_resource_service(
) -> Resource:
"""Build a Gmail service."""
credentials = credentials or get_gmail_credentials()
builder = import_googleapiclient_resource_builder()
builder = guard_import(
module_name="googleapiclient.discovery", pip_name="google-api-python-client"
).build
return builder(service_name, service_version, credentials=credentials)


Expand Down
20 changes: 9 additions & 11 deletions libs/community/langchain_community/tools/playwright/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseTool
from langchain_core.utils import guard_import

if TYPE_CHECKING:
from playwright.async_api import Browser as AsyncBrowser
Expand All @@ -25,15 +26,10 @@ def lazy_import_playwright_browsers() -> Tuple[Type[AsyncBrowser], Type[SyncBrow
Tuple[Type[AsyncBrowser], Type[SyncBrowser]]:
AsyncBrowser and SyncBrowser classes.
"""
try:
from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser
except ImportError:
raise ImportError(
"The 'playwright' package is required to use the playwright tools."
" Please install it with 'pip install playwright'."
)
return AsyncBrowser, SyncBrowser
return (
guard_import(module_name="playwright.async_api").AsyncBrowser,
guard_import(module_name="playwright.sync_api").SyncBrowser,
)


class BaseBrowserTool(BaseTool):
Expand All @@ -45,7 +41,8 @@ class BaseBrowserTool(BaseTool):
@root_validator
def validate_browser_provided(cls, values: dict) -> dict:
"""Check that the arguments are valid."""
lazy_import_playwright_browsers()
guard_import(module_name="playwright.async_api").AsyncBrowser
guard_import(module_name="playwright.sync_api").SyncBrowser
if values.get("async_browser") is None and values.get("sync_browser") is None:
raise ValueError("Either async_browser or sync_browser must be specified.")
return values
Expand All @@ -57,5 +54,6 @@ def from_browser(
async_browser: Optional[AsyncBrowser] = None,
) -> BaseBrowserTool:
"""Instantiate the tool."""
lazy_import_playwright_browsers()
guard_import(module_name="playwright.async_api").AsyncBrowser
guard_import(module_name="playwright.sync_api").SyncBrowser
return cls(sync_browser=sync_browser, async_browser=async_browser)
4 changes: 2 additions & 2 deletions libs/core/langchain_core/runnables/graph_mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def draw_mermaid(
subgraph = ""
# Add edges to the graph
for edge in edges:
src_prefix = edge.source.split(":")[0]
tgt_prefix = edge.target.split(":")[0]
src_prefix = edge.source.split(":")[0] if ":" in edge.source else None
tgt_prefix = edge.target.split(":")[0] if ":" in edge.target else None
# exit subgraph if source or target is not in the same subgraph
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
mermaid_graph += "\tend\n"
Expand Down
6 changes: 4 additions & 2 deletions libs/core/langchain_core/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic utility functions."""

import contextlib
import datetime
import functools
Expand Down Expand Up @@ -88,10 +89,11 @@ def guard_import(
installed."""
try:
module = importlib.import_module(module_name, package)
except ImportError:
except (ImportError, ModuleNotFoundError):
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name or module_name}`."
f"Please install it with `pip install {pip_name}`."
)
return module

Expand Down
62 changes: 60 additions & 2 deletions libs/core/tests/unit_tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import re
from contextlib import AbstractContextManager, nullcontext
from typing import Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type, Union
from unittest.mock import patch

import pytest

from langchain_core.utils import check_package_version
from langchain_core import utils
from langchain_core.utils import check_package_version, guard_import
from langchain_core.utils._merge import merge_dicts


Expand Down Expand Up @@ -113,3 +114,60 @@ def test_merge_dicts(
with err:
actual = merge_dicts(left, right)
assert actual == expected


@pytest.mark.parametrize(
("module_name", "pip_name", "package", "expected"),
[
("langchain_core.utils", None, None, utils),
("langchain_core.utils", "langchain-core", None, utils),
("langchain_core.utils", None, "langchain-core", utils),
("langchain_core.utils", "langchain-core", "langchain-core", utils),
],
)
def test_guard_import(
module_name: str, pip_name: Optional[str], package: Optional[str], expected: Any
) -> None:
if package is None and pip_name is None:
ret = guard_import(module_name)
elif package is None and pip_name is not None:
ret = guard_import(module_name, pip_name=pip_name)
elif package is not None and pip_name is None:
ret = guard_import(module_name, package=package)
elif package is not None and pip_name is not None:
ret = guard_import(module_name, pip_name=pip_name, package=package)
else:
raise ValueError("Invalid test case")
assert ret == expected


@pytest.mark.parametrize(
("module_name", "pip_name", "package"),
[
("langchain_core.utilsW", None, None),
("langchain_core.utilsW", "langchain-core-2", None),
("langchain_core.utilsW", None, "langchain-coreWX"),
("langchain_core.utilsW", "langchain-core-2", "langchain-coreWX"),
("langchain_coreW", None, None), # ModuleNotFoundError
],
)
def test_guard_import_failure(
module_name: str, pip_name: Optional[str], package: Optional[str]
) -> None:
with pytest.raises(ImportError) as exc_info:
if package is None and pip_name is None:
guard_import(module_name)
elif package is None and pip_name is not None:
guard_import(module_name, pip_name=pip_name)
elif package is not None and pip_name is None:
guard_import(module_name, package=package)
elif package is not None and pip_name is not None:
guard_import(module_name, pip_name=pip_name, package=package)
else:
raise ValueError("Invalid test case")
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
err_msg = (
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
assert exc_info.value.msg == err_msg
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def _get_extraction_function(entity_schema: dict) -> dict:
"https://github.com/langchain-ai/langchain/discussions/18154"
),
removal="0.3.0",
pending=True,
alternative=(
"""
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down Expand Up @@ -130,7 +129,6 @@ def create_extraction_chain(
"https://github.com/langchain-ai/langchain/discussions/18154"
),
removal="0.3.0",
pending=True,
alternative=(
"""
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down
1 change: 0 additions & 1 deletion libs/langchain/langchain/chains/openai_tools/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"https://github.com/langchain-ai/langchain/discussions/18154"
),
removal="0.3.0",
pending=True,
alternative=(
"""
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down
2 changes: 0 additions & 2 deletions libs/langchain/langchain/chains/structured_output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"https://github.com/langchain-ai/langchain/discussions/18154"
),
removal="0.3.0",
pending=True,
alternative=(
"""
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down Expand Up @@ -160,7 +159,6 @@ class RecordDog(BaseModel):
"https://github.com/langchain-ai/langchain/discussions/18154"
),
removal="0.3.0",
pending=True,
alternative=(
"""
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down
9 changes: 5 additions & 4 deletions libs/langchain/langchain/memory/summary_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.pydantic_v1 import root_validator
Expand All @@ -15,8 +15,9 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
memory_key: str = "history"

@property
def buffer(self) -> List[BaseMessage]:
return self.chat_memory.messages
def buffer(self) -> Union[str, List[BaseMessage]]:
"""String buffer of memory."""
return self.load_memory_variables({})[self.memory_key]

@property
def memory_variables(self) -> List[str]:
Expand All @@ -28,7 +29,7 @@ def memory_variables(self) -> List[str]:

def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
buffer = self.buffer
buffer = self.chat_memory.messages
if self.moving_summary_buffer != "":
first_messages: List[BaseMessage] = [
self.summary_message_cls(content=self.moving_summary_buffer)
Expand Down

0 comments on commit ff92d2e

Please sign in to comment.