Cleans up the WASM environment handling.

Instead of passing the entire script resolver as a pointer to the WebAssemblyEnv, we now have split off the data actually relevant to the environment into it's own class, which is stored inside of an Arc.
This commit is contained in:
2022-07-18 13:18:11 +02:00
parent 9472c1cec2
commit 0961b199ff
9 changed files with 178 additions and 121 deletions

View File

@@ -56,7 +56,7 @@ fn _error(env: &WebAssemblyEnv, message: u32, message_len: u32, file: u32, file_
} }
fn move_library_get_move_by_hash(env: &WebAssemblyEnv, lib: ExternRef<MoveLibrary>, hash: u32) -> ExternRef<MoveData> { fn move_library_get_move_by_hash(env: &WebAssemblyEnv, lib: ExternRef<MoveLibrary>, hash: u32) -> ExternRef<MoveData> {
let lib = lib.value(env); let lib = lib.value(env).unwrap();
let m = lib.get_by_hash(hash); let m = lib.get_by_hash(hash);
if let Some(v) = m { if let Some(v) = m {
ExternRef::new(env, v) ExternRef::new(env, v)
@@ -66,20 +66,20 @@ fn move_library_get_move_by_hash(env: &WebAssemblyEnv, lib: ExternRef<MoveLibrar
} }
fn move_data_get_name(env: &WebAssemblyEnv, move_data: ExternRef<MoveData>) -> ExternRef<StringKey> { fn move_data_get_name(env: &WebAssemblyEnv, move_data: ExternRef<MoveData>) -> ExternRef<StringKey> {
let move_data = move_data.value(env); let move_data = move_data.value(env).unwrap();
ExternRef::new(env, move_data.name()) ExternRef::new(env, move_data.name())
} }
fn move_data_get_base_power(env: &WebAssemblyEnv, move_data: ExternRef<MoveData>) -> u8 { fn move_data_get_base_power(env: &WebAssemblyEnv, move_data: ExternRef<MoveData>) -> u8 {
move_data.value(env).base_power() move_data.value(env).unwrap().base_power()
} }
fn const_string_get_hash(env: &WebAssemblyEnv, string_key: ExternRef<StringKey>) -> u32 { fn const_string_get_hash(env: &WebAssemblyEnv, string_key: ExternRef<StringKey>) -> u32 {
string_key.value(env).hash() string_key.value(env).unwrap().hash()
} }
fn const_string_get_str(env: &WebAssemblyEnv, string_key: ExternRef<StringKey>) -> u32 { fn const_string_get_str(env: &WebAssemblyEnv, string_key: ExternRef<StringKey>) -> u32 {
let string_key = string_key.value(env).str(); let string_key = string_key.value(env).unwrap().str();
let s: CString = CString::new(string_key.as_bytes()).unwrap(); let s: CString = CString::new(string_key.as_bytes()).unwrap();
let wasm_string_ptr = env let wasm_string_ptr = env
.resolver() .resolver()
@@ -93,9 +93,9 @@ fn battle_library_get_data_library(
env: &WebAssemblyEnv, env: &WebAssemblyEnv,
dynamic_lib: ExternRef<DynamicLibrary>, dynamic_lib: ExternRef<DynamicLibrary>,
) -> ExternRef<StaticData> { ) -> ExternRef<StaticData> {
ExternRef::new(env, dynamic_lib.value(env).static_data()) ExternRef::new(env, dynamic_lib.value(env).unwrap().static_data())
} }
fn data_library_get_move_library(env: &WebAssemblyEnv, data_library: ExternRef<StaticData>) -> ExternRef<MoveLibrary> { fn data_library_get_move_library(env: &WebAssemblyEnv, data_library: ExternRef<StaticData>) -> ExternRef<MoveLibrary> {
ExternRef::new(env, data_library.value(env).moves()) ExternRef::new(env, data_library.value(env).unwrap().moves())
} }

View File

@@ -3,7 +3,7 @@ use std::marker::PhantomData;
use unique_type_id::UniqueTypeId; use unique_type_id::UniqueTypeId;
use wasmer::FromToNativeWasmType; use wasmer::FromToNativeWasmType;
use crate::script_implementations::wasm::script_resolver::WebAssemblyEnv; use crate::script_implementations::wasm::script_resolver::{WebAssemblyEnv, WebAssemblyScriptResolver};
pub(crate) struct ExternRef<T: UniqueTypeId<u64>> { pub(crate) struct ExternRef<T: UniqueTypeId<u64>> {
index: u32, index: u32,
@@ -18,13 +18,15 @@ impl<T: UniqueTypeId<u64>> ExternRef<T> {
} }
} }
pub fn from_index(index: u32) -> Self { /// Creates an ExternRef with a given resolver. This can be used in cases where we do not have an environment variable.
pub(crate) fn new_with_resolver(resolver: &WebAssemblyScriptResolver, value: &T) -> Self {
Self { Self {
index, index: resolver.environment_data().get_extern_ref_index(value),
_phantom: Default::default(), _phantom: Default::default(),
} }
} }
/// An empty value ExternRef.
pub fn null() -> Self { pub fn null() -> Self {
Self { Self {
index: 0, index: 0,
@@ -32,9 +34,11 @@ impl<T: UniqueTypeId<u64>> ExternRef<T> {
} }
} }
pub fn value<'a, 'b>(&'a self, env: &'b WebAssemblyEnv) -> &'b T { /// Returns the real value for a given ExternRef. Note that the requested type must be the same as the type of the
/// value when it was passed before. If these types do not match, this will panic.
pub fn value<'a, 'b>(&'a self, env: &'b WebAssemblyEnv) -> Option<&'b T> {
let ptr = env.resolver().get_extern_ref_value(self.index) as *const T; let ptr = env.resolver().get_extern_ref_value(self.index) as *const T;
unsafe { ptr.as_ref().unwrap() } unsafe { ptr.as_ref() }
} }
} }

View File

@@ -10,6 +10,7 @@ pub mod script_resolver;
/// us to not call a function if we do not need to. /// us to not call a function if we do not need to.
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
#[allow(missing_docs)] #[allow(missing_docs)]
#[allow(clippy::missing_docs_in_private_items)]
pub enum WebAssemblyScriptCapabilities { pub enum WebAssemblyScriptCapabilities {
None = 0, None = 0,
Initialize = 1, Initialize = 1,

View File

@@ -1,12 +1,14 @@
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
use std::ops::DerefMut;
use std::sync::Arc; use std::sync::Arc;
use hashbrown::{HashMap, HashSet}; use hashbrown::{HashMap, HashSet};
use parking_lot::RwLock; use parking_lot::lock_api::{MappedRwLockReadGuard, RwLockReadGuard};
use unique_type_id::UniqueTypeId; use parking_lot::{RawRwLock, RwLock};
use unique_type_id::{TypeId, UniqueTypeId};
use wasmer::{ use wasmer::{
Cranelift, Exports, Extern, Features, Function, ImportObject, Instance, Memory, Module, NativeFunc, Store, Cranelift, Exports, Extern, Features, Function, ImportObject, Instance, Memory, Module, NativeFunc, Store,
Universal, UniversalEngine, Value, WasmerEnv, Universal, Value, WasmerEnv,
}; };
use crate::dynamic_data::{ItemScript, Script, ScriptResolver}; use crate::dynamic_data::{ItemScript, Script, ScriptResolver};
@@ -19,26 +21,32 @@ use crate::{PkmnResult, ScriptCategory, StringKey};
/// A WebAssembly script resolver implements the dynamic scripts functionality with WebAssembly. /// A WebAssembly script resolver implements the dynamic scripts functionality with WebAssembly.
pub struct WebAssemblyScriptResolver { pub struct WebAssemblyScriptResolver {
engine: UniversalEngine, /// The global state storage of WASM.
store: Store, store: Store,
module: Option<Module>, /// The WASM modules we have loaded.
instance: Option<Instance>, modules: Vec<Module>,
memory: Option<Memory>, /// Our currently loaded WASM instances. Empty until finalize() is called, after which the loaded modules get turned
imports: HashMap<String, Function>, /// into actual instances.
exports: Exports, instances: Vec<Instance>,
/// This is a map of all the functions that WASM gives us.
exported_functions: HashMap<StringKey, Function>, exported_functions: HashMap<StringKey, Function>,
/// This is the WASM function to load a script.
load_script_fn: Option<NativeFunc<(u8, ExternRef<StringKey>), u32>>, load_script_fn: Option<NativeFunc<(u8, ExternRef<StringKey>), u32>>,
allocate_mem_fn: Option<NativeFunc<(u32, u32), u32>>,
/// Script capabilities tell us which functions are implemented on a given script. This allows us to skip unneeded
/// WASM calls.
script_capabilities: RwLock<HashMap<ScriptCapabilitiesKey, HashSet<WebAssemblyScriptCapabilities>>>, script_capabilities: RwLock<HashMap<ScriptCapabilitiesKey, HashSet<WebAssemblyScriptCapabilities>>>,
extern_ref_pointers: RwLock<Vec<*const u8>>, environment_data: Arc<WebAssemblyEnvironmentData>,
extern_ref_pointers_lookup: RwLock<HashMap<*const u8, u32>>,
extern_ref_type_lookup: RwLock<HashMap<*const u8, u64>>,
} }
/// This struct allows us to index a hashmap with both a category and name of a script.
#[derive(Debug, Clone, Eq, PartialEq, Hash)] #[derive(Debug, Clone, Eq, PartialEq, Hash)]
struct ScriptCapabilitiesKey { struct ScriptCapabilitiesKey {
/// The category for the script we're looking for capabilities for.
category: ScriptCategory, category: ScriptCategory,
/// The name of the script we're looking for capabilities for.
script_key: StringKey, script_key: StringKey,
} }
@@ -53,20 +61,13 @@ impl WebAssemblyScriptResolver {
let engine = universal.engine(); let engine = universal.engine();
let store = Store::new(&engine); let store = Store::new(&engine);
let s = Self { let s = Self {
engine,
store, store,
module: Default::default(), modules: Default::default(),
instance: Default::default(), instances: Default::default(),
memory: Default::default(),
imports: Default::default(),
exports: Default::default(),
exported_functions: Default::default(), exported_functions: Default::default(),
load_script_fn: None, load_script_fn: None,
allocate_mem_fn: None,
script_capabilities: Default::default(), script_capabilities: Default::default(),
extern_ref_pointers: Default::default(), environment_data: Arc::new(Default::default()),
extern_ref_pointers_lookup: Default::default(),
extern_ref_type_lookup: Default::default(),
}; };
Box::new(s) Box::new(s)
} }
@@ -75,97 +76,53 @@ impl WebAssemblyScriptResolver {
pub fn load_wasm_from_bytes(&mut self, bytes: &[u8]) { pub fn load_wasm_from_bytes(&mut self, bytes: &[u8]) {
// FIXME: Error handling // FIXME: Error handling
let module = Module::new(&self.store, bytes).unwrap(); let module = Module::new(&self.store, bytes).unwrap();
self.module = Some(module); self.modules.push(module);
self.finalize();
} }
/// Initialise all the data we need. /// Initialise all the data we need.
fn finalize(&mut self) { pub fn finalize(&mut self) {
let mut imports = ImportObject::new(); let mut imports = ImportObject::new();
let mut exports = Exports::new(); let mut exports = Exports::new();
let env = WebAssemblyEnv { let env = WebAssemblyEnv {
resolver: self as *const WebAssemblyScriptResolver, resolver: self.environment_data.clone(),
}; };
register_webassembly_funcs(&mut exports, &self.store, env); register_webassembly_funcs(&mut exports, &self.store, env);
imports.register("env", exports); imports.register("env", exports);
self.instance = Some(Instance::new(&self.module.as_ref().unwrap(), &imports).unwrap()); for module in &self.modules {
let exports = &self.instance.as_ref().unwrap().exports; let instance = Instance::new(module, &imports).unwrap();
for export in exports.iter() { let exports = &instance.exports;
match export.1 { for export in exports.iter() {
Extern::Function(f) => { match export.1 {
self.exported_functions.insert(export.0.as_str().into(), f.clone()); Extern::Function(f) => {
self.exported_functions.insert(export.0.as_str().into(), f.clone());
}
Extern::Memory(m) => {
self.environment_data.memory.write().insert(m.clone());
}
_ => {}
} }
Extern::Memory(m) => {
self.memory = Some(m.clone());
}
_ => {}
} }
} if let Some(m) = &self.environment_data.memory.read().as_ref() {
if let Some(m) = &self.memory { m.grow(32).unwrap();
m.grow(32).unwrap(); }
} if let Some(f) = self.exported_functions.get(&"load_script".into()) {
if let Some(f) = self.exported_functions.get(&"load_script".into()) { self.load_script_fn = Some(f.native().unwrap())
self.load_script_fn = Some(f.native().unwrap()) }
} if let Some(f) = self.exported_functions.get(&"allocate_mem".into()) {
if let Some(f) = self.exported_functions.get(&"allocate_mem".into()) { self.environment_data
self.allocate_mem_fn = Some(f.native().unwrap()) .allocate_mem_fn
.write()
.insert(f.native().unwrap());
}
self.instances.push(instance);
} }
} }
/// Gets the internal WASM memory. /// Gets the data passed to every function as environment data.
pub fn memory(&self) -> &Memory { pub fn environment_data(&self) -> &Arc<WebAssemblyEnvironmentData> {
self.memory.as_ref().unwrap() &self.environment_data
}
/// Get a numeric value from any given value. This is not a true Extern Ref from WASM, as this
/// is not supported by our current WASM platform (Rust). Instead, this is simply a way to not
/// have to send arbitrary pointer values back and forth with WASM. Only values WASM can actually
/// access can be touched through this, and we ensure the value is the correct type. In the future,
/// when extern refs get actually properly implemented at compile time we might want to get rid
/// of this code.
pub fn get_extern_ref_index<T: UniqueTypeId<u64>>(&self, value: &T) -> u32 {
let ptr = value as *const T as *const u8;
if let Some(v) = self.extern_ref_pointers_lookup.read().get(&ptr) {
return *v as u32;
}
let index = {
let mut extern_ref_guard = self.extern_ref_pointers.write();
extern_ref_guard.push(ptr);
extern_ref_guard.len() as u32
};
self.extern_ref_pointers_lookup.write().insert(ptr, index);
self.extern_ref_type_lookup.write().insert(ptr, T::id().0);
index
}
/// Gets a value from the extern ref lookup. This turns an earlier registered index back into
/// its proper value, validates its type, and returns the value.
pub fn get_extern_ref_value<T: UniqueTypeId<u64>>(&self, index: u32) -> &T {
let read_guard = self.extern_ref_pointers.read();
let ptr = read_guard.get((index - 1) as usize).unwrap();
let expected_type_id = self.extern_ref_type_lookup.read()[&ptr];
if expected_type_id != T::id().0 {
panic!("Extern ref was accessed with wrong type");
}
unsafe { (*ptr as *const T).as_ref().unwrap() }
}
/// Allocates memory inside the WASM container with a given size and alignment. This memory is
/// owned by WASM, and is how we can pass memory references that the host allocated to WASM.
/// The return is a tuple containing both the actual pointer to the memory (usable by the host),
/// and the WASM offset to the memory (usable by the client).
pub fn allocate_mem(&self, size: u32, align: u32) -> (*const u8, u32) {
let wasm_ptr = self.allocate_mem_fn.as_ref().unwrap().call(size, align).unwrap();
unsafe {
(
self.memory.as_ref().unwrap().data_ptr().offset(wasm_ptr as isize),
wasm_ptr,
)
}
} }
} }
@@ -180,10 +137,7 @@ impl ScriptResolver for WebAssemblyScriptResolver {
.load_script_fn .load_script_fn
.as_ref() .as_ref()
.unwrap() .unwrap()
.call( .call(category as u8, ExternRef::new_with_resolver(self, script_key))
category as u8,
ExternRef::from_index(self.get_extern_ref_index(script_key)),
)
.unwrap(); .unwrap();
if script == 0 { if script == 0 {
return Ok(None); return Ok(None);
@@ -199,11 +153,12 @@ impl ScriptResolver for WebAssemblyScriptResolver {
unsafe { unsafe {
if let Some(get_cap) = self.exported_functions.get(&"get_script_capabilities".into()) { if let Some(get_cap) = self.exported_functions.get(&"get_script_capabilities".into()) {
let res = get_cap.call(&[Value::I32(script as i32)]).unwrap(); let res = get_cap.call(&[Value::I32(script as i32)]).unwrap();
let ptr = (self.memory.as_ref().unwrap().data_ptr() as *const WebAssemblyScriptCapabilities) let ptr = (self.environment_data.memory.read().as_ref().unwrap().data_ptr()
as *const WebAssemblyScriptCapabilities)
.offset(res[0].i32().unwrap() as isize); .offset(res[0].i32().unwrap() as isize);
let length = res[1].i32().unwrap() as usize; let length = res[1].i32().unwrap() as usize;
for i in 0..length { for i in 0..length {
capabilities.insert(*ptr.offset(i as isize)); capabilities.insert(*ptr.add(i));
} }
} }
} }
@@ -234,14 +189,104 @@ impl Debug for WebAssemblyScriptResolver {
} }
} }
/// This data is what is passed to every function that requires access to the global runtime context.
#[derive(Default)]
pub struct WebAssemblyEnvironmentData {
/// We currently have a hacky implementation of extern refs while we're waiting for ExternRef support to hit the
/// wasm32-unknown-unknown target of Rust. As we don't want to pass raw memory pointers to WASM for security reasons,
/// we instead keep track of all the data we've sent to WASM, and pass the ID of that data to WASM. This allows us
/// to only operate on data we know WASM owns. We currently store this data in this continuous Vec, and give the index
/// of the data as the ID.
extern_ref_pointers: RwLock<Vec<*const u8>>,
/// To make sure we send the same identifier to WASM when we send the same piece of data multiple times, we have a
/// backwards lookup on extern_ref_pointers. This allows us to get the index for a given piece of data.
extern_ref_pointers_lookup: RwLock<HashMap<*const u8, u32>>,
/// As an added security measure on our extern refs, we keep track of the types of the extern ref data we've sent.
/// This prevents illegal arbitrary memory operations, where we expect type X, but the actual type is Y, which would
/// allow for modifying memory we might not want to. If we get a type mismatch, we will panic, preventing this.
extern_ref_type_lookup: RwLock<HashMap<*const u8, TypeId<u64>>>,
/// The memory inside of the WASM container.
memory: RwLock<Option<Memory>>,
/// This is the WASM function to allocate memory inside the WASM container.
allocate_mem_fn: RwLock<Option<NativeFunc<(u32, u32), u32>>>,
}
impl WebAssemblyEnvironmentData {
/// This returns the memory of the WASM container.
pub fn memory(&self) -> MappedRwLockReadGuard<'_, RawRwLock, Memory> {
RwLockReadGuard::map(self.memory.read(), |a| a.as_ref().unwrap())
}
/// Allocates memory inside the WASM container with a given size and alignment. This memory is
/// owned by WASM, and is how we can pass memory references that the host allocated to WASM.
/// The return is a tuple containing both the actual pointer to the memory (usable by the host),
/// and the WASM offset to the memory (usable by the client).
pub fn allocate_mem(&self, size: u32, align: u32) -> (*const u8, u32) {
let wasm_ptr = self.allocate_mem_fn.read().as_ref().unwrap().call(size, align).unwrap();
unsafe {
(
self.memory
.read()
.as_ref()
.unwrap()
.data_ptr()
.offset(wasm_ptr as isize),
wasm_ptr,
)
}
}
/// Get a numeric value from any given value. This is not a true Extern Ref from WASM, as this
/// is not supported by our current WASM platform (Rust). Instead, this is simply a way to not
/// have to send arbitrary pointer values back and forth with WASM. Only values WASM can actually
/// access can be touched through this, and we ensure the value is the correct type. In the future,
/// when extern refs get actually properly implemented at compile time we might want to get rid
/// of this code.
pub fn get_extern_ref_index<T: UniqueTypeId<u64>>(&self, value: &T) -> u32 {
let ptr = value as *const T as *const u8;
if let Some(v) = self.extern_ref_pointers_lookup.read().get(&ptr) {
return *v as u32;
}
let index = {
let mut extern_ref_guard = self.extern_ref_pointers.write();
extern_ref_guard.push(ptr);
extern_ref_guard.len() as u32
};
self.extern_ref_pointers_lookup.write().insert(ptr, index);
self.extern_ref_type_lookup.write().insert(ptr, T::id());
index
}
/// Gets a value from the extern ref lookup. This turns an earlier registered index back into
/// its proper value, validates its type, and returns the value.
pub fn get_extern_ref_value<T: UniqueTypeId<u64>>(&self, index: u32) -> &T {
let read_guard = self.extern_ref_pointers.read();
let ptr = read_guard.get((index - 1) as usize).unwrap();
let expected_type_id = &self.extern_ref_type_lookup.read()[ptr];
if expected_type_id.0 != T::id().0 {
panic!(
"Extern ref was accessed with wrong type. Requested type {}, but this was not the type the extern ref was stored with.",
std::any::type_name::<T>()
);
}
unsafe { (*ptr as *const T).as_ref().unwrap() }
}
}
/// The runtime environment for script execution. This is passed to most of the host functions being called.
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct WebAssemblyEnv { pub(crate) struct WebAssemblyEnv {
pub resolver: *const WebAssemblyScriptResolver, /// A pointer to the WebAssemblyScriptResolver belonging to the current script environment.
pub resolver: Arc<WebAssemblyEnvironmentData>,
} }
impl WebAssemblyEnv { impl WebAssemblyEnv {
pub fn resolver(&self) -> &WebAssemblyScriptResolver { /// Get the WebAssemblyScriptResolver belonging to the current context.
unsafe { self.resolver.as_ref().unwrap() } pub fn resolver(&self) -> &Arc<WebAssemblyEnvironmentData> {
&self.resolver
} }
} }

View File

@@ -28,6 +28,7 @@ pub struct Species {
/// A cached String Key to get the default form. /// A cached String Key to get the default form.
static DEFAULT_KEY: conquer_once::OnceCell<StringKey> = conquer_once::OnceCell::uninit(); static DEFAULT_KEY: conquer_once::OnceCell<StringKey> = conquer_once::OnceCell::uninit();
/// Gets the StringKey for "default". Initialises it if it does not exist.
fn get_default_key() -> StringKey { fn get_default_key() -> StringKey {
DEFAULT_KEY.get_or_init(|| StringKey::new("default")).clone() DEFAULT_KEY.get_or_init(|| StringKey::new("default")).clone()
} }

View File

@@ -1,5 +1,6 @@
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::{Arc, Mutex, Weak}; use std::sync::{Arc, Mutex, Weak};
use conquer_once::OnceCell; use conquer_once::OnceCell;
@@ -110,7 +111,7 @@ impl Equivalent<StringKey> for u32 {
impl Display for StringKey { impl Display for StringKey {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&*self.str) f.write_str(self.str.deref())
} }
} }

View File

@@ -281,6 +281,7 @@ fn load_wasm(path: &String, library: &mut WebAssemblyScriptResolver) {
let mut buffer = Vec::new(); let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).unwrap(); reader.read_to_end(&mut buffer).unwrap();
library.load_wasm_from_bytes(&buffer); library.load_wasm_from_bytes(&buffer);
library.finalize();
} }
fn parse_form(name: StringKey, value: &Value, library: &mut StaticData) -> Form { fn parse_form(name: StringKey, value: &Value, library: &mut StaticData) -> Form {

View File

@@ -8,3 +8,8 @@ MoveLibrary = 1
StaticData = 2 StaticData = 2
MoveData = 3 MoveData = 3
StringKey = 4 StringKey = 4
DynamicLibrary = 0
MoveLibrary = 1
StaticData = 2
MoveData = 3
StringKey = 4