Skip to content

Commit 084612a

Browse files
authored
Merge pull request #91 from WaveSpeedAI/dev
fix: register MultiplyBeta custom layer for ESRGAN medium/thick models
2 parents c1fe0ba + d92b91a commit 084612a

1 file changed

Lines changed: 55 additions & 0 deletions

File tree

src/workers/upscaler.worker.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,59 @@
11
import Upscaler from "upscaler";
2+
import * as tf from "@tensorflow/tfjs";
3+
4+
// Register custom layer used by ESRGAN medium/thick models (loaded from CDN)
5+
// Without this, TF.js throws "Unknown layer: MultiplyBeta" when loading these models
6+
class MultiplyBeta extends tf.layers.Layer {
7+
static className = "MultiplyBeta";
8+
private beta: number;
9+
10+
constructor(config: Record<string, unknown> = {}) {
11+
super(config);
12+
this.beta = (config.beta as number) ?? 0.2;
13+
}
14+
15+
call(inputs: tf.Tensor | tf.Tensor[]): tf.Tensor {
16+
const input = Array.isArray(inputs) ? inputs[0] : inputs;
17+
return tf.mul(input, tf.scalar(this.beta));
18+
}
19+
20+
getConfig() {
21+
return { ...super.getConfig(), beta: this.beta };
22+
}
23+
}
24+
tf.serialization.registerClass(MultiplyBeta);
25+
26+
// PixelShuffle layer used by ESRGAN thick models — does depth-to-space rearrangement
27+
function createPixelShuffleClass(scale: number) {
28+
class PixelShuffle extends tf.layers.Layer {
29+
static className = `PixelShuffle${scale}x`;
30+
private scale: number;
31+
32+
constructor(config: Record<string, unknown> = {}) {
33+
super(config);
34+
this.scale = scale;
35+
}
36+
37+
computeOutputShape(inputShape: Array<number | null>): Array<number | null> {
38+
return [inputShape[0], inputShape[1], inputShape[2], 3];
39+
}
40+
41+
call(inputs: tf.Tensor | tf.Tensor[]): tf.Tensor {
42+
const input = Array.isArray(inputs) ? inputs[0] : inputs;
43+
return tf.depthToSpace(input as tf.Tensor4D, this.scale, "NHWC");
44+
}
45+
46+
getConfig() {
47+
return { ...super.getConfig(), scale: this.scale };
48+
}
49+
}
50+
return PixelShuffle;
51+
}
52+
53+
// Register PixelShuffle for all supported scales
54+
[2, 3, 4].forEach((s) => {
55+
tf.serialization.registerClass(createPixelShuffleClass(s));
56+
});
257

358
type ModelType = "slim" | "medium" | "thick";
459
type ScaleType = "2x" | "3x" | "4x";

0 commit comments

Comments
 (0)