Skip to content

Commit

Permalink
Support local Ollama server, explicitly name local model servers
Browse files Browse the repository at this point in the history
  • Loading branch information
ad-si committed Apr 14, 2024
1 parent 1d1483e commit 92cffcf
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 24 deletions.
25 changes: 18 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ pub enum Provider {
Anthropic,
Groq,
OpenAI,
Local,
Llamafile,
Ollama,
}

impl std::fmt::Display for Provider {
Expand All @@ -33,7 +34,8 @@ impl std::fmt::Display for Provider {
Provider::Anthropic => write!(f, "Anthropic"),
Provider::Groq => write!(f, "Groq"),
Provider::OpenAI => write!(f, "OpenAI"),
Provider::Local => write!(f, "Local"),
Provider::Llamafile => write!(f, "Llamafile"),
Provider::Ollama => write!(f, "Ollama"),
}
}
}
Expand Down Expand Up @@ -103,7 +105,9 @@ struct AnthropicAiResponse {
content: Vec<AnthropicAiContent>,
}

fn default_req_for_provider(provider: &Provider) -> AiRequest {
fn default_req_for_model(model: &Model) -> AiRequest {
let Model::Model(provider, model_id) = model;

match provider {
Provider::Groq => AiRequest {
provider: Provider::Groq,
Expand All @@ -124,11 +128,17 @@ fn default_req_for_provider(provider: &Provider) -> AiRequest {
max_tokens: 4096,
..Default::default()
},
Provider::Local => AiRequest {
provider: Provider::Local,
Provider::Llamafile => AiRequest {
provider: Provider::Llamafile,
url: "http://localhost:8080/v1/chat/completions".to_string(),
..Default::default()
},
Provider::Ollama => AiRequest {
provider: Provider::Ollama,
url: "http://localhost:11434/v1/chat/completions".to_string(),
model: model_id.to_string(),
..Default::default()
},
}
}

Expand Down Expand Up @@ -159,7 +169,8 @@ fn get_api_request(
Provider::Groq => full_config.get("groq_api_key"),
Provider::OpenAI => full_config.get("openai_api_key"),
Provider::Anthropic => full_config.get("anthropic_api_key"),
Provider::Local => Some(&dummy_key),
Provider::Llamafile => Some(&dummy_key),
Provider::Ollama => Some(&dummy_key),
}
}
.and_then(|api_key| {
Expand All @@ -173,7 +184,7 @@ fn get_api_request(
.ok_or(get_key_setup_msg(secrets_path_str))
.map(|api_key| AiRequest {
api_key: api_key.clone(),
..(default_req_for_provider(provider)).clone()
..(default_req_for_model(model)).clone()
})
}

Expand Down
60 changes: 43 additions & 17 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,22 @@ enum Commands {
/// The prompt to send to the AI model
prompt: Vec<String>,
},
/// Local model hosted at http://localhost:8080 (e.g. Llamafile)
#[clap(visible_alias = "lo")]
Local {
/// Llamafile server hosted at http://localhost:8080
#[clap(visible_alias = "lf")]
Llamafile {
/// The prompt to send to the AI model
prompt: Vec<String>,
},
/// Send the prompt to every provider's default model simultaneously
/// (Claude Haiku, Groq Mixtral, GPT 4 Turbo, Local)
/// Ollama server hosted at http://localhost:11434
#[clap(visible_alias = "ol")]
Ollama {
/// The model to use (e.g. llama2, mistral, …)
model: String,
/// The prompt to send to the AI model
prompt: Vec<String>,
},
/// Send prompt to each provider's default model simultaneously
/// (Claude Haiku, Groq Mixtral, GPT 4 Turbo, Llamafile, Ollama Llama2)
All {
/// The prompt to send to the AI models simultaneously
prompt: Vec<String>,
Expand All @@ -80,13 +88,17 @@ enum Commands {
<dim># Send a prompt to the default model</dim>
<b>cai</b> How heigh is the Eiffel Tower in meters
<dim># Send a prompt to the default model of each provider</dim>
<dim># Send a prompt to each provider's default model</dim>
<b>cai all</b> How heigh is the Eiffel Tower in meters
<dim># Send a prompt to Anthropic's Claude Opus (+ alias)</dim>
<b>cai claude-opus</b> How heigh is the Eiffel Tower in meters
<b>cai op</b> How heigh is the Eiffel Tower in meters
<dim># Send a prompt to locally running Ollama server</dim>
<b>cai ollama mistral</b> How heigh is the Eiffel Tower in meters
<b>cai ol mistral</b> How heigh is the Eiffel Tower in meters
<dim># Add data via stdin</dim>
cat main.rs | <b>cai</b> Explain this code
"
Expand All @@ -113,62 +125,75 @@ fn capitalize_str(str: &str) -> String {
}

async fn exec_with_args(args: Args, stdin: &str) {
let stdin = if stdin.is_empty() {
"".into()
} else {
format!("{}\n", stdin)
};

match args.command {
None => {
// No subcommand provided -> Use input as prompt for the default model
submit_prompt(
&None,
&format!("{stdin}\n{}", &args.prompt.join(" ")), //
&format!("{stdin}{}", &args.prompt.join(" ")), //
)
.await
}
Some(cmd) => match cmd {
Commands::Mixtral { prompt } => {
submit_prompt(
&Some(Model::Model(Provider::Groq, GROQ_MIXTRAL.to_string())),
&format!("{stdin}\n{}", prompt.join(" ")),
&format!("{stdin}{}", prompt.join(" ")),
)
.await
}
Commands::GptTurbo { prompt } => {
submit_prompt(
&Some(Model::Model(Provider::OpenAI, OPENAI_GPT_TURBO.to_string())),
&format!("{stdin}\n{}", prompt.join(" ")),
&format!("{stdin}{}", prompt.join(" ")),
)
.await
}
Commands::Gpt { prompt } => {
submit_prompt(
&Some(Model::Model(Provider::OpenAI, OPENAI_GPT.to_string())),
&format!("{stdin}\n{}", prompt.join(" ")),
&format!("{stdin}{}", prompt.join(" ")),
)
.await
}
Commands::ClaudeOpus { prompt } => {
submit_prompt(
&Some(Model::Model(Provider::Anthropic, CLAUDE_OPUS.to_string())),
&format!("{stdin}\n{}", prompt.join(" ")),
&format!("{stdin}{}", prompt.join(" ")),
)
.await
}
Commands::ClaudeSonnet { prompt } => {
submit_prompt(
&Some(Model::Model(Provider::Anthropic, CLAUDE_SONNET.to_string())),
&format!("{stdin}\n{}", prompt.join(" ")),
&format!("{stdin}{}", prompt.join(" ")),
)
.await
}
Commands::ClaudeHaiku { prompt } => {
submit_prompt(
&Some(Model::Model(Provider::Anthropic, CLAUDE_HAIKU.to_string())),
&format!("{stdin}\n{}", prompt.join(" ")),
&format!("{stdin}{}", prompt.join(" ")),
)
.await
}
Commands::Local { prompt } => {
Commands::Llamafile { prompt } => {
submit_prompt(
&Some(Model::Model(Provider::Llamafile, "".to_string())),
&format!("{stdin}{}", prompt.join(" ")),
)
.await //
}
Commands::Ollama { model, prompt } => {
submit_prompt(
&Some(Model::Model(Provider::Local, "".to_string())),
&format!("{stdin}\n{}", prompt.join(" ")),
&Some(Model::Model(Provider::Ollama, model)),
&format!("{stdin}{}", prompt.join(" ")),
)
.await //
}
Expand All @@ -177,7 +202,8 @@ async fn exec_with_args(args: Args, stdin: &str) {
Model::Model(Provider::Anthropic, CLAUDE_HAIKU.to_string()),
Model::Model(Provider::Groq, GROQ_MIXTRAL.to_string()),
Model::Model(Provider::OpenAI, OPENAI_GPT_TURBO.to_string()),
Model::Model(Provider::Local, "".to_string()),
Model::Model(Provider::Llamafile, "".to_string()),
Model::Model(Provider::Ollama, "llama2".to_string()),
];

let mut handles = vec![];
Expand Down

0 comments on commit 92cffcf

Please sign in to comment.