From 6e3251f865a96662d9c36f6b98ca4cbe8df5841b Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Thu, 23 Apr 2026 15:17:56 -0500 Subject: [PATCH] Add OpenReview ICLR 2025 pilot benchmark and LLM-judge eval - Pilot JSONL, scripts, docs, locked eval report and eval_history - evaluate_openreview.py; .gitignore for local OpenReview paths --- .gitignore | 3 + benchmarks/openreview_benchmark/OPENREVIEW.md | 120 +++++++++ benchmarks/openreview_benchmark/REPORT.md | 74 ++++++ .../data/openreview_benchmark.jsonl | 10 + .../openreview_benchmark/eval_history.jsonl | 1 + .../reports/eval_20260423T104302Z.json | 150 +++++++++++ .../scripts/collect_openreview.py | 148 +++++++++++ .../scripts/download_openreview_pdfs.py | 102 ++++++++ .../scripts/evaluate_openreview_benchmark.py | 241 ++++++++++++++++++ .../scripts/filter_candidates.py | 230 +++++++++++++++++ .../scripts/normalize_openreview.py | 222 ++++++++++++++++ .../scripts/openreview_http.py | 62 +++++ .../scripts/validate_openreview_benchmark.py | 124 +++++++++ src/reviewer/evaluate_openreview.py | 220 ++++++++++++++++ 14 files changed, 1707 insertions(+) create mode 100644 benchmarks/openreview_benchmark/OPENREVIEW.md create mode 100644 benchmarks/openreview_benchmark/REPORT.md create mode 100644 benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl create mode 100644 benchmarks/openreview_benchmark/eval_history.jsonl create mode 100644 benchmarks/openreview_benchmark/reports/eval_20260423T104302Z.json create mode 100644 benchmarks/openreview_benchmark/scripts/collect_openreview.py create mode 100644 benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py create mode 100644 benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py create mode 100644 benchmarks/openreview_benchmark/scripts/filter_candidates.py create mode 100644 benchmarks/openreview_benchmark/scripts/normalize_openreview.py create mode 100644 benchmarks/openreview_benchmark/scripts/openreview_http.py create mode 100644 benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py create mode 100644 src/reviewer/evaluate_openreview.py diff --git a/.gitignore b/.gitignore index 972e636..4f24c63 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,9 @@ results/ benchmarks/results/ data/raw_html/ benchmarks/data/raw_html/ +benchmarks/openreview_benchmark/data/openreview_pdfs/ +benchmarks/openreview_benchmark/data/openreview_raw/ +benchmarks/openreview_benchmark/results/ # Review output review_results/ diff --git a/benchmarks/openreview_benchmark/OPENREVIEW.md b/benchmarks/openreview_benchmark/OPENREVIEW.md new file mode 100644 index 0000000..9e0dac8 --- /dev/null +++ b/benchmarks/openreview_benchmark/OPENREVIEW.md @@ -0,0 +1,120 @@ +# OpenReview benchmark track (pilot) + +This track complements the Refine-based benchmark in `benchmarks/data/benchmark.jsonl`. It uses **public OpenReview** threads (reviews, author replies, meta-review, decision) from ML venues. It is **not** paragraph-anchored like Refine; evaluation should use **semantic overlap** (e.g. LLM-as-judge) between model comments and human review text, not paragraph-location metrics. + +All OpenReview-specific assets live under **`benchmarks/openreview_benchmark/`** (data, scripts, this doc). + +## Pilot scope (ICLR 2025) + +The pilot includes **10 papers** from **ICLR 2025**, chosen from a random sample of accepted papers by **longest average** review length (sum of `summary`, `strengths`, `weaknesses`, and `questions` per official review, averaged across reviewers). Papers also had to have **at least three official reviews** and **at least one author reply** in the thread. + +| Forum ID | Title | +|----------|--------| +| `jj7b3p5kLY` | The AdEMAMix Optimizer: Better, Faster, Older | +| `kOJf7Dklyv` | Air Quality Prediction with Physics-Guided Dual Neural ODEs in Open Systems | +| `ajxAJ8GUX4` | Learning Geometric Reasoning Networks For Robot Task And Motion Planning | +| `XMOaOigOQo` | ContraDiff: Planning Towards High Return States via Contrastive Learning | +| `SFNqrHQTEP` | NExUME: Adaptive Training and Inference for DNNs under Intermittent Power Environments | +| `BC4lIvfSzv` | Generative Representational Instruction Tuning | +| `M992mjgKzI` | OGBench: Benchmarking Offline Goal-Conditioned RL | +| `BM9qfolt6p` | LucidPPN: Unambiguous Prototypical Parts Network for User-centric Interpretable Computer Vision | +| `7b2JrzdLhA` | Graph Neural Ricci Flow: Evolving Feature from a Curvature Perspective | +| `d4qMoUSMLT` | Efficient Training of Neural Stochastic Differential Equations by Matching Finite Dimensional Distributions | + +## Data files + +| Path | Description | +|------|-------------| +| `benchmarks/openreview_benchmark/data/openreview_raw/.json` | Raw API response: all notes in the forum (`GET /notes?forum=`). **Gitignored**; produce with `collect_openreview.py` if you need to re-run `normalize_openreview.py`. Not required to run eval (committed JSONL is enough). | +| `benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl` | One JSON object per line: normalized paper metadata, reviews, discussions, meta-review, decision. **Committed**; this is what the eval script reads. | + +Optional: `filter_candidates.py` can write a ranked list (e.g. `candidate_papers.json`) while discovering the pilot; that file is **not** required to use the benchmark once `openreview_benchmark.jsonl` exists. + +### Locked evaluation artifacts (committed) + +| Path | Description | +|------|-------------| +| `benchmarks/openreview_benchmark/reports/` | Frozen full-eval JSON copies for git (`eval_.json`); use **repo-relative** `benchmark` / `results_dir` paths inside each file. | +| `benchmarks/openreview_benchmark/REPORT.md` | Human-readable pilot report (tables, caveats, how to reproduce). | + +## Scripts + +Shared HTTP helpers (Cloudflare session) live in **`benchmarks/openreview_benchmark/scripts/openreview_http.py`** and are imported by the fetch/download scripts below. + +| Script | Purpose | +|--------|---------| +| `benchmarks/openreview_benchmark/scripts/collect_openreview.py` | Fetch forums by venue or explicit `--forum-ids`; writes `data/openreview_raw/`. Uses a browser session (visit `openreview.net` first) so API requests are not blocked. | +| `benchmarks/openreview_benchmark/scripts/normalize_openreview.py` | Convert raw forum JSON to `data/openreview_benchmark.jsonl`. | +| `benchmarks/openreview_benchmark/scripts/filter_candidates.py` | List accepted papers for ICLR 2025 + NeurIPS 2025, random sample, rank by review text length; optional pilot discovery. | +| `benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py` | Check JSONL schema; optional `--parse-one` downloads the first paper’s PDF and runs `parse_document` (no LLM). Use before a full review run. | +| `benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py` | LLM-judge **precision / recall / F1**; optional `--save-full-report` / `--output`; appends to `eval_history.jsonl` unless `--no-eval-history`. | +| `benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py` | Download PDFs for papers in `openreview_benchmark.jsonl` into `data/openreview_pdfs/` (gitignored) for `openaireview review `. | + +## Schema (normalized JSONL) + +Each line is one paper. Main fields: + +- **Paper:** `paper_id`, `forum_url`, `venue`, `year`, `title`, `authors`, `abstract`, `keywords`, `primary_area`, `pdf_url`, `decision` +- **Reviews:** `reviews[]` — each item has `review_id`, `reviewer`, `rating`, `confidence`, `soundness`, `presentation`, `contribution`, `summary`, `strengths`, `weaknesses`, `questions` +- **Discussion:** `discussions[]` — `comment_id`, `replyto`, `author_type`, `comment` (and optional `reviewer` for reviewer comments) +- **Meta-review:** `meta_review` (object or null) + +## Evaluation (implemented) + +Module: `src/reviewer/evaluate_openreview.py`. CLI: `benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py`. + +OpenAIReview outputs **discrete comments** (title, quote, explanation). Human ground truth is **official reviews** with separate fields. Scores are **LLM-as-judge** (configurable model, default `gpt-4o-mini` via `OPENREVIEW_JUDGE_MODEL`). + +**Precision** (per paper): among model comments, the fraction for which the judge answers **YES** to: “Does this comment overlap **any** substantive critique or question in the **pooled** human review text (all reviewers combined)?” + +**Recall** (per paper): for each official review with non-empty text, the judge answers **YES** if **at least one** model comment addresses a substantive issue in **that** review. **Recall** = (number of YES) / (number of non-empty official reviews). Macro-averaged over papers in the CLI summary. + +**F1** = harmonic mean of precision and recall per paper; the script prints per-paper and **mean** P/R/F1. + +**API keys:** use the same stack as the rest of the package (e.g. `OPENAI_API_KEY` and `REVIEW_PROVIDER=openai` for the judge). Review runs and judge calls can share the provider. + +**Get PDFs locally** (the CLI does not fetch OpenReview PDF URLs like arXiv): + +```bash +python benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py +# Writes benchmarks/openreview_benchmark/data/openreview_pdfs/.pdf (gitignored) +``` + +**Run a review** (keep outputs under this track; `results/` is gitignored except you can commit summaries separately): + +```bash +openaireview review benchmarks/openreview_benchmark/data/openreview_pdfs/jj7b3p5kLY.pdf \ + --name jj7b3p5kLY --method zero_shot \ + --output-dir benchmarks/openreview_benchmark/results/reviews +``` + +**Run evaluation** — `--results-dir` must match where review JSON lives: + +```bash +python benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py \ + --results-dir benchmarks/openreview_benchmark/results/reviews \ + --save-full-report +``` + +That writes a **timestamped** full report under `benchmarks/openreview_benchmark/results/eval_.json` and **appends one line** to **`benchmarks/openreview_benchmark/eval_history.jsonl`** (mean P/R/F1, judge model, paper ids, optional pointer to the full report). Commit `eval_history.jsonl` when you want a paper-trail for a written report; use `--no-eval-history` to skip the append. Use `--output ` instead of `--save-full-report` if you want a fixed report path. + +For a **PR-ready snapshot**, copy that JSON into **`reports/`**, normalize paths to repo-relative strings, and extend **`REPORT.md`** (see the existing locked run there). + +Do **not** use paragraph-index metrics from `evaluate.py` as the primary signal for this track unless human spans are aligned to the paper in a future version. + +**Next steps (optional):** atomic human bullets; rebuttal–point linkage; cheaper embedding baselines. + +## Local-only files (gitignored: `results/`, `data/openreview_pdfs/`, `data/openreview_raw/`) + +| Path | Needed for git / PR? | When you can delete | +|------|----------------------|----------------------| +| `data/openreview_raw/.json` | No | Only for **regenerating** `openreview_benchmark.jsonl` via `normalize_openreview.py`. Eval and the committed pilot do **not** need these files on disk. | +| `results/reviews/.json` | No (local LLM outputs) | Never required for the **committed** scorecard; keep if you want to **re-run eval** without paying for reviews again. | +| `results/eval_.json` | No | **Redundant** after you copy metrics into `reports/` (same numbers; `reports/` is the committed snapshot). | +| `data/openreview_pdfs/*.pdf` | No | Safe to remove to save disk if you no longer run `openaireview review` locally; download again with `download_openreview_pdfs.py` if needed. | + +## Limitations + +- OpenReview is **ML/AI-heavy**; diversity is mostly via topic area within venues. +- API access may require the same session pattern as in `collect_openreview.py` (Cloudflare). +- Review quality and length vary by reviewer; the pilot biased toward **longer** average reviews for denser supervision. diff --git a/benchmarks/openreview_benchmark/REPORT.md b/benchmarks/openreview_benchmark/REPORT.md new file mode 100644 index 0000000..9668ac3 --- /dev/null +++ b/benchmarks/openreview_benchmark/REPORT.md @@ -0,0 +1,74 @@ +# OpenReview pilot benchmark — locked evaluation report + +**Run:** `generated_at` = `2026-04-23T10:43:02.136255+00:00` (UTC) +**Committed scorecard:** [`reports/eval_20260423T104302Z.json`](reports/eval_20260423T104302Z.json) +**History line:** [`eval_history.jsonl`](eval_history.jsonl) (same run; `full_report` points at the committed JSON under `reports/`) + +This report summarizes one completed **LLM-as-judge** pass over all **10** ICLR 2025 pilot papers. It is meant to be cited before a PR; raw review outputs stay under `results/` (gitignored). + +--- + +## What was evaluated + +| Role | Model | Notes | +|------|--------|--------| +| **Paper review** (predictions) | `claude-opus-4-6` | `openaireview review … --method zero_shot`; method key `zero_shot__claude-opus-4-6` in each `.json`. | +| **Judge** (precision / recall) | `claude-sonnet-4-6` | Same API stack as reviews (`REVIEW_PROVIDER=openai` + gateway). Judge calls use `temperature=0.0`, `max_tokens=8`, YES/NO prompts per `src/reviewer/evaluate_openreview.py`. | + +Metrics are **not** comparable to the Refine benchmark in `benchmarks/REPORT.md` (different ground truth: paragraph-anchored Refine comments vs OpenReview review text overlap). + +--- + +## Metric definitions (short) + +See **`OPENREVIEW.md`** and **`src/reviewer/evaluate_openreview.py`** for the exact prompts. + +- **Precision:** fraction of model comments the judge says overlap **any** substantive critique or question in **pooled** official review text (all reviewers). +- **Recall:** for each official review with non-empty formatted text, the judge says whether **at least one** model comment addresses a substantive issue in **that** review; recall = YES count / number of such reviews. +- **F1:** harmonic mean of precision and recall **per paper**; the table below matches the committed JSON. **Means** in the JSON are unweighted averages across the 10 papers. + +--- + +## Aggregate results (n = 10) + +| Mean precision | Mean recall | Mean F1 | +|----------------|-------------|---------| +| 0.377 | 0.745 | 0.464 | + +--- + +## Per-paper results + +| `paper_id` | Precision | Recall | F1 | Predictions | Reviews covered / non-empty | +|------------|-----------|--------|-----|-------------|----------------------------| +| 7b2JrzdLhA | 0.500 | 0.750 | 0.600 | 12 | 3 / 4 | +| ajxAJ8GUX4 | 0.250 | 1.000 | 0.400 | 8 | 4 / 4 | +| BC4lIvfSzv | 0.300 | 1.000 | 0.462 | 10 | 4 / 4 | +| BM9qfolt6p | 0.111 | 0.750 | 0.194 | 9 | 3 / 4 | +| d4qMoUSMLT | 0.500 | 0.750 | 0.600 | 8 | 3 / 4 | +| jj7b3p5kLY | 0.500 | 0.600 | 0.545 | 8 | 3 / 5 | +| kOJf7Dklyv | 0.750 | 0.600 | 0.667 | 8 | 3 / 5 | +| M992mjgKzI | 0.000 | 0.000 | 0.000 | 8 | 0 / 4 | +| SFNqrHQTEP | 0.556 | 1.000 | 0.714 | 9 | 4 / 4 | +| XMOaOigOQo | 0.300 | 1.000 | 0.462 | 10 | 3 / 3 | + +--- + +## Interpretation and caveats + +1. **LLM judge variance:** A second run with the same inputs can change YES/NO edges; treat means as **point estimates**, not ground truth. +2. **Strict overlap:** The judge is asked for overlap with **substantive** human critiques. Model comments that are mostly notation or internal consistency may score **no** overlap when humans emphasized contribution, novelty, or positioning (see **`M992mjgKzI`**: all NO in this run despite substantive model comments). +3. **Review vs judge model mismatch:** Reviews used **Opus**, judge **Sonnet**; both are valid for an end-to-end pipeline but should be stated in any write-up. +4. **Infrastructure:** Gateway retries (including higher retry count on judge calls in `evaluate_openreview.py` during this workstream) absorbed intermittent 503 / Bedrock errors; long runs are still sensitive to outages. + +--- + +## Reproducing (after PDFs and review JSON exist) + +```bash +python benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py \ + --results-dir benchmarks/openreview_benchmark/results/reviews \ + --save-full-report +``` + +Copy the new `eval_.json` into `reports/` with **repo-relative** `benchmark` and `results_dir` fields if you want another locked row for git. You can then delete the duplicate under `results/` to save space; the committed snapshot lives only in `reports/`. diff --git a/benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl b/benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl new file mode 100644 index 0000000..639ea69 --- /dev/null +++ b/benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl @@ -0,0 +1,10 @@ +{"paper_id": "7b2JrzdLhA", "forum_url": "https://openreview.net/forum?id=7b2JrzdLhA", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "Graph Neural Ricci Flow: Evolving Feature from a Curvature Perspective", "authors": ["Jialong Chen", "Bowen Deng", "Zhen WANG", "Chuan Chen", "Zibin Zheng"], "abstract": "Differential equations provide a dynamical perspective for understanding and designing graph neural networks (GNNs). By generalizing the discrete Ricci flow (DRF) to attributed graphs, we can leverage a new paradigm for the evolution of node features with the help of curvature. We show that in the attributed graphs, DRF guarantees a vital property: The curvature of each edge concentrates toward zero over time. This property leads to two interesting consequences: 1) graph Dirichlet energy with bilateral bounds and 2) data-independent curvature decay rate. Based on these theoretical results, we propose the Graph Neural Ricci Flow (GNRF), a novel curvature-aware continuous-depth GNN. Compared to traditional curvature-based graph learning methods, GNRF is not limited to a specific curvature definition. It computes and adjusts time-varying curvature efficiently in linear time. We also empirically illustrate the operating mechanism of GNRF and verify that it performs excellently on diverse datasets.", "keywords": ["Graph neural network", "Differential equation", "Curvature", "Ricci flow"], "primary_area": "learning on graphs and other geometries & topologies", "pdf_url": "https://openreview.net/pdf?id=7b2JrzdLhA", "decision": "Accept (Poster)", "num_reviews": 4, "num_discussions": 27, "reviews": [{"review_id": "UFjaTsDjtf", "reviewer": "Reviewer_tn6E", "rating": 6, "confidence": 2, "soundness": 3, "presentation": 2, "contribution": 3, "summary": "The paper introduces the dynamical system Attribute Discrete Ricci Flow (Attri-DRF) and incorporates this to propose the Graph Neural Ricci Flow (GNRF), a curvature-aware continuous GNN. This ensures that the graph Dirichlet energy can be bilaterally bounded and that the curvature decay to 0 independent of data. Using an auxiliary network (EdgeNet), the model can theoretically incorporate different types of curvature definition. GNRF has excellent performance on many data sets against a variety of discrete and continuous GNNs.", "strengths": "Theoretically, the paper provide several interesting results. \n1. Section 3 provides guarantees on the curvature decay rate and the stable curvature limit of Attri-DRF when certain conditions are met, along with providing a bound on the Dirichlet energy when the curvature stabilizes. This indicates it may be able to avoid over-smoothing/over-squashing. \n2. Incorporating recent results, the paper uses an auxiliary network (EdgeNet), which is capable of approximating arbitrary edge curvature with high precision.\n\nExperimentally, the paper performs well on a variety of popular node classification tasks against a number of old and new discrete and continuous GNN architectures. Section 5.2 and 5.1 provides good evidence that theoretical guarantees hold in reality.", "weaknesses": "1. It is not clear to the reviewer how the theoretical results tie together/what assumptions are made at each step of the way.\n2. The design of EdgeNet is glossed over within the paper, with only a few formulas mentioned within either the main paper or the appendix to explain it. There's also no comparison between EdgeNet's curvature values compared against any other type of curvature that it supposedly can approximate.\n3. The datasets used in the experiments are relatively small datasets.", "questions": "1) Why does having a data-independent curvature decay rate a good thing?\n2) Where does equation (3) come from? I checked the Ollivier paper and I can't find this equation there.\n3) Does GNRF satisfy the theoretical results in Section 3. If it does, it would be great if the authors can clarify this a bit more within the paper.\n4) Over-smoothing and over-squashing are problems caused by the message-passing design in GNNs. Considering that GNRF does not strictly adhere to the message passing design, is it appropriate to mention these problems in this work?\n5) Does GNRF work on larger datasets?"}, {"review_id": "6P1ZSqJETu", "reviewer": "Reviewer_tPnw", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 2, "summary": "his paper introduces Graph Neural Ricci Flow (GNRF), designed to model a dynamic system called Attribute Discrete Ricci Flow (Attri-DRF) on graph.\nUnlike traditional GNNs, which use multiple layers and pass outputs from one layer as inputs to the next, the model in this work employs only a single layer. Instead, it iteratively updates node features and curvatures—treated similarly to edge weights—over discrete time steps according to the Attri-DRF ODE.", "strengths": "1. The model dynamically learns the curvature instead of relying on precomputed values, enabling it to adapt in response to both internal hidden features and the graph’s topology. \n2. This approach is closely aligned with the heat flow equation.\n3. As shown in Figure 3, the proposed framework achieves stable curvature over sufficient time steps. At the same time, the curvature concentrates around zero that can facilitate smoother information flow across the graph.", "weaknesses": "1. The network architecture used in this framework is relatively general. And it would be valuable to discuss the potential benefits of using more complex GNN architectures. Moreover, exploring the motivation behind the proposed framework with other related works [1, 2, 3] that focus on graph curvatures could provide additional insights.\n\n2. The experiment primarily focuses on node classification tasks on small-scale graphs. To better validate the effectiveness of the proposed framework, applying it to larger graphs would be beneficial. For example, the ogbn-arxiv dataset could serve as a graph classification dataset with GNNs as baselines. Additionally, for non-homophilous graph datasets, larger datasets and relevant baselines are available in [4].\n\n3. An efficiency study would be helpful. The computational cost of applying the ODE method should be explicitly discussed so readers can better understand its applicability. For instance, comparing parameters, training time, and GPU memory usage between this approach and other GNNs, such as GCN and GAT, would clarify its potential advantages and trade-offs.\n\n[1] Curvdrop: A ricci curvature based approach to prevent graph neural networks from over-smoothing and over-squashing.\n[2] Curvature Graph Neural Network.\n[3] Hyperbolic variational graph neural network for modeling dynamic graphs\n[4] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods", "questions": "1. Could you share your thoughts on the relationship and differences between the curvature predicted in this architecture and edge attention mechanisms? And do you think it is possible to apply attention mechanism in the framework?\n\n2. In EdgeNet, are the edge features from the previous time step used as input for the current time step? Equation 13 suggests that the previous edge features should be used, but the code appears to rely only on the previous node features without incorporating the prior edge features.\n\n3. What is the formulation for updating node features at each subsequent time step? As a suggestion, including an illustrative figure or a pseudo-algorithm would help readers gain a clearer understanding of the overall framework.\n\n4. While there is only one layer in the implementation, is it possible to apply a multiple layer GNN? And is it possible to connect layers with time steps in this case?"}, {"review_id": "kwM2qZC4PU", "reviewer": "Reviewer_QrP8", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "The paper proposes a continuous GNN dynamics, namely GNRF by incorporating curvature based on the Ricci flow. In particular, by expressing edge weights as a function of node features, GNRF propagate features following a discrete (graph) Ricci flow. In order to avoid the costly computation of Ricci curvature on graphs, the paper proposes an auxiliary network for modeling curvature and learned end-to-end. The paper provides several theoretical guarantees in terms of bounded Dirichlet energy and fast curvature decay. The experiments support the effectiveness of the method.", "strengths": "1. Compared to curvature-based graph rewiring, it is interesting and natural to incorporate Ricci curvature into the propagation of node features.\n\n2. Theoretical developments are supportive of the claims.", "weaknesses": "1. It is unclear how EdgeNet approximates edge curvatures? In particular, given there are trainable parameters and in the experiments, EdgeNet is trained end-to-end with supervision only from the task, instead of actual curvature. How to ensure EdgeNet approximates the curvature in this case? \n\n2. Theorem 5 is unclear. Does this mean there exists some network \\phi_1, \\phi_2 such that the network can approximate any curvature? Please give more explanations.\n\n3. Even though the theory is well-developed, the main GNN algorithm in (14) seems to resemble GRAND, especially the EdgeNet seems to act like a re-weighting term as in graph attention. How does EdgeNet differ to the graph attention module? Can you add experiments to verify the difference?", "questions": "1. In Line 285, the paper claims the sign of EdgeNet_ij aligns with the sign of k_ij. I am not sure how this is achieved without the supervision from the actual curvature. \n\n2. In Line 175 of Theorem 2, w_ij should be changed to k_ij?\n\n3. In Section 5.1, the curvature seems to be computed from the EdgeNet? What about the actual curvature?\n\n4. I am also curious whether there could be improvements when EdgeNet is replaced with actual curvature? This could be part of ablation study."}, {"review_id": "vw1bWn0nIW", "reviewer": "Reviewer_1v62", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 2, "contribution": 3, "summary": "The paper introduces Graph Neural Ricci Flow (GNRF), a novel method for evolving node features in Graph Neural Networks (GNNs) using a differential equation-inspired approach based on the Ricci flow. The authors generalize the Discrete Ricci Flow to attributed graphs, where each edge's curvature converges toward zero over time. This has two key consequences: it bounds the graph's Dirichlet energy and provides a data-independent curvature decay rate. GNRF is unique because it computes time-varying curvature efficiently in linear time, unlike traditional curvature-based methods, which are typically precomputed and limited to specific curvature definitions. \n\nThe motivation behind the GNRF stems from the limitations of existing GNNs that rely on heat diffusion equations, which often lead to over-smoothing. Instead, the paper explores an alternative differential equation—Ricci flow—to mitigate over-smoothing and create more stable, non-smooth node representations. This innovative approach contrasts with traditional methods that view curvature as a static, precomputed property tied only to graph topology. By allowing curvature to evolve with node features, GNRF enables more dynamic and flexible graph learning.", "strengths": "- Novelty: The paper introduces an interesting and innovative approach with GNRF, applying Discrete Ricci Flow to attributed graphs in a novel way. By allowing edge curvature to evolve dynamically with node attributes, GNRF moves beyond the limitations of traditional methods that rely on static, precomputed curvature. Its ability to work with any curvature definition and to compute curvature in linear time addresses key concerns around scalability and efficiency. This flexibility offers a practical and effective way to handle common challenges such as over-smoothing and over-squashing in graph neural networks. Overall, GNRF presents a meaningful advancement in curvature-aware graph learning.\n- Theoretical results: The paper offers solid theoretical contributions that help establish the soundness of the Attribute Discrete Ricci Flow framework. One key result is the demonstration that edge curvature naturally converges toward zero, ensuring a stable evolution of node features and addressing potential issues like over-smoothing and over-squashing. The bounding of the Dirichlet energy provides additional assurance that node representations maintain a balance between being too homogeneous or too distinct.", "weaknesses": "- Limited experimental results: The paper's experimental evaluation has some limitations. The focus is exclusively on node classification tasks, raising the question of why the method wasn't tested on other common tasks like graph classification or regression, which would provide a broader view of its applicability. Additionally, of the seven node classification datasets used, three (Cornell, Wisconsin, and Texas) are notably small, making it difficult to draw definitive conclusions about the method’s performance on more challenging or larger-scale data. Furthermore, on two of the larger datasets (Cora_Full and PubMed), the proposed method performs only within the statistical margin of error compared to the baselines, which limits its ability to demonstrate a clear and significant improvement over existing methods. Overall, I believe the paper would benefit significantly if the authors added some experimental results on graph classification/ regressions tasks, for example the LRGB datasets [1].\n- Baseline comparisons: The paper’s comparison with baseline methods raises some concerns regarding its evaluation methodology. Specifically, on the Tolokers and Roman Empire datasets, the authors use a 60/20/20 train/validation/test split, but the results reported for baseline models like GCN are significantly lower than what is found in the original work, which used a 50/25/25 split. When compared with the results in “A Critical Look at the Evaluation of GNNs under Heterophily” (2023), the proposed method (82.55 on Tolokers) does not seem to outperform a simple baseline like GCN on Tolokers (83.64), raising questions about whether the method truly offers improvements in these settings.\n\nOverall, I would be happy to increase my score to a 5 or 6 if the authors can convincingly address the above two points and show the practical usefulness of their method.\n\n[1] Dwivedi, Vijay Prakash, et al. \"Long range graph benchmark.\" Advances in Neural Information Processing Systems 35 (2022): 22326-22340.", "questions": "Could the authors explain why they are only using the Tolokers and Roman Empire datasets from “A Critical Look at the Evaluation of GNNs under Heterophily” and not the three other node classification datasets?"}], "discussions": [{"comment_id": "fOgDxsYjT9", "replyto": "kakKDml5VA", "author_type": "authors", "reviewer": null, "comment": "We are grateful that our efforts have finally been recognized by you. We will continue to work to improve the quality of our papers."}, {"comment_id": "kakKDml5VA", "replyto": "uiWJgnOtUr", "author_type": "reviewer", "reviewer": "Reviewer_tPnw", "comment": "Thank you for your efforts in revising and addressing my concerns in the paper. I have updated my score to 6 to support the acceptance."}, {"comment_id": "uiWJgnOtUr", "replyto": "6P1ZSqJETu", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nWe greatly appreciate your suggestions, and **we have made corresponding revisions to the latest paper**. \n\nFirstly, we moved the original Table 3 to the appendix (it is now Table 8), and created a new Table 3. In this new table, we modified the originally uniform model depth from 4 to 3, and then presented the results for hidden layer sizes of 16, 64, and 256, respectively. The advantage of this setting is that when the hidden layer size is 256, the model capacity of GCN aligns with the official recommendations from OGB, ensuring fairness in comparison. Based on this new table, we believe that GNRF still achieves a meaningful trade-off between efficiency and performance. We have extracted the table for your reference below:\n\n\n| **Model** | **#Param** | **Time** | **Acc.** | **#Param** | **Time** | **Acc.** | **#Param** | **Time** | **Acc.** |\n|-----------|------------|----------|----------|------------|----------|----------|------------|----------|----------|\n| | #Hidden=16 | #Hidden=16 | #Hidden=16 | #Hidden=64 |#Hidden=64 | #Hidden=64 | #Hidden=256 | #Hidden=256 | #Hidden=256|\n| GCN(Depth=3) | 3.15k | 0.12s | 60.95 | 15.2k | 0.14s | 68.55 | 110k | 0.21s | **71.65** |\n| GAT(Depth=3) | 15.0k | 0.17s | 59.52 | 86.8k | 0.25s | 64.39 | 788k | OOM | OOM |\n| ACMP(Depth=3) | 4.18k | 3.29s | 61.03 | 32.0k | 6.35s | 68.89 | 374k | OOM | OOM |\n | GNRF(Depth=3) | 5.50k | 0.31s | **62.11** | 52.6k | 0.78s | **69.33** | 701k | OOM | OOM |\n\n\nIn addition, in Table 8 (the original Table 3), we deleted GCN and GAT and added APPNP and GCNII (both are deep GNNs) to ensure that our discussion in this scenario is meaningful. Our observation is similar to the original one. When the depth is relatively shallow (4 or 16), GNRF still has advantages over APPNP and GCNII, but as the depth deepens, the performance of GNRF declines.\n\nWe hope this meets your suggestions well and look forward to further discussions."}, {"comment_id": "2XjEVo39nF", "replyto": "6P1ZSqJETu", "author_type": "authors", "reviewer": null, "comment": "Thank you for your response. We are glad to see that our previous efforts have clarified most of your concerns. Regarding your additional suggestion, we continue to respond as follows:\n\n+ Hyper-parameters for GCN and GAT: We checked the hyperparameter settings on the OGB website and found two potential hyperparameter differences that may affect performance. First is the hidden layer size. We set the hidden layer size to 64, while OGB uses 256. As noted in Table 4 of the appendix, when using the OGBN-Arxiv dataset, we fixed the hidden size of GNRF to 64 (to avoid OOM). To ensure consistent model capacity, we also fixed the hidden layer size of all comparison methods to 64, but obviously, larger hidden layers generally lead to better performance. The second important parameter is the number of layers. We used 4, while OGB uses 3. We believe these two parameters have the most significant impact. Other parameters include: lr=0.001, epoch=2000, dropout=0.5. As for the design of GCN and GNN, we reviewed the OGB source code and found that they are essentially the same, with no residual connections used.\n\n+ Suggestion on efficiency comparison: Yes, we agree with your point. Here, the deeper comparisons are mainly used to highlight two unique features of GNRF. First, as a continuous depth GNN, the number of parameters and memory usage of GNRF is independent of depth. Second, because GNRF uses a fixed-step solver, it is much faster than other popular continuous deep GNNs. We believe it is beneficial to show these two points to the readers. However, we also admit that GCN and GAT are not the most suitable choices for deep GNNs. Therefore we changed the comparison method here. You can see our reply (6). And we have implemented these improvements in the paper.\n\nIf possible, we would appreciate your quick feedback. Thank you."}, {"comment_id": "E0qSbLDPTt", "replyto": "6P1ZSqJETu", "author_type": "reviewer", "reviewer": "Reviewer_tPnw", "comment": "I have read the revised paper and your responses—thank you for your hard work in addressing my concerns. I find that my previous questions and the weaknesses have been well addressed. But I have two additional suggestions for the revised paper, and I will appreciate it if you can consider them.\n* **Hyperparameters for Baseline GCN and GAT Models.**\nWould you mind sharing the hyperparameters used for the baseline GCN and GAT models? The performance reported on the OGB official website for GCN on the Ogb-Arxiv dataset is ranked 64, with a validation accuracy of 0.7300 ± 0.0017 and a test accuracy of 0.7174 ± 0.0029. It is a little different compared to the reported results in the revised paper. Providing the hyperparameters and clarifying the reasons for any differences in results (e.g., whether residual connections were used) would enhance the credibility of your results and help readers better understand the differences.\n\n* **Suggestions for the Efficiency Comparison in Table 3.**\nWhen comparing the efficiency of the proposed method with the baseline GCN in Table 3, GNRF requires more time for continuous-depth 4 (0.79s vs. 0.17s). In my opinion, both models are relatively fast. However, for deeper networks, the performance of both GNRF and GCN tends to degrade, which highlights another perspective on the over-smoothing issue in GNNs. Therefore, I believe it may not be suitable to include comparisons with deeper cases here. \nWhile such efficiency comparisons do demonstrate the training time of proposed GNRF will be the same with increasing depth, GNRF does not show performance improvements with increasing depth. Therefore, focusing the efficiency study on 4-layer networks might be sufficient. Additionally, if deeper networks are to be compared, models like DeeperGCN, which can maintain or improve performance with increasing depth, might be more appropriate for this context.\n\nBy the way, I delete the minor issue about GNN training cost, since the training time complexity for GNN full batch training is linearly increased with the number of layers shown in LADIES[1]. Sorry for any inconvenience.\n\n[1] Layer-Dependent Importance Sampling for Training Deep and Large Graph Convolutional Networks\n."}, {"comment_id": "FAeXFlE8Sj", "replyto": "6P1ZSqJETu", "author_type": "authors", "reviewer": null, "comment": "Dear reviewer, we agree with your point in Weakness 1 that \"it would be valuable to discuss the potential benefits of using more complex GNN architectures\". Therefore, we have added a new section, **\"Future direction\" in the appendix C.3** to explore the possibility of applying our proposed Attri-DRF to Graph Transformer, another commonly used GNN architecture. We excerpt the original text for you as follows:\n\n\"In this paper, we focus on the application of Attribute Discrete Ricci flow in message passing-based discrete/continuous-depth GNNs. However, in view of the excellent performance of Graph Transformers (GTs), especially on graph-level tasks, we believe that it is also meaningful to consider the application of Attri-DRF in this type of method. We believe that this generalization may be feasible, based on two observations: \n\n(1) In theory, there are certain curvatures that can be defined on any node pair $(i, j)$ without requiring $i$ and $j$ to be adjacent (For example, Ollivier Ricci curvature). This is in GTs is very useful because GTs directly aggregates information from the entire graph. \n\n(2) In practice, a significant difference between our model GNRF and GARND is that the aggregation weight replaces the attention coefficient with a curvature-aware coefficient. Since attention is widely used in GTs, this replacement is likely natural.\n\nWe also provide a possible promotion here. Let $\\mathsf{PE}(\\cdot)$ be some position encoding function and $\\mathsf{sim}(\\cdot,\\cdot)$ be some similarity function. We can let $w_{ij}(t) \\equiv \\mathsf{sim}(\\mathsf{PE}(i,t), \\mathsf{PE}(j,t))$ to get the generalization of Attri-DRF:\n\n$$\n\\frac{\\partial \\mathsf{sim}(\\mathsf{PE}(i,t), \\mathsf{PE}(j,t))}{\\partial t} = -\\kappa_{ij}(t)\\mathsf{sim}(\\mathsf{PE}(i,t), \\mathsf{PE}(j,t)).\n$$\n\n\nWe leave application on GTs of this definition to future work.\""}, {"comment_id": "exdV6hgHe3", "replyto": "Xbq76XQKsg", "author_type": "authors", "reviewer": null, "comment": "We are very grateful to the reviewers for their seriousness and responsibility, and we are happy to see that you recognized our work. We will continue to work hard to continuously improve the quality of this paper."}, {"comment_id": "Xbq76XQKsg", "replyto": "Pulh6XlyOn", "author_type": "reviewer", "reviewer": "Reviewer_tn6E", "comment": "I thank the authors for their responses. I think this paper presents an interesting idea, and would like to keep my score as is."}, {"comment_id": "qA9uxWv9l7", "replyto": "CmsB66R1vS", "author_type": "authors", "reviewer": null, "comment": "We are so grateful that the reviewer recognized our efforts! We will continue to improve our paper in the future!"}, {"comment_id": "CmsB66R1vS", "replyto": "8OOSNwipsz", "author_type": "reviewer", "reviewer": "Reviewer_QrP8", "comment": "I thank the authors for providing the detailed responses. My main concerns are well addressed and thus I have increased the score accordingly."}, {"comment_id": "UbQgtBZysu", "replyto": "6P1ZSqJETu", "author_type": "authors", "reviewer": null, "comment": "We present experimental data here that may address your concerns for your review. \n\n## Table 3 \nWe first present the experimental results for OGBN-Arxiv and OGBN-Year. We observe that, when maintaining the same depth, our method shows significant improvements over classical models, both in homophilious and heterophilious settings.\n| |OGBN-Arxiv|OGBN-Year |\n|---|---|---|\n|GCN(depth=4)|67.85|46.22|\n|GAT(head=3,depth=4)| 66.71|44.51|\n|ACMP(depth=4)|67.16|47.55|\n|GNRF(depth=4)|69.25|48.55|\n\nNext, we report the parameter count, storage, and average runtime per epoch based on OGBN-Arxiv. We extract the scenario from Table 3 where the depth is set to 64. In this depth setting, scalability becomes a significant challenge for the model.\n\n| |#Param|Mem.|Time|\n|---|---|---|---|\n|GCN|273k|12.9k|0.93s|\n|GAT|OOM|N/A|N/A|\n|ACMP|19.5k|7.15G|17.6s|\n|GNRF|35.9k|11.5G|0.79s|\n\nIn the main text, we explained that GNRF has computational complexity comparable to that of GCN. However, in this experiment, we found that discrete-depth GNNs (GCN/GAT) require different parameters for each layer, causing their parameter count to increase linearly with the number of layers. In contrast, GNRF and ACMP, as continuous-depth GNNs, maintain a constant number of parameters regardless of depth. Additionally, thanks to GNRF's use of a fixed-step ODE solver, it is significantly faster than ACMP, which uses an adaptive-step solver, when facing long-duration evolution processes. As a result, GNRF achieves a favorable balance across various resource consumption metrics.\n\n## Table 7\nWe conducted experiments on two graph task datasets with over 1 million nodes. These datasets are highly challenging for general message-passing GNNs and are often used to validate the adversarial robustness of models under over-squashing. The results show that our method significantly outperforms GCN, and even competes with SAN (a Graph Transformer-based model with much higher complexity than GNRF). This also demonstrates that GNRF is well-suited for large-scale datasets.\n\n| | GCN | GatedGCN+RWSE | SAN+LapPE | SAN+RWSE | GNRF | GNRF+LapPE | GNRF+RWSE |\n|-----------|--------------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------|\n| Peptides-func AP(↑) | 0.5930±0.0023 | 0.6069±0.0035 | 0.6384±0.0121 | 0.6439±0.0075 | 0.6233±0.0080 | 0.6455±0.0062 | 0.6480±0.0056 |\n| Peptides-struct MAE(↓) | 0.3496±0.0013 | 0.3357±0.0006 | 0.2683±0.0043 | 0.2545±0.0012 | 0.3166±0.0053 | 0.2675±0.0044 | 0.2811±0.0031 |"}, {"comment_id": "8OOSNwipsz", "replyto": "kwM2qZC4PU", "author_type": "authors", "reviewer": null, "comment": "We would like to provide a more detailed explanation regarding the three weaknesses you mentioned.\n\n## Weakness 1 (On how GNRF approximates edge curvature)\nAs explained in the previous comment, EdgeNet does not directly approximate any specific real-world definition of curvature during end-to-end training; rather, it acts as a dataset-adaptive curvature proxy. We support this approach both theoretically and experimentally. Theoretically, our results do not depend on any specific definition, and experimentally, we found that (1) using a specific curvature definition does not consistently perform well across all datasets (as shown in the figure below), and (2) even when using adaptive curvature, GNRF's performance aligns with that of Ricci flow (as shown in Sections 5.2 and 5.3 of the paper).\n\n| Dataset | Corn. | Wisc. | Texas | R. Emp. | Tolo. | Mine. | Ques. | A.-rat. | C._Full | PubM. | DBLP | C._ML |\n|------------|-------|-------|-------|---------|-------|-------|-------|---------|---------|-------|------|-------|\n| GNRF | 87.28 | 88.00 | 87.39 | 86.25 | 83.96 | 95.03 | 73.86 | 46.89 | 72.12 | 90.37 | 85.73| 89.18 |\n| GNRF_FRC | 85.59 | 84.00 | 82.08 | 75.23 | 76.17 | 81.61 | 61.78 | 41.22 | 67.51 | 88.96 | 82.55| 87.29 |\n| GNRF_ARC | 86.49 | 88.00 | 81.90 | 76.52 | 78.14 | 87.25 | 64.55 | 41.74 | 70.17 | 88.21 | 83.33| 89.43 |\n\n## Weakness 2 (On the explanation of Theorem 5)\nWe know that there are many definitions of curvature for edges in a graph. We found that there is a network architecture (EdgeNet) where, when a specific curvature definition (e.g., Forman-Ricci Curvature or others) is specified, we can always find appropriate parameters for this EdgeNet, such that it takes the neighborhood information of an edge as input and outputs the Forman-Ricci Curvature value. As shown in Appendix C.1, EdgeNet is actually composed of several MLPs, and its ability to approximate curvature comes from the universal approximation theorem of MLPs.\n\n## Weakness 3 (On the difference between GNRF and GRAND)\nThe neighbor aggregation weights in GRAND are actually attention coefficients, meaning they satisfy two constraints: normalization and non-negativity. However, for GNRF, the aggregation weights do not have these constraints; they can be negative, and negative weights yield significant benefits in heterophilious graphs. Another point is that attention coefficients come with an implicit bias: node pairs with similar features often receive higher weights. While this bias is often shown to be beneficial, we found that removing it can lead to unexpected results. As shown in Figure 5 of the main text, we observed that GNRF tends to reject pairs of nodes that are very similar, which in turn leads to smoother boundaries, exhibiting behavior that is quite different from that of GRAND. Finally, we present results from an ablation study. Here, \\(d\\) denotes the damping factor, and the difference between the models GRAND+d and GNRF lies only in the aggregation weight calculation. We observed that GNRF significantly outperforms GRAND+d, particularly on heterophilious graphs (Roman-Empire and Tolokers).\n\n| | Roman-Empire | Tolokers | Cora Full |\n|---|---|---|---|\n| GRAND | 60.12 | 79.01 | 67.66 |\n| GRAND+d | 58.57 | 78.78 | 67.31 |\n| GNRF | 86.26 (+26.14) | 83.96 (+4.95) | 72.12 (+4.46) |"}, {"comment_id": "15JU6nONWN", "replyto": "vw1bWn0nIW", "author_type": "authors", "reviewer": null, "comment": "We are so grateful that the reviewer recognized our efforts! We will continue to improve our paper in the future!"}, {"comment_id": "NAOVAsqaZD", "replyto": "xlnc7VRKw0", "author_type": "reviewer", "reviewer": "Reviewer_1v62", "comment": "I would like to thank the authors for addressing my concerns and especially for providing a large number of additional experimental results. I find the results convincing and will therefore adjust my score accordingly."}, {"comment_id": "xlnc7VRKw0", "replyto": "vw1bWn0nIW", "author_type": "authors", "reviewer": null, "comment": "Dear reviewers, all planned changes have now been included in our latest version of the paper. In particular, we conduct richer experiments to enable readers to more comprehensively evaluate the performance of GNRF. They can be found in Tables 1-3 in the main text of the paper and Tables 5-7 in Appendix C.2. At the same time, I also excerpt and record it for you:\n\n(Table 5) We first performed experiments on three commonly used graph classification data sets. Our experimental results were based on 80%/10%/10% division (after our research, this is a commonly used ratio), and reported 10 results. We found that the effect based on continuous depth GNN is generally better than the classic model. We speculate that this may be because the graph-level task requires fusing information from all node information in the entire graph, which is a challenge for discrete GNNs, but is easier for continuous GNNs. This is because in order to achieve sufficiently high accuracy, the ODE solver often needs to perform many time step within [0, T], and it is usually much more than the common layer setting of discrete GNN (for example, within 5). GNRF performs better than the current advanced continuous depth GNN, namely ACMP.\n\n| Pooling | NCI1 | NCI1 | DD |DD | PROTEINS | PROTEINS |\n|---------|--------------|--------------|--------------|--------------|--------------|-------------|\n| | Sum | Mean | Sum | Mean |Sum | Mean |\n| GCN+res | 75.28 ± 1.33 | 76.26 ± 1.05 | 74.81 ± 0.96 | 76.12 ± 0.57 | 75.42 ± 1.30 | 75.82 ± 0.35 |\n| GAT+res | 73.25 ± 2.11 | 73.65 ± 1.35 | 76.68 ± 0.88 | 77.26 ± 2.01 | 74.44 ± 1.35 | 74.51 ± 0.96 |\n| GRAND | 76.54 ± 1.51 | 77.82 ± 0.68 | 75.56 ± 0.55 | 78.51 ± 0.87 | 77.12 ± 0.53 | 78.25 ± 1.14 |\n| ACMP | 74.42 ± 0.60 | 79.09 ± 0.77 | 75.82 ± 1.83 | 78.44 ± 0.53 | 78.88 ± 0.33 | 78.34 ± 0.66 |\n| GNRF | 79.59 ± 0.69 | 81.67 ± 0.54 | 78.52 ± 0.64 | 79.08 ± 0.88 | 78.59 ± 2.12 | 80.12 ± 0.54 |\n\n(Table 6) According to your request, we have supplemented two datasets from the Long Range Graph Benchmark (LRGB) in Table 7. Our dataset partitioning and statistical methods are fully consistent with the official LRGB. Additionally, we directly cite data from the LRGB LeaderBoard for comparison to ensure the fairness of the results. Following the convention of LRGB, we also tested the gains provided by GNRF after using two common positional/structural encodings: LapPE and RWSE. Based on our results, we find that GNRF shows significant improvements over classic message-passing-based GNNs. Without additional encoding, GNRF improves performance by at least 3% over GCN on both Peptides-func and Peptides-struct. When additional encodings are used, GNRF's performance can rival that of SAN (a Transformer-based architecture). However, we acknowledge that GNRF still struggles to match the state-of-the-art Graph Transformer methods on the LRGB dataset. Nevertheless, we believe this is forgivable because GNRF remains a fully message-passing architecture, where first-order neighbors are the only direct source of information for feature updates. Compared to Graph Transformer methods, GNRF has much lower computational complexity and is more suitable for large-scale single-graph scenarios.\n\n\n| | GCN | GatedGCN+RWSE | SAN+LapPE | SAN+RWSE | GNRF | GNRF+LapPE | GNRF+RWSE |\n|-----------|--------------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------|\n| Peptides-func AP(↑) | 0.5930±0.0023 | 0.6069±0.0035 | 0.6384±0.0121 | 0.6439±0.0075 | 0.6233±0.0080 | 0.6455±0.0062 | 0.6480±0.0056 |\n| Peptides-struct MAE(↓) | 0.3496±0.0013 | 0.3357±0.0006 | 0.2683±0.0043 | 0.2545±0.0012 | 0.3166±0.0053 | 0.2675±0.0044 | 0.2811±0.0031 |\n\nWe hope that the additional experiments will address your concerns. We also look forward to your feedback to help us further improve the paper."}, {"comment_id": "07JXuy9jdc", "replyto": "kwM2qZC4PU", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer, we have now completed all the planned revisions. You can find details of these changes in the Official Comment and the latest version of the paper. We are eagerly awaiting your positive feedback."}, {"comment_id": "9aN86EmQpM", "replyto": "6P1ZSqJETu", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer, we have completed all the planned revisions as scheduled. Specifically, we have supplemented the content with graph classification tasks (see Table 6 in Appendix C.2), and conducted classification and regression tasks on two datasets each containing over one million nodes (also documented in Table 7 of Appendix C.2). We are eagerly looking forward to your positive feedback."}, {"comment_id": "Pulh6XlyOn", "replyto": "UFjaTsDjtf", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer, We have now completed all the planned revisions. Specifically, we have added more references related to discrete Ricci flow in the main text; and introduced more diverse and larger datasets, which are detailed in Appendix C.2, Tables 6 and 7. We are eagerly awaiting your positive response."}, {"comment_id": "WFgVyKIvMK", "replyto": "7b2JrzdLhA", "author_type": "authors", "reviewer": null, "comment": "We have conducted a second comprehensive revision of the paper, incorporating all planned changes. Specifically, these include:\n\n1. **Supplementing Missing References**: Reviewer tn6E raised concerns about the unclear origin of discrete Ricci flow. We have now provided more relevant references.\n\n2. **Adding Pseudocode**: Reviewer tPnw suggested including pseudocode. We have addressed this by adding pseudocode in Appendix C.1.\n\n3. **Providing Additional Experiments**: Reviewer 1v62 believed that our experimental results were limited. In response, we have included additional experiments in Section C.2. Table 6 presents the performance of GNRF on three commonly used graph classification datasets, while Table 7 evaluates GNRF on long-range graph benchmarks. These benchmarks involve datasets with over one million nodes, and the results demonstrate the strong performance of our method.\n\n4. **Future Direction**: Reviewer tPnw thought it would be beneficial to discuss the application of our framework to a wider range of GNNs. We now show in Appendix C.3 an intuition of how to apply Attri-DRF on Graph Transformer. We also show why this intuition is reasonable.\n\n**Additional Changes**: \n5. We have further developed Theorem 2 by providing additional proof that the Dirichlet energy lower bound obtained in Theorem 2 is strictly greater than zero. This ensures that our conclusions are non-trivial.\n\n6. We highlighted all theorems to make the paper look better."}, {"comment_id": "WLiN09V9kx", "replyto": "UFjaTsDjtf", "author_type": "authors", "reviewer": null, "comment": "## Weakness 3\nThe results for OGBN-Arxiv and OGBN-Year are reported in **Table 3 (resource consumption experiments)**. We focus on the scalability of RNRF on these larger datasets while also reporting performance. Below is an excerpt from the table (where $d$ denotes the depth of the model):\n\n|Dataset|GCN(d=4)|GCN(d=16)|GCN(d=64)|GAT(d=4)|GAT(d=16)|GAT(d=64)|ACMP(d=4)|ACMP(d=16)|ACMP(d=64)|GNRF(d=4)|GNRF(d=16)|GNRF(d=64)|\n|---|---|---|---|---|---|---|---|---|---|---|---|---|\n|OGBN-Arxiv|67.85|56.81|33.09|66.71|OOM|OOM|67.16|65.72|51.72|69.25|65.14|55.23|\n|OGBN-Year|46.22|42.94|38.01|44.51|OOM|OOM|47.55|43.53|42.31|48.55|44.13|40.15|\n\n## Question 2\nWe referenced Ollivier's paper [1] in our work, where Ollivier defined the Coarse Ricci curvature (now known as Ollivier-Ricci Curvature, abbreviated as ORC). Subsequently, [1] introduced a continuous-time version of ORC defined as follows:\n$$\n\\kappa(x,y) = -\\frac{d}{dt}\\frac{W_1(m_x^t,m_y^t)}{d(x,y)}\n$$\n\nHere, $m_x^t$ epresents the probability distribution of a random walk at point $x$ at time $t$. Ollivier also discussed how this formula could be extended to graphs by treating $m_x^t$ as the probability distribution of a random walk starting at node $x$ and transitioning to its first-order neighbors (with probabilities determined by edge weights). \nPlease note that this actually contains the idea of ​​Ricci flow: edge weights change over time, making $x$ time-dependent and, consequently, curvature time-dependent.\n\nA few months later, Ollivier significantly expanded upon this in his paper (2010, [2]). In Section 2.3.5, “Problem N,” Ollivier formally proposed Discrete Ricci flow, using the following equation:\n$$\n\\frac{d}{dt}d(x,y)=-\\kappa(x,y)d(x,y)\n$$\nThis equation is nearly identical to the one we used. In [2], Ollivier explained that this equation was inspired by results in continuous Riemannian geometry. The first application of this formula in graph learning was in [3]. Our innovation lies in applying Discrete Ricci flow to other curvature definitions.\n\nIt is worth noting that Discrete Ricci flow had already been widely used in the field of computer graphics before [3]. For example, in [4], the following equation was used:\n$$\n\\frac{dg_{ij}(t)}{dt}=-2K(t)g_{ij}(t)\n$$\nHere, $g_{ij}(t)$ is a distance metric on the manifold, and $K(t)$ is the corresponding Gaussian curvature. Although [4] did not mention Ollivier’s work, the formulas are formally identical. We will add citations to the above works, especially [2], in our paper to avoid confusion for readers.\n\n## Question 4\nWe respectfully offer a different perspective on this matter. We believe that GNRF is, in fact, a fully message-passing framework. Specifically, the equation we used in the paper is as follows:\n\n$$\n\\frac{\\partial\\boldsymbol{h}_i(t)}{\\partial t} = \\sum -{\\rm EdgeNet}(t) [\\boldsymbol{h}_j(t) - {\\cos\\big(\\boldsymbol{h}_j(t), \\boldsymbol{h}_i(t)\\big)}\\boldsymbol{h}_i(t)]\n$$\n\nUsing the simplest ODE solver (i.e., the forward Euler method), we derive the following explicit update process:\n\n$$\n{\\boldsymbol{h}_i(t+1)} = \\boldsymbol{h}_i(t) - \\eta\\sum -{\\rm EdgeNet}(t) [\\boldsymbol{h}_j(t) - {\\cos\\big(\\boldsymbol{h}_j(t), \\boldsymbol{h}_i(t)\\big)}\\boldsymbol{h}_i(t)]\n$$\n\nThis formula fully aligns with the three-stage message-passing paradigm—Message, Aggregation, and Update:\n\nMessage function:\n\n$$\nM_{ij}(t) = \\boldsymbol{h}_j(t) - {\\cos\\big(\\boldsymbol{h}_j(t), \\boldsymbol{h}_i(t)\\big)}\\boldsymbol{h}_i(t)\n$$\n\nAggregate function:\n\n$$\nh^\\prime_i(t) = \\sum -{\\rm EdgeNet}(t)M_{ij}(t)\n$$\n\nUpdate function:\n$$\nh_i(t+1) = h_i(t) - \\eta h^\\prime_i(t)\n$$\n\nMore advanced ODE solvers only modify the Update function. Therefore, GNRF still entirely fits within the message-passing framework.\n\nWe sincerely hope this addresses your concerns.\n\n[1] Ricci curvature of markov chains on metric spaces.\n\n[2] A survey of Ricci curvature for metric spaces and Markov chains\n\n[3] Network Alignment by Discrete Ollivier-Ricci Flow\n\n[4] Discrete Surface Ricci Flow"}, {"comment_id": "l295dfVHea", "replyto": "GlLeHNIYJk", "author_type": "reviewer", "reviewer": "Reviewer_tn6E", "comment": "Thank you for responding to my review.\n\nRegarding weakness 3, where's the performance comparison on the OBGN datasets?\n\nRegarding question 2, what is the dynamic process similar to curvature flow that you mentioned? I can't find this in Ollivier's paper either. I don't yet see how this 'curvature' equation is justified.\n\nRegarding question 4: Let me clarify my question. Over-smoothing and over-squashing are caused by the discrete/continuous message-passing design in GNNs. The curvature is just a proxy to measure how connected/bottlenecked the graph is, which is only relevant since message passing utilizes the graph topology to propagate information. The curvature itself doesn't have a direct relevant to the learning task. Considering that GNRF does not strictly adhere to the message passing design, is it appropriate to mention these problems in this work?"}, {"comment_id": "ADTmD6sHCU", "replyto": "6P1ZSqJETu", "author_type": "authors", "reviewer": null, "comment": "## Question 1\nFor this question, you can refer to our discussion on the differences between GNRF and GRAND in the paper. GRAND is a classic continuous-depth GNN model that directly uses attention coefficients as aggregation weights. We summarize the main difference as follows: a significant distinction lies in the sign of the aggregation weights. Attention coefficients are typically positive, whereas GNRF allows for negative weights. Specifically, when an edge has positive curvature, we use negative weights, and vice versa for negative curvature. This leads to completely different behavior between GNRF and GRAND. GRAND (attention) tends to smooth all node pairs, while GNRF only smooths node pairs with negative curvature and repels those with positive curvature. In our ablation study, we analyzed the impact when GNRF and GRAND differ only in the aggregation weights. The results showed that GRAND performed poorly on heterophilious graphs (e.g., Tolokers and Roman-Empire), which supports our view on the importance of negative weights/attention coefficients. You may further ask what would happen if we introduced the ability to use negative weights in the attention mechanism—this is precisely what another model, ACMP, does. We also provided a detailed comparison in the paper, and the experimental results show that ACMP still performs significantly worse than GNRF.\n\n## Question 2\nThere may be some ambiguity in our description of EdgeNet, which led to a misunderstanding, and we apologize for that. For datasets commonly used in graph deep learning (such as the node classification datasets we used), attributes are often only present on nodes, not edges. At each time step, we first generate an attribute for each edge (specifically, on edge i~j, we use h_i(t) || h_j(t) for concatenation). Then, in that time step, we obtain the aggregation weights through several layers of EdgeNet. The edge attributes are cleared after that time step, and the process is repeated in the next time step. Therefore, Equation 13 (which has been moved to the appendix in the updated paper) describes how to **obtain aggregation weights using a multi-layer network within a single time step**.\n\n## Question 3\nIn Appendix C.1, we added an explicit update formula for GNRF under forward differentiation method. Please note that this is only for illustration. In actual practice, the update formula for features is more complex due to our use of a more advanced ODE solver. In the field of deep learning, we focus more on describing a novel partial differential equation without discussing the internal workings of the ODE solver in detail, which is consistent with almost all related works on GNNs based on differential equations, such as [4] and [5]. Nevertheless, we highly value your feedback and will provide pseudocode in the next version of the paper.\n\n## Question 4\nIn fact, calling GNRF a single \"layer\" is inaccurate. In the paper, we refer to it as \"continuous depth,\" and in the code, we call it a \"block.\" We avoid using the term \"layer.\" The GNRF implementation in the code has only one ODE **block**. However, it is important to note that an ODE block can simulate arbitrary depth (or, less rigorously, any number of layers) of a GNN by appropriately setting the evolution end time T. For example, if the first ODE block evolves the system from T = t_0 to T = t_1, and the second ODE block continues to evolve the system from T = t_1 to T = t_2, this is essentially equivalent to using a single ODE block to evolve the system from T = t_0 to T = t_2. Within the same ODE block, the feature update fomula is executed multiple times (the specific number and manner of execution are determined by the ODE solver and are not explicitly shown in the code). Therefore, the answer is yes—a single ODE block can effectively approximate any multi-layer discrete-depth GNN; we only need to increase the termination time T.\n\nWe hope this response effectively addresses your concerns, and we are more than willing to provide further details regarding any other questions you may have and to update the paper accordingly. Thank you once again for your valuable feedback and suggestions!\n\n[5] GRAND: Graph Neural Diffusion\n\n[6] ACMP: Allen-Cahn Message Passing for Graph Neural Networks with Particle Phase Transition"}, {"comment_id": "IzHkHbVx2e", "replyto": "6P1ZSqJETu", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nThank you for your professional comments and your recognition of our work. We have also noted your concerns, and below is our detailed response:\n\n## Weakness 1 \nWe have comprehensively updated the paper to include more in-depth discussions. Specifically, we have added discussions on curvature-based edge sampling [1] and weighted aggregation [2] to provide a more holistic understanding of curvature graph learning. Additionally, other branches of Riemannian graph learning, such as hyperbolic graph learning [3], are discussed in the related work section of the appendix. In addition, we have added a new appendix section C.3 to discuss the potential application of our proposed Attri-DRF in other GNN architectures, especially Graph Transformer. You can see our detailed response to this in reply (4).\n\n## Weakness 2 \nIn the latest version, we have added 5 more datasets to the main experiments, bringing the total to 12, with the largest containing nearly 50,000 nodes. Additionally, for the resource overhead experiments, we have evaluated two larger datasets, OGBN-Arxiv and OGBN-Year [4], both of which have over 100,000 nodes. The results indicate that our method still shows stable improvements on these larger datasets.\n\n## Weakness 3 \nWe have now supplemented the relevant experiments. In the resource overhead evaluation, we additionally report the number of trainable parameters, peak memory usage, and average training time per epoch. The results demonstrate that GNRF achieves a good balance across multiple metrics, particularly in deep model settings, where it incurs less overhead compared to traditional models like GCN. Furthermore, we have added a discussion on computational complexity in the main text. Based on widely recognized computation methods, the results show that GNRF has the same complexity as GCN.\n\n[1] Curvdrop: A ricci curvature based approach to prevent graph neural networks from over-smoothing and over-squashing.\n\n [2] Curvature Graph Neural Network. \n\n[3] Hyperbolic variational graph neural network for modeling dynamic graphs \n\n[4] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods"}, {"comment_id": "KuLOfPEqvo", "replyto": "kwM2qZC4PU", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nThank you very much for your professional review comments. We have noticed that your main concerns focus on the model design, especially regarding EdgeNet. Below is our detailed response:\n\n## Weakness 1\nIndeed, EdgeNet does not use a specific curvature definition. This is because, although there are multiple ways to define curvature, as far as we know, there is no theoretical guidance on how to choose the appropriate one in practice. Furthermore, based on experimental data from existing literature ([1], [2]) as well as additional experiments we conducted, we observed that the impact of using different curvature definitions on model performance is quite significant, yet no generalizable guidelines can be formed. Therefore, we believe that using an adaptive curvature definition may be a better choice. Additionally, as shown in Section 5.2 of the revised paper, even though EdgeNet does not rely on a specific curvature, it still exhibits behavior consistent with Ricci flow. This is because our theorem itself is independent of any particular curvature, providing a theoretical foundation for the introduction of EdgeNet.\n\n## Weakness 2\nYour understanding is mostly correct. We implemented $\\phi_1$ and $\\phi_2$ as a two-layer MLP. In Theorem 5, we state that it is always possible to find suitable parameters for these MLPs such that the network output can approximate any curvature (i.e., the network structure is fixed). You can refer to Appendix B.5 in the revised paper for a rigorous proof of Theorem 5, as well as Appendix C.1 for the implementation details of EdgeNet.\n\n## Weakness 3\nIn the original version, we provided an intuitive and experimental discussion on the difference between GNRF and GRAND, and we have now further deepened this discussion. One major distinction lies in the sign of the aggregation weights. GRAND is derived from a heat diffusion model, resulting in all positive aggregation weights (i.e., the attention coefficients are always positive). In contrast, GNRF allows negative weights—specifically, when an edge has positive curvature, we use negative weights, and vice versa for negative curvature. This leads to fundamentally different behavior between GNRF and GRAND. While GRAND tends to smooth all node pairs, GNRF only smooths negative curvature node pairs while repelling positive curvature ones. In experiments, we demonstrated through ablation studies the effects when GNRF and GRAND differ only in aggregation weights. We found that GRAND performs poorly on heterophilious graphs (such as Tolokers and Roman-Empire), supporting our view on the importance of negative weights/attention coefficients.\n\n## Question 1\nThe original statement was indeed not precise, and we have corrected it. What we meant was that the aggregation weights of GNRF ($\\kappa^\\prime$) have the same sign as the curvature ($\\kappa$), and EdgeNet’s role is to approximate $\\kappa^\\prime$. As mentioned in our response to Weakness 1, when using EdgeNet, the model actually utilizes a dataset-specific personalized curvature rather than a pre-defined curvature. The experiments in Section 5.2 of the revised paper confirm that this approach is feasible—using EdgeNet still adheres to the characteristics of Ricci flow, and our theoretical results also apply to EdgeNet.\n\n## Question 2\nYes, you are correct. This was indeed a typographical error, which we have fixed in the latest version of the paper. We have also updated the statements of all theorems with more detailed descriptions to ensure their rigor.\n\n## Question 3\nDear reviewer, please refer to our responses to Weakness 1 and Question 1.\n\n## Question 4\nWe have added this experiment in the latest version of the paper. Specifically, we replaced the curvature in GNRF with two real curvatures: Forman-Ricci Curvature and approximate resistance curvature, resulting in two new models: GNRF_FRC and GNRF_ARC. We validated these models on 12 datasets, and although the two variants have their own strengths and weaknesses, they generally perform worse than GNRF with EdgeNet, which further supports our belief that adaptive curvature is more advantageous. \n\nWe hope that our response effectively addresses your concerns, and we are more than willing to provide further details on any other questions you may have and update the paper accordingly. Once again, thank you for your valuable feedback and suggestions!\n\n\n[1] Curvature filtrations for graph generative model evaluation \n[2] Curvature constrained mpnns: Improving message passing with local structural properties"}, {"comment_id": "GlLeHNIYJk", "replyto": "UFjaTsDjtf", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nThank you for your professional and detailed review of our paper. We are very pleased to see your recognition of our work, and we have carefully noted your valuable comments. Below, we provide a detailed response to your feedback.\n\n## Weakness 1\n\nIn the latest version of the paper, we have added a more detailed formal version for each theorem, including all necessary assumptions in the statement. These improvements can be found in the appendix section. The main text retains an informal version that is more intuitive and easy to understand, to allow readers to quickly grasp the key conclusions.\n\n## Weakness 2\n\nWe have added detailed descriptions of the models used in the experiments in Section C.1 of the appendix, including a description of EdgeNet. Additionally, we introduced two GNRF variants in Section 5.1 of the main text, which do not use EdgeNet but instead rely on explicit curvature calculations. Our comparisons reveal that specific curvature definitions may not always perform well across different datasets (as reflected in experiments from other papers such as [1] and [2]). Therefore, using EdgeNet for adaptive curvature appears to be a better choice.\n\n## Weakness 3\n\nWe added five new datasets to the main experiments, with the largest containing about 50,000 nodes. Furthermore, the resource consumption experiments include two even larger datasets (with over 100,000 nodes). The experimental results demonstrate that our model still achieves consistent improvements on these datasets.\n\n## Question 1\n\nWe have provided a more detailed explanation of this point in the main text, covering two aspects: 1. It serves as an extension of the theoretical results. Lemma 1 describes the state of Attri-DRF \"when reaching equilibrium,\" while Theorem 3 further explains \"whether equilibrium can be reached.\" 2. On a practical level, it ensures consistency in the evolution process, meaning that within a finite time, all edges evolve sufficiently, ensuring synchronized evolution of the overall graph structure without worrying about parts of the graph being insufficiently developed.\n\n## Question 2\n\nOllivier's paper does not directly present this equation, primarily because Ollivier's work is early research, and its notation differs somewhat from today's conventions. However, Ollivier indeed first explored a dynamic process very similar to curvature flow in his paper. Other related works also adopt a similar perspective to ours, recognizing that Ollivier's paper was the first to introduce Ricci flow on graphs (e.g., see page 5 of [3]).\n\n## Question 3\n\nWe highly value the consistency of GNRF with the theoretical results. In Sections 5.2 and 5.3 of the updated paper, we conducted detailed experiments to investigate this. The results show that GNRF does align well with the theory, including properties like curvature approaching zero (Lemma 1), uniform decay (Theorem 3), and bounded energy (Theorem 2). We appreciate your feedback and will clarify this point further in the paper.\n\n## Question 4\n\nIndeed, the issues of over-smoothing and over-squashing were first raised in the context of discrete-depth GNNs. However, as we have added in Section 3 of the updated paper, classic continuous-depth GNNs (such as GRAND) fully adhere to the design principle of heat diffusion, one of whose fundamental characteristics is reaching thermal equilibrium, i.e., nodes becoming completely uniform. This is consistent with the concept of over-smoothing, and we have validated this in the experiments presented in Section 5.3 of the updated paper. Regarding the over-squashing problem, to the best of our knowledge, there has not yet been dedicated research on this challenge in the context of continuous-depth GNNs. However, as stated in the paper, we have found that many current methods aimed at solving the over-squashing problem share a striking consistency: reducing the influence of edges with extreme positive/negative curvature. We also note that these methods typically treat curvature as a static, topology-dependent attribute. While our approach is conceptually similar to these methods, the way we utilize curvature is entirely different, and we believe this offers a new perspective for formally addressing this challenge in the future.\n\n## Question 5\n\nBased on the experiments added in Section 5.1 of the main text, GNRF has proven effective even on larger datasets. Moreover, we observed that GNNs based on differential equations often demonstrate stronger advantages when dealing with very large-scale model settings.\n\nWe hope that our responses effectively address your questions, and we are very willing to provide more detailed replies to any further questions you may have and promptly update the paper. Thank you again for your valuable comments!\n\n[1] Curvature filtrations for graph generative model evaluation\n[2] Curvature constrained mpnns: Improving message passing with local structural properties\n[3] Graph Pooling via Ricci Flow"}, {"comment_id": "MPMc9M085R", "replyto": "7b2JrzdLhA", "author_type": "authors", "reviewer": null, "comment": "Sorry to keep the reviewers waiting so long! We highly valued your professional comments and revised our paper comprehensively, which took a couple of days because the workload was a bit much. Specifically, our revisions are as follows:\n\n# More rigorous statement of the theorem\n1. We add a new formal version of all theorems in the appendix that is more detailed than before, and retain the more intuitive and accessible informal version in the main text (tn6E, QrP8)\n\n# More in-depth discussion\n1. We add a detailed design description of EdgeNet in the appendix section (tn6E, tPnw) .\n2. We add a description of the computational complexity of the algorithm in the main text (tPnw)\n3. We further explain the advantages of data-independent decay rates in the main text (tn6E)\n4. We make the differences with GRAND more explicit in Section 4.1 (QrP8)\n5. We analyze more work related to graph curvature, Riemannian graph learning (tPnw)\n6. We further discuss the advantages of using EdgeNet to approximate curvature in Section 4.2 (tn6E, QrP8)\n7. We re-organize the formulation of the ablation study to illustrate the differences with existing work such as GRAND (QrP8 )\n\n# Richer experiments\n1. We report results on larger datasets (OGBN-Arxiv and OGBN-Year) (tn6E, tPnw, 1v62), along with the number of trainable parameters, the peak memory footprint, and the average single-round training time (tPnw)\n2. The main experiment includes a larger dataset and a more strong baseline, while using more rational evaluation metrics (e.g., ROC-AUC in Tolokers' evaluation) (tn6E, 1v62)\n3. We perform ablation experiments for the case where real curvature is used instead of EdgeNet (tn6E, QrP8)\n\nWe are eager to discuss the current updated version with reviewers as soon as possible, and we are willing to continue to make rapid adjustments to the paper in response to further feedback."}, {"comment_id": "DRUtZVhQNW", "replyto": "vw1bWn0nIW", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nThank you for your detailed review and for recognizing the innovativeness of our approach. We also acknowledge your concerns regarding the limitations of our experimental results, and we would like to address them as follows:\n\n## Regarding Limited Experimental Results\nIn the original manuscript, we aimed to provide a more diverse set of experiments to offer a comprehensive understanding of our method. However, we acknowledge that the experimental results on the main task (node classification) were somewhat limited. To address this, we have added the remaining three datasets from HeterophilousGraphs: **Minesweeper**, **Questions**, **Amazon-ratings**, and two datasets from CitationFull: **DBLP** and **Cora_ML** as supplementary results, which can be found below. Additionally, in response to other reviewers' requests, we will report results on the **OGBN-Arxiv** and **OGBN-Year** datasets (both with over 100k nodes) in the coming days. Lastly, regarding the LRGB dataset, experiments are ongoing, and we will report results on **Peptides-func** and **Peptides-struct** shortly. We appreciate your patience.\n\n## Regarding Baseline Comparisons\nWe revisited the paper and code from [1] and identified two key factors:\n\n1. We reported results on Tolokers based on **accuracy**, whereas the original paper used **ROC-AUC**. Cross-metric comparisons are inappropriate, and we acknowledge that ROC-AUC is a better metric for binary classification. We will re-evaluate the results on Tolokers and update the paper accordingly. You will be notified once the updated results are available.\n\n2. We noticed that in [1], they included **residual connections and an additional linear layer** for GCN (referred to as GCN+Res), while we reported results for the **vanilla GCN**. We found that residual connections had a significant effect on the HeterophilousGraphs benchmark but did not improve performance on DBLP and Cora_ML. Given that residual connections are designed for layered neural networks, we have not incorporated this module into our continuous deep neural network based on differential equations. Therefore, we feel that considering GCN+Res as a baseline for continuous deep GNNs may not be entirely fair. Nonetheless, our model (GNRF) still performs competitively.\n\n## Regarding the Choice of Tolokers and Roman Empire\nThis choice was random. We have now included the remaining datasets from our benchmark, so this should no longer be a concern.\n\n| | Minesweeper | Questions | A.-ratings | DBLP | Cora_ML |\n|---------------|----------------------|--------------------|-------------------|-------------------|------------------|\n| **GCN** | 74.79 ± 1.78 | 50.21 ± 2.24 | 37.99 ± 0.61 | 83.93 ± 0.34 | 87.07 ± 1.21 |\n| **GCN+Res** | 90.13 ± 0.70 | 75.45 ± 2.31 | 48.17 ± 0.55 | 82.64 ± 0.51 | 85.62 ± 0.72 |\n| **GRAND** | 80.56 ± 3.12 | 54.90 ± 2.12 | 37.53 ± 0.36 | 84.60 ± 0.99 | 88.49 ± 0.81 |\n| **GNRF** | 95.03 ± 0.20 | 73.86 ± 1.18 | 47.89 ± 1.08 | 85.73 ± 0.76 | 89.18 ± 0.19 |\n\nThe above improvements will be updated in the paper soon. Once again, thank you for your professional review.\n\n[1] A Critical Look at the Evaluation of GNNs under Heterophily” (2023)"}], "meta_review": {"metareview": "In the paper, the authors introduce the dynamical system Attribute Discrete Ricci Flow (Attri-DRF) and incorporate it into a novel framework called Graph Neural Ricci Flow (GNRF), a continuous graph neural network that is curvature-aware. \n\nAfter the rebuttal, most of the concerns were addressed. There are several strengths of the current paper: (1) The proposed framework is novel and interesting. Theoretically, the results are sound and solid (e.g., guarantees on the curvature decay rate and the stable curvature limit of Attri-DRF in Section 3).\n\nWhile there are still some concerns about limited experiments and evaluations, in my opinion the strengths outweigh the weaknesses. As a consequence, I recommend accepting the paper. The authors are encouraged to incorporate the suggestions and feedback of the reviewers into the revision of their manuscript.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "ajxAJ8GUX4", "forum_url": "https://openreview.net/forum?id=ajxAJ8GUX4", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "Learning Geometric Reasoning Networks For Robot Task And Motion Planning", "authors": ["Smail Ait Bouhsain", "Rachid Alami", "Thierry Simeon"], "abstract": "Task and Motion Planning (TAMP) is a computationally challenging robotics problem due to the tight coupling of discrete symbolic planning and continuous geometric planning of robot motions. In particular, planning manipulation tasks in complex 3D environments leads to a large number of costly geometric planner queries to verify the feasibility of considered actions and plan their motions. To address this issue, we propose Geometric Reasoning Networks (GRN), a graph neural network (GNN)-based model for action and grasp feasibility prediction, designed to significantly reduce the dependency on the geometric planner. Moreover, we introduce two key interpretability mechanisms: inverse kinematics (IK) feasibility prediction and grasp obstruction (GO) estimation. These modules not only improve feasibility predictions accuracy, but also explain why certain actions or grasps are infeasible, thus allowing a more efficient search for a feasible solution. Through extensive experimental results, we show that our model outperforms state-of-the-art methods, while maintaining generalizability to more complex environments, diverse object shapes, multi-robot settings, and real-world robots.", "keywords": ["Graph Neural Networks", "Deep Learning for Robotics", "Task and Motion Planning", "Robot Manipulation Planning"], "primary_area": "applications to robotics, autonomy, planning", "pdf_url": "https://openreview.net/pdf?id=ajxAJ8GUX4", "decision": "Accept (Poster)", "num_reviews": 4, "num_discussions": 27, "reviews": [{"review_id": "b38m95GMVt", "reviewer": "Reviewer_LeQs", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 2, "summary": "This paper presents Geometric Reasoning Networks (GRN), a graph neural network-based approach to enhance efficiency in Task and Motion Planning (TAMP) for robotic manipulation. An incremental GNN-based model is introduced with two mechanisms: inverse kinematics (IK) feasibility prediction, and grasp obstruction (GO) estimation using full-state knowledge. Experimental validation demonstrates GRN's applicability to various environments and configurations.", "strengths": "- The study tackles a significant challenge in TAMP arising from the bottleneck in geometric planning within complex 3D environments and proposes a novel, though incremental, architecture to predict action and grasp feasibility to reduce the need for costly geometric planning queries. \n\n- The model is more generalizable in complex environments compared to prior works (Table 4).\n\n- The proposed model is evaluated and compared against prior works and an ablation study is conducted.\n\n- The authors conducted real-robot experiments, and shared their code.", "weaknesses": "Below is the list of the issues I observed, and my suggestions for this work:\n\n1. **Strong Assumptions on Object Knowledge and Static Scene Conditions**: The model assumes complete knowledge of object shape, dimensions, and pose, as well as static conditions where objects remain stationary unless moved by the robot (lines 153–154). This is a considerable simplification for real-world scenarios, where partial observability and dynamic conditions are typical. By contrast, some baselines used for comparison assume partial observability and predict feasibility based on images rather than full state knowledge, making these comparisons less directly fair. Addressing this by either relaxing these assumptions or comparing only with baselines under similar assumptions would improve applicability and fairness.\n\n2. **Limited Integration and Short Planning Horizons**: While some baselines are fully integrated within TAMP solvers, GRN’s integration is relatively basic and tested only on shorter-horizon problems. The scenarios addressed in the experiments are limited in complexity, lacking tasks that demand multi-step, long-horizon planning (e.g., removal of obstructing objects or inter-robot handovers). Expanding the evaluation to more complex, long-horizon planning tasks would better demonstrate GRN's practical utility in handling realistic TAMP challenges.\n\n3. **Improvement for TAMP**: Since the paper do not include a full-fledge TAMP solver, it remains unclear how it improves traditional/existing solvers. It'd have been really useful to see how many geometrical computations were skipped (so computational time was gained) when GRN was used vs. not used. Inference time comparison to some of the baselines is not fair, as they use image-based input while this work relies on a significantly smaller full-state information of the environment. \n\n4. **Incremental Advancement and Limited Impact of IK Infeasibility Module**: While GRN introduces small adjustments on Edge-Featured Graph Attention Network (EGAT), these additions are incremental and rely heavily on prior EGAT methods. Furthermore, as shown in Table 3, the IK infeasibility module does not significantly impact overall performance. Adding a deeper analysis or improvement in these areas could strengthen the contribution. It seems like GO by itself is already a good estimator of feasibility. This could also be due to problem settings where there might not be (enough) cases when there is no grasp obstruction IK was still infeasible (e.g., due to reachability). A more nuanced analysis of scenarios where IK feasibility alone might be crucial (e.g., environments with tighter spaces or more complex robot configurations) that would help clarify when this module provides substantial benefits.\n\n5. **Restricted Grasp Representation**: The model simplifies grasp feasibility by considering only five grasp types, which may limit its effectiveness in real-world settings where grasp variety and adaptability are crucial. Even for grasping box-shaped objects requires 24 (or 20 if it lies flat on a surface) different configurations with a parallel jaw gripper, much higher than 5 as used in this work. Exploring a more flexible grasp representation that could handle diverse object shapes and placements would make the model more versatile for practical robotics applications.\n\n6. **Limited Interpretability Support**: While the authors claim that GRN allows interpretation of feasibility (or infeasibility) of actions, this claim lacks follow-up discussion or specific experimental support. Including experiments or qualitative analysis to illustrate interpretability (e.g., showing how grasp obstruction information directly affects planning decisions) would make this feature more convincing and actionable. \n\n7. **Clarity and Organization of Paper Structure**: The organization could be improved for readability and logical flow. For example, Figure 2 is never referenced in the text, and Figure 1, which appears on page 3, is not mentioned until Section 6.4 on page 10. Figures in the Appendix (e.g., Figures 7 and 8) are not referenced or explained, leaving the reader without context for understanding their relevance.", "questions": "I've listed my main concerns under _weaknesses_ section, and here are some further questions (some overlapping with the points above):\n\n- **Self-loop Edge Representation in 4.1**: What is the rationale for using the self-loop edge to store IK feasibility instead of treating it as a node feature? How does this choice affect the overall model performance or its ability to generalize to unseen environments?\n\n- **Effectiveness and Necessity of the IK Module**: In the ablation study (Table 3), the results show minimal difference between the full model and the one without the IK module. Given this, is the IK module truly necessary? What are the specific cases where IK adds value? Also, could you provide the inference time for the IK module to understand its overhead?\n\n- **Access Task Success Rate in Section 6.4**: The success rate for the Access task in Table 5 is slightly lower than Bouhsain et al. (2024). What are the main reasons for this reduction? Are there particular scenarios where GRN struggles, or is this a result of experimental variance?\n\n- **Inference Time in Table 5**: Could you provide more detailed inference times for Table 5, including how the times break down between GRN predictions and geometric planning? Understanding the computational cost of each component would provide a clearer picture of where time savings are realized.\n\n- **Handling Dynamic Scenes**: The current assumptions restrict the model to static scenes, which is not always practical in real-world tasks. How would this approach be extended to handle dynamic environments?\n \n- **Real-world Evaluation of Long-horizon Tasks**: The paper would benefit from showing how GRN scales to more challenging multi-step tasks, such as multi-robot collaboration or handling non-prehensile manipulation. Can the authors comment on future steps for validating GRN on these scenarios?\n\n- **Simplified Grasp Types**: The use of only five grasp types is limiting. Are there plans to extend the grasp representation, or can the authors justify why five types were sufficient for the tasks tested in this paper?"}, {"review_id": "QhtxzDAmIb", "reviewer": "Reviewer_edpJ", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "The paper presents a graph neural network model for efficient Task and Motion Planning (TAMP). It uses a scene graph to represent 3D environments, incorporating Inverse Kinematics (IK) and Grasp Obstruction (GO) modules. Experiments show its superiority over state-of-the-art methods in accuracy, generalization, and speed. It reduces the reliance on computationally expensive geometric planners while remaining robust across diverse robots and complex scenes.", "strengths": "1. It incorporates explainable modules within the GNN, such as inverse kinematics and grasp obstruction estimations, to predict feasibility. This interpretability enhances decision-making by providing insights into why certain actions are infeasible.\n\n2. Experiments demonstrate that GRN significantly lowers computational costs compared to traditional geometric planners, making it a more efficient solution for real-time TAMP applications.", "weaknesses": "1. The method's reliance on simplified bounding box representations and discrete grasp feasibility may limit its effectiveness in real-world TAMP problems, especially in dynamic, cluttered environments with irregular objects or sensor noise.\n\n2. The method only assesses discrete feasibility without motion planning, restricting its ability to handle complex cases where motion feasibility [1] is crucial. \n\n[1] Scaling Multi-Modal Planning: Using Experience and Informing Discrete Search", "questions": "1. How would the performance of GRN be affected if the Inverse Kinematics (IK) and Grasp Obstruction (GO) modules were replaced with traditional methods, such as standard inverse kinematics solvers or simpler distance-based obstruction checks? Specifically, how would this change impact prediction accuracy, interpretability, and computational efficiency in various environments?\n\n2. What is the maximum level of environmental complexity that GRN can handle effectively? For example, how many fixed nodes or objects can it support while maintaining accuracy and efficiency? Could GRN scale to environments with thousands of nodes, akin to point cloud representations, and still perform well in highly cluttered scenes?"}, {"review_id": "u7aFKw2RAC", "reviewer": "Reviewer_FvER", "rating": 8, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 2, "summary": "The paper introduces Geometric Reasoning Networks (GRN), a GNN model for predicting action and grasp feasibility. Technical Innovations of the paper includes:\n- Novel scene representation using directed graphs\n- Edge-Enhanced Graph Attention Network (EGAT) that leverages both node and edge features\n- Ability to explain why actions are infeasible, enabling more efficient planning", "strengths": "1. Outperforms state-of-the-art methods in both action and grasp feasibility prediction\n2. Lower inference time compared to traditional geometric planning\n3. Generalization to more complex environments and different robot types\n4. Writing is very clear and detailed", "weaknesses": "1. Some details for the experiment setting needed to be discussed further to correctly evaluate the approach. See the questions.", "questions": "1a. How imbalanced is the dataset? Is it nearly composed of 50% feasible cases and 50% infeasible cases? It would be appreciated if some ratio of the dataset could be shared as a table.\n\n1b. Since RRT is a random algorithm, I think it doesn't necessarily mean one task is truly infeasible, if RRT couldn't generate a plan. What is the definition of the feasibility in this paper's scope?\n\n2. Regarding training, why is pre-trained needed before jointly fine-tuning? Was there any training instability if we directly train all the modules?\n\n3. The generalization experiment is interesting. What I'm expecting is that the graph representation should be powerful enough to achieve a significantly better generalization. However, it seems that the MLP method is also good enough to generalize to 20 obstacles. Is there any specific reason why this is the case? Or are we also facing the imbalanced dataset issue here? Sampling a feasible case for 20 obstacles sounds challenging, right?\n\n4. I am not very familiar with TAMP problem. Just as a discussion, in general, when would a learning-based feasibility checker be preferred compared to a learning-based end-to-end planner, that directly generates a plan? And when vice versa?"}, {"review_id": "B05laaJQGc", "reviewer": "Reviewer_WQrW", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 2, "contribution": 3, "summary": "This paper proposes to use GNN-based methods to predict action and grasp feasibility, which can speed up the geometric planner in task and motion planning. Furthermore, the authors propose to predict the inverse kinematics feasibility and grasp obstruction to improve the interpretability. Through quantitative and qualitative experiments, the authors verify the effectiveness of the proposed approach.", "strengths": "1: This paper is well-motivated. The efficiency of task and motion planning is a significant issue and how to improve it is an important research question. \n\n2: Predicting geometric dependencies and feasibility with a graph neural network makes sense to me. \n\n3: This paper contains an extensive experiments section and compares the method to several baselines, which is nice. \n\n4: The paper shows some real-world demos.", "weaknesses": "1: My first concern is the limited discussion and comparison to related works in geometric feasibility reasoning [1, 2, 3] and graph neural networks in task and motion planning [4, 5, 6]. In particular, [2] also uses feasibility-related predicates to improve the efficiency of the task planning. \n\n[1]: K. Lin, C. Agia, T. Migimatsu, M. Pavone, and J. Bohg. Text2motion: From natural language instructions to feasible plans. Autonomous Robots, 47(8):1345–1365, 2023.\n\n[2]: Y. Huang, C. Agia, J. Wu, T. Hermans, and J. Bohg. Points2Plans: From Point Clouds to Long-Horizon Plans with Composable Relational Dynamics, ArXiv, 2024. \n\n[3]: C. Agia, T. Migimatsu, J. Wu, and J. Bohg. Stap: Sequencing task-agnostic policies. In 2023 IEEE International Conference on Robotics and Automation (ICRA), pages 7951–7958. IEEE, 2023.\n\n[4]: Planning for Multi-Object Manipulation with Graph Neural Network Relational Classifiers. In IEEE International Conference on Robotics and Automation (ICRA), 2023.\n\n[5]: H. Shi, H. Xu, Z. Huang, Y. Li, and J. Wu. Robocraft: Learning to see, simulate, and shape elastoplastic objects in 3d with graph networks. The International Journal of Robotics Research. \n\n[6]: H. Chen, Y. Niu, K. Hong, S. Liu, Y. Wang, Y. Li, and K. R. Driggs-Campbell, “Predicting object interactions with behavior primitives: An application in stowing tasks,” in 7th Annual Conference on Robot Learning, 2023.\n\n2: The videos and figures for this paper are hard to follow. For example, in Figure 1, I can only understand green represents feasible and red represents feasible. However, there are other colored objects like blue ones. Furthermore, why robots are sometimes grey and sometimes red? \n\n3: This paper makes several strong assumptions including known object shapes and poses, and all objects remain static except the robot grasps them.", "questions": "1: how would your proposed approach compare to related works in geometric feasibility reasoning [1, 2, 3] and graph neural networks in task and motion planning [4, 5, 6]? \n\n2: Is there any noise when you estimate the object's shape and pose in the real world? If there is noise, would your proposed approach be robust to the noise? \n\n3: Could you train one more for all robots? Training one model for each robot limits the generalization ability of your proposed system. \n\n4: For the Panda-3D-4 dataset, the MLP achieves a 0.558 F1 score but the Panda-3D-20 achieves a 0.609 F1 score, why does MLP perform better in a more complex dataset? This is weird. \n\n5: In Table 5, why does your proposed approach perform worse than the baseline in the “Access” problem? Why does the baseline always achieve a 100% success rate? I guess your task is too easy."}], "discussions": [{"comment_id": "BGLWc2bRK9", "replyto": "2wQnpYo9rY", "author_type": "authors", "reviewer": null, "comment": "Thank you very much for your reassessment. We are glad that our revisions and additional experiments satisfactorily addressed many of your concerns. Your detailed feedback and suggestions have been instrumental in enhancing the clarity and quality of our work."}, {"comment_id": "2wQnpYo9rY", "replyto": "L4lfRcMsWn", "author_type": "reviewer", "reviewer": "Reviewer_LeQs", "comment": "Thanks again for the clarifications and additional analysis. I appreciate the effort and authors' dedication to improve their work. With these updates, I increased my score. However, I’d like to note that without a more tightly integrated solution to address TAMP problems comprehensively, the use of 'TAMP' in the title may not fully align with the work’s primary focus. Perhaps a title emphasizing the feasibility prediction framework or interpretability in manipulation planning could more accurately capture the scope of this contribution."}, {"comment_id": "juzGmn8hlm", "replyto": "ccwzUpjNb5", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nAs the discussion phase is nearing its conclusion, we kindly request your feedback and reassessment based on the revisions and additional experiments we have provided, particularly the evaluation of GRN’s interpretability in Appendix D and the robustness analysis under noisy inputs in Appendix E. We hope these additions address your remaining concerns and demonstrate the broader applicability and interpretability of our approach.\n\nIf you have any further questions, we would be more than happy to address them. Thank you once again for your time and valuable contributions to this review process."}, {"comment_id": "UL62NtPaGq", "replyto": "HvUPuls2Tq", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nAs the discussion phase is nearing its conclusion, we kindly request your feedback and reassessment based on the revisions and additional experiments we have provided, particularly the robustness analysis under noisy inputs in Appendix E. We hope these additions address your remaining concerns and demonstrate the broader applicability and generalization of our approach.\n\nIf you have any further questions, we would be more than happy to address them. Thank you once again for your time and valuable contributions to this review process."}, {"comment_id": "L4lfRcMsWn", "replyto": "cDQsDfGZQH", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nAs the discussion phase is nearing its conclusion, we kindly request your feedback and reassessment based on the revisions and additional experiments we have provided, particularly the evaluation of GRN’s interpretability in Appendix D and the robustness analysis under noisy inputs in Appendix E. We hope these additions address your remaining concerns and demonstrate the broader applicability and interpretability of our approach.\n\nIf you have any further questions, we would be more than happy to address them. Thank you once again for your time and valuable contributions to this review process."}, {"comment_id": "HvUPuls2Tq", "replyto": "QhtSY6KEg1", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank you for your valuable feedback and your reassessment of our work. \n\nTo address your remaining concerns regarding generalization to unstructured, noisy environments, we conducted an additional experiment to evaluate GRN's robustness to input noise. This experiment, detailed in Appendix E of the newly revised manuscript, introduces varying levels of Gaussian noise to objects' poses and bounding box dimensions during inference. Results on the Panda-3D-4 test set demonstrate that GRN retains a good performance under different noise levels. Also, even under the highest noise level, our model still outperforms baselines evaluated on noise-free data, especially on grasp feasibility prediction. \n\nWe highlight that in addition to state-of-art performance on action and grasp feasibility prediction, GRN provides richer information thanks to its two interpretation mechanisms. We also provide an analysis of the impact of noise on IK feasibility and grasp obstruction predictions, which remain reliable for guiding downstream planning tasks.\n\n**Performance of GRN on the Panda-3D-4 test set under different levels of noise**\n| Noise Level \t| Action (F1) | Grasp (F1) \t| IK (F1) \t| GO (MAE) \t|\n|--------------------|-------------|-----------------------|--------------------|---------------------|\n| No Noise \t| 0.939 \t| 0.940 (± 0.009) \t| 0.995 (± 0.001) | 0.028 (± 0.003)\t|\n| 1 cm, 1° \t| 0.912 \t| 0.891 (± 0.020) \t| 0.985 (± 0.001) | 0.039 (± 0.003)\t|\n| 2 cm, 2° \t| 0.864 \t| 0.820 (± 0.029) \t| 0.971 (± 0.002) | 0.057 (± 0.003)\t|\n\nWe hope these additional results and analysis further address your concerns about GRN's robustness to noise. We appreciate your constructive comments and remain open to any further discussion."}, {"comment_id": "ccwzUpjNb5", "replyto": "8aANW2Lli8", "author_type": "authors", "reviewer": null, "comment": "Dear reviewer edpJ,\n\nIn order to further address your concerns regarding the effectiveness of our method on noisy input data, we conducted an additional experiment evaluating GRN’s robustness to sensor noise, as detailed in Appendix E of the latest manuscript. This experiment introduces varying levels of Gaussian noise to objects' poses and bounding box dimensions during inference. \n\nDespite the added uncertainty, GRN demonstrates strong performance, maintaining high F1 scores for action and grasp feasibility predictions under different noise levels. Notably, GRN’s performance under noisy conditions still surpasses the noise-free performance of baseline methods, highlighting its ability to handle real-world uncertainties effectively. It also retains a high accuracy on IK feasibility and grasp obstruction predictions.\n\n**Performance of GRN on the Panda-3D-4 test set under different levels of noise**\n| Noise Level \t| Action (F1) | Grasp (F1) \t| IK (F1) \t| GO (MAE) \t|\n|--------------------|-------------|-----------------------|--------------------|---------------------|\n| No Noise \t| 0.939 \t| 0.940 (± 0.009) \t| 0.995 (± 0.001) | 0.028 (± 0.003)\t|\n| 1 cm, 1° \t| 0.912 \t| 0.891 (± 0.020) \t| 0.985 (± 0.001) | 0.039 (± 0.003)\t|\n| 2 cm, 2° \t| 0.864 \t| 0.820 (± 0.029) \t| 0.971 (± 0.002) | 0.057 (± 0.003)\t|\n\nOur results show that even with a simplified bounding box representation and a classification of grasps into grasp types, our method achieves state-of-the-art performance on action and grasp feasibility prediction, while providing richer information through our interpretation mechanisms, better generalization capabilities to complex environments, multi-robot settings and noisy inputs, as well as the ability to solve in a single shot difficult TAMP problems efficiently.\n\nWe appreciate your valuable insights and hope the additional results and proposed extensions help address your remaining concerns."}, {"comment_id": "cDQsDfGZQH", "replyto": "b5FllXvC2C", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank you for your thoughtful feedback and reassessment of our work. We deeply value your constructive comments and fully understand your remaining reservations.\n\nTo further address your concerns about interpretability validation, we conduct a qualitative evaluation of GRN's interpretability in Appendix D of the latest manuscript. In this experiment, we evaluate GRN’s predictions on an example 3D scene. By systematically modifying the environment based on GRN’s predicted infeasibility causes, we observe how the model's predictions change in response. This evaluation demonstrates that GRN’s interpretability mechanisms effectively explain why certain actions or grasps are infeasible and how modifying the scene can render them feasible. For instance, when an obstruction is removed or an object becomes reachable, the action and relevant grasps become feasible, validating the interpretability of our approach.\n\nWe also conducted an additional experiment to assess GRN’s robustness to noisy inputs, detailed in Appendix E of the updated manuscript. In this study, we introduced translational and rotational noise to the objects’ poses and bounding box dimensions in the Panda-3D-4 test set. The results show that GRN’s performance under noise remains superior to the performance of baseline models on noise-free data. These findings highlight the robustness of GRN to different levels of noise, supporting its applicability in real-world scenarios with uncertain inputs.\n\nWhile a full TAMP integration and the application to dynamic scenes are beyond the scope of this work, we believe these additional experiments further strengthen our contributions and demonstrate the applicability of GRN in practical settings, as an accurate and efficient geometric reasoning module that can be integrated with various off-the-shelf TAMP planners.\n\nOur results show that GRN achieves state-of-the-art performance on action and grasp feasibility prediction, while providing richer information through our interpretation mechanisms, better generalization capabilities to complex environments, multi-robot settings and noisy inputs, as well as the ability to solve in a single shot difficult TAMP problems efficiently.\n\nWe appreciate your engagement and the opportunity to address your concerns, and we hope these additional clarifications and experiments further highlight the strengths and applicability of our approach. Thank you once again for your valuable insights."}, {"comment_id": "b5FllXvC2C", "replyto": "Qexe1RA2M7", "author_type": "reviewer", "reviewer": "Reviewer_LeQs", "comment": "I thank the authors for their detailed explanations, and additional analysis. The revised manuscript and added experiments improve the work. The authors’ responses address many of my concerns satisfactorily, particularly those regarding assumptions, the IK module, and grasp representation. However, some limitations persist, especially in TAMP integration, broader applicability (dynamic scenes) and the depth of interpretability validation. So, I updated my score accordingly."}, {"comment_id": "QhtSY6KEg1", "replyto": "cTiuCTJ0ma", "author_type": "reviewer", "reviewer": "Reviewer_WQrW", "comment": "Thank you for the detailed responses. Most of my concerns have been addressed, and I have updated my rating accordingly.\n\nHowever, I still have concerns about the generalization of this work to unstructured, uncertain, real-world environments due to the strong assumptions about known object shape and pose. While the current real-world experiments appear to perform well, they do so only under limited noise and in a simplified experimental setup. Therefore, I cannot further improve my rating."}, {"comment_id": "Hme8xcnHyl", "replyto": "WwbUU6IkxT", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank you for your updated assessment. We are glad that our revisions and additional clarifications addressed your concerns effectively. We greatly appreciate your constructive feedback, which has been invaluable in improving the clarity and depth of our work."}, {"comment_id": "WwbUU6IkxT", "replyto": "y8Rda1LTlP", "author_type": "reviewer", "reviewer": "Reviewer_FvER", "comment": "Thank you. I think all my concerns have been addressed, and the additional table helps explain when MLP could fail and GNN does well. While the novelty might be incremental, I don't see any reason to reject it."}, {"comment_id": "cJ67DwtCvt", "replyto": "B05laaJQGc", "author_type": "area_chair", "reviewer": null, "comment": "Dear Reviewer,\n\nPlease provide feedback to the authors before the end of the discussion period, and in case of additional concerns, give them a chance to respond.\n\nTimeline: As a reminder, the review timeline is as follows:\n\nNovember 26: Last day for reviewers to ask questions to authors.\n\nNovember 27: Last day for authors to respond to reviewers."}, {"comment_id": "y8Rda1LTlP", "replyto": "u7aFKw2RAC", "author_type": "area_chair", "reviewer": null, "comment": "Dear Reviewer,\n\nPlease provide feedback to the authors before the end of the discussion period, and in case of additional concerns, give them a chance to respond.\n\nTimeline: As a reminder, the review timeline is as follows:\n\nNovember 26: Last day for reviewers to ask questions to authors.\n\nNovember 27: Last day for authors to respond to reviewers."}, {"comment_id": "GVZd9KV3tl", "replyto": "b38m95GMVt", "author_type": "area_chair", "reviewer": null, "comment": "Dear Reviewer,\n\nPlease provide feedback to the authors before the end of the discussion period, and in case of additional concerns, give them a chance to respond.\n\nTimeline: As a reminder, the review timeline is as follows:\n\nNovember 26: Last day for reviewers to ask questions to authors.\n\nNovember 27: Last day for authors to respond to reviewers."}, {"comment_id": "8aANW2Lli8", "replyto": "JBsxJLjYP1", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank you for your prompt response and engagement. We greatly value your thoughtful feedback and understand your remaining reservations.\n\nTo clarify, while this paper aims to showcase the interpretability of GRN and the possibilities it offers by proposing a GRN-based planner that circumvents the need for a task planner, the ultimate goal of our approach is not to replace the graph search module in task planning. Instead, our method is designed to provide search heuristics to this module that accelerate the planning process while maintaining completeness.\n\nIn this work, we focus on demonstrating the efficiency, generalizability, and interpretability of GRN as an independent neural feasibility checker that can be seamlessly integrated into traditional task and motion planning algorithms. Furthermore, we are actively working on a full integration of GRN into a TAMP algorithm that we plan to submit as a future contribution to a more robotics-oriented conference.\n\nWe hope this response, along with constructive discussions with other reviewers, will help address your remaining concerns. We deeply appreciate your continued engagement and contribution to the review process."}, {"comment_id": "mUPdQ0c6eZ", "replyto": "ajxAJ8GUX4", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank all reviewers for their constructive feedback and valuable insights. We appreciate the reviewers' acknowledgment of the **novelty and clarity of our work**, as well as the **efficiency, generalizability, and interpretability** of our approach.\n\nBased on the reviewers’ feedback, we have made the following modifications to our manuscript:\n\n- **Improved Figures:** We revised the manuscript to make our figures easier to understand and properly referenced them in relevant sections.\n- **Updated Related Works:** We updated our related works section to include recent research in geometric feasibility reasoning and GNNs for task and motion planning, as suggested by Reviewer WQrW.\n- **Enhanced Table 5:** We included the number of geometric planner queries in Table 5, addressing Reviewer LeQs’s suggestion.\n- **Inference Time Analysis:** We added an inference time decomposition of our method in Appendix A.\n- **Dataset Details:** We provided additional details about our datasets and data annotation in Appendix B, including label distributions for the Panda-3D-4 and PR2-3D-4 datasets.\n- **Problems complexity:** We emphasize the difficulty of the **Access** and **Clutter** problems in Appendix C.2.\n- **New Experiments:** We added two new experiments in Appendix C.3, demonstrating our method's ability to handle **complex long-horizon problems** and **inter-robot handover tasks**.\n\nWe believe these changes significantly enhance the strengths and clarity of our paper while further showcasing the impact of our contributions. We hope these improvements align with the reviewers’ expectations.\n\nAs the discussion phase nears its conclusion, we kindly ask all reviewers for their reassessment and encourage them to share any remaining concerns they might have."}, {"comment_id": "JBsxJLjYP1", "replyto": "cs0BNUiK9b", "author_type": "reviewer", "reviewer": "Reviewer_edpJ", "comment": "Thank you for providing details about the experimental setup. While I appreciate the work's effort to replace the graph search module in task planning with GNN and to accelerate feasibility checks using neural IK and collision modules (correct me if my summary is not correct), this approach does not sufficiently address my concerns about simplified setup to give a higher score. However, I hope other reviewers could actively engage in this discussion to help clarify the issue further.\n\nOn a broader note, the lack of reviewer engagement has been highlighted as a concern in this year’s ICLR. I hope more discussion to ensure a constructive review process."}, {"comment_id": "cs0BNUiK9b", "replyto": "faJEnqhKQ7", "author_type": "authors", "reviewer": null, "comment": "Thank you very much for your response and the opportunity to clarify your remaining concerns. Below, we address the raised points:\n\n**Traditional solver details:**\n\nWe use MoveIt Task Constructor (MTC) [1] as our geometric planner, which leverages the widely used KDL plugin [2] for inverse kinematics and FCL library [3] for collision checking. These methods are neither parallelized nor GPU-accelerated in our implementation. For each object, we first uniformly sample a set of end-effector grasps based on object size. For each sampled grasp, the KDL solver computes up to 8 inverse kinematics (IK) solutions. Each IK solution is then checked for collisions using the FCL collision checker. We revised Appendix B to make these details clearer.\n \n**Batch size and computation time:**\n\nIn the Panda-3D-4 dataset, an average of 150 IK solutions is computed and tested for collisions per movable object. This can be considered the batch size for our traditional methods. The reported 250 ms includes all computations for this batch size. \n\nWe agree that recent GPU-accelerated methods like cuRobo [4] can significantly reduce computation times. For example, cuRobo reports 14.7 ms for a batch size of 100, implying it could handle our average batch size of 150 in approximately 22 ms, which is 11 times faster than our chosen traditional methods.\n\n**Comparison to our IK feasibility and GO modules:**\n\ncuRobo and other traditional methods verify IK feasibility by actually computing IK solutions and checking for collisions. Many of these solutions may be infeasible, resulting in computation time unnecessarily wasted on infeasible solutions. Our method avoids this inefficiency by predicting which grasps (end-effector poses) are promising, allowing traditional geometric planners to focus only on potentially feasible solutions. \n\nFurthermore, our IK feasibility and GO modules, in contrast to traditional methods, take as input object features (e.g., dimensions, poses) rather than individual sampled grasps (i.e end-effector poses), and simultaneously predict IK feasibility and grasp obstruction ratios for 5 grasp types (i.e sets of grasps associated with each face of the object's bounding box). This design enables our modules to reason over batches of objects instead of batches of grasps.\n\nWhile cuRobo significantly reduces computation time thanks to GPU-acceleration, our IK feasibility and GO modules compute predictions for 4 movable objects in under 1 ms (cf. Appendix A). Based on the previous observations, running cuRobo on the same environment with 150 IK computations per object, would require 88 ms (22 ms × 4). Thus, our IK feasibility and GO modules are **88 times faster**. \n\nIn summary, while GPU-accelerated methods like cuRobo offer significant improvements over other traditional solvers, GRN operates at a fraction of their computational cost while predicting feasibility directly from object features without sampling individual grasps. Our method offers a pre-selection step for deciding which grasps traditional methods should focus on.\n\n---\n\n[1] Görner, Michael, et al. \"Moveit! task constructor for task-level motion planning.\" 2019 International Conference on Robotics and Automation (ICRA). IEEE, 2019.\n\n[2] https://www.orocos.org/kdl.html\n\n[3] Pan, Jia, Sachin Chitta, and Dinesh Manocha. \"FCL: A general purpose library for collision and proximity queries.\" 2012 IEEE International Conference on Robotics and Automation. IEEE, 2012.\n\n[4] Sundaralingam, Balakumar, et al. \"CuRobo: Parallelized collision-free minimum-jerk robot motion generation.\" arXiv preprint arXiv:2310.17274 (2023)."}, {"comment_id": "faJEnqhKQ7", "replyto": "xzB6oow3sh", "author_type": "reviewer", "reviewer": "Reviewer_edpJ", "comment": "Thanks for your response, and I appreciate the explanation regarding IK and collision-checking times. However, I find the reported 250ms for traditional methods somewhat high compared to benchmarks like cuRobo [1], which achieves collision-free computation in about 5ms for batch size 10 with GPU acceleration. Could you clarify the following points to provide more context?\n\n1. Batch Size: What batch size was used for both MLP-based IK and traditional IK computations?\n\n2. Parallelization: Were traditional IK computations or collision checking parallelized or GPU-accelerated?\n\n3. Solver Details: What IK solver and collision-checking methods were used?\n\nAddressing these points would help align the benchmarks and clarify the comparison. Thank you.\n\n[1] cuRobo: Parallelized Collision-Free Minimum-Jerk Robot Motion Generation"}, {"comment_id": "cTiuCTJ0ma", "replyto": "DL07smVn97", "author_type": "authors", "reviewer": null, "comment": "**3. Training one model for all robots:**\n\nOur current implementation does not support training a single model for all robots. This limitation is common among existing methods for feasibility prediction, as kinematics and collision models can vary significantly across different robots. Training a reliable unified model would require incorporating a representation of the kinematics and collision model for each robot, which can be complex to encode effectively. That said, this is an interesting direction for future work, and we aim to explore ways to address this challenge in subsequent research.\n\n**4. MLP generalizability to a larger number of objects:**\n\nWe appreciate the reviewer’s observation. In our problem, IK feasibility prediction is the simplest task (cf. Table 3), with a simple MLP (IK module) achieving an F1 score exceeding 99%. Environments with more objects naturally lead to an increase in infeasibility cases due to unreachability, which depends solely on the object’s position and not its neighborhood. This explains the seemingly good generalization of the MLP baseline in larger environments, as it performs well at predicting unreachability.\n\nHowever, the MLP baseline does not account for neighboring obstacles and performs poorly in grasp obstruction cases, where its predictions are effectively random. Comparing confusion matrices for Panda-3D-20 below, the MLP baseline yields a high number of false positives. In contrast, GRN incorporates both IK and GO considerations, leading to more balanced and accurate predictions, with a notable improvement in F1 scores 0.2 higher for action feasibility and 0.3 higher for grasp feasibility on Panda-3D-20 w.r.t MLP.\n\n---\n\n**MLP Confusion Matrix on Panda-3D-20**\n\n| Actual / Predicted | Infeasible | Feasible |\n|---------------------|------------|----------|\n| **Infeasible** | 65129 | 23065 |\n| **Feasible** | 4679 | 27127 |\n\n---\n\n**GRN Confusion Matrix on Panda-3D-20**\n\n| Actual / Predicted | Infeasible | Feasible |\n|---------------------|------------|----------|\n| **Infeasible** | 85956 | 2288 |\n| **Feasible** | 3856 | 27910 | \n\n---\n\n**5. Performance on the Access problem:**\n\nThe baseline TAMP planner is a complete, closed-loop algorithm that iteratively evaluates multiple task plans until a feasible one is found. In contrast, our GRN-based planner is a one-shot open-loop algorithm: if the initial plan generated using GRN’s predictions is infeasible, the planning process terminates, and no alternative plans are considered. In one instance of the Access problem, GRN produces an incorrect prediction, resulting in a geometrically infeasible plan, and thus a planning failure.\n\nThe considered problems are not particularly easy (see Appendix C.2). However, we deliberately limit the number of objects to ensure that the baseline planner can solve the problems in a reasonable amount of time, allowing for a fair comparison of planning time. To further demonstrate the capabilities of GRN, we include an additional experiment in Appendix C.3, featuring a more challenging 28-object Access problem. The baseline planner fails completely on this task, while the GRN-based planner solves it in under 15 seconds.\n\n---\n**References** \n[1]: K. Lin, C. Agia, T. Migimatsu, M. Pavone, and J. Bohg. Text2motion: From natural language instructions to feasible plans. Autonomous Robots, 2023.\n\n[2]: Y. Huang, C. Agia, J. Wu, T. Hermans, and J. Bohg. Points2Plans: From Point Clouds to Long-Horizon Plans with Composable Relational Dynamics, ArXiv, 2024.\n\n[3]: C. Agia, T. Migimatsu, J. Wu, and J. Bohg. Stap: Sequencing task-agnostic policies. In 2023 IEEE International Conference on Robotics and Automation, ICRA 2023.\n\n[4]: Planning for Multi-Object Manipulation with Graph Neural Network Relational Classifiers. In IEEE International Conference on Robotics and Automation, ICRA 2023.\n\n[5]: H. Shi, H. Xu, Z. Huang, Y. Li, and J. Wu. Robocraft: Learning to see, simulate, and shape elastoplastic objects in 3d with graph networks. IJRR 2024.\n\n[6]: H. Chen, Y. Niu, K. Hong, S. Liu, Y. Wang, Y. Li, and K. R. Driggs-Campbell, “Predicting object interactions with behavior primitives: An application in stowing tasks,” CoRL 2023.\n\n[7] Garrett, Caelan Reed, et al. \"Integrated task and motion planning.\" Annual review of control, robotics, and autonomous systems 4.1 (2021).\n\n[8] Garrett, Caelan Reed, Tomas Lozano-Perez, and Leslie Pack Kaelbling. \"Ffrob: Leveraging symbolic planning for efficient task and motion planning.\" IJRR 2018.\n\n[9] Wells, Andrew M., et al. \"Learning feasibility for task and motion planning in tabletop environments.\" RAL 2019.\n\n[10] Driess, Danny, et al. \"Deep visual heuristics: Learning feasibility of mixed-integer programs for manipulation planning.\" ICRA 2020.\n\n[11] Bouhsain, Smail Ait, et al . \"Simultaneous Action and Grasp Feasibility Prediction for Task and Motion Planning through Multi-Task Learning.\" IROS 2023.\n\n---"}, {"comment_id": "DL07smVn97", "replyto": "B05laaJQGc", "author_type": "authors", "reviewer": null, "comment": "We greatly appreciate the reviewer’s comprehensive feedback and address their concerns in detail below.\n\n---\n\n### **Comments on raised weaknesses**\n\n**1. Comparison to related works in geometric feasibility reasoning and GNNs in TAMP:**\n\nWe thank the reviewer for bringing these works to our attention. We have carefully reviewed the suggested references and included a discussion of these works in the related works section of the revised manuscript.\n\n**2. Figures clarity:**\n\nWe appreciate the reviewer’s feedback on the clarity of the figures. In the revised manuscript, we have improved the caption for Figure 1 to make it easier to interpret. Additionally, we have added a detailed explanation and commentary for the figures in Appendix D. To clarify, the visualization figures show the three predictions of GRN: (i) Action feasibility, represented by the object color in the corresponding plots (green for feasible, red for infeasible). (ii) Grasp type feasibility, represented by the coloration of the corresponding object faces. (iii) Reasons of infeasibility: Shown for a single object for clarity., which is depicted in blue (exceptionally in pink in Figure 9). Neighboring objects, shown in full opacity, are color-coded to represent the grasp obstruction ratio. We believe these updates and explanations significantly improve the clarity and comprehensibility of the figures and welcome any further suggestions.\n\n**3. Scene knowledge and static environments assumptions:**\n\nWe acknowledge the reviewer’s concern regarding the assumptions of known object shapes, poses, and static scenes. However, such assumptions are standard in offline TAMP research, as demonstrated in both traditional and learning-based works [7, 8, 9, 10, 11]. These works typically decouple the perception and planning problems, focusing on solving the manipulation planning task with fully defined scenes. This decoupling is a common practice in offline manipulation planning to allow methods to focus on geometric reasoning. Our approach aligns with these standards, making the assumptions consistent with the existing literature.\n\n---\n\n### **Answers to raised questions**\n\n**1. Comparison to geometric feasibility reasoning and GNNs in TAMP:**\n\n[1] and [2] propose learning-based methods for TAMP using large language models. Unlike [1], which relies solely on the probability of feasibility of learned skill primitives, our method aligns more closely with [2] by explaining the reasons behind infeasibility. This interpretability allows downstream planners to bypass infeasible actions more effectively. Points2Plans ([2]) trains relational dynamics models for single-step transitions, whereas GRN simplifies feasibility checks into interpretable modules, making it straightforward to train and integrate with traditional TAMP planners. GRN could also serve as an independent feasibility checker for [1] and [2].\n\n[3] uses RL to learn task-agnostic robot skills and verify their feasibility for manipulation tasks. While this method requires a separate model per skill and a single query per action, GRN learns a unified model for both pick and place actions, predicting feasibility for all objects simultaneously. Additionally, GRN’s interpretability mechanisms provide insights into infeasibility, enhancing decision-making.\n\n[4] and [6] utilize GNNs for different aspects of TAMP. [4] learns inter-robot relations to verify subgoal satisfaction with a scene graph representation similar to ours. [6] leverages GNNs for dynamic object interactions in stowing tasks. However, neither focuses specifically on action and grasp feasibility prediction. GRN could be integrated into these methods to predict feasibility-related inter-robot relations or evaluate the feasibility of stowing action sequences.\n\n[5] addresses shaping deformable objects, a task unrelated to action and grasp feasibility prediction. As such, it does not align closely with our focus.\n\nIn summary, GRN distinguishes itself through its focus on feasibility prediction, interpretability, and modularity, making it a complementary tool for integration into various TAMP frameworks. The mentioned works, however, tackle non-prehensile actions as well, which is a limitation of our method we aim to address in future work.\n\n**2. Robustness to noise in real-world experiments:**\n\nIn our real-world experiments, the estimated object poses are indeed noisy, while the object shapes are exact as they are obtained through object recognition, primarily using tags. Despite this, the observed success of the GRN-based planner on real-world tasks suggests that our approach is robust to relatively small noise in pose estimation.\n\nIt is important to note that GRN predicts the success of a sampling-based geometric planner (defined as feasibility). While larger noise levels would reduce GRN’s performance, they would similarly undermine the success of the geometric planner itself, as both rely on accurate environment representations."}, {"comment_id": "G06bN9Hlv5", "replyto": "u7aFKw2RAC", "author_type": "authors", "reviewer": null, "comment": "We sincerely value the reviewer’s input and have addressed their questions in the sections below:\n\n**1. Dataset Imbalance:**\n\nWe have included a detailed description of the dataset distribution in Appendix B for the Panda-3D-4 and PR2-3D-4 datasets. Action feasibility labels are well balanced, with roughly equal numbers of feasible and infeasible cases. Grasp feasibility labels, however, are imbalanced, with a higher proportion of infeasible cases. Importantly, the reasons for infeasibility are balanced between IK infeasibility and grasp obstructions in the Panda-3D-4 dataset, while the PR2-3D-4 dataset contains more IK infeasibility cases. Our proposed data augmentation methods help mitigate these imbalances. Motion planning infeasibility cases are rare across both datasets.\n\n**2. RRT and definition of feasibility:**\n\nWe agree with the reviewer’s observation regarding RRT. However, since RRT is probabilistically complete, running RRT for an infinite amount of time would be needed to guarantee that a plan will be found if it exists. In practice, we set a timeout during data annotation, and consider an action as infeasible if the geometric planner fails to find a solution within this timeout. Thus, feasibility is defined as the success of the chosen geometric planner in finding a solution within a user-defined timeout.\n\nIt is worth noting that in our datasets, as shown in our updated Appendix B, most infeasibility cases arise from inverse kinematics (IK) constraints or grasp obstructions (GO), rather than from motion planning (MP) failures.\n\n**3. Pros of pre-training modules independently:**\n\nWe observed no instability when training all modules jointly. However, pre-training ensures that each module effectively learns its specific task given its designated inputs. Without pre-training, jointly training all modules leads to convergence to spurious local optima due to the greedy nature of weight optimization (cf. ablation study). Pre-training helps mitigate this issue by providing a strong initialization, which ensures more effective fine-tuning during joint training.\n\n**4. MLP’s generalizability to larger environments:**\n\nWe appreciate the reviewer’s observation. The simplest task in our problem is IK feasibility prediction, which even a simple MLP achieves with an F1 score exceeding 99% (cf. ablation study). Generating environments with more objects increases cases where infeasibility is due to unreachability, which depends solely on the object’s position rather than its neighborhood. This explains the seemingly good generalization of the MLP baseline in larger environments, as it excels at predicting unreachability.\n\nHowever, the MLP baseline does not account for neighboring obstacles and struggles in cases of grasp obstruction, where its predictions are no better than random. Comparing confusion matrices obtained on Panda-3D-20 shown below, the MLP baseline frequently predicts actions and grasps as feasible, resulting in a high number of false positives. GRN, by contrast, incorporates both IK and GO considerations, producing more balanced and informed predictions, which explains the significant difference in F1 score of 0.2 (resp. 0.3), for action (resp. grasp) feasibility prediction.\n\n---\n\n**MLP Confusion Matrix on Panda-3D-20**\n\n| Actual / Predicted | Infeasible | Feasible |\n|---------------------|------------|----------|\n| **Infeasible** | 65129 | 23065 |\n| **Feasible** | 4679 | 27127 |\n\n---\n\n**GRN Confusion Matrix on Panda-3D-20**\n\n| Actual / Predicted | Infeasible | Feasible |\n|---------------------|------------|----------|\n| **Infeasible** | 85956 | 2288 |\n| **Feasible** | 3856 | 27910 | \n\n---\n\n**5. Learned feasibility checker Vs. Learned end-to-end planner:**\n\nOur solution based on a learned feasibility checker aims to limit the calls to a costly geometric planner. It ensures completeness, generalizability to unseen environments and diverse robots, interpretability for geometric reasoning, and safety through traditional collision-free motion planning, all with lower data generation costs as only binary feasibility labels (and GO ratios) are needed. But it still requires a geometric planner to plan motions for actions predicted as feasible, increasing planning time. In contrast, a learned end-to-end planner could be faster, as it completely avoids geometric queries. However, it sacrifices key requirements such as safety since collision-free motions are not guaranteed, it struggles with generalization to new environments and high-DOF robots, and demands higher data generation costs, requiring full motion plans for training.\nIn offline manipulation planning where completeness, safety and generalizability are desirable, learned feasibility checkers are therefore a preferable option.\n\n---"}, {"comment_id": "xzB6oow3sh", "replyto": "QhtxzDAmIb", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank the reviewer for their valuable feedback and address their concerns below:\n\n---\n\n### **Comments on raised weaknesses**\n\n**1. Reliance on simplified bounding box representations and discrete grasp feasibility:**\n\nWe acknowledge the reviewer’s concern and emphasize that our use of bounding box representations and discrete grasp feasibility strikes a balance between computational efficiency and representational capability. These simplifications have proven effective in predicting action and grasp feasibility for everyday objects [1][2][3].\n\nFor more complex objects, finer representations—such as multiple bounding boxes, alternative primitives (cylinders, spheres), or advanced object shape encoders—could be explored. Similarly, continuous grasp feasibility predictions could be achieved with additional data annotation and training efforts, allowing for greater adaptability. We aim to investigate these directions in future work to enhance GRN’s applicability to more challenging scenarios.\n\n**2. Assessing motion planning feasibility:**\n\nWe clarify that motion planning feasibility is considered during dataset annotation and training. However, in the generated datasets, such cases are relatively rare (cf. Appendix B in the updated manuscript) and are therefore under-represented in our experiments.\n\nFrom our experience in TAMP and observations from the generated datasets, the vast majority of infeasibility cases arise from inverse kinematics or grasp obstructions. These dominate failure modes in practical scenarios, making our focus on these aspects both relevant and impactful.\n\nThat said, we recognize the importance of addressing complex cases where motion feasibility is a limiting factor. As mentioned in the future work section, we aim to develop a new interpretation module for motion feasibility detection and explore methods to estimate the configuration space’s connectivity using our graph representation, further enhancing GRN’s capabilities in these scenarios.\n\n---\n\n### **Answers to raised questions**\n\n**1. Using traditional methods as IK and GO modules:**\n\nReplacing the IK and GO modules with traditional methods, such as standard IK solvers or distance-based obstruction checks, would significantly impact computational efficiency. On average, for a pick action, IK computation and collision checking take approximately 250 ms (includes computations for all considered grasps), making these methods substantially slower than our learned modules which take 0.5ms each (cf. Appendix A in the updated manuscript).\n\nIn terms of accuracy and interpretability, we observed that the pretrained GNN (before fine-tuning) using ground truth IK feasibility and grasp obstruction values achieves F1 scores of 0.95 and 0.96 for action and grasp feasibility predictions, respectively. However, we believe that the slight improvement in accuracy does not justify the significant computational overhead introduced by traditional methods, especially in time-sensitive TAMP applications.\n\n**2. Maximum environmental complexity:**\n\nWhile it is difficult to quantify an exact maximum level of environmental complexity that GRN can handle, we believe it could scale effectively to environments with a high number of objects. This scalability would stem from the localized nature of feasibility predictions: only the distance-based neighborhood of a specific object is considered for each prediction. Even in environments containing thousands of objects, the number of objects in any single neighborhood is inherently limited, ensuring manageable computational requirements and maintaining high prediction quality.\n\nIn scenarios involving point clouds, the environment can be decomposed into separate objects, support surfaces, and obstacles. This decomposition allows GRN to operate on a structured representation rather than directly on the raw point cloud. By leveraging this structured approach, GRN could preserve prediction accuracy and efficiency, even in highly cluttered and complex scenes.\n\n---\n\n**References** \n[1] Wells, Andrew M., et al. \"Learning feasibility for task and motion planning in tabletop environments.\" IEEE robotics and automation letters 4.2 (2019): 1255-1262.\n\n[2] Bouhsain, Smail Ait,et al . \"Extending Task and Motion Planning with Feasibility Prediction: Towards Multi-Robot Manipulation Planning of Realistic Objects.\" 2024 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS). IEEE, 2024.\n\n[3] Khodeir, Mohamed, et al. \"Policy-guided lazy search with feedback for task and motion planning.\" 2023 IEEE International Conference on Robotics and Automation (ICRA). IEEE, 2023."}, {"comment_id": "Qexe1RA2M7", "replyto": "p8agLJOpmQ", "author_type": "authors", "reviewer": null, "comment": "### **Answers to raised questions**\n\n**Self-loop Edge Representation in 4.1:**\n\nThis design decision was primarily motivated by considerations of consistency and simplicity in the graph representation. Since IK feasibility is computed exclusively for nodes corresponding to movable objects, treating it as a node feature would require adding padding or a mask to the features of nodes representing fixed objects, which do not require IK feasibility information. By storing IK feasibility predictions in self-loop edges, we avoid this issue and maintain consistent dimensionality across all node and edge features.\n\nIn our experiments, we observed that both approaches—storing IK feasibility as a node feature versus a self-loop edge feature—achieved very similar results in terms of predictive performance. As such, the decision to use self-loop edges was driven by a preference for a more “elegant” and streamlined graph representation rather than by any performance considerations.\n\n**Effectiveness and Necessity of the IK Module:**\n\nSee answer 4.b.\n\n**Access Task Success Rate in Section 6.4**\n\nThe slightly lower success rate observed for the Access task is due to a misclassification made by GRN on one of the ten instances of the problem. The GRN-based planner is not complete, if the initial plan found is infeasible, planning stops.\n\n\nIt is important to note that this limitation is specific to the current implementation of the GRN-based planner. A complete integration of GRN into a TAMP solver would mitigate this issue by incorporating mechanisms to recover from misclassifications, such as re-evaluating feasibility predictions at each step of the planning process or exploring alternative solutions.\nGiven that this misclassification occurred in only one out of ten instances of the same problem, the slightly lower success rate can be attributed to experimental variance.\n\n**Inference Time in Table 5**\n\nThe inference time reported in Table 5 reflects the combined duration of GRN queries and geometric planning. Notably, GRN is queried only once per problem on the initial scene, which takes a few milliseconds only. The large majority of the reported time is attributable to the geometric planning phase, which dominates the computation.\n\n**Handling Dynamic Scenes**\n\nWe appreciate the reviewer’s interest in extending our approach to dynamic environments. While the current work focuses on static scenes for computational tractability and alignment with baseline methods, we agree that handling dynamic scenes is an important direction for future research. One promising direction to explore is the use of Temporal Graph Neural Networks [6], which could model the evolution of the environment over time by incorporating temporal dependencies into the graph representation. This extension would allow GRN to reason about changes in object configurations and dynamic constraints, enabling its application to a wider range of real-world tasks involving dynamic scenes.\n\n**Simplified Grasp Types:**\n\nSee Answer 5.\n\n---\n\n**References** \n[1] Garrett, Caelan Reed, et al. \"Integrated task and motion planning.\" Annual review of control, robotics, and autonomous systems 4.1 (2021): 265-293.\n\n[2] Garrett, Caelan Reed, Tomas Lozano-Perez, and Leslie Pack Kaelbling. \"Ffrob: Leveraging symbolic planning for efficient task and motion planning.\" The International Journal of Robotics Research 37.1 (2018): 104-136.\n\n[3] Wells, Andrew M., et al. \"Learning feasibility for task and motion planning in tabletop environments.\" IEEE robotics and automation letters 4.2 (2019): 1255-1262.\n\n[4] Driess, Danny, et al. \"Deep visual heuristics: Learning feasibility of mixed-integer programs for manipulation planning.\" 2020 IEEE international conference on robotics and automation (ICRA). IEEE, 2020.\n\n[5] Bouhsain, Smail Ait, et al . \"Simultaneous Action and Grasp Feasibility Prediction for Task and Motion Planning through Multi-Task Learning.\" 2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS). IEEE, 2023.\n\n[6] Rossi, Emanuele, et al. \"Temporal graph networks for deep learning on dynamic graphs.\" arXiv preprint arXiv:2006.10637 (2020).\n\n\n---"}, {"comment_id": "p8agLJOpmQ", "replyto": "AR6gP9pRkp", "author_type": "authors", "reviewer": null, "comment": "**4. Incremental Advancement and Limited Impact of IK Infeasibility Module:**\n\n**a. Adjustments to EGAT:**\n\nWhile it is true that the adjustments to EGAT are incremental, they are not presented as a contribution of the paper. Rather, our work leverages EGAT in a novel application within the domain of action and grasp feasibility prediction, showcasing its suitability for this task while building upon prior advancements in GNNs.\n\n**b. IK Feasibility Module Impact:**\n\nThe impact of the IK feasibility module on overall performance is indeed limited. However, this is not due to a lack of IK infeasibility cases. Figures 5 and 6 in Appendix B of the updated manuscript show a balanced distribution of infeasibility causes in the datasets between IK infeasibility and grasp obstructions. The observed limited performance gain is due to the AGF module’s ability to implicitly learn and address these cases, effectively replacing the IK module’s role during feasibility prediction if the latter is omitted.\n\nThe IK infeasibility module is essential for interpretability and planning efficiency, providing explicit reasoning about infeasibility causes. For instance, if an object is both obstructed and unreachable due to IK constraints, GRN without the IK module will attribute infeasibility solely to the obstruction. The IK module, however, identifies unreachability as the root cause, distinguishing cases where clearing the obstruction resolves infeasibility from those where the object remains unreachable. This clarity benefits TAMP planners by avoiding unnecessary computations, such as attempting to clear obstructions for unreachable objects. As detailed in Appendix A, the IK module adds minimal overhead of 0.5 ms.\n\n\n**5. Restricted Grasp Representation:**\n\nWhile our approach defines 5 grasp types, it is important to clarify that these types are not single grasps but represent an infinite set of grasps. Each grasp type defines the face of the object from which it is grasped, leaving several DOFs as free parameters (e.g. position of the grasp along the face, the rotation of the end effector w.r.t. the approach axis, and the depth of the grasp). This definition of grasp types is widely adopted in the field of action and grasp feasibility prediction [3, 4, 5] and has proven effective for representing diverse grasping strategies while maintaining computational efficiency. During data annotation, we evaluate an average of 150 sampled grasp configurations per object (reaching ~600 grasps for larger objects), ensuring that our representation captures a rich set of grasp possibilities within each type.\n\n\n**6. Limited Interpretability Support:**\n\nWe highlight that the application to TAMP (section 6.4) serves as our experimental support for the interpretability of our predictions. Specifically, the proposed planner leverages the predicted feasibility and the reasons for infeasibility derived from the initial state of the environment to compute a complete feasible plan. Notably, the planner relies solely on this interpretability information, without iterative feedback or additional inputs during the planning process. The success of the resulting plans, as demonstrated in the Access and Clutter tasks, directly showcases the interpretability and utility of the predicted feasibility and infeasibility explanations. For instance, the planner uses these insights to identify and resolve obstructions or determine that certain actions are fundamentally infeasible, enabling it to generate efficient and accurate solutions. This experimental validation underscores that our model’s interpretability mechanisms provide actionable insights that support both feasibility prediction and effective task and motion planning.\n\n\n**7. Clarity and Organization of Paper Structure:**\n\nWe apologize for these oversights and thank the reviewer for bringing these issues to our attention. We have carefully reviewed the manuscript and made the necessary corrections to ensure a more logical flow and improved readability.\n\n---"}, {"comment_id": "AR6gP9pRkp", "replyto": "b38m95GMVt", "author_type": "authors", "reviewer": null, "comment": "We deeply appreciate the reviewer's detailed feedback and provide detailed responses to their concerns below.\n\n---\n\n### **Comments on raised weaknesses**\n\n**1. Strong Assumptions on Object Knowledge and Static Scene Conditions:**\n\n**a. Method Assumptions:**\n\nWe acknowledge the reviewer's concern about the assumptions of full knowledge of objects shape, pose, and static scene conditions. However, we emphasize that such assumptions are standard in offline TAMP research, as demonstrated in various traditional and learning-based works [1, 2, 3, 4, 5]. These works decouple the planning component from perception, focusing on offline manipulation planning with fully defined scenes, mainly because the main challenge in TAMP is the combinatorial complexity of combining discrete symbolic search and continuous geometric planning.\n\n**b. Fairness with Respect to Image-Based Baselines:**\n\nThe baselines used in our experiments similarly assume full scene knowledge. For instance, the approach presented in [5], while image-based, internally constructs depth images from full scene knowledge, and does not rely on sensor-acquired images. Similarly, [4] operates in simplified environments (i.e., tabletop setups with box-shaped objects). During planning, a TAMP algorithm considers many states of the environment, which are \"imagined\" hypothetical states based on the actions considered by the planner, rather than sensor-perceived states. The use of image-based methods on these states requires the input images to be internally built. Hence, although [4] can use images obtained from depth cameras, its application to TAMP requires internally building these images based on scene knowledge.\n\nIn this paper, we ensure fairness by training all methods on the same 3D scenarios, comparing them within the same offline manipulation planning context, and making sure all inputs are constructed using the same full state knowledge. Furthermore, we include comparisons on tabletop scenarios to further improve fairness with [3] and [4].\n\n\n**2. Limited Integration and Short Planning Horizons:**\n\nBoth the Access and Clutter problems require multi-step reasoning to identify and remove obstructing objects in a specific order before accessing the desired object. We emphasize the challenges posed by these problems in Appendix C.2 of the updated manuscript. To further highlight the ability of our method to handle significantly more complex problems, we have included a challenging 28-object version of the Access problem in Appendix C.3, on which the baseline planner completely fails, while the GRN-based planner solves it in under 15s. Moreover, to showcase our method’s ability to handle inter-robot handover tasks, we added an experiment in Appendix C.3 demonstrating how GRN can be leveraged in this type of problem.\n\n\n**3. Improvement for TAMP:**\n\n**a. Improvement of Traditional/Existing TAMP Solvers:**\n\nWe appreciate the reviewer’s interest in the integration of GRN into a full-fledged TAMP solver. While this is a valuable direction, we deliberately chose not to include such integration within this paper to avoid overshadowing the contributions of GRN itself, and sacrificing key experiments and analysis crucial to demonstrating the capabilities of GRN. Instead, we plan to submit the integration of GRN into a TAMP algorithm as a separate contribution to a more robotics-oriented conference.\n\nThe planner provided in the paper is intended to highlight the potential of GRN in accelerating TAMP by leveraging its unique interpretation capabilities. However, this planner is not complete. A full integration of GRN into a TAMP solver that is complete would involve querying GRN iteratively at each step of the planning process, building and maintaining a set of relaxed constraints representing the predicted infeasibility causes, and using the predictions to compute heuristics that guide the search for a solution. Importantly, this process must ensure that no potential solution is excluded, preserving the completeness of the algorithm. These complexities highlight the scope of GRN’s integration as a standalone future contribution.\n\nTo address concerns about computational improvements, we have updated Table 5 in the manuscript to include the number of geometric planner calls, illustrating the reduction achieved by using GRN.\n\n**b. Inference Time Comparison Fairness:**\n\nWe kindly reiterate that both image-based and feature-based methods use the same underlying environment information. Both input types require independent perception and knowledge management modules, and our comparisons are conducted under these equivalent conditions."}], "meta_review": {"metareview": "This paper proposes a novel framework for learning Geometric Reasoning Networks (GRN) to address task and motion planning (TAMP) challenges in robotic manipulation. The method combines geometric reasoning with deep learning techniques to improve efficiency and reduce dependency on traditional geometric planners. The authors integrate the GRN within a TAMP framework, focusing on long-horizon planning problems and demonstrating their approach's efficacy in both simulated and real-world environments.\n\nThe paper makes strong claims about the generalizability of its learned representations and the ability to handle complex reasoning tasks that were previously challenging for geometric planners. The experimental results show improvements over baseline methods in terms of success rates, planning efficiency, and computational cost, making a convincing case for the practical utility of the proposed method.\n\n**Strengths:**\n\n1. The proposed GRN introduces a novel combination of geometric reasoning with task planning, a critical step forward in addressing the limitations of traditional TAMP approaches.\n\n2. The authors provide extensive evaluations, including comparisons to baseline geometric planners and ablations that demonstrate the contributions of individual components. The experiments are performed across diverse environments, adding credibility to the results.\n\n3. The integration of GRN within a broader TAMP framework demonstrates its scalability and applicability to real-world tasks, showcasing both simulation and real-world results.\n\n**Weaknesses:**\n\n1. While the results are strong in structured and static scenarios, the paper does not sufficiently explore the method's applicability to dynamic environments or tasks involving significant real-time changes.\n\n2. The reviewers noted that the paper's claims about generalizability are not fully substantiated, as the evaluations focus primarily on specific planning scenarios. Broader validations would further strengthen the impact.\n\n3. Although the paper includes comparisons with traditional geometric planners, additional comparisons with alternative learning-based TAMP methods could provide a more comprehensive evaluation.\n\n4. The real-world experiments, while impactful, are limited in scale and complexity. Expanding these evaluations to larger or more diverse scenarios would better support the paper's claims.\n\n**Reasons for Acceptance:**\n\nThe paper addresses a critical challenge in TAMP, offering a novel solution that integrates geometric reasoning with deep learning. The proposed framework shows strong empirical results and demonstrates its potential applicability to real-world tasks. While there are some limitations in terms of generalizability and dynamic task exploration, the strengths outweigh these concerns. The method is a valuable contribution to the field, and the revisions made during the rebuttal significantly improved the paper's quality and clarity.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "BC4lIvfSzv", "forum_url": "https://openreview.net/forum?id=BC4lIvfSzv", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "Generative Representational Instruction Tuning", "authors": ["Niklas Muennighoff", "Hongjin SU", "Liang Wang", "Nan Yang", "Furu Wei", "Tao Yu", "Amanpreet Singh", "Douwe Kiela"], "abstract": "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM-7B is among the top models on the Massive Text Embedding Benchmark (MTEB) and outperforms various models up to its size on a range of generative tasks. By scaling up further, GritLM-8x7B achieves even stronger generative performance while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.", "keywords": ["large language models", "instruction tuning", "text embedding"], "primary_area": "foundation or frontier models, including LLMs", "pdf_url": "https://openreview.net/pdf?id=BC4lIvfSzv", "decision": "Accept (Poster)", "num_reviews": 4, "num_discussions": 11, "reviews": [{"review_id": "IyvHbxSJN5", "reviewer": "Reviewer_XQzj", "rating": 8, "confidence": 4, "soundness": 3, "presentation": 4, "contribution": 3, "summary": "The paper presents generative representational instruction tuning (GRIT), a unified model for embedding and generative tasks in text. GRIT learns embedding representations with a bidirectional attention followed by mean pooling and a instruction tuning with causal attention. The experiments show that GRITLM outperforms various prior open models on the MTEB benchmark and matches the performance of several instruction tuning models. Furthermore, the unified model speeds up retrieval augmented generation by 60%.", "strengths": "- The paper is very well written. It is easy to follow the main motivation of the paper. The related work positions the paper well. \n- The paper presents large-scale experiments. The model matches the performance of strong baselines on challenging benchmarks such as MTEB and instruction tuning datasets. \n- The caching mechanism reduces latency for RAG, especially longer sequences. \n- The GRITLM model will be useful for practitioners.", "weaknesses": "**Mixed results**\nThe main contribution of the unified method is to reduce the latency for generating the output. However, Table 4 shows a tradeoff between performance and latency. In Doc-Query and Query-Doc experiments, we see that GRITLM speeds up RAG but at the cost of overall performance. Furthermore, GRITLM does not show significant speed-ups on GPUs. Finally, I would be curious to see if a smaller embedding model (besides a smaller GRITLM) shows improved performance compared to the RAG performance in Table 4.\n\n**Modularity.**\nOne of the main advantages of RAG is that it is modular. The separation of the embedding model and the generative model makes it easy to swap out either one of the components. With a unified embedding and generative model, the entire model has to be retrained which can be computationally expensive. \n\n**Include more recent work**\nThe authors have acknowledged that the more recent embedding models, such as NV-Embed, show improved performance over GRITLM. It would be awesome if the authors cited more recent work [a, b] and more. \n\n[a] ​​SFR-Embedding-Mistral:Enhance Text Retrieval with Transfer Learning.\n\n[b] Towards General Text Embeddings with Multi-stage Contrastive Learning", "questions": "Please see the weaknesses."}, {"review_id": "YVi200zIHO", "reviewer": "Reviewer_vUnt", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 2, "summary": "This paper introduces GritLM, a language model designed to excel at both text generation and embedding tasks. Current large language models (LLMs) typically specialize in one or the other, requiring separate models for applications that need both functionalities. GritLM addresses this limitation by employing a joint training approach.\n\nThe model architecture leverages a standard autoregressive generation head for text generation, trained with a next-token prediction cross-entropy loss. For embedding tasks, GritLM uses bidirectional encoding of the input prompt and mean pooling of the final hidden layer representations. A contrastive loss with in-batch negatives is applied to these embedding representations. The overall training objective combines these two losses, allowing the model to learn both tasks concurrently.\n\nExperimental results demonstrate that GritLM achieves competitive performance on both generation and embedding benchmarks, comparable to similarly-sized specialized models. Furthermore, the authors explore the benefits of this unified architecture in two specific scenarios: (1) reranking, where GritLM improves its own generated text through its embedding capabilities, and (2) retrieval-augmented generation (RAG), where the unified model serves as both retriever and reader, significantly reducing inference costs.", "strengths": "* GritLM effectively demonstrates strong performance in both generation and embedding tasks within a single model.\n\n* The paper presents a thorough experimental evaluation, including reranking and RAG scenarios, showcasing the practical advantages of the unified architecture.", "weaknesses": "The scalability of the proposed method raises some concerns. The practicality of training and deploying a single model for both retrieval and generation may be limited to certain model sizes. In real-world applications, employing a smaller, faster embedding model alongside a potentially much larger generation model is often preferred. A smaller embedding model typically suffices for retrieval, while larger generation models are crucial for high-quality text generation. The paper would benefit from a discussion addressing the impact of model scale on the effectiveness of the unified approach and whether it remains advantageous when using vastly different-sized models for retrieval and generation. Specifically, quantifying the trade-off between performance and efficiency in such mixed-size scenarios would strengthen the paper's claims.\n\n(An alternative approach for using different sizes of embedder and generator is to use the output of the N-th layer (where N is relatively small) for embeddings instead of the last layer.)", "questions": "N/A"}, {"review_id": "TO9kRkCXZC", "reviewer": "Reviewer_2rfu", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "This work introduces GRIT, a method to train a single large language model to excel at both generative and embedding tasks through differentiated instructions. They proposed GRITLM models achieve SOTA performance on the Massive Text Embedding Benchmark and surpass other models in generative tasks. GRIT unifies the two tasks without compromising performance, offering efficiency gains such as over 60% faster Retrieval-Augmented Generation for long documents. The unified model simplifies infrastructure by handling both embedding and generative tasks, reducing the need for separate models.", "strengths": "1. GRIT introduces a novel approach that enables a single large language model to excel at both generative and embedding tasks, traditionally handled separately.\n2. By eliminating the need for separate retrieval and generation models, GRIT speeds up RAG by more than 60% for long documents, which is a substantial improvement in processing time and resource management", "weaknesses": "1. The unified model requires more training resource, while there is no comparison of the performance between separate generation models and embedding models under the same resource consumption as shown in Table 1 and Table 2.\n2. The paper uses the Mistral model as the base model. I think it would also be necessary to conduct experiments on the LLaMA series models to verify the robustness of the method.", "questions": "Please refer to the above."}, {"review_id": "vaPUFZ4eP4", "reviewer": "Reviewer_xDdr", "rating": 8, "confidence": 3, "soundness": 4, "presentation": 4, "contribution": 4, "summary": "The paper introduces a new framework called GRIT, which aims to unify text generation and embedding tasks within a single LLM, GRITLM. The model handles both tasks efficiently by distinguishing between them through instructions, which streamline use in multi-task applications like RAG. The authors demonstrate that GRITLM performs strongly on text representation and generation benchmarks, achieving competitive performance on the Massive Text Embedding Benchmark (MTEB) while also excelling in generative tasks. \n\nContributions:\n1. Unified Generative and Embedding Model: The GRIT framework combines generative and embedding tasks within a single LLM. By using instructional prompts to distinguish between tasks, GRIT allows both generation and embedding without sacrificing performance. GRIT also reduces the need for separate models and complex infrastructure setups. This unification could simplify real-world deployments, particularly for applications that traditionally require both retrieval and generation components, such as search engines, recommendation systems, and conversational AI.\n2. Efficient RAG catching design: The paper proposes innovative caching techniques, like Doc-Query Caching, and Query-Doc Caching, that significantly speed up RAG processes by reducing the number of forward passes required for long document processing. This approach reduces computational load for RAG tasks, enhancing efficiency in applications that rely on fast, context-sensitive retrieval and generation.\n3. Competitive Performance Across Generative and Embedding Benchmarks: GRITLM achieves strong results on both the MTEB and several generative tasks, outperforming other open models of comparable size. This dual-task proficiency demonstrates that GRIT can match or exceed task-specific models, marking a significant step toward a general-purpose language model that handles both types of tasks seamlessly.\n4. Task-Specific Performance Optimization: GRIT introduces several improvements, such as bidirectional attention with mean pooling for embedding tasks and mixed token-sample level loss aggregation for generative tasks. These innovations contribute to the model's performance across diverse tasks and offer insights into optimizing large language models for multi-task functionality.", "strengths": "Originality: The GRIT framework presents an original approach by unifying generative and representational capabilities within a single model, GRITLM, that can seamlessly switch between tasks based on instructional prompts. This concept is innovative as it directly addresses a long-standing limitation in language models: the need for distinct models optimized separately for generation and embedding. Previous work has focused on either generation or embedding, often leading to complex infrastructures where multiple models must be managed, synchronized, and deployed separately. GRIT’s unified approach not only simplifies these workflows but also brings both task types under one architecture without compromising performance. Additionally, GRIT’s application of caching techniques to accelerate RAG showcases an innovative use of model design to enhance efficiency, a departure from traditional RAG approaches that rely on separate models.\n\nQuality: The paper demonstrates a strong methodological foundation, supported by comprehensive experimentation and ablation studies. The authors provide detailed evaluations on major benchmarks, contrasting GRITLM’s performance with task-specific models to validate its efficacy as a multi-task solution. The robustness of the results is further confirmed through comparisons with proprietary models and current open-source alternatives, evidencing GRIT’s strong performance in both generative and embedding tasks. The use of ablations to explore trade-offs in task prioritization, loss aggregation, and memory efficiency contributes to the overall rigor, allowing readers to clearly understand how GRIT’s dual-objective structure was optimized. The paper’s experiments on efficiency gains with caching also underscore the quality of its findings, providing quantitative backing for its claims regarding speed improvements in RAG tasks.\n\nClarity: The paper is well-organized and clear in its presentation, guiding the reader through complex ideas with a logical flow. Key concepts, GRIT’s caching mechanisms, and instructional tuning, are introduced with adequate background and broken down into understandable segments. Figures effectively support comprehension, making the technical details more accessible. The thorough presentation of results, including detailed tables and ablation analyses, provides clarity around GRIT’s performance relative to baselines, demonstrating where it excels and where there may be trade-offs. Additionally, the inclusion of an in-depth Appendix suggests a commitment to transparency and accessibility, ensuring that interested readers have the resources to delve deeper into implementation specifics and experiment configurations. \n\nSignificance: The significance of GRIT lies in its potential to impact the field of NLP by simplifying multi-task language model deployment and reducing reliance on separate models for embedding and generation tasks. Furthermore, GRIT’s caching innovations for RAG tasks significantly reduce computational overhead and latency, especially in long-document settings, which is valuable for any application that relies on fast, context-aware responses. Moreover, GRIT’s design choices and improvements, such as mean pooling for embeddings and loss aggregation, may inspire further research into architectural unification across other language model tasks.", "weaknesses": "Storage Costs for Caching: The paper proposes innovative caching strategies to speed up RAG, but for example, Doc Caching, in particular, requires 30TB of storage for key-value states on GRITLM 7B. Such high storage demands are prohibitive in many real-world scenarios, limiting the practical usefulness of these techniques.\n\nInstruction Dependence: The model’s reliance on instruction-based differentiation between tasks could lead to inconsistent performance if instructions are poorly structured or if the model misinterprets the intended task. Instruction-based models can sometimes be sensitive to variations in phrasing, and such dependency on clear, well-defined instructions might limit GRIT’s robustness in noisy or ambiguous real-world applications.\n\nComplexity of Caching Mechanisms and Trade-offs: While the caching mechanisms offer significant speed-ups, they introduce substantial complexity to the model's architecture and inference workflow. The paper acknowledges that Query-Doc Caching can result in degraded performance due to mismatches in attention patterns. This complexity could make it challenging for practitioners to implement GRITLM optimally and may lead to inconsistent performance across different tasks and input types.", "questions": "Storage Costs for Caching: The caching techniques provide impressive speed improvements, but the storage requirements (e.g., 30TB for Doc Caching) are substantial. Can the authors discuss potential strategies to make these techniques more feasible for real-world use, particularly in terms of storage optimization?\n\nInstruction Dependence: Including experimental results on GRITLM’s sensitivity to instruction phrasing and format would provide valuable insights into its robustness and areas for improvement."}], "discussions": [{"comment_id": "Ky4TiZaO9r", "replyto": "IUVZoUdcZL", "author_type": "reviewer", "reviewer": "Reviewer_xDdr", "comment": "Thank you for following up! Yes, your response has addressed my concerns. I appreciate the detailed clarifications provided."}, {"comment_id": "bRShXXWVr0", "replyto": "WAkwuWxbAk", "author_type": "reviewer", "reviewer": "Reviewer_XQzj", "comment": "Thank you for your response. I appreciate the authors adding additional results in Appendix F. This is great! I also want to thank the authors for pointing out that GritLM is still modular. \n\nFor these reasons, I will be increasing my score to 8 and confidence to 4."}, {"comment_id": "IUVZoUdcZL", "replyto": "sS8u6rAd1r", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nWe'd appreciate it if you'd let us know if our response has addressed your concerns.\n\nThank you!"}, {"comment_id": "80blEWPhm8", "replyto": "jcV7TQf74g", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nWe'd appreciate it if you'd let us know if our response has addressed your concerns.\n\nThank you!"}, {"comment_id": "ENjTs3tsVa", "replyto": "GT3ETWeGyl", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nWe'd appreciate it if you'd let us know if our response has addressed your concerns.\n\nThank you!"}, {"comment_id": "WAkwuWxbAk", "replyto": "dBVDY6teRK", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer,\n\nWe'd appreciate it if you'd let us know if our response has addressed your concerns.\n\nThanks!"}, {"comment_id": "sS8u6rAd1r", "replyto": "vaPUFZ4eP4", "author_type": "authors", "reviewer": null, "comment": "Thanks a lot for your extensive review and highlighting the originality and novelty of the approach.\n\n**Storage Costs:** Query caching does not require any additional storage, only doc caching requires storing additional key-value states. For doc caching the KV cache can be fully offloaded to disk and does not need to be kept in memory. Disk storage is generally cheap. One can also store only part of the cache, e.g. only cache the key-value states of the first N layers. This will still lead to speed-ups but to a lesser extent. Thus, practitioners can flexibly choose the amount of storage they want to use depending on their specific setup.\n\n**Instruction Dependence:** GritLM can produce reliable embeddings without instructions – in fact, for embedding documents during retrieval we do not use instructions. The generative part, however, indeed needs instructions. There has been prior work investigating the robustness of generative instruction-tuned models [1] and while it can be problematic at small scale, these issues generally go away at larger scales.\n\n**Complexity of Caching:** Note that caching does not change the architecture of the model. It does add some additional steps at inference, however, these are only ~3 lines of additional code to extract the key-value cache and repass it to the model. We provide open-source example code for running the caching in the supplementary material. Overall, we think that the caching is less complex than having to load and serve a second embedding/generative model as is necessary for current RAG setups that do not use GRIT.\n\n[1] Evaluating the zero-shot robustness of instruction-tuned language models by J Sun, C Shaib, BC Wallace, URL https://arxiv.org/abs/2306.11270"}, {"comment_id": "jcV7TQf74g", "replyto": "TO9kRkCXZC", "author_type": "authors", "reviewer": null, "comment": "Thanks a lot for your review and notes on GritLMs strong performance.\n\n**Comparison of training resources:** Great point! As we write in Lines 114-118, we believe that finetuning is so cheap compared to pretraining, that the additional training resources for GRIT don’t make a big difference. However, it may still matter in resource-constrained scenarios, thus we have added precise information on the GPU hours for each approach. Specifically: We used 72 GPU hours for the gen-only model, 1760 GPU hours for the emb-only model and 3072 GPU hours for GRIT. The GRIT number was already in “Appendix P: Hardware” and we have added the other two numbers there, too. Increasing the GPU hours for the gen-/emb-only model to match GRIT is unlikely to improve performance as all models have converged; Especially for the gen-only model, it would probably just lead to an excessive number of epochs. Nonetheless, we acknowledge that efficiency is a limitation and we have added more discussion on this in our “Appendix Q: Limitations and Future Work”, where we mention that packing and reusing the same samples for both the embedding and generative losses could significantly improve efficiency. We have uploaded the revised paper with the resource numbers and additional discussion - thank you for bringing this up!\n\n**Other base models:** We experimented with different base models (Llama2 and GPT-J) in Appendix A, Table 5, where we found that the approach works just as well but Mistral delivers better performance. In addition, we have also finetuned Mixtral using GRIT. All of these variants will be open-sourced."}, {"comment_id": "GT3ETWeGyl", "replyto": "YVi200zIHO", "author_type": "authors", "reviewer": null, "comment": "Thank you for your detailed review and highlighting the extensive experiments!\n\n**Scalability:** GritLM-7B is faster than a 7B generative model + a tiny embedding model for RAG when using the caching techniques we introduce. This is because the caching techniques (e.g. doc caching) will only require a single forward pass of GritLM-7B at inference, while in the other case, a forward pass for both the 7B model and the tiny embedding model is required. Without the caching techniques, speed indeed matters. We like your idea of using an intermediate layer and would expect it to lead to a performance drop while improving speed. In fact, we performed a similar experiment to reduce storage costs in Appendix A, Table 5 (e), where we find that we can downproject the embeddings to a 4x smaller dimension (->1024) at a small reduction in performance. Similarly, if we cannot use caching, we could increase speed 2x by taking the output of the middle layer at a slight reduction in performance. We have added a short note on this in Appendix A, thanks a lot for bringing this to our attention!"}, {"comment_id": "dBVDY6teRK", "replyto": "IyvHbxSJN5", "author_type": "authors", "reviewer": null, "comment": "Thank you for your detailed review. We are glad that you think the model will be useful and the work is well-positioned!\n\n**Mixed results:** As highlighted in the text we generally recommend doc caching (or query caching) but not the combined doc-query / query-doc caching mechanisms. We mostly present the query-doc and doc-query variants to inspire future work and are working on improving their performance in follow-up work. In Figure 5, we show that caching reduces latency on GPUs by around half compared to traditional RAG, which can be quite significant in time-sensitive applications. We note that the speed-up from doc (query) caching correlates directly with the length of documents (queries). E.g. for a book retrieval service where books are retrieved given user queries and each book has on the order of 10,000 or more tokens, the speed-up via doc caching would be significantly more than 2x, probably closer to 10x (depending on the query lengths). We ran RAG with a smaller model in Table 4, specifically, we ran using BGE as the embedding model which we also compare retrieval performance with in Table 1. The generative model is still GritLM-7B. Below are the match scores on NQ:\n\n- BGE Large 0.34B: 10.39\n- BGE Base 0.11B: 10.31\n- BGE Small 0.03B: 10.17\n\nFrom the paper:\n- GritLM 7B: 30.50\n\nWe find performance to be significantly worse than with GritLM. Based on a manual inspection of samples, it appears that the embedding models commonly retrieve irrelevant passages that confuse the generative model. There may be other smaller embedding models or other generative models that may perform better, but overall we expect the RAG performance to be a function of the embedding and generative performance of the individual components (e.g. if an embedding model performs better than GritLM, we would expect it to lead to better RAG performance; BGE generally does not perform better on embedding as shown in Table 1). We have added this in Appendix F, thank you for raising it!\n\n**Modularity:** This is an interesting topic, thanks for bringing it up! One can only use the embedding/generative part of GritLM, thus it is still modular. However, in that case, some advantages to having the unified model are gone, such as e.g. query caching. The doc caching technique we introduce, however, still works even if embedding and generative models are separate. In that case, however, the entire corpus needs to be passed through the generative model once during index construction. From the compute perspective, retraining a GRIT model can be cheaper than a traditional RAG model. For GRIT, the pretraining and finetuning is both done using the same model, whereas for traditional RAG models, the embedding and generative model need to be pretrained and finetuned separately thus incurring more compute. We have added a note on this in the paper by rephrasing the end of the introduction, thanks for bringing it up!\n\n**More recent work:** Thank you for pointing us to these great works. We have added citations to them and several other recent embedding papers. Please let us know if there are any other works we should be citing."}, {"comment_id": "int62mVV1M", "replyto": "BC4lIvfSzv", "author_type": "authors", "reviewer": null, "comment": "We thank all reviewers for their detailed reviews and great feedback! Below is a summary of all changes we have made to the paper in a new uploaded revision:\n1. Added RAG results with BGE in “Appendix F: Additional RAG results” in response to Reviewer XQzj.\n2. Rephrased the end of the Introduction to better motivate that GritLM requires less compute than separate generative and embedding models when considering pretraining thanks to a pointer from Reviewer XQzj.\n3. Added citations to more recent work in Section 3.2 thanks to pointers from Reviewer XQzj.\n4. Added more discussion on the potential speed-performance trade-off of using a smaller and faster embedding model by using the embedding from intermediate layers of GritLM in “Appendix A: Ablations” together with our embedding head ablation that explores the cost-performance trade-off of smaller embedding dimensions.\n5. Added resources used by Gen.-only and Emb.-only baselines in “Appendix P: Hardware” thanks to the comment by Reviewer 2rfu.\n6. Elaborated more on training discussion and potential avenues for improvement in “Appendix Q: Limitations and Future Work” in response to Reviewer 2rfu.\n\nOverall, we are glad reviewers have found the paper to be well-positioned and the methods to be original and novel. Reviewers have also pointed to the strong performance of the model and its usefulness. There was a lot of interest in using GRIT for RAG and the caching variants proposed - We are excited about further pushing these approaches together with the broader community."}], "meta_review": {"metareview": "Previous embedding models and generative models are typically learned separately. This paper proposes to learn them together through massive multi-task training and different tasks are separated through instructions. Experimental results are strong, demonstrating the performance of this joint model can match the best performance from both worlds. One unique advantage of this model is the improved efficiency in RAG applications, where the model can reuse the encodings of the query in an RAG pipeline. The reviewers unanimously vote by acceptance of this paper with scores 6,6,8,8.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "BM9qfolt6p", "forum_url": "https://openreview.net/forum?id=BM9qfolt6p", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "LucidPPN: Unambiguous Prototypical Parts Network for User-centric Interpretable Computer Vision", "authors": ["Mateusz Pach", "Koryna Lewandowska", "Jacek Tabor", "Bartosz Michał Zieliński", "Dawid Damian Rymarczyk"], "abstract": "Prototypical parts networks combine the power of deep learning with the explainability of case-based reasoning to make accurate, interpretable decisions. They follow the this looks like that reasoning, representing each prototypical part with patches from training images. However, a single image patch comprises multiple visual features, such as color, shape, and texture, making it difficult for users to identify which feature is important to the model.\n\nTo reduce this ambiguity, we introduce the Lucid Prototypical Parts Network (LucidPPN), a novel prototypical parts network that separates color prototypes from other visual features. Our method employs two reasoning branches: one for non-color visual features, processing grayscale images, and another focusing solely on color information. This separation allows us to clarify whether the model's decisions are based on color, shape, or texture. Additionally, LucidPPN identifies prototypical parts corresponding to semantic parts of classified objects, making comparisons between data classes more intuitive, e.g., when two bird species might differ primarily in belly color.\n\nOur experiments demonstrate that the two branches are complementary and together achieve results comparable to baseline methods. More importantly, LucidPPN generates less ambiguous prototypical parts, enhancing user understanding.", "keywords": ["xai", "interpretability", "prototypical parts"], "primary_area": "interpretability and explainable AI", "pdf_url": "https://openreview.net/pdf?id=BM9qfolt6p", "decision": "Accept (Poster)", "num_reviews": 4, "num_discussions": 22, "reviews": [{"review_id": "Kydg8p3P6R", "reviewer": "Reviewer_Zw4S", "rating": 6, "confidence": 3, "soundness": 2, "presentation": 2, "contribution": 2, "summary": "Summary Of Contributions:\n1.Introduction of LucidPPN: This novel architecture separates color features from other visual components during inference, enabling clearer identification of feature importance in the decision-making process.\n2.Consistent Object-Part Mapping: A mechanism ensures that prototypes within each class consistently correspond to the same object parts, improving interpretability.\n3.Enhanced Visualization Method: A more intuitive visualization type is introduced, optimized for fine-grained classification.\n4.Comprehensive Analysis: The paper provides an in-depth examination of LucidPPN's usefulness and limitations, particularly identifying cases where color may or may not be a critical feature in fine-grained classification.", "strengths": "1.The LucidPPN in the paper consists of two branches, one for color and the other for shape/texture, which effectively decouples different features. This method can reduce the ambiguity of traditional prototype networks and enable users to better understand the reasons behind the model's decisions.\n\n2.Compared to existing methods, LucidPPN achieves a more detailed analysis of Prototypical Parts, making it easier for users to understand the features that the model is focusing on.\n\n3.Through user studies, it was proven that the explanations provided by LucidPPN are clearer and easier for users to understand than those of other models such as PIP-Net. This empirical result helps to enhance the persuasiveness of the method.", "weaknesses": "Weakness\n\n1.The Section 3 has a lot of paragraphs but lacks subheadings, making it difficult to follow the logical flow of the different parts.\n\n2.There was no noticeable advantage in accuracy. The model was compared on four datasets in total, and its accuracy was lower than that of PIP-Net on two of the datasets, especially on the CUB dataset, where its accuracy was lower than that of all three methods, and no explanation was given for this gap.", "questions": "Concerns:\n1.It is recommended to add subheadings to each key step or method description to make it easier for readers to understand and locate the content.\n\n2.Consider further improving the accuracy of LucidPPN to enhance its explainability while maintaining a minimal loss of performance."}, {"review_id": "USxI5SFIHA", "reviewer": "Reviewer_na4o", "rating": 8, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "In this paper, the authors proposed a Lucid Prototypical Parts Network (LucidPPN), a novel prototypical parts network that separates color prototypes from other visual features. A LucidPPN has two branches: a ShapeTexNet and a ColorNet. Given an input image, the ShapeTexNet is a convolutional neural network (CNN) that takes a gray-scale version of the image as input and outputs a set of feature maps, and the ColorNet is another CNN that takes a down-sampled version of the image as input and outputs another set of feature maps. Since the last layer of both the ShapeTexNet and the ColorNet is a 1x1 convolutional layer with KM filters, we can interpret the last convolutional layer as a prototype layer with KM prototypes, where K is the number of prototypes per class and M is the number of classes, and the output of the last layer as prototype activation maps. The output feature maps (aka prototype activation maps) from the ShapeTexNet and the ColorNet are fused using element-wise products, and then max-pooled to yield a prototype similarity score for each prototype. The predicted class score is simply an average of the prototype similarity scores over all prototypes of the class. In a LucidPPN, each of the K prototypes in each of the M classes corresponds to consistent image parts (e.g., the first prototype of each class corresponds to head of a bird, etc.). This is achieved by aligning the fused output feature maps (prototype activation maps) with segmentation masks produced by a pre-trained PDiscoNet (an object part segmentation model) using a prototypical-object part correspondence loss. In addition to a loss function to improve the classification accuracy of the entire model, the authors also introduced a loss function to improve the classification accuracy of the ShapeTexNet alone and to disentangle color from other visual features. The authors evaluated their LucidPPN models on 4 commonly used fine-grained classification benchmarks (CUB-200-2011, Stanford Cars, Stanford Dogs, and Oxford Flowers), and found that their LucidPPN models achieved competitive test accuracy compared to other interpretable models. The authors also did a user study to evaluate the influence of disentangling color from other visual attributes on interpretability.", "strengths": "- Originality: The paper introduced a novel idea of disentangling color from shape and texture, so that the visual attribute of each prototype is more clearly defined (compared to prior work).\n- Quality: The authors did show that their LucidPPN could maintain a reasonable accuracy while providing less ambiguous prototypes.\n- Clarity: The paper is clearly written.\n- Significance: Interpretability is a significant area of research in machine learning.", "weaknesses": "- Quality: There seems to be no prototype projection in this work. Without prototype projections, it is unclear if the prototypes can be faithfully visualized using training images (because the closest training images to a prototype could still be far away from the prototype in the latent space).\n- Clarity: Page 6, Lines 314-315. I am confused as to whether you are aligning the segmentation masks from PDiscoNet with prototype activation maps from the ShapeTexNet or the aggregated feature maps.", "questions": "- My main concern is that I did not see prototype projections in this work. Without prototype projections, how could you conclusively visualize prototypes using training images? The closest training images to a prototype could still be far away from the prototype in the latent space.\n- During training, are the segmentation masks from PDiscoNet aligned with the ShapeTexNet feature maps or the aggregated feature maps? \n- I am also not clear as to why binary cross entropy is used instead of multi-class cross entropy for training?"}, {"review_id": "Ck1szKWb7j", "reviewer": "Reviewer_9gD3", "rating": 6, "confidence": 5, "soundness": 4, "presentation": 4, "contribution": 4, "summary": "This paper propose to disentangle color prototypes from other visual features in ProtoPNets, by introducing a novel network architecture, named LucidPPN. The proposed method clarifies feature importance and aligns prototypical parts with object semantics, enhancing interpretability. Experiments show that LucidPPN achieves competitive accuracy while producing clearer and less ambiguous explanations for users.", "strengths": "* This paper explicitly decouple prototypes into specific semantic types, such as color and shape, whereas existing methods have overlooked this aspect of information. And I believe this paper could serve as a significant inspiration for future research.\n* This paper provides sufficient cases and visualizations to validate the semantic information of the learned prototypes.\n* The paper is well-written and easy to follow.\n* The authors provide code for reproducibility check.", "weaknesses": "[Major]\n\n1. **Quantitative evaluation of the interpretability:** In previous work, Huang et al. [1] have discussed the inconsistency of traditional ProtoPNets. Does this issue exists within the proposed method? Please provide qualitative or quantitative evaluations.\n2. **Experiments:** Please supplement the missing results for baseline methods on datasets like DOGS and FLOWERS in Table 1, as adapting to these datasets, which were not covered in the original papers, seems quite straightforward.\n3. **Experiments:** This paper only implement the proposed method on several CNNs. However, vision Transformers are introduced to the realm of CV for several years, and have also been implemented as the backbone of ProtoPNets [2]. Please provide additional experimental results using ViT [3-4] or even CLIP [5] as the backbone.\n4. **Related Work:** In XAI, introducing human understandable semantics as evidences for prediction has been explored by concept bottleneck models (CBMs) [6]. What is the relationship between the proposed method and CBMs. Can concepts be introduced into the realm of ProtoPNet for higher interpretability?\n\n\n> [1] Huang, Qihan, et al. \"Evaluation and improvement of interpretability for self-explainable part-prototype networks.\" Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023.\n> \n> [2] Xue, Mengqi, et al. \"Protopformer: Concentrating on prototypical parts in vision transformers for interpretable image recognition.\" arXiv preprint arXiv:2208.10431 (2022).\n>\n> [3] Dosovitskiy, Alexey, et al. \"An image is worth 16x16 words: Transformers for image recognition at scale.\" International Conference on Learning Representations. 2021.\n>\n> [4] Touvron, Hugo, et al. \"Training data-efficient image transformers & distillation through attention.\" International conference on machine learning. PMLR, 2021.\n>\n> [5] Radford, Alec, et al. \"Learning transferable visual models from natural language supervision.\" International conference on machine learning. PMLR, 2021.\n>\n> [6] Koh, Pang Wei, et al. \"Concept bottleneck models.\" International conference on machine learning. PMLR, 2020.\n\n[Minor]\n\n1. **Experiments:** What is the computational cost of inference and training? Please provide a comparison with baseline methods, including metrics such as training time, FLOPs, and memory usage.", "questions": "My questions are listed in \"Weaknesses\" section."}, {"review_id": "FIunVysmJR", "reviewer": "Reviewer_9WXj", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "The manuscript presents the Lucid Prototypical Parts Network (LucidPPN), designed to identify key visual features—specifically color, shape, and texture—based on the prototypical parts networks. The proposed LucidPPN utilizes a non-color branch to process grayscale images alongside a color branch that focuses on color information, thereby clarifying the model's decisions based on these visual attributes. Experimental results demonstrate that the proposed method exhibits advantages over baseline approaches and generates more interpretable prototype parts.", "strengths": "(1)\tThe methodology is well-structured, with intuitive design in the separation of color and non-color network branches, making it accessible and easy to understand.\n\n(2)\tThe experiments are comprehensive, with a substantial number of visualization results provided in the appendix, enhancing the manuscript's depth.", "weaknesses": "(1)\tWhile analyzing \"color,\" \"shape,\" and \"texture\" offers a valuable perspective, these features have been extensively studied in the field of visual perception. Given that the shallow layers of deep networks are capable of extracting low-level features, the necessity for additional processing and analysis from prototypical parts raises concerns on the novelty and contribution of this work.\n\n(2)\tThe improvements demonstrated by the proposed method appear to be limited because its performance on some instances is lower than that of the compared methods. For example, in Table 1, the proposed method underperforms other prototypical parts networks on some datasets. While color, shape, and texture are indeed significant visual features in interpretability, they may not be sufficiently critical in this context.\n\n(3)\tThe organization of the experimental section appears somewhat unbalanced. While the results and visualizations presented are commendable, an excessive amount of content is relegated to the appendix, which may hinder the reader’s ability to grasp key insights and maintain a coherent narrative.", "questions": "(1)\tThe manuscript focuses on interpretability through the lenses of color, shape, and texture. However, other low-level features such as edges, contrast, and spatial frequency are also relevant. Have alternative low-level features also been considered in the analysis?\n\n(2)\tThe datasets utilized in the experiments are relatively small in size. How will the proposed method perform on larger datasets, such as ImageNet? Some insights into performance scalability would be beneficial.\n\n(3)\tThe manuscript primarily presents visualization results for the prototypical parts identified by the proposed method. How do these results compare with other prototypical parts-based models? A comparative analysis would enhance the understanding of the method's effectiveness.\n\n(4)\tIn global feature visualizations, such as Figure 14, the manuscript illustrates the ability of the proposed method to detect shape and color. How does this compare with traditional edge detection operators (e.g., Sobel) for shape extraction and color feature extraction methods (e.g., color histogram)? Additionally, how does it compare with the direct visualizations of shallow layer attention to texture and color using techniques like Grad-CAM?"}], "discussions": [{"comment_id": "wS4mEqwhu9", "replyto": "ytXUeATpfQ", "author_type": "authors", "reviewer": null, "comment": "We sincerely appreciate your reassessment and are thrilled that the revisions addressed your concerns. Thank you for your time and constructive feedback!"}, {"comment_id": "ytXUeATpfQ", "replyto": "0kSeWoa1Ws", "author_type": "reviewer", "reviewer": "Reviewer_na4o", "comment": "Thank you for your clarification. I have raised my score to 8: accept, good paper."}, {"comment_id": "0kSeWoa1Ws", "replyto": "DayXG6p7hS", "author_type": "authors", "reviewer": null, "comment": "Thank you for your valuable feedback and for raising the score.\n\nIt is indeed interesting to see the effects of pruning the prototypes with less faithful representation (those with resemblance scores < 0.9). Therefore, we investigate it in the newly added section \"Pruning prototypes with less faithful visualizations\" of the Supplementary Materials.\nIt contains Table 13, which shows that LucidPPN accuracy after pruning drops only by around 2% (from 81.6% to 79.3%). However, interestingly, the accuracy stays the same for $L_C=0.05$.\nIt suggests that combination of $L_C$ and pruning allows to enforce high resemblance scores (>0.9) of visualized patches without sacrificing on the accuracy.\n\nWhen it comes to the discussion on choosing sigmoid and binary cross entropy, we added it as section \"Reason behind using the Binary Cross Entropy with Sigmoid instead of the Cross Entropy with Softmax\" of the Supplementary Materials."}, {"comment_id": "DayXG6p7hS", "replyto": "USxI5SFIHA", "author_type": "reviewer", "reviewer": "Reviewer_na4o", "comment": "Thank you for your clarification. I have raised my score to 6: marginally above the acceptance threshold.\n\nTo further improve the paper, it would be interesting to see what happens if you remove all prototypes whose self-resemblance score is not close to 1.\n\nAlso, it would be helpful to add a discussion on why you chose sigmoid (instead of softmax) and binary cross entropy (instead of multi-class cross entropy) for your method."}, {"comment_id": "WjPFVoNe2s", "replyto": "euB784j0rm", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank you for your valuable insights, which have significantly contributed to enhancing our manuscript."}, {"comment_id": "yNO5ePaqle", "replyto": "4JAn53qJ06", "author_type": "authors", "reviewer": null, "comment": "Thank you for your valuable feedback and for raising the score of our manuscript. We are pleased to hear that we were able to address most of your concerns.\n\nTo better clarify the unique advantages of our approach compared to shallow network visualization methods, we have revised the section \"Usage of low-level vision features for image classification\" in the Related Works.\n\nWe greatly appreciate your insights, which have helped us improve the manuscript."}, {"comment_id": "LfV7E91rFL", "replyto": "GMan0iRRXx", "author_type": "authors", "reviewer": null, "comment": "To answer this question, we provide an additional section \"Faithfullness of patch visualizations\" in the Supplementary Materials. It contains Figure 20 with a distribution of the sigmoid function values obtained for patches used in prototype visualization. For LucidPPN trained on the CUB dataset (blue curve), 61.04% of those patches have values above 0.9, which indicates that prototype visualizations are relatively faithful.\n\nMoreover, we show that higher faithfulness can be obtained when training with an additional loss component $L_C$ that punishes the model if the sigmoid function value for a given prototype is smaller than $1$ for all samples in the batch (see green and yellow curves in Figure 20)."}, {"comment_id": "4JAn53qJ06", "replyto": "tdH6OUUzZk", "author_type": "reviewer", "reviewer": "Reviewer_9WXj", "comment": "Thanks for the prompt response. \n\nThe concerns regarding accuracy have been largely addressed. We agree with the authors on their explanation that by decoupling color, shape, and texture in the early stages and fusing them later for more accurate interpretation, there is an inherent trade-off in terms of accuracy.\n \nRegarding concerns about the contribution, some doubts still remain. While we acknowledge that the authors have made valuable contributions to the prototype network by introducing inherently interpretable mechanisms at the low-level feature level, similar research [1] on shallow network visualizations has already provided substantial insights into the impact of features such as color, shape, and texture on classification network results. This somewhat limits the scope of the contribution of the proposed method in the context of low-level feature interpretation. The authors are still expected to further clarify the unique advantages of their approach compared to shallow network visualization in the revised version, which would definitely better highlight the contribution and innovation of the manuscript.\n\nBut overall, after the above rounds of feedback and discussion, I think this work have adequate technical merits and contributions, and I would be happy to raise the score to 6 (marginally above accept).\n \n1. Zeiler, M. D. (2014). Visualizing and Understanding Convolutional Networks. In European Conference on Computer Vision."}, {"comment_id": "euB784j0rm", "replyto": "0RuuCW6gws", "author_type": "reviewer", "reviewer": "Reviewer_Zw4S", "comment": "Thanks for the response ! I got the explanation for limited improvement. Considering the overall quality, I vote for borderline accept."}, {"comment_id": "GMan0iRRXx", "replyto": "USxI5SFIHA", "author_type": "reviewer", "reviewer": "Reviewer_na4o", "comment": "Thank you for your response!\n\nI see why you chose sigmoid (instead of softmax) to be applied to the latent feature maps.\n\n\"As a result, one can easily verify if the image patches selected for visualization are faithful because such patches should have a resemblance score close to 1.\"\n\nIn order to establish that your prototype visualizations are faithful, did you verify that every image patch selected for the visualization of a prototype has a self-resemblance score close to 1 with the prototype itself? Is it true? What would you do if you found a prototype whose visualized patch did not have a self-resemblance score close to 1?"}, {"comment_id": "0RuuCW6gws", "replyto": "X6bK7F0o2z", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer Zw4S,\n\nAs the deadline for the discussion period is approaching quickly, we would like to kindly remind the reviewer that we are waiting for your response.\n\nIn particular, we have provided point-by-point responses to all of your questions to address your concerns and provided the revision that reflects such changes. Therefore, your timely feedback and change in the score if applicable would be highly appreciated.\n\nBest,\n\nAuthors"}, {"comment_id": "bHyXYonvpo", "replyto": "wG4iN4El3q", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer na4o,\n\nAs the deadline for the discussion period is approaching quickly, we would like to kindly remind the reviewer that we are waiting for your response.\n\nIn particular, we have provided point-by-point responses to all of your questions to address your concerns and provided the revision that reflects such changes. Therefore, your timely feedback and change in the score if applicable would be highly appreciated.\n\nBest,\n\nAuthors"}, {"comment_id": "LRwQ7ksX5U", "replyto": "9o7ElZbTFg", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer 9gD3,\n\nAs the deadline for the discussion period is approaching quickly, we would like to kindly remind the reviewer that we are waiting for your response.\n\nIn particular, we have provided point-by-point responses to all of your questions to address your concerns and provided the revision that reflects such changes. Therefore, your timely feedback and change in the score if applicable would be highly appreciated.\n\nBest, \n\nAuthors"}, {"comment_id": "tdH6OUUzZk", "replyto": "Uqw4oNwb30", "author_type": "authors", "reviewer": null, "comment": "As noted in the Review, we agree that there are techniques to visualize the shallow layers of neural networks, and these methods provide a certain level of interpretability. However, LucidPPN focuses on high-level concepts (represented with prototypical parts) from deeper layers that encode complex information. Our goal is to enhance the transparency of these concepts by disentangling the color from the remaining visual features. While there is a wealth of research on prototypical parts-based interpretability (e.g., Chen et al., 2019; Nauta et al., 2021; 2023b; Rymarczyk et al., 2021; 2022; 2023; Wang et al., 2024), none of these works aim to introduce an inherently interpretable mechanism into the network at the level of low-level visual features.\n\nRegarding accuracy, we note that the results for the late fusion method (LucidPPN) and the earliest fusion method (single branch) are provided in Table 4. For the CUB dataset, LucidPPN scored 81.5% while single branch 86.6%. To investigate the influence of earlier fusion on the accuracy, we also experimented with fusion applied after the second block of ConvNeXt Tiny. This configuration achieved an accuracy of 84.1%, indicating that earlier fusion can indeed enhance the model's accuracy. However, this comes at the cost of explanation granularity: when fusion occurs earlier, it becomes difficult to disentangle the influence of shape with texture and color on prototypical parts.\n\nWe kindly ask the Reviewer to evaluate whether our responses address their concerns. If not, we would appreciate clarification on two points. First, regarding shallow layer visualization, could you specify the techniques or works you had in mind so we can reference them more precisely? Second, in terms of fusion analysis, what specific types of evaluations or comparisons would you find most informative? If no further concerns remain, we kindly request a reevaluation of your score."}, {"comment_id": "Uqw4oNwb30", "replyto": "AIyjYRqj2A", "author_type": "reviewer", "reviewer": "Reviewer_9WXj", "comment": "Thanks for the detailed response from the authors. While some concerns have been addressed, there remain several points that may require further clarification.\n\n1.\tAbout the contributions:\nThe authors' response on the novelty and contribution of analyzing \"color,\" \"shape,\" and \"texture\" didn’t seem to fully resolve the concern. Visualizing the shallow layers of the neural network has indeed demonstrated the network's focus on the features of \"color,\" \"shape,\" and \"texture,\" and has visualized \"shape\" and \"texture\" more clearly, which contributes to making the network processing more transparent from these perspectives. However, the authors state that \"conventional neural networks often entangle visual features in ways that make it difficult to disentangle and present them in an understandable format for users\", which may not be entirely consistent with the explanatory results of existing methods that visualize shallow layers of networks. These existing methods also provide a certain level of interpretability. \n\n2.\tAbout performance:\nWe agree with the authors' explanation that the performance decrease may due to the delayed fusion of texture and shape features with color. This delay prevents the network from effectively associating shape and texture with color in the early layers, leading to incomplete detection of certain features. However, there is a lack of experimental results and analysis to support this explanation. For instance, comparisons between shallow fusion and deep fusion under the same conditions would help verify whether the accuracy drop is indeed caused by the location of fusions. Given that the performance decline is likely due to insufficient attention to important features (such as texture and shape), the limitations of the interpretability method itself, and other factors, further experiments related to this topic would be beneficial. A more comprehensive analysis of the limitations and the true advantages of the method may be expected."}, {"comment_id": "AIyjYRqj2A", "replyto": "FIunVysmJR", "author_type": "authors", "reviewer": null, "comment": "**Q1. The manuscript focuses on interpretability through the lenses of color, shape, and texture. However, other low-level features such as edges, contrast, and spatial frequency are also relevant. Have alternative low-level features also been considered in the analysis?**\n\nThis work represents the first step toward disentangling low-level features. As mentioned in the Limitations Section, we plan to explore low-level features in more detail in future work. For now, we focus on extracting color information and correlating prototypical parts with semantic object parts. These contributions already make the work comprehensive. Introducing additional low-level feature extraction and integration at this stage would complicate the model further and make it more difficult to communicate.\n\n\n**Q2. The datasets utilized in the experiments are relatively small in size. How will the proposed method perform on larger datasets, such as ImageNet? Some insights into performance scalability would be beneficial.** \n\nMost benchmarking of prototypical part-based methods has been conducted on fine-grained datasets such as CUB and Stanford Cars Chen et al. (2019); Nauta et al. (2021; 2023b); Rymarczyk et al. (2021; 2022; 2023); Wang et al. (2024). Scaling these architectures to ImageNet-sized datasets is an important but orthogonal research direction that remains unsolved. However, to assess whether LucidPPN generalizes to broader classification tasks (beyond fine-grained datasets), we added results on PartImageNet He et al. (2022) in the Supplementary Materials. For this dataset, LucidPPN achieves an accuracy of 84.1%, outperforming PIPNet, which achieves 82.8%.\n\n**Q3. The manuscript primarily presents visualization results for the prototypical parts identified by the proposed method. How do these results compare with other prototypical parts-based models? A comparative analysis would enhance the understanding of the method’s effectiveness.**\n\nIn Supplementary Figures 15-19, we added explanations from various prototypical-part-based methods. Additionally, our user studies demonstrate the effectiveness of our explanations from the user’s perspective.\n\n\n**Q4. In global feature visualizations, such as Figure 14, the manuscript illustrates the ability of the proposed method to detect shape and color. How does this compare with traditional edge detection operators (e.g., Sobel) for shape extraction and color feature extraction methods (e.g., color histogram)? Additionally, how does it compare with the direct visualizations of shallow layer attention to texture and color using techniques like Grad-CAM?**\n\nEdge detectors and color histograms are not trainable and do not represent high-level features, unlike prototypical parts. Even our low-level color prototypical representation captures a higher-level concept, such as a ”red tail,” which is then further decomposed into its components—color (red) and shape/texture (tail). Regarding GradCAM, we want to highlight that post-hoc methods can often be unreliable, as demonstrated in multiple studies Adebayo et al. (2018); Kim et al. (2022); Rudin (2019); Tomsett et al. (2020). This underscores the need for developing inherently interpretable models Rudin (2019); Rudin et al. (2022) such as our LucidPPN."}, {"comment_id": "SbOhGqyso3", "replyto": "FIunVysmJR", "author_type": "authors", "reviewer": null, "comment": "**W1. While analyzing ”color,” ”shape,” and ”texture” offers a valuable perspective, these features have been extensively studied in the field of visual perception. Given that the shallow layers of deep networks are capable of extracting low-level features, the necessity for additional processing and analysis from prototypical parts raises concerns on the novelty and contribution of this work.** \n\nIt is true that shallow layers of neural networks are capable of extracting low-level features. The goal of LucidPPN, however, is to make this processing more transparent to the user, which aligns with the broader objective of inherently interpretable models Chen et al. (2019); Rudin (2019). Thanks to LucidPPN we can analyze which colors were important for classification. Conventional neural networks often entangle visual features in ways that make it difficult to disentangle and present them in an understandable format for users. This is why we believe our work offers a novel contribution to the field of interpretable AI, particularly in the context of prototypical parts.\n\n\n**W2. The improvements demonstrated by the proposed method appear to be limited because its performance on some instances is lower than that of the compared methods**\n\nWe respond to this comment in **shared remarks** in paragraphs *The improvements demonstrated by the proposed method appear to be limited because its performance on some instances is lower than that of the compared methods.* and *There was no noticeable advantage in accuracy. Why?*
\n\n**W3. The organization of the experimental section appears somewhat unbalanced. While the results and visualizations presented are commendable, an excessive amount of content is relegated to the appendix, which may hinder the reader’s ability to grasp key insights and maintain a coherent narrative.**\n\nWe agree and in the revised version of the manuscript, we have reorganized the experimental section. However, due to space constraints and to adhere to the ICLR template, some content has been moved to the appendix."}, {"comment_id": "9o7ElZbTFg", "replyto": "Ck1szKWb7j", "author_type": "authors", "reviewer": null, "comment": "**W1. In previous work, Huang et al. (2023) have discussed the inconsistency of traditional ProtoPNets. Does this issue exist within the proposed method? Please provide qualitative or quantitative evaluations.** \n\nTo answer this question we have calculated consistency and stability of our method and compared it in Supplementary Table 11. One can observe that LucidPPN achieves a consistency score comparable to the method proposed by Huang et al. (2023) while outperforming other prototypical-parts-based methods. Regarding stability, LucidPPN demonstrates results on par with other methods. This improvement is likely due to the correspondence of prototypical parts to semantic parts of the classified objects.\n\n**W2. Please supplement the missing results for baseline methods on datasets like DOGS and FLOWERS in Table 1, as adapting to these datasets, which were not covered in the original papers, seems quite straightforward.** \n\nThank you for your comment. We have added results for baselines on additional datasets, except for the ProtoTree. This exception is due to the model’s tendency to exhibit instability during training on these datasets. Upon reviewing the issues section of the ProtoTree GitHub repository [https://github.com/M-Nauta/ProtoTree/issues](https://github.com/M-Nauta/ProtoTree/issues), we noticed that others have faced similar challenges in applying this model to different datasets.\n\n**W3. This paper only implement the proposed method on several CNNs. However, vision Transformers are introduced to the realm of CV for several years, and have also been implemented as the backbone of ProtoPNets Xue et al. (2022). Please provide additional experimental results using ViT or even CLIP as the backbone.**\n\nThank you for your comment. Unfortunately, we are unable to run LucidPPN with a ViT backbone for the following reasons: \n* **Incompatibility with PIPNet-Based Prototypical Part Definition**: ProtoPFormer Xue et al. (2022) is built on the ProtoPNet-based definition of prototypical parts Chen et al. (2019), whereas our method relies on the PIPNet-based definition Nauta et al. (2023b). Currently, there is no adaptation of the PIPNet-based definition for the ViT backbone. This limitation is why we opted for the ConvNeXt backbone, which has demonstrated comparable performance to ViTs. \n* **Challenges with Self-Attention**: Adapting a ViT backbone to the PIPNet-based definition is not straightforward due to the nature of self-attention. Unlike convolutions, self-attention lacks the properties of locality and a direct correspondence between the input and feature map Chen et al. (2019). This discrepancy makes it difficult to visualize prototypical parts faithfully. \n* **Orthogonal Research Direction**: Adapting the ViT backbone to prototypical parts represents a separate research direction. Both ProtoPFormer and recent works in this area Ma et al. (2024a), which were unavailable at the time of submission, highlight that integrating a ViT backbone with prototypical parts is a non-trivial task requiring substantial architectural/training changes. For these reasons, we chose to use the ConvNeXt backbone in our work.\n\n\n**W4. In XAI, introducing human understandable semantics as evidence for prediction has been explored by concept bottleneck models (CBMs) Koh et al. (2020). What is the relationship between the proposed method and CBMs. Can concepts be introduced into the realm of ProtoPNet for higher interpretability?** \n\nThank you for pointing out the connection between CBMs and our work. I’d like to clarify that both concept bottlenecks and prototypical parts can be considered concept-based models Bontempelli et al. (2022). However, there is a key distinction: concept bottlenecks use predefined intermediate classes (named concepts) that are directly associated with the image. While, prototypical parts, aim to identify relevant classification concepts during model training without any additional labels. A potential future research direction could involve combining concept bottlenecks with prototypical parts.\n\n**w1. What is the computational cost of inference and training? Please provide a comparison with baseline methods, including metrics such as training time, FLOPs, and memory usage.**\n\nIn Supplementary Table 12 we provide information about training time, GFLOPs needed, and average memory usage during training for LucidPPN, PIPNet, ProtoPool, and ProtoPNet. One can observe that LucidPPN is faster and uses less memory than PIPNet. However, ProtoPNet and ProtoPool require much less memory to train while having longer training times."}, {"comment_id": "wG4iN4El3q", "replyto": "USxI5SFIHA", "author_type": "authors", "reviewer": null, "comment": "**Q1. My main concern is that I did not see prototype projections in this work. Without prototype projections, how could you conclusively visualize prototypes using training images? The closest training images to a prototype could still be far away from the prototype in the latent space.**\n\nOur work builds on PIPNet’s definition of prototypical parts, which is why it lacks projection, which can lead to less faithful visualizations. Despite this drawback, PIPNet-based architectures have been successfully applied in various works De Santi et al. (2024a;b); Nauta et al. (2023a), and further developed, e.g. Wang et al. (2024) improving the interpretability. \n\nMoreover, LucidPPN introduces a key difference in the definition of prototypical parts compared to PIPNet. While PIPNet employs Softmax across channels in the latent feature map, LucidPPN uses the sigmoid activation function. The sigmoid function allows each channel’s activation to be learned independently, not influenced by the relative activations of other channels. While, Softmax normalization can distort activations by emphasizing values that are only relatively high compared to others, even if they are low in absolute terms. \n\nTo build an intuition for this statement, let us consider $i$-th pixel of a feature map with activation values $z_i = [−2, 5, −0.2, −0.1]$, and $j$-th pixel with $z_j = [10, 300, 30, 10]$, the Softmax output $\\theta$ for both would be $\\theta_i = \\theta_j = [0, 1, 0, 0]$. This implies that PIPNet would treat both pixels as equally important, despite the activations differing by a factor of 60. In contrast, with sigmoid activation $\\sigma$ used in LucidPPN, the outputs would be $\\sigma_i = [0.1192, 0.9933, 0.4502, 0.4750]$ and $\\sigma_j = [1.0000, 1.0000, 1.0000, 1.0000]$, preserving the distinction in activation magnitudes. As a result, one can easily verify if the image patches selected for visualization are faithful because such patches should have a resemblance score close to 1.\n\n\n**Q2. During training, are the segmentation masks from PDiscoNet aligned with the ShapeTexNet feature maps or the aggregated feature maps?** \n\nLoss $L_D$ is applied only to the ShapeTexNet feature maps as we directly align them with masks from PDiscoNet. Indirectly, it also causes alignment of masks with the aggregated feature maps which are computed from the ShapeTexNet feature maps. To the Supplementary Materials (Figure 14), we added an image illustrating this process more concisely.\n\n\n**Q3. I am also not clear as to why binary cross entropy is used instead of multi-class cross entropy for training?** \n\nThe intuition behind BCE usage is rooted from multilabel classification. To some degree ShapeTexNet operates in a multilabel setting from the prototypical parts perspective as they may match multiple classes. Hence, to enable multiple classes having high similarity to the same prototypical parts, we use sigmoid instead of softmax when computing the feature maps. This necessitates a shift from Cross-Entropy (CE) to Binary Cross-Entropy (BCE) because CE would then solely maximize the activation of the correct class while ignoring crucial signals from negative classes. Another reason behind our choice is to make it easier to verify the faithfulness of visualizations (see the answer about prototype projection)."}, {"comment_id": "X6bK7F0o2z", "replyto": "Kydg8p3P6R", "author_type": "authors", "reviewer": null, "comment": "**1. The Section 3 has a lot of paragraphs but lacks subheadings, making it difficult to follow the logical flow of the different parts.** \n\nWe agree that Section 3 was dense. Therefore, we have revised it by introducing subsections that clarify the methodology of LucidPPN more clearly.\n\n**2. There was no noticeable advantage in accuracy. Why?**\n\nWe answer this question in **shared remarks** in paragraphs *There was no noticeable advantage in accuracy. Why?* and *The improvements demonstrated by the proposed method appear to be limited because its performance on some instances is lower than that of the compared methods.*"}, {"comment_id": "bid0A5GPFT", "replyto": "CsxZvYM0H0", "author_type": "authors", "reviewer": null, "comment": "### **References**\n\nJ. Adebayo et al. Sanity checks for saliency maps. NeurIPS 2018. \n\nA. Bontempelli et al. Concept-level debugging of part-prototype networks. arXiv 2022. \n\nC. Chen et al. This looks like that: deep learning for interpretable image recognition. NeurIPS 2019. \n\nL. A. De Santi et al. Patch-based intuitive multimodal prototypes network (pimpnet) for alzheimer’s disease classification. arXiv 2024a. \n\nL. A. De Santi et al. Pipnet3d: Interpretable detection of alzheimer in mri scans. arXiv 2024b. \n\nJ. He et al. Partimagenet: A large, high-quality dataset of parts. ECCV 2022.\n\nQ. Huang et al. Evaluation and improvement of interpretability for self-explainable part-prototype networks. ICCV 2023.\n\nS. SY Kim et al. Hive: Evaluating the human interpretability of visual explanations. ECCV 2022. \n\nP. W. Koh et al. Concept bottleneck models. ICML 2020.\n\nC. Ma et al. Interpretable image classification with adaptive prototype-based vision transformers. arXiv 2024a. \n\nC. Ma et al. This looks like those: Illuminating prototypical concepts using multiple visualizations. NeurIPS 2024b. \n\nA. Nagrani et al. Attention bottlenecks for multimodal fusion. NeurIPS 2021. \n\nM. Nauta et al. Neural prototype trees for interpretable fine-grained image recognition. CVPR 2021.\n\nM. Nauta et al. Interpreting and correcting medical image classification with pip-net. ECAI 2023a. \n\nM. Nauta et al. Pip-net: Patch-based intuitive prototypes for interpretable image classification. CVPR 2023b. \n\nC. Rudin. Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead. Nature machine intelligence 2019. \n\nC. Rudin et al. Interpretable machine learning: Fundamental principles and 10 grand challenges. Statistic Surveys 2022. \n\nD. Rymarczyk et al. Protopshare: Prototypical parts sharing for similarity discovery in interpretable image classification. ACM SIGKDD 2021. \n\nD. Rymarczyk et al. Interpretable image classification with differentiable prototypes assignment. ECCV 2022. \n\nD. Rymarczyk et al. Icicle: Interpretable class incremental continual learning. ICCV 2023\n\nR. Tomsett et al. Sanity checks for saliency metrics. AAAI 2020.\n\nBS Wang et al. Mcpnet: An interpretable classifier via multi-level concept prototypes. CVPR 2024.\n\nM. Xue et al. Protopformer: Concentrating on prototypical parts in vision transformers for interpretable image recognition. arXiv 2022."}, {"comment_id": "CsxZvYM0H0", "replyto": "BM9qfolt6p", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank the Reviewers for their positive and encouraging feedback. They have recognized that our work can *serve as a significant inspiration for future research* (R9gD3), introduces *a novel idea of disentangling color from shape and texture* (R-na4o), and addresses a *significant field of research in machine learning* (R-na4o). \n\nThe clarity and interpretability of LucidPPN have been particularly appreciated. Reviewers have highlighted that it *can reduce the ambiguity of traditional prototype networks and enable users to better understand the reasons behind the model’s decisions* (R-Zw4S) and that the model makes *it easier for users to understand the features that the model is focusing on* (R-Zw4S). Furthermore, it was emphasized that *the explanations provided by LucidPPN are clearer and easier for users to understand.* (R-Zw4S). \n\nLucidPPN has also been noted for its presentation, described as *clearly written* (R-na4o), and *easy to follow* (R-9gD3). Reviewers appreciated that *this paper provides sufficient cases and visualizations to validate the semantic information of the learned prototypes* (R9gD3) and that *the methodology is well-structured, with intuitive design* (R-9WXj). The experiments were recognized as *comprehensive, with a substantial number of visualization results.* (R-9WXj). \n\nWe have carefully addressed the Reviewers’ comments and incorporated their suggestions to strengthen our manuscript. We kindly ask the Reviewers to consider increasing their rating if they find our responses satisfactory. Responses to remarks shared among Reviewers are provided below, and followed by replies to specific comments. Additionally, we have attached a revised version of the work with all changes highlighted in blue.\n\n\n### **Shared remarks**\n\n**(R-Zw4S, R-jXW9) There was no noticeable advantage in accuracy. Why?**\nThe primary goal of this work was not to surpass PIPNet in accuracy but to reduce the ambiguity of prototypical parts through color disentanglement and correspondence to semantic parts of classified objects. Multiple works Adebayo et al. (2018); Huang et al. (2023); Kim et al. (2022); Ma et al. (2024b) show that explanations are ambiguous for a user and can cause overconfidence. That is why one should consider user study as the main result that shows that LucidPPN enabled significantly better user scores than PIPNet, even on the CUB dataset where LucidPPN’s accuracy was lower. \n\n**(R-Zw4S, R-jXW9) The improvements demonstrated by the proposed method appear to be limited because its performance on some instances is lower than that of the compared methods.**\nThe accuracy drop stems from a late-stage fusion of texture and shape features with color. This delay prevents the network from correlating shape and texture with color effectively in earlier layers, causing some features to go undetected. It can be seen as a multimodal scenario, where early fusion (in our case PIPNet) achieves higher accuracy than late fusion (in our case LucidPPN), just like in Nagrani et al. (2021). Nonetheless, increasing the disambiguation of prototypical parts can improve accuracy over PIPNet, like in 3 of the 5 datasets, including PartImageNet added in the rebuttal phase."}], "meta_review": {"metareview": "In this paper, a Lucid Prototypical Parts Network (LucidPPN) prototypical parts network is presented, which has two branches: a ShapeTexNet and a ColorNet. Given an input image, the ShapeTexNet is a convolutional neural network (CNN) that takes a gray-scale version of the image as input and outputs a set of feature maps, and the ColorNet is another CNN that takes a down-sampled version of the image as input and outputs another set of feature maps. Evaluation is carried out on 4 commonly used fine-grained classification benchmarks (CUB-200-2011, Stanford Cars, Stanford Dogs, and Oxford Flowers), and found the LucidPPN models achieved competitive test accuracy compared to other interpretable models.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "d4qMoUSMLT", "forum_url": "https://openreview.net/forum?id=d4qMoUSMLT", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "Efficient Training of Neural Stochastic Differential Equations by Matching Finite Dimensional Distributions", "authors": ["Jianxin Zhang", "Josh Viktorov", "Doosan Jung", "Emily Pitler"], "abstract": "Neural Stochastic Differential Equations (Neural SDEs) have emerged as powerful mesh-free generative models for continuous stochastic processes, with critical applications in fields such as finance, physics, and biology. Previous state-of-the-art methods have relied on adversarial training, such as GANs, or on minimizing distance measures between processes using signature kernels. However, GANs suffer from issues like instability, mode collapse, and the need for specialized training techniques, while signature kernel-based methods require solving linear PDEs and backpropagating gradients through the solver, whose computational complexity scales quadratically with the discretization steps. In this paper, we identify a novel class of strictly proper scoring rules for comparing continuous Markov processes. This theoretical finding naturally leads to a novel approach called Finite Dimensional Matching (FDM) for training Neural SDEs. Our method leverages the Markov property of SDEs to provide a computationally efficient training objective. This scoring rule allows us to bypass the computational overhead associated with signature kernels and reduces the training complexity from $O(D^2)$ to $O(D)$ per epoch, where $D$ represents the number of discretization steps of the process. We demonstrate that FDM achieves superior performance, consistently outperforming existing methods in terms of both computational efficiency and generative quality.", "keywords": ["neural stochastic differential equations", "Markov process", "scoring rule"], "primary_area": "learning on time series and dynamical systems", "pdf_url": "https://openreview.net/pdf?id=d4qMoUSMLT", "decision": "Accept (Poster)", "num_reviews": 4, "num_discussions": 17, "reviews": [{"review_id": "JXEcfWl4J4", "reviewer": "Reviewer_GLJM", "rating": 3, "confidence": 3, "soundness": 2, "presentation": 1, "contribution": 2, "summary": "This paper proposes a new method called Finite Dimensional Matching (FDM) for training Neural SDEs by identifying a class of strictly proper scoring rules for comparing continuous Markov processes. Using this scoring rule, they claim to reduce the training complexity from quadratic in discretization timesteps to linear.", "strengths": "This paper addresses an important problem: the high computational complexity (quadratic in time steps) associated with training Neural SDEs using scoring rules. The authors propose a reduced complexity method, aiming for linear complexity to enhance performance. They also provide a theoretically grounded approach in designing a new scoring rule. However, despite this strong motivation, the results are somewhat unconvincing due to certain weaknesses noted below.", "weaknesses": "This paper has several notable weaknesses. While the authors propose a reduced complexity approach for training Neural SDEs, they do not adequately explain key concepts, making the paper difficult to follow, especially for readers without a strong background in this area. For example, despite experience with SDEs in score-based generative models and deep learning theory, I found the explanations lacking in detail and context and some results are not convincing.\n\nThe authors do not provide sufficient preliminary material on scoring rules or background on how these rules are used to measure divergence between two Markov processes. A concrete example of the scoring rule $s(P, z)$ with an RBF kernel, presented early on, would have improved clarity.\n\nAdditionally, the complexity reduction claim is unconvincing. For example, at the beginning, the authors claim they reduce the complexity to linear using $D$ to denote the time steps, however, this notation $D$ is never used again in the rest of the paper. Instead, in Section 4.2 Algorithm, they compare two stochastic processes and use $B$ to denote the total number of time steps. However, the nested summations in the top equation on page 5 suggest quadratic rather than linear complexity, i.e., $B^2$.\n\nFinally, the paper uses the **ICLR 2024 format rather than the ICLR 2025** format.", "questions": "In Theorem 2, does the result hold for any scoring rule $s$, or does $s$ also need to be strictly proper? Could you clarify if there are specific conditions on $s$ required for Theorem 2 to apply?"}, {"review_id": "LJD22sR0Oz", "reviewer": "Reviewer_SGBZ", "rating": 8, "confidence": 4, "soundness": 4, "presentation": 2, "contribution": 3, "summary": "This paper introduces a very simple yet elegant way of comparing the similarity of the laws of two homogeneous Markov processes. The idea follows from the fact that the distribution of the process is completely determined by the transition kernel. Hence the ability to compare the transition distribution should give enable a comparison between the laws. \n\nThe authors then use this finding as a scoring rule to learn to simulate from a neural SDE, given the observations from a ground truth process. Then, they test their scoring rule and their FDD matching procedure in numerical experiments, showing that their method outperforms existing law matching techniques.", "strengths": "The method is very simple and elegant and the implication of this result, **if true**, could be impactful to the neural SDE and generative modeling community.\n\n**Update: I was wrong about the validity of the proof. The result is sound. **", "weaknesses": "I think there is a significant issue in the proof that could invalidate your main theorem 2. In the middle part of page 14, you have the integral over $t_1,t_2$ of the $S(P_{t_1,t_2}, P_{t_1,t_2}')$ scoring rules being equal, but then you conclude that $$S(P_{t_1,t_2},P_{t_1,t_2}') = S(P_{t_1,t_2},P_{t_1,t_2})$$ a.e.. This is not true or I am missing something? \n\n**Update: I was wrong about the validity of the proof. The result is sound. **\n\nMy intuition is that this is not an easy fix: You want to match the distribution over all $t_1,t_2$, so, in some sense, you want the expectation equality to hold for all test measures $\\nu$ instead of just one particular $\\nu$. I would be very happy to raise the score and rewrite my review if I am wrong. However, it does seem that proving a result like yours is possible, given that the generator or the resolvent will determine the law of the Markov process (see Either and Kurtz). \n\nI appreciate the rigor in defining the math notations and results. But the writing and explanations in the paper need improvements. For example, what do you mean by \"Update the model parameters $\\theta$ through backpropogation to maximize $\\hat S$\" (inside the algorithm)? Also, there is no explanation of what the data is. What are the \"Average KS test scores\"? Are you repeating the experiment across multiple batches and producing the percentage of rejection \"chance of rejecting the null hypothesis (%) at 5%-significance level on marginal\"? \n\nOther minor issues include: \n1. I think you need $\\mathcal{E}$ to be Polish. \n2. The radial basis function (RBF) kernel is not defined. \n3. The $\\pi$ notation seems a little distracting, why not just use $(x_t,x_s)$ for $\\pi_{t,s}(x)$. \n4. It would be nice to emphasize that the Markov processes you define are homogeneous Markov processes.", "questions": "The main concern I had was in regard to the proof. Please help me to understand or fix the issue. \n\nIf this and the writing issues I raised above can be addressed, I will happily give you at least a 6. However, since I feel that your main result is wrong, I have to give a low score for now.\n\n**Update: I was wrong about the validity of the proof. The result is sound. **\n\nIf this (or a variant of the) scoring rule for Markov processes is indeed correct, I think the authors could improve the paper by exploring the sensitivity properties of this scoring rule; for example, what kernel to use and how the score behaves when P and Q are close? Is there a simple formula to compute the gradient?"}, {"review_id": "tzblZicoc8", "reviewer": "Reviewer_DrUo", "rating": 8, "confidence": 4, "soundness": 4, "presentation": 2, "contribution": 3, "summary": "The authors propose a scoring rule for continuous time stochastic processes that is directly derived from a scoring rule on a generic space. They show that this rule is proper (i.e. injective from the space of paths to the space of laws), which is a non-trivial contribution. Experiments show that this method outperforms existing concurrent methods based on signature kernels and SDE-GANs. I stress that the experiments are carried out on a vast array of datasets.", "strengths": "* The paper is well written and structured, making it easy to read and to follow. The introduction is sound and features a thorough literature review.\n \n* The main technical contribution is Theorem 2, which allows to convert any scoring rule on a generic space. I believe this contribution to be non trivial and novel, although simple to prove. \n\n* The experiments are carried out on a vast array of time series datasets and show overall superior performance of the proposed approach. The authors compare themselves to all relevant baselines to the best of my knowledge. Large generative models such as diffusion models are not included ; however, I do believe that they do not belong to the same class of models and do not require a comparable computational budget.", "weaknesses": "* While the contribution made in Theorem 2 is elegant and novel, one could object that it is slightly insufficient --- this is not my point of view, but another reviewer might disagree and I am willing to discuss this point. In order to strengthen their theoretical part, I encourage the authors to consider for instance the sample complexity of their kernel i.e. how fast does the empirical divergence they define through a kernel on $\\mathcal{E}$ converge through the expected divergence ? Do sample complexities in $\\mathcal{E}$ carry over to the space of paths ? See Gretton (2012) for an example. \n\n* A blind spot of the paper is in my opinion the choice of the kernel. The authors do not seem to consider other kernels for $\\mathbb{R}^d$ valued processes, which could yield an interesting extension. This might especially be interesting since some kernels are sometimes used because of their specific properties (invariances, ...). I would suggest that the authors consider this point, at least in the Appendix. \n\n* This approach nicely extends to kernels defined on any space - this could be graphs, images, etc. and could allow to generate time series in these spaces. This would provide, in my sens, a extremely valuable extension to the paper. \n\n* The experimental section is hard to read and a tad unstructured. I encourage the authors to use less tables, add more comments and broaden the analysis of their experiments. Also, a notable restriction of this part is that experiments are only carried out on $2$ dimensional time series. I strongly encourage the authors to extend their experiments to high dimensional datasets. Also, there are no confidence intervals in the tables. \n\n* Regarding this last point, I believe that a valuable extension could be to consider random feature approximations to the kernels for high dimensional generation, which is still a major hurdle in the field. Similarly, the authors could consider sliced kernels on $\\mathcal{E}$ when the dimension is high. \n\n* Concerning experiments, an interesting task to consider could be the augmentation of a time-series dataset, and the analysis of the gain in performance for any model trained on this dataset. \n\n* Concerning experiments, I believe that it would be highly valuable to extend applications beyond finance. Generating time series is a major hurdle in many domains with great social impact such as neuroscience, healthcare, biology, climatology, economics ... \n\n* A valuable extension of this work would be the investigate the use of the devised score for other purposes than training generative models, such as two-sample tests for instance.", "questions": "* Please include confidence intervals in your tables: variability of your results is a very important aspect.\n\n* Could you please include vizualizations of the full generated time series, rather than only plotting the time marginals ?\n\n* Could you add experiments on at least one high dimensional and one non euclidian real-world dataset ?\n\nI would considered increasing my score if a significant number of concerns on the experiments are addressed. Similarly, I could consider lowering my score during the discussion phase if other reviewers relativise the strength of the theoretical contribution --- which again seems sufficient to me."}, {"review_id": "hfTFiU93yU", "reviewer": "Reviewer_4f46", "rating": 5, "confidence": 3, "soundness": 1, "presentation": 2, "contribution": 2, "summary": "This paper presents a novel approach to training neural stochastic differential equations by introducing Finite Dimensional Matching, a new scoring rule designed for continuous Markov processes. The proposed method leverages the Markov property of stochastic processes to reduce the computational complexity of training neural SDEs, with the goal of enhancing both efficiency and generative quality. Theoretical contributions establish that the new scoring rule provides a strictly proper method for comparing two-time joint distributions. Experimental results demonstrate improved performance in training efficiency and generative quality when evaluated against competing methods.", "strengths": "The paper tackles a crucial challenge in training neural SDEs by introducing a new scoring rule that optimizes efficiency for continuous Markov processes. The proposed FDM method is backed by mathematical proofs, providing a strictly proper scoring rule that extends from finite-dimensional distributions to continuous Markov processes. This theoretical contribution is valuable to the literature on neural SDE training methods. Experiments show that FDM offers computational efficiency gains, reducing training complexity from quadratic to linear in the number of discretization steps. The approach outperforms prior methods in computational efficiency, as shown in multiple experimental benchmarks.", "weaknesses": "The paper’s main theorem relies on strong assumptions regarding the Markovian properties and continuity of the processes involved. These assumptions may limit the applicability of the FDM algorithm in more complex, non-Markovian stochastic processes or those with jumps, which are common in real-world scenarios. Explicitly discussing these limitations and potential ways to relax these assumptions would make the contributions more transparent.\n\nAlso, the writing template is for ICLR 2024, not ICLR 2025.", "questions": "Please see weaknesses."}], "discussions": [{"comment_id": "K6s8IHGrP2", "replyto": "d4qMoUSMLT", "author_type": "reviewer", "reviewer": "Reviewer_DrUo", "comment": "As the end of the discussion period is approaching, I would like to stress once again to fellow reviewers and chairs that this paper substantiality contributes IMO to the field of neural SDEs and time series. The theoretical result presented in this paper hugely simplifies the current go-to pipeline, which often relies on the signature kernel whose computation is relatively expensive and does not scale to high dimensions. Applications to generative modelling display strong improvements over the current SOTA. \n\nThe authors have addressed various points during the discussion period and have improved their paper. \n\nThe potential applications, while not fully investigated by the authors, are considerable and include generative models, two-sample tests, kernel methods, transfer learning and the extension of all the previous to sequential data that lives in any space that can be endowed with a characteristic kernel. \n\nI strongly believe that this paper should be accepted."}, {"comment_id": "NFqXFHSTE5", "replyto": "vRFBxRJMoP", "author_type": "reviewer", "reviewer": "Reviewer_GLJM", "comment": "Thank you for your response. However, I did not notice any updates in the revised version. Based on the current write-up, I am not comfortable recommending acceptance for this paper. While the paper addresses an interesting problem, the current presentation is not ready for publication. I suggest that the authors undertake at least one round of revisions to significantly improve the clarity and quality of the presentation."}, {"comment_id": "s8gQN83Xr9", "replyto": "54A1L7tRHv", "author_type": "reviewer", "reviewer": "Reviewer_DrUo", "comment": "Many thanks for your answers. I look forward to reading your revision. Let me know if you want to discuss any other point of your paper."}, {"comment_id": "yCs0QMMYkc", "replyto": "82j5dxGHWA", "author_type": "reviewer", "reviewer": "Reviewer_SGBZ", "comment": "It is your decision, but then you need to argue things through separability. I thought the focus of this paper should be providing a divergence for neural SDE applications. How would jump processes of graphs fit into the neural SDE context?"}, {"comment_id": "54A1L7tRHv", "replyto": "VLrcyC8238", "author_type": "authors", "reviewer": null, "comment": "Thank you again for the insightful review.\n\n**Concerning point 3**, thank you for the reference. We'll try to include experiments on non-Euclidean datasets in the final revision.\n\n**Concerning points 5 and 6**, I agree these are interesting directions. However, they seem to be beyond the scope of neural SDEs and may deserve a separate paper.\n\n**Concerning point 7**, thank you for the insight. The sample complexity bound is currently on our to-do list. We hope to integrate it into the final revision.\n\n**Concerning point 8**, we appreciate this advice and agree to add error bars at least for some smaller datasets. However, we do not have access to our cluster in the next few days, so the error bars will be included in a later revision toward the end of the discussion period.\n\nAgain, we appreciate your time and insights."}, {"comment_id": "VLrcyC8238", "replyto": "NLIQ3tTsZw", "author_type": "reviewer", "reviewer": "Reviewer_DrUo", "comment": "Thank you for your answers on points **1,2 and 4**.\n\n**Concerning point 3**, you can apply your method to any space on which you have a suitable kernel if I am not mistaken. This paper develops neural SDEs for graphs for instance https://arxiv.org/pdf/2308.12316 and could give you some inspiration. You could also think of shapes evolving through time, a topic of interest in healthcare (think about organs and cells deforming over time). Finally, you could also consider measures evolving over time - see for instance https://proceedings.mlr.press/v151/bunne22a/bunne22a.pdf. \n\n**Concerning point 5**, yes this is exactly what I mean. I believe that this would allow you to show that generating time series for training a downstream model with your method has real benefits (which is something that I don't believe to be straightforward, see for instance https://arxiv.org/abs/2402.07712v1). \n\n**Concerning point 6**: even if your method simplifies the optimization of the neural SDE through a simpler kernel (compared to the signature kernel which can be technical to compute with high precision, or the GAN-like training of Kidger et al 2021), you still have to evaluate this very kernel. To simply such a computation, a long lasting line of research going back to Rahimi (2007) has resorted to randomised approximations. I believe that this framework could allow you to train neural SDEs for generating high dimensional time series. \n\n**Concerning point 7**: my intuition would be that the sample complexity is similar. But then you need to approximate the expectation in definition 1, so you might have to deal with a concentration bound on this stochastic approximation. Hence intuitively I believe that you should end up with the classical sample complexity of kernel mean embeddings + a term that depends on the number of times you draw times to compute your scoring function. Could you maybe provide an empirical assessment of this ? A plot of your metric vs the number of samples used and a plot of your metric vs the number of times drawn would suffice to convince me. \n\n**Concerning point 8**: I sincerely believe that this is not an acceptable answer. Reporting confidence intervals is a common and well-established scientific practice. NeurIPS, for instance, requires it in the paper checklist: https://neurips.cc/Conferences/2022/PaperInformation/PaperChecklist. Providing confidence intervals on even smaller datasets is a bare minimum."}, {"comment_id": "NLIQ3tTsZw", "replyto": "tzblZicoc8", "author_type": "authors", "reviewer": null, "comment": "Hi,\n\nThank you for the detailed and constructive review. We're happy to address the raised concerns.\n\n**1. Experiments are only carried out on 2-dimensional time series**\n\nWe apologize for the confusion, but this is not the case. For each dataset, we model all the features jointly, i.e., the dimension is $2$ for the metal price and exchange rate dataset, $5$ for the stock indices dataset, $4$ for the energy price dataset, $3$ for the bonds dataset, and $16$ and $32$ for the Rough Bergomi Model data. We present the two-dimensional joint marginals only because plotting higher-dimensional distributions on paper is challenging.\n\n**2. Visualizations of the full generated time series**\n\nThank you for pointing this out. We'll add them in the forthcoming revision.\n\n**3. Non-Euclidean real-world dataset**\n\nWe're not familiar with this type of dataset that is suitable to be modeled as an SDE. We would appreciate any recommendations you might have.\n\n**4. Choice of kernels**\n\nWe can include a study on the kernel choice in the appendix in the forthcoming revision.\n\n**5. Augmentation of a time-series dataset**\n\nWe're not sure what this means; could you please elaborate? Do you mean using the generative method for data augmentation and applying it to a downstream task?\n\n**6. Random feature approximations to the kernels for high-dimensional generation**\n\nCould you please elaborate on this as well? Do you mean computing the kernels with random feature approximation? It seems like this is just another way to compute/approximate the kernel.\n\n**7. Sample complexity**\n\nWe agree that this is an interesting aspect and will look into it. We'll try to integrate sample complexity analysis in the final revision by the end of the discussion period, but we cannot guarantee.\n\n**8. Confidence intervals**\n\nRepeating some models on higher-dimensional datasets or datasets with longer sequences can be computationally expensive, so we omitted the confidence intervals for consistency. The previous state-of-the-art [Issa et al., 2023](https://arxiv.org/pdf/2305.16274) in this area also didn't report error bars, so we believed this was acceptable. That being said, we're happy to include confidence intervals for some smaller datasets in the appendix in the forthcoming revision.\n\nWe thank you again for your insights, and please let us know if there is anything unclear; we're more than happy to clarify!"}, {"comment_id": "EwipoFhaHh", "replyto": "gJmdr7hXQ9", "author_type": "reviewer", "reviewer": "Reviewer_DrUo", "comment": "In my opinion, the \"strong assumptions regarding the Markovian properties and continuity of the processes involved\" mentioned by reviewer 4f46 are relatively acceptable for many applications - for instance stochastic modelling in biology and healthcare, with which I am fairly familiar. While I am not familiar with non-Markovian processes, I have never encountered them in my research. Do you have any examples in mind ? I am genuinely interested in this question. \n\nI do agree however that including jumps is of high interest. The authors could maybe at least empirically verify that their scoring rules are effective for training neural jump SDE (see https://arxiv.org/pdf/1905.10403) on a simple example. The code of this paper is available (https://github.com/000Justin000/torchdiffeq/tree/jj585) and seems to be well-built. However, I am unsure whether such an extension can be reasonably implemented during the short reviewing period."}, {"comment_id": "kWX9RZzXNE", "replyto": "82j5dxGHWA", "author_type": "reviewer", "reviewer": "Reviewer_DrUo", "comment": "Thanks to reviewer SGBZ and the authors for this discussion. I have myself also read carefully read the proof, and do no see any issues. This strengthens my opinion that the contributions of the paper are strong and of high interest."}, {"comment_id": "82j5dxGHWA", "replyto": "2sDvfSGPUv", "author_type": "authors", "reviewer": null, "comment": "Got it. I agree wo do need the separability for the last paragraph in the proof of Theorem 5. I'll add the Polish assumption for sure. \n\nThe reviewer DrUo mentioned $\\mathcal{E}$ could be a space of graphs, so would $\\mathcal{E}$ being Euclidian a bit too restrictive? \n\nThank you again for the insightful and timely comments."}, {"comment_id": "2sDvfSGPUv", "replyto": "LJD22sR0Oz", "author_type": "reviewer", "reviewer": "Reviewer_SGBZ", "comment": "I misinterpreted your $X,Y$ taking value in $\\mathcal E$. I thought you were saying that $\\omega\\rightarrow X(\\omega,\\cdot )$ is in $\\mathcal E$. In any case, you do need the processes to be separable (continuity would suffice), i.e. the FDD is determined by a dense subset of $[0,T)$, to argue for the last part of Theorem 5. This is not explicitly stated in the assumption. I would suggest you just go with $X\\in C_{\\mathcal E}[0,T]$ or $D_{\\mathcal E}[0,T]$ where $\\mathcal E$ is a Euclidian space.\n\nFor the sensitivity, I am wondering if the scoring rule is smooth (and how smooth) in a change in the parameter of the neural SDE. This will affect the convergence rates of your algorithm. I think this is interesting especially for the volatility part, as we know that the KL is infinity if you perturb the volatility (the two laws are not absolutely continuous)."}, {"comment_id": "gJmdr7hXQ9", "replyto": "hfTFiU93yU", "author_type": "authors", "reviewer": null, "comment": "Hi,\n\nThank you for the review, and we're happy to address the raised concerns.\n\n1. We're going to fix the template in the forthcoming revision.\n\n2. *Continuity and Markov Assumption*\n\nAs mentioned in the response to reviewer SBGZ ([link](https://openreview.net/forum?id=d4qMoUSMLT¬eId=zKB258o4r6)), the proof can be directly extended to cadlag Markov processes defined on an open interval $[0, T)$. The only modification needed in the proof is to construct the sequences to converge from the right in the last paragraph of the proof of Theorem 5. If you and the other reviewers deem it necessary, we are happy to extend the proof to the cadlag case. This allows our theory to cover jump processes.\n\nRegarding the Markov assumption, we believe this can be partially addressed by augmenting the paths with more information or leveraging a hidden Markov model. However, these are beyond the scope of neural SDEs and may warrant a separate paper. If you feel it is necessary, we can add additional discussion on the relaxation of these assumptions.\n\nWe thank you again for your insights, and please let us know if there is anything unclear; we're more than happy to clarify!"}, {"comment_id": "u94J6DHCbk", "replyto": "NyY4O36qZg", "author_type": "authors", "reviewer": null, "comment": "For the sensitive properties, could you recommend any literature? We'd deeply appreciate it."}, {"comment_id": "zKB258o4r6", "replyto": "NyY4O36qZg", "author_type": "authors", "reviewer": null, "comment": "Thank you for the quick response!\n\nIndeed, the proof can be directly extended to cadlag Markov processes defined on an open interval $[0, T)$. The only modification needed in the proof is to construct the sequences to converge from the right in the last paragraph of the proof of Theorem 5. As the focus of the paper is on neural SDEs, we did not include the proof for the cadlag case. That being said, if you and the other reviewers feel it is necessary, we are happy to extend the proof to cover the cadlag case."}, {"comment_id": "NyY4O36qZg", "replyto": "RALdUgwNyd", "author_type": "reviewer", "reviewer": "Reviewer_SGBZ", "comment": "I see, I forgot about the property of the scoring rule. \n\nBy the way, in responding to the concerns of the other reviewer about the jumps, I think this should work for (at least a large subclass of) Feller processes. Because essentially the law of the Feller process will be determined by the generator. This essentially means that you only need to match the joint distribution over an infinitesimal interval. So, the a.e. Lebesgue would suffice. \n\nIn this case, I will raise my score to 5 for now. If you can clarify your paper as suggested by me and other reviews, I will further improve the score to 6. \n\nAs the reviewer DrUo helpfully pointed out, this would be a better submission if you could say something about the qualities/properties of the scoring rule, especially sensitivity properties, i.e. how perturbations in $\\mu$ and $\\sigma$ would translate to $S$."}, {"comment_id": "RALdUgwNyd", "replyto": "LJD22sR0Oz", "author_type": "authors", "reviewer": null, "comment": "Hi,\n\nThank you for the detailed review and we're happy to address the raise concerns.\n\n1. For the proof, it is important that we choose $\\nu$ to be an equivalent measure to the Lebesgue measure $\\mu$ on $[0, T] \\times [0, T]$. So there exists a $\\mu$-a.e. positive function $\\lambda(t_1, t_2)$ such that $d\\nu = \\lambda d\\mu$. (in case you need a proof, see https://math.stackexchange.com/questions/1393425/equivalent-finite-measures-if-and-only-if-strictly-positive-radon-nikodym-deriv)\n\nThen \n$$E_{(t_1, t_2) \\sim \\nu} S(P_{\\pi_{t_1, t_2} (X)}, P_{\\pi_{t_1, t_2} (Y)}) = E_{(t_1, t_2) \\sim \\nu} S(P_{\\pi_{t_1, t_2} (Y)}, P_{\\pi_{t_1, t_2} (Y)})$$ \nsuggests\n$$E_{(t_1, t_2) \\sim \\mu} \\lambda(t_1, t_2) S(P_{\\pi_{t_1, t_2} (X)}, P_{\\pi_{t_1, t_2} (Y)}) = E_{(t_1, t_2) \\sim \\mu} \\lambda(t_1, t_2) S(P_{\\pi_{t_1, t_2} (Y)}, P_{\\pi_{t_1, t_2} (Y)}),$$\nwhich further implies \n$$E_{(t_1, t_2) \\sim \\mu} \\lambda(t_1, t_2) [ S(P_{\\pi_{t_1, t_2} (Y)}, P_{\\pi_{t_1, t_2} (Y)}) - S(P_{\\pi_{t_1, t_2} (X)}, P_{\\pi_{t_1, t_2} (Y)})] = 0.$$\n\nRecall that by definition of the proper scoring rule $S(P_{\\pi_{t_1, t_2} (Y)}, P_{\\pi_{t_1, t_2} (Y)}) \\geq S(P_{\\pi_{t_1, t_2} (X)}, P_{\\pi_{t_1, t_2} (Y)})$ and $\\lambda(t_1, t_2) > 0$ $\\mu$-a.e. due to the equivalence of $\\nu$ and $\\mu$. This makes the integrand to be non-negative $\\mu$-a.e., forcing $\\lambda(t_1, t_2) [ S(P_{\\pi_{t_1, t_2} (Y)}, P_{\\pi_{t_1, t_2} (Y)}) - S(P_{\\pi_{t_1, t_2} (X)}, P_{\\pi_{t_1, t_2} (Y)})] = 0$ $\\mu$-a.e.. \n\nAgain, due to the fact that $\\lambda(t_1, t_2) > 0$ $\\mu$-a.e., $S(P_{\\pi_{t_1, t_2} (Y)}, P_{\\pi_{t_1, t_2} (Y)}) - S(P_{\\pi_{t_1, t_2} (X)}, P_{\\pi_{t_1, t_2} (Y)}) = 0$ $\\mu$-a.e.. \n\nI think two important points here are the equivalence of between $\\mu$ and $\\nu$ and that $s$ is a strictly proper scoring rule.\n\n2. I'm checking where we need $\\mathcal{E}$ to be Polish. The disintegration theorem we used (Theorem 8.5 of Kallenberg (2021)) only requires the value space to be Borel. We'd sincerely appreciate it you can help us to point it out.\n\n3. For the other minor issues, we can fix them in the upcoming revision.\n\nWe thank you again for the insights and please let us know if you feel there is anything unclear; we're more than happy to clarify!"}, {"comment_id": "vRFBxRJMoP", "replyto": "JXEcfWl4J4", "author_type": "authors", "reviewer": null, "comment": "Hi,\n\nWe deeply appreciate the detailed reviews and are happy to address the concerns.\n\n*1. In Theorem 2, does the result hold for any scoring rule $s$, or does $s$ also need to be strictly proper?* \nYes. $s$ has to be strictly proper. This is stated in the first paragraph of section 4.1: \"*Let $s$ be any strictly proper scoring rule defined on ...*\". We agree this is probably not sufficiently clear and will move this condition to the theorem statement in the forthcoming revision.\n\n*2. The complexity reduction claim* \nWe apologize for this confusion. $D$ refers to the discretization step in the numerical integration of the SDE, i.e., for the SDE\n$$\ndZ_t = \\mu^{\\theta}(t, Z_t) dt + \\sigma^{\\theta}(t, Z_t) dW_t,\n$$\nwe evaluate the integral\n$$\nZ_T = Z_0 + \\int_{0}^{T} \\mu^{\\theta}(t, Z_t) dt + \\int_{0}^{T} \\sigma^{\\theta}(t, Z_t) dW_t\n$$\nby numerical integration with $D$ discretization steps using the Euler-Maruyama method:\n$$\nZ_{t_{k+1}} = Z_{t_k} + \\mu^{\\theta}(t_k, Z_{t_k}) \\Delta t + \\sigma^{\\theta}(t_k, Z_{t_k}) \\Delta W_k\n$$\nwhere $\\Delta t = T/D$, $t_{k+1}=t_k + \\Delta t$, and $\\Delta W_k$ is the Wiener increment over the interval $[t_k, t_{k+1}]$.\n\n$B$ in Algorithm 1 refers to the batch size of the SGD optimizer and is different from $D$. \n\nThe $O(D^2)$ complexity comes from the previous state-of-the-art Neural SDE training method proposed in [Issa et al., 2023](https://arxiv.org/pdf/2305.16274), which requires solving a PDE\n$$\nf(s, t) = 1 + \\int_{0}^{s} \\int_{0}^{t} f(u, v) \\langle dx_u, dy_v \\rangle_1 dv du\n$$\n(see (2) in the linked paper) and backpropagate the gradients through the PDE solver. The double integral is typically numerically approximated using a rectangular rule with $D$ discretization steps:\n$$\n\\int_{0}^{T} \\int_{0}^{T} f(u, v) \\langle dx_u, dy_v \\rangle_1 dv du \\approx \\sum_{i=1}^{D} \\sum_{j=1}^{D} f(u_i, v_j) \\langle dx_{u_i}, dy_{v_j} \\rangle \\Delta u \\Delta v,\n$$\nwhere $\\Delta u = T/D$, $\\Delta v = T/D$, and $u_i = i \\Delta u$, $v_j = j \\Delta v$ for $i, j = 1, \\dots, D$. The double sum requires $O(D^2)$. Also, the method of [Issa et al., 2023](https://arxiv.org/pdf/2305.16274) does not have a better complexity on $B$ as their objective also involves a double sum over $B$ (see (4) in the linked paper, our $B$ is their $m$, the double integral occurs in their $k_{sig}$). \n\nSo overall, our method reduces the complexity from $O(D^2)$ to $O(D)$ (or, from $O(D^2 B^2)$ to $O(D B^2)$ if you prefer to also include $B$ ) as we don't need to solve the PDE with the double integral. \n\n*3. Not sufficient preliminary material on scoring rules or background on how these rules are used to measure divergence between two Markov processes.* \nThank you for pointing out the lack of clarity. We're happy to add more explanation to the background of scoring rules in the forthcoming revision. We'll move the RBF kernel example earlier.\n\nWe thank you again for the insights and please let us know if you feel there is anything unclear; we're more than happy to clarify!"}], "meta_review": {"metareview": "In this work, the authors show that proper scoring rules on distributions can be extended to proper scoring rules on processes through their finite dimensional distributions. Using this result, a method called Finite Dimensional Matching (FDM) is proposed to bypass pain points with fitting stochastic processes to data. In a large number of numerical experiments, the authors demonstrate the advantages of this approach over competing methods. Reviewer opinions are mixed: two are positive, one mixed, and one negative. The reviewers stating positive opinions have stood firm on their stance that the paper should be accepted, while the more negative reviewers have failed to adequately engage in the discussion period. This is unfortunate, as the discussion has been considerable!\n\nMy own stance agrees with the positive feedback: the results are simple in retrospect, but profound in development and application. This is a strong contribution to the literature, and I believe it is worthwhile to the community even in its current state. **However**, in addition to reviewer feedback, I object to several aspects of the presentation: text in figures is unacceptably small (should be close to same size as surrounding text), avoid repeating the same reference multiple times in a paragraph, values in tables might be better separated by brackets e.g. .137 (17.0), some mistaken capitalization, placement of footnote 1 is strange, missing spacing around some references, no space between tables and table legends, paragraph discussing figures is too far away from the figures themselves. I implore the authors to address these points in their next revision. I also agree that providing the most general possible form of Theorem 2, even if in supplementary material, would be ideal. \n\nWith this in mind, I give a tentative recommendation for acceptance.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "jj7b3p5kLY", "forum_url": "https://openreview.net/forum?id=jj7b3p5kLY", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "The AdEMAMix Optimizer: Better, Faster, Older", "authors": ["Matteo Pagliardini", "Pierre Ablin", "David Grangier"], "abstract": "Momentum based optimizers are central to a wide range of machine learning applications. These typically rely on an Exponential Moving Average (EMA) of gradients, which decays exponentially the present contribution of older gradients. This accounts for gradients being local linear approximations which lose their relevance as the iterate moves along the loss landscape. This work questions the use of a single EMA to accumulate past gradients and empirically demonstrates how this choice can be sub-optimal: a single EMA cannot simultaneously give a high weight to the immediate past, and a non-negligible weight to older gradients. Building on this observation, we propose AdEMAMix, a simple modification of the Adam optimizer with a mixture of two EMAs to better take advantage of past gradients. Our experiments on language modeling and image classification show---quite surprisingly---that gradients can stay relevant for tens of thousands of steps. They help to converge faster, and often to lower minima: e.g., a $1.3$B parameter AdEMAMix LLM trained on $101$B tokens performs comparably to an AdamW model trained on $197$B tokens ($+95\\%$). Moreover, our method significantly slows-down model forgetting during training. Our work motivates further exploration of different types of functions to leverage past gradients, beyond EMAs.", "keywords": ["Optimization", "LLM", "Deep Learning", "Momentum"], "primary_area": "foundation or frontier models, including LLMs", "pdf_url": "https://openreview.net/pdf?id=jj7b3p5kLY", "decision": "Accept (Poster)", "num_reviews": 5, "num_discussions": 15, "reviews": [{"review_id": "n0FXVCyfdP", "reviewer": "Reviewer_zVt6", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 4, "contribution": 3, "summary": "The paper propose a new optimizer called AdEMAMix, which add an additional EMA sequence to Adam. The intuition is to to keep a high sensitivity to recent gradients (using m1), while also incorporating information from older gradients (using m2).", "strengths": "The prerformance of AdaEMAMix is quite impressive. The writing is excellent. The ablation studies are thorough.", "weaknesses": "The motivation is rather vague. See below.", "questions": "1. The motivation on the Rosenbrock function is rather weak. Why do you consider the Rosenbrock function? How does it relate to LLM training? Why do you think the implication on the Rosenbrock function can be transferred to LLM? To me, there is a huge gap between these two.\n\n2. Need more evidence that \"beta1 = 0.9 is indeed suboptimal for LLMs\". While I fully understand Figure 3, i still did not fully convinced that \"LLM training really needs to use more historical gradients, in one way or another\". More LLM-related evidence is needed to convince the that it is indeed a major bottleneck. Despite the impressive performance, the overall motivation still seems quite vague to me. I suggest the authors put more effort into designing more experiments for better motivation.\n\n3. It is difficult to understand the highlighted sentence in the introduction: \"While changing the direction of the slow momentum is difficult, any adjustment orthogonal to that direction is easy—which favors fast progress in sinuous canyon-like landscapes.\" I don't understand what you mean by \"change the direction of the slow momentum\" or \"any adjustment orthogonal to that direction\". Please explain more. Also, the authors mentioned \"canyon-like landscapes (in Rosenbork function).\" but never connected it to LLM in the script. This also makes the motivation rather weak.\n\n4. I don't quite understand how to read Figure 2. For instance: \n\n -- in Figure 2 (a), is AdaEMAMix considered better than Adam? It is not clear whether we can draw such conclusion based on the figure. In my opinion, at least 10 x more iterations are needed to draw valid conclusions.\n\n -- in Figure 2 (c), did AdaEMAMix converge to the optimal solution? \n\n5. The proposed method boosts performance by using extra memory (to store an additional copy of momentum). Though many readers might regard it as a drawback, I personally think such a trade-off is acceptable as long as the performance gain is significant. Further, I think AdaEMAMix can be combined with some orthogonal methods to reduce memory. For instance, AdaEMAMix can be combined with the recent method Adam-mini [1] to reduce the memory for V. I suggest the authors try it out.\n\n\n\n6. Some missing related works: [2] proves that vanilla Adam can converge under a wide range of beta1 = 0, 0.5, 0.9, 0.99, etc., as opposed to the divergence result in Reddi et al. 2018. This result lays down a preliminary foundation of this work. Without the theoretical guarantee, it would be dangerous to play with beta1 of Adam.\n\n[1] Zhang, Y., Chen, C., Li, Z., Ding, T., Wu, C., Ye, Y., ... & Sun, R. (2024). Adam-mini: Use fewer learning rates to gain more. *arXiv preprint arXiv:2406.16793*.\n\n[2] Zhang, Y., Chen, C., Shi, N., Sun, R., & Luo, Z. Q. (2022). Adam can converge without any modification on update rules. *Advances in neural information processing systems*, *35*, 28386-28399."}, {"review_id": "HZGoTxkYT6", "reviewer": "Reviewer_H5Kk", "rating": 10, "confidence": 4, "soundness": 4, "presentation": 4, "contribution": 4, "summary": "This paper proposes AdEMAMix, a new optimizer which outperforms Adam on language model training and ViT training. Their empirical results show large benefits (~50% reduction) over Adam in the regime of noisy gradients i.e. small batch size or longer runs. The main idea behind the optimizer is to maintain two momentum terms and combine them to get the final movement direction. The coefficient of momentum and their combination are also dynamically adapted.", "strengths": "These are strong results on a very important problem. They also provide many optimizer ablations in the Appendix showing the robustness of their proposed optimizer.", "weaknesses": "Since many of the experiments are with small batch size it would have been interesting to explore the effect of weight averaging. For example, is it the case that weight averaging helps AdamW and AdEMAMix equally? Or not?", "questions": "The authors state “While no answer to those questions is given in this work, we\nprovide a toy justification which indicates that large momentums can have a positive impact in\nnoise-free non-convex settings (see Fig. 2)—indicating the improvement of our approach is at least partially explainable without considering variance-reduction effects.” Is there empirical support for this in the LLM experiments? looking at Figure 17 the benefit seems to drop with increasing batch size. Note that the maximum batch size used here (512k) is smaller than that used for LLMs like Llama (4m), though I agree that the model size is also small here. Could the authors provide an experiment with 2m batch size to see the trend?"}, {"review_id": "NEAO3UGAgL", "reviewer": "Reviewer_GCtm", "rating": 5, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "This paper proposes to use an additional momentum buffer for the Adam optimizer. The motivation behind is that a single momentum buffer may not be able to utilize both past and current gradient information efficiently. To this end, one buffer with large momentum parameter $\\beta$ is for slow-changing directions, and the other one with small $\\beta$ is for fast adaptation to current gradient. Then, the two momentum buffers are mixed according to some fixed ratio ($\\alpha$), and the rest of updates (i.e., pre-conditioning and weight-decay) follow those of AdamW. However, a large $\\beta$ (for slow-changing momentum buffer) can cause training instabilities in the initial stage. To this end, the paper proposes to gradually increase this parameter until the target value is reached. Then, experiments are performed on various language modelling tasks to show that the proposed algorithm is faster than AdamW given a fixed computation budget.", "strengths": "- The paper is well written and easy to follow. The experiments on the Rosenbrock function are convincing. There are some similar (loss landscape) models proposed recently to analyze learning rate schemes [1]. It maybe interesting to draw some theory/experimental connections in the case of momentum. \n\n- The experiments seem to be comprehensive covering different settings and tasks. The improvement over baseline is shown.\n\n[1] Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective.", "weaknesses": "- There are some additional hyperparameters introduced. Noticeably, it seems that $\\alpha$ (that controls the mixing ratio) is important for the algorithm. It would be important to study the sensitivity of the algorithm to this hyperparameter, given that it essentially controls the contribution of each momentum buffer to the current update. \n\n- It would be better if some convergence guarantees of the algorithm can be provided even in the convex setting. For example, what is the relationship between the two momentum parameters that would guarantee convergence?", "questions": "Overall, I think the idea introduced in this paper is interesting. The paper can be improved if the above weaknesses can be addressed."}, {"review_id": "lw3RysqbTt", "reviewer": "Reviewer_98Nr", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 4, "contribution": 3, "summary": "This paper introduces AdEMAMix, an optimizer modifying Adam by incorporating a mixture of two Exponential Moving Averages (EMAs) to better leverage past gradients. Traditional Adam or AdamW optimizers rely on a single EMA to smooth gradients, which prioritizes recent gradients over older ones. However, the authors argue that this approach is suboptimal because it cannot simultaneously emphasize both recent and very old gradients. They conduct experiments across language (LLMs and Mamba) and vision (ViTs), showing that AdEMAMix outperforms and tends to forget training data slower than Adam.", "strengths": "- The paper is written and organized well. The optimizer is a simple modification to Adam and is conceptually easy to understand, with the algorithmic differences presented clearly.\n- The authors extensively benchmark their proposed optimizer against Adam across architectures, dataset domains and hyperparameters.\n- The spirit of AdEMAMix (having a ‘fast’ and ‘slow’ tracking of gradients) could be useful for mitigating forgetting in continual learning regimes, and is also shown to be relevant for current pretraining regimes where we often train much longer past eg. Chinchilla-optimal tokens.\n- The authors have experiments which address one’s immediate concerns of the proposed optimizer, like hyperparameter tuning and memory overhead.", "weaknesses": "- The proposed optimizer is not that novel, with several existing approaches in maintaining a longer horizon at the cost of memory which have been mentioned by authors; in particular, AggMo is essentially the same as AdEMAMix but applied to gradient descent, and the only difference in the setting of this work is applying the same principle to Adam.\n- Although the authors have performed experiments on the hyperparameter stability of AdEMAMix, they do introduce two new parameters beta_3 and alpha which require the use of schedulers for larger values to avoid divergence in early iterations, which may require more extensive tuning depending on the setting (eg. fast domain shifts).\n- There is an additional memory overhead in the order of the model size to incorporate the second EMA, which the authors propose can be mitigated by setting beta_1 to be 0; however, these runs exhibit instabilities and sometimes large spikes in training especially at larger batch sizes (Figure 22). These spikes are generally undesirable even though the loss seems to ‘recover’, it is unclear to me that such spikes wouldn’t have more detrimental effects at larger scales.", "questions": "1. Doesn’t the introduction of the additional beta_3 term potentially lead to an exploding step size if the incoming gradients are extremely small for many steps (the additional term in the numerator causes it to decay slower than the denominator in the update)? This isn’t unforeseeable in language model training, where the model could encounter multiple documents of rare tokens consecutively.\n2. Do the authors have any experiments adding an EMA to other optimizers like Adafactor or Signum? Does the additional EMA always provide a gain, even with factored gradients?\n3. Related to the previous question to address the memory overhead, does storing the slow-moving EMA in lower precision affect AdEMAMix performance?\n4. Why was the optimal Adam learning rate being too high for AdEMAMix only manifesting at higher scales? Does this perhaps suggest AdEMAMix being more difficult to tune at larger scales? I’m wondering if there was there further investigation into this eg. Was there a certain part of the network that was destabilizing?\n5. Did the authors use certain stabilization strategies, eg. QK-LayerNorm in your experiments? I’m wondering if this could be used for the author’s benefit, to stabilize training and offer greater ease in hyperparameter tuning. \n6. Why was alpha set to 2 and 1 for the switching Adam -> AdEMAMix experiments whereas the optimal values found in hyperparameter tuning were much higher?"}, {"review_id": "uHi5Cjthff", "reviewer": "Reviewer_A3u9", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 2, "contribution": 3, "summary": "This work proposes a new optimization algorithm named AdEMAMix, a variant of Adam(W). They add an aditional moving average of gradients to the algorithm, on top of the existing first and second moment moment averages. In contrast to the traditional $\\beta_1 = 0.9 $ moving average typically used in Adam, the new moving average that updates much slower ( $\\beta_3 = 0.9999 $). This additional moving average is then used with the traditional moving average to update the parameters, in a similar fashion to the Adam update. The authors argue this assists in preventing the “forgetting” of data and better uses past information than the traditional Adam update, and show strong empiracle performance on some vision and language settings.", "strengths": "The primary strength of this paper apperas to be an algorithm with strong empiracle performance. The authors conduct a thorough set of experiments and ablation in the setting of language modelling and vision using multiple architectures and configuarations. The idea is interesting and combines other ideas from the optimization literature (the Adam update, multiple momentum terms). I am particularly interested in the authors study of how the loss with respect to different batches changes over time depending on the order in which they where trained on. I think this direction is a promising area of research for understanding why some optimizers work better than others. The paper also contains many interesting questions that are of interest to the optimization community, although struggles to answer them (as does the rest of the field).", "weaknesses": "Some specific concerns are addressed in the questions section, but I think the fundamental challenge this paper faces is explaining why the algorithm is better. As is typical with deep learning optimizers, it would be too hard for any provable reason why the algorithm may be superior, but I do think some more effort in the paper could be focused on explaining why this is a good idea, especially with the increased memory requirments. As said in the strengths, I do like the amount of questions being asked but it would be nice to see if the answer to any of those questions is leading to a better optimizer. I think it’s possible that the “forgetting slower” direction may lead to an answer, but the majority of the experiments and argumentation for that is relegated to the appendix. While practically showing plots with better performance is good, from a scientific perspective understanding why that performance is better is important.\n\nI’m slightly concernerd by the number of additional hyperparameters being introduced, which appear to be somewhat sensitive given the use of schedules for most of them. This one can make the algorithm more difficult to use in practice, and two runs the risk of having some lucky configurations for a given problem that works great but in general may not be great. I also feel that the additional memory requirment is a non trivial drawback. While it does not increase communication costs, just running in FSDP is not an easy option for many people who do not have easy access to multi-gpu servers.\n\nA more cosmetic note, the figures in this work can lead to confusion at times. In multiple of the figures, there are three subplots where some and perhaps all are unrelated. In some cases, Some subplots are meant to be related, but use different colours for the same thing. This is also likely the reason many of the figure captions are $\\approx $ half a page. I think it would be valuable for the authors to give the plots another pass before any camera ready version. \n\nOverall despite being a potentially effective algorithm, I do feel that this paper is somewhat lacking in the motivation and justification of that. I would be open to raising my score in the discussion period but in it’s current state I’m not sure this work is ready. I would be interested in why the authors think this is a good idea and if there are better ways to show that to be the case. This does not need to be large experiments or rigorous theory. Well designed small experiments or simple derivations can be just as effective give an intuition to explain the performance on minimizing the loss or solving some problem existing optimizers have.\n\nMore specific questions can be found below.", "questions": "- In figure 2, authors claim the proposed algorithm does not exhibit oscilations when compared to Adam. However, in both 2 (a) and 2 (c) it appears the AdEMAMix does in fact overshoot the minimizer, further than Adam at that. It looks to me that Adam with “default” hypterparamters appears to be the most stable of all the options. In general this is a little concerning as modern loss surfaces will be much more complex than the Rosenbrock function and these large oscilations/overshooting minimizers may be difficult to deal with or even diagnose. A seperate point on this figure but potentially contributing to my confusion is that the colour scheme in the three subplots is not consistent, if the authors could use a shared legend or something along those lines the results may be more interpretable.\n\n- Given this method requires an additional buffer for the second EMA, this authors discuss it potentially being slower and mention that it is also possible to set $\\beta_1 $ to zero to compensate for this. I’m curious if any of the main experiments or the wall time comparison are using this $\\beta_1= 0 $. The hyperparameter configuration for figure 1 uses $\\beta_1 =0.9 $ so I would just want ensure the wall time comparison is too to ensure figure 1 doesnt change substantially if the x axis is changed from steps to hours.\n\n- What does older gradients being outdated mean? Is there any way to quantify this? Any intuition? All the data is weighted equally as far as minimizing the loss is concerned despite being at different points in the loss surface when you evaluate it. In general this whole idea of older/newer gradients being less/more relevant seems vaugly okay, but it would be nice to see a more scientific defintion or explanation. This doesn’t need to be totally rigorus and have proof but the current explanations seem hand-wavey and leads to far more questions than answers.\n\n- Why not just use a normal non exponential moving average? You can have a uniformly weighted moving average that should not diverge, has this been tried?\n\n- Is there any intuition for why $\\alpha $ is so large (in particlar $\\ge1 $?) Does this lead to a smaller step size being necessary? Can you also set it to zero until some point to fight early instability?\n\n- Why is there no bias correction on $m_2 $? I don’t really think it matters much much and most analysis ignores it but I’m curious why that decision was made in comparison to the typical Adam bias correction step.\n\n- Regarding figure 4, while it does appear that the proposed algorithm “forgets” slightly less quickly than AdamW, is the claim that this is why it works better? It appears to me that the biggest predictor of “forgetting” is the injection time $t_B $ rather than the algorithm, although perhaps if this minor difference is compounded across batches this leads to the superior performance of the algorithm. \n\n- In many experiments, notably figure 5, the x axis is limited to not show the optimization performance in earlier parts of training. Given that instability is cited as a challenge for this algorithm, to a point where complex schedules need to be used, I’m curious how unstable the well tuned algorithm is. Are there instability issues in the 0-200k range in figure 5 (b) with AdEMAMix from 0 for example?\n\n- Recent analysis in [1] has shown Adam(W) should work well if the gradient and Hessian become correlated, and [2] provided a mechanism where that shows up in language models. Can the authors reconcile this with their modifications to the algorithm? It seems to me those works would suggest this new variant is not a great idea so I’m wondering if the authors have any thoughts for why that isn’t the case.\n\n- Any idea why there is less good performance on imagenet 1k? It seems AdamW is performing better there although all the plots are fairly close.\n\n\n[1]\nMichael Crawshaw, Mingrui Liu, Francesco Orabona, Wei Zhang, Zhenxun Zhuang\n\nRobustness to Unbounded Smoothness of Generalized SignSGD\n\nhttps://openreview.net/forum?id=8oj_2Ypp0j\n\n[2]\nFrederik Kunstner, Robin Yadav, Alan Milligan, Mark Schmidt, Alberto Bietti\n\nHeavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models\n\nhttps://arxiv.org/abs/2402.19449"}], "discussions": [{"comment_id": "IudLCzC0DW", "replyto": "aqsN4kN0Et", "author_type": "authors", "reviewer": null, "comment": "Now answering other points raised by the reviewer:\n\n1. Our Rosenbrock experiments only allow us to gain a bit of intuition by illustrating how (i) mixing slow and fast changing signals visibly results in less oscillations, (ii) using a single EMA with a large beta is not working, and (iii) AdEMAMix finds \"good\" solution relatively early. Beyond those observations, as mentioned before, the behavior on Rosenbrock is not necessarily an accurate representation of a high dimensional optimization landscape. The focus of this work is to develop a better optimizer for deep learning. \n2. In line 249, we are talking about any exponential moving average over gradients. We are conveying that—in order to use an EMA with a very large beta—it is important to have another term which keeps its sensitivity to recent gradients. \n\n3 & 5. We thank the reviewer for their interaction and providing suggestions on how to further improve our work."}, {"comment_id": "aqsN4kN0Et", "replyto": "VDwpDjPeJs", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for their reply. We first would like to focus on addressing a key misunderstanding:\n\n> In rebuttal, the authors mentioned that \"we propose a simpler solution which consists in setting beta1=0. \". This makes me quite confused. When beta1= 0, how is AdEMAMix different from AdamW with beta1 = 0.9999? They seem to be the same up to a constant alpha, right?\n\nThis is not the case. When $\\beta_1=0$, $m_1=g$, with $g$ being the gradient. Therefore, $m_1+\\alpha m_2$ in AdEMAMix becomes $g + \\alpha m_2$. This is quite different from Adam, and still fits our narrative of balancing two signals, one sensitive to local variations of the loss landscape ($g$), and one incorporating information from older gradients ($m_2$). \n\n> Further, if AdEMAmix (beta1 = 0) works as well as beta1 >0, then why bother introducing the idea of \"balancing fast and slow changing signals\"? It seems that we do not need to balance anything at all. This further makes me confused about the motivation of the paper.\n\nNow that we have established that AdEMAMix with $\\beta_1=0$ is different from Adam, we can ask if there is an advantage to using $\\beta_1>0$ instead of always $\\beta_1 = 0$? The answer is yes. In App.C.1.8, we show that $\\beta_1 = 0$ can be less stable. For this reason, we kept the two EMA formulation in our method, as there might be cases where $\\beta_1>0$ is preferred. In general, we recommend using $\\beta_1>0$ unless memory is an issue."}, {"comment_id": "ZfFwcqATU5", "replyto": "ZeTAVNidfE", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for their reply. \n\n1. We would like to bring some final clarifications on the following point:\n\n> if the proposed solution for reducing memory overhead by setting beta1 = 0 is sufficient, then why do the authors propose generally having the two momentum signals? Is there an explanation for why setting beta1=0 works?\n\nWe provide an explanation in our introduction, line 85 to 90, quoted here for convenience:\n\n> We observe that a single EMA cannot both give a significant weight to recent gradients, and give a non-negligible weight to older gradients (see Fig. 3a). However, a linear combination between a “fast-changing” ***(e.g. β1 = 0.9 or β1 = 0)*** and a “slow-changing” (e.g. β = 0.9999) EMA allows the iterate to beneficiate from (i) the great speedup provided by the larger (slow-changing) momentum, while (ii) still being reactive to small changes in the loss landscape (fast-changing). \n\nThe core of the method is not to require two EMAs. It is to combine a term gathering information from old gradients ($m_2$), with a term that remains sensitive to local loss variations ($m_1$). When $\\beta_1 = 0$, $m_1$ is simply the gradient $g$, so $m_1+\\alpha m_2$ becomes $g+\\alpha m_2$. Following only the direction of $m_2$ (removing entirely $m_1$), would not work, as shown in our Rosenbrock experiments (Fig.2) and language modeling experiments increasing $\\beta_1$ in Adam (Fig.20 and Fig.21). \n\nNow, is there an advantage to using $\\beta_1 = 0.9$ instead of $\\beta_1 = 0$? The answer is yes. In App.C.1.8, we show how $\\beta_1 = 0$ can be more unstable. For this reason, we kept the two EMA formulation in our method, as there might be cases where $\\beta_1 > 0$ is preferred. In general, we recommend using $\\beta_1 > 0$ unless memory is an issue. \n\n2. Concerning your second point:\n\n> There are also other barriers for its use more generally, eg. in fast domain-shift settings where the question of setting a performant scheduler will likely be more difficult.\n\nOur work focuses on a very standard and widely adopted optimization setting, which consists of training models on static datasets. This setting is the one currently used throughout the industry to train models such as large LLMs, state of the art vision models and more. We agree studying our method in domain-shift settings is an interesting direction for future work, yet this is outside of the scope of our work."}, {"comment_id": "ZeTAVNidfE", "replyto": "n7oWHvF6Qp", "author_type": "reviewer", "reviewer": "Reviewer_98Nr", "comment": "I thank the authors for their response and for answering my questions. I have read the other reviews and responses and I do recommend acceptance for the paper. I agree that while the notion of adding more momentum terms is not new, there is value in demonstrating an algorithm which is performant in more practical settings. I will maintain my score as weak accept due to similar comments brought up by the other reviewers; for instance, if the proposed solution for reducing memory overhead by setting beta1 = 0 is sufficient, then why do the authors propose generally having the two momentum signals? Is there an explanation for why setting beta1=0 works? There are also other barriers for its use more generally, eg. in fast domain-shift settings where the question of setting a performant scheduler will likely be more difficult."}, {"comment_id": "VDwpDjPeJs", "replyto": "n0FXVCyfdP", "author_type": "reviewer", "reviewer": "Reviewer_zVt6", "comment": "Thanks for the rebuttal and I sincerely apologize for the late reply. I have carefully read the rebuttal. I still think this is a **professionally-written paper with strong empirical evidence, yet the motivation is still a bit weak**. I vote for acceptance and I will keep my score. \n\nHere are some follow-up comments. I think they would be helpful to improve the paper quality. \n\n1. **Regarding my Q4:** I am clearly aware of \"oscillations of momentum methods\". This is not my comment. My comment is \"it is not clear whether we can draw such a conclusion based on the figure. In my opinion, at least 10 x more iterations are needed to draw valid conclusions.\". Anyhow, the authors did not provide new experiments. \n\n2. **Regarding paper presentation**. Some notations are unclear and are not consistent. For instance, in line 249 \"This allows for the use of much larger beta values e.g. 0.9999\" what is beta here? Is it beta1 or beta3?\n3. **Regarding motivation.** I still think it is too weak to motivate using Rosenbrock function. Here are my suggestions: please try linear model + 2-class cross-entropy loss classification, which also have \"river-valley\" landscape. Show that AdEMAmix has an advantage on this task (perhaps a bit more theory would be cool), then generalize to 1-hidden-layer-NN + cross-entropy or 1-layer-Transformer + cross-entropy. These experiments and discussions will provide much stronger insight and motivation. At least much better than Rosenbrock.\n\n3. **Regarding your rebuttal.** In rebuttal, the authors mentioned that \"we propose a simpler solution which consists in setting beta1=0. \". This makes me quite confused. When beta1= 0, how is AdEMAmix different from AdamW with beta1 = 0.9999? They seem to be the same up to a constant alpha, right? If so, then does it mean that AdamW (beta1 = 0.9999) works better than AdamW (beta1 = 0.9)? However, the authors also claim that AdamW (beta1 = 0.9999) does not work well. So I am confused what is going on here.\n\n Further, if AdEMAmix (beta1 = 0) works as well as beta1 >0, then why bother introducing the idea of \"balancing fast and slow changing signals\"? It seems that we do not need to balance anything at all. This further makes me confused about the motivation of the paper.\n\n4. **Missing important discussions.** It is good to have the discussion on line 407 \"Why not simply increase AdamW's beta1?\" But this discussion is rather numerically guided. I suggest adding a rigorous math clarification that \"increasing beta1\" is NOT equivalent to \"linear combination over an additional EMA copy\". (I assume they are indeed different, right? I didn't have time to carefully check)"}, {"comment_id": "1Ed9HLmwr2", "replyto": "Jy2NhsGg3c", "author_type": "reviewer", "reviewer": "Reviewer_H5Kk", "comment": "Thank you for the response, I maintain my positive assement of the work."}, {"comment_id": "uKj0eRSAZi", "replyto": "uHi5Cjthff", "author_type": "reviewer", "reviewer": "Reviewer_A3u9", "comment": "I thank authors for their responses. In regards to their question about why the work I cited may not agree with their approach, I preface  by saying this is my  inutition and not rigorous. In [1], it is shown that Adam will benefit when the norm of the gradients and the trace of the Hessians become correlated. Then, [2] shows a situation where this occurs naturally where Adam  is superior to SGD. My intuition is that, since now we have two EMAs (with the slower changing one being more dominant due to $\\alpha\\ge1$) for the gradients but only one for the squared  gradient, this correlation will be weakened. In future work, it may be a good experiment to run in order to figue out if this algorithm is effective for similar  reasons to Adam. If the authors  have any thoughts on this I would be interested.\n\nOverall I do still feel that there could more investigation of why this algorithm works better even with the response probvided by the authors. I will  raise my score to a weak accept based on the authors responses and hopw that that question will be addressed in future work. I would again like to point out that some of the figures a bit confusing as is stated  in my original review,  so if this work ends up getting accepted I reccommend the authors take  a  second  pass on  them."}, {"comment_id": "mPvWwAr0fo", "replyto": "KJGzAu95vU", "author_type": "authors", "reviewer": null, "comment": "8. In Fig.5, the curves partially shown for AdamW and “AdEMAMix from scratch” are similar to the ones in Fig.1. Curves for Fig.5.b reach similar losses as the curves in Fig.1.a trained for 500k iterations. Similarly, the two curves from Fig.5.c reach similar losses as the curves in Fig.1.c trained for 770k iterations. For historical reasons, Fig.5 relies on a larger clipping, which does not affect the performance but smoothes the curves. We thank the reviewer for raising this point and will update the description of our experimental protocol to clarify this point. \n9. We thank the reviewer for sharing those interesting works. We would appreciate it if the reviewer could elaborate on which part of those works suggests that our method should not work well? \n10. On imageNet 1k, Fig.31 might paint a clearer picture of what is happening. Looking at the test and training losses on the last column, we see that while the training and test losses decrease faster for AdEMAMix, the test loss shows a clear overfitting pattern. We believe AdEMAMix to be best suited in cases where enough data is available w.r.t. the capacity of the model. \n\nWe hope we have addressed the concerns of the reviewer. We stay at their disposal for any further questions and hope the above response would bring the reviewer to raise their score. \n\n[1] Wen et al., 2024: “Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective”\n[2] Flammarion et al., 2015: \"From averaging to acceleration, there is only a step-size.\""}, {"comment_id": "KJGzAu95vU", "replyto": "uHi5Cjthff", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for the time taken to review our work. On the intuition behind our method. We provide an intuition in our Rosenbrock experiments. In those experiments, the fast EMA (m1) helps the algorithm stay in the valley, and the slow EMA (m2) helps the algorithm crawl along the valley, remembering the valley's direction and accelerating. As pointed out by reviewer GCtm, a recent line of work [1] suggests a “river-valley” structure of the loss landscape of LLMs, which does share some similarities with the Rosenbrock function. We provide another intuition through studying the forgetting of the training data during training (Fig.4). This shows how the larger momentum can help the algorithm to be more data-efficient. \n\nNow answering the reviewer’s questions:\n\n1. In Fig.2.a, we indeed observe the distance to the solution going down and then up again. This is due to the momentum values causing the iterates to overshoot the solution. When we mention the lack of oscillations for AdEMAMix, we refer to the trajectory not oscillating around the bottom of the valley. The inability of Adam to follow a similar trajectory when using a large $\\beta_1$ is what we want to focus on. It shows that—for larger momentum values—the iterate fails to respond to local variations of the loss landscape. This motivates us to also include a momentum term with a small $\\beta$ as in AdEMAMix. The oscillations of momentum methods around the solution on deterministic objectives are well documented, see e.g. https://francisbach.com/continuized-acceleration/, https://distill.pub/2017/momentum/, or [2, fig5]. In practice, to the best of our knowledge, oscillations due to momentum have never been an issue in deep learning training (i.e. the phenomenon of overshooting the solution seems to be only observed in toy settings, see Fig.20 and Fig.21 to see the lack of oscillations even with very large $\\beta_1$ in a real world setting using Adam). This is why we focus on the left part of the curves, corresponding to the first time required to arrive in the vicinity of the solution. \n2. The times in Fig.5.a use $\\beta_1=0.9$.\n3. Given a past timestep $t$, a sample $x$, the gradient $\\nabla_\\theta \\ell(x, \\theta_{t})$ is outdated at time $t+K$ if its inner product with $\\nabla_\\theta \\ell(x, \\theta_{t+K})$ is non-positive. Such a gradient would be outdated as it pushes the iterate in a direction that is no longer relevant to decrease the loss on $x$. A simple illustration could be overshooting the minimum of a quadratic function. Gradients are local approximations, we mention outdated gradients as it is surprising for us that one can take advantage of very old gradients. Our intuition was that those would no longer be relevant given the local loss landscape. We believe that the fact this raises many questions is a good thing.\n4. In a limited setting, we also tried with uniform averages, we found those to also improve over the baseline but much less. We therefore decided to focus on EMAs. \n5. We do use a linear scheduler on alpha. Starting from $0$, alpha grows up to its final value over the same number of iterations as $\\beta_3$. Setting it to $0$ during the early iterations can work, this corresponds to our experiments on switching from AdamW to AdEMAMix (Fig.5.b and Fig.5.c). Depending on when the switch occurs, the final loss—while better than the baseline—might not be as good as training AdEMAMix from scratch. In Fig.5.b and Fig.5.c, we do not use any schedulers on beta3 nor alpha. \n5. The bias correction is there to compensate for the initialization of the buffers to $0$. In our case, it is desirable to not do the correction as it implies the values of m2 are initially very small, and increase little by little. Intuitively, this means that, for the early steps, AdEMAMix behaves similarly to AdamW. Moreover, when switching from AdamW to AdEMAMix, the effective learning rate increases, which we believe explains the small increase in loss that can be observed right after the switch in Fig.5.b and Fig.5.c. Not doing any bias correction might help to smooth the transition.\n7. We were curious on what was explaining the correlation between forgetting and iterations and ran additional experiments in App.C.1.3 (see Fig.11). In that experiment we use a WSD learning rate scheduler instead of a cosine decay. This allowed us to show that the forgetting is tied not only to the optimizer, but also to the learning rate. Setting this dependency aside, it is clear from Fig.4 that AdEMAMix forgets training batches slower than AdamW."}, {"comment_id": "n7oWHvF6Qp", "replyto": "BwbjzTiqBM", "author_type": "authors", "reviewer": null, "comment": "Now answering the reviewer’s questions:\n\n1. While this is possible, we did not observe this in practice. Given the large values of $\\beta_2$ and $\\beta_3$, the contribution of each individual gradient is very small. For an exploding step size to occur, small gradients would need to be observed over many thousand consecutive steps, which is unlikely if the data is shuffled. Moreover, the epsilon on the denominator can be used to prevent the updates from growing too much. Another related problem is the sensitivity to large gradient norms. Given the large $\\beta_3$, outlier gradients can have a long lasting and detrimental effect. In our experiments, we found that using gradient-clipping can be important to mitigate this issue. We thank the reviewer for raising this point and will add a paragraph detailing these two cases in our next revision. \n2. We did not try adding an additional EMA to other optimizers like Adafactor or Signum. Our focus was on showing convincingly that our approach works well in standard settings. \n3. This is an interesting suggestion. Our work aims to introduce a novel method and show it performs well in standard settings. While we provide some direction to reduce the memory footprint, we do not believe the memory overhead is entirely challenging the usability of the method in many cases. As such, we leave further memory optimizations as future work. \n4. For our 1B parameter experiments, when using a learning rate of 5e-4 for AdEMAMix, we observed instabilities in the form of gradient norm spikes. We did not deep dive into our model to check if a certain subpart of the model was causing this. We believe those to be caused by the slight effective learning rate increase resulting from adding alpha*m2 to m1. Our results at a larger 3B scale in App.C.1.2 uses the same learning rate for both AdamW and AdEMAMix. \n5. This is a great suggestion. Given our goal to showcase our optimizer in standard settings, we used vanilla architectures in all of our experiments, which do not include QK-LayerNorms. In practice, we did not find it difficult to tune the hyperparameters for AdEMAMix. For instance, results on ViT models in Fig.6 (or Fig.31), show AdEMAMix models trained with the same hyperparameters as the best baseline, using different beta3 and alpha. Given enough data, it is easy to find a combination outperforming Adam. \n6. When switching from Adam to AdEMAMix, the effective learning rate suddenly increases, which we believe explains the bump that can be seen right after the switch. Larger alphas can imply larger bumps, which—in our experiments—take longer to recover from. Given more iterations, the model recovers. The values used were small enough to allow the model to recover well given the remaining number of iterations. Interestingly, App.C.1.4 details the reverse experiment, switching from AdEMAMix to AdamW. We observe a drop in loss immediately after the switch, likely explained by a drop in effective learning rate. \n\nWe hope we have addressed the reviewer’s concerns. We would appreciate it if the reviewer could consider raising their score and stay at their disposal in case any further questions need to be answered."}, {"comment_id": "BwbjzTiqBM", "replyto": "lw3RysqbTt", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for the time taken to review our work. First, responding to the weaknesses identified by the reviewer:\n\n1. On the novelty of our approach compared to AggMo. While the AggMo method introduced by Lucas et al. also adds additional EMAs with larger beta values, we strongly disagree that it implies our work lacks novelty. As discussed in the related work, the claims from Lucas et al. are that additional momentum terms speed up convergence and hyperparameter stability. In our work we show our optimizer is not only converging faster, but reaches better solutions (lower loss). Moreover, Lucas et al., relies on small-scale settings where SGD works well, using MNIST, CIFAR-10 and CIFAR-100, and training LSTM LMs on the Penn Treebank dataset. In contrast, we focus on significantly larger real-world settings. Those larger settings heavily favor Adam over SGD. Getting our optimizer to work in those settings was not as easy as adding an additional momentum buffer, it required tackling training instabilities through deriving schedulers, without which the method would not work. Finally, AggMo is suggesting using many (e.g. 4) momentum terms. We show in App.C.3.3 that using more than two momentum terms in AdEMAMix is not providing any benefit. \n2. On using AdEMAMix in fast domain-shift settings. This is a valid concern for that setting. We did not study non-stationary training distributions and leave that question for future work. Concerning schedulers, we observed that our design-choice of always setting $T_{\\alpha,\\beta_3}=T$ yielded stable training runs without additional validation, even when switching across model size, to state space models and vision models. \n3. We will change the text to say that using beta_1=0 needs to be confirmed at larger scale. We want to emphasize that, in many cases, the memory overhead is not a critical issue. Especially, in the case of distributed training, the memory overhead can be mitigated by sharding the optimizer state across devices."}, {"comment_id": "PDGujsJmj5", "replyto": "NEAO3UGAgL", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for the time taken to review our work.\n\n1. App.C.1.5 studies extensively the sensitivity to hyperparameters. The sensitivity to the mixing ratio alpha is shown in Fig.13.a. It is shown that the range of alpha values outperforming the Adam baseline is very wide. \n2. While a theoretical proof of convergence of our method would be welcome, deriving such proof poses significant challenges. We can take as an example Adam, which convergence has been challenged by Reddi et al. [1], which still failed to explain convergence for hyperparameter values used in practice as discussed in [2]. Despite the enduring gap between theory and practice, Adam remained the workhorse of deep learning optimization. Understanding it from a theoretical standpoint is a line of research on its own. As such, we believe providing convergence guarantees—albeit desirable—lies outside the scope of our work. This being said, convergence bounds in convex settings for a simpler method (AggMo) combining GD with a linear combination of EMAs have been shown in [3]. \n\nWe hope we have addressed the reviewer’s concerns. Sensitivity to hyperparameters is studied in App.C.1.5, and providing theoretical backing seems tedious given a theoretical understanding of even Adam is still an active field of research. We hope the reviewer will consider raising their score or provide further details justifying their rejection of our work.\n\n[1] Reddi et al., 2019: On the Convergence of Adam and Beyond\n[2] Zhang et al., 2022: Adam can converge without any modification on update rules.\n[3] Lucas et al., 2019: Aggregated momentum: Stability through passive damping"}, {"comment_id": "Jy2NhsGg3c", "replyto": "HZGoTxkYT6", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for the positive appreciation of our work. \n\nIn Fig.17 we keep the total number of training tokens constant, and vary the number of steps (in {$32k, 64k, 128k, 256k$}). To keep the number of training tokens constant, this forces us to increase the batch size as we decrease the number of steps. Looking at Fig.17.a and Fig.17.b, we observe that both methods suffer when we trade a large number of iterations for a larger batch size and fewer steps. AdEMAMix—while still outperforming Adam—is more affected. To understand this phenomenon, we notice that when we do fewer steps, we have fewer gradients to accumulate in $m_2$. Given our very large $\\beta_3$ values (e.g. $0.9999$), this can become a problem. We show in Fig.18 that the problem is mitigated by reducing the $\\beta_3$ value to e.g. $0.999$. Interestingly, comparing AdamW in Fig.17.a and AdEMAMix with $\\beta_3=0.999$ in Fig.18.b, we can see that AdEMAMix is now less affected than AdamW when we trade a large number of iterations for a larger batch size and fewer steps (the increase in final loss is smaller for AdEMAMix when increasing the batch size). In general, from our experiments on both images and text, we did not notice any disadvantage of ADEMAMix over AdamW when increasing the batch size. Our largest experiments, in App.C.1.2, train 3B parameter models using a batch size of $1024^2$ tokens. We still observe improvements. \n\nWe thank again the reviewer for appreciating our work and stay at their disposal to answer any further questions."}, {"comment_id": "ykGtrKdIAX", "replyto": "5ycVPKvn1k", "author_type": "authors", "reviewer": null, "comment": "5. Combining AdEMAMix with Adam-mini is indeed an interesting research direction. In our work, we propose a simpler solution which consists in setting beta1=0. This means that m1 is replaced by the gradient, and the memory cost of AdEMAMix is then equal to Adam. We show in App. C.1.8 that this strategy works in most cases. Investigating more elaborated memory saving strategies is an interesting future work direction. We will cite Adam-mini in our next revision. \n6. We thank the reviewer for sharing this work, we will cite it in our next revision. \n\nWe hope we have addressed the reviewer’s concerns and conveyed the intuition behind our method. We hope the reviewer will consider increasing their score and we stay at their disposal to answer any further questions. \n\n[1] Wen et al., 2024: Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective\n[2] Flammarion et al., 2015: \"From averaging to acceleration, there is only a step-size.\" Conference on learning theory."}, {"comment_id": "5ycVPKvn1k", "replyto": "n0FXVCyfdP", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for taking the time to review our work. \n\n1. We agree that the gap between Rosenbrock and LLM training is important. We do not claim that this toy setting captures the complexity of LLM training dynamics. Instead, it serves as an illustration, showing that—quite remarkably—our method can be motivated even in a noise-free toy setting. Moreover, while understanding the training dynamics of LLMs is still an active line of research, a recent line of work [1] (also pointed out by reviewer GCtm) suggests a “river-valley” structure of the loss landscape of LLMs, which does share some similarities with the Rosenbrock function. We will cite this work in our next revision. \n2. We are not exactly sure where our \"effort into designing more experiments\" should concentrate. Could the reviewer specify the type of experiments that would be needed? We provide two key intuitions that help understand the reason behind the superiority of our approach. First, while limited, our Rosenbrock example shows that using an EMA with a large beta can be beneficial yet needs adjustment of the optimizer (it doesn’t work to simply increase beta1 in Adam). Saying that “$\\beta_1=0.9$ is suboptimal for LLMs” is ambiguous, and this is not what we aim to convey. It seems $\\beta_1=0.9$ is optimal for Adam (see Fig.20). Using Adam with $\\beta_1=0.9999$ does not work, which makes sense as $m_1$ is no longer responsive to local variations of loss landscape, therefore Adam with $\\beta_1=0.9999$ fails to optimize the underlying function (Fig.3.a). Our core message is that this can be solved by combining it with a term that stays responsive to the local loss landscape (e.g. an EMA with a small beta). Then, it is possible to gain from using much older gradients. In the Appendix, section C.1.7 is entirely devoted to limitations of using a single EMA in Adam. In that section, we show increasing $\\beta_1$ in Adam does not work, even if we bypass the initial training instabilities and start from a pretrained AdamW checkpoint (Fig.21), even when adding schedulers on $\\beta_1$ as for AdEMAMix (Fig.20.b). Even when increasing $\\beta_2$ from $0.999$ to $0.9999$ to stabilize training (again Fig.20.b). The second intuition we provide relates to our analysis of forgetting (Fig.4). We show that larger beta values as in AdEMAMix can be more data-efficient, improving the final loss on training samples when compared to Adam. While we concede that our understanding of the phenomenon and motivations are not exhaustive, we nonetheless prove empirically that more historical gradients can be used efficiently, while providing several possible justifications. \n3. The larger the beta, the smaller the contribution of each gradient, and therefore many gradients are needed to change the direction of $m_2$. In contrast, changing $m_1$ only requires a few iterations. Therefore, updating the weights in a direction going against $m_2$ requires pushing against $m_2$ for many iterations. In contrast, updating the weights in a direction orthogonal to $m_2$ is easy. In essence, what we are trying to convey is: the fast EMA ($m_1$) helps the algorithm stay in the valley, and the slow EMA ($m_2$) helps the algorithm crawl along the valley, remembering the valley's direction and accelerating. We will clarify in our next revision. \n4. Fig.2 aims to show that AdEMAMix can use larger momentum values to reach good solutions faster, with less oscillations. In Fig.2.a, we observe the distance to the solution going down and then up again. This is due to the momentum causing the iterates to overshoot the solution. The oscillations of momentum methods on deterministic objectives are well documented, see e.g. https://francisbach.com/continuized-acceleration/, https://distill.pub/2017/momentum/, or [2, Fig.5]. In practice, to the best of our knowledge, oscillations due to momentum have never been an issue in deep learning training (i.e. the phenomenon of overshooting the solution seems to be only observed in toy settings, see Fig.20 and Fig.21 to see the lack of oscillations even with very large $\\beta_1$ in a real world setting). This is why we focus on the left part of the curves, corresponding to the first time required to arrive in the vicinity of the solution. Ultimately, all the methods tested converge to the solution given enough iterations, we only claim that AdEMAMix can find good solutions relatively fast."}], "meta_review": {"metareview": "This paper proposes a new heuristic to improve the momentum-based methods. In particular, it introduces another moving average sequence of stochastic gradient with a larger momentum parameter and adds it to the standard EMA sequence with a large weight. The paper has demonstrated superior performance in various settings. All reviewers agree that the paper has done a great work in demonstrating the effectiveness of the paper. However, a concern is that the paper does not provide any convergence guarantee of the proposed method. Hence, I will recommend a weak acceptance. The authors should take reviewers' comments into account for improving their paper.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "kOJf7Dklyv", "forum_url": "https://openreview.net/forum?id=kOJf7Dklyv", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "Air Quality Prediction with Physics-Guided Dual Neural ODEs in Open Systems", "authors": ["Jindong Tian", "Yuxuan Liang", "Ronghui Xu", "Peng Chen", "Chenjuan Guo", "Aoying Zhou", "Lujia Pan", "Zhongwen Rao", "Bin Yang"], "abstract": "Air pollution significantly threatens human health and ecosystems, necessitating effective air quality prediction to inform public policy. Traditional approaches are generally categorized into physics-based and data-driven models. Physics-based models usually struggle with high computational demands and closed-system assumptions, while data-driven models may overlook essential physical dynamics, confusing the capturing of spatiotemporal correlations. Although some physics-guided approaches combine the strengths of both models, they often face a mismatch between explicit physical equations and implicit learned representations. To address these challenges, we propose Air-DualODE, a novel physics-guided approach that integrates dual branches of Neural ODEs for air quality prediction. The first branch applies open-system physical equations to capture spatiotemporal dependencies for learning physics dynamics, while the second branch identifies the dependencies not addressed by the first in a fully data-driven way. These dual representations are temporally aligned and fused to enhance prediction accuracy. Our experimental results demonstrate that Air-DualODE achieves state-of-the-art performance in predicting pollutant concentrations across various spatial scales, thereby offering a promising solution for real-world air quality challenges.", "keywords": ["Air Quality Prediction; Physics-guided Deep Learning"], "primary_area": "learning on time series and dynamical systems", "pdf_url": "https://openreview.net/pdf?id=kOJf7Dklyv", "decision": "Accept (Poster)", "num_reviews": 5, "num_discussions": 38, "reviews": [{"review_id": "Ys1pxLTDIm", "reviewer": "Reviewer_6Tsm", "rating": 6, "confidence": 5, "soundness": 2, "presentation": 2, "contribution": 2, "summary": "This paper proposes Air-DualODE, a novel approach for air quality prediction that combines physics-based and data-driven methods using dual Neural ODEs. The physics branch implements a modified diffusion-advection equation with a correction term for open systems (BA-DAE). In contrast, the data-driven branch employs masked attention-based Neural ODEs to capture unknown dynamics. The two branches are temporally aligned using a decaying contrastive learning scheme and fused in latent space using GNN, demonstrating superior performance on city-scale (Beijing) and national-scale (KnowAir) datasets.", "strengths": "1. The paper addresses the limitations of pure physics-based and pure data-driven approaches by proposing a hybrid framework that attempts to leverage the advantages of both methods.\n\n2. The introduction of BA-DAE with a correction term represents an attempt to model open system dynamics, which is more realistic for air quality prediction than traditional closed system assumptions.\n\n3. The model achieves state-of-the-art performance across different spatial scales while maintaining some level of interpretability through its physics branch.", "weaknesses": "1. The paper oversimplifies complex air pollution dynamics using a linear correction term (βX) without proper theoretical justification, undermining its claim of accurate open system modeling.\n\n2. The approach loses physical interpretability when projecting to latent space and violates conservation laws, raising concerns about numerical stability and contradicting the paper's emphasis on physics-informed modeling.\n\n3. The computational efficiency claims are questionable as the dual branch architecture with multiple ODE solvers likely increases computational burden rather than reducing it.\n\n4. The experimental validation is limited, with case studies confined to Beijing data and lacking crucial analyses such as parameter sensitivity testing and solver comparisons.\n\n5. The technical documentation is incomplete, with key mathematical elements missing from figures and insufficient details about architectural choices, making reproducibility challenging.", "questions": "Q1. Given that real air pollution sources (industrial activity, vehicle emissions) and sinks (forests, lakes) exhibit complex non-linear relationships, why did you simplify the correction term $\\beta X$ as a linear term? What is the physical justification for setting $\\beta$'s range to $[−1, +∞)$?\n\nQ2. While you claim that the Physics branch explicitly models physical phenomena, how is this physical interpretability preserved when projecting into latent space?\n\nQ3. Regarding the temporal alignment process using Decay-TCL, how do the chosen values of $\\lambda_1 = 1$ and $\\lambda_2 =0.8$ guarantee physically meaningful alignment? What is the physical significance of using time-decaying weights?\n\nQ4. Why specifically choose Spatial-MSA in the Data-Driven branch? How does this align with a physics-informed approach?\n\nQ5. The authors justify GNN fusion based on 'distance-dependent influence', but isn't this characteristic already considered in the Physics branch?\n\nQ6. How can the authors justify the performance on the national-scale KnowAir dataset when case studies are limited to the Beijing dataset?\n\nQ7. How does the visualization of $\\beta$ values correspond to actual observed pollution source/sink data?\"\n\nQ8. Can authors perform sensitivity analysis for different ranges of $\\beta$ values?\n\nQ9. The authors seem to only consider the DOPRI5 ODE solver. Could they analyze performance and runtime differences when using simpler methods like Euler or RK4?\n\nQ10. The paper should reference and compare with recent work on climate modeling using diffusion and diffusion-advection equations in neural ODE frameworks [1,2]. Can authors clarify their position by analyzing similarities and differences in their approach to diffusion and advection?\n\nQ11. How does the intentional violation of conservation law in BA-DAE affect numerical stability, particularly for ODE solvers?\n\nQ12. How do you ensure that the BA-DAE in the Physics branch and Neural ODE in the Data-driven branch operate in the same state space?\n\nQ14. Is 'Physics-Informed' appropriate in the title? Would 'Physics-guided' or 'Physics-inspired' be more accurate, given that this might be confused with traditional PINN approaches?\n\nQ15. Can authors provide more details about the RNN used in the Coefficient Estimator?\n\nQ16. Figure 2 lacks several elements mentioned in the text, particularly α from equation 6. The relationship between Dynamics fusion and Section 3.4 equations needs clarification.\n\nQ17. Can authors provide visualizations or distribution analyses showing how Gdiff and Gadv change dynamically with wind speed and direction?\n\nQ18. While criticizing the computational cost of existing physics-based methods, how does your dual branch architecture with a complex fusion mechanism improve efficiency? Doesn't using two ODE solvers increase computational burden?\n\nQ19. Can the authors include the number of forward evaluations (NFE) comparisons in Table 2's ablation studies?\n\nQ20. The authors should cover related work on GNNs that redesign the diffusion equation [3,4] and its variations[5,6] using NODE. Can you discuss more about what the authors' methods have in common and what they differ from?\n\n> [1] Choi, Hwangyong, et al. \"Climate modeling with neural advection-diffusion equation.\" Knowledge and Information Systems 65.6 (2023): 2403-2427.\n> \n> [2] Hwang, Jeehyun, et al. \"Climate modeling with neural diffusion equations.\" 2021 IEEE International Conference on Data Mining (ICDM). IEEE, 2021.\n>\n> [3] Wang, Yifei, et al. \"Dissecting the diffusion process in linear graph convolutional networks.\" Advances in Neural Information Processing Systems 34 (2021): 5758-5769.\n>\n> [4] Chamberlain, Ben, et al. \"Grand: Graph neural diffusion.\" International conference on machine learning. PMLR, 2021.\n>\n> [5] Thorpe, Matthew, et al. \"GRAND++: Graph neural diffusion with a source term.\" ICLR (2022).\n>\n> [6] Choi, Jeongwhan, et al. \"Gread: Graph neural reaction-diffusion networks.\" International Conference on Machine Learning. PMLR, 2023."}, {"review_id": "bFILgrO9MG", "reviewer": "Reviewer_NWR1", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "This paper introduces Air-DualODE, a novel physics-informed model for predicting air quality that combines the strengths of physics-based and data-driven approaches. Traditional physics-based models are computationally intensive and rely on closed-system assumptions, while data-driven models often lack critical physical insights. Air-DualODE addresses these limitations with a dual-branch design using Neural Ordinary Differential Equations (Neural ODEs): one branch incorporates open-system physical equations to model spatiotemporal dependencies, while the second captures additional dependencies in a data-driven manner. The two branches are temporally synchronized and fused for enhanced predictive accuracy. Experimental results show that Air-DualODE outperforms existing methods in predicting pollutant levels across different spatial areas, making it a robust tool for air quality forecasting.", "strengths": "1) Very well written paper combined with appropriate plots.\n\n2) Recognizing that air pollution transport occurs in an open system, the authors redefine the diffusion-advection equation in explicit spaces. This new formulation, termed BA-DAE, aligns the physical equations more accurately with real-world pollutant transport in open air environments, enhancing the model's applicability and reliability.\n\n3) The proposed model, Air-DualODE, uniquely integrates both Physics Dynamics and Data-Driven Dynamics. This dual-branch structure leverages the advantages of physical models (to capture foundational physical behaviors) and data-driven insights (to adapt to complex patterns not covered by physics alone). This approach represents the first dual-dynamics deep learning model specifically tailored for air quality prediction in open systems.\n\n4) Experimental results demonstrate that Air-DualODE outperforms existing models, achieving state-of-the-art accuracy in forecasting pollutant concentrations across diverse spatial scales, from city-wide to national levels.", "weaknesses": "The dual-branch structure in Air-DualODE, while innovative, adds complexity to the model, potentially making it less interpretable than simpler models. This may pose challenges for stakeholders, such as policymakers, who require clear explanations of how predictions are made. But this is not a very important point neither does it offer any reason to not accept this paper.", "questions": "1) How does Air-DualODE perform in regions with sparse or inconsistent air quality data? Are there mechanisms in place to handle data gaps, or do you recommend a minimum data density for effective predictions?\n\n2) Has Air-DualODE been tested on pollutants other than those mentioned in the paper? Could this framework be adapted to predict other types of environmental data, such as water quality?"}, {"review_id": "uXTRkVcV9N", "reviewer": "Reviewer_v3np", "rating": 8, "confidence": 4, "soundness": 4, "presentation": 3, "contribution": 4, "summary": "To enhance air quality prediction in open systems, this paper proposes a dual Neural ODE architecture named Air-DualODE. It combines a boundary-aware diffusion-advection Equation (BA-DAE) for physical dynamics with a Neural ODE employing masked spatial attention for data-driven dynamics. The framework also incorporates a fusion mechanism that temporally aligns and merges outputs from these dynamics in a shared latent space. Experimental results on both city- and national-level datasets demonstrate state-of-the-art performance.", "strengths": "1.This paper addresses a significant problem, i.e. air quality prediction in open systems. In particular, the model thoroughly considers the non-conservation of pollutant concentration within the region of interest, introducing the BA-DAE to effectively model sources and sinks within the area.\n2.The author propose an interesting Air-DualODE framework that models known physical equations and unknown spatiotemporal dependencies separately, then aligns and fuses them in the latent space. This approach highlights the guiding role of physical knowledge within the model while allowing it to capture unknown spatiotemporal dependencies that may not be described by physical equations.\n3.In experiments, the model achieves superior performance across multiple metrics on different spatial scales datasets. Besides, the provided code enhance this paepr’s reproducibility.\n4.The writing and structure of this paper are clear and easy to understand.", "weaknesses": "1.The discrete diffusion and advection equations are not thoroughly described in this paper. Highly recommend fully explaining in the appendix.\n2.Please elaborate the role of spatial-MSA in data-driven branch presented in Section 3.3.\n3.Check some typos in this paper. For example, in Section 4.2 Table 1, \"Beijing1718\" should be “Beijing”.", "questions": "1.Why using the GNN Fusion after temporal alignment? What about other simple structure like MLP?\n2.Is there other data should be included for geospatial graph construction?"}, {"review_id": "ui3DPmcmG0", "reviewer": "Reviewer_joi5", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "The authors propose a model called Air-DualODE that integrates dual branches of Neural Ordinary Differential Equations (ODEs) to enhance air quality predictions. To address the mismatch between explicit physical equations and implicit learned representations, the paper introduces a hybrid model consisting of two branches.\n\nThe Physics Branch utilizes the Boundary-Aware Diffusion-Advection Equation (BA-DAE) to capture the spatiotemporal dependencies of pollutants. Meanwhile, the Data-Driven Branch captures dependencies not addressed by the physical equations. These dual representations are skillfully fused using a dynamics fusion module with decaying temporal alignment, which enhances prediction accuracy. \n\nThe effectiveness of this method is demonstrated on two widely-used air quality prediction datasets, showing that the model outperforms other existing methods while offering strong interpretability compared to purely data-driven approaches.", "strengths": "- Enhanced Physical Modeling: This paper innovatively transforms closed-system physical equations into open-system physical equations based on existing Physics-Informed Neural Networks (PINNs). This modification aligns more closely with real-world conditions and improves interpretability in boundary regions.\n\n- Effective Integration of Knowledge: The DualODE model skillfully integrates physical knowledge with data-driven methods to bridge the gap between physical equations and real-world data. This dual approach ensures that the model captures dependencies that are not addressed by physical equations alone.\n\n- Advanced Fusion Technique: By employing dynamic fusion to combine the representations from the two ODE models, and further refining the output through a Graph Neural Network (GNN), the model effectively captures real-world patterns. This approach outperforms the simplistic method of directly concatenating the two representations, leading to more accurate and realistic predictions.", "weaknesses": "- Insufficient Detail on Data-Driven Neural ODE: The paper does not provide a detailed description of the data-driven Neural Ordinary Differential Equations (ODE). Additionally, the necessity of using a Neural ODE model as the data-driven approach is not clearly justified, which may leave readers questioning its selection over other potential methods.\n\n- Limited Results Presentation: The results presented in the paper are limited to predictions for 3 days and scenarios of sudden changes. Predictions for shorter time frames, such as 1-day and 2-day forecasts, are not included. Furthermore, there are discrepancies between the results reported for the reproduction of other works and those in the original publications, which could undermine the credibility of the comparative analysis.\n\n- Language Errors in Figures and Tables: There are minor language errors in the figures and tables that could cause confusion. These errors necessitate careful cross-referencing with the accompanying text to ensure clarity and accurate interpretation of the data.", "questions": "- Errors in Figure and Table Text: There are textual and informational errors in the figures and tables. In Figure 2, \"Pollutant Contentration\" should be corrected to \"Pollutant Concentration,\" and there is an error in the expression for Wind speed. Additionally, the last row in Table 2 should be labeled as \"Air-DualODE.\" without “w/o”.\n\n- Incomplete and Redundant Equations in the Appendix: The equations following lines 690-692 in the appendix are incomplete and appear to be redundant with the content in lines 219-220 of the main text. Could these be revised for clarity and completeness?\n\n- Detailed Derivation of Equation 3 and $F^D$: It would be beneficial to provide a more detailed derivation of Equation 3 in the appendix to enhance understanding. Additionally, the $F^D$ formula is not provided. Could you include a more thorough description to facilitate reader comprehension?"}, {"review_id": "lQ2v56zcFV", "reviewer": "Reviewer_iuiU", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 2, "summary": "This paper presents Air-DualODE, a hybrid model that integrates physics-based and data-driven techniques. The proposed model addresses the limitations of traditional physics-based and data-driven models, particularly the assumptions about closed systems and the mismatch between explicit physical equations and learned representations. The model involves two main components, where one component applies open-system physical equations to model spatiotemporal dependencies and learn physical dynamics, while the second component identifies the unmodeled dependencies using a fully data-driven approach. The model achieves superior performance on two real-world datasets at city scale and national scale, compared to the other baselines in this domain.", "strengths": "1. The paper addresses a significant and timely problem.\n2. The write-up style and the narrative of the paper is clear, structured, and professional. \n3. The reported results are very strong compared to the existing baselines.", "weaknesses": "1. The paper does not adequately address the scalability of the proposed model. Given that the approach involves multiple components like ODE solvers and GNNs, the computational cost could become prohibitive when applied to larger datasets with thousands of locations. (ex: Dataset used in AirFormer (Liang et al., 2023)). \n\n2. The overall approach and the methodology of this work closely resembles prior work like AirPhyNet (Hettige et al., 2024), which also integrates physics-based diffusion-advection processes with neural networks for air quality prediction, raising concerns about novelty and significant contributions.\n\n3. Eventhough the work emphasizes about the interpretability of the model, there is limited explanation about how the model’s output could be interpreted in terms of real-world physical phenomena.\n\n4. All the baselines used in the experiments are either fully data driven or hybrid models. Advanced physics based models are not included as baselines and the performance is not compared with the proposed model. Can you evaluate your model’s performance against specific physics-based models, such as Community Multiscale Air Quality (CMAQ) , Weather Reaserch Forecasting (WRF), AERMOD to provide a more comprehensive understanding of how your hybrid model compares in terms of accuracy and computational efficiency?", "questions": "1. The introduction and role of the $\\beta$ term in Equation 6 is not fully clear. Could you provide more details including mathematical justifcation on how this term is derived and its exact significance within the framework? \n\n2. Could you provide more theoretical justification and intuition behind the Decay-TCL mechanism? Specifically, how do the hyperparameters influence the alignment and fusion process between the physics-based and data-driven dynamics?\n\n3. Could you elaborate on how the GNN fusion mechanism effectively balances the two components?\n\n4. How scalable is Air-DualODE to larger datasets with thousands of nodes? Could you provide a detailed complexity analysis of your model as the number of locations (nodes) increases ?\n\n5. Could you provide more clarification on how the predictions can be interpreted and provide specific examples on how these outputs help policy makers or environmental scientists in their decision making process?\n\n6. The proposed model and AirPhyNet (Hettige et al., 2024) employ a similar hybrid architecture of physics based and data driven models , combining GNNs and differential equation solvers. Additionally, both approaches emphasize interpretability through case studies that link predictions to real-world physical phenomena, such as pollutant dispersion influenced by environmental factors like wind speed and direction. Could you elaborate on how your model differs methodologically and in its interpretive capabilities and what specific aspects of your work provide novelty beyond what is addressed in AirPhyNet?"}], "discussions": [{"comment_id": "Ol3WCHvyqY", "replyto": "yj8Qyl8GE4", "author_type": "authors", "reviewer": null, "comment": "We are happy that you are finally satisfied with our revision, and deeply value your engagement and dedication as a reviewer for ICLR 2025. If our revision fully meets your expectations, would you consider giving us an accept instead of marginally above accept?😉"}, {"comment_id": "yj8Qyl8GE4", "replyto": "Ys1pxLTDIm", "author_type": "reviewer", "reviewer": "Reviewer_6Tsm", "comment": "I have read your final response. Thank you for addressing my questions and showing dedication until the very end. Your earnest efforts throughout our discussion have been greatly appreciated.​​​​​​​​​​​​​​​​"}, {"comment_id": "FOvTzUK3W8", "replyto": "h0cufoORNz", "author_type": "authors", "reviewer": null, "comment": "Thank you sincerely once again for raising your score and for recognizing our efforts to improve the manuscript based on your comments during the rebuttal period. Your recognition is the greatest encouragement for our work.\n\n### **Q1**: Could you explain how this difference affects the interpretation of results in the three tables?\n\nSorry for the confusion. Considering the new definition of sudden change, we have completely abandoned the original sudden change settings. Specifically, the original sudden change definition only considered cases where $\\text{PM2.5} \\ge 75$μg/m³. Conversely, the two new sudden change definitions account for the entire range of $\\text{PM2.5}$ levels, including cases where $\\text{PM2.5} \\le 75$μg/m³, which results in the value range appearing smaller compared to the original definition (e.g., in the Beijing dataset, the original definition’s MAE value range is around 67, while the new definition’s MAE value range is around 45). However, for consistency and alignment with the original sudden change’s value range, we provide the change point detection (with the constraint $\\text{PM2.5} \\ge 75$μg/m³) results as follows.\n\n`Table1` The original sudden change definition (cf. lines 399 to 416): \n\n| | Beijing | | | | KnowAir | | | |\n| ----------- | --------- | --------- | -------- | -------- | --------- | --------- | -------- | -------- |\n| | MAE | RMSE | SMAPE | MDA | MAE | RMSE | SMAPE | MDA |\n| Airformer | 68.80 | 91.16 | 0.75 | 0.53 | 39.99 | 55.35 | **0.49** | 0.55 |\n| AirPhyNet | 70.03 | 94.60 | 0.78 | 0.51 | 43.23 | 58.79 | 0.50 | 0.51 |\n| Air-DualODE | **66.40** | **90.31** | **0.73** | **0.57** | **39.79** | **54.61** | **0.49** | **0.57** |\n\n`Table2` The new sudden change definition based on the **Binary Segmentation detection**: \n\n| $\\text{PM2.5} \\ge 75$μg/m³ | Beijing | | | | KnowAir | | | |\n| -------------------------- | --------- | ---------- | -------- | -------- | --------- | --------- | -------- | -------- |\n| | MAE | RMSE | SMAPE | MDA | MAE | RMSE | SMAPE | MDA |\n| Airformer | 85.47 | 105.57 | 0.82 | 0.41 | **50.70** | 66.79 | **0.53** | **0.42** |\n| AirPhyNet | 88.73 | 107.97 | 0.88 | 0.37 | 52.99 | 68.00 | 0.57 | 0.34 |\n| Air-DualODE | **84.29** | **105.31** | **0.81** | **0.42** | **50.67** | **66.13** | **0.53** | **0.42** |\n\n`Table3` The new sudden change definition based on the **Kernel Change Point Detection**: \n\n| $\\text{PM2.5} \\ge 75$μg/m³ | Beijing | | | | KnowAir | | | |\n| -------------------------- | --------- | ---------- | -------- | -------- | --------- | --------- | -------- | -------- |\n| | MAE | RMSE | SMAPE | MDA | MAE | RMSE | SMAPE | MDA |\n| Airformer | 87.96 | **107.54** | 0.84 | 0.44 | 51.74 | 66.95 | **0.50** | **0.41** |\n| AirPhyNet | 91.35 | 110.68 | 0.89 | 0.37 | 52.86 | 67.01 | 0.54 | 0.35 |\n| Air-DualODE | **86.80** | **107.51** | **0.83** | **0.45** | **50.98** | **66.79** | **0.50** | **0.41** |\n\nFrom the above three tables, Air-DualODE is indeed better than the baseline models under different definitions of sudden changes and a variety of evaluation metrics.\n\n\n\n### **Q2**: Can the authors further discuss the potential threats to validity that may arise when using these metrics and two kinds of evaluation methods?\n\nFrom our perspective, sudden change remains a context-dependent concept and does not have a well-recognized definition in the air quality. Essentially, sudden changes represent a subset of the test set that represent the stakeholder’s interests. The subset defined by stakeholders, based on their specific criteria for sudden changes, often includes challenging or noteworthy scenarios. Based on the results of the three subsets mentioned in Q1, we believe that Air-DualODE demonstrates greater robustness compared to the baselines under five different sudden change settings."}, {"comment_id": "h0cufoORNz", "replyto": "Ys1pxLTDIm", "author_type": "reviewer", "reviewer": "Reviewer_6Tsm", "comment": "Thank you for your response. I value your dedication to addressing the limitation I pointed out about sudden changes by incorporating 2 new change point detection approaches and evaluating them using the MDA metric.\n\nIn line 375, you define sudden changes as \"cases where PM2.5 levels exceed 75 μg/m³ and fluctuate by more than ±20 μg/m³ in the next three hours,\" but in your note, you mention, \"we do not impose the constraint of PM2.5 being greater than or equal to 75 when using the algorithm to identify change points.\" \n*Could you explain how this difference affects the interpretation of results in the three tables?*\n\nThe comprehensive analysis provided by the authors confirms that Air-DualODE performs well in terms of sudden changes. *Can the authors further discuss the potential threats to validity that may arise when using these metrics and two kinds of evaluation methods (i.e., binary segmentation and kernel change point detection)?*\n\n---\n\nInitially, I rated this paper at `3`. Through our extensive discussion and your responses to numerous feedback points, I raised my rating to `5`. \n\nNow, with this response showing your readiness to go beyond the conventional evaluation approaches used by Airformer and AirPhyNet by incorporating more sophisticated change point detection evaluation methods and the MDA metric, I am increasing my rating to `6`. This last improvement demonstrates your commitment to advancing the field beyond existing limitations and provides a more rigorous evaluation framework for future research in this area."}, {"comment_id": "g17dLYr5z6", "replyto": "cEMb5d1RXh", "author_type": "authors", "reviewer": null, "comment": "> Given the extended rebuttal period, I encourage you to either: a) strengthen your experimental analysis using these suggested metrics or b) expand your limitations section to acknowledge these evaluation challenges and discuss potential alternative approaches.\n\nRegarding Q6: To follow your suggestion on conducting additional experiments on sudden changes (approach (a)), we use two new sudden change definitions based on two different change point detection algorithms (Binary Segmentation [1] and Kernel Change Point Detection [2]) and use the new evaluation metric `MDA`, as detailed below. \n\n`Table1` The original sudden change definition (cf. lines 399 to 416): \n\n| | Beijing | | | | KnowAir | | | |\n| ----------- | --------- | --------- | -------- | -------- | --------- | --------- | -------- | -------- |\n| | MAE | RMSE | SMAPE | MDA | MAE | RMSE | SMAPE | MDA |\n| Airformer | 68.80 | 91.16 | 0.75 | 0.53 | 39.99 | 55.35 | **0.49** | 0.55 |\n| AirPhyNet | 70.03 | 94.60 | 0.78 | 0.51 | 43.23 | 58.79 | 0.50 | 0.51 |\n| Air-DualODE | **66.40** | **90.31** | **0.73** | **0.57** | **39.79** | **54.61** | **0.49** | **0.57** |\n\n`Table2` The new sudden change definition based on the **Binary Segmentation detection**: \n\n| | Beijing | | | | KnowAir | | | |\n| ----------- | --------- | --------- | -------- | -------- | --------- | --------- | -------- | -------- |\n| | MAE | RMSE | SMAPE | MDA | MAE | RMSE | SMAPE | MDA |\n| Airformer | 45.23 | 65.58 | 0.78 | 0.40 | 20.59 | 32.25 | **0.45** | **0.44** |\n| AirPhyNet | 46.63 | 67.71 | 0.80 | 0.38 | 23.00 | 33.65 | 0.50 | 0.34 |\n| Air-DualODE | **43.47** | **65.35** | **0.76** | **0.45** | **20.13** | **31.56** | **0.45** | **0.44** |\n\n`Table3` The new sudden change definition based on the **Kernel Change Point Detection**: \n\n| | Beijing | | | | KnowAir | | | |\n| ----------- | --------- | --------- | -------- | -------- | --------- | --------- | -------- | -------- |\n| | MAE | RMSE | SMAPE | MDA | MAE | RMSE | SMAPE | MDA |\n| Airformer | 45.43 | 67.33 | 0.74 | **0.46** | 20.50 | 32.09 | 0.45 | **0.42** |\n| AirPhyNet | 47.27 | 69.75 | 0.76 | 0.39 | 22.87 | 33.53 | 0.49 | 0.36 |\n| Air-DualODE | **43.92** | **66.92** | **0.72** | **0.46** | **20.04** | **31.41** | **0.44** | **0.42** |\n\n**Note:** In these two new sudden change definitions, we do not impose the constraint of PM2.5 being greater than or equal to 75 when using the algorithm to identify change points. Therefore, the value ranges differ from those in the original definition.\n\nFrom the above three tables, Air-DualODE is indeed better than the baseline models under different definitions of sudden changes and a variety of evaluation metrics.\n\n> [1]. Scott, Andrew Jhon, and Martin Knott. \"A cluster analysis method for grouping means in the analysis of variance.\" *Biometrics* (1974): 507-512.\n>\n> [2]. Harchaoui, Zaid, Eric Moulines, and Francis Bach. \"Kernel change-point analysis.\" *Advances in neural information processing systems* 21 (2008)."}, {"comment_id": "cEMb5d1RXh", "replyto": "fvS9bZgKmP", "author_type": "authors", "reviewer": null, "comment": "Thank you sincerely for raising your score and for recognizing our efforts on improving the manuscript based on your comments.\n\nLayer normalization discussions: Since the last submission deadline has already passed, we are unable to submit an updated version to OpenReview anymore. Therefore, we plan to include additional detailed discussions according to your suggestions in the final version's Appendix.\n\nWe provide the training times as follows (one epoch).\n\n| Training time | Beijing | KnowAir |\n| ------------------ | ------- | ------- |\n| Air-DualODE-Euler | 53s | 108s |\n| Air-DualODE-RK4 | 68s | 274s |\n| Air-DualODE-Dopri5 | 81s | 378s |\n\nThis result shows that dopri5 requires longer training time, but considering that training is conducted offline, sacrificing training time for improved model accuracy is acceptable. However, in scenarios where training time is a critical factor, Euler or RK4 solvers can be used at the expense of model accuracy. Air-DualODE allows different types of users to select the ODE solver based on their specific requirements. That is a trade-off for different scenarios.\n\nRegarding Q6: Based on the requirements mentioned in your original Q6,\n\n> I encourage you to either: a) strengthen your experimental analysis using these suggested metrics or b) expand your limitations section to acknowledge these evaluation challenges and discuss potential alternative approaches.\n\nWe chose to follow your second approach and believe we have addressed this issue according to the suggestions in Q6. You can find the following limitation discussions about sudden changes definitions and evaluations in the main text (cf. lines 428–431).\n\n> Discussions: We use the current definition of sudden changes and the corresponding metrics to ensure consistency with existing studies (Liang et al., 2023; Hettige et al., 2024). However, they may not be the most appropriate ones. In future work, we plan to explore alternative evaluation metrics, such as mean directional accuracy (MDA) (Van den Burg & Williams, 2020), and adopt new sudden change definitions based on change point detection algorithms (Witzke et al., 2023).\n\nWe truly appreciate that you raise the score, but we would also like to understand if there are any remaining concerns preventing you from giving us an acceptance score. From your most recent reply, we feel that the only unresolved concern might be our decision not to follow your previously suggested first approach (`strengthen your experimental analysis using these suggested metrics`). We would like to confirm if this is the key reason for your score being below the acceptance bar. Although the remaining rebuttal period is limited, we are now trying our best to provide a small demonstration experiment of sudden changes’ detection algorithms and evaluation metrics. We hope to be able to offer this experiment before the end of the discussion phase."}, {"comment_id": "fvS9bZgKmP", "replyto": "Ys1pxLTDIm", "author_type": "reviewer", "reviewer": "Reviewer_6Tsm", "comment": "Thank you for your dedicated engagement in our discussion. I am pleased to recognize your efforts and proactive improvements by raising my initial `rating` by 2 points and `confidence`, and `contribution` by 1 point each. While I spent considerable time reviewing your responses and aimed to reply promptly, I appreciate your understanding if you think I am late.\n\nYour responses to Q1 and Q2 are satisfied, providing detailed experimental results and clearly distinguishing between $D$-Mean-Pooling Fusion and $N$-Mean-Pooling Fusion, along with well-explained differences between datasets.\n\nRegarding Q3, you have appropriately addressed my concerns with relevant results.\n\nWhile you note that layer normalization is a common component in various models, I believe its significance in Neural ODEs could benefit from additional interpretation [1]. Based on [1], I recommend expanding your discussion to include:\nFirst, layer normalization is well-suited for NODEs as it normalizes features rather than batches, aligning with NODEs' parameter sharing across continuous time steps. Second, it demonstrates strong compatibility with NODEs' continuous nature by avoiding batch statistics dependency. Additionally, layer normalization plays a crucial role in stabilizing the learning process by ensuring smoother dynamics of hidden representations, especially given NODEs' sensitivity to normalization technique selection.\n\n\nGiven that you've demonstrated functionality without Layer Normalization using fixed-step solvers, I suggest a more detailed discussion in your Appendix. From my experience, Dopri5 solving time in your model will likely be quite substantial. Could you report training times rather than inference times?\n\nRegarding Q6, I find it somewhat unsatisfactory that your paper remains constrained within the conventional boundaries set by previous studies such as Airformer and AirPhyNet. During this extended rebuttal period, I would be delighted to see even a small demonstration of the potential to better these existing limitations. \nAlthough the authors cannot cover this perfectly due to time constraints, I suggest this because I believe that if the authors can differentiate from existing methods in this aspect, your research value will be more distinct.\n\n> [1] Gusak, Julia, et al. \"Towards understanding normalization in neural odes.\" ICLR 2020."}, {"comment_id": "sGhHABRGpS", "replyto": "0s0Ii6y3gY", "author_type": "authors", "reviewer": null, "comment": "We truly appreciate your 49 (25+8+7+3+6) insightful and constructive suggestions and questions, which helps us improve our manuscript significantly. If you find that our manuscript, after incorporating all 49 of your valuable feedback points, meets your expectations, we would be truly grateful if you could kindly consider raising your score."}, {"comment_id": "0s0Ii6y3gY", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "### **Q3**: simpler solvers without layer normalization\n\nWe conduct several experiments using simpler fixed-step solvers, such as RK4 and Euler, without Layer Normalization. The results are as follows.\n\n| w/o Layer Normalization | MAE | RMSE | SMAPE |\n| ----------------------- | ----- | ----- | ----- |\n| Beijing-euler | 41.52 | 63.78 | 0.75 |\n| Beijing-rk4 | 41.32 | 63.22 | 0.75 |\n| Beijing-dopri5 | / | / | / |\n| KnowAir-euler | 18.99 | 30.78 | 0.42 |\n| KnowAir-rk4 | 18.98 | 30.45 | 0.42 |\n| KnowAir-dopri5 | / | / | / |\n\nThe results demonstrate that both RK4 and Euler can train stably without Layer Normalization. These empirical findings suggest that the dependency on Layer Normalization is related to whether using fixed-step solvers, as Dopri5 is a ODE solver using adaptive steps. The dual dynamics, with their inconsistent numerical ranges, lead to instability during adaptive step sizes in the forward and backward propagation. However, Layer Normalization ensures consistency of the numerical ranges in the dual dynamics’ adaptive solving during each iteration, thereby ensuring numerical stability.\n\n| w/ Layer Normalization | MAE | RMSE | SMAPE |\n| --------------------- | ----- | ----- | ----- |\n| Beijing-euler | 41.23 | 63.09 | 0.74 |\n| Beijing-rk4 | 40.80 | 62.90 | 0.74 |\n| Beijing-dopri5 | 40.32 | 62.04 | 0.74 |\n| KnowAir-euler | 18.92 | 30.77 | 0.42 |\n| KnowAir-rk4 | 18.94 | 30.42 | 0.42 |\n| KnowAir-dopri5 | 18.64 | 29.37 | 0.42 |\n\nThe above table provides the other simpler solvers' results (with layer normalization). The results demonstrate that w/o Layer Normalization slightly decreases model performance. Additionally, using the adaptive ODE solver Dopri5 yields the best results. To summarize, in settings where layer normalization is not available, we can go for fix-step solvers. Otherwise, we go for Dopri5 with layer normalization, which gives the best accuracy.\n\n### **Q4**: layer normalization dependency\n\nLayer Normalization is a common component used to constrain numerical ranges in deep learning architectures like Transformers. It is not a new, additional part introduced by our proposal. In addition, as indicated in the previous comment, if there exist settings where layer normalization are unavailable, we can use fix step size based solvers.\n\n\n\n### **Q5**: including other reviewer's limitations\n\nWe include limitations and discussions mentioned by other reviewers in the Appendix A.17 and A.18. However, we have sufficiently addressed all other reviewers' questions, and they all maintain positive outlook on our paper.\n\n\n\n### **Q6**: sudden changes definition and evaluation \n\nAlthough the rebuttal period has been extended, the revision submission deadline remains unchanged. From the time we received your questions and concerns to the final submission deadline, the time available was very limited (i.e., around 8 hours). Therefore, we chose your second approach. To clarify, both our sudden change definition and evaluation metrics follow previous works (e.g., Airformer and AirPhyNet) to ensure empirical setting consistency, leading to fair comparisons. We include these evaluation challenges and discuss potential alternative approaches you suggested in both the main text (cf. lines 428-431) and Appendix A.18. In future work, we will consider your suggestions regarding the definition and evaluation of sudden changes."}, {"comment_id": "kBlXI0Z0Hv", "replyto": "N0UbqLvxwi", "author_type": "authors", "reviewer": null, "comment": "Your follow-up questions and concerns are addressed as follows.\n\n### **Q1**: layers begin to show degradation\n\nWe conduct experiments with additional layers in the GNN Fusion module, and observe that when the number of layers increases to 30, degradation begins to occur. When the number of layers is fewer than 30, no significant degradation is observed. This is because we utilized residual connections between GNN layers. Previous studies have shown that residual connections enable deeper GNN architectures by mitigating issues such as over-smoothing [1, 2].\n\n| KnowAir-n | MAE | RMSE | SMAPE |\n| --------- | ----- | ----- | ----- |\n| 1 | 18.94 | 29.71 | 0.42 |\n| 3 | 18.64 | 29.37 | 0.42 |\n| 5 | 18.78 | 29.76 | 0.42 |\n| 10 | 18.75 | 29.74 | 0.42 |\n| 15 | 18.88 | 29.98 | 0.42 |\n| 20 | 19.04 | 30.05 | 0.43 |\n| 30 | 19.27 | 30.12 | 0.43 |\n| 40 | 24.81 | 39.80 | 0.55 |\n| 50 | 25.38 | 71.36 | 0.55 |\n\n> [1]. Li G, Muller M, Thabet A, et al. Deepgcns: Can gcns go as deep as cnns?[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2019: 9267-9276.\n>\n> [2]. Li G, Müller M, Ghanem B, et al. Training graph neural networks with 1000 layers[C]//International conference on machine learning. PMLR, 2021: 6437-6449.\n\n\n### **Q2**: clarification of mean pooling fusion\n\nTo clarify, this process does not incorporate representations from other nodes (cf. lines 1030–1031). Specifically, the dimension of $Z^{(0)}_t$ in GNN Fusion (Section 3.4 Dynamics Fusion) is $\\tau \\times N \\times D$, where $\\tau$ represents the prediction length, $N$ denotes the number of stations, and $D$ means the dimensions of the latent representations of each station after concatenation. Considering that the prediction results have the shape of $\\tau \\times N \\times 1$, we apply mean pooling fusion along the feature dimension $D$ ($D$-Mean-Pooling Fusion). Therefore, it does not contradict our explanation that distant node representations act as noise.\n\nTo avoid confusion, we also include experiments on mean pooling fusion along the station dimension $N$ of node representations ($N$-Mean-Pooling Fusion. This process incorporates representations from some other nodes. Specifically, these nodes' representations are aggregated based on the order of stations in the representation instead of distances. The results are as follows.\n\n| Model | Beijing | | | KnowAir | | |\n| ------------------------------------ | ------- | ----- | ----- | ------- | ----- | ----- |\n| | MAE | RMSE | SMAPE | MAE | RMSE | SMAPE |\n| MLP Fusion (one layer) | 41.49 | 63.14 | 0.76 | 21.50 | 32.92 | 0.56 |\n| Equivalent MLP Fusion (three layers) | 41.45 | 63.23 | 0.75 | 21.42 | 32.83 | 0.56 |\n| $D$-Mean-Pooling Fusion | 43.23 | 68.88 | 0.80 | 20.03 | 31.22 | 0.45 |\n| $N$-Mean-Pooling Fusion | 41.71 | 64.12 | 0.76 | 20.89 | 32.09 | 0.51 |\n| Air-DualODE | 40.32 | 62.04 | 0.74 | 18.64 | 29.37 | 0.42 |\n\nFor the Beijing dataset, stations are very close to each other, so incorporating features from other nodes helps improve the final prediction results ( see MLP-based Fusion and $N$-Mean-Pooling Fusion). Among these, the experiments show that GNN Fusion performs the best. Conversely, $D$-Mean-Pooling Fusion, which does not incorporate features from other nodes, performs the worst.\n\nFor the KnowAir dataset, stations are relatively far from the target station. The significant contrast between MLP-based Fusion and $D$-Mean-Pooling Fusion suggests that incorporating representations from distant nodes introduces noise. Additionally, $N$-Mean-Pooling Fusion also performs poorly because it does not necessarily incorporate representations from adjacent nodes. \n\nIn summary, these results do not contradict our view."}, {"comment_id": "N0UbqLvxwi", "replyto": "Ys1pxLTDIm", "author_type": "reviewer", "reviewer": "Reviewer_6Tsm", "comment": "Thank you for your thorough review of the revised paper. I have several follow-up questions and concerns regarding your modifications:\n\n---\n\nI appreciate the analysis in `Section A.16`, which helps justify the use of GNN. However, there seems to be a contradiction in your explanations:\n\n**Q1.** For KnowAir, `Figure 10` shows minimal change in error metrics as the number of layers increases, which appears inconsistent with your explanation about performance degradation due to incorporating distant node representations (`lines 1046-1048`). Could you clarify at which point the layers begin to show degradation in `Figure 10` via an additional sensitivity study?\n\n**Q2.** Additionally, mean pooling incorporates representations from all nodes, including distant ones, yet performs better than MLP. This seems to contradict your explanation about distant node representations acting as noise. Could you reconcile these apparent contradictions?\n\n\n---\n\nWhile your explanation about layer normalization dependency is reasonable when using Dopri5, I have two questions:\n\n**Q3.** Have you tested whether NaN values occur when using simpler solvers like RK4 or Euler method without layer normalization?\n\n**Q4.** How can you claim that layer normalization dependency doesn't affect the model's practicality and scalability, especially when the model fails without it? This seems to contradict your statement that \"these limitations do not affect these contributions.\" Could you refine and expand on this?\n\n**Q5.** While `A.17` is a good addition, I suggest expanding it to include limitations identified via discussions with other reviewers, discussing broader challenges in open system modeling and hybrid DualODE framework design, and providing more concrete future research directions for addressing these limitations.\n\n---\n\n**Q6.** Your current evaluation metrics (MAE, RMSE, MAPE) may not be the most appropriate for evaluating sudden change predictions. I suggest below aspects:\nThe authors need to consider including [mean directional accuracy (MDA)](https://en.wikipedia.org/wiki/Mean_directional_accuracy) to evaluate prediction direction (up/down) accuracy [2,3] and consider change point detection evaluation metrics [1]. Either the authors supplement your experiments with these metrics or acknowledge this limitation. The authors must add this discussion in the main text about the challenges of evaluating sudden change predictions.\n\nGiven the extended rebuttal period, I encourage you to either:\na) strengthen your experimental analysis using these suggested metrics or\nb) expand your limitations section to acknowledge these evaluation challenges and discuss potential alternative approaches.\n\n---\n\n> [1] Van den Burg, Gerrit JJ, and Christopher KI Williams. \"An evaluation of change point detection algorithms.\" arXiv preprint arXiv:2003.06222 (2020).\n>\n> [2] Witzke, Simon, et al. \"Mobility data improve forecasting of COVID-19 incidence trends using graph neural networks.\" epiDAMIK 6.0: The 6th International workshop on Epidemiology meets Data Mining and Knowledge Discovery at KDD 2023. 2023.\n>\n> [3] Blaskowitz, Oliver, and Helmut Herwartz. \"On economic evaluation of directional forecasts.\" International journal of forecasting 27.4 (2011): 1058-1065."}, {"comment_id": "mLrcUHZSYD", "replyto": "MiTq8hpQrR", "author_type": "authors", "reviewer": null, "comment": "Many thanks for your constructive suggestions to improve our manuscript. We have now revised our paper in the updated version. The summary of changes is as follows:\n\n1. We revise the explanation of GNN Fusion (cf. lines 351-355) and add an effective analysis of GNN Fusion in the $\\underline{\\text{Appendix A.16 Ablation study on GNN Fusion}}$.\n2. We update the content of $\\underline{\\text{Appendix A.15 Visualization of sudden changes' results}}$ to provide a more balanced analysis of Figure 12.\n3. We include $\\underline{\\text{Appendix A.17 Limitations Discussion}}$ to discuss Air-DualODE’s current limitations more clearly, and future work is included in $\\underline{\\text{Section 5: Conclusion and Future Work}}$.\n\nDespite these limitations, we believe that we have made considerable contributions on modeling open systems and designing a hybrid DualODE framework, and these limitations do not affect these contributions."}, {"comment_id": "yjcSEhqGAi", "replyto": "7abDnq6cPa", "author_type": "authors", "reviewer": null, "comment": "Thank you sincerely for your thoughtful feedback and for recognizing our work. Best wishes!"}, {"comment_id": "7abDnq6cPa", "replyto": "QCHfm2Hr32", "author_type": "reviewer", "reviewer": "Reviewer_joi5", "comment": "Thank you for your detailed response and for addressing the concerns raised in my initial review. I keep my current Accept score."}, {"comment_id": "1bLyJVXDYs", "replyto": "ZgDWEhOOWJ", "author_type": "authors", "reviewer": null, "comment": "Thank you sincerely for your thoughtful feedback and for recognizing our work. Best wishes!"}, {"comment_id": "ZgDWEhOOWJ", "replyto": "bY2ixFbWlq", "author_type": "reviewer", "reviewer": "Reviewer_iuiU", "comment": "Thank you for your clarifications to my questions/comments. I updated my rating to a 6 and hold a positive outlook on the paper."}, {"comment_id": "MiTq8hpQrR", "replyto": "Ys1pxLTDIm", "author_type": "reviewer", "reviewer": "Reviewer_6Tsm", "comment": "Thank you for your detailed responses. I have several suggestions that I believe would strengthen the paper. \n\nFirst, I notice that both Sum-Pooling and Mean-Pooling actually outperform MLP Fusion on the KnowAir dataset. Could you update the paper to provide a more nuanced justification for the \"necessity of using GNN\" through a comprehensive comparison of these results? This would make your argument more convincing.\n\nRegarding layer normalization and stability, I suggest addressing the dependency on layer normalization as a limitation section. Specifically, the fact that the model cannot work without layer normalization due to NaN values is an important limitation that should be explicitly discussed.\n\nWhile I understand that sudden change prediction is not the primary claim of your paper, I believe the current presentation could be improved. Instead of claiming to \"effectively handle\" sudden changes, I suggest providing a more balanced analysis of Figure 12, explicitly discussing both successful and unsuccessful predictions. This could include acknowledging the limitations in capturing sudden changes and reframing this analysis as an exploration of the model's capabilities rather than a definitive solution.\n\nI recommend expanding the limitations and future work sections to include a clear discussion of the model's current limitations in handling sudden changes, the dependency on layer normalization and its implications, and potential future research directions for addressing these limitations.\n\nBased on how these suggestions are incorporated into the paper, I would be willing to reconsider my evaluation."}, {"comment_id": "QCHfm2Hr32", "replyto": "ui3DPmcmG0", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer joi5,\n\nThank you for your valuable and constructive reviews, which has inspired further improvements to our paper. We have made an extensive effort to address your questions and concerns by providing additional results for 1-day and 2-day forecasts, revising our paper and complementing the details of the equations in the Appendix. We hope our response can effectively address your concerns, If you have any further concerns or questions, please do not hesitate to let us know, and we will respond timely.\n\nBest regards,\n\nAuthors"}, {"comment_id": "hg5y7hh8B6", "replyto": "2tXplAJRcB", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer NWR1,\n\nThank you for your valuable and constructive reviews, which has inspired further improvements to our paper. We have made an extensive effort to try to successfully address your concerns, by conducting experiments on sparse data. We hope our response can effectively address your concerns, If you have any further concerns or questions, please do not hesitate to let us know, and we will respond timely.\n\nBest regards,\n\nAuthors"}, {"comment_id": "bY2ixFbWlq", "replyto": "lQ2v56zcFV", "author_type": "authors", "reviewer": null, "comment": "## Scalability\n\n### **Q1**: complexity analysis\nWe provide a formal complexity analysis for the major components of Air-DualODE. In physical branch, the complexity is $O(I_P \\cdot |E|)$. In the data-driven branch with Spatial-MSA, the complexity is $O(I_D \\cdot M^2 \\cdot C)$. Here, $I_P$ and $I_D$ refer to the iteration time for the physics and data-driven dynamics, respectively. $|E|$ is the number of edge in the $G$. $C$ represents the dimension of latent variables. $M$ represents number of neighboring stations around the current station. When scaling with an increasing number of nodes, the complexity growth primarily depends on the number of neighboring nodes for each node rather than the total number of nodes.\n\n## Interpretability\n\n### **Q2**: computation of $\\boldsymbol{\\beta}$ and its physical meaning\n\nThe $\\boldsymbol{\\beta}$ is estimated according to the historical data: $T \\times N \\times D$, where $T$ means the historical timestamps, $N$ means the number of stations and $D$ means the number of input variables. The mapping of $\\boldsymbol{\\beta}$ can be defined as $\\mathbb{R}^{T \\times N \\times D} \\to \\mathbb{R}^{N \\times 1}$ using the RNN-based coefficient estimator, which assigns a specific $\\beta_i \\in \\mathbb{R}$ to each station $i$.\n\nRegarding the physical meaning: $\\beta_i > 0$ indicates that station $i$ is likely influenced by a source nearby like industrial activity and vehicular emissions , while $\\beta_i < 0$ suggests that station $i$ is likely influenced by a sink like forest absorption.\n\n### **Q3**: Beyond identifying sources and sinks, how can the model’s outputs inform specific policies or interventions? How are these outputs communicated to stakeholders in practice?\n\nBecause each time window estimates $\\boldsymbol{\\beta}$, if the estimation consistently results in $\\beta_i < 0$ over a continuous period, it indicates the presence of a sink near station $i$ that absorbs pollutants during that time. This provides stakeholders with insights or suggestions on how to create sinks to absorb pollutants. Conversely, if the estimation consistently results in $\\beta_i > 0$ over a continuous period, it signifies the presence of a source near station $i$ generating pollutants. In such cases, policymakers may need to implement regulations to reduce $\\beta_i$, mitigating pollutant generation.\n\n\n## GNN Fusion and Decay-TCL\n\n### **Q4**: alternative fusion approaches\n\nYes, we conducted experiments with alternative fusion approaches, see the Table below. The results show that other fusion approaches are not as effective as our GNN Fusion.\n\n| Model | Beijing | | | KnowAir | | |\n| ------------------------------------ | ------- | ----- | ----- | ------- | ------ | ----- |\n| | MAE | RMSE | SMAPE | MAE | RMSE | SMAPE |\n| MLP Fusion (one layer) | 41.49 | 63.14 | 0.76 | 21.5 | 32.919 | 0.56 |\n| Equivalent MLP Fusion (three layers) | 41.45 | 63.23 | 0.75 | 21.42 | 32.83 | 0.56 |\n| Mean-Pooling Fusion | 43.23 | 68.88 | 0.80 | 20.03 | 31.22 | 0.45 |\n| Sum-Pooling Fusion | 43.12 | 68.30 | 0.80 | 19.97 | 31.10 | 0.44 |\n| Air-DualODE | 40.32 | 62.04 | 0.74 | 18.64 | 29.37 | 0.42 |\n\n**Note:** MLP fusion introduces representations from distant stations, which act as noise for the current station. As a result, performance degradation occurs when GNN fusion is replaced with MLP fusion.\n\n### **Q5**: generalization to traffic forecasting or climate modeling\n\nTo clarify, Decay-TCL and the GNN fusion mechanism are specifically designed for Air-DualODE due to the physical and data-driven dynamics. The BA-DAE in the physical branch may not align with traffic forecasting or climate modeling because their governing equations are different. Therefore, Air-DualODE is not directly applicable to these tasks. However, by revising the Air-DualODE's physical branch with domain specific equations (e.g., Navier-Stokes equations for climate), we believe it has the potential to generalize to those domains."}, {"comment_id": "VkirI50SzW", "replyto": "vklpMGppyx", "author_type": "authors", "reviewer": null, "comment": "We would like to clarify that the paper's contribution mainly lie in open system modeling and the DualODE design to model both physical and data-driven dynamics, which is acknowledged by you in the original review. In addition, we hope that you appreciate our efforts on elaborating the technical details and that you are satisfied with our responses. If so, we kindly ask you to reconsider your score."}, {"comment_id": "vklpMGppyx", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "## Questions about model stability and implementation\n\n### **Q1.** exploration on alternative tolerance\n\nWe experimented with alternative tolerance, such as $10^{-2}$ and $10^{-5}$, but the issue of `NaN` still occurs. We believe that layer normalization plays an important role in stabilizing the training process.\n\n### **Q2.** The authors claim that Decay-TCL \"effectively reduces\" NFE. However, the numerical results show relatively small differences. Could you elaborate on what constitutes an \"effective reduction\" in this context?\n\nSorry for the confusion. Compared to `Cross-Space Fusion` as shown in the table w.r.t Q19 in the original review, `Air-DualODE` (which uses Decay-TCL) reduces both the three error metrics and the NFE, which is highly desirable. Thus, we think that Decay-TCL is effective.\n\nTo avoid confusion, we rephrase that sentence as follows: Therefore, we use Decay-TCL to align the representations in the same space, which reduces both the errors and the NFE, suggesting that it is effective.\n\n\n### **Q3, Q4:** comparison with equivalent numbers of MLP layers, more metrics in Fig.10 and more ablation studies about GNN Fusion\n\n**comparison with equivalent numbers of MLP layers**\n\nWe provide the experimental results for the equivalent MLP layers. The results are as follows:\n\n| Model | Beijing | | | KnowAir | | |\n| ------------------------------------ | ------- | ----- | ----- | ------- | ------ | ----- |\n| | MAE | RMSE | SMAPE | MAE | RMSE | SMAPE |\n| MLP Fusion (one layer) | 41.49 | 63.14 | 0.76 | 21.50 | 32.919 | 0.56 |\n| Equivalent MLP Fusion (three layers) | 41.45 | 63.23 | 0.75 | 21.42 | 32.83 | 0.56 |\n| Mean-Pooling Fusion | 43.23 | 68.88 | 0.80 | 20.03 | 31.22 | 0.45 |\n| Sum-Pooling Fusion | 43.12 | 68.30 | 0.80 | 19.97 | 31.10 | 0.44 |\n| Air-DualODE | 40.32 | 62.04 | 0.74 | 18.64 | 29.37 | 0.42 |\n\n**more metrics in Figure 10**\n\nFollowing your suggestions, we now include the RMSE and SMAPE metrics to prove the Air-DualODE's robustness in Fig.10.\n\n**more ablation study about GNN Fusion**\n\nWe also provide the results of mean pooling fusion and sum pooling fusion in the above table. The results suggest that Pooling Fusion performs worse than MLP Fusion, and GNN Fusion. Therefore, this justify the necessity of using GNN.\n\n\n## Questions about sudden changes prediction\n\n### **Q5, Q6, Q7**: sudden changes\n\nTo clarify, we include the experimental setting of sudden changes is mainly to ensure empirical fairness and consistency with the two SOTA methods --- both AirPhyNet and Airformer have this setting in their experiments. In our proposal, we do not have a specific design goal to capture sudden changes, neither does AirPhyNet nor Airformer. From both the aggregated results in Table1 in the original paper and the step-wise metrics shown in Table5 to respond your question Q4 in the second round review, it all demonstrates that Air-DualODE outperforms AirPhyNet and Airformer under sudden changes, though we cannot guarantee that our proposal can well capture every sudden change. We leave specific designs on better handling sudden changes as a promising future research direction."}, {"comment_id": "TsPRbeSX2Y", "replyto": "2CPGonR3Lv", "author_type": "reviewer", "reviewer": "Reviewer_iuiU", "comment": "Thank you for your detailed and thorough elaborations on my concerns. I have a few follow-up questions:\n\n**Scalability:**\n1. While you emphasize optimizations like Spatial-MSA, the explanation does not include a formal complexity analysis of how the computational cost scales with the number of nodes. Could you provide a formal complexity analysis (e.g., \\( O(n) \\) or \\( O(n^2) \\)) for the major components of Air-DualODE, particularly the ODE solvers and Spatial-MSA? How do these scale with increasing numbers of nodes?\n\n**Interpretability:**\n1. Could you elaborate on how $\\beta$ is computed and provide a clearer explanation of its physical meaning? \n2. Beyond identifying sources and sinks, how can the model’s outputs inform specific policies or interventions? How are these outputs communicated to stakeholders in practice?\n\n**GNN Fusion and Decay-TCL:**\n1. Have you conducted experiments with alternative fusion approaches, such as weighted averaging or attention mechanisms, in addition to GNN fusion? \n2. Can Decay-TCL and the GNN fusion mechanism be generalized to other spatiotemporal tasks, such as traffic forecasting or climate modeling? If so, what modifications, if any, would be required to adapt them to different domains?"}, {"comment_id": "RhXFS7JH2Z", "replyto": "ls1SBYJQUs", "author_type": "reviewer", "reviewer": "Reviewer_v3np", "comment": "Thanks the authors for the efforts to address my comments. \nMy questions have been clarified and I remain my positive view on this paper. I keep my current Accept score."}, {"comment_id": "v9kaJwC4Yi", "replyto": "Ys1pxLTDIm", "author_type": "reviewer", "reviewer": "Reviewer_6Tsm", "comment": "Thank you for your response. I have remaining questions for clarification.\n\n---\n\n**Questions about model stability and implementation**\n\n**Q1.** Regarding the observed behavior when removing layer normalization, I'm curious about the distinction between training instability and numerical stability. Additionally, could the NaN occurrences be related to the Dopri5's relative and absolute tolerance settings to 1e-3? Have alternative tolerance values been explored?\n\n**Q2.** The authors claim that Decay-TCL \"effectively reduces\" NFE. However, the numerical results show relatively small differences. Could you elaborate on what constitutes an \"effective reduction\" in this context?\n\n**Q3.** About the ablation study of GNN Fusion, would it be possible to provide experimental results comparing equivalent numbers of GNN layers and mapping (i.e., MLP) layers? Additionally, could RMSE and SMAPE metrics be included alongside the MAE values reported in Figure 10?\n\n**Q4.** For the mapping by MLP from $\\mathbb{R}^{N\\times 2D} \\rightarrow \\mathbb{R}^{N\\times 1}$, have alternative approaches such as mean pooling or sum pooling been considered alongside the MLP implementation? Could the authors further justify the necessity of GNN by comparing the results using these pooling methods?\n\n**Questions about sudden changes prediction**\n\n*The visualization of sudden changes prediction raises several questions about the model's capabilities:*\n\n**Q5**. In Figure 12, there appear to be challenges in predicting significant changes: Figures 12(a) and 12(b) show limitations in capturing sharp increases during time steps 20-30, while (c) and (d) show difficulties in representing post-decrease volatility. Could you clarify how these results demonstrate effective handling of trend changes?\n\n**Q6.** The predictions seem to follow smoothed trends rather than capturing sudden changes. Could you explain more about how this aligns with the goal of sudden change prediction?\n\n**Q7.** Would it be possible to provide more comprehensive quantitative metrics for sudden change prediction?"}, {"comment_id": "TufDtH0Zzm", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer 6Tsm,\n\nWe hope that we have clarified the new questions you have raised. If we have satisfactorily addressed all your questions, we kindly ask you to reconsider your score.\n\nThanks for your valuable suggestions and comments to improve our work !\n\nBest Regards,\n\nAuthors"}, {"comment_id": "yRFpd82B37", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "We are glad that our responses resolve most of your questions and concerns from the original review. Your new questions are addressed as follows.\n\n\n### **Q1**: I'm not fully convinced that training loss convergence alone can completely prove numerical stability.\n\nWe have not observed any numerical instability issues in our experiments. If you have more specific suggestions on investigating numerical stability, we are happy to provide more information according to your suggestions.\n\n### **Q2**: relationship between layer normalization and numerical stability\n\nWhen removing all layer normalization from Air-DualODE, gradient explosion and `NAN` occurred after a few epochs, preventing us to have a stable model.\n\n\n### **Q3**: sudden change visualization\n\nWe provide a visualization of the results for sudden changes in Appendix $\\underline{\\text{A.15 Visualization of sudden changes' results}}$ (Fig.12). The visualizations demonstrate that Air-DualODE can effectively predict upward and downward trends during periods of sudden change.\n\n### **Q4**: each step error comparison\n\nConsidering that a comparison of errors for all 24 steps would be overwhelming, we have provided the error comparison for the 1st, 3rd, 5th, …, 21st, and 23rd steps in Appendix $\\underline{\\text{A.10 More experiment results}}$ (Table 5) of the updated version.\n\n\n### **Q5**: How does NFE change in each branch of Air-DualODE's dual branch structure?\n\nNFE changes in each branch depend on their structures because, similar to solving ODEs, different ODEs require different iteration steps (NFE) to achieve numerical precision. From the table in Q19, Physics dynamics' NFE is larger than Data-Driven dynamics' NFE.\n\n### **Q6**: You claimed that Decay-TCL \"effectively reduces NFE\" -- could you explain more about what you mean by \"effectively reduces NFE\" in your result?\n\n\"Effectively reduces NFE\" refers to the advantage of `Air-DualODE` over `Cross-Space Fusion`. In `Cross-Space Fusion`, the mismatch between the explicit physical equation and the latent data-driven representations leads to large NFE. By placing both of them into the latent space and aligning them using Decay-TCL, `Air-DualODE` achieves smaller NFE, thus effectively reducing NFE compared to `Cross-Space Fusion`.\n\n### **Q7**: removing Spatial-MSA slightly decreases NFE\n\nBecause Spatial-MSA prevents the Data-Driven branch from focusing on distant stations, the adaptive ODE solver needs to extract more useful representations from fewer stations within the Data-Driven branch. As a result, Spatial-MSA slightly increases the NFE of Air-DualODE.\n\n\n### **Q8**: notation confusion and ablation study of GNN Fusion\n\n**notation confusion**\n\nWe revise this notation confusion in the updated version.\n\n**ablation study of GNN Fusion**\n\nWe could provide the results for both datasets when no GNN layers are used. To enable fusion across stations, we replaced the GNN Fusion with an MLP structure, which maps $\\mathbb{R}^{N \\times 2D} \\to \\mathbb{R}^{N \\times 1}$ for pollutant prediction. The results are as follows.\n\n| Model | Beijing | | | KnowAir | | |\n| -------------- | ------- | ----- | ----- | ------- | ------ | ----- |\n| | MAE | RMSE | SMAPE | MAE | RMSE | SMAPE |\n| w/o GNN Fusion | 41.49 | 63.14 | 0.76 | 21.5 | 32.919 | 0.56 |\n| Air-DualODE | 40.32 | 62.04 | 0.74 | 18.64 | 29.37 | 0.42 |\n\nIt can be observed that w/o GNN Fusion results in a performance drop compared to Air-DualODE on both datasets. This drop is particularly significant on the larger, nationwide KnowAir dataset. Due to the large distances between some stations, using an MLP structure for fusion introduces representations from distant stations, which act as noise for predicting pollutant concentrations at the current station. Consequently, the degradation is more obvious on larger KnowAir dataset."}, {"comment_id": "alSxDC83qG", "replyto": "Ys1pxLTDIm", "author_type": "reviewer", "reviewer": "Reviewer_6Tsm", "comment": "Thank you for your responses. While most of your answers address my questions and concerns, I still have remaining questions:\n\n**Regarding reviewer's response to \"W2 & Q11\" and the newly added Fig. 8**\n1. I'm not fully convinced that training loss convergence alone can completely prove numerical stability.\n2. You mentioned that Layer Normalization contributes to numerical stability -- could you provide additional ablation studies on Layer Normalization to support this claim?\n\n**Regarding sudden changes prediction**\n\n3. How accurate are your predictions when visualized for sudden changes? Specifically, if there's an upward trend in a sudden change period, does your model correctly predict this upward movement?\n4. Since you're forecasting 24 steps, could you provide each step error comparison between your model and representative baselines? Currently, Table 4 only shows errors by period.\n\n**Regarding NFE analysis**\n\n5. How does NFE change in each branch of Air-DualODE's dual branch structure?\n6. You claimed that Decay-TCL \"effectively reduces NFE\" -- could you explain more about what you mean by \"effectively reduces NFE\" in your result?\n7. How do you explain that removing Spatial-MSA slightly decreases NFE?\n\n**Regarding GNN implementation**\n\n8. I've reviewed your GNN layer sensitivity study. There appears to be notation confusion between '$n$' in the main paper and '$n$' in Appendix A.6. Your sensitivity analysis shows little variation from 1 to 5 layers. Could you show results for both datasets when using no GNN layer ($n=0$), i.e., using just $Z_t^{(0)}$ as mentioned in line 354? *I will carefully review how this response aligns with and makes sense to the authors' claim to reviewer `v3np` that \"GNN Fusion is a natural choice\".*"}, {"comment_id": "2CPGonR3Lv", "replyto": "lQ2v56zcFV", "author_type": "authors", "reviewer": null, "comment": "### **Q2**: theory and intuition behind Decay-TCL mechanism and the hyperparameters' influence to the alignment and fusion process\n\n**theory and intuition behind Decay-TCL mechanism**\n\nDecay-TCL is derived from temporal contrastive learning techniques. In Section 3.4, $\\lambda_1$ controls the soft weight distribution for intra-dynamics latent representations, while $\\lambda_2$ controls the soft weight distribution for inter-dynamics latent representations. These hyperparameters determine the sharpness of $w(t, s)$. The similarity of intra-dynamics latent representations should decay along the time axis, while the decay of inter-dynamics latent representations should account for the differences between two types of dynamics, which satisfies the rule1 and rule2 (rf. line323-324). It is reasonable to enforce $\\lambda_1 > \\lambda_2$, as intra-dynamics latent representations should exhibit greater similarity than inter-dynamics, satisfying Rule 3 (ref. line 325). By adhering to these rules, the soft InfoNCE $\\ell(t)$ in Eq.11 adjusts the similarity of the two dynamics’ latent representations effectively.\n\n**the hyperparameters' influence to the alignment and fusion process**\n\nIf both $\\lambda$ are very large, the distribution of $w(t, s)$ becomes sharper, significantly reducing the number of soft positive samples. This forces the model to focus more on learning the differences between the dynamics’ latent representations. However, dynamics inherently exhibit continuity, and emphasizing large differences can disrupt this continuity. This negatively impacts ODE solver-based learning and breaks the alignment between latent representations.\n\nConversely, if both $\\lambda$ are very small, the distribution of $w(t, s)$ becomes smoother, reducing the number of negative samples. This causes the model to focus more on the commonalities between the dynamics’ latent representations. However, since physics-based dynamics already capture spatiotemporal dependencies, this approach may prevent data-driven dynamics from learning additional spatiotemporal dependencies in the latent space, rendering the subsequent fusion process ineffective.\n\nTherefore, considering these impacts and the constraint $\\lambda_1 > \\lambda_2$, we conducted a grid search (restricted to the upper or lower triangular region) rather than relying on experience, ultimately obtaining $\\lambda_1 = 1$ and $\\lambda_2 = 0.8$.\n\n\n\n### **Q3**: how the GNN fusion mechanism effectively balances the two components?\n\nFirst, before applying the GNN fusion, the Decay-TCL are used to map the representations into a shared latent space as much as possible, aligning physics and data-driven dynamics latent representations. Second, the GNN layers adjust the influence of individual nodes based on their spatial connectivity, allowing the model to balance information across geographically connected stations."}, {"comment_id": "RQHMimYzD8", "replyto": "lQ2v56zcFV", "author_type": "authors", "reviewer": null, "comment": "### **W3, Q5**: the prediction's interpretability for policy maker\n\nOur model can provide stakeholders with explanations to some extent. In Section 4.4, the second case study presents a visualization of $\\boldsymbol{\\beta}$ under two scenarios in Beijing. In our view, $\\boldsymbol{\\beta}$ can serve as an important reference indicator for stakeholders: $\\beta_i < 0$ indicates a sink (e.g., out-of-boundary diffusion/advection or forest absorption), while $\\beta_i > 0$ indicates a source (e.g., industrial activity or vehicular emissions). This provides stakeholders with insights into how the model makes predictions to some degree.\n\n\n\n### **W4**: Air-DualODE Compared with CMAQ, WRF and AERMOD\n\nWe appreciate your observation regarding the absence of a direct comparison with physics-based models such as CMAQ (Community Multiscale Air Quality) and WRF (Weather Research and Forecasting). These models require numerous input variables, including land use data, boundary conditions, meteorological parameters (e.g., boundary layer height) which are not available for our dataset. Thus, we are unable to run such physics based models for comparisons.\n\n\n### **Q1**: mathematical intuition on $\\boldsymbol{\\beta} X$ and significance within the framework\n\n**mathematical intuition on $\\boldsymbol{\\beta} X$**\n\nThis term is motivated from the Infectious differential equation in SIR model.\n$$\n\\begin{align*}\n\\frac{dI}{dt} &= \\eta \\frac{SI}{N} - \\gamma I \\\\\\\\\n\\frac{dR}{dt} &= \\gamma I\n\\end{align*}\n$$\n\n* **Susceptible (S), Infectious (I), Recovered (R)** \n* $N = S + I + R$ is the total population.\n* $\\eta$ is the transmission rate and $\\gamma$ is the recovery rate.\n\nIn these differential equations, the growth rate of recovery is modeled as a proportion to infected population. However, in reality, their relationship exhibits complex nonlinear dependencies. Motivated by this design and the principle of Occam’s razor, we also use $\\boldsymbol{\\beta} X$ to model sources or sinks dynamics, which are crucial in an open air system.\n\n**significance within the framework**\n\nCombined with Eq.5 in Section 3.2.1, it can be observed that BA-DAE no longer satisfies mass conservation, which is more realistic for an open air system. As mentioned in Section 3.2.2, the dissipation of pollutants cannot exceed the current pollutant concentration; therefore, $\\boldsymbol{\\beta}$ has a minimum value of -1, i.e., all pollutants are dissipated. However, the generation of pollutants may far exceed the current concentration. Thus, $\\boldsymbol{\\beta}$ can theoretically have a maximum value of $+\\infty$. Besides, we perform an ablation study to evaluate the significance of $\\boldsymbol{\\beta} X$ in open air system. In Fig.5 DAE-c without $\\boldsymbol{\\beta} X$ is for closed system modeling, while BA-DAE-o with $\\boldsymbol{\\beta} X$ consider the open air system. The results highlights the significance of $\\boldsymbol{\\beta} X$ in our framework."}, {"comment_id": "VdQkYO5Dz1", "replyto": "lQ2v56zcFV", "author_type": "authors", "reviewer": null, "comment": "Thank you for the insightful review. We appreciate the acknowledgment of our novel architecture and strong results. In response to your feedback, the noted concerns are addressed below.\n\n\n### **W1, Q4**: the scalability of Air-DualODE and concerns about the computational cost\n\nThat is a good question! This is precisely why we conducted experiments on the KnowAir dataset, which represents a national-level air quality prediction scenario. From our perspective, the computational cost of Air-DualODE should remain manageable even when applied to larger datasets. Unfortunately, there does not exist publicly available datasets with larger sizes. While Airformer utilizes a nationwide dataset with 1,085 stations, this dataset is not publicly accessible online. To address this, we generate a synthetic dataset with 1085 nodes to evaluate the method’s scalability on large-scale datasets.\n\nWe selected AirPhyNet as a competitor because both AirPhyNet and Air-DualODE belong to the family of physics-guided machine learning approaches. Here is their training and inference time costs (one epoch): \n\n* Training Time\n\n| | AirPhyNet | Air-DualODE |\n| --------------- | ----------------------- | ----------- |\n| Beijing(35) | 90s | 81s |\n| KnowAir(184) | 3870s | 378s |\n| Synthetic(1085) | larger than three hours | 1235s |\n\n* Inference Time\n\n| | AirPhyNet | Air-DualODE |\n| --------------- | ---------------------- | ----------- |\n| Beijing(35) | 2.75s | 1.98s |\n| KnowAir(184) | 540s | 22.5s |\n| Synthetic(1085) | larger than 30 minutes | 54s |\n\nFor fair comparison, these experiments are conducted on **a single** NVIDIA A800 Tensor Core GPU. \n\nAirPhyNet involves many MLP structures (e.g., Gated Fusion) tied to the number of stations during the solving process. The gradients of these parameters are tracked and recorded in the GPU during training, leading to an increase in the computational cost of the ODE solver as the number of stations grows. In contrast, Air-DualODE incorporates several optimizations, such as Spatial-MSA, which effectively reduce the overall computational complexity. Therefore, despite utilizing two ODE solvers, the overall computational cost of Air-DualODE remains manageable.\n\n### **W2, Q6**: the difference between AirPhyNet and Air-DualODE\n\nThe main differences are as follows:\n\n| Difference | AirPhyNet | Air-DualODE |\n| ------------------------------------------------ | ------------------------------------------ | ----------------------------------------------------------- |\n| **Computational efficiency** | Slow | Fast |\n| **Air system modeling** | Closed system | Open system |\n| **Physics-guided approaches** | Explicit physical equation on latent space | Explicit physical equation on explicit space and GNN fusion |\n| **Matching physics with latent representations** | No mechanism to ensure | Decay-TCL |\n| **Spatial module** | GNN-based differential equation | NODE with Spatial-MSA and BA-DAE |\n| **Temporal module** | GNN-based differential equation | NODE and BA-DAE |\n\n1. As discussed in W1, Q4, Air-DualODE outperforms AirPhyNet in terms of computational efficiency and scalability to a larger number of nodes.\n2. Air-DualODE models open air systems, making it more realistic compared to the closed system assumption of AirPhyNet.\n3. Air-DualODE applies physical equations directly to explicit variables, which is more reasonable than applying them to latent variables that no longer hold actual physical meanings.\n4. To address the mismatch between explicit physical equations and data-driven latent representations, Air-DualODE employs Decay-TCL for alignment, whereas AirPhyNet overlooks this issue.\n5. AirPhyNet relies only on GNN-based differential equations to model spatiotemporal dependencies, whereas Air-DualODE incorporates multiple components, including BA-DAE for modeling in the explicit space and NODE with Spatial-MSA for capturing additional dependencies in the latent space."}, {"comment_id": "lKFIDQpp98", "replyto": "ui3DPmcmG0", "author_type": "authors", "reviewer": null, "comment": "We would like to sincerely thank Reviewer joi5 for acknowledging our work's novelty and contributions. Your curiosity about Data-Driven Neural ODE and our experiment results are resolved in following contents. Some incomplete contents and typos have been revised in the updated version.\n\n\n### **W1, Q3**: details about Data-Driven Neural ODE and other potential methods to replace\n\n**Details about Data-Driven Neural ODE**\n\n$F^D$ is a data-driven branch's ODE function with Spatial-MSA Structure. We have provided the details of $F^D$ in the Appendix $\\underline{\\text{A.14 The details of $\\mathbf{F^D}$}}$. The following contents are about $\\mathbf{F^D}$ formula.\n$$\n\\begin{align*}\n Q, K, V &= \\text{Projection}(Z^D), \\\\\\\\\n Z^D &= Z^D + \\text{Spatial-MSA}(Q, K, V), \\\\\\\\\n \\frac{d Z^D}{dt} &= \\text{LN}(Z^D + \\text{MLP}(\\text{LN}(Z^D))) = \\mathbf{F^D}.\n\\end{align*}\n$$\nAmong these, $Q$, $K$, and $V$ are obtained through a linear projection of $Z^D$. The spatiotemporal dependencies are then captured using the carefully designed Spatial-MSA and a residual connection in the form of a data-driven derivative. Layer normalization is applied to ensure numerical stability.\n\n**other potential methods to replace**\n\nWe consider air pollutant propagation as a spatiotemporal dynamic system. Neural ODE, serving as a bridge between dynamic systems and neural networks, is better suited for modeling the spatiotemporal dynamics of pollutants compared to traditional sequence models like RNN and Transformer.\n\n\n### **W2**: more results on 1-day and 2-day forecasts and discrepancies between our results and previous works\n\nWe have provided the 1-day and 2-day results in the updated version's Appendix $\\underline{\\text{A10. More experiment results}}$.\n\n**discrepancies between our results and previous works**\n\nAt the dataset level, the preprocessing methods used in other studies differ from ours. Due to missing values in the raw data of the Beijing dataset, previous work used linear interpolation to fill in the gaps over time. However, this approach neither incorporates data from other stations at the same time nor adequately addresses the unrealistic assumption of linearity in cases of large consecutive missing intervals. Our preprocessing approach considering these two problems is described in detail in Appendix $\\underline{\\text{A.5 Dataset Description}}$. At the experimental design level, the settings are also different. For example, in the KnowAir dataset, PM25GNN uses wind speed and direction at the next time step as covariates for all models. However, we believe this design leaks future data, which may deviate from real-world pollutant concentration prediction scenarios. To ensure reproducibility, we have provided an anonymous repository link containing our code (cf. abstract line 027).\n\n### **W3, Q1**: Language Errors in Figures and Tables\n\nThank you for pointing out these issues. We have revised in the updated version.\n\n\n\n### **Q2**: Incomplete and Redundant Equations in the Appendix\n\nWe have revised the math part in the Appendix and given more details about diffusion equation and advection equation in the updated version's Appendix $\\underline{\\text{A.2 The Derivation of Diffusion-Advection Equation}}$."}, {"comment_id": "ls1SBYJQUs", "replyto": "uXTRkVcV9N", "author_type": "authors", "reviewer": null, "comment": "Thank you very much for recognizing our idea and contributions. The shortcomings you mentioned have been addressed in the updated version.\n\n\n\n### **W1**: necessary mathematic description in Appendix\n\nThank you for your suggestion. We have added discrete diffusion and advection equations and their explanations in the updated version's Appendix $\\underline{\\text{A.2 The Derivation of Diffusion-Advection Equation}}$ and $\\underline{\\text{A.3 Proof of closed system's conservation property}}$.\n\n\n\n### **W2**: the role of spatial-MSA in data-driven branch\n\nThe purpose of Spatial-MSA is to exclude distant nodes from the computation of attention scores, preventing interference and noise from affecting the current station. This is because we aim for the data-driven branch to focus solely on capturing the spatiotemporal dependencies of adjacent nodes.\n\n\n\n### **W3**: some typos in our paper\n\nWe have corrected all typos in the updated version.\n\n\n\n### **Q1**: Why using the GNN Fusion after temporal alignment? What about other simple structure like MLP? \n\nGNN Fusion leverages a distance-based graph structure. While this design may not significantly improve performance on the Beijing dataset (because these stations are close to each other), it proves highly effective for the national-level KnowAir dataset. Intuitively, nodes that are closer together should interact with other stations' dynamics latent representations, whereas distant nodes do not require such interactions. Therefore, GNN Fusion is a natural choice.\n\n**other network structure**\n\nFirst, considering that the spatial distribution of stations forms a non-Euclidean structure, CNNs are inherently unsuitable for handling such data. Second, using MLP for fusion introduces two approaches: site-level and global-level. A site-level MLP can only fuse the two types of dynamics latent representations for a given station, without enabling interactions with features from other stations. Conversely, a global-level MLP introduces features from distant stations, which serve as irrelevant noise for the current station. Also, this contrasts with Spatial-MSA, where we aim to let the data-driven model capture features from nearby stations.\n\n\n\n### **Q2**: Is there other data should be included for geospatial graph construction?\n\nYes, we utilized SRTM data to construct the GeoSpatial Graph. Naturally, pollutants cannot propagate over large distances or when geographical barriers, such as mountains, exist between stations. The methodology for constructing the GeoSpatial graph is described in detail in Appendix $\\underline{\\text{A.4 GeoSpatial Graph}}$."}, {"comment_id": "2tXplAJRcB", "replyto": "bFILgrO9MG", "author_type": "authors", "reviewer": null, "comment": "We would like to sincerely thank Reviewer NWR1 for acknowledging our work's novelty and contributions. Your concerns about sparse data's performance and usage in other domains are resolved in following contents. \n\n\n### **W1**: stakeholders require clear explanations of how predictions are made. (Although not very important point, we still want to elaborate.)\n\nThank you for recognizing our work. Actually, our model can provide stakeholders with explanations to some extent. In Section 4.4, the second case study presents a visualization of $\\boldsymbol{\\beta}$ under two scenarios in Beijing. In our view, $\\boldsymbol{\\beta}$ can serve as an important reference indicator for stakeholders: $\\beta_i < 0$ indicates a sink (e.g., out-of-boundary diffusion/advection or forest absorption around station $i$), while $\\beta_i > 0$ indicates a source (e.g., industrial activity or vehicular emissions around station $i$). This provides stakeholders with insights into how the model makes predictions to some degree. We hope this response addresses your concerns.\n\n\n### **Q1**: Air-DualODE's performance on spare data and minimum data density for effective predictions\n\n**Air-DualODE's performance on spare data**\n\nTo study how Air-DualODE performs in regions with sparse or inconsistent air quality data, we vary the size of the training set from 10% to 100%, while keeping the validation and test sets fixed. The results are presented in the table below:\n\n| Beijing | MAE | RMSE | SMAPE |\n| ------- | ------- | ------- | ------ |\n| 10% | 44.0587 | 67.6437 | 0.8279 |\n| 20% | 43.5233 | 66.664 | 0.8004 |\n| 30% | 42.5117 | 65.9448 | 0.7685 |\n| 70% | 41.7993 | 63.9566 | 0.7672 |\n| ALL | 40.3208 | 62.0407 | 0.7388 |\n\n| KnowAir | MAE | RMSE | SMAPE |\n| ------- | ------- | ------- | ------ |\n| 10% | 20.0789 | 31.4796 | 0.4493 |\n| 20% | 19.8123 | 30.8212 | 0.4454 |\n| 30% | 19.3563 | 30.5257 | 0.4365 |\n| 70% | 18.9352 | 29.9471 | 0.429 |\n| ALL | 18.6431 | 29.3657 | 0.4213 |\n\nThe following two conclusions can be drawn from the tables below:\n\n1. As the amount of training data increases, the performance of Air-DualODE improves correspondingly (scaling law in data sizes).\n2. Even with only 30% of the training set available, Air-DualODE still performs well and does not degrade significantly.\n\nThis demonstrates that our model can still achieve strong performance even with sparse data.\n\n**minimum data density for effective predictions**\n\nAccording to the experiment results, minimum data density for effective predictions is around 30%.\n\n### **Q2**: tested on other pollutants and usage in other domain\n\n**tested on other pollutants**\n\nIn the experiments, we focused on PM2.5 due to the availability of the dataset (KnowAir only provides three-hour interval data for PM2.5) and to maintain consistency with previous studies. However, we believe that other pollutants, being essentially fine particles, can also be effectively modeled by Air-DualODE.\n\n**usage in other domain**\n\nThanks for your suggestion. Air-DualODE is specifically designed for air pollutant prediction, and the BA-DAE in its physical branch may not align with the requirements for water quality prediction. Therefore, Air-DualODE is not directly applicable to water quality prediction tasks. However, by revising the Air-DualODE's physical branch with water quality-specific equations, we believe it has the potential to effectively address water quality prediction tasks."}, {"comment_id": "7V67sQKxCV", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "### **Q14**: the suggestion of appropriate title\n\nThank for your suggestion. We believe you are correct, and we have revised our phrasing in the updated version to avoid potential misunderstandings. We plan to change the title to \"**Air Quality Prediction with Physics-Guided Dual Neural ODEs in Open Systems**\", if the paper is accepted and the chair approves this title.\n\n\n### **Q15**: the details of RNN Coefficient Estimator\n\nThis is a sequence modeling approach for estimating the coefficients $\\boldsymbol{k}, \\boldsymbol{\\beta} \\ \\in R^{N \\times 1}$. In the implementation, we use a GRU block to construct the RNN, enabling the processing of historical data $(X_i, A_i)_{i = 1}^T$.\n\n\n### **Q16**: information supplementation in Fig. 2 and the relationship between Dynamics Fusion and Section 3.4\n\nIn Eq. (6), $\\alpha$ represents a weighted sum of diffusion and advection. This information has been supplemented in Fig. 2 in the updated version. The Dynamics Fusion in Fig. 2 illustrates the inequalities that must be satisfied between the two dynamics latent representations. These relationships are enforced by the Decay-TCL loss, as discussed in Section 3.4.\n\n### **Q17**: the visualization of $G_{\\text{diff}}$ and $G_{\\text{adv}}$ change dynamically\n\nIt is important to clarify that $G_{\\text{diff}}$ is not dynamically changing, as described in Section 3.2.1 Diffusion Graph. As for $G_{\\text{adv}}$, we have provided additional visualizations in updated version's Appendix $\\underline{\\text{A.13 Visualization of Advection graph}}$.\n\n### **Q19**: NFE comparisons in Table 2's ablation studies\n\nThe following table contains the NFE comparison of Table 2.\n\n| Model | Beijing | | | | KnowAir | | | |\n| ------------------------ | ------- | ----- | ----- | ------- | ------- | ----- | ----- | ------- |\n| | MAE | RMSE | SMAPE | NFE | MAE | RMSE | SMAPE | NFE |\n| w/o Physics Dynamics | 42.75 | 63.45 | 0.75 | 14.00 | 19.69 | 30.65 | 0.45 | 14.00 |\n| w/o Data-Driven Dynamics | 44.33 | 65.6 | 0.82 | 20.00 | 21.21 | 33.09 | 0.56 | 19.73 |\n| Explicit Fusion | 41.32 | 63.08 | 0.76 | 29.27 | 19.06 | 30.28 | 0.43 | 30.40 |\n| Cross-Space Fusion | 41.97 | 63.12 | 0.77 | 33.64 | 19.27 | 30.54 | 0.44 | 36.00 |\n| w/o Decay-TCL | 42.34 | 66.56 | 0.82 | 30.73 | 19.10 | 29.79 | 0.43 | 33.73 |\n| w/o Spatial-MSA | 40.52 | 62.49 | 0.75 | 32.91 | 18.97 | 30.01 | 0.43 | 34.00 |\n| Air-DualODE | 40.32 | 62.04 | 0.74 | 32.64 | 18.64 | 29.37 | 0.42 | 35.46 |\n\nFrom the table above, we can draw the following conclusions:\n\n1. Dual branches result in a slightly higher NFE but deliver better performance compared to a single branch, owing to the use of two ODE solvers.\n2. Cross-Space Fusion increases the NFE because of the mismatch between explicit physical equations and latent data-driven representations. Therefore, we use Decay-TCL to align the representations in the same space, which reduces both the errors and the NFE, suggesting that it is effective.\n\n### **Q20**: coverage of advanced GNNs in Related Work and their similarities and differences with Air-DualODE\n\n**coverage of advanced GNNs in Related Work**\n\nThank you for providing these articles. We believe that some of the GNN designs referenced in these works draw inspiration from the diffusion equation in PDEs. Therefore, we have included them in the related work section of our updated version.\n\n**their differences with Air-DualODE**\n\nIn [3, 4, 5, 6], the diffusion equation is used to model different objects compared to Air-DualODE. In these GNNs, the diffusion equation serves as a motivation for improving classical GNNs, primarily addressing challenges such as over-smoothing and gradient vanishing. In contrast, Air-DualODE’s diffusion refers to the physical phenomenon of pollutant transport caused by uneven concentration distributions. In Air-DualODE, we use the diffusion equation to model pollutant transport's dynamics. In summary, the use of diffusion equation in [3, 4, 5, 6] versus in our Air-DualODE are largely orthogonal."}, {"comment_id": "roxnBrNQuX", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "### **Q7**: the relationship between visualized $\\boldsymbol{\\beta}$ and pollution source/sink.\n\nIn the case study presented in Section 4.4, we use two specific time points to illustrate the source/sink scenarios in Beijing (dissipation out of boundary and generation from the boundary can also be seen as special sink and source). When $\\beta_i > 0$, it indicates that the pollutant growth rate at the station, beyond diffusion and advection, is positive, signifying the presence of a source phenomenon. Conversely, when $\\beta_i < 0$ , it represents a sink phenomenon.\n\nIn the left part of Fig. 7, there is a northwest wind, leading to an advection out-of-boundary effect at the stations highlighted by the red circles. This manifests as a sink phenomenon, causing their $\\beta$ to trend negative. Conversely, some stations exhibit positive $\\beta$, which can be attributed to pollutants transported by the wind from outside the boundary to these stations, representing a source phenomenon.\n\nIn the right part of Fig. 7, the wind effect is minimal. For boundary stations, diffusion out of the boundary acts as a sink phenomenon. Meanwhile, for internal stations, the morning rush hour in the city contributes to a source phenomenon due to vehicle emissions. In general, the source/sink nature at different stations can help validate the signs of their $\\beta$. For instance, stations located in areas with abundant trees often experience sink phenomena.\n\n### **Q8**: the sensitivity analysis of $\\boldsymbol{\\beta}$\n\n$\\boldsymbol{\\beta}$ is not a hyperparameter, so sensitivity analysis cannot be performed. Rather, it is a value estimated from historical data $( X_i, A_i )_{i = 1}^T$ by the coefficient estimator, which is a GRU based RNN, as shown in the Fig.2's \"Physics Dynamics\" part.\n\n\n### **Q10**: comparison with climate model and the difference between Air-DualODE and \\[1\\]\\[2\\]\n\n**comparison with climate model**\n\nWe sincerely appreciate you sharing these two articles. [1, 2] primarily focus on climate forecasting scenarios, introducing the Neural Diffusion Equation (NDE) and Neural Diffusion-Advection Equation (NADE). In contrast, Air-DualODE is specifically designed for spatiotemporal sequence modeling of air pollutant propagation, whereas NDE and NADE target climate prediction tasks. Due to concerns about the differing domains of application, we did not include them as baselines for comparison. However, as both NDE and NADE fall under the category of physics-guided machine learning, we have incorporated them into the related work section in the updated version to provide a more comprehensive review of relevant studies.\n\n**the difference between Air-DualODE and [1, 2]**\n\nNADE maps the current values to a latent space through an encoder, then iteratively solves the diffusion-advection equation to obtain future latent states, and finally gets predictions using a decoder. Additionally, NADE models uncertainty, which is crucial in the field of climate forecasting. Both NADE and ClimODE(Verma et al., 2024) are pioneering works in physics-guided machine learning and are inspiring.\n\nIn terms of similarity, both NADE and Air-DualODE utilize the diffusion-advection equation, derived from the mass conservation law, as it plays a vital role even in meteorological domains. However, there are significant differences between NADE and Air-DualODE:\n\n1. **Modeling Objects**: In Air-DualODE, the diffusion-advection equation primarily models pollutant propagation, while in NADE, it is used to model latent variables in meteorology.\n2. **Different Equation**: Air-DualODE’s BA-DAE considers an open system, while NADE’s diffusion-advection equation, adhering to the mass conservation law, primarily applies in a closed system.\n3. **Discretization Method**: Air-DualODE discretizes the equation into ODEs using the Method of Lines (MOL) and approximates the Laplacian operator with GNNs and Laplacian matrices. In contrast, NADE directly replaces the Laplacian operator with a precomputed Laplacian matrix.\n4. **Physics-Inspired Approach**: Air-DualODE applies the equation explicitly to pollutants in a physical space and combines it with a data-driven NODE for the final results. On the other hand, NADE applies the equation to latent variables and derives the final predictions via a decoder.\n\n### **Q12**: concerns of two dynamics' latent representations in the different state space\n\nThat’s an excellent question. For this very reason, we introduce Decay-TCL, incorporating a decaying weight $w(t, s)$ based on timestamp differences and dynamics types, to regulate the similarity between latent representations. This approach aims to map the two representations into a shared latent space as much as possible, facilitating subsequent fusion."}, {"comment_id": "HOqXD4vWh6", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "### **W4, Q6, Q9:** Case studies and hyperparameter, solver sensitivity analysis\n\n**Case studies about KnowAir**\n\nIn Table 1, we present comprehensive experiments to demonstrate that our model performs effectively even on a larger and coarser-grained air pollution dataset (KnowAir). Additionally, we have provided visualizations of several case studies on the KnowAir dataset in the updated version's Appendix $\\underline{\\text{A.9 The case study of KnowAir}}$.\n\n**hyperparameter and solver sensitive analysis**\n\nThese sensitivity analyses are provided in the updated version's Appendix $\\underline{\\text{A.11 Hyperparameter sensitive analysis}}$ and $\\underline{\\text{A.12 ODE solver sensitive analysis}}$. From the experiments, it is evident that Air-DualODE demonstrates robustness to hyperparameters and solver types. At the ODE solver level, we recommend the adaptive Dopri5 algorithm. Although its computational cost is higher than the other two solvers, it also delivers greater accuracy.\n\n\n### **W5**: incomplete technical documentation and concerns on reproducibility.\n\nTo ensure reproducibility, we have made our code public available, where the link to the code is in the original submission paper's abstract (cf. abstract line 027). We also revised the mathematical part to provide clear access to the details of our architecture.\n\n\n\n### **Q2**: how to preserve the physical interpretability when projecting into latent space\n\nOur primary goal is to accurately predict pollutants (e.g., PM2.5) rather than to develop a fully interpretable model. Our approach remains within the domain of deep learning. We applied physical modeling to the inputs (PM2.5) and claim that the physical branch explicitly models physical phenomena. However, after projecting into the latent space, the latent variables retain some physical information but do not preserve full physical interpretability.\n\n### **Q3**: the chosen values of $\\lambda_1$ and $\\lambda_2$ in Decay-TCL\n\nBoth $\\lambda_1$ and $\\lambda_2$ are used to constrain the latent representations of the two dynamics according to the third rule mentioned in Section 3.4: representations from the same dynamics should be more similar than those from different dynamics. To meet this condition, we must ensure $\\lambda_1 > \\lambda_2$. A grid search is performed to determine the appropriate values for $\\lambda_1$ and $\\lambda_2$ while enforcing $\\lambda_1 > \\lambda_2$.\n\n\n### **Q4**: the reason of choosing Spatial-MSA and aligning with a physics-informed approach\n\n**the reason of choosing Spatial-MSA**\n\nTo capture spatial correlations across different time steps within the dynamic system, a masked self-attention mechanism is incorporated into $\\mathbf{F^D}$ as we described in Section 3.3. Specifically, the adjacency matrix of $G$ is employed as a mask in the Spatial-MSA. Intuitively, if no potential transport pathway exists between two stations, their representations should not be correlated. This approach not only improves computational efficiency but also enables the model to focus on relevant information, thereby enhancing its effectiveness across various spatial granularities in prediction scenarios because every station just attends to their nearby, accessible stations.\n\n**aligning with a physics-informed approach**\n\nWe could say this is physics-guided because it leverages geospatial intuition (First Law of Geography: Everything is related to everything else, but near things are more related to each other.) rather than explicit physical equations.\n\n\n### **Q5**: the difference between GNN fusion and Physics branch\n\nAlthough both branches consider 'distance-dependent influence', they have essentially two distinct targets. In the physics branch, $G_{\\text{diff}}$ and $G_{\\text{adv}}$ are used exclusively by the GNN to approximate the Laplacian operator. However, during the fusion process, our goal is to account for the interactive influence of the two types of dynamics representations across different nodes. The degree of this interaction also needs to be defined based on spatial distance. To achieve better fusion of the dynamics representations, we again utilize a graph structure based on spatial distances. Both designs are essential to the Air-DualODE framework."}, {"comment_id": "TVnt1bykx4", "replyto": "Ys1pxLTDIm", "author_type": "authors", "reviewer": null, "comment": "We would like to sincerely thank Reviewer 6Tsm for providing a detailed review and insightful comments regarding BA-DAE, hyperparameter's sensitivity, typos and other important components in our model. We have revised our paper accordingly.\n\n### **W1, Q1**: more discussions about $\\boldsymbol{\\beta} X$\n\n**1. the reason of linear term.**\n\nFrom the well-known SIR model, in the infectious equation ($\\frac{dI}{dt} = \\eta \\frac{SI}{N} - \\gamma I$) and the recovery equation ($\\frac{d R}{dt} = \\gamma I$), the growth rate of recovery is modeled as a proportion to infected population. However, in reality, their relationship exhibits complex nonlinear dependencies. Motivated by this design and the principle of Occam’s razor, we also use $\\boldsymbol{\\beta} X$ to model sources (generation of pollutants) and sinks (dissipation of pollutants). Besides, $\\boldsymbol{\\beta}$ is estimated by the RNN based coefficient estimator (cf. Fig.2's \"Physics Dynamics\" part), which accounts for the nonlinear dependencies of sources and sinks.\n\n**2. physical justification for $\\boldsymbol{\\beta}$'s range.**\n\nAs mentioned in Section 3.2.2, the dissipation of pollutants cannot exceed the current pollutant concentration; therefore, $\\boldsymbol{\\beta}$ has a minimum value of -1, i.e., all pollutants are dissipated. However, the generation of pollutants may far exceed the current concentration. Thus, $\\boldsymbol{\\beta}$ can theoretically have a maximum value of $+\\infty$.\n\n\n### **W2, Q11**: numerical stability and physical interpretability\n\n**numerical stability**\n\nWe did not observe numerical instability during our experiments. Moreover, the training loss curve (now available in the updated version's Appendix $\\underline{\\text{A.8 Numerical stability}}$) demonstrates that our method converges. Besides, Air-DualODE incorporates normalization techniques, such as Layer Normalization, which contributes to numerical stability.\n\n**physical interpretability**\n\nTo clarify, our proposal is not a fully-explainable new physical model but rather a hybrid model or a physics-guided/inspired method (as you mentioned in Q14). The goal of our work is to enable accurate air quality prediction, rather than having fully-explainability. We integrate physical equations into deep learning models, enabling existing physical knowledge to enhance the model’s learning and inference, thus improving prediction accuracy.\n\n### **W3, Q18**: computational concerns\n\nOur computational efficiency claims are in comparisons with physics-based methods that solve PDEs on a large grid, incurring significant computational costs. While our physical branch also solves the ODEs derived from the Method of Lines (MOL), it is noteworthy that our spatial discretization is observation station-based. This means that the number of ODEs corresponds to the number of observation stations, which effectively reduces the computational time. We theoretically provide a comparison of computational complexity as follows.\n\n| | theoretical analysis of time cost |\n| --------------------- | ------------------------------------------------------------ |\n| Air-DualODE | $O(\\text{NFE}^{\\text{P}} \\cdot \\| E \\| + \\text{NFE}^{\\text{D}} \\cdot N^2 \\cdot C + \\| E \\|)$ |\n| Physics-based methods | $O(T_{grid} \\cdot N_{grid}^{\\frac{3}{2}})$ |\n\n\nHere, $\\text{NFE}^{\\text{P}}$ and $\\text{NFE}^{\\text{D}}$ refer to the Number of Forward Evaluations for the physics and data-driven branches, respectively. $|E|$ is the number of edge in $G$. $C$ represents the dimension of latent variables. $N$ represents number of neighboring stations around the current station, while $N_{\\text{grid}}$ and $T_{\\text{grid}}$ denote the grid points in PDE solving. In practical scenarios, the number of grid points is much larger than the number of stations: $N_{\\text{grid}} \\gg N$ and the number of grid time step is also greater than NFE: $T_{\\text{grid}} \\gg \\text{NFE}$. For example, on the Beijing dataset with 35 observation stations, predicting 24 steps means $N \\leq 35$ and $\\text{NFE} \\geq 24$ (depending on ODE-solver). In contrast, physics-based methods would require $N_{\\text{grid}} \\approx 1000,000 ( 1000^2 )$ and $T_{\\text{grid}} \\approx 1,000$ to keep accuracy. Therefore, the computational cost of Air-DualODE is lower than that of physics-based methods."}], "meta_review": {"metareview": "This paper presents a spatiotemporal forecasting method specialized to the air quality prediction which is an important problem in some countries. They combine a physics-based modeling and a data-based modeling to have better predictions. However, they failed to consider some recent methods as pointed by the reviewers and during the rebuttal phase, they added more discussion on them. For the physics modeling, they use the diffusion-advection system, which is a quite standard method for this task. Later, they may want to extend to the diffusion-advection-reaction system in my personal opinion. They showed SOTA performance in various datasets.\n\nI recommend accepting this paper, but I also recommend that the authors put more comparisons with more spatiotemporal forecasting baselines. There are many baselines for traffic forecasting. Since they are technically similar to each other, some audiences may be interested in those comparisons.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "M992mjgKzI", "forum_url": "https://openreview.net/forum?id=M992mjgKzI", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "OGBench: Benchmarking Offline Goal-Conditioned RL", "authors": ["Seohong Park", "Kevin Frans", "Benjamin Eysenbach", "Sergey Levine"], "abstract": "Offline goal-conditioned reinforcement learning (GCRL) is a major problem in reinforcement learning (RL) because it provides a simple, unsupervised, and domain-agnostic way to acquire diverse behaviors and representations from unlabeled data without rewards. Despite the importance of this setting, we lack a standard benchmark that can systematically evaluate the capabilities of offline GCRL algorithms. In this work, we propose OGBench, a new, high-quality benchmark for algorithms research in offline goal-conditioned RL. OGBench consists of 8 types of environments, 85 datasets, and reference implementations of 6 representative offline GCRL algorithms. We have designed these challenging and realistic environments and datasets to directly probe different capabilities of algorithms, such as stitching, long-horizon reasoning, and the ability to handle high-dimensional inputs and stochasticity. While representative algorithms may rank similarly on prior benchmarks, our experiments reveal stark strengths and weaknesses in these different capabilities, providing a strong foundation for building new algorithms. Project page: https://seohong.me/projects/ogbench", "keywords": ["reinforcement learning"], "primary_area": "reinforcement learning", "pdf_url": "https://openreview.net/pdf?id=M992mjgKzI", "decision": "Accept (Poster)", "num_reviews": 4, "num_discussions": 10, "reviews": [{"review_id": "JE4Gj3ndF6", "reviewer": "Reviewer_jjTu", "rating": 8, "confidence": 4, "soundness": 4, "presentation": 4, "contribution": 4, "summary": "In this work the authors present OGBench, a novel benchmark for offline goal-conditioned reinforcement learning (GCRL). The authors motivate the need for such benchmark, due to the lack of a standardized evaluation suite for works in offline GCRL and the increasing interest from the research community on the topic. The authors describe the main challenges tackled with OGBench: learning from sub-optimal data, goal stitching, long-horizon reasoning and stochastic environments. To tackle these challenges, the authors describe the design principles of the benchmark and consider a wide range tasks (locomotion, manipulation and drawing) across multiple environments. Additionally, the authors collect several different datasets per environment, considering different sub-optimality conditions of the behavior policy. Finally, the authors evaluate several literature-standard algorithms in OGBench, contributing their implementation as well, and discuss their performance and other findings.", "strengths": "- **Originality**:\n - Recently, there has been an extensive push towards the development of benchmarks that focus on different aspects of reinforcement learning (e.g., for offline RL [1], [2]). However, as noted by the authors, there existed no centralized evaluation suite for algorithms in offline GCRL and the community often resorted to different set of tasks (not tailored for, making difficult to access the true performance of the proposed methods. As such, the development of a benchmark that tackles specific research issues in offline GCRL is most welcomed and, to the best of my knowledge, novel.\n\n- **Quality**:\n - The work presented here is of high quality: the proposed benchmark contains a set of diverse tasks that address specific problems in offline GCRL research and high-quality, fine-tuned implementations of recently proposed algorithms for offline GCRL (that often surpass the performance of the original versions). Furthermore, the authors collect a wide range of datasets for each task, with different levels of sub-optimality.\n\n- **Clarity**:\n - The current version of the work is well-written, with no major typos (that I could detect). Furthermore, the document is easy to read, with self-explanatory figures and tables. I would, however, recommend some refrain when using adjectives such as \"cool\" (Section 6, page 4) or \"exciting\" (Section 6, page 4), due to their intrinsic subjectivity.\n\n- **Significance**:\n - This work proposes a novel benchmark for offline GCRL. Given the relevance of the topic for RL research, the work can have substantial impact by providing a standardized suite for evaluation of novel algorithms.\n\n\n**References**: \n\n- [1] Seno, Takuma, and Michita Imai. \"d3rlpy: An offline deep reinforcement learning library.\" Journal of Machine Learning Research 23.315 (2022): 1-20.\n- [2] Fu, Justin, et al. \"D4rl: Datasets for deep data-driven reinforcement learning.\" arXiv preprint arXiv:2004.07219 (2020).", "weaknesses": "My concerns with the current version of the work are the following (none of them particularly major):\n- Currently, it is challenging to assess the varying levels of novelty present in the environments available in OGBench. Throughout Section 7, I found myself questioning whether each environment was accessible elsewhere (in another benchmark or repository) or if it was unique to this benchmark. I do understand that some tasks (e.g., AntMaze) are heavily inspired by/available in other benchmarks. It would strengthen the originality of the benchmark if the authors clarified the level of novelty in each environment.\n- Currently the paper lacks any discussion about the limitations of the current benchmark and plans for future extension (for example, one potential extension could tackle robustness to data corruption as explored in [1], or integrating scenarios with multiple modalities).\n- Is not clear how the five evaluation goals per task were defined or what criteria was used for their selection. Additionally, why only 5 goals? Since GC policies should be able to reach any (viable) state from any state, it is unclear why the authors selected such a set of small goals.\n\n\n**References**\n\n- [1] Yang, Rui, et al. \"Towards Robust Offline Reinforcement Learning under Diverse Data Corruption.\" The Twelfth International Conference on Learning Representations.", "questions": "Please, also refer to the questions brought up in the \"Weaknesses\" section.\n\n- **1** - Why not also collect an \"exploratory\" dataset for the AntSoccer task? Is it due to the lack of coverage of states where the agent is dribbling the ball?\n- **2** - In the locomotion and manipulation tasks, how exactly is the goal state provided to the agent? Does it also include the final pose of the agent/robot? In that case, the task goal becomes a convolution of an environmental change (for example, a set of objects in a specific location in the Scene task) and a final pose of the agent?\n- **3** - I found it quite interesting the comment at the end of Section 8.2 (\"this suggests that we may need to prioritize coverage much more than optimality when collecting datasets for offline GCRL in the real world\") as this seems to go against the current trends in large scale collection of interaction data in robotics, where the data are collected across multiple environments/embodiments but always by expert-level demonstrators (e.g., [1]). Can the authors make any parallelism between the case of offline GCRL and other large-scale data collection efforts?\n\n\n**Reference**:\n- [1] Collaboration, Open-X. Embodiment, et al. \"Open X-Embodiment: Robotic learning datasets and RT-X models.\" arXiv preprint arXiv:2310.08864 1.2 (2023)."}, {"review_id": "Z48hE4lU6Z", "reviewer": "Reviewer_t6ud", "rating": 8, "confidence": 4, "soundness": 3, "presentation": 4, "contribution": 3, "summary": "The paper proposes a new suite of tasks for benchmarking goal-conditioned offline RL algorithms. The proposed benchmarks are built with consideration specific to the goal-conditioned navigation tasks, with an emphasis on evaluating relevant task capabilities such as stitching, long-horizon reasoning, dealing with stochastic and visual input. The tasks cover a range of difficulty levels and types, including ant and humanoid locomotion, multi-step manipulation and a painting task with an high-dimension state space entailing combinatorial search. Offline policy data are also collected with various policies to emulate the data sub optimality. The benchmark is shown to be challenging for state-of-the-art and helpful in analysing strengths and limitations on investigated task capabilities. Code and policy scripted are provided and with a minimum dependencies on open-sourced mujoco and PyTorch libraries.", "strengths": "* Well-motivated research efforts and a solid execution with many considerations customised to goal-conditioned offline RL.\n\n* A good coverage of task variations and challenges to state-of-the-art algorithms based on different principles.\n\n* Proper task difficulties demonstrated from the results. Could be promising to spur new algorithmic research.\n\n* Decent writing with clear presentation and well organised flow of narratives.", "weaknesses": "* The benchmark is motivated for the generality of goal-conditioned navigation tasks, aiming at learning transferrable representations for down-stream tasks. However, this is not reflected in the task design and experimental results. It would be better to include some functionalities to facilitate analysing and transferring the learned latent representations, with a report on a few SotA algorithms' performance on this aspect.\n\n* The tasks seem to only challenge policies with different initial states but not the shift of transition dynamics.", "questions": "* The expert data are collected from policies trained by RL. How can we assume the policies provide sufficient optimality? Since the all the tasks are essentially planning towards a goal, would running motion planning/search algorithms with certain guarantees give better data quality?\n\n* Can the benchmark add a study on the representation and the transferability to assess the idea of using goal-conditioned tasks as a general-purpose representation learning?"}, {"review_id": "pDlSjcMeeS", "reviewer": "Reviewer_sCze", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 4, "summary": "This paper introduces OGBench, a comprehensive benchmark for the standardized evaluation of offline GCRL. OGBench provides a diverse range of environments and datasets, covering areas such as locomotion, manipulation, and drawing, to assess key capabilities like trajectory stitching, long-horizon reasoning, image-input handling, and managing stochasticity. The authors evaluate six offline GCRL algorithms using this benchmark, designed to highlight each algorithm's unique strengths and weaknesses across different tasks. OGBench is optimized for computational efficiency and user-friendliness, allowing researchers to easily test and refine their ideas, thereby advancing the field of offline GCRL.", "strengths": "* Offline GCRL is an important research direction, and this work addresses the current lack of a challenging and comprehensive benchmark.\n\n* This benchmark evaluates GCRL algorithms across diverse aspects, including stitching, stochastic environments, high-dimensional control, and long-horizon control.\n\n* The benchmark includes popular baseline results, highlighting some unsolved tasks and areas for improvement.", "weaknesses": "Some benchmark designs need further discussion or improvement.\n\n* The evaluation seems limited to five tasks. Does this mean there are only five evaluation goals? If so, this number is restricted compared to several prior goal-conditioned tasks such as the Fetch environments.\n\n* The design of the transparent arm in pixel-based manipulation tasks seems unrealistic. It would be better to let users choose between a transparent or solid option. Additionally, demonstrating how transparency affects performance could be helpful.\n\n* The goal information is missing in Table 6. Is the goal in image-based tasks another image provided to the agent, or is it just a transparent bot as shown in the manipulation task figures? \n\n* Can the benchmark support language instructions as a special type of \"goal\"? This would be interesting given the current research focus on LLMs.\n\n* The stochastic task appears confined to the teleport domain. More stochastic settings that might be relevant to real-world scenarios should be explored, such as random observation perturbation during both data collection and evaluation. \n\n* Current baselines still use MLP networks as backbones. Could the authors consider implementing more advanced architectures, such as transformers? Are there any limitations in the current tasks that more sophisticated network architectures might address?", "questions": "See the weakness part."}, {"review_id": "qZa8BUYWbr", "reviewer": "Reviewer_BGac", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "This paper contributes to offline GCRL by introducing a suite of high-quality benchmarks. These benchmarks systematically assess offline GCRL algorithms' performance in learning from suboptimal data, goal stitching, long-horizon reasoning, and handling stochasticity, providing a useful tool for advancing offline GCRL research.", "strengths": "1. They are very useful benchmarks to evaluate offline GCRL methods. \n2. The authors evaluate $6$ standard offline GCRL methods on the proposed benchmarks, establishing a solid baseline for comparison and further development.", "weaknesses": "The positioning of this paper is very similar to a previous paper - D4RL (Fu et al., 2021), which was well-known but rejected by ICLR 2021 (https://openreview.net/forum?id=px0-N3_KjA). \n\nThis paper makes a clear contribution by providing valuable benchmarks, but it does not introduce novel ideas. However, given its potential importance to the field, I lean towards assigning a slightly positive score.", "questions": "1. It would be better to retain one or two decimal places in Table 2. In benchmarks such as D4RL (Fu et al., 2021) and offline robotics (Yang et al., 2022), the success rate/score provided by decimal places is sometimes critical for comparisons. \n2. Could the authors clarify in the paper whether a rollout is considered successful if the goal is achieved at the last state, or if any state within the rollout reaches the goal? \n3. The authors should include a mention of GOPlan (Wang et al., 2024), which applies model-based planning with a generative adversarial network as the prior policy, in the related work. Additionally, GOPlan and GoFar (Ma et al., 2022) evaluate algorithm robustness in stochastic offline GCRL settings and should be discussed in Section 5.4.\n\n\n**References**\n\nFu, J., Kumar, A., Nachum, O., Tucker, G., \\& Levine, S. (2021). D4RL: Datasets for Deep Data-Driven Reinforcement Learning. \n\nMa, J. Y., Yan, J., Jayaraman, D., \\& Bastani, O. (2022). Offline Goal-Conditioned Reinforcement Learning via f-Advantage Regression. NeurIPS.\n\nWang, M., Yang, R., Chen, X., Sun, H., Fang, M., \\& Montana, G. (2024). GOPlan: Goal-conditioned Offline Reinforcement Learning by Planning with Learned Models. TMLR.\n\nYang, R., Lu, Y., Li, W., Sun, H., Fang, M., Du, Y., … Zhang, C. (2022). Rethinking Goal-conditioned Supervised Learning and Its Connection to Offline RL. ICLR."}], "discussions": [{"comment_id": "qAgxOth7ky", "replyto": "TQNMjaIQY9", "author_type": "reviewer", "reviewer": "Reviewer_BGac", "comment": "Thank you for your reply. This could be an influential work for the offline GCRL community. I maintain my current positive rating."}, {"comment_id": "L7hAkDGrfd", "replyto": "LixuWEJwVW", "author_type": "reviewer", "reviewer": "Reviewer_sCze", "comment": "Thank you for your response.Overall, I think this work is valuable to the RL community, and I will maintain my current positive rating."}, {"comment_id": "stlcEqoZJj", "replyto": "0WfjhHYEe5", "author_type": "reviewer", "reviewer": "Reviewer_t6ud", "comment": "Thank you for the discussion and texts for addressing my comments. It is good to know the literature on the off-dynamics benchmark. I maintain my opinion that this is a good and timely work piece and the original score."}, {"comment_id": "5blHXSAXAx", "replyto": "KVNuxXCp0Z", "author_type": "reviewer", "reviewer": "Reviewer_jjTu", "comment": "Thank you for the insightfull rebuttal. In line with my previous comments, I think the authors did a great job with this work and, as such, I maintain my score."}, {"comment_id": "TQNMjaIQY9", "replyto": "qZa8BUYWbr", "author_type": "authors", "reviewer": null, "comment": "Thank you for the detailed review and constructive feedback about this work. We especially appreciate the reviewer's clarification question about our goal success criterion. Please find our response below.\n\n* **“This paper makes a clear contribution by providing valuable benchmarks, but it does not introduce novel ideas.”**\n\nThanks for the positive comments! Indeed, our goal in this work is not to propose a novel method but to provide a new benchmark. Just for clarification, we would like to note that “datasets and benchmarks” is one of the ICLR subject areas ([ICLR 2025 Call for Papers](https://iclr.cc/Conferences/2025/CallForPapers)).\n\n* **Similarity to D4RL**\n\nWhile D4RL primarily focuses on single-task RL, our benchmark aims to complement D4RL to provide a thorough suite of tasks and algorithms for studying offline goal-conditioned RL. Our benchmark also addresses some key limitations of D4RL (e.g., task performances are saturated [1], it only supports single-goal evaluation, and most tasks do not support pixel-based observations). Unlike D4RL, OGBench provides much more challenging tasks and datasets specifically designed for offline goal-conditioned RL (e.g., HumanoidMaze, AntSoccer, Scene, Puzzle, Powderworld, etc.), while posing diverse challenges for goal-conditioned RL, including long-horizon reasoning, pixel-based control, goal stitching, and stochastic control.\n\n* **“It would be better to retain one or two decimal places in Table 2.”**\n\nThanks for the suggestion! Following the suggestion, we have revised Tables 2 and 10-16 to show one decimal place; please find them in the revised PDF.\n\n* **How is a rollout considered successful at goal-reaching?**\n\nThanks for asking this clarification question. During evaluation, the trajectory ends immediately after the agent reaches the goal, meaning that we consider a trajectory to be successful if it contains at least one state (not necessarily the final one) that achieves the goal. It seems we were not entirely clear about our evaluation protocol in the initial draft, and we have clarified this point by making a separate paragraph (titled “Evaluation”) at the beginning of Section 7 of the revised PDF.\n\n* **Missing related works: GOPlan and GoFAR**\n\nThanks for pointing out these relevant works! We have revised the draft to mention GOPlan and GoFAR (please refer to Section 3 and Appendix C of the revised PDF).\n\n---\n\nWe would like to thank the reviewer again for raising important questions about OGBench. We believe the added clarification about the success criterion as well as the discussions about further related works have improved the quality of the paper. **Have we sufficiently addressed the reviewer’s main concerns?** Please feel free to let us know if there are additional concerns or questions.\n\n[1] Tarasov et al., Revisiting the minimalist approach to offline reinforcement learning, NeurIPS 2023."}, {"comment_id": "LixuWEJwVW", "replyto": "pDlSjcMeeS", "author_type": "authors", "reviewer": null, "comment": "* **Can the benchmark support language instructions as a special type of \"goal\"?**\n\nThanks for the suggestion. While language-conditioned RL/control is an important problem, we consider it to be beyond the scope of this benchmark. However, we believe it may not be very difficult to convert many manipulation tasks into language-conditioned tasks (e.g., we can describe evaluation tasks by “swap the cubes”, “turn all buttons into blue”, “open the drawer and window”, etc.). We leave the integration of language instructions and OGBench tasks as an interesting potential extension of this work.\n\n* **About stochastic tasks / “The stochastic task appears confined to the teleport domain.”**\n\nIn OGBench, we provide two different types of stochastic environments: (1) the “teleport” maze in navigation environments and (2) Powderworld, whose transition dynamics are highly stochastic and unpredictable. As the reviewer mentioned, our manipulation environments have deterministic dynamics (though the poses/colors of objects are slightly randomized/permuted each time), as in many previous simulated tasks in robotic manipulation (e.g., Fetch [4], Kitchen [5], CALVIN [6], etc.). This is mainly because it is not entirely straightforward to design realistic manipulation tasks with truly stochastic dynamics. That said, we believe the new option for choosing an opaque arm (please see our response above) can provide an additional challenge regarding stochasticity (via partial observability) in manipulation domains.\n\n* **Current baselines still use MLP networks as backbones.**\n\nAs the reviewer pointed out, our baselines are based on MLPs and CNNs. While we were unable to add and tune Transformer-based baselines within this short discussion period due to our limited computing resources, we believe complex tasks in OGBench (e.g., $\\texttt{visual-humanoidmaze-giant}$, $\\texttt{puzzle-4x6}$, $\\texttt{cube-quadruple}$) do require scalable and expressive networks and thus provide a ground for future research involving modern architectures. We leave this extension as an exciting future research opportunity.\n\n---\n\nWe would like to thank the reviewer again for raising important questions about OGBench. We believe the new clarifications and the additional features have substantially improved the quality of this work. Have we sufficiently addressed the reviewer’s main concerns? Please feel free to let us know if there are additional concerns or questions.\n\n---\n\n[1] Fang et al., Planning to Practice: Efficient Online Fine-Tuning by Composing Goals in Latent Space, IROS 2022. \\\n[2] Zheng et al., Stabilizing Contrastive RL: Techniques for Robotic Goal Reaching from Offline Data, ICLR 2024. \\\n[3] Ghugare et al., Closing the Gap between TD Learning and Supervised Learning -- A Generalisation Point of View, ICLR 2024. \\\n[4] Plappert et al., Multi-Goal Reinforcement Learning: Challenging Robotics Environments and Request for Research, 2018. \\\n[5] Gupta et al., Relay Policy Learning: Solving Long-Horizon Tasks via Imitation and Reinforcement Learning, CoRL 2019. \\\n[6] Mees et al., CALVIN - A benchmark for Language-Conditioned Policy Learning for Long-Horizon Robot Manipulation Tasks, RA-L 2022."}, {"comment_id": "QQHNcp9PCl", "replyto": "pDlSjcMeeS", "author_type": "authors", "reviewer": null, "comment": "Thank you for the detailed review and constructive feedback about this work. We especially appreciate the reviewer's feedback about evaluation goals, the use of the transparent arm, and several clarification questions. Following the reviewer's suggestion, we have added an option to reproduce datasets with an opaque arm. Please find our response below.\n\n* **Evaluation is limited to five goals. Does this mean there are only five evaluation goals?**\n\nThe reviewer’s understanding is correct — we provide five evaluation goals for each task in the benchmark, although their positions, poses, color orders, etc., are slightly randomized each time, and we perform multiple ($50$) rollouts for each evaluation goal. We would first like to note that this type of evaluation (i.e., having a fixed set of evaluation tasks) is quite commonly used in prior works in offline goal/language-conditioned RL and robotics [1, 2, 3]. There are two main reasons for this choice: (1) using five pre-defined goals reduces computational cost for evaluation (which sometimes takes even longer than training in long-horizon environments!), as evaluation with fully randomized goals typically requires many more evaluation rollouts to derive statistically significant results and (2) providing diverse *types* of pre-defined goals enables the practitioner to better analyze their algorithm’s strengths and weaknesses. When designing the tasks, we put substantial effort into choosing a small number of *representative* evaluation goals to maximize research signals while minimizing computational cost. For example, in $\\texttt{cube-double}$ tasks, we curated five evaluation goals to cover single pick-and-place, double pick-and-place, swapping, stacking, etc. (Figure 6). The results in Table 14 show that some methods are good at single pick-and-place but struggle with tasks involving double pick-and-place, which can guide practitioners on areas for improvement in their methods.\n\n* **The use of the transparent arm**\n\nThanks for asking this question! First of all, following the suggestion, we have added an option (`pixel_transparent_arm={False, True}` in `manipspace_env.py`) for the manipulation environments to make the transparency of the arm configurable. Since we have provided the exact commands to reproduce all the datasets, users who wish to use a non-transparent arm can easily generate and use opaque versions of the visual manipulation datasets.\n\nAs for why we chose the transparent arm as the default, we also debated this point; there are pros and cons. A solid arm looks more natural, but due to occlusion in some tasks (e.g., Puzzle), it requires the use of a memory component (e.g., RNNs, Transformers) or multiple viewpoints, which can significantly increase computational burdens. On the other hand, a transparent arm looks less natural, but it enables using a single $64 \\times 64 \\times 3$ image and does not require any additional memory components. Among these two choices, we decided to choose the latter, following our design principle (design principle 4 in Section 6): **minimize unnecessary computation and focus on algorithmic challenges**. After all, OGBench is a benchmark mainly for *algorithms research*, so we prioritized minimizing other orthogonal challenges so that researchers can more quickly iterate on their algorithmic ideas. To further clarify this, we have mentioned this trade-off as a limitation in the newly added limitation section (Appendix A) of the revised PDF. Nevertheless, we believe the new option to configure the transparency of the arm provides additional flexibility for users.\n\n* **“The goal information is missing in Table 6.” / How is a goal given for image-based tasks?**\n\nThanks for asking the clarification question. In all tasks in our benchmark, the goal space is the same as the state space (L100); in other words, a goal is simply yet another state. As such, in image-based tasks, the agent receives as a goal $g$ the rendering of the desired state (with the transparent arm by default). It seems we were not entirely clear about our evaluation protocol in the initial draft, and we have clarified this point by making a separate paragraph (titled “Evaluation”) at the beginning of Section 7 of the revised PDF."}, {"comment_id": "0WfjhHYEe5", "replyto": "Z48hE4lU6Z", "author_type": "authors", "reviewer": null, "comment": "Thank you for the detailed review and positive and constructive feedback about this work. We especially appreciate the reviewer's questions about representation learning and the optimality of datasets. Please find our response below.\n\n* **“Can the benchmark add a study on the representation and the transferability to assess the idea of using goal-conditioned tasks as a general-purpose representation learning?”**\n\nThanks for the suggestion! We completely agree that evaluating the transferability of representations to different types of downstream tasks is an important problem. In fact, we are currently working on this very problem (i.e., how to pre-train and fine-tune GCRL representations trained on diverse, unlabeled data for downstream tasks) using the OGBench locomotion/manipulation datasets. However, instead of adding our initial results on representation transfer to the final version of this paper, we hope to separate it into a different study to have more comprehensive analyses and experiments, given the difference between the two problem settings (offline goal-conditioned RL vs. representation-based downstream adaptation). Nevertheless, we would like to highlight that OGBench can indeed facilitate studies on representation learning by providing a variety of unlabeled datasets, and we hope to provide extensive analyses and results (as a separate work) as a more thorough answer to this question in the near future.\n\n* **“The tasks seem to only challenge policies with different initial states but not the shift of transition dynamics.”**\n\nAs the reviewer mentioned, we do not specifically address challenges with distributional shifts in transition dynamics in this benchmark. While this is an important issue, we consider it beyond the scope of this work, and would like to refer to other great benchmarks specialized in off-dynamics learning, such as ODRL [1]. We have mentioned this as a limitation in the newly added limitation section (Appendix A) of the revised PDF.\n\n* **“How can we assume the policies provide sufficient optimality?”**\n\nThanks for the question. In locomotion environments, we trained a low-level directional policy (using RL) to move as far as possible in a commanded direction, and combined it with an oracle BFS-based high-level planner to generate diverse maze-navigation trajectories. Since the objective of the low-level policy is to simply maximize traveled distances, we can easily estimate its performance (optimality). When creating the datasets and tasks, we confirmed that the trained low-level policies were reasonably optimal and that the evaluation tasks were solvable by the expert policy and oracle planner within the maximum episode length. For manipulation tasks, we used manually scripted policies to collect datasets (rather than RL), which we found sufficient for generating datasets.\n\nWe would like to thank the reviewer again for raising important questions about OGBench. We hope that our response has addressed the reviewer's questions, and please feel free to let us know if there are any additional concerns or questions.\n\n[1] Lyu et al., ODRL: A Benchmark for Off-Dynamics Reinforcement Learning, NeurIPS 2024."}, {"comment_id": "KVNuxXCp0Z", "replyto": "JE4Gj3ndF6", "author_type": "authors", "reviewer": null, "comment": "* **Why are there no exploratory datasets for AntSoccer?**\n\nWe found that $\\texttt{explore}$ datasets are already very challenging for AntMaze (Table 2), so we decided not to provide separate $\\texttt{explore}$ datasets for AntSoccer, which is much harder than AntMaze. That being said, since we provide the exact commands to reproduce all the datasets, we believe it wouldn’t be difficult for a user to generate custom $\\texttt{explore}$ datasets for AntSoccer with the provided data-generation scripts.\n\n* **How exactly is the goal state provided to the agent in locomotion/manipulation tasks?**\n\nIn all tasks in our benchmark, the goal space is the same as the state space (L100); in other words, a goal is simply yet another state. As such, a goal state contains both the proprioceptive states (e.g., joint angles) and object positions (if they exist). However, when measuring success during evaluation, we consider only the $x$-$y$ position (in locomotion environments) or object positions (in manipulation environments) (though the goal state still contains the full information), following prior practices in offline goal-conditioned RL [3, 5, 6]. We have clarified this point by making a separate paragraph at the beginning of Section 7, and have described the detailed task success criteria in Appendix E.1.\n\n* **Regarding dataset optimality vs. coverage for offline RL/GCRL**\n\nThanks for asking this question! We were also a bit surprised by the results in Figure 3, which show that insufficient state coverage can lead to a *complete* failure of otherwise solvable tasks. While the current trend in robotics seems to favor collecting large, expert datasets (as the reviewer mentioned), there has also recently been an increasing number of studies investigating the role of dataset coverage for offline RL [7, 8, 9], which have made similar/relevant observations to ours. For instance, Park et al. [7] show that offline RL is mainly bottlenecked by test-time generalization rather than data quality, highlighting the importance of sufficient noise/coverage. In robotics, data coverage aspects are often considered in the form of having “recovery” transitions from suboptimal states in the dataset [8, 9], which can teach the agent how to recover from mistakes during evaluation. For example, the recent $\\pi_0$ generalist robot policy from Physical Intelligence (a robotics startup) uses diverse yet suboptimal data that contains recovery behaviors for post-training [9], and its technical report states that “training on only this high-quality data results in a brittle model” [9], which directly aligns with our observation in Figure 3. We hope that our controllable and reproducible data-generation scripts enable systematic and scientific studies of diverse aspects (e.g., coverage, suboptimality, scalability, diversity, etc.) of datasets for offline RL/GCRL.\n\nWe would like to thank the reviewer again for raising important questions about OGBench, and we believe the added limitation section as well as the clarifications have improved the paper substantially. Please let us know if there are any additional concerns or questions.\n\n[1] Frans et al., Powderworld: A Platform for Understanding Generalization via Rich Task Distributions, ICLR 2023. \\\n[2] Fang et al., Planning to Practice: Efficient Online Fine-Tuning by Composing Goals in Latent Space, IROS 2022. \\\n[3] Zheng et al., Stabilizing Contrastive RL: Techniques for Robotic Goal Reaching from Offline Data, ICLR 2024. \\\n[4] Ghugare et al., Closing the Gap between TD Learning and Supervised Learning -- A Generalisation Point of View, ICLR 2024. \\\n[5] Park et al., HIQL: Offline Goal-Conditioned RL with Latent States as Actions, NeurIPS 2023. \\\n[6] Myers et al., Learning Temporal Distances: Contrastive Successor Features Can Provide a Metric Structure for Decision-Making, ICML 2024. \\\n[7] Park et al., Is Value Learning Really the Main Bottleneck in Offline RL?, NeurIPS 2024. \\\n[8] Ke et al., CCIL: Continuity-based Data Augmentation for Corrective Imitation Learning, ICLR 2024. \\\n[9] Physical Intelligence, π0: A Vision-Language-Action Flow Model for General Robot Control, 2024, https://www.physicalintelligence.company/download/pi0.pdf"}, {"comment_id": "xY1VLt9h9Q", "replyto": "JE4Gj3ndF6", "author_type": "authors", "reviewer": null, "comment": "Thank you for the detailed review and positive and constructive feedback about this work. We especially appreciate the reviewer’s questions about limitations and the importance of coverage in data collection. Please find our response below.\n\n* **How new are the OGBench tasks?**\n\nWe provide $7$ types of tasks in this benchmark: AntMaze, HumanoidMaze, AntSoccer, Cube, Scene, Puzzle, and Powderworld. OGBench AntMaze is a direct extension of D4RL AntMaze, where we provide more types of datasets, mazes, and observation modalities (states/pixels) to pose further challenges. Powderworld is an existing environment for online RL [1], but we designed evaluation tasks and datasets to make it offline and goal-conditioned.\n\nThe other five types of tasks (HumanoidMaze, AntSoccer, Cube, Scene, Puzzle) were designed by us; especially, the manipulation tasks were entirely built from scratch using publicly available MuJoCo assets for objects and robots. That said, some tasks were conceptually inspired by existing online RL tasks (e.g., DMC quadruped-fetch for AntSoccer), as mentioned in the paper (L328).\n\n* **Limitations of OGBench**\n\nThanks for pointing this out. Following the suggestion, we have added a limitation section in the revised version of the paper (Appendix A):\n\n> **Limitations.** While OGBench covers a number of challenges in offline goal-conditioned RL, such as long-horizon reasoning, goal stitching, and stochastic control, there exist other challenges that our benchmark does not address. For example, all OGBench tasks assume that the environment dynamics remain the same between the training and evaluation environments. Also, although several OGBench tasks (e.g., Cube, Puzzle, and Powderworld) require unseen goal generalization to some degree, our tasks do not specifically test visual generalization to entirely new objects. Finally, we have made several trade-offs to reduce computational cost and to focus the benchmark on algorithms research at the expense of sacrificing realism to some degree (e.g., the use of the transparent arm in manipulation environments, the use of synthetic (yet fully controllable) datasets, etc.). Nonetheless, we believe OGBench can spur the development of performant offline GCRL *algorithms*, which can then help researchers develop scalable data-driven unsupervised RL pre-training methods for real-world tasks.\n\n* **How are the evaluation goals selected? / Evaluation limited to five goals**\n\nThanks for the question. For each task, we curated a small number (5) of *representative* evaluation goals that cover diverse goal-reaching behaviors of varying difficulty. For example, in $\\texttt{cube-double}$, we curated five evaluation goals to cover single pick-and-place, double pick-and-place, swapping, stacking, etc., and in $\\texttt{puzzle-4x6}$, we chose five goals with increasing levels of difficulty (from $6$ to $24$, in terms of the minimum number of button presses). On each of the five goals, we perform $50$ rollouts with slightly randomized poses/color orders/initial states to obtain the results.\n\nAs the reviewer pointed out, it is indeed possible to evaluate a goal-conditioned policy with any arbitrary goals. However, we chose to restrict to five evaluation goals in this benchmark (similarly to several prior works in offline goal/language-conditioned RL and robotics [2, 3, 4]). There are two main reasons for this choice: (1) using five pre-defined goals reduces computational cost for evaluation (which sometimes takes even longer than training in long-horizon environments!), as evaluation with fully randomized goals typically requires many more evaluation rollouts to derive statistically significant results, and (2) providing diverse *types* of pre-defined goals enables the practitioner to better analyze their algorithm’s strengths and weaknesses. For example, the results in Table 14 show that some methods are good at single pick-and-place but struggle with tasks involving double pick-and-place, which can guide practitioners on areas for improvement in their methods."}], "meta_review": {"metareview": "This paper introduces a dataset for goal-conditioned RL. Reviewers generally agree that it is an important problem and that the dataset could be useful. Reviewers also found the paper to be well-written. I agree with this sentiment and think that the dataset can complement popular RL datasets like D4RL. One concern raised was that only simple model classes (e.g., MLP and CNN) were raised. I think this is okay for a benchmark paper although it would have been great to see more common transformer-style architectures. I would urge authors to consider adding this for revision.\n\nOverall, I recommend acceptance.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "SFNqrHQTEP", "forum_url": "https://openreview.net/forum?id=SFNqrHQTEP", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "NExUME: Adaptive Training and Inference for DNNs under Intermittent Power Environments", "authors": ["Cyan Subhra Mishra", "Deeksha Chaudhary", "Jack Sampson", "Mahmut Kandemir", "Chita R. Das"], "abstract": "The deployment of Deep Neural Networks (DNNs) in energy-constrained environments, such as Energy Harvesting Wireless Sensor Networks (EH-WSNs), introduces significant challenges due to the intermittent nature of power availability. This study introduces NExUME, a novel training methodology designed specifically for DNNs operating under such constraints. We propose a dynamic adjustment of training parameters—dropout rates and quantization levels—that adapt in real-time to the available energy, which varies in energy harvesting scenarios.\n\nThis approach utilizes a model that integrates the characteristics of the network architecture and the specific energy harvesting profile. It dynamically adjusts training strategies, such as the intensity and timing of dropout and quantization, based on predictions of energy availability. This method not only conserves energy but also enhances the network’s adaptability, ensuring robust learning and inference capabilities even under stringent power constraints. Our results show a 6% to 22% improvement in accuracy over current methods, with an increase of less than 5% in computational overhead. This paper details the development of the adaptive training framework, describes the integration of energy profiles with dropout and quantization adjustments, and presents a comprehensive evaluation using real-world data. Additionally, we introduce a novel dataset aimed at furthering the application of energy harvesting in computational settings.", "keywords": ["Intermittent Computing", "Energy Harvesting", "Intermittency Aware Training", "Hardware-Software codesign"], "primary_area": "infrastructure, software libraries, hardware, systems, etc.", "pdf_url": "https://openreview.net/pdf?id=SFNqrHQTEP", "decision": "Accept (Poster)", "num_reviews": 4, "num_discussions": 12, "reviews": [{"review_id": "9gOJluAI8W", "reviewer": "Reviewer_P7Bk", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "The paper introduces NExUME, a framework addressing issues with DNN training for energy-constrained environments which can't guarantee a sufficient amount of power at all times, such as Energy Harvesting Wireless Sensor Networks. In order to optimise DNN-Training for these unstable conditions, NExUME relies on an estimation of the available resources. For this, a first-of-its-kind dataset containing energy harvesting traces and available computation libraries is introduced.\nThe training process reduces intermittency-related failures by treating the number of loop iterations as learnable parameters and task-fusions to meet energy budgets. Additionally, dynamic dropouts during execution ensure the completion of layers and dynamic quantization balances out the accuracy degradation. An adaptive regularization strategy prevents weights from being undertrained. Lastly, the authors introduce a task-scheduler that adjusts in real-time to the energy conditions estimated.", "strengths": "- Embedding energy variability into the training process is a novel idea.\n- Extends existing Neural Architecture Search for intermittent computing systems\n- Evaluation on SOTA datasets and DNNs (in the context of intermittent computing)", "weaknesses": "- Evaluation on SOTA datasets and DNNs (in the context of intermittent computing). While also a strength, it also raises a questions. The ML community has moved on to Attention and Transformers and large scale datasets. This paper does not discuss how such modern architectures can be deployed in the intermittent setting. \n\n- contribution over SOTA remains unclear. The paper cites numerous NAS frameworks for intermittent/MCU computing (and there are more like [1]) and a large body of work on intermittent execution of DNNs such as [2, 3] and numerous papers cited in the introduction. To me, the contribution over these remains unclear, as many aspects are also in these papers. \n\n- lack of SOTA baselines: The paper should compare to SOTA approaches. \n\n- statement such as \"Since we are the first work to propose a new training approach targeted for intermittent devices and inference optimizations\" should be toned down. Instead, please carefully explain your contributions over SOTA and compare them to SOTA baseline. \n\n- ablation study: accuracy and overhead if full power is available \n\n- Overall: this paper is better suited at a system conference, such as SenSys or MobiSys\n\n- the title is misleading, the paper is about more than DNN training. \n\n- BLE board does not have FeRAM and thereby not a classic board for intermittent computing. Why do the authors choose it? How does the intermittent part work here, especially QuantaTask?\n\n\n- On several occurrences, the paper is written (too) vaguely. E.g. \n - There is no hint as to how tasks are \"fused\" when executing multiple quanta would exceed the energy budget. To my understanding, the function in only explained in the appendix as part of the source code but not the text.\n - Dropout rates are adjusted on \"specific\" criteria. Even though the appendix provides details, this vague style of writing reads weirdly and is better served with examples\n\n- As mentioned on several occasions, a big part of the presented work is the availability of the database of DynAgent which also contains hardware-information, yet only 2 different microcontrollers are used for evaluating the framework. Testing a broader variety of systems seems sensible here\n\n- Drawbacks like up to 34% increased instruction count and up to 17% increased memory bandwidth usage are stated but hardly discussed or put into perspective. While this is not surprising for intermittent computing, the numbers should still be discussed and be compared to other approaches.\n\n- Typos: Figure 3: Sensitivity and ablation study. DN is DynNAS, DF is FynFit, and DI is DynInfer: FynFit -> DynFit\n\n- with a Pixel-5 phone as the host device: does this matter? Any device should do the job.\n\n- the paper has a section LIMITATIONS AND DISCUSSION which also discussed limitations, such as the runtime overhead. The last two sentences, however, read a bit bumpy and should be streamlined and also discussed (and not just stated). \n\n\n[1] Edgar Liberis, Łukasz Dudziak, and Nicholas D. Lane. 2021. ΜNAS: Constrained Neural Architecture Search for Microcontrollers. In Proceedings of the 1st Workshop on Machine Learning and Systems (EuroMLSys '21). Association for Computing Machinery, New York, NY, USA, 70–79. https://doi.org/10.1145/3437984.3458836\n\n[2] Chih-Hsuan Yen, Hashan Roshantha Mendis, Tei-Wei Kuo, and Pi-Cheng Hsiu. 2023. Keep in Balance: Runtime-reconfigurable Intermittent Deep Inference. ACM Trans. Embed. Comput. Syst. 22, 5s, Article 124 (October 2023), 25 pages. https://doi.org/10.1145/3607918\n\n[3] C. -H. Yen, H. R. Mendis, T. -W. Kuo and P. -C. Hsiu, \"Stateful Neural Networks for Intermittent Systems,\" in IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems, vol. 41, no. 11, pp. 4229-4240, Nov. 2022, doi: 10.1109/TCAD.2022.3197513", "questions": "* the BLE board does not have FeRAM and thereby not a classic board for intermittent computing. Why do the authors choose it? How does the intermittent part work here, especially QuantaTask?\n\n* what are the contributions over SOTA?\n\n* what is the performance compared SOTA baselines?\n\n* can you consider a different title?"}, {"review_id": "5Wej5X5dcF", "reviewer": "Reviewer_qDus", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "The paper introduces a framework designed to enable consistent and accurate deep neural network inference on energy-harvesting wireless sensor networks that operate under intermittent power conditions. This framework addresses the challenges of unreliable energy supply and computational limitations in such environments. The proposed framework uses energy variability aware network architecture search, dynamic training optimizations, and an intermittency-aware task scheduler to adapt DNN computations based on real-time energy availability, in order to meet service level objectives (SLOs) in resource-constrained settings.", "strengths": "The paper studies an interesting and important problem of enabling reliable DNN inference in energy-harvesting wireless sensor networks. The writing is overall clear and well-organized. The motivations, methods, and experimental findings are easy to follow.", "weaknesses": "The proposed method relies on detailed profiling of the hardware to model energy consumption, computational capabilities, and memory footprint. This process can be time-consuming and complex, requiring extensive micro-profiling. \n\nIn DynInfer, an energy-aware priority scheduling heuristic is used. With no theoretical analysis of its performance compared to optimal scheduling solution, its scheduling optimality is hard to estimate.\n\nThe explanations of some techniques in the methods section, particularly within the DynFit and DynInfer components, remain at a high level, lacking depth in technical specifics. For example, while the dynamic dropout and quantization strategies in DynFit are introduced, there is limited detail on how dropout rates and quantization levels are adjusted based on energy profiles or how these adjustments differ from standard implementations. Additionally, the methods used in each component lack a sense of innovation, as they seem to be a simple use of existing techniques without substantial enhancements. \n\nThe impact of under-trained and overfitting weights requires further examination. More frequent updates of certain weights do not necessarily lead to \"overfitting,\" and, conversely, infrequent updates do not inherently imply \"underfitting.\" From a layer perspective, the effect of varying update frequencies on individual weights may be limited, suggesting that this issue may be less impactful than indicated. \n\nThe experiments mainly focus on accuracy improvement. Other performance metrics, such as energy consumption, latency, computational overhead, and the number of power failures or SLO violations, are not extensively analysed. \n\nThe experiments are conducted on relatively small datasets and models.", "questions": "1. How much of the resource of the method is used or what is its time complexity?\n\n2. As a sub-optimal scheduling solution, how would its scheduling performance to be ensured?\n\n3. What is the rationale behind the overall method design?\n\n4. Whether the accuracy is more important than the other performance metrics in your design?"}, {"review_id": "4ddLqWwXxZ", "reviewer": "Reviewer_GUp2", "rating": 6, "confidence": 4, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "The paper presents NExUME, a novel framework for training and deploying deep neural networks (DNNs) on energy-harvesting micro-computers with intermittent power. The authors introduce dynamic dropout rates and quantization levels that adapt based on real-time energy availability, improving the accuracy and robustness of DNNs in constrained power settings. The paper demonstrates NExUME's efficacy in optimizing both training and inference phases through extensive experiments, showcasing significant accuracy gains over traditional approaches in intermittently powered environments. Additionally, the introduction of a unique dataset to facilitate further research on energy-harvesting applications is a noteworthy contribution.", "strengths": "1. The paper addresses an important challenge in deploying DNNs on resource-constrained, intermittently powered devices, an area that is underexplored in current literature. By incorporating real-time energy-aware adaptations, this work proposes a unique and valuable solution.\n2. The work is thorough, presenting a well-structured methodology, clearly defined optimization functions, and a series of experiments across various datasets and hardware platforms. The choice of energy-aware dropout and quantization strategies tailored to intermittent environments is both innovative and well-validated.\n3. The paper is well-written, with each component of the proposed framework (DynFit, DynInfer) and the optimization strategies clearly explained. Figures and tables effectively support the results and comparisons.\n4. The proposed approach has broad implications for real-world applications in energy-limited environments, such as remote monitoring and IoT systems, where consistent power is unavailable. The improvements in accuracy (6-22%) and the novel dataset enhance the significance and impact of the research.", "weaknesses": "1. The experiments, while comprehensive, rely on specific hardware configurations that may not be accessible for replication. The reliance on components like MSP430FR5994 and certain energy-harvesting setups may limit reproducibility.\n2. While the paper compares NExUME to iNAS and other energy-aware methods, it lacks a detailed comparison with additional state-of-the-art adaptive or intermittent DNN training techniques. Including broader comparisons could enhance the validation of its claims.\n3. The paper mentions challenges with larger networks and datasets. However, there is limited discussion on potential approaches to address these limitations, which would be valuable for practitioners aiming to scale this approach.\n4. The profiling is based on conservative estimates, which, while practical, may not be universally applicable. Further analysis of the impact of profiling variations on model performance could strengthen the evaluation.", "questions": "1. Can the authors clarify the robustness of NExUME across various hardware platforms beyond those tested? Would modifications be required for different types of microcontrollers or energy-harvesting setups?\n2. How does NExUME handle environments with extremely low or sporadic energy levels, where consistent dropout and quantization adjustments may not be feasible?\n3. Can the authors provide more detail on the potential effects of overfitting introduced by DynFit’s dropout variations? Would techniques like dropout scheduling help mitigate this?"}, {"review_id": "FrTQYkrVub", "reviewer": "Reviewer_tANu", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 2, "summary": "This paper presents NExUME, a training methodology that is designed to cater to intermittent power energy harvesting systems. The authors proposed two key contributions which are (1) a new method for training where one can dynamically adjust dropout rate and quantization levels to cater to varying energy availability in the EH system, and (2) a task scheduler that optimizes task completion in EH systems. The authors also contribute a machine status monitoring dataset. NExUME shows a 6 to 22 % accuracy improvement over existing baselines on simple ML tasks. However, at the same time, NExUME incurs a 5% overhead in computation, an increase in the number of instructions ranging between 11.4%- 34.2%, and an increase in memory bandwidth from 6 to 17%.", "strengths": "+ This paper is the first to present novel training and inference methods that take dropouts and quantization into account in the context of energy harvesting systems. \n+ A new machine monitoring dataset \n+ The results shown are good when compared to the baselines presented in the paper.\n+ Writing and the presentation of the work are clear. \n+ Choice of datasets is appropriate given that the work is designed for resource-constrained embedded systems. \n+ Decent ablation studies \n+ Actual implementation of such a system is not trivial.", "weaknesses": "- The idea of dynamically adjusting dropout rates and quantization levels is not novel. It is novel in the context of EH systems. \n- Energy-aware scheduling is not novel. \n- The quantification of overheads is done. However, its implications are not discussed. The range in terms of % is indicated. However, how does it vary with the datasets? \n- Some of the existing work in intermittent systems are not compared such as ePerceptive: energy reactive embedded intelligence for batteryless sensors and Zygarde: Time-Sensitive On-Device Deep Inference and Adaptation on Intermittently-Powered Systems\n- Details on the machine status monitoring dataset are missing in Sec 4.3 How are R1, R2, and R3 different? What are their RPM speeds? What is S1 and S2? \n- Only accuracy results are shown. It is also important to know the latency/inference. and memory requirements of the system. \n- DynFit comprises adjusting quantization levels and dropouts. In the ablation studies, it is unclear which of these is bringing more benefits to the system.", "questions": "Details on the machine status monitoring dataset are missing in Sec 4.3 How are R1, R2, and R3 different? What are the RPM speeds? What is S1 and S2?\n\nEnergy-aware scheduling is not novel. Can you clarify your novelty with respect to existing scheduling algorithms? \n\nThe authors claim that their machine status monitoring is the first of its kind. Can they clarify what datasets already exist and how the dataset introduced in the paper is different?"}], "discussions": [{"comment_id": "vu3ztv9Apd", "replyto": "SFNqrHQTEP", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank all the reviewers for their thoughtful, constructive and encouraging feedback, which has greatly helped us improve our work. We thanks the PC members and area chairs for their dedicated efforts. We have carefully considered all comments and made revisions to the manuscript, which the reviewers have graciously accepted. Below, we summarize the key points:\n\n**1. Our Novelty**\n- **Integration of Energy Variability into Training and Inference**: We introduce NExUME, a novel framework that uniquely integrates energy variability awareness directly into both the training (DynFit) and inference (DynInfer) processes. Unlike existing methods, our approach enables DNNs to adapt dynamically to real-time energy conditions in energy-harvesting (EH) environments.\n- **Adaptive Regularization Strategy**: We introduce an adaptive regularization technique to prevent underfitting and overfitting caused by uneven weight updates due to dynamic dropout, ensuring model robustness in EH settings.\n- **Novel Task Fusion Mechanism**: Our scheduler includes a novel task fusion mechanism that combines smaller tasks into larger atomic units, optimizing execution under intermittent power and minimizing checkpointing overhead—a challenge unique to EH systems.\n\n**2. Additional Evaluation**\n- **Expanded Hardware Platforms**: We have extended our experiments to include additional microcontrollers such as ESP32 S3 Eye, STM32H7, and Raspberry Pi Pico. This demonstrates NExUME's applicability and robustness across various hardware configurations.\n- **Comparison with SOTA Methods**: We have included comparisons with recent SOTA methods, including Keep in Balance, Stateful Neural Networks, ePerceptive, and DynBal. Our results show that NExUME consistently outperforms these methods in terms of accuracy and energy efficiency across multiple datasets and platforms.\n- **Detailed Dataset Information**: We have provided more details about the machine status monitoring dataset, clarifying the experimental setup and the significance of each class in the dataset. This dataset is the first of its kind to capture machine status monitoring data across multiple operating conditions using EH sensors.\n- **Comprehensive Metrics**: In addition to accuracy, we have included evaluations on energy efficiency (MOps/Joule), overheads during training and inferece, and memory requirements. We have also conducted an ablation study.\n- **Overhead Analysis**: We have provided a detailed discussion of the computational overhead introduced by our methods, noting that while there is an increase in instruction count and memory bandwidth usage, the trade-off results in significant improvements in accuracy and energy efficiency, which is acceptable given the constraints of intermittent computing.\n\n**3. Why ICLR is the Right Venue**\n- **Interdisciplinary Contribution**: Our work lies at the intersection of machine learning and systems engineering, addressing challenges that require interdisciplinary solutions. We believe that the ICLR community values such contributions that push the boundaries of ML deployment in real-world environments.\n- **Promoting Sustainable and Green Computing**: As the ML community continues to develop large-scale models, there is a growing emphasis on sustainability and reducing the environmental impact of computing. Our research contributes to this goal by enabling intelligent computations on devices powered by renewable energy sources, which aligns with the interests of the ICLR community.\n- **Encouraging Co-Design of Algorithms and Hardware**: By co-designing algorithms, models, and hardware, we can achieve more efficient and effective solutions for highly dynamic and constrained systems. This philosophy aligns with the broader ML community's interest in holistic approaches to problem-solving.\n- **Relevance to ML Deployment and Practical Applications**: Our work introduces novel machine learning methodologies tailored for emerging hardware platforms. We anticipate that presenting our research at ICLR will inspire others in the community to explore and develop algorithms that are not only theoretically sound but also practically deployable on next-generation, sustainable hardware.\n- **Stimulating Further Research**: We believe that our work will encourage the ML community to build better algorithms, models, and hardware that are optimized for deployment in energy-harvesting environments, contributing to the advancement of sustainable, intelligent systems.\n---\nWe appreciate the reviewers' recognition of our efforts and their acknowledgment of the contributions our work makes to the field. We believe that our paper offers valuable insights and advancements that are highly relevant to the ICLR audience. We look forward to the opportunity to share our work with the community and take the research on intermittent/sustainable AI forward."}, {"comment_id": "9ziKu65jSS", "replyto": "7ZdWBRFupN", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for their constructive and encourging feedback. We will work towards addressing the minor changes in the final version of the paper. We look forward to the opportunity to share our work with the ICLR community."}, {"comment_id": "7ZdWBRFupN", "replyto": "5Wej5X5dcF", "author_type": "reviewer", "reviewer": "Reviewer_qDus", "comment": "I am happy with the responses provided by the authors. Compared to the previous version, the new draft shows significant improvements. As a result, I have increased my scores across various items and the overall ranking to acknowledge the authors' efforts."}, {"comment_id": "CQXqA1sGu5", "replyto": "cpfnU3bKEP", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank the reviewer for their encouraging feedback. \nTo address the remaining minor points:\n\n1. **Figure 3 Font Size:**\n\n - We will rework Figure 3 to increase the font sizes and enhance readability. This adjustment will ensure that all details are easily discernible in the final version of the paper.\n\n2. **Appendix Details:**\n\n - We will expand the appendix to include details about the new baselines, the profiling mechanisms adapted for the new hardware platforms, and provide general guidelines on micro-profiling. This addition will help readers understand the process of benchmarking new hardware and grasp its characteristics more effectively.\n\n3. **Energy Consumption Ablation Study:**\n\n - We acknowledge the importance of comparing accuracy at matching energy consumption levels. In the final version, we will include an ablation study that provides a breakdown of energy consumption, illustrating how each component contributes to the total energy usage. This will allow for a direct comparison of accuracy relative to energy consumption, aligning with the baselines' energy profiles.\n\nRegarding your observation on energy consumption:\n\n- While our approach shows higher energy efficiency—achieving a higher number of operations per Joule—it does not necessarily result in higher overall energy consumption. The overhead introduced by our methods is less than 5% in terms of operations. Higher energy efficiency in our context means performing more effective computations with less energy, thereby improving the ratio of useful work to energy consumed. We apologize for any confusion and will clarify these metrics and insights more explicitly in the final paper.\n\n---\n\n**Why ICLR Is the Right Venue for This Work**\n\nWe firmly believe that ICLR is an impactful venue for this research for several reasons:\n\n- **Bridging Machine Learning and Systems:** Our work lies at the intersection of machine learning and systems engineering. It demonstrates how advances in ML training and inference can be effectively adapted for deployment in resource-constrained, intermittent environments—a challenge that requires interdisciplinary solutions.\n\n- **Promoting Sustainable and Green Computing:** As the ML community continues to develop large-scale models, there is a growing emphasis on sustainability and reducing the environmental impact of computing. Energy-harvesting systems, especially when deployed at scale, can significantly reduce embodied carbon. Our research contributes to this goal by enabling intelligent computations on devices powered by renewable energy sources.\n\n- **Encouraging Co-Design of Algorithms and Hardware:** We believe that optimizing software or hardware in isolation is insufficient for the challenges posed by highly dynamic and constrained systems. By co-designing algorithms, models, and hardware, we can achieve more efficient and effective solutions. This philosophy aligns with the broader ML community's interest in holistic approaches to problem-solving.\n\n- **Relevance to the ICLR Community:** Our work introduces novel machine learning methodologies tailored for emerging hardware platforms. We anticipate that presenting our research at ICLR will inspire others in the community to explore and develop algorithms that are not only theoretically sound but also practically deployable on next-generation, sustainable hardware.\n\n---\n\nWe thank the reviewer again for their constructive feedback and look forward to the opportunity to share our work with the ICLR community."}, {"comment_id": "cpfnU3bKEP", "replyto": "9gOJluAI8W", "author_type": "reviewer", "reviewer": "Reviewer_P7Bk", "comment": "Firstly, many thanks to the authors for their hard work. I very much appreciate the changes and how detailed they address my comments. Overall, my concerns have now been addressed, and I am raising my score accordingly. \n\nOnly, minor things remain (and can wait for a camera-ready version / later submission from my perspective):\n* Figure 3: font size is too small\n* Appendix: still limited to the original baselines\n* The paper shows higher accuracy but also higher energy consumption than baselines. I am wondering, if, as ablation study, it is possible to see the accuracy for an energy consumption that matches the consumption of the baselines. \n\nNonetheless, despite the strong and interesting system results, I still unsure if the ICLR community will appreciate this paper or if the paper is better suited at a systems or computer engineering conference. For example, out of the baselines only one paper with published at an AI conference, all others are at system or computer engineering venues if I am not mistaken."}, {"comment_id": "X0unI3CUAz", "replyto": "FrTQYkrVub", "author_type": "authors", "reviewer": null, "comment": "We thank the reviewer for the encouraging feedback. \nWe used Million Ops/Joule (MOPs/J) to measure the energy efficiency and have mentioned it in the text. To keep the caption in one line, we removed it from the caption. In the final draft of the paper, we'll add the unit to the table caption."}, {"comment_id": "tCGSb5rW85", "replyto": "FrTQYkrVub", "author_type": "reviewer", "reviewer": "Reviewer_tANu", "comment": "I am happy with the rebuttal. The authors have done a diligent job in clarifying the novelty and did extra experiments to compare against other SOTA baselines. It is more of an engineering type of work but is important in the context of embedded ML on intermittent systems. I am happy to raise my score to 6. \n\nMinor comment: It would be good to add the units of measurement in the caption of Table 2."}, {"comment_id": "oY7llL6BKs", "replyto": "FrTQYkrVub", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank the reviewer for the valuable feedback. Below, we address your questions and concerns. Please go over the revised manuscript for a detailed overview of the changes.\n\n1. **Details on the Machine Status Monitoring Dataset:**\n - **Differences between R1, R2, and R3:** R1, R2, and R3 correspond to three different spindle rotation speeds of the Bridgeport machine without any load (no job). Specifically, R1 is 100 RPM, R2 is 200 RPM, and R3 is 300 RPM.\n - **Explanation of S1 and S2:** We apologize for the oversight. There was a typographical error; S1 and S2 should have been SJ and SI. SJ stands for \"Spindle under Job,\" where the machine is operating with an active job, and SI stands for \"Spindle Idle,\" where the machine is powered on but not performing any operation.\n \n2. **Novelty of Energy-Aware Scheduling:**\n - **Our Novel Contributions:**\n - **Integration with Intermittent Environments:** While energy-aware scheduling exists, our scheduler (DynInfer) is specifically designed for intermittent power environments, accounting for atomicity constraints due to unpredictable power failures.\n - **Task Fusion Mechanism:** We introduce a novel task fusion strategy that combines smaller tasks into larger atomic units to optimize execution within the available energy budget, minimizing checkpointing overhead—a challenge unique to intermittent systems.\n - **Real-Time Energy Adaptation:** Our scheduler dynamically adjusts to real-time energy availability, which is critical in energy-harvesting systems with fluctuating power, and is not typically addressed in traditional energy-aware schedulers.\n - **Differentiation from Existing Algorithms:** Existing schedulers often assume stable energy supply and do not account for the atomic execution requirements of intermittent systems. Our approach addresses these gaps, providing a scheduling solution tailored to the unique challenges of EH environments.\n\n3. **Clarification on the Machine Status Monitoring Dataset:**\n - **Existing Datasets:** Prior datasets, such as the Case Western Reserve University Bearing Data Center dataset, focus on fault detection in machinery but do not provide data for predictive maintenance under varying operational conditions.\n - **Our Dataset's Novelty:**\n - **First of Its Kind:** Our dataset is the first to capture machine status monitoring data across multiple operating conditions (different RPMs, idle, under load) using energy-harvesting sensors.\n - **Multi-Sensor Data:** We include data from multiple types of sensors (e.g., accelerometers with different sampling rates) placed at various locations on the machine, providing a rich dataset for developing and testing algorithms in EH environments.\n - **Facilitating Research:** This dataset fills a gap in the field by enabling research on predictive maintenance and monitoring in industrial settings with intermittent power.\n\n4. **Novelty of Dynamically Adjusting Dropout Rates and Quantization Levels:**\n - **Our Contribution:** While the concept of adjusting dropout rates and quantization levels is known, our novelty lies in integrating these adjustments directly into both the training and inference processes based on real-time energy availability in EH systems.\n - **Energy Variability Awareness:** Existing methods do not incorporate energy profiles into the training loop. Our approach trains the DNN to adapt to energy fluctuations, which is critical for intermittent environments.\n - **Adaptive Regularization Strategy:** We introduce an adaptive regularization technique to handle the challenges posed by dynamic adjustments, ensuring model robustness—a contribution not present in prior work.\n\n5. **Implications of Overheads Not Discussed:**\n - **Overhead Analysis:** In the revised manuscript, we have added a detailed discussion on how the overheads vary with different datasets and models (Section 4.1).\n - **Dataset Variations:** The increase in instruction count and memory bandwidth usage varies depending on the complexity of the dataset and the model size. We provide a breakdown of these variations in the updated results.\n - **Trade-Off Justification:** Despite the overheads, the significant improvements in accuracy and energy efficiency justify the trade-off, especially in the context of resource-constrained EH systems.\n\n6. **Comparison with Existing Work:**\n - **Included Comparisons:** We have now included ePerceptive, DynBal, Stateful in our baselines and have compared our approach against these methods in the revised manuscript (Section 4.2).\n - **Results:** Our experiments demonstrate that NExUME outperforms these methods in terms of accuracy and energy efficiency, highlighting the effectiveness of our proposed techniques.\n---\n\nWe hope that these revisions and clarifications address your concerns."}, {"comment_id": "y5pMXYFDgO", "replyto": "4ddLqWwXxZ", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank the reviewer for the thoughtful and constructive feedback. Your comments have been invaluable in improving our work. Below, we address your questions and concerns. Please go over the revised manuscript for a detailed overview of the changes.\n\n1. **Robustness Across Various Hardware Platforms:**\n\n - **Applicability:** NExUME is designed to be hardware-agnostic and can be applied to various microcontrollers and energy-harvesting setups.\n - **Modifications:** Minimal adjustments are needed when deploying on different hardware; primarily, profiling the new hardware to obtain energy and computational characteristics.\n - **Expanded Experiments:** In the revised manuscript, we have included additional microcontrollers like ESP32 S3 Eye, STM32H7, and Raspberry Pi Pico to demonstrate NExUME's robustness across different platforms.\n\n2. **Handling Extremely Low or Sporadic Energy Levels:**\n\n - **Minimum Viable Configuration:** NExUME implements a configuration with maximum dropout rates and minimum quantization bit-widths to operate under very low energy conditions.\n - **Task Prioritization:** Essential tasks are prioritized, and non-critical computations are deferred or skipped.\n - **Predictive Models:** We employ predictive energy harvesting models to anticipate energy availability and adjust computations proactively.\n - **Low-Power Modes:** In extreme cases, the system enters a low-power standby mode until sufficient energy is available.\n\n3. **Overfitting Due to DynFit's Dropout Variations:**\n\n - **Adaptive Regularization Strategy:** We monitor weight update frequencies to detect and mitigate overfitting or underfitting.\n - **Dropout Scheduling:** We adjust dropout rates over time based on training progress and energy profiles, similar to dropout scheduling techniques.\n - **Mitigation:** These strategies help prevent overfitting introduced by dynamic dropout variations, maintaining model performance.\n\n\n4. **Limited Hardware Configurations and Reproducibility:**\n\n - **Expanded Hardware Platforms:** We have included additional microcontrollers (ESP32 S3 Eye, STM32H7, Raspberry Pi Pico) in our experiments to demonstrate applicability across various hardware.\n - **Open-Source Tools:** We provide guidelines, profiling tools, and code in our supplementary materials to facilitate replication on different hardware setups.\n - **Generalizability:** NExUME's design is modular and adaptable, requiring only minimal profiling for new platforms.\n\n5. **Comparison with Additional State-of-the-Art Methods:**\n\n - **Included Comparisons:** We have added comparisons with recent state-of-the-art methods, including Keep in Balance, Stateful Neural Networks, and ePerceptive.\n - **Results:** Our revised manuscript includes detailed evaluations showing that NExUME outperforms these methods in accuracy and energy efficiency across various datasets and platforms.\n\n6. **Challenges with Larger Networks and Datasets:**\n\n - **Acknowledgment:** We recognize the limitations when scaling to larger networks and datasets due to resource constraints in intermittent environments.\n - **Potential Solutions:** In the revised manuscript, we discuss strategies like advanced model compression, lightweight architectures, and hierarchical models to address scaling challenges.\n - **Future Work:** Extending NExUME to support larger models and datasets is identified as an area for future research.\n\n7. **Impact of Profiling Variations on Model Performance:**\n\n - **Robustness to Profiling Variations:** We conducted sensitivity analyses to assess how variations in profiling affect performance.\n - **Conservative Estimates Justification:** Using conservative estimates ensures that tasks complete within worst-case energy scenarios, enhancing system reliability.\n - **Adaptability:** NExUME can adjust to different profiling data, and we provide methods to update models if significant discrepancies are observed between estimated and actual performance.\n\n---\n\nWe hope that these revisions and clarifications address your concerns."}, {"comment_id": "mhPaIOv5HR", "replyto": "5Wej5X5dcF", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank the reviewer for the valuable feedback. Below, we address your questions and concerns. Please go over the revised manuscript for a detailed overview of the changes. \n\n1. **Resource Usage and Time Complexity:**\n - **DynFit Complexity:** Time complexity is \\( O(N . T) \\), where \\( N \\) is the number of weights and \\( T \\) is the number of training iterations. Overhead from monitoring and adjusting dropout rates is minimal. We have detailed the analysis in the revised version of the paper. \n - **DynInfer Complexity:** Scheduling algorithm has time complexity \\( O(N log N) \\) due to sorting tasks. This is acceptable for real-time applications.\n - **Resource Usage:** Additional computational overhead is within < 5\\% of standard training and inference, as detailed in Section 4.1.\n\n2. **Ensuring Scheduling Performance of the Sub-Optimal Solution:**\n - **Empirical Validation:** We compared our scheduler against optimal solutions on smaller instances; it achieves within 95% of the optimal task completion rate.\n - **Theoretical Analysis:** Scheduler prioritizes tasks based on an effective priority metric that balances task importance and energy consumption. This is an emprical metric, and can be tweaked as per the application need. Infact, each kernel can be assigned importance and the scheduler will pick them in the order of importance. Scheduler adapts to real-time energy fluctuations, ensuring critical tasks are prioritized and executed within energy constraints. Many suboptimal versions of this have been used in EH systems.\n\n3. **Rationale Behind the Overall Method Design:**\n - **Integration of Energy Variability:** Our method integrates energy variability awareness into both training (DynFit) and inference (DynInfer), enabling DNNs to adapt dynamically to real-time energy conditions.\n - **Holistic Approach:** By addressing both training and inference phases, we provide a comprehensive solution for intermittent environments.\n - **Task Atomicity and Scheduling:** QuantaTasks and the energy-aware scheduler ensure computations complete without interruption, minimizing checkpointing overhead.\n\n4. **Balancing Accuracy and Other Performance Metrics:**\n - While accuracy is crucial, we also optimize energy consumption, latency, and computational overhead. We show Energy efficiency as a metric in Table 2.\n - We ensure the DNN meets SLOs, balancing accuracy with timely and efficient execution under energy constraints.\n\n5. **Reliance on Detailed Hardware Profiling:**\n - Profiling is a one-time process per hardware platform, essential for accurate energy modeling in intermittent environments. This becomes even important where we are designing targeted hardware and applications. In the appendix we show some of the kernels we use to profile the micro-controllers. We are going include a version of the profiling code into our repo, but for microcontrollers the design, profiling need and metrics would require expert analysis and design. \n\n6. **High-Level Explanations and Lack of Technical Specifics:**\n - Expanded technical details in Sections 3.1 and 3.2.\n - **DynFit:** Provided equations showing how dropout rates and quantization levels are adjusted based on energy availability; explained the adaptive regularization strategy.\n - **DynInfer:** Included a formal definition of task fusion; provided examples to illustrate the scheduling process.\n\n7. **Perceived Lack of Innovation in Methods:**\n - Our contributions include:\n - **Dynamic Adjustment Based on Energy Profiles:** Methods adjust dropout and quantization in real-time based on energy availability, integrating energy variability into training.\n - **Adaptive Regularization Strategy:** Introduced a new strategy to prevent underfitting and overfitting caused by uneven weight updates due to dynamic dropout.\n - **Task Fusion in Scheduling:** Scheduler includes a novel task fusion mechanism to optimize execution under intermittent power.\n\n8. **Impact of Under-Trained and Overfitting Weights:**\n - Elaborated on the adaptive regularization strategy in 3.1.1\n - **Monitoring Update Frequencies:** We monitor weight updates to identify under-trained or overfitting weights.\n - **Adjustment Mechanisms:** Adjust dropout rates and apply L2 regularization to ensure balanced training.\n - **Effectiveness:** This strategy helps maintain model performance despite dynamic adjustments during training.\n\n9. **Experiments on Small Datasets and Models:**\n - The datasets used are representative of real-world applications in resource-constrained environments, EH devices typically operate with lightweight models due to hardware/energy limitations.\n - Extending our methods to larger datasets and models is an area for future research, especially for EH applicarion with large energy footprint like solar powered urban mobility.\n\nWe hope these clarifications address your concerns. Thank you again for your valuable feedback."}, {"comment_id": "C2X42M0tZ0", "replyto": "9gOJluAI8W", "author_type": "authors", "reviewer": null, "comment": "We sincerely thank the reviewer for the thoughtful and constructive feedback, which has greatly improved our work. Below, we address each of your points in detail. Please go over the revised manuscript for a detailed overview of the changes.\n\n**Contribution over State-of-the-Art:**\n- We have compared NExUME with recent state-of-the-art methods, including DynBal, Keep in Balance, and Stateful Neural Networks, and included detailed discussions on how our approach differs from and improves upon these methods.\n- We highlighted the unique integration of energy variability awareness directly into both the training and inference processes, which is not addressed by existing methods. While iNAS focuses on constrained NAS for microcontrollers in intermittent settings, it does not account for real-time energy fluctuations during training and inference. Similarly, Keep in Balance and Stateful Neural Networks primarily address inference optimizations without integrating energy variability into the training process.\n- We emphasized our novel adaptive training mechanisms (DynFit) and intermittency-aware scheduling with task fusion (DynInfer), which collectively provide a holistic solution for intermittent environments.\n- We have faithfully re-implemented the aofrementioned methods and included them in baseline. Results on accuracy and energy efficiency (which we believe is a crucial metric) are shown in the Results section.\n\n**Evaluation on SOTA Datasets and Modern Architectures:**\nWe acknowledge that modern architectures like Transformers have gained prominence. However, deploying such architectures on ultra-low-power, energy-harvesting devices is currently impractical due to their high computational and memory demands. Our work focuses on enabling efficient and reliable deployment of lightweight DNNs in intermittent environments.\nIn the Limitations and Discussion section.\nAs IoT scales, the embodied carbon from silicon manufacturing and battery usage poses significant challenges. Addressing this needs perennial sustainable/EH devices running compact DNNs for specific tasks, emphasizing EH-aware DNN architectures, EH-aware training strategies and EH aware inference schduling specifically optimized for tiny devices, which is our core focus.\n\n**BLE Board Does Not Have FeRAM; Why Choose It?**\n- We selected the Arduino Nano 33 BLE Sense to demonstrate NExUME's applicability on a widely used microcontroller platform, even though it lacks FeRAM. We store intermediate data in flash memory, which is non-volatile, showing that our methods can be applied to general-purpose microcontrollers commonly used in IoT applications.\n- For devices without FeRAM, we utilize flash memory for checkpointing during power interruptions. While this introduces additional overhead, our scheduling algorithm (DynInfer) and task design (QuantaTask) minimize this by ensuring tasks can complete within the available energy budget.\n- QuantaTasks are carefully profiled to be atomic units of computation that can complete without interruption, even on devices without FeRAM, by adjusting task sizes based on the energy harvesting profile and the device's energy storage capacity.\nWe have updated the manuscript in Implementation Details to explain this in more detail.\n\n**Writing:**\n- In DynInfer, we have added a formal definition of task fusion and provided an example diagram to clearly explain how tasks are fused.\n- We have elaborated on how dropout rates are adjusted based on energy availability, providing equations and explanations.\n\n**Limited Evaluation:**\nWe agree that evaluating on a broader range of hardware platforms strengthens the work. We have:\n- Expanded our experiments to include additional microcontrollers: ESP32 S3 Eye, STM32H7, and Raspberry Pi Pico.\n- Provided results on these platforms in Table 3, demonstrating NExUME’s applicability across various hardware configurations.\n\n**Increased Instruction Count and Bandwidth:**\nWe have:\n- Provided a detailed discussion of these overheads in Section 4.1.\n- Compared the overheads to other approaches, noting that while there is an increase in instruction count and memory usage, the trade-off results in significant improvements in accuracy and energy efficiency under intermittent power conditions.\n- Emphasized that the overhead is acceptable given the constraints of intermittent computing and the benefits provided by our methods.\n\n**Limitations and Discussion Section:**\nWe have revised the Limitations and Discussion section to improve clarity, coherence, and provide more insight into the importance of smaller DNNs and sustainability. \n\n**Different Title**\nWe considering chaing the title to:\n**\"NExUME: Adaptive Training and Inference for Deep Neural Networks under Intermittent Power Environments.\"**\nWe believe this title better reflcts our idea and we will discuss the possibility of officially changing it with the committee.\n\nWe hope that these revisions address your concerns."}, {"comment_id": "082ISe4sSz", "replyto": "9gOJluAI8W", "author_type": "reviewer", "reviewer": "Reviewer_P7Bk", "comment": "After reading the other reviews, I see that my review has the most negative scores. However, I see that other reviews point out similar issues and for me, these do not justify a higher score. Thus, for now, I am sticking to my scores and hope for a rebuttal that answers my open questions."}], "meta_review": {"metareview": "**Summary:** The paper proposes NExUME, a framework for training DNNs in energy-harvesting environments with intermittent power. It integrates real-time energy variability into DNN training via dynamic dropout rates, quantization adjustments, and a task scheduler to optimize computations under energy constraints. Results show 6–22% accuracy improvement over baselines with modest computational overhead.\n\n\n**Strength:** \n1. The integration of energy variability into the DNN training process, enhancing adaptability in energy-harvesting systems, is novel.\n\n\n2. The proposed framework achieves significant accuracy improvements under specific energy constraints, as validated across multiple datasets and hardware platforms.\n\n\n3. In addition to the framework, this work also introduces a first-of-its-kind machine monitoring dataset for research in intermittent computing.\n\n\n**Weakness:**\n1. The techniques like dynamic dropout and energy-aware scheduling are extensions of existing concepts.\n\n2. The proposed method is mainly evaluated on scale-scale models and datasets, lacking sufficient evaluation of its scalability to larger models and modern architectures, such as transformers.\n\n3. SOTA training methods and systems like ePerceptive and Zygarde are not benchmarked. Additionally, the efficiency metrics are not sufficiently analyzed in the original manuscript.\n\n\n**Reasons for the decision:**\n\nWhile the paper's technical contributions are relatively incremental, it addresses a critical and underexplored challenge in training DNNs in energy-harvesting environments. In addition, the author's response provided more comprehensive experiments and efficiency analyses, effectively addressing most of the reviewers’ concerns. Therefore, I am inclined to accept this paper.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} +{"paper_id": "XMOaOigOQo", "forum_url": "https://openreview.net/forum?id=XMOaOigOQo", "venue": "ICLR.cc/2025/Conference", "year": 2025, "title": "ContraDiff: Planning Towards High Return States via Contrastive Learning", "authors": ["Yixiang Shan", "Zhengbang Zhu", "Ting Long", "Liang Qifan", "Yi Chang", "Weinan Zhang", "Liang Yin"], "abstract": "The performance of offline reinforcement learning (RL) is sensitive to the proportion of high-return trajectories in the offline dataset. However, in many simulation environments and real-world scenarios, there are large ratios of low-return trajectories rather than high-return trajectories, which makes learning an efficient policy challenging. In this paper, we propose a method called Contrastive Diffuser (ContraDiff) to make full use of low-return trajectories and improve the performance of offline RL algorithms. Specifically, ContraDiff groups the states of trajectories in the offline dataset into high-return states and low-return states and treats them as positive and negative samples correspondingly. Then, it designs a contrastive mechanism to pull the planned trajectory of an agent toward high-return states and push them away from low-return states. Through the contrast mechanism, trajectories with low returns can serve as negative examples for policy learning, guiding the agent to avoid areas associated with low returns and achieve better performance. Through the contrast mechanism, trajectories with low returns provide a ``counteracting force'' guides the agent to avoid areas associated with low returns and achieve better performance.\nExperiments on 27 sub-optimal datasets demonstrate the effectiveness of our proposed method. Our code is publicly available at https://github.com/Looomo/contradiff.", "keywords": ["Offline Reinforcement Learning", "Decision Making", "Diffusion Models", "Machine Learning"], "primary_area": "reinforcement learning", "pdf_url": "https://openreview.net/pdf?id=XMOaOigOQo", "decision": "Accept (Poster)", "num_reviews": 3, "num_discussions": 29, "reviews": [{"review_id": "fjzVQ7SjMb", "reviewer": "Reviewer_2zbW", "rating": 5, "confidence": 4, "soundness": 2, "presentation": 3, "contribution": 2, "summary": "This paper introduces ContraDiff, a novel offline reinforcement learning method that leverages contrastive learning to make better use of low-return trajectories by pulling states towards high-return regions and pushing them away from low-return areas. While the approach shows promising results across various environments and demonstrates advantages in handling sub-optimal datasets, several limitations remain, including unclear theoretical justification for the relationship between contrastive and RL losses, heavy reliance on hyperparameters for positive/negative sample selection, and lack of trajectory-level temporal information in the contrastive learning design. Despite these limitations, the paper presents a well-structured contribution with comprehensive experiments, providing a new perspective on utilizing sub-optimal data in offline RL.", "strengths": "1. The paper is well-structured and easy to follow, with clear motivation and comprehensive explanations of each component in the proposed method.\n\n2. The empirical evaluation is thorough, covering 27 sub-optimal datasets and including extensive ablation studies, hyper-parameter analysis, and visualizations that help understand the method's behavior.\n\n3. The proposed method provides a novel perspective on utilizing low-return trajectories in offline RL, and the implementation is relatively simple with only a few additional hyperparameters compared to the base diffusion model.", "weaknesses": "1. The relationship between contrastive loss and RL loss needs further detailed justification. Intuitively, these two losses appear highly correlated, especially when considering low-return trajectories as negative samples. A numerical analysis comparing these losses across trajectories with different returns would be valuable to verify if they indeed show similar trends. If so, the contrastive loss might merely serve as an enhancement to the RL loss rather than a meaningful regularization term that needs to be balanced against the primary objective.\n2. The methodological differences illustrated in Figure 1(d)(e)(f) lack sufficient justification.This figure fails to demonstrate why pushing away from low-return states is fundamentally superior to upweighting high-return trajectories, or why return-based contrastive learning outperforms traditional trajectory-based approaches.\n3. The paper's method of determining positive and negative samples relies heavily on hyperparameters (ξ, ζ, σ) without theoretical justification for their value ranges. The paper lacks a principled approach to determine these thresholds across different environments. A more systematic study on how different return thresholds affect the distribution of positive/negative samples and their impact on learning dynamics would strengthen the method's foundation.\n4. The current state-level contrastive learning approach breaks the temporal correlation between states within trajectories, as it samples positive and negative states purely based on return values (SR) or clustering results (SRD). This design might lose important sequential information that could be better captured through trajectory-level contrastive learning. I encourage authors to explore trajectory-aligned state sampling strategies and conduct comparative experiments between state-level and trajectory-level contrastive learning to provide more solid empirical evidence for the design choices.\n5. The competitive performance of baselines like CDE and HD raises questions about the necessity of the proposed contrastive mechanism, especially considering these methods achieve comparable results without any specific designs for handling low-return trajectories. While ContraDiff shows improvements in certain scenarios, the performance overlap with these simpler methods suggests that the advantages of explicitly handling low-return trajectories might be less significant than claimed. This observation calls for a more thorough analysis of what unique benefits, if any, are brought by the contrastive mechanism for low-return trajectory utilization.", "questions": "See weakness"}, {"review_id": "S6HtbgwfIZ", "reviewer": "Reviewer_DdY7", "rating": 6, "confidence": 4, "soundness": 2, "presentation": 2, "contribution": 2, "summary": "The paper presents a method called ContraDiff, which leverages contrastive learning to guide an agent towards high-return states and pushing agents away from low-return states. The proposed method builds upon diffusion-based trajectory planning models and aims to solve offline reinforcement learning problems especially when high-return trajectories are limited in the offline dataset. Experiments are conducted on standard D4RL tasks and exhibit better performance as compared to baselines.", "strengths": "The combination of contrastive learning and diffusion-based trajectory planning model is a novel idea, allowing the model to learn from both high- and low-return samples. The main figure is clear and well-organized, making it easy to understand the method. Experiments are extensively conducted across many test environments.", "weaknesses": "1. While the idea of using contrastive learning is well motivated, the section on how \"positive\" and \"negative\" examples are identified requires additional explanations. The paper did not specify how the clustering is performed -- if the clustering is to enforce \"dynamic consistency\", I would assume the information is already available in the offline dataset? For me it is unclear how well the clustering captures reachability or state transitions beyond just grouping next states together. This could be a limitation in environments with complex or non-linear dynamics.\n2. The paper lacks discussion on the complexity or runtime implications introduced by k-means clustering, especially for large offline RL datasets.\n3. The increases in performance are mostly marginal. For many tasks the increase falls within one standard deviation of the baseline methods. \n4. The authors mentioned \"comparing ContraDiff with other regular methods\" but didn't include specific indication of where the results are presented. I later found the results in the appendix, but the authors may consider additional illustrations and explanations on the results. Furthermore methods like CQL are only tested on the standard D4RL datasets but not the mixture dataset proposed by the authors. More experiment results would be needed here to better demonstrate model performance. \n5. It is unclear how the baseline results were obtained -- whether they came from existing code, were reimplemented by the authors, or taken from other sources. \n6. The writing, especially in the experiments section, is not very clear to understand. Some sentences and typos make the paper harder to follow. For example: \n* L290: \"...high-return sample sparsity situations\"\n* L301: \"...which are focus on addressing...\"\n* L364: \"...declines with the ratio declines\"\n* Table3: \"Walekr2d\"", "questions": "Aside from mentioned in the Weakness section, my questions mostly come from experiments presented by the paper:\n1. When comparing Table 3 with Table 1, it seems like introducing the expert trajectory leads to a decrease in performance, unexpectedly. For instance, DT achieves a score of 36.6 at original HalfCheetah-MR dataset but 7.5, 6.7, 6.1 in three conditions where expert data is introduced. I was wondering if the same codebase is used in achieving these results, and if so, could the authors share insights on why DT’s performance drops so drastically with the added expert data?\n2. In most tasks ContraDiff-SR outperforms ContraDiff-SRD -- this seems unexpected given the author's intuition mentioned in section 3 that ContraDiff-SR may ignore the transition dynamics in its sampling process? \n3. How would the proposed model perform on sparse-reward tasks like AntMaze? This benchmark is commonly used in offline RL evaluations but results for it are missing here."}, {"review_id": "rIVYiUny7q", "reviewer": "Reviewer_Ku5d", "rating": 6, "confidence": 3, "soundness": 3, "presentation": 3, "contribution": 3, "summary": "This paper introduces a novel approach, ContraDiff, designed to enhance offline reinforcement learning (RL) by leveraging low-return trajectories in a contrastive learning framework. Offline RL's performance often depends on the presence of high-return trajectories, which are frequently sparse in datasets. ContraDiff addresses this challenge by classifying trajectories into high- and low-return groups, treating them as positive and negative samples, respectively. The approach applies a contrastive mechanism that pulls the agent's trajectories toward high-return states while pushing them away from low-return ones. This strategy enables low-return trajectories to serve as a guiding force to avoid suboptimal regions, enhancing policy performance.", "strengths": "The paper presents a compelling approach to enhancing offline reinforcement learning (RL) through a novel contrastive framework, which incorporates low-return trajectories in ways that previous approaches overlook. ContraDiff’s use of a contrastive learning framework to distinguish between high- and low-return trajectories is an innovative application of contrastive learning within the offline RL domain. Instead of merely reweighting high-return samples, the method leverages low-return trajectories as guiding forces, creating a “counteracting force” that helps the policy avoid suboptimal states. This approach represents a shift from traditional methods that focus primarily on high-return trajectories, expanding the value of underutilized low-return data.", "weaknesses": "* While the paper presents contrastive learning as a means to exploit low-return data, it could better clarify why this mechanism is theoretically optimal for avoiding low-return states compared to other potential approaches, such as weighting adjustments or imitation-based filtering. \n\n* The authors are not clear enough in describing the use of weighted contrast loss to constrain trajectory generation", "questions": "* In the 3.3 model learning section, I noticed that the optimized trajectory generation by minimizing the Mean Square Error between the ground truth and neat trajectory predicted. Therefore, the diffusion should denoise the data from the noisy data completely. I think it's more expensive to train like this. Can you show some experiments comparing the cost of training?\n* ContraDiff planning towards high return states, leading policy improvements. It is very similar to some offline RL methods, such as SAW[1], A2PR[2], LAPO[3]. Can you add some discussion with these methods in the related works or more experiments comparison? \n* Did you experiment with setting different thresholds for what qualifies as “low” or “high” return? If so, how did varying these thresholds impact the learned policy or the model’s ability to generalize to novel tasks?\n\nReferences:\n[1] Lyu, Jiafei, et al. \"State advantage weighting for offline RL.\" arXiv preprint arXiv:2210.04251 (2022).\n\n[2] Liu, Tenglong, et al. \"Adaptive Advantage-Guided Policy Regularization for Offline Reinforcement Learning.\" In International Conference on Machine Learning (ICML). PMLR, 2024.\n\n[3] Chen, Xi, et al. \"Latent-variable advantage-weighted policy optimization for offline rl.\" arXiv preprint arXiv:2203.08949 (2022)."}], "discussions": [{"comment_id": "0fv1Ae8bQZ", "replyto": "XMOaOigOQo", "author_type": "authors", "reviewer": null, "comment": "Dear PC, SAC, AC and Reviewers,\n\nWe would like to express our sincere gratitude to all the reviewers for their efforts in reviewing. We are delighted that we have addressed the concerns of Reviewers Ku5d and DdY7, and they have increased their scores toward the positive side. \n\nWe also greatly appreciate the reviewers' recognition of our work. All the reviewers found our approach to be sufficiently novel and considered our experiments thorough and comprehensive. In particular, Reviewer Ku5d mentioned, \"Instead of merely reweighting high-return samples, the method leverages low-return trajectories as guiding forces, creating a 'counteracting force' that helps the policy avoid suboptimal states.\" We believe this is an excellent summary of our work. With the introduction of the contrastive learning mechanism, ContraDiff is able to (1) fully leverage the information of both high and low-return samples, (2) actively escape from low-return regions.\n\n\nWe are also glad that we have addressed most of the concerns of Reviewer 2zbW. For unclosed questions of Reviewer 2zbW, we have provided a comprehensive analysis and additional experimental evidence. We believe our response can addresses the reviewer' concern.\n\n\nWe would like to once again express our sincere gratitude to all of the reviewers, the comments and suggestions from reviewers have been very pertinent and valuable, and we believe that this discussion has been highly efficient and fruitful.\n\nBest Regards, \\\nThe Authors"}, {"comment_id": "FBJ3sXv3ab", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "**Q3**. The difference between our method ContraDiff and QGPO [1]. \\\n**R3**. We are glad that we have reached a consensus on this point that, as we mentioned in **R3** of our last response, our method and QGPO utilize contrastive learning in **different ways** and are applied to solve **different problems**.\n\nNevertheless, following your suggestion, we conducted experiments of QGPO on Walker2d-Rand-Exp, and the results are as follows. As can be observed, ContraDiff outperforms QGPO in 2 out of 3 settings, indicating that ContraDiff, leveraging contrastive learning to pull the generated samples closer to high-return samples and actively push them away from low-return samples, achieves better results compared to QGPO, which uses contrastive learning to learn higher guidance. We will also add the relevant experimental results to our paper.\n\n| Mix Ratio | QGPO | ContraDiff |\n|:--------:|:--------:|:--------:|\n| 0.1 | **22.5 ± 2.1** | 20.2 ± 1.3 |\n| 0.2 | 44.6 ± 1.9 | **57.4 ± 0.7** |\n| 0.3 | 61.1 ± 3.3 | **78.4 ± 1.2** |\n\n\nAs we had included a section in Related Works (Section 5.3) that provides a detailed overview of methods using contrastive learning in RL, given the different approaches and objectives in utilizing contrastive learning, we have given an appropriate introduction to QGPO in Section 5.3.\n\n[1] Lu, Cheng, et al. \"Contrastive energy prediction for exact energy-guided diffusion sampling in offline reinforcement learning.\" International Conference on Machine Learning. PMLR, 2023."}, {"comment_id": "nNbgKkANz3", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer 2zbW, \\\nThank you for your further comment. \\\n**Q1**. About the understanding of RL loss in diffusion-based methods. \\\n**R1**. The diffusion model-based methods like Diffuser [1] are typically optimized with two losses: the diffusion loss and the guidance loss (i.e., the guidance you have mentioned). The diffusion loss is usually defined as MSE. For example, in Diffuser, the diffusion model loss is defined as (in page 4 of diffuser paper [1]):\n\n$$\n\\\\mathcal{L}(\\\\theta) = \\\\mathbb{E}\\_{i,\\\\epsilon, \\\\tau^0}[ || \\\\epsilon - \\\\epsilon\\_{\\\\theta}(\\\\tau^i, i) ||^2 ]\n$$\n\nThe guidance loss is used to predict the value of each action, which can be observed in the official repo of Diffuser. We agree with your opinion that previous methods adopt value-guided sampling can generate better actions (or trajectories). However, the capability of value-guided sampling is actually learned by the guidance loss instead of the diffusion loss. As can be seen from the formula, \nreturn is not utilized in the diffusion loss, and therefore the diffusion loss essentially just **simply fits the trajectory distribution**. \n\nThe effect of diffusion losses can be validated in our experimental results. In Figure 6, we visualize the probability density functions of the rewards obtained by ContraDiff and ContraDiff-C **when no guidance is used**, where ContraDiff-C denotes the setting of removing the contrastive learning loss from ContraDiff. It is evident that the probability density obtained by ContraDiff-C is closer to the distribution of the dataset than ContraDiff, indicating the diffusion loss essentially just **simply fits the trajectory distribution**. Moreover, ContraDiff is more concentrated in the high-reward region, indicating that our proposed contrastive learning indeed **actively moves away from low-return samples, and moves toward the high-return samples**. \n\n[1] Janner, Michael, et al. \"Planning with diffusion for flexible behavior synthesis.\" arXiv preprint arXiv:2205.09991 (2022).\n\n\n**Q2.** Traditional actor-critic methods like IQL and CQL inherently achieve a similar \"pushing away\" effect that our contrastive learning aims to accomplish. \\\n**R2.** Generally, the RL loss of advantage-based methods adopts a form similar to IQL, where the RL loss of IQL is as follows (Eq.7 from the IQL paper [1]):\n\n$$\n\\\\begin{equation}\n \\\\mathcal{L}\\_{\\\\pi}(\\\\phi) = \\\\mathbb{E}\\_{(s,a)\\\\sim \\\\mathcal{D}}[ exp( \\\\beta( Q\\_{\\\\hat{\\\\theta}}(s,a) - V\\_{\\\\psi}(s) ) )\\\\text{log}\\pi\\_{\\\\phi}(a|s) ]\n\\\\end{equation}\n$$\n\nAs can be seen, for samples with low advantage, the weight is closer to 0, This means that when a sample has a very low return, these methods choose to neglect the information in those samples. Hence, essentially speaking, these methods mainly focus on learning from higher-return samples (i.e., how to achieve a high return), and traditional actor-critic methods like IQL and CQL do not exhibit the 'pushing away' effect we mentioned. Ablation studies in Section 4.3.1 show that ignoring the information from low-return samples leads to poor performance in most cases.\n\n\nIn contrast, our method learns from both low-return and high-return samples (i.e., how to achieve a high return and how to avoid the low-return), making our method able to **actively avoid the low-return area (our unique advantage)**. \nAs shown in Figure 3, we initialize ContraDiff and the baseline methods in a low-return state to explore their ability to escape from the low-return state. As can be observed, even though AW, a method focused on learning from high-return samples, takes a few more steps than Diffuser, it eventually receives terminal signals. In contrast, ContraDiff goes further and ultimately maintains a healthy posture to keep moving forward. In other words, with the contrastive loss, ContraDiff makes it better to escape from the low-return region compared to methods based on advantage-weighting.\n\n\n\nOverall, **The key difference** between our method and previous works is previous methods focus on learning from high-return samples, but our method learns from both high and low-return samples. \n\n\n[1] Kostrikov, Ilya, Ashvin Nair, and Sergey Levine. \"Offline reinforcement learning with implicit q-learning.\" arXiv preprint arXiv:2110.06169 (2021)."}, {"comment_id": "kDHHCHjxJB", "replyto": "F5nhWmQ6gc", "author_type": "reviewer", "reviewer": "Reviewer_2zbW", "comment": "Thank you for the detailed response. However, I remain unconvinced by your arguments for several fundamental reasons.\n\n1.Your characterization of the relationship between RL and contrastive losses reveals concerning misunderstandings. The claim that RL loss in diffusion-based methods \"simply fits trajectory distribution\" is inaccurate - these methods already incorporate mechanisms to differentiate between high and low-return trajectories through value-guided sampling. Similarly, traditional actor-critic methods like IQL and CQL effectively handle trajectory relationships through value estimation and advantage weighting, inherently achieving a similar \"pushing away\" effect that your contrastive learning aims to accomplish. A deeper issue is that these approaches - Q/V values in actor-critic methods, value-guided sampling in diffusion models, and your proposed contrastive learning - are all fundamentally forms of trajectory guidance that bias learning towards high-return behaviors.\n\n2.The claimed independence between RL loss and contrastive loss therefore requires much stronger theoretical justification. This connection has been well explored in works like QGPO, which leads to my second concern. **In the revision, your response that QGPO \"addresses a different problem\" is insufficient. This requires further comparison and discussion, as it is the work most closely related to this paper.** Both works use contrastive learning to improve diffusion-based offline RL, warranting performance comparisons on common benchmarks, detailed technical analysis of similarities and differences, and clear positioning of your contributions relative to QGPO. The current **one-line addition to Related Work** and lack of comparative analysis make it difficult to evaluate your work's novelty and contribution to the field. I believe these issues need to be thoroughly addressed to support your claims about the method's uniqueness and effectiveness.\n\nI believe this paper falls below acceptable standards on the several key issues, so I will maintain my score of 5."}, {"comment_id": "jIFjupSnQl", "replyto": "4J8FW4gXKV", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer Ku5d, \n\n\nThank you for your effort in reviewing our paper, and thank you for acknowledging our work! Your valuable comments and suggestions have significantly contributed to improving our paper.\n\n\nBest regards, \\\nThe Authors"}, {"comment_id": "4J8FW4gXKV", "replyto": "y0dm8M7jNZ", "author_type": "reviewer", "reviewer": "Reviewer_Ku5d", "comment": "Thanks for your additional experiments and effort. My concerns have been addressed. I have updated my score."}, {"comment_id": "y0dm8M7jNZ", "replyto": "rIVYiUny7q", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer Ku5d,\n\nThank you for your effort in reviewing our paper! We are glad that we have addressed most of you concerns. We have provided the performance of ContraDiff without expert data accordingly several days ago, and we hope our responses have adequately addressed your additional concern. As the discussion phase is coming to an end, we sincerely request your further responses.\n\nIf we have resolved your issue, please consider raising your score to the positive side. If you have any further questions, please feel free to share with us! We would appreciate your further discussion on our paper.\n\nBest regards, \\\nThe Authors"}, {"comment_id": "F5nhWmQ6gc", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer 2zbW,\n\nThank you for your effort in reviewing our paper! We have provided detailed responses to your concerns several days ago, and we hope our responses have adequately addressed your additional concerns. As the discussion phase is coming to an end, we sincerely request your further responses.\n\nIf we have resolved your concerns, please consider raising your score to the positive side. If you have any further questions, please feel free to share with us! We would appreciate your further discussion on our paper.\n\nBest regards, \\\nThe Authors"}, {"comment_id": "yMZGtDvtXq", "replyto": "X6CqqRA8hp", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer 2zbW,\n\nWe have provided detailed responses to your concerns days ago, and we have also fixed the issue in the previous response where formulas could not render properly. We hope our responses have adequately addressed your additional concerns. As the discussion phase is coming to an end, we sincerely request your further responses.\n\nIf we have resolved your concerns, please consider raising your score to the positive side. If you have any further questions, please feel free to share with us! We would appreciate your further discussion on our paper.\n\nBest regards,\nThe Authors"}, {"comment_id": "tBd520KSyF", "replyto": "rIVYiUny7q", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer Ku5d,\n\nWe have provided detailed responses to your concerns days ago, and have provided the performance of ContraDiff without the expert dataset accordingly. We hope our responses have adequately addressed your additional concern. As the discussion phase is coming to an end, we sincerely request your further responses.\n\nIf we have resolved your issue, please consider raising your score to the positive side. If you have any further questions, please feel free to share with us! We would appreciate your further discussion on our paper.\n\nBest regards,\nThe Authors"}, {"comment_id": "h0H9taxLnO", "replyto": "SgGyCwUcVA", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer DdY7, \n\n\nThank you for your effort in reviewing our paper, and thank you for acknowledging our work! Your valuable comments and suggestions have significantly contributed to improving our paper.\n\n\nBest regards, \\\nThe Authors"}, {"comment_id": "SgGyCwUcVA", "replyto": "lasyGWV0iS", "author_type": "reviewer", "reviewer": "Reviewer_DdY7", "comment": "Thank you to the authors for the great efforts to address my concerns with the additional experiments! Most of my questions have been resolved. I have raised my score accordingly."}, {"comment_id": "jAFNNxtOPL", "replyto": "sJy6EtNVH1", "author_type": "authors", "reviewer": null, "comment": "Dear reviewer, \\\nWe are glad that we have addressed most of your concerns. Here are our responses:\n\n\n**Q1.** The performance of ContraDiff without the expert dataset. \\\n**R1.** Thank you for your positive feedback. We introduced a small portion of expert data into each dataset to evaluate ContraDiff's performance in handling suboptimal datasets, in which high-return samples are limited. We provided the performance of ContraDiff without additional expert samples in the Mujoco, Maze2d, and Kitchen datasets in Table 3 and Table 4 from Appendix A.3. It can be observed that without expert samples, ContraDiff achieved the best or second-best results in 4 out of 6 Mujoco tasks and in all Maze2d and Kitchen tasks. This more comprehensively demonstrates the advantages of ContraDiff and indicates that it also performs well in sparse reward scenarios. Please refer to Table 3 and Table 4 in Appendix A.3 in our paper for details."}, {"comment_id": "T9aahQYZ5J", "replyto": "X6CqqRA8hp", "author_type": "authors", "reviewer": null, "comment": "**Q3.** Inadequate positioning relative to important related work \"Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline RL\". \\\n**R3.** Thank you for your further comment. Although QGPO [1] also employs contrastive learning, it addresses a different problem compared to ContraDiff. Generally speaking, QGPO [1] uses contrastive learning to learn a more accurate energy function as guidance, focusing on proposing a new energy-based guidance for Diffusion models. In contrast, as mentioned in R1, ContraDiff focuses on leveraging contrastive learning to address the data imbalance problem, so the two do not overlap. However, thank you for your suggestion, we have enriched our Related Works with QGPO. Please refer to our revised paper for details.\n\n[1] Lu, Cheng, et al. \"Contrastive energy prediction for exact energy-guided diffusion sampling in offline reinforcement learning.\" International Conference on Machine Learning. PMLR, 2023."}, {"comment_id": "r5UelzKyZR", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "Dear reviewer, \\\nWe greatly appreciate you letting us know that some of your concerns have not been fully addressed. \nAlso, we found that the formulas in our previous response (R1) were not rendered correctly. We sincerely apologize for any inconvenience this may have caused you. We have fixed the issue, and the formulas are now properly rendered. \nHere are our further responses:\n\n\n**Q1.** The unclear relationship between contrastive and RL losses. The claimed independence between contrastive and RL losses is therefore questionable, as both inherently bias towards high-return trajectories. \\\n**R1.** Thank you for your further comment. \\\nFirstly, we want to clarify that the RL loss mentioned in our previous reply refers to the loss used in Diffusion-based methods such as Diffuser. In the works based on Diffusion for decision-making, such as Diffuser, they apply an MSE loss between real data and generated trajectories to train the Diffusion model. Only during evaluation do they adopt value-guided sampling to generate trajectories with higher returns. Therefore, the learned diffusion model simply fits the trajectory distribution. \\\nHowever, the contrastive loss used in our paper is designed to better leverage the information embedded in negative samples (i.e., low-return trajectories). By maximizing the distance between generated trajectories and negative samples while minimizing the distance between generated trajectories and positive samples(i.e., high-return trajectories), the contrastive loss constrains the generated trajectories to a high-return area.\n\nSecondly, even methods based on the Actor-Critic framework, such as IQL and CQL, their loss basically assign weights to different ($s$, $a$) pairs based on the advantage of $a$, which fails in actively utilizing the information from low-return samples, for example, they do not actively distance the generated samples from low-return trajectories. In contrast, by employing a contrastive loss, our model can fully leverage the information from low-return trajectories, and avoid low-return samples. The comparisons in Table 2 demonstrate the advantages of our proposed contrastive loss over these reweighting methods. \\\n\nFurthermore, the difference between the contrastive loss and the RL loss can be visually demonstrated through the experiments. Figure 5 and Figure 6 in Section 4.3 of our paper intuitively demonstrate the advantages of our contrastive loss. In Figure 5, we can observe that compared to without using contrastive loss (Figure 5(b)), employing contrastive loss (Figure 5(c)) results in more high-reward states, as indicated by the warmer-colored sample points in the figure. To further illustrate the advantages of the contrastive loss, Figure 6 presents the reward distributions when value guidance is not applied, comparing cases with and without contrastive loss. As shown, without contrastive loss (orange region), the rewards obtained by the model closely align with the reward distribution of the dataset itself (green region). However, when using contrastive loss (blue region), the reward distribution of the model is highly concentrated in the high-reward region, with minimal presence in the low-reward region. Figures 5 and 6 intuitively demonstrate the benefits brought by the proposed contrastive learning loss. \n\n\n**Q2.** Insufficient comparison with critical baseline methods. \\\n**R2.** Thank you for your further comment.  \\\nFirst of all, we want to clarify that, as mentioned in the Introduction, we focus on addressing the data imbalance problem by making better use of low-return samples in the dataset. Therefore, it is more fair to compare our method with other approaches aimed at solving data imbalance. \n\nAs is shown in Table 2, we compare ContraDiff with data imbalance works that were recently published on ICLR 2023 and NIPS 2023, including advantage-weighting (AW) [1], return-weighting (RW) [1],  density-ratio weighting with advantage (AW-DW) [2] and density-ratio weighting with uniform (U-DW) [2]. The results show that ContraDiff achieves optimal or sub-optimal results in 25 out of 27 situations.  Besides, we have compared ContraDiff with plenty of regular SOTA RL methods in Table 1, in which ContraDiff achieves optimal or sub-optimal results in 25 out of 27 situations. \n\n[1] Hong, Zhang-Wei, et al. \"Harnessing Mixed Offline Reinforcement Learning Datasets via Trajectory Weighting.\" The Eleventh International Conference on Learning Representations. \\\n[2] Hong, Zhang-Wei, et al. \"Beyond uniform sampling: Offline reinforcement learning with imbalanced datasets.\" Advances in Neural Information Processing Systems 36 (2023): 4985-5009."}, {"comment_id": "sJy6EtNVH1", "replyto": "T651PlQTxJ", "author_type": "reviewer", "reviewer": "Reviewer_Ku5d", "comment": "Thanks for the effort of the authors. Most of my concerns have been addressed, but I have one more question. I see that the experiments in the paper all need to include expert data, and I'm curious how the performance of ContraDiff is without the expert dataset."}, {"comment_id": "X6CqqRA8hp", "replyto": "fjzVQ7SjMb", "author_type": "reviewer", "reviewer": "Reviewer_2zbW", "comment": "Thank you for the authors' rebuttal. However, I regret that my fundamental concerns remain unaddressed:\n\n1. The authors' characterization of RL loss as 'simply fitting trajectory distribution' is fundamentally incorrect and misleading. This description fits vanilla behavioral cloning, not RL losses, which are specifically designed to learn from high-return trajectories through value estimation and advantage weighting. The claimed independence between contrastive and RL losses is therefore questionable, as both inherently bias towards high-return trajectories.\n\n2. The response merely differentiates the losses mathematically without demonstrating their fundamental independence or complementarity rather than redundancy.\n\n3. Most critically, the authors fail to acknowledge or compare their work with 'Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline RL', which presents a more theoretically complete treatment of combining contrastive learning with diffusion models in offline RL.\n\nGiven these unresolved issues - the unclear relationship between contrastive and RL losses, insufficient comparison with critical baseline methods, and inadequate positioning relative to important related work - I believe this work does not yet meet the bar for publication."}, {"comment_id": "T651PlQTxJ", "replyto": "rIVYiUny7q", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer Ku5d, \n\nWe have provided detailed responses to your concerns many days ago, we hope these responses have adequately addressed your concerns. As the discussion phase is coming to an end, we sincerely request your further responses. \n\nIf we have resolved your issues, please consider raising your score to the positive side. If you have any further questions, please feel free to share them with us! Any valuable feedback is crucial for improving our work. \n\nBest regards, \\\nThe Authors"}, {"comment_id": "lasyGWV0iS", "replyto": "S6HtbgwfIZ", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer DdY7, \n\nWe have provided detailed responses to your concerns many days ago, we hope these responses have adequately addressed your concerns. As the discussion phase is coming to an end, we sincerely request your further responses. \n\nIf we have resolved your issues, please consider raising your score to the positive side. If you have any further questions, please feel free to share them with us! Any valuable feedback is crucial for improving our work. \n\nBest regards, \\\nThe Authors"}, {"comment_id": "d67kamZ2KC", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "Dear Reviewer 2zbW, \n\nWe have provided detailed responses to your concerns many days ago, we hope these responses have adequately addressed your concerns. As the discussion phase is coming to an end, we sincerely request your further responses. \n\nIf we have resolved your issues, please consider raising your score to the positive side. If you have any further questions, please feel free to share them with us! Any valuable feedback is crucial for improving our work. \n\nBest regards, \\\nThe Authors"}, {"comment_id": "NCBZGgNJla", "replyto": "rIVYiUny7q", "author_type": "authors", "reviewer": null, "comment": "**Q5.** Did you experiment with setting different thresholds for what qualifies as “low” or “high” return? If so, how did varying these thresholds impact the learned policy or the model’s ability to generalize to novel tasks? \\\n**R5.** Thank you for your question. \\ \n(1) We have conducted the experiments of thresholds for positive and negative samples (i.e., the thresholds for what qualifies as “low” or “high” return) and hyperparameters (ξ, ζ, σ) in Appendix A.11. In summary, analysis of the hyper-parameters indicates that their impact on model performance is moderate, and searching for the optimal hyper-parameters is easy. Please refer to Appendix A.11 for more detailed discussions. \n(2) The choice of threshold on novel tasks can be made a priori based on the state return in the dataset. For example, one can visualize the state-return distribution of the dataset, and take the return with the most dramatic return change as the threshold. For example, thresholds can be identified as turning points by calculating the rate of change (e.g., differences) or the second-order rate of change (acceleration)[1,2]. After that, we can fix the threshold and adjust $\\sigma$ for better results. \\\n\n\n[1] Satopaa, V., Albrecht, J., Irwin, D., & Raghavan, B. (2011, June). Finding a\" kneedle\" in a haystack: Detecting knee points in system behavior. In 2011 31st international conference on distributed computing systems workshops (pp. 166-171). IEEE. \\\n[2] Silverman, B. W. (2018). Density estimation for statistics and data analysis. Routledge."}, {"comment_id": "thep1WoTPQ", "replyto": "rIVYiUny7q", "author_type": "authors", "reviewer": null, "comment": "**Q1.** While the paper presents contrastive learning as a means to exploit low-return data, it could better clarify why this mechanism is theoretically optimal for avoiding low-return states compared to other potential approaches, such as weighting adjustments or imitation-based filtering. \\\n**R1.** Thank you for your comment. RL methods based on weighting adjustments or imitation-based filtering usually assign weights to states or trajectories according to their returns, aiming to prioritize high-return samples during the optimization process and ignore low-return samples. In other words, these methods do not actively attempt to distance the model from low-return samples. In contrast, our approach uses contrastive learning, treating low-return samples as negative examples and high-return samples as positive examples. This not only brings generated trajectories closer to high-return trajectories but also actively pushes them away from low-return samples. The comparison with the reweighting methods (summarized in Table 2) demonstrates the effectiveness of our proposed approach. We also designed ablation studies in Section 4.3.1 to demonstrate the effectiveness of utilizing low-return states. \n\n**Q2.** The authors are not clear enough in describing the use of weighted contrast loss to constrain trajectory generation. \\\n**R2.** Thank you for your comment. The weighted contrast loss is formulated in Eq.(13), which is further used for contractive learning, as formulated in Eq.(14). Details of Eq.(13) and Eq.(14) are discussed in Section 3.3. Specifically, at time step $ t_1 $, the agent obtains the observation from the environment and generates a future trajectory of $ H $ steps, $[ (s_h, a_h) ]_{t \\leq h \\leq h+H}$. Considering that the importance of predictions decreases as the time step extends further into the future, we apply a weighting scheme to the contrastive learning loss for the next $ H $ steps. For the contrastive learning loss at time step $ h $, we assign a weight of $ 1/(h+1) $. \n \n**Q3.** In the 3.3 model learning section, I noticed that the optimized trajectory generation by minimizing the Mean Square Error between the ground truth and neat trajectory predicted. Therefore, the diffusion should denoise the data from the noisy data completely. I think it's more expensive to train like this. Can you show some experiments comparing the cost of training? \\\n**R3.** Thank you for your comment. The cost of training with the Mean Square Error between the ground truth and neat trajectory predicted (i.e., Eq.(11)) is the same as predicting the noise. Following EDP [1] and the official repo of Diffuser [2], we adopt the one-step denoising method to directly predict the original trajectory within one step rather than denoising the data from the noisy data completely, for additional supervisory information during the training phase. However, during the testing phase, we still generate trajectories (i.e., future plans) through the complete reverse denoising process.\n\n[1] Kang, B., Ma, X., Du, C., Pang, T., & Yan, S. (2024). Efficient diffusion policies for offline reinforcement learning. Advances in Neural Information Processing Systems, 36. \\\n[2] Janner, M., Du, Y., Tenenbaum, J. B., & Levine, S. (2022). Planning with diffusion for flexible behavior synthesis. arXiv preprint arXiv:2205.09991.\n\n**Q4.** ContraDiff planning towards high return states, leading policy improvements. It is very similar to some offline RL methods, such as SAW[1], A2PR[2], LAPO[3]. Can you add some discussion with these methods in the related works or more experiments comparison?\n**R4.** Thank you for your suggestion! We have added discussions of SAW[1], A2PR[2], and LAPO[3] to enrich the Related Works section.\nWe ported SAW to Diffuser (Diffuser-SAW) for experimentation. Since A2PR is not applicable to Diffuser, we conducted experiments with A2PR on our suboptimal datasets. We were unable to perform experiments related to LAPO because the code for LAPO was not provided in the paper. The comparison of ContraDiff with Diffuser-SAW and LAPO on Walker2d-Rand-Exp is shown in the table below. It can be observed that ContraDiff outperforms Diffuser-SAW and A2PR in all cases, demonstrating the advantage of ContraDiff.\n\n\n| Mix Ratio | Diffuser-SAW | A2PR | ContraDiff |\n|-----------|--------------|-------|---------------------|\n| 0.1 | 18.6 ± 1.8 | 13.6 ± 4.3 | **20.2 ± 1.3** |\n| 0.2 | 52.9 ± 3.3 | 21.8 ± 0.9 | **57.4 ± 0.7** |\n| 0.3 | 61.3 ± 1.6 | 22.9 ± 1.4 | **78.4 ± 1.2** |"}, {"comment_id": "dE1OA20UpB", "replyto": "S6HtbgwfIZ", "author_type": "authors", "reviewer": null, "comment": "**Q5.** It is unclear how the baseline results were obtained -- whether they came from existing code, were reimplemented by the authors, or taken from other sources. \\\n**R5.** Thank you for your comment. The codes of other baselines were all obtained from their official repo. For the results of baselines in Table 1 and Table 2, we obtain the results by running the existing code of their official repo with default hyper-parameters. The results of baseline methods in Table 3 are adopted from their own papers and Diffuser paper. \n\n**Q6.** The writing, especially in the experiments section, is not very clear to understand. Some sentences and typos make the paper harder to follow. \\\n**R6.** Thank you for your comment! We have revised our paper based on your suggestions. Please refer to our revised paper for details.\n\n**Q7.** When comparing Table 3 with Table 1, it seems like introducing the expert trajectory leads to a decrease in performance, unexpectedly. For instance, DT achieves a score of 36.6 at original HalfCheetah-MR dataset but 7.5, 6.7, 6.1 in three conditions where expert data is introduced. I was wondering if the same codebase is used in achieving these results, and if so, could the authors share insights on why DT’s performance drops so drastically with the added expert data? \\\n**R7.** Thank you for your comment. We have rechecked the code we used, and we confirm that the code we are using is from DT's official repo (as long as the other baselines). Note that since the DT paper demonstrated that the target return-to-go value has minimal impact on the model, we initially selected the maximum value from the dataset as the target return-to-go (i.e., RTG) for better performance. We pulled the code again from the official repository of DT and conducted relevant experiments, with results consistent with those reported in our paper. \n \nStill, we experimented with different return-to-go values on Halfcheetah-Rand-Exp, and the results are shown in the table below. It can be observed that excessively large target return-to-go values caused a performance decline. When the target return-to-go was reduced to 6000, we achieved comparable results. As DT achieves a score of 36.6 at the original HalfCheetah-MR dataset on the same setting, it can be concluded that expert data helps in improving DT's performance. Even though, ContraDiff still demonstrates a significantly better trend compared to DT.\n\n| Mix Ratio | DT (RTG=12000) | DT (RTG=6000) | $\\text{ContraDiff}$ |\n|-----------|----------------|----------------|----------------------|\n| 0.1 | 7.5 ± 2.1 | 37.1 ± 0.9 | 39.0 ± 0.5 |\n| 0.2 | 6.7 ± 4.5 | 37.4 ± 0.8 | 58.4 ± 0.9 |\n| 0.3 | 6.1 ± 0.1 | 39.6 ± 2.2 | 60.3 ± 2.7 |\n\n**Q8.** In most tasks ContraDiff-SR outperforms ContraDiff-SRD -- this seems unexpected given the author's intuition mentioned in section 3 that ContraDiff-SR may ignore the transition dynamics in its sampling process? \\\n**R8.** Thank you for your question. Ensuring temporal consistency is not always appropriate in all cases. In some environments and datasets, enforcing temporal consistency based on existing data may limit the model's exploration of states with higher returns. Although these states are not explicitly indicated as reachable in the dataset, they can be reached through the generalization capability of diffusion models, making ContraDiff-SR show better performance. Details of the scenarios where ensuring temporal consistency based on existing data can yield better results are discussed in Appendix A.4, please refer to our paper for details.\n\n**Q9.** How would the proposed model perform on sparse-reward tasks like AntMaze? This benchmark is commonly used in offline RL evaluations but results for it are missing here. \\\n**R9.** Thank you for your suggestion! Following Diffuser and other baselines, we primarily evaluate ContraDiff's performance on locomotion, Maze2D, and Kitchen environments, in which Maze2D and Kitchen are sparse-reward tasks. Nevertheless, we have conducted experiments on Antmaze datasets, and the results are as follows. As can be observed, although Antmaze datasets are reward-diverse, ContraDiff outperforms Diffuser in most situations. \n\n| Datasets | Diffuser | ContraDiff |\n|------------------------------|------------------|-----------------|\n| Antmaze-umaze-v2 | 73.3 ± 3.2 | **78.0 ± 2.1** |\n| Antmaze-umaze-diverse-v2 | 64.0 ± 1.1 | **68.5 ± 1.9** |\n| Antmaze-large-play-v2 | 36.4 ± 1.3 | **72.7 ± 2.3** |\n| Antmaze-large-diverse-v2 | 13.6 ± 0.7 | **20.0 ± 2.1** |\n| Antmaze-medium-play-v2 | 45.0 ± 4.3 | **60.0 ± 0.9** |\n| Antmaze-medium-diverse-v2 | **11.1 ± 2.7** | 10.0 ± 1.1 |"}, {"comment_id": "SvYwEunfcf", "replyto": "S6HtbgwfIZ", "author_type": "authors", "reviewer": null, "comment": "**Q3.** The increases in performance are mostly marginal. For many tasks the increase falls within one standard deviation of the baseline methods. \\\n**R3.** Thank you for your comment. In Section 4, we demonstrate the unique advantages of ContraDiff from multiple perspectives. Specifically, \\\n(1)Table 1 presents the comparison between ContraDiff and several recent SOTA baselines on suboptimal datasets, where ContraDiff achieves the best or second-best performance on 25 out of 27 datasets. Even after accounting for the potential impact of variance, ContraDiff delivers outstanding results. For example, on HalfCheetah-Rand-Exp-0.1, ContraDiff outperforms the second-best method by 28.1%. This preliminarily demonstrates the advantage of ContraDiff in effectively leveraging negative samples. \\\n(2)Furthermore, to better showcase the advantages of ContraDiff, we compare it with SOTA methods designed to handle imbalanced data. We initialize the environment with **random low-return states** to create more challenging scenarios, aiming to evaluate the ability of ContraDiff and SOTA methods to handle situations where the agent is trapped in low-return states. Specifically, we investigate the capability of ContraDiff to transition from low-return states to high-return states. As is shown in Table 2, ContraDiff achieves the best performance on almost all (23 out of 27) datasets compared to other SOTA baselines. \nFor example, ContraDiff outperforms the second-best method by 41.4% on Hopper-Rand-Exp-0.1.\nSince the difference between ContraDiff and the baseline lies in leveraging negative sample information effectively through contrastive learning, results in Table 2 further highlight that utilizing low-return samples can effectively help the model escape from low-return areas. \\\n(3)To provide a more intuitive explanation, we visualize in Figure 3 the ability of ContraDiff, AW+Diffuser, and their baseline, Diffuser, to escape from low-return states. As shown, when facing similar low-return states, AW+Diffuser achieves better results through its reweighting mechanism but still fails to fully escape the low-return region (terminating after 40 steps). In contrast, with the help of contrastive learning, ContraDiff effectively leverages negative sample information and maintains a healthy trajectory even after 60 steps, successfully escaping the low-return state. This visualization intuitively demonstrates the advantage of our proposed contrastive learning framework in fully utilizing negative sample information.\n\n\n**Q4.** The authors mentioned \"comparing ContraDiff with other regular methods\" but didn't include specific indication of where the results are presented. I later found the results in the appendix, but the authors may consider additional illustrations and explanations on the results. Furthermore methods like CQL are only tested on the standard D4RL datasets but not the mixture dataset proposed by the authors. More experiment results would be needed here to better demonstrate model performance. \\\n**R4.** Thank you for your comment. \n(1)We have revised the description based on your comment to avoid potential misunderstandings. We summarize the results comparing ContraDiff with newly proposed SOTA methods in Table 1. However, due to the paper length limitation, we present the results compared with traditional methods in Table 3. The relevant analysis for comparisons with newly proposed SOTA methods is in Section 4.2.1, while the analysis for comparison with diffusion-free methods is in Appendix A.3. \\\n(2) Nevertheless, based on your suggestion, we evaluated the performance of CQL, IQL, and MOPO on HalfCheetah-Rand-Exp, Walker2d-Rand-Exp, and Hopper-Rand-Exp, and compared them with ContraDiff. The results are shown in the table below. As can be observed, ContraDiff achieves the best performance in almost all scenarios.\n\n| Environment | Dataset | Mix Ratio | CQL | IQL | MOPO | ContraDiff |\n|-------------|------------|-----------|-------------|-------------|-------------|-------------|\n| Halfcheetah | Rand-Exp | 0.1 | 31.1 ± 1.6 | 3.9 ± 1.1 | 23.8 ± 2.1 | **48.0 ± 2.9** |\n| | | 0.2 | 39.4 ± 3.2 | 70.4 ± 2.8 | 30.6 ± 1.1 | **72.3 ± 0.7** |\n| | | 0.3 | 41.1 ± 0.9 | 72.7 ± 1.3 | 33.1 ± 1.9 | **88.7 ± 0.9** |\n| Hopper | Rand-Exp | 0.1 | 9.2 ± 2.2 | 36.3 ± 2.9 | 12.9 ± 2.7 | **52.0 ± 0.7** |\n| | | 0.2 | 20.8 ± 5.1 | 37.2 ± 0.9 | 17.0 ± 0.4 | **75.3 ± 1.0** |\n| | | 0.3 | 50.7 ± 3.1 | 80.6 ± 1.1 | 36.8 ± 1.3 | **86.4 ± 1.5** |\n| Walker2d | Rand-Exp | 0.1 | 11.6 ± 2.7 | **26.4 ± 2.1** | 14.9 ± 2.4 | 20.2 ± 1.3 |\n| | | 0.2 | 13.9 ± 1.1 | 39.9 ± 1.7 | 30.7 ± 1.6 | **57.4 ± 0.7** |\n| | | 0.3 | 34.2 ± 2.4 | 74.8 ± 2.2 | 44.9 ± 2.1 | **78.4 ± 1.2** |"}, {"comment_id": "Jqvm2H5GtV", "replyto": "S6HtbgwfIZ", "author_type": "authors", "reviewer": null, "comment": "**Q1.** While the idea of using contrastive learning is well motivated, the section on how \"positive\" and \"negative\" examples are identified requires additional explanations. The paper did not specify how the clustering is performed -- if the clustering is to enforce \"dynamic consistency\", I would assume the information is already available in the offline dataset? For me it is unclear how well the clustering captures reachability or state transitions beyond just grouping next states together. This could be a limitation in environments with complex or non-linear dynamics. \\\n**R1.** Thank you for your comment. We believe there is a misunderstanding regarding the details of clustering. \\\n(1) We simply group all the states in the dataset into clusters. Specifically, for a state $s_t$, we first find the cluster it belongs to, marked as $C_t$. Next, we treat all the next states of states in $C_t$ as the positive candidates of $s_t$. As all the states in $C_t$ are grouped into the same cluster as $s_t$, we consider these states to be similar, and thus the next state of any in-cluster state of $C_t$ is also considered reachable from $s_t$. Thus, selecting the positive sample of $s_t$ from the next states of in-cluster states ensures the temporal consistency of the constructed contrastive samples. Note that all the required info is available in the dataset. \\\n(2) However, as we understand it, you are referring to treating the same-cluster states of $s_t$'s next state as $s_t$'s next state. This represents a fundamental difference in exploration ability: In ContraDiff, the next states of same-cluster states are not necessarily within the same cluster, which allows ContraDiff to explore a broader range of potential next states. However, if we treat the same-cluster states of $s_t$'s next state as $s_t$'s next state since these states are all within the same cluster, they are inherently similar, which limits ContraDiff's exploration capacity. \nOverall, selecting contrastive samples from the next states of $s_t$'s same-cluster states is not the same as selecting contrastive samples from the same-cluster states of $s_t$'s next state. The next states of same-cluster states used in ContraDiff provide more diverse contrastive learning samples than the same-cluster states of $s_t$'s next state, and as a result, ContraDiff has greater exploration ability. \\\n(3) We believe you are referring to scenarios where the environment dynamics may change over time. For these cases, we can use sub-segments of the trajectories as clustering units, and take into account both $s_t$​ and its historical information when searching for reachable future states in the training phase. Since historical information contains dynamics of the environment, clustering that incorporates historical information can effectively overcome the challenges posed by complex or non-linear dynamics. Exploration in non-linear dynamic scenarios is a promising direction, and we will consider it as part of our future work.\n\n\n\n**Q2.** The paper lacks discussion on the complexity or runtime implications introduced by k-means clustering, especially for large offline RL datasets. \\\n**R2.** Thank you for your comment. Clustering is a preprocessing step, performed only once and stored in a file. During training, we only need to look up the clustering results, so clustering will not impact the training speed of the model. Moreover, the need for clustering information exists only during the training phase and is not required during the testing phase. \\\nNevertheless, we measured the time spent on clustering across different datasets, and the results are shown in the table below. Please note that we have taken several measures to optimize the clustering process. For example, we set the number of clusters to the square root of the number of states, and configure appropriate clustering batch sizes. It can be observed that even during the preprocessing stage, the clustering operation does not consume significant time.\n\n| Datasets | Number of States | Time Cost (Seconds) |\n|---------------------------|------------------|----------------------|\n| Walker2d-Medium-V2 | 1,000,000 | 116 |\n| Walker2d-Medium-Replay-V2 | 302,000 | 75 |\n| Walker2d-Random-V2 | 1,000,000 | 97 |"}, {"comment_id": "NTvXlKsvCD", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "**Q5.** The competitive performance of baselines like CDE and HD raises questions about the necessity of the proposed contrastive mechanism, especially considering these methods achieve comparable results without any specific designs for handling low-return trajectories. While ContraDiff shows improvements in certain scenarios, the performance overlap with these simpler methods suggests that the advantages of explicitly handling low-return trajectories might be less significant than claimed. This observation calls for a more thorough analysis of what unique benefits, if any, are brought by the contrastive mechanism for low-return trajectory utilization. \\\n**R5.** Thank you for your comment. \\\n(1)Table 1 presents the comparison between ContraDiff and several recent SOTA baselines on suboptimal datasets, where ContraDiff achieves the best or second-best performance on 25 out of 27 datasets. This preliminarily demonstrates the advantage of ContraDiff in effectively leveraging negative samples. \\\n(2)Furthermore, to better showcase the advantages of ContraDiff, we compare it with SOTA methods designed to handle imbalanced data. We initialize the environment with **random low-return states** to create more challenging scenarios, aiming to evaluate the ability of ContraDiff and SOTA methods to handle situations where the agent is trapped in low-return states. Specifically, we investigate the capability of ContraDiff to transition from low-return states to high-return states. As is shown in Table 2, ContraDiff achieves the best performance on almost all (23 out of 27) datasets compared to other SOTA baselines. Since the difference between ContraDiff and the baseline lies in leveraging negative sample information effectively through contrastive learning, results in Table 2 further highlight that utilizing low-return samples can effectively help the model escape from low-return areas. \\\n(3)To provide a more intuitive explanation, we visualize in Figure 3 the ability of ContraDiff, AW+Diffuser, and their baseline, Diffuser, to escape from low-return states. As shown, when facing similar low-return states, AW+Diffuser achieves better results through its reweighting mechanism but still fails to fully escape the low-return region (terminating after 40 steps). In contrast, with the help of contrastive learning, ContraDiff effectively leverages negative sample information and maintains a healthy trajectory even after 60 steps, successfully escaping the low-return state. This visualization intuitively demonstrates the advantage of our proposed contrastive learning framework in fully utilizing negative sample information."}, {"comment_id": "yqjfRhhNaD", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "**Q3.** The paper's method of determining positive and negative samples relies heavily on hyperparameters (ξ, ζ, σ) without theoretical justification for their value ranges. The paper lacks a principled approach to determine these thresholds across different environments. A more systematic study on how different return thresholds affect the distribution of positive/negative samples and their impact on learning dynamics would strengthen the method's foundation. \\\n**R3.** Thank you for your comment. \\\n(1) The main idea of our paper is to introduce contrastive learning to enhance decision-making. Therefore, we focus on evaluating the effectiveness of the idea instead of justifying the hyperparameters. As shown in Table 1 and Table 2, even using the most rudimentary positive and negative sample selection method (i.e., dividing solely based on thresholds), ContraDiff still demonstrates significant advantages in most scenarios.\\\nAs for the choice of threshold, a priori can be made based on the state return in the dataset. For example, one can visualize the state-return distribution of the dataset, and take the return with the most dramatic return change as the threshold. For example, thresholds can be identified as turning points by calculating the rate of change (e.g., differences) or the second-order rate of change (acceleration)[1,2]. After that, we can fix the threshold and adjust $\\sigma$ for better results. \\\n(2) We have conducted the experiments of thresholds for positive and negative samples and hyperparameters (ξ, ζ, σ)  in Appendix A.11.  In summary,  analysis of the hyper-parameters indicates that their impact on model performance is moderate, and searching for the optimal hyper-parameters is easy. Please refer to Appendix A.11 for more detailed discussions. \\\n[1] Satopaa, V., Albrecht, J., Irwin, D., & Raghavan, B. (2011, June). Finding a\" kneedle\" in a haystack: Detecting knee points in system behavior. In 2011 31st international conference on distributed computing systems workshops (pp. 166-171). IEEE. \\\n[2] Silverman, B. W. (2018). Density estimation for statistics and data analysis. Routledge.\\\n\n**Q4.** The current state-level contrastive learning approach breaks the temporal correlation between states within trajectories, as it samples positive and negative states purely based on return values (SR) or clustering results (SRD). This design might lose important sequential information that could be better captured through trajectory-level contrastive learning. I encourage authors to explore trajectory-aligned state sampling strategies and conduct comparative experiments between state-level and trajectory-level contrastive learning to provide more solid empirical evidence for the design choices. \\\n**R4.** Thank you for your comment and suggestion. \n(1) We believe by using \"the important sequential information\", you are referring to the dynamic consistency. We want to clarify that we designed the cluster-based approach (i.e., SRD) to ensure dynamic consistency, as is described in Section 3.2.1. Specifically, for a state $s_t$, we first find the cluster it belongs to, marked as $C_t$. Next, we treat all the next states of states in $C_t$ as the positive candidates of $s_t$. As all the states in $C_t$ are grouped into the same cluster as $s_t$, we consider these states to be similar, and thus the next state of any in-cluster state of $C_t$ is also considered reachable from $s_t$. Thus, selecting the positive sample of $s_t$ from the next states of in-cluster states ensures the temporal consistency of the constructed contrastive samples. \\\n(2) Nevertheless, following your suggestion, we conducted experiments on trajectory-level contrastive learning (denoted as $\\text{ContraDiff}^{\\text{t}}$ ) on Walker2d-Rand-Exp, and the results are as follows. As can be observed, ContraDiff outperforms $\\text{ContraDiff}^{\\text{t}}$ in all scenarios, indicating that our state-level contrastive learning method does not significantly affect the sequential information.\n\n\n| Mix Ratio | $\\text{ContraDiff}^{~\\text{t}}$ | $\\text{ContraDiff}$ |\n|-----------|--------------------------------|---------------------|\n| 0.1 | 19.2 ± 2.2 | 20.2 ± 1.3 |\n| 0.2 | 27.7 ± 1.9 | 57.4 ± 0.7 |\n| 0.3 | 45.1 ± 2.9 | 78.4 ± 1.2 |"}, {"comment_id": "XPagqs7Oxo", "replyto": "fjzVQ7SjMb", "author_type": "authors", "reviewer": null, "comment": "**Q1.** The relationship between contrastive loss and RL loss needs further detailed justification. Intuitively, these two losses appear highly correlated, especially when considering low-return trajectories as negative samples. A numerical analysis comparing these losses across trajectories with different returns would be valuable to verify if they indeed show similar trends. If so, the contrastive loss might merely serve as an enhancement to the RL loss rather than a meaningful regularization term that needs to be balanced against the primary objective. \\\n**R1.** Thank you for your comments. \\\n(1) The contrastive loss and RL loss are independent. The contrastive loss (i.e., Eq.(13) and Eq.(9)) is derived from InfoNCE, aiming to increase the expected return of generated trajectories by maximizing the distance from low-return trajectories while minimizing the distance to high-return trajectories. This effectively pulls the distribution of generated trajectories toward high-return trajectories and pushes it away from low-return trajectories. In contrast, the RL loss (i.e., Eq.(11)) simply attempts to fit the distribution of trajectories across the entire dataset. \\\n(2) Nevertheless, following your suggestion, we visualized the magnitude of the contrastive loss with and without contrastive learning on Walker2d-Rand-Exp-0.1, and presented this visualization in Appendix A.14 in our revised paper. Note that:\n$$\n Eq.(13) = \\\\mathbb{E}_{t>0, i \\sim [1, N]} [\\\\sum\\_{h=t}^{t+H}\\\\frac{1}{h+1}\\\\mathcal{L}\\_{h}^i ],\n$$\nand\n$$\n\\\\mathcal{L}\\_{h}^i =-\\\\log \\\\frac{ \\\\sum\\_{k=0}^{\\\\kappa} \\\\exp( \\\\text{sim}( {f}(\\\\hat{s}\\_h^{i,0}), {f}({s}\\_h^{+}) ) / T ) }{ \\\\sum\\_{k=0}^{\\\\kappa} \\\\exp( \\\\text{sim}( {f}(\\\\hat{s}\\_h^{i,0}), {f}({s}\\_h^{-}) ) / T ) }.\n$$\nThe closer the distance to high-return states and the farther the distance to low-return states, the smaller the value of Eq. (13). It can be observed that, regardless of the use of contrastive learning, the RL loss shows a downward trend and eventually stabilizes. However, for the value of Eq.(13), when contrastive learning is not used, the contrastive loss remains large and fluctuates significantly. On the other hand, when contrastive learning is applied, the contrastive loss decreases. Moreover, ContraDiff shows a lower value on RL Loss + Eq.(13) than ContrasDiff w/o CL. This demonstrates that (1) the RL loss and contrastive loss are independent, with the contrastive loss serving as a meaningful regularization term; and (2) contrastive learning indeed maximizes the distance from low-return trajectories while minimizing the distance to high-return trajectories, as further corroborated by the analyses in Sections 4.3.\n\n**Q2.** The methodological differences illustrated in Figure 1(d)(e)(f) lack sufficient justification.This figure fails to demonstrate why pushing away from low-return states is fundamentally superior to upweighting high-return trajectories, or why return-based contrastive learning outperforms traditional trajectory-based approaches. \\\n**R2.** Thank you for your comment. As shown in Figure 1 (d), up-weighting high-return trajectories can only achieve the high-return trajectory ($s$ -> $G$) that exists within the dataset. However, when the agent is trapped in low-return states, e.g., $s_t$ in Figure 1 (d), there is no corresponding high-return example in the dataset (e.g., $s_t$ -> $G$), and up-weighting high-return trajectories fails. Our method, as shown in Figure 1 (e), by constraining the future states, takes the state from the high-return trajectory ($s$ -> $G$) as the target state at $s_t$, thereby enabling a transition from $s_t$ to a high-return trajectory. Besides, there are also some other methods that use contrastive learning in reinforcement learning (Figure (f)). However, they focus on learning better representations. For example, they aim to minimize the similarity between representations of adjacent states within the same trajectory while maximizing the distance between representations of states from different trajectories, as illustrated in Figure 1(f)."}, {"comment_id": "Wf8yEdFEy4", "replyto": "XMOaOigOQo", "author_type": "authors", "reviewer": null, "comment": "We would like to express our sincere appreciation to all reviewers for your constructive feedback. We have made the following adjustments and conducted additional experiments based on the your suggestions, all changes are marked in blue in the revised paper:\n1. We have enhanced the descriptions of Figure 1.\n2. We have enhanced the discussions of Related Works.\n3. We have revised several sentences in our paper for better understanding, and have corrected the typos in our paper."}], "meta_review": {"metareview": "This paper proposes a method for offline RL that focuses on improving performance in regimes where there is a substantial proportion of data present from sub-optimal trajectories (a large ratio of low-return trajectories to high-return trajectories). The paper proposes a contrastive learning method to \"constrain\" (penalize in a loss function) the learned policy to high-return states and away from low-return states. The main claim is that the proposed method results in a performance improvement over comparable related work on sub-optimal datasets. The reviewers mostly agreed that the paper provided sufficient evidence to support this claim, although one reviewer pointed out a potentially relevant missing performance comparison (see additional comments below for more details). The authors responded by explaining the differences, including a brief mention of the differences in an updated submission, and providing a small amount of additional experiments that compared the results against the proposed work.\n\nI concur with the reviewer that the mostly-missing performance comparison was plausibly relevant, and also observed that the preliminary experimental evidence that the authors followed up with provides a bit more evidence for the claim. I also found that the author's qualitative explanation of the difference plausible -- the related work is indeed applying contrastive learning to, at least on the surface, a different, (although potentially more general) purpose. I think this paper is acceptable, but would strongly benefit from a more comprehensive comparison on the remaining settings (the authors included additional experiments of the related work, QGPO, on 3/27 of the experimental settings in Table 2.) This evidence would provide a more complete performance comparison on the settings studied in the paper.", "justification_for_why_not_higher_score": null, "justification_for_why_not_lower_score": null}} diff --git a/benchmarks/openreview_benchmark/eval_history.jsonl b/benchmarks/openreview_benchmark/eval_history.jsonl new file mode 100644 index 0000000..aa3e22b --- /dev/null +++ b/benchmarks/openreview_benchmark/eval_history.jsonl @@ -0,0 +1 @@ +{"generated_at": "2026-04-23T10:43:02.139100+00:00", "benchmark": "benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl", "results_dir": "benchmarks/openreview_benchmark/results/reviews", "judge_model": "claude-sonnet-4-6", "judge_provider": "openai", "method_key": null, "paper_ids_evaluated": ["7b2JrzdLhA", "ajxAJ8GUX4", "BC4lIvfSzv", "BM9qfolt6p", "d4qMoUSMLT", "jj7b3p5kLY", "kOJf7Dklyv", "M992mjgKzI", "SFNqrHQTEP", "XMOaOigOQo"], "num_papers": 10, "mean": {"precision": 0.37667, "recall": 0.745, "f1": 0.4643}, "full_report": "reports/eval_20260423T104302Z.json"} diff --git a/benchmarks/openreview_benchmark/reports/eval_20260423T104302Z.json b/benchmarks/openreview_benchmark/reports/eval_20260423T104302Z.json new file mode 100644 index 0000000..563ee5d --- /dev/null +++ b/benchmarks/openreview_benchmark/reports/eval_20260423T104302Z.json @@ -0,0 +1,150 @@ +{ + "generated_at": "2026-04-23T10:43:02.136255+00:00", + "benchmark": "benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl", + "results_dir": "benchmarks/openreview_benchmark/results/reviews", + "judge_model": "claude-sonnet-4-6", + "judge_provider": "openai", + "method_key": null, + "num_papers": 10, + "mean": { + "precision": 0.37667, + "recall": 0.745, + "f1": 0.4643 + }, + "per_paper": [ + { + "precision": 0.5, + "recall": 0.75, + "f1": 0.6, + "num_predictions": 12, + "num_human_reviews": 4, + "num_predictions_matched": 6, + "num_reviews_covered": 3, + "num_nonempty_reviews": 4, + "judge_model": "claude-sonnet-4-6", + "paper_id": "7b2JrzdLhA", + "title": "Graph Neural Ricci Flow: Evolving Feature from a Curvature P" + }, + { + "precision": 0.25, + "recall": 1.0, + "f1": 0.4, + "num_predictions": 8, + "num_human_reviews": 4, + "num_predictions_matched": 2, + "num_reviews_covered": 4, + "num_nonempty_reviews": 4, + "judge_model": "claude-sonnet-4-6", + "paper_id": "ajxAJ8GUX4", + "title": "Learning Geometric Reasoning Networks For Robot Task And Mot" + }, + { + "precision": 0.3, + "recall": 1.0, + "f1": 0.4615, + "num_predictions": 10, + "num_human_reviews": 4, + "num_predictions_matched": 3, + "num_reviews_covered": 4, + "num_nonempty_reviews": 4, + "judge_model": "claude-sonnet-4-6", + "paper_id": "BC4lIvfSzv", + "title": "Generative Representational Instruction Tuning" + }, + { + "precision": 0.1111, + "recall": 0.75, + "f1": 0.1935, + "num_predictions": 9, + "num_human_reviews": 4, + "num_predictions_matched": 1, + "num_reviews_covered": 3, + "num_nonempty_reviews": 4, + "judge_model": "claude-sonnet-4-6", + "paper_id": "BM9qfolt6p", + "title": "LucidPPN: Unambiguous Prototypical Parts Network for User-ce" + }, + { + "precision": 0.5, + "recall": 0.75, + "f1": 0.6, + "num_predictions": 8, + "num_human_reviews": 4, + "num_predictions_matched": 4, + "num_reviews_covered": 3, + "num_nonempty_reviews": 4, + "judge_model": "claude-sonnet-4-6", + "paper_id": "d4qMoUSMLT", + "title": "Efficient Training of Neural Stochastic Differential Equatio" + }, + { + "precision": 0.5, + "recall": 0.6, + "f1": 0.5455, + "num_predictions": 8, + "num_human_reviews": 5, + "num_predictions_matched": 4, + "num_reviews_covered": 3, + "num_nonempty_reviews": 5, + "judge_model": "claude-sonnet-4-6", + "paper_id": "jj7b3p5kLY", + "title": "The AdEMAMix Optimizer: Better, Faster, Older" + }, + { + "precision": 0.75, + "recall": 0.6, + "f1": 0.6667, + "num_predictions": 8, + "num_human_reviews": 5, + "num_predictions_matched": 6, + "num_reviews_covered": 3, + "num_nonempty_reviews": 5, + "judge_model": "claude-sonnet-4-6", + "paper_id": "kOJf7Dklyv", + "title": "Air Quality Prediction with Physics-Guided Dual Neural ODEs " + }, + { + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + "num_predictions": 8, + "num_human_reviews": 4, + "num_predictions_matched": 0, + "num_reviews_covered": 0, + "num_nonempty_reviews": 4, + "judge_model": "claude-sonnet-4-6", + "paper_id": "M992mjgKzI", + "title": "OGBench: Benchmarking Offline Goal-Conditioned RL" + }, + { + "precision": 0.5556, + "recall": 1.0, + "f1": 0.7143, + "num_predictions": 9, + "num_human_reviews": 4, + "num_predictions_matched": 5, + "num_reviews_covered": 4, + "num_nonempty_reviews": 4, + "judge_model": "claude-sonnet-4-6", + "paper_id": "SFNqrHQTEP", + "title": "NExUME: Adaptive Training and Inference for DNNs under Inter" + }, + { + "precision": 0.3, + "recall": 1.0, + "f1": 0.4615, + "num_predictions": 10, + "num_human_reviews": 3, + "num_predictions_matched": 3, + "num_reviews_covered": 3, + "num_nonempty_reviews": 3, + "judge_model": "claude-sonnet-4-6", + "paper_id": "XMOaOigOQo", + "title": "ContraDiff: Planning Towards High Return States via Contrast" + } + ], + "lock": { + "locked_for_repo": "2026-04-23", + "notes": "Duplicate of the eval run identified by generated_at; paths are repo-relative for portability. Raw per-paper review JSON under results/reviews/ remains gitignored; this file is the committed scorecard." + } +} diff --git a/benchmarks/openreview_benchmark/scripts/collect_openreview.py b/benchmarks/openreview_benchmark/scripts/collect_openreview.py new file mode 100644 index 0000000..4a3c0b0 --- /dev/null +++ b/benchmarks/openreview_benchmark/scripts/collect_openreview.py @@ -0,0 +1,148 @@ +"""Collect paper forums (reviews, rebuttals, decisions) from the OpenReview API. + +Usage: + python benchmarks/openreview_benchmark/scripts/collect_openreview.py --venue ICLR.cc/2025/Conference --limit 10 + python benchmarks/openreview_benchmark/scripts/collect_openreview.py --forum-ids PwxYoMvmvy,HX5ujdsSon +""" + +import argparse +import json +import sys +import time +from pathlib import Path + +_SCRIPT_DIR = Path(__file__).resolve().parent +_TRACK_ROOT = _SCRIPT_DIR.parent +_DATA_DIR = _TRACK_ROOT / "data" +OUTPUT_DIR = _DATA_DIR / "openreview_raw" + +if str(_SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(_SCRIPT_DIR)) + +from openreview_http import API_BASE_URL, create_openreview_session + + +def fetch_forum(session, forum_id: str) -> dict | None: + """Fetch all notes for a single paper forum.""" + resp = session.get( + f"{API_BASE_URL}/notes", + params={"forum": forum_id}, + timeout=30, + ) + if resp.status_code != 200: + print(f" ERROR {resp.status_code} for forum {forum_id}: {resp.text[:200]}") + return None + return resp.json() + + +def fetch_accepted_paper_ids(session, venue_id: str, limit: int) -> list[str]: + """Fetch forum IDs of accepted papers for a venue.""" + paper_ids = [] + offset = 0 + batch_size = min(limit, 50) + + while len(paper_ids) < limit: + resp = session.get( + f"{API_BASE_URL}/notes", + params={ + "content.venueid": venue_id, + "limit": batch_size, + "offset": offset, + }, + timeout=30, + ) + if resp.status_code != 200: + print(f" ERROR {resp.status_code} fetching paper list: {resp.text[:200]}") + break + + data = resp.json() + notes = data.get("notes", []) + if not notes: + break + + for note in notes: + paper_ids.append(note["id"]) + if len(paper_ids) >= limit: + break + + offset += batch_size + time.sleep(0.5) + + return paper_ids + + +def save_forum(forum_data: dict, forum_id: str, output_dir: Path) -> Path: + """Save raw forum JSON to disk.""" + output_dir.mkdir(parents=True, exist_ok=True) + out_path = output_dir / f"{forum_id}.json" + with open(out_path, "w", encoding="utf-8") as f: + json.dump(forum_data, f, indent=2, ensure_ascii=False) + return out_path + + +def main(): + parser = argparse.ArgumentParser(description="Collect OpenReview paper forums.") + parser.add_argument( + "--venue", + type=str, + help="Venue ID, e.g. ICLR.cc/2025/Conference", + ) + parser.add_argument( + "--forum-ids", + type=str, + help="Comma-separated forum IDs to fetch directly", + ) + parser.add_argument( + "--limit", + type=int, + default=10, + help="Max number of papers to fetch when using --venue (default: 10)", + ) + parser.add_argument( + "--output-dir", + type=str, + default=str(OUTPUT_DIR), + help=f"Output directory for raw JSON files (default: {OUTPUT_DIR})", + ) + parser.add_argument( + "--delay", + type=float, + default=1.0, + help="Seconds to wait between API requests (default: 1.0)", + ) + args = parser.parse_args() + + if not args.venue and not args.forum_ids: + parser.error("Provide either --venue or --forum-ids") + + output_dir = Path(args.output_dir) + session = create_openreview_session(mode="api", warmup_timeout=15.0) + print("Session established.") + + if args.forum_ids: + forum_ids = [fid.strip() for fid in args.forum_ids.split(",")] + else: + print(f"Fetching up to {args.limit} accepted paper IDs from {args.venue}...") + forum_ids = fetch_accepted_paper_ids(session, args.venue, args.limit) + print(f"Found {len(forum_ids)} paper IDs.") + + print(f"Collecting {len(forum_ids)} forums...\n") + collected = 0 + for i, fid in enumerate(forum_ids): + print(f"[{i + 1}/{len(forum_ids)}] Fetching {fid}...") + data = fetch_forum(session, fid) + if data and data.get("notes"): + path = save_forum(data, fid, output_dir) + n_notes = len(data["notes"]) + print(f" Saved {n_notes} notes to {path}") + collected += 1 + else: + print(f" Skipped (no data)") + if i < len(forum_ids) - 1: + time.sleep(args.delay) + + print(f"\nDone. Collected {collected}/{len(forum_ids)} forums in {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py b/benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py new file mode 100644 index 0000000..6c33216 --- /dev/null +++ b/benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py @@ -0,0 +1,102 @@ +"""Download OpenReview PDFs for benchmark papers (local files for openaireview review). + +OpenReview PDF URLs are not handled by ``parse_document`` like arXiv; this script +fetches PDFs using the shared session helper in ``openreview_http``. + +PDFs are written to ``benchmarks/openreview_benchmark/data/openreview_pdfs/`` by default (gitignored). +If OpenReview returns HTTP 429, rerun for the missing ``--forum-ids`` with a larger ``--delay``. + +Usage: + python benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py + python benchmarks/openreview_benchmark/scripts/download_openreview_pdfs.py --forum-ids jj7b3p5kLY,kOJf7Dklyv +""" + +import argparse +import json +import sys +import time +from pathlib import Path + +_SCRIPT_DIR = Path(__file__).resolve().parent +_TRACK_ROOT = _SCRIPT_DIR.parent +_DATA_DIR = _TRACK_ROOT / "data" +DEFAULT_BENCHMARK = _DATA_DIR / "openreview_benchmark.jsonl" +DEFAULT_OUT = _DATA_DIR / "openreview_pdfs" + +if str(_SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(_SCRIPT_DIR)) + +from openreview_http import create_openreview_session, pdf_download_url + + +def download_pdf(session, forum_id: str, out_path: Path) -> bool: + url = pdf_download_url(forum_id) + resp = session.get(url, timeout=120) + if resp.status_code != 200: + print(f" ERROR {resp.status_code} for {forum_id}") + return False + ct = resp.headers.get("Content-Type", "") + if "pdf" not in ct.lower() and resp.content[:4] != b"%PDF": + print(f" WARN {forum_id}: response may not be PDF (Content-Type: {ct})") + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_bytes(resp.content) + print(f" Saved {out_path} ({len(resp.content)} bytes)") + return True + + +def main() -> None: + parser = argparse.ArgumentParser(description="Download OpenReview PDFs for benchmark papers") + parser.add_argument( + "--benchmark", + type=Path, + default=DEFAULT_BENCHMARK, + help=f"JSONL with paper_id field (default: {DEFAULT_BENCHMARK})", + ) + parser.add_argument( + "--forum-ids", + type=str, + default=None, + help="Comma-separated forum ids (overrides --benchmark)", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=DEFAULT_OUT, + help=f"Directory for PDF files (default: {DEFAULT_OUT})", + ) + parser.add_argument( + "--delay", + type=float, + default=3.0, + help="Seconds between downloads (default: 3; increase if you see HTTP 429)", + ) + args = parser.parse_args() + + if args.forum_ids: + ids = [x.strip() for x in args.forum_ids.split(",") if x.strip()] + else: + if not args.benchmark.exists(): + print(f"Missing {args.benchmark}", file=sys.stderr) + sys.exit(1) + ids = [] + with open(args.benchmark, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + ids.append(json.loads(line)["paper_id"]) + + print(f"Downloading {len(ids)} PDFs to {args.output_dir}\n") + session = create_openreview_session(mode="pdf", warmup_timeout=60.0) + ok = 0 + for i, fid in enumerate(ids): + out = args.output_dir / f"{fid}.pdf" + if download_pdf(session, fid, out): + ok += 1 + if i < len(ids) - 1: + time.sleep(args.delay) + print(f"\nDone: {ok}/{len(ids)} OK") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py b/benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py new file mode 100644 index 0000000..d01af0a --- /dev/null +++ b/benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py @@ -0,0 +1,241 @@ +"""Evaluate OpenAIReview outputs against openreview_benchmark.jsonl using LLM judge. + +Metrics (per paper): precision, recall, F1 — see ``reviewer.evaluate_openreview``. + +Usage: + # After running reviews with matching slugs (use --name ): + # openaireview review paper.pdf --name jj7b3p5kLY --method zero_shot + + python benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py \\ + --benchmark benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl \\ + --results-dir ./review_results + + # Full JSON under the track (gitignored) + append a line to eval_history.jsonl: + python benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py \\ + --results-dir benchmarks/openreview_benchmark/results/reviews \\ + --save-full-report + + # Or set an explicit report path (--save-full-report not needed): + python benchmarks/openreview_benchmark/scripts/evaluate_openreview_benchmark.py \\ + --results-dir ./review_results \\ + --output benchmarks/openreview_benchmark/results/my_run.json + +Environment: + OPENAI_API_KEY Native OpenAI (set REVIEW_PROVIDER=openai if multiple keys) + REVIEW_PROVIDER e.g. openai + OPENREVIEW_JUDGE_MODEL Judge model (default: gpt-4o-mini) +""" + +import argparse +import json +import os +import sys +from datetime import datetime, timezone +from pathlib import Path + +from dotenv import load_dotenv + +load_dotenv() + +_SCRIPT_DIR = Path(__file__).resolve().parent +_TRACK_ROOT = _SCRIPT_DIR.parent +_DATA_DIR = _TRACK_ROOT / "data" +_REPO_ROOT = _SCRIPT_DIR.parents[3] +_EVAL_HISTORY = _TRACK_ROOT / "eval_history.jsonl" +sys.path.insert(0, str(_REPO_ROOT / "src")) + +from reviewer.evaluate_openreview import ( # noqa: E402 + comments_from_results_json, + evaluate_openreview_pooled, +) + + +def _append_eval_history( + *, + benchmark: Path, + results_dir: Path, + judge_model: str, + judge_provider: str | None, + method_key: str | None, + paper_ids: list[str], + mean_p: float, + mean_r: float, + mean_f: float, + full_report_rel: str | None, +) -> None: + """One JSON line per run for REPORT-style tables (file lives under the track root).""" + row = { + "generated_at": datetime.now(timezone.utc).isoformat(), + "benchmark": str(benchmark.resolve()), + "results_dir": str(results_dir.resolve()), + "judge_model": judge_model, + "judge_provider": judge_provider, + "method_key": method_key, + "paper_ids_evaluated": paper_ids, + "num_papers": len(paper_ids), + "mean": { + "precision": round(mean_p, 6), + "recall": round(mean_r, 6), + "f1": round(mean_f, 6), + }, + "full_report": full_report_rel, + } + with open(_EVAL_HISTORY, "a", encoding="utf-8") as f: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + +def main(): + parser = argparse.ArgumentParser(description="LLM-judge evaluation for OpenReview track") + parser.add_argument( + "--benchmark", + type=Path, + default=_DATA_DIR / "openreview_benchmark.jsonl", + help="Path to openreview_benchmark.jsonl", + ) + parser.add_argument( + "--results-dir", + type=Path, + default=Path("./review_results"), + help="Directory with .json from openaireview review", + ) + parser.add_argument( + "--method-key", + default=None, + help="Key under methods in result JSON (default: first method)", + ) + parser.add_argument( + "--judge-model", + default=os.environ.get("OPENREVIEW_JUDGE_MODEL", "gpt-4o-mini"), + help="Model for LLM judge", + ) + parser.add_argument( + "--provider", + default=os.environ.get("REVIEW_PROVIDER"), + help="Provider for judge API (e.g. openai)", + ) + parser.add_argument( + "--papers", + nargs="*", + default=None, + help="Optional paper_ids to evaluate (default: all in benchmark with results)", + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="Write full eval report JSON (per-paper metrics + means + run metadata)", + ) + parser.add_argument( + "--save-full-report", + action="store_true", + help=f"Write full report to {_TRACK_ROOT / 'results' / 'eval_.json'} (implies --results-dir is usually under results/reviews/)", + ) + parser.add_argument( + "--no-eval-history", + action="store_true", + help="Do not append a summary line to benchmarks/openreview_benchmark/eval_history.jsonl", + ) + args = parser.parse_args() + + papers = [] + with open(args.benchmark, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + papers.append(json.loads(line)) + + if args.papers: + want = set(args.papers) + papers = [p for p in papers if p["paper_id"] in want] + + rows = [] + for paper in papers: + pid = paper["paper_id"] + result_path = args.results_dir / f"{pid}.json" + if not result_path.exists(): + print(f" SKIP {pid}: no {result_path}") + continue + data = json.loads(result_path.read_text(encoding="utf-8")) + try: + preds = comments_from_results_json(data, method_key=args.method_key) + except KeyError as e: + print(f" ERROR {pid}: {e}") + continue + + m = evaluate_openreview_pooled( + preds, + paper, + judge_model=args.judge_model, + judge_provider=args.provider, + ) + m["paper_id"] = pid + m["title"] = (paper.get("title") or "")[:60] + rows.append(m) + print( + f" {pid}: P={m['precision']:.2f} R={m['recall']:.2f} F1={m['f1']:.2f} " + f"(preds={m['num_predictions']}, covered_reviews={m['num_reviews_covered']}/{m['num_nonempty_reviews']})" + ) + + if not rows: + print("No papers evaluated. Place review JSON files as .json in results-dir.") + sys.exit(1) + + avg_p = sum(r["precision"] for r in rows) / len(rows) + avg_r = sum(r["recall"] for r in rows) / len(rows) + avg_f = sum(r["f1"] for r in rows) / len(rows) + print(f"\nMean over {len(rows)} papers: precision={avg_p:.3f} recall={avg_r:.3f} f1={avg_f:.3f}") + + report = { + "generated_at": datetime.now(timezone.utc).isoformat(), + "benchmark": str(args.benchmark.resolve()), + "results_dir": str(args.results_dir.resolve()), + "judge_model": args.judge_model, + "judge_provider": args.provider, + "method_key": args.method_key, + "num_papers": len(rows), + "mean": { + "precision": round(avg_p, 6), + "recall": round(avg_r, 6), + "f1": round(avg_f, 6), + }, + "per_paper": rows, + } + + full_report_path: Path | None = None + if args.output is not None: + full_report_path = args.output + elif args.save_full_report: + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + full_report_path = _TRACK_ROOT / "results" / f"eval_{ts}.json" + + full_report_rel: str | None = None + if full_report_path is not None: + full_report_path.parent.mkdir(parents=True, exist_ok=True) + full_report_path.write_text( + json.dumps(report, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + try: + full_report_rel = str(full_report_path.resolve().relative_to(_TRACK_ROOT.resolve())) + except ValueError: + full_report_rel = str(full_report_path.resolve()) + print(f"\nWrote full report to {full_report_path}") + + if not args.no_eval_history: + _append_eval_history( + benchmark=args.benchmark, + results_dir=args.results_dir, + judge_model=args.judge_model, + judge_provider=args.provider, + method_key=args.method_key, + paper_ids=[r["paper_id"] for r in rows], + mean_p=avg_p, + mean_r=avg_r, + mean_f=avg_f, + full_report_rel=full_report_rel, + ) + print(f"Appended summary to {_EVAL_HISTORY}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/openreview_benchmark/scripts/filter_candidates.py b/benchmarks/openreview_benchmark/scripts/filter_candidates.py new file mode 100644 index 0000000..82e8127 --- /dev/null +++ b/benchmarks/openreview_benchmark/scripts/filter_candidates.py @@ -0,0 +1,230 @@ +"""Find good benchmark candidate papers by sampling and ranking review quality. + +Steps: +1. Fetch all accepted paper IDs from given venues (lightweight metadata only) +2. Randomly sample N papers from the full set +3. Fetch full forums for the sampled papers +4. Rank by average review text length across structured fields +5. Print a ranked table for manual curation + +Usage: + python benchmarks/openreview_benchmark/scripts/filter_candidates.py + python benchmarks/openreview_benchmark/scripts/filter_candidates.py --sample-size 30 --seed 42 +""" + +import argparse +import json +import random +import sys +import time +from pathlib import Path + +_SCRIPT_DIR = Path(__file__).resolve().parent +_TRACK_ROOT = _SCRIPT_DIR.parent +_DATA_DIR = _TRACK_ROOT / "data" + +if str(_SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(_SCRIPT_DIR)) + +from openreview_http import API_BASE_URL, create_openreview_session, rewarm_session + +VENUES = [ + "ICLR.cc/2025/Conference", + "NeurIPS.cc/2025/Conference", +] + +REVIEW_FIELDS = ["summary", "strengths", "weaknesses", "questions"] + + +def fetch_all_paper_ids(session, venue_id: str) -> list[dict]: + """Fetch all accepted paper IDs and basic metadata for a venue.""" + papers = [] + offset = 0 + batch_size = 200 + + while True: + resp = session.get( + f"{API_BASE_URL}/notes", + params={ + "content.venueid": venue_id, + "limit": batch_size, + "offset": offset, + }, + timeout=30, + ) + if resp.status_code != 200: + print(f" ERROR {resp.status_code} at offset {offset}") + break + + data = resp.json() + notes = data.get("notes", []) + if not notes: + break + + for note in notes: + content = note.get("content", {}) + title_field = content.get("title", {}) + area_field = content.get("primary_area", {}) + papers.append({ + "paper_id": note["id"], + "venue": venue_id, + "title": title_field.get("value", "") if isinstance(title_field, dict) else str(title_field), + "primary_area": area_field.get("value", "") if isinstance(area_field, dict) else str(area_field), + }) + + offset += batch_size + print(f" Fetched {len(papers)} IDs so far (offset={offset})...") + time.sleep(0.5) + + return papers + + +def fetch_forum_notes(session, forum_id: str) -> list[dict] | None: + """Fetch all notes for a forum, retrying once with rewarm on 403.""" + for attempt in range(2): + resp = session.get( + f"{API_BASE_URL}/notes", + params={"forum": forum_id}, + timeout=30, + ) + if resp.status_code == 200: + return resp.json().get("notes", []) + if resp.status_code == 403 and attempt == 0: + rewarm_session(session, timeout=15.0) + time.sleep(1) + continue + break + return None + + +def score_paper(notes: list[dict]) -> dict: + """Compute review quality metrics from forum notes.""" + reviews = [] + has_author_response = False + + for note in notes: + invitations = " ".join(note.get("invitations", [])) + if "Official_Review" in invitations: + reviews.append(note) + elif "Official_Comment" in invitations: + sigs = " ".join(note.get("signatures", [])) + if "Authors" in sigs: + has_author_response = True + + if not reviews: + return {"num_reviews": 0, "avg_review_length": 0, "has_author_response": False} + + review_lengths = [] + for review in reviews: + content = review.get("content", {}) + total = 0 + for field in REVIEW_FIELDS: + val = content.get(field, {}) + text = val.get("value", "") if isinstance(val, dict) else str(val or "") + total += len(text) + review_lengths.append(total) + + return { + "num_reviews": len(reviews), + "avg_review_length": sum(review_lengths) // len(review_lengths), + "min_review_length": min(review_lengths), + "max_review_length": max(review_lengths), + "has_author_response": has_author_response, + } + + +def main(): + parser = argparse.ArgumentParser(description="Find benchmark candidate papers.") + parser.add_argument( + "--sample-size", + type=int, + default=50, + help="Number of papers to sample per venue (default: 50)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility (default: 42)", + ) + parser.add_argument( + "--delay", + type=float, + default=1.0, + help="Seconds between forum fetches (default: 1.0)", + ) + args = parser.parse_args() + + random.seed(args.seed) + session = create_openreview_session(mode="api", warmup_timeout=15.0) + print("Session established.\n") + + # Step 1: Fetch all paper IDs from each venue + all_papers = [] + for venue in VENUES: + print(f"Fetching paper IDs from {venue}...") + papers = fetch_all_paper_ids(session, venue) + print(f" Total: {len(papers)} papers\n") + all_papers.extend(papers) + + print(f"Total papers across all venues: {len(all_papers)}\n") + + # Step 2: Random sample + sample_size = min(args.sample_size * len(VENUES), len(all_papers)) + sampled = random.sample(all_papers, sample_size) + print(f"Randomly sampled {len(sampled)} papers.\n") + + # Step 3: Fetch forums and score + print("Fetching forums and scoring review quality...\n") + candidates = [] + for i, paper in enumerate(sampled): + pid = paper["paper_id"] + print(f" [{i + 1}/{len(sampled)}] {pid}: {paper['title'][:50]}...") + notes = fetch_forum_notes(session, pid) + if notes is None: + print(f" SKIPPED (fetch failed)") + continue + scores = score_paper(notes) + paper.update(scores) + candidates.append(paper) + if i < len(sampled) - 1: + time.sleep(args.delay) + + # Step 4: Filter and rank + # Require: at least 1 author response, at least 3 reviews + filtered = [ + c for c in candidates + if c["has_author_response"] and c["num_reviews"] >= 3 + ] + filtered.sort(key=lambda x: x["avg_review_length"], reverse=True) + + # Step 5: Print results + print(f"\n{'=' * 120}") + print(f"TOP CANDIDATES (filtered: {len(filtered)} of {len(candidates)} sampled)") + print(f"{'=' * 120}") + print( + f"{'Rank':<5} {'Venue':<15} {'Reviews':<8} {'AvgLen':<8} {'MinLen':<8} " + f"{'Primary Area':<35} {'Title'}" + ) + print("-" * 120) + + for i, c in enumerate(filtered[:30]): + venue_short = c["venue"].split("/")[0] + area = (c["primary_area"] or "")[:33] + title = c["title"][:55] + print( + f"{i + 1:<5} {venue_short:<15} {c['num_reviews']:<8} " + f"{c['avg_review_length']:<8} {c['min_review_length']:<8} " + f"{area:<35} {title}" + ) + + # Save full results for reference + out_path = _DATA_DIR / "candidate_papers.json" + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as f: + json.dump(filtered, f, indent=2, ensure_ascii=False) + print(f"\nFull results saved to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/openreview_benchmark/scripts/normalize_openreview.py b/benchmarks/openreview_benchmark/scripts/normalize_openreview.py new file mode 100644 index 0000000..4e732bd --- /dev/null +++ b/benchmarks/openreview_benchmark/scripts/normalize_openreview.py @@ -0,0 +1,222 @@ +"""Normalize raw OpenReview forum JSON into benchmark JSONL format. + +Reads raw forum JSON files (from collect_openreview.py) and produces a single +JSONL file with one paper per line, following a schema designed for the +OpenReview benchmark track. + +Usage: + python benchmarks/openreview_benchmark/scripts/normalize_openreview.py + python benchmarks/openreview_benchmark/scripts/normalize_openreview.py --raw-dir benchmarks/openreview_benchmark/data/openreview_raw --output benchmarks/openreview_benchmark/data/openreview_benchmark.jsonl +""" + +import argparse +import json +from pathlib import Path + +_SCRIPT_DIR = Path(__file__).resolve().parent +_TRACK_ROOT = _SCRIPT_DIR.parent +_DATA_DIR = _TRACK_ROOT / "data" +RAW_DIR = _DATA_DIR / "openreview_raw" +OUTPUT_PATH = _DATA_DIR / "openreview_benchmark.jsonl" + + +def classify_note(note: dict) -> str: + """Classify a note by its invitation type.""" + if note.get("replyto") is None: + return "submission" + invitations = " ".join(note.get("invitations", [])) + if "Official_Review" in invitations: + return "review" + if "Meta_Review" in invitations: + return "meta_review" + if "Decision" in invitations: + return "decision" + if "Official_Comment" in invitations: + return "comment" + if "Rebuttal" in invitations: + return "rebuttal" + return "other" + + +def extract_value(field) -> str | list | None: + """Extract the 'value' from an OpenReview content field.""" + if isinstance(field, dict): + return field.get("value") + return field + + +def extract_reviewer_id(note: dict) -> str: + """Extract anonymous reviewer ID from signatures.""" + sigs = note.get("signatures", []) + if sigs: + return sigs[0].split("/")[-1] + return "Unknown" + + +def extract_author_type(note: dict) -> str: + """Determine whether a comment is from authors, a reviewer, or an AC.""" + sigs = " ".join(note.get("signatures", [])) + if "Authors" in sigs: + return "authors" + if "Area_Chair" in sigs: + return "area_chair" + if "Reviewer" in sigs: + return "reviewer" + return "other" + + +def normalize_review(note: dict) -> dict: + """Convert a raw review note into the benchmark review schema.""" + content = note.get("content", {}) + return { + "review_id": note["id"], + "reviewer": extract_reviewer_id(note), + "rating": extract_value(content.get("rating", {})), + "confidence": extract_value(content.get("confidence", {})), + "soundness": extract_value(content.get("soundness", {})), + "presentation": extract_value(content.get("presentation", {})), + "contribution": extract_value(content.get("contribution", {})), + "summary": extract_value(content.get("summary", {})), + "strengths": extract_value(content.get("strengths", {})), + "weaknesses": extract_value(content.get("weaknesses", {})), + "questions": extract_value(content.get("questions", {})), + } + + +def normalize_comment(note: dict) -> dict: + """Convert a raw comment/rebuttal note into the benchmark discussion schema.""" + content = note.get("content", {}) + text = extract_value(content.get("comment", {})) or extract_value( + content.get("rebuttal", {}) + ) + return { + "comment_id": note["id"], + "replyto": note.get("replyto"), + "author_type": extract_author_type(note), + "reviewer": extract_reviewer_id(note) if "Reviewer" in " ".join(note.get("signatures", [])) else None, + "comment": text or "", + } + + +def normalize_forum(raw_data: dict) -> dict | None: + """Convert a full raw forum into the benchmark paper schema.""" + notes = raw_data.get("notes", []) + if not notes: + return None + + submission = None + reviews = [] + discussions = [] + meta_review = None + decision = None + + for note in notes: + note_type = classify_note(note) + if note_type == "submission": + submission = note + elif note_type == "review": + reviews.append(normalize_review(note)) + elif note_type in ("comment", "rebuttal"): + discussions.append(normalize_comment(note)) + elif note_type == "meta_review": + content = note.get("content", {}) + meta_review = { + "metareview": extract_value(content.get("metareview", {})), + "justification_for_why_not_higher_score": extract_value( + content.get("justification_for_why_not_higher_score", {}) + ), + "justification_for_why_not_lower_score": extract_value( + content.get("justification_for_why_not_lower_score", {}) + ), + } + elif note_type == "decision": + content = note.get("content", {}) + decision = extract_value(content.get("decision", {})) + + if submission is None: + return None + + content = submission.get("content", {}) + venue_id = extract_value(content.get("venueid", {})) or "" + + # Derive venue and year from venue_id (e.g. "ICLR.cc/2025/Conference") + parts = venue_id.split("/") + year = None + for part in parts: + if part.isdigit() and len(part) == 4: + year = int(part) + break + + paper = { + "paper_id": submission["id"], + "forum_url": f"https://openreview.net/forum?id={submission['id']}", + "venue": venue_id, + "year": year, + "title": extract_value(content.get("title", {})), + "authors": extract_value(content.get("authors", {})), + "abstract": extract_value(content.get("abstract", {})), + "keywords": extract_value(content.get("keywords", {})), + "primary_area": extract_value(content.get("primary_area", {})), + "pdf_url": f"https://openreview.net/pdf?id={submission['id']}", + "decision": decision, + "num_reviews": len(reviews), + "num_discussions": len(discussions), + "reviews": reviews, + "discussions": discussions, + "meta_review": meta_review, + } + + return paper + + +def main(): + parser = argparse.ArgumentParser( + description="Normalize raw OpenReview forums into benchmark JSONL." + ) + parser.add_argument( + "--raw-dir", + type=str, + default=str(RAW_DIR), + help=f"Directory with raw forum JSON files (default: {RAW_DIR})", + ) + parser.add_argument( + "--output", + type=str, + default=str(OUTPUT_PATH), + help=f"Output JSONL path (default: {OUTPUT_PATH})", + ) + args = parser.parse_args() + + raw_dir = Path(args.raw_dir) + output_path = Path(args.output) + + raw_files = sorted(raw_dir.glob("*.json")) + if not raw_files: + print(f"No JSON files found in {raw_dir}") + return + + print(f"Normalizing {len(raw_files)} forums from {raw_dir}...") + papers = [] + for raw_file in raw_files: + with open(raw_file, encoding="utf-8") as f: + raw_data = json.load(f) + paper = normalize_forum(raw_data) + if paper: + papers.append(paper) + n_rev = paper["num_reviews"] + n_disc = paper["num_discussions"] + title = paper["title"] or "(no title)" + print(f" {raw_file.name}: {title[:60]}... ({n_rev} reviews, {n_disc} discussions)") + else: + print(f" {raw_file.name}: SKIPPED (no submission note found)") + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + for paper in papers: + f.write(json.dumps(paper, ensure_ascii=False) + "\n") + + print(f"\nWrote {len(papers)} papers to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/openreview_benchmark/scripts/openreview_http.py b/benchmarks/openreview_benchmark/scripts/openreview_http.py new file mode 100644 index 0000000..2c5dd92 --- /dev/null +++ b/benchmarks/openreview_benchmark/scripts/openreview_http.py @@ -0,0 +1,62 @@ +"""Shared HTTP helpers for OpenReview (API + PDF) behind Cloudflare. + +Visit ``openreview.net`` first so ``api2.openreview.net`` and PDF URLs succeed. +""" + +from __future__ import annotations + +import time +from typing import Literal + +import requests + +WEB_ORIGIN = "https://openreview.net" +API_BASE_URL = "https://api2.openreview.net" + +BROWSER_HEADERS = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/125.0.0.0 Safari/537.36" + ), + "Accept-Language": "en-US,en;q=0.5", +} + + +def create_openreview_session( + *, + mode: Literal["api", "pdf"] = "api", + warmup_timeout: float = 30.0, + warmup_max_attempts: int = 1, +) -> requests.Session: + """Session that passes Cloudflare: warm up on the main site, then set Accept. + + mode ``api``: JSON API calls to ``api2.openreview.net``. + mode ``pdf``: fetches from ``openreview.net/pdf?...``. + """ + session = requests.Session() + session.headers.update(BROWSER_HEADERS) + for attempt in range(warmup_max_attempts): + try: + session.get(WEB_ORIGIN, timeout=warmup_timeout) + break + except requests.RequestException as e: + last_err = e + if attempt == warmup_max_attempts - 1: + raise last_err + time.sleep(1.0) + if mode == "api": + session.headers["Accept"] = "application/json" + else: + session.headers["Accept"] = "application/pdf,*/*" + return session + + +def rewarm_session(session: requests.Session, *, timeout: float = 15.0) -> None: + """Call after HTTP 403 on the API to refresh the Cloudflare session.""" + session.get(WEB_ORIGIN, timeout=timeout) + + +def pdf_download_url(forum_id: str) -> str: + """HTTPS URL for the venue PDF for a forum id.""" + return f"{WEB_ORIGIN}/pdf?id={forum_id}" diff --git a/benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py b/benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py new file mode 100644 index 0000000..6e16755 --- /dev/null +++ b/benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py @@ -0,0 +1,124 @@ +"""Validate the OpenReview benchmark file and optionally verify PDF download + parsing. + +Does not call the review LLM — safe to run without API keys. + +Usage: + python benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py + python benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py --parse-one +""" + +import argparse +import json +import sys +import tempfile +from pathlib import Path + +_SCRIPT_DIR = Path(__file__).resolve().parent +_TRACK_ROOT = _SCRIPT_DIR.parent +_DATA_DIR = _TRACK_ROOT / "data" +_REPO_ROOT = _SCRIPT_DIR.parents[3] +BENCHMARK = _DATA_DIR / "openreview_benchmark.jsonl" + +if str(_SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(_SCRIPT_DIR)) + +from openreview_http import create_openreview_session + +sys.path.insert(0, str(_REPO_ROOT / "src")) + + +def main(): + parser = argparse.ArgumentParser(description="Validate OpenReview benchmark JSONL") + parser.add_argument( + "--parse-one", + action="store_true", + help="Download first paper PDF and run reviewer.parsers.parse_document (no LLM)", + ) + parser.add_argument( + "--benchmark", + type=Path, + default=BENCHMARK, + help=f"Path to JSONL (default: {BENCHMARK})", + ) + args = parser.parse_args() + + path = args.benchmark + if not path.exists(): + print(f"Missing file: {path}") + sys.exit(1) + + papers = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + papers.append(json.loads(line)) + + print(f"Loaded {len(papers)} papers from {path}\n") + + required = ( + "paper_id", + "title", + "pdf_url", + "reviews", + "venue", + ) + errors = [] + for i, p in enumerate(papers): + for key in required: + if key not in p: + errors.append(f"Line {i + 1} ({p.get('paper_id', '?')}): missing {key}") + if p.get("reviews") is not None and len(p["reviews"]) == 0: + errors.append(f"Paper {p.get('paper_id')}: no reviews") + + if errors: + print("Validation errors:") + for e in errors: + print(f" {e}") + sys.exit(1) + + for p in papers: + n = len(p["reviews"]) + d = p.get("num_discussions", len(p.get("discussions", []))) + print(f" OK {p['paper_id']}: {p['title'][:55]}... | {n} reviews | {d} discussions") + + print("\nSchema check passed.") + + if not args.parse_one: + print( + "\nTo smoke-test PDF parsing for one paper, run:\n" + " python benchmarks/openreview_benchmark/scripts/validate_openreview_benchmark.py --parse-one\n" + "\nTo run an actual review (needs API keys), download a PDF or use a local path:\n" + " openaireview review --method zero_shot\n" + "(OpenReview PDF URLs are not parsed as URLs by the CLI; use a downloaded file.)\n" + ) + return + + first = papers[0] + pdf_url = first["pdf_url"] + print(f"\n--parse-one: fetching {pdf_url}") + + sess = create_openreview_session(mode="pdf", warmup_timeout=60.0, warmup_max_attempts=3) + r = sess.get(pdf_url, timeout=120) + if r.status_code != 200: + print(f"HTTP {r.status_code}: {r.text[:200]}") + sys.exit(1) + + from reviewer.parsers import parse_document + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: + tmp.write(r.content) + tmp_path = Path(tmp.name) + + try: + title, text, was_ocr = parse_document(tmp_path, ocr="pymupdf") + print(f" Parsed title: {title[:80]}...") + print(f" Text length: {len(text)} chars, was_ocr={was_ocr}") + print(" PDF parse OK.") + finally: + tmp_path.unlink(missing_ok=True) + + +if __name__ == "__main__": + main() diff --git a/src/reviewer/evaluate_openreview.py b/src/reviewer/evaluate_openreview.py new file mode 100644 index 0000000..4759311 --- /dev/null +++ b/src/reviewer/evaluate_openreview.py @@ -0,0 +1,220 @@ +"""Evaluation for the OpenReview benchmark track (pooled human text + LLM judge). + +Unlike ``evaluate.py`` (Refine), ground truth is not paragraph-anchored. We define: + +- **Precision:** fraction of model comments that the judge says overlap with *any* + substantive issue in the **pooled** human review text (all reviewers). +- **Recall:** for each **official review** separately, fraction of reviews for which + the judge says *at least one* model comment addresses a substantive critique or + question in that review; then average across reviewers (macro recall). + +**F1:** harmonic mean of precision and recall on [0, 1]. + +Requires API access for ``reviewer.client.chat`` (e.g. ``OPENAI_API_KEY`` + judge model). +""" + +from __future__ import annotations + +import os +from typing import Any + +from .client import chat +from .models import Comment + +DEFAULT_OPENREVIEW_JUDGE_MODEL = os.environ.get( + "OPENREVIEW_JUDGE_MODEL", "gpt-4o-mini" +) + +# Keep prompts within context limits +MAX_POOLED_CHARS = 28_000 +MAX_SINGLE_REVIEW_CHARS = 12_000 +MAX_PRED_LIST_CHARS = 24_000 + +# Bedrock-backed judges sometimes return 400 "Operation not allowed" transiently; extra retries help. +_JUDGE_CHAT_RETRIES = 8 + + +def format_official_review_text(review: dict[str, Any]) -> str: + """Build one string from OpenReview-style review fields.""" + parts: list[str] = [] + for key in ("summary", "strengths", "weaknesses", "questions"): + val = review.get(key) + if val and str(val).strip(): + parts.append(f"## {key.replace('_', ' ').title()}\n{val.strip()}") + return "\n\n".join(parts) + + +def pool_human_reviews(paper: dict[str, Any]) -> str: + """Concatenate all official reviews with separators.""" + reviews = paper.get("reviews") or [] + blocks = [format_official_review_text(r) for r in reviews if format_official_review_text(r)] + return "\n\n==========\n\n".join(blocks) + + +def _truncate(s: str, max_len: int) -> str: + if len(s) <= max_len: + return s + return s[: max_len - 20] + "\n...[truncated]" + + +def _yes_no_from_response(text: str) -> bool: + t = text.strip().upper() + return t.startswith("YES") + + +def llm_precision_vs_pooled( + pred: Comment, + pooled_human_text: str, + model: str, + provider: str | None = None, +) -> bool: + """True if the judge says this prediction overlaps any substantive human issue.""" + pooled = _truncate(pooled_human_text, MAX_POOLED_CHARS) + prompt = f"""You compare one model-generated review comment to the combined text of human peer reviews (multiple reviewers; sections may include summary, strengths, weaknesses, questions). + +Human reviews (combined): +{pooled} + +Predicted comment: +Title: {pred.title} +Quote from paper: {pred.quote[:1200]} +Explanation: {pred.explanation[:2000]} + +Does this predicted comment identify or substantially overlap with ANY substantive critique, limitation, or question raised in the human reviews? (Agreement with strengths-only praise alone does not count as YES.) + +Reply with exactly one word: YES or NO.""" + response, _ = chat( + messages=[{"role": "user", "content": prompt}], + model=model, + temperature=0.0, + max_tokens=8, + provider=provider, + retries=_JUDGE_CHAT_RETRIES, + ) + return _yes_no_from_response(response) + + +def llm_recall_one_review( + review_text: str, + preds: list[Comment], + model: str, + provider: str | None = None, +) -> bool: + """True if the judge says at least one prediction addresses an issue in this review.""" + if not review_text.strip(): + return False + rev = _truncate(review_text, MAX_SINGLE_REVIEW_CHARS) + lines: list[str] = [] + for i, p in enumerate(preds, 1): + lines.append( + f"{i}. Title: {p.title}\n Quote: {p.quote[:600]}\n Explanation: {p.explanation[:1200]}" + ) + pred_block = _truncate("\n\n".join(lines), MAX_PRED_LIST_CHARS) + + prompt = f"""Human review from ONE reviewer (sections may include summary, strengths, weaknesses, questions): + +{rev} + +Predicted comments from a model (numbered): +{pred_block} + +Does at least ONE predicted comment address the same substantive issue as ANY criticism, limitation, or question in this human review? (Focus on weaknesses and questions; matching only generic strengths without addressing a concern does not count.) + +Reply with exactly one word: YES or NO.""" + response, _ = chat( + messages=[{"role": "user", "content": prompt}], + model=model, + temperature=0.0, + max_tokens=8, + provider=provider, + retries=_JUDGE_CHAT_RETRIES, + ) + return _yes_no_from_response(response) + + +def evaluate_openreview_pooled( + predictions: list[Comment], + paper: dict[str, Any], + judge_model: str | None = None, + judge_provider: str | None = None, +) -> dict[str, Any]: + """Compute precision, recall, F1 for one paper. + + ``paper`` is one line from ``openreview_benchmark.jsonl`` (dict). + """ + model = judge_model or DEFAULT_OPENREVIEW_JUDGE_MODEL + reviews = paper.get("reviews") or [] + pooled = pool_human_reviews(paper) + + if not predictions: + return { + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + "num_predictions": 0, + "num_human_reviews": len(reviews), + "num_predictions_matched": 0, + "num_reviews_covered": 0, + "judge_model": model, + } + + matched_preds = 0 + for pred in predictions: + if llm_precision_vs_pooled(pred, pooled, model, provider=judge_provider): + matched_preds += 1 + + precision = matched_preds / len(predictions) + + covered = 0 + for r in reviews: + text = format_official_review_text(r) + if not text.strip(): + continue + if llm_recall_one_review(text, predictions, model, provider=judge_provider): + covered += 1 + + nonempty = sum(1 for r in reviews if format_official_review_text(r).strip()) + recall = (covered / nonempty) if nonempty else 0.0 + + if precision + recall > 0: + f1 = 2 * precision * recall / (precision + recall) + else: + f1 = 0.0 + + return { + "precision": round(precision, 4), + "recall": round(recall, 4), + "f1": round(f1, 4), + "num_predictions": len(predictions), + "num_human_reviews": len(reviews), + "num_predictions_matched": matched_preds, + "num_reviews_covered": covered, + "num_nonempty_reviews": nonempty, + "judge_model": model, + } + + +def comments_from_results_json(data: dict[str, Any], method_key: str | None = None) -> list[Comment]: + """Load Comment list from a ``review_results`` JSON file (viz format).""" + methods = data.get("methods") or {} + if not methods: + return [] + if method_key is not None: + key = method_key + if key not in methods: + raise KeyError(f"method key not found: {key!r}. Available: {list(methods)}") + else: + key = next(iter(methods)) + block = methods[key] + out: list[Comment] = [] + for c in block.get("comments") or []: + out.append( + Comment( + title=c.get("title", ""), + quote=c.get("quote", ""), + explanation=c.get("explanation", ""), + comment_type=c.get("comment_type", "technical"), + paragraph_index=c.get("paragraph_index"), + ) + ) + return out