Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Commit

Permalink
Merge pull request #24 from emrgnt-cmplxty/feature/add-hugging-face-r…
Browse files Browse the repository at this point in the history
…ebased2

Feature/add hugging face rebased2
  • Loading branch information
emrgnt-cmplxty committed Aug 25, 2023
2 parents 200bfd8 + 659b980 commit 3861cf1
Show file tree
Hide file tree
Showing 9 changed files with 704 additions and 15 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Overview

The Zero-Shot Replication Framework is a minimal environment designed to replicate zero-shot results from past academic papers. It currently supports OpenAI models to generate completions for various datasets and provides tools for handling, evaluating, and storing these completions.
The Zero-Shot Replication Framework is a minimal environment designed to replicate zero-shot results from past academic papers. It currently supports OpenAI, Anthropic, and HuggingFace models to generate completions for various datasets and provides tools for handling, evaluating, and storing these completions.

## Features

Expand Down Expand Up @@ -71,15 +71,15 @@ To see explicit commands ran to generate the reported results, check out the [co

## Results (all models accessed on 08/24)

| Category | gpt-3.5-turbo-0301 | gpt-3.5-turbo-0613 | Claude 2 | GPT-4-0314 | GPT-4-0613 | GPT-4 Baseline | Sources |
|------------------|--------------------|--------------------|----------|------------|------------|----------------|----------|
| HumanEval | 81.7 | XX | 65.2 | 87.2 | 84.1 | 67 | [1] |
| EvalPlus | 71.3 | XX | 54.9 | 79.2 | 74.4 | N/A | |
| Leetcode Easy | XX | XX | XX | 91.0 | 88.0 | 72.2-75.6 | [1,2] |
| Leetcode Medium | XX | XX | XX | 26.0 | 17.0 | 26.2-38.7 | [1,2] |
| Leetcode Hard | XX | XX | XX | 6.0 | 4.0 | 6.7-7 | [1,2] |
| GSM8K | XX | XX | XX | X | X | 87.1 | |
| MATH | XX | XX | XX | 49.0 | 46.4 | 42.2 | [3] |
| Category | gpt-3.5-turbo-0301 | gpt-3.5-turbo-0613 | claude-2 | gpt-4-0314 | gpt-4-0613 | gpt-4 Baseline | Sources |
|----------------------|--------------------|--------------------|----------|------------|------------|----------------|----------|
| HumanEval | 81.7 | 61.5 | 65.2 | 87.2 | 84.1 | 67 | [1] |
| EvalPlus | 71.3 | 54.2 | 54.9 | 79.2 | 74.4 | N/A | |
| LeetCode_100 Easy | XX | XX | 73.0 | 91.0 | 88.0 | 72.2-75.6 | [1,2] |
| LeetCode_100 Medium | XX | XX | 16.0 | 26.0 | 17.0 | 26.2-38.7 | [1,2] |
| LeetCode_100 Hard | XX | XX | 2.0 | 6.0 | 4.0 | 6.7-7 | [1,2] |
| GSM8K | XX | XX | XX | X | X | 87.1 | |
| MATH | XX | XX | XX | 49.0 | 46.4 | 42.2 | [3] |

## License

Expand Down
512 changes: 511 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ python-leetcode = "1.2.1"
astunparse = "1.6.3"
anthropic = "^0.3.10"
numpy = "^1.25.2"
transformers = "^4.32.0"
torch = "^2.0.1"
accelerate = "^0.22.0"

[tool.poetry.group.dev.dependencies]
sourcery = "^1.6.0"
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion zero_shot_replication/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def parse_arguments() -> argparse.Namespace:
def prep_for_file_path(in_path: str) -> str:
"""Prepare a string to be used in a file path."""

return in_path.replace("-", "_").replace(".", "p")
return in_path.replace("-", "_").replace(".", "p").replace("/", "_")


def extract_code(raw_response: str) -> str:
Expand Down
8 changes: 8 additions & 0 deletions zero_shot_replication/llm_providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@


class LLMProvider(ABC):
"""An abstract class to provide a common interface for LLM providers."""

@abstractmethod
def __init__(self, model: str, temperature: float) -> None:
pass

@abstractmethod
def get_completion(self, prompt: str) -> str:
pass


@dataclass
class ProviderConfig:
"""A dataclass to hold the configuration for a provider."""

name: str
models: List[str]
llm_class: Type[LLMProvider]
48 changes: 48 additions & 0 deletions zero_shot_replication/llm_providers/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import transformers
from transformers import AutoTokenizer

from zero_shot_replication.llm_providers.base import LLMProvider


class HuggingFaceZeroShotProvider(LLMProvider):
"""A class to provide zero-shot completions from the Anthropic API."""

MAX_TOKENS_TO_SAMPLE = (
4_096 # This is a large value, we should check if it makes sense
)

def __init__(
self,
model: str = "facebook/opt-125m",
temperature: float = 0.7,
stream: bool = False,
) -> None:
self.model = model
self.temperature = temperature
self.stream = stream
self.tokenizer = AutoTokenizer.from_pretrained(self.model)

self.pipeline = transformers.pipeline(
"text-generation",
model=self.model,
torch_dtype=torch.float16,
device_map="auto",
)

def get_completion(self, prompt: str) -> str:
"""Get a completion from the Anthropic API based on the provided prompt."""

sequences = self.pipeline(
'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n',
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
max_length=HuggingFaceZeroShotProvider.MAX_TOKENS_TO_SAMPLE,
)
for seq in sequences:
print(f"Result: {seq['generated_text']}")
raise NotImplementedError(
"HuggingFaceZeroShotProvider not implemented."
)
21 changes: 19 additions & 2 deletions zero_shot_replication/llm_providers/provider_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
LLMProvider,
ProviderConfig,
)
from zero_shot_replication.llm_providers.huggingface import (
HuggingFaceZeroShotProvider,
)
from zero_shot_replication.llm_providers.openai import OpenAIZeroShotProvider


Expand All @@ -25,14 +28,28 @@ class ProviderManager:
["claude-2", "claude-instant-1"],
AnthropicZeroShotProvider,
),
ProviderConfig(
"huggingface",
[
"facebook/opt-125m", # for testing
"meta-llama/Llama-2-7b",
"meta-llama/Llama-2-13b",
"meta-llama/Llama-2-70b",
],
HuggingFaceZeroShotProvider,
),
]

@staticmethod
def get_provider(provider_name: str, model_name: str) -> LLMProvider:
def get_provider(
provider_name: str, model_name: str, temperature: float
) -> LLMProvider:
for provider in ProviderManager.PROVIDERS:
if provider.name == provider_name:
if model_name in provider.models:
return provider.llm_class()
return provider.llm_class(
model=model_name, temperature=temperature
)
raise ValueError(
f"Model '{model_name}' not supported by provider '{provider_name}'"
)
Expand Down
4 changes: 3 additions & 1 deletion zero_shot_replication/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def get_output_path(args: argparse.Namespace) -> str:
out_path = get_output_path(args)

# Build an LLM provider instance
llm_provider = ProviderManager.get_provider(args.provider, args.model)
llm_provider = ProviderManager.get_provider(
args.provider, args.model, args.temperature
)

if not llm_provider:
raise NotImplementedError(f"Provider '{args.provider}' not supported.")
Expand Down

0 comments on commit 3861cf1

Please sign in to comment.