diff --git a/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs new file mode 100644 index 00000000..93003a64 --- /dev/null +++ b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs @@ -0,0 +1,318 @@ +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Mono.Cecil; +using Mono.Cecil.Cil; +using CallSite = Mono.Cecil.CallSite; + +namespace StardewModdingAPI.Framework.AssemblyRewriting +{ + /// Rewrites type references. + internal class AssemblyTypeRewriter + { + /********* + ** Properties + *********/ + /// The assemblies to target. Equivalent types will be rewritten to use these assemblies. + private readonly Assembly[] TargetAssemblies; + + /// >The short assembly names to remove as assembly reference, and replace with the . + private readonly string[] RemoveAssemblyNames; + + /// A type => assembly lookup for types which should be rewritten. + private readonly IDictionary TypeAssemblies; + + /// An assembly => reference cache. + private readonly IDictionary AssemblyNameReferences; + + /// An assembly => module cache. + private readonly IDictionary AssemblyModules; + + + /********* + ** Public methods + *********/ + /// Construct an instance. + /// The assembly filenames to target. Equivalent types will be rewritten to use these assemblies. + /// The short assembly names to remove as assembly reference, and replace with the . + public AssemblyTypeRewriter(Assembly[] targetAssemblies, string[] removeAssemblyNames) + { + // save config + this.TargetAssemblies = targetAssemblies; + this.RemoveAssemblyNames = removeAssemblyNames; + + // cache assembly metadata + this.AssemblyNameReferences = targetAssemblies.ToDictionary(assembly => assembly, assembly => AssemblyNameReference.Parse(assembly.FullName)); + this.AssemblyModules = targetAssemblies.ToDictionary(assembly => assembly, assembly => ModuleDefinition.ReadModule(assembly.Modules.Single().FullyQualifiedName)); // technically an assembly can contain multiple modules, but none of the build tools (including MSBuild itself) support it + + // collect type => assembly lookup + this.TypeAssemblies = new Dictionary(); + foreach (Assembly assembly in targetAssemblies) + { + ModuleDefinition module = this.AssemblyModules[assembly]; + foreach (TypeDefinition type in module.GetTypes()) + { + if (!type.IsPublic) + continue; // no need to rewrite + if (type.Namespace.Contains("<")) + continue; // ignore C++ stuff + this.TypeAssemblies[type.FullName] = assembly; + } + } + } + + /// Rewrite the types referenced by an assembly. + /// The assembly to rewrite. + public void RewriteAssembly(AssemblyDefinition assembly) + { + foreach (ModuleDefinition module in assembly.Modules) + { + // rewrite assembly references + bool shouldRewriteTypes = false; + for (int i = 0; i < module.AssemblyReferences.Count; i++) + { + bool shouldRemove = this.RemoveAssemblyNames.Any(name => module.AssemblyReferences[i].Name == name) || this.TargetAssemblies.Any(a => module.AssemblyReferences[i].Name == a.GetName().Name); + if (shouldRemove) + { + shouldRewriteTypes = true; + module.AssemblyReferences.RemoveAt(i); + i--; + } + } + foreach (AssemblyNameReference target in this.AssemblyNameReferences.Values) + { + module.AssemblyReferences.Add(target); + shouldRewriteTypes = true; + } + + // rewrite references + if (shouldRewriteTypes) + { + // rewrite types + foreach (TypeDefinition type in module.GetTypes()) + this.RewriteReferences(type, module); + + // rewrite type references + TypeReference[] refs = (TypeReference[])module.GetTypeReferences(); + for (int i = 0; i < refs.Length; ++i) + refs[i] = this.GetTypeReference(refs[i], module); + } + } + } + + + /********* + ** Private methods + *********/ + /// Rewrite the references for a code object. + /// The type to rewrite. + /// The module being rewritten. + private void RewriteReferences(TypeDefinition type, ModuleDefinition module) + { + // rewrite base type + type.BaseType = this.GetTypeReference(type.BaseType, module); + + // rewrite interfaces + for (int i = 0; i < type.Interfaces.Count; i++) + type.Interfaces[i] = this.GetTypeReference(type.Interfaces[i], module); + + // rewrite events + foreach (EventDefinition @event in type.Events) + { + this.RewriteReferences(@event.AddMethod, module); + this.RewriteReferences(@event.RemoveMethod, module); + this.RewriteReferences(@event.InvokeMethod, module); + } + + // rewrite properties + foreach (PropertyDefinition property in type.Properties) + { + this.RewriteReferences(property.GetMethod, module); + this.RewriteReferences(property.SetMethod, module); + } + + // rewrite methods + foreach (MethodDefinition method in type.Methods) + this.RewriteReferences(method, module); + + // rewrite fields + foreach (FieldDefinition field in type.Fields) + this.RewriteReferences(field, module); + + // rewrite nested types + foreach (TypeDefinition nestedType in type.NestedTypes) + this.RewriteReferences(nestedType, module); + + // rewrite generic parameters + foreach (GenericParameter parameter in type.GenericParameters) + this.RewriteReferences(parameter, module); + + module.Import(type); + } + + /// Rewrite the references for a code object. + /// The method to rewrite. + /// The module being rewritten. + private void RewriteReferences(MethodReference method, ModuleDefinition module) + { + // parameter types + if (method.HasParameters) + { + foreach (ParameterDefinition parameter in method.Parameters) + parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module); + } + + // return type + method.MethodReturnType.ReturnType = this.GetTypeReference(method.MethodReturnType.ReturnType, module); + + module.Import(method); + } + + /// Rewrite the references for a code object. + /// The method to rewrite. + /// The module being rewritten. + private void RewriteReferences(MethodDefinition method, ModuleDefinition module) + { + if (method == null) + return; + + this.RewriteReferences((MethodReference)method, module); + + // overrides + foreach (MethodReference @override in method.Overrides) + this.RewriteReferences(@override, module); + + // body + if (method.HasBody) + { + // this + if (method.Body.ThisParameter != null) + method.Body.ThisParameter.ParameterType = this.GetTypeReference(method.Body.ThisParameter.ParameterType, module); + + // variables + if (method.Body.HasVariables) + { + foreach (VariableDefinition variable in method.Body.Variables) + variable.VariableType = this.GetTypeReference(variable.VariableType, module); + } + + // instructions + foreach (Instruction instruction in method.Body.Instructions) + { + object operand = instruction.Operand; + + // type + { + TypeReference type = operand as TypeReference; + if (type != null) + { + instruction.Operand = this.GetTypeReference(type, module); + continue; + } + } + + // method + { + MethodReference methodRef = operand as MethodReference; + if (methodRef != null) + { + this.RewriteReferences(methodRef, module); + continue; + } + } + + // field + { + FieldReference field = operand as FieldReference; + if (field != null) + { + this.RewriteReferences(field, module); + continue; + } + } + + // variable + { + VariableDefinition variable = operand as VariableDefinition; + if (variable != null) + { + variable.VariableType = this.GetTypeReference(variable.VariableType, module); + continue; + } + } + + // parameter + { + ParameterDefinition parameter = operand as ParameterDefinition; + if (parameter != null) + { + parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module); + continue; + } + } + + // call site + { + CallSite call = operand as CallSite; + if (call != null) + { + foreach (ParameterDefinition parameter in call.Parameters) + parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module); + call.ReturnType = this.GetTypeReference(call.ReturnType, module); + } + } + } + } + + module.Import(method); + } + + /// Rewrite the references for a code object. + /// The generic parameter to rewrite. + /// The module being rewritten. + private void RewriteReferences(GenericParameter parameter, ModuleDefinition module) + { + // constraints + for (int i = 0; i < parameter.Constraints.Count; i++) + parameter.Constraints[i] = this.GetTypeReference(parameter.Constraints[i], module); + + // generic parameters + foreach (GenericParameter genericParam in parameter.GenericParameters) + this.RewriteReferences(genericParam, module); + } + + /// Rewrite the references for a code object. + /// The field to rewrite. + /// The module being rewritten. + private void RewriteReferences(FieldReference field, ModuleDefinition module) + { + field.DeclaringType = this.GetTypeReference(field.DeclaringType, module); + field.FieldType = this.GetTypeReference(field.FieldType, module); + module.Import(field); + } + + /// Get the correct reference to use for compatibility with the current platform. + /// The type reference to rewrite. + /// The module being rewritten. + private TypeReference GetTypeReference(TypeReference type, ModuleDefinition module) + { + // check skip conditions + if (type == null) + return null; + if (type.FullName.StartsWith("System.")) + return type; + + // get assembly + Assembly assembly; + if (!this.TypeAssemblies.TryGetValue(type.FullName, out assembly)) + return type; + + // replace type + AssemblyNameReference newAssembly = this.AssemblyNameReferences[assembly]; + ModuleDefinition newModule = this.AssemblyModules[assembly]; + type = new TypeReference(type.Namespace, type.Name, newModule, newAssembly); + + return module.Import(type); + } + } +} diff --git a/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs b/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs index bde23e3b..4e59bb08 100644 --- a/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs +++ b/src/StardewModdingAPI/Framework/ModAssemblyLoader.cs @@ -17,6 +17,9 @@ namespace StardewModdingAPI.Framework /// The directory in which to cache data. private readonly string CacheDirPath; + /// Rewrites assembly types to match the current platform. + private readonly AssemblyTypeRewriter AssemblyTypeRewriter; + /// Encapsulates monitoring and logging for a given module. private readonly IMonitor Monitor; @@ -32,6 +35,7 @@ namespace StardewModdingAPI.Framework { this.CacheDirPath = cacheDirPath; this.Monitor = monitor; + this.AssemblyTypeRewriter = this.GetAssemblyRewriter(targetPlatform); } /// Preprocess an assembly and cache the modified version. @@ -52,14 +56,17 @@ namespace StardewModdingAPI.Framework this.Monitor.Log($"Preprocessing new assembly {assemblyPath}..."); // read assembly definition - AssemblyDefinition definition; + AssemblyDefinition assembly; using (Stream readStream = new MemoryStream(assemblyBytes)) - definition = AssemblyDefinition.ReadAssembly(readStream); + assembly = AssemblyDefinition.ReadAssembly(readStream); + + // rewrite assembly to match platform + this.AssemblyTypeRewriter.RewriteAssembly(assembly); // write cache using (MemoryStream outStream = new MemoryStream()) { - definition.Write(outStream); + assembly.Write(outStream); byte[] outBytes = outStream.ToArray(); Directory.CreateDirectory(cachePaths.Directory); File.WriteAllBytes(cachePaths.Assembly, outBytes); @@ -92,5 +99,51 @@ namespace StardewModdingAPI.Framework string cacheHashPath = Path.Combine(dirPath, $"{key}.hash"); return new CachePaths(dirPath, cacheAssemblyPath, cacheHashPath); } + + /// Get an assembly rewriter for the target platform. + /// The target game platform. + private AssemblyTypeRewriter GetAssemblyRewriter(Platform targetPlatform) + { + // get assembly changes needed for platform + string[] removeAssemblyReferences; + Assembly[] targetAssemblies; + switch (targetPlatform) + { + case Platform.Mono: + removeAssemblyReferences = new[] + { + "Stardew Valley", + "Microsoft.Xna.Framework", + "Microsoft.Xna.Framework.Game", + "Microsoft.Xna.Framework.Graphics" + }; + targetAssemblies = new[] + { + typeof(StardewValley.Game1).Assembly, + typeof(Microsoft.Xna.Framework.Vector2).Assembly + }; + break; + + case Platform.Windows: + removeAssemblyReferences = new[] + { + "StardewValley", + "MonoGame.Framework" + }; + targetAssemblies = new[] + { + typeof(StardewValley.Game1).Assembly, + typeof(Microsoft.Xna.Framework.Vector2).Assembly, + typeof(Microsoft.Xna.Framework.Game).Assembly, + typeof(Microsoft.Xna.Framework.Graphics.SpriteBatch).Assembly + }; + break; + + default: + throw new InvalidOperationException($"Unknown target platform '{targetPlatform}'."); + } + + return new AssemblyTypeRewriter(targetAssemblies, removeAssemblyReferences); + } } } diff --git a/src/StardewModdingAPI/StardewModdingAPI.csproj b/src/StardewModdingAPI/StardewModdingAPI.csproj index c835df42..2abcdc23 100644 --- a/src/StardewModdingAPI/StardewModdingAPI.csproj +++ b/src/StardewModdingAPI/StardewModdingAPI.csproj @@ -214,6 +214,7 @@ +