Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions pkg/ir/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,122 @@ func applyPseudoInverseGadget[F field.Element[F]](e air.Term[F], module air.Modu
return term.FieldAccess[F, air.Term[F]](index, 0)
}

// IsZeroIndicator returns an AIR term that holds the "is zero" indicator of
// e — 1 when e evaluates to zero, 0 otherwise. This is the value
// 1 - e*inv(e), realised as a shared computed column so that every caller
// that asks for the same e gets a single FieldAccess back instead of
// reconstructing a fresh 4-node subtree.
//
// The column is constrained by two vanishings:
//
// e * (1 - e*inv) == 0 (the existing inverse constraint, emitted by
// applyPseudoInverseGadget)
// iz + e*inv - 1 == 0 (defining vanishing: iz = 1 - e*inv)
//
// Together with iz's 1-bit declared width (range constraint to {0,1}) these
// force iz to be the correct indicator value: 1 iff e == 0.
func IsZeroIndicator[F field.Element[F]](e air.Term[F], module air.ModuleBuilder[F]) air.Term[F] {
// Ensure the inv column exists (creates it + its defining vanishing on
// first call for this e).
inv_e := applyPseudoInverseGadget(e, module)
//
return applyIsZeroGadget(e, inv_e, module)
}

// applyIsZeroGadget creates (or reuses) the iz column for e and returns a
// FieldAccess to it. Sharing-by-name follows the same pattern as
// applyPseudoInverseGadget.
func applyIsZeroGadget[F field.Element[F]](
e, inv_e air.Term[F], module air.ModuleBuilder[F],
) air.Term[F] {
var (
// Construct iz indicator term (used for name + padding only).
iz = &pseudoIz[F]{Expr: e}
// Determine computed column name.
name = iz.Lisp(true, module).String(false)
// Look up existing column.
index, ok = module.HasRegister(name)
// Padding value (0 or 1 depending on whether e is zero in padding).
padding = ir.PaddingFor[F](iz, module)
)
// Add new column (if it does not already exist).
if !ok {
// iz ∈ {0,1}: width 1 doubles as a range constraint.
index = module.NewRegister(register.NewComputed(name, 1, padding))
target := register.NewRef(module.Id(), index)
// Trace-filling assignment.
module.AddAssignment(assignment.NewPseudoIz(target, e))
// Defining vanishing: iz + e*inv - 1 == 0
var iz_access air.Term[F] = term.FieldAccess[F, air.Term[F]](index, 0)

e_inv := term.Product[F, air.Term[F]](e, inv_e)
defn := term.Subtract(
term.Sum[F, air.Term[F]](iz_access, e_inv),
term.Const64[F, air.Term[F]](1),
)
l_name := fmt.Sprintf("%s <=", name)
module.AddConstraint(air.NewVanishingConstraint(l_name, module.Id(), util.None[int](), defn))
}
//
return term.FieldAccess[F, air.Term[F]](index, 0)
}

// pseudoIz mirrors pseudoInverse but represents the "is zero" indicator
// (1 if Expr is zero, 0 otherwise). Used for the column name and for the
// padding-value computation; the row-by-row trace fill lives in
// assignment.PseudoIz.
type pseudoIz[F field.Element[F]] struct {
Expr air.Term[F]
}

// EvalAt returns 1 when Expr evaluates to zero, 0 otherwise.
func (e *pseudoIz[F]) EvalAt(k int, tr trace.Module[F], sc register.Map) (F, error) {
val, err := e.Expr.EvalAt(k, tr, sc)
if err != nil {
return val, err
}
//
var one F
if val.IsZero() {
return one.SetUint64(1), nil
}
//
var zero F

return zero, nil
}

// Bounds delegates to the underlying expression.
func (e *pseudoIz[F]) Bounds() util.Bounds { return e.Expr.Bounds() }

// RequiredRegisters delegates to the underlying expression.
func (e *pseudoIz[F]) RequiredRegisters() *set.SortedSet[uint] {
return e.Expr.RequiredRegisters()
}

// RequiredCells delegates to the underlying expression.
func (e *pseudoIz[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] {
return e.Expr.RequiredCells(row, mid)
}

// Lisp encodes the iz indicator as (iz <expr>) for naming purposes.
func (e *pseudoIz[F]) Lisp(global bool, mapping register.Map) sexp.SExp {
return sexp.NewList([]sexp.SExp{
sexp.NewSymbol("iz"),
e.Expr.Lisp(global, mapping),
})
}

// Substitute implementation for Substitutable interface.
func (e *pseudoIz[F]) Substitute(mapping map[string]F) {
panic("unreachable")
}

// ValueRange implementation for Term interface. iz is always in {0,1}.
func (e *pseudoIz[F]) ValueRange() util_math.Interval {
return util_math.NewInterval64(0, 1)
}

// pseudoInverse represents a computation which computes the multiplicative
// inverse of a given expression. This is only needed now for the padding
// computation.
Expand Down
157 changes: 157 additions & 0 deletions pkg/ir/assignment/pseudo_iz.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright Consensys Software Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package assignment

import (
"fmt"

"github.com/LFDT-Lineth/zkc/pkg/ir/air"
"github.com/LFDT-Lineth/zkc/pkg/schema"
"github.com/LFDT-Lineth/zkc/pkg/schema/register"
"github.com/LFDT-Lineth/zkc/pkg/trace"
"github.com/LFDT-Lineth/zkc/pkg/util"
"github.com/LFDT-Lineth/zkc/pkg/util/collection/array"
"github.com/LFDT-Lineth/zkc/pkg/util/collection/set"
"github.com/LFDT-Lineth/zkc/pkg/util/field"
"github.com/LFDT-Lineth/zkc/pkg/util/source/sexp"
)

// PseudoIz represents a computation which produces the "is zero" indicator
// of a given expression: 1 when the expression evaluates to zero, 0
// otherwise. It is the sibling of PseudoInverse and exists so that the
// MIR→AIR lowering can CSE the 1 - e*inv(e) subtree as a single shared
// computed column referenced from every NotEqual constraint over the same
// e — see [pkg/ir/air/gadgets/normalisation.go].
type PseudoIz[F field.Element[F]] struct {
// Target index for the computed column.
Target register.Ref
// Expression whose "is zero" indicator this column holds.
Expr air.Term[F]
}

// NewPseudoIz constructs a new "is zero" indicator assignment for the given
// target register and expression.
func NewPseudoIz[F field.Element[F]](target register.Ref, expr air.Term[F]) *PseudoIz[F] {
return &PseudoIz[F]{Target: target, Expr: expr}
}

// Bounds determines the well-definedness bounds for this assignment. It is
// the same as that of the expression whose value is being checked.
func (e *PseudoIz[F]) Bounds(mid schema.ModuleId) util.Bounds {
if mid == e.Target.Module() {
return e.Expr.Bounds()
}
//
return util.EMPTY_BOUND
}

// Compute fills the target column with 1 where the expression evaluates to
// zero and 0 elsewhere.
func (e *PseudoIz[F]) Compute(tr trace.Trace[F], schema schema.AnySchema[F]) ([]array.MutArray[F], error) {
var (
trModule = tr.Module(e.Target.Module())
scModule = schema.Module(e.Target.Module())
height = trModule.Height()
// 1-bit storage: each cell is 0 or 1.
data = tr.Builder().NewArray(height, 1)
one F
)
//
one = one.SetUint64(1)
//
for i := range height {
val, err := e.Expr.EvalAt(int(i), trModule, scModule)
if err != nil {
return nil, err
}
//
if val.IsZero() {
data = data.Set(i, one)
}
}
//
return []array.MutArray[F]{data}, nil
}

// Consistent performs some simple checks that the given assignment is
// consistent with its enclosing schema.
func (e *PseudoIz[F]) Consistent(schema.AnySchema[F]) []error {
return nil
}

// RegistersExpanded identifies registers expanded by this assignment.
func (e *PseudoIz[F]) RegistersExpanded() []register.Ref {
return nil
}

// RegistersRead returns the set of columns that this assignment depends upon.
func (e *PseudoIz[F]) RegistersRead() []register.Ref {
var (
module = e.Target.Module()
regs = e.Expr.RequiredRegisters()
rids = make([]register.Ref, regs.Iter().Count())
)
//
for i, iter := 0, regs.Iter(); iter.HasNext(); i++ {
rid := register.NewId(iter.Next())
rids[i] = register.NewRef(module, rid)
}
// Allow recursive definitions: never read the target itself.
return array.RemoveMatching(rids, func(r register.Ref) bool {
return r == e.Target
})
}

// RegistersWritten identifies registers assigned by this assignment.
func (e *PseudoIz[F]) RegistersWritten() []register.Ref {
return []register.Ref{e.Target}
}

// Lisp converts this assignment into an S-Expression.
//
//nolint:revive
func (e *PseudoIz[F]) Lisp(schema schema.AnySchema[F]) sexp.SExp {
var (
module = schema.Module(e.Target.Module())
target = module.Register(e.Target.Register())
datatype = "𝔽"
)
//
if !target.IsNative() {
datatype = fmt.Sprintf("u%d", target.Width())
}
//
return sexp.NewList(
[]sexp.SExp{sexp.NewSymbol("iz"),
sexp.NewList([]sexp.SExp{
sexp.NewSymbol(target.QualifiedName(module)),
sexp.NewSymbol(datatype),
}),
e.Expr.Lisp(false, module),
})
}

// RequiredRegisters returns the set of registers on which this term depends.
func (e *PseudoIz[F]) RequiredRegisters() *set.SortedSet[uint] {
return e.Expr.RequiredRegisters()
}

// RequiredCells returns the set of trace cells on which this term depends.
func (e *PseudoIz[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] {
return e.Expr.RequiredCells(row, mid)
}

// Substitute implementation for Substitutable interface.
func (e *PseudoIz[F]) Substitute(map[string]F) {
panic("unreachable")
}
86 changes: 47 additions & 39 deletions pkg/ir/mir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,18 +550,59 @@ func (p *AirLowering[F]) lowerEqualityTo(e *Equal[F], airModule air.ModuleBuilde

func (p *AirLowering[F]) lowerNonEqualityTo(e *NotEqual[F], airModule air.ModuleBuilder[F], bitwidths []uint,
) []air.Term[F] {
// //
var (
lhs air.Term[F] = p.lowerTermTo(e.Lhs, airModule)
rhs air.Term[F] = p.lowerTermTo(e.Rhs, airModule)
eq = term.Subtract(lhs, rhs)
)
//
one := term.Const64[F, air.Term[F]](1)
// construct norm(eq)
norm_eq := p.normalise(eq, airModule)
// construct 1 - norm(eq)
return []air.Term[F]{term.Subtract(one, norm_eq)}
return []air.Term[F]{p.lowerIsZeroIndicator(eq, airModule)}
}

// lowerIsZeroIndicator returns an AIR term that evaluates to 1 when arg is
// zero and 0 otherwise. This is the value that lowerNonEqualityTo asserts
// must vanish: 0 means "arg is non-zero", which encodes lhs != rhs.
//
// When arg's value range is small enough we use a cheap arithmetic form
// instead of materialising a column:
//
// arg ∈ {0,1} ⇒ 1 - arg
// arg ∈ {-1,0,1} ⇒ 1 - arg*arg
//
// Otherwise we delegate to air_gadgets.IsZeroIndicator, which CSEs the
// indicator as a shared computed column so every NotEqual over the same
// arg returns a single FieldAccess instead of rebuilding the
// 1 - arg*inv(arg) subtree.
func (p *AirLowering[F]) lowerIsZeroIndicator(arg air.Term[F], airModule air.ModuleBuilder[F]) air.Term[F] {
var (
bounds = arg.ValueRange()
one = term.Const64[F, air.Term[F]](1)
)
// Cheap shortcuts when no inverse is needed.
if p.config.InverseEliminiationLevel > 0 && bounds.Within(util_math.NewInterval64(0, 1)) {
return term.Subtract(one, arg)
} else if p.config.InverseEliminiationLevel > 0 && bounds.Within(util_math.NewInterval64(-1, 1)) {
return term.Subtract(one, term.Product(arg, arg))
}
// Determine an appropriate row-shift so the column is keyed on the
// canonical (shift-zero) form of arg; this keeps CSE working when the
// same expression appears at different relative shifts.
shift := 0
//
if p.config.ShiftNormalisation {
minS, maxS := arg.ShiftRange()
//
if maxS < 0 {
shift = maxS
} else if minS > 0 {
shift = minS
}
}
//
shifted := arg.ApplyShift(-shift).Simplify(false)
iz := air_gadgets.IsZeroIndicator(shifted, airModule)
//
return iz.ApplyShift(shift)
}

// Inner form is used for recursive calls and does not repeat the constant
Expand Down Expand Up @@ -634,39 +675,6 @@ func shiftTerm[F field.Element[F]](expr air.Term[F], width uint) air.Term[F] {
return term.Product(term.Const[F, air.Term[F]](n), expr)
}

func (p *AirLowering[F]) normalise(arg air.Term[F], airModule air.ModuleBuilder[F]) air.Term[F] {
bounds := arg.ValueRange()
// Check whether normalisation actually required. For example, if the
// argument is just a binary column then a normalisation is not actually
// required.
if p.config.InverseEliminiationLevel > 0 && bounds.Within(util_math.NewInterval64(0, 1)) {
// arg ∈ {0,1} ==> normalised already :)
return arg
} else if p.config.InverseEliminiationLevel > 0 && bounds.Within(util_math.NewInterval64(-1, 1)) {
// arg ∈ {-1,0,1} ==> (arg*arg) ∈ {0,1}
return term.Product(arg, arg)
}
// Determine appropriate shift
shift := 0
// Apply shift normalisation (if enabled)
if p.config.ShiftNormalisation {
// Determine shift ranges
min, max := arg.ShiftRange()
// determine shift amount
if max < 0 {
shift = max
} else if min > 0 {
shift = min
}
}
// Construct an expression representing the normalised value of e. That is,
// an expression which is 0 when e is 0, and 1 when e is non-zero.
arg = arg.ApplyShift(-shift).Simplify(false)
norm := air_gadgets.Normalise(arg, airModule)
//
return norm.ApplyShift(shift)
}

// Simplify a bunch of logical terms
func simplify[F field.Element[F]](terms []air.Term[F]) []air.Term[F] {
var nterms []air.Term[F] = make([]air.Term[F], len(terms))
Expand Down
Loading
Loading