engine

package
v0.10.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 20, 2025 License: Apache-2.0 Imports: 4 Imported by: 0

README

Kwil Engine

This document outlines Kwil's engine, which is responsible for handling all database-related functionality.

The engine has the following responsibilities:

  • Accepting DDL (CREATE TABLE, CREATE ACTION, etc.) statements, converting them to a structured format which can be held in memory, and persisting them within the DB.
  • Accepting SQL statements, parsing them, rewriting them to be deterministic, and executing them against the database.
  • Executing actions that have been defined with a CREATE ACTION statement.
  • Storing rules and enforcing access control rules for them.
  • Managing developer-defined precompiles (extensions).
  • Making all of the above operations deterministic.

How The Engine Works

The engine has two functionalities: execute and call. execute is used when executing a raw statement, such as a SQL statement or a DDL statement. call is used when executing an action that has been defined either using a CREATE ACTION statement or as a precompile. When both execute and call are used, they are run against the *baseInterpreter struct, which holds important metadata in-memory for the lifetime of the node.

The interpreter "interprets" a statement passed to execute by parsing the statement and traversing the AST, executing logic specific to each node in the AST's tree as it is encountered. It does this by converting statements (or parts of statements) into one of three functions: execFunc, stmtFunc, exprFunc:

// execFunc is a block of code that can be called with a set of ordered inputs.
// For example, built-in SQL functions like ABS() and FORMAT(), or user-defined
// actions, all require arguments to be passed in a specific order.
type execFunc func(exec *executionContext, args []value, returnFn resultFunc) error

// stmtFunc is a block of code that executes a "statement" from the AST.
// "statements" are language features such as:
// - sql: INSERT/UPDATE/DELETE/SELECT
// - ddl: CREATE/ALTER/DROP
// - action logic: FOR loops / IF clauses / variable assignment
type stmtFunc func(exec *executionContext, fn resultFunc) error

// exprFunc is a function that returns a single value.
// It is used to represent pieces of logic that should evaluate to
// exactly one thing (e.g. arithmetic, comparison, etc.)
type exprFunc func(exec *executionContext) (value, error)

Notice that a resultFunc is passed around to execFunc and stmtFunc functions. The resultFunc allows the interpreter to progressively write results while the interpreter executes. This means that if an action or statement returns many rows of data, the interpreter will only read each row (and perform subsequent execution logic) as needed. In previous versions of Kwil which relied on the PL/pgSQL interpreter, it would read all requisite data from disk before processing it. The resultFunc allows us to avoid this.

Understanding execute

execute takes a statement (or a group of statements delimited by ;), converts them into stmtFuncs, and executes them. The implementation of each stmtFunc depends on the statement passed, but there are generally 4 types of stmtFuncs.

  • SQL: an INSERT/UPDATE/DELETE/SELECT statement that is immediately executed against the database. More information on SQL statements is included below in this document.
  • DDL: any DDL that does something like creating/alterting/deleting a table, index, role, etc. The implementation for DDL statements are all different and very implementation-specific, but also quite simple.
  • CREATE ACTION: while technically a DDL, creating an action involves several extra concepts not used in other DDL. It converts each statement within the action body into a reusable stmtFunc, and then wraps them into a single execFunc, which is cached and can be reused later.
  • USE : also technically a DDL, but performs special logic involving extra concepts to initialize a developer-defined extension. It wraps the user-defined behavior into an execFunc, which is cached and can be reused later.

Note: these 4 types of stmtFuncs are not explicitly defined anywhere in the code. It is simply a broad characterization, and there are some statements that do not fall into any 4 of these (e.g. SET CURRENT NAMESPACE).

Understanding call

The interpreter's call functionality allows a user to execute either an action or an extension method. It does this by accessing the locally cached execFunc for the action or extension method.

SQL Queries

There are special considerations that the engine takes into account when executing SQL queries. These considerations are applicable both for ad-hoc queries during execute, and for queries within an action / extension.

By default, SQL queries are not deterministic. I won't list all forms of non-determinism here, but one basic example is the order of returned results; SELECT * FROM table can return rows in any order. Since this breaks the determinism requirements of Kwil, Kwil rewrites queries to be deterministic (e.g. guaranteeing ordering). It does this by converting the SQL statements to a "logical plan".

A "logical plan" is a mathematical representation of operations applied on a "relation" (a table). It is based off of relational algebra; it is highly recommended that you learn basic relational algebra if you need to understand Kwil's query planner. Kwil converts a SQL statement to our own modified version of relational algebra, identifies areas of non-determinism, applies extra logic on these areas to make them deterministic, and then re-generates Postgres-compliant SQL. If a query doesn't need to be deterministic (if it is being used outside of consensus by a read-only RPC call), it will not have this additional logic applied.

Structure of the Engine Code

The engine code has the following structure:

  • /: The root directory (which this README is contained in) contains common pieces of code used throughout the rest of the subdirectories in the engine. These are primarily common types, errors, and lists of constants (e.g. functions).
  • /interpreter: The main entry-point for the engine is in the interpreter package. This defines the logic for how statements are interpreted, how data is stored on disk and represented in-memory, and how other packages are used and called. If you are new to this section of the code, I would recommend starting in /interpreter/interpreter.go, and branching out to other files and packages that are used within there.
  • /parse: The parse package implements the parser for all of SQL Smart Contracts. It uses Antlr v4 as a parser-generator, and defines the languages AST, grammar rules, and other basic syntax validations.
  • /pg_generate: The pg_generate package is a very simple package that allows ghenerating Postgres-compatible SQL from Kwil's SQL AST.
  • /planner: The planner package implements Kwil's deterministic query planner. It has two sub-packages: logical and optimizer. The optimizer package is currently unused. The logical package contains the logical query planner.

Documentation

Index

Constants

View Source
const (
	// ReservedKwilNamespacePrefix is the prefix that is reserved for namespaces that are not part of the engine
	// directly, but are built in by Kwil via extensions.
	ReservedKwilNamespacePrefix = "kwil_"
	ReservedPGNamespacePrefix   = "pg_"
)
View Source
const (
	DefaultNamespace       = "main"
	InfoNamespace          = "info"
	InternalEnginePGSchema = "kwild_engine"
)

Variables

View Source
var (
	// Errors that suggest a bug in a user's executing code. These are type errors,
	// issues in arithmetic, array indexing, etc.
	ErrType                    = errors.New("type error")
	ErrReturnShape             = errors.New("unexpected action/function return shape")
	ErrUnknownVariable         = errors.New("unknown variable")
	ErrInvalidVariable         = errors.New("invalid variable name")
	ErrLoop                    = errors.New("loop error")
	ErrArithmetic              = errors.New("arithmetic error")
	ErrComparison              = errors.New("comparison error")
	ErrCast                    = errors.New("type cast error")
	ErrUnary                   = errors.New("unary operation error")
	ErrIndexOutOfBounds        = errors.New("index out of bounds")
	ErrArrayDimensionality     = errors.New("array dimensionality error")
	ErrInvalidNull             = errors.New("invalid null value")
	ErrArrayTooSmall           = errors.New("array too small")
	ErrExtensionImplementation = errors.New("extension implementation error")
	ErrActionInvocation        = errors.New("action invocation error")

	// Errors that signal the existence or non-existence of an object.
	ErrUnknownAction     = errors.New("unknown action")
	ErrUnknownTable      = errors.New("unknown table")
	ErrNamespaceNotFound = errors.New("namespace not found")
	ErrNamespaceExists   = errors.New("namespace already exists")

	// Errors that likely are not the result of a user error, but instead are informing
	// the user of an operation that is not allowed in order to maintain the integrity of
	// the system.
	ErrCannotMutateState          = errors.New("connection is read-only and cannot mutate state")
	ErrIllegalFunctionUsage       = errors.New("illegal function usage")
	ErrQueryActive                = errors.New("a query is currently active. nested queries are not allowed")
	ErrCannotBeNamespaced         = errors.New("the selected object is global-only, and cannot be namespaced")
	ErrCannotMutateExtension      = errors.New("cannot mutate an extension's schema or data directly")
	ErrCannotMutateInfoNamespace  = errors.New(`cannot mutate the "info" namespace directly`)
	ErrCannotDropBuiltinNamespace = errors.New("cannot drop a built-in namespace")
	ErrBuiltInRole                = errors.New("invalid operation on built-in role")
	ErrInvalidTxCtx               = errors.New("invalid transaction context")
	ErrReservedNamespacePrefix    = errors.New("namespace prefix is reserved")
	ErrCannotAlterPrimaryKey      = errors.New("cannot drop or alter a table's primary key")

	// Errors that are the result of not having proper permissions or failing to meet a condition
	// that was programmed by the user.
	ErrActionOwnerOnly      = errors.New("action is owner-only")
	ErrActionPrivate        = errors.New("action is private")
	ErrActionSystemOnly     = errors.New("action is system-only")
	ErrDoesNotHavePrivilege = errors.New("user does not have privilege")

	// Errors that signal an error in a deeper layer, and originate from
	// somewhere deeper than the interpreter.
	ErrParse        = errors.New("parse error")
	ErrQueryPlanner = errors.New("query planner error")
	ErrPGGen        = errors.New("postgres SQL generation error")
)
View Source
var (
	Functions = map[string]FunctionDefinition{
		"abs": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.IntType) && args[0].Name != types.NumericStr {
					return nil, fmt.Errorf("%w: expected argument to be int or decimal, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("abs"),
		},
		"error": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				def, err := defaultFormat("error")(inputs)
				if err != nil {
					return "", err
				}

				return def + "::text", nil
			},
		},
		"parse_unix_timestamp": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return decimal16_6, nil
			},
			PGFormatFunc: defaultFormat("parse_unix_timestamp"),
		},
		"format_unix_timestamp": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(decimal16_6) {
					return nil, wrapErrArgumentType(decimal16_6, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("format_unix_timestamp"),
		},
		"notice": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.NullType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				return "", fmt.Errorf(`%w: "notice" cannot be used in SQL statements`, ErrIllegalFunctionUsage)
			},
		},
		"uuid_generate_v5": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.UUIDType) {
					return nil, wrapErrArgumentType(types.UUIDType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.UUIDType, nil
			},
			PGFormatFunc: defaultFormat("uuid_generate_v5"),
		},

		"uuid_generate_kwil": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.UUIDType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				return "uuid_generate_v5('a247cac1-d817-4949-bac7-dc4b1dc41d09'::uuid," + inputs[0] + ")", nil
			},
		},
		"encode": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.ByteaType) {
					return nil, wrapErrArgumentType(types.ByteaType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("encode"),
		},
		"decode": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.ByteaType, nil
			},
			PGFormatFunc: defaultFormat("decode"),
		},
		"digest": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.TextType) && !args[0].Equals(types.ByteaType) {
					return nil, fmt.Errorf("%w: expected first argument to be text or blob, got %s", ErrType, args[0].String())
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.ByteaType, nil
			},
			PGFormatFunc: defaultFormat("digest"),
		},

		"array_append": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String())
				}

				if args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be a scalar, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: append type must be equal to scalar array type. array type: %s append type: %s", ErrType, args[0].Name, args[1].Name)
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("array_append"),
		},
		"array_prepend": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be a scalar, got %s", ErrType, args[0].String())
				}

				if !args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: prepend type must be equal to scalar array type. array type: %s prepend type: %s", ErrType, args[1].Name, args[0].Name)
				}

				return args[1], nil
			},
			PGFormatFunc: defaultFormat("array_prepend"),
		},
		"array_cat": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String())
				}

				if !args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: expected both arrays to be of the same scalar type, got %s and %s", ErrType, args[0].Name, args[1].Name)
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("array_cat"),
		},
		"array_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected argument to be an array, got %s", ErrType, args[0].String())
				}

				if len(args) == 2 && !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				return types.IntType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				if len(inputs) == 1 {
					return fmt.Sprintf("array_length(%s, 1)", inputs[0]), nil
				}
				return fmt.Sprintf("array_length(%s, %s)", inputs[0], inputs[1]), nil
			},
		},
		"array_remove": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String())
				}

				if args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be a scalar, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: remove type must be equal to scalar array type. array type: %s remove type: %s", ErrType, args[0].Name, args[1].Name)
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("array_remove"),
		},

		"bit_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("bit_length"),
		},
		"char_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("char_length"),
		},
		"character_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("character_length"),
		},
		"length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("length"),
		},
		"lower": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("lower"),
		},
		"lpad": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 2 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				if len(args) == 3 && !args[2].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[2])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("lpad(")
				str.WriteString(inputs[0])
				str.WriteString(", ")
				str.WriteString(inputs[1])
				str.WriteString("::INT4")
				if len(inputs) == 3 {
					str.WriteString(", ")
					str.WriteString(inputs[2])
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"ltrim": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("ltrim"),
		},
		"octet_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("octet_length"),
		},
		"overlay": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 3 || len(args) > 4 {
					return nil, fmt.Errorf("invalid number of arguments: expected 3 or 4, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				if !args[2].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[2])
				}

				if len(args) == 4 && !args[3].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[3])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("overlay(")
				str.WriteString(inputs[0])
				str.WriteString(" placing ")
				str.WriteString(inputs[1])
				str.WriteString(" from ")
				str.WriteString(inputs[2])
				str.WriteString("::INT4")
				if len(inputs) == 4 {
					str.WriteString(" for ")
					str.WriteString(inputs[3])
					str.WriteString("::INT4")
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"position": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.IntType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				return fmt.Sprintf("position(%s in %s)", inputs[0], inputs[1]), nil
			},
		},
		"rpad": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 2 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				if len(args) == 3 && !args[2].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[2])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("rpad(")
				str.WriteString(inputs[0])
				str.WriteString(", ")
				str.WriteString(inputs[1])
				str.WriteString("::INT4")
				if len(inputs) == 3 {
					str.WriteString(", ")
					str.WriteString(inputs[2])
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"rtrim": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("rtrim"),
		},
		"substring": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 2 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				if len(args) == 3 && !args[2].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[2])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("substring(")
				str.WriteString(inputs[0])
				str.WriteString(" from ")
				str.WriteString(inputs[1])
				str.WriteString("::INT4")
				if len(inputs) == 3 {
					str.WriteString(" for ")
					str.WriteString(inputs[2])
					str.WriteString("::INT4")
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"trim": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("trim"),
		},
		"upper": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("upper"),
		},
		"format": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) < 1 {
					return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("format"),
		},
		"coalesce": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) < 1 {
					return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got %d", len(args))
				}

				firstType := args[0]

				for i, arg := range args {
					if !firstType.Equals(arg) {
						return nil, fmt.Errorf("%w: all arguments must be the same type, but argument %d is %s and argument 1 is %s", ErrType, i+1, arg.String(), firstType.String())
					}
				}

				return firstType, nil
			},
			PGFormatFunc: defaultFormat("coalesce"),
		},
		"greatest": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) == 0 {
					return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got 0")
				}

				if !args[0].IsNumeric() || !args[1].IsNumeric() {
					return nil, fmt.Errorf("%w: expected both arguments to be numeric, got %s and %s", ErrType, args[0].String(), args[1].String())
				}

				for i, arg := range args {
					if !args[0].Equals(arg) {
						return nil, fmt.Errorf("%w: all arguments must be the same type, but argument %d is %s and argument 1 is %s", ErrType, i+1, arg.String(), args[0].String())
					}
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("greatest"),
		},
		"least": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) == 0 {
					return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got 0")
				}

				if !args[0].IsNumeric() || !args[1].IsNumeric() {
					return nil, fmt.Errorf("%w: expected both arguments to be numeric, got %s and %s", ErrType, args[0].String(), args[1].String())
				}

				for i, arg := range args {
					if !args[0].Equals(arg) {
						return nil, fmt.Errorf("%w: all arguments must be the same type, but argument %d is %s and argument 1 is %s", ErrType, i+1, arg.String(), args[0].String())
					}
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("least"),
		},
		"nullif": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 2, got %d", len(args))
				}

				if !args[0].Equals(args[1]) {
					return nil, fmt.Errorf("%w: both arguments must be the same type, but got %s and %s", ErrType, args[0].String(), args[1].String())
				}

				if args[0].EqualsStrict(types.NullType) && args[1].EqualsStrict(types.NullType) {
					return nil, fmt.Errorf("%w: both arguments to NULLIF cannot be null", ErrType)
				}

				if args[0].EqualsStrict(types.NullType) {
					return args[1], nil
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("nullif"),
		},

		"count": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) > 1 {
					return nil, fmt.Errorf("invalid number of arguments: expected at most 1, got %d", len(args))
				}

				return types.IntType, nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if len(inputs) == 0 {
					if distinct {
						return "", fmt.Errorf("count(DISTINCT *) is not supported")
					}
					return "count(*)", nil
				}
				if distinct {
					return fmt.Sprintf("count(DISTINCT %s)", inputs[0]), nil
				}

				return fmt.Sprintf("count(%s)", inputs[0]), nil
			},
		},
		"sum": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].IsNumeric() {
					return nil, fmt.Errorf("%w: expected argument to be numeric, got %s", ErrType, args[0].String())
				}

				var retType *types.DataType
				switch {
				case args[0].Equals(types.IntType):
					retType = decimal1000.Copy()
				case args[0].Name == types.NumericStr:
					retType = args[0].Copy()
					retType.Metadata[0] = 1000
				default:
					panic(fmt.Sprintf("unexpected numeric type: %s", retType.String()))
				}

				return retType, nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "sum(DISTINCT %s)", nil
				}

				return fmt.Sprintf("sum(%s)", inputs[0]), nil
			},
		},
		"min": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].IsNumeric() && !args[0].Equals(types.TextType) {
					return nil, fmt.Errorf("%w: expected argument to be numeric or text, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "min(DISTINCT %s)", nil
				}

				return fmt.Sprintf("min(%s)", inputs[0]), nil
			},
		},
		"max": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].IsNumeric() && !args[0].Equals(types.TextType) {
					return nil, fmt.Errorf("%w: expected argument to be numeric or text, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "max(DISTINCT %s)", nil
				}

				return fmt.Sprintf("max(%s)", inputs[0]), nil
			},
		},
		"array_agg": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if args[0].IsArray {
					return nil, fmt.Errorf("%w: expected argument to be a scalar, got %s", ErrType, args[0].String())
				}

				a2 := args[0].Copy()
				a2.IsArray = true
				return a2, nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "array_agg(DISTINCT %s)", nil
				}

				return fmt.Sprintf("array_agg(%s ORDER BY %s)", inputs[0], inputs[0]), nil
			},
		},
		"avg": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !strings.EqualFold(args[0].Name, types.NumericStr) {
					return nil, fmt.Errorf("%w: expected argument to be numeric, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "avg(DISTINCT %s)", nil
				}

				return fmt.Sprintf("avg(%s)", inputs[0]), nil
			},
		},

		"lag": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1-3, got %d", len(args))
				}

				if len(args) >= 2 {
					if !args[1].Equals(types.IntType) {
						return nil, wrapErrArgumentType(types.IntType, args[1])
					}
				}

				if len(args) == 3 {
					if !args[2].Equals(args[0]) {
						return nil, fmt.Errorf("%w: expected default value to be the same type as the value expression: %s != %s", ErrType, args[0].String(), args[2].String())
					}
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("lag"),
		},
		"lead": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1-3, got %d", len(args))
				}

				if len(args) >= 2 {
					if !args[1].Equals(types.IntType) {
						return nil, wrapErrArgumentType(types.IntType, args[1])
					}
				}

				if len(args) == 3 {
					if !args[2].Equals(args[0]) {
						return nil, fmt.Errorf("%w: expected default value to be the same type as the value expression: %s != %s", ErrType, args[0].String(), args[2].String())
					}
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("lead"),
		},
		"first_value": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("first_value"),
		},
		"last_value": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("last_value"),
		},
		"nth_value": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("nth_value"),
		},
		"row_number": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 0 {
					return nil, wrapErrArgumentNumber(0, len(args))
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("row_number"),
		},
	}
)

Functions

func MakeTypeCast

func MakeTypeCast(d *types.DataType) (string, error)

MakeTypeCast returns the string that type casts a value to the given type. It should be used when generating SQL. If the type is null, no type cast is returned.

Types

type AggregateFunctionDefinition

type AggregateFunctionDefinition struct {
	// ValidateArgs is a function that checks the arguments passed to the function.
	// It can check the argument type and amount of arguments.
	ValidateArgsFunc func(args []*types.DataType) (*types.DataType, error)
	// PGFormat is a function that formats the inputs to the function in Postgres format.
	// For example, the function `sum` would format the inputs as `sum($1)`.
	// It can also format the inputs with DISTINCT. If no inputs are given, it is a *.
	PGFormatFunc func(inputs []string, distinct bool) (string, error)
}

AggregateFunctionDefinition is a definition of an aggregate function.

func (*AggregateFunctionDefinition) ValidateArgs

func (a *AggregateFunctionDefinition) ValidateArgs(args []*types.DataType) (*types.DataType, error)

type Column

type Column struct {
	// Name is the name of the column.
	Name string
	// DataType is the data type of the column.
	DataType *types.DataType
	// Nullable is true if the column can be null.
	Nullable bool
	// IsPrimaryKey is true if the column is part of the primary key.
	IsPrimaryKey bool
}

Column is a column in a table.

func (*Column) Copy

func (c *Column) Copy() *Column

type Constraint

type Constraint struct {
	// Type is the type of the constraint.
	Type ConstraintType
	// Columns is a list of column names that the constraint is on.
	Columns []string
}

Constraint is a constraint in the schema.

func (*Constraint) ContainsColumn

func (c *Constraint) ContainsColumn(col string) bool

func (*Constraint) Copy

func (c *Constraint) Copy() *Constraint

type ConstraintType

type ConstraintType string
const (
	ConstraintUnique ConstraintType = "unique"
	ConstraintCheck  ConstraintType = "check"
	ConstraintFK     ConstraintType = "foreign_key"
)

type FormatFunc

type FormatFunc func(inputs []string) (string, error)

FormatFunc is a function that formats a string of inputs for a SQL function.

type FunctionDefinition

type FunctionDefinition interface {
	// ValidateArgs is a function that checks the arguments passed to the function.
	// It can check the argument type and amount of arguments.
	// It returns the expected return type based on the arguments.
	ValidateArgs(args []*types.DataType) (*types.DataType, error)
	// contains filtered or unexported methods
}

FunctionDefinition if a definition of a function. It has two implementations: ScalarFuncDef and AggregateFuncDef.

type Index

type Index struct {
	Name    string    `json:"name"`
	Columns []string  `json:"columns"`
	Type    IndexType `json:"type"`
}

Index is an index on a table.

func (*Index) ContainsColumn

func (i *Index) ContainsColumn(col string) bool

func (*Index) Copy

func (i *Index) Copy() *Index

type IndexType

type IndexType string

IndexType is a type of index (e.g. BTREE, UNIQUE_BTREE, PRIMARY)

const (
	// BTREE is the default index type.
	BTREE IndexType = "BTREE"
	// UNIQUE_BTREE is a unique BTREE index.
	UNIQUE_BTREE IndexType = "UNIQUE_BTREE"
	// PRIMARY is a primary index.
	// Only one primary index is allowed per table.
	// A primary index cannot exist on a table that also has a primary key.
	PRIMARY IndexType = "PRIMARY"
)

index types

type NamedType

type NamedType struct {
	// Name is the name of the parameter.
	// It should always be lower case.
	// If it is an action parameter, it should begin with a $.
	Name string `json:"name"`
	// Type is the type of the parameter.
	Type *types.DataType `json:"type"`
}

NamedType is a parameter in an action.

type NamespaceRegister

type NamespaceRegister interface {
	Lock()
	Unlock()
	RegisterNamespace(ns string)
	UnregisterAllNamespaces()
}

type ScalarFunctionDefinition

type ScalarFunctionDefinition struct {
	ValidateArgsFunc func(args []*types.DataType) (*types.DataType, error)
	PGFormatFunc     func(inputs []string) (string, error)
}

ScalarFunctionDefinition is a definition of a scalar function.

func (*ScalarFunctionDefinition) ValidateArgs

func (s *ScalarFunctionDefinition) ValidateArgs(args []*types.DataType) (*types.DataType, error)

type Table

type Table struct {
	// Name is the name of the table.
	Name string
	// Columns is a list of columns in the table.
	Columns []*Column
	// Indexes is a list of indexes on the table.
	Indexes []*Index
	// Constraints are constraints on the table.
	Constraints map[string]*Constraint
}

Table is a table in the schema.

func (*Table) Column

func (t *Table) Column(name string) (*Column, bool)

Column returns a column by name. If the column is not found, the second return value is false.

func (*Table) Copy

func (t *Table) Copy() *Table

Copy deep copies the table.

func (*Table) HasPrimaryKey

func (t *Table) HasPrimaryKey(col string) bool

HasPrimaryKey returns true if the column is part of the primary key.

func (*Table) PrimaryKeyCols

func (t *Table) PrimaryKeyCols() []*Column

func (*Table) SearchConstraint

func (t *Table) SearchConstraint(column string, constraint ConstraintType) []*Constraint

SearchConstraint returns a list of constraints that match the given column and type.

type WindowFunctionDefinition

type WindowFunctionDefinition struct {
	// ValidateArgs is a function that checks the arguments passed to the function.
	// It can check the argument type and amount of arguments.
	ValidateArgsFunc func(args []*types.DataType) (*types.DataType, error)
	// PGFormat is a function that formats the inputs to the function in Postgres format.
	// For example, the function `sum` would format the inputs as `sum($1)`.
	// It can also format the inputs with DISTINCT. If no inputs are given, it is a *.
	PGFormatFunc func(inputs []string) (string, error)
}

func (*WindowFunctionDefinition) ValidateArgs

func (w *WindowFunctionDefinition) ValidateArgs(args []*types.DataType) (*types.DataType, error)

Directories

Path Synopsis
package interpreter provides a basic interpreter for Kuneiform procedures.
package interpreter provides a basic interpreter for Kuneiform procedures.
package parse contains logic for parsing SQL, DDL, and Actions, and SQL.
package parse contains logic for parsing SQL, DDL, and Actions, and SQL.
gen
pggenerate package is responsible for generating the Postgres-compatible SQL from the AST.
pggenerate package is responsible for generating the Postgres-compatible SQL from the AST.
planner

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL