Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ export type TableExpression =
| TableExpression.SubQuery
| TableExpression.CrossJoin
| TableExpression.QualifiedJoin
| TableExpression.FunctionCall

export namespace TableExpression {
export type Table = {
Expand All @@ -473,6 +474,19 @@ export namespace TableExpression {
return { kind: 'Table', table, as }
}

export type FunctionCall = {
kind: 'FunctionCall'
func: Expression.FunctionCall
as: string | null
}

export function createFunctionCall(
func: Expression.FunctionCall,
as: string | null
): FunctionCall {
return { kind: 'FunctionCall', func, as }
}

export type SubQuery = {
kind: 'SubQuery'
query: AST
Expand Down Expand Up @@ -537,6 +551,7 @@ export namespace TableExpression {
tableExpr: TableExpression,
handlers: {
table: (node: Table) => T
functionCall: (node: FunctionCall) => T
subQuery: (node: SubQuery) => T
crossJoin: (node: CrossJoin) => T
qualifiedJoin: (node: QualifiedJoin) => T
Expand All @@ -545,6 +560,8 @@ export namespace TableExpression {
switch (tableExpr.kind) {
case 'Table':
return handlers.table(tableExpr)
case 'FunctionCall':
return handlers.functionCall(tableExpr)
case 'SubQuery':
return handlers.subQuery(tableExpr)
case 'CrossJoin':
Expand Down
112 changes: 80 additions & 32 deletions src/infer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ function inferSelectBodyOutput(
function inferSelectListOutput(
client: SchemaClient,
outsideCTEs: VirtualTable[],
sourceColumns: SourceColumn[],
sourceColumns: (SourceColumn | VirtualField)[],
paramNullability: ParamNullability[],
conditions: Array<ast.Expression | null>,
selectList: ast.SelectListItem[]
Expand Down Expand Up @@ -350,7 +350,7 @@ function inferSelectListOutput(
function inferSelectListItemOutput(
client: SchemaClient,
outsideCTEs: VirtualTable[],
sourceColumns: SourceColumn[],
sourceColumns: (SourceColumn | VirtualField)[],
paramNullability: ParamNullability[],
nonNullExpressions: ast.Expression[],
selectListItem: ast.SelectListItem
Expand All @@ -368,7 +368,7 @@ function inferSelectListItemOutput(
),
Either.map((columns) =>
columns.map((column) => ({
name: column.columnName,
name: isSourceColumn(column) ? column.columnName : column.name,
nullability: column.nullability,
}))
)
Expand All @@ -384,7 +384,7 @@ function inferSelectListItemOutput(
),
Either.map((columns) =>
columns.map((column) => ({
name: column.columnName,
name: isSourceColumn(column) ? column.columnName : column.name,
nullability: column.nullability,
}))
)
Expand Down Expand Up @@ -416,19 +416,21 @@ type NonNullableColumn = { tableName: string | null; columnName: string }

function isColumnNonNullable(
nonNullableColumns: NonNullableColumn[],
sourceColumn: SourceColumn
sourceColumn: SourceColumn | VirtualField
): boolean {
return nonNullableColumns.some((nonNull) =>
nonNull.tableName
? sourceColumn.tableAlias === nonNull.tableName
: true && sourceColumn.columnName === nonNull.columnName
)
return isSourceColumn(sourceColumn)
? nonNullableColumns.some((nonNull) =>
nonNull.tableName
? sourceColumn.tableAlias === nonNull.tableName
: true && sourceColumn.columnName === nonNull.columnName
)
: !sourceColumn.nullability.nullable // TODO correct?
}

function applyExpressionNonNullability(
nonNullableExpressions: ast.Expression[],
sourceColumns: SourceColumn[]
): SourceColumn[] {
sourceColumns: (SourceColumn | VirtualField)[]
): (SourceColumn | VirtualField)[] {
const nonNullableColumns = pipe(
nonNullableExpressions,
R.map((expr) =>
Expand Down Expand Up @@ -464,7 +466,7 @@ function inferExpressionName(expression: ast.Expression): string {
function inferExpressionNullability(
client: SchemaClient,
outsideCTEs: VirtualTable[],
sourceColumns: SourceColumn[],
sourceColumns: (SourceColumn | VirtualField)[],
paramNullability: ParamNullability[],
nonNullExprs: ast.Expression[],
expression: ast.Expression
Expand All @@ -485,7 +487,13 @@ function inferExpressionNullability(
// have a NOT NULL constraint
tableColumnRef: ({ table, column }) =>
pipe(
InferM.fromEither(findSourceTableColumn(table, column, sourceColumns)),
InferM.fromEither(
findSourceTableColumn(
table,
column,
sourceColumns.filter(isSourceColumn)
)
),
InferM.map((column) => column.nullability)
),

Expand Down Expand Up @@ -1173,7 +1181,7 @@ function getSourceColumnsForTableExpr(
paramNullability: ParamNullability[],
tableExpr: ast.TableExpression | null,
setNullable = false
): InferM.InferM<SourceColumn[]> {
): InferM.InferM<(SourceColumn | VirtualField)[]> {
if (!tableExpr) {
return InferM.right([])
}
Expand All @@ -1182,6 +1190,22 @@ function getSourceColumnsForTableExpr(
ast.TableExpression.walk(tableExpr, {
table: ({ table, as }) =>
getSourceColumnsForTable(client, ctes, table, as),
functionCall: ({ func, as }): InferM.InferM<VirtualField[]> =>
InferM.map((nullability: FieldNullability) => [
{
name: as || inferExpressionName(func),
nullability,
},
])(
inferExpressionNullability(
client,
ctes,
[], // TODO?
paramNullability,
[], // TODO?
func
)
),
subQuery: ({ query, as }) =>
getSourceColumnsForSubQuery(client, ctes, paramNullability, query, as),
crossJoin: ({ left, right }) =>
Expand Down Expand Up @@ -1248,21 +1272,31 @@ function getSourceColumnsForSubQuery(
}

function setSourceColumnsAsNullable(
sourceColumns: SourceColumn[]
): SourceColumn[] {
return sourceColumns.map((col) => ({
...col,
nullability: { ...col.nullability, nullable: true },
}))
sourceColumns: (SourceColumn | VirtualField)[]
): (SourceColumn | VirtualField)[] {
return sourceColumns.map((col) =>
isSourceColumn(col)
? {
...col,
nullability: { ...col.nullability, nullable: true },
}
: {
...col,
nullability: { ...col.nullability, nullable: true },
}
)
}

function combineSourceColumns(
...sourceColumns: Array<InferM.InferM<SourceColumn[]>>
): InferM.InferM<SourceColumn[]> {
...sourceColumns: Array<InferM.InferM<(SourceColumn | VirtualField)[]>>
): InferM.InferM<(SourceColumn | VirtualField)[]> {
return pipe(
sourceColumns,
sequenceAIM,
InferM.map<SourceColumn[][], SourceColumn[]>(R.flatten)
InferM.map<
(SourceColumn | VirtualField)[][],
(SourceColumn | VirtualField)[]
>(R.flatten)
)
}

Expand All @@ -1272,11 +1306,19 @@ function isConstantExprOf(expectedValue: string, expr: ast.Expression) {
})
}

function isSourceColumn(c: SourceColumn | VirtualField): c is SourceColumn {
return !!(c as SourceColumn).tableAlias
}

function findNonHiddenSourceColumns(
sourceColumns: SourceColumn[]
): Either.Either<string, SourceColumn[]> {
sourceColumns: (SourceColumn | VirtualField)[]
): Either.Either<string, (SourceColumn | VirtualField)[]> {
return pipe(
sourceColumns.filter((col) => !col.hidden),
sourceColumns.filter(
(col) =>
(isSourceColumn(col) && !col.hidden) ||
!isSourceColumn(col) /* VirtualField's are never hidden */
),
Either.fromPredicate(
(result) => result.length > 0,
() => `No columns`
Expand All @@ -1286,12 +1328,14 @@ function findNonHiddenSourceColumns(

function findNonHiddenSourceTableColumns(
tableName: string,
sourceColumns: SourceColumn[]
sourceColumns: (SourceColumn | VirtualField)[]
): Either.Either<string, SourceColumn[]> {
return pipe(
findNonHiddenSourceColumns(sourceColumns),
Either.map((sourceColumns) =>
sourceColumns.filter((col) => col.tableAlias === tableName)
sourceColumns
.filter(isSourceColumn)
.filter((col) => col.tableAlias === tableName)
),
Either.chain((result) =>
result.length > 0
Expand All @@ -1317,10 +1361,14 @@ function findSourceTableColumn(

function findSourceColumn(
columnName: string,
sourceColumns: SourceColumn[]
): Either.Either<string, SourceColumn> {
sourceColumns: (SourceColumn | VirtualField)[]
): Either.Either<string, SourceColumn | VirtualField> {
return pipe(
sourceColumns.find((col) => col.columnName === columnName),
sourceColumns.find(
(col) =>
(isSourceColumn(col) && col.columnName === columnName) ||
(!isSourceColumn(col) && col.name === columnName)
),
Either.fromNullable(`Unknown column ${columnName}`)
)
}
Expand Down
18 changes: 18 additions & 0 deletions src/parser/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,24 @@ const tableExpression: Parser<TableExpression> = seq(
as
)((stmt, as) => TableExpression.createSubQuery(stmt, as))
),
attempt(
seq(
seq(
identifier,
optional(seq2(symbol('.'), identifier)),
functionArguments
)((ident, ident2, argList) =>
Expression.createFunctionCall(
ident2 ? ident : null,
ident2 ? ident2 : ident,
argList,
null,
null
)
),
optional(as)
)((fnCall, as) => TableExpression.createFunctionCall(fnCall, as))
),
table
),
many(oneOf(crossJoin, qualifiedJoin, naturalJoin))
Expand Down
33 changes: 28 additions & 5 deletions src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ export interface SchemaClient {
): TaskEither.TaskEither<string, Table>
getEnums(): Promise<Enum[]>
getArrayTypes(): Promise<ArrayType[]>
getFunction(
schemaName: string | null,
functionName: string
): TaskEither.TaskEither<string, SqlFunction>
functionNullSafety(
schemaName: string | null,
functionName: string
Expand All @@ -54,9 +58,7 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient {
tableName,
})
if (result.length === 0) {
return Either.left(
`No such table: ${fullTableName(schemaName, tableName)}`
)
return Either.left(`No such table: ${fullName(schemaName, tableName)}`)
}
return Either.right({
name: tableName,
Expand Down Expand Up @@ -86,6 +88,27 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient {
}))
)

// TODO: handle overloaded functions
const getFunction = (
schemaName: string | null,
functionName: string
): TaskEither.TaskEither<string, SqlFunction> => async () => {
const allFunctions = await getFunctions()
const res = allFunctions.find(
(f) =>
schemaName !== null &&
f.schema === schemaName &&
f.name === functionName
)
if (res) {
return Either.right(res)
} else {
return Either.left(
`No such function: ${fullName(schemaName, functionName)}`
)
}
}

const getFunctions = asyncCached(
async (): Promise<SqlFunction[]> =>
(await sql.functions(postgresClient)).map((row) => ({
Expand Down Expand Up @@ -117,9 +140,9 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient {
)
}

return { getTable, getEnums, getArrayTypes, functionNullSafety }
return { getTable, getEnums, getArrayTypes, getFunction, functionNullSafety }
}

function fullTableName(schemaName: string | null, tableName: string): string {
function fullName(schemaName: string | null, tableName: string): string {
return (schemaName ? schemaName + '.' : '') + tableName
}