@@ -8,16 +8,16 @@ namespace Tensorflow.NumPy
88{
99 public partial class NDArray
1010 {
11- public NDArray this [ params int [ ] index ]
11+ public NDArray this [ params int [ ] indices ]
1212 {
13- get => GetData ( index . Select ( x => new Slice
13+ get => GetData ( indices . Select ( x => new Slice
1414 {
1515 Start = x ,
1616 Stop = x + 1 ,
1717 IsIndex = true
1818 } ) ) ;
1919
20- set => SetData ( index . Select ( x =>
20+ set => SetData ( indices . Select ( x =>
2121 {
2222 if ( x < 0 )
2323 x = ( int ) dims [ 0 ] + x ;
@@ -57,12 +57,37 @@ public NDArray this[NDArray mask]
5757
5858 NDArray GetData ( IEnumerable < Slice > slices )
5959 {
60- var tensor = _tensor [ slices . ToArray ( ) ] ;
61- return new NDArray ( tensor ) ;
60+ if ( shape . IsScalar )
61+ return GetScalar ( ) ;
62+
63+ var tensor = base [ slices . ToArray ( ) ] ;
64+ if ( tensor . Handle == null )
65+ tensor = tf . defaultSession . eval ( tensor ) ;
66+ return new NDArray ( tensor . Handle ) ;
67+ }
68+
69+ unsafe T GetAtIndex < T > ( params int [ ] indices ) where T : unmanaged
70+ {
71+ var offset = ( ulong ) ShapeHelper . GetOffset ( shape , indices ) ;
72+ return * ( ( T * ) data + offset ) ;
73+ }
74+
75+ NDArray GetScalar ( )
76+ {
77+ var array = new NDArray ( Shape . Scalar , dtype : dtype ) ;
78+ unsafe
79+ {
80+ var src = ( byte * ) data + dtypesize ;
81+ System . Buffer . MemoryCopy ( src , array . buffer . ToPointer ( ) , bytesize , bytesize ) ;
82+ }
83+ return array ;
6284 }
6385
6486 NDArray GetData ( int [ ] indices , int axis = 0 )
6587 {
88+ if ( shape . IsScalar )
89+ return GetScalar ( ) ;
90+
6691 if ( axis == 0 )
6792 {
6893 var dims = shape . as_int_list ( ) ;
0 commit comments