diff --git a/core-relations/src/table/mod.rs b/core-relations/src/table/mod.rs index 4628a25b1..3166577f7 100644 --- a/core-relations/src/table/mod.rs +++ b/core-relations/src/table/mod.rs @@ -401,6 +401,10 @@ impl Table for SortedWritesTable { fn as_any(&self) -> &dyn Any { self } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } fn clear(&mut self) { self.pending_state.clear(); if self.data.data.len() == 0 { @@ -620,6 +624,12 @@ impl Table for SortedWritesTable { } } +impl SortedWritesTable { + pub fn set_merge(&mut self, merge: Box) { + self.merge = merge.into(); + } +} + impl SortedWritesTable { /// Create a new [`SortedWritesTable`] with the given number of keys, /// columns, and an optional sort column. diff --git a/core-relations/src/table_spec.rs b/core-relations/src/table_spec.rs index dc50ce360..6bae2c7d4 100644 --- a/core-relations/src/table_spec.rs +++ b/core-relations/src/table_spec.rs @@ -27,7 +27,7 @@ use crate::{ offsets::{RowId, Subset, SubsetRef}, pool::{with_pool_set, PoolSet, Pooled}, row_buffer::{RowBuffer, TaggedRowBuffer}, - QueryEntry, TableId, Variable, + DisplacedTable, DisplacedTableWithProvenance, QueryEntry, TableId, Variable, }; define_id!(pub ColumnId, u32, "a particular column in a table"); @@ -183,6 +183,9 @@ pub trait Table: Any + Send + Sync { /// `self`. fn as_any(&self) -> &dyn Any; + /// A mutable variant of [`Table::as_any`] for downcasting. + fn as_any_mut(&mut self) -> &mut dyn Any; + /// The schema of the table. /// /// These are immutable properties of the table; callers can assume they @@ -542,8 +545,17 @@ impl<'de> Deserialize<'de> for WrappedTable { D: serde::Deserializer<'de>, { let inner: Box = Deserialize::deserialize(deserializer)?; - - let wrapper = wrapper::(); // todo: different kind of wrapper? + let wrapper = if inner.as_any().is::() { + wrapper::() + } else if inner.as_any().is::() { + wrapper::() + } else if inner.as_any().is::() { + wrapper::() + } else { + return Err(serde::de::Error::custom( + "unknown table type for WrappedTable", + )); + }; Ok(WrappedTable { inner, wrapper }) } @@ -564,6 +576,10 @@ impl WrappedTable { } } + pub fn as_any_mut(&mut self) -> &mut dyn Any { + self.inner.as_any_mut() + } + pub(crate) fn as_ref(&self) -> WrappedTableRef<'_> { WrappedTableRef { inner: &*self.inner, diff --git a/core-relations/src/uf/mod.rs b/core-relations/src/uf/mod.rs index 5688ddb9e..ca589c775 100644 --- a/core-relations/src/uf/mod.rs +++ b/core-relations/src/uf/mod.rs @@ -279,6 +279,10 @@ impl Table for DisplacedTable { fn as_any(&self) -> &dyn Any { self } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } fn spec(&self) -> TableSpec { let mut uncacheable_columns = DenseIdMap::default(); // The second column of this table is determined dynamically by the union-find. @@ -882,6 +886,10 @@ impl Table for DisplacedTableWithProvenance { fn as_any(&self) -> &dyn Any { self } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } fn clear(&mut self) { self.base.clear() } diff --git a/egglog-bridge/src/lib.rs b/egglog-bridge/src/lib.rs index 7232d5def..b17177a93 100644 --- a/egglog-bridge/src/lib.rs +++ b/egglog-bridge/src/lib.rs @@ -768,6 +768,7 @@ impl EGraph { incremental_rebuild_rules: Default::default(), nonincremental_rebuild_rule: RuleId::new(!0), default_val: default, + merge: merge.clone(), can_subsume, name, }); @@ -1135,6 +1136,43 @@ impl EGraph { updated } + pub fn restore_deserialized_runtime(&mut self) { + let funcs = self + .funcs + .iter() + .map(|(func, info)| { + ( + func, + info.table, + info.schema.clone(), + info.can_subsume, + info.name.clone(), + info.merge.clone(), + ) + }) + .collect::>(); + for (func, table_id, schema, can_subsume, name, merge) in funcs { + let schema_math = SchemaMath { + tracing: self.tracing, + subsume: can_subsume, + func_cols: schema.len(), + }; + let merge_fn = merge.to_callback(schema_math, &name, self); + let table = self + .db + .get_table_mut(table_id) + .as_any_mut() + .downcast_mut::() + .expect("function tables must use SortedWritesTable"); + table.set_merge(merge_fn); + let incremental_rebuild_rules = self.incremental_rebuild_rules(func, &schema); + let nonincremental_rebuild_rule = self.nonincremental_rebuild(func, &schema); + let info = &mut self.funcs[func]; + info.incremental_rebuild_rules = incremental_rebuild_rules; + info.nonincremental_rebuild_rule = nonincremental_rebuild_rule; + } + } + pub fn set_report_level(&mut self, level: ReportLevel) { self.report_level = level; } @@ -1163,6 +1201,7 @@ struct FunctionInfo { incremental_rebuild_rules: Vec, nonincremental_rebuild_rule: RuleId, default_val: DefaultVal, + merge: MergeFn, can_subsume: bool, name: Arc, } @@ -1185,6 +1224,7 @@ pub enum DefaultVal { } /// How to resolve FD conflicts for a table. +#[derive(Clone, Serialize, Deserialize)] pub enum MergeFn { /// Panic if the old and new values don't match. AssertEq, diff --git a/infra/nightly-resources/web/chart.js b/infra/nightly-resources/web/chart.js index 466b69975..c96d8426e 100644 --- a/infra/nightly-resources/web/chart.js +++ b/infra/nightly-resources/web/chart.js @@ -197,7 +197,7 @@ function initializeCharts() { }, }, y: { - stacked: false, + stacked: true, title: { display: true, text: "Time (ms)", diff --git a/infra/nightly-resources/web/data.js b/infra/nightly-resources/web/data.js index 61578765c..c6a711e95 100644 --- a/infra/nightly-resources/web/data.js +++ b/infra/nightly-resources/web/data.js @@ -59,6 +59,7 @@ function initializeGlobalData() { GLOBAL_DATA.extractChart = null; GLOBAL_DATA.differenceChart = null; GLOBAL_DATA.minedChart = null; + GLOBAL_DATA.mineData = {}; return fetch("data/data.json") .then((response) => response.json()) @@ -86,22 +87,11 @@ function processRawData(blob) { if (!GLOBAL_DATA.data[suite]) { return; } - // Aggregate commands across all timelines const times = { benchmark, - ...Object.fromEntries(CMDS.map((cmd) => [cmd, []])), - other: [], + ...aggregateTimelinesByCommand(timelines), }; - timelines.forEach(({ events, sexps }) => { - events.forEach((time_micros, idx) => { - const cmd = getCmd(sexps[idx]); - - // Group times by command type - times[getCmdType(cmd)].push(time_micros / 1000); // we measure microseconds, but for charts, it's nicer to show in ms - }); - }); - GLOBAL_DATA.data[suite][runMode][benchmark] = times; }); } diff --git a/infra/nightly-resources/web/mined.html b/infra/nightly-resources/web/mined.html index 8dbe0e7fe..1609288de 100644 --- a/infra/nightly-resources/web/mined.html +++ b/infra/nightly-resources/web/mined.html @@ -29,7 +29,21 @@

POACH vs Vanilla Egglog

+ +

Extraction Cost Summary (easteregg)

+ + + + + + + + + + + +
benchmark name# extractsavg initial costavg final costavg cost difference
- \ No newline at end of file + diff --git a/infra/nightly-resources/web/mined.js b/infra/nightly-resources/web/mined.js index 56cccded6..b0920dbd2 100644 --- a/infra/nightly-resources/web/mined.js +++ b/infra/nightly-resources/web/mined.js @@ -1,45 +1,137 @@ function initialize() { - initializeGlobalData().then(initializeCharts).then(plotMine); + Promise.all([initializeGlobalData(), loadMineData()]) + .then(initializeCharts) + .then(() => { + plotMine(); + renderMineSummaryTable(); + }); } -function plotMine() { - const mega_mined = GLOBAL_DATA.data.easteregg["mine-mega"]; - const indiv_mined = GLOBAL_DATA.data.easteregg["mine-indiv"]; - const baseline = GLOBAL_DATA.data.easteregg.timeline; +function loadMineData() { + return fetch("data/mine-data.json") + .then((response) => response.json()) + .then((data) => { + GLOBAL_DATA.mineData = data; + }) + .catch((error) => { + console.error("Failed to load mine-data.json", error); + GLOBAL_DATA.mineData = {}; + }); +} +function plotMine() { if (GLOBAL_DATA.minedChart === null) { return; } - const benchmarks = Object.keys(baseline); - - const data = {}; - - benchmarks.forEach((b) => { - data[b] = {}; - - data[b].baseline = benchmarkTotalTime(baseline[b]); - data[b].mega_mined = benchmarkTotalTime(mega_mined[b]); - data[b].indiv_mined = benchmarkTotalTime(indiv_mined[b]); - }); + const benchmarks = Object.keys(GLOBAL_DATA.mineData).sort(); + const byBenchmark = Object.fromEntries( + benchmarks.map((b) => { + const baseline = aggregateTimelinesByCommand( + GLOBAL_DATA.mineData[b].baseline_timeline, + ); + const mega = aggregateTimelinesByCommand( + GLOBAL_DATA.mineData[b].mine_mega_timeline, + ); + const indiv = aggregateTimelinesByCommand( + GLOBAL_DATA.mineData[b].mine_indiv_timeline, + ); + return [ + b, + { + baseline: { + run: aggregate(baseline.run, "total"), + extract: aggregate(baseline.extract, "total"), + }, + mega: { + run: aggregate(mega.run, "total"), + extract: aggregate(mega.extract, "total"), + }, + indiv: { + run: aggregate(indiv.run, "total"), + extract: aggregate(indiv.extract, "total"), + }, + }, + ]; + }), + ); GLOBAL_DATA.minedChart.data = { labels: benchmarks, datasets: [ { - label: "baseline", - data: Object.values(data).map((d) => d.baseline), + label: "baseline: run", + stack: "baseline", + backgroundColor: "#1e3a8a", + data: benchmarks.map((b) => byBenchmark[b].baseline.run), + }, + { + label: "baseline: extract", + stack: "baseline", + backgroundColor: "#60a5fa", + data: benchmarks.map((b) => byBenchmark[b].baseline.extract), + }, + { + label: "mined (mega): run", + stack: "mega", + backgroundColor: "#b91c1c", + data: benchmarks.map((b) => byBenchmark[b].mega.run), + }, + { + label: "mined (mega): extract", + stack: "mega", + backgroundColor: "#f472b6", + data: benchmarks.map((b) => byBenchmark[b].mega.extract), }, { - label: "mined (mega)", - data: Object.values(data).map((d) => d.mega_mined), + label: "mined (indiv): run", + stack: "indiv", + backgroundColor: "#166534", + data: benchmarks.map((b) => byBenchmark[b].indiv.run), }, { - label: "mined (indiv)", - data: Object.values(data).map((d) => d.indiv_mined), + label: "mined (indiv): extract", + stack: "indiv", + backgroundColor: "#86efac", + data: benchmarks.map((b) => byBenchmark[b].indiv.extract), }, ], }; GLOBAL_DATA.minedChart.update(); } + +function renderMineSummaryTable() { + const tableBody = document.getElementById("mine-summary-body"); + if (!tableBody) { + return; + } + + const rows = Object.entries(GLOBAL_DATA.mineData) + .sort(([a], [b]) => a.localeCompare(b)) + .map(([benchmarkName, data]) => { + const entries = data.mine_mega_extracts || []; + const extractCount = entries.length; + const initialTotal = entries.reduce( + (sum, entry) => sum + entry.initial_cost, + 0, + ); + const finalTotal = entries.reduce( + (sum, entry) => sum + entry.final_cost, + 0, + ); + const avgInitialCost = + extractCount === 0 ? 0 : initialTotal / extractCount; + const avgFinalCost = extractCount === 0 ? 0 : finalTotal / extractCount; + return ` + + ${benchmarkName} + ${extractCount} + ${avgInitialCost} + ${avgFinalCost} + + `; + }); + + tableBody.innerHTML = rows.join("\n"); +} diff --git a/infra/nightly-resources/web/stylesheet.css b/infra/nightly-resources/web/stylesheet.css index 4e591be49..1331367e4 100644 --- a/infra/nightly-resources/web/stylesheet.css +++ b/infra/nightly-resources/web/stylesheet.css @@ -77,4 +77,16 @@ body { #on-error { display: none; -} \ No newline at end of file +} + +#mine-summary-table { + border-collapse: collapse; + margin-top: 16px; +} + +#mine-summary-table th, +#mine-summary-table td { + border: 1px solid #ddd; + padding: 6px 10px; + text-align: left; +} diff --git a/infra/nightly-resources/web/util.js b/infra/nightly-resources/web/util.js index 13fc415af..39022a86d 100644 --- a/infra/nightly-resources/web/util.js +++ b/infra/nightly-resources/web/util.js @@ -53,3 +53,19 @@ function getCmdType(cmd) { return "other"; } } + +function aggregateTimelinesByCommand(timelines) { + const times = { + ...Object.fromEntries(CMDS.map((cmd) => [cmd, []])), + other: [], + }; + + (timelines || []).forEach(({ events, sexps }) => { + (events || []).forEach((timeMicros, idx) => { + const cmd = getCmd((sexps || [])[idx] || ""); + times[getCmdType(cmd)].push(timeMicros / 1000); + }); + }); + + return times; +} diff --git a/infra/nightly.py b/infra/nightly.py index 3e833356a..675ec8ec2 100644 --- a/infra/nightly.py +++ b/infra/nightly.py @@ -4,6 +4,7 @@ from pathlib import Path import transform import glob +import json ############################################################################### # IMPORTANT: @@ -44,13 +45,12 @@ def add_benchmark_data(aggregator, timeline_file, benchmark_key): if timeline_file.exists(): aggregator.add_file(timeline_file, benchmark_key) -def remove_file(path): - if path.exists(): - path.unlink() - -def cleanup_benchmark_files(*paths): - for path in paths: - remove_file(path) +def cleanup_benchmark_files(tmp_dir): + for path in tmp_dir.iterdir(): + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() def benchmark_files(input_dir, recursive = False): pattern = "**/*.egg" if recursive else "*.egg" @@ -63,7 +63,7 @@ def run_timeline_experiments(resource_dir, tmp_dir, aggregator): timeline_file = tmp_dir / f"{benchmark.stem}-timeline.json" run_poach(benchmark, tmp_dir, "timeline-only") add_benchmark_data(aggregator, timeline_file, f"{suite}/timeline/{benchmark.stem}/timeline.json") - cleanup_benchmark_files(timeline_file, tmp_dir / "summary.json") + cleanup_benchmark_files(tmp_dir) def run_no_io_experiments(resource_dir, tmp_dir, aggregator): no_io_suites = ["easteregg", "herbie-hamming", "herbie-math-rewrite"] # herbie-math-taylor runs out of memory @@ -72,7 +72,7 @@ def run_no_io_experiments(resource_dir, tmp_dir, aggregator): timeline_file = tmp_dir / f"{benchmark.stem}-timeline.json" run_poach(benchmark, tmp_dir, "no-io") add_benchmark_data(aggregator, timeline_file, f"{suite}/no-io/{benchmark.stem}/timeline.json") - cleanup_benchmark_files(timeline_file, tmp_dir / "summary.json") + cleanup_benchmark_files(tmp_dir) def run_test_experiments(top_dir, tmp_dir, aggregator): test_modes = [ @@ -87,39 +87,59 @@ def run_test_experiments(top_dir, tmp_dir, aggregator): timeline_file = tmp_dir / f"{benchmark.stem}-timeline.json" run_poach(benchmark, tmp_dir, run_mode) add_benchmark_data(aggregator, timeline_file, f"tests/{benchmark_name}/{benchmark.stem}/timeline.json") - extra_files = { - "sequential-round-trip": [tmp_dir / f"{benchmark.stem}-serialize1.json"], - "old-serialize": [ - tmp_dir / f"{benchmark.stem}-serialize-poach.json", - tmp_dir / f"{benchmark.stem}-serialize-old.json", - ], - }.get(run_mode, []) - cleanup_benchmark_files(timeline_file, tmp_dir / "summary.json", *extra_files) + cleanup_benchmark_files(tmp_dir) def run_mined_experiments(resource_dir, tmp_dir, aggregator): - mega_serialize_file = tmp_dir / "mega-easteregg-serialize.json" - mega_timeline_file = tmp_dir / "mega-easteregg-timeline.json" - run_poach(resource_dir / "mega-easteregg.egg", tmp_dir, "serialize") - add_benchmark_data(aggregator, mega_timeline_file, "easteregg/serialize/mega-easteregg/timeline.json") - cleanup_benchmark_files(mega_timeline_file, tmp_dir / "summary.json") + mine_data = {} + mined_timeline_aggregator = transform.TimelineAggregator(tmp_dir) + mega_seed_dir = tmp_dir.parent / "cache" + mega_seed_dir.mkdir(exist_ok = True) + mega_serialize_file = mega_seed_dir / "mega-easteregg-serialize.json" + mega_timeline_file = mega_seed_dir / "mega-easteregg-timeline.json" + run_poach(resource_dir / "mega-easteregg.egg", mega_seed_dir, "serialize") + for benchmark in benchmark_files(resource_dir / "test-files" / "easteregg"): + benchmark_name = benchmark.stem timeline_file = tmp_dir / f"{benchmark.stem}-timeline.json" - serialize_file = tmp_dir / f"{benchmark.stem}-serialize.json" + mine_extract_file = tmp_dir / "mine-extracts.json" + baseline_key = f"{benchmark_name}/baseline" + mine_indiv_key = f"{benchmark_name}/mine-indiv" + mine_mega_key = f"{benchmark_name}/mine-mega" + + # First, make sure we have a serialized e-graph for the benchmark run_poach(benchmark, tmp_dir, "serialize") + mined_timeline_aggregator.add_file(timeline_file, baseline_key) + baseline_timeline = mined_timeline_aggregator.aggregated[baseline_key] add_benchmark_data(aggregator, timeline_file, f"easteregg/serialize/{benchmark.stem}/timeline.json") - cleanup_benchmark_files(timeline_file, tmp_dir / "summary.json") + # Mine Individual: Run the file starting from the serialized e-graph for the benchmark run_poach(benchmark, tmp_dir, "mine", - ["--initial-egraph=" + str(tmp_dir)]) + ["--initial-egraph=" + str(tmp_dir / f"{benchmark.stem}-serialize.json")]) + mined_timeline_aggregator.add_file(timeline_file, mine_indiv_key) + mine_indiv_timeline = mined_timeline_aggregator.aggregated[mine_indiv_key] + with open(mine_extract_file) as file: + mine_indiv_extracts = json.load(file)[benchmark_name] add_benchmark_data(aggregator, timeline_file, f"easteregg/mine-indiv/{benchmark.stem}/timeline.json") - cleanup_benchmark_files(timeline_file, serialize_file, tmp_dir / "summary.json") + # Mine Mega: Run the file starting from the mega e-graph for all of easteregg run_poach(benchmark, tmp_dir, "mine", ["--initial-egraph=" + str(mega_serialize_file)]) + mined_timeline_aggregator.add_file(timeline_file, mine_mega_key) + mine_mega_timeline = mined_timeline_aggregator.aggregated[mine_mega_key] + with open(mine_extract_file) as file: + mine_mega_extracts = json.load(file)[benchmark_name] add_benchmark_data(aggregator, timeline_file, f"easteregg/mine-mega/{benchmark.stem}/timeline.json") - cleanup_benchmark_files(timeline_file, tmp_dir / "summary.json") - cleanup_benchmark_files(mega_serialize_file, tmp_dir / "summary.json") + mine_data[benchmark_name] = { + "baseline_timeline": baseline_timeline, + "mine_indiv_timeline": mine_indiv_timeline, + "mine_mega_timeline": mine_mega_timeline, + "mine_indiv_extracts": mine_indiv_extracts, + "mine_mega_extracts": mine_mega_extracts, + } + cleanup_benchmark_files(tmp_dir) + + transform.save_json(aggregator.output_dir / "mine-data.json", mine_data) if __name__ == "__main__": print("Beginning poach nightly") diff --git a/src/lib.rs b/src/lib.rs index f5a0f85a8..05ac98b7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -373,9 +373,9 @@ pub enum CommandOutput { /// The best function found after extracting ExtractBest(TermDag, DefaultCost, TermId), /// The variants of a function found after extracting - ExtractVariants(TermDag, Vec), + ExtractVariants(TermDag, Vec<(DefaultCost, TermId)>), /// The variants of multiple functions found after extracting - MultiExtractVariants(TermDag, Vec>), + MultiExtractVariants(TermDag, Vec>), /// The report from all runs OverallStatistics(RunReport), /// A printed function and all its values @@ -402,7 +402,7 @@ impl std::fmt::Display for CommandOutput { } CommandOutput::ExtractVariants(termdag, terms) => { writeln!(f, "(")?; - for expr in terms { + for (_, expr) in terms { writeln!(f, " {}", termdag.to_string(*expr))?; } writeln!(f, ")") @@ -411,7 +411,7 @@ impl std::fmt::Display for CommandOutput { writeln!(f, "(")?; for variants in terms { writeln!(f, " (")?; - for expr in variants { + for (_, expr) in variants { writeln!(f, " {}", termdag.to_string(*expr))?; } writeln!(f, " )")?; @@ -1140,6 +1140,69 @@ impl EGraph { Ok(RunReport::singleton(ruleset, iteration_report)) } + fn restore_deserialized_runtime(&mut self) -> Result<(), Error> { + self.type_info.restore_deserialized_runtime_metadata(); + self.type_info.clear_primitives(); + + let mut sorts = self + .type_info + .all_sorts() + .into_iter() + .map(|sort| (sort.name().to_owned(), sort)) + .collect::>(); + for builtin in ["Unit", "String", "bool", "i64", "f64", "BigInt", "BigRat"] { + if let Some(sort) = sorts.remove(builtin) { + sort.register_primitives(self); + } + } + add_primitive!(self, "!=" = |a: #, b: #| -?> () { + (a != b).then_some(()) + }); + add_primitive!(self, "value-eq" = |a: #, b: #| -?> () { + (a == b).then_some(()) + }); + add_primitive!(self, "ordering-min" = |a: #, b: #| -> # { + if a < b { a } else { b } + }); + add_primitive!(self, "ordering-max" = |a: #, b: #| -> # { + if a > b { a } else { b } + }); + for (_, sort) in sorts { + sort.register_primitives(self); + } + self.backend.restore_deserialized_runtime(); + + let mut restored_rulesets = IndexMap::default(); + let rulesets = std::mem::take(&mut self.rulesets); + + for (ruleset_name, ruleset) in rulesets { + let restored = match ruleset { + Ruleset::Rules(rules) => { + let mut restored_rules = IndexMap::default(); + for (rule_name, (core_rule, _)) in rules { + let rule_id = { + let mut translator = BackendRule::new( + self.backend.new_rule(&rule_name, self.seminaive), + &self.functions, + &self.type_info, + ); + translator.query(&core_rule.body, false); + translator.actions(&core_rule.head)?; + translator.build() + }; + restored_rules.insert(rule_name, (core_rule, rule_id)); + } + Ruleset::Rules(restored_rules) + } + Ruleset::Combined(sub_rulesets) => Ruleset::Combined(sub_rulesets), + }; + restored_rulesets.insert(ruleset_name, restored); + } + + self.rulesets = restored_rulesets; + Ok(()) + } + fn add_rule(&mut self, rule: ast::ResolvedRule) -> Result { // Disable union_to_set optimization in proof or term encoding mode, since // it expects only `union` on constructors (not set). @@ -1444,11 +1507,7 @@ impl EGraph { if n < 0 { panic!("Cannot extract negative number of variants"); } - let terms: Vec = extractor - .extract_variants(self, &mut termdag, x, n as usize) - .iter() - .map(|e| e.1) - .collect(); + let terms = extractor.extract_variants(self, &mut termdag, x, n as usize); if log_enabled!(Level::Info) { let expr_str = expr.to_string(); log::info!("extracted {} variants for {expr_str}", terms.len()); @@ -1481,17 +1540,13 @@ impl EGraph { .iter() .zip(exprs.iter()) .map(|(x, expr)| { - let variants: Vec<_> = extractor - .extract_variants_with_sort( - self, - &mut termdag, - *x, - n as usize, - expr.output_type(), - ) - .iter() - .map(|e| e.1.clone()) - .collect(); + let variants = extractor.extract_variants_with_sort( + self, + &mut termdag, + *x, + n as usize, + expr.output_type(), + ); if log_enabled!(Level::Info) { let expr_str = expr.to_string(); log::info!("extracted {} variants for {expr_str}", variants.len()); @@ -2484,7 +2539,11 @@ impl TimedEgraph { let file = File::open(path).expect("failed to open egraph file"); let reader = BufReader::new(file); - let egraph: EGraph = serde_json::from_reader(reader).expect("failed to parse egraph JSON"); + let mut egraph: EGraph = + serde_json::from_reader(reader).expect("failed to parse egraph JSON"); + egraph + .restore_deserialized_runtime() + .expect("failed to restore deserialized runtime"); Self { egraphs: vec![egraph], @@ -2525,6 +2584,17 @@ impl TimedEgraph { Ok(outputs) } + pub fn run_from_string(&mut self, program_text: &str) -> Result> { + let parsed_commands = self + .egraphs + .last_mut() + .expect("There are no egraphs") + .parser + .get_program_from_string(None, program_text)?; + + Ok(self.run_program_with_timeline(parsed_commands, program_text)?) + } + pub fn run_program_with_timeline( &mut self, program: Vec, @@ -2651,8 +2721,11 @@ impl TimedEgraph { time_micros: self.timer.elapsed().as_micros(), }); - let egraph: EGraph = + let mut egraph: EGraph = serde_json::from_value(value).context("Failed to decode egraph from json")?; + egraph + .restore_deserialized_runtime() + .context("Failed to restore deserialized runtime")?; timeline.evts.push(EgraphEvent { sexp_idx: 0, @@ -2691,7 +2764,7 @@ impl TimedEgraph { let file = fs::File::create(path) .with_context(|| format!("failed to create file {}", path.display()))?; - serde_json::to_writer(BufWriter::new(file), &value) + serde_json::to_writer_pretty(BufWriter::new(file), &value) .context("Failed to write value to file")?; timeline.evts.push(EgraphEvent { @@ -2732,7 +2805,10 @@ impl TimedEgraph { time_micros: self.timer.elapsed().as_micros(), }); - let egraph: EGraph = serde_json::from_value(value)?; + let mut egraph: EGraph = serde_json::from_value(value)?; + egraph + .restore_deserialized_runtime() + .context("Failed to restore deserialized runtime")?; timeline.evts.push(EgraphEvent { sexp_idx: 1, diff --git a/src/poach.rs b/src/poach.rs index f8a1facd5..5e1f15967 100644 --- a/src/poach.rs +++ b/src/poach.rs @@ -4,6 +4,7 @@ use egglog::ast::{ all_sexps, GenericAction, GenericCommand, GenericExpr, GenericFact, GenericRunConfig, GenericSchedule, Sexp, SexpParser, }; +use egglog::extract::DefaultCost; use egglog::{CommandOutput, EGraph, TimedEgraph}; use env_logger::Env; use hashbrown::HashMap; @@ -67,8 +68,10 @@ enum RunMode { // Requires initial-egraph to be provided via Args // For each egg file under the input path, - // Deserialize the initial egraph - // Run the egglog program, skipping declarations of Sorts and Rules + // Run the egglog program from a fresh egraph and record extract outputs. + // Deserialize the initial egraph. + // Run the egglog program, skipping declarations of Sorts and Rules. + // Compare extract outputs between the two runs. // Save the completed timeline, for consumption by the nightly frontend Mine, } @@ -99,10 +102,7 @@ struct Args { output_dir: PathBuf, run_mode: RunMode, - // If this is a single file, it will be used as the initial egraph for - // every file in the input_path directory - // If it is a directory, we will look for a file matching the name of each - // file in the input_path directory + // Path to an initial serialized egraph file to load before running each benchmark. #[arg(long)] initial_egraph: Option, } @@ -173,11 +173,7 @@ where .unwrap_or("unknown"); let mut timed_egraph = if let Some(path) = initial_egraph { - if path.is_file() { - TimedEgraph::new_from_file(path) - } else { - TimedEgraph::new_from_file(&path.join(format!("{name}-serialize.json"))) - } + TimedEgraph::new_from_file(path) } else { TimedEgraph::new() }; @@ -204,42 +200,125 @@ where (successes, failures) } -fn compare_extracts( - initial_extracts: &[CommandOutput], - final_extracts: &[CommandOutput], -) -> Result<()> { - if initial_extracts.len() != final_extracts.len() { - anyhow::bail!("extract lengths mismatch") - } +fn collect_extract_outputs(outputs: Vec) -> Vec { + outputs + .into_iter() + .filter(|output| { + matches!( + output, + CommandOutput::ExtractBest(_, _, _) + | CommandOutput::ExtractVariants(_, _) + | CommandOutput::MultiExtractVariants(_, _) + ) + }) + .collect() +} - for (x, y) in initial_extracts.iter().zip(final_extracts) { - match (x, y) { - (CommandOutput::ExtractBest(_, _, term1), CommandOutput::ExtractBest(_, _, term2)) => { - if term1 != term2 { - anyhow::bail!("No match : {:?} {:?}", x, y) - } +#[derive(Serialize)] +struct MineExtractComparison { + initial_term: String, + initial_cost: DefaultCost, + final_term: String, + final_cost: DefaultCost, +} + +fn extract_term_costs_from_output(outputs: &[CommandOutput]) -> Result> { + let mut pairs = Vec::new(); + for output in outputs { + match output { + CommandOutput::ExtractBest(dag, cost, term) => { + pairs.push((dag.to_string(*term), *cost)); } - ( - CommandOutput::ExtractVariants(_, terms1), - CommandOutput::ExtractVariants(_, terms2), - ) => { - if terms1 != terms2 { - anyhow::bail!("No match : {:?} {:?}", x, y) - } + CommandOutput::ExtractVariants(dag, variants) => { + pairs.extend( + variants + .iter() + .map(|(cost, term)| (dag.to_string(*term), *cost)), + ); } - ( - CommandOutput::MultiExtractVariants(_, items1), - CommandOutput::MultiExtractVariants(_, items2), - ) => { - if items1 != items2 { - anyhow::bail!("No match : {:?} {:?}", x, y) - } + CommandOutput::MultiExtractVariants(dag, groups) => { + pairs.extend(groups.iter().flat_map(|group| { + group + .iter() + .map(|(cost, term)| (dag.to_string(*term), *cost)) + })); } - _ => anyhow::bail!("No match : {:?} {:?}", x, y), + _ => anyhow::bail!("Not an extract output: {output:?}"), } } + Ok(pairs) +} - Ok(()) +#[derive(Default)] +struct Namespace { + map: HashMap, +} + +impl Namespace { + fn add(&mut self, name: String) -> String { + if self.map.contains_key(&name) { + panic!("duplicate variable names") + } else { + let namespaced = format!("{name}@@"); + self.map.insert(name.clone(), namespaced.clone()); + namespaced + } + } + + fn get(&self, name: String) -> String { + self.map.get(&name).unwrap_or(&name).to_string() + } + + fn replace_expr(&self, expr: GenericExpr) -> GenericExpr { + match expr { + GenericExpr::Var(span, n) => GenericExpr::Var(span, self.get(n)), + GenericExpr::Call(span, h, generic_exprs) => GenericExpr::Call( + span, + self.get(h), + generic_exprs + .into_iter() + .map(|x| self.replace_expr(x)) + .collect(), + ), + GenericExpr::Lit(span, literal) => GenericExpr::Lit(span, literal), + } + } + + fn replace_fact(&self, fact: GenericFact) -> GenericFact { + match fact { + GenericFact::Eq(span, e1, e2) => { + GenericFact::Eq(span, self.replace_expr(e1), self.replace_expr(e2)) + } + GenericFact::Fact(e) => GenericFact::Fact(self.replace_expr(e)), + } + } + + fn replace_sched( + &self, + schedule: GenericSchedule, + ) -> GenericSchedule { + match schedule { + GenericSchedule::Saturate(span, sched) => { + GenericSchedule::Saturate(span, Box::new(self.replace_sched(*sched))) + } + GenericSchedule::Repeat(span, n, sched) => { + GenericSchedule::Repeat(span, n, Box::new(self.replace_sched(*sched))) + } + GenericSchedule::Run(span, config) => GenericSchedule::Run( + span, + GenericRunConfig { + ruleset: config.ruleset, + until: config + .until + .map(|facts| facts.into_iter().map(|f| self.replace_fact(f)).collect()), + }, + ), + GenericSchedule::Sequence(span, scheds) => GenericSchedule::Sequence( + span, + scheds.into_iter().map(|x| self.replace_sched(x)).collect(), + ), + } + } } fn poach( @@ -268,7 +347,17 @@ fn poach( initial_egraph.as_deref(), |egg_file, out_dir, timed_egraph| { let name = benchmark_name(egg_file); - timed_egraph.run_from_file(egg_file)?; + let outputs = timed_egraph.run_from_file(egg_file)?; + for output in outputs { + if matches!( + output, + CommandOutput::ExtractBest(_, _, _) + | CommandOutput::ExtractVariants(_, _) + | CommandOutput::MultiExtractVariants(_, _) + ) { + print!("{output}"); + } + } timed_egraph.to_file(&out_dir.join(format!("{name}-serialize.json")))?; timed_egraph.write_timeline(&out_dir.join(format!("{name}-timeline.json")))?; Ok(()) @@ -398,19 +487,8 @@ fn poach( initial_egraph.as_deref(), |egg_file, out_dir, timed_egraph| { let name = benchmark_name(egg_file); - let initial_outputs = timed_egraph.run_from_file(egg_file)?; - - let initial_extracts: Vec = initial_outputs - .into_iter() - .filter(|x| { - matches!( - x, - CommandOutput::ExtractBest(_, _, _) - | CommandOutput::ExtractVariants(_, _) - | CommandOutput::MultiExtractVariants(_, _) - ) - }) - .collect(); + let initial_extracts = + collect_extract_outputs(timed_egraph.run_from_file(egg_file)?); let program_string = &read_to_string(egg_file)?; @@ -454,10 +532,19 @@ fn poach( check_egraph_number(&timed_egraph, 2)?; - let final_extracts = - timed_egraph.run_program_with_timeline(extract_cmds, &extracts)?; - - compare_extracts(&initial_extracts, &final_extracts)?; + let final_extracts = collect_extract_outputs( + timed_egraph.run_program_with_timeline(extract_cmds, &extracts)?, + ); + + let initial_pairs = extract_term_costs_from_output(&initial_extracts)?; + let final_pairs = extract_term_costs_from_output(&final_extracts)?; + if initial_pairs != final_pairs { + anyhow::bail!( + "extract outputs differ:\ninitial: {:?}\nfinal: {:?}", + initial_pairs, + final_pairs + ); + } timed_egraph.write_timeline(&out_dir.join(format!("{name}-timeline.json")))?; @@ -470,102 +557,18 @@ fn poach( initial_egraph.is_some(), "initial_egraph must be provided via CLI args for Mine run mode" ); - process_files( + let mut extract_report: HashMap> = HashMap::new(); + let result = process_files( &files, out_dir, initial_egraph.as_deref(), |egg_file, out_dir, timed_egraph| { - let name = benchmark_name(egg_file); - // Namespace to avoid shadowing - #[derive(Default)] - struct Namespace { - map: HashMap, - } - - impl Namespace { - fn add(&mut self, name: String) -> String { - if self.map.contains_key(&name) { - panic!("duplicate variable names") - } else { - let namespaced = format!("@@{name}"); - self.map.insert(name.clone(), namespaced.clone()); - namespaced - } - } - - fn get(&self, name: String) -> String { - self.map.get(&name).unwrap_or(&name).to_string() - } - - fn replace_expr( - &self, - expr: GenericExpr, - ) -> GenericExpr { - match expr { - GenericExpr::Var(span, n) => GenericExpr::Var(span, self.get(n)), - GenericExpr::Call(span, h, generic_exprs) => GenericExpr::Call( - span, - self.get(h), - generic_exprs - .into_iter() - .map(|x| self.replace_expr(x)) - .collect(), - ), - GenericExpr::Lit(span, literal) => GenericExpr::Lit(span, literal), - } - } + // First, run the file on a blank e-graph and track extracts + let mut fresh_egraph = TimedEgraph::new(); + let fresh_extracts = + collect_extract_outputs(fresh_egraph.run_from_file(egg_file)?); - fn replace_fact( - &self, - fact: GenericFact, - ) -> GenericFact { - match fact { - GenericFact::Eq(span, e1, e2) => GenericFact::Eq( - span, - self.replace_expr(e1), - self.replace_expr(e2), - ), - GenericFact::Fact(e) => GenericFact::Fact(self.replace_expr(e)), - } - } - - fn replace_sched( - &self, - schedule: GenericSchedule, - ) -> GenericSchedule { - match schedule { - GenericSchedule::Saturate(span, sched) => { - GenericSchedule::Saturate( - span, - Box::new(self.replace_sched(*sched)), - ) - } - GenericSchedule::Repeat(span, n, sched) => GenericSchedule::Repeat( - span, - n, - Box::new(self.replace_sched(*sched)), - ), - GenericSchedule::Run(span, config) => GenericSchedule::Run( - span, - GenericRunConfig { - ruleset: config.ruleset, - until: config.until.map(|facts| { - facts - .into_iter() - .map(|f| self.replace_fact(f)) - .collect() - }), - }, - ), - GenericSchedule::Sequence(span, scheds) => { - GenericSchedule::Sequence( - span, - scheds.into_iter().map(|x| self.replace_sched(x)).collect(), - ) - } - } - } - } + let name = benchmark_name(egg_file); let mut namespace = Namespace::default(); let program_string = &read_to_string(egg_file)?; @@ -581,19 +584,16 @@ fn poach( let (filtered_cmds, filtered_sexps): (Vec<_>, Vec<_>) = all_cmds .into_iter() .zip(all_sexps) - .filter(|(c, _)| { - match c { - GenericCommand::Action(GenericAction::Let(..)) => true, - egglog::ast::GenericCommand::Extract(..) => true, - egglog::ast::GenericCommand::MultiExtract(..) => true, - // TODO: Running rules on a deserialized egraph currently does not work - // | egglog::ast::GenericCommand::RunSchedule(_) - egglog::ast::GenericCommand::PrintOverallStatistics(..) => true, - egglog::ast::GenericCommand::Check(..) => true, - egglog::ast::GenericCommand::PrintFunction(..) => true, - egglog::ast::GenericCommand::PrintSize(..) => true, - _ => false, - } + .filter(|(c, _)| match c { + GenericCommand::Action(GenericAction::Let(..)) => true, + egglog::ast::GenericCommand::Extract(..) => true, + egglog::ast::GenericCommand::MultiExtract(..) => true, + egglog::ast::GenericCommand::RunSchedule(..) => true, + egglog::ast::GenericCommand::PrintOverallStatistics(..) => true, + egglog::ast::GenericCommand::Check(..) => true, + egglog::ast::GenericCommand::PrintFunction(..) => true, + egglog::ast::GenericCommand::PrintSize(..) => true, + _ => false, }) .map(|(cmd, sexp)| { ( @@ -645,7 +645,8 @@ fn poach( }) .unzip(); - timed_egraph.run_program_with_timeline( + // Run program on the mined e-graph + let mined_outputs = timed_egraph.run_program_with_timeline( filtered_cmds, &filtered_sexps .iter() @@ -654,11 +655,44 @@ fn poach( .join("\n"), )?; + let mined_extracts = collect_extract_outputs(mined_outputs); + + let initial_pairs = extract_term_costs_from_output(&fresh_extracts)?; + let final_pairs = extract_term_costs_from_output(&mined_extracts)?; + if initial_pairs.len() != final_pairs.len() { + anyhow::bail!( + "extract lengths mismatch for {}: {} != {}", + name, + initial_pairs.len(), + final_pairs.len() + ); + } + extract_report.insert( + name.to_string(), + initial_pairs + .into_iter() + .zip(final_pairs) + .map(|((initial_term, initial_cost), (final_term, final_cost))| { + MineExtractComparison { + initial_term, + initial_cost, + final_term, + final_cost, + } + }) + .collect(), + ); + timed_egraph.write_timeline(&out_dir.join(format!("{name}-timeline.json")))?; Ok(()) }, - ) + ); + let file = File::create(out_dir.join("mine-extracts.json")) + .expect("failed to create mine-extracts.json"); + serde_json::to_writer_pretty(BufWriter::new(file), &extract_report) + .expect("failed to write mine-extracts.json"); + result } } } @@ -685,7 +719,6 @@ fn main() { WalkDir::new(input_path) .into_iter() .filter_map(|entry| entry.ok()) - .filter(|entry| !entry.path().to_string_lossy().contains("fail")) .filter(|entry| entry.file_type().is_file()) .filter(|entry| entry.path().extension().and_then(|s| s.to_str()) == Some("egg")) .map(|entry| entry.path().to_path_buf()) @@ -705,3 +738,43 @@ fn main() { File::create(output_dir.join("summary.json")).expect("Failed to create summary.json"); serde_json::to_writer_pretty(BufWriter::new(file), &out).expect("failed to write summary.json"); } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn run_rules_after_deserialize() { + let mut timed_egraph = TimedEgraph::new(); + let program = r#" + (function fib (i64) i64 :no-merge) + (set (fib 0) 0) + (set (fib 1) 1) + + (rule ((= f0 (fib x)) + (= f1 (fib (+ x 1)))) + ((set (fib (+ x 2)) (+ f0 f1)))) + + (run 2) + + (check (= (fib 3) 2)) + (fail (check (= (fib 4) 3))) + "#; + + let res = timed_egraph.run_from_string(program); + assert!(res.is_ok()); + + // round trip serialize + let v = timed_egraph.to_value().expect("failed to serialize"); + timed_egraph.from_value(v).expect("failed to deserialize"); + + let second_run = r#" + (fail (check (= (fib 4) 3))) + (run 1) + (check (= (fib 4) 3)) + "#; + + let res2 = timed_egraph.run_from_string(second_run); + assert!(res2.is_ok()); + } +} diff --git a/src/typechecking.rs b/src/typechecking.rs index 2f39e97a0..859f464e5 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -165,7 +165,7 @@ impl<'de> Deserialize<'de> for TypeInfo { primitives: helper.primitives, func_types: helper.func_types, global_sorts: helper.global_sorts, - }) // TODO: this is a bogus default value + }) } } @@ -443,6 +443,24 @@ impl EGraph { } impl TypeInfo { + pub(crate) fn restore_deserialized_runtime_metadata(&mut self) { + self.mksorts = Default::default(); + self.reserved_primitives = Default::default(); + self.add_presort::(span!()).unwrap(); + self.add_presort::(span!()).unwrap(); + self.add_presort::(span!()).unwrap(); + self.add_presort::(span!()).unwrap(); + self.add_presort::(span!()).unwrap(); + } + + pub(crate) fn clear_primitives(&mut self) { + self.primitives.clear(); + } + + pub(crate) fn all_sorts(&self) -> Vec { + self.sorts.values().cloned().collect() + } + /// Adds a sort constructor to the typechecker's known set of types. pub fn add_presort(&mut self, span: Span) -> Result<(), TypeError> { let name = S::presort_name();