Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion packages/typegpu/src/tgsl/generationHelpers.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
import { $internal, $resolve } from '../../src/shared/symbols.ts';
import { type AnyData, UnknownData } from '../data/dataTypes.ts';
import { abstractFloat, abstractInt, bool, f32, i32 } from '../data/numeric.ts';
import { isRef } from '../data/ref.ts';
import { isSnippet, snip, type Snippet } from '../data/snippet.ts';
import {
isEphemeralSnippet,
isSnippet,
type ResolvedSnippet,
snip,
type Snippet,
} from '../data/snippet.ts';
import {
type AnyWgslData,
type F32,
type I32,
isMatInstance,
isNaturallyEphemeral,
isVecInstance,
WORKAROUND_getSchema,
} from '../data/wgslTypes.ts';
import {
type FunctionScopeLayer,
getOwnSnippet,
type ResolutionCtx,
type SelfResolvable,
} from '../types.ts';
import type { ShelllessRepository } from './shellless.ts';
import { stitch } from '../../src/core/resolve/stitch.ts';
import { WgslTypeError } from '../../src/errors.ts';

export function numericLiteralToSnippet(value: number): Snippet {
if (value >= 2 ** 63 || value < -(2 ** 63)) {
Expand Down Expand Up @@ -127,3 +138,47 @@ export function coerceToSnippet(value: unknown): Snippet {

return snip(value, UnknownData, /* origin */ 'constant');
}

// defers the resolution of array expressions
export class ArrayExpression implements SelfResolvable {
readonly [$internal] = true;

constructor(
public readonly elementType: AnyWgslData,
public readonly type: AnyWgslData,
public readonly elements: Snippet[],
) {
}

toString(): string {
return 'ArrayExpression';
}

[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
for (const elem of this.elements) {
// We check if there are no references among the elements
if (
(elem.origin === 'argument' &&
!isNaturallyEphemeral(elem.dataType)) ||
!isEphemeralSnippet(elem)
) {
const snippetStr = ctx.resolve(elem.value, elem.dataType).value;
const snippetType =
ctx.resolve(concretize(elem.dataType as AnyData)).value;
throw new WgslTypeError(
`'${snippetStr}' reference cannot be used in an array constructor.\n-----\nTry '${snippetType}(${snippetStr})' or 'arrayOf(${snippetType}, count)([...])' to copy the value instead.\n-----`,
);
}
}

const arrayType = `array<${
ctx.resolve(this.elementType).value
}, ${this.elements.length}>`;

return snip(
stitch`${arrayType}(${this.elements})`,
this.type,
/* origin */ 'runtime',
);
}
}
96 changes: 65 additions & 31 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
tryConvertSnippet,
} from './conversion.ts';
import {
ArrayExpression,
concretize,
type GenerationCtx,
numericLiteralToSnippet,
Expand Down Expand Up @@ -498,21 +499,21 @@ ${this.ctx.pre}}`;
const [_, calleeNode, argNodes] = expression;
const callee = this.expression(calleeNode);

if (wgsl.isWgslStruct(callee.value) || wgsl.isWgslArray(callee.value)) {
// Struct/array schema call.
if (wgsl.isWgslStruct(callee.value)) {
// Struct schema call.
if (argNodes.length > 1) {
throw new WgslTypeError(
'Array and struct schemas should always be called with at most 1 argument',
'Struct schemas should always be called with at most 1 argument',
);
}

// No arguments `Struct()`, resolve struct name and return.
if (!argNodes[0]) {
// the schema becomes the data type
// The schema becomes the data type.
return snip(
`${this.ctx.resolve(callee.value).value}()`,
callee.value,
// A new struct, so not a reference
// A new struct, so not a reference.
/* origin */ 'runtime',
);
}
Expand All @@ -527,7 +528,53 @@ ${this.ctx.pre}}`;
return snip(
this.ctx.resolve(arg.value, callee.value).value,
callee.value,
// A new struct, so not a reference
// A new struct, so not a reference.
/* origin */ 'runtime',
);
}

if (wgsl.isWgslArray(callee.value)) {
// Array schema call.
if (argNodes.length > 1) {
throw new WgslTypeError(
'Array schemas should always be called with at most 1 argument',
);
}

// No arguments `array<...>()`, resolve array type and return.
if (!argNodes[0]) {
// The schema becomes the data type.
return snip(
`${this.ctx.resolve(callee.value).value}()`,
callee.value,
// A new array, so not a reference.
/* origin */ 'runtime',
);
}

const arg = this.typedExpression(
argNodes[0],
callee.value,
);

// `d.arrayOf(...)([...])`.
// We resolve each element separately.
if (arg.value instanceof ArrayExpression) {
return snip(
stitch`${
this.ctx.resolve(callee.value).value
}(${arg.value.elements})`,
arg.dataType,
/* origin */ 'runtime',
);
}

// `d.arrayOf(...)(otherArr)`.
// We just let the argument resolve everything.
return snip(
this.ctx.resolve(arg.value, callee.value).value,
callee.value,
// A new array, so not a reference.
/* origin */ 'runtime',
);
}
Expand Down Expand Up @@ -720,24 +767,9 @@ ${this.ctx.pre}}`;
}
} else {
// The array is not typed, so we try to guess the types.
const valuesSnippets = valueNodes.map((value) => {
const snippet = this.expression(value as tinyest.Expression);
// We check if there are no references among the elements
if (
(snippet.origin === 'argument' &&
!wgsl.isNaturallyEphemeral(snippet.dataType)) ||
!isEphemeralSnippet(snippet)
) {
const snippetStr =
this.ctx.resolve(snippet.value, snippet.dataType).value;
const snippetType =
this.ctx.resolve(concretize(snippet.dataType as AnyData)).value;
throw new WgslTypeError(
`'${snippetStr}' reference cannot be used in an array constructor.\n-----\nTry '${snippetType}(${snippetStr})' or 'arrayOf(${snippetType}, count)([...])' to copy the value instead.\n-----`,
);
}
return snippet;
});
const valuesSnippets = valueNodes.map((value) =>
this.expression(value as tinyest.Expression)
);

if (valuesSnippets.length === 0) {
throw new WgslTypeError(
Expand All @@ -756,16 +788,18 @@ ${this.ctx.pre}}`;
elemType = concretize(values[0]?.dataType as wgsl.AnyWgslData);
}

const arrayType = `array<${
this.ctx.resolve(elemType).value
}, ${values.length}>`;
const arrayType = arrayOf[$internal].jsImpl(
elemType as wgsl.AnyWgslData,
values.length,
);

return snip(
stitch`${arrayType}(${values})`,
arrayOf[$internal].jsImpl(
new ArrayExpression(
elemType as wgsl.AnyWgslData,
values.length,
) as wgsl.AnyWgslData,
arrayType,
values,
),
arrayType,
/* origin */ 'runtime',
);
}
Expand Down
101 changes: 96 additions & 5 deletions packages/typegpu/tests/array.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,24 @@ describe('array', () => {
);
});

it('throws when invalid number of arguments during code generation', () => {
const ArraySchema = d.arrayOf(d.u32, 2);

const f = () => {
'use gpu';
// @ts-expect-error
const arr = ArraySchema([1, 1], [6, 7]);
return;
};

expect(() => tgpu.resolve([f])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn*:f
- fn*:f(): Array schemas should always be called with at most 1 argument]
`);
});

it('can be called to create a default value', () => {
const ArraySchema = d.arrayOf(d.vec3f, 2);

Expand Down Expand Up @@ -188,16 +206,28 @@ describe('array', () => {
it('generates correct code when array clone is used', () => {
const ArraySchema = d.arrayOf(d.u32, 1);

const testFn = tgpu.fn([])(() => {
const f = (arr: d.Infer<typeof ArraySchema>) => {
'use gpu';
const clone = ArraySchema(arr);
};

const testFn = () => {
'use gpu';
const myArray = ArraySchema([d.u32(10)]);
const myClone = ArraySchema(myArray);
f(myArray);
return;
});
};

expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(`
"fn testFn() {
"fn f(arr: array<u32, 1>) {
var clone = arr;
}

fn testFn() {
var myArray = array<u32, 1>(10u);
var myClone = myArray;
f(myArray);
return;
}"
`);
Expand All @@ -221,6 +251,65 @@ describe('array', () => {
`);
});

it('generates correct code when array expression with ephemeral element type clone is used', () => {
const f = () => {
'use gpu';
const arr = d.arrayOf(d.f32, 2)([6, 7]);
return;
};

expect(tgpu.resolve([f])).toMatchInlineSnapshot(`
"fn f() {
var arr = array<f32, 2>(6f, 7f);
return;
}"
`);
});

it('generates correct code when array expression with reference element type clone is used', () => {
const f = (v: d.v4f) => {
'use gpu';
const v2 = d.vec4f(3);
const v3 = v2;
const arr = d.arrayOf(d.vec4f, 3)([v, v2, v3]);
};

const main = tgpu.fn([])(() => {
const v1 = d.vec4f(7);
f(v1);
return;
});

expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn f(v: vec4f) {
var v2 = vec4f(3);
let v3 = (&v2);
var arr = array<vec4f, 3>(v, v2, (*v3));
}

fn main() {
var v1 = vec4f(7);
f(v1);
return;
}"
`);
});

it('generates correct code when array expression with mixed element types clone is used', () => {
const f = () => {
'use gpu';
const arr = d.arrayOf(d.f32, 3)([5, 6.7, 8.0]);
return;
};

expect(tgpu.resolve([f])).toMatchInlineSnapshot(`
"fn f() {
var arr = array<f32, 3>(5f, 6.7f, 8f);
return;
}"
`);
});

it('can be immediately-invoked in TGSL', () => {
const foo = tgpu.fn([])(() => {
const result = d.arrayOf(d.f32, 4)();
Expand Down Expand Up @@ -339,7 +428,8 @@ describe('array', () => {
expect(() => tgpu.resolve([foo])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:foo: 'myVec' reference cannot be used in an array constructor.
- fn:foo
- ArrayExpression: 'myVec' reference cannot be used in an array constructor.
-----
Try 'vec2f(myVec)' or 'arrayOf(vec2f, count)([...])' to copy the value instead.
-----]
Expand All @@ -354,7 +444,8 @@ describe('array', () => {
expect(() => tgpu.resolve([foo])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:foo: 'myVec' reference cannot be used in an array constructor.
- fn:foo
- ArrayExpression: 'myVec' reference cannot be used in an array constructor.
-----
Try 'vec2f(myVec)' or 'arrayOf(vec2f, count)([...])' to copy the value instead.
-----]
Expand Down
23 changes: 23 additions & 0 deletions packages/typegpu/tests/struct.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,29 @@ describe('struct', () => {
);
});

it('throws when invalid number of arguments during code generation', () => {
const Boid = struct({
pos: vec2f,
vel: vec2f,
});

const f = () => {
'use gpu';
const b1 = Boid({ pos: vec2f(6), vel: vec2f(7) });

// @ts-expect-error
const b2 = Boid(b1, b1);
return;
};

expect(() => tgpu.resolve([f])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn*:f
- fn*:f(): Struct schemas should always be called with at most 1 argument]
`);
});

it('allows builtin names as struct props', () => {
const myStruct = struct({
min: u32,
Expand Down
Loading