using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
///
/// https://stackoverflow.com/a/60719233/275504
///
///
///
public class ConcurrentMultiDictionary
: IEnumerable>
{
private class Bag : HashSet
{
public bool IsDiscarded { get; set; }
}
private readonly ConcurrentDictionary _dictionary;
public ConcurrentMultiDictionary()
{
_dictionary = new ConcurrentDictionary();
}
public int Count => _dictionary.Count;
public bool Add(TKey key, TValue value)
{
var spinWait = new SpinWait();
while (true)
{
var bag = _dictionary.GetOrAdd(key, _ => new Bag());
lock (bag)
{
if (!bag.IsDiscarded) return bag.Add(value);
}
spinWait.SpinOnce();
}
}
public bool AddOrReplace(TKey key, TValue value)
{
Remove(key, value);
return Add(key, value);
}
public bool Remove(TKey key)
{
return _dictionary.TryRemove(key, out _);
}
public bool Remove(TKey key, out TValue[]? items)
{
if(_dictionary.TryRemove(key, out var x))
{
items = x.ToArray();
return true;
}
items = null;
return false;
}
public bool Remove(TKey key, TValue value)
{
var spinWait = new SpinWait();
while (true)
{
if (!_dictionary.TryGetValue(key, out var bag)) return false;
bool spinAndRetry = false;
lock (bag)
{
if (bag.IsDiscarded)
{
spinAndRetry = true;
}
else
{
bool valueRemoved = bag.Remove(value);
if (!valueRemoved) return false;
if (bag.Count != 0) return true;
bag.IsDiscarded = true;
}
}
if (spinAndRetry) { spinWait.SpinOnce(); continue; }
bool keyRemoved = _dictionary.TryRemove(key, out var currentBag);
Debug.Assert(keyRemoved, $"Key {key} was not removed");
Debug.Assert(bag == currentBag, $"Removed wrong bag");
return true;
}
}
public bool TryGetValues(TKey key, out TValue[] values)
{
if (!_dictionary.TryGetValue(key, out var bag)) { values = null; return false; }
bool isDiscarded;
lock (bag) { isDiscarded = bag.IsDiscarded; values = bag.ToArray(); }
if (isDiscarded) { values = null; return false; }
return true;
}
public bool Contains(TKey key, TValue value)
{
if (!_dictionary.TryGetValue(key, out var bag)) return false;
lock (bag) return !bag.IsDiscarded && bag.Contains(value);
}
public bool Contains(TKey key, IEnumerable value)
{
if (!_dictionary.TryGetValue(key, out var bag)) return false;
lock (bag) return !bag.IsDiscarded && value.Any(bag.Contains);
}
public bool ContainsKey(TKey key) => _dictionary.ContainsKey(key);
public ICollection Keys => _dictionary.Keys;
public IEnumerator> GetEnumerator()
{
foreach (var key in _dictionary.Keys)
{
if (this.TryGetValues(key, out var values))
{
yield return new KeyValuePair(key, values);
}
}
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public bool ContainsValue(TValue value)
{
return _dictionary.Keys.Any(key => Contains(key, value));
}
public bool ContainsValue(IEnumerable value)
{
return _dictionary.Keys.Any(key => Contains(key, value));
}
public IEnumerable GetKeysContainingValue(IEnumerable value)
{
return _dictionary.Keys.Where(key => Contains(key, value));
}
public IEnumerable GetKeysContainingValue(TValue value)
{
return _dictionary.Keys.Where(key => Contains(key, value));
}
}