DBCRYUN5Q3DIFGBVVCOWRNBMMWCGILT4TY2OY22IYVH6U7TQIYTAC ZEBVMUTTIRFH5GXANSVDMMYNA5NAFQ3YJZOLVWUD5RBPCQ5QMYYQC DH3RVE6JRQBOSRUBVIIMKSJYFS4CK5PI25KV2NAJ7EP4Q7NNKSJQC IFWHYAZTPKL7ZVOV3AFZNIUTRLMVSOACYACA3CI4DSEEZU77YQEQC UV2RTMNSNI3AJFVVD66OT3CEF7WU7TLAPJKACGX7BBSSWAEBAAAAC SIHQ3OG5GIBQLSMOBKSD2IHFTGFXU4VDCX45BH2ABS42BV3N2UIQC LECAQCCL2UKTMH3LQSGEFVBOF6CYCGJGJZM2I2T4R7QPTBVW4VYQC G5C3RSELXRNXJFQQ6ZVCXLQE4AJNXYB32RULKHWLWF5SGEFXYTWQC KZX6RVSGAYRMCPKHROF3L2CHELFYLEAAN4Z7HN4MW2AAEIJJGBCQC D6U4ZZRN6VJKM4WLEQ3B5E23O5GTSURIRVIV4FPHNEXCSUA63VGAC V5WTYF3HH6YLBBT3C2IH4D5XTY6W5U7AX6JYUOOJPOEZR7PBWKZAC id serial8 not null primary key,index_table_name text NOT NULL,table_name text NOT NULL,allocation integer UNIQUE NOT NULL,dim integer NOT NULL
id serial8 NOT NULL PRIMARY KEY,memcache_index_id bigint NOT NULL REFERENCES memcache_index(id) ON DELETE CASCADE,allocation integer UNIQUE NOT NULL
for block in SHMEM_BLOCKS.iter() {block.exclusive().clear();}Ok(Some(()))});}pub fn init_index_table() {Spi::connect(|mut client| {client.update("CREATE TABLE IF NOT EXISTS memcache_index (id serial8 NOT NULL PRIMARY KEY,index_table_name text UNIQUE NOT NULL,table_name text NOT NULL,dim integer NOT NULL)",None,None,);
if let (Some(allocation_index), Some(dim), Some(index_table_name), Some(table_name)) =(alloc_opt, dim_opt, index_table_opt, table_opt)
if let (Some(allocation_index),Some(dim),Some(index_table_name),Some(table_name),) = (alloc_opt, dim_opt, index_table_opt, table_opt)
use serde::{Deserialize, Serialize};
fn insert_random_vecs<R: Rng>(rng: &mut R, vectors: &Vectors, n_vecs: usize) {let slice_iter = (0..n_vecs).map(|_| {let unif = Uniform::new(-1.0, 1.0);(0..vectors.dim).map(|_| unif.sample(rng)).collect::<Vec<_>>()});vectors.insert_vec_iter(slice_iter);}
let vectors = VectorsBuilder::new().init(dim);for _ in 0..n_vecs {let unif = Uniform::new(-1.0, 1.0);let vec = (0..dim).map(|_| unif.sample(&mut rng)).collect::<Vec<_>>();Spi::connect(|mut client| {client.update("INSERT INTO vectors (vec) VALUES ($1)",None,Some(vec![(PgBuiltInOids::FLOAT4ARRAYOID.oid(),vec.into_datum(),)]),);Ok(Some(()))});}
let vectors = VectorsBuilder::new().init("vectors_2", dim);insert_random_vecs(&mut rng, &vectors, n_vecs);
let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init();mem_cache_vectors.push_batch(&mut shmem_handler, &vectors);
let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init(&vectors);mem_cache_vectors.push_all_possible(&mut shmem_handler);
// let c = Spi::get_one::<i32>("SELECT sum(array_length(vec, 1)) FROM vectors");// assert_eq!(Some((dim as i32) * n_vecs), c);
impl RandDim {fn get_dim<R: Rng>(&self, rng: &mut R) -> Result<u16, rand::Error> {let n_samples = self.samples.sample(rng) as usize;let mut total = 10;for _ in 0..n_samples {let on = self.sample_on_or_off.sample(rng);if on {total += self.contribution_per_sample.sample(rng) as u16;}}Ok(total)}
#[pg_test]fn test_rebuild() {// First we buildlet mut rng = thread_rng();let n_tables_poisson = Poisson::new(6.0).unwrap();let n_vecs_poisson = Poisson::new(50.0).expect("Bad poisson for n_vecs");let rand_dim_gen = RandDim {samples: Poisson::new(20.0).expect("Bad Poisson for n_samples"),contribution_per_sample: Poisson::new(10.0).expect("Bad Poisson for contribution_per_sample"),sample_on_or_off: Bernoulli::new(0.5).expect("Bad Bernouli for on or off."),};let mut shmem_handler = SharedMemHandler {};println!("INITIALISING TABLES");init_index_table();clear_allocations();let n_tables = n_tables_poisson.sample(&mut rng) as i32;let vector_tables = (0..n_tables).map(|i| {let dim = rand_dim_gen.get_dim(&mut rng).expect("Failed to gen dim.");let vector_table =VectorsBuilder::new().init(&format!("vector_table_{}", i + 1), dim);let n_vecs = 100 * n_vecs_poisson.sample(&mut rng) as usize;println!("N_vecs: {}, dim: {}", n_vecs, dim);insert_random_vecs(&mut rng, &vector_table, n_vecs);vector_table}).collect::<Vec<_>>();let mem_cache_vectors = vector_tables.iter().map(|vector_table| {let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init(&vector_table);mem_cache_vectors.push_all_possible(&mut shmem_handler);mem_cache_vectors}).collect::<Vec<_>>();println!("Resetting memcache");clear_allocations();rebuild(&mut shmem_handler);assert!(true)}// todo: test querying
pub(crate) fn init(self, dim: u16) -> Vectors {Spi::run(&format!("CREATE TABLE vectors (id serial8 not null primary key,vec real[{}])",dim));Vectors { dim }
pub(crate) fn init(self, name: &str, dim: u16) -> Vectors {Spi::connect(|mut client| {client.update(&format!("CREATE TABLE {} (id serial8 not null primary key,vec real[{}])",&name, dim),None,None,);Ok(Some(()))});//.expect(format!("Unable to create vector table {}", name).as_str());let name = String::from(name);Vectors { name, dim }
impl Vectors {pub(crate) fn get_vecs_block<'a>(&'a self,client: &'a SpiClient,n_vectors: u32,) -> impl Iterator<Item = VecWithId> + 'a {let dim = self.dim as usize;client
pub(crate) fn rebuild(shmem: &mut SharedMemHandler) {Spi::connect(|client| {let vec_tables_with_index = client
.expect("The table schema disallows null vector_ids. This shouldn't be NULL").value::<i64>().expect("The value should definitely not be null");let vec = tuple
.expect("Unable to read vector table name").value::<String>().expect("Unable to cast vector table name to String");let index_table_name = heap_tuple
.expect("The vector argument can't be null").value::<Vec<f32>>().expect("There should definitely be a vector here");assert!(vec.len() <= dim);VecWithId { id, vec }
.expect("Unabled to read index table name").value::<String>().expect("Unable to cast index table name to String");let dim = heap_tuple.by_ordinal(3).expect("Unable to read dimension").value::<i32>().expect("Unable to cast dimension to int");let mut vec = Vectors {name: vec_table_name,dim: dim as u16,};(vec, index_table_name)}).collect::<Vec<_>>();vec_tables_with_index.into_iter().fold(client, |client, (vec, index_table_name)| {let mut memcache_index = MemCacheVectors {vector_table: &vec,name: index_table_name,};memcache_index.push_until_done(client, shmem).expect("Failed to return client").spi_client});Ok(Some(()))});}struct WrapSlice<'a, T>(&'a [T]);impl<'a, T: Copy + IntoDatum> IntoDatum for WrapSlice<'a, T> {fn into_datum(self) -> Option<pg_sys::Datum> {let WrapSlice(slice) = self;let mut state = unsafe {pg_sys::initArrayResult(T::type_oid(),PgMemoryContexts::CurrentMemoryContext.value(),false,)};for s in slice {let datum = s.into_datum();let isnull = datum.is_none();unsafe {state = pg_sys::accumArrayResult(state,datum.unwrap_or(0usize),isnull,T::type_oid(),PgMemoryContexts::CurrentMemoryContext.value(),);}}if state.is_null() {// shoudln't happenNone} else {Some(unsafe {pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
impl Vectors {pub(crate) fn insert_slice<'a>(&self, vec: &'a [f32]) {assert!(vec.len() == (self.dim as usize));let slice = WrapSlice(vec);Spi::connect(|mut client| {client.update(// todo: figure out sql formatting and schemata&format!("INSERT INTO {} (vec) VALUES ($1)", &self.name),None,Some(vec![(PgBuiltInOids::FLOAT4ARRAYOID.oid(),slice.into_datum(),)]),);Ok(Some(()))});}
pub(crate) fn insert_vecs<'a>(&self, vectors: &'a [Vec<f32>]) {self.insert_slice_iter(vectors.iter().map(|x| x.as_slice()));}pub(crate) fn insert_slice_iter<'a, Vecs>(&self, vectors: Vecs)whereVecs: Iterator<Item = &'a [f32]>,{for vec in vectors {self.insert_slice(vec);}}pub(crate) fn insert_vec_iter<Vecs>(&self, vectors: Vecs)whereVecs: Iterator<Item = Vec<f32>>,{for vec in vectors {self.insert_slice(vec.as_slice());}}}
iFROM generate_series(0, $1) range(i)LEFT JOIN index_has_allocation ihaON iha.allocation = range.iWHERE iha.allocation IS NULLORDER BY i",
iFROM generate_series(0, $1) range(i)LEFT JOIN index_has_allocation ihaON iha.allocation = range.iWHERE iha.allocation IS NULLORDER BY i",
let inserts = client
let table_id = client.select("SELECTidFROM memcache_indexWHERE index_table_name = $1",Some(1),Some(vec![(PgBuiltInOids::TEXTOID.oid(),memcache_table.name.as_str().into_datum(),)]),).first().get_one::<i64>().expect(&format!("Unable to find entry for table {}",&memcache_table.name));client
pub(crate) fn init(self) -> MemCacheVectors {Spi::run("CREATE TABLE memcache_index_vectors (id serial8 not null primary key,vector_id bigint UNIQUE NOT NULL REFERENCES vectors(id),allocator integer NOT NULL REFERENCES index_has_allocation(allocation),allocator_index integer NOT NULL,UNIQUE (allocator, allocator_index))",);MemCacheVectors {}
pub(crate) fn init(self, vectors: &Vectors) -> MemCacheVectors {let name = format!("memcache_table_{}", Uuid::new_v4()).replace("-", "_");Spi::execute(|mut client| {let id = client.update("INSERT INTO memcache_index (index_table_name,table_name,dim) VALUES ($1,$2,$3) RETURNING id",Some(1),Some(vec![(PgBuiltInOids::TEXTOID.oid(), name.as_str().into_datum()),(PgBuiltInOids::TEXTOID.oid(),vectors.name.as_str().into_datum(),),(PgBuiltInOids::INT4OID.oid(),(vectors.dim as i32).into_datum(),),]),).first().get_one::<i64>().expect("Couldn't create an index entry.");client.update(// todo: figure out sql formatting and schemata&format!("CREATE TABLE {} (id serial8 not null primary key,vector_id bigint UNIQUE NOT NULL REFERENCES {}(id) ON DELETE CASCADE,allocator integer NOT NULL REFERENCES index_has_allocation(allocation) ON DELETE CASCADE,allocator_index integer NOT NULL,UNIQUE (allocator, allocator_index))",&name,&vectors.name),None,None);});MemCacheVectors {vector_table: vectors,name,}
impl MemCacheVectors {pub(crate) fn push_batch(&mut self, shmem: &mut SharedMemHandler, vec: &Vectors) {Spi::connect(|mut client| {let max_vecs_per_block = (MAX_SPACE / (vec.dim as usize)) as u32;let AllocatorWithCount {allocator_index,count,} = self.get_smallest_existing_alloc(&client, max_vecs_per_block).or_else(|| {shmem.get_new_block(&mut client, vec, "vectors", "memcache_index_vectors")}).ok_or(SpiError::Transaction)?; // Probably unwrap is better
pub(crate) struct SpiClientWithReturn<T> {spi_client: SpiClient,return_val: T,}trait ReturnWithSpiwhereSelf: Sized,{fn wrap_with_client(self, spi_client: SpiClient) -> SpiClientWithReturn<Self> {SpiClientWithReturn {spi_client,return_val: self,}}}impl<T: Sized> ReturnWithSpi for T {}type ResultWithSpi<T> = Result<SpiClientWithReturn<T>, SpiError>;impl<'a> MemCacheVectors<'a> {fn get_n_unindexed(&self, client: &SpiClient) -> i64 {client.select(&format!("SELECTCOUNT(vec.id)FROM {} vecLEFT JOIN {} mivON miv.vector_id = vec.idWHERE miv.vector_id IS NULL",self.vector_table.name.as_str(),self.name.as_str()),Some(1),None,).first().get_one::<i64>().expect("Query number of indexed vectors must return one row.")}
let entries = client.select("SELECT COUNT(*) FROM memcache_index_vectors", None, None).first().get_one::<i64>();let vecs_with_ids = vec.get_vecs_block(&mut client, count).collect::<Vec<_>>();vecs_with_ids.iter().for_each(|vec_with_id| {self.push_vector(allocator_index as usize, &mut client, vec_with_id).expect("Failed to push vec");}); // todo if one of these fails, there should be a recovery step.Ok(Some(()))
pub(crate) fn get_vecs_block(&self,client: &SpiClient,n_vectors: u32,) -> impl Iterator<Item = VecWithId> + 'a {let dim = self.vector_table.dim as usize;client.select(// todo: figure out sql formatting and schemata&format!("SELECTvec.id,vec.vecFROM {} vecLEFT JOIN {} mivON miv.vector_id = vec.idWHERE miv.vector_id IS NULL",&self.vector_table.name, self.name),Some(n_vectors as i64),None,).map(move |tuple| {let id = tuple.by_ordinal(1).expect("The table schema disallows null vector_ids. This shouldn't be NULL").value::<i64>().expect("The value should definitely not be null");let vec = tuple.by_ordinal(2).expect("The vector argument can't be null").value::<Vec<f32>>().expect("There should definitely be a vector here");assert!(vec.len() <= dim);VecWithId { id, vec }})}pub(crate) fn push_until_done(&mut self,mut client: SpiClient,shmem: &mut SharedMemHandler,) -> ResultWithSpi<()> {let SpiClientWithReturn {spi_client,return_val,} = self.push_batch(shmem, client)?;if return_val == 0 {Ok(().wrap_with_client(spi_client))} else {self.push_until_done(spi_client, shmem)}}pub(crate) fn push_all_possible(&mut self, shmem: &mut SharedMemHandler) {Spi::connect(|mut client| {let count = self.get_n_unindexed(&client);if count == 0 {Ok(Some(()))} else {self.push_until_done(client, shmem).expect("Failed to push vecs to shared buffers");Ok(Some(()))}
}fn push_batch(&mut self,shmem: &mut SharedMemHandler,mut client: SpiClient,) -> ResultWithSpi<usize> {let max_vecs_per_block = (MAX_SPACE / (self.vector_table.dim as usize)) as u32;let AllocatorWithCount {allocator_index,count,} = self.get_smallest_existing_alloc(&mut client, max_vecs_per_block).or_else(|| {println!("Couldn't find existing block. Looking for new block");shmem.get_new_block(&mut client, self)}).ok_or({println!("Doing some Stuff");SpiError::Transaction})?; // Probably unwrap is betterlet vecs_with_ids = self.get_vecs_block(&client, count).collect::<Vec<_>>();let inserts = vecs_with_ids.iter().map(|vec_with_id| self.push_vector(allocator_index as usize, &mut client, vec_with_id)).collect::<Result<Vec<()>, String>>().map_err(|_| SpiError::Transaction)?.len(); // todo if one of these fails, there should be a recovery step.println!("Insert n_vecs {}, in allocator {} with dim {}",vecs_with_ids.len(),allocator_index,self.vector_table.dim);Ok(inserts.wrap_with_client(client))
/// we can create our own schemas, which are just Rust `mod`s. Anything defined in this module/// will be created in a Postgres schema of the same name#[pg_schema]mod some_schema {use pgx::*;use serde::{Deserialize, Serialize};#[derive(PostgresType, Serialize, Deserialize)]pub struct MySomeSchemaType(pub(crate) String);#[pg_extern]fn hello_some_schema() -> &'static str {"Hello from some_schema"}}
-- ./src/vector_query/opclass.rs:5:0CREATE OR REPLACE FUNCTION "cosine_similarity"("a" real[], "b" real[]) RETURNS real STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'cosine_similarity_wrapper';-- ./src/vector_query/opclass.rs:27:0CREATE OR REPLACE FUNCTION "l2_dist"("a" real[], "b" real[]) RETURNS real STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'l2_dist_wrapper';-- ./src/vector_query/opclass.rs:42:0CREATE OR REPLACE FUNCTION "ip"("a" real[], "b" real[]) RETURNS real STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'ip_wrapper';
CREATE OR REPLACE FUNCTION amhandler(internal) RETURNS index_am_handler PARALLEL SAFE IMMUTABLE STRICT COST 0.0001 LANGUAGE c AS 'MODULE_PATHNAME', 'amhandler_wrapper';CREATE ACCESS METHOD vector TYPE INDEX HANDLER amhandler;CREATE SCHEMA IF NOT EXISTS "tests";-- ./src/mem_cache_index/mod.rs:269:4CREATE OR REPLACE FUNCTION tests."test_rebuild"() RETURNS void LANGUAGE c AS 'MODULE_PATHNAME', 'test_rebuild_wrapper';-- ./src/mem_cache_index/mod.rs:212:0
CREATE TYPE vector;CREATE OR REPLACE FUNCTION vector_in(cstring) RETURNS vector IMMUTABLE STRICT PARALLEL SAFE LANGUAGE C AS 'MODULE_PATHNAME', 'vector_in_wrapper';CREATE OR REPLACE FUNCTION vector_out(vector) RETURNS cstring IMMUTABLE STRICT PARALLEL SAFE LANGUAGE C AS 'MODULE_PATHNAME', 'vector_out_wrapper';CREATE TYPE vector (INTERNALLENGTH = variable,INPUT = vector_in,OUTPUT = vector_out,STORAGE = extended);CREATE OR REPLACE FUNCTION "set_allocator_dim"("block" integer, "dimemsion" smallint) RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'set_allocator_dim_wrapper';CREATE OR REPLACE FUNCTION "fill_vec_allocator"("block" integer, "vecs" integer) RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'fill_vec_allocator_wrapper';CREATE OR REPLACE FUNCTION "push_vec_to_mem"("block" integer, "vec" real[]) RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'push_vec_to_mem_wrapper';CREATE OR REPLACE FUNCTION "query_vec"("block" integer, "vec" real[], "ids" bigint[]) RETURNS TABLE ("index" bigint, "score" real) STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'query_vec_wrapper';-- ./src/lib.rs:195:0-- ./src/lib.rs:182:0-- ./src/lib.rs:167:0-- ./src/lib.rs:160:0
[build]# Postgres symbols won't ve available until runtimerustflags = ["-C", "link-args=-Wl,-undefined,dynamic_lookup"]
# Auto-generated by pgx. You may edit this, or delete it to have a new one created.[target.x86_64-unknown-linux-gnu]linker = "./.cargo/pgx-linker-script.sh"[target.aarch64-unknown-linux-gnu]linker = "./.cargo/pgx-linker-script.sh"[target.x86_64-apple-darwin]linker = "./.cargo/pgx-linker-script.sh"[target.aarch64-apple-darwin]linker = "./.cargo/pgx-linker-script.sh"[target.x86_64-unknown-freebsd]linker = "./.cargo/pgx-linker-script.sh"