diff --git a/src/contexts/component_factory.rs b/src/contexts/component_factory.rs index a3c51e7..cb3d74a 100644 --- a/src/contexts/component_factory.rs +++ b/src/contexts/component_factory.rs @@ -73,7 +73,7 @@ mod tests { Ok(from_context!(context; A, B)) } - let mut context = Context::new(); + let context = Context::new(); context.set(A("Test")); context.set(B(67)); @@ -103,7 +103,7 @@ mod tests { // Given derive_ccf!(CFactory; A, B); let c_factory = CFactory; - let mut context = Context::new(); + let context = Context::new(); context.set(A("test")); context.set(B(67)); diff --git a/src/contexts/model_i.rs b/src/contexts/model_i.rs index 82135c5..d1c9144 100644 --- a/src/contexts/model_i.rs +++ b/src/contexts/model_i.rs @@ -1,11 +1,11 @@ use std::any::Any; +use std::cell::RefCell; use super::component::*; use super::component_map::*; - #[allow(dead_code)] pub struct Context<'t> { - map: ComponentMap, + map: RefCell, base_context: Option<&'t Context<'t>>, } @@ -13,13 +13,13 @@ pub struct Context<'t> { impl<'t> Context<'t> { pub fn new() -> Context<'t> { Context { - map: ComponentMap::new(), + map: RefCell::new(ComponentMap::new()), base_context: None, } } pub fn get(&self) -> Option> { - match self.map.get() { + match self.map.borrow().get() { component @ Some(_) => component, _ => match self.base_context { Some(context) => context.get(), @@ -28,8 +28,8 @@ impl<'t> Context<'t> { } } - pub fn set(&mut self, component: T) { - self.map.set(component) + pub fn set(&self, component: T) { + self.map.borrow_mut().set(component) } pub fn subcontext(&self) -> Context { @@ -58,11 +58,8 @@ mod tests { // Given #[derive(Debug, Eq, PartialEq)] struct TestStruct(&'static str); - let context = { - let mut tmp = Context::new(); - tmp.set(TestStruct("hello")); - tmp - }; + let context = Context::new(); + context.set(TestStruct("hello")); // When let test_struct = context.get::().unwrap(); @@ -76,11 +73,8 @@ mod tests { // Given #[derive(Debug, Eq, PartialEq)] struct TestStruct(&'static str); - let context = { - let mut tmp = Context::new(); - tmp.set(TestStruct("hello")); - tmp - }; + let context = Context::new(); + context.set(TestStruct("hello")); let subcontext = context.subcontext(); // When @@ -96,11 +90,8 @@ mod tests { #[derive(Debug, Eq, PartialEq)] struct TestStruct(&'static str); let context = Context::new(); - let subcontext = { - let mut tmp = context.subcontext(); - tmp.set(TestStruct("hello")); - tmp - }; + let subcontext = context.subcontext(); + subcontext.set(TestStruct("hello")); // When let test_struct = subcontext.get::().unwrap();