From a37b486227ddc309278739371e53c2d2964c4a5a Mon Sep 17 00:00:00 2001 From: hunteraraujo Date: Tue, 19 Sep 2023 20:20:31 -0700 Subject: [PATCH] Enhance SkillTreeViewModel to Manage Benchmark Status (#5266) Enhance SkillTreeViewModel to Manage Benchmark Execution and Status --- .../lib/viewmodels/skill_tree_viewmodel.dart | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/frontend/lib/viewmodels/skill_tree_viewmodel.dart b/frontend/lib/viewmodels/skill_tree_viewmodel.dart index 23d712a8..d48fd2f9 100644 --- a/frontend/lib/viewmodels/skill_tree_viewmodel.dart +++ b/frontend/lib/viewmodels/skill_tree_viewmodel.dart @@ -11,14 +11,14 @@ import 'package:collection/collection.dart'; import 'package:flutter/foundation.dart'; import 'package:flutter/services.dart'; import 'package:graphview/GraphView.dart'; -import 'package:provider/provider.dart'; -import 'package:uuid/uuid.dart'; class SkillTreeViewModel extends ChangeNotifier { // TODO: Potentially move to task queue view model when we create one final BenchmarkService benchmarkService; // TODO: Potentially move to task queue view model when we create one bool isBenchmarkRunning = false; + // TODO: Potentially move to task queue view model when we create one + List> benchmarkStatusList = []; List _skillTreeNodes = []; List _skillTreeEdges = []; @@ -26,6 +26,8 @@ class SkillTreeViewModel extends ChangeNotifier { // TODO: Potentially move to task queue view model when we create one List? _selectedNodeHierarchy; + List get skillTreeNodes => _skillTreeNodes; + List get skillTreeEdges => _skillTreeEdges; SkillTreeNode? get selectedNode => _selectedNode; List? get selectedNodeHierarchy => _selectedNodeHierarchy; @@ -137,43 +139,48 @@ class SkillTreeViewModel extends ChangeNotifier { } // TODO: Move to task queue view model - // TODO: We should check if the test passed. If not we short circuit. - // TODO: We should create a model to track our active tests + // TODO: We should be creating TestSuite objects Future runBenchmark(ChatViewModel chatViewModel) async { - // 1. Set the benchmark running flag to true + // Clear the benchmarkStatusList + benchmarkStatusList.clear(); + + // Set the benchmark running flag to true isBenchmarkRunning = true; - // 2. Notify listeners + // Notify listeners notifyListeners(); - // 3. Grab the reversed node hierarchy + // Populate benchmarkStatusList with reversed node hierarchy final reversedSelectedNodeHierarchy = List.from(_selectedNodeHierarchy!.reversed); + for (var node in reversedSelectedNodeHierarchy) { + benchmarkStatusList.add({node: false}); + } try { - // 4. Loop through the nodes in the hierarchy + // Loop through the nodes in the hierarchy for (var node in reversedSelectedNodeHierarchy) { - // 5. Create a BenchmarkTaskRequestBody + // Create a BenchmarkTaskRequestBody final benchmarkTaskRequestBody = BenchmarkTaskRequestBody( input: node.data.task, evalId: node.data.evalId); - // 6. Create a new benchmark task + // Create a new benchmark task final createdTask = await benchmarkService .createBenchmarkTask(benchmarkTaskRequestBody); - // 7. Create a new Task object + // Create a new Task object final task = Task(id: createdTask['task_id'], title: createdTask['input']); - // 8. Update the current task ID in ChatViewModel + // Update the current task ID in ChatViewModel chatViewModel.setCurrentTaskId(task.id); - // 9. Execute the first step and initialize the Step object + // Execute the first step and initialize the Step object Map stepResponse = await benchmarkService.executeBenchmarkStep( task.id, BenchmarkStepRequestBody(input: null)); Step step = Step.fromMap(stepResponse); - // 11. Check if it's the last step + // Check if it's the last step while (!step.isLast) { // Fetch chats for the task chatViewModel.fetchChatsForTask(); @@ -184,10 +191,17 @@ class SkillTreeViewModel extends ChangeNotifier { step = Step.fromMap(stepResponse); } - // 12. Trigger the evaluation + // Trigger the evaluation final evaluationResponse = await benchmarkService.triggerEvaluation(task.id); print("Evaluation response: $evaluationResponse"); + + // Update the benchmarkStatusList based on the evaluation response + bool successStatus = evaluationResponse['metrics']['success']; + var nodeStatus = benchmarkStatusList.firstWhere( + (element) => element.keys.first.id == node.id, + ); + nodeStatus[node] = successStatus; } } catch (e) { print("Error while running benchmark: $e"); @@ -197,10 +211,4 @@ class SkillTreeViewModel extends ChangeNotifier { isBenchmarkRunning = false; notifyListeners(); } - - // Getter to expose nodes for the View - List get skillTreeNodes => _skillTreeNodes; - - // Getter to expose edges for the View - List get skillTreeEdges => _skillTreeEdges; }