From 96dd0746480aef3e83708b1a1e1ccaf06105a4a2 Mon Sep 17 00:00:00 2001
From: Luc Maisonobe <luc@orekit.org>
Date: Tue, 30 Nov 2021 11:38:14 +0100
Subject: [PATCH] Handle yield feature for additional equations.

---
 .../AbstractIntegratedPropagator.java         | 53 ++++++++----
 .../FieldAbstractIntegratedPropagator.java    | 53 ++++++++----
 .../integration/AdditionalEquationsTest.java  | 66 ++++++++++++++-
 .../FieldAdditionalEquationsTest.java         | 80 ++++++++++++++++++-
 4 files changed, 219 insertions(+), 33 deletions(-)

diff --git a/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java b/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java
index c23a0ab4b7..434175c8f2 100644
--- a/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java
+++ b/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java
@@ -21,8 +21,10 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Queue;
 
 import org.hipparchus.exception.MathRuntimeException;
 import org.hipparchus.ode.DenseOutputModel;
@@ -464,14 +466,15 @@ public abstract class AbstractIntegratedPropagator extends AbstractPropagator {
                                                         propagationType);
 
             if (!additionalEquations.isEmpty()) {
-                final double[] secondary = mathFinalState.getSecondaryState(1);
-                int offset = 0;
+                final double[] secondary            = mathFinalState.getSecondaryState(1);
+                final double[] secondaryDerivatives = mathFinalState.getSecondaryDerivative(1);
                 for (AdditionalEquations equations : additionalEquations) {
-                    finalState = finalState.addAdditionalState(equations.getName(),
-                                                               Arrays.copyOfRange(secondary,
-                                                                                  offset,
-                                                                                  offset + equations.getDimension()));
-                    offset += equations.getDimension();
+                    final String   name        = equations.getName();
+                    final int      offset      = secondaryOffsets.get(name);
+                    final int      dimension   = equations.getDimension();
+                    finalState = finalState.
+                                 addAdditionalState(name, Arrays.copyOfRange(secondary, offset, offset + dimension)).
+                                 addAdditionalStateDerivative(name, Arrays.copyOfRange(secondaryDerivatives, offset, offset + dimension));
                 }
             }
             finalState = updateAdditionalStates(finalState);
@@ -777,16 +780,36 @@ public abstract class AbstractIntegratedPropagator extends AbstractPropagator {
             // update space dynamics view
             // the integrable generators generate method will be called here,
             // according to the generators yield order
-            final SpacecraftState currentState = convert(t, primary, primaryDot, secondary);
+            SpacecraftState updated = convert(t, primary, primaryDot, secondary);
 
-            // gather the derivatives from all integrable generators
+            // set up queue for equations
+            final Queue<AdditionalEquations> pending = new LinkedList<>(additionalEquations);
+
+            // gather the derivatives from all additional equations, taking care of dependencies
             final double[] secondaryDot = new double[combinedDimension];
-            // FIXME: use yield to compute derivatives in dependencies order
-            for (final AdditionalEquations equations : additionalEquations) {
-                final String name      = equations.getName();
-                final int    offset    = secondaryOffsets.get(name);
-                final int    dimension = equations.getDimension();
-                System.arraycopy(equations.derivatives(currentState), 0, secondaryDot, offset, dimension);
+            int yieldCount = 0;
+            while (!pending.isEmpty()) {
+                final AdditionalEquations equations = pending.remove();
+                if (equations.yield(updated)) {
+                    // these equations have to wait for another set,
+                    // we put them again in the pending queue
+                    pending.add(equations);
+                    if (++yieldCount >= pending.size()) {
+                        // all pending equations yielded!, they probably need data not yet initialized
+                        // we let the propagation proceed, if these data are really needed right now
+                        // an appropriate exception will be triggered when caller tries to access them
+                        break;
+                    }
+                } else {
+                    // we can use these equations right now
+                    final String   name        = equations.getName();
+                    final int      offset      = secondaryOffsets.get(name);
+                    final int      dimension   = equations.getDimension();
+                    final double[] derivatives = equations.derivatives(updated);
+                    System.arraycopy(derivatives, 0, secondaryDot, offset, dimension);
+                    updated = updated.addAdditionalStateDerivative(name, derivatives);
+                    yieldCount = 0;
+                }
             }
 
             return secondaryDot;
diff --git a/src/main/java/org/orekit/propagation/integration/FieldAbstractIntegratedPropagator.java b/src/main/java/org/orekit/propagation/integration/FieldAbstractIntegratedPropagator.java
index 5388401419..6abaf241cb 100644
--- a/src/main/java/org/orekit/propagation/integration/FieldAbstractIntegratedPropagator.java
+++ b/src/main/java/org/orekit/propagation/integration/FieldAbstractIntegratedPropagator.java
@@ -21,8 +21,10 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Queue;
 
 import org.hipparchus.CalculusFieldElement;
 import org.hipparchus.Field;
@@ -456,14 +458,15 @@ public abstract class FieldAbstractIntegratedPropagator<T extends CalculusFieldE
                                                         mathFinalState.getPrimaryDerivative(),
                                                         propagationType);
             if (!additionalEquations.isEmpty()) {
-                final T[] secondary = mathFinalState.getSecondaryState(1);
-                int offset = 0;
+                final T[] secondary            = mathFinalState.getSecondaryState(1);
+                final T[] secondaryDerivatives = mathFinalState.getSecondaryDerivative(1);
                 for (FieldAdditionalEquations<T> equations : additionalEquations) {
-                    finalState = finalState.addAdditionalState(equations.getName(),
-                                                               Arrays.copyOfRange(secondary,
-                                                                                  offset,
-                                                                                  offset + equations.getDimension()));
-                    offset += equations.getDimension();
+                    final String   name        = equations.getName();
+                    final int      offset      = secondaryOffsets.get(name);
+                    final int      dimension   = equations.getDimension();
+                    finalState = finalState.
+                                 addAdditionalState(name, Arrays.copyOfRange(secondary, offset, offset + dimension)).
+                                 addAdditionalStateDerivative(name, Arrays.copyOfRange(secondaryDerivatives, offset, offset + dimension));
                 }
             }
             finalState = updateAdditionalStates(finalState);
@@ -773,16 +776,36 @@ public abstract class FieldAbstractIntegratedPropagator<T extends CalculusFieldE
             // update space dynamics view
             // the integrable generators generate method will be called here,
             // according to the generators yield order
-            final FieldSpacecraftState<T> currentState = convert(t, primary, primaryDot, secondary);
+            FieldSpacecraftState<T> updated = convert(t, primary, primaryDot, secondary);
 
-            // gather the derivatives from all integrable generators
+            // set up queue for equations
+            final Queue<FieldAdditionalEquations<T>> pending = new LinkedList<>(additionalEquations);
+
+            // gather the derivatives from all additional equations, taking care of dependencies
             final T[] secondaryDot = MathArrays.buildArray(t.getField(), combinedDimension);
-            // FIXME: use yield to compute derivatives in dependencies order
-            for (final FieldAdditionalEquations<T> equations : additionalEquations) {
-                final String name      = equations.getName();
-                final int    offset    = secondaryOffsets.get(name);
-                final int    dimension = equations.getDimension();
-                System.arraycopy(equations.derivatives(currentState), 0, secondaryDot, offset, dimension);
+            int yieldCount = 0;
+            while (!pending.isEmpty()) {
+                final FieldAdditionalEquations<T> equations = pending.remove();
+                if (equations.yield(updated)) {
+                    // these equations have to wait for another set,
+                    // we put them again in the pending queue
+                    pending.add(equations);
+                    if (++yieldCount >= pending.size()) {
+                        // all pending equations yielded!, they probably need data not yet initialized
+                        // we let the propagation proceed, if these data are really needed right now
+                        // an appropriate exception will be triggered when caller tries to access them
+                        break;
+                    }
+                } else {
+                    // we can use these equations right now
+                    final String name        = equations.getName();
+                    final int    offset      = secondaryOffsets.get(name);
+                    final int    dimension   = equations.getDimension();
+                    final T[]    derivatives = equations.derivatives(updated);
+                    System.arraycopy(derivatives, 0, secondaryDot, offset, dimension);
+                    updated = updated.addAdditionalStateDerivative(name, derivatives);
+                    yieldCount = 0;
+                }
             }
 
             return secondaryDot;
diff --git a/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java b/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java
index b5a9d79380..165e413bda 100644
--- a/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java
+++ b/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java
@@ -132,6 +132,36 @@ public class AdditionalEquationsTest {
 
     }
 
+    @Test
+    public void testYield() {
+
+        // setup
+        final double init1 = 1.0;
+        final double init2 = 2.0;
+        final double rate  = 0.5;
+        final double dt    = 600;
+        Yield yield1 = new Yield(null, "yield-1", rate);
+        Yield yield2 = new Yield(yield1.getName(), "yield-2", Double.NaN);
+
+        // action
+        AdaptiveStepsizeIntegrator integrator = new DormandPrince853Integrator(0.001, 200, tolerance[0], tolerance[1]);
+        integrator.setInitialStepSize(60);
+        NumericalPropagator propagatorNumerical = new NumericalPropagator(integrator);
+        propagatorNumerical.setInitialState(initialState.
+                                            addAdditionalState(yield1.getName(), init1).
+                                            addAdditionalState(yield2.getName(), init2));
+        propagatorNumerical.addAdditionalEquations(yield2); // we intentionally register yield2 before yield 1 to check reordering
+        propagatorNumerical.addAdditionalEquations(yield1);
+        SpacecraftState finalState = propagatorNumerical.propagate(initDate.shiftedBy(dt));
+
+        // verify
+        Assert.assertEquals(init1 + dt * rate, finalState.getAdditionalState(yield1.getName())[0],           1.0e-10);
+        Assert.assertEquals(init2 + dt * rate, finalState.getAdditionalState(yield2.getName())[0],           1.0e-10);
+        Assert.assertEquals(rate,              finalState.getAdditionalStateDerivative(yield1.getName())[0], 1.0e-10);
+        Assert.assertEquals(rate,              finalState.getAdditionalStateDerivative(yield2.getName())[0], 1.0e-10);
+
+    }
+
     @Before
     public void setUp() {
         Utils.setDataRoot("regular-data:potential/shm-format");
@@ -153,7 +183,7 @@ public class AdditionalEquationsTest {
         tolerance    = null;
     }
 
-    public static class Linear implements AdditionalEquations {
+    private static class Linear implements AdditionalEquations {
 
         private String  name;
         private double  expectedAtInit;
@@ -194,4 +224,38 @@ public class AdditionalEquationsTest {
 
     }
 
+    private static class Yield implements AdditionalEquations {
+
+        private String dependency;
+        private String name;
+        private double rate;
+
+        public Yield(final String dependency, final String name, final double rate) {
+            this.dependency = dependency;
+            this.name       = name;
+            this.rate       = rate;
+        }
+
+        @Override
+        public double[] derivatives(final SpacecraftState s) {
+            return dependency == null ? new double[] { rate } : s.getAdditionalStateDerivative(dependency);
+        }
+
+        @Override
+        public boolean yield(final SpacecraftState state) {
+            return dependency != null && !state.hasAdditionalStateDerivative(dependency);
+        }
+
+        @Override
+        public int getDimension() {
+            return 1;
+        }
+
+        @Override
+        public String getName() {
+            return name;
+        }
+
+    }
+
 }
diff --git a/src/test/java/org/orekit/propagation/integration/FieldAdditionalEquationsTest.java b/src/test/java/org/orekit/propagation/integration/FieldAdditionalEquationsTest.java
index 17f85f1681..8bcbf325c8 100644
--- a/src/test/java/org/orekit/propagation/integration/FieldAdditionalEquationsTest.java
+++ b/src/test/java/org/orekit/propagation/integration/FieldAdditionalEquationsTest.java
@@ -65,10 +65,15 @@ public class FieldAdditionalEquationsTest {
     }
 
     @Test
-    public void testResetStateT() {
+    public void testResetState() {
         doTestResetState(Decimal64Field.getInstance());
     }
 
+    @Test
+    public void testYield() {
+        doTestYield(Decimal64Field.getInstance());
+    }
+
     private <T extends CalculusFieldElement<T>> void doTestInitNumerical(Field<T> field) {
         // setup
         final double reference = 1.25;
@@ -149,6 +154,36 @@ public class FieldAdditionalEquationsTest {
 
     }
 
+    private <T extends CalculusFieldElement<T>> void doTestYield(Field<T> field) {
+
+        // setup
+        final double init1 = 1.0;
+        final double init2 = 2.0;
+        final double rate  = 0.5;
+        final double dt    = 600;
+        Yield<T> yield1 = new Yield<>(null, "yield-1", rate);
+        Yield<T> yield2 = new Yield<>(yield1.getName(), "yield-2", Double.NaN);
+
+        // action
+        AdaptiveStepsizeFieldIntegrator<T> integrator = new DormandPrince853FieldIntegrator<>(field, 0.001, 200,
+                        tolerance[0], tolerance[1]);
+        integrator.setInitialStepSize(60);
+        FieldNumericalPropagator<T> propagatorNumerical = new FieldNumericalPropagator<>(field, integrator);
+        propagatorNumerical.setInitialState(new FieldSpacecraftState<>(field, initialState).
+                                            addAdditionalState(yield1.getName(), field.getZero().newInstance(init1)).
+                                            addAdditionalState(yield2.getName(), field.getZero().newInstance(init2)));
+        propagatorNumerical.addAdditionalEquations(yield2); // we intentionally register yield2 before yield 1 to check reordering
+        propagatorNumerical.addAdditionalEquations(yield1);
+        FieldSpacecraftState<T> finalState = propagatorNumerical.propagate(new FieldAbsoluteDate<>(field, initDate).shiftedBy(dt));
+
+        // verify
+        Assert.assertEquals(init1 + dt * rate, finalState.getAdditionalState(yield1.getName())[0].getReal(),           1.0e-10);
+        Assert.assertEquals(init2 + dt * rate, finalState.getAdditionalState(yield2.getName())[0].getReal(),           1.0e-10);
+        Assert.assertEquals(rate,              finalState.getAdditionalStateDerivative(yield1.getName())[0].getReal(), 1.0e-10);
+        Assert.assertEquals(rate,              finalState.getAdditionalStateDerivative(yield2.getName())[0].getReal(), 1.0e-10);
+
+    }
+
     @Before
     public void setUp() {
         Utils.setDataRoot("regular-data:potential/shm-format");
@@ -170,7 +205,7 @@ public class FieldAdditionalEquationsTest {
         tolerance    = null;
     }
 
-    public static class Linear<T extends CalculusFieldElement<T>> implements FieldAdditionalEquations<T> {
+    private static class Linear<T extends CalculusFieldElement<T>> implements FieldAdditionalEquations<T> {
 
         private String  name;
         private double  expectedAtInit;
@@ -213,4 +248,45 @@ public class FieldAdditionalEquationsTest {
 
     }
 
+    private static class Yield<T extends CalculusFieldElement<T>> implements FieldAdditionalEquations<T> {
+
+        private String dependency;
+        private String name;
+        private double rate;
+
+        public Yield(final String dependency, final String name, final double rate) {
+            this.dependency = dependency;
+            this.name       = name;
+            this.rate       = rate;
+        }
+
+        @Override
+        public T[] derivatives(final FieldSpacecraftState<T> s) {
+            final T[] pDot;
+            if (dependency == null) {
+                pDot = MathArrays.buildArray(s.getDate().getField(), 1);
+                pDot[0] = s.getDate().getField().getZero().newInstance(rate);
+            } else {
+                pDot = s.getAdditionalStateDerivative(dependency);
+            }
+            return pDot;
+        }
+
+        @Override
+        public boolean yield(final FieldSpacecraftState<T> state) {
+            return dependency != null && !state.hasAdditionalStateDerivative(dependency);
+        }
+
+        @Override
+        public int getDimension() {
+            return 1;
+        }
+
+        @Override
+        public String getName() {
+            return name;
+        }
+
+    }
+
 }
-- 
GitLab