@@ -28,6 +28,7 @@ import {
2828 TupleAssignmentNode ,
2929 NullaryOpNode ,
3030 SliceNode ,
31+ IntLiteralNode ,
3132} from '../ast/AST.js' ;
3233import AstTraversal from '../ast/AstTraversal.js' ;
3334import {
@@ -57,17 +58,13 @@ export default class TypeCheckTraversal extends AstTraversal {
5758 if ( ! ( node . tuple instanceof BinaryOpNode ) || node . tuple . operator !== BinaryOperator . SPLIT ) {
5859 throw new TupleAssignmentError ( node . tuple ) ;
5960 }
60- const tupleType = node . tuple . left . type ;
61- for ( const variable of [ node . var1 , node . var2 ] ) {
62- if ( ! implicitlyCastable ( tupleType , variable . type ) ) {
63- // Ignore if both are of type byte. problem: bytes16 can be typed to bytes32
64- if ( tupleType instanceof BytesType && variable . type instanceof BytesType ) {
65- return node ;
66- }
67- throw new AssignTypeError (
68- new VariableDefinitionNode ( variable . type , [ ] , variable . name , node . tuple ) ,
69- ) ;
70- }
61+
62+ const assignmentType = new TupleType ( node . left . type , node . right . type ) ;
63+
64+ if ( ! implicitlyCastable ( node . tuple . type , assignmentType ) ) {
65+ throw new AssignTypeError (
66+ new VariableDefinitionNode ( assignmentType , [ ] , node . left . name , node . tuple ) ,
67+ ) ;
7168 }
7269 return node ;
7370 }
@@ -240,12 +237,7 @@ export default class TypeCheckTraversal extends AstTraversal {
240237 case BinaryOperator . SPLIT :
241238 expectAnyOfTypes ( node , node . left . type , [ new BytesType ( ) , PrimitiveType . STRING ] ) ;
242239 expectInt ( node , node . right . type ) ;
243-
244- // Result of split are two unbounded bytes types (could be improved to do type inference)
245- node . type = new TupleType (
246- node . left . type instanceof BytesType ? new BytesType ( ) : PrimitiveType . STRING ,
247- node . left . type instanceof BytesType ? new BytesType ( ) : PrimitiveType . STRING ,
248- ) ;
240+ node . type = inferTupleType ( node ) ;
249241 return node ;
250242 default :
251243 return node ;
@@ -346,9 +338,7 @@ export default class TypeCheckTraversal extends AstTraversal {
346338type ExpectedNode = BinaryOpNode | UnaryOpNode | TimeOpNode | TupleIndexOpNode | SliceNode ;
347339function expectAnyOfTypes ( node : ExpectedNode , actual ?: Type , expectedTypes ?: Type [ ] ) : void {
348340 if ( ! expectedTypes || expectedTypes . length === 0 ) return ;
349- if ( expectedTypes . find ( ( expected ) => implicitlyCastable ( actual , expected ) ) ) {
350- return ;
351- }
341+ if ( expectedTypes . find ( ( expected ) => implicitlyCastable ( actual , expected ) ) ) return ;
352342
353343 throw new UnsupportedTypeError ( node , actual , expectedTypes [ 0 ] ) ;
354344}
@@ -392,3 +382,37 @@ function expectParameters(node: NodeWithParameters, actual: Type[], expected: Ty
392382 throw new InvalidParameterTypeError ( node , actual , expected ) ;
393383 }
394384}
385+
386+ // We only call this function for the split operator, so we assume that the node.op is SPLIT
387+ function inferTupleType ( node : BinaryOpNode ) : Type {
388+ // string.split() -> string, string
389+ if ( node . left . type === PrimitiveType . STRING ) {
390+ return new TupleType ( PrimitiveType . STRING , PrimitiveType . STRING ) ;
391+ }
392+
393+ // If the expression is not a bytes type, then it must be a different compatible type (e.g. sig/pubkey)
394+ // We treat this as an unbounded bytes type for the purposes of splitting
395+ const expressionType = node . left . type instanceof BytesType ? node . left . type : new BytesType ( ) ;
396+
397+ // bytes.split(variable) -> bytes, bytes
398+ if ( ! ( node . right instanceof IntLiteralNode ) ) {
399+ return new TupleType ( new BytesType ( ) , new BytesType ( ) ) ;
400+ }
401+
402+ const splitIndex = Number ( node . right . value ) ;
403+
404+ // bytes.split(NumberLiteral) -> bytes(NumberLiteral), bytes
405+ if ( expressionType . bound === undefined ) {
406+ return new TupleType ( new BytesType ( splitIndex ) , new BytesType ( ) ) ;
407+ }
408+
409+ if ( splitIndex > expressionType . bound ) {
410+ throw new IndexOutOfBoundsError ( node ) ;
411+ }
412+
413+ // bytesX.split(NumberLiteral) -> bytes(NumberLiteral), bytes(X - NumberLiteral)
414+ return new TupleType (
415+ new BytesType ( splitIndex ) ,
416+ new BytesType ( expressionType . bound ! - splitIndex ) ,
417+ ) ;
418+ }
0 commit comments