@@ -11,7 +11,7 @@ PyObject *NA_OBJ = NULL;
1111 * Internal helper to create new instances
1212 */
1313PyObject *
14- new_stringdtype_instance (PyObject * na_object )
14+ new_stringdtype_instance (PyObject * na_object , int coerce )
1515{
1616 PyObject * new =
1717 PyArrayDescr_Type .tp_new ((PyTypeObject * )& StringDType , NULL , NULL );
@@ -22,6 +22,7 @@ new_stringdtype_instance(PyObject *na_object)
2222
2323 Py_INCREF (na_object );
2424 ((StringDTypeObject * )new )-> na_object = na_object ;
25+ ((StringDTypeObject * )new )-> coerce = coerce ;
2526
2627 PyArray_Descr * base = (PyArray_Descr * )new ;
2728 base -> elsize = sizeof (ss );
@@ -67,23 +68,32 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
6768}
6869
6970// returns a new reference to the string "value" of
70- // `scalar`. If scalar is not already a string, __str__
71- // is called to convert it to a string. If the scalar
72- // is the na_object for the dtype class, return
73- // a new reference to the na_object.
71+ // `scalar`. If scalar is not already a string and
72+ // coerce is nonzero, __str__ is called to convert it
73+ // to a string. If coerce is zero, raises an error for
74+ // non-string or non-NA input. If the scalar is the
75+ // na_object for the dtype class, return a new
76+ // reference to the na_object.
7477
7578static PyObject *
76- get_value (PyObject * scalar )
79+ get_value (PyObject * scalar , int coerce )
7780{
7881 PyTypeObject * scalar_type = Py_TYPE (scalar );
7982 if (!((scalar_type == & PyUnicode_Type ) ||
8083 (scalar_type == StringScalar_Type ))) {
81- // attempt to coerce to str
82- scalar = PyObject_Str (scalar );
83- if (scalar == NULL ) {
84- // __str__ raised an exception
84+ if (coerce == 0 ) {
85+ PyErr_SetString (PyExc_ValueError ,
86+ "StringDType only allows string data" );
8587 return NULL ;
8688 }
89+ else {
90+ // attempt to coerce to str
91+ scalar = PyObject_Str (scalar );
92+ if (scalar == NULL ) {
93+ // __str__ raised an exception
94+ return NULL ;
95+ }
96+ }
8797 }
8898 // attempt to decode as UTF8
8999 return PyUnicode_AsUTF8String (scalar );
@@ -93,12 +103,12 @@ static PyArray_Descr *
93103string_discover_descriptor_from_pyobject (PyTypeObject * NPY_UNUSED (cls ),
94104 PyObject * obj )
95105{
96- PyObject * val = get_value (obj );
106+ PyObject * val = get_value (obj , 1 );
97107 if (val == NULL ) {
98108 return NULL ;
99109 }
100110
101- PyArray_Descr * ret = (PyArray_Descr * )new_stringdtype_instance (NA_OBJ );
111+ PyArray_Descr * ret = (PyArray_Descr * )new_stringdtype_instance (NA_OBJ , 1 );
102112 if (ret == NULL ) {
103113 return NULL ;
104114 }
@@ -126,7 +136,7 @@ stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
126136 // so it already contains a NA value
127137 }
128138 else {
129- PyObject * val_obj = get_value (obj );
139+ PyObject * val_obj = get_value (obj , descr -> coerce );
130140
131141 if (val_obj == NULL ) {
132142 return -1 ;
@@ -334,21 +344,23 @@ static PyType_Slot StringDType_Slots[] = {
334344static PyObject *
335345stringdtype_new (PyTypeObject * NPY_UNUSED (cls ), PyObject * args , PyObject * kwds )
336346{
337- static char * kwargs_strs [] = {"size" , "na_object" , NULL };
347+ static char * kwargs_strs [] = {"size" , "na_object" , "coerce" , NULL };
338348
339349 long size = 0 ;
340350 PyObject * na_object = NULL ;
351+ int coerce = 1 ;
341352
342- if (!PyArg_ParseTupleAndKeywords (args , kwds , "|lO:StringDType" ,
343- kwargs_strs , & size , & na_object )) {
353+ if (!PyArg_ParseTupleAndKeywords (args , kwds , "|lOp:StringDType" ,
354+ kwargs_strs , & size , & na_object ,
355+ & coerce )) {
344356 return NULL ;
345357 }
346358
347359 if (na_object == NULL ) {
348360 na_object = NA_OBJ ;
349361 }
350362
351- PyObject * ret = new_stringdtype_instance (na_object );
363+ PyObject * ret = new_stringdtype_instance (na_object , coerce );
352364
353365 return ret ;
354366}
@@ -365,11 +377,18 @@ stringdtype_repr(StringDTypeObject *self)
365377 PyObject * ret = NULL ;
366378 // borrow reference
367379 PyObject * na_object = self -> na_object ;
380+ int coerce = self -> coerce ;
368381
369382 // TODO: handle non-default NA
370- if (na_object != NA_OBJ ) {
371- ret = PyUnicode_FromFormat ("StringDType(na_object=%R)" ,
372- self -> na_object );
383+ if (na_object != NA_OBJ && coerce == 0 ) {
384+ ret = PyUnicode_FromFormat ("StringDType(na_object=%R, coerce=False)" ,
385+ na_object );
386+ }
387+ else if (na_object != NA_OBJ ) {
388+ ret = PyUnicode_FromFormat ("StringDType(na_object=%R)" , na_object );
389+ }
390+ else if (coerce == 0 ) {
391+ ret = PyUnicode_FromFormat ("StringDType(coerce=False)" , coerce );
373392 }
374393 else {
375394 ret = PyUnicode_FromString ("StringDType()" );
@@ -378,7 +397,7 @@ stringdtype_repr(StringDTypeObject *self)
378397 return ret ;
379398}
380399
381- static int PICKLE_VERSION = 1 ;
400+ static int PICKLE_VERSION = 2 ;
382401
383402static PyObject *
384403stringdtype__reduce__ (StringDTypeObject * self )
@@ -405,9 +424,9 @@ stringdtype__reduce__(StringDTypeObject *self)
405424
406425 PyTuple_SET_ITEM (ret , 0 , obj );
407426
408- PyTuple_SET_ITEM (
409- ret , 1 ,
410- Py_BuildValue ( "(NO)" , PyLong_FromLong ( 0 ), self -> na_object ));
427+ PyTuple_SET_ITEM (ret , 1 ,
428+ Py_BuildValue ( "(NOi)" , PyLong_FromLong ( 0 ) ,
429+ self -> na_object , self -> coerce ));
411430
412431 PyTuple_SET_ITEM (ret , 2 , Py_BuildValue ("(l)" , PICKLE_VERSION ));
413432
@@ -456,9 +475,39 @@ static PyMemberDef StringDType_members[] = {
456475 {"na_object" , T_OBJECT_EX , offsetof(StringDTypeObject , na_object ),
457476 READONLY ,
458477 "The missing value object associated with the dtype instance" },
478+ {"coerce" , T_INT , offsetof(StringDTypeObject , coerce ), READONLY ,
479+ "Controls hether non-string values should be coerced to string" },
459480 {NULL , 0 , 0 , 0 , NULL },
460481};
461482
483+ static PyObject *
484+ StringDType_richcompare (PyObject * self , PyObject * other , int op )
485+ {
486+ if (!((op == Py_EQ ) || (op == Py_NE )) ||
487+ (Py_TYPE (other ) != Py_TYPE (self ))) {
488+ Py_INCREF (Py_NotImplemented );
489+ return Py_NotImplemented ;
490+ }
491+
492+ // we know both are instances of StringDType so this is safe
493+ StringDTypeObject * sself = (StringDTypeObject * )self ;
494+ StringDTypeObject * sother = (StringDTypeObject * )other ;
495+
496+ int eq = (sself -> na_object == sother -> na_object ) &&
497+ (sself -> coerce == sother -> coerce );
498+
499+ PyObject * ret = Py_NotImplemented ;
500+ if ((op == Py_EQ && eq ) || (op == Py_NE && !eq )) {
501+ ret = Py_True ;
502+ }
503+ else {
504+ ret = Py_False ;
505+ }
506+
507+ Py_INCREF (ret );
508+ return ret ;
509+ }
510+
462511/*
463512 * This is the basic things that you need to create a Python Type/Class in C.
464513 * However, there is a slight difference here because we create a
@@ -476,6 +525,7 @@ StringDType_type StringDType = {
476525 .tp_str = (reprfunc )stringdtype_repr ,
477526 .tp_methods = StringDType_methods ,
478527 .tp_members = StringDType_members ,
528+ .tp_richcompare = StringDType_richcompare ,
479529 }}},
480530 /* rest, filled in during DTypeMeta initialization */
481531};
0 commit comments