@@ -119,13 +119,69 @@ def test_retry_on_remote_disconnected(self):
119119
120120 def test_flush_stats_with_tags (self ):
121121 lambda_stats = ThreadStatsWriter (True )
122+ original_constant_tags = lambda_stats .thread_stats .constant_tags .copy ()
122123 tags = ["tag1:value1" , "tag2:value2" ]
123- lambda_stats .flush (tags )
124- self .mock_threadstats_flush_distributions .assert_called_once_with (
125- lambda_stats .thread_stats ._get_aggregate_metrics_and_dists (float ("inf" ))[1 ]
126- )
127- for tag in tags :
128- self .assertTrue (tag in lambda_stats .thread_stats .constant_tags )
124+
125+ # Add a metric to be flushed
126+ lambda_stats .distribution ("test.metric" , 1 , tags = ["metric:tag" ])
127+
128+ with patch .object (
129+ lambda_stats .thread_stats .reporter , "flush_distributions"
130+ ) as mock_flush_distributions :
131+ lambda_stats .flush (tags )
132+ mock_flush_distributions .assert_called_once ()
133+ # Verify that after flush, constant_tags is reset to original
134+ self .assertEqual (
135+ lambda_stats .thread_stats .constant_tags , original_constant_tags
136+ )
137+
138+ def test_flush_temp_constant_tags (self ):
139+ lambda_stats = ThreadStatsWriter (flush_in_thread = True )
140+ lambda_stats .thread_stats .constant_tags = ["initial:tag" ]
141+ original_constant_tags = lambda_stats .thread_stats .constant_tags .copy ()
142+
143+ lambda_stats .distribution ("test.metric" , 1 , tags = ["metric:tag" ])
144+ flush_tags = ["flush:tag1" , "flush:tag2" ]
145+
146+ with patch .object (
147+ lambda_stats .thread_stats .reporter , "flush_distributions"
148+ ) as mock_flush_distributions :
149+ lambda_stats .flush (tags = flush_tags )
150+ mock_flush_distributions .assert_called_once ()
151+ flushed_dists = mock_flush_distributions .call_args [0 ][0 ]
152+
153+ # Expected tags: original constant_tags + flush_tags + metric tags
154+ expected_tags = original_constant_tags + flush_tags + ["metric:tag" ]
155+
156+ # Verify the tags on the metric
157+ self .assertEqual (len (flushed_dists ), 1 )
158+ metric = flushed_dists [0 ]
159+ self .assertEqual (sorted (metric ["tags" ]), sorted (expected_tags ))
160+
161+ # Verify that constant_tags is reset after flush
162+ self .assertEqual (
163+ lambda_stats .thread_stats .constant_tags , original_constant_tags
164+ )
165+
166+ # Repeat to ensure tags do not accumulate over multiple flushes
167+ new_flush_tags = ["flush:tag3" ]
168+ lambda_stats .distribution ("test.metric2" , 2 , tags = ["metric2:tag" ])
169+
170+ with patch .object (
171+ lambda_stats .thread_stats .reporter , "flush_distributions"
172+ ) as mock_flush_distributions :
173+ lambda_stats .flush (tags = new_flush_tags )
174+ mock_flush_distributions .assert_called_once ()
175+ flushed_dists = mock_flush_distributions .call_args [0 ][0 ]
176+ # Expected tags for the new metric
177+ expected_tags = original_constant_tags + new_flush_tags + ["metric2:tag" ]
178+
179+ self .assertEqual (len (flushed_dists ), 1 )
180+ metric = flushed_dists [0 ]
181+ self .assertEqual (sorted (metric ["tags" ]), sorted (expected_tags ))
182+ self .assertEqual (
183+ lambda_stats .thread_stats .constant_tags , original_constant_tags
184+ )
129185
130186 def test_flush_stats_without_context (self ):
131187 flush_stats (lambda_context = None )
0 commit comments