mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-11-14 15:45:13 +01:00
[docx] split commit for file 400
Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
parent
6c4587804f
commit
3c586de8ec
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/doubleArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/doubleArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class doubleArray {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected doubleArray(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(doubleArray obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_doubleArray(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public doubleArray(int nelements) {
|
|
||||||
this(swigfaissJNI.new_doubleArray(nelements), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getitem(int index) {
|
|
||||||
return swigfaissJNI.doubleArray_getitem(swigCPtr, this, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setitem(int index, double value) {
|
|
||||||
swigfaissJNI.doubleArray_setitem(swigCPtr, this, index, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_double cast() {
|
|
||||||
long cPtr = swigfaissJNI.doubleArray_cast(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_double(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static doubleArray frompointer(SWIGTYPE_p_double t) {
|
|
||||||
long cPtr = swigfaissJNI.doubleArray_frompointer(SWIGTYPE_p_double.getCPtr(t));
|
|
||||||
return (cPtr == 0) ? null : new doubleArray(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/floatArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/floatArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class floatArray {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected floatArray(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(floatArray obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_floatArray(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public floatArray(int nelements) {
|
|
||||||
this(swigfaissJNI.new_floatArray(nelements), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public float getitem(int index) {
|
|
||||||
return swigfaissJNI.floatArray_getitem(swigCPtr, this, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setitem(int index, float value) {
|
|
||||||
swigfaissJNI.floatArray_setitem(swigCPtr, this, index, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_float cast() {
|
|
||||||
long cPtr = swigfaissJNI.floatArray_cast(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static floatArray frompointer(SWIGTYPE_p_float t) {
|
|
||||||
long cPtr = swigfaissJNI.floatArray_frompointer(SWIGTYPE_p_float.getCPtr(t));
|
|
||||||
return (cPtr == 0) ? null : new floatArray(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Binary file not shown.
@ -1,133 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class float_maxheap_array_t {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected float_maxheap_array_t(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(float_maxheap_array_t obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_float_maxheap_array_t(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNh(long value) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_nh_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getNh() {
|
|
||||||
return swigfaissJNI.float_maxheap_array_t_nh_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setK(long value) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_k_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getK() {
|
|
||||||
return swigfaissJNI.float_maxheap_array_t_k_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIds(LongVector value) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector getIds() {
|
|
||||||
return new LongVector(swigfaissJNI.float_maxheap_array_t_ids_get(swigCPtr, this), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setVal(SWIGTYPE_p_float value) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_float.getCPtr(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_float getVal() {
|
|
||||||
long cPtr = swigfaissJNI.float_maxheap_array_t_val_get(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_float get_val(long key) {
|
|
||||||
long cPtr = swigfaissJNI.float_maxheap_array_t_get_val(swigCPtr, this, key);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector get_ids(long key) {
|
|
||||||
return new LongVector(swigfaissJNI.float_maxheap_array_t_get_ids(swigCPtr, this, key), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void heapify() {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_heapify(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0, long ni) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0, long ni) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void reorder() {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_reorder(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void per_line_extrema(SWIGTYPE_p_float vals_out, LongVector idx_out) {
|
|
||||||
swigfaissJNI.float_maxheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_float.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
public float_maxheap_array_t() {
|
|
||||||
this(swigfaissJNI.new_float_maxheap_array_t(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Binary file not shown.
@ -1,133 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class float_minheap_array_t {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected float_minheap_array_t(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(float_minheap_array_t obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_float_minheap_array_t(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNh(long value) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_nh_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getNh() {
|
|
||||||
return swigfaissJNI.float_minheap_array_t_nh_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setK(long value) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_k_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getK() {
|
|
||||||
return swigfaissJNI.float_minheap_array_t_k_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIds(LongVector value) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector getIds() {
|
|
||||||
return new LongVector(swigfaissJNI.float_minheap_array_t_ids_get(swigCPtr, this), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setVal(SWIGTYPE_p_float value) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_float.getCPtr(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_float getVal() {
|
|
||||||
long cPtr = swigfaissJNI.float_minheap_array_t_val_get(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_float get_val(long key) {
|
|
||||||
long cPtr = swigfaissJNI.float_minheap_array_t_get_val(swigCPtr, this, key);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector get_ids(long key) {
|
|
||||||
return new LongVector(swigfaissJNI.float_minheap_array_t_get_ids(swigCPtr, this, key), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void heapify() {
|
|
||||||
swigfaissJNI.float_minheap_array_t_heapify(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0, long ni) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_float vin) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0, long ni) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void reorder() {
|
|
||||||
swigfaissJNI.float_minheap_array_t_reorder(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void per_line_extrema(SWIGTYPE_p_float vals_out, LongVector idx_out) {
|
|
||||||
swigfaissJNI.float_minheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_float.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
public float_minheap_array_t() {
|
|
||||||
this(swigfaissJNI.new_float_minheap_array_t(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/intArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/intArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class intArray {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected intArray(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(intArray obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_intArray(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public intArray(int nelements) {
|
|
||||||
this(swigfaissJNI.new_intArray(nelements), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getitem(int index) {
|
|
||||||
return swigfaissJNI.intArray_getitem(swigCPtr, this, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setitem(int index, int value) {
|
|
||||||
swigfaissJNI.intArray_setitem(swigCPtr, this, index, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_int cast() {
|
|
||||||
long cPtr = swigfaissJNI.intArray_cast(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static intArray frompointer(SWIGTYPE_p_int t) {
|
|
||||||
long cPtr = swigfaissJNI.intArray_frompointer(SWIGTYPE_p_int.getCPtr(t));
|
|
||||||
return (cPtr == 0) ? null : new intArray(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Binary file not shown.
@ -1,133 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class int_maxheap_array_t {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected int_maxheap_array_t(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(int_maxheap_array_t obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_int_maxheap_array_t(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNh(long value) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_nh_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getNh() {
|
|
||||||
return swigfaissJNI.int_maxheap_array_t_nh_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setK(long value) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_k_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getK() {
|
|
||||||
return swigfaissJNI.int_maxheap_array_t_k_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIds(LongVector value) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector getIds() {
|
|
||||||
return new LongVector(swigfaissJNI.int_maxheap_array_t_ids_get(swigCPtr, this), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setVal(SWIGTYPE_p_int value) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_int.getCPtr(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_int getVal() {
|
|
||||||
long cPtr = swigfaissJNI.int_maxheap_array_t_val_get(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_int get_val(long key) {
|
|
||||||
long cPtr = swigfaissJNI.int_maxheap_array_t_get_val(swigCPtr, this, key);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector get_ids(long key) {
|
|
||||||
return new LongVector(swigfaissJNI.int_maxheap_array_t_get_ids(swigCPtr, this, key), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void heapify() {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_heapify(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0, long ni) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0, long ni) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void reorder() {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_reorder(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void per_line_extrema(SWIGTYPE_p_int vals_out, LongVector idx_out) {
|
|
||||||
swigfaissJNI.int_maxheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_int.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int_maxheap_array_t() {
|
|
||||||
this(swigfaissJNI.new_int_maxheap_array_t(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Binary file not shown.
@ -1,133 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class int_minheap_array_t {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected int_minheap_array_t(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(int_minheap_array_t obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_int_minheap_array_t(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNh(long value) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_nh_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getNh() {
|
|
||||||
return swigfaissJNI.int_minheap_array_t_nh_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setK(long value) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_k_set(swigCPtr, this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getK() {
|
|
||||||
return swigfaissJNI.int_minheap_array_t_k_get(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIds(LongVector value) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector getIds() {
|
|
||||||
return new LongVector(swigfaissJNI.int_minheap_array_t_ids_get(swigCPtr, this), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setVal(SWIGTYPE_p_int value) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_int.getCPtr(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_int getVal() {
|
|
||||||
long cPtr = swigfaissJNI.int_minheap_array_t_val_get(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_int get_val(long key) {
|
|
||||||
long cPtr = swigfaissJNI.int_minheap_array_t_get_val(swigCPtr, this, key);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LongVector get_ids(long key) {
|
|
||||||
return new LongVector(swigfaissJNI.int_minheap_array_t_get_ids(swigCPtr, this, key), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void heapify() {
|
|
||||||
swigfaissJNI.int_minheap_array_t_heapify(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0, long ni) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn(long nj, SWIGTYPE_p_int vin) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0, long ni) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void reorder() {
|
|
||||||
swigfaissJNI.int_minheap_array_t_reorder(swigCPtr, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void per_line_extrema(SWIGTYPE_p_int vals_out, LongVector idx_out) {
|
|
||||||
swigfaissJNI.int_minheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_int.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int_minheap_array_t() {
|
|
||||||
this(swigfaissJNI.new_int_minheap_array_t(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/longArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/longArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class longArray {
|
|
||||||
private transient long swigCPtr;
|
|
||||||
protected transient boolean swigCMemOwn;
|
|
||||||
|
|
||||||
protected longArray(long cPtr, boolean cMemoryOwn) {
|
|
||||||
swigCMemOwn = cMemoryOwn;
|
|
||||||
swigCPtr = cPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static long getCPtr(longArray obj) {
|
|
||||||
return (obj == null) ? 0 : obj.swigCPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("deprecation")
|
|
||||||
protected void finalize() {
|
|
||||||
delete();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void delete() {
|
|
||||||
if (swigCPtr != 0) {
|
|
||||||
if (swigCMemOwn) {
|
|
||||||
swigCMemOwn = false;
|
|
||||||
swigfaissJNI.delete_longArray(swigCPtr);
|
|
||||||
}
|
|
||||||
swigCPtr = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public longArray(int nelements) {
|
|
||||||
this(swigfaissJNI.new_longArray(nelements), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getitem(int index) {
|
|
||||||
return swigfaissJNI.longArray_getitem(swigCPtr, this, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setitem(int index, long value) {
|
|
||||||
swigfaissJNI.longArray_setitem(swigCPtr, this, index, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SWIGTYPE_p_long_long cast() {
|
|
||||||
long cPtr = swigfaissJNI.longArray_cast(swigCPtr, this);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_long_long(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static longArray frompointer(SWIGTYPE_p_long_long t) {
|
|
||||||
long cPtr = swigfaissJNI.longArray_frompointer(SWIGTYPE_p_long_long.getCPtr(t));
|
|
||||||
return (cPtr == 0) ? null : new longArray(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaiss.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaiss.docx
Normal file
Binary file not shown.
@ -1,575 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public class swigfaiss implements swigfaissConstants {
|
|
||||||
public static void bitvec_print(SWIGTYPE_p_unsigned_char b, long d) {
|
|
||||||
swigfaissJNI.bitvec_print(SWIGTYPE_p_unsigned_char.getCPtr(b), d);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void fvecs2bitvecs(SWIGTYPE_p_float x, SWIGTYPE_p_unsigned_char b, long d, long n) {
|
|
||||||
swigfaissJNI.fvecs2bitvecs(SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_unsigned_char.getCPtr(b), d, n);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void bitvecs2fvecs(SWIGTYPE_p_unsigned_char b, SWIGTYPE_p_float x, long d, long n) {
|
|
||||||
swigfaissJNI.bitvecs2fvecs(SWIGTYPE_p_unsigned_char.getCPtr(b), SWIGTYPE_p_float.getCPtr(x), d, n);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void fvec2bitvec(SWIGTYPE_p_float x, SWIGTYPE_p_unsigned_char b, long d) {
|
|
||||||
swigfaissJNI.fvec2bitvec(SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_unsigned_char.getCPtr(b), d);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void bitvec_shuffle(long n, long da, long db, SWIGTYPE_p_int order, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b) {
|
|
||||||
swigfaissJNI.bitvec_shuffle(n, da, db, SWIGTYPE_p_int.getCPtr(order), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setHamming_batch_size(long value) {
|
|
||||||
swigfaissJNI.hamming_batch_size_set(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long getHamming_batch_size() {
|
|
||||||
return swigfaissJNI.hamming_batch_size_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int popcount64(long x) {
|
|
||||||
return swigfaissJNI.popcount64(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void hammings(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, long nbytespercode, SWIGTYPE_p_int dis) {
|
|
||||||
swigfaissJNI.hammings(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, nbytespercode, SWIGTYPE_p_int.getCPtr(dis));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long ncodes, int ordered) {
|
|
||||||
swigfaissJNI.hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, ncodes, ordered);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void hammings_knn(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long ncodes, int ordered) {
|
|
||||||
swigfaissJNI.hammings_knn(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, ncodes, ordered);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void hammings_knn_mc(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, long k, long ncodes, SWIGTYPE_p_int distances, LongVector labels) {
|
|
||||||
swigfaissJNI.hammings_knn_mc(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, k, ncodes, SWIGTYPE_p_int.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void hamming_range_search(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, int radius, long ncodes, RangeSearchResult result) {
|
|
||||||
swigfaissJNI.hamming_range_search(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, radius, ncodes, RangeSearchResult.getCPtr(result), result);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void hamming_count_thres(SWIGTYPE_p_unsigned_char bs1, SWIGTYPE_p_unsigned_char bs2, long n1, long n2, int ht, long ncodes, SWIGTYPE_p_unsigned_long nptr) {
|
|
||||||
swigfaissJNI.hamming_count_thres(SWIGTYPE_p_unsigned_char.getCPtr(bs1), SWIGTYPE_p_unsigned_char.getCPtr(bs2), n1, n2, ht, ncodes, SWIGTYPE_p_unsigned_long.getCPtr(nptr));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long match_hamming_thres(SWIGTYPE_p_unsigned_char bs1, SWIGTYPE_p_unsigned_char bs2, long n1, long n2, int ht, long ncodes, LongVector idx, SWIGTYPE_p_int dis) {
|
|
||||||
return swigfaissJNI.match_hamming_thres(SWIGTYPE_p_unsigned_char.getCPtr(bs1), SWIGTYPE_p_unsigned_char.getCPtr(bs2), n1, n2, ht, ncodes, SWIGTYPE_p_long_long.getCPtr(idx.data()), idx, SWIGTYPE_p_int.getCPtr(dis));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void crosshamming_count_thres(SWIGTYPE_p_unsigned_char dbs, long n, int ht, long ncodes, SWIGTYPE_p_unsigned_long nptr) {
|
|
||||||
swigfaissJNI.crosshamming_count_thres(SWIGTYPE_p_unsigned_char.getCPtr(dbs), n, ht, ncodes, SWIGTYPE_p_unsigned_long.getCPtr(nptr));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int get_num_gpus() {
|
|
||||||
return swigfaissJNI.get_num_gpus();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String get_compile_options() {
|
|
||||||
return swigfaissJNI.get_compile_options();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double getmillisecs() {
|
|
||||||
return swigfaissJNI.getmillisecs();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long get_mem_usage_kb() {
|
|
||||||
return swigfaissJNI.get_mem_usage_kb();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long get_cycles() {
|
|
||||||
return swigfaissJNI.get_cycles();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void fvec_madd(long n, SWIGTYPE_p_float a, float bf, SWIGTYPE_p_float b, SWIGTYPE_p_float c) {
|
|
||||||
swigfaissJNI.fvec_madd(n, SWIGTYPE_p_float.getCPtr(a), bf, SWIGTYPE_p_float.getCPtr(b), SWIGTYPE_p_float.getCPtr(c));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int fvec_madd_and_argmin(long n, SWIGTYPE_p_float a, float bf, SWIGTYPE_p_float b, SWIGTYPE_p_float c) {
|
|
||||||
return swigfaissJNI.fvec_madd_and_argmin(n, SWIGTYPE_p_float.getCPtr(a), bf, SWIGTYPE_p_float.getCPtr(b), SWIGTYPE_p_float.getCPtr(c));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void reflection(SWIGTYPE_p_float u, SWIGTYPE_p_float x, long n, long d, long nu) {
|
|
||||||
swigfaissJNI.reflection(SWIGTYPE_p_float.getCPtr(u), SWIGTYPE_p_float.getCPtr(x), n, d, nu);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void matrix_qr(int m, int n, SWIGTYPE_p_float a) {
|
|
||||||
swigfaissJNI.matrix_qr(m, n, SWIGTYPE_p_float.getCPtr(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void ranklist_handle_ties(int k, LongVector idx, SWIGTYPE_p_float dis) {
|
|
||||||
swigfaissJNI.ranklist_handle_ties(k, SWIGTYPE_p_long_long.getCPtr(idx.data()), idx, SWIGTYPE_p_float.getCPtr(dis));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long ranklist_intersection_size(long k1, LongVector v1, long k2, LongVector v2) {
|
|
||||||
return swigfaissJNI.ranklist_intersection_size(k1, SWIGTYPE_p_long_long.getCPtr(v1.data()), v1, k2, SWIGTYPE_p_long_long.getCPtr(v2.data()), v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1, boolean keep_min, long translation) {
|
|
||||||
return swigfaissJNI.merge_result_table_with__SWIG_0(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1), keep_min, translation);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1, boolean keep_min) {
|
|
||||||
return swigfaissJNI.merge_result_table_with__SWIG_1(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1), keep_min);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1) {
|
|
||||||
return swigfaissJNI.merge_result_table_with__SWIG_2(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double imbalance_factor(int n, int k, LongVector assign) {
|
|
||||||
return swigfaissJNI.imbalance_factor__SWIG_0(n, k, SWIGTYPE_p_long_long.getCPtr(assign.data()), assign);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double imbalance_factor(int k, SWIGTYPE_p_int hist) {
|
|
||||||
return swigfaissJNI.imbalance_factor__SWIG_1(k, SWIGTYPE_p_int.getCPtr(hist));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void fvec_argsort(long n, SWIGTYPE_p_float vals, SWIGTYPE_p_unsigned_long perm) {
|
|
||||||
swigfaissJNI.fvec_argsort(n, SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_unsigned_long.getCPtr(perm));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void fvec_argsort_parallel(long n, SWIGTYPE_p_float vals, SWIGTYPE_p_unsigned_long perm) {
|
|
||||||
swigfaissJNI.fvec_argsort_parallel(n, SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_unsigned_long.getCPtr(perm));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int ivec_hist(long n, SWIGTYPE_p_int v, int vmax, SWIGTYPE_p_int hist) {
|
|
||||||
return swigfaissJNI.ivec_hist(n, SWIGTYPE_p_int.getCPtr(v), vmax, SWIGTYPE_p_int.getCPtr(hist));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void bincode_hist(long n, long nbits, SWIGTYPE_p_unsigned_char codes, SWIGTYPE_p_int hist) {
|
|
||||||
swigfaissJNI.bincode_hist(n, nbits, SWIGTYPE_p_unsigned_char.getCPtr(codes), SWIGTYPE_p_int.getCPtr(hist));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long ivec_checksum(long n, SWIGTYPE_p_int a) {
|
|
||||||
return swigfaissJNI.ivec_checksum(n, SWIGTYPE_p_int.getCPtr(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x, boolean verbose, long seed) {
|
|
||||||
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_0(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x), verbose, seed);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x, boolean verbose) {
|
|
||||||
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_1(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x), verbose);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x) {
|
|
||||||
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_2(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x));
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void binary_to_real(long d, SWIGTYPE_p_unsigned_char x_in, SWIGTYPE_p_float x_out) {
|
|
||||||
swigfaissJNI.binary_to_real(d, SWIGTYPE_p_unsigned_char.getCPtr(x_in), SWIGTYPE_p_float.getCPtr(x_out));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void real_to_binary(long d, SWIGTYPE_p_float x_in, SWIGTYPE_p_unsigned_char x_out) {
|
|
||||||
swigfaissJNI.real_to_binary(d, SWIGTYPE_p_float.getCPtr(x_in), SWIGTYPE_p_unsigned_char.getCPtr(x_out));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long hash_bytes(SWIGTYPE_p_unsigned_char bytes, long n) {
|
|
||||||
return swigfaissJNI.hash_bytes(SWIGTYPE_p_unsigned_char.getCPtr(bytes), n);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean check_openmp() {
|
|
||||||
return swigfaissJNI.check_openmp();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static float kmeans_clustering(long d, long n, long k, SWIGTYPE_p_float x, SWIGTYPE_p_float centroids) {
|
|
||||||
return swigfaissJNI.kmeans_clustering(d, n, k, SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_float.getCPtr(centroids));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setIndexPQ_stats(IndexPQStats value) {
|
|
||||||
swigfaissJNI.indexPQ_stats_set(IndexPQStats.getCPtr(value), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexPQStats getIndexPQ_stats() {
|
|
||||||
long cPtr = swigfaissJNI.indexPQ_stats_get();
|
|
||||||
return (cPtr == 0) ? null : new IndexPQStats(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setIndexIVF_stats(IndexIVFStats value) {
|
|
||||||
swigfaissJNI.indexIVF_stats_set(IndexIVFStats.getCPtr(value), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexIVFStats getIndexIVF_stats() {
|
|
||||||
long cPtr = swigfaissJNI.indexIVF_stats_get();
|
|
||||||
return (cPtr == 0) ? null : new IndexIVFStats(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static short[] getHamdis_tab_ham_bytes() {
|
|
||||||
return swigfaissJNI.hamdis_tab_ham_bytes_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int generalized_hamming_64(long a) {
|
|
||||||
return swigfaissJNI.generalized_hamming_64(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void generalized_hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long code_size, int ordered) {
|
|
||||||
swigfaissJNI.generalized_hammings_knn_hc__SWIG_0(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, code_size, ordered);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void generalized_hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long code_size) {
|
|
||||||
swigfaissJNI.generalized_hammings_knn_hc__SWIG_1(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, code_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void check_compatible_for_merge(Index index1, Index index2) {
|
|
||||||
swigfaissJNI.check_compatible_for_merge(Index.getCPtr(index1), index1, Index.getCPtr(index2), index2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexIVF extract_index_ivf(Index index) {
|
|
||||||
long cPtr = swigfaissJNI.extract_index_ivf__SWIG_0(Index.getCPtr(index), index);
|
|
||||||
return (cPtr == 0) ? null : new IndexIVF(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexIVF try_extract_index_ivf(Index index) {
|
|
||||||
long cPtr = swigfaissJNI.try_extract_index_ivf__SWIG_0(Index.getCPtr(index), index);
|
|
||||||
return (cPtr == 0) ? null : new IndexIVF(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void merge_into(Index index0, Index index1, boolean shift_ids) {
|
|
||||||
swigfaissJNI.merge_into(Index.getCPtr(index0), index0, Index.getCPtr(index1), index1, shift_ids);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void search_centroid(Index index, SWIGTYPE_p_float x, int n, LongVector centroid_ids) {
|
|
||||||
swigfaissJNI.search_centroid(Index.getCPtr(index), index, SWIGTYPE_p_float.getCPtr(x), n, SWIGTYPE_p_long_long.getCPtr(centroid_ids.data()), centroid_ids);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void search_and_return_centroids(Index index, long n, SWIGTYPE_p_float xin, int k, SWIGTYPE_p_float distances, LongVector labels, LongVector query_centroid_ids, LongVector result_centroid_ids) {
|
|
||||||
swigfaissJNI.search_and_return_centroids(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(xin), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, SWIGTYPE_p_long_long.getCPtr(query_centroid_ids.data()), query_centroid_ids, SWIGTYPE_p_long_long.getCPtr(result_centroid_ids.data()), result_centroid_ids);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ArrayInvertedLists get_invlist_range(Index index, int i0, int i1) {
|
|
||||||
long cPtr = swigfaissJNI.get_invlist_range(Index.getCPtr(index), index, i0, i1);
|
|
||||||
return (cPtr == 0) ? null : new ArrayInvertedLists(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void set_invlist_range(Index index, int i0, int i1, ArrayInvertedLists src) {
|
|
||||||
swigfaissJNI.set_invlist_range(Index.getCPtr(index), index, i0, i1, ArrayInvertedLists.getCPtr(src), src);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis, SWIGTYPE_p_double ms_per_stage) {
|
|
||||||
swigfaissJNI.search_with_parameters__SWIG_0(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis), SWIGTYPE_p_double.getCPtr(ms_per_stage));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis) {
|
|
||||||
swigfaissJNI.search_with_parameters__SWIG_1(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params) {
|
|
||||||
swigfaissJNI.search_with_parameters__SWIG_2(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis, SWIGTYPE_p_double ms_per_stage) {
|
|
||||||
swigfaissJNI.range_search_with_parameters__SWIG_0(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis), SWIGTYPE_p_double.getCPtr(ms_per_stage));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis) {
|
|
||||||
swigfaissJNI.range_search_with_parameters__SWIG_1(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params) {
|
|
||||||
swigfaissJNI.range_search_with_parameters__SWIG_2(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setHnsw_stats(HNSWStats value) {
|
|
||||||
swigfaissJNI.hnsw_stats_set(HNSWStats.getCPtr(value), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static HNSWStats getHnsw_stats() {
|
|
||||||
long cPtr = swigfaissJNI.hnsw_stats_get();
|
|
||||||
return (cPtr == 0) ? null : new HNSWStats(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setPrecomputed_table_max_bytes(long value) {
|
|
||||||
swigfaissJNI.precomputed_table_max_bytes_set(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static long getPrecomputed_table_max_bytes() {
|
|
||||||
return swigfaissJNI.precomputed_table_max_bytes_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void initialize_IVFPQ_precomputed_table(SWIGTYPE_p_int use_precomputed_table, Index quantizer, ProductQuantizer pq, SWIGTYPE_p_AlignedTableT_float_32_t precomputed_table, boolean verbose) {
|
|
||||||
swigfaissJNI.initialize_IVFPQ_precomputed_table(SWIGTYPE_p_int.getCPtr(use_precomputed_table), Index.getCPtr(quantizer), quantizer, ProductQuantizer.getCPtr(pq), pq, SWIGTYPE_p_AlignedTableT_float_32_t.getCPtr(precomputed_table), verbose);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setIndexIVFPQ_stats(IndexIVFPQStats value) {
|
|
||||||
swigfaissJNI.indexIVFPQ_stats_set(IndexIVFPQStats.getCPtr(value), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexIVFPQStats getIndexIVFPQ_stats() {
|
|
||||||
long cPtr = swigfaissJNI.indexIVFPQ_stats_get();
|
|
||||||
return (cPtr == 0) ? null : new IndexIVFPQStats(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index downcast_index(Index index) {
|
|
||||||
long cPtr = swigfaissJNI.downcast_index(Index.getCPtr(index), index);
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static VectorTransform downcast_VectorTransform(VectorTransform vt) {
|
|
||||||
long cPtr = swigfaissJNI.downcast_VectorTransform(VectorTransform.getCPtr(vt), vt);
|
|
||||||
return (cPtr == 0) ? null : new VectorTransform(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary downcast_IndexBinary(IndexBinary index) {
|
|
||||||
long cPtr = swigfaissJNI.downcast_IndexBinary(IndexBinary.getCPtr(index), index);
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index upcast_IndexShards(IndexShards index) {
|
|
||||||
long cPtr = swigfaissJNI.upcast_IndexShards(IndexShards.getCPtr(index), index);
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_index(Index idx, String fname) {
|
|
||||||
swigfaissJNI.write_index__SWIG_0(Index.getCPtr(idx), idx, fname);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_index(Index idx, SWIGTYPE_p_FILE f) {
|
|
||||||
swigfaissJNI.write_index__SWIG_1(Index.getCPtr(idx), idx, SWIGTYPE_p_FILE.getCPtr(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_index(Index idx, SWIGTYPE_p_faiss__IOWriter writer) {
|
|
||||||
swigfaissJNI.write_index__SWIG_2(Index.getCPtr(idx), idx, SWIGTYPE_p_faiss__IOWriter.getCPtr(writer));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_index_binary(IndexBinary idx, String fname) {
|
|
||||||
swigfaissJNI.write_index_binary__SWIG_0(IndexBinary.getCPtr(idx), idx, fname);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_index_binary(IndexBinary idx, SWIGTYPE_p_FILE f) {
|
|
||||||
swigfaissJNI.write_index_binary__SWIG_1(IndexBinary.getCPtr(idx), idx, SWIGTYPE_p_FILE.getCPtr(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_index_binary(IndexBinary idx, SWIGTYPE_p_faiss__IOWriter writer) {
|
|
||||||
swigfaissJNI.write_index_binary__SWIG_2(IndexBinary.getCPtr(idx), idx, SWIGTYPE_p_faiss__IOWriter.getCPtr(writer));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int getIO_FLAG_READ_ONLY() {
|
|
||||||
return swigfaissJNI.IO_FLAG_READ_ONLY_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int getIO_FLAG_ONDISK_SAME_DIR() {
|
|
||||||
return swigfaissJNI.IO_FLAG_ONDISK_SAME_DIR_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int getIO_FLAG_SKIP_IVF_DATA() {
|
|
||||||
return swigfaissJNI.IO_FLAG_SKIP_IVF_DATA_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int getIO_FLAG_MMAP() {
|
|
||||||
return swigfaissJNI.IO_FLAG_MMAP_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index read_index(String fname, int io_flags) {
|
|
||||||
long cPtr = swigfaissJNI.read_index__SWIG_0(fname, io_flags);
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index read_index(String fname) {
|
|
||||||
long cPtr = swigfaissJNI.read_index__SWIG_1(fname);
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index read_index(SWIGTYPE_p_FILE f, int io_flags) {
|
|
||||||
long cPtr = swigfaissJNI.read_index__SWIG_2(SWIGTYPE_p_FILE.getCPtr(f), io_flags);
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index read_index(SWIGTYPE_p_FILE f) {
|
|
||||||
long cPtr = swigfaissJNI.read_index__SWIG_3(SWIGTYPE_p_FILE.getCPtr(f));
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index read_index(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
|
|
||||||
long cPtr = swigfaissJNI.read_index__SWIG_4(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index read_index(SWIGTYPE_p_faiss__IOReader reader) {
|
|
||||||
long cPtr = swigfaissJNI.read_index__SWIG_5(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary read_index_binary(String fname, int io_flags) {
|
|
||||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_0(fname, io_flags);
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary read_index_binary(String fname) {
|
|
||||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_1(fname);
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary read_index_binary(SWIGTYPE_p_FILE f, int io_flags) {
|
|
||||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_2(SWIGTYPE_p_FILE.getCPtr(f), io_flags);
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary read_index_binary(SWIGTYPE_p_FILE f) {
|
|
||||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_3(SWIGTYPE_p_FILE.getCPtr(f));
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary read_index_binary(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
|
|
||||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_4(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary read_index_binary(SWIGTYPE_p_faiss__IOReader reader) {
|
|
||||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_5(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_VectorTransform(VectorTransform vt, String fname) {
|
|
||||||
swigfaissJNI.write_VectorTransform(VectorTransform.getCPtr(vt), vt, fname);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static VectorTransform read_VectorTransform(String fname) {
|
|
||||||
long cPtr = swigfaissJNI.read_VectorTransform(fname);
|
|
||||||
return (cPtr == 0) ? null : new VectorTransform(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ProductQuantizer read_ProductQuantizer(String fname) {
|
|
||||||
long cPtr = swigfaissJNI.read_ProductQuantizer__SWIG_0(fname);
|
|
||||||
return (cPtr == 0) ? null : new ProductQuantizer(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ProductQuantizer read_ProductQuantizer(SWIGTYPE_p_faiss__IOReader reader) {
|
|
||||||
long cPtr = swigfaissJNI.read_ProductQuantizer__SWIG_1(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
|
||||||
return (cPtr == 0) ? null : new ProductQuantizer(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_ProductQuantizer(ProductQuantizer pq, String fname) {
|
|
||||||
swigfaissJNI.write_ProductQuantizer__SWIG_0(ProductQuantizer.getCPtr(pq), pq, fname);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_ProductQuantizer(ProductQuantizer pq, SWIGTYPE_p_faiss__IOWriter f) {
|
|
||||||
swigfaissJNI.write_ProductQuantizer__SWIG_1(ProductQuantizer.getCPtr(pq), pq, SWIGTYPE_p_faiss__IOWriter.getCPtr(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void write_InvertedLists(InvertedLists ils, SWIGTYPE_p_faiss__IOWriter f) {
|
|
||||||
swigfaissJNI.write_InvertedLists(InvertedLists.getCPtr(ils), ils, SWIGTYPE_p_faiss__IOWriter.getCPtr(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static InvertedLists read_InvertedLists(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
|
|
||||||
long cPtr = swigfaissJNI.read_InvertedLists__SWIG_0(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
|
|
||||||
return (cPtr == 0) ? null : new InvertedLists(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static InvertedLists read_InvertedLists(SWIGTYPE_p_faiss__IOReader reader) {
|
|
||||||
long cPtr = swigfaissJNI.read_InvertedLists__SWIG_1(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
|
||||||
return (cPtr == 0) ? null : new InvertedLists(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index index_factory(int d, String description, MetricType metric) {
|
|
||||||
long cPtr = swigfaissJNI.index_factory__SWIG_0(d, description, metric.swigValue());
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Index index_factory(int d, String description) {
|
|
||||||
long cPtr = swigfaissJNI.index_factory__SWIG_1(d, description);
|
|
||||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setIndex_factory_verbose(int value) {
|
|
||||||
swigfaissJNI.index_factory_verbose_set(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int getIndex_factory_verbose() {
|
|
||||||
return swigfaissJNI.index_factory_verbose_get();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static IndexBinary index_binary_factory(int d, String description) {
|
|
||||||
long cPtr = swigfaissJNI.index_binary_factory(d, description);
|
|
||||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void simd_histogram_8(SWIGTYPE_p_uint16_t data, int n, SWIGTYPE_p_uint16_t min, int shift, SWIGTYPE_p_int hist) {
|
|
||||||
swigfaissJNI.simd_histogram_8(SWIGTYPE_p_uint16_t.getCPtr(data), n, SWIGTYPE_p_uint16_t.getCPtr(min), shift, SWIGTYPE_p_int.getCPtr(hist));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void simd_histogram_16(SWIGTYPE_p_uint16_t data, int n, SWIGTYPE_p_uint16_t min, int shift, SWIGTYPE_p_int hist) {
|
|
||||||
swigfaissJNI.simd_histogram_16(SWIGTYPE_p_uint16_t.getCPtr(data), n, SWIGTYPE_p_uint16_t.getCPtr(min), shift, SWIGTYPE_p_int.getCPtr(hist));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void setPartition_stats(PartitionStats value) {
|
|
||||||
swigfaissJNI.partition_stats_set(PartitionStats.getCPtr(value), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static PartitionStats getPartition_stats() {
|
|
||||||
long cPtr = swigfaissJNI.partition_stats_get();
|
|
||||||
return (cPtr == 0) ? null : new PartitionStats(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static float CMin_float_partition_fuzzy(SWIGTYPE_p_float vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
|
||||||
return swigfaissJNI.CMin_float_partition_fuzzy(SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static float CMax_float_partition_fuzzy(SWIGTYPE_p_float vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
|
||||||
return swigfaissJNI.CMax_float_partition_fuzzy(SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_uint16_t CMax_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
|
||||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMax_uint16_partition_fuzzy__SWIG_0(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_uint16_t CMin_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
|
||||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMin_uint16_partition_fuzzy__SWIG_0(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_uint16_t CMax_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, SWIGTYPE_p_int ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
|
||||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMax_uint16_partition_fuzzy__SWIG_1(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_int.getCPtr(ids), n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_uint16_t CMin_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, SWIGTYPE_p_int ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
|
||||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMin_uint16_partition_fuzzy__SWIG_1(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_int.getCPtr(ids), n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void omp_set_num_threads(int num_threads) {
|
|
||||||
swigfaissJNI.omp_set_num_threads(num_threads);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int omp_get_max_threads() {
|
|
||||||
return swigfaissJNI.omp_get_max_threads();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_void memcpy(SWIGTYPE_p_void dest, SWIGTYPE_p_void src, long n) {
|
|
||||||
long cPtr = swigfaissJNI.memcpy(SWIGTYPE_p_void.getCPtr(dest), SWIGTYPE_p_void.getCPtr(src), n);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_void(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_float cast_integer_to_float_ptr(int x) {
|
|
||||||
long cPtr = swigfaissJNI.cast_integer_to_float_ptr(x);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_long cast_integer_to_long_ptr(int x) {
|
|
||||||
long cPtr = swigfaissJNI.cast_integer_to_long_ptr(x);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_long(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SWIGTYPE_p_int cast_integer_to_int_ptr(int x) {
|
|
||||||
long cPtr = swigfaissJNI.cast_integer_to_int_ptr(x);
|
|
||||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void ignore_SIGTTIN() {
|
|
||||||
swigfaissJNI.ignore_SIGTTIN();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Binary file not shown.
@ -1,15 +0,0 @@
|
|||||||
/* ----------------------------------------------------------------------------
|
|
||||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
|
||||||
* Version 4.0.2
|
|
||||||
*
|
|
||||||
* Do not make changes to this file unless you know what you are doing--modify
|
|
||||||
* the SWIG interface file instead.
|
|
||||||
* ----------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
package com.twitter.ann.faiss;
|
|
||||||
|
|
||||||
public interface swigfaissConstants {
|
|
||||||
public final static int FAISS_VERSION_MAJOR = swigfaissJNI.FAISS_VERSION_MAJOR_get();
|
|
||||||
public final static int FAISS_VERSION_MINOR = swigfaissJNI.FAISS_VERSION_MINOR_get();
|
|
||||||
public final static int FAISS_VERSION_PATCH = swigfaissJNI.FAISS_VERSION_PATCH_get();
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaissJNI.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaissJNI.docx
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -1,18 +0,0 @@
|
|||||||
java_library(
|
|
||||||
sources = ["*.java"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
platform = "java8",
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/com/google/guava",
|
|
||||||
"3rdparty/jvm/com/google/inject:guice",
|
|
||||||
"3rdparty/jvm/com/twitter/bijection:core",
|
|
||||||
"3rdparty/jvm/commons-lang",
|
|
||||||
"3rdparty/jvm/org/apache/thrift",
|
|
||||||
"ann/src/main/scala/com/twitter/ann/common",
|
|
||||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-java",
|
|
||||||
"mediaservices/commons/src/main/scala:futuretracker",
|
|
||||||
"scrooge/scrooge-core",
|
|
||||||
"src/java/com/twitter/search/common/file",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/BUILD.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/BUILD.docx
Normal file
Binary file not shown.
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistanceFunction.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistanceFunction.docx
Normal file
Binary file not shown.
@ -1,8 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
public interface DistanceFunction<T, Q> {
|
|
||||||
/**
|
|
||||||
* Distance between two items.
|
|
||||||
*/
|
|
||||||
float distance(T t, Q q);
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItem.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItem.docx
Normal file
Binary file not shown.
@ -1,23 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An item associated with a float distance
|
|
||||||
* @param <T> The type of the item.
|
|
||||||
*/
|
|
||||||
public class DistancedItem<T> {
|
|
||||||
private final T item;
|
|
||||||
private final float distance;
|
|
||||||
|
|
||||||
public DistancedItem(T item, float distance) {
|
|
||||||
this.item = item;
|
|
||||||
this.distance = distance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public T getItem() {
|
|
||||||
return item;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float getDistance() {
|
|
||||||
return distance;
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItemQueue.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItemQueue.docx
Normal file
Binary file not shown.
@ -1,196 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.PriorityQueue;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Container for items with their distance.
|
|
||||||
*
|
|
||||||
* @param <U> Type of origin/reference element.
|
|
||||||
* @param <T> Type of element that the queue will hold
|
|
||||||
*/
|
|
||||||
public class DistancedItemQueue<U, T> implements Iterable<DistancedItem<T>> {
|
|
||||||
private final U origin;
|
|
||||||
private final DistanceFunction<U, T> distFn;
|
|
||||||
private final PriorityQueue<DistancedItem<T>> queue;
|
|
||||||
private final boolean minQueue;
|
|
||||||
/**
|
|
||||||
* Creates ontainer for items with their distances.
|
|
||||||
*
|
|
||||||
* @param origin Origin (reference) point
|
|
||||||
* @param initial Initial list of elements to add in the structure
|
|
||||||
* @param minQueue True for min queue, False for max queue
|
|
||||||
* @param distFn Distance function
|
|
||||||
*/
|
|
||||||
public DistancedItemQueue(
|
|
||||||
U origin,
|
|
||||||
List<T> initial,
|
|
||||||
boolean minQueue,
|
|
||||||
DistanceFunction<U, T> distFn
|
|
||||||
) {
|
|
||||||
this.origin = origin;
|
|
||||||
this.distFn = distFn;
|
|
||||||
this.minQueue = minQueue;
|
|
||||||
final Comparator<DistancedItem<T>> cmp;
|
|
||||||
if (minQueue) {
|
|
||||||
cmp = (o1, o2) -> Float.compare(o1.getDistance(), o2.getDistance());
|
|
||||||
} else {
|
|
||||||
cmp = (o1, o2) -> Float.compare(o2.getDistance(), o1.getDistance());
|
|
||||||
}
|
|
||||||
this.queue = new PriorityQueue<>(cmp);
|
|
||||||
enqueueAll(initial);
|
|
||||||
new DistancedItemQueue<>(origin, distFn, queue, minQueue);
|
|
||||||
}
|
|
||||||
|
|
||||||
private DistancedItemQueue(
|
|
||||||
U origin,
|
|
||||||
DistanceFunction<U, T> distFn,
|
|
||||||
PriorityQueue<DistancedItem<T>> queue,
|
|
||||||
boolean minQueue
|
|
||||||
) {
|
|
||||||
this.origin = origin;
|
|
||||||
this.distFn = distFn;
|
|
||||||
this.queue = queue;
|
|
||||||
this.minQueue = minQueue;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Enqueues all the items into the queue.
|
|
||||||
*/
|
|
||||||
public void enqueueAll(List<T> list) {
|
|
||||||
for (T t : list) {
|
|
||||||
enqueue(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return if queue is non empty or not
|
|
||||||
*
|
|
||||||
* @return true if queue is not empty else false
|
|
||||||
*/
|
|
||||||
public boolean nonEmpty() {
|
|
||||||
return !queue.isEmpty();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return root of the queue
|
|
||||||
*
|
|
||||||
* @return root of the queue i.e min/max element depending upon min-max queue
|
|
||||||
*/
|
|
||||||
public DistancedItem<T> peek() {
|
|
||||||
return queue.peek();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dequeue root of the queue.
|
|
||||||
*
|
|
||||||
* @return remove and return root of the queue i.e min/max element depending upon min-max queue
|
|
||||||
*/
|
|
||||||
public DistancedItem<T> dequeue() {
|
|
||||||
return queue.poll();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dequeue all the elements from queueu with ordering mantained
|
|
||||||
*
|
|
||||||
* @return remove all the elements in the order of the queue i.e min/max queue.
|
|
||||||
*/
|
|
||||||
public List<DistancedItem<T>> dequeueAll() {
|
|
||||||
final List<DistancedItem<T>> list = new ArrayList<>(queue.size());
|
|
||||||
while (!queue.isEmpty()) {
|
|
||||||
list.add(queue.poll());
|
|
||||||
}
|
|
||||||
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert queue to list
|
|
||||||
*
|
|
||||||
* @return list of elements of queue with distance and without any specific ordering
|
|
||||||
*/
|
|
||||||
public List<DistancedItem<T>> toList() {
|
|
||||||
return new ArrayList<>(queue);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert queue to list
|
|
||||||
*
|
|
||||||
* @return list of elements of queue without any specific ordering
|
|
||||||
*/
|
|
||||||
List<T> toListWithItem() {
|
|
||||||
List<T> list = new ArrayList<>(queue.size());
|
|
||||||
Iterator<DistancedItem<T>> itr = iterator();
|
|
||||||
while (itr.hasNext()) {
|
|
||||||
list.add(itr.next().getItem());
|
|
||||||
}
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Enqueue an item into the queue
|
|
||||||
*/
|
|
||||||
public void enqueue(T item) {
|
|
||||||
queue.add(new DistancedItem<>(item, distFn.distance(origin, item)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Enqueue an item into the queue with its distance.
|
|
||||||
*/
|
|
||||||
public void enqueue(T item, float distance) {
|
|
||||||
queue.add(new DistancedItem<>(item, distance));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Size
|
|
||||||
*
|
|
||||||
* @return size of the queue
|
|
||||||
*/
|
|
||||||
public int size() {
|
|
||||||
return queue.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Is Min queue
|
|
||||||
*
|
|
||||||
* @return true if min queue else false
|
|
||||||
*/
|
|
||||||
public boolean isMinQueue() {
|
|
||||||
return minQueue;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns origin (base element) of the queue
|
|
||||||
*
|
|
||||||
* @return origin of the queue
|
|
||||||
*/
|
|
||||||
public U getOrigin() {
|
|
||||||
return origin;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a new queue with ordering reversed.
|
|
||||||
*/
|
|
||||||
public DistancedItemQueue<U, T> reverse() {
|
|
||||||
final PriorityQueue<DistancedItem<T>> rqueue =
|
|
||||||
new PriorityQueue<>(queue.comparator().reversed());
|
|
||||||
if (queue.isEmpty()) {
|
|
||||||
return new DistancedItemQueue<>(origin, distFn, rqueue, !isMinQueue());
|
|
||||||
}
|
|
||||||
|
|
||||||
final Iterator<DistancedItem<T>> itr = iterator();
|
|
||||||
while (itr.hasNext()) {
|
|
||||||
rqueue.add(itr.next());
|
|
||||||
}
|
|
||||||
|
|
||||||
return new DistancedItemQueue<>(origin, distFn, rqueue, !isMinQueue());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterator<DistancedItem<T>> iterator() {
|
|
||||||
return queue.iterator();
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndex.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndex.docx
Normal file
Binary file not shown.
@ -1,711 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Random;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
|
||||||
import java.util.concurrent.locks.Lock;
|
|
||||||
import java.util.concurrent.locks.ReadWriteLock;
|
|
||||||
import java.util.concurrent.locks.ReentrantLock;
|
|
||||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
|
||||||
import com.google.common.base.Preconditions;
|
|
||||||
import com.google.common.collect.ImmutableList;
|
|
||||||
|
|
||||||
import org.apache.thrift.TException;
|
|
||||||
|
|
||||||
import com.twitter.ann.common.IndexOutputFile;
|
|
||||||
import com.twitter.ann.common.thriftjava.HnswInternalIndexMetadata;
|
|
||||||
import com.twitter.bijection.Injection;
|
|
||||||
import com.twitter.logging.Logger;
|
|
||||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec;
|
|
||||||
import com.twitter.search.common.file.AbstractFile;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Typed multithreaded HNSW implementation supporting creation/querying of approximate nearest neighbour
|
|
||||||
* Paper: https://arxiv.org/pdf/1603.09320.pdf
|
|
||||||
* Multithreading impl based on NMSLIB version : https://github.com/nmslib/hnsw/blob/master/hnswlib/hnswalg.h
|
|
||||||
*
|
|
||||||
* @param <T> The type of items inserted / searched in the HNSW index.
|
|
||||||
* @param <Q> The type of KNN query.
|
|
||||||
*/
|
|
||||||
public class HnswIndex<T, Q> {
|
|
||||||
private static final Logger LOG = Logger.get(HnswIndex.class);
|
|
||||||
private static final String METADATA_FILE_NAME = "hnsw_internal_metadata";
|
|
||||||
private static final String GRAPH_FILE_NAME = "hnsw_internal_graph";
|
|
||||||
private static final int MAP_SIZE_FACTOR = 5;
|
|
||||||
|
|
||||||
private final DistanceFunction<T, T> distFnIndex;
|
|
||||||
private final DistanceFunction<Q, T> distFnQuery;
|
|
||||||
private final int efConstruction;
|
|
||||||
private final int maxM;
|
|
||||||
private final int maxM0;
|
|
||||||
private final double levelMultiplier;
|
|
||||||
private final AtomicReference<HnswMeta<T>> graphMeta = new AtomicReference<>();
|
|
||||||
private final Map<HnswNode<T>, ImmutableList<T>> graph;
|
|
||||||
// To take lock on vertex level
|
|
||||||
private final ConcurrentHashMap<T, ReadWriteLock> locks;
|
|
||||||
// To take lock on whole graph only if vertex addition is on layer above the current maxLevel
|
|
||||||
private final ReentrantLock globalLock;
|
|
||||||
private final Function<T, ReadWriteLock> lockProvider;
|
|
||||||
|
|
||||||
private final RandomProvider randomProvider;
|
|
||||||
|
|
||||||
// Probability of reevaluating connections of an element in the neighborhood during an update
|
|
||||||
// Can be used as a knob to adjust update_speed/search_speed tradeoff.
|
|
||||||
private final float updateNeighborProbability;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates instance of hnsw index.
|
|
||||||
*
|
|
||||||
* @param distFnIndex Any distance metric/non metric that specifies similarity between two items for indexing.
|
|
||||||
* @param distFnQuery Any distance metric/non metric that specifies similarity between item for which nearest neighbours queried for and already indexed item.
|
|
||||||
* @param efConstruction Provide speed vs index quality tradeoff, higher the value better the quality and higher the time to create index.
|
|
||||||
* Valid range of efConstruction can be anywhere between 1 and tens of thousand. Typically, it should be set so that a search of M
|
|
||||||
* neighbors with ef=efConstruction should end in recall>0.95.
|
|
||||||
* @param maxM Maximum connections per layer except 0th level.
|
|
||||||
* Optimal values between 5-48.
|
|
||||||
* Smaller M generally produces better result for lower recalls and/ or lower dimensional data,
|
|
||||||
* while bigger M is better for high recall and/ or high dimensional, data on the expense of more memory/disk usage
|
|
||||||
* @param expectedElements Approximate number of elements to be indexed
|
|
||||||
*/
|
|
||||||
protected HnswIndex(
|
|
||||||
DistanceFunction<T, T> distFnIndex,
|
|
||||||
DistanceFunction<Q, T> distFnQuery,
|
|
||||||
int efConstruction,
|
|
||||||
int maxM,
|
|
||||||
int expectedElements,
|
|
||||||
RandomProvider randomProvider
|
|
||||||
) {
|
|
||||||
this(distFnIndex,
|
|
||||||
distFnQuery,
|
|
||||||
efConstruction,
|
|
||||||
maxM,
|
|
||||||
expectedElements,
|
|
||||||
new HnswMeta<>(-1, Optional.empty()),
|
|
||||||
new ConcurrentHashMap<>(MAP_SIZE_FACTOR * expectedElements),
|
|
||||||
randomProvider
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
private HnswIndex(
|
|
||||||
DistanceFunction<T, T> distFnIndex,
|
|
||||||
DistanceFunction<Q, T> distFnQuery,
|
|
||||||
int efConstruction,
|
|
||||||
int maxM,
|
|
||||||
int expectedElements,
|
|
||||||
HnswMeta<T> graphMeta,
|
|
||||||
Map<HnswNode<T>, ImmutableList<T>> graph,
|
|
||||||
RandomProvider randomProvider
|
|
||||||
) {
|
|
||||||
this.distFnIndex = distFnIndex;
|
|
||||||
this.distFnQuery = distFnQuery;
|
|
||||||
this.efConstruction = efConstruction;
|
|
||||||
this.maxM = maxM;
|
|
||||||
this.maxM0 = 2 * maxM;
|
|
||||||
this.levelMultiplier = 1.0 / Math.log(1.0 * maxM);
|
|
||||||
this.graphMeta.set(graphMeta);
|
|
||||||
this.graph = graph;
|
|
||||||
this.locks = new ConcurrentHashMap<>(MAP_SIZE_FACTOR * expectedElements);
|
|
||||||
this.globalLock = new ReentrantLock();
|
|
||||||
this.lockProvider = key -> new ReentrantReadWriteLock();
|
|
||||||
this.randomProvider = randomProvider;
|
|
||||||
this.updateNeighborProbability = 1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* wireConnectionForAllLayers finds connections for a new element and creates bi-direction links.
|
|
||||||
* The method assumes using a reentrant lock to link list reads.
|
|
||||||
*
|
|
||||||
* @param entryPoint the global entry point
|
|
||||||
* @param item the item for which the connections are found
|
|
||||||
* @param itemLevel the level of the added item (maximum layer in which we wire the connections)
|
|
||||||
* @param maxLayer the level of the entry point
|
|
||||||
*/
|
|
||||||
private void wireConnectionForAllLayers(final T entryPoint, final T item, final int itemLevel,
|
|
||||||
final int maxLayer, final boolean isUpdate) {
|
|
||||||
T curObj = entryPoint;
|
|
||||||
if (itemLevel < maxLayer) {
|
|
||||||
curObj = bestEntryPointUntilLayer(curObj, item, maxLayer, itemLevel, distFnIndex);
|
|
||||||
}
|
|
||||||
for (int level = Math.min(itemLevel, maxLayer); level >= 0; level--) {
|
|
||||||
final DistancedItemQueue<T, T> candidates =
|
|
||||||
searchLayerForCandidates(item, curObj, efConstruction, level, distFnIndex, isUpdate);
|
|
||||||
curObj = mutuallyConnectNewElement(item, candidates, level, isUpdate);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Insert the item into HNSW index.
|
|
||||||
*/
|
|
||||||
public void insert(final T item) throws IllegalDuplicateInsertException {
|
|
||||||
final Lock itemLock = locks.computeIfAbsent(item, lockProvider).writeLock();
|
|
||||||
itemLock.lock();
|
|
||||||
try {
|
|
||||||
final HnswMeta<T> metadata = graphMeta.get();
|
|
||||||
// If the graph already have the item, should not re-insert it again
|
|
||||||
// Need to check entry point in case we reinsert first item where is are no graph
|
|
||||||
// but only a entry point
|
|
||||||
if (graph.containsKey(HnswNode.from(0, item))
|
|
||||||
|| (metadata.getEntryPoint().isPresent()
|
|
||||||
&& Objects.equals(metadata.getEntryPoint().get(), item))) {
|
|
||||||
throw new IllegalDuplicateInsertException(
|
|
||||||
"Duplicate insertion is not supported: " + item);
|
|
||||||
}
|
|
||||||
final int curLevel = getRandomLevel();
|
|
||||||
Optional<T> entryPoint = metadata.getEntryPoint();
|
|
||||||
// The global lock prevents two threads from making changes to the entry point. This lock
|
|
||||||
// should get taken very infrequently. Something like log-base-levelMultiplier(num items)
|
|
||||||
// For a full explanation of locking see this document: http://go/hnsw-locking
|
|
||||||
int maxLevelCopy = metadata.getMaxLevel();
|
|
||||||
if (curLevel > maxLevelCopy) {
|
|
||||||
globalLock.lock();
|
|
||||||
// Re initialize the entryPoint and maxLevel in case these are changed by any other thread
|
|
||||||
// No need to check the condition again since,
|
|
||||||
// it is already checked at the end before updating entry point struct
|
|
||||||
// No need to unlock for optimization and keeping as is if condition fails since threads
|
|
||||||
// will not be entering this section a lot.
|
|
||||||
final HnswMeta<T> temp = graphMeta.get();
|
|
||||||
entryPoint = temp.getEntryPoint();
|
|
||||||
maxLevelCopy = temp.getMaxLevel();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (entryPoint.isPresent()) {
|
|
||||||
wireConnectionForAllLayers(entryPoint.get(), item, curLevel, maxLevelCopy, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (curLevel > maxLevelCopy) {
|
|
||||||
Preconditions.checkState(globalLock.isHeldByCurrentThread(),
|
|
||||||
"Global lock not held before updating entry point");
|
|
||||||
graphMeta.set(new HnswMeta<>(curLevel, Optional.of(item)));
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
if (globalLock.isHeldByCurrentThread()) {
|
|
||||||
globalLock.unlock();
|
|
||||||
}
|
|
||||||
itemLock.unlock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* set connections of an element with synchronization
|
|
||||||
* The only other place that should have the lock for writing is during
|
|
||||||
* the element insertion
|
|
||||||
*/
|
|
||||||
private void setConnectionList(final T item, int layer, List<T> connections) {
|
|
||||||
final Lock candidateLock = locks.computeIfAbsent(item, lockProvider).writeLock();
|
|
||||||
candidateLock.lock();
|
|
||||||
try {
|
|
||||||
graph.put(
|
|
||||||
HnswNode.from(layer, item),
|
|
||||||
ImmutableList.copyOf(connections)
|
|
||||||
);
|
|
||||||
} finally {
|
|
||||||
candidateLock.unlock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Reinsert the item into HNSW index.
|
|
||||||
* This method updates the links of an element assuming
|
|
||||||
* the element's distance function is changed externally (e.g. by updating the features)
|
|
||||||
*/
|
|
||||||
|
|
||||||
public void reInsert(final T item) {
|
|
||||||
final HnswMeta<T> metadata = graphMeta.get();
|
|
||||||
|
|
||||||
Optional<T> entryPoint = metadata.getEntryPoint();
|
|
||||||
|
|
||||||
Preconditions.checkState(entryPoint.isPresent(),
|
|
||||||
"Update cannot be performed if entry point is not present");
|
|
||||||
|
|
||||||
// This is a check for the single element case
|
|
||||||
if (entryPoint.get().equals(item) && graph.isEmpty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Preconditions.checkState(graph.containsKey(HnswNode.from(0, item)),
|
|
||||||
"Graph does not contain the item to be updated at level 0");
|
|
||||||
|
|
||||||
int curLevel = 0;
|
|
||||||
|
|
||||||
int maxLevelCopy = metadata.getMaxLevel();
|
|
||||||
|
|
||||||
for (int layer = maxLevelCopy; layer >= 0; layer--) {
|
|
||||||
if (graph.containsKey(HnswNode.from(layer, item))) {
|
|
||||||
curLevel = layer;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Updating the links of the elements from the 1-hop radius of the updated element
|
|
||||||
|
|
||||||
for (int layer = 0; layer <= curLevel; layer++) {
|
|
||||||
|
|
||||||
// Filling the element sets for candidates and updated elements
|
|
||||||
final HashSet<T> setCand = new HashSet<T>();
|
|
||||||
final HashSet<T> setNeigh = new HashSet<T>();
|
|
||||||
final List<T> listOneHop = getConnectionListForRead(item, layer);
|
|
||||||
|
|
||||||
if (listOneHop.isEmpty()) {
|
|
||||||
LOG.debug("No links for the updated element. Empty dataset?");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
setCand.add(item);
|
|
||||||
|
|
||||||
for (T elOneHop : listOneHop) {
|
|
||||||
setCand.add(elOneHop);
|
|
||||||
if (randomProvider.get().nextFloat() > updateNeighborProbability) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
setNeigh.add(elOneHop);
|
|
||||||
final List<T> listTwoHop = getConnectionListForRead(elOneHop, layer);
|
|
||||||
|
|
||||||
if (listTwoHop.isEmpty()) {
|
|
||||||
LOG.debug("No links for the updated element. Empty dataset?");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (T oneHopEl : listTwoHop) {
|
|
||||||
setCand.add(oneHopEl);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// No need to update the item itself, so remove it
|
|
||||||
setNeigh.remove(item);
|
|
||||||
|
|
||||||
// Updating the link lists of elements from setNeigh:
|
|
||||||
for (T neigh : setNeigh) {
|
|
||||||
final HashSet<T> setCopy = new HashSet<T>(setCand);
|
|
||||||
setCopy.remove(neigh);
|
|
||||||
int keepElementsNum = Math.min(efConstruction, setCopy.size());
|
|
||||||
final DistancedItemQueue<T, T> candidates = new DistancedItemQueue<>(
|
|
||||||
neigh,
|
|
||||||
ImmutableList.of(),
|
|
||||||
false,
|
|
||||||
distFnIndex
|
|
||||||
);
|
|
||||||
for (T cand : setCopy) {
|
|
||||||
final float distance = distFnIndex.distance(neigh, cand);
|
|
||||||
if (candidates.size() < keepElementsNum) {
|
|
||||||
candidates.enqueue(cand, distance);
|
|
||||||
} else {
|
|
||||||
if (distance < candidates.peek().getDistance()) {
|
|
||||||
candidates.dequeue();
|
|
||||||
candidates.enqueue(cand, distance);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
final ImmutableList<T> neighbours = selectNearestNeighboursByHeuristic(
|
|
||||||
candidates,
|
|
||||||
layer == 0 ? maxM0 : maxM
|
|
||||||
);
|
|
||||||
|
|
||||||
final List<T> temp = getConnectionListForRead(neigh, layer);
|
|
||||||
if (temp.isEmpty()) {
|
|
||||||
LOG.debug("existing linkslist is empty. Corrupt index");
|
|
||||||
}
|
|
||||||
if (neighbours.isEmpty()) {
|
|
||||||
LOG.debug("predicted linkslist is empty. Corrupt index");
|
|
||||||
}
|
|
||||||
setConnectionList(neigh, layer, neighbours);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
wireConnectionForAllLayers(metadata.getEntryPoint().get(), item, curLevel, maxLevelCopy, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method can be used to get the graph statistics, specifically
|
|
||||||
* it prints the histogram of inbound connections for each element.
|
|
||||||
*/
|
|
||||||
private String getStats() {
|
|
||||||
int histogramMaxBins = 50;
|
|
||||||
int[] histogram = new int[histogramMaxBins];
|
|
||||||
HashMap<T, Integer> mmap = new HashMap<T, Integer>();
|
|
||||||
for (HnswNode<T> key : graph.keySet()) {
|
|
||||||
if (key.level == 0) {
|
|
||||||
List<T> linkList = getConnectionListForRead(key.item, key.level);
|
|
||||||
for (T node : linkList) {
|
|
||||||
int a = mmap.computeIfAbsent(node, k -> 0);
|
|
||||||
mmap.put(node, a + 1);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (T key : mmap.keySet()) {
|
|
||||||
int ind = mmap.get(key) < histogramMaxBins - 1 ? mmap.get(key) : histogramMaxBins - 1;
|
|
||||||
histogram[ind]++;
|
|
||||||
}
|
|
||||||
int minNonZeroIndex;
|
|
||||||
for (minNonZeroIndex = histogramMaxBins - 1; minNonZeroIndex >= 0; minNonZeroIndex--) {
|
|
||||||
if (histogram[minNonZeroIndex] > 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
String output = "";
|
|
||||||
for (int i = 0; i <= minNonZeroIndex; i++) {
|
|
||||||
output += "" + i + "\t" + histogram[i] / (0.01f * mmap.keySet().size()) + "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
private int getRandomLevel() {
|
|
||||||
return (int) (-Math.log(randomProvider.get().nextDouble()) * levelMultiplier);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Note that to avoid deadlocks it is important that this method is called after all the searches
|
|
||||||
* of the graph have completed. If you take a lock on any items discovered in the graph after
|
|
||||||
* this, you may get stuck waiting on a thread that is waiting for item to be fully inserted.
|
|
||||||
* <p>
|
|
||||||
* Note: when using concurrent writers we can miss connections that we would otherwise get.
|
|
||||||
* This will reduce the recall.
|
|
||||||
* <p>
|
|
||||||
* For a full explanation of locking see this document: http://go/hnsw-locking
|
|
||||||
* The method returns the closest nearest neighbor (can be used as an enter point)
|
|
||||||
*/
|
|
||||||
private T mutuallyConnectNewElement(
|
|
||||||
final T item,
|
|
||||||
final DistancedItemQueue<T, T> candidates, // Max queue
|
|
||||||
final int level,
|
|
||||||
final boolean isUpdate
|
|
||||||
) {
|
|
||||||
|
|
||||||
// Using maxM here. Its implementation is ambiguous in HNSW paper,
|
|
||||||
// so using the way it is getting used in Hnsw lib.
|
|
||||||
final ImmutableList<T> neighbours = selectNearestNeighboursByHeuristic(candidates, maxM);
|
|
||||||
setConnectionList(item, level, neighbours);
|
|
||||||
final int M = level == 0 ? maxM0 : maxM;
|
|
||||||
for (T nn : neighbours) {
|
|
||||||
if (nn.equals(item)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
final Lock curLock = locks.computeIfAbsent(nn, lockProvider).writeLock();
|
|
||||||
curLock.lock();
|
|
||||||
try {
|
|
||||||
final HnswNode<T> key = HnswNode.from(level, nn);
|
|
||||||
final ImmutableList<T> connections = graph.getOrDefault(key, ImmutableList.of());
|
|
||||||
final boolean isItemAlreadyPresent =
|
|
||||||
isUpdate && connections.indexOf(item) != -1 ? true : false;
|
|
||||||
|
|
||||||
// If `item` is already present in the neighboring connections,
|
|
||||||
// then no need to modify any connections or run the search heuristics.
|
|
||||||
if (isItemAlreadyPresent) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
final ImmutableList<T> updatedConnections;
|
|
||||||
if (connections.size() < M) {
|
|
||||||
final List<T> temp = new ArrayList<>(connections);
|
|
||||||
temp.add(item);
|
|
||||||
updatedConnections = ImmutableList.copyOf(temp.iterator());
|
|
||||||
} else {
|
|
||||||
// Max Queue
|
|
||||||
final DistancedItemQueue<T, T> queue = new DistancedItemQueue<>(
|
|
||||||
nn,
|
|
||||||
connections,
|
|
||||||
false,
|
|
||||||
distFnIndex
|
|
||||||
);
|
|
||||||
queue.enqueue(item);
|
|
||||||
updatedConnections = selectNearestNeighboursByHeuristic(queue, M);
|
|
||||||
}
|
|
||||||
if (updatedConnections.isEmpty()) {
|
|
||||||
LOG.debug("Internal error: predicted linkslist is empty");
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.put(key, updatedConnections);
|
|
||||||
} finally {
|
|
||||||
curLock.unlock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return neighbours.get(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* bestEntryPointUntilLayer starts the graph search for item from the entry point
|
|
||||||
* until the searches reaches the selectedLayer layer.
|
|
||||||
* @return a point from selectedLayer layer, was the closest on the (selectedLayer+1) layer
|
|
||||||
*/
|
|
||||||
private <K> T bestEntryPointUntilLayer(
|
|
||||||
final T entryPoint,
|
|
||||||
final K item,
|
|
||||||
int maxLayer,
|
|
||||||
int selectedLayer,
|
|
||||||
DistanceFunction<K, T> distFn
|
|
||||||
) {
|
|
||||||
T curObj = entryPoint;
|
|
||||||
if (selectedLayer < maxLayer) {
|
|
||||||
float curDist = distFn.distance(item, curObj);
|
|
||||||
for (int level = maxLayer; level > selectedLayer; level--) {
|
|
||||||
boolean changed = true;
|
|
||||||
while (changed) {
|
|
||||||
changed = false;
|
|
||||||
final List<T> list = getConnectionListForRead(curObj, level);
|
|
||||||
for (T nn : list) {
|
|
||||||
final float tempDist = distFn.distance(item, nn);
|
|
||||||
if (tempDist < curDist) {
|
|
||||||
curDist = tempDist;
|
|
||||||
curObj = nn;
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return curObj;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
protected ImmutableList<T> selectNearestNeighboursByHeuristic(
|
|
||||||
final DistancedItemQueue<T, T> candidates, // Max queue
|
|
||||||
final int maxConnections
|
|
||||||
) {
|
|
||||||
Preconditions.checkState(!candidates.isMinQueue(),
|
|
||||||
"candidates in selectNearestNeighboursByHeuristic should be a max queue");
|
|
||||||
|
|
||||||
final T baseElement = candidates.getOrigin();
|
|
||||||
if (candidates.size() <= maxConnections) {
|
|
||||||
List<T> list = candidates.toListWithItem();
|
|
||||||
list.remove(baseElement);
|
|
||||||
return ImmutableList.copyOf(list);
|
|
||||||
} else {
|
|
||||||
final List<T> resSet = new ArrayList<>(maxConnections);
|
|
||||||
// Min queue for closest elements first
|
|
||||||
final DistancedItemQueue<T, T> minQueue = candidates.reverse();
|
|
||||||
while (minQueue.nonEmpty()) {
|
|
||||||
if (resSet.size() >= maxConnections) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
final DistancedItem<T> candidate = minQueue.dequeue();
|
|
||||||
|
|
||||||
// We do not want to creates loops:
|
|
||||||
// While heuristic is used only for creating the links
|
|
||||||
if (candidate.getItem().equals(baseElement)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean toInclude = true;
|
|
||||||
for (T e : resSet) {
|
|
||||||
// Do not include candidate if the distance from candidate to any of existing item in
|
|
||||||
// resSet is closer to the distance from the candidate to the item. By doing this, the
|
|
||||||
// connection of graph will be more diverse, and in case of highly clustered data set,
|
|
||||||
// connections will be made between clusters instead of all being in the same cluster.
|
|
||||||
final float dist = distFnIndex.distance(e, candidate.getItem());
|
|
||||||
if (dist < candidate.getDistance()) {
|
|
||||||
toInclude = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (toInclude) {
|
|
||||||
resSet.add(candidate.getItem());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ImmutableList.copyOf(resSet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Search the index for the neighbours.
|
|
||||||
*
|
|
||||||
* @param query Query
|
|
||||||
* @param numOfNeighbours Number of neighbours to search for.
|
|
||||||
* @param ef This param controls the accuracy of the search.
|
|
||||||
* Bigger the ef better the accuracy on the expense of latency.
|
|
||||||
* Keep it atleast number of neighbours to find.
|
|
||||||
* @return Neighbours
|
|
||||||
*/
|
|
||||||
public List<DistancedItem<T>> searchKnn(final Q query, final int numOfNeighbours, final int ef) {
|
|
||||||
final HnswMeta<T> metadata = graphMeta.get();
|
|
||||||
if (metadata.getEntryPoint().isPresent()) {
|
|
||||||
T entryPoint = bestEntryPointUntilLayer(metadata.getEntryPoint().get(),
|
|
||||||
query, metadata.getMaxLevel(), 0, distFnQuery);
|
|
||||||
// Get the actual neighbours from 0th layer
|
|
||||||
final List<DistancedItem<T>> neighbours =
|
|
||||||
searchLayerForCandidates(query, entryPoint, Math.max(ef, numOfNeighbours),
|
|
||||||
0, distFnQuery, false).dequeueAll();
|
|
||||||
Collections.reverse(neighbours);
|
|
||||||
return neighbours.size() > numOfNeighbours
|
|
||||||
? neighbours.subList(0, numOfNeighbours) : neighbours;
|
|
||||||
} else {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This method is currently not used
|
|
||||||
// It is needed for debugging purposes only
|
|
||||||
private void checkIntegrity(String message) {
|
|
||||||
final HnswMeta<T> metadata = graphMeta.get();
|
|
||||||
for (HnswNode<T> node : graph.keySet()) {
|
|
||||||
List<T> linkList = graph.get(node);
|
|
||||||
|
|
||||||
for (T el : linkList) {
|
|
||||||
if (el.equals(node.item)) {
|
|
||||||
LOG.debug(message);
|
|
||||||
throw new RuntimeException("integrity check failed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private <K> DistancedItemQueue<K, T> searchLayerForCandidates(
|
|
||||||
final K item,
|
|
||||||
final T entryPoint,
|
|
||||||
final int ef,
|
|
||||||
final int level,
|
|
||||||
final DistanceFunction<K, T> distFn,
|
|
||||||
boolean isUpdate
|
|
||||||
) {
|
|
||||||
// Min queue
|
|
||||||
final DistancedItemQueue<K, T> cQueue = new DistancedItemQueue<>(
|
|
||||||
item,
|
|
||||||
Collections.singletonList(entryPoint),
|
|
||||||
true,
|
|
||||||
distFn
|
|
||||||
);
|
|
||||||
// Max Queue
|
|
||||||
final DistancedItemQueue<K, T> wQueue = cQueue.reverse();
|
|
||||||
final Set<T> visited = new HashSet<>();
|
|
||||||
float lowerBoundDistance = wQueue.peek().getDistance();
|
|
||||||
visited.add(entryPoint);
|
|
||||||
|
|
||||||
while (cQueue.nonEmpty()) {
|
|
||||||
final DistancedItem<T> candidate = cQueue.peek();
|
|
||||||
if (candidate.getDistance() > lowerBoundDistance) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
cQueue.dequeue();
|
|
||||||
final List<T> list = getConnectionListForRead(candidate.getItem(), level);
|
|
||||||
for (T nn : list) {
|
|
||||||
if (!visited.contains(nn)) {
|
|
||||||
visited.add(nn);
|
|
||||||
final float distance = distFn.distance(item, nn);
|
|
||||||
if (wQueue.size() < ef || distance < wQueue.peek().getDistance()) {
|
|
||||||
cQueue.enqueue(nn, distance);
|
|
||||||
|
|
||||||
if (isUpdate && item.equals(nn)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
wQueue.enqueue(nn, distance);
|
|
||||||
if (wQueue.size() > ef) {
|
|
||||||
wQueue.dequeue();
|
|
||||||
}
|
|
||||||
|
|
||||||
lowerBoundDistance = wQueue.peek().getDistance();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return wQueue;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Serialize hnsw index
|
|
||||||
*/
|
|
||||||
public void toDirectory(IndexOutputFile indexOutputFile, Injection<T, byte[]> injection)
|
|
||||||
throws IOException, TException {
|
|
||||||
final int totalGraphEntries = HnswIndexIOUtil.saveHnswGraphEntries(
|
|
||||||
graph,
|
|
||||||
indexOutputFile.createFile(GRAPH_FILE_NAME).getOutputStream(),
|
|
||||||
injection);
|
|
||||||
|
|
||||||
HnswIndexIOUtil.saveMetadata(
|
|
||||||
graphMeta.get(),
|
|
||||||
efConstruction,
|
|
||||||
maxM,
|
|
||||||
totalGraphEntries,
|
|
||||||
injection,
|
|
||||||
indexOutputFile.createFile(METADATA_FILE_NAME).getOutputStream());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load hnsw index
|
|
||||||
*/
|
|
||||||
public static <T, Q> HnswIndex<T, Q> loadHnswIndex(
|
|
||||||
DistanceFunction<T, T> distFnIndex,
|
|
||||||
DistanceFunction<Q, T> distFnQuery,
|
|
||||||
AbstractFile directory,
|
|
||||||
Injection<T, byte[]> injection,
|
|
||||||
RandomProvider randomProvider) throws IOException, TException {
|
|
||||||
final AbstractFile graphFile = directory.getChild(GRAPH_FILE_NAME);
|
|
||||||
final AbstractFile metadataFile = directory.getChild(METADATA_FILE_NAME);
|
|
||||||
final HnswInternalIndexMetadata metadata = HnswIndexIOUtil.loadMetadata(metadataFile);
|
|
||||||
final Map<HnswNode<T>, ImmutableList<T>> graph =
|
|
||||||
HnswIndexIOUtil.loadHnswGraph(graphFile, injection, metadata.numElements);
|
|
||||||
final ByteBuffer entryPointBB = metadata.entryPoint;
|
|
||||||
final HnswMeta<T> graphMeta = new HnswMeta<>(
|
|
||||||
metadata.maxLevel,
|
|
||||||
entryPointBB == null ? Optional.empty()
|
|
||||||
: Optional.of(injection.invert(ArrayByteBufferCodec.decode(entryPointBB)).get())
|
|
||||||
);
|
|
||||||
return new HnswIndex<>(
|
|
||||||
distFnIndex,
|
|
||||||
distFnQuery,
|
|
||||||
metadata.efConstruction,
|
|
||||||
metadata.maxM,
|
|
||||||
metadata.numElements,
|
|
||||||
graphMeta,
|
|
||||||
graph,
|
|
||||||
randomProvider
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<T> getConnectionListForRead(T node, int level) {
|
|
||||||
final Lock curLock = locks.computeIfAbsent(node, lockProvider).readLock();
|
|
||||||
curLock.lock();
|
|
||||||
final List<T> list;
|
|
||||||
try {
|
|
||||||
list = graph
|
|
||||||
.getOrDefault(HnswNode.from(level, node), ImmutableList.of());
|
|
||||||
} finally {
|
|
||||||
curLock.unlock();
|
|
||||||
}
|
|
||||||
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
AtomicReference<HnswMeta<T>> getGraphMeta() {
|
|
||||||
return graphMeta;
|
|
||||||
}
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
Map<T, ReadWriteLock> getLocks() {
|
|
||||||
return locks;
|
|
||||||
}
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
Map<HnswNode<T>, ImmutableList<T>> getGraph() {
|
|
||||||
return graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
public interface RandomProvider {
|
|
||||||
/**
|
|
||||||
* RandomProvider interface made public for scala 2.12 compat
|
|
||||||
*/
|
|
||||||
Random get();
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndexIOUtil.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndexIOUtil.docx
Normal file
Binary file not shown.
@ -1,133 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableList;
|
|
||||||
|
|
||||||
import org.apache.thrift.TDeserializer;
|
|
||||||
import org.apache.thrift.TException;
|
|
||||||
import org.apache.thrift.TSerializer;
|
|
||||||
import org.apache.thrift.protocol.TBinaryProtocol;
|
|
||||||
import org.apache.thrift.protocol.TProtocol;
|
|
||||||
import org.apache.thrift.transport.TIOStreamTransport;
|
|
||||||
import org.apache.thrift.transport.TTransportException;
|
|
||||||
|
|
||||||
import com.twitter.ann.common.thriftjava.HnswGraphEntry;
|
|
||||||
import com.twitter.ann.common.thriftjava.HnswInternalIndexMetadata;
|
|
||||||
import com.twitter.bijection.Injection;
|
|
||||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec;
|
|
||||||
import com.twitter.search.common.file.AbstractFile;
|
|
||||||
|
|
||||||
public final class HnswIndexIOUtil {
|
|
||||||
private HnswIndexIOUtil() {
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save thrift object in file
|
|
||||||
*/
|
|
||||||
public static <T> void saveMetadata(
|
|
||||||
HnswMeta<T> graphMeta,
|
|
||||||
int efConstruction,
|
|
||||||
int maxM,
|
|
||||||
int numElements,
|
|
||||||
Injection<T, byte[]> injection,
|
|
||||||
OutputStream outputStream
|
|
||||||
) throws IOException, TException {
|
|
||||||
final int maxLevel = graphMeta.getMaxLevel();
|
|
||||||
final HnswInternalIndexMetadata metadata = new HnswInternalIndexMetadata(
|
|
||||||
maxLevel,
|
|
||||||
efConstruction,
|
|
||||||
maxM,
|
|
||||||
numElements
|
|
||||||
);
|
|
||||||
|
|
||||||
if (graphMeta.getEntryPoint().isPresent()) {
|
|
||||||
metadata.setEntryPoint(injection.apply(graphMeta.getEntryPoint().get()));
|
|
||||||
}
|
|
||||||
final TSerializer serializer = new TSerializer(new TBinaryProtocol.Factory());
|
|
||||||
outputStream.write(serializer.serialize(metadata));
|
|
||||||
outputStream.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load Hnsw index metadata
|
|
||||||
*/
|
|
||||||
public static HnswInternalIndexMetadata loadMetadata(AbstractFile file)
|
|
||||||
throws IOException, TException {
|
|
||||||
final HnswInternalIndexMetadata obj = new HnswInternalIndexMetadata();
|
|
||||||
final TDeserializer deserializer = new TDeserializer(new TBinaryProtocol.Factory());
|
|
||||||
deserializer.deserialize(obj, file.getByteSource().read());
|
|
||||||
return obj;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load Hnsw graph entries from file
|
|
||||||
*/
|
|
||||||
public static <T> Map<HnswNode<T>, ImmutableList<T>> loadHnswGraph(
|
|
||||||
AbstractFile file,
|
|
||||||
Injection<T, byte[]> injection,
|
|
||||||
int numElements
|
|
||||||
) throws IOException, TException {
|
|
||||||
final InputStream stream = file.getByteSource().openBufferedStream();
|
|
||||||
final TProtocol protocol = new TBinaryProtocol(new TIOStreamTransport(stream));
|
|
||||||
final Map<HnswNode<T>, ImmutableList<T>> graph =
|
|
||||||
new HashMap<>(numElements);
|
|
||||||
while (true) {
|
|
||||||
try {
|
|
||||||
final HnswGraphEntry entry = new HnswGraphEntry();
|
|
||||||
entry.read(protocol);
|
|
||||||
final HnswNode<T> node = HnswNode.from(entry.level,
|
|
||||||
injection.invert(ArrayByteBufferCodec.decode(entry.key)).get());
|
|
||||||
final List<T> list = entry.getNeighbours().stream()
|
|
||||||
.map(bb -> injection.invert(ArrayByteBufferCodec.decode(bb)).get())
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
graph.put(node, ImmutableList.copyOf(list.iterator()));
|
|
||||||
} catch (TException e) {
|
|
||||||
if (e instanceof TTransportException
|
|
||||||
&& TTransportException.class.cast(e).getType() == TTransportException.END_OF_FILE) {
|
|
||||||
stream.close();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
stream.close();
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save hnsw graph in file
|
|
||||||
*
|
|
||||||
* @return number of keys in the graph
|
|
||||||
*/
|
|
||||||
public static <T> int saveHnswGraphEntries(
|
|
||||||
Map<HnswNode<T>, ImmutableList<T>> graph,
|
|
||||||
OutputStream outputStream,
|
|
||||||
Injection<T, byte[]> injection
|
|
||||||
) throws IOException, TException {
|
|
||||||
final TProtocol protocol = new TBinaryProtocol(new TIOStreamTransport(outputStream));
|
|
||||||
final Set<HnswNode<T>> nodes = graph.keySet();
|
|
||||||
for (HnswNode<T> node : nodes) {
|
|
||||||
final HnswGraphEntry entry = new HnswGraphEntry();
|
|
||||||
entry.setLevel(node.level);
|
|
||||||
entry.setKey(injection.apply(node.item));
|
|
||||||
final List<ByteBuffer> nn = graph.getOrDefault(node, ImmutableList.of()).stream()
|
|
||||||
.map(t -> ByteBuffer.wrap(injection.apply(t)))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
entry.setNeighbours(nn);
|
|
||||||
entry.write(protocol);
|
|
||||||
}
|
|
||||||
|
|
||||||
outputStream.close();
|
|
||||||
return nodes.size();
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswMeta.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswMeta.docx
Normal file
Binary file not shown.
@ -1,45 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
class HnswMeta<T> {
|
|
||||||
private final int maxLevel;
|
|
||||||
private final Optional<T> entryPoint;
|
|
||||||
|
|
||||||
HnswMeta(int maxLevel, Optional<T> entryPoint) {
|
|
||||||
this.maxLevel = maxLevel;
|
|
||||||
this.entryPoint = entryPoint;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMaxLevel() {
|
|
||||||
return maxLevel;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Optional<T> getEntryPoint() {
|
|
||||||
return entryPoint;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (o == null || getClass() != o.getClass()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
HnswMeta<?> hnswMeta = (HnswMeta<?>) o;
|
|
||||||
return maxLevel == hnswMeta.maxLevel
|
|
||||||
&& Objects.equals(entryPoint, hnswMeta.entryPoint);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return Objects.hash(maxLevel, entryPoint);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "HnswMeta{maxLevel=" + maxLevel + ", entryPoint=" + entryPoint + '}';
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswNode.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswNode.docx
Normal file
Binary file not shown.
@ -1,45 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
import org.apache.commons.lang.builder.EqualsBuilder;
|
|
||||||
import org.apache.commons.lang.builder.HashCodeBuilder;
|
|
||||||
|
|
||||||
public class HnswNode<T> {
|
|
||||||
public final int level;
|
|
||||||
public final T item;
|
|
||||||
|
|
||||||
public HnswNode(int level, T item) {
|
|
||||||
this.level = level;
|
|
||||||
this.item = item;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a hnsw node.
|
|
||||||
*/
|
|
||||||
public static <T> HnswNode<T> from(int level, T item) {
|
|
||||||
return new HnswNode<>(level, item);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (o == this) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (!(o instanceof HnswNode)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
HnswNode<?> that = (HnswNode<?>) o;
|
|
||||||
return new EqualsBuilder()
|
|
||||||
.append(this.item, that.item)
|
|
||||||
.append(this.level, that.level)
|
|
||||||
.isEquals();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return new HashCodeBuilder()
|
|
||||||
.append(item)
|
|
||||||
.append(level)
|
|
||||||
.toHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,7 +0,0 @@
|
|||||||
package com.twitter.ann.hnsw;
|
|
||||||
|
|
||||||
public class IllegalDuplicateInsertException extends Exception {
|
|
||||||
public IllegalDuplicateInsertException(String message) {
|
|
||||||
super(message);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,38 +0,0 @@
|
|||||||
resources(
|
|
||||||
name = "sql",
|
|
||||||
sources = ["bq.sql"],
|
|
||||||
)
|
|
||||||
|
|
||||||
python3_library(
|
|
||||||
name = "faiss_indexing",
|
|
||||||
sources = ["**/*.py"],
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
":sql",
|
|
||||||
"3rdparty/python/apache-beam:default",
|
|
||||||
"3rdparty/python/faiss-gpu:default",
|
|
||||||
"3rdparty/python/gcsfs:default",
|
|
||||||
"3rdparty/python/google-cloud-bigquery:default",
|
|
||||||
"3rdparty/python/google-cloud-storage",
|
|
||||||
"3rdparty/python/numpy:default",
|
|
||||||
"3rdparty/python/pandas:default",
|
|
||||||
"3rdparty/python/pandas-gbq:default",
|
|
||||||
"3rdparty/python/pyarrow:default",
|
|
||||||
"src/python/twitter/ml/common/apache_beam",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python37_binary(
|
|
||||||
name = "faiss_indexing_bin",
|
|
||||||
sources = ["faiss_index_bq_dataset.py"],
|
|
||||||
platforms = [
|
|
||||||
"current",
|
|
||||||
"linux_x86_64",
|
|
||||||
],
|
|
||||||
tags = ["no-mypy"],
|
|
||||||
zip_safe = False,
|
|
||||||
dependencies = [
|
|
||||||
":faiss_indexing",
|
|
||||||
"3rdparty/python/_closures/ann/src/main/python/dataflow:faiss_indexing_bin",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
ann/src/main/python/dataflow/BUILD.docx
Normal file
BIN
ann/src/main/python/dataflow/BUILD.docx
Normal file
Binary file not shown.
BIN
ann/src/main/python/dataflow/bq.docx
Normal file
BIN
ann/src/main/python/dataflow/bq.docx
Normal file
Binary file not shown.
@ -1,6 +0,0 @@
|
|||||||
WITH maxts as (SELECT as value MAX(ts) as ts FROM `twttr-recos-ml-prod.ssedhain.twhin_tweet_avg_embedding`)
|
|
||||||
SELECT entityId, embedding
|
|
||||||
FROM `twttr-recos-ml-prod.ssedhain.twhin_tweet_avg_embedding`
|
|
||||||
WHERE ts >= (select max(maxts) from maxts)
|
|
||||||
AND DATE(TIMESTAMP_MILLIS(createdAt)) <= (select max(maxts) from maxts)
|
|
||||||
AND DATE(TIMESTAMP_MILLIS(createdAt)) >= DATE_SUB((select max(maxts) from maxts), INTERVAL 1 DAY)
|
|
BIN
ann/src/main/python/dataflow/faiss_index_bq_dataset.docx
Normal file
BIN
ann/src/main/python/dataflow/faiss_index_bq_dataset.docx
Normal file
Binary file not shown.
@ -1,232 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import pkgutil
|
|
||||||
import sys
|
|
||||||
from urllib.parse import urlsplit
|
|
||||||
|
|
||||||
import apache_beam as beam
|
|
||||||
from apache_beam.options.pipeline_options import PipelineOptions
|
|
||||||
import faiss
|
|
||||||
|
|
||||||
|
|
||||||
def parse_d6w_config(argv=None):
|
|
||||||
"""Parse d6w config.
|
|
||||||
:param argv: d6w config
|
|
||||||
:return: dictionary containing d6w config
|
|
||||||
"""
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="See https://docbird.twitter.biz/d6w/model.html for any parameters inherited from d6w job config"
|
|
||||||
)
|
|
||||||
parser.add_argument("--job_name", dest="job_name", required=True, help="d6w attribute")
|
|
||||||
parser.add_argument("--project", dest="project", required=True, help="d6w attribute")
|
|
||||||
parser.add_argument(
|
|
||||||
"--staging_location", dest="staging_location", required=True, help="d6w attribute"
|
|
||||||
)
|
|
||||||
parser.add_argument("--temp_location", dest="temp_location", required=True, help="d6w attribute")
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_location",
|
|
||||||
dest="output_location",
|
|
||||||
required=True,
|
|
||||||
help="GCS bucket and path where resulting artifacts are uploaded",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--service_account_email", dest="service_account_email", required=True, help="d6w attribute"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--factory_string",
|
|
||||||
dest="factory_string",
|
|
||||||
required=False,
|
|
||||||
help="FAISS factory string describing index to build. See https://github.com/facebookresearch/faiss/wiki/The-index-factory",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--metric",
|
|
||||||
dest="metric",
|
|
||||||
required=True,
|
|
||||||
help="Metric used to compute distance between embeddings. Valid values are 'l2', 'ip', 'l1', 'linf'",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_gpu",
|
|
||||||
dest="gpu",
|
|
||||||
required=True,
|
|
||||||
help="--use_gpu=yes if you want to use GPU during index building",
|
|
||||||
)
|
|
||||||
|
|
||||||
known_args, unknown_args = parser.parse_known_args(argv)
|
|
||||||
d6w_config = vars(known_args)
|
|
||||||
d6w_config["gpu"] = d6w_config["gpu"].lower() == "yes"
|
|
||||||
d6w_config["metric"] = parse_metric(d6w_config)
|
|
||||||
|
|
||||||
"""
|
|
||||||
WARNING: Currently, d6w (a Twitter tool used to deploy Dataflow jobs to GCP) and
|
|
||||||
PipelineOptions.for_dataflow_runner (a helper method in twitter.ml.common.apache_beam) do not
|
|
||||||
play nicely together. The helper method will overwrite some of the config specified in the d6w
|
|
||||||
file using the defaults in https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/src/python/twitter/ml/common/apache_beam/__init__.py?L24.'
|
|
||||||
However, the d6w output message will still report that the config specified in the d6w file was used.
|
|
||||||
"""
|
|
||||||
logging.warning(
|
|
||||||
f"The following d6w config parameters will be overwritten by the defaults in "
|
|
||||||
f"https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/src/python/twitter/ml/common/apache_beam/__init__.py?L24\n"
|
|
||||||
f"{str(unknown_args)}"
|
|
||||||
)
|
|
||||||
return d6w_config
|
|
||||||
|
|
||||||
|
|
||||||
def get_bq_query():
|
|
||||||
"""
|
|
||||||
Query is expected to return rows with unique entityId
|
|
||||||
"""
|
|
||||||
return pkgutil.get_data(__name__, "bq.sql").decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_metric(config):
|
|
||||||
metric_str = config["metric"].lower()
|
|
||||||
if metric_str == "l2":
|
|
||||||
return faiss.METRIC_L2
|
|
||||||
elif metric_str == "ip":
|
|
||||||
return faiss.METRIC_INNER_PRODUCT
|
|
||||||
elif metric_str == "l1":
|
|
||||||
return faiss.METRIC_L1
|
|
||||||
elif metric_str == "linf":
|
|
||||||
return faiss.METRIC_Linf
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unknown metric: {metric_str}")
|
|
||||||
|
|
||||||
|
|
||||||
def run_pipeline(argv=[]):
|
|
||||||
config = parse_d6w_config(argv)
|
|
||||||
argv_with_extras = argv
|
|
||||||
if config["gpu"]:
|
|
||||||
argv_with_extras.extend(["--experiments", "use_runner_v2"])
|
|
||||||
argv_with_extras.extend(
|
|
||||||
["--experiments", "worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver"]
|
|
||||||
)
|
|
||||||
argv_with_extras.extend(
|
|
||||||
[
|
|
||||||
"--worker_harness_container_image",
|
|
||||||
"gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
options = PipelineOptions(argv_with_extras)
|
|
||||||
output_bucket_name = urlsplit(config["output_location"]).netloc
|
|
||||||
|
|
||||||
with beam.Pipeline(options=options) as p:
|
|
||||||
input_data = p | "Read from BigQuery" >> beam.io.ReadFromBigQuery(
|
|
||||||
method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
|
|
||||||
query=get_bq_query(),
|
|
||||||
use_standard_sql=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
index_built = input_data | "Build and upload index" >> beam.CombineGlobally(
|
|
||||||
MergeAndBuildIndex(
|
|
||||||
output_bucket_name,
|
|
||||||
config["output_location"],
|
|
||||||
config["factory_string"],
|
|
||||||
config["metric"],
|
|
||||||
config["gpu"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make linter happy
|
|
||||||
index_built
|
|
||||||
|
|
||||||
|
|
||||||
class MergeAndBuildIndex(beam.CombineFn):
|
|
||||||
def __init__(self, bucket_name, gcs_output_path, factory_string, metric, gpu):
|
|
||||||
self.bucket_name = bucket_name
|
|
||||||
self.gcs_output_path = gcs_output_path
|
|
||||||
self.factory_string = factory_string
|
|
||||||
self.metric = metric
|
|
||||||
self.gpu = gpu
|
|
||||||
|
|
||||||
def create_accumulator(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def add_input(self, accumulator, element):
|
|
||||||
accumulator.append(element)
|
|
||||||
return accumulator
|
|
||||||
|
|
||||||
def merge_accumulators(self, accumulators):
|
|
||||||
merged = []
|
|
||||||
for accum in accumulators:
|
|
||||||
merged.extend(accum)
|
|
||||||
return merged
|
|
||||||
|
|
||||||
def extract_output(self, rows):
|
|
||||||
# Reimports are needed on workers
|
|
||||||
import glob
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import faiss
|
|
||||||
from google.cloud import storage
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
client = storage.Client()
|
|
||||||
bucket = client.get_bucket(self.bucket_name)
|
|
||||||
|
|
||||||
logging.info("Building FAISS index")
|
|
||||||
logging.info(f"There are {len(rows)} rows")
|
|
||||||
|
|
||||||
ids = np.array([x["entityId"] for x in rows]).astype("long")
|
|
||||||
embeds = np.array([x["embedding"] for x in rows]).astype("float32")
|
|
||||||
dimensions = len(embeds[0])
|
|
||||||
N = ids.shape[0]
|
|
||||||
logging.info(f"There are {dimensions} dimensions")
|
|
||||||
|
|
||||||
if self.factory_string is None:
|
|
||||||
M = 48
|
|
||||||
|
|
||||||
divideable_dimensions = (dimensions // M) * M
|
|
||||||
if divideable_dimensions != dimensions:
|
|
||||||
opq_prefix = f"OPQ{M}_{divideable_dimensions}"
|
|
||||||
else:
|
|
||||||
opq_prefix = f"OPQ{M}"
|
|
||||||
|
|
||||||
clusters = N // 20
|
|
||||||
self.factory_string = f"{opq_prefix},IVF{clusters},PQ{M}"
|
|
||||||
|
|
||||||
logging.info(f"Factory string is {self.factory_string}, metric={self.metric}")
|
|
||||||
|
|
||||||
if self.gpu:
|
|
||||||
logging.info("Using GPU")
|
|
||||||
|
|
||||||
res = faiss.StandardGpuResources()
|
|
||||||
cpu_index = faiss.index_factory(dimensions, self.factory_string, self.metric)
|
|
||||||
cpu_index = faiss.IndexIDMap(cpu_index)
|
|
||||||
gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
|
|
||||||
gpu_index.train(embeds)
|
|
||||||
gpu_index.add_with_ids(embeds, ids)
|
|
||||||
cpu_index = faiss.index_gpu_to_cpu(gpu_index)
|
|
||||||
else:
|
|
||||||
logging.info("Using CPU")
|
|
||||||
|
|
||||||
cpu_index = faiss.index_factory(dimensions, self.factory_string, self.metric)
|
|
||||||
cpu_index = faiss.IndexIDMap(cpu_index)
|
|
||||||
cpu_index.train(embeds)
|
|
||||||
cpu_index.add_with_ids(embeds, ids)
|
|
||||||
|
|
||||||
logging.info("Built faiss index")
|
|
||||||
|
|
||||||
local_path = "/indices"
|
|
||||||
logging.info(f"Writing indices to local {local_path}")
|
|
||||||
subprocess.run(f"mkdir -p {local_path}".strip().split())
|
|
||||||
local_index_path = os.path.join(local_path, "result.index")
|
|
||||||
|
|
||||||
faiss.write_index(cpu_index, local_index_path)
|
|
||||||
logging.info(f"Done writing indices to local {local_path}")
|
|
||||||
|
|
||||||
logging.info(f"Uploading to GCS with path {self.gcs_output_path}")
|
|
||||||
assert os.path.isdir(local_path)
|
|
||||||
for local_file in glob.glob(local_path + "/*"):
|
|
||||||
remote_path = os.path.join(
|
|
||||||
self.gcs_output_path.split("/")[-1], local_file[1 + len(local_path) :]
|
|
||||||
)
|
|
||||||
blob = bucket.blob(remote_path)
|
|
||||||
blob.upload_from_filename(local_file)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
|
||||||
run_pipeline(sys.argv)
|
|
@ -1,34 +0,0 @@
|
|||||||
FROM --platform=linux/amd64 nvidia/cuda:11.2.2-cudnn8-runtime-ubuntu20.04
|
|
||||||
|
|
||||||
RUN \
|
|
||||||
# Add Deadsnakes repository that has a variety of Python packages for Ubuntu.
|
|
||||||
# See: https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa
|
|
||||||
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F23C5A6CF475977595C89F51BA6932366A755776 \
|
|
||||||
&& echo "deb http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main" >> /etc/apt/sources.list.d/custom.list \
|
|
||||||
&& echo "deb-src http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main" >> /etc/apt/sources.list.d/custom.list \
|
|
||||||
&& apt-get update \
|
|
||||||
&& apt-get install -y curl \
|
|
||||||
python3.7 \
|
|
||||||
# With python3.8 package, distutils need to be installed separately.
|
|
||||||
python3.7-distutils \
|
|
||||||
python3-dev \
|
|
||||||
python3.7-dev \
|
|
||||||
libpython3.7-dev \
|
|
||||||
python3-apt \
|
|
||||||
gcc \
|
|
||||||
g++ \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.7 10
|
|
||||||
RUN rm -f /usr/bin/python3 && ln -s /usr/bin/python3.7 /usr/bin/python3
|
|
||||||
RUN \
|
|
||||||
curl https://bootstrap.pypa.io/get-pip.py | python \
|
|
||||||
&& pip3 install pip==22.0.3 \
|
|
||||||
&& python3 -m pip install --no-cache-dir apache-beam[gcp]==2.39.0
|
|
||||||
# Verify that there are no conflicting dependencies.
|
|
||||||
RUN pip3 check
|
|
||||||
|
|
||||||
# Copy the Apache Beam worker dependencies from the Beam Python 3.7 SDK image.
|
|
||||||
COPY --from=apache/beam_python3.7_sdk:2.39.0 /opt/apache/beam /opt/apache/beam
|
|
||||||
|
|
||||||
# Set the entrypoint to Apache Beam SDK worker launcher.
|
|
||||||
ENTRYPOINT [ "/opt/apache/beam/boot" ]
|
|
BIN
ann/src/main/python/dataflow/worker_harness/Dockerfile.docx
Normal file
BIN
ann/src/main/python/dataflow/worker_harness/Dockerfile.docx
Normal file
Binary file not shown.
BIN
ann/src/main/python/dataflow/worker_harness/cloudbuild.docx
Normal file
BIN
ann/src/main/python/dataflow/worker_harness/cloudbuild.docx
Normal file
Binary file not shown.
@ -1,6 +0,0 @@
|
|||||||
steps:
|
|
||||||
- name: 'gcr.io/cloud-builders/docker'
|
|
||||||
args: ['build', '-t', 'gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7', '.']
|
|
||||||
- name: 'gcr.io/cloud-builders/docker'
|
|
||||||
args: ['push', 'gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7']
|
|
||||||
images: ['gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7']
|
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/AnnoyCommon.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/AnnoyCommon.docx
Normal file
Binary file not shown.
@ -1,44 +0,0 @@
|
|||||||
package com.twitter.ann.annoy
|
|
||||||
|
|
||||||
import com.twitter.ann.common.RuntimeParams
|
|
||||||
import com.twitter.ann.common.thriftscala.AnnoyIndexMetadata
|
|
||||||
import com.twitter.bijection.Injection
|
|
||||||
import com.twitter.mediaservices.commons.codec.ThriftByteBufferCodec
|
|
||||||
import com.twitter.ann.common.thriftscala.{AnnoyRuntimeParam, RuntimeParams => ServiceRuntimeParams}
|
|
||||||
import scala.util.{Failure, Success, Try}
|
|
||||||
|
|
||||||
object AnnoyCommon {
|
|
||||||
private[annoy] lazy val MetadataCodec = new ThriftByteBufferCodec(AnnoyIndexMetadata)
|
|
||||||
private[annoy] val IndexFileName = "annoy_index"
|
|
||||||
private[annoy] val MetaDataFileName = "annoy_index_metadata"
|
|
||||||
private[annoy] val IndexIdMappingFileName = "annoy_index_id_mapping"
|
|
||||||
|
|
||||||
val RuntimeParamsInjection: Injection[AnnoyRuntimeParams, ServiceRuntimeParams] =
|
|
||||||
new Injection[AnnoyRuntimeParams, ServiceRuntimeParams] {
|
|
||||||
override def apply(scalaParams: AnnoyRuntimeParams): ServiceRuntimeParams = {
|
|
||||||
ServiceRuntimeParams.AnnoyParam(
|
|
||||||
AnnoyRuntimeParam(
|
|
||||||
scalaParams.nodesToExplore
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def invert(thriftParams: ServiceRuntimeParams): Try[AnnoyRuntimeParams] =
|
|
||||||
thriftParams match {
|
|
||||||
case ServiceRuntimeParams.AnnoyParam(annoyParam) =>
|
|
||||||
Success(
|
|
||||||
AnnoyRuntimeParams(annoyParam.numOfNodesToExplore)
|
|
||||||
)
|
|
||||||
case p => Failure(new IllegalArgumentException(s"Expected AnnoyRuntimeParams got $p"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case class AnnoyRuntimeParams(
|
|
||||||
/* Number of vectors to evaluate while searching. A larger value will give more accurate results, but will take longer time to return.
|
|
||||||
* Default value would be numberOfTrees*numberOfNeigboursRequested
|
|
||||||
*/
|
|
||||||
nodesToExplore: Option[Int])
|
|
||||||
extends RuntimeParams {
|
|
||||||
override def toString: String = s"AnnoyRuntimeParams( nodesToExplore = $nodesToExplore)"
|
|
||||||
}
|
|
@ -1,23 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
platform = "java8",
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/com/spotify:annoy-java",
|
|
||||||
"3rdparty/jvm/com/spotify:annoy-snapshot",
|
|
||||||
"3rdparty/jvm/com/twitter/storehaus:core",
|
|
||||||
"ann/src/main/scala/com/twitter/ann/common",
|
|
||||||
"ann/src/main/scala/com/twitter/ann/file_store",
|
|
||||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
|
||||||
"mediaservices/commons",
|
|
||||||
"src/java/com/twitter/search/common/file",
|
|
||||||
"src/scala/com/twitter/ml/api/embedding",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"ann/src/main/scala/com/twitter/ann/common",
|
|
||||||
"src/java/com/twitter/common_internal/hadoop",
|
|
||||||
"src/java/com/twitter/search/common/file",
|
|
||||||
"src/scala/com/twitter/ml/api/embedding",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/BUILD.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/BUILD.docx
Normal file
Binary file not shown.
Binary file not shown.
@ -1,123 +0,0 @@
|
|||||||
package com.twitter.ann.annoy
|
|
||||||
|
|
||||||
import com.spotify.annoy.jni.base.{Annoy => AnnoyLib}
|
|
||||||
import com.twitter.ann.annoy.AnnoyCommon.IndexFileName
|
|
||||||
import com.twitter.ann.annoy.AnnoyCommon.MetaDataFileName
|
|
||||||
import com.twitter.ann.annoy.AnnoyCommon.MetadataCodec
|
|
||||||
import com.twitter.ann.common.EmbeddingType._
|
|
||||||
import com.twitter.ann.common._
|
|
||||||
import com.twitter.ann.common.thriftscala.AnnoyIndexMetadata
|
|
||||||
import com.twitter.concurrent.AsyncSemaphore
|
|
||||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
|
||||||
import com.twitter.search.common.file.AbstractFile
|
|
||||||
import com.twitter.search.common.file.LocalFile
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import com.twitter.util.FuturePool
|
|
||||||
import java.io.File
|
|
||||||
import java.nio.file.Files
|
|
||||||
import org.apache.beam.sdk.io.fs.ResourceId
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
private[annoy] object RawAnnoyIndexBuilder {
|
|
||||||
private[annoy] def apply[D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
numOfTrees: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
futurePool: FuturePool
|
|
||||||
): RawAppendable[AnnoyRuntimeParams, D] with Serialization = {
|
|
||||||
val indexBuilder = AnnoyLib.newIndex(dimension, annoyMetric(metric))
|
|
||||||
new RawAnnoyIndexBuilder(dimension, numOfTrees, metric, indexBuilder, futurePool)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] def annoyMetric(metric: Metric[_]): AnnoyLib.Metric = {
|
|
||||||
metric match {
|
|
||||||
case L2 => AnnoyLib.Metric.EUCLIDEAN
|
|
||||||
case Cosine => AnnoyLib.Metric.ANGULAR
|
|
||||||
case _ => throw new RuntimeException("Not supported: " + metric)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] class RawAnnoyIndexBuilder[D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
numOfTrees: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
indexBuilder: AnnoyLib.Builder,
|
|
||||||
futurePool: FuturePool)
|
|
||||||
extends RawAppendable[AnnoyRuntimeParams, D]
|
|
||||||
with Serialization {
|
|
||||||
private[this] var counter = 0
|
|
||||||
// Note: Only one thread can access the underlying index, multithreaded index building not supported
|
|
||||||
private[this] val semaphore = new AsyncSemaphore(1)
|
|
||||||
|
|
||||||
override def append(embedding: EmbeddingVector): Future[Long] =
|
|
||||||
semaphore.acquireAndRun({
|
|
||||||
counter += 1
|
|
||||||
indexBuilder.addItem(
|
|
||||||
counter,
|
|
||||||
embedding.toArray
|
|
||||||
.map(float => float2Float(float))
|
|
||||||
.toList
|
|
||||||
.asJava
|
|
||||||
)
|
|
||||||
|
|
||||||
Future.value(counter)
|
|
||||||
})
|
|
||||||
|
|
||||||
override def toQueryable: Queryable[Long, AnnoyRuntimeParams, D] = {
|
|
||||||
val tempDirParent = Files.createTempDirectory("raw_annoy_index").toFile
|
|
||||||
tempDirParent.deleteOnExit
|
|
||||||
val tempDir = new LocalFile(tempDirParent)
|
|
||||||
this.toDirectory(tempDir)
|
|
||||||
RawAnnoyQueryIndex(
|
|
||||||
dimension,
|
|
||||||
metric,
|
|
||||||
futurePool,
|
|
||||||
tempDir
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toDirectory(directory: ResourceId): Unit = {
|
|
||||||
toDirectory(new IndexOutputFile(directory))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Serialize the annoy index in a directory.
|
|
||||||
* @param directory: Directory to save to.
|
|
||||||
*/
|
|
||||||
override def toDirectory(directory: AbstractFile): Unit = {
|
|
||||||
toDirectory(new IndexOutputFile(directory))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def toDirectory(directory: IndexOutputFile): Unit = {
|
|
||||||
val indexFile = directory.createFile(IndexFileName)
|
|
||||||
saveIndex(indexFile)
|
|
||||||
|
|
||||||
val metaDataFile = directory.createFile(MetaDataFileName)
|
|
||||||
saveMetadata(metaDataFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] def saveIndex(indexFile: IndexOutputFile): Unit = {
|
|
||||||
val index = indexBuilder
|
|
||||||
.build(numOfTrees)
|
|
||||||
val temp = new LocalFile(File.createTempFile(IndexFileName, null))
|
|
||||||
index.save(temp.getPath)
|
|
||||||
indexFile.copyFrom(temp.getByteSource.openStream())
|
|
||||||
temp.delete()
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] def saveMetadata(metadataFile: IndexOutputFile): Unit = {
|
|
||||||
val numberOfVectorsIndexed = counter
|
|
||||||
val metadata = AnnoyIndexMetadata(
|
|
||||||
dimension,
|
|
||||||
Metric.toThrift(metric),
|
|
||||||
numOfTrees,
|
|
||||||
numberOfVectorsIndexed
|
|
||||||
)
|
|
||||||
val bytes = ArrayByteBufferCodec.decode(MetadataCodec.encode(metadata))
|
|
||||||
val temp = new LocalFile(File.createTempFile(MetaDataFileName, null))
|
|
||||||
temp.getByteSink.write(bytes)
|
|
||||||
metadataFile.copyFrom(temp.getByteSource.openStream())
|
|
||||||
temp.delete()
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/RawAnnoyQueryIndex.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/RawAnnoyQueryIndex.docx
Normal file
Binary file not shown.
@ -1,142 +0,0 @@
|
|||||||
package com.twitter.ann.annoy
|
|
||||||
|
|
||||||
import com.spotify.annoy.{ANNIndex, IndexType}
|
|
||||||
import com.twitter.ann.annoy.AnnoyCommon._
|
|
||||||
import com.twitter.ann.common._
|
|
||||||
import com.twitter.ann.common.EmbeddingType._
|
|
||||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
|
||||||
import com.twitter.search.common.file.{AbstractFile, LocalFile}
|
|
||||||
import com.twitter.util.{Future, FuturePool}
|
|
||||||
import java.io.File
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
private[annoy] object RawAnnoyQueryIndex {
|
|
||||||
private[annoy] def apply[D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
directory: AbstractFile
|
|
||||||
): Queryable[Long, AnnoyRuntimeParams, D] = {
|
|
||||||
val metadataFile = directory.getChild(MetaDataFileName)
|
|
||||||
val indexFile = directory.getChild(IndexFileName)
|
|
||||||
val metadata = MetadataCodec.decode(
|
|
||||||
ArrayByteBufferCodec.encode(metadataFile.getByteSource.read())
|
|
||||||
)
|
|
||||||
|
|
||||||
val existingDimension = metadata.dimension
|
|
||||||
assert(
|
|
||||||
existingDimension == dimension,
|
|
||||||
s"Dimensions do not match. requested: $dimension existing: $existingDimension"
|
|
||||||
)
|
|
||||||
|
|
||||||
val existingMetric = Metric.fromThrift(metadata.distanceMetric)
|
|
||||||
assert(
|
|
||||||
existingMetric == metric,
|
|
||||||
s"DistanceMetric do not match. requested: $metric existing: $existingMetric"
|
|
||||||
)
|
|
||||||
|
|
||||||
val index = loadIndex(indexFile, dimension, annoyMetric(metric))
|
|
||||||
new RawAnnoyQueryIndex[D](
|
|
||||||
dimension,
|
|
||||||
metric,
|
|
||||||
metadata.numOfTrees,
|
|
||||||
index,
|
|
||||||
futurePool
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] def annoyMetric(metric: Metric[_]): IndexType = {
|
|
||||||
metric match {
|
|
||||||
case L2 => IndexType.EUCLIDEAN
|
|
||||||
case Cosine => IndexType.ANGULAR
|
|
||||||
case _ => throw new RuntimeException("Not supported: " + metric)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] def loadIndex(
|
|
||||||
indexFile: AbstractFile,
|
|
||||||
dimension: Int,
|
|
||||||
indexType: IndexType
|
|
||||||
): ANNIndex = {
|
|
||||||
var localIndexFile = indexFile
|
|
||||||
|
|
||||||
// If not a local file copy to local, so that it can be memory mapped.
|
|
||||||
if (!indexFile.isInstanceOf[LocalFile]) {
|
|
||||||
val tempFile = File.createTempFile(IndexFileName, null)
|
|
||||||
tempFile.deleteOnExit()
|
|
||||||
|
|
||||||
val temp = new LocalFile(tempFile)
|
|
||||||
indexFile.copyTo(temp)
|
|
||||||
localIndexFile = temp
|
|
||||||
}
|
|
||||||
|
|
||||||
new ANNIndex(
|
|
||||||
dimension,
|
|
||||||
localIndexFile.getPath(),
|
|
||||||
indexType
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] class RawAnnoyQueryIndex[D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
numOfTrees: Int,
|
|
||||||
index: ANNIndex,
|
|
||||||
futurePool: FuturePool)
|
|
||||||
extends Queryable[Long, AnnoyRuntimeParams, D]
|
|
||||||
with AutoCloseable {
|
|
||||||
override def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbours: Int,
|
|
||||||
runtimeParams: AnnoyRuntimeParams
|
|
||||||
): Future[List[Long]] = {
|
|
||||||
queryWithDistance(embedding, numOfNeighbours, runtimeParams)
|
|
||||||
.map(_.map(_.neighbor))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbours: Int,
|
|
||||||
runtimeParams: AnnoyRuntimeParams
|
|
||||||
): Future[List[NeighborWithDistance[Long, D]]] = {
|
|
||||||
futurePool {
|
|
||||||
val queryVector = embedding.toArray
|
|
||||||
val neigboursToRequest = neighboursToRequest(numOfNeighbours, runtimeParams)
|
|
||||||
val neigbours = index
|
|
||||||
.getNearestWithDistance(queryVector, neigboursToRequest)
|
|
||||||
.asScala
|
|
||||||
.take(numOfNeighbours)
|
|
||||||
.map { nn =>
|
|
||||||
val id = nn.getFirst.toLong
|
|
||||||
val distance = metric.fromAbsoluteDistance(nn.getSecond)
|
|
||||||
NeighborWithDistance(id, distance)
|
|
||||||
}
|
|
||||||
.toList
|
|
||||||
|
|
||||||
neigbours
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Annoy java lib do not expose param for numOfNodesToExplore.
|
|
||||||
// Default number is numOfTrees*numOfNeigbours.
|
|
||||||
// Simple hack is to artificially increase the numOfNeighbours to be requested and then just cap it before returning.
|
|
||||||
private[this] def neighboursToRequest(
|
|
||||||
numOfNeighbours: Int,
|
|
||||||
annoyParams: AnnoyRuntimeParams
|
|
||||||
): Int = {
|
|
||||||
annoyParams.nodesToExplore match {
|
|
||||||
case Some(nodesToExplore) => {
|
|
||||||
val neigboursToRequest = nodesToExplore / numOfTrees
|
|
||||||
if (neigboursToRequest < numOfNeighbours)
|
|
||||||
numOfNeighbours
|
|
||||||
else
|
|
||||||
neigboursToRequest
|
|
||||||
}
|
|
||||||
case _ => numOfNeighbours
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// To close the memory map based file resource.
|
|
||||||
override def close(): Unit = index.close()
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/TypedAnnoyIndex.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/TypedAnnoyIndex.docx
Normal file
Binary file not shown.
@ -1,55 +0,0 @@
|
|||||||
package com.twitter.ann.annoy
|
|
||||||
|
|
||||||
import com.twitter.ann.common._
|
|
||||||
import com.twitter.bijection.Injection
|
|
||||||
import com.twitter.search.common.file.AbstractFile
|
|
||||||
import com.twitter.util.FuturePool
|
|
||||||
|
|
||||||
// Class to provide Annoy based ann index.
|
|
||||||
object TypedAnnoyIndex {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create Annoy based typed index builder that serializes index to a directory (HDFS/Local file system).
|
|
||||||
* It cannot be used in scalding as it leverage C/C++ jni bindings, whose build conflicts with version of some libs installed on hadoop.
|
|
||||||
* You can use it on aurora or with IndexBuilding job which triggers scalding job but then streams data to aurora machine for building index.
|
|
||||||
* @param dimension dimension of embedding
|
|
||||||
* @param numOfTrees builds a forest of numOfTrees trees.
|
|
||||||
* More trees gives higher precision when querying at the cost of increased memory and disk storage requirement at the build time.
|
|
||||||
* At runtime the index will be memory mapped, so memory wont be an issue but disk storage would be needed.
|
|
||||||
* @param metric distance metric for nearest neighbour search
|
|
||||||
* @param injection Injection to convert bytes to Id.
|
|
||||||
* @tparam T Type of Id for embedding
|
|
||||||
* @tparam D Typed Distance
|
|
||||||
* @return Serializable AnnoyIndex
|
|
||||||
*/
|
|
||||||
def indexBuilder[T, D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
numOfTrees: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
injection: Injection[T, Array[Byte]],
|
|
||||||
futurePool: FuturePool
|
|
||||||
): Appendable[T, AnnoyRuntimeParams, D] with Serialization = {
|
|
||||||
TypedAnnoyIndexBuilderWithFile(dimension, numOfTrees, metric, injection, futurePool)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load Annoy based queryable index from a directory
|
|
||||||
* @param dimension dimension of embedding
|
|
||||||
* @param metric distance metric for nearest neighbour search
|
|
||||||
* @param injection Injection to convert bytes to Id.
|
|
||||||
* @param futurePool FuturePool
|
|
||||||
* @param directory Directory (HDFS/Local file system) where serialized index is stored.
|
|
||||||
* @tparam T Type of Id for embedding
|
|
||||||
* @tparam D Typed Distance
|
|
||||||
* @return Typed Queryable AnnoyIndex
|
|
||||||
*/
|
|
||||||
def loadQueryableIndex[T, D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
injection: Injection[T, Array[Byte]],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
directory: AbstractFile
|
|
||||||
): Queryable[T, AnnoyRuntimeParams, D] = {
|
|
||||||
TypedAnnoyQueryIndexWithFile(dimension, metric, injection, futurePool, directory)
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,55 +0,0 @@
|
|||||||
package com.twitter.ann.annoy
|
|
||||||
|
|
||||||
import com.twitter.ann.annoy.AnnoyCommon.IndexIdMappingFileName
|
|
||||||
import com.twitter.ann.common._
|
|
||||||
import com.twitter.ann.file_store.WritableIndexIdFileStore
|
|
||||||
import com.twitter.bijection.Injection
|
|
||||||
import com.twitter.search.common.file.AbstractFile
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import com.twitter.util.FuturePool
|
|
||||||
import org.apache.beam.sdk.io.fs.ResourceId
|
|
||||||
|
|
||||||
private[annoy] object TypedAnnoyIndexBuilderWithFile {
|
|
||||||
private[annoy] def apply[T, D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
numOfTrees: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
injection: Injection[T, Array[Byte]],
|
|
||||||
futurePool: FuturePool
|
|
||||||
): Appendable[T, AnnoyRuntimeParams, D] with Serialization = {
|
|
||||||
val index = RawAnnoyIndexBuilder(dimension, numOfTrees, metric, futurePool)
|
|
||||||
val writableFileStore = WritableIndexIdFileStore(injection)
|
|
||||||
new TypedAnnoyIndexBuilderWithFile[T, D](index, writableFileStore)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] class TypedAnnoyIndexBuilderWithFile[T, D <: Distance[D]](
|
|
||||||
indexBuilder: RawAppendable[AnnoyRuntimeParams, D] with Serialization,
|
|
||||||
store: WritableIndexIdFileStore[T])
|
|
||||||
extends Appendable[T, AnnoyRuntimeParams, D]
|
|
||||||
with Serialization {
|
|
||||||
private[this] val transformedIndex = IndexTransformer.transformAppendable(indexBuilder, store)
|
|
||||||
|
|
||||||
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
|
|
||||||
transformedIndex.append(entity)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toDirectory(directory: ResourceId): Unit = {
|
|
||||||
indexBuilder.toDirectory(directory)
|
|
||||||
toDirectory(new IndexOutputFile(directory))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toDirectory(directory: AbstractFile): Unit = {
|
|
||||||
indexBuilder.toDirectory(directory)
|
|
||||||
toDirectory(new IndexOutputFile(directory))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def toDirectory(directory: IndexOutputFile): Unit = {
|
|
||||||
val indexIdFile = directory.createFile(IndexIdMappingFileName)
|
|
||||||
store.save(indexIdFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toQueryable: Queryable[T, AnnoyRuntimeParams, D] = {
|
|
||||||
transformedIndex.toQueryable
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,42 +0,0 @@
|
|||||||
package com.twitter.ann.annoy
|
|
||||||
|
|
||||||
import com.twitter.ann.annoy.AnnoyCommon._
|
|
||||||
import com.twitter.ann.common._
|
|
||||||
import com.twitter.ann.file_store.ReadableIndexIdFileStore
|
|
||||||
import com.twitter.bijection.Injection
|
|
||||||
import com.twitter.search.common.file.AbstractFile
|
|
||||||
import com.twitter.util.FuturePool
|
|
||||||
|
|
||||||
private[annoy] object TypedAnnoyQueryIndexWithFile {
|
|
||||||
private[annoy] def apply[T, D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
injection: Injection[T, Array[Byte]],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
directory: AbstractFile
|
|
||||||
): Queryable[T, AnnoyRuntimeParams, D] = {
|
|
||||||
val deserializer =
|
|
||||||
new TypedAnnoyQueryIndexWithFile(dimension, metric, futurePool, injection)
|
|
||||||
deserializer.fromDirectory(directory)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[this] class TypedAnnoyQueryIndexWithFile[T, D <: Distance[D]](
|
|
||||||
dimension: Int,
|
|
||||||
metric: Metric[D],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
injection: Injection[T, Array[Byte]])
|
|
||||||
extends QueryableDeserialization[
|
|
||||||
T,
|
|
||||||
AnnoyRuntimeParams,
|
|
||||||
D,
|
|
||||||
Queryable[T, AnnoyRuntimeParams, D]
|
|
||||||
] {
|
|
||||||
override def fromDirectory(directory: AbstractFile): Queryable[T, AnnoyRuntimeParams, D] = {
|
|
||||||
val index = RawAnnoyQueryIndex(dimension, metric, futurePool, directory)
|
|
||||||
|
|
||||||
val indexIdFile = directory.getChild(IndexIdMappingFileName)
|
|
||||||
val readableFileStore = ReadableIndexIdFileStore(indexIdFile, injection)
|
|
||||||
IndexTransformer.transformQueryable(index, readableFileStore)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
platform = "java8",
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"ann/src/main/scala/com/twitter/ann/common",
|
|
||||||
"ann/src/main/scala/com/twitter/ann/serialization",
|
|
||||||
"ann/src/main/thrift/com/twitter/ann/serialization:serialization-scala",
|
|
||||||
"src/java/com/twitter/search/common/file",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
ann/src/main/scala/com/twitter/ann/brute_force/BUILD.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/brute_force/BUILD.docx
Normal file
Binary file not shown.
Binary file not shown.
@ -1,64 +0,0 @@
|
|||||||
package com.twitter.ann.brute_force
|
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting
|
|
||||||
import com.twitter.ann.common.{Distance, EntityEmbedding, Metric, QueryableDeserialization}
|
|
||||||
import com.twitter.ann.serialization.{PersistedEmbeddingInjection, ThriftIteratorIO}
|
|
||||||
import com.twitter.ann.serialization.thriftscala.PersistedEmbedding
|
|
||||||
import com.twitter.search.common.file.{AbstractFile, LocalFile}
|
|
||||||
import com.twitter.util.FuturePool
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param factory creates a BruteForceIndex from the arguments. This is only exposed for testing.
|
|
||||||
* If for some reason you pass this arg in make sure that it eagerly consumes the
|
|
||||||
* iterator. If you don't you might close the input stream that the iterator is
|
|
||||||
* using.
|
|
||||||
* @tparam T the id of the embeddings
|
|
||||||
*/
|
|
||||||
class BruteForceDeserialization[T, D <: Distance[D]] @VisibleForTesting private[brute_force] (
|
|
||||||
metric: Metric[D],
|
|
||||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding],
|
|
||||||
factory: (Metric[D], FuturePool, Iterator[EntityEmbedding[T]]) => BruteForceIndex[T, D])
|
|
||||||
extends QueryableDeserialization[T, BruteForceRuntimeParams.type, D, BruteForceIndex[T, D]] {
|
|
||||||
import BruteForceIndex._
|
|
||||||
|
|
||||||
def this(
|
|
||||||
metric: Metric[D],
|
|
||||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding]
|
|
||||||
) = {
|
|
||||||
this(
|
|
||||||
metric,
|
|
||||||
embeddingInjection,
|
|
||||||
futurePool,
|
|
||||||
thriftIteratorIO,
|
|
||||||
factory = BruteForceIndex.apply[T, D]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def fromDirectory(
|
|
||||||
serializationDirectory: AbstractFile
|
|
||||||
): BruteForceIndex[T, D] = {
|
|
||||||
val file = File.createTempFile(DataFileName, "tmp")
|
|
||||||
file.deleteOnExit()
|
|
||||||
val temp = new LocalFile(file)
|
|
||||||
val dataFile = serializationDirectory.getChild(DataFileName)
|
|
||||||
dataFile.copyTo(temp)
|
|
||||||
val inputStream = temp.getByteSource.openBufferedStream()
|
|
||||||
try {
|
|
||||||
val iterator: Iterator[PersistedEmbedding] = thriftIteratorIO.fromInputStream(inputStream)
|
|
||||||
|
|
||||||
val embeddings = iterator.map { thriftEmbedding =>
|
|
||||||
embeddingInjection.invert(thriftEmbedding).get
|
|
||||||
}
|
|
||||||
|
|
||||||
factory(metric, futurePool, embeddings)
|
|
||||||
} finally {
|
|
||||||
inputStream.close()
|
|
||||||
temp.delete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,162 +0,0 @@
|
|||||||
package com.twitter.ann.brute_force
|
|
||||||
|
|
||||||
import com.twitter.ann.common.Appendable
|
|
||||||
import com.twitter.ann.common.Distance
|
|
||||||
import com.twitter.ann.common.EmbeddingType._
|
|
||||||
import com.twitter.ann.common.EntityEmbedding
|
|
||||||
import com.twitter.ann.common.IndexOutputFile
|
|
||||||
import com.twitter.ann.common.Metric
|
|
||||||
import com.twitter.ann.common.NeighborWithDistance
|
|
||||||
import com.twitter.ann.common.Queryable
|
|
||||||
import com.twitter.ann.common.RuntimeParams
|
|
||||||
import com.twitter.ann.common.Serialization
|
|
||||||
import com.twitter.ann.serialization.PersistedEmbeddingInjection
|
|
||||||
import com.twitter.ann.serialization.ThriftIteratorIO
|
|
||||||
import com.twitter.ann.serialization.thriftscala.PersistedEmbedding
|
|
||||||
import com.twitter.search.common.file.AbstractFile
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import com.twitter.util.FuturePool
|
|
||||||
import java.util.concurrent.ConcurrentLinkedQueue
|
|
||||||
import org.apache.beam.sdk.io.fs.ResourceId
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
object BruteForceRuntimeParams extends RuntimeParams
|
|
||||||
|
|
||||||
object BruteForceIndex {
|
|
||||||
val DataFileName = "BruteForceFileData"
|
|
||||||
|
|
||||||
def apply[T, D <: Distance[D]](
|
|
||||||
metric: Metric[D],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
initialEmbeddings: Iterator[EntityEmbedding[T]] = Iterator()
|
|
||||||
): BruteForceIndex[T, D] = {
|
|
||||||
val linkedQueue = new ConcurrentLinkedQueue[EntityEmbedding[T]]
|
|
||||||
initialEmbeddings.foreach(embedding => linkedQueue.add(embedding))
|
|
||||||
new BruteForceIndex(metric, futurePool, linkedQueue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class BruteForceIndex[T, D <: Distance[D]] private (
|
|
||||||
metric: Metric[D],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
// visible for serialization
|
|
||||||
private[brute_force] val linkedQueue: ConcurrentLinkedQueue[EntityEmbedding[T]])
|
|
||||||
extends Appendable[T, BruteForceRuntimeParams.type, D]
|
|
||||||
with Queryable[T, BruteForceRuntimeParams.type, D] {
|
|
||||||
|
|
||||||
override def append(embedding: EntityEmbedding[T]): Future[Unit] = {
|
|
||||||
futurePool {
|
|
||||||
linkedQueue.add(embedding)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toQueryable: Queryable[T, BruteForceRuntimeParams.type, D] = this
|
|
||||||
|
|
||||||
override def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbours: Int,
|
|
||||||
runtimeParams: BruteForceRuntimeParams.type
|
|
||||||
): Future[List[T]] = {
|
|
||||||
queryWithDistance(embedding, numOfNeighbours, runtimeParams).map { neighborsWithDistance =>
|
|
||||||
neighborsWithDistance.map(_.neighbor)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbours: Int,
|
|
||||||
runtimeParams: BruteForceRuntimeParams.type
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]] = {
|
|
||||||
futurePool {
|
|
||||||
// Use the reverse ordering so that we can call dequeue to remove the largest element.
|
|
||||||
val ordering = Ordering.by[NeighborWithDistance[T, D], D](_.distance)
|
|
||||||
val priorityQueue =
|
|
||||||
new mutable.PriorityQueue[NeighborWithDistance[T, D]]()(ordering)
|
|
||||||
linkedQueue
|
|
||||||
.iterator()
|
|
||||||
.asScala
|
|
||||||
.foreach { entity =>
|
|
||||||
val neighborWithDistance =
|
|
||||||
NeighborWithDistance(entity.id, metric.distance(entity.embedding, embedding))
|
|
||||||
priorityQueue.+=(neighborWithDistance)
|
|
||||||
if (priorityQueue.size > numOfNeighbours) {
|
|
||||||
priorityQueue.dequeue()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val reverseList: List[NeighborWithDistance[T, D]] =
|
|
||||||
priorityQueue.dequeueAll
|
|
||||||
reverseList.reverse
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
object SerializableBruteForceIndex {
|
|
||||||
def apply[T, D <: Distance[D]](
|
|
||||||
metric: Metric[D],
|
|
||||||
futurePool: FuturePool,
|
|
||||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
|
||||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding]
|
|
||||||
): SerializableBruteForceIndex[T, D] = {
|
|
||||||
val bruteForceIndex = BruteForceIndex[T, D](metric, futurePool)
|
|
||||||
|
|
||||||
new SerializableBruteForceIndex(bruteForceIndex, embeddingInjection, thriftIteratorIO)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This is a class that wrapps a BruteForceIndex and provides a method for serialization.
|
|
||||||
*
|
|
||||||
* @param bruteForceIndex all queries and updates are sent to this index.
|
|
||||||
* @param embeddingInjection injection that can convert embeddings to thrift embeddings.
|
|
||||||
* @param thriftIteratorIO class that provides a way to write PersistedEmbeddings to disk
|
|
||||||
*/
|
|
||||||
class SerializableBruteForceIndex[T, D <: Distance[D]](
|
|
||||||
bruteForceIndex: BruteForceIndex[T, D],
|
|
||||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
|
||||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding])
|
|
||||||
extends Appendable[T, BruteForceRuntimeParams.type, D]
|
|
||||||
with Queryable[T, BruteForceRuntimeParams.type, D]
|
|
||||||
with Serialization {
|
|
||||||
import BruteForceIndex._
|
|
||||||
|
|
||||||
override def append(entity: EntityEmbedding[T]): Future[Unit] =
|
|
||||||
bruteForceIndex.append(entity)
|
|
||||||
|
|
||||||
override def toQueryable: Queryable[T, BruteForceRuntimeParams.type, D] = this
|
|
||||||
|
|
||||||
override def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbours: Int,
|
|
||||||
runtimeParams: BruteForceRuntimeParams.type
|
|
||||||
): Future[List[T]] =
|
|
||||||
bruteForceIndex.query(embedding, numOfNeighbours, runtimeParams)
|
|
||||||
|
|
||||||
override def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbours: Int,
|
|
||||||
runtimeParams: BruteForceRuntimeParams.type
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]] =
|
|
||||||
bruteForceIndex.queryWithDistance(embedding, numOfNeighbours, runtimeParams)
|
|
||||||
|
|
||||||
override def toDirectory(serializationDirectory: ResourceId): Unit = {
|
|
||||||
toDirectory(new IndexOutputFile(serializationDirectory))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toDirectory(serializationDirectory: AbstractFile): Unit = {
|
|
||||||
toDirectory(new IndexOutputFile(serializationDirectory))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def toDirectory(serializationDirectory: IndexOutputFile): Unit = {
|
|
||||||
val outputStream = serializationDirectory.createFile(DataFileName).getOutputStream()
|
|
||||||
val thriftEmbeddings =
|
|
||||||
bruteForceIndex.linkedQueue.iterator().asScala.map { embedding =>
|
|
||||||
embeddingInjection(embedding)
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
thriftIteratorIO.toOutputStream(thriftEmbeddings, outputStream)
|
|
||||||
} finally {
|
|
||||||
outputStream.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/AnnInjections.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/AnnInjections.docx
Normal file
Binary file not shown.
@ -1,28 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.bijection.{Bijection, Injection}
|
|
||||||
|
|
||||||
// Class providing commonly used injections that can be used directly with ANN apis.
|
|
||||||
// Injection prefixed with `J` can be used in java directly with ANN apis.
|
|
||||||
object AnnInjections {
|
|
||||||
val LongInjection: Injection[Long, Array[Byte]] = Injection.long2BigEndian
|
|
||||||
|
|
||||||
def StringInjection: Injection[String, Array[Byte]] = Injection.utf8
|
|
||||||
|
|
||||||
def IntInjection: Injection[Int, Array[Byte]] = Injection.int2BigEndian
|
|
||||||
|
|
||||||
val JLongInjection: Injection[java.lang.Long, Array[Byte]] =
|
|
||||||
Bijection.long2Boxed
|
|
||||||
.asInstanceOf[Bijection[Long, java.lang.Long]]
|
|
||||||
.inverse
|
|
||||||
.andThen(LongInjection)
|
|
||||||
|
|
||||||
val JStringInjection: Injection[java.lang.String, Array[Byte]] =
|
|
||||||
StringInjection
|
|
||||||
|
|
||||||
val JIntInjection: Injection[java.lang.Integer, Array[Byte]] =
|
|
||||||
Bijection.int2Boxed
|
|
||||||
.asInstanceOf[Bijection[Int, java.lang.Integer]]
|
|
||||||
.inverse
|
|
||||||
.andThen(IntInjection)
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/Api.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/Api.docx
Normal file
Binary file not shown.
@ -1,150 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
|
||||||
import com.twitter.ml.api.embedding.Embedding
|
|
||||||
import com.twitter.ml.api.embedding.EmbeddingMath
|
|
||||||
import com.twitter.ml.api.embedding.EmbeddingSerDe
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
object EmbeddingType {
|
|
||||||
type EmbeddingVector = Embedding[Float]
|
|
||||||
val embeddingSerDe = EmbeddingSerDe.apply[Float]
|
|
||||||
private[common] val math = EmbeddingMath.Float
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Typed entity with an embedding associated with it.
|
|
||||||
* @param id : Unique Id for an entity.
|
|
||||||
* @param embedding : Embedding/Vector of an entity.
|
|
||||||
* @tparam T: Type of id.
|
|
||||||
*/
|
|
||||||
case class EntityEmbedding[T](id: T, embedding: EmbeddingVector)
|
|
||||||
|
|
||||||
// Query interface for ANN
|
|
||||||
trait Queryable[T, P <: RuntimeParams, D <: Distance[D]] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ANN query for ids.
|
|
||||||
* @param embedding: Embedding/Vector to be queried with.
|
|
||||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
|
||||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
|
||||||
* @return List of approximate nearest neighbour ids.
|
|
||||||
*/
|
|
||||||
def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[T]]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ANN query for ids with distance.
|
|
||||||
* @param embedding: Embedding/Vector to be queried with.
|
|
||||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
|
||||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
|
||||||
* @return List of approximate nearest neighbour ids with distance from the query embedding.
|
|
||||||
*/
|
|
||||||
def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query interface for ANN over indexes that are grouped
|
|
||||||
trait QueryableGrouped[T, P <: RuntimeParams, D <: Distance[D]] extends Queryable[T, P, D] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ANN query for ids.
|
|
||||||
* @param embedding: Embedding/Vector to be queried with.
|
|
||||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
|
||||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
|
||||||
* @param key: Optional key to lookup specific ANN index and perform query there
|
|
||||||
* @return List of approximate nearest neighbour ids.
|
|
||||||
*/
|
|
||||||
def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P,
|
|
||||||
key: Option[String]
|
|
||||||
): Future[List[T]]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ANN query for ids with distance.
|
|
||||||
* @param embedding: Embedding/Vector to be queried with.
|
|
||||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
|
||||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
|
||||||
* @param key: Optional key to lookup specific ANN index and perform query there
|
|
||||||
* @return List of approximate nearest neighbour ids with distance from the query embedding.
|
|
||||||
*/
|
|
||||||
def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P,
|
|
||||||
key: Option[String]
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]]
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runtime params associated with index to control accuracy/latency etc while querying.
|
|
||||||
*/
|
|
||||||
trait RuntimeParams {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ANN query result with distance.
|
|
||||||
* @param neighbor : Id of the neighbours
|
|
||||||
* @param distance: Distance of neighbour from query ex: D: CosineDistance, L2Distance, InnerProductDistance
|
|
||||||
*/
|
|
||||||
case class NeighborWithDistance[T, D <: Distance[D]](neighbor: T, distance: D)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ANN query result with seed entity for which this neighbor was provided.
|
|
||||||
* @param seed: Seed Id for which ann query was called
|
|
||||||
* @param neighbor : Id of the neighbours
|
|
||||||
*/
|
|
||||||
case class NeighborWithSeed[T1, T2](seed: T1, neighbor: T2)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ANN query result with distance with seed entity for which this neighbor was provided.
|
|
||||||
* @param seed: Seed Id for which ann query was called
|
|
||||||
* @param neighbor : Id of the neighbours
|
|
||||||
* @param distance: Distance of neighbour from query ex: D: CosineDistance, L2Distance, InnerProductDistance
|
|
||||||
*/
|
|
||||||
case class NeighborWithDistanceWithSeed[T1, T2, D <: Distance[D]](
|
|
||||||
seed: T1,
|
|
||||||
neighbor: T2,
|
|
||||||
distance: D)
|
|
||||||
|
|
||||||
trait RawAppendable[P <: RuntimeParams, D <: Distance[D]] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Append an embedding in an index.
|
|
||||||
* @param embedding: Embedding/Vector
|
|
||||||
* @return Future of long id associated with embedding autogenerated.
|
|
||||||
*/
|
|
||||||
def append(embedding: EmbeddingVector): Future[Long]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert an Appendable to Queryable interface to query an index.
|
|
||||||
*/
|
|
||||||
def toQueryable: Queryable[Long, P, D]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Index building interface for ANN.
|
|
||||||
trait Appendable[T, P <: RuntimeParams, D <: Distance[D]] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Append an entity with embedding in an index.
|
|
||||||
* @param entity: Entity with its embedding
|
|
||||||
*/
|
|
||||||
def append(entity: EntityEmbedding[T]): Future[Unit]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert an Appendable to Queryable interface to query an index.
|
|
||||||
*/
|
|
||||||
def toQueryable: Queryable[T, P, D]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Updatable index interface for ANN.
|
|
||||||
trait Updatable[T] {
|
|
||||||
def update(entity: EntityEmbedding[T]): Future[Unit]
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
platform = "java8",
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/com/google/guava",
|
|
||||||
"3rdparty/jvm/com/twitter/bijection:core",
|
|
||||||
"3rdparty/jvm/com/twitter/storehaus:core",
|
|
||||||
"3rdparty/jvm/org/apache/beam:beam-sdks-java-io-google-cloud-platform",
|
|
||||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
|
||||||
"finatra/inject/inject-mdc/src/main/scala",
|
|
||||||
"mediaservices/commons/src/main/scala:futuretracker",
|
|
||||||
"src/java/com/twitter/search/common/file",
|
|
||||||
"src/scala/com/twitter/ml/api/embedding",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"3rdparty/jvm/com/twitter/bijection:core",
|
|
||||||
],
|
|
||||||
)
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/BUILD.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/BUILD.docx
Normal file
Binary file not shown.
BIN
ann/src/main/scala/com/twitter/ann/common/EmbeddingProducer.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/EmbeddingProducer.docx
Normal file
Binary file not shown.
@ -1,13 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
|
|
||||||
trait EmbeddingProducer[T] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce an embedding from type T. Implementations of this could do a lookup from an id to an
|
|
||||||
* embedding. Or they could run a deep model on features that output and embedding.
|
|
||||||
* @return An embedding Stitch. See go/stitch for details on how to use the Stitch API.
|
|
||||||
*/
|
|
||||||
def produceEmbedding(input: T): Stitch[Option[EmbeddingType.EmbeddingVector]]
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/IndexOutputFile.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/IndexOutputFile.docx
Normal file
Binary file not shown.
@ -1,226 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.google.common.io.ByteStreams
|
|
||||||
import com.twitter.ann.common.thriftscala.AnnIndexMetadata
|
|
||||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
|
||||||
import com.twitter.mediaservices.commons.codec.ThriftByteBufferCodec
|
|
||||||
import com.twitter.search.common.file.AbstractFile
|
|
||||||
import java.io.IOException
|
|
||||||
import java.io.InputStream
|
|
||||||
import java.io.OutputStream
|
|
||||||
import java.nio.channels.Channels
|
|
||||||
import org.apache.beam.sdk.io.FileSystems
|
|
||||||
import org.apache.beam.sdk.io.fs.MoveOptions
|
|
||||||
import org.apache.beam.sdk.io.fs.ResolveOptions
|
|
||||||
import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions
|
|
||||||
import org.apache.beam.sdk.io.fs.ResourceId
|
|
||||||
import org.apache.beam.sdk.util.MimeTypes
|
|
||||||
import org.apache.hadoop.io.IOUtils
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This class creates a wrapper around GCS filesystem and HDFS filesystem for the index
|
|
||||||
* generation job. It implements the basic methods required by the index generation job and hides
|
|
||||||
* the logic around handling HDFS vs GCS.
|
|
||||||
*/
|
|
||||||
class IndexOutputFile(val abstractFile: AbstractFile, val resourceId: ResourceId) {
|
|
||||||
|
|
||||||
// Success file name
|
|
||||||
private val SUCCESS_FILE = "_SUCCESS"
|
|
||||||
private val INDEX_METADATA_FILE = "ANN_INDEX_METADATA"
|
|
||||||
private val MetadataCodec = new ThriftByteBufferCodec[AnnIndexMetadata](AnnIndexMetadata)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor for ResourceId. This is used for GCS filesystem
|
|
||||||
* @param resourceId
|
|
||||||
*/
|
|
||||||
def this(resourceId: ResourceId) = {
|
|
||||||
this(null, resourceId)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor for AbstractFile. This is used for HDFS and local filesystem
|
|
||||||
* @param abstractFile
|
|
||||||
*/
|
|
||||||
def this(abstractFile: AbstractFile) = {
|
|
||||||
this(abstractFile, null)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns true if this instance is around an AbstractFile.
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
def isAbstractFile(): Boolean = {
|
|
||||||
abstractFile != null
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a _SUCCESS file in the current directory.
|
|
||||||
*/
|
|
||||||
def createSuccessFile(): Unit = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
abstractFile.createSuccessFile()
|
|
||||||
} else {
|
|
||||||
val successFile =
|
|
||||||
resourceId.resolve(SUCCESS_FILE, ResolveOptions.StandardResolveOptions.RESOLVE_FILE)
|
|
||||||
val successWriterChannel = FileSystems.create(successFile, MimeTypes.BINARY)
|
|
||||||
successWriterChannel.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns whether the current instance represents a directory
|
|
||||||
* @return True if the current instance is a directory
|
|
||||||
*/
|
|
||||||
def isDirectory(): Boolean = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
abstractFile.isDirectory
|
|
||||||
} else {
|
|
||||||
resourceId.isDirectory
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the current path of the file represented by the current instance
|
|
||||||
* @return The path string of the file/directory
|
|
||||||
*/
|
|
||||||
def getPath(): String = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
abstractFile.getPath.toString
|
|
||||||
} else {
|
|
||||||
if (resourceId.isDirectory) {
|
|
||||||
resourceId.getCurrentDirectory.toString
|
|
||||||
} else {
|
|
||||||
resourceId.getCurrentDirectory.toString + resourceId.getFilename
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a new file @param fileName in the current directory.
|
|
||||||
* @param fileName
|
|
||||||
* @return A new file inside the current directory
|
|
||||||
*/
|
|
||||||
def createFile(fileName: String): IndexOutputFile = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
// AbstractFile treats files and directories the same way. Hence, not checking for directory
|
|
||||||
// here.
|
|
||||||
new IndexOutputFile(abstractFile.getChild(fileName))
|
|
||||||
} else {
|
|
||||||
if (!resourceId.isDirectory) {
|
|
||||||
// If this is not a directory, throw exception.
|
|
||||||
throw new IllegalArgumentException(getPath() + " is not a directory.")
|
|
||||||
}
|
|
||||||
new IndexOutputFile(
|
|
||||||
resourceId.resolve(fileName, ResolveOptions.StandardResolveOptions.RESOLVE_FILE))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a new directory @param directoryName in the current directory.
|
|
||||||
* @param directoryName
|
|
||||||
* @return A new directory inside the current directory
|
|
||||||
*/
|
|
||||||
def createDirectory(directoryName: String): IndexOutputFile = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
// AbstractFile treats files and directories the same way. Hence, not checking for directory
|
|
||||||
// here.
|
|
||||||
val dir = abstractFile.getChild(directoryName)
|
|
||||||
dir.mkdirs()
|
|
||||||
new IndexOutputFile(dir)
|
|
||||||
} else {
|
|
||||||
if (!resourceId.isDirectory) {
|
|
||||||
// If this is not a directory, throw exception.
|
|
||||||
throw new IllegalArgumentException(getPath() + " is not a directory.")
|
|
||||||
}
|
|
||||||
val newResourceId =
|
|
||||||
resourceId.resolve(directoryName, ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
|
|
||||||
|
|
||||||
// Create a tmp file and delete in order to trigger directory creation
|
|
||||||
val tmpFile =
|
|
||||||
newResourceId.resolve("tmp", ResolveOptions.StandardResolveOptions.RESOLVE_FILE)
|
|
||||||
val tmpWriterChannel = FileSystems.create(tmpFile, MimeTypes.BINARY)
|
|
||||||
tmpWriterChannel.close()
|
|
||||||
FileSystems.delete(List(tmpFile).asJava, MoveOptions.StandardMoveOptions.IGNORE_MISSING_FILES)
|
|
||||||
|
|
||||||
new IndexOutputFile(newResourceId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def getChild(fileName: String, isDirectory: Boolean = false): IndexOutputFile = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
new IndexOutputFile(abstractFile.getChild(fileName))
|
|
||||||
} else {
|
|
||||||
val resolveOption = if (isDirectory) {
|
|
||||||
StandardResolveOptions.RESOLVE_DIRECTORY
|
|
||||||
} else {
|
|
||||||
StandardResolveOptions.RESOLVE_FILE
|
|
||||||
}
|
|
||||||
new IndexOutputFile(resourceId.resolve(fileName, resolveOption))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns an OutputStream for the underlying file.
|
|
||||||
* Note: Close the OutputStream after writing
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
def getOutputStream(): OutputStream = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
abstractFile.getByteSink.openStream()
|
|
||||||
} else {
|
|
||||||
if (resourceId.isDirectory) {
|
|
||||||
// If this is a directory, throw exception.
|
|
||||||
throw new IllegalArgumentException(getPath() + " is a directory.")
|
|
||||||
}
|
|
||||||
val writerChannel = FileSystems.create(resourceId, MimeTypes.BINARY)
|
|
||||||
Channels.newOutputStream(writerChannel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns an InputStream for the underlying file.
|
|
||||||
* Note: Close the InputStream after reading
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
def getInputStream(): InputStream = {
|
|
||||||
if (isAbstractFile()) {
|
|
||||||
abstractFile.getByteSource.openStream()
|
|
||||||
} else {
|
|
||||||
if (resourceId.isDirectory) {
|
|
||||||
// If this is a directory, throw exception.
|
|
||||||
throw new IllegalArgumentException(getPath() + " is a directory.")
|
|
||||||
}
|
|
||||||
val readChannel = FileSystems.open(resourceId)
|
|
||||||
Channels.newInputStream(readChannel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Copies content from the srcIn into the current file.
|
|
||||||
* @param srcIn
|
|
||||||
*/
|
|
||||||
def copyFrom(srcIn: InputStream): Unit = {
|
|
||||||
val out = getOutputStream()
|
|
||||||
try {
|
|
||||||
IOUtils.copyBytes(srcIn, out, 4096)
|
|
||||||
out.close()
|
|
||||||
} catch {
|
|
||||||
case ex: IOException =>
|
|
||||||
IOUtils.closeStream(out);
|
|
||||||
throw ex;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def writeIndexMetadata(annIndexMetadata: AnnIndexMetadata): Unit = {
|
|
||||||
val out = createFile(INDEX_METADATA_FILE).getOutputStream()
|
|
||||||
val bytes = ArrayByteBufferCodec.decode(MetadataCodec.encode(annIndexMetadata))
|
|
||||||
out.write(bytes)
|
|
||||||
out.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
def loadIndexMetadata(): AnnIndexMetadata = {
|
|
||||||
val in = ByteStreams.toByteArray(getInputStream())
|
|
||||||
MetadataCodec.decode(ArrayByteBufferCodec.encode(in))
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/IndexTransformer.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/IndexTransformer.docx
Normal file
Binary file not shown.
@ -1,118 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
|
||||||
import com.twitter.storehaus.{ReadableStore, Store}
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
// Utility to transform raw index to typed index using Store
|
|
||||||
object IndexTransformer {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform a long type queryable index to Typed queryable index
|
|
||||||
* @param index: Raw Queryable index
|
|
||||||
* @param store: Readable store to provide mappings between Long and T
|
|
||||||
* @tparam T: Type to transform to
|
|
||||||
* @tparam P: Runtime params
|
|
||||||
* @return Queryable index typed on T
|
|
||||||
*/
|
|
||||||
def transformQueryable[T, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
index: Queryable[Long, P, D],
|
|
||||||
store: ReadableStore[Long, T]
|
|
||||||
): Queryable[T, P, D] = {
|
|
||||||
new Queryable[T, P, D] {
|
|
||||||
override def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[T]] = {
|
|
||||||
val neighbors = index.query(embedding, numOfNeighbors, runtimeParams)
|
|
||||||
neighbors
|
|
||||||
.flatMap(nn => {
|
|
||||||
val ids = nn.map(id => store.get(id).map(_.get))
|
|
||||||
Future
|
|
||||||
.collect(ids)
|
|
||||||
.map(_.toList)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
override def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]] = {
|
|
||||||
val neighbors = index.queryWithDistance(embedding, numOfNeighbors, runtimeParams)
|
|
||||||
neighbors
|
|
||||||
.flatMap(nn => {
|
|
||||||
val ids = nn.map(obj =>
|
|
||||||
store.get(obj.neighbor).map(id => NeighborWithDistance(id.get, obj.distance)))
|
|
||||||
Future
|
|
||||||
.collect(ids)
|
|
||||||
.map(_.toList)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform a long type appendable index to Typed appendable index
|
|
||||||
* @param index: Raw Appendable index
|
|
||||||
* @param store: Writable store to store mappings between Long and T
|
|
||||||
* @tparam T: Type to transform to
|
|
||||||
* @return Appendable index typed on T
|
|
||||||
*/
|
|
||||||
def transformAppendable[T, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
index: RawAppendable[P, D],
|
|
||||||
store: Store[Long, T]
|
|
||||||
): Appendable[T, P, D] = {
|
|
||||||
new Appendable[T, P, D]() {
|
|
||||||
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
|
|
||||||
index
|
|
||||||
.append(entity.embedding)
|
|
||||||
.flatMap(id => store.put((id, Some(entity.id))))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toQueryable: Queryable[T, P, D] = {
|
|
||||||
transformQueryable(index.toQueryable, store)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform a long type appendable and queryable index to Typed appendable and queryable index
|
|
||||||
* @param index: Raw Appendable and queryable index
|
|
||||||
* @param store: Store to provide/store mappings between Long and T
|
|
||||||
* @tparam T: Type to transform to
|
|
||||||
* @tparam Index: Index
|
|
||||||
* @return Appendable and queryable index typed on T
|
|
||||||
*/
|
|
||||||
def transform1[
|
|
||||||
Index <: RawAppendable[P, D] with Queryable[Long, P, D],
|
|
||||||
T,
|
|
||||||
P <: RuntimeParams,
|
|
||||||
D <: Distance[D]
|
|
||||||
](
|
|
||||||
index: Index,
|
|
||||||
store: Store[Long, T]
|
|
||||||
): Queryable[T, P, D] with Appendable[T, P, D] = {
|
|
||||||
val queryable = transformQueryable(index, store)
|
|
||||||
val appendable = transformAppendable(index, store)
|
|
||||||
|
|
||||||
new Queryable[T, P, D] with Appendable[T, P, D] {
|
|
||||||
override def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
) = queryable.query(embedding, numOfNeighbors, runtimeParams)
|
|
||||||
|
|
||||||
override def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
) = queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams)
|
|
||||||
|
|
||||||
override def append(entity: EntityEmbedding[T]) = appendable.append(entity)
|
|
||||||
|
|
||||||
override def toQueryable: Queryable[T, P, D] = appendable.toQueryable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/MemoizedInEpochs.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/MemoizedInEpochs.docx
Normal file
Binary file not shown.
@ -1,37 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.util.Return
|
|
||||||
import com.twitter.util.Throw
|
|
||||||
import com.twitter.util.Try
|
|
||||||
import com.twitter.util.logging.Logging
|
|
||||||
|
|
||||||
// Memoization with a twist
|
|
||||||
// New epoch reuse K:V pairs from previous and recycle everything else
|
|
||||||
class MemoizedInEpochs[K, V](f: K => Try[V]) extends Logging {
|
|
||||||
private var memoizedCalls: Map[K, V] = Map.empty
|
|
||||||
|
|
||||||
def epoch(keys: Seq[K]): Seq[V] = {
|
|
||||||
val newSet = keys.toSet
|
|
||||||
val keysToBeComputed = newSet.diff(memoizedCalls.keySet)
|
|
||||||
val computedKeysAndValues = keysToBeComputed.map { key =>
|
|
||||||
info(s"Memoize ${key}")
|
|
||||||
(key, f(key))
|
|
||||||
}
|
|
||||||
val keysAndValuesAfterFilteringFailures = computedKeysAndValues
|
|
||||||
.flatMap {
|
|
||||||
case (key, Return(value)) => Some((key, value))
|
|
||||||
case (key, Throw(e)) =>
|
|
||||||
warn(s"Calling f for ${key} has failed", e)
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
|
||||||
val keysReusedFromLastEpoch = memoizedCalls.filterKeys(newSet.contains)
|
|
||||||
memoizedCalls = keysReusedFromLastEpoch ++ keysAndValuesAfterFilteringFailures
|
|
||||||
|
|
||||||
debug(s"Final memoization is ${memoizedCalls.keys.mkString(", ")}")
|
|
||||||
|
|
||||||
keys.flatMap(memoizedCalls.get)
|
|
||||||
}
|
|
||||||
|
|
||||||
def currentEpochKeys: Set[K] = memoizedCalls.keySet
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/Metric.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/Metric.docx
Normal file
Binary file not shown.
@ -1,290 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableBiMap
|
|
||||||
import com.twitter.ann.common.EmbeddingType._
|
|
||||||
import com.twitter.ann.common.thriftscala.DistanceMetric
|
|
||||||
import com.twitter.ann.common.thriftscala.{CosineDistance => ServiceCosineDistance}
|
|
||||||
import com.twitter.ann.common.thriftscala.{Distance => ServiceDistance}
|
|
||||||
import com.twitter.ann.common.thriftscala.{InnerProductDistance => ServiceInnerProductDistance}
|
|
||||||
import com.twitter.ann.common.thriftscala.{EditDistance => ServiceEditDistance}
|
|
||||||
import com.twitter.ann.common.thriftscala.{L2Distance => ServiceL2Distance}
|
|
||||||
import com.twitter.bijection.Injection
|
|
||||||
import scala.util.Failure
|
|
||||||
import scala.util.Success
|
|
||||||
import scala.util.Try
|
|
||||||
|
|
||||||
// Ann distance metrics
|
|
||||||
trait Distance[D] extends Any with Ordered[D] {
|
|
||||||
def distance: Float
|
|
||||||
}
|
|
||||||
|
|
||||||
case class L2Distance(distance: Float) extends AnyVal with Distance[L2Distance] {
|
|
||||||
override def compare(that: L2Distance): Int =
|
|
||||||
Ordering.Float.compare(this.distance, that.distance)
|
|
||||||
}
|
|
||||||
|
|
||||||
case class CosineDistance(distance: Float) extends AnyVal with Distance[CosineDistance] {
|
|
||||||
override def compare(that: CosineDistance): Int =
|
|
||||||
Ordering.Float.compare(this.distance, that.distance)
|
|
||||||
}
|
|
||||||
|
|
||||||
case class InnerProductDistance(distance: Float)
|
|
||||||
extends AnyVal
|
|
||||||
with Distance[InnerProductDistance] {
|
|
||||||
override def compare(that: InnerProductDistance): Int =
|
|
||||||
Ordering.Float.compare(this.distance, that.distance)
|
|
||||||
}
|
|
||||||
|
|
||||||
case class EditDistance(distance: Float) extends AnyVal with Distance[EditDistance] {
|
|
||||||
override def compare(that: EditDistance): Int =
|
|
||||||
Ordering.Float.compare(this.distance, that.distance)
|
|
||||||
}
|
|
||||||
|
|
||||||
object Metric {
|
|
||||||
private[this] val thriftMetricMapping = ImmutableBiMap.of(
|
|
||||||
L2,
|
|
||||||
DistanceMetric.L2,
|
|
||||||
Cosine,
|
|
||||||
DistanceMetric.Cosine,
|
|
||||||
InnerProduct,
|
|
||||||
DistanceMetric.InnerProduct,
|
|
||||||
Edit,
|
|
||||||
DistanceMetric.EditDistance
|
|
||||||
)
|
|
||||||
|
|
||||||
def fromThrift(metric: DistanceMetric): Metric[_ <: Distance[_]] = {
|
|
||||||
thriftMetricMapping.inverse().get(metric)
|
|
||||||
}
|
|
||||||
|
|
||||||
def toThrift(metric: Metric[_ <: Distance[_]]): DistanceMetric = {
|
|
||||||
thriftMetricMapping.get(metric)
|
|
||||||
}
|
|
||||||
|
|
||||||
def fromString(metricName: String): Metric[_ <: Distance[_]]
|
|
||||||
with Injection[_, ServiceDistance] = {
|
|
||||||
metricName match {
|
|
||||||
case "Cosine" => Cosine
|
|
||||||
case "L2" => L2
|
|
||||||
case "InnerProduct" => InnerProduct
|
|
||||||
case "EditDistance" => Edit
|
|
||||||
case _ =>
|
|
||||||
throw new IllegalArgumentException(s"No Metric with the name $metricName")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sealed trait Metric[D <: Distance[D]] {
|
|
||||||
def distance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): D
|
|
||||||
def absoluteDistance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Float
|
|
||||||
def fromAbsoluteDistance(distance: Float): D
|
|
||||||
}
|
|
||||||
|
|
||||||
case object L2 extends Metric[L2Distance] with Injection[L2Distance, ServiceDistance] {
|
|
||||||
override def distance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): L2Distance = {
|
|
||||||
fromAbsoluteDistance(MetricUtil.l2distance(embedding1, embedding2).toFloat)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def fromAbsoluteDistance(distance: Float): L2Distance = {
|
|
||||||
L2Distance(distance)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def absoluteDistance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Float = distance(embedding1, embedding2).distance
|
|
||||||
|
|
||||||
override def apply(scalaDistance: L2Distance): ServiceDistance = {
|
|
||||||
ServiceDistance.L2Distance(ServiceL2Distance(scalaDistance.distance))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def invert(serviceDistance: ServiceDistance): Try[L2Distance] = {
|
|
||||||
serviceDistance match {
|
|
||||||
case ServiceDistance.L2Distance(l2Distance) =>
|
|
||||||
Success(L2Distance(l2Distance.distance.toFloat))
|
|
||||||
case distance =>
|
|
||||||
Failure(new IllegalArgumentException(s"Expected an l2 distance but got $distance"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case object Cosine extends Metric[CosineDistance] with Injection[CosineDistance, ServiceDistance] {
|
|
||||||
override def distance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): CosineDistance = {
|
|
||||||
fromAbsoluteDistance(1 - MetricUtil.cosineSimilarity(embedding1, embedding2))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def fromAbsoluteDistance(distance: Float): CosineDistance = {
|
|
||||||
CosineDistance(distance)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def absoluteDistance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Float = distance(embedding1, embedding2).distance
|
|
||||||
|
|
||||||
override def apply(scalaDistance: CosineDistance): ServiceDistance = {
|
|
||||||
ServiceDistance.CosineDistance(ServiceCosineDistance(scalaDistance.distance))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def invert(serviceDistance: ServiceDistance): Try[CosineDistance] = {
|
|
||||||
serviceDistance match {
|
|
||||||
case ServiceDistance.CosineDistance(cosineDistance) =>
|
|
||||||
Success(CosineDistance(cosineDistance.distance.toFloat))
|
|
||||||
case distance =>
|
|
||||||
Failure(new IllegalArgumentException(s"Expected a cosine distance but got $distance"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case object InnerProduct
|
|
||||||
extends Metric[InnerProductDistance]
|
|
||||||
with Injection[InnerProductDistance, ServiceDistance] {
|
|
||||||
override def distance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): InnerProductDistance = {
|
|
||||||
fromAbsoluteDistance(1 - MetricUtil.dot(embedding1, embedding2))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def fromAbsoluteDistance(distance: Float): InnerProductDistance = {
|
|
||||||
InnerProductDistance(distance)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def absoluteDistance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Float = distance(embedding1, embedding2).distance
|
|
||||||
|
|
||||||
override def apply(scalaDistance: InnerProductDistance): ServiceDistance = {
|
|
||||||
ServiceDistance.InnerProductDistance(ServiceInnerProductDistance(scalaDistance.distance))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def invert(
|
|
||||||
serviceDistance: ServiceDistance
|
|
||||||
): Try[InnerProductDistance] = {
|
|
||||||
serviceDistance match {
|
|
||||||
case ServiceDistance.InnerProductDistance(cosineDistance) =>
|
|
||||||
Success(InnerProductDistance(cosineDistance.distance.toFloat))
|
|
||||||
case distance =>
|
|
||||||
Failure(
|
|
||||||
new IllegalArgumentException(s"Expected a inner product distance but got $distance")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case object Edit extends Metric[EditDistance] with Injection[EditDistance, ServiceDistance] {
|
|
||||||
|
|
||||||
private def intDistance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector,
|
|
||||||
pos1: Int,
|
|
||||||
pos2: Int,
|
|
||||||
precomputedDistances: scala.collection.mutable.Map[(Int, Int), Int]
|
|
||||||
): Int = {
|
|
||||||
// return the remaining characters of other String
|
|
||||||
if (pos1 == 0) return pos2
|
|
||||||
if (pos2 == 0) return pos1
|
|
||||||
|
|
||||||
// To check if the recursive tree
|
|
||||||
// for given n & m has already been executed
|
|
||||||
precomputedDistances.getOrElse(
|
|
||||||
(pos1, pos2), {
|
|
||||||
// We might want to change this so that capitals are considered the same.
|
|
||||||
// Also maybe some characters that look similar should also be the same.
|
|
||||||
val computed = if (embedding1(pos1 - 1) == embedding2(pos2 - 1)) {
|
|
||||||
intDistance(embedding1, embedding2, pos1 - 1, pos2 - 1, precomputedDistances)
|
|
||||||
} else { // If characters are nt equal, we need to
|
|
||||||
// find the minimum cost out of all 3 operations.
|
|
||||||
val insert = intDistance(embedding1, embedding2, pos1, pos2 - 1, precomputedDistances)
|
|
||||||
val del = intDistance(embedding1, embedding2, pos1 - 1, pos2, precomputedDistances)
|
|
||||||
val replace =
|
|
||||||
intDistance(embedding1, embedding2, pos1 - 1, pos2 - 1, precomputedDistances)
|
|
||||||
1 + Math.min(insert, Math.min(del, replace))
|
|
||||||
}
|
|
||||||
precomputedDistances.put((pos1, pos2), computed)
|
|
||||||
computed
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def distance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): EditDistance = {
|
|
||||||
val editDistance = intDistance(
|
|
||||||
embedding1,
|
|
||||||
embedding2,
|
|
||||||
embedding1.length,
|
|
||||||
embedding2.length,
|
|
||||||
scala.collection.mutable.Map[(Int, Int), Int]()
|
|
||||||
)
|
|
||||||
EditDistance(editDistance)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def fromAbsoluteDistance(distance: Float): EditDistance = {
|
|
||||||
EditDistance(distance.toInt)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def absoluteDistance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Float = distance(embedding1, embedding2).distance
|
|
||||||
|
|
||||||
override def apply(scalaDistance: EditDistance): ServiceDistance = {
|
|
||||||
ServiceDistance.EditDistance(ServiceEditDistance(scalaDistance.distance.toInt))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def invert(
|
|
||||||
serviceDistance: ServiceDistance
|
|
||||||
): Try[EditDistance] = {
|
|
||||||
serviceDistance match {
|
|
||||||
case ServiceDistance.EditDistance(cosineDistance) =>
|
|
||||||
Success(EditDistance(cosineDistance.distance.toFloat))
|
|
||||||
case distance =>
|
|
||||||
Failure(
|
|
||||||
new IllegalArgumentException(s"Expected a inner product distance but got $distance")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
object MetricUtil {
|
|
||||||
private[ann] def dot(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Float = {
|
|
||||||
math.dotProduct(embedding1, embedding2)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[ann] def l2distance(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Double = {
|
|
||||||
math.l2Distance(embedding1, embedding2)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[ann] def cosineSimilarity(
|
|
||||||
embedding1: EmbeddingVector,
|
|
||||||
embedding2: EmbeddingVector
|
|
||||||
): Float = {
|
|
||||||
math.cosineSimilarity(embedding1, embedding2).toFloat
|
|
||||||
}
|
|
||||||
|
|
||||||
private[ann] def norm(
|
|
||||||
embedding: EmbeddingVector
|
|
||||||
): EmbeddingVector = {
|
|
||||||
math.normalize(embedding)
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/QueryableById.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/QueryableById.docx
Normal file
Binary file not shown.
@ -1,41 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This is a trait that allows you to query for nearest neighbors given an arbitrary type T1. This is
|
|
||||||
* in contrast to a regular com.twitter.ann.common.Appendable, which takes an embedding as the input
|
|
||||||
* argument.
|
|
||||||
*
|
|
||||||
* This interface uses the Stitch API for batching. See go/stitch for details on how to use it.
|
|
||||||
*
|
|
||||||
* @tparam T1 type of the query.
|
|
||||||
* @tparam T2 type of the result.
|
|
||||||
* @tparam P runtime parameters supported by the index.
|
|
||||||
* @tparam D distance function used in the index.
|
|
||||||
*/
|
|
||||||
trait QueryableById[T1, T2, P <: RuntimeParams, D <: Distance[D]] {
|
|
||||||
def queryById(
|
|
||||||
id: T1,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[T2]]
|
|
||||||
|
|
||||||
def queryByIdWithDistance(
|
|
||||||
id: T1,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[NeighborWithDistance[T2, D]]]
|
|
||||||
|
|
||||||
def batchQueryById(
|
|
||||||
ids: Seq[T1],
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[NeighborWithSeed[T1, T2]]]
|
|
||||||
|
|
||||||
def batchQueryWithDistanceById(
|
|
||||||
ids: Seq[T1],
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[NeighborWithDistanceWithSeed[T1, T2, D]]]
|
|
||||||
}
|
|
Binary file not shown.
@ -1,91 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implementation of QueryableById that composes an EmbeddingProducer and a Queryable so that we
|
|
||||||
* can get nearest neighbors given an id of type T1
|
|
||||||
* @param embeddingProducer provides an embedding given an id.
|
|
||||||
* @param queryable provides a list of neighbors given an embedding.
|
|
||||||
* @tparam T1 type of the query.
|
|
||||||
* @tparam T2 type of the result.
|
|
||||||
* @tparam P runtime parameters supported by the index.
|
|
||||||
* @tparam D distance function used in the index.
|
|
||||||
*/
|
|
||||||
class QueryableByIdImplementation[T1, T2, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
embeddingProducer: EmbeddingProducer[T1],
|
|
||||||
queryable: Queryable[T2, P, D])
|
|
||||||
extends QueryableById[T1, T2, P, D] {
|
|
||||||
override def queryById(
|
|
||||||
id: T1,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[T2]] = {
|
|
||||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
|
||||||
embeddingOption
|
|
||||||
.map { embedding =>
|
|
||||||
Stitch.callFuture(queryable.query(embedding, numOfNeighbors, runtimeParams))
|
|
||||||
}.getOrElse {
|
|
||||||
Stitch.value(List.empty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def queryByIdWithDistance(
|
|
||||||
id: T1,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[NeighborWithDistance[T2, D]]] = {
|
|
||||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
|
||||||
embeddingOption
|
|
||||||
.map { embedding =>
|
|
||||||
Stitch.callFuture(queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
|
|
||||||
}.getOrElse {
|
|
||||||
Stitch.value(List.empty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def batchQueryById(
|
|
||||||
ids: Seq[T1],
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[NeighborWithSeed[T1, T2]]] = {
|
|
||||||
Stitch
|
|
||||||
.traverse(ids) { id =>
|
|
||||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
|
||||||
embeddingOption
|
|
||||||
.map { embedding =>
|
|
||||||
Stitch
|
|
||||||
.callFuture(queryable.query(embedding, numOfNeighbors, runtimeParams)).map(
|
|
||||||
_.map(neighbor => NeighborWithSeed(id, neighbor)))
|
|
||||||
}.getOrElse {
|
|
||||||
Stitch.value(List.empty)
|
|
||||||
}.handle { case _ => List.empty }
|
|
||||||
}
|
|
||||||
}.map { _.toList.flatten }
|
|
||||||
}
|
|
||||||
|
|
||||||
override def batchQueryWithDistanceById(
|
|
||||||
ids: Seq[T1],
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Stitch[List[NeighborWithDistanceWithSeed[T1, T2, D]]] = {
|
|
||||||
Stitch
|
|
||||||
.traverse(ids) { id =>
|
|
||||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
|
||||||
embeddingOption
|
|
||||||
.map { embedding =>
|
|
||||||
Stitch
|
|
||||||
.callFuture(queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
|
|
||||||
.map(_.map(neighbor =>
|
|
||||||
NeighborWithDistanceWithSeed(id, neighbor.neighbor, neighbor.distance)))
|
|
||||||
}.getOrElse {
|
|
||||||
Stitch.value(List.empty)
|
|
||||||
}.handle { case _ => List.empty }
|
|
||||||
}
|
|
||||||
}.map {
|
|
||||||
_.toList.flatten
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,26 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
object QueryableOperations {
|
|
||||||
implicit class Map[T, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
val q: Queryable[T, P, D]) {
|
|
||||||
def mapRuntimeParameters(f: P => P): Queryable[T, P, D] = {
|
|
||||||
new Queryable[T, P, D] {
|
|
||||||
def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[T]] = q.query(embedding, numOfNeighbors, f(runtimeParams))
|
|
||||||
|
|
||||||
def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]] =
|
|
||||||
q.queryWithDistance(embedding, numOfNeighbors, f(runtimeParams))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,29 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
import com.google.common.annotations.VisibleForTesting
|
|
||||||
import com.twitter.util.{Future, FuturePool}
|
|
||||||
|
|
||||||
trait ReadWriteFuturePool {
|
|
||||||
def read[T](f: => T): Future[T]
|
|
||||||
def write[T](f: => T): Future[T]
|
|
||||||
}
|
|
||||||
|
|
||||||
object ReadWriteFuturePool {
|
|
||||||
def apply(readPool: FuturePool, writePool: FuturePool): ReadWriteFuturePool = {
|
|
||||||
new ReadWriteFuturePoolANN(readPool, writePool)
|
|
||||||
}
|
|
||||||
|
|
||||||
def apply(commonPool: FuturePool): ReadWriteFuturePool = {
|
|
||||||
new ReadWriteFuturePoolANN(commonPool, commonPool)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
private[ann] class ReadWriteFuturePoolANN(readPool: FuturePool, writePool: FuturePool)
|
|
||||||
extends ReadWriteFuturePool {
|
|
||||||
def read[T](f: => T): Future[T] = {
|
|
||||||
readPool.apply(f)
|
|
||||||
}
|
|
||||||
def write[T](f: => T): Future[T] = {
|
|
||||||
writePool.apply(f)
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/Serialization.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/Serialization.docx
Normal file
Binary file not shown.
@ -1,28 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.search.common.file.AbstractFile
|
|
||||||
import org.apache.beam.sdk.io.fs.ResourceId
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for writing an Appendable to a directory.
|
|
||||||
*/
|
|
||||||
trait Serialization {
|
|
||||||
def toDirectory(
|
|
||||||
serializationDirectory: AbstractFile
|
|
||||||
): Unit
|
|
||||||
|
|
||||||
def toDirectory(
|
|
||||||
serializationDirectory: ResourceId
|
|
||||||
): Unit
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for reading a Queryable from a directory
|
|
||||||
* @tparam T the id of the embeddings
|
|
||||||
* @tparam Q type of the Queryable that is deserialized.
|
|
||||||
*/
|
|
||||||
trait QueryableDeserialization[T, P <: RuntimeParams, D <: Distance[D], Q <: Queryable[T, P, D]] {
|
|
||||||
def fromDirectory(
|
|
||||||
serializationDirectory: AbstractFile
|
|
||||||
): Q
|
|
||||||
}
|
|
Binary file not shown.
@ -1,64 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.ann.common.EmbeddingType._
|
|
||||||
import com.twitter.ann.common.thriftscala.{
|
|
||||||
NearestNeighborQuery,
|
|
||||||
NearestNeighborResult,
|
|
||||||
Distance => ServiceDistance,
|
|
||||||
RuntimeParams => ServiceRuntimeParams
|
|
||||||
}
|
|
||||||
import com.twitter.bijection.Injection
|
|
||||||
import com.twitter.finagle.Service
|
|
||||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
|
||||||
import com.twitter.util.Future
|
|
||||||
|
|
||||||
class ServiceClientQueryable[T, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
service: Service[NearestNeighborQuery, NearestNeighborResult],
|
|
||||||
runtimeParamInjection: Injection[P, ServiceRuntimeParams],
|
|
||||||
distanceInjection: Injection[D, ServiceDistance],
|
|
||||||
idInjection: Injection[T, Array[Byte]])
|
|
||||||
extends Queryable[T, P, D] {
|
|
||||||
override def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[T]] = {
|
|
||||||
service
|
|
||||||
.apply(
|
|
||||||
NearestNeighborQuery(
|
|
||||||
embeddingSerDe.toThrift(embedding),
|
|
||||||
withDistance = false,
|
|
||||||
runtimeParamInjection(runtimeParams),
|
|
||||||
numOfNeighbors
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.map { result =>
|
|
||||||
result.nearestNeighbors.map { nearestNeighbor =>
|
|
||||||
idInjection.invert(ArrayByteBufferCodec.decode(nearestNeighbor.id)).get
|
|
||||||
}.toList
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]] =
|
|
||||||
service
|
|
||||||
.apply(
|
|
||||||
NearestNeighborQuery(
|
|
||||||
embeddingSerDe.toThrift(embedding),
|
|
||||||
withDistance = true,
|
|
||||||
runtimeParamInjection(runtimeParams),
|
|
||||||
numOfNeighbors
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.map { result =>
|
|
||||||
result.nearestNeighbors.map { nearestNeighbor =>
|
|
||||||
NeighborWithDistance(
|
|
||||||
idInjection.invert(ArrayByteBufferCodec.decode(nearestNeighbor.id)).get,
|
|
||||||
distanceInjection.invert(nearestNeighbor.distance.get).get
|
|
||||||
)
|
|
||||||
}.toList
|
|
||||||
}
|
|
||||||
}
|
|
BIN
ann/src/main/scala/com/twitter/ann/common/ShardApi.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/ShardApi.docx
Normal file
Binary file not shown.
@ -1,87 +0,0 @@
|
|||||||
package com.twitter.ann.common
|
|
||||||
|
|
||||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
|
||||||
import com.twitter.util.Future
|
|
||||||
import scala.util.Random
|
|
||||||
|
|
||||||
trait ShardFunction[T] {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Shard function to shard embedding based on total shards and embedding data.
|
|
||||||
* @param shards
|
|
||||||
* @param entity
|
|
||||||
* @return Shard index, from 0(Inclusive) to shards(Exclusive))
|
|
||||||
*/
|
|
||||||
def apply(shards: Int, entity: EntityEmbedding[T]): Int
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Randomly shards the embeddings based on number of total shards.
|
|
||||||
*/
|
|
||||||
class RandomShardFunction[T] extends ShardFunction[T] {
|
|
||||||
def apply(shards: Int, entity: EntityEmbedding[T]): Int = {
|
|
||||||
Random.nextInt(shards)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sharded appendable to shard the embedding into different appendable indices
|
|
||||||
* @param indices: Sequence of appendable indices
|
|
||||||
* @param shardFn: Shard function to shard data into different indices
|
|
||||||
* @param shards: Total shards
|
|
||||||
* @tparam T: Type of id.
|
|
||||||
*/
|
|
||||||
class ShardedAppendable[T, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
indices: Seq[Appendable[T, P, D]],
|
|
||||||
shardFn: ShardFunction[T],
|
|
||||||
shards: Int)
|
|
||||||
extends Appendable[T, P, D] {
|
|
||||||
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
|
|
||||||
val shard = shardFn(shards, entity)
|
|
||||||
val index = indices(shard)
|
|
||||||
index.append(entity)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toQueryable: Queryable[T, P, D] = {
|
|
||||||
new ComposedQueryable[T, P, D](indices.map(_.toQueryable))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Composition of sequence of queryable indices, it queries all the indices,
|
|
||||||
* and merges the result in memory to return the K nearest neighbours
|
|
||||||
* @param indices: Sequence of queryable indices
|
|
||||||
* @tparam T: Type of id
|
|
||||||
* @tparam P: Type of runtime param
|
|
||||||
* @tparam D: Type of distance metric
|
|
||||||
*/
|
|
||||||
class ComposedQueryable[T, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
indices: Seq[Queryable[T, P, D]])
|
|
||||||
extends Queryable[T, P, D] {
|
|
||||||
private[this] val ordering =
|
|
||||||
Ordering.by[NeighborWithDistance[T, D], D](_.distance)
|
|
||||||
override def query(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[T]] = {
|
|
||||||
val neighbours = queryWithDistance(embedding, numOfNeighbors, runtimeParams)
|
|
||||||
neighbours.map(list => list.map(nn => nn.neighbor))
|
|
||||||
}
|
|
||||||
|
|
||||||
override def queryWithDistance(
|
|
||||||
embedding: EmbeddingVector,
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P
|
|
||||||
): Future[List[NeighborWithDistance[T, D]]] = {
|
|
||||||
val futures = Future.collect(
|
|
||||||
indices.map(index => index.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
|
|
||||||
)
|
|
||||||
futures.map { list =>
|
|
||||||
list.flatten
|
|
||||||
.sorted(ordering)
|
|
||||||
.take(numOfNeighbors)
|
|
||||||
.toList
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user