mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-11-14 07:35:10 +01:00
[docx] split commit for file 400
Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
parent
6c4587804f
commit
3c586de8ec
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/doubleArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/doubleArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class doubleArray {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected doubleArray(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(doubleArray obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_doubleArray(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public doubleArray(int nelements) {
|
||||
this(swigfaissJNI.new_doubleArray(nelements), true);
|
||||
}
|
||||
|
||||
public double getitem(int index) {
|
||||
return swigfaissJNI.doubleArray_getitem(swigCPtr, this, index);
|
||||
}
|
||||
|
||||
public void setitem(int index, double value) {
|
||||
swigfaissJNI.doubleArray_setitem(swigCPtr, this, index, value);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_double cast() {
|
||||
long cPtr = swigfaissJNI.doubleArray_cast(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_double(cPtr, false);
|
||||
}
|
||||
|
||||
public static doubleArray frompointer(SWIGTYPE_p_double t) {
|
||||
long cPtr = swigfaissJNI.doubleArray_frompointer(SWIGTYPE_p_double.getCPtr(t));
|
||||
return (cPtr == 0) ? null : new doubleArray(cPtr, false);
|
||||
}
|
||||
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/floatArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/floatArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class floatArray {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected floatArray(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(floatArray obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_floatArray(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public floatArray(int nelements) {
|
||||
this(swigfaissJNI.new_floatArray(nelements), true);
|
||||
}
|
||||
|
||||
public float getitem(int index) {
|
||||
return swigfaissJNI.floatArray_getitem(swigCPtr, this, index);
|
||||
}
|
||||
|
||||
public void setitem(int index, float value) {
|
||||
swigfaissJNI.floatArray_setitem(swigCPtr, this, index, value);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_float cast() {
|
||||
long cPtr = swigfaissJNI.floatArray_cast(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public static floatArray frompointer(SWIGTYPE_p_float t) {
|
||||
long cPtr = swigfaissJNI.floatArray_frompointer(SWIGTYPE_p_float.getCPtr(t));
|
||||
return (cPtr == 0) ? null : new floatArray(cPtr, false);
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
@ -1,133 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class float_maxheap_array_t {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected float_maxheap_array_t(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(float_maxheap_array_t obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_float_maxheap_array_t(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public void setNh(long value) {
|
||||
swigfaissJNI.float_maxheap_array_t_nh_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getNh() {
|
||||
return swigfaissJNI.float_maxheap_array_t_nh_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setK(long value) {
|
||||
swigfaissJNI.float_maxheap_array_t_k_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getK() {
|
||||
return swigfaissJNI.float_maxheap_array_t_k_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setIds(LongVector value) {
|
||||
swigfaissJNI.float_maxheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
||||
}
|
||||
|
||||
public LongVector getIds() {
|
||||
return new LongVector(swigfaissJNI.float_maxheap_array_t_ids_get(swigCPtr, this), false);
|
||||
}
|
||||
|
||||
public void setVal(SWIGTYPE_p_float value) {
|
||||
swigfaissJNI.float_maxheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_float.getCPtr(value));
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_float getVal() {
|
||||
long cPtr = swigfaissJNI.float_maxheap_array_t_val_get(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_float get_val(long key) {
|
||||
long cPtr = swigfaissJNI.float_maxheap_array_t_get_val(swigCPtr, this, key);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public LongVector get_ids(long key) {
|
||||
return new LongVector(swigfaissJNI.float_maxheap_array_t_get_ids(swigCPtr, this, key), false);
|
||||
}
|
||||
|
||||
public void heapify() {
|
||||
swigfaissJNI.float_maxheap_array_t_heapify(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0, long ni) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0, ni);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0, long ni) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin) {
|
||||
swigfaissJNI.float_maxheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void reorder() {
|
||||
swigfaissJNI.float_maxheap_array_t_reorder(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void per_line_extrema(SWIGTYPE_p_float vals_out, LongVector idx_out) {
|
||||
swigfaissJNI.float_maxheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_float.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
||||
}
|
||||
|
||||
public float_maxheap_array_t() {
|
||||
this(swigfaissJNI.new_float_maxheap_array_t(), true);
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
@ -1,133 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class float_minheap_array_t {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected float_minheap_array_t(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(float_minheap_array_t obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_float_minheap_array_t(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public void setNh(long value) {
|
||||
swigfaissJNI.float_minheap_array_t_nh_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getNh() {
|
||||
return swigfaissJNI.float_minheap_array_t_nh_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setK(long value) {
|
||||
swigfaissJNI.float_minheap_array_t_k_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getK() {
|
||||
return swigfaissJNI.float_minheap_array_t_k_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setIds(LongVector value) {
|
||||
swigfaissJNI.float_minheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
||||
}
|
||||
|
||||
public LongVector getIds() {
|
||||
return new LongVector(swigfaissJNI.float_minheap_array_t_ids_get(swigCPtr, this), false);
|
||||
}
|
||||
|
||||
public void setVal(SWIGTYPE_p_float value) {
|
||||
swigfaissJNI.float_minheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_float.getCPtr(value));
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_float getVal() {
|
||||
long cPtr = swigfaissJNI.float_minheap_array_t_val_get(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_float get_val(long key) {
|
||||
long cPtr = swigfaissJNI.float_minheap_array_t_get_val(swigCPtr, this, key);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public LongVector get_ids(long key) {
|
||||
return new LongVector(swigfaissJNI.float_minheap_array_t_get_ids(swigCPtr, this, key), false);
|
||||
}
|
||||
|
||||
public void heapify() {
|
||||
swigfaissJNI.float_minheap_array_t_heapify(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0, long ni) {
|
||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0, ni);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0, long i0) {
|
||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0, i0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin, long j0) {
|
||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), j0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_float vin) {
|
||||
swigfaissJNI.float_minheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0, long ni) {
|
||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride, long i0) {
|
||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in, long id_stride) {
|
||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin, LongVector id_in) {
|
||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_float vin) {
|
||||
swigfaissJNI.float_minheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_float.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void reorder() {
|
||||
swigfaissJNI.float_minheap_array_t_reorder(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void per_line_extrema(SWIGTYPE_p_float vals_out, LongVector idx_out) {
|
||||
swigfaissJNI.float_minheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_float.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
||||
}
|
||||
|
||||
public float_minheap_array_t() {
|
||||
this(swigfaissJNI.new_float_minheap_array_t(), true);
|
||||
}
|
||||
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/intArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/intArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class intArray {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected intArray(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(intArray obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_intArray(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public intArray(int nelements) {
|
||||
this(swigfaissJNI.new_intArray(nelements), true);
|
||||
}
|
||||
|
||||
public int getitem(int index) {
|
||||
return swigfaissJNI.intArray_getitem(swigCPtr, this, index);
|
||||
}
|
||||
|
||||
public void setitem(int index, int value) {
|
||||
swigfaissJNI.intArray_setitem(swigCPtr, this, index, value);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_int cast() {
|
||||
long cPtr = swigfaissJNI.intArray_cast(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
||||
}
|
||||
|
||||
public static intArray frompointer(SWIGTYPE_p_int t) {
|
||||
long cPtr = swigfaissJNI.intArray_frompointer(SWIGTYPE_p_int.getCPtr(t));
|
||||
return (cPtr == 0) ? null : new intArray(cPtr, false);
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
@ -1,133 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class int_maxheap_array_t {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected int_maxheap_array_t(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(int_maxheap_array_t obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_int_maxheap_array_t(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public void setNh(long value) {
|
||||
swigfaissJNI.int_maxheap_array_t_nh_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getNh() {
|
||||
return swigfaissJNI.int_maxheap_array_t_nh_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setK(long value) {
|
||||
swigfaissJNI.int_maxheap_array_t_k_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getK() {
|
||||
return swigfaissJNI.int_maxheap_array_t_k_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setIds(LongVector value) {
|
||||
swigfaissJNI.int_maxheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
||||
}
|
||||
|
||||
public LongVector getIds() {
|
||||
return new LongVector(swigfaissJNI.int_maxheap_array_t_ids_get(swigCPtr, this), false);
|
||||
}
|
||||
|
||||
public void setVal(SWIGTYPE_p_int value) {
|
||||
swigfaissJNI.int_maxheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_int.getCPtr(value));
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_int getVal() {
|
||||
long cPtr = swigfaissJNI.int_maxheap_array_t_val_get(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_int get_val(long key) {
|
||||
long cPtr = swigfaissJNI.int_maxheap_array_t_get_val(swigCPtr, this, key);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
||||
}
|
||||
|
||||
public LongVector get_ids(long key) {
|
||||
return new LongVector(swigfaissJNI.int_maxheap_array_t_get_ids(swigCPtr, this, key), false);
|
||||
}
|
||||
|
||||
public void heapify() {
|
||||
swigfaissJNI.int_maxheap_array_t_heapify(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0, long ni) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0, ni);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0, long ni) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin) {
|
||||
swigfaissJNI.int_maxheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void reorder() {
|
||||
swigfaissJNI.int_maxheap_array_t_reorder(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void per_line_extrema(SWIGTYPE_p_int vals_out, LongVector idx_out) {
|
||||
swigfaissJNI.int_maxheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_int.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
||||
}
|
||||
|
||||
public int_maxheap_array_t() {
|
||||
this(swigfaissJNI.new_int_maxheap_array_t(), true);
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
@ -1,133 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class int_minheap_array_t {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected int_minheap_array_t(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(int_minheap_array_t obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_int_minheap_array_t(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public void setNh(long value) {
|
||||
swigfaissJNI.int_minheap_array_t_nh_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getNh() {
|
||||
return swigfaissJNI.int_minheap_array_t_nh_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setK(long value) {
|
||||
swigfaissJNI.int_minheap_array_t_k_set(swigCPtr, this, value);
|
||||
}
|
||||
|
||||
public long getK() {
|
||||
return swigfaissJNI.int_minheap_array_t_k_get(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void setIds(LongVector value) {
|
||||
swigfaissJNI.int_minheap_array_t_ids_set(swigCPtr, this, SWIGTYPE_p_long_long.getCPtr(value.data()), value);
|
||||
}
|
||||
|
||||
public LongVector getIds() {
|
||||
return new LongVector(swigfaissJNI.int_minheap_array_t_ids_get(swigCPtr, this), false);
|
||||
}
|
||||
|
||||
public void setVal(SWIGTYPE_p_int value) {
|
||||
swigfaissJNI.int_minheap_array_t_val_set(swigCPtr, this, SWIGTYPE_p_int.getCPtr(value));
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_int getVal() {
|
||||
long cPtr = swigfaissJNI.int_minheap_array_t_val_get(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_int get_val(long key) {
|
||||
long cPtr = swigfaissJNI.int_minheap_array_t_get_val(swigCPtr, this, key);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
||||
}
|
||||
|
||||
public LongVector get_ids(long key) {
|
||||
return new LongVector(swigfaissJNI.int_minheap_array_t_get_ids(swigCPtr, this, key), false);
|
||||
}
|
||||
|
||||
public void heapify() {
|
||||
swigfaissJNI.int_minheap_array_t_heapify(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0, long ni) {
|
||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0, ni);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0, long i0) {
|
||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0, i0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin, long j0) {
|
||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), j0);
|
||||
}
|
||||
|
||||
public void addn(long nj, SWIGTYPE_p_int vin) {
|
||||
swigfaissJNI.int_minheap_array_t_addn__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0, long ni) {
|
||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_0(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0, ni);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride, long i0) {
|
||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_1(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride, i0);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in, long id_stride) {
|
||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_2(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in, id_stride);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin, LongVector id_in) {
|
||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_3(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin), SWIGTYPE_p_long_long.getCPtr(id_in.data()), id_in);
|
||||
}
|
||||
|
||||
public void addn_with_ids(long nj, SWIGTYPE_p_int vin) {
|
||||
swigfaissJNI.int_minheap_array_t_addn_with_ids__SWIG_4(swigCPtr, this, nj, SWIGTYPE_p_int.getCPtr(vin));
|
||||
}
|
||||
|
||||
public void reorder() {
|
||||
swigfaissJNI.int_minheap_array_t_reorder(swigCPtr, this);
|
||||
}
|
||||
|
||||
public void per_line_extrema(SWIGTYPE_p_int vals_out, LongVector idx_out) {
|
||||
swigfaissJNI.int_minheap_array_t_per_line_extrema(swigCPtr, this, SWIGTYPE_p_int.getCPtr(vals_out), SWIGTYPE_p_long_long.getCPtr(idx_out.data()), idx_out);
|
||||
}
|
||||
|
||||
public int_minheap_array_t() {
|
||||
this(swigfaissJNI.new_int_minheap_array_t(), true);
|
||||
}
|
||||
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/longArray.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/longArray.docx
Normal file
Binary file not shown.
@ -1,61 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class longArray {
|
||||
private transient long swigCPtr;
|
||||
protected transient boolean swigCMemOwn;
|
||||
|
||||
protected longArray(long cPtr, boolean cMemoryOwn) {
|
||||
swigCMemOwn = cMemoryOwn;
|
||||
swigCPtr = cPtr;
|
||||
}
|
||||
|
||||
protected static long getCPtr(longArray obj) {
|
||||
return (obj == null) ? 0 : obj.swigCPtr;
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if (swigCPtr != 0) {
|
||||
if (swigCMemOwn) {
|
||||
swigCMemOwn = false;
|
||||
swigfaissJNI.delete_longArray(swigCPtr);
|
||||
}
|
||||
swigCPtr = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public longArray(int nelements) {
|
||||
this(swigfaissJNI.new_longArray(nelements), true);
|
||||
}
|
||||
|
||||
public long getitem(int index) {
|
||||
return swigfaissJNI.longArray_getitem(swigCPtr, this, index);
|
||||
}
|
||||
|
||||
public void setitem(int index, long value) {
|
||||
swigfaissJNI.longArray_setitem(swigCPtr, this, index, value);
|
||||
}
|
||||
|
||||
public SWIGTYPE_p_long_long cast() {
|
||||
long cPtr = swigfaissJNI.longArray_cast(swigCPtr, this);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_long_long(cPtr, false);
|
||||
}
|
||||
|
||||
public static longArray frompointer(SWIGTYPE_p_long_long t) {
|
||||
long cPtr = swigfaissJNI.longArray_frompointer(SWIGTYPE_p_long_long.getCPtr(t));
|
||||
return (cPtr == 0) ? null : new longArray(cPtr, false);
|
||||
}
|
||||
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaiss.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaiss.docx
Normal file
Binary file not shown.
@ -1,575 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public class swigfaiss implements swigfaissConstants {
|
||||
public static void bitvec_print(SWIGTYPE_p_unsigned_char b, long d) {
|
||||
swigfaissJNI.bitvec_print(SWIGTYPE_p_unsigned_char.getCPtr(b), d);
|
||||
}
|
||||
|
||||
public static void fvecs2bitvecs(SWIGTYPE_p_float x, SWIGTYPE_p_unsigned_char b, long d, long n) {
|
||||
swigfaissJNI.fvecs2bitvecs(SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_unsigned_char.getCPtr(b), d, n);
|
||||
}
|
||||
|
||||
public static void bitvecs2fvecs(SWIGTYPE_p_unsigned_char b, SWIGTYPE_p_float x, long d, long n) {
|
||||
swigfaissJNI.bitvecs2fvecs(SWIGTYPE_p_unsigned_char.getCPtr(b), SWIGTYPE_p_float.getCPtr(x), d, n);
|
||||
}
|
||||
|
||||
public static void fvec2bitvec(SWIGTYPE_p_float x, SWIGTYPE_p_unsigned_char b, long d) {
|
||||
swigfaissJNI.fvec2bitvec(SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_unsigned_char.getCPtr(b), d);
|
||||
}
|
||||
|
||||
public static void bitvec_shuffle(long n, long da, long db, SWIGTYPE_p_int order, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b) {
|
||||
swigfaissJNI.bitvec_shuffle(n, da, db, SWIGTYPE_p_int.getCPtr(order), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b));
|
||||
}
|
||||
|
||||
public static void setHamming_batch_size(long value) {
|
||||
swigfaissJNI.hamming_batch_size_set(value);
|
||||
}
|
||||
|
||||
public static long getHamming_batch_size() {
|
||||
return swigfaissJNI.hamming_batch_size_get();
|
||||
}
|
||||
|
||||
public static int popcount64(long x) {
|
||||
return swigfaissJNI.popcount64(x);
|
||||
}
|
||||
|
||||
public static void hammings(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, long nbytespercode, SWIGTYPE_p_int dis) {
|
||||
swigfaissJNI.hammings(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, nbytespercode, SWIGTYPE_p_int.getCPtr(dis));
|
||||
}
|
||||
|
||||
public static void hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long ncodes, int ordered) {
|
||||
swigfaissJNI.hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, ncodes, ordered);
|
||||
}
|
||||
|
||||
public static void hammings_knn(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long ncodes, int ordered) {
|
||||
swigfaissJNI.hammings_knn(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, ncodes, ordered);
|
||||
}
|
||||
|
||||
public static void hammings_knn_mc(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, long k, long ncodes, SWIGTYPE_p_int distances, LongVector labels) {
|
||||
swigfaissJNI.hammings_knn_mc(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, k, ncodes, SWIGTYPE_p_int.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels);
|
||||
}
|
||||
|
||||
public static void hamming_range_search(SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long na, long nb, int radius, long ncodes, RangeSearchResult result) {
|
||||
swigfaissJNI.hamming_range_search(SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), na, nb, radius, ncodes, RangeSearchResult.getCPtr(result), result);
|
||||
}
|
||||
|
||||
public static void hamming_count_thres(SWIGTYPE_p_unsigned_char bs1, SWIGTYPE_p_unsigned_char bs2, long n1, long n2, int ht, long ncodes, SWIGTYPE_p_unsigned_long nptr) {
|
||||
swigfaissJNI.hamming_count_thres(SWIGTYPE_p_unsigned_char.getCPtr(bs1), SWIGTYPE_p_unsigned_char.getCPtr(bs2), n1, n2, ht, ncodes, SWIGTYPE_p_unsigned_long.getCPtr(nptr));
|
||||
}
|
||||
|
||||
public static long match_hamming_thres(SWIGTYPE_p_unsigned_char bs1, SWIGTYPE_p_unsigned_char bs2, long n1, long n2, int ht, long ncodes, LongVector idx, SWIGTYPE_p_int dis) {
|
||||
return swigfaissJNI.match_hamming_thres(SWIGTYPE_p_unsigned_char.getCPtr(bs1), SWIGTYPE_p_unsigned_char.getCPtr(bs2), n1, n2, ht, ncodes, SWIGTYPE_p_long_long.getCPtr(idx.data()), idx, SWIGTYPE_p_int.getCPtr(dis));
|
||||
}
|
||||
|
||||
public static void crosshamming_count_thres(SWIGTYPE_p_unsigned_char dbs, long n, int ht, long ncodes, SWIGTYPE_p_unsigned_long nptr) {
|
||||
swigfaissJNI.crosshamming_count_thres(SWIGTYPE_p_unsigned_char.getCPtr(dbs), n, ht, ncodes, SWIGTYPE_p_unsigned_long.getCPtr(nptr));
|
||||
}
|
||||
|
||||
public static int get_num_gpus() {
|
||||
return swigfaissJNI.get_num_gpus();
|
||||
}
|
||||
|
||||
public static String get_compile_options() {
|
||||
return swigfaissJNI.get_compile_options();
|
||||
}
|
||||
|
||||
public static double getmillisecs() {
|
||||
return swigfaissJNI.getmillisecs();
|
||||
}
|
||||
|
||||
public static long get_mem_usage_kb() {
|
||||
return swigfaissJNI.get_mem_usage_kb();
|
||||
}
|
||||
|
||||
public static long get_cycles() {
|
||||
return swigfaissJNI.get_cycles();
|
||||
}
|
||||
|
||||
public static void fvec_madd(long n, SWIGTYPE_p_float a, float bf, SWIGTYPE_p_float b, SWIGTYPE_p_float c) {
|
||||
swigfaissJNI.fvec_madd(n, SWIGTYPE_p_float.getCPtr(a), bf, SWIGTYPE_p_float.getCPtr(b), SWIGTYPE_p_float.getCPtr(c));
|
||||
}
|
||||
|
||||
public static int fvec_madd_and_argmin(long n, SWIGTYPE_p_float a, float bf, SWIGTYPE_p_float b, SWIGTYPE_p_float c) {
|
||||
return swigfaissJNI.fvec_madd_and_argmin(n, SWIGTYPE_p_float.getCPtr(a), bf, SWIGTYPE_p_float.getCPtr(b), SWIGTYPE_p_float.getCPtr(c));
|
||||
}
|
||||
|
||||
public static void reflection(SWIGTYPE_p_float u, SWIGTYPE_p_float x, long n, long d, long nu) {
|
||||
swigfaissJNI.reflection(SWIGTYPE_p_float.getCPtr(u), SWIGTYPE_p_float.getCPtr(x), n, d, nu);
|
||||
}
|
||||
|
||||
public static void matrix_qr(int m, int n, SWIGTYPE_p_float a) {
|
||||
swigfaissJNI.matrix_qr(m, n, SWIGTYPE_p_float.getCPtr(a));
|
||||
}
|
||||
|
||||
public static void ranklist_handle_ties(int k, LongVector idx, SWIGTYPE_p_float dis) {
|
||||
swigfaissJNI.ranklist_handle_ties(k, SWIGTYPE_p_long_long.getCPtr(idx.data()), idx, SWIGTYPE_p_float.getCPtr(dis));
|
||||
}
|
||||
|
||||
public static long ranklist_intersection_size(long k1, LongVector v1, long k2, LongVector v2) {
|
||||
return swigfaissJNI.ranklist_intersection_size(k1, SWIGTYPE_p_long_long.getCPtr(v1.data()), v1, k2, SWIGTYPE_p_long_long.getCPtr(v2.data()), v2);
|
||||
}
|
||||
|
||||
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1, boolean keep_min, long translation) {
|
||||
return swigfaissJNI.merge_result_table_with__SWIG_0(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1), keep_min, translation);
|
||||
}
|
||||
|
||||
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1, boolean keep_min) {
|
||||
return swigfaissJNI.merge_result_table_with__SWIG_1(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1), keep_min);
|
||||
}
|
||||
|
||||
public static long merge_result_table_with(long n, long k, LongVector I0, SWIGTYPE_p_float D0, LongVector I1, SWIGTYPE_p_float D1) {
|
||||
return swigfaissJNI.merge_result_table_with__SWIG_2(n, k, SWIGTYPE_p_long_long.getCPtr(I0.data()), I0, SWIGTYPE_p_float.getCPtr(D0), SWIGTYPE_p_long_long.getCPtr(I1.data()), I1, SWIGTYPE_p_float.getCPtr(D1));
|
||||
}
|
||||
|
||||
public static double imbalance_factor(int n, int k, LongVector assign) {
|
||||
return swigfaissJNI.imbalance_factor__SWIG_0(n, k, SWIGTYPE_p_long_long.getCPtr(assign.data()), assign);
|
||||
}
|
||||
|
||||
public static double imbalance_factor(int k, SWIGTYPE_p_int hist) {
|
||||
return swigfaissJNI.imbalance_factor__SWIG_1(k, SWIGTYPE_p_int.getCPtr(hist));
|
||||
}
|
||||
|
||||
public static void fvec_argsort(long n, SWIGTYPE_p_float vals, SWIGTYPE_p_unsigned_long perm) {
|
||||
swigfaissJNI.fvec_argsort(n, SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_unsigned_long.getCPtr(perm));
|
||||
}
|
||||
|
||||
public static void fvec_argsort_parallel(long n, SWIGTYPE_p_float vals, SWIGTYPE_p_unsigned_long perm) {
|
||||
swigfaissJNI.fvec_argsort_parallel(n, SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_unsigned_long.getCPtr(perm));
|
||||
}
|
||||
|
||||
public static int ivec_hist(long n, SWIGTYPE_p_int v, int vmax, SWIGTYPE_p_int hist) {
|
||||
return swigfaissJNI.ivec_hist(n, SWIGTYPE_p_int.getCPtr(v), vmax, SWIGTYPE_p_int.getCPtr(hist));
|
||||
}
|
||||
|
||||
public static void bincode_hist(long n, long nbits, SWIGTYPE_p_unsigned_char codes, SWIGTYPE_p_int hist) {
|
||||
swigfaissJNI.bincode_hist(n, nbits, SWIGTYPE_p_unsigned_char.getCPtr(codes), SWIGTYPE_p_int.getCPtr(hist));
|
||||
}
|
||||
|
||||
public static long ivec_checksum(long n, SWIGTYPE_p_int a) {
|
||||
return swigfaissJNI.ivec_checksum(n, SWIGTYPE_p_int.getCPtr(a));
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x, boolean verbose, long seed) {
|
||||
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_0(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x), verbose, seed);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x, boolean verbose) {
|
||||
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_1(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x), verbose);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_float fvecs_maybe_subsample(long d, SWIGTYPE_p_unsigned_long n, long nmax, SWIGTYPE_p_float x) {
|
||||
long cPtr = swigfaissJNI.fvecs_maybe_subsample__SWIG_2(d, SWIGTYPE_p_unsigned_long.getCPtr(n), nmax, SWIGTYPE_p_float.getCPtr(x));
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public static void binary_to_real(long d, SWIGTYPE_p_unsigned_char x_in, SWIGTYPE_p_float x_out) {
|
||||
swigfaissJNI.binary_to_real(d, SWIGTYPE_p_unsigned_char.getCPtr(x_in), SWIGTYPE_p_float.getCPtr(x_out));
|
||||
}
|
||||
|
||||
public static void real_to_binary(long d, SWIGTYPE_p_float x_in, SWIGTYPE_p_unsigned_char x_out) {
|
||||
swigfaissJNI.real_to_binary(d, SWIGTYPE_p_float.getCPtr(x_in), SWIGTYPE_p_unsigned_char.getCPtr(x_out));
|
||||
}
|
||||
|
||||
public static long hash_bytes(SWIGTYPE_p_unsigned_char bytes, long n) {
|
||||
return swigfaissJNI.hash_bytes(SWIGTYPE_p_unsigned_char.getCPtr(bytes), n);
|
||||
}
|
||||
|
||||
public static boolean check_openmp() {
|
||||
return swigfaissJNI.check_openmp();
|
||||
}
|
||||
|
||||
public static float kmeans_clustering(long d, long n, long k, SWIGTYPE_p_float x, SWIGTYPE_p_float centroids) {
|
||||
return swigfaissJNI.kmeans_clustering(d, n, k, SWIGTYPE_p_float.getCPtr(x), SWIGTYPE_p_float.getCPtr(centroids));
|
||||
}
|
||||
|
||||
public static void setIndexPQ_stats(IndexPQStats value) {
|
||||
swigfaissJNI.indexPQ_stats_set(IndexPQStats.getCPtr(value), value);
|
||||
}
|
||||
|
||||
public static IndexPQStats getIndexPQ_stats() {
|
||||
long cPtr = swigfaissJNI.indexPQ_stats_get();
|
||||
return (cPtr == 0) ? null : new IndexPQStats(cPtr, false);
|
||||
}
|
||||
|
||||
public static void setIndexIVF_stats(IndexIVFStats value) {
|
||||
swigfaissJNI.indexIVF_stats_set(IndexIVFStats.getCPtr(value), value);
|
||||
}
|
||||
|
||||
public static IndexIVFStats getIndexIVF_stats() {
|
||||
long cPtr = swigfaissJNI.indexIVF_stats_get();
|
||||
return (cPtr == 0) ? null : new IndexIVFStats(cPtr, false);
|
||||
}
|
||||
|
||||
public static short[] getHamdis_tab_ham_bytes() {
|
||||
return swigfaissJNI.hamdis_tab_ham_bytes_get();
|
||||
}
|
||||
|
||||
public static int generalized_hamming_64(long a) {
|
||||
return swigfaissJNI.generalized_hamming_64(a);
|
||||
}
|
||||
|
||||
public static void generalized_hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long code_size, int ordered) {
|
||||
swigfaissJNI.generalized_hammings_knn_hc__SWIG_0(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, code_size, ordered);
|
||||
}
|
||||
|
||||
public static void generalized_hammings_knn_hc(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t ha, SWIGTYPE_p_unsigned_char a, SWIGTYPE_p_unsigned_char b, long nb, long code_size) {
|
||||
swigfaissJNI.generalized_hammings_knn_hc__SWIG_1(SWIGTYPE_p_faiss__HeapArrayT_faiss__CMaxT_int_int64_t_t_t.getCPtr(ha), SWIGTYPE_p_unsigned_char.getCPtr(a), SWIGTYPE_p_unsigned_char.getCPtr(b), nb, code_size);
|
||||
}
|
||||
|
||||
public static void check_compatible_for_merge(Index index1, Index index2) {
|
||||
swigfaissJNI.check_compatible_for_merge(Index.getCPtr(index1), index1, Index.getCPtr(index2), index2);
|
||||
}
|
||||
|
||||
public static IndexIVF extract_index_ivf(Index index) {
|
||||
long cPtr = swigfaissJNI.extract_index_ivf__SWIG_0(Index.getCPtr(index), index);
|
||||
return (cPtr == 0) ? null : new IndexIVF(cPtr, false);
|
||||
}
|
||||
|
||||
public static IndexIVF try_extract_index_ivf(Index index) {
|
||||
long cPtr = swigfaissJNI.try_extract_index_ivf__SWIG_0(Index.getCPtr(index), index);
|
||||
return (cPtr == 0) ? null : new IndexIVF(cPtr, false);
|
||||
}
|
||||
|
||||
public static void merge_into(Index index0, Index index1, boolean shift_ids) {
|
||||
swigfaissJNI.merge_into(Index.getCPtr(index0), index0, Index.getCPtr(index1), index1, shift_ids);
|
||||
}
|
||||
|
||||
public static void search_centroid(Index index, SWIGTYPE_p_float x, int n, LongVector centroid_ids) {
|
||||
swigfaissJNI.search_centroid(Index.getCPtr(index), index, SWIGTYPE_p_float.getCPtr(x), n, SWIGTYPE_p_long_long.getCPtr(centroid_ids.data()), centroid_ids);
|
||||
}
|
||||
|
||||
public static void search_and_return_centroids(Index index, long n, SWIGTYPE_p_float xin, int k, SWIGTYPE_p_float distances, LongVector labels, LongVector query_centroid_ids, LongVector result_centroid_ids) {
|
||||
swigfaissJNI.search_and_return_centroids(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(xin), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, SWIGTYPE_p_long_long.getCPtr(query_centroid_ids.data()), query_centroid_ids, SWIGTYPE_p_long_long.getCPtr(result_centroid_ids.data()), result_centroid_ids);
|
||||
}
|
||||
|
||||
public static ArrayInvertedLists get_invlist_range(Index index, int i0, int i1) {
|
||||
long cPtr = swigfaissJNI.get_invlist_range(Index.getCPtr(index), index, i0, i1);
|
||||
return (cPtr == 0) ? null : new ArrayInvertedLists(cPtr, false);
|
||||
}
|
||||
|
||||
public static void set_invlist_range(Index index, int i0, int i1, ArrayInvertedLists src) {
|
||||
swigfaissJNI.set_invlist_range(Index.getCPtr(index), index, i0, i1, ArrayInvertedLists.getCPtr(src), src);
|
||||
}
|
||||
|
||||
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis, SWIGTYPE_p_double ms_per_stage) {
|
||||
swigfaissJNI.search_with_parameters__SWIG_0(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis), SWIGTYPE_p_double.getCPtr(ms_per_stage));
|
||||
}
|
||||
|
||||
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis) {
|
||||
swigfaissJNI.search_with_parameters__SWIG_1(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis));
|
||||
}
|
||||
|
||||
public static void search_with_parameters(Index index, long n, SWIGTYPE_p_float x, long k, SWIGTYPE_p_float distances, LongVector labels, IVFSearchParameters params) {
|
||||
swigfaissJNI.search_with_parameters__SWIG_2(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), k, SWIGTYPE_p_float.getCPtr(distances), SWIGTYPE_p_long_long.getCPtr(labels.data()), labels, IVFSearchParameters.getCPtr(params), params);
|
||||
}
|
||||
|
||||
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis, SWIGTYPE_p_double ms_per_stage) {
|
||||
swigfaissJNI.range_search_with_parameters__SWIG_0(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis), SWIGTYPE_p_double.getCPtr(ms_per_stage));
|
||||
}
|
||||
|
||||
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params, SWIGTYPE_p_unsigned_long nb_dis) {
|
||||
swigfaissJNI.range_search_with_parameters__SWIG_1(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params, SWIGTYPE_p_unsigned_long.getCPtr(nb_dis));
|
||||
}
|
||||
|
||||
public static void range_search_with_parameters(Index index, long n, SWIGTYPE_p_float x, float radius, RangeSearchResult result, IVFSearchParameters params) {
|
||||
swigfaissJNI.range_search_with_parameters__SWIG_2(Index.getCPtr(index), index, n, SWIGTYPE_p_float.getCPtr(x), radius, RangeSearchResult.getCPtr(result), result, IVFSearchParameters.getCPtr(params), params);
|
||||
}
|
||||
|
||||
public static void setHnsw_stats(HNSWStats value) {
|
||||
swigfaissJNI.hnsw_stats_set(HNSWStats.getCPtr(value), value);
|
||||
}
|
||||
|
||||
public static HNSWStats getHnsw_stats() {
|
||||
long cPtr = swigfaissJNI.hnsw_stats_get();
|
||||
return (cPtr == 0) ? null : new HNSWStats(cPtr, false);
|
||||
}
|
||||
|
||||
public static void setPrecomputed_table_max_bytes(long value) {
|
||||
swigfaissJNI.precomputed_table_max_bytes_set(value);
|
||||
}
|
||||
|
||||
public static long getPrecomputed_table_max_bytes() {
|
||||
return swigfaissJNI.precomputed_table_max_bytes_get();
|
||||
}
|
||||
|
||||
public static void initialize_IVFPQ_precomputed_table(SWIGTYPE_p_int use_precomputed_table, Index quantizer, ProductQuantizer pq, SWIGTYPE_p_AlignedTableT_float_32_t precomputed_table, boolean verbose) {
|
||||
swigfaissJNI.initialize_IVFPQ_precomputed_table(SWIGTYPE_p_int.getCPtr(use_precomputed_table), Index.getCPtr(quantizer), quantizer, ProductQuantizer.getCPtr(pq), pq, SWIGTYPE_p_AlignedTableT_float_32_t.getCPtr(precomputed_table), verbose);
|
||||
}
|
||||
|
||||
public static void setIndexIVFPQ_stats(IndexIVFPQStats value) {
|
||||
swigfaissJNI.indexIVFPQ_stats_set(IndexIVFPQStats.getCPtr(value), value);
|
||||
}
|
||||
|
||||
public static IndexIVFPQStats getIndexIVFPQ_stats() {
|
||||
long cPtr = swigfaissJNI.indexIVFPQ_stats_get();
|
||||
return (cPtr == 0) ? null : new IndexIVFPQStats(cPtr, false);
|
||||
}
|
||||
|
||||
public static Index downcast_index(Index index) {
|
||||
long cPtr = swigfaissJNI.downcast_index(Index.getCPtr(index), index);
|
||||
return (cPtr == 0) ? null : new Index(cPtr, false);
|
||||
}
|
||||
|
||||
public static VectorTransform downcast_VectorTransform(VectorTransform vt) {
|
||||
long cPtr = swigfaissJNI.downcast_VectorTransform(VectorTransform.getCPtr(vt), vt);
|
||||
return (cPtr == 0) ? null : new VectorTransform(cPtr, false);
|
||||
}
|
||||
|
||||
public static IndexBinary downcast_IndexBinary(IndexBinary index) {
|
||||
long cPtr = swigfaissJNI.downcast_IndexBinary(IndexBinary.getCPtr(index), index);
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, false);
|
||||
}
|
||||
|
||||
public static Index upcast_IndexShards(IndexShards index) {
|
||||
long cPtr = swigfaissJNI.upcast_IndexShards(IndexShards.getCPtr(index), index);
|
||||
return (cPtr == 0) ? null : new Index(cPtr, false);
|
||||
}
|
||||
|
||||
public static void write_index(Index idx, String fname) {
|
||||
swigfaissJNI.write_index__SWIG_0(Index.getCPtr(idx), idx, fname);
|
||||
}
|
||||
|
||||
public static void write_index(Index idx, SWIGTYPE_p_FILE f) {
|
||||
swigfaissJNI.write_index__SWIG_1(Index.getCPtr(idx), idx, SWIGTYPE_p_FILE.getCPtr(f));
|
||||
}
|
||||
|
||||
public static void write_index(Index idx, SWIGTYPE_p_faiss__IOWriter writer) {
|
||||
swigfaissJNI.write_index__SWIG_2(Index.getCPtr(idx), idx, SWIGTYPE_p_faiss__IOWriter.getCPtr(writer));
|
||||
}
|
||||
|
||||
public static void write_index_binary(IndexBinary idx, String fname) {
|
||||
swigfaissJNI.write_index_binary__SWIG_0(IndexBinary.getCPtr(idx), idx, fname);
|
||||
}
|
||||
|
||||
public static void write_index_binary(IndexBinary idx, SWIGTYPE_p_FILE f) {
|
||||
swigfaissJNI.write_index_binary__SWIG_1(IndexBinary.getCPtr(idx), idx, SWIGTYPE_p_FILE.getCPtr(f));
|
||||
}
|
||||
|
||||
public static void write_index_binary(IndexBinary idx, SWIGTYPE_p_faiss__IOWriter writer) {
|
||||
swigfaissJNI.write_index_binary__SWIG_2(IndexBinary.getCPtr(idx), idx, SWIGTYPE_p_faiss__IOWriter.getCPtr(writer));
|
||||
}
|
||||
|
||||
public static int getIO_FLAG_READ_ONLY() {
|
||||
return swigfaissJNI.IO_FLAG_READ_ONLY_get();
|
||||
}
|
||||
|
||||
public static int getIO_FLAG_ONDISK_SAME_DIR() {
|
||||
return swigfaissJNI.IO_FLAG_ONDISK_SAME_DIR_get();
|
||||
}
|
||||
|
||||
public static int getIO_FLAG_SKIP_IVF_DATA() {
|
||||
return swigfaissJNI.IO_FLAG_SKIP_IVF_DATA_get();
|
||||
}
|
||||
|
||||
public static int getIO_FLAG_MMAP() {
|
||||
return swigfaissJNI.IO_FLAG_MMAP_get();
|
||||
}
|
||||
|
||||
public static Index read_index(String fname, int io_flags) {
|
||||
long cPtr = swigfaissJNI.read_index__SWIG_0(fname, io_flags);
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static Index read_index(String fname) {
|
||||
long cPtr = swigfaissJNI.read_index__SWIG_1(fname);
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static Index read_index(SWIGTYPE_p_FILE f, int io_flags) {
|
||||
long cPtr = swigfaissJNI.read_index__SWIG_2(SWIGTYPE_p_FILE.getCPtr(f), io_flags);
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static Index read_index(SWIGTYPE_p_FILE f) {
|
||||
long cPtr = swigfaissJNI.read_index__SWIG_3(SWIGTYPE_p_FILE.getCPtr(f));
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static Index read_index(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
|
||||
long cPtr = swigfaissJNI.read_index__SWIG_4(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static Index read_index(SWIGTYPE_p_faiss__IOReader reader) {
|
||||
long cPtr = swigfaissJNI.read_index__SWIG_5(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static IndexBinary read_index_binary(String fname, int io_flags) {
|
||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_0(fname, io_flags);
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
||||
}
|
||||
|
||||
public static IndexBinary read_index_binary(String fname) {
|
||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_1(fname);
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
||||
}
|
||||
|
||||
public static IndexBinary read_index_binary(SWIGTYPE_p_FILE f, int io_flags) {
|
||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_2(SWIGTYPE_p_FILE.getCPtr(f), io_flags);
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
||||
}
|
||||
|
||||
public static IndexBinary read_index_binary(SWIGTYPE_p_FILE f) {
|
||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_3(SWIGTYPE_p_FILE.getCPtr(f));
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
||||
}
|
||||
|
||||
public static IndexBinary read_index_binary(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
|
||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_4(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
||||
}
|
||||
|
||||
public static IndexBinary read_index_binary(SWIGTYPE_p_faiss__IOReader reader) {
|
||||
long cPtr = swigfaissJNI.read_index_binary__SWIG_5(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
||||
}
|
||||
|
||||
public static void write_VectorTransform(VectorTransform vt, String fname) {
|
||||
swigfaissJNI.write_VectorTransform(VectorTransform.getCPtr(vt), vt, fname);
|
||||
}
|
||||
|
||||
public static VectorTransform read_VectorTransform(String fname) {
|
||||
long cPtr = swigfaissJNI.read_VectorTransform(fname);
|
||||
return (cPtr == 0) ? null : new VectorTransform(cPtr, true);
|
||||
}
|
||||
|
||||
public static ProductQuantizer read_ProductQuantizer(String fname) {
|
||||
long cPtr = swigfaissJNI.read_ProductQuantizer__SWIG_0(fname);
|
||||
return (cPtr == 0) ? null : new ProductQuantizer(cPtr, true);
|
||||
}
|
||||
|
||||
public static ProductQuantizer read_ProductQuantizer(SWIGTYPE_p_faiss__IOReader reader) {
|
||||
long cPtr = swigfaissJNI.read_ProductQuantizer__SWIG_1(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
||||
return (cPtr == 0) ? null : new ProductQuantizer(cPtr, true);
|
||||
}
|
||||
|
||||
public static void write_ProductQuantizer(ProductQuantizer pq, String fname) {
|
||||
swigfaissJNI.write_ProductQuantizer__SWIG_0(ProductQuantizer.getCPtr(pq), pq, fname);
|
||||
}
|
||||
|
||||
public static void write_ProductQuantizer(ProductQuantizer pq, SWIGTYPE_p_faiss__IOWriter f) {
|
||||
swigfaissJNI.write_ProductQuantizer__SWIG_1(ProductQuantizer.getCPtr(pq), pq, SWIGTYPE_p_faiss__IOWriter.getCPtr(f));
|
||||
}
|
||||
|
||||
public static void write_InvertedLists(InvertedLists ils, SWIGTYPE_p_faiss__IOWriter f) {
|
||||
swigfaissJNI.write_InvertedLists(InvertedLists.getCPtr(ils), ils, SWIGTYPE_p_faiss__IOWriter.getCPtr(f));
|
||||
}
|
||||
|
||||
public static InvertedLists read_InvertedLists(SWIGTYPE_p_faiss__IOReader reader, int io_flags) {
|
||||
long cPtr = swigfaissJNI.read_InvertedLists__SWIG_0(SWIGTYPE_p_faiss__IOReader.getCPtr(reader), io_flags);
|
||||
return (cPtr == 0) ? null : new InvertedLists(cPtr, false);
|
||||
}
|
||||
|
||||
public static InvertedLists read_InvertedLists(SWIGTYPE_p_faiss__IOReader reader) {
|
||||
long cPtr = swigfaissJNI.read_InvertedLists__SWIG_1(SWIGTYPE_p_faiss__IOReader.getCPtr(reader));
|
||||
return (cPtr == 0) ? null : new InvertedLists(cPtr, false);
|
||||
}
|
||||
|
||||
public static Index index_factory(int d, String description, MetricType metric) {
|
||||
long cPtr = swigfaissJNI.index_factory__SWIG_0(d, description, metric.swigValue());
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static Index index_factory(int d, String description) {
|
||||
long cPtr = swigfaissJNI.index_factory__SWIG_1(d, description);
|
||||
return (cPtr == 0) ? null : new Index(cPtr, true);
|
||||
}
|
||||
|
||||
public static void setIndex_factory_verbose(int value) {
|
||||
swigfaissJNI.index_factory_verbose_set(value);
|
||||
}
|
||||
|
||||
public static int getIndex_factory_verbose() {
|
||||
return swigfaissJNI.index_factory_verbose_get();
|
||||
}
|
||||
|
||||
public static IndexBinary index_binary_factory(int d, String description) {
|
||||
long cPtr = swigfaissJNI.index_binary_factory(d, description);
|
||||
return (cPtr == 0) ? null : new IndexBinary(cPtr, true);
|
||||
}
|
||||
|
||||
public static void simd_histogram_8(SWIGTYPE_p_uint16_t data, int n, SWIGTYPE_p_uint16_t min, int shift, SWIGTYPE_p_int hist) {
|
||||
swigfaissJNI.simd_histogram_8(SWIGTYPE_p_uint16_t.getCPtr(data), n, SWIGTYPE_p_uint16_t.getCPtr(min), shift, SWIGTYPE_p_int.getCPtr(hist));
|
||||
}
|
||||
|
||||
public static void simd_histogram_16(SWIGTYPE_p_uint16_t data, int n, SWIGTYPE_p_uint16_t min, int shift, SWIGTYPE_p_int hist) {
|
||||
swigfaissJNI.simd_histogram_16(SWIGTYPE_p_uint16_t.getCPtr(data), n, SWIGTYPE_p_uint16_t.getCPtr(min), shift, SWIGTYPE_p_int.getCPtr(hist));
|
||||
}
|
||||
|
||||
public static void setPartition_stats(PartitionStats value) {
|
||||
swigfaissJNI.partition_stats_set(PartitionStats.getCPtr(value), value);
|
||||
}
|
||||
|
||||
public static PartitionStats getPartition_stats() {
|
||||
long cPtr = swigfaissJNI.partition_stats_get();
|
||||
return (cPtr == 0) ? null : new PartitionStats(cPtr, false);
|
||||
}
|
||||
|
||||
public static float CMin_float_partition_fuzzy(SWIGTYPE_p_float vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
||||
return swigfaissJNI.CMin_float_partition_fuzzy(SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out));
|
||||
}
|
||||
|
||||
public static float CMax_float_partition_fuzzy(SWIGTYPE_p_float vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
||||
return swigfaissJNI.CMax_float_partition_fuzzy(SWIGTYPE_p_float.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out));
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_uint16_t CMax_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMax_uint16_partition_fuzzy__SWIG_0(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_uint16_t CMin_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, LongVector ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMin_uint16_partition_fuzzy__SWIG_0(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_long_long.getCPtr(ids.data()), ids, n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_uint16_t CMax_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, SWIGTYPE_p_int ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMax_uint16_partition_fuzzy__SWIG_1(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_int.getCPtr(ids), n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_uint16_t CMin_uint16_partition_fuzzy(SWIGTYPE_p_uint16_t vals, SWIGTYPE_p_int ids, long n, long q_min, long q_max, SWIGTYPE_p_unsigned_long q_out) {
|
||||
return new SWIGTYPE_p_uint16_t(swigfaissJNI.CMin_uint16_partition_fuzzy__SWIG_1(SWIGTYPE_p_uint16_t.getCPtr(vals), SWIGTYPE_p_int.getCPtr(ids), n, q_min, q_max, SWIGTYPE_p_unsigned_long.getCPtr(q_out)), true);
|
||||
}
|
||||
|
||||
public static void omp_set_num_threads(int num_threads) {
|
||||
swigfaissJNI.omp_set_num_threads(num_threads);
|
||||
}
|
||||
|
||||
public static int omp_get_max_threads() {
|
||||
return swigfaissJNI.omp_get_max_threads();
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_void memcpy(SWIGTYPE_p_void dest, SWIGTYPE_p_void src, long n) {
|
||||
long cPtr = swigfaissJNI.memcpy(SWIGTYPE_p_void.getCPtr(dest), SWIGTYPE_p_void.getCPtr(src), n);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_void(cPtr, false);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_float cast_integer_to_float_ptr(int x) {
|
||||
long cPtr = swigfaissJNI.cast_integer_to_float_ptr(x);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_float(cPtr, false);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_long cast_integer_to_long_ptr(int x) {
|
||||
long cPtr = swigfaissJNI.cast_integer_to_long_ptr(x);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_long(cPtr, false);
|
||||
}
|
||||
|
||||
public static SWIGTYPE_p_int cast_integer_to_int_ptr(int x) {
|
||||
long cPtr = swigfaissJNI.cast_integer_to_int_ptr(x);
|
||||
return (cPtr == 0) ? null : new SWIGTYPE_p_int(cPtr, false);
|
||||
}
|
||||
|
||||
public static void ignore_SIGTTIN() {
|
||||
swigfaissJNI.ignore_SIGTTIN();
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
@ -1,15 +0,0 @@
|
||||
/* ----------------------------------------------------------------------------
|
||||
* This file was automatically generated by SWIG (http://www.swig.org).
|
||||
* Version 4.0.2
|
||||
*
|
||||
* Do not make changes to this file unless you know what you are doing--modify
|
||||
* the SWIG interface file instead.
|
||||
* ----------------------------------------------------------------------------- */
|
||||
|
||||
package com.twitter.ann.faiss;
|
||||
|
||||
public interface swigfaissConstants {
|
||||
public final static int FAISS_VERSION_MAJOR = swigfaissJNI.FAISS_VERSION_MAJOR_get();
|
||||
public final static int FAISS_VERSION_MINOR = swigfaissJNI.FAISS_VERSION_MINOR_get();
|
||||
public final static int FAISS_VERSION_PATCH = swigfaissJNI.FAISS_VERSION_PATCH_get();
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaissJNI.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/faiss/swig/swigfaissJNI.docx
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -1,18 +0,0 @@
|
||||
java_library(
|
||||
sources = ["*.java"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
platform = "java8",
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/com/google/guava",
|
||||
"3rdparty/jvm/com/google/inject:guice",
|
||||
"3rdparty/jvm/com/twitter/bijection:core",
|
||||
"3rdparty/jvm/commons-lang",
|
||||
"3rdparty/jvm/org/apache/thrift",
|
||||
"ann/src/main/scala/com/twitter/ann/common",
|
||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-java",
|
||||
"mediaservices/commons/src/main/scala:futuretracker",
|
||||
"scrooge/scrooge-core",
|
||||
"src/java/com/twitter/search/common/file",
|
||||
],
|
||||
)
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/BUILD.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/BUILD.docx
Normal file
Binary file not shown.
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistanceFunction.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistanceFunction.docx
Normal file
Binary file not shown.
@ -1,8 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
public interface DistanceFunction<T, Q> {
|
||||
/**
|
||||
* Distance between two items.
|
||||
*/
|
||||
float distance(T t, Q q);
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItem.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItem.docx
Normal file
Binary file not shown.
@ -1,23 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
/**
|
||||
* An item associated with a float distance
|
||||
* @param <T> The type of the item.
|
||||
*/
|
||||
public class DistancedItem<T> {
|
||||
private final T item;
|
||||
private final float distance;
|
||||
|
||||
public DistancedItem(T item, float distance) {
|
||||
this.item = item;
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
public T getItem() {
|
||||
return item;
|
||||
}
|
||||
|
||||
public float getDistance() {
|
||||
return distance;
|
||||
}
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItemQueue.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/DistancedItemQueue.docx
Normal file
Binary file not shown.
@ -1,196 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.PriorityQueue;
|
||||
|
||||
/**
|
||||
* Container for items with their distance.
|
||||
*
|
||||
* @param <U> Type of origin/reference element.
|
||||
* @param <T> Type of element that the queue will hold
|
||||
*/
|
||||
public class DistancedItemQueue<U, T> implements Iterable<DistancedItem<T>> {
|
||||
private final U origin;
|
||||
private final DistanceFunction<U, T> distFn;
|
||||
private final PriorityQueue<DistancedItem<T>> queue;
|
||||
private final boolean minQueue;
|
||||
/**
|
||||
* Creates ontainer for items with their distances.
|
||||
*
|
||||
* @param origin Origin (reference) point
|
||||
* @param initial Initial list of elements to add in the structure
|
||||
* @param minQueue True for min queue, False for max queue
|
||||
* @param distFn Distance function
|
||||
*/
|
||||
public DistancedItemQueue(
|
||||
U origin,
|
||||
List<T> initial,
|
||||
boolean minQueue,
|
||||
DistanceFunction<U, T> distFn
|
||||
) {
|
||||
this.origin = origin;
|
||||
this.distFn = distFn;
|
||||
this.minQueue = minQueue;
|
||||
final Comparator<DistancedItem<T>> cmp;
|
||||
if (minQueue) {
|
||||
cmp = (o1, o2) -> Float.compare(o1.getDistance(), o2.getDistance());
|
||||
} else {
|
||||
cmp = (o1, o2) -> Float.compare(o2.getDistance(), o1.getDistance());
|
||||
}
|
||||
this.queue = new PriorityQueue<>(cmp);
|
||||
enqueueAll(initial);
|
||||
new DistancedItemQueue<>(origin, distFn, queue, minQueue);
|
||||
}
|
||||
|
||||
private DistancedItemQueue(
|
||||
U origin,
|
||||
DistanceFunction<U, T> distFn,
|
||||
PriorityQueue<DistancedItem<T>> queue,
|
||||
boolean minQueue
|
||||
) {
|
||||
this.origin = origin;
|
||||
this.distFn = distFn;
|
||||
this.queue = queue;
|
||||
this.minQueue = minQueue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enqueues all the items into the queue.
|
||||
*/
|
||||
public void enqueueAll(List<T> list) {
|
||||
for (T t : list) {
|
||||
enqueue(t);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return if queue is non empty or not
|
||||
*
|
||||
* @return true if queue is not empty else false
|
||||
*/
|
||||
public boolean nonEmpty() {
|
||||
return !queue.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return root of the queue
|
||||
*
|
||||
* @return root of the queue i.e min/max element depending upon min-max queue
|
||||
*/
|
||||
public DistancedItem<T> peek() {
|
||||
return queue.peek();
|
||||
}
|
||||
|
||||
/**
|
||||
* Dequeue root of the queue.
|
||||
*
|
||||
* @return remove and return root of the queue i.e min/max element depending upon min-max queue
|
||||
*/
|
||||
public DistancedItem<T> dequeue() {
|
||||
return queue.poll();
|
||||
}
|
||||
|
||||
/**
|
||||
* Dequeue all the elements from queueu with ordering mantained
|
||||
*
|
||||
* @return remove all the elements in the order of the queue i.e min/max queue.
|
||||
*/
|
||||
public List<DistancedItem<T>> dequeueAll() {
|
||||
final List<DistancedItem<T>> list = new ArrayList<>(queue.size());
|
||||
while (!queue.isEmpty()) {
|
||||
list.add(queue.poll());
|
||||
}
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert queue to list
|
||||
*
|
||||
* @return list of elements of queue with distance and without any specific ordering
|
||||
*/
|
||||
public List<DistancedItem<T>> toList() {
|
||||
return new ArrayList<>(queue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert queue to list
|
||||
*
|
||||
* @return list of elements of queue without any specific ordering
|
||||
*/
|
||||
List<T> toListWithItem() {
|
||||
List<T> list = new ArrayList<>(queue.size());
|
||||
Iterator<DistancedItem<T>> itr = iterator();
|
||||
while (itr.hasNext()) {
|
||||
list.add(itr.next().getItem());
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enqueue an item into the queue
|
||||
*/
|
||||
public void enqueue(T item) {
|
||||
queue.add(new DistancedItem<>(item, distFn.distance(origin, item)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Enqueue an item into the queue with its distance.
|
||||
*/
|
||||
public void enqueue(T item, float distance) {
|
||||
queue.add(new DistancedItem<>(item, distance));
|
||||
}
|
||||
|
||||
/**
|
||||
* Size
|
||||
*
|
||||
* @return size of the queue
|
||||
*/
|
||||
public int size() {
|
||||
return queue.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Is Min queue
|
||||
*
|
||||
* @return true if min queue else false
|
||||
*/
|
||||
public boolean isMinQueue() {
|
||||
return minQueue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns origin (base element) of the queue
|
||||
*
|
||||
* @return origin of the queue
|
||||
*/
|
||||
public U getOrigin() {
|
||||
return origin;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new queue with ordering reversed.
|
||||
*/
|
||||
public DistancedItemQueue<U, T> reverse() {
|
||||
final PriorityQueue<DistancedItem<T>> rqueue =
|
||||
new PriorityQueue<>(queue.comparator().reversed());
|
||||
if (queue.isEmpty()) {
|
||||
return new DistancedItemQueue<>(origin, distFn, rqueue, !isMinQueue());
|
||||
}
|
||||
|
||||
final Iterator<DistancedItem<T>> itr = iterator();
|
||||
while (itr.hasNext()) {
|
||||
rqueue.add(itr.next());
|
||||
}
|
||||
|
||||
return new DistancedItemQueue<>(origin, distFn, rqueue, !isMinQueue());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<DistancedItem<T>> iterator() {
|
||||
return queue.iterator();
|
||||
}
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndex.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndex.docx
Normal file
Binary file not shown.
@ -1,711 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import java.util.function.Function;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import org.apache.thrift.TException;
|
||||
|
||||
import com.twitter.ann.common.IndexOutputFile;
|
||||
import com.twitter.ann.common.thriftjava.HnswInternalIndexMetadata;
|
||||
import com.twitter.bijection.Injection;
|
||||
import com.twitter.logging.Logger;
|
||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec;
|
||||
import com.twitter.search.common.file.AbstractFile;
|
||||
|
||||
/**
|
||||
* Typed multithreaded HNSW implementation supporting creation/querying of approximate nearest neighbour
|
||||
* Paper: https://arxiv.org/pdf/1603.09320.pdf
|
||||
* Multithreading impl based on NMSLIB version : https://github.com/nmslib/hnsw/blob/master/hnswlib/hnswalg.h
|
||||
*
|
||||
* @param <T> The type of items inserted / searched in the HNSW index.
|
||||
* @param <Q> The type of KNN query.
|
||||
*/
|
||||
public class HnswIndex<T, Q> {
|
||||
private static final Logger LOG = Logger.get(HnswIndex.class);
|
||||
private static final String METADATA_FILE_NAME = "hnsw_internal_metadata";
|
||||
private static final String GRAPH_FILE_NAME = "hnsw_internal_graph";
|
||||
private static final int MAP_SIZE_FACTOR = 5;
|
||||
|
||||
private final DistanceFunction<T, T> distFnIndex;
|
||||
private final DistanceFunction<Q, T> distFnQuery;
|
||||
private final int efConstruction;
|
||||
private final int maxM;
|
||||
private final int maxM0;
|
||||
private final double levelMultiplier;
|
||||
private final AtomicReference<HnswMeta<T>> graphMeta = new AtomicReference<>();
|
||||
private final Map<HnswNode<T>, ImmutableList<T>> graph;
|
||||
// To take lock on vertex level
|
||||
private final ConcurrentHashMap<T, ReadWriteLock> locks;
|
||||
// To take lock on whole graph only if vertex addition is on layer above the current maxLevel
|
||||
private final ReentrantLock globalLock;
|
||||
private final Function<T, ReadWriteLock> lockProvider;
|
||||
|
||||
private final RandomProvider randomProvider;
|
||||
|
||||
// Probability of reevaluating connections of an element in the neighborhood during an update
|
||||
// Can be used as a knob to adjust update_speed/search_speed tradeoff.
|
||||
private final float updateNeighborProbability;
|
||||
|
||||
/**
|
||||
* Creates instance of hnsw index.
|
||||
*
|
||||
* @param distFnIndex Any distance metric/non metric that specifies similarity between two items for indexing.
|
||||
* @param distFnQuery Any distance metric/non metric that specifies similarity between item for which nearest neighbours queried for and already indexed item.
|
||||
* @param efConstruction Provide speed vs index quality tradeoff, higher the value better the quality and higher the time to create index.
|
||||
* Valid range of efConstruction can be anywhere between 1 and tens of thousand. Typically, it should be set so that a search of M
|
||||
* neighbors with ef=efConstruction should end in recall>0.95.
|
||||
* @param maxM Maximum connections per layer except 0th level.
|
||||
* Optimal values between 5-48.
|
||||
* Smaller M generally produces better result for lower recalls and/ or lower dimensional data,
|
||||
* while bigger M is better for high recall and/ or high dimensional, data on the expense of more memory/disk usage
|
||||
* @param expectedElements Approximate number of elements to be indexed
|
||||
*/
|
||||
protected HnswIndex(
|
||||
DistanceFunction<T, T> distFnIndex,
|
||||
DistanceFunction<Q, T> distFnQuery,
|
||||
int efConstruction,
|
||||
int maxM,
|
||||
int expectedElements,
|
||||
RandomProvider randomProvider
|
||||
) {
|
||||
this(distFnIndex,
|
||||
distFnQuery,
|
||||
efConstruction,
|
||||
maxM,
|
||||
expectedElements,
|
||||
new HnswMeta<>(-1, Optional.empty()),
|
||||
new ConcurrentHashMap<>(MAP_SIZE_FACTOR * expectedElements),
|
||||
randomProvider
|
||||
);
|
||||
}
|
||||
|
||||
private HnswIndex(
|
||||
DistanceFunction<T, T> distFnIndex,
|
||||
DistanceFunction<Q, T> distFnQuery,
|
||||
int efConstruction,
|
||||
int maxM,
|
||||
int expectedElements,
|
||||
HnswMeta<T> graphMeta,
|
||||
Map<HnswNode<T>, ImmutableList<T>> graph,
|
||||
RandomProvider randomProvider
|
||||
) {
|
||||
this.distFnIndex = distFnIndex;
|
||||
this.distFnQuery = distFnQuery;
|
||||
this.efConstruction = efConstruction;
|
||||
this.maxM = maxM;
|
||||
this.maxM0 = 2 * maxM;
|
||||
this.levelMultiplier = 1.0 / Math.log(1.0 * maxM);
|
||||
this.graphMeta.set(graphMeta);
|
||||
this.graph = graph;
|
||||
this.locks = new ConcurrentHashMap<>(MAP_SIZE_FACTOR * expectedElements);
|
||||
this.globalLock = new ReentrantLock();
|
||||
this.lockProvider = key -> new ReentrantReadWriteLock();
|
||||
this.randomProvider = randomProvider;
|
||||
this.updateNeighborProbability = 1.0f;
|
||||
}
|
||||
|
||||
/**
|
||||
* wireConnectionForAllLayers finds connections for a new element and creates bi-direction links.
|
||||
* The method assumes using a reentrant lock to link list reads.
|
||||
*
|
||||
* @param entryPoint the global entry point
|
||||
* @param item the item for which the connections are found
|
||||
* @param itemLevel the level of the added item (maximum layer in which we wire the connections)
|
||||
* @param maxLayer the level of the entry point
|
||||
*/
|
||||
private void wireConnectionForAllLayers(final T entryPoint, final T item, final int itemLevel,
|
||||
final int maxLayer, final boolean isUpdate) {
|
||||
T curObj = entryPoint;
|
||||
if (itemLevel < maxLayer) {
|
||||
curObj = bestEntryPointUntilLayer(curObj, item, maxLayer, itemLevel, distFnIndex);
|
||||
}
|
||||
for (int level = Math.min(itemLevel, maxLayer); level >= 0; level--) {
|
||||
final DistancedItemQueue<T, T> candidates =
|
||||
searchLayerForCandidates(item, curObj, efConstruction, level, distFnIndex, isUpdate);
|
||||
curObj = mutuallyConnectNewElement(item, candidates, level, isUpdate);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Insert the item into HNSW index.
|
||||
*/
|
||||
public void insert(final T item) throws IllegalDuplicateInsertException {
|
||||
final Lock itemLock = locks.computeIfAbsent(item, lockProvider).writeLock();
|
||||
itemLock.lock();
|
||||
try {
|
||||
final HnswMeta<T> metadata = graphMeta.get();
|
||||
// If the graph already have the item, should not re-insert it again
|
||||
// Need to check entry point in case we reinsert first item where is are no graph
|
||||
// but only a entry point
|
||||
if (graph.containsKey(HnswNode.from(0, item))
|
||||
|| (metadata.getEntryPoint().isPresent()
|
||||
&& Objects.equals(metadata.getEntryPoint().get(), item))) {
|
||||
throw new IllegalDuplicateInsertException(
|
||||
"Duplicate insertion is not supported: " + item);
|
||||
}
|
||||
final int curLevel = getRandomLevel();
|
||||
Optional<T> entryPoint = metadata.getEntryPoint();
|
||||
// The global lock prevents two threads from making changes to the entry point. This lock
|
||||
// should get taken very infrequently. Something like log-base-levelMultiplier(num items)
|
||||
// For a full explanation of locking see this document: http://go/hnsw-locking
|
||||
int maxLevelCopy = metadata.getMaxLevel();
|
||||
if (curLevel > maxLevelCopy) {
|
||||
globalLock.lock();
|
||||
// Re initialize the entryPoint and maxLevel in case these are changed by any other thread
|
||||
// No need to check the condition again since,
|
||||
// it is already checked at the end before updating entry point struct
|
||||
// No need to unlock for optimization and keeping as is if condition fails since threads
|
||||
// will not be entering this section a lot.
|
||||
final HnswMeta<T> temp = graphMeta.get();
|
||||
entryPoint = temp.getEntryPoint();
|
||||
maxLevelCopy = temp.getMaxLevel();
|
||||
}
|
||||
|
||||
if (entryPoint.isPresent()) {
|
||||
wireConnectionForAllLayers(entryPoint.get(), item, curLevel, maxLevelCopy, false);
|
||||
}
|
||||
|
||||
if (curLevel > maxLevelCopy) {
|
||||
Preconditions.checkState(globalLock.isHeldByCurrentThread(),
|
||||
"Global lock not held before updating entry point");
|
||||
graphMeta.set(new HnswMeta<>(curLevel, Optional.of(item)));
|
||||
}
|
||||
} finally {
|
||||
if (globalLock.isHeldByCurrentThread()) {
|
||||
globalLock.unlock();
|
||||
}
|
||||
itemLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* set connections of an element with synchronization
|
||||
* The only other place that should have the lock for writing is during
|
||||
* the element insertion
|
||||
*/
|
||||
private void setConnectionList(final T item, int layer, List<T> connections) {
|
||||
final Lock candidateLock = locks.computeIfAbsent(item, lockProvider).writeLock();
|
||||
candidateLock.lock();
|
||||
try {
|
||||
graph.put(
|
||||
HnswNode.from(layer, item),
|
||||
ImmutableList.copyOf(connections)
|
||||
);
|
||||
} finally {
|
||||
candidateLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reinsert the item into HNSW index.
|
||||
* This method updates the links of an element assuming
|
||||
* the element's distance function is changed externally (e.g. by updating the features)
|
||||
*/
|
||||
|
||||
public void reInsert(final T item) {
|
||||
final HnswMeta<T> metadata = graphMeta.get();
|
||||
|
||||
Optional<T> entryPoint = metadata.getEntryPoint();
|
||||
|
||||
Preconditions.checkState(entryPoint.isPresent(),
|
||||
"Update cannot be performed if entry point is not present");
|
||||
|
||||
// This is a check for the single element case
|
||||
if (entryPoint.get().equals(item) && graph.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Preconditions.checkState(graph.containsKey(HnswNode.from(0, item)),
|
||||
"Graph does not contain the item to be updated at level 0");
|
||||
|
||||
int curLevel = 0;
|
||||
|
||||
int maxLevelCopy = metadata.getMaxLevel();
|
||||
|
||||
for (int layer = maxLevelCopy; layer >= 0; layer--) {
|
||||
if (graph.containsKey(HnswNode.from(layer, item))) {
|
||||
curLevel = layer;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Updating the links of the elements from the 1-hop radius of the updated element
|
||||
|
||||
for (int layer = 0; layer <= curLevel; layer++) {
|
||||
|
||||
// Filling the element sets for candidates and updated elements
|
||||
final HashSet<T> setCand = new HashSet<T>();
|
||||
final HashSet<T> setNeigh = new HashSet<T>();
|
||||
final List<T> listOneHop = getConnectionListForRead(item, layer);
|
||||
|
||||
if (listOneHop.isEmpty()) {
|
||||
LOG.debug("No links for the updated element. Empty dataset?");
|
||||
continue;
|
||||
}
|
||||
|
||||
setCand.add(item);
|
||||
|
||||
for (T elOneHop : listOneHop) {
|
||||
setCand.add(elOneHop);
|
||||
if (randomProvider.get().nextFloat() > updateNeighborProbability) {
|
||||
continue;
|
||||
}
|
||||
setNeigh.add(elOneHop);
|
||||
final List<T> listTwoHop = getConnectionListForRead(elOneHop, layer);
|
||||
|
||||
if (listTwoHop.isEmpty()) {
|
||||
LOG.debug("No links for the updated element. Empty dataset?");
|
||||
}
|
||||
|
||||
for (T oneHopEl : listTwoHop) {
|
||||
setCand.add(oneHopEl);
|
||||
}
|
||||
}
|
||||
// No need to update the item itself, so remove it
|
||||
setNeigh.remove(item);
|
||||
|
||||
// Updating the link lists of elements from setNeigh:
|
||||
for (T neigh : setNeigh) {
|
||||
final HashSet<T> setCopy = new HashSet<T>(setCand);
|
||||
setCopy.remove(neigh);
|
||||
int keepElementsNum = Math.min(efConstruction, setCopy.size());
|
||||
final DistancedItemQueue<T, T> candidates = new DistancedItemQueue<>(
|
||||
neigh,
|
||||
ImmutableList.of(),
|
||||
false,
|
||||
distFnIndex
|
||||
);
|
||||
for (T cand : setCopy) {
|
||||
final float distance = distFnIndex.distance(neigh, cand);
|
||||
if (candidates.size() < keepElementsNum) {
|
||||
candidates.enqueue(cand, distance);
|
||||
} else {
|
||||
if (distance < candidates.peek().getDistance()) {
|
||||
candidates.dequeue();
|
||||
candidates.enqueue(cand, distance);
|
||||
}
|
||||
}
|
||||
}
|
||||
final ImmutableList<T> neighbours = selectNearestNeighboursByHeuristic(
|
||||
candidates,
|
||||
layer == 0 ? maxM0 : maxM
|
||||
);
|
||||
|
||||
final List<T> temp = getConnectionListForRead(neigh, layer);
|
||||
if (temp.isEmpty()) {
|
||||
LOG.debug("existing linkslist is empty. Corrupt index");
|
||||
}
|
||||
if (neighbours.isEmpty()) {
|
||||
LOG.debug("predicted linkslist is empty. Corrupt index");
|
||||
}
|
||||
setConnectionList(neigh, layer, neighbours);
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
wireConnectionForAllLayers(metadata.getEntryPoint().get(), item, curLevel, maxLevelCopy, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method can be used to get the graph statistics, specifically
|
||||
* it prints the histogram of inbound connections for each element.
|
||||
*/
|
||||
private String getStats() {
|
||||
int histogramMaxBins = 50;
|
||||
int[] histogram = new int[histogramMaxBins];
|
||||
HashMap<T, Integer> mmap = new HashMap<T, Integer>();
|
||||
for (HnswNode<T> key : graph.keySet()) {
|
||||
if (key.level == 0) {
|
||||
List<T> linkList = getConnectionListForRead(key.item, key.level);
|
||||
for (T node : linkList) {
|
||||
int a = mmap.computeIfAbsent(node, k -> 0);
|
||||
mmap.put(node, a + 1);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (T key : mmap.keySet()) {
|
||||
int ind = mmap.get(key) < histogramMaxBins - 1 ? mmap.get(key) : histogramMaxBins - 1;
|
||||
histogram[ind]++;
|
||||
}
|
||||
int minNonZeroIndex;
|
||||
for (minNonZeroIndex = histogramMaxBins - 1; minNonZeroIndex >= 0; minNonZeroIndex--) {
|
||||
if (histogram[minNonZeroIndex] > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
String output = "";
|
||||
for (int i = 0; i <= minNonZeroIndex; i++) {
|
||||
output += "" + i + "\t" + histogram[i] / (0.01f * mmap.keySet().size()) + "\n";
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
private int getRandomLevel() {
|
||||
return (int) (-Math.log(randomProvider.get().nextDouble()) * levelMultiplier);
|
||||
}
|
||||
|
||||
/**
|
||||
* Note that to avoid deadlocks it is important that this method is called after all the searches
|
||||
* of the graph have completed. If you take a lock on any items discovered in the graph after
|
||||
* this, you may get stuck waiting on a thread that is waiting for item to be fully inserted.
|
||||
* <p>
|
||||
* Note: when using concurrent writers we can miss connections that we would otherwise get.
|
||||
* This will reduce the recall.
|
||||
* <p>
|
||||
* For a full explanation of locking see this document: http://go/hnsw-locking
|
||||
* The method returns the closest nearest neighbor (can be used as an enter point)
|
||||
*/
|
||||
private T mutuallyConnectNewElement(
|
||||
final T item,
|
||||
final DistancedItemQueue<T, T> candidates, // Max queue
|
||||
final int level,
|
||||
final boolean isUpdate
|
||||
) {
|
||||
|
||||
// Using maxM here. Its implementation is ambiguous in HNSW paper,
|
||||
// so using the way it is getting used in Hnsw lib.
|
||||
final ImmutableList<T> neighbours = selectNearestNeighboursByHeuristic(candidates, maxM);
|
||||
setConnectionList(item, level, neighbours);
|
||||
final int M = level == 0 ? maxM0 : maxM;
|
||||
for (T nn : neighbours) {
|
||||
if (nn.equals(item)) {
|
||||
continue;
|
||||
}
|
||||
final Lock curLock = locks.computeIfAbsent(nn, lockProvider).writeLock();
|
||||
curLock.lock();
|
||||
try {
|
||||
final HnswNode<T> key = HnswNode.from(level, nn);
|
||||
final ImmutableList<T> connections = graph.getOrDefault(key, ImmutableList.of());
|
||||
final boolean isItemAlreadyPresent =
|
||||
isUpdate && connections.indexOf(item) != -1 ? true : false;
|
||||
|
||||
// If `item` is already present in the neighboring connections,
|
||||
// then no need to modify any connections or run the search heuristics.
|
||||
if (isItemAlreadyPresent) {
|
||||
continue;
|
||||
}
|
||||
|
||||
final ImmutableList<T> updatedConnections;
|
||||
if (connections.size() < M) {
|
||||
final List<T> temp = new ArrayList<>(connections);
|
||||
temp.add(item);
|
||||
updatedConnections = ImmutableList.copyOf(temp.iterator());
|
||||
} else {
|
||||
// Max Queue
|
||||
final DistancedItemQueue<T, T> queue = new DistancedItemQueue<>(
|
||||
nn,
|
||||
connections,
|
||||
false,
|
||||
distFnIndex
|
||||
);
|
||||
queue.enqueue(item);
|
||||
updatedConnections = selectNearestNeighboursByHeuristic(queue, M);
|
||||
}
|
||||
if (updatedConnections.isEmpty()) {
|
||||
LOG.debug("Internal error: predicted linkslist is empty");
|
||||
}
|
||||
|
||||
graph.put(key, updatedConnections);
|
||||
} finally {
|
||||
curLock.unlock();
|
||||
}
|
||||
}
|
||||
return neighbours.get(0);
|
||||
}
|
||||
|
||||
/*
|
||||
* bestEntryPointUntilLayer starts the graph search for item from the entry point
|
||||
* until the searches reaches the selectedLayer layer.
|
||||
* @return a point from selectedLayer layer, was the closest on the (selectedLayer+1) layer
|
||||
*/
|
||||
private <K> T bestEntryPointUntilLayer(
|
||||
final T entryPoint,
|
||||
final K item,
|
||||
int maxLayer,
|
||||
int selectedLayer,
|
||||
DistanceFunction<K, T> distFn
|
||||
) {
|
||||
T curObj = entryPoint;
|
||||
if (selectedLayer < maxLayer) {
|
||||
float curDist = distFn.distance(item, curObj);
|
||||
for (int level = maxLayer; level > selectedLayer; level--) {
|
||||
boolean changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
final List<T> list = getConnectionListForRead(curObj, level);
|
||||
for (T nn : list) {
|
||||
final float tempDist = distFn.distance(item, nn);
|
||||
if (tempDist < curDist) {
|
||||
curDist = tempDist;
|
||||
curObj = nn;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return curObj;
|
||||
}
|
||||
|
||||
|
||||
@VisibleForTesting
|
||||
protected ImmutableList<T> selectNearestNeighboursByHeuristic(
|
||||
final DistancedItemQueue<T, T> candidates, // Max queue
|
||||
final int maxConnections
|
||||
) {
|
||||
Preconditions.checkState(!candidates.isMinQueue(),
|
||||
"candidates in selectNearestNeighboursByHeuristic should be a max queue");
|
||||
|
||||
final T baseElement = candidates.getOrigin();
|
||||
if (candidates.size() <= maxConnections) {
|
||||
List<T> list = candidates.toListWithItem();
|
||||
list.remove(baseElement);
|
||||
return ImmutableList.copyOf(list);
|
||||
} else {
|
||||
final List<T> resSet = new ArrayList<>(maxConnections);
|
||||
// Min queue for closest elements first
|
||||
final DistancedItemQueue<T, T> minQueue = candidates.reverse();
|
||||
while (minQueue.nonEmpty()) {
|
||||
if (resSet.size() >= maxConnections) {
|
||||
break;
|
||||
}
|
||||
final DistancedItem<T> candidate = minQueue.dequeue();
|
||||
|
||||
// We do not want to creates loops:
|
||||
// While heuristic is used only for creating the links
|
||||
if (candidate.getItem().equals(baseElement)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
boolean toInclude = true;
|
||||
for (T e : resSet) {
|
||||
// Do not include candidate if the distance from candidate to any of existing item in
|
||||
// resSet is closer to the distance from the candidate to the item. By doing this, the
|
||||
// connection of graph will be more diverse, and in case of highly clustered data set,
|
||||
// connections will be made between clusters instead of all being in the same cluster.
|
||||
final float dist = distFnIndex.distance(e, candidate.getItem());
|
||||
if (dist < candidate.getDistance()) {
|
||||
toInclude = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (toInclude) {
|
||||
resSet.add(candidate.getItem());
|
||||
}
|
||||
}
|
||||
return ImmutableList.copyOf(resSet);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Search the index for the neighbours.
|
||||
*
|
||||
* @param query Query
|
||||
* @param numOfNeighbours Number of neighbours to search for.
|
||||
* @param ef This param controls the accuracy of the search.
|
||||
* Bigger the ef better the accuracy on the expense of latency.
|
||||
* Keep it atleast number of neighbours to find.
|
||||
* @return Neighbours
|
||||
*/
|
||||
public List<DistancedItem<T>> searchKnn(final Q query, final int numOfNeighbours, final int ef) {
|
||||
final HnswMeta<T> metadata = graphMeta.get();
|
||||
if (metadata.getEntryPoint().isPresent()) {
|
||||
T entryPoint = bestEntryPointUntilLayer(metadata.getEntryPoint().get(),
|
||||
query, metadata.getMaxLevel(), 0, distFnQuery);
|
||||
// Get the actual neighbours from 0th layer
|
||||
final List<DistancedItem<T>> neighbours =
|
||||
searchLayerForCandidates(query, entryPoint, Math.max(ef, numOfNeighbours),
|
||||
0, distFnQuery, false).dequeueAll();
|
||||
Collections.reverse(neighbours);
|
||||
return neighbours.size() > numOfNeighbours
|
||||
? neighbours.subList(0, numOfNeighbours) : neighbours;
|
||||
} else {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
// This method is currently not used
|
||||
// It is needed for debugging purposes only
|
||||
private void checkIntegrity(String message) {
|
||||
final HnswMeta<T> metadata = graphMeta.get();
|
||||
for (HnswNode<T> node : graph.keySet()) {
|
||||
List<T> linkList = graph.get(node);
|
||||
|
||||
for (T el : linkList) {
|
||||
if (el.equals(node.item)) {
|
||||
LOG.debug(message);
|
||||
throw new RuntimeException("integrity check failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private <K> DistancedItemQueue<K, T> searchLayerForCandidates(
|
||||
final K item,
|
||||
final T entryPoint,
|
||||
final int ef,
|
||||
final int level,
|
||||
final DistanceFunction<K, T> distFn,
|
||||
boolean isUpdate
|
||||
) {
|
||||
// Min queue
|
||||
final DistancedItemQueue<K, T> cQueue = new DistancedItemQueue<>(
|
||||
item,
|
||||
Collections.singletonList(entryPoint),
|
||||
true,
|
||||
distFn
|
||||
);
|
||||
// Max Queue
|
||||
final DistancedItemQueue<K, T> wQueue = cQueue.reverse();
|
||||
final Set<T> visited = new HashSet<>();
|
||||
float lowerBoundDistance = wQueue.peek().getDistance();
|
||||
visited.add(entryPoint);
|
||||
|
||||
while (cQueue.nonEmpty()) {
|
||||
final DistancedItem<T> candidate = cQueue.peek();
|
||||
if (candidate.getDistance() > lowerBoundDistance) {
|
||||
break;
|
||||
}
|
||||
|
||||
cQueue.dequeue();
|
||||
final List<T> list = getConnectionListForRead(candidate.getItem(), level);
|
||||
for (T nn : list) {
|
||||
if (!visited.contains(nn)) {
|
||||
visited.add(nn);
|
||||
final float distance = distFn.distance(item, nn);
|
||||
if (wQueue.size() < ef || distance < wQueue.peek().getDistance()) {
|
||||
cQueue.enqueue(nn, distance);
|
||||
|
||||
if (isUpdate && item.equals(nn)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
wQueue.enqueue(nn, distance);
|
||||
if (wQueue.size() > ef) {
|
||||
wQueue.dequeue();
|
||||
}
|
||||
|
||||
lowerBoundDistance = wQueue.peek().getDistance();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return wQueue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize hnsw index
|
||||
*/
|
||||
public void toDirectory(IndexOutputFile indexOutputFile, Injection<T, byte[]> injection)
|
||||
throws IOException, TException {
|
||||
final int totalGraphEntries = HnswIndexIOUtil.saveHnswGraphEntries(
|
||||
graph,
|
||||
indexOutputFile.createFile(GRAPH_FILE_NAME).getOutputStream(),
|
||||
injection);
|
||||
|
||||
HnswIndexIOUtil.saveMetadata(
|
||||
graphMeta.get(),
|
||||
efConstruction,
|
||||
maxM,
|
||||
totalGraphEntries,
|
||||
injection,
|
||||
indexOutputFile.createFile(METADATA_FILE_NAME).getOutputStream());
|
||||
}
|
||||
|
||||
/**
|
||||
* Load hnsw index
|
||||
*/
|
||||
public static <T, Q> HnswIndex<T, Q> loadHnswIndex(
|
||||
DistanceFunction<T, T> distFnIndex,
|
||||
DistanceFunction<Q, T> distFnQuery,
|
||||
AbstractFile directory,
|
||||
Injection<T, byte[]> injection,
|
||||
RandomProvider randomProvider) throws IOException, TException {
|
||||
final AbstractFile graphFile = directory.getChild(GRAPH_FILE_NAME);
|
||||
final AbstractFile metadataFile = directory.getChild(METADATA_FILE_NAME);
|
||||
final HnswInternalIndexMetadata metadata = HnswIndexIOUtil.loadMetadata(metadataFile);
|
||||
final Map<HnswNode<T>, ImmutableList<T>> graph =
|
||||
HnswIndexIOUtil.loadHnswGraph(graphFile, injection, metadata.numElements);
|
||||
final ByteBuffer entryPointBB = metadata.entryPoint;
|
||||
final HnswMeta<T> graphMeta = new HnswMeta<>(
|
||||
metadata.maxLevel,
|
||||
entryPointBB == null ? Optional.empty()
|
||||
: Optional.of(injection.invert(ArrayByteBufferCodec.decode(entryPointBB)).get())
|
||||
);
|
||||
return new HnswIndex<>(
|
||||
distFnIndex,
|
||||
distFnQuery,
|
||||
metadata.efConstruction,
|
||||
metadata.maxM,
|
||||
metadata.numElements,
|
||||
graphMeta,
|
||||
graph,
|
||||
randomProvider
|
||||
);
|
||||
}
|
||||
|
||||
private List<T> getConnectionListForRead(T node, int level) {
|
||||
final Lock curLock = locks.computeIfAbsent(node, lockProvider).readLock();
|
||||
curLock.lock();
|
||||
final List<T> list;
|
||||
try {
|
||||
list = graph
|
||||
.getOrDefault(HnswNode.from(level, node), ImmutableList.of());
|
||||
} finally {
|
||||
curLock.unlock();
|
||||
}
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
AtomicReference<HnswMeta<T>> getGraphMeta() {
|
||||
return graphMeta;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Map<T, ReadWriteLock> getLocks() {
|
||||
return locks;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Map<HnswNode<T>, ImmutableList<T>> getGraph() {
|
||||
return graph;
|
||||
}
|
||||
|
||||
public interface RandomProvider {
|
||||
/**
|
||||
* RandomProvider interface made public for scala 2.12 compat
|
||||
*/
|
||||
Random get();
|
||||
}
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndexIOUtil.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswIndexIOUtil.docx
Normal file
Binary file not shown.
@ -1,133 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import org.apache.thrift.TDeserializer;
|
||||
import org.apache.thrift.TException;
|
||||
import org.apache.thrift.TSerializer;
|
||||
import org.apache.thrift.protocol.TBinaryProtocol;
|
||||
import org.apache.thrift.protocol.TProtocol;
|
||||
import org.apache.thrift.transport.TIOStreamTransport;
|
||||
import org.apache.thrift.transport.TTransportException;
|
||||
|
||||
import com.twitter.ann.common.thriftjava.HnswGraphEntry;
|
||||
import com.twitter.ann.common.thriftjava.HnswInternalIndexMetadata;
|
||||
import com.twitter.bijection.Injection;
|
||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec;
|
||||
import com.twitter.search.common.file.AbstractFile;
|
||||
|
||||
public final class HnswIndexIOUtil {
|
||||
private HnswIndexIOUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Save thrift object in file
|
||||
*/
|
||||
public static <T> void saveMetadata(
|
||||
HnswMeta<T> graphMeta,
|
||||
int efConstruction,
|
||||
int maxM,
|
||||
int numElements,
|
||||
Injection<T, byte[]> injection,
|
||||
OutputStream outputStream
|
||||
) throws IOException, TException {
|
||||
final int maxLevel = graphMeta.getMaxLevel();
|
||||
final HnswInternalIndexMetadata metadata = new HnswInternalIndexMetadata(
|
||||
maxLevel,
|
||||
efConstruction,
|
||||
maxM,
|
||||
numElements
|
||||
);
|
||||
|
||||
if (graphMeta.getEntryPoint().isPresent()) {
|
||||
metadata.setEntryPoint(injection.apply(graphMeta.getEntryPoint().get()));
|
||||
}
|
||||
final TSerializer serializer = new TSerializer(new TBinaryProtocol.Factory());
|
||||
outputStream.write(serializer.serialize(metadata));
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Load Hnsw index metadata
|
||||
*/
|
||||
public static HnswInternalIndexMetadata loadMetadata(AbstractFile file)
|
||||
throws IOException, TException {
|
||||
final HnswInternalIndexMetadata obj = new HnswInternalIndexMetadata();
|
||||
final TDeserializer deserializer = new TDeserializer(new TBinaryProtocol.Factory());
|
||||
deserializer.deserialize(obj, file.getByteSource().read());
|
||||
return obj;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load Hnsw graph entries from file
|
||||
*/
|
||||
public static <T> Map<HnswNode<T>, ImmutableList<T>> loadHnswGraph(
|
||||
AbstractFile file,
|
||||
Injection<T, byte[]> injection,
|
||||
int numElements
|
||||
) throws IOException, TException {
|
||||
final InputStream stream = file.getByteSource().openBufferedStream();
|
||||
final TProtocol protocol = new TBinaryProtocol(new TIOStreamTransport(stream));
|
||||
final Map<HnswNode<T>, ImmutableList<T>> graph =
|
||||
new HashMap<>(numElements);
|
||||
while (true) {
|
||||
try {
|
||||
final HnswGraphEntry entry = new HnswGraphEntry();
|
||||
entry.read(protocol);
|
||||
final HnswNode<T> node = HnswNode.from(entry.level,
|
||||
injection.invert(ArrayByteBufferCodec.decode(entry.key)).get());
|
||||
final List<T> list = entry.getNeighbours().stream()
|
||||
.map(bb -> injection.invert(ArrayByteBufferCodec.decode(bb)).get())
|
||||
.collect(Collectors.toList());
|
||||
graph.put(node, ImmutableList.copyOf(list.iterator()));
|
||||
} catch (TException e) {
|
||||
if (e instanceof TTransportException
|
||||
&& TTransportException.class.cast(e).getType() == TTransportException.END_OF_FILE) {
|
||||
stream.close();
|
||||
break;
|
||||
}
|
||||
stream.close();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save hnsw graph in file
|
||||
*
|
||||
* @return number of keys in the graph
|
||||
*/
|
||||
public static <T> int saveHnswGraphEntries(
|
||||
Map<HnswNode<T>, ImmutableList<T>> graph,
|
||||
OutputStream outputStream,
|
||||
Injection<T, byte[]> injection
|
||||
) throws IOException, TException {
|
||||
final TProtocol protocol = new TBinaryProtocol(new TIOStreamTransport(outputStream));
|
||||
final Set<HnswNode<T>> nodes = graph.keySet();
|
||||
for (HnswNode<T> node : nodes) {
|
||||
final HnswGraphEntry entry = new HnswGraphEntry();
|
||||
entry.setLevel(node.level);
|
||||
entry.setKey(injection.apply(node.item));
|
||||
final List<ByteBuffer> nn = graph.getOrDefault(node, ImmutableList.of()).stream()
|
||||
.map(t -> ByteBuffer.wrap(injection.apply(t)))
|
||||
.collect(Collectors.toList());
|
||||
entry.setNeighbours(nn);
|
||||
entry.write(protocol);
|
||||
}
|
||||
|
||||
outputStream.close();
|
||||
return nodes.size();
|
||||
}
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswMeta.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswMeta.docx
Normal file
Binary file not shown.
@ -1,45 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
class HnswMeta<T> {
|
||||
private final int maxLevel;
|
||||
private final Optional<T> entryPoint;
|
||||
|
||||
HnswMeta(int maxLevel, Optional<T> entryPoint) {
|
||||
this.maxLevel = maxLevel;
|
||||
this.entryPoint = entryPoint;
|
||||
}
|
||||
|
||||
public int getMaxLevel() {
|
||||
return maxLevel;
|
||||
}
|
||||
|
||||
public Optional<T> getEntryPoint() {
|
||||
return entryPoint;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
HnswMeta<?> hnswMeta = (HnswMeta<?>) o;
|
||||
return maxLevel == hnswMeta.maxLevel
|
||||
&& Objects.equals(entryPoint, hnswMeta.entryPoint);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(maxLevel, entryPoint);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "HnswMeta{maxLevel=" + maxLevel + ", entryPoint=" + entryPoint + '}';
|
||||
}
|
||||
}
|
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswNode.docx
Normal file
BIN
ann/src/main/java/com/twitter/ann/hnsw/HnswNode.docx
Normal file
Binary file not shown.
@ -1,45 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
import org.apache.commons.lang.builder.EqualsBuilder;
|
||||
import org.apache.commons.lang.builder.HashCodeBuilder;
|
||||
|
||||
public class HnswNode<T> {
|
||||
public final int level;
|
||||
public final T item;
|
||||
|
||||
public HnswNode(int level, T item) {
|
||||
this.level = level;
|
||||
this.item = item;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a hnsw node.
|
||||
*/
|
||||
public static <T> HnswNode<T> from(int level, T item) {
|
||||
return new HnswNode<>(level, item);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o == this) {
|
||||
return true;
|
||||
}
|
||||
if (!(o instanceof HnswNode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
HnswNode<?> that = (HnswNode<?>) o;
|
||||
return new EqualsBuilder()
|
||||
.append(this.item, that.item)
|
||||
.append(this.level, that.level)
|
||||
.isEquals();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return new HashCodeBuilder()
|
||||
.append(item)
|
||||
.append(level)
|
||||
.toHashCode();
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,7 +0,0 @@
|
||||
package com.twitter.ann.hnsw;
|
||||
|
||||
public class IllegalDuplicateInsertException extends Exception {
|
||||
public IllegalDuplicateInsertException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
resources(
|
||||
name = "sql",
|
||||
sources = ["bq.sql"],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "faiss_indexing",
|
||||
sources = ["**/*.py"],
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
":sql",
|
||||
"3rdparty/python/apache-beam:default",
|
||||
"3rdparty/python/faiss-gpu:default",
|
||||
"3rdparty/python/gcsfs:default",
|
||||
"3rdparty/python/google-cloud-bigquery:default",
|
||||
"3rdparty/python/google-cloud-storage",
|
||||
"3rdparty/python/numpy:default",
|
||||
"3rdparty/python/pandas:default",
|
||||
"3rdparty/python/pandas-gbq:default",
|
||||
"3rdparty/python/pyarrow:default",
|
||||
"src/python/twitter/ml/common/apache_beam",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "faiss_indexing_bin",
|
||||
sources = ["faiss_index_bq_dataset.py"],
|
||||
platforms = [
|
||||
"current",
|
||||
"linux_x86_64",
|
||||
],
|
||||
tags = ["no-mypy"],
|
||||
zip_safe = False,
|
||||
dependencies = [
|
||||
":faiss_indexing",
|
||||
"3rdparty/python/_closures/ann/src/main/python/dataflow:faiss_indexing_bin",
|
||||
],
|
||||
)
|
BIN
ann/src/main/python/dataflow/BUILD.docx
Normal file
BIN
ann/src/main/python/dataflow/BUILD.docx
Normal file
Binary file not shown.
BIN
ann/src/main/python/dataflow/bq.docx
Normal file
BIN
ann/src/main/python/dataflow/bq.docx
Normal file
Binary file not shown.
@ -1,6 +0,0 @@
|
||||
WITH maxts as (SELECT as value MAX(ts) as ts FROM `twttr-recos-ml-prod.ssedhain.twhin_tweet_avg_embedding`)
|
||||
SELECT entityId, embedding
|
||||
FROM `twttr-recos-ml-prod.ssedhain.twhin_tweet_avg_embedding`
|
||||
WHERE ts >= (select max(maxts) from maxts)
|
||||
AND DATE(TIMESTAMP_MILLIS(createdAt)) <= (select max(maxts) from maxts)
|
||||
AND DATE(TIMESTAMP_MILLIS(createdAt)) >= DATE_SUB((select max(maxts) from maxts), INTERVAL 1 DAY)
|
BIN
ann/src/main/python/dataflow/faiss_index_bq_dataset.docx
Normal file
BIN
ann/src/main/python/dataflow/faiss_index_bq_dataset.docx
Normal file
Binary file not shown.
@ -1,232 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import apache_beam as beam
|
||||
from apache_beam.options.pipeline_options import PipelineOptions
|
||||
import faiss
|
||||
|
||||
|
||||
def parse_d6w_config(argv=None):
|
||||
"""Parse d6w config.
|
||||
:param argv: d6w config
|
||||
:return: dictionary containing d6w config
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="See https://docbird.twitter.biz/d6w/model.html for any parameters inherited from d6w job config"
|
||||
)
|
||||
parser.add_argument("--job_name", dest="job_name", required=True, help="d6w attribute")
|
||||
parser.add_argument("--project", dest="project", required=True, help="d6w attribute")
|
||||
parser.add_argument(
|
||||
"--staging_location", dest="staging_location", required=True, help="d6w attribute"
|
||||
)
|
||||
parser.add_argument("--temp_location", dest="temp_location", required=True, help="d6w attribute")
|
||||
parser.add_argument(
|
||||
"--output_location",
|
||||
dest="output_location",
|
||||
required=True,
|
||||
help="GCS bucket and path where resulting artifacts are uploaded",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--service_account_email", dest="service_account_email", required=True, help="d6w attribute"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--factory_string",
|
||||
dest="factory_string",
|
||||
required=False,
|
||||
help="FAISS factory string describing index to build. See https://github.com/facebookresearch/faiss/wiki/The-index-factory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric",
|
||||
dest="metric",
|
||||
required=True,
|
||||
help="Metric used to compute distance between embeddings. Valid values are 'l2', 'ip', 'l1', 'linf'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
dest="gpu",
|
||||
required=True,
|
||||
help="--use_gpu=yes if you want to use GPU during index building",
|
||||
)
|
||||
|
||||
known_args, unknown_args = parser.parse_known_args(argv)
|
||||
d6w_config = vars(known_args)
|
||||
d6w_config["gpu"] = d6w_config["gpu"].lower() == "yes"
|
||||
d6w_config["metric"] = parse_metric(d6w_config)
|
||||
|
||||
"""
|
||||
WARNING: Currently, d6w (a Twitter tool used to deploy Dataflow jobs to GCP) and
|
||||
PipelineOptions.for_dataflow_runner (a helper method in twitter.ml.common.apache_beam) do not
|
||||
play nicely together. The helper method will overwrite some of the config specified in the d6w
|
||||
file using the defaults in https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/src/python/twitter/ml/common/apache_beam/__init__.py?L24.'
|
||||
However, the d6w output message will still report that the config specified in the d6w file was used.
|
||||
"""
|
||||
logging.warning(
|
||||
f"The following d6w config parameters will be overwritten by the defaults in "
|
||||
f"https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/src/python/twitter/ml/common/apache_beam/__init__.py?L24\n"
|
||||
f"{str(unknown_args)}"
|
||||
)
|
||||
return d6w_config
|
||||
|
||||
|
||||
def get_bq_query():
|
||||
"""
|
||||
Query is expected to return rows with unique entityId
|
||||
"""
|
||||
return pkgutil.get_data(__name__, "bq.sql").decode("utf-8")
|
||||
|
||||
|
||||
def parse_metric(config):
|
||||
metric_str = config["metric"].lower()
|
||||
if metric_str == "l2":
|
||||
return faiss.METRIC_L2
|
||||
elif metric_str == "ip":
|
||||
return faiss.METRIC_INNER_PRODUCT
|
||||
elif metric_str == "l1":
|
||||
return faiss.METRIC_L1
|
||||
elif metric_str == "linf":
|
||||
return faiss.METRIC_Linf
|
||||
else:
|
||||
raise Exception(f"Unknown metric: {metric_str}")
|
||||
|
||||
|
||||
def run_pipeline(argv=[]):
|
||||
config = parse_d6w_config(argv)
|
||||
argv_with_extras = argv
|
||||
if config["gpu"]:
|
||||
argv_with_extras.extend(["--experiments", "use_runner_v2"])
|
||||
argv_with_extras.extend(
|
||||
["--experiments", "worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver"]
|
||||
)
|
||||
argv_with_extras.extend(
|
||||
[
|
||||
"--worker_harness_container_image",
|
||||
"gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
|
||||
]
|
||||
)
|
||||
|
||||
options = PipelineOptions(argv_with_extras)
|
||||
output_bucket_name = urlsplit(config["output_location"]).netloc
|
||||
|
||||
with beam.Pipeline(options=options) as p:
|
||||
input_data = p | "Read from BigQuery" >> beam.io.ReadFromBigQuery(
|
||||
method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
|
||||
query=get_bq_query(),
|
||||
use_standard_sql=True,
|
||||
)
|
||||
|
||||
index_built = input_data | "Build and upload index" >> beam.CombineGlobally(
|
||||
MergeAndBuildIndex(
|
||||
output_bucket_name,
|
||||
config["output_location"],
|
||||
config["factory_string"],
|
||||
config["metric"],
|
||||
config["gpu"],
|
||||
)
|
||||
)
|
||||
|
||||
# Make linter happy
|
||||
index_built
|
||||
|
||||
|
||||
class MergeAndBuildIndex(beam.CombineFn):
|
||||
def __init__(self, bucket_name, gcs_output_path, factory_string, metric, gpu):
|
||||
self.bucket_name = bucket_name
|
||||
self.gcs_output_path = gcs_output_path
|
||||
self.factory_string = factory_string
|
||||
self.metric = metric
|
||||
self.gpu = gpu
|
||||
|
||||
def create_accumulator(self):
|
||||
return []
|
||||
|
||||
def add_input(self, accumulator, element):
|
||||
accumulator.append(element)
|
||||
return accumulator
|
||||
|
||||
def merge_accumulators(self, accumulators):
|
||||
merged = []
|
||||
for accum in accumulators:
|
||||
merged.extend(accum)
|
||||
return merged
|
||||
|
||||
def extract_output(self, rows):
|
||||
# Reimports are needed on workers
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
import faiss
|
||||
from google.cloud import storage
|
||||
import numpy as np
|
||||
|
||||
client = storage.Client()
|
||||
bucket = client.get_bucket(self.bucket_name)
|
||||
|
||||
logging.info("Building FAISS index")
|
||||
logging.info(f"There are {len(rows)} rows")
|
||||
|
||||
ids = np.array([x["entityId"] for x in rows]).astype("long")
|
||||
embeds = np.array([x["embedding"] for x in rows]).astype("float32")
|
||||
dimensions = len(embeds[0])
|
||||
N = ids.shape[0]
|
||||
logging.info(f"There are {dimensions} dimensions")
|
||||
|
||||
if self.factory_string is None:
|
||||
M = 48
|
||||
|
||||
divideable_dimensions = (dimensions // M) * M
|
||||
if divideable_dimensions != dimensions:
|
||||
opq_prefix = f"OPQ{M}_{divideable_dimensions}"
|
||||
else:
|
||||
opq_prefix = f"OPQ{M}"
|
||||
|
||||
clusters = N // 20
|
||||
self.factory_string = f"{opq_prefix},IVF{clusters},PQ{M}"
|
||||
|
||||
logging.info(f"Factory string is {self.factory_string}, metric={self.metric}")
|
||||
|
||||
if self.gpu:
|
||||
logging.info("Using GPU")
|
||||
|
||||
res = faiss.StandardGpuResources()
|
||||
cpu_index = faiss.index_factory(dimensions, self.factory_string, self.metric)
|
||||
cpu_index = faiss.IndexIDMap(cpu_index)
|
||||
gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
|
||||
gpu_index.train(embeds)
|
||||
gpu_index.add_with_ids(embeds, ids)
|
||||
cpu_index = faiss.index_gpu_to_cpu(gpu_index)
|
||||
else:
|
||||
logging.info("Using CPU")
|
||||
|
||||
cpu_index = faiss.index_factory(dimensions, self.factory_string, self.metric)
|
||||
cpu_index = faiss.IndexIDMap(cpu_index)
|
||||
cpu_index.train(embeds)
|
||||
cpu_index.add_with_ids(embeds, ids)
|
||||
|
||||
logging.info("Built faiss index")
|
||||
|
||||
local_path = "/indices"
|
||||
logging.info(f"Writing indices to local {local_path}")
|
||||
subprocess.run(f"mkdir -p {local_path}".strip().split())
|
||||
local_index_path = os.path.join(local_path, "result.index")
|
||||
|
||||
faiss.write_index(cpu_index, local_index_path)
|
||||
logging.info(f"Done writing indices to local {local_path}")
|
||||
|
||||
logging.info(f"Uploading to GCS with path {self.gcs_output_path}")
|
||||
assert os.path.isdir(local_path)
|
||||
for local_file in glob.glob(local_path + "/*"):
|
||||
remote_path = os.path.join(
|
||||
self.gcs_output_path.split("/")[-1], local_file[1 + len(local_path) :]
|
||||
)
|
||||
blob = bucket.blob(remote_path)
|
||||
blob.upload_from_filename(local_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
run_pipeline(sys.argv)
|
@ -1,34 +0,0 @@
|
||||
FROM --platform=linux/amd64 nvidia/cuda:11.2.2-cudnn8-runtime-ubuntu20.04
|
||||
|
||||
RUN \
|
||||
# Add Deadsnakes repository that has a variety of Python packages for Ubuntu.
|
||||
# See: https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa
|
||||
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F23C5A6CF475977595C89F51BA6932366A755776 \
|
||||
&& echo "deb http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main" >> /etc/apt/sources.list.d/custom.list \
|
||||
&& echo "deb-src http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal main" >> /etc/apt/sources.list.d/custom.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y curl \
|
||||
python3.7 \
|
||||
# With python3.8 package, distutils need to be installed separately.
|
||||
python3.7-distutils \
|
||||
python3-dev \
|
||||
python3.7-dev \
|
||||
libpython3.7-dev \
|
||||
python3-apt \
|
||||
gcc \
|
||||
g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.7 10
|
||||
RUN rm -f /usr/bin/python3 && ln -s /usr/bin/python3.7 /usr/bin/python3
|
||||
RUN \
|
||||
curl https://bootstrap.pypa.io/get-pip.py | python \
|
||||
&& pip3 install pip==22.0.3 \
|
||||
&& python3 -m pip install --no-cache-dir apache-beam[gcp]==2.39.0
|
||||
# Verify that there are no conflicting dependencies.
|
||||
RUN pip3 check
|
||||
|
||||
# Copy the Apache Beam worker dependencies from the Beam Python 3.7 SDK image.
|
||||
COPY --from=apache/beam_python3.7_sdk:2.39.0 /opt/apache/beam /opt/apache/beam
|
||||
|
||||
# Set the entrypoint to Apache Beam SDK worker launcher.
|
||||
ENTRYPOINT [ "/opt/apache/beam/boot" ]
|
BIN
ann/src/main/python/dataflow/worker_harness/Dockerfile.docx
Normal file
BIN
ann/src/main/python/dataflow/worker_harness/Dockerfile.docx
Normal file
Binary file not shown.
BIN
ann/src/main/python/dataflow/worker_harness/cloudbuild.docx
Normal file
BIN
ann/src/main/python/dataflow/worker_harness/cloudbuild.docx
Normal file
Binary file not shown.
@ -1,6 +0,0 @@
|
||||
steps:
|
||||
- name: 'gcr.io/cloud-builders/docker'
|
||||
args: ['build', '-t', 'gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7', '.']
|
||||
- name: 'gcr.io/cloud-builders/docker'
|
||||
args: ['push', 'gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7']
|
||||
images: ['gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7']
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/AnnoyCommon.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/AnnoyCommon.docx
Normal file
Binary file not shown.
@ -1,44 +0,0 @@
|
||||
package com.twitter.ann.annoy
|
||||
|
||||
import com.twitter.ann.common.RuntimeParams
|
||||
import com.twitter.ann.common.thriftscala.AnnoyIndexMetadata
|
||||
import com.twitter.bijection.Injection
|
||||
import com.twitter.mediaservices.commons.codec.ThriftByteBufferCodec
|
||||
import com.twitter.ann.common.thriftscala.{AnnoyRuntimeParam, RuntimeParams => ServiceRuntimeParams}
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
object AnnoyCommon {
|
||||
private[annoy] lazy val MetadataCodec = new ThriftByteBufferCodec(AnnoyIndexMetadata)
|
||||
private[annoy] val IndexFileName = "annoy_index"
|
||||
private[annoy] val MetaDataFileName = "annoy_index_metadata"
|
||||
private[annoy] val IndexIdMappingFileName = "annoy_index_id_mapping"
|
||||
|
||||
val RuntimeParamsInjection: Injection[AnnoyRuntimeParams, ServiceRuntimeParams] =
|
||||
new Injection[AnnoyRuntimeParams, ServiceRuntimeParams] {
|
||||
override def apply(scalaParams: AnnoyRuntimeParams): ServiceRuntimeParams = {
|
||||
ServiceRuntimeParams.AnnoyParam(
|
||||
AnnoyRuntimeParam(
|
||||
scalaParams.nodesToExplore
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
override def invert(thriftParams: ServiceRuntimeParams): Try[AnnoyRuntimeParams] =
|
||||
thriftParams match {
|
||||
case ServiceRuntimeParams.AnnoyParam(annoyParam) =>
|
||||
Success(
|
||||
AnnoyRuntimeParams(annoyParam.numOfNodesToExplore)
|
||||
)
|
||||
case p => Failure(new IllegalArgumentException(s"Expected AnnoyRuntimeParams got $p"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class AnnoyRuntimeParams(
|
||||
/* Number of vectors to evaluate while searching. A larger value will give more accurate results, but will take longer time to return.
|
||||
* Default value would be numberOfTrees*numberOfNeigboursRequested
|
||||
*/
|
||||
nodesToExplore: Option[Int])
|
||||
extends RuntimeParams {
|
||||
override def toString: String = s"AnnoyRuntimeParams( nodesToExplore = $nodesToExplore)"
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
platform = "java8",
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/com/spotify:annoy-java",
|
||||
"3rdparty/jvm/com/spotify:annoy-snapshot",
|
||||
"3rdparty/jvm/com/twitter/storehaus:core",
|
||||
"ann/src/main/scala/com/twitter/ann/common",
|
||||
"ann/src/main/scala/com/twitter/ann/file_store",
|
||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
||||
"mediaservices/commons",
|
||||
"src/java/com/twitter/search/common/file",
|
||||
"src/scala/com/twitter/ml/api/embedding",
|
||||
],
|
||||
exports = [
|
||||
"ann/src/main/scala/com/twitter/ann/common",
|
||||
"src/java/com/twitter/common_internal/hadoop",
|
||||
"src/java/com/twitter/search/common/file",
|
||||
"src/scala/com/twitter/ml/api/embedding",
|
||||
],
|
||||
)
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/BUILD.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/BUILD.docx
Normal file
Binary file not shown.
Binary file not shown.
@ -1,123 +0,0 @@
|
||||
package com.twitter.ann.annoy
|
||||
|
||||
import com.spotify.annoy.jni.base.{Annoy => AnnoyLib}
|
||||
import com.twitter.ann.annoy.AnnoyCommon.IndexFileName
|
||||
import com.twitter.ann.annoy.AnnoyCommon.MetaDataFileName
|
||||
import com.twitter.ann.annoy.AnnoyCommon.MetadataCodec
|
||||
import com.twitter.ann.common.EmbeddingType._
|
||||
import com.twitter.ann.common._
|
||||
import com.twitter.ann.common.thriftscala.AnnoyIndexMetadata
|
||||
import com.twitter.concurrent.AsyncSemaphore
|
||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
||||
import com.twitter.search.common.file.AbstractFile
|
||||
import com.twitter.search.common.file.LocalFile
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.FuturePool
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
import org.apache.beam.sdk.io.fs.ResourceId
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
private[annoy] object RawAnnoyIndexBuilder {
|
||||
private[annoy] def apply[D <: Distance[D]](
|
||||
dimension: Int,
|
||||
numOfTrees: Int,
|
||||
metric: Metric[D],
|
||||
futurePool: FuturePool
|
||||
): RawAppendable[AnnoyRuntimeParams, D] with Serialization = {
|
||||
val indexBuilder = AnnoyLib.newIndex(dimension, annoyMetric(metric))
|
||||
new RawAnnoyIndexBuilder(dimension, numOfTrees, metric, indexBuilder, futurePool)
|
||||
}
|
||||
|
||||
private[this] def annoyMetric(metric: Metric[_]): AnnoyLib.Metric = {
|
||||
metric match {
|
||||
case L2 => AnnoyLib.Metric.EUCLIDEAN
|
||||
case Cosine => AnnoyLib.Metric.ANGULAR
|
||||
case _ => throw new RuntimeException("Not supported: " + metric)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[this] class RawAnnoyIndexBuilder[D <: Distance[D]](
|
||||
dimension: Int,
|
||||
numOfTrees: Int,
|
||||
metric: Metric[D],
|
||||
indexBuilder: AnnoyLib.Builder,
|
||||
futurePool: FuturePool)
|
||||
extends RawAppendable[AnnoyRuntimeParams, D]
|
||||
with Serialization {
|
||||
private[this] var counter = 0
|
||||
// Note: Only one thread can access the underlying index, multithreaded index building not supported
|
||||
private[this] val semaphore = new AsyncSemaphore(1)
|
||||
|
||||
override def append(embedding: EmbeddingVector): Future[Long] =
|
||||
semaphore.acquireAndRun({
|
||||
counter += 1
|
||||
indexBuilder.addItem(
|
||||
counter,
|
||||
embedding.toArray
|
||||
.map(float => float2Float(float))
|
||||
.toList
|
||||
.asJava
|
||||
)
|
||||
|
||||
Future.value(counter)
|
||||
})
|
||||
|
||||
override def toQueryable: Queryable[Long, AnnoyRuntimeParams, D] = {
|
||||
val tempDirParent = Files.createTempDirectory("raw_annoy_index").toFile
|
||||
tempDirParent.deleteOnExit
|
||||
val tempDir = new LocalFile(tempDirParent)
|
||||
this.toDirectory(tempDir)
|
||||
RawAnnoyQueryIndex(
|
||||
dimension,
|
||||
metric,
|
||||
futurePool,
|
||||
tempDir
|
||||
)
|
||||
}
|
||||
|
||||
override def toDirectory(directory: ResourceId): Unit = {
|
||||
toDirectory(new IndexOutputFile(directory))
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the annoy index in a directory.
|
||||
* @param directory: Directory to save to.
|
||||
*/
|
||||
override def toDirectory(directory: AbstractFile): Unit = {
|
||||
toDirectory(new IndexOutputFile(directory))
|
||||
}
|
||||
|
||||
private def toDirectory(directory: IndexOutputFile): Unit = {
|
||||
val indexFile = directory.createFile(IndexFileName)
|
||||
saveIndex(indexFile)
|
||||
|
||||
val metaDataFile = directory.createFile(MetaDataFileName)
|
||||
saveMetadata(metaDataFile)
|
||||
}
|
||||
|
||||
private[this] def saveIndex(indexFile: IndexOutputFile): Unit = {
|
||||
val index = indexBuilder
|
||||
.build(numOfTrees)
|
||||
val temp = new LocalFile(File.createTempFile(IndexFileName, null))
|
||||
index.save(temp.getPath)
|
||||
indexFile.copyFrom(temp.getByteSource.openStream())
|
||||
temp.delete()
|
||||
}
|
||||
|
||||
private[this] def saveMetadata(metadataFile: IndexOutputFile): Unit = {
|
||||
val numberOfVectorsIndexed = counter
|
||||
val metadata = AnnoyIndexMetadata(
|
||||
dimension,
|
||||
Metric.toThrift(metric),
|
||||
numOfTrees,
|
||||
numberOfVectorsIndexed
|
||||
)
|
||||
val bytes = ArrayByteBufferCodec.decode(MetadataCodec.encode(metadata))
|
||||
val temp = new LocalFile(File.createTempFile(MetaDataFileName, null))
|
||||
temp.getByteSink.write(bytes)
|
||||
metadataFile.copyFrom(temp.getByteSource.openStream())
|
||||
temp.delete()
|
||||
}
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/RawAnnoyQueryIndex.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/RawAnnoyQueryIndex.docx
Normal file
Binary file not shown.
@ -1,142 +0,0 @@
|
||||
package com.twitter.ann.annoy
|
||||
|
||||
import com.spotify.annoy.{ANNIndex, IndexType}
|
||||
import com.twitter.ann.annoy.AnnoyCommon._
|
||||
import com.twitter.ann.common._
|
||||
import com.twitter.ann.common.EmbeddingType._
|
||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
||||
import com.twitter.search.common.file.{AbstractFile, LocalFile}
|
||||
import com.twitter.util.{Future, FuturePool}
|
||||
import java.io.File
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
private[annoy] object RawAnnoyQueryIndex {
|
||||
private[annoy] def apply[D <: Distance[D]](
|
||||
dimension: Int,
|
||||
metric: Metric[D],
|
||||
futurePool: FuturePool,
|
||||
directory: AbstractFile
|
||||
): Queryable[Long, AnnoyRuntimeParams, D] = {
|
||||
val metadataFile = directory.getChild(MetaDataFileName)
|
||||
val indexFile = directory.getChild(IndexFileName)
|
||||
val metadata = MetadataCodec.decode(
|
||||
ArrayByteBufferCodec.encode(metadataFile.getByteSource.read())
|
||||
)
|
||||
|
||||
val existingDimension = metadata.dimension
|
||||
assert(
|
||||
existingDimension == dimension,
|
||||
s"Dimensions do not match. requested: $dimension existing: $existingDimension"
|
||||
)
|
||||
|
||||
val existingMetric = Metric.fromThrift(metadata.distanceMetric)
|
||||
assert(
|
||||
existingMetric == metric,
|
||||
s"DistanceMetric do not match. requested: $metric existing: $existingMetric"
|
||||
)
|
||||
|
||||
val index = loadIndex(indexFile, dimension, annoyMetric(metric))
|
||||
new RawAnnoyQueryIndex[D](
|
||||
dimension,
|
||||
metric,
|
||||
metadata.numOfTrees,
|
||||
index,
|
||||
futurePool
|
||||
)
|
||||
}
|
||||
|
||||
private[this] def annoyMetric(metric: Metric[_]): IndexType = {
|
||||
metric match {
|
||||
case L2 => IndexType.EUCLIDEAN
|
||||
case Cosine => IndexType.ANGULAR
|
||||
case _ => throw new RuntimeException("Not supported: " + metric)
|
||||
}
|
||||
}
|
||||
|
||||
private[this] def loadIndex(
|
||||
indexFile: AbstractFile,
|
||||
dimension: Int,
|
||||
indexType: IndexType
|
||||
): ANNIndex = {
|
||||
var localIndexFile = indexFile
|
||||
|
||||
// If not a local file copy to local, so that it can be memory mapped.
|
||||
if (!indexFile.isInstanceOf[LocalFile]) {
|
||||
val tempFile = File.createTempFile(IndexFileName, null)
|
||||
tempFile.deleteOnExit()
|
||||
|
||||
val temp = new LocalFile(tempFile)
|
||||
indexFile.copyTo(temp)
|
||||
localIndexFile = temp
|
||||
}
|
||||
|
||||
new ANNIndex(
|
||||
dimension,
|
||||
localIndexFile.getPath(),
|
||||
indexType
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private[this] class RawAnnoyQueryIndex[D <: Distance[D]](
|
||||
dimension: Int,
|
||||
metric: Metric[D],
|
||||
numOfTrees: Int,
|
||||
index: ANNIndex,
|
||||
futurePool: FuturePool)
|
||||
extends Queryable[Long, AnnoyRuntimeParams, D]
|
||||
with AutoCloseable {
|
||||
override def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbours: Int,
|
||||
runtimeParams: AnnoyRuntimeParams
|
||||
): Future[List[Long]] = {
|
||||
queryWithDistance(embedding, numOfNeighbours, runtimeParams)
|
||||
.map(_.map(_.neighbor))
|
||||
}
|
||||
|
||||
override def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbours: Int,
|
||||
runtimeParams: AnnoyRuntimeParams
|
||||
): Future[List[NeighborWithDistance[Long, D]]] = {
|
||||
futurePool {
|
||||
val queryVector = embedding.toArray
|
||||
val neigboursToRequest = neighboursToRequest(numOfNeighbours, runtimeParams)
|
||||
val neigbours = index
|
||||
.getNearestWithDistance(queryVector, neigboursToRequest)
|
||||
.asScala
|
||||
.take(numOfNeighbours)
|
||||
.map { nn =>
|
||||
val id = nn.getFirst.toLong
|
||||
val distance = metric.fromAbsoluteDistance(nn.getSecond)
|
||||
NeighborWithDistance(id, distance)
|
||||
}
|
||||
.toList
|
||||
|
||||
neigbours
|
||||
}
|
||||
}
|
||||
|
||||
// Annoy java lib do not expose param for numOfNodesToExplore.
|
||||
// Default number is numOfTrees*numOfNeigbours.
|
||||
// Simple hack is to artificially increase the numOfNeighbours to be requested and then just cap it before returning.
|
||||
private[this] def neighboursToRequest(
|
||||
numOfNeighbours: Int,
|
||||
annoyParams: AnnoyRuntimeParams
|
||||
): Int = {
|
||||
annoyParams.nodesToExplore match {
|
||||
case Some(nodesToExplore) => {
|
||||
val neigboursToRequest = nodesToExplore / numOfTrees
|
||||
if (neigboursToRequest < numOfNeighbours)
|
||||
numOfNeighbours
|
||||
else
|
||||
neigboursToRequest
|
||||
}
|
||||
case _ => numOfNeighbours
|
||||
}
|
||||
}
|
||||
|
||||
// To close the memory map based file resource.
|
||||
override def close(): Unit = index.close()
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/annoy/TypedAnnoyIndex.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/annoy/TypedAnnoyIndex.docx
Normal file
Binary file not shown.
@ -1,55 +0,0 @@
|
||||
package com.twitter.ann.annoy
|
||||
|
||||
import com.twitter.ann.common._
|
||||
import com.twitter.bijection.Injection
|
||||
import com.twitter.search.common.file.AbstractFile
|
||||
import com.twitter.util.FuturePool
|
||||
|
||||
// Class to provide Annoy based ann index.
|
||||
object TypedAnnoyIndex {
|
||||
|
||||
/**
|
||||
* Create Annoy based typed index builder that serializes index to a directory (HDFS/Local file system).
|
||||
* It cannot be used in scalding as it leverage C/C++ jni bindings, whose build conflicts with version of some libs installed on hadoop.
|
||||
* You can use it on aurora or with IndexBuilding job which triggers scalding job but then streams data to aurora machine for building index.
|
||||
* @param dimension dimension of embedding
|
||||
* @param numOfTrees builds a forest of numOfTrees trees.
|
||||
* More trees gives higher precision when querying at the cost of increased memory and disk storage requirement at the build time.
|
||||
* At runtime the index will be memory mapped, so memory wont be an issue but disk storage would be needed.
|
||||
* @param metric distance metric for nearest neighbour search
|
||||
* @param injection Injection to convert bytes to Id.
|
||||
* @tparam T Type of Id for embedding
|
||||
* @tparam D Typed Distance
|
||||
* @return Serializable AnnoyIndex
|
||||
*/
|
||||
def indexBuilder[T, D <: Distance[D]](
|
||||
dimension: Int,
|
||||
numOfTrees: Int,
|
||||
metric: Metric[D],
|
||||
injection: Injection[T, Array[Byte]],
|
||||
futurePool: FuturePool
|
||||
): Appendable[T, AnnoyRuntimeParams, D] with Serialization = {
|
||||
TypedAnnoyIndexBuilderWithFile(dimension, numOfTrees, metric, injection, futurePool)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load Annoy based queryable index from a directory
|
||||
* @param dimension dimension of embedding
|
||||
* @param metric distance metric for nearest neighbour search
|
||||
* @param injection Injection to convert bytes to Id.
|
||||
* @param futurePool FuturePool
|
||||
* @param directory Directory (HDFS/Local file system) where serialized index is stored.
|
||||
* @tparam T Type of Id for embedding
|
||||
* @tparam D Typed Distance
|
||||
* @return Typed Queryable AnnoyIndex
|
||||
*/
|
||||
def loadQueryableIndex[T, D <: Distance[D]](
|
||||
dimension: Int,
|
||||
metric: Metric[D],
|
||||
injection: Injection[T, Array[Byte]],
|
||||
futurePool: FuturePool,
|
||||
directory: AbstractFile
|
||||
): Queryable[T, AnnoyRuntimeParams, D] = {
|
||||
TypedAnnoyQueryIndexWithFile(dimension, metric, injection, futurePool, directory)
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,55 +0,0 @@
|
||||
package com.twitter.ann.annoy
|
||||
|
||||
import com.twitter.ann.annoy.AnnoyCommon.IndexIdMappingFileName
|
||||
import com.twitter.ann.common._
|
||||
import com.twitter.ann.file_store.WritableIndexIdFileStore
|
||||
import com.twitter.bijection.Injection
|
||||
import com.twitter.search.common.file.AbstractFile
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.FuturePool
|
||||
import org.apache.beam.sdk.io.fs.ResourceId
|
||||
|
||||
private[annoy] object TypedAnnoyIndexBuilderWithFile {
|
||||
private[annoy] def apply[T, D <: Distance[D]](
|
||||
dimension: Int,
|
||||
numOfTrees: Int,
|
||||
metric: Metric[D],
|
||||
injection: Injection[T, Array[Byte]],
|
||||
futurePool: FuturePool
|
||||
): Appendable[T, AnnoyRuntimeParams, D] with Serialization = {
|
||||
val index = RawAnnoyIndexBuilder(dimension, numOfTrees, metric, futurePool)
|
||||
val writableFileStore = WritableIndexIdFileStore(injection)
|
||||
new TypedAnnoyIndexBuilderWithFile[T, D](index, writableFileStore)
|
||||
}
|
||||
}
|
||||
|
||||
private[this] class TypedAnnoyIndexBuilderWithFile[T, D <: Distance[D]](
|
||||
indexBuilder: RawAppendable[AnnoyRuntimeParams, D] with Serialization,
|
||||
store: WritableIndexIdFileStore[T])
|
||||
extends Appendable[T, AnnoyRuntimeParams, D]
|
||||
with Serialization {
|
||||
private[this] val transformedIndex = IndexTransformer.transformAppendable(indexBuilder, store)
|
||||
|
||||
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
|
||||
transformedIndex.append(entity)
|
||||
}
|
||||
|
||||
override def toDirectory(directory: ResourceId): Unit = {
|
||||
indexBuilder.toDirectory(directory)
|
||||
toDirectory(new IndexOutputFile(directory))
|
||||
}
|
||||
|
||||
override def toDirectory(directory: AbstractFile): Unit = {
|
||||
indexBuilder.toDirectory(directory)
|
||||
toDirectory(new IndexOutputFile(directory))
|
||||
}
|
||||
|
||||
private def toDirectory(directory: IndexOutputFile): Unit = {
|
||||
val indexIdFile = directory.createFile(IndexIdMappingFileName)
|
||||
store.save(indexIdFile)
|
||||
}
|
||||
|
||||
override def toQueryable: Queryable[T, AnnoyRuntimeParams, D] = {
|
||||
transformedIndex.toQueryable
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,42 +0,0 @@
|
||||
package com.twitter.ann.annoy
|
||||
|
||||
import com.twitter.ann.annoy.AnnoyCommon._
|
||||
import com.twitter.ann.common._
|
||||
import com.twitter.ann.file_store.ReadableIndexIdFileStore
|
||||
import com.twitter.bijection.Injection
|
||||
import com.twitter.search.common.file.AbstractFile
|
||||
import com.twitter.util.FuturePool
|
||||
|
||||
private[annoy] object TypedAnnoyQueryIndexWithFile {
|
||||
private[annoy] def apply[T, D <: Distance[D]](
|
||||
dimension: Int,
|
||||
metric: Metric[D],
|
||||
injection: Injection[T, Array[Byte]],
|
||||
futurePool: FuturePool,
|
||||
directory: AbstractFile
|
||||
): Queryable[T, AnnoyRuntimeParams, D] = {
|
||||
val deserializer =
|
||||
new TypedAnnoyQueryIndexWithFile(dimension, metric, futurePool, injection)
|
||||
deserializer.fromDirectory(directory)
|
||||
}
|
||||
}
|
||||
|
||||
private[this] class TypedAnnoyQueryIndexWithFile[T, D <: Distance[D]](
|
||||
dimension: Int,
|
||||
metric: Metric[D],
|
||||
futurePool: FuturePool,
|
||||
injection: Injection[T, Array[Byte]])
|
||||
extends QueryableDeserialization[
|
||||
T,
|
||||
AnnoyRuntimeParams,
|
||||
D,
|
||||
Queryable[T, AnnoyRuntimeParams, D]
|
||||
] {
|
||||
override def fromDirectory(directory: AbstractFile): Queryable[T, AnnoyRuntimeParams, D] = {
|
||||
val index = RawAnnoyQueryIndex(dimension, metric, futurePool, directory)
|
||||
|
||||
val indexIdFile = directory.getChild(IndexIdMappingFileName)
|
||||
val readableFileStore = ReadableIndexIdFileStore(indexIdFile, injection)
|
||||
IndexTransformer.transformQueryable(index, readableFileStore)
|
||||
}
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
platform = "java8",
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"ann/src/main/scala/com/twitter/ann/common",
|
||||
"ann/src/main/scala/com/twitter/ann/serialization",
|
||||
"ann/src/main/thrift/com/twitter/ann/serialization:serialization-scala",
|
||||
"src/java/com/twitter/search/common/file",
|
||||
],
|
||||
)
|
BIN
ann/src/main/scala/com/twitter/ann/brute_force/BUILD.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/brute_force/BUILD.docx
Normal file
Binary file not shown.
Binary file not shown.
@ -1,64 +0,0 @@
|
||||
package com.twitter.ann.brute_force
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import com.twitter.ann.common.{Distance, EntityEmbedding, Metric, QueryableDeserialization}
|
||||
import com.twitter.ann.serialization.{PersistedEmbeddingInjection, ThriftIteratorIO}
|
||||
import com.twitter.ann.serialization.thriftscala.PersistedEmbedding
|
||||
import com.twitter.search.common.file.{AbstractFile, LocalFile}
|
||||
import com.twitter.util.FuturePool
|
||||
import java.io.File
|
||||
|
||||
/**
|
||||
* @param factory creates a BruteForceIndex from the arguments. This is only exposed for testing.
|
||||
* If for some reason you pass this arg in make sure that it eagerly consumes the
|
||||
* iterator. If you don't you might close the input stream that the iterator is
|
||||
* using.
|
||||
* @tparam T the id of the embeddings
|
||||
*/
|
||||
class BruteForceDeserialization[T, D <: Distance[D]] @VisibleForTesting private[brute_force] (
|
||||
metric: Metric[D],
|
||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
||||
futurePool: FuturePool,
|
||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding],
|
||||
factory: (Metric[D], FuturePool, Iterator[EntityEmbedding[T]]) => BruteForceIndex[T, D])
|
||||
extends QueryableDeserialization[T, BruteForceRuntimeParams.type, D, BruteForceIndex[T, D]] {
|
||||
import BruteForceIndex._
|
||||
|
||||
def this(
|
||||
metric: Metric[D],
|
||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
||||
futurePool: FuturePool,
|
||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding]
|
||||
) = {
|
||||
this(
|
||||
metric,
|
||||
embeddingInjection,
|
||||
futurePool,
|
||||
thriftIteratorIO,
|
||||
factory = BruteForceIndex.apply[T, D]
|
||||
)
|
||||
}
|
||||
|
||||
override def fromDirectory(
|
||||
serializationDirectory: AbstractFile
|
||||
): BruteForceIndex[T, D] = {
|
||||
val file = File.createTempFile(DataFileName, "tmp")
|
||||
file.deleteOnExit()
|
||||
val temp = new LocalFile(file)
|
||||
val dataFile = serializationDirectory.getChild(DataFileName)
|
||||
dataFile.copyTo(temp)
|
||||
val inputStream = temp.getByteSource.openBufferedStream()
|
||||
try {
|
||||
val iterator: Iterator[PersistedEmbedding] = thriftIteratorIO.fromInputStream(inputStream)
|
||||
|
||||
val embeddings = iterator.map { thriftEmbedding =>
|
||||
embeddingInjection.invert(thriftEmbedding).get
|
||||
}
|
||||
|
||||
factory(metric, futurePool, embeddings)
|
||||
} finally {
|
||||
inputStream.close()
|
||||
temp.delete()
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,162 +0,0 @@
|
||||
package com.twitter.ann.brute_force
|
||||
|
||||
import com.twitter.ann.common.Appendable
|
||||
import com.twitter.ann.common.Distance
|
||||
import com.twitter.ann.common.EmbeddingType._
|
||||
import com.twitter.ann.common.EntityEmbedding
|
||||
import com.twitter.ann.common.IndexOutputFile
|
||||
import com.twitter.ann.common.Metric
|
||||
import com.twitter.ann.common.NeighborWithDistance
|
||||
import com.twitter.ann.common.Queryable
|
||||
import com.twitter.ann.common.RuntimeParams
|
||||
import com.twitter.ann.common.Serialization
|
||||
import com.twitter.ann.serialization.PersistedEmbeddingInjection
|
||||
import com.twitter.ann.serialization.ThriftIteratorIO
|
||||
import com.twitter.ann.serialization.thriftscala.PersistedEmbedding
|
||||
import com.twitter.search.common.file.AbstractFile
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.FuturePool
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
import org.apache.beam.sdk.io.fs.ResourceId
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
object BruteForceRuntimeParams extends RuntimeParams
|
||||
|
||||
object BruteForceIndex {
|
||||
val DataFileName = "BruteForceFileData"
|
||||
|
||||
def apply[T, D <: Distance[D]](
|
||||
metric: Metric[D],
|
||||
futurePool: FuturePool,
|
||||
initialEmbeddings: Iterator[EntityEmbedding[T]] = Iterator()
|
||||
): BruteForceIndex[T, D] = {
|
||||
val linkedQueue = new ConcurrentLinkedQueue[EntityEmbedding[T]]
|
||||
initialEmbeddings.foreach(embedding => linkedQueue.add(embedding))
|
||||
new BruteForceIndex(metric, futurePool, linkedQueue)
|
||||
}
|
||||
}
|
||||
|
||||
class BruteForceIndex[T, D <: Distance[D]] private (
|
||||
metric: Metric[D],
|
||||
futurePool: FuturePool,
|
||||
// visible for serialization
|
||||
private[brute_force] val linkedQueue: ConcurrentLinkedQueue[EntityEmbedding[T]])
|
||||
extends Appendable[T, BruteForceRuntimeParams.type, D]
|
||||
with Queryable[T, BruteForceRuntimeParams.type, D] {
|
||||
|
||||
override def append(embedding: EntityEmbedding[T]): Future[Unit] = {
|
||||
futurePool {
|
||||
linkedQueue.add(embedding)
|
||||
}
|
||||
}
|
||||
|
||||
override def toQueryable: Queryable[T, BruteForceRuntimeParams.type, D] = this
|
||||
|
||||
override def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbours: Int,
|
||||
runtimeParams: BruteForceRuntimeParams.type
|
||||
): Future[List[T]] = {
|
||||
queryWithDistance(embedding, numOfNeighbours, runtimeParams).map { neighborsWithDistance =>
|
||||
neighborsWithDistance.map(_.neighbor)
|
||||
}
|
||||
}
|
||||
|
||||
override def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbours: Int,
|
||||
runtimeParams: BruteForceRuntimeParams.type
|
||||
): Future[List[NeighborWithDistance[T, D]]] = {
|
||||
futurePool {
|
||||
// Use the reverse ordering so that we can call dequeue to remove the largest element.
|
||||
val ordering = Ordering.by[NeighborWithDistance[T, D], D](_.distance)
|
||||
val priorityQueue =
|
||||
new mutable.PriorityQueue[NeighborWithDistance[T, D]]()(ordering)
|
||||
linkedQueue
|
||||
.iterator()
|
||||
.asScala
|
||||
.foreach { entity =>
|
||||
val neighborWithDistance =
|
||||
NeighborWithDistance(entity.id, metric.distance(entity.embedding, embedding))
|
||||
priorityQueue.+=(neighborWithDistance)
|
||||
if (priorityQueue.size > numOfNeighbours) {
|
||||
priorityQueue.dequeue()
|
||||
}
|
||||
}
|
||||
val reverseList: List[NeighborWithDistance[T, D]] =
|
||||
priorityQueue.dequeueAll
|
||||
reverseList.reverse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object SerializableBruteForceIndex {
|
||||
def apply[T, D <: Distance[D]](
|
||||
metric: Metric[D],
|
||||
futurePool: FuturePool,
|
||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding]
|
||||
): SerializableBruteForceIndex[T, D] = {
|
||||
val bruteForceIndex = BruteForceIndex[T, D](metric, futurePool)
|
||||
|
||||
new SerializableBruteForceIndex(bruteForceIndex, embeddingInjection, thriftIteratorIO)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This is a class that wrapps a BruteForceIndex and provides a method for serialization.
|
||||
*
|
||||
* @param bruteForceIndex all queries and updates are sent to this index.
|
||||
* @param embeddingInjection injection that can convert embeddings to thrift embeddings.
|
||||
* @param thriftIteratorIO class that provides a way to write PersistedEmbeddings to disk
|
||||
*/
|
||||
class SerializableBruteForceIndex[T, D <: Distance[D]](
|
||||
bruteForceIndex: BruteForceIndex[T, D],
|
||||
embeddingInjection: PersistedEmbeddingInjection[T],
|
||||
thriftIteratorIO: ThriftIteratorIO[PersistedEmbedding])
|
||||
extends Appendable[T, BruteForceRuntimeParams.type, D]
|
||||
with Queryable[T, BruteForceRuntimeParams.type, D]
|
||||
with Serialization {
|
||||
import BruteForceIndex._
|
||||
|
||||
override def append(entity: EntityEmbedding[T]): Future[Unit] =
|
||||
bruteForceIndex.append(entity)
|
||||
|
||||
override def toQueryable: Queryable[T, BruteForceRuntimeParams.type, D] = this
|
||||
|
||||
override def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbours: Int,
|
||||
runtimeParams: BruteForceRuntimeParams.type
|
||||
): Future[List[T]] =
|
||||
bruteForceIndex.query(embedding, numOfNeighbours, runtimeParams)
|
||||
|
||||
override def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbours: Int,
|
||||
runtimeParams: BruteForceRuntimeParams.type
|
||||
): Future[List[NeighborWithDistance[T, D]]] =
|
||||
bruteForceIndex.queryWithDistance(embedding, numOfNeighbours, runtimeParams)
|
||||
|
||||
override def toDirectory(serializationDirectory: ResourceId): Unit = {
|
||||
toDirectory(new IndexOutputFile(serializationDirectory))
|
||||
}
|
||||
|
||||
override def toDirectory(serializationDirectory: AbstractFile): Unit = {
|
||||
toDirectory(new IndexOutputFile(serializationDirectory))
|
||||
}
|
||||
|
||||
private def toDirectory(serializationDirectory: IndexOutputFile): Unit = {
|
||||
val outputStream = serializationDirectory.createFile(DataFileName).getOutputStream()
|
||||
val thriftEmbeddings =
|
||||
bruteForceIndex.linkedQueue.iterator().asScala.map { embedding =>
|
||||
embeddingInjection(embedding)
|
||||
}
|
||||
try {
|
||||
thriftIteratorIO.toOutputStream(thriftEmbeddings, outputStream)
|
||||
} finally {
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/AnnInjections.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/AnnInjections.docx
Normal file
Binary file not shown.
@ -1,28 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.bijection.{Bijection, Injection}
|
||||
|
||||
// Class providing commonly used injections that can be used directly with ANN apis.
|
||||
// Injection prefixed with `J` can be used in java directly with ANN apis.
|
||||
object AnnInjections {
|
||||
val LongInjection: Injection[Long, Array[Byte]] = Injection.long2BigEndian
|
||||
|
||||
def StringInjection: Injection[String, Array[Byte]] = Injection.utf8
|
||||
|
||||
def IntInjection: Injection[Int, Array[Byte]] = Injection.int2BigEndian
|
||||
|
||||
val JLongInjection: Injection[java.lang.Long, Array[Byte]] =
|
||||
Bijection.long2Boxed
|
||||
.asInstanceOf[Bijection[Long, java.lang.Long]]
|
||||
.inverse
|
||||
.andThen(LongInjection)
|
||||
|
||||
val JStringInjection: Injection[java.lang.String, Array[Byte]] =
|
||||
StringInjection
|
||||
|
||||
val JIntInjection: Injection[java.lang.Integer, Array[Byte]] =
|
||||
Bijection.int2Boxed
|
||||
.asInstanceOf[Bijection[Int, java.lang.Integer]]
|
||||
.inverse
|
||||
.andThen(IntInjection)
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/Api.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/Api.docx
Normal file
Binary file not shown.
@ -1,150 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
||||
import com.twitter.ml.api.embedding.Embedding
|
||||
import com.twitter.ml.api.embedding.EmbeddingMath
|
||||
import com.twitter.ml.api.embedding.EmbeddingSerDe
|
||||
import com.twitter.util.Future
|
||||
|
||||
object EmbeddingType {
|
||||
type EmbeddingVector = Embedding[Float]
|
||||
val embeddingSerDe = EmbeddingSerDe.apply[Float]
|
||||
private[common] val math = EmbeddingMath.Float
|
||||
}
|
||||
|
||||
/**
|
||||
* Typed entity with an embedding associated with it.
|
||||
* @param id : Unique Id for an entity.
|
||||
* @param embedding : Embedding/Vector of an entity.
|
||||
* @tparam T: Type of id.
|
||||
*/
|
||||
case class EntityEmbedding[T](id: T, embedding: EmbeddingVector)
|
||||
|
||||
// Query interface for ANN
|
||||
trait Queryable[T, P <: RuntimeParams, D <: Distance[D]] {
|
||||
|
||||
/**
|
||||
* ANN query for ids.
|
||||
* @param embedding: Embedding/Vector to be queried with.
|
||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
||||
* @return List of approximate nearest neighbour ids.
|
||||
*/
|
||||
def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[T]]
|
||||
|
||||
/**
|
||||
* ANN query for ids with distance.
|
||||
* @param embedding: Embedding/Vector to be queried with.
|
||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
||||
* @return List of approximate nearest neighbour ids with distance from the query embedding.
|
||||
*/
|
||||
def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[NeighborWithDistance[T, D]]]
|
||||
}
|
||||
|
||||
// Query interface for ANN over indexes that are grouped
|
||||
trait QueryableGrouped[T, P <: RuntimeParams, D <: Distance[D]] extends Queryable[T, P, D] {
|
||||
|
||||
/**
|
||||
* ANN query for ids.
|
||||
* @param embedding: Embedding/Vector to be queried with.
|
||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
||||
* @param key: Optional key to lookup specific ANN index and perform query there
|
||||
* @return List of approximate nearest neighbour ids.
|
||||
*/
|
||||
def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P,
|
||||
key: Option[String]
|
||||
): Future[List[T]]
|
||||
|
||||
/**
|
||||
* ANN query for ids with distance.
|
||||
* @param embedding: Embedding/Vector to be queried with.
|
||||
* @param numOfNeighbors: Number of neighbours to be queried for.
|
||||
* @param runtimeParams: Runtime params associated with index to control accuracy/latency etc.
|
||||
* @param key: Optional key to lookup specific ANN index and perform query there
|
||||
* @return List of approximate nearest neighbour ids with distance from the query embedding.
|
||||
*/
|
||||
def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P,
|
||||
key: Option[String]
|
||||
): Future[List[NeighborWithDistance[T, D]]]
|
||||
}
|
||||
|
||||
/**
|
||||
* Runtime params associated with index to control accuracy/latency etc while querying.
|
||||
*/
|
||||
trait RuntimeParams {}
|
||||
|
||||
/**
|
||||
* ANN query result with distance.
|
||||
* @param neighbor : Id of the neighbours
|
||||
* @param distance: Distance of neighbour from query ex: D: CosineDistance, L2Distance, InnerProductDistance
|
||||
*/
|
||||
case class NeighborWithDistance[T, D <: Distance[D]](neighbor: T, distance: D)
|
||||
|
||||
/**
|
||||
* ANN query result with seed entity for which this neighbor was provided.
|
||||
* @param seed: Seed Id for which ann query was called
|
||||
* @param neighbor : Id of the neighbours
|
||||
*/
|
||||
case class NeighborWithSeed[T1, T2](seed: T1, neighbor: T2)
|
||||
|
||||
/**
|
||||
* ANN query result with distance with seed entity for which this neighbor was provided.
|
||||
* @param seed: Seed Id for which ann query was called
|
||||
* @param neighbor : Id of the neighbours
|
||||
* @param distance: Distance of neighbour from query ex: D: CosineDistance, L2Distance, InnerProductDistance
|
||||
*/
|
||||
case class NeighborWithDistanceWithSeed[T1, T2, D <: Distance[D]](
|
||||
seed: T1,
|
||||
neighbor: T2,
|
||||
distance: D)
|
||||
|
||||
trait RawAppendable[P <: RuntimeParams, D <: Distance[D]] {
|
||||
|
||||
/**
|
||||
* Append an embedding in an index.
|
||||
* @param embedding: Embedding/Vector
|
||||
* @return Future of long id associated with embedding autogenerated.
|
||||
*/
|
||||
def append(embedding: EmbeddingVector): Future[Long]
|
||||
|
||||
/**
|
||||
* Convert an Appendable to Queryable interface to query an index.
|
||||
*/
|
||||
def toQueryable: Queryable[Long, P, D]
|
||||
}
|
||||
|
||||
// Index building interface for ANN.
|
||||
trait Appendable[T, P <: RuntimeParams, D <: Distance[D]] {
|
||||
|
||||
/**
|
||||
* Append an entity with embedding in an index.
|
||||
* @param entity: Entity with its embedding
|
||||
*/
|
||||
def append(entity: EntityEmbedding[T]): Future[Unit]
|
||||
|
||||
/**
|
||||
* Convert an Appendable to Queryable interface to query an index.
|
||||
*/
|
||||
def toQueryable: Queryable[T, P, D]
|
||||
}
|
||||
|
||||
// Updatable index interface for ANN.
|
||||
trait Updatable[T] {
|
||||
def update(entity: EntityEmbedding[T]): Future[Unit]
|
||||
}
|
@ -1,21 +0,0 @@
|
||||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
platform = "java8",
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/com/google/guava",
|
||||
"3rdparty/jvm/com/twitter/bijection:core",
|
||||
"3rdparty/jvm/com/twitter/storehaus:core",
|
||||
"3rdparty/jvm/org/apache/beam:beam-sdks-java-io-google-cloud-platform",
|
||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
||||
"finatra/inject/inject-mdc/src/main/scala",
|
||||
"mediaservices/commons/src/main/scala:futuretracker",
|
||||
"src/java/com/twitter/search/common/file",
|
||||
"src/scala/com/twitter/ml/api/embedding",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
exports = [
|
||||
"3rdparty/jvm/com/twitter/bijection:core",
|
||||
],
|
||||
)
|
BIN
ann/src/main/scala/com/twitter/ann/common/BUILD.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/BUILD.docx
Normal file
Binary file not shown.
BIN
ann/src/main/scala/com/twitter/ann/common/EmbeddingProducer.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/EmbeddingProducer.docx
Normal file
Binary file not shown.
@ -1,13 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.stitch.Stitch
|
||||
|
||||
trait EmbeddingProducer[T] {
|
||||
|
||||
/**
|
||||
* Produce an embedding from type T. Implementations of this could do a lookup from an id to an
|
||||
* embedding. Or they could run a deep model on features that output and embedding.
|
||||
* @return An embedding Stitch. See go/stitch for details on how to use the Stitch API.
|
||||
*/
|
||||
def produceEmbedding(input: T): Stitch[Option[EmbeddingType.EmbeddingVector]]
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/IndexOutputFile.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/IndexOutputFile.docx
Normal file
Binary file not shown.
@ -1,226 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.google.common.io.ByteStreams
|
||||
import com.twitter.ann.common.thriftscala.AnnIndexMetadata
|
||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
||||
import com.twitter.mediaservices.commons.codec.ThriftByteBufferCodec
|
||||
import com.twitter.search.common.file.AbstractFile
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.io.OutputStream
|
||||
import java.nio.channels.Channels
|
||||
import org.apache.beam.sdk.io.FileSystems
|
||||
import org.apache.beam.sdk.io.fs.MoveOptions
|
||||
import org.apache.beam.sdk.io.fs.ResolveOptions
|
||||
import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions
|
||||
import org.apache.beam.sdk.io.fs.ResourceId
|
||||
import org.apache.beam.sdk.util.MimeTypes
|
||||
import org.apache.hadoop.io.IOUtils
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
/**
|
||||
* This class creates a wrapper around GCS filesystem and HDFS filesystem for the index
|
||||
* generation job. It implements the basic methods required by the index generation job and hides
|
||||
* the logic around handling HDFS vs GCS.
|
||||
*/
|
||||
class IndexOutputFile(val abstractFile: AbstractFile, val resourceId: ResourceId) {
|
||||
|
||||
// Success file name
|
||||
private val SUCCESS_FILE = "_SUCCESS"
|
||||
private val INDEX_METADATA_FILE = "ANN_INDEX_METADATA"
|
||||
private val MetadataCodec = new ThriftByteBufferCodec[AnnIndexMetadata](AnnIndexMetadata)
|
||||
|
||||
/**
|
||||
* Constructor for ResourceId. This is used for GCS filesystem
|
||||
* @param resourceId
|
||||
*/
|
||||
def this(resourceId: ResourceId) = {
|
||||
this(null, resourceId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor for AbstractFile. This is used for HDFS and local filesystem
|
||||
* @param abstractFile
|
||||
*/
|
||||
def this(abstractFile: AbstractFile) = {
|
||||
this(abstractFile, null)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if this instance is around an AbstractFile.
|
||||
* @return
|
||||
*/
|
||||
def isAbstractFile(): Boolean = {
|
||||
abstractFile != null
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a _SUCCESS file in the current directory.
|
||||
*/
|
||||
def createSuccessFile(): Unit = {
|
||||
if (isAbstractFile()) {
|
||||
abstractFile.createSuccessFile()
|
||||
} else {
|
||||
val successFile =
|
||||
resourceId.resolve(SUCCESS_FILE, ResolveOptions.StandardResolveOptions.RESOLVE_FILE)
|
||||
val successWriterChannel = FileSystems.create(successFile, MimeTypes.BINARY)
|
||||
successWriterChannel.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether the current instance represents a directory
|
||||
* @return True if the current instance is a directory
|
||||
*/
|
||||
def isDirectory(): Boolean = {
|
||||
if (isAbstractFile()) {
|
||||
abstractFile.isDirectory
|
||||
} else {
|
||||
resourceId.isDirectory
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the current path of the file represented by the current instance
|
||||
* @return The path string of the file/directory
|
||||
*/
|
||||
def getPath(): String = {
|
||||
if (isAbstractFile()) {
|
||||
abstractFile.getPath.toString
|
||||
} else {
|
||||
if (resourceId.isDirectory) {
|
||||
resourceId.getCurrentDirectory.toString
|
||||
} else {
|
||||
resourceId.getCurrentDirectory.toString + resourceId.getFilename
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new file @param fileName in the current directory.
|
||||
* @param fileName
|
||||
* @return A new file inside the current directory
|
||||
*/
|
||||
def createFile(fileName: String): IndexOutputFile = {
|
||||
if (isAbstractFile()) {
|
||||
// AbstractFile treats files and directories the same way. Hence, not checking for directory
|
||||
// here.
|
||||
new IndexOutputFile(abstractFile.getChild(fileName))
|
||||
} else {
|
||||
if (!resourceId.isDirectory) {
|
||||
// If this is not a directory, throw exception.
|
||||
throw new IllegalArgumentException(getPath() + " is not a directory.")
|
||||
}
|
||||
new IndexOutputFile(
|
||||
resourceId.resolve(fileName, ResolveOptions.StandardResolveOptions.RESOLVE_FILE))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new directory @param directoryName in the current directory.
|
||||
* @param directoryName
|
||||
* @return A new directory inside the current directory
|
||||
*/
|
||||
def createDirectory(directoryName: String): IndexOutputFile = {
|
||||
if (isAbstractFile()) {
|
||||
// AbstractFile treats files and directories the same way. Hence, not checking for directory
|
||||
// here.
|
||||
val dir = abstractFile.getChild(directoryName)
|
||||
dir.mkdirs()
|
||||
new IndexOutputFile(dir)
|
||||
} else {
|
||||
if (!resourceId.isDirectory) {
|
||||
// If this is not a directory, throw exception.
|
||||
throw new IllegalArgumentException(getPath() + " is not a directory.")
|
||||
}
|
||||
val newResourceId =
|
||||
resourceId.resolve(directoryName, ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
|
||||
|
||||
// Create a tmp file and delete in order to trigger directory creation
|
||||
val tmpFile =
|
||||
newResourceId.resolve("tmp", ResolveOptions.StandardResolveOptions.RESOLVE_FILE)
|
||||
val tmpWriterChannel = FileSystems.create(tmpFile, MimeTypes.BINARY)
|
||||
tmpWriterChannel.close()
|
||||
FileSystems.delete(List(tmpFile).asJava, MoveOptions.StandardMoveOptions.IGNORE_MISSING_FILES)
|
||||
|
||||
new IndexOutputFile(newResourceId)
|
||||
}
|
||||
}
|
||||
|
||||
def getChild(fileName: String, isDirectory: Boolean = false): IndexOutputFile = {
|
||||
if (isAbstractFile()) {
|
||||
new IndexOutputFile(abstractFile.getChild(fileName))
|
||||
} else {
|
||||
val resolveOption = if (isDirectory) {
|
||||
StandardResolveOptions.RESOLVE_DIRECTORY
|
||||
} else {
|
||||
StandardResolveOptions.RESOLVE_FILE
|
||||
}
|
||||
new IndexOutputFile(resourceId.resolve(fileName, resolveOption))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an OutputStream for the underlying file.
|
||||
* Note: Close the OutputStream after writing
|
||||
* @return
|
||||
*/
|
||||
def getOutputStream(): OutputStream = {
|
||||
if (isAbstractFile()) {
|
||||
abstractFile.getByteSink.openStream()
|
||||
} else {
|
||||
if (resourceId.isDirectory) {
|
||||
// If this is a directory, throw exception.
|
||||
throw new IllegalArgumentException(getPath() + " is a directory.")
|
||||
}
|
||||
val writerChannel = FileSystems.create(resourceId, MimeTypes.BINARY)
|
||||
Channels.newOutputStream(writerChannel)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an InputStream for the underlying file.
|
||||
* Note: Close the InputStream after reading
|
||||
* @return
|
||||
*/
|
||||
def getInputStream(): InputStream = {
|
||||
if (isAbstractFile()) {
|
||||
abstractFile.getByteSource.openStream()
|
||||
} else {
|
||||
if (resourceId.isDirectory) {
|
||||
// If this is a directory, throw exception.
|
||||
throw new IllegalArgumentException(getPath() + " is a directory.")
|
||||
}
|
||||
val readChannel = FileSystems.open(resourceId)
|
||||
Channels.newInputStream(readChannel)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Copies content from the srcIn into the current file.
|
||||
* @param srcIn
|
||||
*/
|
||||
def copyFrom(srcIn: InputStream): Unit = {
|
||||
val out = getOutputStream()
|
||||
try {
|
||||
IOUtils.copyBytes(srcIn, out, 4096)
|
||||
out.close()
|
||||
} catch {
|
||||
case ex: IOException =>
|
||||
IOUtils.closeStream(out);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
|
||||
def writeIndexMetadata(annIndexMetadata: AnnIndexMetadata): Unit = {
|
||||
val out = createFile(INDEX_METADATA_FILE).getOutputStream()
|
||||
val bytes = ArrayByteBufferCodec.decode(MetadataCodec.encode(annIndexMetadata))
|
||||
out.write(bytes)
|
||||
out.close()
|
||||
}
|
||||
|
||||
def loadIndexMetadata(): AnnIndexMetadata = {
|
||||
val in = ByteStreams.toByteArray(getInputStream())
|
||||
MetadataCodec.decode(ArrayByteBufferCodec.encode(in))
|
||||
}
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/IndexTransformer.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/IndexTransformer.docx
Normal file
Binary file not shown.
@ -1,118 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
||||
import com.twitter.storehaus.{ReadableStore, Store}
|
||||
import com.twitter.util.Future
|
||||
|
||||
// Utility to transform raw index to typed index using Store
|
||||
object IndexTransformer {
|
||||
|
||||
/**
|
||||
* Transform a long type queryable index to Typed queryable index
|
||||
* @param index: Raw Queryable index
|
||||
* @param store: Readable store to provide mappings between Long and T
|
||||
* @tparam T: Type to transform to
|
||||
* @tparam P: Runtime params
|
||||
* @return Queryable index typed on T
|
||||
*/
|
||||
def transformQueryable[T, P <: RuntimeParams, D <: Distance[D]](
|
||||
index: Queryable[Long, P, D],
|
||||
store: ReadableStore[Long, T]
|
||||
): Queryable[T, P, D] = {
|
||||
new Queryable[T, P, D] {
|
||||
override def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[T]] = {
|
||||
val neighbors = index.query(embedding, numOfNeighbors, runtimeParams)
|
||||
neighbors
|
||||
.flatMap(nn => {
|
||||
val ids = nn.map(id => store.get(id).map(_.get))
|
||||
Future
|
||||
.collect(ids)
|
||||
.map(_.toList)
|
||||
})
|
||||
}
|
||||
|
||||
override def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[NeighborWithDistance[T, D]]] = {
|
||||
val neighbors = index.queryWithDistance(embedding, numOfNeighbors, runtimeParams)
|
||||
neighbors
|
||||
.flatMap(nn => {
|
||||
val ids = nn.map(obj =>
|
||||
store.get(obj.neighbor).map(id => NeighborWithDistance(id.get, obj.distance)))
|
||||
Future
|
||||
.collect(ids)
|
||||
.map(_.toList)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform a long type appendable index to Typed appendable index
|
||||
* @param index: Raw Appendable index
|
||||
* @param store: Writable store to store mappings between Long and T
|
||||
* @tparam T: Type to transform to
|
||||
* @return Appendable index typed on T
|
||||
*/
|
||||
def transformAppendable[T, P <: RuntimeParams, D <: Distance[D]](
|
||||
index: RawAppendable[P, D],
|
||||
store: Store[Long, T]
|
||||
): Appendable[T, P, D] = {
|
||||
new Appendable[T, P, D]() {
|
||||
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
|
||||
index
|
||||
.append(entity.embedding)
|
||||
.flatMap(id => store.put((id, Some(entity.id))))
|
||||
}
|
||||
|
||||
override def toQueryable: Queryable[T, P, D] = {
|
||||
transformQueryable(index.toQueryable, store)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform a long type appendable and queryable index to Typed appendable and queryable index
|
||||
* @param index: Raw Appendable and queryable index
|
||||
* @param store: Store to provide/store mappings between Long and T
|
||||
* @tparam T: Type to transform to
|
||||
* @tparam Index: Index
|
||||
* @return Appendable and queryable index typed on T
|
||||
*/
|
||||
def transform1[
|
||||
Index <: RawAppendable[P, D] with Queryable[Long, P, D],
|
||||
T,
|
||||
P <: RuntimeParams,
|
||||
D <: Distance[D]
|
||||
](
|
||||
index: Index,
|
||||
store: Store[Long, T]
|
||||
): Queryable[T, P, D] with Appendable[T, P, D] = {
|
||||
val queryable = transformQueryable(index, store)
|
||||
val appendable = transformAppendable(index, store)
|
||||
|
||||
new Queryable[T, P, D] with Appendable[T, P, D] {
|
||||
override def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
) = queryable.query(embedding, numOfNeighbors, runtimeParams)
|
||||
|
||||
override def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
) = queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams)
|
||||
|
||||
override def append(entity: EntityEmbedding[T]) = appendable.append(entity)
|
||||
|
||||
override def toQueryable: Queryable[T, P, D] = appendable.toQueryable
|
||||
}
|
||||
}
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/MemoizedInEpochs.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/MemoizedInEpochs.docx
Normal file
Binary file not shown.
@ -1,37 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.util.Return
|
||||
import com.twitter.util.Throw
|
||||
import com.twitter.util.Try
|
||||
import com.twitter.util.logging.Logging
|
||||
|
||||
// Memoization with a twist
|
||||
// New epoch reuse K:V pairs from previous and recycle everything else
|
||||
class MemoizedInEpochs[K, V](f: K => Try[V]) extends Logging {
|
||||
private var memoizedCalls: Map[K, V] = Map.empty
|
||||
|
||||
def epoch(keys: Seq[K]): Seq[V] = {
|
||||
val newSet = keys.toSet
|
||||
val keysToBeComputed = newSet.diff(memoizedCalls.keySet)
|
||||
val computedKeysAndValues = keysToBeComputed.map { key =>
|
||||
info(s"Memoize ${key}")
|
||||
(key, f(key))
|
||||
}
|
||||
val keysAndValuesAfterFilteringFailures = computedKeysAndValues
|
||||
.flatMap {
|
||||
case (key, Return(value)) => Some((key, value))
|
||||
case (key, Throw(e)) =>
|
||||
warn(s"Calling f for ${key} has failed", e)
|
||||
|
||||
None
|
||||
}
|
||||
val keysReusedFromLastEpoch = memoizedCalls.filterKeys(newSet.contains)
|
||||
memoizedCalls = keysReusedFromLastEpoch ++ keysAndValuesAfterFilteringFailures
|
||||
|
||||
debug(s"Final memoization is ${memoizedCalls.keys.mkString(", ")}")
|
||||
|
||||
keys.flatMap(memoizedCalls.get)
|
||||
}
|
||||
|
||||
def currentEpochKeys: Set[K] = memoizedCalls.keySet
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/Metric.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/Metric.docx
Normal file
Binary file not shown.
@ -1,290 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.google.common.collect.ImmutableBiMap
|
||||
import com.twitter.ann.common.EmbeddingType._
|
||||
import com.twitter.ann.common.thriftscala.DistanceMetric
|
||||
import com.twitter.ann.common.thriftscala.{CosineDistance => ServiceCosineDistance}
|
||||
import com.twitter.ann.common.thriftscala.{Distance => ServiceDistance}
|
||||
import com.twitter.ann.common.thriftscala.{InnerProductDistance => ServiceInnerProductDistance}
|
||||
import com.twitter.ann.common.thriftscala.{EditDistance => ServiceEditDistance}
|
||||
import com.twitter.ann.common.thriftscala.{L2Distance => ServiceL2Distance}
|
||||
import com.twitter.bijection.Injection
|
||||
import scala.util.Failure
|
||||
import scala.util.Success
|
||||
import scala.util.Try
|
||||
|
||||
// Ann distance metrics
|
||||
trait Distance[D] extends Any with Ordered[D] {
|
||||
def distance: Float
|
||||
}
|
||||
|
||||
case class L2Distance(distance: Float) extends AnyVal with Distance[L2Distance] {
|
||||
override def compare(that: L2Distance): Int =
|
||||
Ordering.Float.compare(this.distance, that.distance)
|
||||
}
|
||||
|
||||
case class CosineDistance(distance: Float) extends AnyVal with Distance[CosineDistance] {
|
||||
override def compare(that: CosineDistance): Int =
|
||||
Ordering.Float.compare(this.distance, that.distance)
|
||||
}
|
||||
|
||||
case class InnerProductDistance(distance: Float)
|
||||
extends AnyVal
|
||||
with Distance[InnerProductDistance] {
|
||||
override def compare(that: InnerProductDistance): Int =
|
||||
Ordering.Float.compare(this.distance, that.distance)
|
||||
}
|
||||
|
||||
case class EditDistance(distance: Float) extends AnyVal with Distance[EditDistance] {
|
||||
override def compare(that: EditDistance): Int =
|
||||
Ordering.Float.compare(this.distance, that.distance)
|
||||
}
|
||||
|
||||
object Metric {
|
||||
private[this] val thriftMetricMapping = ImmutableBiMap.of(
|
||||
L2,
|
||||
DistanceMetric.L2,
|
||||
Cosine,
|
||||
DistanceMetric.Cosine,
|
||||
InnerProduct,
|
||||
DistanceMetric.InnerProduct,
|
||||
Edit,
|
||||
DistanceMetric.EditDistance
|
||||
)
|
||||
|
||||
def fromThrift(metric: DistanceMetric): Metric[_ <: Distance[_]] = {
|
||||
thriftMetricMapping.inverse().get(metric)
|
||||
}
|
||||
|
||||
def toThrift(metric: Metric[_ <: Distance[_]]): DistanceMetric = {
|
||||
thriftMetricMapping.get(metric)
|
||||
}
|
||||
|
||||
def fromString(metricName: String): Metric[_ <: Distance[_]]
|
||||
with Injection[_, ServiceDistance] = {
|
||||
metricName match {
|
||||
case "Cosine" => Cosine
|
||||
case "L2" => L2
|
||||
case "InnerProduct" => InnerProduct
|
||||
case "EditDistance" => Edit
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"No Metric with the name $metricName")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait Metric[D <: Distance[D]] {
|
||||
def distance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): D
|
||||
def absoluteDistance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Float
|
||||
def fromAbsoluteDistance(distance: Float): D
|
||||
}
|
||||
|
||||
case object L2 extends Metric[L2Distance] with Injection[L2Distance, ServiceDistance] {
|
||||
override def distance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): L2Distance = {
|
||||
fromAbsoluteDistance(MetricUtil.l2distance(embedding1, embedding2).toFloat)
|
||||
}
|
||||
|
||||
override def fromAbsoluteDistance(distance: Float): L2Distance = {
|
||||
L2Distance(distance)
|
||||
}
|
||||
|
||||
override def absoluteDistance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Float = distance(embedding1, embedding2).distance
|
||||
|
||||
override def apply(scalaDistance: L2Distance): ServiceDistance = {
|
||||
ServiceDistance.L2Distance(ServiceL2Distance(scalaDistance.distance))
|
||||
}
|
||||
|
||||
override def invert(serviceDistance: ServiceDistance): Try[L2Distance] = {
|
||||
serviceDistance match {
|
||||
case ServiceDistance.L2Distance(l2Distance) =>
|
||||
Success(L2Distance(l2Distance.distance.toFloat))
|
||||
case distance =>
|
||||
Failure(new IllegalArgumentException(s"Expected an l2 distance but got $distance"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case object Cosine extends Metric[CosineDistance] with Injection[CosineDistance, ServiceDistance] {
|
||||
override def distance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): CosineDistance = {
|
||||
fromAbsoluteDistance(1 - MetricUtil.cosineSimilarity(embedding1, embedding2))
|
||||
}
|
||||
|
||||
override def fromAbsoluteDistance(distance: Float): CosineDistance = {
|
||||
CosineDistance(distance)
|
||||
}
|
||||
|
||||
override def absoluteDistance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Float = distance(embedding1, embedding2).distance
|
||||
|
||||
override def apply(scalaDistance: CosineDistance): ServiceDistance = {
|
||||
ServiceDistance.CosineDistance(ServiceCosineDistance(scalaDistance.distance))
|
||||
}
|
||||
|
||||
override def invert(serviceDistance: ServiceDistance): Try[CosineDistance] = {
|
||||
serviceDistance match {
|
||||
case ServiceDistance.CosineDistance(cosineDistance) =>
|
||||
Success(CosineDistance(cosineDistance.distance.toFloat))
|
||||
case distance =>
|
||||
Failure(new IllegalArgumentException(s"Expected a cosine distance but got $distance"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case object InnerProduct
|
||||
extends Metric[InnerProductDistance]
|
||||
with Injection[InnerProductDistance, ServiceDistance] {
|
||||
override def distance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): InnerProductDistance = {
|
||||
fromAbsoluteDistance(1 - MetricUtil.dot(embedding1, embedding2))
|
||||
}
|
||||
|
||||
override def fromAbsoluteDistance(distance: Float): InnerProductDistance = {
|
||||
InnerProductDistance(distance)
|
||||
}
|
||||
|
||||
override def absoluteDistance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Float = distance(embedding1, embedding2).distance
|
||||
|
||||
override def apply(scalaDistance: InnerProductDistance): ServiceDistance = {
|
||||
ServiceDistance.InnerProductDistance(ServiceInnerProductDistance(scalaDistance.distance))
|
||||
}
|
||||
|
||||
override def invert(
|
||||
serviceDistance: ServiceDistance
|
||||
): Try[InnerProductDistance] = {
|
||||
serviceDistance match {
|
||||
case ServiceDistance.InnerProductDistance(cosineDistance) =>
|
||||
Success(InnerProductDistance(cosineDistance.distance.toFloat))
|
||||
case distance =>
|
||||
Failure(
|
||||
new IllegalArgumentException(s"Expected a inner product distance but got $distance")
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case object Edit extends Metric[EditDistance] with Injection[EditDistance, ServiceDistance] {
|
||||
|
||||
private def intDistance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector,
|
||||
pos1: Int,
|
||||
pos2: Int,
|
||||
precomputedDistances: scala.collection.mutable.Map[(Int, Int), Int]
|
||||
): Int = {
|
||||
// return the remaining characters of other String
|
||||
if (pos1 == 0) return pos2
|
||||
if (pos2 == 0) return pos1
|
||||
|
||||
// To check if the recursive tree
|
||||
// for given n & m has already been executed
|
||||
precomputedDistances.getOrElse(
|
||||
(pos1, pos2), {
|
||||
// We might want to change this so that capitals are considered the same.
|
||||
// Also maybe some characters that look similar should also be the same.
|
||||
val computed = if (embedding1(pos1 - 1) == embedding2(pos2 - 1)) {
|
||||
intDistance(embedding1, embedding2, pos1 - 1, pos2 - 1, precomputedDistances)
|
||||
} else { // If characters are nt equal, we need to
|
||||
// find the minimum cost out of all 3 operations.
|
||||
val insert = intDistance(embedding1, embedding2, pos1, pos2 - 1, precomputedDistances)
|
||||
val del = intDistance(embedding1, embedding2, pos1 - 1, pos2, precomputedDistances)
|
||||
val replace =
|
||||
intDistance(embedding1, embedding2, pos1 - 1, pos2 - 1, precomputedDistances)
|
||||
1 + Math.min(insert, Math.min(del, replace))
|
||||
}
|
||||
precomputedDistances.put((pos1, pos2), computed)
|
||||
computed
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
override def distance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): EditDistance = {
|
||||
val editDistance = intDistance(
|
||||
embedding1,
|
||||
embedding2,
|
||||
embedding1.length,
|
||||
embedding2.length,
|
||||
scala.collection.mutable.Map[(Int, Int), Int]()
|
||||
)
|
||||
EditDistance(editDistance)
|
||||
}
|
||||
|
||||
override def fromAbsoluteDistance(distance: Float): EditDistance = {
|
||||
EditDistance(distance.toInt)
|
||||
}
|
||||
|
||||
override def absoluteDistance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Float = distance(embedding1, embedding2).distance
|
||||
|
||||
override def apply(scalaDistance: EditDistance): ServiceDistance = {
|
||||
ServiceDistance.EditDistance(ServiceEditDistance(scalaDistance.distance.toInt))
|
||||
}
|
||||
|
||||
override def invert(
|
||||
serviceDistance: ServiceDistance
|
||||
): Try[EditDistance] = {
|
||||
serviceDistance match {
|
||||
case ServiceDistance.EditDistance(cosineDistance) =>
|
||||
Success(EditDistance(cosineDistance.distance.toFloat))
|
||||
case distance =>
|
||||
Failure(
|
||||
new IllegalArgumentException(s"Expected a inner product distance but got $distance")
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object MetricUtil {
|
||||
private[ann] def dot(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Float = {
|
||||
math.dotProduct(embedding1, embedding2)
|
||||
}
|
||||
|
||||
private[ann] def l2distance(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Double = {
|
||||
math.l2Distance(embedding1, embedding2)
|
||||
}
|
||||
|
||||
private[ann] def cosineSimilarity(
|
||||
embedding1: EmbeddingVector,
|
||||
embedding2: EmbeddingVector
|
||||
): Float = {
|
||||
math.cosineSimilarity(embedding1, embedding2).toFloat
|
||||
}
|
||||
|
||||
private[ann] def norm(
|
||||
embedding: EmbeddingVector
|
||||
): EmbeddingVector = {
|
||||
math.normalize(embedding)
|
||||
}
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/QueryableById.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/QueryableById.docx
Normal file
Binary file not shown.
@ -1,41 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.stitch.Stitch
|
||||
|
||||
/**
|
||||
* This is a trait that allows you to query for nearest neighbors given an arbitrary type T1. This is
|
||||
* in contrast to a regular com.twitter.ann.common.Appendable, which takes an embedding as the input
|
||||
* argument.
|
||||
*
|
||||
* This interface uses the Stitch API for batching. See go/stitch for details on how to use it.
|
||||
*
|
||||
* @tparam T1 type of the query.
|
||||
* @tparam T2 type of the result.
|
||||
* @tparam P runtime parameters supported by the index.
|
||||
* @tparam D distance function used in the index.
|
||||
*/
|
||||
trait QueryableById[T1, T2, P <: RuntimeParams, D <: Distance[D]] {
|
||||
def queryById(
|
||||
id: T1,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[T2]]
|
||||
|
||||
def queryByIdWithDistance(
|
||||
id: T1,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[NeighborWithDistance[T2, D]]]
|
||||
|
||||
def batchQueryById(
|
||||
ids: Seq[T1],
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[NeighborWithSeed[T1, T2]]]
|
||||
|
||||
def batchQueryWithDistanceById(
|
||||
ids: Seq[T1],
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[NeighborWithDistanceWithSeed[T1, T2, D]]]
|
||||
}
|
Binary file not shown.
@ -1,91 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.stitch.Stitch
|
||||
|
||||
/**
|
||||
* Implementation of QueryableById that composes an EmbeddingProducer and a Queryable so that we
|
||||
* can get nearest neighbors given an id of type T1
|
||||
* @param embeddingProducer provides an embedding given an id.
|
||||
* @param queryable provides a list of neighbors given an embedding.
|
||||
* @tparam T1 type of the query.
|
||||
* @tparam T2 type of the result.
|
||||
* @tparam P runtime parameters supported by the index.
|
||||
* @tparam D distance function used in the index.
|
||||
*/
|
||||
class QueryableByIdImplementation[T1, T2, P <: RuntimeParams, D <: Distance[D]](
|
||||
embeddingProducer: EmbeddingProducer[T1],
|
||||
queryable: Queryable[T2, P, D])
|
||||
extends QueryableById[T1, T2, P, D] {
|
||||
override def queryById(
|
||||
id: T1,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[T2]] = {
|
||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
||||
embeddingOption
|
||||
.map { embedding =>
|
||||
Stitch.callFuture(queryable.query(embedding, numOfNeighbors, runtimeParams))
|
||||
}.getOrElse {
|
||||
Stitch.value(List.empty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def queryByIdWithDistance(
|
||||
id: T1,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[NeighborWithDistance[T2, D]]] = {
|
||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
||||
embeddingOption
|
||||
.map { embedding =>
|
||||
Stitch.callFuture(queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
|
||||
}.getOrElse {
|
||||
Stitch.value(List.empty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def batchQueryById(
|
||||
ids: Seq[T1],
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[NeighborWithSeed[T1, T2]]] = {
|
||||
Stitch
|
||||
.traverse(ids) { id =>
|
||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
||||
embeddingOption
|
||||
.map { embedding =>
|
||||
Stitch
|
||||
.callFuture(queryable.query(embedding, numOfNeighbors, runtimeParams)).map(
|
||||
_.map(neighbor => NeighborWithSeed(id, neighbor)))
|
||||
}.getOrElse {
|
||||
Stitch.value(List.empty)
|
||||
}.handle { case _ => List.empty }
|
||||
}
|
||||
}.map { _.toList.flatten }
|
||||
}
|
||||
|
||||
override def batchQueryWithDistanceById(
|
||||
ids: Seq[T1],
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Stitch[List[NeighborWithDistanceWithSeed[T1, T2, D]]] = {
|
||||
Stitch
|
||||
.traverse(ids) { id =>
|
||||
embeddingProducer.produceEmbedding(id).flatMap { embeddingOption =>
|
||||
embeddingOption
|
||||
.map { embedding =>
|
||||
Stitch
|
||||
.callFuture(queryable.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
|
||||
.map(_.map(neighbor =>
|
||||
NeighborWithDistanceWithSeed(id, neighbor.neighbor, neighbor.distance)))
|
||||
}.getOrElse {
|
||||
Stitch.value(List.empty)
|
||||
}.handle { case _ => List.empty }
|
||||
}
|
||||
}.map {
|
||||
_.toList.flatten
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,26 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
||||
import com.twitter.util.Future
|
||||
|
||||
object QueryableOperations {
|
||||
implicit class Map[T, P <: RuntimeParams, D <: Distance[D]](
|
||||
val q: Queryable[T, P, D]) {
|
||||
def mapRuntimeParameters(f: P => P): Queryable[T, P, D] = {
|
||||
new Queryable[T, P, D] {
|
||||
def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[T]] = q.query(embedding, numOfNeighbors, f(runtimeParams))
|
||||
|
||||
def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[NeighborWithDistance[T, D]]] =
|
||||
q.queryWithDistance(embedding, numOfNeighbors, f(runtimeParams))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -1,29 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import com.twitter.util.{Future, FuturePool}
|
||||
|
||||
trait ReadWriteFuturePool {
|
||||
def read[T](f: => T): Future[T]
|
||||
def write[T](f: => T): Future[T]
|
||||
}
|
||||
|
||||
object ReadWriteFuturePool {
|
||||
def apply(readPool: FuturePool, writePool: FuturePool): ReadWriteFuturePool = {
|
||||
new ReadWriteFuturePoolANN(readPool, writePool)
|
||||
}
|
||||
|
||||
def apply(commonPool: FuturePool): ReadWriteFuturePool = {
|
||||
new ReadWriteFuturePoolANN(commonPool, commonPool)
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
private[ann] class ReadWriteFuturePoolANN(readPool: FuturePool, writePool: FuturePool)
|
||||
extends ReadWriteFuturePool {
|
||||
def read[T](f: => T): Future[T] = {
|
||||
readPool.apply(f)
|
||||
}
|
||||
def write[T](f: => T): Future[T] = {
|
||||
writePool.apply(f)
|
||||
}
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/Serialization.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/Serialization.docx
Normal file
Binary file not shown.
@ -1,28 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.search.common.file.AbstractFile
|
||||
import org.apache.beam.sdk.io.fs.ResourceId
|
||||
|
||||
/**
|
||||
* Interface for writing an Appendable to a directory.
|
||||
*/
|
||||
trait Serialization {
|
||||
def toDirectory(
|
||||
serializationDirectory: AbstractFile
|
||||
): Unit
|
||||
|
||||
def toDirectory(
|
||||
serializationDirectory: ResourceId
|
||||
): Unit
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for reading a Queryable from a directory
|
||||
* @tparam T the id of the embeddings
|
||||
* @tparam Q type of the Queryable that is deserialized.
|
||||
*/
|
||||
trait QueryableDeserialization[T, P <: RuntimeParams, D <: Distance[D], Q <: Queryable[T, P, D]] {
|
||||
def fromDirectory(
|
||||
serializationDirectory: AbstractFile
|
||||
): Q
|
||||
}
|
Binary file not shown.
@ -1,64 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.ann.common.EmbeddingType._
|
||||
import com.twitter.ann.common.thriftscala.{
|
||||
NearestNeighborQuery,
|
||||
NearestNeighborResult,
|
||||
Distance => ServiceDistance,
|
||||
RuntimeParams => ServiceRuntimeParams
|
||||
}
|
||||
import com.twitter.bijection.Injection
|
||||
import com.twitter.finagle.Service
|
||||
import com.twitter.mediaservices.commons.codec.ArrayByteBufferCodec
|
||||
import com.twitter.util.Future
|
||||
|
||||
class ServiceClientQueryable[T, P <: RuntimeParams, D <: Distance[D]](
|
||||
service: Service[NearestNeighborQuery, NearestNeighborResult],
|
||||
runtimeParamInjection: Injection[P, ServiceRuntimeParams],
|
||||
distanceInjection: Injection[D, ServiceDistance],
|
||||
idInjection: Injection[T, Array[Byte]])
|
||||
extends Queryable[T, P, D] {
|
||||
override def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[T]] = {
|
||||
service
|
||||
.apply(
|
||||
NearestNeighborQuery(
|
||||
embeddingSerDe.toThrift(embedding),
|
||||
withDistance = false,
|
||||
runtimeParamInjection(runtimeParams),
|
||||
numOfNeighbors
|
||||
)
|
||||
)
|
||||
.map { result =>
|
||||
result.nearestNeighbors.map { nearestNeighbor =>
|
||||
idInjection.invert(ArrayByteBufferCodec.decode(nearestNeighbor.id)).get
|
||||
}.toList
|
||||
}
|
||||
}
|
||||
|
||||
override def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[NeighborWithDistance[T, D]]] =
|
||||
service
|
||||
.apply(
|
||||
NearestNeighborQuery(
|
||||
embeddingSerDe.toThrift(embedding),
|
||||
withDistance = true,
|
||||
runtimeParamInjection(runtimeParams),
|
||||
numOfNeighbors
|
||||
)
|
||||
)
|
||||
.map { result =>
|
||||
result.nearestNeighbors.map { nearestNeighbor =>
|
||||
NeighborWithDistance(
|
||||
idInjection.invert(ArrayByteBufferCodec.decode(nearestNeighbor.id)).get,
|
||||
distanceInjection.invert(nearestNeighbor.distance.get).get
|
||||
)
|
||||
}.toList
|
||||
}
|
||||
}
|
BIN
ann/src/main/scala/com/twitter/ann/common/ShardApi.docx
Normal file
BIN
ann/src/main/scala/com/twitter/ann/common/ShardApi.docx
Normal file
Binary file not shown.
@ -1,87 +0,0 @@
|
||||
package com.twitter.ann.common
|
||||
|
||||
import com.twitter.ann.common.EmbeddingType.EmbeddingVector
|
||||
import com.twitter.util.Future
|
||||
import scala.util.Random
|
||||
|
||||
trait ShardFunction[T] {
|
||||
|
||||
/**
|
||||
* Shard function to shard embedding based on total shards and embedding data.
|
||||
* @param shards
|
||||
* @param entity
|
||||
* @return Shard index, from 0(Inclusive) to shards(Exclusive))
|
||||
*/
|
||||
def apply(shards: Int, entity: EntityEmbedding[T]): Int
|
||||
}
|
||||
|
||||
/**
|
||||
* Randomly shards the embeddings based on number of total shards.
|
||||
*/
|
||||
class RandomShardFunction[T] extends ShardFunction[T] {
|
||||
def apply(shards: Int, entity: EntityEmbedding[T]): Int = {
|
||||
Random.nextInt(shards)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sharded appendable to shard the embedding into different appendable indices
|
||||
* @param indices: Sequence of appendable indices
|
||||
* @param shardFn: Shard function to shard data into different indices
|
||||
* @param shards: Total shards
|
||||
* @tparam T: Type of id.
|
||||
*/
|
||||
class ShardedAppendable[T, P <: RuntimeParams, D <: Distance[D]](
|
||||
indices: Seq[Appendable[T, P, D]],
|
||||
shardFn: ShardFunction[T],
|
||||
shards: Int)
|
||||
extends Appendable[T, P, D] {
|
||||
override def append(entity: EntityEmbedding[T]): Future[Unit] = {
|
||||
val shard = shardFn(shards, entity)
|
||||
val index = indices(shard)
|
||||
index.append(entity)
|
||||
}
|
||||
|
||||
override def toQueryable: Queryable[T, P, D] = {
|
||||
new ComposedQueryable[T, P, D](indices.map(_.toQueryable))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Composition of sequence of queryable indices, it queries all the indices,
|
||||
* and merges the result in memory to return the K nearest neighbours
|
||||
* @param indices: Sequence of queryable indices
|
||||
* @tparam T: Type of id
|
||||
* @tparam P: Type of runtime param
|
||||
* @tparam D: Type of distance metric
|
||||
*/
|
||||
class ComposedQueryable[T, P <: RuntimeParams, D <: Distance[D]](
|
||||
indices: Seq[Queryable[T, P, D]])
|
||||
extends Queryable[T, P, D] {
|
||||
private[this] val ordering =
|
||||
Ordering.by[NeighborWithDistance[T, D], D](_.distance)
|
||||
override def query(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[T]] = {
|
||||
val neighbours = queryWithDistance(embedding, numOfNeighbors, runtimeParams)
|
||||
neighbours.map(list => list.map(nn => nn.neighbor))
|
||||
}
|
||||
|
||||
override def queryWithDistance(
|
||||
embedding: EmbeddingVector,
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P
|
||||
): Future[List[NeighborWithDistance[T, D]]] = {
|
||||
val futures = Future.collect(
|
||||
indices.map(index => index.queryWithDistance(embedding, numOfNeighbors, runtimeParams))
|
||||
)
|
||||
futures.map { list =>
|
||||
list.flatten
|
||||
.sorted(ordering)
|
||||
.take(numOfNeighbors)
|
||||
.toList
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user