280 lines
12 KiB
C++
280 lines
12 KiB
C++
#ifndef PKMNLIB_AI_DEPTHSEARCHAI_HPP
|
|
#define PKMNLIB_AI_DEPTHSEARCHAI_HPP
|
|
#include <CreatureLib/Battling/TurnChoices/AttackTurnChoice.hpp>
|
|
#include <CreatureLib/Battling/TurnChoices/PassTurnChoice.hpp>
|
|
#include <CreatureLib/Battling/TurnChoices/SwitchTurnChoice.hpp>
|
|
#include <PkmnLib/Battling/Pokemon/LearnedMove.hpp>
|
|
#include <algorithm>
|
|
#include <angelscript.h>
|
|
#include <future>
|
|
#include <thread>
|
|
#include "../cmake-build-release/PkmnLib/src/pkmnlib/src/Battling/Pokemon/PokemonParty.hpp"
|
|
#include "NaiveAI.hpp"
|
|
#include "PokemonAI.hpp"
|
|
|
|
namespace PkmnLibAI {
|
|
class DepthSearchAI : public PokemonAI {
|
|
NaiveAI _naive;
|
|
|
|
private:
|
|
class SimulatedResult {
|
|
public:
|
|
std::optional<float> Score;
|
|
CreatureLib::Battling::CreatureIndex Target;
|
|
|
|
SimulatedResult(std::optional<float> score, CreatureLib::Battling::CreatureIndex target)
|
|
: Score(score), Target(target) {}
|
|
SimulatedResult() {}
|
|
};
|
|
|
|
class ScoredOption {
|
|
public:
|
|
u8 OptionIndex;
|
|
std::optional<float> Score;
|
|
CreatureLib::Battling::CreatureIndex Target;
|
|
|
|
ScoredOption(u8 optionIndex, SimulatedResult result)
|
|
: OptionIndex(optionIndex), Score(result.Score), Target(result.Target) {}
|
|
ScoredOption() {}
|
|
};
|
|
|
|
float ScoreBattle(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) {
|
|
auto side = user->GetBattleSide().GetValue();
|
|
if (battle->HasEnded()) {
|
|
if (battle->GetResult().GetWinningSide() == side->GetSideIndex()) {
|
|
return std::numeric_limits<float>::max();
|
|
} else {
|
|
return std::numeric_limits<float>::min();
|
|
}
|
|
}
|
|
float score = 0.0;
|
|
auto opposite = GetOppositeIndex(user);
|
|
for (auto* party : battle->GetParties()) {
|
|
if (party->IsResponsibleForIndex(side->GetSideIndex(), 0)) {
|
|
for (auto& mon : party->GetParty()->GetParty()) {
|
|
score += mon->GetCurrentHealth() / (float)mon->GetMaxHealth();
|
|
}
|
|
}
|
|
if (party->IsResponsibleForIndex(opposite)) {
|
|
for (auto& mon : party->GetParty()->GetParty()) {
|
|
score -= (mon->GetCurrentHealth() / (float)mon->GetMaxHealth());
|
|
}
|
|
}
|
|
}
|
|
return score;
|
|
}
|
|
|
|
struct BattlePointerWrapper {
|
|
PkmnLib::Battling::Battle* Battle;
|
|
|
|
BattlePointerWrapper(PkmnLib::Battling::Battle* battle) : Battle(battle) {}
|
|
|
|
inline PkmnLib::Battling::Battle* operator->() const noexcept { return Battle; }
|
|
|
|
~BattlePointerWrapper() {
|
|
std::vector<const CreatureLib::Battling::CreatureParty*> parties;
|
|
for (auto party : Battle->GetParties()) {
|
|
parties.push_back(party->GetParty().GetRaw());
|
|
}
|
|
delete Battle;
|
|
for (auto party : parties) {
|
|
delete party;
|
|
}
|
|
}
|
|
};
|
|
|
|
std::vector<ScoredOption> ScoreChoicesThreaded(PkmnLib::Battling::Battle* battle,
|
|
PkmnLib::Battling::Pokemon* user, uint8_t depth) {
|
|
std::vector<std::future<ScoredOption>> threadPool;
|
|
auto side = user->GetBattleSide().GetValue();
|
|
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
|
|
auto move = user->GetMoves()[moveIndex];
|
|
if (!move.HasValue()) {
|
|
continue;
|
|
}
|
|
if (move.GetValue()->GetRemainingUses() == 0) {
|
|
continue;
|
|
}
|
|
threadPool.push_back(std::async([this, battle, side, moveIndex, depth] {
|
|
auto v = ScoredOption(moveIndex, SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth));
|
|
asThreadCleanup();
|
|
return v;
|
|
}));
|
|
}
|
|
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
|
|
for (u8 i = 0; i < party.Count(); ++i) {
|
|
auto mon = party[i];
|
|
if (!mon.HasValue()) {
|
|
continue;
|
|
}
|
|
if (mon.GetValue()->IsFainted()) {
|
|
continue;
|
|
}
|
|
threadPool.push_back(std::async([this, battle, side, i, depth] {
|
|
auto v = ScoredOption((u8)(i + 4), SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth));
|
|
asThreadCleanup();
|
|
return v;
|
|
}));
|
|
}
|
|
std::vector<ScoredOption> results(threadPool.size());
|
|
for (size_t i = 0; i < threadPool.size(); ++i) {
|
|
results[i] = threadPool[i].get();
|
|
}
|
|
return results;
|
|
}
|
|
|
|
std::vector<ScoredOption> ScoreChoices(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user,
|
|
uint8_t depth) {
|
|
std::vector<ScoredOption> scoredMoves;
|
|
auto side = user->GetBattleSide().GetValue();
|
|
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
|
|
auto move = user->GetMoves()[moveIndex];
|
|
if (!move.HasValue()) {
|
|
continue;
|
|
}
|
|
|
|
if (move.GetValue()->GetRemainingUses() == 0) {
|
|
continue;
|
|
}
|
|
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth);
|
|
scoredMoves.emplace_back(moveIndex, scored);
|
|
}
|
|
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
|
|
for (size_t i = 0; i < party.Count(); ++i) {
|
|
auto mon = party[i];
|
|
if (!mon.HasValue()) {
|
|
continue;
|
|
}
|
|
if (mon.GetValue()->IsFainted()) {
|
|
continue;
|
|
}
|
|
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth);
|
|
scoredMoves.emplace_back(i + 4, scored);
|
|
}
|
|
return scoredMoves;
|
|
}
|
|
|
|
std::optional<CreatureLib::Battling::CreatureIndex>
|
|
GetTarget(PkmnLib::Battling::Pokemon* user, ArbUt::BorrowedPtr<const CreatureLib::Library::AttackData> move) {
|
|
switch (move->GetTarget()) {
|
|
|
|
case CreatureLib::Library::AttackTarget::Adjacent: return {};
|
|
case CreatureLib::Library::AttackTarget::AdjacentAlly: return {};
|
|
case CreatureLib::Library::AttackTarget::AdjacentAllySelf: return user->GetBattleIndex();
|
|
case CreatureLib::Library::AttackTarget::AdjacentOpponent: return GetOppositeIndex(user);
|
|
case CreatureLib::Library::AttackTarget::All: return GetOppositeIndex(user);
|
|
case CreatureLib::Library::AttackTarget::AllAdjacent: return GetOppositeIndex(user);
|
|
case CreatureLib::Library::AttackTarget::AllAdjacentOpponent: return GetOppositeIndex(user);
|
|
case CreatureLib::Library::AttackTarget::AllAlly: return user->GetBattleIndex();
|
|
case CreatureLib::Library::AttackTarget::AllOpponent: return GetOppositeIndex(user);
|
|
case CreatureLib::Library::AttackTarget::Any: return GetOppositeIndex(user);
|
|
case CreatureLib::Library::AttackTarget::RandomOpponent: return GetOppositeIndex(user);
|
|
case CreatureLib::Library::AttackTarget::Self: return user->GetBattleIndex();
|
|
}
|
|
return {};
|
|
}
|
|
|
|
SimulatedResult SimulateTurn(PkmnLib::Battling::Battle* originalBattle, u8 sideIndex, u8 pokemonIndex, u8 index,
|
|
u8 depth) {
|
|
auto battle = BattlePointerWrapper(originalBattle->Clone());
|
|
auto user =
|
|
dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(sideIndex, pokemonIndex).GetValue());
|
|
CreatureLib::Battling::CreatureIndex target;
|
|
if (index < 4) {
|
|
auto move = user->GetMoves()[index];
|
|
if (!move.HasValue()) {
|
|
return {};
|
|
}
|
|
if (move.GetValue()->GetRemainingUses() <= 0) {
|
|
return {};
|
|
}
|
|
auto targetOption = GetTarget(user, move.GetValue()->GetAttack());
|
|
if (!targetOption.has_value()) {
|
|
return {};
|
|
}
|
|
target = targetOption.value();
|
|
auto choice = new CreatureLib::Battling::AttackTurnChoice(user, move.GetValue(), target);
|
|
if (!battle->TrySetChoice(choice)) {
|
|
delete choice;
|
|
return {};
|
|
}
|
|
} else {
|
|
auto mon = battle->GetParties()[sideIndex]->GetParty()->GetParty().At(index - 4);
|
|
auto choice = new CreatureLib::Battling::SwitchTurnChoice(user, mon);
|
|
if (!battle->TrySetChoice(choice)) {
|
|
delete choice;
|
|
return {};
|
|
}
|
|
}
|
|
battle->TrySetChoice(_naive.GetChoice(
|
|
battle.Battle,
|
|
dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(GetOppositeIndex(user)).GetValue())));
|
|
|
|
float score;
|
|
if (depth <= 1) {
|
|
score = ScoreBattle(battle.Battle, user);
|
|
} else {
|
|
auto scoredChoices = ScoreChoices(battle.Battle, user, depth - 1);
|
|
float summedScore = 0;
|
|
size_t amount = 0;
|
|
for (auto& option : scoredChoices) {
|
|
auto v = option.Score;
|
|
if (!v.has_value()) {
|
|
continue;
|
|
}
|
|
summedScore += v.value();
|
|
amount++;
|
|
}
|
|
if (amount == 0) {
|
|
return {};
|
|
}
|
|
|
|
score = summedScore / amount;
|
|
}
|
|
return SimulatedResult(score, target);
|
|
}
|
|
|
|
public:
|
|
std::string GetName() const noexcept override { return "depthsearch"; }
|
|
|
|
CreatureLib::Battling::BaseTurnChoice* GetChoice(PkmnLib::Battling::Battle* battle,
|
|
PkmnLib::Battling::Pokemon* user) override {
|
|
auto scoredChoices = ScoreChoicesThreaded(battle, user, 2);
|
|
auto side = user->GetBattleSide().GetValue();
|
|
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
|
|
|
|
if (scoredChoices.empty()) {
|
|
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, GetOppositeIndex(user));
|
|
}
|
|
|
|
i32 highest = -1;
|
|
float highestScore = -std::numeric_limits<float>::infinity();
|
|
CreatureLib::Battling::CreatureIndex highestTarget;
|
|
|
|
for (auto& option : scoredChoices) {
|
|
auto v = option.Score;
|
|
if (!v.has_value()) {
|
|
continue;
|
|
}
|
|
if (v.value() > highestScore) {
|
|
highestScore = v.value();
|
|
highest = option.OptionIndex;
|
|
highestTarget = option.Target;
|
|
}
|
|
}
|
|
if (highest == -1) {
|
|
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, GetOppositeIndex(user));
|
|
}
|
|
|
|
if (highest < 4) {
|
|
return new CreatureLib::Battling::AttackTurnChoice(user, user->GetMoves()[highest].GetValue(),
|
|
highestTarget);
|
|
} else {
|
|
return new CreatureLib::Battling::SwitchTurnChoice(user, party.At(highest - 4));
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
#endif // PKMNLIB_AI_DEPTHSEARCHAI_HPP
|