@@ -170,36 +170,33 @@ MutableTerm::compare(const MutableTerm &other, RewriteContext &ctx) const {
170170 return shortlexCompare (begin (), end (), other.begin (), other.end (), ctx);
171171}
172172
173- // / Replace the subterm in the range [from,to) with \p rhs. The subrange must
174- // / be part of this term itself.
175- // /
176- // / Note that \p rhs must precede [from,to) in the linear order on terms.
173+ // / Replace the subterm in the range [from,to) of this term with \p rhs.
177174void MutableTerm::rewriteSubTerm (Symbol *from, Symbol *to, Term rhs) {
178175 auto oldSize = size ();
179176 unsigned lhsLength = (unsigned )(to - from);
180- assert (rhs. size () <= lhsLength);
181-
182- // Overwrite the occurrence of the left hand side with the
183- // right hand side.
184- auto newIter = std::copy (rhs. begin (), rhs. end (), from);
185-
186- // If the right hand side is shorter than the left hand side,
187- // then newIter will point to a location before oldIter, eg
188- // if this term is 'T.A.B.C', lhs is 'A.B' and rhs is 'X',
189- // then we now have:
190- //
191- // T.X .C
192- // ^--- oldIter
193- // ^--- newIter
194- //
195- // Shift everything over to close the gap (by one location,
196- // in this case).
197- if (newIter != to) {
198- auto newEnd = std::copy (to, end (), newIter );
199-
200- // Now, we've moved the gap to the end of the term; close
201- // it by shortening the term.
202- Symbols.erase (newEnd, end ());
177+
178+ if (lhsLength == rhs. size ()) {
179+ // Copy the RHS to the LHS.
180+ auto newTo = std::copy (rhs. begin (), rhs. end (), from);
181+
182+ // The RHS has the same length as the LHS, so we're done.
183+ assert (newTo == to);
184+ ( void ) newTo;
185+ } else if (lhsLength > rhs. size ()) {
186+ // Copy the RHS to the LHS.
187+ auto newTo = std::copy (rhs. begin (), rhs. end (), from);
188+
189+ // Shorten the term.
190+ Symbols. erase (newTo, to);
191+ } else {
192+ assert (lhsLength < rhs. size ());
193+
194+ // Copy the LHS-sized prefix of RHS to the LHS.
195+ auto newTo = std::copy (rhs. begin (), rhs. begin () + lhsLength, from );
196+ assert (newTo == to);
197+
198+ // Insert the remainder of the RHS term.
199+ Symbols.insert (to, rhs. begin () + lhsLength, rhs. end ());
203200 }
204201
205202 assert (size () == oldSize - lhsLength + rhs.size ());
0 commit comments