DH3RVE6JRQBOSRUBVIIMKSJYFS4CK5PI25KV2NAJ7EP4Q7NNKSJQC
KZX6RVSGAYRMCPKHROF3L2CHELFYLEAAN4Z7HN4MW2AAEIJJGBCQC
IFWHYAZTPKL7ZVOV3AFZNIUTRLMVSOACYACA3CI4DSEEZU77YQEQC
LECAQCCL2UKTMH3LQSGEFVBOF6CYCGJGJZM2I2T4R7QPTBVW4VYQC
UV2RTMNSNI3AJFVVD66OT3CEF7WU7TLAPJKACGX7BBSSWAEBAAAAC
SIHQ3OG5GIBQLSMOBKSD2IHFTGFXU4VDCX45BH2ABS42BV3N2UIQC
G5C3RSELXRNXJFQQ6ZVCXLQE4AJNXYB32RULKHWLWF5SGEFXYTWQC
#[pg_test]
fn test_hello_pgvector_rs() {
let arr0 = (0..32)
.map(|x| x as f32)
.map(|x| x.sin())
.collect::<Vec<_>>();
let mut arr = arr0
.iter()
.map(|x| x.into_datum())
.map(|x| x.unwrap())
.collect::<Vec<pg_sys::Datum>>();
let mut nulls = arr.iter().map(|_| false).collect::<Vec<bool>>();
let mut arr_2 = arr.clone();
let mut nulls_2 = nulls.clone();
println!("{:?}", arr);
let vec1 = unsafe {
let size = arr.len();
let b = arr.as_mut_ptr();
println!("{:?}", b);
Array::over(b, nulls.as_mut_ptr(), size)
};
vec1.iter().for_each(|x| {
println!("{:?}", x);
});
let vec2 = unsafe {
let size = arr_2.len();
Array::over(arr_2.as_mut_ptr(), nulls_2.as_mut_ptr(), size)
};
vec2.iter().zip(arr0.iter()).for_each(|(x, y)| {
println!("{:?}", x.unwrap() - y);
});
let d = l2_dist(vec1, vec2);
println!("distance: {}", d);
assert!(false);
}
// #[pg_test]
// fn test_hello_pgvector_rs() {
// let arr0 = (0..32)
// .map(|x| x as f32)
// .map(|x| x.sin())
// .collect::<Vec<_>>();
// let mut arr = arr0
// .iter()
// .map(|x| x.into_datum())
// .map(|x| x.unwrap())
// .collect::<Vec<pg_sys::Datum>>();
// let mut nulls = arr.iter().map(|_| false).collect::<Vec<bool>>();
// let mut arr_2 = arr.clone();
// let mut nulls_2 = nulls.clone();
// let vec1 = unsafe {
// let size = arr.len();
// let b = arr.as_mut_ptr();
// Array::over(b, nulls.as_mut_ptr(), size)
// };
// vec1.iter().for_each(|x| {});
// let vec2 = unsafe {
// let size = arr_2.len();
// Array::over(arr_2.as_mut_ptr(), nulls_2.as_mut_ptr(), size)
// };
// vec2.iter().zip(arr0.iter()).for_each(|(x, y)| {});
// let d = l2_dist(vec1, vec2);
//
// assert!(false);
// }
pub const CREATE_STR: &str = r#"
CREATE TABLE IF NOT EXISTS index_has_allocation (
id serial8 not null primary key,
index_table_name text NOT NULL,
table_name text NOT NULL,
allocation integer UNIQUE NOT NULL,
dim integer NOT NULL
);
"#;
pub fn init_index_alloc_table() {
Spi::connect(|mut client| {
client.update(CREATE_STR, None, None);
Ok(Some(()))
});
}
fn get_ord_opt<T: FromDatum>(heap_tup_data: &SpiHeapTupleData, ord: usize) -> Option<T> {
heap_tup_data
.by_ordinal(ord)
.ok()
.map(|datum| datum.value::<T>())
.flatten()
}
/// This function clears allocation blocks that we've stored by can't use
pub fn clear_non_existent_allocs<const SIZE: usize>(
shmem_blocks: &'static ShMemBlocks<SIZE>,
) -> Option<()> {
let size = shmem_blocks.get_size();
Spi::connect(|mut client| {
client.update(
"DELETE FROM index_has_allocation WHERE allocation > $1;",
None,
Some(vec![(PgBuiltInOids::INT4OID.oid(), size.into_datum())]),
);
let bad_rows = client
.select(
"SELECT id, allocation, dim, index_table_name, table_name FROM index_has_allocation;",
None,
None,
)
.filter_map(|spi_heap_tuple_data| {
let id_opt: Option<i64> = get_ord_opt(&spi_heap_tuple_data, 1);
let alloc_opt: Option<i32> = get_ord_opt(&spi_heap_tuple_data, 2);
let dim_opt: Option<i32> = get_ord_opt(&spi_heap_tuple_data, 3);
let index_table_opt: Option<String> = get_ord_opt(&spi_heap_tuple_data, 4);
let table_opt: Option<String> = get_ord_opt(&spi_heap_tuple_data, 5);
if let Some(id) = id_opt {
if let (Some(allocation_index), Some(dim), Some(index_table_name), Some(table_name)) =
(alloc_opt, dim_opt, index_table_opt, table_opt)
{
if let Some(allocation_lock) =
shmem_blocks.get_block(allocation_index as usize)
{
let mut allocation = allocation_lock.exclusive();
allocation.clear();
allocation.set_dim(dim as u16);
load_vecs_into_mem(
&mut client,
allocation_lock,
allocation_index,
dim as u16,
&index_table_name,
&table_name,
);
None
} else {
Some(id)
}
} else {
Some(id)
}
} else {
None
}
})
.collect::<Vec<_>>();
client.update(
"DELETE FROM index_has_allocation WHERE id in $1;",
None,
Some(vec![(
PgBuiltInOids::INT8ARRAYOID.oid(),
bad_rows.into_datum(),
)]),
);
Ok(Some(()))
})
}
#[derive(Clone, Copy)]
struct ReIndex {
idx_table_id: i64,
lookup: i64,
}
fn load_vecs_into_mem<const SIZE: usize>(
client: &mut SpiClient,
lock: &PgLwLock<FiniteDimAllocator<SIZE>>,
allocation_index: i32,
dim: u16,
index_table_name: &str,
table_name: &str,
) {
let tup_table = client.update(
"SELECT
idx_table.id,
idx_table.lookup_index,
data_table.vec_data
FROM $1 idx_table
JOIN $2 data_table
ON idx_table.vec_id = data_table.vec_id,
WHERE allocation_index = $2
ORDER BY lookup_index;",
None,
Some(vec![
(
PgBuiltInOids::REGCLASSOID.oid(),
index_table_name.into_datum(),
),
(PgBuiltInOids::REGCLASSOID.oid(), table_name.into_datum()),
(
PgBuiltInOids::INT4OID.oid(),
(allocation_index as i32).into_datum(),
),
]),
);
let mut allocation = lock.exclusive();
let max_entries = allocation.get_max_entries();
let mut current_lookup = 0;
let (reindex, delete): (Vec<_>, Vec<_>) = tup_table
.filter_map(|tup_desc_data| {
let idx_table_id = get_ord_opt::<i64>(&tup_desc_data, 1)?;
if let Some(vec_data) = get_ord_opt::<Vec<f32>>(&tup_desc_data, 3) {
allocation.add_vec(&vec_data);
if let Some(lookup_idx) = get_ord_opt::<i64>(&tup_desc_data, 2) {
if current_lookup == lookup_idx {
None
} else {
let reindex = Some(Either::Left(ReIndex {
idx_table_id,
lookup: current_lookup,
}));
current_lookup += 1;
reindex
}
} else {
let reindex = Some(Either::Left(ReIndex {
idx_table_id,
lookup: current_lookup,
}));
current_lookup += 1;
reindex
}
} else {
let delete = Some(Either::Right(idx_table_id));
current_lookup += 1;
delete
}
})
.partition_map(|x| x);
for ReIndex {
idx_table_id,
lookup,
} in reindex.iter()
{
client.update(
"UPDATE $1
SET lookup_index = $2,
WHERE id = $3",
None,
Some(vec![
(
PgBuiltInOids::REGCLASSOID.oid(),
index_table_name.into_datum(),
),
(PgBuiltInOids::INT4OID.oid(), lookup.into_datum()),
(PgBuiltInOids::INT8OID.oid(), idx_table_id.into_datum()),
]),
);
}
client.update(
"DELETE FROM $1 WHERE id in $2",
None,
Some(vec![
(
PgBuiltInOids::REGCLASSOID.oid(),
index_table_name.into_datum(),
),
(PgBuiltInOids::INT8ARRAYOID.oid(), delete.into_datum()),
]),
);
}
#[cfg(test)]
mod test {
// We want to try out something like
// ```sql
// CREATE TABLE public.vecs_2
//(
// idx bigserial,
// vec real[] CONSTRAINT vecs_2_dim512 CHECK (cardinality(vec) = 512),
// PRIMARY KEY (idx)
//);
//
// ```
fn test_constraints() {
todo!() // YOLO
#[cfg(any(test, feature = "pg_test"))]
mod tests {
use crate::init_index_alloc_table;
use crate::mem_cache_index::handlers::AllocatorWithCount;
use crate::mem_cache_index::handlers::MemCacheVectorsBuilder;
use crate::mem_cache_index::handlers::SharedMemHandler;
use crate::mem_cache_index::handlers::Vectors;
use crate::mem_cache_index::handlers::VectorsBuilder;
use std::ops::Deref;
use std::ops::DerefMut;
use std::sync::RwLock;
use crate::MAX_SPACE;
use crate::SHMEM_BLOCKS;
use pgx::*;
use rand::*;
use rand_distr::{Distribution, Poisson, Uniform};
use serde::{Deserialize, Serialize};
pub fn postgresql_conf_options() -> Vec<&'static str> {
// return any postgresql.conf settings that are required for your tests
vec![
"",
"shared_buffers = 512MB",
"pgvector.allocation_blocks = 128",
]
#[pg_test]
fn test_rebuild() {
let mut rng = thread_rng();
let poisson = Poisson::new(1024.0).unwrap();
let n_vecs = poisson.sample(&mut rng) as i32;
let dim = 512u16;
let mut shmem_handler = SharedMemHandler {};
init_index_alloc_table();
println!("Building vectors");
let vectors = VectorsBuilder::new().init(dim);
for _ in 0..n_vecs {
let unif = Uniform::new(-1.0, 1.0);
let vec = (0..dim).map(|_| unif.sample(&mut rng)).collect::<Vec<_>>();
Spi::connect(|mut client| {
client.update(
"INSERT INTO vectors (vec) VALUES ($1)",
None,
Some(vec![(
PgBuiltInOids::FLOAT4ARRAYOID.oid(),
vec.into_datum(),
)]),
);
Ok(Some(()))
});
}
println!("DONE Building vectors");
let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init();
mem_cache_vectors.push_batch(&mut shmem_handler, &vectors);
let AllocatorWithCount {
allocator_index,
count,
} = Spi::connect(|client| -> Result<Option<AllocatorWithCount>, SpiError> {
let (allocator, count) = client
.select(
"SELECT
allocator,
COUNT(*) n_vecs
FROM memcache_index_vectors
GROUP BY allocator
ORDER BY n_vecs DESC",
Some(1),
None,
)
.first()
.get_two::<i32, i32>();
Ok(allocator.map(|allocator_index| AllocatorWithCount {
allocator_index: allocator_index as u32,
count: count.expect("Missing count") as u32,
}))
})
.expect("Failed to get vec count");
let end = SHMEM_BLOCKS
.get_block(allocator_index as usize)
.expect("Couldn't get allocator")
.share()
.len;
assert_eq!(end, count as usize);
// let c = Spi::get_one::<i32>("SELECT sum(array_length(vec, 1)) FROM vectors");
// assert_eq!(Some((dim as i32) * n_vecs), c);
}
use crate::MAX_SPACE;
use crate::SHMEM_BLOCKS;
use pgx::*;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
pub struct VectorsBuilder {}
impl VectorsBuilder {
pub(crate) fn new() -> Self {
VectorsBuilder {}
}
pub(crate) fn init(self, dim: u16) -> Vectors {
Spi::run(&format!(
"CREATE TABLE vectors (
id serial8 not null primary key,
vec real[{}]
)",
dim
));
Vectors { dim }
}
}
pub(crate) struct Vectors {
dim: u16,
}
impl Vectors {
pub(crate) fn get_vecs_block<'a>(
&'a self,
client: &'a SpiClient,
n_vectors: u32,
) -> impl Iterator<Item = VecWithId> + 'a {
let dim = self.dim as usize;
client
.select(
"SELECT
vectors.id,
vectors.vec
FROM vectors
LEFT JOIN memcache_index_vectors miv
ON miv.vector_id = vectors.id
WHERE allocator_index IS NULL",
Some(n_vectors as i64),
None,
)
.map(move |tuple| {
let id = tuple
.by_ordinal(1)
.expect("The table schema disallows null vector_ids. This shouldn't be NULL")
.value::<i64>()
.expect("The value should definitely not be null");
let vec = tuple
.by_ordinal(2)
.expect("The vector argument can't be null")
.value::<Vec<f32>>()
.expect("There should definitely be a vector here");
assert!(vec.len() <= dim);
VecWithId { id, vec }
})
}
}
pub(crate) struct SharedMemHandler {}
impl SharedMemHandler {
pub fn get_new_block(
&mut self,
client: &mut SpiClient,
vec_table: &Vectors,
vec_table_name: &str,
memcache_table: &str,
) -> Option<AllocatorWithCount> {
let size = (SHMEM_BLOCKS.get_size() as i32) - 1;
client
.select(
"SELECT
i
FROM generate_series(0, $1) range(i)
LEFT JOIN index_has_allocation iha
ON iha.allocation = range.i
WHERE iha.allocation IS NULL
ORDER BY i",
Some(1),
Some(vec![(PgBuiltInOids::INT4OID.oid(), size.into_datum())]),
)
.first()
.get_one::<i32>()
.map(|index| {
let lock = SHMEM_BLOCKS
.get_block(index as usize)
.expect("The lock must exist by passing size above");
lock.exclusive().set_dim(vec_table.dim);
let inserts = client
.update(
"INSERT INTO index_has_allocation (
index_table_name,
table_name,
allocation,
dim
)VALUES (
$1,
$2,
$3,
$4
)",
None,
Some(vec![
(PgBuiltInOids::TEXTOID.oid(), memcache_table.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), vec_table_name.into_datum()),
(PgBuiltInOids::INT4OID.oid(), index.into_datum()),
(
PgBuiltInOids::INT4OID.oid(),
(vec_table.dim as i32).into_datum(),
),
]),
)
.count();
let max_vecs_per_block = MAX_SPACE as u32 / vec_table.dim as u32;
AllocatorWithCount {
allocator_index: index as u32,
count: max_vecs_per_block,
}
})
}
}
pub(crate) struct MemCacheVectorsBuilder {}
impl MemCacheVectorsBuilder {
pub(crate) fn new() -> Self {
MemCacheVectorsBuilder {}
}
pub(crate) fn init(self) -> MemCacheVectors {
Spi::run(
"CREATE TABLE memcache_index_vectors (
id serial8 not null primary key,
vector_id bigint UNIQUE NOT NULL REFERENCES vectors(id),
allocator integer NOT NULL REFERENCES index_has_allocation(allocation),
allocator_index integer NOT NULL,
UNIQUE (allocator, allocator_index)
)",
);
MemCacheVectors {}
}
}
pub(crate) struct MemCacheVectors {}
#[derive(Serialize, Deserialize, PostgresType, Debug)]
pub struct AllocatorWithCount {
pub allocator_index: u32,
pub count: u32,
}
#[derive(Deserialize, Serialize, PostgresType, Debug)]
pub struct VecWithId {
pub id: i64,
pub vec: Vec<f32>,
}
impl MemCacheVectors {
pub(crate) fn push_batch(&mut self, shmem: &mut SharedMemHandler, vec: &Vectors) {
Spi::connect(|mut client| {
let max_vecs_per_block = (MAX_SPACE / (vec.dim as usize)) as u32;
let AllocatorWithCount {
allocator_index,
count,
} = self
.get_smallest_existing_alloc(&client, max_vecs_per_block)
.or_else(|| {
shmem.get_new_block(&mut client, vec, "vectors", "memcache_index_vectors")
})
.ok_or(SpiError::Transaction)?; // Probably unwrap is better
let entries = client
.select("SELECT COUNT(*) FROM memcache_index_vectors", None, None)
.first()
.get_one::<i64>();
let vecs_with_ids = vec.get_vecs_block(&mut client, count).collect::<Vec<_>>();
vecs_with_ids.iter().for_each(|vec_with_id| {
self.push_vector(allocator_index as usize, &mut client, vec_with_id)
.expect("Failed to push vec");
}); // todo if one of these fails, there should be a recovery step.
Ok(Some(()))
});
}
fn push_vector(
&mut self,
allocator_number: usize,
client: &mut SpiClient,
vec_with_id: &VecWithId,
) -> Result<(), String> {
let lock = SHMEM_BLOCKS.get_block(allocator_number).ok_or(format!(
"How did we get an allocator_index {}, when we have { } allocators",
allocator_number,
SHMEM_BLOCKS.get_size()
))?;
let mut allocator = lock.exclusive();
let VecWithId { id, vec } = vec_with_id;
let allocator_index = allocator.len as i32;
println!("PUTTING VECTOR");
allocator
.add_vec(vec)
.map_err(|err| format!("ERROR : {:?}", err))?;
let next_index = allocator.len as i32;
println!("PUT VECTOR");
println!(
"VEC_ID : {:?} ALLOCATOR: {:?}, ALLOCATOR_INDEX: {:?}, NEXT_INDEX: {:?}",
id, allocator_number, allocator_index, next_index
);
let inserts = client
.update(
"INSERT INTO memcache_index_vectors (
vector_id,
allocator,
allocator_index
) VALUES (
$1,
$2,
$3
);",
Some(1),
Some(vec![
(PgBuiltInOids::INT8OID.oid(), id.into_datum()),
(PgBuiltInOids::INT4OID.oid(), allocator_number.into_datum()),
(PgBuiltInOids::INT4OID.oid(), allocator_index.into_datum()),
]),
)
.count();
assert!(inserts == 1);
println!("UPDATE VECTOR TABLE");
Ok(())
}
fn get_smallest_existing_alloc<T: Deref<Target = SpiClient>>(
&self,
client: T,
vecs_per_block: u32,
) -> Option<AllocatorWithCount> {
let (allocator_opt, count_opt) = client
.select(
"SELECT
allocator,
COUNT(vector_id) n_vecs
FROM memcache_index_vectors
GROUP BY allocator
HAVING 2 <= $1
ORDER BY n_vecs",
None,
Some(vec![(
PgBuiltInOids::INT4OID.oid(),
vecs_per_block.into_datum(),
)]),
)
.first()
.get_two::<u32, u32>();
let count = count_opt?;
allocator_opt.map(|allocator_index| AllocatorWithCount {
allocator_index,
count,
})
}
}
CREATE SCHEMA IF NOT EXISTS "tests";
-- ./src/mem_cache_index/mod.rs:269:4
CREATE OR REPLACE FUNCTION tests."test_rebuild"() RETURNS void LANGUAGE c AS 'MODULE_PATHNAME', 'test_rebuild_wrapper';
mem_cache_index_handlers.generated.sql