DBCRYUN5Q3DIFGBVVCOWRNBMMWCGILT4TY2OY22IYVH6U7TQIYTAC
ZEBVMUTTIRFH5GXANSVDMMYNA5NAFQ3YJZOLVWUD5RBPCQ5QMYYQC
DH3RVE6JRQBOSRUBVIIMKSJYFS4CK5PI25KV2NAJ7EP4Q7NNKSJQC
IFWHYAZTPKL7ZVOV3AFZNIUTRLMVSOACYACA3CI4DSEEZU77YQEQC
UV2RTMNSNI3AJFVVD66OT3CEF7WU7TLAPJKACGX7BBSSWAEBAAAAC
SIHQ3OG5GIBQLSMOBKSD2IHFTGFXU4VDCX45BH2ABS42BV3N2UIQC
LECAQCCL2UKTMH3LQSGEFVBOF6CYCGJGJZM2I2T4R7QPTBVW4VYQC
G5C3RSELXRNXJFQQ6ZVCXLQE4AJNXYB32RULKHWLWF5SGEFXYTWQC
KZX6RVSGAYRMCPKHROF3L2CHELFYLEAAN4Z7HN4MW2AAEIJJGBCQC
D6U4ZZRN6VJKM4WLEQ3B5E23O5GTSURIRVIV4FPHNEXCSUA63VGAC
V5WTYF3HH6YLBBT3C2IH4D5XTY6W5U7AX6JYUOOJPOEZR7PBWKZAC
id serial8 not null primary key,
index_table_name text NOT NULL,
table_name text NOT NULL,
allocation integer UNIQUE NOT NULL,
dim integer NOT NULL
id serial8 NOT NULL PRIMARY KEY,
memcache_index_id bigint NOT NULL REFERENCES memcache_index(id) ON DELETE CASCADE,
allocation integer UNIQUE NOT NULL
for block in SHMEM_BLOCKS.iter() {
block.exclusive().clear();
}
Ok(Some(()))
});
}
pub fn init_index_table() {
Spi::connect(|mut client| {
client.update(
"
CREATE TABLE IF NOT EXISTS memcache_index (
id serial8 NOT NULL PRIMARY KEY,
index_table_name text UNIQUE NOT NULL,
table_name text NOT NULL,
dim integer NOT NULL
)
",
None,
None,
);
if let (Some(allocation_index), Some(dim), Some(index_table_name), Some(table_name)) =
(alloc_opt, dim_opt, index_table_opt, table_opt)
if let (
Some(allocation_index),
Some(dim),
Some(index_table_name),
Some(table_name),
) = (alloc_opt, dim_opt, index_table_opt, table_opt)
use serde::{Deserialize, Serialize};
fn insert_random_vecs<R: Rng>(rng: &mut R, vectors: &Vectors, n_vecs: usize) {
let slice_iter = (0..n_vecs).map(|_| {
let unif = Uniform::new(-1.0, 1.0);
(0..vectors.dim)
.map(|_| unif.sample(rng))
.collect::<Vec<_>>()
});
vectors.insert_vec_iter(slice_iter);
}
let vectors = VectorsBuilder::new().init(dim);
for _ in 0..n_vecs {
let unif = Uniform::new(-1.0, 1.0);
let vec = (0..dim).map(|_| unif.sample(&mut rng)).collect::<Vec<_>>();
Spi::connect(|mut client| {
client.update(
"INSERT INTO vectors (vec) VALUES ($1)",
None,
Some(vec![(
PgBuiltInOids::FLOAT4ARRAYOID.oid(),
vec.into_datum(),
)]),
);
Ok(Some(()))
});
}
let vectors = VectorsBuilder::new().init("vectors_2", dim);
insert_random_vecs(&mut rng, &vectors, n_vecs);
let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init();
mem_cache_vectors.push_batch(&mut shmem_handler, &vectors);
let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init(&vectors);
mem_cache_vectors.push_all_possible(&mut shmem_handler);
// let c = Spi::get_one::<i32>("SELECT sum(array_length(vec, 1)) FROM vectors");
// assert_eq!(Some((dim as i32) * n_vecs), c);
impl RandDim {
fn get_dim<R: Rng>(&self, rng: &mut R) -> Result<u16, rand::Error> {
let n_samples = self.samples.sample(rng) as usize;
let mut total = 10;
for _ in 0..n_samples {
let on = self.sample_on_or_off.sample(rng);
if on {
total += self.contribution_per_sample.sample(rng) as u16;
}
}
Ok(total)
}
#[pg_test]
fn test_rebuild() {
// First we build
let mut rng = thread_rng();
let n_tables_poisson = Poisson::new(6.0).unwrap();
let n_vecs_poisson = Poisson::new(50.0).expect("Bad poisson for n_vecs");
let rand_dim_gen = RandDim {
samples: Poisson::new(20.0).expect("Bad Poisson for n_samples"),
contribution_per_sample: Poisson::new(10.0)
.expect("Bad Poisson for contribution_per_sample"),
sample_on_or_off: Bernoulli::new(0.5).expect("Bad Bernouli for on or off."),
};
let mut shmem_handler = SharedMemHandler {};
println!("INITIALISING TABLES");
init_index_table();
clear_allocations();
let n_tables = n_tables_poisson.sample(&mut rng) as i32;
let vector_tables = (0..n_tables)
.map(|i| {
let dim = rand_dim_gen.get_dim(&mut rng).expect("Failed to gen dim.");
let vector_table =
VectorsBuilder::new().init(&format!("vector_table_{}", i + 1), dim);
let n_vecs = 100 * n_vecs_poisson.sample(&mut rng) as usize;
println!("N_vecs: {}, dim: {}", n_vecs, dim);
insert_random_vecs(&mut rng, &vector_table, n_vecs);
vector_table
})
.collect::<Vec<_>>();
let mem_cache_vectors = vector_tables
.iter()
.map(|vector_table| {
let mut mem_cache_vectors = MemCacheVectorsBuilder::new().init(&vector_table);
mem_cache_vectors.push_all_possible(&mut shmem_handler);
mem_cache_vectors
})
.collect::<Vec<_>>();
println!("Resetting memcache");
clear_allocations();
rebuild(&mut shmem_handler);
assert!(true)
}
// todo: test querying
pub(crate) fn init(self, dim: u16) -> Vectors {
Spi::run(&format!(
"CREATE TABLE vectors (
id serial8 not null primary key,
vec real[{}]
)",
dim
));
Vectors { dim }
pub(crate) fn init(self, name: &str, dim: u16) -> Vectors {
Spi::connect(|mut client| {
client.update(
&format!(
"CREATE TABLE {} (
id serial8 not null primary key,
vec real[{}])",
&name, dim
),
None,
None,
);
Ok(Some(()))
});
//.expect(format!("Unable to create vector table {}", name).as_str());
let name = String::from(name);
Vectors { name, dim }
impl Vectors {
pub(crate) fn get_vecs_block<'a>(
&'a self,
client: &'a SpiClient,
n_vectors: u32,
) -> impl Iterator<Item = VecWithId> + 'a {
let dim = self.dim as usize;
client
pub(crate) fn rebuild(shmem: &mut SharedMemHandler) {
Spi::connect(|client| {
let vec_tables_with_index = client
.expect("The table schema disallows null vector_ids. This shouldn't be NULL")
.value::<i64>()
.expect("The value should definitely not be null");
let vec = tuple
.expect("Unable to read vector table name")
.value::<String>()
.expect("Unable to cast vector table name to String");
let index_table_name = heap_tuple
.expect("The vector argument can't be null")
.value::<Vec<f32>>()
.expect("There should definitely be a vector here");
assert!(vec.len() <= dim);
VecWithId { id, vec }
.expect("Unabled to read index table name")
.value::<String>()
.expect("Unable to cast index table name to String");
let dim = heap_tuple
.by_ordinal(3)
.expect("Unable to read dimension")
.value::<i32>()
.expect("Unable to cast dimension to int");
let mut vec = Vectors {
name: vec_table_name,
dim: dim as u16,
};
(vec, index_table_name)
})
.collect::<Vec<_>>();
vec_tables_with_index
.into_iter()
.fold(client, |client, (vec, index_table_name)| {
let mut memcache_index = MemCacheVectors {
vector_table: &vec,
name: index_table_name,
};
memcache_index
.push_until_done(client, shmem)
.expect("Failed to return client")
.spi_client
});
Ok(Some(()))
});
}
struct WrapSlice<'a, T>(&'a [T]);
impl<'a, T: Copy + IntoDatum> IntoDatum for WrapSlice<'a, T> {
fn into_datum(self) -> Option<pg_sys::Datum> {
let WrapSlice(slice) = self;
let mut state = unsafe {
pg_sys::initArrayResult(
T::type_oid(),
PgMemoryContexts::CurrentMemoryContext.value(),
false,
)
};
for s in slice {
let datum = s.into_datum();
let isnull = datum.is_none();
unsafe {
state = pg_sys::accumArrayResult(
state,
datum.unwrap_or(0usize),
isnull,
T::type_oid(),
PgMemoryContexts::CurrentMemoryContext.value(),
);
}
}
if state.is_null() {
// shoudln't happen
None
} else {
Some(unsafe {
pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
impl Vectors {
pub(crate) fn insert_slice<'a>(&self, vec: &'a [f32]) {
assert!(vec.len() == (self.dim as usize));
let slice = WrapSlice(vec);
Spi::connect(|mut client| {
client.update(
// todo: figure out sql formatting and schemata
&format!("INSERT INTO {} (vec) VALUES ($1)", &self.name),
None,
Some(vec![(
PgBuiltInOids::FLOAT4ARRAYOID.oid(),
slice.into_datum(),
)]),
);
Ok(Some(()))
});
}
pub(crate) fn insert_vecs<'a>(&self, vectors: &'a [Vec<f32>]) {
self.insert_slice_iter(vectors.iter().map(|x| x.as_slice()));
}
pub(crate) fn insert_slice_iter<'a, Vecs>(&self, vectors: Vecs)
where
Vecs: Iterator<Item = &'a [f32]>,
{
for vec in vectors {
self.insert_slice(vec);
}
}
pub(crate) fn insert_vec_iter<Vecs>(&self, vectors: Vecs)
where
Vecs: Iterator<Item = Vec<f32>>,
{
for vec in vectors {
self.insert_slice(vec.as_slice());
}
}
}
i
FROM generate_series(0, $1) range(i)
LEFT JOIN index_has_allocation iha
ON iha.allocation = range.i
WHERE iha.allocation IS NULL
ORDER BY i",
i
FROM generate_series(0, $1) range(i)
LEFT JOIN index_has_allocation iha
ON iha.allocation = range.i
WHERE iha.allocation IS NULL
ORDER BY i",
let inserts = client
let table_id = client
.select(
"
SELECT
id
FROM memcache_index
WHERE index_table_name = $1
",
Some(1),
Some(vec![(
PgBuiltInOids::TEXTOID.oid(),
memcache_table.name.as_str().into_datum(),
)]),
)
.first()
.get_one::<i64>()
.expect(&format!(
"Unable to find entry for table {}",
&memcache_table.name
));
client
pub(crate) fn init(self) -> MemCacheVectors {
Spi::run(
"CREATE TABLE memcache_index_vectors (
id serial8 not null primary key,
vector_id bigint UNIQUE NOT NULL REFERENCES vectors(id),
allocator integer NOT NULL REFERENCES index_has_allocation(allocation),
allocator_index integer NOT NULL,
UNIQUE (allocator, allocator_index)
)",
);
MemCacheVectors {}
pub(crate) fn init(self, vectors: &Vectors) -> MemCacheVectors {
let name = format!("memcache_table_{}", Uuid::new_v4()).replace("-", "_");
Spi::execute(|mut client| {
let id = client
.update(
"
INSERT INTO memcache_index (
index_table_name,
table_name,
dim
) VALUES (
$1,
$2,
$3
) RETURNING id
",
Some(1),
Some(vec![
(PgBuiltInOids::TEXTOID.oid(), name.as_str().into_datum()),
(
PgBuiltInOids::TEXTOID.oid(),
vectors.name.as_str().into_datum(),
),
(
PgBuiltInOids::INT4OID.oid(),
(vectors.dim as i32).into_datum(),
),
]),
)
.first()
.get_one::<i64>()
.expect("Couldn't create an index entry.");
client.update(
// todo: figure out sql formatting and schemata
&format!(
"CREATE TABLE {} (
id serial8 not null primary key,
vector_id bigint UNIQUE NOT NULL REFERENCES {}(id) ON DELETE CASCADE,
allocator integer NOT NULL REFERENCES index_has_allocation(allocation) ON DELETE CASCADE,
allocator_index integer NOT NULL,
UNIQUE (allocator, allocator_index)
)",
&name,
&vectors.name
),
None,
None
);
});
MemCacheVectors {
vector_table: vectors,
name,
}
impl MemCacheVectors {
pub(crate) fn push_batch(&mut self, shmem: &mut SharedMemHandler, vec: &Vectors) {
Spi::connect(|mut client| {
let max_vecs_per_block = (MAX_SPACE / (vec.dim as usize)) as u32;
let AllocatorWithCount {
allocator_index,
count,
} = self
.get_smallest_existing_alloc(&client, max_vecs_per_block)
.or_else(|| {
shmem.get_new_block(&mut client, vec, "vectors", "memcache_index_vectors")
})
.ok_or(SpiError::Transaction)?; // Probably unwrap is better
pub(crate) struct SpiClientWithReturn<T> {
spi_client: SpiClient,
return_val: T,
}
trait ReturnWithSpi
where
Self: Sized,
{
fn wrap_with_client(self, spi_client: SpiClient) -> SpiClientWithReturn<Self> {
SpiClientWithReturn {
spi_client,
return_val: self,
}
}
}
impl<T: Sized> ReturnWithSpi for T {}
type ResultWithSpi<T> = Result<SpiClientWithReturn<T>, SpiError>;
impl<'a> MemCacheVectors<'a> {
fn get_n_unindexed(&self, client: &SpiClient) -> i64 {
client
.select(
&format!(
"
SELECT
COUNT(vec.id)
FROM {} vec
LEFT JOIN {} miv
ON miv.vector_id = vec.id
WHERE miv.vector_id IS NULL
",
self.vector_table.name.as_str(),
self.name.as_str()
),
Some(1),
None,
)
.first()
.get_one::<i64>()
.expect("Query number of indexed vectors must return one row.")
}
let entries = client
.select("SELECT COUNT(*) FROM memcache_index_vectors", None, None)
.first()
.get_one::<i64>();
let vecs_with_ids = vec.get_vecs_block(&mut client, count).collect::<Vec<_>>();
vecs_with_ids.iter().for_each(|vec_with_id| {
self.push_vector(allocator_index as usize, &mut client, vec_with_id)
.expect("Failed to push vec");
}); // todo if one of these fails, there should be a recovery step.
Ok(Some(()))
pub(crate) fn get_vecs_block(
&self,
client: &SpiClient,
n_vectors: u32,
) -> impl Iterator<Item = VecWithId> + 'a {
let dim = self.vector_table.dim as usize;
client
.select(
// todo: figure out sql formatting and schemata
&format!(
"SELECT
vec.id,
vec.vec
FROM {} vec
LEFT JOIN {} miv
ON miv.vector_id = vec.id
WHERE miv.vector_id IS NULL",
&self.vector_table.name, self.name
),
Some(n_vectors as i64),
None,
)
.map(move |tuple| {
let id = tuple
.by_ordinal(1)
.expect("The table schema disallows null vector_ids. This shouldn't be NULL")
.value::<i64>()
.expect("The value should definitely not be null");
let vec = tuple
.by_ordinal(2)
.expect("The vector argument can't be null")
.value::<Vec<f32>>()
.expect("There should definitely be a vector here");
assert!(vec.len() <= dim);
VecWithId { id, vec }
})
}
pub(crate) fn push_until_done(
&mut self,
mut client: SpiClient,
shmem: &mut SharedMemHandler,
) -> ResultWithSpi<()> {
let SpiClientWithReturn {
spi_client,
return_val,
} = self.push_batch(shmem, client)?;
if return_val == 0 {
Ok(().wrap_with_client(spi_client))
} else {
self.push_until_done(spi_client, shmem)
}
}
pub(crate) fn push_all_possible(&mut self, shmem: &mut SharedMemHandler) {
Spi::connect(|mut client| {
let count = self.get_n_unindexed(&client);
if count == 0 {
Ok(Some(()))
} else {
self.push_until_done(client, shmem)
.expect("Failed to push vecs to shared buffers");
Ok(Some(()))
}
}
fn push_batch(
&mut self,
shmem: &mut SharedMemHandler,
mut client: SpiClient,
) -> ResultWithSpi<usize> {
let max_vecs_per_block = (MAX_SPACE / (self.vector_table.dim as usize)) as u32;
let AllocatorWithCount {
allocator_index,
count,
} = self
.get_smallest_existing_alloc(&mut client, max_vecs_per_block)
.or_else(|| {
println!("Couldn't find existing block. Looking for new block");
shmem.get_new_block(&mut client, self)
})
.ok_or({
println!("Doing some Stuff");
SpiError::Transaction
})?; // Probably unwrap is better
let vecs_with_ids = self.get_vecs_block(&client, count).collect::<Vec<_>>();
let inserts = vecs_with_ids
.iter()
.map(|vec_with_id| self.push_vector(allocator_index as usize, &mut client, vec_with_id))
.collect::<Result<Vec<()>, String>>()
.map_err(|_| SpiError::Transaction)?
.len(); // todo if one of these fails, there should be a recovery step.
println!(
"Insert n_vecs {}, in allocator {} with dim {}",
vecs_with_ids.len(),
allocator_index,
self.vector_table.dim
);
Ok(inserts.wrap_with_client(client))
/// we can create our own schemas, which are just Rust `mod`s. Anything defined in this module
/// will be created in a Postgres schema of the same name
#[pg_schema]
mod some_schema {
use pgx::*;
use serde::{Deserialize, Serialize};
#[derive(PostgresType, Serialize, Deserialize)]
pub struct MySomeSchemaType(pub(crate) String);
#[pg_extern]
fn hello_some_schema() -> &'static str {
"Hello from some_schema"
}
}
-- ./src/vector_query/opclass.rs:5:0
CREATE OR REPLACE FUNCTION "cosine_similarity"("a" real[], "b" real[]) RETURNS real STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'cosine_similarity_wrapper';
-- ./src/vector_query/opclass.rs:27:0
CREATE OR REPLACE FUNCTION "l2_dist"("a" real[], "b" real[]) RETURNS real STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'l2_dist_wrapper';
-- ./src/vector_query/opclass.rs:42:0
CREATE OR REPLACE FUNCTION "ip"("a" real[], "b" real[]) RETURNS real STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'ip_wrapper';
CREATE OR REPLACE FUNCTION amhandler(internal) RETURNS index_am_handler PARALLEL SAFE IMMUTABLE STRICT COST 0.0001 LANGUAGE c AS 'MODULE_PATHNAME', 'amhandler_wrapper';
CREATE ACCESS METHOD vector TYPE INDEX HANDLER amhandler;
CREATE SCHEMA IF NOT EXISTS "tests";
-- ./src/mem_cache_index/mod.rs:269:4
CREATE OR REPLACE FUNCTION tests."test_rebuild"() RETURNS void LANGUAGE c AS 'MODULE_PATHNAME', 'test_rebuild_wrapper';
-- ./src/mem_cache_index/mod.rs:212:0
CREATE TYPE vector;
CREATE OR REPLACE FUNCTION vector_in(cstring) RETURNS vector IMMUTABLE STRICT PARALLEL SAFE LANGUAGE C AS 'MODULE_PATHNAME', 'vector_in_wrapper';
CREATE OR REPLACE FUNCTION vector_out(vector) RETURNS cstring IMMUTABLE STRICT PARALLEL SAFE LANGUAGE C AS 'MODULE_PATHNAME', 'vector_out_wrapper';
CREATE TYPE vector (
INTERNALLENGTH = variable,
INPUT = vector_in,
OUTPUT = vector_out,
STORAGE = extended
);
CREATE OR REPLACE FUNCTION "set_allocator_dim"("block" integer, "dimemsion" smallint) RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'set_allocator_dim_wrapper';
CREATE OR REPLACE FUNCTION "fill_vec_allocator"("block" integer, "vecs" integer) RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'fill_vec_allocator_wrapper';
CREATE OR REPLACE FUNCTION "push_vec_to_mem"("block" integer, "vec" real[]) RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'push_vec_to_mem_wrapper';
CREATE OR REPLACE FUNCTION "query_vec"("block" integer, "vec" real[], "ids" bigint[]) RETURNS TABLE ("index" bigint, "score" real) STRICT LANGUAGE c AS 'MODULE_PATHNAME', 'query_vec_wrapper';
-- ./src/lib.rs:195:0
-- ./src/lib.rs:182:0
-- ./src/lib.rs:167:0
-- ./src/lib.rs:160:0
[build]
# Postgres symbols won't ve available until runtime
rustflags = ["-C", "link-args=-Wl,-undefined,dynamic_lookup"]
# Auto-generated by pgx. You may edit this, or delete it to have a new one created.
[target.x86_64-unknown-linux-gnu]
linker = "./.cargo/pgx-linker-script.sh"
[target.aarch64-unknown-linux-gnu]
linker = "./.cargo/pgx-linker-script.sh"
[target.x86_64-apple-darwin]
linker = "./.cargo/pgx-linker-script.sh"
[target.aarch64-apple-darwin]
linker = "./.cargo/pgx-linker-script.sh"
[target.x86_64-unknown-freebsd]
linker = "./.cargo/pgx-linker-script.sh"