208 lines
8.9 KiB
C++
208 lines
8.9 KiB
C++
#ifndef PKMNLIB_AI_DEPTHSEARCHAI_HPP
|
|
#define PKMNLIB_AI_DEPTHSEARCHAI_HPP
|
|
#include <CreatureLib/Battling/TurnChoices/AttackTurnChoice.hpp>
|
|
#include <CreatureLib/Battling/TurnChoices/PassTurnChoice.hpp>
|
|
#include <CreatureLib/Battling/TurnChoices/SwitchTurnChoice.hpp>
|
|
#include <PkmnLib/Battling/Pokemon/LearnedMove.hpp>
|
|
#include <algorithm>
|
|
#include <future>
|
|
#include <thread>
|
|
#include "../cmake-build-release/PkmnLib/src/pkmnlib/src/Battling/Pokemon/PokemonParty.hpp"
|
|
#include "NaiveAI.hpp"
|
|
#include "PokemonAI.hpp"
|
|
|
|
namespace PkmnLibAI {
|
|
class DepthSearchAI : public PokemonAI {
|
|
NaiveAI _naive;
|
|
|
|
private:
|
|
float ScoreBattle(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) {
|
|
auto side = user->GetBattleSide().GetValue();
|
|
if (battle->HasEnded()) {
|
|
if (battle->GetResult().GetWinningSide() == side->GetSideIndex()) {
|
|
return std::numeric_limits<float>::max();
|
|
} else {
|
|
return std::numeric_limits<float>::min();
|
|
}
|
|
}
|
|
float score = 0.0;
|
|
auto opposite = GetOppositeIndex(user);
|
|
for (auto* party : battle->GetParties()) {
|
|
if (party->IsResponsibleForIndex(side->GetSideIndex(), 0)) {
|
|
for (auto& mon : party->GetParty()->GetParty()) {
|
|
score += mon->GetCurrentHealth() / (float)mon->GetMaxHealth();
|
|
}
|
|
}
|
|
if (party->IsResponsibleForIndex(opposite)) {
|
|
for (auto& mon : party->GetParty()->GetParty()) {
|
|
score -= (mon->GetCurrentHealth() / (float)mon->GetMaxHealth());
|
|
}
|
|
}
|
|
}
|
|
return score;
|
|
}
|
|
|
|
struct BattlePointerWrapper {
|
|
PkmnLib::Battling::Battle* Battle;
|
|
|
|
BattlePointerWrapper(PkmnLib::Battling::Battle* battle) : Battle(battle) {}
|
|
|
|
inline PkmnLib::Battling::Battle* operator->() const noexcept { return Battle; }
|
|
|
|
~BattlePointerWrapper() {
|
|
std::vector<const CreatureLib::Battling::CreatureParty*> parties;
|
|
for (auto party : Battle->GetParties()) {
|
|
parties.push_back(party->GetParty().GetRaw());
|
|
}
|
|
delete Battle;
|
|
for (auto party : parties) {
|
|
delete party;
|
|
}
|
|
}
|
|
};
|
|
|
|
std::vector<std::tuple<u8, float>> ScoreChoicesThreaded(PkmnLib::Battling::Battle* battle,
|
|
PkmnLib::Battling::Pokemon* user, uint8_t depth) {
|
|
std::vector<std::future<std::tuple<u8, float>>> threadPool;
|
|
auto side = user->GetBattleSide().GetValue();
|
|
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
|
|
auto move = user->GetMoves()[moveIndex];
|
|
if (move->GetRemainingUses() == 0) {
|
|
continue;
|
|
}
|
|
threadPool.push_back(std::async([=] {
|
|
auto v = std::tuple(moveIndex, SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth));
|
|
asThreadCleanup();
|
|
return v;
|
|
}));
|
|
}
|
|
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
|
|
for (u8 i = 0; i < party.Count(); ++i) {
|
|
auto mon = party[i];
|
|
if (!mon.HasValue()) {
|
|
continue;
|
|
}
|
|
if (mon.GetValue()->IsFainted()) {
|
|
continue;
|
|
}
|
|
threadPool.push_back(std::async([=] {
|
|
auto v = std::tuple((u8)(i + 4), SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth));
|
|
asThreadCleanup();
|
|
return v;
|
|
}));
|
|
}
|
|
std::vector<std::tuple<u8, float>> results(threadPool.size());
|
|
for (int i = 0; i < threadPool.size(); ++i) {
|
|
results[i] = threadPool[i].get();
|
|
}
|
|
return results;
|
|
}
|
|
|
|
std::vector<std::tuple<int, float>> ScoreChoices(PkmnLib::Battling::Battle* battle,
|
|
PkmnLib::Battling::Pokemon* user, uint8_t depth) {
|
|
std::vector<std::tuple<int, float>> scoredMoves;
|
|
auto side = user->GetBattleSide().GetValue();
|
|
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
|
|
auto move = user->GetMoves()[moveIndex];
|
|
if (move->GetRemainingUses() == 0) {
|
|
continue;
|
|
}
|
|
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth);
|
|
scoredMoves.emplace_back(moveIndex, scored);
|
|
}
|
|
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
|
|
for (int i = 0; i < party.Count(); ++i) {
|
|
auto mon = party[i];
|
|
if (!mon.HasValue()) {
|
|
continue;
|
|
}
|
|
if (mon.GetValue()->IsFainted()) {
|
|
continue;
|
|
}
|
|
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth);
|
|
scoredMoves.emplace_back(i + 4, scored);
|
|
}
|
|
return scoredMoves;
|
|
}
|
|
|
|
float SimulateTurn(PkmnLib::Battling::Battle* originalBattle, u8 sideIndex, u8 pokemonIndex, u8 index,
|
|
u8 depth) {
|
|
auto battle = BattlePointerWrapper(originalBattle->Clone());
|
|
auto user =
|
|
dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(sideIndex, pokemonIndex).GetValue());
|
|
auto target = GetOppositeIndex(user);
|
|
if (index < 4) {
|
|
auto move = user->GetMoves()[index];
|
|
auto choice = new CreatureLib::Battling::AttackTurnChoice(user, move, target);
|
|
if (!battle->TrySetChoice(choice)) {
|
|
delete choice;
|
|
return std::numeric_limits<float>::min();
|
|
}
|
|
} else {
|
|
auto mon = battle->GetParties()[sideIndex]->GetParty()->GetParty().At(index - 4);
|
|
auto choice = new CreatureLib::Battling::SwitchTurnChoice(user, mon);
|
|
if (!battle->TrySetChoice(choice)) {
|
|
delete choice;
|
|
return std::numeric_limits<float>::min();
|
|
}
|
|
}
|
|
battle->TrySetChoice(_naive.GetChoice(
|
|
battle.Battle, dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(target).GetValue())));
|
|
|
|
float score;
|
|
if (depth <= 1) {
|
|
score = ScoreBattle(battle.Battle, user);
|
|
} else {
|
|
auto scoredChoices = ScoreChoices(battle.Battle, user, depth - 1);
|
|
float summedScore = 0;
|
|
size_t amount = 0;
|
|
for (auto& option : scoredChoices) {
|
|
summedScore += std::get<1>(option);
|
|
amount++;
|
|
}
|
|
if (amount == 0) {
|
|
return std::numeric_limits<float>::min();
|
|
}
|
|
|
|
score = summedScore / amount;
|
|
}
|
|
return score;
|
|
}
|
|
|
|
public:
|
|
std::string GetName() const noexcept override { return "depthsearch"; }
|
|
|
|
CreatureLib::Battling::BaseTurnChoice* GetChoice(PkmnLib::Battling::Battle* battle,
|
|
PkmnLib::Battling::Pokemon* user) override {
|
|
auto scoredChoices = ScoreChoicesThreaded(battle, user, 2);
|
|
auto side = user->GetBattleSide().GetValue();
|
|
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
|
|
|
|
auto target = GetOppositeIndex(user);
|
|
if (scoredChoices.empty()) {
|
|
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, target);
|
|
}
|
|
|
|
i32 highest = -1;
|
|
float highestScore = -std::numeric_limits<float>::infinity();
|
|
for (auto& option : scoredChoices) {
|
|
if (std::get<1>(option) > highestScore) {
|
|
highestScore = std::get<1>(option);
|
|
highest = std::get<0>(option);
|
|
}
|
|
}
|
|
if (highest == -1) {
|
|
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, target);
|
|
}
|
|
|
|
if (highest < 4) {
|
|
return new CreatureLib::Battling::AttackTurnChoice(user, user->GetMoves()[highest], target);
|
|
} else {
|
|
return new CreatureLib::Battling::SwitchTurnChoice(user, party.At(highest - 4));
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
#endif // PKMNLIB_AI_DEPTHSEARCHAI_HPP
|