Skip to content

User-facing API #2

@klamike

Description

@klamike

Currently the functionality exists but it is rather convoluted to use

L2ODLL.jl/test/runtests.jl

Lines 92 to 100 in fe32671

blp_cache = L2ODLL.build_cache(m, L2ODLL.BoundDecomposition(m));
blp_y_pred = randn_like(L2ODLL.get_y(blp_cache));
dobj1 = blp_cache.dll_layer(blp_y_pred, param_value)
dobj, dobj_wrt_y = DifferentiationInterface.value_and_gradient(
(y,p) ->blp_cache.dll_layer(L2ODLL.unflatten_y(y, L2ODLL.y_shape(blp_cache)),p),
DifferentiationInterface.AutoForwardDiff(),
L2ODLL.flatten_y(blp_y_pred), DifferentiationInterface.Constant(param_value)
)
dobj_wrt_y = L2ODLL.unflatten_y(dobj_wrt_y, L2ODLL.y_shape(blp_cache))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions