Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/glua.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,20 @@ pub fn function(f: fn(List(Value)) -> Action(List(Value), Never)) -> Value {
/// to encourage using `glua.error` instead since `glua.failure` wouldn't make sense in that case.
pub type Never

pub fn function_decoder() -> decode.Decoder(
fn(List(Value)) -> Action(List(Value), e),
) {
decode.new_primitive_decoder("LuaFunction", decode_lua_function)
}

@external(erlang, "glua_ffi", "decode_fun")
fn decode_lua_function(
v: dynamic.Dynamic,
) -> Result(
fn(List(Value)) -> Action(List(Value), e),
fn(List(Value)) -> Action(List(Value), e),
)

pub fn list(encoder: fn(a) -> Value, values: List(a)) -> List(Value) {
list.map(values, encoder)
}
Expand Down
29 changes: 19 additions & 10 deletions src/glua_ffi.erl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
-import(ttdict, [fold/3]).
-include_lib("luerl/include/luerl.hrl").

-export([get_stacktrace/1, dereference/2, coerce/1, coerce_nil/0, wrap_fun/1, sandbox_fun/1, get_table_keys/2,
-export([get_stacktrace/1, dereference/2, coerce/1, coerce_nil/0, wrap_fun/1, decode_fun/1, sandbox_fun/1, get_table_keys/2,
get_private/2, set_table_keys/3, load/2, load_file/2, eval/2, eval_file/2,
eval_chunk/2, call_function/3]).


%% helper to convert luerl return values to a format
%% that is more suitable for use in Gleam code
to_gleam(Value) ->
Expand Down Expand Up @@ -39,11 +40,11 @@ dereference(#usdref{}=U, St, _In) ->
{#userdata{d=Data},_} = luerl_heap:get_userdata(U, St),
Data;
dereference(#funref{}=Fun, _St, _In) ->
dereference_fun(Fun);
dereference_fun(fun(Args, State) -> luerl_emul:functioncall(Fun, Args, State) end);
dereference(#erl_func{code=Fun}, _St, _In) ->
Fun; %Just the bare fun
dereference(#erl_mfa{m=M, f=F}, _St, _In) ->
dereference_fun(fun(Args, St0) -> M:F(nil, Args, St0) end);
dereference_fun(fun(Args, State) -> M:F(nil, Args, State) end);
dereference(Lua, _, _) -> error({badarg,Lua}). %Shouldn't have anything else

dereference_table(#tref{i=N}=T, St, In0) ->
Expand All @@ -70,13 +71,15 @@ dereference_table(#tref{i=N}=T, St, In0) ->
end.

dereference_fun(F) when is_function(F, 2) ->
{luafun, fun(St0, Args) ->
try
{Ret, St1} = F(Args, St0),
{ok, {St1, Ret}}
catch
error:{lua_error, _, _} = Err -> {error, map_error(Err)}
end
{luafun, fun(Args) ->
{action, fun(St0) ->
try
{Ret, St1} = F(Args, St0),
{ok, {St1, Ret}}
catch
error:{lua_error, _, _} = Err -> {error, map_error(Err)}
end
end}
end}.

map_error({error, Errors, _}) ->
Expand Down Expand Up @@ -229,6 +232,12 @@ sandbox_fun(Msg) ->
{error, map_error(lua_error({error_call, [Msg]}, State))}
end}.

decode_fun(Fun) ->
case Fun of
{luafun, F} -> {ok, F};
_ -> {error, fun(_) -> {action, fun(State) -> {ok, {State, nil}} end} end}
end.

get_table_keys(Lua, Keys) ->
case luerl:get_table_keys(Keys, Lua) of
{ok, nil, _} ->
Expand Down
143 changes: 143 additions & 0 deletions test/glua_test.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,146 @@ pub fn format_error_test() {
let assert Error(e) = glua.run(glua.new(), glua.failure(1))
assert glua.format_error(e) == "1"
}

pub fn decode_function_test() {
let lua_sqrt = {
use ref <- glua.then(glua.get(keys: ["math", "sqrt"]))
glua.dereference(ref:, using: glua.function_decoder())
}

let assert Ok(_) =
glua.run(glua.new(), {
use fun <- glua.then(lua_sqrt)
use ret <- glua.then(fun([glua.int(9)]))
use ref <- glua.try(list.first(ret))
use result <- glua.map(glua.dereference(ref:, using: decode.float))

assert result == 3.0
Nil
})

let assert Error(glua.LuaRuntimeException(_, _)) =
glua.run(glua.new(), {
use fun <- glua.then(lua_sqrt)
fun([])
})

let code =
"
local function fold(tbl, initial, cb)
local acc = initial
for _, v in ipairs(tbl) do
local new_acc = cb(acc, v)
acc = new_acc
end
return acc
end

return fold
"

let lua_fold = {
use ret <- glua.then(glua.eval(code:))
use ref <- glua.try(list.first(ret))
glua.dereference(ref:, using: glua.function_decoder())
}

let assert Ok(_) =
glua.run(glua.new(), {
use tbl <- glua.then(
list.index_map([3, 9, 27], fn(x, i) { #(glua.int(i + 1), glua.int(x)) })
|> glua.table,
)

let callback =
fn(args) {
use args <- glua.map(glua.fold(args, glua.dereference(_, decode.int)))
let assert [acc, v] = args
glua.int(acc * v)
|> list.wrap
}
|> glua.function

use fun <- glua.then(lua_fold)
use ret <- glua.then(fun([tbl, glua.int(1), callback]))
use ref <- glua.try(list.first(ret))
use result <- glua.map(glua.dereference(ref:, using: decode.int))
assert result == 729
Nil
})

let assert Error(glua.LuaRuntimeException(exn, _)) =
glua.run(glua.new(), {
use fun <- glua.then(lua_fold)
use tbl <- glua.then(
list.index_map([3, 9, 27], fn(x, i) { #(glua.int(i + 1), glua.int(x)) })
|> glua.table,
)
fun([tbl, glua.int(1)])
})
assert exn == glua.UndefinedFunction("nil")

let lua_table_unpack = {
use ref <- glua.then(glua.get(keys: ["table", "unpack"]))
glua.dereference(ref:, using: glua.function_decoder())
}

let assert Ok(_) =
glua.run(glua.new(), {
use fun <- glua.then(lua_table_unpack)
use tbl <- glua.then(
["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]
|> list.index_map(fn(l, i) { #(glua.int(i + 1), glua.string(l)) })
|> glua.table,
)

use ret <- glua.then(fun([tbl, glua.int(4), glua.int(8)]))
use result <- glua.then(
glua.fold(ret, glua.dereference(_, using: decode.string)),
)

assert result == ["d", "e", "f", "g", "h"]

use ret <- glua.then(fun([tbl, glua.int(8)]))
use result <- glua.then(
glua.fold(ret, glua.dereference(_, using: decode.string)),
)
assert result == ["h", "i", "j"]
glua.success(Nil)
})

let assert Error(glua.LuaRuntimeException(_, _)) =
glua.run(glua.new(), {
use fun <- glua.then(lua_table_unpack)
fun([])
})

let code = "return function() error('some error') end"

let assert Error(glua.LuaRuntimeException(exn, _)) =
glua.run(glua.new(), {
use ret <- glua.then(glua.eval(code:))
use ref <- glua.try(list.first(ret))
use fun <- glua.then(glua.dereference(
ref:,
using: glua.function_decoder(),
))
fun([])
})

assert exn == glua.ErrorCall("some error", option.None)

let code = "return function() return 3 * true end"
let assert Error(glua.LuaRuntimeException(exn, _)) =
glua.run(glua.new(), {
use ret <- glua.then(glua.eval(code:))
use ref <- glua.try(list.first(ret))
use fun <- glua.then(glua.dereference(
ref:,
using: glua.function_decoder(),
))
fun([])
})

assert exn == glua.BadArith("*", ["3", "true"])
}