PkmnLib/src/ScriptResolving/WASM/WebAssemblyScriptResolver.cpp

176 lines
6.3 KiB
C++

#include "WebAssemblyScriptResolver.hpp"
#include <iostream>
#include <unordered_map>
#include <wasmer.h>
#include "InterfaceMethods/CoreMethods.hpp"
#include "InterfaceMethods/Library/LibraryMethods.hpp"
#include "WebAssemblyBattleScript.hpp"
#include "WebAssemblyFunctionCall.hpp"
#include "wasm.h"
PkmnLib::Battling::ScriptResolver* PkmnLib::Battling::BattleLibrary::CreateScriptResolver() {
return new WebAssemblyScriptResolver();
}
WebAssemblyScriptResolver::WebAssemblyScriptResolver() : _engine(wasm_engine_new()), _store(wasm_store_new(_engine)) {}
WebAssemblyScriptResolver::~WebAssemblyScriptResolver() {
wasm_extern_vec_delete(&_exports);
for (auto& import : _imports) {
wasm_func_delete(import.second);
}
if (_instance != nullptr) {
wasm_instance_delete(_instance);
}
if (_module != nullptr) {
wasm_module_delete(_module);
}
wasm_store_delete(_store);
wasm_engine_delete(_engine);
}
u8 WebAssemblyScriptResolver::LoadWasmFromFile(const std::string& path) {
auto file = fopen(path.c_str(), "rb");
if (!file) {
return 1;
}
fseek(file, 0L, SEEK_END);
size_t file_size = ftell(file);
fseek(file, 0L, SEEK_SET);
wasm_byte_vec_t wasm_bytes;
wasm_byte_vec_new_uninitialized(&wasm_bytes, file_size);
if (fread(wasm_bytes.data, file_size, 1, file) != 1) {
wasm_byte_vec_delete(&wasm_bytes);
return 2;
}
fclose(file);
_module = wasm_module_new(_store, &wasm_bytes);
if (_module == nullptr) {
wasm_byte_vec_delete(&wasm_bytes);
return 3;
}
wasm_byte_vec_delete(&wasm_bytes);
return 0;
}
u8 WebAssemblyScriptResolver::LoadWasmFromBytes(std::vector<u8> wasm_bytes) {
wasm_byte_vec_t data = {wasm_bytes.size(), (char*)wasm_bytes.data()};
_module = wasm_module_new(_store, &data);
if (_module == nullptr) {
return 3;
}
return 0;
}
u8 WebAssemblyScriptResolver::LoadWatFromString(const std::string& data) {
wasm_byte_vec_t wat;
wasm_byte_vec_new(&wat, data.size(), data.c_str());
wasm_byte_vec_t wasm_bytes;
wat2wasm(&wat, &wasm_bytes);
wasm_byte_vec_delete(&wat);
_module = wasm_module_new(_store, &wasm_bytes);
if (_module == nullptr) {
return 3;
}
wasm_byte_vec_delete(&wasm_bytes);
return 0;
}
void WebAssemblyScriptResolver::RegisterFunction() {}
void WebAssemblyScriptResolver::RegisterDefaultMethods() {
WebAssemblyCoreMethods::Register(_store, _imports, this);
LibraryMethods::Register(_store, _imports, this);
}
void WebAssemblyScriptResolver::Finalize() {
RegisterDefaultMethods();
auto imports = ArbUt::List<wasm_extern_t*>();
wasm_importtype_vec_t import_types;
wasm_module_imports(_module, &import_types);
for (size_t i = 0; i < import_types.size; ++i) {
auto importType = import_types.data[i];
auto nameWasm = wasm_importtype_name(importType);
auto name = std::string(nameWasm->data, nameWasm->size);
auto exportFunc = _imports.TryGet(name);
if (!exportFunc.has_value()) {
THROW("Missing imported WASM function: ", name);
}
imports.Append(wasm_func_as_extern(exportFunc.value()));
}
wasm_extern_vec_t import_object = {imports.Count(), const_cast<wasm_extern_t**>(imports.RawData())};
wasm_trap_t* trap = nullptr;
_instance = wasm_instance_new(_store, _module, &import_object, &trap);
wasm_importtype_vec_delete(&import_types);
if (_instance == nullptr) {
char* err = new char[wasmer_last_error_length()];
wasmer_last_error_message(err, wasmer_last_error_length());
std::cout << err << std::endl;
delete[] err;
}
EnsureNotNull(_instance);
wasm_exporttype_vec_t export_types;
wasm_module_exports(_module, &export_types);
wasm_instance_exports(_instance, &_exports);
for (size_t i = 0; i < export_types.size; ++i) {
auto t = wasm_externtype_kind(wasm_exporttype_type(export_types.data[i]));
if (t == WASM_EXTERN_FUNC) {
const auto* name = wasm_exporttype_name(export_types.data[i]);
_exportedFunctions.Insert(ArbUt::StringView(name->data, name->size), wasm_extern_as_func(_exports.data[i]));
} else if (t == WASM_EXTERN_MEMORY) {
_memory = wasm_extern_as_memory(_exports.data[i]);
}
}
wasm_exporttype_vec_delete(&export_types);
if (_memory != nullptr) {
wasm_memory_grow(_memory, 100);
}
}
CreatureLib::Battling::BattleScript*
WebAssemblyScriptResolver::LoadScript(const ArbUt::OptionalBorrowedPtr<void>& owner, ScriptCategory category,
const ArbUt::StringView& scriptName) {
auto loadScriptOpt = GetFunction<2, 1>("load_script"_cnc);
if (!loadScriptOpt.has_value()) {
return nullptr;
}
auto& loadScriptFunc = loadScriptOpt.value();
loadScriptFunc.Loadi32(0, static_cast<i32>(category));
loadScriptFunc.LoadExternRef(1, &scriptName);
loadScriptFunc.Call();
auto result = loadScriptFunc.GetResultAsi32();
if (result == 0) {
return nullptr;
}
auto key = std::pair<ScriptCategory, ArbUt::StringView>(category, scriptName);
auto findCapabilities = _scriptCapabilities.find(key);
std::unordered_set<WebAssemblyScriptCapabilities> capabilities;
if (findCapabilities != _scriptCapabilities.end()) {
capabilities = findCapabilities->second;
} else {
auto getCapabilitiesOpt = GetFunction<1,2>("get_script_capabilities"_cnc);
if (getCapabilitiesOpt.has_value()) {
auto& getCapabilitiesFunc = getCapabilitiesOpt.value();
getCapabilitiesFunc.Loadi32(0, result);
getCapabilitiesFunc.Call();
const auto* rawResult = getCapabilitiesFunc.GetRawResults();
auto ptr = (WebAssemblyScriptCapabilities*)(wasm_memory_data(_memory) + rawResult[0].of.i32);
auto end = (WebAssemblyScriptCapabilities*)(ptr + rawResult[1].of.i32);
auto vec = std::vector<WebAssemblyScriptCapabilities>(ptr, end);
for (auto capability: vec){
capabilities.insert(capability);
}
}
_scriptCapabilities[key] = capabilities;
}
auto script = new WebAssemblyBattleScript(owner, result, &_scriptCapabilities[key], this, scriptName);
_loadedScripts.Insert(result, script);
return script;
}