engine

package
v0.10.0-rc.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 27, 2025 License: Apache-2.0 Imports: 4 Imported by: 0

Documentation

Index

Constants

View Source
const (
	// ReservedKwilNamespacePrefix is the prefix that is reserved for namespaces that are not part of the engine
	// directly, but are built in by Kwil via extensions.
	ReservedKwilNamespacePrefix = "kwil_"
	ReservedPGNamespacePrefix   = "pg_"
)
View Source
const (
	DefaultNamespace       = "main"
	InfoNamespace          = "info"
	InternalEnginePGSchema = "kwild_engine"
)

Variables

View Source
var (
	// Errors that suggest a bug in a user's executing code. These are type errors,
	// issues in arithmetic, array indexing, etc.
	ErrType                    = errors.New("type error")
	ErrReturnShape             = errors.New("unexpected action/function return shape")
	ErrUnknownVariable         = errors.New("unknown variable")
	ErrInvalidVariable         = errors.New("invalid variable name")
	ErrLoop                    = errors.New("loop error")
	ErrArithmetic              = errors.New("arithmetic error")
	ErrComparison              = errors.New("comparison error")
	ErrCast                    = errors.New("type cast error")
	ErrUnary                   = errors.New("unary operation error")
	ErrIndexOutOfBounds        = errors.New("index out of bounds")
	ErrArrayDimensionality     = errors.New("array dimensionality error")
	ErrInvalidNull             = errors.New("invalid null value")
	ErrArrayTooSmall           = errors.New("array too small")
	ErrExtensionImplementation = errors.New("extension implementation error")
	ErrActionInvocation        = errors.New("action invocation error")

	// Errors that signal the existence or non-existence of an object.
	ErrUnknownAction     = errors.New("unknown action")
	ErrUnknownTable      = errors.New("unknown table")
	ErrNamespaceNotFound = errors.New("namespace not found")
	ErrNamespaceExists   = errors.New("namespace already exists")

	// Errors that likely are not the result of a user error, but instead are informing
	// the user of an operation that is not allowed in order to maintain the integrity of
	// the system.
	ErrCannotMutateState          = errors.New("connection is read-only and cannot mutate state")
	ErrIllegalFunctionUsage       = errors.New("illegal function usage")
	ErrQueryActive                = errors.New("a query is currently active. nested queries are not allowed")
	ErrCannotBeNamespaced         = errors.New("the selected object is global-only, and cannot be namespaced")
	ErrCannotMutateExtension      = errors.New("cannot mutate an extension's schema or data directly")
	ErrCannotMutateInfoNamespace  = errors.New(`cannot mutate the "info" namespace directly`)
	ErrCannotDropBuiltinNamespace = errors.New("cannot drop a built-in namespace")
	ErrBuiltInRole                = errors.New("invalid operation on built-in role")
	ErrInvalidTxCtx               = errors.New("invalid transaction context")
	ErrReservedNamespacePrefix    = errors.New("namespace prefix is reserved")
	ErrCannotAlterPrimaryKey      = errors.New("cannot drop or alter a table's primary key")

	// Errors that are the result of not having proper permissions or failing to meet a condition
	// that was programmed by the user.
	ErrActionOwnerOnly      = errors.New("action is owner-only")
	ErrActionPrivate        = errors.New("action is private")
	ErrActionSystemOnly     = errors.New("action is system-only")
	ErrDoesNotHavePrivilege = errors.New("user does not have privilege")

	// Errors that signal an error in a deeper layer, and originate from
	// somewhere deeper than the interpreter.
	ErrParse        = errors.New("parse error")
	ErrQueryPlanner = errors.New("query planner error")
	ErrPGGen        = errors.New("postgres SQL generation error")
)
View Source
var (
	Functions = map[string]FunctionDefinition{
		"abs": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.IntType) && args[0].Name != types.NumericStr {
					return nil, fmt.Errorf("%w: expected argument to be int or decimal, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("abs"),
		},
		"error": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.NullType, nil
			},
			PGFormatFunc: defaultFormat("error"),
		},
		"parse_unix_timestamp": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return decimal16_6, nil
			},
			PGFormatFunc: defaultFormat("parse_unix_timestamp"),
		},
		"format_unix_timestamp": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(decimal16_6) {
					return nil, wrapErrArgumentType(decimal16_6, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("format_unix_timestamp"),
		},
		"notice": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.NullType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				return "", fmt.Errorf("%w: notice cannot be used in SQL statements", ErrIllegalFunctionUsage)
			},
		},
		"uuid_generate_v5": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.UUIDType) {
					return nil, wrapErrArgumentType(types.UUIDType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.UUIDType, nil
			},
			PGFormatFunc: defaultFormat("uuid_generate_v5"),
		},

		"uuid_generate_kwil": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.UUIDType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				return "uuid_generate_v5('a247cac1-d817-4949-bac7-dc4b1dc41d09'::uuid," + inputs[0] + ")", nil
			},
		},
		"encode": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.ByteaType) {
					return nil, wrapErrArgumentType(types.ByteaType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("encode"),
		},
		"decode": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.ByteaType, nil
			},
			PGFormatFunc: defaultFormat("decode"),
		},
		"digest": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].Equals(types.TextType) && !args[0].Equals(types.ByteaType) {
					return nil, fmt.Errorf("%w: expected first argument to be text or blob, got %s", ErrType, args[0].String())
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				return types.ByteaType, nil
			},
			PGFormatFunc: defaultFormat("digest"),
		},

		"array_append": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String())
				}

				if args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be a scalar, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: append type must be equal to scalar array type. array type: %s append type: %s", ErrType, args[0].Name, args[1].Name)
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("array_append"),
		},
		"array_prepend": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be a scalar, got %s", ErrType, args[0].String())
				}

				if !args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: prepend type must be equal to scalar array type. array type: %s prepend type: %s", ErrType, args[1].Name, args[0].Name)
				}

				return args[1], nil
			},
			PGFormatFunc: defaultFormat("array_prepend"),
		},
		"array_cat": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String())
				}

				if !args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: expected both arrays to be of the same scalar type, got %s and %s", ErrType, args[0].Name, args[1].Name)
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("array_cat"),
		},
		"array_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected argument to be an array, got %s", ErrType, args[0].String())
				}

				if len(args) == 2 && !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				return types.IntType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				if len(inputs) == 1 {
					return fmt.Sprintf("array_length(%s, 1)", inputs[0]), nil
				}
				return fmt.Sprintf("array_length(%s, %s)", inputs[0], inputs[1]), nil
			},
		},
		"array_remove": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[0].IsArray {
					return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String())
				}

				if args[1].IsArray {
					return nil, fmt.Errorf("%w: expected second argument to be a scalar, got %s", ErrType, args[1].String())
				}

				if !strings.EqualFold(args[0].Name, args[1].Name) {
					return nil, fmt.Errorf("%w: remove type must be equal to scalar array type. array type: %s remove type: %s", ErrType, args[0].Name, args[1].Name)
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("array_remove"),
		},

		"bit_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("bit_length"),
		},
		"char_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("char_length"),
		},
		"character_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("character_length"),
		},
		"length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("length"),
		},
		"lower": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("lower"),
		},
		"lpad": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 2 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				if len(args) == 3 && !args[2].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[2])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("lpad(")
				str.WriteString(inputs[0])
				str.WriteString(", ")
				str.WriteString(inputs[1])
				str.WriteString("::INT4")
				if len(inputs) == 3 {
					str.WriteString(", ")
					str.WriteString(inputs[2])
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"ltrim": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("ltrim"),
		},
		"octet_length": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("octet_length"),
		},
		"overlay": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 3 || len(args) > 4 {
					return nil, fmt.Errorf("invalid number of arguments: expected 3 or 4, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[1])
				}

				if !args[2].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[2])
				}

				if len(args) == 4 && !args[3].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[3])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("overlay(")
				str.WriteString(inputs[0])
				str.WriteString(" placing ")
				str.WriteString(inputs[1])
				str.WriteString(" from ")
				str.WriteString(inputs[2])
				str.WriteString("::INT4")
				if len(inputs) == 4 {
					str.WriteString(" for ")
					str.WriteString(inputs[3])
					str.WriteString("::INT4")
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"position": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.IntType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				return fmt.Sprintf("position(%s in %s)", inputs[0], inputs[1]), nil
			},
		},
		"rpad": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 2 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				if len(args) == 3 && !args[2].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[2])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("rpad(")
				str.WriteString(inputs[0])
				str.WriteString(", ")
				str.WriteString(inputs[1])
				str.WriteString("::INT4")
				if len(inputs) == 3 {
					str.WriteString(", ")
					str.WriteString(inputs[2])
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"rtrim": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("rtrim"),
		},
		"substring": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 2 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				if len(args) == 3 && !args[2].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[2])
				}

				return types.TextType, nil
			},
			PGFormatFunc: func(inputs []string) (string, error) {
				str := strings.Builder{}
				str.WriteString("substring(")
				str.WriteString(inputs[0])
				str.WriteString(" from ")
				str.WriteString(inputs[1])
				str.WriteString("::INT4")
				if len(inputs) == 3 {
					str.WriteString(" for ")
					str.WriteString(inputs[2])
					str.WriteString("::INT4")
				}
				str.WriteString(")")

				return str.String(), nil
			},
		},
		"trim": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 2 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args))
				}

				for _, arg := range args {
					if !arg.Equals(types.TextType) {
						return nil, wrapErrArgumentType(types.TextType, arg)
					}
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("trim"),
		},
		"upper": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("upper"),
		},
		"format": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) < 1 {
					return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got %d", len(args))
				}

				if !args[0].Equals(types.TextType) {
					return nil, wrapErrArgumentType(types.TextType, args[0])
				}

				return types.TextType, nil
			},
			PGFormatFunc: defaultFormat("format"),
		},
		"coalesce": &ScalarFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) < 1 {
					return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got %d", len(args))
				}

				firstType := args[0]

				for i, arg := range args {
					if !firstType.Equals(arg) {
						return nil, fmt.Errorf("%w: all arguments must be the same type, but argument %d is %s and argument 1 is %s", ErrType, i+1, arg.String(), firstType.String())
					}
				}

				return firstType, nil
			},
			PGFormatFunc: defaultFormat("coalesce"),
		},

		"count": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) > 1 {
					return nil, fmt.Errorf("invalid number of arguments: expected at most 1, got %d", len(args))
				}

				return types.IntType, nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if len(inputs) == 0 {
					if distinct {
						return "", fmt.Errorf("count(DISTINCT *) is not supported")
					}
					return "count(*)", nil
				}
				if distinct {
					return fmt.Sprintf("count(DISTINCT %s)", inputs[0]), nil
				}

				return fmt.Sprintf("count(%s)", inputs[0]), nil
			},
		},
		"sum": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].IsNumeric() {
					return nil, fmt.Errorf("%w: expected argument to be numeric, got %s", ErrType, args[0].String())
				}

				var retType *types.DataType
				switch {
				case args[0].Equals(types.IntType):
					retType = decimal1000.Copy()
				case args[0].Name == types.NumericStr:
					retType = args[0].Copy()
					retType.Metadata[0] = 1000
				default:
					panic(fmt.Sprintf("unexpected numeric type: %s", retType.String()))
				}

				return retType, nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "sum(DISTINCT %s)", nil
				}

				return fmt.Sprintf("sum(%s)", inputs[0]), nil
			},
		},
		"min": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].IsNumeric() && !args[0].Equals(types.TextType) {
					return nil, fmt.Errorf("%w: expected argument to be numeric or text, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "min(DISTINCT %s)", nil
				}

				return fmt.Sprintf("min(%s)", inputs[0]), nil
			},
		},
		"max": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !args[0].IsNumeric() && !args[0].Equals(types.TextType) {
					return nil, fmt.Errorf("%w: expected argument to be numeric or text, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "max(DISTINCT %s)", nil
				}

				return fmt.Sprintf("max(%s)", inputs[0]), nil
			},
		},
		"array_agg": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {
				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if args[0].IsArray {
					return nil, fmt.Errorf("%w: expected argument to be a scalar, got %s", ErrType, args[0].String())
				}

				a2 := args[0].Copy()
				a2.IsArray = true
				return a2, nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "array_agg(DISTINCT %s)", nil
				}

				return fmt.Sprintf("array_agg(%s ORDER BY %s)", inputs[0], inputs[0]), nil
			},
		},
		"avg": &AggregateFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				if !strings.EqualFold(args[0].Name, types.NumericStr) {
					return nil, fmt.Errorf("%w: expected argument to be numeric, got %s", ErrType, args[0].String())
				}

				return args[0], nil
			},
			PGFormatFunc: func(inputs []string, distinct bool) (string, error) {
				if distinct {
					return "avg(DISTINCT %s)", nil
				}

				return fmt.Sprintf("avg(%s)", inputs[0]), nil
			},
		},

		"lag": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1-3, got %d", len(args))
				}

				if len(args) >= 2 {
					if !args[1].Equals(types.IntType) {
						return nil, wrapErrArgumentType(types.IntType, args[1])
					}
				}

				if len(args) == 3 {
					if !args[2].Equals(args[0]) {
						return nil, fmt.Errorf("%w: expected default value to be the same type as the value expression: %s != %s", ErrType, args[0].String(), args[2].String())
					}
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("lag"),
		},
		"lead": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) < 1 || len(args) > 3 {
					return nil, fmt.Errorf("invalid number of arguments: expected 1-3, got %d", len(args))
				}

				if len(args) >= 2 {
					if !args[1].Equals(types.IntType) {
						return nil, wrapErrArgumentType(types.IntType, args[1])
					}
				}

				if len(args) == 3 {
					if !args[2].Equals(args[0]) {
						return nil, fmt.Errorf("%w: expected default value to be the same type as the value expression: %s != %s", ErrType, args[0].String(), args[2].String())
					}
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("lead"),
		},
		"first_value": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("first_value"),
		},
		"last_value": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 1 {
					return nil, wrapErrArgumentNumber(1, len(args))
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("last_value"),
		},
		"nth_value": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 2 {
					return nil, wrapErrArgumentNumber(2, len(args))
				}

				if !args[1].Equals(types.IntType) {
					return nil, wrapErrArgumentType(types.IntType, args[1])
				}

				return args[0], nil
			},
			PGFormatFunc: defaultFormat("nth_value"),
		},
		"row_number": &WindowFunctionDefinition{
			ValidateArgsFunc: func(args []*types.DataType) (*types.DataType, error) {

				if len(args) != 0 {
					return nil, wrapErrArgumentNumber(0, len(args))
				}

				return types.IntType, nil
			},
			PGFormatFunc: defaultFormat("row_number"),
		},
	}
)

Functions

func MakeTypeCast

func MakeTypeCast(d *types.DataType) (string, error)

MakeTypeCast returns the string that type casts a value to the given type. It should be used when generating SQL. If the type is null, no type cast is returned.

Types

type AggregateFunctionDefinition

type AggregateFunctionDefinition struct {
	// ValidateArgs is a function that checks the arguments passed to the function.
	// It can check the argument type and amount of arguments.
	ValidateArgsFunc func(args []*types.DataType) (*types.DataType, error)
	// PGFormat is a function that formats the inputs to the function in Postgres format.
	// For example, the function `sum` would format the inputs as `sum($1)`.
	// It can also format the inputs with DISTINCT. If no inputs are given, it is a *.
	PGFormatFunc func(inputs []string, distinct bool) (string, error)
}

AggregateFunctionDefinition is a definition of an aggregate function.

func (*AggregateFunctionDefinition) ValidateArgs

func (a *AggregateFunctionDefinition) ValidateArgs(args []*types.DataType) (*types.DataType, error)

type Column

type Column struct {
	// Name is the name of the column.
	Name string
	// DataType is the data type of the column.
	DataType *types.DataType
	// Nullable is true if the column can be null.
	Nullable bool
	// IsPrimaryKey is true if the column is part of the primary key.
	IsPrimaryKey bool
}

Column is a column in a table.

func (*Column) Copy

func (c *Column) Copy() *Column

type Constraint

type Constraint struct {
	// Type is the type of the constraint.
	Type ConstraintType
	// Columns is a list of column names that the constraint is on.
	Columns []string
}

Constraint is a constraint in the schema.

func (*Constraint) ContainsColumn

func (c *Constraint) ContainsColumn(col string) bool

func (*Constraint) Copy

func (c *Constraint) Copy() *Constraint

type ConstraintType

type ConstraintType string
const (
	ConstraintUnique ConstraintType = "unique"
	ConstraintCheck  ConstraintType = "check"
	ConstraintFK     ConstraintType = "foreign_key"
)

type FormatFunc

type FormatFunc func(inputs []string) (string, error)

FormatFunc is a function that formats a string of inputs for a SQL function.

type FunctionDefinition

type FunctionDefinition interface {
	// ValidateArgs is a function that checks the arguments passed to the function.
	// It can check the argument type and amount of arguments.
	// It returns the expected return type based on the arguments.
	ValidateArgs(args []*types.DataType) (*types.DataType, error)
	// contains filtered or unexported methods
}

FunctionDefinition if a definition of a function. It has two implementations: ScalarFuncDef and AggregateFuncDef.

type Index

type Index struct {
	Name    string    `json:"name"`
	Columns []string  `json:"columns"`
	Type    IndexType `json:"type"`
}

Index is an index on a table.

func (*Index) ContainsColumn

func (i *Index) ContainsColumn(col string) bool

func (*Index) Copy

func (i *Index) Copy() *Index

type IndexType

type IndexType string

IndexType is a type of index (e.g. BTREE, UNIQUE_BTREE, PRIMARY)

const (
	// BTREE is the default index type.
	BTREE IndexType = "BTREE"
	// UNIQUE_BTREE is a unique BTREE index.
	UNIQUE_BTREE IndexType = "UNIQUE_BTREE"
	// PRIMARY is a primary index.
	// Only one primary index is allowed per table.
	// A primary index cannot exist on a table that also has a primary key.
	PRIMARY IndexType = "PRIMARY"
)

index types

type NamedType

type NamedType struct {
	// Name is the name of the parameter.
	// It should always be lower case.
	// If it is an action parameter, it should begin with a $.
	Name string `json:"name"`
	// Type is the type of the parameter.
	Type *types.DataType `json:"type"`
}

NamedType is a parameter in an action.

type NamespaceRegister

type NamespaceRegister interface {
	Lock()
	Unlock()
	RegisterNamespace(ns string)
	UnregisterAllNamespaces()
}

type ScalarFunctionDefinition

type ScalarFunctionDefinition struct {
	ValidateArgsFunc func(args []*types.DataType) (*types.DataType, error)
	PGFormatFunc     func(inputs []string) (string, error)
}

ScalarFunctionDefinition is a definition of a scalar function.

func (*ScalarFunctionDefinition) ValidateArgs

func (s *ScalarFunctionDefinition) ValidateArgs(args []*types.DataType) (*types.DataType, error)

type Table

type Table struct {
	// Name is the name of the table.
	Name string
	// Columns is a list of columns in the table.
	Columns []*Column
	// Indexes is a list of indexes on the table.
	Indexes []*Index
	// Constraints are constraints on the table.
	Constraints map[string]*Constraint
}

Table is a table in the schema.

func (*Table) Column

func (t *Table) Column(name string) (*Column, bool)

Column returns a column by name. If the column is not found, the second return value is false.

func (*Table) Copy

func (t *Table) Copy() *Table

Copy deep copies the table.

func (*Table) HasPrimaryKey

func (t *Table) HasPrimaryKey(col string) bool

HasPrimaryKey returns true if the column is part of the primary key.

func (*Table) PrimaryKeyCols

func (t *Table) PrimaryKeyCols() []*Column

func (*Table) SearchConstraint

func (t *Table) SearchConstraint(column string, constraint ConstraintType) []*Constraint

SearchConstraint returns a list of constraints that match the given column and type.

type WindowFunctionDefinition

type WindowFunctionDefinition struct {
	// ValidateArgs is a function that checks the arguments passed to the function.
	// It can check the argument type and amount of arguments.
	ValidateArgsFunc func(args []*types.DataType) (*types.DataType, error)
	// PGFormat is a function that formats the inputs to the function in Postgres format.
	// For example, the function `sum` would format the inputs as `sum($1)`.
	// It can also format the inputs with DISTINCT. If no inputs are given, it is a *.
	PGFormatFunc func(inputs []string) (string, error)
}

func (*WindowFunctionDefinition) ValidateArgs

func (w *WindowFunctionDefinition) ValidateArgs(args []*types.DataType) (*types.DataType, error)

Directories

Path Synopsis
package interpreter provides a basic interpreter for Kuneiform procedures.
package interpreter provides a basic interpreter for Kuneiform procedures.
package parse contains logic for parsing SQL, DDL, and Actions, and SQL.
package parse contains logic for parsing SQL, DDL, and Actions, and SQL.
gen
pggenerate package is responsible for generating the Postgres-compatible SQL from the AST.
pggenerate package is responsible for generating the Postgres-compatible SQL from the AST.
planner

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL