diff --git a/fhevm/params.go b/fhevm/params.go index 3f3bab7..24f8064 100644 --- a/fhevm/params.go +++ b/fhevm/params.go @@ -66,18 +66,19 @@ type GasCosts struct { FheGetCiphertext map[tfhe.FheUintType]uint64 // TEE Operations - TeeAddSub map[tfhe.FheUintType]uint64 - TeeMul map[tfhe.FheUintType]uint64 - TeeDiv map[tfhe.FheUintType]uint64 - TeeRem map[tfhe.FheUintType]uint64 - TeeEncrypt map[tfhe.FheUintType]uint64 - TeeDecrypt map[tfhe.FheUintType]uint64 - TeeComparison map[tfhe.FheUintType]uint64 - TeeShift map[tfhe.FheUintType]uint64 - TeeNot map[tfhe.FheUintType]uint64 - TeeNeg map[tfhe.FheUintType]uint64 - TeeBitwiseOp map[tfhe.FheUintType]uint64 - TeeCast uint64 + TeeAddSub map[tfhe.FheUintType]uint64 + TeeMul map[tfhe.FheUintType]uint64 + TeeDiv map[tfhe.FheUintType]uint64 + TeeRem map[tfhe.FheUintType]uint64 + TeeEncrypt map[tfhe.FheUintType]uint64 + TeeDecrypt map[tfhe.FheUintType]uint64 + TeeComparison map[tfhe.FheUintType]uint64 + TeeShift map[tfhe.FheUintType]uint64 + TeeNot map[tfhe.FheUintType]uint64 + TeeNeg map[tfhe.FheUintType]uint64 + TeeBitwiseOp map[tfhe.FheUintType]uint64 + TeeVerifyCiphertext map[tfhe.FheUintType]uint64 + TeeCast uint64 } func DefaultGasCosts() GasCosts { @@ -308,6 +309,14 @@ func DefaultGasCosts() GasCosts { tfhe.FheUint32: 150, tfhe.FheUint64: 189, }, + TeeVerifyCiphertext: map[tfhe.FheUintType]uint64{ + tfhe.FheBool: 60, + tfhe.FheUint4: 60, + tfhe.FheUint8: 60, + tfhe.FheUint16: 70, + tfhe.FheUint32: 90, + tfhe.FheUint64: 120, + }, } } diff --git a/fhevm/tee_cast.go b/fhevm/tee_cast.go index 72d9dd2..3d4bbc6 100644 --- a/fhevm/tee_cast.go +++ b/fhevm/tee_cast.go @@ -22,11 +22,11 @@ func teeCastTo(ciphertext *tfhe.TfheCiphertext, castToType tfhe.FheUintType) (*t value := big.NewInt(0).SetBytes(result.Value).Uint64() - resultBz, err := marshalTfheType(value, castToType) - + resultBz, err := tee.MarshalTfheType(value, castToType) if err != nil { return nil, errors.New("marshalling failed") } + teePlaintext := tee.NewTeePlaintext(resultBz, castToType, common.Address{}) resultCt, err := tee.Encrypt(teePlaintext) diff --git a/fhevm/tee_comparison.go b/fhevm/tee_comparison.go index 4d5cf5b..78b7b27 100644 --- a/fhevm/tee_comparison.go +++ b/fhevm/tee_comparison.go @@ -109,7 +109,7 @@ func teeSelectRun(environment EVMEnvironment, caller common.Address, addr common } else { result.Set(t) } - resultBz, err := marshalTfheType(&result, p2.FheUintType) + resultBz, err := tee.MarshalTfheType(&result, p2.FheUintType) if err != nil { return nil, err } diff --git a/fhevm/tee_crypto.go b/fhevm/tee_crypto.go index 46c5042..b05d8a8 100644 --- a/fhevm/tee_crypto.go +++ b/fhevm/tee_crypto.go @@ -2,6 +2,7 @@ package fhevm import ( "bytes" + "encoding/binary" "encoding/hex" "errors" "math/big" @@ -88,3 +89,75 @@ func teeDecryptRun(environment EVMEnvironment, caller common.Address, addr commo copy(ret[32-len(plaintext):], plaintext) return ret, nil } + +func teeVerifyCiphertextRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) { + logger := environment.GetLogger() + // first 32 bytes of the payload is offset, then 32 bytes are size of byte array + if len(input) <= 68 { + err := errors.New("verifyCiphertext(bytes) must contain at least 68 bytes for selector, byte offset and size") + logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + bytesPaddingSize := 32 + bytesSizeSlotSize := 32 + // read only last 4 bytes of padded number for byte array size + sizeStart := bytesPaddingSize + bytesSizeSlotSize - 4 + sizeEnd := sizeStart + 4 + bytesSize := binary.BigEndian.Uint32(input[sizeStart:sizeEnd]) + bytesStart := bytesPaddingSize + bytesSizeSlotSize + bytesEnd := bytesStart + int(bytesSize) + input = input[bytesStart:minInt(bytesEnd, len(input))] + + if len(input) <= 1 { + msg := "verifyCiphertext Run() input needs to contain a ciphertext and one byte for its type" + logger.Error(msg, "len", len(input)) + return nil, errors.New(msg) + } + + ctBytes := input[:len(input)-1] + ctTypeByte := input[len(input)-1] + if !tfhe.IsValidFheType(ctTypeByte) { + msg := "verifyCiphertext Run() ciphertext type is invalid" + logger.Error(msg, "type", ctTypeByte) + return nil, errors.New(msg) + } + ctType := tfhe.FheUintType(ctTypeByte) + otelDescribeOperandsFheTypes(runSpan, ctType) + + expectedSize, found := tee.GetTeeCiphertextSize(ctType) + if !found || expectedSize != uint(len(ctBytes)) { + msg := "verifyCiphertext Run() compact ciphertext size is invalid" + logger.Error(msg, "type", ctTypeByte, "size", len(ctBytes), "expectedSize", expectedSize) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, ctType), nil + } + + ct := new(tfhe.TfheCiphertext) + ct.Serialization = ctBytes + ct.FheUintType = ctType + + plaintext, err := tee.Decrypt(ct) + if err != nil { + msg := "verifyCiphertext Run() compact ciphertext is invalid" + return nil, errors.New(msg) + } + + if plaintext.FheUintType != ctType { + msg := "verifyCiphertext Run() compact type mismatch" + logger.Error(msg, "type", plaintext.FheUintType, "expectedType", ctType) + return nil, errors.New(msg) + } + + ctHash := ct.GetHash() + importCiphertext(environment, ct) + if environment.IsCommitting() { + logger.Info("verifyCiphertext success", + "ctHash", ctHash.Hex(), + "ctBytes64", hex.EncodeToString(ctBytes[:minInt(len(ctBytes), 64)])) + } + return ctHash.Bytes(), nil +} diff --git a/fhevm/tee_crypto_gas.go b/fhevm/tee_crypto_gas.go index 29bdc41..3ae84eb 100644 --- a/fhevm/tee_crypto_gas.go +++ b/fhevm/tee_crypto_gas.go @@ -34,3 +34,21 @@ func teeDecryptRequiredGas(environment EVMEnvironment, input []byte) uint64 { } return environment.FhevmParams().GasCosts.TeeDecrypt[ct.fheUintType()] } + +func teeVerifyCiphertextRequiredGas(environment EVMEnvironment, input []byte) uint64 { + logger := environment.GetLogger() + + if len(input) <= 68 { + logger.Error("verifyCiphertext(bytes) must contain at least 68 bytes for selector, byte offset and size") + return 0 + } + ctTypeByte := input[len(input)-1] + if !tfhe.IsValidFheType(ctTypeByte) { + msg := "verifyCiphertext Run() ciphertext type is invalid" + logger.Error(msg, "type", ctTypeByte) + return 0 + } + + ctType := tfhe.FheUintType(ctTypeByte) + return environment.FhevmParams().GasCosts.TeeVerifyCiphertext[ctType] +} diff --git a/fhevm/tee_crypto_test.go b/fhevm/tee_crypto_test.go index 36b6106..2f52558 100644 --- a/fhevm/tee_crypto_test.go +++ b/fhevm/tee_crypto_test.go @@ -6,6 +6,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/zama-ai/fhevm-go/fhevm/tfhe" + "github.com/zama-ai/fhevm-go/tee" "pgregory.net/rapid" ) @@ -46,3 +47,148 @@ func TestTeeDecryptRun(t *testing.T) { } }) } + +func TestTeeVerifyCiphertext4(t *testing.T) { + TeeVerifyCiphertext(t, tfhe.FheUint4) +} + +func TestTeeVerifyCiphertext8(t *testing.T) { + TeeVerifyCiphertext(t, tfhe.FheUint8) +} + +func TestTeeVerifyCiphertext16(t *testing.T) { + TeeVerifyCiphertext(t, tfhe.FheUint16) +} + +func TestTeeVerifyCiphertext32(t *testing.T) { + TeeVerifyCiphertext(t, tfhe.FheUint32) +} + +func TestTeeVerifyCiphertext64(t *testing.T) { + TeeVerifyCiphertext(t, tfhe.FheUint64) +} + +func TestTeeVerifyCiphertext4BadType(t *testing.T) { + TeeVerifyCiphertextBadType(t, tfhe.FheUint4, tfhe.FheUint8) + TeeVerifyCiphertextBadType(t, tfhe.FheUint4, tfhe.FheUint16) + TeeVerifyCiphertextBadType(t, tfhe.FheUint4, tfhe.FheUint32) + TeeVerifyCiphertextBadType(t, tfhe.FheUint4, tfhe.FheUint64) +} + +func TestTeeVerifyCiphertext8BadType(t *testing.T) { + TeeVerifyCiphertextBadType(t, tfhe.FheUint8, tfhe.FheUint4) + TeeVerifyCiphertextBadType(t, tfhe.FheUint8, tfhe.FheUint16) + TeeVerifyCiphertextBadType(t, tfhe.FheUint8, tfhe.FheUint32) + TeeVerifyCiphertextBadType(t, tfhe.FheUint8, tfhe.FheUint64) +} + +func TestTeeVerifyCiphertext16BadType(t *testing.T) { + TeeVerifyCiphertextBadType(t, tfhe.FheUint16, tfhe.FheUint4) + TeeVerifyCiphertextBadType(t, tfhe.FheUint16, tfhe.FheUint8) + TeeVerifyCiphertextBadType(t, tfhe.FheUint16, tfhe.FheUint32) + TeeVerifyCiphertextBadType(t, tfhe.FheUint16, tfhe.FheUint64) +} + +func TestTeeVerifyCiphertext32BadType(t *testing.T) { + TeeVerifyCiphertextBadType(t, tfhe.FheUint32, tfhe.FheUint4) + TeeVerifyCiphertextBadType(t, tfhe.FheUint32, tfhe.FheUint8) + TeeVerifyCiphertextBadType(t, tfhe.FheUint32, tfhe.FheUint16) + TeeVerifyCiphertextBadType(t, tfhe.FheUint32, tfhe.FheUint64) +} + +func TestTeeVerifyCiphertext64BadType(t *testing.T) { + TeeVerifyCiphertextBadType(t, tfhe.FheUint64, tfhe.FheUint4) + TeeVerifyCiphertextBadType(t, tfhe.FheUint64, tfhe.FheUint8) + TeeVerifyCiphertextBadType(t, tfhe.FheUint64, tfhe.FheUint16) + TeeVerifyCiphertextBadType(t, tfhe.FheUint64, tfhe.FheUint32) +} + +func TeeVerifyCiphertext(t *testing.T, fheUintType tfhe.FheUintType) { + var value uint64 + switch fheUintType { + case tfhe.FheBool: + value = 1 + case tfhe.FheUint4: + value = 2 + case tfhe.FheUint8: + value = 234 + case tfhe.FheUint16: + value = 4283 + case tfhe.FheUint32: + value = 1333337 + case tfhe.FheUint64: + value = 13333377777777777 + } + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + resultBz, err := tee.MarshalTfheType(value, fheUintType) + if err != nil { + t.Fatalf(err.Error()) + } + + plaintext := tee.NewTeePlaintext(resultBz, fheUintType, addr) + ct, err := tee.Encrypt(plaintext) + if err != nil { + t.Fatalf(err.Error()) + } + + input := prepareInputForVerifyCiphertext(append(ct.Serialization, byte(fheUintType))) + out, err := teeVerifyCiphertextRun(environment, addr, addr, input, readOnly, nil) + if err != nil { + t.Fatalf(err.Error()) + } + + if common.BytesToHash(out) != ct.GetHash() { + t.Fatalf("output hash in verifyCipertext is incorrect") + } + res := getVerifiedCiphertextFromEVM(environment, ct.GetHash()) + if res == nil { + t.Fatalf("verifyCiphertext must have verified given ciphertext") + } +} + +func TeeVerifyCiphertextBadType(t *testing.T, actualType tfhe.FheUintType, metadataType tfhe.FheUintType) { + var value uint64 + switch actualType { + case tfhe.FheUint4: + value = 2 + case tfhe.FheUint8: + value = 2 + case tfhe.FheUint16: + value = 4283 + case tfhe.FheUint32: + value = 1333337 + case tfhe.FheUint64: + value = 13333377777777777 + } + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + + resultBz, err := tee.MarshalTfheType(value, actualType) + if err != nil { + t.Fatalf(err.Error()) + } + + plaintext := tee.NewTeePlaintext(resultBz, actualType, addr) + ct, err := tee.Encrypt(plaintext) + if err != nil { + t.Fatalf(err.Error()) + } + + input := prepareInputForVerifyCiphertext(append(ct.Serialization, byte(metadataType))) + _, err = teeVerifyCiphertextRun(environment, addr, addr, input, readOnly, nil) + if err == nil { + t.Fatalf("verifyCiphertext must have failed on type mismatch") + } + + if len(environment.FhevmData().verifiedCiphertexts) != 0 { + t.Fatalf("verifyCiphertext mustn't have verified given ciphertext") + } +} diff --git a/fhevm/tee_interpreter.go b/fhevm/tee_interpreter.go index 186b6ef..4788a35 100644 --- a/fhevm/tee_interpreter.go +++ b/fhevm/tee_interpreter.go @@ -1,7 +1,6 @@ package fhevm import ( - "encoding/binary" "encoding/hex" "errors" "fmt" @@ -55,7 +54,7 @@ func doOp( result := operator(l, r) var resultBz []byte - resultBz, err = marshalTfheType(result, lp.FheUintType) + resultBz, err = tee.MarshalTfheType(result, lp.FheUintType) if err != nil { logger.Error(op, "failed", "err", err) return nil, err @@ -118,7 +117,7 @@ func doEqNeOp( result := operator(l, r) - resultBz, err := marshalTfheType(result, lp.FheUintType) + resultBz, err := tee.MarshalTfheType(result, lp.FheUintType) if err != nil { logger.Error(op, "failed", "err", err) return nil, err @@ -187,7 +186,7 @@ func doShiftOp( return nil, err } var resultBz []byte - resultBz, err = marshalTfheType(result, lp.FheUintType) + resultBz, err = tee.MarshalTfheType(result, lp.FheUintType) if err != nil { logger.Error(op, "failed", "err", err) return nil, err @@ -252,7 +251,7 @@ func doNegNotOp( result := operator(c) var resultBz []byte - resultBz, err = marshalTfheType(result, cp.FheUintType) + resultBz, err = tee.MarshalTfheType(result, cp.FheUintType) if err != nil { logger.Error(op, "failed", "err", err) return nil, err @@ -388,58 +387,6 @@ func extract3Operands(op string, environment EVMEnvironment, input []byte, runSp return &fp, &sp, &tp, fhs, shs, ths, nil } -// marshalTfheType converts a any to a byte slice -func marshalTfheType(value any, typ tfhe.FheUintType) ([]byte, error) { - switch value := any(value).(type) { - case uint64: - switch typ { - case tfhe.FheBool: - resultBz := make([]byte, 1) - resultBz[0] = byte(value) - return resultBz, nil - case tfhe.FheUint4: - value = value & 0x0f - return []byte{byte(value)}, nil - case tfhe.FheUint8: - resultBz := []byte{byte(value)} - return resultBz, nil - case tfhe.FheUint16: - resultBz := make([]byte, 2) - binary.BigEndian.PutUint16(resultBz, uint16(value)) - return resultBz, nil - case tfhe.FheUint32: - resultBz := make([]byte, 4) - binary.BigEndian.PutUint32(resultBz, uint32(value)) - return resultBz, nil - case tfhe.FheUint64: - resultBz := make([]byte, 8) - binary.BigEndian.PutUint64(resultBz, value) - return resultBz, nil - case tfhe.FheUint160: - resultBz := make([]byte, 8) - binary.BigEndian.PutUint64(resultBz, value) - return resultBz, nil - default: - return nil, - fmt.Errorf("unsupported FheUintType: %s", typ) - } - case bool: - resultBz := make([]byte, 1) - if value { - resultBz[0] = 1 - } else { - resultBz[0] = 0 - } - return resultBz, nil - case *big.Int: - resultBz := value.Bytes() - return resultBz, nil - default: - return nil, - fmt.Errorf("unsupported value type: %s", value) - } -} - func boolToUint64(b bool) uint64 { if b { return 1 // true converts to 1 diff --git a/fhevm/tee_test.go b/fhevm/tee_test.go index 757ee9a..6e62184 100644 --- a/fhevm/tee_test.go +++ b/fhevm/tee_test.go @@ -31,7 +31,7 @@ func teeOperationHelper(t *testing.T, fheUintType tfhe.FheUintType, lhs, rhs, ex } input = toLibPrecompileInput(signature, false, lhsCt.GetHash(), rhsCt.GetHash()) } else { - valueBz, err := marshalTfheType(rhs, fheUintType) + valueBz, err := tee.MarshalTfheType(rhs, fheUintType) if err != nil { t.Fatalf(err.Error()) } @@ -151,7 +151,7 @@ func teeNegNotOperationHelper(t *testing.T, fheUintType tfhe.FheUintType, chs, e } func importTeePlaintextToEVM(environment EVMEnvironment, depth int, value any, typ tfhe.FheUintType) (tfhe.TfheCiphertext, error) { - valueBz, err := marshalTfheType(value, typ) + valueBz, err := tee.MarshalTfheType(value, typ) if err != nil { return tfhe.TfheCiphertext{}, err } diff --git a/fhevm/teelib.go b/fhevm/teelib.go index 85e155b..78dd1f9 100644 --- a/fhevm/teelib.go +++ b/fhevm/teelib.go @@ -167,6 +167,12 @@ var teelibMethods = []*FheLibMethod{ requiredGasFunction: teeCastRequiredGas, runFunction: teeCastRun, }, + { + name: "teeVerifyCiphertext", + argTypes: "(bytes)", + requiredGasFunction: teeVerifyCiphertextRequiredGas, + runFunction: teeVerifyCiphertextRun, + }, } func init() { diff --git a/tee/tee_mock.go b/tee/tee_mock.go index 8b59065..caf2d2e 100644 --- a/tee/tee_mock.go +++ b/tee/tee_mock.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" @@ -23,6 +24,7 @@ func init() { panic(err) } key = ecies.ImportECDSA(ecdsaKey) + initCiphertextSizes() } type TeePlaintext struct { @@ -118,3 +120,89 @@ func Decrypt(ct *tfhe.TfheCiphertext) (TeePlaintext, error) { return plaintext, nil } + +func GetTeeCiphertextSize(t tfhe.FheUintType) (size uint, found bool) { + size, found = teeCiphertextSize[t] + return +} + +// Compact TFHE ciphertext sizes by type, in bytes. +var teeCiphertextSize map[tfhe.FheUintType]uint + +func initCiphertextSizes() { + teeCiphertextSize = make(map[tfhe.FheUintType]uint) + + teeCiphertextSize[tfhe.FheBool] = uint(len(TeeEncryptAndSerialize(0, tfhe.FheBool))) + teeCiphertextSize[tfhe.FheUint4] = uint(len(TeeEncryptAndSerialize(0, tfhe.FheUint4))) + teeCiphertextSize[tfhe.FheUint8] = uint(len(TeeEncryptAndSerialize(0, tfhe.FheUint8))) + teeCiphertextSize[tfhe.FheUint16] = uint(len(TeeEncryptAndSerialize(0, tfhe.FheUint16))) + teeCiphertextSize[tfhe.FheUint32] = uint(len(TeeEncryptAndSerialize(0, tfhe.FheUint32))) + teeCiphertextSize[tfhe.FheUint64] = uint(len(TeeEncryptAndSerialize(0, tfhe.FheUint64))) + teeCiphertextSize[tfhe.FheUint160] = uint(len(TeeEncryptAndSerialize(0, tfhe.FheUint160))) +} + +func TeeEncryptAndSerialize(value uint64, fheUintType tfhe.FheUintType) []byte { + resultBz, err := MarshalTfheType(value, fheUintType) + if err != nil { + panic(err) + } + + ciphertext, err := Encrypt(NewTeePlaintext(resultBz, fheUintType, common.Address{})) + if err != nil { + panic(err) + } + + return ciphertext.Serialization +} + +// MarshalTfheType converts a any to a byte slice +func MarshalTfheType(value any, typ tfhe.FheUintType) ([]byte, error) { + switch value := any(value).(type) { + case uint64: + switch typ { + case tfhe.FheBool: + resultBz := make([]byte, 1) + resultBz[0] = byte(value) + return resultBz, nil + case tfhe.FheUint4: + value = value & 0x0f + return []byte{byte(value)}, nil + case tfhe.FheUint8: + resultBz := []byte{byte(value)} + return resultBz, nil + case tfhe.FheUint16: + resultBz := make([]byte, 2) + binary.BigEndian.PutUint16(resultBz, uint16(value)) + return resultBz, nil + case tfhe.FheUint32: + resultBz := make([]byte, 4) + binary.BigEndian.PutUint32(resultBz, uint32(value)) + return resultBz, nil + case tfhe.FheUint64: + resultBz := make([]byte, 8) + binary.BigEndian.PutUint64(resultBz, value) + return resultBz, nil + case tfhe.FheUint160: + resultBz := make([]byte, 8) + binary.BigEndian.PutUint64(resultBz, value) + return resultBz, nil + default: + return nil, + fmt.Errorf("unsupported FheUintType: %s", typ) + } + case bool: + resultBz := make([]byte, 1) + if value { + resultBz[0] = 1 + } else { + resultBz[0] = 0 + } + return resultBz, nil + case *big.Int: + resultBz := value.Bytes() + return resultBz, nil + default: + return nil, + fmt.Errorf("unsupported value type: %s", value) + } +}