#ifndef PKMNLIB_AI_DEPTHSEARCHAI_HPP #define PKMNLIB_AI_DEPTHSEARCHAI_HPP #include #include #include #include #include #include #include #include "../cmake-build-release/PkmnLib/src/pkmnlib/src/Battling/Pokemon/PokemonParty.hpp" #include "NaiveAI.hpp" #include "PokemonAI.hpp" namespace PkmnLibAI { class DepthSearchAI : public PokemonAI { NaiveAI _naive; private: float ScoreBattle(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) { auto side = user->GetBattleSide().GetValue(); if (battle->HasEnded()) { if (battle->GetResult().GetWinningSide() == side->GetSideIndex()) { return std::numeric_limits::max(); } else { return std::numeric_limits::min(); } } float score = 0.0; auto opposite = GetOppositeIndex(user); for (auto* party : battle->GetParties()) { if (party->IsResponsibleForIndex(side->GetSideIndex(), 0)) { for (auto& mon : party->GetParty()->GetParty()) { score += mon->GetCurrentHealth() / (float)mon->GetMaxHealth(); } } if (party->IsResponsibleForIndex(opposite)) { for (auto& mon : party->GetParty()->GetParty()) { score -= (mon->GetCurrentHealth() / (float)mon->GetMaxHealth()); } } } return score; } struct BattlePointerWrapper { PkmnLib::Battling::Battle* Battle; BattlePointerWrapper(PkmnLib::Battling::Battle* battle) : Battle(battle) {} inline PkmnLib::Battling::Battle* operator->() const noexcept { return Battle; } ~BattlePointerWrapper() { std::vector parties; for (auto party : Battle->GetParties()) { parties.push_back(party->GetParty().GetRaw()); } delete Battle; for (auto party : parties) { delete party; } } }; std::vector> ScoreChoicesThreaded(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user, uint8_t depth) { std::vector>> threadPool; auto side = user->GetBattleSide().GetValue(); for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) { auto move = user->GetMoves()[moveIndex]; if (move->GetRemainingUses() == 0) { continue; } threadPool.push_back(std::async([=] { auto v = std::tuple(moveIndex, SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth)); asThreadCleanup(); return v; })); } auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty(); for (u8 i = 0; i < party.Count(); ++i) { auto mon = party[i]; if (!mon.HasValue()) { continue; } if (mon.GetValue()->IsFainted()) { continue; } threadPool.push_back(std::async([=] { auto v = std::tuple((u8)(i + 4), SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth)); asThreadCleanup(); return v; })); } std::vector> results(threadPool.size()); for (int i = 0; i < threadPool.size(); ++i) { results[i] = threadPool[i].get(); } return results; } std::vector> ScoreChoices(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user, uint8_t depth) { std::vector> scoredMoves; auto side = user->GetBattleSide().GetValue(); for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) { auto move = user->GetMoves()[moveIndex]; if (move->GetRemainingUses() == 0) { continue; } auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth); scoredMoves.emplace_back(moveIndex, scored); } auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty(); for (int i = 0; i < party.Count(); ++i) { auto mon = party[i]; if (!mon.HasValue()) { continue; } if (mon.GetValue()->IsFainted()) { continue; } auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth); scoredMoves.emplace_back(i + 4, scored); } return scoredMoves; } float SimulateTurn(PkmnLib::Battling::Battle* originalBattle, u8 sideIndex, u8 pokemonIndex, u8 index, u8 depth) { auto battle = BattlePointerWrapper(originalBattle->Clone()); auto user = dynamic_cast(battle->GetCreature(sideIndex, pokemonIndex).GetValue()); auto target = GetOppositeIndex(user); if (index < 4) { auto move = user->GetMoves()[index]; auto choice = new CreatureLib::Battling::AttackTurnChoice(user, move, target); if (!battle->TrySetChoice(choice)) { delete choice; return std::numeric_limits::min(); } } else { auto mon = battle->GetParties()[sideIndex]->GetParty()->GetParty().At(index - 4); auto choice = new CreatureLib::Battling::SwitchTurnChoice(user, mon); if (!battle->TrySetChoice(choice)) { delete choice; return std::numeric_limits::min(); } } battle->TrySetChoice(_naive.GetChoice( battle.Battle, dynamic_cast(battle->GetCreature(target).GetValue()))); float score; if (depth <= 1) { score = ScoreBattle(battle.Battle, user); } else { auto scoredChoices = ScoreChoices(battle.Battle, user, depth - 1); float summedScore = 0; size_t amount = 0; for (auto& option : scoredChoices) { summedScore += std::get<1>(option); amount++; } if (amount == 0) { return std::numeric_limits::min(); } score = summedScore / amount; } return score; } public: std::string GetName() const noexcept override { return "depthsearch"; } CreatureLib::Battling::BaseTurnChoice* GetChoice(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) override { auto scoredChoices = ScoreChoicesThreaded(battle, user, 2); auto side = user->GetBattleSide().GetValue(); auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty(); auto target = GetOppositeIndex(user); if (scoredChoices.empty()) { return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, target); } i32 highest = -1; float highestScore = -std::numeric_limits::infinity(); for (auto& option : scoredChoices) { if (std::get<1>(option) > highestScore) { highestScore = std::get<1>(option); highest = std::get<0>(option); } } if (highest == -1) { return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, target); } if (highest < 4) { return new CreatureLib::Battling::AttackTurnChoice(user, user->GetMoves()[highest], target); } else { return new CreatureLib::Battling::SwitchTurnChoice(user, party.At(highest - 4)); } } }; } #endif // PKMNLIB_AI_DEPTHSEARCHAI_HPP