From 373928a77ccf748220fc0831d5650d74444088b2 Mon Sep 17 00:00:00 2001 From: selenil Date: Thu, 19 Feb 2026 23:40:18 -0500 Subject: [PATCH] add function decoder implement function decoder remove whitespace remove whitespace --- src/glua.gleam | 14 +++++ src/glua_ffi.erl | 29 ++++++--- test/glua_test.gleam | 143 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 176 insertions(+), 10 deletions(-) diff --git a/src/glua.gleam b/src/glua.gleam index 8872ed8..24c2f1e 100644 --- a/src/glua.gleam +++ b/src/glua.gleam @@ -556,6 +556,20 @@ pub fn function(f: fn(List(Value)) -> Action(List(Value), Never)) -> Value { /// to encourage using `glua.error` instead since `glua.failure` wouldn't make sense in that case. pub type Never +pub fn function_decoder() -> decode.Decoder( + fn(List(Value)) -> Action(List(Value), e), +) { + decode.new_primitive_decoder("LuaFunction", decode_lua_function) +} + +@external(erlang, "glua_ffi", "decode_fun") +fn decode_lua_function( + v: dynamic.Dynamic, +) -> Result( + fn(List(Value)) -> Action(List(Value), e), + fn(List(Value)) -> Action(List(Value), e), +) + pub fn list(encoder: fn(a) -> Value, values: List(a)) -> List(Value) { list.map(values, encoder) } diff --git a/src/glua_ffi.erl b/src/glua_ffi.erl index d585b28..acf808c 100644 --- a/src/glua_ffi.erl +++ b/src/glua_ffi.erl @@ -3,10 +3,11 @@ -import(ttdict, [fold/3]). -include_lib("luerl/include/luerl.hrl"). --export([get_stacktrace/1, dereference/2, coerce/1, coerce_nil/0, wrap_fun/1, sandbox_fun/1, get_table_keys/2, +-export([get_stacktrace/1, dereference/2, coerce/1, coerce_nil/0, wrap_fun/1, decode_fun/1, sandbox_fun/1, get_table_keys/2, get_private/2, set_table_keys/3, load/2, load_file/2, eval/2, eval_file/2, eval_chunk/2, call_function/3]). + %% helper to convert luerl return values to a format %% that is more suitable for use in Gleam code to_gleam(Value) -> @@ -39,11 +40,11 @@ dereference(#usdref{}=U, St, _In) -> {#userdata{d=Data},_} = luerl_heap:get_userdata(U, St), Data; dereference(#funref{}=Fun, _St, _In) -> - dereference_fun(Fun); + dereference_fun(fun(Args, State) -> luerl_emul:functioncall(Fun, Args, State) end); dereference(#erl_func{code=Fun}, _St, _In) -> Fun; %Just the bare fun dereference(#erl_mfa{m=M, f=F}, _St, _In) -> - dereference_fun(fun(Args, St0) -> M:F(nil, Args, St0) end); + dereference_fun(fun(Args, State) -> M:F(nil, Args, State) end); dereference(Lua, _, _) -> error({badarg,Lua}). %Shouldn't have anything else dereference_table(#tref{i=N}=T, St, In0) -> @@ -70,13 +71,15 @@ dereference_table(#tref{i=N}=T, St, In0) -> end. dereference_fun(F) when is_function(F, 2) -> - {luafun, fun(St0, Args) -> - try - {Ret, St1} = F(Args, St0), - {ok, {St1, Ret}} - catch - error:{lua_error, _, _} = Err -> {error, map_error(Err)} - end + {luafun, fun(Args) -> + {action, fun(St0) -> + try + {Ret, St1} = F(Args, St0), + {ok, {St1, Ret}} + catch + error:{lua_error, _, _} = Err -> {error, map_error(Err)} + end + end} end}. map_error({error, Errors, _}) -> @@ -229,6 +232,12 @@ sandbox_fun(Msg) -> {error, map_error(lua_error({error_call, [Msg]}, State))} end}. +decode_fun(Fun) -> + case Fun of + {luafun, F} -> {ok, F}; + _ -> {error, fun(_) -> {action, fun(State) -> {ok, {State, nil}} end} end} + end. + get_table_keys(Lua, Keys) -> case luerl:get_table_keys(Keys, Lua) of {ok, nil, _} -> diff --git a/test/glua_test.gleam b/test/glua_test.gleam index 3fc8e42..d28b2ae 100644 --- a/test/glua_test.gleam +++ b/test/glua_test.gleam @@ -763,3 +763,146 @@ pub fn format_error_test() { let assert Error(e) = glua.run(glua.new(), glua.failure(1)) assert glua.format_error(e) == "1" } + +pub fn decode_function_test() { + let lua_sqrt = { + use ref <- glua.then(glua.get(keys: ["math", "sqrt"])) + glua.dereference(ref:, using: glua.function_decoder()) + } + + let assert Ok(_) = + glua.run(glua.new(), { + use fun <- glua.then(lua_sqrt) + use ret <- glua.then(fun([glua.int(9)])) + use ref <- glua.try(list.first(ret)) + use result <- glua.map(glua.dereference(ref:, using: decode.float)) + + assert result == 3.0 + Nil + }) + + let assert Error(glua.LuaRuntimeException(_, _)) = + glua.run(glua.new(), { + use fun <- glua.then(lua_sqrt) + fun([]) + }) + + let code = + " + local function fold(tbl, initial, cb) + local acc = initial + for _, v in ipairs(tbl) do + local new_acc = cb(acc, v) + acc = new_acc + end + return acc + end + + return fold +" + + let lua_fold = { + use ret <- glua.then(glua.eval(code:)) + use ref <- glua.try(list.first(ret)) + glua.dereference(ref:, using: glua.function_decoder()) + } + + let assert Ok(_) = + glua.run(glua.new(), { + use tbl <- glua.then( + list.index_map([3, 9, 27], fn(x, i) { #(glua.int(i + 1), glua.int(x)) }) + |> glua.table, + ) + + let callback = + fn(args) { + use args <- glua.map(glua.fold(args, glua.dereference(_, decode.int))) + let assert [acc, v] = args + glua.int(acc * v) + |> list.wrap + } + |> glua.function + + use fun <- glua.then(lua_fold) + use ret <- glua.then(fun([tbl, glua.int(1), callback])) + use ref <- glua.try(list.first(ret)) + use result <- glua.map(glua.dereference(ref:, using: decode.int)) + assert result == 729 + Nil + }) + + let assert Error(glua.LuaRuntimeException(exn, _)) = + glua.run(glua.new(), { + use fun <- glua.then(lua_fold) + use tbl <- glua.then( + list.index_map([3, 9, 27], fn(x, i) { #(glua.int(i + 1), glua.int(x)) }) + |> glua.table, + ) + fun([tbl, glua.int(1)]) + }) + assert exn == glua.UndefinedFunction("nil") + + let lua_table_unpack = { + use ref <- glua.then(glua.get(keys: ["table", "unpack"])) + glua.dereference(ref:, using: glua.function_decoder()) + } + + let assert Ok(_) = + glua.run(glua.new(), { + use fun <- glua.then(lua_table_unpack) + use tbl <- glua.then( + ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] + |> list.index_map(fn(l, i) { #(glua.int(i + 1), glua.string(l)) }) + |> glua.table, + ) + + use ret <- glua.then(fun([tbl, glua.int(4), glua.int(8)])) + use result <- glua.then( + glua.fold(ret, glua.dereference(_, using: decode.string)), + ) + + assert result == ["d", "e", "f", "g", "h"] + + use ret <- glua.then(fun([tbl, glua.int(8)])) + use result <- glua.then( + glua.fold(ret, glua.dereference(_, using: decode.string)), + ) + assert result == ["h", "i", "j"] + glua.success(Nil) + }) + + let assert Error(glua.LuaRuntimeException(_, _)) = + glua.run(glua.new(), { + use fun <- glua.then(lua_table_unpack) + fun([]) + }) + + let code = "return function() error('some error') end" + + let assert Error(glua.LuaRuntimeException(exn, _)) = + glua.run(glua.new(), { + use ret <- glua.then(glua.eval(code:)) + use ref <- glua.try(list.first(ret)) + use fun <- glua.then(glua.dereference( + ref:, + using: glua.function_decoder(), + )) + fun([]) + }) + + assert exn == glua.ErrorCall("some error", option.None) + + let code = "return function() return 3 * true end" + let assert Error(glua.LuaRuntimeException(exn, _)) = + glua.run(glua.new(), { + use ret <- glua.then(glua.eval(code:)) + use ref <- glua.try(list.first(ret)) + use fun <- glua.then(glua.dereference( + ref:, + using: glua.function_decoder(), + )) + fun([]) + }) + + assert exn == glua.BadArith("*", ["3", "true"]) +}