diff --git a/frontend/lib/viewmodels/skill_tree_viewmodel.dart b/frontend/lib/viewmodels/skill_tree_viewmodel.dart index 9358e27e..833d8c11 100644 --- a/frontend/lib/viewmodels/skill_tree_viewmodel.dart +++ b/frontend/lib/viewmodels/skill_tree_viewmodel.dart @@ -18,6 +18,7 @@ import 'package:flutter/foundation.dart'; import 'package:flutter/services.dart'; import 'package:graphview/GraphView.dart'; import 'package:uuid/uuid.dart'; +import 'package:auto_gpt_flutter_client/utils/stack.dart'; class SkillTreeViewModel extends ChangeNotifier { // TODO: Potentially move to task queue view model when we create one @@ -38,6 +39,9 @@ class SkillTreeViewModel extends ChangeNotifier { // TODO: Potentially move to task queue view model when we create one List? _selectedNodeHierarchy; + String _selectedOption = 'Run single test'; + String get selectedOption => _selectedOption; + List get skillTreeNodes => _skillTreeNodes; List get skillTreeEdges => _skillTreeEdges; SkillTreeNode? get selectedNode => _selectedNode; @@ -101,11 +105,95 @@ class SkillTreeViewModel extends ChangeNotifier { } else { // Select the new node _selectedNode = _skillTreeNodes.firstWhere((node) => node.id == nodeId); - populateSelectedNodeHierarchy(nodeId); + updateSelectedNodeHierarchyBasedOnOption(_selectedOption); } notifyListeners(); } + void updateSelectedNodeHierarchyBasedOnOption(String selectedOption) { + _selectedOption = selectedOption; + switch (selectedOption) { + // TODO: Turn this into enum + case 'Run single test': + _selectedNodeHierarchy = _selectedNode != null ? [_selectedNode!] : []; + break; + + case 'Run test suite including selected node and ancestors': + if (_selectedNode != null) { + populateSelectedNodeHierarchy(_selectedNode!.id); + } + break; + + case 'Run all tests in category': + if (_selectedNode != null) { + _getAllNodesInDepthFirstOrderEnsuringParents(); + } + break; + } + notifyListeners(); + } + + void _getAllNodesInDepthFirstOrderEnsuringParents() { + var nodes = []; + var stack = Stack(); + var visited = {}; + + // Identify the root node by its label + var root = _skillTreeNodes.firstWhere((node) => node.label == "WriteFile"); + + stack.push(root); + visited.add(root.id); + + while (stack.isNotEmpty) { + var node = stack.peek(); // Peek the top node, but do not remove it yet + var parents = _getParentsOfNodeUsingEdges(node.id); + + // Check if all parents are visited + if (parents.every((parent) => visited.contains(parent.id))) { + nodes.add(node); + stack.pop(); // Remove the node only when all its parents are visited + + // Get the children of the current node using edges + var children = _getChildrenOfNodeUsingEdges(node.id) + .where((child) => !visited.contains(child.id)); + + children.forEach((child) { + visited.add(child.id); + stack.push(child); + }); + } else { + stack + .pop(); // Remove the node if not all parents are visited, it will be re-added when its parents are visited + } + } + + _selectedNodeHierarchy = nodes; + } + + List _getParentsOfNodeUsingEdges(String nodeId) { + var parents = []; + + for (var edge in _skillTreeEdges) { + if (edge.to == nodeId) { + parents.add(_skillTreeNodes.firstWhere((node) => node.id == edge.from)); + } + } + + return parents; + } + + List _getChildrenOfNodeUsingEdges(String nodeId) { + var children = []; + + for (var edge in _skillTreeEdges) { + if (edge.from == nodeId) { + children.add(_skillTreeNodes.firstWhere((node) => node.id == edge.to)); + } + } + + return children; + } + // TODO: Do we want to continue testing other branches of tree if one branch side fails benchmarking? void populateSelectedNodeHierarchy(String startNodeId) { // Initialize an empty list to hold the nodes in all hierarchies. diff --git a/frontend/lib/views/task_queue/task_queue_view.dart b/frontend/lib/views/task_queue/task_queue_view.dart index 8132bb5d..4562b0a5 100644 --- a/frontend/lib/views/task_queue/task_queue_view.dart +++ b/frontend/lib/views/task_queue/task_queue_view.dart @@ -97,18 +97,22 @@ class TaskQueueView extends StatelessWidget { children: [ // TestSuiteButton TestSuiteButton( - onPressed: viewModel.isBenchmarkRunning - ? null - : () { - final chatViewModel = Provider.of( - context, - listen: false); - final taskViewModel = Provider.of( - context, - listen: false); - chatViewModel.clearCurrentTaskAndChats(); - viewModel.runBenchmark(chatViewModel, taskViewModel); - }, + isDisabled: viewModel.isBenchmarkRunning, + selectedOption: viewModel.selectedOption, + onOptionSelected: (selectedOption) { + print('Option Selected: $selectedOption'); + viewModel.updateSelectedNodeHierarchyBasedOnOption( + selectedOption); + }, + onPlayPressed: (selectedOption) { + print('Starting benchmark with option: $selectedOption'); + final chatViewModel = + Provider.of(context, listen: false); + final taskViewModel = + Provider.of(context, listen: false); + chatViewModel.clearCurrentTaskAndChats(); + viewModel.runBenchmark(chatViewModel, taskViewModel); + }, ), SizedBox(height: 8), // Gap of 8 points between buttons // LeaderboardSubmissionButton diff --git a/frontend/lib/views/task_queue/test_suite_button.dart b/frontend/lib/views/task_queue/test_suite_button.dart index e6086a8e..1e1f871f 100644 --- a/frontend/lib/views/task_queue/test_suite_button.dart +++ b/frontend/lib/views/task_queue/test_suite_button.dart @@ -1,47 +1,115 @@ import 'package:auto_gpt_flutter_client/constants/app_colors.dart'; import 'package:flutter/material.dart'; -class TestSuiteButton extends StatelessWidget { - final VoidCallback? onPressed; +class TestSuiteButton extends StatefulWidget { final bool isDisabled; + final Function(String) onOptionSelected; + final Function(String) onPlayPressed; + String selectedOption; - TestSuiteButton({required this.onPressed, this.isDisabled = false}); + TestSuiteButton({ + this.isDisabled = false, + required this.onOptionSelected, + required this.onPlayPressed, + required this.selectedOption, + }); + @override + _TestSuiteButtonState createState() => _TestSuiteButtonState(); +} + +class _TestSuiteButtonState extends State { @override Widget build(BuildContext context) { - return SizedBox( - height: 50, - child: ElevatedButton( - style: ElevatedButton.styleFrom( - backgroundColor: isDisabled ? Colors.grey : AppColors.primaryLight, - shape: RoundedRectangleBorder( - borderRadius: BorderRadius.circular(8.0), - ), - padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 8), - elevation: 5.0, - ), - onPressed: isDisabled ? null : onPressed, - child: const Row( - mainAxisAlignment: MainAxisAlignment.center, - children: [ - Text( - 'Initiate test suite', - style: TextStyle( - color: Colors.white, - fontSize: 12.50, - fontFamily: 'Archivo', - fontWeight: FontWeight.w400, + return Row( + children: [ + // Dropdown button with test options + Expanded( + // Added Expanded to make sure it takes the available space + child: PopupMenuButton( + enabled: !widget.isDisabled, + onSelected: (value) { + setState(() { + widget.selectedOption = value; + }); + widget.onOptionSelected(widget.selectedOption); + }, + itemBuilder: (BuildContext context) { + return [ + const PopupMenuItem( + value: 'Run single test', + child: Text('Run single test'), + ), + const PopupMenuItem( + value: 'Run test suite including selected node and ancestors', + child: Text( + 'Run test suite including selected node and ancestors'), + ), + const PopupMenuItem( + value: 'Run all tests in category', + child: Text('Run all tests in category'), + ), + ]; + }, + child: Container( + height: 50, + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 8), + decoration: BoxDecoration( + color: widget.isDisabled ? Colors.grey : AppColors.primaryLight, + borderRadius: BorderRadius.circular(8.0), + ), + child: Row( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Flexible( + child: Text( + widget.selectedOption, + style: const TextStyle( + color: Colors.white, + fontSize: 12.50, + fontFamily: 'Archivo', + fontWeight: FontWeight.w400, + ), + overflow: TextOverflow.ellipsis, + maxLines: 2, + ), + ), + const Icon( + Icons.arrow_drop_down, + color: Colors.white, + ) + ], ), ), - SizedBox(width: 10), - Icon( + ), + ), + // Play button + const SizedBox(width: 10), + SizedBox( + height: 50, + child: ElevatedButton( + style: ElevatedButton.styleFrom( + backgroundColor: + widget.isDisabled ? Colors.grey : AppColors.primaryLight, + shape: RoundedRectangleBorder( + borderRadius: BorderRadius.circular(8.0), + ), + padding: const EdgeInsets.symmetric(horizontal: 16, vertical: 8), + elevation: 5.0, + ), + onPressed: widget.isDisabled + ? null + : () { + widget.onPlayPressed(widget.selectedOption); + }, + child: const Icon( Icons.play_arrow, color: Colors.white, size: 24, ), - ], + ), ), - ), + ], ); } }