Skip to content

Commit a675b3e

Browse files
h-joocopybara-github
authored andcommitted
Change the iteration order for functions containing async for and yield from.
The iteration order is determined by pytype constructing control flow information by looking at the bytecode instructions. It determines which instructions to elide, and how to connect BB with each other to decide how it should run within pytype's VM. The thing is, the implementation before was a bit incomplete in terms of detecting all control flow including exceptions. It was making some assumptions on what instructions or group of instructions comes after another, which did not hold anymore for python 3.12. In python 3.12 the instruction order around async construct has changed, also some new instructions were added (END_SEND) and how the instructions jump to one another too has changed. Due to this reason, pytype starts to break in 3.12 because of the iteration order being different compared to the real runtime, and it fails due to the wrong order of execution, and the result is that it fails due to insufficient stack elements when it's expecting some elements to be present at a moment. We can try to fix it to make pytype comprehend the full control graph, but I think that's going to take a bit longer to implement. Rather than doing that, with this change we group the basic blocks which are coming from async constructs into a single basic block, to prevent from getting split by the regular BB analyzer so that it runs sequentially without accidentally following the wrong control flow which never happens in the python runtime. PiperOrigin-RevId: 748587745
1 parent cbab356 commit a675b3e

4 files changed

Lines changed: 265 additions & 36 deletions

File tree

pytype/blocks/blocks.py

Lines changed: 170 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Functions for computing the execution order of bytecode."""
22

33
from collections.abc import Iterator
4-
from typing import Any, cast
4+
from typing import Any, Sequence, cast
55
from pycnite import bytecode as pyc_bytecode
66
from pycnite import marshal as pyc_marshal
77
import pycnite.types
@@ -316,7 +316,9 @@ def add_pop_block_targets(bytecode: list[opcodes.Opcode]) -> None:
316316
todo.append((op.next, block_stack))
317317

318318

319-
def _split_bytecode(bytecode: list[opcodes.Opcode]) -> list[Block]:
319+
def _split_bytecode(
320+
bytecode: list[opcodes.Opcode], processed_blocks: set[Block], python_version
321+
) -> list[Block]:
320322
"""Given a sequence of bytecodes, return basic blocks.
321323
322324
This will split the code at "basic block boundaries". These occur at
@@ -333,21 +335,175 @@ def _split_bytecode(bytecode: list[opcodes.Opcode]) -> list[Block]:
333335
targets = {op.target for op in bytecode if op.target}
334336
blocks = []
335337
code = []
336-
for op in bytecode:
338+
prev_block: Block = None
339+
i = 0
340+
while i < len(bytecode):
341+
op = bytecode[i]
342+
# SEND is only used in the context of async for and `yield from`.
343+
# These instructions are not used in other context, so it's safe to process
344+
# it assuming that these are the only constructs they're being used.
345+
if python_version >= (3, 12) and isinstance(op, opcodes.SEND):
346+
if code:
347+
prev_block = Block(code)
348+
blocks.append(prev_block)
349+
code = []
350+
new_blocks, i = _preprocess_async_for_and_yield(
351+
i, bytecode, prev_block, processed_blocks
352+
)
353+
blocks.extend(new_blocks)
354+
prev_block = blocks[-1]
355+
continue
356+
337357
code.append(op)
338358
if (
339359
op.no_next()
340360
or op.does_jump()
341361
or op.pops_block()
342362
or op.next is None
343-
or op.next in targets
363+
or (op.next in targets)
364+
and (
365+
not isinstance(op.next, opcodes.GET_ANEXT)
366+
or python_version < (3, 12)
367+
)
344368
):
345-
blocks.append(Block(code))
369+
prev_block = Block(code)
370+
blocks.append(prev_block)
346371
code = []
372+
i += 1
373+
347374
return blocks
348375

349376

350-
def compute_order(bytecode: list[opcodes.Opcode]) -> list[Block]:
377+
def _preprocess_async_for_and_yield(
378+
idx: int,
379+
bytecode: Sequence[opcodes.Opcode],
380+
prev_block: Block,
381+
processed_blocks: set[Block],
382+
) -> tuple[list[Block], int]:
383+
"""Process bytecode instructions for yield and async for in a way that pytype can iterate correctly.
384+
385+
'Async for' and yield statements, contains instructions that starts with SEND
386+
and ends with END_SEND.
387+
388+
The reason why we need to pre process async for is because the control flow of
389+
async for is drastically different from regular control flows also due to the
390+
fact that the termination of the loop happens by STOP_ASYNC_ITERATION
391+
exception, not a regular control flow. So we need to split (or merge) the
392+
basic blocks in a way that pytype executes in the order that what'd happen in
393+
the runtime, so that it doesn't fail with wrong order of execution, which can
394+
result in a stack underrun.
395+
396+
Args:
397+
idx: The index of the SEND instruction.
398+
bytecode: A list of instances of opcodes.Opcode
399+
prev_block: The previous block that we want to connect the new blocks to.
400+
processed_blocks: Blocks that has been processed so that it doesn't get
401+
processed again by compute_order.
402+
403+
Returns:
404+
A tuple of (list[Block], int), where the Block is the block containing the
405+
iteration part of the async for construct, and the int is the index of the
406+
END_SEND instruction.
407+
"""
408+
assert isinstance(bytecode[idx], opcodes.SEND)
409+
i = next(
410+
i
411+
for i in range(idx + 1, len(bytecode))
412+
if isinstance(bytecode[i], opcodes.JUMP_BACKWARD_NO_INTERRUPT)
413+
)
414+
415+
end_block_idx = i + 1
416+
# In CLEANUP_THROW can be present after JUMP_BACKWARD_NO_INTERRUPT
417+
# depending on how the control flow graph is constructed.
418+
# Usually, CLEANUP_THROW comes way after
419+
if isinstance(bytecode[end_block_idx], opcodes.CLEANUP_THROW):
420+
end_block_idx += 1
421+
422+
# Somehow pytype expects the SEND and YIELD_VALUE to be in different
423+
# blocks, so we need to split.
424+
send_block = Block(bytecode[idx : idx + 1])
425+
yield_value_block = Block(bytecode[idx + 1 : end_block_idx])
426+
prev_block.connect_outgoing(send_block)
427+
send_block.connect_outgoing(yield_value_block)
428+
processed_blocks.update(send_block, yield_value_block)
429+
return [send_block, yield_value_block], end_block_idx
430+
431+
432+
def _remove_jmp_to_get_anext_and_merge(
433+
blocks: list[Block], processed_blocks: set[Block]
434+
) -> list[Block]:
435+
"""Remove JUMP_BACKWARD instructions to GET_ANEXT instructions.
436+
437+
And also merge the block that contains the END_ASYNC_FOR which is part of the
438+
same loop of the GET_ANEXT and JUMP_BACKWARD construct, to the JUMP_BACKWARD
439+
instruction. This is to ignore the JUMP_BACKWARD because in pytype's eyes it's
440+
useless (as it'll jump back to block that it already executed), and also
441+
this is the way to make pytype run the code of END_ASYNC_FOR and whatever
442+
comes afterwards.
443+
444+
Args:
445+
blocks: A list of Block instances.
446+
447+
Returns:
448+
A list of Block instances after the removal and merge.
449+
"""
450+
op_to_block = {}
451+
merge_list = []
452+
for block_idx, block in enumerate(blocks):
453+
for code in block.code:
454+
op_to_block[code] = block_idx
455+
456+
for block_idx, block in enumerate(blocks):
457+
for code in block.code:
458+
if code.end_async_for_target:
459+
merge_list.append((block_idx, op_to_block[code.end_async_for_target]))
460+
map_target = {}
461+
for block_idx, block_idx_to_merge in merge_list:
462+
# Remove JUMP_BACKWARD instruction as we don't want to execute it.
463+
jump_back_op = blocks[block_idx].code.pop()
464+
blocks[block_idx].code.extend(blocks[block_idx_to_merge].code)
465+
map_target[jump_back_op] = blocks[block_idx_to_merge].code[0]
466+
467+
if block_idx_to_merge < len(blocks) - 1:
468+
blocks[block_idx].connect_outgoing(blocks[block_idx_to_merge + 1])
469+
processed_blocks.add(blocks[block_idx])
470+
471+
to_delete = sorted({to_idx for _, to_idx in merge_list}, reverse=True)
472+
473+
for block_idx in to_delete:
474+
del blocks[block_idx]
475+
476+
for block in blocks:
477+
replace_op = map_target.get(block.code[-1].target, None)
478+
if replace_op:
479+
block.code[-1].target = replace_op
480+
481+
return blocks
482+
483+
484+
def _remove_jump_back_block(blocks: list[Block]):
485+
"""Remove JUMP_BACKWARD instructions which are exception handling for async for.
486+
487+
These are not used during the regular pytype control flow analysis.
488+
"""
489+
new_blocks = []
490+
for block in blocks:
491+
last_op = block.code[-1]
492+
if (
493+
isinstance(last_op, opcodes.JUMP_BACKWARD)
494+
and isinstance(last_op.target, opcodes.END_SEND)
495+
and len(block.code) >= 2
496+
and isinstance(block.code[-2], opcodes.CLEANUP_THROW)
497+
):
498+
continue
499+
new_blocks.append(block)
500+
501+
return new_blocks
502+
503+
504+
def compute_order(
505+
bytecode: list[opcodes.Opcode], python_version
506+
) -> list[Block]:
351507
"""Split bytecode into blocks and order the blocks.
352508
353509
This builds an "ancestor first" ordering of the basic blocks of the bytecode.
@@ -359,10 +515,16 @@ def compute_order(bytecode: list[opcodes.Opcode]) -> list[Block]:
359515
Returns:
360516
A list of Block instances.
361517
"""
362-
blocks = _split_bytecode(bytecode)
518+
processed_blocks = set()
519+
blocks = _split_bytecode(bytecode, processed_blocks, python_version)
520+
if python_version >= (3, 12):
521+
blocks = _remove_jump_back_block(blocks)
522+
blocks = _remove_jmp_to_get_anext_and_merge(blocks, processed_blocks)
363523
first_op_to_block = {block.code[0]: block for block in blocks}
364524
for i, block in enumerate(blocks):
365525
next_block = blocks[i + 1] if i < len(blocks) - 1 else None
526+
if block in processed_blocks:
527+
continue
366528
first_op, last_op = block.code[0], block.code[-1]
367529
if next_block and not last_op.no_next():
368530
block.connect_outgoing(next_block)
@@ -390,7 +552,7 @@ def _order_code(dis_code: pycnite.types.DisassembledCode) -> OrderedCode:
390552
"""
391553
ops = opcodes.build_opcodes(dis_code)
392554
add_pop_block_targets(ops)
393-
blocks = compute_order(ops)
555+
blocks = compute_order(ops, dis_code.python_version)
394556
return OrderedCode(dis_code.code, ops, blocks)
395557

396558

pytype/pyc/opcodes.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Opcode:
4848
"prev",
4949
"next",
5050
"target",
51+
"end_async_for_target",
5152
"block_target",
5253
"code",
5354
"annotation",
@@ -67,6 +68,9 @@ def __init__(self, index, line, endline=None, col=None, endcol=None):
6768
self.prev = None
6869
self.next = None
6970
self.target = None
71+
# The END_ASYNC_FOR instruction of which we want to make pytype jump to for
72+
# this instruction.
73+
self.end_async_for_target = None
7074
self.block_target = None
7175
self.code = None # If we have a CodeType or OrderedCode parent
7276
self.annotation = None
@@ -1306,30 +1310,6 @@ def _should_elide_opcode(
13061310
and isinstance(op_items[i + 1][1], END_ASYNC_FOR)
13071311
)
13081312

1309-
# In 3.12 all generators are compiled into infinite loops, too. In addition,
1310-
# YIELD_VALUE inserts exception handling instructions:
1311-
# CLEANUP_THROW
1312-
# JUMP_BACKWARD
1313-
# These can appear on their own or they can be inserted between JUMP_BACKWARD
1314-
# and END_ASYNC_FOR, possibly many times. We keep eliding the `async for` jump
1315-
# and also elide the exception handling cleanup codes because they're not
1316-
# relevant for pytype and complicate the block graph.
1317-
if python_version == (3, 12):
1318-
return (
1319-
isinstance(op, CLEANUP_THROW)
1320-
or (
1321-
isinstance(op, JUMP_BACKWARD)
1322-
and i >= 1
1323-
and isinstance(op_items[i - 1][1], CLEANUP_THROW)
1324-
)
1325-
or (
1326-
isinstance(op, JUMP_BACKWARD)
1327-
and isinstance(
1328-
_get_opcode_following_cleanup_throw_jump_pairs(op_items, i + 1),
1329-
END_ASYNC_FOR,
1330-
)
1331-
)
1332-
)
13331313
return False
13341314

13351315

@@ -1372,13 +1352,44 @@ def _add_jump_targets(ops, offset_to_index):
13721352
op.target = ops[op.arg]
13731353

13741354

1355+
def _add_async_for_jump_back_targets(
1356+
ops: list[Opcode],
1357+
offset_to_op: dict[int, Opcode],
1358+
exc_table: pycnite.types.ExceptionTable,
1359+
):
1360+
"""Find the END_ASYNC_FOR target of which is related to a JUMP_BACKWARD instruction.
1361+
1362+
Also, assign them in a attribute end_async_for_target so that we can process
1363+
it later.
1364+
"""
1365+
1366+
get_anext_incoming: dict[JUMP_BACKWARD, set[GET_ANEXT]] = {}
1367+
for op in ops:
1368+
if isinstance(op, JUMP_BACKWARD) and isinstance(op.target, GET_ANEXT):
1369+
if op.target not in get_anext_incoming:
1370+
get_anext_incoming[op.target] = set()
1371+
get_anext_incoming[op.target].add(op)
1372+
1373+
for e in exc_table.entries:
1374+
if e.start in offset_to_op and isinstance(offset_to_op[e.start], GET_ANEXT):
1375+
get_anext = offset_to_op[e.start]
1376+
if get_anext not in get_anext_incoming:
1377+
continue
1378+
for jump_backward in get_anext_incoming[get_anext]:
1379+
jump_backward.end_async_for_target = offset_to_op[e.target]
1380+
1381+
13751382
def build_opcodes(dis_code: pycnite.types.DisassembledCode) -> list[Opcode]:
13761383
"""Build a list of opcodes from pycnite opcodes."""
13771384
offset_to_op = _make_opcodes(dis_code.opcodes, dis_code.python_version)
13781385
if dis_code.exception_table:
13791386
_add_setup_except(offset_to_op, dis_code.exception_table)
13801387
ops, offset_to_idx = _make_opcode_list(offset_to_op, dis_code.python_version)
13811388
_add_jump_targets(ops, offset_to_idx)
1389+
if dis_code.python_version >= (3, 12):
1390+
_add_async_for_jump_back_targets(
1391+
ops, offset_to_op, dis_code.exception_table
1392+
)
13821393
return ops
13831394

13841395

pytype/state.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,6 @@ def merge_into(self, other):
206206
self.data_stack,
207207
other.data_stack,
208208
)
209-
assert len(self.block_stack) == len(other.block_stack), (
210-
self.block_stack,
211-
other.block_stack,
212-
)
213209
both = list(zip(self.data_stack, other.data_stack))
214210
if any(v1 is not v2 for v1, v2 in both):
215211
for v, o in both:

pytype/tests/test_async_generators.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,66 @@ async def gen():
406406
x4: Coroutine[Any, Any, None] = gen().aclose()
407407
""")
408408

409+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
410+
def test_async_gen_coroutines_error(self):
411+
"""Test whether the async for within async with does not fail at runtime."""
412+
self.Check("""
413+
def outer(f):
414+
async def wrapper(t, *args, **kwargs):
415+
if t is None:
416+
async with f():
417+
async for c in f():
418+
yield c
419+
else:
420+
async for c in f():
421+
yield c
422+
return wrapper
423+
""")
424+
425+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
426+
def test_async_for(self):
427+
self.Check("""
428+
async def iterate(num):
429+
try:
430+
async for s in range(num): # pytype: disable=attribute-error
431+
if s > 3:
432+
yield ''
433+
except ValueError as e:
434+
yield ''
435+
yield ''
436+
""")
437+
438+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
439+
def test_async_for_with_control_flow(self):
440+
self.Check("""
441+
from typing import Any
442+
import random
443+
async def iterate(stream: Any):
444+
async for _ in stream:
445+
if (random.randint(0, 100) != 30 or random.randint(0, 100) != 40):
446+
continue
447+
yield random.randint(0, 100)
448+
""")
449+
450+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
451+
def test_async_double_for_loop(self):
452+
self.Check("""
453+
def outer(f):
454+
async def wrapper(t, *args, **kwargs):
455+
if t is None:
456+
async with f():
457+
async for c in f():
458+
async for d in f():
459+
yield c + d
460+
yield c
461+
else:
462+
async for c in f():
463+
async for d in f():
464+
yield c + d
465+
yield c
466+
return wrapper
467+
""")
468+
409469

410470
if __name__ == "__main__":
411471
test_base.main()

0 commit comments

Comments
 (0)