diff --git a/__mocks__/typedData/example_enumNested.json b/__mocks__/typedData/example_enumNested.json new file mode 100644 index 000000000..6ee6a75bd --- /dev/null +++ b/__mocks__/typedData/example_enumNested.json @@ -0,0 +1,33 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [{ "name": "someEnum", "type": "enum", "contains": "EnumA" }], + "EnumA": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128,StructA)" } + ], + "StructA": [{ "name": "nestedEnum", "type": "enum", "contains": "EnumB" }], + "EnumB": [ + { "name": "Variant A", "type": "()" }, + { "name": "Variant B", "type": "(StructB*)" } + ], + "StructB": [{ "name": "flag", "type": "bool" }] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "someEnum": { + "Variant 2": [2, { "nestedEnum": { "Variant B": [[{ "flag": true }, { "flag": false }]] } }] + } + } +} diff --git a/__tests__/utils/typedData.test.ts b/__tests__/utils/typedData.test.ts index 945f12075..c21ea251d 100644 --- a/__tests__/utils/typedData.test.ts +++ b/__tests__/utils/typedData.test.ts @@ -3,6 +3,7 @@ import * as starkCurve from '@scure/starknet'; import typedDataExample from '../../__mocks__/typedData/baseExample.json'; import exampleBaseTypes from '../../__mocks__/typedData/example_baseTypes.json'; import exampleEnum from '../../__mocks__/typedData/example_enum.json'; +import exampleEnumNested from '../../__mocks__/typedData/example_enumNested.json'; import examplePresetTypes from '../../__mocks__/typedData/example_presetTypes.json'; import typedDataStructArrayExample from '../../__mocks__/typedData/mail_StructArray.json'; import typedDataSessionExample from '../../__mocks__/typedData/session_MerkleTree.json'; @@ -64,7 +65,11 @@ describe('typedData', () => { ); encoded = encodeType(exampleEnum.types, 'Example', TypedDataRevision.ACTIVE); expect(encoded).toMatchInlineSnapshot( - `"\\"Example\\"(\\"someEnum1\\":\\"EnumA\\",\\"someEnum2\\":\\"EnumB\\")\\"EnumA\\"(\\"Variant 1\\":(),\\"Variant 2\\":(\\"u128\\",\\"u128*\\"),\\"Variant 3\\":(\\"u128\\"))\\"EnumB\\"(\\"Variant 1\\":(),\\"Variant 2\\":(\\"u128\\"))"` + `"\\"Example\\"(\\"someEnum1\\":\\"EnumA\\",\\"someEnum2\\":\\"EnumB\\")\\"EnumA\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\",\\"u128*\\"),\\"Variant 3\\"(\\"u128\\"))\\"EnumB\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\"))"` + ); + encoded = encodeType(exampleEnumNested.types, 'Example', TypedDataRevision.ACTIVE); + expect(encoded).toMatchInlineSnapshot( + `"\\"Example\\"(\\"someEnum\\":\\"EnumA\\")\\"EnumA\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\",\\"StructA\\"))\\"EnumB\\"(\\"Variant A\\"(),\\"Variant B\\"(\\"StructB*\\"))\\"StructA\\"(\\"nestedEnum\\":\\"EnumB\\")\\"StructB\\"(\\"flag\\":\\"bool\\")"` ); }); @@ -104,7 +109,11 @@ describe('typedData', () => { ); typeHash = getTypeHash(exampleEnum.types, 'Example', TypedDataRevision.ACTIVE); expect(typeHash).toMatchInlineSnapshot( - `"0x8eb4aeac64b707f3e843284c4258df6df1f0f7fd38dcffdd8a153a495cd351"` + `"0x393bf83422ca8626a2932696cfa0acb19dcad6de2fe84a2dd2ca7607ea5329a"` + ); + typeHash = getTypeHash(exampleEnumNested.types, 'Example', TypedDataRevision.ACTIVE); + expect(typeHash).toMatchInlineSnapshot( + `"0x267f739fd83d30528a0fafb23df33b6c35ca0a5adbcfb32152721478fa9d0ce"` ); }); @@ -326,7 +335,12 @@ describe('typedData', () => { messageHash = getMessageHash(exampleEnum, exampleAddress); expect(messageHash).toMatchInlineSnapshot( - `"0x6e61abaf480b1370bbf231f54e298c5f4872f40a6d2dd409ff30accee5bbd1e"` + `"0x150a589bb56a4fbf4ee01f52e44fd5adde6af94c02b37e383413fed185321a2"` + ); + + messageHash = getMessageHash(exampleEnumNested, exampleAddress); + expect(messageHash).toMatchInlineSnapshot( + `"0x6e70eb4ef625dda451094716eee7f31fa81ca0ba99d390885e9c7b0d64cd22"` ); expect(spyPedersen).not.toHaveBeenCalled(); diff --git a/src/utils/typedData.ts b/src/utils/typedData.ts index 1d462283a..d16f34257 100644 --- a/src/utils/typedData.ts +++ b/src/utils/typedData.ts @@ -167,36 +167,46 @@ export function getDependencies( contains: string = '', revision: Revision = Revision.LEGACY ): string[] { + let dependencyTypes: string[] = [type]; + // Include pointers (struct arrays) if (type[type.length - 1] === '*') { - type = type.slice(0, -1); + dependencyTypes = [type.slice(0, -1)]; } else if (revision === Revision.ACTIVE) { // enum base if (type === 'enum') { - type = contains; + dependencyTypes = [contains]; } // enum element types else if (type.match(/^\(.*\)$/)) { - type = type.slice(1, -1); + dependencyTypes = type + .slice(1, -1) + .split(',') + .map((depType) => (depType[depType.length - 1] === '*' ? depType.slice(0, -1) : depType)); } } - if (dependencies.includes(type) || !types[type]) { - return dependencies; - } - - return [ - type, - ...(types[type] as StarknetEnumType[]).reduce( - (previous, t) => [ - ...previous, - ...getDependencies(types, t.type, previous, t.contains, revision).filter( - (dependency) => !previous.includes(dependency) - ), + return dependencyTypes + .filter((t) => !dependencies.includes(t) && types[t]) + .reduce( + // This comment prevents prettier from rolling everything here into a single line. + (p, depType) => [ + ...p, + ...[ + depType, + ...(types[depType] as StarknetEnumType[]).reduce( + (previous, t) => [ + ...previous, + ...getDependencies(types, t.type, previous, t.contains, revision).filter( + (dependency) => !previous.includes(dependency) + ), + ], + [] + ), + ].filter((dependency) => !p.includes(dependency)), ], [] - ), - ]; + ); } function getMerkleTreeType(types: TypedData['types'], ctx: Context) { @@ -266,8 +276,8 @@ export function encodeType( .split(',') .map((e) => (e ? esc(e) : e)) .join(',')})` - : esc(targetType); - return `${esc(t.name)}:${typeString}`; + : `:${esc(targetType)}`; + return `${esc(t.name)}${typeString}`; }); return `${esc(dependency)}(${dependencyElements})`; }) @@ -357,11 +367,13 @@ export function encodeValue( if (revision === Revision.ACTIVE) { const [variantKey, variantData] = Object.entries(data as TypedData['message'])[0]; - const parentType = types[ctx.parent as string].find((t) => t.name === ctx.key); - const enumType = types[(parentType as StarknetEnumType).contains]; + const parentType = types[ctx.parent as string].find((t) => t.name === ctx.key)!; + const enumName = (parentType as StarknetEnumType).contains; + const enumType = types[enumName]; const variantType = enumType.find((t) => t.name === variantKey) as StarknetType; const variantIndex = enumType.indexOf(variantType); + const typeHash = getTypeHash(types, enumName, revision); const encodedSubtypes = variantType.type .slice(1, -1) .split(',') @@ -372,7 +384,7 @@ export function encodeValue( }); return [ type, - revisionConfiguration[revision].hashMethod([variantIndex, ...encodedSubtypes]), + revisionConfiguration[revision].hashMethod([typeHash, variantIndex, ...encodedSubtypes]), ]; } // else fall through to default return [type, getHex(data as string)];