Skip to content

Commit 64f1727

Browse files
authored
Interface for solver supporting diff natively (#356)
* Interface for solver supporting diff natively * Fixes * Fix format * Claude improves coverage * Fix format * Add tests * fix format * Rename * Use the attributes in inner models too * Simplify * Fix format * Fix * Restore keyword option for nonlinear * Fix * Fix * Fixes * Fix * Fix * Fix * Fix * Simpler * Fix format * Simplify handling of POI layer * Fix format
1 parent 5aff2b0 commit 64f1727

File tree

11 files changed

+1304
-114
lines changed

11 files changed

+1304
-114
lines changed

src/ConicProgram/ConicProgram.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ function _gradient_cache(model::Model)
253253
return model.gradient_cache
254254
end
255255

256-
function DiffOpt.forward_differentiate!(model::Model)
256+
MOI.supports(::Model, ::DiffOpt.ForwardDifferentiate) = true
257+
258+
function MOI.set(model::Model, ::DiffOpt.ForwardDifferentiate, ::Nothing)
257259
model.diff_time = @elapsed begin
258260
gradient_cache = _gradient_cache(model)
259261
M = gradient_cache.M
@@ -333,7 +335,9 @@ function DiffOpt.forward_differentiate!(model::Model)
333335
# return -dx, -dy, -ds
334336
end
335337

336-
function DiffOpt.reverse_differentiate!(model::Model)
338+
MOI.supports(::Model, ::DiffOpt.ReverseDifferentiate) = true
339+
340+
function MOI.set(model::Model, ::DiffOpt.ReverseDifferentiate, ::Nothing)
337341
model.diff_time = @elapsed begin
338342
gradient_cache = _gradient_cache(model)
339343
M = gradient_cache.M

src/NonLinearProgram/NonLinearProgram.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,13 @@ function _cache_evaluator!(model::Model)
512512
return model.cache
513513
end
514514

515+
MOI.supports(::Model, ::DiffOpt.ForwardDifferentiate) = true
516+
517+
function MOI.set(model::Model, ::DiffOpt.ForwardDifferentiate, value)
518+
kws = something(value, (;))
519+
return DiffOpt.forward_differentiate!(model; kws...)
520+
end
521+
515522
function DiffOpt.forward_differentiate!(model::Model; tol = 1e-6)
516523
model.diff_time = @elapsed begin
517524
cache = _cache_evaluator!(model)
@@ -544,6 +551,13 @@ function DiffOpt.forward_differentiate!(model::Model; tol = 1e-6)
544551
return nothing
545552
end
546553

554+
MOI.supports(::Model, ::DiffOpt.ReverseDifferentiate) = true
555+
556+
function MOI.set(model::Model, ::DiffOpt.ReverseDifferentiate, value)
557+
kws = something(value, (;))
558+
return DiffOpt.reverse_differentiate!(model; kws...)
559+
end
560+
547561
function DiffOpt.reverse_differentiate!(model::Model; tol = 1e-6)
548562
model.diff_time = @elapsed begin
549563
cache = _cache_evaluator!(model)

src/QuadraticProgram/QuadraticProgram.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ function DiffOpt._get_db(model::Model, ci::EQ)
313313
return model.back_grad_cache.dν[ci.value]
314314
end
315315

316-
function DiffOpt.reverse_differentiate!(model::Model)
316+
MOI.supports(::Model, ::DiffOpt.ReverseDifferentiate) = true
317+
318+
function MOI.set(model::Model, ::DiffOpt.ReverseDifferentiate, ::Nothing)
317319
model.diff_time = @elapsed begin
318320
gradient_cache = _gradient_cache(model)
319321
LHS = gradient_cache.lhs
@@ -354,7 +356,9 @@ end
354356
struct _QPSets end
355357
MOI.Utilities.rows(::_QPSets, ci::MOI.ConstraintIndex) = ci.value
356358

357-
function DiffOpt.forward_differentiate!(model::Model)
359+
MOI.supports(::Model, ::DiffOpt.ForwardDifferentiate) = true
360+
361+
function MOI.set(model::Model, ::DiffOpt.ForwardDifferentiate, ::Nothing)
358362
model.diff_time = @elapsed begin
359363
gradient_cache = _gradient_cache(model)
360364
LHS = gradient_cache.lhs

src/diff_opt.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ with respect to the solution set with the [`ReverseVariablePrimal`](@ref) attrib
5454
The output problem data differentials can be queried with the
5555
attributes [`ReverseObjectiveFunction`](@ref) and [`ReverseConstraintFunction`](@ref).
5656
"""
57-
function reverse_differentiate! end
57+
function reverse_differentiate!(model)
58+
return MOI.set(model, ReverseDifferentiate(), nothing)
59+
end
5860

5961
"""
60-
forward_differentiate!(model::Optimizer)
62+
forward_differentiate!(model::Union{MOI.ModelLike,JuMP.AbstractModel})
6163
6264
Wrapper method for the forward pass.
6365
This method will consider as input a currently solved problem and
@@ -66,7 +68,9 @@ the [`ForwardObjectiveFunction`](@ref) and [`ForwardConstraintFunction`](@ref)
6668
The output solution differentials can be queried with the attribute
6769
[`ForwardVariablePrimal`](@ref).
6870
"""
69-
function forward_differentiate! end
71+
function forward_differentiate!(model)
72+
return MOI.set(model, ForwardDifferentiate(), nothing)
73+
end
7074

7175
"""
7276
empty_input_sensitivities!(model::MOI.ModelLike)
@@ -314,6 +318,40 @@ the differentiation information.
314318
"""
315319
struct DifferentiateTimeSec <: MOI.AbstractModelAttribute end
316320

321+
"""
322+
ReverseDifferentiate <: MOI.AbstractOptimizerAttribute
323+
324+
An `MOI.AbstractOptimizerAttribute` that triggers reverse differentiation
325+
on the solver. If `MOI.supports(optimizer, DiffOpt.ReverseDifferentiate())`
326+
returns `true`, then the solver natively supports reverse differentiation
327+
through the DiffOpt attribute interface, and DiffOpt will delegate
328+
differentiation directly to the solver instead of using its own
329+
differentiation backend.
330+
331+
Trigger the computation with:
332+
```julia
333+
MOI.set(optimizer, DiffOpt.ReverseDifferentiate(), nothing)
334+
```
335+
"""
336+
struct ReverseDifferentiate <: MOI.AbstractOptimizerAttribute end
337+
338+
"""
339+
ForwardDifferentiate <: MOI.AbstractOptimizerAttribute
340+
341+
An `MOI.AbstractOptimizerAttribute` that triggers forward differentiation
342+
on the solver. If `MOI.supports(optimizer, DiffOpt.ForwardDifferentiate())`
343+
returns `true`, then the solver natively supports forward differentiation
344+
through the DiffOpt attribute interface, and DiffOpt will delegate
345+
differentiation directly to the solver instead of using its own
346+
differentiation backend.
347+
348+
Trigger the computation with:
349+
```julia
350+
MOI.set(optimizer, DiffOpt.ForwardDifferentiate(), nothing)
351+
```
352+
"""
353+
struct ForwardDifferentiate <: MOI.AbstractOptimizerAttribute end
354+
317355
MOI.attribute_value_type(::DifferentiateTimeSec) = Float64
318356

319357
MOI.is_set_by_optimize(::DifferentiateTimeSec) = true
@@ -434,6 +472,14 @@ function MOI.set(model::AbstractModel, ::ForwardObjectiveFunction, objective)
434472
return
435473
end
436474

475+
function MOI.supports(::AbstractModel, ::NonLinearKKTJacobianFactorization)
476+
return true
477+
end
478+
479+
function MOI.supports(::AbstractModel, ::AllowObjectiveAndSolutionInput)
480+
return true
481+
end
482+
437483
function MOI.set(
438484
model::AbstractModel,
439485
::NonLinearKKTJacobianFactorization,

src/jump_moi_overloads.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,20 @@ end
251251

252252
MOI.constant(func::IndexMappedFunction) = MOI.constant(func.func)
253253

254+
# Support JuMP.coefficient on plain MOI functions returned by native solvers
255+
function JuMP.coefficient(
256+
func::MOI.ScalarAffineFunction{T},
257+
vi::MOI.VariableIndex,
258+
) where {T}
259+
coef = zero(T)
260+
for term in func.terms
261+
if term.variable == vi
262+
coef += term.coefficient
263+
end
264+
end
265+
return coef
266+
end
267+
254268
function JuMP.coefficient(func::IndexMappedFunction, vi::MOI.VariableIndex)
255269
return JuMP.coefficient(func.func, func.index_map[vi])
256270
end

src/moi_wrapper.jl

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,9 @@ end
548548

549549
MOI.get(model::Optimizer, ::ModelConstructor) = model.model_constructor
550550

551-
function reverse_differentiate!(model::Optimizer)
551+
MOI.supports(::Optimizer, ::ReverseDifferentiate) = true
552+
553+
function MOI.set(model::Optimizer, attr::ReverseDifferentiate, value)
552554
st = MOI.get(model.optimizer, MOI.TerminationStatus())
553555
if !in(st, (MOI.LOCALLY_SOLVED, MOI.OPTIMAL))
554556
error(
@@ -562,17 +564,22 @@ function reverse_differentiate!(model::Optimizer)
562564
"Set `DiffOpt.AllowObjectiveAndSolutionInput()` to `true` to silence this warning."
563565
end
564566
end
565-
diff = _diff(model)
566-
MOI.set(
567-
diff,
568-
NonLinearKKTJacobianFactorization(),
569-
model.input_cache.factorization,
570-
)
571-
MOI.set(
572-
diff,
573-
AllowObjectiveAndSolutionInput(),
574-
model.input_cache.allow_objective_and_solution_input,
575-
)
567+
diff = _diff(model, attr)
568+
# Native differentiation interface are allow to not support it
569+
if MOI.supports(diff, NonLinearKKTJacobianFactorization())
570+
MOI.set(
571+
diff,
572+
NonLinearKKTJacobianFactorization(),
573+
model.input_cache.factorization,
574+
)
575+
end
576+
if MOI.supports(diff, AllowObjectiveAndSolutionInput())
577+
MOI.set(
578+
diff,
579+
AllowObjectiveAndSolutionInput(),
580+
model.input_cache.allow_objective_and_solution_input,
581+
)
582+
end
576583
for (vi, value) in model.input_cache.dx
577584
MOI.set(diff, ReverseVariablePrimal(), model.index_map[vi], value)
578585
end
@@ -593,7 +600,7 @@ function reverse_differentiate!(model::Optimizer)
593600
end
594601
end
595602
end
596-
return reverse_differentiate!(diff)
603+
return MOI.set(diff, attr, value)
597604
end
598605

599606
# Gradient evaluation functions for objective sensitivity fallbacks
@@ -639,13 +646,12 @@ function _eval_gradient(
639646
end
640647

641648
function _fallback_set_reverse_objective_sensitivity(model::Optimizer, val)
642-
diff = _diff(model)
643649
obj_type = MOI.get(model, MOI.ObjectiveFunctionType())
644650
obj_func = MOI.get(model, MOI.ObjectiveFunction{obj_type}())
645651
grad = _eval_gradient(model, obj_func)
646652
for (xi, df_dxi) in grad
647653
MOI.set(
648-
diff,
654+
model.diff,
649655
ReverseVariablePrimal(),
650656
model.index_map[xi],
651657
df_dxi * val,
@@ -666,24 +672,30 @@ function _copy_forward_in_constraint(diff, index_map, con_map, constraints)
666672
return
667673
end
668674

669-
function forward_differentiate!(model::Optimizer)
675+
MOI.supports(model::Optimizer, attr::ForwardDifferentiate) = true
676+
677+
function MOI.set(model::Optimizer, attr::ForwardDifferentiate, value)
670678
st = MOI.get(model.optimizer, MOI.TerminationStatus())
671679
if !in(st, (MOI.LOCALLY_SOLVED, MOI.OPTIMAL))
672680
error(
673681
"Trying to compute the forward differentiation on a model with termination status $(st)",
674682
)
675683
end
676-
diff = _diff(model)
677-
MOI.set(
678-
diff,
679-
NonLinearKKTJacobianFactorization(),
680-
model.input_cache.factorization,
681-
)
682-
MOI.set(
683-
diff,
684-
AllowObjectiveAndSolutionInput(),
685-
model.input_cache.allow_objective_and_solution_input,
686-
)
684+
diff = _diff(model, attr)
685+
if MOI.supports(diff, NonLinearKKTJacobianFactorization())
686+
MOI.set(
687+
diff,
688+
NonLinearKKTJacobianFactorization(),
689+
model.input_cache.factorization,
690+
)
691+
end
692+
if MOI.supports(diff, AllowObjectiveAndSolutionInput())
693+
MOI.set(
694+
diff,
695+
AllowObjectiveAndSolutionInput(),
696+
model.input_cache.allow_objective_and_solution_input,
697+
)
698+
end
687699
T = Float64
688700
list = MOI.get(
689701
model,
@@ -700,7 +712,7 @@ function forward_differentiate!(model::Optimizer)
700712
MOI.Parameter(value),
701713
)
702714
end
703-
return forward_differentiate!(diff)
715+
return MOI.set(diff, attr, value)
704716
end
705717
# @show "func mode"
706718
if model.input_cache.objective !== nothing
@@ -733,7 +745,7 @@ function forward_differentiate!(model::Optimizer)
733745
model.input_cache.vector_constraints[F, S],
734746
)
735747
end
736-
return forward_differentiate!(diff)
748+
return MOI.set(diff, attr, value)
737749
end
738750

739751
function empty_input_sensitivities!(model::Optimizer)
@@ -782,8 +794,24 @@ function _instantiate_diff(model::Optimizer, constructor)
782794
return model_bridged
783795
end
784796

785-
function _diff(model::Optimizer)
786-
if model.diff === nothing
797+
# Find the native differentiation solver in the optimizer chain.
798+
# Cached in `model.diff` to avoid repeated unwrapping.
799+
function _native_diff_solver(model::Optimizer)
800+
if isnothing(model.diff)
801+
model.diff = model.optimizer
802+
model.index_map = MOI.Utilities.identity_index_map(model.optimizer)
803+
end
804+
return model.diff
805+
end
806+
807+
function _diff(
808+
model::Optimizer,
809+
attr::Union{ForwardDifferentiate,ReverseDifferentiate},
810+
)
811+
if MOI.supports(model.optimizer, attr)
812+
model.diff = model.optimizer
813+
model.index_map = MOI.Utilities.identity_index_map(model.optimizer)
814+
elseif isnothing(model.diff)
787815
_check_termination_status(model)
788816
model_constructor = MOI.get(model, ModelConstructor())
789817
if isnothing(model_constructor)
@@ -1128,15 +1156,7 @@ function MOI.get(model::Optimizer, attr::DifferentiateTimeSec)
11281156
return MOI.get(model.diff, attr)
11291157
end
11301158

1131-
function MOI.supports(
1132-
::Optimizer,
1133-
::NonLinearKKTJacobianFactorization,
1134-
::Function,
1135-
)
1136-
return true
1137-
end
1138-
1139-
function MOI.supports(::Optimizer, ::AllowObjectiveAndSolutionInput, ::Bool)
1159+
function MOI.supports(::Optimizer, ::NonLinearKKTJacobianFactorization)
11401160
return true
11411161
end
11421162

src/parameters.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,11 @@ function empty_input_sensitivities!(model::POI.Optimizer{T}) where {T}
479479
return
480480
end
481481

482-
function forward_differentiate!(model::POI.Optimizer{T}) where {T}
483-
empty_input_sensitivities!(model.optimizer)
482+
function MOI.set(
483+
model::POI.Optimizer{T},
484+
attr::ForwardDifferentiate,
485+
value,
486+
) where {T}
484487
ctr_types = MOI.get(model, POI.ListOfParametricConstraintTypesPresent())
485488
for (F, S, P) in ctr_types
486489
dict = MOI.get(
@@ -497,7 +500,7 @@ function forward_differentiate!(model::POI.Optimizer{T}) where {T}
497500
elseif obj_type <: POI.ParametricCubicFunction
498501
_cubic_objective_set_forward!(model)
499502
end
500-
forward_differentiate!(model.optimizer)
503+
MOI.set(model.optimizer, attr, value)
501504
return
502505
end
503506

@@ -699,8 +702,8 @@ function _quadratic_objective_get_reverse!(model::POI.Optimizer{T}) where {T}
699702
return
700703
end
701704

702-
function reverse_differentiate!(model::POI.Optimizer)
703-
reverse_differentiate!(model.optimizer)
705+
function MOI.set(model::POI.Optimizer, attr::ReverseDifferentiate, value)
706+
MOI.set(model.optimizer, attr, value)
704707
sensitivity_data = _get_sensitivity_data(model)
705708
empty!(sensitivity_data.parameter_output_backward)
706709
sizehint!(

test/jump.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,11 +688,9 @@ function test_conic_supports()
688688
DiffOpt.ForwardConstraintSet(),
689689
MOI.ConstraintIndex{MOI.VariableIndex,MOI.Parameter{Float64}},
690690
)
691-
function some end
692691
@test MOI.supports(
693692
backend(model),
694693
DiffOpt.NonLinearKKTJacobianFactorization(),
695-
some,
696694
)
697695
MOI.is_set_by_optimize(DiffOpt.ReverseConstraintFunction())
698696
MOI.is_set_by_optimize(DiffOpt.ReverseConstraintSet())

0 commit comments

Comments
 (0)