Skip to content

Commit

Permalink
Merge pull request #564 from kevinmgamboa/feature/save-qa-logging
Browse files Browse the repository at this point in the history
Added feature to save the user questions and model answers to CSV bas…
  • Loading branch information
PromtEngineer committed Oct 30, 2023
2 parents 1aa6d14 + 7b31977 commit 8cb7168
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
1 change: 1 addition & 0 deletions 3.20.2
@@ -0,0 +1 @@
Requirement already satisfied: protobuf in c:\users\kevin\anaconda3\lib\site-packages (4.24.4)
10 changes: 5 additions & 5 deletions load_models.py
Expand Up @@ -141,11 +141,11 @@ def load_full_model(model_id, model_basename, device_type, logging):
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
cache_dir=MODELS_PATH,
# trust_remote_code=True, # set these if you are using NVIDIA GPU
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16,
# max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors
trust_remote_code=True, # set these if you are using NVIDIA GPU
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors
)
model.tie_weights()
return model, tokenizer
13 changes: 12 additions & 1 deletion run_localGPT.py
Expand Up @@ -2,6 +2,7 @@
import logging
import click
import torch
import utils
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
Expand Down Expand Up @@ -207,7 +208,13 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
),
help="model type, llama, mistral or non_llama",
)
def main(device_type, show_sources, use_history, model_type):
@click.option(
"--save_qa",
is_flag=True,
help="whether to save Q&A pairs to a CSV file (Default is False)",
)

def main(device_type, show_sources, use_history, model_type, save_qa):
"""
Implements the main information retrieval task for a localGPT.
Expand Down Expand Up @@ -259,6 +266,10 @@ def main(device_type, show_sources, use_history, model_type):
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
print("----------------------------------SOURCE DOCUMENTS---------------------------")

# Log the Q&A to CSV only if save_qa is True
if save_qa:
utils.log_to_csv(query, answer)


if __name__ == "__main__":
Expand Down
25 changes: 25 additions & 0 deletions utils.py
@@ -0,0 +1,25 @@
import os
import csv
from datetime import datetime

def log_to_csv(question, answer):

log_dir, log_file = "local_chat_history", "qa_log.csv"
# Ensure log directory exists, create if not
if not os.path.exists(log_dir):
os.makedirs(log_dir)

# Construct the full file path
log_path = os.path.join(log_dir, log_file)

# Check if file exists, if not create and write headers
if not os.path.isfile(log_path):
with open(log_path, mode='w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(["timestamp", "question", "answer"])

# Append the log entry
with open(log_path, mode='a', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
writer.writerow([timestamp, question, answer])

0 comments on commit 8cb7168

Please sign in to comment.