Skip to content

Commit a7b1d13

Browse files
authored
Merge pull request #51 from rajithv/function_symbol
Function symbol Ruby Wrappers
2 parents 23e1ba1 + 6c0fe5d commit a7b1d13

File tree

8 files changed

+109
-6
lines changed

8 files changed

+109
-6
lines changed

ext/symengine/ruby_function.c

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "ruby_function.h"
22

3+
typedef struct CVecBasic CVecBasic;
4+
35
#define IMPLEMENT_ONE_ARG_FUNC(func) \
46
VALUE cfunction_ ## func(VALUE self, VALUE operand1) { \
57
return function_onearg(basic_ ## func, operand1); \
@@ -37,3 +39,36 @@ IMPLEMENT_ONE_ARG_FUNC(gamma);
3739

3840
#undef IMPLEMENT_ONE_ARG_FUNC
3941

42+
VALUE cfunction_functionsymbol_init(VALUE self, VALUE args)
43+
{
44+
int argc = RARRAY_LEN(args);
45+
if (argc == 0) {
46+
rb_raise(rb_eTypeError, "Arguments Expected");
47+
}
48+
49+
VALUE first = rb_ary_shift(args);
50+
if (TYPE(first) != T_STRING) {
51+
rb_raise(rb_eTypeError, "String expected as first argument");
52+
}
53+
char *name = StringValueCStr(first);
54+
55+
CVecBasic *cargs = vecbasic_new();
56+
57+
basic x;
58+
basic_new_stack(x);
59+
int i;
60+
for (i = 0; i < argc-1; i++) {
61+
sympify(rb_ary_shift(args), x);
62+
vecbasic_push_back(cargs, x);
63+
}
64+
65+
basic_struct *this;
66+
Data_Get_Struct(self, basic_struct, this);
67+
function_symbol_set(this, name, cargs);
68+
69+
vecbasic_free(cargs);
70+
basic_free_stack(x);
71+
72+
return self;
73+
}
74+

ext/symengine/ruby_function.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,6 @@ VALUE cfunction_dirichlet_eta(VALUE self, VALUE operand1);
3737
VALUE cfunction_zeta(VALUE self, VALUE operand1);
3838
VALUE cfunction_gamma(VALUE self, VALUE operand1);
3939

40+
VALUE cfunction_functionsymbol_init(VALUE self, VALUE args);
41+
4042
#endif //RUBY_FUNCTION_H_

ext/symengine/symengine.c

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ void Init_symengine() {
110110
rb_define_const(m_symengine, "I", cconstant_i());
111111
rb_define_const(m_symengine, "HAVE_MPFR", cconstant_have_mpfr());
112112
rb_define_const(m_symengine, "HAVE_MPC", cconstant_have_mpc());
113+
114+
//Subs class
115+
c_subs = rb_define_class_under(m_symengine, "Subs", c_basic);
113116

114117
//Add class
115118
c_add = rb_define_class_under(m_symengine, "Add", c_basic);
@@ -130,10 +133,12 @@ void Init_symengine() {
130133
c_dirichlet_eta = rb_define_class_under(m_symengine, "Dirichlet_eta", c_function);
131134
c_zeta = rb_define_class_under(m_symengine, "Zeta", c_function);
132135
c_gamma = rb_define_class_under(m_symengine, "Gamma", c_function);
133-
134-
//Abs Class
135136
c_abs = rb_define_class_under(m_symengine, "Abs", c_function);
136137

138+
//FunctionSymbol Class
139+
c_function_symbol = rb_define_class_under(m_symengine, "FunctionSymbol", c_function);
140+
rb_define_method(c_function_symbol, "initialize", cfunction_functionsymbol_init, -2);
141+
137142
//TrigFunction SubClasses
138143
c_sin = rb_define_class_under(m_symengine, "Sin", c_trig_function);
139144
c_cos = rb_define_class_under(m_symengine, "Cos", c_trig_function);

ext/symengine/symengine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ VALUE c_real_mpfr;
2121
VALUE c_complex_mpc;
2222
#endif //HAVE_SYMENGINE_MPC
2323
VALUE c_constant;
24+
VALUE c_subs;
2425
VALUE c_add;
2526
VALUE c_mul;
2627
VALUE c_pow;
2728
VALUE c_function;
29+
VALUE c_function_symbol;
2830
VALUE c_trig_function;
2931
VALUE c_hyperbolic_function;
3032
VALUE c_lambertw;

ext/symengine/symengine_utils.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ void sympify(VALUE operand2, basic_struct *cbasic_operand2) {
88
VALUE a, b;
99
double f;
1010
char *c;
11-
rb_cBigDecimal = CLASS_OF(rb_eval_string("BigDecimal.new('0.0001')"));
1211

1312
switch(TYPE(operand2)) {
1413
case T_FIXNUM:
@@ -63,7 +62,7 @@ void sympify(VALUE operand2, basic_struct *cbasic_operand2) {
6362
case T_DATA:
6463
c = rb_obj_classname(operand2);
6564
#ifdef HAVE_SYMENGINE_MPFR
66-
if(CLASS_OF(operand2) == rb_cBigDecimal){
65+
if (strcmp(c, "BigDecimal") == 0) {
6766
c = RSTRING_PTR( rb_funcall(operand2, rb_intern("to_s"), 1, rb_str_new2("F")) );
6867
real_mpfr_set_str(cbasic_operand2, c, 200);
6968
break;
@@ -113,12 +112,16 @@ VALUE Klass_of_Basic(const basic_struct *basic_ptr) {
113112
#endif //HAVE_SYMENGINE_MPFR
114113
case SYMENGINE_CONSTANT:
115114
return c_constant;
115+
case SYMENGINE_SUBS:
116+
return c_subs;
116117
case SYMENGINE_ADD:
117118
return c_add;
118119
case SYMENGINE_MUL:
119120
return c_mul;
120121
case SYMENGINE_POW:
121122
return c_pow;
123+
case SYMENGINE_FUNCTIONSYMBOL:
124+
return c_function_symbol;
122125
case SYMENGINE_ABS:
123126
return c_abs;
124127
case SYMENGINE_SIN:

lib/symengine.rb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module SymEngine
22
class << self
3-
43
# Defines a shortcut for SymEngine::Symbol.new() allowing multiple symbols
54
# to be created all at once.
65
#
@@ -27,7 +26,10 @@ def symbols ary_or_string, *params
2726
ary_or_string.map do |symbol_or_string|
2827
SymEngine::Symbol.new(symbol_or_string)
2928
end
30-
end
29+
end
30+
def Function(n)
31+
return SymEngine::UndefFunction.new(n)
32+
end
3133
end
3234
end
3335

@@ -37,3 +39,4 @@ def symbols ary_or_string, *params
3739
require 'symengine/integer'
3840
require 'symengine/complex'
3941
require 'symengine/complex_double'
42+
require 'symengine/undef_function'

lib/symengine/undef_function.rb

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module SymEngine
2+
class UndefFunction
3+
4+
def initialize(n)
5+
@name = n
6+
end
7+
8+
def call(*args)
9+
SymEngine::FunctionSymbol.new(@name, *args)
10+
end
11+
12+
end
13+
end

spec/function_symbol_spec.rb

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
describe SymEngine::FunctionSymbol do
2+
3+
let(:x) { sym('x') }
4+
let(:y) { sym('y') }
5+
let(:z) { sym('z') }
6+
7+
describe '.new' do
8+
context 'with symbols' do
9+
subject { SymEngine::FunctionSymbol.new('f', x, y, z) }
10+
it { is_expected.to be_a SymEngine::FunctionSymbol }
11+
end
12+
13+
context 'with compound arguments' do
14+
subject { SymEngine::FunctionSymbol.new('f', 2*x, y, SymEngine::sin(z)) }
15+
it { is_expected.to be_a SymEngine::FunctionSymbol }
16+
end
17+
end
18+
19+
context '#diff' do
20+
let(:fun) { (SymEngine::FunctionSymbol.new('f', 2*x, y, SymEngine::sin(z))) }
21+
context 'by variable' do
22+
subject { fun.diff(x)/2 }
23+
it { is_expected.to be_a SymEngine::Subs }
24+
end
25+
end
26+
27+
context 'Initializing with UndefFunctions' do
28+
let(:fun) { SymEngine::Function('f') }
29+
context 'UndefFunction' do
30+
subject { fun }
31+
it { is_expected.to be_a SymEngine::UndefFunction }
32+
end
33+
context 'using call method for UndefFunction' do
34+
subject { fun.(x, y, z) }
35+
it { is_expected.to be_a SymEngine::FunctionSymbol }
36+
end
37+
end
38+
end
39+
40+

0 commit comments

Comments
 (0)