diff --git a/AI/AIRunner/AIRunner.csproj b/AI/AIRunner/AIRunner.csproj index 36b3bc9..6d0efc3 100644 --- a/AI/AIRunner/AIRunner.csproj +++ b/AI/AIRunner/AIRunner.csproj @@ -8,8 +8,10 @@ + + diff --git a/AI/AIRunner/TestCommandRunner.cs b/AI/AIRunner/TestCommandRunner.cs index 66db1a9..125379b 100644 --- a/AI/AIRunner/TestCommandRunner.cs +++ b/AI/AIRunner/TestCommandRunner.cs @@ -6,6 +6,8 @@ using PkmnLib.Plugin.Gen7; using PkmnLib.Static.Species; using PkmnLib.Static.Utils; using Serilog; +using Serilog.Events; +using ShellProgressBar; namespace AIRunner; @@ -20,6 +22,7 @@ public static class TestCommandRunner Log.Information("Running {Battles} battles between {AI1} and {AI2}", battles, ai1.Name, ai2.Name); var averageTimePerTurnPerBattle = new List(battles); + var turnsPerBattle = new ConcurrentBag(); var results = new ConcurrentBag(); var rootRandom = new RandomImpl(); @@ -31,27 +34,51 @@ public static class TestCommandRunner randoms[i] = new RandomImpl(rootRandom.GetInt()); battleTasks[i] = Task.CompletedTask; // Initialize tasks to avoid null references } - // Here you would implement the logic to run the AI scripts against each other. - // This is a placeholder for demonstration purposes. + + // Show a progress bar if debug logging is not enabled. + // This is to avoid weird console output where the progress bar is drawn in the middle of debug logs. + ProgressBar? pb = null; + if (!Log.IsEnabled(LogEventLevel.Debug)) + { + pb = new ProgressBar(battles, "Running battles...", new ProgressBarOptions + { + ShowEstimatedDuration = true, + ProgressBarOnBottom = true, + }); + pb.EstimatedDuration = TimeSpan.FromMilliseconds(battles); + } for (var i = 0; i < battles; i++) { var taskIndex = i % maxTasks; var index = i; var battleTask = Task.Run(async () => { - Log.Information("Battle {BattleNumber}: {AI1} vs {AI2}", index + 1, ai1.Name, ai2.Name); + Log.Debug("Battle {BattleNumber}: {AI1} vs {AI2}", index + 1, ai1.Name, ai2.Name); var random = randoms[taskIndex]; var battle = GenerateBattle(library, 3, random); var timePerTurn = new List(20); while (!battle.HasEnded) { + if (battle.CurrentTurnNumber > 1000) + { + Log.Warning("Battle {BattleNumber} exceeded 1000 turns, ending battle early", index + 1); + battle.ForceEndBattle(); + var last10Choices = battle.PreviousTurnChoices.TakeLast(10).ToList(); + Log.Warning("Last 10 choices: {Choices}", last10Choices); + + return; + } + var res = await GetAndSetChoices(battle, ai1, ai2); timePerTurn.Add(res.MsPerTurn); } var result = battle.Result; - Log.Information("Battle {BattleNumber} ended with result: {Result}", index + 1, result); + Log.Debug("Battle {BattleNumber} ended with result: {Result}", index + 1, result); averageTimePerTurnPerBattle.Add(timePerTurn.Average()); results.Add(result.Value); + turnsPerBattle.Add(battle.CurrentTurnNumber); + // ReSharper disable once AccessToDisposedClosure + pb?.Tick(); }); battleTasks[taskIndex] = battleTask; if (i % maxTasks == maxTasks - 1 || i == battles - 1) @@ -62,11 +89,13 @@ public static class TestCommandRunner Array.Fill(battleTasks, Task.CompletedTask); // Reset tasks for the next batch } } + pb?.Dispose(); var t2 = DateTime.UtcNow; Log.Information("{Amount} battles completed in {Duration} ms", battles, (t2 - t1).TotalMilliseconds); var averageTimePerTurn = averageTimePerTurnPerBattle.Average(); Log.Information("Average time per turn: {AverageTimePerTurn} ms", averageTimePerTurn); + Log.Information("Average turns per battle: {AverageTurnsPerBattle}", turnsPerBattle.Average(x => x)); var winCount1 = results.Count(x => x.WinningSide == 0); var winCount2 = results.Count(x => x.WinningSide == 1); @@ -123,26 +152,32 @@ public static class TestCommandRunner private static async Task GetAndSetChoices(BattleImpl battle, PokemonAI ai1, PokemonAI ai2) { var pokemon1 = battle.Sides[0].Pokemon[0]; - if (pokemon1 is null) + while (pokemon1 is null && !battle.HasEnded) { pokemon1 = battle.Parties[0].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable); if (pokemon1 is null) throw new InvalidOperationException("No usable Pokémon found in party 1."); battle.Sides[0].SwapPokemon(0, pokemon1); + pokemon1 = battle.Sides[0].Pokemon[0]; } var pokemon2 = battle.Sides[1].Pokemon[0]; - if (pokemon2 is null) + while (pokemon2 is null && !battle.HasEnded) { pokemon2 = battle.Parties[1].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable); if (pokemon2 is null) throw new InvalidOperationException("No usable Pokémon found in party 2."); battle.Sides[1].SwapPokemon(0, pokemon2); + pokemon2 = battle.Sides[1].Pokemon[0]; + } + if (pokemon1 is null || pokemon2 is null) + { + throw new InvalidOperationException("Both Pokémon must be non-null to proceed with the battle."); } - var taskAiOne = !battle.HasForcedTurn(pokemon1, out var choice1) + var taskAiOne = !battle.HasForcedTurn(pokemon1!, out var choice1) ? Task.Run(() => ai1.GetChoice(battle, pokemon1)) : Task.FromResult(choice1); - var taskAiTwo = !battle.HasForcedTurn(pokemon2, out var choice2) + var taskAiTwo = !battle.HasForcedTurn(pokemon2!, out var choice2) ? Task.Run(() => ai2.GetChoice(battle, pokemon2)) : Task.FromResult(choice2); await Task.WhenAll(taskAiOne, taskAiTwo); diff --git a/Directory.Packages.props b/Directory.Packages.props index 582e4fe..fca5803 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -4,9 +4,11 @@ + + diff --git a/PkmnLib.Dynamic/AI/HighestDamageAI.cs b/PkmnLib.Dynamic/AI/HighestDamageAI.cs index 4ab6a1c..a9a9857 100644 --- a/PkmnLib.Dynamic/AI/HighestDamageAI.cs +++ b/PkmnLib.Dynamic/AI/HighestDamageAI.cs @@ -5,6 +5,9 @@ using PkmnLib.Static.Utils; namespace PkmnLib.Dynamic.AI; +/// +/// HighestDamageAI is an AI that selects the move that it expects to deal the highest damage. +/// public class HighestDamageAI : PokemonAI { /// diff --git a/PkmnLib.Dynamic/AI/PrescientAI.cs b/PkmnLib.Dynamic/AI/PrescientAI.cs new file mode 100644 index 0000000..4d63f7a --- /dev/null +++ b/PkmnLib.Dynamic/AI/PrescientAI.cs @@ -0,0 +1,87 @@ +using PkmnLib.Dynamic.Models; +using PkmnLib.Dynamic.Models.Choices; +using PkmnLib.Static.Utils; + +namespace PkmnLib.Dynamic.AI; + +/// +/// PrescientAI is an AI that predicts the best move based on the current state of the battle. +/// This is slightly cheaty, as it simulates the battle with each possible move to find the best one. +/// +public class PrescientAI : PokemonAI +{ + private static readonly PokemonAI OpponentAI = new HighestDamageAI(); + + /// + public PrescientAI() : base("Prescient") + { + } + + /// + public override ITurnChoice GetChoice(IBattle battle, IPokemon pokemon) + { + var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0; + var moves = pokemon.Moves.WhereNotNull().Where(x => battle.CanUse(new MoveChoice(pokemon, x, opponentSide, 0))) + .ToList(); + + var choices = ScoreChoices(battle, moves, pokemon).OrderByDescending(x => x.Score).ToList(); + if (choices.Count == 0) + { + return battle.Library.MiscLibrary.ReplacementChoice(pokemon, opponentSide, 0); + } + var bestChoice = choices.First().Choice; + return bestChoice; + } + + private static IEnumerable<(ITurnChoice Choice, float Score)> ScoreChoices(IBattle battle, + IReadOnlyList moves, IPokemon pokemon) + { + var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0; + foreach (var learnedMoveOriginal in moves.WhereNotNull()) + { + var battleClone = battle.DeepClone(); + var pokemonClone = battleClone.Sides[pokemon.BattleData!.SideIndex].Pokemon[pokemon.BattleData.Position]!; + var learnedMove = pokemonClone.Moves.WhereNotNull() + .First(m => m.MoveData.Name == learnedMoveOriginal.MoveData.Name); + var choice = new MoveChoice(pokemonClone, learnedMove, opponentSide, 0); + var opponentChoice = GetOpponentChoice(battleClone, pokemonClone); + if (!battleClone.TrySetChoice(opponentChoice)) + { + var replacementChoice = + battleClone.Library.MiscLibrary.ReplacementChoice(pokemonClone, opponentSide, 0); + if (!battleClone.TrySetChoice(replacementChoice)) + { + throw new InvalidOperationException( + "Could not set opponent choice or replacement choice in battle clone."); + } + } + if (battleClone.TrySetChoice(choice)) + { + var score = CalculateScore(battleClone.Parties[pokemon.BattleData.SideIndex], + battleClone.Parties[opponentSide]); + var realChoice = new MoveChoice(pokemon, learnedMoveOriginal, opponentSide, 0); + yield return (realChoice, score); + } + } + } + + private static ITurnChoice GetOpponentChoice(IBattle battle, IPokemon pokemon) + { + var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0; + var opponent = battle.Sides[opponentSide].Pokemon[0]; + if (opponent is null) + { + throw new InvalidOperationException("Opponent Pokemon is null."); + } + if (battle.HasForcedTurn(opponent, out var forcedChoice)) + { + return forcedChoice; + } + + return OpponentAI.GetChoice(battle, opponent); + } + + private static float CalculateScore(IBattleParty ownParty, IBattleParty opponentParty) => + ownParty.Party.WhereNotNull().Sum(x => x.CurrentHealth / (float)x.MaxHealth) - + opponentParty.Party.WhereNotNull().Sum(x => x.CurrentHealth / (float)x.MaxHealth); +} \ No newline at end of file diff --git a/PkmnLib.Dynamic/Models/Battle.cs b/PkmnLib.Dynamic/Models/Battle.cs index 977d938..0b243f2 100644 --- a/PkmnLib.Dynamic/Models/Battle.cs +++ b/PkmnLib.Dynamic/Models/Battle.cs @@ -325,7 +325,7 @@ public class BattleImpl : ScriptSource, IBattle // Always allow moves such as Struggle. If we block this, we can run into an infinite loop if (Library.MiscLibrary.IsReplacementChoice(choice)) return true; - if (HasForcedTurn(choice.User, out var forcedChoice) && !Equals(choice, forcedChoice)) + if (HasForcedTurn(choice.User, out var forcedChoice) && !IsValidForForcedTurn(forcedChoice, choice)) return false; if (choice is IMoveChoice moveChoice) @@ -343,6 +343,20 @@ public class BattleImpl : ScriptSource, IBattle } return true; + + bool IsValidForForcedTurn(ITurnChoice forcedChoice, ITurnChoice choiceToCheck) + { + // If the forced choice is a move choice, we can only use it if the move is the same + if (forcedChoice is IMoveChoice forcedMove && choiceToCheck is IMoveChoice moveChoice) + { + return forcedMove.ChosenMove.MoveData.Name == moveChoice.ChosenMove.MoveData.Name; + } + if (forcedChoice is IPassChoice && choiceToCheck is IPassChoice) + { + return true; // Both are pass choices, so they are valid + } + return forcedChoice.Equals(choiceToCheck); + } } /// diff --git a/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs b/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs index 639c77d..b00e5bd 100644 --- a/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs +++ b/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs @@ -155,7 +155,8 @@ public static class ScriptExecution Action hook) { List? suppressedCategories = null; - foreach (var container in source.SelectMany(x => x)) + var iterator = new ScriptIterator(source); + foreach (var container in iterator) { if (container.IsEmpty) continue; @@ -163,7 +164,7 @@ public static class ScriptExecution if (script is IScriptOnBeforeAnyHookInvoked onBeforeAnyHookInvoked) onBeforeAnyHookInvoked.OnBeforeAnyHookInvoked(ref suppressedCategories); } - foreach (var container in source.SelectMany(x => x)) + foreach (var container in iterator) { if (container.IsEmpty) continue; diff --git a/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs b/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs index efbbdb8..83223f3 100644 --- a/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs +++ b/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs @@ -10,7 +10,7 @@ namespace PkmnLib.Dynamic.ScriptHandling; /// We can add, remove, and clear scripts from the set. /// This is generally used for volatile scripts. /// -public interface IScriptSet : IEnumerable +public interface IScriptSet : IEnumerable, IDeepCloneable { /// /// Adds a script to the set. If the script with that name already exists in this set, this @@ -97,15 +97,7 @@ public class ScriptSet : IScriptSet } /// - public IEnumerator GetEnumerator() - { - var currentIndex = 0; - while (currentIndex < _scripts.Count) - { - yield return _scripts[currentIndex]; - currentIndex++; - } - } + public IEnumerator GetEnumerator() => _scripts.GetEnumerator(); /// IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); diff --git a/PkmnLib.Tests/Static/DeepCloneTests.cs b/PkmnLib.Tests/Static/DeepCloneTests.cs index 570c868..b0ce4ff 100644 --- a/PkmnLib.Tests/Static/DeepCloneTests.cs +++ b/PkmnLib.Tests/Static/DeepCloneTests.cs @@ -2,6 +2,8 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection; using Pcg; using PkmnLib.Dynamic.Models; +using PkmnLib.Plugin.Gen7.Scripts.Moves; +using PkmnLib.Plugin.Gen7.Scripts.Pokemon; using PkmnLib.Static; using PkmnLib.Static.Species; using PkmnLib.Static.Utils; @@ -113,6 +115,7 @@ public class DeepCloneTests battle.Sides[1].SwapPokemon(0, party2[0]); party1[0]!.ChangeStatBoost(Statistic.Defense, 2, true, false); await Assert.That(party1[0]!.StatBoost.Defense).IsEqualTo((sbyte)2); + party1[0]!.Volatile.Add(new ChargeBounceEffect(party1[0]!)); var clone = battle.DeepClone(); await Assert.That(clone).IsNotEqualTo(battle); @@ -125,12 +128,26 @@ public class DeepCloneTests await Assert.That(clone.Library).IsEqualTo(battle.Library); var pokemon = clone.Sides[0].Pokemon[0]!; + await Assert.That(pokemon).IsNotNull(); + await Assert.That(pokemon).IsNotEqualTo(battle.Sides[0].Pokemon[0]); await Assert.That(pokemon.BattleData).IsNotNull(); await Assert.That(pokemon.BattleData).IsNotEqualTo(battle.Sides[0].Pokemon[0]!.BattleData); await Assert.That(pokemon.BattleData!.Battle).IsEqualTo(clone); await Assert.That(pokemon.BattleData!.SeenOpponents).Contains(clone.Sides[1].Pokemon[0]!); await Assert.That(pokemon.BattleData!.SeenOpponents).DoesNotContain(battle.Sides[1].Pokemon[0]!); await Assert.That(pokemon.StatBoost.Defense).IsEqualTo((sbyte)2); + await Assert.That(pokemon.Volatile.Get()).IsNotNull(); + await Assert.That(pokemon.Volatile.Get()).IsNotEqualTo( + battle.Sides[0].Pokemon[0]!.Volatile.Get()); + + var ownerGetter = + typeof(ChargeBounceEffect).GetField("_owner", BindingFlags.NonPublic | BindingFlags.Instance)!; + var owner = ownerGetter.GetValue(pokemon.Volatile.Get()!); + await Assert.That(owner).IsEqualTo(pokemon); + pokemon.Volatile.Remove(); + + await Assert.That(pokemon.Volatile.Get()).IsNull(); + await Assert.That(battle.Sides[0].Pokemon[0]!.Volatile.Get()).IsNotNull(); } ///