Alexandria  2.22.0
Please provide a description of the project.
GridInterpolation.icpp
Go to the documentation of this file.
1 #ifndef GRIDINTERPOLATION_IMPL
2 #error Please, include "MathUtils/interpolation/GridInterpolation.h"
3 #endif
4 
5 #include "AlexandriaKernel/Tuples.h"
6 #include "MathUtils/interpolation/interpolation.h"
7 
8 namespace Euclid {
9 namespace MathUtils {
10 
11 template <typename T, typename Enable = void>
12 struct InterpolationImpl;
13 
14 /**
15  * Trait for continuous types
16  */
17 template <typename T>
18 struct InterpolationImpl<T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
19  static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values, bool extrapolate) {
20  return simple_interpolation(x, knots, values, extrapolate);
21  }
22 
23  template <typename... Rest>
24  static double interpolate(const T x, const std::vector<T>& knots,
25  const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool extrapolate,
26  const Rest... rest) {
27  // If no extrapolation, and the value if out-of-bounds, just clip at 0
28  if ((x < knots.front() || x > knots.back()) && !extrapolate) {
29  return 0.;
30  }
31 
32  if (knots.size() == 1) {
33  return (*interpolators[0])(rest...);
34  }
35 
36  std::size_t x2i = std::lower_bound(knots.begin(), knots.end(), x) - knots.begin();
37  if (x2i == 0) {
38  ++x2i;
39  } else if (x2i == knots.size()) {
40  --x2i;
41  }
42  std::size_t x1i = x2i - 1;
43 
44  double y1 = (*interpolators[x1i])(rest...);
45  double y2 = (*interpolators[x2i])(rest...);
46 
47  return simple_interpolation(x, {knots[x1i], knots[x2i]}, {y1, y2}, extrapolate);
48  }
49 
50  static void checkOrder(const std::vector<T>& knots) {
51  if (!std::is_sorted(knots.begin(), knots.end())) {
52  throw InterpolationException("coordinates must be sorted");
53  }
54  }
55 };
56 
57 /**
58  * Trait for discrete types
59  */
60 template <typename T>
61 struct InterpolationImpl<T, typename std::enable_if<!std::is_floating_point<T>::value>::type> {
62  static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values, bool /*extrapolate*/) {
63  std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
64  if (i >= knots.size() || knots[i] != x)
65  return 0.;
66  return values[i];
67  }
68 
69  template <typename... Rest>
70  static double interpolate(const T x, const std::vector<T>& knots,
71  const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool, const Rest... rest) {
72  std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
73  if (i >= knots.size() || knots[i] != x)
74  return 0.;
75  return (*interpolators[i])(rest...);
76  }
77 
78  static void checkOrder(const std::vector<T>&) {
79  // Discrete axes do not need to be in order
80  }
81 };
82 
83 /**
84  * Specialization (and end of the recursion) for a 1-dimensional interpolation.
85  */
86 template <typename T>
87 class InterpN<T> {
88 public:
89  /**
90  * Constructor
91  * @param grid
92  * A 1-dimensional grid
93  * @param values
94  * @param type
95  * @param extrapolate
96  */
97  InterpN(const std::tuple<std::vector<T>>& grid, const NdArray::NdArray<double>& values, bool extrapolate)
98  : m_knots(std::get<0>(grid)), m_values(values.begin(), values.end()), m_extrapolate(extrapolate) {
99  if (values.shape().size() != 1) {
100  throw InterpolationException() << "values and coordinates dimensionalities must match: " << values.shape().size() << " != 1";
101  }
102  if (m_knots.size() != values.size()) {
103  throw InterpolationException() << "The size of the grid and the size of the values do not match: " << m_knots.size()
104  << " != " << m_values.size();
105  }
106  }
107 
108  /**
109  * Call as a function
110  * @param x
111  * Coordinate value
112  * @return
113  * Interpolated value
114  */
115  double operator()(const T x) const {
116  return InterpolationImpl<T>::interpolate(x, m_knots, m_values, m_extrapolate);
117  }
118 
119  /// Copy constructor
120  InterpN(const InterpN&) = default;
121 
122  /// Move constructor
123  InterpN(InterpN&&) = default;
124 
125 private:
126  std::vector<T> m_knots;
127  std::vector<double> m_values;
128  bool m_extrapolate;
129 };
130 
131 /**
132  * Recursive specialization of an N-Dimensional interpolator
133  * @tparam N Dimensionality (N > 1)
134  * @tparam F The first element of the index sequence
135  * @tparam Rest The rest of the elements from the index sequence
136  */
137 template <typename T, typename... Rest>
138 class InterpN<T, Rest...> {
139 public:
140  /**
141  * Constructor
142  * @param grid
143  * @param values
144  * @param type
145  * @param extrapolate
146  */
147  InterpN(const std::tuple<std::vector<T>, std::vector<Rest>...>& grid, const NdArray::NdArray<double>& values, bool extrapolate)
148  : m_extrapolate(extrapolate) {
149  constexpr std::size_t N = sizeof...(Rest) + 1;
150 
151  if (values.shape().size() != N) {
152  throw InterpolationException() << "values and coordinates dimensionality must match: " << values.shape().size()
153  << " != " << N;
154  }
155  m_knots = std::get<0>(grid);
156  InterpolationImpl<T>::checkOrder(m_knots);
157  if (m_knots.size() != values.shape().back()) {
158  throw InterpolationException("coordinates and value sizes must match");
159  }
160  // Build nested interpolators
161  auto subgrid = Tuple::Tail(std::move(grid));
162  m_interpolators.resize(m_knots.size());
163  for (size_t i = 0; i < m_knots.size(); ++i) {
164  auto subvalues = values.rslice(i);
165  m_interpolators[i].reset(new InterpN<Rest...>(subgrid, subvalues, extrapolate));
166  }
167  }
168 
169  /**
170  * Call as a function
171  * @param x Value for the axis for the first dimension
172  * @param rest Values for the next set of axes
173  * @return The interpolated value
174  * @details
175  * Doubles<Rest>... is used to expand into (N-1) doubles
176  * x is used to find the interpolators for x1 and x2 s.t. x1 <= x <=x2
177  * Those two interpolators are used to compute y1 for x1, and y2 for x2 (based on the rest of the parameters)
178  * A final linear interpolator is used to get the value of y at the position x
179  */
180  double operator()(T x, Rest... rest) const {
181  return InterpolationImpl<T>::interpolate(x, m_knots, m_interpolators, m_extrapolate, rest...);
182  }
183 
184  /// Copy constructor
185  InterpN(const InterpN& other) : m_knots(other.m_knots), m_extrapolate(other.m_extrapolate) {
186  m_interpolators.resize(m_knots.size());
187  for (size_t i = 0; i < m_interpolators.size(); ++i) {
188  m_interpolators[i].reset(new InterpN<Rest...>(*other.m_interpolators[i]));
189  }
190  }
191 
192 private:
193  std::vector<T> m_knots;
194  std::vector<std::unique_ptr<InterpN<Rest...>>> m_interpolators;
195  bool m_extrapolate;
196 };
197 
198 } // namespace MathUtils
199 } // namespace Euclid