209 lines
9.2 KiB
C#
209 lines
9.2 KiB
C#
using System.Collections.Concurrent;
|
|
using PkmnLib.Dynamic.AI;
|
|
using PkmnLib.Dynamic.Libraries;
|
|
using PkmnLib.Dynamic.Models;
|
|
using PkmnLib.Plugin.Gen7;
|
|
using PkmnLib.Static.Species;
|
|
using PkmnLib.Static.Utils;
|
|
using Serilog;
|
|
using Serilog.Events;
|
|
using ShellProgressBar;
|
|
|
|
namespace AIRunner;
|
|
|
|
public static class TestCommandRunner
|
|
{
|
|
internal static async Task RunTestCommand(PokemonAI ai1, PokemonAI ai2, int battles)
|
|
{
|
|
var t1 = DateTime.UtcNow;
|
|
var library = DynamicLibraryImpl.Create([
|
|
new Gen7Plugin(),
|
|
]);
|
|
|
|
Log.Information("Running {Battles} battles between {AI1} and {AI2}", battles, ai1.Name, ai2.Name);
|
|
var averageTimePerTurnPerBattle = new List<double>(battles);
|
|
var turnsPerBattle = new ConcurrentBag<uint>();
|
|
var results = new ConcurrentBag<BattleResult>();
|
|
var rootRandom = new RandomImpl();
|
|
|
|
const int maxTasks = 10;
|
|
var battleTasks = new Task[maxTasks];
|
|
var randoms = new IRandom[maxTasks];
|
|
for (var i = 0; i < maxTasks; i++)
|
|
{
|
|
randoms[i] = new RandomImpl(rootRandom.GetInt());
|
|
battleTasks[i] = Task.CompletedTask; // Initialize tasks to avoid null references
|
|
}
|
|
|
|
// Show a progress bar if debug logging is not enabled.
|
|
// This is to avoid weird console output where the progress bar is drawn in the middle of debug logs.
|
|
ProgressBar? pb = null;
|
|
if (!Log.IsEnabled(LogEventLevel.Debug))
|
|
{
|
|
pb = new ProgressBar(battles, "Running battles...", new ProgressBarOptions
|
|
{
|
|
ShowEstimatedDuration = true,
|
|
ProgressBarOnBottom = true,
|
|
});
|
|
pb.EstimatedDuration = TimeSpan.FromMilliseconds(battles);
|
|
}
|
|
for (var i = 0; i < battles; i++)
|
|
{
|
|
var taskIndex = i % maxTasks;
|
|
var index = i;
|
|
var battleTask = Task.Run(async () =>
|
|
{
|
|
Log.Debug("Battle {BattleNumber}: {AI1} vs {AI2}", index + 1, ai1.Name, ai2.Name);
|
|
var random = randoms[taskIndex];
|
|
var battle = GenerateBattle(library, 3, random);
|
|
var timePerTurn = new List<double>(20);
|
|
while (!battle.HasEnded)
|
|
{
|
|
if (battle.CurrentTurnNumber > 1000)
|
|
{
|
|
Log.Warning("Battle {BattleNumber} exceeded 1000 turns, ending battle early", index + 1);
|
|
battle.ForceEndBattle();
|
|
var last10Choices = battle.PreviousTurnChoices.TakeLast(10).ToList();
|
|
Log.Warning("Last 10 choices: {Choices}", last10Choices);
|
|
|
|
return;
|
|
}
|
|
|
|
var res = await GetAndSetChoices(battle, ai1, ai2);
|
|
timePerTurn.Add(res.MsPerTurn);
|
|
}
|
|
var result = battle.Result;
|
|
Log.Debug("Battle {BattleNumber} ended with result: {Result}", index + 1, result);
|
|
averageTimePerTurnPerBattle.Add(timePerTurn.Average());
|
|
results.Add(result.Value);
|
|
turnsPerBattle.Add(battle.CurrentTurnNumber);
|
|
// ReSharper disable once AccessToDisposedClosure
|
|
pb?.Tick();
|
|
});
|
|
battleTasks[taskIndex] = battleTask;
|
|
if (i % maxTasks == maxTasks - 1 || i == battles - 1)
|
|
{
|
|
Log.Debug("Starting {TaskCount} tasks", maxTasks);
|
|
await Task.WhenAll(battleTasks);
|
|
Log.Debug("Batch of {TaskCount} tasks completed", maxTasks);
|
|
Array.Fill(battleTasks, Task.CompletedTask); // Reset tasks for the next batch
|
|
}
|
|
}
|
|
pb?.Dispose();
|
|
|
|
var t2 = DateTime.UtcNow;
|
|
Log.Information("{Amount} battles completed in {Duration} ms", battles, (t2 - t1).TotalMilliseconds);
|
|
var averageTimePerTurn = averageTimePerTurnPerBattle.Average();
|
|
Log.Information("Average time per turn: {AverageTimePerTurn} ms", averageTimePerTurn);
|
|
Log.Information("Average turns per battle: {AverageTurnsPerBattle}", turnsPerBattle.Average(x => x));
|
|
|
|
var winCount1 = results.Count(x => x.WinningSide == 0);
|
|
var winCount2 = results.Count(x => x.WinningSide == 1);
|
|
var drawCount = results.Count(x => x.WinningSide == null);
|
|
|
|
var winRate1 = winCount1 / (double)battles * 100;
|
|
var winRate2 = winCount2 / (double)battles * 100;
|
|
Log.Information("AI {AI1} win rate: {WinRate1:F3}% ({WinCount1} wins)", ai1.Name, winRate1, winCount1);
|
|
Log.Information("AI {AI2} win rate: {WinRate2:F3}% ({WinCount2} wins)", ai2.Name, winRate2, winCount2);
|
|
Log.Information("Draw rate: {DrawRate:F3}% ({DrawCount} draws)", drawCount / (double)battles * 100, drawCount);
|
|
}
|
|
|
|
private static PokemonPartyImpl GenerateParty(IDynamicLibrary library, int length, IRandom random)
|
|
{
|
|
var party = new PokemonPartyImpl(6);
|
|
for (var i = 0; i < length; i++)
|
|
{
|
|
var species = library.StaticLibrary.Species.GetRandom(random);
|
|
var nature = library.StaticLibrary.Natures.GetRandom(random);
|
|
const byte level = 50;
|
|
var defaultForm = species.GetDefaultForm();
|
|
var abilityIndex = (byte)random.GetInt(0, defaultForm.Abilities.Count);
|
|
var mon = new PokemonImpl(library, species, species.GetDefaultForm(), new AbilityIndex
|
|
{
|
|
IsHidden = false,
|
|
Index = abilityIndex,
|
|
}, level, 0, species.GetRandomGender(random), 0, nature.Name);
|
|
var moves = defaultForm.Moves.GetDistinctLevelMoves().OrderBy(_ => random.GetInt()).Take(4);
|
|
foreach (var move in moves)
|
|
mon.LearnMove(move, MoveLearnMethod.LevelUp, 255);
|
|
|
|
party.SwapInto(mon, i);
|
|
}
|
|
Log.Debug("Generated party: {Party}", party);
|
|
return party;
|
|
}
|
|
|
|
private static BattleImpl GenerateBattle(IDynamicLibrary library, int partyLength, IRandom random)
|
|
{
|
|
var parties = new[]
|
|
{
|
|
new BattlePartyImpl(GenerateParty(library, partyLength, random), [
|
|
new ResponsibleIndex(0, 0),
|
|
]),
|
|
new BattlePartyImpl(GenerateParty(library, partyLength, random), [
|
|
new ResponsibleIndex(1, 0),
|
|
]),
|
|
};
|
|
return new BattleImpl(library, parties, false, 2, 1, false, "test");
|
|
}
|
|
|
|
private record struct GetAndSetChoicesResult(double MsPerTurn);
|
|
|
|
private static async Task<GetAndSetChoicesResult> GetAndSetChoices(BattleImpl battle, PokemonAI ai1, PokemonAI ai2)
|
|
{
|
|
var pokemon1 = battle.Sides[0].Pokemon[0];
|
|
while (pokemon1 is null && !battle.HasEnded)
|
|
{
|
|
pokemon1 = battle.Parties[0].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable);
|
|
if (pokemon1 is null)
|
|
throw new InvalidOperationException("No usable Pokémon found in party 1.");
|
|
battle.Sides[0].SwapPokemon(0, pokemon1);
|
|
pokemon1 = battle.Sides[0].Pokemon[0];
|
|
}
|
|
var pokemon2 = battle.Sides[1].Pokemon[0];
|
|
while (pokemon2 is null && !battle.HasEnded)
|
|
{
|
|
pokemon2 = battle.Parties[1].Party.WhereNotNull().FirstOrDefault(x => x.IsUsable);
|
|
if (pokemon2 is null)
|
|
throw new InvalidOperationException("No usable Pokémon found in party 2.");
|
|
battle.Sides[1].SwapPokemon(0, pokemon2);
|
|
pokemon2 = battle.Sides[1].Pokemon[0];
|
|
}
|
|
if (pokemon1 is null || pokemon2 is null)
|
|
{
|
|
throw new InvalidOperationException("Both Pokémon must be non-null to proceed with the battle.");
|
|
}
|
|
|
|
var taskAiOne = !battle.HasForcedTurn(pokemon1!, out var choice1)
|
|
? Task.Run(() => ai1.GetChoice(battle, pokemon1))
|
|
: Task.FromResult(choice1);
|
|
var taskAiTwo = !battle.HasForcedTurn(pokemon2!, out var choice2)
|
|
? Task.Run(() => ai2.GetChoice(battle, pokemon2))
|
|
: Task.FromResult(choice2);
|
|
await Task.WhenAll(taskAiOne, taskAiTwo);
|
|
choice1 = taskAiOne.Result;
|
|
choice2 = taskAiTwo.Result;
|
|
Log.Debug("Turn {Turn}: AI {AI1} choice: {Choice1}, AI {AI2} choice: {Choice2}", battle.CurrentTurnNumber,
|
|
ai1.Name, choice1, ai2.Name, choice2);
|
|
var startTime = DateTime.UtcNow;
|
|
if (!battle.TrySetChoice(choice1))
|
|
{
|
|
var replacementChoice = battle.Library.MiscLibrary.ReplacementChoice(pokemon1, 1, 0);
|
|
if (!battle.TrySetChoice(replacementChoice))
|
|
{
|
|
throw new InvalidOperationException($"AI {ai1.Name} failed to set a valid choice: {choice1}");
|
|
}
|
|
}
|
|
if (!battle.TrySetChoice(choice2))
|
|
{
|
|
var replacementChoice = battle.Library.MiscLibrary.ReplacementChoice(pokemon2, 0, 0);
|
|
if (!battle.TrySetChoice(replacementChoice))
|
|
{
|
|
throw new InvalidOperationException($"AI {ai2.Name} failed to set a valid choice: {choice2}");
|
|
}
|
|
}
|
|
var endTime = DateTime.UtcNow;
|
|
var msPerTurn = (endTime - startTime).TotalMilliseconds;
|
|
return new GetAndSetChoicesResult(msPerTurn);
|
|
}
|
|
} |