Skip to content

HHH-18973, HHH-19679 hibernate-vector module enhancements #10685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@ The Hibernate ORM Vector module contains support for mathematical vector types a
This is useful for AI/ML topics like vector similarity search and Retrieval-Augmented Generation (RAG).
The module comes with support for a special `vector` data type that essentially represents an array of bytes, floats, or doubles.

So far, both the PostgreSQL extension `pgvector` and the Oracle database 23ai+ `AI Vector Search` feature are supported, but in theory,
the vector specific functions could be implemented to work with every database that supports arrays.
Currently, the following databases are supported:

For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation] or the https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[AI Vector Search documentation].
* PostgreSQL 13+ through the https://github.com/pgvector/pgvector#querying[`pgvector` extension]
* https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[Oracle database 23ai+]
* https://mariadb.com/docs/server/reference/sql-structure/vectors/vector-overview[MariaDB 11.7+]
* https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html[MySQL 9.0+]

In theory, the vector-specific functions could be implemented to work with every database that supports arrays.

[WARNING]
====
Per the https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html#function_distance[MySQL documentation],
the various vector distance functions for MySQL only work on MySQL cloud offerings like
https://dev.mysql.com/doc/heatwave/en/mys-hw-about-heatwave.html[HeatWave MySQL on OCI].
====

[[vector-module-setup]]
=== Setup
Expand Down Expand Up @@ -57,7 +68,7 @@ As Oracle AI Vector Search supports different types of elements (to ensure bette
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=usage-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=usage-example]
----
====

Expand Down Expand Up @@ -113,7 +124,7 @@ which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 )
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=cosine-distance-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=cosine-distance-example]
----
====

Expand All @@ -128,7 +139,7 @@ The `l2_distance()` function is an alias.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=euclidean-distance-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-distance-example]
----
====

Expand All @@ -143,7 +154,7 @@ The `l1_distance()` function is an alias.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=taxicab-distance-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=taxicab-distance-example]
----
====

Expand All @@ -158,7 +169,7 @@ and the `inner_product()` function as well, but multiplies the result time `-1`.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=inner-product-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=inner-product-example]
----
====

Expand All @@ -171,7 +182,7 @@ Determines the dimensions of a vector.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=vector-dims-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-dims-example]
----
====

Expand All @@ -185,7 +196,7 @@ which is `sqrt( sum( v_i^2 ) )`.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=vector-norm-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-norm-example]
----
====

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ public class OracleTypes {
public static final int VECTOR_INT8 = -106;
public static final int VECTOR_FLOAT32 = -107;
public static final int VECTOR_FLOAT64 = -108;
public static final int VECTOR_BINARY = -109;
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ public void render(
renderCastArrayToString( sqlAppender, arguments.get( 0 ), dialect, walker );
}
else {
new PatternRenderer( dialect.castPattern( sourceType, targetType ) )
.render( sqlAppender, arguments, walker );
String castPattern = targetJdbcMapping.getJdbcType().castFromPattern( sourceMapping );
if ( castPattern == null ) {
castPattern = sourceMapping.getJdbcType().castToPattern( targetJdbcMapping );
if ( castPattern == null ) {
castPattern = dialect.castPattern( sourceType, targetType );
}
}
new PatternRenderer( castPattern ).render( sqlAppender, arguments, walker );
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ public ReturnableType<?> resolveFunctionReturnType(
case NUMERIC:
return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType;
case VECTOR:
case VECTOR_BINARY:
case VECTOR_INT8:
case VECTOR_FLOAT16:
case VECTOR_FLOAT32:
case VECTOR_FLOAT64:
case SPARSE_VECTOR_INT8:
case SPARSE_VECTOR_FLOAT32:
case SPARSE_VECTOR_FLOAT64:
return basicType;
}
return bigDecimalType;
Expand Down Expand Up @@ -123,6 +131,14 @@ public BasicValuedMapping resolveFunctionReturnType(
final Class<?> argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass();
return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType;
case VECTOR:
case VECTOR_BINARY:
case VECTOR_INT8:
case VECTOR_FLOAT16:
case VECTOR_FLOAT32:
case VECTOR_FLOAT64:
case SPARSE_VECTOR_INT8:
case SPARSE_VECTOR_FLOAT32:
case SPARSE_VECTOR_FLOAT64:
return (BasicValuedMapping) jdbcMapping;
}
return bigDecimalType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.sql.SQLException;
import java.sql.Types;
import java.util.Locale;
import java.util.Objects;

import org.hibernate.HibernateException;
import org.hibernate.boot.model.relational.Database;
Expand Down Expand Up @@ -288,4 +289,16 @@ public String getFriendlyName() {
public String toString() {
return "OracleArrayTypeDescriptor(" + typeName + ")";
}

@Override
public boolean equals(Object that) {
return super.equals( that )
&& that instanceof OracleArrayJdbcType jdbcType
&& Objects.equals( typeName, jdbcType.typeName );
}

@Override
public int hashCode() {
return Objects.hashCode( typeName ) + 31 * super.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ public BasicCollectionType(
this.name = determineName( collectionTypeDescriptor, baseDescriptor );
}

public BasicCollectionType(
BasicType<E> baseDescriptor,
JdbcType arrayJdbcType,
JavaType<C> collectionTypeDescriptor,
String typeName) {
super( arrayJdbcType, collectionTypeDescriptor );
this.baseDescriptor = baseDescriptor;
this.name = typeName;
}

private static String determineName(BasicCollectionJavaType<?, ?> collectionTypeDescriptor, BasicType<?> baseDescriptor) {
final String elementTypeName = determineElementTypeName( baseDescriptor );
switch ( collectionTypeDescriptor.getSemantics().getCollectionClassification() ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import org.hibernate.internal.CoreMessageLogger;
import org.hibernate.internal.util.StringHelper;
import org.hibernate.internal.util.collections.CollectionHelper;
import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation;
import org.hibernate.type.descriptor.converter.spi.BasicValueConverter;
import org.hibernate.type.descriptor.java.BasicPluralJavaType;
import org.hibernate.type.descriptor.java.ImmutableMutabilityPlan;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.DelegatingJdbcTypeIndicators;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.internal.BasicTypeImpl;
import org.hibernate.type.internal.ConvertedBasicTypeImpl;
Expand Down Expand Up @@ -166,8 +168,48 @@ private <E> BasicType<?> resolvedType(ArrayJdbcType arrayType, BasicPluralJavaTy
typeConfiguration,
typeConfiguration.getCurrentBaseSqlTypeIndicators().getDialect(),
elementType,
null,
typeConfiguration.getCurrentBaseSqlTypeIndicators()
new ColumnTypeInformation() {
@Override
public Boolean getNullable() {
return null;
}

@Override
public int getTypeCode() {
return arrayType.getDefaultSqlTypeCode();
}

@Override
public String getTypeName() {
return null;
}

@Override
public int getColumnSize() {
return 0;
}

@Override
public int getDecimalDigits() {
return 0;
}
},
new DelegatingJdbcTypeIndicators( typeConfiguration.getCurrentBaseSqlTypeIndicators() ) {
@Override
public Integer getExplicitJdbcTypeCode() {
return arrayType.getDefaultSqlTypeCode();
}

@Override
public int getPreferredSqlTypeCodeForArray() {
return arrayType.getDefaultSqlTypeCode();
}

@Override
public int getPreferredSqlTypeCodeForArray(int elementSqlTypeCode) {
return arrayType.getDefaultSqlTypeCode();
}
}
);
if ( resolvedType instanceof BasicPluralType<?,?> ) {
register( resolvedType );
Expand Down
35 changes: 32 additions & 3 deletions hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -681,10 +681,10 @@ public class SqlTypes {


/**
* A type code representing an {@code embedding vector} type for databases
* A type code representing a {@code vector} type for databases
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL},
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai} and {@link org.hibernate.dialect.MariaDBDialect MariaDB}.
* An embedding vector essentially is a {@code float[]} with a fixed size.
* A vector essentially is a {@code float[]} with a fixed length.
*
* @since 6.4
*/
Expand All @@ -701,10 +701,39 @@ public class SqlTypes {
public static final int VECTOR_FLOAT32 = 10_002;

/**
* A type code representing a double-precision floating-point type for Oracle 23ai database.
* A type code representing a double-precision floating-point vector type for Oracle 23ai database.
*/
public static final int VECTOR_FLOAT64 = 10_003;

/**
* A type code representing a bit precision vector type for databases
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai}.
*/
public static final int VECTOR_BINARY = 10_004;

/**
* A type code representing a half-precision floating-point vector type for databases
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL}.
*/
public static final int VECTOR_FLOAT16 = 10_005;

/**
* A type code representing a sparse single-byte integer vector type for Oracle 23ai database.
*/
public static final int SPARSE_VECTOR_INT8 = 10_006;

/**
* A type code representing a sparse single-precision floating-point vector type for Oracle 23ai database.
*/
public static final int SPARSE_VECTOR_FLOAT32 = 10_007;

/**
* A type code representing a sparse double-precision floating-point vector type for Oracle 23ai database.
*/
public static final int SPARSE_VECTOR_FLOAT64 = 10_008;


private SqlTypes() {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,14 @@ private StandardBasicTypes() {
"byte_vector", byte[].class, SqlTypes.VECTOR_INT8
);

/**
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_FLOAT16 VECTOR_FLOAT16},
* specifically for embedding half-precision floating-point (16-bits) vectors like provided by the PostgreSQL extension pgvector.
*/
public static final BasicTypeReference<float[]> VECTOR_FLOAT16 = new BasicTypeReference<>(
"float16_vector", float[].class, SqlTypes.VECTOR_FLOAT16
);

/**
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
* specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai.
Expand All @@ -765,6 +773,38 @@ private StandardBasicTypes() {
"double_vector", double[].class, SqlTypes.VECTOR_FLOAT64
);

/**
* The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_BINARY VECTOR_BIT},
* specifically for embedding bit vectors like provided by Oracle 23ai.
*/
public static final BasicTypeReference<byte[]> VECTOR_BINARY = new BasicTypeReference<>(
"binary_vector", byte[].class, SqlTypes.VECTOR_BINARY
);

// /**
// * The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_INT8 VECTOR_INT8},
// * specifically for embedding integer vectors (8-bits) like provided by Oracle 23ai.
// */
// public static final BasicTypeReference<byte[]> SPARSE_VECTOR_INT8 = new BasicTypeReference<>(
// "sparse_byte_vector", byte[].class, SqlTypes.SPARSE_VECTOR_INT8
// );
//
// /**
// * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
// * specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai.
// */
// public static final BasicTypeReference<float[]> SPARSE_VECTOR_FLOAT32 = new BasicTypeReference<>(
// "sparse_float_vector", float[].class, SqlTypes.SPARSE_VECTOR_FLOAT32
// );
//
// /**
// * The standard Hibernate type for mapping {@code double[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
// * specifically for embedding double-precision floating-point (64-bits) vectors like provided by Oracle 23ai.
// */
// public static final BasicTypeReference<double[]> SPARSE_VECTOR_FLOAT64 = new BasicTypeReference<>(
// "sparse_double_vector", double[].class, SqlTypes.SPARSE_VECTOR_FLOAT64
// );


public static void prime(TypeConfiguration typeConfiguration) {
BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
Expand Down Expand Up @@ -1286,6 +1326,34 @@ public static void prime(TypeConfiguration typeConfiguration) {
"byte_vector"
);

handle(
VECTOR_BINARY,
null,
basicTypeRegistry,
"bit_vector"
);

// handle(
// SPARSE_VECTOR_FLOAT32,
// null,
// basicTypeRegistry,
// "sparse_float_vector"
// );
//
// handle(
// SPARSE_VECTOR_FLOAT64,
// null,
// basicTypeRegistry,
// "sparse_double_vector"
// );
//
// handle(
// SPARSE_VECTOR_INT8,
// null,
// basicTypeRegistry,
// "sparse_byte_vector"
// );


// Specialized version handlers

Expand Down
Loading