Refactor Benchmarking Workflow and Introduce New Data Models (#5264)

* New benchmark data models

* Update _benchmarkBaseUrl

* Remove ReportRequestBody

* Update benchmark service methods for proxy approach

* Add eval id to SkillNodeData

* Refactor runBenchmark Method for proxy approach
This commit is contained in:
hunteraraujo
2023-09-19 17:01:15 -07:00
committed by GitHub
parent 2098e192da
commit 5afab461ee
8 changed files with 113 additions and 70 deletions

View File

@@ -0,0 +1,12 @@
class BenchmarkStepRequestBody {
final String? input;
BenchmarkStepRequestBody({required this.input});
Map<String, dynamic> toJson() {
if (input == null) {
return {};
}
return {'input': input};
}
}

View File

@@ -0,0 +1,13 @@
class BenchmarkTaskRequestBody {
final String input;
final String evalId;
BenchmarkTaskRequestBody({required this.input, required this.evalId});
Map<String, dynamic> toJson() {
return {
'input': input,
'eval_id': evalId,
};
}
}

View File

@@ -1,16 +0,0 @@
class ReportRequestBody {
final String test;
final String testRunId;
final bool mock;
ReportRequestBody(
{required this.test, required this.testRunId, required this.mock});
Map<String, dynamic> toJson() {
return {
'test': test,
'test_run_id': testRunId,
'mock': mock,
};
}
}

View File

@@ -9,6 +9,7 @@ class SkillNodeData {
final int cutoff;
final Ground ground;
final Info info;
final String evalId;
SkillNodeData({
required this.name,
@@ -18,6 +19,7 @@ class SkillNodeData {
required this.cutoff,
required this.ground,
required this.info,
required this.evalId,
});
factory SkillNodeData.fromJson(Map<String, dynamic> json) {
@@ -29,6 +31,7 @@ class SkillNodeData {
cutoff: json['cutoff'] ?? 0,
ground: Ground.fromJson(json['ground'] ?? {}),
info: Info.fromJson(json['info'] ?? {}),
evalId: json['eval_id'] ?? "",
);
}
}

View File

@@ -1,5 +1,6 @@
import 'dart:async';
import 'package:auto_gpt_flutter_client/models/benchmark_service/report_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_step_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_task_request_body.dart';
import 'package:auto_gpt_flutter_client/utils/rest_api_utility.dart';
import 'package:auto_gpt_flutter_client/models/benchmark_service/api_type.dart';
@@ -8,30 +9,43 @@ class BenchmarkService {
BenchmarkService(this.api);
/// Generates a single report using POST REST API at the /reports URL.
/// Creates a new benchmark task.
///
/// [reportRequestBody] is a Map representing the request body for generating a single report.
Future<Map<String, dynamic>> generateSingleReport(
ReportRequestBody reportRequestBody) async {
/// [benchmarkTaskRequestBody] is a Map representing the request body for creating a task.
Future<Map<String, dynamic>> createBenchmarkTask(
BenchmarkTaskRequestBody benchmarkTaskRequestBody) async {
try {
return await api.post('reports', reportRequestBody.toJson(),
return await api.post('agent/tasks', benchmarkTaskRequestBody.toJson(),
apiType: ApiType.benchmark);
} catch (e) {
throw Exception('Failed to generate single report: $e');
throw Exception('Failed to create a new task: $e');
}
}
/// Generates a combined report using POST REST API at the /reports/query URL.
/// Executes a step in a specific benchmark task.
///
/// [testRunIds] is a list of strings representing the test run IDs to be combined into a single report.
Future<Map<String, dynamic>> generateCombinedReport(
List<String> testRunIds) async {
/// [taskId] is the ID of the task.
/// [benchmarkStepRequestBody] is a Map representing the request body for executing a step.
Future<Map<String, dynamic>> executeBenchmarkStep(
String taskId, BenchmarkStepRequestBody benchmarkStepRequestBody) async {
try {
final Map<String, dynamic> requestBody = {'test_run_ids': testRunIds};
return await api.post('reports/query', requestBody,
return await api.post(
'agent/tasks/$taskId/steps', benchmarkStepRequestBody.toJson(),
apiType: ApiType.benchmark);
} catch (e) {
throw Exception('Failed to generate combined report: $e');
throw Exception('Failed to execute step: $e');
}
}
/// Triggers an evaluation for a specific benchmark task.
///
/// [taskId] is the ID of the task.
Future<Map<String, dynamic>> triggerEvaluation(String taskId) async {
try {
return await api.post('agent/tasks/$taskId/evaluation', {},
apiType: ApiType.benchmark);
} catch (e) {
throw Exception('Failed to trigger evaluation: $e');
}
}
}

View File

@@ -5,7 +5,7 @@ import 'package:http/http.dart' as http;
class RestApiUtility {
String _agentBaseUrl;
final String _benchmarkBaseUrl = "http://127.0.0.1:8080";
final String _benchmarkBaseUrl = "http://127.0.0.1:8080/ap/v1";
RestApiUtility(this._agentBaseUrl);

View File

@@ -1,12 +1,17 @@
import 'dart:convert';
import 'package:auto_gpt_flutter_client/models/benchmark_service/report_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_step_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark_service/benchmark_task_request_body.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
import 'package:auto_gpt_flutter_client/models/step.dart';
import 'package:auto_gpt_flutter_client/models/task.dart';
import 'package:auto_gpt_flutter_client/services/benchmark_service.dart';
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
import 'package:collection/collection.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'package:graphview/GraphView.dart';
import 'package:provider/provider.dart';
import 'package:uuid/uuid.dart';
class SkillTreeViewModel extends ChangeNotifier {
@@ -131,56 +136,64 @@ class SkillTreeViewModel extends ChangeNotifier {
}
}
// TODO: Update to actual implementation
Future<void> runBenchmark() async {
// Set the benchmark running flag to true
// TODO: Move to task queue view model
// TODO: We should check if the test passed. If not we short circuit.
// TODO: We should create a model to track our active tests
Future<void> runBenchmark(ChatViewModel chatViewModel) async {
// 1. Set the benchmark running flag to true
isBenchmarkRunning = true;
// 2. Notify listeners
notifyListeners();
// Initialize an empty list to collect unique UUIDs for test runs
List<String> testRunIds = [];
// 3. Grab the reversed node hierarchy
final reversedSelectedNodeHierarchy =
List.from(_selectedNodeHierarchy!.reversed);
try {
// Reverse the selected node hierarchy
final reversedSelectedNodeHierarchy =
List.from(_selectedNodeHierarchy!.reversed);
// Loop through the reversed node hierarchy to generate reports for each node
// 4. Loop through the nodes in the hierarchy
for (var node in reversedSelectedNodeHierarchy) {
// Generate a unique UUID for the test run
final uuid = const Uuid().v4();
// 5. Create a BenchmarkTaskRequestBody
final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
input: node.data.task, evalId: node.data.evalId);
// Create a ReportRequestBody object
final reportRequestBody = ReportRequestBody(
test: node.data.name, testRunId: uuid, mock: true);
// 6. Create a new benchmark task
final createdTask = await benchmarkService
.createBenchmarkTask(benchmarkTaskRequestBody);
// Call generateSingleReport with the created ReportRequestBody object
final singleReport =
await benchmarkService.generateSingleReport(reportRequestBody);
print("Single report generated: $singleReport");
// 7. Create a new Task object
final task =
Task(id: createdTask['task_id'], title: createdTask['input']);
// Add the unique UUID to the list
// TODO: We should check if the test passed. If not we short circuit.
// TODO: We should create a model to track our active tests
testRunIds.add(uuid);
// 8. Update the current task ID in ChatViewModel
chatViewModel.setCurrentTaskId(task.id);
// Notify the UI
notifyListeners();
// 9. Execute the first step and initialize the Step object
Map<String, dynamic> stepResponse =
await benchmarkService.executeBenchmarkStep(
task.id, BenchmarkStepRequestBody(input: null));
Step step = Step.fromMap(stepResponse);
// 11. Check if it's the last step
while (!step.isLast) {
// Fetch chats for the task
chatViewModel.fetchChatsForTask();
// Execute next step and update the Step object
stepResponse = await benchmarkService.executeBenchmarkStep(
task.id, BenchmarkStepRequestBody(input: null));
step = Step.fromMap(stepResponse);
}
// 12. Trigger the evaluation
final evaluationResponse =
await benchmarkService.triggerEvaluation(task.id);
print("Evaluation response: $evaluationResponse");
}
// Generate a combined report using all the unique UUIDs
final combinedReport =
await benchmarkService.generateCombinedReport(testRunIds);
// Pretty-print the JSON result
String prettyResult =
JsonEncoder.withIndent(' ').convert(combinedReport);
print("Combined report generated: $prettyResult");
} catch (e) {
print("Failed to generate reports: $e");
print("Error while running benchmark: $e");
}
// Set the benchmark running flag to false
// Reset the benchmark running flag
isBenchmarkRunning = false;
notifyListeners();
}

View File

@@ -1,4 +1,4 @@
import 'package:auto_gpt_flutter_client/models/benchmark_service/report_request_body.dart';
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
import 'package:flutter/material.dart';
import 'package:auto_gpt_flutter_client/viewmodels/skill_tree_viewmodel.dart';
import 'package:provider/provider.dart';
@@ -56,8 +56,12 @@ class TaskQueueView extends StatelessWidget {
onPressed: viewModel.isBenchmarkRunning
? null
: () {
// TODO: Handle this better
final chatViewModel =
Provider.of<ChatViewModel>(context, listen: false);
chatViewModel.clearCurrentTaskAndChats();
// Call runBenchmark method from SkillTreeViewModel
viewModel.runBenchmark();
viewModel.runBenchmark(chatViewModel);
},
child: Row(
mainAxisAlignment: