diff --git a/AI/AIRunner/AIRunner.csproj b/AI/AIRunner/AIRunner.csproj
index 36b3bc9..6d0efc3 100644
--- a/AI/AIRunner/AIRunner.csproj
+++ b/AI/AIRunner/AIRunner.csproj
@@ -8,8 +8,10 @@
+
+
diff --git a/AI/AIRunner/TestCommandRunner.cs b/AI/AIRunner/TestCommandRunner.cs
index 66db1a9..125379b 100644
--- a/AI/AIRunner/TestCommandRunner.cs
+++ b/AI/AIRunner/TestCommandRunner.cs
@@ -6,6 +6,8 @@ using PkmnLib.Plugin.Gen7;
using PkmnLib.Static.Species;
using PkmnLib.Static.Utils;
using Serilog;
+using Serilog.Events;
+using ShellProgressBar;
namespace AIRunner;
@@ -20,6 +22,7 @@ public static class TestCommandRunner
Log.Information("Running {Battles} battles between {AI1} and {AI2}", battles, ai1.Name, ai2.Name);
var averageTimePerTurnPerBattle = new List(battles);
+ var turnsPerBattle = new ConcurrentBag();
var results = new ConcurrentBag();
var rootRandom = new RandomImpl();
@@ -31,27 +34,51 @@ public static class TestCommandRunner
randoms[i] = new RandomImpl(rootRandom.GetInt());
battleTasks[i] = Task.CompletedTask; // Initialize tasks to avoid null references
}
- // Here you would implement the logic to run the AI scripts against each other.
- // This is a placeholder for demonstration purposes.
+
+ // Show a progress bar if debug logging is not enabled.
+ // This is to avoid weird console output where the progress bar is drawn in the middle of debug logs.
+ ProgressBar? pb = null;
+ if (!Log.IsEnabled(LogEventLevel.Debug))
+ {
+ pb = new ProgressBar(battles, "Running battles...", new ProgressBarOptions
+ {
+ ShowEstimatedDuration = true,
+ ProgressBarOnBottom = true,
+ });
+ pb.EstimatedDuration = TimeSpan.FromMilliseconds(battles);
+ }
for (var i = 0; i < battles; i++)
{
var taskIndex = i % maxTasks;
var index = i;
var battleTask = Task.Run(async () =>
{
- Log.Information("Battle {BattleNumber}: {AI1} vs {AI2}", index + 1, ai1.Name, ai2.Name);
+ Log.Debug("Battle {BattleNumber}: {AI1} vs {AI2}", index + 1, ai1.Name, ai2.Name);
var random = randoms[taskIndex];
var battle = GenerateBattle(library, 3, random);
var timePerTurn = new List(20);
while (!battle.HasEnded)
{
+ if (battle.CurrentTurnNumber > 1000)
+ {
+ Log.Warning("Battle {BattleNumber} exceeded 1000 turns, ending battle early", index + 1);
+ battle.ForceEndBattle();
+ var last10Choices = battle.PreviousTurnChoices.TakeLast(10).ToList();
+ Log.Warning("Last 10 choices: {Choices}", last10Choices);
+
+ return;
+ }
+
var res = await GetAndSetChoices(battle, ai1, ai2);
timePerTurn.Add(res.MsPerTurn);
}
var result = battle.Result;
- Log.Information("Battle {BattleNumber} ended with result: {Result}", index + 1, result);
+ Log.Debug("Battle {BattleNumber} ended with result: {Result}", index + 1, result);
averageTimePerTurnPerBattle.Add(timePerTurn.Average());
results.Add(result.Value);
+ turnsPerBattle.Add(battle.CurrentTurnNumber);
+ // ReSharper disable once AccessToDisposedClosure
+ pb?.Tick();
});
battleTasks[taskIndex] = battleTask;
if (i % maxTasks == maxTasks - 1 || i == battles - 1)
@@ -62,11 +89,13 @@ public static class TestCommandRunner
Array.Fill(battleTasks, Task.CompletedTask); // Reset tasks for the next batch
}
}
+ pb?.Dispose();
var t2 = DateTime.UtcNow;
Log.Information("{Amount} battles completed in {Duration} ms", battles, (t2 - t1).TotalMilliseconds);
var averageTimePerTurn = averageTimePerTurnPerBattle.Average();
Log.Information("Average time per turn: {AverageTimePerTurn} ms", averageTimePerTurn);
+ Log.Information("Average turns per battle: {AverageTurnsPerBattle}", turnsPerBattle.Average(x => x));
var winCount1 = results.Count(x => x.WinningSide == 0);
var winCount2 = results.Count(x => x.WinningSide == 1);
@@ -123,26 +152,32 @@ public static class TestCommandRunner
private static async Task GetAndSetChoices(BattleImpl battle, PokemonAI ai1, PokemonAI ai2)
{
var pokemon1 = battle.Sides[0].Pokemon[0];
- if (pokemon1 is null)
+ while (pokemon1 is null && !battle.HasEnded)
{
pokemon1 = battle.Parties[0].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable);
if (pokemon1 is null)
throw new InvalidOperationException("No usable Pokémon found in party 1.");
battle.Sides[0].SwapPokemon(0, pokemon1);
+ pokemon1 = battle.Sides[0].Pokemon[0];
}
var pokemon2 = battle.Sides[1].Pokemon[0];
- if (pokemon2 is null)
+ while (pokemon2 is null && !battle.HasEnded)
{
pokemon2 = battle.Parties[1].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable);
if (pokemon2 is null)
throw new InvalidOperationException("No usable Pokémon found in party 2.");
battle.Sides[1].SwapPokemon(0, pokemon2);
+ pokemon2 = battle.Sides[1].Pokemon[0];
+ }
+ if (pokemon1 is null || pokemon2 is null)
+ {
+ throw new InvalidOperationException("Both Pokémon must be non-null to proceed with the battle.");
}
- var taskAiOne = !battle.HasForcedTurn(pokemon1, out var choice1)
+ var taskAiOne = !battle.HasForcedTurn(pokemon1!, out var choice1)
? Task.Run(() => ai1.GetChoice(battle, pokemon1))
: Task.FromResult(choice1);
- var taskAiTwo = !battle.HasForcedTurn(pokemon2, out var choice2)
+ var taskAiTwo = !battle.HasForcedTurn(pokemon2!, out var choice2)
? Task.Run(() => ai2.GetChoice(battle, pokemon2))
: Task.FromResult(choice2);
await Task.WhenAll(taskAiOne, taskAiTwo);
diff --git a/Directory.Packages.props b/Directory.Packages.props
index 582e4fe..fca5803 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -4,9 +4,11 @@
+
+
diff --git a/PkmnLib.Dynamic/AI/HighestDamageAI.cs b/PkmnLib.Dynamic/AI/HighestDamageAI.cs
index 4ab6a1c..a9a9857 100644
--- a/PkmnLib.Dynamic/AI/HighestDamageAI.cs
+++ b/PkmnLib.Dynamic/AI/HighestDamageAI.cs
@@ -5,6 +5,9 @@ using PkmnLib.Static.Utils;
namespace PkmnLib.Dynamic.AI;
+///
+/// HighestDamageAI is an AI that selects the move that it expects to deal the highest damage.
+///
public class HighestDamageAI : PokemonAI
{
///
diff --git a/PkmnLib.Dynamic/AI/PrescientAI.cs b/PkmnLib.Dynamic/AI/PrescientAI.cs
new file mode 100644
index 0000000..4d63f7a
--- /dev/null
+++ b/PkmnLib.Dynamic/AI/PrescientAI.cs
@@ -0,0 +1,87 @@
+using PkmnLib.Dynamic.Models;
+using PkmnLib.Dynamic.Models.Choices;
+using PkmnLib.Static.Utils;
+
+namespace PkmnLib.Dynamic.AI;
+
+///
+/// PrescientAI is an AI that predicts the best move based on the current state of the battle.
+/// This is slightly cheaty, as it simulates the battle with each possible move to find the best one.
+///
+public class PrescientAI : PokemonAI
+{
+ private static readonly PokemonAI OpponentAI = new HighestDamageAI();
+
+ ///
+ public PrescientAI() : base("Prescient")
+ {
+ }
+
+ ///
+ public override ITurnChoice GetChoice(IBattle battle, IPokemon pokemon)
+ {
+ var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0;
+ var moves = pokemon.Moves.WhereNotNull().Where(x => battle.CanUse(new MoveChoice(pokemon, x, opponentSide, 0)))
+ .ToList();
+
+ var choices = ScoreChoices(battle, moves, pokemon).OrderByDescending(x => x.Score).ToList();
+ if (choices.Count == 0)
+ {
+ return battle.Library.MiscLibrary.ReplacementChoice(pokemon, opponentSide, 0);
+ }
+ var bestChoice = choices.First().Choice;
+ return bestChoice;
+ }
+
+ private static IEnumerable<(ITurnChoice Choice, float Score)> ScoreChoices(IBattle battle,
+ IReadOnlyList moves, IPokemon pokemon)
+ {
+ var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0;
+ foreach (var learnedMoveOriginal in moves.WhereNotNull())
+ {
+ var battleClone = battle.DeepClone();
+ var pokemonClone = battleClone.Sides[pokemon.BattleData!.SideIndex].Pokemon[pokemon.BattleData.Position]!;
+ var learnedMove = pokemonClone.Moves.WhereNotNull()
+ .First(m => m.MoveData.Name == learnedMoveOriginal.MoveData.Name);
+ var choice = new MoveChoice(pokemonClone, learnedMove, opponentSide, 0);
+ var opponentChoice = GetOpponentChoice(battleClone, pokemonClone);
+ if (!battleClone.TrySetChoice(opponentChoice))
+ {
+ var replacementChoice =
+ battleClone.Library.MiscLibrary.ReplacementChoice(pokemonClone, opponentSide, 0);
+ if (!battleClone.TrySetChoice(replacementChoice))
+ {
+ throw new InvalidOperationException(
+ "Could not set opponent choice or replacement choice in battle clone.");
+ }
+ }
+ if (battleClone.TrySetChoice(choice))
+ {
+ var score = CalculateScore(battleClone.Parties[pokemon.BattleData.SideIndex],
+ battleClone.Parties[opponentSide]);
+ var realChoice = new MoveChoice(pokemon, learnedMoveOriginal, opponentSide, 0);
+ yield return (realChoice, score);
+ }
+ }
+ }
+
+ private static ITurnChoice GetOpponentChoice(IBattle battle, IPokemon pokemon)
+ {
+ var opponentSide = pokemon.BattleData!.SideIndex == 0 ? (byte)1 : (byte)0;
+ var opponent = battle.Sides[opponentSide].Pokemon[0];
+ if (opponent is null)
+ {
+ throw new InvalidOperationException("Opponent Pokemon is null.");
+ }
+ if (battle.HasForcedTurn(opponent, out var forcedChoice))
+ {
+ return forcedChoice;
+ }
+
+ return OpponentAI.GetChoice(battle, opponent);
+ }
+
+ private static float CalculateScore(IBattleParty ownParty, IBattleParty opponentParty) =>
+ ownParty.Party.WhereNotNull().Sum(x => x.CurrentHealth / (float)x.MaxHealth) -
+ opponentParty.Party.WhereNotNull().Sum(x => x.CurrentHealth / (float)x.MaxHealth);
+}
\ No newline at end of file
diff --git a/PkmnLib.Dynamic/Models/Battle.cs b/PkmnLib.Dynamic/Models/Battle.cs
index 977d938..0b243f2 100644
--- a/PkmnLib.Dynamic/Models/Battle.cs
+++ b/PkmnLib.Dynamic/Models/Battle.cs
@@ -325,7 +325,7 @@ public class BattleImpl : ScriptSource, IBattle
// Always allow moves such as Struggle. If we block this, we can run into an infinite loop
if (Library.MiscLibrary.IsReplacementChoice(choice))
return true;
- if (HasForcedTurn(choice.User, out var forcedChoice) && !Equals(choice, forcedChoice))
+ if (HasForcedTurn(choice.User, out var forcedChoice) && !IsValidForForcedTurn(forcedChoice, choice))
return false;
if (choice is IMoveChoice moveChoice)
@@ -343,6 +343,20 @@ public class BattleImpl : ScriptSource, IBattle
}
return true;
+
+ bool IsValidForForcedTurn(ITurnChoice forcedChoice, ITurnChoice choiceToCheck)
+ {
+ // If the forced choice is a move choice, we can only use it if the move is the same
+ if (forcedChoice is IMoveChoice forcedMove && choiceToCheck is IMoveChoice moveChoice)
+ {
+ return forcedMove.ChosenMove.MoveData.Name == moveChoice.ChosenMove.MoveData.Name;
+ }
+ if (forcedChoice is IPassChoice && choiceToCheck is IPassChoice)
+ {
+ return true; // Both are pass choices, so they are valid
+ }
+ return forcedChoice.Equals(choiceToCheck);
+ }
}
///
diff --git a/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs b/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs
index 639c77d..b00e5bd 100644
--- a/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs
+++ b/PkmnLib.Dynamic/ScriptHandling/ScriptExecution.cs
@@ -155,7 +155,8 @@ public static class ScriptExecution
Action hook)
{
List? suppressedCategories = null;
- foreach (var container in source.SelectMany(x => x))
+ var iterator = new ScriptIterator(source);
+ foreach (var container in iterator)
{
if (container.IsEmpty)
continue;
@@ -163,7 +164,7 @@ public static class ScriptExecution
if (script is IScriptOnBeforeAnyHookInvoked onBeforeAnyHookInvoked)
onBeforeAnyHookInvoked.OnBeforeAnyHookInvoked(ref suppressedCategories);
}
- foreach (var container in source.SelectMany(x => x))
+ foreach (var container in iterator)
{
if (container.IsEmpty)
continue;
diff --git a/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs b/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs
index efbbdb8..83223f3 100644
--- a/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs
+++ b/PkmnLib.Dynamic/ScriptHandling/ScriptSet.cs
@@ -10,7 +10,7 @@ namespace PkmnLib.Dynamic.ScriptHandling;
/// We can add, remove, and clear scripts from the set.
/// This is generally used for volatile scripts.
///
-public interface IScriptSet : IEnumerable
+public interface IScriptSet : IEnumerable, IDeepCloneable
{
///
/// Adds a script to the set. If the script with that name already exists in this set, this
@@ -97,15 +97,7 @@ public class ScriptSet : IScriptSet
}
///
- public IEnumerator GetEnumerator()
- {
- var currentIndex = 0;
- while (currentIndex < _scripts.Count)
- {
- yield return _scripts[currentIndex];
- currentIndex++;
- }
- }
+ public IEnumerator GetEnumerator() => _scripts.GetEnumerator();
///
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
diff --git a/PkmnLib.Tests/Static/DeepCloneTests.cs b/PkmnLib.Tests/Static/DeepCloneTests.cs
index 570c868..b0ce4ff 100644
--- a/PkmnLib.Tests/Static/DeepCloneTests.cs
+++ b/PkmnLib.Tests/Static/DeepCloneTests.cs
@@ -2,6 +2,8 @@ using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using Pcg;
using PkmnLib.Dynamic.Models;
+using PkmnLib.Plugin.Gen7.Scripts.Moves;
+using PkmnLib.Plugin.Gen7.Scripts.Pokemon;
using PkmnLib.Static;
using PkmnLib.Static.Species;
using PkmnLib.Static.Utils;
@@ -113,6 +115,7 @@ public class DeepCloneTests
battle.Sides[1].SwapPokemon(0, party2[0]);
party1[0]!.ChangeStatBoost(Statistic.Defense, 2, true, false);
await Assert.That(party1[0]!.StatBoost.Defense).IsEqualTo((sbyte)2);
+ party1[0]!.Volatile.Add(new ChargeBounceEffect(party1[0]!));
var clone = battle.DeepClone();
await Assert.That(clone).IsNotEqualTo(battle);
@@ -125,12 +128,26 @@ public class DeepCloneTests
await Assert.That(clone.Library).IsEqualTo(battle.Library);
var pokemon = clone.Sides[0].Pokemon[0]!;
+ await Assert.That(pokemon).IsNotNull();
+ await Assert.That(pokemon).IsNotEqualTo(battle.Sides[0].Pokemon[0]);
await Assert.That(pokemon.BattleData).IsNotNull();
await Assert.That(pokemon.BattleData).IsNotEqualTo(battle.Sides[0].Pokemon[0]!.BattleData);
await Assert.That(pokemon.BattleData!.Battle).IsEqualTo(clone);
await Assert.That(pokemon.BattleData!.SeenOpponents).Contains(clone.Sides[1].Pokemon[0]!);
await Assert.That(pokemon.BattleData!.SeenOpponents).DoesNotContain(battle.Sides[1].Pokemon[0]!);
await Assert.That(pokemon.StatBoost.Defense).IsEqualTo((sbyte)2);
+ await Assert.That(pokemon.Volatile.Get()).IsNotNull();
+ await Assert.That(pokemon.Volatile.Get()).IsNotEqualTo(
+ battle.Sides[0].Pokemon[0]!.Volatile.Get());
+
+ var ownerGetter =
+ typeof(ChargeBounceEffect).GetField("_owner", BindingFlags.NonPublic | BindingFlags.Instance)!;
+ var owner = ownerGetter.GetValue(pokemon.Volatile.Get()!);
+ await Assert.That(owner).IsEqualTo(pokemon);
+ pokemon.Volatile.Remove();
+
+ await Assert.That(pokemon.Volatile.Get()).IsNull();
+ await Assert.That(battle.Sides[0].Pokemon[0]!.Volatile.Get()).IsNotNull();
}
///