PkmnLib/src/ScriptResolving/WASM/WebAssemblyScriptResolver.hpp

93 lines
3.4 KiB
C++

#ifndef PKMNLIB_WEBASSEMBLYSCRIPTRESOLVER_HPP
#define PKMNLIB_WEBASSEMBLYSCRIPTRESOLVER_HPP
#include <Arbutils/Collections/Dictionary.hpp>
#include <optional>
#include <utility>
#include <wasm.h>
#include "../../Battling/Library/ScriptResolver.hpp"
#include "WebAssemblyBattleScript.hpp"
#include "WebAssemblyFunctionCall.hpp"
#include "WebAssemblyScriptCapabilities.hpp"
class WebAssemblyScriptResolver : public PkmnLib::Battling::ScriptResolver {
public:
WebAssemblyScriptResolver();
~WebAssemblyScriptResolver();
u8 LoadWasmFromFile(const std::string& path);
u8 LoadWasmFromBytes(std::vector<u8>);
u8 LoadWatFromString(const std::string& data);
void RegisterFunction();
void Finalize();
template <u32 argsCount, u32 returnsCount>
inline std::optional<WebAssemblyFunctionCall<argsCount, returnsCount>>
GetFunction(const ArbUt::StringView& name) const {
auto res = _exportedFunctions.TryGet(name);
if (!res.has_value()) {
return {};
}
return std::make_optional<WebAssemblyFunctionCall<argsCount, returnsCount>>(
ArbUt::BorrowedPtr<wasm_func_t>(res.value()));
}
std::pair<u8*, i32> AllocateMemory(u32 size, u32 align) const {
auto funcOpt = GetFunction<2, 1>("allocate_mem");
auto& func = funcOpt.value();
func.Loadi32(0, size);
func.Loadi32(1, align);
func.Call();
auto memoryOffset = func.GetResultAsi32();
return std::make_pair(reinterpret_cast<u8*>(wasm_memory_data(_memory) + memoryOffset), memoryOffset);
}
[[nodiscard]] inline wasm_memory_t* GetMemory() const noexcept { return _memory; }
CreatureLib::Battling::BattleScript* LoadScript(const ArbUt::OptionalBorrowedPtr<void>& owner,
ScriptCategory category,
const ArbUt::StringView& scriptName) nullable override;
[[nodiscard]] inline wasm_store_t* GetStore() const noexcept { return _store; }
inline void RemoveRegisteredScript(i32 wasmPtr) { _loadedScripts.Remove(wasmPtr); }
template <typename T>
inline void MarkLoadedPointer(T* ptr){
_loadedPointers.Set((void*)ptr, typeid(T));
}
template <typename T>
inline bool ValidateLoadedPointer(void* ptr){
const auto& opt = _loadedPointers.TryGet(ptr);
return opt.has_value() && opt.value() == typeid(T);
}
private:
wasm_engine_t* _engine;
wasm_store_t* _store;
wasm_module_t* _module = nullptr;
wasm_instance_t* _instance = nullptr;
wasm_memory_t* _memory = nullptr;
ArbUt::Dictionary<std::string, wasm_func_t*> _imports;
wasm_extern_vec_t _exports = {0, nullptr};
ArbUt::Dictionary<ArbUt::StringView, wasm_func_t*> _exportedFunctions;
ArbUt::Dictionary<i32, WebAssemblyBattleScript*> _loadedScripts;
void RegisterDefaultMethods();
typedef std::pair<ScriptCategory, ArbUt::StringView> scriptCapabilitiesKey;
struct pair_hash {
template <class T1, class T2> std::size_t operator()(const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
}
};
std::unordered_map<scriptCapabilitiesKey, std::unordered_set<WebAssemblyScriptCapabilities>, pair_hash>
_scriptCapabilities;
ArbUt::Dictionary<void*, std::type_info> _loadedPointers;
};
#endif // PKMNLIB_WEBASSEMBLYSCRIPTRESOLVER_HPP