From ea1dd7f1f2fd30309ae420bd959804265daf4227 Mon Sep 17 00:00:00 2001
From: Luc Maisonobe <luc@orekit.org>
Date: Tue, 30 Nov 2021 10:20:00 +0100
Subject: [PATCH] Fixed dimension error in resetState with additional
 equations.

---
 .../AbstractIntegratedPropagator.java         | 13 ++-
 .../integration/AdditionalEquationsTest.java  | 90 ++++++++++++++-----
 2 files changed, 74 insertions(+), 29 deletions(-)

diff --git a/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java b/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java
index 22d3dadff4..c23a0ab4b7 100644
--- a/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java
+++ b/src/main/java/org/orekit/propagation/integration/AbstractIntegratedPropagator.java
@@ -531,7 +531,6 @@ public abstract class AbstractIntegratedPropagator extends AbstractPropagator {
         }
 
         final double[][] secondary = new double[1][secondaryOffsets.get(SECONDARY_DIMENSION)];
-        // FIXME: use yield to extract states in dependencies order
         for (final AdditionalEquations equations : additionalEquations) {
             final String   name       = equations.getName();
             final int      offset     = secondaryOffsets.get(name);
@@ -555,7 +554,6 @@ public abstract class AbstractIntegratedPropagator extends AbstractPropagator {
         }
 
         final double[][] secondaryDerivative = new double[1][secondaryOffsets.get(SECONDARY_DIMENSION)];
-        // FIXME: use yield to extract derivatives in dependencies order
         for (final AdditionalEquations equations : additionalEquations) {
             final String   name       = equations.getName();
             final int      offset     = secondaryOffsets.get(name);
@@ -635,7 +633,6 @@ public abstract class AbstractIntegratedPropagator extends AbstractPropagator {
         if (os.getNumberOfSecondaryStates() > 0) {
             final double[] secondary           = os.getSecondaryState(1);
             final double[] secondaryDerivative = os.getSecondaryDerivative(1);
-            // FIXME: use yield to add states in dependencies order
             for (final AdditionalEquations equations : additionalEquations) {
                 final String name      = equations.getName();
                 final int    offset    = secondaryOffsets.get(name);
@@ -808,7 +805,6 @@ public abstract class AbstractIntegratedPropagator extends AbstractPropagator {
 
             SpacecraftState initialState = stateMapper.mapArrayToState(t, primary, primaryDot, PropagationType.MEAN);
 
-            // FIXME: use yield to add states in dependencies order
             for (final AdditionalEquations equations : additionalEquations) {
                 final String name      = equations.getName();
                 final int    offset    = secondaryOffsets.get(name);
@@ -879,9 +875,12 @@ public abstract class AbstractIntegratedPropagator extends AbstractPropagator {
             stateMapper.mapStateToArray(newState, primary, null);
 
             // secondary part
-            final double[][] secondary    = new double[additionalEquations.size()][];
-            for (int i = 0; i < additionalEquations.size(); ++i) {
-                secondary[i] = newState.getAdditionalState(additionalEquations.get(i).getName());
+            final double[][] secondary = new double[1][secondaryOffsets.get(SECONDARY_DIMENSION)];
+            for (final AdditionalEquations equations : additionalEquations) {
+                final String name      = equations.getName();
+                final int    offset    = secondaryOffsets.get(name);
+                final int    dimension = equations.getDimension();
+                System.arraycopy(newState.getAdditionalState(name), 0, secondary[0], offset, dimension);
             }
 
             return new ODEState(newState.getDate().durationFrom(getStartDate()),
diff --git a/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java b/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java
index 20db4bfcbd..b5a9d79380 100644
--- a/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java
+++ b/src/test/java/org/orekit/propagation/integration/AdditionalEquationsTest.java
@@ -26,11 +26,13 @@ import org.junit.Test;
 import org.orekit.Utils;
 import org.orekit.forces.gravity.potential.GravityFieldFactory;
 import org.orekit.forces.gravity.potential.SHMFormatReader;
+import org.orekit.forces.maneuvers.ImpulseManeuver;
 import org.orekit.frames.FramesFactory;
 import org.orekit.orbits.EquinoctialOrbit;
 import org.orekit.orbits.Orbit;
 import org.orekit.orbits.OrbitType;
 import org.orekit.propagation.SpacecraftState;
+import org.orekit.propagation.events.DateDetector;
 import org.orekit.propagation.numerical.NumericalPropagator;
 import org.orekit.propagation.semianalytical.dsst.DSSTPropagator;
 import org.orekit.time.AbsoluteDate;
@@ -50,19 +52,22 @@ public class AdditionalEquationsTest {
 
         // setup
         final double reference = 1.25;
-        InitCheckerEquations checker = new InitCheckerEquations(reference);
-        Assert.assertFalse(checker.wasCalled());
+        final double rate      = 1.5;
+        final double dt        = 600.0;
+        Linear linear = new Linear("linear", reference, rate);
+        Assert.assertFalse(linear.wasCalled());
 
         // action
         AdaptiveStepsizeIntegrator integrator = new DormandPrince853Integrator(0.001, 200, tolerance[0], tolerance[1]);
         integrator.setInitialStepSize(60);
         NumericalPropagator propagatorNumerical = new NumericalPropagator(integrator);
-        propagatorNumerical.setInitialState(initialState.addAdditionalState(checker.getName(), reference));
-        propagatorNumerical.addAdditionalEquations(checker);
-        propagatorNumerical.propagate(initDate.shiftedBy(600));
+        propagatorNumerical.setInitialState(initialState.addAdditionalState(linear.getName(), reference));
+        propagatorNumerical.addAdditionalEquations(linear);
+        SpacecraftState finalState = propagatorNumerical.propagate(initDate.shiftedBy(dt));
 
         // verify
-        Assert.assertTrue(checker.wasCalled());
+        Assert.assertTrue(linear.wasCalled());
+        Assert.assertEquals(reference + dt * rate, finalState.getAdditionalState(linear.getName())[0], 1.0e-10);
 
     }
 
@@ -73,19 +78,57 @@ public class AdditionalEquationsTest {
 
         // setup
         final double reference = 3.5;
-        InitCheckerEquations checker = new InitCheckerEquations(reference);
-        Assert.assertFalse(checker.wasCalled());
+        final double rate      = 1.5;
+        final double dt        = 600.0;
+        Linear linear = new Linear("linear", reference, rate);
+        Assert.assertFalse(linear.wasCalled());
 
         // action
         AdaptiveStepsizeIntegrator integrator = new DormandPrince853Integrator(0.001, 200, tolerance[0], tolerance[1]);
         integrator.setInitialStepSize(60);
         DSSTPropagator propagatorDSST = new DSSTPropagator(integrator);
-        propagatorDSST.setInitialState(initialState.addAdditionalState(checker.getName(), reference));
-        propagatorDSST.addAdditionalEquations(checker);
-        propagatorDSST.propagate(initDate.shiftedBy(600));
+        propagatorDSST.setInitialState(initialState.addAdditionalState(linear.getName(), reference));
+        propagatorDSST.addAdditionalEquations(linear);
+        SpacecraftState finalState = propagatorDSST.propagate(initDate.shiftedBy(dt));
 
         // verify
-        Assert.assertTrue(checker.wasCalled());
+        Assert.assertTrue(linear.wasCalled());
+        Assert.assertEquals(reference + dt * rate, finalState.getAdditionalState(linear.getName())[0], 1.0e-10);
+
+    }
+
+    @Test
+    public void testResetState() {
+
+        // setup
+        final double reference1 = 3.5;
+        final double rate1      = 1.5;
+        Linear linear1 = new Linear("linear-1", reference1, rate1);
+        Assert.assertFalse(linear1.wasCalled());
+        final double reference2 = 4.5;
+        final double rate2      = 1.25;
+        Linear linear2 = new Linear("linear-2", reference2, rate2);
+        Assert.assertFalse(linear2.wasCalled());
+        final double dt = 600;
+
+        // action
+        AdaptiveStepsizeIntegrator integrator = new DormandPrince853Integrator(0.001, 200, tolerance[0], tolerance[1]);
+        integrator.setInitialStepSize(60);
+        NumericalPropagator propagatorNumerical = new NumericalPropagator(integrator);
+        propagatorNumerical.setInitialState(initialState.
+                                            addAdditionalState(linear1.getName(), reference1).
+                                            addAdditionalState(linear2.getName(), reference2));
+        propagatorNumerical.addAdditionalEquations(linear1);
+        propagatorNumerical.addAdditionalEquations(linear2);
+        propagatorNumerical.addEventDetector(new ImpulseManeuver<>(new DateDetector(initDate.shiftedBy(dt / 2.0)),
+                                                                   new Vector3D(0.1, 0.2, 0.3), 350.0));
+        SpacecraftState finalState = propagatorNumerical.propagate(initDate.shiftedBy(dt));
+
+        // verify
+        Assert.assertTrue(linear1.wasCalled());
+        Assert.assertTrue(linear2.wasCalled());
+        Assert.assertEquals(reference1 + dt * rate1, finalState.getAdditionalState(linear1.getName())[0], 1.0e-10);
+        Assert.assertEquals(reference2 + dt * rate2, finalState.getAdditionalState(linear2.getName())[0], 1.0e-10);
 
     }
 
@@ -110,26 +153,29 @@ public class AdditionalEquationsTest {
         tolerance    = null;
     }
 
-    public static class InitCheckerEquations implements AdditionalEquations {
+    public static class Linear implements AdditionalEquations {
 
-        private double expected;
+        private String  name;
+        private double  expectedAtInit;
+        private double  rate;
         private boolean called;
 
-        public InitCheckerEquations(final double expected) {
-            this.expected = expected;
-            this.called   = false;
+        public Linear(final String name, final double expectedAtInit, final double rate) {
+            this.name           = name;
+            this.expectedAtInit = expectedAtInit;
+            this.rate           = rate;
+            this.called         = false;
         }
 
         @Override
         public void init(SpacecraftState initiaState, AbsoluteDate target) {
-            Assert.assertEquals(expected, initiaState.getAdditionalState(getName())[0], 1.0e-15);
+            Assert.assertEquals(expectedAtInit, initiaState.getAdditionalState(getName())[0], 1.0e-15);
             called = true;
         }
 
         @Override
-        public double[] computeDerivatives(SpacecraftState s, double[] pDot) {
-            pDot[0] = 1.5;
-            return null;
+        public double[] derivatives(SpacecraftState s) {
+            return new double[] { rate };
         }
 
         @Override
@@ -139,7 +185,7 @@ public class AdditionalEquationsTest {
 
         @Override
         public String getName() {
-            return "linear";
+            return name;
         }
 
         public boolean wasCalled() {
-- 
GitLab