SMAPI/ModLoader/Harmony/MethodPatcher.cs

419 lines
14 KiB
C#

using Harmony.ILCopying;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
namespace Harmony
{
public static class MethodPatcher
{
// special parameter names that can be used in prefix and postfix methods
//
public static string INSTANCE_PARAM = "__instance";
public static string ORIGINAL_METHOD_PARAM = "__originalMethod";
public static string RESULT_VAR = "__result";
public static string STATE_VAR = "__state";
public static string PARAM_INDEX_PREFIX = "__";
public static string INSTANCE_FIELD_PREFIX = "___";
// in case of trouble, set to true to write dynamic method to desktop as a dll
// won't work for all methods because of the inability to extend a type compared
// to the way DynamicTools.CreateDynamicMethod works
//
static readonly bool DEBUG_METHOD_GENERATION_BY_DLL_CREATION = false;
// for fixing old harmony bugs
[UpgradeToLatestVersion(1)]
public static DynamicMethod CreatePatchedMethod(MethodBase original, List<MethodInfo> prefixes, List<MethodInfo> postfixes, List<MethodInfo> transpilers)
{
return CreatePatchedMethod(original, "HARMONY_PATCH_1.1.1", prefixes, postfixes, transpilers);
}
public static DynamicMethod CreatePatchedMethod(MethodBase original, string harmonyInstanceID, List<MethodInfo> prefixes, List<MethodInfo> postfixes, List<MethodInfo> transpilers)
{
try
{
if (HarmonyInstance.DEBUG) FileLog.LogBuffered("### Patch " + original.DeclaringType + ", " + original);
var idx = prefixes.Count() + postfixes.Count();
var patch = DynamicTools.CreateDynamicMethod(original, "_Patch" + idx);
if (patch == null)
return null;
var il = patch.GetILGenerator();
// for debugging
AssemblyBuilder assemblyBuilder = null;
TypeBuilder typeBuilder = null;
if (DEBUG_METHOD_GENERATION_BY_DLL_CREATION)
il = DynamicTools.CreateSaveableMethod(original, "_Patch" + idx, out assemblyBuilder, out typeBuilder);
var originalVariables = DynamicTools.DeclareLocalVariables(original, il);
var privateVars = new Dictionary<string, LocalBuilder>();
LocalBuilder resultVariable = null;
if (idx > 0)
{
resultVariable = DynamicTools.DeclareLocalVariable(il, AccessTools.GetReturnedType(original));
privateVars[RESULT_VAR] = resultVariable;
}
prefixes.ForEach(prefix =>
{
prefix.GetParameters()
.Where(patchParam => patchParam.Name == STATE_VAR)
.Do(patchParam =>
{
var privateStateVariable = DynamicTools.DeclareLocalVariable(il, patchParam.ParameterType);
privateVars[prefix.DeclaringType.FullName] = privateStateVariable;
});
});
var skipOriginalLabel = il.DefineLabel();
var canHaveJump = AddPrefixes(il, original, prefixes, privateVars, skipOriginalLabel);
var copier = new MethodCopier(original, il, originalVariables);
foreach (var transpiler in transpilers)
copier.AddTranspiler(transpiler);
var endLabels = new List<Label>();
var endBlocks = new List<ExceptionBlock>();
copier.Finalize(endLabels, endBlocks);
foreach (var label in endLabels)
Emitter.MarkLabel(il, label);
foreach (var block in endBlocks)
Emitter.MarkBlockAfter(il, block);
if (resultVariable != null)
Emitter.Emit(il, OpCodes.Stloc, resultVariable);
if (canHaveJump)
Emitter.MarkLabel(il, skipOriginalLabel);
AddPostfixes(il, original, postfixes, privateVars, false);
if (resultVariable != null)
Emitter.Emit(il, OpCodes.Ldloc, resultVariable);
AddPostfixes(il, original, postfixes, privateVars, true);
Emitter.Emit(il, OpCodes.Ret);
if (HarmonyInstance.DEBUG)
{
FileLog.LogBuffered("DONE");
FileLog.LogBuffered("");
FileLog.FlushBuffer();
}
// for debugging
if (DEBUG_METHOD_GENERATION_BY_DLL_CREATION)
{
DynamicTools.SaveMethod(assemblyBuilder, typeBuilder);
return null;
}
DynamicTools.PrepareDynamicMethod(patch);
return patch;
}
catch (Exception ex)
{
throw new Exception("Exception from HarmonyInstance \"" + harmonyInstanceID + "\"", ex);
}
finally
{
if (HarmonyInstance.DEBUG)
FileLog.FlushBuffer();
}
}
static OpCode LoadIndOpCodeFor(Type type)
{
if (type.IsEnum) return OpCodes.Ldind_I4;
if (type == typeof(float)) return OpCodes.Ldind_R4;
if (type == typeof(double)) return OpCodes.Ldind_R8;
if (type == typeof(byte)) return OpCodes.Ldind_U1;
if (type == typeof(ushort)) return OpCodes.Ldind_U2;
if (type == typeof(uint)) return OpCodes.Ldind_U4;
if (type == typeof(ulong)) return OpCodes.Ldind_I8;
if (type == typeof(sbyte)) return OpCodes.Ldind_I1;
if (type == typeof(short)) return OpCodes.Ldind_I2;
if (type == typeof(int)) return OpCodes.Ldind_I4;
if (type == typeof(long)) return OpCodes.Ldind_I8;
return OpCodes.Ldind_Ref;
}
static HarmonyArgument GetArgumentAttribute(this ParameterInfo parameter)
{
return parameter.GetCustomAttributes(false).FirstOrDefault(attr => attr is HarmonyArgument) as HarmonyArgument;
}
static HarmonyArgument[] GetArgumentAttributes(this MethodInfo method)
{
return method.GetCustomAttributes(false).Where(attr => attr is HarmonyArgument).Cast<HarmonyArgument>().ToArray();
}
static HarmonyArgument[] GetArgumentAttributes(this Type type)
{
return type.GetCustomAttributes(false).Where(attr => attr is HarmonyArgument).Cast<HarmonyArgument>().ToArray();
}
static string GetOriginalArgumentName(this ParameterInfo parameter, string[] originalParameterNames)
{
var attribute = parameter.GetArgumentAttribute();
if (attribute == null)
return null;
if (string.IsNullOrEmpty(attribute.OriginalName) == false)
return attribute.OriginalName;
if (attribute.Index >= 0 && attribute.Index < originalParameterNames.Length)
return originalParameterNames[attribute.Index];
return null;
}
static string GetOriginalArgumentName(HarmonyArgument[] attributes, string name, string[] originalParameterNames)
{
if (attributes.Length <= 0)
return null;
var attribute = attributes.SingleOrDefault(p => p.NewName == name);
if (attribute == null)
return null;
if (string.IsNullOrEmpty(attribute.OriginalName) == false)
return attribute.OriginalName;
if (attribute.Index >= 0 && attribute.Index < originalParameterNames.Length)
return originalParameterNames[attribute.Index];
return null;
}
static string GetOriginalArgumentName(this MethodInfo method, string[] originalParameterNames, string name)
{
string argumentName;
argumentName = GetOriginalArgumentName(method.GetArgumentAttributes(), name, originalParameterNames);
if (argumentName != null)
return argumentName;
argumentName = GetOriginalArgumentName(method.DeclaringType.GetArgumentAttributes(), name, originalParameterNames);
if (argumentName != null)
return argumentName;
return name;
}
private static int GetArgumentIndex(MethodInfo patch, string[] originalParameterNames, ParameterInfo patchParam)
{
var originalName = patchParam.GetOriginalArgumentName(originalParameterNames);
if (originalName != null)
return Array.IndexOf(originalParameterNames, originalName);
var patchParamName = patchParam.Name;
originalName = patch.GetOriginalArgumentName(originalParameterNames, patchParamName);
if (originalName != null)
return Array.IndexOf(originalParameterNames, originalName);
return -1;
}
static MethodInfo getMethodMethod = typeof(MethodBase).GetMethod("GetMethodFromHandle", new[] { typeof(RuntimeMethodHandle) });
static void EmitCallParameter(ILGenerator il, MethodBase original, MethodInfo patch, Dictionary<string, LocalBuilder> variables, bool allowFirsParamPassthrough)
{
var isInstance = original.IsStatic == false;
var originalParameters = original.GetParameters();
var originalParameterNames = originalParameters.Select(p => p.Name).ToArray();
// check for passthrough using first parameter (which must have same type as return type)
var parameters = patch.GetParameters().ToList();
if (allowFirsParamPassthrough && patch.ReturnType != typeof(void) && parameters.Count > 0 && parameters[0].ParameterType == patch.ReturnType)
parameters.RemoveRange(0, 1);
foreach (var patchParam in parameters)
{
if (patchParam.Name == ORIGINAL_METHOD_PARAM)
{
var constructorInfo = original as ConstructorInfo;
if (constructorInfo != null)
{
Emitter.Emit(il, OpCodes.Ldtoken, constructorInfo);
Emitter.Emit(il, OpCodes.Call, getMethodMethod);
continue;
}
var methodInfo = original as MethodInfo;
if (methodInfo != null)
{
Emitter.Emit(il, OpCodes.Ldtoken, methodInfo);
Emitter.Emit(il, OpCodes.Call, getMethodMethod);
continue;
}
Emitter.Emit(il, OpCodes.Ldnull);
continue;
}
if (patchParam.Name == INSTANCE_PARAM)
{
if (original.IsStatic)
Emitter.Emit(il, OpCodes.Ldnull);
else if (patchParam.ParameterType.IsByRef)
Emitter.Emit(il, OpCodes.Ldarga, 0); // probably won't work or will be useless
else
Emitter.Emit(il, OpCodes.Ldarg_0);
continue;
}
if (patchParam.Name.StartsWith(INSTANCE_FIELD_PREFIX))
{
var fieldName = patchParam.Name.Substring(INSTANCE_FIELD_PREFIX.Length);
FieldInfo fieldInfo;
if (fieldName.All(char.IsDigit))
{
fieldInfo = AccessTools.Field(original.DeclaringType, int.Parse(fieldName));
if (fieldInfo == null)
throw new ArgumentException("No field found at given index in class " + original.DeclaringType.FullName, fieldName);
}
else
{
fieldInfo = AccessTools.Field(original.DeclaringType, fieldName);
if (fieldInfo == null)
throw new ArgumentException("No such field defined in class " + original.DeclaringType.FullName, fieldName);
}
if (fieldInfo.IsStatic)
{
if (patchParam.ParameterType.IsByRef)
Emitter.Emit(il, OpCodes.Ldsflda, fieldInfo);
else
Emitter.Emit(il, OpCodes.Ldsfld, fieldInfo);
}
else
{
if (patchParam.ParameterType.IsByRef)
{
Emitter.Emit(il, OpCodes.Ldarg_0);
Emitter.Emit(il, OpCodes.Ldflda, fieldInfo);
}
else
{
Emitter.Emit(il, OpCodes.Ldarg_0);
Emitter.Emit(il, OpCodes.Ldfld, fieldInfo);
}
}
continue;
}
if (patchParam.Name == STATE_VAR)
{
var ldlocCode = patchParam.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc;
Emitter.Emit(il, ldlocCode, variables[patch.DeclaringType.FullName]);
continue;
}
if (patchParam.Name == RESULT_VAR)
{
if (AccessTools.GetReturnedType(original) == typeof(void))
throw new Exception("Cannot get result from void method " + original.FullDescription());
var ldlocCode = patchParam.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc;
Emitter.Emit(il, ldlocCode, variables[RESULT_VAR]);
continue;
}
int idx;
if (patchParam.Name.StartsWith(PARAM_INDEX_PREFIX))
{
var val = patchParam.Name.Substring(PARAM_INDEX_PREFIX.Length);
if (!int.TryParse(val, out idx))
throw new Exception("Parameter " + patchParam.Name + " does not contain a valid index");
if (idx < 0 || idx >= originalParameters.Length)
throw new Exception("No parameter found at index " + idx);
}
else
{
idx = GetArgumentIndex(patch, originalParameterNames, patchParam);
if (idx == -1) throw new Exception("Parameter \"" + patchParam.Name + "\" not found in method " + original.FullDescription());
}
// original -> patch opcode
// --------------------------------------
// 1 normal -> normal : LDARG
// 2 normal -> ref/out : LDARGA
// 3 ref/out -> normal : LDARG, LDIND_x
// 4 ref/out -> ref/out : LDARG
//
var originalIsNormal = originalParameters[idx].IsOut == false && originalParameters[idx].ParameterType.IsByRef == false;
var patchIsNormal = patchParam.IsOut == false && patchParam.ParameterType.IsByRef == false;
var patchArgIndex = idx + (isInstance ? 1 : 0);
// Case 1 + 4
if (originalIsNormal == patchIsNormal)
{
Emitter.Emit(il, OpCodes.Ldarg, patchArgIndex);
continue;
}
// Case 2
if (originalIsNormal && patchIsNormal == false)
{
Emitter.Emit(il, OpCodes.Ldarga, patchArgIndex);
continue;
}
// Case 3
Emitter.Emit(il, OpCodes.Ldarg, patchArgIndex);
Emitter.Emit(il, LoadIndOpCodeFor(originalParameters[idx].ParameterType));
}
}
static bool AddPrefixes(ILGenerator il, MethodBase original, List<MethodInfo> prefixes, Dictionary<string, LocalBuilder> variables, Label label)
{
var canHaveJump = false;
prefixes.ForEach(fix =>
{
EmitCallParameter(il, original, fix, variables, false);
Emitter.Emit(il, OpCodes.Call, fix);
if (fix.ReturnType != typeof(void))
{
if (fix.ReturnType != typeof(bool))
throw new Exception("Prefix patch " + fix + " has not \"bool\" or \"void\" return type: " + fix.ReturnType);
Emitter.Emit(il, OpCodes.Brfalse, label);
canHaveJump = true;
}
});
return canHaveJump;
}
static void AddPostfixes(ILGenerator il, MethodBase original, List<MethodInfo> postfixes, Dictionary<string, LocalBuilder> variables, bool passthroughPatches)
{
postfixes
.Where(fix => passthroughPatches == (fix.ReturnType != typeof(void)))
.Do(fix =>
{
EmitCallParameter(il, original, fix, variables, true);
Emitter.Emit(il, OpCodes.Call, fix);
if (fix.ReturnType != typeof(void))
{
var firstFixParam = fix.GetParameters().FirstOrDefault();
var hasPassThroughResultParam = firstFixParam != null && fix.ReturnType == firstFixParam.ParameterType;
if (!hasPassThroughResultParam)
{
if (firstFixParam != null)
throw new Exception("Return type of postfix patch " + fix + " does match type of its first parameter");
throw new Exception("Postfix patch " + fix + " must have a \"void\" return type");
}
}
});
}
}
}