1 /++
2 This module contains simple numeric algorithms.
3 License: $(HTTP www.apache.org/licenses/LICENSE-2.0, Apache-2.0)
4 Authors: Ilya Yaroshenko
5 Copyright: 2020 Ilya Yaroshenko, Kaleidic Associates Advisory Limited, Symmetry Investments
6 Sponsors: This work has been sponsored by $(SUBREF http://symmetryinvestments.com, Symmetry Investments) and Kaleidic Associates.
7 +/
8 module mir.math.numeric;
9 
10 import mir.math.common;
11 import mir.primitives;
12 import mir.primitives: isInputRange;
13 import std.traits: CommonType, Unqual, isIterable, ForeachType, isPointer;
14 import mir.internal.utility: isFloatingPoint;
15 
16 ///
17 struct ProdAccumulator(T)
18     if (isFloatingPoint!T)
19 {
20     alias F = Unqual!T;
21 
22     ///
23     long exp = 1L;
24     ///
25     F x = cast(F) 0.5;
26     ///
27     alias mantissa = x;
28 
29     ///
30     @safe pure @nogc nothrow
31     this(F value)
32     {
33         import mir.math.ieee: frexp;
34 
35         int lexp;
36         this.x = frexp(value, lexp);
37         this.exp = lexp;
38     }
39 
40     ///
41     @safe pure @nogc nothrow
42     this(long exp, F x)
43     {
44         this.exp = exp;
45         this.x = x;
46     }
47 
48     ///
49     @safe pure @nogc nothrow
50     void put(U)(U e)
51         if (is(U : T))
52     {
53         static if (is(U == T))
54         {
55             int lexp;
56             import mir.math.ieee: frexp;
57             x *= frexp(e, lexp);
58             exp += lexp;
59             if (x.fabs < 0.5f)
60             {
61                 x += x;
62                 exp--;
63             }
64         } else {
65             return put(cast(T) e);
66         }
67     }
68 
69     ///
70     @safe pure @nogc nothrow
71     void put(ProdAccumulator!T value)
72     {
73         exp += value.exp;
74         x *= value.x;
75         if (x.fabs < 0.5f)
76         {
77             x += x;
78             exp--;
79         }
80     }
81 
82     ///
83     void put(Range)(Range r)
84         if (isIterable!Range)
85     {
86         foreach (ref elem; r)
87             put(elem);
88     }
89     
90     import mir.ndslice.slice;
91 
92     /// ditto
93     void put(Range: Slice!(Iterator, N, kind), Iterator, size_t N, SliceKind kind)(Range r)
94     {
95         static if (N > 1 && kind == Contiguous)
96         {
97             import mir.ndslice.topology: flattened;
98             this.put(r.flattened);
99         }
100         else
101         static if (isPointer!Iterator && kind == Contiguous)
102         {
103             this.put(r.field);
104         }
105         else
106         {
107             foreach(elem; r)
108                 this.put(elem);
109         }
110     }
111 
112     ///
113     @safe pure @nogc nothrow
114     T prod() const scope @property
115     {
116         import mir.math.ieee: ldexp;
117         int e =
118             exp > int.max ? int.max :
119             exp < int.min ? int.min :
120             cast(int) exp;
121         return ldexp(mantissa, e);
122     }
123 
124     ///
125     @safe pure @nogc nothrow
126     ProdAccumulator!T ldexp(long exp) const
127     {
128         return typeof(return)(this.exp + exp, mantissa);
129     }
130 
131     // ///
132     alias opOpAssign(string op : "*") = put;
133 
134     ///
135     @safe pure @nogc nothrow
136     ProdAccumulator!T opUnary(string op : "-")() const
137     {
138         return typeof(return)(exp, -mantissa);
139     }
140 
141     ///
142     @safe pure @nogc nothrow
143     ProdAccumulator!T opUnary(string op : "+")() const
144     {
145         return typeof(return)(exp, +mantissa);
146     }
147 }
148 
149 ///
150 version(mir_test)
151 @safe pure nothrow
152 unittest
153 {
154     import mir.ndslice.slice: sliced;
155 
156     ProdAccumulator!float x;
157     x.put([1, 2, 3].sliced);
158     assert(x.prod == 6f);
159     x.put(4);
160     assert(x.prod == 24f);
161 }
162 
163 version(mir_test)
164 @safe pure @nogc nothrow
165 unittest
166 {
167     import mir.ndslice.slice: sliced;
168 
169     static immutable a = [1, 2, 3];
170     ProdAccumulator!float x;
171     x.put(a);
172     assert(x.prod == 6f);
173     x.put(4);
174     assert(x.prod == 24f);
175     static assert(is(typeof(x.prod) == float));
176 }
177 
178 version(mir_test)
179 @safe pure nothrow
180 unittest
181 {
182     import mir.ndslice.slice: sliced;
183 
184     ProdAccumulator!double x;
185     x.put([1.0, 2.0, 3.0]);
186     assert(x.prod == 6.0);
187     x.put(4.0);
188     assert(x.prod == 24.0);
189     static assert(is(typeof(x.prod) == double));
190 }
191 
192 package template prodType(T)
193 {
194     import mir.math.sum: elementType;
195 
196     alias U = elementType!T;
197     
198     static if (__traits(compiles, {
199         auto temp = U.init * U.init;
200         temp *= U.init;
201     })) {
202         import mir.math.stat: statType;
203 
204         alias V = typeof(U.init * U.init);
205         alias prodType = statType!(V, false);
206     } else {
207         static assert(0, "prodType: Can't prod elements of type " ~ U.stringof);
208     }
209 }
210 
211 /++
212 Calculates the product of the elements of the input.
213 
214 This function uses a separate exponential accumulation algorithm to calculate the
215 product. A consequence of this is that the result must be a floating point type.
216 To calculate the product of a type that is not implicitly convertible to a 
217 floating point type, use $(MREF mir, algorithm, iteration, reduce) or $(MREF mir, algorithm, iteration, fold). 
218 
219 /++
220 Params:
221     r = finite iterable range
222 Returns:
223     The prduct of all the elements in `r`
224 +/
225 
226 See_also: 
227 $(MREF mir, algorithm, iteration, reduce)
228 $(MREF mir, algorithm, iteration, fold)
229 +/
230 F prod(F, Range)(Range r)
231     if (isFloatingPoint!F && isIterable!Range)
232 {
233     import core.lifetime: move;
234 
235     ProdAccumulator!F prod;
236     prod.put(r.move);
237     return prod.prod;
238 }
239 
240 /++
241 Params:
242     r = finite iterable range
243     exp = value of exponent
244 Returns:
245     The mantissa, such that the product equals the mantissa times 2^^exp
246 +/
247 F prod(F, Range)(Range r, ref long exp)
248     if (isFloatingPoint!F && isIterable!Range)
249 {
250     import core.lifetime: move;
251 
252     ProdAccumulator!F prod;
253     prod.put(r.move);
254     exp = prod.exp;
255     return prod.x;
256 }
257 
258 /++
259 Params:
260     r = finite iterable range
261 Returns:
262     The prduct of all the elements in `r`
263 +/
264 prodType!Range prod(Range)(Range r)
265     if (isIterable!Range)
266 {
267     import core.lifetime: move;
268     
269     alias F = typeof(return);
270     return .prod!(F, Range)(r.move);
271 }
272 
273 /++
274 Params:
275     r = finite iterable range
276     exp = value of exponent
277 Returns:
278     The mantissa, such that the product equals the mantissa times 2^^exp
279 +/
280 prodType!Range prod(Range)(Range r, ref long exp)
281     if (isIterable!Range)
282 {
283     import core.lifetime: move;
284 
285     alias F = typeof(return);
286     return .prod!(F, Range)(r.move, exp);
287 }
288 
289 /++
290 Params:
291     ar = values
292 Returns:
293     The prduct of all the elements in `ar`
294 +/
295 prodType!T prod(T)(scope const T[] ar...)
296 {
297     alias F = typeof(return);
298     ProdAccumulator!F prod;
299     prod.put(ar);
300     return prod.prod;
301 }
302 
303 /// Product of arbitrary inputs
304 version(mir_test)
305 @safe pure @nogc nothrow
306 unittest
307 {
308     assert(prod(1.0, 3, 4) == 12.0);
309     assert(prod!float(1, 3, 4) == 12f);
310 }
311 
312 /// Product of arrays and ranges
313 version(mir_test)
314 @safe pure nothrow
315 unittest
316 {
317     import mir.math.common: approxEqual;
318 
319     enum l = 2.0 ^^ (double.max_exp - 1);
320     enum s = 2.0 ^^ -(double.max_exp - 1);
321     auto r = [l, l, l, s, s, s, 0.8 * 2.0 ^^ 10];
322     
323     assert(r.prod == 0.8 * 2.0 ^^ 10);
324     
325     // Can get the mantissa and exponent
326     long e;
327     assert(r.prod(e).approxEqual(0.8));
328     assert(e == 10);
329 }
330 
331 /// Product of vector
332 version(mir_test)
333 @safe pure nothrow
334 unittest
335 {
336     import mir.ndslice.slice: sliced;
337     import mir.algorithm.iteration: reduce;
338     import mir.math.common: approxEqual;
339 
340     enum l = 2.0 ^^ (double.max_exp - 1);
341     enum s = 2.0 ^^ -(double.max_exp - 1);
342     auto c = 0.8;
343     auto u = c * 2.0 ^^ 10;
344     auto r = [l, l, l, s, s, s, u, u, u].sliced;
345               
346     assert(r.prod == reduce!"a * b"(1.0, [u, u, u]));
347 
348     long e;
349     assert(r.prod(e).approxEqual(reduce!"a * b"(1.0, [c, c, c])));
350     assert(e == 30);
351 }
352 
353 /// Product of matrix
354 version(mir_test)
355 @safe pure
356 unittest
357 {
358     import mir.ndslice.fuse: fuse;
359     import mir.algorithm.iteration: reduce;
360 
361     enum l = 2.0 ^^ (double.max_exp - 1);
362     enum s = 2.0 ^^ -(double.max_exp - 1);
363     auto c = 0.8;
364     auto u = c * 2.0 ^^ 10;
365     auto r = [
366         [l, l, l],
367         [s, s, s],
368         [u, u, u]
369     ].fuse;
370               
371     assert(r.prod == reduce!"a * b"(1.0, [u, u, u]));
372 
373     long e;
374     assert(r.prod(e) == reduce!"a * b"(1.0, [c, c, c]));
375     assert(e == 30);
376 }
377 
378 /// Column prod of matrix
379 version(mir_test)
380 @safe pure
381 unittest
382 {
383     import mir.ndslice.fuse: fuse;
384     import mir.algorithm.iteration: all;
385     import mir.math.common: approxEqual;
386     import mir.ndslice.topology: alongDim, byDim, map;
387 
388     auto x = [
389         [2.0, 1.0, 1.5, 2.0, 3.5, 4.25],
390         [2.0, 7.5, 5.0, 1.0, 1.5, 5.0]
391     ].fuse;
392 
393     auto result = [4.0, 7.5, 7.5, 2.0, 5.25, 21.25];
394 
395     // Use byDim or alongDim with map to compute mean of row/column.
396     assert(x.byDim!1.map!prod.all!approxEqual(result));
397     assert(x.alongDim!0.map!prod.all!approxEqual(result));
398 
399     // FIXME
400     // Without using map, computes the prod of the whole slice
401     // assert(x.byDim!1.prod.all!approxEqual(result));
402     // assert(x.alongDim!0.prod.all!approxEqual(result));
403 }
404 
405 /// Can also set output type
406 version(mir_test)
407 @safe pure nothrow
408 unittest
409 {
410     import mir.ndslice.slice: sliced;
411     import mir.math.common: approxEqual;
412     import mir.ndslice.topology: repeat;
413 
414     auto x = [1, 2, 3].sliced;
415     assert(x.prod!float == 6f);
416 }
417 
418 /// Product of variables whose underlying types are implicitly convertible to double also have type double
419 version(mir_test)
420 @safe pure nothrow
421 unittest
422 {
423     static struct Foo
424     {
425         int x;
426         alias x this;
427     }
428 
429     auto x = prod(1, 2, 3);
430     assert(x == 6.0);
431     static assert(is(typeof(x) == double));
432     
433     auto y = prod([Foo(1), Foo(2), Foo(3)]);
434     assert(y == 6.0);
435     static assert(is(typeof(y) == double));
436 }
437 
438 version(mir_test)
439 @safe pure @nogc nothrow
440 unittest
441 {
442     import mir.ndslice.slice: sliced;
443     import mir.algorithm.iteration: reduce;
444     import mir.math.common: approxEqual;
445 
446     enum l = 2.0 ^^ (double.max_exp - 1);
447     enum s = 2.0 ^^ -(double.max_exp - 1);
448     enum c = 0.8;
449     enum u = c * 2.0 ^^ 10;
450     static immutable r = [l, l, l, s, s, s, u, u, u];
451     static immutable result1 = [u, u, u];
452     static immutable result2 = [c, c, c];
453               
454     assert(r.sliced.prod.approxEqual(reduce!"a * b"(1.0, result1)));
455 
456     long e;
457     assert(r.sliced.prod(e).approxEqual(reduce!"a * b"(1.0, result2)));
458     assert(e == 30);
459 }
460 
461 version(mir_test)
462 @safe pure @nogc nothrow
463 unittest
464 {
465     import mir.ndslice.slice: sliced;
466     import mir.algorithm.iteration: reduce;
467     import mir.math.common: approxEqual;
468 
469     enum l = 2.0 ^^ (float.max_exp - 1);
470     enum s = 2.0 ^^ -(float.max_exp - 1);
471     enum c = 0.8;
472     enum u = c * 2.0 ^^ 10;
473     static immutable r = [l, l, l, s, s, s, u, u, u];
474     static immutable result1 = [u, u, u];
475     static immutable result2 = [c, c, c];
476               
477     assert(r.sliced.prod!double.approxEqual(reduce!"a * b"(1.0, result1)));
478 
479     long e;
480     assert(r.sliced.prod!double(e).approxEqual(reduce!"a * b"(1.0, result2)));
481     assert(e == 30);
482 }
483 
484 /++
485 Compute the sum of binary logarithms of the input range $(D r).
486 The error of this method is much smaller than with a naive sum of log2.
487 +/
488 Unqual!(DeepElementType!Range) sumOfLog2s(Range)(Range r)
489     if (isFloatingPoint!(DeepElementType!Range))
490 {
491     long exp = 0;
492     auto x = .prod(r, exp);
493     return exp + log2(x);
494 }
495 
496 ///
497 version(mir_test)
498 @safe pure
499 unittest
500 {
501     alias isNaN = x => x != x;
502 
503     assert(sumOfLog2s(new double[0]) == 0);
504     assert(sumOfLog2s([0.0L]) == -real.infinity);
505     assert(sumOfLog2s([-0.0L]) == -real.infinity);
506     assert(sumOfLog2s([2.0L]) == 1);
507     assert(isNaN(sumOfLog2s([-2.0L])));
508     assert(isNaN(sumOfLog2s([real.nan])));
509     assert(isNaN(sumOfLog2s([-real.nan])));
510     assert(sumOfLog2s([real.infinity]) == real.infinity);
511     assert(isNaN(sumOfLog2s([-real.infinity])));
512     assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9);
513 }