@@ -8,76 +8,78 @@ namespace Tensorflow.NumPy
88{
99 public partial class NDArray
1010 {
11- public NDArray this [ int index ]
11+ public NDArray this [ params int [ ] index ]
1212 {
13- get
13+ get => _tensor [ index . Select ( x => new Slice
1414 {
15- return _tensor [ index ] ;
16- }
15+ Start = x ,
16+ Stop = x + 1 ,
17+ IsIndex = true
18+ } ) . ToArray ( ) ] ;
1719
18- set
20+ set => SetData ( index . Select ( x => new Slice
1921 {
22+ Start = x ,
23+ Stop = x + 1 ,
24+ IsIndex = true
25+ } ) , value ) ;
26+ }
2027
21- }
28+ public NDArray this [ params Slice [ ] slices ]
29+ {
30+ get => _tensor [ slices ] ;
31+ set => SetData ( slices , value ) ;
2232 }
2333
24- public NDArray this [ params int [ ] index ]
34+ public NDArray this [ NDArray mask ]
2535 {
2636 get
2737 {
28- return _tensor [ index . Select ( x => new Slice ( x , x + 1 ) ) . ToArray ( ) ] ;
38+ throw new NotImplementedException ( "" ) ;
2939 }
3040
3141 set
3242 {
33- var offset = ShapeHelper . GetOffset ( shape , index ) ;
34- unsafe
35- {
36- if ( dtype == TF_DataType . TF_BOOL )
37- * ( ( bool * ) data + offset ) = value;
38- else if ( dtype == TF_DataType . TF_UINT8 )
39- * ( ( byte * ) data + offset ) = value;
40- else if ( dtype == TF_DataType . TF_INT32 )
41- * ( ( int * ) data + offset ) = value;
42- else if ( dtype == TF_DataType . TF_INT64 )
43- * ( ( long * ) data + offset ) = value;
44- else if ( dtype == TF_DataType . TF_FLOAT )
45- * ( ( float * ) data + offset ) = value;
46- else if ( dtype == TF_DataType . TF_DOUBLE )
47- * ( ( double * ) data + offset ) = value;
48- }
43+ throw new NotImplementedException ( "" ) ;
4944 }
5045 }
5146
52- public NDArray this [ params Slice [ ] slices]
47+ void SetData ( IEnumerable < Slice > slices , NDArray array )
48+ => SetData ( slices , array , - 1 , slices . Select ( x => 0 ) . ToArray ( ) ) ;
49+
50+ void SetData ( IEnumerable < Slice > slices , NDArray array , int currentNDim , int [ ] indices )
5351 {
54- get
55- {
56- return _tensor[ slices ] ;
57- }
52+ if ( dtype != array . dtype )
53+ throw new ArrayTypeMismatchException ( $ "Required dtype { dtype } but { array . dtype } is assigned.") ;
5854
59- set
55+ if ( ! slices . Any ( ) )
56+ return ;
57+
58+ var slice = slices . First ( ) ;
59+
60+ if ( slices . Count ( ) == 1 )
6061 {
61- var pos = _tensor [ slices ] ;
62- var len = value . bytesize ;
62+
63+ if ( slice . Step != 1 )
64+ throw new NotImplementedException ( "" ) ;
65+
66+ indices [ indices . Length - 1 ] = slice . Start ?? 0 ;
67+ var offset = ( ulong ) ShapeHelper . GetOffset ( shape , indices ) ;
68+ var bytesize = array . bytesize ;
6369 unsafe
6470 {
65- System . Buffer . MemoryCopy ( value . data . ToPointer ( ) , pos . TensorDataPointer . ToPointer ( ) , len , len ) ;
71+ var dst = ( byte * ) data + offset * dtypesize ;
72+ System . Buffer . MemoryCopy ( array . data . ToPointer ( ) , dst , bytesize , bytesize ) ;
6673 }
67- // _tensor[slices].assign(constant_op.constant(value));
68- }
69- }
7074
71- public NDArray this[ NDArray mask ]
72- {
73- get
74- {
75- throw new NotImplementedException ( "" ) ;
75+ return ;
7676 }
7777
78- set
78+ currentNDim ++ ;
79+ for ( var i = slice . Start ?? 0 ; i < slice . Stop ; i ++ )
7980 {
80-
81+ indices [ currentNDim ] = i ;
82+ SetData ( slices . Skip ( 1 ) , array , currentNDim , indices ) ;
8183 }
8284 }
8385 }
0 commit comments