33using Tensorflow . Keras . ArgsDefinition ;
44using static Tensorflow . Binding ;
55using Tensorflow . Keras . Utils ;
6+ using Tensorflow . Util ;
7+ using Tensorflow . Framework ;
68
79namespace Tensorflow . Keras . Engine . DataAdapters
810{
@@ -24,6 +26,7 @@ public class DataHandler
2426 long _steps_per_execution_value ;
2527 int _initial_epoch => args . InitialEpoch ;
2628 int _epochs => args . Epochs ;
29+ NDArray _sample_weight => args . SampleWeight ;
2730 IVariableV1 _steps_per_execution ;
2831
2932 public DataHandler ( DataHandlerArgs args )
@@ -75,10 +78,75 @@ public DataHandler(DataHandlerArgs args)
7578 }
7679
7780 _dataset = _adapter . GetDataset ( ) ;
78- _inferred_steps = _infer_steps ( args . StepsPerEpoch , _dataset ) ;
7981 _current_step = 0 ;
8082 _step_increment = _steps_per_execution_value - 1 ;
8183 _insufficient_data = false ;
84+ _configure_dataset_and_inferred_steps ( args . X , args . ClassWeight ) ;
85+ }
86+
87+ void _configure_dataset_and_inferred_steps ( Tensors x , Dictionary < int , float > class_weight )
88+ {
89+ if ( _dataset == null )
90+ {
91+ _dataset = _adapter . GetDataset ( ) ;
92+ _inferred_steps = _infer_steps ( args . StepsPerEpoch , _dataset ) ;
93+ }
94+
95+ if ( class_weight != null )
96+ {
97+ _dataset = _dataset . map ( _make_class_weight_map_fn ( class_weight ) ) ;
98+ }
99+ _inferred_steps = _infer_steps ( args . StepsPerEpoch , _dataset ) ;
100+ }
101+
102+
103+ Func < Tensors , Tensors > _make_class_weight_map_fn ( Dictionary < int , float > class_weight )
104+ {
105+ var class_ids = class_weight . Keys . OrderBy ( key => key ) . ToList ( ) ;
106+ var expected_class_ids = range ( class_ids [ 0 ] , class_ids [ class_ids . Count - 1 ] + 1 ) ;
107+ if ( ! class_ids . SequenceEqual ( expected_class_ids ) )
108+ {
109+ throw new ValueError ( "Expected `class_weight` to be a dict with keys from 0 to one less " +
110+ $ "than the number of classes, found { class_weight } ") ;
111+ }
112+
113+ var class_weight_list = new List < float > ( ) ;
114+ foreach ( var class_id in class_ids )
115+ {
116+ class_weight_list . Add ( class_weight [ class_id ] ) ;
117+ }
118+ var class_weight_tensor = tf . convert_to_tensor ( class_weight_list . ToArray ( ) ) ;
119+
120+ Func < Tensors , Tensors > _class_weight_map_fn = ( Tensors data ) =>
121+ {
122+ var x = data [ 0 ] ;
123+ var y = data [ 1 ] ;
124+ var sw = _sample_weight == null ? null : ops . convert_to_tensor ( _sample_weight ) ;
125+
126+ if ( y . shape . rank > 2 )
127+ {
128+ throw new ValueError ( "`class_weight` not supported for 3+ dimensional targets." ) ;
129+ }
130+
131+ var y_classes = smart_module . smart_cond (
132+ y . shape . rank == 2 && y . shape [ 1 ] > 1 ,
133+ ( ) => math_ops . argmax ( y , dimension : 1 ) ,
134+ ( ) => math_ops . cast ( tf . reshape ( y , ( - 1 ) ) , TF_DataType . TF_INT64 ) ) ;
135+
136+ var cw = array_ops . gather ( class_weight_tensor , y_classes ) ;
137+ if ( sw != null )
138+ {
139+ cw = tf . cast ( cw , sw . dtype ) ;
140+ cw *= sw ;
141+ }
142+ else
143+ {
144+ sw = cw ;
145+ }
146+ return new Tensors { x , y , sw } ;
147+ } ;
148+
149+ return _class_weight_map_fn ;
82150 }
83151
84152 long _infer_steps ( int steps_per_epoch , IDatasetV2 dataset )
0 commit comments