mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-31 20:04:28 +01:00
Refactor Benchmarking Workflow and Introduce New Data Models (#5264)
* New benchmark data models * Update _benchmarkBaseUrl * Remove ReportRequestBody * Update benchmark service methods for proxy approach * Add eval id to SkillNodeData * Refactor runBenchmark Method for proxy approach
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
class BenchmarkStepRequestBody {
|
||||
final String? input;
|
||||
|
||||
BenchmarkStepRequestBody({required this.input});
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
if (input == null) {
|
||||
return {};
|
||||
}
|
||||
return {'input': input};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
class BenchmarkTaskRequestBody {
|
||||
final String input;
|
||||
final String evalId;
|
||||
|
||||
BenchmarkTaskRequestBody({required this.input, required this.evalId});
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
return {
|
||||
'input': input,
|
||||
'eval_id': evalId,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
class ReportRequestBody {
|
||||
final String test;
|
||||
final String testRunId;
|
||||
final bool mock;
|
||||
|
||||
ReportRequestBody(
|
||||
{required this.test, required this.testRunId, required this.mock});
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
return {
|
||||
'test': test,
|
||||
'test_run_id': testRunId,
|
||||
'mock': mock,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ class SkillNodeData {
|
||||
final int cutoff;
|
||||
final Ground ground;
|
||||
final Info info;
|
||||
final String evalId;
|
||||
|
||||
SkillNodeData({
|
||||
required this.name,
|
||||
@@ -18,6 +19,7 @@ class SkillNodeData {
|
||||
required this.cutoff,
|
||||
required this.ground,
|
||||
required this.info,
|
||||
required this.evalId,
|
||||
});
|
||||
|
||||
factory SkillNodeData.fromJson(Map<String, dynamic> json) {
|
||||
@@ -29,6 +31,7 @@ class SkillNodeData {
|
||||
cutoff: json['cutoff'] ?? 0,
|
||||
ground: Ground.fromJson(json['ground'] ?? {}),
|
||||
info: Info.fromJson(json['info'] ?? {}),
|
||||
evalId: json['eval_id'] ?? "",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import 'dart:async';
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/report_request_body.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_step_request_body.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_task_request_body.dart';
|
||||
import 'package:auto_gpt_flutter_client/utils/rest_api_utility.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/api_type.dart';
|
||||
|
||||
@@ -8,30 +9,43 @@ class BenchmarkService {
|
||||
|
||||
BenchmarkService(this.api);
|
||||
|
||||
/// Generates a single report using POST REST API at the /reports URL.
|
||||
/// Creates a new benchmark task.
|
||||
///
|
||||
/// [reportRequestBody] is a Map representing the request body for generating a single report.
|
||||
Future<Map<String, dynamic>> generateSingleReport(
|
||||
ReportRequestBody reportRequestBody) async {
|
||||
/// [benchmarkTaskRequestBody] is a Map representing the request body for creating a task.
|
||||
Future<Map<String, dynamic>> createBenchmarkTask(
|
||||
BenchmarkTaskRequestBody benchmarkTaskRequestBody) async {
|
||||
try {
|
||||
return await api.post('reports', reportRequestBody.toJson(),
|
||||
return await api.post('agent/tasks', benchmarkTaskRequestBody.toJson(),
|
||||
apiType: ApiType.benchmark);
|
||||
} catch (e) {
|
||||
throw Exception('Failed to generate single report: $e');
|
||||
throw Exception('Failed to create a new task: $e');
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a combined report using POST REST API at the /reports/query URL.
|
||||
/// Executes a step in a specific benchmark task.
|
||||
///
|
||||
/// [testRunIds] is a list of strings representing the test run IDs to be combined into a single report.
|
||||
Future<Map<String, dynamic>> generateCombinedReport(
|
||||
List<String> testRunIds) async {
|
||||
/// [taskId] is the ID of the task.
|
||||
/// [benchmarkStepRequestBody] is a Map representing the request body for executing a step.
|
||||
Future<Map<String, dynamic>> executeBenchmarkStep(
|
||||
String taskId, BenchmarkStepRequestBody benchmarkStepRequestBody) async {
|
||||
try {
|
||||
final Map<String, dynamic> requestBody = {'test_run_ids': testRunIds};
|
||||
return await api.post('reports/query', requestBody,
|
||||
return await api.post(
|
||||
'agent/tasks/$taskId/steps', benchmarkStepRequestBody.toJson(),
|
||||
apiType: ApiType.benchmark);
|
||||
} catch (e) {
|
||||
throw Exception('Failed to generate combined report: $e');
|
||||
throw Exception('Failed to execute step: $e');
|
||||
}
|
||||
}
|
||||
|
||||
/// Triggers an evaluation for a specific benchmark task.
|
||||
///
|
||||
/// [taskId] is the ID of the task.
|
||||
Future<Map<String, dynamic>> triggerEvaluation(String taskId) async {
|
||||
try {
|
||||
return await api.post('agent/tasks/$taskId/evaluation', {},
|
||||
apiType: ApiType.benchmark);
|
||||
} catch (e) {
|
||||
throw Exception('Failed to trigger evaluation: $e');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import 'package:http/http.dart' as http;
|
||||
|
||||
class RestApiUtility {
|
||||
String _agentBaseUrl;
|
||||
final String _benchmarkBaseUrl = "http://127.0.0.1:8080";
|
||||
final String _benchmarkBaseUrl = "http://127.0.0.1:8080/ap/v1";
|
||||
|
||||
RestApiUtility(this._agentBaseUrl);
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import 'dart:convert';
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/report_request_body.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_step_request_body.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_task_request_body.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/step.dart';
|
||||
import 'package:auto_gpt_flutter_client/models/task.dart';
|
||||
import 'package:auto_gpt_flutter_client/services/benchmark_service.dart';
|
||||
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
|
||||
import 'package:collection/collection.dart';
|
||||
import 'package:flutter/foundation.dart';
|
||||
import 'package:flutter/services.dart';
|
||||
import 'package:graphview/GraphView.dart';
|
||||
import 'package:provider/provider.dart';
|
||||
import 'package:uuid/uuid.dart';
|
||||
|
||||
class SkillTreeViewModel extends ChangeNotifier {
|
||||
@@ -131,56 +136,64 @@ class SkillTreeViewModel extends ChangeNotifier {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Update to actual implementation
|
||||
Future<void> runBenchmark() async {
|
||||
// Set the benchmark running flag to true
|
||||
// TODO: Move to task queue view model
|
||||
// TODO: We should check if the test passed. If not we short circuit.
|
||||
// TODO: We should create a model to track our active tests
|
||||
Future<void> runBenchmark(ChatViewModel chatViewModel) async {
|
||||
// 1. Set the benchmark running flag to true
|
||||
isBenchmarkRunning = true;
|
||||
// 2. Notify listeners
|
||||
notifyListeners();
|
||||
|
||||
// Initialize an empty list to collect unique UUIDs for test runs
|
||||
List<String> testRunIds = [];
|
||||
// 3. Grab the reversed node hierarchy
|
||||
final reversedSelectedNodeHierarchy =
|
||||
List.from(_selectedNodeHierarchy!.reversed);
|
||||
|
||||
try {
|
||||
// Reverse the selected node hierarchy
|
||||
final reversedSelectedNodeHierarchy =
|
||||
List.from(_selectedNodeHierarchy!.reversed);
|
||||
|
||||
// Loop through the reversed node hierarchy to generate reports for each node
|
||||
// 4. Loop through the nodes in the hierarchy
|
||||
for (var node in reversedSelectedNodeHierarchy) {
|
||||
// Generate a unique UUID for the test run
|
||||
final uuid = const Uuid().v4();
|
||||
// 5. Create a BenchmarkTaskRequestBody
|
||||
final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
|
||||
input: node.data.task, evalId: node.data.evalId);
|
||||
|
||||
// Create a ReportRequestBody object
|
||||
final reportRequestBody = ReportRequestBody(
|
||||
test: node.data.name, testRunId: uuid, mock: true);
|
||||
// 6. Create a new benchmark task
|
||||
final createdTask = await benchmarkService
|
||||
.createBenchmarkTask(benchmarkTaskRequestBody);
|
||||
|
||||
// Call generateSingleReport with the created ReportRequestBody object
|
||||
final singleReport =
|
||||
await benchmarkService.generateSingleReport(reportRequestBody);
|
||||
print("Single report generated: $singleReport");
|
||||
// 7. Create a new Task object
|
||||
final task =
|
||||
Task(id: createdTask['task_id'], title: createdTask['input']);
|
||||
|
||||
// Add the unique UUID to the list
|
||||
// TODO: We should check if the test passed. If not we short circuit.
|
||||
// TODO: We should create a model to track our active tests
|
||||
testRunIds.add(uuid);
|
||||
// 8. Update the current task ID in ChatViewModel
|
||||
chatViewModel.setCurrentTaskId(task.id);
|
||||
|
||||
// Notify the UI
|
||||
notifyListeners();
|
||||
// 9. Execute the first step and initialize the Step object
|
||||
Map<String, dynamic> stepResponse =
|
||||
await benchmarkService.executeBenchmarkStep(
|
||||
task.id, BenchmarkStepRequestBody(input: null));
|
||||
Step step = Step.fromMap(stepResponse);
|
||||
|
||||
// 11. Check if it's the last step
|
||||
while (!step.isLast) {
|
||||
// Fetch chats for the task
|
||||
chatViewModel.fetchChatsForTask();
|
||||
|
||||
// Execute next step and update the Step object
|
||||
stepResponse = await benchmarkService.executeBenchmarkStep(
|
||||
task.id, BenchmarkStepRequestBody(input: null));
|
||||
step = Step.fromMap(stepResponse);
|
||||
}
|
||||
|
||||
// 12. Trigger the evaluation
|
||||
final evaluationResponse =
|
||||
await benchmarkService.triggerEvaluation(task.id);
|
||||
print("Evaluation response: $evaluationResponse");
|
||||
}
|
||||
|
||||
// Generate a combined report using all the unique UUIDs
|
||||
final combinedReport =
|
||||
await benchmarkService.generateCombinedReport(testRunIds);
|
||||
|
||||
// Pretty-print the JSON result
|
||||
String prettyResult =
|
||||
JsonEncoder.withIndent(' ').convert(combinedReport);
|
||||
print("Combined report generated: $prettyResult");
|
||||
} catch (e) {
|
||||
print("Failed to generate reports: $e");
|
||||
print("Error while running benchmark: $e");
|
||||
}
|
||||
|
||||
// Set the benchmark running flag to false
|
||||
// Reset the benchmark running flag
|
||||
isBenchmarkRunning = false;
|
||||
notifyListeners();
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import 'package:auto_gpt_flutter_client/models/benchmark_service/report_request_body.dart';
|
||||
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
|
||||
import 'package:flutter/material.dart';
|
||||
import 'package:auto_gpt_flutter_client/viewmodels/skill_tree_viewmodel.dart';
|
||||
import 'package:provider/provider.dart';
|
||||
@@ -56,8 +56,12 @@ class TaskQueueView extends StatelessWidget {
|
||||
onPressed: viewModel.isBenchmarkRunning
|
||||
? null
|
||||
: () {
|
||||
// TODO: Handle this better
|
||||
final chatViewModel =
|
||||
Provider.of<ChatViewModel>(context, listen: false);
|
||||
chatViewModel.clearCurrentTaskAndChats();
|
||||
// Call runBenchmark method from SkillTreeViewModel
|
||||
viewModel.runBenchmark();
|
||||
viewModel.runBenchmark(chatViewModel);
|
||||
},
|
||||
child: Row(
|
||||
mainAxisAlignment:
|
||||
|
||||
Reference in New Issue
Block a user