mirror of
https://github.com/aljazceru/notedeck.git
synced 2025-12-19 09:34:19 +01:00
dave: introduce model config
so you can switch between openai and ollama models Signed-off-by: William Casarin <jb55@jb55.com>
This commit is contained in:
@@ -304,6 +304,46 @@ pub struct Dave {
|
||||
tools: Arc<HashMap<String, Tool>>,
|
||||
client: async_openai::Client<OpenAIConfig>,
|
||||
incoming_tokens: Option<Receiver<DaveResponse>>,
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
pub struct ModelConfig {
|
||||
endpoint: Option<String>,
|
||||
model: String,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for ModelConfig {
|
||||
fn default() -> Self {
|
||||
ModelConfig {
|
||||
endpoint: None,
|
||||
model: "gpt-4o".to_string(),
|
||||
api_key: std::env::var("OPENAI_API_KEY").ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
pub fn ollama() -> Self {
|
||||
ModelConfig {
|
||||
endpoint: std::env::var("OLLAMA_HOST").ok(),
|
||||
model: "hhao/qwen2.5-coder-tools:latest".to_string(),
|
||||
api_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_api(&self) -> OpenAIConfig {
|
||||
let mut cfg = OpenAIConfig::new();
|
||||
if let Some(endpoint) = &self.endpoint {
|
||||
cfg = cfg.with_api_base(endpoint.to_owned());
|
||||
}
|
||||
|
||||
if let Some(api_key) = &self.api_key {
|
||||
cfg = cfg.with_api_key(api_key.to_owned());
|
||||
}
|
||||
|
||||
cfg
|
||||
}
|
||||
}
|
||||
|
||||
impl Dave {
|
||||
@@ -312,12 +352,9 @@ impl Dave {
|
||||
}
|
||||
|
||||
pub fn new(render_state: Option<&RenderState>) -> Self {
|
||||
let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1");
|
||||
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
|
||||
let client = Client::with_config(config);
|
||||
//let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1");
|
||||
let model_config = ModelConfig::default();
|
||||
let client = Client::with_config(model_config.to_api());
|
||||
|
||||
let input = "".to_string();
|
||||
let pubkey = "32e1827635450ebb3c5a7d12c1f8e7b2b514439ac10a67eef3d9fd9c5c68e245".to_string();
|
||||
@@ -336,6 +373,7 @@ impl Dave {
|
||||
incoming_tokens: None,
|
||||
tools: Arc::new(tools),
|
||||
input,
|
||||
model_config,
|
||||
chat: vec![system_prompt],
|
||||
}
|
||||
}
|
||||
@@ -513,6 +551,7 @@ impl Dave {
|
||||
let ctx = ctx.clone();
|
||||
let client = self.client.clone();
|
||||
let tools = self.tools.clone();
|
||||
let model_name = self.model_config.model.clone();
|
||||
|
||||
let (tx, rx) = mpsc::channel();
|
||||
self.incoming_tokens = Some(rx);
|
||||
@@ -521,8 +560,7 @@ impl Dave {
|
||||
let mut token_stream = match client
|
||||
.chat()
|
||||
.create_stream(CreateChatCompletionRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
//model: "gpt-4o".to_string(),
|
||||
model: model_name,
|
||||
stream: Some(true),
|
||||
messages,
|
||||
tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()),
|
||||
|
||||
Reference in New Issue
Block a user