Skip to content

Commit

Permalink
simple text data sources (#19)
Browse files Browse the repository at this point in the history
Signed-off-by: Max Pumperla <[email protected]>
  • Loading branch information
maxpumperla committed Aug 22, 2023
1 parent e802f1b commit e99f095
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,6 @@ airflow/airflow.db

# scraped folders
docs.ray.io/

# book and other source folders
data/
24 changes: 21 additions & 3 deletions app/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def path_to_uri(path, scheme="https://", domain="docs.ray.io"):
return scheme + domain + path.split(domain)[-1]


def parse_file(record):
def parse_html_file(record):
html_content = load_html_file(record["path"])
if not html_content:
return []
Expand All @@ -100,6 +100,17 @@ def parse_file(record):
]


def parse_text_file(record):
with open(record["path"]) as f:
text = f.read()
return [
{
"source": str(record["path"]),
"text": text,
}
]


class EmbedChunks:
def __init__(self, model_name):
self.embedding_model = HuggingFaceEmbeddings(
Expand Down Expand Up @@ -139,6 +150,7 @@ def __call__(self, batch):
@app.command()
def create_index(
docs_path: Annotated[str, typer.Option(help="location of data")] = DOCS_PATH,
extension_type: Annotated[str, typer.Option(help="type of data")] = "html",
embedding_model: Annotated[str, typer.Option(help="embedder")] = EMBEDDING_MODEL,
chunk_size: Annotated[int, typer.Option(help="chunk size")] = CHUNK_SIZE,
chunk_overlap: Annotated[int, typer.Option(help="chunk overlap")] = CHUNK_OVERLAP,
Expand All @@ -148,11 +160,17 @@ def create_index(

# Dataset
ds = ray.data.from_items(
[{"path": path} for path in Path(docs_path).rglob("*.html") if not path.is_dir()]
[
{"path": path}
for path in Path(docs_path).rglob(f"*.{extension_type}")
if not path.is_dir()
]
)

# Sections
sections_ds = ds.flat_map(parse_file)
parser = parse_html_file if extension_type == "html" else parse_text_file
sections_ds = ds.flat_map(parser)
# TODO: do we really need to take_all()? Bring the splitter to the cluster
sections = sections_ds.take_all()

# Chunking
Expand Down
4 changes: 2 additions & 2 deletions dashboard/pages/1_✨_Generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pgvector.psycopg import register_vector

from app.index import parse_file
from app.index import parse_html_file
from app.query import generate_response


Expand Down Expand Up @@ -38,7 +38,7 @@ def get_ds(docs_path):
docs_page_url = st.text_input("Docs page URL", "https://docs.ray.io/en/master/train/faq.html")
docs_page_path = docs_path_str + docs_page_url.split("docs.ray.io/en/master/")[-1]
with st.expander("View sections"):
sections = parse_file({"path": docs_page_path})
sections = parse_html_file({"path": docs_page_path})
st.write(sections)

# Chunks
Expand Down

0 comments on commit e99f095

Please sign in to comment.