From 707edaed5c4198a11b9611d379143e75950984fa Mon Sep 17 00:00:00 2001 From: zxq82lm <231263183+zxq82lm@users.noreply.github.com> Date: Sat, 15 Nov 2025 13:35:07 -0800 Subject: [PATCH] fix: discocirc sandwich #200 --- lambeq/experimental/discocirc/reader.py | 27 +++++++++++++++++++++++++ tests/test_discocirc_reader.py | 23 +++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/lambeq/experimental/discocirc/reader.py b/lambeq/experimental/discocirc/reader.py index 1e2476cd..bcab1f1a 100644 --- a/lambeq/experimental/discocirc/reader.py +++ b/lambeq/experimental/discocirc/reader.py @@ -174,12 +174,20 @@ def _tree2sandwiches_rec(self, noun_box = Box(node.word, Ty(), NOUN) return Id(NOUN), [noun_box], [node.ind], node.ind + if not isinstance(pruned_ids, set): + pruned_ids = set(pruned_ids) + subdiags = [] nouns = [] noun_inds = [] noun2wire = {} noun_cursor = previous_noun + if node.typ == NOUN and node.children: + local_head = self._find_head_noun(node, pruned_ids) + if local_head is not None: + noun_cursor = local_head + bigdiag = Id() if NOUN.l in node.typ or NOUN.r in node.typ: @@ -213,6 +221,12 @@ def _tree2sandwiches_rec(self, ancilla_nouns = ancilla_nouns - {noun2wire[nid]} wire_ids.append(noun2wire[nid]) + if noun2wire.get(nid) is not None: + idx = noun2wire[nid] + if (idx < len(nouns) + and c_nouns[j].name + and not nouns[idx].name): + nouns[idx] = c_nouns[j] wire_ids = list(sorted(ancilla_nouns))+wire_ids @@ -262,6 +276,19 @@ def _tree2sandwiches_rec(self, return bigdiag, nouns, noun_inds, noun_cursor + def _find_head_noun(self, node, pruned_ids: set[int]) -> int | None: + if node.typ == NOUN and not node.children: + if node.ind in pruned_ids: + return None + return node.ind + + for child in node.children: + head = self._find_head_noun(child, pruned_ids) + if head is not None: + return head + + return None + def _get_index(self, s, pnoun): for j, w in enumerate(s): if w == pnoun: diff --git a/tests/test_discocirc_reader.py b/tests/test_discocirc_reader.py index 42354bc0..7d8163a3 100644 --- a/tests/test_discocirc_reader.py +++ b/tests/test_discocirc_reader.py @@ -32,6 +32,17 @@ (Box('Bob', Ty(), n) @ Box('Claire', Ty(), n)) >> Box('hates', n @ n, n @ n) ] +def _copular_tree(): + return PregroupTreeNode('is', 1, Ty('s'), children=[ + PregroupTreeNode('He', 0, n), + PregroupTreeNode('', 2, n, children=[ + PregroupTreeNode('very', 3, n, children=[ + PregroupTreeNode('talented', 4, n @ n.l), + PregroupTreeNode('programmer', 5, n) + ]) + ]) + ]) + class MockBobcatParser(BobcatParser): def __init__(self): @@ -98,3 +109,15 @@ def test_discocirc_reader_w_different_parsers(monkeypatch): parser.sentence2diagram.assert_called_once_with( sentence, tokenised=True, ) + + +def test_sandwich_prefers_local_head_noun(): + parser = MockBobcatParser() + r = DisCoCircReader(parser=parser, + coref_resolver=MockCorefResolver()) + + tree = _copular_tree() + _, nouns, nids, _ = r._tree2sandwiches_rec(tree, pruned_ids=set()) + + assert [box.name for box in nouns] == ['He', 'programmer'] + assert nids == [0, 5]