Skip to content

Commit cf3714d

Browse files
committed
Add Post-CFG SHIFT model patcher
1 parent 8912fc7 commit cf3714d

File tree

3 files changed

+141
-44
lines changed

3 files changed

+141
-44
lines changed

README.md

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# FreeU Advanced Plus
1+
# FreeU Advanced Plus (And Post-CFG SHIFT)
22
Let's say you and I grab dinner, and movie after lunch? 🌃📺😏
33

44
![image](https://github.com/WASasquatch/FreeU_Advanced/assets/1151589/c1dc2ec9-e6a3-4d2d-bf81-697e5d5aabcb)
@@ -52,3 +52,36 @@ Let's say you and I grab dinner, and movie after lunch? 🌃📺😏
5252
## :newspaper_roll: License
5353

5454
Distributed under the MIT License. See `LICENSE` for more information.
55+
56+
---
57+
58+
## Post-CFG SHIFT (Flux)
59+
60+
Post-CFG Stepwise Hybrid Inject + Fourier Tuning.
61+
62+
- Runs after classifier-free guidance (CFG) merges cond/uncond each sampler step.
63+
- Modifies the sampler’s current denoised tensor (in VAE latent space in typical pipelines), not model weights.
64+
- Applies a hybrid blend between the denoised tensor and a scaled version, with optional frequency-domain shaping.
65+
66+
### How it works
67+
1) Model predicts noise; CFG produces a denoised tensor for the current step.
68+
2) SHIFT blends `denoised` with `denoised * b` using the chosen `mode` and `blend`.
69+
3) Optionally applies `Fourier_filter` with per-scale controls.
70+
4) Applies a final `force_gain` multiplier.
71+
72+
### Parameters
73+
- `mode` (combo): Blend strategy for `denoised` vs `denoised*b`.
74+
- Useful: `inject` (strong), `stable_slerp` (smooth), `lerp` (linear), etc.
75+
- `blend` (float): Blend amount between base and scaled tensors.
76+
- `b` (float): Scale factor for the injected path. Higher = stronger effect.
77+
- `apply_fourier` (bool): Enable frequency-domain shaping.
78+
- `multiscale_mode` (combo): Preset shaping curves. Use stable options (e.g., Default, Pass-Through, Sharpen).
79+
- `multiscale_strength` (float): Intensity of multi-scale shaping.
80+
- `threshold` (int): Base radius in frequency mask.
81+
- `s` (float): Base scale value applied at `threshold` radius.
82+
- `force_gain` (float): Final multiplier to boost or attenuate the overall effect.
83+
- `debug_log` (bool): Prints one-time registration and periodic fire logs.
84+
85+
### Notes
86+
- SHIFT is always-on in Flux; attention/forward-timestep/wrapper paths are disabled for stability.
87+
- If a multiscale preset yields flat/gray output, switch to a stable preset (e.g., Sharpen, Pass-Through) or tune `threshold`/`s`.

nodes.py

Lines changed: 106 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch as th
55
import torch.fft as fft
6+
import torch.nn as nn
67
import math
78

89
def normalize(latent, target_min=None, target_max=None):
@@ -129,17 +130,9 @@ def stable_slerp(a, b, t: float, eps: float = 1e-6):
129130

130131
mscales = {
131132
"Default": None,
132-
"Bandpass": [
133-
(5, 0.0), # Low-pass filter
134-
(15, 1.0), # Pass-through filter (allows mid-range frequencies)
135-
(25, 0.0), # High-pass filter
136-
],
137133
"Low-Pass": [
138134
(10, 1.0), # Allows low-frequency components, suppresses high-frequency components
139135
],
140-
"High-Pass": [
141-
(10, 0.0), # Suppresses low-frequency components, allows high-frequency components
142-
],
143136
"Pass-Through": [
144137
(10, 1.0), # Passes all frequencies unchanged, no filtering
145138
],
@@ -228,27 +221,27 @@ class WAS_FreeU:
228221
def INPUT_TYPES(s):
229222
return {"required": {
230223
"model": ("MODEL",),
231-
"target_block": (["output_block", "middle_block", "input_block", "all"],),
232-
"multiscale_mode": (list(mscales.keys()),),
233-
"multiscale_strength": ("FLOAT", {"default": 1.0, "max": 1.0, "min": 0, "step": 0.001}),
234-
"slice_b1": ("INT", {"default": 640, "min": 64, "max": 1280, "step": 1}),
235-
"slice_b2": ("INT", {"default": 320, "min": 64, "max": 640, "step": 1}),
236-
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.001}),
237-
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.001}),
238-
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.001}),
239-
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001}),
224+
"target_block": (["output_block", "middle_block", "input_block", "all"], {"tooltip": "Which UNet block(s) to patch."}),
225+
"multiscale_mode": (list(mscales.keys()), {"tooltip": "Frequency shaping preset used by Fourier_filter."}),
226+
"multiscale_strength": ("FLOAT", {"default": 1.0, "max": 1.0, "min": 0, "step": 0.001, "tooltip": "Intensity of multi-scale shaping [0-1]."}),
227+
"slice_b1": ("INT", {"default": 640, "min": 64, "max": 1280, "step": 1, "tooltip": "Slice width (channels) affected in 1280-channel features."}),
228+
"slice_b2": ("INT", {"default": 320, "min": 64, "max": 640, "step": 1, "tooltip": "Slice width (channels) affected in 640-channel features."}),
229+
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Gain multiplier applied to the b1 slice (1280-ch)."}),
230+
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Gain multiplier applied to the b2 slice (640-ch)."}),
231+
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Fourier scale at threshold for 1280-ch features."}),
232+
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Fourier scale at threshold for 640-ch features."}),
240233
},
241234
"optional": {
242-
"b1_mode": (list(blending_modes.keys()),),
243-
"b1_blend": ("FLOAT", {"default": 1.0, "max": 100, "min": 0, "step": 0.001}),
244-
"b2_mode": (list(blending_modes.keys()),),
245-
"b2_blend": ("FLOAT", {"default": 1.0, "max": 100, "min": 0, "step": 0.001}),
246-
"threshold": ("INT", {"default": 1.0, "max": 10, "min": 1, "step": 1}),
247-
"use_override_scales": (["false", "true"],),
235+
"b1_mode": (list(blending_modes.keys()), {"tooltip": "Blending mode for b1 path."}),
236+
"b1_blend": ("FLOAT", {"default": 1.0, "max": 100, "min": 0, "step": 0.001, "tooltip": "Blend strength for b1 path."}),
237+
"b2_mode": (list(blending_modes.keys()), {"tooltip": "Blending mode for b2 path."}),
238+
"b2_blend": ("FLOAT", {"default": 1.0, "max": 100, "min": 0, "step": 0.001, "tooltip": "Blend strength for b2 path."}),
239+
"threshold": ("INT", {"default": 1.0, "max": 10, "min": 1, "step": 1, "tooltip": "Base radius for the Fourier mask."}),
240+
"use_override_scales": (["false", "true"], {"tooltip": "Enable manual override of scale presets."}),
248241
"override_scales": ("STRING", {"default": '''# OVERRIDE SCALES
249242
250243
# Sharpen
251-
# 10, 1.5''', "multiline": True}),
244+
# 10, 1.5''', "multiline": True, "tooltip": "Custom scale lines: '<radius>, <scale>'. Comments with #,//,!"}),
252245
}
253246
}
254247

@@ -307,6 +300,7 @@ def block_patch_hsp(h, hsp, transformer_options):
307300
print(f"Patching {target_block}")
308301

309302
m = model.clone()
303+
310304
if target_block == "all" or target_block == "output_block":
311305
m.set_model_output_block_patch(block_patch_hsp)
312306
if target_block == "all" or target_block == "input_block":
@@ -315,30 +309,97 @@ def block_patch_hsp(h, hsp, transformer_options):
315309
m.set_model_patch(block_patch, "middle_block_patch")
316310
return (m, )
317311

312+
313+
314+
class WAS_PostCFGShift:
315+
@classmethod
316+
def INPUT_TYPES(s):
317+
return {"required": {
318+
"model": ("MODEL",),
319+
"steps": ("INT", {"default": 20, "min": 1, "max": 1000, "step": 1, "tooltip": "Number of steps to apply SHIFT."}),
320+
"mode": (list(blending_modes.keys()), {"tooltip": "Blend strategy for denoised vs denoised*b (e.g., inject, stable_slerp)."}),
321+
"blend": ("FLOAT", {"default": 1.0, "max": 100.0, "min": 0.0, "step": 0.001, "tooltip": "Blend amount between base and scaled tensors."}),
322+
"b": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Scale factor for the injected path (higher = stronger)."}),
323+
"apply_fourier": ("BOOLEAN", {"default": False, "tooltip": "Apply frequency-domain shaping (Fourier_filter)."}),
324+
"multiscale_mode": (list(mscales.keys()), {"tooltip": "Preset shaping curves for Fourier_filter."}),
325+
"multiscale_strength": ("FLOAT", {"default": 1.0, "max": 1.0, "min": 0.0, "step": 0.001, "tooltip": "Intensity of multi-scale shaping [0-1]."}),
326+
"threshold": ("INT", {"default": 1, "min": 1, "max": 10, "step": 1, "tooltip": "Base radius for frequency mask."}),
327+
"s": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Base scale value applied at threshold radius."}),
328+
"force_gain": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Final multiplier to boost or attenuate effect."}),
329+
}
330+
}
331+
332+
RETURN_TYPES = ("MODEL",)
333+
FUNCTION = "patch"
334+
335+
CATEGORY = "_for_testing"
336+
337+
def patch(self, model, steps, mode, blend, b, apply_fourier, multiscale_mode, multiscale_strength, threshold, s, force_gain):
338+
339+
scales = mscales[multiscale_mode]
340+
steps = max(1, min(1000, steps))
341+
current_step = 0
342+
343+
print(
344+
"[FluxU] inputs:",
345+
f"mode={mode}", f"blend={blend}", f"b={b}",
346+
f"apply_fourier={apply_fourier}", f"multiscale_mode={multiscale_mode}", f"multiscale_strength={multiscale_strength}",
347+
f"threshold={threshold}", f"s={s}", f"force_gain={force_gain}"
348+
)
349+
350+
m = model.clone()
351+
352+
def post_cfg_function(args):
353+
354+
nonlocal current_step
355+
current_step += 1
356+
if current_step > steps:
357+
return args.get("denoised")
358+
359+
denoised = args.get("denoised")
360+
eff_blend = float(blend)
361+
t_scaled = denoised * b
362+
y = blending_modes[mode](denoised, t_scaled, eff_blend)
363+
364+
if apply_fourier:
365+
y = Fourier_filter(y, threshold=threshold, scale=s, scales=scales, strength=multiscale_strength)
366+
367+
if force_gain != 1.0:
368+
y = y * float(force_gain)
369+
return y
370+
371+
372+
try:
373+
m.set_model_sampler_post_cfg_function(post_cfg_function)
374+
print("[FluxU] set_model_sampler_post_cfg_function registered")
375+
except Exception as e:
376+
print(f"[FluxU] set_model_sampler_post_cfg_function failed: {e}")
377+
return (m, )
378+
318379
class WAS_FreeU_V2:
319380
@classmethod
320381
def INPUT_TYPES(s):
321382
return {"required": {
322383
"model": ("MODEL",),
323-
"input_block": ("BOOLEAN", {"default": False}),
324-
"middle_block": ("BOOLEAN", {"default": False}),
325-
"output_block": ("BOOLEAN", {"default": False}),
326-
"multiscale_mode": (list(mscales.keys()),),
327-
"multiscale_strength": ("FLOAT", {"default": 1.0, "max": 1.0, "min": 0, "step": 0.001}),
328-
"slice_b1": ("INT", {"default": 640, "min": 64, "max": 1280, "step": 1}),
329-
"slice_b2": ("INT", {"default": 320, "min": 64, "max": 640, "step": 1}),
330-
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.001}),
331-
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.001}),
332-
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.001}),
333-
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001}),
384+
"input_block": ("BOOLEAN", {"default": False, "tooltip": "Enable patching on the input block."}),
385+
"middle_block": ("BOOLEAN", {"default": False, "tooltip": "Enable patching on the middle block."}),
386+
"output_block": ("BOOLEAN", {"default": False, "tooltip": "Enable patching on the output block."}),
387+
"multiscale_mode": (list(mscales.keys()), {"tooltip": "Frequency shaping preset used by Fourier_filter."}),
388+
"multiscale_strength": ("FLOAT", {"default": 1.0, "max": 1.0, "min": 0, "step": 0.001, "tooltip": "Intensity of multi-scale shaping [0-1]."}),
389+
"slice_b1": ("INT", {"default": 640, "min": 64, "max": 1280, "step": 1, "tooltip": "Slice width (channels) affected in 1280-channel features."}),
390+
"slice_b2": ("INT", {"default": 320, "min": 64, "max": 640, "step": 1, "tooltip": "Slice width (channels) affected in 640-channel features."}),
391+
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Gain multiplier for 1280-channel slice."}),
392+
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Gain multiplier for 640-channel slice."}),
393+
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Fourier scale at threshold for 1280-ch features."}),
394+
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Fourier scale at threshold for 640-ch features."}),
334395
},
335396
"optional": {
336-
"threshold": ("INT", {"default": 1.0, "max": 10, "min": 1, "step": 1}),
337-
"use_override_scales": (["false", "true"],),
397+
"threshold": ("INT", {"default": 1.0, "max": 10, "min": 1, "step": 1, "tooltip": "Base radius for the Fourier mask."}),
398+
"use_override_scales": (["false", "true"], {"tooltip": "Enable manual override of scale presets."}),
338399
"override_scales": ("STRING", {"default": '''# OVERRIDE SCALES
339400
340401
# Sharpen
341-
# 10, 1.5''', "multiline": True}),
402+
# 10, 1.5''', "multiline": True, "tooltip": "Custom scale lines: '<radius>, <scale>'. Comments with #,//,!"}),
342403
}
343404
}
344405

@@ -410,11 +471,14 @@ def block_patch_hsp(h, hsp, transformer_options):
410471
return (m, )
411472

412473
NODE_CLASS_MAPPINGS = {
413-
"FreeU (Advanced)": WAS_FreeU,
414-
"FreeU_V2 (Advanced)": WAS_FreeU_V2,
474+
"WAS_FreeU": WAS_FreeU,
475+
"WAS_FreeU_V2": WAS_FreeU_V2,
476+
"WAS_PostCFGShift": WAS_PostCFGShift,
415477
}
416478

417479
NODE_DISPLAY_NAME_MAPPINGS = {
418-
"FreeU (Advanced)": "FreeU (Advanced Plus)",
419-
"FreeU_V2 (Advanced)": "FreeU V2 (Advanced Plus)",
480+
"WAS_FreeU": "FreeU (Advanced Plus)",
481+
"WAS_FreeU_V2": "FreeU V2 (Advanced Plus)",
482+
"WAS_PostCFGShift": "Post-CFG SHIFT",
420483
}
484+

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "freeu_advanced"
33
description = "This custom node provides advanced settings for FreeU."
4-
version = "1.1.0"
4+
version = "1.2.0"
55
license = { file = "LICENSE" }
66

77
[project.urls]

0 commit comments

Comments
 (0)