@@ -7705,6 +7705,41 @@ def testEmbeddingVariableForSharedEmbeddingColumnsMultiCol(self):
77057705 for j in range (3 ):
77067706 self .assertAlmostEqual (emb_r [i ][j ], emb_right [i ][j ])
77077707
7708+ def testEmbeddingVariableForSharedPartitionedEmbeddingColumnsMultiCol (self ):
7709+ columns_list = []
7710+ columns_list .append (fc .categorical_column_with_embedding ("col_emb" , dtype = dtypes .string ))
7711+ columns_list .append (fc .categorical_column_with_embedding ("col_emb2" , dtype = dtypes .string ))
7712+ W = fc .shared_embedding_columns (columns_list ,
7713+ dimension = 3 ,
7714+ initializer = init_ops .ones_initializer (dtypes .float32 ),
7715+ shared_embedding_collection_name = "xxxxx_shared" )
7716+
7717+ ids = {}
7718+ ids ["col_emb" ] = sparse_tensor .SparseTensor (indices = [[0 ,0 ],[1 ,0 ],[2 ,0 ],[3 ,0 ],[4 ,0 ]], values = ["aaaa" ,"bbbbb" ,"ccc" ,"4nn" ,"5b" ], dense_shape = [5 , 5 ])
7719+ ids ["col_emb2" ] = sparse_tensor .SparseTensor (indices = [[0 ,0 ],[1 ,0 ],[2 ,0 ],[3 ,0 ],[4 ,0 ]], values = ["aaaa" ,"bbbbb" ,"ccc" ,"4nn" ,"5b" ], dense_shape = [5 , 5 ])
7720+ with variable_scope .variable_scope ("scope" ,partitioner = partitioned_variables .fixed_size_partitioner (4 )):
7721+ emb = fc_old .input_layer (ids , W )
7722+ fun = math_ops .multiply (emb , 2.0 , name = 'multiply' )
7723+ loss = math_ops .reduce_sum (fun , name = 'reduce_sum' )
7724+ opt = ftrl .FtrlOptimizer (0.1 , l1_regularization_strength = 2.0 , l2_regularization_strength = 0.00001 )
7725+ g_v = opt .compute_gradients (loss )
7726+ train_op = opt .apply_gradients (g_v )
7727+ init = variables_lib .global_variables_initializer ()
7728+
7729+ with self .test_session () as sess :
7730+ sess .run (init )
7731+ sess .run ([emb , train_op ,loss ])
7732+ sess .run ([emb , train_op ,loss ])
7733+ emb_r , _ , _ = sess .run ([emb , train_op ,loss ])
7734+ emb_right = [[0.7221214 , 0.7221214 , 0.7221214 ],
7735+ [0.7221214 , 0.7221214 , 0.7221214 ],
7736+ [0.7221214 , 0.7221214 , 0.7221214 ],
7737+ [0.7221214 , 0.7221214 , 0.7221214 ],
7738+ [0.7221214 , 0.7221214 , 0.7221214 ]]
7739+ for i in range (5 ):
7740+ for j in range (3 ):
7741+ self .assertAlmostEqual (emb_r [i ][j ], emb_right [i ][j ])
7742+
77087743 @test_util .run_deprecated_v1
77097744 def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum (self ):
77107745 columns_list = []
0 commit comments