diff --git a/src/main/java/org/orekit/rugged/atmosphericrefraction/AtmosphericRefraction.java b/src/main/java/org/orekit/rugged/atmosphericrefraction/AtmosphericRefraction.java
index ac5a16114bcef6089142db1e7e0123ea0043c4eb..21a545bf19dc08e8ff417b48ca7d81147f096bb4 100644
--- a/src/main/java/org/orekit/rugged/atmosphericrefraction/AtmosphericRefraction.java
+++ b/src/main/java/org/orekit/rugged/atmosphericrefraction/AtmosphericRefraction.java
@@ -19,7 +19,7 @@ package org.orekit.rugged.atmosphericrefraction;
 
 import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
 import org.orekit.rugged.errors.RuggedException;
-import org.orekit.rugged.raster.Tile;
+import org.orekit.rugged.intersection.IntersectionAlgorithm;
 import org.orekit.rugged.utils.NormalizedGeodeticPoint;
 
 /**
@@ -28,6 +28,8 @@ import org.orekit.rugged.utils.NormalizedGeodeticPoint;
  */
 public interface AtmosphericRefraction {
 
-    NormalizedGeodeticPoint getPointOnGround(Vector3D pos, Vector3D los, Vector3D zenith, double altitude, Tile tile) throws RuggedException;
+    NormalizedGeodeticPoint applyCorrection(Vector3D satPos, Vector3D satLos, NormalizedGeodeticPoint rawIntersection,
+                                             IntersectionAlgorithm algorithm)
+            throws RuggedException;
 
 }
diff --git a/src/main/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModel.java b/src/main/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModel.java
index e7a440430fd1381309401ab5f7106d1bc1263ebd..f6b79e7aa00f38e5c51f55c553016edb0353c147 100644
--- a/src/main/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModel.java
+++ b/src/main/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModel.java
@@ -33,8 +33,10 @@ import org.orekit.rugged.utils.NormalizedGeodeticPoint;
 import org.orekit.utils.Constants;
 import org.orekit.utils.IERSConventions;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.TreeMap;
 
@@ -44,86 +46,113 @@ import java.util.TreeMap;
  */
 public class MultiLayerModel implements AtmosphericRefraction {
 
-    // maps altitude (lower bound) to refraction index
-    private static Map<Double, Double> meanAtmosphericRefractions;
-
-    private static Map<Double, ExtendedEllipsoid> atmosphericEllipsoids;
-
-    public MultiLayerModel() throws OrekitException {
-        meanAtmosphericRefractions = new TreeMap(Collections.reverseOrder());
-        meanAtmosphericRefractions.put(-1000.00000000000, 1.00030600000);
-        meanAtmosphericRefractions.put(0.00000000000, 1.00027800000);
-        meanAtmosphericRefractions.put(1000.00000000000, 1.00025200000);
-        meanAtmosphericRefractions.put(3000.00000000000, 1.00020600000);
-        meanAtmosphericRefractions.put(5000.00000000000, 1.00016700000);
-        meanAtmosphericRefractions.put(7000.00000000000, 1.00013400000);
-        meanAtmosphericRefractions.put(9000.00000000000, 1.00010600000);
-        meanAtmosphericRefractions.put(11000.00000000000, 1.00008300000);
-        meanAtmosphericRefractions.put(14000.00000000000, 1.00005200000);
-        meanAtmosphericRefractions.put(18000.00000000000, 1.00002800000);
-        meanAtmosphericRefractions.put(23000.00000000000, 1.00001200000);
-        meanAtmosphericRefractions.put(30000.00000000000, 1.00000400000);
-        meanAtmosphericRefractions.put(40000.00000000000, 1.00000100000);
-        meanAtmosphericRefractions.put(50000.00000000000, 1.00000000000);
-        meanAtmosphericRefractions.put(100000.00000000000, 1.00000000000);
-
-        atmosphericEllipsoids = new HashMap<Double, ExtendedEllipsoid>();
-        for (Double altitude : meanAtmosphericRefractions.keySet()) {
-            OneAxisEllipsoid ellipsoid = new OneAxisEllipsoid(Constants.WGS84_EARTH_EQUATORIAL_RADIUS + altitude,
-                    Constants.WGS84_EARTH_FLATTENING, FramesFactory.getITRF(IERSConventions.IERS_2010, true));
-            ellipsoid = new ExtendedEllipsoid(ellipsoid.getEquatorialRadius(), ellipsoid.getFlattening(),
-                    ellipsoid.getBodyFrame());
-            atmosphericEllipsoids.put(altitude, (ExtendedEllipsoid) ellipsoid);
-        }
+    /** Observed body ellipsoid. */
+    private final ExtendedEllipsoid ellipsoid;
+
+    /** Constant refraction layers */
+    private final List<ConstantRefractionLayer> refractionLayers;
+
+    public MultiLayerModel(final ExtendedEllipsoid ellipsoid)
+            throws OrekitException {
+        this.ellipsoid = ellipsoid;
+
+        refractionLayers = new ArrayList<ConstantRefractionLayer>(15);
+        refractionLayers.add(new ConstantRefractionLayer(100000.00, 1.000000));
+        refractionLayers.add(new ConstantRefractionLayer( 50000.00, 1.000000));
+        refractionLayers.add(new ConstantRefractionLayer( 40000.00, 1.000001));
+        refractionLayers.add(new ConstantRefractionLayer( 30000.00, 1.000004));
+        refractionLayers.add(new ConstantRefractionLayer( 23000.00, 1.000012));
+        refractionLayers.add(new ConstantRefractionLayer( 18000.00, 1.000028));
+        refractionLayers.add(new ConstantRefractionLayer( 14000.00, 1.000052));
+        refractionLayers.add(new ConstantRefractionLayer( 11000.00, 1.000083));
+        refractionLayers.add(new ConstantRefractionLayer(  9000.00, 1.000106));
+        refractionLayers.add(new ConstantRefractionLayer(  7000.00, 1.000134));
+        refractionLayers.add(new ConstantRefractionLayer(  5000.00, 1.000167));
+        refractionLayers.add(new ConstantRefractionLayer(  3000.00, 1.000206));
+        refractionLayers.add(new ConstantRefractionLayer(  1000.00, 1.000252));
+        refractionLayers.add(new ConstantRefractionLayer(     0.00, 1.000278));
+        refractionLayers.add(new ConstantRefractionLayer( -1000.00, 1.000306));
+    }
+
+    public MultiLayerModel(final ExtendedEllipsoid ellipsoid, final List<ConstantRefractionLayer> refractionLayers)
+            throws OrekitException {
+        this.ellipsoid = ellipsoid;
+        // TODO guarantee that list is already ordered by altitude?
+        this.refractionLayers = refractionLayers;
     }
 
     @Override
-    public NormalizedGeodeticPoint getPointOnGround(Vector3D initialPos, Vector3D initialLos, Vector3D initialZenith,
-                                                    double altitude, Tile tile) throws RuggedException {
-
-        Vector3D pos = initialPos;
-        Vector3D los = initialLos;
-        Vector3D zenith = initialZenith;
-        double theta1 = Vector3D.angle(los, zenith), theta2;
-        double previousRefractionIndex = -1;
-        NormalizedGeodeticPoint gp = null;
-        for (Map.Entry<Double, Double> entry : meanAtmosphericRefractions.entrySet()) {
-            if (pos.getZ() < entry.getKey()) {
-                continue;
-            }
+    public NormalizedGeodeticPoint applyCorrection(final Vector3D satPos, final Vector3D satLos,
+                                                   final NormalizedGeodeticPoint rawIntersection,
+                                                   final IntersectionAlgorithm algorithm)
+            throws RuggedException {
 
-            if (previousRefractionIndex > 0) {
-                theta2 = FastMath.asin(previousRefractionIndex * FastMath.sin(theta1) / entry.getValue());
+        try {
 
-                // get new los
-                double a = FastMath.sqrt((1 - FastMath.pow(FastMath.cos(theta2), 2)) /
-                        (1 - FastMath.pow(FastMath.cos(theta1), 2)));
-                double b = a * FastMath.cos(theta1) - FastMath.cos(theta2);
-                los = new Vector3D(a, los, b, zenith);
+            Vector3D pos = satPos;
+            Vector3D los = satLos;
+            Vector3D zenith = null;
+            double previousRefractionIndex = -1;
+            GeodeticPoint gp = ellipsoid.transform(satPos, ellipsoid.getBodyFrame(), null);
 
-                theta1 = theta2;
-            }
+            for(ConstantRefractionLayer refractionLayer : refractionLayers) {
+
+                if(refractionLayer.getLowestAltitude() > gp.getAltitude()) {
+                    continue;
+                }
+
+                if (previousRefractionIndex > 0) {
+
+                    // get new los
+                    final double theta1 = Vector3D.angle(los, zenith);
+                    final double theta2 = FastMath.asin(previousRefractionIndex * FastMath.sin(theta1) /
+                            refractionLayer.getRefractionIndex());
+
+                    final double cosTheta1     = FastMath.cos(theta1);
+                    final double cosTheta2     = FastMath.cos(theta2);
 
-            if (altitude > entry.getKey()) {
-                break;
+                    final double a = FastMath.sqrt((1 - cosTheta2 * cosTheta2) / (1 - cosTheta1 * cosTheta1));
+                    final double b = a * cosTheta1 - cosTheta2;
+                    los = new Vector3D(a, los, b, zenith);
+                }
+
+                if (rawIntersection.getAltitude() > refractionLayer.getLowestAltitude()) {
+                    break;
+                }
+
+                // get intersection point
+                pos = ellipsoid.pointAtAltitude(pos, los, refractionLayer.getLowestAltitude());
+                gp = ellipsoid.transform(pos, ellipsoid.getBodyFrame(), null);
+                zenith = gp.getZenith();
+
+                previousRefractionIndex = refractionLayer.getRefractionIndex();
             }
 
-            // get intersection point
-            ExtendedEllipsoid ellipsoid = atmosphericEllipsoids.get(entry.getKey());
-            gp = ellipsoid.pointOnGround(pos, los, 0.0);
-            gp = new NormalizedGeodeticPoint(gp.getLatitude(), gp.getLongitude(), entry.getKey(), 0.0);
+            final NormalizedGeodeticPoint newGeodeticPoint  =
+                    algorithm.refineIntersection(ellipsoid, pos, los, rawIntersection);
 
-            pos = ellipsoid.transform(gp);
-            zenith = gp.getZenith();
+            return newGeodeticPoint;
 
-            previousRefractionIndex = entry.getValue();
+        } catch (OrekitException oe) {
+            throw new RuggedException(oe, oe.getSpecifier(), oe.getParts());
         }
+    }
+}
+
+class ConstantRefractionLayer {
+    private double lowestAltitude;
+    private double refractionIndex;
 
+    public ConstantRefractionLayer(double lowestAltitude, double refractionIndex) {
+        this.lowestAltitude = lowestAltitude;
+        this.refractionIndex = refractionIndex;
+    }
 
-        // gp = new NormalizedGeodeticPoint(gp.getLatitude(), gp.getLongitude(), 16, 0.0);
-        NormalizedGeodeticPoint newGeodeticPoint = tile.cellIntersection(gp, los,
-                tile.getFloorLatitudeIndex(gp.getLatitude()), tile.getFloorLongitudeIndex(gp.getLongitude()));
+    public double getLowestAltitude() {
+        return lowestAltitude;
+    }
 
-        return newGeodeticPoint;
+    public double getRefractionIndex() {
+        return refractionIndex;
     }
 }
diff --git a/src/test/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModelTest.java b/src/test/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModelTest.java
index 7b9bf81323b386371869615a075aae0a73bf7d2f..9667a7c1ca6ca926c9b57fd3f1c00e7905958e1c 100644
--- a/src/test/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModelTest.java
+++ b/src/test/java/org/orekit/rugged/atmosphericrefraction/MultiLayerModelTest.java
@@ -30,40 +30,29 @@ import org.orekit.rugged.intersection.duvenhage.MinMaxTreeTileFactory;
 import org.orekit.rugged.raster.Tile;
 import org.orekit.rugged.raster.TileUpdater;
 import org.orekit.rugged.raster.TilesCache;
+import org.orekit.rugged.utils.NormalizedGeodeticPoint;
 
 public class MultiLayerModelTest extends AbstractAlgorithmTest {
 
     @Test
-    public void testGetPointOnGround() throws OrekitException, RuggedException {
+    public void testApplyCorrection() throws OrekitException, RuggedException {
 
         setUpMayonVolcanoContext();
         final IntersectionAlgorithm algorithm = createAlgorithm(updater, 8);
         Vector3D position = new Vector3D(-3787079.6453602533, 5856784.405679551, 1655869.0582939098);
         Vector3D los = new Vector3D( 0.5127552821932051, -0.8254313129088879, -0.2361041470463311);
-        GeodeticPoint intersection = algorithm.refineIntersection(earth, position, los,
+        NormalizedGeodeticPoint rawIntersection = algorithm.refineIntersection(earth, position, los,
                 algorithm.intersection(earth, position, los));
 
-        Assert.assertNotNull(intersection);
+        MultiLayerModel model = new MultiLayerModel(earth);
+        GeodeticPoint correctedIntersection = model.applyCorrection(position, los, rawIntersection, algorithm);
 
-        // intersection {lat: 13.4045888388 deg, lon: 123.0160362249 deg, alt: 16}
+        double distance = Vector3D.distance(earth.transform(rawIntersection), earth.transform(correctedIntersection));
 
+        System.out.println("DISTANCE: " + distance);
 
-//        MinMaxTreeTile tile = new MinMaxTreeTileFactory().createTile();
-//        updater.updateTile(intersection.getLatitude(), intersection.getLongitude(), tile);
-//        tile.interpolateElevation(intersection.getLatitude(), intersection.getLongitude());
-
-
-        TilesCache cache = new TilesCache<MinMaxTreeTile>(new MinMaxTreeTileFactory(), updater, 8);
-        // locate the entry tile along the line-of-sight
-        Tile tile = cache.getTile(intersection.getLatitude(), intersection.getLongitude());
-
-
-        MultiLayerModel model = new MultiLayerModel();
-        GeodeticPoint gp = model.getPointOnGround(position, los, intersection.getZenith(), intersection.getAltitude(),
-                tile);
-
-        Assert.assertNotNull(gp);
-
+        // with the current code, this check fails, the distance is about 800m instead of a couple meters
+        Assert.assertEquals(0.0, distance, 2.0);
     }
 
     @Override