1+ using Microsoft . VisualStudio . TestTools . UnitTesting ;
2+ using System . Linq ;
3+ using Tensorflow ;
4+ using Tensorflow . Keras . Engine ;
5+ using static Tensorflow . Binding ;
6+ using static Tensorflow . KerasApi ;
7+ using Tensorflow . NumPy ;
8+
9+ namespace TensorFlowNET . Keras . UnitTest ;
10+
11+ [ TestClass ]
12+ public class GradientTest
13+ {
14+ public Model get_actor ( int num_states )
15+ {
16+ var inputs = keras . layers . Input ( shape : num_states ) ;
17+ var outputs = keras . layers . Dense ( 1 , activation : keras . activations . Tanh ) . Apply ( inputs ) ;
18+
19+ Model model = keras . Model ( inputs , outputs ) ;
20+
21+ return model ;
22+ }
23+
24+ public Model get_critic ( int num_states , int num_actions )
25+ {
26+ // State as input
27+ var state_input = keras . layers . Input ( shape : num_states ) ;
28+
29+ // Action as input
30+ var action_input = keras . layers . Input ( shape : num_actions ) ;
31+
32+ var concat = keras . layers . Concatenate ( axis : 1 ) . Apply ( new Tensors ( state_input , action_input ) ) ;
33+
34+ var outputs = keras . layers . Dense ( 1 ) . Apply ( concat ) ;
35+
36+ Model model = keras . Model ( new Tensors ( state_input , action_input ) , outputs ) ;
37+ model . summary ( ) ;
38+
39+ return model ;
40+ }
41+
42+ [ TestMethod ]
43+ public void GetGradient_Test ( )
44+ {
45+ var numStates = 3 ;
46+ var numActions = 1 ;
47+ var batchSize = 64 ;
48+ var gamma = 0.99f ;
49+
50+ var target_actor_model = get_actor ( numStates ) ;
51+ var target_critic_model = get_critic ( numStates , numActions ) ;
52+ var critic_model = get_critic ( numStates , numActions ) ;
53+
54+ Tensor state_batch = tf . convert_to_tensor ( np . zeros ( ( batchSize , numStates ) ) , TF_DataType . TF_FLOAT ) ;
55+ Tensor action_batch = tf . convert_to_tensor ( np . zeros ( ( batchSize , numActions ) ) , TF_DataType . TF_FLOAT ) ;
56+ Tensor reward_batch = tf . convert_to_tensor ( np . zeros ( ( batchSize , 1 ) ) , TF_DataType . TF_FLOAT ) ;
57+ Tensor next_state_batch = tf . convert_to_tensor ( np . zeros ( ( batchSize , numStates ) ) , TF_DataType . TF_FLOAT ) ;
58+
59+ using ( var tape = tf . GradientTape ( ) )
60+ {
61+ var target_actions = target_actor_model . Apply ( next_state_batch , training : true ) ;
62+ var target_critic_value = target_critic_model . Apply ( new Tensors ( new Tensor [ ] { next_state_batch , target_actions } ) , training : true ) ;
63+
64+ var y = reward_batch + tf . multiply ( gamma , target_critic_value ) ;
65+
66+ var critic_value = critic_model . Apply ( new Tensors ( new Tensor [ ] { state_batch , action_batch } ) , training : true ) ;
67+
68+ var critic_loss = math_ops . reduce_mean ( math_ops . square ( y - critic_value ) ) ;
69+
70+ var critic_grad = tape . gradient ( critic_loss , critic_model . TrainableVariables ) ;
71+
72+ Assert . IsNotNull ( critic_grad ) ;
73+ Assert . IsNotNull ( critic_grad . First ( ) ) ;
74+ }
75+ }
76+ }
0 commit comments