Some initial work on prescient AI, AI runner, and some random fixes
All checks were successful
Build / Build (push) Successful in 1m3s

This commit is contained in:
Deukhoofd 2025-07-05 17:48:51 +02:00
parent 7b25161a8d
commit d57076374f
Signed by: Deukhoofd
GPG Key ID: F63E044490819F6F
9 changed files with 174 additions and 21 deletions

View File

@ -8,8 +8,10 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="iluvadev.ConsoleProgressBar"/>
<PackageReference Include="Serilog"/> <PackageReference Include="Serilog"/>
<PackageReference Include="Serilog.Sinks.Console"/> <PackageReference Include="Serilog.Sinks.Console"/>
<PackageReference Include="ShellProgressBar"/>
<PackageReference Include="System.CommandLine"/> <PackageReference Include="System.CommandLine"/>
</ItemGroup> </ItemGroup>

View File

@ -6,6 +6,8 @@ using PkmnLib.Plugin.Gen7;
using PkmnLib.Static.Species; using PkmnLib.Static.Species;
using PkmnLib.Static.Utils; using PkmnLib.Static.Utils;
using Serilog; using Serilog;
using Serilog.Events;
using ShellProgressBar;
namespace AIRunner; namespace AIRunner;
@ -20,6 +22,7 @@ public static class TestCommandRunner
Log.Information("Running {Battles} battles between {AI1} and {AI2}", battles, ai1.Name, ai2.Name); Log.Information("Running {Battles} battles between {AI1} and {AI2}", battles, ai1.Name, ai2.Name);
var averageTimePerTurnPerBattle = new List<double>(battles); var averageTimePerTurnPerBattle = new List<double>(battles);
var turnsPerBattle = new ConcurrentBag<uint>();
var results = new ConcurrentBag<BattleResult>(); var results = new ConcurrentBag<BattleResult>();
var rootRandom = new RandomImpl(); var rootRandom = new RandomImpl();
@ -31,27 +34,51 @@ public static class TestCommandRunner
randoms[i] = new RandomImpl(rootRandom.GetInt()); randoms[i] = new RandomImpl(rootRandom.GetInt());
battleTasks[i] = Task.CompletedTask; // Initialize tasks to avoid null references battleTasks[i] = Task.CompletedTask; // Initialize tasks to avoid null references
} }
// Here you would implement the logic to run the AI scripts against each other.
// This is a placeholder for demonstration purposes. // Show a progress bar if debug logging is not enabled.
// This is to avoid weird console output where the progress bar is drawn in the middle of debug logs.
ProgressBar? pb = null;
if (!Log.IsEnabled(LogEventLevel.Debug))
{
pb = new ProgressBar(battles, "Running battles...", new ProgressBarOptions
{
ShowEstimatedDuration = true,
ProgressBarOnBottom = true,
});
pb.EstimatedDuration = TimeSpan.FromMilliseconds(battles);
}
for (var i = 0; i < battles; i++) for (var i = 0; i < battles; i++)
{ {
var taskIndex = i % maxTasks; var taskIndex = i % maxTasks;
var index = i; var index = i;
var battleTask = Task.Run(async () => var battleTask = Task.Run(async () =>
{ {
Log.Information("Battle {BattleNumber}: {AI1} vs {AI2}", index + 1, ai1.Name, ai2.Name); Log.Debug("Battle {BattleNumber}: {AI1} vs {AI2}", index + 1, ai1.Name, ai2.Name);
var random = randoms[taskIndex]; var random = randoms[taskIndex];
var battle = GenerateBattle(library, 3, random); var battle = GenerateBattle(library, 3, random);
var timePerTurn = new List<double>(20); var timePerTurn = new List<double>(20);
while (!battle.HasEnded) while (!battle.HasEnded)
{ {
if (battle.CurrentTurnNumber > 1000)
{
Log.Warning("Battle {BattleNumber} exceeded 1000 turns, ending battle early", index + 1);
battle.ForceEndBattle();
var last10Choices = battle.PreviousTurnChoices.TakeLast(10).ToList();
Log.Warning("Last 10 choices: {Choices}", last10Choices);
return;
}
var res = await GetAndSetChoices(battle, ai1, ai2); var res = await GetAndSetChoices(battle, ai1, ai2);
timePerTurn.Add(res.MsPerTurn); timePerTurn.Add(res.MsPerTurn);
} }
var result = battle.Result; var result = battle.Result;
Log.Information("Battle {BattleNumber} ended with result: {Result}", index + 1, result); Log.Debug("Battle {BattleNumber} ended with result: {Result}", index + 1, result);
averageTimePerTurnPerBattle.Add(timePerTurn.Average()); averageTimePerTurnPerBattle.Add(timePerTurn.Average());
results.Add(result.Value); results.Add(result.Value);
turnsPerBattle.Add(battle.CurrentTurnNumber);
// ReSharper disable once AccessToDisposedClosure
pb?.Tick();
}); });
battleTasks[taskIndex] = battleTask; battleTasks[taskIndex] = battleTask;
if (i % maxTasks == maxTasks - 1 || i == battles - 1) if (i % maxTasks == maxTasks - 1 || i == battles - 1)
@ -62,11 +89,13 @@ public static class TestCommandRunner
Array.Fill(battleTasks, Task.CompletedTask); // Reset tasks for the next batch Array.Fill(battleTasks, Task.CompletedTask); // Reset tasks for the next batch
} }
} }
pb?.Dispose();
var t2 = DateTime.UtcNow; var t2 = DateTime.UtcNow;
Log.Information("{Amount} battles completed in {Duration} ms", battles, (t2 - t1).TotalMilliseconds); Log.Information("{Amount} battles completed in {Duration} ms", battles, (t2 - t1).TotalMilliseconds);
var averageTimePerTurn = averageTimePerTurnPerBattle.Average(); var averageTimePerTurn = averageTimePerTurnPerBattle.Average();
Log.Information("Average time per turn: {AverageTimePerTurn} ms", averageTimePerTurn); Log.Information("Average time per turn: {AverageTimePerTurn} ms", averageTimePerTurn);
Log.Information("Average turns per battle: {AverageTurnsPerBattle}", turnsPerBattle.Average(x => x));
var winCount1 = results.Count(x => x.WinningSide == 0); var winCount1 = results.Count(x => x.WinningSide == 0);
var winCount2 = results.Count(x => x.WinningSide == 1); var winCount2 = results.Count(x => x.WinningSide == 1);
@ -123,26 +152,32 @@ public static class TestCommandRunner
private static async Task<GetAndSetChoicesResult> GetAndSetChoices(BattleImpl battle, PokemonAI ai1, PokemonAI ai2) private static async Task<GetAndSetChoicesResult> GetAndSetChoices(BattleImpl battle, PokemonAI ai1, PokemonAI ai2)
{ {
var pokemon1 = battle.Sides[0].Pokemon[0]; var pokemon1 = battle.Sides[0].Pokemon[0];
if (pokemon1 is null) while (pokemon1 is null && !battle.HasEnded)
{ {
pokemon1 = battle.Parties[0].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable); pokemon1 = battle.Parties[0].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable);
if (pokemon1 is null) if (pokemon1 is null)
throw new InvalidOperationException("No usable Pokémon found in party 1."); throw new InvalidOperationException("No usable Pokémon found in party 1.");
battle.Sides[0].SwapPokemon(0, pokemon1); battle.Sides[0].SwapPokemon(0, pokemon1);
pokemon1 = battle.Sides[0].Pokemon[0];
} }
var pokemon2 = battle.Sides[1].Pokemon[0]; var pokemon2 = battle.Sides[1].Pokemon[0];
if (pokemon2 is null) while (pokemon2 is null && !battle.HasEnded)
{ {
pokemon2 = battle.Parties[1].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable); pokemon2 = battle.Parties[1].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable);
if (pokemon2 is null) if (pokemon2 is null)
throw new InvalidOperationException("No usable Pokémon found in party 2."); throw new InvalidOperationException("No usable Pokémon found in party 2.");
battle.Sides[1].SwapPokemon(0, pokemon2); battle.Sides[1].SwapPokemon(0, pokemon2);
pokemon2 = battle.Sides[1].Pokemon[0];
}
if (pokemon1 is null || pokemon2 is null)
{
throw new InvalidOperationException("Both Pokémon must be non-null to proceed with the battle.");
} }
var taskAiOne = !battle.HasForcedTurn(pokemon1, out var choice1) var taskAiOne = !battle.HasForcedTurn(pokemon1!, out var choice1)
? Task.Run(() => ai1.GetChoice(battle, pokemon1)) ? Task.Run(() => ai1.GetChoice(battle, pokemon1))
: Task.FromResult(choice1); : Task.FromResult(choice1);
var taskAiTwo = !battle.HasForcedTurn(pokemon2, out var choice2) var taskAiTwo = !battle.HasForcedTurn(pokemon2!, out var choice2)
? Task.Run(() => ai2.GetChoice(battle, pokemon2)) ? Task.Run(() => ai2.GetChoice(battle, pokemon2))
: Task.FromResult(choice2); : Task.FromResult(choice2);
await Task.WhenAll(taskAiOne, taskAiTwo); await Task.WhenAll(taskAiOne, taskAiTwo);

View File

@ -4,9 +4,11 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageVersion Include="CommandLineParser" Version="2.9.1"/> <PackageVersion Include="CommandLineParser" Version="2.9.1"/>
<PackageVersion Include="iluvadev.ConsoleProgressBar" Version="1.1.0"/>
<PackageVersion Include="PcgRandom" Version="1.2.0"/> <PackageVersion Include="PcgRandom" Version="1.2.0"/>
<PackageVersion Include="Serilog" Version="4.3.0"/> <PackageVersion Include="Serilog" Version="4.3.0"/>
<PackageVersion Include="Serilog.Sinks.Console" Version="6.0.0"/> <PackageVersion Include="Serilog.Sinks.Console" Version="6.0.0"/>
<PackageVersion Include="ShellProgressBar" Version="5.2.0"/>
<PackageVersion Include="System.Collections.Immutable" Version="8.0.0"/> <PackageVersion Include="System.Collections.Immutable" Version="8.0.0"/>
<PackageVersion Include="System.CommandLine" Version="2.0.0-beta5.25306.1"/> <PackageVersion Include="System.CommandLine" Version="2.0.0-beta5.25306.1"/>
<PackageVersion Include="System.Text.Json" Version="8.0.5"/> <PackageVersion Include="System.Text.Json" Version="8.0.5"/>

View File

@ -5,6 +5,9 @@ using PkmnLib.Static.Utils;
namespace PkmnLib.Dynamic.AI; namespace PkmnLib.Dynamic.AI;
/// <summary>
/// HighestDamageAI is an AI that selects the move that it expects to deal the highest damage.
/// </summary>
public class HighestDamageAI : PokemonAI public class HighestDamageAI : PokemonAI
{ {
/// <inheritdoc /> /// <inheritdoc />

View File

@ -0,0 +1,87 @@
using PkmnLib.Dynamic.Models;
using PkmnLib.Dynamic.Models.Choices;
using PkmnLib.Static.Utils;
namespace PkmnLib.Dynamic.AI;
/// <summary>
/// PrescientAI is an AI that predicts the best move based on the current state of the battle.
/// This is slightly cheaty, as it simulates the battle with each possible move to find the best one.
/// </summary>
public class PrescientAI : PokemonAI
{
private static readonly PokemonAI OpponentAI = new HighestDamageAI();
/// <inheritdoc />
public PrescientAI() : base("Prescient")
{
}
/// <inheritdoc />
public override ITurnChoice GetChoice(IBattle battle, IPokemon pokemon)
{
var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0;
var moves = pokemon.Moves.WhereNotNull().Where(x => battle.CanUse(new MoveChoice(pokemon, x, opponentSide, 0)))
.ToList();
var choices = ScoreChoices(battle, moves, pokemon).OrderByDescending(x => x.Score).ToList();
if (choices.Count == 0)
{
return battle.Library.MiscLibrary.ReplacementChoice(pokemon, opponentSide, 0);
}
var bestChoice = choices.First().Choice;
return bestChoice;
}
private static IEnumerable<(ITurnChoice Choice, float Score)> ScoreChoices(IBattle battle,
IReadOnlyList<ILearnedMove> moves, IPokemon pokemon)
{
var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0;
foreach (var learnedMoveOriginal in moves.WhereNotNull())
{
var battleClone = battle.DeepClone();
var pokemonClone = battleClone.Sides[pokemon.BattleData!.SideIndex].Pokemon[pokemon.BattleData.Position]!;
var learnedMove = pokemonClone.Moves.WhereNotNull()
.First(m => m.MoveData.Name == learnedMoveOriginal.MoveData.Name);
var choice = new MoveChoice(pokemonClone, learnedMove, opponentSide, 0);
var opponentChoice = GetOpponentChoice(battleClone, pokemonClone);
if (!battleClone.TrySetChoice(opponentChoice))
{
var replacementChoice =
battleClone.Library.MiscLibrary.ReplacementChoice(pokemonClone, opponentSide, 0);
if (!battleClone.TrySetChoice(replacementChoice))
{
throw new InvalidOperationException(
"Could not set opponent choice or replacement choice in battle clone.");
}
}
if (battleClone.TrySetChoice(choice))
{
var score = CalculateScore(battleClone.Parties[pokemon.BattleData.SideIndex],
battleClone.Parties[opponentSide]);
var realChoice = new MoveChoice(pokemon, learnedMoveOriginal, opponentSide, 0);
yield return (realChoice, score);
}
}
}
private static ITurnChoice GetOpponentChoice(IBattle battle, IPokemon pokemon)
{
var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0;
var opponent = battle.Sides[opponentSide].Pokemon[0];
if (opponent is null)
{
throw new InvalidOperationException("Opponent Pokemon is null.");
}
if (battle.HasForcedTurn(opponent, out var forcedChoice))
{
return forcedChoice;
}
return OpponentAI.GetChoice(battle, opponent);
}
private static float CalculateScore(IBattleParty ownParty, IBattleParty opponentParty) =>
ownParty.Party.WhereNotNull().Sum(x => x.CurrentHealth / (float)x.MaxHealth) -
opponentParty.Party.WhereNotNull().Sum(x => x.CurrentHealth / (float)x.MaxHealth);
}

View File

@ -325,7 +325,7 @@ public class BattleImpl : ScriptSource, IBattle
// Always allow moves such as Struggle. If we block this, we can run into an infinite loop // Always allow moves such as Struggle. If we block this, we can run into an infinite loop
if (Library.MiscLibrary.IsReplacementChoice(choice)) if (Library.MiscLibrary.IsReplacementChoice(choice))
return true; return true;
if (HasForcedTurn(choice.User, out var forcedChoice) && !Equals(choice, forcedChoice)) if (HasForcedTurn(choice.User, out var forcedChoice) && !IsValidForForcedTurn(forcedChoice, choice))
return false; return false;
if (choice is IMoveChoice moveChoice) if (choice is IMoveChoice moveChoice)
@ -343,6 +343,20 @@ public class BattleImpl : ScriptSource, IBattle
} }
return true; return true;
bool IsValidForForcedTurn(ITurnChoice forcedChoice, ITurnChoice choiceToCheck)
{
// If the forced choice is a move choice, we can only use it if the move is the same
if (forcedChoice is IMoveChoice forcedMove && choiceToCheck is IMoveChoice moveChoice)
{
return forcedMove.ChosenMove.MoveData.Name == moveChoice.ChosenMove.MoveData.Name;
}
if (forcedChoice is IPassChoice && choiceToCheck is IPassChoice)
{
return true; // Both are pass choices, so they are valid
}
return forcedChoice.Equals(choiceToCheck);
}
} }
/// <inheritdoc /> /// <inheritdoc />

View File

@ -155,7 +155,8 @@ public static class ScriptExecution
Action<TScriptHook> hook) Action<TScriptHook> hook)
{ {
List<ScriptCategory>? suppressedCategories = null; List<ScriptCategory>? suppressedCategories = null;
foreach (var container in source.SelectMany(x => x)) var iterator = new ScriptIterator(source);
foreach (var container in iterator)
{ {
if (container.IsEmpty) if (container.IsEmpty)
continue; continue;
@ -163,7 +164,7 @@ public static class ScriptExecution
if (script is IScriptOnBeforeAnyHookInvoked onBeforeAnyHookInvoked) if (script is IScriptOnBeforeAnyHookInvoked onBeforeAnyHookInvoked)
onBeforeAnyHookInvoked.OnBeforeAnyHookInvoked(ref suppressedCategories); onBeforeAnyHookInvoked.OnBeforeAnyHookInvoked(ref suppressedCategories);
} }
foreach (var container in source.SelectMany(x => x)) foreach (var container in iterator)
{ {
if (container.IsEmpty) if (container.IsEmpty)
continue; continue;

View File

@ -10,7 +10,7 @@ namespace PkmnLib.Dynamic.ScriptHandling;
/// We can add, remove, and clear scripts from the set. /// We can add, remove, and clear scripts from the set.
/// This is generally used for volatile scripts. /// This is generally used for volatile scripts.
/// </summary> /// </summary>
public interface IScriptSet : IEnumerable<ScriptContainer> public interface IScriptSet : IEnumerable<ScriptContainer>, IDeepCloneable
{ {
/// <summary> /// <summary>
/// Adds a script to the set. If the script with that name already exists in this set, this /// Adds a script to the set. If the script with that name already exists in this set, this
@ -97,15 +97,7 @@ public class ScriptSet : IScriptSet
} }
/// <inheritdoc /> /// <inheritdoc />
public IEnumerator<ScriptContainer> GetEnumerator() public IEnumerator<ScriptContainer> GetEnumerator() => _scripts.GetEnumerator();
{
var currentIndex = 0;
while (currentIndex < _scripts.Count)
{
yield return _scripts[currentIndex];
currentIndex++;
}
}
/// <inheritdoc /> /// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

View File

@ -2,6 +2,8 @@ using System.Diagnostics.CodeAnalysis;
using System.Reflection; using System.Reflection;
using Pcg; using Pcg;
using PkmnLib.Dynamic.Models; using PkmnLib.Dynamic.Models;
using PkmnLib.Plugin.Gen7.Scripts.Moves;
using PkmnLib.Plugin.Gen7.Scripts.Pokemon;
using PkmnLib.Static; using PkmnLib.Static;
using PkmnLib.Static.Species; using PkmnLib.Static.Species;
using PkmnLib.Static.Utils; using PkmnLib.Static.Utils;
@ -113,6 +115,7 @@ public class DeepCloneTests
battle.Sides[1].SwapPokemon(0, party2[0]); battle.Sides[1].SwapPokemon(0, party2[0]);
party1[0]!.ChangeStatBoost(Statistic.Defense, 2, true, false); party1[0]!.ChangeStatBoost(Statistic.Defense, 2, true, false);
await Assert.That(party1[0]!.StatBoost.Defense).IsEqualTo((sbyte)2); await Assert.That(party1[0]!.StatBoost.Defense).IsEqualTo((sbyte)2);
party1[0]!.Volatile.Add(new ChargeBounceEffect(party1[0]!));
var clone = battle.DeepClone(); var clone = battle.DeepClone();
await Assert.That(clone).IsNotEqualTo(battle); await Assert.That(clone).IsNotEqualTo(battle);
@ -125,12 +128,26 @@ public class DeepCloneTests
await Assert.That(clone.Library).IsEqualTo(battle.Library); await Assert.That(clone.Library).IsEqualTo(battle.Library);
var pokemon = clone.Sides[0].Pokemon[0]!; var pokemon = clone.Sides[0].Pokemon[0]!;
await Assert.That(pokemon).IsNotNull();
await Assert.That(pokemon).IsNotEqualTo(battle.Sides[0].Pokemon[0]);
await Assert.That(pokemon.BattleData).IsNotNull(); await Assert.That(pokemon.BattleData).IsNotNull();
await Assert.That(pokemon.BattleData).IsNotEqualTo(battle.Sides[0].Pokemon[0]!.BattleData); await Assert.That(pokemon.BattleData).IsNotEqualTo(battle.Sides[0].Pokemon[0]!.BattleData);
await Assert.That(pokemon.BattleData!.Battle).IsEqualTo(clone); await Assert.That(pokemon.BattleData!.Battle).IsEqualTo(clone);
await Assert.That(pokemon.BattleData!.SeenOpponents).Contains(clone.Sides[1].Pokemon[0]!); await Assert.That(pokemon.BattleData!.SeenOpponents).Contains(clone.Sides[1].Pokemon[0]!);
await Assert.That(pokemon.BattleData!.SeenOpponents).DoesNotContain(battle.Sides[1].Pokemon[0]!); await Assert.That(pokemon.BattleData!.SeenOpponents).DoesNotContain(battle.Sides[1].Pokemon[0]!);
await Assert.That(pokemon.StatBoost.Defense).IsEqualTo((sbyte)2); await Assert.That(pokemon.StatBoost.Defense).IsEqualTo((sbyte)2);
await Assert.That(pokemon.Volatile.Get<ChargeBounceEffect>()).IsNotNull();
await Assert.That(pokemon.Volatile.Get<ChargeBounceEffect>()).IsNotEqualTo(
battle.Sides[0].Pokemon[0]!.Volatile.Get<ChargeBounceEffect>());
var ownerGetter =
typeof(ChargeBounceEffect).GetField("_owner", BindingFlags.NonPublic | BindingFlags.Instance)!;
var owner = ownerGetter.GetValue(pokemon.Volatile.Get<ChargeBounceEffect>()!);
await Assert.That(owner).IsEqualTo(pokemon);
pokemon.Volatile.Remove<ChargeBounceEffect>();
await Assert.That(pokemon.Volatile.Get<ChargeBounceEffect>()).IsNull();
await Assert.That(battle.Sides[0].Pokemon[0]!.Volatile.Get<ChargeBounceEffect>()).IsNotNull();
} }
/// <summary> /// <summary>