Support deep cloning for BattleRandom

This commit is contained in:
Deukhoofd 2024-12-29 15:00:15 +01:00
parent 40803f0269
commit 9bdd584b54
Signed by: Deukhoofd
GPG Key ID: F63E044490819F6F
3 changed files with 64 additions and 15 deletions

View File

@ -6,7 +6,7 @@ namespace PkmnLib.Dynamic.Models;
/// <summary> /// <summary>
/// Random number generator for battles. /// Random number generator for battles.
/// </summary> /// </summary>
public interface IBattleRandom : IRandom public interface IBattleRandom : IRandom, IDeepCloneable
{ {
/// <summary> /// <summary>
/// Gets whether or not a move triggers its secondary effect. This takes its chance, and /// Gets whether or not a move triggers its secondary effect. This takes its chance, and

View File

@ -2,6 +2,7 @@ using System.Collections;
using System.Linq.Expressions; using System.Linq.Expressions;
using System.Reflection; using System.Reflection;
using System.Runtime.Serialization; using System.Runtime.Serialization;
using Pcg;
namespace PkmnLib.Static.Utils; namespace PkmnLib.Static.Utils;
@ -25,11 +26,16 @@ public static class DeepCloneHandler
/// references. /// references.
/// </summary> /// </summary>
public static T DeepClone<T>(this T? obj, Dictionary<(Type, int), object>? objects = null) where T : IDeepCloneable public static T DeepClone<T>(this T? obj, Dictionary<(Type, int), object>? objects = null) where T : IDeepCloneable
{
return (T)DeepClone((object?)obj, objects)!;
}
private static object? DeepClone(this object? obj, Dictionary<(Type, int), object>? objects = null)
{ {
if (obj == null) if (obj == null)
return default!; return null;
if (objects != null && objects.TryGetValue((obj.GetType(), obj.GetHashCode()), out var value)) if (objects != null && objects.TryGetValue((obj.GetType(), obj.GetHashCode()), out var value))
return (T)value; return value;
var type = obj.GetType(); var type = obj.GetType();
// We use GetUninitializedObject to create an object without calling the constructor. This is necessary to prevent // We use GetUninitializedObject to create an object without calling the constructor. This is necessary to prevent
@ -52,24 +58,30 @@ public static class DeepCloneHandler
setter.Invoke(newObj, cloned); setter.Invoke(newObj, cloned);
} }
return (T)newObj; return newObj;
} }
private static object DeepCloneInternal(object? obj, Type type, Dictionary<(Type, int), object> objects) private static readonly HashSet<Type> ExternalDeepCloneTypes = new()
{
typeof(PcgRandom),
typeof(Pcg32Single),
};
private static object? DeepCloneInternal(object? obj, Type type, Dictionary<(Type, int), object> objects)
{ {
if (obj == null) if (obj == null)
return null!; return null;
// If the object is a value type or a string, we can just return it. // If the object is a value type or a string, we can just return it.
if (type.IsValueType || type == typeof(string)) if (type.IsValueType || type == typeof(string))
return obj; return obj;
// If the object is marked as deep cloneable, we will clone it. // If the object is marked as deep cloneable, we will clone it.
if (type.GetInterface(nameof(IDeepCloneable)) != null) if (type.GetInterface(nameof(IDeepCloneable)) != null || ExternalDeepCloneTypes.Contains(type))
{ {
// If the object is already cloned, we return the cloned object to prevent infinite loops and invalid references. // If the object is already cloned, we return the cloned object to prevent infinite loops and invalid references.
if (objects.TryGetValue((obj.GetType(), obj.GetHashCode()), out var value)) if (objects.TryGetValue((obj.GetType(), obj.GetHashCode()), out var value))
return value; return value;
var o = DeepClone((IDeepCloneable)obj, objects); var o = DeepClone(obj, objects);
return o; return o;
} }
@ -104,7 +116,7 @@ public static class DeepCloneHandler
var newDictionary = (IDictionary)Activator.CreateInstance(type); var newDictionary = (IDictionary)Activator.CreateInstance(type);
foreach (DictionaryEntry entry in dictionary) foreach (DictionaryEntry entry in dictionary)
newDictionary.Add( newDictionary.Add(
DeepCloneInternal(entry.Key, type.GetGenericArguments()[0], objects), DeepCloneInternal(entry.Key, type.GetGenericArguments()[0], objects)!,
DeepCloneInternal(entry.Value, type.GetGenericArguments()[1], objects)); DeepCloneInternal(entry.Value, type.GetGenericArguments()[1], objects));
return newDictionary; return newDictionary;
} }
@ -121,7 +133,7 @@ public static class DeepCloneHandler
/// This method is thread safe, and will only create the expressions once for each type. It returns compiled expressions for /// This method is thread safe, and will only create the expressions once for each type. It returns compiled expressions for
/// each field in the type, so that we can get high performance deep cloning. /// each field in the type, so that we can get high performance deep cloning.
/// </remarks> /// </remarks>
private static (Func<object, object> getter, Action<object, object> setter)[] private static (Func<object, object?> getter, Action<object, object?> setter)[]
GetDeepCloneExpressions(Type type) GetDeepCloneExpressions(Type type)
{ {
// We use a lock here to prevent multiple threads from trying to create the expressions at the same time. // We use a lock here to prevent multiple threads from trying to create the expressions at the same time.
@ -131,7 +143,7 @@ public static class DeepCloneHandler
return value; return value;
var fields = GetFields(type).ToArray(); var fields = GetFields(type).ToArray();
var expressions = new (Func<object, object> getter, Action<object, object> setter)[fields.Length]; var expressions = new (Func<object, object?> getter, Action<object, object?> setter)[fields.Length];
for (var i = 0; i < fields.Length; i++) for (var i = 0; i < fields.Length; i++)
{ {
var field = fields[i]; var field = fields[i];
@ -152,7 +164,7 @@ public static class DeepCloneHandler
// does this. This is not ideal as it is slower, but works for now. // does this. This is not ideal as it is slower, but works for now.
if (field.IsInitOnly) if (field.IsInitOnly)
{ {
void Setter(object instance, object v) void Setter(object instance, object? v)
{ {
field.SetValue(instance, v); field.SetValue(instance, v);
} }
@ -169,7 +181,7 @@ public static class DeepCloneHandler
// 3. Assign the value to the field. // 3. Assign the value to the field.
var assign = Expression.Assign(Expression.Field(cast, field), valueCast); var assign = Expression.Assign(Expression.Field(cast, field), valueCast);
// 4. Wrap the assign in a lambda so we can compile it. // 4. Wrap the assign in a lambda so we can compile it.
var set = Expression.Lambda<Action<object, object>>(assign, obj, valueLambda); var set = Expression.Lambda<Action<object, object?>>(assign, obj, valueLambda);
expressions[i] = (getter: lambda.Compile(), setter: set.Compile()); expressions[i] = (getter: lambda.Compile(), setter: set.Compile());
} }
} }
@ -191,6 +203,6 @@ public static class DeepCloneHandler
return fields; return fields;
} }
private static readonly Dictionary<Type, (Func<object, object> getter, Action<object, object> setter)[]> private static readonly Dictionary<Type, (Func<object, object?> getter, Action<object, object?> setter)[]>
DeepCloneExpressions = new(); DeepCloneExpressions = new();
} }

View File

@ -1,5 +1,6 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Reflection; using System.Reflection;
using Pcg;
using PkmnLib.Dynamic.Models; using PkmnLib.Dynamic.Models;
using PkmnLib.Static; using PkmnLib.Static;
using PkmnLib.Static.Species; using PkmnLib.Static.Species;
@ -123,7 +124,7 @@ public class DeepCloneTests
await Assert.That(clone).IsNotEqualTo(battle); await Assert.That(clone).IsNotEqualTo(battle);
await Assert.That(clone.Sides[0].Pokemon[0]).IsNotEqualTo(battle.Sides[0].Pokemon[0]); await Assert.That(clone.Sides[0].Pokemon[0]).IsNotEqualTo(battle.Sides[0].Pokemon[0]);
await Assert.That(clone.Sides[1].Pokemon[0]).IsNotEqualTo(battle.Sides[1].Pokemon[0]); await Assert.That(clone.Sides[1].Pokemon[0]).IsNotEqualTo(battle.Sides[1].Pokemon[0]);
await Assert.That(clone.Sides[0].Pokemon[0]!.Species).IsEqualTo(battle.Sides[0].Pokemon[0]!.Species); await Assert.That(clone.Sides[0].Pokemon[0]!.Species).IsEqualTo(battle.Sides[0].Pokemon[0]!.Species);
await Assert.That(clone.Sides[1].Pokemon[0]!.Species).IsEqualTo(battle.Sides[1].Pokemon[0]!.Species); await Assert.That(clone.Sides[1].Pokemon[0]!.Species).IsEqualTo(battle.Sides[1].Pokemon[0]!.Species);
@ -137,4 +138,40 @@ public class DeepCloneTests
await Assert.That(pokemon.BattleData!.SeenOpponents).DoesNotContain(battle.Sides[1].Pokemon[0]!); await Assert.That(pokemon.BattleData!.SeenOpponents).DoesNotContain(battle.Sides[1].Pokemon[0]!);
await Assert.That(pokemon.StatBoost.Defense).IsEqualTo((sbyte)2); await Assert.That(pokemon.StatBoost.Defense).IsEqualTo((sbyte)2);
} }
/// <summary>
/// We have custom handling for the random number generator within the deep cloning handling. We need to ensure that
/// the random number generator is cloned correctly, so that the state of the random number generator is the same
/// in the clone as it is in the original, while still being a different instance.
/// </summary>
[Test]
public async Task DeepCloneIntegrationTestsBattleRandom()
{
var battleRandom = new BattleRandomImpl(0);
battleRandom.GetInt();
var clone = battleRandom.DeepClone();
await Assert.That(clone).IsNotEqualTo(battleRandom);
// We hack out way into the private fields of the random number generator to ensure that the state is the same
// in the clone as it is in the original. None of this is part of the public API, so we use reflection to
// access the private fields.
var pcgRandomField = typeof(RandomImpl).GetField("_random", BindingFlags.NonPublic | BindingFlags.Instance)!;
var pcgRandom = (PcgRandom)pcgRandomField.GetValue(battleRandom)!;
var clonePcgRandom = (PcgRandom)pcgRandomField.GetValue(clone)!;
await Assert.That(clonePcgRandom).IsNotEqualTo(pcgRandom);
var pcgRngField = typeof(PcgRandom).GetField("_rng", BindingFlags.NonPublic | BindingFlags.Instance)!;
var pcgRng = pcgRngField.GetValue(pcgRandom)!;
var clonePcgRng = pcgRngField.GetValue(clonePcgRandom)!;
var pcgStateField = pcgRng.GetType().GetField("_state", BindingFlags.NonPublic | BindingFlags.Instance)!;
var pcgState = pcgStateField.GetValue(pcgRng);
var clonePcgState = pcgStateField.GetValue(clonePcgRng);
await Assert.That(clonePcgState).IsEqualTo(pcgState);
var randomNumber = battleRandom.GetInt();
var cloneRandomNumber = clone.GetInt();
await Assert.That(cloneRandomNumber).IsEqualTo(randomNumber);
}
} }