@@ -565,14 +565,19 @@ impl MinPointState {
565565 true
566566 } else if ( new_depth, new_size) > ( self . best_depth , self . best_size ) {
567567 self . count += 1 ;
568- true
568+ // Limit to 5 iterations without improvement
569+ // and back track to previous best
570+ self . count < 5
569571 } else if ( new_depth, new_size) < ( self . best_depth , self . best_size ) {
570572 self . count = 1 ;
571573 self . best_depth = new_depth;
572574 self . best_size = new_size;
575+ self . best_dag = dag. clone ( ) ;
573576 true
574577 } else {
575- ( new_depth, new_size) != ( self . best_depth , self . best_size )
578+ // When depth and size are unchanged, we reach a fixed point
579+ // and can stop the optimization loop
580+ false
576581 }
577582 }
578583}
@@ -847,4 +852,115 @@ mod tests {
847852 assert ! ( result. 1 . output_permutation( ) . is_some( ) ) ;
848853 }
849854 }
855+
856+ #[ test]
857+ fn test_update_best_dag ( ) {
858+ let circuit1 = CircuitData :: from_packed_operations (
859+ 1 ,
860+ 1 ,
861+ vec ! [ Ok ( (
862+ StandardGate :: H . into( ) ,
863+ smallvec![ ] ,
864+ vec![ Qubit ( 0 ) ] ,
865+ vec![ ] ,
866+ ) ) ] ,
867+ Param :: Float ( 0. ) ,
868+ )
869+ . unwrap ( ) ;
870+
871+ let dag1 = DAGCircuit :: from_circuit_data ( & circuit1, false , None , None , None , None ) . unwrap ( ) ;
872+ let circuit2 = CircuitData :: from_packed_operations ( 1 , 1 , vec ! [ ] , Param :: Float ( 0. ) ) . unwrap ( ) ;
873+ let dag2 = DAGCircuit :: from_circuit_data ( & circuit2, false , None , None , None , None ) . unwrap ( ) ;
874+
875+ let mut state = MinPointState :: new ( & dag1) ;
876+ assert ! ( state. update_with( & dag1) ) ;
877+ assert_eq ! ( state. count, 0 ) ;
878+ assert ! ( state. update_with( & dag2) ) ;
879+ assert_eq ! ( state. count, 1 ) ;
880+ assert_eq ! (
881+ state. best_dag. depth( false ) . unwrap( ) ,
882+ dag2. depth( false ) . unwrap( )
883+ ) ;
884+ assert_eq ! (
885+ state. best_dag. size( false ) . unwrap( ) ,
886+ dag2. size( false ) . unwrap( )
887+ ) ;
888+ }
889+
890+ #[ test]
891+ fn test_backtrack_limit_stops_loop ( ) {
892+ let circuit1 = CircuitData :: from_packed_operations ( 1 , 1 , vec ! [ ] , Param :: Float ( 0. ) ) . unwrap ( ) ;
893+ let dag1 = DAGCircuit :: from_circuit_data ( & circuit1, false , None , None , None , None ) . unwrap ( ) ;
894+ let circuit2 = CircuitData :: from_packed_operations (
895+ 1 ,
896+ 1 ,
897+ vec ! [ Ok ( (
898+ StandardGate :: H . into( ) ,
899+ smallvec![ ] ,
900+ vec![ Qubit ( 0 ) ] ,
901+ vec![ ] ,
902+ ) ) ] ,
903+ Param :: Float ( 0. ) ,
904+ )
905+ . unwrap ( ) ;
906+
907+ let dag2 = DAGCircuit :: from_circuit_data ( & circuit2, false , None , None , None , None ) . unwrap ( ) ;
908+ let mut state = MinPointState :: new ( & dag1) ;
909+
910+ state. update_with ( & dag1) ;
911+ for i in 0 ..5 {
912+ let continue_loop = state. update_with ( & dag2) ;
913+ if i < 4 {
914+ assert ! ( continue_loop) ;
915+ } else {
916+ assert ! ( !continue_loop) ;
917+ }
918+ }
919+ }
920+
921+ #[ test]
922+ fn test_backtrack_resets_on_improvement ( ) {
923+ let circuit1 = CircuitData :: from_packed_operations (
924+ 1 ,
925+ 1 ,
926+ vec ! [
927+ Ok ( ( StandardGate :: H . into( ) , smallvec![ ] , vec![ Qubit ( 0 ) ] , vec![ ] ) ) ,
928+ Ok ( ( StandardGate :: H . into( ) , smallvec![ ] , vec![ Qubit ( 0 ) ] , vec![ ] ) ) ,
929+ ] ,
930+ Param :: Float ( 0. ) ,
931+ )
932+ . unwrap ( ) ;
933+ let dag_worst =
934+ DAGCircuit :: from_circuit_data ( & circuit1, false , None , None , None , None ) . unwrap ( ) ;
935+ let circuit2 = CircuitData :: from_packed_operations (
936+ 1 ,
937+ 1 ,
938+ vec ! [ Ok ( (
939+ StandardGate :: H . into( ) ,
940+ smallvec![ ] ,
941+ vec![ Qubit ( 0 ) ] ,
942+ vec![ ] ,
943+ ) ) ] ,
944+ Param :: Float ( 0. ) ,
945+ )
946+ . unwrap ( ) ;
947+ let dag_better =
948+ DAGCircuit :: from_circuit_data ( & circuit2, false , None , None , None , None ) . unwrap ( ) ;
949+ let circuit3 = CircuitData :: from_packed_operations ( 1 , 1 , vec ! [ ] , Param :: Float ( 0. ) ) . unwrap ( ) ;
950+ let dag_best =
951+ DAGCircuit :: from_circuit_data ( & circuit3, false , None , None , None , None ) . unwrap ( ) ;
952+ let mut state = MinPointState :: new ( & dag_worst) ;
953+
954+ state. update_with ( & dag_worst) ;
955+ state. update_with ( & dag_better) ;
956+ for _i in 0 ..3 {
957+ state. update_with ( & dag_worst) ;
958+ }
959+ state. update_with ( & dag_best) ;
960+ assert_eq ! ( state. count, 1 ) ;
961+ // After updating to the dag_best the state tracked dag should
962+ // be empty
963+ assert_eq ! ( state. best_dag. depth( false ) . unwrap( ) , 0 ) ;
964+ assert_eq ! ( state. best_dag. size( false ) . unwrap( ) , 0 ) ;
965+ }
850966}
0 commit comments