Refactor SkillTreeViewModel and Update TaskQueueView UI for Task Status (#5269)

* Refactor SkillTreeViewModel and Update TaskQueueView UI for Task Status

* Notify UI when updating benchmark status
This commit is contained in:
hunteraraujo
2023-09-19 23:30:22 -07:00
committed by GitHub
parent 99035103e0
commit 377d0af228
3 changed files with 68 additions and 15 deletions

View File

@@ -0,0 +1,6 @@
enum BenchmarkTaskStatus {
notStarted,
inProgress,
success,
failure,
}

View File

@@ -1,6 +1,7 @@
import 'dart:convert';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_step_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_status.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
import 'package:auto_gpt_flutter_client/models/step.dart';
@@ -20,7 +21,7 @@ class SkillTreeViewModel extends ChangeNotifier {
// TODO: Potentially move to task queue view model when we create one
bool isBenchmarkRunning = false;
// TODO: Potentially move to task queue view model when we create one
List<Map<SkillTreeNode, bool>> benchmarkStatusList = [];
Map<SkillTreeNode, BenchmarkTaskStatus> benchmarkStatusMap = {};
List<SkillTreeNode> _skillTreeNodes = [];
List<SkillTreeEdge> _skillTreeEdges = [];
@@ -144,7 +145,7 @@ class SkillTreeViewModel extends ChangeNotifier {
Future<void> runBenchmark(
ChatViewModel chatViewModel, TaskViewModel taskViewModel) async {
// Clear the benchmarkStatusList
benchmarkStatusList.clear();
benchmarkStatusMap.clear();
// Create a new TestSuite object with the current timestamp
final testSuite =
@@ -159,12 +160,15 @@ class SkillTreeViewModel extends ChangeNotifier {
final reversedSelectedNodeHierarchy =
List.from(_selectedNodeHierarchy!.reversed);
for (var node in reversedSelectedNodeHierarchy) {
benchmarkStatusList.add({node: false});
benchmarkStatusMap[node] = BenchmarkTaskStatus.notStarted;
}
try {
// Loop through the nodes in the hierarchy
for (var node in reversedSelectedNodeHierarchy) {
benchmarkStatusMap[node] = BenchmarkTaskStatus.inProgress;
notifyListeners();
// Create a BenchmarkTaskRequestBody
final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
input: node.data.task, evalId: node.data.evalId);
@@ -204,10 +208,10 @@ class SkillTreeViewModel extends ChangeNotifier {
// Update the benchmarkStatusList based on the evaluation response
bool successStatus = evaluationResponse['metrics']['success'];
var nodeStatus = benchmarkStatusList.firstWhere(
(element) => element.keys.first.id == node.id,
);
nodeStatus[node] = successStatus;
benchmarkStatusMap[node] = successStatus
? BenchmarkTaskStatus.success
: BenchmarkTaskStatus.failure;
notifyListeners();
// If successStatus is false, break out of the loop
if (!successStatus) {

View File

@@ -1,3 +1,4 @@
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_status.dart';
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
import 'package:auto_gpt_flutter_client/viewmodels/task_viewmodel.dart';
import 'package:flutter/material.dart';
@@ -14,10 +15,6 @@ class TaskQueueView extends StatelessWidget {
final reversedHierarchy =
viewModel.selectedNodeHierarchy?.reversed.toList() ?? [];
// Convert reversedHierarchy to a list of test names
final List<String> testNames =
reversedHierarchy.map((node) => node.data.name).toList();
return Material(
color: Colors.white,
child: Stack(
@@ -27,15 +24,61 @@ class TaskQueueView extends StatelessWidget {
itemCount: reversedHierarchy.length,
itemBuilder: (context, index) {
final node = reversedHierarchy[index];
// Choose the appropriate leading widget based on the task status
Widget leadingWidget;
switch (viewModel.benchmarkStatusMap[node]) {
case null:
case BenchmarkTaskStatus.notStarted:
leadingWidget = CircleAvatar(
radius: 12,
backgroundColor: Colors.grey,
child: CircleAvatar(
radius: 6,
backgroundColor: Colors.white,
),
);
break;
case BenchmarkTaskStatus.inProgress:
leadingWidget = SizedBox(
width: 24,
height: 24,
child: CircularProgressIndicator(
strokeWidth: 2,
),
);
break;
case BenchmarkTaskStatus.success:
leadingWidget = CircleAvatar(
radius: 12,
backgroundColor: Colors.green,
child: CircleAvatar(
radius: 6,
backgroundColor: Colors.white,
),
);
break;
case BenchmarkTaskStatus.failure:
leadingWidget = CircleAvatar(
radius: 12,
backgroundColor: Colors.red,
child: CircleAvatar(
radius: 6,
backgroundColor: Colors.white,
),
);
break;
}
return Container(
margin: EdgeInsets.fromLTRB(20, 5, 20, 5),
decoration: BoxDecoration(
color: Colors.white, // white background
border: Border.all(
color: Colors.black, width: 1), // thin black border
borderRadius: BorderRadius.circular(4), // small corner radius
color: Colors.white,
border: Border.all(color: Colors.black, width: 1),
borderRadius: BorderRadius.circular(4),
),
child: ListTile(
leading: leadingWidget,
title: Center(child: Text('${node.label}')),
subtitle:
Center(child: Text('${node.data.info.description}')),