using System.Collections; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Threading; /// /// https://stackoverflow.com/a/60719233/275504 /// /// /// public class ConcurrentMultiDictionary : IEnumerable> { private class Bag : HashSet { public bool IsDiscarded { get; set; } } private readonly ConcurrentDictionary _dictionary; public ConcurrentMultiDictionary() { _dictionary = new ConcurrentDictionary(); } public int Count => _dictionary.Count; public bool Add(TKey key, TValue value) { var spinWait = new SpinWait(); while (true) { var bag = _dictionary.GetOrAdd(key, _ => new Bag()); lock (bag) { if (!bag.IsDiscarded) return bag.Add(value); } spinWait.SpinOnce(); } } public bool AddOrReplace(TKey key, TValue value) { Remove(key, value); return Add(key, value); } public bool Remove(TKey key) { return _dictionary.TryRemove(key, out _); } public bool Remove(TKey key, out TValue[]? items) { if(_dictionary.TryRemove(key, out var x)) { items = x.ToArray(); return true; } items = null; return false; } public bool Remove(TKey key, TValue value) { var spinWait = new SpinWait(); while (true) { if (!_dictionary.TryGetValue(key, out var bag)) return false; bool spinAndRetry = false; lock (bag) { if (bag.IsDiscarded) { spinAndRetry = true; } else { bool valueRemoved = bag.Remove(value); if (!valueRemoved) return false; if (bag.Count != 0) return true; bag.IsDiscarded = true; } } if (spinAndRetry) { spinWait.SpinOnce(); continue; } bool keyRemoved = _dictionary.TryRemove(key, out var currentBag); Debug.Assert(keyRemoved, $"Key {key} was not removed"); Debug.Assert(bag == currentBag, $"Removed wrong bag"); return true; } } public bool TryGetValues(TKey key, out TValue[] values) { if (!_dictionary.TryGetValue(key, out var bag)) { values = null; return false; } bool isDiscarded; lock (bag) { isDiscarded = bag.IsDiscarded; values = bag.ToArray(); } if (isDiscarded) { values = null; return false; } return true; } public bool Contains(TKey key, TValue value) { if (!_dictionary.TryGetValue(key, out var bag)) return false; lock (bag) return !bag.IsDiscarded && bag.Contains(value); } public bool Contains(TKey key, IEnumerable value) { if (!_dictionary.TryGetValue(key, out var bag)) return false; lock (bag) return !bag.IsDiscarded && value.Any(bag.Contains); } public bool ContainsKey(TKey key) => _dictionary.ContainsKey(key); public ICollection Keys => _dictionary.Keys; public IEnumerator> GetEnumerator() { foreach (var key in _dictionary.Keys) { if (this.TryGetValues(key, out var values)) { yield return new KeyValuePair(key, values); } } } IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); public bool ContainsValue(TValue value) { return _dictionary.Keys.Any(key => Contains(key, value)); } public bool ContainsValue(IEnumerable value) { return _dictionary.Keys.Any(key => Contains(key, value)); } public IEnumerable GetKeysContainingValue(IEnumerable value) { return _dictionary.Keys.Where(key => Contains(key, value)); } public IEnumerable GetKeysContainingValue(TValue value) { return _dictionary.Keys.Where(key => Contains(key, value)); } }