#ifndef PKMNLIB_AI_DEPTHSEARCHAI_HPP #define PKMNLIB_AI_DEPTHSEARCHAI_HPP #include #include #include #include #include #include #include #include #include "../cmake-build-release/PkmnLib/src/pkmnlib/src/Battling/Pokemon/PokemonParty.hpp" #include "NaiveAI.hpp" #include "PokemonAI.hpp" namespace PkmnLibAI { class DepthSearchAI : public PokemonAI { NaiveAI _naive; private: class SimulatedResult { public: std::optional Score; CreatureLib::Battling::CreatureIndex Target; SimulatedResult(std::optional score, CreatureLib::Battling::CreatureIndex target) : Score(score), Target(target) {} SimulatedResult() {} }; class ScoredOption { public: u8 OptionIndex; std::optional Score; CreatureLib::Battling::CreatureIndex Target; ScoredOption(u8 optionIndex, SimulatedResult result) : OptionIndex(optionIndex), Score(result.Score), Target(result.Target) {} ScoredOption() {} }; float ScoreBattle(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) { auto side = user->GetBattleSide().GetValue(); if (battle->HasEnded()) { if (battle->GetResult().GetWinningSide() == side->GetSideIndex()) { return std::numeric_limits::max(); } else { return std::numeric_limits::min(); } } float score = 0.0; auto opposite = GetOppositeIndex(user); for (auto* party : battle->GetParties()) { if (party->IsResponsibleForIndex(side->GetSideIndex(), 0)) { for (auto& mon : party->GetParty()->GetParty()) { score += mon->GetCurrentHealth() / (float)mon->GetMaxHealth(); } } if (party->IsResponsibleForIndex(opposite)) { for (auto& mon : party->GetParty()->GetParty()) { score -= (mon->GetCurrentHealth() / (float)mon->GetMaxHealth()); } } } return score; } struct BattlePointerWrapper { PkmnLib::Battling::Battle* Battle; BattlePointerWrapper(PkmnLib::Battling::Battle* battle) : Battle(battle) {} inline PkmnLib::Battling::Battle* operator->() const noexcept { return Battle; } ~BattlePointerWrapper() { std::vector parties; for (auto party : Battle->GetParties()) { parties.push_back(party->GetParty().GetRaw()); } delete Battle; for (auto party : parties) { delete party; } } }; std::vector ScoreChoicesThreaded(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user, uint8_t depth) { std::vector> threadPool; auto side = user->GetBattleSide().GetValue(); for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) { auto move = user->GetMoves()[moveIndex]; if (!move.HasValue()) { continue; } if (move.GetValue()->GetRemainingUses() == 0) { continue; } threadPool.push_back(std::async([this, battle, side, moveIndex, depth] { auto v = ScoredOption(moveIndex, SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth)); asThreadCleanup(); return v; })); } auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty(); for (u8 i = 0; i < party.Count(); ++i) { auto mon = party[i]; if (!mon.HasValue()) { continue; } if (mon.GetValue()->IsFainted()) { continue; } threadPool.push_back(std::async([this, battle, side, i, depth] { auto v = ScoredOption((u8)(i + 4), SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth)); asThreadCleanup(); return v; })); } std::vector results(threadPool.size()); for (size_t i = 0; i < threadPool.size(); ++i) { results[i] = threadPool[i].get(); } return results; } std::vector ScoreChoices(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user, uint8_t depth) { std::vector scoredMoves; auto side = user->GetBattleSide().GetValue(); for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) { auto move = user->GetMoves()[moveIndex]; if (!move.HasValue()) { continue; } if (move.GetValue()->GetRemainingUses() == 0) { continue; } auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth); scoredMoves.emplace_back(moveIndex, scored); } auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty(); for (size_t i = 0; i < party.Count(); ++i) { auto mon = party[i]; if (!mon.HasValue()) { continue; } if (mon.GetValue()->IsFainted()) { continue; } auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth); scoredMoves.emplace_back(i + 4, scored); } return scoredMoves; } std::optional GetTarget(PkmnLib::Battling::Pokemon* user, ArbUt::BorrowedPtr move) { switch (move->GetTarget()) { case CreatureLib::Library::AttackTarget::Adjacent: return {}; case CreatureLib::Library::AttackTarget::AdjacentAlly: return {}; case CreatureLib::Library::AttackTarget::AdjacentAllySelf: return user->GetBattleIndex(); case CreatureLib::Library::AttackTarget::AdjacentOpponent: return GetOppositeIndex(user); case CreatureLib::Library::AttackTarget::All: return GetOppositeIndex(user); case CreatureLib::Library::AttackTarget::AllAdjacent: return GetOppositeIndex(user); case CreatureLib::Library::AttackTarget::AllAdjacentOpponent: return GetOppositeIndex(user); case CreatureLib::Library::AttackTarget::AllAlly: return user->GetBattleIndex(); case CreatureLib::Library::AttackTarget::AllOpponent: return GetOppositeIndex(user); case CreatureLib::Library::AttackTarget::Any: return GetOppositeIndex(user); case CreatureLib::Library::AttackTarget::RandomOpponent: return GetOppositeIndex(user); case CreatureLib::Library::AttackTarget::Self: return user->GetBattleIndex(); } return {}; } SimulatedResult SimulateTurn(PkmnLib::Battling::Battle* originalBattle, u8 sideIndex, u8 pokemonIndex, u8 index, u8 depth) { auto battle = BattlePointerWrapper(originalBattle->Clone()); auto user = dynamic_cast(battle->GetCreature(sideIndex, pokemonIndex).GetValue()); CreatureLib::Battling::CreatureIndex target; if (index < 4) { auto move = user->GetMoves()[index]; if (!move.HasValue()) { return {}; } if (move.GetValue()->GetRemainingUses() <= 0) { return {}; } auto targetOption = GetTarget(user, move.GetValue()->GetAttack()); if (!targetOption.has_value()) { return {}; } target = targetOption.value(); auto choice = new CreatureLib::Battling::AttackTurnChoice(user, move.GetValue(), target); if (!battle->TrySetChoice(choice)) { delete choice; return {}; } } else { auto mon = battle->GetParties()[sideIndex]->GetParty()->GetParty().At(index - 4); auto choice = new CreatureLib::Battling::SwitchTurnChoice(user, mon); if (!battle->TrySetChoice(choice)) { delete choice; return {}; } } battle->TrySetChoice(_naive.GetChoice( battle.Battle, dynamic_cast(battle->GetCreature(GetOppositeIndex(user)).GetValue()))); float score; if (depth <= 1) { score = ScoreBattle(battle.Battle, user); } else { auto scoredChoices = ScoreChoices(battle.Battle, user, depth - 1); float summedScore = 0; size_t amount = 0; for (auto& option : scoredChoices) { auto v = option.Score; if (!v.has_value()) { continue; } summedScore += v.value(); amount++; } if (amount == 0) { return {}; } score = summedScore / amount; } return SimulatedResult(score, target); } public: std::string GetName() const noexcept override { return "depthsearch"; } CreatureLib::Battling::BaseTurnChoice* GetChoice(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) override { auto scoredChoices = ScoreChoicesThreaded(battle, user, 2); auto side = user->GetBattleSide().GetValue(); auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty(); if (scoredChoices.empty()) { return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, GetOppositeIndex(user)); } i32 highest = -1; float highestScore = -std::numeric_limits::infinity(); CreatureLib::Battling::CreatureIndex highestTarget; for (auto& option : scoredChoices) { auto v = option.Score; if (!v.has_value()) { continue; } if (v.value() > highestScore) { highestScore = v.value(); highest = option.OptionIndex; highestTarget = option.Target; } } if (highest == -1) { return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, GetOppositeIndex(user)); } if (highest < 4) { return new CreatureLib::Battling::AttackTurnChoice(user, user->GetMoves()[highest].GetValue(), highestTarget); } else { return new CreatureLib::Battling::SwitchTurnChoice(user, party.At(highest - 4)); } } }; } #endif // PKMNLIB_AI_DEPTHSEARCHAI_HPP