diff --git a/core/index_method/toy_vector_sparse_ivf.rs b/core/index_method/toy_vector_sparse_ivf.rs index 12dca96c1..a3dc20ff2 100644 --- a/core/index_method/toy_vector_sparse_ivf.rs +++ b/core/index_method/toy_vector_sparse_ivf.rs @@ -229,15 +229,17 @@ enum VectorSparseInvertedIndexSearchState { Init, CollectComponentsSeek { sum: f64, - positions: Option>, - components: Option>, + vector: Option>, + idx: usize, + components: Option>, limit: i64, key: Option, }, CollectComponentsRead { sum: f64, - positions: Option>, - components: Option>, + vector: Option>, + idx: usize, + components: Option>, limit: i64, }, Seek { @@ -290,9 +292,17 @@ enum VectorSparseInvertedIndexSearchState { }, } +#[derive(Debug, PartialEq)] +pub enum ScanOrder { + DatasetFrequencyAsc, + QueryWeightDesc, +} + pub struct VectorSparseInvertedIndexMethodCursor { configuration: IndexMethodConfiguration, delta: f64, + scan_portion: f64, + scan_order: ScanOrder, scratch_btree: String, scratch_cursor: Option, stats_btree: String, @@ -348,9 +358,24 @@ impl VectorSparseInvertedIndexMethodCursor { Some(&Value::Float(delta)) => delta, _ => 0.0, }; + let scan_portion = match configuration.parameters.get("scan_portion") { + Some(&Value::Float(scan_portion)) => scan_portion, + _ => 1.0, + }; + let scan_order = match configuration.parameters.get("scan_order") { + Some(Value::Text(scan_order)) if scan_order.as_str() == "dataset_frequency_asc" => { + ScanOrder::DatasetFrequencyAsc + } + Some(Value::Text(scan_order)) if scan_order.as_str() == "query_weight_desc" => { + ScanOrder::QueryWeightDesc + } + _ => ScanOrder::QueryWeightDesc, + }; Self { configuration, delta, + scan_portion, + scan_order, scratch_btree, scratch_cursor: None, stats_btree, @@ -929,7 +954,7 @@ impl IndexMethodCursor for VectorSparseInvertedIndexMethodCursor { "second value must be i64 limit parameter".to_string(), )); }; - let vector = Vector::from_slice(vector)?; + let vector = Vector::from_vec(vector.to_vec())?; if !matches!(vector.vector_type, VectorType::Float32Sparse) { return Err(LimboError::InternalError( "first value must be sparse vector".to_string(), @@ -940,7 +965,8 @@ impl IndexMethodCursor for VectorSparseInvertedIndexMethodCursor { self.search_state = VectorSparseInvertedIndexSearchState::CollectComponentsSeek { sum, - positions: Some(sparse.idx.to_vec().into()), + vector: Some(vector), + idx: 0, components: Some(Vec::new()), key: None, limit, @@ -948,21 +974,39 @@ impl IndexMethodCursor for VectorSparseInvertedIndexMethodCursor { } VectorSparseInvertedIndexSearchState::CollectComponentsSeek { sum, - positions, + vector, + idx, components, limit, key, } => { - let p = positions.as_ref().unwrap(); + let p = &vector.as_ref().unwrap().as_f32_sparse().idx[*idx..]; if p.is_empty() && key.is_none() { let mut components = components.take().unwrap(); - // order by cnt ASC in order to check low-cardinality components first - components.sort_by_key(|c| c.cnt); + match self.scan_order { + ScanOrder::DatasetFrequencyAsc => { + // order by cnt ASC in order to check low-cardinality components first + components.sort_by_key(|(c, _)| c.cnt); + } + ScanOrder::QueryWeightDesc => { + // order by weight DESC in order to check high-impact components first + components + .sort_by_key(|(_, w)| std::cmp::Reverse(FloatOrd(*w as f64))); + } + } + let take = (components.len() as f64 * self.scan_portion).ceil() as usize; + let components = components + .into_iter() + .take(take) + .map(|(c, _)| c) + .collect::>(); tracing::debug!( - "query_start: components: {:?}, delta: {}", + "query_start: components: {:?}, delta: {}, scan_portion: {}, scan_order: {:?}", components, - self.delta + self.delta, + self.scan_portion, + self.scan_order, ); self.search_state = VectorSparseInvertedIndexSearchState::Seek { sum: *sum, @@ -977,7 +1021,7 @@ impl IndexMethodCursor for VectorSparseInvertedIndexMethodCursor { continue; } if key.is_none() { - let position = positions.as_mut().unwrap().pop_front().unwrap(); + let position = vector.as_ref().unwrap().as_f32_sparse().idx[*idx]; *key = Some(ImmutableRecord::from_values( &[Value::Integer(position as i64)], 1, @@ -992,7 +1036,8 @@ impl IndexMethodCursor for VectorSparseInvertedIndexMethodCursor { self.search_state = VectorSparseInvertedIndexSearchState::CollectComponentsRead { sum: *sum, - positions: positions.take(), + vector: vector.take(), + idx: *idx, components: components.take(), limit: *limit, }; @@ -1002,7 +1047,8 @@ impl IndexMethodCursor for VectorSparseInvertedIndexMethodCursor { VectorSparseInvertedIndexSearchState::CollectComponentsSeek { sum: *sum, components: components.take(), - positions: positions.take(), + vector: vector.take(), + idx: *idx + 1, limit: *limit, key: None, }; @@ -1011,18 +1057,21 @@ impl IndexMethodCursor for VectorSparseInvertedIndexMethodCursor { } VectorSparseInvertedIndexSearchState::CollectComponentsRead { sum, - positions, + vector, + idx, components, limit, } => { let record = return_if_io!(stats.record()); + let v = vector.as_ref().unwrap().as_f32_sparse().values[*idx]; let component = parse_stat_row(record.as_deref())?; - components.as_mut().unwrap().push(component); + components.as_mut().unwrap().push((component, v)); self.search_state = VectorSparseInvertedIndexSearchState::CollectComponentsSeek { sum: *sum, components: components.take(), - positions: positions.take(), + vector: vector.take(), + idx: *idx + 1, limit: *limit, key: None, };