diff --git a/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs
new file mode 100644
index 00000000..93003a64
--- /dev/null
+++ b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs
@@ -0,0 +1,318 @@
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using Mono.Cecil;
+using Mono.Cecil.Cil;
+using CallSite = Mono.Cecil.CallSite;
+
+namespace StardewModdingAPI.Framework.AssemblyRewriting
+{
+ /// Rewrites type references.
+ internal class AssemblyTypeRewriter
+ {
+ /*********
+ ** Properties
+ *********/
+ /// The assemblies to target. Equivalent types will be rewritten to use these assemblies.
+ private readonly Assembly[] TargetAssemblies;
+
+ /// >The short assembly names to remove as assembly reference, and replace with the .
+ private readonly string[] RemoveAssemblyNames;
+
+ /// A type => assembly lookup for types which should be rewritten.
+ private readonly IDictionary TypeAssemblies;
+
+ /// An assembly => reference cache.
+ private readonly IDictionary AssemblyNameReferences;
+
+ /// An assembly => module cache.
+ private readonly IDictionary AssemblyModules;
+
+
+ /*********
+ ** Public methods
+ *********/
+ /// Construct an instance.
+ /// The assembly filenames to target. Equivalent types will be rewritten to use these assemblies.
+ /// The short assembly names to remove as assembly reference, and replace with the .
+ public AssemblyTypeRewriter(Assembly[] targetAssemblies, string[] removeAssemblyNames)
+ {
+ // save config
+ this.TargetAssemblies = targetAssemblies;
+ this.RemoveAssemblyNames = removeAssemblyNames;
+
+ // cache assembly metadata
+ this.AssemblyNameReferences = targetAssemblies.ToDictionary(assembly => assembly, assembly => AssemblyNameReference.Parse(assembly.FullName));
+ this.AssemblyModules = targetAssemblies.ToDictionary(assembly => assembly, assembly => ModuleDefinition.ReadModule(assembly.Modules.Single().FullyQualifiedName)); // technically an assembly can contain multiple modules, but none of the build tools (including MSBuild itself) support it
+
+ // collect type => assembly lookup
+ this.TypeAssemblies = new Dictionary();
+ foreach (Assembly assembly in targetAssemblies)
+ {
+ ModuleDefinition module = this.AssemblyModules[assembly];
+ foreach (TypeDefinition type in module.GetTypes())
+ {
+ if (!type.IsPublic)
+ continue; // no need to rewrite
+ if (type.Namespace.Contains("<"))
+ continue; // ignore C++ stuff
+ this.TypeAssemblies[type.FullName] = assembly;
+ }
+ }
+ }
+
+ /// Rewrite the types referenced by an assembly.
+ /// The assembly to rewrite.
+ public void RewriteAssembly(AssemblyDefinition assembly)
+ {
+ foreach (ModuleDefinition module in assembly.Modules)
+ {
+ // rewrite assembly references
+ bool shouldRewriteTypes = false;
+ for (int i = 0; i < module.AssemblyReferences.Count; i++)
+ {
+ bool shouldRemove = this.RemoveAssemblyNames.Any(name => module.AssemblyReferences[i].Name == name) || this.TargetAssemblies.Any(a => module.AssemblyReferences[i].Name == a.GetName().Name);
+ if (shouldRemove)
+ {
+ shouldRewriteTypes = true;
+ module.AssemblyReferences.RemoveAt(i);
+ i--;
+ }
+ }
+ foreach (AssemblyNameReference target in this.AssemblyNameReferences.Values)
+ {
+ module.AssemblyReferences.Add(target);
+ shouldRewriteTypes = true;
+ }
+
+ // rewrite references
+ if (shouldRewriteTypes)
+ {
+ // rewrite types
+ foreach (TypeDefinition type in module.GetTypes())
+ this.RewriteReferences(type, module);
+
+ // rewrite type references
+ TypeReference[] refs = (TypeReference[])module.GetTypeReferences();
+ for (int i = 0; i < refs.Length; ++i)
+ refs[i] = this.GetTypeReference(refs[i], module);
+ }
+ }
+ }
+
+
+ /*********
+ ** Private methods
+ *********/
+ /// Rewrite the references for a code object.
+ /// The type to rewrite.
+ /// The module being rewritten.
+ private void RewriteReferences(TypeDefinition type, ModuleDefinition module)
+ {
+ // rewrite base type
+ type.BaseType = this.GetTypeReference(type.BaseType, module);
+
+ // rewrite interfaces
+ for (int i = 0; i < type.Interfaces.Count; i++)
+ type.Interfaces[i] = this.GetTypeReference(type.Interfaces[i], module);
+
+ // rewrite events
+ foreach (EventDefinition @event in type.Events)
+ {
+ this.RewriteReferences(@event.AddMethod, module);
+ this.RewriteReferences(@event.RemoveMethod, module);
+ this.RewriteReferences(@event.InvokeMethod, module);
+ }
+
+ // rewrite properties
+ foreach (PropertyDefinition property in type.Properties)
+ {
+ this.RewriteReferences(property.GetMethod, module);
+ this.RewriteReferences(property.SetMethod, module);
+ }
+
+ // rewrite methods
+ foreach (MethodDefinition method in type.Methods)
+ this.RewriteReferences(method, module);
+
+ // rewrite fields
+ foreach (FieldDefinition field in type.Fields)
+ this.RewriteReferences(field, module);
+
+ // rewrite nested types
+ foreach (TypeDefinition nestedType in type.NestedTypes)
+ this.RewriteReferences(nestedType, module);
+
+ // rewrite generic parameters
+ foreach (GenericParameter parameter in type.GenericParameters)
+ this.RewriteReferences(parameter, module);
+
+ module.Import(type);
+ }
+
+ /// Rewrite the references for a code object.
+ /// The method to rewrite.
+ /// The module being rewritten.
+ private void RewriteReferences(MethodReference method, ModuleDefinition module)
+ {
+ // parameter types
+ if (method.HasParameters)
+ {
+ foreach (ParameterDefinition parameter in method.Parameters)
+ parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
+ }
+
+ // return type
+ method.MethodReturnType.ReturnType = this.GetTypeReference(method.MethodReturnType.ReturnType, module);
+
+ module.Import(method);
+ }
+
+ /// Rewrite the references for a code object.
+ /// The method to rewrite.
+ /// The module being rewritten.
+ private void RewriteReferences(MethodDefinition method, ModuleDefinition module)
+ {
+ if (method == null)
+ return;
+
+ this.RewriteReferences((MethodReference)method, module);
+
+ // overrides
+ foreach (MethodReference @override in method.Overrides)
+ this.RewriteReferences(@override, module);
+
+ // body
+ if (method.HasBody)
+ {
+ // this
+ if (method.Body.ThisParameter != null)
+ method.Body.ThisParameter.ParameterType = this.GetTypeReference(method.Body.ThisParameter.ParameterType, module);
+
+ // variables
+ if (method.Body.HasVariables)
+ {
+ foreach (VariableDefinition variable in method.Body.Variables)
+ variable.VariableType = this.GetTypeReference(variable.VariableType, module);
+ }
+
+ // instructions
+ foreach (Instruction instruction in method.Body.Instructions)
+ {
+ object operand = instruction.Operand;
+
+ // type
+ {
+ TypeReference type = operand as TypeReference;
+ if (type != null)
+ {
+ instruction.Operand = this.GetTypeReference(type, module);
+ continue;
+ }
+ }
+
+ // method
+ {
+ MethodReference methodRef = operand as MethodReference;
+ if (methodRef != null)
+ {
+ this.RewriteReferences(methodRef, module);
+ continue;
+ }
+ }
+
+ // field
+ {
+ FieldReference field = operand as FieldReference;
+ if (field != null)
+ {
+ this.RewriteReferences(field, module);
+ continue;
+ }
+ }
+
+ // variable
+ {
+ VariableDefinition variable = operand as VariableDefinition;
+ if (variable != null)
+ {
+ variable.VariableType = this.GetTypeReference(variable.VariableType, module);
+ continue;
+ }
+ }
+
+ // parameter
+ {
+ ParameterDefinition parameter = operand as ParameterDefinition;
+ if (parameter != null)
+ {
+ parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
+ continue;
+ }
+ }
+
+ // call site
+ {
+ CallSite call = operand as CallSite;
+ if (call != null)
+ {
+ foreach (ParameterDefinition parameter in call.Parameters)
+ parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
+ call.ReturnType = this.GetTypeReference(call.ReturnType, module);
+ }
+ }
+ }
+ }
+
+ module.Import(method);
+ }
+
+ /// Rewrite the references for a code object.
+ /// The generic parameter to rewrite.
+ /// The module being rewritten.
+ private void RewriteReferences(GenericParameter parameter, ModuleDefinition module)
+ {
+ // constraints
+ for (int i = 0; i < parameter.Constraints.Count; i++)
+ parameter.Constraints[i] = this.GetTypeReference(parameter.Constraints[i], module);
+
+ // generic parameters
+ foreach (GenericParameter genericParam in parameter.GenericParameters)
+ this.RewriteReferences(genericParam, module);
+ }
+
+ /// Rewrite the references for a code object.
+ /// The field to rewrite.
+ /// The module being rewritten.
+ private void RewriteReferences(FieldReference field, ModuleDefinition module)
+ {
+ field.DeclaringType = this.GetTypeReference(field.DeclaringType, module);
+ field.FieldType = this.GetTypeReference(field.FieldType, module);
+ module.Import(field);
+ }
+
+ /// Get the correct reference to use for compatibility with the current platform.
+ /// The type reference to rewrite.
+ /// The module being rewritten.
+ private TypeReference GetTypeReference(TypeReference type, ModuleDefinition module)
+ {
+ // check skip conditions
+ if (type == null)
+ return null;
+ if (type.FullName.StartsWith("System."))
+ return type;
+
+ // get assembly
+ Assembly assembly;
+ if (!this.TypeAssemblies.TryGetValue(type.FullName, out assembly))
+ return type;
+
+ // replace type
+ AssemblyNameReference newAssembly = this.AssemblyNameReferences[assembly];
+ ModuleDefinition newModule = this.AssemblyModules[assembly];
+ type = new TypeReference(type.Namespace, type.Name, newModule, newAssembly);
+
+ return module.Import(type);
+ }
+ }
+}
diff --git a/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs b/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs
index bde23e3b..4e59bb08 100644
--- a/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs
+++ b/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs
@@ -17,6 +17,9 @@ namespace StardewModdingAPI.Framework
/// The directory in which to cache data.
private readonly string CacheDirPath;
+ /// Rewrites assembly types to match the current platform.
+ private readonly AssemblyTypeRewriter AssemblyTypeRewriter;
+
/// Encapsulates monitoring and logging for a given module.
private readonly IMonitor Monitor;
@@ -32,6 +35,7 @@ namespace StardewModdingAPI.Framework
{
this.CacheDirPath = cacheDirPath;
this.Monitor = monitor;
+ this.AssemblyTypeRewriter = this.GetAssemblyRewriter(targetPlatform);
}
/// Preprocess an assembly and cache the modified version.
@@ -52,14 +56,17 @@ namespace StardewModdingAPI.Framework
this.Monitor.Log($"Preprocessing new assembly {assemblyPath}...");
// read assembly definition
- AssemblyDefinition definition;
+ AssemblyDefinition assembly;
using (Stream readStream = new MemoryStream(assemblyBytes))
- definition = AssemblyDefinition.ReadAssembly(readStream);
+ assembly = AssemblyDefinition.ReadAssembly(readStream);
+
+ // rewrite assembly to match platform
+ this.AssemblyTypeRewriter.RewriteAssembly(assembly);
// write cache
using (MemoryStream outStream = new MemoryStream())
{
- definition.Write(outStream);
+ assembly.Write(outStream);
byte[] outBytes = outStream.ToArray();
Directory.CreateDirectory(cachePaths.Directory);
File.WriteAllBytes(cachePaths.Assembly, outBytes);
@@ -92,5 +99,51 @@ namespace StardewModdingAPI.Framework
string cacheHashPath = Path.Combine(dirPath, $"{key}.hash");
return new CachePaths(dirPath, cacheAssemblyPath, cacheHashPath);
}
+
+ /// Get an assembly rewriter for the target platform.
+ /// The target game platform.
+ private AssemblyTypeRewriter GetAssemblyRewriter(Platform targetPlatform)
+ {
+ // get assembly changes needed for platform
+ string[] removeAssemblyReferences;
+ Assembly[] targetAssemblies;
+ switch (targetPlatform)
+ {
+ case Platform.Mono:
+ removeAssemblyReferences = new[]
+ {
+ "Stardew Valley",
+ "Microsoft.Xna.Framework",
+ "Microsoft.Xna.Framework.Game",
+ "Microsoft.Xna.Framework.Graphics"
+ };
+ targetAssemblies = new[]
+ {
+ typeof(StardewValley.Game1).Assembly,
+ typeof(Microsoft.Xna.Framework.Vector2).Assembly
+ };
+ break;
+
+ case Platform.Windows:
+ removeAssemblyReferences = new[]
+ {
+ "StardewValley",
+ "MonoGame.Framework"
+ };
+ targetAssemblies = new[]
+ {
+ typeof(StardewValley.Game1).Assembly,
+ typeof(Microsoft.Xna.Framework.Vector2).Assembly,
+ typeof(Microsoft.Xna.Framework.Game).Assembly,
+ typeof(Microsoft.Xna.Framework.Graphics.SpriteBatch).Assembly
+ };
+ break;
+
+ default:
+ throw new InvalidOperationException($"Unknown target platform '{targetPlatform}'.");
+ }
+
+ return new AssemblyTypeRewriter(targetAssemblies, removeAssemblyReferences);
+ }
}
}
diff --git a/src/StardewModdingAPI/StardewModdingAPI.csproj b/src/StardewModdingAPI/StardewModdingAPI.csproj
index c835df42..2abcdc23 100644
--- a/src/StardewModdingAPI/StardewModdingAPI.csproj
+++ b/src/StardewModdingAPI/StardewModdingAPI.csproj
@@ -214,6 +214,7 @@
+