548548
549549MOI. get (model:: Optimizer , :: ModelConstructor ) = model. model_constructor
550550
551- function reverse_differentiate! (model:: Optimizer )
551+ MOI. supports (:: Optimizer , :: ReverseDifferentiate ) = true
552+
553+ function MOI. set (model:: Optimizer , attr:: ReverseDifferentiate , value)
552554 st = MOI. get (model. optimizer, MOI. TerminationStatus ())
553555 if ! in (st, (MOI. LOCALLY_SOLVED, MOI. OPTIMAL))
554556 error (
@@ -562,17 +564,22 @@ function reverse_differentiate!(model::Optimizer)
562564 " Set `DiffOpt.AllowObjectiveAndSolutionInput()` to `true` to silence this warning."
563565 end
564566 end
565- diff = _diff (model)
566- MOI. set (
567- diff,
568- NonLinearKKTJacobianFactorization (),
569- model. input_cache. factorization,
570- )
571- MOI. set (
572- diff,
573- AllowObjectiveAndSolutionInput (),
574- model. input_cache. allow_objective_and_solution_input,
575- )
567+ diff = _diff (model, attr)
568+ # Native differentiation interface are allow to not support it
569+ if MOI. supports (diff, NonLinearKKTJacobianFactorization ())
570+ MOI. set (
571+ diff,
572+ NonLinearKKTJacobianFactorization (),
573+ model. input_cache. factorization,
574+ )
575+ end
576+ if MOI. supports (diff, AllowObjectiveAndSolutionInput ())
577+ MOI. set (
578+ diff,
579+ AllowObjectiveAndSolutionInput (),
580+ model. input_cache. allow_objective_and_solution_input,
581+ )
582+ end
576583 for (vi, value) in model. input_cache. dx
577584 MOI. set (diff, ReverseVariablePrimal (), model. index_map[vi], value)
578585 end
@@ -593,7 +600,7 @@ function reverse_differentiate!(model::Optimizer)
593600 end
594601 end
595602 end
596- return reverse_differentiate! (diff)
603+ return MOI . set (diff, attr, value )
597604end
598605
599606# Gradient evaluation functions for objective sensitivity fallbacks
@@ -639,13 +646,12 @@ function _eval_gradient(
639646end
640647
641648function _fallback_set_reverse_objective_sensitivity (model:: Optimizer , val)
642- diff = _diff (model)
643649 obj_type = MOI. get (model, MOI. ObjectiveFunctionType ())
644650 obj_func = MOI. get (model, MOI. ObjectiveFunction {obj_type} ())
645651 grad = _eval_gradient (model, obj_func)
646652 for (xi, df_dxi) in grad
647653 MOI. set (
648- diff,
654+ model . diff,
649655 ReverseVariablePrimal (),
650656 model. index_map[xi],
651657 df_dxi * val,
@@ -666,24 +672,30 @@ function _copy_forward_in_constraint(diff, index_map, con_map, constraints)
666672 return
667673end
668674
669- function forward_differentiate! (model:: Optimizer )
675+ MOI. supports (model:: Optimizer , attr:: ForwardDifferentiate ) = true
676+
677+ function MOI. set (model:: Optimizer , attr:: ForwardDifferentiate , value)
670678 st = MOI. get (model. optimizer, MOI. TerminationStatus ())
671679 if ! in (st, (MOI. LOCALLY_SOLVED, MOI. OPTIMAL))
672680 error (
673681 " Trying to compute the forward differentiation on a model with termination status $(st) " ,
674682 )
675683 end
676- diff = _diff (model)
677- MOI. set (
678- diff,
679- NonLinearKKTJacobianFactorization (),
680- model. input_cache. factorization,
681- )
682- MOI. set (
683- diff,
684- AllowObjectiveAndSolutionInput (),
685- model. input_cache. allow_objective_and_solution_input,
686- )
684+ diff = _diff (model, attr)
685+ if MOI. supports (diff, NonLinearKKTJacobianFactorization ())
686+ MOI. set (
687+ diff,
688+ NonLinearKKTJacobianFactorization (),
689+ model. input_cache. factorization,
690+ )
691+ end
692+ if MOI. supports (diff, AllowObjectiveAndSolutionInput ())
693+ MOI. set (
694+ diff,
695+ AllowObjectiveAndSolutionInput (),
696+ model. input_cache. allow_objective_and_solution_input,
697+ )
698+ end
687699 T = Float64
688700 list = MOI. get (
689701 model,
@@ -700,7 +712,7 @@ function forward_differentiate!(model::Optimizer)
700712 MOI. Parameter (value),
701713 )
702714 end
703- return forward_differentiate! (diff)
715+ return MOI . set (diff, attr, value )
704716 end
705717 # @show "func mode"
706718 if model. input_cache. objective != = nothing
@@ -733,7 +745,7 @@ function forward_differentiate!(model::Optimizer)
733745 model. input_cache. vector_constraints[F, S],
734746 )
735747 end
736- return forward_differentiate! (diff)
748+ return MOI . set (diff, attr, value )
737749end
738750
739751function empty_input_sensitivities! (model:: Optimizer )
@@ -782,8 +794,24 @@ function _instantiate_diff(model::Optimizer, constructor)
782794 return model_bridged
783795end
784796
785- function _diff (model:: Optimizer )
786- if model. diff === nothing
797+ # Find the native differentiation solver in the optimizer chain.
798+ # Cached in `model.diff` to avoid repeated unwrapping.
799+ function _native_diff_solver (model:: Optimizer )
800+ if isnothing (model. diff)
801+ model. diff = model. optimizer
802+ model. index_map = MOI. Utilities. identity_index_map (model. optimizer)
803+ end
804+ return model. diff
805+ end
806+
807+ function _diff (
808+ model:: Optimizer ,
809+ attr:: Union{ForwardDifferentiate,ReverseDifferentiate} ,
810+ )
811+ if MOI. supports (model. optimizer, attr)
812+ model. diff = model. optimizer
813+ model. index_map = MOI. Utilities. identity_index_map (model. optimizer)
814+ elseif isnothing (model. diff)
787815 _check_termination_status (model)
788816 model_constructor = MOI. get (model, ModelConstructor ())
789817 if isnothing (model_constructor)
@@ -1128,15 +1156,7 @@ function MOI.get(model::Optimizer, attr::DifferentiateTimeSec)
11281156 return MOI. get (model. diff, attr)
11291157end
11301158
1131- function MOI. supports (
1132- :: Optimizer ,
1133- :: NonLinearKKTJacobianFactorization ,
1134- :: Function ,
1135- )
1136- return true
1137- end
1138-
1139- function MOI. supports (:: Optimizer , :: AllowObjectiveAndSolutionInput , :: Bool )
1159+ function MOI. supports (:: Optimizer , :: NonLinearKKTJacobianFactorization )
11401160 return true
11411161end
11421162
0 commit comments