PkmnLibAI/src/DepthSearchAI.hpp

208 lines
8.9 KiB
C++

#ifndef PKMNLIB_AI_DEPTHSEARCHAI_HPP
#define PKMNLIB_AI_DEPTHSEARCHAI_HPP
#include <CreatureLib/Battling/TurnChoices/AttackTurnChoice.hpp>
#include <CreatureLib/Battling/TurnChoices/PassTurnChoice.hpp>
#include <CreatureLib/Battling/TurnChoices/SwitchTurnChoice.hpp>
#include <PkmnLib/Battling/Pokemon/LearnedMove.hpp>
#include <algorithm>
#include <future>
#include <thread>
#include "../cmake-build-release/PkmnLib/src/pkmnlib/src/Battling/Pokemon/PokemonParty.hpp"
#include "NaiveAI.hpp"
#include "PokemonAI.hpp"
namespace PkmnLibAI {
class DepthSearchAI : public PokemonAI {
NaiveAI _naive;
private:
float ScoreBattle(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) {
auto side = user->GetBattleSide().GetValue();
if (battle->HasEnded()) {
if (battle->GetResult().GetWinningSide() == side->GetSideIndex()) {
return std::numeric_limits<float>::max();
} else {
return std::numeric_limits<float>::min();
}
}
float score = 0.0;
auto opposite = GetOppositeIndex(user);
for (auto* party : battle->GetParties()) {
if (party->IsResponsibleForIndex(side->GetSideIndex(), 0)) {
for (auto& mon : party->GetParty()->GetParty()) {
score += mon->GetCurrentHealth() / (float)mon->GetMaxHealth();
}
}
if (party->IsResponsibleForIndex(opposite)) {
for (auto& mon : party->GetParty()->GetParty()) {
score -= (mon->GetCurrentHealth() / (float)mon->GetMaxHealth());
}
}
}
return score;
}
struct BattlePointerWrapper {
PkmnLib::Battling::Battle* Battle;
BattlePointerWrapper(PkmnLib::Battling::Battle* battle) : Battle(battle) {}
inline PkmnLib::Battling::Battle* operator->() const noexcept { return Battle; }
~BattlePointerWrapper() {
std::vector<const CreatureLib::Battling::CreatureParty*> parties;
for (auto party : Battle->GetParties()) {
parties.push_back(party->GetParty().GetRaw());
}
delete Battle;
for (auto party : parties) {
delete party;
}
}
};
std::vector<std::tuple<u8, float>> ScoreChoicesThreaded(PkmnLib::Battling::Battle* battle,
PkmnLib::Battling::Pokemon* user, uint8_t depth) {
std::vector<std::future<std::tuple<u8, float>>> threadPool;
auto side = user->GetBattleSide().GetValue();
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
auto move = user->GetMoves()[moveIndex];
if (move->GetRemainingUses() == 0) {
continue;
}
threadPool.push_back(std::async([=] {
auto v = std::tuple(moveIndex, SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth));
asThreadCleanup();
return v;
}));
}
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
for (u8 i = 0; i < party.Count(); ++i) {
auto mon = party[i];
if (!mon.HasValue()) {
continue;
}
if (mon.GetValue()->IsFainted()) {
continue;
}
threadPool.push_back(std::async([=] {
auto v = std::tuple((u8)(i + 4), SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth));
asThreadCleanup();
return v;
}));
}
std::vector<std::tuple<u8, float>> results(threadPool.size());
for (int i = 0; i < threadPool.size(); ++i) {
results[i] = threadPool[i].get();
}
return results;
}
std::vector<std::tuple<int, float>> ScoreChoices(PkmnLib::Battling::Battle* battle,
PkmnLib::Battling::Pokemon* user, uint8_t depth) {
std::vector<std::tuple<int, float>> scoredMoves;
auto side = user->GetBattleSide().GetValue();
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
auto move = user->GetMoves()[moveIndex];
if (move->GetRemainingUses() == 0) {
continue;
}
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth);
scoredMoves.emplace_back(moveIndex, scored);
}
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
for (int i = 0; i < party.Count(); ++i) {
auto mon = party[i];
if (!mon.HasValue()) {
continue;
}
if (mon.GetValue()->IsFainted()) {
continue;
}
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth);
scoredMoves.emplace_back(i + 4, scored);
}
return scoredMoves;
}
float SimulateTurn(PkmnLib::Battling::Battle* originalBattle, u8 sideIndex, u8 pokemonIndex, u8 index,
u8 depth) {
auto battle = BattlePointerWrapper(originalBattle->Clone());
auto user =
dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(sideIndex, pokemonIndex).GetValue());
auto target = GetOppositeIndex(user);
if (index < 4) {
auto move = user->GetMoves()[index];
auto choice = new CreatureLib::Battling::AttackTurnChoice(user, move, target);
if (!battle->TrySetChoice(choice)) {
delete choice;
return std::numeric_limits<float>::min();
}
} else {
auto mon = battle->GetParties()[sideIndex]->GetParty()->GetParty().At(index - 4);
auto choice = new CreatureLib::Battling::SwitchTurnChoice(user, mon);
if (!battle->TrySetChoice(choice)) {
delete choice;
return std::numeric_limits<float>::min();
}
}
battle->TrySetChoice(_naive.GetChoice(
battle.Battle, dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(target).GetValue())));
float score;
if (depth <= 1) {
score = ScoreBattle(battle.Battle, user);
} else {
auto scoredChoices = ScoreChoices(battle.Battle, user, depth - 1);
float summedScore = 0;
size_t amount = 0;
for (auto& option : scoredChoices) {
summedScore += std::get<1>(option);
amount++;
}
if (amount == 0) {
return std::numeric_limits<float>::min();
}
score = summedScore / amount;
}
return score;
}
public:
std::string GetName() const noexcept override { return "depthsearch"; }
CreatureLib::Battling::BaseTurnChoice* GetChoice(PkmnLib::Battling::Battle* battle,
PkmnLib::Battling::Pokemon* user) override {
auto scoredChoices = ScoreChoicesThreaded(battle, user, 2);
auto side = user->GetBattleSide().GetValue();
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
auto target = GetOppositeIndex(user);
if (scoredChoices.empty()) {
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, target);
}
i32 highest = -1;
float highestScore = -std::numeric_limits<float>::infinity();
for (auto& option : scoredChoices) {
if (std::get<1>(option) > highestScore) {
highestScore = std::get<1>(option);
highest = std::get<0>(option);
}
}
if (highest == -1) {
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, target);
}
if (highest < 4) {
return new CreatureLib::Battling::AttackTurnChoice(user, user->GetMoves()[highest], target);
} else {
return new CreatureLib::Battling::SwitchTurnChoice(user, party.At(highest - 4));
}
}
};
}
#endif // PKMNLIB_AI_DEPTHSEARCHAI_HPP