forked from ruby-numo/numo-narray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmath.c
More file actions
147 lines (128 loc) · 4.2 KB
/
math.c
File metadata and controls
147 lines (128 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
/*
math.c
Ruby/Numo::NArray - Numerical Array class for Ruby
Copyright (C) 1999-2020 Masahiro TANAKA
*/
#include <ruby.h>
#include "numo/narray.h"
VALUE numo_mNMath;
extern VALUE numo_mDFloatMath, numo_mDComplexMath;
extern VALUE numo_mSFloatMath, numo_mSComplexMath;
static ID id_send;
static ID id_UPCAST;
static ID id_DISPATCH;
static ID id_extract;
static VALUE
nary_type_s_upcast(VALUE type1, VALUE type2)
{
VALUE upcast_hash;
VALUE result_type;
if (type1==type2) return type1;
upcast_hash = rb_const_get(type1, id_UPCAST);
result_type = rb_hash_aref(upcast_hash, type2);
if (NIL_P(result_type)) {
if (TYPE(type2)==T_CLASS) {
if ( RTEST(rb_class_inherited_p(type2,cNArray)) ) {
upcast_hash = rb_const_get(type2, id_UPCAST);
result_type = rb_hash_aref(upcast_hash, type1);
}
}
}
return result_type;
}
static VALUE nary_math_cast2(VALUE type1, VALUE type2)
{
if ( RTEST(rb_class_inherited_p( type1, cNArray )) ){
return nary_type_s_upcast( type1, type2 );
}
if ( RTEST(rb_class_inherited_p( type2, cNArray )) ){
return nary_type_s_upcast( type2, type1 );
}
if ( RTEST(rb_class_inherited_p( type1, rb_cNumeric )) &&
RTEST(rb_class_inherited_p( type2, rb_cNumeric )) ){
if ( RTEST(rb_class_inherited_p( type1, rb_cComplex)) ||
RTEST(rb_class_inherited_p( type2, rb_cComplex )) ){
return rb_cComplex;
}
return rb_cFloat;
}
return type2;
}
VALUE na_ary_composition_dtype(VALUE);
static VALUE nary_mathcast(int argc, VALUE *argv)
{
VALUE type, type2;
int i;
type = na_ary_composition_dtype(argv[0]);
for (i=1; i<argc; i++) {
type2 = na_ary_composition_dtype(argv[i]);
type = nary_math_cast2(type, type2);
if (NIL_P(type)) {
rb_raise(rb_eTypeError,"includes unknown DataType for upcast");
}
}
return type;
}
/*
Dispatches method to Math module of upcasted type,
eg, Numo::DFloat::Math.
@overload method_missing(name,x,...)
@param [Symbol] name method name.
@param [NArray,Numeric] x input array.
@return [NArray] result.
*/
static VALUE nary_math_method_missing(int argc, VALUE *argv, VALUE mod)
{
VALUE type, ans, typemod, hash;
if (argc>1) {
type = nary_mathcast(argc-1,argv+1);
hash = rb_const_get(mod, id_DISPATCH);
typemod = rb_hash_aref( hash, type );
if (NIL_P(typemod)) {
rb_raise(rb_eTypeError,"%s is unknown for Numo::NMath",
rb_class2name(type));
}
ans = rb_funcall2(typemod,id_send,argc,argv);
if (!RTEST(rb_class_inherited_p(type,cNArray)) &&
IsNArray(ans) ) {
ans = rb_funcall(ans,id_extract,0);
}
return ans;
}
rb_raise(rb_eArgError,"argument or method missing");
return Qnil;
}
void
Init_nary_math(void)
{
VALUE hCast;
numo_mNMath = rb_define_module_under(mNumo, "NMath");
rb_define_singleton_method(numo_mNMath, "method_missing", nary_math_method_missing, -1);
hCast = rb_hash_new();
rb_define_const(numo_mNMath, "DISPATCH", hCast);
rb_hash_aset(hCast, numo_cInt64, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cInt32, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cInt16, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cInt8, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cUInt64, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cUInt32, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cUInt16, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cUInt8, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cDFloat, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cDFloat, numo_mDFloatMath);
rb_hash_aset(hCast, numo_cDComplex, numo_mDComplexMath);
rb_hash_aset(hCast, numo_cSFloat, numo_mSFloatMath);
rb_hash_aset(hCast, numo_cSComplex, numo_mSComplexMath);
#ifdef RUBY_INTEGER_UNIFICATION
rb_hash_aset(hCast, rb_cInteger, rb_mMath);
#else
rb_hash_aset(hCast, rb_cFixnum, rb_mMath);
rb_hash_aset(hCast, rb_cBignum, rb_mMath);
#endif
rb_hash_aset(hCast, rb_cFloat, rb_mMath);
rb_hash_aset(hCast, rb_cComplex, numo_mDComplexMath);
id_send = rb_intern("send");
id_UPCAST = rb_intern("UPCAST");
id_DISPATCH = rb_intern("DISPATCH");
id_extract = rb_intern("extract");
}