vmkit-core/vmkit/examples/binarytrees.rs

264 lines
7.2 KiB
Rust
Raw Normal View History

2025-02-12 20:42:03 +07:00
use mmtk::util::Address;
2025-02-13 22:32:44 +07:00
use mmtk::vm::slot::UnimplementedMemorySlice;
2025-02-12 20:42:03 +07:00
use mmtk::{util::options::PlanSelector, vm::slot::SimpleSlot, AllocationSemantics, MMTKBuilder};
use std::cell::RefCell;
use std::mem::offset_of;
use std::sync::Arc;
use std::sync::OnceLock;
use vmkit::threading::parked_scope;
use vmkit::{
mm::{traits::Trace, MemoryManager},
object_model::{
metadata::{GCMetadata, TraceCallback},
object::VMKitObject,
},
sync::Monitor,
threading::{GCBlockAdapter, Thread, ThreadContext},
VMKit, VirtualMachine,
};
#[repr(C)]
struct Node {
left: NodeRef,
right: NodeRef,
}
static METADATA: GCMetadata<BenchVM> = GCMetadata {
trace: TraceCallback::TraceObject(|object, tracer| unsafe {
let node = object.as_address().as_mut_ref::<Node>();
node.left.0.trace_object(tracer);
node.right.0.trace_object(tracer);
}),
instance_size: size_of::<Node>(),
compute_size: None,
alignment: 16,
2025-02-14 14:03:47 +07:00
compute_alignment: None,
2025-02-12 20:42:03 +07:00
};
struct BenchVM {
vmkit: VMKit<Self>,
}
static VM: OnceLock<BenchVM> = OnceLock::new();
struct ThreadBenchContext;
impl ThreadContext<BenchVM> for ThreadBenchContext {
fn new(_: bool) -> Self {
Self
}
fn save_thread_state(&self) {}
fn scan_roots(
&self,
_factory: impl mmtk::vm::RootsWorkFactory<<BenchVM as VirtualMachine>::Slot>,
) {
}
fn scan_conservative_roots(
&self,
_croots: &mut vmkit::mm::conservative_roots::ConservativeRoots,
) {
}
}
impl VirtualMachine for BenchVM {
type BlockAdapterList = (GCBlockAdapter, ());
type Metadata = &'static GCMetadata<Self>;
type Slot = SimpleSlot;
type ThreadContext = ThreadBenchContext;
2025-02-13 22:32:44 +07:00
type MemorySlice = UnimplementedMemorySlice;
2025-02-12 20:42:03 +07:00
fn get() -> &'static Self {
VM.get().unwrap()
}
fn vmkit(&self) -> &VMKit<Self> {
&self.vmkit
}
fn prepare_for_roots_re_scanning() {}
fn notify_initial_thread_scan_complete(partial_scan: bool, tls: mmtk::util::VMWorkerThread) {
let _ = partial_scan;
let _ = tls;
}
fn forward_weak_refs(
_worker: &mut mmtk::scheduler::GCWorker<vmkit::mm::MemoryManager<Self>>,
_tracer_context: impl mmtk::vm::ObjectTracerContext<vmkit::mm::MemoryManager<Self>>,
) {
}
fn scan_roots_in_mutator_thread(
_tls: mmtk::util::VMWorkerThread,
_mutator: &'static mut mmtk::Mutator<vmkit::mm::MemoryManager<Self>>,
_factory: impl mmtk::vm::RootsWorkFactory<
<vmkit::mm::MemoryManager<Self> as mmtk::vm::VMBinding>::VMSlot,
>,
) {
}
fn scan_vm_specific_roots(
_tls: mmtk::util::VMWorkerThread,
_factory: impl mmtk::vm::RootsWorkFactory<
<vmkit::mm::MemoryManager<Self> as mmtk::vm::VMBinding>::VMSlot,
>,
) {
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq)]
struct NodeRef(VMKitObject);
impl NodeRef {
pub fn new(thread: &Thread<BenchVM>, left: NodeRef, right: NodeRef) -> Self {
let node = MemoryManager::<BenchVM>::allocate(
thread,
size_of::<Node>(),
2025-02-13 17:24:08 +07:00
32,
2025-02-12 20:42:03 +07:00
&METADATA,
AllocationSemantics::Default,
);
node.set_field_object::<BenchVM, false>(offset_of!(Node, left), left.0);
node.set_field_object::<BenchVM, false>(offset_of!(Node, right), right.0);
Self(node)
}
pub fn left(self) -> NodeRef {
unsafe {
let node = self.0.as_address().as_ref::<Node>();
node.left
}
}
pub fn right(self) -> NodeRef {
unsafe {
let node = self.0.as_address().as_ref::<Node>();
node.right
}
}
pub fn null() -> Self {
Self(VMKitObject::NULL)
}
pub fn item_check(&self) -> usize {
if self.left() == NodeRef::null() {
1
} else {
1 + self.left().item_check() + self.right().item_check()
}
}
pub fn leaf(thread: &Thread<BenchVM>) -> Self {
Self::new(thread, NodeRef::null(), NodeRef::null())
}
}
fn bottom_up_tree(thread: &Thread<BenchVM>, depth: usize) -> NodeRef {
if thread.take_yieldpoint() != 0 {
Thread::<BenchVM>::yieldpoint(0, Address::ZERO);
}
if depth > 0 {
NodeRef::new(
thread,
bottom_up_tree(thread, depth - 1),
bottom_up_tree(thread, depth - 1),
)
} else {
NodeRef::leaf(thread)
}
}
const MIN_DEPTH: usize = 4;
fn main() {
env_logger::init();
let nthreads = std::env::var("THREADS")
.unwrap_or("4".to_string())
.parse::<usize>()
.unwrap();
let mut builder = MMTKBuilder::new();
builder.options.plan.set(PlanSelector::Immix);
builder.options.threads.set(nthreads);
builder
.options
.gc_trigger
.set(mmtk::util::options::GCTriggerSelector::DynamicHeapSize(
4 * 1024 * 1024 * 1024,
16 * 1024 * 1024 * 1024,
));
VM.set(BenchVM {
vmkit: VMKit::new(&mut builder),
})
.unwrap_or_else(|_| panic!());
Thread::<BenchVM>::main(ThreadBenchContext, || {
let thread = Thread::<BenchVM>::current();
let start = std::time::Instant::now();
let n = std::env::var("DEPTH")
.unwrap_or("18".to_string())
.parse::<usize>()
.unwrap();
let max_depth = if n < MIN_DEPTH + 2 { MIN_DEPTH + 2 } else { n };
let stretch_depth = max_depth + 1;
println!("stretch tree of depth {stretch_depth}");
let _ = bottom_up_tree(&thread, stretch_depth);
let duration = start.elapsed();
println!("time: {duration:?}");
let results = Arc::new(Monitor::new(vec![
RefCell::new(String::new());
(max_depth - MIN_DEPTH) / 2 + 1
]));
let mut handles = Vec::new();
for d in (MIN_DEPTH..=max_depth).step_by(2) {
let depth = d;
let thread = Thread::<BenchVM>::for_mutator(ThreadBenchContext);
let results = results.clone();
let handle = thread.start(move || {
let thread = Thread::<BenchVM>::current();
let mut check = 0;
let iterations = 1 << (max_depth - depth + MIN_DEPTH);
for _ in 1..=iterations {
let tree_node = bottom_up_tree(&thread, depth);
check += tree_node.item_check();
}
*results.lock_with_handshake::<BenchVM>()[(depth - MIN_DEPTH) / 2].borrow_mut() =
format!("{iterations}\t trees of depth {depth}\t check: {check}");
});
handles.push(handle);
}
println!("created {} threads", handles.len());
parked_scope::<(), BenchVM>(|| {
while let Some(handle) = handles.pop() {
handle.join().unwrap();
}
});
for result in results.lock_with_handshake::<BenchVM>().iter() {
println!("{}", result.borrow());
}
println!(
"long lived tree of depth {max_depth}\t check: {}",
bottom_up_tree(&thread, max_depth).item_check()
);
let duration = start.elapsed();
println!("time: {duration:?}");
});
}