diff --git a/src/script_implementations/wasm/export_registry/mod.rs b/src/script_implementations/wasm/export_registry/mod.rs index e001740..ae2c71c 100644 --- a/src/script_implementations/wasm/export_registry/mod.rs +++ b/src/script_implementations/wasm/export_registry/mod.rs @@ -56,7 +56,7 @@ fn _error(env: &WebAssemblyEnv, message: u32, message_len: u32, file: u32, file_ } fn move_library_get_move_by_hash(env: &WebAssemblyEnv, lib: ExternRef, hash: u32) -> ExternRef { - let lib = lib.value(env); + let lib = lib.value(env).unwrap(); let m = lib.get_by_hash(hash); if let Some(v) = m { ExternRef::new(env, v) @@ -66,20 +66,20 @@ fn move_library_get_move_by_hash(env: &WebAssemblyEnv, lib: ExternRef) -> ExternRef { - let move_data = move_data.value(env); + let move_data = move_data.value(env).unwrap(); ExternRef::new(env, move_data.name()) } fn move_data_get_base_power(env: &WebAssemblyEnv, move_data: ExternRef) -> u8 { - move_data.value(env).base_power() + move_data.value(env).unwrap().base_power() } fn const_string_get_hash(env: &WebAssemblyEnv, string_key: ExternRef) -> u32 { - string_key.value(env).hash() + string_key.value(env).unwrap().hash() } fn const_string_get_str(env: &WebAssemblyEnv, string_key: ExternRef) -> u32 { - let string_key = string_key.value(env).str(); + let string_key = string_key.value(env).unwrap().str(); let s: CString = CString::new(string_key.as_bytes()).unwrap(); let wasm_string_ptr = env .resolver() @@ -93,9 +93,9 @@ fn battle_library_get_data_library( env: &WebAssemblyEnv, dynamic_lib: ExternRef, ) -> ExternRef { - ExternRef::new(env, dynamic_lib.value(env).static_data()) + ExternRef::new(env, dynamic_lib.value(env).unwrap().static_data()) } fn data_library_get_move_library(env: &WebAssemblyEnv, data_library: ExternRef) -> ExternRef { - ExternRef::new(env, data_library.value(env).moves()) + ExternRef::new(env, data_library.value(env).unwrap().moves()) } diff --git a/src/script_implementations/wasm/extern_ref.rs b/src/script_implementations/wasm/extern_ref.rs index 092db80..ff18adb 100644 --- a/src/script_implementations/wasm/extern_ref.rs +++ b/src/script_implementations/wasm/extern_ref.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use unique_type_id::UniqueTypeId; use wasmer::FromToNativeWasmType; -use crate::script_implementations::wasm::script_resolver::WebAssemblyEnv; +use crate::script_implementations::wasm::script_resolver::{WebAssemblyEnv, WebAssemblyScriptResolver}; pub(crate) struct ExternRef> { index: u32, @@ -18,13 +18,15 @@ impl> ExternRef { } } - pub fn from_index(index: u32) -> Self { + /// Creates an ExternRef with a given resolver. This can be used in cases where we do not have an environment variable. + pub(crate) fn new_with_resolver(resolver: &WebAssemblyScriptResolver, value: &T) -> Self { Self { - index, + index: resolver.environment_data().get_extern_ref_index(value), _phantom: Default::default(), } } + /// An empty value ExternRef. pub fn null() -> Self { Self { index: 0, @@ -32,9 +34,11 @@ impl> ExternRef { } } - pub fn value<'a, 'b>(&'a self, env: &'b WebAssemblyEnv) -> &'b T { + /// Returns the real value for a given ExternRef. Note that the requested type must be the same as the type of the + /// value when it was passed before. If these types do not match, this will panic. + pub fn value<'a, 'b>(&'a self, env: &'b WebAssemblyEnv) -> Option<&'b T> { let ptr = env.resolver().get_extern_ref_value(self.index) as *const T; - unsafe { ptr.as_ref().unwrap() } + unsafe { ptr.as_ref() } } } diff --git a/src/script_implementations/wasm/mod.rs b/src/script_implementations/wasm/mod.rs index 4a50b1a..2ce7f35 100644 --- a/src/script_implementations/wasm/mod.rs +++ b/src/script_implementations/wasm/mod.rs @@ -10,6 +10,7 @@ pub mod script_resolver; /// us to not call a function if we do not need to. #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] #[allow(missing_docs)] +#[allow(clippy::missing_docs_in_private_items)] pub enum WebAssemblyScriptCapabilities { None = 0, Initialize = 1, diff --git a/src/script_implementations/wasm/script_capabilities.rs b/src/script_implementations/wasm/script_capabilities.rs deleted file mode 100644 index 8b13789..0000000 --- a/src/script_implementations/wasm/script_capabilities.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/script_implementations/wasm/script_resolver.rs b/src/script_implementations/wasm/script_resolver.rs index a06f8b0..e869231 100644 --- a/src/script_implementations/wasm/script_resolver.rs +++ b/src/script_implementations/wasm/script_resolver.rs @@ -1,12 +1,14 @@ use std::fmt::{Debug, Formatter}; +use std::ops::DerefMut; use std::sync::Arc; use hashbrown::{HashMap, HashSet}; -use parking_lot::RwLock; -use unique_type_id::UniqueTypeId; +use parking_lot::lock_api::{MappedRwLockReadGuard, RwLockReadGuard}; +use parking_lot::{RawRwLock, RwLock}; +use unique_type_id::{TypeId, UniqueTypeId}; use wasmer::{ Cranelift, Exports, Extern, Features, Function, ImportObject, Instance, Memory, Module, NativeFunc, Store, - Universal, UniversalEngine, Value, WasmerEnv, + Universal, Value, WasmerEnv, }; use crate::dynamic_data::{ItemScript, Script, ScriptResolver}; @@ -19,26 +21,32 @@ use crate::{PkmnResult, ScriptCategory, StringKey}; /// A WebAssembly script resolver implements the dynamic scripts functionality with WebAssembly. pub struct WebAssemblyScriptResolver { - engine: UniversalEngine, + /// The global state storage of WASM. store: Store, - module: Option, - instance: Option, - memory: Option, - imports: HashMap, - exports: Exports, + /// The WASM modules we have loaded. + modules: Vec, + /// Our currently loaded WASM instances. Empty until finalize() is called, after which the loaded modules get turned + /// into actual instances. + instances: Vec, + /// This is a map of all the functions that WASM gives us. exported_functions: HashMap, + + /// This is the WASM function to load a script. load_script_fn: Option), u32>>, - allocate_mem_fn: Option>, + + /// Script capabilities tell us which functions are implemented on a given script. This allows us to skip unneeded + /// WASM calls. script_capabilities: RwLock>>, - extern_ref_pointers: RwLock>, - extern_ref_pointers_lookup: RwLock>, - extern_ref_type_lookup: RwLock>, + environment_data: Arc, } +/// This struct allows us to index a hashmap with both a category and name of a script. #[derive(Debug, Clone, Eq, PartialEq, Hash)] struct ScriptCapabilitiesKey { + /// The category for the script we're looking for capabilities for. category: ScriptCategory, + /// The name of the script we're looking for capabilities for. script_key: StringKey, } @@ -53,20 +61,13 @@ impl WebAssemblyScriptResolver { let engine = universal.engine(); let store = Store::new(&engine); let s = Self { - engine, store, - module: Default::default(), - instance: Default::default(), - memory: Default::default(), - imports: Default::default(), - exports: Default::default(), + modules: Default::default(), + instances: Default::default(), exported_functions: Default::default(), load_script_fn: None, - allocate_mem_fn: None, script_capabilities: Default::default(), - extern_ref_pointers: Default::default(), - extern_ref_pointers_lookup: Default::default(), - extern_ref_type_lookup: Default::default(), + environment_data: Arc::new(Default::default()), }; Box::new(s) } @@ -75,97 +76,53 @@ impl WebAssemblyScriptResolver { pub fn load_wasm_from_bytes(&mut self, bytes: &[u8]) { // FIXME: Error handling let module = Module::new(&self.store, bytes).unwrap(); - self.module = Some(module); - - self.finalize(); + self.modules.push(module); } /// Initialise all the data we need. - fn finalize(&mut self) { + pub fn finalize(&mut self) { let mut imports = ImportObject::new(); let mut exports = Exports::new(); let env = WebAssemblyEnv { - resolver: self as *const WebAssemblyScriptResolver, + resolver: self.environment_data.clone(), }; register_webassembly_funcs(&mut exports, &self.store, env); imports.register("env", exports); - self.instance = Some(Instance::new(&self.module.as_ref().unwrap(), &imports).unwrap()); - let exports = &self.instance.as_ref().unwrap().exports; - for export in exports.iter() { - match export.1 { - Extern::Function(f) => { - self.exported_functions.insert(export.0.as_str().into(), f.clone()); + for module in &self.modules { + let instance = Instance::new(module, &imports).unwrap(); + let exports = &instance.exports; + for export in exports.iter() { + match export.1 { + Extern::Function(f) => { + self.exported_functions.insert(export.0.as_str().into(), f.clone()); + } + Extern::Memory(m) => { + self.environment_data.memory.write().insert(m.clone()); + } + _ => {} } - Extern::Memory(m) => { - self.memory = Some(m.clone()); - } - _ => {} } - } - if let Some(m) = &self.memory { - m.grow(32).unwrap(); - } - if let Some(f) = self.exported_functions.get(&"load_script".into()) { - self.load_script_fn = Some(f.native().unwrap()) - } - if let Some(f) = self.exported_functions.get(&"allocate_mem".into()) { - self.allocate_mem_fn = Some(f.native().unwrap()) + if let Some(m) = &self.environment_data.memory.read().as_ref() { + m.grow(32).unwrap(); + } + if let Some(f) = self.exported_functions.get(&"load_script".into()) { + self.load_script_fn = Some(f.native().unwrap()) + } + if let Some(f) = self.exported_functions.get(&"allocate_mem".into()) { + self.environment_data + .allocate_mem_fn + .write() + .insert(f.native().unwrap()); + } + self.instances.push(instance); } } - /// Gets the internal WASM memory. - pub fn memory(&self) -> &Memory { - self.memory.as_ref().unwrap() - } - - /// Get a numeric value from any given value. This is not a true Extern Ref from WASM, as this - /// is not supported by our current WASM platform (Rust). Instead, this is simply a way to not - /// have to send arbitrary pointer values back and forth with WASM. Only values WASM can actually - /// access can be touched through this, and we ensure the value is the correct type. In the future, - /// when extern refs get actually properly implemented at compile time we might want to get rid - /// of this code. - pub fn get_extern_ref_index>(&self, value: &T) -> u32 { - let ptr = value as *const T as *const u8; - if let Some(v) = self.extern_ref_pointers_lookup.read().get(&ptr) { - return *v as u32; - } - let index = { - let mut extern_ref_guard = self.extern_ref_pointers.write(); - extern_ref_guard.push(ptr); - extern_ref_guard.len() as u32 - }; - self.extern_ref_pointers_lookup.write().insert(ptr, index); - self.extern_ref_type_lookup.write().insert(ptr, T::id().0); - index - } - - /// Gets a value from the extern ref lookup. This turns an earlier registered index back into - /// its proper value, validates its type, and returns the value. - pub fn get_extern_ref_value>(&self, index: u32) -> &T { - let read_guard = self.extern_ref_pointers.read(); - let ptr = read_guard.get((index - 1) as usize).unwrap(); - let expected_type_id = self.extern_ref_type_lookup.read()[&ptr]; - if expected_type_id != T::id().0 { - panic!("Extern ref was accessed with wrong type"); - } - - unsafe { (*ptr as *const T).as_ref().unwrap() } - } - - /// Allocates memory inside the WASM container with a given size and alignment. This memory is - /// owned by WASM, and is how we can pass memory references that the host allocated to WASM. - /// The return is a tuple containing both the actual pointer to the memory (usable by the host), - /// and the WASM offset to the memory (usable by the client). - pub fn allocate_mem(&self, size: u32, align: u32) -> (*const u8, u32) { - let wasm_ptr = self.allocate_mem_fn.as_ref().unwrap().call(size, align).unwrap(); - unsafe { - ( - self.memory.as_ref().unwrap().data_ptr().offset(wasm_ptr as isize), - wasm_ptr, - ) - } + /// Gets the data passed to every function as environment data. + pub fn environment_data(&self) -> &Arc { + &self.environment_data } } @@ -180,10 +137,7 @@ impl ScriptResolver for WebAssemblyScriptResolver { .load_script_fn .as_ref() .unwrap() - .call( - category as u8, - ExternRef::from_index(self.get_extern_ref_index(script_key)), - ) + .call(category as u8, ExternRef::new_with_resolver(self, script_key)) .unwrap(); if script == 0 { return Ok(None); @@ -199,11 +153,12 @@ impl ScriptResolver for WebAssemblyScriptResolver { unsafe { if let Some(get_cap) = self.exported_functions.get(&"get_script_capabilities".into()) { let res = get_cap.call(&[Value::I32(script as i32)]).unwrap(); - let ptr = (self.memory.as_ref().unwrap().data_ptr() as *const WebAssemblyScriptCapabilities) + let ptr = (self.environment_data.memory.read().as_ref().unwrap().data_ptr() + as *const WebAssemblyScriptCapabilities) .offset(res[0].i32().unwrap() as isize); let length = res[1].i32().unwrap() as usize; for i in 0..length { - capabilities.insert(*ptr.offset(i as isize)); + capabilities.insert(*ptr.add(i)); } } } @@ -234,14 +189,104 @@ impl Debug for WebAssemblyScriptResolver { } } +/// This data is what is passed to every function that requires access to the global runtime context. +#[derive(Default)] +pub struct WebAssemblyEnvironmentData { + /// We currently have a hacky implementation of extern refs while we're waiting for ExternRef support to hit the + /// wasm32-unknown-unknown target of Rust. As we don't want to pass raw memory pointers to WASM for security reasons, + /// we instead keep track of all the data we've sent to WASM, and pass the ID of that data to WASM. This allows us + /// to only operate on data we know WASM owns. We currently store this data in this continuous Vec, and give the index + /// of the data as the ID. + extern_ref_pointers: RwLock>, + /// To make sure we send the same identifier to WASM when we send the same piece of data multiple times, we have a + /// backwards lookup on extern_ref_pointers. This allows us to get the index for a given piece of data. + extern_ref_pointers_lookup: RwLock>, + /// As an added security measure on our extern refs, we keep track of the types of the extern ref data we've sent. + /// This prevents illegal arbitrary memory operations, where we expect type X, but the actual type is Y, which would + /// allow for modifying memory we might not want to. If we get a type mismatch, we will panic, preventing this. + extern_ref_type_lookup: RwLock>>, + + /// The memory inside of the WASM container. + memory: RwLock>, + + /// This is the WASM function to allocate memory inside the WASM container. + allocate_mem_fn: RwLock>>, +} + +impl WebAssemblyEnvironmentData { + /// This returns the memory of the WASM container. + pub fn memory(&self) -> MappedRwLockReadGuard<'_, RawRwLock, Memory> { + RwLockReadGuard::map(self.memory.read(), |a| a.as_ref().unwrap()) + } + + /// Allocates memory inside the WASM container with a given size and alignment. This memory is + /// owned by WASM, and is how we can pass memory references that the host allocated to WASM. + /// The return is a tuple containing both the actual pointer to the memory (usable by the host), + /// and the WASM offset to the memory (usable by the client). + pub fn allocate_mem(&self, size: u32, align: u32) -> (*const u8, u32) { + let wasm_ptr = self.allocate_mem_fn.read().as_ref().unwrap().call(size, align).unwrap(); + unsafe { + ( + self.memory + .read() + .as_ref() + .unwrap() + .data_ptr() + .offset(wasm_ptr as isize), + wasm_ptr, + ) + } + } + + /// Get a numeric value from any given value. This is not a true Extern Ref from WASM, as this + /// is not supported by our current WASM platform (Rust). Instead, this is simply a way to not + /// have to send arbitrary pointer values back and forth with WASM. Only values WASM can actually + /// access can be touched through this, and we ensure the value is the correct type. In the future, + /// when extern refs get actually properly implemented at compile time we might want to get rid + /// of this code. + pub fn get_extern_ref_index>(&self, value: &T) -> u32 { + let ptr = value as *const T as *const u8; + if let Some(v) = self.extern_ref_pointers_lookup.read().get(&ptr) { + return *v as u32; + } + let index = { + let mut extern_ref_guard = self.extern_ref_pointers.write(); + extern_ref_guard.push(ptr); + extern_ref_guard.len() as u32 + }; + self.extern_ref_pointers_lookup.write().insert(ptr, index); + self.extern_ref_type_lookup.write().insert(ptr, T::id()); + index + } + + /// Gets a value from the extern ref lookup. This turns an earlier registered index back into + /// its proper value, validates its type, and returns the value. + pub fn get_extern_ref_value>(&self, index: u32) -> &T { + let read_guard = self.extern_ref_pointers.read(); + let ptr = read_guard.get((index - 1) as usize).unwrap(); + let expected_type_id = &self.extern_ref_type_lookup.read()[ptr]; + if expected_type_id.0 != T::id().0 { + panic!( + "Extern ref was accessed with wrong type. Requested type {}, but this was not the type the extern ref was stored with.", + std::any::type_name::() + ); + } + + unsafe { (*ptr as *const T).as_ref().unwrap() } + } +} + +/// The runtime environment for script execution. This is passed to most of the host functions being called. #[derive(Clone)] pub(crate) struct WebAssemblyEnv { - pub resolver: *const WebAssemblyScriptResolver, + /// A pointer to the WebAssemblyScriptResolver belonging to the current script environment. + pub resolver: Arc, } impl WebAssemblyEnv { - pub fn resolver(&self) -> &WebAssemblyScriptResolver { - unsafe { self.resolver.as_ref().unwrap() } + /// Get the WebAssemblyScriptResolver belonging to the current context. + pub fn resolver(&self) -> &Arc { + &self.resolver } } diff --git a/src/static_data/species_data/species.rs b/src/static_data/species_data/species.rs index e120e66..38d94bb 100644 --- a/src/static_data/species_data/species.rs +++ b/src/static_data/species_data/species.rs @@ -28,6 +28,7 @@ pub struct Species { /// A cached String Key to get the default form. static DEFAULT_KEY: conquer_once::OnceCell = conquer_once::OnceCell::uninit(); +/// Gets the StringKey for "default". Initialises it if it does not exist. fn get_default_key() -> StringKey { DEFAULT_KEY.get_or_init(|| StringKey::new("default")).clone() } diff --git a/src/utils/string_key.rs b/src/utils/string_key.rs index 356fc31..b8437ab 100644 --- a/src/utils/string_key.rs +++ b/src/utils/string_key.rs @@ -1,5 +1,6 @@ use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; +use std::ops::Deref; use std::sync::{Arc, Mutex, Weak}; use conquer_once::OnceCell; @@ -110,7 +111,7 @@ impl Equivalent for u32 { impl Display for StringKey { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(&*self.str) + f.write_str(self.str.deref()) } } diff --git a/tests/common/library_loader.rs b/tests/common/library_loader.rs index b0e4dcb..f28577f 100644 --- a/tests/common/library_loader.rs +++ b/tests/common/library_loader.rs @@ -281,6 +281,7 @@ fn load_wasm(path: &String, library: &mut WebAssemblyScriptResolver) { let mut buffer = Vec::new(); reader.read_to_end(&mut buffer).unwrap(); library.load_wasm_from_bytes(&buffer); + library.finalize(); } fn parse_form(name: StringKey, value: &Value, library: &mut StaticData) -> Form { diff --git a/types.toml b/types.toml index 58ae822..60c4e0e 100644 --- a/types.toml +++ b/types.toml @@ -8,3 +8,8 @@ MoveLibrary = 1 StaticData = 2 MoveData = 3 StringKey = 4 +DynamicLibrary = 0 +MoveLibrary = 1 +StaticData = 2 +MoveData = 3 +StringKey = 4