PkmnLibAI/src/DepthSearchAI.hpp

280 lines
12 KiB
C++

#ifndef PKMNLIB_AI_DEPTHSEARCHAI_HPP
#define PKMNLIB_AI_DEPTHSEARCHAI_HPP
#include <CreatureLib/Battling/TurnChoices/AttackTurnChoice.hpp>
#include <CreatureLib/Battling/TurnChoices/PassTurnChoice.hpp>
#include <CreatureLib/Battling/TurnChoices/SwitchTurnChoice.hpp>
#include <PkmnLib/Battling/Pokemon/LearnedMove.hpp>
#include <algorithm>
#include <angelscript.h>
#include <future>
#include <thread>
#include "../cmake-build-release/PkmnLib/src/pkmnlib/src/Battling/Pokemon/PokemonParty.hpp"
#include "NaiveAI.hpp"
#include "PokemonAI.hpp"
namespace PkmnLibAI {
class DepthSearchAI : public PokemonAI {
NaiveAI _naive;
private:
class SimulatedResult {
public:
std::optional<float> Score;
CreatureLib::Battling::CreatureIndex Target;
SimulatedResult(std::optional<float> score, CreatureLib::Battling::CreatureIndex target)
: Score(score), Target(target) {}
SimulatedResult() {}
};
class ScoredOption {
public:
u8 OptionIndex;
std::optional<float> Score;
CreatureLib::Battling::CreatureIndex Target;
ScoredOption(u8 optionIndex, SimulatedResult result)
: OptionIndex(optionIndex), Score(result.Score), Target(result.Target) {}
ScoredOption() {}
};
float ScoreBattle(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user) {
auto side = user->GetBattleSide().GetValue();
if (battle->HasEnded()) {
if (battle->GetResult().GetWinningSide() == side->GetSideIndex()) {
return std::numeric_limits<float>::max();
} else {
return std::numeric_limits<float>::min();
}
}
float score = 0.0;
auto opposite = GetOppositeIndex(user);
for (auto* party : battle->GetParties()) {
if (party->IsResponsibleForIndex(side->GetSideIndex(), 0)) {
for (auto& mon : party->GetParty()->GetParty()) {
score += mon->GetCurrentHealth() / (float)mon->GetMaxHealth();
}
}
if (party->IsResponsibleForIndex(opposite)) {
for (auto& mon : party->GetParty()->GetParty()) {
score -= (mon->GetCurrentHealth() / (float)mon->GetMaxHealth());
}
}
}
return score;
}
struct BattlePointerWrapper {
PkmnLib::Battling::Battle* Battle;
BattlePointerWrapper(PkmnLib::Battling::Battle* battle) : Battle(battle) {}
inline PkmnLib::Battling::Battle* operator->() const noexcept { return Battle; }
~BattlePointerWrapper() {
std::vector<const CreatureLib::Battling::CreatureParty*> parties;
for (auto party : Battle->GetParties()) {
parties.push_back(party->GetParty().GetRaw());
}
delete Battle;
for (auto party : parties) {
delete party;
}
}
};
std::vector<ScoredOption> ScoreChoicesThreaded(PkmnLib::Battling::Battle* battle,
PkmnLib::Battling::Pokemon* user, uint8_t depth) {
std::vector<std::future<ScoredOption>> threadPool;
auto side = user->GetBattleSide().GetValue();
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
auto move = user->GetMoves()[moveIndex];
if (!move.HasValue()) {
continue;
}
if (move.GetValue()->GetRemainingUses() == 0) {
continue;
}
threadPool.push_back(std::async([this, battle, side, moveIndex, depth] {
auto v = ScoredOption(moveIndex, SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth));
asThreadCleanup();
return v;
}));
}
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
for (u8 i = 0; i < party.Count(); ++i) {
auto mon = party[i];
if (!mon.HasValue()) {
continue;
}
if (mon.GetValue()->IsFainted()) {
continue;
}
threadPool.push_back(std::async([this, battle, side, i, depth] {
auto v = ScoredOption((u8)(i + 4), SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth));
asThreadCleanup();
return v;
}));
}
std::vector<ScoredOption> results(threadPool.size());
for (size_t i = 0; i < threadPool.size(); ++i) {
results[i] = threadPool[i].get();
}
return results;
}
std::vector<ScoredOption> ScoreChoices(PkmnLib::Battling::Battle* battle, PkmnLib::Battling::Pokemon* user,
uint8_t depth) {
std::vector<ScoredOption> scoredMoves;
auto side = user->GetBattleSide().GetValue();
for (u8 moveIndex = 0; moveIndex < (u8)user->GetMoves().Count(); ++moveIndex) {
auto move = user->GetMoves()[moveIndex];
if (!move.HasValue()) {
continue;
}
if (move.GetValue()->GetRemainingUses() == 0) {
continue;
}
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, moveIndex, depth);
scoredMoves.emplace_back(moveIndex, scored);
}
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
for (size_t i = 0; i < party.Count(); ++i) {
auto mon = party[i];
if (!mon.HasValue()) {
continue;
}
if (mon.GetValue()->IsFainted()) {
continue;
}
auto scored = SimulateTurn(battle, side->GetSideIndex(), 0, i + 4, depth);
scoredMoves.emplace_back(i + 4, scored);
}
return scoredMoves;
}
std::optional<CreatureLib::Battling::CreatureIndex>
GetTarget(PkmnLib::Battling::Pokemon* user, ArbUt::BorrowedPtr<const CreatureLib::Library::AttackData> move) {
switch (move->GetTarget()) {
case CreatureLib::Library::AttackTarget::Adjacent: return {};
case CreatureLib::Library::AttackTarget::AdjacentAlly: return {};
case CreatureLib::Library::AttackTarget::AdjacentAllySelf: return user->GetBattleIndex();
case CreatureLib::Library::AttackTarget::AdjacentOpponent: return GetOppositeIndex(user);
case CreatureLib::Library::AttackTarget::All: return GetOppositeIndex(user);
case CreatureLib::Library::AttackTarget::AllAdjacent: return GetOppositeIndex(user);
case CreatureLib::Library::AttackTarget::AllAdjacentOpponent: return GetOppositeIndex(user);
case CreatureLib::Library::AttackTarget::AllAlly: return user->GetBattleIndex();
case CreatureLib::Library::AttackTarget::AllOpponent: return GetOppositeIndex(user);
case CreatureLib::Library::AttackTarget::Any: return GetOppositeIndex(user);
case CreatureLib::Library::AttackTarget::RandomOpponent: return GetOppositeIndex(user);
case CreatureLib::Library::AttackTarget::Self: return user->GetBattleIndex();
}
return {};
}
SimulatedResult SimulateTurn(PkmnLib::Battling::Battle* originalBattle, u8 sideIndex, u8 pokemonIndex, u8 index,
u8 depth) {
auto battle = BattlePointerWrapper(originalBattle->Clone());
auto user =
dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(sideIndex, pokemonIndex).GetValue());
CreatureLib::Battling::CreatureIndex target;
if (index < 4) {
auto move = user->GetMoves()[index];
if (!move.HasValue()) {
return {};
}
if (move.GetValue()->GetRemainingUses() <= 0) {
return {};
}
auto targetOption = GetTarget(user, move.GetValue()->GetAttack());
if (!targetOption.has_value()) {
return {};
}
target = targetOption.value();
auto choice = new CreatureLib::Battling::AttackTurnChoice(user, move.GetValue(), target);
if (!battle->TrySetChoice(choice)) {
delete choice;
return {};
}
} else {
auto mon = battle->GetParties()[sideIndex]->GetParty()->GetParty().At(index - 4);
auto choice = new CreatureLib::Battling::SwitchTurnChoice(user, mon);
if (!battle->TrySetChoice(choice)) {
delete choice;
return {};
}
}
battle->TrySetChoice(_naive.GetChoice(
battle.Battle,
dynamic_cast<PkmnLib::Battling::Pokemon*>(battle->GetCreature(GetOppositeIndex(user)).GetValue())));
float score;
if (depth <= 1) {
score = ScoreBattle(battle.Battle, user);
} else {
auto scoredChoices = ScoreChoices(battle.Battle, user, depth - 1);
float summedScore = 0;
size_t amount = 0;
for (auto& option : scoredChoices) {
auto v = option.Score;
if (!v.has_value()) {
continue;
}
summedScore += v.value();
amount++;
}
if (amount == 0) {
return {};
}
score = summedScore / amount;
}
return SimulatedResult(score, target);
}
public:
std::string GetName() const noexcept override { return "depthsearch"; }
CreatureLib::Battling::BaseTurnChoice* GetChoice(PkmnLib::Battling::Battle* battle,
PkmnLib::Battling::Pokemon* user) override {
auto scoredChoices = ScoreChoicesThreaded(battle, user, 2);
auto side = user->GetBattleSide().GetValue();
auto& party = battle->GetParties()[side->GetSideIndex()]->GetParty()->GetParty();
if (scoredChoices.empty()) {
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, GetOppositeIndex(user));
}
i32 highest = -1;
float highestScore = -std::numeric_limits<float>::infinity();
CreatureLib::Battling::CreatureIndex highestTarget;
for (auto& option : scoredChoices) {
auto v = option.Score;
if (!v.has_value()) {
continue;
}
if (v.value() > highestScore) {
highestScore = v.value();
highest = option.OptionIndex;
highestTarget = option.Target;
}
}
if (highest == -1) {
return battle->GetLibrary()->GetMiscLibrary()->ReplacementAttack(user, GetOppositeIndex(user));
}
if (highest < 4) {
return new CreatureLib::Battling::AttackTurnChoice(user, user->GetMoves()[highest].GetValue(),
highestTarget);
} else {
return new CreatureLib::Battling::SwitchTurnChoice(user, party.At(highest - 4));
}
}
};
}
#endif // PKMNLIB_AI_DEPTHSEARCHAI_HPP