diff --git a/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs b/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs
index 8d128b37..f7497789 100644
--- a/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs
+++ b/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs
@@ -1,21 +1,23 @@
+using System;
using System.Collections.Generic;
-using System.Text.RegularExpressions;
+using System.Linq;
using Mono.Cecil;
namespace StardewModdingAPI.Framework.ModLoading
{
/// Performs heuristic equality checks for instances.
+ ///
+ /// This implementation compares instances to see if they likely
+ /// refer to the same type. While the implementation is obvious for types like System.Bool,
+ /// this class mainly exists to handle cases like System.Collections.Generic.Dictionary`2<!0,Netcode.NetRoot`1<!1>>
+ /// and System.Collections.Generic.Dictionary`2<TKey,Netcode.NetRoot`1<TValue>>
+ /// which are compatible, but not directly comparable. It does this by splitting each type name
+ /// into its component token types, and performing placeholder substitution (e.g. !0 to
+ /// TKey in the above example). If all components are equal after substitution, and the
+ /// tokens can all be mapped to the same generic type, the types are considered equal.
+ ///
internal class TypeReferenceComparer : IEqualityComparer
{
- /*********
- ** Properties
- *********/
- /// A pattern matching type name substrings to strip for display.
- private readonly Regex StripTypeNamePattern = new Regex(@"`\d+(?=<)", RegexOptions.Compiled);
-
- private List symbolBoundaries = new List { '<', '>', ',' };
-
-
/*********
** Public methods
*********/
@@ -24,25 +26,13 @@ namespace StardewModdingAPI.Framework.ModLoading
/// The second object to compare.
public bool Equals(TypeReference a, TypeReference b)
{
- string typeA = this.GetComparableTypeID(a);
- string typeB = this.GetComparableTypeID(b);
+ if (a == null || b == null)
+ return a == b;
- string placeholderType = "", actualType = "";
-
- if (this.HasPlaceholder(typeA))
- {
- placeholderType = typeA;
- actualType = typeB;
- }
- else if (this.HasPlaceholder(typeB))
- {
- placeholderType = typeB;
- actualType = typeA;
- }
- else
- return typeA == typeB;
-
- return this.PlaceholderTypeValidates(placeholderType, actualType);
+ return
+ a == b
+ || a.FullName == b.FullName
+ || this.HeuristicallyEquals(a, b);
}
/// Get a hash code for the specified object.
@@ -57,153 +47,155 @@ namespace StardewModdingAPI.Framework.ModLoading
/*********
** Private methods
*********/
- /// Get a unique string representation of a type.
- /// The type reference.
- private string GetComparableTypeID(TypeReference type)
+ /// Get whether two types are heuristically equal based on generic type token substitution.
+ /// The first type to compare.
+ /// The second type to compare.
+ private bool HeuristicallyEquals(TypeReference typeA, TypeReference typeB)
{
- return this.StripTypeNamePattern.Replace(type.FullName, "");
- }
-
- /// Determine whether this type ID has a placeholder such as !0.
- /// The type to check.
- /// true if the type ID contains a placeholder, false if not.
- private bool HasPlaceholder(string typeID)
- {
- return typeID.Contains("!0");
- }
-
- /// returns whether this type ID is a placeholder, i.e., it begins with "!".
- /// The symbol to validate.
- /// true if the symbol is a placeholder, false if not
- private bool IsPlaceholder(string symbol)
- {
- return symbol.StartsWith("!");
- }
-
- /// Traverses and parses out symbols from a type which does not contain placeholder values.
- /// The type to traverse.
- /// A List in which to store the parsed symbols.
- private void TraverseActualType(string type, List typeSymbols)
- {
- int depth = 0;
- string symbol = "";
-
- foreach (char c in type)
+ bool HeuristicallyEquals(string typeNameA, string typeNameB, IDictionary tokenMap)
{
- if (this.symbolBoundaries.Contains(c))
+ // analyse type names
+ bool hasTokensA = typeNameA.Contains("!");
+ bool hasTokensB = typeNameB.Contains("!");
+ bool isTokenA = hasTokensA && typeNameA[0] == '!';
+ bool isTokenB = hasTokensB && typeNameB[0] == '!';
+
+ // validate
+ if (!hasTokensA && !hasTokensB)
+ return typeNameA == typeNameB; // no substitution needed
+ if (hasTokensA && hasTokensB)
+ throw new InvalidOperationException("Can't compare two type names when both contain generic type tokens.");
+
+ // perform substitution if applicable
+ if (isTokenA)
+ typeNameA = this.MapPlaceholder(placeholder: typeNameA, type: typeNameB, map: tokenMap);
+ if (isTokenB)
+ typeNameB = this.MapPlaceholder(placeholder: typeNameB, type: typeNameA, map: tokenMap);
+
+ // compare inner tokens
+ string[] symbolsA = this.GetTypeSymbols(typeNameA).ToArray();
+ string[] symbolsB = this.GetTypeSymbols(typeNameB).ToArray();
+ if (symbolsA.Length != symbolsB.Length)
+ return false;
+
+ for (int i = 0; i < symbolsA.Length; i++)
{
- typeSymbols.Add(new SymbolLocation(symbol, depth));
- symbol = "";
- switch (c)
- {
- case '<':
- depth++;
- break;
- case '>':
- depth--;
- break;
- default:
- break;
- }
- }
- else
- symbol += c;
- }
- }
-
- /// Determines whether two symbols in a type ID match, accounting for placeholders such as !0.
- /// A symbol in a typename which contains placeholders.
- /// A symbol in a typename which does not contain placeholders.
- /// A dictionary containing a mapping of placeholders to concrete types.
- /// true if the symbols match, false if not.
- private bool SymbolsMatch(SymbolLocation symbolA, SymbolLocation symbolB, Dictionary placeholderMap)
- {
- if (symbolA.depth != symbolB.depth)
- return false;
-
- if (!this.IsPlaceholder(symbolA.symbol))
- {
- return symbolA.symbol == symbolB.symbol;
- }
-
- if (placeholderMap.ContainsKey(symbolA.symbol))
- {
- return placeholderMap[symbolA.symbol] == symbolB.symbol;
- }
-
- placeholderMap[symbolA.symbol] = symbolB.symbol;
-
- return true;
- }
-
- /// Determines whether a type which has placeholders correctly resolves to the concrete type provided.
- /// A type containing placeholders such as !0.
- /// The list of symbols extracted from the concrete type.
- /// true if the type resolves correctly, false if not.
- private bool PlaceholderTypeResolvesToActualType(string type, List typeSymbols)
- {
- Dictionary placeholderMap = new Dictionary();
-
- int depth = 0, symbolCount = 0;
- string symbol = "";
-
- foreach (char c in type)
- {
- if (this.symbolBoundaries.Contains(c))
- {
- bool match = this.SymbolsMatch(new SymbolLocation(symbol, depth), typeSymbols[symbolCount], placeholderMap);
- if (typeSymbols.Count <= symbolCount ||
- !match)
+ if (!HeuristicallyEquals(symbolsA[i], symbolsB[i], tokenMap))
return false;
-
- symbolCount++;
- symbol = "";
- switch (c)
- {
- case '<':
- depth++;
- break;
- case '>':
- depth--;
- break;
- default:
- break;
- }
}
- else
- symbol += c;
+
+ return true;
}
- return true;
+ return HeuristicallyEquals(typeA.FullName, typeB.FullName, new Dictionary());
}
- /// Determines whether a type with placeholders in it matches a type without placeholders.
- /// The type with placeholders in it.
- /// The type without placeholders.
- /// true if the placeholder type can resolve to the actual type, false if not.
- private bool PlaceholderTypeValidates(string placeholderType, string actualType)
+ /// Map a generic type placeholder (like !0) to its actual type.
+ /// The token placeholder.
+ /// The actual type.
+ /// The map of token to map substitutions.
+ /// Returns the previously-mapped type if applicable, else the .
+ private string MapPlaceholder(string placeholder, string type, IDictionary map)
{
- List typeSymbols = new List();
+ if (map.TryGetValue(placeholder, out string result))
+ return result;
- this.TraverseActualType(actualType, typeSymbols);
- return PlaceholderTypeResolvesToActualType(placeholderType, typeSymbols);
+ map[placeholder] = type;
+ return type;
}
-
-
- /*********
- ** Inner classes
- *********/
- protected class SymbolLocation
+ /// Get the top-level type symbols in a type name (e.g. List
and NetRef<T>
in List<NetRef<T>>
)
+ /// The full type name.
+ private IEnumerable GetTypeSymbols(string typeName)
{
- public string symbol;
- public int depth;
+ int openGenerics = 0;
- public SymbolLocation(string symbol, int depth)
+ Queue queue = new Queue(typeName);
+ string symbol = "";
+ while (queue.Any())
{
- this.symbol = symbol;
- this.depth = depth;
+ char ch = queue.Dequeue();
+ switch (ch)
+ {
+ // skip `1 generic type identifiers
+ case '`':
+ while (int.TryParse(queue.Peek().ToString(), out int _))
+ queue.Dequeue();
+ break;
+
+ // start generic args
+ case '<':
+ switch (openGenerics)
+ {
+ // start new generic symbol
+ case 0:
+ yield return symbol;
+ symbol = "";
+ openGenerics++;
+ break;
+
+ // continue accumulating nested type symbol
+ default:
+ symbol += ch;
+ openGenerics++;
+ break;
+ }
+ break;
+
+ // generic args delimiter
+ case ',':
+ switch (openGenerics)
+ {
+ // invalid
+ case 0:
+ throw new InvalidOperationException($"Encountered unexpected comma in type name: {typeName}.");
+
+ // start next generic symbol
+ case 1:
+ yield return symbol;
+ symbol = "";
+ break;
+
+ // continue accumulating nested type symbol
+ default:
+ symbol += ch;
+ break;
+ }
+ break;
+
+
+ // end generic args
+ case '>':
+ switch (openGenerics)
+ {
+ // invalid
+ case 0:
+ throw new InvalidOperationException($"Encountered unexpected closing generic in type name: {typeName}.");
+
+ // end generic symbol
+ case 1:
+ yield return symbol;
+ symbol = "";
+ openGenerics--;
+ break;
+
+ // continue accumulating nested type symbol
+ default:
+ symbol += ch;
+ openGenerics--;
+ break;
+ }
+ break;
+
+ // continue symbol
+ default:
+ symbol += ch;
+ break;
+ }
}
+
+ if (symbol != "")
+ yield return symbol;
}
}
}