Skip to content

Code Review Questions — MobileStyleGAN.pytorch #66

@bes-dev

Description

@bes-dev

Automated Code Review — Questions

This issue was auto-generated by a code review pipeline.

Code Review Questions

Architecture & Design

  1. train.py:22 (checkpoint_callback)pl.callbacks.ModelCheckpoint is constructed with the deprecated filepath parameter (removed in PyTorch-Lightning ≥ 1.2, replaced by dirpath/filename). Which version of PL is this targeting, and will this silently discard checkpoints on newer PL installs?

  2. core/distiller.py:validation_epoch_end — A # TODO: add all_gather for distributed mode comment is left in the validation aggregation code, yet the README advertises support for up to 8 GPUs. Without all_gather, KID and val-loss metrics will be computed only over the local rank's subset during multi-GPU runs — how does the training remain reliable across all the advertised GPU configurations?

  3. core/distiller.py:make_sample (else-branch, ~line 93) — When coin < stylemix_p[0], the style tensor is constructed as view(1, self.wsize, ...) (batch size hard-coded to 1) while the rest of the batch has batch_size elements. Does this batch-dimension mismatch silently broadcast or cause an error when batch_size > 1?

  4. core/distiller.py:configure_optimizers — Optimizers are built by iterating self.cfg.mode.split(',') twice (once for the variable name mode, once in the loop), yet the outer mode variable is never used and the inner variable shadows it. Could a config mode string that lists the same character twice (e.g. "g,g") accidentally add duplicate optimizers, corrupting the opt_to_mode mapping and training_step routing?

  5. core/model_zoo.py:model_zoo — The JSON zoo file is opened via json.load(open(zoo_path)) with no with block, leaking the file descriptor. For a training loop that calls model_zoo repeatedly (e.g. during checkpoint conversion), could this exhaust OS file-descriptor limits?


Security

  1. core/model_zoo.py:model_zoo / core/utils.py:download_ckpt — Both paths ultimately call torch.load(path, map_location="cpu") on user-supplied or remotely downloaded files. PyTorch's default pickle-based loader can execute arbitrary code embedded in a checkpoint. Why is there no use of weights_only=True (available since PyTorch 1.13) or any signature verification for the downloaded .ckpt files?

  2. core/utils.py:download_ckpt — Downloaded checkpoints are cached at the fixed path /tmp/{name}. On a shared multi-user machine two concurrent training jobs using the same pretrained model name would collide, and a malicious local user could pre-place a crafted file at that path. Was this path chosen deliberately, and should it use a user-specific temp directory instead?

  3. train.py:build_loggergetattr(pl_loggers, cfg.type)(**cfg.params) instantiates a logger class by name read directly from the JSON config with no allowlist. If a user (or automated sweep tool) supplies an arbitrary cfg.type, could this resolve to unintended classes in the pl_loggers namespace?


Performance

  1. core/distiller.py:validation_step (~lines 56–60)self.student(style, noise=gt["noise"]) is called twice back-to-back: once to get pred_inc for KID features, and again to compute val_loss. Since this is inside @torch.no_grad() territory and the student is deterministic given fixed noise, is the second forward pass intentional, or is this an accidental duplication that doubles validation compute?

  2. core/distiller.py:make_sample — The teacher (self.synthesis_net) runs a full forward pass on every training step to generate ground-truth targets. Because the teacher is frozen (eval(), no grad), could its activations be pre-computed and cached in a dataset (similar to how NoiseDataset pre-samples noise), rather than regenerating them on the fly each step?

  3. core/models/modules/modulated_conv2d.py:ModulatedConv2d.get_demodulation (and ModulatedDWConv2d) — The demodulation norm uses self.style_inv, a randomly-initialised register_buffer that is never updated during training, rather than the per-sample style vector. StyleGAN2's demodulation is supposed to use the actual style modulation to whiten the convolution weights. Is this a deliberate approximation trading quality for speed, and if so, what is the measured FID impact compared to proper per-sample demodulation?


Code Quality

  1. core/models/modules/idwt_upsample.py:IDWTUpsaplme — The class name is misspelled (IDWTUpsaplme instead of IDWTUpsample). This typo propagates into mobile_synthesis_block.py imports and the public module API. Why was this never corrected, and does it affect downstream code that might try to import the class by its correct name?

  2. train.py:main (~line 46)raise "Unknown export format." raises a bare string, which is a no-op in Python 3 (strings are not exceptions and the statement is silently ignored, so execution continues past it). Shouldn't this be raise ValueError("Unknown export format.")?

  3. core/distiller.py:make_sample — The stylemix_p config value is treated as a two-element probability threshold list ([p0, p1]), but there is no validation that 0 <= p0 < p1 <= 1.0. What happens if a user accidentally sets stylemix_p: [0.9, 0.1] — does the style-mixing branch become unreachable with no warning?

  4. core/models/mobile_synthesis_network.py:forward (~line 43) — The W+ style slicing for each block uses style[:, m.wsize()*i + 1 : m.wsize()*i + m.wsize() + 1, :]. With wsize()=3 per block and the initial +1 offset, block 0 receives style[:,1:4,:], block 1 receives style[:,4:7,:], etc. The total wsize() is len(layers)*3 + 2, which matches the slice's final index. However, index 0 of the W+ vector is used for the initial conv1 but index 1 is also included in block 0's slice. Is this overlap intentional, or does it mean one style vector is shared between the stem conv and the first block's upsample, potentially entangling their style degrees of freedom?


Tests & Coverage

  1. Repository-wide — No test files (test_*.py, *_test.py, pytest.ini, setup.cfg [tool:pytest], etc.) are present anywhere in the repository. Given the numerical-correctness-critical code in modulated_conv2d.py (custom demodulation), idwt_upsample.py (DWT-based upsampling), and the W+ style-index slicing in mobile_synthesis_network.py:forward, how are regressions in image quality or numerical stability detected when these modules are modified?

  2. core/distiller.py:compute_mean_style — The style mean is computed with a hard-coded sample of 4096 at __init__ time using the mapping network in its initial state (before any potential fine-tuning). Is the resulting style_mean buffer ever validated, and should there be a test confirming it is a reasonable centroid that produces non-degenerate truncated samples?


Reviewed commit: a9776ff8f05a868b2d3b637bda14eca4c074d2a3 | Files scanned: 42

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions