rewrite type references in mod assemblies to match target platform (#166)
This commit is contained in:
parent
4df1999855
commit
b06aed66c4
|
@ -0,0 +1,318 @@
|
|||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using Mono.Cecil;
|
||||
using Mono.Cecil.Cil;
|
||||
using CallSite = Mono.Cecil.CallSite;
|
||||
|
||||
namespace StardewModdingAPI.Framework.AssemblyRewriting
|
||||
{
|
||||
/// <summary>Rewrites type references.</summary>
|
||||
internal class AssemblyTypeRewriter
|
||||
{
|
||||
/*********
|
||||
** Properties
|
||||
*********/
|
||||
/// <summary>The assemblies to target. Equivalent types will be rewritten to use these assemblies.</summary>
|
||||
private readonly Assembly[] TargetAssemblies;
|
||||
|
||||
/// <summary>>The short assembly names to remove as assembly reference, and replace with the <see cref="TargetAssemblies"/>.</summary>
|
||||
private readonly string[] RemoveAssemblyNames;
|
||||
|
||||
/// <summary>A type => assembly lookup for types which should be rewritten.</summary>
|
||||
private readonly IDictionary<string, Assembly> TypeAssemblies;
|
||||
|
||||
/// <summary>An assembly => reference cache.</summary>
|
||||
private readonly IDictionary<Assembly, AssemblyNameReference> AssemblyNameReferences;
|
||||
|
||||
/// <summary>An assembly => module cache.</summary>
|
||||
private readonly IDictionary<Assembly, ModuleDefinition> AssemblyModules;
|
||||
|
||||
|
||||
/*********
|
||||
** Public methods
|
||||
*********/
|
||||
/// <summary>Construct an instance.</summary>
|
||||
/// <param name="targetAssemblies">The assembly filenames to target. Equivalent types will be rewritten to use these assemblies.</param>
|
||||
/// <param name="removeAssemblyNames">The short assembly names to remove as assembly reference, and replace with the <paramref name="targetAssemblies"/>.</param>
|
||||
public AssemblyTypeRewriter(Assembly[] targetAssemblies, string[] removeAssemblyNames)
|
||||
{
|
||||
// save config
|
||||
this.TargetAssemblies = targetAssemblies;
|
||||
this.RemoveAssemblyNames = removeAssemblyNames;
|
||||
|
||||
// cache assembly metadata
|
||||
this.AssemblyNameReferences = targetAssemblies.ToDictionary(assembly => assembly, assembly => AssemblyNameReference.Parse(assembly.FullName));
|
||||
this.AssemblyModules = targetAssemblies.ToDictionary(assembly => assembly, assembly => ModuleDefinition.ReadModule(assembly.Modules.Single().FullyQualifiedName)); // technically an assembly can contain multiple modules, but none of the build tools (including MSBuild itself) support it
|
||||
|
||||
// collect type => assembly lookup
|
||||
this.TypeAssemblies = new Dictionary<string, Assembly>();
|
||||
foreach (Assembly assembly in targetAssemblies)
|
||||
{
|
||||
ModuleDefinition module = this.AssemblyModules[assembly];
|
||||
foreach (TypeDefinition type in module.GetTypes())
|
||||
{
|
||||
if (!type.IsPublic)
|
||||
continue; // no need to rewrite
|
||||
if (type.Namespace.Contains("<"))
|
||||
continue; // ignore C++ stuff
|
||||
this.TypeAssemblies[type.FullName] = assembly;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Rewrite the types referenced by an assembly.</summary>
|
||||
/// <param name="assembly">The assembly to rewrite.</param>
|
||||
public void RewriteAssembly(AssemblyDefinition assembly)
|
||||
{
|
||||
foreach (ModuleDefinition module in assembly.Modules)
|
||||
{
|
||||
// rewrite assembly references
|
||||
bool shouldRewriteTypes = false;
|
||||
for (int i = 0; i < module.AssemblyReferences.Count; i++)
|
||||
{
|
||||
bool shouldRemove = this.RemoveAssemblyNames.Any(name => module.AssemblyReferences[i].Name == name) || this.TargetAssemblies.Any(a => module.AssemblyReferences[i].Name == a.GetName().Name);
|
||||
if (shouldRemove)
|
||||
{
|
||||
shouldRewriteTypes = true;
|
||||
module.AssemblyReferences.RemoveAt(i);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
foreach (AssemblyNameReference target in this.AssemblyNameReferences.Values)
|
||||
{
|
||||
module.AssemblyReferences.Add(target);
|
||||
shouldRewriteTypes = true;
|
||||
}
|
||||
|
||||
// rewrite references
|
||||
if (shouldRewriteTypes)
|
||||
{
|
||||
// rewrite types
|
||||
foreach (TypeDefinition type in module.GetTypes())
|
||||
this.RewriteReferences(type, module);
|
||||
|
||||
// rewrite type references
|
||||
TypeReference[] refs = (TypeReference[])module.GetTypeReferences();
|
||||
for (int i = 0; i < refs.Length; ++i)
|
||||
refs[i] = this.GetTypeReference(refs[i], module);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/*********
|
||||
** Private methods
|
||||
*********/
|
||||
/// <summary>Rewrite the references for a code object.</summary>
|
||||
/// <param name="type">The type to rewrite.</param>
|
||||
/// <param name="module">The module being rewritten.</param>
|
||||
private void RewriteReferences(TypeDefinition type, ModuleDefinition module)
|
||||
{
|
||||
// rewrite base type
|
||||
type.BaseType = this.GetTypeReference(type.BaseType, module);
|
||||
|
||||
// rewrite interfaces
|
||||
for (int i = 0; i < type.Interfaces.Count; i++)
|
||||
type.Interfaces[i] = this.GetTypeReference(type.Interfaces[i], module);
|
||||
|
||||
// rewrite events
|
||||
foreach (EventDefinition @event in type.Events)
|
||||
{
|
||||
this.RewriteReferences(@event.AddMethod, module);
|
||||
this.RewriteReferences(@event.RemoveMethod, module);
|
||||
this.RewriteReferences(@event.InvokeMethod, module);
|
||||
}
|
||||
|
||||
// rewrite properties
|
||||
foreach (PropertyDefinition property in type.Properties)
|
||||
{
|
||||
this.RewriteReferences(property.GetMethod, module);
|
||||
this.RewriteReferences(property.SetMethod, module);
|
||||
}
|
||||
|
||||
// rewrite methods
|
||||
foreach (MethodDefinition method in type.Methods)
|
||||
this.RewriteReferences(method, module);
|
||||
|
||||
// rewrite fields
|
||||
foreach (FieldDefinition field in type.Fields)
|
||||
this.RewriteReferences(field, module);
|
||||
|
||||
// rewrite nested types
|
||||
foreach (TypeDefinition nestedType in type.NestedTypes)
|
||||
this.RewriteReferences(nestedType, module);
|
||||
|
||||
// rewrite generic parameters
|
||||
foreach (GenericParameter parameter in type.GenericParameters)
|
||||
this.RewriteReferences(parameter, module);
|
||||
|
||||
module.Import(type);
|
||||
}
|
||||
|
||||
/// <summary>Rewrite the references for a code object.</summary>
|
||||
/// <param name="method">The method to rewrite.</param>
|
||||
/// <param name="module">The module being rewritten.</param>
|
||||
private void RewriteReferences(MethodReference method, ModuleDefinition module)
|
||||
{
|
||||
// parameter types
|
||||
if (method.HasParameters)
|
||||
{
|
||||
foreach (ParameterDefinition parameter in method.Parameters)
|
||||
parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
|
||||
}
|
||||
|
||||
// return type
|
||||
method.MethodReturnType.ReturnType = this.GetTypeReference(method.MethodReturnType.ReturnType, module);
|
||||
|
||||
module.Import(method);
|
||||
}
|
||||
|
||||
/// <summary>Rewrite the references for a code object.</summary>
|
||||
/// <param name="method">The method to rewrite.</param>
|
||||
/// <param name="module">The module being rewritten.</param>
|
||||
private void RewriteReferences(MethodDefinition method, ModuleDefinition module)
|
||||
{
|
||||
if (method == null)
|
||||
return;
|
||||
|
||||
this.RewriteReferences((MethodReference)method, module);
|
||||
|
||||
// overrides
|
||||
foreach (MethodReference @override in method.Overrides)
|
||||
this.RewriteReferences(@override, module);
|
||||
|
||||
// body
|
||||
if (method.HasBody)
|
||||
{
|
||||
// this
|
||||
if (method.Body.ThisParameter != null)
|
||||
method.Body.ThisParameter.ParameterType = this.GetTypeReference(method.Body.ThisParameter.ParameterType, module);
|
||||
|
||||
// variables
|
||||
if (method.Body.HasVariables)
|
||||
{
|
||||
foreach (VariableDefinition variable in method.Body.Variables)
|
||||
variable.VariableType = this.GetTypeReference(variable.VariableType, module);
|
||||
}
|
||||
|
||||
// instructions
|
||||
foreach (Instruction instruction in method.Body.Instructions)
|
||||
{
|
||||
object operand = instruction.Operand;
|
||||
|
||||
// type
|
||||
{
|
||||
TypeReference type = operand as TypeReference;
|
||||
if (type != null)
|
||||
{
|
||||
instruction.Operand = this.GetTypeReference(type, module);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// method
|
||||
{
|
||||
MethodReference methodRef = operand as MethodReference;
|
||||
if (methodRef != null)
|
||||
{
|
||||
this.RewriteReferences(methodRef, module);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// field
|
||||
{
|
||||
FieldReference field = operand as FieldReference;
|
||||
if (field != null)
|
||||
{
|
||||
this.RewriteReferences(field, module);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// variable
|
||||
{
|
||||
VariableDefinition variable = operand as VariableDefinition;
|
||||
if (variable != null)
|
||||
{
|
||||
variable.VariableType = this.GetTypeReference(variable.VariableType, module);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// parameter
|
||||
{
|
||||
ParameterDefinition parameter = operand as ParameterDefinition;
|
||||
if (parameter != null)
|
||||
{
|
||||
parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// call site
|
||||
{
|
||||
CallSite call = operand as CallSite;
|
||||
if (call != null)
|
||||
{
|
||||
foreach (ParameterDefinition parameter in call.Parameters)
|
||||
parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
|
||||
call.ReturnType = this.GetTypeReference(call.ReturnType, module);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.Import(method);
|
||||
}
|
||||
|
||||
/// <summary>Rewrite the references for a code object.</summary>
|
||||
/// <param name="parameter">The generic parameter to rewrite.</param>
|
||||
/// <param name="module">The module being rewritten.</param>
|
||||
private void RewriteReferences(GenericParameter parameter, ModuleDefinition module)
|
||||
{
|
||||
// constraints
|
||||
for (int i = 0; i < parameter.Constraints.Count; i++)
|
||||
parameter.Constraints[i] = this.GetTypeReference(parameter.Constraints[i], module);
|
||||
|
||||
// generic parameters
|
||||
foreach (GenericParameter genericParam in parameter.GenericParameters)
|
||||
this.RewriteReferences(genericParam, module);
|
||||
}
|
||||
|
||||
/// <summary>Rewrite the references for a code object.</summary>
|
||||
/// <param name="field">The field to rewrite.</param>
|
||||
/// <param name="module">The module being rewritten.</param>
|
||||
private void RewriteReferences(FieldReference field, ModuleDefinition module)
|
||||
{
|
||||
field.DeclaringType = this.GetTypeReference(field.DeclaringType, module);
|
||||
field.FieldType = this.GetTypeReference(field.FieldType, module);
|
||||
module.Import(field);
|
||||
}
|
||||
|
||||
/// <summary>Get the correct reference to use for compatibility with the current platform.</summary>
|
||||
/// <param name="type">The type reference to rewrite.</param>
|
||||
/// <param name="module">The module being rewritten.</param>
|
||||
private TypeReference GetTypeReference(TypeReference type, ModuleDefinition module)
|
||||
{
|
||||
// check skip conditions
|
||||
if (type == null)
|
||||
return null;
|
||||
if (type.FullName.StartsWith("System."))
|
||||
return type;
|
||||
|
||||
// get assembly
|
||||
Assembly assembly;
|
||||
if (!this.TypeAssemblies.TryGetValue(type.FullName, out assembly))
|
||||
return type;
|
||||
|
||||
// replace type
|
||||
AssemblyNameReference newAssembly = this.AssemblyNameReferences[assembly];
|
||||
ModuleDefinition newModule = this.AssemblyModules[assembly];
|
||||
type = new TypeReference(type.Namespace, type.Name, newModule, newAssembly);
|
||||
|
||||
return module.Import(type);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,6 +17,9 @@ namespace StardewModdingAPI.Framework
|
|||
/// <summary>The directory in which to cache data.</summary>
|
||||
private readonly string CacheDirPath;
|
||||
|
||||
/// <summary>Rewrites assembly types to match the current platform.</summary>
|
||||
private readonly AssemblyTypeRewriter AssemblyTypeRewriter;
|
||||
|
||||
/// <summary>Encapsulates monitoring and logging for a given module.</summary>
|
||||
private readonly IMonitor Monitor;
|
||||
|
||||
|
@ -32,6 +35,7 @@ namespace StardewModdingAPI.Framework
|
|||
{
|
||||
this.CacheDirPath = cacheDirPath;
|
||||
this.Monitor = monitor;
|
||||
this.AssemblyTypeRewriter = this.GetAssemblyRewriter(targetPlatform);
|
||||
}
|
||||
|
||||
/// <summary>Preprocess an assembly and cache the modified version.</summary>
|
||||
|
@ -52,14 +56,17 @@ namespace StardewModdingAPI.Framework
|
|||
this.Monitor.Log($"Preprocessing new assembly {assemblyPath}...");
|
||||
|
||||
// read assembly definition
|
||||
AssemblyDefinition definition;
|
||||
AssemblyDefinition assembly;
|
||||
using (Stream readStream = new MemoryStream(assemblyBytes))
|
||||
definition = AssemblyDefinition.ReadAssembly(readStream);
|
||||
assembly = AssemblyDefinition.ReadAssembly(readStream);
|
||||
|
||||
// rewrite assembly to match platform
|
||||
this.AssemblyTypeRewriter.RewriteAssembly(assembly);
|
||||
|
||||
// write cache
|
||||
using (MemoryStream outStream = new MemoryStream())
|
||||
{
|
||||
definition.Write(outStream);
|
||||
assembly.Write(outStream);
|
||||
byte[] outBytes = outStream.ToArray();
|
||||
Directory.CreateDirectory(cachePaths.Directory);
|
||||
File.WriteAllBytes(cachePaths.Assembly, outBytes);
|
||||
|
@ -92,5 +99,51 @@ namespace StardewModdingAPI.Framework
|
|||
string cacheHashPath = Path.Combine(dirPath, $"{key}.hash");
|
||||
return new CachePaths(dirPath, cacheAssemblyPath, cacheHashPath);
|
||||
}
|
||||
|
||||
/// <summary>Get an assembly rewriter for the target platform.</summary>
|
||||
/// <param name="targetPlatform">The target game platform.</param>
|
||||
private AssemblyTypeRewriter GetAssemblyRewriter(Platform targetPlatform)
|
||||
{
|
||||
// get assembly changes needed for platform
|
||||
string[] removeAssemblyReferences;
|
||||
Assembly[] targetAssemblies;
|
||||
switch (targetPlatform)
|
||||
{
|
||||
case Platform.Mono:
|
||||
removeAssemblyReferences = new[]
|
||||
{
|
||||
"Stardew Valley",
|
||||
"Microsoft.Xna.Framework",
|
||||
"Microsoft.Xna.Framework.Game",
|
||||
"Microsoft.Xna.Framework.Graphics"
|
||||
};
|
||||
targetAssemblies = new[]
|
||||
{
|
||||
typeof(StardewValley.Game1).Assembly,
|
||||
typeof(Microsoft.Xna.Framework.Vector2).Assembly
|
||||
};
|
||||
break;
|
||||
|
||||
case Platform.Windows:
|
||||
removeAssemblyReferences = new[]
|
||||
{
|
||||
"StardewValley",
|
||||
"MonoGame.Framework"
|
||||
};
|
||||
targetAssemblies = new[]
|
||||
{
|
||||
typeof(StardewValley.Game1).Assembly,
|
||||
typeof(Microsoft.Xna.Framework.Vector2).Assembly,
|
||||
typeof(Microsoft.Xna.Framework.Game).Assembly,
|
||||
typeof(Microsoft.Xna.Framework.Graphics.SpriteBatch).Assembly
|
||||
};
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new InvalidOperationException($"Unknown target platform '{targetPlatform}'.");
|
||||
}
|
||||
|
||||
return new AssemblyTypeRewriter(targetAssemblies, removeAssemblyReferences);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -214,6 +214,7 @@
|
|||
<Compile Include="Events\TimeEvents.cs" />
|
||||
<Compile Include="Extensions.cs" />
|
||||
<Compile Include="Framework\AssemblyRewriting\CachePaths.cs" />
|
||||
<Compile Include="Framework\AssemblyRewriting\AssemblyTypeRewriter.cs" />
|
||||
<Compile Include="Framework\DeprecationLevel.cs" />
|
||||
<Compile Include="Framework\DeprecationManager.cs" />
|
||||
<Compile Include="Framework\InternalExtensions.cs" />
|
||||
|
|
Loading…
Reference in New Issue