refactor models component (#1535)

This commit is contained in:
Lily Delalande
2025-03-07 11:56:03 -08:00
committed by GitHub
parent 148737c1cf
commit 4f9c08a025
12 changed files with 113 additions and 90 deletions

View File

@@ -167,7 +167,9 @@ export default function BottomMenu({
className="flex items-center cursor-pointer"
onClick={() => setIsModelMenuOpen(!isModelMenuOpen)}
>
<span>{envModelProvider || currentModel?.name || 'Select Model'}</span>
<span>
{envModelProvider || (currentModel?.alias ?? currentModel?.name) || 'Select Model'}
</span>
{isModelMenuOpen ? (
<ChevronDown className="w-4 h-4 ml-1" />
) : (
@@ -182,14 +184,14 @@ export default function BottomMenu({
<ModelRadioList
className="divide-y divide-borderSubtle"
renderItem={({ model, isSelected, onSelect }) => (
<label key={model.name} className="block cursor-pointer">
<label key={model.alias ?? model.name} className="block cursor-pointer">
<div
className="flex items-center justify-between p-2 text-textStandard hover:bg-bgSubtle transition-colors"
onClick={onSelect}
>
<div>
<p className="text-sm ">{model.name}</p>
<p className="text-xs text-textSubtle">{model.provider}</p>
<p className="text-sm ">{model.alias ?? model.name}</p>
<p className="text-xs text-textSubtle">{model.subtext ?? model.provider}</p>
</div>
<div className="relative">
<input

View File

@@ -247,23 +247,10 @@ export default function SettingsView({
{/* Content Area */}
<div className="flex-1 py-8 pt-[20px]">
<div className="space-y-8">
{/*Models Section*/}
<section id="models">
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-medium text-textStandard">Models</h2>
<button
onClick={() => {
setView('moreModels');
}}
className="text-indigo-500 hover:text-indigo-600 text-sm"
>
Browse
</button>
</div>
<div className="px-8">
<RecentModelsRadio />
</div>
<RecentModelsRadio setView={setView} />
</section>
<section id="extensions">
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-semibold text-textStandard">Extensions</h2>

View File

@@ -5,7 +5,7 @@ import Select from 'react-select';
import { Plus } from 'lucide-react';
import { createSelectedModel, useHandleModelSelection } from './utils';
import { useActiveKeys } from '../api_keys/ActiveKeysContext';
import { goose_models } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { createDarkSelectStyles, darkSelectTheme } from '../../ui/select-styles';
export function AddModelInline() {
@@ -31,7 +31,7 @@ export function AddModelInline() {
return;
}
const filtered = goose_models
const filtered = gooseModels
.filter(
(model) =>
model.provider.toLowerCase() === selectedProvider &&

View File

@@ -0,0 +1,30 @@
import { Model } from './ModelContext';
// TODO: move into backends / fetch dynamically
// this is used by ModelContext
export const gooseModels: Model[] = [
{ id: 1, name: 'gpt-4o-mini', provider: 'OpenAI' },
{ id: 2, name: 'gpt-4o', provider: 'OpenAI' },
{ id: 3, name: 'gpt-4-turbo', provider: 'OpenAI' },
{ id: 5, name: 'o1', provider: 'OpenAI' },
{ id: 7, name: 'claude-3-5-sonnet-latest', provider: 'Anthropic' },
{ id: 8, name: 'claude-3-5-haiku-latest', provider: 'Anthropic' },
{ id: 9, name: 'claude-3-opus-latest', provider: 'Anthropic' },
{ id: 10, name: 'gemini-1.5-pro', provider: 'Google' },
{ id: 11, name: 'gemini-1.5-flash', provider: 'Google' },
{ id: 12, name: 'gemini-2.0-flash', provider: 'Google' },
{ id: 13, name: 'gemini-2.0-flash-lite-preview-02-05', provider: 'Google' },
{ id: 14, name: 'gemini-2.0-flash-thinking-exp-01-21', provider: 'Google' },
{ id: 15, name: 'gemini-2.0-pro-exp-02-05', provider: 'Google' },
{ id: 16, name: 'llama-3.3-70b-versatile', provider: 'Groq' },
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'claude-3-7-sonnet@20250219', provider: 'GCP Vertex AI' },
{ id: 21, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' },
{ id: 22, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' },
{ id: 23, name: 'claude-3-5-haiku@20241022', provider: 'GCP Vertex AI' },
{ id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' },
{ id: 25, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' },
{ id: 26, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' },
];

View File

@@ -1,6 +1,6 @@
import React, { createContext, useContext, useState, ReactNode } from 'react';
import { GOOSE_MODEL, GOOSE_PROVIDER } from '../../../env_vars';
import { goose_models } from './hardcoded_stuff'; // Assuming hardcoded models are here
import { gooseModels } from './GooseModels'; // Assuming hardcoded models are here
// TODO: API keys
export interface Model {
@@ -8,6 +8,8 @@ export interface Model {
name: string;
provider: string;
lastUsed?: string;
alias?: string; // optional model display name
subtext?: string; // goes below model name if not the provider
}
interface ModelContextValue {
@@ -31,7 +33,7 @@ export const ModelProvider = ({ children }: { children: ReactNode }) => {
const switchModel = (model: Model) => {
const newModel = model.id
? goose_models.find((m) => m.id === model.id) || model
? gooseModels.find((m) => m.id === model.id) || model
: { id: Date.now(), ...model }; // Assign unique ID for user-defined models
updateModel(newModel);
};

View File

@@ -1,8 +1,16 @@
import React, { useState, useEffect } from 'react';
import { Model } from './ModelContext';
import { useRecentModels } from './RecentModels';
import { useModel } from './ModelContext';
import { useHandleModelSelection } from './utils';
import { useRecentModels } from './RecentModels';
import type { View } from '@/src/App';
import { SettingsViewOptions } from '@/src/components/settings/SettingsView';
export interface Model {
id?: number; // Make `id` optional to allow user-defined models
name: string;
provider: string;
lastUsed?: string;
}
interface ModelRadioListProps {
renderItem: (props: {
@@ -13,6 +21,22 @@ interface ModelRadioListProps {
className?: string;
}
export function SeeMoreModelsButtons({ setView }: { setView: (view: View) => void }) {
return (
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-medium text-textStandard">Models</h2>
<button
onClick={() => {
setView('moreModels');
}}
className="text-indigo-500 hover:text-indigo-600 text-sm"
>
Browse
</button>
</div>
);
}
export function ModelRadioList({ renderItem, className = '' }: ModelRadioListProps) {
const { recentModels } = useRecentModels();
const { currentModel } = useModel();

View File

@@ -2,7 +2,8 @@ import React, { useState, useEffect } from 'react';
import { Button } from '../../ui/button';
import { Switch } from '../../ui/switch';
import { useActiveKeys } from '../api_keys/ActiveKeysContext';
import { model_docs_link, goose_models } from './hardcoded_stuff';
import { model_docs_link } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { useModel } from './ModelContext';
import { useHandleModelSelection } from './utils';
@@ -31,7 +32,7 @@ export function ProviderButtons() {
// Filter models by provider
const providerModels = selectedProvider
? goose_models.filter((model) => model.provider === selectedProvider)
? gooseModels.filter((model) => model.provider === selectedProvider)
: [];
return (

View File

@@ -1,9 +1,10 @@
import React, { useState, useEffect } from 'react';
import { Clock } from 'lucide-react';
import { Model } from './ModelContext';
import { useHandleModelSelection } from './utils';
import { ModelRadioList, SeeMoreModelsButtons } from './ModelRadioList';
import { useModel } from './ModelContext';
import { ModelRadioList } from './ModelRadioList';
import { useHandleModelSelection } from './utils';
import type { View } from '../../../App';
const MAX_RECENT_MODELS = 3;
@@ -129,10 +130,12 @@ export function RecentModels() {
);
}
export function RecentModelsRadio() {
export function RecentModelsRadio({ setView }: { setView: (view: View) => void }) {
return (
<div>
<SeeMoreModelsButtons setView={setView} />
<div className="px-8">
<div className="space-y-2">
<h2 className="text-md font-medium text-textStandard">Recently used</h2>
<ModelRadioList
renderItem={({ model, isSelected, onSelect }) => (
<label key={model.name} className="flex items-center py-2 cursor-pointer">
@@ -154,12 +157,14 @@ export function RecentModelsRadio() {
</div>
<div className="">
<p className="text-sm text-textStandard">{model.name}</p>
<p className="text-xs text-textSubtle">{model.provider}</p>
<p className="text-sm text-textStandard">{model.alias ?? model.name}</p>
<p className="text-xs text-textSubtle">{model.subtext ?? model.provider}</p>
</div>
</label>
)}
/>
</div>
</div>
</div>
);
}

View File

@@ -1,7 +1,7 @@
import React, { useState, useEffect, useRef } from 'react';
import { Search } from 'lucide-react';
import { Switch } from '../../ui/switch';
import { goose_models } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { useModel } from './ModelContext';
import { useHandleModelSelection } from './utils';
import { useActiveKeys } from '../api_keys/ActiveKeysContext';
@@ -22,7 +22,7 @@ export function SearchBar() {
// results set will only include models that have a configured provider
const { activeKeys } = useActiveKeys(); // Access active keys from context
const model_options = goose_models.filter((model) => activeKeys.includes(model.provider));
const model_options = gooseModels.filter((model) => activeKeys.includes(model.provider));
const filteredModels = model_options
.filter((model) => model.name.toLowerCase().includes(search.toLowerCase()))

View File

@@ -1,33 +1,5 @@
import { Model } from './ModelContext';
// TODO: move into backends / fetch dynamically
export const goose_models: Model[] = [
{ id: 1, name: 'gpt-4o-mini', provider: 'OpenAI' },
{ id: 2, name: 'gpt-4o', provider: 'OpenAI' },
{ id: 3, name: 'gpt-4-turbo', provider: 'OpenAI' },
{ id: 5, name: 'o1', provider: 'OpenAI' },
{ id: 7, name: 'claude-3-5-sonnet-latest', provider: 'Anthropic' },
{ id: 8, name: 'claude-3-5-haiku-latest', provider: 'Anthropic' },
{ id: 9, name: 'claude-3-opus-latest', provider: 'Anthropic' },
{ id: 10, name: 'gemini-1.5-pro', provider: 'Google' },
{ id: 11, name: 'gemini-1.5-flash', provider: 'Google' },
{ id: 12, name: 'gemini-2.0-flash', provider: 'Google' },
{ id: 13, name: 'gemini-2.0-flash-lite-preview-02-05', provider: 'Google' },
{ id: 14, name: 'gemini-2.0-flash-thinking-exp-01-21', provider: 'Google' },
{ id: 15, name: 'gemini-2.0-pro-exp-02-05', provider: 'Google' },
{ id: 16, name: 'llama-3.3-70b-versatile', provider: 'Groq' },
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'claude-3-7-sonnet@20250219', provider: 'GCP Vertex AI' },
{ id: 21, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' },
{ id: 22, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' },
{ id: 23, name: 'claude-3-5-haiku@20241022', provider: 'GCP Vertex AI' },
{ id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' },
{ id: 25, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' },
{ id: 26, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' },
];
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
export const anthropic_models = [

View File

@@ -6,7 +6,7 @@ export function ToastSuccessModelSwitch(model: Model) {
return toast.success(
<div>
<strong>Model Changed</strong>
<div>Switched to {model.name}</div>
<div>Switched to {model.alias ?? model.name}</div>
</div>,
{
position: 'top-right',

View File

@@ -1,7 +1,7 @@
import { useModel } from './ModelContext'; // Import the useModel hook
import { Model } from './ModelContext';
import { useMemo } from 'react';
import { goose_models } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { ToastFailureGeneral, ToastSuccessModelSwitch } from './toasts';
import { initializeSystem } from '../../../utils/providerUtils';
import { useRecentModels } from './RecentModels';
@@ -43,7 +43,7 @@ export function useHandleModelSelection() {
}
export function createSelectedModel(selectedProvider, modelName) {
let selectedModel = goose_models.find(
let selectedModel = gooseModels.find(
(model) =>
model.provider.toLowerCase() === selectedProvider &&
model.name.toLowerCase() === modelName.toLowerCase()
@@ -52,7 +52,7 @@ export function createSelectedModel(selectedProvider, modelName) {
if (!selectedModel) {
// Normalize the casing for the provider using the first matching model
const normalizedProvider =
goose_models.find((model) => model.provider.toLowerCase() === selectedProvider)?.provider ||
gooseModels.find((model) => model.provider.toLowerCase() === selectedProvider)?.provider ||
selectedProvider;
// Construct a model object
@@ -67,7 +67,7 @@ export function createSelectedModel(selectedProvider, modelName) {
export function useFilteredModels(search: string, activeKeys: string[]) {
const filteredModels = useMemo(() => {
const modelOptions = goose_models.filter((model) => activeKeys.includes(model.provider));
const modelOptions = gooseModels.filter((model) => activeKeys.includes(model.provider));
if (!search) {
return modelOptions; // Return all models if no search term