DH3RVE6JRQBOSRUBVIIMKSJYFS4CK5PI25KV2NAJ7EP4Q7NNKSJQC KZX6RVSGAYRMCPKHROF3L2CHELFYLEAAN4Z7HN4MW2AAEIJJGBCQC IFWHYAZTPKL7ZVOV3AFZNIUTRLMVSOACYACA3CI4DSEEZU77YQEQC LECAQCCL2UKTMH3LQSGEFVBOF6CYCGJGJZM2I2T4R7QPTBVW4VYQC UV2RTMNSNI3AJFVVD66OT3CEF7WU7TLAPJKACGX7BBSSWAEBAAAAC SIHQ3OG5GIBQLSMOBKSD2IHFTGFXU4VDCX45BH2ABS42BV3N2UIQC G5C3RSELXRNXJFQQ6ZVCXLQE4AJNXYB32RULKHWLWF5SGEFXYTWQC #[pg_test]fn test_hello_pgvector_rs() {let arr0 = (0..32).map(|x| x as f32).map(|x| x.sin()).collect::<Vec<_>>();let mut arr = arr0.iter().map(|x| x.into_datum()).map(|x| x.unwrap()).collect::<Vec<pg_sys::Datum>>();let mut nulls = arr.iter().map(|_| false).collect::<Vec<bool>>();let mut arr_2 = arr.clone();let mut nulls_2 = nulls.clone();println!("{:?}", arr);let vec1 = unsafe {let size = arr.len();let b = arr.as_mut_ptr();println!("{:?}", b);Array::over(b, nulls.as_mut_ptr(), size)};vec1.iter().for_each(|x| {println!("{:?}", x);});let vec2 = unsafe {let size = arr_2.len();Array::over(arr_2.as_mut_ptr(), nulls_2.as_mut_ptr(), size)};vec2.iter().zip(arr0.iter()).for_each(|(x, y)| {println!("{:?}", x.unwrap() - y);});let d = l2_dist(vec1, vec2);println!("distance: {}", d);assert!(false);}
// #[pg_test]// fn test_hello_pgvector_rs() {// let arr0 = (0..32)// .map(|x| x as f32)// .map(|x| x.sin())// .collect::<Vec<_>>();// let mut arr = arr0// .iter()// .map(|x| x.into_datum())// .map(|x| x.unwrap())// .collect::<Vec<pg_sys::Datum>>();// let mut nulls = arr.iter().map(|_| false).collect::<Vec<bool>>();// let mut arr_2 = arr.clone();// let mut nulls_2 = nulls.clone();// let vec1 = unsafe {// let size = arr.len();// let b = arr.as_mut_ptr();// Array::over(b, nulls.as_mut_ptr(), size)// };// vec1.iter().for_each(|x| {});// let vec2 = unsafe {// let size = arr_2.len();// Array::over(arr_2.as_mut_ptr(), nulls_2.as_mut_ptr(), size)// };// vec2.iter().zip(arr0.iter()).for_each(|(x, y)| {});// let d = l2_dist(vec1, vec2);//// assert!(false);// }
pub const CREATE_STR: &str = r#"CREATE TABLE IF NOT EXISTS index_has_allocation (id serial8 not null primary key,index_table_name text NOT NULL,table_name text NOT NULL,allocation integer UNIQUE NOT NULL,dim integer NOT NULL);"#;pub fn init_index_alloc_table() {Spi::connect(|mut client| {client.update(CREATE_STR, None, None);Ok(Some(()))});}fn get_ord_opt<T: FromDatum>(heap_tup_data: &SpiHeapTupleData, ord: usize) -> Option<T> {heap_tup_data.by_ordinal(ord).ok().map(|datum| datum.value::<T>()).flatten()}/// This function clears allocation blocks that we've stored by can't usepub fn clear_non_existent_allocs<const SIZE: usize>(shmem_blocks: &'static ShMemBlocks<SIZE>,) -> Option<()> {let size = shmem_blocks.get_size();Spi::connect(|mut client| {client.update("DELETE FROM index_has_allocation WHERE allocation > $1;",None,Some(vec![(PgBuiltInOids::INT4OID.oid(), size.into_datum())]),);let bad_rows = client.select("SELECT id, allocation, dim, index_table_name, table_name FROM index_has_allocation;",None,None,).filter_map(|spi_heap_tuple_data| {let id_opt: Option<i64> = get_ord_opt(&spi_heap_tuple_data, 1);let alloc_opt: Option<i32> = get_ord_opt(&spi_heap_tuple_data, 2);let dim_opt: Option<i32> = get_ord_opt(&spi_heap_tuple_data, 3);let index_table_opt: Option<String> = get_ord_opt(&spi_heap_tuple_data, 4);let table_opt: Option<String> = get_ord_opt(&spi_heap_tuple_data, 5);if let Some(id) = id_opt {if let (Some(allocation_index), Some(dim), Some(index_table_name), Some(table_name)) =(alloc_opt, dim_opt, index_table_opt, table_opt){if let Some(allocation_lock) =shmem_blocks.get_block(allocation_index as usize){let mut allocation = allocation_lock.exclusive();allocation.clear();allocation.set_dim(dim as u16);load_vecs_into_mem(&mut client,allocation_lock,allocation_index,dim as u16,&index_table_name,&table_name,);None} else {Some(id)}} else {Some(id)}} else {None}}).collect::<Vec<_>>();client.update("DELETE FROM index_has_allocation WHERE id in $1;",None,Some(vec![(PgBuiltInOids::INT8ARRAYOID.oid(),bad_rows.into_datum(),)]),);Ok(Some(()))})}#[derive(Clone, Copy)]struct ReIndex {idx_table_id: i64,lookup: i64,}fn load_vecs_into_mem<const SIZE: usize>(client: &mut SpiClient,lock: &PgLwLock<FiniteDimAllocator<SIZE>>,allocation_index: i32,dim: u16,index_table_name: &str,table_name: &str,) {let tup_table = client.update("SELECTidx_table.id,idx_table.lookup_index,data_table.vec_dataFROM $1 idx_tableJOIN $2 data_tableON idx_table.vec_id = data_table.vec_id,WHERE allocation_index = $2ORDER BY lookup_index;",None,Some(vec![(PgBuiltInOids::REGCLASSOID.oid(),index_table_name.into_datum(),),(PgBuiltInOids::REGCLASSOID.oid(), table_name.into_datum()),(PgBuiltInOids::INT4OID.oid(),(allocation_index as i32).into_datum(),),]),);let mut allocation = lock.exclusive();let max_entries = allocation.get_max_entries();let mut current_lookup = 0;let (reindex, delete): (Vec<_>, Vec<_>) = tup_table.filter_map(|tup_desc_data| {let idx_table_id = get_ord_opt::<i64>(&tup_desc_data, 1)?;if let Some(vec_data) = get_ord_opt::<Vec<f32>>(&tup_desc_data, 3) {allocation.add_vec(&vec_data);if let Some(lookup_idx) = get_ord_opt::<i64>(&tup_desc_data, 2) {if current_lookup == lookup_idx {None} else {let reindex = Some(Either::Left(ReIndex {idx_table_id,lookup: current_lookup,}));current_lookup += 1;reindex}} else {let reindex = Some(Either::Left(ReIndex {idx_table_id,lookup: current_lookup,}));current_lookup += 1;reindex}} else {let delete = Some(Either::Right(idx_table_id));current_lookup += 1;delete}}).partition_map(|x| x);for ReIndex {idx_table_id,lookup,} in reindex.iter(){client.update("UPDATE $1SET lookup_index = $2,WHERE id = $3",None,Some(vec![(PgBuiltInOids::REGCLASSOID.oid(),index_table_name.into_datum(),),(PgBuiltInOids::INT4OID.oid(), lookup.into_datum()),(PgBuiltInOids::INT8OID.oid(), idx_table_id.into_datum()),]),);}client.update("DELETE FROM $1 WHERE id in $2",None,Some(vec![(PgBuiltInOids::REGCLASSOID.oid(),index_table_name.into_datum(),),(PgBuiltInOids::INT8ARRAYOID.oid(), delete.into_datum()),]),);}
#[cfg(test)]mod test {// We want to try out something like// ```sql// CREATE TABLE public.vecs_2//(// idx bigserial,// vec real[] CONSTRAINT vecs_2_dim512 CHECK (cardinality(vec) = 512),// PRIMARY KEY (idx)//);//// ```fn test_constraints() {todo!() // YOLO
#[cfg(any(test, feature = "pg_test"))]mod tests {use crate::init_index_alloc_table;use crate::mem_cache_index::handlers::AllocatorWithCount;use crate::mem_cache_index::handlers::MemCacheVectorsBuilder;use crate::mem_cache_index::handlers::SharedMemHandler;use crate::mem_cache_index::handlers::Vectors;use crate::mem_cache_index::handlers::VectorsBuilder;use std::ops::Deref;use std::ops::DerefMut;use std::sync::RwLock;use crate::MAX_SPACE;use crate::SHMEM_BLOCKS;use pgx::*;use rand::*;use rand_distr::{Distribution, Poisson, Uniform};use serde::{Deserialize, Serialize};pub fn postgresql_conf_options() -> Vec<&'static str> {// return any postgresql.conf settings that are required for your testsvec!["","shared_buffers = 512MB","pgvector.allocation_blocks = 128",]
#[pg_test]fn test_rebuild() {let mut rng = thread_rng();let poisson = Poisson::new(1024.0).unwrap();let n_vecs = poisson.sample(&mut rng) as i32;let dim = 512u16;let mut shmem_handler = SharedMemHandler {};init_index_alloc_table();println!("Building vectors");let vectors = VectorsBuilder::new().init(dim);for _ in 0..n_vecs {let unif = Uniform::new(-1.0, 1.0);let vec = (0..dim).map(|_| unif.sample(&mut rng)).collect::<Vec<_>>();Spi::connect(|mut client| {client.update("INSERT INTO vectors (vec) VALUES ($1)",None,Some(vec![(PgBuiltInOids::FLOAT4ARRAYOID.oid(),vec.into_datum(),)]),);Ok(Some(()))});}println!("DONE Building vectors");let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init();mem_cache_vectors.push_batch(&mut shmem_handler, &vectors);let AllocatorWithCount {allocator_index,count,} = Spi::connect(|client| -> Result<Option<AllocatorWithCount>, SpiError> {let (allocator, count) = client.select("SELECTallocator,COUNT(*) n_vecsFROM memcache_index_vectorsGROUP BY allocatorORDER BY n_vecs DESC",Some(1),None,).first().get_two::<i32, i32>();Ok(allocator.map(|allocator_index| AllocatorWithCount {allocator_index: allocator_index as u32,count: count.expect("Missing count") as u32,}))}).expect("Failed to get vec count");let end = SHMEM_BLOCKS.get_block(allocator_index as usize).expect("Couldn't get allocator").share().len;assert_eq!(end, count as usize);// let c = Spi::get_one::<i32>("SELECT sum(array_length(vec, 1)) FROM vectors");// assert_eq!(Some((dim as i32) * n_vecs), c);}
use crate::MAX_SPACE;use crate::SHMEM_BLOCKS;use pgx::*;use serde::{Deserialize, Serialize};use std::ops::Deref;pub struct VectorsBuilder {}impl VectorsBuilder {pub(crate) fn new() -> Self {VectorsBuilder {}}pub(crate) fn init(self, dim: u16) -> Vectors {Spi::run(&format!("CREATE TABLE vectors (id serial8 not null primary key,vec real[{}])",dim));Vectors { dim }}}pub(crate) struct Vectors {dim: u16,}impl Vectors {pub(crate) fn get_vecs_block<'a>(&'a self,client: &'a SpiClient,n_vectors: u32,) -> impl Iterator<Item = VecWithId> + 'a {let dim = self.dim as usize;client.select("SELECTvectors.id,vectors.vecFROM vectorsLEFT JOIN memcache_index_vectors mivON miv.vector_id = vectors.idWHERE allocator_index IS NULL",Some(n_vectors as i64),None,).map(move |tuple| {let id = tuple.by_ordinal(1).expect("The table schema disallows null vector_ids. This shouldn't be NULL").value::<i64>().expect("The value should definitely not be null");let vec = tuple.by_ordinal(2).expect("The vector argument can't be null").value::<Vec<f32>>().expect("There should definitely be a vector here");assert!(vec.len() <= dim);VecWithId { id, vec }})}}pub(crate) struct SharedMemHandler {}impl SharedMemHandler {pub fn get_new_block(&mut self,client: &mut SpiClient,vec_table: &Vectors,vec_table_name: &str,memcache_table: &str,) -> Option<AllocatorWithCount> {let size = (SHMEM_BLOCKS.get_size() as i32) - 1;client.select("SELECTiFROM generate_series(0, $1) range(i)LEFT JOIN index_has_allocation ihaON iha.allocation = range.iWHERE iha.allocation IS NULLORDER BY i",Some(1),Some(vec![(PgBuiltInOids::INT4OID.oid(), size.into_datum())]),).first().get_one::<i32>().map(|index| {let lock = SHMEM_BLOCKS.get_block(index as usize).expect("The lock must exist by passing size above");lock.exclusive().set_dim(vec_table.dim);let inserts = client.update("INSERT INTO index_has_allocation (index_table_name,table_name,allocation,dim)VALUES ($1,$2,$3,$4)",None,Some(vec![(PgBuiltInOids::TEXTOID.oid(), memcache_table.into_datum()),(PgBuiltInOids::TEXTOID.oid(), vec_table_name.into_datum()),(PgBuiltInOids::INT4OID.oid(), index.into_datum()),(PgBuiltInOids::INT4OID.oid(),(vec_table.dim as i32).into_datum(),),]),).count();let max_vecs_per_block = MAX_SPACE as u32 / vec_table.dim as u32;AllocatorWithCount {allocator_index: index as u32,count: max_vecs_per_block,}})}}pub(crate) struct MemCacheVectorsBuilder {}impl MemCacheVectorsBuilder {pub(crate) fn new() -> Self {MemCacheVectorsBuilder {}}pub(crate) fn init(self) -> MemCacheVectors {Spi::run("CREATE TABLE memcache_index_vectors (id serial8 not null primary key,vector_id bigint UNIQUE NOT NULL REFERENCES vectors(id),allocator integer NOT NULL REFERENCES index_has_allocation(allocation),allocator_index integer NOT NULL,UNIQUE (allocator, allocator_index))",);MemCacheVectors {}}}pub(crate) struct MemCacheVectors {}#[derive(Serialize, Deserialize, PostgresType, Debug)]pub struct AllocatorWithCount {pub allocator_index: u32,pub count: u32,}#[derive(Deserialize, Serialize, PostgresType, Debug)]pub struct VecWithId {pub id: i64,pub vec: Vec<f32>,}impl MemCacheVectors {pub(crate) fn push_batch(&mut self, shmem: &mut SharedMemHandler, vec: &Vectors) {Spi::connect(|mut client| {let max_vecs_per_block = (MAX_SPACE / (vec.dim as usize)) as u32;let AllocatorWithCount {allocator_index,count,} = self.get_smallest_existing_alloc(&client, max_vecs_per_block).or_else(|| {shmem.get_new_block(&mut client, vec, "vectors", "memcache_index_vectors")}).ok_or(SpiError::Transaction)?; // Probably unwrap is betterlet entries = client.select("SELECT COUNT(*) FROM memcache_index_vectors", None, None).first().get_one::<i64>();let vecs_with_ids = vec.get_vecs_block(&mut client, count).collect::<Vec<_>>();vecs_with_ids.iter().for_each(|vec_with_id| {self.push_vector(allocator_index as usize, &mut client, vec_with_id).expect("Failed to push vec");}); // todo if one of these fails, there should be a recovery step.Ok(Some(()))});}fn push_vector(&mut self,allocator_number: usize,client: &mut SpiClient,vec_with_id: &VecWithId,) -> Result<(), String> {let lock = SHMEM_BLOCKS.get_block(allocator_number).ok_or(format!("How did we get an allocator_index {}, when we have { } allocators",allocator_number,SHMEM_BLOCKS.get_size()))?;let mut allocator = lock.exclusive();let VecWithId { id, vec } = vec_with_id;let allocator_index = allocator.len as i32;println!("PUTTING VECTOR");allocator.add_vec(vec).map_err(|err| format!("ERROR : {:?}", err))?;let next_index = allocator.len as i32;println!("PUT VECTOR");println!("VEC_ID : {:?} ALLOCATOR: {:?}, ALLOCATOR_INDEX: {:?}, NEXT_INDEX: {:?}",id, allocator_number, allocator_index, next_index);let inserts = client.update("INSERT INTO memcache_index_vectors (vector_id,allocator,allocator_index) VALUES ($1,$2,$3);",Some(1),Some(vec![(PgBuiltInOids::INT8OID.oid(), id.into_datum()),(PgBuiltInOids::INT4OID.oid(), allocator_number.into_datum()),(PgBuiltInOids::INT4OID.oid(), allocator_index.into_datum()),]),).count();assert!(inserts == 1);println!("UPDATE VECTOR TABLE");Ok(())}fn get_smallest_existing_alloc<T: Deref<Target = SpiClient>>(&self,client: T,vecs_per_block: u32,) -> Option<AllocatorWithCount> {let (allocator_opt, count_opt) = client.select("SELECTallocator,COUNT(vector_id) n_vecsFROM memcache_index_vectorsGROUP BY allocatorHAVING 2 <= $1ORDER BY n_vecs",None,Some(vec![(PgBuiltInOids::INT4OID.oid(),vecs_per_block.into_datum(),)]),).first().get_two::<u32, u32>();let count = count_opt?;allocator_opt.map(|allocator_index| AllocatorWithCount {allocator_index,count,})}}
CREATE SCHEMA IF NOT EXISTS "tests";-- ./src/mem_cache_index/mod.rs:269:4CREATE OR REPLACE FUNCTION tests."test_rebuild"() RETURNS void LANGUAGE c AS 'MODULE_PATHNAME', 'test_rebuild_wrapper';
mem_cache_index_handlers.generated.sql